import torch from torch import nn from torch.nn import init import torch.nn.functional as F from torch.optim import Adam import numpy from einops import rearrange import time from transformer import Transformer from Intra_MLP import index_points,knn_l2 # vgg choice base = {'vgg': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']} # vgg16 def vgg(cfg, i=3, batch_norm=True): layers = [] in_channels = i for v in cfg: if v == 'M': layers += [nn.MaxPool2d(kernel_size=2, stride=2)] else: conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) if batch_norm: layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] else: layers += [conv2d, nn.ReLU(inplace=True)] in_channels = v return layers def hsp(in_channel, out_channel): layers = nn.Sequential(nn.Conv2d(in_channel, out_channel, 1, 1), nn.ReLU()) return layers def cls_modulation_branch(in_channel, hiden_channel): layers = nn.Sequential(nn.Linear(in_channel, hiden_channel), nn.ReLU()) return layers def cls_branch(hiden_channel, class_num): layers = nn.Sequential(nn.Linear(hiden_channel, class_num), nn.Sigmoid()) return layers def intra(): layers = [] layers += [nn.Conv2d(512, 512, 1, 1)] layers += [nn.Sigmoid()] return layers def concat_r(): layers = [] layers += [nn.Conv2d(512, 512, 1, 1)] layers += [nn.ReLU()] layers += [nn.Conv2d(512, 512, 3, 1, 1)] layers += [nn.ReLU()] layers += [nn.ConvTranspose2d(512, 512, 4, 2, 1)] return layers def concat_1(): layers = [] layers += [nn.Conv2d(512, 512, 1, 1)] layers += [nn.ReLU()] layers += [nn.Conv2d(512, 512, 3, 1, 1)] layers += [nn.ReLU()] return layers def mask_branch(): layers = [] layers += [nn.Conv2d(512, 2, 3, 1, 1)] layers += [nn.ConvTranspose2d(2, 2, 8, 4, 2)] layers += [nn.Softmax2d()] return layers def incr_channel(): layers = [] layers += [nn.Conv2d(128, 512, 3, 1, 1)] layers += [nn.Conv2d(256, 512, 3, 1, 1)] layers += [nn.Conv2d(512, 512, 3, 1, 1)] layers += [nn.Conv2d(512, 512, 3, 1, 1)] return layers def incr_channel2(): layers = [] layers += [nn.Conv2d(512, 512, 3, 1, 1)] layers += [nn.Conv2d(512, 512, 3, 1, 1)] layers += [nn.Conv2d(512, 512, 3, 1, 1)] layers += [nn.Conv2d(512, 512, 3, 1, 1)] layers += [nn.ReLU()] return layers def norm(x, dim): squared_norm = (x ** 2).sum(dim=dim, keepdim=True) normed = x / torch.sqrt(squared_norm) return normed def fuse_hsp(x, p,group_size=5): t = torch.zeros(group_size, x.size(1)) for i in range(x.size(0)): tmp = x[i, :] if i == 0: nx = tmp.expand_as(t) else: nx = torch.cat(([nx, tmp.expand_as(t)]), dim=0) nx = nx.view(x.size(0)*group_size, x.size(1), 1, 1) y = nx.expand_as(p) return y class Model(nn.Module): def __init__(self, device, base, incr_channel, incr_channel2, hsp1, hsp2, cls_m, cls, concat_r, concat_1, mask_branch, intra,demo_mode=False): super(Model, self).__init__() self.base = nn.ModuleList(base) self.sp1 = hsp1 self.sp2 = hsp2 self.cls_m = cls_m self.cls = cls self.incr_channel1 = nn.ModuleList(incr_channel) self.incr_channel2 = nn.ModuleList(incr_channel2) self.concat4 = nn.ModuleList(concat_r) self.concat3 = nn.ModuleList(concat_r) self.concat2 = nn.ModuleList(concat_r) self.concat1 = nn.ModuleList(concat_1) self.mask = nn.ModuleList(mask_branch) self.extract = [13, 23, 33, 43] self.device = device self.group_size = 5 self.intra = nn.ModuleList(intra) self.transformer_1=Transformer(512,4,4,782,group=self.group_size) self.transformer_2=Transformer(512,4,4,782,group=self.group_size) self.demo_mode=demo_mode def forward(self, x): # backbone, p is the pool2, 3, 4, 5 p = list() for k in range(len(self.base)): x = self.base[k](x) if k in self.extract: p.append(x) # increase the channel newp = list() newp_T=list() for k in range(len(p)): np = self.incr_channel1[k](p[k]) np = self.incr_channel2[k](np) newp.append(self.incr_channel2[4](np)) if k==3: tmp_newp_T3=self.transformer_1(newp[k]) newp_T.append(tmp_newp_T3) if k==2: newp_T.append(self.transformer_2(newp[k])) if k<2: newp_T.append(None) # intra-MLP point = newp[3].view(newp[3].size(0), newp[3].size(1), -1) point = point.permute(0,2,1) idx = knn_l2(self.device, point, 4, 1) feat=idx new_point = index_points(self.device, point,idx) group_point = new_point.permute(0, 3, 2, 1) group_point = self.intra[0](group_point) group_point = torch.max(group_point, 2)[0] # [B, D', S] intra_mask = group_point.view(group_point.size(0), group_point.size(1), 7, 7) intra_mask = intra_mask + newp[3] spa_mask = self.intra[1](intra_mask) x = newp[3] x = self.sp1(x) x = x.view(-1, x.size(1), x.size(2) * x.size(3)) x = torch.bmm(x, x.transpose(1, 2)) x = x.view(-1, x.size(1) * x.size(2)) x = x.view(x.size(0) // self.group_size, x.size(1), -1, 1) x = self.sp2(x) x = x.view(-1, x.size(1), x.size(2) * x.size(3)) x = torch.bmm(x, x.transpose(1, 2)) x = x.view(-1, x.size(1) * x.size(2)) #cls pred cls_modulated_vector = self.cls_m(x) cls_pred = self.cls(cls_modulated_vector) #semantic and spatial modulator g1 = fuse_hsp(cls_modulated_vector, newp[0],self.group_size) g2 = fuse_hsp(cls_modulated_vector, newp[1],self.group_size) g3 = fuse_hsp(cls_modulated_vector, newp[2],self.group_size) g4 = fuse_hsp(cls_modulated_vector, newp[3],self.group_size) spa_1 = F.interpolate(spa_mask, size=[g1.size(2), g1.size(3)], mode='bilinear') spa_1 = spa_1.expand_as(g1) spa_2 = F.interpolate(spa_mask, size=[g2.size(2), g2.size(3)], mode='bilinear') spa_2 = spa_2.expand_as(g2) spa_3 = F.interpolate(spa_mask, size=[g3.size(2), g3.size(3)], mode='bilinear') spa_3 = spa_3.expand_as(g3) spa_4 = F.interpolate(spa_mask, size=[g4.size(2), g4.size(3)], mode='bilinear') spa_4 = spa_4.expand_as(g4) y4 = newp_T[3] * g4 + spa_4 for k in range(len(self.concat4)): y4 = self.concat4[k](y4) y3 = newp_T[2] * g3 + spa_3 for k in range(len(self.concat3)): y3 = self.concat3[k](y3) if k == 1: y3 = y3 + y4 y2 = newp[1] * g2 + spa_2 #print(y2.shape) for k in range(len(self.concat2)): y2 = self.concat2[k](y2) if k == 1: y2 = y2 + y3 y1 = newp[0] * g1 + spa_1 for k in range(len(self.concat1)): y1 = self.concat1[k](y1) if k == 1: y1 = y1 + y2 y = y1 if self.demo_mode: tmp=F.interpolate(y1, size=[14,14], mode='bilinear') tmp=tmp.permute(0,2,3,1).contiguous().reshape(tmp.shape[0]*tmp.shape[2]*tmp.shape[3],tmp.shape[1]) tmp=tmp/torch.norm(tmp,p=2,dim=1).unsqueeze(1) feat2=(tmp@tmp.t()) feat=F.interpolate(y, size=[14,14], mode='bilinear') # decoder for k in range(len(self.mask)): y = self.mask[k](y) mask_pred = y[:, 0, :, :] if self.demo_mode: return cls_pred, mask_pred,feat,feat2 else: return cls_pred, mask_pred # build the whole network def build_model(device,demo_mode=False): return Model(device, vgg(base['vgg']), incr_channel(), incr_channel2(), hsp(512, 64), hsp(64**2, 32), cls_modulation_branch(32**2, 512), cls_branch(512, 78), concat_r(), concat_1(), mask_branch(), intra(),demo_mode) # weight init def xavier(param): init.xavier_uniform_(param) def weights_init(m): if isinstance(m, nn.Conv2d): xavier(m.weight.data) elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias, 0) '''import os os.environ['CUDA_VISIBLE_DEVICES']='6' gpu_id='cuda:0' device = torch.device(gpu_id) nt=build_model(device).to(device) it=2 bs=1 gs=5 sum=0 with torch.no_grad(): for i in range(it): A=torch.rand(bs*gs,3,448,256).cuda() A=A*2-1 start=time.time() nt(A) sum+=time.time()-start print(sum/bs/gs/it)'''