c语言sscanf函数的用法是什么
341
2022-11-30
注意力机制论文:CCNet: Criss-Cross Attention for Semantic Segmentation及其PyTorch实现
CCNet: Criss-Cross Attention for Semantic Segmentation PDF: map计算的是所有像素与所有像素之间的相似性,空间复杂度为(HxW)x(HxW),而本文采用了criss-cross思想,只计算每个像素与其同行同列即十字上的像素的相似性,通过进行循环(两次相同操作),间接计算到每个像素与每个像素的相似性,将空间复杂度降为(HxW)x(H+W-1)
具有以下优点
1)较少的GPU内存使用;2)计算效率高;3)最先进的性能。
1 Criss-Cross Attention
2 Recurrent Criss-Cross Attention
在计算矩阵相乘时每个像素只抽取特征图中对应十字位置的像素进行点乘,计算相似度。和non-local的方法相比极大的降低了计算量,同时采用二阶注意力,能够从所有像素中获取全图像的上下文信息,以生成具有密集且丰富的上下文信息的新特征图。在计算矩阵相乘时,每个像素只抽取特征图中对应十字位置的像素进行点乘,计算相似度。
3 CCNet
4 实验结果
PyTorch代码:
import torchimport torch.nn as nndef INF(B, H, W): return -torch.diag(torch.tensor(float("inf")).repeat(H), 0).unsqueeze(0).repeat(B * W, 1, 1)class CrissCrossAttention(nn.Module): """ Criss-Cross Attention Module""" def __init__(self, in_dim): super(CrissCrossAttention, self).__init__() self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) self.softmax = nn.Softmax(dim=3) self.INF = INF self.gamma = nn.Parameter(torch.zeros(1)) def forward(self, x): m_batchsize, _, height, width = x.size() proj_query = self.query_conv(x) proj_query_H = proj_query.permute(0, 3, 1, 2).contiguous().view(m_batchsize * width, -1, height).permute(0, 2, 1) proj_query_W = proj_query.permute(0, 2, 1, 3).contiguous().view(m_batchsize * height, -1, width).permute(0, 2, 1) proj_key = self.key_conv(x) proj_key_H = proj_key.permute(0, 3, 1, 2).contiguous().view(m_batchsize * width, -1, height) proj_key_W = proj_key.permute(0, 2, 1, 3).contiguous().view(m_batchsize * height, -1, width) proj_value = self.value_conv(x) proj_value_H = proj_value.permute(0, 3, 1, 2).contiguous().view(m_batchsize * width, -1, height) proj_value_W = proj_value.permute(0, 2, 1, 3).contiguous().view(m_batchsize * height, -1, width) energy_H = (torch.bmm(proj_query_H, proj_key_H) + self.INF(m_batchsize, height, width)).view(m_batchsize, width, height, height).permute(0, 2, 1, 3) energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize, height, width, width) concate = self.softmax(torch.cat([energy_H, energy_W], 3)) att_H = concate[:, :, :, 0:height].permute(0, 2, 1, 3).contiguous().view(m_batchsize * width, height, height) # print(concate) # print(att_H) att_W = concate[:, :, :, height:height + width].contiguous().view(m_batchsize * height, width, width) out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize, width, -1, height).permute(0, 2, 3, 1) out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize, height, -1, width).permute(0, 2, 1, 3) # print(out_H.size(),out_W.size()) return self.gamma * (out_H + out_W) + xif __name__=='__main__': model = CrissCrossAttention(16) print(model) input = torch.randn(1, 16, 64, 64) out = model(input) print(out.shape)
版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。
发表评论
暂时没有评论,来抢沙发吧~