sunana commited on
Commit
a5423cb
1 Parent(s): 7f60e6b

main model

Browse files
Files changed (1) hide show
  1. FFV1MT_MS.py +311 -0
FFV1MT_MS.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+ from MT import FeatureTransformer
6
+ from torch.cuda.amp import autocast as autocast
7
+ from flow_tools import viz_img_seq, save_img_seq, plt_show_img_flow
8
+ from copy import deepcopy
9
+ from V1 import V1
10
+ import matplotlib.pyplot as plt
11
+ from io import BytesIO
12
+ from PIL import Image
13
+
14
+ def conv(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, isReLU=True):
15
+ if isReLU:
16
+ return nn.Sequential(
17
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
18
+ dilation=dilation,
19
+ padding=((kernel_size - 1) * dilation) // 2, bias=True),
20
+ nn.GELU()
21
+ )
22
+ else:
23
+ return nn.Sequential(
24
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
25
+ dilation=dilation,
26
+ padding=((kernel_size - 1) * dilation) // 2, bias=True)
27
+ )
28
+
29
+
30
+
31
+ def plt_attention(attention, h, w):
32
+ col = len(attention) // 2
33
+ fig = plt.figure(figsize=(10, 8))
34
+
35
+ for i in range(len(attention)):
36
+ viz = attention[i][0, :, :, h, w].detach().cpu().numpy()
37
+ # viz = viz[7:-7, 7:-7]
38
+ if i == 0:
39
+ viz_all = viz
40
+ else:
41
+ viz_all = viz_all + viz
42
+
43
+ ax1 = fig.add_subplot(2, col, i + 1)
44
+ img = ax1.imshow(viz, cmap="rainbow", interpolation="bilinear")
45
+ ax1.scatter(w, h, color='grey', s=300, alpha=0.5)
46
+ ax1.scatter(w, h, color='red', s=150, alpha=0.5)
47
+ plt.title(" Iteration %d" % (i + 1))
48
+ if i == len(attention) - 1:
49
+ plt.title(" Final Iteration")
50
+ plt.xticks([])
51
+ plt.yticks([])
52
+
53
+
54
+ # tight layout
55
+ plt.tight_layout()
56
+ # save the figure
57
+ buf = BytesIO()
58
+ plt.savefig(buf, format='png')
59
+ buf.seek(0)
60
+ plt.close()
61
+ # convert the figure to an array
62
+ img = Image.open(buf)
63
+ img = np.array(img)
64
+ return img
65
+
66
+
67
+ class FlowDecoder(nn.Module):
68
+ # can reduce 25% of training time.
69
+ def __init__(self, ch_in):
70
+ super(FlowDecoder, self).__init__()
71
+ self.conv1 = conv(ch_in, 256, kernel_size=1)
72
+ self.conv2 = conv(256, 128, kernel_size=1)
73
+ self.conv3 = conv(256 + 128, 96, kernel_size=1)
74
+ self.conv4 = conv(96 + 128, 64, kernel_size=1)
75
+ self.conv5 = conv(96 + 64, 32, kernel_size=1)
76
+
77
+ self.feat_dim = 32
78
+ self.predict_flow = conv(64 + 32, 2, isReLU=False)
79
+
80
+ def forward(self, x):
81
+ x1 = self.conv1(x)
82
+ x2 = self.conv2(x1)
83
+ x3 = self.conv3(torch.cat([x1, x2], dim=1))
84
+ x4 = self.conv4(torch.cat([x2, x3], dim=1))
85
+ x5 = self.conv5(torch.cat([x3, x4], dim=1))
86
+ flow = self.predict_flow(torch.cat([x4, x5], dim=1))
87
+ return flow
88
+
89
+
90
+ class FFV1DNN(nn.Module):
91
+ def __init__(self,
92
+ num_scales=8,
93
+ num_cells=256,
94
+ upsample_factor=8,
95
+ feature_channels=256,
96
+ scale_factor=16,
97
+ num_layers=6,
98
+ ):
99
+ super(FFV1DNN, self).__init__()
100
+ self.ffv1 = V1(spatial_num=num_cells // num_scales, scale_num=num_scales, scale_factor=scale_factor,
101
+ kernel_radius=7, num_ft=num_cells // num_scales,
102
+ kernel_size=6, average_time=True)
103
+ self.v1_kz = 7
104
+ self.scale_factor = scale_factor
105
+ scale_each_level = np.exp(1 / (num_scales - 1) * np.log(1 / scale_factor))
106
+ self.scale_num = num_scales
107
+ self.scale_each_level = scale_each_level
108
+ v1_channel = self.ffv1.num_after_st
109
+ self.num_scales = num_scales
110
+ self.MT_channel = feature_channels
111
+ assert self.MT_channel == v1_channel
112
+ self.feature_channels = feature_channels
113
+
114
+ self.upsample_factor = upsample_factor
115
+ self.num_layers = num_layers
116
+ # convex upsampling: concat feature0 and flow as input
117
+ self.upsampler_1 = nn.Sequential(nn.Conv2d(2 + feature_channels, 256, 3, 1, 1),
118
+ nn.ReLU(inplace=True),
119
+ nn.Conv2d(256, 256, 3, 1, 1),
120
+ nn.ReLU(inplace=True),
121
+ nn.Conv2d(256, upsample_factor ** 2 * 9, 3, 1, 1))
122
+ self.decoder = FlowDecoder(feature_channels)
123
+ self.conv_feat = nn.ModuleList([conv(v1_channel, feature_channels, 1) for i in range(num_scales)])
124
+ self.MT = FeatureTransformer(d_model=feature_channels, num_layers=self.num_layers)
125
+
126
+ # 2*2*8*scale`
127
+ def upsample_flow(self, flow, feature, upsampler=None, bilinear=False, upsample_factor=4):
128
+ if bilinear:
129
+ up_flow = F.interpolate(flow, scale_factor=upsample_factor,
130
+ mode='bilinear', align_corners=True) * upsample_factor
131
+ else:
132
+ # convex upsampling
133
+ concat = torch.cat((flow, feature), dim=1)
134
+ mask = upsampler(concat)
135
+ b, flow_channel, h, w = flow.shape
136
+ mask = mask.view(b, 1, 9, upsample_factor, upsample_factor, h, w) # [B, 1, 9, K, K, H, W]
137
+ mask = torch.softmax(mask, dim=2)
138
+
139
+ up_flow = F.unfold(upsample_factor * flow, [3, 3], padding=1)
140
+ up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W]
141
+
142
+ up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W]
143
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W]
144
+ up_flow = up_flow.reshape(b, flow_channel, upsample_factor * h,
145
+ upsample_factor * w) # [B, 2, K*H, K*W]
146
+
147
+ return up_flow
148
+
149
+ def forward(self, image_list, mix_enable=True, layer=6):
150
+ if layer is not None:
151
+ self.MT.num_layers = layer
152
+ self.num_layers = layer
153
+ results_dict = {}
154
+ padding = self.v1_kz * self.scale_factor
155
+ with torch.no_grad():
156
+ if image_list[0].max() > 10:
157
+ image_list = [img / 255.0 for img in image_list] # [B, 1, H, W] 0-1
158
+ if image_list[0].shape[1] == 3:
159
+ # convert to gray using transform Gray = R*0.299 + G*0.587 + B*0.114
160
+ image_list = [img[:, 0, :, :] * 0.299 + img[:, 1, :, :] * 0.587 + img[:, 2, :, :] * 0.114 for img in
161
+ image_list]
162
+ image_list = [img.unsqueeze(1) for img in image_list]
163
+
164
+ B, _, H, W = image_list[0].shape
165
+ MT_size = (H // 8, W // 8)
166
+ with autocast(enabled=mix_enable):
167
+ # with torch.no_grad(): # TODO: only for test wheather a trainable V1 is needed.
168
+ st_component = self.ffv1(image_list)
169
+ # viz_img_seq(image_scale, if_debug=True)
170
+ if self.num_layers == 0:
171
+ motion_feature = [st_component]
172
+ flows = [self.decoder(feature) for feature in motion_feature]
173
+ flows_up = [self.upsample_flow(flow, feature=None, bilinear=True, upsample_factor=8) for flow in flows]
174
+ results_dict["flow_seq"] = flows_up
175
+ return results_dict
176
+ motion_feature, attn = self.MT.forward_save_mem(st_component)
177
+ flow_v1 = self.decoder(st_component)
178
+
179
+ flows = [flow_v1] + [self.decoder(feature) for feature in motion_feature]
180
+ flows_bi = [self.upsample_flow(flow, feature=None, bilinear=True, upsample_factor=8) for flow in flows]
181
+ flows_up = [flows_bi[0]] + \
182
+ [self.upsample_flow(flows, upsampler=self.upsampler_1, feature=attn, upsample_factor=8) for
183
+ flows, attn in zip(flows[1:], attn)]
184
+ assert len(flows_bi) == len(flows_up)
185
+ results_dict["flow_seq"] = flows_up
186
+ results_dict["flow_seq_bi"] = flows_bi
187
+ return results_dict
188
+
189
+ def forward_test(self, image_list, mix_enable=True, layer=6):
190
+ if layer is not None:
191
+ self.MT.num_layers = layer
192
+ self.num_layers = layer
193
+ results_dict = {}
194
+ padding = self.v1_kz * self.scale_factor
195
+ with torch.no_grad():
196
+ if image_list[0].max() > 10:
197
+ image_list = [img / 255.0 for img in image_list] # [B, 1, H, W] 0-1
198
+
199
+ B, _, H, W = image_list[0].shape
200
+ MT_size = (H // 8, W // 8)
201
+ with autocast(enabled=mix_enable):
202
+ st_component = self.ffv1(image_list)
203
+ # viz_img_seq(image_scale, if_debug=True)
204
+ if self.num_layers == 0:
205
+ motion_feature = [st_component]
206
+ flows = [self.decoder(feature) for feature in motion_feature]
207
+ flows_up = [self.upsample_flow(flow, feature=None, bilinear=True, upsample_factor=8) for flow in flows]
208
+ results_dict["flow_seq"] = flows_up
209
+ return results_dict
210
+ motion_feature, attn, _ = self.MT.forward_save_mem(st_component)
211
+ flow_v1 = self.decoder(st_component)
212
+ flows = [flow_v1] + [self.decoder(feature) for feature in motion_feature]
213
+ flows_bi = [self.upsample_flow(flow, feature=None, bilinear=True, upsample_factor=8) for flow in flows]
214
+ flows_up = [flows_bi[0]] + \
215
+ [self.upsample_flow(flows, upsampler=self.upsampler_1, feature=attn, upsample_factor=8) for
216
+ flows, attn in zip(flows[1:], attn)]
217
+ assert len(flows_bi) == len(flows_up)
218
+ results_dict["flow_seq"] = flows_up
219
+ results_dict["flow_seq_bi"] = flows_bi
220
+ return results_dict
221
+
222
+ def forward_viz(self, image_list, layer=None, x=50, y=50):
223
+ x = x / 100
224
+ y = y / 100
225
+ if layer is not None:
226
+ self.MT.num_layers = layer
227
+ results_dict = {}
228
+ padding = self.v1_kz * self.scale_factor
229
+ with torch.no_grad():
230
+ if image_list[0].max() > 10:
231
+ image_list = [img / 255.0 for img in image_list] # [B, 1, H, W] 0-1
232
+ if image_list[0].shape[1] == 3:
233
+ # convert to gray using transform Gray = R*0.299 + G*0.587 + B*0.114
234
+ image_list = [img[:, 0, :, :] * 0.299 + img[:, 1, :, :] * 0.587 + img[:, 2, :, :] * 0.114 for img in
235
+ image_list]
236
+ image_list = [img.unsqueeze(1) for img in image_list]
237
+ image_list_ori = deepcopy(image_list)
238
+
239
+ B, _, H, W = image_list[0].shape
240
+ MT_size = (H // 8, W // 8)
241
+ with autocast(enabled=True):
242
+ st_component = self.ffv1(image_list)
243
+ activation = self.ffv1.visualize_activation(st_component)
244
+ # viz_img_seq(image_scale, if_debug=True)
245
+ motion_feature, attn, attn_viz = self.MT(st_component)
246
+ flow_v1 = self.decoder(st_component)
247
+
248
+ flows = [flow_v1] + [self.decoder(feature) for feature in motion_feature]
249
+ flows_bi = [self.upsample_flow(flow, feature=None, bilinear=True, upsample_factor=8) for flow in flows]
250
+ flows_up = [flows_bi[0]] + \
251
+ [self.upsample_flow(flows, upsampler=self.upsampler_1, feature=attn, upsample_factor=8) for
252
+ flows, attn in zip(flows[1:], attn)]
253
+ assert len(flows_bi) == len(flows_up)
254
+ results_dict["flow_seq"] = flows_up
255
+ # select 1,3,5,7
256
+ flows_up = [flows_up[i] for i in [0, 2, 4]] + [flows_up[-1]]
257
+ attn_viz = [attn_viz[i] for i in [0, 2, 4]] + [attn_viz[-1]]
258
+ flow = plt_show_img_flow(image_list_ori, flows_up)
259
+ h = int(MT_size[0] * y)
260
+ w = int(MT_size[1] * x)
261
+ attention = plt_attention(attn_viz, h=h, w=w)
262
+ print("done")
263
+ results_dict["activation"] = activation
264
+ results_dict["attention"] = attention
265
+ results_dict["flow"] = flow
266
+
267
+ return results_dict
268
+
269
+ def num_parameters(self):
270
+ return sum(
271
+ [p.data.nelement() if p.requires_grad else 0 for p in self.parameters()])
272
+
273
+ def init_weights(self):
274
+ for layer in self.named_modules():
275
+ if isinstance(layer, nn.Conv2d):
276
+ nn.init.kaiming_normal_(layer.weight)
277
+ if layer.bias is not None:
278
+ nn.init.constant_(layer.bias, 0)
279
+ if isinstance(layer, nn.Conv1d):
280
+ nn.init.kaiming_normal_(layer.weight)
281
+ if layer.bias is not None:
282
+ nn.init.constant_(layer.bias, 0)
283
+
284
+ elif isinstance(layer, nn.ConvTranspose2d):
285
+ nn.init.kaiming_normal_(layer.weight)
286
+ if layer.bias is not None:
287
+ nn.init.constant_(layer.bias, 0)
288
+
289
+ @staticmethod
290
+ def demo(file=None):
291
+ import time
292
+ from utils import torch_utils as utils
293
+ frame_list = [torch.randn([4, 1, 512, 512], device="cuda")] * 11
294
+ model = FFV1DNN(num_scales=8, scale_factor=16, num_cells=256, upsample_factor=8, num_layers=6,
295
+ feature_channels=256).cuda()
296
+ if file is not None:
297
+ model = utils.restore_model(model, file)
298
+ print(model.num_parameters())
299
+ for i in range(100):
300
+ start = time.time()
301
+ output = model.forward_viz(frame_list, layer=7)
302
+ # print(output["flow_seq"][-1])
303
+ torch.mean(output["flow_seq"][-1]).backward()
304
+ print(torch.any(torch.isnan(output["flow_seq"][-1])))
305
+ end = time.time()
306
+ print(end - start)
307
+ print("#================================++#")
308
+
309
+
310
+ if __name__ == '__main__':
311
+ FFV1DNN.demo(None)