2.模拟测试
这里是接7月31号的部分
2.3 PromptEncoder模块
【经验1】:这一部分是关于pytorch中的部分内置函数。
- torch.topk()
torch.topk()默认使用dim=-1
选择张量的最后一个维度,默认使用largest=True
返回k
个最大值,默认使用sorted
选择k
个元素也是有序的,最后返回的是(values,indices)
组合而成的元组,例如我对大小为[B,1,S,H*W]
张量,使用topk,返回的values,indices
的大小都是[B,1,S,H*W]
🥪:这个函数一般用于选择使用他返回的索引indices
,搭配一种显著性分数计算算法,选择k
个最大显著特征. - torch.Tensor.expand()
torch.Tensor.expand()会返回张量的视图,expand的输入扩展为目标张量的大小,需要严格匹配张量的size,例如输入张量的大小为[B,1,S,H,W]
,那么expand的输入就必须为expand(-1,3,-1,-1,-1)
,其中-1
表示占位符,不对这个维度做扩展(实际上也无法对这个张量做扩展,因为维度不为1
),3
表示对张量扩展复制3
次
🥪:这个函数一般用于批量复制显著点的索引 - torch.gather()
torch.gather(),在官方介绍中,torch.gather()的运行原理可能会比较晦涩,这里需要明确一个概念,torch.gather()的输入和输出的形状是一致的,输入的索引是[H,W]
那么输出的索引也是[H,W]
其次gather用于一维索引比较好理解,即目标图的大小设置为[B,32,S,H*W]
,那么索引图最好也是[B,32,S,Num]
,其中value(Num)
要在H*W
的范围内,此时就可以简单认为索引值在哪个维度就dim参数就设置在相应的维度,例如索引值在最后一维,那么可以设置dim=-1
- torch.squeeze()/torch.unsqueeze()
这两个函数一个用来去掉dim == 1
的维度,默认去掉所有dim == 1
的维度,可以通过指定位置去掉dim == 1
的维度;另一个函数用于用于在指定位置创造dim == 1
的维度,用于
-1.
【】谁都有难以名状的心情,因此人会变的难过