当前位置: 首页 > news >正文

深度学习(pytorch量化)

pytorch中的动态量化和静态量化是两种主要的模型量化技术,旨在通过使用低精度数据类型(如 int8)代替高精度数据类型(如 float32)来减小模型大小并加速推理。

动态量化:在模型运行时(推理时)动态计算激活(activations)的量化参数(scale 和 zero_point)。权重(weights)的量化通常在模型加载时或第一次运行前进行。

静态量化:在模型部署之前,使用一个代表性的校准数据集(Calibration Dataset)预先确定网络中所有权重和所有激活的量化参数(scale 和 zero_point)。这些参数在推理过程中是固定的(静态的)。

部署时通常静态量化比较常用一些。下面给了两个例子,可以验证一下。

动态量化:

import torch
import torch.nn as nn
import warnings
warnings.filterwarnings("ignore")
# torch.serialization.add_safe_globals([torch.ScriptObject])class SimpleLSTM(nn.Module):"""简单的LSTM模型,适合动态量化"""def __init__(self, input_size=10, hidden_size=50, num_layers=2, output_size=15):super().__init__()self.hidden_size = hidden_sizeself.num_layers = num_layers# LSTM层self.lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,batch_first=True)# 全连接层self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):# 初始化隐藏状态h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)# LSTM前向传播out, _ = self.lstm(x, (h0, c0))# 只取最后一个时间步的输出out = self.fc(out[:, -1, :])return outdef save_fp32_model(model_fp32, x):model_fp32.eval()y = model_fp32(x)print("FP32模型输出:", y)torch.save(model_fp32.state_dict(), 'model_fp32.pth')def load_fp32_model(x):model_fp32 = SimpleLSTM()model_fp32.load_state_dict(torch.load('model_fp32.pth'))model_fp32.eval()y_fp32 = model_fp32(x)print("加载的FP32模型输出:", y_fp32)return model_fp32def save_int8_model(model_fp32, x):model_int8 = torch.quantization.quantize_dynamic(model_fp32,{nn.LSTM,nn.Linear},dtype=torch.qint8)model_int8.eval()y_int8 = model_int8(x)print("INT8模型输出:", y_int8)torch.save(model_int8.state_dict(), 'model_int8.pth')def load_int8_model(x):model_fp32 = SimpleLSTM()model_int8 = torch.quantization.quantize_dynamic(model_fp32,{nn.LSTM,nn.Linear},dtype=torch.qint8)model_int8.load_state_dict(torch.load('model_int8.pth',weights_only=False))model_int8.eval()y_int8 = model_int8(x)print("加载的INT8模型输出:", y_int8)return model_int8if __name__ == '__main__':x = torch.randn(1, 10, 10)model_fp32 = SimpleLSTM()save_fp32_model(model_fp32,x)save_int8_model(model_fp32,x)load_fp32_model(x)load_int8_model(x)

静态量化:

import torch
import numpy as np
import warnings
warnings.filterwarnings("ignore")class Model(torch.nn.Module):def __init__(self):super().__init__()# QuantStub converts tensors from floating point to quantizedself.quant = torch.ao.quantization.QuantStub()self.conv = torch.nn.Conv2d(1, 100, 1)self.conv1 = torch.nn.Conv2d(100, 100, 1)self.conv2 = torch.nn.Conv2d(100, 100, 1)self.conv3 = torch.nn.Conv2d(100, 1, 1)self.relu1 = torch.nn.ReLU()self.relu2 = torch.nn.ReLU()# DeQuantStub converts tensors from quantized to floating pointself.dequant = torch.ao.quantization.DeQuantStub()def forward(self, x):x = self.quant(x)x = self.conv(x)x = self.conv1(x)x = self.relu1(x)x = self.conv2(x)x = self.relu2(x)x = self.conv3(x)x = self.dequant(x)return xdef save_fp32_model(model_fp32,x):model_fp32.eval()y = model_fp32(x)print("FP32模型输出:", y)torch.save(model_fp32.state_dict(), 'model_fp32.pth')torch.onnx.export(model_fp32,x,'model_fp32.onnx',input_names=['input'],output_names=['output'])def load_fp32_model(x):model_fp32 = Model()model_fp32.load_state_dict(torch.load('model_fp32.pth'))model_fp32.eval()y_fp32 = model_fp32(x)print("加载的FP32模型输出:", y_fp32)return model_fp32def save_int8_model(model_fp32,x):model_fp32.eval()  model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('x86')model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv1', 'relu1'], ['conv2', 'relu2']])model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused)#calibration
    with torch.no_grad():  for i in range(100):  input_data = torch.randn(1, 1, 4, 4)          model_fp32_prepared(input_data)model_int8 = torch.ao.quantization.convert(model_fp32_prepared)model_int8.eval()  y_int8 = model_int8(x)print("INT8模型输出:", y_int8)torch.save(model_int8.state_dict(), 'model_int8.pth')torch.onnx.export(model_int8,x,'model_int8.onnx',input_names=['input'],output_names=['output'])def load_int8_model(x):model_fp32 = Model()model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('x86')model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv1', 'relu1'], ['conv2', 'relu2']])model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused)model_int8 = torch.ao.quantization.convert(model_fp32_prepared)model_int8.load_state_dict(torch.load('model_int8.pth'))model_int8.eval()y_int8 = model_int8(x)print("加载的INT8模型输出:", y_int8)return model_int8if __name__ == '__main__':x = np.array([[0.1,0.2,0.3,0.4],[0.5,0.6,0.7,0.8],[0.9,0.1,0.2,0.3],[0.4,0.5,0.6,0.7]], dtype=np.float32)x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0)  model_fp32 = Model()save_fp32_model(model_fp32,x)save_int8_model(model_fp32,x)load_fp32_model(x)load_int8_model(x)
http://www.sczhlp.com/news/767.html

相关文章:

  • 在运维工作中,传统虚拟化与docker有什么区别?
  • 在运维工作中,Docker怎么清理容器磁盘空间?
  • 在运维工作中,Dockerfile中常见指令有哪些?
  • 英语_阅读_Rivers are important in culture_单词_待读
  • 题解:P12151 【MX-X11-T5】「蓬莱人形 Round 1」俄罗斯方块
  • 题解:P1291 [SHOI2002] 百事世界杯之旅
  • 题解:P4170 [CQOI2007] 涂色
  • 课堂分组赛、组队赛小结
  • 【AI News | 20250725】每日AI进展
  • 题解:P13308 故障
  • 今天做什么
  • mmap提高LCD显示效率
  • 用 Python 构建可扩展的验证码识别系统
  • Java学习Day28
  • 语录
  • 深度学习(onnx量化)
  • Redisson
  • P13493 【MX-X14-T3】心电感应 题解
  • uni-app项目跑APP报useStore报错
  • DE_aemmprty 草稿纸合集
  • 22天
  • 基于 Python 的简易验证码识别系统设计与实现
  • java语法的学习笔记
  • 机械运动
  • 【2025.7.28】模拟赛T4
  • 《构建之法》读后感
  • 亚马逊发布TEACh数据集训练家用机器人
  • 日记
  • 完全使用TRAE和AI 开发一款完整的应用----第一周
  • CentOS Stream 9上部署FTP应用服务的两种方法(传统安装和docker-compose)