Spaces:
Paused
Paused
| import copy | |
| import random | |
| import numpy as np | |
| import scipy.signal | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from lib import seg_dvgo as dvgo | |
| from lib import seg_dcvgo as dcvgo | |
| from .load_data import load_data | |
| from .masked_adam import MaskedAdam | |
| from torch import Tensor | |
| ''' Misc | |
| ''' | |
| mse2psnr = lambda x : -10. * torch.log10(x) | |
| to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8) | |
| def seed_everything(args): | |
| '''Seed everything for better reproducibility. | |
| NOTE that some pytorch operation is non-deterministic like the backprop of grid_samples | |
| ''' | |
| torch.manual_seed(args.seed) | |
| np.random.seed(args.seed) | |
| random.seed(args.seed) | |
| def cal_IoU(a: Tensor, b: Tensor) -> Tensor: | |
| """Calculates the Intersection over Union (IoU) between two tensors. | |
| Args: | |
| a: A tensor of shape (N, H, W). | |
| b: A tensor of shape (N, H, W). | |
| Returns: | |
| A tensor of shape (N,) containing the IoU score between each pair of | |
| elements in a and b. | |
| """ | |
| intersection = torch.count_nonzero(torch.logical_and(a == b, a != 0)) | |
| union = torch.count_nonzero(a + b) | |
| return intersection / union | |
| def load_everything(args, cfg): | |
| '''Load images / poses / camera settings / data split. | |
| ''' | |
| data_dict = load_data(cfg.data) | |
| # remove useless field | |
| kept_keys = { | |
| 'hwf', 'HW', 'Ks', 'near', 'far', 'near_clip', | |
| 'i_train', 'i_val', 'i_test', 'irregular_shape', | |
| 'poses', 'render_poses', 'images'} | |
| for k in list(data_dict.keys()): | |
| if k not in kept_keys: | |
| data_dict.pop(k) | |
| # construct data tensor | |
| if data_dict['irregular_shape']: | |
| data_dict['images'] = [torch.FloatTensor(im, device='cpu') for im in data_dict['images']] | |
| else: | |
| data_dict['images'] = torch.FloatTensor(data_dict['images'], device='cpu') | |
| data_dict['poses'] = torch.Tensor(data_dict['poses']) | |
| data_dict['render_poses'] = torch.Tensor(data_dict['render_poses']) | |
| return data_dict | |
| # semantic nerf is used for reproducing the segmentation of SPIn-NeRF | |
| def load_existed_model(args, cfg, cfg_train, reload_ckpt_path, device): | |
| model_class = find_model(cfg) | |
| model = load_model(model_class, reload_ckpt_path).to(device) | |
| optimizer = create_optimizer_or_freeze_model(model, cfg_train, global_step=0) | |
| model, optimizer, start = load_checkpoint( | |
| model, optimizer, reload_ckpt_path, no_reload_optimizer = True) | |
| return model, optimizer, start | |
| def gen_rand_colors(num_obj): | |
| rand_colors = np.random.rand(num_obj + 1, 3) | |
| rand_colors[-1,:] = 0 | |
| return rand_colors | |
| def to_cuda(batch, device=torch.device('cuda')): | |
| if isinstance(batch, tuple) or isinstance(batch, list): | |
| batch = [to_cuda(b, device) for b in batch] | |
| elif isinstance(batch, dict): | |
| batch_ = {} | |
| for key in batch: | |
| if key == 'meta': | |
| batch_[key] = batch[key] | |
| else: | |
| batch_[key] = to_cuda(batch[key], device) | |
| batch = batch_ | |
| elif isinstance(batch, np.ndarray): | |
| batch = torch.from_numpy(batch).to(device) | |
| else: | |
| batch = batch.to(device) | |
| return batch | |
| def to_tensor(array, device=torch.device('cuda')): | |
| '''cvt numpy array to cuda tensor, if already tensor, do nothing | |
| ''' | |
| if isinstance(array, np.ndarray): | |
| array = torch.from_numpy(array).to(device) | |
| elif isinstance(array, torch.Tensor) and not array.is_cuda: | |
| array = array.to(device) | |
| else: | |
| pass | |
| return array.float() | |
| ''' optimizer | |
| ''' | |
| def create_optimizer_or_freeze_model(model, cfg_train, global_step): | |
| decay_steps = cfg_train.lrate_decay * 1000 | |
| decay_factor = 0.1 ** (global_step/decay_steps) | |
| param_group = [] | |
| for k in cfg_train.keys(): | |
| if not k.startswith('lrate_'): | |
| continue | |
| k = k[len('lrate_'):] | |
| if not hasattr(model, k): | |
| continue | |
| param = getattr(model, k) | |
| if param is None: | |
| print(f'create_optimizer_or_freeze_model: param {k} not exist') | |
| continue | |
| lr = getattr(cfg_train, f'lrate_{k}') * decay_factor | |
| if lr > 0: | |
| print(f'create_optimizer_or_freeze_model: param {k} lr {lr}') | |
| if isinstance(param, nn.Module): | |
| param = param.parameters() | |
| param_group.append({'params': param, 'lr': lr, 'skip_zero_grad': (k in cfg_train.skip_zero_grad_fields)}) | |
| else: | |
| print(f'create_optimizer_or_freeze_model: param {k} freeze') | |
| param.requires_grad = False | |
| return MaskedAdam(param_group) | |
| def create_segmentation_optimizer(model, cfg_train): | |
| param_group = [] | |
| for k in cfg_train.keys(): | |
| if not k.startswith('lrate_'): | |
| continue | |
| k = k[len('lrate_'):] | |
| if not hasattr(model, k): | |
| continue | |
| param = getattr(model, k) | |
| if param is None: | |
| print(f'create_optimizer_or_freeze_model: param {k} not exist') | |
| continue | |
| lr = getattr(cfg_train, f'lrate_{k}') | |
| if lr > 0: | |
| print(f'create_optimizer_or_freeze_model: param {k} lr {lr}') | |
| if isinstance(param, nn.Module): | |
| param = param.parameters() | |
| param_group.append({'params': param, 'lr': lr}) | |
| else: | |
| print(f'create_optimizer_or_freeze_model: param {k} freeze') | |
| param.requires_grad = False | |
| return torch.optim.SGD(param_group) | |
| ''' Checkpoint utils | |
| ''' | |
| def load_checkpoint(model, optimizer, ckpt_path, no_reload_optimizer): | |
| ckpt = torch.load(ckpt_path) | |
| try: | |
| start = ckpt['global_step'] | |
| except: | |
| start = 0 | |
| msg = model.load_state_dict(ckpt['model_state_dict'], strict = False) | |
| print("NeRF loaded with msg: ", msg) | |
| if not no_reload_optimizer: | |
| optimizer.load_state_dict(ckpt['optimizer_state_dict']) | |
| return model, optimizer, start | |
| def find_model(cfg): | |
| if cfg.data.ndc: | |
| model_class = dvgo.DirectVoxGO | |
| elif cfg.data.unbounded_inward: | |
| model_class = dcvgo.DirectContractedVoxGO | |
| else: | |
| model_class = dvgo.DirectVoxGO | |
| return model_class | |
| def load_model(model_class, ckpt_path): | |
| ckpt = torch.load(ckpt_path) | |
| num_objects = 1 | |
| if 'seg_mask_grid.grid' in ckpt['model_state_dict'].keys(): | |
| num_objects = ckpt['model_state_dict']['seg_mask_grid.grid'].shape[1] | |
| print("Load model with num_objects =", num_objects) | |
| model = model_class(num_objects = num_objects, **ckpt['model_kwargs']) | |
| msg = model.load_state_dict(ckpt['model_state_dict'], strict = False) | |
| print("NeRF loaded with msg: ", msg) | |
| return model | |
| def create_new_model(cfg, cfg_model, cfg_train, xyz_min, xyz_max, stage, coarse_ckpt_path, device=torch.device('cuda')): | |
| model_kwargs = copy.deepcopy(cfg_model) | |
| num_voxels = model_kwargs.pop('num_voxels') | |
| if len(cfg_train.pg_scale): | |
| num_voxels = int(num_voxels / (2**len(cfg_train.pg_scale))) | |
| if cfg.data.ndc: | |
| #print(f'scene_rep_reconstruction ({stage}): \033[96muse multiplane images\033[0m') | |
| #model = dmpigo.DirectMPIGO( | |
| # xyz_min=xyz_min, xyz_max=xyz_max, | |
| # num_voxels=num_voxels, | |
| # **model_kwargs) | |
| print(f'scene_rep_reconstruction ({stage}): \033[96muse dense voxel grid\033[0m') | |
| model = dvgo.DirectVoxGO( | |
| xyz_min=xyz_min, xyz_max=xyz_max, | |
| num_voxels=num_voxels, | |
| mask_cache_path=coarse_ckpt_path, | |
| **model_kwargs) | |
| elif cfg.data.unbounded_inward: | |
| print(f'scene_rep_reconstruction ({stage}): \033[96muse contraced voxel grid (covering unbounded)\033[0m') | |
| model = dcvgo.DirectContractedVoxGO( | |
| xyz_min=xyz_min, xyz_max=xyz_max, | |
| num_voxels=num_voxels, | |
| **model_kwargs) | |
| else: | |
| print(f'scene_rep_reconstruction ({stage}): \033[96muse dense voxel grid\033[0m') | |
| model = dvgo.DirectVoxGO( | |
| xyz_min=xyz_min, xyz_max=xyz_max, | |
| num_voxels=num_voxels, | |
| mask_cache_path=coarse_ckpt_path, | |
| **model_kwargs) | |
| model = model.to(device) | |
| optimizer = create_optimizer_or_freeze_model(model, cfg_train, global_step=0) | |
| return model, optimizer | |
| ''' Evaluation metrics (ssim, lpips) | |
| ''' | |
| def rgb_ssim(img0, img1, max_val, | |
| filter_size=11, | |
| filter_sigma=1.5, | |
| k1=0.01, | |
| k2=0.03, | |
| return_map=False): | |
| # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58 | |
| assert len(img0.shape) == 3 | |
| assert img0.shape[-1] == 3 | |
| assert img0.shape == img1.shape | |
| # Construct a 1D Gaussian blur filter. | |
| hw = filter_size // 2 | |
| shift = (2 * hw - filter_size + 1) / 2 | |
| f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2 | |
| filt = np.exp(-0.5 * f_i) | |
| filt /= np.sum(filt) | |
| # Blur in x and y (faster than the 2D convolution). | |
| def convolve2d(z, f): | |
| return scipy.signal.convolve2d(z, f, mode='valid') | |
| filt_fn = lambda z: np.stack([ | |
| convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :]) | |
| for i in range(z.shape[-1])], -1) | |
| mu0 = filt_fn(img0) | |
| mu1 = filt_fn(img1) | |
| mu00 = mu0 * mu0 | |
| mu11 = mu1 * mu1 | |
| mu01 = mu0 * mu1 | |
| sigma00 = filt_fn(img0**2) - mu00 | |
| sigma11 = filt_fn(img1**2) - mu11 | |
| sigma01 = filt_fn(img0 * img1) - mu01 | |
| # Clip the variances and covariances to valid values. | |
| # Variance must be non-negative: | |
| sigma00 = np.maximum(0., sigma00) | |
| sigma11 = np.maximum(0., sigma11) | |
| sigma01 = np.sign(sigma01) * np.minimum( | |
| np.sqrt(sigma00 * sigma11), np.abs(sigma01)) | |
| c1 = (k1 * max_val)**2 | |
| c2 = (k2 * max_val)**2 | |
| numer = (2 * mu01 + c1) * (2 * sigma01 + c2) | |
| denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) | |
| ssim_map = numer / denom | |
| ssim = np.mean(ssim_map) | |
| return ssim_map if return_map else ssim | |
| __LPIPS__ = {} | |
| def init_lpips(net_name, device): | |
| assert net_name in ['alex', 'vgg'] | |
| import lpips | |
| print(f'init_lpips: lpips_{net_name}') | |
| return lpips.LPIPS(net=net_name, version='0.1').eval().to(device) | |
| def rgb_lpips(np_gt, np_im, net_name, device): | |
| if net_name not in __LPIPS__: | |
| __LPIPS__[net_name] = init_lpips(net_name, device) | |
| gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device) | |
| im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device) | |
| return __LPIPS__[net_name](gt, im, normalize=True).item() | |
| ''' generate rays | |
| ''' | |
| def get_rays(H, W, K, c2w, inverse_y, flip_x, flip_y, mode='center'): | |
| i, j = torch.meshgrid( | |
| torch.linspace(0, W-1, W, device=c2w.device), | |
| torch.linspace(0, H-1, H, device=c2w.device)) # pytorch's meshgrid has indexing='ij' | |
| i = i.t().float() | |
| j = j.t().float() | |
| if mode == 'lefttop': | |
| pass | |
| elif mode == 'center': | |
| i, j = i+0.5, j+0.5 | |
| elif mode == 'random': | |
| i = i+torch.rand_like(i) | |
| j = j+torch.rand_like(j) | |
| else: | |
| raise NotImplementedError | |
| if flip_x: | |
| i = i.flip((1,)) | |
| if flip_y: | |
| j = j.flip((0,)) | |
| if inverse_y: | |
| dirs = torch.stack([(i-K[0][2])/K[0][0], (j-K[1][2])/K[1][1], torch.ones_like(i)], -1) | |
| else: | |
| dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1) | |
| # Rotate ray directions from camera frame to the world frame | |
| rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] | |
| # Translate camera frame's origin to the world frame. It is the origin of all rays. | |
| rays_o = c2w[:3,3].expand(rays_d.shape) | |
| return rays_o, rays_d | |
| def ndc_rays(H, W, focal, near, rays_o, rays_d): | |
| # Shift ray origins to near plane | |
| t = -(near + rays_o[...,2]) / rays_d[...,2] | |
| rays_o = rays_o + t[...,None] * rays_d | |
| # Projection | |
| o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2] | |
| o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2] | |
| o2 = 1. + 2. * near / rays_o[...,2] | |
| d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2]) | |
| d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2]) | |
| d2 = -2. * near / rays_o[...,2] | |
| rays_o = torch.stack([o0,o1,o2], -1) | |
| rays_d = torch.stack([d0,d1,d2], -1) | |
| return rays_o, rays_d | |
| def get_rays_of_a_view(H, W, K, c2w, ndc, inverse_y, flip_x, flip_y, mode='center'): | |
| rays_o, rays_d = get_rays(H, W, K, c2w, inverse_y=inverse_y, flip_x=flip_x, flip_y=flip_y, mode=mode) | |
| viewdirs = rays_d / rays_d.norm(dim=-1, keepdim=True) | |
| if ndc: | |
| rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d) | |
| return rays_o, rays_d, viewdirs | |
| ''' interactive mode TODO''' | |
| def fetch_user_define_points(): | |
| pass | |