watchtowerss's picture
track-anything --version 1
4d1ebf3
raw
history blame
14.2 kB
''' Towards An End-to-End Framework for Video Inpainting
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.modules.flow_comp import SPyNet
from model.modules.feat_prop import BidirectionalPropagation, SecondOrderDeformableAlignment
from model.modules.tfocal_transformer import TemporalFocalTransformerBlock, SoftSplit, SoftComp
from model.modules.spectral_norm import spectral_norm as _spectral_norm
class BaseNetwork(nn.Module):
def __init__(self):
super(BaseNetwork, self).__init__()
def print_network(self):
if isinstance(self, list):
self = self[0]
num_params = 0
for param in self.parameters():
num_params += param.numel()
print(
'Network [%s] was created. Total number of parameters: %.1f million. '
'To see the architecture, do print(network).' %
(type(self).__name__, num_params / 1000000))
def init_weights(self, init_type='normal', gain=0.02):
'''
initialize network's weights
init_type: normal | xavier | kaiming | orthogonal
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
'''
def init_func(m):
classname = m.__class__.__name__
if classname.find('InstanceNorm2d') != -1:
if hasattr(m, 'weight') and m.weight is not None:
nn.init.constant_(m.weight.data, 1.0)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif hasattr(m, 'weight') and (classname.find('Conv') != -1
or classname.find('Linear') != -1):
if init_type == 'normal':
nn.init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
nn.init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'xavier_uniform':
nn.init.xavier_uniform_(m.weight.data, gain=1.0)
elif init_type == 'kaiming':
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
nn.init.orthogonal_(m.weight.data, gain=gain)
elif init_type == 'none': # uses pytorch's default init method
m.reset_parameters()
else:
raise NotImplementedError(
'initialization method [%s] is not implemented' %
init_type)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
self.apply(init_func)
# propagate to children
for m in self.children():
if hasattr(m, 'init_weights'):
m.init_weights(init_type, gain)
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.group = [1, 2, 4, 8, 1]
self.layers = nn.ModuleList([
nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1),
nn.LeakyReLU(0.2, inplace=True)
])
def forward(self, x):
bt, c, h, w = x.size()
h, w = h // 4, w // 4
out = x
for i, layer in enumerate(self.layers):
if i == 8:
x0 = out
if i > 8 and i % 2 == 0:
g = self.group[(i - 8) // 2]
x = x0.view(bt, g, -1, h, w)
o = out.view(bt, g, -1, h, w)
out = torch.cat([x, o], 2).view(bt, -1, h, w)
out = layer(out)
return out
class deconv(nn.Module):
def __init__(self,
input_channel,
output_channel,
kernel_size=3,
padding=0):
super().__init__()
self.conv = nn.Conv2d(input_channel,
output_channel,
kernel_size=kernel_size,
stride=1,
padding=padding)
def forward(self, x):
x = F.interpolate(x,
scale_factor=2,
mode='bilinear',
align_corners=True)
return self.conv(x)
class InpaintGenerator(BaseNetwork):
def __init__(self, init_weights=True):
super(InpaintGenerator, self).__init__()
channel = 256
hidden = 512
# encoder
self.encoder = Encoder()
# decoder
self.decoder = nn.Sequential(
deconv(channel // 2, 128, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
deconv(64, 64, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1))
# feature propagation module
self.feat_prop_module = BidirectionalPropagation(channel // 2)
# soft split and soft composition
kernel_size = (7, 7)
padding = (3, 3)
stride = (3, 3)
output_size = (60, 108)
t2t_params = {
'kernel_size': kernel_size,
'stride': stride,
'padding': padding,
'output_size': output_size
}
self.ss = SoftSplit(channel // 2,
hidden,
kernel_size,
stride,
padding,
t2t_param=t2t_params)
self.sc = SoftComp(channel // 2, hidden, output_size, kernel_size,
stride, padding)
n_vecs = 1
for i, d in enumerate(kernel_size):
n_vecs *= int((output_size[i] + 2 * padding[i] -
(d - 1) - 1) / stride[i] + 1)
blocks = []
depths = 8
num_heads = [4] * depths
window_size = [(5, 9)] * depths
focal_windows = [(5, 9)] * depths
focal_levels = [2] * depths
pool_method = "fc"
for i in range(depths):
blocks.append(
TemporalFocalTransformerBlock(dim=hidden,
num_heads=num_heads[i],
window_size=window_size[i],
focal_level=focal_levels[i],
focal_window=focal_windows[i],
n_vecs=n_vecs,
t2t_params=t2t_params,
pool_method=pool_method))
self.transformer = nn.Sequential(*blocks)
if init_weights:
self.init_weights()
# Need to initial the weights of MSDeformAttn specifically
for m in self.modules():
if isinstance(m, SecondOrderDeformableAlignment):
m.init_offset()
# flow completion network
self.update_spynet = SPyNet()
def forward_bidirect_flow(self, masked_local_frames):
b, l_t, c, h, w = masked_local_frames.size()
# compute forward and backward flows of masked frames
masked_local_frames = F.interpolate(masked_local_frames.view(
-1, c, h, w),
scale_factor=1 / 4,
mode='bilinear',
align_corners=True,
recompute_scale_factor=True)
masked_local_frames = masked_local_frames.view(b, l_t, c, h // 4,
w // 4)
mlf_1 = masked_local_frames[:, :-1, :, :, :].reshape(
-1, c, h // 4, w // 4)
mlf_2 = masked_local_frames[:, 1:, :, :, :].reshape(
-1, c, h // 4, w // 4)
pred_flows_forward = self.update_spynet(mlf_1, mlf_2)
pred_flows_backward = self.update_spynet(mlf_2, mlf_1)
pred_flows_forward = pred_flows_forward.view(b, l_t - 1, 2, h // 4,
w // 4)
pred_flows_backward = pred_flows_backward.view(b, l_t - 1, 2, h // 4,
w // 4)
return pred_flows_forward, pred_flows_backward
def forward(self, masked_frames, num_local_frames):
l_t = num_local_frames
b, t, ori_c, ori_h, ori_w = masked_frames.size()
# normalization before feeding into the flow completion module
masked_local_frames = (masked_frames[:, :l_t, ...] + 1) / 2
pred_flows = self.forward_bidirect_flow(masked_local_frames)
# extracting features and performing the feature propagation on local features
enc_feat = self.encoder(masked_frames.view(b * t, ori_c, ori_h, ori_w))
_, c, h, w = enc_feat.size()
local_feat = enc_feat.view(b, t, c, h, w)[:, :l_t, ...]
ref_feat = enc_feat.view(b, t, c, h, w)[:, l_t:, ...]
local_feat = self.feat_prop_module(local_feat, pred_flows[0],
pred_flows[1])
enc_feat = torch.cat((local_feat, ref_feat), dim=1)
# content hallucination through stacking multiple temporal focal transformer blocks
trans_feat = self.ss(enc_feat.view(-1, c, h, w), b)
trans_feat = self.transformer(trans_feat)
trans_feat = self.sc(trans_feat, t)
trans_feat = trans_feat.view(b, t, -1, h, w)
enc_feat = enc_feat + trans_feat
# decode frames from features
output = self.decoder(enc_feat.view(b * t, c, h, w))
output = torch.tanh(output)
return output, pred_flows
# ######################################################################
# Discriminator for Temporal Patch GAN
# ######################################################################
class Discriminator(BaseNetwork):
def __init__(self,
in_channels=3,
use_sigmoid=False,
use_spectral_norm=True,
init_weights=True):
super(Discriminator, self).__init__()
self.use_sigmoid = use_sigmoid
nf = 32
self.conv = nn.Sequential(
spectral_norm(
nn.Conv3d(in_channels=in_channels,
out_channels=nf * 1,
kernel_size=(3, 5, 5),
stride=(1, 2, 2),
padding=1,
bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(64, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(
nn.Conv3d(nf * 1,
nf * 2,
kernel_size=(3, 5, 5),
stride=(1, 2, 2),
padding=(1, 2, 2),
bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(128, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(
nn.Conv3d(nf * 2,
nf * 4,
kernel_size=(3, 5, 5),
stride=(1, 2, 2),
padding=(1, 2, 2),
bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(256, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(
nn.Conv3d(nf * 4,
nf * 4,
kernel_size=(3, 5, 5),
stride=(1, 2, 2),
padding=(1, 2, 2),
bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(256, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(
nn.Conv3d(nf * 4,
nf * 4,
kernel_size=(3, 5, 5),
stride=(1, 2, 2),
padding=(1, 2, 2),
bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(256, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(nf * 4,
nf * 4,
kernel_size=(3, 5, 5),
stride=(1, 2, 2),
padding=(1, 2, 2)))
if init_weights:
self.init_weights()
def forward(self, xs):
# T, C, H, W = xs.shape (old)
# B, T, C, H, W (new)
xs_t = torch.transpose(xs, 1, 2)
feat = self.conv(xs_t)
if self.use_sigmoid:
feat = torch.sigmoid(feat)
out = torch.transpose(feat, 1, 2) # B, T, C, H, W
return out
def spectral_norm(module, mode=True):
if mode:
return _spectral_norm(module)
return module