Our3D / lib /utils.py
yansong1616's picture
Upload 384 files
b177539 verified
raw
history blame
12.9 kB
import copy
import random
import numpy as np
import scipy.signal
import torch
import torch.nn as nn
import torch.nn.functional as F
from lib import seg_dvgo as dvgo
from lib import seg_dcvgo as dcvgo
from .load_data import load_data
from .masked_adam import MaskedAdam
from torch import Tensor
''' Misc
'''
mse2psnr = lambda x : -10. * torch.log10(x)
to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)
def seed_everything(args):
'''Seed everything for better reproducibility.
NOTE that some pytorch operation is non-deterministic like the backprop of grid_samples
'''
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
@torch.jit.script
def cal_IoU(a: Tensor, b: Tensor) -> Tensor:
"""Calculates the Intersection over Union (IoU) between two tensors.
Args:
a: A tensor of shape (N, H, W).
b: A tensor of shape (N, H, W).
Returns:
A tensor of shape (N,) containing the IoU score between each pair of
elements in a and b.
"""
intersection = torch.count_nonzero(torch.logical_and(a == b, a != 0))
union = torch.count_nonzero(a + b)
return intersection / union
def load_everything(args, cfg):
'''Load images / poses / camera settings / data split.
'''
data_dict = load_data(cfg.data)
# remove useless field
kept_keys = {
'hwf', 'HW', 'Ks', 'near', 'far', 'near_clip',
'i_train', 'i_val', 'i_test', 'irregular_shape',
'poses', 'render_poses', 'images'}
for k in list(data_dict.keys()):
if k not in kept_keys:
data_dict.pop(k)
# construct data tensor
if data_dict['irregular_shape']:
data_dict['images'] = [torch.FloatTensor(im, device='cpu') for im in data_dict['images']]
else:
data_dict['images'] = torch.FloatTensor(data_dict['images'], device='cpu')
data_dict['poses'] = torch.Tensor(data_dict['poses'])
data_dict['render_poses'] = torch.Tensor(data_dict['render_poses'])
return data_dict
# semantic nerf is used for reproducing the segmentation of SPIn-NeRF
def load_existed_model(args, cfg, cfg_train, reload_ckpt_path, device):
model_class = find_model(cfg)
model = load_model(model_class, reload_ckpt_path).to(device)
optimizer = create_optimizer_or_freeze_model(model, cfg_train, global_step=0)
model, optimizer, start = load_checkpoint(
model, optimizer, reload_ckpt_path, no_reload_optimizer = True)
return model, optimizer, start
def gen_rand_colors(num_obj):
rand_colors = np.random.rand(num_obj + 1, 3)
rand_colors[-1,:] = 0
return rand_colors
def to_cuda(batch, device=torch.device('cuda')):
if isinstance(batch, tuple) or isinstance(batch, list):
batch = [to_cuda(b, device) for b in batch]
elif isinstance(batch, dict):
batch_ = {}
for key in batch:
if key == 'meta':
batch_[key] = batch[key]
else:
batch_[key] = to_cuda(batch[key], device)
batch = batch_
elif isinstance(batch, np.ndarray):
batch = torch.from_numpy(batch).to(device)
else:
batch = batch.to(device)
return batch
def to_tensor(array, device=torch.device('cuda')):
'''cvt numpy array to cuda tensor, if already tensor, do nothing
'''
if isinstance(array, np.ndarray):
array = torch.from_numpy(array).to(device)
elif isinstance(array, torch.Tensor) and not array.is_cuda:
array = array.to(device)
else:
pass
return array.float()
''' optimizer
'''
def create_optimizer_or_freeze_model(model, cfg_train, global_step):
decay_steps = cfg_train.lrate_decay * 1000
decay_factor = 0.1 ** (global_step/decay_steps)
param_group = []
for k in cfg_train.keys():
if not k.startswith('lrate_'):
continue
k = k[len('lrate_'):]
if not hasattr(model, k):
continue
param = getattr(model, k)
if param is None:
print(f'create_optimizer_or_freeze_model: param {k} not exist')
continue
lr = getattr(cfg_train, f'lrate_{k}') * decay_factor
if lr > 0:
print(f'create_optimizer_or_freeze_model: param {k} lr {lr}')
if isinstance(param, nn.Module):
param = param.parameters()
param_group.append({'params': param, 'lr': lr, 'skip_zero_grad': (k in cfg_train.skip_zero_grad_fields)})
else:
print(f'create_optimizer_or_freeze_model: param {k} freeze')
param.requires_grad = False
return MaskedAdam(param_group)
def create_segmentation_optimizer(model, cfg_train):
param_group = []
for k in cfg_train.keys():
if not k.startswith('lrate_'):
continue
k = k[len('lrate_'):]
if not hasattr(model, k):
continue
param = getattr(model, k)
if param is None:
print(f'create_optimizer_or_freeze_model: param {k} not exist')
continue
lr = getattr(cfg_train, f'lrate_{k}')
if lr > 0:
print(f'create_optimizer_or_freeze_model: param {k} lr {lr}')
if isinstance(param, nn.Module):
param = param.parameters()
param_group.append({'params': param, 'lr': lr})
else:
print(f'create_optimizer_or_freeze_model: param {k} freeze')
param.requires_grad = False
return torch.optim.SGD(param_group)
''' Checkpoint utils
'''
def load_checkpoint(model, optimizer, ckpt_path, no_reload_optimizer):
ckpt = torch.load(ckpt_path)
try:
start = ckpt['global_step']
except:
start = 0
msg = model.load_state_dict(ckpt['model_state_dict'], strict = False)
print("NeRF loaded with msg: ", msg)
if not no_reload_optimizer:
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
return model, optimizer, start
def find_model(cfg):
if cfg.data.ndc:
model_class = dvgo.DirectVoxGO
elif cfg.data.unbounded_inward:
model_class = dcvgo.DirectContractedVoxGO
else:
model_class = dvgo.DirectVoxGO
return model_class
def load_model(model_class, ckpt_path):
ckpt = torch.load(ckpt_path)
num_objects = 1
if 'seg_mask_grid.grid' in ckpt['model_state_dict'].keys():
num_objects = ckpt['model_state_dict']['seg_mask_grid.grid'].shape[1]
print("Load model with num_objects =", num_objects)
model = model_class(num_objects = num_objects, **ckpt['model_kwargs'])
msg = model.load_state_dict(ckpt['model_state_dict'], strict = False)
print("NeRF loaded with msg: ", msg)
return model
def create_new_model(cfg, cfg_model, cfg_train, xyz_min, xyz_max, stage, coarse_ckpt_path, device=torch.device('cuda')):
model_kwargs = copy.deepcopy(cfg_model)
num_voxels = model_kwargs.pop('num_voxels')
if len(cfg_train.pg_scale):
num_voxels = int(num_voxels / (2**len(cfg_train.pg_scale)))
if cfg.data.ndc:
#print(f'scene_rep_reconstruction ({stage}): \033[96muse multiplane images\033[0m')
#model = dmpigo.DirectMPIGO(
# xyz_min=xyz_min, xyz_max=xyz_max,
# num_voxels=num_voxels,
# **model_kwargs)
print(f'scene_rep_reconstruction ({stage}): \033[96muse dense voxel grid\033[0m')
model = dvgo.DirectVoxGO(
xyz_min=xyz_min, xyz_max=xyz_max,
num_voxels=num_voxels,
mask_cache_path=coarse_ckpt_path,
**model_kwargs)
elif cfg.data.unbounded_inward:
print(f'scene_rep_reconstruction ({stage}): \033[96muse contraced voxel grid (covering unbounded)\033[0m')
model = dcvgo.DirectContractedVoxGO(
xyz_min=xyz_min, xyz_max=xyz_max,
num_voxels=num_voxels,
**model_kwargs)
else:
print(f'scene_rep_reconstruction ({stage}): \033[96muse dense voxel grid\033[0m')
model = dvgo.DirectVoxGO(
xyz_min=xyz_min, xyz_max=xyz_max,
num_voxels=num_voxels,
mask_cache_path=coarse_ckpt_path,
**model_kwargs)
model = model.to(device)
optimizer = create_optimizer_or_freeze_model(model, cfg_train, global_step=0)
return model, optimizer
''' Evaluation metrics (ssim, lpips)
'''
def rgb_ssim(img0, img1, max_val,
filter_size=11,
filter_sigma=1.5,
k1=0.01,
k2=0.03,
return_map=False):
# Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58
assert len(img0.shape) == 3
assert img0.shape[-1] == 3
assert img0.shape == img1.shape
# Construct a 1D Gaussian blur filter.
hw = filter_size // 2
shift = (2 * hw - filter_size + 1) / 2
f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2
filt = np.exp(-0.5 * f_i)
filt /= np.sum(filt)
# Blur in x and y (faster than the 2D convolution).
def convolve2d(z, f):
return scipy.signal.convolve2d(z, f, mode='valid')
filt_fn = lambda z: np.stack([
convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :])
for i in range(z.shape[-1])], -1)
mu0 = filt_fn(img0)
mu1 = filt_fn(img1)
mu00 = mu0 * mu0
mu11 = mu1 * mu1
mu01 = mu0 * mu1
sigma00 = filt_fn(img0**2) - mu00
sigma11 = filt_fn(img1**2) - mu11
sigma01 = filt_fn(img0 * img1) - mu01
# Clip the variances and covariances to valid values.
# Variance must be non-negative:
sigma00 = np.maximum(0., sigma00)
sigma11 = np.maximum(0., sigma11)
sigma01 = np.sign(sigma01) * np.minimum(
np.sqrt(sigma00 * sigma11), np.abs(sigma01))
c1 = (k1 * max_val)**2
c2 = (k2 * max_val)**2
numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
ssim_map = numer / denom
ssim = np.mean(ssim_map)
return ssim_map if return_map else ssim
__LPIPS__ = {}
def init_lpips(net_name, device):
assert net_name in ['alex', 'vgg']
import lpips
print(f'init_lpips: lpips_{net_name}')
return lpips.LPIPS(net=net_name, version='0.1').eval().to(device)
def rgb_lpips(np_gt, np_im, net_name, device):
if net_name not in __LPIPS__:
__LPIPS__[net_name] = init_lpips(net_name, device)
gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device)
im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device)
return __LPIPS__[net_name](gt, im, normalize=True).item()
''' generate rays
'''
def get_rays(H, W, K, c2w, inverse_y, flip_x, flip_y, mode='center'):
i, j = torch.meshgrid(
torch.linspace(0, W-1, W, device=c2w.device),
torch.linspace(0, H-1, H, device=c2w.device)) # pytorch's meshgrid has indexing='ij'
i = i.t().float()
j = j.t().float()
if mode == 'lefttop':
pass
elif mode == 'center':
i, j = i+0.5, j+0.5
elif mode == 'random':
i = i+torch.rand_like(i)
j = j+torch.rand_like(j)
else:
raise NotImplementedError
if flip_x:
i = i.flip((1,))
if flip_y:
j = j.flip((0,))
if inverse_y:
dirs = torch.stack([(i-K[0][2])/K[0][0], (j-K[1][2])/K[1][1], torch.ones_like(i)], -1)
else:
dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)
# Rotate ray directions from camera frame to the world frame
rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
# Translate camera frame's origin to the world frame. It is the origin of all rays.
rays_o = c2w[:3,3].expand(rays_d.shape)
return rays_o, rays_d
def ndc_rays(H, W, focal, near, rays_o, rays_d):
# Shift ray origins to near plane
t = -(near + rays_o[...,2]) / rays_d[...,2]
rays_o = rays_o + t[...,None] * rays_d
# Projection
o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2]
o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2]
o2 = 1. + 2. * near / rays_o[...,2]
d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2])
d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2])
d2 = -2. * near / rays_o[...,2]
rays_o = torch.stack([o0,o1,o2], -1)
rays_d = torch.stack([d0,d1,d2], -1)
return rays_o, rays_d
def get_rays_of_a_view(H, W, K, c2w, ndc, inverse_y, flip_x, flip_y, mode='center'):
rays_o, rays_d = get_rays(H, W, K, c2w, inverse_y=inverse_y, flip_x=flip_x, flip_y=flip_y, mode=mode)
viewdirs = rays_d / rays_d.norm(dim=-1, keepdim=True)
if ndc:
rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)
return rays_o, rays_d, viewdirs
''' interactive mode TODO'''
def fetch_user_define_points():
pass