Spaces:
Runtime error
Runtime error
Tolga
commited on
Commit
·
280b585
1
Parent(s):
79da588
version1
Browse files- app.py +109 -0
- config/vox-256.yaml +74 -0
- model.py +119 -0
- modules/.DS_Store +0 -0
- modules/__pycache__/avd_network.cpython-310.pyc +0 -0
- modules/__pycache__/dense_motion.cpython-310.pyc +0 -0
- modules/__pycache__/inpainting_network.cpython-310.pyc +0 -0
- modules/__pycache__/keypoint_detector.cpython-310.pyc +0 -0
- modules/__pycache__/util.cpython-310.pyc +0 -0
- modules/avd_network.py +65 -0
- modules/bg_motion_predictor.py +24 -0
- modules/dense_motion.py +164 -0
- modules/inpainting_network.py +127 -0
- modules/keypoint_detector.py +27 -0
- modules/model.py +182 -0
- modules/util.py +349 -0
- requirements.txt +27 -0
app.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
|
4 |
+
import imageio
|
5 |
+
import imageio_ffmpeg
|
6 |
+
import numpy as np
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import matplotlib.animation as animation
|
9 |
+
from skimage.transform import resize
|
10 |
+
from IPython.display import HTML
|
11 |
+
import warnings
|
12 |
+
import os
|
13 |
+
from model import load_checkpoints
|
14 |
+
from model import make_animation
|
15 |
+
from skimage import img_as_ubyte
|
16 |
+
from PIL import Image
|
17 |
+
import time
|
18 |
+
warnings.filterwarnings("ignore")
|
19 |
+
|
20 |
+
device = torch.device('cuda:0')
|
21 |
+
device = torch.device('cpu')
|
22 |
+
|
23 |
+
|
24 |
+
dataset_name = 'vox' # ['vox', 'taichi', 'ted', 'mgif']
|
25 |
+
source_image_path = './assets/source.png'
|
26 |
+
driving_video_path = './assets/driving.mp4'
|
27 |
+
output_video_path = './generated.mp4'
|
28 |
+
config_path = './config/vox-256.yaml'
|
29 |
+
checkpoint_path = 'checkpoints/vox.pth.tar'
|
30 |
+
checkpoint_path = 'vox.pth.tar'
|
31 |
+
predict_mode = 'relative' # ['standard', 'relative', 'avd']
|
32 |
+
find_best_frame = False # when use the relative mode to animate a face, use 'find_best_frame=True' can get better quality result
|
33 |
+
|
34 |
+
pixel = 256 # for vox, taichi and mgif, the resolution is 256*256
|
35 |
+
if(dataset_name == 'ted'): # for ted, the resolution is 384*384
|
36 |
+
pixel = 384
|
37 |
+
|
38 |
+
if find_best_frame:
|
39 |
+
#!pip install face_alignment
|
40 |
+
pass
|
41 |
+
|
42 |
+
|
43 |
+
def create_video(tt):
|
44 |
+
|
45 |
+
source_image = imageio.imread(f"assets/img_{tt}.jpg")
|
46 |
+
reader = imageio.get_reader(f"assets/ref_{tt}.mp4")
|
47 |
+
|
48 |
+
source_image = resize(source_image, (pixel, pixel))[..., :3]
|
49 |
+
|
50 |
+
fps = reader.get_meta_data()['fps']
|
51 |
+
driving_video = []
|
52 |
+
try:
|
53 |
+
for im in reader:
|
54 |
+
driving_video.append(im)
|
55 |
+
except RuntimeError:
|
56 |
+
pass
|
57 |
+
reader.close()
|
58 |
+
|
59 |
+
driving_video = [resize(frame, (pixel, pixel))[..., :3] for frame in driving_video]
|
60 |
+
|
61 |
+
def display(source, driving, generated=None):
|
62 |
+
fig = plt.figure(figsize=(8 + 4 * (generated is not None), 6))
|
63 |
+
|
64 |
+
ims = []
|
65 |
+
for i in range(len(driving)):
|
66 |
+
cols = [source]
|
67 |
+
cols.append(driving[i])
|
68 |
+
if generated is not None:
|
69 |
+
cols.append(generated[i])
|
70 |
+
im = plt.imshow(np.concatenate(cols, axis=1), animated=True)
|
71 |
+
plt.axis('off')
|
72 |
+
ims.append([im])
|
73 |
+
|
74 |
+
ani = animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=1000)
|
75 |
+
plt.close()
|
76 |
+
return ani
|
77 |
+
|
78 |
+
|
79 |
+
#HTML(display(source_image, driving_video).to_html5_video())
|
80 |
+
inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path = config_path, checkpoint_path = checkpoint_path, device = device)
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
if predict_mode=='relative' and find_best_frame:
|
85 |
+
from model import find_best_frame as _find
|
86 |
+
i = _find(source_image, driving_video, device.type=='cpu')
|
87 |
+
print ("Best frame: " + str(i))
|
88 |
+
driving_forward = driving_video[i:]
|
89 |
+
driving_backward = driving_video[:(i+1)][::-1]
|
90 |
+
predictions_forward = make_animation(source_image, driving_forward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)
|
91 |
+
predictions_backward = make_animation(source_image, driving_backward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)
|
92 |
+
predictions = predictions_backward[::-1] + predictions_forward[1:]
|
93 |
+
else:
|
94 |
+
predictions = make_animation(source_image, driving_video, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)
|
95 |
+
|
96 |
+
#save resulting video
|
97 |
+
imageio.mimsave(f"./assets/output_{tt}.mp4", [img_as_ubyte(frame) for frame in predictions], fps=fps)
|
98 |
+
|
99 |
+
|
100 |
+
def greet(img,video):
|
101 |
+
tt=str(time.time())
|
102 |
+
os.replace(video, f"assets/ref_{tt}.mp4")
|
103 |
+
img.save(f"assets/img_{tt}.jpg")
|
104 |
+
create_video(tt)
|
105 |
+
return f"./assets/output_{tt}.mp4"
|
106 |
+
|
107 |
+
|
108 |
+
iface = gr.Interface(fn=greet, inputs=[gr.inputs.Image(type="pil"),gr.inputs.Video()], outputs=gr.inputs.Video())
|
109 |
+
iface.launch()
|
config/vox-256.yaml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_params:
|
2 |
+
root_dir: ../vox
|
3 |
+
frame_shape: null
|
4 |
+
id_sampling: True
|
5 |
+
augmentation_params:
|
6 |
+
flip_param:
|
7 |
+
horizontal_flip: True
|
8 |
+
time_flip: True
|
9 |
+
jitter_param:
|
10 |
+
brightness: 0.1
|
11 |
+
contrast: 0.1
|
12 |
+
saturation: 0.1
|
13 |
+
hue: 0.1
|
14 |
+
|
15 |
+
|
16 |
+
model_params:
|
17 |
+
common_params:
|
18 |
+
num_tps: 10
|
19 |
+
num_channels: 3
|
20 |
+
bg: True
|
21 |
+
multi_mask: True
|
22 |
+
generator_params:
|
23 |
+
block_expansion: 64
|
24 |
+
max_features: 512
|
25 |
+
num_down_blocks: 3
|
26 |
+
dense_motion_params:
|
27 |
+
block_expansion: 64
|
28 |
+
max_features: 1024
|
29 |
+
num_blocks: 5
|
30 |
+
scale_factor: 0.25
|
31 |
+
avd_network_params:
|
32 |
+
id_bottle_size: 128
|
33 |
+
pose_bottle_size: 128
|
34 |
+
|
35 |
+
|
36 |
+
train_params:
|
37 |
+
num_epochs: 100
|
38 |
+
num_repeats: 75
|
39 |
+
epoch_milestones: [70, 90]
|
40 |
+
lr_generator: 2.0e-4
|
41 |
+
batch_size: 28
|
42 |
+
scales: [1, 0.5, 0.25, 0.125]
|
43 |
+
dataloader_workers: 12
|
44 |
+
checkpoint_freq: 50
|
45 |
+
dropout_epoch: 35
|
46 |
+
dropout_maxp: 0.3
|
47 |
+
dropout_startp: 0.1
|
48 |
+
dropout_inc_epoch: 10
|
49 |
+
bg_start: 10
|
50 |
+
transform_params:
|
51 |
+
sigma_affine: 0.05
|
52 |
+
sigma_tps: 0.005
|
53 |
+
points_tps: 5
|
54 |
+
loss_weights:
|
55 |
+
perceptual: [10, 10, 10, 10, 10]
|
56 |
+
equivariance_value: 10
|
57 |
+
warp_loss: 10
|
58 |
+
bg: 10
|
59 |
+
|
60 |
+
train_avd_params:
|
61 |
+
num_epochs: 200
|
62 |
+
num_repeats: 300
|
63 |
+
batch_size: 256
|
64 |
+
dataloader_workers: 24
|
65 |
+
checkpoint_freq: 50
|
66 |
+
epoch_milestones: [140, 180]
|
67 |
+
lr: 1.0e-3
|
68 |
+
lambda_shift: 1
|
69 |
+
random_scale: 0.25
|
70 |
+
|
71 |
+
visualizer_params:
|
72 |
+
kp_size: 5
|
73 |
+
draw_border: True
|
74 |
+
colormap: 'gist_rainbow'
|
model.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib
|
2 |
+
matplotlib.use('Agg')
|
3 |
+
import sys
|
4 |
+
import yaml
|
5 |
+
from argparse import ArgumentParser
|
6 |
+
from tqdm import tqdm
|
7 |
+
from scipy.spatial import ConvexHull
|
8 |
+
import numpy as np
|
9 |
+
import imageio
|
10 |
+
from skimage.transform import resize
|
11 |
+
from skimage import img_as_ubyte
|
12 |
+
import torch
|
13 |
+
from modules.inpainting_network import InpaintingNetwork
|
14 |
+
from modules.keypoint_detector import KPDetector
|
15 |
+
from modules.dense_motion import DenseMotionNetwork
|
16 |
+
from modules.avd_network import AVDNetwork
|
17 |
+
|
18 |
+
def load_checkpoints(config_path, checkpoint_path, device):
|
19 |
+
with open(config_path) as f:
|
20 |
+
config = yaml.full_load(f)
|
21 |
+
|
22 |
+
inpainting = InpaintingNetwork(**config['model_params']['generator_params'],
|
23 |
+
**config['model_params']['common_params'])
|
24 |
+
kp_detector = KPDetector(**config['model_params']['common_params'])
|
25 |
+
dense_motion_network = DenseMotionNetwork(**config['model_params']['common_params'],
|
26 |
+
**config['model_params']['dense_motion_params'])
|
27 |
+
avd_network = AVDNetwork(num_tps=config['model_params']['common_params']['num_tps'],
|
28 |
+
**config['model_params']['avd_network_params'])
|
29 |
+
kp_detector.to(device)
|
30 |
+
dense_motion_network.to(device)
|
31 |
+
inpainting.to(device)
|
32 |
+
avd_network.to(device)
|
33 |
+
|
34 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
35 |
+
|
36 |
+
inpainting.load_state_dict(checkpoint['inpainting_network'])
|
37 |
+
kp_detector.load_state_dict(checkpoint['kp_detector'])
|
38 |
+
dense_motion_network.load_state_dict(checkpoint['dense_motion_network'])
|
39 |
+
if 'avd_network' in checkpoint:
|
40 |
+
avd_network.load_state_dict(checkpoint['avd_network'])
|
41 |
+
|
42 |
+
inpainting.eval()
|
43 |
+
kp_detector.eval()
|
44 |
+
dense_motion_network.eval()
|
45 |
+
avd_network.eval()
|
46 |
+
|
47 |
+
return inpainting, kp_detector, dense_motion_network, avd_network
|
48 |
+
|
49 |
+
def relative_kp(kp_source, kp_driving, kp_driving_initial):
|
50 |
+
|
51 |
+
source_area = ConvexHull(kp_source['fg_kp'][0].data.cpu().numpy()).volume
|
52 |
+
driving_area = ConvexHull(kp_driving_initial['fg_kp'][0].data.cpu().numpy()).volume
|
53 |
+
adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
|
54 |
+
|
55 |
+
kp_new = {k: v for k, v in kp_driving.items()}
|
56 |
+
|
57 |
+
kp_value_diff = (kp_driving['fg_kp'] - kp_driving_initial['fg_kp'])
|
58 |
+
kp_value_diff *= adapt_movement_scale
|
59 |
+
kp_new['fg_kp'] = kp_value_diff + kp_source['fg_kp']
|
60 |
+
|
61 |
+
return kp_new
|
62 |
+
|
63 |
+
def make_animation(source_image, driving_video, inpainting_network, kp_detector, dense_motion_network, avd_network, device, mode = 'relative'):
|
64 |
+
assert mode in ['standard', 'relative', 'avd']
|
65 |
+
with torch.no_grad():
|
66 |
+
predictions = []
|
67 |
+
source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
|
68 |
+
source = source.to(device)
|
69 |
+
driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3).to(device)
|
70 |
+
kp_source = kp_detector(source)
|
71 |
+
kp_driving_initial = kp_detector(driving[:, :, 0])
|
72 |
+
|
73 |
+
for frame_idx in tqdm(range(driving.shape[2])):
|
74 |
+
driving_frame = driving[:, :, frame_idx]
|
75 |
+
driving_frame = driving_frame.to(device)
|
76 |
+
kp_driving = kp_detector(driving_frame)
|
77 |
+
if mode == 'standard':
|
78 |
+
kp_norm = kp_driving
|
79 |
+
elif mode=='relative':
|
80 |
+
kp_norm = relative_kp(kp_source=kp_source, kp_driving=kp_driving,
|
81 |
+
kp_driving_initial=kp_driving_initial)
|
82 |
+
elif mode == 'avd':
|
83 |
+
kp_norm = avd_network(kp_source, kp_driving)
|
84 |
+
dense_motion = dense_motion_network(source_image=source, kp_driving=kp_norm,
|
85 |
+
kp_source=kp_source, bg_param = None,
|
86 |
+
dropout_flag = False)
|
87 |
+
out = inpainting_network(source, dense_motion)
|
88 |
+
|
89 |
+
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
|
90 |
+
return predictions
|
91 |
+
|
92 |
+
|
93 |
+
def find_best_frame(source, driving, cpu):
|
94 |
+
import face_alignment
|
95 |
+
|
96 |
+
def normalize_kp(kp):
|
97 |
+
kp = kp - kp.mean(axis=0, keepdims=True)
|
98 |
+
area = ConvexHull(kp[:, :2]).volume
|
99 |
+
area = np.sqrt(area)
|
100 |
+
kp[:, :2] = kp[:, :2] / area
|
101 |
+
return kp
|
102 |
+
|
103 |
+
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
|
104 |
+
device= 'cpu' if cpu else 'cuda')
|
105 |
+
kp_source = fa.get_landmarks(255 * source)[0]
|
106 |
+
kp_source = normalize_kp(kp_source)
|
107 |
+
norm = float('inf')
|
108 |
+
frame_num = 0
|
109 |
+
for i, image in tqdm(enumerate(driving)):
|
110 |
+
try:
|
111 |
+
kp_driving = fa.get_landmarks(255 * image)[0]
|
112 |
+
kp_driving = normalize_kp(kp_driving)
|
113 |
+
new_norm = (np.abs(kp_source - kp_driving) ** 2).sum()
|
114 |
+
if new_norm < norm:
|
115 |
+
norm = new_norm
|
116 |
+
frame_num = i
|
117 |
+
except:
|
118 |
+
pass
|
119 |
+
return frame_num
|
modules/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
modules/__pycache__/avd_network.cpython-310.pyc
ADDED
Binary file (1.57 kB). View file
|
|
modules/__pycache__/dense_motion.cpython-310.pyc
ADDED
Binary file (5.67 kB). View file
|
|
modules/__pycache__/inpainting_network.cpython-310.pyc
ADDED
Binary file (3.75 kB). View file
|
|
modules/__pycache__/keypoint_detector.cpython-310.pyc
ADDED
Binary file (1.12 kB). View file
|
|
modules/__pycache__/util.cpython-310.pyc
ADDED
Binary file (10.8 kB). View file
|
|
modules/avd_network.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
class AVDNetwork(nn.Module):
|
7 |
+
"""
|
8 |
+
Animation via Disentanglement network
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(self, num_tps, id_bottle_size=64, pose_bottle_size=64):
|
12 |
+
super(AVDNetwork, self).__init__()
|
13 |
+
input_size = 5*2 * num_tps
|
14 |
+
self.num_tps = num_tps
|
15 |
+
|
16 |
+
self.id_encoder = nn.Sequential(
|
17 |
+
nn.Linear(input_size, 256),
|
18 |
+
nn.BatchNorm1d(256),
|
19 |
+
nn.ReLU(inplace=True),
|
20 |
+
nn.Linear(256, 512),
|
21 |
+
nn.BatchNorm1d(512),
|
22 |
+
nn.ReLU(inplace=True),
|
23 |
+
nn.Linear(512, 1024),
|
24 |
+
nn.BatchNorm1d(1024),
|
25 |
+
nn.ReLU(inplace=True),
|
26 |
+
nn.Linear(1024, id_bottle_size)
|
27 |
+
)
|
28 |
+
|
29 |
+
self.pose_encoder = nn.Sequential(
|
30 |
+
nn.Linear(input_size, 256),
|
31 |
+
nn.BatchNorm1d(256),
|
32 |
+
nn.ReLU(inplace=True),
|
33 |
+
nn.Linear(256, 512),
|
34 |
+
nn.BatchNorm1d(512),
|
35 |
+
nn.ReLU(inplace=True),
|
36 |
+
nn.Linear(512, 1024),
|
37 |
+
nn.BatchNorm1d(1024),
|
38 |
+
nn.ReLU(inplace=True),
|
39 |
+
nn.Linear(1024, pose_bottle_size)
|
40 |
+
)
|
41 |
+
|
42 |
+
self.decoder = nn.Sequential(
|
43 |
+
nn.Linear(pose_bottle_size + id_bottle_size, 1024),
|
44 |
+
nn.BatchNorm1d(1024),
|
45 |
+
nn.ReLU(),
|
46 |
+
nn.Linear(1024, 512),
|
47 |
+
nn.BatchNorm1d(512),
|
48 |
+
nn.ReLU(),
|
49 |
+
nn.Linear(512, 256),
|
50 |
+
nn.BatchNorm1d(256),
|
51 |
+
nn.ReLU(),
|
52 |
+
nn.Linear(256, input_size)
|
53 |
+
)
|
54 |
+
|
55 |
+
def forward(self, kp_source, kp_random):
|
56 |
+
|
57 |
+
bs = kp_source['fg_kp'].shape[0]
|
58 |
+
|
59 |
+
pose_emb = self.pose_encoder(kp_random['fg_kp'].view(bs, -1))
|
60 |
+
id_emb = self.id_encoder(kp_source['fg_kp'].view(bs, -1))
|
61 |
+
|
62 |
+
rec = self.decoder(torch.cat([pose_emb, id_emb], dim=1))
|
63 |
+
|
64 |
+
rec = {'fg_kp': rec.view(bs, self.num_tps*5, -1)}
|
65 |
+
return rec
|
modules/bg_motion_predictor.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch
|
3 |
+
from torchvision import models
|
4 |
+
|
5 |
+
class BGMotionPredictor(nn.Module):
|
6 |
+
"""
|
7 |
+
Module for background estimation, return single transformation, parametrized as 3x3 matrix. The third row is [0 0 1]
|
8 |
+
"""
|
9 |
+
|
10 |
+
def __init__(self):
|
11 |
+
super(BGMotionPredictor, self).__init__()
|
12 |
+
self.bg_encoder = models.resnet18(pretrained=False)
|
13 |
+
self.bg_encoder.conv1 = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
|
14 |
+
num_features = self.bg_encoder.fc.in_features
|
15 |
+
self.bg_encoder.fc = nn.Linear(num_features, 6)
|
16 |
+
self.bg_encoder.fc.weight.data.zero_()
|
17 |
+
self.bg_encoder.fc.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
|
18 |
+
|
19 |
+
def forward(self, source_image, driving_image):
|
20 |
+
bs = source_image.shape[0]
|
21 |
+
out = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).type(source_image.type())
|
22 |
+
prediction = self.bg_encoder(torch.cat([source_image, driving_image], dim=1))
|
23 |
+
out[:, :2, :] = prediction.view(bs, 2, 3)
|
24 |
+
return out
|
modules/dense_motion.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch
|
4 |
+
from modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian
|
5 |
+
from modules.util import to_homogeneous, from_homogeneous, UpBlock2d, TPS
|
6 |
+
import math
|
7 |
+
|
8 |
+
class DenseMotionNetwork(nn.Module):
|
9 |
+
"""
|
10 |
+
Module that estimating an optical flow and multi-resolution occlusion masks
|
11 |
+
from K TPS transformations and an affine transformation.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, block_expansion, num_blocks, max_features, num_tps, num_channels,
|
15 |
+
scale_factor=0.25, bg = False, multi_mask = True, kp_variance=0.01):
|
16 |
+
super(DenseMotionNetwork, self).__init__()
|
17 |
+
|
18 |
+
if scale_factor != 1:
|
19 |
+
self.down = AntiAliasInterpolation2d(num_channels, scale_factor)
|
20 |
+
self.scale_factor = scale_factor
|
21 |
+
self.multi_mask = multi_mask
|
22 |
+
|
23 |
+
self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_channels * (num_tps+1) + num_tps*5+1),
|
24 |
+
max_features=max_features, num_blocks=num_blocks)
|
25 |
+
|
26 |
+
hourglass_output_size = self.hourglass.out_channels
|
27 |
+
self.maps = nn.Conv2d(hourglass_output_size[-1], num_tps + 1, kernel_size=(7, 7), padding=(3, 3))
|
28 |
+
|
29 |
+
if multi_mask:
|
30 |
+
up = []
|
31 |
+
self.up_nums = int(math.log(1/scale_factor, 2))
|
32 |
+
self.occlusion_num = 4
|
33 |
+
|
34 |
+
channel = [hourglass_output_size[-1]//(2**i) for i in range(self.up_nums)]
|
35 |
+
for i in range(self.up_nums):
|
36 |
+
up.append(UpBlock2d(channel[i], channel[i]//2, kernel_size=3, padding=1))
|
37 |
+
self.up = nn.ModuleList(up)
|
38 |
+
|
39 |
+
channel = [hourglass_output_size[-i-1] for i in range(self.occlusion_num-self.up_nums)[::-1]]
|
40 |
+
for i in range(self.up_nums):
|
41 |
+
channel.append(hourglass_output_size[-1]//(2**(i+1)))
|
42 |
+
occlusion = []
|
43 |
+
|
44 |
+
for i in range(self.occlusion_num):
|
45 |
+
occlusion.append(nn.Conv2d(channel[i], 1, kernel_size=(7, 7), padding=(3, 3)))
|
46 |
+
self.occlusion = nn.ModuleList(occlusion)
|
47 |
+
else:
|
48 |
+
occlusion = [nn.Conv2d(hourglass_output_size[-1], 1, kernel_size=(7, 7), padding=(3, 3))]
|
49 |
+
self.occlusion = nn.ModuleList(occlusion)
|
50 |
+
|
51 |
+
self.num_tps = num_tps
|
52 |
+
self.bg = bg
|
53 |
+
self.kp_variance = kp_variance
|
54 |
+
|
55 |
+
|
56 |
+
def create_heatmap_representations(self, source_image, kp_driving, kp_source):
|
57 |
+
|
58 |
+
spatial_size = source_image.shape[2:]
|
59 |
+
gaussian_driving = kp2gaussian(kp_driving['fg_kp'], spatial_size=spatial_size, kp_variance=self.kp_variance)
|
60 |
+
gaussian_source = kp2gaussian(kp_source['fg_kp'], spatial_size=spatial_size, kp_variance=self.kp_variance)
|
61 |
+
heatmap = gaussian_driving - gaussian_source
|
62 |
+
|
63 |
+
zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type()).to(heatmap.device)
|
64 |
+
heatmap = torch.cat([zeros, heatmap], dim=1)
|
65 |
+
|
66 |
+
return heatmap
|
67 |
+
|
68 |
+
def create_transformations(self, source_image, kp_driving, kp_source, bg_param):
|
69 |
+
# K TPS transformaions
|
70 |
+
bs, _, h, w = source_image.shape
|
71 |
+
kp_1 = kp_driving['fg_kp']
|
72 |
+
kp_2 = kp_source['fg_kp']
|
73 |
+
kp_1 = kp_1.view(bs, -1, 5, 2)
|
74 |
+
kp_2 = kp_2.view(bs, -1, 5, 2)
|
75 |
+
trans = TPS(mode = 'kp', bs = bs, kp_1 = kp_1, kp_2 = kp_2)
|
76 |
+
driving_to_source = trans.transform_frame(source_image)
|
77 |
+
|
78 |
+
identity_grid = make_coordinate_grid((h, w), type=kp_1.type()).to(kp_1.device)
|
79 |
+
identity_grid = identity_grid.view(1, 1, h, w, 2)
|
80 |
+
identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1)
|
81 |
+
|
82 |
+
# affine background transformation
|
83 |
+
if not (bg_param is None):
|
84 |
+
identity_grid = to_homogeneous(identity_grid)
|
85 |
+
identity_grid = torch.matmul(bg_param.view(bs, 1, 1, 1, 3, 3), identity_grid.unsqueeze(-1)).squeeze(-1)
|
86 |
+
identity_grid = from_homogeneous(identity_grid)
|
87 |
+
|
88 |
+
transformations = torch.cat([identity_grid, driving_to_source], dim=1)
|
89 |
+
return transformations
|
90 |
+
|
91 |
+
def create_deformed_source_image(self, source_image, transformations):
|
92 |
+
|
93 |
+
bs, _, h, w = source_image.shape
|
94 |
+
source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_tps + 1, 1, 1, 1, 1)
|
95 |
+
source_repeat = source_repeat.view(bs * (self.num_tps + 1), -1, h, w)
|
96 |
+
transformations = transformations.view((bs * (self.num_tps + 1), h, w, -1))
|
97 |
+
deformed = F.grid_sample(source_repeat, transformations, align_corners=True)
|
98 |
+
deformed = deformed.view((bs, self.num_tps+1, -1, h, w))
|
99 |
+
return deformed
|
100 |
+
|
101 |
+
def dropout_softmax(self, X, P):
|
102 |
+
'''
|
103 |
+
Dropout for TPS transformations. Eq(7) and Eq(8) in the paper.
|
104 |
+
'''
|
105 |
+
drop = (torch.rand(X.shape[0],X.shape[1]) < (1-P)).type(X.type()).to(X.device)
|
106 |
+
drop[..., 0] = 1
|
107 |
+
drop = drop.repeat(X.shape[2],X.shape[3],1,1).permute(2,3,0,1)
|
108 |
+
|
109 |
+
maxx = X.max(1).values.unsqueeze_(1)
|
110 |
+
X = X - maxx
|
111 |
+
X_exp = X.exp()
|
112 |
+
X[:,1:,...] /= (1-P)
|
113 |
+
mask_bool =(drop == 0)
|
114 |
+
X_exp = X_exp.masked_fill(mask_bool, 0)
|
115 |
+
partition = X_exp.sum(dim=1, keepdim=True) + 1e-6
|
116 |
+
return X_exp / partition
|
117 |
+
|
118 |
+
def forward(self, source_image, kp_driving, kp_source, bg_param = None, dropout_flag=False, dropout_p = 0):
|
119 |
+
if self.scale_factor != 1:
|
120 |
+
source_image = self.down(source_image)
|
121 |
+
|
122 |
+
bs, _, h, w = source_image.shape
|
123 |
+
|
124 |
+
out_dict = dict()
|
125 |
+
heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source)
|
126 |
+
transformations = self.create_transformations(source_image, kp_driving, kp_source, bg_param)
|
127 |
+
deformed_source = self.create_deformed_source_image(source_image, transformations)
|
128 |
+
out_dict['deformed_source'] = deformed_source
|
129 |
+
# out_dict['transformations'] = transformations
|
130 |
+
deformed_source = deformed_source.view(bs,-1,h,w)
|
131 |
+
input = torch.cat([heatmap_representation, deformed_source], dim=1)
|
132 |
+
input = input.view(bs, -1, h, w)
|
133 |
+
|
134 |
+
prediction = self.hourglass(input, mode = 1)
|
135 |
+
|
136 |
+
contribution_maps = self.maps(prediction[-1])
|
137 |
+
if(dropout_flag):
|
138 |
+
contribution_maps = self.dropout_softmax(contribution_maps, dropout_p)
|
139 |
+
else:
|
140 |
+
contribution_maps = F.softmax(contribution_maps, dim=1)
|
141 |
+
out_dict['contribution_maps'] = contribution_maps
|
142 |
+
|
143 |
+
# Combine the K+1 transformations
|
144 |
+
# Eq(6) in the paper
|
145 |
+
contribution_maps = contribution_maps.unsqueeze(2)
|
146 |
+
transformations = transformations.permute(0, 1, 4, 2, 3)
|
147 |
+
deformation = (transformations * contribution_maps).sum(dim=1)
|
148 |
+
deformation = deformation.permute(0, 2, 3, 1)
|
149 |
+
|
150 |
+
out_dict['deformation'] = deformation # Optical Flow
|
151 |
+
|
152 |
+
occlusion_map = []
|
153 |
+
if self.multi_mask:
|
154 |
+
for i in range(self.occlusion_num-self.up_nums):
|
155 |
+
occlusion_map.append(torch.sigmoid(self.occlusion[i](prediction[self.up_nums-self.occlusion_num+i])))
|
156 |
+
prediction = prediction[-1]
|
157 |
+
for i in range(self.up_nums):
|
158 |
+
prediction = self.up[i](prediction)
|
159 |
+
occlusion_map.append(torch.sigmoid(self.occlusion[i+self.occlusion_num-self.up_nums](prediction)))
|
160 |
+
else:
|
161 |
+
occlusion_map.append(torch.sigmoid(self.occlusion[0](prediction[-1])))
|
162 |
+
|
163 |
+
out_dict['occlusion_map'] = occlusion_map # Multi-resolution Occlusion Masks
|
164 |
+
return out_dict
|
modules/inpainting_network.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d
|
5 |
+
from modules.dense_motion import DenseMotionNetwork
|
6 |
+
|
7 |
+
|
8 |
+
class InpaintingNetwork(nn.Module):
|
9 |
+
"""
|
10 |
+
Inpaint the missing regions and reconstruct the Driving image.
|
11 |
+
"""
|
12 |
+
def __init__(self, num_channels, block_expansion, max_features, num_down_blocks, multi_mask = True, **kwargs):
|
13 |
+
super(InpaintingNetwork, self).__init__()
|
14 |
+
|
15 |
+
self.num_down_blocks = num_down_blocks
|
16 |
+
self.multi_mask = multi_mask
|
17 |
+
self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3))
|
18 |
+
|
19 |
+
down_blocks = []
|
20 |
+
up_blocks = []
|
21 |
+
resblock = []
|
22 |
+
for i in range(num_down_blocks):
|
23 |
+
in_features = min(max_features, block_expansion * (2 ** i))
|
24 |
+
out_features = min(max_features, block_expansion * (2 ** (i + 1)))
|
25 |
+
down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
|
26 |
+
decoder_in_feature = out_features * 2
|
27 |
+
if i==num_down_blocks-1:
|
28 |
+
decoder_in_feature = out_features
|
29 |
+
up_blocks.append(UpBlock2d(decoder_in_feature, in_features, kernel_size=(3, 3), padding=(1, 1)))
|
30 |
+
resblock.append(ResBlock2d(decoder_in_feature, kernel_size=(3, 3), padding=(1, 1)))
|
31 |
+
resblock.append(ResBlock2d(decoder_in_feature, kernel_size=(3, 3), padding=(1, 1)))
|
32 |
+
self.down_blocks = nn.ModuleList(down_blocks)
|
33 |
+
self.up_blocks = nn.ModuleList(up_blocks[::-1])
|
34 |
+
self.resblock = nn.ModuleList(resblock[::-1])
|
35 |
+
|
36 |
+
self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3))
|
37 |
+
self.num_channels = num_channels
|
38 |
+
|
39 |
+
def deform_input(self, inp, deformation):
|
40 |
+
_, h_old, w_old, _ = deformation.shape
|
41 |
+
_, _, h, w = inp.shape
|
42 |
+
if h_old != h or w_old != w:
|
43 |
+
deformation = deformation.permute(0, 3, 1, 2)
|
44 |
+
deformation = F.interpolate(deformation, size=(h, w), mode='bilinear', align_corners=True)
|
45 |
+
deformation = deformation.permute(0, 2, 3, 1)
|
46 |
+
return F.grid_sample(inp, deformation,align_corners=True)
|
47 |
+
|
48 |
+
def occlude_input(self, inp, occlusion_map):
|
49 |
+
if not self.multi_mask:
|
50 |
+
if inp.shape[2] != occlusion_map.shape[2] or inp.shape[3] != occlusion_map.shape[3]:
|
51 |
+
occlusion_map = F.interpolate(occlusion_map, size=inp.shape[2:], mode='bilinear',align_corners=True)
|
52 |
+
out = inp * occlusion_map
|
53 |
+
return out
|
54 |
+
|
55 |
+
def forward(self, source_image, dense_motion):
|
56 |
+
out = self.first(source_image)
|
57 |
+
encoder_map = [out]
|
58 |
+
for i in range(len(self.down_blocks)):
|
59 |
+
out = self.down_blocks[i](out)
|
60 |
+
encoder_map.append(out)
|
61 |
+
|
62 |
+
output_dict = {}
|
63 |
+
output_dict['contribution_maps'] = dense_motion['contribution_maps']
|
64 |
+
output_dict['deformed_source'] = dense_motion['deformed_source']
|
65 |
+
|
66 |
+
occlusion_map = dense_motion['occlusion_map']
|
67 |
+
output_dict['occlusion_map'] = occlusion_map
|
68 |
+
|
69 |
+
deformation = dense_motion['deformation']
|
70 |
+
out_ij = self.deform_input(out.detach(), deformation)
|
71 |
+
out = self.deform_input(out, deformation)
|
72 |
+
|
73 |
+
out_ij = self.occlude_input(out_ij, occlusion_map[0].detach())
|
74 |
+
out = self.occlude_input(out, occlusion_map[0])
|
75 |
+
|
76 |
+
warped_encoder_maps = []
|
77 |
+
warped_encoder_maps.append(out_ij)
|
78 |
+
|
79 |
+
for i in range(self.num_down_blocks):
|
80 |
+
|
81 |
+
out = self.resblock[2*i](out)
|
82 |
+
out = self.resblock[2*i+1](out)
|
83 |
+
out = self.up_blocks[i](out)
|
84 |
+
|
85 |
+
encode_i = encoder_map[-(i+2)]
|
86 |
+
encode_ij = self.deform_input(encode_i.detach(), deformation)
|
87 |
+
encode_i = self.deform_input(encode_i, deformation)
|
88 |
+
|
89 |
+
occlusion_ind = 0
|
90 |
+
if self.multi_mask:
|
91 |
+
occlusion_ind = i+1
|
92 |
+
encode_ij = self.occlude_input(encode_ij, occlusion_map[occlusion_ind].detach())
|
93 |
+
encode_i = self.occlude_input(encode_i, occlusion_map[occlusion_ind])
|
94 |
+
warped_encoder_maps.append(encode_ij)
|
95 |
+
|
96 |
+
if(i==self.num_down_blocks-1):
|
97 |
+
break
|
98 |
+
|
99 |
+
out = torch.cat([out, encode_i], 1)
|
100 |
+
|
101 |
+
deformed_source = self.deform_input(source_image, deformation)
|
102 |
+
output_dict["deformed"] = deformed_source
|
103 |
+
output_dict["warped_encoder_maps"] = warped_encoder_maps
|
104 |
+
|
105 |
+
occlusion_last = occlusion_map[-1]
|
106 |
+
if not self.multi_mask:
|
107 |
+
occlusion_last = F.interpolate(occlusion_last, size=out.shape[2:], mode='bilinear',align_corners=True)
|
108 |
+
|
109 |
+
out = out * (1 - occlusion_last) + encode_i
|
110 |
+
out = self.final(out)
|
111 |
+
out = torch.sigmoid(out)
|
112 |
+
out = out * (1 - occlusion_last) + deformed_source * occlusion_last
|
113 |
+
output_dict["prediction"] = out
|
114 |
+
|
115 |
+
return output_dict
|
116 |
+
|
117 |
+
def get_encode(self, driver_image, occlusion_map):
|
118 |
+
out = self.first(driver_image)
|
119 |
+
encoder_map = []
|
120 |
+
encoder_map.append(self.occlude_input(out.detach(), occlusion_map[-1].detach()))
|
121 |
+
for i in range(len(self.down_blocks)):
|
122 |
+
out = self.down_blocks[i](out.detach())
|
123 |
+
out_mask = self.occlude_input(out.detach(), occlusion_map[2-i].detach())
|
124 |
+
encoder_map.append(out_mask.detach())
|
125 |
+
|
126 |
+
return encoder_map
|
127 |
+
|
modules/keypoint_detector.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch
|
3 |
+
from torchvision import models
|
4 |
+
|
5 |
+
class KPDetector(nn.Module):
|
6 |
+
"""
|
7 |
+
Predict K*5 keypoints.
|
8 |
+
"""
|
9 |
+
|
10 |
+
def __init__(self, num_tps, **kwargs):
|
11 |
+
super(KPDetector, self).__init__()
|
12 |
+
self.num_tps = num_tps
|
13 |
+
|
14 |
+
self.fg_encoder = models.resnet18(pretrained=False)
|
15 |
+
num_features = self.fg_encoder.fc.in_features
|
16 |
+
self.fg_encoder.fc = nn.Linear(num_features, num_tps*5*2)
|
17 |
+
|
18 |
+
|
19 |
+
def forward(self, image):
|
20 |
+
|
21 |
+
fg_kp = self.fg_encoder(image)
|
22 |
+
bs, _, = fg_kp.shape
|
23 |
+
fg_kp = torch.sigmoid(fg_kp)
|
24 |
+
fg_kp = fg_kp * 2 - 1
|
25 |
+
out = {'fg_kp': fg_kp.view(bs, self.num_tps*5, -1)}
|
26 |
+
|
27 |
+
return out
|
modules/model.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from modules.util import AntiAliasInterpolation2d, TPS
|
5 |
+
from torchvision import models
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
class Vgg19(torch.nn.Module):
|
10 |
+
"""
|
11 |
+
Vgg19 network for perceptual loss. See Sec 3.3.
|
12 |
+
"""
|
13 |
+
def __init__(self, requires_grad=False):
|
14 |
+
super(Vgg19, self).__init__()
|
15 |
+
vgg_pretrained_features = models.vgg19(pretrained=True).features
|
16 |
+
self.slice1 = torch.nn.Sequential()
|
17 |
+
self.slice2 = torch.nn.Sequential()
|
18 |
+
self.slice3 = torch.nn.Sequential()
|
19 |
+
self.slice4 = torch.nn.Sequential()
|
20 |
+
self.slice5 = torch.nn.Sequential()
|
21 |
+
for x in range(2):
|
22 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
23 |
+
for x in range(2, 7):
|
24 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
25 |
+
for x in range(7, 12):
|
26 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
27 |
+
for x in range(12, 21):
|
28 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
29 |
+
for x in range(21, 30):
|
30 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
31 |
+
|
32 |
+
self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
|
33 |
+
requires_grad=False)
|
34 |
+
self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
|
35 |
+
requires_grad=False)
|
36 |
+
|
37 |
+
if not requires_grad:
|
38 |
+
for param in self.parameters():
|
39 |
+
param.requires_grad = False
|
40 |
+
|
41 |
+
def forward(self, X):
|
42 |
+
X = (X - self.mean) / self.std
|
43 |
+
h_relu1 = self.slice1(X)
|
44 |
+
h_relu2 = self.slice2(h_relu1)
|
45 |
+
h_relu3 = self.slice3(h_relu2)
|
46 |
+
h_relu4 = self.slice4(h_relu3)
|
47 |
+
h_relu5 = self.slice5(h_relu4)
|
48 |
+
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
49 |
+
return out
|
50 |
+
|
51 |
+
|
52 |
+
class ImagePyramide(torch.nn.Module):
|
53 |
+
"""
|
54 |
+
Create image pyramide for computing pyramide perceptual loss. See Sec 3.3
|
55 |
+
"""
|
56 |
+
def __init__(self, scales, num_channels):
|
57 |
+
super(ImagePyramide, self).__init__()
|
58 |
+
downs = {}
|
59 |
+
for scale in scales:
|
60 |
+
downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
|
61 |
+
self.downs = nn.ModuleDict(downs)
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
out_dict = {}
|
65 |
+
for scale, down_module in self.downs.items():
|
66 |
+
out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
|
67 |
+
return out_dict
|
68 |
+
|
69 |
+
|
70 |
+
def detach_kp(kp):
|
71 |
+
return {key: value.detach() for key, value in kp.items()}
|
72 |
+
|
73 |
+
|
74 |
+
class GeneratorFullModel(torch.nn.Module):
|
75 |
+
"""
|
76 |
+
Merge all generator related updates into single model for better multi-gpu usage
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(self, kp_extractor, bg_predictor, dense_motion_network, inpainting_network, train_params, *kwargs):
|
80 |
+
super(GeneratorFullModel, self).__init__()
|
81 |
+
self.kp_extractor = kp_extractor
|
82 |
+
self.inpainting_network = inpainting_network
|
83 |
+
self.dense_motion_network = dense_motion_network
|
84 |
+
|
85 |
+
self.bg_predictor = None
|
86 |
+
if bg_predictor:
|
87 |
+
self.bg_predictor = bg_predictor
|
88 |
+
self.bg_start = train_params['bg_start']
|
89 |
+
|
90 |
+
self.train_params = train_params
|
91 |
+
self.scales = train_params['scales']
|
92 |
+
|
93 |
+
self.pyramid = ImagePyramide(self.scales, inpainting_network.num_channels)
|
94 |
+
if torch.cuda.is_available():
|
95 |
+
self.pyramid = self.pyramid.cuda()
|
96 |
+
|
97 |
+
self.loss_weights = train_params['loss_weights']
|
98 |
+
self.dropout_epoch = train_params['dropout_epoch']
|
99 |
+
self.dropout_maxp = train_params['dropout_maxp']
|
100 |
+
self.dropout_inc_epoch = train_params['dropout_inc_epoch']
|
101 |
+
self.dropout_startp =train_params['dropout_startp']
|
102 |
+
|
103 |
+
if sum(self.loss_weights['perceptual']) != 0:
|
104 |
+
self.vgg = Vgg19()
|
105 |
+
if torch.cuda.is_available():
|
106 |
+
self.vgg = self.vgg.cuda()
|
107 |
+
|
108 |
+
|
109 |
+
def forward(self, x, epoch):
|
110 |
+
kp_source = self.kp_extractor(x['source'])
|
111 |
+
kp_driving = self.kp_extractor(x['driving'])
|
112 |
+
bg_param = None
|
113 |
+
if self.bg_predictor:
|
114 |
+
if(epoch>=self.bg_start):
|
115 |
+
bg_param = self.bg_predictor(x['source'], x['driving'])
|
116 |
+
|
117 |
+
if(epoch>=self.dropout_epoch):
|
118 |
+
dropout_flag = False
|
119 |
+
dropout_p = 0
|
120 |
+
else:
|
121 |
+
# dropout_p will linearly increase from dropout_startp to dropout_maxp
|
122 |
+
dropout_flag = True
|
123 |
+
dropout_p = min(epoch/self.dropout_inc_epoch * self.dropout_maxp + self.dropout_startp, self.dropout_maxp)
|
124 |
+
|
125 |
+
dense_motion = self.dense_motion_network(source_image=x['source'], kp_driving=kp_driving,
|
126 |
+
kp_source=kp_source, bg_param = bg_param,
|
127 |
+
dropout_flag = dropout_flag, dropout_p = dropout_p)
|
128 |
+
generated = self.inpainting_network(x['source'], dense_motion)
|
129 |
+
generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
|
130 |
+
|
131 |
+
loss_values = {}
|
132 |
+
|
133 |
+
pyramide_real = self.pyramid(x['driving'])
|
134 |
+
pyramide_generated = self.pyramid(generated['prediction'])
|
135 |
+
|
136 |
+
# reconstruction loss
|
137 |
+
if sum(self.loss_weights['perceptual']) != 0:
|
138 |
+
value_total = 0
|
139 |
+
for scale in self.scales:
|
140 |
+
x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
|
141 |
+
y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
|
142 |
+
|
143 |
+
for i, weight in enumerate(self.loss_weights['perceptual']):
|
144 |
+
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
|
145 |
+
value_total += self.loss_weights['perceptual'][i] * value
|
146 |
+
loss_values['perceptual'] = value_total
|
147 |
+
|
148 |
+
# equivariance loss
|
149 |
+
if self.loss_weights['equivariance_value'] != 0:
|
150 |
+
transform_random = TPS(mode = 'random', bs = x['driving'].shape[0], **self.train_params['transform_params'])
|
151 |
+
transform_grid = transform_random.transform_frame(x['driving'])
|
152 |
+
transformed_frame = F.grid_sample(x['driving'], transform_grid, padding_mode="reflection",align_corners=True)
|
153 |
+
transformed_kp = self.kp_extractor(transformed_frame)
|
154 |
+
|
155 |
+
generated['transformed_frame'] = transformed_frame
|
156 |
+
generated['transformed_kp'] = transformed_kp
|
157 |
+
|
158 |
+
warped = transform_random.warp_coordinates(transformed_kp['fg_kp'])
|
159 |
+
kp_d = kp_driving['fg_kp']
|
160 |
+
value = torch.abs(kp_d - warped).mean()
|
161 |
+
loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value
|
162 |
+
|
163 |
+
# warp loss
|
164 |
+
if self.loss_weights['warp_loss'] != 0:
|
165 |
+
occlusion_map = generated['occlusion_map']
|
166 |
+
encode_map = self.inpainting_network.get_encode(x['driving'], occlusion_map)
|
167 |
+
decode_map = generated['warped_encoder_maps']
|
168 |
+
value = 0
|
169 |
+
for i in range(len(encode_map)):
|
170 |
+
value += torch.abs(encode_map[i]-decode_map[-i-1]).mean()
|
171 |
+
|
172 |
+
loss_values['warp_loss'] = self.loss_weights['warp_loss'] * value
|
173 |
+
|
174 |
+
# bg loss
|
175 |
+
if self.bg_predictor and epoch >= self.bg_start and self.loss_weights['bg'] != 0:
|
176 |
+
bg_param_reverse = self.bg_predictor(x['driving'], x['source'])
|
177 |
+
value = torch.matmul(bg_param, bg_param_reverse)
|
178 |
+
eye = torch.eye(3).view(1, 1, 3, 3).type(value.type())
|
179 |
+
value = torch.abs(eye - value).mean()
|
180 |
+
loss_values['bg'] = self.loss_weights['bg'] * value
|
181 |
+
|
182 |
+
return loss_values, generated
|
modules/util.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class TPS:
|
7 |
+
'''
|
8 |
+
TPS transformation, mode 'kp' for Eq(2) in the paper, mode 'random' for equivariance loss.
|
9 |
+
'''
|
10 |
+
def __init__(self, mode, bs, **kwargs):
|
11 |
+
self.bs = bs
|
12 |
+
self.mode = mode
|
13 |
+
if mode == 'random':
|
14 |
+
noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3]))
|
15 |
+
self.theta = noise + torch.eye(2, 3).view(1, 2, 3)
|
16 |
+
self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type())
|
17 |
+
self.control_points = self.control_points.unsqueeze(0)
|
18 |
+
self.control_params = torch.normal(mean=0,
|
19 |
+
std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2]))
|
20 |
+
elif mode == 'kp':
|
21 |
+
kp_1 = kwargs["kp_1"]
|
22 |
+
kp_2 = kwargs["kp_2"]
|
23 |
+
device = kp_1.device
|
24 |
+
kp_type = kp_1.type()
|
25 |
+
self.gs = kp_1.shape[1]
|
26 |
+
n = kp_1.shape[2]
|
27 |
+
K = torch.norm(kp_1[:,:,:, None]-kp_1[:,:, None, :], dim=4, p=2)
|
28 |
+
K = K**2
|
29 |
+
K = K * torch.log(K+1e-9)
|
30 |
+
|
31 |
+
one1 = torch.ones(self.bs, kp_1.shape[1], kp_1.shape[2], 1).to(device).type(kp_type)
|
32 |
+
kp_1p = torch.cat([kp_1,one1], 3)
|
33 |
+
|
34 |
+
zero = torch.zeros(self.bs, kp_1.shape[1], 3, 3).to(device).type(kp_type)
|
35 |
+
P = torch.cat([kp_1p, zero],2)
|
36 |
+
L = torch.cat([K,kp_1p.permute(0,1,3,2)],2)
|
37 |
+
L = torch.cat([L,P],3)
|
38 |
+
|
39 |
+
zero = torch.zeros(self.bs, kp_1.shape[1], 3, 2).to(device).type(kp_type)
|
40 |
+
Y = torch.cat([kp_2, zero], 2)
|
41 |
+
one = torch.eye(L.shape[2]).expand(L.shape).to(device).type(kp_type)*0.01
|
42 |
+
L = L + one
|
43 |
+
|
44 |
+
param = torch.matmul(torch.inverse(L),Y)
|
45 |
+
self.theta = param[:,:,n:,:].permute(0,1,3,2)
|
46 |
+
|
47 |
+
self.control_points = kp_1
|
48 |
+
self.control_params = param[:,:,:n,:]
|
49 |
+
else:
|
50 |
+
raise Exception("Error TPS mode")
|
51 |
+
|
52 |
+
def transform_frame(self, frame):
|
53 |
+
grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0).to(frame.device)
|
54 |
+
grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
|
55 |
+
shape = [self.bs, frame.shape[2], frame.shape[3], 2]
|
56 |
+
if self.mode == 'kp':
|
57 |
+
shape.insert(1, self.gs)
|
58 |
+
grid = self.warp_coordinates(grid).view(*shape)
|
59 |
+
return grid
|
60 |
+
|
61 |
+
def warp_coordinates(self, coordinates):
|
62 |
+
theta = self.theta.type(coordinates.type()).to(coordinates.device)
|
63 |
+
control_points = self.control_points.type(coordinates.type()).to(coordinates.device)
|
64 |
+
control_params = self.control_params.type(coordinates.type()).to(coordinates.device)
|
65 |
+
|
66 |
+
if self.mode == 'kp':
|
67 |
+
transformed = torch.matmul(theta[:, :, :, :2], coordinates.permute(0, 2, 1)) + theta[:, :, :, 2:]
|
68 |
+
|
69 |
+
distances = coordinates.view(coordinates.shape[0], 1, 1, -1, 2) - control_points.view(self.bs, control_points.shape[1], -1, 1, 2)
|
70 |
+
|
71 |
+
distances = distances ** 2
|
72 |
+
result = distances.sum(-1)
|
73 |
+
result = result * torch.log(result + 1e-9)
|
74 |
+
result = torch.matmul(result.permute(0, 1, 3, 2), control_params)
|
75 |
+
transformed = transformed.permute(0, 1, 3, 2) + result
|
76 |
+
|
77 |
+
elif self.mode == 'random':
|
78 |
+
theta = theta.unsqueeze(1)
|
79 |
+
transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:]
|
80 |
+
transformed = transformed.squeeze(-1)
|
81 |
+
ances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
|
82 |
+
distances = ances ** 2
|
83 |
+
|
84 |
+
result = distances.sum(-1)
|
85 |
+
result = result * torch.log(result + 1e-9)
|
86 |
+
result = result * control_params
|
87 |
+
result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
|
88 |
+
transformed = transformed + result
|
89 |
+
else:
|
90 |
+
raise Exception("Error TPS mode")
|
91 |
+
|
92 |
+
return transformed
|
93 |
+
|
94 |
+
|
95 |
+
def kp2gaussian(kp, spatial_size, kp_variance):
|
96 |
+
"""
|
97 |
+
Transform a keypoint into gaussian like representation
|
98 |
+
"""
|
99 |
+
|
100 |
+
coordinate_grid = make_coordinate_grid(spatial_size, kp.type()).to(kp.device)
|
101 |
+
number_of_leading_dimensions = len(kp.shape) - 1
|
102 |
+
shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
|
103 |
+
coordinate_grid = coordinate_grid.view(*shape)
|
104 |
+
repeats = kp.shape[:number_of_leading_dimensions] + (1, 1, 1)
|
105 |
+
coordinate_grid = coordinate_grid.repeat(*repeats)
|
106 |
+
|
107 |
+
# Preprocess kp shape
|
108 |
+
shape = kp.shape[:number_of_leading_dimensions] + (1, 1, 2)
|
109 |
+
kp = kp.view(*shape)
|
110 |
+
|
111 |
+
mean_sub = (coordinate_grid - kp)
|
112 |
+
|
113 |
+
out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
|
114 |
+
|
115 |
+
return out
|
116 |
+
|
117 |
+
|
118 |
+
def make_coordinate_grid(spatial_size, type):
|
119 |
+
"""
|
120 |
+
Create a meshgrid [-1,1] x [-1,1] of given spatial_size.
|
121 |
+
"""
|
122 |
+
h, w = spatial_size
|
123 |
+
x = torch.arange(w).type(type)
|
124 |
+
y = torch.arange(h).type(type)
|
125 |
+
|
126 |
+
x = (2 * (x / (w - 1)) - 1)
|
127 |
+
y = (2 * (y / (h - 1)) - 1)
|
128 |
+
|
129 |
+
yy = y.view(-1, 1).repeat(1, w)
|
130 |
+
xx = x.view(1, -1).repeat(h, 1)
|
131 |
+
|
132 |
+
meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
|
133 |
+
|
134 |
+
return meshed
|
135 |
+
|
136 |
+
|
137 |
+
class ResBlock2d(nn.Module):
|
138 |
+
"""
|
139 |
+
Res block, preserve spatial resolution.
|
140 |
+
"""
|
141 |
+
|
142 |
+
def __init__(self, in_features, kernel_size, padding):
|
143 |
+
super(ResBlock2d, self).__init__()
|
144 |
+
self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
|
145 |
+
padding=padding)
|
146 |
+
self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
|
147 |
+
padding=padding)
|
148 |
+
self.norm1 = nn.InstanceNorm2d(in_features, affine=True)
|
149 |
+
self.norm2 = nn.InstanceNorm2d(in_features, affine=True)
|
150 |
+
|
151 |
+
def forward(self, x):
|
152 |
+
out = self.norm1(x)
|
153 |
+
out = F.relu(out)
|
154 |
+
out = self.conv1(out)
|
155 |
+
out = self.norm2(out)
|
156 |
+
out = F.relu(out)
|
157 |
+
out = self.conv2(out)
|
158 |
+
out += x
|
159 |
+
return out
|
160 |
+
|
161 |
+
|
162 |
+
class UpBlock2d(nn.Module):
|
163 |
+
"""
|
164 |
+
Upsampling block for use in decoder.
|
165 |
+
"""
|
166 |
+
|
167 |
+
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
|
168 |
+
super(UpBlock2d, self).__init__()
|
169 |
+
|
170 |
+
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
|
171 |
+
padding=padding, groups=groups)
|
172 |
+
self.norm = nn.InstanceNorm2d(out_features, affine=True)
|
173 |
+
|
174 |
+
def forward(self, x):
|
175 |
+
out = F.interpolate(x, scale_factor=2)
|
176 |
+
out = self.conv(out)
|
177 |
+
out = self.norm(out)
|
178 |
+
out = F.relu(out)
|
179 |
+
return out
|
180 |
+
|
181 |
+
|
182 |
+
class DownBlock2d(nn.Module):
|
183 |
+
"""
|
184 |
+
Downsampling block for use in encoder.
|
185 |
+
"""
|
186 |
+
|
187 |
+
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
|
188 |
+
super(DownBlock2d, self).__init__()
|
189 |
+
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
|
190 |
+
padding=padding, groups=groups)
|
191 |
+
self.norm = nn.InstanceNorm2d(out_features, affine=True)
|
192 |
+
self.pool = nn.AvgPool2d(kernel_size=(2, 2))
|
193 |
+
|
194 |
+
def forward(self, x):
|
195 |
+
out = self.conv(x)
|
196 |
+
out = self.norm(out)
|
197 |
+
out = F.relu(out)
|
198 |
+
out = self.pool(out)
|
199 |
+
return out
|
200 |
+
|
201 |
+
|
202 |
+
class SameBlock2d(nn.Module):
|
203 |
+
"""
|
204 |
+
Simple block, preserve spatial resolution.
|
205 |
+
"""
|
206 |
+
|
207 |
+
def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1):
|
208 |
+
super(SameBlock2d, self).__init__()
|
209 |
+
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features,
|
210 |
+
kernel_size=kernel_size, padding=padding, groups=groups)
|
211 |
+
self.norm = nn.InstanceNorm2d(out_features, affine=True)
|
212 |
+
|
213 |
+
def forward(self, x):
|
214 |
+
out = self.conv(x)
|
215 |
+
out = self.norm(out)
|
216 |
+
out = F.relu(out)
|
217 |
+
return out
|
218 |
+
|
219 |
+
|
220 |
+
class Encoder(nn.Module):
|
221 |
+
"""
|
222 |
+
Hourglass Encoder
|
223 |
+
"""
|
224 |
+
|
225 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
226 |
+
super(Encoder, self).__init__()
|
227 |
+
|
228 |
+
down_blocks = []
|
229 |
+
for i in range(num_blocks):
|
230 |
+
down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
|
231 |
+
min(max_features, block_expansion * (2 ** (i + 1))),
|
232 |
+
kernel_size=3, padding=1))
|
233 |
+
self.down_blocks = nn.ModuleList(down_blocks)
|
234 |
+
|
235 |
+
def forward(self, x):
|
236 |
+
outs = [x]
|
237 |
+
#print('encoder:' ,outs[-1].shape)
|
238 |
+
for down_block in self.down_blocks:
|
239 |
+
outs.append(down_block(outs[-1]))
|
240 |
+
#print('encoder:' ,outs[-1].shape)
|
241 |
+
return outs
|
242 |
+
|
243 |
+
|
244 |
+
class Decoder(nn.Module):
|
245 |
+
"""
|
246 |
+
Hourglass Decoder
|
247 |
+
"""
|
248 |
+
|
249 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
250 |
+
super(Decoder, self).__init__()
|
251 |
+
|
252 |
+
up_blocks = []
|
253 |
+
self.out_channels = []
|
254 |
+
for i in range(num_blocks)[::-1]:
|
255 |
+
in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
|
256 |
+
self.out_channels.append(in_filters)
|
257 |
+
out_filters = min(max_features, block_expansion * (2 ** i))
|
258 |
+
up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1))
|
259 |
+
|
260 |
+
self.up_blocks = nn.ModuleList(up_blocks)
|
261 |
+
self.out_channels.append(block_expansion + in_features)
|
262 |
+
# self.out_filters = block_expansion + in_features
|
263 |
+
|
264 |
+
def forward(self, x, mode = 0):
|
265 |
+
out = x.pop()
|
266 |
+
outs = []
|
267 |
+
for up_block in self.up_blocks:
|
268 |
+
out = up_block(out)
|
269 |
+
skip = x.pop()
|
270 |
+
out = torch.cat([out, skip], dim=1)
|
271 |
+
outs.append(out)
|
272 |
+
if(mode == 0):
|
273 |
+
return out
|
274 |
+
else:
|
275 |
+
return outs
|
276 |
+
|
277 |
+
|
278 |
+
class Hourglass(nn.Module):
|
279 |
+
"""
|
280 |
+
Hourglass architecture.
|
281 |
+
"""
|
282 |
+
|
283 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
284 |
+
super(Hourglass, self).__init__()
|
285 |
+
self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
|
286 |
+
self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
|
287 |
+
self.out_channels = self.decoder.out_channels
|
288 |
+
# self.out_filters = self.decoder.out_filters
|
289 |
+
|
290 |
+
def forward(self, x, mode = 0):
|
291 |
+
return self.decoder(self.encoder(x), mode)
|
292 |
+
|
293 |
+
|
294 |
+
class AntiAliasInterpolation2d(nn.Module):
|
295 |
+
"""
|
296 |
+
Band-limited downsampling, for better preservation of the input signal.
|
297 |
+
"""
|
298 |
+
def __init__(self, channels, scale):
|
299 |
+
super(AntiAliasInterpolation2d, self).__init__()
|
300 |
+
sigma = (1 / scale - 1) / 2
|
301 |
+
kernel_size = 2 * round(sigma * 4) + 1
|
302 |
+
self.ka = kernel_size // 2
|
303 |
+
self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
|
304 |
+
|
305 |
+
kernel_size = [kernel_size, kernel_size]
|
306 |
+
sigma = [sigma, sigma]
|
307 |
+
# The gaussian kernel is the product of the
|
308 |
+
# gaussian function of each dimension.
|
309 |
+
kernel = 1
|
310 |
+
meshgrids = torch.meshgrid(
|
311 |
+
[
|
312 |
+
torch.arange(size, dtype=torch.float32)
|
313 |
+
for size in kernel_size
|
314 |
+
]
|
315 |
+
)
|
316 |
+
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
317 |
+
mean = (size - 1) / 2
|
318 |
+
kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
|
319 |
+
|
320 |
+
# Make sure sum of values in gaussian kernel equals 1.
|
321 |
+
kernel = kernel / torch.sum(kernel)
|
322 |
+
# Reshape to depthwise convolutional weight
|
323 |
+
kernel = kernel.view(1, 1, *kernel.size())
|
324 |
+
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
325 |
+
|
326 |
+
self.register_buffer('weight', kernel)
|
327 |
+
self.groups = channels
|
328 |
+
self.scale = scale
|
329 |
+
|
330 |
+
def forward(self, input):
|
331 |
+
if self.scale == 1.0:
|
332 |
+
return input
|
333 |
+
|
334 |
+
out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
|
335 |
+
out = F.conv2d(out, weight=self.weight, groups=self.groups)
|
336 |
+
out = F.interpolate(out, scale_factor=(self.scale, self.scale))
|
337 |
+
|
338 |
+
return out
|
339 |
+
|
340 |
+
|
341 |
+
def to_homogeneous(coordinates):
|
342 |
+
ones_shape = list(coordinates.shape)
|
343 |
+
ones_shape[-1] = 1
|
344 |
+
ones = torch.ones(ones_shape).type(coordinates.type())
|
345 |
+
|
346 |
+
return torch.cat([coordinates, ones], dim=-1)
|
347 |
+
|
348 |
+
def from_homogeneous(coordinates):
|
349 |
+
return coordinates[..., :2] / coordinates[..., 2:3]
|
requirements.txt
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
cffi==1.14.6
|
3 |
+
cycler==0.10.0
|
4 |
+
decorator==5.1.0
|
5 |
+
face-alignment==1.3.5
|
6 |
+
imageio==2.9.0
|
7 |
+
imageio-ffmpeg==0.4.5
|
8 |
+
kiwisolver==1.3.2
|
9 |
+
matplotlib==3.4.3
|
10 |
+
networkx==2.6.3
|
11 |
+
numpy==1.20.3
|
12 |
+
pandas==1.3.3
|
13 |
+
Pillow==8.3.2
|
14 |
+
pycparser==2.20
|
15 |
+
pyparsing==2.4.7
|
16 |
+
python-dateutil==2.8.2
|
17 |
+
pytz==2021.1
|
18 |
+
PyWavelets==1.1.1
|
19 |
+
PyYAML==5.4.1
|
20 |
+
scikit-image==0.18.3
|
21 |
+
scikit-learn==1.0
|
22 |
+
scipy==1.7.1
|
23 |
+
six==1.16.0
|
24 |
+
torch==1.10.0+cu113
|
25 |
+
torchvision==0.11.0+cu113
|
26 |
+
tqdm==4.62.3
|
27 |
+
face-alignment
|