from torch import nn import torch import torch.nn.functional as F from modules.util import AntiAliasInterpolation2d, TPS from torchvision import models import numpy as np class Vgg19(torch.nn.Module): """ Vgg19 network for perceptual loss. See Sec 3.3. """ def __init__(self, requires_grad=False): super(Vgg19, self).__init__() vgg_pretrained_features = models.vgg19(pretrained=True).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() for x in range(2): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(2, 7): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(7, 12): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(12, 21): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(21, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))), requires_grad=False) self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))), requires_grad=False) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): X = (X - self.mean) / self.std h_relu1 = self.slice1(X) h_relu2 = self.slice2(h_relu1) h_relu3 = self.slice3(h_relu2) h_relu4 = self.slice4(h_relu3) h_relu5 = self.slice5(h_relu4) out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] return out class ImagePyramide(torch.nn.Module): """ Create image pyramide for computing pyramide perceptual loss. See Sec 3.3 """ def __init__(self, scales, num_channels): super(ImagePyramide, self).__init__() downs = {} for scale in scales: downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale) self.downs = nn.ModuleDict(downs) def forward(self, x): out_dict = {} for scale, down_module in self.downs.items(): out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x) return out_dict def detach_kp(kp): return {key: value.detach() for key, value in kp.items()} class GeneratorFullModel(torch.nn.Module): """ Merge all generator related updates into single model for better multi-gpu usage """ def __init__(self, kp_extractor, bg_predictor, dense_motion_network, inpainting_network, train_params, *kwargs): super(GeneratorFullModel, self).__init__() self.kp_extractor = kp_extractor self.inpainting_network = inpainting_network self.dense_motion_network = dense_motion_network self.bg_predictor = None if bg_predictor: self.bg_predictor = bg_predictor self.bg_start = train_params['bg_start'] self.train_params = train_params self.scales = train_params['scales'] self.pyramid = ImagePyramide(self.scales, inpainting_network.num_channels) if torch.cuda.is_available(): self.pyramid = self.pyramid.cuda() self.loss_weights = train_params['loss_weights'] self.dropout_epoch = train_params['dropout_epoch'] self.dropout_maxp = train_params['dropout_maxp'] self.dropout_inc_epoch = train_params['dropout_inc_epoch'] self.dropout_startp =train_params['dropout_startp'] if sum(self.loss_weights['perceptual']) != 0: self.vgg = Vgg19() if torch.cuda.is_available(): self.vgg = self.vgg.cuda() def forward(self, x, epoch): kp_source = self.kp_extractor(x['source']) kp_driving = self.kp_extractor(x['driving']) bg_param = None if self.bg_predictor: if(epoch>=self.bg_start): bg_param = self.bg_predictor(x['source'], x['driving']) if(epoch>=self.dropout_epoch): dropout_flag = False dropout_p = 0 else: # dropout_p will linearly increase from dropout_startp to dropout_maxp dropout_flag = True dropout_p = min(epoch/self.dropout_inc_epoch * self.dropout_maxp + self.dropout_startp, self.dropout_maxp) dense_motion = self.dense_motion_network(source_image=x['source'], kp_driving=kp_driving, kp_source=kp_source, bg_param = bg_param, dropout_flag = dropout_flag, dropout_p = dropout_p) generated = self.inpainting_network(x['source'], dense_motion) generated.update({'kp_source': kp_source, 'kp_driving': kp_driving}) loss_values = {} pyramide_real = self.pyramid(x['driving']) pyramide_generated = self.pyramid(generated['prediction']) # reconstruction loss if sum(self.loss_weights['perceptual']) != 0: value_total = 0 for scale in self.scales: x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)]) y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)]) for i, weight in enumerate(self.loss_weights['perceptual']): value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() value_total += self.loss_weights['perceptual'][i] * value loss_values['perceptual'] = value_total # equivariance loss if self.loss_weights['equivariance_value'] != 0: transform_random = TPS(mode = 'random', bs = x['driving'].shape[0], **self.train_params['transform_params']) transform_grid = transform_random.transform_frame(x['driving']) transformed_frame = F.grid_sample(x['driving'], transform_grid, padding_mode="reflection",align_corners=True) transformed_kp = self.kp_extractor(transformed_frame) generated['transformed_frame'] = transformed_frame generated['transformed_kp'] = transformed_kp warped = transform_random.warp_coordinates(transformed_kp['fg_kp']) kp_d = kp_driving['fg_kp'] value = torch.abs(kp_d - warped).mean() loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value # warp loss if self.loss_weights['warp_loss'] != 0: occlusion_map = generated['occlusion_map'] encode_map = self.inpainting_network.get_encode(x['driving'], occlusion_map) decode_map = generated['warped_encoder_maps'] value = 0 for i in range(len(encode_map)): value += torch.abs(encode_map[i]-decode_map[-i-1]).mean() loss_values['warp_loss'] = self.loss_weights['warp_loss'] * value # bg loss if self.bg_predictor and epoch >= self.bg_start and self.loss_weights['bg'] != 0: bg_param_reverse = self.bg_predictor(x['driving'], x['source']) value = torch.matmul(bg_param, bg_param_reverse) eye = torch.eye(3).view(1, 1, 3, 3).type(value.type()) value = torch.abs(eye - value).mean() loss_values['bg'] = self.loss_weights['bg'] * value return loss_values, generated