photo2video / modules /model.py
Tolga
version1
280b585
raw
history blame
7.7 kB
from torch import nn
import torch
import torch.nn.functional as F
from modules.util import AntiAliasInterpolation2d, TPS
from torchvision import models
import numpy as np
class Vgg19(torch.nn.Module):
"""
Vgg19 network for perceptual loss. See Sec 3.3.
"""
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
vgg_pretrained_features = models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
requires_grad=False)
self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
requires_grad=False)
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
X = (X - self.mean) / self.std
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
class ImagePyramide(torch.nn.Module):
"""
Create image pyramide for computing pyramide perceptual loss. See Sec 3.3
"""
def __init__(self, scales, num_channels):
super(ImagePyramide, self).__init__()
downs = {}
for scale in scales:
downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
self.downs = nn.ModuleDict(downs)
def forward(self, x):
out_dict = {}
for scale, down_module in self.downs.items():
out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
return out_dict
def detach_kp(kp):
return {key: value.detach() for key, value in kp.items()}
class GeneratorFullModel(torch.nn.Module):
"""
Merge all generator related updates into single model for better multi-gpu usage
"""
def __init__(self, kp_extractor, bg_predictor, dense_motion_network, inpainting_network, train_params, *kwargs):
super(GeneratorFullModel, self).__init__()
self.kp_extractor = kp_extractor
self.inpainting_network = inpainting_network
self.dense_motion_network = dense_motion_network
self.bg_predictor = None
if bg_predictor:
self.bg_predictor = bg_predictor
self.bg_start = train_params['bg_start']
self.train_params = train_params
self.scales = train_params['scales']
self.pyramid = ImagePyramide(self.scales, inpainting_network.num_channels)
if torch.cuda.is_available():
self.pyramid = self.pyramid.cuda()
self.loss_weights = train_params['loss_weights']
self.dropout_epoch = train_params['dropout_epoch']
self.dropout_maxp = train_params['dropout_maxp']
self.dropout_inc_epoch = train_params['dropout_inc_epoch']
self.dropout_startp =train_params['dropout_startp']
if sum(self.loss_weights['perceptual']) != 0:
self.vgg = Vgg19()
if torch.cuda.is_available():
self.vgg = self.vgg.cuda()
def forward(self, x, epoch):
kp_source = self.kp_extractor(x['source'])
kp_driving = self.kp_extractor(x['driving'])
bg_param = None
if self.bg_predictor:
if(epoch>=self.bg_start):
bg_param = self.bg_predictor(x['source'], x['driving'])
if(epoch>=self.dropout_epoch):
dropout_flag = False
dropout_p = 0
else:
# dropout_p will linearly increase from dropout_startp to dropout_maxp
dropout_flag = True
dropout_p = min(epoch/self.dropout_inc_epoch * self.dropout_maxp + self.dropout_startp, self.dropout_maxp)
dense_motion = self.dense_motion_network(source_image=x['source'], kp_driving=kp_driving,
kp_source=kp_source, bg_param = bg_param,
dropout_flag = dropout_flag, dropout_p = dropout_p)
generated = self.inpainting_network(x['source'], dense_motion)
generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
loss_values = {}
pyramide_real = self.pyramid(x['driving'])
pyramide_generated = self.pyramid(generated['prediction'])
# reconstruction loss
if sum(self.loss_weights['perceptual']) != 0:
value_total = 0
for scale in self.scales:
x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
for i, weight in enumerate(self.loss_weights['perceptual']):
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
value_total += self.loss_weights['perceptual'][i] * value
loss_values['perceptual'] = value_total
# equivariance loss
if self.loss_weights['equivariance_value'] != 0:
transform_random = TPS(mode = 'random', bs = x['driving'].shape[0], **self.train_params['transform_params'])
transform_grid = transform_random.transform_frame(x['driving'])
transformed_frame = F.grid_sample(x['driving'], transform_grid, padding_mode="reflection",align_corners=True)
transformed_kp = self.kp_extractor(transformed_frame)
generated['transformed_frame'] = transformed_frame
generated['transformed_kp'] = transformed_kp
warped = transform_random.warp_coordinates(transformed_kp['fg_kp'])
kp_d = kp_driving['fg_kp']
value = torch.abs(kp_d - warped).mean()
loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value
# warp loss
if self.loss_weights['warp_loss'] != 0:
occlusion_map = generated['occlusion_map']
encode_map = self.inpainting_network.get_encode(x['driving'], occlusion_map)
decode_map = generated['warped_encoder_maps']
value = 0
for i in range(len(encode_map)):
value += torch.abs(encode_map[i]-decode_map[-i-1]).mean()
loss_values['warp_loss'] = self.loss_weights['warp_loss'] * value
# bg loss
if self.bg_predictor and epoch >= self.bg_start and self.loss_weights['bg'] != 0:
bg_param_reverse = self.bg_predictor(x['driving'], x['source'])
value = torch.matmul(bg_param, bg_param_reverse)
eye = torch.eye(3).view(1, 1, 3, 3).type(value.type())
value = torch.abs(eye - value).mean()
loss_values['bg'] = self.loss_weights['bg'] * value
return loss_values, generated