text2live / util /atlas_utils.py
SupermanxKiaski's picture
Upload 351 files
3b40f46
import scipy.interpolate
import torch
from torchvision.transforms.functional import crop
from tqdm import tqdm
from models.implicit_neural_networks import IMLP
def load_neural_atlases_models(config):
foreground_mapping = IMLP(
input_dim=3,
output_dim=2,
hidden_dim=256,
use_positional=False,
num_layers=6,
skip_layers=[],
).to(config["device"])
background_mapping = IMLP(
input_dim=3,
output_dim=2,
hidden_dim=256,
use_positional=False,
num_layers=4,
skip_layers=[],
).to(config["device"])
foreground_atlas_model = IMLP(
input_dim=2,
output_dim=3,
hidden_dim=256,
use_positional=True,
positional_dim=10,
num_layers=8,
skip_layers=[4, 7],
).to(config["device"])
background_atlas_model = IMLP(
input_dim=2,
output_dim=3,
hidden_dim=256,
use_positional=True,
positional_dim=10,
num_layers=8,
skip_layers=[4, 7],
).to(config["device"])
alpha_model = IMLP(
input_dim=3,
output_dim=1,
hidden_dim=256,
use_positional=True,
positional_dim=5,
num_layers=8,
skip_layers=[],
).to(config["device"])
checkpoint = torch.load(config["checkpoint_path"])
foreground_mapping.load_state_dict(checkpoint["model_F_mapping1_state_dict"])
background_mapping.load_state_dict(checkpoint["model_F_mapping2_state_dict"])
foreground_atlas_model.load_state_dict(checkpoint["F_atlas_state_dict"])
background_atlas_model.load_state_dict(checkpoint["F_atlas_state_dict"])
alpha_model.load_state_dict(checkpoint["model_F_alpha_state_dict"])
foreground_mapping = foreground_mapping.eval().requires_grad_(False)
background_mapping = background_mapping.eval().requires_grad_(False)
foreground_atlas_model = foreground_atlas_model.eval().requires_grad_(False)
background_atlas_model = background_atlas_model.eval().requires_grad_(False)
alpha_model = alpha_model.eval().requires_grad_(False)
return foreground_mapping, background_mapping, foreground_atlas_model, background_atlas_model, alpha_model
@torch.no_grad()
def get_frames_data(config, foreground_mapping, background_mapping, alpha_model):
max_size = max(config["resx"], config["resy"])
normalizing_factor = torch.tensor([max_size / 2, max_size / 2, config["maximum_number_of_frames"] / 2])
background_uv_values = torch.zeros(
size=(config["maximum_number_of_frames"], config["resy"], config["resx"], 2), device=config["device"]
)
foreground_uv_values = torch.zeros(
size=(config["maximum_number_of_frames"], config["resy"], config["resx"], 2), device=config["device"]
)
alpha = torch.zeros(
size=(config["maximum_number_of_frames"], config["resy"], config["resx"], 1), device=config["device"]
)
for frame in tqdm(range(config["maximum_number_of_frames"]), leave=False):
indices = get_grid_indices(0, 0, config["resy"], config["resx"], t=torch.tensor(frame))
normalized_chunk = (indices / normalizing_factor - 1).to(config["device"])
# get the atlas UV coordinates from the two mapping networks;
with torch.no_grad():
current_background_uv_values = background_mapping(normalized_chunk)
current_foreground_uv_values = foreground_mapping(normalized_chunk)
current_alpha = alpha_model(normalized_chunk)
background_uv_values[frame, indices[:, 1], indices[:, 0]] = current_background_uv_values * 0.5 - 0.5
foreground_uv_values[frame, indices[:, 1], indices[:, 0]] = current_foreground_uv_values * 0.5 + 0.5
current_alpha = 0.5 * (current_alpha + 1.0)
current_alpha = 0.99 * current_alpha + 0.001
alpha[frame, indices[:, 1], indices[:, 0]] = current_alpha
if config["return_atlas_alpha"]: # this should take a few minutes
foreground_atlas_alpha = torch.zeros(
size=(
config["maximum_number_of_frames"],
config["grid_atlas_resolution"],
config["grid_atlas_resolution"],
1,
),
)
foreground_uv_values_grid = foreground_uv_values * config["grid_atlas_resolution"]
indices = get_grid_indices(0, 0, config["grid_atlas_resolution"], config["grid_atlas_resolution"])
for frame in tqdm(range(config["maximum_number_of_frames"]), leave=False):
interpolated = scipy.interpolate.griddata(
foreground_uv_values_grid[frame].reshape(-1, 2).cpu().numpy(),
alpha[frame]
.reshape(
-1,
)
.cpu()
.numpy(),
indices.reshape(-1, 2).cpu().numpy(),
method="linear",
).reshape(config["grid_atlas_resolution"], config["grid_atlas_resolution"], 1)
foreground_atlas_alpha[frame] = torch.from_numpy(interpolated)
foreground_atlas_alpha[foreground_atlas_alpha.isnan()] = 0.0
foreground_atlas_alpha = (
torch.median(foreground_atlas_alpha, dim=0, keepdim=True).values.to(config["device"]).permute(0, 3, 2, 1)
)
else:
foreground_atlas_alpha = None
return background_uv_values, foreground_uv_values, alpha.permute(0, 3, 1, 2), foreground_atlas_alpha
@torch.no_grad()
def reconstruct_video_layer(uv_values, atlas_model):
t, h, w, _ = uv_values.shape
reconstruction = torch.zeros(size=(t, h, w, 3), device=uv_values.device)
for frame in range(t):
rgb = (atlas_model(uv_values[frame].reshape(-1, 2)) + 1) * 0.5
reconstruction[frame] = rgb.reshape(h, w, 3)
return reconstruction.permute(0, 3, 1, 2)
@torch.no_grad()
def create_uv_mask(config, mapping_model, min_u, min_v, max_u, max_v, uv_shift=-0.5, resolution_shift=1):
max_size = max(config["resx"], config["resy"])
normalizing_factor = torch.tensor([max_size / 2, max_size / 2, config["maximum_number_of_frames"] / 2])
resolution = config["grid_atlas_resolution"]
uv_mask = torch.zeros(size=(resolution, resolution), device=config["device"])
for frame in tqdm(range(config["maximum_number_of_frames"]), leave=False):
indices = get_grid_indices(0, 0, config["resy"], config["resx"], t=torch.tensor(frame))
for chunk in indices.split(50000, dim=0):
normalized_chunk = (chunk / normalizing_factor - 1).to(config["device"])
# get the atlas UV coordinates from the two mapping networks;
with torch.no_grad():
uv_values = mapping_model(normalized_chunk)
uv_values = uv_values * 0.5 + uv_shift
uv_values = ((uv_values + resolution_shift) * resolution).clip(0, resolution - 1)
uv_mask[uv_values[:, 1].floor().long(), uv_values[:, 0].floor().long()] = 1
uv_mask[uv_values[:, 1].floor().long(), uv_values[:, 0].ceil().long()] = 1
uv_mask[uv_values[:, 1].ceil().long(), uv_values[:, 0].floor().long()] = 1
uv_mask[uv_values[:, 1].ceil().long(), uv_values[:, 0].ceil().long()] = 1
uv_mask = crop(uv_mask.unsqueeze(0).unsqueeze(0), min_v, min_u, max_v, max_u)
return uv_mask.detach().cpu() # shape [1, 1, resolution, resolution]
@torch.no_grad()
def get_high_res_atlas(atlas_model, min_v, min_u, max_v, max_u, resolution, device="cuda", layer="background"):
inds_grid = get_grid_indices(0, 0, resolution, resolution)
inds_grid_chunks = inds_grid.split(50000, dim=0)
if layer == "background":
shift = -1
else:
shift = 0
rendered_atlas = torch.zeros((resolution, resolution, 3)).to(device) # resy, resx, 3
with torch.no_grad():
# reconstruct image row by row
for chunk in inds_grid_chunks:
normalized_chunk = torch.stack(
[
(chunk[:, 0] / resolution) + shift,
(chunk[:, 1] / resolution) + shift,
],
dim=-1,
).to(device)
rgb_output = atlas_model(normalized_chunk)
rendered_atlas[chunk[:, 1], chunk[:, 0], :] = rgb_output
# move colors to RGB color domain (0,1)
rendered_atlas = 0.5 * (rendered_atlas + 1)
rendered_atlas = rendered_atlas.permute(2, 0, 1).unsqueeze(0) # shape (1, 3, resy, resx)
cropped_atlas = crop(
rendered_atlas,
min_v,
min_u,
max_v,
max_u,
)
return cropped_atlas
def get_grid_indices(x_start, y_start, h_crop, w_crop, t=None):
crop_indices = torch.meshgrid(torch.arange(w_crop) + x_start, torch.arange(h_crop) + y_start)
crop_indices = torch.stack(crop_indices, dim=-1)
crop_indices = crop_indices.reshape(h_crop * w_crop, crop_indices.shape[-1])
if t is not None:
crop_indices = torch.cat([crop_indices, t.repeat(h_crop * w_crop, 1)], dim=1)
return crop_indices
def get_atlas_crops(uv_values, grid_atlas, augmentation=None):
if len(uv_values.shape) == 3:
dims = [0, 1]
elif len(uv_values.shape) == 4:
dims = [0, 1, 2]
else:
raise ValueError("uv_values should be of shape of len 3 or 4")
min_u, min_v = uv_values.amin(dim=dims).long()
max_u, max_v = uv_values.amax(dim=dims).ceil().long()
# min_u, min_v = uv_values.min(dim=0).values
# max_u, max_v = uv_values.max(dim=0).values
h_v = max_v - min_v
w_u = max_u - min_u
atlas_crop = crop(grid_atlas, min_v, min_u, h_v, w_u)
if augmentation is not None:
atlas_crop = augmentation(atlas_crop)
return atlas_crop, torch.stack([min_u, min_v]), torch.stack([max_u, max_v])
def get_random_crop_params(input_size, output_size):
w, h = input_size
th, tw = output_size
if h + 1 < th or w + 1 < tw:
raise ValueError(f"Required crop size {(th, tw)} is larger then input image size {(h, w)}")
if w == tw and h == th:
return 0, 0, h, w
i = torch.randint(0, h - th + 1, size=(1,)).item()
j = torch.randint(0, w - tw + 1, size=(1,)).item()
return i, j, th, tw
def get_masks_boundaries(alpha_video, border=20, threshold=0.95, min_crop_size=2 ** 7 + 1):
resy, resx = alpha_video.shape[-2:]
num_frames = alpha_video.shape[0]
masks_borders = torch.zeros((num_frames, 4), dtype=torch.int64)
for i, file in enumerate(range(num_frames)):
mask_im = alpha_video[i]
mask_im[mask_im >= threshold] = 1
mask_im[mask_im < threshold] = 0
all_ones = mask_im.squeeze().nonzero()
min_y, min_x = torch.maximum(all_ones.min(dim=0).values - border, torch.tensor([0, 0]))
max_y, max_x = torch.minimum(all_ones.max(dim=0).values + border, torch.tensor([resy, resx]))
h = max_y - min_y
w = max_x - min_x
if h < min_crop_size:
pad = min_crop_size - h
if max_y + pad > resy:
min_y -= pad
else:
max_y += pad
h = max_y - min_y
if w < min_crop_size:
pad = min_crop_size - w
if max_x + pad > resx:
min_x -= pad
else:
max_x += pad
w = max_x - min_x
masks_borders[i] = torch.tensor([min_y, min_x, h, w])
return masks_borders
def get_atlas_bounding_box(mask_boundaries, grid_atlas, video_uvs):
min_uv = torch.tensor(grid_atlas.shape[-2:], device=video_uvs.device)
max_uv = torch.tensor([0, 0], device=video_uvs.device)
for boundary, frame in zip(mask_boundaries, video_uvs):
cropped_uvs = crop(frame.permute(2, 0, 1).unsqueeze(0), *list(boundary)) # 1,2,h,w
min_uv = torch.minimum(cropped_uvs.amin(dim=[0, 2, 3]), min_uv).floor().int()
max_uv = torch.maximum(cropped_uvs.amax(dim=[0, 2, 3]), max_uv).ceil().int()
hw = max_uv - min_uv
crop_data = [*list(min_uv)[::-1], *list(hw)[::-1]]
return crop(grid_atlas, *crop_data), crop_data