import random import json import torch import torch.nn.functional as F from torch.utils.data import Dataset from torchvision import transforms from torchvision.transforms.functional import crop from stablevideo.atlas_utils import ( load_neural_atlases_models, get_frames_data, get_high_res_atlas, get_atlas_crops, reconstruct_video_layer, create_uv_mask, get_masks_boundaries, get_random_crop_params, get_atlas_bounding_box, load_video ) class AtlasData(): def __init__(self, video_name): with open(f"data/{video_name}/config.json", "r") as f: json_dict = json.load(f) try: maximum_number_of_frames = json_dict["number_of_frames"] except: maximum_number_of_frames = json_dict["maximum_number_of_frames"] config = { "device": "cuda", "checkpoint_path": f"data/{video_name}/checkpoint.ckpt", "resx": json_dict["resx"], "resy": json_dict["resy"], "maximum_number_of_frames": maximum_number_of_frames, "return_atlas_alpha": False, "grid_atlas_resolution": 2000, "num_scales": 7, "masks_border_expansion": 30, "mask_alpha_threshold": 0.99, # 0.95 "align_corners": False } self.config = config self.device = config["device"] self.min_size = min(self.config["resx"], self.config["resy"]) self.max_size = max(self.config["resx"], self.config["resy"]) data_folder = f"data/{video_name}/{video_name}" self.original_video = load_video( data_folder, resize=(self.config["resy"], self.config["resx"]), num_frames=self.config["maximum_number_of_frames"], ) self.original_video = self.original_video.to(self.device) # tensor ( foreground_mapping, background_mapping, foreground_atlas_model, background_atlas_model, alpha_model, ) = load_neural_atlases_models(config) ( original_background_all_uvs, original_foreground_all_uvs, self.all_alpha, foreground_atlas_alpha, ) = get_frames_data( config, foreground_mapping, background_mapping, alpha_model, ) self.background_reconstruction = reconstruct_video_layer(original_background_all_uvs, background_atlas_model) # using original video for the foreground layer self.foreground_reconstruction = self.original_video * self.all_alpha ( self.background_all_uvs, self.scaled_background_uvs, self.background_min_u, self.background_min_v, self.background_max_u, self.background_max_v, ) = self.preprocess_uv_values( original_background_all_uvs, config["grid_atlas_resolution"], device=self.device, layer="background" ) ( self.foreground_all_uvs, self.scaled_foreground_uvs, self.foreground_min_u, self.foreground_min_v, self.foreground_max_u, self.foreground_max_v, ) = self.preprocess_uv_values( original_foreground_all_uvs, config["grid_atlas_resolution"], device=self.device, layer="foreground" ) self.background_uv_mask = create_uv_mask( config, background_mapping, self.background_min_u, self.background_min_v, self.background_max_u, self.background_max_v, uv_shift=-0.5, resolution_shift=1, ) self.foreground_uv_mask = create_uv_mask( config, foreground_mapping, self.foreground_min_u, self.foreground_min_v, self.foreground_max_u, self.foreground_max_v, uv_shift=0.5, resolution_shift=0, ) self.background_grid_atlas = get_high_res_atlas( background_atlas_model, self.background_min_v, self.background_min_u, self.background_max_v, self.background_max_u, config["grid_atlas_resolution"], device=config["device"], layer="background", ) self.foreground_grid_atlas = get_high_res_atlas( foreground_atlas_model, self.foreground_min_v, self.foreground_min_u, self.foreground_max_v, self.foreground_max_u, config["grid_atlas_resolution"], device=config["device"], layer="foreground", ) if config["return_atlas_alpha"]: self.foreground_atlas_alpha = foreground_atlas_alpha # used for visualizations self.cnn_min_crop_size = 2 ** self.config["num_scales"] + 1 self.mask_boundaries = get_masks_boundaries( alpha_video=self.all_alpha.cpu(), border=self.config["masks_border_expansion"], threshold=self.config["mask_alpha_threshold"], min_crop_size=self.cnn_min_crop_size, ) self.cropped_foreground_atlas, self.foreground_atlas_bbox = get_atlas_bounding_box( self.mask_boundaries, self.foreground_grid_atlas, self.foreground_all_uvs ) self.step = -1 self.edited_atlas_dict, self.edit_dict, self.uv_mask = {}, {}, {} @staticmethod def preprocess_uv_values(layer_uv_values, resolution, device="cuda", layer="background"): if layer == "background": shift = 1 else: shift = 0 uv_values = (layer_uv_values + shift) * resolution min_u, min_v = uv_values.reshape(-1, 2).min(dim=0).values.long() uv_values -= torch.tensor([min_u, min_v], device=device) max_u, max_v = uv_values.reshape(-1, 2).max(dim=0).values.ceil().long() edge_size = torch.tensor([max_u, max_v], device=device) scaled_uv_values = ((uv_values.reshape(-1, 2) / edge_size) * 2 - 1).unsqueeze(1).unsqueeze(0) return uv_values, scaled_uv_values, min_u, min_v, max_u, max_v def get_random_crop_data(self, crop_size): t = random.randint(0, self.config["maximum_number_of_frames"] - 1) y_start, x_start, h_crop, w_crop = get_random_crop_params((self.config["resx"], self.config["resy"]), crop_size) return y_start, x_start, h_crop, w_crop, t def get_global_crops_multi(self, keyframes, res): foreground_atlas_crops = [] background_atlas_crops = [] foreground_uvs = [] background_uvs = [] background_alpha_crops = [] foreground_alpha_crops = [] original_background_crops = [] original_foreground_crops = [] output_dict = {} self.config["crops_min_cover"] = 0.95 self.config["grid_atlas_resolution"] = res for cur_frame in keyframes: y_start, x_start, frame_h, frame_w = self.mask_boundaries[cur_frame].tolist() crop_size = ( max( random.randint(round(self.config["crops_min_cover"] * frame_h), frame_h), self.cnn_min_crop_size, ), max( random.randint(round(self.config["crops_min_cover"] * frame_w), frame_w), self.cnn_min_crop_size, ), ) y_crop, x_crop, h_crop, w_crop = get_random_crop_params((frame_w, frame_h), crop_size) foreground_uv = self.foreground_all_uvs[ cur_frame, y_start + y_crop : y_start + y_crop + h_crop, x_start + x_crop : x_start + x_crop + w_crop, ] alpha = self.all_alpha[ [cur_frame], :, y_start + y_crop : y_start + y_crop + h_crop, x_start + x_crop : x_start + x_crop + w_crop, ] original_foreground_crop = self.foreground_reconstruction[ [cur_frame], :, y_start + y_crop : y_start + y_crop + h_crop, x_start + x_crop : x_start + x_crop + w_crop, ] foreground_alpha_crops.append(alpha) foreground_uvs.append(foreground_uv) # not scaled original_foreground_crops.append(original_foreground_crop) foreground_max_vals = torch.tensor( [self.config["grid_atlas_resolution"]] * 2, device=self.device, dtype=torch.long ) foreground_min_vals = torch.tensor([0] * 2, device=self.device, dtype=torch.long) for uv_values in foreground_uvs: min_uv = uv_values.amin(dim=[0, 1]).long() max_uv = uv_values.amax(dim=[0, 1]).ceil().long() foreground_min_vals = torch.minimum(foreground_min_vals, min_uv) foreground_max_vals = torch.maximum(foreground_max_vals, max_uv) h_v = foreground_max_vals[1] - foreground_min_vals[1] w_u = foreground_max_vals[0] - foreground_min_vals[0] foreground_atlas_crop = crop( self.foreground_grid_atlas, foreground_min_vals[1], foreground_min_vals[0], h_v, w_u, ) foreground_atlas_crops.append(foreground_atlas_crop) for i, uv_values in enumerate(foreground_uvs): foreground_uvs[i] = ( 2 * (uv_values - foreground_min_vals) / (foreground_max_vals - foreground_min_vals) - 1 ).unsqueeze(0) crop_size = ( random.randint(round(self.config["crops_min_cover"] * self.min_size), self.min_size), random.randint(round(self.config["crops_min_cover"] * self.max_size), self.max_size), ) crop_data = self.get_random_crop_data(crop_size) y, x, h, w, _ = crop_data background_uv = self.background_all_uvs[keyframes, y : y + h, x : x + w] original_background_crop = self.background_reconstruction[ keyframes, :, y : y + h, x : x + w ] alpha = self.all_alpha[keyframes, :, y : y + h, x : x + w] original_background_crops = [el.unsqueeze(0) for el in original_background_crop] background_alpha_crops = [el.unsqueeze(0) for el in alpha] background_atlas_crop, background_min_vals, background_max_vals = get_atlas_crops( background_uv, self.background_grid_atlas, ) background_uv = 2 * (background_uv - background_min_vals) / (background_max_vals - background_min_vals) - 1 background_atlas_crops = [el.unsqueeze(0) for el in background_atlas_crop] background_uvs = [el.unsqueeze(0) for el in background_uv] output_dict["foreground_alpha"] = foreground_alpha_crops output_dict["foreground_uvs"] = foreground_uvs output_dict["original_foreground_crops"] = original_foreground_crops output_dict["foreground_atlas_crops"] = foreground_atlas_crops output_dict["background_alpha"] = background_alpha_crops output_dict["background_uvs"] = background_uvs output_dict["original_background_crops"] = original_background_crops output_dict["background_atlas_crops"] = background_atlas_crops return output_dict