djl234 commited on
Commit
c83be2c
1 Parent(s): edc384d

Create new file

Browse files
Files changed (1) hide show
  1. model_video.py +297 -0
model_video.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import init
4
+ import torch.nn.functional as F
5
+ from torch.optim import Adam
6
+ import numpy
7
+ from einops import rearrange
8
+ import time
9
+ from transformer import Transformer
10
+ from Intra_MLP import index_points,knn_l2
11
+
12
+ # vgg choice
13
+ base = {'vgg': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']}
14
+
15
+ # vgg16
16
+ def vgg(cfg, i=3, batch_norm=True):
17
+ layers = []
18
+ in_channels = i
19
+ for v in cfg:
20
+ if v == 'M':
21
+ layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
22
+ else:
23
+ conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
24
+ if batch_norm:
25
+ layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
26
+ else:
27
+ layers += [conv2d, nn.ReLU(inplace=True)]
28
+ in_channels = v
29
+ return layers
30
+
31
+
32
+ def hsp(in_channel, out_channel):
33
+ layers = nn.Sequential(nn.Conv2d(in_channel, out_channel, 1, 1),
34
+ nn.ReLU())
35
+ return layers
36
+
37
+ def cls_modulation_branch(in_channel, hiden_channel):
38
+ layers = nn.Sequential(nn.Linear(in_channel, hiden_channel),
39
+ nn.ReLU())
40
+ return layers
41
+
42
+ def cls_branch(hiden_channel, class_num):
43
+ layers = nn.Sequential(nn.Linear(hiden_channel, class_num),
44
+ nn.Sigmoid())
45
+ return layers
46
+
47
+ def intra():
48
+ layers = []
49
+ layers += [nn.Conv2d(512, 512, 1, 1)]
50
+ layers += [nn.Sigmoid()]
51
+ return layers
52
+
53
+ def concat_r():
54
+ layers = []
55
+ layers += [nn.Conv2d(512, 512, 1, 1)]
56
+ layers += [nn.ReLU()]
57
+ layers += [nn.Conv2d(512, 512, 3, 1, 1)]
58
+ layers += [nn.ReLU()]
59
+ layers += [nn.ConvTranspose2d(512, 512, 4, 2, 1)]
60
+ return layers
61
+
62
+ def concat_1():
63
+ layers = []
64
+ layers += [nn.Conv2d(512, 512, 1, 1)]
65
+ layers += [nn.ReLU()]
66
+ layers += [nn.Conv2d(512, 512, 3, 1, 1)]
67
+ layers += [nn.ReLU()]
68
+ return layers
69
+
70
+ def mask_branch():
71
+ layers = []
72
+ layers += [nn.Conv2d(512, 2, 3, 1, 1)]
73
+ layers += [nn.ConvTranspose2d(2, 2, 8, 4, 2)]
74
+ layers += [nn.Softmax2d()]
75
+ return layers
76
+
77
+ def incr_channel():
78
+ layers = []
79
+ layers += [nn.Conv2d(128, 512, 3, 1, 1)]
80
+ layers += [nn.Conv2d(256, 512, 3, 1, 1)]
81
+ layers += [nn.Conv2d(512, 512, 3, 1, 1)]
82
+ layers += [nn.Conv2d(512, 512, 3, 1, 1)]
83
+ return layers
84
+
85
+ def incr_channel2():
86
+ layers = []
87
+ layers += [nn.Conv2d(512, 512, 3, 1, 1)]
88
+ layers += [nn.Conv2d(512, 512, 3, 1, 1)]
89
+ layers += [nn.Conv2d(512, 512, 3, 1, 1)]
90
+ layers += [nn.Conv2d(512, 512, 3, 1, 1)]
91
+ layers += [nn.ReLU()]
92
+ return layers
93
+
94
+ def norm(x, dim):
95
+ squared_norm = (x ** 2).sum(dim=dim, keepdim=True)
96
+ normed = x / torch.sqrt(squared_norm)
97
+ return normed
98
+
99
+ def fuse_hsp(x, p,group_size=5):
100
+
101
+ t = torch.zeros(group_size, x.size(1))
102
+ for i in range(x.size(0)):
103
+ tmp = x[i, :]
104
+ if i == 0:
105
+ nx = tmp.expand_as(t)
106
+ else:
107
+ nx = torch.cat(([nx, tmp.expand_as(t)]), dim=0)
108
+ nx = nx.view(x.size(0)*group_size, x.size(1), 1, 1)
109
+ y = nx.expand_as(p)
110
+ return y
111
+
112
+
113
+ class Model(nn.Module):
114
+ def __init__(self, device, base, incr_channel, incr_channel2, hsp1, hsp2, cls_m, cls, concat_r, concat_1, mask_branch, intra,demo_mode=False):
115
+ super(Model, self).__init__()
116
+ self.base = nn.ModuleList(base)
117
+ self.sp1 = hsp1
118
+ self.sp2 = hsp2
119
+ self.cls_m = cls_m
120
+ self.cls = cls
121
+ self.incr_channel1 = nn.ModuleList(incr_channel)
122
+ self.incr_channel2 = nn.ModuleList(incr_channel2)
123
+ self.concat4 = nn.ModuleList(concat_r)
124
+ self.concat3 = nn.ModuleList(concat_r)
125
+ self.concat2 = nn.ModuleList(concat_r)
126
+ self.concat1 = nn.ModuleList(concat_1)
127
+ self.mask = nn.ModuleList(mask_branch)
128
+ self.extract = [13, 23, 33, 43]
129
+ self.device = device
130
+ self.group_size = 5
131
+ self.intra = nn.ModuleList(intra)
132
+ self.transformer_1=Transformer(512,4,4,782,group=self.group_size)
133
+ self.transformer_2=Transformer(512,4,4,782,group=self.group_size)
134
+ self.demo_mode=demo_mode
135
+
136
+ def forward(self, x):
137
+ # backbone, p is the pool2, 3, 4, 5
138
+ p = list()
139
+ for k in range(len(self.base)):
140
+ x = self.base[k](x)
141
+ if k in self.extract:
142
+ p.append(x)
143
+
144
+
145
+ # increase the channel
146
+ newp = list()
147
+ newp_T=list()
148
+ for k in range(len(p)):
149
+ np = self.incr_channel1[k](p[k])
150
+ np = self.incr_channel2[k](np)
151
+ newp.append(self.incr_channel2[4](np))
152
+ if k==3:
153
+ tmp_newp_T3=self.transformer_1(newp[k])
154
+ newp_T.append(tmp_newp_T3)
155
+ if k==2:
156
+ newp_T.append(self.transformer_2(newp[k]))
157
+ if k<2:
158
+ newp_T.append(None)
159
+
160
+
161
+ # intra-MLP
162
+ point = newp[3].view(newp[3].size(0), newp[3].size(1), -1)
163
+ point = point.permute(0,2,1)
164
+
165
+ idx = knn_l2(self.device, point, 4, 1)
166
+ feat=idx
167
+ new_point = index_points(self.device, point,idx)
168
+
169
+ group_point = new_point.permute(0, 3, 2, 1)
170
+ group_point = self.intra[0](group_point)
171
+ group_point = torch.max(group_point, 2)[0] # [B, D', S]
172
+
173
+ intra_mask = group_point.view(group_point.size(0), group_point.size(1), 7, 7)
174
+ intra_mask = intra_mask + newp[3]
175
+
176
+ spa_mask = self.intra[1](intra_mask)
177
+
178
+
179
+ x = newp[3]
180
+ x = self.sp1(x)
181
+ x = x.view(-1, x.size(1), x.size(2) * x.size(3))
182
+ x = torch.bmm(x, x.transpose(1, 2))
183
+ x = x.view(-1, x.size(1) * x.size(2))
184
+ x = x.view(x.size(0) // self.group_size, x.size(1), -1, 1)
185
+ x = self.sp2(x)
186
+ x = x.view(-1, x.size(1), x.size(2) * x.size(3))
187
+ x = torch.bmm(x, x.transpose(1, 2))
188
+ x = x.view(-1, x.size(1) * x.size(2))
189
+
190
+ #cls pred
191
+ cls_modulated_vector = self.cls_m(x)
192
+ cls_pred = self.cls(cls_modulated_vector)
193
+
194
+ #semantic and spatial modulator
195
+ g1 = fuse_hsp(cls_modulated_vector, newp[0],self.group_size)
196
+ g2 = fuse_hsp(cls_modulated_vector, newp[1],self.group_size)
197
+ g3 = fuse_hsp(cls_modulated_vector, newp[2],self.group_size)
198
+ g4 = fuse_hsp(cls_modulated_vector, newp[3],self.group_size)
199
+
200
+ spa_1 = F.interpolate(spa_mask, size=[g1.size(2), g1.size(3)], mode='bilinear')
201
+ spa_1 = spa_1.expand_as(g1)
202
+ spa_2 = F.interpolate(spa_mask, size=[g2.size(2), g2.size(3)], mode='bilinear')
203
+ spa_2 = spa_2.expand_as(g2)
204
+ spa_3 = F.interpolate(spa_mask, size=[g3.size(2), g3.size(3)], mode='bilinear')
205
+ spa_3 = spa_3.expand_as(g3)
206
+ spa_4 = F.interpolate(spa_mask, size=[g4.size(2), g4.size(3)], mode='bilinear')
207
+ spa_4 = spa_4.expand_as(g4)
208
+
209
+ y4 = newp_T[3] * g4 + spa_4
210
+ for k in range(len(self.concat4)):
211
+ y4 = self.concat4[k](y4)
212
+
213
+ y3 = newp_T[2] * g3 + spa_3
214
+
215
+ for k in range(len(self.concat3)):
216
+ y3 = self.concat3[k](y3)
217
+ if k == 1:
218
+ y3 = y3 + y4
219
+
220
+ y2 = newp[1] * g2 + spa_2
221
+
222
+ #print(y2.shape)
223
+
224
+ for k in range(len(self.concat2)):
225
+ y2 = self.concat2[k](y2)
226
+ if k == 1:
227
+ y2 = y2 + y3
228
+ y1 = newp[0] * g1 + spa_1
229
+
230
+ for k in range(len(self.concat1)):
231
+ y1 = self.concat1[k](y1)
232
+ if k == 1:
233
+ y1 = y1 + y2
234
+ y = y1
235
+ if self.demo_mode:
236
+ tmp=F.interpolate(y1, size=[14,14], mode='bilinear')
237
+ tmp=tmp.permute(0,2,3,1).contiguous().reshape(tmp.shape[0]*tmp.shape[2]*tmp.shape[3],tmp.shape[1])
238
+ tmp=tmp/torch.norm(tmp,p=2,dim=1).unsqueeze(1)
239
+ feat2=(tmp@tmp.t())
240
+ feat=F.interpolate(y, size=[14,14], mode='bilinear')
241
+
242
+ # decoder
243
+ for k in range(len(self.mask)):
244
+
245
+ y = self.mask[k](y)
246
+ mask_pred = y[:, 0, :, :]
247
+ if self.demo_mode:
248
+ return cls_pred, mask_pred,feat,feat2
249
+ else:
250
+ return cls_pred, mask_pred
251
+
252
+
253
+
254
+ # build the whole network
255
+ def build_model(device,demo_mode=False):
256
+ return Model(device,
257
+ vgg(base['vgg']),
258
+ incr_channel(),
259
+ incr_channel2(),
260
+ hsp(512, 64),
261
+ hsp(64**2, 32),
262
+ cls_modulation_branch(32**2, 512),
263
+ cls_branch(512, 78),
264
+ concat_r(),
265
+ concat_1(),
266
+ mask_branch(),
267
+ intra(),demo_mode)
268
+
269
+ # weight init
270
+ def xavier(param):
271
+ init.xavier_uniform_(param)
272
+
273
+ def weights_init(m):
274
+ if isinstance(m, nn.Conv2d):
275
+ xavier(m.weight.data)
276
+ elif isinstance(m, nn.BatchNorm2d):
277
+ init.constant_(m.weight, 1)
278
+ init.constant_(m.bias, 0)
279
+
280
+ '''import os
281
+ os.environ['CUDA_VISIBLE_DEVICES']='6'
282
+ gpu_id='cuda:0'
283
+ device = torch.device(gpu_id)
284
+ nt=build_model(device).to(device)
285
+ it=2
286
+ bs=1
287
+ gs=5
288
+ sum=0
289
+ with torch.no_grad():
290
+ for i in range(it):
291
+ A=torch.rand(bs*gs,3,448,256).cuda()
292
+ A=A*2-1
293
+ start=time.time()
294
+ nt(A)
295
+ sum+=time.time()-start
296
+ print(sum/bs/gs/it)'''
297
+