PascalLiu commited on
Commit
d51072c
1 Parent(s): 4489d0e

complete image and video

Browse files
api.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import yaml
4
+ from argparse import ArgumentParser
5
+ from tqdm import tqdm
6
+
7
+ import numpy as np
8
+ import imageio
9
+ from skimage.transform import resize
10
+ from skimage import img_as_ubyte
11
+ from scipy.spatial import ConvexHull
12
+ import torch
13
+ from sync_batchnorm import DataParallelWithCallback
14
+ import face_alignment
15
+
16
+ from modules.generator import OcclusionAwareGenerator_SPADE
17
+ from modules.keypoint_detector import KPDetector
18
+
19
+
20
+ def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,
21
+ use_relative_movement=False, use_relative_jacobian=False):
22
+ kp_new = {k: v for k, v in kp_driving.items()}
23
+
24
+ if adapt_movement_scale:
25
+ source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume
26
+ driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume
27
+ adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
28
+ kp_new['value'] = kp_driving['value'] * adapt_movement_scale # for reenactment demo
29
+ else:
30
+ adapt_movement_scale = 1
31
+
32
+ if use_relative_movement:
33
+ kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
34
+ kp_value_diff *= adapt_movement_scale
35
+ kp_new['value'] = kp_value_diff + kp_source['value']
36
+
37
+ if use_relative_jacobian:
38
+ jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
39
+ kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
40
+
41
+ return kp_new
42
+
43
+
44
+ def load_checkpoints(config_path, checkpoint_path, cpu=False):
45
+ with open(config_path) as f:
46
+ # config = yaml.load(f)
47
+ config = yaml.load(f, Loader=yaml.FullLoader)
48
+
49
+ generator = OcclusionAwareGenerator_SPADE(**config['model_params']['generator_params'],
50
+ **config['model_params']['common_params'])
51
+ if not cpu:
52
+ generator.cuda()
53
+
54
+ kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
55
+ **config['model_params']['common_params'])
56
+ if not cpu:
57
+ kp_detector.cuda()
58
+
59
+ if cpu:
60
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
61
+ else:
62
+ checkpoint = torch.load(checkpoint_path)
63
+
64
+ generator.load_state_dict(checkpoint['generator'])
65
+ kp_detector.load_state_dict(checkpoint['kp_detector'])
66
+
67
+ if not cpu:
68
+ generator = DataParallelWithCallback(generator)
69
+ kp_detector = DataParallelWithCallback(kp_detector)
70
+
71
+ generator.eval()
72
+ kp_detector.eval()
73
+
74
+ return generator, kp_detector
75
+
76
+
77
+ def make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True,
78
+ cpu=False):
79
+ with torch.no_grad():
80
+ predictions = []
81
+ source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
82
+ if not cpu:
83
+ source = source.cuda()
84
+ driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
85
+ kp_source = kp_detector(source)
86
+ kp_driving_initial = kp_detector(driving[:, :, 0])
87
+
88
+ for frame_idx in tqdm(range(driving.shape[2])):
89
+ driving_frame = driving[:, :, frame_idx]
90
+ if not cpu:
91
+ driving_frame = driving_frame.cuda()
92
+ kp_driving = kp_detector(driving_frame)
93
+ kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
94
+ kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
95
+ use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
96
+ out = generator(source, kp_source=kp_source, kp_driving=kp_norm)
97
+
98
+ predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
99
+ return predictions
100
+
101
+
102
+ def find_best_frame_func(source, driving, cpu=False):
103
+ def normalize_kp_infunc(kp):
104
+ kp = kp - kp.mean(axis=0, keepdims=True)
105
+ area = ConvexHull(kp[:, :2]).volume
106
+ area = np.sqrt(area)
107
+ kp[:, :2] = kp[:, :2] / area
108
+ return kp
109
+
110
+ fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
111
+ device='cpu' if cpu else 'cuda')
112
+ kp_source = fa.get_landmarks(255 * source)[0]
113
+ kp_source = normalize_kp_infunc(kp_source)
114
+ norm = float('inf')
115
+ frame_num = 0
116
+ for i, image in tqdm(enumerate(driving)):
117
+ kp_driving = fa.get_landmarks(255 * image)[0]
118
+ kp_driving = normalize_kp_infunc(kp_driving)
119
+ new_norm = (np.abs(kp_source - kp_driving) ** 2).sum()
120
+ if new_norm < norm:
121
+ norm = new_norm
122
+ frame_num = i
123
+ return frame_num
124
+
125
+
126
+ def drive_im(source_image, driving_image, adapt_scale):
127
+ source_image = resize(source_image, (256, 256))[..., :3]
128
+ driving_image = [resize(driving_image, (256, 256))[..., :3]]
129
+
130
+ prediction = make_animation(source_image, driving_image, generator, kp_detector, relative=False,
131
+ adapt_movement_scale=adapt_scale, cpu=cpu)
132
+ return img_as_ubyte(prediction[0])
133
+
134
+
135
+ def drive_vi(source_image, driving_video, mode, find_best_frame, relative, adapt_scale):
136
+ reader = imageio.get_reader(driving_video)
137
+ fps = reader.get_meta_data()['fps']
138
+ driving_video = []
139
+ try:
140
+ for im in reader:
141
+ driving_video.append(im)
142
+ except RuntimeError:
143
+ pass
144
+ reader.close()
145
+
146
+
147
+ if mode == 'reconstruction':
148
+ source_image = driving_video[0]
149
+
150
+ source_image = resize(source_image, (256, 256))[..., :3]
151
+ driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
152
+
153
+ if find_best_frame:
154
+ i = find_best_frame_func(source_image, driving_video, cpu=cpu)
155
+ print("Best frame: " + str(i))
156
+ driving_forward = driving_video[i:]
157
+ driving_backward = driving_video[:(i + 1)][::-1]
158
+ predictions_forward = make_animation(source_image, driving_forward, generator, kp_detector,
159
+ relative=relative, adapt_movement_scale=adapt_scale, cpu=cpu)
160
+ predictions_backward = make_animation(source_image, driving_backward, generator, kp_detector,
161
+ relative=relative, adapt_movement_scale=adapt_scale, cpu=cpu)
162
+ predictions = predictions_backward[::-1] + predictions_forward[1:]
163
+ else:
164
+ predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=relative,
165
+ adapt_movement_scale=adapt_scale, cpu=cpu)
166
+ result_video_path = "result_video.mp4"
167
+ imageio.mimsave(result_video_path, [img_as_ubyte(frame) for frame in predictions], fps=fps)
168
+ return result_video_path
169
+
170
+
171
+ os.environ['CUDA_VISIBLE_DEVICES'] = "3"
172
+ config = "config/vox-256.yaml"
173
+ checkpoint = "00000099-checkpoint.pth.tar"
174
+ mode = "reenactment"
175
+ relative = True
176
+ adapt_scale = True
177
+ find_best_frame = True
178
+ cpu = False # decided by the deploying environment
179
+
180
+ description = "We propose a Face Neural Volume Rendering (FNeVR) network for more realistic face animation, by taking the merits of 2D motion warping on facial expression transformation and 3D volume rendering on high-quality image synthesis in a unified framework.<br>[Paper](https://arxiv.org/abs/2209.10340) and [Code](https://github.com/zengbohan0217/FNeVR)"
181
+ im_description = "We can animate a face portrait by a single image in this tab.<br>Please input the origin face and the driving face which provides pose and expression information, then we can obtain the virtual generated face.<br>We can select \"adaptive scale\" parameter for better optic flow estimation using adaptive movement scale based on convex hull of keypoints."
182
+ vi_description = "We can animate a face portrait by a video in this tab.<br>Please input the origin face and the driving video which provides pose and expression information, then we can obtain the virtual generated video.<br>Please select inference mode (reenactment for different identities and reconstruction for the same identities).<br>We can select \"find best frame\" parameter to generate video from the frame that is the most alligned with source image, select \"relative motion\" paramter to use relative keypoint coordinates for preserving global object geometry, and select \"adaptive scale\" parameter for better optic flow estimation using adaptive movement scale based on convex hull of keypoints."
183
+ acknowledgements = "This work was supported by “the Fundamental Research Funds for the Central Universities”, and the National Natural Science Foundation of China under Grant 62076016, Beijing Natural Science Foundation-Xiaomi Innovation Joint Fund L223024. Besides, we gratefully acknowledge the support of [MindSpore](https://www.mindspore.cn), CANN (Compute Architecture for Neural Networks) and Ascend AI processor used for this research.<br>Our FNeVR implementation is inspired by [FOMM](https://github.com/AliaksandrSiarohin/first-order-model) and [DECA](https://github.com/YadiraF/DECA). We appreciate the authors of these papers for making their codes available to the public."
184
+
185
+ generator, kp_detector = load_checkpoints(config_path=config, checkpoint_path=checkpoint, cpu=cpu)
186
+
187
+ # iface = gr.Interface(fn=drive_im,
188
+ # inputs=[gr.Image(label="Origin face"),
189
+ # gr.Image(label="Driving face"),
190
+ # gr.CheckboxGroup(label="adapt scale")],
191
+ # outputs=gr.Image(label="Generated face"), examples=[["sup-mat/source.png"], ["sup-mat/driving.png"]],
192
+ # title="Demostration of FNeVR", description=description)
193
+
194
+ with gr.Blocks(title="Demostration of FNeVR") as demo:
195
+ gr.Markdown("# <center> Demostration of FNeVR")
196
+ gr.Markdown(description)
197
+
198
+ with gr.Tab("Driving by image"):
199
+ gr.Markdown(im_description)
200
+
201
+ with gr.Row():
202
+ with gr.Column():
203
+ gr.Markdown("#### Inputs")
204
+ inp2 = gr.Image(label="Driving face")
205
+ inp1 = gr.Image(label="Origin face")
206
+
207
+ gr.Markdown("#### Parameter")
208
+ inp3 = gr.Checkbox(value=True, label="adaptive scale")
209
+
210
+ btn = gr.Button(value="Animate")
211
+ with gr.Column():
212
+ gr.Markdown("#### Output")
213
+ outp = gr.Image(label="Generated face")
214
+
215
+ gr.Examples([["sup-mat/driving.png", "sup-mat/source.png"]], [inp2, inp1])
216
+
217
+ btn.click(fn=drive_im, inputs=[inp1, inp2, inp3], outputs=outp)
218
+ with gr.Tab("Driving by video"):
219
+ gr.Markdown(vi_description)
220
+
221
+ with gr.Row():
222
+ with gr.Column():
223
+ gr.Markdown("#### Inputs")
224
+ inp2 = gr.Video(label="Driving video")
225
+ inp1 = gr.Image(label="Origin face")
226
+
227
+ gr.Markdown("#### Parameters")
228
+ inp3 = gr.Radio(choices=['reenactment', 'reconstruction'], value="reenactment", label="mode (if \"reconstruction\" selected, origin face is the first frame of driving video)")
229
+ inp4 = gr.Checkbox(value=True, label="find best frame (more time consumed)")
230
+ inp5 = gr.Checkbox(value=True, label="relative motion")
231
+ inp6 = gr.Checkbox(value=True, label="adaptive scale")
232
+
233
+ btn = gr.Button(value="Animate")
234
+ with gr.Column():
235
+ gr.Markdown("#### Output")
236
+ outp = gr.Video(label="Generated video")
237
+
238
+ gr.Examples([["sup-mat/driving.mp4", "sup-mat/source_for_video.png"]], [inp2, inp1])
239
+
240
+ btn.click(fn=drive_vi, inputs=[inp1, inp2, inp3, inp4, inp5, inp6], outputs=outp)
241
+
242
+ with gr.Tab("Real time animation"):
243
+ gr.Markdown("#### Real time animation")
244
+
245
+ gr.Markdown("## Acknowledgements")
246
+ gr.Markdown(acknowledgements)
247
+
248
+ demo.launch()
augmentation.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code from https://github.com/hassony2/torch_videovision
3
+ """
4
+
5
+ import numbers
6
+
7
+ import random
8
+ import numpy as np
9
+ import PIL
10
+
11
+ from skimage.transform import resize, rotate
12
+ from skimage.util import pad
13
+ import torchvision
14
+
15
+ import warnings
16
+
17
+ from skimage import img_as_ubyte, img_as_float
18
+
19
+
20
+ def crop_clip(clip, min_h, min_w, h, w):
21
+ if isinstance(clip[0], np.ndarray):
22
+ cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]
23
+
24
+ elif isinstance(clip[0], PIL.Image.Image):
25
+ cropped = [
26
+ img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
27
+ ]
28
+ else:
29
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
30
+ 'but got list of {0}'.format(type(clip[0])))
31
+ return cropped
32
+
33
+
34
+ def pad_clip(clip, h, w):
35
+ im_h, im_w = clip[0].shape[:2]
36
+ pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2)
37
+ pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2)
38
+
39
+ return pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge')
40
+
41
+
42
+ def resize_clip(clip, size, interpolation='bilinear'):
43
+ if isinstance(clip[0], np.ndarray):
44
+ if isinstance(size, numbers.Number):
45
+ im_h, im_w, im_c = clip[0].shape
46
+ # Min spatial dim already matches minimal size
47
+ if (im_w <= im_h and im_w == size) or (im_h <= im_w
48
+ and im_h == size):
49
+ return clip
50
+ new_h, new_w = get_resize_sizes(im_h, im_w, size)
51
+ size = (new_w, new_h)
52
+ else:
53
+ size = size[1], size[0]
54
+
55
+ scaled = [
56
+ resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True,
57
+ mode='constant', anti_aliasing=True) for img in clip
58
+ ]
59
+ elif isinstance(clip[0], PIL.Image.Image):
60
+ if isinstance(size, numbers.Number):
61
+ im_w, im_h = clip[0].size
62
+ # Min spatial dim already matches minimal size
63
+ if (im_w <= im_h and im_w == size) or (im_h <= im_w
64
+ and im_h == size):
65
+ return clip
66
+ new_h, new_w = get_resize_sizes(im_h, im_w, size)
67
+ size = (new_w, new_h)
68
+ else:
69
+ size = size[1], size[0]
70
+ if interpolation == 'bilinear':
71
+ pil_inter = PIL.Image.NEAREST
72
+ else:
73
+ pil_inter = PIL.Image.BILINEAR
74
+ scaled = [img.resize(size, pil_inter) for img in clip]
75
+ else:
76
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
77
+ 'but got list of {0}'.format(type(clip[0])))
78
+ return scaled
79
+
80
+
81
+ def get_resize_sizes(im_h, im_w, size):
82
+ if im_w < im_h:
83
+ ow = size
84
+ oh = int(size * im_h / im_w)
85
+ else:
86
+ oh = size
87
+ ow = int(size * im_w / im_h)
88
+ return oh, ow
89
+
90
+
91
+ class RandomFlip(object):
92
+ def __init__(self, time_flip=False, horizontal_flip=False):
93
+ self.time_flip = time_flip
94
+ self.horizontal_flip = horizontal_flip
95
+
96
+ def __call__(self, clip):
97
+ if random.random() < 0.5 and self.time_flip:
98
+ return clip[::-1]
99
+ if random.random() < 0.5 and self.horizontal_flip:
100
+ return [np.fliplr(img) for img in clip]
101
+
102
+ return clip
103
+
104
+
105
+ class RandomResize(object):
106
+ """Resizes a list of (H x W x C) numpy.ndarray to the final size
107
+ The larger the original image is, the more times it takes to
108
+ interpolate
109
+ Args:
110
+ interpolation (str): Can be one of 'nearest', 'bilinear'
111
+ defaults to nearest
112
+ size (tuple): (widht, height)
113
+ """
114
+
115
+ def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'):
116
+ self.ratio = ratio
117
+ self.interpolation = interpolation
118
+
119
+ def __call__(self, clip):
120
+ scaling_factor = random.uniform(self.ratio[0], self.ratio[1])
121
+
122
+ if isinstance(clip[0], np.ndarray):
123
+ im_h, im_w, im_c = clip[0].shape
124
+ elif isinstance(clip[0], PIL.Image.Image):
125
+ im_w, im_h = clip[0].size
126
+
127
+ new_w = int(im_w * scaling_factor)
128
+ new_h = int(im_h * scaling_factor)
129
+ new_size = (new_w, new_h)
130
+ resized = resize_clip(
131
+ clip, new_size, interpolation=self.interpolation)
132
+
133
+ return resized
134
+
135
+
136
+ class RandomCrop(object):
137
+ """Extract random crop at the same location for a list of videos
138
+ Args:
139
+ size (sequence or int): Desired output size for the
140
+ crop in format (h, w)
141
+ """
142
+
143
+ def __init__(self, size):
144
+ if isinstance(size, numbers.Number):
145
+ size = (size, size)
146
+
147
+ self.size = size
148
+
149
+ def __call__(self, clip):
150
+ """
151
+ Args:
152
+ img (PIL.Image or numpy.ndarray): List of videos to be cropped
153
+ in format (h, w, c) in numpy.ndarray
154
+ Returns:
155
+ PIL.Image or numpy.ndarray: Cropped list of videos
156
+ """
157
+ h, w = self.size
158
+ if isinstance(clip[0], np.ndarray):
159
+ im_h, im_w, im_c = clip[0].shape
160
+ elif isinstance(clip[0], PIL.Image.Image):
161
+ im_w, im_h = clip[0].size
162
+ else:
163
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
164
+ 'but got list of {0}'.format(type(clip[0])))
165
+
166
+ clip = pad_clip(clip, h, w)
167
+ im_h, im_w = clip.shape[1:3]
168
+ x1 = 0 if h == im_h else random.randint(0, im_w - w)
169
+ y1 = 0 if w == im_w else random.randint(0, im_h - h)
170
+ cropped = crop_clip(clip, y1, x1, h, w)
171
+
172
+ return cropped
173
+
174
+
175
+ class RandomRotation(object):
176
+ """Rotate entire clip randomly by a random angle within
177
+ given bounds
178
+ Args:
179
+ degrees (sequence or int): Range of degrees to select from
180
+ If degrees is a number instead of sequence like (min, max),
181
+ the range of degrees, will be (-degrees, +degrees).
182
+ """
183
+
184
+ def __init__(self, degrees):
185
+ if isinstance(degrees, numbers.Number):
186
+ if degrees < 0:
187
+ raise ValueError('If degrees is a single number,'
188
+ 'must be positive')
189
+ degrees = (-degrees, degrees)
190
+ else:
191
+ if len(degrees) != 2:
192
+ raise ValueError('If degrees is a sequence,'
193
+ 'it must be of len 2.')
194
+
195
+ self.degrees = degrees
196
+
197
+ def __call__(self, clip):
198
+ """
199
+ Args:
200
+ img (PIL.Image or numpy.ndarray): List of videos to be cropped
201
+ in format (h, w, c) in numpy.ndarray
202
+ Returns:
203
+ PIL.Image or numpy.ndarray: Cropped list of videos
204
+ """
205
+ angle = random.uniform(self.degrees[0], self.degrees[1])
206
+ if isinstance(clip[0], np.ndarray):
207
+ rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip]
208
+ elif isinstance(clip[0], PIL.Image.Image):
209
+ rotated = [img.rotate(angle) for img in clip]
210
+ else:
211
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
212
+ 'but got list of {0}'.format(type(clip[0])))
213
+
214
+ return rotated
215
+
216
+
217
+ class ColorJitter(object):
218
+ """Randomly change the brightness, contrast and saturation and hue of the clip
219
+ Args:
220
+ brightness (float): How much to jitter brightness. brightness_factor
221
+ is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
222
+ contrast (float): How much to jitter contrast. contrast_factor
223
+ is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
224
+ saturation (float): How much to jitter saturation. saturation_factor
225
+ is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
226
+ hue(float): How much to jitter hue. hue_factor is chosen uniformly from
227
+ [-hue, hue]. Should be >=0 and <= 0.5.
228
+ """
229
+
230
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
231
+ self.brightness = brightness
232
+ self.contrast = contrast
233
+ self.saturation = saturation
234
+ self.hue = hue
235
+
236
+ def get_params(self, brightness, contrast, saturation, hue):
237
+ if brightness > 0:
238
+ brightness_factor = random.uniform(
239
+ max(0, 1 - brightness), 1 + brightness)
240
+ else:
241
+ brightness_factor = None
242
+
243
+ if contrast > 0:
244
+ contrast_factor = random.uniform(
245
+ max(0, 1 - contrast), 1 + contrast)
246
+ else:
247
+ contrast_factor = None
248
+
249
+ if saturation > 0:
250
+ saturation_factor = random.uniform(
251
+ max(0, 1 - saturation), 1 + saturation)
252
+ else:
253
+ saturation_factor = None
254
+
255
+ if hue > 0:
256
+ hue_factor = random.uniform(-hue, hue)
257
+ else:
258
+ hue_factor = None
259
+ return brightness_factor, contrast_factor, saturation_factor, hue_factor
260
+
261
+ def __call__(self, clip):
262
+ """
263
+ Args:
264
+ clip (list): list of PIL.Image
265
+ Returns:
266
+ list PIL.Image : list of transformed PIL.Image
267
+ """
268
+ if isinstance(clip[0], np.ndarray):
269
+ brightness, contrast, saturation, hue = self.get_params(
270
+ self.brightness, self.contrast, self.saturation, self.hue)
271
+
272
+ # Create img transform function sequence
273
+ img_transforms = []
274
+ if brightness is not None:
275
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
276
+ if saturation is not None:
277
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
278
+ if hue is not None:
279
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
280
+ if contrast is not None:
281
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
282
+ random.shuffle(img_transforms)
283
+ img_transforms = [img_as_ubyte, torchvision.transforms.ToPILImage()] + img_transforms + [np.array,
284
+ img_as_float]
285
+
286
+ with warnings.catch_warnings():
287
+ warnings.simplefilter("ignore")
288
+ jittered_clip = []
289
+ for img in clip:
290
+ jittered_img = img
291
+ for func in img_transforms:
292
+ jittered_img = func(jittered_img)
293
+ jittered_clip.append(jittered_img.astype('float32'))
294
+ elif isinstance(clip[0], PIL.Image.Image):
295
+ brightness, contrast, saturation, hue = self.get_params(
296
+ self.brightness, self.contrast, self.saturation, self.hue)
297
+
298
+ # Create img transform function sequence
299
+ img_transforms = []
300
+ if brightness is not None:
301
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
302
+ if saturation is not None:
303
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
304
+ if hue is not None:
305
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
306
+ if contrast is not None:
307
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
308
+ random.shuffle(img_transforms)
309
+
310
+ # Apply to all videos
311
+ jittered_clip = []
312
+ for img in clip:
313
+ for func in img_transforms:
314
+ jittered_img = func(img)
315
+ jittered_clip.append(jittered_img)
316
+
317
+ else:
318
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
319
+ 'but got list of {0}'.format(type(clip[0])))
320
+ return jittered_clip
321
+
322
+
323
+ class AllAugmentationTransform:
324
+ def __init__(self, resize_param=None, rotation_param=None, flip_param=None, crop_param=None, jitter_param=None):
325
+ self.transforms = []
326
+
327
+ if flip_param is not None:
328
+ self.transforms.append(RandomFlip(**flip_param))
329
+
330
+ if rotation_param is not None:
331
+ self.transforms.append(RandomRotation(**rotation_param))
332
+
333
+ if resize_param is not None:
334
+ self.transforms.append(RandomResize(**resize_param))
335
+
336
+ if crop_param is not None:
337
+ self.transforms.append(RandomCrop(**crop_param))
338
+
339
+ if jitter_param is not None:
340
+ self.transforms.append(ColorJitter(**jitter_param))
341
+
342
+ def __call__(self, clip):
343
+ for t in self.transforms:
344
+ clip = t(clip)
345
+ return clip
config/vox-256.yaml ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_params:
2
+ root_dir: /data/lh/repo/datasets/face-video-preprocessing/vox-png
3
+ frame_shape: [256, 256, 3]
4
+ id_sampling: True
5
+ pairs_list: data/vox256.csv
6
+ augmentation_params:
7
+ flip_param:
8
+ horizontal_flip: True
9
+ time_flip: True
10
+ jitter_param:
11
+ brightness: 0.1
12
+ contrast: 0.1
13
+ saturation: 0.1
14
+ hue: 0.1
15
+
16
+
17
+ model_params:
18
+ common_params:
19
+ num_kp: 10
20
+ num_channels: 3
21
+ estimate_jacobian: True
22
+ kp_detector_params:
23
+ temperature: 0.1
24
+ block_expansion: 32
25
+ max_features: 1024
26
+ scale_factor: 0.25
27
+ num_blocks: 5
28
+ generator_params:
29
+ block_expansion: 64
30
+ max_features: 512
31
+ num_down_blocks: 2
32
+ num_bottleneck_blocks: 6
33
+ estimate_occlusion_map: True
34
+ dense_motion_params:
35
+ block_expansion: 64
36
+ max_features: 1024
37
+ num_blocks: 5
38
+ scale_factor: 0.25
39
+ render_params:
40
+ simpled_channel_rgb: 128
41
+ simpled_channel_sigma: 128
42
+ floor_num: 8
43
+ hidden_size: 128
44
+ discriminator_params:
45
+ scales: [1]
46
+ block_expansion: 32
47
+ max_features: 512
48
+ num_blocks: 4
49
+ sn: True
50
+
51
+ train_params:
52
+ num_epochs: 100
53
+ num_repeats: 75
54
+ epoch_milestones: [60, 90]
55
+ lr_generator: 2.0e-4
56
+ lr_discriminator: 2.0e-4
57
+ lr_kp_detector: 2.0e-4
58
+ lr_face_editor: 2.0e-4
59
+ # batch_size: 40
60
+ batch_size: 20
61
+ scales: [1, 0.5, 0.25, 0.125]
62
+ # checkpoint_freq: 75
63
+ checkpoint_freq: 10
64
+ transform_params:
65
+ sigma_affine: 0.05
66
+ sigma_tps: 0.005
67
+ points_tps: 5
68
+ loss_weights:
69
+ generator_gan: 1 # 0
70
+ discriminator_gan: 1
71
+ feature_matching: [10, 10, 10, 10]
72
+ perceptual: [10, 10, 10, 10, 10]
73
+ equivariance_value: 10
74
+ equivariance_jacobian: 10
75
+ perceptual_l1: 5
76
+ pose_edit: 1
77
+
78
+ reconstruction_params:
79
+ num_videos: 1000
80
+ format: '.mp4'
81
+
82
+ animate_params:
83
+ num_pairs: 50
84
+ format: '.mp4'
85
+ normalization_params:
86
+ adapt_movement_scale: True
87
+ use_relative_movement: True
88
+ use_relative_jacobian: True
89
+
90
+ visualizer_params:
91
+ kp_size: 5
92
+ draw_border: True
93
+ colormap: 'gist_rainbow'
modules/__pycache__/dense_motion.cpython-38.pyc ADDED
Binary file (3.83 kB). View file
 
modules/__pycache__/generator.cpython-38.pyc ADDED
Binary file (6.9 kB). View file
 
modules/__pycache__/keypoint_detector.cpython-38.pyc ADDED
Binary file (2.48 kB). View file
 
modules/__pycache__/nerf_verts_util.cpython-38.pyc ADDED
Binary file (7.04 kB). View file
 
modules/__pycache__/util.cpython-38.pyc ADDED
Binary file (7.82 kB). View file
 
modules/dense_motion.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch.nn.functional as F
3
+ import torch
4
+ from modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian
5
+
6
+
7
+ class DenseMotionNetwork(nn.Module):
8
+ """
9
+ Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
10
+ """
11
+
12
+ def __init__(self, block_expansion, num_blocks, max_features, num_kp, num_channels, estimate_occlusion_map=False,
13
+ scale_factor=1, kp_variance=0.01):
14
+ super(DenseMotionNetwork, self).__init__()
15
+ self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp + 1) * (num_channels + 1),
16
+ max_features=max_features, num_blocks=num_blocks)
17
+
18
+ self.mask = nn.Conv2d(self.hourglass.out_filters, num_kp + 1, kernel_size=(7, 7), padding=(3, 3))
19
+
20
+ if estimate_occlusion_map:
21
+ self.occlusion = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3))
22
+ else:
23
+ self.occlusion = None
24
+
25
+ self.num_kp = num_kp
26
+ self.scale_factor = scale_factor
27
+ self.kp_variance = kp_variance
28
+
29
+ if self.scale_factor != 1:
30
+ self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)
31
+
32
+ def create_heatmap_representations(self, source_image, kp_driving, kp_source):
33
+ """
34
+ Eq 6. in the paper H_k(z)
35
+ """
36
+ spatial_size = source_image.shape[2:]
37
+ gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=self.kp_variance)
38
+ gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=self.kp_variance)
39
+ heatmap = gaussian_driving - gaussian_source
40
+
41
+ #adding background feature
42
+ zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type())
43
+ heatmap = torch.cat([zeros, heatmap], dim=1)
44
+ heatmap = heatmap.unsqueeze(2)
45
+ return heatmap
46
+
47
+ def create_sparse_motions(self, source_image, kp_driving, kp_source):
48
+ """
49
+ Eq 4. in the paper T_{s<-d}(z)
50
+ """
51
+ bs, _, h, w = source_image.shape
52
+ identity_grid = make_coordinate_grid((h, w), type=kp_source['value'].type())
53
+ identity_grid = identity_grid.view(1, 1, h, w, 2)
54
+ coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 2)
55
+ if 'jacobian' in kp_driving:
56
+ jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian']))
57
+
58
+ jacobian = jacobian.unsqueeze(-3).unsqueeze(-3)
59
+ jacobian = jacobian.repeat(1, 1, h, w, 1, 1)
60
+ coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1))
61
+ coordinate_grid = coordinate_grid.squeeze(-1)
62
+
63
+ driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 2)
64
+
65
+ #adding background feature
66
+ identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1)
67
+ sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1)
68
+ return sparse_motions
69
+
70
+ def create_deformed_source_image(self, source_image, sparse_motions):
71
+ """
72
+ Eq 7. in the paper \hat{T}_{s<-d}(z)
73
+ """
74
+ bs, _, h, w = source_image.shape
75
+ source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp + 1, 1, 1, 1, 1)
76
+ source_repeat = source_repeat.view(bs * (self.num_kp + 1), -1, h, w)
77
+ sparse_motions = sparse_motions.view((bs * (self.num_kp + 1), h, w, -1))
78
+ sparse_deformed = F.grid_sample(source_repeat, sparse_motions)
79
+ sparse_deformed = sparse_deformed.view((bs, self.num_kp + 1, -1, h, w))
80
+ return sparse_deformed
81
+
82
+ def forward(self, source_image, kp_driving, kp_source):
83
+ if self.scale_factor != 1:
84
+ source_image = self.down(source_image)
85
+
86
+ bs, _, h, w = source_image.shape
87
+
88
+ out_dict = dict()
89
+ heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source)
90
+ sparse_motion = self.create_sparse_motions(source_image, kp_driving, kp_source)
91
+ deformed_source = self.create_deformed_source_image(source_image, sparse_motion)
92
+ out_dict['sparse_deformed'] = deformed_source
93
+
94
+ input = torch.cat([heatmap_representation, deformed_source], dim=2)
95
+ input = input.view(bs, -1, h, w)
96
+
97
+ prediction = self.hourglass(input)
98
+
99
+ mask = self.mask(prediction)
100
+ mask = F.softmax(mask, dim=1)
101
+ out_dict['mask'] = mask
102
+ mask = mask.unsqueeze(2)
103
+ sparse_motion = sparse_motion.permute(0, 1, 4, 2, 3)
104
+ deformation = (sparse_motion * mask).sum(dim=1)
105
+ deformation = deformation.permute(0, 2, 3, 1)
106
+
107
+ out_dict['deformation'] = deformation
108
+
109
+ # Sec. 3.2 in the paper
110
+ if self.occlusion:
111
+ occlusion_map = torch.sigmoid(self.occlusion(prediction))
112
+ out_dict['occlusion_map'] = occlusion_map
113
+
114
+ return out_dict
modules/generator.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d
5
+ from modules.dense_motion import DenseMotionNetwork
6
+ from modules.nerf_verts_util import RenderModel
7
+
8
+
9
+ class SPADE_layer(nn.Module):
10
+ def __init__(self, norm_channel, label_channel):
11
+ super(SPADE_layer, self).__init__()
12
+
13
+ self.param_free_norm = nn.InstanceNorm2d(norm_channel, affine=False)
14
+ hidden_channel = 128
15
+
16
+ self.mlp_shared = nn.Sequential(
17
+ nn.Conv2d(label_channel, hidden_channel, kernel_size=3, padding=1),
18
+ nn.ReLU()
19
+ )
20
+ self.mlp_gamma = nn.Conv2d(hidden_channel, norm_channel, kernel_size=3, padding=1)
21
+ self.mlp_beta = nn.Conv2d(hidden_channel, norm_channel, kernel_size=3, padding=1)
22
+
23
+ def forward(self, x, modulation_in):
24
+ normalized = self.param_free_norm(x)
25
+ modulation_in = F.interpolate(modulation_in, size=x.size()[2:], mode='nearest')
26
+ actv = self.mlp_shared(modulation_in)
27
+ gamma = self.mlp_gamma(actv)
28
+ beta = self.mlp_beta(actv)
29
+ out = normalized * (1 + gamma) + beta
30
+ return out
31
+
32
+
33
+ class SPADE_block(nn.Module):
34
+ def __init__(self, norm_channel, label_channel, out_channel):
35
+ super(SPADE_block, self).__init__()
36
+ self.SPADE_0 = SPADE_layer(norm_channel, label_channel)
37
+ self.relu_0 = nn.ReLU()
38
+ self.conv_0 = nn.Conv2d(norm_channel, norm_channel, kernel_size=3, padding=1)
39
+ self.SPADE_1 = SPADE_layer(norm_channel, label_channel)
40
+ self.relu_1 = nn.ReLU()
41
+ self.conv_1 = nn.Conv2d(norm_channel, out_channel, kernel_size=3, padding=1)
42
+
43
+ def forward(self, x, modulation_in):
44
+ out = self.SPADE_0(x, modulation_in)
45
+ out = self.relu_0(out)
46
+ out = self.conv_0(out)
47
+ out = self.SPADE_1(out, modulation_in)
48
+ out = self.relu_1(out)
49
+ out = self.conv_1(out)
50
+ return out
51
+
52
+
53
+ class SPADE_decoder(nn.Module):
54
+ def __init__(self, in_channel, mid_channel):
55
+ super(SPADE_decoder, self).__init__()
56
+ self.in_channel = in_channel
57
+ self.mid_channel = mid_channel
58
+ self.seg_conv = nn.Sequential(
59
+ nn.Conv2d(in_channel, mid_channel, kernel_size=3, padding=1),
60
+ nn.ReLU()
61
+ )
62
+ self.SPADE_0 = SPADE_block(in_channel, mid_channel, in_channel // 4)
63
+ self.up_0 = nn.UpsamplingBilinear2d(scale_factor=2)
64
+ in_channel = in_channel // 4
65
+ self.SPADE_1 = SPADE_block(in_channel, mid_channel, in_channel // 4)
66
+ self.up_1 = nn.UpsamplingBilinear2d(scale_factor=2)
67
+ in_channel = in_channel // 4
68
+ self.SPADE_2 = SPADE_block(in_channel, mid_channel, in_channel)
69
+ self.SPADE_3 = SPADE_block(in_channel, mid_channel, in_channel)
70
+ self.final = nn.Sequential(
71
+ nn.Conv2d(in_channel, 3, kernel_size=7, padding=3),
72
+ nn.Sigmoid()
73
+ )
74
+
75
+ def forward(self, x):
76
+ seg = self.seg_conv(x)
77
+ x = self.SPADE_0(x, seg)
78
+ x = self.up_0(x)
79
+ x = self.SPADE_1(x, seg)
80
+ x = self.up_1(x)
81
+ x = self.SPADE_2(x, seg)
82
+ x = self.SPADE_3(x, seg)
83
+ x = self.final(x)
84
+ return x
85
+
86
+
87
+ def calc_mean_std(feat, eps=1e-5):
88
+ # eps is a small value added to the variance to avoid divide-by-zero.
89
+ size = feat.size()
90
+ assert (len(size) == 4)
91
+ N, C = size[:2]
92
+ feat_var = feat.view(N, C, -1).var(dim=2) + eps
93
+ feat_std = feat_var.sqrt().view(N, C, 1, 1)
94
+ feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
95
+ return feat_mean, feat_std
96
+
97
+
98
+ def adaptive_instance_normalization(x, modulation_in):
99
+ assert (x.size()[:2] == modulation_in.size()[:2])
100
+ size = x.size()
101
+ style_mean, style_std = calc_mean_std(modulation_in)
102
+ content_mean, content_std = calc_mean_std(x)
103
+
104
+ normalized_feat = (x - content_mean.expand(
105
+ size)) / content_std.expand(size)
106
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
107
+
108
+
109
+ class AdaIN_layer(nn.Module):
110
+ def __init__(self, norm_channel, label_channel):
111
+ super(AdaIN_layer, self).__init__()
112
+ self.param_free_norm = nn.InstanceNorm2d(norm_channel, affine=False)
113
+
114
+ self.mlp_shared = nn.Sequential(
115
+ nn.Conv2d(label_channel, norm_channel, kernel_size=3, padding=1),
116
+ nn.ReLU()
117
+ )
118
+
119
+ def forward(self, x, modulation_in):
120
+ normalized = self.param_free_norm(x)
121
+ modulation_in = self.mlp_shared(modulation_in)
122
+ out = adaptive_instance_normalization(normalized, modulation_in)
123
+ return out
124
+
125
+
126
+ class OcclusionAwareGenerator_SPADE(nn.Module):
127
+ """
128
+ Generator that given source image and and keypoints try to transform image according to movement trajectories
129
+ induced by keypoints. Generator follows Johnson architecture.
130
+ """
131
+
132
+ def __init__(self, num_channels, num_kp, block_expansion, max_features, num_down_blocks,
133
+ num_bottleneck_blocks, estimate_occlusion_map=False, dense_motion_params=None, render_params=None,
134
+ estimate_jacobian=False):
135
+ super(OcclusionAwareGenerator_SPADE, self).__init__()
136
+
137
+ if dense_motion_params is not None:
138
+ self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, num_channels=num_channels,
139
+ estimate_occlusion_map=estimate_occlusion_map,
140
+ **dense_motion_params)
141
+ else:
142
+ self.dense_motion_network = None
143
+
144
+ self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3))
145
+
146
+ down_blocks = []
147
+ for i in range(num_down_blocks):
148
+ in_features = min(max_features, block_expansion * (2 ** i))
149
+ out_features = min(max_features, block_expansion * (2 ** (i + 1)))
150
+ down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
151
+ self.down_blocks = nn.ModuleList(down_blocks)
152
+
153
+ in_features = min(max_features, block_expansion * (2 ** num_down_blocks))
154
+
155
+ self.Render_model = RenderModel(in_channels=in_features, **render_params)
156
+ self.decoder = SPADE_decoder(in_channel=in_features * 2, mid_channel=128)
157
+
158
+ self.estimate_occlusion_map = estimate_occlusion_map
159
+ self.num_channels = num_channels
160
+
161
+ def deform_input(self, inp, deformation):
162
+ _, h_old, w_old, _ = deformation.shape
163
+ _, _, h, w = inp.shape
164
+ if h_old != h or w_old != w:
165
+ deformation = deformation.permute(0, 3, 1, 2)
166
+ deformation = F.interpolate(deformation, size=(h, w), mode='bilinear')
167
+ deformation = deformation.permute(0, 2, 3, 1)
168
+ return F.grid_sample(inp, deformation)
169
+
170
+ def forward(self, source_image, kp_driving, kp_source):
171
+ # Encoding (downsampling) part
172
+ out = self.first(source_image)
173
+ for i in range(len(self.down_blocks)):
174
+ out = self.down_blocks[i](out)
175
+
176
+ # Transforming feature representation according to deformation and occlusion
177
+ output_dict = {}
178
+ if self.dense_motion_network is not None:
179
+ dense_motion = self.dense_motion_network(source_image=source_image, kp_driving=kp_driving,
180
+ kp_source=kp_source)
181
+ output_dict['mask'] = dense_motion['mask']
182
+ output_dict['sparse_deformed'] = dense_motion['sparse_deformed']
183
+
184
+ if 'occlusion_map' in dense_motion:
185
+ occlusion_map = dense_motion['occlusion_map']
186
+ output_dict['occlusion_map'] = occlusion_map
187
+ else:
188
+ occlusion_map = None
189
+ deformation = dense_motion['deformation']
190
+ out = self.deform_input(out, deformation)
191
+
192
+ if occlusion_map is not None:
193
+ if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
194
+ occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
195
+ out = out * occlusion_map
196
+
197
+ output_dict["deformed"] = self.deform_input(source_image, deformation)
198
+
199
+ # render part
200
+ render_result = self.Render_model(feature=out)
201
+ output_dict['render'] = render_result['mini_pred']
202
+ output_dict['point_pred'] = render_result['point_pred']
203
+ out = torch.cat((out, render_result['render']), dim=1)
204
+ # out = self.merge_conv(out)
205
+
206
+ # Decoding part
207
+ out = self.decoder(out)
208
+
209
+ output_dict["prediction"] = out
210
+
211
+ return output_dict
modules/keypoint_detector.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from modules.util import Hourglass, make_coordinate_grid, AntiAliasInterpolation2d
5
+
6
+
7
+ class KPDetector(nn.Module):
8
+ """
9
+ Detecting a keypoints. Return keypoint position and jacobian near each keypoint.
10
+ """
11
+
12
+ def __init__(self, block_expansion, num_kp, num_channels, max_features,
13
+ num_blocks, temperature, estimate_jacobian=False, estimate_hessian=False,
14
+ scale_factor=1, single_jacobian_map=False, pad=0):
15
+ super(KPDetector, self).__init__()
16
+
17
+ self.predictor = Hourglass(block_expansion, in_features=num_channels,
18
+ max_features=max_features, num_blocks=num_blocks)
19
+
20
+ self.kp = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=(7, 7),
21
+ padding=pad)
22
+
23
+ if estimate_jacobian:
24
+ self.num_jacobian_maps = 1 if single_jacobian_map else num_kp
25
+ self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters,
26
+ out_channels=4 * self.num_jacobian_maps, kernel_size=(7, 7), padding=pad)
27
+ self.jacobian.weight.data.zero_()
28
+ self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float))
29
+ else:
30
+ self.jacobian = None
31
+
32
+ self.temperature = temperature
33
+ self.scale_factor = scale_factor
34
+ if self.scale_factor != 1:
35
+ self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)
36
+
37
+ def gaussian2kp(self, heatmap):
38
+ """
39
+ Extract the mean and from a heatmap
40
+ """
41
+ shape = heatmap.shape
42
+ heatmap = heatmap.unsqueeze(-1)
43
+ grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0)
44
+ value = (heatmap * grid).sum(dim=(2, 3))
45
+ kp = {'value': value}
46
+
47
+ return kp
48
+
49
+ def forward(self, x):
50
+ if self.scale_factor != 1:
51
+ x = self.down(x)
52
+
53
+ feature_map = self.predictor(x)
54
+ prediction = self.kp(feature_map)
55
+
56
+ final_shape = prediction.shape
57
+ heatmap = prediction.view(final_shape[0], final_shape[1], -1)
58
+ heatmap = F.softmax(heatmap / self.temperature, dim=2)
59
+ heatmap = heatmap.view(*final_shape)
60
+
61
+ out = self.gaussian2kp(heatmap)
62
+
63
+ if self.jacobian is not None:
64
+ jacobian_map = self.jacobian(feature_map)
65
+
66
+ jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2],
67
+ final_shape[3])
68
+ heatmap = heatmap.unsqueeze(2)
69
+
70
+ jacobian = heatmap * jacobian_map
71
+ jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1)
72
+ jacobian = jacobian.sum(dim=-1)
73
+ jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2)
74
+ out['jacobian'] = jacobian
75
+
76
+ return out
modules/nerf_verts_util.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
5
+ from sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d
6
+ import einops
7
+ from modules.util import UpBlock2d, DownBlock2d
8
+
9
+
10
+ def make_coordinate_grid(spatial_size, type):
11
+ d, h, w = spatial_size
12
+ x = torch.arange(w).type(type)
13
+ y = torch.arange(h).type(type)
14
+ z = torch.arange(d).type(type)
15
+
16
+ x = (2 * (x / (w - 1)) - 1)
17
+ y = (2 * (y / (h - 1)) - 1)
18
+ z = (2 * (z / (d - 1)) - 1)
19
+
20
+ yy = y.view(1, -1, 1).repeat(d, 1, w)
21
+ xx = x.view(1, 1, -1).repeat(d, h, 1)
22
+ zz = z.view(-1, 1, 1).repeat(1, h, w)
23
+
24
+ meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3)
25
+
26
+ return meshed
27
+
28
+
29
+ def kp2gaussian_3d(kp, spatial_size, kp_variance):
30
+ """
31
+ Transform a keypoint into gaussian like representation
32
+ """
33
+ # mean = kp['value']
34
+ mean = kp
35
+
36
+ coordinate_grid = make_coordinate_grid(spatial_size, mean.type())
37
+ number_of_leading_dimensions = len(mean.shape) - 1
38
+ shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
39
+ coordinate_grid = coordinate_grid.view(*shape)
40
+ repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1)
41
+ coordinate_grid = coordinate_grid.repeat(*repeats)
42
+
43
+ # Preprocess kp shape
44
+ shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3)
45
+ mean = mean.view(*shape)
46
+
47
+ mean_sub = (coordinate_grid - mean)
48
+
49
+ out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
50
+
51
+ return out
52
+
53
+
54
+ class ResBlock3d(nn.Module):
55
+ """
56
+ Res block, preserve spatial resolution.
57
+ """
58
+
59
+ def __init__(self, in_features, kernel_size, padding):
60
+ super(ResBlock3d, self).__init__()
61
+ self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
62
+ padding=padding)
63
+ self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
64
+ padding=padding)
65
+ self.norm1 = BatchNorm3d(in_features, affine=True)
66
+ self.norm2 = BatchNorm3d(in_features, affine=True)
67
+
68
+ def forward(self, x):
69
+ out = self.norm1(x)
70
+ out = F.relu(out)
71
+ out = self.conv1(out)
72
+ out = self.norm2(out)
73
+ out = F.relu(out)
74
+ out = self.conv2(out)
75
+ out += x
76
+ return out
77
+
78
+
79
+ class rgb_predictor(nn.Module):
80
+ def __init__(self, in_channels, simpled_channel=128, floor_num=8):
81
+ super(rgb_predictor, self).__init__()
82
+ self.floor_num = floor_num
83
+ self.down_conv = nn.Conv2d(in_channels=in_channels, out_channels=simpled_channel, kernel_size=3, padding=1)
84
+
85
+ def forward(self, feature):
86
+ """
87
+ Args:
88
+ feature: warp feature: bs * c * h * w
89
+ Returns:
90
+ rgb: bs * h * w * floor_num * e
91
+ """
92
+ feature = self.down_conv(feature)
93
+ feature = einops.rearrange(feature, 'b (c f) h w -> b c f h w', f=self.floor_num)
94
+ feature = einops.rearrange(feature, 'b c f h w -> b h w f c')
95
+ return feature
96
+
97
+
98
+ class sigma_predictor(nn.Module):
99
+ def __init__(self, in_channels, simpled_channel=128, floor_num=8):
100
+ super(sigma_predictor, self).__init__()
101
+ self.floor_num = floor_num
102
+ self.down_conv = nn.Conv2d(in_channels=in_channels, out_channels=simpled_channel, kernel_size=3, padding=1)
103
+
104
+ self.res_conv3d = nn.Sequential(
105
+ ResBlock3d(16, 3, 1),
106
+ nn.BatchNorm3d(16),
107
+ ResBlock3d(16, 3, 1),
108
+ nn.BatchNorm3d(16),
109
+ ResBlock3d(16, 3, 1),
110
+ nn.BatchNorm3d(16)
111
+ )
112
+
113
+ def forward(self, feature):
114
+ """
115
+ Args:
116
+ feature: bs * h * w * floor * c, the output of rgb predictor
117
+ Returns:
118
+ sigma: bs * h * w * floor * encode
119
+ point: bs * 5023 * 3
120
+ """
121
+ heatmap = self.down_conv(feature)
122
+ heatmap = einops.rearrange(heatmap, "b (c f) h w -> b c f h w", f=self.floor_num)
123
+ heatmap = self.res_conv3d(heatmap)
124
+ sigma = einops.rearrange(heatmap, "b c f h w -> b h w f c")
125
+
126
+ point_dict = {'sigma_map': heatmap}
127
+ # point_pred = einops.rearrange(point_pred, 'b p n -> b n p')
128
+ return sigma, point_dict
129
+
130
+
131
+ class MultiHeadNeRFModel(torch.nn.Module):
132
+
133
+ def __init__(self, hidden_size=128, num_encoding_rgb=16, num_encoding_sigma=16):
134
+ super(MultiHeadNeRFModel, self).__init__()
135
+ # self.xyz_encoding_dims = 1 + 1 * 2 * num_encoding_functions + num_encoding_rgb
136
+ self.xyz_encoding_dims = num_encoding_sigma
137
+ self.viewdir_encoding_dims = num_encoding_rgb
138
+
139
+ # Input layer (default: 16 -> 128)
140
+ self.layer1 = torch.nn.Linear(self.xyz_encoding_dims, hidden_size)
141
+ # Layer 2 (default: 128 -> 128)
142
+ self.layer2 = torch.nn.Linear(hidden_size, hidden_size)
143
+ # Layer 3_1 (default: 128 -> 1): Predicts radiance ("sigma")
144
+ self.layer3_1 = torch.nn.Linear(hidden_size, 1)
145
+ # Layer 3_2 (default: 128 -> 32): Predicts a feature vector (used for color)
146
+ self.layer3_2 = torch.nn.Linear(hidden_size, hidden_size // 4)
147
+ self.layer3_3 = torch.nn.Linear(self.viewdir_encoding_dims, hidden_size)
148
+
149
+ # Layer 4 (default: 32 + 128 -> 128)
150
+ self.layer4 = torch.nn.Linear(
151
+ hidden_size // 4 + hidden_size, hidden_size
152
+ )
153
+ # Layer 5 (default: 128 -> 128)
154
+ self.layer5 = torch.nn.Linear(hidden_size, hidden_size)
155
+ # Layer 6 (default: 128 -> 256): Predicts RGB color
156
+ self.layer6 = torch.nn.Linear(hidden_size, 256)
157
+
158
+ # Short hand for torch.nn.functional.relu
159
+ self.relu = torch.nn.functional.relu
160
+
161
+ def forward(self, rgb_in, sigma_in):
162
+ """
163
+ Args:
164
+ x: rgb pred result of Perdict3D
165
+ view: result of LightPredict
166
+ Returns:
167
+ """
168
+ bs, h, w, floor_num, _ = rgb_in.size()
169
+ # x = torch.cat((x, point3D), dim=-1)
170
+ out = self.relu(self.layer1(sigma_in))
171
+ out = self.relu(self.layer2(out))
172
+ sigma = self.layer3_1(out)
173
+ feat_sigma = self.relu(self.layer3_2(out))
174
+ feat_rgb = self.relu(self.layer3_3(rgb_in))
175
+ x = torch.cat((feat_sigma, feat_rgb), dim=-1)
176
+ x = self.relu(self.layer4(x))
177
+ x = self.relu(self.layer5(x))
178
+ x = self.layer6(x)
179
+ return x, sigma
180
+
181
+
182
+ def volume_render(rgb_pred, sigma_pred):
183
+ """
184
+ Args:
185
+ rgb_pred: result of Nerf, [bs, h, w, floor, rgb_channel]
186
+ sigma_pred: result of Nerf, [bs, h, w, floor, sigma_channel]
187
+ Returns:
188
+
189
+ """
190
+ _, _, _, floor, _ = sigma_pred.size()
191
+ c = 0
192
+ T = 0
193
+ for i in range(floor):
194
+ sigma_mid = torch.nn.functional.relu(sigma_pred[:, :, :, i, :])
195
+ T = T + (-sigma_mid)
196
+ c = c + torch.exp(T) * (1 - torch.exp(-sigma_mid)) * rgb_pred[:, :, :, i, :]
197
+ c = einops.rearrange(c, 'b h w c -> b c h w')
198
+ return c
199
+
200
+
201
+ class RenderModel(nn.Module):
202
+ def __init__(self, in_channels, simpled_channel_rgb, simpled_channel_sigma, floor_num, hidden_size):
203
+ super(RenderModel, self).__init__()
204
+ self.rgb_predict = rgb_predictor(in_channels=in_channels, simpled_channel=simpled_channel_rgb,
205
+ floor_num=floor_num)
206
+ self.sigma_predict = sigma_predictor(in_channels=in_channels, simpled_channel=simpled_channel_sigma,
207
+ floor_num=floor_num)
208
+ num_encoding_rgb, num_encoding_sigma = simpled_channel_rgb // floor_num, simpled_channel_sigma // floor_num
209
+ self.nerf_module = MultiHeadNeRFModel(hidden_size=hidden_size, num_encoding_rgb=num_encoding_rgb,
210
+ num_encoding_sigma=num_encoding_sigma)
211
+ self.mini_decoder = nn.Sequential(
212
+ UpBlock2d(256, 64, kernel_size=3, padding=1),
213
+ nn.ReLU(),
214
+ UpBlock2d(64, 3, kernel_size=3, padding=1),
215
+ nn.Sigmoid()
216
+ )
217
+
218
+ def forward(self, feature):
219
+ rgb_in = self.rgb_predict(feature)
220
+ # sigma_in, point_dict = self.sigma_predict(feature.detach())
221
+ sigma_in, point_dict = self.sigma_predict(feature)
222
+ rgb_out, sigma_out = self.nerf_module(rgb_in, sigma_in)
223
+ render_result = volume_render(rgb_out, sigma_out)
224
+ render_result = torch.sigmoid(render_result)
225
+ mini_pred = self.mini_decoder(render_result)
226
+ out_dict = {'render': render_result, 'mini_pred': mini_pred, 'point_pred': point_dict}
227
+ return out_dict
modules/util.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ import torch.nn.functional as F
4
+ import torch
5
+
6
+ from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
7
+
8
+
9
+ def kp2gaussian(kp, spatial_size, kp_variance):
10
+ """
11
+ Transform a keypoint into gaussian like representation
12
+ """
13
+ mean = kp['value']
14
+
15
+ coordinate_grid = make_coordinate_grid(spatial_size, mean.type())
16
+ number_of_leading_dimensions = len(mean.shape) - 1
17
+ shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
18
+ coordinate_grid = coordinate_grid.view(*shape)
19
+ repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1)
20
+ coordinate_grid = coordinate_grid.repeat(*repeats)
21
+
22
+ # Preprocess kp shape
23
+ shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 2)
24
+ mean = mean.view(*shape)
25
+
26
+ mean_sub = (coordinate_grid - mean)
27
+
28
+ out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
29
+
30
+ return out
31
+
32
+
33
+ def make_coordinate_grid(spatial_size, type):
34
+ """
35
+ Create a meshgrid [-1,1] x [-1,1] of given spatial_size.
36
+ """
37
+ h, w = spatial_size
38
+ x = torch.arange(w).type(type)
39
+ y = torch.arange(h).type(type)
40
+
41
+ x = (2 * (x / (w - 1)) - 1)
42
+ y = (2 * (y / (h - 1)) - 1)
43
+
44
+ yy = y.view(-1, 1).repeat(1, w)
45
+ xx = x.view(1, -1).repeat(h, 1)
46
+
47
+ meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
48
+
49
+ return meshed
50
+
51
+
52
+ class ResBlock2d(nn.Module):
53
+ """
54
+ Res block, preserve spatial resolution.
55
+ """
56
+
57
+ def __init__(self, in_features, kernel_size, padding):
58
+ super(ResBlock2d, self).__init__()
59
+ self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
60
+ padding=padding)
61
+ self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
62
+ padding=padding)
63
+ self.norm1 = BatchNorm2d(in_features, affine=True)
64
+ self.norm2 = BatchNorm2d(in_features, affine=True)
65
+
66
+ def forward(self, x):
67
+ out = self.norm1(x)
68
+ out = F.relu(out)
69
+ out = self.conv1(out)
70
+ out = self.norm2(out)
71
+ out = F.relu(out)
72
+ out = self.conv2(out)
73
+ out += x
74
+ return out
75
+
76
+
77
+ class UpBlock2d(nn.Module):
78
+ """
79
+ Upsampling block for use in decoder.
80
+ """
81
+
82
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
83
+ super(UpBlock2d, self).__init__()
84
+
85
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
86
+ padding=padding, groups=groups)
87
+ self.norm = BatchNorm2d(out_features, affine=True)
88
+
89
+ def forward(self, x):
90
+ out = F.interpolate(x, scale_factor=2)
91
+ out = self.conv(out)
92
+ out = self.norm(out)
93
+ out = F.relu(out)
94
+ return out
95
+
96
+
97
+ class DownBlock2d(nn.Module):
98
+ """
99
+ Downsampling block for use in encoder.
100
+ """
101
+
102
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
103
+ super(DownBlock2d, self).__init__()
104
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
105
+ padding=padding, groups=groups)
106
+ self.norm = BatchNorm2d(out_features, affine=True)
107
+ self.pool = nn.AvgPool2d(kernel_size=(2, 2))
108
+
109
+ def forward(self, x):
110
+ out = self.conv(x)
111
+ out = self.norm(out)
112
+ out = F.relu(out)
113
+ out = self.pool(out)
114
+ return out
115
+
116
+
117
+ class SameBlock2d(nn.Module):
118
+ """
119
+ Simple block, preserve spatial resolution.
120
+ """
121
+
122
+ def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1):
123
+ super(SameBlock2d, self).__init__()
124
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features,
125
+ kernel_size=kernel_size, padding=padding, groups=groups)
126
+ self.norm = BatchNorm2d(out_features, affine=True)
127
+
128
+ def forward(self, x):
129
+ out = self.conv(x)
130
+ out = self.norm(out)
131
+ out = F.relu(out)
132
+ return out
133
+
134
+
135
+ class Encoder(nn.Module):
136
+ """
137
+ Hourglass Encoder
138
+ """
139
+
140
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
141
+ super(Encoder, self).__init__()
142
+
143
+ down_blocks = []
144
+ for i in range(num_blocks):
145
+ down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
146
+ min(max_features, block_expansion * (2 ** (i + 1))),
147
+ kernel_size=3, padding=1))
148
+ self.down_blocks = nn.ModuleList(down_blocks)
149
+
150
+ def forward(self, x):
151
+ outs = [x]
152
+ for down_block in self.down_blocks:
153
+ outs.append(down_block(outs[-1]))
154
+ return outs
155
+
156
+
157
+ class Decoder(nn.Module):
158
+ """
159
+ Hourglass Decoder
160
+ """
161
+
162
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
163
+ super(Decoder, self).__init__()
164
+
165
+ up_blocks = []
166
+
167
+ for i in range(num_blocks)[::-1]:
168
+ in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
169
+ out_filters = min(max_features, block_expansion * (2 ** i))
170
+ up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1))
171
+
172
+ self.up_blocks = nn.ModuleList(up_blocks)
173
+ self.out_filters = block_expansion + in_features
174
+
175
+ def forward(self, x):
176
+ out = x.pop()
177
+ for up_block in self.up_blocks:
178
+ out = up_block(out)
179
+ skip = x.pop()
180
+ out = torch.cat([out, skip], dim=1)
181
+ return out
182
+
183
+
184
+ class Hourglass(nn.Module):
185
+ """
186
+ Hourglass architecture.
187
+ """
188
+
189
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
190
+ super(Hourglass, self).__init__()
191
+ self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
192
+ self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
193
+ self.out_filters = self.decoder.out_filters
194
+
195
+ def forward(self, x):
196
+ return self.decoder(self.encoder(x))
197
+
198
+
199
+ class AntiAliasInterpolation2d(nn.Module):
200
+ """
201
+ Band-limited downsampling, for better preservation of the input signal.
202
+ """
203
+ def __init__(self, channels, scale):
204
+ super(AntiAliasInterpolation2d, self).__init__()
205
+ sigma = (1 / scale - 1) / 2
206
+ kernel_size = 2 * round(sigma * 4) + 1
207
+ self.ka = kernel_size // 2
208
+ self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
209
+
210
+ kernel_size = [kernel_size, kernel_size]
211
+ sigma = [sigma, sigma]
212
+ # The gaussian kernel is the product of the
213
+ # gaussian function of each dimension.
214
+ kernel = 1
215
+ meshgrids = torch.meshgrid(
216
+ [
217
+ torch.arange(size, dtype=torch.float32)
218
+ for size in kernel_size
219
+ ]
220
+ )
221
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
222
+ mean = (size - 1) / 2
223
+ kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
224
+
225
+ # Make sure sum of values in gaussian kernel equals 1.
226
+ kernel = kernel / torch.sum(kernel)
227
+ # Reshape to depthwise convolutional weight
228
+ kernel = kernel.view(1, 1, *kernel.size())
229
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
230
+
231
+ self.register_buffer('weight', kernel)
232
+ self.groups = channels
233
+ self.scale = scale
234
+ inv_scale = 1 / scale
235
+ self.int_inv_scale = int(inv_scale)
236
+
237
+ def forward(self, input):
238
+ if self.scale == 1.0:
239
+ return input
240
+
241
+ out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
242
+ out = F.conv2d(out, weight=self.weight, groups=self.groups)
243
+ out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
244
+
245
+ return out
requirements.txt ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.0.0
2
+ aiohttp==3.7.4.post0
3
+ astropy==5.0.4
4
+ astunparse==1.6.3
5
+ async-timeout==3.0.1
6
+ attrs==21.4.0
7
+ blessings==1.7
8
+ blinker==1.4
9
+ brotlipy==0.7.0
10
+ cachetools==5.0.0
11
+ certifi==2021.10.8
12
+ cffi==1.15.0
13
+ chardet==4.0.0
14
+ charset-normalizer==2.0.12
15
+ chumpy==0.70
16
+ colorama==0.4.4
17
+ cryptography==35.0.0
18
+ cycler==0.11.0
19
+ Cython==0.29.28
20
+ dill==0.3.5.1
21
+ dominate==2.6.0
22
+ einops==0.4.1
23
+ face-alignment==1.3.5
24
+ flatbuffers==2.0
25
+ fonttools==4.33.3
26
+ fvcore==0.1.5.post20210915
27
+ gast==0.5.3
28
+ google-auth==2.6.6
29
+ google-auth-oauthlib==0.4.6
30
+ google-pasta==0.2.0
31
+ grpcio==1.46.0
32
+ h5py==3.6.0
33
+ idna==3.3
34
+ imageio==2.19.1
35
+ imageio-ffmpeg==0.4.7
36
+ importlib-metadata==4.11.3
37
+ iopath==0.1.9
38
+ joblib==1.1.0
39
+ keras==2.8.0
40
+ Keras-Preprocessing==1.1.2
41
+ kiwisolver==1.4.2
42
+ libclang==14.0.1
43
+ llvmlite==0.38.0
44
+ matplotlib==3.5.2
45
+ mkl-fft==1.3.0
46
+ mkl-random==1.2.2
47
+ mkl-service==2.4.0
48
+ multidict==5.2.0
49
+ multiprocess==0.70.12.2
50
+ networkx==2.8
51
+ numba
52
+ nvidia-ml-py3==7.352.0
53
+ oauthlib==3.2.0
54
+ onnx==1.11.0
55
+ opencv-python==4.5.5.64
56
+ opt-einsum==3.3.0
57
+ packaging==21.3
58
+ pandas==1.4.2
59
+ Pillow==9.0.1
60
+ portalocker==2.4.0
61
+ protobuf==3.20.1
62
+ psutil==5.9.0
63
+ pyasn1==0.4.8
64
+ pyasn1-modules==0.2.8
65
+ pycparser==2.21
66
+ pyerfa==2.0.0.1
67
+ PyJWT==2.3.0
68
+ pyOpenSSL==22.0.0
69
+ pyparsing==3.0.8
70
+ PySocks==1.7.1
71
+ python-dateutil==2.8.2
72
+ pytz==2022.1
73
+ pyu2f==0.1.5
74
+ PyWavelets==1.3.0
75
+ PyYAML==5.4.1
76
+ requests==2.27.1
77
+ requests-oauthlib==1.3.1
78
+ requests-toolbelt==0.9.1
79
+ rsa==4.8
80
+ scikit-image==0.17.2
81
+ scikit-learn==1.0.2
82
+ scipy==1.8.0
83
+ six==1.16.0
84
+ tabulate==0.8.9
85
+ threadpoolctl==3.1.0
86
+ tifffile==2022.5.4
87
+ tqdm==4.64.0
88
+ typing_extensions==4.1.1
89
+ urllib3==1.26.9
90
+ Werkzeug==2.1.2
91
+ wrapt==1.14.1
92
+ yacs==0.1.8
93
+ yarl==1.6.3
sup-mat/driving.mp4 ADDED
Binary file (101 kB). View file
 
sup-mat/driving.png ADDED
sup-mat/source.png ADDED
sup-mat/source_for_video.png ADDED
sync_batchnorm/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : __init__.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12
+ from .replicate import DataParallelWithCallback, patch_replication_callback
sync_batchnorm/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (371 Bytes). View file
 
sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc ADDED
Binary file (12.9 kB). View file
 
sync_batchnorm/__pycache__/comm.cpython-38.pyc ADDED
Binary file (4.81 kB). View file
 
sync_batchnorm/__pycache__/replicate.cpython-38.pyc ADDED
Binary file (3.46 kB). View file
 
sync_batchnorm/batchnorm.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import collections
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ from torch.nn.modules.batchnorm import _BatchNorm
17
+ from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
18
+
19
+ from .comm import SyncMaster
20
+
21
+ __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
22
+
23
+
24
+ def _sum_ft(tensor):
25
+ """sum over the first and last dimention"""
26
+ return tensor.sum(dim=0).sum(dim=-1)
27
+
28
+
29
+ def _unsqueeze_ft(tensor):
30
+ """add new dementions at the front and the tail"""
31
+ return tensor.unsqueeze(0).unsqueeze(-1)
32
+
33
+
34
+ _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
35
+ _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
36
+
37
+
38
+ class _SynchronizedBatchNorm(_BatchNorm):
39
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
40
+ super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
41
+
42
+ self._sync_master = SyncMaster(self._data_parallel_master)
43
+
44
+ self._is_parallel = False
45
+ self._parallel_id = None
46
+ self._slave_pipe = None
47
+
48
+ def forward(self, input):
49
+ # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
50
+ if not (self._is_parallel and self.training):
51
+ return F.batch_norm(
52
+ input, self.running_mean, self.running_var, self.weight, self.bias,
53
+ self.training, self.momentum, self.eps)
54
+
55
+ # Resize the input to (B, C, -1).
56
+ input_shape = input.size()
57
+ input = input.view(input.size(0), self.num_features, -1)
58
+
59
+ # Compute the sum and square-sum.
60
+ sum_size = input.size(0) * input.size(2)
61
+ input_sum = _sum_ft(input)
62
+ input_ssum = _sum_ft(input ** 2)
63
+
64
+ # Reduce-and-broadcast the statistics.
65
+ if self._parallel_id == 0:
66
+ mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
67
+ else:
68
+ mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
69
+
70
+ # Compute the output.
71
+ if self.affine:
72
+ # MJY:: Fuse the multiplication for speed.
73
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
74
+ else:
75
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
76
+
77
+ # Reshape it.
78
+ return output.view(input_shape)
79
+
80
+ def __data_parallel_replicate__(self, ctx, copy_id):
81
+ self._is_parallel = True
82
+ self._parallel_id = copy_id
83
+
84
+ # parallel_id == 0 means master device.
85
+ if self._parallel_id == 0:
86
+ ctx.sync_master = self._sync_master
87
+ else:
88
+ self._slave_pipe = ctx.sync_master.register_slave(copy_id)
89
+
90
+ def _data_parallel_master(self, intermediates):
91
+ """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
92
+
93
+ # Always using same "device order" makes the ReduceAdd operation faster.
94
+ # Thanks to:: Tete Xiao (http://tetexiao.com/)
95
+ intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
96
+
97
+ to_reduce = [i[1][:2] for i in intermediates]
98
+ to_reduce = [j for i in to_reduce for j in i] # flatten
99
+ target_gpus = [i[1].sum.get_device() for i in intermediates]
100
+
101
+ sum_size = sum([i[1].sum_size for i in intermediates])
102
+ sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
103
+ mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
104
+
105
+ broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
106
+
107
+ outputs = []
108
+ for i, rec in enumerate(intermediates):
109
+ outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
110
+
111
+ return outputs
112
+
113
+ def _compute_mean_std(self, sum_, ssum, size):
114
+ """Compute the mean and standard-deviation with sum and square-sum. This method
115
+ also maintains the moving average on the master device."""
116
+ assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
117
+ mean = sum_ / size
118
+ sumvar = ssum - sum_ * mean
119
+ unbias_var = sumvar / (size - 1)
120
+ bias_var = sumvar / size
121
+
122
+ self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
123
+ self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
124
+
125
+ return mean, bias_var.clamp(self.eps) ** -0.5
126
+
127
+
128
+ class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
129
+ r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
130
+ mini-batch.
131
+
132
+ .. math::
133
+
134
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
135
+
136
+ This module differs from the built-in PyTorch BatchNorm1d as the mean and
137
+ standard-deviation are reduced across all devices during training.
138
+
139
+ For example, when one uses `nn.DataParallel` to wrap the network during
140
+ training, PyTorch's implementation normalize the tensor on each device using
141
+ the statistics only on that device, which accelerated the computation and
142
+ is also easy to implement, but the statistics might be inaccurate.
143
+ Instead, in this synchronized version, the statistics will be computed
144
+ over all training samples distributed on multiple devices.
145
+
146
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
147
+ as the built-in PyTorch implementation.
148
+
149
+ The mean and standard-deviation are calculated per-dimension over
150
+ the mini-batches and gamma and beta are learnable parameter vectors
151
+ of size C (where C is the input size).
152
+
153
+ During training, this layer keeps a running estimate of its computed mean
154
+ and variance. The running sum is kept with a default momentum of 0.1.
155
+
156
+ During evaluation, this running mean/variance is used for normalization.
157
+
158
+ Because the BatchNorm is done over the `C` dimension, computing statistics
159
+ on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
160
+
161
+ Args:
162
+ num_features: num_features from an expected input of size
163
+ `batch_size x num_features [x width]`
164
+ eps: a value added to the denominator for numerical stability.
165
+ Default: 1e-5
166
+ momentum: the value used for the running_mean and running_var
167
+ computation. Default: 0.1
168
+ affine: a boolean value that when set to ``True``, gives the layer learnable
169
+ affine parameters. Default: ``True``
170
+
171
+ Shape:
172
+ - Input: :math:`(N, C)` or :math:`(N, C, L)`
173
+ - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
174
+
175
+ Examples:
176
+ >>> # With Learnable Parameters
177
+ >>> m = SynchronizedBatchNorm1d(100)
178
+ >>> # Without Learnable Parameters
179
+ >>> m = SynchronizedBatchNorm1d(100, affine=False)
180
+ >>> input = torch.autograd.Variable(torch.randn(20, 100))
181
+ >>> output = m(input)
182
+ """
183
+
184
+ def _check_input_dim(self, input):
185
+ if input.dim() != 2 and input.dim() != 3:
186
+ raise ValueError('expected 2D or 3D input (got {}D input)'
187
+ .format(input.dim()))
188
+ super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
189
+
190
+
191
+ class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
192
+ r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
193
+ of 3d inputs
194
+
195
+ .. math::
196
+
197
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
198
+
199
+ This module differs from the built-in PyTorch BatchNorm2d as the mean and
200
+ standard-deviation are reduced across all devices during training.
201
+
202
+ For example, when one uses `nn.DataParallel` to wrap the network during
203
+ training, PyTorch's implementation normalize the tensor on each device using
204
+ the statistics only on that device, which accelerated the computation and
205
+ is also easy to implement, but the statistics might be inaccurate.
206
+ Instead, in this synchronized version, the statistics will be computed
207
+ over all training samples distributed on multiple devices.
208
+
209
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
210
+ as the built-in PyTorch implementation.
211
+
212
+ The mean and standard-deviation are calculated per-dimension over
213
+ the mini-batches and gamma and beta are learnable parameter vectors
214
+ of size C (where C is the input size).
215
+
216
+ During training, this layer keeps a running estimate of its computed mean
217
+ and variance. The running sum is kept with a default momentum of 0.1.
218
+
219
+ During evaluation, this running mean/variance is used for normalization.
220
+
221
+ Because the BatchNorm is done over the `C` dimension, computing statistics
222
+ on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
223
+
224
+ Args:
225
+ num_features: num_features from an expected input of
226
+ size batch_size x num_features x height x width
227
+ eps: a value added to the denominator for numerical stability.
228
+ Default: 1e-5
229
+ momentum: the value used for the running_mean and running_var
230
+ computation. Default: 0.1
231
+ affine: a boolean value that when set to ``True``, gives the layer learnable
232
+ affine parameters. Default: ``True``
233
+
234
+ Shape:
235
+ - Input: :math:`(N, C, H, W)`
236
+ - Output: :math:`(N, C, H, W)` (same shape as input)
237
+
238
+ Examples:
239
+ >>> # With Learnable Parameters
240
+ >>> m = SynchronizedBatchNorm2d(100)
241
+ >>> # Without Learnable Parameters
242
+ >>> m = SynchronizedBatchNorm2d(100, affine=False)
243
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
244
+ >>> output = m(input)
245
+ """
246
+
247
+ def _check_input_dim(self, input):
248
+ if input.dim() != 4:
249
+ raise ValueError('expected 4D input (got {}D input)'
250
+ .format(input.dim()))
251
+ super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
252
+
253
+
254
+ class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
255
+ r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
256
+ of 4d inputs
257
+
258
+ .. math::
259
+
260
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
261
+
262
+ This module differs from the built-in PyTorch BatchNorm3d as the mean and
263
+ standard-deviation are reduced across all devices during training.
264
+
265
+ For example, when one uses `nn.DataParallel` to wrap the network during
266
+ training, PyTorch's implementation normalize the tensor on each device using
267
+ the statistics only on that device, which accelerated the computation and
268
+ is also easy to implement, but the statistics might be inaccurate.
269
+ Instead, in this synchronized version, the statistics will be computed
270
+ over all training samples distributed on multiple devices.
271
+
272
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
273
+ as the built-in PyTorch implementation.
274
+
275
+ The mean and standard-deviation are calculated per-dimension over
276
+ the mini-batches and gamma and beta are learnable parameter vectors
277
+ of size C (where C is the input size).
278
+
279
+ During training, this layer keeps a running estimate of its computed mean
280
+ and variance. The running sum is kept with a default momentum of 0.1.
281
+
282
+ During evaluation, this running mean/variance is used for normalization.
283
+
284
+ Because the BatchNorm is done over the `C` dimension, computing statistics
285
+ on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
286
+ or Spatio-temporal BatchNorm
287
+
288
+ Args:
289
+ num_features: num_features from an expected input of
290
+ size batch_size x num_features x depth x height x width
291
+ eps: a value added to the denominator for numerical stability.
292
+ Default: 1e-5
293
+ momentum: the value used for the running_mean and running_var
294
+ computation. Default: 0.1
295
+ affine: a boolean value that when set to ``True``, gives the layer learnable
296
+ affine parameters. Default: ``True``
297
+
298
+ Shape:
299
+ - Input: :math:`(N, C, D, H, W)`
300
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
301
+
302
+ Examples:
303
+ >>> # With Learnable Parameters
304
+ >>> m = SynchronizedBatchNorm3d(100)
305
+ >>> # Without Learnable Parameters
306
+ >>> m = SynchronizedBatchNorm3d(100, affine=False)
307
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
308
+ >>> output = m(input)
309
+ """
310
+
311
+ def _check_input_dim(self, input):
312
+ if input.dim() != 5:
313
+ raise ValueError('expected 5D input (got {}D input)'
314
+ .format(input.dim()))
315
+ super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
sync_batchnorm/comm.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : comm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import queue
12
+ import collections
13
+ import threading
14
+
15
+ __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16
+
17
+
18
+ class FutureResult(object):
19
+ """A thread-safe future implementation. Used only as one-to-one pipe."""
20
+
21
+ def __init__(self):
22
+ self._result = None
23
+ self._lock = threading.Lock()
24
+ self._cond = threading.Condition(self._lock)
25
+
26
+ def put(self, result):
27
+ with self._lock:
28
+ assert self._result is None, 'Previous result has\'t been fetched.'
29
+ self._result = result
30
+ self._cond.notify()
31
+
32
+ def get(self):
33
+ with self._lock:
34
+ if self._result is None:
35
+ self._cond.wait()
36
+
37
+ res = self._result
38
+ self._result = None
39
+ return res
40
+
41
+
42
+ _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43
+ _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44
+
45
+
46
+ class SlavePipe(_SlavePipeBase):
47
+ """Pipe for master-slave communication."""
48
+
49
+ def run_slave(self, msg):
50
+ self.queue.put((self.identifier, msg))
51
+ ret = self.result.get()
52
+ self.queue.put(True)
53
+ return ret
54
+
55
+
56
+ class SyncMaster(object):
57
+ """An abstract `SyncMaster` object.
58
+
59
+ - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60
+ call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61
+ - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62
+ and passed to a registered callback.
63
+ - After receiving the messages, the master device should gather the information and determine to message passed
64
+ back to each slave devices.
65
+ """
66
+
67
+ def __init__(self, master_callback):
68
+ """
69
+
70
+ Args:
71
+ master_callback: a callback to be invoked after having collected messages from slave devices.
72
+ """
73
+ self._master_callback = master_callback
74
+ self._queue = queue.Queue()
75
+ self._registry = collections.OrderedDict()
76
+ self._activated = False
77
+
78
+ def __getstate__(self):
79
+ return {'master_callback': self._master_callback}
80
+
81
+ def __setstate__(self, state):
82
+ self.__init__(state['master_callback'])
83
+
84
+ def register_slave(self, identifier):
85
+ """
86
+ Register an slave device.
87
+
88
+ Args:
89
+ identifier: an identifier, usually is the device id.
90
+
91
+ Returns: a `SlavePipe` object which can be used to communicate with the master device.
92
+
93
+ """
94
+ if self._activated:
95
+ assert self._queue.empty(), 'Queue is not clean before next initialization.'
96
+ self._activated = False
97
+ self._registry.clear()
98
+ future = FutureResult()
99
+ self._registry[identifier] = _MasterRegistry(future)
100
+ return SlavePipe(identifier, self._queue, future)
101
+
102
+ def run_master(self, master_msg):
103
+ """
104
+ Main entry for the master device in each forward pass.
105
+ The messages were first collected from each devices (including the master device), and then
106
+ an callback will be invoked to compute the message to be sent back to each devices
107
+ (including the master device).
108
+
109
+ Args:
110
+ master_msg: the message that the master want to send to itself. This will be placed as the first
111
+ message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
112
+
113
+ Returns: the message to be sent back to the master device.
114
+
115
+ """
116
+ self._activated = True
117
+
118
+ intermediates = [(0, master_msg)]
119
+ for i in range(self.nr_slaves):
120
+ intermediates.append(self._queue.get())
121
+
122
+ results = self._master_callback(intermediates)
123
+ assert results[0][0] == 0, 'The first result should belongs to the master.'
124
+
125
+ for i, res in results:
126
+ if i == 0:
127
+ continue
128
+ self._registry[i].result.put(res)
129
+
130
+ for i in range(self.nr_slaves):
131
+ assert self._queue.get() is True
132
+
133
+ return results[0][1]
134
+
135
+ @property
136
+ def nr_slaves(self):
137
+ return len(self._registry)
sync_batchnorm/replicate.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : replicate.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import functools
12
+
13
+ from torch.nn.parallel.data_parallel import DataParallel
14
+
15
+ __all__ = [
16
+ 'CallbackContext',
17
+ 'execute_replication_callbacks',
18
+ 'DataParallelWithCallback',
19
+ 'patch_replication_callback'
20
+ ]
21
+
22
+
23
+ class CallbackContext(object):
24
+ pass
25
+
26
+
27
+ def execute_replication_callbacks(modules):
28
+ """
29
+ Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30
+
31
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32
+
33
+ Note that, as all modules are isomorphism, we assign each sub-module with a context
34
+ (shared among multiple copies of this module on different devices).
35
+ Through this context, different copies can share some information.
36
+
37
+ We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38
+ of any slave copies.
39
+ """
40
+ master_copy = modules[0]
41
+ nr_modules = len(list(master_copy.modules()))
42
+ ctxs = [CallbackContext() for _ in range(nr_modules)]
43
+
44
+ for i, module in enumerate(modules):
45
+ for j, m in enumerate(module.modules()):
46
+ if hasattr(m, '__data_parallel_replicate__'):
47
+ m.__data_parallel_replicate__(ctxs[j], i)
48
+
49
+
50
+ class DataParallelWithCallback(DataParallel):
51
+ """
52
+ Data Parallel with a replication callback.
53
+
54
+ An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55
+ original `replicate` function.
56
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57
+
58
+ Examples:
59
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61
+ # sync_bn.__data_parallel_replicate__ will be invoked.
62
+ """
63
+
64
+ def replicate(self, module, device_ids):
65
+ modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66
+ execute_replication_callbacks(modules)
67
+ return modules
68
+
69
+
70
+ def patch_replication_callback(data_parallel):
71
+ """
72
+ Monkey-patch an existing `DataParallel` object. Add the replication callback.
73
+ Useful when you have customized `DataParallel` implementation.
74
+
75
+ Examples:
76
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77
+ > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78
+ > patch_replication_callback(sync_bn)
79
+ # this is equivalent to
80
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82
+ """
83
+
84
+ assert isinstance(data_parallel, DataParallel)
85
+
86
+ old_replicate = data_parallel.replicate
87
+
88
+ @functools.wraps(old_replicate)
89
+ def new_replicate(module, device_ids):
90
+ modules = old_replicate(module, device_ids)
91
+ execute_replication_callbacks(modules)
92
+ return modules
93
+
94
+ data_parallel.replicate = new_replicate
sync_batchnorm/unittest.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : unittest.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import unittest
12
+
13
+ import numpy as np
14
+ from torch.autograd import Variable
15
+
16
+
17
+ def as_numpy(v):
18
+ if isinstance(v, Variable):
19
+ v = v.data
20
+ return v.cpu().numpy()
21
+
22
+
23
+ class TorchTestCase(unittest.TestCase):
24
+ def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
25
+ npa, npb = as_numpy(a), as_numpy(b)
26
+ self.assertTrue(
27
+ np.allclose(npa, npb, atol=atol),
28
+ 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
29
+ )