Spaces:
Configuration error
Configuration error
import scipy.interpolate | |
import torch | |
from torchvision.transforms.functional import crop | |
from tqdm import tqdm | |
from models.implicit_neural_networks import IMLP | |
def load_neural_atlases_models(config): | |
foreground_mapping = IMLP( | |
input_dim=3, | |
output_dim=2, | |
hidden_dim=256, | |
use_positional=False, | |
num_layers=6, | |
skip_layers=[], | |
).to(config["device"]) | |
background_mapping = IMLP( | |
input_dim=3, | |
output_dim=2, | |
hidden_dim=256, | |
use_positional=False, | |
num_layers=4, | |
skip_layers=[], | |
).to(config["device"]) | |
foreground_atlas_model = IMLP( | |
input_dim=2, | |
output_dim=3, | |
hidden_dim=256, | |
use_positional=True, | |
positional_dim=10, | |
num_layers=8, | |
skip_layers=[4, 7], | |
).to(config["device"]) | |
background_atlas_model = IMLP( | |
input_dim=2, | |
output_dim=3, | |
hidden_dim=256, | |
use_positional=True, | |
positional_dim=10, | |
num_layers=8, | |
skip_layers=[4, 7], | |
).to(config["device"]) | |
alpha_model = IMLP( | |
input_dim=3, | |
output_dim=1, | |
hidden_dim=256, | |
use_positional=True, | |
positional_dim=5, | |
num_layers=8, | |
skip_layers=[], | |
).to(config["device"]) | |
checkpoint = torch.load(config["checkpoint_path"]) | |
foreground_mapping.load_state_dict(checkpoint["model_F_mapping1_state_dict"]) | |
background_mapping.load_state_dict(checkpoint["model_F_mapping2_state_dict"]) | |
foreground_atlas_model.load_state_dict(checkpoint["F_atlas_state_dict"]) | |
background_atlas_model.load_state_dict(checkpoint["F_atlas_state_dict"]) | |
alpha_model.load_state_dict(checkpoint["model_F_alpha_state_dict"]) | |
foreground_mapping = foreground_mapping.eval().requires_grad_(False) | |
background_mapping = background_mapping.eval().requires_grad_(False) | |
foreground_atlas_model = foreground_atlas_model.eval().requires_grad_(False) | |
background_atlas_model = background_atlas_model.eval().requires_grad_(False) | |
alpha_model = alpha_model.eval().requires_grad_(False) | |
return foreground_mapping, background_mapping, foreground_atlas_model, background_atlas_model, alpha_model | |
def get_frames_data(config, foreground_mapping, background_mapping, alpha_model): | |
max_size = max(config["resx"], config["resy"]) | |
normalizing_factor = torch.tensor([max_size / 2, max_size / 2, config["maximum_number_of_frames"] / 2]) | |
background_uv_values = torch.zeros( | |
size=(config["maximum_number_of_frames"], config["resy"], config["resx"], 2), device=config["device"] | |
) | |
foreground_uv_values = torch.zeros( | |
size=(config["maximum_number_of_frames"], config["resy"], config["resx"], 2), device=config["device"] | |
) | |
alpha = torch.zeros( | |
size=(config["maximum_number_of_frames"], config["resy"], config["resx"], 1), device=config["device"] | |
) | |
for frame in tqdm(range(config["maximum_number_of_frames"]), leave=False): | |
indices = get_grid_indices(0, 0, config["resy"], config["resx"], t=torch.tensor(frame)) | |
normalized_chunk = (indices / normalizing_factor - 1).to(config["device"]) | |
# get the atlas UV coordinates from the two mapping networks; | |
with torch.no_grad(): | |
current_background_uv_values = background_mapping(normalized_chunk) | |
current_foreground_uv_values = foreground_mapping(normalized_chunk) | |
current_alpha = alpha_model(normalized_chunk) | |
background_uv_values[frame, indices[:, 1], indices[:, 0]] = current_background_uv_values * 0.5 - 0.5 | |
foreground_uv_values[frame, indices[:, 1], indices[:, 0]] = current_foreground_uv_values * 0.5 + 0.5 | |
current_alpha = 0.5 * (current_alpha + 1.0) | |
current_alpha = 0.99 * current_alpha + 0.001 | |
alpha[frame, indices[:, 1], indices[:, 0]] = current_alpha | |
if config["return_atlas_alpha"]: # this should take a few minutes | |
foreground_atlas_alpha = torch.zeros( | |
size=( | |
config["maximum_number_of_frames"], | |
config["grid_atlas_resolution"], | |
config["grid_atlas_resolution"], | |
1, | |
), | |
) | |
foreground_uv_values_grid = foreground_uv_values * config["grid_atlas_resolution"] | |
indices = get_grid_indices(0, 0, config["grid_atlas_resolution"], config["grid_atlas_resolution"]) | |
for frame in tqdm(range(config["maximum_number_of_frames"]), leave=False): | |
interpolated = scipy.interpolate.griddata( | |
foreground_uv_values_grid[frame].reshape(-1, 2).cpu().numpy(), | |
alpha[frame] | |
.reshape( | |
-1, | |
) | |
.cpu() | |
.numpy(), | |
indices.reshape(-1, 2).cpu().numpy(), | |
method="linear", | |
).reshape(config["grid_atlas_resolution"], config["grid_atlas_resolution"], 1) | |
foreground_atlas_alpha[frame] = torch.from_numpy(interpolated) | |
foreground_atlas_alpha[foreground_atlas_alpha.isnan()] = 0.0 | |
foreground_atlas_alpha = ( | |
torch.median(foreground_atlas_alpha, dim=0, keepdim=True).values.to(config["device"]).permute(0, 3, 2, 1) | |
) | |
else: | |
foreground_atlas_alpha = None | |
return background_uv_values, foreground_uv_values, alpha.permute(0, 3, 1, 2), foreground_atlas_alpha | |
def reconstruct_video_layer(uv_values, atlas_model): | |
t, h, w, _ = uv_values.shape | |
reconstruction = torch.zeros(size=(t, h, w, 3), device=uv_values.device) | |
for frame in range(t): | |
rgb = (atlas_model(uv_values[frame].reshape(-1, 2)) + 1) * 0.5 | |
reconstruction[frame] = rgb.reshape(h, w, 3) | |
return reconstruction.permute(0, 3, 1, 2) | |
def create_uv_mask(config, mapping_model, min_u, min_v, max_u, max_v, uv_shift=-0.5, resolution_shift=1): | |
max_size = max(config["resx"], config["resy"]) | |
normalizing_factor = torch.tensor([max_size / 2, max_size / 2, config["maximum_number_of_frames"] / 2]) | |
resolution = config["grid_atlas_resolution"] | |
uv_mask = torch.zeros(size=(resolution, resolution), device=config["device"]) | |
for frame in tqdm(range(config["maximum_number_of_frames"]), leave=False): | |
indices = get_grid_indices(0, 0, config["resy"], config["resx"], t=torch.tensor(frame)) | |
for chunk in indices.split(50000, dim=0): | |
normalized_chunk = (chunk / normalizing_factor - 1).to(config["device"]) | |
# get the atlas UV coordinates from the two mapping networks; | |
with torch.no_grad(): | |
uv_values = mapping_model(normalized_chunk) | |
uv_values = uv_values * 0.5 + uv_shift | |
uv_values = ((uv_values + resolution_shift) * resolution).clip(0, resolution - 1) | |
uv_mask[uv_values[:, 1].floor().long(), uv_values[:, 0].floor().long()] = 1 | |
uv_mask[uv_values[:, 1].floor().long(), uv_values[:, 0].ceil().long()] = 1 | |
uv_mask[uv_values[:, 1].ceil().long(), uv_values[:, 0].floor().long()] = 1 | |
uv_mask[uv_values[:, 1].ceil().long(), uv_values[:, 0].ceil().long()] = 1 | |
uv_mask = crop(uv_mask.unsqueeze(0).unsqueeze(0), min_v, min_u, max_v, max_u) | |
return uv_mask.detach().cpu() # shape [1, 1, resolution, resolution] | |
def get_high_res_atlas(atlas_model, min_v, min_u, max_v, max_u, resolution, device="cuda", layer="background"): | |
inds_grid = get_grid_indices(0, 0, resolution, resolution) | |
inds_grid_chunks = inds_grid.split(50000, dim=0) | |
if layer == "background": | |
shift = -1 | |
else: | |
shift = 0 | |
rendered_atlas = torch.zeros((resolution, resolution, 3)).to(device) # resy, resx, 3 | |
with torch.no_grad(): | |
# reconstruct image row by row | |
for chunk in inds_grid_chunks: | |
normalized_chunk = torch.stack( | |
[ | |
(chunk[:, 0] / resolution) + shift, | |
(chunk[:, 1] / resolution) + shift, | |
], | |
dim=-1, | |
).to(device) | |
rgb_output = atlas_model(normalized_chunk) | |
rendered_atlas[chunk[:, 1], chunk[:, 0], :] = rgb_output | |
# move colors to RGB color domain (0,1) | |
rendered_atlas = 0.5 * (rendered_atlas + 1) | |
rendered_atlas = rendered_atlas.permute(2, 0, 1).unsqueeze(0) # shape (1, 3, resy, resx) | |
cropped_atlas = crop( | |
rendered_atlas, | |
min_v, | |
min_u, | |
max_v, | |
max_u, | |
) | |
return cropped_atlas | |
def get_grid_indices(x_start, y_start, h_crop, w_crop, t=None): | |
crop_indices = torch.meshgrid(torch.arange(w_crop) + x_start, torch.arange(h_crop) + y_start) | |
crop_indices = torch.stack(crop_indices, dim=-1) | |
crop_indices = crop_indices.reshape(h_crop * w_crop, crop_indices.shape[-1]) | |
if t is not None: | |
crop_indices = torch.cat([crop_indices, t.repeat(h_crop * w_crop, 1)], dim=1) | |
return crop_indices | |
def get_atlas_crops(uv_values, grid_atlas, augmentation=None): | |
if len(uv_values.shape) == 3: | |
dims = [0, 1] | |
elif len(uv_values.shape) == 4: | |
dims = [0, 1, 2] | |
else: | |
raise ValueError("uv_values should be of shape of len 3 or 4") | |
min_u, min_v = uv_values.amin(dim=dims).long() | |
max_u, max_v = uv_values.amax(dim=dims).ceil().long() | |
# min_u, min_v = uv_values.min(dim=0).values | |
# max_u, max_v = uv_values.max(dim=0).values | |
h_v = max_v - min_v | |
w_u = max_u - min_u | |
atlas_crop = crop(grid_atlas, min_v, min_u, h_v, w_u) | |
if augmentation is not None: | |
atlas_crop = augmentation(atlas_crop) | |
return atlas_crop, torch.stack([min_u, min_v]), torch.stack([max_u, max_v]) | |
def get_random_crop_params(input_size, output_size): | |
w, h = input_size | |
th, tw = output_size | |
if h + 1 < th or w + 1 < tw: | |
raise ValueError(f"Required crop size {(th, tw)} is larger then input image size {(h, w)}") | |
if w == tw and h == th: | |
return 0, 0, h, w | |
i = torch.randint(0, h - th + 1, size=(1,)).item() | |
j = torch.randint(0, w - tw + 1, size=(1,)).item() | |
return i, j, th, tw | |
def get_masks_boundaries(alpha_video, border=20, threshold=0.95, min_crop_size=2 ** 7 + 1): | |
resy, resx = alpha_video.shape[-2:] | |
num_frames = alpha_video.shape[0] | |
masks_borders = torch.zeros((num_frames, 4), dtype=torch.int64) | |
for i, file in enumerate(range(num_frames)): | |
mask_im = alpha_video[i] | |
mask_im[mask_im >= threshold] = 1 | |
mask_im[mask_im < threshold] = 0 | |
all_ones = mask_im.squeeze().nonzero() | |
min_y, min_x = torch.maximum(all_ones.min(dim=0).values - border, torch.tensor([0, 0])) | |
max_y, max_x = torch.minimum(all_ones.max(dim=0).values + border, torch.tensor([resy, resx])) | |
h = max_y - min_y | |
w = max_x - min_x | |
if h < min_crop_size: | |
pad = min_crop_size - h | |
if max_y + pad > resy: | |
min_y -= pad | |
else: | |
max_y += pad | |
h = max_y - min_y | |
if w < min_crop_size: | |
pad = min_crop_size - w | |
if max_x + pad > resx: | |
min_x -= pad | |
else: | |
max_x += pad | |
w = max_x - min_x | |
masks_borders[i] = torch.tensor([min_y, min_x, h, w]) | |
return masks_borders | |
def get_atlas_bounding_box(mask_boundaries, grid_atlas, video_uvs): | |
min_uv = torch.tensor(grid_atlas.shape[-2:], device=video_uvs.device) | |
max_uv = torch.tensor([0, 0], device=video_uvs.device) | |
for boundary, frame in zip(mask_boundaries, video_uvs): | |
cropped_uvs = crop(frame.permute(2, 0, 1).unsqueeze(0), *list(boundary)) # 1,2,h,w | |
min_uv = torch.minimum(cropped_uvs.amin(dim=[0, 2, 3]), min_uv).floor().int() | |
max_uv = torch.maximum(cropped_uvs.amax(dim=[0, 2, 3]), max_uv).ceil().int() | |
hw = max_uv - min_uv | |
crop_data = [*list(min_uv)[::-1], *list(hw)[::-1]] | |
return crop(grid_atlas, *crop_data), crop_data | |