import scipy.interpolate import torch from torchvision.transforms.functional import crop from tqdm import tqdm from models.implicit_neural_networks import IMLP def load_neural_atlases_models(config): foreground_mapping = IMLP( input_dim=3, output_dim=2, hidden_dim=256, use_positional=False, num_layers=6, skip_layers=[], ).to(config["device"]) background_mapping = IMLP( input_dim=3, output_dim=2, hidden_dim=256, use_positional=False, num_layers=4, skip_layers=[], ).to(config["device"]) foreground_atlas_model = IMLP( input_dim=2, output_dim=3, hidden_dim=256, use_positional=True, positional_dim=10, num_layers=8, skip_layers=[4, 7], ).to(config["device"]) background_atlas_model = IMLP( input_dim=2, output_dim=3, hidden_dim=256, use_positional=True, positional_dim=10, num_layers=8, skip_layers=[4, 7], ).to(config["device"]) alpha_model = IMLP( input_dim=3, output_dim=1, hidden_dim=256, use_positional=True, positional_dim=5, num_layers=8, skip_layers=[], ).to(config["device"]) checkpoint = torch.load(config["checkpoint_path"]) foreground_mapping.load_state_dict(checkpoint["model_F_mapping1_state_dict"]) background_mapping.load_state_dict(checkpoint["model_F_mapping2_state_dict"]) foreground_atlas_model.load_state_dict(checkpoint["F_atlas_state_dict"]) background_atlas_model.load_state_dict(checkpoint["F_atlas_state_dict"]) alpha_model.load_state_dict(checkpoint["model_F_alpha_state_dict"]) foreground_mapping = foreground_mapping.eval().requires_grad_(False) background_mapping = background_mapping.eval().requires_grad_(False) foreground_atlas_model = foreground_atlas_model.eval().requires_grad_(False) background_atlas_model = background_atlas_model.eval().requires_grad_(False) alpha_model = alpha_model.eval().requires_grad_(False) return foreground_mapping, background_mapping, foreground_atlas_model, background_atlas_model, alpha_model @torch.no_grad() def get_frames_data(config, foreground_mapping, background_mapping, alpha_model): max_size = max(config["resx"], config["resy"]) normalizing_factor = torch.tensor([max_size / 2, max_size / 2, config["maximum_number_of_frames"] / 2]) background_uv_values = torch.zeros( size=(config["maximum_number_of_frames"], config["resy"], config["resx"], 2), device=config["device"] ) foreground_uv_values = torch.zeros( size=(config["maximum_number_of_frames"], config["resy"], config["resx"], 2), device=config["device"] ) alpha = torch.zeros( size=(config["maximum_number_of_frames"], config["resy"], config["resx"], 1), device=config["device"] ) for frame in tqdm(range(config["maximum_number_of_frames"]), leave=False): indices = get_grid_indices(0, 0, config["resy"], config["resx"], t=torch.tensor(frame)) normalized_chunk = (indices / normalizing_factor - 1).to(config["device"]) # get the atlas UV coordinates from the two mapping networks; with torch.no_grad(): current_background_uv_values = background_mapping(normalized_chunk) current_foreground_uv_values = foreground_mapping(normalized_chunk) current_alpha = alpha_model(normalized_chunk) background_uv_values[frame, indices[:, 1], indices[:, 0]] = current_background_uv_values * 0.5 - 0.5 foreground_uv_values[frame, indices[:, 1], indices[:, 0]] = current_foreground_uv_values * 0.5 + 0.5 current_alpha = 0.5 * (current_alpha + 1.0) current_alpha = 0.99 * current_alpha + 0.001 alpha[frame, indices[:, 1], indices[:, 0]] = current_alpha if config["return_atlas_alpha"]: # this should take a few minutes foreground_atlas_alpha = torch.zeros( size=( config["maximum_number_of_frames"], config["grid_atlas_resolution"], config["grid_atlas_resolution"], 1, ), ) foreground_uv_values_grid = foreground_uv_values * config["grid_atlas_resolution"] indices = get_grid_indices(0, 0, config["grid_atlas_resolution"], config["grid_atlas_resolution"]) for frame in tqdm(range(config["maximum_number_of_frames"]), leave=False): interpolated = scipy.interpolate.griddata( foreground_uv_values_grid[frame].reshape(-1, 2).cpu().numpy(), alpha[frame] .reshape( -1, ) .cpu() .numpy(), indices.reshape(-1, 2).cpu().numpy(), method="linear", ).reshape(config["grid_atlas_resolution"], config["grid_atlas_resolution"], 1) foreground_atlas_alpha[frame] = torch.from_numpy(interpolated) foreground_atlas_alpha[foreground_atlas_alpha.isnan()] = 0.0 foreground_atlas_alpha = ( torch.median(foreground_atlas_alpha, dim=0, keepdim=True).values.to(config["device"]).permute(0, 3, 2, 1) ) else: foreground_atlas_alpha = None return background_uv_values, foreground_uv_values, alpha.permute(0, 3, 1, 2), foreground_atlas_alpha @torch.no_grad() def reconstruct_video_layer(uv_values, atlas_model): t, h, w, _ = uv_values.shape reconstruction = torch.zeros(size=(t, h, w, 3), device=uv_values.device) for frame in range(t): rgb = (atlas_model(uv_values[frame].reshape(-1, 2)) + 1) * 0.5 reconstruction[frame] = rgb.reshape(h, w, 3) return reconstruction.permute(0, 3, 1, 2) @torch.no_grad() def create_uv_mask(config, mapping_model, min_u, min_v, max_u, max_v, uv_shift=-0.5, resolution_shift=1): max_size = max(config["resx"], config["resy"]) normalizing_factor = torch.tensor([max_size / 2, max_size / 2, config["maximum_number_of_frames"] / 2]) resolution = config["grid_atlas_resolution"] uv_mask = torch.zeros(size=(resolution, resolution), device=config["device"]) for frame in tqdm(range(config["maximum_number_of_frames"]), leave=False): indices = get_grid_indices(0, 0, config["resy"], config["resx"], t=torch.tensor(frame)) for chunk in indices.split(50000, dim=0): normalized_chunk = (chunk / normalizing_factor - 1).to(config["device"]) # get the atlas UV coordinates from the two mapping networks; with torch.no_grad(): uv_values = mapping_model(normalized_chunk) uv_values = uv_values * 0.5 + uv_shift uv_values = ((uv_values + resolution_shift) * resolution).clip(0, resolution - 1) uv_mask[uv_values[:, 1].floor().long(), uv_values[:, 0].floor().long()] = 1 uv_mask[uv_values[:, 1].floor().long(), uv_values[:, 0].ceil().long()] = 1 uv_mask[uv_values[:, 1].ceil().long(), uv_values[:, 0].floor().long()] = 1 uv_mask[uv_values[:, 1].ceil().long(), uv_values[:, 0].ceil().long()] = 1 uv_mask = crop(uv_mask.unsqueeze(0).unsqueeze(0), min_v, min_u, max_v, max_u) return uv_mask.detach().cpu() # shape [1, 1, resolution, resolution] @torch.no_grad() def get_high_res_atlas(atlas_model, min_v, min_u, max_v, max_u, resolution, device="cuda", layer="background"): inds_grid = get_grid_indices(0, 0, resolution, resolution) inds_grid_chunks = inds_grid.split(50000, dim=0) if layer == "background": shift = -1 else: shift = 0 rendered_atlas = torch.zeros((resolution, resolution, 3)).to(device) # resy, resx, 3 with torch.no_grad(): # reconstruct image row by row for chunk in inds_grid_chunks: normalized_chunk = torch.stack( [ (chunk[:, 0] / resolution) + shift, (chunk[:, 1] / resolution) + shift, ], dim=-1, ).to(device) rgb_output = atlas_model(normalized_chunk) rendered_atlas[chunk[:, 1], chunk[:, 0], :] = rgb_output # move colors to RGB color domain (0,1) rendered_atlas = 0.5 * (rendered_atlas + 1) rendered_atlas = rendered_atlas.permute(2, 0, 1).unsqueeze(0) # shape (1, 3, resy, resx) cropped_atlas = crop( rendered_atlas, min_v, min_u, max_v, max_u, ) return cropped_atlas def get_grid_indices(x_start, y_start, h_crop, w_crop, t=None): crop_indices = torch.meshgrid(torch.arange(w_crop) + x_start, torch.arange(h_crop) + y_start) crop_indices = torch.stack(crop_indices, dim=-1) crop_indices = crop_indices.reshape(h_crop * w_crop, crop_indices.shape[-1]) if t is not None: crop_indices = torch.cat([crop_indices, t.repeat(h_crop * w_crop, 1)], dim=1) return crop_indices def get_atlas_crops(uv_values, grid_atlas, augmentation=None): if len(uv_values.shape) == 3: dims = [0, 1] elif len(uv_values.shape) == 4: dims = [0, 1, 2] else: raise ValueError("uv_values should be of shape of len 3 or 4") min_u, min_v = uv_values.amin(dim=dims).long() max_u, max_v = uv_values.amax(dim=dims).ceil().long() # min_u, min_v = uv_values.min(dim=0).values # max_u, max_v = uv_values.max(dim=0).values h_v = max_v - min_v w_u = max_u - min_u atlas_crop = crop(grid_atlas, min_v, min_u, h_v, w_u) if augmentation is not None: atlas_crop = augmentation(atlas_crop) return atlas_crop, torch.stack([min_u, min_v]), torch.stack([max_u, max_v]) def get_random_crop_params(input_size, output_size): w, h = input_size th, tw = output_size if h + 1 < th or w + 1 < tw: raise ValueError(f"Required crop size {(th, tw)} is larger then input image size {(h, w)}") if w == tw and h == th: return 0, 0, h, w i = torch.randint(0, h - th + 1, size=(1,)).item() j = torch.randint(0, w - tw + 1, size=(1,)).item() return i, j, th, tw def get_masks_boundaries(alpha_video, border=20, threshold=0.95, min_crop_size=2 ** 7 + 1): resy, resx = alpha_video.shape[-2:] num_frames = alpha_video.shape[0] masks_borders = torch.zeros((num_frames, 4), dtype=torch.int64) for i, file in enumerate(range(num_frames)): mask_im = alpha_video[i] mask_im[mask_im >= threshold] = 1 mask_im[mask_im < threshold] = 0 all_ones = mask_im.squeeze().nonzero() min_y, min_x = torch.maximum(all_ones.min(dim=0).values - border, torch.tensor([0, 0])) max_y, max_x = torch.minimum(all_ones.max(dim=0).values + border, torch.tensor([resy, resx])) h = max_y - min_y w = max_x - min_x if h < min_crop_size: pad = min_crop_size - h if max_y + pad > resy: min_y -= pad else: max_y += pad h = max_y - min_y if w < min_crop_size: pad = min_crop_size - w if max_x + pad > resx: min_x -= pad else: max_x += pad w = max_x - min_x masks_borders[i] = torch.tensor([min_y, min_x, h, w]) return masks_borders def get_atlas_bounding_box(mask_boundaries, grid_atlas, video_uvs): min_uv = torch.tensor(grid_atlas.shape[-2:], device=video_uvs.device) max_uv = torch.tensor([0, 0], device=video_uvs.device) for boundary, frame in zip(mask_boundaries, video_uvs): cropped_uvs = crop(frame.permute(2, 0, 1).unsqueeze(0), *list(boundary)) # 1,2,h,w min_uv = torch.minimum(cropped_uvs.amin(dim=[0, 2, 3]), min_uv).floor().int() max_uv = torch.maximum(cropped_uvs.amax(dim=[0, 2, 3]), max_uv).ceil().int() hw = max_uv - min_uv crop_data = [*list(min_uv)[::-1], *list(hw)[::-1]] return crop(grid_atlas, *crop_data), crop_data