Tolga commited on
Commit
280b585
1 Parent(s): 79da588
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+
4
+ import imageio
5
+ import imageio_ffmpeg
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib.animation as animation
9
+ from skimage.transform import resize
10
+ from IPython.display import HTML
11
+ import warnings
12
+ import os
13
+ from model import load_checkpoints
14
+ from model import make_animation
15
+ from skimage import img_as_ubyte
16
+ from PIL import Image
17
+ import time
18
+ warnings.filterwarnings("ignore")
19
+
20
+ device = torch.device('cuda:0')
21
+ device = torch.device('cpu')
22
+
23
+
24
+ dataset_name = 'vox' # ['vox', 'taichi', 'ted', 'mgif']
25
+ source_image_path = './assets/source.png'
26
+ driving_video_path = './assets/driving.mp4'
27
+ output_video_path = './generated.mp4'
28
+ config_path = './config/vox-256.yaml'
29
+ checkpoint_path = 'checkpoints/vox.pth.tar'
30
+ checkpoint_path = 'vox.pth.tar'
31
+ predict_mode = 'relative' # ['standard', 'relative', 'avd']
32
+ find_best_frame = False # when use the relative mode to animate a face, use 'find_best_frame=True' can get better quality result
33
+
34
+ pixel = 256 # for vox, taichi and mgif, the resolution is 256*256
35
+ if(dataset_name == 'ted'): # for ted, the resolution is 384*384
36
+ pixel = 384
37
+
38
+ if find_best_frame:
39
+ #!pip install face_alignment
40
+ pass
41
+
42
+
43
+ def create_video(tt):
44
+
45
+ source_image = imageio.imread(f"assets/img_{tt}.jpg")
46
+ reader = imageio.get_reader(f"assets/ref_{tt}.mp4")
47
+
48
+ source_image = resize(source_image, (pixel, pixel))[..., :3]
49
+
50
+ fps = reader.get_meta_data()['fps']
51
+ driving_video = []
52
+ try:
53
+ for im in reader:
54
+ driving_video.append(im)
55
+ except RuntimeError:
56
+ pass
57
+ reader.close()
58
+
59
+ driving_video = [resize(frame, (pixel, pixel))[..., :3] for frame in driving_video]
60
+
61
+ def display(source, driving, generated=None):
62
+ fig = plt.figure(figsize=(8 + 4 * (generated is not None), 6))
63
+
64
+ ims = []
65
+ for i in range(len(driving)):
66
+ cols = [source]
67
+ cols.append(driving[i])
68
+ if generated is not None:
69
+ cols.append(generated[i])
70
+ im = plt.imshow(np.concatenate(cols, axis=1), animated=True)
71
+ plt.axis('off')
72
+ ims.append([im])
73
+
74
+ ani = animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=1000)
75
+ plt.close()
76
+ return ani
77
+
78
+
79
+ #HTML(display(source_image, driving_video).to_html5_video())
80
+ inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path = config_path, checkpoint_path = checkpoint_path, device = device)
81
+
82
+
83
+
84
+ if predict_mode=='relative' and find_best_frame:
85
+ from model import find_best_frame as _find
86
+ i = _find(source_image, driving_video, device.type=='cpu')
87
+ print ("Best frame: " + str(i))
88
+ driving_forward = driving_video[i:]
89
+ driving_backward = driving_video[:(i+1)][::-1]
90
+ predictions_forward = make_animation(source_image, driving_forward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)
91
+ predictions_backward = make_animation(source_image, driving_backward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)
92
+ predictions = predictions_backward[::-1] + predictions_forward[1:]
93
+ else:
94
+ predictions = make_animation(source_image, driving_video, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)
95
+
96
+ #save resulting video
97
+ imageio.mimsave(f"./assets/output_{tt}.mp4", [img_as_ubyte(frame) for frame in predictions], fps=fps)
98
+
99
+
100
+ def greet(img,video):
101
+ tt=str(time.time())
102
+ os.replace(video, f"assets/ref_{tt}.mp4")
103
+ img.save(f"assets/img_{tt}.jpg")
104
+ create_video(tt)
105
+ return f"./assets/output_{tt}.mp4"
106
+
107
+
108
+ iface = gr.Interface(fn=greet, inputs=[gr.inputs.Image(type="pil"),gr.inputs.Video()], outputs=gr.inputs.Video())
109
+ iface.launch()
config/vox-256.yaml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_params:
2
+ root_dir: ../vox
3
+ frame_shape: null
4
+ id_sampling: True
5
+ augmentation_params:
6
+ flip_param:
7
+ horizontal_flip: True
8
+ time_flip: True
9
+ jitter_param:
10
+ brightness: 0.1
11
+ contrast: 0.1
12
+ saturation: 0.1
13
+ hue: 0.1
14
+
15
+
16
+ model_params:
17
+ common_params:
18
+ num_tps: 10
19
+ num_channels: 3
20
+ bg: True
21
+ multi_mask: True
22
+ generator_params:
23
+ block_expansion: 64
24
+ max_features: 512
25
+ num_down_blocks: 3
26
+ dense_motion_params:
27
+ block_expansion: 64
28
+ max_features: 1024
29
+ num_blocks: 5
30
+ scale_factor: 0.25
31
+ avd_network_params:
32
+ id_bottle_size: 128
33
+ pose_bottle_size: 128
34
+
35
+
36
+ train_params:
37
+ num_epochs: 100
38
+ num_repeats: 75
39
+ epoch_milestones: [70, 90]
40
+ lr_generator: 2.0e-4
41
+ batch_size: 28
42
+ scales: [1, 0.5, 0.25, 0.125]
43
+ dataloader_workers: 12
44
+ checkpoint_freq: 50
45
+ dropout_epoch: 35
46
+ dropout_maxp: 0.3
47
+ dropout_startp: 0.1
48
+ dropout_inc_epoch: 10
49
+ bg_start: 10
50
+ transform_params:
51
+ sigma_affine: 0.05
52
+ sigma_tps: 0.005
53
+ points_tps: 5
54
+ loss_weights:
55
+ perceptual: [10, 10, 10, 10, 10]
56
+ equivariance_value: 10
57
+ warp_loss: 10
58
+ bg: 10
59
+
60
+ train_avd_params:
61
+ num_epochs: 200
62
+ num_repeats: 300
63
+ batch_size: 256
64
+ dataloader_workers: 24
65
+ checkpoint_freq: 50
66
+ epoch_milestones: [140, 180]
67
+ lr: 1.0e-3
68
+ lambda_shift: 1
69
+ random_scale: 0.25
70
+
71
+ visualizer_params:
72
+ kp_size: 5
73
+ draw_border: True
74
+ colormap: 'gist_rainbow'
model.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ matplotlib.use('Agg')
3
+ import sys
4
+ import yaml
5
+ from argparse import ArgumentParser
6
+ from tqdm import tqdm
7
+ from scipy.spatial import ConvexHull
8
+ import numpy as np
9
+ import imageio
10
+ from skimage.transform import resize
11
+ from skimage import img_as_ubyte
12
+ import torch
13
+ from modules.inpainting_network import InpaintingNetwork
14
+ from modules.keypoint_detector import KPDetector
15
+ from modules.dense_motion import DenseMotionNetwork
16
+ from modules.avd_network import AVDNetwork
17
+
18
+ def load_checkpoints(config_path, checkpoint_path, device):
19
+ with open(config_path) as f:
20
+ config = yaml.full_load(f)
21
+
22
+ inpainting = InpaintingNetwork(**config['model_params']['generator_params'],
23
+ **config['model_params']['common_params'])
24
+ kp_detector = KPDetector(**config['model_params']['common_params'])
25
+ dense_motion_network = DenseMotionNetwork(**config['model_params']['common_params'],
26
+ **config['model_params']['dense_motion_params'])
27
+ avd_network = AVDNetwork(num_tps=config['model_params']['common_params']['num_tps'],
28
+ **config['model_params']['avd_network_params'])
29
+ kp_detector.to(device)
30
+ dense_motion_network.to(device)
31
+ inpainting.to(device)
32
+ avd_network.to(device)
33
+
34
+ checkpoint = torch.load(checkpoint_path, map_location=device)
35
+
36
+ inpainting.load_state_dict(checkpoint['inpainting_network'])
37
+ kp_detector.load_state_dict(checkpoint['kp_detector'])
38
+ dense_motion_network.load_state_dict(checkpoint['dense_motion_network'])
39
+ if 'avd_network' in checkpoint:
40
+ avd_network.load_state_dict(checkpoint['avd_network'])
41
+
42
+ inpainting.eval()
43
+ kp_detector.eval()
44
+ dense_motion_network.eval()
45
+ avd_network.eval()
46
+
47
+ return inpainting, kp_detector, dense_motion_network, avd_network
48
+
49
+ def relative_kp(kp_source, kp_driving, kp_driving_initial):
50
+
51
+ source_area = ConvexHull(kp_source['fg_kp'][0].data.cpu().numpy()).volume
52
+ driving_area = ConvexHull(kp_driving_initial['fg_kp'][0].data.cpu().numpy()).volume
53
+ adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
54
+
55
+ kp_new = {k: v for k, v in kp_driving.items()}
56
+
57
+ kp_value_diff = (kp_driving['fg_kp'] - kp_driving_initial['fg_kp'])
58
+ kp_value_diff *= adapt_movement_scale
59
+ kp_new['fg_kp'] = kp_value_diff + kp_source['fg_kp']
60
+
61
+ return kp_new
62
+
63
+ def make_animation(source_image, driving_video, inpainting_network, kp_detector, dense_motion_network, avd_network, device, mode = 'relative'):
64
+ assert mode in ['standard', 'relative', 'avd']
65
+ with torch.no_grad():
66
+ predictions = []
67
+ source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
68
+ source = source.to(device)
69
+ driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3).to(device)
70
+ kp_source = kp_detector(source)
71
+ kp_driving_initial = kp_detector(driving[:, :, 0])
72
+
73
+ for frame_idx in tqdm(range(driving.shape[2])):
74
+ driving_frame = driving[:, :, frame_idx]
75
+ driving_frame = driving_frame.to(device)
76
+ kp_driving = kp_detector(driving_frame)
77
+ if mode == 'standard':
78
+ kp_norm = kp_driving
79
+ elif mode=='relative':
80
+ kp_norm = relative_kp(kp_source=kp_source, kp_driving=kp_driving,
81
+ kp_driving_initial=kp_driving_initial)
82
+ elif mode == 'avd':
83
+ kp_norm = avd_network(kp_source, kp_driving)
84
+ dense_motion = dense_motion_network(source_image=source, kp_driving=kp_norm,
85
+ kp_source=kp_source, bg_param = None,
86
+ dropout_flag = False)
87
+ out = inpainting_network(source, dense_motion)
88
+
89
+ predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
90
+ return predictions
91
+
92
+
93
+ def find_best_frame(source, driving, cpu):
94
+ import face_alignment
95
+
96
+ def normalize_kp(kp):
97
+ kp = kp - kp.mean(axis=0, keepdims=True)
98
+ area = ConvexHull(kp[:, :2]).volume
99
+ area = np.sqrt(area)
100
+ kp[:, :2] = kp[:, :2] / area
101
+ return kp
102
+
103
+ fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
104
+ device= 'cpu' if cpu else 'cuda')
105
+ kp_source = fa.get_landmarks(255 * source)[0]
106
+ kp_source = normalize_kp(kp_source)
107
+ norm = float('inf')
108
+ frame_num = 0
109
+ for i, image in tqdm(enumerate(driving)):
110
+ try:
111
+ kp_driving = fa.get_landmarks(255 * image)[0]
112
+ kp_driving = normalize_kp(kp_driving)
113
+ new_norm = (np.abs(kp_source - kp_driving) ** 2).sum()
114
+ if new_norm < norm:
115
+ norm = new_norm
116
+ frame_num = i
117
+ except:
118
+ pass
119
+ return frame_num
modules/.DS_Store ADDED
Binary file (6.15 kB). View file
 
modules/__pycache__/avd_network.cpython-310.pyc ADDED
Binary file (1.57 kB). View file
 
modules/__pycache__/dense_motion.cpython-310.pyc ADDED
Binary file (5.67 kB). View file
 
modules/__pycache__/inpainting_network.cpython-310.pyc ADDED
Binary file (3.75 kB). View file
 
modules/__pycache__/keypoint_detector.cpython-310.pyc ADDED
Binary file (1.12 kB). View file
 
modules/__pycache__/util.cpython-310.pyc ADDED
Binary file (10.8 kB). View file
 
modules/avd_network.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ class AVDNetwork(nn.Module):
7
+ """
8
+ Animation via Disentanglement network
9
+ """
10
+
11
+ def __init__(self, num_tps, id_bottle_size=64, pose_bottle_size=64):
12
+ super(AVDNetwork, self).__init__()
13
+ input_size = 5*2 * num_tps
14
+ self.num_tps = num_tps
15
+
16
+ self.id_encoder = nn.Sequential(
17
+ nn.Linear(input_size, 256),
18
+ nn.BatchNorm1d(256),
19
+ nn.ReLU(inplace=True),
20
+ nn.Linear(256, 512),
21
+ nn.BatchNorm1d(512),
22
+ nn.ReLU(inplace=True),
23
+ nn.Linear(512, 1024),
24
+ nn.BatchNorm1d(1024),
25
+ nn.ReLU(inplace=True),
26
+ nn.Linear(1024, id_bottle_size)
27
+ )
28
+
29
+ self.pose_encoder = nn.Sequential(
30
+ nn.Linear(input_size, 256),
31
+ nn.BatchNorm1d(256),
32
+ nn.ReLU(inplace=True),
33
+ nn.Linear(256, 512),
34
+ nn.BatchNorm1d(512),
35
+ nn.ReLU(inplace=True),
36
+ nn.Linear(512, 1024),
37
+ nn.BatchNorm1d(1024),
38
+ nn.ReLU(inplace=True),
39
+ nn.Linear(1024, pose_bottle_size)
40
+ )
41
+
42
+ self.decoder = nn.Sequential(
43
+ nn.Linear(pose_bottle_size + id_bottle_size, 1024),
44
+ nn.BatchNorm1d(1024),
45
+ nn.ReLU(),
46
+ nn.Linear(1024, 512),
47
+ nn.BatchNorm1d(512),
48
+ nn.ReLU(),
49
+ nn.Linear(512, 256),
50
+ nn.BatchNorm1d(256),
51
+ nn.ReLU(),
52
+ nn.Linear(256, input_size)
53
+ )
54
+
55
+ def forward(self, kp_source, kp_random):
56
+
57
+ bs = kp_source['fg_kp'].shape[0]
58
+
59
+ pose_emb = self.pose_encoder(kp_random['fg_kp'].view(bs, -1))
60
+ id_emb = self.id_encoder(kp_source['fg_kp'].view(bs, -1))
61
+
62
+ rec = self.decoder(torch.cat([pose_emb, id_emb], dim=1))
63
+
64
+ rec = {'fg_kp': rec.view(bs, self.num_tps*5, -1)}
65
+ return rec
modules/bg_motion_predictor.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ from torchvision import models
4
+
5
+ class BGMotionPredictor(nn.Module):
6
+ """
7
+ Module for background estimation, return single transformation, parametrized as 3x3 matrix. The third row is [0 0 1]
8
+ """
9
+
10
+ def __init__(self):
11
+ super(BGMotionPredictor, self).__init__()
12
+ self.bg_encoder = models.resnet18(pretrained=False)
13
+ self.bg_encoder.conv1 = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
14
+ num_features = self.bg_encoder.fc.in_features
15
+ self.bg_encoder.fc = nn.Linear(num_features, 6)
16
+ self.bg_encoder.fc.weight.data.zero_()
17
+ self.bg_encoder.fc.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
18
+
19
+ def forward(self, source_image, driving_image):
20
+ bs = source_image.shape[0]
21
+ out = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).type(source_image.type())
22
+ prediction = self.bg_encoder(torch.cat([source_image, driving_image], dim=1))
23
+ out[:, :2, :] = prediction.view(bs, 2, 3)
24
+ return out
modules/dense_motion.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch.nn.functional as F
3
+ import torch
4
+ from modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian
5
+ from modules.util import to_homogeneous, from_homogeneous, UpBlock2d, TPS
6
+ import math
7
+
8
+ class DenseMotionNetwork(nn.Module):
9
+ """
10
+ Module that estimating an optical flow and multi-resolution occlusion masks
11
+ from K TPS transformations and an affine transformation.
12
+ """
13
+
14
+ def __init__(self, block_expansion, num_blocks, max_features, num_tps, num_channels,
15
+ scale_factor=0.25, bg = False, multi_mask = True, kp_variance=0.01):
16
+ super(DenseMotionNetwork, self).__init__()
17
+
18
+ if scale_factor != 1:
19
+ self.down = AntiAliasInterpolation2d(num_channels, scale_factor)
20
+ self.scale_factor = scale_factor
21
+ self.multi_mask = multi_mask
22
+
23
+ self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_channels * (num_tps+1) + num_tps*5+1),
24
+ max_features=max_features, num_blocks=num_blocks)
25
+
26
+ hourglass_output_size = self.hourglass.out_channels
27
+ self.maps = nn.Conv2d(hourglass_output_size[-1], num_tps + 1, kernel_size=(7, 7), padding=(3, 3))
28
+
29
+ if multi_mask:
30
+ up = []
31
+ self.up_nums = int(math.log(1/scale_factor, 2))
32
+ self.occlusion_num = 4
33
+
34
+ channel = [hourglass_output_size[-1]//(2**i) for i in range(self.up_nums)]
35
+ for i in range(self.up_nums):
36
+ up.append(UpBlock2d(channel[i], channel[i]//2, kernel_size=3, padding=1))
37
+ self.up = nn.ModuleList(up)
38
+
39
+ channel = [hourglass_output_size[-i-1] for i in range(self.occlusion_num-self.up_nums)[::-1]]
40
+ for i in range(self.up_nums):
41
+ channel.append(hourglass_output_size[-1]//(2**(i+1)))
42
+ occlusion = []
43
+
44
+ for i in range(self.occlusion_num):
45
+ occlusion.append(nn.Conv2d(channel[i], 1, kernel_size=(7, 7), padding=(3, 3)))
46
+ self.occlusion = nn.ModuleList(occlusion)
47
+ else:
48
+ occlusion = [nn.Conv2d(hourglass_output_size[-1], 1, kernel_size=(7, 7), padding=(3, 3))]
49
+ self.occlusion = nn.ModuleList(occlusion)
50
+
51
+ self.num_tps = num_tps
52
+ self.bg = bg
53
+ self.kp_variance = kp_variance
54
+
55
+
56
+ def create_heatmap_representations(self, source_image, kp_driving, kp_source):
57
+
58
+ spatial_size = source_image.shape[2:]
59
+ gaussian_driving = kp2gaussian(kp_driving['fg_kp'], spatial_size=spatial_size, kp_variance=self.kp_variance)
60
+ gaussian_source = kp2gaussian(kp_source['fg_kp'], spatial_size=spatial_size, kp_variance=self.kp_variance)
61
+ heatmap = gaussian_driving - gaussian_source
62
+
63
+ zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type()).to(heatmap.device)
64
+ heatmap = torch.cat([zeros, heatmap], dim=1)
65
+
66
+ return heatmap
67
+
68
+ def create_transformations(self, source_image, kp_driving, kp_source, bg_param):
69
+ # K TPS transformaions
70
+ bs, _, h, w = source_image.shape
71
+ kp_1 = kp_driving['fg_kp']
72
+ kp_2 = kp_source['fg_kp']
73
+ kp_1 = kp_1.view(bs, -1, 5, 2)
74
+ kp_2 = kp_2.view(bs, -1, 5, 2)
75
+ trans = TPS(mode = 'kp', bs = bs, kp_1 = kp_1, kp_2 = kp_2)
76
+ driving_to_source = trans.transform_frame(source_image)
77
+
78
+ identity_grid = make_coordinate_grid((h, w), type=kp_1.type()).to(kp_1.device)
79
+ identity_grid = identity_grid.view(1, 1, h, w, 2)
80
+ identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1)
81
+
82
+ # affine background transformation
83
+ if not (bg_param is None):
84
+ identity_grid = to_homogeneous(identity_grid)
85
+ identity_grid = torch.matmul(bg_param.view(bs, 1, 1, 1, 3, 3), identity_grid.unsqueeze(-1)).squeeze(-1)
86
+ identity_grid = from_homogeneous(identity_grid)
87
+
88
+ transformations = torch.cat([identity_grid, driving_to_source], dim=1)
89
+ return transformations
90
+
91
+ def create_deformed_source_image(self, source_image, transformations):
92
+
93
+ bs, _, h, w = source_image.shape
94
+ source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_tps + 1, 1, 1, 1, 1)
95
+ source_repeat = source_repeat.view(bs * (self.num_tps + 1), -1, h, w)
96
+ transformations = transformations.view((bs * (self.num_tps + 1), h, w, -1))
97
+ deformed = F.grid_sample(source_repeat, transformations, align_corners=True)
98
+ deformed = deformed.view((bs, self.num_tps+1, -1, h, w))
99
+ return deformed
100
+
101
+ def dropout_softmax(self, X, P):
102
+ '''
103
+ Dropout for TPS transformations. Eq(7) and Eq(8) in the paper.
104
+ '''
105
+ drop = (torch.rand(X.shape[0],X.shape[1]) < (1-P)).type(X.type()).to(X.device)
106
+ drop[..., 0] = 1
107
+ drop = drop.repeat(X.shape[2],X.shape[3],1,1).permute(2,3,0,1)
108
+
109
+ maxx = X.max(1).values.unsqueeze_(1)
110
+ X = X - maxx
111
+ X_exp = X.exp()
112
+ X[:,1:,...] /= (1-P)
113
+ mask_bool =(drop == 0)
114
+ X_exp = X_exp.masked_fill(mask_bool, 0)
115
+ partition = X_exp.sum(dim=1, keepdim=True) + 1e-6
116
+ return X_exp / partition
117
+
118
+ def forward(self, source_image, kp_driving, kp_source, bg_param = None, dropout_flag=False, dropout_p = 0):
119
+ if self.scale_factor != 1:
120
+ source_image = self.down(source_image)
121
+
122
+ bs, _, h, w = source_image.shape
123
+
124
+ out_dict = dict()
125
+ heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source)
126
+ transformations = self.create_transformations(source_image, kp_driving, kp_source, bg_param)
127
+ deformed_source = self.create_deformed_source_image(source_image, transformations)
128
+ out_dict['deformed_source'] = deformed_source
129
+ # out_dict['transformations'] = transformations
130
+ deformed_source = deformed_source.view(bs,-1,h,w)
131
+ input = torch.cat([heatmap_representation, deformed_source], dim=1)
132
+ input = input.view(bs, -1, h, w)
133
+
134
+ prediction = self.hourglass(input, mode = 1)
135
+
136
+ contribution_maps = self.maps(prediction[-1])
137
+ if(dropout_flag):
138
+ contribution_maps = self.dropout_softmax(contribution_maps, dropout_p)
139
+ else:
140
+ contribution_maps = F.softmax(contribution_maps, dim=1)
141
+ out_dict['contribution_maps'] = contribution_maps
142
+
143
+ # Combine the K+1 transformations
144
+ # Eq(6) in the paper
145
+ contribution_maps = contribution_maps.unsqueeze(2)
146
+ transformations = transformations.permute(0, 1, 4, 2, 3)
147
+ deformation = (transformations * contribution_maps).sum(dim=1)
148
+ deformation = deformation.permute(0, 2, 3, 1)
149
+
150
+ out_dict['deformation'] = deformation # Optical Flow
151
+
152
+ occlusion_map = []
153
+ if self.multi_mask:
154
+ for i in range(self.occlusion_num-self.up_nums):
155
+ occlusion_map.append(torch.sigmoid(self.occlusion[i](prediction[self.up_nums-self.occlusion_num+i])))
156
+ prediction = prediction[-1]
157
+ for i in range(self.up_nums):
158
+ prediction = self.up[i](prediction)
159
+ occlusion_map.append(torch.sigmoid(self.occlusion[i+self.occlusion_num-self.up_nums](prediction)))
160
+ else:
161
+ occlusion_map.append(torch.sigmoid(self.occlusion[0](prediction[-1])))
162
+
163
+ out_dict['occlusion_map'] = occlusion_map # Multi-resolution Occlusion Masks
164
+ return out_dict
modules/inpainting_network.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d
5
+ from modules.dense_motion import DenseMotionNetwork
6
+
7
+
8
+ class InpaintingNetwork(nn.Module):
9
+ """
10
+ Inpaint the missing regions and reconstruct the Driving image.
11
+ """
12
+ def __init__(self, num_channels, block_expansion, max_features, num_down_blocks, multi_mask = True, **kwargs):
13
+ super(InpaintingNetwork, self).__init__()
14
+
15
+ self.num_down_blocks = num_down_blocks
16
+ self.multi_mask = multi_mask
17
+ self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3))
18
+
19
+ down_blocks = []
20
+ up_blocks = []
21
+ resblock = []
22
+ for i in range(num_down_blocks):
23
+ in_features = min(max_features, block_expansion * (2 ** i))
24
+ out_features = min(max_features, block_expansion * (2 ** (i + 1)))
25
+ down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
26
+ decoder_in_feature = out_features * 2
27
+ if i==num_down_blocks-1:
28
+ decoder_in_feature = out_features
29
+ up_blocks.append(UpBlock2d(decoder_in_feature, in_features, kernel_size=(3, 3), padding=(1, 1)))
30
+ resblock.append(ResBlock2d(decoder_in_feature, kernel_size=(3, 3), padding=(1, 1)))
31
+ resblock.append(ResBlock2d(decoder_in_feature, kernel_size=(3, 3), padding=(1, 1)))
32
+ self.down_blocks = nn.ModuleList(down_blocks)
33
+ self.up_blocks = nn.ModuleList(up_blocks[::-1])
34
+ self.resblock = nn.ModuleList(resblock[::-1])
35
+
36
+ self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3))
37
+ self.num_channels = num_channels
38
+
39
+ def deform_input(self, inp, deformation):
40
+ _, h_old, w_old, _ = deformation.shape
41
+ _, _, h, w = inp.shape
42
+ if h_old != h or w_old != w:
43
+ deformation = deformation.permute(0, 3, 1, 2)
44
+ deformation = F.interpolate(deformation, size=(h, w), mode='bilinear', align_corners=True)
45
+ deformation = deformation.permute(0, 2, 3, 1)
46
+ return F.grid_sample(inp, deformation,align_corners=True)
47
+
48
+ def occlude_input(self, inp, occlusion_map):
49
+ if not self.multi_mask:
50
+ if inp.shape[2] != occlusion_map.shape[2] or inp.shape[3] != occlusion_map.shape[3]:
51
+ occlusion_map = F.interpolate(occlusion_map, size=inp.shape[2:], mode='bilinear',align_corners=True)
52
+ out = inp * occlusion_map
53
+ return out
54
+
55
+ def forward(self, source_image, dense_motion):
56
+ out = self.first(source_image)
57
+ encoder_map = [out]
58
+ for i in range(len(self.down_blocks)):
59
+ out = self.down_blocks[i](out)
60
+ encoder_map.append(out)
61
+
62
+ output_dict = {}
63
+ output_dict['contribution_maps'] = dense_motion['contribution_maps']
64
+ output_dict['deformed_source'] = dense_motion['deformed_source']
65
+
66
+ occlusion_map = dense_motion['occlusion_map']
67
+ output_dict['occlusion_map'] = occlusion_map
68
+
69
+ deformation = dense_motion['deformation']
70
+ out_ij = self.deform_input(out.detach(), deformation)
71
+ out = self.deform_input(out, deformation)
72
+
73
+ out_ij = self.occlude_input(out_ij, occlusion_map[0].detach())
74
+ out = self.occlude_input(out, occlusion_map[0])
75
+
76
+ warped_encoder_maps = []
77
+ warped_encoder_maps.append(out_ij)
78
+
79
+ for i in range(self.num_down_blocks):
80
+
81
+ out = self.resblock[2*i](out)
82
+ out = self.resblock[2*i+1](out)
83
+ out = self.up_blocks[i](out)
84
+
85
+ encode_i = encoder_map[-(i+2)]
86
+ encode_ij = self.deform_input(encode_i.detach(), deformation)
87
+ encode_i = self.deform_input(encode_i, deformation)
88
+
89
+ occlusion_ind = 0
90
+ if self.multi_mask:
91
+ occlusion_ind = i+1
92
+ encode_ij = self.occlude_input(encode_ij, occlusion_map[occlusion_ind].detach())
93
+ encode_i = self.occlude_input(encode_i, occlusion_map[occlusion_ind])
94
+ warped_encoder_maps.append(encode_ij)
95
+
96
+ if(i==self.num_down_blocks-1):
97
+ break
98
+
99
+ out = torch.cat([out, encode_i], 1)
100
+
101
+ deformed_source = self.deform_input(source_image, deformation)
102
+ output_dict["deformed"] = deformed_source
103
+ output_dict["warped_encoder_maps"] = warped_encoder_maps
104
+
105
+ occlusion_last = occlusion_map[-1]
106
+ if not self.multi_mask:
107
+ occlusion_last = F.interpolate(occlusion_last, size=out.shape[2:], mode='bilinear',align_corners=True)
108
+
109
+ out = out * (1 - occlusion_last) + encode_i
110
+ out = self.final(out)
111
+ out = torch.sigmoid(out)
112
+ out = out * (1 - occlusion_last) + deformed_source * occlusion_last
113
+ output_dict["prediction"] = out
114
+
115
+ return output_dict
116
+
117
+ def get_encode(self, driver_image, occlusion_map):
118
+ out = self.first(driver_image)
119
+ encoder_map = []
120
+ encoder_map.append(self.occlude_input(out.detach(), occlusion_map[-1].detach()))
121
+ for i in range(len(self.down_blocks)):
122
+ out = self.down_blocks[i](out.detach())
123
+ out_mask = self.occlude_input(out.detach(), occlusion_map[2-i].detach())
124
+ encoder_map.append(out_mask.detach())
125
+
126
+ return encoder_map
127
+
modules/keypoint_detector.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ from torchvision import models
4
+
5
+ class KPDetector(nn.Module):
6
+ """
7
+ Predict K*5 keypoints.
8
+ """
9
+
10
+ def __init__(self, num_tps, **kwargs):
11
+ super(KPDetector, self).__init__()
12
+ self.num_tps = num_tps
13
+
14
+ self.fg_encoder = models.resnet18(pretrained=False)
15
+ num_features = self.fg_encoder.fc.in_features
16
+ self.fg_encoder.fc = nn.Linear(num_features, num_tps*5*2)
17
+
18
+
19
+ def forward(self, image):
20
+
21
+ fg_kp = self.fg_encoder(image)
22
+ bs, _, = fg_kp.shape
23
+ fg_kp = torch.sigmoid(fg_kp)
24
+ fg_kp = fg_kp * 2 - 1
25
+ out = {'fg_kp': fg_kp.view(bs, self.num_tps*5, -1)}
26
+
27
+ return out
modules/model.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from modules.util import AntiAliasInterpolation2d, TPS
5
+ from torchvision import models
6
+ import numpy as np
7
+
8
+
9
+ class Vgg19(torch.nn.Module):
10
+ """
11
+ Vgg19 network for perceptual loss. See Sec 3.3.
12
+ """
13
+ def __init__(self, requires_grad=False):
14
+ super(Vgg19, self).__init__()
15
+ vgg_pretrained_features = models.vgg19(pretrained=True).features
16
+ self.slice1 = torch.nn.Sequential()
17
+ self.slice2 = torch.nn.Sequential()
18
+ self.slice3 = torch.nn.Sequential()
19
+ self.slice4 = torch.nn.Sequential()
20
+ self.slice5 = torch.nn.Sequential()
21
+ for x in range(2):
22
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
23
+ for x in range(2, 7):
24
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
25
+ for x in range(7, 12):
26
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
27
+ for x in range(12, 21):
28
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
29
+ for x in range(21, 30):
30
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
31
+
32
+ self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
33
+ requires_grad=False)
34
+ self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
35
+ requires_grad=False)
36
+
37
+ if not requires_grad:
38
+ for param in self.parameters():
39
+ param.requires_grad = False
40
+
41
+ def forward(self, X):
42
+ X = (X - self.mean) / self.std
43
+ h_relu1 = self.slice1(X)
44
+ h_relu2 = self.slice2(h_relu1)
45
+ h_relu3 = self.slice3(h_relu2)
46
+ h_relu4 = self.slice4(h_relu3)
47
+ h_relu5 = self.slice5(h_relu4)
48
+ out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
49
+ return out
50
+
51
+
52
+ class ImagePyramide(torch.nn.Module):
53
+ """
54
+ Create image pyramide for computing pyramide perceptual loss. See Sec 3.3
55
+ """
56
+ def __init__(self, scales, num_channels):
57
+ super(ImagePyramide, self).__init__()
58
+ downs = {}
59
+ for scale in scales:
60
+ downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
61
+ self.downs = nn.ModuleDict(downs)
62
+
63
+ def forward(self, x):
64
+ out_dict = {}
65
+ for scale, down_module in self.downs.items():
66
+ out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
67
+ return out_dict
68
+
69
+
70
+ def detach_kp(kp):
71
+ return {key: value.detach() for key, value in kp.items()}
72
+
73
+
74
+ class GeneratorFullModel(torch.nn.Module):
75
+ """
76
+ Merge all generator related updates into single model for better multi-gpu usage
77
+ """
78
+
79
+ def __init__(self, kp_extractor, bg_predictor, dense_motion_network, inpainting_network, train_params, *kwargs):
80
+ super(GeneratorFullModel, self).__init__()
81
+ self.kp_extractor = kp_extractor
82
+ self.inpainting_network = inpainting_network
83
+ self.dense_motion_network = dense_motion_network
84
+
85
+ self.bg_predictor = None
86
+ if bg_predictor:
87
+ self.bg_predictor = bg_predictor
88
+ self.bg_start = train_params['bg_start']
89
+
90
+ self.train_params = train_params
91
+ self.scales = train_params['scales']
92
+
93
+ self.pyramid = ImagePyramide(self.scales, inpainting_network.num_channels)
94
+ if torch.cuda.is_available():
95
+ self.pyramid = self.pyramid.cuda()
96
+
97
+ self.loss_weights = train_params['loss_weights']
98
+ self.dropout_epoch = train_params['dropout_epoch']
99
+ self.dropout_maxp = train_params['dropout_maxp']
100
+ self.dropout_inc_epoch = train_params['dropout_inc_epoch']
101
+ self.dropout_startp =train_params['dropout_startp']
102
+
103
+ if sum(self.loss_weights['perceptual']) != 0:
104
+ self.vgg = Vgg19()
105
+ if torch.cuda.is_available():
106
+ self.vgg = self.vgg.cuda()
107
+
108
+
109
+ def forward(self, x, epoch):
110
+ kp_source = self.kp_extractor(x['source'])
111
+ kp_driving = self.kp_extractor(x['driving'])
112
+ bg_param = None
113
+ if self.bg_predictor:
114
+ if(epoch>=self.bg_start):
115
+ bg_param = self.bg_predictor(x['source'], x['driving'])
116
+
117
+ if(epoch>=self.dropout_epoch):
118
+ dropout_flag = False
119
+ dropout_p = 0
120
+ else:
121
+ # dropout_p will linearly increase from dropout_startp to dropout_maxp
122
+ dropout_flag = True
123
+ dropout_p = min(epoch/self.dropout_inc_epoch * self.dropout_maxp + self.dropout_startp, self.dropout_maxp)
124
+
125
+ dense_motion = self.dense_motion_network(source_image=x['source'], kp_driving=kp_driving,
126
+ kp_source=kp_source, bg_param = bg_param,
127
+ dropout_flag = dropout_flag, dropout_p = dropout_p)
128
+ generated = self.inpainting_network(x['source'], dense_motion)
129
+ generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
130
+
131
+ loss_values = {}
132
+
133
+ pyramide_real = self.pyramid(x['driving'])
134
+ pyramide_generated = self.pyramid(generated['prediction'])
135
+
136
+ # reconstruction loss
137
+ if sum(self.loss_weights['perceptual']) != 0:
138
+ value_total = 0
139
+ for scale in self.scales:
140
+ x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
141
+ y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
142
+
143
+ for i, weight in enumerate(self.loss_weights['perceptual']):
144
+ value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
145
+ value_total += self.loss_weights['perceptual'][i] * value
146
+ loss_values['perceptual'] = value_total
147
+
148
+ # equivariance loss
149
+ if self.loss_weights['equivariance_value'] != 0:
150
+ transform_random = TPS(mode = 'random', bs = x['driving'].shape[0], **self.train_params['transform_params'])
151
+ transform_grid = transform_random.transform_frame(x['driving'])
152
+ transformed_frame = F.grid_sample(x['driving'], transform_grid, padding_mode="reflection",align_corners=True)
153
+ transformed_kp = self.kp_extractor(transformed_frame)
154
+
155
+ generated['transformed_frame'] = transformed_frame
156
+ generated['transformed_kp'] = transformed_kp
157
+
158
+ warped = transform_random.warp_coordinates(transformed_kp['fg_kp'])
159
+ kp_d = kp_driving['fg_kp']
160
+ value = torch.abs(kp_d - warped).mean()
161
+ loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value
162
+
163
+ # warp loss
164
+ if self.loss_weights['warp_loss'] != 0:
165
+ occlusion_map = generated['occlusion_map']
166
+ encode_map = self.inpainting_network.get_encode(x['driving'], occlusion_map)
167
+ decode_map = generated['warped_encoder_maps']
168
+ value = 0
169
+ for i in range(len(encode_map)):
170
+ value += torch.abs(encode_map[i]-decode_map[-i-1]).mean()
171
+
172
+ loss_values['warp_loss'] = self.loss_weights['warp_loss'] * value
173
+
174
+ # bg loss
175
+ if self.bg_predictor and epoch >= self.bg_start and self.loss_weights['bg'] != 0:
176
+ bg_param_reverse = self.bg_predictor(x['driving'], x['source'])
177
+ value = torch.matmul(bg_param, bg_param_reverse)
178
+ eye = torch.eye(3).view(1, 1, 3, 3).type(value.type())
179
+ value = torch.abs(eye - value).mean()
180
+ loss_values['bg'] = self.loss_weights['bg'] * value
181
+
182
+ return loss_values, generated
modules/util.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch.nn.functional as F
3
+ import torch
4
+
5
+
6
+ class TPS:
7
+ '''
8
+ TPS transformation, mode 'kp' for Eq(2) in the paper, mode 'random' for equivariance loss.
9
+ '''
10
+ def __init__(self, mode, bs, **kwargs):
11
+ self.bs = bs
12
+ self.mode = mode
13
+ if mode == 'random':
14
+ noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3]))
15
+ self.theta = noise + torch.eye(2, 3).view(1, 2, 3)
16
+ self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type())
17
+ self.control_points = self.control_points.unsqueeze(0)
18
+ self.control_params = torch.normal(mean=0,
19
+ std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2]))
20
+ elif mode == 'kp':
21
+ kp_1 = kwargs["kp_1"]
22
+ kp_2 = kwargs["kp_2"]
23
+ device = kp_1.device
24
+ kp_type = kp_1.type()
25
+ self.gs = kp_1.shape[1]
26
+ n = kp_1.shape[2]
27
+ K = torch.norm(kp_1[:,:,:, None]-kp_1[:,:, None, :], dim=4, p=2)
28
+ K = K**2
29
+ K = K * torch.log(K+1e-9)
30
+
31
+ one1 = torch.ones(self.bs, kp_1.shape[1], kp_1.shape[2], 1).to(device).type(kp_type)
32
+ kp_1p = torch.cat([kp_1,one1], 3)
33
+
34
+ zero = torch.zeros(self.bs, kp_1.shape[1], 3, 3).to(device).type(kp_type)
35
+ P = torch.cat([kp_1p, zero],2)
36
+ L = torch.cat([K,kp_1p.permute(0,1,3,2)],2)
37
+ L = torch.cat([L,P],3)
38
+
39
+ zero = torch.zeros(self.bs, kp_1.shape[1], 3, 2).to(device).type(kp_type)
40
+ Y = torch.cat([kp_2, zero], 2)
41
+ one = torch.eye(L.shape[2]).expand(L.shape).to(device).type(kp_type)*0.01
42
+ L = L + one
43
+
44
+ param = torch.matmul(torch.inverse(L),Y)
45
+ self.theta = param[:,:,n:,:].permute(0,1,3,2)
46
+
47
+ self.control_points = kp_1
48
+ self.control_params = param[:,:,:n,:]
49
+ else:
50
+ raise Exception("Error TPS mode")
51
+
52
+ def transform_frame(self, frame):
53
+ grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0).to(frame.device)
54
+ grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
55
+ shape = [self.bs, frame.shape[2], frame.shape[3], 2]
56
+ if self.mode == 'kp':
57
+ shape.insert(1, self.gs)
58
+ grid = self.warp_coordinates(grid).view(*shape)
59
+ return grid
60
+
61
+ def warp_coordinates(self, coordinates):
62
+ theta = self.theta.type(coordinates.type()).to(coordinates.device)
63
+ control_points = self.control_points.type(coordinates.type()).to(coordinates.device)
64
+ control_params = self.control_params.type(coordinates.type()).to(coordinates.device)
65
+
66
+ if self.mode == 'kp':
67
+ transformed = torch.matmul(theta[:, :, :, :2], coordinates.permute(0, 2, 1)) + theta[:, :, :, 2:]
68
+
69
+ distances = coordinates.view(coordinates.shape[0], 1, 1, -1, 2) - control_points.view(self.bs, control_points.shape[1], -1, 1, 2)
70
+
71
+ distances = distances ** 2
72
+ result = distances.sum(-1)
73
+ result = result * torch.log(result + 1e-9)
74
+ result = torch.matmul(result.permute(0, 1, 3, 2), control_params)
75
+ transformed = transformed.permute(0, 1, 3, 2) + result
76
+
77
+ elif self.mode == 'random':
78
+ theta = theta.unsqueeze(1)
79
+ transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:]
80
+ transformed = transformed.squeeze(-1)
81
+ ances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
82
+ distances = ances ** 2
83
+
84
+ result = distances.sum(-1)
85
+ result = result * torch.log(result + 1e-9)
86
+ result = result * control_params
87
+ result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
88
+ transformed = transformed + result
89
+ else:
90
+ raise Exception("Error TPS mode")
91
+
92
+ return transformed
93
+
94
+
95
+ def kp2gaussian(kp, spatial_size, kp_variance):
96
+ """
97
+ Transform a keypoint into gaussian like representation
98
+ """
99
+
100
+ coordinate_grid = make_coordinate_grid(spatial_size, kp.type()).to(kp.device)
101
+ number_of_leading_dimensions = len(kp.shape) - 1
102
+ shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
103
+ coordinate_grid = coordinate_grid.view(*shape)
104
+ repeats = kp.shape[:number_of_leading_dimensions] + (1, 1, 1)
105
+ coordinate_grid = coordinate_grid.repeat(*repeats)
106
+
107
+ # Preprocess kp shape
108
+ shape = kp.shape[:number_of_leading_dimensions] + (1, 1, 2)
109
+ kp = kp.view(*shape)
110
+
111
+ mean_sub = (coordinate_grid - kp)
112
+
113
+ out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
114
+
115
+ return out
116
+
117
+
118
+ def make_coordinate_grid(spatial_size, type):
119
+ """
120
+ Create a meshgrid [-1,1] x [-1,1] of given spatial_size.
121
+ """
122
+ h, w = spatial_size
123
+ x = torch.arange(w).type(type)
124
+ y = torch.arange(h).type(type)
125
+
126
+ x = (2 * (x / (w - 1)) - 1)
127
+ y = (2 * (y / (h - 1)) - 1)
128
+
129
+ yy = y.view(-1, 1).repeat(1, w)
130
+ xx = x.view(1, -1).repeat(h, 1)
131
+
132
+ meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
133
+
134
+ return meshed
135
+
136
+
137
+ class ResBlock2d(nn.Module):
138
+ """
139
+ Res block, preserve spatial resolution.
140
+ """
141
+
142
+ def __init__(self, in_features, kernel_size, padding):
143
+ super(ResBlock2d, self).__init__()
144
+ self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
145
+ padding=padding)
146
+ self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
147
+ padding=padding)
148
+ self.norm1 = nn.InstanceNorm2d(in_features, affine=True)
149
+ self.norm2 = nn.InstanceNorm2d(in_features, affine=True)
150
+
151
+ def forward(self, x):
152
+ out = self.norm1(x)
153
+ out = F.relu(out)
154
+ out = self.conv1(out)
155
+ out = self.norm2(out)
156
+ out = F.relu(out)
157
+ out = self.conv2(out)
158
+ out += x
159
+ return out
160
+
161
+
162
+ class UpBlock2d(nn.Module):
163
+ """
164
+ Upsampling block for use in decoder.
165
+ """
166
+
167
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
168
+ super(UpBlock2d, self).__init__()
169
+
170
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
171
+ padding=padding, groups=groups)
172
+ self.norm = nn.InstanceNorm2d(out_features, affine=True)
173
+
174
+ def forward(self, x):
175
+ out = F.interpolate(x, scale_factor=2)
176
+ out = self.conv(out)
177
+ out = self.norm(out)
178
+ out = F.relu(out)
179
+ return out
180
+
181
+
182
+ class DownBlock2d(nn.Module):
183
+ """
184
+ Downsampling block for use in encoder.
185
+ """
186
+
187
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
188
+ super(DownBlock2d, self).__init__()
189
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
190
+ padding=padding, groups=groups)
191
+ self.norm = nn.InstanceNorm2d(out_features, affine=True)
192
+ self.pool = nn.AvgPool2d(kernel_size=(2, 2))
193
+
194
+ def forward(self, x):
195
+ out = self.conv(x)
196
+ out = self.norm(out)
197
+ out = F.relu(out)
198
+ out = self.pool(out)
199
+ return out
200
+
201
+
202
+ class SameBlock2d(nn.Module):
203
+ """
204
+ Simple block, preserve spatial resolution.
205
+ """
206
+
207
+ def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1):
208
+ super(SameBlock2d, self).__init__()
209
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features,
210
+ kernel_size=kernel_size, padding=padding, groups=groups)
211
+ self.norm = nn.InstanceNorm2d(out_features, affine=True)
212
+
213
+ def forward(self, x):
214
+ out = self.conv(x)
215
+ out = self.norm(out)
216
+ out = F.relu(out)
217
+ return out
218
+
219
+
220
+ class Encoder(nn.Module):
221
+ """
222
+ Hourglass Encoder
223
+ """
224
+
225
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
226
+ super(Encoder, self).__init__()
227
+
228
+ down_blocks = []
229
+ for i in range(num_blocks):
230
+ down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
231
+ min(max_features, block_expansion * (2 ** (i + 1))),
232
+ kernel_size=3, padding=1))
233
+ self.down_blocks = nn.ModuleList(down_blocks)
234
+
235
+ def forward(self, x):
236
+ outs = [x]
237
+ #print('encoder:' ,outs[-1].shape)
238
+ for down_block in self.down_blocks:
239
+ outs.append(down_block(outs[-1]))
240
+ #print('encoder:' ,outs[-1].shape)
241
+ return outs
242
+
243
+
244
+ class Decoder(nn.Module):
245
+ """
246
+ Hourglass Decoder
247
+ """
248
+
249
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
250
+ super(Decoder, self).__init__()
251
+
252
+ up_blocks = []
253
+ self.out_channels = []
254
+ for i in range(num_blocks)[::-1]:
255
+ in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
256
+ self.out_channels.append(in_filters)
257
+ out_filters = min(max_features, block_expansion * (2 ** i))
258
+ up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1))
259
+
260
+ self.up_blocks = nn.ModuleList(up_blocks)
261
+ self.out_channels.append(block_expansion + in_features)
262
+ # self.out_filters = block_expansion + in_features
263
+
264
+ def forward(self, x, mode = 0):
265
+ out = x.pop()
266
+ outs = []
267
+ for up_block in self.up_blocks:
268
+ out = up_block(out)
269
+ skip = x.pop()
270
+ out = torch.cat([out, skip], dim=1)
271
+ outs.append(out)
272
+ if(mode == 0):
273
+ return out
274
+ else:
275
+ return outs
276
+
277
+
278
+ class Hourglass(nn.Module):
279
+ """
280
+ Hourglass architecture.
281
+ """
282
+
283
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
284
+ super(Hourglass, self).__init__()
285
+ self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
286
+ self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
287
+ self.out_channels = self.decoder.out_channels
288
+ # self.out_filters = self.decoder.out_filters
289
+
290
+ def forward(self, x, mode = 0):
291
+ return self.decoder(self.encoder(x), mode)
292
+
293
+
294
+ class AntiAliasInterpolation2d(nn.Module):
295
+ """
296
+ Band-limited downsampling, for better preservation of the input signal.
297
+ """
298
+ def __init__(self, channels, scale):
299
+ super(AntiAliasInterpolation2d, self).__init__()
300
+ sigma = (1 / scale - 1) / 2
301
+ kernel_size = 2 * round(sigma * 4) + 1
302
+ self.ka = kernel_size // 2
303
+ self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
304
+
305
+ kernel_size = [kernel_size, kernel_size]
306
+ sigma = [sigma, sigma]
307
+ # The gaussian kernel is the product of the
308
+ # gaussian function of each dimension.
309
+ kernel = 1
310
+ meshgrids = torch.meshgrid(
311
+ [
312
+ torch.arange(size, dtype=torch.float32)
313
+ for size in kernel_size
314
+ ]
315
+ )
316
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
317
+ mean = (size - 1) / 2
318
+ kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
319
+
320
+ # Make sure sum of values in gaussian kernel equals 1.
321
+ kernel = kernel / torch.sum(kernel)
322
+ # Reshape to depthwise convolutional weight
323
+ kernel = kernel.view(1, 1, *kernel.size())
324
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
325
+
326
+ self.register_buffer('weight', kernel)
327
+ self.groups = channels
328
+ self.scale = scale
329
+
330
+ def forward(self, input):
331
+ if self.scale == 1.0:
332
+ return input
333
+
334
+ out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
335
+ out = F.conv2d(out, weight=self.weight, groups=self.groups)
336
+ out = F.interpolate(out, scale_factor=(self.scale, self.scale))
337
+
338
+ return out
339
+
340
+
341
+ def to_homogeneous(coordinates):
342
+ ones_shape = list(coordinates.shape)
343
+ ones_shape[-1] = 1
344
+ ones = torch.ones(ones_shape).type(coordinates.type())
345
+
346
+ return torch.cat([coordinates, ones], dim=-1)
347
+
348
+ def from_homogeneous(coordinates):
349
+ return coordinates[..., :2] / coordinates[..., 2:3]
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ cffi==1.14.6
3
+ cycler==0.10.0
4
+ decorator==5.1.0
5
+ face-alignment==1.3.5
6
+ imageio==2.9.0
7
+ imageio-ffmpeg==0.4.5
8
+ kiwisolver==1.3.2
9
+ matplotlib==3.4.3
10
+ networkx==2.6.3
11
+ numpy==1.20.3
12
+ pandas==1.3.3
13
+ Pillow==8.3.2
14
+ pycparser==2.20
15
+ pyparsing==2.4.7
16
+ python-dateutil==2.8.2
17
+ pytz==2021.1
18
+ PyWavelets==1.1.1
19
+ PyYAML==5.4.1
20
+ scikit-image==0.18.3
21
+ scikit-learn==1.0
22
+ scipy==1.7.1
23
+ six==1.16.0
24
+ torch==1.10.0+cu113
25
+ torchvision==0.11.0+cu113
26
+ tqdm==4.62.3
27
+ face-alignment