Spaces:
Runtime error
Runtime error
maduvantha
commited on
Commit
•
54bf1bc
1
Parent(s):
eddf80e
Upload 14 files
Browse files- .gitattributes +1 -0
- 00.mp4 +0 -0
- animate.py +101 -0
- augmentation.py +345 -0
- crop-video.py +158 -0
- demo.py +157 -0
- discordbot.py +0 -0
- frames_dataset.py +197 -0
- generated.mp4 +0 -0
- got-05.jpg +3 -0
- logger.py +208 -0
- reconstruction.py +67 -0
- run.py +87 -0
- sdkan.png +0 -0
- train.py +87 -0
.gitattributes
CHANGED
@@ -38,3 +38,4 @@ sup-mat/fashion-teaser.gif filter=lfs diff=lfs merge=lfs -text
|
|
38 |
sup-mat/mgif-teaser.gif filter=lfs diff=lfs merge=lfs -text
|
39 |
sup-mat/relative-demo.gif filter=lfs diff=lfs merge=lfs -text
|
40 |
sup-mat/vox-teaser.gif filter=lfs diff=lfs merge=lfs -text
|
|
|
|
38 |
sup-mat/mgif-teaser.gif filter=lfs diff=lfs merge=lfs -text
|
39 |
sup-mat/relative-demo.gif filter=lfs diff=lfs merge=lfs -text
|
40 |
sup-mat/vox-teaser.gif filter=lfs diff=lfs merge=lfs -text
|
41 |
+
got-05.jpg filter=lfs diff=lfs merge=lfs -text
|
00.mp4
ADDED
Binary file (527 kB). View file
|
|
animate.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from tqdm import tqdm
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
|
7 |
+
from frames_dataset import PairedDataset
|
8 |
+
from logger import Logger, Visualizer
|
9 |
+
import imageio
|
10 |
+
from scipy.spatial import ConvexHull
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
from sync_batchnorm import DataParallelWithCallback
|
14 |
+
|
15 |
+
|
16 |
+
def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,
|
17 |
+
use_relative_movement=False, use_relative_jacobian=False):
|
18 |
+
if adapt_movement_scale:
|
19 |
+
source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume
|
20 |
+
driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume
|
21 |
+
adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
|
22 |
+
else:
|
23 |
+
adapt_movement_scale = 1
|
24 |
+
|
25 |
+
kp_new = {k: v for k, v in kp_driving.items()}
|
26 |
+
|
27 |
+
if use_relative_movement:
|
28 |
+
kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
|
29 |
+
kp_value_diff *= adapt_movement_scale
|
30 |
+
kp_new['value'] = kp_value_diff + kp_source['value']
|
31 |
+
|
32 |
+
if use_relative_jacobian:
|
33 |
+
jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
|
34 |
+
kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
|
35 |
+
|
36 |
+
return kp_new
|
37 |
+
|
38 |
+
|
39 |
+
def animate(config, generator, kp_detector, checkpoint, log_dir, dataset):
|
40 |
+
log_dir = os.path.join(log_dir, 'animation')
|
41 |
+
png_dir = os.path.join(log_dir, 'png')
|
42 |
+
animate_params = config['animate_params']
|
43 |
+
|
44 |
+
dataset = PairedDataset(initial_dataset=dataset, number_of_pairs=animate_params['num_pairs'])
|
45 |
+
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
|
46 |
+
|
47 |
+
if checkpoint is not None:
|
48 |
+
Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector)
|
49 |
+
else:
|
50 |
+
raise AttributeError("Checkpoint should be specified for mode='animate'.")
|
51 |
+
|
52 |
+
if not os.path.exists(log_dir):
|
53 |
+
os.makedirs(log_dir)
|
54 |
+
|
55 |
+
if not os.path.exists(png_dir):
|
56 |
+
os.makedirs(png_dir)
|
57 |
+
|
58 |
+
if torch.cuda.is_available():
|
59 |
+
generator = DataParallelWithCallback(generator)
|
60 |
+
kp_detector = DataParallelWithCallback(kp_detector)
|
61 |
+
|
62 |
+
generator.eval()
|
63 |
+
kp_detector.eval()
|
64 |
+
|
65 |
+
for it, x in tqdm(enumerate(dataloader)):
|
66 |
+
with torch.no_grad():
|
67 |
+
predictions = []
|
68 |
+
visualizations = []
|
69 |
+
|
70 |
+
driving_video = x['driving_video']
|
71 |
+
source_frame = x['source_video'][:, :, 0, :, :]
|
72 |
+
|
73 |
+
kp_source = kp_detector(source_frame)
|
74 |
+
kp_driving_initial = kp_detector(driving_video[:, :, 0])
|
75 |
+
|
76 |
+
for frame_idx in range(driving_video.shape[2]):
|
77 |
+
driving_frame = driving_video[:, :, frame_idx]
|
78 |
+
kp_driving = kp_detector(driving_frame)
|
79 |
+
kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
|
80 |
+
kp_driving_initial=kp_driving_initial, **animate_params['normalization_params'])
|
81 |
+
out = generator(source_frame, kp_source=kp_source, kp_driving=kp_norm)
|
82 |
+
|
83 |
+
out['kp_driving'] = kp_driving
|
84 |
+
out['kp_source'] = kp_source
|
85 |
+
out['kp_norm'] = kp_norm
|
86 |
+
|
87 |
+
del out['sparse_deformed']
|
88 |
+
|
89 |
+
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
|
90 |
+
|
91 |
+
visualization = Visualizer(**config['visualizer_params']).visualize(source=source_frame,
|
92 |
+
driving=driving_frame, out=out)
|
93 |
+
visualization = visualization
|
94 |
+
visualizations.append(visualization)
|
95 |
+
|
96 |
+
predictions = np.concatenate(predictions, axis=1)
|
97 |
+
result_name = "-".join([x['driving_name'][0], x['source_name'][0]])
|
98 |
+
imageio.imsave(os.path.join(png_dir, result_name + '.png'), (255 * predictions).astype(np.uint8))
|
99 |
+
|
100 |
+
image_name = result_name + animate_params['format']
|
101 |
+
imageio.mimsave(os.path.join(log_dir, image_name), visualizations)
|
augmentation.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code from https://github.com/hassony2/torch_videovision
|
3 |
+
"""
|
4 |
+
|
5 |
+
import numbers
|
6 |
+
|
7 |
+
import random
|
8 |
+
import numpy as np
|
9 |
+
import PIL
|
10 |
+
|
11 |
+
from skimage.transform import resize, rotate
|
12 |
+
from skimage.util import pad
|
13 |
+
import torchvision
|
14 |
+
|
15 |
+
import warnings
|
16 |
+
|
17 |
+
from skimage import img_as_ubyte, img_as_float
|
18 |
+
|
19 |
+
|
20 |
+
def crop_clip(clip, min_h, min_w, h, w):
|
21 |
+
if isinstance(clip[0], np.ndarray):
|
22 |
+
cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]
|
23 |
+
|
24 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
25 |
+
cropped = [
|
26 |
+
img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
|
27 |
+
]
|
28 |
+
else:
|
29 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
30 |
+
'but got list of {0}'.format(type(clip[0])))
|
31 |
+
return cropped
|
32 |
+
|
33 |
+
|
34 |
+
def pad_clip(clip, h, w):
|
35 |
+
im_h, im_w = clip[0].shape[:2]
|
36 |
+
pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2)
|
37 |
+
pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2)
|
38 |
+
|
39 |
+
return pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge')
|
40 |
+
|
41 |
+
|
42 |
+
def resize_clip(clip, size, interpolation='bilinear'):
|
43 |
+
if isinstance(clip[0], np.ndarray):
|
44 |
+
if isinstance(size, numbers.Number):
|
45 |
+
im_h, im_w, im_c = clip[0].shape
|
46 |
+
# Min spatial dim already matches minimal size
|
47 |
+
if (im_w <= im_h and im_w == size) or (im_h <= im_w
|
48 |
+
and im_h == size):
|
49 |
+
return clip
|
50 |
+
new_h, new_w = get_resize_sizes(im_h, im_w, size)
|
51 |
+
size = (new_w, new_h)
|
52 |
+
else:
|
53 |
+
size = size[1], size[0]
|
54 |
+
|
55 |
+
scaled = [
|
56 |
+
resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True,
|
57 |
+
mode='constant', anti_aliasing=True) for img in clip
|
58 |
+
]
|
59 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
60 |
+
if isinstance(size, numbers.Number):
|
61 |
+
im_w, im_h = clip[0].size
|
62 |
+
# Min spatial dim already matches minimal size
|
63 |
+
if (im_w <= im_h and im_w == size) or (im_h <= im_w
|
64 |
+
and im_h == size):
|
65 |
+
return clip
|
66 |
+
new_h, new_w = get_resize_sizes(im_h, im_w, size)
|
67 |
+
size = (new_w, new_h)
|
68 |
+
else:
|
69 |
+
size = size[1], size[0]
|
70 |
+
if interpolation == 'bilinear':
|
71 |
+
pil_inter = PIL.Image.NEAREST
|
72 |
+
else:
|
73 |
+
pil_inter = PIL.Image.BILINEAR
|
74 |
+
scaled = [img.resize(size, pil_inter) for img in clip]
|
75 |
+
else:
|
76 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
77 |
+
'but got list of {0}'.format(type(clip[0])))
|
78 |
+
return scaled
|
79 |
+
|
80 |
+
|
81 |
+
def get_resize_sizes(im_h, im_w, size):
|
82 |
+
if im_w < im_h:
|
83 |
+
ow = size
|
84 |
+
oh = int(size * im_h / im_w)
|
85 |
+
else:
|
86 |
+
oh = size
|
87 |
+
ow = int(size * im_w / im_h)
|
88 |
+
return oh, ow
|
89 |
+
|
90 |
+
|
91 |
+
class RandomFlip(object):
|
92 |
+
def __init__(self, time_flip=False, horizontal_flip=False):
|
93 |
+
self.time_flip = time_flip
|
94 |
+
self.horizontal_flip = horizontal_flip
|
95 |
+
|
96 |
+
def __call__(self, clip):
|
97 |
+
if random.random() < 0.5 and self.time_flip:
|
98 |
+
return clip[::-1]
|
99 |
+
if random.random() < 0.5 and self.horizontal_flip:
|
100 |
+
return [np.fliplr(img) for img in clip]
|
101 |
+
|
102 |
+
return clip
|
103 |
+
|
104 |
+
|
105 |
+
class RandomResize(object):
|
106 |
+
"""Resizes a list of (H x W x C) numpy.ndarray to the final size
|
107 |
+
The larger the original image is, the more times it takes to
|
108 |
+
interpolate
|
109 |
+
Args:
|
110 |
+
interpolation (str): Can be one of 'nearest', 'bilinear'
|
111 |
+
defaults to nearest
|
112 |
+
size (tuple): (widht, height)
|
113 |
+
"""
|
114 |
+
|
115 |
+
def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'):
|
116 |
+
self.ratio = ratio
|
117 |
+
self.interpolation = interpolation
|
118 |
+
|
119 |
+
def __call__(self, clip):
|
120 |
+
scaling_factor = random.uniform(self.ratio[0], self.ratio[1])
|
121 |
+
|
122 |
+
if isinstance(clip[0], np.ndarray):
|
123 |
+
im_h, im_w, im_c = clip[0].shape
|
124 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
125 |
+
im_w, im_h = clip[0].size
|
126 |
+
|
127 |
+
new_w = int(im_w * scaling_factor)
|
128 |
+
new_h = int(im_h * scaling_factor)
|
129 |
+
new_size = (new_w, new_h)
|
130 |
+
resized = resize_clip(
|
131 |
+
clip, new_size, interpolation=self.interpolation)
|
132 |
+
|
133 |
+
return resized
|
134 |
+
|
135 |
+
|
136 |
+
class RandomCrop(object):
|
137 |
+
"""Extract random crop at the same location for a list of videos
|
138 |
+
Args:
|
139 |
+
size (sequence or int): Desired output size for the
|
140 |
+
crop in format (h, w)
|
141 |
+
"""
|
142 |
+
|
143 |
+
def __init__(self, size):
|
144 |
+
if isinstance(size, numbers.Number):
|
145 |
+
size = (size, size)
|
146 |
+
|
147 |
+
self.size = size
|
148 |
+
|
149 |
+
def __call__(self, clip):
|
150 |
+
"""
|
151 |
+
Args:
|
152 |
+
img (PIL.Image or numpy.ndarray): List of videos to be cropped
|
153 |
+
in format (h, w, c) in numpy.ndarray
|
154 |
+
Returns:
|
155 |
+
PIL.Image or numpy.ndarray: Cropped list of videos
|
156 |
+
"""
|
157 |
+
h, w = self.size
|
158 |
+
if isinstance(clip[0], np.ndarray):
|
159 |
+
im_h, im_w, im_c = clip[0].shape
|
160 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
161 |
+
im_w, im_h = clip[0].size
|
162 |
+
else:
|
163 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
164 |
+
'but got list of {0}'.format(type(clip[0])))
|
165 |
+
|
166 |
+
clip = pad_clip(clip, h, w)
|
167 |
+
im_h, im_w = clip.shape[1:3]
|
168 |
+
x1 = 0 if h == im_h else random.randint(0, im_w - w)
|
169 |
+
y1 = 0 if w == im_w else random.randint(0, im_h - h)
|
170 |
+
cropped = crop_clip(clip, y1, x1, h, w)
|
171 |
+
|
172 |
+
return cropped
|
173 |
+
|
174 |
+
|
175 |
+
class RandomRotation(object):
|
176 |
+
"""Rotate entire clip randomly by a random angle within
|
177 |
+
given bounds
|
178 |
+
Args:
|
179 |
+
degrees (sequence or int): Range of degrees to select from
|
180 |
+
If degrees is a number instead of sequence like (min, max),
|
181 |
+
the range of degrees, will be (-degrees, +degrees).
|
182 |
+
"""
|
183 |
+
|
184 |
+
def __init__(self, degrees):
|
185 |
+
if isinstance(degrees, numbers.Number):
|
186 |
+
if degrees < 0:
|
187 |
+
raise ValueError('If degrees is a single number,'
|
188 |
+
'must be positive')
|
189 |
+
degrees = (-degrees, degrees)
|
190 |
+
else:
|
191 |
+
if len(degrees) != 2:
|
192 |
+
raise ValueError('If degrees is a sequence,'
|
193 |
+
'it must be of len 2.')
|
194 |
+
|
195 |
+
self.degrees = degrees
|
196 |
+
|
197 |
+
def __call__(self, clip):
|
198 |
+
"""
|
199 |
+
Args:
|
200 |
+
img (PIL.Image or numpy.ndarray): List of videos to be cropped
|
201 |
+
in format (h, w, c) in numpy.ndarray
|
202 |
+
Returns:
|
203 |
+
PIL.Image or numpy.ndarray: Cropped list of videos
|
204 |
+
"""
|
205 |
+
angle = random.uniform(self.degrees[0], self.degrees[1])
|
206 |
+
if isinstance(clip[0], np.ndarray):
|
207 |
+
rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip]
|
208 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
209 |
+
rotated = [img.rotate(angle) for img in clip]
|
210 |
+
else:
|
211 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
212 |
+
'but got list of {0}'.format(type(clip[0])))
|
213 |
+
|
214 |
+
return rotated
|
215 |
+
|
216 |
+
|
217 |
+
class ColorJitter(object):
|
218 |
+
"""Randomly change the brightness, contrast and saturation and hue of the clip
|
219 |
+
Args:
|
220 |
+
brightness (float): How much to jitter brightness. brightness_factor
|
221 |
+
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
|
222 |
+
contrast (float): How much to jitter contrast. contrast_factor
|
223 |
+
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
|
224 |
+
saturation (float): How much to jitter saturation. saturation_factor
|
225 |
+
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
|
226 |
+
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
|
227 |
+
[-hue, hue]. Should be >=0 and <= 0.5.
|
228 |
+
"""
|
229 |
+
|
230 |
+
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
|
231 |
+
self.brightness = brightness
|
232 |
+
self.contrast = contrast
|
233 |
+
self.saturation = saturation
|
234 |
+
self.hue = hue
|
235 |
+
|
236 |
+
def get_params(self, brightness, contrast, saturation, hue):
|
237 |
+
if brightness > 0:
|
238 |
+
brightness_factor = random.uniform(
|
239 |
+
max(0, 1 - brightness), 1 + brightness)
|
240 |
+
else:
|
241 |
+
brightness_factor = None
|
242 |
+
|
243 |
+
if contrast > 0:
|
244 |
+
contrast_factor = random.uniform(
|
245 |
+
max(0, 1 - contrast), 1 + contrast)
|
246 |
+
else:
|
247 |
+
contrast_factor = None
|
248 |
+
|
249 |
+
if saturation > 0:
|
250 |
+
saturation_factor = random.uniform(
|
251 |
+
max(0, 1 - saturation), 1 + saturation)
|
252 |
+
else:
|
253 |
+
saturation_factor = None
|
254 |
+
|
255 |
+
if hue > 0:
|
256 |
+
hue_factor = random.uniform(-hue, hue)
|
257 |
+
else:
|
258 |
+
hue_factor = None
|
259 |
+
return brightness_factor, contrast_factor, saturation_factor, hue_factor
|
260 |
+
|
261 |
+
def __call__(self, clip):
|
262 |
+
"""
|
263 |
+
Args:
|
264 |
+
clip (list): list of PIL.Image
|
265 |
+
Returns:
|
266 |
+
list PIL.Image : list of transformed PIL.Image
|
267 |
+
"""
|
268 |
+
if isinstance(clip[0], np.ndarray):
|
269 |
+
brightness, contrast, saturation, hue = self.get_params(
|
270 |
+
self.brightness, self.contrast, self.saturation, self.hue)
|
271 |
+
|
272 |
+
# Create img transform function sequence
|
273 |
+
img_transforms = []
|
274 |
+
if brightness is not None:
|
275 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
|
276 |
+
if saturation is not None:
|
277 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
|
278 |
+
if hue is not None:
|
279 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
|
280 |
+
if contrast is not None:
|
281 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
|
282 |
+
random.shuffle(img_transforms)
|
283 |
+
img_transforms = [img_as_ubyte, torchvision.transforms.ToPILImage()] + img_transforms + [np.array,
|
284 |
+
img_as_float]
|
285 |
+
|
286 |
+
with warnings.catch_warnings():
|
287 |
+
warnings.simplefilter("ignore")
|
288 |
+
jittered_clip = []
|
289 |
+
for img in clip:
|
290 |
+
jittered_img = img
|
291 |
+
for func in img_transforms:
|
292 |
+
jittered_img = func(jittered_img)
|
293 |
+
jittered_clip.append(jittered_img.astype('float32'))
|
294 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
295 |
+
brightness, contrast, saturation, hue = self.get_params(
|
296 |
+
self.brightness, self.contrast, self.saturation, self.hue)
|
297 |
+
|
298 |
+
# Create img transform function sequence
|
299 |
+
img_transforms = []
|
300 |
+
if brightness is not None:
|
301 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
|
302 |
+
if saturation is not None:
|
303 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
|
304 |
+
if hue is not None:
|
305 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
|
306 |
+
if contrast is not None:
|
307 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
|
308 |
+
random.shuffle(img_transforms)
|
309 |
+
|
310 |
+
# Apply to all videos
|
311 |
+
jittered_clip = []
|
312 |
+
for img in clip:
|
313 |
+
for func in img_transforms:
|
314 |
+
jittered_img = func(img)
|
315 |
+
jittered_clip.append(jittered_img)
|
316 |
+
|
317 |
+
else:
|
318 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
319 |
+
'but got list of {0}'.format(type(clip[0])))
|
320 |
+
return jittered_clip
|
321 |
+
|
322 |
+
|
323 |
+
class AllAugmentationTransform:
|
324 |
+
def __init__(self, resize_param=None, rotation_param=None, flip_param=None, crop_param=None, jitter_param=None):
|
325 |
+
self.transforms = []
|
326 |
+
|
327 |
+
if flip_param is not None:
|
328 |
+
self.transforms.append(RandomFlip(**flip_param))
|
329 |
+
|
330 |
+
if rotation_param is not None:
|
331 |
+
self.transforms.append(RandomRotation(**rotation_param))
|
332 |
+
|
333 |
+
if resize_param is not None:
|
334 |
+
self.transforms.append(RandomResize(**resize_param))
|
335 |
+
|
336 |
+
if crop_param is not None:
|
337 |
+
self.transforms.append(RandomCrop(**crop_param))
|
338 |
+
|
339 |
+
if jitter_param is not None:
|
340 |
+
self.transforms.append(ColorJitter(**jitter_param))
|
341 |
+
|
342 |
+
def __call__(self, clip):
|
343 |
+
for t in self.transforms:
|
344 |
+
clip = t(clip)
|
345 |
+
return clip
|
crop-video.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import face_alignment
|
2 |
+
import skimage.io
|
3 |
+
import numpy
|
4 |
+
from argparse import ArgumentParser
|
5 |
+
from skimage import img_as_ubyte
|
6 |
+
from skimage.transform import resize
|
7 |
+
from tqdm import tqdm
|
8 |
+
import os
|
9 |
+
import imageio
|
10 |
+
import numpy as np
|
11 |
+
import warnings
|
12 |
+
warnings.filterwarnings("ignore")
|
13 |
+
|
14 |
+
def extract_bbox(frame, fa):
|
15 |
+
if max(frame.shape[0], frame.shape[1]) > 640:
|
16 |
+
scale_factor = max(frame.shape[0], frame.shape[1]) / 640.0
|
17 |
+
frame = resize(frame, (int(frame.shape[0] / scale_factor), int(frame.shape[1] / scale_factor)))
|
18 |
+
frame = img_as_ubyte(frame)
|
19 |
+
else:
|
20 |
+
scale_factor = 1
|
21 |
+
frame = frame[..., :3]
|
22 |
+
bboxes = fa.face_detector.detect_from_image(frame[..., ::-1])
|
23 |
+
if len(bboxes) == 0:
|
24 |
+
return []
|
25 |
+
return np.array(bboxes)[:, :-1] * scale_factor
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
def bb_intersection_over_union(boxA, boxB):
|
30 |
+
xA = max(boxA[0], boxB[0])
|
31 |
+
yA = max(boxA[1], boxB[1])
|
32 |
+
xB = min(boxA[2], boxB[2])
|
33 |
+
yB = min(boxA[3], boxB[3])
|
34 |
+
interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
|
35 |
+
boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
|
36 |
+
boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
|
37 |
+
iou = interArea / float(boxAArea + boxBArea - interArea)
|
38 |
+
return iou
|
39 |
+
|
40 |
+
|
41 |
+
def join(tube_bbox, bbox):
|
42 |
+
xA = min(tube_bbox[0], bbox[0])
|
43 |
+
yA = min(tube_bbox[1], bbox[1])
|
44 |
+
xB = max(tube_bbox[2], bbox[2])
|
45 |
+
yB = max(tube_bbox[3], bbox[3])
|
46 |
+
return (xA, yA, xB, yB)
|
47 |
+
|
48 |
+
|
49 |
+
def compute_bbox(start, end, fps, tube_bbox, frame_shape, inp, image_shape, increase_area=0.1):
|
50 |
+
left, top, right, bot = tube_bbox
|
51 |
+
width = right - left
|
52 |
+
height = bot - top
|
53 |
+
|
54 |
+
#Computing aspect preserving bbox
|
55 |
+
width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width))
|
56 |
+
height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height))
|
57 |
+
|
58 |
+
left = int(left - width_increase * width)
|
59 |
+
top = int(top - height_increase * height)
|
60 |
+
right = int(right + width_increase * width)
|
61 |
+
bot = int(bot + height_increase * height)
|
62 |
+
|
63 |
+
top, bot, left, right = max(0, top), min(bot, frame_shape[0]), max(0, left), min(right, frame_shape[1])
|
64 |
+
h, w = bot - top, right - left
|
65 |
+
|
66 |
+
start = start / fps
|
67 |
+
end = end / fps
|
68 |
+
time = end - start
|
69 |
+
|
70 |
+
scale = f'{image_shape[0]}:{image_shape[1]}'
|
71 |
+
|
72 |
+
return f'ffmpeg -i {inp} -ss {start} -t {time} -filter:v "crop={w}:{h}:{left}:{top}, scale={scale}" crop.mp4'
|
73 |
+
|
74 |
+
|
75 |
+
def compute_bbox_trajectories(trajectories, fps, frame_shape, args):
|
76 |
+
commands = []
|
77 |
+
for i, (bbox, tube_bbox, start, end) in enumerate(trajectories):
|
78 |
+
if (end - start) > args.min_frames:
|
79 |
+
command = compute_bbox(start, end, fps, tube_bbox, frame_shape, inp=args.inp, image_shape=args.image_shape, increase_area=args.increase)
|
80 |
+
commands.append(command)
|
81 |
+
return commands
|
82 |
+
|
83 |
+
|
84 |
+
def process_video(args):
|
85 |
+
device = 'cpu' if args.cpu else 'cuda'
|
86 |
+
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False, device=device)
|
87 |
+
video = imageio.get_reader(args.inp)
|
88 |
+
|
89 |
+
trajectories = []
|
90 |
+
previous_frame = None
|
91 |
+
fps = video.get_meta_data()['fps']
|
92 |
+
commands = []
|
93 |
+
try:
|
94 |
+
for i, frame in tqdm(enumerate(video)):
|
95 |
+
frame_shape = frame.shape
|
96 |
+
bboxes = extract_bbox(frame, fa)
|
97 |
+
## For each trajectory check the criterion
|
98 |
+
not_valid_trajectories = []
|
99 |
+
valid_trajectories = []
|
100 |
+
|
101 |
+
for trajectory in trajectories:
|
102 |
+
tube_bbox = trajectory[0]
|
103 |
+
intersection = 0
|
104 |
+
for bbox in bboxes:
|
105 |
+
intersection = max(intersection, bb_intersection_over_union(tube_bbox, bbox))
|
106 |
+
if intersection > args.iou_with_initial:
|
107 |
+
valid_trajectories.append(trajectory)
|
108 |
+
else:
|
109 |
+
not_valid_trajectories.append(trajectory)
|
110 |
+
|
111 |
+
commands += compute_bbox_trajectories(not_valid_trajectories, fps, frame_shape, args)
|
112 |
+
trajectories = valid_trajectories
|
113 |
+
|
114 |
+
## Assign bbox to trajectories, create new trajectories
|
115 |
+
for bbox in bboxes:
|
116 |
+
intersection = 0
|
117 |
+
current_trajectory = None
|
118 |
+
for trajectory in trajectories:
|
119 |
+
tube_bbox = trajectory[0]
|
120 |
+
current_intersection = bb_intersection_over_union(tube_bbox, bbox)
|
121 |
+
if intersection < current_intersection and current_intersection > args.iou_with_initial:
|
122 |
+
intersection = bb_intersection_over_union(tube_bbox, bbox)
|
123 |
+
current_trajectory = trajectory
|
124 |
+
|
125 |
+
## Create new trajectory
|
126 |
+
if current_trajectory is None:
|
127 |
+
trajectories.append([bbox, bbox, i, i])
|
128 |
+
else:
|
129 |
+
current_trajectory[3] = i
|
130 |
+
current_trajectory[1] = join(current_trajectory[1], bbox)
|
131 |
+
|
132 |
+
|
133 |
+
except IndexError as e:
|
134 |
+
raise (e)
|
135 |
+
|
136 |
+
commands += compute_bbox_trajectories(trajectories, fps, frame_shape, args)
|
137 |
+
return commands
|
138 |
+
|
139 |
+
|
140 |
+
if __name__ == "__main__":
|
141 |
+
parser = ArgumentParser()
|
142 |
+
|
143 |
+
parser.add_argument("--image_shape", default=(256, 256), type=lambda x: tuple(map(int, x.split(','))),
|
144 |
+
help="Image shape")
|
145 |
+
parser.add_argument("--increase", default=0.1, type=float, help='Increase bbox by this amount')
|
146 |
+
parser.add_argument("--iou_with_initial", type=float, default=0.25, help="The minimal allowed iou with inital bbox")
|
147 |
+
parser.add_argument("--inp", required=True, help='Input image or video')
|
148 |
+
parser.add_argument("--min_frames", type=int, default=150, help='Minimum number of frames')
|
149 |
+
parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
|
150 |
+
|
151 |
+
|
152 |
+
args = parser.parse_args()
|
153 |
+
|
154 |
+
commands = process_video(args)
|
155 |
+
for command in commands:
|
156 |
+
print (command)
|
157 |
+
|
158 |
+
|
demo.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib
|
2 |
+
matplotlib.use('Agg')
|
3 |
+
import os, sys
|
4 |
+
import yaml
|
5 |
+
from argparse import ArgumentParser
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
import imageio
|
9 |
+
import numpy as np
|
10 |
+
from skimage.transform import resize
|
11 |
+
from skimage import img_as_ubyte
|
12 |
+
import torch
|
13 |
+
from sync_batchnorm import DataParallelWithCallback
|
14 |
+
from modules.generator import OcclusionAwareGenerator
|
15 |
+
from modules.keypoint_detector import KPDetector
|
16 |
+
from animate import normalize_kp
|
17 |
+
from scipy.spatial import ConvexHull
|
18 |
+
|
19 |
+
|
20 |
+
if sys.version_info[0] < 3:
|
21 |
+
raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")
|
22 |
+
|
23 |
+
|
24 |
+
def load_checkpoints(config_path, checkpoint_path, cpu=False):
|
25 |
+
|
26 |
+
with open(config_path) as f:
|
27 |
+
config = yaml.load(f)
|
28 |
+
|
29 |
+
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
|
30 |
+
**config['model_params']['common_params'])
|
31 |
+
if not cpu:
|
32 |
+
generator.cuda()
|
33 |
+
|
34 |
+
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
|
35 |
+
**config['model_params']['common_params'])
|
36 |
+
if not cpu:
|
37 |
+
kp_detector.cuda()
|
38 |
+
|
39 |
+
if cpu:
|
40 |
+
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
41 |
+
else:
|
42 |
+
checkpoint = torch.load(checkpoint_path)
|
43 |
+
|
44 |
+
generator.load_state_dict(checkpoint['generator'])
|
45 |
+
kp_detector.load_state_dict(checkpoint['kp_detector'])
|
46 |
+
|
47 |
+
if not cpu:
|
48 |
+
generator = DataParallelWithCallback(generator)
|
49 |
+
kp_detector = DataParallelWithCallback(kp_detector)
|
50 |
+
|
51 |
+
generator.eval()
|
52 |
+
kp_detector.eval()
|
53 |
+
|
54 |
+
return generator, kp_detector
|
55 |
+
|
56 |
+
def make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False):
|
57 |
+
with torch.no_grad():
|
58 |
+
predictions = []
|
59 |
+
source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
|
60 |
+
if not cpu:
|
61 |
+
source = source.cuda()
|
62 |
+
driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
|
63 |
+
kp_source = kp_detector(source)
|
64 |
+
kp_driving_initial = kp_detector(driving[:, :, 0])
|
65 |
+
|
66 |
+
for frame_idx in tqdm(range(driving.shape[2])):
|
67 |
+
|
68 |
+
driving_frame = driving[:, :, frame_idx]
|
69 |
+
if not cpu:
|
70 |
+
driving_frame = driving_frame.cuda()
|
71 |
+
kp_driving = kp_detector(driving_frame)
|
72 |
+
kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
|
73 |
+
kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
|
74 |
+
use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
|
75 |
+
out = generator(source, kp_source=kp_source, kp_driving=kp_norm)
|
76 |
+
|
77 |
+
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
|
78 |
+
return predictions
|
79 |
+
|
80 |
+
def find_best_frame(source, driving, cpu=False):
|
81 |
+
import face_alignment
|
82 |
+
|
83 |
+
def normalize_kp(kp):
|
84 |
+
kp = kp - kp.mean(axis=0, keepdims=True)
|
85 |
+
area = ConvexHull(kp[:, :2]).volume
|
86 |
+
area = np.sqrt(area)
|
87 |
+
kp[:, :2] = kp[:, :2] / area
|
88 |
+
return kp
|
89 |
+
|
90 |
+
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
|
91 |
+
device='cpu' if cpu else 'cuda')
|
92 |
+
kp_source = fa.get_landmarks(255 * source)[0]
|
93 |
+
kp_source = normalize_kp(kp_source)
|
94 |
+
norm = float('inf')
|
95 |
+
frame_num = 0
|
96 |
+
for i, image in tqdm(enumerate(driving)):
|
97 |
+
kp_driving = fa.get_landmarks(255 * image)[0]
|
98 |
+
kp_driving = normalize_kp(kp_driving)
|
99 |
+
new_norm = (np.abs(kp_source - kp_driving) ** 2).sum()
|
100 |
+
if new_norm < norm:
|
101 |
+
norm = new_norm
|
102 |
+
frame_num = i
|
103 |
+
return frame_num
|
104 |
+
|
105 |
+
if __name__ == "__main__":
|
106 |
+
parser = ArgumentParser()
|
107 |
+
parser.add_argument("--config", required=True, help="path to config")
|
108 |
+
parser.add_argument("--checkpoint", default='vox-cpk.pth.tar', help="path to checkpoint to restore")
|
109 |
+
|
110 |
+
parser.add_argument("--source_image", default='sup-mat/source.png', help="path to source image")
|
111 |
+
parser.add_argument("--driving_video", default='sup-mat/source.png', help="path to driving video")
|
112 |
+
parser.add_argument("--result_video", default='result.mp4', help="path to output")
|
113 |
+
|
114 |
+
parser.add_argument("--relative", dest="relative", action="store_true", help="use relative or absolute keypoint coordinates")
|
115 |
+
parser.add_argument("--adapt_scale", dest="adapt_scale", action="store_true", help="adapt movement scale based on convex hull of keypoints")
|
116 |
+
|
117 |
+
parser.add_argument("--find_best_frame", dest="find_best_frame", action="store_true",
|
118 |
+
help="Generate from the frame that is the most alligned with source. (Only for faces, requires face_aligment lib)")
|
119 |
+
|
120 |
+
parser.add_argument("--best_frame", dest="best_frame", type=int, default=None,
|
121 |
+
help="Set frame to start from.")
|
122 |
+
|
123 |
+
parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
|
124 |
+
|
125 |
+
|
126 |
+
parser.set_defaults(relative=False)
|
127 |
+
parser.set_defaults(adapt_scale=False)
|
128 |
+
|
129 |
+
opt = parser.parse_args()
|
130 |
+
|
131 |
+
source_image = imageio.imread(opt.source_image)
|
132 |
+
reader = imageio.get_reader(opt.driving_video)
|
133 |
+
fps = reader.get_meta_data()['fps']
|
134 |
+
driving_video = []
|
135 |
+
try:
|
136 |
+
for im in reader:
|
137 |
+
driving_video.append(im)
|
138 |
+
except RuntimeError:
|
139 |
+
pass
|
140 |
+
reader.close()
|
141 |
+
|
142 |
+
source_image = resize(source_image, (256, 256))[..., :3]
|
143 |
+
driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
|
144 |
+
generator, kp_detector = load_checkpoints(config_path=opt.config, checkpoint_path=opt.checkpoint, cpu=opt.cpu)
|
145 |
+
|
146 |
+
if opt.find_best_frame or opt.best_frame is not None:
|
147 |
+
i = opt.best_frame if opt.best_frame is not None else find_best_frame(source_image, driving_video, cpu=opt.cpu)
|
148 |
+
print ("Best frame: " + str(i))
|
149 |
+
driving_forward = driving_video[i:]
|
150 |
+
driving_backward = driving_video[:(i+1)][::-1]
|
151 |
+
predictions_forward = make_animation(source_image, driving_forward, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
|
152 |
+
predictions_backward = make_animation(source_image, driving_backward, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
|
153 |
+
predictions = predictions_backward[::-1] + predictions_forward[1:]
|
154 |
+
else:
|
155 |
+
predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
|
156 |
+
imageio.mimsave(opt.result_video, [img_as_ubyte(frame) for frame in predictions], fps=fps)
|
157 |
+
|
discordbot.py
ADDED
File without changes
|
frames_dataset.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from skimage import io, img_as_float32
|
3 |
+
from skimage.color import gray2rgb
|
4 |
+
from sklearn.model_selection import train_test_split
|
5 |
+
from imageio import mimread
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
import pandas as pd
|
10 |
+
from augmentation import AllAugmentationTransform
|
11 |
+
import glob
|
12 |
+
|
13 |
+
|
14 |
+
def read_video(name, frame_shape):
|
15 |
+
"""
|
16 |
+
Read video which can be:
|
17 |
+
- an image of concatenated frames
|
18 |
+
- '.mp4' and'.gif'
|
19 |
+
- folder with videos
|
20 |
+
"""
|
21 |
+
|
22 |
+
if os.path.isdir(name):
|
23 |
+
frames = sorted(os.listdir(name))
|
24 |
+
num_frames = len(frames)
|
25 |
+
video_array = np.array(
|
26 |
+
[img_as_float32(io.imread(os.path.join(name, frames[idx]))) for idx in range(num_frames)])
|
27 |
+
elif name.lower().endswith('.png') or name.lower().endswith('.jpg'):
|
28 |
+
image = io.imread(name)
|
29 |
+
|
30 |
+
if len(image.shape) == 2 or image.shape[2] == 1:
|
31 |
+
image = gray2rgb(image)
|
32 |
+
|
33 |
+
if image.shape[2] == 4:
|
34 |
+
image = image[..., :3]
|
35 |
+
|
36 |
+
image = img_as_float32(image)
|
37 |
+
|
38 |
+
video_array = np.moveaxis(image, 1, 0)
|
39 |
+
|
40 |
+
video_array = video_array.reshape((-1,) + frame_shape)
|
41 |
+
video_array = np.moveaxis(video_array, 1, 2)
|
42 |
+
elif name.lower().endswith('.gif') or name.lower().endswith('.mp4') or name.lower().endswith('.mov'):
|
43 |
+
video = np.array(mimread(name))
|
44 |
+
if len(video.shape) == 3:
|
45 |
+
video = np.array([gray2rgb(frame) for frame in video])
|
46 |
+
if video.shape[-1] == 4:
|
47 |
+
video = video[..., :3]
|
48 |
+
video_array = img_as_float32(video)
|
49 |
+
else:
|
50 |
+
raise Exception("Unknown file extensions %s" % name)
|
51 |
+
|
52 |
+
return video_array
|
53 |
+
|
54 |
+
|
55 |
+
class FramesDataset(Dataset):
|
56 |
+
"""
|
57 |
+
Dataset of videos, each video can be represented as:
|
58 |
+
- an image of concatenated frames
|
59 |
+
- '.mp4' or '.gif'
|
60 |
+
- folder with all frames
|
61 |
+
"""
|
62 |
+
|
63 |
+
def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True,
|
64 |
+
random_seed=0, pairs_list=None, augmentation_params=None):
|
65 |
+
self.root_dir = root_dir
|
66 |
+
self.videos = os.listdir(root_dir)
|
67 |
+
self.frame_shape = tuple(frame_shape)
|
68 |
+
self.pairs_list = pairs_list
|
69 |
+
self.id_sampling = id_sampling
|
70 |
+
if os.path.exists(os.path.join(root_dir, 'train')):
|
71 |
+
assert os.path.exists(os.path.join(root_dir, 'test'))
|
72 |
+
print("Use predefined train-test split.")
|
73 |
+
if id_sampling:
|
74 |
+
train_videos = {os.path.basename(video).split('#')[0] for video in
|
75 |
+
os.listdir(os.path.join(root_dir, 'train'))}
|
76 |
+
train_videos = list(train_videos)
|
77 |
+
else:
|
78 |
+
train_videos = os.listdir(os.path.join(root_dir, 'train'))
|
79 |
+
test_videos = os.listdir(os.path.join(root_dir, 'test'))
|
80 |
+
self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test')
|
81 |
+
else:
|
82 |
+
print("Use random train-test split.")
|
83 |
+
train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2)
|
84 |
+
|
85 |
+
if is_train:
|
86 |
+
self.videos = train_videos
|
87 |
+
else:
|
88 |
+
self.videos = test_videos
|
89 |
+
|
90 |
+
self.is_train = is_train
|
91 |
+
|
92 |
+
if self.is_train:
|
93 |
+
self.transform = AllAugmentationTransform(**augmentation_params)
|
94 |
+
else:
|
95 |
+
self.transform = None
|
96 |
+
|
97 |
+
def __len__(self):
|
98 |
+
return len(self.videos)
|
99 |
+
|
100 |
+
def __getitem__(self, idx):
|
101 |
+
if self.is_train and self.id_sampling:
|
102 |
+
name = self.videos[idx]
|
103 |
+
path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4')))
|
104 |
+
else:
|
105 |
+
name = self.videos[idx]
|
106 |
+
path = os.path.join(self.root_dir, name)
|
107 |
+
|
108 |
+
video_name = os.path.basename(path)
|
109 |
+
|
110 |
+
if self.is_train and os.path.isdir(path):
|
111 |
+
frames = os.listdir(path)
|
112 |
+
num_frames = len(frames)
|
113 |
+
frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2))
|
114 |
+
video_array = [img_as_float32(io.imread(os.path.join(path, frames[idx]))) for idx in frame_idx]
|
115 |
+
else:
|
116 |
+
video_array = read_video(path, frame_shape=self.frame_shape)
|
117 |
+
num_frames = len(video_array)
|
118 |
+
frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) if self.is_train else range(
|
119 |
+
num_frames)
|
120 |
+
video_array = video_array[frame_idx]
|
121 |
+
|
122 |
+
if self.transform is not None:
|
123 |
+
video_array = self.transform(video_array)
|
124 |
+
|
125 |
+
out = {}
|
126 |
+
if self.is_train:
|
127 |
+
source = np.array(video_array[0], dtype='float32')
|
128 |
+
driving = np.array(video_array[1], dtype='float32')
|
129 |
+
|
130 |
+
out['driving'] = driving.transpose((2, 0, 1))
|
131 |
+
out['source'] = source.transpose((2, 0, 1))
|
132 |
+
else:
|
133 |
+
video = np.array(video_array, dtype='float32')
|
134 |
+
out['video'] = video.transpose((3, 0, 1, 2))
|
135 |
+
|
136 |
+
out['name'] = video_name
|
137 |
+
|
138 |
+
return out
|
139 |
+
|
140 |
+
|
141 |
+
class DatasetRepeater(Dataset):
|
142 |
+
"""
|
143 |
+
Pass several times over the same dataset for better i/o performance
|
144 |
+
"""
|
145 |
+
|
146 |
+
def __init__(self, dataset, num_repeats=100):
|
147 |
+
self.dataset = dataset
|
148 |
+
self.num_repeats = num_repeats
|
149 |
+
|
150 |
+
def __len__(self):
|
151 |
+
return self.num_repeats * self.dataset.__len__()
|
152 |
+
|
153 |
+
def __getitem__(self, idx):
|
154 |
+
return self.dataset[idx % self.dataset.__len__()]
|
155 |
+
|
156 |
+
|
157 |
+
class PairedDataset(Dataset):
|
158 |
+
"""
|
159 |
+
Dataset of pairs for animation.
|
160 |
+
"""
|
161 |
+
|
162 |
+
def __init__(self, initial_dataset, number_of_pairs, seed=0):
|
163 |
+
self.initial_dataset = initial_dataset
|
164 |
+
pairs_list = self.initial_dataset.pairs_list
|
165 |
+
|
166 |
+
np.random.seed(seed)
|
167 |
+
|
168 |
+
if pairs_list is None:
|
169 |
+
max_idx = min(number_of_pairs, len(initial_dataset))
|
170 |
+
nx, ny = max_idx, max_idx
|
171 |
+
xy = np.mgrid[:nx, :ny].reshape(2, -1).T
|
172 |
+
number_of_pairs = min(xy.shape[0], number_of_pairs)
|
173 |
+
self.pairs = xy.take(np.random.choice(xy.shape[0], number_of_pairs, replace=False), axis=0)
|
174 |
+
else:
|
175 |
+
videos = self.initial_dataset.videos
|
176 |
+
name_to_index = {name: index for index, name in enumerate(videos)}
|
177 |
+
pairs = pd.read_csv(pairs_list)
|
178 |
+
pairs = pairs[np.logical_and(pairs['source'].isin(videos), pairs['driving'].isin(videos))]
|
179 |
+
|
180 |
+
number_of_pairs = min(pairs.shape[0], number_of_pairs)
|
181 |
+
self.pairs = []
|
182 |
+
self.start_frames = []
|
183 |
+
for ind in range(number_of_pairs):
|
184 |
+
self.pairs.append(
|
185 |
+
(name_to_index[pairs['driving'].iloc[ind]], name_to_index[pairs['source'].iloc[ind]]))
|
186 |
+
|
187 |
+
def __len__(self):
|
188 |
+
return len(self.pairs)
|
189 |
+
|
190 |
+
def __getitem__(self, idx):
|
191 |
+
pair = self.pairs[idx]
|
192 |
+
first = self.initial_dataset[pair[0]]
|
193 |
+
second = self.initial_dataset[pair[1]]
|
194 |
+
first = {'driving_' + key: value for key, value in first.items()}
|
195 |
+
second = {'source_' + key: value for key, value in second.items()}
|
196 |
+
|
197 |
+
return {**first, **second}
|
generated.mp4
ADDED
Binary file (130 kB). View file
|
|
got-05.jpg
ADDED
Git LFS Details
|
logger.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import imageio
|
5 |
+
|
6 |
+
import os
|
7 |
+
from skimage.draw import circle
|
8 |
+
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import collections
|
11 |
+
|
12 |
+
|
13 |
+
class Logger:
|
14 |
+
def __init__(self, log_dir, checkpoint_freq=100, visualizer_params=None, zfill_num=8, log_file_name='log.txt'):
|
15 |
+
|
16 |
+
self.loss_list = []
|
17 |
+
self.cpk_dir = log_dir
|
18 |
+
self.visualizations_dir = os.path.join(log_dir, 'train-vis')
|
19 |
+
if not os.path.exists(self.visualizations_dir):
|
20 |
+
os.makedirs(self.visualizations_dir)
|
21 |
+
self.log_file = open(os.path.join(log_dir, log_file_name), 'a')
|
22 |
+
self.zfill_num = zfill_num
|
23 |
+
self.visualizer = Visualizer(**visualizer_params)
|
24 |
+
self.checkpoint_freq = checkpoint_freq
|
25 |
+
self.epoch = 0
|
26 |
+
self.best_loss = float('inf')
|
27 |
+
self.names = None
|
28 |
+
|
29 |
+
def log_scores(self, loss_names):
|
30 |
+
loss_mean = np.array(self.loss_list).mean(axis=0)
|
31 |
+
|
32 |
+
loss_string = "; ".join(["%s - %.5f" % (name, value) for name, value in zip(loss_names, loss_mean)])
|
33 |
+
loss_string = str(self.epoch).zfill(self.zfill_num) + ") " + loss_string
|
34 |
+
|
35 |
+
print(loss_string, file=self.log_file)
|
36 |
+
self.loss_list = []
|
37 |
+
self.log_file.flush()
|
38 |
+
|
39 |
+
def visualize_rec(self, inp, out):
|
40 |
+
image = self.visualizer.visualize(inp['driving'], inp['source'], out)
|
41 |
+
imageio.imsave(os.path.join(self.visualizations_dir, "%s-rec.png" % str(self.epoch).zfill(self.zfill_num)), image)
|
42 |
+
|
43 |
+
def save_cpk(self, emergent=False):
|
44 |
+
cpk = {k: v.state_dict() for k, v in self.models.items()}
|
45 |
+
cpk['epoch'] = self.epoch
|
46 |
+
cpk_path = os.path.join(self.cpk_dir, '%s-checkpoint.pth.tar' % str(self.epoch).zfill(self.zfill_num))
|
47 |
+
if not (os.path.exists(cpk_path) and emergent):
|
48 |
+
torch.save(cpk, cpk_path)
|
49 |
+
|
50 |
+
@staticmethod
|
51 |
+
def load_cpk(checkpoint_path, generator=None, discriminator=None, kp_detector=None,
|
52 |
+
optimizer_generator=None, optimizer_discriminator=None, optimizer_kp_detector=None):
|
53 |
+
checkpoint = torch.load(checkpoint_path)
|
54 |
+
if generator is not None:
|
55 |
+
generator.load_state_dict(checkpoint['generator'])
|
56 |
+
if kp_detector is not None:
|
57 |
+
kp_detector.load_state_dict(checkpoint['kp_detector'])
|
58 |
+
if discriminator is not None:
|
59 |
+
try:
|
60 |
+
discriminator.load_state_dict(checkpoint['discriminator'])
|
61 |
+
except:
|
62 |
+
print ('No discriminator in the state-dict. Dicriminator will be randomly initialized')
|
63 |
+
if optimizer_generator is not None:
|
64 |
+
optimizer_generator.load_state_dict(checkpoint['optimizer_generator'])
|
65 |
+
if optimizer_discriminator is not None:
|
66 |
+
try:
|
67 |
+
optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
|
68 |
+
except RuntimeError as e:
|
69 |
+
print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized')
|
70 |
+
if optimizer_kp_detector is not None:
|
71 |
+
optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector'])
|
72 |
+
|
73 |
+
return checkpoint['epoch']
|
74 |
+
|
75 |
+
def __enter__(self):
|
76 |
+
return self
|
77 |
+
|
78 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
79 |
+
if 'models' in self.__dict__:
|
80 |
+
self.save_cpk()
|
81 |
+
self.log_file.close()
|
82 |
+
|
83 |
+
def log_iter(self, losses):
|
84 |
+
losses = collections.OrderedDict(losses.items())
|
85 |
+
if self.names is None:
|
86 |
+
self.names = list(losses.keys())
|
87 |
+
self.loss_list.append(list(losses.values()))
|
88 |
+
|
89 |
+
def log_epoch(self, epoch, models, inp, out):
|
90 |
+
self.epoch = epoch
|
91 |
+
self.models = models
|
92 |
+
if (self.epoch + 1) % self.checkpoint_freq == 0:
|
93 |
+
self.save_cpk()
|
94 |
+
self.log_scores(self.names)
|
95 |
+
self.visualize_rec(inp, out)
|
96 |
+
|
97 |
+
|
98 |
+
class Visualizer:
|
99 |
+
def __init__(self, kp_size=5, draw_border=False, colormap='gist_rainbow'):
|
100 |
+
self.kp_size = kp_size
|
101 |
+
self.draw_border = draw_border
|
102 |
+
self.colormap = plt.get_cmap(colormap)
|
103 |
+
|
104 |
+
def draw_image_with_kp(self, image, kp_array):
|
105 |
+
image = np.copy(image)
|
106 |
+
spatial_size = np.array(image.shape[:2][::-1])[np.newaxis]
|
107 |
+
kp_array = spatial_size * (kp_array + 1) / 2
|
108 |
+
num_kp = kp_array.shape[0]
|
109 |
+
for kp_ind, kp in enumerate(kp_array):
|
110 |
+
rr, cc = circle(kp[1], kp[0], self.kp_size, shape=image.shape[:2])
|
111 |
+
image[rr, cc] = np.array(self.colormap(kp_ind / num_kp))[:3]
|
112 |
+
return image
|
113 |
+
|
114 |
+
def create_image_column_with_kp(self, images, kp):
|
115 |
+
image_array = np.array([self.draw_image_with_kp(v, k) for v, k in zip(images, kp)])
|
116 |
+
return self.create_image_column(image_array)
|
117 |
+
|
118 |
+
def create_image_column(self, images):
|
119 |
+
if self.draw_border:
|
120 |
+
images = np.copy(images)
|
121 |
+
images[:, :, [0, -1]] = (1, 1, 1)
|
122 |
+
images[:, :, [0, -1]] = (1, 1, 1)
|
123 |
+
return np.concatenate(list(images), axis=0)
|
124 |
+
|
125 |
+
def create_image_grid(self, *args):
|
126 |
+
out = []
|
127 |
+
for arg in args:
|
128 |
+
if type(arg) == tuple:
|
129 |
+
out.append(self.create_image_column_with_kp(arg[0], arg[1]))
|
130 |
+
else:
|
131 |
+
out.append(self.create_image_column(arg))
|
132 |
+
return np.concatenate(out, axis=1)
|
133 |
+
|
134 |
+
def visualize(self, driving, source, out):
|
135 |
+
images = []
|
136 |
+
|
137 |
+
# Source image with keypoints
|
138 |
+
source = source.data.cpu()
|
139 |
+
kp_source = out['kp_source']['value'].data.cpu().numpy()
|
140 |
+
source = np.transpose(source, [0, 2, 3, 1])
|
141 |
+
images.append((source, kp_source))
|
142 |
+
|
143 |
+
# Equivariance visualization
|
144 |
+
if 'transformed_frame' in out:
|
145 |
+
transformed = out['transformed_frame'].data.cpu().numpy()
|
146 |
+
transformed = np.transpose(transformed, [0, 2, 3, 1])
|
147 |
+
transformed_kp = out['transformed_kp']['value'].data.cpu().numpy()
|
148 |
+
images.append((transformed, transformed_kp))
|
149 |
+
|
150 |
+
# Driving image with keypoints
|
151 |
+
kp_driving = out['kp_driving']['value'].data.cpu().numpy()
|
152 |
+
driving = driving.data.cpu().numpy()
|
153 |
+
driving = np.transpose(driving, [0, 2, 3, 1])
|
154 |
+
images.append((driving, kp_driving))
|
155 |
+
|
156 |
+
# Deformed image
|
157 |
+
if 'deformed' in out:
|
158 |
+
deformed = out['deformed'].data.cpu().numpy()
|
159 |
+
deformed = np.transpose(deformed, [0, 2, 3, 1])
|
160 |
+
images.append(deformed)
|
161 |
+
|
162 |
+
# Result with and without keypoints
|
163 |
+
prediction = out['prediction'].data.cpu().numpy()
|
164 |
+
prediction = np.transpose(prediction, [0, 2, 3, 1])
|
165 |
+
if 'kp_norm' in out:
|
166 |
+
kp_norm = out['kp_norm']['value'].data.cpu().numpy()
|
167 |
+
images.append((prediction, kp_norm))
|
168 |
+
images.append(prediction)
|
169 |
+
|
170 |
+
|
171 |
+
## Occlusion map
|
172 |
+
if 'occlusion_map' in out:
|
173 |
+
occlusion_map = out['occlusion_map'].data.cpu().repeat(1, 3, 1, 1)
|
174 |
+
occlusion_map = F.interpolate(occlusion_map, size=source.shape[1:3]).numpy()
|
175 |
+
occlusion_map = np.transpose(occlusion_map, [0, 2, 3, 1])
|
176 |
+
images.append(occlusion_map)
|
177 |
+
|
178 |
+
# Deformed images according to each individual transform
|
179 |
+
if 'sparse_deformed' in out:
|
180 |
+
full_mask = []
|
181 |
+
for i in range(out['sparse_deformed'].shape[1]):
|
182 |
+
image = out['sparse_deformed'][:, i].data.cpu()
|
183 |
+
image = F.interpolate(image, size=source.shape[1:3])
|
184 |
+
mask = out['mask'][:, i:(i+1)].data.cpu().repeat(1, 3, 1, 1)
|
185 |
+
mask = F.interpolate(mask, size=source.shape[1:3])
|
186 |
+
image = np.transpose(image.numpy(), (0, 2, 3, 1))
|
187 |
+
mask = np.transpose(mask.numpy(), (0, 2, 3, 1))
|
188 |
+
|
189 |
+
if i != 0:
|
190 |
+
color = np.array(self.colormap((i - 1) / (out['sparse_deformed'].shape[1] - 1)))[:3]
|
191 |
+
else:
|
192 |
+
color = np.array((0, 0, 0))
|
193 |
+
|
194 |
+
color = color.reshape((1, 1, 1, 3))
|
195 |
+
|
196 |
+
images.append(image)
|
197 |
+
if i != 0:
|
198 |
+
images.append(mask * color)
|
199 |
+
else:
|
200 |
+
images.append(mask)
|
201 |
+
|
202 |
+
full_mask.append(mask * color)
|
203 |
+
|
204 |
+
images.append(sum(full_mask))
|
205 |
+
|
206 |
+
image = self.create_image_grid(*images)
|
207 |
+
image = (255 * image).astype(np.uint8)
|
208 |
+
return image
|
reconstruction.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from tqdm import tqdm
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from logger import Logger, Visualizer
|
6 |
+
import numpy as np
|
7 |
+
import imageio
|
8 |
+
from sync_batchnorm import DataParallelWithCallback
|
9 |
+
|
10 |
+
|
11 |
+
def reconstruction(config, generator, kp_detector, checkpoint, log_dir, dataset):
|
12 |
+
png_dir = os.path.join(log_dir, 'reconstruction/png')
|
13 |
+
log_dir = os.path.join(log_dir, 'reconstruction')
|
14 |
+
|
15 |
+
if checkpoint is not None:
|
16 |
+
Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector)
|
17 |
+
else:
|
18 |
+
raise AttributeError("Checkpoint should be specified for mode='reconstruction'.")
|
19 |
+
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
|
20 |
+
|
21 |
+
if not os.path.exists(log_dir):
|
22 |
+
os.makedirs(log_dir)
|
23 |
+
|
24 |
+
if not os.path.exists(png_dir):
|
25 |
+
os.makedirs(png_dir)
|
26 |
+
|
27 |
+
loss_list = []
|
28 |
+
if torch.cuda.is_available():
|
29 |
+
generator = DataParallelWithCallback(generator)
|
30 |
+
kp_detector = DataParallelWithCallback(kp_detector)
|
31 |
+
|
32 |
+
generator.eval()
|
33 |
+
kp_detector.eval()
|
34 |
+
|
35 |
+
for it, x in tqdm(enumerate(dataloader)):
|
36 |
+
if config['reconstruction_params']['num_videos'] is not None:
|
37 |
+
if it > config['reconstruction_params']['num_videos']:
|
38 |
+
break
|
39 |
+
with torch.no_grad():
|
40 |
+
predictions = []
|
41 |
+
visualizations = []
|
42 |
+
if torch.cuda.is_available():
|
43 |
+
x['video'] = x['video'].cuda()
|
44 |
+
kp_source = kp_detector(x['video'][:, :, 0])
|
45 |
+
for frame_idx in range(x['video'].shape[2]):
|
46 |
+
source = x['video'][:, :, 0]
|
47 |
+
driving = x['video'][:, :, frame_idx]
|
48 |
+
kp_driving = kp_detector(driving)
|
49 |
+
out = generator(source, kp_source=kp_source, kp_driving=kp_driving)
|
50 |
+
out['kp_source'] = kp_source
|
51 |
+
out['kp_driving'] = kp_driving
|
52 |
+
del out['sparse_deformed']
|
53 |
+
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
|
54 |
+
|
55 |
+
visualization = Visualizer(**config['visualizer_params']).visualize(source=source,
|
56 |
+
driving=driving, out=out)
|
57 |
+
visualizations.append(visualization)
|
58 |
+
|
59 |
+
loss_list.append(torch.abs(out['prediction'] - driving).mean().cpu().numpy())
|
60 |
+
|
61 |
+
predictions = np.concatenate(predictions, axis=1)
|
62 |
+
imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8))
|
63 |
+
|
64 |
+
image_name = x['name'][0] + config['reconstruction_params']['format']
|
65 |
+
imageio.mimsave(os.path.join(log_dir, image_name), visualizations)
|
66 |
+
|
67 |
+
print("Reconstruction loss: %s" % np.mean(loss_list))
|
run.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib
|
2 |
+
|
3 |
+
matplotlib.use('Agg')
|
4 |
+
|
5 |
+
import os, sys
|
6 |
+
import yaml
|
7 |
+
from argparse import ArgumentParser
|
8 |
+
from time import gmtime, strftime
|
9 |
+
from shutil import copy
|
10 |
+
|
11 |
+
from frames_dataset import FramesDataset
|
12 |
+
|
13 |
+
from modules.generator import OcclusionAwareGenerator
|
14 |
+
from modules.discriminator import MultiScaleDiscriminator
|
15 |
+
from modules.keypoint_detector import KPDetector
|
16 |
+
|
17 |
+
import torch
|
18 |
+
|
19 |
+
from train import train
|
20 |
+
from reconstruction import reconstruction
|
21 |
+
from animate import animate
|
22 |
+
|
23 |
+
if __name__ == "__main__":
|
24 |
+
|
25 |
+
if sys.version_info[0] < 3:
|
26 |
+
raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")
|
27 |
+
|
28 |
+
parser = ArgumentParser()
|
29 |
+
parser.add_argument("--config", required=True, help="path to config")
|
30 |
+
parser.add_argument("--mode", default="train", choices=["train", "reconstruction", "animate"])
|
31 |
+
parser.add_argument("--log_dir", default='log', help="path to log into")
|
32 |
+
parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore")
|
33 |
+
parser.add_argument("--device_ids", default="0", type=lambda x: list(map(int, x.split(','))),
|
34 |
+
help="Names of the devices comma separated.")
|
35 |
+
parser.add_argument("--verbose", dest="verbose", action="store_true", help="Print model architecture")
|
36 |
+
parser.set_defaults(verbose=False)
|
37 |
+
|
38 |
+
opt = parser.parse_args()
|
39 |
+
with open(opt.config) as f:
|
40 |
+
config = yaml.load(f)
|
41 |
+
|
42 |
+
if opt.checkpoint is not None:
|
43 |
+
log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1])
|
44 |
+
else:
|
45 |
+
log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0])
|
46 |
+
log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime())
|
47 |
+
|
48 |
+
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
|
49 |
+
**config['model_params']['common_params'])
|
50 |
+
|
51 |
+
if torch.cuda.is_available():
|
52 |
+
generator.to(opt.device_ids[0])
|
53 |
+
if opt.verbose:
|
54 |
+
print(generator)
|
55 |
+
|
56 |
+
discriminator = MultiScaleDiscriminator(**config['model_params']['discriminator_params'],
|
57 |
+
**config['model_params']['common_params'])
|
58 |
+
if torch.cuda.is_available():
|
59 |
+
discriminator.to(opt.device_ids[0])
|
60 |
+
if opt.verbose:
|
61 |
+
print(discriminator)
|
62 |
+
|
63 |
+
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
|
64 |
+
**config['model_params']['common_params'])
|
65 |
+
|
66 |
+
if torch.cuda.is_available():
|
67 |
+
kp_detector.to(opt.device_ids[0])
|
68 |
+
|
69 |
+
if opt.verbose:
|
70 |
+
print(kp_detector)
|
71 |
+
|
72 |
+
dataset = FramesDataset(is_train=(opt.mode == 'train'), **config['dataset_params'])
|
73 |
+
|
74 |
+
if not os.path.exists(log_dir):
|
75 |
+
os.makedirs(log_dir)
|
76 |
+
if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))):
|
77 |
+
copy(opt.config, log_dir)
|
78 |
+
|
79 |
+
if opt.mode == 'train':
|
80 |
+
print("Training...")
|
81 |
+
train(config, generator, discriminator, kp_detector, opt.checkpoint, log_dir, dataset, opt.device_ids)
|
82 |
+
elif opt.mode == 'reconstruction':
|
83 |
+
print("Reconstruction...")
|
84 |
+
reconstruction(config, generator, kp_detector, opt.checkpoint, log_dir, dataset)
|
85 |
+
elif opt.mode == 'animate':
|
86 |
+
print("Animate...")
|
87 |
+
animate(config, generator, kp_detector, opt.checkpoint, log_dir, dataset)
|
sdkan.png
ADDED
train.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import trange
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
|
6 |
+
from logger import Logger
|
7 |
+
from modules.model import GeneratorFullModel, DiscriminatorFullModel
|
8 |
+
|
9 |
+
from torch.optim.lr_scheduler import MultiStepLR
|
10 |
+
|
11 |
+
from sync_batchnorm import DataParallelWithCallback
|
12 |
+
|
13 |
+
from frames_dataset import DatasetRepeater
|
14 |
+
|
15 |
+
|
16 |
+
def train(config, generator, discriminator, kp_detector, checkpoint, log_dir, dataset, device_ids):
|
17 |
+
train_params = config['train_params']
|
18 |
+
|
19 |
+
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params['lr_generator'], betas=(0.5, 0.999))
|
20 |
+
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=train_params['lr_discriminator'], betas=(0.5, 0.999))
|
21 |
+
optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(), lr=train_params['lr_kp_detector'], betas=(0.5, 0.999))
|
22 |
+
|
23 |
+
if checkpoint is not None:
|
24 |
+
start_epoch = Logger.load_cpk(checkpoint, generator, discriminator, kp_detector,
|
25 |
+
optimizer_generator, optimizer_discriminator,
|
26 |
+
None if train_params['lr_kp_detector'] == 0 else optimizer_kp_detector)
|
27 |
+
else:
|
28 |
+
start_epoch = 0
|
29 |
+
|
30 |
+
scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1,
|
31 |
+
last_epoch=start_epoch - 1)
|
32 |
+
scheduler_discriminator = MultiStepLR(optimizer_discriminator, train_params['epoch_milestones'], gamma=0.1,
|
33 |
+
last_epoch=start_epoch - 1)
|
34 |
+
scheduler_kp_detector = MultiStepLR(optimizer_kp_detector, train_params['epoch_milestones'], gamma=0.1,
|
35 |
+
last_epoch=-1 + start_epoch * (train_params['lr_kp_detector'] != 0))
|
36 |
+
|
37 |
+
if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
|
38 |
+
dataset = DatasetRepeater(dataset, train_params['num_repeats'])
|
39 |
+
dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=6, drop_last=True)
|
40 |
+
|
41 |
+
generator_full = GeneratorFullModel(kp_detector, generator, discriminator, train_params)
|
42 |
+
discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params)
|
43 |
+
|
44 |
+
if torch.cuda.is_available():
|
45 |
+
generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids)
|
46 |
+
discriminator_full = DataParallelWithCallback(discriminator_full, device_ids=device_ids)
|
47 |
+
|
48 |
+
with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger:
|
49 |
+
for epoch in trange(start_epoch, train_params['num_epochs']):
|
50 |
+
for x in dataloader:
|
51 |
+
losses_generator, generated = generator_full(x)
|
52 |
+
|
53 |
+
loss_values = [val.mean() for val in losses_generator.values()]
|
54 |
+
loss = sum(loss_values)
|
55 |
+
|
56 |
+
loss.backward()
|
57 |
+
optimizer_generator.step()
|
58 |
+
optimizer_generator.zero_grad()
|
59 |
+
optimizer_kp_detector.step()
|
60 |
+
optimizer_kp_detector.zero_grad()
|
61 |
+
|
62 |
+
if train_params['loss_weights']['generator_gan'] != 0:
|
63 |
+
optimizer_discriminator.zero_grad()
|
64 |
+
losses_discriminator = discriminator_full(x, generated)
|
65 |
+
loss_values = [val.mean() for val in losses_discriminator.values()]
|
66 |
+
loss = sum(loss_values)
|
67 |
+
|
68 |
+
loss.backward()
|
69 |
+
optimizer_discriminator.step()
|
70 |
+
optimizer_discriminator.zero_grad()
|
71 |
+
else:
|
72 |
+
losses_discriminator = {}
|
73 |
+
|
74 |
+
losses_generator.update(losses_discriminator)
|
75 |
+
losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()}
|
76 |
+
logger.log_iter(losses=losses)
|
77 |
+
|
78 |
+
scheduler_generator.step()
|
79 |
+
scheduler_discriminator.step()
|
80 |
+
scheduler_kp_detector.step()
|
81 |
+
|
82 |
+
logger.log_epoch(epoch, {'generator': generator,
|
83 |
+
'discriminator': discriminator,
|
84 |
+
'kp_detector': kp_detector,
|
85 |
+
'optimizer_generator': optimizer_generator,
|
86 |
+
'optimizer_discriminator': optimizer_discriminator,
|
87 |
+
'optimizer_kp_detector': optimizer_kp_detector}, inp=x, out=generated)
|