oguzakif's picture
init repo
d4b77ac
raw
history blame
No virus
14.1 kB
from models.BaseNetwork import BaseNetwork
from models.transformer_base.ffn_base import FusionFeedForward
from models.transformer_base.attention_flow import SWMHSA_depthGlobalWindowConcatLN_qkFlow_reweightFlow
from models.transformer_base.attention_base import TMHSA
import torch
import torch.nn as nn
from functools import reduce
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self, config):
super(Model, self).__init__()
self.net = FGT(config['tw'], config['sw'], config['gd'], config['input_resolution'], config['in_channel'],
config['cnum'], config['flow_inChannel'], config['flow_cnum'], config['frame_hidden'],
config['flow_hidden'], config['PASSMASK'],
config['numBlocks'], config['kernel_size'], config['stride'], config['padding'],
config['num_head'], config['conv_type'], config['norm'],
config['use_bias'], config['ape'],
config['mlp_ratio'], config['drop'], config['init_weights'])
def forward(self, frames, flows, masks):
ret = self.net(frames, flows, masks)
return ret
class Encoder(nn.Module):
def __init__(self, in_channels):
super(Encoder, self).__init__()
self.group = [1, 2, 4, 8, 1]
self.layers = nn.ModuleList([
nn.Conv2d(in_channels, 64, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1),
nn.LeakyReLU(0.2, inplace=True)
])
def forward(self, x):
bt, c, h, w = x.size()
h, w = h // 4, w // 4
out = x
for i, layer in enumerate(self.layers):
if i == 8:
x0 = out
if i > 8 and i % 2 == 0:
g = self.group[(i - 8) // 2]
x = x0.view(bt, g, -1, h, w)
o = out.view(bt, g, -1, h, w)
out = torch.cat([x, o], 2).view(bt, -1, h, w)
out = layer(out)
return out
class AddPosEmb(nn.Module):
def __init__(self, h, w, in_channels, out_channels):
super(AddPosEmb, self).__init__()
self.proj = nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=True, groups=out_channels)
self.h, self.w = h, w
def forward(self, x, h=0, w=0):
B, N, C = x.shape
if h == 0 and w == 0:
assert N == self.h * self.w, 'Wrong input size'
else:
assert N == h * w, 'Wrong input size during inference'
feat_token = x
if h == 0 and w == 0:
cnn_feat = feat_token.transpose(1, 2).view(B, C, self.h, self.w)
else:
cnn_feat = feat_token.transpose(1, 2).view(B, C, h, w)
x = self.proj(cnn_feat) + cnn_feat
x = x.flatten(2).transpose(1, 2)
return x
class Vec2Patch(nn.Module):
def __init__(self, channel, hidden, output_size, kernel_size, stride, padding):
super(Vec2Patch, self).__init__()
self.relu = nn.LeakyReLU(0.2, inplace=True)
c_out = reduce((lambda x, y: x * y), kernel_size) * channel
self.embedding = nn.Linear(hidden, c_out)
self.restore = nn.Fold(output_size=output_size, kernel_size=kernel_size, stride=stride, padding=padding)
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
def forward(self, x, output_h=0, output_w=0):
feat = self.embedding(x)
feat = feat.permute(0, 2, 1)
if output_h != 0 or output_w != 0:
feat = F.fold(feat, output_size=(output_h, output_w), kernel_size=self.kernel_size, stride=self.stride,
padding=self.padding)
else:
feat = self.restore(feat)
return feat
class TemporalTransformer(nn.Module):
def __init__(self, token_size, frame_hidden, num_heads, t_groupSize, mlp_ratio, dropout, n_vecs,
t2t_params):
super(TemporalTransformer, self).__init__()
self.attention = TMHSA(token_size=token_size, group_size=t_groupSize, d_model=frame_hidden, head=num_heads,
p=dropout)
self.ffn = FusionFeedForward(frame_hidden, mlp_ratio, n_vecs, t2t_params, p=dropout)
self.norm1 = nn.LayerNorm(frame_hidden)
self.norm2 = nn.LayerNorm(frame_hidden)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, t, h, w, output_size):
token_size = h * w
s = self.norm1(x)
x = x + self.dropout(self.attention(s, t, h, w))
y = self.norm2(x)
x = x + self.ffn(y, token_size, output_size[0], output_size[1])
return x
class SpatialTransformer(nn.Module):
def __init__(self, token_size, frame_hidden, flow_hidden, num_heads, s_windowSize, g_downSize, mlp_ratio,
dropout, n_vecs, t2t_params):
super(SpatialTransformer, self).__init__()
self.attention = SWMHSA_depthGlobalWindowConcatLN_qkFlow_reweightFlow(token_size=token_size, window_size=s_windowSize,
kernel_size=g_downSize, d_model=frame_hidden,
flow_dModel=flow_hidden, head=num_heads, p=dropout)
self.ffn = FusionFeedForward(frame_hidden, mlp_ratio, n_vecs, t2t_params, p=dropout)
self.norm = nn.LayerNorm(frame_hidden)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, f, t, h, w, output_size):
token_size = h * w
x = x + self.dropout(self.attention(x, f, t, h, w))
y = self.norm(x)
x = x + self.ffn(y, token_size, output_size[0], output_size[1])
return x
class TransformerBlock(nn.Module):
def __init__(self, token_size, frame_hidden, flow_hidden, num_heads, t_groupSize, s_windowSize, g_downSize,
mlp_ratio,
dropout, n_vecs,
t2t_params):
super(TransformerBlock, self).__init__()
self.t_transformer = TemporalTransformer(token_size=token_size, frame_hidden=frame_hidden, num_heads=num_heads,
t_groupSize=t_groupSize, mlp_ratio=mlp_ratio,
dropout=dropout, n_vecs=n_vecs,
t2t_params=t2t_params) # temporal multi-head self attention
self.s_transformer = SpatialTransformer(token_size=token_size, frame_hidden=frame_hidden,
flow_hidden=flow_hidden, num_heads=num_heads, s_windowSize=s_windowSize,
g_downSize=g_downSize, mlp_ratio=mlp_ratio,
dropout=dropout, n_vecs=n_vecs, t2t_params=t2t_params)
def forward(self, inputs):
x, f, t = inputs['x'], inputs['f'], inputs['t']
h, w = inputs['h'], inputs['w']
output_size = inputs['output_size']
x = self.t_transformer(x, t, h, w, output_size)
x = self.s_transformer(x, f, t, h, w, output_size)
return {'x': x, 'f': f, 't': t, 'h': h, 'w': w, 'output_size': output_size}
class Decoder(BaseNetwork):
def __init__(self, conv_type, in_channels, out_channels, use_bias, norm=None):
super(Decoder, self).__init__(conv_type)
self.layer1 = self.DeconvBlock(in_channels, in_channels, kernel_size=3, padding=1, norm=norm,
bias=use_bias)
self.layer2 = self.ConvBlock(in_channels, in_channels // 2, kernel_size=3, stride=1, padding=1, norm=norm,
bias=use_bias)
self.layer3 = self.DeconvBlock(in_channels // 2, in_channels // 2, kernel_size=3, padding=1, norm=norm,
bias=use_bias)
self.final = self.ConvBlock(in_channels // 2, out_channels, kernel_size=3, stride=1, padding=1, norm=norm,
bias=use_bias, activation=None)
def forward(self, features):
feat1 = self.layer1(features)
feat2 = self.layer2(feat1)
feat3 = self.layer3(feat2)
output = self.final(feat3)
return output
class FGT(BaseNetwork):
def __init__(self, t_groupSize, s_windowSize, g_downSize, input_resolution, in_channels, cnum, flow_inChannel,
flow_cnum,
frame_hidden, flow_hidden, passmask, numBlocks, kernel_size, stride, padding, num_heads, conv_type,
norm, use_bias, ape, mlp_ratio=4, drop=0, init_weights=True):
super(FGT, self).__init__(conv_type)
self.in_channels = in_channels
self.passmask = passmask
self.ape = ape
self.frame_endoder = Encoder(in_channels)
self.flow_encoder = nn.Sequential(
nn.ReplicationPad2d(2),
self.ConvBlock(flow_inChannel, flow_cnum, kernel_size=5, stride=1, padding=0, bias=use_bias, norm=norm),
self.ConvBlock(flow_cnum, flow_cnum * 2, kernel_size=3, stride=2, padding=1, bias=use_bias, norm=norm),
self.ConvBlock(flow_cnum * 2, flow_cnum * 2, kernel_size=3, stride=1, padding=1, bias=use_bias, norm=norm),
self.ConvBlock(flow_cnum * 2, flow_cnum * 2, kernel_size=3, stride=2, padding=1, bias=use_bias, norm=norm)
)
# patch to vector operation
self.patch2vec = nn.Conv2d(cnum * 2, frame_hidden, kernel_size=kernel_size, stride=stride, padding=padding)
self.f_patch2vec = nn.Conv2d(flow_cnum * 2, flow_hidden, kernel_size=kernel_size, stride=stride,
padding=padding)
# initialize transformer blocks for frame completion
n_vecs = 1
token_size = []
output_shape = (input_resolution[0] // 4, input_resolution[1] // 4)
for i, d in enumerate(kernel_size):
token_nums = int((output_shape[i] + 2 * padding[i] - kernel_size[i]) / stride[i] + 1)
n_vecs *= token_nums
token_size.append(token_nums)
# Add positional embedding to the encode features
if self.ape:
self.add_pos_emb = AddPosEmb(token_size[0], token_size[1], frame_hidden, frame_hidden)
self.token_size = token_size
# initialize transformer blocks
blocks = []
t2t_params = {'kernel_size': kernel_size, 'stride': stride, 'padding': padding, 'output_size': output_shape}
for i in range(numBlocks // 2 - 1):
layer = TransformerBlock(token_size, frame_hidden, flow_hidden, num_heads, t_groupSize, s_windowSize,
g_downSize, mlp_ratio, drop, n_vecs, t2t_params)
blocks.append(layer)
self.first_t_transformer = TemporalTransformer(token_size, frame_hidden, num_heads, t_groupSize, mlp_ratio,
drop, n_vecs, t2t_params)
self.first_s_transformer = SpatialTransformer(token_size, frame_hidden, flow_hidden, num_heads, s_windowSize,
g_downSize, mlp_ratio, drop, n_vecs, t2t_params)
self.transformer = nn.Sequential(*blocks)
# vector to patch operation
self.vec2patch = Vec2Patch(cnum * 2, frame_hidden, output_shape, kernel_size, stride, padding)
# decoder
self.decoder = Decoder(conv_type, cnum * 2, 3, use_bias, norm)
if init_weights:
self.init_weights()
def forward(self, masked_frames, flows, masks):
b, t, c, h, w = masked_frames.shape
cf = flows.shape[2]
output_shape = (h // 4, w // 4)
if self.passmask:
inputs = torch.cat((masked_frames, masks), dim=2)
else:
inputs = masked_frames
inputs = inputs.view(b * t, self.in_channels, h, w)
flows = flows.view(b * t, cf, h, w)
enc_feats = self.frame_endoder(inputs)
flow_feats = self.flow_encoder(flows)
trans_feat = self.patch2vec(enc_feats)
flow_patches = self.f_patch2vec(flow_feats)
_, c, h, w = trans_feat.shape
cf = flow_patches.shape[1]
if h != self.token_size[0] or w != self.token_size[1]:
new_h, new_w = h, w
else:
new_h, new_w = 0, 0
output_shape = (0, 0)
trans_feat = trans_feat.view(b * t, c, -1).permute(0, 2, 1)
flow_patches = flow_patches.view(b * t, cf, -1).permute(0, 2, 1)
trans_feat = self.first_t_transformer(trans_feat, t, new_h, new_w, output_shape)
trans_feat = self.add_pos_emb(trans_feat, new_h, new_w)
trans_feat = self.first_s_transformer(trans_feat, flow_patches, t, new_h, new_w, output_shape)
inputs_trans_feat = {'x': trans_feat, 'f': flow_patches, 't': t, 'h': new_h, 'w': new_w,
'output_size': output_shape}
trans_feat = self.transformer(inputs_trans_feat)['x']
trans_feat = self.vec2patch(trans_feat, output_shape[0], output_shape[1])
enc_feats = enc_feats + trans_feat
output = self.decoder(enc_feats)
output = torch.tanh(output)
return output