|
""" |
|
decouple the trainer with the renderer |
|
""" |
|
import os |
|
import cv2 as cv |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import numpy as np |
|
import trimesh |
|
from icecream import ic |
|
|
|
from utils.misc_utils import visualize_depth_numpy |
|
|
|
from loss.depth_metric import compute_depth_errors |
|
|
|
from loss.depth_loss import DepthLoss, DepthSmoothLoss |
|
|
|
from models.sparse_neus_renderer import SparseNeuSRenderer |
|
|
|
class GenericTrainer(nn.Module): |
|
def __init__(self, |
|
rendering_network_outside, |
|
pyramid_feature_network_lod0, |
|
pyramid_feature_network_lod1, |
|
sdf_network_lod0, |
|
sdf_network_lod1, |
|
variance_network_lod0, |
|
variance_network_lod1, |
|
rendering_network_lod0, |
|
rendering_network_lod1, |
|
n_samples_lod0, |
|
n_importance_lod0, |
|
n_samples_lod1, |
|
n_importance_lod1, |
|
n_outside, |
|
perturb, |
|
alpha_type='div', |
|
conf=None, |
|
timestamp="", |
|
mode='train', |
|
base_exp_dir=None, |
|
): |
|
super(GenericTrainer, self).__init__() |
|
|
|
self.conf = conf |
|
self.timestamp = timestamp |
|
|
|
|
|
self.base_exp_dir = base_exp_dir |
|
|
|
|
|
self.anneal_start = self.conf.get_float('train.anneal_start', default=0.0) |
|
self.anneal_end = self.conf.get_float('train.anneal_end', default=0.0) |
|
self.anneal_start_lod1 = self.conf.get_float('train.anneal_start_lod1', default=0.0) |
|
self.anneal_end_lod1 = self.conf.get_float('train.anneal_end_lod1', default=0.0) |
|
|
|
|
|
self.rendering_network_outside = rendering_network_outside |
|
self.pyramid_feature_network_geometry_lod0 = pyramid_feature_network_lod0 |
|
self.pyramid_feature_network_geometry_lod1 = pyramid_feature_network_lod1 |
|
|
|
|
|
self.sdf_network_lod0 = sdf_network_lod0 |
|
self.sdf_network_lod1 = sdf_network_lod1 |
|
|
|
|
|
self.variance_network_lod0 = variance_network_lod0 |
|
self.variance_network_lod1 = variance_network_lod1 |
|
|
|
self.rendering_network_lod0 = rendering_network_lod0 |
|
self.rendering_network_lod1 = rendering_network_lod1 |
|
|
|
self.n_samples_lod0 = n_samples_lod0 |
|
self.n_importance_lod0 = n_importance_lod0 |
|
self.n_samples_lod1 = n_samples_lod1 |
|
self.n_importance_lod1 = n_importance_lod1 |
|
self.n_outside = n_outside |
|
self.num_lods = conf.get_int('model.num_lods') |
|
self.perturb = perturb |
|
self.alpha_type = alpha_type |
|
|
|
|
|
self.sdf_renderer_lod0 = SparseNeuSRenderer( |
|
self.rendering_network_outside, |
|
self.sdf_network_lod0, |
|
self.variance_network_lod0, |
|
self.rendering_network_lod0, |
|
self.n_samples_lod0, |
|
self.n_importance_lod0, |
|
self.n_outside, |
|
self.perturb, |
|
alpha_type='div', |
|
conf=self.conf) |
|
|
|
self.sdf_renderer_lod1 = SparseNeuSRenderer( |
|
self.rendering_network_outside, |
|
self.sdf_network_lod1, |
|
self.variance_network_lod1, |
|
self.rendering_network_lod1, |
|
self.n_samples_lod1, |
|
self.n_importance_lod1, |
|
self.n_outside, |
|
self.perturb, |
|
alpha_type='div', |
|
conf=self.conf) |
|
|
|
self.if_fix_lod0_networks = self.conf.get_bool('train.if_fix_lod0_networks') |
|
|
|
|
|
self.sdf_igr_weight = self.conf.get_float('train.sdf_igr_weight') |
|
self.sdf_sparse_weight = self.conf.get_float('train.sdf_sparse_weight', default=0) |
|
self.sdf_decay_param = self.conf.get_float('train.sdf_decay_param', default=100) |
|
self.fg_bg_weight = self.conf.get_float('train.fg_bg_weight', default=0.00) |
|
self.bg_ratio = self.conf.get_float('train.bg_ratio', default=0.0) |
|
|
|
self.depth_criterion = DepthLoss() |
|
|
|
|
|
|
|
self.val_mesh_freq = self.conf.get_int('train.val_mesh_freq') |
|
|
|
|
|
self.if_fitted_rendering = self.conf.get_bool('train.if_fitted_rendering', default=False) |
|
|
|
self.prune_depth_filter = self.conf.get_bool('model.prune_depth_filter', default=False) |
|
|
|
def get_trainable_params(self): |
|
|
|
|
|
self.params_to_train = [] |
|
|
|
if not self.if_fix_lod0_networks: |
|
|
|
self.params_to_train += list(self.pyramid_feature_network_geometry_lod0.parameters()) |
|
self.params_to_train += list(self.sdf_network_lod0.parameters()) |
|
self.params_to_train += list(self.variance_network_lod0.parameters()) |
|
|
|
if self.rendering_network_lod0 is not None: |
|
self.params_to_train += list(self.rendering_network_lod0.parameters()) |
|
|
|
if self.sdf_network_lod1 is not None: |
|
|
|
self.params_to_train += list(self.pyramid_feature_network_geometry_lod1.parameters()) |
|
|
|
self.params_to_train += list(self.sdf_network_lod1.parameters()) |
|
self.params_to_train += list(self.variance_network_lod1.parameters()) |
|
if self.rendering_network_lod1 is not None: |
|
self.params_to_train += list(self.rendering_network_lod1.parameters()) |
|
|
|
return self.params_to_train |
|
|
|
def train_step(self, sample, |
|
perturb_overwrite=-1, |
|
background_rgb=None, |
|
alpha_inter_ratio_lod0=0.0, |
|
alpha_inter_ratio_lod1=0.0, |
|
iter_step=0, |
|
): |
|
|
|
|
|
batch_idx = sample['batch_idx'][0] |
|
meta = sample['meta'][batch_idx] |
|
|
|
sizeW = sample['img_wh'][0][0] |
|
sizeH = sample['img_wh'][0][1] |
|
partial_vol_origin = sample['partial_vol_origin'] |
|
near, far = sample['near_fars'][0, 0, :1], sample['near_fars'][0, 0, 1:] |
|
|
|
|
|
sample_rays = sample['rays'] |
|
rays_o = sample_rays['rays_o'][0] |
|
rays_d = sample_rays['rays_v'][0] |
|
|
|
imgs = sample['images'][0] |
|
intrinsics = sample['intrinsics'][0] |
|
intrinsics_l_4x = intrinsics.clone() |
|
intrinsics_l_4x[:, :2] *= 0.25 |
|
w2cs = sample['w2cs'][0] |
|
c2ws = sample['c2ws'][0] |
|
proj_matrices = sample['affine_mats'] |
|
scale_mat = sample['scale_mat'] |
|
trans_mat = sample['trans_mat'] |
|
|
|
|
|
if not self.if_fix_lod0_networks: |
|
geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs) |
|
|
|
conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume( |
|
feature_maps=geometry_feature_maps[None, 1:, :, :, :], |
|
partial_vol_origin=partial_vol_origin, |
|
proj_mats=proj_matrices[:,1:], |
|
|
|
sizeH=sizeH, |
|
sizeW=sizeW, |
|
lod=0, |
|
) |
|
|
|
else: |
|
with torch.no_grad(): |
|
geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0) |
|
|
|
conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume( |
|
feature_maps=geometry_feature_maps[None, 1:, :, :, :], |
|
partial_vol_origin=partial_vol_origin, |
|
proj_mats=proj_matrices[:,1:], |
|
|
|
sizeH=sizeH, |
|
sizeW=sizeW, |
|
lod=0, |
|
) |
|
|
|
con_volume_lod0 = conditional_features_lod0['dense_volume_scale0'] |
|
|
|
con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0'] |
|
coords_lod0 = conditional_features_lod0['coords_scale0'] |
|
|
|
|
|
depth_maps_lod0, depth_masks_lod0 = None, None |
|
if self.num_lods > 1: |
|
sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume( |
|
con_volume_lod0, con_valid_mask_volume_lod0, |
|
coords_lod0, partial_vol_origin) |
|
|
|
if self.prune_depth_filter: |
|
depth_maps_lod0_l4x, depth_masks_lod0_l4x = self.sdf_renderer_lod0.extract_depth_maps( |
|
self.sdf_network_lod0, sdf_volume_lod0, intrinsics_l_4x, c2ws, |
|
sizeH // 4, sizeW // 4, near * 1.5, far) |
|
depth_maps_lod0 = F.interpolate(depth_maps_lod0_l4x, size=(sizeH, sizeW), mode='bilinear', |
|
align_corners=True) |
|
depth_masks_lod0 = F.interpolate(depth_masks_lod0_l4x.float(), size=(sizeH, sizeW), mode='nearest') |
|
|
|
|
|
loss_lod0, losses_lod0, depth_statis_lod0 = None, None, None |
|
|
|
if not self.if_fix_lod0_networks: |
|
|
|
render_out = self.sdf_renderer_lod0.render( |
|
rays_o, rays_d, near, far, |
|
self.sdf_network_lod0, |
|
self.rendering_network_lod0, |
|
background_rgb=background_rgb, |
|
alpha_inter_ratio=alpha_inter_ratio_lod0, |
|
|
|
lod=0, |
|
conditional_volume=con_volume_lod0, |
|
conditional_valid_mask_volume=con_valid_mask_volume_lod0, |
|
|
|
feature_maps=geometry_feature_maps, |
|
color_maps=imgs, |
|
w2cs=w2cs, |
|
intrinsics=intrinsics, |
|
img_wh=[sizeW, sizeH], |
|
if_general_rendering=True, |
|
if_render_with_grad=True, |
|
) |
|
|
|
loss_lod0, losses_lod0, depth_statis_lod0 = self.cal_losses_sdf(render_out, sample_rays, |
|
iter_step, lod=0) |
|
|
|
|
|
|
|
loss_lod1, losses_lod1, depth_statis_lod1 = None, None, None |
|
|
|
if self.num_lods > 1: |
|
geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1) |
|
|
|
if self.prune_depth_filter: |
|
pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf_depthfilter( |
|
sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0], |
|
depth_maps_lod0, proj_matrices[0], |
|
partial_vol_origin, self.sdf_network_lod0.voxel_size, |
|
near, far, self.sdf_network_lod0.voxel_size, 12) |
|
else: |
|
pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf( |
|
sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0]) |
|
|
|
pre_coords[:, 1:] = pre_coords[:, 1:] * 2 |
|
|
|
|
|
conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume( |
|
feature_maps=geometry_feature_maps_lod1[None, 1:, :, :, :], |
|
partial_vol_origin=partial_vol_origin, |
|
proj_mats=proj_matrices[:,1:], |
|
|
|
sizeH=sizeH, |
|
sizeW=sizeW, |
|
pre_coords=pre_coords, |
|
pre_feats=pre_feats, |
|
) |
|
|
|
con_volume_lod1 = conditional_features_lod1['dense_volume_scale1'] |
|
con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1'] |
|
|
|
|
|
render_out_lod1 = self.sdf_renderer_lod1.render( |
|
rays_o, rays_d, near, far, |
|
self.sdf_network_lod1, |
|
self.rendering_network_lod1, |
|
background_rgb=background_rgb, |
|
alpha_inter_ratio=alpha_inter_ratio_lod1, |
|
|
|
lod=1, |
|
conditional_volume=con_volume_lod1, |
|
conditional_valid_mask_volume=con_valid_mask_volume_lod1, |
|
|
|
feature_maps=geometry_feature_maps_lod1, |
|
color_maps=imgs, |
|
w2cs=w2cs, |
|
intrinsics=intrinsics, |
|
img_wh=[sizeW, sizeH], |
|
bg_ratio=self.bg_ratio, |
|
) |
|
loss_lod1, losses_lod1, depth_statis_lod1 = self.cal_losses_sdf(render_out_lod1, sample_rays, |
|
iter_step, lod=1) |
|
|
|
|
|
|
|
if iter_step % self.val_mesh_freq == 0: |
|
torch.cuda.empty_cache() |
|
self.validate_mesh(self.sdf_network_lod0, |
|
self.sdf_renderer_lod0.extract_geometry, |
|
conditional_volume=con_volume_lod0, lod=0, |
|
threshold=0, |
|
|
|
mode='train_bg', meta=meta, |
|
iter_step=iter_step, scale_mat=scale_mat, |
|
trans_mat=trans_mat) |
|
torch.cuda.empty_cache() |
|
|
|
if self.num_lods > 1: |
|
self.validate_mesh(self.sdf_network_lod1, |
|
self.sdf_renderer_lod1.extract_geometry, |
|
conditional_volume=con_volume_lod1, lod=1, |
|
|
|
mode='train_bg', meta=meta, |
|
iter_step=iter_step, scale_mat=scale_mat, |
|
trans_mat=trans_mat) |
|
losses = { |
|
|
|
'loss_lod0': loss_lod0, |
|
'losses_lod0': losses_lod0, |
|
'depth_statis_lod0': depth_statis_lod0, |
|
|
|
|
|
'loss_lod1': loss_lod1, |
|
'losses_lod1': losses_lod1, |
|
'depth_statis_lod1': depth_statis_lod1, |
|
|
|
} |
|
|
|
return losses |
|
|
|
def val_step(self, sample, |
|
perturb_overwrite=-1, |
|
background_rgb=None, |
|
alpha_inter_ratio_lod0=0.0, |
|
alpha_inter_ratio_lod1=0.0, |
|
iter_step=0, |
|
chunk_size=512, |
|
save_vis=False, |
|
): |
|
|
|
|
|
batch_idx = sample['batch_idx'][0] |
|
meta = sample['meta'][batch_idx] |
|
|
|
sizeW = sample['img_wh'][0][0] |
|
sizeH = sample['img_wh'][0][1] |
|
H, W = sizeH, sizeW |
|
|
|
partial_vol_origin = sample['partial_vol_origin'] |
|
near, far = sample['query_near_far'][0, :1], sample['query_near_far'][0, 1:] |
|
|
|
|
|
sample_rays = sample['rays'] |
|
rays_o = sample_rays['rays_o'][0] |
|
rays_d = sample_rays['rays_v'][0] |
|
rays_ndc_uv = sample_rays['rays_ndc_uv'][0] |
|
|
|
imgs = sample['images'][0] |
|
intrinsics = sample['intrinsics'][0] |
|
intrinsics_l_4x = intrinsics.clone() |
|
intrinsics_l_4x[:, :2] *= 0.25 |
|
w2cs = sample['w2cs'][0] |
|
c2ws = sample['c2ws'][0] |
|
proj_matrices = sample['affine_mats'] |
|
|
|
|
|
|
|
|
|
|
|
scale_mat = sample['scale_mat'] |
|
trans_mat = sample['trans_mat'] |
|
query_c2w = sample['query_c2w'] |
|
query_w2c = sample['query_w2c'] |
|
true_img = sample['query_image'][0] |
|
true_img = np.uint8(true_img.permute(1, 2, 0).cpu().numpy() * 255) |
|
|
|
depth_min, depth_max = near.cpu().numpy(), far.cpu().numpy() |
|
|
|
scale_factor = sample['scale_factor'][0].cpu().numpy() |
|
true_depth = sample['query_depth'] if 'query_depth' in sample.keys() else None |
|
if true_depth is not None: |
|
true_depth = true_depth[0].cpu().numpy() |
|
true_depth_colored = visualize_depth_numpy(true_depth, [depth_min, depth_max])[0] |
|
else: |
|
true_depth_colored = None |
|
|
|
rays_o = rays_o.reshape(-1, 3).split(chunk_size) |
|
rays_d = rays_d.reshape(-1, 3).split(chunk_size) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0) |
|
|
|
conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume( |
|
feature_maps=geometry_feature_maps[None, :, :, :, :], |
|
partial_vol_origin=partial_vol_origin, |
|
proj_mats=proj_matrices, |
|
sizeH=sizeH, |
|
sizeW=sizeW, |
|
lod=0, |
|
) |
|
|
|
con_volume_lod0 = conditional_features_lod0['dense_volume_scale0'] |
|
con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0'] |
|
coords_lod0 = conditional_features_lod0['coords_scale0'] |
|
|
|
if self.num_lods > 1: |
|
sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume( |
|
con_volume_lod0, con_valid_mask_volume_lod0, |
|
coords_lod0, partial_vol_origin) |
|
|
|
depth_maps_lod0, depth_masks_lod0 = None, None |
|
if self.prune_depth_filter: |
|
depth_maps_lod0_l4x, depth_masks_lod0_l4x = self.sdf_renderer_lod0.extract_depth_maps( |
|
self.sdf_network_lod0, sdf_volume_lod0, |
|
intrinsics_l_4x, c2ws, |
|
sizeH // 4, sizeW // 4, near * 1.5, far) |
|
depth_maps_lod0 = F.interpolate(depth_maps_lod0_l4x, size=(sizeH, sizeW), mode='bilinear', |
|
align_corners=True) |
|
depth_masks_lod0 = F.interpolate(depth_masks_lod0_l4x.float(), size=(sizeH, sizeW), mode='nearest') |
|
|
|
|
|
colored_depth_maps_lod0 = [] |
|
for i in range(depth_maps_lod0.shape[0]): |
|
colored_depth_maps_lod0.append( |
|
visualize_depth_numpy(depth_maps_lod0[i, 0].cpu().numpy(), [depth_min, depth_max])[0]) |
|
|
|
colored_depth_maps_lod0 = np.concatenate(colored_depth_maps_lod0, axis=0).astype(np.uint8) |
|
os.makedirs(os.path.join(self.base_exp_dir, 'depth_maps_lod0'), exist_ok=True) |
|
cv.imwrite(os.path.join(self.base_exp_dir, 'depth_maps_lod0', |
|
'{:0>8d}_{}.png'.format(iter_step, meta)), |
|
colored_depth_maps_lod0[:, :, ::-1]) |
|
|
|
if self.num_lods > 1: |
|
geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1) |
|
|
|
if self.prune_depth_filter: |
|
pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf_depthfilter( |
|
sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0], |
|
depth_maps_lod0, proj_matrices[0], |
|
partial_vol_origin, self.sdf_network_lod0.voxel_size, |
|
near, far, self.sdf_network_lod0.voxel_size, 12) |
|
else: |
|
pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf( |
|
sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0]) |
|
|
|
pre_coords[:, 1:] = pre_coords[:, 1:] * 2 |
|
|
|
with torch.no_grad(): |
|
conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume( |
|
feature_maps=geometry_feature_maps_lod1[None, :, :, :, :], |
|
partial_vol_origin=partial_vol_origin, |
|
proj_mats=proj_matrices, |
|
sizeH=sizeH, |
|
sizeW=sizeW, |
|
pre_coords=pre_coords, |
|
pre_feats=pre_feats, |
|
) |
|
|
|
con_volume_lod1 = conditional_features_lod1['dense_volume_scale1'] |
|
con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1'] |
|
|
|
out_rgb_fine = [] |
|
out_normal_fine = [] |
|
out_depth_fine = [] |
|
|
|
out_rgb_fine_lod1 = [] |
|
out_normal_fine_lod1 = [] |
|
out_depth_fine_lod1 = [] |
|
|
|
|
|
if save_vis: |
|
for rays_o_batch, rays_d_batch in zip(rays_o, rays_d): |
|
|
|
|
|
render_out = self.sdf_renderer_lod0.render( |
|
rays_o_batch, rays_d_batch, near, far, |
|
self.sdf_network_lod0, |
|
self.rendering_network_lod0, |
|
background_rgb=background_rgb, |
|
alpha_inter_ratio=alpha_inter_ratio_lod0, |
|
|
|
lod=0, |
|
conditional_volume=con_volume_lod0, |
|
conditional_valid_mask_volume=con_valid_mask_volume_lod0, |
|
|
|
feature_maps=geometry_feature_maps, |
|
color_maps=imgs, |
|
w2cs=w2cs, |
|
intrinsics=intrinsics, |
|
img_wh=[sizeW, sizeH], |
|
query_c2w=query_c2w, |
|
if_render_with_grad=False, |
|
) |
|
|
|
feasible = lambda key: ((key in render_out) and (render_out[key] is not None)) |
|
|
|
if feasible('depth'): |
|
out_depth_fine.append(render_out['depth'].detach().cpu().numpy()) |
|
|
|
|
|
if feasible('color_fine'): |
|
out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy()) |
|
if feasible('gradients') and feasible('weights'): |
|
if render_out['inside_sphere'] is not None: |
|
out_normal_fine.append((render_out['gradients'] * render_out['weights'][:, |
|
:self.n_samples_lod0 + self.n_importance_lod0, |
|
None] * render_out['inside_sphere'][ |
|
..., None]).sum(dim=1).detach().cpu().numpy()) |
|
else: |
|
out_normal_fine.append((render_out['gradients'] * render_out['weights'][:, |
|
:self.n_samples_lod0 + self.n_importance_lod0, |
|
None]).sum(dim=1).detach().cpu().numpy()) |
|
del render_out |
|
|
|
|
|
if self.num_lods > 1: |
|
for rays_o_batch, rays_d_batch in zip(rays_o, rays_d): |
|
render_out_lod1 = self.sdf_renderer_lod1.render( |
|
rays_o_batch, rays_d_batch, near, far, |
|
self.sdf_network_lod1, |
|
self.rendering_network_lod1, |
|
background_rgb=background_rgb, |
|
alpha_inter_ratio=alpha_inter_ratio_lod1, |
|
|
|
lod=1, |
|
conditional_volume=con_volume_lod1, |
|
conditional_valid_mask_volume=con_valid_mask_volume_lod1, |
|
|
|
feature_maps=geometry_feature_maps_lod1, |
|
color_maps=imgs, |
|
w2cs=w2cs, |
|
intrinsics=intrinsics, |
|
img_wh=[sizeW, sizeH], |
|
query_c2w=query_c2w, |
|
if_render_with_grad=False, |
|
) |
|
|
|
feasible = lambda key: ((key in render_out_lod1) and (render_out_lod1[key] is not None)) |
|
|
|
if feasible('depth'): |
|
out_depth_fine_lod1.append(render_out_lod1['depth'].detach().cpu().numpy()) |
|
|
|
|
|
if feasible('color_fine'): |
|
out_rgb_fine_lod1.append(render_out_lod1['color_fine'].detach().cpu().numpy()) |
|
if feasible('gradients') and feasible('weights'): |
|
if render_out_lod1['inside_sphere'] is not None: |
|
out_normal_fine_lod1.append((render_out_lod1['gradients'] * render_out_lod1['weights'][:, |
|
:self.n_samples_lod1 + self.n_importance_lod1, |
|
None] * |
|
render_out_lod1['inside_sphere'][ |
|
..., None]).sum(dim=1).detach().cpu().numpy()) |
|
else: |
|
out_normal_fine_lod1.append((render_out_lod1['gradients'] * render_out_lod1['weights'][:, |
|
:self.n_samples_lod1 + self.n_importance_lod1, |
|
None]).sum( |
|
dim=1).detach().cpu().numpy()) |
|
del render_out_lod1 |
|
|
|
|
|
|
|
self.save_visualization(true_img, true_depth_colored, out_depth_fine, out_normal_fine, |
|
query_w2c[0], out_rgb_fine, H, W, |
|
depth_min, depth_max, iter_step, meta, "val_lod0", true_depth=true_depth, scale_factor=scale_factor) |
|
|
|
if self.num_lods > 1: |
|
self.save_visualization(true_img, true_depth_colored, out_depth_fine_lod1, out_normal_fine_lod1, |
|
query_w2c[0], out_rgb_fine_lod1, H, W, |
|
depth_min, depth_max, iter_step, meta, "val_lod1", true_depth=true_depth, scale_factor=scale_factor) |
|
|
|
|
|
if (iter_step % self.val_mesh_freq == 0): |
|
torch.cuda.empty_cache() |
|
self.validate_mesh(self.sdf_network_lod0, |
|
self.sdf_renderer_lod0.extract_geometry, |
|
conditional_volume=con_volume_lod0, lod=0, |
|
threshold=0, |
|
|
|
mode='val_bg', meta=meta, |
|
iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat) |
|
torch.cuda.empty_cache() |
|
|
|
if self.num_lods > 1: |
|
self.validate_mesh(self.sdf_network_lod1, |
|
self.sdf_renderer_lod1.extract_geometry, |
|
conditional_volume=con_volume_lod1, lod=1, |
|
|
|
mode='val_bg', meta=meta, |
|
iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat) |
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
def export_mesh_step(self, sample, |
|
perturb_overwrite=-1, |
|
background_rgb=None, |
|
alpha_inter_ratio_lod0=0.0, |
|
alpha_inter_ratio_lod1=0.0, |
|
iter_step=0, |
|
chunk_size=512, |
|
save_vis=False, |
|
): |
|
|
|
|
|
batch_idx = sample['batch_idx'][0] |
|
meta = sample['meta'][batch_idx] |
|
|
|
sizeW = sample['img_wh'][0][0] |
|
sizeH = sample['img_wh'][0][1] |
|
H, W = sizeH, sizeW |
|
|
|
partial_vol_origin = sample['partial_vol_origin'] |
|
near, far = sample['query_near_far'][0, :1], sample['query_near_far'][0, 1:] |
|
|
|
|
|
sample_rays = sample['rays'] |
|
rays_o = sample_rays['rays_o'][0] |
|
rays_d = sample_rays['rays_v'][0] |
|
rays_ndc_uv = sample_rays['rays_ndc_uv'][0] |
|
|
|
imgs = sample['images'][0] |
|
intrinsics = sample['intrinsics'][0] |
|
intrinsics_l_4x = intrinsics.clone() |
|
intrinsics_l_4x[:, :2] *= 0.25 |
|
w2cs = sample['w2cs'][0] |
|
c2ws = sample['c2ws'][0] |
|
|
|
proj_matrices = sample['affine_mats'] |
|
|
|
|
|
|
|
scale_mat = sample['scale_mat'] |
|
trans_mat = sample['trans_mat'] |
|
query_c2w = sample['query_c2w'] |
|
query_w2c = sample['query_w2c'] |
|
true_img = sample['query_image'][0] |
|
true_img = np.uint8(true_img.permute(1, 2, 0).cpu().numpy() * 255) |
|
|
|
depth_min, depth_max = near.cpu().numpy(), far.cpu().numpy() |
|
|
|
scale_factor = sample['scale_factor'][0].cpu().numpy() |
|
true_depth = sample['query_depth'] if 'query_depth' in sample.keys() else None |
|
if true_depth is not None: |
|
true_depth = true_depth[0].cpu().numpy() |
|
true_depth_colored = visualize_depth_numpy(true_depth, [depth_min, depth_max])[0] |
|
else: |
|
true_depth_colored = None |
|
|
|
rays_o = rays_o.reshape(-1, 3).split(chunk_size) |
|
rays_d = rays_d.reshape(-1, 3).split(chunk_size) |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0) |
|
|
|
conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume( |
|
feature_maps=geometry_feature_maps[None, :, :, :, :], |
|
partial_vol_origin=partial_vol_origin, |
|
proj_mats=proj_matrices, |
|
sizeH=sizeH, |
|
sizeW=sizeW, |
|
lod=0, |
|
) |
|
|
|
|
|
|
|
con_volume_lod0 = conditional_features_lod0['dense_volume_scale0'] |
|
con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0'] |
|
coords_lod0 = conditional_features_lod0['coords_scale0'] |
|
|
|
if self.num_lods > 1: |
|
sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume( |
|
con_volume_lod0, con_valid_mask_volume_lod0, |
|
coords_lod0, partial_vol_origin) |
|
|
|
depth_maps_lod0, depth_masks_lod0 = None, None |
|
|
|
|
|
if self.num_lods > 1: |
|
geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1) |
|
|
|
if self.prune_depth_filter: |
|
pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf_depthfilter( |
|
sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0], |
|
depth_maps_lod0, proj_matrices[0], |
|
partial_vol_origin, self.sdf_network_lod0.voxel_size, |
|
near, far, self.sdf_network_lod0.voxel_size, 12) |
|
else: |
|
pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf( |
|
sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0]) |
|
|
|
pre_coords[:, 1:] = pre_coords[:, 1:] * 2 |
|
|
|
with torch.no_grad(): |
|
conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume( |
|
feature_maps=geometry_feature_maps_lod1[None, :, :, :, :], |
|
partial_vol_origin=partial_vol_origin, |
|
proj_mats=proj_matrices, |
|
sizeH=sizeH, |
|
sizeW=sizeW, |
|
pre_coords=pre_coords, |
|
pre_feats=pre_feats, |
|
) |
|
|
|
con_volume_lod1 = conditional_features_lod1['dense_volume_scale1'] |
|
con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1'] |
|
|
|
out_rgb_fine = [] |
|
out_normal_fine = [] |
|
out_depth_fine = [] |
|
|
|
out_rgb_fine_lod1 = [] |
|
out_normal_fine_lod1 = [] |
|
out_depth_fine_lod1 = [] |
|
|
|
|
|
|
|
|
|
|
|
if (iter_step % self.val_mesh_freq == 0): |
|
torch.cuda.empty_cache() |
|
|
|
self.validate_colored_mesh( |
|
density_or_sdf_network=self.sdf_network_lod0, |
|
func_extract_geometry=self.sdf_renderer_lod0.extract_geometry, |
|
conditional_volume=con_volume_lod0, |
|
conditional_valid_mask_volume = con_valid_mask_volume_lod0, |
|
feature_maps=geometry_feature_maps, |
|
color_maps=imgs, |
|
w2cs=w2cs, |
|
target_candidate_w2cs=None, |
|
intrinsics=intrinsics, |
|
rendering_network=self.rendering_network_lod0, |
|
rendering_projector=self.sdf_renderer_lod0.rendering_projector, |
|
lod=0, |
|
threshold=0, |
|
query_c2w=query_c2w, |
|
mode='val_bg', meta=meta, |
|
iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat |
|
) |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
if self.num_lods > 1: |
|
self.validate_colored_mesh( |
|
density_or_sdf_network=self.sdf_network_lod1, |
|
func_extract_geometry=self.sdf_renderer_lod1.extract_geometry, |
|
conditional_volume=con_volume_lod1, |
|
conditional_valid_mask_volume = con_valid_mask_volume_lod1, |
|
feature_maps=geometry_feature_maps, |
|
color_maps=imgs, |
|
w2cs=w2cs, |
|
target_candidate_w2cs=None, |
|
intrinsics=intrinsics, |
|
rendering_network=self.rendering_network_lod1, |
|
rendering_projector=self.sdf_renderer_lod1.rendering_projector, |
|
lod=1, |
|
threshold=0, |
|
query_c2w=query_c2w, |
|
mode='val_bg', meta=meta, |
|
iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat |
|
) |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
def save_visualization(self, true_img, true_colored_depth, out_depth, out_normal, w2cs, out_color, H, W, |
|
depth_min, depth_max, iter_step, meta, comment, out_color_mlp=[], true_depth=None, scale_factor=1.0): |
|
if len(out_color) > 0: |
|
img_fine = (np.concatenate(out_color, axis=0).reshape([H, W, 3]) * 256).clip(0, 255) |
|
|
|
if len(out_color_mlp) > 0: |
|
img_mlp = (np.concatenate(out_color_mlp, axis=0).reshape([H, W, 3]) * 256).clip(0, 255) |
|
|
|
if len(out_normal) > 0: |
|
normal_img = np.concatenate(out_normal, axis=0) |
|
rot = w2cs[:3, :3].detach().cpu().numpy() |
|
|
|
normal_img = (np.matmul(rot[None, :, :], |
|
normal_img[:, :, None]).reshape([H, W, 3]) * 128 + 128).clip(0, 255) |
|
if len(out_depth) > 0: |
|
pred_depth = np.concatenate(out_depth, axis=0).reshape([H, W]) |
|
pred_depth_colored = visualize_depth_numpy(pred_depth, [depth_min, depth_max])[0] |
|
|
|
if len(out_depth) > 0: |
|
os.makedirs(os.path.join(self.base_exp_dir, 'depths_' + comment), exist_ok=True) |
|
if true_colored_depth is not None: |
|
|
|
if true_depth is not None: |
|
depth_error_map = np.abs(true_depth - pred_depth) * 2.0 / scale_factor |
|
|
|
depth_error_map = np.tile(depth_error_map[:, :, None], [1, 1, 3]) |
|
print("meta: ", meta) |
|
print("scale_factor: ", scale_factor) |
|
print("depth_error_mean: ", depth_error_map.mean()) |
|
depth_visualized = np.concatenate( |
|
[(depth_error_map * 255).astype(np.uint8), true_colored_depth, pred_depth_colored, true_img], axis=1)[:, :, ::-1] |
|
|
|
|
|
|
|
else: |
|
depth_visualized = np.concatenate( |
|
[true_colored_depth, pred_depth_colored, true_img])[:, :, ::-1] |
|
cv.imwrite( |
|
os.path.join(self.base_exp_dir, 'depths_' + comment, |
|
'{:0>8d}_{}.png'.format(iter_step, meta)), depth_visualized |
|
) |
|
else: |
|
cv.imwrite( |
|
os.path.join(self.base_exp_dir, 'depths_' + comment, |
|
'{:0>8d}_{}.png'.format(iter_step, meta)), |
|
np.concatenate( |
|
[pred_depth_colored, true_img])[:, :, ::-1]) |
|
if len(out_color) > 0: |
|
os.makedirs(os.path.join(self.base_exp_dir, 'synthesized_color_' + comment), exist_ok=True) |
|
cv.imwrite(os.path.join(self.base_exp_dir, 'synthesized_color_' + comment, |
|
'{:0>8d}_{}.png'.format(iter_step, meta)), |
|
np.concatenate( |
|
[img_fine, true_img])[:, :, ::-1]) |
|
|
|
mse_loss = np.mean((img_fine - true_img) ** 2) |
|
psnr = 10 * np.log10(255 ** 2 / mse_loss) |
|
|
|
print("PSNR: ", psnr) |
|
|
|
if len(out_color_mlp) > 0: |
|
os.makedirs(os.path.join(self.base_exp_dir, 'synthesized_color_mlp_' + comment), exist_ok=True) |
|
cv.imwrite(os.path.join(self.base_exp_dir, 'synthesized_color_mlp_' + comment, |
|
'{:0>8d}_{}.png'.format(iter_step, meta)), |
|
np.concatenate( |
|
[img_mlp, true_img])[:, :, ::-1]) |
|
|
|
if len(out_normal) > 0: |
|
os.makedirs(os.path.join(self.base_exp_dir, 'normals_' + comment), exist_ok=True) |
|
cv.imwrite(os.path.join(self.base_exp_dir, 'normals_' + comment, |
|
'{:0>8d}_{}.png'.format(iter_step, meta)), |
|
normal_img[:, :, ::-1]) |
|
|
|
def forward(self, sample, |
|
perturb_overwrite=-1, |
|
background_rgb=None, |
|
alpha_inter_ratio_lod0=0.0, |
|
alpha_inter_ratio_lod1=0.0, |
|
iter_step=0, |
|
mode='train', |
|
save_vis=False, |
|
): |
|
|
|
if mode == 'train': |
|
return self.train_step(sample, |
|
perturb_overwrite=perturb_overwrite, |
|
background_rgb=background_rgb, |
|
alpha_inter_ratio_lod0=alpha_inter_ratio_lod0, |
|
alpha_inter_ratio_lod1=alpha_inter_ratio_lod1, |
|
iter_step=iter_step |
|
) |
|
elif mode == 'val': |
|
import time |
|
begin = time.time() |
|
result = self.val_step(sample, |
|
perturb_overwrite=perturb_overwrite, |
|
background_rgb=background_rgb, |
|
alpha_inter_ratio_lod0=alpha_inter_ratio_lod0, |
|
alpha_inter_ratio_lod1=alpha_inter_ratio_lod1, |
|
iter_step=iter_step, |
|
save_vis=save_vis, |
|
) |
|
end = time.time() |
|
print("val_step time: ", end - begin) |
|
return result |
|
elif mode == 'export_mesh': |
|
import time |
|
begin = time.time() |
|
result = self.export_mesh_step(sample, |
|
perturb_overwrite=perturb_overwrite, |
|
background_rgb=background_rgb, |
|
alpha_inter_ratio_lod0=alpha_inter_ratio_lod0, |
|
alpha_inter_ratio_lod1=alpha_inter_ratio_lod1, |
|
iter_step=iter_step, |
|
save_vis=save_vis, |
|
) |
|
end = time.time() |
|
print("export mesh time: ", end - begin) |
|
return result |
|
def obtain_pyramid_feature_maps(self, imgs, lod=0): |
|
""" |
|
get feature maps of all conditional images |
|
:param imgs: |
|
:return: |
|
""" |
|
|
|
if lod == 0: |
|
extractor = self.pyramid_feature_network_geometry_lod0 |
|
elif lod >= 1: |
|
extractor = self.pyramid_feature_network_geometry_lod1 |
|
|
|
pyramid_feature_maps = extractor(imgs) |
|
|
|
|
|
fused_feature_maps = torch.cat([ |
|
F.interpolate(pyramid_feature_maps[0], scale_factor=4, mode='bilinear', align_corners=True), |
|
F.interpolate(pyramid_feature_maps[1], scale_factor=2, mode='bilinear', align_corners=True), |
|
pyramid_feature_maps[2] |
|
], dim=1) |
|
|
|
return fused_feature_maps |
|
|
|
def cal_losses_sdf(self, render_out, sample_rays, iter_step=-1, lod=0): |
|
|
|
|
|
def get_weight(iter_step, weight): |
|
if lod == 1: |
|
anneal_start = self.anneal_end if lod == 0 else self.anneal_end_lod1 |
|
anneal_end = self.anneal_end if lod == 0 else self.anneal_end_lod1 |
|
anneal_end = anneal_end * 2 |
|
else: |
|
anneal_start = self.anneal_start if lod == 0 else self.anneal_start_lod1 |
|
anneal_end = self.anneal_end if lod == 0 else self.anneal_end_lod1 |
|
anneal_end = anneal_end * 2 |
|
|
|
if iter_step < 0: |
|
return weight |
|
|
|
if anneal_end == 0.0: |
|
return weight |
|
elif iter_step < anneal_start: |
|
return 0.0 |
|
else: |
|
return np.min( |
|
[1.0, |
|
(iter_step - anneal_start) / (anneal_end - anneal_start)]) * weight |
|
|
|
rays_o = sample_rays['rays_o'][0] |
|
rays_d = sample_rays['rays_v'][0] |
|
true_rgb = sample_rays['rays_color'][0] |
|
|
|
if 'rays_depth' in sample_rays.keys(): |
|
true_depth = sample_rays['rays_depth'][0] |
|
else: |
|
true_depth = None |
|
mask = sample_rays['rays_mask'][0] |
|
|
|
color_fine = render_out['color_fine'] |
|
color_fine_mask = render_out['color_fine_mask'] |
|
depth_pred = render_out['depth'] |
|
|
|
variance = render_out['variance'] |
|
cdf_fine = render_out['cdf_fine'] |
|
weight_sum = render_out['weights_sum'] |
|
|
|
gradient_error_fine = render_out['gradient_error_fine'] |
|
|
|
sdf = render_out['sdf'] |
|
|
|
|
|
color_mlp = render_out['color_mlp'] |
|
color_mlp_mask = render_out['color_mlp_mask'] |
|
|
|
if color_fine is not None: |
|
|
|
color_mask = color_fine_mask if color_fine_mask is not None else mask |
|
color_mask = color_mask[..., 0] |
|
color_error = (color_fine[color_mask] - true_rgb[color_mask]) |
|
|
|
|
|
color_fine_loss = F.l1_loss(color_error, torch.zeros_like(color_error).to(color_error.device), |
|
reduction='mean') |
|
|
|
psnr = 20.0 * torch.log10( |
|
1.0 / (((color_fine[color_mask] - true_rgb[color_mask]) ** 2).mean() / (3.0)).sqrt()) |
|
else: |
|
color_fine_loss = 0. |
|
psnr = 0. |
|
|
|
if color_mlp is not None: |
|
|
|
color_mlp_mask = color_mlp_mask[..., 0] |
|
color_error_mlp = (color_mlp[color_mlp_mask] - true_rgb[color_mlp_mask]) |
|
color_mlp_loss = F.l1_loss(color_error_mlp, |
|
torch.zeros_like(color_error_mlp).to(color_error_mlp.device), |
|
reduction='mean') |
|
|
|
psnr_mlp = 20.0 * torch.log10( |
|
1.0 / (((color_mlp[color_mlp_mask] - true_rgb[color_mlp_mask]) ** 2).mean() / (3.0)).sqrt()) |
|
else: |
|
color_mlp_loss = 0. |
|
psnr_mlp = 0. |
|
|
|
|
|
if true_depth is not None: |
|
|
|
depth_loss = self.depth_criterion(depth_pred, true_depth) |
|
|
|
|
|
|
|
|
|
depth_statis = None |
|
else: |
|
depth_loss = 0. |
|
depth_statis = None |
|
|
|
sparse_loss_1 = torch.exp( |
|
-1 * torch.abs(render_out['sdf_random']) * self.sdf_decay_param).mean() |
|
sparse_loss_2 = torch.exp(-1 * torch.abs(sdf) * self.sdf_decay_param).mean() |
|
sparse_loss = (sparse_loss_1 + sparse_loss_2) / 2 |
|
|
|
sdf_mean = torch.abs(sdf).mean() |
|
sparseness_1 = (torch.abs(sdf) < 0.01).to(torch.float32).mean() |
|
sparseness_2 = (torch.abs(sdf) < 0.02).to(torch.float32).mean() |
|
|
|
|
|
gradient_error_loss = gradient_error_fine |
|
|
|
|
|
fg_bg_weight = 0.0 if iter_step < 50000 else get_weight(iter_step, self.fg_bg_weight) |
|
|
|
|
|
|
|
|
|
background_loss = 0.0 |
|
fg_bg_loss = 0.0 |
|
if self.fg_bg_weight > 0 and torch.mean((mask < 0.5).to(torch.float32)) > 0.02: |
|
weights_sum_fg = render_out['weights_sum_fg'] |
|
fg_bg_error = (weights_sum_fg - mask)[mask < 0.5] |
|
fg_bg_loss = F.l1_loss(fg_bg_error, |
|
torch.zeros_like(fg_bg_error).to(fg_bg_error.device), |
|
reduction='mean') |
|
|
|
|
|
|
|
loss = 1.0 * depth_loss + color_fine_loss + color_mlp_loss + \ |
|
sparse_loss * get_weight(iter_step, self.sdf_sparse_weight) + \ |
|
fg_bg_loss * fg_bg_weight + \ |
|
gradient_error_loss * self.sdf_igr_weight |
|
|
|
losses = { |
|
"loss": loss, |
|
"depth_loss": depth_loss, |
|
"color_fine_loss": color_fine_loss, |
|
"color_mlp_loss": color_mlp_loss, |
|
"gradient_error_loss": gradient_error_loss, |
|
"background_loss": background_loss, |
|
"sparse_loss": sparse_loss, |
|
"sparseness_1": sparseness_1, |
|
"sparseness_2": sparseness_2, |
|
"sdf_mean": sdf_mean, |
|
"psnr": psnr, |
|
"psnr_mlp": psnr_mlp, |
|
"weights_sum": render_out['weights_sum'], |
|
"weights_sum_fg": render_out['weights_sum_fg'], |
|
"alpha_sum": render_out['alpha_sum'], |
|
"variance": render_out['variance'], |
|
"sparse_weight": get_weight(iter_step, self.sdf_sparse_weight), |
|
"fg_bg_weight": fg_bg_weight, |
|
"fg_bg_loss": fg_bg_loss, |
|
} |
|
losses = torch.tensor(losses, device=rays_o.device) |
|
return loss, losses, depth_statis |
|
|
|
@torch.no_grad() |
|
def validate_mesh(self, density_or_sdf_network, func_extract_geometry, world_space=True, resolution=360, |
|
threshold=0.0, mode='val', |
|
|
|
conditional_volume=None, lod=None, occupancy_mask=None, |
|
bound_min=[-1, -1, -1], bound_max=[1, 1, 1], meta='', iter_step=0, scale_mat=None, |
|
trans_mat=None |
|
): |
|
|
|
bound_min = torch.tensor(bound_min, dtype=torch.float32) |
|
bound_max = torch.tensor(bound_max, dtype=torch.float32) |
|
|
|
vertices, triangles, fields = func_extract_geometry( |
|
density_or_sdf_network, |
|
bound_min, bound_max, resolution=resolution, |
|
threshold=threshold, device=conditional_volume.device, |
|
|
|
conditional_volume=conditional_volume, lod=lod, |
|
occupancy_mask=occupancy_mask |
|
) |
|
|
|
|
|
if scale_mat is not None: |
|
scale_mat_np = scale_mat.cpu().numpy() |
|
vertices = vertices * scale_mat_np[0][0, 0] + scale_mat_np[0][:3, 3][None] |
|
|
|
if trans_mat is not None: |
|
trans_mat_np = trans_mat.cpu().numpy() |
|
vertices_homo = np.concatenate([vertices, np.ones_like(vertices[:, :1])], axis=1) |
|
vertices = np.matmul(trans_mat_np, vertices_homo[:, :, None])[:, :3, 0] |
|
|
|
mesh = trimesh.Trimesh(vertices, triangles) |
|
os.makedirs(os.path.join(self.base_exp_dir, 'meshes_' + mode), exist_ok=True) |
|
mesh.export(os.path.join(self.base_exp_dir, 'meshes_' + mode, |
|
'mesh_{:0>8d}_{}_lod{:0>1d}.ply'.format(iter_step, meta, lod))) |
|
|
|
|
|
|
|
def validate_colored_mesh(self, density_or_sdf_network, func_extract_geometry, world_space=True, resolution=360, |
|
threshold=0.0, mode='val', |
|
|
|
conditional_volume=None, |
|
conditional_valid_mask_volume=None, |
|
feature_maps=None, |
|
color_maps = None, |
|
w2cs=None, |
|
target_candidate_w2cs=None, |
|
intrinsics=None, |
|
rendering_network=None, |
|
rendering_projector=None, |
|
query_c2w=None, |
|
lod=None, occupancy_mask=None, |
|
bound_min=[-1, -1, -1], bound_max=[1, 1, 1], meta='', iter_step=0, scale_mat=None, |
|
trans_mat=None |
|
): |
|
|
|
bound_min = torch.tensor(bound_min, dtype=torch.float32) |
|
bound_max = torch.tensor(bound_max, dtype=torch.float32) |
|
|
|
|
|
vertices, triangles, fields = func_extract_geometry( |
|
density_or_sdf_network, |
|
bound_min, bound_max, resolution=resolution, |
|
threshold=threshold, device=conditional_volume.device, |
|
|
|
conditional_volume=conditional_volume, lod=lod, |
|
occupancy_mask=occupancy_mask |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask, _, _ = rendering_projector.compute_view_independent( |
|
torch.tensor(vertices).to(conditional_volume), |
|
lod=lod, |
|
|
|
geometryVolume=conditional_volume[0], |
|
geometryVolumeMask=conditional_valid_mask_volume[0], |
|
sdf_network=density_or_sdf_network, |
|
|
|
rendering_feature_maps=feature_maps, |
|
color_maps=color_maps, |
|
w2cs=w2cs, |
|
target_candidate_w2cs=target_candidate_w2cs, |
|
intrinsics=intrinsics, |
|
img_wh=[256,256], |
|
query_img_idx=0, |
|
query_c2w=query_c2w, |
|
) |
|
|
|
|
|
vertices_color, rendering_valid_mask = rendering_network( |
|
ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask) |
|
|
|
|
|
|
|
|
|
if scale_mat is not None: |
|
scale_mat_np = scale_mat.cpu().numpy() |
|
vertices = vertices * scale_mat_np[0][0, 0] + scale_mat_np[0][:3, 3][None] |
|
|
|
if trans_mat is not None: |
|
trans_mat_np = trans_mat.cpu().numpy() |
|
vertices_homo = np.concatenate([vertices, np.ones_like(vertices[:, :1])], axis=1) |
|
vertices = np.matmul(trans_mat_np, vertices_homo[:, :, None])[:, :3, 0] |
|
|
|
vertices_color = np.array(vertices_color.squeeze(0).cpu() * 255, dtype=np.uint8) |
|
mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=vertices_color) |
|
os.makedirs(os.path.join(self.base_exp_dir, 'meshes_' + mode, 'lod{:0>1d}'.format(lod)), exist_ok=True) |
|
|
|
|
|
|
|
mesh.export(os.path.join(self.base_exp_dir, 'meshes_' + mode, 'lod{:0>1d}'.format(lod), |
|
'mesh_{:0>8d}_gradio_lod{:0>1d}.ply'.format(iter_step, lod))) |