当前位置: 首页 > news >正文

深度学习(ACNet重参数化)

和RepVGG类似,ACNet也是通过重参数化提高推理性能。

RepVGG是将3*3结构,1*1结构和直连结构并联在一起,而ACNet是将3*3结构,3*1结构和1*3结构并联在一起,最终在推理时融合为一个3*3结构。

形式如下图:

屏幕截图_30-8-2025_21493_blog.csdn.net

下面代码是按照自己的理解实现的重参数化Block,分为训练和部署两个分支,结果通过了allclose验证。

import torch
import torch.nn as nnclass AcNetBlock(nn.Module):def __init__(self, channels, deploy):super(AcNetBlock, self).__init__()self.deploy = deployself.channels = channelsself.conv3x3 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=True)self.bn3x3 = nn.BatchNorm2d(channels)self.conv3x1 = nn.Conv2d(channels, channels, kernel_size=(3,1), stride=1, padding=(1,0), bias=True)self.bn3x1 = nn.BatchNorm2d(channels)self.conv1x3 = nn.Conv2d(channels, channels, kernel_size=(1,3), stride=1, padding=(0,1), bias=True)self.bn1x3 = nn.BatchNorm2d(channels)if deploy == False:self.conv3x3.weight.data = torch.randn(channels, channels, 3, 3)self.conv3x3.bias.data = torch.randn(channels)self.bn3x3.weight.data = torch.randn(channels)self.bn3x3.bias.data = torch.randn(channels)self.conv3x1.weight.data = torch.randn(channels, channels, 3, 1)self.conv3x1.bias.data = torch.randn(channels)self.bn3x1.weight.data = torch.randn(channels)self.bn3x1.bias.data = torch.randn(channels)self.conv1x3.weight.data = torch.randn(channels, channels, 1, 3)self.conv1x3.bias.data = torch.randn(channels)self.bn1x3.weight.data = torch.randn(channels)self.bn1x3.bias.data = torch.randn(channels)# Fusion convself.fusion_conv = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=True)self.relu = nn.ReLU(inplace=True)def forward(self, x):if self.deploy == False:x1 = self.conv3x3(x)x1 = self.bn3x3(x1)x2 = self.conv3x1(x)x2 = self.bn3x1(x2)x3 = self.conv1x3(x)x3 = self.bn1x3(x3)x = x1 + x2 + x3else:x = self.fusion_conv(x)return self.relu(x)def reparam3x3(self):conv_w = self.conv3x3.weightconv_b = self.conv3x3.biasbn_w = self.bn3x3.weightbn_b = self.bn3x3.bias bn_w = bn_w.div(torch.sqrt(self.bn3x3.eps + self.bn3x3.running_var))fusion_w = torch.mm(torch.diag(bn_w), conv_w.view(self.channels, -1)).view(self.channels,self.channels,3,3)fusion_b = bn_w * (conv_b - self.bn3x3.running_mean) + bn_bprint(fusion_w.shape,fusion_b.shape)return fusion_w, fusion_bdef reparam3x1(self):conv_w = self.conv3x1.weightconv_b = self.conv3x1.biasbn_w = self.bn3x1.weightbn_b = self.bn3x1.bias bn_w = bn_w.div(torch.sqrt(self.bn3x1.eps + self.bn3x1.running_var))fusion_w = torch.mm(torch.diag(bn_w), conv_w.view(self.channels, -1)).view(self.channels,self.channels,3,1)w = torch.zeros(self.channels, self.channels, 3, 3)w[:,:,:,1] = fusion_w.squeeze(3)fusion_b = bn_w * (conv_b - self.bn3x1.running_mean) + bn_bprint(w.shape,fusion_b.shape)return w, fusion_bdef reparam1x3(self):conv_w = self.conv1x3.weightconv_b = self.conv1x3.biasbn_w = self.bn1x3.weightbn_b = self.bn1x3.bias bn_w = bn_w.div(torch.sqrt(self.bn1x3.eps + self.bn1x3.running_var))fusion_w = torch.mm(torch.diag(bn_w), conv_w.view(self.channels, -1)).view(self.channels,self.channels,1,3)w = torch.zeros(self.channels, self.channels, 3, 3)w[:,:,1,:] = fusion_w.squeeze(2)fusion_b = bn_w * (conv_b - self.bn1x3.running_mean) + bn_bprint(w.shape,fusion_b.shape)return w, fusion_bdef reparam(self):w_3x3, b_3x3 = self.reparam3x3()w_3x1, b_3x1 = self.reparam3x1()w_1x3, b_1x3 = self.reparam1x3()self.fusion_conv.weight.data = (w_3x3 + w_3x1 + w_1x3).clone()self.fusion_conv.bias.data = (b_3x3 +b_3x1 + b_1x3).clone()   x = torch.randn(1, 20, 224, 224)  net1 = AcNetBlock(20, False)
torch.save(net1.state_dict(), "acnet.pth")
net1.eval()   
y1 = net1(x)net2 = AcNetBlock(20, True)
net2.load_state_dict(torch.load("acnet.pth"))
net2.reparam()  
net2.eval()   
y2 = net2(x)print(y1.shape,y2.shape)
print(torch.allclose(y1, y2, atol=1e-4))torch.onnx.export(net1, x, "acnet.onnx", input_names=['input'], output_names=['output'])
torch.onnx.export(net2, x, "acnet_deploy.onnx", input_names=['input'], output_names=['output'])
http://www.sczhlp.com/news/53555/

相关文章:

  • 神经架构搜索NAS详解:三种核心算法原理与Python实战代码
  • 鹤壁集团网站建设网站基础知识域名5个点
  • 四川住建厅官方网站的网址狮山建网站
  • 重庆网站优化法库综合网站建设方案
  • 做网站需要哪些框架端午节网站建设目的
  • 学校网站模板网页制作html完整代码
  • 网站托管服务是什么html 新手入门
  • 建筑网站翻译编辑网页设计实训总结报告大全
  • 莱芜买房网站工商营业执照网上查询官网
  • 网站建设方案书组网方案hexo添加wordpress评论
  • 网站开发应该怎么学网页设计师有前途吗
  • 拓扑复习 | munkres 19 积拓扑
  • 复健。(11~20,ChO)
  • 网站备案表上面的开办单位写什么企业网站设计建设
  • 天津市企业网站设计公司代发货网站建设
  • 专门做鞋的网站wordpress二次元主题
  • 广州模板建站定制网站网站建设必备软件
  • 美容 北京 公司 网站建设便利的邯郸网站建设
  • 为什么说jobleap.cn是最适合大学生找工作的App
  • 性能测试全流程解析:从需求分析到报告输出
  • fullpage网站红桥网站建设
  • 企业网站排名提升软件智能优化买商标
  • 淄博市建设局网站首页永州市建设网站
  • 网站自适应深圳小程序制作流程
  • 福州专业网站设计wordpress产品菜单
  • 怎么自己做个免费网站吗新沂网站建设公司
  • 长春哪里做网站东莞网站制作外包
  • 广西建设协会网站网站下载端口建设
  • wordpress博客非插件手机百度seo怎么优化
  • 【大二病也要学离散!】第十八章 一阶逻辑基本概念