Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from attention import ChannelAttention, SpatialAttention, DualCrossModalAttention | |
| from srm_conv import SRMConv2d_simple, SRMConv2d_Separate | |
| from xception import TransferModel | |
| class SRMPixelAttention(nn.Module): | |
| def __init__(self, in_channels): | |
| super(SRMPixelAttention, self).__init__() | |
| # self.srm = SRMConv2d_simple() | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(in_channels, 32, 3, 2, 0, bias=False), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(32, 64, 3, bias=False), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(inplace=True), | |
| ) | |
| self.pa = SpatialAttention() | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_(m.weight, a=1) | |
| if not m.bias is None: | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, x_srm): | |
| # x_srm = self.srm(x) | |
| fea = self.conv(x_srm) | |
| att_map = self.pa(fea) | |
| return att_map | |
| class FeatureFusionModule(nn.Module): | |
| def __init__(self, in_chan=2048*2, out_chan=2048, *args, **kwargs): | |
| super(FeatureFusionModule, self).__init__() | |
| self.convblk = nn.Sequential( | |
| nn.Conv2d(in_chan, out_chan, 1, 1, 0, bias=False), | |
| nn.BatchNorm2d(out_chan), | |
| nn.ReLU() | |
| ) | |
| self.ca = ChannelAttention(out_chan, ratio=16) | |
| self.init_weight() | |
| def forward(self, x, y): | |
| fuse_fea = self.convblk(torch.cat((x, y), dim=1)) | |
| fuse_fea = fuse_fea + fuse_fea * self.ca(fuse_fea) | |
| return fuse_fea | |
| def init_weight(self): | |
| for ly in self.children(): | |
| if isinstance(ly, nn.Conv2d): | |
| nn.init.kaiming_normal_(ly.weight, a=1) | |
| if not ly.bias is None: | |
| nn.init.constant_(ly.bias, 0) | |
| class Two_Stream_Net(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.xception_rgb = TransferModel( | |
| 'xception', dropout=0.5, inc=3, return_fea=True) | |
| self.xception_srm = TransferModel( | |
| 'xception', dropout=0.5, inc=3, return_fea=True) | |
| self.srm_conv0 = SRMConv2d_simple(inc=3) | |
| self.srm_conv1 = SRMConv2d_Separate(32, 32) | |
| self.srm_conv2 = SRMConv2d_Separate(64, 64) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.att_map = None | |
| self.srm_sa = SRMPixelAttention(3) | |
| self.srm_sa_post = nn.Sequential( | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(inplace=True) | |
| ) | |
| self.dual_cma0 = DualCrossModalAttention(in_dim=728, ret_att=False) | |
| self.dual_cma1 = DualCrossModalAttention(in_dim=728, ret_att=False) | |
| self.fusion = FeatureFusionModule() | |
| self.att_dic = {} | |
| def features(self, x): | |
| srm = self.srm_conv0(x) | |
| x = self.xception_rgb.model.fea_part1_0(x) | |
| y = self.xception_srm.model.fea_part1_0(srm) \ | |
| + self.srm_conv1(x) | |
| y = self.relu(y) | |
| x = self.xception_rgb.model.fea_part1_1(x) | |
| y = self.xception_srm.model.fea_part1_1(y) \ | |
| + self.srm_conv2(x) | |
| y = self.relu(y) | |
| # srm guided spatial attention | |
| self.att_map = self.srm_sa(srm) | |
| x = x * self.att_map + x | |
| x = self.srm_sa_post(x) | |
| x = self.xception_rgb.model.fea_part2(x) | |
| y = self.xception_srm.model.fea_part2(y) | |
| x, y = self.dual_cma0(x, y) | |
| x = self.xception_rgb.model.fea_part3(x) | |
| y = self.xception_srm.model.fea_part3(y) | |
| x, y = self.dual_cma1(x, y) | |
| x = self.xception_rgb.model.fea_part4(x) | |
| y = self.xception_srm.model.fea_part4(y) | |
| x = self.xception_rgb.model.fea_part5(x) | |
| y = self.xception_srm.model.fea_part5(y) | |
| fea = self.fusion(x, y) | |
| return fea | |
| def classifier(self, fea): | |
| out, fea = self.xception_rgb.classifier(fea) | |
| return out, fea | |
| def forward(self, x): | |
| ''' | |
| x: original rgb | |
| Return: | |
| out: (B, 2) the output for loss computing | |
| fea: (B, 1024) the flattened features before the last FC | |
| att_map: srm spatial attention map | |
| ''' | |
| out, fea = self.classifier(self.features(x)) | |
| return out, fea, self.att_map | |
| if __name__ == '__main__': | |
| model = Two_Stream_Net() | |
| dummy = torch.rand((1,3,256,256)) | |
| out = model(dummy) | |
| print(model) | |