maduvantha commited on
Commit
54bf1bc
1 Parent(s): eddf80e

Upload 14 files

Browse files
Files changed (15) hide show
  1. .gitattributes +1 -0
  2. 00.mp4 +0 -0
  3. animate.py +101 -0
  4. augmentation.py +345 -0
  5. crop-video.py +158 -0
  6. demo.py +157 -0
  7. discordbot.py +0 -0
  8. frames_dataset.py +197 -0
  9. generated.mp4 +0 -0
  10. got-05.jpg +3 -0
  11. logger.py +208 -0
  12. reconstruction.py +67 -0
  13. run.py +87 -0
  14. sdkan.png +0 -0
  15. train.py +87 -0
.gitattributes CHANGED
@@ -38,3 +38,4 @@ sup-mat/fashion-teaser.gif filter=lfs diff=lfs merge=lfs -text
38
  sup-mat/mgif-teaser.gif filter=lfs diff=lfs merge=lfs -text
39
  sup-mat/relative-demo.gif filter=lfs diff=lfs merge=lfs -text
40
  sup-mat/vox-teaser.gif filter=lfs diff=lfs merge=lfs -text
 
 
38
  sup-mat/mgif-teaser.gif filter=lfs diff=lfs merge=lfs -text
39
  sup-mat/relative-demo.gif filter=lfs diff=lfs merge=lfs -text
40
  sup-mat/vox-teaser.gif filter=lfs diff=lfs merge=lfs -text
41
+ got-05.jpg filter=lfs diff=lfs merge=lfs -text
00.mp4 ADDED
Binary file (527 kB). View file
 
animate.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm import tqdm
3
+
4
+ import torch
5
+ from torch.utils.data import DataLoader
6
+
7
+ from frames_dataset import PairedDataset
8
+ from logger import Logger, Visualizer
9
+ import imageio
10
+ from scipy.spatial import ConvexHull
11
+ import numpy as np
12
+
13
+ from sync_batchnorm import DataParallelWithCallback
14
+
15
+
16
+ def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,
17
+ use_relative_movement=False, use_relative_jacobian=False):
18
+ if adapt_movement_scale:
19
+ source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume
20
+ driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume
21
+ adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
22
+ else:
23
+ adapt_movement_scale = 1
24
+
25
+ kp_new = {k: v for k, v in kp_driving.items()}
26
+
27
+ if use_relative_movement:
28
+ kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
29
+ kp_value_diff *= adapt_movement_scale
30
+ kp_new['value'] = kp_value_diff + kp_source['value']
31
+
32
+ if use_relative_jacobian:
33
+ jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
34
+ kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
35
+
36
+ return kp_new
37
+
38
+
39
+ def animate(config, generator, kp_detector, checkpoint, log_dir, dataset):
40
+ log_dir = os.path.join(log_dir, 'animation')
41
+ png_dir = os.path.join(log_dir, 'png')
42
+ animate_params = config['animate_params']
43
+
44
+ dataset = PairedDataset(initial_dataset=dataset, number_of_pairs=animate_params['num_pairs'])
45
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
46
+
47
+ if checkpoint is not None:
48
+ Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector)
49
+ else:
50
+ raise AttributeError("Checkpoint should be specified for mode='animate'.")
51
+
52
+ if not os.path.exists(log_dir):
53
+ os.makedirs(log_dir)
54
+
55
+ if not os.path.exists(png_dir):
56
+ os.makedirs(png_dir)
57
+
58
+ if torch.cuda.is_available():
59
+ generator = DataParallelWithCallback(generator)
60
+ kp_detector = DataParallelWithCallback(kp_detector)
61
+
62
+ generator.eval()
63
+ kp_detector.eval()
64
+
65
+ for it, x in tqdm(enumerate(dataloader)):
66
+ with torch.no_grad():
67
+ predictions = []
68
+ visualizations = []
69
+
70
+ driving_video = x['driving_video']
71
+ source_frame = x['source_video'][:, :, 0, :, :]
72
+
73
+ kp_source = kp_detector(source_frame)
74
+ kp_driving_initial = kp_detector(driving_video[:, :, 0])
75
+
76
+ for frame_idx in range(driving_video.shape[2]):
77
+ driving_frame = driving_video[:, :, frame_idx]
78
+ kp_driving = kp_detector(driving_frame)
79
+ kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
80
+ kp_driving_initial=kp_driving_initial, **animate_params['normalization_params'])
81
+ out = generator(source_frame, kp_source=kp_source, kp_driving=kp_norm)
82
+
83
+ out['kp_driving'] = kp_driving
84
+ out['kp_source'] = kp_source
85
+ out['kp_norm'] = kp_norm
86
+
87
+ del out['sparse_deformed']
88
+
89
+ predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
90
+
91
+ visualization = Visualizer(**config['visualizer_params']).visualize(source=source_frame,
92
+ driving=driving_frame, out=out)
93
+ visualization = visualization
94
+ visualizations.append(visualization)
95
+
96
+ predictions = np.concatenate(predictions, axis=1)
97
+ result_name = "-".join([x['driving_name'][0], x['source_name'][0]])
98
+ imageio.imsave(os.path.join(png_dir, result_name + '.png'), (255 * predictions).astype(np.uint8))
99
+
100
+ image_name = result_name + animate_params['format']
101
+ imageio.mimsave(os.path.join(log_dir, image_name), visualizations)
augmentation.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code from https://github.com/hassony2/torch_videovision
3
+ """
4
+
5
+ import numbers
6
+
7
+ import random
8
+ import numpy as np
9
+ import PIL
10
+
11
+ from skimage.transform import resize, rotate
12
+ from skimage.util import pad
13
+ import torchvision
14
+
15
+ import warnings
16
+
17
+ from skimage import img_as_ubyte, img_as_float
18
+
19
+
20
+ def crop_clip(clip, min_h, min_w, h, w):
21
+ if isinstance(clip[0], np.ndarray):
22
+ cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]
23
+
24
+ elif isinstance(clip[0], PIL.Image.Image):
25
+ cropped = [
26
+ img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
27
+ ]
28
+ else:
29
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
30
+ 'but got list of {0}'.format(type(clip[0])))
31
+ return cropped
32
+
33
+
34
+ def pad_clip(clip, h, w):
35
+ im_h, im_w = clip[0].shape[:2]
36
+ pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2)
37
+ pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2)
38
+
39
+ return pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge')
40
+
41
+
42
+ def resize_clip(clip, size, interpolation='bilinear'):
43
+ if isinstance(clip[0], np.ndarray):
44
+ if isinstance(size, numbers.Number):
45
+ im_h, im_w, im_c = clip[0].shape
46
+ # Min spatial dim already matches minimal size
47
+ if (im_w <= im_h and im_w == size) or (im_h <= im_w
48
+ and im_h == size):
49
+ return clip
50
+ new_h, new_w = get_resize_sizes(im_h, im_w, size)
51
+ size = (new_w, new_h)
52
+ else:
53
+ size = size[1], size[0]
54
+
55
+ scaled = [
56
+ resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True,
57
+ mode='constant', anti_aliasing=True) for img in clip
58
+ ]
59
+ elif isinstance(clip[0], PIL.Image.Image):
60
+ if isinstance(size, numbers.Number):
61
+ im_w, im_h = clip[0].size
62
+ # Min spatial dim already matches minimal size
63
+ if (im_w <= im_h and im_w == size) or (im_h <= im_w
64
+ and im_h == size):
65
+ return clip
66
+ new_h, new_w = get_resize_sizes(im_h, im_w, size)
67
+ size = (new_w, new_h)
68
+ else:
69
+ size = size[1], size[0]
70
+ if interpolation == 'bilinear':
71
+ pil_inter = PIL.Image.NEAREST
72
+ else:
73
+ pil_inter = PIL.Image.BILINEAR
74
+ scaled = [img.resize(size, pil_inter) for img in clip]
75
+ else:
76
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
77
+ 'but got list of {0}'.format(type(clip[0])))
78
+ return scaled
79
+
80
+
81
+ def get_resize_sizes(im_h, im_w, size):
82
+ if im_w < im_h:
83
+ ow = size
84
+ oh = int(size * im_h / im_w)
85
+ else:
86
+ oh = size
87
+ ow = int(size * im_w / im_h)
88
+ return oh, ow
89
+
90
+
91
+ class RandomFlip(object):
92
+ def __init__(self, time_flip=False, horizontal_flip=False):
93
+ self.time_flip = time_flip
94
+ self.horizontal_flip = horizontal_flip
95
+
96
+ def __call__(self, clip):
97
+ if random.random() < 0.5 and self.time_flip:
98
+ return clip[::-1]
99
+ if random.random() < 0.5 and self.horizontal_flip:
100
+ return [np.fliplr(img) for img in clip]
101
+
102
+ return clip
103
+
104
+
105
+ class RandomResize(object):
106
+ """Resizes a list of (H x W x C) numpy.ndarray to the final size
107
+ The larger the original image is, the more times it takes to
108
+ interpolate
109
+ Args:
110
+ interpolation (str): Can be one of 'nearest', 'bilinear'
111
+ defaults to nearest
112
+ size (tuple): (widht, height)
113
+ """
114
+
115
+ def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'):
116
+ self.ratio = ratio
117
+ self.interpolation = interpolation
118
+
119
+ def __call__(self, clip):
120
+ scaling_factor = random.uniform(self.ratio[0], self.ratio[1])
121
+
122
+ if isinstance(clip[0], np.ndarray):
123
+ im_h, im_w, im_c = clip[0].shape
124
+ elif isinstance(clip[0], PIL.Image.Image):
125
+ im_w, im_h = clip[0].size
126
+
127
+ new_w = int(im_w * scaling_factor)
128
+ new_h = int(im_h * scaling_factor)
129
+ new_size = (new_w, new_h)
130
+ resized = resize_clip(
131
+ clip, new_size, interpolation=self.interpolation)
132
+
133
+ return resized
134
+
135
+
136
+ class RandomCrop(object):
137
+ """Extract random crop at the same location for a list of videos
138
+ Args:
139
+ size (sequence or int): Desired output size for the
140
+ crop in format (h, w)
141
+ """
142
+
143
+ def __init__(self, size):
144
+ if isinstance(size, numbers.Number):
145
+ size = (size, size)
146
+
147
+ self.size = size
148
+
149
+ def __call__(self, clip):
150
+ """
151
+ Args:
152
+ img (PIL.Image or numpy.ndarray): List of videos to be cropped
153
+ in format (h, w, c) in numpy.ndarray
154
+ Returns:
155
+ PIL.Image or numpy.ndarray: Cropped list of videos
156
+ """
157
+ h, w = self.size
158
+ if isinstance(clip[0], np.ndarray):
159
+ im_h, im_w, im_c = clip[0].shape
160
+ elif isinstance(clip[0], PIL.Image.Image):
161
+ im_w, im_h = clip[0].size
162
+ else:
163
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
164
+ 'but got list of {0}'.format(type(clip[0])))
165
+
166
+ clip = pad_clip(clip, h, w)
167
+ im_h, im_w = clip.shape[1:3]
168
+ x1 = 0 if h == im_h else random.randint(0, im_w - w)
169
+ y1 = 0 if w == im_w else random.randint(0, im_h - h)
170
+ cropped = crop_clip(clip, y1, x1, h, w)
171
+
172
+ return cropped
173
+
174
+
175
+ class RandomRotation(object):
176
+ """Rotate entire clip randomly by a random angle within
177
+ given bounds
178
+ Args:
179
+ degrees (sequence or int): Range of degrees to select from
180
+ If degrees is a number instead of sequence like (min, max),
181
+ the range of degrees, will be (-degrees, +degrees).
182
+ """
183
+
184
+ def __init__(self, degrees):
185
+ if isinstance(degrees, numbers.Number):
186
+ if degrees < 0:
187
+ raise ValueError('If degrees is a single number,'
188
+ 'must be positive')
189
+ degrees = (-degrees, degrees)
190
+ else:
191
+ if len(degrees) != 2:
192
+ raise ValueError('If degrees is a sequence,'
193
+ 'it must be of len 2.')
194
+
195
+ self.degrees = degrees
196
+
197
+ def __call__(self, clip):
198
+ """
199
+ Args:
200
+ img (PIL.Image or numpy.ndarray): List of videos to be cropped
201
+ in format (h, w, c) in numpy.ndarray
202
+ Returns:
203
+ PIL.Image or numpy.ndarray: Cropped list of videos
204
+ """
205
+ angle = random.uniform(self.degrees[0], self.degrees[1])
206
+ if isinstance(clip[0], np.ndarray):
207
+ rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip]
208
+ elif isinstance(clip[0], PIL.Image.Image):
209
+ rotated = [img.rotate(angle) for img in clip]
210
+ else:
211
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
212
+ 'but got list of {0}'.format(type(clip[0])))
213
+
214
+ return rotated
215
+
216
+
217
+ class ColorJitter(object):
218
+ """Randomly change the brightness, contrast and saturation and hue of the clip
219
+ Args:
220
+ brightness (float): How much to jitter brightness. brightness_factor
221
+ is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
222
+ contrast (float): How much to jitter contrast. contrast_factor
223
+ is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
224
+ saturation (float): How much to jitter saturation. saturation_factor
225
+ is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
226
+ hue(float): How much to jitter hue. hue_factor is chosen uniformly from
227
+ [-hue, hue]. Should be >=0 and <= 0.5.
228
+ """
229
+
230
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
231
+ self.brightness = brightness
232
+ self.contrast = contrast
233
+ self.saturation = saturation
234
+ self.hue = hue
235
+
236
+ def get_params(self, brightness, contrast, saturation, hue):
237
+ if brightness > 0:
238
+ brightness_factor = random.uniform(
239
+ max(0, 1 - brightness), 1 + brightness)
240
+ else:
241
+ brightness_factor = None
242
+
243
+ if contrast > 0:
244
+ contrast_factor = random.uniform(
245
+ max(0, 1 - contrast), 1 + contrast)
246
+ else:
247
+ contrast_factor = None
248
+
249
+ if saturation > 0:
250
+ saturation_factor = random.uniform(
251
+ max(0, 1 - saturation), 1 + saturation)
252
+ else:
253
+ saturation_factor = None
254
+
255
+ if hue > 0:
256
+ hue_factor = random.uniform(-hue, hue)
257
+ else:
258
+ hue_factor = None
259
+ return brightness_factor, contrast_factor, saturation_factor, hue_factor
260
+
261
+ def __call__(self, clip):
262
+ """
263
+ Args:
264
+ clip (list): list of PIL.Image
265
+ Returns:
266
+ list PIL.Image : list of transformed PIL.Image
267
+ """
268
+ if isinstance(clip[0], np.ndarray):
269
+ brightness, contrast, saturation, hue = self.get_params(
270
+ self.brightness, self.contrast, self.saturation, self.hue)
271
+
272
+ # Create img transform function sequence
273
+ img_transforms = []
274
+ if brightness is not None:
275
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
276
+ if saturation is not None:
277
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
278
+ if hue is not None:
279
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
280
+ if contrast is not None:
281
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
282
+ random.shuffle(img_transforms)
283
+ img_transforms = [img_as_ubyte, torchvision.transforms.ToPILImage()] + img_transforms + [np.array,
284
+ img_as_float]
285
+
286
+ with warnings.catch_warnings():
287
+ warnings.simplefilter("ignore")
288
+ jittered_clip = []
289
+ for img in clip:
290
+ jittered_img = img
291
+ for func in img_transforms:
292
+ jittered_img = func(jittered_img)
293
+ jittered_clip.append(jittered_img.astype('float32'))
294
+ elif isinstance(clip[0], PIL.Image.Image):
295
+ brightness, contrast, saturation, hue = self.get_params(
296
+ self.brightness, self.contrast, self.saturation, self.hue)
297
+
298
+ # Create img transform function sequence
299
+ img_transforms = []
300
+ if brightness is not None:
301
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
302
+ if saturation is not None:
303
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
304
+ if hue is not None:
305
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
306
+ if contrast is not None:
307
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
308
+ random.shuffle(img_transforms)
309
+
310
+ # Apply to all videos
311
+ jittered_clip = []
312
+ for img in clip:
313
+ for func in img_transforms:
314
+ jittered_img = func(img)
315
+ jittered_clip.append(jittered_img)
316
+
317
+ else:
318
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
319
+ 'but got list of {0}'.format(type(clip[0])))
320
+ return jittered_clip
321
+
322
+
323
+ class AllAugmentationTransform:
324
+ def __init__(self, resize_param=None, rotation_param=None, flip_param=None, crop_param=None, jitter_param=None):
325
+ self.transforms = []
326
+
327
+ if flip_param is not None:
328
+ self.transforms.append(RandomFlip(**flip_param))
329
+
330
+ if rotation_param is not None:
331
+ self.transforms.append(RandomRotation(**rotation_param))
332
+
333
+ if resize_param is not None:
334
+ self.transforms.append(RandomResize(**resize_param))
335
+
336
+ if crop_param is not None:
337
+ self.transforms.append(RandomCrop(**crop_param))
338
+
339
+ if jitter_param is not None:
340
+ self.transforms.append(ColorJitter(**jitter_param))
341
+
342
+ def __call__(self, clip):
343
+ for t in self.transforms:
344
+ clip = t(clip)
345
+ return clip
crop-video.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import face_alignment
2
+ import skimage.io
3
+ import numpy
4
+ from argparse import ArgumentParser
5
+ from skimage import img_as_ubyte
6
+ from skimage.transform import resize
7
+ from tqdm import tqdm
8
+ import os
9
+ import imageio
10
+ import numpy as np
11
+ import warnings
12
+ warnings.filterwarnings("ignore")
13
+
14
+ def extract_bbox(frame, fa):
15
+ if max(frame.shape[0], frame.shape[1]) > 640:
16
+ scale_factor = max(frame.shape[0], frame.shape[1]) / 640.0
17
+ frame = resize(frame, (int(frame.shape[0] / scale_factor), int(frame.shape[1] / scale_factor)))
18
+ frame = img_as_ubyte(frame)
19
+ else:
20
+ scale_factor = 1
21
+ frame = frame[..., :3]
22
+ bboxes = fa.face_detector.detect_from_image(frame[..., ::-1])
23
+ if len(bboxes) == 0:
24
+ return []
25
+ return np.array(bboxes)[:, :-1] * scale_factor
26
+
27
+
28
+
29
+ def bb_intersection_over_union(boxA, boxB):
30
+ xA = max(boxA[0], boxB[0])
31
+ yA = max(boxA[1], boxB[1])
32
+ xB = min(boxA[2], boxB[2])
33
+ yB = min(boxA[3], boxB[3])
34
+ interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
35
+ boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
36
+ boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
37
+ iou = interArea / float(boxAArea + boxBArea - interArea)
38
+ return iou
39
+
40
+
41
+ def join(tube_bbox, bbox):
42
+ xA = min(tube_bbox[0], bbox[0])
43
+ yA = min(tube_bbox[1], bbox[1])
44
+ xB = max(tube_bbox[2], bbox[2])
45
+ yB = max(tube_bbox[3], bbox[3])
46
+ return (xA, yA, xB, yB)
47
+
48
+
49
+ def compute_bbox(start, end, fps, tube_bbox, frame_shape, inp, image_shape, increase_area=0.1):
50
+ left, top, right, bot = tube_bbox
51
+ width = right - left
52
+ height = bot - top
53
+
54
+ #Computing aspect preserving bbox
55
+ width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width))
56
+ height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height))
57
+
58
+ left = int(left - width_increase * width)
59
+ top = int(top - height_increase * height)
60
+ right = int(right + width_increase * width)
61
+ bot = int(bot + height_increase * height)
62
+
63
+ top, bot, left, right = max(0, top), min(bot, frame_shape[0]), max(0, left), min(right, frame_shape[1])
64
+ h, w = bot - top, right - left
65
+
66
+ start = start / fps
67
+ end = end / fps
68
+ time = end - start
69
+
70
+ scale = f'{image_shape[0]}:{image_shape[1]}'
71
+
72
+ return f'ffmpeg -i {inp} -ss {start} -t {time} -filter:v "crop={w}:{h}:{left}:{top}, scale={scale}" crop.mp4'
73
+
74
+
75
+ def compute_bbox_trajectories(trajectories, fps, frame_shape, args):
76
+ commands = []
77
+ for i, (bbox, tube_bbox, start, end) in enumerate(trajectories):
78
+ if (end - start) > args.min_frames:
79
+ command = compute_bbox(start, end, fps, tube_bbox, frame_shape, inp=args.inp, image_shape=args.image_shape, increase_area=args.increase)
80
+ commands.append(command)
81
+ return commands
82
+
83
+
84
+ def process_video(args):
85
+ device = 'cpu' if args.cpu else 'cuda'
86
+ fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False, device=device)
87
+ video = imageio.get_reader(args.inp)
88
+
89
+ trajectories = []
90
+ previous_frame = None
91
+ fps = video.get_meta_data()['fps']
92
+ commands = []
93
+ try:
94
+ for i, frame in tqdm(enumerate(video)):
95
+ frame_shape = frame.shape
96
+ bboxes = extract_bbox(frame, fa)
97
+ ## For each trajectory check the criterion
98
+ not_valid_trajectories = []
99
+ valid_trajectories = []
100
+
101
+ for trajectory in trajectories:
102
+ tube_bbox = trajectory[0]
103
+ intersection = 0
104
+ for bbox in bboxes:
105
+ intersection = max(intersection, bb_intersection_over_union(tube_bbox, bbox))
106
+ if intersection > args.iou_with_initial:
107
+ valid_trajectories.append(trajectory)
108
+ else:
109
+ not_valid_trajectories.append(trajectory)
110
+
111
+ commands += compute_bbox_trajectories(not_valid_trajectories, fps, frame_shape, args)
112
+ trajectories = valid_trajectories
113
+
114
+ ## Assign bbox to trajectories, create new trajectories
115
+ for bbox in bboxes:
116
+ intersection = 0
117
+ current_trajectory = None
118
+ for trajectory in trajectories:
119
+ tube_bbox = trajectory[0]
120
+ current_intersection = bb_intersection_over_union(tube_bbox, bbox)
121
+ if intersection < current_intersection and current_intersection > args.iou_with_initial:
122
+ intersection = bb_intersection_over_union(tube_bbox, bbox)
123
+ current_trajectory = trajectory
124
+
125
+ ## Create new trajectory
126
+ if current_trajectory is None:
127
+ trajectories.append([bbox, bbox, i, i])
128
+ else:
129
+ current_trajectory[3] = i
130
+ current_trajectory[1] = join(current_trajectory[1], bbox)
131
+
132
+
133
+ except IndexError as e:
134
+ raise (e)
135
+
136
+ commands += compute_bbox_trajectories(trajectories, fps, frame_shape, args)
137
+ return commands
138
+
139
+
140
+ if __name__ == "__main__":
141
+ parser = ArgumentParser()
142
+
143
+ parser.add_argument("--image_shape", default=(256, 256), type=lambda x: tuple(map(int, x.split(','))),
144
+ help="Image shape")
145
+ parser.add_argument("--increase", default=0.1, type=float, help='Increase bbox by this amount')
146
+ parser.add_argument("--iou_with_initial", type=float, default=0.25, help="The minimal allowed iou with inital bbox")
147
+ parser.add_argument("--inp", required=True, help='Input image or video')
148
+ parser.add_argument("--min_frames", type=int, default=150, help='Minimum number of frames')
149
+ parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
150
+
151
+
152
+ args = parser.parse_args()
153
+
154
+ commands = process_video(args)
155
+ for command in commands:
156
+ print (command)
157
+
158
+
demo.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ matplotlib.use('Agg')
3
+ import os, sys
4
+ import yaml
5
+ from argparse import ArgumentParser
6
+ from tqdm import tqdm
7
+
8
+ import imageio
9
+ import numpy as np
10
+ from skimage.transform import resize
11
+ from skimage import img_as_ubyte
12
+ import torch
13
+ from sync_batchnorm import DataParallelWithCallback
14
+ from modules.generator import OcclusionAwareGenerator
15
+ from modules.keypoint_detector import KPDetector
16
+ from animate import normalize_kp
17
+ from scipy.spatial import ConvexHull
18
+
19
+
20
+ if sys.version_info[0] < 3:
21
+ raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")
22
+
23
+
24
+ def load_checkpoints(config_path, checkpoint_path, cpu=False):
25
+
26
+ with open(config_path) as f:
27
+ config = yaml.load(f)
28
+
29
+ generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
30
+ **config['model_params']['common_params'])
31
+ if not cpu:
32
+ generator.cuda()
33
+
34
+ kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
35
+ **config['model_params']['common_params'])
36
+ if not cpu:
37
+ kp_detector.cuda()
38
+
39
+ if cpu:
40
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
41
+ else:
42
+ checkpoint = torch.load(checkpoint_path)
43
+
44
+ generator.load_state_dict(checkpoint['generator'])
45
+ kp_detector.load_state_dict(checkpoint['kp_detector'])
46
+
47
+ if not cpu:
48
+ generator = DataParallelWithCallback(generator)
49
+ kp_detector = DataParallelWithCallback(kp_detector)
50
+
51
+ generator.eval()
52
+ kp_detector.eval()
53
+
54
+ return generator, kp_detector
55
+
56
+ def make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False):
57
+ with torch.no_grad():
58
+ predictions = []
59
+ source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
60
+ if not cpu:
61
+ source = source.cuda()
62
+ driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
63
+ kp_source = kp_detector(source)
64
+ kp_driving_initial = kp_detector(driving[:, :, 0])
65
+
66
+ for frame_idx in tqdm(range(driving.shape[2])):
67
+
68
+ driving_frame = driving[:, :, frame_idx]
69
+ if not cpu:
70
+ driving_frame = driving_frame.cuda()
71
+ kp_driving = kp_detector(driving_frame)
72
+ kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
73
+ kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
74
+ use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
75
+ out = generator(source, kp_source=kp_source, kp_driving=kp_norm)
76
+
77
+ predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
78
+ return predictions
79
+
80
+ def find_best_frame(source, driving, cpu=False):
81
+ import face_alignment
82
+
83
+ def normalize_kp(kp):
84
+ kp = kp - kp.mean(axis=0, keepdims=True)
85
+ area = ConvexHull(kp[:, :2]).volume
86
+ area = np.sqrt(area)
87
+ kp[:, :2] = kp[:, :2] / area
88
+ return kp
89
+
90
+ fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
91
+ device='cpu' if cpu else 'cuda')
92
+ kp_source = fa.get_landmarks(255 * source)[0]
93
+ kp_source = normalize_kp(kp_source)
94
+ norm = float('inf')
95
+ frame_num = 0
96
+ for i, image in tqdm(enumerate(driving)):
97
+ kp_driving = fa.get_landmarks(255 * image)[0]
98
+ kp_driving = normalize_kp(kp_driving)
99
+ new_norm = (np.abs(kp_source - kp_driving) ** 2).sum()
100
+ if new_norm < norm:
101
+ norm = new_norm
102
+ frame_num = i
103
+ return frame_num
104
+
105
+ if __name__ == "__main__":
106
+ parser = ArgumentParser()
107
+ parser.add_argument("--config", required=True, help="path to config")
108
+ parser.add_argument("--checkpoint", default='vox-cpk.pth.tar', help="path to checkpoint to restore")
109
+
110
+ parser.add_argument("--source_image", default='sup-mat/source.png', help="path to source image")
111
+ parser.add_argument("--driving_video", default='sup-mat/source.png', help="path to driving video")
112
+ parser.add_argument("--result_video", default='result.mp4', help="path to output")
113
+
114
+ parser.add_argument("--relative", dest="relative", action="store_true", help="use relative or absolute keypoint coordinates")
115
+ parser.add_argument("--adapt_scale", dest="adapt_scale", action="store_true", help="adapt movement scale based on convex hull of keypoints")
116
+
117
+ parser.add_argument("--find_best_frame", dest="find_best_frame", action="store_true",
118
+ help="Generate from the frame that is the most alligned with source. (Only for faces, requires face_aligment lib)")
119
+
120
+ parser.add_argument("--best_frame", dest="best_frame", type=int, default=None,
121
+ help="Set frame to start from.")
122
+
123
+ parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
124
+
125
+
126
+ parser.set_defaults(relative=False)
127
+ parser.set_defaults(adapt_scale=False)
128
+
129
+ opt = parser.parse_args()
130
+
131
+ source_image = imageio.imread(opt.source_image)
132
+ reader = imageio.get_reader(opt.driving_video)
133
+ fps = reader.get_meta_data()['fps']
134
+ driving_video = []
135
+ try:
136
+ for im in reader:
137
+ driving_video.append(im)
138
+ except RuntimeError:
139
+ pass
140
+ reader.close()
141
+
142
+ source_image = resize(source_image, (256, 256))[..., :3]
143
+ driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
144
+ generator, kp_detector = load_checkpoints(config_path=opt.config, checkpoint_path=opt.checkpoint, cpu=opt.cpu)
145
+
146
+ if opt.find_best_frame or opt.best_frame is not None:
147
+ i = opt.best_frame if opt.best_frame is not None else find_best_frame(source_image, driving_video, cpu=opt.cpu)
148
+ print ("Best frame: " + str(i))
149
+ driving_forward = driving_video[i:]
150
+ driving_backward = driving_video[:(i+1)][::-1]
151
+ predictions_forward = make_animation(source_image, driving_forward, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
152
+ predictions_backward = make_animation(source_image, driving_backward, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
153
+ predictions = predictions_backward[::-1] + predictions_forward[1:]
154
+ else:
155
+ predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
156
+ imageio.mimsave(opt.result_video, [img_as_ubyte(frame) for frame in predictions], fps=fps)
157
+
discordbot.py ADDED
File without changes
frames_dataset.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from skimage import io, img_as_float32
3
+ from skimage.color import gray2rgb
4
+ from sklearn.model_selection import train_test_split
5
+ from imageio import mimread
6
+
7
+ import numpy as np
8
+ from torch.utils.data import Dataset
9
+ import pandas as pd
10
+ from augmentation import AllAugmentationTransform
11
+ import glob
12
+
13
+
14
+ def read_video(name, frame_shape):
15
+ """
16
+ Read video which can be:
17
+ - an image of concatenated frames
18
+ - '.mp4' and'.gif'
19
+ - folder with videos
20
+ """
21
+
22
+ if os.path.isdir(name):
23
+ frames = sorted(os.listdir(name))
24
+ num_frames = len(frames)
25
+ video_array = np.array(
26
+ [img_as_float32(io.imread(os.path.join(name, frames[idx]))) for idx in range(num_frames)])
27
+ elif name.lower().endswith('.png') or name.lower().endswith('.jpg'):
28
+ image = io.imread(name)
29
+
30
+ if len(image.shape) == 2 or image.shape[2] == 1:
31
+ image = gray2rgb(image)
32
+
33
+ if image.shape[2] == 4:
34
+ image = image[..., :3]
35
+
36
+ image = img_as_float32(image)
37
+
38
+ video_array = np.moveaxis(image, 1, 0)
39
+
40
+ video_array = video_array.reshape((-1,) + frame_shape)
41
+ video_array = np.moveaxis(video_array, 1, 2)
42
+ elif name.lower().endswith('.gif') or name.lower().endswith('.mp4') or name.lower().endswith('.mov'):
43
+ video = np.array(mimread(name))
44
+ if len(video.shape) == 3:
45
+ video = np.array([gray2rgb(frame) for frame in video])
46
+ if video.shape[-1] == 4:
47
+ video = video[..., :3]
48
+ video_array = img_as_float32(video)
49
+ else:
50
+ raise Exception("Unknown file extensions %s" % name)
51
+
52
+ return video_array
53
+
54
+
55
+ class FramesDataset(Dataset):
56
+ """
57
+ Dataset of videos, each video can be represented as:
58
+ - an image of concatenated frames
59
+ - '.mp4' or '.gif'
60
+ - folder with all frames
61
+ """
62
+
63
+ def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True,
64
+ random_seed=0, pairs_list=None, augmentation_params=None):
65
+ self.root_dir = root_dir
66
+ self.videos = os.listdir(root_dir)
67
+ self.frame_shape = tuple(frame_shape)
68
+ self.pairs_list = pairs_list
69
+ self.id_sampling = id_sampling
70
+ if os.path.exists(os.path.join(root_dir, 'train')):
71
+ assert os.path.exists(os.path.join(root_dir, 'test'))
72
+ print("Use predefined train-test split.")
73
+ if id_sampling:
74
+ train_videos = {os.path.basename(video).split('#')[0] for video in
75
+ os.listdir(os.path.join(root_dir, 'train'))}
76
+ train_videos = list(train_videos)
77
+ else:
78
+ train_videos = os.listdir(os.path.join(root_dir, 'train'))
79
+ test_videos = os.listdir(os.path.join(root_dir, 'test'))
80
+ self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test')
81
+ else:
82
+ print("Use random train-test split.")
83
+ train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2)
84
+
85
+ if is_train:
86
+ self.videos = train_videos
87
+ else:
88
+ self.videos = test_videos
89
+
90
+ self.is_train = is_train
91
+
92
+ if self.is_train:
93
+ self.transform = AllAugmentationTransform(**augmentation_params)
94
+ else:
95
+ self.transform = None
96
+
97
+ def __len__(self):
98
+ return len(self.videos)
99
+
100
+ def __getitem__(self, idx):
101
+ if self.is_train and self.id_sampling:
102
+ name = self.videos[idx]
103
+ path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4')))
104
+ else:
105
+ name = self.videos[idx]
106
+ path = os.path.join(self.root_dir, name)
107
+
108
+ video_name = os.path.basename(path)
109
+
110
+ if self.is_train and os.path.isdir(path):
111
+ frames = os.listdir(path)
112
+ num_frames = len(frames)
113
+ frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2))
114
+ video_array = [img_as_float32(io.imread(os.path.join(path, frames[idx]))) for idx in frame_idx]
115
+ else:
116
+ video_array = read_video(path, frame_shape=self.frame_shape)
117
+ num_frames = len(video_array)
118
+ frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) if self.is_train else range(
119
+ num_frames)
120
+ video_array = video_array[frame_idx]
121
+
122
+ if self.transform is not None:
123
+ video_array = self.transform(video_array)
124
+
125
+ out = {}
126
+ if self.is_train:
127
+ source = np.array(video_array[0], dtype='float32')
128
+ driving = np.array(video_array[1], dtype='float32')
129
+
130
+ out['driving'] = driving.transpose((2, 0, 1))
131
+ out['source'] = source.transpose((2, 0, 1))
132
+ else:
133
+ video = np.array(video_array, dtype='float32')
134
+ out['video'] = video.transpose((3, 0, 1, 2))
135
+
136
+ out['name'] = video_name
137
+
138
+ return out
139
+
140
+
141
+ class DatasetRepeater(Dataset):
142
+ """
143
+ Pass several times over the same dataset for better i/o performance
144
+ """
145
+
146
+ def __init__(self, dataset, num_repeats=100):
147
+ self.dataset = dataset
148
+ self.num_repeats = num_repeats
149
+
150
+ def __len__(self):
151
+ return self.num_repeats * self.dataset.__len__()
152
+
153
+ def __getitem__(self, idx):
154
+ return self.dataset[idx % self.dataset.__len__()]
155
+
156
+
157
+ class PairedDataset(Dataset):
158
+ """
159
+ Dataset of pairs for animation.
160
+ """
161
+
162
+ def __init__(self, initial_dataset, number_of_pairs, seed=0):
163
+ self.initial_dataset = initial_dataset
164
+ pairs_list = self.initial_dataset.pairs_list
165
+
166
+ np.random.seed(seed)
167
+
168
+ if pairs_list is None:
169
+ max_idx = min(number_of_pairs, len(initial_dataset))
170
+ nx, ny = max_idx, max_idx
171
+ xy = np.mgrid[:nx, :ny].reshape(2, -1).T
172
+ number_of_pairs = min(xy.shape[0], number_of_pairs)
173
+ self.pairs = xy.take(np.random.choice(xy.shape[0], number_of_pairs, replace=False), axis=0)
174
+ else:
175
+ videos = self.initial_dataset.videos
176
+ name_to_index = {name: index for index, name in enumerate(videos)}
177
+ pairs = pd.read_csv(pairs_list)
178
+ pairs = pairs[np.logical_and(pairs['source'].isin(videos), pairs['driving'].isin(videos))]
179
+
180
+ number_of_pairs = min(pairs.shape[0], number_of_pairs)
181
+ self.pairs = []
182
+ self.start_frames = []
183
+ for ind in range(number_of_pairs):
184
+ self.pairs.append(
185
+ (name_to_index[pairs['driving'].iloc[ind]], name_to_index[pairs['source'].iloc[ind]]))
186
+
187
+ def __len__(self):
188
+ return len(self.pairs)
189
+
190
+ def __getitem__(self, idx):
191
+ pair = self.pairs[idx]
192
+ first = self.initial_dataset[pair[0]]
193
+ second = self.initial_dataset[pair[1]]
194
+ first = {'driving_' + key: value for key, value in first.items()}
195
+ second = {'source_' + key: value for key, value in second.items()}
196
+
197
+ return {**first, **second}
generated.mp4 ADDED
Binary file (130 kB). View file
 
got-05.jpg ADDED

Git LFS Details

  • SHA256: fe5f2983b83f900c2df85c41a90dd3cb0ccd4459830effac5134ab518e2e63ca
  • Pointer size: 132 Bytes
  • Size of remote file: 1.13 MB
logger.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import imageio
5
+
6
+ import os
7
+ from skimage.draw import circle
8
+
9
+ import matplotlib.pyplot as plt
10
+ import collections
11
+
12
+
13
+ class Logger:
14
+ def __init__(self, log_dir, checkpoint_freq=100, visualizer_params=None, zfill_num=8, log_file_name='log.txt'):
15
+
16
+ self.loss_list = []
17
+ self.cpk_dir = log_dir
18
+ self.visualizations_dir = os.path.join(log_dir, 'train-vis')
19
+ if not os.path.exists(self.visualizations_dir):
20
+ os.makedirs(self.visualizations_dir)
21
+ self.log_file = open(os.path.join(log_dir, log_file_name), 'a')
22
+ self.zfill_num = zfill_num
23
+ self.visualizer = Visualizer(**visualizer_params)
24
+ self.checkpoint_freq = checkpoint_freq
25
+ self.epoch = 0
26
+ self.best_loss = float('inf')
27
+ self.names = None
28
+
29
+ def log_scores(self, loss_names):
30
+ loss_mean = np.array(self.loss_list).mean(axis=0)
31
+
32
+ loss_string = "; ".join(["%s - %.5f" % (name, value) for name, value in zip(loss_names, loss_mean)])
33
+ loss_string = str(self.epoch).zfill(self.zfill_num) + ") " + loss_string
34
+
35
+ print(loss_string, file=self.log_file)
36
+ self.loss_list = []
37
+ self.log_file.flush()
38
+
39
+ def visualize_rec(self, inp, out):
40
+ image = self.visualizer.visualize(inp['driving'], inp['source'], out)
41
+ imageio.imsave(os.path.join(self.visualizations_dir, "%s-rec.png" % str(self.epoch).zfill(self.zfill_num)), image)
42
+
43
+ def save_cpk(self, emergent=False):
44
+ cpk = {k: v.state_dict() for k, v in self.models.items()}
45
+ cpk['epoch'] = self.epoch
46
+ cpk_path = os.path.join(self.cpk_dir, '%s-checkpoint.pth.tar' % str(self.epoch).zfill(self.zfill_num))
47
+ if not (os.path.exists(cpk_path) and emergent):
48
+ torch.save(cpk, cpk_path)
49
+
50
+ @staticmethod
51
+ def load_cpk(checkpoint_path, generator=None, discriminator=None, kp_detector=None,
52
+ optimizer_generator=None, optimizer_discriminator=None, optimizer_kp_detector=None):
53
+ checkpoint = torch.load(checkpoint_path)
54
+ if generator is not None:
55
+ generator.load_state_dict(checkpoint['generator'])
56
+ if kp_detector is not None:
57
+ kp_detector.load_state_dict(checkpoint['kp_detector'])
58
+ if discriminator is not None:
59
+ try:
60
+ discriminator.load_state_dict(checkpoint['discriminator'])
61
+ except:
62
+ print ('No discriminator in the state-dict. Dicriminator will be randomly initialized')
63
+ if optimizer_generator is not None:
64
+ optimizer_generator.load_state_dict(checkpoint['optimizer_generator'])
65
+ if optimizer_discriminator is not None:
66
+ try:
67
+ optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
68
+ except RuntimeError as e:
69
+ print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized')
70
+ if optimizer_kp_detector is not None:
71
+ optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector'])
72
+
73
+ return checkpoint['epoch']
74
+
75
+ def __enter__(self):
76
+ return self
77
+
78
+ def __exit__(self, exc_type, exc_val, exc_tb):
79
+ if 'models' in self.__dict__:
80
+ self.save_cpk()
81
+ self.log_file.close()
82
+
83
+ def log_iter(self, losses):
84
+ losses = collections.OrderedDict(losses.items())
85
+ if self.names is None:
86
+ self.names = list(losses.keys())
87
+ self.loss_list.append(list(losses.values()))
88
+
89
+ def log_epoch(self, epoch, models, inp, out):
90
+ self.epoch = epoch
91
+ self.models = models
92
+ if (self.epoch + 1) % self.checkpoint_freq == 0:
93
+ self.save_cpk()
94
+ self.log_scores(self.names)
95
+ self.visualize_rec(inp, out)
96
+
97
+
98
+ class Visualizer:
99
+ def __init__(self, kp_size=5, draw_border=False, colormap='gist_rainbow'):
100
+ self.kp_size = kp_size
101
+ self.draw_border = draw_border
102
+ self.colormap = plt.get_cmap(colormap)
103
+
104
+ def draw_image_with_kp(self, image, kp_array):
105
+ image = np.copy(image)
106
+ spatial_size = np.array(image.shape[:2][::-1])[np.newaxis]
107
+ kp_array = spatial_size * (kp_array + 1) / 2
108
+ num_kp = kp_array.shape[0]
109
+ for kp_ind, kp in enumerate(kp_array):
110
+ rr, cc = circle(kp[1], kp[0], self.kp_size, shape=image.shape[:2])
111
+ image[rr, cc] = np.array(self.colormap(kp_ind / num_kp))[:3]
112
+ return image
113
+
114
+ def create_image_column_with_kp(self, images, kp):
115
+ image_array = np.array([self.draw_image_with_kp(v, k) for v, k in zip(images, kp)])
116
+ return self.create_image_column(image_array)
117
+
118
+ def create_image_column(self, images):
119
+ if self.draw_border:
120
+ images = np.copy(images)
121
+ images[:, :, [0, -1]] = (1, 1, 1)
122
+ images[:, :, [0, -1]] = (1, 1, 1)
123
+ return np.concatenate(list(images), axis=0)
124
+
125
+ def create_image_grid(self, *args):
126
+ out = []
127
+ for arg in args:
128
+ if type(arg) == tuple:
129
+ out.append(self.create_image_column_with_kp(arg[0], arg[1]))
130
+ else:
131
+ out.append(self.create_image_column(arg))
132
+ return np.concatenate(out, axis=1)
133
+
134
+ def visualize(self, driving, source, out):
135
+ images = []
136
+
137
+ # Source image with keypoints
138
+ source = source.data.cpu()
139
+ kp_source = out['kp_source']['value'].data.cpu().numpy()
140
+ source = np.transpose(source, [0, 2, 3, 1])
141
+ images.append((source, kp_source))
142
+
143
+ # Equivariance visualization
144
+ if 'transformed_frame' in out:
145
+ transformed = out['transformed_frame'].data.cpu().numpy()
146
+ transformed = np.transpose(transformed, [0, 2, 3, 1])
147
+ transformed_kp = out['transformed_kp']['value'].data.cpu().numpy()
148
+ images.append((transformed, transformed_kp))
149
+
150
+ # Driving image with keypoints
151
+ kp_driving = out['kp_driving']['value'].data.cpu().numpy()
152
+ driving = driving.data.cpu().numpy()
153
+ driving = np.transpose(driving, [0, 2, 3, 1])
154
+ images.append((driving, kp_driving))
155
+
156
+ # Deformed image
157
+ if 'deformed' in out:
158
+ deformed = out['deformed'].data.cpu().numpy()
159
+ deformed = np.transpose(deformed, [0, 2, 3, 1])
160
+ images.append(deformed)
161
+
162
+ # Result with and without keypoints
163
+ prediction = out['prediction'].data.cpu().numpy()
164
+ prediction = np.transpose(prediction, [0, 2, 3, 1])
165
+ if 'kp_norm' in out:
166
+ kp_norm = out['kp_norm']['value'].data.cpu().numpy()
167
+ images.append((prediction, kp_norm))
168
+ images.append(prediction)
169
+
170
+
171
+ ## Occlusion map
172
+ if 'occlusion_map' in out:
173
+ occlusion_map = out['occlusion_map'].data.cpu().repeat(1, 3, 1, 1)
174
+ occlusion_map = F.interpolate(occlusion_map, size=source.shape[1:3]).numpy()
175
+ occlusion_map = np.transpose(occlusion_map, [0, 2, 3, 1])
176
+ images.append(occlusion_map)
177
+
178
+ # Deformed images according to each individual transform
179
+ if 'sparse_deformed' in out:
180
+ full_mask = []
181
+ for i in range(out['sparse_deformed'].shape[1]):
182
+ image = out['sparse_deformed'][:, i].data.cpu()
183
+ image = F.interpolate(image, size=source.shape[1:3])
184
+ mask = out['mask'][:, i:(i+1)].data.cpu().repeat(1, 3, 1, 1)
185
+ mask = F.interpolate(mask, size=source.shape[1:3])
186
+ image = np.transpose(image.numpy(), (0, 2, 3, 1))
187
+ mask = np.transpose(mask.numpy(), (0, 2, 3, 1))
188
+
189
+ if i != 0:
190
+ color = np.array(self.colormap((i - 1) / (out['sparse_deformed'].shape[1] - 1)))[:3]
191
+ else:
192
+ color = np.array((0, 0, 0))
193
+
194
+ color = color.reshape((1, 1, 1, 3))
195
+
196
+ images.append(image)
197
+ if i != 0:
198
+ images.append(mask * color)
199
+ else:
200
+ images.append(mask)
201
+
202
+ full_mask.append(mask * color)
203
+
204
+ images.append(sum(full_mask))
205
+
206
+ image = self.create_image_grid(*images)
207
+ image = (255 * image).astype(np.uint8)
208
+ return image
reconstruction.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm import tqdm
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+ from logger import Logger, Visualizer
6
+ import numpy as np
7
+ import imageio
8
+ from sync_batchnorm import DataParallelWithCallback
9
+
10
+
11
+ def reconstruction(config, generator, kp_detector, checkpoint, log_dir, dataset):
12
+ png_dir = os.path.join(log_dir, 'reconstruction/png')
13
+ log_dir = os.path.join(log_dir, 'reconstruction')
14
+
15
+ if checkpoint is not None:
16
+ Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector)
17
+ else:
18
+ raise AttributeError("Checkpoint should be specified for mode='reconstruction'.")
19
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
20
+
21
+ if not os.path.exists(log_dir):
22
+ os.makedirs(log_dir)
23
+
24
+ if not os.path.exists(png_dir):
25
+ os.makedirs(png_dir)
26
+
27
+ loss_list = []
28
+ if torch.cuda.is_available():
29
+ generator = DataParallelWithCallback(generator)
30
+ kp_detector = DataParallelWithCallback(kp_detector)
31
+
32
+ generator.eval()
33
+ kp_detector.eval()
34
+
35
+ for it, x in tqdm(enumerate(dataloader)):
36
+ if config['reconstruction_params']['num_videos'] is not None:
37
+ if it > config['reconstruction_params']['num_videos']:
38
+ break
39
+ with torch.no_grad():
40
+ predictions = []
41
+ visualizations = []
42
+ if torch.cuda.is_available():
43
+ x['video'] = x['video'].cuda()
44
+ kp_source = kp_detector(x['video'][:, :, 0])
45
+ for frame_idx in range(x['video'].shape[2]):
46
+ source = x['video'][:, :, 0]
47
+ driving = x['video'][:, :, frame_idx]
48
+ kp_driving = kp_detector(driving)
49
+ out = generator(source, kp_source=kp_source, kp_driving=kp_driving)
50
+ out['kp_source'] = kp_source
51
+ out['kp_driving'] = kp_driving
52
+ del out['sparse_deformed']
53
+ predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
54
+
55
+ visualization = Visualizer(**config['visualizer_params']).visualize(source=source,
56
+ driving=driving, out=out)
57
+ visualizations.append(visualization)
58
+
59
+ loss_list.append(torch.abs(out['prediction'] - driving).mean().cpu().numpy())
60
+
61
+ predictions = np.concatenate(predictions, axis=1)
62
+ imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8))
63
+
64
+ image_name = x['name'][0] + config['reconstruction_params']['format']
65
+ imageio.mimsave(os.path.join(log_dir, image_name), visualizations)
66
+
67
+ print("Reconstruction loss: %s" % np.mean(loss_list))
run.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+
3
+ matplotlib.use('Agg')
4
+
5
+ import os, sys
6
+ import yaml
7
+ from argparse import ArgumentParser
8
+ from time import gmtime, strftime
9
+ from shutil import copy
10
+
11
+ from frames_dataset import FramesDataset
12
+
13
+ from modules.generator import OcclusionAwareGenerator
14
+ from modules.discriminator import MultiScaleDiscriminator
15
+ from modules.keypoint_detector import KPDetector
16
+
17
+ import torch
18
+
19
+ from train import train
20
+ from reconstruction import reconstruction
21
+ from animate import animate
22
+
23
+ if __name__ == "__main__":
24
+
25
+ if sys.version_info[0] < 3:
26
+ raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")
27
+
28
+ parser = ArgumentParser()
29
+ parser.add_argument("--config", required=True, help="path to config")
30
+ parser.add_argument("--mode", default="train", choices=["train", "reconstruction", "animate"])
31
+ parser.add_argument("--log_dir", default='log', help="path to log into")
32
+ parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore")
33
+ parser.add_argument("--device_ids", default="0", type=lambda x: list(map(int, x.split(','))),
34
+ help="Names of the devices comma separated.")
35
+ parser.add_argument("--verbose", dest="verbose", action="store_true", help="Print model architecture")
36
+ parser.set_defaults(verbose=False)
37
+
38
+ opt = parser.parse_args()
39
+ with open(opt.config) as f:
40
+ config = yaml.load(f)
41
+
42
+ if opt.checkpoint is not None:
43
+ log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1])
44
+ else:
45
+ log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0])
46
+ log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime())
47
+
48
+ generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
49
+ **config['model_params']['common_params'])
50
+
51
+ if torch.cuda.is_available():
52
+ generator.to(opt.device_ids[0])
53
+ if opt.verbose:
54
+ print(generator)
55
+
56
+ discriminator = MultiScaleDiscriminator(**config['model_params']['discriminator_params'],
57
+ **config['model_params']['common_params'])
58
+ if torch.cuda.is_available():
59
+ discriminator.to(opt.device_ids[0])
60
+ if opt.verbose:
61
+ print(discriminator)
62
+
63
+ kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
64
+ **config['model_params']['common_params'])
65
+
66
+ if torch.cuda.is_available():
67
+ kp_detector.to(opt.device_ids[0])
68
+
69
+ if opt.verbose:
70
+ print(kp_detector)
71
+
72
+ dataset = FramesDataset(is_train=(opt.mode == 'train'), **config['dataset_params'])
73
+
74
+ if not os.path.exists(log_dir):
75
+ os.makedirs(log_dir)
76
+ if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))):
77
+ copy(opt.config, log_dir)
78
+
79
+ if opt.mode == 'train':
80
+ print("Training...")
81
+ train(config, generator, discriminator, kp_detector, opt.checkpoint, log_dir, dataset, opt.device_ids)
82
+ elif opt.mode == 'reconstruction':
83
+ print("Reconstruction...")
84
+ reconstruction(config, generator, kp_detector, opt.checkpoint, log_dir, dataset)
85
+ elif opt.mode == 'animate':
86
+ print("Animate...")
87
+ animate(config, generator, kp_detector, opt.checkpoint, log_dir, dataset)
sdkan.png ADDED
train.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import trange
2
+ import torch
3
+
4
+ from torch.utils.data import DataLoader
5
+
6
+ from logger import Logger
7
+ from modules.model import GeneratorFullModel, DiscriminatorFullModel
8
+
9
+ from torch.optim.lr_scheduler import MultiStepLR
10
+
11
+ from sync_batchnorm import DataParallelWithCallback
12
+
13
+ from frames_dataset import DatasetRepeater
14
+
15
+
16
+ def train(config, generator, discriminator, kp_detector, checkpoint, log_dir, dataset, device_ids):
17
+ train_params = config['train_params']
18
+
19
+ optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params['lr_generator'], betas=(0.5, 0.999))
20
+ optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=train_params['lr_discriminator'], betas=(0.5, 0.999))
21
+ optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(), lr=train_params['lr_kp_detector'], betas=(0.5, 0.999))
22
+
23
+ if checkpoint is not None:
24
+ start_epoch = Logger.load_cpk(checkpoint, generator, discriminator, kp_detector,
25
+ optimizer_generator, optimizer_discriminator,
26
+ None if train_params['lr_kp_detector'] == 0 else optimizer_kp_detector)
27
+ else:
28
+ start_epoch = 0
29
+
30
+ scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1,
31
+ last_epoch=start_epoch - 1)
32
+ scheduler_discriminator = MultiStepLR(optimizer_discriminator, train_params['epoch_milestones'], gamma=0.1,
33
+ last_epoch=start_epoch - 1)
34
+ scheduler_kp_detector = MultiStepLR(optimizer_kp_detector, train_params['epoch_milestones'], gamma=0.1,
35
+ last_epoch=-1 + start_epoch * (train_params['lr_kp_detector'] != 0))
36
+
37
+ if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
38
+ dataset = DatasetRepeater(dataset, train_params['num_repeats'])
39
+ dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=6, drop_last=True)
40
+
41
+ generator_full = GeneratorFullModel(kp_detector, generator, discriminator, train_params)
42
+ discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params)
43
+
44
+ if torch.cuda.is_available():
45
+ generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids)
46
+ discriminator_full = DataParallelWithCallback(discriminator_full, device_ids=device_ids)
47
+
48
+ with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger:
49
+ for epoch in trange(start_epoch, train_params['num_epochs']):
50
+ for x in dataloader:
51
+ losses_generator, generated = generator_full(x)
52
+
53
+ loss_values = [val.mean() for val in losses_generator.values()]
54
+ loss = sum(loss_values)
55
+
56
+ loss.backward()
57
+ optimizer_generator.step()
58
+ optimizer_generator.zero_grad()
59
+ optimizer_kp_detector.step()
60
+ optimizer_kp_detector.zero_grad()
61
+
62
+ if train_params['loss_weights']['generator_gan'] != 0:
63
+ optimizer_discriminator.zero_grad()
64
+ losses_discriminator = discriminator_full(x, generated)
65
+ loss_values = [val.mean() for val in losses_discriminator.values()]
66
+ loss = sum(loss_values)
67
+
68
+ loss.backward()
69
+ optimizer_discriminator.step()
70
+ optimizer_discriminator.zero_grad()
71
+ else:
72
+ losses_discriminator = {}
73
+
74
+ losses_generator.update(losses_discriminator)
75
+ losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()}
76
+ logger.log_iter(losses=losses)
77
+
78
+ scheduler_generator.step()
79
+ scheduler_discriminator.step()
80
+ scheduler_kp_detector.step()
81
+
82
+ logger.log_epoch(epoch, {'generator': generator,
83
+ 'discriminator': discriminator,
84
+ 'kp_detector': kp_detector,
85
+ 'optimizer_generator': optimizer_generator,
86
+ 'optimizer_discriminator': optimizer_discriminator,
87
+ 'optimizer_kp_detector': optimizer_kp_detector}, inp=x, out=generated)