import numpy as np import torch from torch import nn from torch.nn import functional as F from MT import FeatureTransformer from torch.cuda.amp import autocast as autocast from flow_tools import viz_img_seq, save_img_seq, plt_show_img_flow from copy import deepcopy from V1 import V1 import matplotlib.pyplot as plt from io import BytesIO from PIL import Image def conv(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, isReLU=True): if isReLU: return nn.Sequential( nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size - 1) * dilation) // 2, bias=True), nn.GELU() ) else: return nn.Sequential( nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size - 1) * dilation) // 2, bias=True) ) def plt_attention(attention, h, w): col = len(attention) // 2 fig = plt.figure(figsize=(10, 8)) for i in range(len(attention)): viz = attention[i][0, :, :, h, w].detach().cpu().numpy() # viz = viz[7:-7, 7:-7] if i == 0: viz_all = viz else: viz_all = viz_all + viz ax1 = fig.add_subplot(2, col, i + 1) img = ax1.imshow(viz, cmap="rainbow", interpolation="bilinear") ax1.scatter(w, h, color='grey', s=300, alpha=0.5) ax1.scatter(w, h, color='red', s=150, alpha=0.5) plt.title(" Iteration %d" % (i + 1)) if i == len(attention) - 1: plt.title(" Final Iteration") plt.xticks([]) plt.yticks([]) # tight layout plt.tight_layout() # save the figure buf = BytesIO() plt.savefig(buf, format='png') buf.seek(0) plt.close() # convert the figure to an array img = Image.open(buf) img = np.array(img) return img class FlowDecoder(nn.Module): # can reduce 25% of training time. def __init__(self, ch_in): super(FlowDecoder, self).__init__() self.conv1 = conv(ch_in, 256, kernel_size=1) self.conv2 = conv(256, 128, kernel_size=1) self.conv3 = conv(256 + 128, 96, kernel_size=1) self.conv4 = conv(96 + 128, 64, kernel_size=1) self.conv5 = conv(96 + 64, 32, kernel_size=1) self.feat_dim = 32 self.predict_flow = conv(64 + 32, 2, isReLU=False) def forward(self, x): x1 = self.conv1(x) x2 = self.conv2(x1) x3 = self.conv3(torch.cat([x1, x2], dim=1)) x4 = self.conv4(torch.cat([x2, x3], dim=1)) x5 = self.conv5(torch.cat([x3, x4], dim=1)) flow = self.predict_flow(torch.cat([x4, x5], dim=1)) return flow class FFV1DNN(nn.Module): def __init__(self, num_scales=8, num_cells=256, upsample_factor=8, feature_channels=256, scale_factor=16, num_layers=6, ): super(FFV1DNN, self).__init__() self.ffv1 = V1(spatial_num=num_cells // num_scales, scale_num=num_scales, scale_factor=scale_factor, kernel_radius=7, num_ft=num_cells // num_scales, kernel_size=6, average_time=True) self.v1_kz = 7 self.scale_factor = scale_factor scale_each_level = np.exp(1 / (num_scales - 1) * np.log(1 / scale_factor)) self.scale_num = num_scales self.scale_each_level = scale_each_level v1_channel = self.ffv1.num_after_st self.num_scales = num_scales self.MT_channel = feature_channels assert self.MT_channel == v1_channel self.feature_channels = feature_channels self.upsample_factor = upsample_factor self.num_layers = num_layers # convex upsampling: concat feature0 and flow as input self.upsampler_1 = nn.Sequential(nn.Conv2d(2 + feature_channels, 256, 3, 1, 1), nn.ReLU(inplace=True), nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True), nn.Conv2d(256, upsample_factor ** 2 * 9, 3, 1, 1)) self.decoder = FlowDecoder(feature_channels) self.conv_feat = nn.ModuleList([conv(v1_channel, feature_channels, 1) for i in range(num_scales)]) self.MT = FeatureTransformer(d_model=feature_channels, num_layers=self.num_layers) # 2*2*8*scale` def upsample_flow(self, flow, feature, upsampler=None, bilinear=False, upsample_factor=4): if bilinear: up_flow = F.interpolate(flow, scale_factor=upsample_factor, mode='bilinear', align_corners=True) * upsample_factor else: # convex upsampling concat = torch.cat((flow, feature), dim=1) mask = upsampler(concat) b, flow_channel, h, w = flow.shape mask = mask.view(b, 1, 9, upsample_factor, upsample_factor, h, w) # [B, 1, 9, K, K, H, W] mask = torch.softmax(mask, dim=2) up_flow = F.unfold(upsample_factor * flow, [3, 3], padding=1) up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W] up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W] up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W] up_flow = up_flow.reshape(b, flow_channel, upsample_factor * h, upsample_factor * w) # [B, 2, K*H, K*W] return up_flow def forward(self, image_list, mix_enable=True, layer=6): if layer is not None: self.MT.num_layers = layer self.num_layers = layer results_dict = {} padding = self.v1_kz * self.scale_factor with torch.no_grad(): if image_list[0].max() > 10: image_list = [img / 255.0 for img in image_list] # [B, 1, H, W] 0-1 if image_list[0].shape[1] == 3: # convert to gray using transform Gray = R*0.299 + G*0.587 + B*0.114 image_list = [img[:, 0, :, :] * 0.299 + img[:, 1, :, :] * 0.587 + img[:, 2, :, :] * 0.114 for img in image_list] image_list = [img.unsqueeze(1) for img in image_list] B, _, H, W = image_list[0].shape MT_size = (H // 8, W // 8) with autocast(enabled=mix_enable): # with torch.no_grad(): # TODO: only for test wheather a trainable V1 is needed. st_component = self.ffv1(image_list) # viz_img_seq(image_scale, if_debug=True) if self.num_layers == 0: motion_feature = [st_component] flows = [self.decoder(feature) for feature in motion_feature] flows_up = [self.upsample_flow(flow, feature=None, bilinear=True, upsample_factor=8) for flow in flows] results_dict["flow_seq"] = flows_up return results_dict motion_feature, attn = self.MT.forward_save_mem(st_component) flow_v1 = self.decoder(st_component) flows = [flow_v1] + [self.decoder(feature) for feature in motion_feature] flows_bi = [self.upsample_flow(flow, feature=None, bilinear=True, upsample_factor=8) for flow in flows] flows_up = [flows_bi[0]] + \ [self.upsample_flow(flows, upsampler=self.upsampler_1, feature=attn, upsample_factor=8) for flows, attn in zip(flows[1:], attn)] assert len(flows_bi) == len(flows_up) results_dict["flow_seq"] = flows_up results_dict["flow_seq_bi"] = flows_bi return results_dict def forward_test(self, image_list, mix_enable=True, layer=6): if layer is not None: self.MT.num_layers = layer self.num_layers = layer results_dict = {} padding = self.v1_kz * self.scale_factor with torch.no_grad(): if image_list[0].max() > 10: image_list = [img / 255.0 for img in image_list] # [B, 1, H, W] 0-1 B, _, H, W = image_list[0].shape MT_size = (H // 8, W // 8) with autocast(enabled=mix_enable): st_component = self.ffv1(image_list) # viz_img_seq(image_scale, if_debug=True) if self.num_layers == 0: motion_feature = [st_component] flows = [self.decoder(feature) for feature in motion_feature] flows_up = [self.upsample_flow(flow, feature=None, bilinear=True, upsample_factor=8) for flow in flows] results_dict["flow_seq"] = flows_up return results_dict motion_feature, attn, _ = self.MT.forward_save_mem(st_component) flow_v1 = self.decoder(st_component) flows = [flow_v1] + [self.decoder(feature) for feature in motion_feature] flows_bi = [self.upsample_flow(flow, feature=None, bilinear=True, upsample_factor=8) for flow in flows] flows_up = [flows_bi[0]] + \ [self.upsample_flow(flows, upsampler=self.upsampler_1, feature=attn, upsample_factor=8) for flows, attn in zip(flows[1:], attn)] assert len(flows_bi) == len(flows_up) results_dict["flow_seq"] = flows_up results_dict["flow_seq_bi"] = flows_bi return results_dict @torch.no_grad() def forward_viz(self, image_list, layer=None, x=50, y=50): x = x / 100 y = y / 100 if layer is not None: self.MT.num_layers = layer results_dict = {} padding = self.v1_kz * self.scale_factor with torch.no_grad(): if image_list[0].max() > 10: image_list = [img / 255.0 for img in image_list] # [B, 1, H, W] 0-1 if image_list[0].shape[1] == 3: # convert to gray using transform Gray = R*0.299 + G*0.587 + B*0.114 image_list = [img[:, 0, :, :] * 0.299 + img[:, 1, :, :] * 0.587 + img[:, 2, :, :] * 0.114 for img in image_list] image_list = [img.unsqueeze(1) for img in image_list] image_list_ori = deepcopy(image_list) B, _, H, W = image_list[0].shape MT_size = (H // 8, W // 8) with autocast(enabled=True): st_component = self.ffv1(image_list) activation = self.ffv1.visualize_activation(st_component) # viz_img_seq(image_scale, if_debug=True) motion_feature, attn, attn_viz = self.MT(st_component) flow_v1 = self.decoder(st_component) flows = [flow_v1] + [self.decoder(feature) for feature in motion_feature] flows_bi = [self.upsample_flow(flow, feature=None, bilinear=True, upsample_factor=8) for flow in flows] flows_up = [flows_bi[0]] + \ [self.upsample_flow(flows, upsampler=self.upsampler_1, feature=attn, upsample_factor=8) for flows, attn in zip(flows[1:], attn)] assert len(flows_bi) == len(flows_up) results_dict["flow_seq"] = flows_up flows_up = flows_up[:-1] attn_viz = attn_viz print(len(flows_up), len(attn_viz)) flow = plt_show_img_flow(image_list_ori, flows_up) h = int(MT_size[0] * y) w = int(MT_size[1] * x) attention = plt_attention(attn_viz, h=h, w=w) print("done") results_dict["activation"] = activation results_dict["attention"] = attention results_dict["flow"] = flow plt.clf() plt.cla() plt.close() return results_dict def num_parameters(self): return sum( [p.data.nelement() if p.requires_grad else 0 for p in self.parameters()]) def init_weights(self): for layer in self.named_modules(): if isinstance(layer, nn.Conv2d): nn.init.kaiming_normal_(layer.weight) if layer.bias is not None: nn.init.constant_(layer.bias, 0) if isinstance(layer, nn.Conv1d): nn.init.kaiming_normal_(layer.weight) if layer.bias is not None: nn.init.constant_(layer.bias, 0) elif isinstance(layer, nn.ConvTranspose2d): nn.init.kaiming_normal_(layer.weight) if layer.bias is not None: nn.init.constant_(layer.bias, 0) @staticmethod def demo(file=None): import time from utils import torch_utils as utils frame_list = [torch.randn([4, 1, 512, 512], device="cuda")] * 11 model = FFV1DNN(num_scales=8, scale_factor=16, num_cells=256, upsample_factor=8, num_layers=6, feature_channels=256).cuda() if file is not None: model = utils.restore_model(model, file) print(model.num_parameters()) for i in range(100): start = time.time() output = model.forward_viz(frame_list, layer=7) # print(output["flow_seq"][-1]) torch.mean(output["flow_seq"][-1]).backward() print(torch.any(torch.isnan(output["flow_seq"][-1]))) end = time.time() print(end - start) print("#================================++#") if __name__ == '__main__': FFV1DNN.demo(None)