Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from attention import ChannelAttention, SpatialAttention, DualCrossModalAttention | |
from srm_conv import SRMConv2d_simple, SRMConv2d_Separate | |
from xception import TransferModel | |
class SRMPixelAttention(nn.Module): | |
def __init__(self, in_channels): | |
super(SRMPixelAttention, self).__init__() | |
# self.srm = SRMConv2d_simple() | |
self.conv = nn.Sequential( | |
nn.Conv2d(in_channels, 32, 3, 2, 0, bias=False), | |
nn.BatchNorm2d(32), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(32, 64, 3, bias=False), | |
nn.BatchNorm2d(64), | |
nn.ReLU(inplace=True), | |
) | |
self.pa = SpatialAttention() | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
nn.init.kaiming_normal_(m.weight, a=1) | |
if not m.bias is None: | |
nn.init.constant_(m.bias, 0) | |
def forward(self, x_srm): | |
# x_srm = self.srm(x) | |
fea = self.conv(x_srm) | |
att_map = self.pa(fea) | |
return att_map | |
class FeatureFusionModule(nn.Module): | |
def __init__(self, in_chan=2048*2, out_chan=2048, *args, **kwargs): | |
super(FeatureFusionModule, self).__init__() | |
self.convblk = nn.Sequential( | |
nn.Conv2d(in_chan, out_chan, 1, 1, 0, bias=False), | |
nn.BatchNorm2d(out_chan), | |
nn.ReLU() | |
) | |
self.ca = ChannelAttention(out_chan, ratio=16) | |
self.init_weight() | |
def forward(self, x, y): | |
fuse_fea = self.convblk(torch.cat((x, y), dim=1)) | |
fuse_fea = fuse_fea + fuse_fea * self.ca(fuse_fea) | |
return fuse_fea | |
def init_weight(self): | |
for ly in self.children(): | |
if isinstance(ly, nn.Conv2d): | |
nn.init.kaiming_normal_(ly.weight, a=1) | |
if not ly.bias is None: | |
nn.init.constant_(ly.bias, 0) | |
class Two_Stream_Net(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.xception_rgb = TransferModel( | |
'xception', dropout=0.5, inc=3, return_fea=True) | |
self.xception_srm = TransferModel( | |
'xception', dropout=0.5, inc=3, return_fea=True) | |
self.srm_conv0 = SRMConv2d_simple(inc=3) | |
self.srm_conv1 = SRMConv2d_Separate(32, 32) | |
self.srm_conv2 = SRMConv2d_Separate(64, 64) | |
self.relu = nn.ReLU(inplace=True) | |
self.att_map = None | |
self.srm_sa = SRMPixelAttention(3) | |
self.srm_sa_post = nn.Sequential( | |
nn.BatchNorm2d(64), | |
nn.ReLU(inplace=True) | |
) | |
self.dual_cma0 = DualCrossModalAttention(in_dim=728, ret_att=False) | |
self.dual_cma1 = DualCrossModalAttention(in_dim=728, ret_att=False) | |
self.fusion = FeatureFusionModule() | |
self.att_dic = {} | |
def features(self, x): | |
srm = self.srm_conv0(x) | |
x = self.xception_rgb.model.fea_part1_0(x) | |
y = self.xception_srm.model.fea_part1_0(srm) \ | |
+ self.srm_conv1(x) | |
y = self.relu(y) | |
x = self.xception_rgb.model.fea_part1_1(x) | |
y = self.xception_srm.model.fea_part1_1(y) \ | |
+ self.srm_conv2(x) | |
y = self.relu(y) | |
# srm guided spatial attention | |
self.att_map = self.srm_sa(srm) | |
x = x * self.att_map + x | |
x = self.srm_sa_post(x) | |
x = self.xception_rgb.model.fea_part2(x) | |
y = self.xception_srm.model.fea_part2(y) | |
x, y = self.dual_cma0(x, y) | |
x = self.xception_rgb.model.fea_part3(x) | |
y = self.xception_srm.model.fea_part3(y) | |
x, y = self.dual_cma1(x, y) | |
x = self.xception_rgb.model.fea_part4(x) | |
y = self.xception_srm.model.fea_part4(y) | |
x = self.xception_rgb.model.fea_part5(x) | |
y = self.xception_srm.model.fea_part5(y) | |
fea = self.fusion(x, y) | |
return fea | |
def classifier(self, fea): | |
out, fea = self.xception_rgb.classifier(fea) | |
return out, fea | |
def forward(self, x): | |
''' | |
x: original rgb | |
Return: | |
out: (B, 2) the output for loss computing | |
fea: (B, 1024) the flattened features before the last FC | |
att_map: srm spatial attention map | |
''' | |
out, fea = self.classifier(self.features(x)) | |
return out, fea, self.att_map | |
if __name__ == '__main__': | |
model = Two_Stream_Net() | |
dummy = torch.rand((1,3,256,256)) | |
out = model(dummy) | |
print(model) | |