""" decouple the trainer with the renderer """ import os import cv2 as cv import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import trimesh from icecream import ic from utils.misc_utils import visualize_depth_numpy from loss.depth_metric import compute_depth_errors from loss.depth_loss import DepthLoss, DepthSmoothLoss from models.sparse_neus_renderer import SparseNeuSRenderer class GenericTrainer(nn.Module): def __init__(self, rendering_network_outside, pyramid_feature_network_lod0, pyramid_feature_network_lod1, sdf_network_lod0, sdf_network_lod1, variance_network_lod0, variance_network_lod1, rendering_network_lod0, rendering_network_lod1, n_samples_lod0, n_importance_lod0, n_samples_lod1, n_importance_lod1, n_outside, perturb, alpha_type='div', conf=None, timestamp="", mode='train', base_exp_dir=None, ): super(GenericTrainer, self).__init__() self.conf = conf self.timestamp = timestamp self.base_exp_dir = base_exp_dir self.anneal_start = self.conf.get_float('train.anneal_start', default=0.0) self.anneal_end = self.conf.get_float('train.anneal_end', default=0.0) self.anneal_start_lod1 = self.conf.get_float('train.anneal_start_lod1', default=0.0) self.anneal_end_lod1 = self.conf.get_float('train.anneal_end_lod1', default=0.0) # network setups self.rendering_network_outside = rendering_network_outside self.pyramid_feature_network_geometry_lod0 = pyramid_feature_network_lod0 # 2D pyramid feature network for geometry self.pyramid_feature_network_geometry_lod1 = pyramid_feature_network_lod1 # use differnet networks for the two lods # when num_lods==2, may consume too much memeory self.sdf_network_lod0 = sdf_network_lod0 self.sdf_network_lod1 = sdf_network_lod1 # - warpped by ModuleList to support DataParallel self.variance_network_lod0 = variance_network_lod0 self.variance_network_lod1 = variance_network_lod1 self.rendering_network_lod0 = rendering_network_lod0 self.rendering_network_lod1 = rendering_network_lod1 self.n_samples_lod0 = n_samples_lod0 self.n_importance_lod0 = n_importance_lod0 self.n_samples_lod1 = n_samples_lod1 self.n_importance_lod1 = n_importance_lod1 self.n_outside = n_outside self.num_lods = conf.get_int('model.num_lods') # the number of octree lods self.perturb = perturb self.alpha_type = alpha_type # - the two renderers self.sdf_renderer_lod0 = SparseNeuSRenderer( self.rendering_network_outside, self.sdf_network_lod0, self.variance_network_lod0, self.rendering_network_lod0, self.n_samples_lod0, self.n_importance_lod0, self.n_outside, self.perturb, alpha_type='div', conf=self.conf) self.sdf_renderer_lod1 = SparseNeuSRenderer( self.rendering_network_outside, self.sdf_network_lod1, self.variance_network_lod1, self.rendering_network_lod1, self.n_samples_lod1, self.n_importance_lod1, self.n_outside, self.perturb, alpha_type='div', conf=self.conf) self.if_fix_lod0_networks = self.conf.get_bool('train.if_fix_lod0_networks') # sdf network weights self.sdf_igr_weight = self.conf.get_float('train.sdf_igr_weight') self.sdf_sparse_weight = self.conf.get_float('train.sdf_sparse_weight', default=0) self.sdf_decay_param = self.conf.get_float('train.sdf_decay_param', default=100) self.fg_bg_weight = self.conf.get_float('train.fg_bg_weight', default=0.00) self.bg_ratio = self.conf.get_float('train.bg_ratio', default=0.0) self.depth_criterion = DepthLoss() # - DataParallel mode, cannot modify attributes in forward() # self.iter_step = 0 self.val_mesh_freq = self.conf.get_int('train.val_mesh_freq') # - True for finetuning; False for general training self.if_fitted_rendering = self.conf.get_bool('train.if_fitted_rendering', default=False) self.prune_depth_filter = self.conf.get_bool('model.prune_depth_filter', default=False) def get_trainable_params(self): # set trainable params self.params_to_train = [] if not self.if_fix_lod0_networks: # load pretrained featurenet self.params_to_train += list(self.pyramid_feature_network_geometry_lod0.parameters()) self.params_to_train += list(self.sdf_network_lod0.parameters()) self.params_to_train += list(self.variance_network_lod0.parameters()) if self.rendering_network_lod0 is not None: self.params_to_train += list(self.rendering_network_lod0.parameters()) if self.sdf_network_lod1 is not None: # load pretrained featurenet self.params_to_train += list(self.pyramid_feature_network_geometry_lod1.parameters()) self.params_to_train += list(self.sdf_network_lod1.parameters()) self.params_to_train += list(self.variance_network_lod1.parameters()) if self.rendering_network_lod1 is not None: self.params_to_train += list(self.rendering_network_lod1.parameters()) return self.params_to_train def train_step(self, sample, perturb_overwrite=-1, background_rgb=None, alpha_inter_ratio_lod0=0.0, alpha_inter_ratio_lod1=0.0, iter_step=0, ): # * only support batch_size==1 # ! attention: the list of string cannot be splited in DataParallel batch_idx = sample['batch_idx'][0] meta = sample['meta'][batch_idx] # the scan lighting ref_view info sizeW = sample['img_wh'][0][0] sizeH = sample['img_wh'][0][1] partial_vol_origin = sample['partial_vol_origin'] # [B, 3] near, far = sample['near_fars'][0, 0, :1], sample['near_fars'][0, 0, 1:] # the full-size ray variables sample_rays = sample['rays'] rays_o = sample_rays['rays_o'][0] rays_d = sample_rays['rays_v'][0] imgs = sample['images'][0] intrinsics = sample['intrinsics'][0] intrinsics_l_4x = intrinsics.clone() intrinsics_l_4x[:, :2] *= 0.25 w2cs = sample['w2cs'][0] c2ws = sample['c2ws'][0] proj_matrices = sample['affine_mats'] scale_mat = sample['scale_mat'] trans_mat = sample['trans_mat'] # *********************** Lod==0 *********************** if not self.if_fix_lod0_networks: geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs) conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume( feature_maps=geometry_feature_maps[None, 1:, :, :, :], partial_vol_origin=partial_vol_origin, proj_mats=proj_matrices[:,1:], # proj_mats=proj_matrices, sizeH=sizeH, sizeW=sizeW, lod=0, ) else: with torch.no_grad(): geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0) # geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0) conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume( feature_maps=geometry_feature_maps[None, 1:, :, :, :], partial_vol_origin=partial_vol_origin, proj_mats=proj_matrices[:,1:], # proj_mats=proj_matrices, sizeH=sizeH, sizeW=sizeW, lod=0, ) # print("Checker2:, construct cost volume") con_volume_lod0 = conditional_features_lod0['dense_volume_scale0'] con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0'] coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ] # * extract depth maps for all the images depth_maps_lod0, depth_masks_lod0 = None, None if self.num_lods > 1: sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume( con_volume_lod0, con_valid_mask_volume_lod0, coords_lod0, partial_vol_origin) # [1, 1, dX, dY, dZ] if self.prune_depth_filter: depth_maps_lod0_l4x, depth_masks_lod0_l4x = self.sdf_renderer_lod0.extract_depth_maps( self.sdf_network_lod0, sdf_volume_lod0, intrinsics_l_4x, c2ws, sizeH // 4, sizeW // 4, near * 1.5, far) depth_maps_lod0 = F.interpolate(depth_maps_lod0_l4x, size=(sizeH, sizeW), mode='bilinear', align_corners=True) depth_masks_lod0 = F.interpolate(depth_masks_lod0_l4x.float(), size=(sizeH, sizeW), mode='nearest') # *************** losses loss_lod0, losses_lod0, depth_statis_lod0 = None, None, None if not self.if_fix_lod0_networks: render_out = self.sdf_renderer_lod0.render( rays_o, rays_d, near, far, self.sdf_network_lod0, self.rendering_network_lod0, background_rgb=background_rgb, alpha_inter_ratio=alpha_inter_ratio_lod0, # * related to conditional feature lod=0, conditional_volume=con_volume_lod0, conditional_valid_mask_volume=con_valid_mask_volume_lod0, # * 2d feature maps feature_maps=geometry_feature_maps, color_maps=imgs, w2cs=w2cs, intrinsics=intrinsics, img_wh=[sizeW, sizeH], if_general_rendering=True, if_render_with_grad=True, ) loss_lod0, losses_lod0, depth_statis_lod0 = self.cal_losses_sdf(render_out, sample_rays, iter_step, lod=0) # *********************** Lod==1 *********************** loss_lod1, losses_lod1, depth_statis_lod1 = None, None, None if self.num_lods > 1: geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1) # geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1) if self.prune_depth_filter: pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf_depthfilter( sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0], depth_maps_lod0, proj_matrices[0], partial_vol_origin, self.sdf_network_lod0.voxel_size, near, far, self.sdf_network_lod0.voxel_size, 12) else: pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf( sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0]) pre_coords[:, 1:] = pre_coords[:, 1:] * 2 # ? It seems that training gru_fusion, this part should be trainable too conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume( feature_maps=geometry_feature_maps_lod1[None, 1:, :, :, :], partial_vol_origin=partial_vol_origin, proj_mats=proj_matrices[:,1:], # proj_mats=proj_matrices, sizeH=sizeH, sizeW=sizeW, pre_coords=pre_coords, pre_feats=pre_feats, ) con_volume_lod1 = conditional_features_lod1['dense_volume_scale1'] con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1'] # if not self.if_gru_fusion_lod1: render_out_lod1 = self.sdf_renderer_lod1.render( rays_o, rays_d, near, far, self.sdf_network_lod1, self.rendering_network_lod1, background_rgb=background_rgb, alpha_inter_ratio=alpha_inter_ratio_lod1, # * related to conditional feature lod=1, conditional_volume=con_volume_lod1, conditional_valid_mask_volume=con_valid_mask_volume_lod1, # * 2d feature maps feature_maps=geometry_feature_maps_lod1, color_maps=imgs, w2cs=w2cs, intrinsics=intrinsics, img_wh=[sizeW, sizeH], bg_ratio=self.bg_ratio, ) loss_lod1, losses_lod1, depth_statis_lod1 = self.cal_losses_sdf(render_out_lod1, sample_rays, iter_step, lod=1) # print("Checker3:, compute losses") # # - extract mesh if iter_step % self.val_mesh_freq == 0: torch.cuda.empty_cache() self.validate_mesh(self.sdf_network_lod0, self.sdf_renderer_lod0.extract_geometry, conditional_volume=con_volume_lod0, lod=0, threshold=0, # occupancy_mask=con_valid_mask_volume_lod0[0, 0], mode='train_bg', meta=meta, iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat) torch.cuda.empty_cache() if self.num_lods > 1: self.validate_mesh(self.sdf_network_lod1, self.sdf_renderer_lod1.extract_geometry, conditional_volume=con_volume_lod1, lod=1, # occupancy_mask=con_valid_mask_volume_lod1[0, 0].detach(), mode='train_bg', meta=meta, iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat) losses = { # - lod 0 'loss_lod0': loss_lod0, 'losses_lod0': losses_lod0, 'depth_statis_lod0': depth_statis_lod0, # - lod 1 'loss_lod1': loss_lod1, 'losses_lod1': losses_lod1, 'depth_statis_lod1': depth_statis_lod1, } return losses def val_step(self, sample, perturb_overwrite=-1, background_rgb=None, alpha_inter_ratio_lod0=0.0, alpha_inter_ratio_lod1=0.0, iter_step=0, chunk_size=512, save_vis=False, ): # * only support batch_size==1 # ! attention: the list of string cannot be splited in DataParallel batch_idx = sample['batch_idx'][0] meta = sample['meta'][batch_idx] # the scan lighting ref_view info sizeW = sample['img_wh'][0][0] sizeH = sample['img_wh'][0][1] H, W = sizeH, sizeW partial_vol_origin = sample['partial_vol_origin'] # [B, 3] near, far = sample['query_near_far'][0, :1], sample['query_near_far'][0, 1:] # the ray variables sample_rays = sample['rays'] rays_o = sample_rays['rays_o'][0] rays_d = sample_rays['rays_v'][0] rays_ndc_uv = sample_rays['rays_ndc_uv'][0] imgs = sample['images'][0] intrinsics = sample['intrinsics'][0] intrinsics_l_4x = intrinsics.clone() intrinsics_l_4x[:, :2] *= 0.25 w2cs = sample['w2cs'][0] c2ws = sample['c2ws'][0] proj_matrices = sample['affine_mats'] # render_img_idx = sample['render_img_idx'][0] # true_img = sample['images'][0][render_img_idx] # - the image to render scale_mat = sample['scale_mat'] # [1,4,4] used to convert mesh into true scale trans_mat = sample['trans_mat'] query_c2w = sample['query_c2w'] # [1,4,4] query_w2c = sample['query_w2c'] # [1,4,4] true_img = sample['query_image'][0] true_img = np.uint8(true_img.permute(1, 2, 0).cpu().numpy() * 255) depth_min, depth_max = near.cpu().numpy(), far.cpu().numpy() scale_factor = sample['scale_factor'][0].cpu().numpy() true_depth = sample['query_depth'] if 'query_depth' in sample.keys() else None if true_depth is not None: true_depth = true_depth[0].cpu().numpy() true_depth_colored = visualize_depth_numpy(true_depth, [depth_min, depth_max])[0] else: true_depth_colored = None rays_o = rays_o.reshape(-1, 3).split(chunk_size) rays_d = rays_d.reshape(-1, 3).split(chunk_size) # - obtain conditional features with torch.no_grad(): # - obtain conditional features geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0) # - lod 0 conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume( feature_maps=geometry_feature_maps[None, :, :, :, :], partial_vol_origin=partial_vol_origin, proj_mats=proj_matrices, sizeH=sizeH, sizeW=sizeW, lod=0, ) con_volume_lod0 = conditional_features_lod0['dense_volume_scale0'] con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0'] coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ] if self.num_lods > 1: sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume( con_volume_lod0, con_valid_mask_volume_lod0, coords_lod0, partial_vol_origin) # [1, 1, dX, dY, dZ] depth_maps_lod0, depth_masks_lod0 = None, None if self.prune_depth_filter: depth_maps_lod0_l4x, depth_masks_lod0_l4x = self.sdf_renderer_lod0.extract_depth_maps( self.sdf_network_lod0, sdf_volume_lod0, intrinsics_l_4x, c2ws, sizeH // 4, sizeW // 4, near * 1.5, far) # - near*1.5 is a experienced number depth_maps_lod0 = F.interpolate(depth_maps_lod0_l4x, size=(sizeH, sizeW), mode='bilinear', align_corners=True) depth_masks_lod0 = F.interpolate(depth_masks_lod0_l4x.float(), size=(sizeH, sizeW), mode='nearest') #### visualize the depth_maps_lod0 for checking colored_depth_maps_lod0 = [] for i in range(depth_maps_lod0.shape[0]): colored_depth_maps_lod0.append( visualize_depth_numpy(depth_maps_lod0[i, 0].cpu().numpy(), [depth_min, depth_max])[0]) colored_depth_maps_lod0 = np.concatenate(colored_depth_maps_lod0, axis=0).astype(np.uint8) os.makedirs(os.path.join(self.base_exp_dir, 'depth_maps_lod0'), exist_ok=True) cv.imwrite(os.path.join(self.base_exp_dir, 'depth_maps_lod0', '{:0>8d}_{}.png'.format(iter_step, meta)), colored_depth_maps_lod0[:, :, ::-1]) if self.num_lods > 1: geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1) if self.prune_depth_filter: pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf_depthfilter( sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0], depth_maps_lod0, proj_matrices[0], partial_vol_origin, self.sdf_network_lod0.voxel_size, near, far, self.sdf_network_lod0.voxel_size, 12) else: pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf( sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0]) pre_coords[:, 1:] = pre_coords[:, 1:] * 2 with torch.no_grad(): conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume( feature_maps=geometry_feature_maps_lod1[None, :, :, :, :], partial_vol_origin=partial_vol_origin, proj_mats=proj_matrices, sizeH=sizeH, sizeW=sizeW, pre_coords=pre_coords, pre_feats=pre_feats, ) con_volume_lod1 = conditional_features_lod1['dense_volume_scale1'] con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1'] out_rgb_fine = [] out_normal_fine = [] out_depth_fine = [] out_rgb_fine_lod1 = [] out_normal_fine_lod1 = [] out_depth_fine_lod1 = [] # out_depth_fine_explicit = [] if save_vis: for rays_o_batch, rays_d_batch in zip(rays_o, rays_d): # ****** lod 0 **** render_out = self.sdf_renderer_lod0.render( rays_o_batch, rays_d_batch, near, far, self.sdf_network_lod0, self.rendering_network_lod0, background_rgb=background_rgb, alpha_inter_ratio=alpha_inter_ratio_lod0, # * related to conditional feature lod=0, conditional_volume=con_volume_lod0, conditional_valid_mask_volume=con_valid_mask_volume_lod0, # * 2d feature maps feature_maps=geometry_feature_maps, color_maps=imgs, w2cs=w2cs, intrinsics=intrinsics, img_wh=[sizeW, sizeH], query_c2w=query_c2w, if_render_with_grad=False, ) feasible = lambda key: ((key in render_out) and (render_out[key] is not None)) if feasible('depth'): out_depth_fine.append(render_out['depth'].detach().cpu().numpy()) # if render_out['color_coarse'] is not None: if feasible('color_fine'): out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy()) if feasible('gradients') and feasible('weights'): if render_out['inside_sphere'] is not None: out_normal_fine.append((render_out['gradients'] * render_out['weights'][:, :self.n_samples_lod0 + self.n_importance_lod0, None] * render_out['inside_sphere'][ ..., None]).sum(dim=1).detach().cpu().numpy()) else: out_normal_fine.append((render_out['gradients'] * render_out['weights'][:, :self.n_samples_lod0 + self.n_importance_lod0, None]).sum(dim=1).detach().cpu().numpy()) del render_out # ****************** lod 1 ************************** if self.num_lods > 1: for rays_o_batch, rays_d_batch in zip(rays_o, rays_d): render_out_lod1 = self.sdf_renderer_lod1.render( rays_o_batch, rays_d_batch, near, far, self.sdf_network_lod1, self.rendering_network_lod1, background_rgb=background_rgb, alpha_inter_ratio=alpha_inter_ratio_lod1, # * related to conditional feature lod=1, conditional_volume=con_volume_lod1, conditional_valid_mask_volume=con_valid_mask_volume_lod1, # * 2d feature maps feature_maps=geometry_feature_maps_lod1, color_maps=imgs, w2cs=w2cs, intrinsics=intrinsics, img_wh=[sizeW, sizeH], query_c2w=query_c2w, if_render_with_grad=False, ) feasible = lambda key: ((key in render_out_lod1) and (render_out_lod1[key] is not None)) if feasible('depth'): out_depth_fine_lod1.append(render_out_lod1['depth'].detach().cpu().numpy()) # if render_out['color_coarse'] is not None: if feasible('color_fine'): out_rgb_fine_lod1.append(render_out_lod1['color_fine'].detach().cpu().numpy()) if feasible('gradients') and feasible('weights'): if render_out_lod1['inside_sphere'] is not None: out_normal_fine_lod1.append((render_out_lod1['gradients'] * render_out_lod1['weights'][:, :self.n_samples_lod1 + self.n_importance_lod1, None] * render_out_lod1['inside_sphere'][ ..., None]).sum(dim=1).detach().cpu().numpy()) else: out_normal_fine_lod1.append((render_out_lod1['gradients'] * render_out_lod1['weights'][:, :self.n_samples_lod1 + self.n_importance_lod1, None]).sum( dim=1).detach().cpu().numpy()) del render_out_lod1 # - save visualization of lod 0 self.save_visualization(true_img, true_depth_colored, out_depth_fine, out_normal_fine, query_w2c[0], out_rgb_fine, H, W, depth_min, depth_max, iter_step, meta, "val_lod0", true_depth=true_depth, scale_factor=scale_factor) if self.num_lods > 1: self.save_visualization(true_img, true_depth_colored, out_depth_fine_lod1, out_normal_fine_lod1, query_w2c[0], out_rgb_fine_lod1, H, W, depth_min, depth_max, iter_step, meta, "val_lod1", true_depth=true_depth, scale_factor=scale_factor) # - extract mesh if (iter_step % self.val_mesh_freq == 0): torch.cuda.empty_cache() self.validate_mesh(self.sdf_network_lod0, self.sdf_renderer_lod0.extract_geometry, conditional_volume=con_volume_lod0, lod=0, threshold=0, # occupancy_mask=con_valid_mask_volume_lod0[0, 0], mode='val_bg', meta=meta, iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat) torch.cuda.empty_cache() if self.num_lods > 1: self.validate_mesh(self.sdf_network_lod1, self.sdf_renderer_lod1.extract_geometry, conditional_volume=con_volume_lod1, lod=1, # occupancy_mask=con_valid_mask_volume_lod1[0, 0].detach(), mode='val_bg', meta=meta, iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat) torch.cuda.empty_cache() def export_mesh_step(self, sample, perturb_overwrite=-1, background_rgb=None, alpha_inter_ratio_lod0=0.0, alpha_inter_ratio_lod1=0.0, iter_step=0, chunk_size=512, save_vis=False, ): # * only support batch_size==1 # ! attention: the list of string cannot be splited in DataParallel batch_idx = sample['batch_idx'][0] meta = sample['meta'][batch_idx] # the scan lighting ref_view info sizeW = sample['img_wh'][0][0] sizeH = sample['img_wh'][0][1] H, W = sizeH, sizeW partial_vol_origin = sample['partial_vol_origin'] # [B, 3] near, far = sample['query_near_far'][0, :1], sample['query_near_far'][0, 1:] # the ray variables sample_rays = sample['rays'] rays_o = sample_rays['rays_o'][0] rays_d = sample_rays['rays_v'][0] rays_ndc_uv = sample_rays['rays_ndc_uv'][0] imgs = sample['images'][0] intrinsics = sample['intrinsics'][0] intrinsics_l_4x = intrinsics.clone() intrinsics_l_4x[:, :2] *= 0.25 w2cs = sample['w2cs'][0] c2ws = sample['c2ws'][0] # target_candidate_w2cs = sample['target_candidate_w2cs'][0] proj_matrices = sample['affine_mats'] # - the image to render scale_mat = sample['scale_mat'] # [1,4,4] used to convert mesh into true scale trans_mat = sample['trans_mat'] query_c2w = sample['query_c2w'] # [1,4,4] query_w2c = sample['query_w2c'] # [1,4,4] true_img = sample['query_image'][0] true_img = np.uint8(true_img.permute(1, 2, 0).cpu().numpy() * 255) depth_min, depth_max = near.cpu().numpy(), far.cpu().numpy() scale_factor = sample['scale_factor'][0].cpu().numpy() true_depth = sample['query_depth'] if 'query_depth' in sample.keys() else None if true_depth is not None: true_depth = true_depth[0].cpu().numpy() true_depth_colored = visualize_depth_numpy(true_depth, [depth_min, depth_max])[0] else: true_depth_colored = None rays_o = rays_o.reshape(-1, 3).split(chunk_size) rays_d = rays_d.reshape(-1, 3).split(chunk_size) # import time # jha_begin1 = time.time() # - obtain conditional features with torch.no_grad(): # - obtain conditional features geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0) # - lod 0 conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume( feature_maps=geometry_feature_maps[None, :, :, :, :], partial_vol_origin=partial_vol_origin, proj_mats=proj_matrices, sizeH=sizeH, sizeW=sizeW, lod=0, ) # jha_end1 = time.time() # print("get_conditional_volume: ", jha_end1 - jha_begin1) # jha_begin2 = time.time() con_volume_lod0 = conditional_features_lod0['dense_volume_scale0'] con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0'] coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ] if self.num_lods > 1: sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume( con_volume_lod0, con_valid_mask_volume_lod0, coords_lod0, partial_vol_origin) # [1, 1, dX, dY, dZ] depth_maps_lod0, depth_masks_lod0 = None, None if self.num_lods > 1: geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1) if self.prune_depth_filter: pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf_depthfilter( sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0], depth_maps_lod0, proj_matrices[0], partial_vol_origin, self.sdf_network_lod0.voxel_size, near, far, self.sdf_network_lod0.voxel_size, 12) else: pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf( sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0]) pre_coords[:, 1:] = pre_coords[:, 1:] * 2 with torch.no_grad(): conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume( feature_maps=geometry_feature_maps_lod1[None, :, :, :, :], partial_vol_origin=partial_vol_origin, proj_mats=proj_matrices, sizeH=sizeH, sizeW=sizeW, pre_coords=pre_coords, pre_feats=pre_feats, ) con_volume_lod1 = conditional_features_lod1['dense_volume_scale1'] con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1'] out_rgb_fine = [] out_normal_fine = [] out_depth_fine = [] out_rgb_fine_lod1 = [] out_normal_fine_lod1 = [] out_depth_fine_lod1 = [] # jha_end2 = time.time() # print("interval before starting mesh export: ", jha_end2 - jha_begin2) # - extract mesh if (iter_step % self.val_mesh_freq == 0): torch.cuda.empty_cache() # jha_begin3 = time.time() self.validate_colored_mesh( density_or_sdf_network=self.sdf_network_lod0, func_extract_geometry=self.sdf_renderer_lod0.extract_geometry, conditional_volume=con_volume_lod0, conditional_valid_mask_volume = con_valid_mask_volume_lod0, feature_maps=geometry_feature_maps, color_maps=imgs, w2cs=w2cs, target_candidate_w2cs=None, intrinsics=intrinsics, rendering_network=self.rendering_network_lod0, rendering_projector=self.sdf_renderer_lod0.rendering_projector, lod=0, threshold=0, query_c2w=query_c2w, mode='val_bg', meta=meta, iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat ) torch.cuda.empty_cache() # jha_end3 = time.time() # print("validate_colored_mesh_test_time: ", jha_end3 - jha_begin3) if self.num_lods > 1: self.validate_colored_mesh( density_or_sdf_network=self.sdf_network_lod1, func_extract_geometry=self.sdf_renderer_lod1.extract_geometry, conditional_volume=con_volume_lod1, conditional_valid_mask_volume = con_valid_mask_volume_lod1, feature_maps=geometry_feature_maps, color_maps=imgs, w2cs=w2cs, target_candidate_w2cs=None, intrinsics=intrinsics, rendering_network=self.rendering_network_lod1, rendering_projector=self.sdf_renderer_lod1.rendering_projector, lod=1, threshold=0, query_c2w=query_c2w, mode='val_bg', meta=meta, iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat ) torch.cuda.empty_cache() def save_visualization(self, true_img, true_colored_depth, out_depth, out_normal, w2cs, out_color, H, W, depth_min, depth_max, iter_step, meta, comment, out_color_mlp=[], true_depth=None, scale_factor=1.0): if len(out_color) > 0: img_fine = (np.concatenate(out_color, axis=0).reshape([H, W, 3]) * 256).clip(0, 255) if len(out_color_mlp) > 0: img_mlp = (np.concatenate(out_color_mlp, axis=0).reshape([H, W, 3]) * 256).clip(0, 255) if len(out_normal) > 0: normal_img = np.concatenate(out_normal, axis=0) rot = w2cs[:3, :3].detach().cpu().numpy() # - convert normal from world space to camera space normal_img = (np.matmul(rot[None, :, :], normal_img[:, :, None]).reshape([H, W, 3]) * 128 + 128).clip(0, 255) if len(out_depth) > 0: pred_depth = np.concatenate(out_depth, axis=0).reshape([H, W]) pred_depth_colored = visualize_depth_numpy(pred_depth, [depth_min, depth_max])[0] if len(out_depth) > 0: os.makedirs(os.path.join(self.base_exp_dir, 'depths_' + comment), exist_ok=True) if true_colored_depth is not None: if true_depth is not None: depth_error_map = np.abs(true_depth - pred_depth) * 2.0 / scale_factor # [256, 256, 1] -> [256, 256, 3] depth_error_map = np.tile(depth_error_map[:, :, None], [1, 1, 3]) print("meta: ", meta) print("scale_factor: ", scale_factor) print("depth_error_mean: ", depth_error_map.mean()) depth_visualized = np.concatenate( [(depth_error_map * 255).astype(np.uint8), true_colored_depth, pred_depth_colored, true_img], axis=1)[:, :, ::-1] # print("depth_visualized.shape: ", depth_visualized.shape) # write depth error result text on img, the input is a numpy array of [256, 1024, 3] # cv.putText(depth_visualized.copy(), "depth_error_mean: {:.4f}".format(depth_error_map.mean()), (10, 30), cv.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) else: depth_visualized = np.concatenate( [true_colored_depth, pred_depth_colored, true_img])[:, :, ::-1] cv.imwrite( os.path.join(self.base_exp_dir, 'depths_' + comment, '{:0>8d}_{}.png'.format(iter_step, meta)), depth_visualized ) else: cv.imwrite( os.path.join(self.base_exp_dir, 'depths_' + comment, '{:0>8d}_{}.png'.format(iter_step, meta)), np.concatenate( [pred_depth_colored, true_img])[:, :, ::-1]) if len(out_color) > 0: os.makedirs(os.path.join(self.base_exp_dir, 'synthesized_color_' + comment), exist_ok=True) cv.imwrite(os.path.join(self.base_exp_dir, 'synthesized_color_' + comment, '{:0>8d}_{}.png'.format(iter_step, meta)), np.concatenate( [img_fine, true_img])[:, :, ::-1]) # bgr2rgb # compute psnr (image pixel lie in [0, 255]) mse_loss = np.mean((img_fine - true_img) ** 2) psnr = 10 * np.log10(255 ** 2 / mse_loss) print("PSNR: ", psnr) if len(out_color_mlp) > 0: os.makedirs(os.path.join(self.base_exp_dir, 'synthesized_color_mlp_' + comment), exist_ok=True) cv.imwrite(os.path.join(self.base_exp_dir, 'synthesized_color_mlp_' + comment, '{:0>8d}_{}.png'.format(iter_step, meta)), np.concatenate( [img_mlp, true_img])[:, :, ::-1]) # bgr2rgb if len(out_normal) > 0: os.makedirs(os.path.join(self.base_exp_dir, 'normals_' + comment), exist_ok=True) cv.imwrite(os.path.join(self.base_exp_dir, 'normals_' + comment, '{:0>8d}_{}.png'.format(iter_step, meta)), normal_img[:, :, ::-1]) def forward(self, sample, perturb_overwrite=-1, background_rgb=None, alpha_inter_ratio_lod0=0.0, alpha_inter_ratio_lod1=0.0, iter_step=0, mode='train', save_vis=False, ): if mode == 'train': return self.train_step(sample, perturb_overwrite=perturb_overwrite, background_rgb=background_rgb, alpha_inter_ratio_lod0=alpha_inter_ratio_lod0, alpha_inter_ratio_lod1=alpha_inter_ratio_lod1, iter_step=iter_step ) elif mode == 'val': import time begin = time.time() result = self.val_step(sample, perturb_overwrite=perturb_overwrite, background_rgb=background_rgb, alpha_inter_ratio_lod0=alpha_inter_ratio_lod0, alpha_inter_ratio_lod1=alpha_inter_ratio_lod1, iter_step=iter_step, save_vis=save_vis, ) end = time.time() print("val_step time: ", end - begin) return result elif mode == 'export_mesh': import time begin = time.time() result = self.export_mesh_step(sample, perturb_overwrite=perturb_overwrite, background_rgb=background_rgb, alpha_inter_ratio_lod0=alpha_inter_ratio_lod0, alpha_inter_ratio_lod1=alpha_inter_ratio_lod1, iter_step=iter_step, save_vis=save_vis, ) end = time.time() print("export mesh time: ", end - begin) return result def obtain_pyramid_feature_maps(self, imgs, lod=0): """ get feature maps of all conditional images :param imgs: :return: """ if lod == 0: extractor = self.pyramid_feature_network_geometry_lod0 elif lod >= 1: extractor = self.pyramid_feature_network_geometry_lod1 pyramid_feature_maps = extractor(imgs) # * the pyramid features are very important, if only use the coarst features, hard to optimize fused_feature_maps = torch.cat([ F.interpolate(pyramid_feature_maps[0], scale_factor=4, mode='bilinear', align_corners=True), F.interpolate(pyramid_feature_maps[1], scale_factor=2, mode='bilinear', align_corners=True), pyramid_feature_maps[2] ], dim=1) return fused_feature_maps def cal_losses_sdf(self, render_out, sample_rays, iter_step=-1, lod=0): # loss weight schedule; the regularization terms should be added in later training stage def get_weight(iter_step, weight): if lod == 1: anneal_start = self.anneal_end if lod == 0 else self.anneal_end_lod1 anneal_end = self.anneal_end if lod == 0 else self.anneal_end_lod1 anneal_end = anneal_end * 2 else: anneal_start = self.anneal_start if lod == 0 else self.anneal_start_lod1 anneal_end = self.anneal_end if lod == 0 else self.anneal_end_lod1 anneal_end = anneal_end * 2 if iter_step < 0: return weight if anneal_end == 0.0: return weight elif iter_step < anneal_start: return 0.0 else: return np.min( [1.0, (iter_step - anneal_start) / (anneal_end - anneal_start)]) * weight rays_o = sample_rays['rays_o'][0] rays_d = sample_rays['rays_v'][0] true_rgb = sample_rays['rays_color'][0] if 'rays_depth' in sample_rays.keys(): true_depth = sample_rays['rays_depth'][0] else: true_depth = None mask = sample_rays['rays_mask'][0] color_fine = render_out['color_fine'] color_fine_mask = render_out['color_fine_mask'] depth_pred = render_out['depth'] variance = render_out['variance'] cdf_fine = render_out['cdf_fine'] weight_sum = render_out['weights_sum'] gradient_error_fine = render_out['gradient_error_fine'] sdf = render_out['sdf'] # * color generated by mlp color_mlp = render_out['color_mlp'] color_mlp_mask = render_out['color_mlp_mask'] if color_fine is not None: # Color loss color_mask = color_fine_mask if color_fine_mask is not None else mask color_mask = color_mask[..., 0] color_error = (color_fine[color_mask] - true_rgb[color_mask]) # print("Nan number", torch.isnan(color_error).sum()) # print("Color error shape", color_error.shape) color_fine_loss = F.l1_loss(color_error, torch.zeros_like(color_error).to(color_error.device), reduction='mean') # print(color_fine_loss) psnr = 20.0 * torch.log10( 1.0 / (((color_fine[color_mask] - true_rgb[color_mask]) ** 2).mean() / (3.0)).sqrt()) else: color_fine_loss = 0. psnr = 0. if color_mlp is not None: # Color loss color_mlp_mask = color_mlp_mask[..., 0] color_error_mlp = (color_mlp[color_mlp_mask] - true_rgb[color_mlp_mask]) color_mlp_loss = F.l1_loss(color_error_mlp, torch.zeros_like(color_error_mlp).to(color_error_mlp.device), reduction='mean') psnr_mlp = 20.0 * torch.log10( 1.0 / (((color_mlp[color_mlp_mask] - true_rgb[color_mlp_mask]) ** 2).mean() / (3.0)).sqrt()) else: color_mlp_loss = 0. psnr_mlp = 0. # depth loss is only used for inference, not included in total loss if true_depth is not None: # depth_loss = self.depth_criterion(depth_pred, true_depth, mask) depth_loss = self.depth_criterion(depth_pred, true_depth) # # depth evaluation # depth_statis = compute_depth_errors(depth_pred.detach().cpu().numpy(), true_depth.cpu().numpy()) # depth_statis = numpy2tensor(depth_statis, device=rays_o.device) depth_statis = None else: depth_loss = 0. depth_statis = None sparse_loss_1 = torch.exp( -1 * torch.abs(render_out['sdf_random']) * self.sdf_decay_param).mean() # - should equal sparse_loss_2 = torch.exp(-1 * torch.abs(sdf) * self.sdf_decay_param).mean() sparse_loss = (sparse_loss_1 + sparse_loss_2) / 2 sdf_mean = torch.abs(sdf).mean() sparseness_1 = (torch.abs(sdf) < 0.01).to(torch.float32).mean() sparseness_2 = (torch.abs(sdf) < 0.02).to(torch.float32).mean() # Eikonal loss gradient_error_loss = gradient_error_fine # ! the first 50k, don't use bg constraint fg_bg_weight = 0.0 if iter_step < 50000 else get_weight(iter_step, self.fg_bg_weight) # Mask loss, optional # The images of DTU dataset contain large black regions (0 rgb values), # can use this data prior to make fg more clean background_loss = 0.0 fg_bg_loss = 0.0 if self.fg_bg_weight > 0 and torch.mean((mask < 0.5).to(torch.float32)) > 0.02: weights_sum_fg = render_out['weights_sum_fg'] fg_bg_error = (weights_sum_fg - mask)[mask < 0.5] fg_bg_loss = F.l1_loss(fg_bg_error, torch.zeros_like(fg_bg_error).to(fg_bg_error.device), reduction='mean') loss = 1.0 * depth_loss + color_fine_loss + color_mlp_loss + \ sparse_loss * get_weight(iter_step, self.sdf_sparse_weight) + \ fg_bg_loss * fg_bg_weight + \ gradient_error_loss * self.sdf_igr_weight # ! gradient_error_loss need a mask losses = { "loss": loss, "depth_loss": depth_loss, "color_fine_loss": color_fine_loss, "color_mlp_loss": color_mlp_loss, "gradient_error_loss": gradient_error_loss, "background_loss": background_loss, "sparse_loss": sparse_loss, "sparseness_1": sparseness_1, "sparseness_2": sparseness_2, "sdf_mean": sdf_mean, "psnr": psnr, "psnr_mlp": psnr_mlp, "weights_sum": render_out['weights_sum'], "weights_sum_fg": render_out['weights_sum_fg'], "alpha_sum": render_out['alpha_sum'], "variance": render_out['variance'], "sparse_weight": get_weight(iter_step, self.sdf_sparse_weight), "fg_bg_weight": fg_bg_weight, "fg_bg_loss": fg_bg_loss, # added by jha, bug of sparseNeuS } losses = torch.tensor(losses, device=rays_o.device) return loss, losses, depth_statis @torch.no_grad() def validate_mesh(self, density_or_sdf_network, func_extract_geometry, world_space=True, resolution=360, threshold=0.0, mode='val', # * 3d feature volume conditional_volume=None, lod=None, occupancy_mask=None, bound_min=[-1, -1, -1], bound_max=[1, 1, 1], meta='', iter_step=0, scale_mat=None, trans_mat=None ): bound_min = torch.tensor(bound_min, dtype=torch.float32) bound_max = torch.tensor(bound_max, dtype=torch.float32) vertices, triangles, fields = func_extract_geometry( density_or_sdf_network, bound_min, bound_max, resolution=resolution, threshold=threshold, device=conditional_volume.device, # * 3d feature volume conditional_volume=conditional_volume, lod=lod, occupancy_mask=occupancy_mask ) if scale_mat is not None: scale_mat_np = scale_mat.cpu().numpy() vertices = vertices * scale_mat_np[0][0, 0] + scale_mat_np[0][:3, 3][None] if trans_mat is not None: # w2c_ref_inv trans_mat_np = trans_mat.cpu().numpy() vertices_homo = np.concatenate([vertices, np.ones_like(vertices[:, :1])], axis=1) vertices = np.matmul(trans_mat_np, vertices_homo[:, :, None])[:, :3, 0] mesh = trimesh.Trimesh(vertices, triangles) os.makedirs(os.path.join(self.base_exp_dir, 'meshes_' + mode), exist_ok=True) mesh.export(os.path.join(self.base_exp_dir, 'meshes_' + mode, 'mesh_{:0>8d}_{}_lod{:0>1d}.ply'.format(iter_step, meta, lod))) def validate_colored_mesh(self, density_or_sdf_network, func_extract_geometry, world_space=True, resolution=360, threshold=0.0, mode='val', # * 3d feature volume conditional_volume=None, conditional_valid_mask_volume=None, feature_maps=None, color_maps = None, w2cs=None, target_candidate_w2cs=None, intrinsics=None, rendering_network=None, rendering_projector=None, query_c2w=None, lod=None, occupancy_mask=None, bound_min=[-1, -1, -1], bound_max=[1, 1, 1], meta='', iter_step=0, scale_mat=None, trans_mat=None ): bound_min = torch.tensor(bound_min, dtype=torch.float32) bound_max = torch.tensor(bound_max, dtype=torch.float32) # import time # jha_begin4 = time.time() vertices, triangles, fields = func_extract_geometry( density_or_sdf_network, bound_min, bound_max, resolution=resolution, threshold=threshold, device=conditional_volume.device, # * 3d feature volume conditional_volume=conditional_volume, lod=lod, occupancy_mask=occupancy_mask ) # jha_end4 = time.time() # print("[TEST]: func_extract_geometry time", jha_end4 - jha_begin4) # import time # jha_begin5 = time.time() with torch.no_grad(): ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask, _, _ = rendering_projector.compute_view_independent( torch.tensor(vertices).to(conditional_volume), lod=lod, # JHA EDITED # * 3d geometry feature volumes geometryVolume=conditional_volume[0], geometryVolumeMask=conditional_valid_mask_volume[0], sdf_network=density_or_sdf_network, # * 2d rendering feature maps rendering_feature_maps=feature_maps, # [n_view, 56, 256, 256] color_maps=color_maps, w2cs=w2cs, target_candidate_w2cs=target_candidate_w2cs, intrinsics=intrinsics, img_wh=[256,256], query_img_idx=0, # the index of the N_views dim for rendering query_c2w=query_c2w, ) vertices_color, rendering_valid_mask = rendering_network( ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask) # jha_end5 = time.time() # print("[TEST]: rendering_network time", jha_end5 - jha_begin5) if scale_mat is not None: scale_mat_np = scale_mat.cpu().numpy() vertices = vertices * scale_mat_np[0][0, 0] + scale_mat_np[0][:3, 3][None] if trans_mat is not None: # w2c_ref_inv trans_mat_np = trans_mat.cpu().numpy() vertices_homo = np.concatenate([vertices, np.ones_like(vertices[:, :1])], axis=1) vertices = np.matmul(trans_mat_np, vertices_homo[:, :, None])[:, :3, 0] vertices_color = np.array(vertices_color.squeeze(0).cpu() * 255, dtype=np.uint8) mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=vertices_color) os.makedirs(os.path.join(self.base_exp_dir, 'meshes_' + mode, 'lod{:0>1d}'.format(lod)), exist_ok=True) # mesh.export(os.path.join(self.base_exp_dir, 'meshes_' + mode, 'lod{:0>1d}'.format(lod), # 'mesh_{:0>8d}_{}_lod{:0>1d}.ply'.format(iter_step, meta, lod))) # MODIFIED mesh.export(os.path.join(self.base_exp_dir, 'meshes_' + mode, 'lod{:0>1d}'.format(lod), 'mesh_{:0>8d}_gradio_lod{:0>1d}.ply'.format(iter_step, lod)))