code / SparseNeuS_demo_v1 /models /trainer_generic.py
Chao Xu
code pruning
216282e
raw
history blame
56.5 kB
"""
decouple the trainer with the renderer
"""
import os
import cv2 as cv
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import trimesh
from icecream import ic
from utils.misc_utils import visualize_depth_numpy
from loss.depth_metric import compute_depth_errors
from loss.depth_loss import DepthLoss, DepthSmoothLoss
from models.sparse_neus_renderer import SparseNeuSRenderer
class GenericTrainer(nn.Module):
def __init__(self,
rendering_network_outside,
pyramid_feature_network_lod0,
pyramid_feature_network_lod1,
sdf_network_lod0,
sdf_network_lod1,
variance_network_lod0,
variance_network_lod1,
rendering_network_lod0,
rendering_network_lod1,
n_samples_lod0,
n_importance_lod0,
n_samples_lod1,
n_importance_lod1,
n_outside,
perturb,
alpha_type='div',
conf=None,
timestamp="",
mode='train',
base_exp_dir=None,
):
super(GenericTrainer, self).__init__()
self.conf = conf
self.timestamp = timestamp
self.base_exp_dir = base_exp_dir
self.anneal_start = self.conf.get_float('train.anneal_start', default=0.0)
self.anneal_end = self.conf.get_float('train.anneal_end', default=0.0)
self.anneal_start_lod1 = self.conf.get_float('train.anneal_start_lod1', default=0.0)
self.anneal_end_lod1 = self.conf.get_float('train.anneal_end_lod1', default=0.0)
# network setups
self.rendering_network_outside = rendering_network_outside
self.pyramid_feature_network_geometry_lod0 = pyramid_feature_network_lod0 # 2D pyramid feature network for geometry
self.pyramid_feature_network_geometry_lod1 = pyramid_feature_network_lod1 # use differnet networks for the two lods
# when num_lods==2, may consume too much memeory
self.sdf_network_lod0 = sdf_network_lod0
self.sdf_network_lod1 = sdf_network_lod1
# - warpped by ModuleList to support DataParallel
self.variance_network_lod0 = variance_network_lod0
self.variance_network_lod1 = variance_network_lod1
self.rendering_network_lod0 = rendering_network_lod0
self.rendering_network_lod1 = rendering_network_lod1
self.n_samples_lod0 = n_samples_lod0
self.n_importance_lod0 = n_importance_lod0
self.n_samples_lod1 = n_samples_lod1
self.n_importance_lod1 = n_importance_lod1
self.n_outside = n_outside
self.num_lods = conf.get_int('model.num_lods') # the number of octree lods
self.perturb = perturb
self.alpha_type = alpha_type
# - the two renderers
self.sdf_renderer_lod0 = SparseNeuSRenderer(
self.rendering_network_outside,
self.sdf_network_lod0,
self.variance_network_lod0,
self.rendering_network_lod0,
self.n_samples_lod0,
self.n_importance_lod0,
self.n_outside,
self.perturb,
alpha_type='div',
conf=self.conf)
self.sdf_renderer_lod1 = SparseNeuSRenderer(
self.rendering_network_outside,
self.sdf_network_lod1,
self.variance_network_lod1,
self.rendering_network_lod1,
self.n_samples_lod1,
self.n_importance_lod1,
self.n_outside,
self.perturb,
alpha_type='div',
conf=self.conf)
self.if_fix_lod0_networks = self.conf.get_bool('train.if_fix_lod0_networks')
# sdf network weights
self.sdf_igr_weight = self.conf.get_float('train.sdf_igr_weight')
self.sdf_sparse_weight = self.conf.get_float('train.sdf_sparse_weight', default=0)
self.sdf_decay_param = self.conf.get_float('train.sdf_decay_param', default=100)
self.fg_bg_weight = self.conf.get_float('train.fg_bg_weight', default=0.00)
self.bg_ratio = self.conf.get_float('train.bg_ratio', default=0.0)
self.depth_criterion = DepthLoss()
# - DataParallel mode, cannot modify attributes in forward()
# self.iter_step = 0
self.val_mesh_freq = self.conf.get_int('train.val_mesh_freq')
# - True for finetuning; False for general training
self.if_fitted_rendering = self.conf.get_bool('train.if_fitted_rendering', default=False)
self.prune_depth_filter = self.conf.get_bool('model.prune_depth_filter', default=False)
def get_trainable_params(self):
# set trainable params
self.params_to_train = []
if not self.if_fix_lod0_networks:
# load pretrained featurenet
self.params_to_train += list(self.pyramid_feature_network_geometry_lod0.parameters())
self.params_to_train += list(self.sdf_network_lod0.parameters())
self.params_to_train += list(self.variance_network_lod0.parameters())
if self.rendering_network_lod0 is not None:
self.params_to_train += list(self.rendering_network_lod0.parameters())
if self.sdf_network_lod1 is not None:
# load pretrained featurenet
self.params_to_train += list(self.pyramid_feature_network_geometry_lod1.parameters())
self.params_to_train += list(self.sdf_network_lod1.parameters())
self.params_to_train += list(self.variance_network_lod1.parameters())
if self.rendering_network_lod1 is not None:
self.params_to_train += list(self.rendering_network_lod1.parameters())
return self.params_to_train
def train_step(self, sample,
perturb_overwrite=-1,
background_rgb=None,
alpha_inter_ratio_lod0=0.0,
alpha_inter_ratio_lod1=0.0,
iter_step=0,
):
# * only support batch_size==1
# ! attention: the list of string cannot be splited in DataParallel
batch_idx = sample['batch_idx'][0]
meta = sample['meta'][batch_idx] # the scan lighting ref_view info
sizeW = sample['img_wh'][0][0]
sizeH = sample['img_wh'][0][1]
partial_vol_origin = sample['partial_vol_origin'] # [B, 3]
near, far = sample['near_fars'][0, 0, :1], sample['near_fars'][0, 0, 1:]
# the full-size ray variables
sample_rays = sample['rays']
rays_o = sample_rays['rays_o'][0]
rays_d = sample_rays['rays_v'][0]
imgs = sample['images'][0]
intrinsics = sample['intrinsics'][0]
intrinsics_l_4x = intrinsics.clone()
intrinsics_l_4x[:, :2] *= 0.25
w2cs = sample['w2cs'][0]
c2ws = sample['c2ws'][0]
proj_matrices = sample['affine_mats']
scale_mat = sample['scale_mat']
trans_mat = sample['trans_mat']
# *********************** Lod==0 ***********************
if not self.if_fix_lod0_networks:
geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs)
conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume(
feature_maps=geometry_feature_maps[None, 1:, :, :, :],
partial_vol_origin=partial_vol_origin,
proj_mats=proj_matrices[:,1:],
# proj_mats=proj_matrices,
sizeH=sizeH,
sizeW=sizeW,
lod=0,
)
else:
with torch.no_grad():
geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0)
# geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0)
conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume(
feature_maps=geometry_feature_maps[None, 1:, :, :, :],
partial_vol_origin=partial_vol_origin,
proj_mats=proj_matrices[:,1:],
# proj_mats=proj_matrices,
sizeH=sizeH,
sizeW=sizeW,
lod=0,
)
# print("Checker2:, construct cost volume")
con_volume_lod0 = conditional_features_lod0['dense_volume_scale0']
con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0']
coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ]
# * extract depth maps for all the images
depth_maps_lod0, depth_masks_lod0 = None, None
if self.num_lods > 1:
sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume(
con_volume_lod0, con_valid_mask_volume_lod0,
coords_lod0, partial_vol_origin) # [1, 1, dX, dY, dZ]
if self.prune_depth_filter:
depth_maps_lod0_l4x, depth_masks_lod0_l4x = self.sdf_renderer_lod0.extract_depth_maps(
self.sdf_network_lod0, sdf_volume_lod0, intrinsics_l_4x, c2ws,
sizeH // 4, sizeW // 4, near * 1.5, far)
depth_maps_lod0 = F.interpolate(depth_maps_lod0_l4x, size=(sizeH, sizeW), mode='bilinear',
align_corners=True)
depth_masks_lod0 = F.interpolate(depth_masks_lod0_l4x.float(), size=(sizeH, sizeW), mode='nearest')
# *************** losses
loss_lod0, losses_lod0, depth_statis_lod0 = None, None, None
if not self.if_fix_lod0_networks:
render_out = self.sdf_renderer_lod0.render(
rays_o, rays_d, near, far,
self.sdf_network_lod0,
self.rendering_network_lod0,
background_rgb=background_rgb,
alpha_inter_ratio=alpha_inter_ratio_lod0,
# * related to conditional feature
lod=0,
conditional_volume=con_volume_lod0,
conditional_valid_mask_volume=con_valid_mask_volume_lod0,
# * 2d feature maps
feature_maps=geometry_feature_maps,
color_maps=imgs,
w2cs=w2cs,
intrinsics=intrinsics,
img_wh=[sizeW, sizeH],
if_general_rendering=True,
if_render_with_grad=True,
)
loss_lod0, losses_lod0, depth_statis_lod0 = self.cal_losses_sdf(render_out, sample_rays,
iter_step, lod=0)
# *********************** Lod==1 ***********************
loss_lod1, losses_lod1, depth_statis_lod1 = None, None, None
if self.num_lods > 1:
geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1)
# geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1)
if self.prune_depth_filter:
pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf_depthfilter(
sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0],
depth_maps_lod0, proj_matrices[0],
partial_vol_origin, self.sdf_network_lod0.voxel_size,
near, far, self.sdf_network_lod0.voxel_size, 12)
else:
pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf(
sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0])
pre_coords[:, 1:] = pre_coords[:, 1:] * 2
# ? It seems that training gru_fusion, this part should be trainable too
conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume(
feature_maps=geometry_feature_maps_lod1[None, 1:, :, :, :],
partial_vol_origin=partial_vol_origin,
proj_mats=proj_matrices[:,1:],
# proj_mats=proj_matrices,
sizeH=sizeH,
sizeW=sizeW,
pre_coords=pre_coords,
pre_feats=pre_feats,
)
con_volume_lod1 = conditional_features_lod1['dense_volume_scale1']
con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1']
# if not self.if_gru_fusion_lod1:
render_out_lod1 = self.sdf_renderer_lod1.render(
rays_o, rays_d, near, far,
self.sdf_network_lod1,
self.rendering_network_lod1,
background_rgb=background_rgb,
alpha_inter_ratio=alpha_inter_ratio_lod1,
# * related to conditional feature
lod=1,
conditional_volume=con_volume_lod1,
conditional_valid_mask_volume=con_valid_mask_volume_lod1,
# * 2d feature maps
feature_maps=geometry_feature_maps_lod1,
color_maps=imgs,
w2cs=w2cs,
intrinsics=intrinsics,
img_wh=[sizeW, sizeH],
bg_ratio=self.bg_ratio,
)
loss_lod1, losses_lod1, depth_statis_lod1 = self.cal_losses_sdf(render_out_lod1, sample_rays,
iter_step, lod=1)
# print("Checker3:, compute losses")
# # - extract mesh
if iter_step % self.val_mesh_freq == 0:
torch.cuda.empty_cache()
self.validate_mesh(self.sdf_network_lod0,
self.sdf_renderer_lod0.extract_geometry,
conditional_volume=con_volume_lod0, lod=0,
threshold=0,
# occupancy_mask=con_valid_mask_volume_lod0[0, 0],
mode='train_bg', meta=meta,
iter_step=iter_step, scale_mat=scale_mat,
trans_mat=trans_mat)
torch.cuda.empty_cache()
if self.num_lods > 1:
self.validate_mesh(self.sdf_network_lod1,
self.sdf_renderer_lod1.extract_geometry,
conditional_volume=con_volume_lod1, lod=1,
# occupancy_mask=con_valid_mask_volume_lod1[0, 0].detach(),
mode='train_bg', meta=meta,
iter_step=iter_step, scale_mat=scale_mat,
trans_mat=trans_mat)
losses = {
# - lod 0
'loss_lod0': loss_lod0,
'losses_lod0': losses_lod0,
'depth_statis_lod0': depth_statis_lod0,
# - lod 1
'loss_lod1': loss_lod1,
'losses_lod1': losses_lod1,
'depth_statis_lod1': depth_statis_lod1,
}
return losses
def val_step(self, sample,
perturb_overwrite=-1,
background_rgb=None,
alpha_inter_ratio_lod0=0.0,
alpha_inter_ratio_lod1=0.0,
iter_step=0,
chunk_size=512,
save_vis=False,
):
# * only support batch_size==1
# ! attention: the list of string cannot be splited in DataParallel
batch_idx = sample['batch_idx'][0]
meta = sample['meta'][batch_idx] # the scan lighting ref_view info
sizeW = sample['img_wh'][0][0]
sizeH = sample['img_wh'][0][1]
H, W = sizeH, sizeW
partial_vol_origin = sample['partial_vol_origin'] # [B, 3]
near, far = sample['query_near_far'][0, :1], sample['query_near_far'][0, 1:]
# the ray variables
sample_rays = sample['rays']
rays_o = sample_rays['rays_o'][0]
rays_d = sample_rays['rays_v'][0]
rays_ndc_uv = sample_rays['rays_ndc_uv'][0]
imgs = sample['images'][0]
intrinsics = sample['intrinsics'][0]
intrinsics_l_4x = intrinsics.clone()
intrinsics_l_4x[:, :2] *= 0.25
w2cs = sample['w2cs'][0]
c2ws = sample['c2ws'][0]
proj_matrices = sample['affine_mats']
# render_img_idx = sample['render_img_idx'][0]
# true_img = sample['images'][0][render_img_idx]
# - the image to render
scale_mat = sample['scale_mat'] # [1,4,4] used to convert mesh into true scale
trans_mat = sample['trans_mat']
query_c2w = sample['query_c2w'] # [1,4,4]
query_w2c = sample['query_w2c'] # [1,4,4]
true_img = sample['query_image'][0]
true_img = np.uint8(true_img.permute(1, 2, 0).cpu().numpy() * 255)
depth_min, depth_max = near.cpu().numpy(), far.cpu().numpy()
scale_factor = sample['scale_factor'][0].cpu().numpy()
true_depth = sample['query_depth'] if 'query_depth' in sample.keys() else None
if true_depth is not None:
true_depth = true_depth[0].cpu().numpy()
true_depth_colored = visualize_depth_numpy(true_depth, [depth_min, depth_max])[0]
else:
true_depth_colored = None
rays_o = rays_o.reshape(-1, 3).split(chunk_size)
rays_d = rays_d.reshape(-1, 3).split(chunk_size)
# - obtain conditional features
with torch.no_grad():
# - obtain conditional features
geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0)
# - lod 0
conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume(
feature_maps=geometry_feature_maps[None, :, :, :, :],
partial_vol_origin=partial_vol_origin,
proj_mats=proj_matrices,
sizeH=sizeH,
sizeW=sizeW,
lod=0,
)
con_volume_lod0 = conditional_features_lod0['dense_volume_scale0']
con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0']
coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ]
if self.num_lods > 1:
sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume(
con_volume_lod0, con_valid_mask_volume_lod0,
coords_lod0, partial_vol_origin) # [1, 1, dX, dY, dZ]
depth_maps_lod0, depth_masks_lod0 = None, None
if self.prune_depth_filter:
depth_maps_lod0_l4x, depth_masks_lod0_l4x = self.sdf_renderer_lod0.extract_depth_maps(
self.sdf_network_lod0, sdf_volume_lod0,
intrinsics_l_4x, c2ws,
sizeH // 4, sizeW // 4, near * 1.5, far) # - near*1.5 is a experienced number
depth_maps_lod0 = F.interpolate(depth_maps_lod0_l4x, size=(sizeH, sizeW), mode='bilinear',
align_corners=True)
depth_masks_lod0 = F.interpolate(depth_masks_lod0_l4x.float(), size=(sizeH, sizeW), mode='nearest')
#### visualize the depth_maps_lod0 for checking
colored_depth_maps_lod0 = []
for i in range(depth_maps_lod0.shape[0]):
colored_depth_maps_lod0.append(
visualize_depth_numpy(depth_maps_lod0[i, 0].cpu().numpy(), [depth_min, depth_max])[0])
colored_depth_maps_lod0 = np.concatenate(colored_depth_maps_lod0, axis=0).astype(np.uint8)
os.makedirs(os.path.join(self.base_exp_dir, 'depth_maps_lod0'), exist_ok=True)
cv.imwrite(os.path.join(self.base_exp_dir, 'depth_maps_lod0',
'{:0>8d}_{}.png'.format(iter_step, meta)),
colored_depth_maps_lod0[:, :, ::-1])
if self.num_lods > 1:
geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1)
if self.prune_depth_filter:
pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf_depthfilter(
sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0],
depth_maps_lod0, proj_matrices[0],
partial_vol_origin, self.sdf_network_lod0.voxel_size,
near, far, self.sdf_network_lod0.voxel_size, 12)
else:
pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf(
sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0])
pre_coords[:, 1:] = pre_coords[:, 1:] * 2
with torch.no_grad():
conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume(
feature_maps=geometry_feature_maps_lod1[None, :, :, :, :],
partial_vol_origin=partial_vol_origin,
proj_mats=proj_matrices,
sizeH=sizeH,
sizeW=sizeW,
pre_coords=pre_coords,
pre_feats=pre_feats,
)
con_volume_lod1 = conditional_features_lod1['dense_volume_scale1']
con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1']
out_rgb_fine = []
out_normal_fine = []
out_depth_fine = []
out_rgb_fine_lod1 = []
out_normal_fine_lod1 = []
out_depth_fine_lod1 = []
# out_depth_fine_explicit = []
if save_vis:
for rays_o_batch, rays_d_batch in zip(rays_o, rays_d):
# ****** lod 0 ****
render_out = self.sdf_renderer_lod0.render(
rays_o_batch, rays_d_batch, near, far,
self.sdf_network_lod0,
self.rendering_network_lod0,
background_rgb=background_rgb,
alpha_inter_ratio=alpha_inter_ratio_lod0,
# * related to conditional feature
lod=0,
conditional_volume=con_volume_lod0,
conditional_valid_mask_volume=con_valid_mask_volume_lod0,
# * 2d feature maps
feature_maps=geometry_feature_maps,
color_maps=imgs,
w2cs=w2cs,
intrinsics=intrinsics,
img_wh=[sizeW, sizeH],
query_c2w=query_c2w,
if_render_with_grad=False,
)
feasible = lambda key: ((key in render_out) and (render_out[key] is not None))
if feasible('depth'):
out_depth_fine.append(render_out['depth'].detach().cpu().numpy())
# if render_out['color_coarse'] is not None:
if feasible('color_fine'):
out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy())
if feasible('gradients') and feasible('weights'):
if render_out['inside_sphere'] is not None:
out_normal_fine.append((render_out['gradients'] * render_out['weights'][:,
:self.n_samples_lod0 + self.n_importance_lod0,
None] * render_out['inside_sphere'][
..., None]).sum(dim=1).detach().cpu().numpy())
else:
out_normal_fine.append((render_out['gradients'] * render_out['weights'][:,
:self.n_samples_lod0 + self.n_importance_lod0,
None]).sum(dim=1).detach().cpu().numpy())
del render_out
# ****************** lod 1 **************************
if self.num_lods > 1:
for rays_o_batch, rays_d_batch in zip(rays_o, rays_d):
render_out_lod1 = self.sdf_renderer_lod1.render(
rays_o_batch, rays_d_batch, near, far,
self.sdf_network_lod1,
self.rendering_network_lod1,
background_rgb=background_rgb,
alpha_inter_ratio=alpha_inter_ratio_lod1,
# * related to conditional feature
lod=1,
conditional_volume=con_volume_lod1,
conditional_valid_mask_volume=con_valid_mask_volume_lod1,
# * 2d feature maps
feature_maps=geometry_feature_maps_lod1,
color_maps=imgs,
w2cs=w2cs,
intrinsics=intrinsics,
img_wh=[sizeW, sizeH],
query_c2w=query_c2w,
if_render_with_grad=False,
)
feasible = lambda key: ((key in render_out_lod1) and (render_out_lod1[key] is not None))
if feasible('depth'):
out_depth_fine_lod1.append(render_out_lod1['depth'].detach().cpu().numpy())
# if render_out['color_coarse'] is not None:
if feasible('color_fine'):
out_rgb_fine_lod1.append(render_out_lod1['color_fine'].detach().cpu().numpy())
if feasible('gradients') and feasible('weights'):
if render_out_lod1['inside_sphere'] is not None:
out_normal_fine_lod1.append((render_out_lod1['gradients'] * render_out_lod1['weights'][:,
:self.n_samples_lod1 + self.n_importance_lod1,
None] *
render_out_lod1['inside_sphere'][
..., None]).sum(dim=1).detach().cpu().numpy())
else:
out_normal_fine_lod1.append((render_out_lod1['gradients'] * render_out_lod1['weights'][:,
:self.n_samples_lod1 + self.n_importance_lod1,
None]).sum(
dim=1).detach().cpu().numpy())
del render_out_lod1
# - save visualization of lod 0
self.save_visualization(true_img, true_depth_colored, out_depth_fine, out_normal_fine,
query_w2c[0], out_rgb_fine, H, W,
depth_min, depth_max, iter_step, meta, "val_lod0", true_depth=true_depth, scale_factor=scale_factor)
if self.num_lods > 1:
self.save_visualization(true_img, true_depth_colored, out_depth_fine_lod1, out_normal_fine_lod1,
query_w2c[0], out_rgb_fine_lod1, H, W,
depth_min, depth_max, iter_step, meta, "val_lod1", true_depth=true_depth, scale_factor=scale_factor)
# - extract mesh
if (iter_step % self.val_mesh_freq == 0):
torch.cuda.empty_cache()
self.validate_mesh(self.sdf_network_lod0,
self.sdf_renderer_lod0.extract_geometry,
conditional_volume=con_volume_lod0, lod=0,
threshold=0,
# occupancy_mask=con_valid_mask_volume_lod0[0, 0],
mode='val_bg', meta=meta,
iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat)
torch.cuda.empty_cache()
if self.num_lods > 1:
self.validate_mesh(self.sdf_network_lod1,
self.sdf_renderer_lod1.extract_geometry,
conditional_volume=con_volume_lod1, lod=1,
# occupancy_mask=con_valid_mask_volume_lod1[0, 0].detach(),
mode='val_bg', meta=meta,
iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat)
torch.cuda.empty_cache()
def export_mesh_step(self, sample,
perturb_overwrite=-1,
background_rgb=None,
alpha_inter_ratio_lod0=0.0,
alpha_inter_ratio_lod1=0.0,
iter_step=0,
chunk_size=512,
save_vis=False,
):
# * only support batch_size==1
# ! attention: the list of string cannot be splited in DataParallel
batch_idx = sample['batch_idx'][0]
meta = sample['meta'][batch_idx] # the scan lighting ref_view info
sizeW = sample['img_wh'][0][0]
sizeH = sample['img_wh'][0][1]
H, W = sizeH, sizeW
partial_vol_origin = sample['partial_vol_origin'] # [B, 3]
near, far = sample['query_near_far'][0, :1], sample['query_near_far'][0, 1:]
# the ray variables
sample_rays = sample['rays']
rays_o = sample_rays['rays_o'][0]
rays_d = sample_rays['rays_v'][0]
rays_ndc_uv = sample_rays['rays_ndc_uv'][0]
imgs = sample['images'][0]
intrinsics = sample['intrinsics'][0]
intrinsics_l_4x = intrinsics.clone()
intrinsics_l_4x[:, :2] *= 0.25
w2cs = sample['w2cs'][0]
c2ws = sample['c2ws'][0]
# target_candidate_w2cs = sample['target_candidate_w2cs'][0]
proj_matrices = sample['affine_mats']
# - the image to render
scale_mat = sample['scale_mat'] # [1,4,4] used to convert mesh into true scale
trans_mat = sample['trans_mat']
query_c2w = sample['query_c2w'] # [1,4,4]
query_w2c = sample['query_w2c'] # [1,4,4]
true_img = sample['query_image'][0]
true_img = np.uint8(true_img.permute(1, 2, 0).cpu().numpy() * 255)
depth_min, depth_max = near.cpu().numpy(), far.cpu().numpy()
scale_factor = sample['scale_factor'][0].cpu().numpy()
true_depth = sample['query_depth'] if 'query_depth' in sample.keys() else None
if true_depth is not None:
true_depth = true_depth[0].cpu().numpy()
true_depth_colored = visualize_depth_numpy(true_depth, [depth_min, depth_max])[0]
else:
true_depth_colored = None
rays_o = rays_o.reshape(-1, 3).split(chunk_size)
rays_d = rays_d.reshape(-1, 3).split(chunk_size)
# import time
# jha_begin1 = time.time()
# - obtain conditional features
with torch.no_grad():
# - obtain conditional features
geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0)
# - lod 0
conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume(
feature_maps=geometry_feature_maps[None, :, :, :, :],
partial_vol_origin=partial_vol_origin,
proj_mats=proj_matrices,
sizeH=sizeH,
sizeW=sizeW,
lod=0,
)
# jha_end1 = time.time()
# print("get_conditional_volume: ", jha_end1 - jha_begin1)
# jha_begin2 = time.time()
con_volume_lod0 = conditional_features_lod0['dense_volume_scale0']
con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0']
coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ]
if self.num_lods > 1:
sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume(
con_volume_lod0, con_valid_mask_volume_lod0,
coords_lod0, partial_vol_origin) # [1, 1, dX, dY, dZ]
depth_maps_lod0, depth_masks_lod0 = None, None
if self.num_lods > 1:
geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1)
if self.prune_depth_filter:
pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf_depthfilter(
sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0],
depth_maps_lod0, proj_matrices[0],
partial_vol_origin, self.sdf_network_lod0.voxel_size,
near, far, self.sdf_network_lod0.voxel_size, 12)
else:
pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf(
sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0])
pre_coords[:, 1:] = pre_coords[:, 1:] * 2
with torch.no_grad():
conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume(
feature_maps=geometry_feature_maps_lod1[None, :, :, :, :],
partial_vol_origin=partial_vol_origin,
proj_mats=proj_matrices,
sizeH=sizeH,
sizeW=sizeW,
pre_coords=pre_coords,
pre_feats=pre_feats,
)
con_volume_lod1 = conditional_features_lod1['dense_volume_scale1']
con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1']
out_rgb_fine = []
out_normal_fine = []
out_depth_fine = []
out_rgb_fine_lod1 = []
out_normal_fine_lod1 = []
out_depth_fine_lod1 = []
# jha_end2 = time.time()
# print("interval before starting mesh export: ", jha_end2 - jha_begin2)
# - extract mesh
if (iter_step % self.val_mesh_freq == 0):
torch.cuda.empty_cache()
# jha_begin3 = time.time()
self.validate_colored_mesh(
density_or_sdf_network=self.sdf_network_lod0,
func_extract_geometry=self.sdf_renderer_lod0.extract_geometry,
conditional_volume=con_volume_lod0,
conditional_valid_mask_volume = con_valid_mask_volume_lod0,
feature_maps=geometry_feature_maps,
color_maps=imgs,
w2cs=w2cs,
target_candidate_w2cs=None,
intrinsics=intrinsics,
rendering_network=self.rendering_network_lod0,
rendering_projector=self.sdf_renderer_lod0.rendering_projector,
lod=0,
threshold=0,
query_c2w=query_c2w,
mode='val_bg', meta=meta,
iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat
)
torch.cuda.empty_cache()
# jha_end3 = time.time()
# print("validate_colored_mesh_test_time: ", jha_end3 - jha_begin3)
if self.num_lods > 1:
self.validate_colored_mesh(
density_or_sdf_network=self.sdf_network_lod1,
func_extract_geometry=self.sdf_renderer_lod1.extract_geometry,
conditional_volume=con_volume_lod1,
conditional_valid_mask_volume = con_valid_mask_volume_lod1,
feature_maps=geometry_feature_maps,
color_maps=imgs,
w2cs=w2cs,
target_candidate_w2cs=None,
intrinsics=intrinsics,
rendering_network=self.rendering_network_lod1,
rendering_projector=self.sdf_renderer_lod1.rendering_projector,
lod=1,
threshold=0,
query_c2w=query_c2w,
mode='val_bg', meta=meta,
iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat
)
torch.cuda.empty_cache()
def save_visualization(self, true_img, true_colored_depth, out_depth, out_normal, w2cs, out_color, H, W,
depth_min, depth_max, iter_step, meta, comment, out_color_mlp=[], true_depth=None, scale_factor=1.0):
if len(out_color) > 0:
img_fine = (np.concatenate(out_color, axis=0).reshape([H, W, 3]) * 256).clip(0, 255)
if len(out_color_mlp) > 0:
img_mlp = (np.concatenate(out_color_mlp, axis=0).reshape([H, W, 3]) * 256).clip(0, 255)
if len(out_normal) > 0:
normal_img = np.concatenate(out_normal, axis=0)
rot = w2cs[:3, :3].detach().cpu().numpy()
# - convert normal from world space to camera space
normal_img = (np.matmul(rot[None, :, :],
normal_img[:, :, None]).reshape([H, W, 3]) * 128 + 128).clip(0, 255)
if len(out_depth) > 0:
pred_depth = np.concatenate(out_depth, axis=0).reshape([H, W])
pred_depth_colored = visualize_depth_numpy(pred_depth, [depth_min, depth_max])[0]
if len(out_depth) > 0:
os.makedirs(os.path.join(self.base_exp_dir, 'depths_' + comment), exist_ok=True)
if true_colored_depth is not None:
if true_depth is not None:
depth_error_map = np.abs(true_depth - pred_depth) * 2.0 / scale_factor
# [256, 256, 1] -> [256, 256, 3]
depth_error_map = np.tile(depth_error_map[:, :, None], [1, 1, 3])
print("meta: ", meta)
print("scale_factor: ", scale_factor)
print("depth_error_mean: ", depth_error_map.mean())
depth_visualized = np.concatenate(
[(depth_error_map * 255).astype(np.uint8), true_colored_depth, pred_depth_colored, true_img], axis=1)[:, :, ::-1]
# print("depth_visualized.shape: ", depth_visualized.shape)
# write depth error result text on img, the input is a numpy array of [256, 1024, 3]
# cv.putText(depth_visualized.copy(), "depth_error_mean: {:.4f}".format(depth_error_map.mean()), (10, 30), cv.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
else:
depth_visualized = np.concatenate(
[true_colored_depth, pred_depth_colored, true_img])[:, :, ::-1]
cv.imwrite(
os.path.join(self.base_exp_dir, 'depths_' + comment,
'{:0>8d}_{}.png'.format(iter_step, meta)), depth_visualized
)
else:
cv.imwrite(
os.path.join(self.base_exp_dir, 'depths_' + comment,
'{:0>8d}_{}.png'.format(iter_step, meta)),
np.concatenate(
[pred_depth_colored, true_img])[:, :, ::-1])
if len(out_color) > 0:
os.makedirs(os.path.join(self.base_exp_dir, 'synthesized_color_' + comment), exist_ok=True)
cv.imwrite(os.path.join(self.base_exp_dir, 'synthesized_color_' + comment,
'{:0>8d}_{}.png'.format(iter_step, meta)),
np.concatenate(
[img_fine, true_img])[:, :, ::-1]) # bgr2rgb
# compute psnr (image pixel lie in [0, 255])
mse_loss = np.mean((img_fine - true_img) ** 2)
psnr = 10 * np.log10(255 ** 2 / mse_loss)
print("PSNR: ", psnr)
if len(out_color_mlp) > 0:
os.makedirs(os.path.join(self.base_exp_dir, 'synthesized_color_mlp_' + comment), exist_ok=True)
cv.imwrite(os.path.join(self.base_exp_dir, 'synthesized_color_mlp_' + comment,
'{:0>8d}_{}.png'.format(iter_step, meta)),
np.concatenate(
[img_mlp, true_img])[:, :, ::-1]) # bgr2rgb
if len(out_normal) > 0:
os.makedirs(os.path.join(self.base_exp_dir, 'normals_' + comment), exist_ok=True)
cv.imwrite(os.path.join(self.base_exp_dir, 'normals_' + comment,
'{:0>8d}_{}.png'.format(iter_step, meta)),
normal_img[:, :, ::-1])
def forward(self, sample,
perturb_overwrite=-1,
background_rgb=None,
alpha_inter_ratio_lod0=0.0,
alpha_inter_ratio_lod1=0.0,
iter_step=0,
mode='train',
save_vis=False,
):
if mode == 'train':
return self.train_step(sample,
perturb_overwrite=perturb_overwrite,
background_rgb=background_rgb,
alpha_inter_ratio_lod0=alpha_inter_ratio_lod0,
alpha_inter_ratio_lod1=alpha_inter_ratio_lod1,
iter_step=iter_step
)
elif mode == 'val':
import time
begin = time.time()
result = self.val_step(sample,
perturb_overwrite=perturb_overwrite,
background_rgb=background_rgb,
alpha_inter_ratio_lod0=alpha_inter_ratio_lod0,
alpha_inter_ratio_lod1=alpha_inter_ratio_lod1,
iter_step=iter_step,
save_vis=save_vis,
)
end = time.time()
print("val_step time: ", end - begin)
return result
elif mode == 'export_mesh':
import time
begin = time.time()
result = self.export_mesh_step(sample,
perturb_overwrite=perturb_overwrite,
background_rgb=background_rgb,
alpha_inter_ratio_lod0=alpha_inter_ratio_lod0,
alpha_inter_ratio_lod1=alpha_inter_ratio_lod1,
iter_step=iter_step,
save_vis=save_vis,
)
end = time.time()
print("export mesh time: ", end - begin)
return result
def obtain_pyramid_feature_maps(self, imgs, lod=0):
"""
get feature maps of all conditional images
:param imgs:
:return:
"""
if lod == 0:
extractor = self.pyramid_feature_network_geometry_lod0
elif lod >= 1:
extractor = self.pyramid_feature_network_geometry_lod1
pyramid_feature_maps = extractor(imgs)
# * the pyramid features are very important, if only use the coarst features, hard to optimize
fused_feature_maps = torch.cat([
F.interpolate(pyramid_feature_maps[0], scale_factor=4, mode='bilinear', align_corners=True),
F.interpolate(pyramid_feature_maps[1], scale_factor=2, mode='bilinear', align_corners=True),
pyramid_feature_maps[2]
], dim=1)
return fused_feature_maps
def cal_losses_sdf(self, render_out, sample_rays, iter_step=-1, lod=0):
# loss weight schedule; the regularization terms should be added in later training stage
def get_weight(iter_step, weight):
if lod == 1:
anneal_start = self.anneal_end if lod == 0 else self.anneal_end_lod1
anneal_end = self.anneal_end if lod == 0 else self.anneal_end_lod1
anneal_end = anneal_end * 2
else:
anneal_start = self.anneal_start if lod == 0 else self.anneal_start_lod1
anneal_end = self.anneal_end if lod == 0 else self.anneal_end_lod1
anneal_end = anneal_end * 2
if iter_step < 0:
return weight
if anneal_end == 0.0:
return weight
elif iter_step < anneal_start:
return 0.0
else:
return np.min(
[1.0,
(iter_step - anneal_start) / (anneal_end - anneal_start)]) * weight
rays_o = sample_rays['rays_o'][0]
rays_d = sample_rays['rays_v'][0]
true_rgb = sample_rays['rays_color'][0]
if 'rays_depth' in sample_rays.keys():
true_depth = sample_rays['rays_depth'][0]
else:
true_depth = None
mask = sample_rays['rays_mask'][0]
color_fine = render_out['color_fine']
color_fine_mask = render_out['color_fine_mask']
depth_pred = render_out['depth']
variance = render_out['variance']
cdf_fine = render_out['cdf_fine']
weight_sum = render_out['weights_sum']
gradient_error_fine = render_out['gradient_error_fine']
sdf = render_out['sdf']
# * color generated by mlp
color_mlp = render_out['color_mlp']
color_mlp_mask = render_out['color_mlp_mask']
if color_fine is not None:
# Color loss
color_mask = color_fine_mask if color_fine_mask is not None else mask
color_mask = color_mask[..., 0]
color_error = (color_fine[color_mask] - true_rgb[color_mask])
# print("Nan number", torch.isnan(color_error).sum())
# print("Color error shape", color_error.shape)
color_fine_loss = F.l1_loss(color_error, torch.zeros_like(color_error).to(color_error.device),
reduction='mean')
# print(color_fine_loss)
psnr = 20.0 * torch.log10(
1.0 / (((color_fine[color_mask] - true_rgb[color_mask]) ** 2).mean() / (3.0)).sqrt())
else:
color_fine_loss = 0.
psnr = 0.
if color_mlp is not None:
# Color loss
color_mlp_mask = color_mlp_mask[..., 0]
color_error_mlp = (color_mlp[color_mlp_mask] - true_rgb[color_mlp_mask])
color_mlp_loss = F.l1_loss(color_error_mlp,
torch.zeros_like(color_error_mlp).to(color_error_mlp.device),
reduction='mean')
psnr_mlp = 20.0 * torch.log10(
1.0 / (((color_mlp[color_mlp_mask] - true_rgb[color_mlp_mask]) ** 2).mean() / (3.0)).sqrt())
else:
color_mlp_loss = 0.
psnr_mlp = 0.
# depth loss is only used for inference, not included in total loss
if true_depth is not None:
# depth_loss = self.depth_criterion(depth_pred, true_depth, mask)
depth_loss = self.depth_criterion(depth_pred, true_depth)
# # depth evaluation
# depth_statis = compute_depth_errors(depth_pred.detach().cpu().numpy(), true_depth.cpu().numpy())
# depth_statis = numpy2tensor(depth_statis, device=rays_o.device)
depth_statis = None
else:
depth_loss = 0.
depth_statis = None
sparse_loss_1 = torch.exp(
-1 * torch.abs(render_out['sdf_random']) * self.sdf_decay_param).mean() # - should equal
sparse_loss_2 = torch.exp(-1 * torch.abs(sdf) * self.sdf_decay_param).mean()
sparse_loss = (sparse_loss_1 + sparse_loss_2) / 2
sdf_mean = torch.abs(sdf).mean()
sparseness_1 = (torch.abs(sdf) < 0.01).to(torch.float32).mean()
sparseness_2 = (torch.abs(sdf) < 0.02).to(torch.float32).mean()
# Eikonal loss
gradient_error_loss = gradient_error_fine
# ! the first 50k, don't use bg constraint
fg_bg_weight = 0.0 if iter_step < 50000 else get_weight(iter_step, self.fg_bg_weight)
# Mask loss, optional
# The images of DTU dataset contain large black regions (0 rgb values),
# can use this data prior to make fg more clean
background_loss = 0.0
fg_bg_loss = 0.0
if self.fg_bg_weight > 0 and torch.mean((mask < 0.5).to(torch.float32)) > 0.02:
weights_sum_fg = render_out['weights_sum_fg']
fg_bg_error = (weights_sum_fg - mask)[mask < 0.5]
fg_bg_loss = F.l1_loss(fg_bg_error,
torch.zeros_like(fg_bg_error).to(fg_bg_error.device),
reduction='mean')
loss = 1.0 * depth_loss + color_fine_loss + color_mlp_loss + \
sparse_loss * get_weight(iter_step, self.sdf_sparse_weight) + \
fg_bg_loss * fg_bg_weight + \
gradient_error_loss * self.sdf_igr_weight # ! gradient_error_loss need a mask
losses = {
"loss": loss,
"depth_loss": depth_loss,
"color_fine_loss": color_fine_loss,
"color_mlp_loss": color_mlp_loss,
"gradient_error_loss": gradient_error_loss,
"background_loss": background_loss,
"sparse_loss": sparse_loss,
"sparseness_1": sparseness_1,
"sparseness_2": sparseness_2,
"sdf_mean": sdf_mean,
"psnr": psnr,
"psnr_mlp": psnr_mlp,
"weights_sum": render_out['weights_sum'],
"weights_sum_fg": render_out['weights_sum_fg'],
"alpha_sum": render_out['alpha_sum'],
"variance": render_out['variance'],
"sparse_weight": get_weight(iter_step, self.sdf_sparse_weight),
"fg_bg_weight": fg_bg_weight,
"fg_bg_loss": fg_bg_loss, # added by jha, bug of sparseNeuS
}
losses = torch.tensor(losses, device=rays_o.device)
return loss, losses, depth_statis
@torch.no_grad()
def validate_mesh(self, density_or_sdf_network, func_extract_geometry, world_space=True, resolution=360,
threshold=0.0, mode='val',
# * 3d feature volume
conditional_volume=None, lod=None, occupancy_mask=None,
bound_min=[-1, -1, -1], bound_max=[1, 1, 1], meta='', iter_step=0, scale_mat=None,
trans_mat=None
):
bound_min = torch.tensor(bound_min, dtype=torch.float32)
bound_max = torch.tensor(bound_max, dtype=torch.float32)
vertices, triangles, fields = func_extract_geometry(
density_or_sdf_network,
bound_min, bound_max, resolution=resolution,
threshold=threshold, device=conditional_volume.device,
# * 3d feature volume
conditional_volume=conditional_volume, lod=lod,
occupancy_mask=occupancy_mask
)
if scale_mat is not None:
scale_mat_np = scale_mat.cpu().numpy()
vertices = vertices * scale_mat_np[0][0, 0] + scale_mat_np[0][:3, 3][None]
if trans_mat is not None: # w2c_ref_inv
trans_mat_np = trans_mat.cpu().numpy()
vertices_homo = np.concatenate([vertices, np.ones_like(vertices[:, :1])], axis=1)
vertices = np.matmul(trans_mat_np, vertices_homo[:, :, None])[:, :3, 0]
mesh = trimesh.Trimesh(vertices, triangles)
os.makedirs(os.path.join(self.base_exp_dir, 'meshes_' + mode), exist_ok=True)
mesh.export(os.path.join(self.base_exp_dir, 'meshes_' + mode,
'mesh_{:0>8d}_{}_lod{:0>1d}.ply'.format(iter_step, meta, lod)))
def validate_colored_mesh(self, density_or_sdf_network, func_extract_geometry, world_space=True, resolution=360,
threshold=0.0, mode='val',
# * 3d feature volume
conditional_volume=None,
conditional_valid_mask_volume=None,
feature_maps=None,
color_maps = None,
w2cs=None,
target_candidate_w2cs=None,
intrinsics=None,
rendering_network=None,
rendering_projector=None,
query_c2w=None,
lod=None, occupancy_mask=None,
bound_min=[-1, -1, -1], bound_max=[1, 1, 1], meta='', iter_step=0, scale_mat=None,
trans_mat=None
):
bound_min = torch.tensor(bound_min, dtype=torch.float32)
bound_max = torch.tensor(bound_max, dtype=torch.float32)
# import time
# jha_begin4 = time.time()
vertices, triangles, fields = func_extract_geometry(
density_or_sdf_network,
bound_min, bound_max, resolution=resolution,
threshold=threshold, device=conditional_volume.device,
# * 3d feature volume
conditional_volume=conditional_volume, lod=lod,
occupancy_mask=occupancy_mask
)
# jha_end4 = time.time()
# print("[TEST]: func_extract_geometry time", jha_end4 - jha_begin4)
# import time
# jha_begin5 = time.time()
with torch.no_grad():
ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask, _, _ = rendering_projector.compute_view_independent(
torch.tensor(vertices).to(conditional_volume),
lod=lod, # JHA EDITED
# * 3d geometry feature volumes
geometryVolume=conditional_volume[0],
geometryVolumeMask=conditional_valid_mask_volume[0],
sdf_network=density_or_sdf_network,
# * 2d rendering feature maps
rendering_feature_maps=feature_maps, # [n_view, 56, 256, 256]
color_maps=color_maps,
w2cs=w2cs,
target_candidate_w2cs=target_candidate_w2cs,
intrinsics=intrinsics,
img_wh=[256,256],
query_img_idx=0, # the index of the N_views dim for rendering
query_c2w=query_c2w,
)
vertices_color, rendering_valid_mask = rendering_network(
ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask)
# jha_end5 = time.time()
# print("[TEST]: rendering_network time", jha_end5 - jha_begin5)
if scale_mat is not None:
scale_mat_np = scale_mat.cpu().numpy()
vertices = vertices * scale_mat_np[0][0, 0] + scale_mat_np[0][:3, 3][None]
if trans_mat is not None: # w2c_ref_inv
trans_mat_np = trans_mat.cpu().numpy()
vertices_homo = np.concatenate([vertices, np.ones_like(vertices[:, :1])], axis=1)
vertices = np.matmul(trans_mat_np, vertices_homo[:, :, None])[:, :3, 0]
vertices_color = np.array(vertices_color.squeeze(0).cpu() * 255, dtype=np.uint8)
mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=vertices_color)
os.makedirs(os.path.join(self.base_exp_dir, 'meshes_' + mode, 'lod{:0>1d}'.format(lod)), exist_ok=True)
# mesh.export(os.path.join(self.base_exp_dir, 'meshes_' + mode, 'lod{:0>1d}'.format(lod),
# 'mesh_{:0>8d}_{}_lod{:0>1d}.ply'.format(iter_step, meta, lod)))
# MODIFIED
mesh.export(os.path.join(self.base_exp_dir, 'meshes_' + mode, 'lod{:0>1d}'.format(lod),
'mesh_{:0>8d}_gradio_lod{:0>1d}.ply'.format(iter_step, lod)))