咸宁市网站建设,电影网站建设多少钱,手机版在线公章制作生成,jsp网站建设项目实战课后文章首发及后续更新#xff1a;https://mwhls.top/4475.html#xff0c;无图/无目录/格式错误/更多相关请至首发页查看。 新的更新内容请到mwhls.top查看。 欢迎提出任何疑问及批评#xff0c;非常感谢#xff01; 摘要#xff1a;绘制模型指定层的热力图 可视化环境安装 …文章首发及后续更新https://mwhls.top/4475.html无图/无目录/格式错误/更多相关请至首发页查看。 新的更新内容请到mwhls.top查看。 欢迎提出任何疑问及批评非常感谢 摘要绘制模型指定层的热力图 可视化环境安装
可用的环境版本 mmseg 1.0.0rc5mmdet 3.0.0rc6mmcv 2.0.0rc4mmengine 0.6.0注不要用在其它版本跑的文件覆盖它我最开始一直没成功就是因为我想偷懒直接复制我的模型过去但是模型调用了在原版本存在但新版本不存在的方法导致一直报错。 安装以上环境参考该 issue 代码可正常推理代码如下 还有其它 issue 也提到了 featmap可以在 mmseg 的 GitHub 搜 cam 关键词或者点这里。
import torch
import cv2
import numpy as npfrom mmseg.visualization import SegLocalVisualizer
from mmseg.apis import init_model
from mmseg.utils import register_all_modules
from mmengine.model import revert_sync_batchnormconfig_path ../mmsegv2/configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py
checkpoint_path ../mmsegv2/checkpoints/pspnet_r50-d8_512x1024_80k_cityscapes_20200606_112131-2376f12b.pth
img_path ../mmsegv2/demo/demo.pngregister_all_modules()model init_model(config_path, checkpoint_path, devicecpu)
model revert_sync_batchnorm(model)
vis SegLocalVisualizer()ori_img cv2.imread(img_path)
img torch.from_numpy(ori_img.astype(np.single)).permute(2, 0, 1).unsqueeze(0)logits model(img)
out vis.draw_featmap(logits[0], ori_img)cv2.imshow(cam, out)
cv2.waitKey(0)
指定位置可视化
修改后的可视化代码 Startup.py
# Thank xiexinch: https://github.com/open-mmlab/mmsegmentation/issues/2434#issuecomment-1441392574
import torch
import cv2
import numpy as np
from mmseg.visualization import SegLocalVisualizer
from mmseg.apis import init_model
from mmseg.utils import register_all_modules
from mmengine.model import revert_sync_batchnorm# prefix mmsegmentation-1.0.0rc5/
prefix
config prefix rlog\7_ttpla_p2t_t_20k\ttpla_p2t_t_20k.py
checkpoint prefix rlog\7_ttpla_p2t_t_20k\iter_8000.pthconfig prefix rlog\9_ttpla_r50_20k\ttpla_r50_20k.py
checkpoint prefix rlog\9_ttpla_r50_20k\iter_8000.pthimg_path prefix rimg.pngdef draw_heatmap(featmap):vis SegLocalVisualizer()ori_img cv2.imread(img_path)out vis.draw_featmap(featmap, ori_img)cv2.imshow(cam, out)cv2.waitKey(0)def generate_featmap(config, checkpoint, img_path):register_all_modules()model init_model(config, checkpoint, devicecpu)model revert_sync_batchnorm(model)vis SegLocalVisualizer()ori_img cv2.imread(img_path)img torch.from_numpy(ori_img.astype(np.single)).permute(2, 0, 1).unsqueeze(0)logits model(img)out vis.draw_featmap(logits[0], ori_img)cv2.imshow(cam, out)cv2.waitKey(0)if __name__ __main__:generate_featmap(config, checkpoint, img_path)如下在模型内调用 draw_heatmap()
from Startup import draw_heatmap
draw_heatmap(x[0])def forward(self, x):Forward function.from Startup import draw_heatmapdraw_heatmap(x[0])if self.deep_stem:x self.stem(x)else:x self.conv1(x)x self.norm1(x)x self.relu(x)x self.maxpool(x)outs []for i, layer_name in enumerate(self.res_layers):res_layer getattr(self, layer_name)x res_layer(x)if i in self.out_indices:outs.append(x)from Startup import draw_heatmapdraw_heatmap(x[0])return tuple(outs)效果展示