当前位置: 首页 > news >正文

深度学习(onnx量化)

onnx中的动态量化和静态量化概念与pytorch中的核心思想一致,但实现工具、流程和具体api有所不同。

onnx量化通常依赖onnxrunntime来执行量化模型,并使用onnx工具库进行模型转换。

除了pytorch量化和onnx量化,实际工作中一般像英伟达、地平线、昇腾等不同的芯片都会有各自独特的工具链和加速算子,按照官方教程使用即可。

下面同样给了两个例子,可以验证一下。结合上篇代码可以做个对比。

动态量化:

import torch
import torch.nn as nn
import warnings
import numpy as np
import onnxruntime as ortwarnings.filterwarnings("ignore")
from onnxruntime.quantization import QuantType, quantize_dynamicclass SimpleLSTM(nn.Module):"""简单的LSTM模型,适合动态量化"""def __init__(self, input_size=10, hidden_size=50, num_layers=2, output_size=15):super().__init__()self.hidden_size = hidden_sizeself.num_layers = num_layers# LSTM层self.lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,batch_first=True)# 全连接层self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):# 初始化隐藏状态h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)# LSTM前向传播out, _ = self.lstm(x, (h0, c0))# 只取最后一个时间步的输出out = self.fc(out[:, -1, :])return outdef quantize_onnx_model(input_model_path, output_model_path):quantize_dynamic(input_model_path,output_model_path,weight_type=QuantType.QInt8    )print(f"ONNX模型已动态量化并保存到: {output_model_path}")if __name__ == '__main__':model = SimpleLSTM()x = torch.randn(1, 10, 10)  # 假设输入
torch.onnx.export(model,                  # model being runx,                      # model input (or a tuple for multiple inputs)"simple_lstm.onnx",     # where to save the modelexport_params=True,     # store the trained parameter weights inside the model fileopset_version=12,       # the ONNX version to export the model toinput_names = ['input'],   # the model's input namesoutput_names = ['output'])quantize_onnx_model("simple_lstm.onnx", "simple_lstm_quantized.onnx")# 测试ONNX模型和量化后的模型x = np.random.randn(1, 10, 10).astype(np.float32)  # 假设输入ort_session = ort.InferenceSession("simple_lstm.onnx")ort_session_quantized = ort.InferenceSession("simple_lstm_quantized.onnx")inputs = {ort_session.get_inputs()[0].name: x}outputs = ort_session.run(None, inputs)print("ONNX模型输出:\n", outputs[0])inputs = {ort_session_quantized.get_inputs()[0].name: x}outputs_quantized = ort_session_quantized.run(None, inputs)print("动态量化后的ONNX模型输出:\n", outputs_quantized[0])

 静态量化:

import torch
import numpy as np
import warnings
import onnx
import onnxruntime as ort
from onnxruntime.quantization import QuantFormat, QuantType, quantize_static, CalibrationDataReader
warnings.filterwarnings("ignore")class Model(torch.nn.Module):def __init__(self):super().__init__()# QuantStub converts tensors from floating point to quantized# self.quant = torch.ao.quantization.QuantStub()self.conv = torch.nn.Conv2d(1, 100, 1)self.conv1 = torch.nn.Conv2d(100, 100, 1)self.conv2 = torch.nn.Conv2d(100, 100, 1)self.conv3 = torch.nn.Conv2d(100, 1, 1)self.relu1 = torch.nn.ReLU()self.relu2 = torch.nn.ReLU()# DeQuantStub converts tensors from quantized to floating point#  self.dequant = torch.ao.quantization.DeQuantStub()def forward(self, x):# x = self.quant(x)x = self.conv(x)x = self.conv1(x)x = self.relu1(x)x = self.conv2(x)x = self.relu2(x)x = self.conv3(x)#  x = self.dequant(x)return x# 1. 准备校准数据集类
class CustomCalibrationDataReader(CalibrationDataReader):def __init__(self, calibration_data_path, input_name):"""初始化校准数据读取器参数:calibration_data_path: 校准数据.npz文件路径input_name: 模型输入名称"""self.data = np.load(calibration_data_path)self.input_name = input_nameself.datasize = len(self.data.files[0])self.enum_data = iter(self.data[self.data.files[0]])def get_next(self):"""获取下一批校准数据"""try:batch = next(self.enum_data)return {self.input_name: np.expand_dims(batch, axis=0)}except StopIteration:return Nonedef rewind(self):"""重置数据迭代器"""self.enum_data = iter(self.data[self.data.files[0]])# 2. 主量化函数
def quantize_onnx_model_static(original_model_path, quantized_model_path, calibration_data_path):"""执行ONNX模型静态量化参数:original_model_path: 原始FP32模型路径quantized_model_path: 量化后模型保存路径calibration_data_path: 校准数据集路径(.npz格式)"""# 加载原始模型model = onnx.load(original_model_path)# 获取模型输入名称input_name = model.graph.input[0].name# 创建校准数据读取器calibration_data_reader = CustomCalibrationDataReader(calibration_data_path, input_name)quantize_static(model_input=original_model_path,model_output=quantized_model_path,calibration_data_reader=calibration_data_reader,quant_format=QuantFormat.QDQ ,  # QDQ 或 QOperatorper_channel=True,               # 每通道量化reduce_range=True,              # 减少量化范围(某些CPU需要)activation_type=QuantType.QInt8,  # 激活量化类型weight_type=QuantType.QInt8,      # 权重量化类型
    )print(f"量化完成!量化模型已保存至: {quantized_model_path}")# 3. 辅助函数:生成校准数据集
def generate_calibration_data(output_path, num_samples=100):"""生成校准数据集参数:output_path: 校准数据保存路径(.npz)num_samples: 生成样本数量"""        # 创建随机输入数据 (根据实际模型调整)calibration_data = []for _ in range(num_samples):data = np.random.randn(1,4,4).astype(np.float32)  # 生成随机数据
        calibration_data.append(data)# 保存为.npz文件np.savez(output_path, calibration_data=np.array(calibration_data))print(f"已生成 {num_samples} 个校准样本到: {output_path}")# 4. 使用示例
if __name__ == "__main__":MODEL_FP32 = 'model_fp32.onnx'MODEL_INT8 = 'model_int8.onnx'model_fp32 = Model()x = torch.randn(1, 1, 4, 4)  # 假设输入
    torch.onnx.export(model_fp32,x,MODEL_FP32,input_names=['input'],output_names=['output'])# 步骤1: 生成校准数据 (如果已有数据可跳过)generate_calibration_data("calibration_data.npz", num_samples=100)# 步骤2: 执行静态量化
    quantize_onnx_model_static(original_model_path=MODEL_FP32 ,quantized_model_path=MODEL_INT8,calibration_data_path="calibration_data.npz")# 步骤3: 验证量化模型 (可选)# 加载量化模型x = np.random.randn(1, 1, 4, 4).astype(np.float32)  # 假设输入ort_session = ort.InferenceSession(MODEL_FP32)ort_session_quantized = ort.InferenceSession(MODEL_INT8)inputs = {ort_session.get_inputs()[0].name: x}outputs = ort_session.run(None, inputs)print("ONNX模型输出:\n", outputs[0])inputs = {ort_session_quantized.get_inputs()[0].name: x}outputs_quantized = ort_session_quantized.run(None, inputs)print("静态量化后的ONNX模型输出:\n", outputs_quantized[0])
http://www.sczhlp.com/news/749.html

相关文章:

  • Redisson
  • P13493 【MX-X14-T3】心电感应 题解
  • uni-app项目跑APP报useStore报错
  • DE_aemmprty 草稿纸合集
  • 22天
  • 基于 Python 的简易验证码识别系统设计与实现
  • java语法的学习笔记
  • 机械运动
  • 【2025.7.28】模拟赛T4
  • 《构建之法》读后感
  • 亚马逊发布TEACh数据集训练家用机器人
  • 日记
  • 完全使用TRAE和AI 开发一款完整的应用----第一周
  • CentOS Stream 9上部署FTP应用服务的两种方法(传统安装和docker-compose)
  • SeuratExtend 可视化教程(1):单细胞分析的高颜值绘图指南
  • SpringBoot 默认配置
  • 暑假7.28
  • 计算机硬件:RAID 0、1、5、6、10简单介绍
  • nest基础学习流程图
  • grabcad
  • 2025.7.28总结 - A
  • Python 实现基于图像处理的验证码识别
  • 2025最新程序员面试题集合 包括各大厂面试规范,面试问题
  • 浅谈基环树
  • Day 28
  • 2025.7.28
  • 《叔向贺贫》
  • 2025总结
  • AI绘画提示词
  • 记一个由tinyint类型引发的低级错误