from models.BaseNetwork import BaseNetwork from models.transformer_base.ffn_base import FusionFeedForward from models.transformer_base.attention_flow import SWMHSA_depthGlobalWindowConcatLN_qkFlow_reweightFlow from models.transformer_base.attention_base import TMHSA import torch import torch.nn as nn from functools import reduce import torch.nn.functional as F class Model(nn.Module): def __init__(self, config): super(Model, self).__init__() self.net = FGT(config['tw'], config['sw'], config['gd'], config['input_resolution'], config['in_channel'], config['cnum'], config['flow_inChannel'], config['flow_cnum'], config['frame_hidden'], config['flow_hidden'], config['PASSMASK'], config['numBlocks'], config['kernel_size'], config['stride'], config['padding'], config['num_head'], config['conv_type'], config['norm'], config['use_bias'], config['ape'], config['mlp_ratio'], config['drop'], config['init_weights']) def forward(self, frames, flows, masks): ret = self.net(frames, flows, masks) return ret class Encoder(nn.Module): def __init__(self, in_channels): super(Encoder, self).__init__() self.group = [1, 2, 4, 8, 1] self.layers = nn.ModuleList([ nn.Conv2d(in_channels, 64, kernel_size=3, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1), nn.LeakyReLU(0.2, inplace=True) ]) def forward(self, x): bt, c, h, w = x.size() h, w = h // 4, w // 4 out = x for i, layer in enumerate(self.layers): if i == 8: x0 = out if i > 8 and i % 2 == 0: g = self.group[(i - 8) // 2] x = x0.view(bt, g, -1, h, w) o = out.view(bt, g, -1, h, w) out = torch.cat([x, o], 2).view(bt, -1, h, w) out = layer(out) return out class AddPosEmb(nn.Module): def __init__(self, h, w, in_channels, out_channels): super(AddPosEmb, self).__init__() self.proj = nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=True, groups=out_channels) self.h, self.w = h, w def forward(self, x, h=0, w=0): B, N, C = x.shape if h == 0 and w == 0: assert N == self.h * self.w, 'Wrong input size' else: assert N == h * w, 'Wrong input size during inference' feat_token = x if h == 0 and w == 0: cnn_feat = feat_token.transpose(1, 2).view(B, C, self.h, self.w) else: cnn_feat = feat_token.transpose(1, 2).view(B, C, h, w) x = self.proj(cnn_feat) + cnn_feat x = x.flatten(2).transpose(1, 2) return x class Vec2Patch(nn.Module): def __init__(self, channel, hidden, output_size, kernel_size, stride, padding): super(Vec2Patch, self).__init__() self.relu = nn.LeakyReLU(0.2, inplace=True) c_out = reduce((lambda x, y: x * y), kernel_size) * channel self.embedding = nn.Linear(hidden, c_out) self.restore = nn.Fold(output_size=output_size, kernel_size=kernel_size, stride=stride, padding=padding) self.kernel_size = kernel_size self.stride = stride self.padding = padding def forward(self, x, output_h=0, output_w=0): feat = self.embedding(x) feat = feat.permute(0, 2, 1) if output_h != 0 or output_w != 0: feat = F.fold(feat, output_size=(output_h, output_w), kernel_size=self.kernel_size, stride=self.stride, padding=self.padding) else: feat = self.restore(feat) return feat class TemporalTransformer(nn.Module): def __init__(self, token_size, frame_hidden, num_heads, t_groupSize, mlp_ratio, dropout, n_vecs, t2t_params): super(TemporalTransformer, self).__init__() self.attention = TMHSA(token_size=token_size, group_size=t_groupSize, d_model=frame_hidden, head=num_heads, p=dropout) self.ffn = FusionFeedForward(frame_hidden, mlp_ratio, n_vecs, t2t_params, p=dropout) self.norm1 = nn.LayerNorm(frame_hidden) self.norm2 = nn.LayerNorm(frame_hidden) self.dropout = nn.Dropout(p=dropout) def forward(self, x, t, h, w, output_size): token_size = h * w s = self.norm1(x) x = x + self.dropout(self.attention(s, t, h, w)) y = self.norm2(x) x = x + self.ffn(y, token_size, output_size[0], output_size[1]) return x class SpatialTransformer(nn.Module): def __init__(self, token_size, frame_hidden, flow_hidden, num_heads, s_windowSize, g_downSize, mlp_ratio, dropout, n_vecs, t2t_params): super(SpatialTransformer, self).__init__() self.attention = SWMHSA_depthGlobalWindowConcatLN_qkFlow_reweightFlow(token_size=token_size, window_size=s_windowSize, kernel_size=g_downSize, d_model=frame_hidden, flow_dModel=flow_hidden, head=num_heads, p=dropout) self.ffn = FusionFeedForward(frame_hidden, mlp_ratio, n_vecs, t2t_params, p=dropout) self.norm = nn.LayerNorm(frame_hidden) self.dropout = nn.Dropout(p=dropout) def forward(self, x, f, t, h, w, output_size): token_size = h * w x = x + self.dropout(self.attention(x, f, t, h, w)) y = self.norm(x) x = x + self.ffn(y, token_size, output_size[0], output_size[1]) return x class TransformerBlock(nn.Module): def __init__(self, token_size, frame_hidden, flow_hidden, num_heads, t_groupSize, s_windowSize, g_downSize, mlp_ratio, dropout, n_vecs, t2t_params): super(TransformerBlock, self).__init__() self.t_transformer = TemporalTransformer(token_size=token_size, frame_hidden=frame_hidden, num_heads=num_heads, t_groupSize=t_groupSize, mlp_ratio=mlp_ratio, dropout=dropout, n_vecs=n_vecs, t2t_params=t2t_params) # temporal multi-head self attention self.s_transformer = SpatialTransformer(token_size=token_size, frame_hidden=frame_hidden, flow_hidden=flow_hidden, num_heads=num_heads, s_windowSize=s_windowSize, g_downSize=g_downSize, mlp_ratio=mlp_ratio, dropout=dropout, n_vecs=n_vecs, t2t_params=t2t_params) def forward(self, inputs): x, f, t = inputs['x'], inputs['f'], inputs['t'] h, w = inputs['h'], inputs['w'] output_size = inputs['output_size'] x = self.t_transformer(x, t, h, w, output_size) x = self.s_transformer(x, f, t, h, w, output_size) return {'x': x, 'f': f, 't': t, 'h': h, 'w': w, 'output_size': output_size} class Decoder(BaseNetwork): def __init__(self, conv_type, in_channels, out_channels, use_bias, norm=None): super(Decoder, self).__init__(conv_type) self.layer1 = self.DeconvBlock(in_channels, in_channels, kernel_size=3, padding=1, norm=norm, bias=use_bias) self.layer2 = self.ConvBlock(in_channels, in_channels // 2, kernel_size=3, stride=1, padding=1, norm=norm, bias=use_bias) self.layer3 = self.DeconvBlock(in_channels // 2, in_channels // 2, kernel_size=3, padding=1, norm=norm, bias=use_bias) self.final = self.ConvBlock(in_channels // 2, out_channels, kernel_size=3, stride=1, padding=1, norm=norm, bias=use_bias, activation=None) def forward(self, features): feat1 = self.layer1(features) feat2 = self.layer2(feat1) feat3 = self.layer3(feat2) output = self.final(feat3) return output class FGT(BaseNetwork): def __init__(self, t_groupSize, s_windowSize, g_downSize, input_resolution, in_channels, cnum, flow_inChannel, flow_cnum, frame_hidden, flow_hidden, passmask, numBlocks, kernel_size, stride, padding, num_heads, conv_type, norm, use_bias, ape, mlp_ratio=4, drop=0, init_weights=True): super(FGT, self).__init__(conv_type) self.in_channels = in_channels self.passmask = passmask self.ape = ape self.frame_endoder = Encoder(in_channels) self.flow_encoder = nn.Sequential( nn.ReplicationPad2d(2), self.ConvBlock(flow_inChannel, flow_cnum, kernel_size=5, stride=1, padding=0, bias=use_bias, norm=norm), self.ConvBlock(flow_cnum, flow_cnum * 2, kernel_size=3, stride=2, padding=1, bias=use_bias, norm=norm), self.ConvBlock(flow_cnum * 2, flow_cnum * 2, kernel_size=3, stride=1, padding=1, bias=use_bias, norm=norm), self.ConvBlock(flow_cnum * 2, flow_cnum * 2, kernel_size=3, stride=2, padding=1, bias=use_bias, norm=norm) ) # patch to vector operation self.patch2vec = nn.Conv2d(cnum * 2, frame_hidden, kernel_size=kernel_size, stride=stride, padding=padding) self.f_patch2vec = nn.Conv2d(flow_cnum * 2, flow_hidden, kernel_size=kernel_size, stride=stride, padding=padding) # initialize transformer blocks for frame completion n_vecs = 1 token_size = [] output_shape = (input_resolution[0] // 4, input_resolution[1] // 4) for i, d in enumerate(kernel_size): token_nums = int((output_shape[i] + 2 * padding[i] - kernel_size[i]) / stride[i] + 1) n_vecs *= token_nums token_size.append(token_nums) # Add positional embedding to the encode features if self.ape: self.add_pos_emb = AddPosEmb(token_size[0], token_size[1], frame_hidden, frame_hidden) self.token_size = token_size # initialize transformer blocks blocks = [] t2t_params = {'kernel_size': kernel_size, 'stride': stride, 'padding': padding, 'output_size': output_shape} for i in range(numBlocks // 2 - 1): layer = TransformerBlock(token_size, frame_hidden, flow_hidden, num_heads, t_groupSize, s_windowSize, g_downSize, mlp_ratio, drop, n_vecs, t2t_params) blocks.append(layer) self.first_t_transformer = TemporalTransformer(token_size, frame_hidden, num_heads, t_groupSize, mlp_ratio, drop, n_vecs, t2t_params) self.first_s_transformer = SpatialTransformer(token_size, frame_hidden, flow_hidden, num_heads, s_windowSize, g_downSize, mlp_ratio, drop, n_vecs, t2t_params) self.transformer = nn.Sequential(*blocks) # vector to patch operation self.vec2patch = Vec2Patch(cnum * 2, frame_hidden, output_shape, kernel_size, stride, padding) # decoder self.decoder = Decoder(conv_type, cnum * 2, 3, use_bias, norm) if init_weights: self.init_weights() def forward(self, masked_frames, flows, masks): b, t, c, h, w = masked_frames.shape cf = flows.shape[2] output_shape = (h // 4, w // 4) if self.passmask: inputs = torch.cat((masked_frames, masks), dim=2) else: inputs = masked_frames inputs = inputs.view(b * t, self.in_channels, h, w) flows = flows.view(b * t, cf, h, w) enc_feats = self.frame_endoder(inputs) flow_feats = self.flow_encoder(flows) trans_feat = self.patch2vec(enc_feats) flow_patches = self.f_patch2vec(flow_feats) _, c, h, w = trans_feat.shape cf = flow_patches.shape[1] if h != self.token_size[0] or w != self.token_size[1]: new_h, new_w = h, w else: new_h, new_w = 0, 0 output_shape = (0, 0) trans_feat = trans_feat.view(b * t, c, -1).permute(0, 2, 1) flow_patches = flow_patches.view(b * t, cf, -1).permute(0, 2, 1) trans_feat = self.first_t_transformer(trans_feat, t, new_h, new_w, output_shape) trans_feat = self.add_pos_emb(trans_feat, new_h, new_w) trans_feat = self.first_s_transformer(trans_feat, flow_patches, t, new_h, new_w, output_shape) inputs_trans_feat = {'x': trans_feat, 'f': flow_patches, 't': t, 'h': new_h, 'w': new_w, 'output_size': output_shape} trans_feat = self.transformer(inputs_trans_feat)['x'] trans_feat = self.vec2patch(trans_feat, output_shape[0], output_shape[1]) enc_feats = enc_feats + trans_feat output = self.decoder(enc_feats) output = torch.tanh(output) return output