SupermanxKiaski commited on
Commit
8e729e5
1 Parent(s): 4919318

Upload video_dataset.py

Browse files
Files changed (1) hide show
  1. video_dataset.py +360 -0
video_dataset.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.utils.data import Dataset
6
+ from torchvision import transforms
7
+ from torchvision.transforms.functional import crop
8
+
9
+ from models.video_model import VideoModel
10
+ from util.atlas_utils import (
11
+ load_neural_atlases_models,
12
+ get_frames_data,
13
+ get_high_res_atlas,
14
+ get_atlas_crops,
15
+ reconstruct_video_layer,
16
+ create_uv_mask,
17
+ get_masks_boundaries,
18
+ get_random_crop_params,
19
+ get_atlas_bounding_box,
20
+ )
21
+ from util.util import load_video
22
+
23
+
24
+ class AtlasDataset(Dataset):
25
+ def __init__(self, config):
26
+ self.config = config
27
+ self.device = config["device"]
28
+
29
+ self.min_size = min(self.config["resx"], self.config["resy"])
30
+ self.max_size = max(self.config["resx"], self.config["resy"])
31
+ data_folder = f"data/videos/{self.config['checkpoint_path'].split('/')[2]}"
32
+ self.original_video = load_video(
33
+ data_folder,
34
+ resize=(self.config["resy"], self.config["resx"]),
35
+ num_frames=self.config["maximum_number_of_frames"],
36
+ ).to(self.device)
37
+
38
+ (
39
+ foreground_mapping,
40
+ background_mapping,
41
+ foreground_atlas_model,
42
+ background_atlas_model,
43
+ alpha_model,
44
+ ) = load_neural_atlases_models(config)
45
+ (
46
+ original_background_all_uvs,
47
+ original_foreground_all_uvs,
48
+ self.all_alpha,
49
+ foreground_atlas_alpha,
50
+ ) = get_frames_data(
51
+ config,
52
+ foreground_mapping,
53
+ background_mapping,
54
+ alpha_model,
55
+ )
56
+
57
+ self.background_reconstruction = reconstruct_video_layer(original_background_all_uvs, background_atlas_model)
58
+ # using original video for the foreground layer
59
+ self.foreground_reconstruction = self.original_video * self.all_alpha
60
+
61
+ (
62
+ self.background_all_uvs,
63
+ self.scaled_background_uvs,
64
+ self.background_min_u,
65
+ self.background_min_v,
66
+ self.background_max_u,
67
+ self.background_max_v,
68
+ ) = self.preprocess_uv_values(
69
+ original_background_all_uvs, config["grid_atlas_resolution"], device=self.device, layer="background"
70
+ )
71
+ (
72
+ self.foreground_all_uvs,
73
+ self.scaled_foreground_uvs,
74
+ self.foreground_min_u,
75
+ self.foreground_min_v,
76
+ self.foreground_max_u,
77
+ self.foreground_max_v,
78
+ ) = self.preprocess_uv_values(
79
+ original_foreground_all_uvs, config["grid_atlas_resolution"], device=self.device, layer="foreground"
80
+ )
81
+
82
+ self.background_uv_mask = create_uv_mask(
83
+ config,
84
+ background_mapping,
85
+ self.background_min_u,
86
+ self.background_min_v,
87
+ self.background_max_u,
88
+ self.background_max_v,
89
+ uv_shift=-0.5,
90
+ resolution_shift=1,
91
+ )
92
+ self.foreground_uv_mask = create_uv_mask(
93
+ config,
94
+ foreground_mapping,
95
+ self.foreground_min_u,
96
+ self.foreground_min_v,
97
+ self.foreground_max_u,
98
+ self.foreground_max_v,
99
+ uv_shift=0.5,
100
+ resolution_shift=0,
101
+ )
102
+ self.background_grid_atlas = get_high_res_atlas(
103
+ background_atlas_model,
104
+ self.background_min_v,
105
+ self.background_min_u,
106
+ self.background_max_v,
107
+ self.background_max_u,
108
+ config["grid_atlas_resolution"],
109
+ device=config["device"],
110
+ layer="background",
111
+ )
112
+ self.foreground_grid_atlas = get_high_res_atlas(
113
+ foreground_atlas_model,
114
+ self.foreground_min_v,
115
+ self.foreground_min_u,
116
+ self.foreground_max_v,
117
+ self.foreground_max_u,
118
+ config["grid_atlas_resolution"],
119
+ device=config["device"],
120
+ layer="foreground",
121
+ )
122
+ if config["return_atlas_alpha"]:
123
+ self.foreground_atlas_alpha = foreground_atlas_alpha # used for visualizations
124
+ self.cnn_min_crop_size = 2 ** self.config["num_scales"] + 1
125
+ if self.config["finetune_foreground"]:
126
+ self.mask_boundaries = get_masks_boundaries(
127
+ alpha_video=self.all_alpha.cpu(),
128
+ border=self.config["masks_border_expansion"],
129
+ threshold=self.config["mask_alpha_threshold"],
130
+ min_crop_size=self.cnn_min_crop_size,
131
+ )
132
+ self.cropped_foreground_atlas, self.foreground_atlas_bbox = get_atlas_bounding_box(
133
+ self.mask_boundaries, self.foreground_grid_atlas, self.foreground_all_uvs
134
+ )
135
+
136
+ self.step = -1
137
+
138
+ crop_transforms = transforms.Compose(
139
+ [
140
+ transforms.RandomApply(
141
+ [transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1)],
142
+ p=0.1,
143
+ ),
144
+ ]
145
+ )
146
+ self.crop_aug = crop_transforms
147
+ self.dist = self.config["center_frame_distance"]
148
+
149
+ @staticmethod
150
+ def preprocess_uv_values(layer_uv_values, resolution, device="cuda", layer="background"):
151
+ if layer == "background":
152
+ shift = 1
153
+ else:
154
+ shift = 0
155
+ uv_values = (layer_uv_values + shift) * resolution
156
+ min_u, min_v = uv_values.reshape(-1, 2).min(dim=0).values.long()
157
+ uv_values -= torch.tensor([min_u, min_v], device=device)
158
+ max_u, max_v = uv_values.reshape(-1, 2).max(dim=0).values.ceil().long()
159
+
160
+ edge_size = torch.tensor([max_u, max_v], device=device)
161
+ scaled_uv_values = ((uv_values.reshape(-1, 2) / edge_size) * 2 - 1).unsqueeze(1).unsqueeze(0)
162
+
163
+ return uv_values, scaled_uv_values, min_u, min_v, max_u, max_v
164
+
165
+ def get_random_crop_data(self, crop_size):
166
+ t = random.randint(0, self.config["maximum_number_of_frames"] - 1)
167
+ y_start, x_start, h_crop, w_crop = get_random_crop_params((self.config["resx"], self.config["resy"]), crop_size)
168
+ return y_start, x_start, h_crop, w_crop, t
169
+
170
+ def get_global_crops_multi(self):
171
+ foreground_atlas_crops = []
172
+ background_atlas_crops = []
173
+ foreground_uvs = []
174
+ background_uvs = []
175
+ background_alpha_crops = []
176
+ foreground_alpha_crops = []
177
+ original_background_crops = []
178
+ original_foreground_crops = []
179
+ output_dict = {}
180
+
181
+ t = random.randint(self.dist, self.config["maximum_number_of_frames"] - 1 - self.dist)
182
+ flip = torch.rand(1) < self.config["flip_p"]
183
+ if self.config["finetune_foreground"]:
184
+ for cur_frame in [t - self.dist, t, t + self.dist]:
185
+ y_start, x_start, frame_h, frame_w = self.mask_boundaries[cur_frame].tolist()
186
+ crop_size = (
187
+ max(
188
+ random.randint(round(self.config["crops_min_cover"] * frame_h), frame_h),
189
+ self.cnn_min_crop_size,
190
+ ),
191
+ max(
192
+ random.randint(round(self.config["crops_min_cover"] * frame_w), frame_w),
193
+ self.cnn_min_crop_size,
194
+ ),
195
+ )
196
+ y_crop, x_crop, h_crop, w_crop = get_random_crop_params((frame_w, frame_h), crop_size)
197
+
198
+ foreground_uv = self.foreground_all_uvs[
199
+ cur_frame,
200
+ y_start + y_crop : y_start + y_crop + h_crop,
201
+ x_start + x_crop : x_start + x_crop + w_crop,
202
+ ]
203
+ alpha = self.all_alpha[
204
+ [cur_frame],
205
+ :,
206
+ y_start + y_crop : y_start + y_crop + h_crop,
207
+ x_start + x_crop : x_start + x_crop + w_crop,
208
+ ]
209
+
210
+ original_foreground_crop = self.foreground_reconstruction[
211
+ [cur_frame],
212
+ :,
213
+ y_start + y_crop : y_start + y_crop + h_crop,
214
+ x_start + x_crop : x_start + x_crop + w_crop,
215
+ ]
216
+
217
+ original_foreground_crop = self.crop_aug(original_foreground_crop)
218
+ foreground_alpha_crops.append(alpha.flip(-1) if flip else alpha)
219
+ foreground_uvs.append(foreground_uv) # not scaled
220
+ original_foreground_crops.append(
221
+ original_foreground_crop.flip(-1) if flip else original_foreground_crop
222
+ )
223
+
224
+ foreground_min_vals = torch.tensor(
225
+ [self.config["grid_atlas_resolution"]] * 2, device=self.device, dtype=torch.long
226
+ )
227
+ foreground_max_vals = torch.tensor([0] * 2, device=self.device, dtype=torch.long)
228
+ for uv_values in foreground_uvs:
229
+ min_uv = uv_values.amin(dim=[0, 1]).long()
230
+ max_uv = uv_values.amax(dim=[0, 1]).ceil().long()
231
+ foreground_min_vals = torch.minimum(foreground_min_vals, min_uv)
232
+ foreground_max_vals = torch.maximum(foreground_max_vals, max_uv)
233
+
234
+ h_v = foreground_max_vals[1] - foreground_min_vals[1]
235
+ w_u = foreground_max_vals[0] - foreground_min_vals[0]
236
+ foreground_atlas_crop = crop(
237
+ self.foreground_grid_atlas,
238
+ foreground_min_vals[1],
239
+ foreground_min_vals[0],
240
+ h_v,
241
+ w_u,
242
+ )
243
+ foreground_atlas_crop = self.crop_aug(foreground_atlas_crop)
244
+
245
+ for i, uv_values in enumerate(foreground_uvs):
246
+ foreground_uvs[i] = (
247
+ 2 * (uv_values - foreground_min_vals) / (foreground_max_vals - foreground_min_vals) - 1
248
+ ).unsqueeze(0)
249
+ if flip:
250
+ foreground_uvs[i][:, :, :, 0] = -foreground_uvs[i][:, :, :, 0]
251
+ foreground_uvs[i] = foreground_uvs[i].flip(-2)
252
+ foreground_atlas_crops.append(foreground_atlas_crop.flip(-1) if flip else foreground_atlas_crop)
253
+
254
+ elif self.config["finetune_background"]:
255
+ crop_size = (
256
+ random.randint(round(self.config["crops_min_cover"] * self.min_size), self.min_size),
257
+ random.randint(round(self.config["crops_min_cover"] * self.max_size), self.max_size),
258
+ )
259
+ crop_data = self.get_random_crop_data(crop_size)
260
+ y, x, h, w, _ = crop_data
261
+ background_uv = self.background_all_uvs[[t - self.dist, t, t + self.dist], y : y + h, x : x + w]
262
+ original_background_crop = self.background_reconstruction[
263
+ [t - self.dist, t, t + self.dist], :, y : y + h, x : x + w
264
+ ]
265
+ alpha = self.all_alpha[[t - self.dist, t, t + self.dist], :, y : y + h, x : x + w]
266
+
267
+ original_background_crop = self.crop_aug(original_background_crop)
268
+
269
+ original_background_crops = [
270
+ el.unsqueeze(0).flip(-1) if flip else el.unsqueeze(0) for el in original_background_crop
271
+ ]
272
+ background_alpha_crops = [el.unsqueeze(0).flip(-1) if flip else el.unsqueeze(0) for el in alpha]
273
+
274
+ background_atlas_crop, background_min_vals, background_max_vals = get_atlas_crops(
275
+ background_uv,
276
+ self.background_grid_atlas,
277
+ self.crop_aug,
278
+ )
279
+ background_uv = 2 * (background_uv - background_min_vals) / (background_max_vals - background_min_vals) - 1
280
+ if flip:
281
+ background_uv[:, :, :, 0] = -background_uv[:, :, :, 0]
282
+ background_uv = background_uv.flip(-2)
283
+ background_atlas_crops = [
284
+ el.unsqueeze(0).flip(-1) if flip else el.unsqueeze(0) for el in background_atlas_crop
285
+ ]
286
+ background_uvs = [el.unsqueeze(0) for el in background_uv]
287
+
288
+ if self.config["finetune_foreground"]:
289
+ output_dict["foreground_alpha"] = foreground_alpha_crops
290
+ output_dict["foreground_uvs"] = foreground_uvs
291
+ output_dict["original_foreground_crops"] = original_foreground_crops
292
+ output_dict["foreground_atlas_crops"] = foreground_atlas_crops
293
+ elif self.config["finetune_background"]:
294
+ output_dict["background_alpha"] = background_alpha_crops
295
+ output_dict["background_uvs"] = background_uvs
296
+ output_dict["original_background_crops"] = original_background_crops
297
+ output_dict["background_atlas_crops"] = background_atlas_crops
298
+
299
+ return output_dict
300
+
301
+ @torch.no_grad()
302
+ def render_video_from_atlas(self, model, layer="background", foreground_padding_mode="replicate"):
303
+ if layer == "background":
304
+ grid_atlas = self.background_grid_atlas
305
+ all_uvs = self.scaled_background_uvs
306
+ uv_mask = self.background_uv_mask
307
+ else:
308
+ grid_atlas = self.cropped_foreground_atlas
309
+ full_grid_atlas = self.foreground_grid_atlas
310
+ all_uvs = self.scaled_foreground_uvs
311
+ uv_mask = crop(self.foreground_uv_mask, *self.foreground_atlas_bbox)
312
+ atlas_edit_only = model.netG(grid_atlas)
313
+ edited_atlas_dict = model.render(atlas_edit_only, bg_image=grid_atlas)
314
+
315
+ if layer == "foreground":
316
+ atlas_edit_only = torch.nn.functional.pad(
317
+ atlas_edit_only,
318
+ pad=(
319
+ self.foreground_atlas_bbox[1],
320
+ full_grid_atlas.shape[-1] - (self.foreground_atlas_bbox[1] + self.foreground_atlas_bbox[3]),
321
+ self.foreground_atlas_bbox[0],
322
+ full_grid_atlas.shape[-2] - (self.foreground_atlas_bbox[0] + self.foreground_atlas_bbox[2]),
323
+ ),
324
+ mode=foreground_padding_mode,
325
+ )
326
+
327
+ edit = F.grid_sample(
328
+ atlas_edit_only, all_uvs, mode="bilinear", align_corners=self.config["align_corners"]
329
+ ).clamp(min=0.0, max=1.0)
330
+ edit = edit.squeeze().t() # shape (batch, 3)
331
+ edit = (
332
+ edit.reshape(self.config["maximum_number_of_frames"], self.config["resy"], self.config["resx"], 4)
333
+ .permute(0, 3, 1, 2)
334
+ .clamp(min=0.0, max=1.0)
335
+ )
336
+ edit_dict = model.render(edit, bg_image=self.original_video)
337
+
338
+ return edited_atlas_dict, edit_dict, uv_mask
339
+
340
+ def get_whole_atlas(self):
341
+ if self.config["finetune_foreground"]:
342
+ atlas = self.cropped_foreground_atlas
343
+ else:
344
+ atlas = self.background_grid_atlas
345
+ atlas = VideoModel.resize_crops(atlas, 3)
346
+
347
+ return atlas
348
+
349
+ def __getitem__(self, index):
350
+ self.step += 1
351
+ sample = {"step": self.step}
352
+ sample["global_crops"] = self.get_global_crops_multi()
353
+
354
+ if self.config["input_entire_atlas"] and ((self.step + 1) % self.config["entire_atlas_every"] == 0):
355
+ sample["input_image"] = self.get_whole_atlas()
356
+
357
+ return sample
358
+
359
+ def __len__(self):
360
+ return 1