UFO / model_video.py
djl234's picture
Update model_video.py
0a31ff1
import torch
from torch import nn
from torch.nn import init
import torch.nn.functional as F
from torch.optim import Adam
import numpy
from einops import rearrange
import time
from transformer import Transformer
from Intra_MLP import index_points,knn_l2
# vgg choice
base = {'vgg': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']}
# vgg16
def vgg(cfg, i=3, batch_norm=True):
layers = []
in_channels = i
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return layers
def hsp(in_channel, out_channel):
layers = nn.Sequential(nn.Conv2d(in_channel, out_channel, 1, 1),
nn.ReLU())
return layers
def cls_modulation_branch(in_channel, hiden_channel):
layers = nn.Sequential(nn.Linear(in_channel, hiden_channel),
nn.ReLU())
return layers
def cls_branch(hiden_channel, class_num):
layers = nn.Sequential(nn.Linear(hiden_channel, class_num),
nn.Sigmoid())
return layers
def intra():
layers = []
layers += [nn.Conv2d(512, 512, 1, 1)]
layers += [nn.Sigmoid()]
return layers
def concat_r():
layers = []
layers += [nn.Conv2d(512, 512, 1, 1)]
layers += [nn.ReLU()]
layers += [nn.Conv2d(512, 512, 3, 1, 1)]
layers += [nn.ReLU()]
layers += [nn.ConvTranspose2d(512, 512, 4, 2, 1)]
return layers
def concat_1():
layers = []
layers += [nn.Conv2d(512, 512, 1, 1)]
layers += [nn.ReLU()]
layers += [nn.Conv2d(512, 512, 3, 1, 1)]
layers += [nn.ReLU()]
return layers
def mask_branch():
layers = []
layers += [nn.Conv2d(512, 2, 3, 1, 1)]
layers += [nn.ConvTranspose2d(2, 2, 8, 4, 2)]
layers += [nn.Softmax2d()]
return layers
def incr_channel():
layers = []
layers += [nn.Conv2d(128, 512, 3, 1, 1)]
layers += [nn.Conv2d(256, 512, 3, 1, 1)]
layers += [nn.Conv2d(512, 512, 3, 1, 1)]
layers += [nn.Conv2d(512, 512, 3, 1, 1)]
return layers
def incr_channel2():
layers = []
layers += [nn.Conv2d(512, 512, 3, 1, 1)]
layers += [nn.Conv2d(512, 512, 3, 1, 1)]
layers += [nn.Conv2d(512, 512, 3, 1, 1)]
layers += [nn.Conv2d(512, 512, 3, 1, 1)]
layers += [nn.ReLU()]
return layers
def norm(x, dim):
squared_norm = (x ** 2).sum(dim=dim, keepdim=True)
normed = x / torch.sqrt(squared_norm)
return normed
def fuse_hsp(x, p,group_size=5):
t = torch.zeros(group_size, x.size(1))
for i in range(x.size(0)):
tmp = x[i, :]
if i == 0:
nx = tmp.expand_as(t)
else:
nx = torch.cat(([nx, tmp.expand_as(t)]), dim=0)
nx = nx.view(x.size(0)*group_size, x.size(1), 1, 1)
y = nx.expand_as(p)
return y
class Model(nn.Module):
def __init__(self, device, base, incr_channel, incr_channel2, hsp1, hsp2, cls_m, cls, concat_r, concat_1, mask_branch, intra,demo_mode=False):
super(Model, self).__init__()
self.base = nn.ModuleList(base)
self.sp1 = hsp1
self.sp2 = hsp2
self.cls_m = cls_m
self.cls = cls
self.incr_channel1 = nn.ModuleList(incr_channel)
self.incr_channel2 = nn.ModuleList(incr_channel2)
self.concat4 = nn.ModuleList(concat_r)
self.concat3 = nn.ModuleList(concat_r)
self.concat2 = nn.ModuleList(concat_r)
self.concat1 = nn.ModuleList(concat_1)
self.mask = nn.ModuleList(mask_branch)
self.extract = [13, 23, 33, 43]
self.device = device
self.group_size = 5
self.intra = nn.ModuleList(intra)
self.transformer_1=Transformer(512,4,4,782,group=self.group_size)
self.transformer_2=Transformer(512,4,4,782,group=self.group_size)
self.demo_mode=demo_mode
def forward(self, x):
# backbone, p is the pool2, 3, 4, 5
p = list()
for k in range(len(self.base)):
x = self.base[k](x)
if k in self.extract:
p.append(x)
# increase the channel
newp = list()
newp_T=list()
for k in range(len(p)):
np = self.incr_channel1[k](p[k])
np = self.incr_channel2[k](np)
newp.append(self.incr_channel2[4](np))
if k==3:
tmp_newp_T3=self.transformer_1(newp[k])
newp_T.append(tmp_newp_T3)
if k==2:
newp_T.append(self.transformer_2(newp[k]))
if k<2:
newp_T.append(None)
# intra-MLP
point = newp[3].view(newp[3].size(0), newp[3].size(1), -1)
point = point.permute(0,2,1)
idx = knn_l2(self.device, point, 4, 1)
feat=idx
new_point = index_points(self.device, point,idx)
group_point = new_point.permute(0, 3, 2, 1)
group_point = self.intra[0](group_point)
group_point = torch.max(group_point, 2)[0] # [B, D', S]
intra_mask = group_point.view(group_point.size(0), group_point.size(1), 7, 7)
intra_mask = intra_mask + newp[3]
spa_mask = self.intra[1](intra_mask)
x = newp[3]
x = self.sp1(x)
x = x.view(-1, x.size(1), x.size(2) * x.size(3))
x = torch.bmm(x, x.transpose(1, 2))
x = x.view(-1, x.size(1) * x.size(2))
x = x.view(x.size(0) // self.group_size, x.size(1), -1, 1)
x = self.sp2(x)
x = x.view(-1, x.size(1), x.size(2) * x.size(3))
x = torch.bmm(x, x.transpose(1, 2))
x = x.view(-1, x.size(1) * x.size(2))
#cls pred
cls_modulated_vector = self.cls_m(x)
cls_pred = self.cls(cls_modulated_vector)
#semantic and spatial modulator
g1 = fuse_hsp(cls_modulated_vector, newp[0],self.group_size)
g2 = fuse_hsp(cls_modulated_vector, newp[1],self.group_size)
g3 = fuse_hsp(cls_modulated_vector, newp[2],self.group_size)
g4 = fuse_hsp(cls_modulated_vector, newp[3],self.group_size)
spa_1 = F.interpolate(spa_mask, size=[g1.size(2), g1.size(3)], mode='bilinear')
spa_1 = spa_1.expand_as(g1)
spa_2 = F.interpolate(spa_mask, size=[g2.size(2), g2.size(3)], mode='bilinear')
spa_2 = spa_2.expand_as(g2)
spa_3 = F.interpolate(spa_mask, size=[g3.size(2), g3.size(3)], mode='bilinear')
spa_3 = spa_3.expand_as(g3)
spa_4 = F.interpolate(spa_mask, size=[g4.size(2), g4.size(3)], mode='bilinear')
spa_4 = spa_4.expand_as(g4)
y4 = newp_T[3] * g4 + spa_4
for k in range(len(self.concat4)):
y4 = self.concat4[k](y4)
y3 = newp_T[2] * g3 + spa_3
for k in range(len(self.concat3)):
y3 = self.concat3[k](y3)
if k == 1:
y3 = y3 + y4
y2 = newp[1] * g2 + spa_2
#print(y2.shape)
for k in range(len(self.concat2)):
y2 = self.concat2[k](y2)
if k == 1:
y2 = y2 + y3
y1 = newp[0] * g1 + spa_1
for k in range(len(self.concat1)):
y1 = self.concat1[k](y1)
if k == 1:
y1 = y1 + y2
y = y1
if self.demo_mode:
tmp=F.interpolate(y1, size=[14,14], mode='bilinear')
tmp=tmp.permute(0,2,3,1).contiguous().reshape(tmp.shape[0]*tmp.shape[2]*tmp.shape[3],tmp.shape[1])
tmp=tmp/torch.norm(tmp,p=2,dim=1).unsqueeze(1)
feat2=(tmp@tmp.t())
feat=F.interpolate(y, size=[14,14], mode='bilinear')
# decoder
for k in range(len(self.mask)):
y = self.mask[k](y)
mask_pred = y[:, 0, :, :]
if self.demo_mode:
return cls_pred, mask_pred,feat,feat2
else:
return cls_pred, mask_pred
# build the whole network
def build_model(device,demo_mode=False):
return Model(device,
vgg(base['vgg']),
incr_channel(),
incr_channel2(),
hsp(512, 64),
hsp(64**2, 32),
cls_modulation_branch(32**2, 512),
cls_branch(512, 78),
concat_r(),
concat_1(),
mask_branch(),
intra(),demo_mode)
# weight init
def xavier(param):
init.xavier_uniform_(param)
def weights_init(m):
if isinstance(m, nn.Conv2d):
xavier(m.weight.data)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
'''import os
os.environ['CUDA_VISIBLE_DEVICES']='6'
gpu_id='cuda:0'
device = torch.device(gpu_id)
nt=build_model(device).to(device)
it=2
bs=1
gs=5
sum=0
with torch.no_grad():
for i in range(it):
A=torch.rand(bs*gs,3,448,256).cuda()
A=A*2-1
start=time.time()
nt(A)
sum+=time.time()-start
print(sum/bs/gs/it)'''