diff --git a/.gitignore b/.gitignore index ba0430d26c996e7f078385407f959c96c271087c..208b0be95d9ed2968365b9367a95b7ca1fbe58ed 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -__pycache__/ \ No newline at end of file +__pycache__/ +*.DS_Store \ No newline at end of file diff --git a/SparseNeuS_demo_v1/confs/blender_general_lod1_val_new.conf b/SparseNeuS_demo_v1/confs/blender_general_lod1_val_new.conf new file mode 100644 index 0000000000000000000000000000000000000000..dacbc09968c2f4cd6f7348dd93552ea5d8876236 --- /dev/null +++ b/SparseNeuS_demo_v1/confs/blender_general_lod1_val_new.conf @@ -0,0 +1,137 @@ +# - for the lod1 geometry network, using adaptive cost for sparse cost regularization network +#- for lod1 rendering network, using depth-adaptive render + +general { + base_exp_dir = ./exp/val/1_4_only_narrow_lod1 + + recording = [ + ./, + ./data + ./ops + ./models + ./loss + ] +} + +dataset { + # local path + trainpath = /objaverse-processed/zero12345_img/eval_selected + valpath = /objaverse-processed/zero12345_img/eval_selected + testpath = /objaverse-processed/zero12345_img/eval_selected + # trainpath = /objaverse-processed/zero12345_img/zero12345_2stage_5pred_sample/ + # valpath = /objaverse-processed/zero12345_img/zero12345_2stage_5pred_sample/ + # testpath = /objaverse-processed/zero12345_img/zero12345_2stage_5pred_sample/ + imgScale_train = 1.0 + imgScale_test = 1.0 + nviews = 5 + clean_image = True + importance_sample = True + test_ref_views = [23] + + # test dataset + test_n_views = 2 + test_img_wh = [256, 256] + test_clip_wh = [0, 0] + test_scan_id = scan110 + train_img_idx = [49, 50, 52, 53, 54, 56, 58] #[21, 22, 23, 24, 25] # + test_img_idx = [51, 55, 57] #[32, 33, 34] # + + test_dir_comment = train +} + +train { + learning_rate = 2e-4 + learning_rate_milestone = [100000, 150000, 200000] + learning_rate_factor = 0.5 + end_iter = 200000 + save_freq = 5000 + val_freq = 1 + val_mesh_freq =1 + report_freq = 100 + + N_rays = 512 + + validate_resolution_level = 4 + anneal_start = 0 + anneal_end = 25000 + anneal_start_lod1 = 0 + anneal_end_lod1 = 15000 + + use_white_bkgd = True + + # Loss + # ! for training the lod1 network, don't use this regularization in first 10k steps; then use the regularization + sdf_igr_weight = 0.1 + sdf_sparse_weight = 0.02 # 0.002 for lod1 network; 0.02 for lod0 network + sdf_decay_param = 100 # cannot be too large, which decide the tsdf range + fg_bg_weight = 0.01 # first 0.01 + bg_ratio = 0.3 + + if_fix_lod0_networks = True +} + +model { + num_lods = 2 + + sdf_network_lod0 { + lod = 0, + ch_in = 56, # the channel num of fused pyramid features + voxel_size = 0.02105263, # 0.02083333, should be 2/95 + vol_dims = [96, 96, 96], + hidden_dim = 128, + cost_type = variance_mean + d_pyramid_feature_compress = 16, + regnet_d_out = 16, + num_sdf_layers = 4, + # position embedding + multires = 6 + } + + + sdf_network_lod1 { + lod = 1, + ch_in = 56, # the channel num of fused pyramid features + voxel_size = 0.0104712, #0.01041667, should be 2/191 + vol_dims = [192, 192, 192], + hidden_dim = 128, + cost_type = variance_mean + d_pyramid_feature_compress = 8, + regnet_d_out = 8, + num_sdf_layers = 4, + # position embedding + multires = 6 + } + + + variance_network { + init_val = 0.2 + } + + variance_network_lod1 { + init_val = 0.2 + } + + rendering_network { + in_geometry_feat_ch = 16 + in_rendering_feat_ch = 56 + anti_alias_pooling = True + } + + rendering_network_lod1 { + in_geometry_feat_ch = 8 + in_rendering_feat_ch = 56 + anti_alias_pooling = True + + } + + + trainer { + n_samples_lod0 = 64 + n_importance_lod0 = 64 + n_samples_lod1 = 64 + n_importance_lod1 = 64 + n_outside = 0 # 128 if render_outside_uniform_sampling + perturb = 1.0 + alpha_type = div + } +} diff --git a/SparseNeuS_demo_v1/confs/one2345_lod0_val_demo.conf b/SparseNeuS_demo_v1/confs/one2345_lod0_val_demo.conf new file mode 100644 index 0000000000000000000000000000000000000000..7be6d4098d66473f63252c42d0a1bd25e2338a6b --- /dev/null +++ b/SparseNeuS_demo_v1/confs/one2345_lod0_val_demo.conf @@ -0,0 +1,137 @@ +# - for the lod1 geometry network, using adaptive cost for sparse cost regularization network +#- for lod1 rendering network, using depth-adaptive render + +general { + + base_exp_dir = exp/lod0 # !!! where you store the results and checkpoints to be used + recording = [ + ./, + ./data + ./ops + ./models + ./loss + ] +} + +dataset { + trainpath = ../ + valpath = ../ # !!! where you store the validation data + testpath = ../ + + + + imgScale_train = 1.0 + imgScale_test = 1.0 + nviews = 5 + clean_image = True + importance_sample = True + test_ref_views = [23] + + # test dataset + test_n_views = 2 + test_img_wh = [256, 256] + test_clip_wh = [0, 0] + test_scan_id = scan110 + train_img_idx = [49, 50, 52, 53, 54, 56, 58] #[21, 22, 23, 24, 25] # + test_img_idx = [51, 55, 57] #[32, 33, 34] # + + test_dir_comment = train +} + +train { + learning_rate = 2e-4 + learning_rate_milestone = [100000, 150000, 200000] + learning_rate_factor = 0.5 + end_iter = 200000 + save_freq = 5000 + val_freq = 1 + val_mesh_freq = 1 + report_freq = 100 + + N_rays = 512 + + validate_resolution_level = 4 + anneal_start = 0 + anneal_end = 25000 + anneal_start_lod1 = 0 + anneal_end_lod1 = 15000 + + use_white_bkgd = True + + # Loss + # ! for training the lod1 network, don't use this regularization in first 10k steps; then use the regularization + sdf_igr_weight = 0.1 + sdf_sparse_weight = 0.02 # 0.002 for lod1 network; 0.02 for lod0 network + sdf_decay_param = 100 # cannot be too large, which decide the tsdf range + fg_bg_weight = 0.01 # first 0.01 + bg_ratio = 0.3 + + if_fix_lod0_networks = False +} + +model { + num_lods = 1 + + sdf_network_lod0 { + lod = 0, + ch_in = 56, # the channel num of fused pyramid features + voxel_size = 0.02105263, # 0.02083333, should be 2/95 + vol_dims = [96, 96, 96], + hidden_dim = 128, + cost_type = variance_mean + d_pyramid_feature_compress = 16, + regnet_d_out = 16, + num_sdf_layers = 4, + # position embedding + multires = 6 + } + + + sdf_network_lod1 { + lod = 1, + ch_in = 56, # the channel num of fused pyramid features + voxel_size = 0.0104712, #0.01041667, should be 2/191 + vol_dims = [192, 192, 192], + hidden_dim = 128, + cost_type = variance_mean + d_pyramid_feature_compress = 8, + regnet_d_out = 16, + num_sdf_layers = 4, + + # position embedding + multires = 6 + } + + + variance_network { + init_val = 0.2 + } + + variance_network_lod1 { + init_val = 0.2 + } + + rendering_network { + in_geometry_feat_ch = 16 + in_rendering_feat_ch = 56 + anti_alias_pooling = True + } + + rendering_network_lod1 { + in_geometry_feat_ch = 16 # default 8 + in_rendering_feat_ch = 56 + anti_alias_pooling = True + + } + + + trainer { + n_samples_lod0 = 64 + n_importance_lod0 = 64 + n_samples_lod1 = 64 + n_importance_lod1 = 64 + n_outside = 0 # 128 if render_outside_uniform_sampling + perturb = 1.0 + alpha_type = div + } +} diff --git a/SparseNeuS_demo_v1/data/__init__.py b/SparseNeuS_demo_v1/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SparseNeuS_demo_v1/data/blender.py b/SparseNeuS_demo_v1/data/blender.py new file mode 100644 index 0000000000000000000000000000000000000000..c027f3e05367497c91026b362af4378fe31ff24a --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender.py @@ -0,0 +1,340 @@ +import torch +from torch.utils.data import Dataset +import json +import numpy as np +import os +from PIL import Image +from torchvision import transforms as T +from kornia import create_meshgrid +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import cv2 as cv +from data.scene import get_boundingbox + + +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def get_rays(directions, c2w): + """ + Get ray origin and normalized directions in world coordinate for all pixels in one image. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + directions: (H, W, 3) precomputed ray directions in camera coordinate + c2w: (3, 4) transformation matrix from camera coordinate to world coordinate + Outputs: + rays_o: (H*W, 3), the origin of the rays in world coordinate + rays_d: (H*W, 3), the normalized direction of the rays in world coordinate + """ + # Rotate ray directions from camera coordinate to the world coordinate + rays_d = directions @ c2w[:3, :3].T # (H, W, 3) + # rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) + # The origin of all rays is the camera origin in world coordinate + rays_o = c2w[:3, 3].expand(rays_d.shape) # (H, W, 3) + + rays_d = rays_d.view(-1, 3) + rays_o = rays_o.view(-1, 3) + + return rays_o, rays_d + + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +class BlenderDataset(Dataset): + def __init__(self, root_dir, split, scan_id, n_views, train_img_idx=[], test_img_idx=[], + img_wh=[800, 800], clip_wh=[0, 0], original_img_wh=[800, 800], + N_rays=512, h_patch_size=5, near=2.0, far=6.0): + self.root_dir = root_dir + self.split = split + self.img_wh = img_wh + self.clip_wh = clip_wh + self.define_transforms() + self.train_img_idx = train_img_idx + self.test_img_idx = test_img_idx + self.N_rays = N_rays + self.h_patch_size = h_patch_size # used to extract patch for supervision + self.n_views = n_views + self.near, self.far = near, far + self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) + + with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f: + self.meta = json.load(f) + + + self.read_meta(near, far) + # import ipdb; ipdb.set_trace() + self.raw_near_fars = np.stack([np.array([self.near, self.far]) for i in range(len(self.meta['frames']))]) + + + # ! estimate scale_mat + self.scale_mat, self.scale_factor = self.cal_scale_mat( + img_hw=[self.img_wh[1], self.img_wh[0]], + intrinsics=self.all_intrinsics[self.train_img_idx], + extrinsics=self.all_w2cs[self.train_img_idx], + near_fars=self.raw_near_fars[self.train_img_idx], + factor=1.1) + # self.scale_mat = np.eye(4) + # self.scale_factor = 1.0 + # import ipdb; ipdb.set_trace() + # * after scaling and translation, unit bounding box + self.scaled_intrinsics, self.scaled_w2cs, self.scaled_c2ws, \ + self.scaled_affine_mats, self.scaled_near_fars = self.scale_cam_info() + + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + self.partial_vol_origin = torch.Tensor([-1., -1., -1.]) + self.white_back = True + + def read_meta(self, near=2.0, far=6.0): + + + self.ref_img_idx = self.train_img_idx[0] + ref_c2w = np.array(self.meta['frames'][self.ref_img_idx]['transform_matrix']) @ self.blender2opencv + # ref_c2w = torch.FloatTensor(ref_c2w) + self.ref_c2w = ref_c2w + self.ref_w2c = np.linalg.inv(ref_c2w) + + + w, h = self.img_wh + self.focal = 0.5 * 800 / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length + self.focal *= self.img_wh[0] / 800 # modify focal length to match size self.img_wh + + # bounds, common for all scenes + self.near = near + self.far = far + self.bounds = np.array([self.near, self.far]) + + # ray directions for all pixels, same for all images (same H, W, focal) + self.directions = get_ray_directions(h, w, [self.focal,self.focal]) # (h, w, 3) + intrinsics = np.eye(4) + intrinsics[:3, :3] = np.array([[self.focal,0,w/2],[0,self.focal,h/2],[0,0,1]]).astype(np.float32) + self.intrinsics = intrinsics + + self.image_paths = [] + self.poses = [] + self.all_rays = [] + self.all_images = [] + self.all_masks = [] + self.all_w2cs = [] + self.all_intrinsics = [] + for frame in self.meta['frames']: + pose = np.array(frame['transform_matrix']) @ self.blender2opencv + self.poses += [pose] + c2w = torch.FloatTensor(pose) + w2c = np.linalg.inv(c2w) + image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png") + self.image_paths += [image_path] + img = Image.open(image_path) + img = img.resize(self.img_wh, Image.LANCZOS) + img = self.transform(img) # (4, h, w) + + self.all_masks += [img[-1:,:]>0] + # img = img[:3, :] * img[ -1:,:] + (1 - img[-1:, :]) # blend A to RGB + img = img[:3, :] * img[ -1:,:] + img = img.numpy() # (3, h, w) + self.all_images += [img] + + + self.all_masks += [] + self.all_intrinsics.append(self.intrinsics) + # - transform from world system to ref-camera system + self.all_w2cs.append(w2c @ np.linalg.inv(self.ref_w2c)) + + self.all_images = torch.from_numpy(np.stack(self.all_images)).to(torch.float32) + self.all_intrinsics = torch.from_numpy(np.stack(self.all_intrinsics)).to(torch.float32) + self.all_w2cs = torch.from_numpy(np.stack(self.all_w2cs)).to(torch.float32) + # self.img_wh = [self.img_wh[0] - self.clip_wh[0] - self.clip_wh[2], + # self.img_wh[1] - self.clip_wh[1] - self.clip_wh[3]] + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + center, radius, _ = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def scale_cam_info(self): + new_intrinsics = [] + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + for idx in range(len(self.all_images)): + + intrinsics = self.all_intrinsics[idx] + # import ipdb; ipdb.set_trace() + P = intrinsics @ self.all_w2cs[idx] @ self.scale_mat + P = P.cpu().numpy()[:3, :4] + + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + new_intrinsics.append(intrinsics) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsics[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + + new_intrinsics, new_w2cs, new_c2ws, new_affine_mats, new_near_fars = \ + np.stack(new_intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), \ + np.stack(new_affine_mats), np.stack(new_near_fars) + + new_intrinsics = torch.from_numpy(np.float32(new_intrinsics)) + new_w2cs = torch.from_numpy(np.float32(new_w2cs)) + new_c2ws = torch.from_numpy(np.float32(new_c2ws)) + new_affine_mats = torch.from_numpy(np.float32(new_affine_mats)) + new_near_fars = torch.from_numpy(np.float32(new_near_fars)) + + return new_intrinsics, new_w2cs, new_c2ws, new_affine_mats, new_near_fars + + def load_poses_all(self, file=f"transforms_train.json"): + with open(os.path.join(self.root_dir, file), 'r') as f: + meta = json.load(f) + + c2ws = [] + for i,frame in enumerate(meta['frames']): + c2ws.append(np.array(frame['transform_matrix']) @ self.blender2opencv) + return np.stack(c2ws) + + def define_transforms(self): + self.transform = T.ToTensor() + + + + def get_conditional_sample(self): + sample = {} + support_idxs = self.train_img_idx + + sample['images'] = self.all_images[support_idxs] # (V, 3, H, W) + sample['w2cs'] = self.scaled_w2cs[self.train_img_idx] # (V, 4, 4) + sample['c2ws'] = self.scaled_c2ws[self.train_img_idx] # (V, 4, 4) + sample['near_fars'] = self.scaled_near_fars[self.train_img_idx] # (V, 2) + sample['intrinsics'] = self.scaled_intrinsics[self.train_img_idx][:, :3, :3] # (V, 3, 3) + sample['affine_mats'] = self.scaled_affine_mats[self.train_img_idx] # ! in world space + + # sample['scan'] = self.scan_id + sample['scale_factor'] = torch.tensor(self.scale_factor) + sample['scale_mat'] = torch.from_numpy(self.scale_mat) + sample['trans_mat'] = torch.from_numpy(np.linalg.inv(self.ref_w2c)) + sample['img_wh'] = torch.from_numpy(np.array(self.img_wh)) + sample['partial_vol_origin'] = torch.tensor(self.partial_vol_origin, dtype=torch.float32) + + return sample + + + + def __len__(self): + if self.split == 'train': + return self.n_views * 1000 + else: + return len(self.test_img_idx) * 1000 + + + def __getitem__(self, idx): + sample = {} + + if self.split == 'train': + render_idx = self.train_img_idx[idx % self.n_views] + support_idxs = [idx for idx in self.train_img_idx if idx != render_idx] + else: + # render_idx = idx % self.n_test_images + self.n_train_images + render_idx = self.test_img_idx[idx % len(self.test_img_idx)] + support_idxs = [render_idx] + + sample['images'] = self.all_images[support_idxs] # (V, 3, H, W) + sample['w2cs'] = self.scaled_w2cs[support_idxs] # (V, 4, 4) + sample['c2ws'] = self.scaled_c2ws[support_idxs] # (V, 4, 4) + sample['intrinsics'] = self.scaled_intrinsics[support_idxs][:, :3, :3] # (V, 3, 3) + sample['affine_mats'] = self.scaled_affine_mats[support_idxs] # ! in world space + # sample['scan'] = self.scan_id + sample['scale_factor'] = torch.tensor(self.scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(self.img_wh)) + sample['partial_vol_origin'] = torch.tensor(self.partial_vol_origin, dtype=torch.float32) + sample['img_index'] = torch.tensor(render_idx) + + # - query image + sample['query_image'] = self.all_images[render_idx] + sample['query_c2w'] = self.scaled_c2ws[render_idx] + sample['query_w2c'] = self.scaled_w2cs[render_idx] + sample['query_intrinsic'] = self.scaled_intrinsics[render_idx] + sample['query_near_far'] = self.scaled_near_fars[render_idx] + # sample['meta'] = str(self.scan_id) + "_" + os.path.basename(self.images_list[render_idx]) + sample['scale_mat'] = torch.from_numpy(self.scale_mat) + sample['trans_mat'] = torch.from_numpy(np.linalg.inv(self.ref_w2c)) + sample['rendering_c2ws'] = self.scaled_c2ws[self.test_img_idx] + sample['rendering_imgs_idx'] = torch.Tensor(np.array(self.test_img_idx).astype(np.int32)) + + # - generate rays + if self.split == 'val' or self.split == 'test': + sample_rays = gen_rays_from_single_image( + self.img_wh[1], self.img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=None, + mask=None) + else: + sample_rays = gen_random_rays_from_single_image( + self.img_wh[1], self.img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=None, + mask=None, + dilated_mask=None, + importance_sample=False, + h_patch_size=self.h_patch_size + ) + + sample['rays'] = sample_rays + + return sample \ No newline at end of file diff --git a/SparseNeuS_demo_v1/data/blender_general.py b/SparseNeuS_demo_v1/data/blender_general.py new file mode 100644 index 0000000000000000000000000000000000000000..871bcd6e9e2542110213e34ac5e7bde97184d938 --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general.py @@ -0,0 +1,432 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[]): + + # print("root_dir: ", root_dir) + self.root_dir = root_dir + self.split = split + + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + + lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid + with open(lvis_json_path, 'r') as f: + lvis_paths = json.load(f) + if self.split == 'train': + self.lvis_paths = lvis_paths['train'] + else: + self.lvis_paths = lvis_paths['val'] + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + pose_json_path = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json" + with open(pose_json_path, 'r') as f: + meta = json.load(f) + + self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10" + self.img_wh = (256, 256) + self.input_poses = np.array(list(meta["c2ws"].values())) + intrinsic = np.eye(4) + intrinsic[:3, :3] = np.array(meta["intrinsics"]) + self.intrinsic = intrinsic + self.near_far = np.array(meta["near_far"]) + + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + # self.root_dir = root_dir + for idx, img_id in enumerate(self.img_ids): + pose = self.input_poses[idx] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + depth_h = np.array(read_pfm(filename)[0], dtype=np.float32) # (1200, 1600) + depth_h = cv2.resize(depth_h, None, fx=0.5, fy=0.5, + interpolation=cv2.INTER_NEAREST) # (600, 800) + depth_h = depth_h[44:556, 80:720] # (512, 640) + depth_h = cv2.resize(depth_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + depth = cv2.resize(depth_h, None, fx=1.0 / 4, fy=1.0 / 4, + interpolation=cv2.INTER_NEAREST) + + return depth, depth_h + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + # print("center", center) + # print("radius", radius) + # print("bounds", bounds) + # import ipdb; ipdb.set_trace() + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + return 8*len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + depth_h = cv2.imread(filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 65535 * 1.4 + 0.5 + + depth_h[depth_h < near_bound+1e-3] = 0.0 + + depth = {} + for l in range(3): + depth[f"level_{l}"] = cv2.resize( + depth_h, + None, + fx=1.0 / (2**l), + fy=1.0 / (2**l), + interpolation=cv2.INTER_NEAREST, + ) + + if self.split == "train": + cutout = np.ones_like(depth[f"level_2"]) + h0 = int(np.random.randint(0, high=cutout.shape[0] // 5, size=1)) + h1 = int( + np.random.randint( + 4 * cutout.shape[0] // 5, high=cutout.shape[0], size=1 + ) + ) + w0 = int(np.random.randint(0, high=cutout.shape[1] // 5, size=1)) + w1 = int( + np.random.randint( + 4 * cutout.shape[1] // 5, high=cutout.shape[1], size=1 + ) + ) + cutout[h0:h1, w0:w1] = 0 + depth_aug = depth[f"level_2"] * cutout + else: + depth_aug = depth[f"level_2"].copy() + + return depth, depth_h, depth_aug + + + def __getitem__(self, idx): + sample = {} + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views + + + folder_uid_dict = self.lvis_paths[idx//8] + idx = idx % 8 # [0, 7] + folder_id = folder_uid_dict['folder_id'] + uid = folder_uid_dict['uid'] + + # idx = idx % 8 + # uid = 'c40d63d5d740405e91c7f5fce855076e' + # folder_id = '000-123' + + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{idx}.png') + + depth_filename = os.path.join(os.path.join(self.root_dir, folder_id, uid, f'view_{idx}_depth_mm.png')) + + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0 + mask_h = depth_h > 0 + # print("valid pixels", np.sum(mask_h)) + directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3] + surface_points = directions * depth_h[..., None] # [H, W, 3] + distance = np.linalg.norm(surface_points, axis=-1) # [H, W] + depth_h = distance + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + src_views = range(8+idx*4, 8+(idx+1)*4) + + + for vid in src_views: + img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{idx}_{vid%4}_10.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + # ! estimate scale_mat + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + # print(scale_mat) + # print(scale_factor) + # ! calculate the new w2cs after scaling + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + view_ids = [idx] + list(src_views) + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = folder_id + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0]) + + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_12_narrow.py b/SparseNeuS_demo_v1/data/blender_general_12_narrow.py new file mode 100644 index 0000000000000000000000000000000000000000..bb1183fb695101bac1f8f33da9438a84378b3dca --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_12_narrow.py @@ -0,0 +1,427 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[]): + + self.root_dir = root_dir + self.split = split + self.imgs_per_instance = 12 + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + + lvis_json_path = '/objaverse-processed/zero12345_img/narrow_12_split_upd.json' # folder_id and uid + with open(lvis_json_path, 'r') as f: + lvis_paths = json.load(f) + if self.split == 'train': + self.lvis_paths = lvis_paths['train'] + else: + self.lvis_paths = lvis_paths['val'] + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + pose_json_path_narrow_8 = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json" + with open(pose_json_path_narrow_8, 'r') as f: + narrow_8_meta = json.load(f) + + pose_json_path_narrow_4 = "/objaverse-processed/zero12345_img/zero12345_2stage_12_pose.json" + with open(pose_json_path_narrow_4, 'r') as f: + narrow_4_meta = json.load(f) + + + self.img_ids = list(narrow_8_meta["c2ws"].keys()) + list(narrow_4_meta["c2ws"].keys()) # (8 + 8*4) + (4 + 4*4) + self.img_wh = (256, 256) + self.input_poses = np.array(list(narrow_8_meta["c2ws"].values()) + list(narrow_4_meta["c2ws"].values())) + intrinsic = np.eye(4) + assert narrow_8_meta["intrinsics"] == narrow_4_meta["intrinsics"], "intrinsics not equal" + intrinsic[:3, :3] = np.array(narrow_8_meta["intrinsics"]) + self.intrinsic = intrinsic + assert narrow_8_meta["near_far"] == narrow_4_meta["near_far"], "near_far not equal" + self.near_far = np.array(narrow_8_meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + for idx, img_id in enumerate(self.img_ids): + pose = self.input_poses[idx] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + + + + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + return self.imgs_per_instance*len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views + idx_original=idx + + folder_uid_dict = self.lvis_paths[idx//self.imgs_per_instance] + + folder_id = folder_uid_dict['folder_id'] + uid = folder_uid_dict['uid'] + + idx = idx % self.imgs_per_instance # [0, 11] + if idx < 8: + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}.png') + + depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}_depth_mm.png')) + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + else: + # target view + c2w = self.c2ws[idx-8+40] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow_12/", folder_id, uid, f'view_{idx}.png') + + depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow_12/", folder_id, uid, f'view_{idx}_depth_mm.png')) + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0 + mask_h = depth_h > 0 + # print("valid pixels", np.sum(mask_h)) + directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3] + surface_points = directions * depth_h[..., None] # [H, W, 3] + distance = np.linalg.norm(surface_points, axis=-1) # [H, W] + depth_h = distance + + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + + src_views = range(8, 8 + 8 * 4 + 4 + 4*4) + src_views_used = [] + skipped_idx = [40, 41, 42, 43] + for vid in src_views: + if vid in skipped_idx: + continue + + src_views_used.append(vid) + cur_view_id = (vid - 8) // 4 # [0, 7] + + # choose narrow + if cur_view_id < 8: + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{cur_view_id}_{vid%4}_10.png') + else: # choose 2-stage + cur_view_id = cur_view_id - 1 + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow_12", folder_id, uid, f'view_{cur_view_id}_{vid%4}.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + + + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + # print("img numeber: ", len(imgs)) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + view_ids = [idx_original % self.imgs_per_instance] + src_views_used + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = folder_id + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + if view_ids[0] < 8: + meta_end = "_narrow"+ "_refview" + str(view_ids[0]) + else: + meta_end = "_two_stage"+ "_refview" + str(view_ids[0] - 8) + sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + meta_end + + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_12_narrow_8.py b/SparseNeuS_demo_v1/data/blender_general_12_narrow_8.py new file mode 100644 index 0000000000000000000000000000000000000000..467dc5d4d1df3b6d3c8aa4384a1048bec9910973 --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_12_narrow_8.py @@ -0,0 +1,427 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[]): + + self.root_dir = root_dir + self.split = split + self.imgs_per_instance = 8 + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + + lvis_json_path = '/objaverse-processed/zero12345_img/narrow_12_split_upd.json' # folder_id and uid + with open(lvis_json_path, 'r') as f: + lvis_paths = json.load(f) + if self.split == 'train': + self.lvis_paths = lvis_paths['train'] + else: + self.lvis_paths = lvis_paths['val'] + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + pose_json_path_narrow_8 = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json" + with open(pose_json_path_narrow_8, 'r') as f: + narrow_8_meta = json.load(f) + + pose_json_path_narrow_4 = "/objaverse-processed/zero12345_img/zero12345_2stage_12_pose.json" + with open(pose_json_path_narrow_4, 'r') as f: + narrow_4_meta = json.load(f) + + + self.img_ids = list(narrow_8_meta["c2ws"].keys()) + list(narrow_4_meta["c2ws"].keys()) # (8 + 8*4) + (4 + 4*4) + self.img_wh = (256, 256) + self.input_poses = np.array(list(narrow_8_meta["c2ws"].values()) + list(narrow_4_meta["c2ws"].values())) + intrinsic = np.eye(4) + assert narrow_8_meta["intrinsics"] == narrow_4_meta["intrinsics"], "intrinsics not equal" + intrinsic[:3, :3] = np.array(narrow_8_meta["intrinsics"]) + self.intrinsic = intrinsic + assert narrow_8_meta["near_far"] == narrow_4_meta["near_far"], "near_far not equal" + self.near_far = np.array(narrow_8_meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + for idx, img_id in enumerate(self.img_ids): + pose = self.input_poses[idx] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + + + + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + return self.imgs_per_instance*len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views + idx_original=idx + + folder_uid_dict = self.lvis_paths[idx//self.imgs_per_instance] + + folder_id = folder_uid_dict['folder_id'] + uid = folder_uid_dict['uid'] + + idx = idx % self.imgs_per_instance # [0, 11] + if idx < 8: + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}.png') + + depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}_depth_mm.png')) + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + else: + # target view + c2w = self.c2ws[idx-8+40] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow_12/", folder_id, uid, f'view_{idx}.png') + + depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow_12/", folder_id, uid, f'view_{idx}_depth_mm.png')) + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0 + mask_h = depth_h > 0 + # print("valid pixels", np.sum(mask_h)) + directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3] + surface_points = directions * depth_h[..., None] # [H, W, 3] + distance = np.linalg.norm(surface_points, axis=-1) # [H, W] + depth_h = distance + + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + + src_views = range(8, 8 + 8 * 4 + 4 + 4*4) + src_views_used = [] + skipped_idx = [40, 41, 42, 43] + for vid in src_views: + if vid in skipped_idx: + continue + + src_views_used.append(vid) + cur_view_id = (vid - 8) // 4 # [0, 7] + + # choose narrow + if cur_view_id < 8: + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{cur_view_id}_{vid%4}_10.png') + else: # choose 2-stage + cur_view_id = cur_view_id - 1 + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow_12", folder_id, uid, f'view_{cur_view_id}_{vid%4}.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + + + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + # print("img numeber: ", len(imgs)) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + view_ids = [idx_original % self.imgs_per_instance] + src_views_used + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = folder_id + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + if view_ids[0] < 8: + meta_end = "_narrow"+ "_refview" + str(view_ids[0]) + else: + meta_end = "_two_stage"+ "_refview" + str(view_ids[0] - 8) + sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + meta_end + + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_360.py b/SparseNeuS_demo_v1/data/blender_general_360.py new file mode 100644 index 0000000000000000000000000000000000000000..37e8664613a614c03227375d8a0b25224d694bdc --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_360.py @@ -0,0 +1,412 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d + +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[]): + + # print("root_dir: ", root_dir) + self.root_dir = root_dir + self.split = split + + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + + lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid + with open(lvis_json_path, 'r') as f: + lvis_paths = json.load(f) + if self.split == 'train': + self.lvis_paths = lvis_paths['train'] + else: + self.lvis_paths = lvis_paths['val'] + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + pose_json_path = "/objaverse-processed/zero12345_img/zero12345_wide_pose.json" + with open(pose_json_path, 'r') as f: + meta = json.load(f) + + self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0_0", "view_0_5", "view_1_7" + self.img_wh = (256, 256) + self.input_poses = np.array(list(meta["c2ws"].values())) + intrinsic = np.eye(4) + intrinsic[:3, :3] = np.array(meta["intrinsics"]) + self.intrinsic = intrinsic + self.near_far = np.array(meta["near_far"]) + + + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + # self.root_dir = root_dir + for idx, img_id in enumerate(self.img_ids): + pose = self.input_poses[idx] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + depth_h = np.array(read_pfm(filename)[0], dtype=np.float32) # (1200, 1600) + depth_h = cv2.resize(depth_h, None, fx=0.5, fy=0.5, + interpolation=cv2.INTER_NEAREST) # (600, 800) + depth_h = depth_h[44:556, 80:720] # (512, 640) + depth_h = cv2.resize(depth_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + depth = cv2.resize(depth_h, None, fx=1.0 / 4, fy=1.0 / 4, + interpolation=cv2.INTER_NEAREST) + + return depth, depth_h + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + # print("center", center) + # print("radius", radius) + # print("bounds", bounds) + # import ipdb; ipdb.set_trace() + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + return 36*len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views + + + folder_uid_dict = self.lvis_paths[idx//36] + + + folder_id = folder_uid_dict['folder_id'] + uid = folder_uid_dict['uid'] + + idx = idx % 36 # [0, 35] + gt_view_idx = idx // 12 # [0, 2] + target_view_idx = idx % 12 # [0, 11] + + + + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{gt_view_idx}_{target_view_idx}_gt.png') + + depth_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{gt_view_idx}_{target_view_idx}_gt_depth_mm.png') + + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0 + mask_h = depth_h > 0 + # print("valid pixels", np.sum(mask_h)) + directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3] + surface_points = directions * depth_h[..., None] # [H, W, 3] + distance = np.linalg.norm(surface_points, axis=-1) # [H, W] + depth_h = distance + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + # src_views = range(gt_view_idx * 12, (gt_view_idx + 1) * 12) + + idx_of_12 = idx - 12 * gt_view_idx # idx % 12 + + src_views = list(i % 12 + 12 * gt_view_idx for i in range(idx_of_12 - 1-1, idx_of_12 + 2+1)) + + + for vid in src_views: + # if vid == idx: + # continue + img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{gt_view_idx}_{target_view_idx}.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + # ! estimate scale_mat + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + # print(scale_mat) + # print(scale_factor) + # ! calculate the new w2cs after scaling + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + view_ids = [idx] + list(src_views) + + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = folder_id + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0]) + + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_360_2_stage_1_3.py b/SparseNeuS_demo_v1/data/blender_general_360_2_stage_1_3.py new file mode 100644 index 0000000000000000000000000000000000000000..72ad72bbfb336fa3e0d8b69f74c94afbea1593b7 --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_360_2_stage_1_3.py @@ -0,0 +1,406 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[]): + + # print("root_dir: ", root_dir) + self.root_dir = root_dir + self.split = split + + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + + lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid + with open(lvis_json_path, 'r') as f: + lvis_paths = json.load(f) + if self.split == 'train': + self.lvis_paths = lvis_paths['train'] + else: + self.lvis_paths = lvis_paths['val'] + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + pose_json_path = "/objaverse-processed/zero12345_img/zero12345_2stage_pose.json" + with open(pose_json_path, 'r') as f: + meta = json.load(f) + + self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0_0", "view_0_5", "view_1_7" + self.img_wh = (256, 256) + self.input_poses = np.array(list(meta["c2ws"].values())) + intrinsic = np.eye(4) + intrinsic[:3, :3] = np.array(meta["intrinsics"]) + self.intrinsic = intrinsic + self.near_far = np.array(meta["near_far"]) + + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + # self.root_dir = root_dir + for idx, img_id in enumerate(self.img_ids): + pose = self.input_poses[idx] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + depth_h = np.array(read_pfm(filename)[0], dtype=np.float32) # (1200, 1600) + depth_h = cv2.resize(depth_h, None, fx=0.5, fy=0.5, + interpolation=cv2.INTER_NEAREST) # (600, 800) + depth_h = depth_h[44:556, 80:720] # (512, 640) + depth_h = cv2.resize(depth_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + depth = cv2.resize(depth_h, None, fx=1.0 / 4, fy=1.0 / 4, + interpolation=cv2.INTER_NEAREST) + + return depth, depth_h + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + # print("center", center) + # print("radius", radius) + # print("bounds", bounds) + # import ipdb; ipdb.set_trace() + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + return 6*len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views + + + folder_uid_dict = self.lvis_paths[idx//6] + idx = idx % 6 + + folder_id = folder_uid_dict['folder_id'] + uid = folder_uid_dict['uid'] + + # idx = idx % 24 # [0, 23] + + + + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_0_{idx}_gt.png') + + depth_filename = os.path.join(self.root_dir, folder_id, uid, f'view_0_{idx}_gt_depth_mm.png') + + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0 + mask_h = depth_h > 0 + # print("valid pixels", np.sum(mask_h)) + directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3] + surface_points = directions * depth_h[..., None] # [H, W, 3] + distance = np.linalg.norm(surface_points, axis=-1) # [H, W] + depth_h = distance + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + # src_views = range(gt_view_idx * 12, (gt_view_idx + 1) * 12) + + + src_views = range(6+idx*4, 6+(idx+1)*4) + + for vid in src_views: + # if vid == idx: + # continue + img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_0_{idx}_{vid % 4}.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + # ! estimate scale_mat + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + # print(scale_mat) + # print(scale_factor) + # ! calculate the new w2cs after scaling + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + view_ids = [idx] + list(src_views) + + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = folder_id + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0]) + + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_360_2_stage_1_4.py b/SparseNeuS_demo_v1/data/blender_general_360_2_stage_1_4.py new file mode 100644 index 0000000000000000000000000000000000000000..380706615bfe4a183b302f127af9913bfc2f4790 --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_360_2_stage_1_4.py @@ -0,0 +1,411 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[]): + + # print("root_dir: ", root_dir) + self.root_dir = root_dir + self.split = split + + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + + lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid + with open(lvis_json_path, 'r') as f: + lvis_paths = json.load(f) + if self.split == 'train': + self.lvis_paths = lvis_paths['train'] + else: + self.lvis_paths = lvis_paths['val'] + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + pose_json_path = "/objaverse-processed/zero12345_img/zero12345_2stage_5pred_pose.json" + with open(pose_json_path, 'r') as f: + meta = json.load(f) + + self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0_0", "view_0_5", "view_1_7" + self.img_wh = (256, 256) + self.input_poses = np.array(list(meta["c2ws"].values())) + intrinsic = np.eye(4) + intrinsic[:3, :3] = np.array(meta["intrinsics"]) + self.intrinsic = intrinsic + self.near_far = np.array(meta["near_far"]) + + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + # self.root_dir = root_dir + for idx, img_id in enumerate(self.img_ids): + pose = self.input_poses[idx] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + depth_h = np.array(read_pfm(filename)[0], dtype=np.float32) # (1200, 1600) + depth_h = cv2.resize(depth_h, None, fx=0.5, fy=0.5, + interpolation=cv2.INTER_NEAREST) # (600, 800) + depth_h = depth_h[44:556, 80:720] # (512, 640) + depth_h = cv2.resize(depth_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + depth = cv2.resize(depth_h, None, fx=1.0 / 4, fy=1.0 / 4, + interpolation=cv2.INTER_NEAREST) + + return depth, depth_h + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + # print("center", center) + # print("radius", radius) + # print("bounds", bounds) + # import ipdb; ipdb.set_trace() + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + return 6*len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views + + + folder_uid_dict = self.lvis_paths[idx//6] + idx = idx % 6 + + folder_id = folder_uid_dict['folder_id'] + uid = folder_uid_dict['uid'] + + # idx = idx % 24 # [0, 23] + + + + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage", folder_id, uid, f'view_0_{idx}_gt.png') + + depth_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage", folder_id, uid, f'view_0_{idx}_gt_depth_mm.png') + + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + # print("img_pre", img.shape) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + # print("img", img.shape) + imgs += [img] + + + depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0 + mask_h = depth_h > 0 + # print("valid pixels", np.sum(mask_h)) + directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3] + surface_points = directions * depth_h[..., None] # [H, W, 3] + distance = np.linalg.norm(surface_points, axis=-1) # [H, W] + depth_h = distance + # print("depth_h", depth_h.shape) + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + # src_views = range(gt_view_idx * 12, (gt_view_idx + 1) * 12) + + + src_views = range(6+idx*4, 6+(idx+1)*4) + + for vid in src_views: + # if vid == idx: + # continue + img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_0_{idx}_{vid % 4 + 1}.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + # print("img shape1: ", img.shape) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + # print("img shape2: ", img.shape) + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + # ! estimate scale_mat + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + # print(scale_mat) + # print(scale_factor) + # ! calculate the new w2cs after scaling + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + # print("imgs: ", len(imgs)) + # print("img1 shape:", imgs[0].shape) + # print("img2 shape:", imgs[1].shape) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + view_ids = [idx] + list(src_views) + + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = folder_id + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0]) + + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_4_narrow_and_4_2_stage_mix.py b/SparseNeuS_demo_v1/data/blender_general_4_narrow_and_4_2_stage_mix.py new file mode 100644 index 0000000000000000000000000000000000000000..beb1f976907680936b20b37d76133589804d40c5 --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_4_narrow_and_4_2_stage_mix.py @@ -0,0 +1,480 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[]): + + self.root_dir = root_dir + self.split = split + self.imgs_per_instance = 16 + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + + lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid + with open(lvis_json_path, 'r') as f: + lvis_paths = json.load(f) + if self.split == 'train': + self.lvis_paths = lvis_paths['train'] + else: + self.lvis_paths = lvis_paths['val'] + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + pose_json_path_narrow = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json" + with open(pose_json_path_narrow, 'r') as f: + narrow_meta = json.load(f) + + pose_json_path_two_stage = "/objaverse-processed/zero12345_img/zero12345_2stage_8_pose.json" + with open(pose_json_path_two_stage, 'r') as f: + two_stage_meta = json.load(f) + + + self.img_ids = list(narrow_meta["c2ws"].keys()) + list(two_stage_meta["c2ws"].keys()) # (8 + 8*4) + (8 + 4*4) + self.img_wh = (256, 256) + self.input_poses = np.array(list(narrow_meta["c2ws"].values()) + list(two_stage_meta["c2ws"].values())) + intrinsic = np.eye(4) + assert narrow_meta["intrinsics"] == two_stage_meta["intrinsics"], "intrinsics not equal" + intrinsic[:3, :3] = np.array(narrow_meta["intrinsics"]) + self.intrinsic = intrinsic + assert narrow_meta["near_far"] == two_stage_meta["near_far"], "near_far not equal" + self.near_far = np.array(narrow_meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + for idx, img_id in enumerate(self.img_ids): + pose = self.input_poses[idx] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + + + + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + return self.imgs_per_instance * len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views + idx_original=idx + + folder_uid_dict = self.lvis_paths[idx//self.imgs_per_instance] + + folder_id = folder_uid_dict['folder_id'] + uid = folder_uid_dict['uid'] + + if idx % 2 == 0: + valid_list = [0, 2, 4, 6] + else: + valid_list = [1, 3, 5, 7] + + if idx % 16 < 8: + idx = idx % 16 # [0, 7] + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}.png') + + depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}_depth_mm.png')) + + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0 + mask_h = depth_h > 0 + # print("valid pixels", np.sum(mask_h)) + directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3] + surface_points = directions * depth_h[..., None] # [H, W, 3] + distance = np.linalg.norm(surface_points, axis=-1) # [H, W] + depth_h = distance + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + # src_views = range(8+idx*4, 8+(idx+1)*4) + + src_views = range(8, 8 + 8 * 4) + src_views_used = [] + for vid in src_views: + view_dix_to_use = (vid - 8) // 4 + if view_dix_to_use not in valid_list: + continue + src_views_used.append(vid) + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{(vid - 8) // 4}_{vid%4}_10.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + else: + idx = idx % 16 - 8 # [0, 7] + + c2w = self.c2ws[idx + 40] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage_8/", folder_id, uid, f'view_0_{idx}_0.png') + + + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + # print("img_pre", img.shape) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + # print("img", img.shape) + imgs += [img] + + + depth_h =torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32) + depth_h = depth_h.fill_(-1.0) + # depth_h = torch.fill((img.shape[1], img.shape[2]), -1.0) + # print("depth_h", depth_h.shape) + mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32) + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + + + src_views = range(40+8, 40+8+32) + src_views_used = [] + for vid in src_views: + view_dix_to_use = (vid - 40 - 8) // 4 + if view_dix_to_use not in valid_list: + continue + src_views_used.append(vid) + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage_8/", folder_id, uid, f'view_0_{idx}_{(vid-48) % 4 + 1}.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + # print("img shape1: ", img.shape) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + # print("img shape2: ", img.shape) + imgs += [img] + depth_h =torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32) + depth_h = depth_h.fill_(-1.0) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + # print("img numeber: ", len(imgs)) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + view_ids = [idx_original % self.imgs_per_instance] + src_views_used + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = folder_id + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + if view_ids[0] < 8: + meta_end = "_narrow"+ "_refview" + str(view_ids[0]) + else: + meta_end = "_two_stage"+ "_refview" + str(view_ids[0] - 8) + sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + meta_end + + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_4_narrow_and_6_2_stage_mix.py b/SparseNeuS_demo_v1/data/blender_general_4_narrow_and_6_2_stage_mix.py new file mode 100644 index 0000000000000000000000000000000000000000..e80567fe34ee51cb49355ee26ea8ce80dff706e6 --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_4_narrow_and_6_2_stage_mix.py @@ -0,0 +1,476 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[]): + + self.root_dir = root_dir + self.split = split + + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + + lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid + with open(lvis_json_path, 'r') as f: + lvis_paths = json.load(f) + if self.split == 'train': + self.lvis_paths = lvis_paths['train'] + else: + self.lvis_paths = lvis_paths['val'] + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + pose_json_path_narrow = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json" + with open(pose_json_path_narrow, 'r') as f: + narrow_meta = json.load(f) + + pose_json_path_two_stage = "/objaverse-processed/zero12345_img/zero12345_2stage_5pred_pose.json" + with open(pose_json_path_two_stage, 'r') as f: + two_stage_meta = json.load(f) + + + self.img_ids = list(narrow_meta["c2ws"].keys()) + list(two_stage_meta["c2ws"].keys()) # (8 + 8*4) + (6 + 6*4) + self.img_wh = (256, 256) + self.input_poses = np.array(list(narrow_meta["c2ws"].values()) + list(two_stage_meta["c2ws"].values())) + intrinsic = np.eye(4) + assert narrow_meta["intrinsics"] == two_stage_meta["intrinsics"], "intrinsics not equal" + intrinsic[:3, :3] = np.array(narrow_meta["intrinsics"]) + self.intrinsic = intrinsic + assert narrow_meta["near_far"] == two_stage_meta["near_far"], "near_far not equal" + self.near_far = np.array(narrow_meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + for idx, img_id in enumerate(self.img_ids): + pose = self.input_poses[idx] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + + + + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + return 12*len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views + idx_original=idx + + folder_uid_dict = self.lvis_paths[idx//12] + + folder_id = folder_uid_dict['folder_id'] + uid = folder_uid_dict['uid'] + + if idx % 12 < 8: + idx = idx % 12 # [0, 7] + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}.png') + + depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}_depth_mm.png')) + + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0 + mask_h = depth_h > 0 + # print("valid pixels", np.sum(mask_h)) + directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3] + surface_points = directions * depth_h[..., None] # [H, W, 3] + distance = np.linalg.norm(surface_points, axis=-1) # [H, W] + depth_h = distance + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + # src_views = range(8+idx*4, 8+(idx+1)*4) + + src_views = range(8, 8 + 8 * 4) + src_views_used = [] + for vid in src_views: + if (vid // 4) % 2 != idx % 2: + continue + src_views_used.append(vid) + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{(vid - 8) // 4}_{vid%4}_10.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + else: + idx = idx % 12 - 8 # [0, 5] + valid_list = [0, 2, 3, 5] + idx = valid_list[idx] # [0, 3] + c2w = self.c2ws[idx + 40] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage_5pred/", folder_id, uid, f'view_0_{idx}_0.png') + + + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + # print("img_pre", img.shape) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + # print("img", img.shape) + imgs += [img] + + + depth_h =torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32) + depth_h = depth_h.fill_(-1.0) + # depth_h = torch.fill((img.shape[1], img.shape[2]), -1.0) + # print("depth_h", depth_h.shape) + mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32) + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + # src_views = range(gt_view_idx * 12, (gt_view_idx + 1) * 12) + + + src_views = range(40+6, 40+6+24) + src_views_used = [] + for vid in src_views: + view_dix_to_use = (vid - 40 - 6) // 4 + if view_dix_to_use not in valid_list: + continue + src_views_used.append(vid) + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage_5pred/", folder_id, uid, f'view_0_{idx}_{(vid-46) % 4 + 1}.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + # print("img shape1: ", img.shape) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + # print("img shape2: ", img.shape) + imgs += [img] + depth_h =torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32) + depth_h = depth_h.fill_(-1.0) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + # print("img numeber: ", len(imgs)) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + view_ids = [idx_original % 12] + src_views_used + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = folder_id + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + if view_ids[0] < 8: + meta_end = "_narrow"+ "_refview" + str(view_ids[0]) + else: + meta_end = "_two_stage"+ "_refview" + str(view_ids[0] - 8) + sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + meta_end + + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_6_narrow_and_6_2_stage_blend_mix.py b/SparseNeuS_demo_v1/data/blender_general_6_narrow_and_6_2_stage_blend_mix.py new file mode 100644 index 0000000000000000000000000000000000000000..248e9f9591b95a711406b0e1efb3568e05e2414a --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_6_narrow_and_6_2_stage_blend_mix.py @@ -0,0 +1,449 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[]): + + self.root_dir = root_dir + self.split = split + if self.split == 'train': + self.imgs_per_instance = 12 + else: + self.imgs_per_instance = 16 + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + + lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid + with open(lvis_json_path, 'r') as f: + lvis_paths = json.load(f) + if self.split == 'train': + self.lvis_paths = lvis_paths['train'] + else: + self.lvis_paths = lvis_paths['val'] + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + pose_json_path_narrow = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json" + with open(pose_json_path_narrow, 'r') as f: + narrow_meta = json.load(f) + + pose_json_path_two_stage = "/objaverse-processed/zero12345_img/zero12345_2stage_8_pose.json" + with open(pose_json_path_two_stage, 'r') as f: + two_stage_meta = json.load(f) + + + self.img_ids = list(narrow_meta["c2ws"].keys()) + list(two_stage_meta["c2ws"].keys()) # (8 + 8*4) + (8 + 4*4) + self.img_wh = (256, 256) + self.input_poses = np.array(list(narrow_meta["c2ws"].values()) + list(two_stage_meta["c2ws"].values())) + intrinsic = np.eye(4) + assert narrow_meta["intrinsics"] == two_stage_meta["intrinsics"], "intrinsics not equal" + intrinsic[:3, :3] = np.array(narrow_meta["intrinsics"]) + self.intrinsic = intrinsic + assert narrow_meta["near_far"] == two_stage_meta["near_far"], "near_far not equal" + self.near_far = np.array(narrow_meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + for idx, img_id in enumerate(self.img_ids): + pose = self.input_poses[idx] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + + + + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + return self.imgs_per_instance*len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views + idx_original=idx + + folder_uid_dict = self.lvis_paths[idx//self.imgs_per_instance] + + folder_id = folder_uid_dict['folder_id'] + uid = folder_uid_dict['uid'] + + if self.split == 'train': + if idx == 4: + idx = 5 + elif idx == 5: + idx = 7 + elif idx == 10: + idx = 13 + elif idx == 11: + idx = 15 + + if idx % 16 < 8: # narrow image as target + idx = idx % 16 # [0, 7] + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}.png') + + depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}_depth_mm.png')) + + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0 + mask_h = depth_h > 0 + # print("valid pixels", np.sum(mask_h)) + directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3] + surface_points = directions * depth_h[..., None] # [H, W, 3] + distance = np.linalg.norm(surface_points, axis=-1) # [H, W] + depth_h = distance + + else: + idx = idx % 16 - 8 # [0, 5] + c2w = self.c2ws[idx + 40] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage_8/", folder_id, uid, f'view_0_{idx}_0.png') + + img = Image.open(img_filename) + img = self.transform(img) # (4, h, w) + + # print("img_pre", img.shape) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + # print("img", img.shape) + imgs += [img] + + depth_h =torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32) + depth_h = depth_h.fill_(-1.0) + + mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32) + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + if_use_narrow = [] + if self.split == 'train': + for i in range(8): + if np.random.random() > 0.5: + if_use_narrow.append(True) # use narrow + else: + if_use_narrow.append(False) # 2-stage prediction + if_use_narrow[origin_idx % 8] = True if origin_idx < 8 else False + else: + for i in range(8): + if_use_narrow.append( True if origin_idx < 8 else False) + src_views = range(8, 8 + 8 * 4) + src_views_used = [] + for vid in src_views: + if ((vid - 8) // 4 == 4) or ((vid - 8) // 4 == 6): + continue + src_views_used.append(vid) + cur_view_id = (vid - 8) // 4 + # choose narrow + if if_use_narrow[cur_view_id]: + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{cur_view_id}_{vid%4}_10.png') + else: # choose 2-stage + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage_8/", folder_id, uid, f'view_0_{(vid - 8) // 4}_{(vid-8) % 4 + 1}.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + new_depths_h.append(depth * scale_factor) + + + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + view_ids = [idx_original % self.imgs_per_instance] + src_views_used + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = folder_id + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + if view_ids[0] < 8: + meta_end = "_narrow"+ "_refview" + str(view_ids[0]) + else: + meta_end = "_two_stage"+ "_refview" + str(view_ids[0] - 8) + sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + meta_end + + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_8_2_stage.py b/SparseNeuS_demo_v1/data/blender_general_8_2_stage.py new file mode 100644 index 0000000000000000000000000000000000000000..e1fd371e5fc7be9685b81efa3d607018b2a9bdb1 --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_8_2_stage.py @@ -0,0 +1,396 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[]): + + self.root_dir = root_dir + self.split = split + + self.imgs_per_instance = 8 + + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + + lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid + with open(lvis_json_path, 'r') as f: + lvis_paths = json.load(f) + if self.split == 'train': + self.lvis_paths = lvis_paths['train'] + else: + self.lvis_paths = lvis_paths['val'] + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + pose_json_path_narrow = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json" + with open(pose_json_path_narrow, 'r') as f: + narrow_meta = json.load(f) + + pose_json_path_two_stage = "/objaverse-processed/zero12345_img/zero12345_2stage_8_pose.json" + with open(pose_json_path_two_stage, 'r') as f: + two_stage_meta = json.load(f) + + + self.img_ids = list(narrow_meta["c2ws"].keys()) + list(two_stage_meta["c2ws"].keys()) # (8 + 8*4) + (8 + 8*4) + self.img_wh = (256, 256) + self.input_poses = np.array(list(narrow_meta["c2ws"].values()) + list(two_stage_meta["c2ws"].values())) + intrinsic = np.eye(4) + assert narrow_meta["intrinsics"] == two_stage_meta["intrinsics"], "intrinsics not equal" + intrinsic[:3, :3] = np.array(narrow_meta["intrinsics"]) + self.intrinsic = intrinsic + assert narrow_meta["near_far"] == two_stage_meta["near_far"], "near_far not equal" + self.near_far = np.array(narrow_meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + for idx, img_id in enumerate(self.img_ids): + pose = self.input_poses[idx] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + + + + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + return self.imgs_per_instance * len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views + idx_original=idx + + folder_uid_dict = self.lvis_paths[idx//self.imgs_per_instance] + + folder_id = folder_uid_dict['folder_id'] + uid = folder_uid_dict['uid'] + + idx = idx % self.imgs_per_instance # [0, 7] + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}.png') + + depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}_depth_mm.png')) + + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0 + mask_h = depth_h > 0 + directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3] + surface_points = directions * depth_h[..., None] # [H, W, 3] + distance = np.linalg.norm(surface_points, axis=-1) # [H, W] + depth_h = distance + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + + + src_views = range(8, 8+32) + src_views_used = [] + for vid in src_views: + view_dix_to_use = (vid - 8) // 4 + src_views_used.append(vid) + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage_8/", folder_id, uid, f'view_0_{idx}_{(vid-8) % 4 + 1}.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + depth_h =torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32) + depth_h = depth_h.fill_(-1.0) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + + new_depths_h.append(depth * scale_factor) + + + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + view_ids = [idx_original % self.imgs_per_instance] + src_views_used + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = folder_id + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + meta_end = "_two_stage"+ "_refview" + str(view_ids[0] - 8) + sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + meta_end + + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_8_4_gt.py b/SparseNeuS_demo_v1/data/blender_general_8_4_gt.py new file mode 100644 index 0000000000000000000000000000000000000000..b1072d6a3e02f1908add474963aa6c6acaf69055 --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_8_4_gt.py @@ -0,0 +1,396 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[]): + + self.root_dir = root_dir + self.split = split + + self.imgs_per_instance = 8 + + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + + lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid + with open(lvis_json_path, 'r') as f: + lvis_paths = json.load(f) + if self.split == 'train': + self.lvis_paths = lvis_paths['train'] + else: + self.lvis_paths = lvis_paths['val'] + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + pose_json_path_narrow = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json" + with open(pose_json_path_narrow, 'r') as f: + narrow_meta = json.load(f) + + pose_json_path_two_stage = "/objaverse-processed/zero12345_img/zero12345_2stage_8_pose.json" + with open(pose_json_path_two_stage, 'r') as f: + two_stage_meta = json.load(f) + + + self.img_ids = list(narrow_meta["c2ws"].keys()) + list(two_stage_meta["c2ws"].keys()) # (8 + 8*4) + (8 + 8*4) + self.img_wh = (256, 256) + self.input_poses = np.array(list(narrow_meta["c2ws"].values()) + list(two_stage_meta["c2ws"].values())) + intrinsic = np.eye(4) + assert narrow_meta["intrinsics"] == two_stage_meta["intrinsics"], "intrinsics not equal" + intrinsic[:3, :3] = np.array(narrow_meta["intrinsics"]) + self.intrinsic = intrinsic + assert narrow_meta["near_far"] == two_stage_meta["near_far"], "near_far not equal" + self.near_far = np.array(narrow_meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + for idx, img_id in enumerate(self.img_ids): + pose = self.input_poses[idx] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + + + + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + return self.imgs_per_instance * len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views + idx_original=idx + + folder_uid_dict = self.lvis_paths[idx//self.imgs_per_instance] + + folder_id = folder_uid_dict['folder_id'] + uid = folder_uid_dict['uid'] + + idx = idx % self.imgs_per_instance # [0, 7] + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}.png') + + depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}_depth_mm.png')) + + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0 + mask_h = depth_h > 0 + directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3] + surface_points = directions * depth_h[..., None] # [H, W, 3] + distance = np.linalg.norm(surface_points, axis=-1) # [H, W] + depth_h = distance + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + + + src_views = range(8, 8+32) + src_views_used = [] + for vid in src_views: + view_dix_to_use = (vid - 8) // 4 + src_views_used.append(vid) + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{(vid - 8) // 4}_{vid%4}_10_gt.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + depth_h =torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32) + depth_h = depth_h.fill_(-1.0) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + + new_depths_h.append(depth * scale_factor) + + + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + view_ids = [idx_original % self.imgs_per_instance] + src_views_used + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = folder_id + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + meta_end = "_two_stage"+ "_refview" + str(view_ids[0] - 8) + sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + meta_end + + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_8_narrow_and_8_2_stage_blend_3_views.py b/SparseNeuS_demo_v1/data/blender_general_8_narrow_and_8_2_stage_blend_3_views.py new file mode 100644 index 0000000000000000000000000000000000000000..fa97eb6ca99c254548e501f2e05d883f2b015e1c --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_8_narrow_and_8_2_stage_blend_3_views.py @@ -0,0 +1,446 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[]): + + self.root_dir = root_dir + self.split = split + self.imgs_per_instance = 16 + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + + lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid + with open(lvis_json_path, 'r') as f: + lvis_paths = json.load(f) + if self.split == 'train': + self.lvis_paths = lvis_paths['train'] + else: + self.lvis_paths = lvis_paths['val'] + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + pose_json_path_narrow = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json" + with open(pose_json_path_narrow, 'r') as f: + narrow_meta = json.load(f) + + pose_json_path_two_stage = "/objaverse-processed/zero12345_img/zero12345_2stage_8_pose.json" + with open(pose_json_path_two_stage, 'r') as f: + two_stage_meta = json.load(f) + + + self.img_ids = list(narrow_meta["c2ws"].keys()) + list(two_stage_meta["c2ws"].keys()) # (8 + 8*4) + (8 + 4*4) + self.img_wh = (256, 256) + self.input_poses = np.array(list(narrow_meta["c2ws"].values()) + list(two_stage_meta["c2ws"].values())) + intrinsic = np.eye(4) + assert narrow_meta["intrinsics"] == two_stage_meta["intrinsics"], "intrinsics not equal" + intrinsic[:3, :3] = np.array(narrow_meta["intrinsics"]) + self.intrinsic = intrinsic + assert narrow_meta["near_far"] == two_stage_meta["near_far"], "near_far not equal" + self.near_far = np.array(narrow_meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + for idx, img_id in enumerate(self.img_ids): + pose = self.input_poses[idx] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + + + + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + return self.imgs_per_instance*len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views + idx_original=idx + + folder_uid_dict = self.lvis_paths[idx//self.imgs_per_instance] + + folder_id = folder_uid_dict['folder_id'] + uid = folder_uid_dict['uid'] + + if idx % 16 < 8: # narrow image as target + idx = idx % self.imgs_per_instance # [0, 7] + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}.png') + + depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}_depth_mm.png')) + + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0 + mask_h = depth_h > 0 + # print("valid pixels", np.sum(mask_h)) + directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3] + surface_points = directions * depth_h[..., None] # [H, W, 3] + distance = np.linalg.norm(surface_points, axis=-1) # [H, W] + depth_h = distance + + else: + idx = idx % self.imgs_per_instance - 8 # [0, 5] + c2w = self.c2ws[idx + 40] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage_8/", folder_id, uid, f'view_0_{idx}_0.png') + + + img = Image.open(img_filename) + img = self.transform(img) # (4, h, w) + + # print("img_pre", img.shape) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + # print("img", img.shape) + imgs += [img] + + depth_h =torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32) + depth_h = depth_h.fill_(-1.0) + + mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32) + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + if_use_narrow = [] + if self.split == 'train': + for i in range(8): + if np.random.random() > 0.5: + if_use_narrow.append(True) # use narrow + else: + if_use_narrow.append(False) # 2-stage prediction + if_use_narrow[origin_idx % 8] = True if origin_idx < 8 else False + else: + for i in range(8): + if_use_narrow.append( True if origin_idx < 8 else False) + + src_views = list() + for i in range(8): + # randomly choose 3 different number from [0,3] + local_idxs = np.random.choice(4, 3, replace=False) + local_idxs = [0,1,2] + local_idxs = [8+i*4+local_idx for local_idx in local_idxs] + src_views += local_idxs + src_views_used = [] + for vid in src_views: + src_views_used.append(vid) + cur_view_id = (vid - 8) // 4 + # choose narrow + if if_use_narrow[cur_view_id]: + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{cur_view_id}_{vid%4}_10.png') + else: # choose 2-stage + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage_8/", folder_id, uid, f'view_0_{(vid - 8) // 4}_{(vid-8) % 4 + 1}.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + + + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + # print("img numeber: ", len(imgs)) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + view_ids = [idx_original % self.imgs_per_instance] + src_views_used + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = folder_id + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + if view_ids[0] < 8: + meta_end = "_narrow"+ "_refview" + str(view_ids[0]) + else: + meta_end = "_two_stage"+ "_refview" + str(view_ids[0] - 8) + sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + meta_end + + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_8_narrow_and_8_2_stage_blend_mix.py b/SparseNeuS_demo_v1/data/blender_general_8_narrow_and_8_2_stage_blend_mix.py new file mode 100644 index 0000000000000000000000000000000000000000..740bb81125a297fc1d504f4c119c7f9a76630507 --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_8_narrow_and_8_2_stage_blend_mix.py @@ -0,0 +1,439 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[]): + + self.root_dir = root_dir + self.split = split + self.imgs_per_instance = 16 + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + + lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid + with open(lvis_json_path, 'r') as f: + lvis_paths = json.load(f) + if self.split == 'train': + self.lvis_paths = lvis_paths['train'] + else: + self.lvis_paths = lvis_paths['val'] + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + pose_json_path_narrow = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json" + with open(pose_json_path_narrow, 'r') as f: + narrow_meta = json.load(f) + + pose_json_path_two_stage = "/objaverse-processed/zero12345_img/zero12345_2stage_8_pose.json" + with open(pose_json_path_two_stage, 'r') as f: + two_stage_meta = json.load(f) + + + self.img_ids = list(narrow_meta["c2ws"].keys()) + list(two_stage_meta["c2ws"].keys()) # (8 + 8*4) + (8 + 8*4) + self.img_wh = (256, 256) + self.input_poses = np.array(list(narrow_meta["c2ws"].values()) + list(two_stage_meta["c2ws"].values())) + intrinsic = np.eye(4) + assert narrow_meta["intrinsics"] == two_stage_meta["intrinsics"], "intrinsics not equal" + intrinsic[:3, :3] = np.array(narrow_meta["intrinsics"]) + self.intrinsic = intrinsic + assert narrow_meta["near_far"] == two_stage_meta["near_far"], "near_far not equal" + self.near_far = np.array(narrow_meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + for idx, img_id in enumerate(self.img_ids): + pose = self.input_poses[idx] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + + + + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + return self.imgs_per_instance*len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views + idx_original=idx + + folder_uid_dict = self.lvis_paths[idx//self.imgs_per_instance] + + folder_id = folder_uid_dict['folder_id'] + uid = folder_uid_dict['uid'] + + if idx % 16 < 8: # gt image as target + idx = idx % self.imgs_per_instance # [0, 7] + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}.png') + + depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}_depth_mm.png')) + + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0 + mask_h = depth_h > 0 + # print("valid pixels", np.sum(mask_h)) + directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3] + surface_points = directions * depth_h[..., None] # [H, W, 3] + distance = np.linalg.norm(surface_points, axis=-1) # [H, W] + depth_h = distance + + else: + idx = idx % self.imgs_per_instance - 8 # [0, 7] + c2w = self.c2ws[idx + 40] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage_8/", folder_id, uid, f'view_0_{idx}_0.png') + + + img = Image.open(img_filename) + img = self.transform(img) # (4, h, w) + + # print("img_pre", img.shape) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + # print("img", img.shape) + imgs += [img] + + depth_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32) + depth_h = depth_h.fill_(-1.0) + + mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32) + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + if_use_narrow = [] + if self.split == 'train': + for i in range(8): + if np.random.random() > 0.5: + if_use_narrow.append(True) # use narrow + else: + if_use_narrow.append(False) # 2-stage prediction + if_use_narrow[origin_idx % 8] = True if (origin_idx % 16) < 8 else False + else: + for i in range(8): + if_use_narrow.append( True if (origin_idx % 16) < 8 else False) + src_views = range(8, 8 + 8 * 4) + src_views_used = [] + for vid in src_views: + src_views_used.append(vid) + cur_view_id = (vid - 8) // 4 # [0, 7] + # choose narrow + if if_use_narrow[cur_view_id]: + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{cur_view_id}_{vid%4}_10.png') + else: # choose 2-stage + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage_8/", folder_id, uid, f'view_0_{cur_view_id}_{(vid) % 4 + 1}.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + + + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + # print("img numeber: ", len(imgs)) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + view_ids = [idx_original % self.imgs_per_instance] + src_views_used + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = folder_id + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + if view_ids[0] < 8: + meta_end = "_narrow"+ "_refview" + str(view_ids[0]) + else: + meta_end = "_two_stage"+ "_refview" + str(view_ids[0] - 8) + sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + meta_end + + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_8_narrow_and_8_2_stage_mix.py b/SparseNeuS_demo_v1/data/blender_general_8_narrow_and_8_2_stage_mix.py new file mode 100644 index 0000000000000000000000000000000000000000..6d860e521935b529c4240a0299d892ff90f683b2 --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_8_narrow_and_8_2_stage_mix.py @@ -0,0 +1,470 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[]): + + self.root_dir = root_dir + self.split = split + self.imgs_per_instance = 16 + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + + lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid + with open(lvis_json_path, 'r') as f: + lvis_paths = json.load(f) + if self.split == 'train': + self.lvis_paths = lvis_paths['train'] + else: + self.lvis_paths = lvis_paths['val'] + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + pose_json_path_narrow = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json" + with open(pose_json_path_narrow, 'r') as f: + narrow_meta = json.load(f) + + pose_json_path_two_stage = "/objaverse-processed/zero12345_img/zero12345_2stage_8_pose.json" + with open(pose_json_path_two_stage, 'r') as f: + two_stage_meta = json.load(f) + + + self.img_ids = list(narrow_meta["c2ws"].keys()) + list(two_stage_meta["c2ws"].keys()) # (8 + 8*4) + (8 + 8*4) + self.img_wh = (256, 256) + self.input_poses = np.array(list(narrow_meta["c2ws"].values()) + list(two_stage_meta["c2ws"].values())) + intrinsic = np.eye(4) + assert narrow_meta["intrinsics"] == two_stage_meta["intrinsics"], "intrinsics not equal" + intrinsic[:3, :3] = np.array(narrow_meta["intrinsics"]) + self.intrinsic = intrinsic + assert narrow_meta["near_far"] == two_stage_meta["near_far"], "near_far not equal" + self.near_far = np.array(narrow_meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + for idx, img_id in enumerate(self.img_ids): + pose = self.input_poses[idx] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + + + + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + return self.imgs_per_instance * len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views + idx_original=idx + + folder_uid_dict = self.lvis_paths[idx//self.imgs_per_instance] + + folder_id = folder_uid_dict['folder_id'] + uid = folder_uid_dict['uid'] + + if idx % self.imgs_per_instance < 8: + idx = idx % self.imgs_per_instance # [0, 7] + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}.png') + + depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}_depth_mm.png')) + + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0 + mask_h = depth_h > 0 + # print("valid pixels", np.sum(mask_h)) + directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3] + surface_points = directions * depth_h[..., None] # [H, W, 3] + distance = np.linalg.norm(surface_points, axis=-1) # [H, W] + depth_h = distance + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + # src_views = range(8+idx*4, 8+(idx+1)*4) + + src_views = range(8, 8 + 8 * 4) + src_views_used = [] + for vid in src_views: + src_views_used.append(vid) + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{(vid - 8) // 4}_{vid%4}_10.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + else: + idx = idx % self.imgs_per_instance - 8 # [0, 5] + + c2w = self.c2ws[idx + 40] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage_8/", folder_id, uid, f'view_0_{idx}_0.png') + + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + # print("img_pre", img.shape) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + # print("img", img.shape) + imgs += [img] + + + depth_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32) + depth_h = depth_h.fill_(-1.0) + # depth_h = torch.fill((img.shape[1], img.shape[2]), -1.0) + # print("depth_h", depth_h.shape) + mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32) + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + + + src_views = range(40+8, 40+8+32) + src_views_used = [] + for vid in src_views: + view_dix_to_use = (vid - 40 - 8) // 4 + + src_views_used.append(vid) + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage_8/", folder_id, uid, f'view_0_{idx}_{(vid-48) % 4 + 1}.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + # print("img shape1: ", img.shape) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + # print("img shape2: ", img.shape) + imgs += [img] + depth_h =torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32) + depth_h = depth_h.fill_(-1.0) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + # print("img numeber: ", len(imgs)) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + view_ids = [idx_original % self.imgs_per_instance] + src_views_used + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = folder_id + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + if view_ids[0] < 8: + meta_end = "_narrow"+ "_refview" + str(view_ids[0]) + else: + meta_end = "_two_stage"+ "_refview" + str(view_ids[0] - 8) + sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + meta_end + + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_8_wide_from_2_stage.py b/SparseNeuS_demo_v1/data/blender_general_8_wide_from_2_stage.py new file mode 100644 index 0000000000000000000000000000000000000000..9609f20a733486544347d7fec78ae16bf1b9e2a3 --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_8_wide_from_2_stage.py @@ -0,0 +1,395 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[]): + + self.root_dir = root_dir + self.split = split + + self.imgs_per_instance = 8 + + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + + lvis_json_path = '/objaverse-processed/zero12345_img/random32_split.json' # folder_id and uid + with open(lvis_json_path, 'r') as f: + lvis_paths = json.load(f) + if self.split == 'train': + self.lvis_paths = lvis_paths['train'] + else: + self.lvis_paths = lvis_paths['val'] + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + pose_json_path_narrow = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json" + with open(pose_json_path_narrow, 'r') as f: + narrow_meta = json.load(f) + + pose_json_path_two_stage = "/objaverse-processed/zero12345_img/zero12345_2stage_8_pose.json" + with open(pose_json_path_two_stage, 'r') as f: + two_stage_meta = json.load(f) + + + self.img_ids = list(narrow_meta["c2ws"].keys()) + list(two_stage_meta["c2ws"].keys()) # (8 + 8*4) + (8 + 8*4) + self.img_wh = (256, 256) + self.input_poses = np.array(list(narrow_meta["c2ws"].values()) + list(two_stage_meta["c2ws"].values())) + intrinsic = np.eye(4) + assert narrow_meta["intrinsics"] == two_stage_meta["intrinsics"], "intrinsics not equal" + intrinsic[:3, :3] = np.array(narrow_meta["intrinsics"]) + self.intrinsic = intrinsic + assert narrow_meta["near_far"] == two_stage_meta["near_far"], "near_far not equal" + self.near_far = np.array(narrow_meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + for idx, img_id in enumerate(self.img_ids): + pose = self.input_poses[idx] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + + + + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + return self.imgs_per_instance * len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views + idx_original=idx + + folder_uid_dict = self.lvis_paths[idx//self.imgs_per_instance] + + folder_id = folder_uid_dict['folder_id'] + uid = folder_uid_dict['uid'] + + idx = idx % self.imgs_per_instance # [0, 7] + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}.png') + + depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}_depth_mm.png')) + + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0 + mask_h = depth_h > 0 + directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3] + surface_points = directions * depth_h[..., None] # [H, W, 3] + distance = np.linalg.norm(surface_points, axis=-1) # [H, W] + depth_h = distance + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + + + src_views = range(0, 8) + src_views_used = [] + for vid in src_views: + src_views_used.append(vid) + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage_8/", folder_id, uid, f'view_0_{vid}_0.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + depth_h =torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32) + depth_h = depth_h.fill_(-1.0) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + + new_depths_h.append(depth * scale_factor) + + + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + view_ids = [idx_original % self.imgs_per_instance] + src_views_used + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = folder_id + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + meta_end = "_two_stage"+ "_refview" + str(view_ids[0] - 8) + sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + meta_end + + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_4_1_eval_new_data.py b/SparseNeuS_demo_v1/data/blender_general_narrow_4_1_eval_new_data.py new file mode 100644 index 0000000000000000000000000000000000000000..bacd68d0d8cc7b578bf546e4484590f985920051 --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_narrow_4_1_eval_new_data.py @@ -0,0 +1,418 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d + + +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[], + specific_dataset_name = 'GSO' + ): + + # print("root_dir: ", root_dir) + self.root_dir = root_dir + self.split = split + # self.specific_dataset_name = 'Realfusion' + # self.specific_dataset_name = 'GSO' + # self.specific_dataset_name = 'Objaverse' + # self.specific_dataset_name = 'Zero123' + + self.specific_dataset_name = specific_dataset_name + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + assert self.split == 'val' or 'export_mesh', 'only support val or export_mesh' + # find all subfolders + main_folder = os.path.join(root_dir, self.specific_dataset_name) + self.shape_list = os.listdir(main_folder) + self.shape_list.sort() + + # self.shape_list = ['barrel_render'] + # self.shape_list = ["barrel", "bag", "mailbox", "shoe", "chair", "car", "dog", "teddy"] # TO BE DELETED + + + self.lvis_paths = [] + for shape_name in self.shape_list: + self.lvis_paths.append(os.path.join(main_folder, shape_name)) + + # print("lvis_paths: ", self.lvis_paths) + + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + return 8*len(self.lvis_paths) + # return len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + # idx = idx * 8 # to be deleted + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj-mats between views + + folder_path = self.lvis_paths[idx//8] + idx = idx % 8 # [0, 7] + + # last subdir name + shape_name = os.path.split(folder_path)[-1] + + pose_json_path = os.path.join(folder_path, "pose.json") + with open(pose_json_path, 'r') as f: + meta = json.load(f) + + self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10" + self.img_wh = (256, 256) + self.input_poses = np.array(list(meta["c2ws"].values())) + intrinsic = np.eye(4) + intrinsic[:3, :3] = np.array(meta["intrinsics"]) + self.intrinsic = intrinsic + self.near_far = np.array(meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + # self.root_dir = root_dir + for image_dix, img_id in enumerate(self.img_ids): + pose = self.input_poses[image_dix] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + # img_filename = os.path.join(folder_path, 'stage1_8_debug', f'{self.img_ids[idx]}') + img_filename = os.path.join(folder_path, 'stage1_8', f'{self.img_ids[idx]}') + + img = Image.open(img_filename) + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + + depth_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32) + depth_h = depth_h.fill_(-1.0) + mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32) + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + + # src_views = range(8, 8 + 8 * 4) + src_views = range(8+idx*4, 8+(idx+1)*4) + for vid in src_views: + + # img_filename = os.path.join(folder_path, 'stage2_8_debug', f'{self.img_ids[vid]}') + img_filename = os.path.join(folder_path, 'stage2_8', f'{self.img_ids[vid]}') + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + # ! estimate scale_mat + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + + + target_w2cs = [] + target_intrinsics = [] + new_target_w2cs = [] + for i_idx in range(8): + target_w2cs.append(self.all_extrinsics[i_idx] @ w2c_ref_inv) + target_intrinsics.append(self.all_intrinsics[i_idx]) + + for intrinsic, extrinsic in zip(target_intrinsics, target_w2cs): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_target_w2cs.append(w2c) + target_w2cs = np.stack(new_target_w2cs) + + + + view_ids = [idx] + list(src_views) + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['target_candidate_w2cs'] = torch.from_numpy(target_w2cs.astype(np.float32)) # (8, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = shape_name + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + sample['meta'] = str(self.specific_dataset_name) + '_' + str(shape_name) + "_refview" + str(view_ids[0]) + # print("meta: ", sample['meta']) + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_6.py b/SparseNeuS_demo_v1/data/blender_general_narrow_6.py new file mode 100644 index 0000000000000000000000000000000000000000..5d8333986bb15b3e3fd495f1ee4600e22ef93246 --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_narrow_6.py @@ -0,0 +1,399 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[]): + + # print("root_dir: ", root_dir) + self.root_dir = root_dir + self.split = split + + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + + lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid + with open(lvis_json_path, 'r') as f: + lvis_paths = json.load(f) + if self.split == 'train': + self.lvis_paths = lvis_paths['train'] + else: + self.lvis_paths = lvis_paths['val'] + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + pose_json_path = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json" + with open(pose_json_path, 'r') as f: + meta = json.load(f) + + self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10" + self.img_wh = (256, 256) + self.input_poses = np.array(list(meta["c2ws"].values())) + intrinsic = np.eye(4) + intrinsic[:3, :3] = np.array(meta["intrinsics"]) + self.intrinsic = intrinsic + self.near_far = np.array(meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + # self.root_dir = root_dir + for idx, img_id in enumerate(self.img_ids): + pose = self.input_poses[idx] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + # print("center", center) + # print("radius", radius) + # print("bounds", bounds) + # import ipdb; ipdb.set_trace() + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + if self.split == 'train': + return 6*len(self.lvis_paths) + else: + return 8*len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views + + if self.split == 'train': + folder_uid_dict = self.lvis_paths[idx//6] + idx = idx % 6 # [0, 5] + if idx == 4: + idx = 5 + elif idx == 5: + idx = 7 + else: + folder_uid_dict = self.lvis_paths[idx//8] + idx = idx % 8 # [0, 7] + + folder_id = folder_uid_dict['folder_id'] + uid = folder_uid_dict['uid'] + + + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{idx}.png') + + depth_filename = os.path.join(os.path.join(self.root_dir, folder_id, uid, f'view_{idx}_depth_mm.png')) + + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0 + mask_h = depth_h > 0 + # print("valid pixels", np.sum(mask_h)) + directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3] + surface_points = directions * depth_h[..., None] # [H, W, 3] + distance = np.linalg.norm(surface_points, axis=-1) # [H, W] + depth_h = distance + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + # src_views = range(8+idx*4, 8+(idx+1)*4) + src_views = range(8, 8 + 8 * 4) + + for vid in src_views: + if ((vid - 8) // 4 == 4) or ((vid - 8) // 4 == 6): + continue + img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{(vid - 8) // 4}_{vid%4}_10.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + # print("len(imges)", len(imgs)) + + # ! estimate scale_mat + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + view_ids = [idx] + list(src_views) + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = folder_id + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0]) + + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_8_3_fixed.py b/SparseNeuS_demo_v1/data/blender_general_narrow_8_3_fixed.py new file mode 100644 index 0000000000000000000000000000000000000000..58c26348e73b44fdcb33bad81b1fddba66efeffc --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_narrow_8_3_fixed.py @@ -0,0 +1,393 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[]): + + # print("root_dir: ", root_dir) + self.root_dir = root_dir + self.split = split + + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + + lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid + with open(lvis_json_path, 'r') as f: + lvis_paths = json.load(f) + if self.split == 'train': + self.lvis_paths = lvis_paths['train'] + else: + self.lvis_paths = lvis_paths['val'] + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + pose_json_path = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json" + with open(pose_json_path, 'r') as f: + meta = json.load(f) + + self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10" + self.img_wh = (256, 256) + self.input_poses = np.array(list(meta["c2ws"].values())) + intrinsic = np.eye(4) + intrinsic[:3, :3] = np.array(meta["intrinsics"]) + self.intrinsic = intrinsic + self.near_far = np.array(meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + # self.root_dir = root_dir + for idx, img_id in enumerate(self.img_ids): + pose = self.input_poses[idx] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + # print("center", center) + # print("radius", radius) + # print("bounds", bounds) + # import ipdb; ipdb.set_trace() + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + return 8*len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views + + + folder_uid_dict = self.lvis_paths[idx//8] + idx = idx % 8 # [0, 7] + folder_id = folder_uid_dict['folder_id'] + uid = folder_uid_dict['uid'] + + + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{idx}.png') + + depth_filename = os.path.join(os.path.join(self.root_dir, folder_id, uid, f'view_{idx}_depth_mm.png')) + + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0 + mask_h = depth_h > 0 + # print("valid pixels", np.sum(mask_h)) + directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3] + surface_points = directions * depth_h[..., None] # [H, W, 3] + distance = np.linalg.norm(surface_points, axis=-1) # [H, W] + depth_h = distance + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + # src_views = range(8+idx*4, 8+(idx+1)*4) + src_views = list() + for i in range(8): + # randomly choose 3 different number from [0,3] + # local_idxs = np.random.choice(4, 3, replace=False) + local_idxs = [0, 2, 3] + # local_idxs = np.random.choice(4, 3, replace=False) + + local_idxs = [8 + i * 4 + local_idx for local_idx in local_idxs] + src_views += local_idxs + for vid in src_views: + img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{(vid - 8) // 4}_{vid%4}_10.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + # print("len(imgs)", len(imgs)) + # ! estimate scale_mat + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + view_ids = [idx] + list(src_views) + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = folder_id + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0]) + + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_8_3_random.py b/SparseNeuS_demo_v1/data/blender_general_narrow_8_3_random.py new file mode 100644 index 0000000000000000000000000000000000000000..b52542595e8d39dff91f18e63a0b504c4c4d2d48 --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_narrow_8_3_random.py @@ -0,0 +1,395 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[]): + + # print("root_dir: ", root_dir) + self.root_dir = root_dir + self.split = split + + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + + lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid + with open(lvis_json_path, 'r') as f: + lvis_paths = json.load(f) + if self.split == 'train': + self.lvis_paths = lvis_paths['train'] + else: + self.lvis_paths = lvis_paths['val'] + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + pose_json_path = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json" + with open(pose_json_path, 'r') as f: + meta = json.load(f) + + self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10" + self.img_wh = (256, 256) + self.input_poses = np.array(list(meta["c2ws"].values())) + intrinsic = np.eye(4) + intrinsic[:3, :3] = np.array(meta["intrinsics"]) + self.intrinsic = intrinsic + self.near_far = np.array(meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + # self.root_dir = root_dir + for idx, img_id in enumerate(self.img_ids): + pose = self.input_poses[idx] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + # print("center", center) + # print("radius", radius) + # print("bounds", bounds) + # import ipdb; ipdb.set_trace() + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + return 8*len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views + + + folder_uid_dict = self.lvis_paths[idx//8] + idx = idx % 8 # [0, 7] + folder_id = folder_uid_dict['folder_id'] + uid = folder_uid_dict['uid'] + + + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{idx}.png') + + depth_filename = os.path.join(os.path.join(self.root_dir, folder_id, uid, f'view_{idx}_depth_mm.png')) + + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0 + mask_h = depth_h > 0 + # print("valid pixels", np.sum(mask_h)) + directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3] + surface_points = directions * depth_h[..., None] # [H, W, 3] + distance = np.linalg.norm(surface_points, axis=-1) # [H, W] + depth_h = distance + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + # src_views = range(8+idx*4, 8+(idx+1)*4) + src_views = list() + for i in range(8): + + if self.split == 'train': + local_idxs = np.random.choice(4, 3, replace=False) + else: + local_idxs = [0, 2, 3] + # local_idxs = np.random.choice(4, 3, replace=False) + + local_idxs = [8 + i * 4 + local_idx for local_idx in local_idxs] + src_views += local_idxs + for vid in src_views: + img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{(vid - 8) // 4}_{vid%4}_10.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + # print("len(imgs)", len(imgs)) + # ! estimate scale_mat + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + view_ids = [idx] + list(src_views) + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = folder_id + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0]) + + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_8_4_random_shading.py b/SparseNeuS_demo_v1/data/blender_general_narrow_8_4_random_shading.py new file mode 100644 index 0000000000000000000000000000000000000000..e120367ce96847e9fb60b2ae038a812583fe75e3 --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_narrow_8_4_random_shading.py @@ -0,0 +1,432 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[]): + + # print("root_dir: ", root_dir) + self.root_dir = root_dir + self.split = split + + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + + lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid + with open(lvis_json_path, 'r') as f: + lvis_paths = json.load(f) + if self.split == 'train': + self.lvis_paths = lvis_paths['train'] + else: + self.lvis_paths = lvis_paths['val'] + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + pose_json_path = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json" + with open(pose_json_path, 'r') as f: + meta = json.load(f) + + self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10" + self.img_wh = (256, 256) + self.input_poses = np.array(list(meta["c2ws"].values())) + intrinsic = np.eye(4) + intrinsic[:3, :3] = np.array(meta["intrinsics"]) + self.intrinsic = intrinsic + self.near_far = np.array(meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + # self.root_dir = root_dir + for idx, img_id in enumerate(self.img_ids): + pose = self.input_poses[idx] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + # print("center", center) + # print("radius", radius) + # print("bounds", bounds) + # import ipdb; ipdb.set_trace() + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + return 8*len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views + + + folder_uid_dict = self.lvis_paths[idx//8] + idx = idx % 8 # [0, 7] + folder_id = folder_uid_dict['folder_id'] + uid = folder_uid_dict['uid'] + + + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{idx}.png') + + depth_filename = os.path.join(os.path.join(self.root_dir, folder_id, uid, f'view_{idx}_depth_mm.png')) + + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0 + mask_h = depth_h > 0 + # print("valid pixels", np.sum(mask_h)) + directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3] + surface_points = directions * depth_h[..., None] # [H, W, 3] + distance = np.linalg.norm(surface_points, axis=-1) # [H, W] + depth_h = distance + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + # src_views = range(8+idx*4, 8+(idx+1)*4) + src_views = range(8, 8 + 8 * 4) + + for vid in src_views: + img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{(vid - 8) // 4}_{vid%4}_10.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + # ! estimate scale_mat + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + new_depths_h.append(depth * scale_factor) + + if self.split == 'train': + # randomly select one view from eight views as reference view + idx_to_select = np.random.randint(0, 8) + + img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{idx_to_select}.png') + img = Image.open(img_filename) + img = self.transform(img) # (4, h, w) + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs[0] = img + + w2c_selected = self.all_extrinsics[idx_to_select] @ w2c_ref_inv + P = self.all_intrinsics[idx_to_select] @ w2c_selected @ scale_mat + P = P[:3, :4] + + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = self.all_intrinsics[idx_to_select][:3, :3] @ w2c[:3, :4] + new_affine_mats[0] = affine_mat + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + new_near_fars[0] = [0.95 * near, 1.05 * far] + + new_w2cs[0] = w2c + new_c2ws[0] = c2w + + depth_filename = os.path.join(os.path.join(self.root_dir, folder_id, uid, f'view_{idx_to_select}_depth_mm.png')) + depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0 + mask_h = depth_h > 0 + directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3] + surface_points = directions * depth_h[..., None] # [H, W, 3] + distance = np.linalg.norm(surface_points, axis=-1) # [H, W] + depth_h = distance * scale_factor + + new_depths_h[0] = depth_h + masks_h[0] = mask_h + + + + # print(new_near_fars) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + + view_ids = [idx] + list(src_views) + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = folder_id + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0]) + + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all.py new file mode 100644 index 0000000000000000000000000000000000000000..50b85d133707e83b36d926b7acf1cb121dd4d04d --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all.py @@ -0,0 +1,386 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[]): + + # print("root_dir: ", root_dir) + self.root_dir = root_dir + self.split = split + + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + + lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid + with open(lvis_json_path, 'r') as f: + lvis_paths = json.load(f) + if self.split == 'train': + self.lvis_paths = lvis_paths['train'] + else: + self.lvis_paths = lvis_paths['val'] + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + pose_json_path = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json" + with open(pose_json_path, 'r') as f: + meta = json.load(f) + + self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10" + self.img_wh = (256, 256) + self.input_poses = np.array(list(meta["c2ws"].values())) + intrinsic = np.eye(4) + intrinsic[:3, :3] = np.array(meta["intrinsics"]) + self.intrinsic = intrinsic + self.near_far = np.array(meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + # self.root_dir = root_dir + for idx, img_id in enumerate(self.img_ids): + pose = self.input_poses[idx] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + # print("center", center) + # print("radius", radius) + # print("bounds", bounds) + # import ipdb; ipdb.set_trace() + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + return 8*len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views + + + folder_uid_dict = self.lvis_paths[idx//8] + idx = idx % 8 # [0, 7] + folder_id = folder_uid_dict['folder_id'] + uid = folder_uid_dict['uid'] + + + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{idx}.png') + + depth_filename = os.path.join(os.path.join(self.root_dir, folder_id, uid, f'view_{idx}_depth_mm.png')) + + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0 + mask_h = depth_h > 0 + # print("valid pixels", np.sum(mask_h)) + directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3] + surface_points = directions * depth_h[..., None] # [H, W, 3] + distance = np.linalg.norm(surface_points, axis=-1) # [H, W] + depth_h = distance + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + # src_views = range(8+idx*4, 8+(idx+1)*4) + src_views = range(8, 8 + 8 * 4) + + for vid in src_views: + img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{(vid - 8) // 4}_{vid%4}_10.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + # ! estimate scale_mat + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + view_ids = [idx] + list(src_views) + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = folder_id + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0]) + + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_2_stage.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_2_stage.py new file mode 100644 index 0000000000000000000000000000000000000000..1b832beccd85c8a0be98edf95f0d244c1cbf8b17 --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_2_stage.py @@ -0,0 +1,410 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[]): + + # print("root_dir: ", root_dir) + self.root_dir = root_dir + self.split = split + + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + + lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid + with open(lvis_json_path, 'r') as f: + lvis_paths = json.load(f) + if self.split == 'train': + self.lvis_paths = lvis_paths['train'] + else: + self.lvis_paths = lvis_paths['val'] + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + pose_json_path = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json" + with open(pose_json_path, 'r') as f: + meta = json.load(f) + + self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10" + self.img_wh = (256, 256) + self.input_poses = np.array(list(meta["c2ws"].values())) + intrinsic = np.eye(4) + intrinsic[:3, :3] = np.array(meta["intrinsics"]) + self.intrinsic = intrinsic + self.near_far = np.array(meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + # self.root_dir = root_dir + for idx, img_id in enumerate(self.img_ids): + pose = self.input_poses[idx] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + # print("center", center) + # print("radius", radius) + # print("bounds", bounds) + # import ipdb; ipdb.set_trace() + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + return 8*len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views + + + folder_uid_dict = self.lvis_paths[idx//8] + idx = idx % 8 # [0, 7] + folder_id = folder_uid_dict['folder_id'] + uid = folder_uid_dict['uid'] + + + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}.png') + + depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}_depth_mm.png')) + + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + # print("img_pre", img.shape) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + # print("img", img.shape) + imgs += [img] + + + depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0 + mask_h = depth_h > 0 + # print("valid pixels", np.sum(mask_h)) + directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3] + surface_points = directions * depth_h[..., None] # [H, W, 3] + distance = np.linalg.norm(surface_points, axis=-1) # [H, W] + depth_h = distance + # print("depth_h", depth_h.shape) + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + # src_views = range(8+idx*4, 8+(idx+1)*4) + src_views = range(8, 8 + 8 * 4) + + for vid in src_views: + img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_0_{(vid - 8) // 4}_{vid % 4 + 1}.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + # ! estimate scale_mat + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + + + target_w2cs = [] + target_intrinsics = [] + new_target_w2cs = [] + for i_idx in range(8): + target_w2cs.append(self.all_extrinsics[i_idx] @ w2c_ref_inv) + target_intrinsics.append(self.all_intrinsics[i_idx]) + + for intrinsic, extrinsic in zip(target_intrinsics, target_w2cs): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_target_w2cs.append(w2c) + target_w2cs = np.stack(new_target_w2cs) + + + + view_ids = [idx] + list(src_views) + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['target_candidate_w2cs'] = torch.from_numpy(target_w2cs.astype(np.float32)) # (8, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = folder_id + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0]) + + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_2_stage_temp.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_2_stage_temp.py new file mode 100644 index 0000000000000000000000000000000000000000..5c2dbebd00ed9e0293c26029c97ab77b7880fcf0 --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_2_stage_temp.py @@ -0,0 +1,411 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[]): + + # print("root_dir: ", root_dir) + self.root_dir = root_dir + self.split = split + + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + + lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid + with open(lvis_json_path, 'r') as f: + lvis_paths = json.load(f) + if self.split == 'train': + self.lvis_paths = lvis_paths['train'] + else: + self.lvis_paths = lvis_paths['val'] + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + pose_json_path = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json" + with open(pose_json_path, 'r') as f: + meta = json.load(f) + + self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10" + self.img_wh = (256, 256) + self.input_poses = np.array(list(meta["c2ws"].values())) + intrinsic = np.eye(4) + intrinsic[:3, :3] = np.array(meta["intrinsics"]) + self.intrinsic = intrinsic + self.near_far = np.array(meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + # self.root_dir = root_dir + for idx, img_id in enumerate(self.img_ids): + pose = self.input_poses[idx] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + # print("center", center) + # print("radius", radius) + # print("bounds", bounds) + # import ipdb; ipdb.set_trace() + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + return 10 + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + idx = idx * 8 + sample = {} + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views + + + folder_uid_dict = self.lvis_paths[idx//8] + idx = idx % 8 # [0, 7] + folder_id = folder_uid_dict['folder_id'] + uid = folder_uid_dict['uid'] + + + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}.png') + + depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}_depth_mm.png')) + + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + # print("img_pre", img.shape) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + # print("img", img.shape) + imgs += [img] + + + depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0 + mask_h = depth_h > 0 + # print("valid pixels", np.sum(mask_h)) + directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3] + surface_points = directions * depth_h[..., None] # [H, W, 3] + distance = np.linalg.norm(surface_points, axis=-1) # [H, W] + depth_h = distance + # print("depth_h", depth_h.shape) + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + # src_views = range(8+idx*4, 8+(idx+1)*4) + src_views = range(8, 8 + 8 * 4) + + for vid in src_views: + img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_0_{(vid - 8) // 4}_{vid % 4 + 1}.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + # ! estimate scale_mat + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + + + target_w2cs = [] + target_intrinsics = [] + new_target_w2cs = [] + for i_idx in range(8): + target_w2cs.append(self.all_extrinsics[i_idx] @ w2c_ref_inv) + target_intrinsics.append(self.all_intrinsics[i_idx]) + + for intrinsic, extrinsic in zip(target_intrinsics, target_w2cs): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_target_w2cs.append(w2c) + target_w2cs = np.stack(new_target_w2cs) + + + + view_ids = [idx] + list(src_views) + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['target_candidate_w2cs'] = torch.from_numpy(target_w2cs.astype(np.float32)) # (8, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = folder_id + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0]) + + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data.py new file mode 100644 index 0000000000000000000000000000000000000000..194cf007f54d2d377ce6561050f82e38dc246e73 --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data.py @@ -0,0 +1,418 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d + + +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[], + specific_dataset_name = 'GSO' + ): + + # print("root_dir: ", root_dir) + self.root_dir = root_dir + self.split = split + # self.specific_dataset_name = 'Realfusion' + # self.specific_dataset_name = 'GSO' + # self.specific_dataset_name = 'Objaverse' + # self.specific_dataset_name = 'Zero123' + + self.specific_dataset_name = specific_dataset_name + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + assert self.split == 'val' or 'export_mesh', 'only support val or export_mesh' + # find all subfolders + main_folder = os.path.join(root_dir, self.specific_dataset_name) + self.shape_list = [""] # os.listdir(main_folder) # MODIFIED + self.shape_list.sort() + + # self.shape_list = ['barrel_render'] + # self.shape_list = ["barrel", "bag", "mailbox", "shoe", "chair", "car", "dog", "teddy"] # TO BE DELETED + + + self.lvis_paths = [] + for shape_name in self.shape_list: + self.lvis_paths.append(os.path.join(main_folder, shape_name)) + + # print("lvis_paths: ", self.lvis_paths) + + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + # return 8*len(self.lvis_paths) + return len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + idx = idx * 8 # to be deleted + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj-mats between views + + folder_path = self.lvis_paths[idx//8] + idx = idx % 8 # [0, 7] + + # last subdir name + shape_name = os.path.split(folder_path)[-1] + + pose_json_path = os.path.join(folder_path, "pose.json") + with open(pose_json_path, 'r') as f: + meta = json.load(f) + + self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10" + self.img_wh = (256, 256) + self.input_poses = np.array(list(meta["c2ws"].values())) + intrinsic = np.eye(4) + intrinsic[:3, :3] = np.array(meta["intrinsics"]) + self.intrinsic = intrinsic + self.near_far = np.array(meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + # self.root_dir = root_dir + for image_dix, img_id in enumerate(self.img_ids): + pose = self.input_poses[image_dix] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + # img_filename = os.path.join(folder_path, 'stage1_8_debug', f'{self.img_ids[idx]}') + img_filename = os.path.join(folder_path, 'stage1_8', f'{self.img_ids[idx]}') + + img = Image.open(img_filename) + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + + depth_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32) + depth_h = depth_h.fill_(-1.0) + mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32) + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + + src_views = range(8, 8 + 8 * 4) + + for vid in src_views: + + # img_filename = os.path.join(folder_path, 'stage2_8_debug', f'{self.img_ids[vid]}') + img_filename = os.path.join(folder_path, 'stage2_8', f'{self.img_ids[vid]}') + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + # ! estimate scale_mat + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + + + target_w2cs = [] + target_intrinsics = [] + new_target_w2cs = [] + for i_idx in range(8): + target_w2cs.append(self.all_extrinsics[i_idx] @ w2c_ref_inv) + target_intrinsics.append(self.all_intrinsics[i_idx]) + + for intrinsic, extrinsic in zip(target_intrinsics, target_w2cs): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_target_w2cs.append(w2c) + target_w2cs = np.stack(new_target_w2cs) + + + + view_ids = [idx] + list(src_views) + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['target_candidate_w2cs'] = torch.from_numpy(target_w2cs.astype(np.float32)) # (8, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = shape_name + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + sample['meta'] = str(self.specific_dataset_name) + '_' + str(shape_name) + "_refview" + str(view_ids[0]) + # print("meta: ", sample['meta']) + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data3_1.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data3_1.py new file mode 100644 index 0000000000000000000000000000000000000000..7ce059be019a360b193c526c358057ffc9b48d1a --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data3_1.py @@ -0,0 +1,414 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[]): + + # print("root_dir: ", root_dir) + self.root_dir = root_dir + self.split = split + # self.specific_dataset_name = 'Realfusion' + self.specific_dataset_name = 'Objaverse' + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + assert self.split == 'val' or 'export_mesh', 'only support val or export_mesh' + # find all subfolders + main_folder = os.path.join(root_dir, self.specific_dataset_name) + self.shape_list = os.listdir(main_folder) + self.shape_list.sort() + + # self.shape_list = ["barrel", "bag", "mailbox", "shoe", "chair", "car", "dog", "teddy"] # TO BE DELETED + + + self.lvis_paths = [] + for shape_name in self.shape_list: + self.lvis_paths.append(os.path.join(main_folder, shape_name)) + + # print("lvis_paths: ", self.lvis_paths) + + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + # return 8*len(self.lvis_paths) + return len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + idx = idx * 8 # to be deleted + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views + + + folder_path = self.lvis_paths[idx//8] + idx = idx % 8 # [0, 7] + + # last subdir name + shape_name = os.path.split(folder_path)[-1] + + + pose_json_path = os.path.join(folder_path, "pose.json") + with open(pose_json_path, 'r') as f: + meta = json.load(f) + + self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10" + self.img_wh = (256, 256) + self.input_poses = np.array(list(meta["c2ws"].values())) + intrinsic = np.eye(4) + intrinsic[:3, :3] = np.array(meta["intrinsics"]) + self.intrinsic = intrinsic + self.near_far = np.array(meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + # self.root_dir = root_dir + for image_dix, img_id in enumerate(self.img_ids): + pose = self.input_poses[image_dix] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join(folder_path, 'stage1_8', f'{self.img_ids[idx]}') + # print(self.img_ids) + img = Image.open(img_filename) + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + + depth_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32) + depth_h = depth_h.fill_(-1.0) + mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32) + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + # src_views = range(8+idx*4, 8+(idx+1)*4) + src_views = range(8, 8 + 8 * 4) + + for vid in src_views: + if vid % 4 == 0: + vid = (vid - 8) // 4 + img_filename = os.path.join(folder_path, 'stage1_8', f'{self.img_ids[vid]}') + else: + img_filename = os.path.join(folder_path, 'stage2_8', f'{self.img_ids[vid]}') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + # ! estimate scale_mat + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + + + target_w2cs = [] + target_intrinsics = [] + new_target_w2cs = [] + for i_idx in range(8): + target_w2cs.append(self.all_extrinsics[i_idx] @ w2c_ref_inv) + target_intrinsics.append(self.all_intrinsics[i_idx]) + + for intrinsic, extrinsic in zip(target_intrinsics, target_w2cs): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_target_w2cs.append(w2c) + target_w2cs = np.stack(new_target_w2cs) + + + + view_ids = [idx] + list(src_views) + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['target_candidate_w2cs'] = torch.from_numpy(target_w2cs.astype(np.float32)) # (8, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = shape_name + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + sample['meta'] = str(self.specific_dataset_name) + '_' + str(shape_name) + "_refview" + str(view_ids[0]) + # print("meta: ", sample['meta']) + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_32_wide.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_32_wide.py new file mode 100644 index 0000000000000000000000000000000000000000..f69ece26bdd88955bf5612f2f6f66ae7f9262e19 --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_32_wide.py @@ -0,0 +1,465 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d + +def calc_pose(phis, thetas, size, radius = 1.2): + import torch + def normalize(vectors): + return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10) + # device = torch.device('cuda') + thetas = torch.FloatTensor(thetas) + phis = torch.FloatTensor(phis) + + centers = torch.stack([ + radius * torch.sin(thetas) * torch.sin(phis), + -radius * torch.cos(thetas) * torch.sin(phis), + radius * torch.cos(phis), + ], dim=-1) # [B, 3] + + # lookat + forward_vector = normalize(centers).squeeze(0) + up_vector = torch.FloatTensor([0, 0, 1]).unsqueeze(0).repeat(size, 1) + right_vector = normalize(torch.cross(up_vector, forward_vector, dim=-1)) + if right_vector.pow(2).sum() < 0.01: + right_vector = torch.FloatTensor([0, 1, 0]).unsqueeze(0).repeat(size, 1) + up_vector = normalize(torch.cross(forward_vector, right_vector, dim=-1)) + + poses = torch.eye(4, dtype=torch.float)[:3].unsqueeze(0).repeat(size, 1, 1) + poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1) + poses[:, :3, 3] = centers + return poses + +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[], + specific_dataset_name = 'GSO' + ): + + # print("root_dir: ", root_dir) + self.root_dir = root_dir + self.split = split + # self.specific_dataset_name = 'Realfusion' + # self.specific_dataset_name = 'GSO' + # self.specific_dataset_name = 'Objaverse' + # self.specific_dataset_name = 'Zero123' + + self.specific_dataset_name = specific_dataset_name + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + assert self.split == 'val' or 'export_mesh', 'only support val or export_mesh' + # find all subfolders + main_folder = os.path.join(root_dir) + self.shape_list = os.listdir(main_folder) + self.shape_list.sort() + + # self.shape_list = ['barrel_render'] + # self.shape_list = ["barrel", "bag", "mailbox", "shoe", "chair", "car", "dog", "teddy"] # TO BE DELETED + + + self.lvis_paths = [] + for shape_name in self.shape_list: + self.lvis_paths.append(os.path.join(main_folder, shape_name)) + + # print("lvis_paths: ", self.lvis_paths) + + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + pose_json_path = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json" + + with open(pose_json_path, 'r') as f: + meta = json.load(f) + intrinsic = np.eye(4) + intrinsic[:3, :3] = np.array(meta["intrinsics"]) + self.intrinsic = intrinsic + self.near_far = np.array(meta["near_far"]) + self.near_far[1] = 1.8 + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid in range(self.input_poses.shape[0]): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + # return 8*len(self.lvis_paths) + return len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + idx = idx * 8 # to be deleted + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj-mats between views + + folder_path = self.lvis_paths[idx//8] + idx = idx % 8 # [0, 7] + + # last subdir name + shape_name = os.path.split(folder_path)[-1] + + # pose_json_path = os.path.join(folder_path, "pose.json") + # with open(pose_json_path, 'r') as f: + # meta = json.load(f) + + # self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10" + # self.img_wh = (256, 256) + # self.input_poses = np.array(list(meta["c2ws"].values())) + # intrinsic = np.eye(4) + # intrinsic[:3, :3] = np.array(meta["intrinsics"]) + # self.intrinsic = intrinsic + # self.near_far = np.array(meta["near_far"]) + # self.near_far[1] = 1.8 + # self.define_transforms() + # self.blender2opencv = np.array( + # [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + # ) + + pose_file = os.path.join(folder_path, '32_random', 'views.npz') + pose_array = np.load(pose_file) + pose = calc_pose(pose_array['elevations'], pose_array['azimuths'], 32) # [32, 3, 4] c2ws + + self.img_wh = (256, 256) + self.input_poses = np.array(pose) + self.input_poses = np.concatenate([self.input_poses, np.tile(np.array([0, 0, 0, 1], dtype=np.float32)[None, None, :], [self.input_poses.shape[0], 1, 1])], axis=1) + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + # self.root_dir = root_dir + for image_dix in range(pose.shape[0]): + pose = self.input_poses[image_dix] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + # img_filename = os.path.join(folder_path, 'stage1_8_debug', f'{self.img_ids[idx]}') + img_filename = os.path.join(folder_path, '32_random', f'{idx}.png') + + img = Image.open(img_filename) + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + + depth_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32) + depth_h = depth_h.fill_(-1.0) + mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32) + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + + src_views = range(0, 8 * 4) + + for vid in src_views: + + # img_filename = os.path.join(folder_path, 'stage2_8_debug', f'{self.img_ids[vid]}') + img_filename = os.path.join(folder_path, '32_random', f'{vid}.png') + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + # ! estimate scale_mat + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + + + target_w2cs = [] + target_intrinsics = [] + new_target_w2cs = [] + for i_idx in range(8): + target_w2cs.append(self.all_extrinsics[i_idx] @ w2c_ref_inv) + target_intrinsics.append(self.all_intrinsics[i_idx]) + + for intrinsic, extrinsic in zip(target_intrinsics, target_w2cs): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_target_w2cs.append(w2c) + target_w2cs = np.stack(new_target_w2cs) + + + + view_ids = [idx] + list(src_views) + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['target_candidate_w2cs'] = torch.from_numpy(target_w2cs.astype(np.float32)) # (8, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = shape_name + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + sample['meta'] = str(self.specific_dataset_name) + '_' + str(shape_name) + "_refview" + str(view_ids[0]) + # print("meta: ", sample['meta']) + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_4_4.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_4_4.py new file mode 100644 index 0000000000000000000000000000000000000000..6263a9ff47edc8f7b65600786c244fafb809240b --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_4_4.py @@ -0,0 +1,419 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d + + +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[], + specific_dataset_name = 'GSO' + ): + + # print("root_dir: ", root_dir) + self.root_dir = root_dir + self.split = split + # self.specific_dataset_name = 'Realfusion' + # self.specific_dataset_name = 'GSO' + # self.specific_dataset_name = 'Objaverse' + # self.specific_dataset_name = 'Zero123' + + self.specific_dataset_name = specific_dataset_name + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + assert self.split == 'val' or 'export_mesh', 'only support val or export_mesh' + # find all subfolders + main_folder = os.path.join(root_dir, self.specific_dataset_name) + self.shape_list = os.listdir(main_folder) + self.shape_list.sort() + + # self.shape_list = ['barrel_render'] + # self.shape_list = ["barrel", "bag", "mailbox", "shoe", "chair", "car", "dog", "teddy"] # TO BE DELETED + + + self.lvis_paths = [] + for shape_name in self.shape_list: + self.lvis_paths.append(os.path.join(main_folder, shape_name)) + + # print("lvis_paths: ", self.lvis_paths) + + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + # return 8*len(self.lvis_paths) + return len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + idx = idx * 8 # to be deleted + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj-mats between views + + folder_path = self.lvis_paths[idx//8] + idx = idx % 8 # [0, 7] + + # last subdir name + shape_name = os.path.split(folder_path)[-1] + + pose_json_path = os.path.join(folder_path, "pose.json") + with open(pose_json_path, 'r') as f: + meta = json.load(f) + + self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10" + self.img_wh = (256, 256) + self.input_poses = np.array(list(meta["c2ws"].values())) + intrinsic = np.eye(4) + intrinsic[:3, :3] = np.array(meta["intrinsics"]) + self.intrinsic = intrinsic + self.near_far = np.array(meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + # self.root_dir = root_dir + for image_dix, img_id in enumerate(self.img_ids): + pose = self.input_poses[image_dix] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + # img_filename = os.path.join(folder_path, 'stage1_8_debug', f'{self.img_ids[idx]}') + img_filename = os.path.join(folder_path, 'stage1_8', f'{self.img_ids[idx]}') + + img = Image.open(img_filename) + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + + depth_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32) + depth_h = depth_h.fill_(-1.0) + mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32) + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + + src_views = range(8, 8 + 8 * 4) + + for vid in src_views: + if (vid // 4) % 2 != 0: + continue + # img_filename = os.path.join(folder_path, 'stage2_8_debug', f'{self.img_ids[vid]}') + img_filename = os.path.join(folder_path, 'stage2_8', f'{self.img_ids[vid]}') + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + # ! estimate scale_mat + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + + + target_w2cs = [] + target_intrinsics = [] + new_target_w2cs = [] + for i_idx in range(8): + target_w2cs.append(self.all_extrinsics[i_idx] @ w2c_ref_inv) + target_intrinsics.append(self.all_intrinsics[i_idx]) + + for intrinsic, extrinsic in zip(target_intrinsics, target_w2cs): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_target_w2cs.append(w2c) + target_w2cs = np.stack(new_target_w2cs) + + + + view_ids = [idx] + list(src_views) + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['target_candidate_w2cs'] = torch.from_numpy(target_w2cs.astype(np.float32)) # (8, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = shape_name + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + sample['meta'] = str(self.specific_dataset_name) + '_' + str(shape_name) + "_refview" + str(view_ids[0]) + # print("meta: ", sample['meta']) + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_6_4.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_6_4.py new file mode 100644 index 0000000000000000000000000000000000000000..c88c0d9b37402f970d9b2d7686b774943366e9a8 --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_6_4.py @@ -0,0 +1,420 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d + + +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[], + specific_dataset_name = 'GSO' + ): + + # print("root_dir: ", root_dir) + self.root_dir = root_dir + self.split = split + # self.specific_dataset_name = 'Realfusion' + # self.specific_dataset_name = 'GSO' + # self.specific_dataset_name = 'Objaverse' + # self.specific_dataset_name = 'Zero123' + + self.specific_dataset_name = specific_dataset_name + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + assert self.split == 'val' or 'export_mesh', 'only support val or export_mesh' + # find all subfolders + main_folder = os.path.join(root_dir, self.specific_dataset_name) + self.shape_list = os.listdir(main_folder) + self.shape_list.sort() + + # self.shape_list = ['barrel_render'] + # self.shape_list = ["barrel", "bag", "mailbox", "shoe", "chair", "car", "dog", "teddy"] # TO BE DELETED + + + self.lvis_paths = [] + for shape_name in self.shape_list: + self.lvis_paths.append(os.path.join(main_folder, shape_name)) + + # print("lvis_paths: ", self.lvis_paths) + + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + # return 8*len(self.lvis_paths) + return len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + idx = idx * 8 # to be deleted + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj-mats between views + + folder_path = self.lvis_paths[idx//8] + idx = idx % 8 # [0, 7] + + # last subdir name + shape_name = os.path.split(folder_path)[-1] + + pose_json_path = os.path.join(folder_path, "pose.json") + with open(pose_json_path, 'r') as f: + meta = json.load(f) + + self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10" + self.img_wh = (256, 256) + self.input_poses = np.array(list(meta["c2ws"].values())) + intrinsic = np.eye(4) + intrinsic[:3, :3] = np.array(meta["intrinsics"]) + self.intrinsic = intrinsic + self.near_far = np.array(meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + # self.root_dir = root_dir + for image_dix, img_id in enumerate(self.img_ids): + pose = self.input_poses[image_dix] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + # img_filename = os.path.join(folder_path, 'stage1_8_debug', f'{self.img_ids[idx]}') + img_filename = os.path.join(folder_path, 'stage1_8', f'{self.img_ids[idx]}') + + img = Image.open(img_filename) + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + + depth_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32) + depth_h = depth_h.fill_(-1.0) + mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32) + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + + src_views = range(8, 8 + 8 * 4) + + for vid in src_views: + if ((vid - 8) // 4 == 4) or ((vid - 8) // 4 == 6): + continue + + # img_filename = os.path.join(folder_path, 'stage2_8_debug', f'{self.img_ids[vid]}') + img_filename = os.path.join(folder_path, 'stage2_8', f'{self.img_ids[vid]}') + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + # ! estimate scale_mat + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + + + target_w2cs = [] + target_intrinsics = [] + new_target_w2cs = [] + for i_idx in range(8): + target_w2cs.append(self.all_extrinsics[i_idx] @ w2c_ref_inv) + target_intrinsics.append(self.all_intrinsics[i_idx]) + + for intrinsic, extrinsic in zip(target_intrinsics, target_w2cs): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_target_w2cs.append(w2c) + target_w2cs = np.stack(new_target_w2cs) + + + + view_ids = [idx] + list(src_views) + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['target_candidate_w2cs'] = torch.from_numpy(target_w2cs.astype(np.float32)) # (8, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = shape_name + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + sample['meta'] = str(self.specific_dataset_name) + '_' + str(shape_name) + "_refview" + str(view_ids[0]) + # print("meta: ", sample['meta']) + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_8_3.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_8_3.py new file mode 100644 index 0000000000000000000000000000000000000000..512c3db02edc8e68208167b7d1715f1f67025cdf --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_8_3.py @@ -0,0 +1,428 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d + + +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[], + specific_dataset_name = 'GSO' + ): + + # print("root_dir: ", root_dir) + self.root_dir = root_dir + self.split = split + # self.specific_dataset_name = 'Realfusion' + # self.specific_dataset_name = 'GSO' + # self.specific_dataset_name = 'Objaverse' + # self.specific_dataset_name = 'Zero123' + + self.specific_dataset_name = specific_dataset_name + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + assert self.split == 'val' or 'export_mesh', 'only support val or export_mesh' + # find all subfolders + main_folder = os.path.join(root_dir, self.specific_dataset_name) + self.shape_list = os.listdir(main_folder) + self.shape_list.sort() + + # self.shape_list = ['barrel_render'] + # self.shape_list = ["barrel", "bag", "mailbox", "shoe", "chair", "car", "dog", "teddy"] # TO BE DELETED + + + self.lvis_paths = [] + for shape_name in self.shape_list: + self.lvis_paths.append(os.path.join(main_folder, shape_name)) + + # print("lvis_paths: ", self.lvis_paths) + + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + # return 8*len(self.lvis_paths) + return len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + idx = idx * 8 # to be deleted + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj-mats between views + + folder_path = self.lvis_paths[idx//8] + idx = idx % 8 # [0, 7] + + # last subdir name + shape_name = os.path.split(folder_path)[-1] + + pose_json_path = os.path.join(folder_path, "pose.json") + with open(pose_json_path, 'r') as f: + meta = json.load(f) + + self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10" + self.img_wh = (256, 256) + self.input_poses = np.array(list(meta["c2ws"].values())) + intrinsic = np.eye(4) + intrinsic[:3, :3] = np.array(meta["intrinsics"]) + self.intrinsic = intrinsic + self.near_far = np.array(meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + # self.root_dir = root_dir + for image_dix, img_id in enumerate(self.img_ids): + pose = self.input_poses[image_dix] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + # img_filename = os.path.join(folder_path, 'stage1_8_debug', f'{self.img_ids[idx]}') + img_filename = os.path.join(folder_path, 'stage1_8', f'{self.img_ids[idx]}') + + img = Image.open(img_filename) + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + + depth_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32) + depth_h = depth_h.fill_(-1.0) + mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32) + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + + # src_views = range(8, 8 + 8 * 4) + + src_views = list() + for i in range(8): + # randomly choose 3 different number from [0,3] + # local_idxs = np.random.choice(4, 3, replace=False) + local_idxs = [0, 2, 3] + # local_idxs = np.random.choice(4, 3, replace=False) + + local_idxs = [8 + i * 4 + local_idx for local_idx in local_idxs] + src_views += local_idxs + + for vid in src_views: + + # img_filename = os.path.join(folder_path, 'stage2_8_debug', f'{self.img_ids[vid]}') + img_filename = os.path.join(folder_path, 'stage2_8', f'{self.img_ids[vid]}') + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + # ! estimate scale_mat + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + + + target_w2cs = [] + target_intrinsics = [] + new_target_w2cs = [] + for i_idx in range(8): + target_w2cs.append(self.all_extrinsics[i_idx] @ w2c_ref_inv) + target_intrinsics.append(self.all_intrinsics[i_idx]) + + for intrinsic, extrinsic in zip(target_intrinsics, target_w2cs): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_target_w2cs.append(w2c) + target_w2cs = np.stack(new_target_w2cs) + + + + view_ids = [idx] + list(src_views) + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['target_candidate_w2cs'] = torch.from_numpy(target_w2cs.astype(np.float32)) # (8, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = shape_name + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + sample['meta'] = str(self.specific_dataset_name) + '_' + str(shape_name) + "_refview" + str(view_ids[0]) + # print("meta: ", sample['meta']) + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_8_wide.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_8_wide.py new file mode 100644 index 0000000000000000000000000000000000000000..3c1a23183a388175c2212bf552fb15ae385737ab --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_8_wide.py @@ -0,0 +1,420 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d + + +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[], + specific_dataset_name = 'GSO' + ): + + # print("root_dir: ", root_dir) + self.root_dir = root_dir + self.split = split + # self.specific_dataset_name = 'Realfusion' + # self.specific_dataset_name = 'GSO' + # self.specific_dataset_name = 'Objaverse' + # self.specific_dataset_name = 'Zero123' + + self.specific_dataset_name = specific_dataset_name + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + assert self.split == 'val' or 'export_mesh', 'only support val or export_mesh' + # find all subfolders + main_folder = os.path.join(root_dir, self.specific_dataset_name) + self.shape_list = os.listdir(main_folder) + self.shape_list.sort() + + # self.shape_list = ['barrel_render'] + # self.shape_list = ["barrel", "bag", "mailbox", "shoe", "chair", "car", "dog", "teddy"] # TO BE DELETED + + + self.lvis_paths = [] + for shape_name in self.shape_list: + self.lvis_paths.append(os.path.join(main_folder, shape_name)) + + # print("lvis_paths: ", self.lvis_paths) + + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + # return 8*len(self.lvis_paths) + return len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + idx = idx * 8 # to be deleted + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj-mats between views + + folder_path = self.lvis_paths[idx//8] + idx = idx % 8 # [0, 7] + + # last subdir name + shape_name = os.path.split(folder_path)[-1] + + pose_json_path = os.path.join(folder_path, "pose.json") + with open(pose_json_path, 'r') as f: + meta = json.load(f) + + self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10" + self.img_wh = (256, 256) + self.input_poses = np.array(list(meta["c2ws"].values())) + intrinsic = np.eye(4) + intrinsic[:3, :3] = np.array(meta["intrinsics"]) + self.intrinsic = intrinsic + self.near_far = np.array(meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + # self.root_dir = root_dir + for image_dix, img_id in enumerate(self.img_ids): + pose = self.input_poses[image_dix] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + # img_filename = os.path.join(folder_path, 'stage1_8_debug', f'{self.img_ids[idx]}') + img_filename = os.path.join(folder_path, 'stage1_8', f'{self.img_ids[idx]}') + + img = Image.open(img_filename) + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + + depth_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32) + depth_h = depth_h.fill_(-1.0) + mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32) + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + + src_views = range(8) + + + for vid in src_views: + + # img_filename = os.path.join(folder_path, 'stage2_8_debug', f'{self.img_ids[vid]}') + # img_filename = os.path.join(folder_path, 'stage2_8', f'{self.img_ids[vid]}') + img_filename = os.path.join(folder_path, 'stage1_8', f'{self.img_ids[vid]}') + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + # ! estimate scale_mat + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + + + target_w2cs = [] + target_intrinsics = [] + new_target_w2cs = [] + for i_idx in range(8): + target_w2cs.append(self.all_extrinsics[i_idx] @ w2c_ref_inv) + target_intrinsics.append(self.all_intrinsics[i_idx]) + + for intrinsic, extrinsic in zip(target_intrinsics, target_w2cs): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_target_w2cs.append(w2c) + target_w2cs = np.stack(new_target_w2cs) + + + + view_ids = [idx] + list(src_views) + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['target_candidate_w2cs'] = torch.from_numpy(target_w2cs.astype(np.float32)) # (8, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = shape_name + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + sample['meta'] = str(self.specific_dataset_name) + '_' + str(shape_name) + "_refview" + str(view_ids[0]) + # print("meta: ", sample['meta']) + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_temp.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_temp.py new file mode 100644 index 0000000000000000000000000000000000000000..4b2c7f6b2306cca93f476c2c233956e4cff0dcfb --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_temp.py @@ -0,0 +1,417 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d + + +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[], + specific_dataset_name = 'GSO' + ): + + # print("root_dir: ", root_dir) + self.root_dir = root_dir + self.split = split + # self.specific_dataset_name = 'Realfusion' + # self.specific_dataset_name = 'GSO' + # self.specific_dataset_name = 'Objaverse' + self.specific_dataset_name = 'Objaverse_archived' + + # self.specific_dataset_name = specific_dataset_name + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + assert self.split == 'val' or 'export_mesh', 'only support val or export_mesh' + # find all subfolders + main_folder = os.path.join(root_dir, self.specific_dataset_name) + self.shape_list = os.listdir(main_folder) + self.shape_list.sort() + + # self.shape_list = ["barrel", "bag", "mailbox", "shoe", "chair", "car", "dog", "teddy"] # TO BE DELETED + + + self.lvis_paths = [] + for shape_name in self.shape_list: + self.lvis_paths.append(os.path.join(main_folder, shape_name)) + + # print("lvis_paths: ", self.lvis_paths) + + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + # return 8*len(self.lvis_paths) + return len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + idx = idx * 8 # to be deleted + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj-mats between views + + folder_path = self.lvis_paths[idx//8] + idx = idx % 8 # [0, 7] + + # last subdir name + shape_name = os.path.split(folder_path)[-1] + + pose_json_path = os.path.join('/objaverse-processed/zero12345_img/zero12345_narrow_pose.json') + with open(pose_json_path, 'r') as f: + meta = json.load(f) + + self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10" + self.img_wh = (256, 256) + self.input_poses = np.array(list(meta["c2ws"].values())) + intrinsic = np.eye(4) + intrinsic[:3, :3] = np.array(meta["intrinsics"]) + self.intrinsic = intrinsic + self.near_far = np.array(meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + # self.root_dir = root_dir + for image_dix, img_id in enumerate(self.img_ids): + pose = self.input_poses[image_dix] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + # img_filename = os.path.join(folder_path, 'stage1_8_debug', f'{self.img_ids[idx]}') + img_filename = os.path.join(folder_path, 'stage1_8', f'{idx}.png') + + img = Image.open(img_filename) + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + + depth_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32) + depth_h = depth_h.fill_(-1.0) + mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32) + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + + src_views = range(8, 8 + 8 * 4) + + for vid in src_views: + + # img_filename = os.path.join(folder_path, 'stage2_8_debug', f'{self.img_ids[vid]}') + img_filename = os.path.join(folder_path, 'stage2_8', f'{(vid-8)//4}_{(vid-8)%4}.png') + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + # ! estimate scale_mat + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + + + target_w2cs = [] + target_intrinsics = [] + new_target_w2cs = [] + for i_idx in range(8): + target_w2cs.append(self.all_extrinsics[i_idx] @ w2c_ref_inv) + target_intrinsics.append(self.all_intrinsics[i_idx]) + + for intrinsic, extrinsic in zip(target_intrinsics, target_w2cs): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_target_w2cs.append(w2c) + target_w2cs = np.stack(new_target_w2cs) + + + + view_ids = [idx] + list(src_views) + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['target_candidate_w2cs'] = torch.from_numpy(target_w2cs.astype(np.float32)) # (8, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = shape_name + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + sample['meta'] = str(self.specific_dataset_name) + '_' + str(shape_name) + "_refview" + str(view_ids[0]) + # print("meta: ", sample['meta']) + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all_no_depth.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all_no_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..33a4ecf7de541049e3b89cc98f74106b59d418c7 --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all_no_depth.py @@ -0,0 +1,388 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[]): + + # print("root_dir: ", root_dir) + self.root_dir = root_dir + self.split = split + + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + + lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid + with open(lvis_json_path, 'r') as f: + lvis_paths = json.load(f) + if self.split == 'train': + self.lvis_paths = lvis_paths['train'] + else: + self.lvis_paths = lvis_paths['val'] + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + pose_json_path = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json" + with open(pose_json_path, 'r') as f: + meta = json.load(f) + + self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10" + self.img_wh = (256, 256) + self.input_poses = np.array(list(meta["c2ws"].values())) + intrinsic = np.eye(4) + intrinsic[:3, :3] = np.array(meta["intrinsics"]) + self.intrinsic = intrinsic + self.near_far = np.array(meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + # self.root_dir = root_dir + for idx, img_id in enumerate(self.img_ids): + pose = self.input_poses[idx] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + # print("center", center) + # print("radius", radius) + # print("bounds", bounds) + # import ipdb; ipdb.set_trace() + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + return 8*len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views + + + folder_uid_dict = self.lvis_paths[idx//8] + idx = idx % 8 # [0, 7] + folder_id = folder_uid_dict['folder_id'] + uid = folder_uid_dict['uid'] + + + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{idx}.png') + + depth_filename = os.path.join(os.path.join(self.root_dir, folder_id, uid, f'view_{idx}_depth_mm.png')) + + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0 + mask_h = depth_h > 0 + # print("valid pixels", np.sum(mask_h)) + # directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3] + # surface_points = directions * depth_h[..., None] # [H, W, 3] + # distance = np.linalg.norm(surface_points, axis=-1) # [H, W] + # depth_h = distance + + depth_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32) + depth_h = depth_h.fill_(-1.0) + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + # src_views = range(8+idx*4, 8+(idx+1)*4) + src_views = range(8, 8 + 8 * 4) + + for vid in src_views: + img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{(vid - 8) // 4}_{vid%4}_10.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + # ! estimate scale_mat + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + view_ids = [idx] + list(src_views) + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = folder_id + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0]) + + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all_only_4.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all_only_4.py new file mode 100644 index 0000000000000000000000000000000000000000..f811326da45563ae870350f78ccdbe358411f3b6 --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all_only_4.py @@ -0,0 +1,389 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[]): + + # print("root_dir: ", root_dir) + self.root_dir = root_dir + self.split = split + + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + + lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid + with open(lvis_json_path, 'r') as f: + lvis_paths = json.load(f) + if self.split == 'train': + self.lvis_paths = lvis_paths['train'] + else: + self.lvis_paths = lvis_paths['val'] + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + pose_json_path = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json" + with open(pose_json_path, 'r') as f: + meta = json.load(f) + + self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10" + self.img_wh = (256, 256) + self.input_poses = np.array(list(meta["c2ws"].values())) + intrinsic = np.eye(4) + intrinsic[:3, :3] = np.array(meta["intrinsics"]) + self.intrinsic = intrinsic + self.near_far = np.array(meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + # self.root_dir = root_dir + for idx, img_id in enumerate(self.img_ids): + pose = self.input_poses[idx] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + # print("center", center) + # print("radius", radius) + # print("bounds", bounds) + # import ipdb; ipdb.set_trace() + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + return 4*len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + idx = idx * 2 + sample = {} + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views + + + folder_uid_dict = self.lvis_paths[idx//8] + idx = idx % 8 # [0, 7] + folder_id = folder_uid_dict['folder_id'] + uid = folder_uid_dict['uid'] + + + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{idx}.png') + + depth_filename = os.path.join(os.path.join(self.root_dir, folder_id, uid, f'view_{idx}_depth_mm.png')) + + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0 + mask_h = depth_h > 0 + # print("valid pixels", np.sum(mask_h)) + directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3] + surface_points = directions * depth_h[..., None] # [H, W, 3] + distance = np.linalg.norm(surface_points, axis=-1) # [H, W] + depth_h = distance + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + # src_views = range(8+idx*4, 8+(idx+1)*4) + src_views = range(8, 8 + 8 * 4) + + for vid in src_views: + if (vid // 4) % 2 != 0: + continue + img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{(vid - 8) // 4}_{vid%4}_10.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + # print("len(imgs)", len(imgs)) + # ! estimate scale_mat + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + view_ids = [idx] + list(src_views) + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = folder_id + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0]) + + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all_only_4_and_4.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all_only_4_and_4.py new file mode 100644 index 0000000000000000000000000000000000000000..76b9fccad69f6929e086074b55807ef5a0a17eee --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all_only_4_and_4.py @@ -0,0 +1,395 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[]): + + # print("root_dir: ", root_dir) + self.root_dir = root_dir + self.split = split + + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + + lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid + with open(lvis_json_path, 'r') as f: + lvis_paths = json.load(f) + if self.split == 'train': + self.lvis_paths = lvis_paths['train'] + else: + self.lvis_paths = lvis_paths['val'] + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + pose_json_path = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json" + with open(pose_json_path, 'r') as f: + meta = json.load(f) + + self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10" + self.img_wh = (256, 256) + self.input_poses = np.array(list(meta["c2ws"].values())) + intrinsic = np.eye(4) + intrinsic[:3, :3] = np.array(meta["intrinsics"]) + self.intrinsic = intrinsic + self.near_far = np.array(meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + # self.root_dir = root_dir + for idx, img_id in enumerate(self.img_ids): + pose = self.input_poses[idx] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + # print("center", center) + # print("radius", radius) + # print("bounds", bounds) + # import ipdb; ipdb.set_trace() + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + return 8*len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + idx = idx + sample = {} + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views + + + folder_uid_dict = self.lvis_paths[idx//8] + idx = idx % 8 # [0, 7] + folder_id = folder_uid_dict['folder_id'] + uid = folder_uid_dict['uid'] + + + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{idx}.png') + + depth_filename = os.path.join(os.path.join(self.root_dir, folder_id, uid, f'view_{idx}_depth_mm.png')) + + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0 + mask_h = depth_h > 0 + # print("valid pixels", np.sum(mask_h)) + directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3] + surface_points = directions * depth_h[..., None] # [H, W, 3] + distance = np.linalg.norm(surface_points, axis=-1) # [H, W] + depth_h = distance + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + # src_views = range(8+idx*4, 8+(idx+1)*4) + + src_views = range(8, 8 + 8 * 4) + + vid_list = [] + for vid in src_views: + if (vid // 4) % 2 != idx % 2: + continue + vid_list.append(vid) + img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{(vid - 8) // 4}_{vid%4}_10.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + # print("idx:", idx) + # print("len(imgs)", len(imgs)) + # print("vid_list", vid_list) + # ! estimate scale_mat + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + view_ids = [idx] + list(src_views) + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = folder_id + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0]) + + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/blender_gt_32.py b/SparseNeuS_demo_v1/data/blender_gt_32.py new file mode 100644 index 0000000000000000000000000000000000000000..9ec6f0075febfcd46061e61ae10cd68b05dfb5fc --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_gt_32.py @@ -0,0 +1,419 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +import json +from termcolor import colored +import imageio +from kornia import create_meshgrid +import open3d as o3d +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +import os, json +import numpy as np +def calc_pose(phis, thetas, size, radius = 1.2): + import torch + def normalize(vectors): + return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10) + # device = torch.device('cuda') + thetas = torch.FloatTensor(thetas) + phis = torch.FloatTensor(phis) + + centers = torch.stack([ + radius * torch.sin(thetas) * torch.sin(phis), + -radius * torch.cos(thetas) * torch.sin(phis), + radius * torch.cos(phis), + ], dim=-1) # [B, 3] + + # lookat + forward_vector = normalize(centers).squeeze(0) + up_vector = torch.FloatTensor([0, 0, 1]).unsqueeze(0).repeat(size, 1) + right_vector = normalize(torch.cross(up_vector, forward_vector, dim=-1)) + if right_vector.pow(2).sum() < 0.01: + right_vector = torch.FloatTensor([0, 1, 0]).unsqueeze(0).repeat(size, 1) + up_vector = normalize(torch.cross(forward_vector, right_vector, dim=-1)) + + poses = torch.eye(4, dtype=torch.float)[:3].unsqueeze(0).repeat(size, 1, 1) + poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1) + poses[:, :3, 3] = centers + return poses + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[]): + + # print("root_dir: ", root_dir) + self.root_dir = root_dir + self.split = split + + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + + lvis_json_path = '/objaverse-processed/zero12345_img/random32_split.json' # folder_id and uid + with open(lvis_json_path, 'r') as f: + lvis_paths = json.load(f) + if self.split == 'train': + self.lvis_paths = lvis_paths['train'] + else: + self.lvis_paths = lvis_paths['val'] + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + pose_json_path = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json" + + with open(pose_json_path, 'r') as f: + meta = json.load(f) + intrinsic = np.eye(4) + intrinsic[:3, :3] = np.array(meta["intrinsics"]) + self.intrinsic = intrinsic + self.near_far = np.array(meta["near_far"]) + self.near_far[1] = 1.8 + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid in range(self.input_poses.shape[0]): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + pass + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + # print("center", center) + # print("radius", radius) + # print("bounds", bounds) + # import ipdb; ipdb.set_trace() + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + return 32*len(self.lvis_paths) + + + def read_depth(self, filename, near_bound, noisy_factor=1.0): + pass + + + def __getitem__(self, idx): + sample = {} + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views + + + folder_uid_dict = self.lvis_paths[idx//32] + idx = idx % 32 # [0, 7] + folder_id = folder_uid_dict['folder_id'] + uid = folder_uid_dict['uid'] + + pose_file = os.path.join('/objaverse-processed/zero12345_img/random32/', folder_id, uid, 'views.npz') + pose_array = np.load(pose_file) + pose = calc_pose(pose_array['elevations'], pose_array['azimuths'], 32) # [32, 3, 4] c2ws + + self.img_wh = (256, 256) + self.input_poses = np.array(pose) + self.input_poses = np.concatenate([self.input_poses, np.tile(np.array([0, 0, 0, 1], dtype=np.float32)[None, None, :], [self.input_poses.shape[0], 1, 1])], axis=1) + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + # self.root_dir = root_dir + for image_dix in range(pose.shape[0]): + pose = self.input_poses[image_dix] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + + + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join('/objaverse-processed/zero12345_img/random32/', folder_id, uid, f'{idx}.png') + + depth_filename = os.path.join(os.path.join('/objaverse-processed/zero12345_img/random32/', folder_id, uid, f'{idx}_depth_mm.png')) + + + img = Image.open(img_filename) + + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0 + mask_h = depth_h > 0 + + directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3] + surface_points = directions * depth_h[..., None] # [H, W, 3] + distance = np.linalg.norm(surface_points, axis=-1) # [H, W] + depth_h = distance + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + # src_views = range(8+idx*4, 8+(idx+1)*4) + src_views = range(0, 8 * 4) + + for vid in src_views: + img_filename = os.path.join('/objaverse-processed/zero12345_img/random32/', folder_id, uid, f'{vid}.png') + + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + # ! estimate scale_mat + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + new_depths_h.append(depth * scale_factor) + + # print(new_near_fars) + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + view_ids = [idx] + list(src_views) + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + # sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = folder_id + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0]) + + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/dtu/dtu_pairs.txt b/SparseNeuS_demo_v1/data/dtu/dtu_pairs.txt new file mode 100644 index 0000000000000000000000000000000000000000..bd0d79868f196991c06ec2a496dbe06e5ded0fd2 --- /dev/null +++ b/SparseNeuS_demo_v1/data/dtu/dtu_pairs.txt @@ -0,0 +1,93 @@ +46 +0 +10 10 2346.410000 1 2036.530000 9 1243.890000 12 1052.870000 11 1000.840000 13 703.583000 2 604.456000 8 439.759000 14 327.419000 27 249.278000 +1 +10 9 2850.870000 10 2583.940000 2 2105.590000 0 2052.840000 8 1868.240000 13 1184.230000 14 1017.510000 12 961.966000 7 670.208000 15 657.218000 +2 +10 8 2501.240000 1 2106.880000 7 1856.500000 9 1782.340000 3 1141.770000 15 1061.760000 14 815.457000 16 762.153000 6 709.789000 10 699.921000 +3 +10 7 1294.390000 6 1159.130000 2 1134.270000 4 905.717000 8 687.320000 5 600.015000 17 496.958000 16 481.969000 1 379.011000 15 307.450000 +4 +10 5 1333.740000 6 1145.150000 3 895.254000 7 486.504000 18 446.420000 2 418.517000 17 326.528000 8 161.115000 16 149.154000 1 103.626000 +5 +10 6 1676.060000 18 1555.060000 4 1335.550000 17 868.416000 3 593.755000 7 467.816000 20 440.579000 19 428.255000 16 242.327000 21 210.253000 +6 +10 17 2332.350000 7 1848.240000 18 1812.740000 5 1696.070000 16 1273.000000 3 1157.990000 4 1155.410000 20 771.624000 21 744.945000 2 700.368000 +7 +10 16 2709.460000 8 2439.700000 15 2078.210000 6 1864.160000 2 1846.600000 17 1791.710000 3 1296.860000 22 957.793000 9 879.088000 21 782.277000 +8 +10 15 3124.010000 9 3099.920000 14 2756.290000 2 2501.220000 7 2449.320000 1 1875.940000 16 1726.040000 13 1325.760000 23 1177.090000 24 1108.820000 +9 +10 13 3355.620000 14 3226.070000 8 3098.800000 10 3097.070000 1 2861.420000 12 1873.630000 2 1785.980000 15 1753.320000 25 1365.450000 0 1261.590000 +10 +10 12 3750.700000 9 3085.870000 13 3028.390000 1 2590.550000 0 2369.790000 11 2266.670000 14 1524.160000 26 1448.150000 27 1293.600000 8 1041.840000 +11 +10 12 3543.760000 27 3056.050000 10 2248.070000 26 1524.280000 28 1273.330000 13 1265.900000 29 1129.550000 0 998.164000 9 591.176000 30 572.919000 +12 +10 27 3889.870000 10 3754.540000 13 3745.210000 11 3584.260000 26 3574.560000 25 1877.110000 9 1866.340000 29 1482.720000 30 1418.510000 14 1341.860000 +13 +10 12 3773.140000 26 3699.280000 25 3657.170000 14 3652.040000 9 3356.290000 10 3049.270000 24 2098.910000 27 1900.960000 31 1460.960000 30 1349.620000 +14 +10 13 3663.520000 24 3610.690000 9 3232.550000 25 3216.400000 15 3128.840000 8 2758.040000 23 2219.910000 26 1567.450000 10 1536.600000 32 1419.330000 +15 +10 23 3194.920000 14 3126.000000 8 3120.430000 16 2897.020000 24 2562.490000 7 2084.050000 22 2041.630000 9 1752.080000 33 1232.290000 13 1137.550000 +16 +10 15 2884.140000 7 2713.880000 22 2708.570000 17 2448.500000 21 2173.300000 23 1908.030000 8 1718.790000 6 1281.960000 35 1047.380000 34 980.064000 +17 +10 21 2632.480000 16 2428.000000 6 2343.570000 18 2250.230000 20 2149.750000 7 1779.420000 22 1380.250000 36 957.046000 5 878.398000 15 789.068000 +18 +9 17 2219.150000 20 2173.020000 6 1802.390000 19 1575.770000 5 1564.810000 21 1160.130000 16 660.317000 7 589.484000 36 559.983000 +19 +7 20 1828.970000 18 1564.630000 17 685.249000 36 613.420000 21 572.770000 5 427.597000 6 368.651000 +20 +8 21 2569.790000 36 2258.330000 18 2186.710000 17 2130.670000 19 1865.060000 35 996.122000 16 799.808000 40 778.721000 +21 +9 36 2704.590000 35 2639.690000 17 2638.190000 20 2605.430000 22 2604.260000 16 2158.250000 34 1239.250000 18 1178.240000 40 1128.570000 +22 +10 23 3232.680000 34 3175.150000 35 2831.090000 16 2712.510000 21 2632.190000 15 2033.390000 33 1712.670000 17 1393.860000 36 1290.960000 24 1195.330000 +23 +10 24 3710.900000 33 3603.070000 22 3244.200000 15 3190.620000 34 3086.490000 14 2220.110000 32 2100.000000 16 1917.100000 35 1359.790000 25 1356.710000 +24 +10 25 3844.600000 32 3750.750000 23 3710.600000 14 3609.090000 33 3091.040000 15 2559.240000 31 2423.710000 13 2109.360000 26 1440.580000 34 1410.030000 +25 +10 26 3951.740000 31 3888.570000 24 3833.070000 13 3667.350000 14 3208.210000 32 2993.460000 30 2681.520000 12 1900.230000 45 1484.030000 27 1462.880000 +26 +10 30 4033.350000 27 3970.470000 25 3925.250000 13 3686.340000 12 3595.590000 29 2943.870000 31 2917.000000 14 1556.340000 11 1554.750000 46 1503.840000 +27 +10 29 4027.840000 26 3929.940000 12 3875.580000 11 3085.030000 28 2908.600000 30 2792.670000 13 1878.420000 25 1438.550000 47 1425.200000 10 1290.250000 +28 +10 29 3687.020000 48 3209.130000 27 2872.860000 47 2014.530000 30 1361.950000 11 1273.600000 26 1062.850000 12 840.841000 46 672.985000 31 271.952000 +29 +10 27 4029.430000 30 3909.550000 28 3739.930000 47 3695.230000 48 3135.870000 26 2910.970000 46 2229.550000 12 1479.160000 31 1430.260000 11 1144.560000 +30 +10 26 4029.860000 29 3953.720000 31 3811.120000 46 3630.460000 47 3105.960000 27 2824.430000 25 2657.890000 45 2347.750000 32 1459.110000 12 1429.620000 +31 +10 25 3882.210000 30 3841.880000 32 3808.500000 45 3649.820000 46 3000.670000 26 2939.940000 24 2409.930000 44 2381.300000 13 1467.590000 29 1459.560000 +32 +10 31 3826.500000 24 3744.140000 33 3613.240000 44 3552.040000 25 3004.600000 45 2884.590000 43 2393.340000 23 2095.270000 30 1478.600000 14 1420.780000 +33 +10 32 3618.110000 23 3598.100000 34 3530.530000 43 3462.370000 24 3091.530000 44 2608.080000 42 2426.000000 22 1717.940000 31 1407.650000 25 1324.780000 +34 +10 33 3523.370000 42 3356.550000 35 3210.340000 22 3178.850000 23 3079.030000 43 2396.450000 41 2386.860000 24 1408.020000 32 1301.340000 21 1256.450000 +35 +10 34 3187.880000 41 3106.440000 36 2866.040000 22 2817.740000 21 2654.870000 40 2416.980000 42 2137.810000 23 1346.860000 33 1150.330000 16 1044.660000 +36 +8 40 2910.700000 35 2832.660000 21 2689.960000 20 2280.460000 41 1787.970000 22 1268.490000 34 981.636000 17 954.229000 +40 +7 36 2918.140000 41 2852.620000 35 2392.960000 21 1124.300000 42 1056.480000 34 877.946000 20 788.701000 +41 +9 35 3111.050000 42 3049.710000 40 2885.360000 34 2371.020000 36 1813.690000 43 1164.710000 22 1126.900000 21 906.536000 33 903.238000 +42 +10 34 3356.980000 43 3183.000000 41 3070.540000 33 2421.770000 35 2155.080000 44 1278.410000 23 1183.520000 22 1147.070000 40 1077.080000 32 899.646000 +43 +10 33 3461.240000 44 3380.740000 42 3188.700000 34 2400.600000 32 2399.090000 45 1359.370000 23 1314.080000 41 1176.120000 24 1159.620000 31 901.556000 +44 +10 32 3550.810000 45 3510.160000 43 3373.110000 33 2602.330000 31 2395.930000 24 1410.430000 46 1386.310000 42 1279.000000 25 1095.240000 34 968.440000 +45 +10 31 3650.090000 46 3555.090000 44 3491.150000 32 2868.390000 30 2373.590000 25 1485.370000 47 1405.280000 43 1349.540000 33 1104.770000 26 1046.810000 +46 +10 30 3635.640000 47 3562.170000 45 3524.170000 31 2976.820000 29 2264.040000 26 1508.870000 44 1367.410000 48 1352.100000 32 1211.240000 25 1102.170000 +47 +10 29 3705.310000 46 3519.760000 48 3450.480000 30 3074.770000 28 2054.630000 27 1434.570000 45 1377.340000 31 1268.230000 26 1223.830000 25 471.111000 +48 +10 47 3401.950000 28 3224.840000 29 3101.160000 46 1317.100000 30 1306.700000 27 1235.070000 26 537.731000 31 291.919000 45 276.869000 11 258.856000 diff --git a/SparseNeuS_demo_v1/data/dtu/lists/test.txt b/SparseNeuS_demo_v1/data/dtu/lists/test.txt new file mode 100644 index 0000000000000000000000000000000000000000..b1420254bbe0fe15e9ad9358cdbaedf34605a558 --- /dev/null +++ b/SparseNeuS_demo_v1/data/dtu/lists/test.txt @@ -0,0 +1,15 @@ +scan24 +scan37 +scan40 +scan55 +scan63 +scan65 +scan69 +scan83 +scan97 +scan105 +scan106 +scan110 +scan114 +scan118 +scan122 \ No newline at end of file diff --git a/SparseNeuS_demo_v1/data/dtu/lists/train.txt b/SparseNeuS_demo_v1/data/dtu/lists/train.txt new file mode 100644 index 0000000000000000000000000000000000000000..4259e846edcee621baf19875e2900e169849f5e3 --- /dev/null +++ b/SparseNeuS_demo_v1/data/dtu/lists/train.txt @@ -0,0 +1,75 @@ +scan1 +scan4 +scan5 +scan6 +scan8 +scan9 +scan10 +scan11 +scan12 +scan13 +scan14 +scan15 +scan16 +scan17 +scan18 +scan19 +scan20 +scan21 +scan22 +scan23 +scan28 +scan29 +scan30 +scan31 +scan32 +scan33 +scan34 +scan35 +scan36 +scan38 +scan39 +scan41 +scan42 +scan43 +scan44 +scan45 +scan46 +scan47 +scan48 +scan49 +scan50 +scan51 +scan52 +scan59 +scan60 +scan61 +scan62 +scan64 +scan74 +scan75 +scan76 +scan77 +scan84 +scan85 +scan86 +scan87 +scan88 +scan89 +scan90 +scan91 +scan92 +scan93 +scan94 +scan95 +scan96 +scan98 +scan99 +scan100 +scan101 +scan102 +scan103 +scan104 +scan126 +scan127 +scan128 \ No newline at end of file diff --git a/SparseNeuS_demo_v1/data/dtu_fit.py b/SparseNeuS_demo_v1/data/dtu_fit.py new file mode 100644 index 0000000000000000000000000000000000000000..e4a97d28b635a9158c49e2a651c7799ad1009021 --- /dev/null +++ b/SparseNeuS_demo_v1/data/dtu_fit.py @@ -0,0 +1,278 @@ +import torch +import torch.nn as nn +import cv2 as cv +import numpy as np +import re +import os +import logging +from glob import glob + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image + +from data.scene import get_boundingbox + + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +class DtuFit: + def __init__(self, root_dir, split, scan_id, n_views, train_img_idx=[], test_img_idx=[], + img_wh=[800, 600], clip_wh=[0, 0], original_img_wh=[1600, 1200], + N_rays=512, h_patch_size=5, near=425, far=900): + super(DtuFit, self).__init__() + logging.info('Load data: Begin') + + self.root_dir = root_dir + self.split = split + self.scan_id = scan_id + self.n_views = n_views + + self.near = near + self.far = far + + if self.scan_id is not None: + self.data_dir = os.path.join(self.root_dir, self.scan_id) + else: + self.data_dir = self.root_dir + + self.img_wh = img_wh + self.clip_wh = clip_wh + + if len(self.clip_wh) == 2: + self.clip_wh = self.clip_wh + self.clip_wh + + self.original_img_wh = original_img_wh + self.N_rays = N_rays + self.h_patch_size = h_patch_size # used to extract patch for supervision + self.train_img_idx = train_img_idx + self.test_img_idx = test_img_idx + + camera_dict = np.load(os.path.join(self.data_dir, 'cameras.npz'), allow_pickle=True) + self.images_list = sorted(glob(os.path.join(self.data_dir, "image/*.png"))) + # world_mat: projection matrix: world to image + self.world_mats_np = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in + range(len(self.images_list))] + + self.raw_near_fars = np.stack([np.array([self.near, self.far]) for i in range(len(self.images_list))]) + + # - reference image; transform the world system to the ref-camera system + self.ref_img_idx = self.train_img_idx[0] + ref_world_mat = self.world_mats_np[self.ref_img_idx] + self.ref_w2c = np.linalg.inv(load_K_Rt_from_P(None, ref_world_mat[:3, :4])[1]) + + self.all_images = [] + self.all_intrinsics = [] + self.all_w2cs = [] + + self.load_scene() # load the scene + + # ! estimate scale_mat + self.scale_mat, self.scale_factor = self.cal_scale_mat( + img_hw=[self.img_wh[1], self.img_wh[0]], + intrinsics=self.all_intrinsics[self.train_img_idx], + extrinsics=self.all_w2cs[self.train_img_idx], + near_fars=self.raw_near_fars[self.train_img_idx], + factor=1.1) + + # * after scaling and translation, unit bounding box + self.scaled_intrinsics, self.scaled_w2cs, self.scaled_c2ws, \ + self.scaled_affine_mats, self.scaled_near_fars = self.scale_cam_info() + # import ipdb; ipdb.set_trace() + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + self.partial_vol_origin = torch.Tensor([-1., -1., -1.]) + + logging.info('Load data: End') + + def load_scene(self): + + scale_x = self.img_wh[0] / self.original_img_wh[0] + scale_y = self.img_wh[1] / self.original_img_wh[1] + + for idx in range(len(self.images_list)): + image = cv.imread(self.images_list[idx]) + image = cv.resize(image, (self.img_wh[0], self.img_wh[1])) / 255. + + image = image[self.clip_wh[1]:self.img_wh[1] - self.clip_wh[3], + self.clip_wh[0]:self.img_wh[0] - self.clip_wh[2]] + self.all_images.append(np.transpose(image[:, :, ::-1], (2, 0, 1))) # append [3,] + + P = self.world_mats_np[idx] + P = P[:3, :4] + intrinsics, c2w = load_K_Rt_from_P(None, P) + w2c = np.linalg.inv(c2w) + + intrinsics[:1] *= scale_x + intrinsics[1:2] *= scale_y + + intrinsics[0, 2] -= self.clip_wh[0] + intrinsics[1, 2] -= self.clip_wh[1] + + self.all_intrinsics.append(intrinsics) + # - transform from world system to ref-camera system + self.all_w2cs.append(w2c @ np.linalg.inv(self.ref_w2c)) + + + self.all_images = torch.from_numpy(np.stack(self.all_images)).to(torch.float32) + self.all_intrinsics = torch.from_numpy(np.stack(self.all_intrinsics)).to(torch.float32) + self.all_w2cs = torch.from_numpy(np.stack(self.all_w2cs)).to(torch.float32) + self.img_wh = [self.img_wh[0] - self.clip_wh[0] - self.clip_wh[2], + self.img_wh[1] - self.clip_wh[1] - self.clip_wh[3]] + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + center, radius, _ = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def scale_cam_info(self): + new_intrinsics = [] + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + for idx in range(len(self.all_images)): + intrinsics = self.all_intrinsics[idx] + P = intrinsics @ self.all_w2cs[idx] @ self.scale_mat + P = P.cpu().numpy()[:3, :4] + + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + new_intrinsics.append(intrinsics) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsics[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + + new_intrinsics, new_w2cs, new_c2ws, new_affine_mats, new_near_fars = \ + np.stack(new_intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), \ + np.stack(new_affine_mats), np.stack(new_near_fars) + + new_intrinsics = torch.from_numpy(np.float32(new_intrinsics)) + new_w2cs = torch.from_numpy(np.float32(new_w2cs)) + new_c2ws = torch.from_numpy(np.float32(new_c2ws)) + new_affine_mats = torch.from_numpy(np.float32(new_affine_mats)) + new_near_fars = torch.from_numpy(np.float32(new_near_fars)) + + return new_intrinsics, new_w2cs, new_c2ws, new_affine_mats, new_near_fars + + + def get_conditional_sample(self): + sample = {} + support_idxs = self.train_img_idx + + sample['images'] = self.all_images[support_idxs] # (V, 3, H, W) + sample['w2cs'] = self.scaled_w2cs[self.train_img_idx] # (V, 4, 4) + sample['c2ws'] = self.scaled_c2ws[self.train_img_idx] # (V, 4, 4) + sample['near_fars'] = self.scaled_near_fars[self.train_img_idx] # (V, 2) + sample['intrinsics'] = self.scaled_intrinsics[self.train_img_idx][:, :3, :3] # (V, 3, 3) + sample['affine_mats'] = self.scaled_affine_mats[self.train_img_idx] # ! in world space + + sample['scan'] = self.scan_id + sample['scale_factor'] = torch.tensor(self.scale_factor) + sample['scale_mat'] = torch.from_numpy(self.scale_mat) + sample['trans_mat'] = torch.from_numpy(np.linalg.inv(self.ref_w2c)) + sample['img_wh'] = torch.from_numpy(np.array(self.img_wh)) + sample['partial_vol_origin'] = torch.tensor(self.partial_vol_origin, dtype=torch.float32) + + return sample + + def __len__(self): + if self.split == 'train': + return self.n_views * 1000 + else: + return len(self.test_img_idx) * 1000 + + def __getitem__(self, idx): + sample = {} + + if self.split == 'train': + render_idx = self.train_img_idx[idx % self.n_views] + support_idxs = [idx for idx in self.train_img_idx if idx != render_idx] + else: + # render_idx = idx % self.n_test_images + self.n_train_images + render_idx = self.test_img_idx[idx % len(self.test_img_idx)] + support_idxs = [render_idx] + + sample['images'] = self.all_images[support_idxs] # (V, 3, H, W) + sample['w2cs'] = self.scaled_w2cs[support_idxs] # (V, 4, 4) + sample['c2ws'] = self.scaled_c2ws[support_idxs] # (V, 4, 4) + sample['intrinsics'] = self.scaled_intrinsics[support_idxs][:, :3, :3] # (V, 3, 3) + sample['affine_mats'] = self.scaled_affine_mats[support_idxs] # ! in world space + sample['scan'] = self.scan_id + sample['scale_factor'] = torch.tensor(self.scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(self.img_wh)) + sample['partial_vol_origin'] = torch.tensor(self.partial_vol_origin, dtype=torch.float32) + sample['img_index'] = torch.tensor(render_idx) + + # - query image + sample['query_image'] = self.all_images[render_idx] + sample['query_c2w'] = self.scaled_c2ws[render_idx] + sample['query_w2c'] = self.scaled_w2cs[render_idx] + sample['query_intrinsic'] = self.scaled_intrinsics[render_idx] + sample['query_near_far'] = self.scaled_near_fars[render_idx] + sample['meta'] = str(self.scan_id) + "_" + os.path.basename(self.images_list[render_idx]) + sample['scale_mat'] = torch.from_numpy(self.scale_mat) + sample['trans_mat'] = torch.from_numpy(np.linalg.inv(self.ref_w2c)) + sample['rendering_c2ws'] = self.scaled_c2ws[self.test_img_idx] + sample['rendering_imgs_idx'] = torch.Tensor(np.array(self.test_img_idx).astype(np.int32)) + + # - generate rays + if self.split == 'val' or self.split == 'test': + sample_rays = gen_rays_from_single_image( + self.img_wh[1], self.img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=None, + mask=None) + else: + sample_rays = gen_random_rays_from_single_image( + self.img_wh[1], self.img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=None, + mask=None, + dilated_mask=None, + importance_sample=False, + h_patch_size=self.h_patch_size + ) + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/dtu_general.py b/SparseNeuS_demo_v1/data/dtu_general.py new file mode 100644 index 0000000000000000000000000000000000000000..c6c7734df6072dd618ccdde71ca428f983a605e8 --- /dev/null +++ b/SparseNeuS_demo_v1/data/dtu_general.py @@ -0,0 +1,376 @@ +from torch.utils.data import Dataset +from utils.misc_utils import read_pfm +import os +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image + +from termcolor import colored +import pdb +import random + + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class MVSDatasetDtuPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(640, 512), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[]): + + self.root_dir = root_dir + self.split = split + + self.img_wh = img_wh + self.downSample = downSample + self.num_all_imgs = 49 # this preprocessed DTU dataset has 49 images + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + self.split_filepath = f'data/dtu/lists/{self.split}.txt' if split_filepath is None else split_filepath + self.pair_filepath = f'data/dtu/dtu_pairs.txt' if pair_filepath is None else pair_filepath + + print(colored("loading all scenes together", 'red')) + with open(self.split_filepath) as f: + self.scans = [line.rstrip() for line in f.readlines()] + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + + self.metas, self.ref_src_pairs = self.build_metas() # load ref-srcs view pairs info of the scene + + self.allview_ids = [i for i in range(self.num_all_imgs)] + + self.load_cam_info() # load camera info of DTU, and estimate scale_mat + + self.build_remap() + self.define_transforms() + print(f'==> image down scale: {self.downSample}') + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.Tensor([-1., -1., -1.]) + + def build_remap(self): + self.remap = np.zeros(np.max(self.allview_ids) + 1).astype('int') + for i, item in enumerate(self.allview_ids): + self.remap[item] = i + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + def build_metas(self): + metas = [] + ref_src_pairs = {} + # light conditions 0-6 for training + # light condition 3 for testing (the brightest?) + light_idxs = [3] if 'train' not in self.split else range(7) + + with open(self.pair_filepath) as f: + num_viewpoint = int(f.readline()) + # viewpoints (49) + for _ in range(num_viewpoint): + ref_view = int(f.readline().rstrip()) + src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] + + ref_src_pairs[ref_view] = src_views + + for light_idx in light_idxs: + for scan in self.scans: + with open(self.pair_filepath) as f: + num_viewpoint = int(f.readline()) + # viewpoints (49) + for _ in range(num_viewpoint): + ref_view = int(f.readline().rstrip()) + src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] + + # ! only for validation + if len(self.test_ref_views) > 0 and ref_view not in self.test_ref_views: + continue + + metas += [(scan, light_idx, ref_view, src_views)] + + return metas, ref_src_pairs + + def read_cam_file(self, filename): + with open(filename) as f: + lines = [line.rstrip() for line in f.readlines()] + # extrinsics: line [1,5), 4x4 matrix + extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ') + extrinsics = extrinsics.reshape((4, 4)) + # intrinsics: line [7-10), 3x3 matrix + intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ') + intrinsics = intrinsics.reshape((3, 3)) + # depth_min & depth_interval: line 11 + depth_min = float(lines[11].split()[0]) + depth_max = depth_min + float(lines[11].split()[1]) * 192 + self.depth_interval = float(lines[11].split()[1]) + intrinsics_ = np.float32(np.diag([1, 1, 1, 1])) + intrinsics_[:3, :3] = intrinsics + return intrinsics_, extrinsics, [depth_min, depth_max] + + def load_cam_info(self): + for vid in range(self.num_all_imgs): + proj_mat_filename = os.path.join(self.root_dir, + f'Cameras/train/{vid:08d}_cam.txt') + intrinsic, extrinsic, near_far = self.read_cam_file(proj_mat_filename) + intrinsic[:2] *= 4 # * the provided intrinsics is 4x downsampled, now keep the same scale with image + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_depth(self, filename): + # import ipdb; ipdb.set_trace() + depth_h = np.array(read_pfm(filename)[0], dtype=np.float32) # (1200, 1600) + depth_h = np.ones((1200, 1600)) + # print(depth_h.shape) + depth_h = cv2.resize(depth_h, None, fx=0.5, fy=0.5, + interpolation=cv2.INTER_NEAREST) # (600, 800) + depth_h = depth_h[44:556, 80:720] # (512, 640) + # print(depth_h.shape) + # import ipdb; ipdb.set_trace() + depth_h = cv2.resize(depth_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + depth = cv2.resize(depth_h, None, fx=1.0 / 4, fy=1.0 / 4, + interpolation=cv2.INTER_NEAREST) + + return depth, depth_h + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + center, radius, _ = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + return len(self.metas) + + def __getitem__(self, idx): + sample = {} + scan, light_idx, ref_view, src_views = self.metas[idx % len(self.metas)] + + # generalized, load some images at once + view_ids = [ref_view] + src_views[:self.n_views] + # * transform from world system to camera system + w2c_ref = self.all_extrinsics[self.remap[ref_view]] + w2c_ref_inv = np.linalg.inv(w2c_ref) + + image_perm = 0 # only supervised on reference view + + imgs, depths_h, masks_h = [], [], [] # full size (640, 512) + intrinsics, w2cs, near_fars = [], [], [] # record proj mats between views + mask_dilated = None + for i, vid in enumerate(view_ids): + # NOTE that the id in image file names is from 1 to 49 (not 0~48) + img_filename = os.path.join(self.root_dir, + f'Rectified/{scan}_train/rect_{vid + 1:03d}_{light_idx}_r5000.png') + depth_filename = os.path.join(self.root_dir, + f'Depths/{scan}_train/depth_map_{vid:04d}.pfm') + # print(depth_filename) + mask_filename = os.path.join(self.root_dir, + f'Masks_clean_dilated/{scan}_train/mask_{vid:04d}.png') + + img = Image.open(img_filename) + img_wh = np.round(np.array(img.size) * self.downSample).astype('int') + img = img.resize(img_wh, Image.BILINEAR) + + if os.path.exists(mask_filename) and self.clean_image: + mask_l, mask_h = self.read_mask(mask_filename) + else: + # print(self.split, "don't find mask file", mask_filename) + mask_h = np.ones([img_wh[1], img_wh[0]]) + masks_h.append(mask_h) + + if i == 0: + kernel_size = 101 # default 101 + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) + mask_dilated = np.float32(cv2.dilate(np.uint8(mask_h * 255), kernel, iterations=1) > 128) + + if self.clean_image: + img = np.array(img) + img[mask_h < 0.5] = 0.0 + + img = self.transform(img) + + imgs += [img] + + index_mat = self.remap[vid] + near_fars.append(self.all_near_fars[index_mat]) + intrinsics.append(self.all_intrinsics[index_mat]) + + w2cs.append(self.all_extrinsics[index_mat] @ w2c_ref_inv) + + # print(depth_filename) + if os.path.exists(depth_filename): # and i == 0 + # print("file exists") + depth_l, depth_h = self.read_depth(depth_filename) + depths_h.append(depth_h) + # ! estimate scale_mat + scale_mat, scale_factor = self.cal_scale_mat(img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1) + + # ! calculate the new w2cs after scaling + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + new_depths_h.append(depth * scale_factor) + + imgs = torch.stack(imgs).float() + print(new_near_fars) + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if 'train' in self.split: + start_idx = 0 + else: + start_idx = 1 + + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + sample['light_idx'] = torch.tensor(light_idx) + sample['scan'] = scan + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(img_wh) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = torch.tensor(self.partial_vol_origin, dtype=torch.float32) + sample['meta'] = str(scan) + "_light" + str(light_idx) + "_refview" + str(ref_view) + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/scene.py b/SparseNeuS_demo_v1/data/scene.py new file mode 100644 index 0000000000000000000000000000000000000000..49183c65418338864ecabdd1af914bbb0f055579 --- /dev/null +++ b/SparseNeuS_demo_v1/data/scene.py @@ -0,0 +1,102 @@ +import numpy as np +import torch +import pdb + + +def rigid_transform(xyz, transform): + """Applies a rigid transform (c2w) to an (N, 3) pointcloud. + """ + device = xyz.device + xyz_h = torch.cat([xyz, torch.ones((len(xyz), 1)).to(device)], dim=1) # (N, 4) + xyz_t_h = (transform @ xyz_h.T).T # * checked: the same with the below + + return xyz_t_h[:, :3] + + +def get_view_frustum(min_depth, max_depth, size, cam_intr, c2w): + """Get corners of 3D camera view frustum of depth image + """ + device = cam_intr.device + im_h, im_w = size + im_h = int(im_h) + im_w = int(im_w) + view_frust_pts = torch.stack([ + (torch.tensor([0, 0, im_w, im_w, 0, 0, im_w, im_w]).to(device) - cam_intr[0, 2]) * torch.tensor( + [min_depth, min_depth, min_depth, min_depth, max_depth, max_depth, max_depth, max_depth]).to(device) / + cam_intr[0, 0], + (torch.tensor([0, im_h, 0, im_h, 0, im_h, 0, im_h]).to(device) - cam_intr[1, 2]) * torch.tensor( + [min_depth, min_depth, min_depth, min_depth, max_depth, max_depth, max_depth, max_depth]).to(device) / + cam_intr[1, 1], + torch.tensor([min_depth, min_depth, min_depth, min_depth, max_depth, max_depth, max_depth, max_depth]).to( + device) + ]) + view_frust_pts = view_frust_pts.type(torch.float32) + c2w = c2w.type(torch.float32) + view_frust_pts = rigid_transform(view_frust_pts.T, c2w).T + return view_frust_pts + + +def set_pixel_coords(h, w): + i_range = torch.arange(0, h).view(1, h, 1).expand(1, h, w).type(torch.float32) # [1, H, W] + j_range = torch.arange(0, w).view(1, 1, w).expand(1, h, w).type(torch.float32) # [1, H, W] + ones = torch.ones(1, h, w).type(torch.float32) + + pixel_coords = torch.stack((j_range, i_range, ones), dim=1) # [1, 3, H, W] + + return pixel_coords + + +def get_boundingbox(img_hw, intrinsics, extrinsics, near_fars): + """ + # get the minimum bounding box of all visual hulls + :param img_hw: + :param intrinsics: + :param extrinsics: + :param near_fars: + :return: + """ + + bnds = torch.zeros((3, 2)) + bnds[:, 0] = np.inf + bnds[:, 1] = -np.inf + + if isinstance(intrinsics, list): + num = len(intrinsics) + else: + num = intrinsics.shape[0] + # print("num: ", num) + view_frust_pts_list = [] + for i in range(num): + if not isinstance(intrinsics[i], torch.Tensor): + cam_intr = torch.tensor(intrinsics[i]) + w2c = torch.tensor(extrinsics[i]) + c2w = torch.inverse(w2c) + else: + cam_intr = intrinsics[i] + w2c = extrinsics[i] + c2w = torch.inverse(w2c) + min_depth, max_depth = near_fars[i][0], near_fars[i][1] + # todo: check the coresponding points are matched + + view_frust_pts = get_view_frustum(min_depth, max_depth, img_hw, cam_intr, c2w) + bnds[:, 0] = torch.min(bnds[:, 0], torch.min(view_frust_pts, dim=1)[0]) + bnds[:, 1] = torch.max(bnds[:, 1], torch.max(view_frust_pts, dim=1)[0]) + view_frust_pts_list.append(view_frust_pts) + all_view_frust_pts = torch.cat(view_frust_pts_list, dim=1) + + # print("all_view_frust_pts: ", all_view_frust_pts.shape) + # distance = torch.norm(all_view_frust_pts, dim=0) + # print("distance: ", distance) + + # print("all_view_frust_pts_z: ", all_view_frust_pts[2, :]) + + center = torch.tensor(((bnds[0, 1] + bnds[0, 0]) / 2, (bnds[1, 1] + bnds[1, 0]) / 2, + (bnds[2, 1] + bnds[2, 0]) / 2)) + + lengths = bnds[:, 1] - bnds[:, 0] + + max_length, _ = torch.max(lengths, dim=0) + radius = max_length / 2 + + # print("radius: ", radius) + return center, radius, bnds diff --git a/SparseNeuS_demo_v1/evaluation/__init__.py b/SparseNeuS_demo_v1/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SparseNeuS_demo_v1/evaluation/clean_mesh.py b/SparseNeuS_demo_v1/evaluation/clean_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..ab65cc72d3be615b71ec852a7adea933355aa250 --- /dev/null +++ b/SparseNeuS_demo_v1/evaluation/clean_mesh.py @@ -0,0 +1,283 @@ +import numpy as np +import cv2 as cv +import os +from glob import glob +from scipy.io import loadmat +import trimesh +import open3d as o3d +import torch +from tqdm import tqdm + +import sys + +sys.path.append("../") + + +def gen_rays_from_single_image(H, W, image, intrinsic, c2w, depth=None, mask=None): + """ + generate rays in world space, for image image + :param H: + :param W: + :param intrinsics: [3,3] + :param c2ws: [4,4] + :return: + """ + device = image.device + ys, xs = torch.meshgrid(torch.linspace(0, H - 1, H), + torch.linspace(0, W - 1, W)) # pytorch's meshgrid has indexing='ij' + p = torch.stack([xs, ys, torch.ones_like(ys)], dim=-1) # H, W, 3 + + # normalized ndc uv coordinates, (-1, 1) + ndc_u = 2 * xs / (W - 1) - 1 + ndc_v = 2 * ys / (H - 1) - 1 + rays_ndc_uv = torch.stack([ndc_u, ndc_v], dim=-1).view(-1, 2).float().to(device) + + intrinsic_inv = torch.inverse(intrinsic) + + p = p.view(-1, 3).float().to(device) # N_rays, 3 + p = torch.matmul(intrinsic_inv[None, :3, :3], p[:, :, None]).squeeze() # N_rays, 3 + rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # N_rays, 3 + rays_v = torch.matmul(c2w[None, :3, :3], rays_v[:, :, None]).squeeze() # N_rays, 3 + rays_o = c2w[None, :3, 3].expand(rays_v.shape) # N_rays, 3 + + image = image.permute(1, 2, 0) + color = image.view(-1, 3) + depth = depth.view(-1, 1) if depth is not None else None + mask = mask.view(-1, 1) if mask is not None else torch.ones([H * W, 1]).to(device) + sample = { + 'rays_o': rays_o, + 'rays_v': rays_v, + 'rays_ndc_uv': rays_ndc_uv, + 'rays_color': color, + # 'rays_depth': depth, + 'rays_mask': mask, + 'rays_norm_XYZ_cam': p # - XYZ_cam, before multiply depth + } + if depth is not None: + sample['rays_depth'] = depth + + return sample + + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +def clean_points_by_mask(points, scan, imgs_idx=None, minimal_vis=0, mask_dilated_size=11): + cameras = np.load('{}/scan{}/cameras.npz'.format(DTU_DIR, scan)) + mask_lis = sorted(glob('{}/scan{}/mask/*.png'.format(DTU_DIR, scan))) + n_images = 49 if scan < 83 else 64 + inside_mask = np.zeros(len(points)) + + if imgs_idx is None: + imgs_idx = [i for i in range(n_images)] + + # imgs_idx = [i for i in range(n_images)] + for i in imgs_idx: + P = cameras['world_mat_{}'.format(i)] + pts_image = np.matmul(P[None, :3, :3], points[:, :, None]).squeeze() + P[None, :3, 3] + pts_image = pts_image / pts_image[:, 2:] + pts_image = np.round(pts_image).astype(np.int32) + 1 + + mask_image = cv.imread(mask_lis[i]) + kernel_size = mask_dilated_size # default 101 + kernel = cv.getStructuringElement(cv.MORPH_ELLIPSE, (kernel_size, kernel_size)) + mask_image = cv.dilate(mask_image, kernel, iterations=1) + mask_image = (mask_image[:, :, 0] > 128) + + mask_image = np.concatenate([np.ones([1, 1600]), mask_image, np.ones([1, 1600])], axis=0) + mask_image = np.concatenate([np.ones([1202, 1]), mask_image, np.ones([1202, 1])], axis=1) + + in_mask = (pts_image[:, 0] >= 0) * (pts_image[:, 0] <= 1600) * (pts_image[:, 1] >= 0) * ( + pts_image[:, 1] <= 1200) > 0 + curr_mask = mask_image[(pts_image[:, 1].clip(0, 1201), pts_image[:, 0].clip(0, 1601))] + + curr_mask = curr_mask.astype(np.float32) * in_mask + + inside_mask += curr_mask + + return inside_mask > minimal_vis + + +def clean_mesh_faces_by_mask(mesh_file, new_mesh_file, scan, imgs_idx, minimal_vis=0, mask_dilated_size=11): + old_mesh = trimesh.load(mesh_file) + old_vertices = old_mesh.vertices[:] + old_faces = old_mesh.faces[:] + mask = clean_points_by_mask(old_vertices, scan, imgs_idx, minimal_vis, mask_dilated_size) + indexes = np.ones(len(old_vertices)) * -1 + indexes = indexes.astype(np.long) + indexes[np.where(mask)] = np.arange(len(np.where(mask)[0])) + + faces_mask = mask[old_faces[:, 0]] & mask[old_faces[:, 1]] & mask[old_faces[:, 2]] + new_faces = old_faces[np.where(faces_mask)] + new_faces[:, 0] = indexes[new_faces[:, 0]] + new_faces[:, 1] = indexes[new_faces[:, 1]] + new_faces[:, 2] = indexes[new_faces[:, 2]] + new_vertices = old_vertices[np.where(mask)] + + new_mesh = trimesh.Trimesh(new_vertices, new_faces) + + new_mesh.export(new_mesh_file) + + +def clean_mesh_by_faces_num(mesh, faces_num=500): + old_vertices = mesh.vertices[:] + old_faces = mesh.faces[:] + + cc = trimesh.graph.connected_components(mesh.face_adjacency, min_len=faces_num) + mask = np.zeros(len(mesh.faces), dtype=np.bool) + mask[np.concatenate(cc)] = True + + indexes = np.ones(len(old_vertices)) * -1 + indexes = indexes.astype(np.long) + indexes[np.where(mask)] = np.arange(len(np.where(mask)[0])) + + faces_mask = mask[old_faces[:, 0]] & mask[old_faces[:, 1]] & mask[old_faces[:, 2]] + new_faces = old_faces[np.where(faces_mask)] + new_faces[:, 0] = indexes[new_faces[:, 0]] + new_faces[:, 1] = indexes[new_faces[:, 1]] + new_faces[:, 2] = indexes[new_faces[:, 2]] + new_vertices = old_vertices[np.where(mask)] + + new_mesh = trimesh.Trimesh(new_vertices, new_faces) + + return new_mesh + + +def clean_mesh_faces_outside_frustum(old_mesh_file, new_mesh_file, imgs_idx, H=1200, W=1600, mask_dilated_size=11, + isolated_face_num=500, keep_largest=True): + '''Remove faces of mesh which cannot be orserved by all cameras + ''' + # if path_mask_npz: + # path_save_clean = IOUtils.add_file_name_suffix(path_save_clean, '_mask') + + cameras = np.load('{}/scan{}/cameras.npz'.format(DTU_DIR, scan)) + mask_lis = sorted(glob('{}/scan{}/mask/*.png'.format(DTU_DIR, scan))) + + mesh = trimesh.load(old_mesh_file) + intersector = trimesh.ray.ray_pyembree.RayMeshIntersector(mesh) + + all_indices = [] + chunk_size = 5120 + for i in imgs_idx: + mask_image = cv.imread(mask_lis[i]) + kernel_size = mask_dilated_size # default 101 + kernel = cv.getStructuringElement(cv.MORPH_ELLIPSE, (kernel_size, kernel_size)) + mask_image = cv.dilate(mask_image, kernel, iterations=1) + + P = cameras['world_mat_{}'.format(i)] + + intrinsic, pose = load_K_Rt_from_P(None, P[:3, :]) + + rays = gen_rays_from_single_image(H, W, torch.from_numpy(mask_image).permute(2, 0, 1).float(), + torch.from_numpy(intrinsic)[:3, :3].float(), + torch.from_numpy(pose).float()) + rays_o = rays['rays_o'] + rays_d = rays['rays_v'] + rays_mask = rays['rays_color'] + + rays_o = rays_o.split(chunk_size) + rays_d = rays_d.split(chunk_size) + rays_mask = rays_mask.split(chunk_size) + + for rays_o_batch, rays_d_batch, rays_mask_batch in tqdm(zip(rays_o, rays_d, rays_mask)): + rays_mask_batch = rays_mask_batch[:, 0] > 128 + rays_o_batch = rays_o_batch[rays_mask_batch] + rays_d_batch = rays_d_batch[rays_mask_batch] + + idx_faces_hits = intersector.intersects_first(rays_o_batch.cpu().numpy(), rays_d_batch.cpu().numpy()) + all_indices.append(idx_faces_hits) + + values = np.unique(np.concatenate(all_indices, axis=0)) + mask_faces = np.ones(len(mesh.faces)) + mask_faces[values[1:]] = 0 + print(f'Surfaces/Kept: {len(mesh.faces)}/{len(values)}') + + mesh_o3d = o3d.io.read_triangle_mesh(old_mesh_file) + print("removing triangles by mask") + mesh_o3d.remove_triangles_by_mask(mask_faces) + + o3d.io.write_triangle_mesh(new_mesh_file, mesh_o3d) + + # # clean meshes + new_mesh = trimesh.load(new_mesh_file) + cc = trimesh.graph.connected_components(new_mesh.face_adjacency, min_len=500) + mask = np.zeros(len(new_mesh.faces), dtype=np.bool) + mask[np.concatenate(cc)] = True + new_mesh.update_faces(mask) + new_mesh.remove_unreferenced_vertices() + new_mesh.export(new_mesh_file) + + # meshes = new_mesh.split(only_watertight=False) + # + # if not keep_largest: + # meshes = [mesh for mesh in meshes if len(mesh.faces) > isolated_face_num] + # # new_mesh = meshes[np.argmax([len(mesh.faces) for mesh in meshes])] + # merged_mesh = trimesh.util.concatenate(meshes) + # merged_mesh.export(new_mesh_file) + # else: + # new_mesh = meshes[np.argmax([len(mesh.faces) for mesh in meshes])] + # new_mesh.export(new_mesh_file) + + o3d.io.write_triangle_mesh(new_mesh_file.replace(".ply", "_raw.ply"), mesh_o3d) + print("finishing removing triangles") + + +def clean_outliers(old_mesh_file, new_mesh_file): + new_mesh = trimesh.load(old_mesh_file) + + meshes = new_mesh.split(only_watertight=False) + new_mesh = meshes[np.argmax([len(mesh.faces) for mesh in meshes])] + + new_mesh.export(new_mesh_file) + + +if __name__ == "__main__": + + scans = [24, 37, 40, 55, 63, 65, 69, 83, 97, 105, 106, 110, 114, 118, 122] + + mask_kernel_size = 11 + + imgs_idx = [0, 1, 2] + # imgs_idx = [42, 43, 44] + # imgs_idx = [1, 8, 9] + + DTU_DIR = "/home/xiaoxiao/dataset/DTU_IDR/DTU" + # DTU_DIR = "/userhome/cs/xxlong/dataset/DTU_IDR/DTU" + + base_path = "/home/xiaoxiao/Workplace/nerf_reconstruction/Volume_NeuS/neus_camsys/exp/dtu/evaluation_23_24_33_new/volsdf" + + for scan in scans: + print("processing scan%d" % scan) + dir_path = os.path.join(base_path, "scan%d" % scan) + + old_mesh_file = glob(os.path.join(dir_path, "*.ply"))[0] + + clean_mesh_file = os.path.join(dir_path, "clean_%03d.ply" % scan) + final_mesh_file = os.path.join(dir_path, "final_%03d.ply" % scan) + + clean_mesh_faces_by_mask(old_mesh_file, clean_mesh_file, scan, imgs_idx, minimal_vis=1, + mask_dilated_size=mask_kernel_size) + clean_mesh_faces_outside_frustum(clean_mesh_file, final_mesh_file, imgs_idx, mask_dilated_size=mask_kernel_size) + + print("finish processing scan%d" % scan) diff --git a/SparseNeuS_demo_v1/evaluation/eval_dtu_python.py b/SparseNeuS_demo_v1/evaluation/eval_dtu_python.py new file mode 100644 index 0000000000000000000000000000000000000000..a60230705ab3f8c7c2a0ed64a20634c7ab4d2eea --- /dev/null +++ b/SparseNeuS_demo_v1/evaluation/eval_dtu_python.py @@ -0,0 +1,369 @@ +import numpy as np +import open3d as o3d +import sklearn.neighbors as skln +from tqdm import tqdm +from scipy.io import loadmat +import multiprocessing as mp +import argparse, os, sys +import cv2 as cv + +from pathlib import Path + + +def get_path_components(path): + path = Path(path) + ppath = str(path.parent) + stem = str(path.stem) + ext = str(path.suffix) + return ppath, stem, ext + + +def sample_single_tri(input_): + n1, n2, v1, v2, tri_vert = input_ + c = np.mgrid[:n1 + 1, :n2 + 1] + c += 0.5 + c[0] /= max(n1, 1e-7) + c[1] /= max(n2, 1e-7) + c = np.transpose(c, (1, 2, 0)) + k = c[c.sum(axis=-1) < 1] # m2 + q = v1 * k[:, :1] + v2 * k[:, 1:] + tri_vert + return q + + +def write_vis_pcd(file, points, colors): + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + pcd.colors = o3d.utility.Vector3dVector(colors) + o3d.io.write_point_cloud(file, pcd) + + +def eval_cloud(args, num_cpu_cores=-1): + mp.freeze_support() + os.makedirs(args.vis_out_dir, exist_ok=True) + + thresh = args.downsample_density + if args.mode == 'mesh': + pbar = tqdm(total=9) + pbar.set_description('read data mesh') + data_mesh = o3d.io.read_triangle_mesh(args.data) + + vertices = np.asarray(data_mesh.vertices) + triangles = np.asarray(data_mesh.triangles) + tri_vert = vertices[triangles] + + pbar.update(1) + pbar.set_description('sample pcd from mesh') + v1 = tri_vert[:, 1] - tri_vert[:, 0] + v2 = tri_vert[:, 2] - tri_vert[:, 0] + l1 = np.linalg.norm(v1, axis=-1, keepdims=True) + l2 = np.linalg.norm(v2, axis=-1, keepdims=True) + area2 = np.linalg.norm(np.cross(v1, v2), axis=-1, keepdims=True) + non_zero_area = (area2 > 0)[:, 0] + l1, l2, area2, v1, v2, tri_vert = [ + arr[non_zero_area] for arr in [l1, l2, area2, v1, v2, tri_vert] + ] + thr = thresh * np.sqrt(l1 * l2 / area2) + n1 = np.floor(l1 / thr) + n2 = np.floor(l2 / thr) + + with mp.Pool() as mp_pool: + new_pts = mp_pool.map(sample_single_tri, + ((n1[i, 0], n2[i, 0], v1[i:i + 1], v2[i:i + 1], tri_vert[i:i + 1, 0]) for i in + range(len(n1))), chunksize=1024) + + new_pts = np.concatenate(new_pts, axis=0) + data_pcd = np.concatenate([vertices, new_pts], axis=0) + + elif args.mode == 'pcd': + pbar = tqdm(total=8) + pbar.set_description('read data pcd') + data_pcd_o3d = o3d.io.read_point_cloud(args.data) + data_pcd = np.asarray(data_pcd_o3d.points) + + pbar.update(1) + pbar.set_description('random shuffle pcd index') + shuffle_rng = np.random.default_rng() + shuffle_rng.shuffle(data_pcd, axis=0) + + pbar.update(1) + pbar.set_description('downsample pcd') + nn_engine = skln.NearestNeighbors(n_neighbors=1, radius=thresh, algorithm='kd_tree', n_jobs=num_cpu_cores) + nn_engine.fit(data_pcd) + rnn_idxs = nn_engine.radius_neighbors(data_pcd, radius=thresh, return_distance=False) + mask = np.ones(data_pcd.shape[0], dtype=np.bool_) + for curr, idxs in enumerate(rnn_idxs): + if mask[curr]: + mask[idxs] = 0 + mask[curr] = 1 + data_down = data_pcd[mask] + + pbar.update(1) + pbar.set_description('masking data pcd') + obs_mask_file = loadmat(f'{args.dataset_dir}/ObsMask/ObsMask{args.scan}_10.mat') + ObsMask, BB, Res = [obs_mask_file[attr] for attr in ['ObsMask', 'BB', 'Res']] + BB = BB.astype(np.float32) + + patch = args.patch_size + inbound = ((data_down >= BB[:1] - patch) & (data_down < BB[1:] + patch * 2)).sum(axis=-1) == 3 + data_in = data_down[inbound] + + data_grid = np.around((data_in - BB[:1]) / Res).astype(np.int32) + grid_inbound = ((data_grid >= 0) & (data_grid < np.expand_dims(ObsMask.shape, 0))).sum(axis=-1) == 3 + data_grid_in = data_grid[grid_inbound] + in_obs = ObsMask[data_grid_in[:, 0], data_grid_in[:, 1], data_grid_in[:, 2]].astype(np.bool_) + data_in_obs = data_in[grid_inbound][in_obs] + + pbar.update(1) + pbar.set_description('read STL pcd') + stl_pcd = o3d.io.read_point_cloud(args.gt) + stl = np.asarray(stl_pcd.points) + + pbar.update(1) + pbar.set_description('compute data2stl') + nn_engine.fit(stl) + dist_d2s, idx_d2s = nn_engine.kneighbors(data_in_obs, n_neighbors=1, return_distance=True) + max_dist = args.max_dist + mean_d2s = dist_d2s[dist_d2s < max_dist].mean() + + pbar.update(1) + pbar.set_description('compute stl2data') + ground_plane = loadmat(f'{args.dataset_dir}/ObsMask/Plane{args.scan}.mat')['P'] + + stl_hom = np.concatenate([stl, np.ones_like(stl[:, :1])], -1) + above = (ground_plane.reshape((1, 4)) * stl_hom).sum(-1) > 0 + stl_above = stl[above] + + nn_engine.fit(data_in) + dist_s2d, idx_s2d = nn_engine.kneighbors(stl_above, n_neighbors=1, return_distance=True) + mean_s2d = dist_s2d[dist_s2d < max_dist].mean() + + pbar.update(1) + pbar.set_description('visualize error') + vis_dist = args.visualize_threshold + R = np.array([[1, 0, 0]], dtype=np.float64) + G = np.array([[0, 1, 0]], dtype=np.float64) + B = np.array([[0, 0, 1]], dtype=np.float64) + W = np.array([[1, 1, 1]], dtype=np.float64) + data_color = np.tile(B, (data_down.shape[0], 1)) + data_alpha = dist_d2s.clip(max=vis_dist) / vis_dist + data_color[np.where(inbound)[0][grid_inbound][in_obs]] = R * data_alpha + W * (1 - data_alpha) + data_color[np.where(inbound)[0][grid_inbound][in_obs][dist_d2s[:, 0] >= max_dist]] = G + write_vis_pcd(f'{args.vis_out_dir}/vis_{args.scan:03}_d2gt.ply', data_down, data_color) + stl_color = np.tile(B, (stl.shape[0], 1)) + stl_alpha = dist_s2d.clip(max=vis_dist) / vis_dist + stl_color[np.where(above)[0]] = R * stl_alpha + W * (1 - stl_alpha) + stl_color[np.where(above)[0][dist_s2d[:, 0] >= max_dist]] = G + write_vis_pcd(f'{args.vis_out_dir}/vis_{args.scan:03}_gt2d.ply', stl, stl_color) + + pbar.update(1) + pbar.set_description('done') + pbar.close() + over_all = (mean_d2s + mean_s2d) / 2 + print(f'ean_d2gt: {mean_d2s}; mean_gt2d: {mean_s2d} over_all: {over_all}; .') + + pparent, stem, ext = get_path_components(args.data) + if args.log is None: + path_log = os.path.join(pparent, 'eval_result.txt') + else: + path_log = args.log + with open(path_log, 'a+') as fLog: + fLog.write(f'mean_d2gt {np.round(mean_d2s, 3)} ' + f'mean_gt2d {np.round(mean_s2d, 3)} ' + f'Over_all {np.round(over_all, 3)} ' + f'[{stem}] \n') + + return over_all, mean_d2s, mean_s2d + + +if __name__ == '__main__': + from glob import glob + + mp.freeze_support() + + parser = argparse.ArgumentParser() + parser.add_argument('--data', type=str, default='data_in.ply') + parser.add_argument('--gt', type=str, help='ground truth') + parser.add_argument('--scan', type=int, default=1) + parser.add_argument('--mode', type=str, default='mesh', choices=['mesh', 'pcd']) + parser.add_argument('--dataset_dir', type=str, default='/dataset/dtu_official/SampleSet/MVS_Data') + parser.add_argument('--vis_out_dir', type=str, default='.') + parser.add_argument('--downsample_density', type=float, default=0.2) + parser.add_argument('--patch_size', type=float, default=60) + parser.add_argument('--max_dist', type=float, default=20) + parser.add_argument('--visualize_threshold', type=float, default=10) + parser.add_argument('--log', type=str, default=None) + args = parser.parse_args() + + base_dir = "./exp" + + GT_DIR = "./gt_pcd" + + scans = [24, 37, 40, 55, 63, 65, 69, 83, 97, 105, 106, 110, 114, 118, 122] + + for scan in scans: + + print("processing scan%d" % scan) + + args.data = os.path.join(base_dir, "scan{}".format(scan), "final_%03d.ply" % scan) + + if not os.path.exists(args.data): + continue + + args.gt = os.path.join(GT_DIR, "stl%03d_total.ply" % scan) + args.vis_out_dir = os.path.join(base_dir, "scan{}".format(scan)) + args.scan = scan + os.makedirs(args.vis_out_dir, exist_ok=True) + + dist_thred1 = 1 + dist_thred2 = 2 + + thresh = args.downsample_density + + if args.mode == 'mesh': + pbar = tqdm(total=9) + pbar.set_description('read data mesh') + data_mesh = o3d.io.read_triangle_mesh(args.data) + + vertices = np.asarray(data_mesh.vertices) + triangles = np.asarray(data_mesh.triangles) + tri_vert = vertices[triangles] + + pbar.update(1) + pbar.set_description('sample pcd from mesh') + v1 = tri_vert[:, 1] - tri_vert[:, 0] + v2 = tri_vert[:, 2] - tri_vert[:, 0] + l1 = np.linalg.norm(v1, axis=-1, keepdims=True) + l2 = np.linalg.norm(v2, axis=-1, keepdims=True) + area2 = np.linalg.norm(np.cross(v1, v2), axis=-1, keepdims=True) + non_zero_area = (area2 > 0)[:, 0] + l1, l2, area2, v1, v2, tri_vert = [ + arr[non_zero_area] for arr in [l1, l2, area2, v1, v2, tri_vert] + ] + thr = thresh * np.sqrt(l1 * l2 / area2) + n1 = np.floor(l1 / thr) + n2 = np.floor(l2 / thr) + + with mp.Pool() as mp_pool: + new_pts = mp_pool.map(sample_single_tri, + ((n1[i, 0], n2[i, 0], v1[i:i + 1], v2[i:i + 1], tri_vert[i:i + 1, 0]) for i in + range(len(n1))), chunksize=1024) + + new_pts = np.concatenate(new_pts, axis=0) + data_pcd = np.concatenate([vertices, new_pts], axis=0) + + elif args.mode == 'pcd': + pbar = tqdm(total=8) + pbar.set_description('read data pcd') + data_pcd_o3d = o3d.io.read_point_cloud(args.data) + data_pcd = np.asarray(data_pcd_o3d.points) + + pbar.update(1) + pbar.set_description('random shuffle pcd index') + shuffle_rng = np.random.default_rng() + shuffle_rng.shuffle(data_pcd, axis=0) + + pbar.update(1) + pbar.set_description('downsample pcd') + nn_engine = skln.NearestNeighbors(n_neighbors=1, radius=thresh, algorithm='kd_tree', n_jobs=-1) + nn_engine.fit(data_pcd) + rnn_idxs = nn_engine.radius_neighbors(data_pcd, radius=thresh, return_distance=False) + mask = np.ones(data_pcd.shape[0], dtype=np.bool_) + for curr, idxs in enumerate(rnn_idxs): + if mask[curr]: + mask[idxs] = 0 + mask[curr] = 1 + data_down = data_pcd[mask] + + pbar.update(1) + pbar.set_description('masking data pcd') + obs_mask_file = loadmat(f'{args.dataset_dir}/ObsMask/ObsMask{args.scan}_10.mat') + ObsMask, BB, Res = [obs_mask_file[attr] for attr in ['ObsMask', 'BB', 'Res']] + BB = BB.astype(np.float32) + + patch = args.patch_size + inbound = ((data_down >= BB[:1] - patch) & (data_down < BB[1:] + patch * 2)).sum(axis=-1) == 3 + data_in = data_down[inbound] + + data_grid = np.around((data_in - BB[:1]) / Res).astype(np.int32) + grid_inbound = ((data_grid >= 0) & (data_grid < np.expand_dims(ObsMask.shape, 0))).sum(axis=-1) == 3 + data_grid_in = data_grid[grid_inbound] + in_obs = ObsMask[data_grid_in[:, 0], data_grid_in[:, 1], data_grid_in[:, 2]].astype(np.bool_) + data_in_obs = data_in[grid_inbound][in_obs] + + pbar.update(1) + pbar.set_description('read STL pcd') + stl_pcd = o3d.io.read_point_cloud(args.gt) + stl = np.asarray(stl_pcd.points) + + pbar.update(1) + pbar.set_description('compute data2stl') + nn_engine.fit(stl) + dist_d2s, idx_d2s = nn_engine.kneighbors(data_in_obs, n_neighbors=1, return_distance=True) + max_dist = args.max_dist + mean_d2s = dist_d2s[dist_d2s < max_dist].mean() + + precision_1 = len(dist_d2s[dist_d2s < dist_thred1]) / len(dist_d2s) + precision_2 = len(dist_d2s[dist_d2s < dist_thred2]) / len(dist_d2s) + + pbar.update(1) + pbar.set_description('compute stl2data') + ground_plane = loadmat(f'{args.dataset_dir}/ObsMask/Plane{args.scan}.mat')['P'] + + stl_hom = np.concatenate([stl, np.ones_like(stl[:, :1])], -1) + above = (ground_plane.reshape((1, 4)) * stl_hom).sum(-1) > 0 + + stl_above = stl[above] + + nn_engine.fit(data_in) + dist_s2d, idx_s2d = nn_engine.kneighbors(stl_above, n_neighbors=1, return_distance=True) + mean_s2d = dist_s2d[dist_s2d < max_dist].mean() + + recall_1 = len(dist_s2d[dist_s2d < dist_thred1]) / len(dist_s2d) + recall_2 = len(dist_s2d[dist_s2d < dist_thred2]) / len(dist_s2d) + + pbar.update(1) + pbar.set_description('visualize error') + vis_dist = args.visualize_threshold + R = np.array([[1, 0, 0]], dtype=np.float64) + G = np.array([[0, 1, 0]], dtype=np.float64) + B = np.array([[0, 0, 1]], dtype=np.float64) + W = np.array([[1, 1, 1]], dtype=np.float64) + data_color = np.tile(B, (data_down.shape[0], 1)) + data_alpha = dist_d2s.clip(max=vis_dist) / vis_dist + data_color[np.where(inbound)[0][grid_inbound][in_obs]] = R * data_alpha + W * (1 - data_alpha) + data_color[np.where(inbound)[0][grid_inbound][in_obs][dist_d2s[:, 0] >= max_dist]] = G + write_vis_pcd(f'{args.vis_out_dir}/vis_{args.scan:03}_d2gt.ply', data_down, data_color) + stl_color = np.tile(B, (stl.shape[0], 1)) + stl_alpha = dist_s2d.clip(max=vis_dist) / vis_dist + stl_color[np.where(above)[0]] = R * stl_alpha + W * (1 - stl_alpha) + stl_color[np.where(above)[0][dist_s2d[:, 0] >= max_dist]] = G + write_vis_pcd(f'{args.vis_out_dir}/vis_{args.scan:03}_gt2d.ply', stl, stl_color) + + pbar.update(1) + pbar.set_description('done') + pbar.close() + over_all = (mean_d2s + mean_s2d) / 2 + + fscore_1 = 2 * precision_1 * recall_1 / (precision_1 + recall_1 + 1e-6) + fscore_2 = 2 * precision_2 * recall_2 / (precision_2 + recall_2 + 1e-6) + + print(f'over_all: {over_all}; mean_d2gt: {mean_d2s}; mean_gt2d: {mean_s2d}.') + print(f'precision_1mm: {precision_1}; recall_1mm: {recall_1}; fscore_1mm: {fscore_1}') + print(f'precision_2mm: {precision_2}; recall_2mm: {recall_2}; fscore_2mm: {fscore_2}') + + pparent, stem, ext = get_path_components(args.data) + if args.log is None: + path_log = os.path.join(pparent, 'eval_result.txt') + else: + path_log = args.log + with open(path_log, 'w+') as fLog: + fLog.write(f'over_all {np.round(over_all, 3)} ' + f'mean_d2gt {np.round(mean_d2s, 3)} ' + f'mean_gt2d {np.round(mean_s2d, 3)} \n' + f'precision_1mm {np.round(precision_1, 3)} ' + f'recall_1mm {np.round(recall_1, 3)} ' + f'fscore_1mm {np.round(fscore_1, 3)} \n' + f'precision_2mm {np.round(precision_2, 3)} ' + f'recall_2mm {np.round(recall_2, 3)} ' + f'fscore_2mm {np.round(fscore_2, 3)} \n' + f'[{stem}] \n') diff --git a/SparseNeuS_demo_v1/exp/lod0/checkpoint_trash/ckpt_285000.pth b/SparseNeuS_demo_v1/exp/lod0/checkpoint_trash/ckpt_285000.pth new file mode 100644 index 0000000000000000000000000000000000000000..043937847350af33b459128ade1a470064ce261c --- /dev/null +++ b/SparseNeuS_demo_v1/exp/lod0/checkpoint_trash/ckpt_285000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:763c2a4934928cc089342905ba61481d6f9efc977b9729d7fc2d3eae4f0e1f9b +size 5310703 diff --git a/SparseNeuS_demo_v1/exp/lod0/checkpoints/ckpt_340000.pth b/SparseNeuS_demo_v1/exp/lod0/checkpoints/ckpt_340000.pth new file mode 100644 index 0000000000000000000000000000000000000000..b5ba43d31ad82a3ccd5e5be45087e602fb98260e --- /dev/null +++ b/SparseNeuS_demo_v1/exp/lod0/checkpoints/ckpt_340000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3a947469b4b1a7044b2dcdd576e52279ed48a05d52135231137ece9f0ef810c8 +size 5310703 diff --git a/SparseNeuS_demo_v1/exp/lod0/checkpoints_white/ckpt_245000.pth b/SparseNeuS_demo_v1/exp/lod0/checkpoints_white/ckpt_245000.pth new file mode 100644 index 0000000000000000000000000000000000000000..90e582ba7a02d6b46dc2366a8b9ef61e195dd9ef --- /dev/null +++ b/SparseNeuS_demo_v1/exp/lod0/checkpoints_white/ckpt_245000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f40cd7db7f7a5ff16bb2bfbfcbd7c6a8c7e6d3032698cc5779eeaf507225cd97 +size 5310703 diff --git a/SparseNeuS_demo_v1/exp_runner_generic_blender_val.py b/SparseNeuS_demo_v1/exp_runner_generic_blender_val.py new file mode 100644 index 0000000000000000000000000000000000000000..8a94a670988c6c354fdface782fe88870cd891c5 --- /dev/null +++ b/SparseNeuS_demo_v1/exp_runner_generic_blender_val.py @@ -0,0 +1,656 @@ +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +import argparse +import os +import logging +import numpy as np +import cv2 as cv +import trimesh +from shutil import copyfile +from torch.utils.tensorboard import SummaryWriter +from icecream import ic +from tqdm import tqdm +from pyhocon import ConfigFactory + +from models.fields import SingleVarianceNetwork + +from models.featurenet import FeatureNet + +from models.trainer_generic import GenericTrainer + +from models.sparse_sdf_network import SparseSdfNetwork + +from models.rendering_network import GeneralRenderingNetwork + +from datetime import datetime + +from data.dtu_general import MVSDatasetDtuPerView + +from utils.training_utils import tocuda +from data.blender_general_narrow_all_eval_new_data import BlenderPerView + +from termcolor import colored + +from datetime import datetime + +class Runner: + def __init__(self, conf_path, mode='train', is_continue=False, + is_restore=False, restore_lod0=False, local_rank=0): + + # Initial setting + self.device = torch.device('cuda:%d' % local_rank) + # self.device = torch.device('cuda') + self.num_devices = torch.cuda.device_count() + self.is_continue = is_continue + self.is_restore = is_restore + self.restore_lod0 = restore_lod0 + self.mode = mode + self.model_list = [] + self.logger = logging.getLogger('exp_logger') + + print(colored("detected %d GPUs" % self.num_devices, "red")) + + self.conf_path = conf_path + self.conf = ConfigFactory.parse_file(conf_path) + self.timestamp = None + if not self.is_continue: + self.timestamp = '_{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now()) + self.base_exp_dir = self.conf['general.base_exp_dir'] + self.timestamp # jha comment this when testing and use this when training + else: + self.base_exp_dir = self.conf['general.base_exp_dir'] + self.conf['general.base_exp_dir'] = self.base_exp_dir # jha use this when testing + print(colored("base_exp_dir: " + self.base_exp_dir, 'yellow')) + os.makedirs(self.base_exp_dir, exist_ok=True) + self.iter_step = 0 + self.val_step = 0 + + # trainning parameters + self.end_iter = self.conf.get_int('train.end_iter') + self.save_freq = self.conf.get_int('train.save_freq') + self.report_freq = self.conf.get_int('train.report_freq') + self.val_freq = self.conf.get_int('train.val_freq') + self.val_mesh_freq = self.conf.get_int('train.val_mesh_freq') + self.batch_size = self.num_devices # use DataParallel to warp + self.validate_resolution_level = self.conf.get_int('train.validate_resolution_level') + self.learning_rate = self.conf.get_float('train.learning_rate') + self.learning_rate_milestone = self.conf.get_list('train.learning_rate_milestone') + self.learning_rate_factor = self.conf.get_float('train.learning_rate_factor') + self.use_white_bkgd = self.conf.get_bool('train.use_white_bkgd') + self.N_rays = self.conf.get_int('train.N_rays') + + # warmup params for sdf gradient + self.anneal_start_lod0 = self.conf.get_float('train.anneal_start', default=0) + self.anneal_end_lod0 = self.conf.get_float('train.anneal_end', default=0) + self.anneal_start_lod1 = self.conf.get_float('train.anneal_start_lod1', default=0) + self.anneal_end_lod1 = self.conf.get_float('train.anneal_end_lod1', default=0) + + self.writer = None + + # Networks + self.num_lods = self.conf.get_int('model.num_lods') + + self.rendering_network_outside = None + self.sdf_network_lod0 = None + self.sdf_network_lod1 = None + self.variance_network_lod0 = None + self.variance_network_lod1 = None + self.rendering_network_lod0 = None + self.rendering_network_lod1 = None + self.pyramid_feature_network = None # extract 2d pyramid feature maps from images, used for geometry + self.pyramid_feature_network_lod1 = None # may use different feature network for different lod + + # * pyramid_feature_network + self.pyramid_feature_network = FeatureNet().to(self.device) + self.sdf_network_lod0 = SparseSdfNetwork(**self.conf['model.sdf_network_lod0']).to(self.device) + self.variance_network_lod0 = SingleVarianceNetwork(**self.conf['model.variance_network']).to(self.device) + + if self.num_lods > 1: + self.sdf_network_lod1 = SparseSdfNetwork(**self.conf['model.sdf_network_lod1']).to(self.device) + self.variance_network_lod1 = SingleVarianceNetwork(**self.conf['model.variance_network']).to(self.device) + + self.rendering_network_lod0 = GeneralRenderingNetwork(**self.conf['model.rendering_network']).to( + self.device) + + if self.num_lods > 1: + self.pyramid_feature_network_lod1 = FeatureNet().to(self.device) + self.rendering_network_lod1 = GeneralRenderingNetwork( + **self.conf['model.rendering_network_lod1']).to(self.device) + if self.mode == 'export_mesh' or self.mode == 'val': + # base_exp_dir_to_store = os.path.join(self.base_exp_dir, '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now())) + print("save mesh to:", os.path.join("../", args.specific_dataset_name)) + base_exp_dir_to_store = os.path.join("../", args.specific_dataset_name) #"../gradio_tmp" # MODIFIED + else: + base_exp_dir_to_store = self.base_exp_dir + + print(colored(f"Store in: {base_exp_dir_to_store}", "blue")) + # Renderer model + self.trainer = GenericTrainer( + self.rendering_network_outside, + self.pyramid_feature_network, + self.pyramid_feature_network_lod1, + self.sdf_network_lod0, + self.sdf_network_lod1, + self.variance_network_lod0, + self.variance_network_lod1, + self.rendering_network_lod0, + self.rendering_network_lod1, + **self.conf['model.trainer'], + timestamp=self.timestamp, + base_exp_dir=base_exp_dir_to_store, + conf=self.conf) + + self.data_setup() # * data setup + + self.optimizer_setup() + + # Load checkpoint + latest_model_name = None + if is_continue: + model_list_raw = os.listdir(os.path.join(self.base_exp_dir, 'checkpoints')) + model_list = [] + for model_name in model_list_raw: + if model_name.startswith('ckpt'): + if model_name[-3:] == 'pth': # and int(model_name[5:-4]) <= self.end_iter: + model_list.append(model_name) + model_list.sort() + latest_model_name = model_list[-1] + + if latest_model_name is not None: + self.logger.info('Find checkpoint: {}'.format(latest_model_name)) + self.load_checkpoint(latest_model_name) + + self.trainer = torch.nn.DataParallel(self.trainer).to(self.device) + + if self.mode[:5] == 'train': + self.file_backup() + + def optimizer_setup(self): + self.params_to_train = self.trainer.get_trainable_params() + self.optimizer = torch.optim.Adam(self.params_to_train, lr=self.learning_rate) + + def data_setup(self): + """ + if use ddp, use setup() not prepare_data(), + prepare_data() only called on 1 GPU/TPU in distributed + :return: + """ + + self.train_dataset = BlenderPerView( + root_dir=self.conf['dataset.trainpath'], + split=self.conf.get_string('dataset.train_split', default='train'), + split_filepath=self.conf.get_string('dataset.train_split_filepath', default=None), + n_views=self.conf['dataset.nviews'], + downSample=self.conf['dataset.imgScale_train'], + N_rays=self.N_rays, + batch_size=self.batch_size, + clean_image=True, # True for training + importance_sample=self.conf.get_bool('dataset.importance_sample', default=False), + specific_dataset_name = args.specific_dataset_name + ) + + self.val_dataset = BlenderPerView( + root_dir=self.conf['dataset.valpath'], + split=self.conf.get_string('dataset.test_split', default='test'), + split_filepath=self.conf.get_string('dataset.val_split_filepath', default=None), + n_views=3, + downSample=self.conf['dataset.imgScale_test'], + N_rays=self.N_rays, + batch_size=self.batch_size, + clean_image=self.conf.get_bool('dataset.mask_out_image', + default=False) if self.mode != 'train' else False, + importance_sample=self.conf.get_bool('dataset.importance_sample', default=False), + test_ref_views=self.conf.get_list('dataset.test_ref_views', default=[]), + specific_dataset_name = args.specific_dataset_name + ) + + # item = self.train_dataset.__getitem__(0) + self.train_dataloader = DataLoader(self.train_dataset, + shuffle=True, + num_workers=4 * self.batch_size, + # num_workers=1, + batch_size=self.batch_size, + pin_memory=True, + drop_last=True + ) + + self.val_dataloader = DataLoader(self.val_dataset, + # shuffle=False if self.mode == 'train' else True, + shuffle=False, + num_workers=4 * self.batch_size, + # num_workers=1, + batch_size=self.batch_size, + pin_memory=True, + drop_last=False + ) + + self.val_dataloader_iterator = iter(self.val_dataloader) # - should be after "reconstruct_metas_for_gru_fusion" + + def train(self): + self.writer = SummaryWriter(log_dir=os.path.join(self.base_exp_dir, 'logs')) + res_step = self.end_iter - self.iter_step + + dataloader = self.train_dataloader + + epochs = int(1 + res_step // len(dataloader)) + + self.adjust_learning_rate() + print(colored("starting training learning rate: {:.5f}".format(self.optimizer.param_groups[0]['lr']), "yellow")) + + background_rgb = None + if self.use_white_bkgd: + # background_rgb = torch.ones([1, 3]).to(self.device) + background_rgb = 1.0 + + for epoch_i in range(epochs): + + print(colored("current epoch %d" % epoch_i, 'red')) + dataloader = tqdm(dataloader) + + for batch in dataloader: + # print("Checker1:, fetch data") + batch['batch_idx'] = torch.tensor([x for x in range(self.batch_size)]) # used to get meta + + # - warmup params + if self.num_lods == 1: + alpha_inter_ratio_lod0 = self.get_alpha_inter_ratio(self.anneal_start_lod0, self.anneal_end_lod0) + else: + alpha_inter_ratio_lod0 = 1. + alpha_inter_ratio_lod1 = self.get_alpha_inter_ratio(self.anneal_start_lod1, self.anneal_end_lod1) + + losses = self.trainer( + batch, + background_rgb=background_rgb, + alpha_inter_ratio_lod0=alpha_inter_ratio_lod0, + alpha_inter_ratio_lod1=alpha_inter_ratio_lod1, + iter_step=self.iter_step, + mode='train', + ) + + loss_types = ['loss_lod0', 'loss_lod1'] + # print("[TEST]: weights_sum in trainer return", losses['losses_lod0']['weights_sum'].mean()) + + losses_lod0 = losses['losses_lod0'] + losses_lod1 = losses['losses_lod1'] + # import ipdb; ipdb.set_trace() + loss = 0 + for loss_type in loss_types: + if losses[loss_type] is not None: + loss = loss + losses[loss_type].mean() + # print("Checker4:, begin BP") + self.optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(self.params_to_train, 1.0) + self.optimizer.step() + # print("Checker5:, end BP") + self.iter_step += 1 + + if self.iter_step % self.report_freq == 0: + self.writer.add_scalar('Loss/loss', loss, self.iter_step) + + if losses_lod0 is not None: + self.writer.add_scalar('Loss/d_loss_lod0', + losses_lod0['depth_loss'].mean() if losses_lod0 is not None else 0, + self.iter_step) + self.writer.add_scalar('Loss/sparse_loss_lod0', + losses_lod0[ + 'sparse_loss'].mean() if losses_lod0 is not None else 0, + self.iter_step) + self.writer.add_scalar('Loss/color_loss_lod0', + losses_lod0['color_fine_loss'].mean() + if losses_lod0['color_fine_loss'] is not None else 0, + self.iter_step) + + self.writer.add_scalar('statis/psnr_lod0', + losses_lod0['psnr'].mean() + if losses_lod0['psnr'] is not None else 0, + self.iter_step) + + self.writer.add_scalar('param/variance_lod0', + 1. / torch.exp(self.variance_network_lod0.variance * 10), + self.iter_step) + self.writer.add_scalar('param/eikonal_loss', losses_lod0['gradient_error_loss'].mean() if losses_lod0 is not None else 0, + self.iter_step) + + ######## - lod 1 + if self.num_lods > 1: + self.writer.add_scalar('Loss/d_loss_lod1', + losses_lod1['depth_loss'].mean() if losses_lod1 is not None else 0, + self.iter_step) + self.writer.add_scalar('Loss/sparse_loss_lod1', + losses_lod1[ + 'sparse_loss'].mean() if losses_lod1 is not None else 0, + self.iter_step) + self.writer.add_scalar('Loss/color_loss_lod1', + losses_lod1['color_fine_loss'].mean() + if losses_lod1['color_fine_loss'] is not None else 0, + self.iter_step) + self.writer.add_scalar('statis/sdf_mean_lod1', + losses_lod1['sdf_mean'].mean() if losses_lod1 is not None else 0, + self.iter_step) + self.writer.add_scalar('statis/psnr_lod1', + losses_lod1['psnr'].mean() + if losses_lod1['psnr'] is not None else 0, + self.iter_step) + self.writer.add_scalar('statis/sparseness_0.01_lod1', + losses_lod1['sparseness_1'].mean() + if losses_lod1['sparseness_1'] is not None else 0, + self.iter_step) + self.writer.add_scalar('statis/sparseness_0.02_lod1', + losses_lod1['sparseness_2'].mean() + if losses_lod1['sparseness_2'] is not None else 0, + self.iter_step) + self.writer.add_scalar('param/variance_lod1', + 1. / torch.exp(self.variance_network_lod1.variance * 10), + self.iter_step) + + print(self.base_exp_dir) + print( + 'iter:{:8>d} ' + 'loss = {:.4f} ' + 'd_loss_lod0 = {:.4f} ' + 'color_loss_lod0 = {:.4f} ' + 'sparse_loss_lod0= {:.4f} ' + 'd_loss_lod1 = {:.4f} ' + 'color_loss_lod1 = {:.4f} ' + ' lr = {:.5f}'.format( + self.iter_step, loss, + losses_lod0['depth_loss'].mean() if losses_lod0 is not None else 0, + losses_lod0['color_fine_loss'].mean() if losses_lod0 is not None else 0, + losses_lod0['sparse_loss'].mean() if losses_lod0 is not None else 0, + losses_lod1['depth_loss'].mean() if losses_lod1 is not None else 0, + losses_lod1['color_fine_loss'].mean() if losses_lod1 is not None else 0, + self.optimizer.param_groups[0]['lr'])) + + print(colored('alpha_inter_ratio_lod0 = {:.4f} alpha_inter_ratio_lod1 = {:.4f}\n'.format( + alpha_inter_ratio_lod0, alpha_inter_ratio_lod1), 'green')) + + if losses_lod0 is not None: + # print("[TEST]: weights_sum in print", losses_lod0['weights_sum'].mean()) + # import ipdb; ipdb.set_trace() + print( + 'iter:{:8>d} ' + 'variance = {:.5f} ' + 'weights_sum = {:.4f} ' + 'weights_sum_fg = {:.4f} ' + 'alpha_sum = {:.4f} ' + 'sparse_weight= {:.4f} ' + 'background_loss = {:.4f} ' + 'background_weight = {:.4f} ' + .format( + self.iter_step, + losses_lod0['variance'].mean(), + losses_lod0['weights_sum'].mean(), + losses_lod0['weights_sum_fg'].mean(), + losses_lod0['alpha_sum'].mean(), + losses_lod0['sparse_weight'].mean(), + losses_lod0['fg_bg_loss'].mean(), + losses_lod0['fg_bg_weight'].mean(), + )) + + if losses_lod1 is not None: + print( + 'iter:{:8>d} ' + 'variance = {:.5f} ' + ' weights_sum = {:.4f} ' + 'alpha_sum = {:.4f} ' + 'fg_bg_loss = {:.4f} ' + 'fg_bg_weight = {:.4f} ' + 'sparse_weight= {:.4f} ' + 'fg_bg_loss = {:.4f} ' + 'fg_bg_weight = {:.4f} ' + .format( + self.iter_step, + losses_lod1['variance'].mean(), + losses_lod1['weights_sum'].mean(), + losses_lod1['alpha_sum'].mean(), + losses_lod1['fg_bg_loss'].mean(), + losses_lod1['fg_bg_weight'].mean(), + losses_lod1['sparse_weight'].mean(), + losses_lod1['fg_bg_loss'].mean(), + losses_lod1['fg_bg_weight'].mean(), + )) + + if self.iter_step % self.save_freq == 0: + self.save_checkpoint() + + if self.iter_step % self.val_freq == 0: + self.validate() + + # - ajust learning rate + self.adjust_learning_rate() + + def adjust_learning_rate(self): + # - ajust learning rate, cosine learning schedule + learning_rate = (np.cos(np.pi * self.iter_step / self.end_iter) + 1.0) * 0.5 * 0.9 + 0.1 + learning_rate = self.learning_rate * learning_rate + for g in self.optimizer.param_groups: + g['lr'] = learning_rate + + def get_alpha_inter_ratio(self, start, end): + if end == 0.0: + return 1.0 + elif self.iter_step < start: + return 0.0 + else: + return np.min([1.0, (self.iter_step - start) / (end - start)]) + + def file_backup(self): + # copy python file + dir_lis = self.conf['general.recording'] + os.makedirs(os.path.join(self.base_exp_dir, 'recording'), exist_ok=True) + for dir_name in dir_lis: + cur_dir = os.path.join(self.base_exp_dir, 'recording', dir_name) + os.makedirs(cur_dir, exist_ok=True) + files = os.listdir(dir_name) + for f_name in files: + if f_name[-3:] == '.py': + copyfile(os.path.join(dir_name, f_name), os.path.join(cur_dir, f_name)) + + # copy configs + copyfile(self.conf_path, os.path.join(self.base_exp_dir, 'recording', 'config.conf')) + + def load_checkpoint(self, checkpoint_name): + + def load_state_dict(network, checkpoint, comment): + if network is not None: + try: + pretrained_dict = checkpoint[comment] + + model_dict = network.state_dict() + + # 1. filter out unnecessary keys + pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} + # 2. overwrite entries in the existing state dict + model_dict.update(pretrained_dict) + # 3. load the new state dict + network.load_state_dict(pretrained_dict) + except: + print(colored(comment + " load fails", 'yellow')) + + checkpoint = torch.load(os.path.join(self.base_exp_dir, 'checkpoints', checkpoint_name), + map_location=self.device) + + load_state_dict(self.rendering_network_outside, checkpoint, 'rendering_network_outside') + + load_state_dict(self.sdf_network_lod0, checkpoint, 'sdf_network_lod0') + load_state_dict(self.sdf_network_lod1, checkpoint, 'sdf_network_lod1') + + load_state_dict(self.pyramid_feature_network, checkpoint, 'pyramid_feature_network') + load_state_dict(self.pyramid_feature_network_lod1, checkpoint, 'pyramid_feature_network_lod1') + + load_state_dict(self.variance_network_lod0, checkpoint, 'variance_network_lod0') + load_state_dict(self.variance_network_lod1, checkpoint, 'variance_network_lod1') + + load_state_dict(self.rendering_network_lod0, checkpoint, 'rendering_network_lod0') + load_state_dict(self.rendering_network_lod1, checkpoint, 'rendering_network_lod1') + + if self.restore_lod0: # use the trained lod0 networks to initialize lod1 networks + load_state_dict(self.sdf_network_lod1, checkpoint, 'sdf_network_lod0') + load_state_dict(self.pyramid_feature_network_lod1, checkpoint, 'pyramid_feature_network') + load_state_dict(self.rendering_network_lod1, checkpoint, 'rendering_network_lod0') + + if self.is_continue and (not self.restore_lod0): + try: + self.optimizer.load_state_dict(checkpoint['optimizer']) + except: + print(colored("load optimizer fails", "yellow")) + self.iter_step = checkpoint['iter_step'] + self.val_step = checkpoint['val_step'] if 'val_step' in checkpoint.keys() else 0 + + self.logger.info('End') + + def save_checkpoint(self): + + def save_state_dict(network, checkpoint, comment): + if network is not None: + checkpoint[comment] = network.state_dict() + + checkpoint = { + 'optimizer': self.optimizer.state_dict(), + 'iter_step': self.iter_step, + 'val_step': self.val_step, + } + + save_state_dict(self.sdf_network_lod0, checkpoint, "sdf_network_lod0") + save_state_dict(self.sdf_network_lod1, checkpoint, "sdf_network_lod1") + + save_state_dict(self.rendering_network_outside, checkpoint, 'rendering_network_outside') + save_state_dict(self.rendering_network_lod0, checkpoint, "rendering_network_lod0") + save_state_dict(self.rendering_network_lod1, checkpoint, "rendering_network_lod1") + + save_state_dict(self.variance_network_lod0, checkpoint, 'variance_network_lod0') + save_state_dict(self.variance_network_lod1, checkpoint, 'variance_network_lod1') + + save_state_dict(self.pyramid_feature_network, checkpoint, 'pyramid_feature_network') + save_state_dict(self.pyramid_feature_network_lod1, checkpoint, 'pyramid_feature_network_lod1') + + os.makedirs(os.path.join(self.base_exp_dir, 'checkpoints'), exist_ok=True) + torch.save(checkpoint, + os.path.join(self.base_exp_dir, 'checkpoints', 'ckpt_{:0>6d}.pth'.format(self.iter_step))) + + def validate(self, idx=-1, resolution_level=-1): + # validate image + + ic(self.iter_step, idx) + self.logger.info('Validate begin') + + if idx < 0: + idx = self.val_step + # idx = np.random.randint(len(self.val_dataset)) + self.val_step += 1 + + try: + batch = next(self.val_dataloader_iterator) + except: + self.val_dataloader_iterator = iter(self.val_dataloader) # reset + + batch = next(self.val_dataloader_iterator) + + + background_rgb = None + if self.use_white_bkgd: + # background_rgb = torch.ones([1, 3]).to(self.device) + background_rgb = 1.0 + + batch['batch_idx'] = torch.tensor([x for x in range(self.batch_size)]) + + # - warmup params + if self.num_lods == 1: + alpha_inter_ratio_lod0 = self.get_alpha_inter_ratio(self.anneal_start_lod0, self.anneal_end_lod0) + else: + alpha_inter_ratio_lod0 = 1. + alpha_inter_ratio_lod1 = self.get_alpha_inter_ratio(self.anneal_start_lod1, self.anneal_end_lod1) + + self.trainer( + batch, + background_rgb=background_rgb, + alpha_inter_ratio_lod0=alpha_inter_ratio_lod0, + alpha_inter_ratio_lod1=alpha_inter_ratio_lod1, + iter_step=self.iter_step, + save_vis=True, + mode='val', + ) + + + def export_mesh(self, idx=-1, resolution_level=-1): + # validate image + + ic(self.iter_step, idx) + self.logger.info('Validate begin') + import time + start1 = time.time() + if idx < 0: + idx = self.val_step + # idx = np.random.randint(len(self.val_dataset)) + self.val_step += 1 + + try: + batch = next(self.val_dataloader_iterator) + except: + self.val_dataloader_iterator = iter(self.val_dataloader) # reset + + batch = next(self.val_dataloader_iterator) + + + background_rgb = None + if self.use_white_bkgd: + # background_rgb = torch.ones([1, 3]).to(self.device) + background_rgb = 1.0 + + batch['batch_idx'] = torch.tensor([x for x in range(self.batch_size)]) + + # - warmup params + if self.num_lods == 1: + alpha_inter_ratio_lod0 = self.get_alpha_inter_ratio(self.anneal_start_lod0, self.anneal_end_lod0) + else: + alpha_inter_ratio_lod0 = 1. + alpha_inter_ratio_lod1 = self.get_alpha_inter_ratio(self.anneal_start_lod1, self.anneal_end_lod1) + end1 = time.time() + print("time for getting data", end1 - start1) + self.trainer( + batch, + background_rgb=background_rgb, + alpha_inter_ratio_lod0=alpha_inter_ratio_lod0, + alpha_inter_ratio_lod1=alpha_inter_ratio_lod1, + iter_step=self.iter_step, + save_vis=True, + mode='export_mesh', + ) + + +if __name__ == '__main__': + # torch.set_default_tensor_type('torch.cuda.FloatTensor') + torch.set_default_dtype(torch.float32) + FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s" + logging.basicConfig(level=logging.INFO, format=FORMAT) + + parser = argparse.ArgumentParser() + parser.add_argument('--conf', type=str, default='./confs/base.conf') + parser.add_argument('--mode', type=str, default='train') + parser.add_argument('--threshold', type=float, default=0.0) + parser.add_argument('--is_continue', default=False, action="store_true") + parser.add_argument('--is_restore', default=False, action="store_true") + parser.add_argument('--is_finetune', default=False, action="store_true") + parser.add_argument('--train_from_scratch', default=False, action="store_true") + parser.add_argument('--restore_lod0', default=False, action="store_true") + parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument('--specific_dataset_name', type=str, default='GSO') + + + args = parser.parse_args() + + torch.cuda.set_device(args.local_rank) + torch.backends.cudnn.benchmark = True # ! make training 2x faster + + runner = Runner(args.conf, args.mode, args.is_continue, args.is_restore, args.restore_lod0, + args.local_rank) + + if args.mode == 'train': + runner.train() + elif args.mode == 'val': + for i in range(len(runner.val_dataset)): + runner.validate() + elif args.mode == 'export_mesh': + for i in range(len(runner.val_dataset)): + runner.export_mesh() diff --git a/SparseNeuS_demo_v1/loss/__init__.py b/SparseNeuS_demo_v1/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SparseNeuS_demo_v1/loss/color_loss.py b/SparseNeuS_demo_v1/loss/color_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..ddb5d1d0d6f7b71416b010ad8d167f4c2eb04f1c --- /dev/null +++ b/SparseNeuS_demo_v1/loss/color_loss.py @@ -0,0 +1,156 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import icecream as ic +from loss.ncc import NCC +from termcolor import colored + + +class Normalize(nn.Module): + def __init__(self): + super(Normalize, self).__init__() + + def forward(self, bottom): + qn = torch.norm(bottom, p=2, dim=1).unsqueeze(dim=1) + 1e-12 + top = bottom.div(qn) + + return top + + +class OcclusionColorLoss(nn.Module): + def __init__(self, alpha=1, beta=0.025, gama=0.01, occlusion_aware=True, weight_thred=[0.6]): + super(OcclusionColorLoss, self).__init__() + self.alpha = alpha + self.beta = beta + self.gama = gama + self.occlusion_aware = occlusion_aware + self.eps = 1e-4 + + self.weight_thred = weight_thred + self.adjuster = ParamAdjuster(self.weight_thred, self.beta) + + def forward(self, pred, gt, weight, mask, detach=False, occlusion_aware=True): + """ + + :param pred: [N_pts, 3] + :param gt: [N_pts, 3] + :param weight: [N_pts] + :param mask: [N_pts] + :return: + """ + if detach: + weight = weight.detach() + + error = torch.abs(pred - gt).sum(dim=-1, keepdim=False) # [N_pts] + error = error[mask] + + if not (self.occlusion_aware and occlusion_aware): + return torch.mean(error), torch.mean(error) + + beta = self.adjuster(weight.mean()) + + # weight = weight[mask] + weight = weight.clamp(0.0, 1.0) + term1 = self.alpha * torch.mean(weight[mask] * error) + term2 = beta * torch.log(1 - weight + self.eps).mean() + term3 = self.gama * torch.log(weight + self.eps).mean() + + return term1 + term2 + term3, term1 + + +class OcclusionColorPatchLoss(nn.Module): + def __init__(self, alpha=1, beta=0.025, gama=0.015, + occlusion_aware=True, type='l1', h_patch_size=3, weight_thred=[0.6]): + super(OcclusionColorPatchLoss, self).__init__() + self.alpha = alpha + self.beta = beta + self.gama = gama + self.occlusion_aware = occlusion_aware + self.type = type # 'l1' or 'ncc' loss + self.ncc = NCC(h_patch_size=h_patch_size) + self.eps = 1e-4 + self.weight_thred = weight_thred + + self.adjuster = ParamAdjuster(self.weight_thred, self.beta) + + print("type {} patch_size {} beta {} gama {} weight_thred {}".format(type, h_patch_size, beta, gama, + weight_thred)) + + def forward(self, pred, gt, weight, mask, penalize_ratio=0.9, detach=False, occlusion_aware=True): + """ + + :param pred: [N_pts, Npx, 3] + :param gt: [N_pts, Npx, 3] + :param weight: [N_pts] + :param mask: [N_pts] + :return: + """ + + if detach: + weight = weight.detach() + + if self.type == 'l1': + error = torch.abs(pred - gt).mean(dim=-1, keepdim=False).sum(dim=-1, keepdim=False) # [N_pts] + elif self.type == 'ncc': + error = 1 - self.ncc(pred[:, None, :, :], gt)[:, 0] # ncc 1 positive, -1 negative + error, indices = torch.sort(error) + mask = torch.index_select(mask, 0, index=indices) + mask[int(penalize_ratio * mask.shape[0]):] = False # can help boundaries + elif self.type == 'ssd': + error = ((pred - gt) ** 2).mean(dim=-1, keepdim=False).sum(dim=-1, keepdims=False) + + error = error[mask] + if not (self.occlusion_aware and occlusion_aware): + return torch.mean(error), torch.mean(error), 0. + + # * weight adjuster + beta = self.adjuster(weight.mean()) + + # weight = weight[mask] + weight = weight.clamp(0.0, 1.0) + + term1 = self.alpha * torch.mean(weight[mask] * error) + term2 = beta * torch.log(1 - weight + self.eps).mean() + term3 = self.gama * torch.log(weight + self.eps).mean() + + return term1 + term2 + term3, term1, beta + + +class ParamAdjuster(nn.Module): + def __init__(self, weight_thred, param): + super(ParamAdjuster, self).__init__() + self.weight_thred = weight_thred + self.thred_num = len(weight_thred) + self.param = param + self.global_step = 0 + self.statis_window = 100 + self.counter = 0 + self.adjusted = False + self.adjusted_step = 0 + self.thred_idx = 0 + + def reset(self): + self.counter = 0 + self.adjusted = False + + def adjust(self): + if (self.counter / self.statis_window) > 0.3: + self.param = self.param + 0.005 + self.adjusted = True + self.adjusted_step = self.global_step + self.thred_idx += 1 + print(colored("ajusted param, now {}".format(self.param), 'red')) + + def forward(self, weight_mean): + self.global_step += 1 + + if (self.global_step % self.statis_window == 0) and self.adjusted is False: + self.adjust() + self.reset() + + if self.thred_idx < self.thred_num: + if weight_mean < self.weight_thred[self.thred_idx] and (not self.adjusted): + self.counter += 1 + + return self.param diff --git a/SparseNeuS_demo_v1/loss/depth_loss.py b/SparseNeuS_demo_v1/loss/depth_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..cba92851a79857ff6edd5c2f2eb12a2972b85bdc --- /dev/null +++ b/SparseNeuS_demo_v1/loss/depth_loss.py @@ -0,0 +1,71 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DepthLoss(nn.Module): + def __init__(self, type='l1'): + super(DepthLoss, self).__init__() + self.type = type + + + def forward(self, depth_pred, depth_gt, mask=None): + if (depth_gt < 0).sum() > 0: + # print("no depth loss") + return torch.tensor(0.0).to(depth_pred.device) + if mask is not None: + mask_d = (depth_gt > 0).float() + + mask = mask * mask_d + + mask_sum = mask.sum() + 1e-5 + depth_error = (depth_pred - depth_gt) * mask + depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device), + reduction='sum') / mask_sum + else: + depth_error = depth_pred - depth_gt + depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device), + reduction='mean') + return depth_loss + +def forward(self, depth_pred, depth_gt, mask=None): + if mask is not None: + mask_d = (depth_gt > 0).float() + + mask = mask * mask_d + + mask_sum = mask.sum() + 1e-5 + depth_error = (depth_pred - depth_gt) * mask + depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device), + reduction='sum') / mask_sum + else: + depth_error = depth_pred - depth_gt + depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device), + reduction='mean') + return depth_loss + +class DepthSmoothLoss(nn.Module): + def __init__(self): + super(DepthSmoothLoss, self).__init__() + + def forward(self, disp, img, mask): + """ + Computes the smoothness loss for a disparity image + The color image is used for edge-aware smoothness + :param disp: [B, 1, H, W] + :param img: [B, 1, H, W] + :param mask: [B, 1, H, W] + :return: + """ + grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:]) + grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :]) + + grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True) + grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True) + + grad_disp_x *= torch.exp(-grad_img_x) + grad_disp_y *= torch.exp(-grad_img_y) + + grad_disp = (grad_disp_x * mask[:, :, :, :-1]).mean() + (grad_disp_y * mask[:, :, :-1, :]).mean() + + return grad_disp diff --git a/SparseNeuS_demo_v1/loss/depth_metric.py b/SparseNeuS_demo_v1/loss/depth_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..e8b6249ac6a06906e20a344f468fc1c6e4b992ae --- /dev/null +++ b/SparseNeuS_demo_v1/loss/depth_metric.py @@ -0,0 +1,240 @@ +import numpy as np + + +def l1(depth1, depth2): + """ + Computes the l1 errors between the two depth maps. + Takes preprocessed depths (no nans, infs and non-positive values) + + depth1: one depth map + depth2: another depth map + + Returns: + L1(log) + + """ + assert (np.all(np.isfinite(depth1) & np.isfinite(depth2) & (depth1 >= 0) & (depth2 >= 0))) + diff = depth1 - depth2 + num_pixels = float(diff.size) + + if num_pixels == 0: + return np.nan + else: + return np.sum(np.absolute(diff)) / num_pixels + + +def l1_inverse(depth1, depth2): + """ + Computes the l1 errors between inverses of two depth maps. + Takes preprocessed depths (no nans, infs and non-positive values) + + depth1: one depth map + depth2: another depth map + + Returns: + L1(log) + + """ + assert (np.all(np.isfinite(depth1) & np.isfinite(depth2) & (depth1 >= 0) & (depth2 >= 0))) + diff = np.reciprocal(depth1) - np.reciprocal(depth2) + num_pixels = float(diff.size) + + if num_pixels == 0: + return np.nan + else: + return np.sum(np.absolute(diff)) / num_pixels + + +def rmse_log(depth1, depth2): + """ + Computes the root min square errors between the logs of two depth maps. + Takes preprocessed depths (no nans, infs and non-positive values) + + depth1: one depth map + depth2: another depth map + + Returns: + RMSE(log) + + """ + assert (np.all(np.isfinite(depth1) & np.isfinite(depth2) & (depth1 >= 0) & (depth2 >= 0))) + log_diff = np.log(depth1) - np.log(depth2) + num_pixels = float(log_diff.size) + + if num_pixels == 0: + return np.nan + else: + return np.sqrt(np.sum(np.square(log_diff)) / num_pixels) + + +def rmse(depth1, depth2): + """ + Computes the root min square errors between the two depth maps. + Takes preprocessed depths (no nans, infs and non-positive values) + + depth1: one depth map + depth2: another depth map + + Returns: + RMSE(log) + + """ + assert (np.all(np.isfinite(depth1) & np.isfinite(depth2) & (depth1 >= 0) & (depth2 >= 0))) + diff = depth1 - depth2 + num_pixels = float(diff.size) + + if num_pixels == 0: + return np.nan + else: + return np.sqrt(np.sum(np.square(diff)) / num_pixels) + + +def scale_invariant(depth1, depth2): + """ + Computes the scale invariant loss based on differences of logs of depth maps. + Takes preprocessed depths (no nans, infs and non-positive values) + + depth1: one depth map + depth2: another depth map + + Returns: + scale_invariant_distance + + """ + # sqrt(Eq. 3) + assert (np.all(np.isfinite(depth1) & np.isfinite(depth2) & (depth1 >= 0) & (depth2 >= 0))) + log_diff = np.log(depth1) - np.log(depth2) + num_pixels = float(log_diff.size) + + if num_pixels == 0: + return np.nan + else: + return np.sqrt(np.sum(np.square(log_diff)) / num_pixels - np.square(np.sum(log_diff)) / np.square(num_pixels)) + + +def abs_relative(depth_pred, depth_gt): + """ + Computes relative absolute distance. + Takes preprocessed depths (no nans, infs and non-positive values) + + depth_pred: depth map prediction + depth_gt: depth map ground truth + + Returns: + abs_relative_distance + + """ + assert (np.all(np.isfinite(depth_pred) & np.isfinite(depth_gt) & (depth_pred >= 0) & (depth_gt >= 0))) + diff = depth_pred - depth_gt + num_pixels = float(diff.size) + + if num_pixels == 0: + return np.nan + else: + return np.sum(np.absolute(diff) / depth_gt) / num_pixels + + +def avg_log10(depth1, depth2): + """ + Computes average log_10 error (Liu, Neural Fields, 2015). + Takes preprocessed depths (no nans, infs and non-positive values) + + depth1: one depth map + depth2: another depth map + + Returns: + abs_relative_distance + + """ + assert (np.all(np.isfinite(depth1) & np.isfinite(depth2) & (depth1 >= 0) & (depth2 >= 0))) + log_diff = np.log10(depth1) - np.log10(depth2) + num_pixels = float(log_diff.size) + + if num_pixels == 0: + return np.nan + else: + return np.sum(np.absolute(log_diff)) / num_pixels + + +def sq_relative(depth_pred, depth_gt): + """ + Computes relative squared distance. + Takes preprocessed depths (no nans, infs and non-positive values) + + depth_pred: depth map prediction + depth_gt: depth map ground truth + + Returns: + squared_relative_distance + + """ + assert (np.all(np.isfinite(depth_pred) & np.isfinite(depth_gt) & (depth_pred >= 0) & (depth_gt >= 0))) + diff = depth_pred - depth_gt + num_pixels = float(diff.size) + + if num_pixels == 0: + return np.nan + else: + return np.sum(np.square(diff) / depth_gt) / num_pixels + + +def ratio_threshold(depth1, depth2, threshold): + """ + Computes the percentage of pixels for which the ratio of the two depth maps is less than a given threshold. + Takes preprocessed depths (no nans, infs and non-positive values) + + depth1: one depth map + depth2: another depth map + + Returns: + percentage of pixels with ratio less than the threshold + + """ + assert (threshold > 0.) + assert (np.all(np.isfinite(depth1) & np.isfinite(depth2) & (depth1 >= 0) & (depth2 >= 0))) + log_diff = np.log(depth1) - np.log(depth2) + num_pixels = float(log_diff.size) + + if num_pixels == 0: + return np.nan + else: + return float(np.sum(np.absolute(log_diff) < np.log(threshold))) / num_pixels + + +def compute_depth_errors(depth_pred, depth_gt, valid_mask): + """ + Computes different distance measures between two depth maps. + + depth_pred: depth map prediction + depth_gt: depth map ground truth + distances_to_compute: which distances to compute + + Returns: + a dictionary with computed distances, and the number of valid pixels + + """ + depth_pred = depth_pred[valid_mask] + depth_gt = depth_gt[valid_mask] + num_valid = np.sum(valid_mask) + + distances_to_compute = ['l1', + 'l1_inverse', + 'scale_invariant', + 'abs_relative', + 'sq_relative', + 'avg_log10', + 'rmse_log', + 'rmse', + 'ratio_threshold_1.25', + 'ratio_threshold_1.5625', + 'ratio_threshold_1.953125'] + + results = {'num_valid': num_valid} + for dist in distances_to_compute: + if dist.startswith('ratio_threshold'): + threshold = float(dist.split('_')[-1]) + results[dist] = ratio_threshold(depth_pred, depth_gt, threshold) + else: + results[dist] = globals()[dist](depth_pred, depth_gt) + + return results diff --git a/SparseNeuS_demo_v1/loss/ncc.py b/SparseNeuS_demo_v1/loss/ncc.py new file mode 100644 index 0000000000000000000000000000000000000000..768fcefc3aab55d8e3fed49f23ffb4a974eec4ec --- /dev/null +++ b/SparseNeuS_demo_v1/loss/ncc.py @@ -0,0 +1,65 @@ +import torch +import torch.nn.functional as F +import numpy as np +from math import exp, sqrt + + +class NCC(torch.nn.Module): + def __init__(self, h_patch_size, mode='rgb'): + super(NCC, self).__init__() + self.window_size = 2 * h_patch_size + 1 + self.mode = mode # 'rgb' or 'gray' + self.channel = 3 + self.register_buffer("window", create_window(self.window_size, self.channel)) + + def forward(self, img_pred, img_gt): + """ + :param img_pred: [Npx, nviews, npatch, c] + :param img_gt: [Npx, npatch, c] + :return: + """ + ntotpx, nviews, npatch, channels = img_pred.shape + + patch_size = int(sqrt(npatch)) + patch_img_pred = img_pred.reshape(ntotpx, nviews, patch_size, patch_size, channels).permute(0, 1, 4, 2, + 3).contiguous() + patch_img_gt = img_gt.reshape(ntotpx, patch_size, patch_size, channels).permute(0, 3, 1, 2) + + return _ncc(patch_img_pred, patch_img_gt, self.window, self.channel) + + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) + return gauss / gauss.sum() + + +def create_window(window_size, channel, std=1.5): + _1D_window = gaussian(window_size, std).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).unsqueeze(0).unsqueeze(0) + window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() + return window + + +def _ncc(pred, gt, window, channel): + ntotpx, nviews, nc, h, w = pred.shape + flat_pred = pred.view(-1, nc, h, w) + mu1 = F.conv2d(flat_pred, window, padding=0, groups=channel).view(ntotpx, nviews, nc) + mu2 = F.conv2d(gt, window, padding=0, groups=channel).view(ntotpx, nc) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2).unsqueeze(1) # (ntotpx, 1, nc) + + sigma1_sq = F.conv2d(flat_pred * flat_pred, window, padding=0, groups=channel).view(ntotpx, nviews, nc) - mu1_sq + sigma2_sq = F.conv2d(gt * gt, window, padding=0, groups=channel).view(ntotpx, 1, 3) - mu2_sq + + sigma1 = torch.sqrt(sigma1_sq + 1e-4) + sigma2 = torch.sqrt(sigma2_sq + 1e-4) + + pred_norm = (pred - mu1[:, :, :, None, None]) / (sigma1[:, :, :, None, None] + 1e-8) # [ntotpx, nviews, nc, h, w] + gt_norm = (gt[:, None, :, :, :] - mu2[:, None, :, None, None]) / ( + sigma2[:, :, :, None, None] + 1e-8) # ntotpx, nc, h, w + + ncc = F.conv2d((pred_norm * gt_norm).view(-1, nc, h, w), window, padding=0, groups=channel).view( + ntotpx, nviews, nc) + + return torch.mean(ncc, dim=2) diff --git a/SparseNeuS_demo_v1/models/__init__.py b/SparseNeuS_demo_v1/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SparseNeuS_demo_v1/models/embedder.py b/SparseNeuS_demo_v1/models/embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..d327d92d9f64c0b32908dbee864160b65daa450e --- /dev/null +++ b/SparseNeuS_demo_v1/models/embedder.py @@ -0,0 +1,101 @@ +import torch +import torch.nn as nn + +""" Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """ + + +class Embedder: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_embedding_fn() + + def create_embedding_fn(self): + embed_fns = [] + d = self.kwargs['input_dims'] + out_dim = 0 + if self.kwargs['include_input']: + embed_fns.append(lambda x: x) + out_dim += d + + max_freq = self.kwargs['max_freq_log2'] + N_freqs = self.kwargs['num_freqs'] + + if self.kwargs['log_sampling']: + freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs) + else: + freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, N_freqs) + + for freq in freq_bands: + for p_fn in self.kwargs['periodic_fns']: + if self.kwargs['normalize']: + embed_fns.append(lambda x, p_fn=p_fn, + freq=freq: p_fn(x * freq) / freq) + else: + embed_fns.append(lambda x, p_fn=p_fn, + freq=freq: p_fn(x * freq)) + out_dim += d + + self.embed_fns = embed_fns + self.out_dim = out_dim + + def embed(self, inputs): + return torch.cat([fn(inputs) for fn in self.embed_fns], -1) + + +def get_embedder(multires, normalize=False, input_dims=3): + embed_kwargs = { + 'include_input': True, + 'input_dims': input_dims, + 'max_freq_log2': multires - 1, + 'num_freqs': multires, + 'normalize': normalize, + 'log_sampling': True, + 'periodic_fns': [torch.sin, torch.cos], + } + + embedder_obj = Embedder(**embed_kwargs) + + def embed(x, eo=embedder_obj): return eo.embed(x) + + return embed, embedder_obj.out_dim + + +class Embedding(nn.Module): + def __init__(self, in_channels, N_freqs, logscale=True, normalize=False): + """ + Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...) + in_channels: number of input channels (3 for both xyz and direction) + """ + super(Embedding, self).__init__() + self.N_freqs = N_freqs + self.in_channels = in_channels + self.funcs = [torch.sin, torch.cos] + self.out_channels = in_channels * (len(self.funcs) * N_freqs + 1) + self.normalize = normalize + + if logscale: + self.freq_bands = 2 ** torch.linspace(0, N_freqs - 1, N_freqs) + else: + self.freq_bands = torch.linspace(1, 2 ** (N_freqs - 1), N_freqs) + + def forward(self, x): + """ + Embeds x to (x, sin(2^k x), cos(2^k x), ...) + Different from the paper, "x" is also in the output + See https://github.com/bmild/nerf/issues/12 + + Inputs: + x: (B, self.in_channels) + + Outputs: + out: (B, self.out_channels) + """ + out = [x] + for freq in self.freq_bands: + for func in self.funcs: + if self.normalize: + out += [func(freq * x) / freq] + else: + out += [func(freq * x)] + + return torch.cat(out, -1) diff --git a/SparseNeuS_demo_v1/models/fast_renderer.py b/SparseNeuS_demo_v1/models/fast_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..1faeba85e5b156d0de12e430287d90f4a803aa92 --- /dev/null +++ b/SparseNeuS_demo_v1/models/fast_renderer.py @@ -0,0 +1,316 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +from icecream import ic + + +# - neus: use sphere-tracing to speed up depth maps extraction +# This code snippet is heavily borrowed from IDR. +class FastRenderer(nn.Module): + def __init__(self): + super(FastRenderer, self).__init__() + + self.sdf_threshold = 5e-5 + self.line_search_step = 0.5 + self.line_step_iters = 1 + self.sphere_tracing_iters = 10 + self.n_steps = 100 + self.n_secant_steps = 8 + + # - use sdf_network to inference sdf value or directly interpolate sdf value from precomputed sdf_volume + self.network_inference = False + + def extract_depth_maps(self, rays_o, rays_d, near, far, sdf_network, conditional_volume): + with torch.no_grad(): + curr_start_points, network_object_mask, acc_start_dis = self.get_intersection( + rays_o, rays_d, near, far, + sdf_network, conditional_volume) + + network_object_mask = network_object_mask.reshape(-1) + + return network_object_mask, acc_start_dis + + def get_intersection(self, rays_o, rays_d, near, far, sdf_network, conditional_volume): + device = rays_o.device + num_pixels, _ = rays_d.shape + + curr_start_points, unfinished_mask_start, acc_start_dis, acc_end_dis, min_dis, max_dis = \ + self.sphere_tracing(rays_o, rays_d, near, far, sdf_network, conditional_volume) + + network_object_mask = (acc_start_dis < acc_end_dis) + + # The non convergent rays should be handled by the sampler + sampler_mask = unfinished_mask_start + sampler_net_obj_mask = torch.zeros_like(sampler_mask).bool().to(device) + if sampler_mask.sum() > 0: + # sampler_min_max = torch.zeros((num_pixels, 2)).to(device) + # sampler_min_max[sampler_mask, 0] = acc_start_dis[sampler_mask] + # sampler_min_max[sampler_mask, 1] = acc_end_dis[sampler_mask] + + # ray_sampler(self, rays_o, rays_d, near, far, sampler_mask): + sampler_pts, sampler_net_obj_mask, sampler_dists = self.ray_sampler(rays_o, + rays_d, + acc_start_dis, + acc_end_dis, + sampler_mask, + sdf_network, + conditional_volume + ) + + curr_start_points[sampler_mask] = sampler_pts[sampler_mask] + acc_start_dis[sampler_mask] = sampler_dists[sampler_mask][:, None] + network_object_mask[sampler_mask] = sampler_net_obj_mask[sampler_mask][:, None] + + # print('----------------------------------------------------------------') + # print('RayTracing: object = {0}/{1}, secant on {2}/{3}.' + # .format(network_object_mask.sum(), len(network_object_mask), sampler_net_obj_mask.sum(), + # sampler_mask.sum())) + # print('----------------------------------------------------------------') + + return curr_start_points, network_object_mask, acc_start_dis + + def sphere_tracing(self, rays_o, rays_d, near, far, sdf_network, conditional_volume): + ''' Run sphere tracing algorithm for max iterations from both sides of unit sphere intersection ''' + + device = rays_o.device + + unfinished_mask_start = (near < far).reshape(-1).clone() + unfinished_mask_end = (near < far).reshape(-1).clone() + + # Initialize start current points + curr_start_points = rays_o + rays_d * near + acc_start_dis = near.clone() + + # Initialize end current points + curr_end_points = rays_o + rays_d * far + acc_end_dis = far.clone() + + # Initizlize min and max depth + min_dis = acc_start_dis.clone() + max_dis = acc_end_dis.clone() + + # Iterate on the rays (from both sides) till finding a surface + iters = 0 + + next_sdf_start = torch.zeros_like(acc_start_dis).to(device) + + if self.network_inference: + sdf_func = sdf_network.sdf + else: + sdf_func = sdf_network.sdf_from_sdfvolume + + next_sdf_start[unfinished_mask_start] = sdf_func( + curr_start_points[unfinished_mask_start], + conditional_volume, lod=0, gru_fusion=False)['sdf_pts_scale%d' % 0] + + next_sdf_end = torch.zeros_like(acc_end_dis).to(device) + next_sdf_end[unfinished_mask_end] = sdf_func(curr_end_points[unfinished_mask_end], + conditional_volume, lod=0, gru_fusion=False)[ + 'sdf_pts_scale%d' % 0] + + while True: + # Update sdf + curr_sdf_start = torch.zeros_like(acc_start_dis).to(device) + curr_sdf_start[unfinished_mask_start] = next_sdf_start[unfinished_mask_start] + curr_sdf_start[curr_sdf_start <= self.sdf_threshold] = 0 + + curr_sdf_end = torch.zeros_like(acc_end_dis).to(device) + curr_sdf_end[unfinished_mask_end] = next_sdf_end[unfinished_mask_end] + curr_sdf_end[curr_sdf_end <= self.sdf_threshold] = 0 + + # Update masks + unfinished_mask_start = unfinished_mask_start & (curr_sdf_start > self.sdf_threshold).reshape(-1) + unfinished_mask_end = unfinished_mask_end & (curr_sdf_end > self.sdf_threshold).reshape(-1) + + if ( + unfinished_mask_start.sum() == 0 and unfinished_mask_end.sum() == 0) or iters == self.sphere_tracing_iters: + break + iters += 1 + + # Make step + # Update distance + acc_start_dis = acc_start_dis + curr_sdf_start + acc_end_dis = acc_end_dis - curr_sdf_end + + # Update points + curr_start_points = rays_o + acc_start_dis * rays_d + curr_end_points = rays_o + acc_end_dis * rays_d + + # Fix points which wrongly crossed the surface + next_sdf_start = torch.zeros_like(acc_start_dis).to(device) + if unfinished_mask_start.sum() > 0: + next_sdf_start[unfinished_mask_start] = sdf_func(curr_start_points[unfinished_mask_start], + conditional_volume, lod=0, gru_fusion=False)[ + 'sdf_pts_scale%d' % 0] + + next_sdf_end = torch.zeros_like(acc_end_dis).to(device) + if unfinished_mask_end.sum() > 0: + next_sdf_end[unfinished_mask_end] = sdf_func(curr_end_points[unfinished_mask_end], + conditional_volume, lod=0, gru_fusion=False)[ + 'sdf_pts_scale%d' % 0] + + not_projected_start = (next_sdf_start < 0).reshape(-1) + not_projected_end = (next_sdf_end < 0).reshape(-1) + not_proj_iters = 0 + + while ( + not_projected_start.sum() > 0 or not_projected_end.sum() > 0) and not_proj_iters < self.line_step_iters: + # Step backwards + if not_projected_start.sum() > 0: + acc_start_dis[not_projected_start] -= ((1 - self.line_search_step) / (2 ** not_proj_iters)) * \ + curr_sdf_start[not_projected_start] + curr_start_points[not_projected_start] = (rays_o + acc_start_dis * rays_d)[not_projected_start] + + next_sdf_start[not_projected_start] = sdf_func( + curr_start_points[not_projected_start], + conditional_volume, lod=0, gru_fusion=False)['sdf_pts_scale%d' % 0] + + if not_projected_end.sum() > 0: + acc_end_dis[not_projected_end] += ((1 - self.line_search_step) / (2 ** not_proj_iters)) * \ + curr_sdf_end[ + not_projected_end] + curr_end_points[not_projected_end] = (rays_o + acc_end_dis * rays_d)[not_projected_end] + + # Calc sdf + + next_sdf_end[not_projected_end] = sdf_func( + curr_end_points[not_projected_end], + conditional_volume, lod=0, gru_fusion=False)['sdf_pts_scale%d' % 0] + + # Update mask + not_projected_start = (next_sdf_start < 0).reshape(-1) + not_projected_end = (next_sdf_end < 0).reshape(-1) + not_proj_iters += 1 + + unfinished_mask_start = unfinished_mask_start & (acc_start_dis < acc_end_dis).reshape(-1) + unfinished_mask_end = unfinished_mask_end & (acc_start_dis < acc_end_dis).reshape(-1) + + return curr_start_points, unfinished_mask_start, acc_start_dis, acc_end_dis, min_dis, max_dis + + def ray_sampler(self, rays_o, rays_d, near, far, sampler_mask, sdf_network, conditional_volume): + ''' Sample the ray in a given range and run secant on rays which have sign transition ''' + device = rays_o.device + num_pixels, _ = rays_d.shape + sampler_pts = torch.zeros(num_pixels, 3).to(device).float() + sampler_dists = torch.zeros(num_pixels).to(device).float() + + intervals_dist = torch.linspace(0, 1, steps=self.n_steps).to(device).view(1, -1) + + pts_intervals = near + intervals_dist * (far - near) + points = rays_o[:, None, :] + pts_intervals[:, :, None] * rays_d[:, None, :] + + # Get the non convergent rays + mask_intersect_idx = torch.nonzero(sampler_mask).flatten() + points = points.reshape((-1, self.n_steps, 3))[sampler_mask, :, :] + pts_intervals = pts_intervals.reshape((-1, self.n_steps))[sampler_mask] + + if self.network_inference: + sdf_func = sdf_network.sdf + else: + sdf_func = sdf_network.sdf_from_sdfvolume + + sdf_val_all = [] + for pnts in torch.split(points.reshape(-1, 3), 100000, dim=0): + sdf_val_all.append(sdf_func(pnts, + conditional_volume, lod=0, gru_fusion=False)['sdf_pts_scale%d' % 0]) + sdf_val = torch.cat(sdf_val_all).reshape(-1, self.n_steps) + + tmp = torch.sign(sdf_val) * torch.arange(self.n_steps, 0, -1).to(device).float().reshape( + (1, self.n_steps)) # Force argmin to return the first min value + sampler_pts_ind = torch.argmin(tmp, -1) + sampler_pts[mask_intersect_idx] = points[torch.arange(points.shape[0]), sampler_pts_ind, :] + sampler_dists[mask_intersect_idx] = pts_intervals[torch.arange(pts_intervals.shape[0]), sampler_pts_ind] + + net_surface_pts = (sdf_val[torch.arange(sdf_val.shape[0]), sampler_pts_ind] < 0) + + # take points with minimal SDF value for P_out pixels + p_out_mask = ~net_surface_pts + n_p_out = p_out_mask.sum() + if n_p_out > 0: + out_pts_idx = torch.argmin(sdf_val[p_out_mask, :], -1) + sampler_pts[mask_intersect_idx[p_out_mask]] = points[p_out_mask, :, :][torch.arange(n_p_out), out_pts_idx, + :] + sampler_dists[mask_intersect_idx[p_out_mask]] = pts_intervals[p_out_mask, :][ + torch.arange(n_p_out), out_pts_idx] + + # Get Network object mask + sampler_net_obj_mask = sampler_mask.clone() + sampler_net_obj_mask[mask_intersect_idx[~net_surface_pts]] = False + + # Run Secant method + secant_pts = net_surface_pts + n_secant_pts = secant_pts.sum() + if n_secant_pts > 0: + # Get secant z predictions + z_high = pts_intervals[torch.arange(pts_intervals.shape[0]), sampler_pts_ind][secant_pts] + sdf_high = sdf_val[torch.arange(sdf_val.shape[0]), sampler_pts_ind][secant_pts] + z_low = pts_intervals[secant_pts][torch.arange(n_secant_pts), sampler_pts_ind[secant_pts] - 1] + sdf_low = sdf_val[secant_pts][torch.arange(n_secant_pts), sampler_pts_ind[secant_pts] - 1] + + cam_loc_secant = rays_o[mask_intersect_idx[secant_pts]] + ray_directions_secant = rays_d[mask_intersect_idx[secant_pts]] + z_pred_secant = self.secant(sdf_low, sdf_high, z_low, z_high, cam_loc_secant, ray_directions_secant, + sdf_network, conditional_volume) + + # Get points + sampler_pts[mask_intersect_idx[secant_pts]] = cam_loc_secant + z_pred_secant[:, + None] * ray_directions_secant + sampler_dists[mask_intersect_idx[secant_pts]] = z_pred_secant + + return sampler_pts, sampler_net_obj_mask, sampler_dists + + def secant(self, sdf_low, sdf_high, z_low, z_high, rays_o, rays_d, sdf_network, conditional_volume): + ''' Runs the secant method for interval [z_low, z_high] for n_secant_steps ''' + + if self.network_inference: + sdf_func = sdf_network.sdf + else: + sdf_func = sdf_network.sdf_from_sdfvolume + + z_pred = -sdf_low * (z_high - z_low) / (sdf_high - sdf_low) + z_low + for i in range(self.n_secant_steps): + p_mid = rays_o + z_pred[:, None] * rays_d + sdf_mid = sdf_func(p_mid, + conditional_volume, lod=0, gru_fusion=False)['sdf_pts_scale%d' % 0].reshape(-1) + ind_low = (sdf_mid > 0).reshape(-1) + if ind_low.sum() > 0: + z_low[ind_low] = z_pred[ind_low] + sdf_low[ind_low] = sdf_mid[ind_low] + ind_high = sdf_mid < 0 + if ind_high.sum() > 0: + z_high[ind_high] = z_pred[ind_high] + sdf_high[ind_high] = sdf_mid[ind_high] + + z_pred = - sdf_low * (z_high - z_low) / (sdf_high - sdf_low) + z_low + + return z_pred # 1D tensor + + def minimal_sdf_points(self, num_pixels, sdf, cam_loc, ray_directions, mask, min_dis, max_dis): + ''' Find points with minimal SDF value on rays for P_out pixels ''' + device = sdf.device + n_mask_points = mask.sum() + + n = self.n_steps + # steps = torch.linspace(0.0, 1.0,n).to(device) + steps = torch.empty(n).uniform_(0.0, 1.0).to(device) + mask_max_dis = max_dis[mask].unsqueeze(-1) + mask_min_dis = min_dis[mask].unsqueeze(-1) + steps = steps.unsqueeze(0).repeat(n_mask_points, 1) * (mask_max_dis - mask_min_dis) + mask_min_dis + + mask_points = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3)[mask] + mask_rays = ray_directions[mask, :] + + mask_points_all = mask_points.unsqueeze(1).repeat(1, n, 1) + steps.unsqueeze(-1) * mask_rays.unsqueeze( + 1).repeat(1, n, 1) + points = mask_points_all.reshape(-1, 3) + + mask_sdf_all = [] + for pnts in torch.split(points, 100000, dim=0): + mask_sdf_all.append(sdf(pnts)) + + mask_sdf_all = torch.cat(mask_sdf_all).reshape(-1, n) + min_vals, min_idx = mask_sdf_all.min(-1) + min_mask_points = mask_points_all.reshape(-1, n, 3)[torch.arange(0, n_mask_points), min_idx] + min_mask_dist = steps.reshape(-1, n)[torch.arange(0, n_mask_points), min_idx] + + return min_mask_points, min_mask_dist diff --git a/SparseNeuS_demo_v1/models/featurenet.py b/SparseNeuS_demo_v1/models/featurenet.py new file mode 100644 index 0000000000000000000000000000000000000000..652e65967708f57a1722c5951d53e72f05ddf1d3 --- /dev/null +++ b/SparseNeuS_demo_v1/models/featurenet.py @@ -0,0 +1,91 @@ +import torch + +# ! amazing!!!! autograd.grad with set_detect_anomaly(True) will cause memory leak +# ! https://github.com/pytorch/pytorch/issues/51349 +# torch.autograd.set_detect_anomaly(True) +import torch.nn as nn +import torch.nn.functional as F +from inplace_abn import InPlaceABN + + +############################################# MVS Net models ################################################ +class ConvBnReLU(nn.Module): + def __init__(self, in_channels, out_channels, + kernel_size=3, stride=1, pad=1, + norm_act=InPlaceABN): + super(ConvBnReLU, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, + kernel_size, stride=stride, padding=pad, bias=False) + self.bn = norm_act(out_channels) + + def forward(self, x): + return self.bn(self.conv(x)) + + +class ConvBnReLU3D(nn.Module): + def __init__(self, in_channels, out_channels, + kernel_size=3, stride=1, pad=1, + norm_act=InPlaceABN): + super(ConvBnReLU3D, self).__init__() + self.conv = nn.Conv3d(in_channels, out_channels, + kernel_size, stride=stride, padding=pad, bias=False) + self.bn = norm_act(out_channels) + # self.bn = nn.ReLU() + + def forward(self, x): + return self.bn(self.conv(x)) + + +################################### feature net ###################################### +class FeatureNet(nn.Module): + """ + output 3 levels of features using a FPN structure + """ + + def __init__(self, norm_act=InPlaceABN): + super(FeatureNet, self).__init__() + + self.conv0 = nn.Sequential( + ConvBnReLU(3, 8, 3, 1, 1, norm_act=norm_act), + ConvBnReLU(8, 8, 3, 1, 1, norm_act=norm_act)) + + self.conv1 = nn.Sequential( + ConvBnReLU(8, 16, 5, 2, 2, norm_act=norm_act), + ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act), + ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act)) + + self.conv2 = nn.Sequential( + ConvBnReLU(16, 32, 5, 2, 2, norm_act=norm_act), + ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act), + ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act)) + + self.toplayer = nn.Conv2d(32, 32, 1) + self.lat1 = nn.Conv2d(16, 32, 1) + self.lat0 = nn.Conv2d(8, 32, 1) + + # to reduce channel size of the outputs from FPN + self.smooth1 = nn.Conv2d(32, 16, 3, padding=1) + self.smooth0 = nn.Conv2d(32, 8, 3, padding=1) + + def _upsample_add(self, x, y): + return F.interpolate(x, scale_factor=2, + mode="bilinear", align_corners=True) + y + + def forward(self, x): + # x: (B, 3, H, W) + conv0 = self.conv0(x) # (B, 8, H, W) + conv1 = self.conv1(conv0) # (B, 16, H//2, W//2) + conv2 = self.conv2(conv1) # (B, 32, H//4, W//4) + feat2 = self.toplayer(conv2) # (B, 32, H//4, W//4) + feat1 = self._upsample_add(feat2, self.lat1(conv1)) # (B, 32, H//2, W//2) + feat0 = self._upsample_add(feat1, self.lat0(conv0)) # (B, 32, H, W) + + # reduce output channels + feat1 = self.smooth1(feat1) # (B, 16, H//2, W//2) + feat0 = self.smooth0(feat0) # (B, 8, H, W) + + # feats = {"level_0": feat0, + # "level_1": feat1, + # "level_2": feat2} + + return [feat2, feat1, feat0] # coarser to finer features diff --git a/SparseNeuS_demo_v1/models/fields.py b/SparseNeuS_demo_v1/models/fields.py new file mode 100644 index 0000000000000000000000000000000000000000..184e4a55399f56f8f505379ce4a14add8821c4c4 --- /dev/null +++ b/SparseNeuS_demo_v1/models/fields.py @@ -0,0 +1,333 @@ +# The codes are from NeuS + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from models.embedder import get_embedder + + +class SDFNetwork(nn.Module): + def __init__(self, + d_in, + d_out, + d_hidden, + n_layers, + skip_in=(4,), + multires=0, + bias=0.5, + scale=1, + geometric_init=True, + weight_norm=True, + activation='softplus', + conditional_type='multiply'): + super(SDFNetwork, self).__init__() + + dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out] + + self.embed_fn_fine = None + + if multires > 0: + embed_fn, input_ch = get_embedder(multires, input_dims=d_in, normalize=False) + self.embed_fn_fine = embed_fn + dims[0] = input_ch + + self.num_layers = len(dims) + self.skip_in = skip_in + self.scale = scale + + for l in range(0, self.num_layers - 1): + if l + 1 in self.skip_in: + out_dim = dims[l + 1] - dims[0] + else: + out_dim = dims[l + 1] + + lin = nn.Linear(dims[l], out_dim) + + if geometric_init: + if l == self.num_layers - 2: + torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) + torch.nn.init.constant_(lin.bias, -bias) + elif multires > 0 and l == 0: + torch.nn.init.constant_(lin.bias, 0.0) + torch.nn.init.constant_(lin.weight[:, 3:], 0.0) + torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) + elif multires > 0 and l in self.skip_in: + torch.nn.init.constant_(lin.bias, 0.0) + torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) + torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) # ? why dims[0] - 3 + else: + torch.nn.init.constant_(lin.bias, 0.0) + torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) + + if weight_norm: + lin = nn.utils.weight_norm(lin) + + setattr(self, "lin" + str(l), lin) + + if activation == 'softplus': + self.activation = nn.Softplus(beta=100) + else: + assert activation == 'relu' + self.activation = nn.ReLU() + + def forward(self, inputs): + inputs = inputs * self.scale + if self.embed_fn_fine is not None: + inputs = self.embed_fn_fine(inputs) + + x = inputs + for l in range(0, self.num_layers - 1): + lin = getattr(self, "lin" + str(l)) + + if l in self.skip_in: + x = torch.cat([x, inputs], 1) / np.sqrt(2) + + x = lin(x) + + if l < self.num_layers - 2: + x = self.activation(x) + return torch.cat([x[:, :1] / self.scale, x[:, 1:]], dim=-1) + + def sdf(self, x): + return self.forward(x)[:, :1] + + def sdf_hidden_appearance(self, x): + return self.forward(x) + + def gradient(self, x): + x.requires_grad_(True) + y = self.sdf(x) + d_output = torch.ones_like(y, requires_grad=False, device=y.device) + gradients = torch.autograd.grad( + outputs=y, + inputs=x, + grad_outputs=d_output, + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + return gradients.unsqueeze(1) + + +class VarianceNetwork(nn.Module): + def __init__(self, d_in, d_out, d_hidden, n_layers, skip_in=(4,), multires=0): + super(VarianceNetwork, self).__init__() + + dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out] + + self.embed_fn_fine = None + + if multires > 0: + embed_fn, input_ch = get_embedder(multires, normalize=False) + self.embed_fn_fine = embed_fn + dims[0] = input_ch + + self.num_layers = len(dims) + self.skip_in = skip_in + + for l in range(0, self.num_layers - 1): + if l + 1 in self.skip_in: + out_dim = dims[l + 1] - dims[0] + else: + out_dim = dims[l + 1] + + lin = nn.Linear(dims[l], out_dim) + setattr(self, "lin" + str(l), lin) + + self.relu = nn.ReLU() + self.softplus = nn.Softplus(beta=100) + + def forward(self, inputs): + if self.embed_fn_fine is not None: + inputs = self.embed_fn_fine(inputs) + + x = inputs + for l in range(0, self.num_layers - 1): + lin = getattr(self, "lin" + str(l)) + + if l in self.skip_in: + x = torch.cat([x, inputs], 1) / np.sqrt(2) + + x = lin(x) + + if l < self.num_layers - 2: + x = self.relu(x) + + # return torch.exp(x) + return 1.0 / (self.softplus(x + 0.5) + 1e-3) + + def coarse(self, inputs): + return self.forward(inputs)[:, :1] + + def fine(self, inputs): + return self.forward(inputs)[:, 1:] + + +class FixVarianceNetwork(nn.Module): + def __init__(self, base): + super(FixVarianceNetwork, self).__init__() + self.base = base + self.iter_step = 0 + + def set_iter_step(self, iter_step): + self.iter_step = iter_step + + def forward(self, x): + return torch.ones([len(x), 1]) * np.exp(-self.iter_step / self.base) + + +class SingleVarianceNetwork(nn.Module): + def __init__(self, init_val=1.0): + super(SingleVarianceNetwork, self).__init__() + self.register_parameter('variance', nn.Parameter(torch.tensor(init_val))) + + def forward(self, x): + return torch.ones([len(x), 1]).to(x.device) * torch.exp(self.variance * 10.0) + + + +class RenderingNetwork(nn.Module): + def __init__( + self, + d_feature, + mode, + d_in, + d_out, + d_hidden, + n_layers, + weight_norm=True, + multires_view=0, + squeeze_out=True, + d_conditional_colors=0 + ): + super().__init__() + + self.mode = mode + self.squeeze_out = squeeze_out + dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out] + + self.embedview_fn = None + if multires_view > 0: + embedview_fn, input_ch = get_embedder(multires_view) + self.embedview_fn = embedview_fn + dims[0] += (input_ch - 3) + + self.num_layers = len(dims) + + for l in range(0, self.num_layers - 1): + out_dim = dims[l + 1] + lin = nn.Linear(dims[l], out_dim) + + if weight_norm: + lin = nn.utils.weight_norm(lin) + + setattr(self, "lin" + str(l), lin) + + self.relu = nn.ReLU() + + def forward(self, points, normals, view_dirs, feature_vectors): + if self.embedview_fn is not None: + view_dirs = self.embedview_fn(view_dirs) + + rendering_input = None + + if self.mode == 'idr': + rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1) + elif self.mode == 'no_view_dir': + rendering_input = torch.cat([points, normals, feature_vectors], dim=-1) + elif self.mode == 'no_normal': + rendering_input = torch.cat([points, view_dirs, feature_vectors], dim=-1) + elif self.mode == 'no_points': + rendering_input = torch.cat([view_dirs, normals, feature_vectors], dim=-1) + elif self.mode == 'no_points_no_view_dir': + rendering_input = torch.cat([normals, feature_vectors], dim=-1) + + x = rendering_input + + for l in range(0, self.num_layers - 1): + lin = getattr(self, "lin" + str(l)) + + x = lin(x) + + if l < self.num_layers - 2: + x = self.relu(x) + + if self.squeeze_out: + x = torch.sigmoid(x) + return x + + +# Code from nerf-pytorch +class NeRF(nn.Module): + def __init__(self, D=8, W=256, d_in=3, d_in_view=3, multires=0, multires_view=0, output_ch=4, skips=[4], + use_viewdirs=False): + """ + """ + super(NeRF, self).__init__() + self.D = D + self.W = W + self.d_in = d_in + self.d_in_view = d_in_view + self.input_ch = 3 + self.input_ch_view = 3 + self.embed_fn = None + self.embed_fn_view = None + + if multires > 0: + embed_fn, input_ch = get_embedder(multires, input_dims=d_in, normalize=False) + self.embed_fn = embed_fn + self.input_ch = input_ch + + if multires_view > 0: + embed_fn_view, input_ch_view = get_embedder(multires_view, input_dims=d_in_view, normalize=False) + self.embed_fn_view = embed_fn_view + self.input_ch_view = input_ch_view + + self.skips = skips + self.use_viewdirs = use_viewdirs + + self.pts_linears = nn.ModuleList( + [nn.Linear(self.input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + self.input_ch, W) + for i in + range(D - 1)]) + + ### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105) + self.views_linears = nn.ModuleList([nn.Linear(self.input_ch_view + W, W // 2)]) + + ### Implementation according to the paper + # self.views_linears = nn.ModuleList( + # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)]) + + if use_viewdirs: + self.feature_linear = nn.Linear(W, W) + self.alpha_linear = nn.Linear(W, 1) + self.rgb_linear = nn.Linear(W // 2, 3) + else: + self.output_linear = nn.Linear(W, output_ch) + + def forward(self, input_pts, input_views): + if self.embed_fn is not None: + input_pts = self.embed_fn(input_pts) + if self.embed_fn_view is not None: + input_views = self.embed_fn_view(input_views) + + h = input_pts + for i, l in enumerate(self.pts_linears): + h = self.pts_linears[i](h) + h = F.relu(h) + if i in self.skips: + h = torch.cat([input_pts, h], -1) + + if self.use_viewdirs: + alpha = self.alpha_linear(h) + feature = self.feature_linear(h) + h = torch.cat([feature, input_views], -1) + + for i, l in enumerate(self.views_linears): + h = self.views_linears[i](h) + h = F.relu(h) + + rgb = self.rgb_linear(h) + return alpha + 1.0, rgb + else: + assert False diff --git a/SparseNeuS_demo_v1/models/patch_projector.py b/SparseNeuS_demo_v1/models/patch_projector.py new file mode 100644 index 0000000000000000000000000000000000000000..cf9ca424c588e49d754988814233069b2cf127fa --- /dev/null +++ b/SparseNeuS_demo_v1/models/patch_projector.py @@ -0,0 +1,211 @@ +""" +Patch Projector +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from models.render_utils import sample_ptsFeatures_from_featureMaps + + +class PatchProjector(): + def __init__(self, patch_size): + self.h_patch_size = patch_size + self.offsets = build_patch_offset(patch_size) # the warping patch offsets index + + self.z_axis = torch.tensor([0, 0, 1]).float() + + self.plane_dist_thresh = 0.001 + + # * correctness checked + def pixel_warp(self, pts, imgs, intrinsics, + w2cs, img_wh=None): + """ + + :param pts: [N_rays, n_samples, 3] + :param imgs: [N_views, 3, H, W] + :param intrinsics: [N_views, 4, 4] + :param c2ws: [N_views, 4, 4] + :param img_wh: + :return: + """ + if img_wh is None: + N_views, _, sizeH, sizeW = imgs.shape + img_wh = [sizeW, sizeH] + + pts_color, valid_mask = sample_ptsFeatures_from_featureMaps( + pts, imgs, w2cs, intrinsics, img_wh, + proj_matrix=None, return_mask=True) # [N_views, c, N_rays, n_samples], [N_views, N_rays, n_samples] + + pts_color = pts_color.permute(2, 3, 0, 1) + valid_mask = valid_mask.permute(1, 2, 0) + + return pts_color, valid_mask # [N_rays, n_samples, N_views, 3] , [N_rays, n_samples, N_views] + + def patch_warp(self, pts, uv, normals, src_imgs, + ref_intrinsic, src_intrinsics, + ref_c2w, src_c2ws, img_wh=None + ): + """ + + :param pts: [N_rays, n_samples, 3] + :param uv : [N_rays, 2] normalized in (-1, 1) + :param normals: [N_rays, n_samples, 3] The normal of pt in world space + :param src_imgs: [N_src, 3, h, w] + :param ref_intrinsic: [4,4] + :param src_intrinsics: [N_src, 4, 4] + :param ref_c2w: [4,4] + :param src_c2ws: [N_src, 4, 4] + :return: + """ + device = pts.device + + N_rays, n_samples, _ = pts.shape + N_pts = N_rays * n_samples + + N_src, _, sizeH, sizeW = src_imgs.shape + + if img_wh is not None: + sizeW, sizeH = img_wh[0], img_wh[1] + + # scale uv from (-1, 1) to (0, W/H) + uv[:, 0] = (uv[:, 0] + 1) / 2. * (sizeW - 1) + uv[:, 1] = (uv[:, 1] + 1) / 2. * (sizeH - 1) + + ref_intr = ref_intrinsic[:3, :3] + inv_ref_intr = torch.inverse(ref_intr) + src_intrs = src_intrinsics[:, :3, :3] + inv_src_intrs = torch.inverse(src_intrs) + + ref_pose = ref_c2w + inv_ref_pose = torch.inverse(ref_pose) + src_poses = src_c2ws + inv_src_poses = torch.inverse(src_poses) + + ref_cam_loc = ref_pose[:3, 3].unsqueeze(0) # [1, 3] + sampled_dists = torch.norm(pts - ref_cam_loc, dim=-1) # [N_pts, 1] + + relative_proj = inv_src_poses @ ref_pose + R_rel = relative_proj[:, :3, :3] + t_rel = relative_proj[:, :3, 3:] + R_ref = inv_ref_pose[:3, :3] + t_ref = inv_ref_pose[:3, 3:] + + pts = pts.view(-1, 3) + normals = normals.view(-1, 3) + + with torch.no_grad(): + rot_normals = R_ref @ normals.unsqueeze(-1) # [N_pts, 3, 1] + points_in_ref = R_ref @ pts.unsqueeze( + -1) + t_ref # [N_pts, 3, 1] points in the reference frame coordiantes system + d1 = torch.sum(rot_normals * points_in_ref, dim=1).unsqueeze( + 1) # distance from the plane to ref camera center + + d2 = torch.sum(rot_normals.unsqueeze(1) * (-R_rel.transpose(1, 2) @ t_rel).unsqueeze(0), + dim=2) # distance from the plane to src camera center + valid_hom = (torch.abs(d1) > self.plane_dist_thresh) & ( + torch.abs(d1 - d2) > self.plane_dist_thresh) & ((d2 / d1) < 1) + + d1 = d1.squeeze() + sign = torch.sign(d1) + sign[sign == 0] = 1 + d = torch.clamp(torch.abs(d1), 1e-8) * sign + + H = src_intrs.unsqueeze(1) @ ( + R_rel.unsqueeze(1) + t_rel.unsqueeze(1) @ rot_normals.view(1, N_pts, 1, 3) / d.view(1, + N_pts, + 1, 1) + ) @ inv_ref_intr.view(1, 1, 3, 3) + + # replace invalid homs with fronto-parallel homographies + H_invalid = src_intrs.unsqueeze(1) @ ( + R_rel.unsqueeze(1) + t_rel.unsqueeze(1) @ self.z_axis.to(device).view(1, 1, 1, 3).expand(-1, N_pts, + -1, + -1) / sampled_dists.view( + 1, N_pts, 1, 1) + ) @ inv_ref_intr.view(1, 1, 3, 3) + tmp_m = ~valid_hom.view(-1, N_src).t() + H[tmp_m] = H_invalid[tmp_m] + + pixels = uv.view(N_rays, 1, 2) + self.offsets.float().to(device) + Npx = pixels.shape[1] + grid, warp_mask_full = self.patch_homography(H, pixels) + + warp_mask_full = warp_mask_full & (grid[..., 0] < (sizeW - self.h_patch_size)) & ( + grid[..., 1] < (sizeH - self.h_patch_size)) & (grid >= self.h_patch_size).all(dim=-1) + warp_mask_full = warp_mask_full.view(N_src, N_rays, n_samples, Npx) + + grid = torch.clamp(normalize(grid, sizeH, sizeW), -10, 10) + + sampled_rgb_val = F.grid_sample(src_imgs, grid.view(N_src, -1, 1, 2), align_corners=True).squeeze( + -1).transpose(1, 2) + sampled_rgb_val = sampled_rgb_val.view(N_src, N_rays, n_samples, Npx, 3) + + warp_mask_full = warp_mask_full.permute(1, 2, 0, 3).contiguous() # (N_rays, n_samples, N_src, Npx) + sampled_rgb_val = sampled_rgb_val.permute(1, 2, 0, 3, 4).contiguous() # (N_rays, n_samples, N_src, Npx, 3) + + return sampled_rgb_val, warp_mask_full + + def patch_homography(self, H, uv): + N, Npx = uv.shape[:2] + Nsrc = H.shape[0] + H = H.view(Nsrc, N, -1, 3, 3) + hom_uv = add_hom(uv) + + # einsum is 30 times faster + # tmp = (H.view(Nsrc, N, -1, 1, 3, 3) @ hom_uv.view(1, N, 1, -1, 3, 1)).squeeze(-1).view(Nsrc, -1, 3) + tmp = torch.einsum("vprik,pok->vproi", H, hom_uv).reshape(Nsrc, -1, 3) + + grid = tmp[..., :2] / torch.clamp(tmp[..., 2:], 1e-8) + mask = tmp[..., 2] > 0 + return grid, mask + + +def add_hom(pts): + try: + dev = pts.device + ones = torch.ones(pts.shape[:-1], device=dev).unsqueeze(-1) + return torch.cat((pts, ones), dim=-1) + + except AttributeError: + ones = np.ones((pts.shape[0], 1)) + return np.concatenate((pts, ones), axis=1) + + +def normalize(flow, h, w, clamp=None): + # either h and w are simple float or N torch.tensor where N batch size + try: + h.device + + except AttributeError: + h = torch.tensor(h, device=flow.device).float().unsqueeze(0) + w = torch.tensor(w, device=flow.device).float().unsqueeze(0) + + if len(flow.shape) == 4: + w = w.unsqueeze(1).unsqueeze(2) + h = h.unsqueeze(1).unsqueeze(2) + elif len(flow.shape) == 3: + w = w.unsqueeze(1) + h = h.unsqueeze(1) + elif len(flow.shape) == 5: + w = w.unsqueeze(0).unsqueeze(2).unsqueeze(2) + h = h.unsqueeze(0).unsqueeze(2).unsqueeze(2) + + res = torch.empty_like(flow) + if res.shape[-1] == 3: + res[..., 2] = 1 + + # for grid_sample with align_corners=True + # https://github.com/pytorch/pytorch/blob/c371542efc31b1abfe6f388042aa3ab0cef935f2/aten/src/ATen/native/GridSampler.h#L33 + res[..., 0] = 2 * flow[..., 0] / (w - 1) - 1 + res[..., 1] = 2 * flow[..., 1] / (h - 1) - 1 + + if clamp: + return torch.clamp(res, -clamp, clamp) + else: + return res + + +def build_patch_offset(h_patch_size): + offsets = torch.arange(-h_patch_size, h_patch_size + 1) + return torch.stack(torch.meshgrid(offsets, offsets)[::-1], dim=-1).view(1, -1, 2) # nb_pixels_patch * 2 diff --git a/SparseNeuS_demo_v1/models/projector.py b/SparseNeuS_demo_v1/models/projector.py new file mode 100644 index 0000000000000000000000000000000000000000..aa58d3f896edefff25cbb6fa713e7342d9b84a1d --- /dev/null +++ b/SparseNeuS_demo_v1/models/projector.py @@ -0,0 +1,425 @@ +# The codes are partly from IBRNet + +import torch +import torch.nn.functional as F +from models.render_utils import sample_ptsFeatures_from_featureMaps, sample_ptsFeatures_from_featureVolume + +def safe_l2_normalize(x, dim=None, eps=1e-6): + return F.normalize(x, p=2, dim=dim, eps=eps) + +class Projector(): + """ + Obtain features from geometryVolume and rendering_feature_maps for generalized rendering + """ + + def compute_angle(self, xyz, query_c2w, supporting_c2ws): + """ + + :param xyz: [N_rays, n_samples,3 ] + :param query_c2w: [1,4,4] + :param supporting_c2ws: [n,4,4] + :return: + """ + N_rays, n_samples, _ = xyz.shape + num_views = supporting_c2ws.shape[0] + xyz = xyz.reshape(-1, 3) + + ray2tar_pose = (query_c2w[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0)) + ray2tar_pose /= (torch.norm(ray2tar_pose, dim=-1, keepdim=True) + 1e-6) + ray2support_pose = (supporting_c2ws[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0)) + ray2support_pose /= (torch.norm(ray2support_pose, dim=-1, keepdim=True) + 1e-6) + ray_diff = ray2tar_pose - ray2support_pose + ray_diff_norm = torch.norm(ray_diff, dim=-1, keepdim=True) + ray_diff_dot = torch.sum(ray2tar_pose * ray2support_pose, dim=-1, keepdim=True) + ray_diff_direction = ray_diff / torch.clamp(ray_diff_norm, min=1e-6) + ray_diff = torch.cat([ray_diff_direction, ray_diff_dot], dim=-1) + ray_diff = ray_diff.reshape((num_views, N_rays, n_samples, 4)) # the last dimension (4) is dot-product + return ray_diff.detach() + + + def compute_angle_view_independent(self, xyz, surface_normals, supporting_c2ws): + """ + + :param xyz: [N_rays, n_samples,3 ] + :param surface_normals: [N_rays, n_samples,3 ] + :param supporting_c2ws: [n,4,4] + :return: + """ + N_rays, n_samples, _ = xyz.shape + num_views = supporting_c2ws.shape[0] + xyz = xyz.reshape(-1, 3) + + ray2tar_pose = surface_normals + ray2support_pose = (supporting_c2ws[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0)) + ray2support_pose /= (torch.norm(ray2support_pose, dim=-1, keepdim=True) + 1e-6) + ray_diff = ray2tar_pose - ray2support_pose + ray_diff_norm = torch.norm(ray_diff, dim=-1, keepdim=True) + ray_diff_dot = torch.sum(ray2tar_pose * ray2support_pose, dim=-1, keepdim=True) + ray_diff_direction = ray_diff / torch.clamp(ray_diff_norm, min=1e-6) + ray_diff = torch.cat([ray_diff_direction, ray_diff_dot], dim=-1) + ray_diff = ray_diff.reshape((num_views, N_rays, n_samples, 4)) # the last dimension (4) is dot-product, + # and the first three dimensions is the normalized ray diff vector + return ray_diff.detach() + + @torch.no_grad() + def compute_z_diff(self, xyz, w2cs, intrinsics, pred_depth_values): + """ + compute the depth difference of query pts projected on the image and the predicted depth values of the image + :param xyz: [N_rays, n_samples,3 ] + :param w2cs: [N_views, 4, 4] + :param intrinsics: [N_views, 3, 3] + :param pred_depth_values: [N_views, N_rays, n_samples,1 ] + :param pred_depth_masks: [N_views, N_rays, n_samples] + :return: + """ + device = xyz.device + N_views = w2cs.shape[0] + N_rays, n_samples, _ = xyz.shape + proj_matrix = torch.matmul(intrinsics, w2cs[:, :3, :]) + + proj_rot = proj_matrix[:, :3, :3] + proj_trans = proj_matrix[:, :3, 3:] + + batch_xyz = xyz.permute(2, 0, 1).contiguous().view(1, 3, N_rays * n_samples).repeat(N_views, 1, 1) + + proj_xyz = proj_rot.bmm(batch_xyz) + proj_trans + + # X = proj_xyz[:, 0] + # Y = proj_xyz[:, 1] + Z = proj_xyz[:, 2].clamp(min=1e-3) # [N_views, N_rays*n_samples] + proj_z = Z.view(N_views, N_rays, n_samples, 1) + + z_diff = proj_z - pred_depth_values # [N_views, N_rays, n_samples,1 ] + + return z_diff + + def compute(self, + pts, + # * 3d geometry feature volumes + geometryVolume=None, + geometryVolumeMask=None, + vol_dims=None, + partial_vol_origin=None, + vol_size=None, + # * 2d rendering feature maps + rendering_feature_maps=None, + color_maps=None, + w2cs=None, + intrinsics=None, + img_wh=None, + query_img_idx=0, # the index of the N_views dim for rendering + query_c2w=None, + pred_depth_maps=None, # no use here + pred_depth_masks=None # no use here + ): + """ + extract features of pts for rendering + :param pts: + :param geometryVolume: + :param vol_dims: + :param partial_vol_origin: + :param vol_size: + :param rendering_feature_maps: + :param color_maps: + :param w2cs: + :param intrinsics: + :param img_wh: + :param rendering_img_idx: by default, we render the first view of w2cs + :return: + """ + device = pts.device + c2ws = torch.inverse(w2cs) + + if len(pts.shape) == 2: + pts = pts[None, :, :] + + N_rays, n_samples, _ = pts.shape + N_views = rendering_feature_maps.shape[0] # shape (N_views, C, H, W) + + supporting_img_idxs = torch.LongTensor([x for x in range(N_views) if x != query_img_idx]).to(device) + query_img_idx = torch.LongTensor([query_img_idx]).to(device) + + if query_c2w is None and query_img_idx > -1: + query_c2w = torch.index_select(c2ws, 0, query_img_idx) + supporting_c2ws = torch.index_select(c2ws, 0, supporting_img_idxs) + supporting_w2cs = torch.index_select(w2cs, 0, supporting_img_idxs) + supporting_rendering_feature_maps = torch.index_select(rendering_feature_maps, 0, supporting_img_idxs) + supporting_color_maps = torch.index_select(color_maps, 0, supporting_img_idxs) + supporting_intrinsics = torch.index_select(intrinsics, 0, supporting_img_idxs) + + if pred_depth_maps is not None: + supporting_depth_maps = torch.index_select(pred_depth_maps, 0, supporting_img_idxs) + supporting_depth_masks = torch.index_select(pred_depth_masks, 0, supporting_img_idxs) + # print("N_supporting_views: ", N_views - 1) + N_supporting_views = N_views - 1 + else: + supporting_c2ws = c2ws + supporting_w2cs = w2cs + supporting_rendering_feature_maps = rendering_feature_maps + supporting_color_maps = color_maps + supporting_intrinsics = intrinsics + supporting_depth_maps = pred_depth_masks + supporting_depth_masks = pred_depth_masks + # print("N_supporting_views: ", N_views) + N_supporting_views = N_views + # import ipdb; ipdb.set_trace() + if geometryVolume is not None: + # * sample feature of pts from 3D feature volume + pts_geometry_feature, pts_geometry_masks_0 = sample_ptsFeatures_from_featureVolume( + pts, geometryVolume, vol_dims, + partial_vol_origin, vol_size) # [N_rays, n_samples, C], [N_rays, n_samples] + + if len(geometryVolumeMask.shape) == 3: + geometryVolumeMask = geometryVolumeMask[None, :, :, :] + + pts_geometry_masks_1, _ = sample_ptsFeatures_from_featureVolume( + pts, geometryVolumeMask.to(geometryVolume.dtype), vol_dims, + partial_vol_origin, vol_size) # [N_rays, n_samples, C] + + pts_geometry_masks = pts_geometry_masks_0 & (pts_geometry_masks_1[..., 0] > 0) + else: + pts_geometry_feature = None + pts_geometry_masks = None + + # * sample feature of pts from 2D feature maps + pts_rendering_feats, pts_rendering_mask = sample_ptsFeatures_from_featureMaps( + pts, supporting_rendering_feature_maps, supporting_w2cs, + supporting_intrinsics, img_wh, + return_mask=True) # [N_views, C, N_rays, n_samples], # [N_views, N_rays, n_samples] + # import ipdb; ipdb.set_trace() + # * size (N_views, N_rays*n_samples, c) + pts_rendering_feats = pts_rendering_feats.permute(0, 2, 3, 1).contiguous() + + pts_rendering_colors = sample_ptsFeatures_from_featureMaps(pts, supporting_color_maps, supporting_w2cs, + supporting_intrinsics, img_wh) + # * size (N_views, N_rays*n_samples, c) + pts_rendering_colors = pts_rendering_colors.permute(0, 2, 3, 1).contiguous() + + rgb_feats = torch.cat([pts_rendering_colors, pts_rendering_feats], dim=-1) # [N_views, N_rays, n_samples, 3+c] + + + ray_diff = self.compute_angle(pts, query_c2w, supporting_c2ws) # [N_views, N_rays, n_samples, 4] + # import ipdb; ipdb.set_trace() + if pts_geometry_masks is not None: + final_mask = pts_geometry_masks[None, :, :].repeat(N_supporting_views, 1, 1) & \ + pts_rendering_mask # [N_views, N_rays, n_samples] + else: + final_mask = pts_rendering_mask + # import ipdb; ipdb.set_trace() + z_diff, pts_pred_depth_masks = None, None + + if pred_depth_maps is not None: + pts_pred_depth_values = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_maps, supporting_w2cs, + supporting_intrinsics, img_wh) + pts_pred_depth_values = pts_pred_depth_values.permute(0, 2, 3, + 1).contiguous() # (N_views, N_rays*n_samples, 1) + + # - pts_pred_depth_masks are critical than final_mask, + # - the ray containing few invalid pts will be treated invalid + pts_pred_depth_masks = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_masks.float(), + supporting_w2cs, + supporting_intrinsics, img_wh) + + pts_pred_depth_masks = pts_pred_depth_masks.permute(0, 2, 3, 1).contiguous()[:, :, :, + 0] # (N_views, N_rays*n_samples) + + z_diff = self.compute_z_diff(pts, supporting_w2cs, supporting_intrinsics, pts_pred_depth_values) + # import ipdb; ipdb.set_trace() + return pts_geometry_feature, rgb_feats, ray_diff, final_mask, z_diff, pts_pred_depth_masks + + + def compute_view_independent( + self, + pts, + # * 3d geometry feature volumes + geometryVolume=None, + geometryVolumeMask=None, + sdf_network=None, + lod=0, + vol_dims=None, + partial_vol_origin=None, + vol_size=None, + # * 2d rendering feature maps + rendering_feature_maps=None, + color_maps=None, + w2cs=None, + target_candidate_w2cs=None, + intrinsics=None, + img_wh=None, + query_img_idx=0, # the index of the N_views dim for rendering + query_c2w=None, + pred_depth_maps=None, # no use here + pred_depth_masks=None # no use here + ): + """ + extract features of pts for rendering + :param pts: + :param geometryVolume: + :param vol_dims: + :param partial_vol_origin: + :param vol_size: + :param rendering_feature_maps: + :param color_maps: + :param w2cs: + :param intrinsics: + :param img_wh: + :param rendering_img_idx: by default, we render the first view of w2cs + :return: + """ + device = pts.device + c2ws = torch.inverse(w2cs) + + if len(pts.shape) == 2: + pts = pts[None, :, :] + + N_rays, n_samples, _ = pts.shape + N_views = rendering_feature_maps.shape[0] # shape (N_views, C, H, W) + + supporting_img_idxs = torch.LongTensor([x for x in range(N_views) if x != query_img_idx]).to(device) + query_img_idx = torch.LongTensor([query_img_idx]).to(device) + + if query_c2w is None and query_img_idx > -1: + query_c2w = torch.index_select(c2ws, 0, query_img_idx) + supporting_c2ws = torch.index_select(c2ws, 0, supporting_img_idxs) + supporting_w2cs = torch.index_select(w2cs, 0, supporting_img_idxs) + supporting_rendering_feature_maps = torch.index_select(rendering_feature_maps, 0, supporting_img_idxs) + supporting_color_maps = torch.index_select(color_maps, 0, supporting_img_idxs) + supporting_intrinsics = torch.index_select(intrinsics, 0, supporting_img_idxs) + + if pred_depth_maps is not None: + supporting_depth_maps = torch.index_select(pred_depth_maps, 0, supporting_img_idxs) + supporting_depth_masks = torch.index_select(pred_depth_masks, 0, supporting_img_idxs) + # print("N_supporting_views: ", N_views - 1) + N_supporting_views = N_views - 1 + else: + supporting_c2ws = c2ws + supporting_w2cs = w2cs + supporting_rendering_feature_maps = rendering_feature_maps + supporting_color_maps = color_maps + supporting_intrinsics = intrinsics + supporting_depth_maps = pred_depth_masks + supporting_depth_masks = pred_depth_masks + # print("N_supporting_views: ", N_views) + N_supporting_views = N_views + # import ipdb; ipdb.set_trace() + if geometryVolume is not None: + # * sample feature of pts from 3D feature volume + pts_geometry_feature, pts_geometry_masks_0 = sample_ptsFeatures_from_featureVolume( + pts, geometryVolume, vol_dims, + partial_vol_origin, vol_size) # [N_rays, n_samples, C], [N_rays, n_samples] + + if len(geometryVolumeMask.shape) == 3: + geometryVolumeMask = geometryVolumeMask[None, :, :, :] + + pts_geometry_masks_1, _ = sample_ptsFeatures_from_featureVolume( + pts, geometryVolumeMask.to(geometryVolume.dtype), vol_dims, + partial_vol_origin, vol_size) # [N_rays, n_samples, C] + + pts_geometry_masks = pts_geometry_masks_0 & (pts_geometry_masks_1[..., 0] > 0) + else: + pts_geometry_feature = None + pts_geometry_masks = None + + # * sample feature of pts from 2D feature maps + pts_rendering_feats, pts_rendering_mask = sample_ptsFeatures_from_featureMaps( + pts, supporting_rendering_feature_maps, supporting_w2cs, + supporting_intrinsics, img_wh, + return_mask=True) # [N_views, C, N_rays, n_samples], # [N_views, N_rays, n_samples] + + # * size (N_views, N_rays*n_samples, c) + pts_rendering_feats = pts_rendering_feats.permute(0, 2, 3, 1).contiguous() + + pts_rendering_colors = sample_ptsFeatures_from_featureMaps(pts, supporting_color_maps, supporting_w2cs, + supporting_intrinsics, img_wh) + # * size (N_views, N_rays*n_samples, c) + pts_rendering_colors = pts_rendering_colors.permute(0, 2, 3, 1).contiguous() + + rgb_feats = torch.cat([pts_rendering_colors, pts_rendering_feats], dim=-1) # [N_views, N_rays, n_samples, 3+c] + + # import ipdb; ipdb.set_trace() + + gradients = sdf_network.gradient( + pts.reshape(-1, 3), # pts.squeeze(0), + geometryVolume.unsqueeze(0), + lod=lod + ).squeeze() + + surface_normals = safe_l2_normalize(gradients, dim=-1) # [npts, 3] + # input normals + ren_ray_diff = self.compute_angle_view_independent( + xyz=pts, + surface_normals=surface_normals, + supporting_c2ws=supporting_c2ws + ) + + # # choose closest target view direction from 32 candidate views + # # choose the closest source view as view direction instead of the normals vectors + # pts2src_centers = safe_l2_normalize((supporting_c2ws[:, :3, 3].unsqueeze(1) - pts)) # [N_views, npts, 3] + + # cosine_distance = torch.sum(pts2src_centers * surface_normals, dim=-1, keepdim=True) # [N_views, npts, 1] + # # choose the largest cosine distance as the view direction + # max_idx = torch.argmax(cosine_distance, dim=0) # [npts, 1] + + # chosen_view_direction = pts2src_centers[max_idx.squeeze(), torch.arange(pts.shape[1]), :] # [npts, 3] + # ren_ray_diff = self.compute_angle_view_independent( + # xyz=pts, + # surface_normals=chosen_view_direction, + # supporting_c2ws=supporting_c2ws + # ) + + + + # # choose closest target view direction from 8 candidate views + # # choose the closest source view as view direction instead of the normals vectors + # target_candidate_c2ws = torch.inverse(target_candidate_w2cs) + # pts2src_centers = safe_l2_normalize((target_candidate_c2ws[:, :3, 3].unsqueeze(1) - pts)) # [N_views, npts, 3] + + # cosine_distance = torch.sum(pts2src_centers * surface_normals, dim=-1, keepdim=True) # [N_views, npts, 1] + # # choose the largest cosine distance as the view direction + # max_idx = torch.argmax(cosine_distance, dim=0) # [npts, 1] + + # chosen_view_direction = pts2src_centers[max_idx.squeeze(), torch.arange(pts.shape[1]), :] # [npts, 3] + # ren_ray_diff = self.compute_angle_view_independent( + # xyz=pts, + # surface_normals=chosen_view_direction, + # supporting_c2ws=supporting_c2ws + # ) + + + # ray_diff = self.compute_angle(pts, query_c2w, supporting_c2ws) # [N_views, N_rays, n_samples, 4] + # import ipdb; ipdb.set_trace() + + + # input_directions = safe_l2_normalize(pts) + # ren_ray_diff = self.compute_angle_view_independent( + # xyz=pts, + # surface_normals=input_directions, + # supporting_c2ws=supporting_c2ws + # ) + + if pts_geometry_masks is not None: + final_mask = pts_geometry_masks[None, :, :].repeat(N_supporting_views, 1, 1) & \ + pts_rendering_mask # [N_views, N_rays, n_samples] + else: + final_mask = pts_rendering_mask + # import ipdb; ipdb.set_trace() + z_diff, pts_pred_depth_masks = None, None + + if pred_depth_maps is not None: + pts_pred_depth_values = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_maps, supporting_w2cs, + supporting_intrinsics, img_wh) + pts_pred_depth_values = pts_pred_depth_values.permute(0, 2, 3, + 1).contiguous() # (N_views, N_rays*n_samples, 1) + + # - pts_pred_depth_masks are critical than final_mask, + # - the ray containing few invalid pts will be treated invalid + pts_pred_depth_masks = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_masks.float(), + supporting_w2cs, + supporting_intrinsics, img_wh) + + pts_pred_depth_masks = pts_pred_depth_masks.permute(0, 2, 3, 1).contiguous()[:, :, :, + 0] # (N_views, N_rays*n_samples) + + z_diff = self.compute_z_diff(pts, supporting_w2cs, supporting_intrinsics, pts_pred_depth_values) + # import ipdb; ipdb.set_trace() + return pts_geometry_feature, rgb_feats, ren_ray_diff, final_mask, z_diff, pts_pred_depth_masks diff --git a/SparseNeuS_demo_v1/models/rays.py b/SparseNeuS_demo_v1/models/rays.py new file mode 100644 index 0000000000000000000000000000000000000000..a31df93e727fd79adaaa3e934c67378b611d4ee0 --- /dev/null +++ b/SparseNeuS_demo_v1/models/rays.py @@ -0,0 +1,325 @@ +import os, torch, cv2, re +import numpy as np + +from PIL import Image +import torch.nn.functional as F +import torchvision.transforms as T + +from random import random + + +def build_patch_offset(h_patch_size): + offsets = torch.arange(-h_patch_size, h_patch_size + 1) + return torch.stack(torch.meshgrid(offsets, offsets)[::-1], dim=-1).view(1, -1, 2) # nb_pixels_patch * 2 + + +def gen_rays_from_single_image(H, W, image, intrinsic, c2w, depth=None, mask=None): + """ + generate rays in world space, for image image + :param H: + :param W: + :param intrinsics: [3,3] + :param c2ws: [4,4] + :return: + """ + device = image.device + ys, xs = torch.meshgrid(torch.linspace(0, H - 1, H), + torch.linspace(0, W - 1, W)) # pytorch's meshgrid has indexing='ij' + p = torch.stack([xs, ys, torch.ones_like(ys)], dim=-1) # H, W, 3 + + # normalized ndc uv coordinates, (-1, 1) + ndc_u = 2 * xs / (W - 1) - 1 + ndc_v = 2 * ys / (H - 1) - 1 + rays_ndc_uv = torch.stack([ndc_u, ndc_v], dim=-1).view(-1, 2).float().to(device) + + intrinsic_inv = torch.inverse(intrinsic) + + p = p.view(-1, 3).float().to(device) # N_rays, 3 + p = torch.matmul(intrinsic_inv[None, :3, :3], p[:, :, None]).squeeze() # N_rays, 3 + rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # N_rays, 3 + rays_v = torch.matmul(c2w[None, :3, :3], rays_v[:, :, None]).squeeze() # N_rays, 3 + rays_o = c2w[None, :3, 3].expand(rays_v.shape) # N_rays, 3 + + image = image.permute(1, 2, 0) + color = image.view(-1, 3) + depth = depth.view(-1, 1) if depth is not None else None + mask = mask.view(-1, 1) if mask is not None else torch.ones([H * W, 1]).to(device) + sample = { + 'rays_o': rays_o, + 'rays_v': rays_v, + 'rays_ndc_uv': rays_ndc_uv, + 'rays_color': color, + # 'rays_depth': depth, + 'rays_mask': mask, + 'rays_norm_XYZ_cam': p # - XYZ_cam, before multiply depth + } + if depth is not None: + sample['rays_depth'] = depth + + return sample + + +def gen_random_rays_from_single_image(H, W, N_rays, image, intrinsic, c2w, depth=None, mask=None, dilated_mask=None, + importance_sample=False, h_patch_size=3): + """ + generate random rays in world space, for a single image + :param H: + :param W: + :param N_rays: + :param image: [3, H, W] + :param intrinsic: [3,3] + :param c2w: [4,4] + :param depth: [H, W] + :param mask: [H, W] + :return: + """ + device = image.device + + if dilated_mask is None: + dilated_mask = mask + + if not importance_sample: + pixels_x = torch.randint(low=0, high=W, size=[N_rays]) + pixels_y = torch.randint(low=0, high=H, size=[N_rays]) + elif importance_sample and dilated_mask is not None: # sample more pts in the valid mask regions + pixels_x_1 = torch.randint(low=0, high=W, size=[N_rays // 4]) + pixels_y_1 = torch.randint(low=0, high=H, size=[N_rays // 4]) + + ys, xs = torch.meshgrid(torch.linspace(0, H - 1, H), + torch.linspace(0, W - 1, W)) # pytorch's meshgrid has indexing='ij' + p = torch.stack([xs, ys], dim=-1) # H, W, 2 + + try: + p_valid = p[dilated_mask > 0] # [num, 2] + random_idx = torch.randint(low=0, high=p_valid.shape[0], size=[N_rays // 4 * 3]) + except: + print("dilated_mask.shape: ", dilated_mask.shape) + print("dilated_mask valid number", dilated_mask.sum()) + + raise ValueError("hhhh") + p_select = p_valid[random_idx] # [N_rays//2, 2] + pixels_x_2 = p_select[:, 0] + pixels_y_2 = p_select[:, 1] + + pixels_x = torch.cat([pixels_x_1, pixels_x_2], dim=0).to(torch.int64) + pixels_y = torch.cat([pixels_y_1, pixels_y_2], dim=0).to(torch.int64) + + # - crop patch from images + offsets = build_patch_offset(h_patch_size).to(device) + grid_patch = torch.stack([pixels_x, pixels_y], dim=-1).view(-1, 1, 2) + offsets.float() # [N_pts, Npx, 2] + patch_mask = (pixels_x > h_patch_size) * (pixels_x < (W - h_patch_size)) * (pixels_y > h_patch_size) * ( + pixels_y < H - h_patch_size) # [N_pts] + grid_patch_u = 2 * grid_patch[:, :, 0] / (W - 1) - 1 + grid_patch_v = 2 * grid_patch[:, :, 1] / (H - 1) - 1 + grid_patch_uv = torch.stack([grid_patch_u, grid_patch_v], dim=-1) # [N_pts, Npx, 2] + patch_color = F.grid_sample(image[None, :, :, :], grid_patch_uv[None, :, :, :], mode='bilinear', + padding_mode='zeros',align_corners=True)[0] # [3, N_pts, Npx] + patch_color = patch_color.permute(1, 2, 0).contiguous() + + # normalized ndc uv coordinates, (-1, 1) + ndc_u = 2 * pixels_x / (W - 1) - 1 + ndc_v = 2 * pixels_y / (H - 1) - 1 + rays_ndc_uv = torch.stack([ndc_u, ndc_v], dim=-1).view(-1, 2).float().to(device) + + image = image.permute(1, 2, 0) # H ,W, C + color = image[(pixels_y, pixels_x)] # N_rays, 3 + + if mask is not None: + mask = mask[(pixels_y, pixels_x)] # N_rays + patch_mask = patch_mask * mask # N_rays + mask = mask.view(-1, 1) + else: + mask = torch.ones([N_rays, 1]) + + if depth is not None: + depth = depth[(pixels_y, pixels_x)] # N_rays + depth = depth.view(-1, 1) + + intrinsic_inv = torch.inverse(intrinsic) + + p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).float().to(device) # N_rays, 3 + p = torch.matmul(intrinsic_inv[None, :3, :3], p[:, :, None]).squeeze() # N_rays, 3 + rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # N_rays, 3 + rays_v = torch.matmul(c2w[None, :3, :3], rays_v[:, :, None]).squeeze() # N_rays, 3 + rays_o = c2w[None, :3, 3].expand(rays_v.shape) # N_rays, 3 + + sample = { + 'rays_o': rays_o, + 'rays_v': rays_v, + 'rays_ndc_uv': rays_ndc_uv, + 'rays_color': color, + # 'rays_depth': depth, + 'rays_mask': mask, + 'rays_norm_XYZ_cam': p, # - XYZ_cam, before multiply depth, + 'rays_patch_color': patch_color, + 'rays_patch_mask': patch_mask.view(-1, 1) + } + + if depth is not None: + sample['rays_depth'] = depth + + return sample + + +def gen_random_rays_of_patch_from_single_image(H, W, N_rays, num_neighboring_pts, patch_size, + image, intrinsic, c2w, depth=None, mask=None): + """ + generate random rays in world space, for a single image + sample rays from local patches + :param H: + :param W: + :param N_rays: the number of center rays of patches + :param image: [3, H, W] + :param intrinsic: [3,3] + :param c2w: [4,4] + :param depth: [H, W] + :param mask: [H, W] + :return: + """ + device = image.device + patch_radius_max = patch_size // 2 + + unit_u = 2 / (W - 1) + unit_v = 2 / (H - 1) + + pixels_x_center = torch.randint(low=patch_size, high=W - patch_size, size=[N_rays]) + pixels_y_center = torch.randint(low=patch_size, high=H - patch_size, size=[N_rays]) + + # normalized ndc uv coordinates, (-1, 1) + ndc_u_center = 2 * pixels_x_center / (W - 1) - 1 + ndc_v_center = 2 * pixels_y_center / (H - 1) - 1 + ndc_uv_center = torch.stack([ndc_u_center, ndc_v_center], dim=-1).view(-1, 2).float().to(device)[:, None, + :] # [N_rays, 1, 2] + + shift_u, shift_v = torch.rand([N_rays, num_neighboring_pts, 1]), torch.rand( + [N_rays, num_neighboring_pts, 1]) # uniform distribution of [0,1) + shift_u = 2 * (shift_u - 0.5) # mapping to [-1, 1) + shift_v = 2 * (shift_v - 0.5) + + # - avoid sample points which are too close to center point + shift_uv = torch.cat([(shift_u * patch_radius_max) * unit_u, (shift_v * patch_radius_max) * unit_v], + dim=-1) # [N_rays, num_npts, 2] + neighboring_pts_uv = ndc_uv_center + shift_uv # [N_rays, num_npts, 2] + + sampled_pts_uv = torch.cat([ndc_uv_center, neighboring_pts_uv], dim=1) # concat the center point + + # sample the gts + color = F.grid_sample(image[None, :, :, :], sampled_pts_uv[None, :, :, :], mode='bilinear', + align_corners=True)[0] # [3, N_rays, num_npts] + depth = F.grid_sample(depth[None, None, :, :], sampled_pts_uv[None, :, :, :], mode='bilinear', + align_corners=True)[0] # [1, N_rays, num_npts] + + mask = F.grid_sample(mask[None, None, :, :].to(torch.float32), sampled_pts_uv[None, :, :, :], mode='nearest', + align_corners=True).to(torch.int64)[0] # [1, N_rays, num_npts] + + intrinsic_inv = torch.inverse(intrinsic) + + sampled_pts_uv = sampled_pts_uv.view(N_rays * (1 + num_neighboring_pts), 2) + color = color.permute(1, 2, 0).contiguous().view(N_rays * (1 + num_neighboring_pts), 3) + depth = depth.permute(1, 2, 0).contiguous().view(N_rays * (1 + num_neighboring_pts), 1) + mask = mask.permute(1, 2, 0).contiguous().view(N_rays * (1 + num_neighboring_pts), 1) + + pixels_x = (sampled_pts_uv[:, 0] + 1) * (W - 1) / 2 + pixels_y = (sampled_pts_uv[:, 1] + 1) * (H - 1) / 2 + p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).float().to(device) # N_rays*num_pts, 3 + p = torch.matmul(intrinsic_inv[None, :3, :3], p[:, :, None]).squeeze() # N_rays*num_pts, 3 + rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # N_rays*num_pts, 3 + rays_v = torch.matmul(c2w[None, :3, :3], rays_v[:, :, None]).squeeze() # N_rays*num_pts, 3 + rays_o = c2w[None, :3, 3].expand(rays_v.shape) # N_rays*num_pts, 3 + + sample = { + 'rays_o': rays_o, + 'rays_v': rays_v, + 'rays_ndc_uv': sampled_pts_uv, + 'rays_color': color, + 'rays_depth': depth, + 'rays_mask': mask, + # 'rays_norm_XYZ_cam': p # - XYZ_cam, before multiply depth + } + + return sample + + +def gen_random_rays_from_batch_images(H, W, N_rays, images, intrinsics, c2ws, depths=None, masks=None): + """ + + :param H: + :param W: + :param N_rays: + :param images: [B,3,H,W] + :param intrinsics: [B, 3, 3] + :param c2ws: [B, 4, 4] + :param depths: [B,H,W] + :param masks: [B,H,W] + :return: + """ + assert len(images.shape) == 4 + + rays_o = [] + rays_v = [] + rays_color = [] + rays_depth = [] + rays_mask = [] + for i in range(images.shape[0]): + sample = gen_random_rays_from_single_image(H, W, N_rays, images[i], intrinsics[i], c2ws[i], + depth=depths[i] if depths is not None else None, + mask=masks[i] if masks is not None else None) + rays_o.append(sample['rays_o']) + rays_v.append(sample['rays_v']) + rays_color.append(sample['rays_color']) + if depths is not None: + rays_depth.append(sample['rays_depth']) + if masks is not None: + rays_mask.append(sample['rays_mask']) + + sample = { + 'rays_o': torch.stack(rays_o, dim=0), # [batch, N_rays, 3] + 'rays_v': torch.stack(rays_v, dim=0), + 'rays_color': torch.stack(rays_color, dim=0), + 'rays_depth': torch.stack(rays_depth, dim=0) if depths is not None else None, + 'rays_mask': torch.stack(rays_mask, dim=0) if masks is not None else None + } + return sample + + +from scipy.spatial.transform import Rotation as Rot +from scipy.spatial.transform import Slerp + + +def gen_rays_between(c2w_0, c2w_1, intrinsic, ratio, H, W, resolution_level=1): + device = c2w_0.device + + l = resolution_level + tx = torch.linspace(0, W - 1, W // l) + ty = torch.linspace(0, H - 1, H // l) + pixels_x, pixels_y = torch.meshgrid(tx, ty) + p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).to(device) # W, H, 3 + + intrinsic_inv = torch.inverse(intrinsic[:3, :3]) + p = torch.matmul(intrinsic_inv[None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3 + rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3 + trans = c2w_0[:3, 3] * (1.0 - ratio) + c2w_1[:3, 3] * ratio + + pose_0 = c2w_0.detach().cpu().numpy() + pose_1 = c2w_1.detach().cpu().numpy() + pose_0 = np.linalg.inv(pose_0) + pose_1 = np.linalg.inv(pose_1) + rot_0 = pose_0[:3, :3] + rot_1 = pose_1[:3, :3] + rots = Rot.from_matrix(np.stack([rot_0, rot_1])) + key_times = [0, 1] + key_rots = [rot_0, rot_1] + slerp = Slerp(key_times, rots) + rot = slerp(ratio) + pose = np.diag([1.0, 1.0, 1.0, 1.0]) + pose = pose.astype(np.float32) + pose[:3, :3] = rot.as_matrix() + pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3] + pose = np.linalg.inv(pose) + + c2w = torch.from_numpy(pose).to(device) + rot = torch.from_numpy(pose[:3, :3]).cuda() + trans = torch.from_numpy(pose[:3, 3]).cuda() + rays_v = torch.matmul(rot[None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3 + rays_o = trans[None, None, :3].expand(rays_v.shape) # W, H, 3 + return c2w, rays_o.transpose(0, 1).contiguous().view(-1, 3), rays_v.transpose(0, 1).contiguous().view(-1, 3) diff --git a/SparseNeuS_demo_v1/models/render_utils.py b/SparseNeuS_demo_v1/models/render_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e1d3d8fc4ca7bf5e306733a213dec96a517a71c7 --- /dev/null +++ b/SparseNeuS_demo_v1/models/render_utils.py @@ -0,0 +1,126 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import logging +import mcubes +import trimesh +from icecream import ic + +from ops.back_project import cam2pixel +import pdb + + +def sample_pdf(bins, weights, n_samples, det=False): + ''' + :param bins: tensor of shape [N_rays, M+1], M is the number of bins + :param weights: tensor of shape [N_rays, M] + :param N_samples: number of samples along each ray + :param det: if True, will perform deterministic sampling + :return: [N_rays, N_samples] + ''' + device = weights.device + + weights = weights + 1e-5 # prevent nans + pdf = weights / torch.sum(weights, -1, keepdim=True) + cdf = torch.cumsum(pdf, -1) + cdf = torch.cat([torch.zeros_like(cdf[..., :1]).to(device), cdf], -1) + + # if bins.shape[1] != weights.shape[1]: # - minor modification, add this constraint + # cdf = torch.cat([torch.zeros_like(cdf[..., :1]).to(device), cdf], -1) + # Take uniform samples + if det: + u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(device) + u = u.expand(list(cdf.shape[:-1]) + [n_samples]) + else: + u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(device) + + # Invert CDF + u = u.contiguous() + # inds = searchsorted(cdf, u, side='right') + inds = torch.searchsorted(cdf, u, right=True) + + below = torch.max(torch.zeros_like(inds - 1), inds - 1) + above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) + inds_g = torch.stack([below, above], -1) # (batch, n_samples, 2) + + matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] + cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) + bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) + + denom = (cdf_g[..., 1] - cdf_g[..., 0]) + denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) + t = (u - cdf_g[..., 0]) / denom + samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) + + # pdb.set_trace() + return samples + + +def sample_ptsFeatures_from_featureVolume(pts, featureVolume, vol_dims=None, partial_vol_origin=None, vol_size=None): + """ + sample feature of pts_wrd from featureVolume, all in world space + :param pts: [N_rays, n_samples, 3] + :param featureVolume: [C,wX,wY,wZ] + :param vol_dims: [3] "3" for dimX, dimY, dimZ + :param partial_vol_origin: [3] + :return: pts_feature: [N_rays, n_samples, C] + :return: valid_mask: [N_rays] + """ + + N_rays, n_samples, _ = pts.shape + + if vol_dims is None: + pts_normalized = pts + else: + # normalized to (-1, 1) + pts_normalized = 2 * (pts - partial_vol_origin[None, None, :]) / (vol_size * (vol_dims[None, None, :] - 1)) - 1 + + valid_mask = (torch.abs(pts_normalized[:, :, 0]) < 1.0) & ( + torch.abs(pts_normalized[:, :, 1]) < 1.0) & ( + torch.abs(pts_normalized[:, :, 2]) < 1.0) # (N_rays, n_samples) + + pts_normalized = torch.flip(pts_normalized, dims=[-1]) # ! reverse the xyz for grid_sample + + # ! checked grid_sample, (x,y,z) is for (D,H,W), reverse for (W,H,D) + pts_feature = F.grid_sample(featureVolume[None, :, :, :, :], pts_normalized[None, None, :, :, :], + padding_mode='zeros', + align_corners=True).view(-1, N_rays, n_samples) # [C, N_rays, n_samples] + + pts_feature = pts_feature.permute(1, 2, 0) # [N_rays, n_samples, C] + return pts_feature, valid_mask + + +def sample_ptsFeatures_from_featureMaps(pts, featureMaps, w2cs, intrinsics, WH, proj_matrix=None, return_mask=False): + """ + sample features of pts from 2d feature maps + :param pts: [N_rays, N_samples, 3] + :param featureMaps: [N_views, C, H, W] + :param w2cs: [N_views, 4, 4] + :param intrinsics: [N_views, 3, 3] + :param proj_matrix: [N_views, 4, 4] + :param HW: + :return: + """ + # normalized to (-1, 1) + N_rays, n_samples, _ = pts.shape + N_views = featureMaps.shape[0] + + if proj_matrix is None: + proj_matrix = torch.matmul(intrinsics, w2cs[:, :3, :]) + + pts = pts.permute(2, 0, 1).contiguous().view(1, 3, N_rays, n_samples).repeat(N_views, 1, 1, 1) + pixel_grids = cam2pixel(pts, proj_matrix[:, :3, :3], proj_matrix[:, :3, 3:], + 'zeros', sizeH=WH[1], sizeW=WH[0]) # (nviews, N_rays, n_samples, 2) + + valid_mask = (torch.abs(pixel_grids[:, :, :, 0]) < 1.0) & ( + torch.abs(pixel_grids[:, :, :, 1]) < 1.00) # (nviews, N_rays, n_samples) + + pts_feature = F.grid_sample(featureMaps, pixel_grids, + padding_mode='zeros', + align_corners=True) # [N_views, C, N_rays, n_samples] + + if return_mask: + return pts_feature, valid_mask + else: + return pts_feature diff --git a/SparseNeuS_demo_v1/models/rendering_network.py b/SparseNeuS_demo_v1/models/rendering_network.py new file mode 100644 index 0000000000000000000000000000000000000000..bfc984223450a609024a65956439ff741a6b133d --- /dev/null +++ b/SparseNeuS_demo_v1/models/rendering_network.py @@ -0,0 +1,129 @@ +# the codes are partly borrowed from IBRNet + +import torch +import torch.nn as nn +import torch.nn.functional as F + +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) + + +# default tensorflow initialization of linear layers +def weights_init(m): + if isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight.data) + if m.bias is not None: + nn.init.zeros_(m.bias.data) + + +@torch.jit.script +def fused_mean_variance(x, weight): + mean = torch.sum(x * weight, dim=2, keepdim=True) + var = torch.sum(weight * (x - mean) ** 2, dim=2, keepdim=True) + return mean, var + + +class GeneralRenderingNetwork(nn.Module): + """ + This model is not sensitive to finetuning + """ + + def __init__(self, in_geometry_feat_ch=8, in_rendering_feat_ch=56, anti_alias_pooling=True): + super(GeneralRenderingNetwork, self).__init__() + + self.in_geometry_feat_ch = in_geometry_feat_ch + self.in_rendering_feat_ch = in_rendering_feat_ch + self.anti_alias_pooling = anti_alias_pooling + + if self.anti_alias_pooling: + self.s = nn.Parameter(torch.tensor(0.2), requires_grad=True) + activation_func = nn.ELU(inplace=True) + + self.ray_dir_fc = nn.Sequential(nn.Linear(4, 16), + activation_func, + nn.Linear(16, in_rendering_feat_ch + 3), + activation_func) + + self.base_fc = nn.Sequential(nn.Linear((in_rendering_feat_ch + 3) * 3 + in_geometry_feat_ch, 64), + activation_func, + nn.Linear(64, 32), + activation_func) + + self.vis_fc = nn.Sequential(nn.Linear(32, 32), + activation_func, + nn.Linear(32, 33), + activation_func, + ) + + self.vis_fc2 = nn.Sequential(nn.Linear(32, 32), + activation_func, + nn.Linear(32, 1), + nn.Sigmoid() + ) + + self.rgb_fc = nn.Sequential(nn.Linear(32 + 1 + 4, 16), + activation_func, + nn.Linear(16, 8), + activation_func, + nn.Linear(8, 1)) + + self.base_fc.apply(weights_init) + self.vis_fc2.apply(weights_init) + self.vis_fc.apply(weights_init) + self.rgb_fc.apply(weights_init) + + def forward(self, geometry_feat, rgb_feat, ray_diff, mask): + ''' + :param geometry_feat: geometry features indicates sdf [n_rays, n_samples, n_feat] + :param rgb_feat: rgbs and image features [n_views, n_rays, n_samples, n_feat] + :param ray_diff: ray direction difference [n_views, n_rays, n_samples, 4], first 3 channels are directions, + last channel is inner product + :param mask: mask for whether each projection is valid or not. [n_views, n_rays, n_samples] + :return: rgb and density output, [n_rays, n_samples, 4] + ''' + + rgb_feat = rgb_feat.permute(1, 2, 0, 3).contiguous() + ray_diff = ray_diff.permute(1, 2, 0, 3).contiguous() + mask = mask[:, :, :, None].permute(1, 2, 0, 3).contiguous() + num_views = rgb_feat.shape[2] + geometry_feat = geometry_feat[:, :, None, :].repeat(1, 1, num_views, 1) + + direction_feat = self.ray_dir_fc(ray_diff) + rgb_in = rgb_feat[..., :3] + rgb_feat = rgb_feat + direction_feat + + if self.anti_alias_pooling: + _, dot_prod = torch.split(ray_diff, [3, 1], dim=-1) + exp_dot_prod = torch.exp(torch.abs(self.s) * (dot_prod - 1)) + weight = (exp_dot_prod - torch.min(exp_dot_prod, dim=2, keepdim=True)[0]) * mask + weight = weight / (torch.sum(weight, dim=2, keepdim=True) + 1e-8) + else: + weight = mask / (torch.sum(mask, dim=2, keepdim=True) + 1e-8) + + # compute mean and variance across different views for each point + mean, var = fused_mean_variance(rgb_feat, weight) # [n_rays, n_samples, 1, n_feat] + globalfeat = torch.cat([mean, var], dim=-1) # [n_rays, n_samples, 1, 2*n_feat] + + x = torch.cat([geometry_feat, globalfeat.expand(-1, -1, num_views, -1), rgb_feat], + dim=-1) # [n_rays, n_samples, n_views, 3*n_feat+n_geo_feat] + x = self.base_fc(x) + + x_vis = self.vis_fc(x * weight) + x_res, vis = torch.split(x_vis, [x_vis.shape[-1] - 1, 1], dim=-1) + vis = F.sigmoid(vis) * mask + x = x + x_res + vis = self.vis_fc2(x * vis) * mask + + # rgb computation + x = torch.cat([x, vis, ray_diff], dim=-1) + x = self.rgb_fc(x) + x = x.masked_fill(mask == 0, -1e9) + blending_weights_valid = F.softmax(x, dim=2) # color blending + rgb_out = torch.sum(rgb_in * blending_weights_valid, dim=2) + + mask = mask.detach().to(rgb_out.dtype) # [n_rays, n_samples, n_views, 1] + mask = torch.sum(mask, dim=2, keepdim=False) + mask = mask >= 2 # more than 2 views see the point + mask = torch.sum(mask.to(rgb_out.dtype), dim=1, keepdim=False) + valid_mask = mask > 8 # valid rays, more than 8 valid samples + return rgb_out, valid_mask # (N_rays, n_samples, 3), (N_rays, 1) diff --git a/SparseNeuS_demo_v1/models/sparse_neus_renderer.py b/SparseNeuS_demo_v1/models/sparse_neus_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..8015669f349f5b61ca1cb234ec2fcdf71cd10407 --- /dev/null +++ b/SparseNeuS_demo_v1/models/sparse_neus_renderer.py @@ -0,0 +1,990 @@ +""" +The codes are heavily borrowed from NeuS +""" + +import os +import cv2 as cv +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import logging +import mcubes +import trimesh +from icecream import ic +from models.render_utils import sample_pdf + +from models.projector import Projector +from tsparse.torchsparse_utils import sparse_to_dense_channel + +from models.fast_renderer import FastRenderer + +from models.patch_projector import PatchProjector + +from models.rays import gen_rays_between + +import pdb + + +class SparseNeuSRenderer(nn.Module): + """ + conditional neus render; + optimize on normalized world space; + warped by nn.Module to support DataParallel traning + """ + + def __init__(self, + rendering_network_outside, + sdf_network, + variance_network, + rendering_network, + n_samples, + n_importance, + n_outside, + perturb, + alpha_type='div', + conf=None + ): + super(SparseNeuSRenderer, self).__init__() + + self.conf = conf + self.base_exp_dir = conf['general.base_exp_dir'] + + # network setups + self.rendering_network_outside = rendering_network_outside + self.sdf_network = sdf_network + self.variance_network = variance_network + self.rendering_network = rendering_network + + self.n_samples = n_samples + self.n_importance = n_importance + self.n_outside = n_outside + self.perturb = perturb + self.alpha_type = alpha_type + + self.rendering_projector = Projector() # used to obtain features for generalized rendering + + self.h_patch_size = self.conf.get_int('model.h_patch_size', default=3) + self.patch_projector = PatchProjector(self.h_patch_size) + + self.ray_tracer = FastRenderer() # ray_tracer to extract depth maps from sdf_volume + + # - fitted rendering or general rendering + try: + self.if_fitted_rendering = self.sdf_network.if_fitted_rendering + except: + self.if_fitted_rendering = False + + def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_variance, + conditional_valid_mask_volume=None): + device = rays_o.device + batch_size, n_samples = z_vals.shape + pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3 + + if conditional_valid_mask_volume is not None: + pts_mask = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), conditional_valid_mask_volume) + pts_mask = pts_mask.reshape(batch_size, n_samples) + pts_mask = pts_mask[:, :-1] * pts_mask[:, 1:] # [batch_size, n_samples-1] + else: + pts_mask = torch.ones([batch_size, n_samples]).to(pts.device) + + sdf = sdf.reshape(batch_size, n_samples) + prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:] + prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:] + mid_sdf = (prev_sdf + next_sdf) * 0.5 + dot_val = None + if self.alpha_type == 'uniform': + dot_val = torch.ones([batch_size, n_samples - 1]) * -1.0 + else: + dot_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5) + prev_dot_val = torch.cat([torch.zeros([batch_size, 1]).to(device), dot_val[:, :-1]], dim=-1) + dot_val = torch.stack([prev_dot_val, dot_val], dim=-1) + dot_val, _ = torch.min(dot_val, dim=-1, keepdim=False) + dot_val = dot_val.clip(-10.0, 0.0) * pts_mask + dist = (next_z_vals - prev_z_vals) + prev_esti_sdf = mid_sdf - dot_val * dist * 0.5 + next_esti_sdf = mid_sdf + dot_val * dist * 0.5 + prev_cdf = torch.sigmoid(prev_esti_sdf * inv_variance) + next_cdf = torch.sigmoid(next_esti_sdf * inv_variance) + alpha_sdf = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5) + + alpha = alpha_sdf + + # - apply pts_mask + alpha = pts_mask * alpha + + weights = alpha * torch.cumprod( + torch.cat([torch.ones([batch_size, 1]).to(device), 1. - alpha + 1e-7], -1), -1)[:, :-1] + + z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach() + return z_samples + + def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, lod, + sdf_network, gru_fusion, + # * related to conditional feature + conditional_volume=None, + conditional_valid_mask_volume=None + ): + device = rays_o.device + batch_size, n_samples = z_vals.shape + _, n_importance = new_z_vals.shape + pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None] + + if conditional_valid_mask_volume is not None: + pts_mask = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), conditional_valid_mask_volume) + pts_mask = pts_mask.reshape(batch_size, n_importance) + pts_mask_bool = (pts_mask > 0).view(-1) + else: + pts_mask = torch.ones([batch_size, n_importance]).to(pts.device) + + new_sdf = torch.ones([batch_size * n_importance, 1]).to(pts.dtype).to(device) * 100 + + if torch.sum(pts_mask) > 1: + new_outputs = sdf_network.sdf(pts.reshape(-1, 3)[pts_mask_bool], conditional_volume, lod=lod) + new_sdf[pts_mask_bool] = new_outputs['sdf_pts_scale%d' % lod] # .reshape(batch_size, n_importance) + + new_sdf = new_sdf.view(batch_size, n_importance) + + z_vals = torch.cat([z_vals, new_z_vals], dim=-1) + sdf = torch.cat([sdf, new_sdf], dim=-1) + + z_vals, index = torch.sort(z_vals, dim=-1) + xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1) + index = index.reshape(-1) + sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance) + + return z_vals, sdf + + @torch.no_grad() + def get_pts_mask_for_conditional_volume(self, pts, mask_volume): + """ + + :param pts: [N, 3] + :param mask_volume: [1, 1, X, Y, Z] + :return: + """ + num_pts = pts.shape[0] + pts = pts.view(1, 1, 1, num_pts, 3) # - should be in range (-1, 1) + + pts = torch.flip(pts, dims=[-1]) + + pts_mask = F.grid_sample(mask_volume, pts, mode='nearest') # [1, c, 1, 1, num_pts] + pts_mask = pts_mask.view(-1, num_pts).permute(1, 0).contiguous() # [num_pts, 1] + + return pts_mask + + def render_core(self, + rays_o, + rays_d, + z_vals, + sample_dist, + lod, + sdf_network, + rendering_network, + background_alpha=None, # - no use here + background_sampled_color=None, # - no use here + background_rgb=None, # - no use here + alpha_inter_ratio=0.0, + # * related to conditional feature + conditional_volume=None, + conditional_valid_mask_volume=None, + # * 2d feature maps + feature_maps=None, + color_maps=None, + w2cs=None, + intrinsics=None, + img_wh=None, + query_c2w=None, # - used for testing + if_general_rendering=True, + if_render_with_grad=True, + # * used for blending mlp rendering network + img_index=None, + rays_uv=None, + # * used for clear bg and fg + bg_num=0 + ): + device = rays_o.device + N_rays = rays_o.shape[0] + _, n_samples = z_vals.shape + dists = z_vals[..., 1:] - z_vals[..., :-1] + dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape).to(device)], -1) + + mid_z_vals = z_vals + dists * 0.5 + mid_dists = mid_z_vals[..., 1:] - mid_z_vals[..., :-1] + + pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # n_rays, n_samples, 3 + dirs = rays_d[:, None, :].expand(pts.shape) + + pts = pts.reshape(-1, 3) + dirs = dirs.reshape(-1, 3) + + # * if conditional_volume is restored from sparse volume, need mask for pts + if conditional_valid_mask_volume is not None: + pts_mask = self.get_pts_mask_for_conditional_volume(pts, conditional_valid_mask_volume) + pts_mask = pts_mask.reshape(N_rays, n_samples).float().detach() + pts_mask_bool = (pts_mask > 0).view(-1) + + if torch.sum(pts_mask_bool.float()) < 1: # ! when render out image, may meet this problem + pts_mask_bool[:100] = True + + else: + pts_mask = torch.ones([N_rays, n_samples]).to(pts.device) + # import ipdb; ipdb.set_trace() + # pts_valid = pts[pts_mask_bool] + sdf_nn_output = sdf_network.sdf(pts[pts_mask_bool], conditional_volume, lod=lod) + + sdf = torch.ones([N_rays * n_samples, 1]).to(pts.dtype).to(device) * 100 + sdf[pts_mask_bool] = sdf_nn_output['sdf_pts_scale%d' % lod] # [N_rays*n_samples, 1] + feature_vector_valid = sdf_nn_output['sdf_features_pts_scale%d' % lod] + feature_vector = torch.zeros([N_rays * n_samples, feature_vector_valid.shape[1]]).to(pts.dtype).to(device) + feature_vector[pts_mask_bool] = feature_vector_valid + + # * estimate alpha from sdf + gradients = torch.zeros([N_rays * n_samples, 3]).to(pts.dtype).to(device) + # import ipdb; ipdb.set_trace() + gradients[pts_mask_bool] = sdf_network.gradient( + pts[pts_mask_bool], conditional_volume, lod=lod).squeeze() + + sampled_color_mlp = None + rendering_valid_mask_mlp = None + sampled_color_patch = None + rendering_patch_mask = None + + if self.if_fitted_rendering: # used for fine-tuning + position_latent = sdf_nn_output['sampled_latent_scale%d' % lod] + sampled_color_mlp = torch.zeros([N_rays * n_samples, 3]).to(pts.dtype).to(device) + sampled_color_mlp_mask = torch.zeros([N_rays * n_samples, 1]).to(pts.dtype).to(device) + + # - extract pixel + pts_pixel_color, pts_pixel_mask = self.patch_projector.pixel_warp( + pts[pts_mask_bool][:, None, :], color_maps, intrinsics, + w2cs, img_wh=None) # [N_rays * n_samples,1, N_views, 3] , [N_rays*n_samples, 1, N_views] + pts_pixel_color = pts_pixel_color[:, 0, :, :] # [N_rays * n_samples, N_views, 3] + pts_pixel_mask = pts_pixel_mask[:, 0, :] # [N_rays*n_samples, N_views] + + # - extract patch + if_patch_blending = False if rays_uv is None else True + pts_patch_color, pts_patch_mask = None, None + if if_patch_blending: + pts_patch_color, pts_patch_mask = self.patch_projector.patch_warp( + pts.reshape([N_rays, n_samples, 3]), + rays_uv, gradients.reshape([N_rays, n_samples, 3]), + color_maps, + intrinsics[0], intrinsics, + query_c2w[0], torch.inverse(w2cs), img_wh=None + ) # (N_rays, n_samples, N_src, Npx, 3), (N_rays, n_samples, N_src, Npx) + N_src, Npx = pts_patch_mask.shape[2:] + pts_patch_color = pts_patch_color.view(N_rays * n_samples, N_src, Npx, 3)[pts_mask_bool] + pts_patch_mask = pts_patch_mask.view(N_rays * n_samples, N_src, Npx)[pts_mask_bool] + + sampled_color_patch = torch.zeros([N_rays * n_samples, Npx, 3]).to(device) + sampled_color_patch_mask = torch.zeros([N_rays * n_samples, 1]).to(device) + + sampled_color_mlp_, sampled_color_mlp_mask_, \ + sampled_color_patch_, sampled_color_patch_mask_ = sdf_network.color_blend( + pts[pts_mask_bool], + position_latent, + gradients[pts_mask_bool], + dirs[pts_mask_bool], + feature_vector[pts_mask_bool], + img_index=img_index, + pts_pixel_color=pts_pixel_color, + pts_pixel_mask=pts_pixel_mask, + pts_patch_color=pts_patch_color, + pts_patch_mask=pts_patch_mask + + ) # [n, 3], [n, 1] + sampled_color_mlp[pts_mask_bool] = sampled_color_mlp_ + sampled_color_mlp_mask[pts_mask_bool] = sampled_color_mlp_mask_.float() + sampled_color_mlp = sampled_color_mlp.view(N_rays, n_samples, 3) + sampled_color_mlp_mask = sampled_color_mlp_mask.view(N_rays, n_samples) + rendering_valid_mask_mlp = torch.mean(pts_mask * sampled_color_mlp_mask, dim=-1, keepdim=True) > 0.5 + + # patch blending + if if_patch_blending: + sampled_color_patch[pts_mask_bool] = sampled_color_patch_ + sampled_color_patch_mask[pts_mask_bool] = sampled_color_patch_mask_.float() + sampled_color_patch = sampled_color_patch.view(N_rays, n_samples, Npx, 3) + sampled_color_patch_mask = sampled_color_patch_mask.view(N_rays, n_samples) + rendering_patch_mask = torch.mean(pts_mask * sampled_color_patch_mask, dim=-1, + keepdim=True) > 0.5 # [N_rays, 1] + else: + sampled_color_patch, rendering_patch_mask = None, None + + if if_general_rendering: # used for general training + # [512, 128, 16]; [4, 512, 128, 59]; [4, 512, 128, 4] + ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask, _, _ = self.rendering_projector.compute( + pts.view(N_rays, n_samples, 3), + # * 3d geometry feature volumes + geometryVolume=conditional_volume[0], + geometryVolumeMask=conditional_valid_mask_volume[0], + # * 2d rendering feature maps + rendering_feature_maps=feature_maps, # [n_views, 56, 256, 256] + color_maps=color_maps, + w2cs=w2cs, + intrinsics=intrinsics, + img_wh=img_wh, + query_img_idx=0, # the index of the N_views dim for rendering + query_c2w=query_c2w, + ) + + # (N_rays, n_samples, 3) + if if_render_with_grad: + # import ipdb; ipdb.set_trace() + # [nrays, 3] [nrays, 1] + sampled_color, rendering_valid_mask = rendering_network( + ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask) + # import ipdb; ipdb.set_trace() + else: + with torch.no_grad(): + sampled_color, rendering_valid_mask = rendering_network( + ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask) + else: + sampled_color, rendering_valid_mask = None, None + + inv_variance = self.variance_network(feature_vector)[:, :1].clip(1e-6, 1e6) + + true_dot_val = (dirs * gradients).sum(-1, keepdim=True) # * calculate + + iter_cos = -(F.relu(-true_dot_val * 0.5 + 0.5) * (1.0 - alpha_inter_ratio) + F.relu( + -true_dot_val) * alpha_inter_ratio) # always non-positive + + iter_cos = iter_cos * pts_mask.view(-1, 1) + + true_estimate_sdf_half_next = sdf + iter_cos.clip(-10.0, 10.0) * dists.reshape(-1, 1) * 0.5 + true_estimate_sdf_half_prev = sdf - iter_cos.clip(-10.0, 10.0) * dists.reshape(-1, 1) * 0.5 + + prev_cdf = torch.sigmoid(true_estimate_sdf_half_prev * inv_variance) + next_cdf = torch.sigmoid(true_estimate_sdf_half_next * inv_variance) + + p = prev_cdf - next_cdf + c = prev_cdf + + if self.alpha_type == 'div': + alpha_sdf = ((p + 1e-5) / (c + 1e-5)).reshape(N_rays, n_samples).clip(0.0, 1.0) + elif self.alpha_type == 'uniform': + uniform_estimate_sdf_half_next = sdf - dists.reshape(-1, 1) * 0.5 + uniform_estimate_sdf_half_prev = sdf + dists.reshape(-1, 1) * 0.5 + uniform_prev_cdf = torch.sigmoid(uniform_estimate_sdf_half_prev * inv_variance) + uniform_next_cdf = torch.sigmoid(uniform_estimate_sdf_half_next * inv_variance) + uniform_alpha = F.relu( + (uniform_prev_cdf - uniform_next_cdf + 1e-5) / (uniform_prev_cdf + 1e-5)).reshape( + N_rays, n_samples).clip(0.0, 1.0) + alpha_sdf = uniform_alpha + else: + assert False + + alpha = alpha_sdf + + # - apply pts_mask + alpha = alpha * pts_mask + + # pts_radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(N_rays, n_samples) + # inside_sphere = (pts_radius < 1.0).float().detach() + # relax_inside_sphere = (pts_radius < 1.2).float().detach() + inside_sphere = pts_mask + relax_inside_sphere = pts_mask + + weights = alpha * torch.cumprod(torch.cat([torch.ones([N_rays, 1]).to(device), 1. - alpha + 1e-7], -1), -1)[:, + :-1] # n_rays, n_samples + weights_sum = weights.sum(dim=-1, keepdim=True) + alpha_sum = alpha.sum(dim=-1, keepdim=True) + + if bg_num > 0: + weights_sum_fg = weights[:, :-bg_num].sum(dim=-1, keepdim=True) + else: + weights_sum_fg = weights_sum + + if sampled_color is not None: + color = (sampled_color * weights[:, :, None]).sum(dim=1) + else: + color = None + # import ipdb; ipdb.set_trace() + + if background_rgb is not None and color is not None: + color = color + background_rgb * (1.0 - weights_sum) + # print("color device:" + str(color.device)) + # if color is not None: + # # import ipdb; ipdb.set_trace() + # color = color + (1.0 - weights_sum) + + + ###################* mlp color rendering ##################### + color_mlp = None + # import ipdb; ipdb.set_trace() + if sampled_color_mlp is not None: + color_mlp = (sampled_color_mlp * weights[:, :, None]).sum(dim=1) + + if background_rgb is not None and color_mlp is not None: + color_mlp = color_mlp + background_rgb * (1.0 - weights_sum) + + ############################ * patch blending ################ + blended_color_patch = None + if sampled_color_patch is not None: + blended_color_patch = (sampled_color_patch * weights[:, :, None, None]).sum(dim=1) # [N_rays, Npx, 3] + + ###################################################### + + gradient_error = (torch.linalg.norm(gradients.reshape(N_rays, n_samples, 3), ord=2, + dim=-1) - 1.0) ** 2 + # ! the gradient normal should be masked out, the pts out of the bounding box should also be penalized + gradient_error = (pts_mask * gradient_error).sum() / ( + (pts_mask).sum() + 1e-5) + + depth = (mid_z_vals * weights[:, :n_samples]).sum(dim=1, keepdim=True) + # print("[TEST]: weights_sum in render_core", weights_sum.mean()) + # print("[TEST]: weights_sum in render_core NAN number", weights_sum.isnan().sum()) + # if weights_sum.isnan().sum() > 0: + # import ipdb; ipdb.set_trace() + return { + 'color': color, + 'color_mask': rendering_valid_mask, # (N_rays, 1) + 'color_mlp': color_mlp, + 'color_mlp_mask': rendering_valid_mask_mlp, + 'sdf': sdf, # (N_rays, n_samples) + 'depth': depth, # (N_rays, 1) + 'dists': dists, + 'gradients': gradients.reshape(N_rays, n_samples, 3), + 'variance': 1.0 / inv_variance, + 'mid_z_vals': mid_z_vals, + 'weights': weights, + 'weights_sum': weights_sum, + 'alpha_sum': alpha_sum, + 'alpha_mean': alpha.mean(), + 'cdf': c.reshape(N_rays, n_samples), + 'gradient_error': gradient_error, + 'inside_sphere': inside_sphere, + 'blended_color_patch': blended_color_patch, + 'blended_color_patch_mask': rendering_patch_mask, + 'weights_sum_fg': weights_sum_fg + } + + def render(self, rays_o, rays_d, near, far, sdf_network, rendering_network, + perturb_overwrite=-1, + background_rgb=None, + alpha_inter_ratio=0.0, + # * related to conditional feature + lod=None, + conditional_volume=None, + conditional_valid_mask_volume=None, + # * 2d feature maps + feature_maps=None, + color_maps=None, + w2cs=None, + intrinsics=None, + img_wh=None, + query_c2w=None, # -used for testing + if_general_rendering=True, + if_render_with_grad=True, + # * used for blending mlp rendering network + img_index=None, + rays_uv=None, + # * importance sample for second lod network + pre_sample=False, # no use here + # * for clear foreground + bg_ratio=0.0 + ): + device = rays_o.device + N_rays = len(rays_o) + # sample_dist = 2.0 / self.n_samples + sample_dist = ((far - near) / self.n_samples).mean().item() + z_vals = torch.linspace(0.0, 1.0, self.n_samples).to(device) + z_vals = near + (far - near) * z_vals[None, :] + + bg_num = int(self.n_samples * bg_ratio) + + if z_vals.shape[0] == 1: + z_vals = z_vals.repeat(N_rays, 1) + + if bg_num > 0: + z_vals_bg = z_vals[:, self.n_samples - bg_num:] + z_vals = z_vals[:, :self.n_samples - bg_num] + + n_samples = self.n_samples - bg_num + perturb = self.perturb + + # - significantly speed up training, for the second lod network + if pre_sample: + z_vals = self.sample_z_vals_from_maskVolume(rays_o, rays_d, near, far, + conditional_valid_mask_volume) + + if perturb_overwrite >= 0: + perturb = perturb_overwrite + if perturb > 0: + # get intervals between samples + mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) + upper = torch.cat([mids, z_vals[..., -1:]], -1) + lower = torch.cat([z_vals[..., :1], mids], -1) + # stratified samples in those intervals + t_rand = torch.rand(z_vals.shape).to(device) + z_vals = lower + (upper - lower) * t_rand + + background_alpha = None + background_sampled_color = None + z_val_before = z_vals.clone() + # Up sample + if self.n_importance > 0: + with torch.no_grad(): + pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] + + sdf_outputs = sdf_network.sdf( + pts.reshape(-1, 3), conditional_volume, lod=lod) + # pdb.set_trace() + sdf = sdf_outputs['sdf_pts_scale%d' % lod].reshape(N_rays, self.n_samples - bg_num) + + n_steps = 4 + for i in range(n_steps): + new_z_vals = self.up_sample(rays_o, rays_d, z_vals, sdf, self.n_importance // n_steps, + 64 * 2 ** i, + conditional_valid_mask_volume=conditional_valid_mask_volume, + ) + + # if new_z_vals.isnan().sum() > 0: + # import ipdb; ipdb.set_trace() + + z_vals, sdf = self.cat_z_vals( + rays_o, rays_d, z_vals, new_z_vals, sdf, lod, + sdf_network, gru_fusion=False, + conditional_volume=conditional_volume, + conditional_valid_mask_volume=conditional_valid_mask_volume, + ) + + del sdf + + n_samples = self.n_samples + self.n_importance + + # Background + ret_outside = None + + # Render + if bg_num > 0: + z_vals = torch.cat([z_vals, z_vals_bg], dim=1) + # if z_vals.isnan().sum() > 0: + # import ipdb; ipdb.set_trace() + ret_fine = self.render_core(rays_o, + rays_d, + z_vals, + sample_dist, + lod, + sdf_network, + rendering_network, + background_rgb=background_rgb, + background_alpha=background_alpha, + background_sampled_color=background_sampled_color, + alpha_inter_ratio=alpha_inter_ratio, + # * related to conditional feature + conditional_volume=conditional_volume, + conditional_valid_mask_volume=conditional_valid_mask_volume, + # * 2d feature maps + feature_maps=feature_maps, + color_maps=color_maps, + w2cs=w2cs, + intrinsics=intrinsics, + img_wh=img_wh, + query_c2w=query_c2w, + if_general_rendering=if_general_rendering, + if_render_with_grad=if_render_with_grad, + # * used for blending mlp rendering network + img_index=img_index, + rays_uv=rays_uv + ) + + color_fine = ret_fine['color'] + + if self.n_outside > 0: + color_fine_mask = torch.logical_or(ret_fine['color_mask'], ret_outside['color_mask']) + else: + color_fine_mask = ret_fine['color_mask'] + + weights = ret_fine['weights'] + weights_sum = ret_fine['weights_sum'] + + gradients = ret_fine['gradients'] + mid_z_vals = ret_fine['mid_z_vals'] + + # depth = (mid_z_vals * weights[:, :n_samples]).sum(dim=1, keepdim=True) + depth = ret_fine['depth'] + depth_varaince = ((mid_z_vals - depth) ** 2 * weights[:, :n_samples]).sum(dim=-1, keepdim=True) + variance = ret_fine['variance'].reshape(N_rays, n_samples).mean(dim=-1, keepdim=True) + + # - randomly sample points from the volume, and maximize the sdf + pts_random = torch.rand([1024, 3]).float().to(device) * 2 - 1 # normalized to (-1, 1) + sdf_random = sdf_network.sdf(pts_random, conditional_volume, lod=lod)['sdf_pts_scale%d' % lod] + + result = { + 'depth': depth, + 'color_fine': color_fine, + 'color_fine_mask': color_fine_mask, + 'color_outside': ret_outside['color'] if ret_outside is not None else None, + 'color_outside_mask': ret_outside['color_mask'] if ret_outside is not None else None, + 'color_mlp': ret_fine['color_mlp'], + 'color_mlp_mask': ret_fine['color_mlp_mask'], + 'variance': variance.mean(), + 'cdf_fine': ret_fine['cdf'], + 'depth_variance': depth_varaince, + 'weights_sum': weights_sum, + 'weights_max': torch.max(weights, dim=-1, keepdim=True)[0], + 'alpha_sum': ret_fine['alpha_sum'].mean(), + 'alpha_mean': ret_fine['alpha_mean'], + 'gradients': gradients, + 'weights': weights, + 'gradient_error_fine': ret_fine['gradient_error'], + 'inside_sphere': ret_fine['inside_sphere'], + 'sdf': ret_fine['sdf'], + 'sdf_random': sdf_random, + 'blended_color_patch': ret_fine['blended_color_patch'], + 'blended_color_patch_mask': ret_fine['blended_color_patch_mask'], + 'weights_sum_fg': ret_fine['weights_sum_fg'] + } + + return result + + @torch.no_grad() + def sample_z_vals_from_sdfVolume(self, rays_o, rays_d, near, far, sdf_volume, mask_volume): + # ? based on sdf to do importance sampling, seems that too biased on pre-estimation + device = rays_o.device + N_rays = len(rays_o) + n_samples = self.n_samples * 2 + + z_vals = torch.linspace(0.0, 1.0, n_samples).to(device) + z_vals = near + (far - near) * z_vals[None, :] + + if z_vals.shape[0] == 1: + z_vals = z_vals.repeat(N_rays, 1) + + pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] + + sdf = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), sdf_volume).reshape([N_rays, n_samples]) + + new_z_vals = self.up_sample(rays_o, rays_d, z_vals, sdf, self.n_samples, + 200, + conditional_valid_mask_volume=mask_volume, + ) + return new_z_vals + + @torch.no_grad() + def sample_z_vals_from_maskVolume(self, rays_o, rays_d, near, far, mask_volume): # don't use + device = rays_o.device + N_rays = len(rays_o) + n_samples = self.n_samples * 2 + + z_vals = torch.linspace(0.0, 1.0, n_samples).to(device) + z_vals = near + (far - near) * z_vals[None, :] + + if z_vals.shape[0] == 1: + z_vals = z_vals.repeat(N_rays, 1) + + mid_z_vals = (z_vals[:, 1:] + z_vals[:, :-1]) * 0.5 + + pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] + + pts_mask = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), mask_volume).reshape( + [N_rays, n_samples - 1]) + + # empty voxel set to 0.1, non-empty voxel set to 1 + weights = torch.where(pts_mask > 0, torch.ones_like(pts_mask).to(device), + 0.1 * torch.ones_like(pts_mask).to(device)) + + # sample more pts in non-empty voxels + z_samples = sample_pdf(z_vals, weights, self.n_samples, det=True).detach() + return z_samples + + @torch.no_grad() + def filter_pts_by_depthmaps(self, coords, pred_depth_maps, proj_matrices, + partial_vol_origin, voxel_size, + near, far, depth_interval, d_plane_nums): + """ + Use the pred_depthmaps to remove redundant pts (pruned by sdf, sdf always have two sides, the back side is useless) + :param coords: [n, 3] int coords + :param pred_depth_maps: [N_views, 1, h, w] + :param proj_matrices: [N_views, 4, 4] + :param partial_vol_origin: [3] + :param voxel_size: 1 + :param near: 1 + :param far: 1 + :param depth_interval: 1 + :param d_plane_nums: 1 + :return: + """ + device = pred_depth_maps.device + n_views, _, sizeH, sizeW = pred_depth_maps.shape + + if len(partial_vol_origin.shape) == 1: + partial_vol_origin = partial_vol_origin[None, :] + pts = coords * voxel_size + partial_vol_origin + + rs_grid = pts.unsqueeze(0).expand(n_views, -1, -1) + rs_grid = rs_grid.permute(0, 2, 1).contiguous() # [n_views, 3, n_pts] + nV = rs_grid.shape[-1] + rs_grid = torch.cat([rs_grid, torch.ones([n_views, 1, nV]).to(device)], dim=1) # [n_views, 4, n_pts] + + # Project grid + im_p = proj_matrices @ rs_grid # - transform world pts to image UV space # [n_views, 4, n_pts] + im_x, im_y, im_z = im_p[:, 0], im_p[:, 1], im_p[:, 2] + im_x = im_x / im_z + im_y = im_y / im_z + + im_grid = torch.stack([2 * im_x / (sizeW - 1) - 1, 2 * im_y / (sizeH - 1) - 1], dim=-1) + + im_grid = im_grid.view(n_views, 1, -1, 2) + sampled_depths = torch.nn.functional.grid_sample(pred_depth_maps, im_grid, mode='bilinear', + padding_mode='zeros', + align_corners=True)[:, 0, 0, :] # [n_views, n_pts] + sampled_depths_valid = (sampled_depths > 0.5 * near).float() + valid_d_min = (sampled_depths - d_plane_nums * depth_interval).clamp(near.item(), + far.item()) * sampled_depths_valid + valid_d_max = (sampled_depths + d_plane_nums * depth_interval).clamp(near.item(), + far.item()) * sampled_depths_valid + + mask = im_grid.abs() <= 1 + mask = mask[:, 0] # [n_views, n_pts, 2] + mask = (mask.sum(dim=-1) == 2) & (im_z > valid_d_min) & (im_z < valid_d_max) + + mask = mask.view(n_views, -1) + mask = mask.permute(1, 0).contiguous() # [num_pts, nviews] + + mask_final = torch.sum(mask.float(), dim=1, keepdim=False) > 0 + + return mask_final + + @torch.no_grad() + def get_valid_sparse_coords_by_sdf_depthfilter(self, sdf_volume, coords_volume, mask_volume, feature_volume, + pred_depth_maps, proj_matrices, + partial_vol_origin, voxel_size, + near, far, depth_interval, d_plane_nums, + threshold=0.02, maximum_pts=110000): + """ + assume batch size == 1, from the first lod to get sparse voxels + :param sdf_volume: [1, X, Y, Z] + :param coords_volume: [3, X, Y, Z] + :param mask_volume: [1, X, Y, Z] + :param feature_volume: [C, X, Y, Z] + :param threshold: + :return: + """ + device = coords_volume.device + _, dX, dY, dZ = coords_volume.shape + + def prune(sdf_pts, coords_pts, mask_volume, threshold): + occupancy_mask = (torch.abs(sdf_pts) < threshold).squeeze(1) # [num_pts] + valid_coords = coords_pts[occupancy_mask] + + # - filter backside surface by depth maps + mask_filtered = self.filter_pts_by_depthmaps(valid_coords, pred_depth_maps, proj_matrices, + partial_vol_origin, voxel_size, + near, far, depth_interval, d_plane_nums) + valid_coords = valid_coords[mask_filtered] + + # - dilate + occupancy_mask = sparse_to_dense_channel(valid_coords, 1, [dX, dY, dZ], 1, 0, device) # [dX, dY, dZ, 1] + + # - dilate + occupancy_mask = occupancy_mask.float() + occupancy_mask = occupancy_mask.view(1, 1, dX, dY, dZ) + occupancy_mask = F.avg_pool3d(occupancy_mask, kernel_size=7, stride=1, padding=3) + occupancy_mask = occupancy_mask.view(-1, 1) > 0 + + final_mask = torch.logical_and(mask_volume, occupancy_mask)[:, 0] # [num_pts] + + return final_mask, torch.sum(final_mask.float()) + + C, dX, dY, dZ = feature_volume.shape + sdf_volume = sdf_volume.permute(1, 2, 3, 0).contiguous().view(-1, 1) + coords_volume = coords_volume.permute(1, 2, 3, 0).contiguous().view(-1, 3) + mask_volume = mask_volume.permute(1, 2, 3, 0).contiguous().view(-1, 1) + feature_volume = feature_volume.permute(1, 2, 3, 0).contiguous().view(-1, C) + + # - for check + # sdf_volume = torch.rand_like(sdf_volume).float().to(sdf_volume.device) * 0.02 + + final_mask, valid_num = prune(sdf_volume, coords_volume, mask_volume, threshold) + + while (valid_num > maximum_pts) and (threshold > 0.003): + threshold = threshold - 0.002 + final_mask, valid_num = prune(sdf_volume, coords_volume, mask_volume, threshold) + + valid_coords = coords_volume[final_mask] # [N, 3] + valid_feature = feature_volume[final_mask] # [N, C] + + valid_coords = torch.cat([torch.ones([valid_coords.shape[0], 1]).to(valid_coords.device) * 0, + valid_coords], dim=1) # [N, 4], append batch idx + + # ! if the valid_num is still larger than maximum_pts, sample part of pts + if valid_num > maximum_pts: + valid_num = valid_num.long() + occupancy = torch.ones([valid_num]).to(device) > 0 + choice = np.random.choice(valid_num.cpu().numpy(), valid_num.cpu().numpy() - maximum_pts, + replace=False) + ind = torch.nonzero(occupancy).to(device) + occupancy[ind[choice]] = False + valid_coords = valid_coords[occupancy] + valid_feature = valid_feature[occupancy] + + print(threshold, "randomly sample to save memory") + + return valid_coords, valid_feature + + @torch.no_grad() + def get_valid_sparse_coords_by_sdf(self, sdf_volume, coords_volume, mask_volume, feature_volume, threshold=0.02, + maximum_pts=110000): + """ + assume batch size == 1, from the first lod to get sparse voxels + :param sdf_volume: [num_pts, 1] + :param coords_volume: [3, X, Y, Z] + :param mask_volume: [1, X, Y, Z] + :param feature_volume: [C, X, Y, Z] + :param threshold: + :return: + """ + + def prune(sdf_volume, mask_volume, threshold): + occupancy_mask = torch.abs(sdf_volume) < threshold # [num_pts, 1] + + # - dilate + occupancy_mask = occupancy_mask.float() + occupancy_mask = occupancy_mask.view(1, 1, dX, dY, dZ) + occupancy_mask = F.avg_pool3d(occupancy_mask, kernel_size=7, stride=1, padding=3) + occupancy_mask = occupancy_mask.view(-1, 1) > 0 + + final_mask = torch.logical_and(mask_volume, occupancy_mask)[:, 0] # [num_pts] + + return final_mask, torch.sum(final_mask.float()) + + C, dX, dY, dZ = feature_volume.shape + coords_volume = coords_volume.permute(1, 2, 3, 0).contiguous().view(-1, 3) + mask_volume = mask_volume.permute(1, 2, 3, 0).contiguous().view(-1, 1) + feature_volume = feature_volume.permute(1, 2, 3, 0).contiguous().view(-1, C) + + final_mask, valid_num = prune(sdf_volume, mask_volume, threshold) + + while (valid_num > maximum_pts) and (threshold > 0.003): + threshold = threshold - 0.002 + final_mask, valid_num = prune(sdf_volume, mask_volume, threshold) + + valid_coords = coords_volume[final_mask] # [N, 3] + valid_feature = feature_volume[final_mask] # [N, C] + + valid_coords = torch.cat([torch.ones([valid_coords.shape[0], 1]).to(valid_coords.device) * 0, + valid_coords], dim=1) # [N, 4], append batch idx + + # ! if the valid_num is still larger than maximum_pts, sample part of pts + if valid_num > maximum_pts: + device = sdf_volume.device + valid_num = valid_num.long() + occupancy = torch.ones([valid_num]).to(device) > 0 + choice = np.random.choice(valid_num.cpu().numpy(), valid_num.cpu().numpy() - maximum_pts, + replace=False) + ind = torch.nonzero(occupancy).to(device) + occupancy[ind[choice]] = False + valid_coords = valid_coords[occupancy] + valid_feature = valid_feature[occupancy] + + print(threshold, "randomly sample to save memory") + + return valid_coords, valid_feature + + @torch.no_grad() + def extract_fields(self, bound_min, bound_max, resolution, query_func, device, + # * related to conditional feature + **kwargs + ): + N = 64 + X = torch.linspace(bound_min[0], bound_max[0], resolution).to(device).split(N) + Y = torch.linspace(bound_min[1], bound_max[1], resolution).to(device).split(N) + Z = torch.linspace(bound_min[2], bound_max[2], resolution).to(device).split(N) + + u = np.zeros([resolution, resolution, resolution], dtype=np.float32) + with torch.no_grad(): + for xi, xs in enumerate(X): + for yi, ys in enumerate(Y): + for zi, zs in enumerate(Z): + xx, yy, zz = torch.meshgrid(xs, ys, zs) + pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) + + # ! attention, the query function is different for extract geometry and fields + output = query_func(pts, **kwargs) + sdf = output['sdf_pts_scale%d' % kwargs['lod']].reshape(len(xs), len(ys), + len(zs)).detach().cpu().numpy() + + u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = -1 * sdf + return u + + @torch.no_grad() + def extract_geometry(self, sdf_network, bound_min, bound_max, resolution, threshold, device, occupancy_mask=None, + # * 3d feature volume + **kwargs + ): + # logging.info('threshold: {}'.format(threshold)) + + u = self.extract_fields(bound_min, bound_max, resolution, + lambda pts, **kwargs: sdf_network.sdf(pts, **kwargs), + # - sdf need to be multiplied by -1 + device, + # * 3d feature volume + **kwargs + ) + if occupancy_mask is not None: + dX, dY, dZ = occupancy_mask.shape + empty_mask = 1 - occupancy_mask + empty_mask = empty_mask.view(1, 1, dX, dY, dZ) + # - dilation + # empty_mask = F.avg_pool3d(empty_mask, kernel_size=7, stride=1, padding=3) + empty_mask = F.interpolate(empty_mask, [resolution, resolution, resolution], mode='nearest') + empty_mask = empty_mask.view(resolution, resolution, resolution).cpu().numpy() > 0 + u[empty_mask] = -100 + del empty_mask + + vertices, triangles = mcubes.marching_cubes(u, threshold) + b_max_np = bound_max.detach().cpu().numpy() + b_min_np = bound_min.detach().cpu().numpy() + + vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] + return vertices, triangles, u + + @torch.no_grad() + def extract_depth_maps(self, sdf_network, con_volume, intrinsics, c2ws, H, W, near, far): + """ + extract depth maps from the density volume + :param con_volume: [1, 1+C, dX, dY, dZ] can by con_volume or sdf_volume + :param c2ws: [B, 4, 4] + :param H: + :param W: + :param near: + :param far: + :return: + """ + device = con_volume.device + batch_size = intrinsics.shape[0] + + with torch.no_grad(): + ys, xs = torch.meshgrid(torch.linspace(0, H - 1, H), + torch.linspace(0, W - 1, W)) # pytorch's meshgrid has indexing='ij' + p = torch.stack([xs, ys, torch.ones_like(ys)], dim=-1) # H, W, 3 + + intrinsics_inv = torch.inverse(intrinsics) + + p = p.view(-1, 3).float().to(device) # N_rays, 3 + p = torch.matmul(intrinsics_inv[:, None, :3, :3], p[:, :, None]).squeeze() # Batch, N_rays, 3 + rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # Batch, N_rays, 3 + rays_v = torch.matmul(c2ws[:, None, :3, :3], rays_v[:, :, :, None]).squeeze() # Batch, N_rays, 3 + rays_o = c2ws[:, None, :3, 3].expand(rays_v.shape) # Batch, N_rays, 3 + rays_d = rays_v + + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + ################## - sphere tracer to extract depth maps ###################### + depth_masks_sphere, depth_maps_sphere = self.ray_tracer.extract_depth_maps( + rays_o, rays_d, + near[None, :].repeat(rays_o.shape[0], 1), + far[None, :].repeat(rays_o.shape[0], 1), + sdf_network, con_volume + ) + + depth_maps = depth_maps_sphere.view(batch_size, 1, H, W) + depth_masks = depth_masks_sphere.view(batch_size, 1, H, W) + + depth_maps = torch.where(depth_masks, depth_maps, + torch.zeros_like(depth_masks.float()).to(device)) # fill invalid pixels by 0 + + return depth_maps, depth_masks diff --git a/SparseNeuS_demo_v1/models/sparse_neus_renderer_normals_new.py b/SparseNeuS_demo_v1/models/sparse_neus_renderer_normals_new.py new file mode 100644 index 0000000000000000000000000000000000000000..34e22aa312312b4fc7e8225e15f1eea5a2de71d1 --- /dev/null +++ b/SparseNeuS_demo_v1/models/sparse_neus_renderer_normals_new.py @@ -0,0 +1,992 @@ +""" +The codes are heavily borrowed from NeuS +""" + +import os +import cv2 as cv +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import logging +import mcubes +import trimesh +from icecream import ic +from models.render_utils import sample_pdf + +from models.projector import Projector +from tsparse.torchsparse_utils import sparse_to_dense_channel + +from models.fast_renderer import FastRenderer + +from models.patch_projector import PatchProjector + +from models.rays import gen_rays_between + +import pdb + + +class SparseNeuSRenderer(nn.Module): + """ + conditional neus render; + optimize on normalized world space; + warped by nn.Module to support DataParallel traning + """ + + def __init__(self, + rendering_network_outside, + sdf_network, + variance_network, + rendering_network, + n_samples, + n_importance, + n_outside, + perturb, + alpha_type='div', + conf=None + ): + super(SparseNeuSRenderer, self).__init__() + + self.conf = conf + self.base_exp_dir = conf['general.base_exp_dir'] + + # network setups + self.rendering_network_outside = rendering_network_outside + self.sdf_network = sdf_network + self.variance_network = variance_network + self.rendering_network = rendering_network + + self.n_samples = n_samples + self.n_importance = n_importance + self.n_outside = n_outside + self.perturb = perturb + self.alpha_type = alpha_type + + self.rendering_projector = Projector() # used to obtain features for generalized rendering + + self.h_patch_size = self.conf.get_int('model.h_patch_size', default=3) + self.patch_projector = PatchProjector(self.h_patch_size) + + self.ray_tracer = FastRenderer() # ray_tracer to extract depth maps from sdf_volume + + # - fitted rendering or general rendering + try: + self.if_fitted_rendering = self.sdf_network.if_fitted_rendering + except: + self.if_fitted_rendering = False + + def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_variance, + conditional_valid_mask_volume=None): + device = rays_o.device + batch_size, n_samples = z_vals.shape + pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3 + + if conditional_valid_mask_volume is not None: + pts_mask = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), conditional_valid_mask_volume) + pts_mask = pts_mask.reshape(batch_size, n_samples) + pts_mask = pts_mask[:, :-1] * pts_mask[:, 1:] # [batch_size, n_samples-1] + else: + pts_mask = torch.ones([batch_size, n_samples]).to(pts.device) + + sdf = sdf.reshape(batch_size, n_samples) + prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:] + prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:] + mid_sdf = (prev_sdf + next_sdf) * 0.5 + dot_val = None + if self.alpha_type == 'uniform': + dot_val = torch.ones([batch_size, n_samples - 1]) * -1.0 + else: + dot_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5) + prev_dot_val = torch.cat([torch.zeros([batch_size, 1]).to(device), dot_val[:, :-1]], dim=-1) + dot_val = torch.stack([prev_dot_val, dot_val], dim=-1) + dot_val, _ = torch.min(dot_val, dim=-1, keepdim=False) + dot_val = dot_val.clip(-10.0, 0.0) * pts_mask + dist = (next_z_vals - prev_z_vals) + prev_esti_sdf = mid_sdf - dot_val * dist * 0.5 + next_esti_sdf = mid_sdf + dot_val * dist * 0.5 + prev_cdf = torch.sigmoid(prev_esti_sdf * inv_variance) + next_cdf = torch.sigmoid(next_esti_sdf * inv_variance) + alpha_sdf = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5) + + alpha = alpha_sdf + + # - apply pts_mask + alpha = pts_mask * alpha + + weights = alpha * torch.cumprod( + torch.cat([torch.ones([batch_size, 1]).to(device), 1. - alpha + 1e-7], -1), -1)[:, :-1] + + z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach() + return z_samples + + def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, lod, + sdf_network, gru_fusion, + # * related to conditional feature + conditional_volume=None, + conditional_valid_mask_volume=None + ): + device = rays_o.device + batch_size, n_samples = z_vals.shape + _, n_importance = new_z_vals.shape + pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None] + + if conditional_valid_mask_volume is not None: + pts_mask = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), conditional_valid_mask_volume) + pts_mask = pts_mask.reshape(batch_size, n_importance) + pts_mask_bool = (pts_mask > 0).view(-1) + else: + pts_mask = torch.ones([batch_size, n_importance]).to(pts.device) + + new_sdf = torch.ones([batch_size * n_importance, 1]).to(pts.dtype).to(device) * 100 + + if torch.sum(pts_mask) > 1: + new_outputs = sdf_network.sdf(pts.reshape(-1, 3)[pts_mask_bool], conditional_volume, lod=lod) + new_sdf[pts_mask_bool] = new_outputs['sdf_pts_scale%d' % lod] # .reshape(batch_size, n_importance) + + new_sdf = new_sdf.view(batch_size, n_importance) + + z_vals = torch.cat([z_vals, new_z_vals], dim=-1) + sdf = torch.cat([sdf, new_sdf], dim=-1) + + z_vals, index = torch.sort(z_vals, dim=-1) + xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1) + index = index.reshape(-1) + sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance) + + return z_vals, sdf + + @torch.no_grad() + def get_pts_mask_for_conditional_volume(self, pts, mask_volume): + """ + + :param pts: [N, 3] + :param mask_volume: [1, 1, X, Y, Z] + :return: + """ + num_pts = pts.shape[0] + pts = pts.view(1, 1, 1, num_pts, 3) # - should be in range (-1, 1) + + pts = torch.flip(pts, dims=[-1]) + + pts_mask = F.grid_sample(mask_volume, pts, mode='nearest') # [1, c, 1, 1, num_pts] + pts_mask = pts_mask.view(-1, num_pts).permute(1, 0).contiguous() # [num_pts, 1] + + return pts_mask + + def render_core(self, + rays_o, + rays_d, + z_vals, + sample_dist, + lod, + sdf_network, + rendering_network, + background_alpha=None, # - no use here + background_sampled_color=None, # - no use here + background_rgb=None, # - no use here + alpha_inter_ratio=0.0, + # * related to conditional feature + conditional_volume=None, + conditional_valid_mask_volume=None, + # * 2d feature maps + feature_maps=None, + color_maps=None, + w2cs=None, + intrinsics=None, + img_wh=None, + query_c2w=None, # - used for testing + if_general_rendering=True, + if_render_with_grad=True, + # * used for blending mlp rendering network + img_index=None, + rays_uv=None, + # * used for clear bg and fg + bg_num=0 + ): + device = rays_o.device + N_rays = rays_o.shape[0] + _, n_samples = z_vals.shape + dists = z_vals[..., 1:] - z_vals[..., :-1] + dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape).to(device)], -1) + + mid_z_vals = z_vals + dists * 0.5 + mid_dists = mid_z_vals[..., 1:] - mid_z_vals[..., :-1] + + pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # n_rays, n_samples, 3 + dirs = rays_d[:, None, :].expand(pts.shape) + + pts = pts.reshape(-1, 3) + dirs = dirs.reshape(-1, 3) + + # * if conditional_volume is restored from sparse volume, need mask for pts + if conditional_valid_mask_volume is not None: + pts_mask = self.get_pts_mask_for_conditional_volume(pts, conditional_valid_mask_volume) + pts_mask = pts_mask.reshape(N_rays, n_samples).float().detach() + pts_mask_bool = (pts_mask > 0).view(-1) + + if torch.sum(pts_mask_bool.float()) < 1: # ! when render out image, may meet this problem + pts_mask_bool[:100] = True + + else: + pts_mask = torch.ones([N_rays, n_samples]).to(pts.device) + # import ipdb; ipdb.set_trace() + # pts_valid = pts[pts_mask_bool] + sdf_nn_output = sdf_network.sdf(pts[pts_mask_bool], conditional_volume, lod=lod) + + sdf = torch.ones([N_rays * n_samples, 1]).to(pts.dtype).to(device) * 100 + sdf[pts_mask_bool] = sdf_nn_output['sdf_pts_scale%d' % lod] # [N_rays*n_samples, 1] + feature_vector_valid = sdf_nn_output['sdf_features_pts_scale%d' % lod] + feature_vector = torch.zeros([N_rays * n_samples, feature_vector_valid.shape[1]]).to(pts.dtype).to(device) + feature_vector[pts_mask_bool] = feature_vector_valid + + # * estimate alpha from sdf + gradients = torch.zeros([N_rays * n_samples, 3]).to(pts.dtype).to(device) + # import ipdb; ipdb.set_trace() + gradients[pts_mask_bool] = sdf_network.gradient( + pts[pts_mask_bool], conditional_volume, lod=lod).squeeze() + + sampled_color_mlp = None + rendering_valid_mask_mlp = None + sampled_color_patch = None + rendering_patch_mask = None + + if self.if_fitted_rendering: # used for fine-tuning + position_latent = sdf_nn_output['sampled_latent_scale%d' % lod] + sampled_color_mlp = torch.zeros([N_rays * n_samples, 3]).to(pts.dtype).to(device) + sampled_color_mlp_mask = torch.zeros([N_rays * n_samples, 1]).to(pts.dtype).to(device) + + # - extract pixel + pts_pixel_color, pts_pixel_mask = self.patch_projector.pixel_warp( + pts[pts_mask_bool][:, None, :], color_maps, intrinsics, + w2cs, img_wh=None) # [N_rays * n_samples,1, N_views, 3] , [N_rays*n_samples, 1, N_views] + pts_pixel_color = pts_pixel_color[:, 0, :, :] # [N_rays * n_samples, N_views, 3] + pts_pixel_mask = pts_pixel_mask[:, 0, :] # [N_rays*n_samples, N_views] + + # - extract patch + if_patch_blending = False if rays_uv is None else True + pts_patch_color, pts_patch_mask = None, None + if if_patch_blending: + pts_patch_color, pts_patch_mask = self.patch_projector.patch_warp( + pts.reshape([N_rays, n_samples, 3]), + rays_uv, gradients.reshape([N_rays, n_samples, 3]), + color_maps, + intrinsics[0], intrinsics, + query_c2w[0], torch.inverse(w2cs), img_wh=None + ) # (N_rays, n_samples, N_src, Npx, 3), (N_rays, n_samples, N_src, Npx) + N_src, Npx = pts_patch_mask.shape[2:] + pts_patch_color = pts_patch_color.view(N_rays * n_samples, N_src, Npx, 3)[pts_mask_bool] + pts_patch_mask = pts_patch_mask.view(N_rays * n_samples, N_src, Npx)[pts_mask_bool] + + sampled_color_patch = torch.zeros([N_rays * n_samples, Npx, 3]).to(device) + sampled_color_patch_mask = torch.zeros([N_rays * n_samples, 1]).to(device) + + sampled_color_mlp_, sampled_color_mlp_mask_, \ + sampled_color_patch_, sampled_color_patch_mask_ = sdf_network.color_blend( + pts[pts_mask_bool], + position_latent, + gradients[pts_mask_bool], + dirs[pts_mask_bool], + feature_vector[pts_mask_bool], + img_index=img_index, + pts_pixel_color=pts_pixel_color, + pts_pixel_mask=pts_pixel_mask, + pts_patch_color=pts_patch_color, + pts_patch_mask=pts_patch_mask + + ) # [n, 3], [n, 1] + sampled_color_mlp[pts_mask_bool] = sampled_color_mlp_ + sampled_color_mlp_mask[pts_mask_bool] = sampled_color_mlp_mask_.float() + sampled_color_mlp = sampled_color_mlp.view(N_rays, n_samples, 3) + sampled_color_mlp_mask = sampled_color_mlp_mask.view(N_rays, n_samples) + rendering_valid_mask_mlp = torch.mean(pts_mask * sampled_color_mlp_mask, dim=-1, keepdim=True) > 0.5 + + # patch blending + if if_patch_blending: + sampled_color_patch[pts_mask_bool] = sampled_color_patch_ + sampled_color_patch_mask[pts_mask_bool] = sampled_color_patch_mask_.float() + sampled_color_patch = sampled_color_patch.view(N_rays, n_samples, Npx, 3) + sampled_color_patch_mask = sampled_color_patch_mask.view(N_rays, n_samples) + rendering_patch_mask = torch.mean(pts_mask * sampled_color_patch_mask, dim=-1, + keepdim=True) > 0.5 # [N_rays, 1] + else: + sampled_color_patch, rendering_patch_mask = None, None + + if if_general_rendering: # used for general training + # [512, 128, 16]; [4, 512, 128, 59]; [4, 512, 128, 4] + ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask, _, _ = self.rendering_projector.compute_view_independent( + pts.view(N_rays, n_samples, 3), + # * 3d geometry feature volumes + geometryVolume=conditional_volume[0], + geometryVolumeMask=conditional_valid_mask_volume[0], + sdf_network=sdf_network, + lod=lod, + # * 2d rendering feature maps + rendering_feature_maps=feature_maps, # [n_views, 56, 256, 256] + color_maps=color_maps, + w2cs=w2cs, + intrinsics=intrinsics, + img_wh=img_wh, + query_img_idx=0, # the index of the N_views dim for rendering + query_c2w=query_c2w, + ) + + # (N_rays, n_samples, 3) + if if_render_with_grad: + # import ipdb; ipdb.set_trace() + # [nrays, 3] [nrays, 1] + sampled_color, rendering_valid_mask = rendering_network( + ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask) + # import ipdb; ipdb.set_trace() + else: + with torch.no_grad(): + sampled_color, rendering_valid_mask = rendering_network( + ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask) + else: + sampled_color, rendering_valid_mask = None, None + + inv_variance = self.variance_network(feature_vector)[:, :1].clip(1e-6, 1e6) + + true_dot_val = (dirs * gradients).sum(-1, keepdim=True) # * calculate + + iter_cos = -(F.relu(-true_dot_val * 0.5 + 0.5) * (1.0 - alpha_inter_ratio) + F.relu( + -true_dot_val) * alpha_inter_ratio) # always non-positive + + iter_cos = iter_cos * pts_mask.view(-1, 1) + + true_estimate_sdf_half_next = sdf + iter_cos.clip(-10.0, 10.0) * dists.reshape(-1, 1) * 0.5 + true_estimate_sdf_half_prev = sdf - iter_cos.clip(-10.0, 10.0) * dists.reshape(-1, 1) * 0.5 + + prev_cdf = torch.sigmoid(true_estimate_sdf_half_prev * inv_variance) + next_cdf = torch.sigmoid(true_estimate_sdf_half_next * inv_variance) + + p = prev_cdf - next_cdf + c = prev_cdf + + if self.alpha_type == 'div': + alpha_sdf = ((p + 1e-5) / (c + 1e-5)).reshape(N_rays, n_samples).clip(0.0, 1.0) + elif self.alpha_type == 'uniform': + uniform_estimate_sdf_half_next = sdf - dists.reshape(-1, 1) * 0.5 + uniform_estimate_sdf_half_prev = sdf + dists.reshape(-1, 1) * 0.5 + uniform_prev_cdf = torch.sigmoid(uniform_estimate_sdf_half_prev * inv_variance) + uniform_next_cdf = torch.sigmoid(uniform_estimate_sdf_half_next * inv_variance) + uniform_alpha = F.relu( + (uniform_prev_cdf - uniform_next_cdf + 1e-5) / (uniform_prev_cdf + 1e-5)).reshape( + N_rays, n_samples).clip(0.0, 1.0) + alpha_sdf = uniform_alpha + else: + assert False + + alpha = alpha_sdf + + # - apply pts_mask + alpha = alpha * pts_mask + + # pts_radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(N_rays, n_samples) + # inside_sphere = (pts_radius < 1.0).float().detach() + # relax_inside_sphere = (pts_radius < 1.2).float().detach() + inside_sphere = pts_mask + relax_inside_sphere = pts_mask + + weights = alpha * torch.cumprod(torch.cat([torch.ones([N_rays, 1]).to(device), 1. - alpha + 1e-7], -1), -1)[:, + :-1] # n_rays, n_samples + weights_sum = weights.sum(dim=-1, keepdim=True) + alpha_sum = alpha.sum(dim=-1, keepdim=True) + + if bg_num > 0: + weights_sum_fg = weights[:, :-bg_num].sum(dim=-1, keepdim=True) + else: + weights_sum_fg = weights_sum + + if sampled_color is not None: + color = (sampled_color * weights[:, :, None]).sum(dim=1) + else: + color = None + # import ipdb; ipdb.set_trace() + + if background_rgb is not None and color is not None: + color = color + background_rgb * (1.0 - weights_sum) + # print("color device:" + str(color.device)) + # if color is not None: + # # import ipdb; ipdb.set_trace() + # color = color + (1.0 - weights_sum) + + + ###################* mlp color rendering ##################### + color_mlp = None + # import ipdb; ipdb.set_trace() + if sampled_color_mlp is not None: + color_mlp = (sampled_color_mlp * weights[:, :, None]).sum(dim=1) + + if background_rgb is not None and color_mlp is not None: + color_mlp = color_mlp + background_rgb * (1.0 - weights_sum) + + ############################ * patch blending ################ + blended_color_patch = None + if sampled_color_patch is not None: + blended_color_patch = (sampled_color_patch * weights[:, :, None, None]).sum(dim=1) # [N_rays, Npx, 3] + + ###################################################### + + gradient_error = (torch.linalg.norm(gradients.reshape(N_rays, n_samples, 3), ord=2, + dim=-1) - 1.0) ** 2 + # ! the gradient normal should be masked out, the pts out of the bounding box should also be penalized + gradient_error = (pts_mask * gradient_error).sum() / ( + (pts_mask).sum() + 1e-5) + + depth = (mid_z_vals * weights[:, :n_samples]).sum(dim=1, keepdim=True) + # print("[TEST]: weights_sum in render_core", weights_sum.mean()) + # print("[TEST]: weights_sum in render_core NAN number", weights_sum.isnan().sum()) + # if weights_sum.isnan().sum() > 0: + # import ipdb; ipdb.set_trace() + return { + 'color': color, + 'color_mask': rendering_valid_mask, # (N_rays, 1) + 'color_mlp': color_mlp, + 'color_mlp_mask': rendering_valid_mask_mlp, + 'sdf': sdf, # (N_rays, n_samples) + 'depth': depth, # (N_rays, 1) + 'dists': dists, + 'gradients': gradients.reshape(N_rays, n_samples, 3), + 'variance': 1.0 / inv_variance, + 'mid_z_vals': mid_z_vals, + 'weights': weights, + 'weights_sum': weights_sum, + 'alpha_sum': alpha_sum, + 'alpha_mean': alpha.mean(), + 'cdf': c.reshape(N_rays, n_samples), + 'gradient_error': gradient_error, + 'inside_sphere': inside_sphere, + 'blended_color_patch': blended_color_patch, + 'blended_color_patch_mask': rendering_patch_mask, + 'weights_sum_fg': weights_sum_fg + } + + def render(self, rays_o, rays_d, near, far, sdf_network, rendering_network, + perturb_overwrite=-1, + background_rgb=None, + alpha_inter_ratio=0.0, + # * related to conditional feature + lod=None, + conditional_volume=None, + conditional_valid_mask_volume=None, + # * 2d feature maps + feature_maps=None, + color_maps=None, + w2cs=None, + intrinsics=None, + img_wh=None, + query_c2w=None, # -used for testing + if_general_rendering=True, + if_render_with_grad=True, + # * used for blending mlp rendering network + img_index=None, + rays_uv=None, + # * importance sample for second lod network + pre_sample=False, # no use here + # * for clear foreground + bg_ratio=0.0 + ): + device = rays_o.device + N_rays = len(rays_o) + # sample_dist = 2.0 / self.n_samples + sample_dist = ((far - near) / self.n_samples).mean().item() + z_vals = torch.linspace(0.0, 1.0, self.n_samples).to(device) + z_vals = near + (far - near) * z_vals[None, :] + + bg_num = int(self.n_samples * bg_ratio) + + if z_vals.shape[0] == 1: + z_vals = z_vals.repeat(N_rays, 1) + + if bg_num > 0: + z_vals_bg = z_vals[:, self.n_samples - bg_num:] + z_vals = z_vals[:, :self.n_samples - bg_num] + + n_samples = self.n_samples - bg_num + perturb = self.perturb + + # - significantly speed up training, for the second lod network + if pre_sample: + z_vals = self.sample_z_vals_from_maskVolume(rays_o, rays_d, near, far, + conditional_valid_mask_volume) + + if perturb_overwrite >= 0: + perturb = perturb_overwrite + if perturb > 0: + # get intervals between samples + mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) + upper = torch.cat([mids, z_vals[..., -1:]], -1) + lower = torch.cat([z_vals[..., :1], mids], -1) + # stratified samples in those intervals + t_rand = torch.rand(z_vals.shape).to(device) + z_vals = lower + (upper - lower) * t_rand + + background_alpha = None + background_sampled_color = None + z_val_before = z_vals.clone() + # Up sample + if self.n_importance > 0: + with torch.no_grad(): + pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] + + sdf_outputs = sdf_network.sdf( + pts.reshape(-1, 3), conditional_volume, lod=lod) + # pdb.set_trace() + sdf = sdf_outputs['sdf_pts_scale%d' % lod].reshape(N_rays, self.n_samples - bg_num) + + n_steps = 4 + for i in range(n_steps): + new_z_vals = self.up_sample(rays_o, rays_d, z_vals, sdf, self.n_importance // n_steps, + 64 * 2 ** i, + conditional_valid_mask_volume=conditional_valid_mask_volume, + ) + + # if new_z_vals.isnan().sum() > 0: + # import ipdb; ipdb.set_trace() + + z_vals, sdf = self.cat_z_vals( + rays_o, rays_d, z_vals, new_z_vals, sdf, lod, + sdf_network, gru_fusion=False, + conditional_volume=conditional_volume, + conditional_valid_mask_volume=conditional_valid_mask_volume, + ) + + del sdf + + n_samples = self.n_samples + self.n_importance + + # Background + ret_outside = None + + # Render + if bg_num > 0: + z_vals = torch.cat([z_vals, z_vals_bg], dim=1) + # if z_vals.isnan().sum() > 0: + # import ipdb; ipdb.set_trace() + ret_fine = self.render_core(rays_o, + rays_d, + z_vals, + sample_dist, + lod, + sdf_network, + rendering_network, + background_rgb=background_rgb, + background_alpha=background_alpha, + background_sampled_color=background_sampled_color, + alpha_inter_ratio=alpha_inter_ratio, + # * related to conditional feature + conditional_volume=conditional_volume, + conditional_valid_mask_volume=conditional_valid_mask_volume, + # * 2d feature maps + feature_maps=feature_maps, + color_maps=color_maps, + w2cs=w2cs, + intrinsics=intrinsics, + img_wh=img_wh, + query_c2w=query_c2w, + if_general_rendering=if_general_rendering, + if_render_with_grad=if_render_with_grad, + # * used for blending mlp rendering network + img_index=img_index, + rays_uv=rays_uv + ) + + color_fine = ret_fine['color'] + + if self.n_outside > 0: + color_fine_mask = torch.logical_or(ret_fine['color_mask'], ret_outside['color_mask']) + else: + color_fine_mask = ret_fine['color_mask'] + + weights = ret_fine['weights'] + weights_sum = ret_fine['weights_sum'] + + gradients = ret_fine['gradients'] + mid_z_vals = ret_fine['mid_z_vals'] + + # depth = (mid_z_vals * weights[:, :n_samples]).sum(dim=1, keepdim=True) + depth = ret_fine['depth'] + depth_varaince = ((mid_z_vals - depth) ** 2 * weights[:, :n_samples]).sum(dim=-1, keepdim=True) + variance = ret_fine['variance'].reshape(N_rays, n_samples).mean(dim=-1, keepdim=True) + + # - randomly sample points from the volume, and maximize the sdf + pts_random = torch.rand([1024, 3]).float().to(device) * 2 - 1 # normalized to (-1, 1) + sdf_random = sdf_network.sdf(pts_random, conditional_volume, lod=lod)['sdf_pts_scale%d' % lod] + + result = { + 'depth': depth, + 'color_fine': color_fine, + 'color_fine_mask': color_fine_mask, + 'color_outside': ret_outside['color'] if ret_outside is not None else None, + 'color_outside_mask': ret_outside['color_mask'] if ret_outside is not None else None, + 'color_mlp': ret_fine['color_mlp'], + 'color_mlp_mask': ret_fine['color_mlp_mask'], + 'variance': variance.mean(), + 'cdf_fine': ret_fine['cdf'], + 'depth_variance': depth_varaince, + 'weights_sum': weights_sum, + 'weights_max': torch.max(weights, dim=-1, keepdim=True)[0], + 'alpha_sum': ret_fine['alpha_sum'].mean(), + 'alpha_mean': ret_fine['alpha_mean'], + 'gradients': gradients, + 'weights': weights, + 'gradient_error_fine': ret_fine['gradient_error'], + 'inside_sphere': ret_fine['inside_sphere'], + 'sdf': ret_fine['sdf'], + 'sdf_random': sdf_random, + 'blended_color_patch': ret_fine['blended_color_patch'], + 'blended_color_patch_mask': ret_fine['blended_color_patch_mask'], + 'weights_sum_fg': ret_fine['weights_sum_fg'] + } + + return result + + @torch.no_grad() + def sample_z_vals_from_sdfVolume(self, rays_o, rays_d, near, far, sdf_volume, mask_volume): + # ? based on sdf to do importance sampling, seems that too biased on pre-estimation + device = rays_o.device + N_rays = len(rays_o) + n_samples = self.n_samples * 2 + + z_vals = torch.linspace(0.0, 1.0, n_samples).to(device) + z_vals = near + (far - near) * z_vals[None, :] + + if z_vals.shape[0] == 1: + z_vals = z_vals.repeat(N_rays, 1) + + pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] + + sdf = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), sdf_volume).reshape([N_rays, n_samples]) + + new_z_vals = self.up_sample(rays_o, rays_d, z_vals, sdf, self.n_samples, + 200, + conditional_valid_mask_volume=mask_volume, + ) + return new_z_vals + + @torch.no_grad() + def sample_z_vals_from_maskVolume(self, rays_o, rays_d, near, far, mask_volume): # don't use + device = rays_o.device + N_rays = len(rays_o) + n_samples = self.n_samples * 2 + + z_vals = torch.linspace(0.0, 1.0, n_samples).to(device) + z_vals = near + (far - near) * z_vals[None, :] + + if z_vals.shape[0] == 1: + z_vals = z_vals.repeat(N_rays, 1) + + mid_z_vals = (z_vals[:, 1:] + z_vals[:, :-1]) * 0.5 + + pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] + + pts_mask = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), mask_volume).reshape( + [N_rays, n_samples - 1]) + + # empty voxel set to 0.1, non-empty voxel set to 1 + weights = torch.where(pts_mask > 0, torch.ones_like(pts_mask).to(device), + 0.1 * torch.ones_like(pts_mask).to(device)) + + # sample more pts in non-empty voxels + z_samples = sample_pdf(z_vals, weights, self.n_samples, det=True).detach() + return z_samples + + @torch.no_grad() + def filter_pts_by_depthmaps(self, coords, pred_depth_maps, proj_matrices, + partial_vol_origin, voxel_size, + near, far, depth_interval, d_plane_nums): + """ + Use the pred_depthmaps to remove redundant pts (pruned by sdf, sdf always have two sides, the back side is useless) + :param coords: [n, 3] int coords + :param pred_depth_maps: [N_views, 1, h, w] + :param proj_matrices: [N_views, 4, 4] + :param partial_vol_origin: [3] + :param voxel_size: 1 + :param near: 1 + :param far: 1 + :param depth_interval: 1 + :param d_plane_nums: 1 + :return: + """ + device = pred_depth_maps.device + n_views, _, sizeH, sizeW = pred_depth_maps.shape + + if len(partial_vol_origin.shape) == 1: + partial_vol_origin = partial_vol_origin[None, :] + pts = coords * voxel_size + partial_vol_origin + + rs_grid = pts.unsqueeze(0).expand(n_views, -1, -1) + rs_grid = rs_grid.permute(0, 2, 1).contiguous() # [n_views, 3, n_pts] + nV = rs_grid.shape[-1] + rs_grid = torch.cat([rs_grid, torch.ones([n_views, 1, nV]).to(device)], dim=1) # [n_views, 4, n_pts] + + # Project grid + im_p = proj_matrices @ rs_grid # - transform world pts to image UV space # [n_views, 4, n_pts] + im_x, im_y, im_z = im_p[:, 0], im_p[:, 1], im_p[:, 2] + im_x = im_x / im_z + im_y = im_y / im_z + + im_grid = torch.stack([2 * im_x / (sizeW - 1) - 1, 2 * im_y / (sizeH - 1) - 1], dim=-1) + + im_grid = im_grid.view(n_views, 1, -1, 2) + sampled_depths = torch.nn.functional.grid_sample(pred_depth_maps, im_grid, mode='bilinear', + padding_mode='zeros', + align_corners=True)[:, 0, 0, :] # [n_views, n_pts] + sampled_depths_valid = (sampled_depths > 0.5 * near).float() + valid_d_min = (sampled_depths - d_plane_nums * depth_interval).clamp(near.item(), + far.item()) * sampled_depths_valid + valid_d_max = (sampled_depths + d_plane_nums * depth_interval).clamp(near.item(), + far.item()) * sampled_depths_valid + + mask = im_grid.abs() <= 1 + mask = mask[:, 0] # [n_views, n_pts, 2] + mask = (mask.sum(dim=-1) == 2) & (im_z > valid_d_min) & (im_z < valid_d_max) + + mask = mask.view(n_views, -1) + mask = mask.permute(1, 0).contiguous() # [num_pts, nviews] + + mask_final = torch.sum(mask.float(), dim=1, keepdim=False) > 0 + + return mask_final + + @torch.no_grad() + def get_valid_sparse_coords_by_sdf_depthfilter(self, sdf_volume, coords_volume, mask_volume, feature_volume, + pred_depth_maps, proj_matrices, + partial_vol_origin, voxel_size, + near, far, depth_interval, d_plane_nums, + threshold=0.02, maximum_pts=110000): + """ + assume batch size == 1, from the first lod to get sparse voxels + :param sdf_volume: [1, X, Y, Z] + :param coords_volume: [3, X, Y, Z] + :param mask_volume: [1, X, Y, Z] + :param feature_volume: [C, X, Y, Z] + :param threshold: + :return: + """ + device = coords_volume.device + _, dX, dY, dZ = coords_volume.shape + + def prune(sdf_pts, coords_pts, mask_volume, threshold): + occupancy_mask = (torch.abs(sdf_pts) < threshold).squeeze(1) # [num_pts] + valid_coords = coords_pts[occupancy_mask] + + # - filter backside surface by depth maps + mask_filtered = self.filter_pts_by_depthmaps(valid_coords, pred_depth_maps, proj_matrices, + partial_vol_origin, voxel_size, + near, far, depth_interval, d_plane_nums) + valid_coords = valid_coords[mask_filtered] + + # - dilate + occupancy_mask = sparse_to_dense_channel(valid_coords, 1, [dX, dY, dZ], 1, 0, device) # [dX, dY, dZ, 1] + + # - dilate + occupancy_mask = occupancy_mask.float() + occupancy_mask = occupancy_mask.view(1, 1, dX, dY, dZ) + occupancy_mask = F.avg_pool3d(occupancy_mask, kernel_size=7, stride=1, padding=3) + occupancy_mask = occupancy_mask.view(-1, 1) > 0 + + final_mask = torch.logical_and(mask_volume, occupancy_mask)[:, 0] # [num_pts] + + return final_mask, torch.sum(final_mask.float()) + + C, dX, dY, dZ = feature_volume.shape + sdf_volume = sdf_volume.permute(1, 2, 3, 0).contiguous().view(-1, 1) + coords_volume = coords_volume.permute(1, 2, 3, 0).contiguous().view(-1, 3) + mask_volume = mask_volume.permute(1, 2, 3, 0).contiguous().view(-1, 1) + feature_volume = feature_volume.permute(1, 2, 3, 0).contiguous().view(-1, C) + + # - for check + # sdf_volume = torch.rand_like(sdf_volume).float().to(sdf_volume.device) * 0.02 + + final_mask, valid_num = prune(sdf_volume, coords_volume, mask_volume, threshold) + + while (valid_num > maximum_pts) and (threshold > 0.003): + threshold = threshold - 0.002 + final_mask, valid_num = prune(sdf_volume, coords_volume, mask_volume, threshold) + + valid_coords = coords_volume[final_mask] # [N, 3] + valid_feature = feature_volume[final_mask] # [N, C] + + valid_coords = torch.cat([torch.ones([valid_coords.shape[0], 1]).to(valid_coords.device) * 0, + valid_coords], dim=1) # [N, 4], append batch idx + + # ! if the valid_num is still larger than maximum_pts, sample part of pts + if valid_num > maximum_pts: + valid_num = valid_num.long() + occupancy = torch.ones([valid_num]).to(device) > 0 + choice = np.random.choice(valid_num.cpu().numpy(), valid_num.cpu().numpy() - maximum_pts, + replace=False) + ind = torch.nonzero(occupancy).to(device) + occupancy[ind[choice]] = False + valid_coords = valid_coords[occupancy] + valid_feature = valid_feature[occupancy] + + print(threshold, "randomly sample to save memory") + + return valid_coords, valid_feature + + @torch.no_grad() + def get_valid_sparse_coords_by_sdf(self, sdf_volume, coords_volume, mask_volume, feature_volume, threshold=0.02, + maximum_pts=110000): + """ + assume batch size == 1, from the first lod to get sparse voxels + :param sdf_volume: [num_pts, 1] + :param coords_volume: [3, X, Y, Z] + :param mask_volume: [1, X, Y, Z] + :param feature_volume: [C, X, Y, Z] + :param threshold: + :return: + """ + + def prune(sdf_volume, mask_volume, threshold): + occupancy_mask = torch.abs(sdf_volume) < threshold # [num_pts, 1] + + # - dilate + occupancy_mask = occupancy_mask.float() + occupancy_mask = occupancy_mask.view(1, 1, dX, dY, dZ) + occupancy_mask = F.avg_pool3d(occupancy_mask, kernel_size=7, stride=1, padding=3) + occupancy_mask = occupancy_mask.view(-1, 1) > 0 + + final_mask = torch.logical_and(mask_volume, occupancy_mask)[:, 0] # [num_pts] + + return final_mask, torch.sum(final_mask.float()) + + C, dX, dY, dZ = feature_volume.shape + coords_volume = coords_volume.permute(1, 2, 3, 0).contiguous().view(-1, 3) + mask_volume = mask_volume.permute(1, 2, 3, 0).contiguous().view(-1, 1) + feature_volume = feature_volume.permute(1, 2, 3, 0).contiguous().view(-1, C) + + final_mask, valid_num = prune(sdf_volume, mask_volume, threshold) + + while (valid_num > maximum_pts) and (threshold > 0.003): + threshold = threshold - 0.002 + final_mask, valid_num = prune(sdf_volume, mask_volume, threshold) + + valid_coords = coords_volume[final_mask] # [N, 3] + valid_feature = feature_volume[final_mask] # [N, C] + + valid_coords = torch.cat([torch.ones([valid_coords.shape[0], 1]).to(valid_coords.device) * 0, + valid_coords], dim=1) # [N, 4], append batch idx + + # ! if the valid_num is still larger than maximum_pts, sample part of pts + if valid_num > maximum_pts: + device = sdf_volume.device + valid_num = valid_num.long() + occupancy = torch.ones([valid_num]).to(device) > 0 + choice = np.random.choice(valid_num.cpu().numpy(), valid_num.cpu().numpy() - maximum_pts, + replace=False) + ind = torch.nonzero(occupancy).to(device) + occupancy[ind[choice]] = False + valid_coords = valid_coords[occupancy] + valid_feature = valid_feature[occupancy] + + print(threshold, "randomly sample to save memory") + + return valid_coords, valid_feature + + @torch.no_grad() + def extract_fields(self, bound_min, bound_max, resolution, query_func, device, + # * related to conditional feature + **kwargs + ): + N = 64 + X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N) + Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N) + Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N) + + u = np.zeros([resolution, resolution, resolution], dtype=np.float32) + with torch.no_grad(): + for xi, xs in enumerate(X): + for yi, ys in enumerate(Y): + for zi, zs in enumerate(Z): + xx, yy, zz = torch.meshgrid(xs, ys, zs) + pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1).to(device) + + # ! attention, the query function is different for extract geometry and fields + output = query_func(pts, **kwargs) + sdf = output['sdf_pts_scale%d' % kwargs['lod']].reshape(len(xs), len(ys), + len(zs)).detach().cpu().numpy() + + u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = -1 * sdf + return u + + @torch.no_grad() + def extract_geometry(self, sdf_network, bound_min, bound_max, resolution, threshold, device, occupancy_mask=None, + # * 3d feature volume + **kwargs + ): + # logging.info('threshold: {}'.format(threshold)) + + u = self.extract_fields(bound_min, bound_max, resolution, + lambda pts, **kwargs: sdf_network.sdf(pts, **kwargs), + # - sdf need to be multiplied by -1 + device, + # * 3d feature volume + **kwargs + ) + if occupancy_mask is not None: + dX, dY, dZ = occupancy_mask.shape + empty_mask = 1 - occupancy_mask + empty_mask = empty_mask.view(1, 1, dX, dY, dZ) + # - dilation + # empty_mask = F.avg_pool3d(empty_mask, kernel_size=7, stride=1, padding=3) + empty_mask = F.interpolate(empty_mask, [resolution, resolution, resolution], mode='nearest') + empty_mask = empty_mask.view(resolution, resolution, resolution).cpu().numpy() > 0 + u[empty_mask] = -100 + del empty_mask + + vertices, triangles = mcubes.marching_cubes(u, threshold) + b_max_np = bound_max.detach().cpu().numpy() + b_min_np = bound_min.detach().cpu().numpy() + + vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] + return vertices, triangles, u + + @torch.no_grad() + def extract_depth_maps(self, sdf_network, con_volume, intrinsics, c2ws, H, W, near, far): + """ + extract depth maps from the density volume + :param con_volume: [1, 1+C, dX, dY, dZ] can by con_volume or sdf_volume + :param c2ws: [B, 4, 4] + :param H: + :param W: + :param near: + :param far: + :return: + """ + device = con_volume.device + batch_size = intrinsics.shape[0] + + with torch.no_grad(): + ys, xs = torch.meshgrid(torch.linspace(0, H - 1, H), + torch.linspace(0, W - 1, W)) # pytorch's meshgrid has indexing='ij' + p = torch.stack([xs, ys, torch.ones_like(ys)], dim=-1) # H, W, 3 + + intrinsics_inv = torch.inverse(intrinsics) + + p = p.view(-1, 3).float().to(device) # N_rays, 3 + p = torch.matmul(intrinsics_inv[:, None, :3, :3], p[:, :, None]).squeeze() # Batch, N_rays, 3 + rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # Batch, N_rays, 3 + rays_v = torch.matmul(c2ws[:, None, :3, :3], rays_v[:, :, :, None]).squeeze() # Batch, N_rays, 3 + rays_o = c2ws[:, None, :3, 3].expand(rays_v.shape) # Batch, N_rays, 3 + rays_d = rays_v + + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + ################## - sphere tracer to extract depth maps ###################### + depth_masks_sphere, depth_maps_sphere = self.ray_tracer.extract_depth_maps( + rays_o, rays_d, + near[None, :].repeat(rays_o.shape[0], 1), + far[None, :].repeat(rays_o.shape[0], 1), + sdf_network, con_volume + ) + + depth_maps = depth_maps_sphere.view(batch_size, 1, H, W) + depth_masks = depth_masks_sphere.view(batch_size, 1, H, W) + + depth_maps = torch.where(depth_masks, depth_maps, + torch.zeros_like(depth_masks.float()).to(device)) # fill invalid pixels by 0 + + return depth_maps, depth_masks diff --git a/SparseNeuS_demo_v1/models/sparse_sdf_network.py b/SparseNeuS_demo_v1/models/sparse_sdf_network.py new file mode 100644 index 0000000000000000000000000000000000000000..817f40ed08b7cb65fb284a4666d6f6a4a3c52683 --- /dev/null +++ b/SparseNeuS_demo_v1/models/sparse_sdf_network.py @@ -0,0 +1,907 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchsparse.tensor import PointTensor, SparseTensor +import torchsparse.nn as spnn + +from tsparse.modules import SparseCostRegNet +from tsparse.torchsparse_utils import sparse_to_dense_channel +from ops.grid_sampler import grid_sample_3d, tricubic_sample_3d + +# from .gru_fusion import GRUFusion +from ops.back_project import back_project_sparse_type +from ops.generate_grids import generate_grid + +from inplace_abn import InPlaceABN + +from models.embedder import Embedding +from models.featurenet import ConvBnReLU + +import pdb +import random + +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) + + +@torch.jit.script +def fused_mean_variance(x, weight): + mean = torch.sum(x * weight, dim=1, keepdim=True) + var = torch.sum(weight * (x - mean) ** 2, dim=1, keepdim=True) + return mean, var + + +class LatentSDFLayer(nn.Module): + def __init__(self, + d_in=3, + d_out=129, + d_hidden=128, + n_layers=4, + skip_in=(4,), + multires=0, + bias=0.5, + geometric_init=True, + weight_norm=True, + activation='softplus', + d_conditional_feature=16): + super(LatentSDFLayer, self).__init__() + + self.d_conditional_feature = d_conditional_feature + + # concat latent code for ench layer input excepting the first layer and the last layer + dims_in = [d_in] + [d_hidden + d_conditional_feature for _ in range(n_layers - 2)] + [d_hidden] + dims_out = [d_hidden for _ in range(n_layers - 1)] + [d_out] + + self.embed_fn_fine = None + + if multires > 0: + embed_fn = Embedding(in_channels=d_in, N_freqs=multires) # * include the input + self.embed_fn_fine = embed_fn + dims_in[0] = embed_fn.out_channels + + self.num_layers = n_layers + self.skip_in = skip_in + + for l in range(0, self.num_layers - 1): + if l in self.skip_in: + in_dim = dims_in[l] + dims_in[0] + else: + in_dim = dims_in[l] + + out_dim = dims_out[l] + lin = nn.Linear(in_dim, out_dim) + + if geometric_init: # - from IDR code, + if l == self.num_layers - 2: + torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(in_dim), std=0.0001) + torch.nn.init.constant_(lin.bias, -bias) + # the channels for latent codes are set to 0 + torch.nn.init.constant_(lin.weight[:, -d_conditional_feature:], 0.0) + torch.nn.init.constant_(lin.bias[-d_conditional_feature:], 0.0) + + elif multires > 0 and l == 0: # the first layer + torch.nn.init.constant_(lin.bias, 0.0) + # * the channels for position embeddings are set to 0 + torch.nn.init.constant_(lin.weight[:, 3:], 0.0) + # * the channels for the xyz coordinate (3 channels) for initialized by normal distribution + torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) + elif multires > 0 and l in self.skip_in: + torch.nn.init.constant_(lin.bias, 0.0) + torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) + # * the channels for position embeddings (and conditional_feature) are initialized to 0 + torch.nn.init.constant_(lin.weight[:, -(dims_in[0] - 3 + d_conditional_feature):], 0.0) + else: + torch.nn.init.constant_(lin.bias, 0.0) + torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) + # the channels for latent code are initialized to 0 + torch.nn.init.constant_(lin.weight[:, -d_conditional_feature:], 0.0) + + if weight_norm: + lin = nn.utils.weight_norm(lin) + + setattr(self, "lin" + str(l), lin) + + if activation == 'softplus': + self.activation = nn.Softplus(beta=100) + else: + assert activation == 'relu' + self.activation = nn.ReLU() + + def forward(self, inputs, latent): + inputs = inputs + if self.embed_fn_fine is not None: + inputs = self.embed_fn_fine(inputs) + + # - only for lod1 network can use the pretrained params of lod0 network + if latent.shape[1] != self.d_conditional_feature: + latent = torch.cat([latent, latent], dim=1) + + x = inputs + for l in range(0, self.num_layers - 1): + lin = getattr(self, "lin" + str(l)) + + # * due to the conditional bias, different from original neus version + if l in self.skip_in: + x = torch.cat([x, inputs], 1) / np.sqrt(2) + + if 0 < l < self.num_layers - 1: + x = torch.cat([x, latent], 1) + + x = lin(x) + + if l < self.num_layers - 2: + x = self.activation(x) + + return x + + +class SparseSdfNetwork(nn.Module): + ''' + Coarse-to-fine sparse cost regularization network + return sparse volume feature for extracting sdf + ''' + + def __init__(self, lod, ch_in, voxel_size, vol_dims, + hidden_dim=128, activation='softplus', + cost_type='variance_mean', + d_pyramid_feature_compress=16, + regnet_d_out=8, num_sdf_layers=4, + multires=6, + ): + super(SparseSdfNetwork, self).__init__() + + self.lod = lod # - gradually training, the current regularization lod + self.ch_in = ch_in + self.voxel_size = voxel_size # - the voxel size of the current volume + self.vol_dims = torch.tensor(vol_dims) # - the dims of the current volume + + self.selected_views_num = 2 # the number of selected views for feature aggregation + self.hidden_dim = hidden_dim + self.activation = activation + self.cost_type = cost_type + self.d_pyramid_feature_compress = d_pyramid_feature_compress + self.gru_fusion = None + + self.regnet_d_out = regnet_d_out + self.multires = multires + + self.pos_embedder = Embedding(3, self.multires) + + self.compress_layer = ConvBnReLU( + self.ch_in, self.d_pyramid_feature_compress, 3, 1, 1, + norm_act=InPlaceABN) + sparse_ch_in = self.d_pyramid_feature_compress * 2 + + sparse_ch_in = sparse_ch_in + 16 if self.lod > 0 else sparse_ch_in + self.sparse_costreg_net = SparseCostRegNet( + d_in=sparse_ch_in, d_out=self.regnet_d_out) + # self.regnet_d_out = self.sparse_costreg_net.d_out + + if activation == 'softplus': + self.activation = nn.Softplus(beta=100) + else: + assert activation == 'relu' + self.activation = nn.ReLU() + + self.sdf_layer = LatentSDFLayer(d_in=3, + d_out=self.hidden_dim + 1, + d_hidden=self.hidden_dim, + n_layers=num_sdf_layers, + multires=multires, + geometric_init=True, + weight_norm=True, + activation=activation, + d_conditional_feature=16 # self.regnet_d_out + ) + + def upsample(self, pre_feat, pre_coords, interval, num=8): + ''' + + :param pre_feat: (Tensor), features from last level, (N, C) + :param pre_coords: (Tensor), coordinates from last level, (N, 4) (4 : Batch ind, x, y, z) + :param interval: interval of voxels, interval = scale ** 2 + :param num: 1 -> 8 + :return: up_feat : (Tensor), upsampled features, (N*8, C) + :return: up_coords: (N*8, 4), upsampled coordinates, (4 : Batch ind, x, y, z) + ''' + with torch.no_grad(): + pos_list = [1, 2, 3, [1, 2], [1, 3], [2, 3], [1, 2, 3]] + n, c = pre_feat.shape + up_feat = pre_feat.unsqueeze(1).expand(-1, num, -1).contiguous() + up_coords = pre_coords.unsqueeze(1).repeat(1, num, 1).contiguous() + for i in range(num - 1): + up_coords[:, i + 1, pos_list[i]] += interval + + up_feat = up_feat.view(-1, c) + up_coords = up_coords.view(-1, 4) + + return up_feat, up_coords + + def aggregate_multiview_features(self, multiview_features, multiview_masks): + """ + aggregate mutli-view features by compute their cost variance + :param multiview_features: (num of voxels, num_of_views, c) + :param multiview_masks: (num of voxels, num_of_views) + :return: + """ + num_pts, n_views, C = multiview_features.shape + + counts = torch.sum(multiview_masks, dim=1, keepdim=False) # [num_pts] + + assert torch.all(counts > 0) # the point is visible for at least 1 view + + volume_sum = torch.sum(multiview_features, dim=1, keepdim=False) # [num_pts, C] + volume_sq_sum = torch.sum(multiview_features ** 2, dim=1, keepdim=False) + + if volume_sum.isnan().sum() > 0: + import ipdb; ipdb.set_trace() + + del multiview_features + + counts = 1. / (counts + 1e-5) + costvar = volume_sq_sum * counts[:, None] - (volume_sum * counts[:, None]) ** 2 + + costvar_mean = torch.cat([costvar, volume_sum * counts[:, None]], dim=1) + del volume_sum, volume_sq_sum, counts + + + + return costvar_mean + + def sparse_to_dense_volume(self, coords, feature, vol_dims, interval, device=None): + """ + convert the sparse volume into dense volume to enable trilinear sampling + to save GPU memory; + :param coords: [num_pts, 3] + :param feature: [num_pts, C] + :param vol_dims: [3] dX, dY, dZ + :param interval: + :return: + """ + + # * assume batch size is 1 + if device is None: + device = feature.device + + coords_int = (coords / interval).to(torch.int64) + vol_dims = (vol_dims / interval).to(torch.int64) + + # - if stored in CPU, too slow + dense_volume = sparse_to_dense_channel( + coords_int.to(device), feature.to(device), vol_dims.to(device), + feature.shape[1], 0, device) # [X, Y, Z, C] + + valid_mask_volume = sparse_to_dense_channel( + coords_int.to(device), + torch.ones([feature.shape[0], 1]).to(feature.device), + vol_dims.to(device), + 1, 0, device) # [X, Y, Z, 1] + + dense_volume = dense_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) # [1, C, X, Y, Z] + valid_mask_volume = valid_mask_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) # [1, 1, X, Y, Z] + + return dense_volume, valid_mask_volume + + def get_conditional_volume(self, feature_maps, partial_vol_origin, proj_mats, sizeH=None, sizeW=None, lod=0, + pre_coords=None, pre_feats=None, + ): + """ + + :param feature_maps: pyramid features (B,V,C0+C1+C2,H,W) fused pyramid features + :param partial_vol_origin: [B, 3] the world coordinates of the volume origin (0,0,0) + :param proj_mats: projection matrix transform world pts into image space [B,V,4,4] suitable for original image size + :param sizeH: the H of original image size + :param sizeW: the W of original image size + :param pre_coords: the coordinates of sparse volume from the prior lod + :param pre_feats: the features of sparse volume from the prior lod + :return: + """ + device = proj_mats.device + bs = feature_maps.shape[0] + N_views = feature_maps.shape[1] + minimum_visible_views = np.min([1, N_views - 1]) + # import ipdb; ipdb.set_trace() + outputs = {} + pts_samples = [] + + # ----coarse to fine---- + + # * use fused pyramid feature maps are very important + if self.compress_layer is not None: + feats = self.compress_layer(feature_maps[0]) + else: + feats = feature_maps[0] + feats = feats[:, None, :, :, :] # [V, B, C, H, W] + KRcam = proj_mats.permute(1, 0, 2, 3).contiguous() # [V, B, 4, 4] + interval = 1 + + if self.lod == 0: + # ----generate new coords---- + coords = generate_grid(self.vol_dims, 1)[0] + coords = coords.view(3, -1).to(device) # [3, num_pts] + up_coords = [] + for b in range(bs): + up_coords.append(torch.cat([torch.ones(1, coords.shape[-1]).to(coords.device) * b, coords])) + up_coords = torch.cat(up_coords, dim=1).permute(1, 0).contiguous() + # * since we only estimate the geometry of input reference image at one time; + # * mask the outside of the camera frustum + # import ipdb; ipdb.set_trace() + frustum_mask = back_project_sparse_type( + up_coords, partial_vol_origin, self.voxel_size, + feats, KRcam, sizeH=sizeH, sizeW=sizeW, only_mask=True) # [num_pts, n_views] + frustum_mask = torch.sum(frustum_mask, dim=-1) > minimum_visible_views # ! here should be large + up_coords = up_coords[frustum_mask] # [num_pts_valid, 4] + + else: + # ----upsample coords---- + assert pre_feats is not None + assert pre_coords is not None + up_feat, up_coords = self.upsample(pre_feats, pre_coords, 1) + + # ----back project---- + # give each valid 3d grid point all valid 2D features and masks + multiview_features, multiview_masks = back_project_sparse_type( + up_coords, partial_vol_origin, self.voxel_size, feats, + KRcam, sizeH=sizeH, sizeW=sizeW) # (num of voxels, num_of_views, c), (num of voxels, num_of_views) + # num_of_views = all views + + # if multiview_features.isnan().sum() > 0: + # import ipdb; ipdb.set_trace() + + # import ipdb; ipdb.set_trace() + if self.lod > 0: + # ! need another invalid voxels filtering + frustum_mask = torch.sum(multiview_masks, dim=-1) > 1 + up_feat = up_feat[frustum_mask] + up_coords = up_coords[frustum_mask] + multiview_features = multiview_features[frustum_mask] + multiview_masks = multiview_masks[frustum_mask] + # if multiview_features.isnan().sum() > 0: + # import ipdb; ipdb.set_trace() + volume = self.aggregate_multiview_features(multiview_features, multiview_masks) # compute variance for all images features + # import ipdb; ipdb.set_trace() + + # if volume.isnan().sum() > 0: + # import ipdb; ipdb.set_trace() + + del multiview_features, multiview_masks + + # ----concat feature from last stage---- + if self.lod != 0: + feat = torch.cat([volume, up_feat], dim=1) + else: + feat = volume + + # batch index is in the last position + r_coords = up_coords[:, [1, 2, 3, 0]] + + # if feat.isnan().sum() > 0: + # print('feat has nan:', feat.isnan().sum()) + # import ipdb; ipdb.set_trace() + + sparse_feat = SparseTensor(feat, r_coords.to( + torch.int32)) # - directly use sparse tensor to avoid point2voxel operations + # import ipdb; ipdb.set_trace() + feat = self.sparse_costreg_net(sparse_feat) + + dense_volume, valid_mask_volume = self.sparse_to_dense_volume(up_coords[:, 1:], feat, self.vol_dims, interval, + device=None) # [1, C/1, X, Y, Z] + + # if dense_volume.isnan().sum() > 0: + # import ipdb; ipdb.set_trace() + + + outputs['dense_volume_scale%d' % self.lod] = dense_volume # [1, 16, 96, 96, 96] + outputs['valid_mask_volume_scale%d' % self.lod] = valid_mask_volume # [1, 1, 96, 96, 96] + outputs['visible_mask_scale%d' % self.lod] = valid_mask_volume # [1, 1, 96, 96, 96] + outputs['coords_scale%d' % self.lod] = generate_grid(self.vol_dims, interval).to(device) + # import ipdb; ipdb.set_trace() + return outputs + + def sdf(self, pts, conditional_volume, lod): + num_pts = pts.shape[0] + device = pts.device + pts_ = pts.clone() + pts = pts.view(1, 1, 1, num_pts, 3) # - should be in range (-1, 1) + + pts = torch.flip(pts, dims=[-1]) + # import ipdb; ipdb.set_trace() + sampled_feature = grid_sample_3d(conditional_volume, pts) # [1, c, 1, 1, num_pts] + sampled_feature = sampled_feature.view(-1, num_pts).permute(1, 0).contiguous().to(device) + + sdf_pts = self.sdf_layer(pts_, sampled_feature) + + outputs = {} + outputs['sdf_pts_scale%d' % lod] = sdf_pts[:, :1] + outputs['sdf_features_pts_scale%d' % lod] = sdf_pts[:, 1:] + outputs['sampled_latent_scale%d' % lod] = sampled_feature + + return outputs + + @torch.no_grad() + def sdf_from_sdfvolume(self, pts, sdf_volume, lod=0): + num_pts = pts.shape[0] + device = pts.device + pts_ = pts.clone() + pts = pts.view(1, 1, 1, num_pts, 3) # - should be in range (-1, 1) + + pts = torch.flip(pts, dims=[-1]) + + sdf = torch.nn.functional.grid_sample(sdf_volume, pts, mode='bilinear', align_corners=True, + padding_mode='border') + sdf = sdf.view(-1, num_pts).permute(1, 0).contiguous().to(device) + + outputs = {} + outputs['sdf_pts_scale%d' % lod] = sdf + + return outputs + + @torch.no_grad() + def get_sdf_volume(self, conditional_volume, mask_volume, coords_volume, partial_origin): + """ + + :param conditional_volume: [1,C, dX,dY,dZ] + :param mask_volume: [1,1, dX,dY,dZ] + :param coords_volume: [1,3, dX,dY,dZ] + :return: + """ + device = conditional_volume.device + chunk_size = 10240 + + _, C, dX, dY, dZ = conditional_volume.shape + conditional_volume = conditional_volume.view(C, dX * dY * dZ).permute(1, 0).contiguous() + mask_volume = mask_volume.view(-1) + coords_volume = coords_volume.view(3, dX * dY * dZ).permute(1, 0).contiguous() + + pts = coords_volume * self.voxel_size + partial_origin # [dX*dY*dZ, 3] + + sdf_volume = torch.ones([dX * dY * dZ, 1]).float().to(device) + + conditional_volume = conditional_volume[mask_volume > 0] + pts = pts[mask_volume > 0] + conditional_volume = conditional_volume.split(chunk_size) + pts = pts.split(chunk_size) + + sdf_all = [] + for pts_part, feature_part in zip(pts, conditional_volume): + sdf_part = self.sdf_layer(pts_part, feature_part)[:, :1] + sdf_all.append(sdf_part) + + sdf_all = torch.cat(sdf_all, dim=0) + sdf_volume[mask_volume > 0] = sdf_all + sdf_volume = sdf_volume.view(1, 1, dX, dY, dZ) + return sdf_volume + + def gradient(self, x, conditional_volume, lod): + """ + return the gradient of specific lod + :param x: + :param lod: + :return: + """ + x.requires_grad_(True) + # import ipdb; ipdb.set_trace() + with torch.enable_grad(): + output = self.sdf(x, conditional_volume, lod) + y = output['sdf_pts_scale%d' % lod] + + d_output = torch.ones_like(y, requires_grad=False, device=y.device) + # ! Distributed Data Parallel doesn’t work with torch.autograd.grad() + # ! (i.e. it will only work if gradients are to be accumulated in .grad attributes of parameters). + gradients = torch.autograd.grad( + outputs=y, + inputs=x, + grad_outputs=d_output, + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + return gradients.unsqueeze(1) + + +def sparse_to_dense_volume(coords, feature, vol_dims, interval, device=None): + """ + convert the sparse volume into dense volume to enable trilinear sampling + to save GPU memory; + :param coords: [num_pts, 3] + :param feature: [num_pts, C] + :param vol_dims: [3] dX, dY, dZ + :param interval: + :return: + """ + + # * assume batch size is 1 + if device is None: + device = feature.device + + coords_int = (coords / interval).to(torch.int64) + vol_dims = (vol_dims / interval).to(torch.int64) + + # - if stored in CPU, too slow + dense_volume = sparse_to_dense_channel( + coords_int.to(device), feature.to(device), vol_dims.to(device), + feature.shape[1], 0, device) # [X, Y, Z, C] + + valid_mask_volume = sparse_to_dense_channel( + coords_int.to(device), + torch.ones([feature.shape[0], 1]).to(feature.device), + vol_dims.to(device), + 1, 0, device) # [X, Y, Z, 1] + + dense_volume = dense_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) # [1, C, X, Y, Z] + valid_mask_volume = valid_mask_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) # [1, 1, X, Y, Z] + + return dense_volume, valid_mask_volume + + +class SdfVolume(nn.Module): + def __init__(self, volume, coords=None, type='dense'): + super(SdfVolume, self).__init__() + self.volume = torch.nn.Parameter(volume, requires_grad=True) + self.coords = coords + self.type = type + + def forward(self): + return self.volume + + +class FinetuneOctreeSdfNetwork(nn.Module): + ''' + After obtain the conditional volume from generalized network; + directly optimize the conditional volume + The conditional volume is still sparse + ''' + + def __init__(self, voxel_size, vol_dims, + origin=[-1., -1., -1.], + hidden_dim=128, activation='softplus', + regnet_d_out=8, + multires=6, + if_fitted_rendering=True, + num_sdf_layers=4, + ): + super(FinetuneOctreeSdfNetwork, self).__init__() + + self.voxel_size = voxel_size # - the voxel size of the current volume + self.vol_dims = torch.tensor(vol_dims) # - the dims of the current volume + + self.origin = torch.tensor(origin).to(torch.float32) + + self.hidden_dim = hidden_dim + self.activation = activation + + self.regnet_d_out = regnet_d_out + + self.if_fitted_rendering = if_fitted_rendering + self.multires = multires + # d_in_embedding = self.regnet_d_out if self.pos_add_type == 'latent' else 3 + # self.pos_embedder = Embedding(d_in_embedding, self.multires) + + # - the optimized parameters + self.sparse_volume_lod0 = None + self.sparse_coords_lod0 = None + + if activation == 'softplus': + self.activation = nn.Softplus(beta=100) + else: + assert activation == 'relu' + self.activation = nn.ReLU() + + self.sdf_layer = LatentSDFLayer(d_in=3, + d_out=self.hidden_dim + 1, + d_hidden=self.hidden_dim, + n_layers=num_sdf_layers, + multires=multires, + geometric_init=True, + weight_norm=True, + activation=activation, + d_conditional_feature=16 # self.regnet_d_out + ) + + # - add mlp rendering when finetuning + self.renderer = None + + d_in_renderer = 3 + self.regnet_d_out + 3 + 3 + self.renderer = BlendingRenderingNetwork( + d_feature=self.hidden_dim - 1, + mode='idr', # ! the view direction influence a lot + d_in=d_in_renderer, + d_out=50, # maximum 50 images + d_hidden=self.hidden_dim, + n_layers=3, + weight_norm=True, + multires_view=4, + squeeze_out=True, + ) + + def initialize_conditional_volumes(self, dense_volume_lod0, dense_volume_mask_lod0, + sparse_volume_lod0=None, sparse_coords_lod0=None): + """ + + :param dense_volume_lod0: [1,C,dX,dY,dZ] + :param dense_volume_mask_lod0: [1,1,dX,dY,dZ] + :param dense_volume_lod1: + :param dense_volume_mask_lod1: + :return: + """ + + if sparse_volume_lod0 is None: + device = dense_volume_lod0.device + _, C, dX, dY, dZ = dense_volume_lod0.shape + + dense_volume_lod0 = dense_volume_lod0.view(C, dX * dY * dZ).permute(1, 0).contiguous() + mask_lod0 = dense_volume_mask_lod0.view(dX * dY * dZ) > 0 + + self.sparse_volume_lod0 = SdfVolume(dense_volume_lod0[mask_lod0], type='sparse') + + coords = generate_grid(self.vol_dims, 1)[0] # [3, dX, dY, dZ] + coords = coords.view(3, dX * dY * dZ).permute(1, 0).to(device) + self.sparse_coords_lod0 = torch.nn.Parameter(coords[mask_lod0], requires_grad=False) + else: + self.sparse_volume_lod0 = SdfVolume(sparse_volume_lod0, type='sparse') + self.sparse_coords_lod0 = torch.nn.Parameter(sparse_coords_lod0, requires_grad=False) + + def get_conditional_volume(self): + dense_volume, valid_mask_volume = sparse_to_dense_volume( + self.sparse_coords_lod0, + self.sparse_volume_lod0(), self.vol_dims, interval=1, + device=None) # [1, C/1, X, Y, Z] + + # valid_mask_volume = self.dense_volume_mask_lod0 + + outputs = {} + outputs['dense_volume_scale%d' % 0] = dense_volume + outputs['valid_mask_volume_scale%d' % 0] = valid_mask_volume + + return outputs + + def tv_regularizer(self): + dense_volume, valid_mask_volume = sparse_to_dense_volume( + self.sparse_coords_lod0, + self.sparse_volume_lod0(), self.vol_dims, interval=1, + device=None) # [1, C/1, X, Y, Z] + + dx = (dense_volume[:, :, 1:, :, :] - dense_volume[:, :, :-1, :, :]) ** 2 # [1, C/1, X-1, Y, Z] + dy = (dense_volume[:, :, :, 1:, :] - dense_volume[:, :, :, :-1, :]) ** 2 # [1, C/1, X, Y-1, Z] + dz = (dense_volume[:, :, :, :, 1:] - dense_volume[:, :, :, :, :-1]) ** 2 # [1, C/1, X, Y, Z-1] + + tv = dx[:, :, :, :-1, :-1] + dy[:, :, :-1, :, :-1] + dz[:, :, :-1, :-1, :] # [1, C/1, X-1, Y-1, Z-1] + + mask = valid_mask_volume[:, :, :-1, :-1, :-1] * valid_mask_volume[:, :, 1:, :-1, :-1] * \ + valid_mask_volume[:, :, :-1, 1:, :-1] * valid_mask_volume[:, :, :-1, :-1, 1:] + + tv = torch.sqrt(tv + 1e-6).mean(dim=1, keepdim=True) * mask + # tv = tv.mean(dim=1, keepdim=True) * mask + + assert torch.all(~torch.isnan(tv)) + + return torch.mean(tv) + + def sdf(self, pts, conditional_volume, lod): + + outputs = {} + + num_pts = pts.shape[0] + device = pts.device + pts_ = pts.clone() + pts = pts.view(1, 1, 1, num_pts, 3) # - should be in range (-1, 1) + + pts = torch.flip(pts, dims=[-1]) + + sampled_feature = grid_sample_3d(conditional_volume, pts) # [1, c, 1, 1, num_pts] + sampled_feature = sampled_feature.view(-1, num_pts).permute(1, 0).contiguous() + outputs['sampled_latent_scale%d' % lod] = sampled_feature + + sdf_pts = self.sdf_layer(pts_, sampled_feature) + + lod = 0 + outputs['sdf_pts_scale%d' % lod] = sdf_pts[:, :1] + outputs['sdf_features_pts_scale%d' % lod] = sdf_pts[:, 1:] + + return outputs + + def color_blend(self, pts, position, normals, view_dirs, feature_vectors, img_index, + pts_pixel_color, pts_pixel_mask, pts_patch_color=None, pts_patch_mask=None): + + return self.renderer(torch.cat([pts, position], dim=-1), normals, view_dirs, feature_vectors, + img_index, pts_pixel_color, pts_pixel_mask, + pts_patch_color=pts_patch_color, pts_patch_mask=pts_patch_mask) + + def gradient(self, x, conditional_volume, lod): + """ + return the gradient of specific lod + :param x: + :param lod: + :return: + """ + x.requires_grad_(True) + output = self.sdf(x, conditional_volume, lod) + y = output['sdf_pts_scale%d' % 0] + + d_output = torch.ones_like(y, requires_grad=False, device=y.device) + + gradients = torch.autograd.grad( + outputs=y, + inputs=x, + grad_outputs=d_output, + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + return gradients.unsqueeze(1) + + @torch.no_grad() + def prune_dense_mask(self, threshold=0.02): + """ + Just gradually prune the mask of dense volume to decrease the number of sdf network inference + :return: + """ + chunk_size = 10240 + coords = generate_grid(self.vol_dims_lod0, 1)[0] # [3, dX, dY, dZ] + + _, dX, dY, dZ = coords.shape + + pts = coords.view(3, -1).permute(1, + 0).contiguous() * self.voxel_size_lod0 + self.origin[None, :] # [dX*dY*dZ, 3] + + # dense_volume = self.dense_volume_lod0() # [1,C,dX,dY,dZ] + dense_volume, _ = sparse_to_dense_volume( + self.sparse_coords_lod0, + self.sparse_volume_lod0(), self.vol_dims_lod0, interval=1, + device=None) # [1, C/1, X, Y, Z] + + sdf_volume = torch.ones([dX * dY * dZ, 1]).float().to(dense_volume.device) * 100 + + mask = self.dense_volume_mask_lod0.view(-1) > 0 + + pts_valid = pts[mask].to(dense_volume.device) + feature_valid = dense_volume.view(self.regnet_d_out, -1).permute(1, 0).contiguous()[mask] + + pts_valid = pts_valid.split(chunk_size) + feature_valid = feature_valid.split(chunk_size) + + sdf_list = [] + + for pts_part, feature_part in zip(pts_valid, feature_valid): + sdf_part = self.sdf_layer(pts_part, feature_part)[:, :1] + sdf_list.append(sdf_part) + + sdf_list = torch.cat(sdf_list, dim=0) + + sdf_volume[mask] = sdf_list + + occupancy_mask = torch.abs(sdf_volume) < threshold # [num_pts, 1] + + # - dilate + occupancy_mask = occupancy_mask.float() + occupancy_mask = occupancy_mask.view(1, 1, dX, dY, dZ) + occupancy_mask = F.avg_pool3d(occupancy_mask, kernel_size=7, stride=1, padding=3) + occupancy_mask = occupancy_mask > 0 + + self.dense_volume_mask_lod0 = torch.logical_and(self.dense_volume_mask_lod0, + occupancy_mask).float() # (1, 1, dX, dY, dZ) + + +class BlendingRenderingNetwork(nn.Module): + def __init__( + self, + d_feature, + mode, + d_in, + d_out, + d_hidden, + n_layers, + weight_norm=True, + multires_view=0, + squeeze_out=True, + ): + super(BlendingRenderingNetwork, self).__init__() + + self.mode = mode + self.squeeze_out = squeeze_out + dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out] + + self.embedder = None + if multires_view > 0: + self.embedder = Embedding(3, multires_view) + dims[0] += (self.embedder.out_channels - 3) + + self.num_layers = len(dims) + + for l in range(0, self.num_layers - 1): + out_dim = dims[l + 1] + lin = nn.Linear(dims[l], out_dim) + + if weight_norm: + lin = nn.utils.weight_norm(lin) + + setattr(self, "lin" + str(l), lin) + + self.relu = nn.ReLU() + + self.color_volume = None + + self.softmax = nn.Softmax(dim=1) + + self.type = 'blending' + + def sample_pts_from_colorVolume(self, pts): + device = pts.device + num_pts = pts.shape[0] + pts_ = pts.clone() + pts = pts.view(1, 1, 1, num_pts, 3) # - should be in range (-1, 1) + + pts = torch.flip(pts, dims=[-1]) + + sampled_color = grid_sample_3d(self.color_volume, pts) # [1, c, 1, 1, num_pts] + sampled_color = sampled_color.view(-1, num_pts).permute(1, 0).contiguous().to(device) + + return sampled_color + + def forward(self, position, normals, view_dirs, feature_vectors, img_index, + pts_pixel_color, pts_pixel_mask, pts_patch_color=None, pts_patch_mask=None): + """ + + :param position: can be 3d coord or interpolated volume latent + :param normals: + :param view_dirs: + :param feature_vectors: + :param img_index: [N_views], used to extract corresponding weights + :param pts_pixel_color: [N_pts, N_views, 3] + :param pts_pixel_mask: [N_pts, N_views] + :param pts_patch_color: [N_pts, N_views, Npx, 3] + :return: + """ + if self.embedder is not None: + view_dirs = self.embedder(view_dirs) + + rendering_input = None + + if self.mode == 'idr': + rendering_input = torch.cat([position, view_dirs, normals, feature_vectors], dim=-1) + elif self.mode == 'no_view_dir': + rendering_input = torch.cat([position, normals, feature_vectors], dim=-1) + elif self.mode == 'no_normal': + rendering_input = torch.cat([position, view_dirs, feature_vectors], dim=-1) + elif self.mode == 'no_points': + rendering_input = torch.cat([view_dirs, normals, feature_vectors], dim=-1) + elif self.mode == 'no_points_no_view_dir': + rendering_input = torch.cat([normals, feature_vectors], dim=-1) + + x = rendering_input + + for l in range(0, self.num_layers - 1): + lin = getattr(self, "lin" + str(l)) + + x = lin(x) + + if l < self.num_layers - 2: + x = self.relu(x) # [n_pts, d_out] + + ## extract value based on img_index + x_extracted = torch.index_select(x, 1, img_index.long()) + + weights_pixel = self.softmax(x_extracted) # [n_pts, N_views] + weights_pixel = weights_pixel * pts_pixel_mask + weights_pixel = weights_pixel / ( + torch.sum(weights_pixel.float(), dim=1, keepdim=True) + 1e-8) # [n_pts, N_views] + final_pixel_color = torch.sum(pts_pixel_color * weights_pixel[:, :, None], dim=1, + keepdim=False) # [N_pts, 3] + + final_pixel_mask = torch.sum(pts_pixel_mask.float(), dim=1, keepdim=True) > 0 # [N_pts, 1] + + final_patch_color, final_patch_mask = None, None + # pts_patch_color [N_pts, N_views, Npx, 3]; pts_patch_mask [N_pts, N_views, Npx] + if pts_patch_color is not None: + N_pts, N_views, Npx, _ = pts_patch_color.shape + patch_mask = torch.sum(pts_patch_mask, dim=-1, keepdim=False) > Npx - 1 # [N_pts, N_views] + + weights_patch = self.softmax(x_extracted) # [N_pts, N_views] + weights_patch = weights_patch * patch_mask + weights_patch = weights_patch / ( + torch.sum(weights_patch.float(), dim=1, keepdim=True) + 1e-8) # [n_pts, N_views] + + final_patch_color = torch.sum(pts_patch_color * weights_patch[:, :, None, None], dim=1, + keepdim=False) # [N_pts, Npx, 3] + final_patch_mask = torch.sum(patch_mask, dim=1, keepdim=True) > 0 # [N_pts, 1] at least one image sees + + return final_pixel_color, final_pixel_mask, final_patch_color, final_patch_mask diff --git a/SparseNeuS_demo_v1/models/trainer_finetune.py b/SparseNeuS_demo_v1/models/trainer_finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..e6203976b2a72dea61e1e728a3b1a225366f56a2 --- /dev/null +++ b/SparseNeuS_demo_v1/models/trainer_finetune.py @@ -0,0 +1,979 @@ +""" +Trainer for fine-tuning +""" +import os +import cv2 as cv +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np +import logging +import mcubes +import trimesh +from icecream import ic +from models.render_utils import sample_pdf +from utils.misc_utils import visualize_depth_numpy + +from utils.training_utils import tocuda, numpy2tensor +from loss.depth_metric import compute_depth_errors +from loss.color_loss import OcclusionColorLoss, OcclusionColorPatchLoss +from loss.depth_loss import DepthLoss, DepthSmoothLoss + +from models.projector import Projector + +from models.rays import gen_rays_between + +from models.sparse_neus_renderer import SparseNeuSRenderer + +import pdb + + +class FinetuneTrainer(nn.Module): + """ + Trainer used for fine-tuning + """ + + def __init__(self, + rendering_network_outside, + pyramid_feature_network_lod0, + pyramid_feature_network_lod1, + sdf_network_lod0, + sdf_network_lod1, + variance_network_lod0, + variance_network_lod1, + sdf_network_finetune, + finetune_lod, # which lod fine-tuning use + n_samples, + n_importance, + n_outside, + perturb, + alpha_type='div', + conf=None + ): + super(FinetuneTrainer, self).__init__() + + self.conf = conf + self.base_exp_dir = conf['general.base_exp_dir'] + + self.finetune_lod = finetune_lod + + self.anneal_start = self.conf.get_float('train.anneal_start', default=0.0) + self.anneal_end = self.conf.get_float('train.anneal_end', default=0.0) + self.end_iter = self.conf.get_int('train.end_iter') + + # network setups + self.rendering_network_outside = rendering_network_outside + self.pyramid_feature_network_geometry_lod0 = pyramid_feature_network_lod0 # 2D pyramid feature network for geometry + self.pyramid_feature_network_geometry_lod1 = pyramid_feature_network_lod1 # use differnet networks for the two lods + + self.sdf_network_lod0 = sdf_network_lod0 # the first lod is density_network + self.sdf_network_lod1 = sdf_network_lod1 + + # - warpped by ModuleList to support DataParallel + self.variance_network_lod0 = variance_network_lod0 + self.variance_network_lod1 = variance_network_lod1 + self.variance_network_finetune = variance_network_lod0 if self.finetune_lod == 0 else variance_network_lod1 + + self.sdf_network_finetune = sdf_network_finetune + + self.n_samples = n_samples + self.n_importance = n_importance + self.n_outside = n_outside + self.perturb = perturb + self.alpha_type = alpha_type + + self.sdf_renderer_finetune = SparseNeuSRenderer( + self.rendering_network_outside, + self.sdf_network_finetune, + self.variance_network_finetune, + None, # rendering_network + self.n_samples, + self.n_importance, + self.n_outside, + self.perturb, + alpha_type='div', + conf=self.conf) + + # sdf network weights + self.sdf_igr_weight = self.conf.get_float('train.sdf_igr_weight') + self.sdf_sparse_weight = self.conf.get_float('train.sdf_sparse_weight', default=0) + + self.sdf_decay_param = self.conf.get_float('train.sdf_decay_param', default=100) + self.color_pixel_weight = self.conf.get_float('train.color_pixel_weight', default=1.0) + self.color_patch_weight = self.conf.get_float('train.color_patch_weight', default=0.) + self.tv_weight = self.conf.get_float('train.tv_weight', default=0.001) # no use + self.visibility_beta = self.conf.get_float('train.visibility_beta', default=0.025) + self.visibility_gama = self.conf.get_float('train.visibility_gama', default=0.015) + self.visibility_penalize_ratio = self.conf.get_float('train.visibility_penalize_ratio', default=0.8) + self.visibility_weight_thred = self.conf.get_list('train.visibility_weight_thred', default=[0.7]) + self.if_visibility_aware = self.conf.get_bool('train.if_visibility_aware', default=True) + self.train_from_scratch = self.conf.get_bool('train.train_from_scratch', default=False) + + self.depth_criterion = DepthLoss() + self.depth_smooth_criterion = DepthSmoothLoss() + self.occlusion_color_criterion = OcclusionColorLoss(beta=self.visibility_beta, + gama=self.visibility_gama, + weight_thred=self.visibility_weight_thred, + occlusion_aware=self.if_visibility_aware) + self.occlusion_color_patch_criterion = OcclusionColorPatchLoss( + type=self.conf.get_string('train.patch_loss_type', default='ncc'), + h_patch_size=self.conf.get_int('model.h_patch_size', default=5), + beta=self.visibility_beta, gama=self.visibility_gama, + weight_thred=self.visibility_weight_thred, + occlusion_aware=self.if_visibility_aware + ) + + # self.iter_step = 0 + self.val_mesh_freq = self.conf.get_int('train.val_mesh_freq') + + # - True if fine-tuning + self.if_fitted_rendering = self.conf.get_bool('train.if_fitted_rendering', default=False) + + def get_trainable_params(self): + # set trainable params + + params = [] + faster_params = [] + slower_params = [] + + params += self.variance_network_finetune.parameters() + slower_params += self.sdf_network_finetune.sparse_volume_lod0.parameters() + params += self.sdf_network_finetune.sdf_layer.parameters() + + faster_params += self.sdf_network_finetune.renderer.parameters() + + self.params_to_train = { + 'slower_params': slower_params, + 'params': params, + 'faster_params': faster_params + } + + return self.params_to_train + + @torch.no_grad() + def prepare_con_volume(self, sample): + # * only support batch_size==1 + sizeW = sample['img_wh'][0] + sizeH = sample['img_wh'][1] + partial_vol_origin = sample['partial_vol_origin'][None, :] # [B, 3] + near, far = sample['near_fars'][0, :1], sample['near_fars'][0, 1:] + near = 0.8 * near + far = 1.2 * far + + imgs = sample['images'] + intrinsics = sample['intrinsics'] + intrinsics_l_4x = intrinsics.clone() + intrinsics_l_4x[:, :2] *= 0.25 + w2cs = sample['w2cs'] + c2ws = sample['c2ws'] + proj_matrices = sample['affine_mats'][None, :, :, :] + + # *********************** Lod==0 *********************** + + with torch.no_grad(): + geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs) + # import ipdb; ipdb.set_trace() + conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume( + feature_maps=geometry_feature_maps[None, :, :, :, :], + partial_vol_origin=partial_vol_origin, + proj_mats=proj_matrices, + sizeH=sizeH, + sizeW=sizeW, + lod=0, + ) + + con_volume_lod0 = conditional_features_lod0['dense_volume_scale0'] + + con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0'] + coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ] + + if self.finetune_lod == 0: + return con_volume_lod0, con_valid_mask_volume_lod0, coords_lod0 + + # * extract depth maps for all the images for adaptive rendering_network + depth_maps_lod0, depth_masks_lod0 = None, None + if self.finetune_lod == 1: + sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume( + con_volume_lod0, con_valid_mask_volume_lod0, + coords_lod0, partial_vol_origin) # [1, 1, dX, dY, dZ] + + if self.finetune_lod == 1: + geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1) + + pre_coords, pre_feats = self.sdf_renderer_finetune.get_valid_sparse_coords_by_sdf( + sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0], + maximum_pts=200000) + + pre_coords[:, 1:] = pre_coords[:, 1:] * 2 + + conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume( + feature_maps=geometry_feature_maps_lod1[None, :, :, :, :], + partial_vol_origin=partial_vol_origin, + proj_mats=proj_matrices, + sizeH=sizeH, + sizeW=sizeW, + pre_coords=pre_coords, + pre_feats=pre_feats + ) + + con_volume_lod1 = conditional_features_lod1['dense_volume_scale1'] + con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1'] + coords_lod1 = conditional_features_lod1['coords_scale1'] # [1,3,wX,wY,wZ] + con_valid_mask_volume_lod0 = F.interpolate(con_valid_mask_volume_lod0, scale_factor=2) + + return con_volume_lod1, con_valid_mask_volume_lod1, coords_lod1 + + def initialize_finetune_network(self, sample, sparse_con_volume=None, sparse_coords_volume=None, + train_from_scratch=False): + + if not train_from_scratch: + if sparse_con_volume is None: # if the + + con_volume, con_mask_volume, _ = self.prepare_con_volume(sample) + + device = con_volume.device + + self.sdf_network_finetune.initialize_conditional_volumes( + con_volume, + con_mask_volume + ) + else: + self.sdf_network_finetune.initialize_conditional_volumes( + None, + None, + sparse_con_volume, + sparse_coords_volume + ) + else: + device = sample['images'].device + vol_dims = self.sdf_network_finetune.vol_dims + con_volume = torch.zeros( + [1, self.sdf_network_finetune.regnet_d_out, vol_dims[0], vol_dims[1], vol_dims[2]]).to(device) + con_mask_volume = torch.ones([1, 1, vol_dims[0], vol_dims[1], vol_dims[2]]).to(device) + self.sdf_network_finetune.initialize_conditional_volumes( + con_volume, + con_mask_volume + ) + + self.sdf_network_lod0, self.sdf_network_lod1 = None, None + self.pyramid_feature_network_geometry_lod0, self.pyramid_feature_network_geometry_lod1 = None, None + + def train_step(self, sample, + perturb_overwrite=-1, + background_rgb=None, + iter_step=0, + chunk_size=512, + save_vis=False, + ): + + # * finetune on one specific scene + # * only support batch_size==1 + # ! attention: the list of string cannot be splited in DataParallel + batch_idx = sample['batch_idx'][0] + meta = sample['meta'][batch_idx] # the scan lighting ref_view info + + sizeW = sample['img_wh'][0][0] + sizeH = sample['img_wh'][0][1] + partial_vol_origin = sample['partial_vol_origin'] # [B, 3] + near, far = sample['query_near_far'][0, :1], sample['query_near_far'][0, 1:] + + img_index = sample['img_index'][0] # [n] + + # the full-size ray variables + sample_rays = sample['rays'] + rays_o = sample_rays['rays_o'][0] + rays_d = sample_rays['rays_v'][0] + rays_ndc_uv = sample_rays['rays_ndc_uv'][0] + + imgs = sample['images'][0] + intrinsics = sample['intrinsics'][0] + w2cs = sample['w2cs'][0] + proj_matrices = sample['affine_mats'] + scale_mat = sample['scale_mat'] + trans_mat = sample['trans_mat'] + + query_c2w = sample['query_c2w'] + + # *********************** Lod==0 *********************** + + conditional_features_lod0 = self.sdf_network_finetune.get_conditional_volume() + + con_volume_lod0 = conditional_features_lod0['dense_volume_scale0'] + con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0'] + + # coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ] + + # # - extract mesh + if iter_step % self.val_mesh_freq == 0: + torch.cuda.empty_cache() + self.validate_mesh(self.sdf_network_finetune, + self.sdf_renderer_finetune.extract_geometry, + conditional_volume=con_volume_lod0, + lod=0, + threshold=0., + occupancy_mask=con_valid_mask_volume_lod0[0, 0], + mode='ft', meta=meta, + iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat) + + torch.cuda.empty_cache() + + render_out = self.sdf_renderer_finetune.render( + rays_o, rays_d, near, far, + self.sdf_network_finetune, + None, # rendering_network + background_rgb=background_rgb, + alpha_inter_ratio=1.0, + # * related to conditional feature + lod=0, + conditional_volume=con_volume_lod0, + conditional_valid_mask_volume=con_valid_mask_volume_lod0, + # * 2d feature maps + feature_maps=None, + color_maps=imgs, + w2cs=w2cs, + intrinsics=intrinsics, + img_wh=[sizeW, sizeH], + query_c2w=query_c2w, + if_general_rendering=False, + img_index=img_index, + rays_uv=rays_ndc_uv if self.color_patch_weight > 0 else None, + ) + + # * optional TV regularizer, we don't use in this paper + if self.tv_weight > 0: + tv = self.sdf_network_finetune.tv_regularizer() + else: + tv = 0.0 + render_out['tv'] = tv + loss_lod0, losses_lod0, depth_statis_lod0 = self.cal_losses_sdf(render_out, sample_rays, iter_step) + + losses = { + # - lod 0 + 'loss_lod0': loss_lod0, + 'losses_lod0': losses_lod0, + 'depth_statis_lod0': depth_statis_lod0, + } + + return losses + + def val_step(self, sample, + perturb_overwrite=-1, + background_rgb=None, + iter_step=0, + chunk_size=512, + save_vis=True, + ): + # * only support batch_size==1 + # ! attention: the list of string cannot be splited in DataParallel + batch_idx = sample['batch_idx'][0] + meta = sample['meta'][batch_idx] # the scan lighting ref_view info + + sizeW = sample['img_wh'][0][0] + sizeH = sample['img_wh'][0][1] + H, W = sizeH, sizeW + + partial_vol_origin = sample['partial_vol_origin'] # [B, 3] + near, far = sample['query_near_far'][0, :1], sample['query_near_far'][0, 1:] + + img_index = sample['img_index'][0] # [n] + + # the ray variables + sample_rays = sample['rays'] + rays_o = sample_rays['rays_o'][0] + rays_d = sample_rays['rays_v'][0] + rays_ndc_uv = sample_rays['rays_ndc_uv'][0] + + imgs = sample['images'][0] + intrinsics = sample['intrinsics'][0] + intrinsics_l_4x = intrinsics.clone() + intrinsics_l_4x[:, :2] *= 0.25 + w2cs = sample['w2cs'][0] + c2ws = sample['c2ws'][0] + proj_matrices = sample['affine_mats'] + + # - the image to render + scale_mat = sample['scale_mat'] # [1,4,4] used to convert mesh into true scale + trans_mat = sample['trans_mat'] + query_c2w = sample['query_c2w'] # [1,4,4] + query_w2c = sample['query_w2c'] # [1,4,4] + true_img = sample['query_image'][0] + true_img = np.uint8(true_img.permute(1, 2, 0).cpu().numpy() * 255) + + depth_min, depth_max = near.cpu().numpy(), far.cpu().numpy() + + true_depth = sample['query_depth'] if 'query_depth' in sample.keys() else None + if true_depth is not None: + true_depth = true_depth[0].cpu().numpy() + true_depth_colored = visualize_depth_numpy(true_depth, [depth_min, depth_max])[0] + else: + true_depth_colored = None + + rays_o = rays_o.reshape(-1, 3).split(chunk_size) + rays_d = rays_d.reshape(-1, 3).split(chunk_size) + + # - obtain conditional features + with torch.no_grad(): + # - lod 0 + conditional_features_lod0 = self.sdf_network_finetune.get_conditional_volume() + + con_volume_lod0 = conditional_features_lod0['dense_volume_scale0'] + con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0'] + # coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ] + + out_rgb_fine = [] + out_normal_fine = [] + out_depth_fine = [] + + out_rgb_mlp = [] + + if save_vis: + for rays_o_batch, rays_d_batch in zip(rays_o, rays_d): + + # ****** lod 0 **** + render_out = self.sdf_renderer_finetune.render( + rays_o_batch, rays_d_batch, near, far, + self.sdf_network_finetune, + None, + background_rgb=background_rgb, + alpha_inter_ratio=1., + # * related to conditional feature + lod=0, + conditional_volume=con_volume_lod0, + conditional_valid_mask_volume=con_valid_mask_volume_lod0, + # * 2d feature maps + feature_maps=None, + color_maps=imgs, + w2cs=w2cs, + intrinsics=intrinsics, + img_wh=[sizeW, sizeH], + query_c2w=query_c2w, + if_general_rendering=False, + if_render_with_grad=False, + img_index=img_index, + # rays_uv=rays_ndc_uv + ) + + feasible = lambda key: ((key in render_out) and (render_out[key] is not None)) + + if feasible('depth'): + out_depth_fine.append(render_out['depth'].detach().cpu().numpy()) + + # if render_out['color_coarse'] is not None: + if feasible('color_fine'): + out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy()) + + if feasible('color_mlp'): + out_rgb_mlp.append(render_out['color_mlp'].detach().cpu().numpy()) + + if feasible('gradients') and feasible('weights'): + if render_out['inside_sphere'] is not None: + out_normal_fine.append((render_out['gradients'] * render_out['weights'][:, + :self.n_samples + self.n_importance, + None] * render_out['inside_sphere'][ + ..., None]).sum(dim=1).detach().cpu().numpy()) + else: + out_normal_fine.append((render_out['gradients'] * render_out['weights'][:, + :self.n_samples + self.n_importance, + None]).sum(dim=1).detach().cpu().numpy()) + del render_out + + # - save visualization of lod 0 + + self.save_visualization(true_img, true_depth_colored, out_depth_fine, out_normal_fine, + query_w2c[0], out_rgb_fine, H, W, + depth_min, depth_max, iter_step, meta, "val_lod0", + out_color_mlp=out_rgb_mlp, true_depth=true_depth) + + # - extract mesh + if (iter_step % self.val_mesh_freq == 0): + torch.cuda.empty_cache() + self.validate_mesh(self.sdf_network_finetune, + self.sdf_renderer_finetune.extract_geometry, + conditional_volume=con_volume_lod0, lod=0, + threshold=0, + occupancy_mask=con_valid_mask_volume_lod0[0, 0], + mode='val', meta=meta, + iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat) + + torch.cuda.empty_cache() + + def export_mesh_step(self, sample, + perturb_overwrite=-1, + background_rgb=None, + iter_step=0, + chunk_size=512, + save_vis=True, + ): + # * only support batch_size==1 + # ! attention: the list of string cannot be splited in DataParallel + batch_idx = sample['batch_idx'][0] + # meta = sample['meta'][batch_idx] # the scan lighting ref_view info + meta='' + + sizeW = sample['img_wh'][0][0] + sizeH = sample['img_wh'][0][1] + + near, far = sample['query_near_far'][0, :1], sample['query_near_far'][0, 1:] + + + # the ray variables + sample_rays = sample['rays'] + rays_o = sample_rays['rays_o'][0] + rays_d = sample_rays['rays_v'][0] + + intrinsics = sample['intrinsics'][0] + intrinsics_l_4x = intrinsics.clone() + intrinsics_l_4x[:, :2] *= 0.25 + + + # - the image to render + scale_mat = sample['scale_mat'] # [1,4,4] used to convert mesh into true scale + trans_mat = sample['trans_mat'] + + true_img = sample['query_image'][0] + true_img = np.uint8(true_img.permute(1, 2, 0).cpu().numpy() * 255) + + + rays_o = rays_o.reshape(-1, 3).split(chunk_size) + rays_d = rays_d.reshape(-1, 3).split(chunk_size) + + # import ipdb; ipdb.set_trace() + # - obtain conditional features + with torch.no_grad(): + # - lod 0 + conditional_features_lod0 = self.sdf_network_finetune.get_conditional_volume() + + con_volume_lod0 = conditional_features_lod0['dense_volume_scale0'] + con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0'] + # coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ] + + + # - extract mesh + + torch.cuda.empty_cache() + self.validate_mesh(self.sdf_network_finetune, + self.sdf_renderer_finetune.extract_geometry, + conditional_volume=con_volume_lod0, lod=0, + threshold=0, + occupancy_mask=con_valid_mask_volume_lod0[0, 0], + mode='val', meta=meta, + iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat) + + torch.cuda.empty_cache() + + def save_visualization(self, true_img, true_colored_depth, out_depth, out_normal, w2cs, out_color, H, W, + depth_min, depth_max, iter_step, meta, comment, out_color_mlp=[], true_depth=None): + if len(out_color) > 0: + img_fine = (np.concatenate(out_color, axis=0).reshape([H, W, 3]) * 256).clip(0, 255) + + if len(out_color_mlp) > 0: + img_mlp = (np.concatenate(out_color_mlp, axis=0).reshape([H, W, 3]) * 256).clip(0, 255) + + if len(out_normal) > 0: + normal_img = np.concatenate(out_normal, axis=0) + rot = w2cs[:3, :3].detach().cpu().numpy() + # - convert normal from world space to camera space + normal_img = (np.matmul(rot[None, :, :], + normal_img[:, :, None]).reshape([H, W, 3]) * 128 + 128).clip(0, 255) + if len(out_depth) > 0: + pred_depth = np.concatenate(out_depth, axis=0).reshape([H, W]) + pred_depth_colored = visualize_depth_numpy(pred_depth, [depth_min, depth_max])[0] + + if len(out_depth) > 0: + os.makedirs(os.path.join(self.base_exp_dir, 'depths_' + comment), exist_ok=True) + if true_colored_depth is not None: + + if true_depth is not None: + depth_error_map = np.abs(true_depth - pred_depth) * 5.0 + depth_visualized = np.concatenate( + [depth_error_map, true_colored_depth, pred_depth_colored, true_img], axis=1)[:, :, ::-1] + else: + depth_visualized = np.concatenate( + [true_colored_depth, pred_depth_colored, true_img])[:, :, ::-1] + cv.imwrite( + os.path.join(self.base_exp_dir, 'depths_' + comment, + '{:0>8d}_{}.png'.format(iter_step, meta)), depth_visualized + ) + else: + cv.imwrite( + os.path.join(self.base_exp_dir, 'depths_' + comment, + '{:0>8d}_{}.png'.format(iter_step, meta)), + np.concatenate( + [pred_depth_colored, true_img])[:, :, ::-1]) + if len(out_color) > 0: + os.makedirs(os.path.join(self.base_exp_dir, 'synthesized_color_' + comment), exist_ok=True) + cv.imwrite(os.path.join(self.base_exp_dir, 'synthesized_color_' + comment, + '{:0>8d}_{}.png'.format(iter_step, meta)), + np.concatenate( + [img_fine, true_img])[:, :, ::-1]) # bgr2rgb + + if len(out_color_mlp) > 0: + os.makedirs(os.path.join(self.base_exp_dir, 'synthesized_color_mlp_' + comment), exist_ok=True) + cv.imwrite(os.path.join(self.base_exp_dir, 'synthesized_color_mlp_' + comment, + '{:0>8d}_{}.png'.format(iter_step, meta)), + np.concatenate( + [img_mlp, true_img])[:, :, ::-1]) # bgr2rgb + + if len(out_normal) > 0: + os.makedirs(os.path.join(self.base_exp_dir, 'normals_' + comment), exist_ok=True) + cv.imwrite(os.path.join(self.base_exp_dir, 'normals_' + comment, + '{:0>8d}_{}.png'.format(iter_step, meta)), + normal_img[:, :, ::-1]) + + def forward(self, sample, + perturb_overwrite=-1, + background_rgb=None, + iter_step=0, + mode='train', + save_vis=False, + ): + + if mode == 'train': + return self.train_step(sample, + perturb_overwrite=perturb_overwrite, + background_rgb=background_rgb, + iter_step=iter_step, + ) + elif mode == 'val': + return self.val_step(sample, + perturb_overwrite=perturb_overwrite, + background_rgb=background_rgb, + iter_step=iter_step, save_vis=save_vis, + ) + elif mode == 'export_mesh': + return self.export_mesh_step(sample, + perturb_overwrite=perturb_overwrite, + background_rgb=background_rgb, + iter_step=iter_step, save_vis=save_vis, + ) + + def obtain_pyramid_feature_maps(self, imgs, lod=0): + """ + get feature maps of all conditional images + :param imgs: + :return: + """ + + if lod == 0: + extractor = self.pyramid_feature_network_geometry_lod0 + elif lod >= 1: + extractor = self.pyramid_feature_network_geometry_lod1 + + pyramid_feature_maps = extractor(imgs) + + # * the pyramid features are very important, if only use the coarst features, hard to optimize + fused_feature_maps = torch.cat([ + F.interpolate(pyramid_feature_maps[0], scale_factor=4, mode='bilinear', align_corners=True), + F.interpolate(pyramid_feature_maps[1], scale_factor=2, mode='bilinear', align_corners=True), + pyramid_feature_maps[2] + ], dim=1) + + return fused_feature_maps + + def cal_losses_sdf(self, render_out, sample_rays, iter_step=-1): + + def get_weight(iter_step, weight): + if iter_step < 0: + return weight + + if self.anneal_end == 0.0: + return weight + elif iter_step < self.anneal_start: + return 0.0 + else: + return np.min( + [1.0, + (iter_step - self.anneal_start) / (self.anneal_end * 2 - self.anneal_start)]) * weight + + rays_o = sample_rays['rays_o'][0] + rays_d = sample_rays['rays_v'][0] + true_rgb = sample_rays['rays_color'][0] + + if 'rays_depth' in sample_rays.keys(): + true_depth = sample_rays['rays_depth'][0] + else: + true_depth = None + mask = sample_rays['rays_mask'][0] + + color_fine = render_out['color_fine'] + color_fine_mask = render_out['color_fine_mask'] + depth_pred = render_out['depth'] + + variance = render_out['variance'] + cdf_fine = render_out['cdf_fine'] + weight_sum = render_out['weights_sum'] + + if self.train_from_scratch: + occlusion_aware = False if iter_step < 5000 else True + else: + occlusion_aware = True + + gradient_error_fine = render_out['gradient_error_fine'] + + sdf = render_out['sdf'] + + # * color generated by mlp + color_mlp = render_out['color_mlp'] + color_mlp_mask = render_out['color_mlp_mask'] + + if color_mlp is not None: + # Color loss + color_mlp_mask = color_mlp_mask[..., 0] + + color_mlp_loss, color_mlp_error = self.occlusion_color_criterion(pred=color_mlp, gt=true_rgb, + weight=weight_sum.squeeze(), + mask=color_mlp_mask, + occlusion_aware=occlusion_aware) + + psnr_mlp = 20.0 * torch.log10( + 1.0 / (((color_mlp[color_mlp_mask] - true_rgb[color_mlp_mask]) ** 2).mean() / (3.0)).sqrt()) + else: + color_mlp_loss = 0. + psnr_mlp = 0. + + # - blended patch loss + blended_color_patch = render_out['blended_color_patch'] # [N_pts, Npx, 3] + blended_color_patch_mask = render_out['blended_color_patch_mask'] # [N_pts, 1] + color_patch_loss = 0.0 + color_patch_error = 0.0 + visibility_beta = 0.0 + if blended_color_patch is not None: + rays_patch_color = sample_rays['rays_patch_color'][0] + rays_patch_mask = sample_rays['rays_patch_mask'][0] + patch_mask = (rays_patch_mask * blended_color_patch_mask).float()[:, 0] > 0 # [N_pts] + + color_patch_loss, color_patch_error, visibility_beta = self.occlusion_color_patch_criterion( + blended_color_patch, + rays_patch_color, + weight=weight_sum.squeeze(), + mask=patch_mask, + penalize_ratio=self.visibility_penalize_ratio, + occlusion_aware=occlusion_aware + ) + + if true_depth is not None: + depth_loss = self.depth_criterion(depth_pred, true_depth, mask) + + # depth evaluation + depth_statis = compute_depth_errors(depth_pred.detach().cpu().numpy(), true_depth.cpu().numpy(), + mask.cpu().numpy() > 0) + depth_statis = numpy2tensor(depth_statis, device=rays_o.device) + else: + depth_loss = 0. + depth_statis = None + + # - if without sparse_loss, the mean sdf is 0.02. + # - use sparse_loss to prevent occluded pts have 0 sdf + sparse_loss_1 = torch.exp(-1 * torch.abs(render_out['sdf_random']) * self.sdf_decay_param * 10).mean() + sparse_loss_2 = torch.exp(-1 * torch.abs(sdf) * self.sdf_decay_param).mean() + sparse_loss = (sparse_loss_1 + sparse_loss_2) / 2 + + sdf_mean = torch.abs(sdf).mean() + sparseness_1 = (torch.abs(sdf) < 0.01).to(torch.float32).mean() + sparseness_2 = (torch.abs(sdf) < 0.02).to(torch.float32).mean() + + # Eikonal loss + gradient_error_loss = gradient_error_fine + + # * optional TV regularizer + if 'tv' in render_out.keys(): + tv = render_out['tv'] + else: + tv = 0.0 + + loss = color_mlp_loss + \ + color_patch_loss * self.color_patch_weight + \ + sparse_loss * get_weight(iter_step, self.sdf_sparse_weight) + \ + gradient_error_loss * self.sdf_igr_weight + + losses = { + "loss": loss, + "depth_loss": depth_loss, + "color_mlp_loss": color_mlp_error, + "gradient_error_loss": gradient_error_loss, + "sparse_loss": sparse_loss, + "sparseness_1": sparseness_1, + "sparseness_2": sparseness_2, + "sdf_mean": sdf_mean, + "psnr_mlp": psnr_mlp, + "weights_sum": render_out['weights_sum'], + "alpha_sum": render_out['alpha_sum'], + "variance": render_out['variance'], + "sparse_weight": get_weight(iter_step, self.sdf_sparse_weight), + 'color_patch_loss': color_patch_error, + 'visibility_beta': visibility_beta, + 'tv': tv, + } + + losses = numpy2tensor(losses, device=rays_o.device) + + return loss, losses, depth_statis + + def validate_mesh(self, sdf_network, func_extract_geometry, world_space=True, resolution=256, + threshold=0.0, mode='val', + # * 3d feature volume + conditional_volume=None, lod=None, occupancy_mask=None, + bound_min=[-1, -1, -1], bound_max=[1, 1, 1], meta='', iter_step=0, scale_mat=None, + trans_mat=None + ): + bound_min = torch.tensor(bound_min, dtype=torch.float32) + bound_max = torch.tensor(bound_max, dtype=torch.float32) + + vertices, triangles, fields = func_extract_geometry( + sdf_network, + bound_min, bound_max, resolution=resolution, + threshold=threshold, device=conditional_volume.device, + # * 3d feature volume + conditional_volume=conditional_volume, lod=lod, + # occupancy_mask=occupancy_mask + ) + + + + if scale_mat is not None: + scale_mat_np = scale_mat.cpu().numpy() + vertices = vertices * scale_mat_np[0][0, 0] + scale_mat_np[0][:3, 3][None] + + if trans_mat is not None: + trans_mat_np = trans_mat.cpu().numpy() + vertices_homo = np.concatenate([vertices, np.ones_like(vertices[:, :1])], axis=1) + vertices = np.matmul(trans_mat_np, vertices_homo[:, :, None])[:, :3, 0] + + mesh = trimesh.Trimesh(vertices, triangles) + os.makedirs(os.path.join(self.base_exp_dir, 'meshes_' + mode), exist_ok=True) + mesh.export(os.path.join(self.base_exp_dir, 'meshes_' + mode, + 'mesh_{:0>8d}_{}_lod{:0>1d}.ply'.format(iter_step, meta, lod))) + + def gen_video(self, sample, + perturb_overwrite=-1, + background_rgb=None, + iter_step=0, + chunk_size=1024, + ): + # * only support batch_size==1 + batch_idx = sample['batch_idx'][0] + meta = sample['meta'][batch_idx] # the scan lighting ref_view info + + sizeW = sample['img_wh'][0][0] + sizeH = sample['img_wh'][0][1] + H, W = sizeH, sizeW + + partial_vol_origin = sample['partial_vol_origin'] # [B, 3] + near, far = sample['query_near_far'][0, :1], sample['query_near_far'][0, 1:] * 0.8 + + img_index = sample['img_index'][0] # [n] + + # the ray variables + sample_rays = sample['rays'] + rays_o = sample_rays['rays_o'][0] + rays_d = sample_rays['rays_v'][0] + rays_ndc_uv = sample_rays['rays_ndc_uv'][0] + + imgs = sample['images'][0] + intrinsics = sample['intrinsics'][0] + intrinsics_l_4x = intrinsics.clone() + intrinsics_l_4x[:, :2] *= 0.25 + w2cs = sample['w2cs'][0] + c2ws = sample['c2ws'][0] + proj_matrices = sample['affine_mats'] + + # - the image to render + scale_mat = sample['scale_mat'] # [1,4,4] used to convert mesh into true scale + trans_mat = sample['trans_mat'] + query_c2w = sample['query_c2w'] # [1,4,4] + query_w2c = sample['query_w2c'] # [1,4,4] + true_img = sample['query_image'][0] + true_img = np.uint8(true_img.permute(1, 2, 0).cpu().numpy() * 255) + rendering_c2ws = sample['rendering_c2ws'][0] # [n, 4, 4] + rendering_imgs_idx = sample['rendering_imgs_idx'][0] + + depth_min, depth_max = near.cpu().numpy(), far.cpu().numpy() + + true_depth = sample['query_depth'] if 'query_depth' in sample.keys() else None + if true_depth is not None: + true_depth = true_depth[0].cpu().numpy() + true_depth_colored = visualize_depth_numpy(true_depth, [depth_min, depth_max])[0] + else: + true_depth_colored = None + + # - obtain conditional features + with torch.no_grad(): + # - lod 0 + conditional_features_lod0 = self.sdf_network_finetune.get_conditional_volume() + + con_volume_lod0 = conditional_features_lod0['dense_volume_scale0'] + con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0'] + # coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ] + + inter_views_num = 60 + resolution_level = 2 + for r_idx in range(rendering_c2ws.shape[0] - 1): + for idx in range(inter_views_num): + query_c2w, rays_o, rays_d = gen_rays_between( + rendering_c2ws[r_idx], rendering_c2ws[r_idx + 1], intrinsics[0], + np.sin(((idx / 60.0) - 0.5) * np.pi) * 0.5 + 0.5, + H, W, resolution_level=resolution_level) + + rays_o = rays_o.reshape(-1, 3).split(chunk_size) + rays_d = rays_d.reshape(-1, 3).split(chunk_size) + + out_rgb_fine = [] + out_normal_fine = [] + out_depth_fine = [] + + for rays_o_batch, rays_d_batch in zip(rays_o, rays_d): + # ****** lod 0 **** + render_out = self.sdf_renderer_finetune.render( + rays_o_batch, rays_d_batch, near, far, + self.sdf_network_finetune, + None, + background_rgb=background_rgb, + alpha_inter_ratio=1., + # * related to conditional feature + lod=0, + conditional_volume=con_volume_lod0, + conditional_valid_mask_volume=con_valid_mask_volume_lod0, + # * 2d feature maps + feature_maps=None, + color_maps=imgs, + w2cs=w2cs, + intrinsics=intrinsics, + img_wh=[sizeW, sizeH], + query_c2w=query_c2w, + if_general_rendering=False, + if_render_with_grad=False, + img_index=img_index, + # rays_uv=rays_ndc_uv + ) + # pdb.set_trace() + feasible = lambda key: ((key in render_out) and (render_out[key] is not None)) + + if feasible('depth'): + out_depth_fine.append(render_out['depth'].detach().cpu().numpy()) + + # if render_out['color_coarse'] is not None: + if feasible('color_mlp'): + out_rgb_fine.append(render_out['color_mlp'].detach().cpu().numpy()) + if feasible('gradients') and feasible('weights'): + if render_out['inside_sphere'] is not None: + out_normal_fine.append((render_out['gradients'] * render_out['weights'][:, + :self.n_samples + self.n_importance, + None] * render_out['inside_sphere'][ + ..., None]).sum(dim=1).detach().cpu().numpy()) + else: + out_normal_fine.append((render_out['gradients'] * render_out['weights'][:, + :self.n_samples + self.n_importance, + None]).sum(dim=1).detach().cpu().numpy()) + del render_out + + img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape( + [H // resolution_level, W // resolution_level, 3, -1]) * 256).clip(0, 255) + save_dir = os.path.join(self.base_exp_dir, 'render_{}_{}'.format(rendering_imgs_idx[r_idx], + rendering_imgs_idx[r_idx + 1])) + os.makedirs(save_dir, exist_ok=True) + # ic(img_fine.shape) + print(cv.imwrite( + os.path.join(save_dir, '{}.png'.format(idx + r_idx * inter_views_num)), + img_fine.squeeze()[:, :, ::-1])) + print(os.path.join(save_dir, '{}.png'.format(idx + r_idx * inter_views_num))) diff --git a/SparseNeuS_demo_v1/models/trainer_generic.py b/SparseNeuS_demo_v1/models/trainer_generic.py new file mode 100644 index 0000000000000000000000000000000000000000..5c87d61d5c7feb93dadd40099a5ebe0a9db81924 --- /dev/null +++ b/SparseNeuS_demo_v1/models/trainer_generic.py @@ -0,0 +1,1224 @@ +""" +decouple the trainer with the renderer +""" +import os +import cv2 as cv +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np +import logging +import mcubes +import trimesh +from icecream import ic + +from utils.misc_utils import visualize_depth_numpy + +from utils.training_utils import numpy2tensor +from loss.depth_metric import compute_depth_errors + +from loss.depth_loss import DepthLoss, DepthSmoothLoss + +from models.rays import gen_rays_between + +from models.sparse_neus_renderer import SparseNeuSRenderer + +def safe_l2_normalize(x, dim=None, eps=1e-6): + return F.normalize(x, p=2, dim=dim, eps=eps) + + +class GenericTrainer(nn.Module): + def __init__(self, + rendering_network_outside, + pyramid_feature_network_lod0, + pyramid_feature_network_lod1, + sdf_network_lod0, + sdf_network_lod1, + variance_network_lod0, + variance_network_lod1, + rendering_network_lod0, + rendering_network_lod1, + n_samples_lod0, + n_importance_lod0, + n_samples_lod1, + n_importance_lod1, + n_outside, + perturb, + alpha_type='div', + conf=None, + timestamp="", + mode='train', + base_exp_dir=None, + ): + super(GenericTrainer, self).__init__() + + self.conf = conf + self.timestamp = timestamp + + + self.base_exp_dir = base_exp_dir + + + self.anneal_start = self.conf.get_float('train.anneal_start', default=0.0) + self.anneal_end = self.conf.get_float('train.anneal_end', default=0.0) + self.anneal_start_lod1 = self.conf.get_float('train.anneal_start_lod1', default=0.0) + self.anneal_end_lod1 = self.conf.get_float('train.anneal_end_lod1', default=0.0) + + # network setups + self.rendering_network_outside = rendering_network_outside + self.pyramid_feature_network_geometry_lod0 = pyramid_feature_network_lod0 # 2D pyramid feature network for geometry + self.pyramid_feature_network_geometry_lod1 = pyramid_feature_network_lod1 # use differnet networks for the two lods + + # when num_lods==2, may consume too much memeory + self.sdf_network_lod0 = sdf_network_lod0 + self.sdf_network_lod1 = sdf_network_lod1 + + # - warpped by ModuleList to support DataParallel + self.variance_network_lod0 = variance_network_lod0 + self.variance_network_lod1 = variance_network_lod1 + + self.rendering_network_lod0 = rendering_network_lod0 + self.rendering_network_lod1 = rendering_network_lod1 + + self.n_samples_lod0 = n_samples_lod0 + self.n_importance_lod0 = n_importance_lod0 + self.n_samples_lod1 = n_samples_lod1 + self.n_importance_lod1 = n_importance_lod1 + self.n_outside = n_outside + self.num_lods = conf.get_int('model.num_lods') # the number of octree lods + self.perturb = perturb + self.alpha_type = alpha_type + + # - the two renderers + self.sdf_renderer_lod0 = SparseNeuSRenderer( + self.rendering_network_outside, + self.sdf_network_lod0, + self.variance_network_lod0, + self.rendering_network_lod0, + self.n_samples_lod0, + self.n_importance_lod0, + self.n_outside, + self.perturb, + alpha_type='div', + conf=self.conf) + + self.sdf_renderer_lod1 = SparseNeuSRenderer( + self.rendering_network_outside, + self.sdf_network_lod1, + self.variance_network_lod1, + self.rendering_network_lod1, + self.n_samples_lod1, + self.n_importance_lod1, + self.n_outside, + self.perturb, + alpha_type='div', + conf=self.conf) + + self.if_fix_lod0_networks = self.conf.get_bool('train.if_fix_lod0_networks') + + # sdf network weights + self.sdf_igr_weight = self.conf.get_float('train.sdf_igr_weight') + self.sdf_sparse_weight = self.conf.get_float('train.sdf_sparse_weight', default=0) + self.sdf_decay_param = self.conf.get_float('train.sdf_decay_param', default=100) + self.fg_bg_weight = self.conf.get_float('train.fg_bg_weight', default=0.00) + self.bg_ratio = self.conf.get_float('train.bg_ratio', default=0.0) + + self.depth_criterion = DepthLoss() + + # - DataParallel mode, cannot modify attributes in forward() + # self.iter_step = 0 + self.val_mesh_freq = self.conf.get_int('train.val_mesh_freq') + + # - True for finetuning; False for general training + self.if_fitted_rendering = self.conf.get_bool('train.if_fitted_rendering', default=False) + + self.prune_depth_filter = self.conf.get_bool('model.prune_depth_filter', default=False) + + def get_trainable_params(self): + # set trainable params + + self.params_to_train = [] + + if not self.if_fix_lod0_networks: + # load pretrained featurenet + self.params_to_train += list(self.pyramid_feature_network_geometry_lod0.parameters()) + self.params_to_train += list(self.sdf_network_lod0.parameters()) + self.params_to_train += list(self.variance_network_lod0.parameters()) + + if self.rendering_network_lod0 is not None: + self.params_to_train += list(self.rendering_network_lod0.parameters()) + + if self.sdf_network_lod1 is not None: + # load pretrained featurenet + self.params_to_train += list(self.pyramid_feature_network_geometry_lod1.parameters()) + + self.params_to_train += list(self.sdf_network_lod1.parameters()) + self.params_to_train += list(self.variance_network_lod1.parameters()) + if self.rendering_network_lod1 is not None: + self.params_to_train += list(self.rendering_network_lod1.parameters()) + + return self.params_to_train + + def train_step(self, sample, + perturb_overwrite=-1, + background_rgb=None, + alpha_inter_ratio_lod0=0.0, + alpha_inter_ratio_lod1=0.0, + iter_step=0, + ): + # * only support batch_size==1 + # ! attention: the list of string cannot be splited in DataParallel + batch_idx = sample['batch_idx'][0] + meta = sample['meta'][batch_idx] # the scan lighting ref_view info + + sizeW = sample['img_wh'][0][0] + sizeH = sample['img_wh'][0][1] + partial_vol_origin = sample['partial_vol_origin'] # [B, 3] + near, far = sample['near_fars'][0, 0, :1], sample['near_fars'][0, 0, 1:] + + # the full-size ray variables + sample_rays = sample['rays'] + rays_o = sample_rays['rays_o'][0] + rays_d = sample_rays['rays_v'][0] + + imgs = sample['images'][0] + intrinsics = sample['intrinsics'][0] + intrinsics_l_4x = intrinsics.clone() + intrinsics_l_4x[:, :2] *= 0.25 + w2cs = sample['w2cs'][0] + c2ws = sample['c2ws'][0] + proj_matrices = sample['affine_mats'] + scale_mat = sample['scale_mat'] + trans_mat = sample['trans_mat'] + + # *********************** Lod==0 *********************** + if not self.if_fix_lod0_networks: + geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs) + + conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume( + feature_maps=geometry_feature_maps[None, 1:, :, :, :], + partial_vol_origin=partial_vol_origin, + proj_mats=proj_matrices[:,1:], + # proj_mats=proj_matrices, + sizeH=sizeH, + sizeW=sizeW, + lod=0, + ) + + else: + with torch.no_grad(): + geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0) + # geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0) + conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume( + feature_maps=geometry_feature_maps[None, 1:, :, :, :], + partial_vol_origin=partial_vol_origin, + proj_mats=proj_matrices[:,1:], + # proj_mats=proj_matrices, + sizeH=sizeH, + sizeW=sizeW, + lod=0, + ) + # print("Checker2:, construct cost volume") + con_volume_lod0 = conditional_features_lod0['dense_volume_scale0'] + + con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0'] + # import ipdb; ipdb.set_trace() + coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ] + + # * extract depth maps for all the images + depth_maps_lod0, depth_masks_lod0 = None, None + if self.num_lods > 1: + sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume( + con_volume_lod0, con_valid_mask_volume_lod0, + coords_lod0, partial_vol_origin) # [1, 1, dX, dY, dZ] + + if self.prune_depth_filter: + depth_maps_lod0_l4x, depth_masks_lod0_l4x = self.sdf_renderer_lod0.extract_depth_maps( + self.sdf_network_lod0, sdf_volume_lod0, intrinsics_l_4x, c2ws, + sizeH // 4, sizeW // 4, near * 1.5, far) + depth_maps_lod0 = F.interpolate(depth_maps_lod0_l4x, size=(sizeH, sizeW), mode='bilinear', + align_corners=True) + depth_masks_lod0 = F.interpolate(depth_masks_lod0_l4x.float(), size=(sizeH, sizeW), mode='nearest') + + # *************** losses + loss_lod0, losses_lod0, depth_statis_lod0 = None, None, None + + if not self.if_fix_lod0_networks: + + render_out = self.sdf_renderer_lod0.render( + rays_o, rays_d, near, far, + self.sdf_network_lod0, + self.rendering_network_lod0, + background_rgb=background_rgb, + alpha_inter_ratio=alpha_inter_ratio_lod0, + # * related to conditional feature + lod=0, + conditional_volume=con_volume_lod0, + conditional_valid_mask_volume=con_valid_mask_volume_lod0, + # * 2d feature maps + feature_maps=geometry_feature_maps, + color_maps=imgs, + w2cs=w2cs, + intrinsics=intrinsics, + img_wh=[sizeW, sizeH], + if_general_rendering=True, + if_render_with_grad=True, + ) + + loss_lod0, losses_lod0, depth_statis_lod0 = self.cal_losses_sdf(render_out, sample_rays, + iter_step, lod=0) + + # *********************** Lod==1 *********************** + + loss_lod1, losses_lod1, depth_statis_lod1 = None, None, None + + if self.num_lods > 1: + geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1) + # geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1) + if self.prune_depth_filter: + pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf_depthfilter( + sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0], + depth_maps_lod0, proj_matrices[0], + partial_vol_origin, self.sdf_network_lod0.voxel_size, + near, far, self.sdf_network_lod0.voxel_size, 12) + else: + pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf( + sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0]) + + pre_coords[:, 1:] = pre_coords[:, 1:] * 2 + + # ? It seems that training gru_fusion, this part should be trainable too + conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume( + feature_maps=geometry_feature_maps_lod1[None, 1:, :, :, :], + partial_vol_origin=partial_vol_origin, + proj_mats=proj_matrices[:,1:], + # proj_mats=proj_matrices, + sizeH=sizeH, + sizeW=sizeW, + pre_coords=pre_coords, + pre_feats=pre_feats, + ) + + con_volume_lod1 = conditional_features_lod1['dense_volume_scale1'] + con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1'] + + # if not self.if_gru_fusion_lod1: + render_out_lod1 = self.sdf_renderer_lod1.render( + rays_o, rays_d, near, far, + self.sdf_network_lod1, + self.rendering_network_lod1, + background_rgb=background_rgb, + alpha_inter_ratio=alpha_inter_ratio_lod1, + # * related to conditional feature + lod=1, + conditional_volume=con_volume_lod1, + conditional_valid_mask_volume=con_valid_mask_volume_lod1, + # * 2d feature maps + feature_maps=geometry_feature_maps_lod1, + color_maps=imgs, + w2cs=w2cs, + intrinsics=intrinsics, + img_wh=[sizeW, sizeH], + bg_ratio=self.bg_ratio, + ) + loss_lod1, losses_lod1, depth_statis_lod1 = self.cal_losses_sdf(render_out_lod1, sample_rays, + iter_step, lod=1) + + # print("Checker3:, compute losses") + # # - extract mesh + if iter_step % self.val_mesh_freq == 0: + torch.cuda.empty_cache() + self.validate_mesh(self.sdf_network_lod0, + self.sdf_renderer_lod0.extract_geometry, + conditional_volume=con_volume_lod0, lod=0, + threshold=0, + # occupancy_mask=con_valid_mask_volume_lod0[0, 0], + mode='train_bg', meta=meta, + iter_step=iter_step, scale_mat=scale_mat, + trans_mat=trans_mat) + torch.cuda.empty_cache() + + if self.num_lods > 1: + self.validate_mesh(self.sdf_network_lod1, + self.sdf_renderer_lod1.extract_geometry, + conditional_volume=con_volume_lod1, lod=1, + # occupancy_mask=con_valid_mask_volume_lod1[0, 0].detach(), + mode='train_bg', meta=meta, + iter_step=iter_step, scale_mat=scale_mat, + trans_mat=trans_mat) + # import ipdb; ipdb.set_trace() + # print("Checker3.1:, after val mesh") + losses = { + # - lod 0 + 'loss_lod0': loss_lod0, + 'losses_lod0': losses_lod0, + 'depth_statis_lod0': depth_statis_lod0, + + # - lod 1 + 'loss_lod1': loss_lod1, + 'losses_lod1': losses_lod1, + 'depth_statis_lod1': depth_statis_lod1, + + } + + return losses + + def val_step(self, sample, + perturb_overwrite=-1, + background_rgb=None, + alpha_inter_ratio_lod0=0.0, + alpha_inter_ratio_lod1=0.0, + iter_step=0, + chunk_size=512, + save_vis=False, + ): + # * only support batch_size==1 + # ! attention: the list of string cannot be splited in DataParallel + batch_idx = sample['batch_idx'][0] + meta = sample['meta'][batch_idx] # the scan lighting ref_view info + + sizeW = sample['img_wh'][0][0] + sizeH = sample['img_wh'][0][1] + H, W = sizeH, sizeW + + partial_vol_origin = sample['partial_vol_origin'] # [B, 3] + near, far = sample['query_near_far'][0, :1], sample['query_near_far'][0, 1:] + + # the ray variables + sample_rays = sample['rays'] + rays_o = sample_rays['rays_o'][0] + rays_d = sample_rays['rays_v'][0] + rays_ndc_uv = sample_rays['rays_ndc_uv'][0] + + imgs = sample['images'][0] + intrinsics = sample['intrinsics'][0] + intrinsics_l_4x = intrinsics.clone() + intrinsics_l_4x[:, :2] *= 0.25 + w2cs = sample['w2cs'][0] + c2ws = sample['c2ws'][0] + proj_matrices = sample['affine_mats'] + + # render_img_idx = sample['render_img_idx'][0] + # true_img = sample['images'][0][render_img_idx] + + # - the image to render + scale_mat = sample['scale_mat'] # [1,4,4] used to convert mesh into true scale + trans_mat = sample['trans_mat'] + query_c2w = sample['query_c2w'] # [1,4,4] + query_w2c = sample['query_w2c'] # [1,4,4] + true_img = sample['query_image'][0] + true_img = np.uint8(true_img.permute(1, 2, 0).cpu().numpy() * 255) + + depth_min, depth_max = near.cpu().numpy(), far.cpu().numpy() + + scale_factor = sample['scale_factor'][0].cpu().numpy() + true_depth = sample['query_depth'] if 'query_depth' in sample.keys() else None + if true_depth is not None: + true_depth = true_depth[0].cpu().numpy() + true_depth_colored = visualize_depth_numpy(true_depth, [depth_min, depth_max])[0] + else: + true_depth_colored = None + + rays_o = rays_o.reshape(-1, 3).split(chunk_size) + rays_d = rays_d.reshape(-1, 3).split(chunk_size) + + # - obtain conditional features + with torch.no_grad(): + # - obtain conditional features + geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0) + # import ipdb; ipdb.set_trace() + # - lod 0 + conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume( + feature_maps=geometry_feature_maps[None, :, :, :, :], + partial_vol_origin=partial_vol_origin, + proj_mats=proj_matrices, + sizeH=sizeH, + sizeW=sizeW, + lod=0, + ) + + con_volume_lod0 = conditional_features_lod0['dense_volume_scale0'] + con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0'] + coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ] + + if self.num_lods > 1: + sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume( + con_volume_lod0, con_valid_mask_volume_lod0, + coords_lod0, partial_vol_origin) # [1, 1, dX, dY, dZ] + + depth_maps_lod0, depth_masks_lod0 = None, None + if self.prune_depth_filter: + depth_maps_lod0_l4x, depth_masks_lod0_l4x = self.sdf_renderer_lod0.extract_depth_maps( + self.sdf_network_lod0, sdf_volume_lod0, + intrinsics_l_4x, c2ws, + sizeH // 4, sizeW // 4, near * 1.5, far) # - near*1.5 is a experienced number + depth_maps_lod0 = F.interpolate(depth_maps_lod0_l4x, size=(sizeH, sizeW), mode='bilinear', + align_corners=True) + depth_masks_lod0 = F.interpolate(depth_masks_lod0_l4x.float(), size=(sizeH, sizeW), mode='nearest') + + #### visualize the depth_maps_lod0 for checking + colored_depth_maps_lod0 = [] + for i in range(depth_maps_lod0.shape[0]): + colored_depth_maps_lod0.append( + visualize_depth_numpy(depth_maps_lod0[i, 0].cpu().numpy(), [depth_min, depth_max])[0]) + + colored_depth_maps_lod0 = np.concatenate(colored_depth_maps_lod0, axis=0).astype(np.uint8) + os.makedirs(os.path.join(self.base_exp_dir, 'depth_maps_lod0'), exist_ok=True) + cv.imwrite(os.path.join(self.base_exp_dir, 'depth_maps_lod0', + '{:0>8d}_{}.png'.format(iter_step, meta)), + colored_depth_maps_lod0[:, :, ::-1]) + + if self.num_lods > 1: + geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1) + + if self.prune_depth_filter: + pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf_depthfilter( + sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0], + depth_maps_lod0, proj_matrices[0], + partial_vol_origin, self.sdf_network_lod0.voxel_size, + near, far, self.sdf_network_lod0.voxel_size, 12) + else: + pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf( + sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0]) + + pre_coords[:, 1:] = pre_coords[:, 1:] * 2 + + with torch.no_grad(): + conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume( + feature_maps=geometry_feature_maps_lod1[None, :, :, :, :], + partial_vol_origin=partial_vol_origin, + proj_mats=proj_matrices, + sizeH=sizeH, + sizeW=sizeW, + pre_coords=pre_coords, + pre_feats=pre_feats, + ) + + con_volume_lod1 = conditional_features_lod1['dense_volume_scale1'] + con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1'] + + out_rgb_fine = [] + out_normal_fine = [] + out_depth_fine = [] + + out_rgb_fine_lod1 = [] + out_normal_fine_lod1 = [] + out_depth_fine_lod1 = [] + + # out_depth_fine_explicit = [] + if save_vis: + for rays_o_batch, rays_d_batch in zip(rays_o, rays_d): + + # ****** lod 0 **** + render_out = self.sdf_renderer_lod0.render( + rays_o_batch, rays_d_batch, near, far, + self.sdf_network_lod0, + self.rendering_network_lod0, + background_rgb=background_rgb, + alpha_inter_ratio=alpha_inter_ratio_lod0, + # * related to conditional feature + lod=0, + conditional_volume=con_volume_lod0, + conditional_valid_mask_volume=con_valid_mask_volume_lod0, + # * 2d feature maps + feature_maps=geometry_feature_maps, + color_maps=imgs, + w2cs=w2cs, + intrinsics=intrinsics, + img_wh=[sizeW, sizeH], + query_c2w=query_c2w, + if_render_with_grad=False, + ) + + feasible = lambda key: ((key in render_out) and (render_out[key] is not None)) + + if feasible('depth'): + out_depth_fine.append(render_out['depth'].detach().cpu().numpy()) + + # if render_out['color_coarse'] is not None: + if feasible('color_fine'): + out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy()) + if feasible('gradients') and feasible('weights'): + if render_out['inside_sphere'] is not None: + out_normal_fine.append((render_out['gradients'] * render_out['weights'][:, + :self.n_samples_lod0 + self.n_importance_lod0, + None] * render_out['inside_sphere'][ + ..., None]).sum(dim=1).detach().cpu().numpy()) + else: + out_normal_fine.append((render_out['gradients'] * render_out['weights'][:, + :self.n_samples_lod0 + self.n_importance_lod0, + None]).sum(dim=1).detach().cpu().numpy()) + del render_out + + # ****************** lod 1 ************************** + if self.num_lods > 1: + for rays_o_batch, rays_d_batch in zip(rays_o, rays_d): + render_out_lod1 = self.sdf_renderer_lod1.render( + rays_o_batch, rays_d_batch, near, far, + self.sdf_network_lod1, + self.rendering_network_lod1, + background_rgb=background_rgb, + alpha_inter_ratio=alpha_inter_ratio_lod1, + # * related to conditional feature + lod=1, + conditional_volume=con_volume_lod1, + conditional_valid_mask_volume=con_valid_mask_volume_lod1, + # * 2d feature maps + feature_maps=geometry_feature_maps_lod1, + color_maps=imgs, + w2cs=w2cs, + intrinsics=intrinsics, + img_wh=[sizeW, sizeH], + query_c2w=query_c2w, + if_render_with_grad=False, + ) + + feasible = lambda key: ((key in render_out_lod1) and (render_out_lod1[key] is not None)) + + if feasible('depth'): + out_depth_fine_lod1.append(render_out_lod1['depth'].detach().cpu().numpy()) + + # if render_out['color_coarse'] is not None: + if feasible('color_fine'): + out_rgb_fine_lod1.append(render_out_lod1['color_fine'].detach().cpu().numpy()) + if feasible('gradients') and feasible('weights'): + if render_out_lod1['inside_sphere'] is not None: + out_normal_fine_lod1.append((render_out_lod1['gradients'] * render_out_lod1['weights'][:, + :self.n_samples_lod1 + self.n_importance_lod1, + None] * + render_out_lod1['inside_sphere'][ + ..., None]).sum(dim=1).detach().cpu().numpy()) + else: + out_normal_fine_lod1.append((render_out_lod1['gradients'] * render_out_lod1['weights'][:, + :self.n_samples_lod1 + self.n_importance_lod1, + None]).sum( + dim=1).detach().cpu().numpy()) + del render_out_lod1 + + # - save visualization of lod 0 + + self.save_visualization(true_img, true_depth_colored, out_depth_fine, out_normal_fine, + query_w2c[0], out_rgb_fine, H, W, + depth_min, depth_max, iter_step, meta, "val_lod0", true_depth=true_depth, scale_factor=scale_factor) + + if self.num_lods > 1: + self.save_visualization(true_img, true_depth_colored, out_depth_fine_lod1, out_normal_fine_lod1, + query_w2c[0], out_rgb_fine_lod1, H, W, + depth_min, depth_max, iter_step, meta, "val_lod1", true_depth=true_depth, scale_factor=scale_factor) + + # - extract mesh + if (iter_step % self.val_mesh_freq == 0): + torch.cuda.empty_cache() + self.validate_mesh(self.sdf_network_lod0, + self.sdf_renderer_lod0.extract_geometry, + conditional_volume=con_volume_lod0, lod=0, + threshold=0, + # occupancy_mask=con_valid_mask_volume_lod0[0, 0], + mode='val_bg', meta=meta, + iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat) + torch.cuda.empty_cache() + + if self.num_lods > 1: + self.validate_mesh(self.sdf_network_lod1, + self.sdf_renderer_lod1.extract_geometry, + conditional_volume=con_volume_lod1, lod=1, + # occupancy_mask=con_valid_mask_volume_lod1[0, 0].detach(), + mode='val_bg', meta=meta, + iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat) + + torch.cuda.empty_cache() + + + + def export_mesh_step(self, sample, + perturb_overwrite=-1, + background_rgb=None, + alpha_inter_ratio_lod0=0.0, + alpha_inter_ratio_lod1=0.0, + iter_step=0, + chunk_size=512, + save_vis=False, + ): + # * only support batch_size==1 + # ! attention: the list of string cannot be splited in DataParallel + batch_idx = sample['batch_idx'][0] + meta = sample['meta'][batch_idx] # the scan lighting ref_view info + + sizeW = sample['img_wh'][0][0] + sizeH = sample['img_wh'][0][1] + H, W = sizeH, sizeW + + partial_vol_origin = sample['partial_vol_origin'] # [B, 3] + near, far = sample['query_near_far'][0, :1], sample['query_near_far'][0, 1:] + + # the ray variables + sample_rays = sample['rays'] + rays_o = sample_rays['rays_o'][0] + rays_d = sample_rays['rays_v'][0] + rays_ndc_uv = sample_rays['rays_ndc_uv'][0] + + imgs = sample['images'][0] + intrinsics = sample['intrinsics'][0] + intrinsics_l_4x = intrinsics.clone() + intrinsics_l_4x[:, :2] *= 0.25 + w2cs = sample['w2cs'][0] + c2ws = sample['c2ws'][0] + # target_candidate_w2cs = sample['target_candidate_w2cs'][0] + proj_matrices = sample['affine_mats'] + + + # - the image to render + scale_mat = sample['scale_mat'] # [1,4,4] used to convert mesh into true scale + trans_mat = sample['trans_mat'] + query_c2w = sample['query_c2w'] # [1,4,4] + query_w2c = sample['query_w2c'] # [1,4,4] + true_img = sample['query_image'][0] + true_img = np.uint8(true_img.permute(1, 2, 0).cpu().numpy() * 255) + + depth_min, depth_max = near.cpu().numpy(), far.cpu().numpy() + + scale_factor = sample['scale_factor'][0].cpu().numpy() + true_depth = sample['query_depth'] if 'query_depth' in sample.keys() else None + if true_depth is not None: + true_depth = true_depth[0].cpu().numpy() + true_depth_colored = visualize_depth_numpy(true_depth, [depth_min, depth_max])[0] + else: + true_depth_colored = None + + rays_o = rays_o.reshape(-1, 3).split(chunk_size) + rays_d = rays_d.reshape(-1, 3).split(chunk_size) + # import time + # jha_begin1 = time.time() + # - obtain conditional features + with torch.no_grad(): + # - obtain conditional features + geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0) + # - lod 0 + conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume( + feature_maps=geometry_feature_maps[None, :, :, :, :], + partial_vol_origin=partial_vol_origin, + proj_mats=proj_matrices, + sizeH=sizeH, + sizeW=sizeW, + lod=0, + ) + # jha_end1 = time.time() + # print("get_conditional_volume: ", jha_end1 - jha_begin1) + # jha_begin2 = time.time() + con_volume_lod0 = conditional_features_lod0['dense_volume_scale0'] + con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0'] + coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ] + + if self.num_lods > 1: + sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume( + con_volume_lod0, con_valid_mask_volume_lod0, + coords_lod0, partial_vol_origin) # [1, 1, dX, dY, dZ] + + depth_maps_lod0, depth_masks_lod0 = None, None + + + if self.num_lods > 1: + geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1) + + if self.prune_depth_filter: + pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf_depthfilter( + sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0], + depth_maps_lod0, proj_matrices[0], + partial_vol_origin, self.sdf_network_lod0.voxel_size, + near, far, self.sdf_network_lod0.voxel_size, 12) + else: + pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf( + sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0]) + + pre_coords[:, 1:] = pre_coords[:, 1:] * 2 + + with torch.no_grad(): + conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume( + feature_maps=geometry_feature_maps_lod1[None, :, :, :, :], + partial_vol_origin=partial_vol_origin, + proj_mats=proj_matrices, + sizeH=sizeH, + sizeW=sizeW, + pre_coords=pre_coords, + pre_feats=pre_feats, + ) + + con_volume_lod1 = conditional_features_lod1['dense_volume_scale1'] + con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1'] + + out_rgb_fine = [] + out_normal_fine = [] + out_depth_fine = [] + + out_rgb_fine_lod1 = [] + out_normal_fine_lod1 = [] + out_depth_fine_lod1 = [] + + # jha_end2 = time.time() + # print("interval before starting mesh export: ", jha_end2 - jha_begin2) + + # - extract mesh + if (iter_step % self.val_mesh_freq == 0): + torch.cuda.empty_cache() + # jha_begin3 = time.time() + self.validate_colored_mesh( + density_or_sdf_network=self.sdf_network_lod0, + func_extract_geometry=self.sdf_renderer_lod0.extract_geometry, + conditional_volume=con_volume_lod0, + conditional_valid_mask_volume = con_valid_mask_volume_lod0, + feature_maps=geometry_feature_maps, + color_maps=imgs, + w2cs=w2cs, + target_candidate_w2cs=None, + intrinsics=intrinsics, + rendering_network=self.rendering_network_lod0, + rendering_projector=self.sdf_renderer_lod0.rendering_projector, + lod=0, + threshold=0, + query_c2w=query_c2w, + mode='val_bg', meta=meta, + iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat + ) + torch.cuda.empty_cache() + # jha_end3 = time.time() + # print("validate_colored_mesh_test_time: ", jha_end3 - jha_begin3) + + if self.num_lods > 1: + self.validate_colored_mesh( + density_or_sdf_network=self.sdf_network_lod1, + func_extract_geometry=self.sdf_renderer_lod1.extract_geometry, + conditional_volume=con_volume_lod1, + conditional_valid_mask_volume = con_valid_mask_volume_lod1, + feature_maps=geometry_feature_maps, + color_maps=imgs, + w2cs=w2cs, + target_candidate_w2cs=None, + intrinsics=intrinsics, + rendering_network=self.rendering_network_lod1, + rendering_projector=self.sdf_renderer_lod1.rendering_projector, + lod=1, + threshold=0, + query_c2w=query_c2w, + mode='val_bg', meta=meta, + iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat + ) + torch.cuda.empty_cache() + + + + def save_visualization(self, true_img, true_colored_depth, out_depth, out_normal, w2cs, out_color, H, W, + depth_min, depth_max, iter_step, meta, comment, out_color_mlp=[], true_depth=None, scale_factor=1.0): + if len(out_color) > 0: + img_fine = (np.concatenate(out_color, axis=0).reshape([H, W, 3]) * 256).clip(0, 255) + + if len(out_color_mlp) > 0: + img_mlp = (np.concatenate(out_color_mlp, axis=0).reshape([H, W, 3]) * 256).clip(0, 255) + + if len(out_normal) > 0: + normal_img = np.concatenate(out_normal, axis=0) + rot = w2cs[:3, :3].detach().cpu().numpy() + # - convert normal from world space to camera space + normal_img = (np.matmul(rot[None, :, :], + normal_img[:, :, None]).reshape([H, W, 3]) * 128 + 128).clip(0, 255) + if len(out_depth) > 0: + pred_depth = np.concatenate(out_depth, axis=0).reshape([H, W]) + pred_depth_colored = visualize_depth_numpy(pred_depth, [depth_min, depth_max])[0] + + if len(out_depth) > 0: + os.makedirs(os.path.join(self.base_exp_dir, 'depths_' + comment), exist_ok=True) + if true_colored_depth is not None: + + if true_depth is not None: + depth_error_map = np.abs(true_depth - pred_depth) * 2.0 / scale_factor + # [256, 256, 1] -> [256, 256, 3] + depth_error_map = np.tile(depth_error_map[:, :, None], [1, 1, 3]) + print("meta: ", meta) + print("scale_factor: ", scale_factor) + print("depth_error_mean: ", depth_error_map.mean()) + # import ipdb; ipdb.set_trace() + depth_visualized = np.concatenate( + [(depth_error_map * 255).astype(np.uint8), true_colored_depth, pred_depth_colored, true_img], axis=1)[:, :, ::-1] + # print("depth_visualized.shape: ", depth_visualized.shape) + # write depth error result text on img, the input is a numpy array of [256, 1024, 3] + # cv.putText(depth_visualized.copy(), "depth_error_mean: {:.4f}".format(depth_error_map.mean()), (10, 30), cv.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) + else: + depth_visualized = np.concatenate( + [true_colored_depth, pred_depth_colored, true_img])[:, :, ::-1] + cv.imwrite( + os.path.join(self.base_exp_dir, 'depths_' + comment, + '{:0>8d}_{}.png'.format(iter_step, meta)), depth_visualized + ) + else: + cv.imwrite( + os.path.join(self.base_exp_dir, 'depths_' + comment, + '{:0>8d}_{}.png'.format(iter_step, meta)), + np.concatenate( + [pred_depth_colored, true_img])[:, :, ::-1]) + if len(out_color) > 0: + os.makedirs(os.path.join(self.base_exp_dir, 'synthesized_color_' + comment), exist_ok=True) + cv.imwrite(os.path.join(self.base_exp_dir, 'synthesized_color_' + comment, + '{:0>8d}_{}.png'.format(iter_step, meta)), + np.concatenate( + [img_fine, true_img])[:, :, ::-1]) # bgr2rgb + # compute psnr (image pixel lie in [0, 255]) + mse_loss = np.mean((img_fine - true_img) ** 2) + psnr = 10 * np.log10(255 ** 2 / mse_loss) + + print("PSNR: ", psnr) + + if len(out_color_mlp) > 0: + os.makedirs(os.path.join(self.base_exp_dir, 'synthesized_color_mlp_' + comment), exist_ok=True) + cv.imwrite(os.path.join(self.base_exp_dir, 'synthesized_color_mlp_' + comment, + '{:0>8d}_{}.png'.format(iter_step, meta)), + np.concatenate( + [img_mlp, true_img])[:, :, ::-1]) # bgr2rgb + + if len(out_normal) > 0: + os.makedirs(os.path.join(self.base_exp_dir, 'normals_' + comment), exist_ok=True) + cv.imwrite(os.path.join(self.base_exp_dir, 'normals_' + comment, + '{:0>8d}_{}.png'.format(iter_step, meta)), + normal_img[:, :, ::-1]) + + def forward(self, sample, + perturb_overwrite=-1, + background_rgb=None, + alpha_inter_ratio_lod0=0.0, + alpha_inter_ratio_lod1=0.0, + iter_step=0, + mode='train', + save_vis=False, + ): + + if mode == 'train': + return self.train_step(sample, + perturb_overwrite=perturb_overwrite, + background_rgb=background_rgb, + alpha_inter_ratio_lod0=alpha_inter_ratio_lod0, + alpha_inter_ratio_lod1=alpha_inter_ratio_lod1, + iter_step=iter_step + ) + elif mode == 'val': + import time + begin = time.time() + result = self.val_step(sample, + perturb_overwrite=perturb_overwrite, + background_rgb=background_rgb, + alpha_inter_ratio_lod0=alpha_inter_ratio_lod0, + alpha_inter_ratio_lod1=alpha_inter_ratio_lod1, + iter_step=iter_step, + save_vis=save_vis, + ) + end = time.time() + print("val_step time: ", end - begin) + return result + elif mode == 'export_mesh': + import time + begin = time.time() + result = self.export_mesh_step(sample, + perturb_overwrite=perturb_overwrite, + background_rgb=background_rgb, + alpha_inter_ratio_lod0=alpha_inter_ratio_lod0, + alpha_inter_ratio_lod1=alpha_inter_ratio_lod1, + iter_step=iter_step, + save_vis=save_vis, + ) + end = time.time() + print("export mesh time: ", end - begin) + return result + def obtain_pyramid_feature_maps(self, imgs, lod=0): + """ + get feature maps of all conditional images + :param imgs: + :return: + """ + + if lod == 0: + extractor = self.pyramid_feature_network_geometry_lod0 + elif lod >= 1: + extractor = self.pyramid_feature_network_geometry_lod1 + + pyramid_feature_maps = extractor(imgs) + + # * the pyramid features are very important, if only use the coarst features, hard to optimize + fused_feature_maps = torch.cat([ + F.interpolate(pyramid_feature_maps[0], scale_factor=4, mode='bilinear', align_corners=True), + F.interpolate(pyramid_feature_maps[1], scale_factor=2, mode='bilinear', align_corners=True), + pyramid_feature_maps[2] + ], dim=1) + + return fused_feature_maps + + def cal_losses_sdf(self, render_out, sample_rays, iter_step=-1, lod=0): + + # loss weight schedule; the regularization terms should be added in later training stage + def get_weight(iter_step, weight): + if lod == 1: + anneal_start = self.anneal_end if lod == 0 else self.anneal_end_lod1 + anneal_end = self.anneal_end if lod == 0 else self.anneal_end_lod1 + anneal_end = anneal_end * 2 + else: + anneal_start = self.anneal_start if lod == 0 else self.anneal_start_lod1 + anneal_end = self.anneal_end if lod == 0 else self.anneal_end_lod1 + anneal_end = anneal_end * 2 + + if iter_step < 0: + return weight + + if anneal_end == 0.0: + return weight + elif iter_step < anneal_start: + return 0.0 + else: + return np.min( + [1.0, + (iter_step - anneal_start) / (anneal_end - anneal_start)]) * weight + + rays_o = sample_rays['rays_o'][0] + rays_d = sample_rays['rays_v'][0] + true_rgb = sample_rays['rays_color'][0] + + if 'rays_depth' in sample_rays.keys(): + true_depth = sample_rays['rays_depth'][0] + else: + true_depth = None + mask = sample_rays['rays_mask'][0] + + color_fine = render_out['color_fine'] + color_fine_mask = render_out['color_fine_mask'] + depth_pred = render_out['depth'] + + variance = render_out['variance'] + cdf_fine = render_out['cdf_fine'] + weight_sum = render_out['weights_sum'] + + gradient_error_fine = render_out['gradient_error_fine'] + + sdf = render_out['sdf'] + + # * color generated by mlp + color_mlp = render_out['color_mlp'] + color_mlp_mask = render_out['color_mlp_mask'] + + if color_fine is not None: + # Color loss + color_mask = color_fine_mask if color_fine_mask is not None else mask + # import ipdb; ipdb.set_trace() + color_mask = color_mask[..., 0] + color_error = (color_fine[color_mask] - true_rgb[color_mask]) + # print("Nan number", torch.isnan(color_error).sum()) + # print("Color error shape", color_error.shape) + # import ipdb; ipdb.set_trace() + color_fine_loss = F.l1_loss(color_error, torch.zeros_like(color_error).to(color_error.device), + reduction='mean') + # print(color_fine_loss) + psnr = 20.0 * torch.log10( + 1.0 / (((color_fine[color_mask] - true_rgb[color_mask]) ** 2).mean() / (3.0)).sqrt()) + else: + color_fine_loss = 0. + psnr = 0. + + if color_mlp is not None: + # Color loss + color_mlp_mask = color_mlp_mask[..., 0] + color_error_mlp = (color_mlp[color_mlp_mask] - true_rgb[color_mlp_mask]) + color_mlp_loss = F.l1_loss(color_error_mlp, + torch.zeros_like(color_error_mlp).to(color_error_mlp.device), + reduction='mean') + + psnr_mlp = 20.0 * torch.log10( + 1.0 / (((color_mlp[color_mlp_mask] - true_rgb[color_mlp_mask]) ** 2).mean() / (3.0)).sqrt()) + else: + color_mlp_loss = 0. + psnr_mlp = 0. + + # depth loss is only used for inference, not included in total loss + if true_depth is not None: + # depth_loss = self.depth_criterion(depth_pred, true_depth, mask) + depth_loss = self.depth_criterion(depth_pred, true_depth) + + # # depth evaluation + # depth_statis = compute_depth_errors(depth_pred.detach().cpu().numpy(), true_depth.cpu().numpy()) + # depth_statis = numpy2tensor(depth_statis, device=rays_o.device) + depth_statis = None + else: + depth_loss = 0. + depth_statis = None + + sparse_loss_1 = torch.exp( + -1 * torch.abs(render_out['sdf_random']) * self.sdf_decay_param).mean() # - should equal + sparse_loss_2 = torch.exp(-1 * torch.abs(sdf) * self.sdf_decay_param).mean() + sparse_loss = (sparse_loss_1 + sparse_loss_2) / 2 + + sdf_mean = torch.abs(sdf).mean() + sparseness_1 = (torch.abs(sdf) < 0.01).to(torch.float32).mean() + sparseness_2 = (torch.abs(sdf) < 0.02).to(torch.float32).mean() + + # Eikonal loss + gradient_error_loss = gradient_error_fine + + # ! the first 50k, don't use bg constraint + fg_bg_weight = 0.0 if iter_step < 50000 else get_weight(iter_step, self.fg_bg_weight) + + # Mask loss, optional + # The images of DTU dataset contain large black regions (0 rgb values), + # can use this data prior to make fg more clean + background_loss = 0.0 + fg_bg_loss = 0.0 + if self.fg_bg_weight > 0 and torch.mean((mask < 0.5).to(torch.float32)) > 0.02: + weights_sum_fg = render_out['weights_sum_fg'] + fg_bg_error = (weights_sum_fg - mask)[mask < 0.5] + fg_bg_loss = F.l1_loss(fg_bg_error, + torch.zeros_like(fg_bg_error).to(fg_bg_error.device), + reduction='mean') + + + + loss = 1.0 * depth_loss + color_fine_loss + color_mlp_loss + \ + sparse_loss * get_weight(iter_step, self.sdf_sparse_weight) + \ + fg_bg_loss * fg_bg_weight + \ + gradient_error_loss * self.sdf_igr_weight # ! gradient_error_loss need a mask + + losses = { + "loss": loss, + "depth_loss": depth_loss, + "color_fine_loss": color_fine_loss, + "color_mlp_loss": color_mlp_loss, + "gradient_error_loss": gradient_error_loss, + "background_loss": background_loss, + "sparse_loss": sparse_loss, + "sparseness_1": sparseness_1, + "sparseness_2": sparseness_2, + "sdf_mean": sdf_mean, + "psnr": psnr, + "psnr_mlp": psnr_mlp, + "weights_sum": render_out['weights_sum'], + "weights_sum_fg": render_out['weights_sum_fg'], + "alpha_sum": render_out['alpha_sum'], + "variance": render_out['variance'], + "sparse_weight": get_weight(iter_step, self.sdf_sparse_weight), + "fg_bg_weight": fg_bg_weight, + "fg_bg_loss": fg_bg_loss, # added by jha, bug of sparseNeuS + } + # print("[TEST]: weights_sum in trainner forward", losses['weights_sum'].mean()) + losses = numpy2tensor(losses, device=rays_o.device) + return loss, losses, depth_statis + + @torch.no_grad() + def validate_mesh(self, density_or_sdf_network, func_extract_geometry, world_space=True, resolution=360, + threshold=0.0, mode='val', + # * 3d feature volume + conditional_volume=None, lod=None, occupancy_mask=None, + bound_min=[-1, -1, -1], bound_max=[1, 1, 1], meta='', iter_step=0, scale_mat=None, + trans_mat=None + ): + + bound_min = torch.tensor(bound_min, dtype=torch.float32) + bound_max = torch.tensor(bound_max, dtype=torch.float32) + + vertices, triangles, fields = func_extract_geometry( + density_or_sdf_network, + bound_min, bound_max, resolution=resolution, + threshold=threshold, device=conditional_volume.device, + # * 3d feature volume + conditional_volume=conditional_volume, lod=lod, + occupancy_mask=occupancy_mask + ) + + + if scale_mat is not None: + scale_mat_np = scale_mat.cpu().numpy() + vertices = vertices * scale_mat_np[0][0, 0] + scale_mat_np[0][:3, 3][None] + + if trans_mat is not None: # w2c_ref_inv + trans_mat_np = trans_mat.cpu().numpy() + vertices_homo = np.concatenate([vertices, np.ones_like(vertices[:, :1])], axis=1) + vertices = np.matmul(trans_mat_np, vertices_homo[:, :, None])[:, :3, 0] + + mesh = trimesh.Trimesh(vertices, triangles) + os.makedirs(os.path.join(self.base_exp_dir, 'meshes_' + mode), exist_ok=True) + mesh.export(os.path.join(self.base_exp_dir, 'meshes_' + mode, + 'mesh_{:0>8d}_{}_lod{:0>1d}.ply'.format(iter_step, meta, lod))) + + + + def validate_colored_mesh(self, density_or_sdf_network, func_extract_geometry, world_space=True, resolution=360, + threshold=0.0, mode='val', + # * 3d feature volume + conditional_volume=None, + conditional_valid_mask_volume=None, + feature_maps=None, + color_maps = None, + w2cs=None, + target_candidate_w2cs=None, + intrinsics=None, + rendering_network=None, + rendering_projector=None, + query_c2w=None, + lod=None, occupancy_mask=None, + bound_min=[-1, -1, -1], bound_max=[1, 1, 1], meta='', iter_step=0, scale_mat=None, + trans_mat=None + ): + + bound_min = torch.tensor(bound_min, dtype=torch.float32) + bound_max = torch.tensor(bound_max, dtype=torch.float32) + # import time + # jha_begin4 = time.time() + vertices, triangles, fields = func_extract_geometry( + density_or_sdf_network, + bound_min, bound_max, resolution=resolution, + threshold=threshold, device=conditional_volume.device, + # * 3d feature volume + conditional_volume=conditional_volume, lod=lod, + occupancy_mask=occupancy_mask + ) + # jha_end4 = time.time() + # print("[TEST]: func_extract_geometry time", jha_end4 - jha_begin4) + + # import time + # jha_begin5 = time.time() + + + with torch.no_grad(): + ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask, _, _ = rendering_projector.compute_view_independent( + torch.tensor(vertices).to(conditional_volume), + lod=lod, # JHA EDITED + # * 3d geometry feature volumes + geometryVolume=conditional_volume[0], + geometryVolumeMask=conditional_valid_mask_volume[0], + sdf_network=density_or_sdf_network, + # * 2d rendering feature maps + rendering_feature_maps=feature_maps, # [n_view, 56, 256, 256] + color_maps=color_maps, + w2cs=w2cs, + target_candidate_w2cs=target_candidate_w2cs, + intrinsics=intrinsics, + img_wh=[256,256], + query_img_idx=0, # the index of the N_views dim for rendering + query_c2w=query_c2w, + ) + + + vertices_color, rendering_valid_mask = rendering_network( + ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask) + + # jha_end5 = time.time() + # print("[TEST]: rendering_network time", jha_end5 - jha_begin5) + + if scale_mat is not None: + scale_mat_np = scale_mat.cpu().numpy() + vertices = vertices * scale_mat_np[0][0, 0] + scale_mat_np[0][:3, 3][None] + + if trans_mat is not None: # w2c_ref_inv + trans_mat_np = trans_mat.cpu().numpy() + vertices_homo = np.concatenate([vertices, np.ones_like(vertices[:, :1])], axis=1) + vertices = np.matmul(trans_mat_np, vertices_homo[:, :, None])[:, :3, 0] + # import ipdb; ipdb.set_trace() + vertices_color = np.array(vertices_color.squeeze(0).cpu() * 255, dtype=np.uint8) + mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=vertices_color) + os.makedirs(os.path.join(self.base_exp_dir, 'meshes_' + mode, 'lod{:0>1d}'.format(lod)), exist_ok=True) + # mesh.export(os.path.join(self.base_exp_dir, 'meshes_' + mode, 'lod{:0>1d}'.format(lod), + # 'mesh_{:0>8d}_{}_lod{:0>1d}.ply'.format(iter_step, meta, lod))) + # MODIFIED + mesh.export(os.path.join(self.base_exp_dir, 'meshes_' + mode, 'lod{:0>1d}'.format(lod), + 'mesh_{:0>8d}_gradio_lod{:0>1d}.ply'.format(iter_step, lod))) \ No newline at end of file diff --git a/SparseNeuS_demo_v1/models/trainer_generic_normals_new.py b/SparseNeuS_demo_v1/models/trainer_generic_normals_new.py new file mode 100644 index 0000000000000000000000000000000000000000..8a75f2c7fcaf613e1a4c5deeb9a8be15abd96d8d --- /dev/null +++ b/SparseNeuS_demo_v1/models/trainer_generic_normals_new.py @@ -0,0 +1,1313 @@ +""" +decouple the trainer with the renderer +""" +import os +import cv2 as cv +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np +import logging +import mcubes +import trimesh +from icecream import ic + +from utils.misc_utils import visualize_depth_numpy + +from utils.training_utils import numpy2tensor +from loss.depth_metric import compute_depth_errors + +from loss.depth_loss import DepthLoss, DepthSmoothLoss + +from models.rays import gen_rays_between + +from models.sparse_neus_renderer_normals_new import SparseNeuSRenderer + +def safe_l2_normalize(x, dim=None, eps=1e-6): + return F.normalize(x, p=2, dim=dim, eps=eps) + + +class GenericTrainer(nn.Module): + def __init__(self, + rendering_network_outside, + pyramid_feature_network_lod0, + pyramid_feature_network_lod1, + sdf_network_lod0, + sdf_network_lod1, + variance_network_lod0, + variance_network_lod1, + rendering_network_lod0, + rendering_network_lod1, + n_samples_lod0, + n_importance_lod0, + n_samples_lod1, + n_importance_lod1, + n_outside, + perturb, + alpha_type='div', + conf=None, + timestamp="", + mode='train', + base_exp_dir=None, + ): + super(GenericTrainer, self).__init__() + + self.conf = conf + self.timestamp = timestamp + + + self.base_exp_dir = base_exp_dir + + + self.anneal_start = self.conf.get_float('train.anneal_start', default=0.0) + self.anneal_end = self.conf.get_float('train.anneal_end', default=0.0) + self.anneal_start_lod1 = self.conf.get_float('train.anneal_start_lod1', default=0.0) + self.anneal_end_lod1 = self.conf.get_float('train.anneal_end_lod1', default=0.0) + + # network setups + self.rendering_network_outside = rendering_network_outside + self.pyramid_feature_network_geometry_lod0 = pyramid_feature_network_lod0 # 2D pyramid feature network for geometry + self.pyramid_feature_network_geometry_lod1 = pyramid_feature_network_lod1 # use differnet networks for the two lods + + # when num_lods==2, may consume too much memeory + self.sdf_network_lod0 = sdf_network_lod0 + self.sdf_network_lod1 = sdf_network_lod1 + + # - warpped by ModuleList to support DataParallel + self.variance_network_lod0 = variance_network_lod0 + self.variance_network_lod1 = variance_network_lod1 + + self.rendering_network_lod0 = rendering_network_lod0 + self.rendering_network_lod1 = rendering_network_lod1 + + self.n_samples_lod0 = n_samples_lod0 + self.n_importance_lod0 = n_importance_lod0 + self.n_samples_lod1 = n_samples_lod1 + self.n_importance_lod1 = n_importance_lod1 + self.n_outside = n_outside + self.num_lods = conf.get_int('model.num_lods') # the number of octree lods + self.perturb = perturb + self.alpha_type = alpha_type + + # - the two renderers + self.sdf_renderer_lod0 = SparseNeuSRenderer( + self.rendering_network_outside, + self.sdf_network_lod0, + self.variance_network_lod0, + self.rendering_network_lod0, + self.n_samples_lod0, + self.n_importance_lod0, + self.n_outside, + self.perturb, + alpha_type='div', + conf=self.conf) + + self.sdf_renderer_lod1 = SparseNeuSRenderer( + self.rendering_network_outside, + self.sdf_network_lod1, + self.variance_network_lod1, + self.rendering_network_lod1, + self.n_samples_lod1, + self.n_importance_lod1, + self.n_outside, + self.perturb, + alpha_type='div', + conf=self.conf) + + self.if_fix_lod0_networks = self.conf.get_bool('train.if_fix_lod0_networks') + + # sdf network weights + self.sdf_igr_weight = self.conf.get_float('train.sdf_igr_weight') + self.sdf_sparse_weight = self.conf.get_float('train.sdf_sparse_weight', default=0) + self.sdf_decay_param = self.conf.get_float('train.sdf_decay_param', default=100) + self.fg_bg_weight = self.conf.get_float('train.fg_bg_weight', default=0.00) + self.bg_ratio = self.conf.get_float('train.bg_ratio', default=0.0) + + self.depth_criterion = DepthLoss() + + # - DataParallel mode, cannot modify attributes in forward() + # self.iter_step = 0 + self.val_mesh_freq = self.conf.get_int('train.val_mesh_freq') + + # - True for finetuning; False for general training + self.if_fitted_rendering = self.conf.get_bool('train.if_fitted_rendering', default=False) + + self.prune_depth_filter = self.conf.get_bool('model.prune_depth_filter', default=False) + + def get_trainable_params(self): + # set trainable params + + self.params_to_train = [] + + if not self.if_fix_lod0_networks: + # load pretrained featurenet + self.params_to_train += list(self.pyramid_feature_network_geometry_lod0.parameters()) + self.params_to_train += list(self.sdf_network_lod0.parameters()) + self.params_to_train += list(self.variance_network_lod0.parameters()) + + if self.rendering_network_lod0 is not None: + self.params_to_train += list(self.rendering_network_lod0.parameters()) + + if self.sdf_network_lod1 is not None: + # load pretrained featurenet + self.params_to_train += list(self.pyramid_feature_network_geometry_lod1.parameters()) + + self.params_to_train += list(self.sdf_network_lod1.parameters()) + self.params_to_train += list(self.variance_network_lod1.parameters()) + if self.rendering_network_lod1 is not None: + self.params_to_train += list(self.rendering_network_lod1.parameters()) + + return self.params_to_train + + def train_step(self, sample, + perturb_overwrite=-1, + background_rgb=None, + alpha_inter_ratio_lod0=0.0, + alpha_inter_ratio_lod1=0.0, + iter_step=0, + ): + # * only support batch_size==1 + # ! attention: the list of string cannot be splited in DataParallel + batch_idx = sample['batch_idx'][0] + meta = sample['meta'][batch_idx] # the scan lighting ref_view info + + sizeW = sample['img_wh'][0][0] + sizeH = sample['img_wh'][0][1] + partial_vol_origin = sample['partial_vol_origin'] # [B, 3] + near, far = sample['near_fars'][0, 0, :1], sample['near_fars'][0, 0, 1:] + + # the full-size ray variables + sample_rays = sample['rays'] + rays_o = sample_rays['rays_o'][0] + rays_d = sample_rays['rays_v'][0] + + imgs = sample['images'][0] + intrinsics = sample['intrinsics'][0] + intrinsics_l_4x = intrinsics.clone() + intrinsics_l_4x[:, :2] *= 0.25 + w2cs = sample['w2cs'][0] + c2ws = sample['c2ws'][0] + proj_matrices = sample['affine_mats'] + scale_mat = sample['scale_mat'] + trans_mat = sample['trans_mat'] + + # *********************** Lod==0 *********************** + if not self.if_fix_lod0_networks: + geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs) + + conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume( + feature_maps=geometry_feature_maps[None, 1:, :, :, :], + partial_vol_origin=partial_vol_origin, + proj_mats=proj_matrices[:,1:], + # proj_mats=proj_matrices, + sizeH=sizeH, + sizeW=sizeW, + lod=0, + ) + + else: + with torch.no_grad(): + geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0) + # geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0) + conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume( + feature_maps=geometry_feature_maps[None, 1:, :, :, :], + partial_vol_origin=partial_vol_origin, + proj_mats=proj_matrices[:,1:], + # proj_mats=proj_matrices, + sizeH=sizeH, + sizeW=sizeW, + lod=0, + ) + # print("Checker2:, construct cost volume") + con_volume_lod0 = conditional_features_lod0['dense_volume_scale0'] + + con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0'] + # import ipdb; ipdb.set_trace() + coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ] + + # * extract depth maps for all the images + depth_maps_lod0, depth_masks_lod0 = None, None + if self.num_lods > 1: + sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume( + con_volume_lod0, con_valid_mask_volume_lod0, + coords_lod0, partial_vol_origin) # [1, 1, dX, dY, dZ] + + if self.prune_depth_filter: + depth_maps_lod0_l4x, depth_masks_lod0_l4x = self.sdf_renderer_lod0.extract_depth_maps( + self.sdf_network_lod0, sdf_volume_lod0, intrinsics_l_4x, c2ws, + sizeH // 4, sizeW // 4, near * 1.5, far) + depth_maps_lod0 = F.interpolate(depth_maps_lod0_l4x, size=(sizeH, sizeW), mode='bilinear', + align_corners=True) + depth_masks_lod0 = F.interpolate(depth_masks_lod0_l4x.float(), size=(sizeH, sizeW), mode='nearest') + + # *************** losses + loss_lod0, losses_lod0, depth_statis_lod0 = None, None, None + + if not self.if_fix_lod0_networks: + + render_out = self.sdf_renderer_lod0.render( + rays_o, rays_d, near, far, + self.sdf_network_lod0, + self.rendering_network_lod0, + background_rgb=background_rgb, + alpha_inter_ratio=alpha_inter_ratio_lod0, + # * related to conditional feature + lod=0, + conditional_volume=con_volume_lod0, + conditional_valid_mask_volume=con_valid_mask_volume_lod0, + # * 2d feature maps + feature_maps=geometry_feature_maps, + color_maps=imgs, + w2cs=w2cs, + intrinsics=intrinsics, + img_wh=[sizeW, sizeH], + if_general_rendering=True, + if_render_with_grad=True, + ) + + loss_lod0, losses_lod0, depth_statis_lod0 = self.cal_losses_sdf(render_out, sample_rays, + iter_step, lod=0) + + # *********************** Lod==1 *********************** + + loss_lod1, losses_lod1, depth_statis_lod1 = None, None, None + + if self.num_lods > 1: + geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1) + # geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1) + if self.prune_depth_filter: + pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf_depthfilter( + sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0], + depth_maps_lod0, proj_matrices[0], + partial_vol_origin, self.sdf_network_lod0.voxel_size, + near, far, self.sdf_network_lod0.voxel_size, 12) + else: + pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf( + sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0]) + + pre_coords[:, 1:] = pre_coords[:, 1:] * 2 + + # ? It seems that training gru_fusion, this part should be trainable too + conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume( + feature_maps=geometry_feature_maps_lod1[None, 1:, :, :, :], + partial_vol_origin=partial_vol_origin, + proj_mats=proj_matrices[:,1:], + # proj_mats=proj_matrices, + sizeH=sizeH, + sizeW=sizeW, + pre_coords=pre_coords, + pre_feats=pre_feats, + ) + + con_volume_lod1 = conditional_features_lod1['dense_volume_scale1'] + con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1'] + + # if not self.if_gru_fusion_lod1: + render_out_lod1 = self.sdf_renderer_lod1.render( + rays_o, rays_d, near, far, + self.sdf_network_lod1, + self.rendering_network_lod1, + background_rgb=background_rgb, + alpha_inter_ratio=alpha_inter_ratio_lod1, + # * related to conditional feature + lod=1, + conditional_volume=con_volume_lod1, + conditional_valid_mask_volume=con_valid_mask_volume_lod1, + # * 2d feature maps + feature_maps=geometry_feature_maps_lod1, + color_maps=imgs, + w2cs=w2cs, + intrinsics=intrinsics, + img_wh=[sizeW, sizeH], + bg_ratio=self.bg_ratio, + ) + loss_lod1, losses_lod1, depth_statis_lod1 = self.cal_losses_sdf(render_out_lod1, sample_rays, + iter_step, lod=1) + + # print("Checker3:, compute losses") + # # - extract mesh + if iter_step % self.val_mesh_freq == 0: + torch.cuda.empty_cache() + self.validate_mesh(self.sdf_network_lod0, + self.sdf_renderer_lod0.extract_geometry, + conditional_volume=con_volume_lod0, lod=0, + threshold=0, + # occupancy_mask=con_valid_mask_volume_lod0[0, 0], + mode='train_bg', meta=meta, + iter_step=iter_step, scale_mat=scale_mat, + trans_mat=trans_mat) + torch.cuda.empty_cache() + + if self.num_lods > 1: + self.validate_mesh(self.sdf_network_lod1, + self.sdf_renderer_lod1.extract_geometry, + conditional_volume=con_volume_lod1, lod=1, + # occupancy_mask=con_valid_mask_volume_lod1[0, 0].detach(), + mode='train_bg', meta=meta, + iter_step=iter_step, scale_mat=scale_mat, + trans_mat=trans_mat) + # import ipdb; ipdb.set_trace() + # print("Checker3.1:, after val mesh") + losses = { + # - lod 0 + 'loss_lod0': loss_lod0, + 'losses_lod0': losses_lod0, + 'depth_statis_lod0': depth_statis_lod0, + + # - lod 1 + 'loss_lod1': loss_lod1, + 'losses_lod1': losses_lod1, + 'depth_statis_lod1': depth_statis_lod1, + + } + + return losses + + def val_step(self, sample, + perturb_overwrite=-1, + background_rgb=None, + alpha_inter_ratio_lod0=0.0, + alpha_inter_ratio_lod1=0.0, + iter_step=0, + chunk_size=512, + save_vis=False, + ): + # * only support batch_size==1 + # ! attention: the list of string cannot be splited in DataParallel + batch_idx = sample['batch_idx'][0] + meta = sample['meta'][batch_idx] # the scan lighting ref_view info + + sizeW = sample['img_wh'][0][0] + sizeH = sample['img_wh'][0][1] + H, W = sizeH, sizeW + + partial_vol_origin = sample['partial_vol_origin'] # [B, 3] + near, far = sample['query_near_far'][0, :1], sample['query_near_far'][0, 1:] + + # the ray variables + sample_rays = sample['rays'] + rays_o = sample_rays['rays_o'][0] + rays_d = sample_rays['rays_v'][0] + rays_ndc_uv = sample_rays['rays_ndc_uv'][0] + + imgs = sample['images'][0] + intrinsics = sample['intrinsics'][0] + intrinsics_l_4x = intrinsics.clone() + intrinsics_l_4x[:, :2] *= 0.25 + w2cs = sample['w2cs'][0] + c2ws = sample['c2ws'][0] + proj_matrices = sample['affine_mats'] + + # render_img_idx = sample['render_img_idx'][0] + # true_img = sample['images'][0][render_img_idx] + + # - the image to render + scale_mat = sample['scale_mat'] # [1,4,4] used to convert mesh into true scale + trans_mat = sample['trans_mat'] + query_c2w = sample['query_c2w'] # [1,4,4] + query_w2c = sample['query_w2c'] # [1,4,4] + true_img = sample['query_image'][0] + true_img = np.uint8(true_img.permute(1, 2, 0).cpu().numpy() * 255) + + depth_min, depth_max = near.cpu().numpy(), far.cpu().numpy() + + scale_factor = sample['scale_factor'][0].cpu().numpy() + true_depth = sample['query_depth'] if 'query_depth' in sample.keys() else None + if true_depth is not None: + true_depth = true_depth[0].cpu().numpy() + true_depth_colored = visualize_depth_numpy(true_depth, [depth_min, depth_max])[0] + else: + true_depth_colored = None + + rays_o = rays_o.reshape(-1, 3).split(chunk_size) + rays_d = rays_d.reshape(-1, 3).split(chunk_size) + + # - obtain conditional features + with torch.no_grad(): + # - obtain conditional features + geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0) + # import ipdb; ipdb.set_trace() + # - lod 0 + conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume( + feature_maps=geometry_feature_maps[None, :, :, :, :], + partial_vol_origin=partial_vol_origin, + proj_mats=proj_matrices, + sizeH=sizeH, + sizeW=sizeW, + lod=0, + ) + + con_volume_lod0 = conditional_features_lod0['dense_volume_scale0'] + con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0'] + coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ] + + if self.num_lods > 1: + sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume( + con_volume_lod0, con_valid_mask_volume_lod0, + coords_lod0, partial_vol_origin) # [1, 1, dX, dY, dZ] + + depth_maps_lod0, depth_masks_lod0 = None, None + if self.prune_depth_filter: + depth_maps_lod0_l4x, depth_masks_lod0_l4x = self.sdf_renderer_lod0.extract_depth_maps( + self.sdf_network_lod0, sdf_volume_lod0, + intrinsics_l_4x, c2ws, + sizeH // 4, sizeW // 4, near * 1.5, far) # - near*1.5 is a experienced number + depth_maps_lod0 = F.interpolate(depth_maps_lod0_l4x, size=(sizeH, sizeW), mode='bilinear', + align_corners=True) + depth_masks_lod0 = F.interpolate(depth_masks_lod0_l4x.float(), size=(sizeH, sizeW), mode='nearest') + + #### visualize the depth_maps_lod0 for checking + colored_depth_maps_lod0 = [] + for i in range(depth_maps_lod0.shape[0]): + colored_depth_maps_lod0.append( + visualize_depth_numpy(depth_maps_lod0[i, 0].cpu().numpy(), [depth_min, depth_max])[0]) + + colored_depth_maps_lod0 = np.concatenate(colored_depth_maps_lod0, axis=0).astype(np.uint8) + os.makedirs(os.path.join(self.base_exp_dir, 'depth_maps_lod0'), exist_ok=True) + cv.imwrite(os.path.join(self.base_exp_dir, 'depth_maps_lod0', + '{:0>8d}_{}.png'.format(iter_step, meta)), + colored_depth_maps_lod0[:, :, ::-1]) + + if self.num_lods > 1: + geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1) + + if self.prune_depth_filter: + pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf_depthfilter( + sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0], + depth_maps_lod0, proj_matrices[0], + partial_vol_origin, self.sdf_network_lod0.voxel_size, + near, far, self.sdf_network_lod0.voxel_size, 12) + else: + pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf( + sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0]) + + pre_coords[:, 1:] = pre_coords[:, 1:] * 2 + + with torch.no_grad(): + conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume( + feature_maps=geometry_feature_maps_lod1[None, :, :, :, :], + partial_vol_origin=partial_vol_origin, + proj_mats=proj_matrices, + sizeH=sizeH, + sizeW=sizeW, + pre_coords=pre_coords, + pre_feats=pre_feats, + ) + + con_volume_lod1 = conditional_features_lod1['dense_volume_scale1'] + con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1'] + + out_rgb_fine = [] + out_normal_fine = [] + out_depth_fine = [] + + out_rgb_fine_lod1 = [] + out_normal_fine_lod1 = [] + out_depth_fine_lod1 = [] + + # out_depth_fine_explicit = [] + if save_vis: + for rays_o_batch, rays_d_batch in zip(rays_o, rays_d): + + # ****** lod 0 **** + render_out = self.sdf_renderer_lod0.render( + rays_o_batch, rays_d_batch, near, far, + self.sdf_network_lod0, + self.rendering_network_lod0, + background_rgb=background_rgb, + alpha_inter_ratio=alpha_inter_ratio_lod0, + # * related to conditional feature + lod=0, + conditional_volume=con_volume_lod0, + conditional_valid_mask_volume=con_valid_mask_volume_lod0, + # * 2d feature maps + feature_maps=geometry_feature_maps, + color_maps=imgs, + w2cs=w2cs, + intrinsics=intrinsics, + img_wh=[sizeW, sizeH], + query_c2w=query_c2w, + if_render_with_grad=False, + ) + + feasible = lambda key: ((key in render_out) and (render_out[key] is not None)) + + if feasible('depth'): + out_depth_fine.append(render_out['depth'].detach().cpu().numpy()) + + # if render_out['color_coarse'] is not None: + if feasible('color_fine'): + out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy()) + if feasible('gradients') and feasible('weights'): + if render_out['inside_sphere'] is not None: + out_normal_fine.append((render_out['gradients'] * render_out['weights'][:, + :self.n_samples_lod0 + self.n_importance_lod0, + None] * render_out['inside_sphere'][ + ..., None]).sum(dim=1).detach().cpu().numpy()) + else: + out_normal_fine.append((render_out['gradients'] * render_out['weights'][:, + :self.n_samples_lod0 + self.n_importance_lod0, + None]).sum(dim=1).detach().cpu().numpy()) + del render_out + + # ****************** lod 1 ************************** + if self.num_lods > 1: + for rays_o_batch, rays_d_batch in zip(rays_o, rays_d): + render_out_lod1 = self.sdf_renderer_lod1.render( + rays_o_batch, rays_d_batch, near, far, + self.sdf_network_lod1, + self.rendering_network_lod1, + background_rgb=background_rgb, + alpha_inter_ratio=alpha_inter_ratio_lod1, + # * related to conditional feature + lod=1, + conditional_volume=con_volume_lod1, + conditional_valid_mask_volume=con_valid_mask_volume_lod1, + # * 2d feature maps + feature_maps=geometry_feature_maps_lod1, + color_maps=imgs, + w2cs=w2cs, + intrinsics=intrinsics, + img_wh=[sizeW, sizeH], + query_c2w=query_c2w, + if_render_with_grad=False, + ) + + feasible = lambda key: ((key in render_out_lod1) and (render_out_lod1[key] is not None)) + + if feasible('depth'): + out_depth_fine_lod1.append(render_out_lod1['depth'].detach().cpu().numpy()) + + # if render_out['color_coarse'] is not None: + if feasible('color_fine'): + out_rgb_fine_lod1.append(render_out_lod1['color_fine'].detach().cpu().numpy()) + if feasible('gradients') and feasible('weights'): + if render_out_lod1['inside_sphere'] is not None: + out_normal_fine_lod1.append((render_out_lod1['gradients'] * render_out_lod1['weights'][:, + :self.n_samples_lod1 + self.n_importance_lod1, + None] * + render_out_lod1['inside_sphere'][ + ..., None]).sum(dim=1).detach().cpu().numpy()) + else: + out_normal_fine_lod1.append((render_out_lod1['gradients'] * render_out_lod1['weights'][:, + :self.n_samples_lod1 + self.n_importance_lod1, + None]).sum( + dim=1).detach().cpu().numpy()) + del render_out_lod1 + + # - save visualization of lod 0 + + self.save_visualization(true_img, true_depth_colored, out_depth_fine, out_normal_fine, + query_w2c[0], out_rgb_fine, H, W, + depth_min, depth_max, iter_step, meta, "val_lod0", true_depth=true_depth, scale_factor=scale_factor) + + if self.num_lods > 1: + self.save_visualization(true_img, true_depth_colored, out_depth_fine_lod1, out_normal_fine_lod1, + query_w2c[0], out_rgb_fine_lod1, H, W, + depth_min, depth_max, iter_step, meta, "val_lod1", true_depth=true_depth, scale_factor=scale_factor) + + # - extract mesh + if (iter_step % self.val_mesh_freq == 0): + torch.cuda.empty_cache() + self.validate_mesh(self.sdf_network_lod0, + self.sdf_renderer_lod0.extract_geometry, + conditional_volume=con_volume_lod0, lod=0, + threshold=0, + # occupancy_mask=con_valid_mask_volume_lod0[0, 0], + mode='val_bg', meta=meta, + iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat) + torch.cuda.empty_cache() + + if self.num_lods > 1: + self.validate_mesh(self.sdf_network_lod1, + self.sdf_renderer_lod1.extract_geometry, + conditional_volume=con_volume_lod1, lod=1, + # occupancy_mask=con_valid_mask_volume_lod1[0, 0].detach(), + mode='val_bg', meta=meta, + iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat) + + torch.cuda.empty_cache() + + + + def export_mesh_step(self, sample, + perturb_overwrite=-1, + background_rgb=None, + alpha_inter_ratio_lod0=0.0, + alpha_inter_ratio_lod1=0.0, + iter_step=0, + chunk_size=512, + save_vis=False, + ): + # * only support batch_size==1 + # ! attention: the list of string cannot be splited in DataParallel + batch_idx = sample['batch_idx'][0] + meta = sample['meta'][batch_idx] # the scan lighting ref_view info + + sizeW = sample['img_wh'][0][0] + sizeH = sample['img_wh'][0][1] + H, W = sizeH, sizeW + + partial_vol_origin = sample['partial_vol_origin'] # [B, 3] + near, far = sample['query_near_far'][0, :1], sample['query_near_far'][0, 1:] + + # the ray variables + sample_rays = sample['rays'] + rays_o = sample_rays['rays_o'][0] + rays_d = sample_rays['rays_v'][0] + rays_ndc_uv = sample_rays['rays_ndc_uv'][0] + + imgs = sample['images'][0] + intrinsics = sample['intrinsics'][0] + intrinsics_l_4x = intrinsics.clone() + intrinsics_l_4x[:, :2] *= 0.25 + w2cs = sample['w2cs'][0] + c2ws = sample['c2ws'][0] + # target_candidate_w2cs = sample['target_candidate_w2cs'][0] + proj_matrices = sample['affine_mats'] + + + # - the image to render + scale_mat = sample['scale_mat'] # [1,4,4] used to convert mesh into true scale + trans_mat = sample['trans_mat'] + query_c2w = sample['query_c2w'] # [1,4,4] + query_w2c = sample['query_w2c'] # [1,4,4] + true_img = sample['query_image'][0] + true_img = np.uint8(true_img.permute(1, 2, 0).cpu().numpy() * 255) + + depth_min, depth_max = near.cpu().numpy(), far.cpu().numpy() + + scale_factor = sample['scale_factor'][0].cpu().numpy() + true_depth = sample['query_depth'] if 'query_depth' in sample.keys() else None + if true_depth is not None: + true_depth = true_depth[0].cpu().numpy() + true_depth_colored = visualize_depth_numpy(true_depth, [depth_min, depth_max])[0] + else: + true_depth_colored = None + + rays_o = rays_o.reshape(-1, 3).split(chunk_size) + rays_d = rays_d.reshape(-1, 3).split(chunk_size) + + # - obtain conditional features + with torch.no_grad(): + # - obtain conditional features + geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0) + # - lod 0 + conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume( + feature_maps=geometry_feature_maps[None, :, :, :, :], + partial_vol_origin=partial_vol_origin, + proj_mats=proj_matrices, + sizeH=sizeH, + sizeW=sizeW, + lod=0, + ) + + con_volume_lod0 = conditional_features_lod0['dense_volume_scale0'] + con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0'] + coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ] + + if self.num_lods > 1: + sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume( + con_volume_lod0, con_valid_mask_volume_lod0, + coords_lod0, partial_vol_origin) # [1, 1, dX, dY, dZ] + + depth_maps_lod0, depth_masks_lod0 = None, None + + + if self.num_lods > 1: + geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1) + + if self.prune_depth_filter: + pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf_depthfilter( + sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0], + depth_maps_lod0, proj_matrices[0], + partial_vol_origin, self.sdf_network_lod0.voxel_size, + near, far, self.sdf_network_lod0.voxel_size, 12) + else: + pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf( + sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0]) + + pre_coords[:, 1:] = pre_coords[:, 1:] * 2 + + with torch.no_grad(): + conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume( + feature_maps=geometry_feature_maps_lod1[None, :, :, :, :], + partial_vol_origin=partial_vol_origin, + proj_mats=proj_matrices, + sizeH=sizeH, + sizeW=sizeW, + pre_coords=pre_coords, + pre_feats=pre_feats, + ) + + con_volume_lod1 = conditional_features_lod1['dense_volume_scale1'] + con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1'] + + out_rgb_fine = [] + out_normal_fine = [] + out_depth_fine = [] + + out_rgb_fine_lod1 = [] + out_normal_fine_lod1 = [] + out_depth_fine_lod1 = [] + + # # out_depth_fine_explicit = [] + # if save_vis: + # for rays_o_batch, rays_d_batch in zip(rays_o, rays_d): + + # # ****** lod 0 **** + # render_out = self.sdf_renderer_lod0.render( + # rays_o_batch, rays_d_batch, near, far, + # self.sdf_network_lod0, + # self.rendering_network_lod0, + # background_rgb=background_rgb, + # alpha_inter_ratio=alpha_inter_ratio_lod0, + # # * related to conditional feature + # lod=0, + # conditional_volume=con_volume_lod0, + # conditional_valid_mask_volume=con_valid_mask_volume_lod0, + # # * 2d feature maps + # feature_maps=geometry_feature_maps, + # color_maps=imgs, + # w2cs=w2cs, + # intrinsics=intrinsics, + # img_wh=[sizeW, sizeH], + # query_c2w=query_c2w, + # if_render_with_grad=False, + # ) + + # feasible = lambda key: ((key in render_out) and (render_out[key] is not None)) + + # if feasible('depth'): + # out_depth_fine.append(render_out['depth'].detach().cpu().numpy()) + + # # if render_out['color_coarse'] is not None: + # if feasible('color_fine'): + # out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy()) + # if feasible('gradients') and feasible('weights'): + # if render_out['inside_sphere'] is not None: + # out_normal_fine.append((render_out['gradients'] * render_out['weights'][:, + # :self.n_samples_lod0 + self.n_importance_lod0, + # None] * render_out['inside_sphere'][ + # ..., None]).sum(dim=1).detach().cpu().numpy()) + # else: + # out_normal_fine.append((render_out['gradients'] * render_out['weights'][:, + # :self.n_samples_lod0 + self.n_importance_lod0, + # None]).sum(dim=1).detach().cpu().numpy()) + # del render_out + + # # ****************** lod 1 ************************** + # if self.num_lods > 1: + # for rays_o_batch, rays_d_batch in zip(rays_o, rays_d): + # render_out_lod1 = self.sdf_renderer_lod1.render( + # rays_o_batch, rays_d_batch, near, far, + # self.sdf_network_lod1, + # self.rendering_network_lod1, + # background_rgb=background_rgb, + # alpha_inter_ratio=alpha_inter_ratio_lod1, + # # * related to conditional feature + # lod=1, + # conditional_volume=con_volume_lod1, + # conditional_valid_mask_volume=con_valid_mask_volume_lod1, + # # * 2d feature maps + # feature_maps=geometry_feature_maps_lod1, + # color_maps=imgs, + # w2cs=w2cs, + # intrinsics=intrinsics, + # img_wh=[sizeW, sizeH], + # query_c2w=query_c2w, + # if_render_with_grad=False, + # ) + + # feasible = lambda key: ((key in render_out_lod1) and (render_out_lod1[key] is not None)) + + # if feasible('depth'): + # out_depth_fine_lod1.append(render_out_lod1['depth'].detach().cpu().numpy()) + + # # if render_out['color_coarse'] is not None: + # if feasible('color_fine'): + # out_rgb_fine_lod1.append(render_out_lod1['color_fine'].detach().cpu().numpy()) + # if feasible('gradients') and feasible('weights'): + # if render_out_lod1['inside_sphere'] is not None: + # out_normal_fine_lod1.append((render_out_lod1['gradients'] * render_out_lod1['weights'][:, + # :self.n_samples_lod1 + self.n_importance_lod1, + # None] * + # render_out_lod1['inside_sphere'][ + # ..., None]).sum(dim=1).detach().cpu().numpy()) + # else: + # out_normal_fine_lod1.append((render_out_lod1['gradients'] * render_out_lod1['weights'][:, + # :self.n_samples_lod1 + self.n_importance_lod1, + # None]).sum( + # dim=1).detach().cpu().numpy()) + # del render_out_lod1 + + # # - save visualization of lod 0 + + # self.save_visualization(true_img, true_depth_colored, out_depth_fine, out_normal_fine, + # query_w2c[0], out_rgb_fine, H, W, + # depth_min, depth_max, iter_step, meta, "val_lod0", true_depth=true_depth, scale_factor=scale_factor) + + # if self.num_lods > 1: + # self.save_visualization(true_img, true_depth_colored, out_depth_fine_lod1, out_normal_fine_lod1, + # query_w2c[0], out_rgb_fine_lod1, H, W, + # depth_min, depth_max, iter_step, meta, "val_lod1", true_depth=true_depth, scale_factor=scale_factor) + + # - extract mesh + if (iter_step % self.val_mesh_freq == 0): + torch.cuda.empty_cache() + self.validate_colored_mesh( + density_or_sdf_network=self.sdf_network_lod0, + func_extract_geometry=self.sdf_renderer_lod0.extract_geometry, + conditional_volume=con_volume_lod0, + conditional_valid_mask_volume = con_valid_mask_volume_lod0, + feature_maps=geometry_feature_maps, + color_maps=imgs, + w2cs=w2cs, + target_candidate_w2cs=None, + intrinsics=intrinsics, + rendering_network=self.rendering_network_lod0, + rendering_projector=self.sdf_renderer_lod0.rendering_projector, + lod=0, + threshold=0, + query_c2w=query_c2w, + mode='val_bg', meta=meta, + iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat + ) + torch.cuda.empty_cache() + + if self.num_lods > 1: + self.validate_colored_mesh( + density_or_sdf_network=self.sdf_network_lod1, + func_extract_geometry=self.sdf_renderer_lod1.extract_geometry, + conditional_volume=con_volume_lod1, + conditional_valid_mask_volume = con_valid_mask_volume_lod1, + feature_maps=geometry_feature_maps, + color_maps=imgs, + w2cs=w2cs, + target_candidate_w2cs=None, + intrinsics=intrinsics, + rendering_network=self.rendering_network_lod1, + rendering_projector=self.sdf_renderer_lod1.rendering_projector, + lod=1, + threshold=0, + query_c2w=query_c2w, + mode='val_bg', meta=meta, + iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat + ) + torch.cuda.empty_cache() + # self.validate_mesh(self.sdf_network_lod1, + # self.sdf_renderer_lod1.extract_geometry, + # conditional_volume=con_volume_lod1, lod=1, + # # occupancy_mask=con_valid_mask_volume_lod1[0, 0].detach(), + # mode='val_bg', meta=meta, + # iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat) + + # torch.cuda.empty_cache() + + + def save_visualization(self, true_img, true_colored_depth, out_depth, out_normal, w2cs, out_color, H, W, + depth_min, depth_max, iter_step, meta, comment, out_color_mlp=[], true_depth=None, scale_factor=1.0): + if len(out_color) > 0: + img_fine = (np.concatenate(out_color, axis=0).reshape([H, W, 3]) * 256).clip(0, 255) + + if len(out_color_mlp) > 0: + img_mlp = (np.concatenate(out_color_mlp, axis=0).reshape([H, W, 3]) * 256).clip(0, 255) + + if len(out_normal) > 0: + normal_img = np.concatenate(out_normal, axis=0) + rot = w2cs[:3, :3].detach().cpu().numpy() + # - convert normal from world space to camera space + normal_img = (np.matmul(rot[None, :, :], + normal_img[:, :, None]).reshape([H, W, 3]) * 128 + 128).clip(0, 255) + if len(out_depth) > 0: + pred_depth = np.concatenate(out_depth, axis=0).reshape([H, W]) + pred_depth_colored = visualize_depth_numpy(pred_depth, [depth_min, depth_max])[0] + + if len(out_depth) > 0: + os.makedirs(os.path.join(self.base_exp_dir, 'depths_' + comment), exist_ok=True) + if true_colored_depth is not None: + + if true_depth is not None: + depth_error_map = np.abs(true_depth - pred_depth) * 2.0 / scale_factor + # [256, 256, 1] -> [256, 256, 3] + depth_error_map = np.tile(depth_error_map[:, :, None], [1, 1, 3]) + print("meta: ", meta) + print("scale_factor: ", scale_factor) + print("depth_error_mean: ", depth_error_map.mean()) + # import ipdb; ipdb.set_trace() + depth_visualized = np.concatenate( + [(depth_error_map * 255).astype(np.uint8), true_colored_depth, pred_depth_colored, true_img], axis=1)[:, :, ::-1] + # print("depth_visualized.shape: ", depth_visualized.shape) + # write depth error result text on img, the input is a numpy array of [256, 1024, 3] + # cv.putText(depth_visualized.copy(), "depth_error_mean: {:.4f}".format(depth_error_map.mean()), (10, 30), cv.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) + else: + depth_visualized = np.concatenate( + [true_colored_depth, pred_depth_colored, true_img])[:, :, ::-1] + cv.imwrite( + os.path.join(self.base_exp_dir, 'depths_' + comment, + '{:0>8d}_{}.png'.format(iter_step, meta)), depth_visualized + ) + else: + cv.imwrite( + os.path.join(self.base_exp_dir, 'depths_' + comment, + '{:0>8d}_{}.png'.format(iter_step, meta)), + np.concatenate( + [pred_depth_colored, true_img])[:, :, ::-1]) + if len(out_color) > 0: + os.makedirs(os.path.join(self.base_exp_dir, 'synthesized_color_' + comment), exist_ok=True) + cv.imwrite(os.path.join(self.base_exp_dir, 'synthesized_color_' + comment, + '{:0>8d}_{}.png'.format(iter_step, meta)), + np.concatenate( + [img_fine, true_img])[:, :, ::-1]) # bgr2rgb + # compute psnr (image pixel lie in [0, 255]) + mse_loss = np.mean((img_fine - true_img) ** 2) + psnr = 10 * np.log10(255 ** 2 / mse_loss) + + print("PSNR: ", psnr) + + if len(out_color_mlp) > 0: + os.makedirs(os.path.join(self.base_exp_dir, 'synthesized_color_mlp_' + comment), exist_ok=True) + cv.imwrite(os.path.join(self.base_exp_dir, 'synthesized_color_mlp_' + comment, + '{:0>8d}_{}.png'.format(iter_step, meta)), + np.concatenate( + [img_mlp, true_img])[:, :, ::-1]) # bgr2rgb + + if len(out_normal) > 0: + os.makedirs(os.path.join(self.base_exp_dir, 'normals_' + comment), exist_ok=True) + cv.imwrite(os.path.join(self.base_exp_dir, 'normals_' + comment, + '{:0>8d}_{}.png'.format(iter_step, meta)), + normal_img[:, :, ::-1]) + + def forward(self, sample, + perturb_overwrite=-1, + background_rgb=None, + alpha_inter_ratio_lod0=0.0, + alpha_inter_ratio_lod1=0.0, + iter_step=0, + mode='train', + save_vis=False, + ): + + if mode == 'train': + return self.train_step(sample, + perturb_overwrite=perturb_overwrite, + background_rgb=background_rgb, + alpha_inter_ratio_lod0=alpha_inter_ratio_lod0, + alpha_inter_ratio_lod1=alpha_inter_ratio_lod1, + iter_step=iter_step + ) + elif mode == 'val': + import time + begin = time.time() + result = self.val_step(sample, + perturb_overwrite=perturb_overwrite, + background_rgb=background_rgb, + alpha_inter_ratio_lod0=alpha_inter_ratio_lod0, + alpha_inter_ratio_lod1=alpha_inter_ratio_lod1, + iter_step=iter_step, + save_vis=save_vis, + ) + end = time.time() + print("val_step time: ", end - begin) + return result + elif mode == 'export_mesh': + import time + begin = time.time() + result = self.export_mesh_step(sample, + perturb_overwrite=perturb_overwrite, + background_rgb=background_rgb, + alpha_inter_ratio_lod0=alpha_inter_ratio_lod0, + alpha_inter_ratio_lod1=alpha_inter_ratio_lod1, + iter_step=iter_step, + save_vis=save_vis, + ) + end = time.time() + print("export mesh time: ", end - begin) + return result + def obtain_pyramid_feature_maps(self, imgs, lod=0): + """ + get feature maps of all conditional images + :param imgs: + :return: + """ + + if lod == 0: + extractor = self.pyramid_feature_network_geometry_lod0 + elif lod >= 1: + extractor = self.pyramid_feature_network_geometry_lod1 + + pyramid_feature_maps = extractor(imgs) + + # * the pyramid features are very important, if only use the coarst features, hard to optimize + fused_feature_maps = torch.cat([ + F.interpolate(pyramid_feature_maps[0], scale_factor=4, mode='bilinear', align_corners=True), + F.interpolate(pyramid_feature_maps[1], scale_factor=2, mode='bilinear', align_corners=True), + pyramid_feature_maps[2] + ], dim=1) + + return fused_feature_maps + + def cal_losses_sdf(self, render_out, sample_rays, iter_step=-1, lod=0): + + # loss weight schedule; the regularization terms should be added in later training stage + def get_weight(iter_step, weight): + if lod == 1: + anneal_start = self.anneal_end if lod == 0 else self.anneal_end_lod1 + anneal_end = self.anneal_end if lod == 0 else self.anneal_end_lod1 + anneal_end = anneal_end * 2 + else: + anneal_start = self.anneal_start if lod == 0 else self.anneal_start_lod1 + anneal_end = self.anneal_end if lod == 0 else self.anneal_end_lod1 + anneal_end = anneal_end * 2 + + if iter_step < 0: + return weight + + if anneal_end == 0.0: + return weight + elif iter_step < anneal_start: + return 0.0 + else: + return np.min( + [1.0, + (iter_step - anneal_start) / (anneal_end - anneal_start)]) * weight + + rays_o = sample_rays['rays_o'][0] + rays_d = sample_rays['rays_v'][0] + true_rgb = sample_rays['rays_color'][0] + + if 'rays_depth' in sample_rays.keys(): + true_depth = sample_rays['rays_depth'][0] + else: + true_depth = None + mask = sample_rays['rays_mask'][0] + + color_fine = render_out['color_fine'] + color_fine_mask = render_out['color_fine_mask'] + depth_pred = render_out['depth'] + + variance = render_out['variance'] + cdf_fine = render_out['cdf_fine'] + weight_sum = render_out['weights_sum'] + + gradient_error_fine = render_out['gradient_error_fine'] + + sdf = render_out['sdf'] + + # * color generated by mlp + color_mlp = render_out['color_mlp'] + color_mlp_mask = render_out['color_mlp_mask'] + + if color_fine is not None: + # Color loss + color_mask = color_fine_mask if color_fine_mask is not None else mask + # import ipdb; ipdb.set_trace() + color_mask = color_mask[..., 0] + color_error = (color_fine[color_mask] - true_rgb[color_mask]) + # print("Nan number", torch.isnan(color_error).sum()) + # print("Color error shape", color_error.shape) + # import ipdb; ipdb.set_trace() + color_fine_loss = F.l1_loss(color_error, torch.zeros_like(color_error).to(color_error.device), + reduction='mean') + # print(color_fine_loss) + psnr = 20.0 * torch.log10( + 1.0 / (((color_fine[color_mask] - true_rgb[color_mask]) ** 2).mean() / (3.0)).sqrt()) + else: + color_fine_loss = 0. + psnr = 0. + + if color_mlp is not None: + # Color loss + color_mlp_mask = color_mlp_mask[..., 0] + color_error_mlp = (color_mlp[color_mlp_mask] - true_rgb[color_mlp_mask]) + color_mlp_loss = F.l1_loss(color_error_mlp, + torch.zeros_like(color_error_mlp).to(color_error_mlp.device), + reduction='mean') + + psnr_mlp = 20.0 * torch.log10( + 1.0 / (((color_mlp[color_mlp_mask] - true_rgb[color_mlp_mask]) ** 2).mean() / (3.0)).sqrt()) + else: + color_mlp_loss = 0. + psnr_mlp = 0. + + # depth loss is only used for inference, not included in total loss + if true_depth is not None: + # depth_loss = self.depth_criterion(depth_pred, true_depth, mask) + depth_loss = self.depth_criterion(depth_pred, true_depth) + + # # depth evaluation + # depth_statis = compute_depth_errors(depth_pred.detach().cpu().numpy(), true_depth.cpu().numpy()) + # depth_statis = numpy2tensor(depth_statis, device=rays_o.device) + depth_statis = None + else: + depth_loss = 0. + depth_statis = None + + sparse_loss_1 = torch.exp( + -1 * torch.abs(render_out['sdf_random']) * self.sdf_decay_param).mean() # - should equal + sparse_loss_2 = torch.exp(-1 * torch.abs(sdf) * self.sdf_decay_param).mean() + sparse_loss = (sparse_loss_1 + sparse_loss_2) / 2 + + sdf_mean = torch.abs(sdf).mean() + sparseness_1 = (torch.abs(sdf) < 0.01).to(torch.float32).mean() + sparseness_2 = (torch.abs(sdf) < 0.02).to(torch.float32).mean() + + # Eikonal loss + gradient_error_loss = gradient_error_fine + + # ! the first 50k, don't use bg constraint + fg_bg_weight = 0.0 if iter_step < 50000 else get_weight(iter_step, self.fg_bg_weight) + + # Mask loss, optional + # The images of DTU dataset contain large black regions (0 rgb values), + # can use this data prior to make fg more clean + background_loss = 0.0 + fg_bg_loss = 0.0 + if self.fg_bg_weight > 0 and torch.mean((mask < 0.5).to(torch.float32)) > 0.02: + weights_sum_fg = render_out['weights_sum_fg'] + fg_bg_error = (weights_sum_fg - mask)[mask < 0.5] + fg_bg_loss = F.l1_loss(fg_bg_error, + torch.zeros_like(fg_bg_error).to(fg_bg_error.device), + reduction='mean') + + + + loss = 1.0 * depth_loss + color_fine_loss + color_mlp_loss + \ + sparse_loss * get_weight(iter_step, self.sdf_sparse_weight) + \ + fg_bg_loss * fg_bg_weight + \ + gradient_error_loss * self.sdf_igr_weight # ! gradient_error_loss need a mask + + losses = { + "loss": loss, + "depth_loss": depth_loss, + "color_fine_loss": color_fine_loss, + "color_mlp_loss": color_mlp_loss, + "gradient_error_loss": gradient_error_loss, + "background_loss": background_loss, + "sparse_loss": sparse_loss, + "sparseness_1": sparseness_1, + "sparseness_2": sparseness_2, + "sdf_mean": sdf_mean, + "psnr": psnr, + "psnr_mlp": psnr_mlp, + "weights_sum": render_out['weights_sum'], + "weights_sum_fg": render_out['weights_sum_fg'], + "alpha_sum": render_out['alpha_sum'], + "variance": render_out['variance'], + "sparse_weight": get_weight(iter_step, self.sdf_sparse_weight), + "fg_bg_weight": fg_bg_weight, + "fg_bg_loss": fg_bg_loss, # added by jha, bug of sparseNeuS + } + # print("[TEST]: weights_sum in trainner forward", losses['weights_sum'].mean()) + losses = numpy2tensor(losses, device=rays_o.device) + return loss, losses, depth_statis + + @torch.no_grad() + def validate_mesh(self, density_or_sdf_network, func_extract_geometry, world_space=True, resolution=360, + threshold=0.0, mode='val', + # * 3d feature volume + conditional_volume=None, lod=None, occupancy_mask=None, + bound_min=[-1, -1, -1], bound_max=[1, 1, 1], meta='', iter_step=0, scale_mat=None, + trans_mat=None + ): + + bound_min = torch.tensor(bound_min, dtype=torch.float32) + bound_max = torch.tensor(bound_max, dtype=torch.float32) + + vertices, triangles, fields = func_extract_geometry( + density_or_sdf_network, + bound_min, bound_max, resolution=resolution, + threshold=threshold, device=conditional_volume.device, + # * 3d feature volume + conditional_volume=conditional_volume, lod=lod, + occupancy_mask=occupancy_mask + ) + + + if scale_mat is not None: + scale_mat_np = scale_mat.cpu().numpy() + vertices = vertices * scale_mat_np[0][0, 0] + scale_mat_np[0][:3, 3][None] + + if trans_mat is not None: # w2c_ref_inv + trans_mat_np = trans_mat.cpu().numpy() + vertices_homo = np.concatenate([vertices, np.ones_like(vertices[:, :1])], axis=1) + vertices = np.matmul(trans_mat_np, vertices_homo[:, :, None])[:, :3, 0] + + mesh = trimesh.Trimesh(vertices, triangles) + os.makedirs(os.path.join(self.base_exp_dir, 'meshes_' + mode), exist_ok=True) + mesh.export(os.path.join(self.base_exp_dir, 'meshes_' + mode, + 'mesh_{:0>8d}_{}_lod{:0>1d}.ply'.format(iter_step, meta, lod))) + + + + def validate_colored_mesh(self, density_or_sdf_network, func_extract_geometry, world_space=True, resolution=360, + threshold=0.0, mode='val', + # * 3d feature volume + conditional_volume=None, + conditional_valid_mask_volume=None, + feature_maps=None, + color_maps = None, + w2cs=None, + target_candidate_w2cs=None, + intrinsics=None, + rendering_network=None, + rendering_projector=None, + query_c2w=None, + lod=None, occupancy_mask=None, + bound_min=[-1, -1, -1], bound_max=[1, 1, 1], meta='', iter_step=0, scale_mat=None, + trans_mat=None + ): + + bound_min = torch.tensor(bound_min, dtype=torch.float32) + bound_max = torch.tensor(bound_max, dtype=torch.float32) + + vertices, triangles, fields = func_extract_geometry( + density_or_sdf_network, + bound_min, bound_max, resolution=resolution, + threshold=threshold, device=conditional_volume.device, + # * 3d feature volume + conditional_volume=conditional_volume, lod=lod, + occupancy_mask=occupancy_mask + ) + + + with torch.no_grad(): + ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask, _, _ = rendering_projector.compute_view_independent( + torch.tensor(vertices).to(conditional_volume), + lod=0, + # * 3d geometry feature volumes + geometryVolume=conditional_volume[0], + geometryVolumeMask=conditional_valid_mask_volume[0], + sdf_network=density_or_sdf_network, + # * 2d rendering feature maps + rendering_feature_maps=feature_maps, # [n_view, 56, 256, 256] + color_maps=color_maps, + w2cs=w2cs, + target_candidate_w2cs=target_candidate_w2cs, + intrinsics=intrinsics, + img_wh=[256,256], + query_img_idx=0, # the index of the N_views dim for rendering + query_c2w=query_c2w, + ) + + + vertices_color, rendering_valid_mask = rendering_network( + ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask) + + + + if scale_mat is not None: + scale_mat_np = scale_mat.cpu().numpy() + vertices = vertices * scale_mat_np[0][0, 0] + scale_mat_np[0][:3, 3][None] + + if trans_mat is not None: # w2c_ref_inv + trans_mat_np = trans_mat.cpu().numpy() + vertices_homo = np.concatenate([vertices, np.ones_like(vertices[:, :1])], axis=1) + vertices = np.matmul(trans_mat_np, vertices_homo[:, :, None])[:, :3, 0] + # import ipdb; ipdb.set_trace() + vertices_color = np.array(vertices_color.squeeze(0).cpu() * 255, dtype=np.uint8) + mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=vertices_color) + os.makedirs(os.path.join(self.base_exp_dir, 'meshes_' + mode, 'lod{:0>1d}'.format(lod)), exist_ok=True) + mesh.export(os.path.join(self.base_exp_dir, 'meshes_' + mode, 'lod{:0>1d}'.format(lod), + 'mesh_{:0>8d}_{}_lod{:0>1d}.ply'.format(iter_step, meta, lod))) \ No newline at end of file diff --git a/SparseNeuS_demo_v1/ops/__init__.py b/SparseNeuS_demo_v1/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SparseNeuS_demo_v1/ops/back_project.py b/SparseNeuS_demo_v1/ops/back_project.py new file mode 100644 index 0000000000000000000000000000000000000000..5398f285f786a0e6c7a029138aa8a6554aae6e58 --- /dev/null +++ b/SparseNeuS_demo_v1/ops/back_project.py @@ -0,0 +1,175 @@ +import torch +from torch.nn.functional import grid_sample + + +def back_project_sparse_type(coords, origin, voxel_size, feats, KRcam, sizeH=None, sizeW=None, only_mask=False, + with_proj_z=False): + # - modified version from NeuRecon + ''' + Unproject the image fetures to form a 3D (sparse) feature volume + + :param coords: coordinates of voxels, + dim: (num of voxels, 4) (4 : batch ind, x, y, z) + :param origin: origin of the partial voxel volume (xyz position of voxel (0, 0, 0)) + dim: (batch size, 3) (3: x, y, z) + :param voxel_size: floats specifying the size of a voxel + :param feats: image features + dim: (num of views, batch size, C, H, W) + :param KRcam: projection matrix + dim: (num of views, batch size, 4, 4) + :return: feature_volume_all: 3D feature volumes + dim: (num of voxels, num_of_views, c) + :return: mask_volume_all: indicate the voxel of sampled feature volume is valid or not + dim: (num of voxels, num_of_views) + ''' + n_views, bs, c, h, w = feats.shape + device = feats.device + + if sizeH is None: + sizeH, sizeW = h, w # - if the KRcam is not suitable for the current feats + + feature_volume_all = torch.zeros(coords.shape[0], n_views, c).to(device) + mask_volume_all = torch.zeros([coords.shape[0], n_views], dtype=torch.int32).to(device) + # import ipdb; ipdb.set_trace() + for batch in range(bs): + # import ipdb; ipdb.set_trace() + batch_ind = torch.nonzero(coords[:, 0] == batch).squeeze(1) + coords_batch = coords[batch_ind][:, 1:] + + coords_batch = coords_batch.view(-1, 3) + origin_batch = origin[batch].unsqueeze(0) + feats_batch = feats[:, batch] + proj_batch = KRcam[:, batch] + + grid_batch = coords_batch * voxel_size + origin_batch.float() + rs_grid = grid_batch.unsqueeze(0).expand(n_views, -1, -1) + rs_grid = rs_grid.permute(0, 2, 1).contiguous() + nV = rs_grid.shape[-1] + rs_grid = torch.cat([rs_grid, torch.ones([n_views, 1, nV]).to(device)], dim=1) + + # Project grid + im_p = proj_batch @ rs_grid # - transform world pts to image UV space + im_x, im_y, im_z = im_p[:, 0], im_p[:, 1], im_p[:, 2] + + im_z[im_z >= 0] = im_z[im_z >= 0].clamp(min=1e-6) + + im_x = im_x / im_z + im_y = im_y / im_z + + im_grid = torch.stack([2 * im_x / (sizeW - 1) - 1, 2 * im_y / (sizeH - 1) - 1], dim=-1) + mask = im_grid.abs() <= 1 + mask = (mask.sum(dim=-1) == 2) & (im_z > 0) + + mask = mask.view(n_views, -1) + mask = mask.permute(1, 0).contiguous() # [num_pts, nviews] + + mask_volume_all[batch_ind] = mask.to(torch.int32) + + if only_mask: + return mask_volume_all + + feats_batch = feats_batch.view(n_views, c, h, w) + im_grid = im_grid.view(n_views, 1, -1, 2) + features = grid_sample(feats_batch, im_grid, padding_mode='zeros', align_corners=True) + # if features.isnan().sum() > 0: + # import ipdb; ipdb.set_trace() + features = features.view(n_views, c, -1) + features = features.permute(2, 0, 1).contiguous() # [num_pts, nviews, c] + + feature_volume_all[batch_ind] = features + + if with_proj_z: + im_z = im_z.view(n_views, 1, -1).permute(2, 0, 1).contiguous() # [num_pts, nviews, 1] + return feature_volume_all, mask_volume_all, im_z + # if feature_volume_all.isnan().sum() > 0: + # import ipdb; ipdb.set_trace() + return feature_volume_all, mask_volume_all + + +def cam2pixel(cam_coords, proj_c2p_rot, proj_c2p_tr, padding_mode, sizeH=None, sizeW=None, with_depth=False): + """Transform coordinates in the camera frame to the pixel frame. + Args: + cam_coords: pixel coordinates defined in the first camera coordinates system -- [B, 3, H, W] + proj_c2p_rot: rotation matrix of cameras -- [B, 3, 3] + proj_c2p_tr: translation vectors of cameras -- [B, 3, 1] + Returns: + array of [-1,1] coordinates -- [B, H, W, 2] + """ + b, _, h, w = cam_coords.size() + if sizeH is None: + sizeH = h + sizeW = w + + cam_coords_flat = cam_coords.view(b, 3, -1) # [B, 3, H*W] + if proj_c2p_rot is not None: + pcoords = proj_c2p_rot.bmm(cam_coords_flat) + else: + pcoords = cam_coords_flat + + if proj_c2p_tr is not None: + pcoords = pcoords + proj_c2p_tr # [B, 3, H*W] + X = pcoords[:, 0] + Y = pcoords[:, 1] + Z = pcoords[:, 2].clamp(min=1e-3) + + X_norm = 2 * (X / Z) / (sizeW - 1) - 1 # Normalized, -1 if on extreme left, + # 1 if on extreme right (x = w-1) [B, H*W] + Y_norm = 2 * (Y / Z) / (sizeH - 1) - 1 # Idem [B, H*W] + if padding_mode == 'zeros': + X_mask = ((X_norm > 1) + (X_norm < -1)).detach() + X_norm[X_mask] = 2 # make sure that no point in warped image is a combinaison of im and gray + Y_mask = ((Y_norm > 1) + (Y_norm < -1)).detach() + Y_norm[Y_mask] = 2 + + if with_depth: + pixel_coords = torch.stack([X_norm, Y_norm, Z], dim=2) # [B, H*W, 3] + return pixel_coords.view(b, h, w, 3) + else: + pixel_coords = torch.stack([X_norm, Y_norm], dim=2) # [B, H*W, 2] + return pixel_coords.view(b, h, w, 2) + + +# * have already checked, should check whether proj_matrix is for right coordinate system and resolution +def back_project_dense_type(coords, origin, voxel_size, feats, proj_matrix, sizeH=None, sizeW=None): + ''' + Unproject the image fetures to form a 3D (dense) feature volume + + :param coords: coordinates of voxels, + dim: (batch, nviews, 3, X,Y,Z) + :param origin: origin of the partial voxel volume (xyz position of voxel (0, 0, 0)) + dim: (batch size, 3) (3: x, y, z) + :param voxel_size: floats specifying the size of a voxel + :param feats: image features + dim: (batch size, num of views, C, H, W) + :param proj_matrix: projection matrix + dim: (batch size, num of views, 4, 4) + :return: feature_volume_all: 3D feature volumes + dim: (batch, nviews, C, X,Y,Z) + :return: count: number of times each voxel can be seen + dim: (batch, nviews, 1, X,Y,Z) + ''' + + batch, nviews, _, wX, wY, wZ = coords.shape + + if sizeH is None: + sizeH, sizeW = feats.shape[-2:] + proj_matrix = proj_matrix.view(batch * nviews, *proj_matrix.shape[2:]) + + coords_wrd = coords * voxel_size + origin.view(batch, 1, 3, 1, 1, 1) + coords_wrd = coords_wrd.view(batch * nviews, 3, wX * wY * wZ, 1) # (b*nviews,3,wX*wY*wZ, 1) + + pixel_grids = cam2pixel(coords_wrd, proj_matrix[:, :3, :3], proj_matrix[:, :3, 3:], + 'zeros', sizeH=sizeH, sizeW=sizeW) # (b*nviews,wX*wY*wZ, 2) + pixel_grids = pixel_grids.view(batch * nviews, 1, wX * wY * wZ, 2) + + feats = feats.view(batch * nviews, *feats.shape[2:]) # (b*nviews,c,h,w) + + ones = torch.ones((batch * nviews, 1, *feats.shape[2:])).to(feats.dtype).to(feats.device) + + features_volume = torch.nn.functional.grid_sample(feats, pixel_grids, padding_mode='zeros', align_corners=True) + counts_volume = torch.nn.functional.grid_sample(ones, pixel_grids, padding_mode='zeros', align_corners=True) + + features_volume = features_volume.view(batch, nviews, -1, wX, wY, wZ) # (batch, nviews, C, X,Y,Z) + counts_volume = counts_volume.view(batch, nviews, -1, wX, wY, wZ) + return features_volume, counts_volume + diff --git a/SparseNeuS_demo_v1/ops/generate_grids.py b/SparseNeuS_demo_v1/ops/generate_grids.py new file mode 100644 index 0000000000000000000000000000000000000000..884c37793131323c566c6d1a738f06d497bbd2fb --- /dev/null +++ b/SparseNeuS_demo_v1/ops/generate_grids.py @@ -0,0 +1,33 @@ +import torch + + +def generate_grid(n_vox, interval): + """ + generate grid + if 3D volume, grid[:,:,x,y,z] = (x,y,z) + :param n_vox: + :param interval: + :return: + """ + with torch.no_grad(): + # Create voxel grid + grid_range = [torch.arange(0, n_vox[axis], interval) for axis in range(3)] + grid = torch.stack(torch.meshgrid(grid_range[0], grid_range[1], grid_range[2])) # 3 dx dy dz + # ! don't create tensor on gpu; imbalanced gpu memory in ddp mode + grid = grid.unsqueeze(0).type(torch.float32) # 1 3 dx dy dz + + return grid + + +if __name__ == "__main__": + import torch.nn.functional as F + grid = generate_grid([5, 6, 8], 1) + + pts = 2 * torch.tensor([1, 2, 3]) / (torch.tensor([5, 6, 8]) - 1) - 1 + pts = pts.view(1, 1, 1, 1, 3) + + pts = torch.flip(pts, dims=[-1]) + + sampled = F.grid_sample(grid, pts, mode='nearest') + + print(sampled) diff --git a/SparseNeuS_demo_v1/ops/grid_sampler.py b/SparseNeuS_demo_v1/ops/grid_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..44113faa705f0b98a5689c0e4fb9e7a95865d6c1 --- /dev/null +++ b/SparseNeuS_demo_v1/ops/grid_sampler.py @@ -0,0 +1,467 @@ +""" +pytorch grid_sample doesn't support second-order derivative +implement custom version +""" + +import torch +import torch.nn.functional as F +import numpy as np + + +def grid_sample_2d(image, optical): + N, C, IH, IW = image.shape + _, H, W, _ = optical.shape + + ix = optical[..., 0] + iy = optical[..., 1] + + ix = ((ix + 1) / 2) * (IW - 1); + iy = ((iy + 1) / 2) * (IH - 1); + with torch.no_grad(): + ix_nw = torch.floor(ix); + iy_nw = torch.floor(iy); + ix_ne = ix_nw + 1; + iy_ne = iy_nw; + ix_sw = ix_nw; + iy_sw = iy_nw + 1; + ix_se = ix_nw + 1; + iy_se = iy_nw + 1; + + nw = (ix_se - ix) * (iy_se - iy) + ne = (ix - ix_sw) * (iy_sw - iy) + sw = (ix_ne - ix) * (iy - iy_ne) + se = (ix - ix_nw) * (iy - iy_nw) + + with torch.no_grad(): + torch.clamp(ix_nw, 0, IW - 1, out=ix_nw) + torch.clamp(iy_nw, 0, IH - 1, out=iy_nw) + + torch.clamp(ix_ne, 0, IW - 1, out=ix_ne) + torch.clamp(iy_ne, 0, IH - 1, out=iy_ne) + + torch.clamp(ix_sw, 0, IW - 1, out=ix_sw) + torch.clamp(iy_sw, 0, IH - 1, out=iy_sw) + + torch.clamp(ix_se, 0, IW - 1, out=ix_se) + torch.clamp(iy_se, 0, IH - 1, out=iy_se) + + image = image.view(N, C, IH * IW) + + nw_val = torch.gather(image, 2, (iy_nw * IW + ix_nw).long().view(N, 1, H * W).repeat(1, C, 1)) + ne_val = torch.gather(image, 2, (iy_ne * IW + ix_ne).long().view(N, 1, H * W).repeat(1, C, 1)) + sw_val = torch.gather(image, 2, (iy_sw * IW + ix_sw).long().view(N, 1, H * W).repeat(1, C, 1)) + se_val = torch.gather(image, 2, (iy_se * IW + ix_se).long().view(N, 1, H * W).repeat(1, C, 1)) + + out_val = (nw_val.view(N, C, H, W) * nw.view(N, 1, H, W) + + ne_val.view(N, C, H, W) * ne.view(N, 1, H, W) + + sw_val.view(N, C, H, W) * sw.view(N, 1, H, W) + + se_val.view(N, C, H, W) * se.view(N, 1, H, W)) + + return out_val + + +# - checked for correctness +def grid_sample_3d(volume, optical): + """ + bilinear sampling cannot guarantee continuous first-order gradient + mimic pytorch grid_sample function + The 8 corner points of a volume noted as: 4 points (front view); 4 points (back view) + fnw (front north west) point + bse (back south east) point + :param volume: [B, C, X, Y, Z] + :param optical: [B, x, y, z, 3] + :return: + """ + N, C, ID, IH, IW = volume.shape + _, D, H, W, _ = optical.shape + + ix = optical[..., 0] + iy = optical[..., 1] + iz = optical[..., 2] + + ix = ((ix + 1) / 2) * (IW - 1) + iy = ((iy + 1) / 2) * (IH - 1) + iz = ((iz + 1) / 2) * (ID - 1) + + mask_x = (ix > 0) & (ix < IW) + mask_y = (iy > 0) & (iy < IH) + mask_z = (iz > 0) & (iz < ID) + + mask = mask_x & mask_y & mask_z # [B, x, y, z] + mask = mask[:, None, :, :, :].repeat(1, C, 1, 1, 1) # [B, C, x, y, z] + + with torch.no_grad(): + # back north west + ix_bnw = torch.floor(ix) + iy_bnw = torch.floor(iy) + iz_bnw = torch.floor(iz) + + ix_bne = ix_bnw + 1 + iy_bne = iy_bnw + iz_bne = iz_bnw + + ix_bsw = ix_bnw + iy_bsw = iy_bnw + 1 + iz_bsw = iz_bnw + + ix_bse = ix_bnw + 1 + iy_bse = iy_bnw + 1 + iz_bse = iz_bnw + + # front view + ix_fnw = ix_bnw + iy_fnw = iy_bnw + iz_fnw = iz_bnw + 1 + + ix_fne = ix_bnw + 1 + iy_fne = iy_bnw + iz_fne = iz_bnw + 1 + + ix_fsw = ix_bnw + iy_fsw = iy_bnw + 1 + iz_fsw = iz_bnw + 1 + + ix_fse = ix_bnw + 1 + iy_fse = iy_bnw + 1 + iz_fse = iz_bnw + 1 + + # back view + bnw = (ix_fse - ix) * (iy_fse - iy) * (iz_fse - iz) # smaller volume, larger weight + bne = (ix - ix_fsw) * (iy_fsw - iy) * (iz_fsw - iz) + bsw = (ix_fne - ix) * (iy - iy_fne) * (iz_fne - iz) + bse = (ix - ix_fnw) * (iy - iy_fnw) * (iz_fnw - iz) + + # front view + fnw = (ix_bse - ix) * (iy_bse - iy) * (iz - iz_bse) # smaller volume, larger weight + fne = (ix - ix_bsw) * (iy_bsw - iy) * (iz - iz_bsw) + fsw = (ix_bne - ix) * (iy - iy_bne) * (iz - iz_bne) + fse = (ix - ix_bnw) * (iy - iy_bnw) * (iz - iz_bnw) + + with torch.no_grad(): + # back view + torch.clamp(ix_bnw, 0, IW - 1, out=ix_bnw) + torch.clamp(iy_bnw, 0, IH - 1, out=iy_bnw) + torch.clamp(iz_bnw, 0, ID - 1, out=iz_bnw) + + torch.clamp(ix_bne, 0, IW - 1, out=ix_bne) + torch.clamp(iy_bne, 0, IH - 1, out=iy_bne) + torch.clamp(iz_bne, 0, ID - 1, out=iz_bne) + + torch.clamp(ix_bsw, 0, IW - 1, out=ix_bsw) + torch.clamp(iy_bsw, 0, IH - 1, out=iy_bsw) + torch.clamp(iz_bsw, 0, ID - 1, out=iz_bsw) + + torch.clamp(ix_bse, 0, IW - 1, out=ix_bse) + torch.clamp(iy_bse, 0, IH - 1, out=iy_bse) + torch.clamp(iz_bse, 0, ID - 1, out=iz_bse) + + # front view + torch.clamp(ix_fnw, 0, IW - 1, out=ix_fnw) + torch.clamp(iy_fnw, 0, IH - 1, out=iy_fnw) + torch.clamp(iz_fnw, 0, ID - 1, out=iz_fnw) + + torch.clamp(ix_fne, 0, IW - 1, out=ix_fne) + torch.clamp(iy_fne, 0, IH - 1, out=iy_fne) + torch.clamp(iz_fne, 0, ID - 1, out=iz_fne) + + torch.clamp(ix_fsw, 0, IW - 1, out=ix_fsw) + torch.clamp(iy_fsw, 0, IH - 1, out=iy_fsw) + torch.clamp(iz_fsw, 0, ID - 1, out=iz_fsw) + + torch.clamp(ix_fse, 0, IW - 1, out=ix_fse) + torch.clamp(iy_fse, 0, IH - 1, out=iy_fse) + torch.clamp(iz_fse, 0, ID - 1, out=iz_fse) + + # xxx = volume[:, :, iz_bnw.long(), iy_bnw.long(), ix_bnw.long()] + volume = volume.view(N, C, ID * IH * IW) + # yyy = volume[:, :, (iz_bnw * ID + iy_bnw * IW + ix_bnw).long()] + + # back view + bnw_val = torch.gather(volume, 2, + (iz_bnw * ID ** 2 + iy_bnw * IW + ix_bnw).long().view(N, 1, D * H * W).repeat(1, C, 1)) + bne_val = torch.gather(volume, 2, + (iz_bne * ID ** 2 + iy_bne * IW + ix_bne).long().view(N, 1, D * H * W).repeat(1, C, 1)) + bsw_val = torch.gather(volume, 2, + (iz_bsw * ID ** 2 + iy_bsw * IW + ix_bsw).long().view(N, 1, D * H * W).repeat(1, C, 1)) + bse_val = torch.gather(volume, 2, + (iz_bse * ID ** 2 + iy_bse * IW + ix_bse).long().view(N, 1, D * H * W).repeat(1, C, 1)) + + # front view + fnw_val = torch.gather(volume, 2, + (iz_fnw * ID ** 2 + iy_fnw * IW + ix_fnw).long().view(N, 1, D * H * W).repeat(1, C, 1)) + fne_val = torch.gather(volume, 2, + (iz_fne * ID ** 2 + iy_fne * IW + ix_fne).long().view(N, 1, D * H * W).repeat(1, C, 1)) + fsw_val = torch.gather(volume, 2, + (iz_fsw * ID ** 2 + iy_fsw * IW + ix_fsw).long().view(N, 1, D * H * W).repeat(1, C, 1)) + fse_val = torch.gather(volume, 2, + (iz_fse * ID ** 2 + iy_fse * IW + ix_fse).long().view(N, 1, D * H * W).repeat(1, C, 1)) + + out_val = ( + # back + bnw_val.view(N, C, D, H, W) * bnw.view(N, 1, D, H, W) + + bne_val.view(N, C, D, H, W) * bne.view(N, 1, D, H, W) + + bsw_val.view(N, C, D, H, W) * bsw.view(N, 1, D, H, W) + + bse_val.view(N, C, D, H, W) * bse.view(N, 1, D, H, W) + + # front + fnw_val.view(N, C, D, H, W) * fnw.view(N, 1, D, H, W) + + fne_val.view(N, C, D, H, W) * fne.view(N, 1, D, H, W) + + fsw_val.view(N, C, D, H, W) * fsw.view(N, 1, D, H, W) + + fse_val.view(N, C, D, H, W) * fse.view(N, 1, D, H, W) + + ) + + # * zero padding + out_val = torch.where(mask, out_val, torch.zeros_like(out_val).float().to(out_val.device)) + + return out_val + + +# Interpolation kernel +def get_weight(s, a=-0.5): + mask_0 = (torch.abs(s) >= 0) & (torch.abs(s) <= 1) + mask_1 = (torch.abs(s) > 1) & (torch.abs(s) <= 2) + mask_2 = torch.abs(s) > 2 + + weight = torch.zeros_like(s).to(s.device) + weight = torch.where(mask_0, (a + 2) * (torch.abs(s) ** 3) - (a + 3) * (torch.abs(s) ** 2) + 1, weight) + weight = torch.where(mask_1, + a * (torch.abs(s) ** 3) - (5 * a) * (torch.abs(s) ** 2) + (8 * a) * torch.abs(s) - 4 * a, + weight) + + # if (torch.abs(s) >= 0) & (torch.abs(s) <= 1): + # return (a + 2) * (torch.abs(s) ** 3) - (a + 3) * (torch.abs(s) ** 2) + 1 + # + # elif (torch.abs(s) > 1) & (torch.abs(s) <= 2): + # return a * (torch.abs(s) ** 3) - (5 * a) * (torch.abs(s) ** 2) + (8 * a) * torch.abs(s) - 4 * a + # return 0 + + return weight + + +def cubic_interpolate(p, x): + """ + one dimensional cubic interpolation + :param p: [N, 4] (4) should be in order + :param x: [N] + :return: + """ + return p[:, 1] + 0.5 * x * (p[:, 2] - p[:, 0] + x * ( + 2.0 * p[:, 0] - 5.0 * p[:, 1] + 4.0 * p[:, 2] - p[:, 3] + x * ( + 3.0 * (p[:, 1] - p[:, 2]) + p[:, 3] - p[:, 0]))) + + +def bicubic_interpolate(p, x, y, if_batch=True): + """ + two dimensional cubic interpolation + :param p: [N, 4, 4] + :param x: [N] + :param y: [N] + :return: + """ + num = p.shape[0] + + if not if_batch: + arr0 = cubic_interpolate(p[:, 0, :], x) # [N] + arr1 = cubic_interpolate(p[:, 1, :], x) + arr2 = cubic_interpolate(p[:, 2, :], x) + arr3 = cubic_interpolate(p[:, 3, :], x) + return cubic_interpolate(torch.stack([arr0, arr1, arr2, arr3], dim=-1), y) # [N] + else: + x = x[:, None].repeat(1, 4).view(-1) + p = p.contiguous().view(num * 4, 4) + arr = cubic_interpolate(p, x) + arr = arr.view(num, 4) + + return cubic_interpolate(arr, y) + + +def tricubic_interpolate(p, x, y, z): + """ + three dimensional cubic interpolation + :param p: [N,4,4,4] + :param x: [N] + :param y: [N] + :param z: [N] + :return: + """ + num = p.shape[0] + + arr0 = bicubic_interpolate(p[:, 0, :, :], x, y) # [N] + arr1 = bicubic_interpolate(p[:, 1, :, :], x, y) + arr2 = bicubic_interpolate(p[:, 2, :, :], x, y) + arr3 = bicubic_interpolate(p[:, 3, :, :], x, y) + + return cubic_interpolate(torch.stack([arr0, arr1, arr2, arr3], dim=-1), z) # [N] + + +def cubic_interpolate_batch(p, x): + """ + one dimensional cubic interpolation + :param p: [B, N, 4] (4) should be in order + :param x: [B, N] + :return: + """ + return p[:, :, 1] + 0.5 * x * (p[:, :, 2] - p[:, :, 0] + x * ( + 2.0 * p[:, :, 0] - 5.0 * p[:, :, 1] + 4.0 * p[:, :, 2] - p[:, :, 3] + x * ( + 3.0 * (p[:, :, 1] - p[:, :, 2]) + p[:, :, 3] - p[:, :, 0]))) + + +def bicubic_interpolate_batch(p, x, y): + """ + two dimensional cubic interpolation + :param p: [B, N, 4, 4] + :param x: [B, N] + :param y: [B, N] + :return: + """ + B, N, _, _ = p.shape + + x = x[:, :, None].repeat(1, 1, 4).view(B, N * 4) # [B, N*4] + arr = cubic_interpolate_batch(p.contiguous().view(B, N * 4, 4), x) + arr = arr.view(B, N, 4) + return cubic_interpolate_batch(arr, y) # [B, N] + + +# * batch version cannot speed up training +def tricubic_interpolate_batch(p, x, y, z): + """ + three dimensional cubic interpolation + :param p: [N,4,4,4] + :param x: [N] + :param y: [N] + :param z: [N] + :return: + """ + N = p.shape[0] + + x = x[None, :].repeat(4, 1) + y = y[None, :].repeat(4, 1) + + p = p.permute(1, 0, 2, 3).contiguous() + + arr = bicubic_interpolate_batch(p[:, :, :, :], x, y) # [4, N] + + arr = arr.permute(1, 0).contiguous() # [N, 4] + + return cubic_interpolate(arr, z) # [N] + + +def tricubic_sample_3d(volume, optical): + """ + tricubic sampling; can guarantee continuous gradient (interpolation border) + :param volume: [B, C, ID, IH, IW] + :param optical: [B, D, H, W, 3] + :param sample_num: + :return: + """ + + @torch.no_grad() + def get_shifts(x): + x1 = -1 * (1 + x - torch.floor(x)) + x2 = -1 * (x - torch.floor(x)) + x3 = torch.floor(x) + 1 - x + x4 = torch.floor(x) + 2 - x + + return torch.stack([x1, x2, x3, x4], dim=-1) # (B,d,h,w,4) + + N, C, ID, IH, IW = volume.shape + _, D, H, W, _ = optical.shape + + device = volume.device + + ix = optical[..., 0] + iy = optical[..., 1] + iz = optical[..., 2] + + ix = ((ix + 1) / 2) * (IW - 1) # (B,d,h,w) + iy = ((iy + 1) / 2) * (IH - 1) + iz = ((iz + 1) / 2) * (ID - 1) + + ix = ix.view(-1) + iy = iy.view(-1) + iz = iz.view(-1) + + with torch.no_grad(): + shifts_x = get_shifts(ix).view(-1, 4) # (B*d*h*w,4) + shifts_y = get_shifts(iy).view(-1, 4) + shifts_z = get_shifts(iz).view(-1, 4) + + perm_weights = torch.ones([N * D * H * W, 4 * 4 * 4]).long().to(device) + perm = torch.cumsum(perm_weights, dim=-1) - 1 # (B*d*h*w,64) + + perm_z = perm // 16 # [N*D*H*W, num] + perm_y = (perm - perm_z * 16) // 4 + perm_x = (perm - perm_z * 16 - perm_y * 4) + + shifts_x = torch.gather(shifts_x, 1, perm_x) # [N*D*H*W, num] + shifts_y = torch.gather(shifts_y, 1, perm_y) + shifts_z = torch.gather(shifts_z, 1, perm_z) + + ix_target = (ix[:, None] + shifts_x).long() # [N*D*H*W, num] + iy_target = (iy[:, None] + shifts_y).long() + iz_target = (iz[:, None] + shifts_z).long() + + torch.clamp(ix_target, 0, IW - 1, out=ix_target) + torch.clamp(iy_target, 0, IH - 1, out=iy_target) + torch.clamp(iz_target, 0, ID - 1, out=iz_target) + + local_dist_x = ix - ix_target[:, 1] # ! attention here is [:, 1] + local_dist_y = iy - iy_target[:, 1 + 4] + local_dist_z = iz - iz_target[:, 1 + 16] + + local_dist_x = local_dist_x.view(N, 1, D * H * W).repeat(1, C, 1).view(-1) + local_dist_y = local_dist_y.view(N, 1, D * H * W).repeat(1, C, 1).view(-1) + local_dist_z = local_dist_z.view(N, 1, D * H * W).repeat(1, C, 1).view(-1) + + # ! attention: IW is correct + idx_target = iz_target * ID ** 2 + iy_target * IW + ix_target # [N*D*H*W, num] + + volume = volume.view(N, C, ID * IH * IW) + + out = torch.gather(volume, 2, + idx_target.view(N, 1, D * H * W * 64).repeat(1, C, 1)) + out = out.view(N * C * D * H * W, 4, 4, 4) + + # - tricubic_interpolate() is a bit faster than tricubic_interpolate_batch() + final = tricubic_interpolate(out, local_dist_x, local_dist_y, local_dist_z).view(N, C, D, H, W) # [N,C,D,H,W] + + return final + + + +if __name__ == "__main__": + # image = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).view(1, 3, 1, 3) + # + # optical = torch.Tensor([0.9, 0.5, 0.6, -0.7]).view(1, 1, 2, 2) + # + # print(grid_sample_2d(image, optical)) + # + # print(F.grid_sample(image, optical, padding_mode='border', align_corners=True)) + + from ops.generate_grids import generate_grid + + p = torch.tensor([x for x in range(4)]).view(1, 4).float() + + v = cubic_interpolate(p, torch.tensor([0.5]).view(1)) + # v = bicubic_interpolate(p, torch.tensor([2/3]).view(1) , torch.tensor([2/3]).view(1)) + + vsize = 9 + volume = generate_grid([vsize, vsize, vsize], 1) # [1,3,10,10,10] + # volume = torch.tensor([x for x in range(1000)]).view(1, 1, 10, 10, 10).float() + X, Y, Z = 0, 0, 6 + x = 2 * X / (vsize - 1) - 1 + y = 2 * Y / (vsize - 1) - 1 + z = 2 * Z / (vsize - 1) - 1 + + # print(volume[:, :, Z, Y, X]) + + # volume = volume.view(1, 3, -1) + # xx = volume[:, :, Z * 9*9 + Y * 9 + X] + + optical = torch.Tensor([-0.6, -0.7, 0.5, 0.3, 0.5, 0.5]).view(1, 1, 1, 2, 3) + + print(F.grid_sample(volume, optical, padding_mode='border', align_corners=True)) + print(grid_sample_3d(volume, optical)) + print(tricubic_sample_3d(volume, optical)) + # target, relative_coords = implicit_sample_3d(volume, optical, 1) + # print(target) diff --git a/SparseNeuS_demo_v1/requirements.txt b/SparseNeuS_demo_v1/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..06a26a213732e63398677064788c50e7d03bf95f --- /dev/null +++ b/SparseNeuS_demo_v1/requirements.txt @@ -0,0 +1,11 @@ +git+https://github.com/mit-han-lab/torchsparse.git@v1.4.0 +opencv_python +trimesh +numpy +pyhocon +icecream +tqdm +scipy +PyMCubes +# sudo apt-get install libsparsehash-dev +inplace_abn \ No newline at end of file diff --git a/SparseNeuS_demo_v1/tsparse/__init__.py b/SparseNeuS_demo_v1/tsparse/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SparseNeuS_demo_v1/tsparse/modules.py b/SparseNeuS_demo_v1/tsparse/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..520809144718d84b77708bbc7a582a64078958b4 --- /dev/null +++ b/SparseNeuS_demo_v1/tsparse/modules.py @@ -0,0 +1,326 @@ +import torch +import torch.nn as nn +import torchsparse +import torchsparse.nn as spnn +from torchsparse.tensor import PointTensor + +from tsparse.torchsparse_utils import * + + +# __all__ = ['SPVCNN', 'SConv3d', 'SparseConvGRU'] + + +class ConvBnReLU(nn.Module): + def __init__(self, in_channels, out_channels, + kernel_size=3, stride=1, pad=1): + super(ConvBnReLU, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, + kernel_size, stride=stride, padding=pad, bias=False) + self.bn = nn.BatchNorm2d(out_channels) + self.activation = nn.ReLU(inplace=True) + + def forward(self, x): + return self.activation(self.bn(self.conv(x))) + + +class ConvBnReLU3D(nn.Module): + def __init__(self, in_channels, out_channels, + kernel_size=3, stride=1, pad=1): + super(ConvBnReLU3D, self).__init__() + self.conv = nn.Conv3d(in_channels, out_channels, + kernel_size, stride=stride, padding=pad, bias=False) + self.bn = nn.BatchNorm3d(out_channels) + self.activation = nn.ReLU(inplace=True) + + def forward(self, x): + return self.activation(self.bn(self.conv(x))) + + +################################### feature net ###################################### +class FeatureNet(nn.Module): + """ + output 3 levels of features using a FPN structure + """ + + def __init__(self): + super(FeatureNet, self).__init__() + + self.conv0 = nn.Sequential( + ConvBnReLU(3, 8, 3, 1, 1), + ConvBnReLU(8, 8, 3, 1, 1)) + + self.conv1 = nn.Sequential( + ConvBnReLU(8, 16, 5, 2, 2), + ConvBnReLU(16, 16, 3, 1, 1), + ConvBnReLU(16, 16, 3, 1, 1)) + + self.conv2 = nn.Sequential( + ConvBnReLU(16, 32, 5, 2, 2), + ConvBnReLU(32, 32, 3, 1, 1), + ConvBnReLU(32, 32, 3, 1, 1)) + + self.toplayer = nn.Conv2d(32, 32, 1) + self.lat1 = nn.Conv2d(16, 32, 1) + self.lat0 = nn.Conv2d(8, 32, 1) + + # to reduce channel size of the outputs from FPN + self.smooth1 = nn.Conv2d(32, 16, 3, padding=1) + self.smooth0 = nn.Conv2d(32, 8, 3, padding=1) + + def _upsample_add(self, x, y): + return torch.nn.functional.interpolate(x, scale_factor=2, + mode="bilinear", align_corners=True) + y + + def forward(self, x): + # x: (B, 3, H, W) + conv0 = self.conv0(x) # (B, 8, H, W) + conv1 = self.conv1(conv0) # (B, 16, H//2, W//2) + conv2 = self.conv2(conv1) # (B, 32, H//4, W//4) + feat2 = self.toplayer(conv2) # (B, 32, H//4, W//4) + feat1 = self._upsample_add(feat2, self.lat1(conv1)) # (B, 32, H//2, W//2) + feat0 = self._upsample_add(feat1, self.lat0(conv0)) # (B, 32, H, W) + + # reduce output channels + feat1 = self.smooth1(feat1) # (B, 16, H//2, W//2) + feat0 = self.smooth0(feat0) # (B, 8, H, W) + + # feats = {"level_0": feat0, + # "level_1": feat1, + # "level_2": feat2} + + return [feat2, feat1, feat0] # coarser to finer features + + +class BasicSparseConvolutionBlock(nn.Module): + def __init__(self, inc, outc, ks=3, stride=1, dilation=1): + super().__init__() + self.net = nn.Sequential( + spnn.Conv3d(inc, + outc, + kernel_size=ks, + dilation=dilation, + stride=stride), + spnn.BatchNorm(outc), + spnn.ReLU(True)) + + def forward(self, x): + out = self.net(x) + return out + + +class BasicSparseDeconvolutionBlock(nn.Module): + def __init__(self, inc, outc, ks=3, stride=1): + super().__init__() + self.net = nn.Sequential( + spnn.Conv3d(inc, + outc, + kernel_size=ks, + stride=stride, + transposed=True), + spnn.BatchNorm(outc), + spnn.ReLU(True)) + + def forward(self, x): + return self.net(x) + + +class SparseResidualBlock(nn.Module): + def __init__(self, inc, outc, ks=3, stride=1, dilation=1): + super().__init__() + self.net = nn.Sequential( + spnn.Conv3d(inc, + outc, + kernel_size=ks, + dilation=dilation, + stride=stride), spnn.BatchNorm(outc), + spnn.ReLU(True), + spnn.Conv3d(outc, + outc, + kernel_size=ks, + dilation=dilation, + stride=1), spnn.BatchNorm(outc)) + + self.downsample = nn.Sequential() if (inc == outc and stride == 1) else \ + nn.Sequential( + spnn.Conv3d(inc, outc, kernel_size=1, dilation=1, stride=stride), + spnn.BatchNorm(outc) + ) + + self.relu = spnn.ReLU(True) + + def forward(self, x): + out = self.relu(self.net(x) + self.downsample(x)) + return out + + +class SPVCNN(nn.Module): + def __init__(self, **kwargs): + super().__init__() + + self.dropout = kwargs['dropout'] + + cr = kwargs.get('cr', 1.0) + cs = [32, 64, 128, 96, 96] + cs = [int(cr * x) for x in cs] + + if 'pres' in kwargs and 'vres' in kwargs: + self.pres = kwargs['pres'] + self.vres = kwargs['vres'] + + self.stem = nn.Sequential( + spnn.Conv3d(kwargs['in_channels'], cs[0], kernel_size=3, stride=1), + spnn.BatchNorm(cs[0]), spnn.ReLU(True) + ) + + self.stage1 = nn.Sequential( + BasicSparseConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1), + SparseResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1), + SparseResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1), + ) + + self.stage2 = nn.Sequential( + BasicSparseConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1), + SparseResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1), + SparseResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1), + ) + + self.up1 = nn.ModuleList([ + BasicSparseDeconvolutionBlock(cs[2], cs[3], ks=2, stride=2), + nn.Sequential( + SparseResidualBlock(cs[3] + cs[1], cs[3], ks=3, stride=1, + dilation=1), + SparseResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1), + ) + ]) + + self.up2 = nn.ModuleList([ + BasicSparseDeconvolutionBlock(cs[3], cs[4], ks=2, stride=2), + nn.Sequential( + SparseResidualBlock(cs[4] + cs[0], cs[4], ks=3, stride=1, + dilation=1), + SparseResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1), + ) + ]) + + self.point_transforms = nn.ModuleList([ + nn.Sequential( + nn.Linear(cs[0], cs[2]), + nn.BatchNorm1d(cs[2]), + nn.ReLU(True), + ), + nn.Sequential( + nn.Linear(cs[2], cs[4]), + nn.BatchNorm1d(cs[4]), + nn.ReLU(True), + ) + ]) + + self.weight_initialization() + + if self.dropout: + self.dropout = nn.Dropout(0.3, True) + + def weight_initialization(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, z): + # x: SparseTensor z: PointTensor + x0 = initial_voxelize(z, self.pres, self.vres) + + x0 = self.stem(x0) + z0 = voxel_to_point(x0, z, nearest=False) + z0.F = z0.F + + x1 = point_to_voxel(x0, z0) + x1 = self.stage1(x1) + x2 = self.stage2(x1) + z1 = voxel_to_point(x2, z0) + z1.F = z1.F + self.point_transforms[0](z0.F) + + y3 = point_to_voxel(x2, z1) + if self.dropout: + y3.F = self.dropout(y3.F) + y3 = self.up1[0](y3) + y3 = torchsparse.cat([y3, x1]) + y3 = self.up1[1](y3) + + y4 = self.up2[0](y3) + y4 = torchsparse.cat([y4, x0]) + y4 = self.up2[1](y4) + z3 = voxel_to_point(y4, z1) + z3.F = z3.F + self.point_transforms[1](z1.F) + + return z3.F + + +class SparseCostRegNet(nn.Module): + """ + Sparse cost regularization network; + require sparse tensors as input + """ + + def __init__(self, d_in, d_out=8): + super(SparseCostRegNet, self).__init__() + self.d_in = d_in + self.d_out = d_out + + self.conv0 = BasicSparseConvolutionBlock(d_in, d_out) + + self.conv1 = BasicSparseConvolutionBlock(d_out, 16, stride=2) + self.conv2 = BasicSparseConvolutionBlock(16, 16) + + self.conv3 = BasicSparseConvolutionBlock(16, 32, stride=2) + self.conv4 = BasicSparseConvolutionBlock(32, 32) + + self.conv5 = BasicSparseConvolutionBlock(32, 64, stride=2) + self.conv6 = BasicSparseConvolutionBlock(64, 64) + + self.conv7 = BasicSparseDeconvolutionBlock(64, 32, ks=3, stride=2) + + self.conv9 = BasicSparseDeconvolutionBlock(32, 16, ks=3, stride=2) + + self.conv11 = BasicSparseDeconvolutionBlock(16, d_out, ks=3, stride=2) + + def forward(self, x): + """ + + :param x: sparse tensor + :return: sparse tensor + """ + conv0 = self.conv0(x) + conv2 = self.conv2(self.conv1(conv0)) + conv4 = self.conv4(self.conv3(conv2)) + + x = self.conv6(self.conv5(conv4)) + x = conv4 + self.conv7(x) + del conv4 + x = conv2 + self.conv9(x) + del conv2 + x = conv0 + self.conv11(x) + del conv0 + return x.F + + +class SConv3d(nn.Module): + def __init__(self, inc, outc, pres, vres, ks=3, stride=1, dilation=1): + super().__init__() + self.net = spnn.Conv3d(inc, + outc, + kernel_size=ks, + dilation=dilation, + stride=stride) + self.point_transforms = nn.Sequential( + nn.Linear(inc, outc), + ) + self.pres = pres + self.vres = vres + + def forward(self, z): + x = initial_voxelize(z, self.pres, self.vres) + x = self.net(x) + out = voxel_to_point(x, z, nearest=False) + out.F = out.F + self.point_transforms(z.F) + return out diff --git a/SparseNeuS_demo_v1/tsparse/torchsparse_utils.py b/SparseNeuS_demo_v1/tsparse/torchsparse_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..32f5b92ae5ef4bf9836b1e4c1dc17eaf3f7c93f9 --- /dev/null +++ b/SparseNeuS_demo_v1/tsparse/torchsparse_utils.py @@ -0,0 +1,137 @@ +""" +Copied from: +https://github.com/mit-han-lab/spvnas/blob/b24f50379ed888d3a0e784508a809d4e92e820c0/core/models/utils.py +""" +import torch +import torchsparse.nn.functional as F +from torchsparse import PointTensor, SparseTensor +from torchsparse.nn.utils import get_kernel_offsets + +import numpy as np + +# __all__ = ['initial_voxelize', 'point_to_voxel', 'voxel_to_point'] + + +# z: PointTensor +# return: SparseTensor +def initial_voxelize(z, init_res, after_res): + new_float_coord = torch.cat( + [(z.C[:, :3] * init_res) / after_res, z.C[:, -1].view(-1, 1)], 1) + + pc_hash = F.sphash(torch.floor(new_float_coord).int()) + sparse_hash = torch.unique(pc_hash) + idx_query = F.sphashquery(pc_hash, sparse_hash) + counts = F.spcount(idx_query.int(), len(sparse_hash)) + + inserted_coords = F.spvoxelize(torch.floor(new_float_coord), idx_query, + counts) + inserted_coords = torch.round(inserted_coords).int() + inserted_feat = F.spvoxelize(z.F, idx_query, counts) + + new_tensor = SparseTensor(inserted_feat, inserted_coords, 1) + new_tensor.cmaps.setdefault(new_tensor.stride, new_tensor.coords) + z.additional_features['idx_query'][1] = idx_query + z.additional_features['counts'][1] = counts + z.C = new_float_coord + + return new_tensor + + +# x: SparseTensor, z: PointTensor +# return: SparseTensor +def point_to_voxel(x, z): + if z.additional_features is None or z.additional_features.get('idx_query') is None \ + or z.additional_features['idx_query'].get(x.s) is None: + # pc_hash = hash_gpu(torch.floor(z.C).int()) + pc_hash = F.sphash( + torch.cat([ + torch.floor(z.C[:, :3] / x.s[0]).int() * x.s[0], + z.C[:, -1].int().view(-1, 1) + ], 1)) + sparse_hash = F.sphash(x.C) + idx_query = F.sphashquery(pc_hash, sparse_hash) + counts = F.spcount(idx_query.int(), x.C.shape[0]) + z.additional_features['idx_query'][x.s] = idx_query + z.additional_features['counts'][x.s] = counts + else: + idx_query = z.additional_features['idx_query'][x.s] + counts = z.additional_features['counts'][x.s] + + inserted_feat = F.spvoxelize(z.F, idx_query, counts) + new_tensor = SparseTensor(inserted_feat, x.C, x.s) + new_tensor.cmaps = x.cmaps + new_tensor.kmaps = x.kmaps + + return new_tensor + + +# x: SparseTensor, z: PointTensor +# return: PointTensor +def voxel_to_point(x, z, nearest=False): + if z.idx_query is None or z.weights is None or z.idx_query.get( + x.s) is None or z.weights.get(x.s) is None: + off = get_kernel_offsets(2, x.s, 1, device=z.F.device) + # old_hash = kernel_hash_gpu(torch.floor(z.C).int(), off) + old_hash = F.sphash( + torch.cat([ + torch.floor(z.C[:, :3] / x.s[0]).int() * x.s[0], + z.C[:, -1].int().view(-1, 1) + ], 1), off) + mm = x.C.to(z.F.device) + pc_hash = F.sphash(x.C.to(z.F.device)) + idx_query = F.sphashquery(old_hash, pc_hash) + weights = F.calc_ti_weights(z.C, idx_query, + scale=x.s[0]).transpose(0, 1).contiguous() + idx_query = idx_query.transpose(0, 1).contiguous() + if nearest: + weights[:, 1:] = 0. + idx_query[:, 1:] = -1 + new_feat = F.spdevoxelize(x.F, idx_query, weights) + new_tensor = PointTensor(new_feat, + z.C, + idx_query=z.idx_query, + weights=z.weights) + new_tensor.additional_features = z.additional_features + new_tensor.idx_query[x.s] = idx_query + new_tensor.weights[x.s] = weights + z.idx_query[x.s] = idx_query + z.weights[x.s] = weights + + else: + new_feat = F.spdevoxelize(x.F, z.idx_query.get(x.s), + z.weights.get(x.s)) # - sparse trilinear interpoltation operation + new_tensor = PointTensor(new_feat, + z.C, + idx_query=z.idx_query, + weights=z.weights) + new_tensor.additional_features = z.additional_features + + return new_tensor + + +def sparse_to_dense_torch_batch(locs, values, dim, default_val): + dense = torch.full([dim[0], dim[1], dim[2], dim[3]], float(default_val), device=locs.device) + dense[locs[:, 0], locs[:, 1], locs[:, 2], locs[:, 3]] = values + return dense + + +def sparse_to_dense_torch(locs, values, dim, default_val, device): + dense = torch.full([dim[0], dim[1], dim[2]], float(default_val), device=device) + if locs.shape[0] > 0: + dense[locs[:, 0], locs[:, 1], locs[:, 2]] = values + return dense + + +def sparse_to_dense_channel(locs, values, dim, c, default_val, device): + locs = locs.to(torch.int64) + dense = torch.full([dim[0], dim[1], dim[2], c], float(default_val), device=device) + if locs.shape[0] > 0: + dense[locs[:, 0], locs[:, 1], locs[:, 2]] = values + return dense + + +def sparse_to_dense_np(locs, values, dim, default_val): + dense = np.zeros([dim[0], dim[1], dim[2]], dtype=values.dtype) + dense.fill(default_val) + dense[locs[:, 0], locs[:, 1], locs[:, 2]] = values + return dense diff --git a/SparseNeuS_demo_v1/utils/__init__.py b/SparseNeuS_demo_v1/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SparseNeuS_demo_v1/utils/misc_utils.py b/SparseNeuS_demo_v1/utils/misc_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..85e80cf4e2bcf8bed0086e2b6c8a3bf3da40a056 --- /dev/null +++ b/SparseNeuS_demo_v1/utils/misc_utils.py @@ -0,0 +1,219 @@ +import os, torch, cv2, re +import numpy as np + +from PIL import Image +import torch.nn.functional as F +import torchvision.transforms as T + +# Misc +img2mse = lambda x, y: torch.mean((x - y) ** 2) +mse2psnr = lambda x: -10. * torch.log(x) / torch.log(torch.Tensor([10.])) +to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) +mse2psnr2 = lambda x: -10. * np.log(x) / np.log(10.) + + +def get_psnr(imgs_pred, imgs_gt): + psnrs = [] + for (img, tar) in zip(imgs_pred, imgs_gt): + psnrs.append(mse2psnr2(np.mean((img - tar.cpu().numpy()) ** 2))) + return np.array(psnrs) + + +def init_log(log, keys): + for key in keys: + log[key] = torch.tensor([0.0], dtype=float) + return log + + +def visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET): + """ + depth: (H, W) + """ + + x = np.nan_to_num(depth) # change nan to 0 + if minmax is None: + mi = np.min(x[x > 0]) # get minimum positive depth (ignore background) + ma = np.max(x) + else: + mi, ma = minmax + + x = (x - mi) / (ma - mi + 1e-8) # normalize to 0~1 + x = (255 * x).astype(np.uint8) + x_ = cv2.applyColorMap(x, cmap) + return x_, [mi, ma] + + +def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET): + """ + depth: (H, W) + """ + if type(depth) is not np.ndarray: + depth = depth.cpu().numpy() + + x = np.nan_to_num(depth) # change nan to 0 + if minmax is None: + mi = np.min(x[x > 0]) # get minimum positive depth (ignore background) + ma = np.max(x) + else: + mi, ma = minmax + + x = (x - mi) / (ma - mi + 1e-8) # normalize to 0~1 + x = (255 * x).astype(np.uint8) + x_ = Image.fromarray(cv2.applyColorMap(x, cmap)) + x_ = T.ToTensor()(x_) # (3, H, W) + return x_, [mi, ma] + + +def abs_error_numpy(depth_pred, depth_gt, mask): + depth_pred, depth_gt = depth_pred[mask], depth_gt[mask] + return np.abs(depth_pred - depth_gt) + + +def abs_error(depth_pred, depth_gt, mask): + depth_pred, depth_gt = depth_pred[mask], depth_gt[mask] + err = depth_pred - depth_gt + return np.abs(err) if type(depth_pred) is np.ndarray else err.abs() + + +def acc_threshold(depth_pred, depth_gt, mask, threshold): + """ + computes the percentage of pixels whose depth error is less than @threshold + """ + errors = abs_error(depth_pred, depth_gt, mask) + acc_mask = errors < threshold + return acc_mask.astype('float') if type(depth_pred) is np.ndarray else acc_mask.float() + + +def to_tensor_cuda(data, device, filter): + for item in data.keys(): + + if item in filter: + continue + + if type(data[item]) is np.ndarray: + data[item] = torch.tensor(data[item], dtype=torch.float32, device=device) + else: + data[item] = data[item].float().to(device) + return data + + +def to_cuda(data, device, filter): + for item in data.keys(): + if item in filter: + continue + + data[item] = data[item].float().to(device) + return data + + +def tensor_unsqueeze(data, filter): + for item in data.keys(): + if item in filter: + continue + + data[item] = data[item][None] + return data + + +def filter_keys(dict): + dict.pop('N_samples') + if 'ndc' in dict.keys(): + dict.pop('ndc') + if 'lindisp' in dict.keys(): + dict.pop('lindisp') + return dict + + +def sub_selete_data(data_batch, device, idx, filtKey=[], + filtIndex=['view_ids_all', 'c2ws_all', 'scan', 'bbox', 'w2ref', 'ref2w', 'light_id', 'ckpt', + 'idx']): + data_sub_selete = {} + for item in data_batch.keys(): + data_sub_selete[item] = data_batch[item][:, idx].float() if ( + item not in filtIndex and torch.is_tensor(item) and item.dim() > 2) else data_batch[item].float() + if not data_sub_selete[item].is_cuda: + data_sub_selete[item] = data_sub_selete[item].to(device) + return data_sub_selete + + +def detach_data(dictionary): + dictionary_new = {} + for key in dictionary.keys(): + dictionary_new[key] = dictionary[key].detach().clone() + return dictionary_new + + +def read_pfm(filename): + file = open(filename, 'rb') + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().decode('utf-8').rstrip() + if header == 'PF': + color = True + elif header == 'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().rstrip()) + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + file.close() + return data, scale + + +from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR + + +# from warmup_scheduler import GradualWarmupScheduler +def get_scheduler(hparams, optimizer): + eps = 1e-8 + if hparams.lr_scheduler == 'steplr': + scheduler = MultiStepLR(optimizer, milestones=hparams.decay_step, + gamma=hparams.decay_gamma) + elif hparams.lr_scheduler == 'cosine': + scheduler = CosineAnnealingLR(optimizer, T_max=hparams.num_epochs, eta_min=eps) + + else: + raise ValueError('scheduler not recognized!') + + # if hparams.warmup_epochs > 0 and hparams.optimizer not in ['radam', 'ranger']: + # scheduler = GradualWarmupScheduler(optimizer, multiplier=hparams.warmup_multiplier, + # total_epoch=hparams.warmup_epochs, after_scheduler=scheduler) + return scheduler + + +#### pairing #### +def get_nearest_pose_ids(tar_pose, ref_poses, num_select): + ''' + Args: + tar_pose: target pose [N, 4, 4] + ref_poses: reference poses [M, 4, 4] + num_select: the number of nearest views to select + Returns: the selected indices + ''' + + dists = np.linalg.norm(tar_pose[:, None, :3, 3] - ref_poses[None, :, :3, 3], axis=-1) + + sorted_ids = np.argsort(dists, axis=-1) + selected_ids = sorted_ids[:, :num_select] + return selected_ids diff --git a/SparseNeuS_demo_v1/utils/training_utils.py b/SparseNeuS_demo_v1/utils/training_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5d128ba2beda39b708850bd4c17c4603a8a17848 --- /dev/null +++ b/SparseNeuS_demo_v1/utils/training_utils.py @@ -0,0 +1,129 @@ +import numpy as np +import torchvision.utils as vutils +import torch, random +import torch.nn.functional as F + + +# print arguments +def print_args(args): + print("################################ args ################################") + for k, v in args.__dict__.items(): + print("{0: <10}\t{1: <30}\t{2: <20}".format(k, str(v), str(type(v)))) + print("########################################################################") + + +# torch.no_grad warpper for functions +def make_nograd_func(func): + def wrapper(*f_args, **f_kwargs): + with torch.no_grad(): + ret = func(*f_args, **f_kwargs) + return ret + + return wrapper + + +# convert a function into recursive style to handle nested dict/list/tuple variables +def make_recursive_func(func): + def wrapper(vars, device=None): + if isinstance(vars, list): + return [wrapper(x, device) for x in vars] + elif isinstance(vars, tuple): + return tuple([wrapper(x, device) for x in vars]) + elif isinstance(vars, dict): + return {k: wrapper(v, device) for k, v in vars.items()} + else: + return func(vars, device) + + return wrapper + + +@make_recursive_func +def tensor2float(vars): + if isinstance(vars, float): + return vars + elif isinstance(vars, torch.Tensor): + return vars.data.item() + else: + raise NotImplementedError("invalid input type {} for tensor2float".format(type(vars))) + + +@make_recursive_func +def tensor2numpy(vars): + if isinstance(vars, np.ndarray): + return vars + elif isinstance(vars, torch.Tensor): + return vars.detach().cpu().numpy().copy() + else: + raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars))) + + +@make_recursive_func +def numpy2tensor(vars, device='cpu'): + if not isinstance(vars, torch.Tensor) and vars is not None : + return torch.tensor(vars, device=device) + elif isinstance(vars, torch.Tensor): + return vars + elif vars is None: + return vars + else: + raise NotImplementedError("invalid input type {} for float2tensor".format(type(vars))) + + +@make_recursive_func +def tocuda(vars, device='cuda'): + if isinstance(vars, torch.Tensor): + return vars.to(device) + elif isinstance(vars, str): + return vars + else: + raise NotImplementedError("invalid input type {} for tocuda".format(type(vars))) + + +import torch.distributed as dist + + +def synchronize(): + """ + Helper function to synchronize (barrier) among all processes when + using distributed training + """ + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + dist.barrier() + + +def get_world_size(): + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + + +def reduce_scalar_outputs(scalar_outputs): + world_size = get_world_size() + if world_size < 2: + return scalar_outputs + with torch.no_grad(): + names = [] + scalars = [] + for k in sorted(scalar_outputs.keys()): + names.append(k) + if isinstance(scalar_outputs[k], torch.Tensor): + scalars.append(scalar_outputs[k]) + else: + scalars.append(torch.tensor(scalar_outputs[k], device='cuda')) + scalars = torch.stack(scalars, dim=0) + dist.reduce(scalars, dst=0) + if dist.get_rank() == 0: + # only main process gets accumulated, so only divide by + # world_size in this case + scalars /= world_size + reduced_scalars = {k: v for k, v in zip(names, scalars)} + + return reduced_scalars diff --git a/SparseNeuS_demo_v1/val.ipynb b/SparseNeuS_demo_v1/val.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..a39350692b1cc7e35754de19b4dae0277a959f2b --- /dev/null +++ b/SparseNeuS_demo_v1/val.ipynb @@ -0,0 +1,951 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name gradio_tmp --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue\n", + "\u001b[31mdetected 1 GPUs\u001b[0m\n", + "\u001b[33mbase_exp_dir: exp/lod0\u001b[0m\n", + "\u001b[34mStore in: ../gradio_tmp\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/chao/anaconda3/envs/gradio/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)\n", + " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n", + "[exp_runner_generic_blender_val.py:159 - __init__() ] Find checkpoint: ckpt_285000.pth\n", + "[exp_runner_generic_blender_val.py:500 - load_checkpoint() ] End\n", + "ic| self.iter_step: 285000, idx: -1\n", + "[exp_runner_generic_blender_val.py:579 - export_mesh() ] Validate begin\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "time for getting data 0.0004603862762451172\n", + "export mesh time: 5.274312734603882\n" + ] + }, + { + "data": { + "text/plain": [ + "0" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import os \n", + "\n", + "dataset = 'gradio_tmp' # !!! the subfolder name in valpath for which you want to eval\n", + "# os.system('pwd')\n", + "\n", + "bash_script = f'CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name {dataset} --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue'\n", + "print(bash_script)\n", + "os.system(bash_script)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name eval_test/ebicycle2 --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue\n", + "\u001b[31mdetected 1 GPUs\u001b[0m\n", + "\u001b[33mbase_exp_dir: exp/lod0\u001b[0m\n", + "save mesh to: ../eval_test/ebicycle2\n", + "\u001b[34mStore in: ../eval_test/ebicycle2\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/chao/anaconda3/envs/gradio/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)\n", + " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n", + "[exp_runner_generic_blender_val.py:160 - __init__() ] Find checkpoint: ckpt_180000.pth\n", + "[exp_runner_generic_blender_val.py:501 - load_checkpoint() ] End\n", + "ic| self.iter_step: 180000, idx: -1\n", + "[exp_runner_generic_blender_val.py:580 - export_mesh() ] Validate begin\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "time for getting data 0.0004477500915527344\n", + "export mesh time: 5.30308723449707\n" + ] + }, + { + "data": { + "text/plain": [ + "0" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import os \n", + "\n", + "dataset = 'eval_test/ebicycle2' # !!! the subfolder name in valpath for which you want to eval\n", + "# os.system('pwd')\n", + "\n", + "bash_script = f'CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name {dataset} --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue'\n", + "print(bash_script)\n", + "os.system(bash_script)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/20 [00:00 + + + \ No newline at end of file diff --git a/one2345_elev_est/.idea/inspectionProfiles/profiles_settings.xml b/one2345_elev_est/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000000000000000000000000000000000000..105ce2da2d6447d11dfe32bfb846c3d5b199fc99 --- /dev/null +++ b/one2345_elev_est/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/one2345_elev_est/.idea/misc.xml b/one2345_elev_est/.idea/misc.xml new file mode 100644 index 0000000000000000000000000000000000000000..d56657add3eb3c246989284ec6e6a8475603cf1d --- /dev/null +++ b/one2345_elev_est/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/one2345_elev_est/.idea/modules.xml b/one2345_elev_est/.idea/modules.xml new file mode 100644 index 0000000000000000000000000000000000000000..c835d1a7ad8ef9f9ef336501b88a5c38eac2dd86 --- /dev/null +++ b/one2345_elev_est/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/one2345_elev_est/.idea/one2345_elev_est.iml b/one2345_elev_est/.idea/one2345_elev_est.iml new file mode 100644 index 0000000000000000000000000000000000000000..870ae2cd58bfdf63caf9d20b67f1b848bad7aabe --- /dev/null +++ b/one2345_elev_est/.idea/one2345_elev_est.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/one2345_elev_est/install.sh b/one2345_elev_est/install.sh new file mode 100644 index 0000000000000000000000000000000000000000..27ad025b471b9368a059759d92501730d9f14cd2 --- /dev/null +++ b/one2345_elev_est/install.sh @@ -0,0 +1 @@ +python setup.py build develop diff --git a/one2345_elev_est/oee/models/loftr/__init__.py b/one2345_elev_est/oee/models/loftr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0d69b9c131cf41e95c5c6ee7d389b375267b22fa --- /dev/null +++ b/one2345_elev_est/oee/models/loftr/__init__.py @@ -0,0 +1,2 @@ +from .loftr import LoFTR +from .utils.cvpr_ds_config import default_cfg diff --git a/one2345_elev_est/oee/models/loftr/backbone/__init__.py b/one2345_elev_est/oee/models/loftr/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6e731b3f53ab367c89ef0ea8e1cbffb0d990775 --- /dev/null +++ b/one2345_elev_est/oee/models/loftr/backbone/__init__.py @@ -0,0 +1,11 @@ +from .resnet_fpn import ResNetFPN_8_2, ResNetFPN_16_4 + + +def build_backbone(config): + if config['backbone_type'] == 'ResNetFPN': + if config['resolution'] == (8, 2): + return ResNetFPN_8_2(config['resnetfpn']) + elif config['resolution'] == (16, 4): + return ResNetFPN_16_4(config['resnetfpn']) + else: + raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.") diff --git a/one2345_elev_est/oee/models/loftr/backbone/resnet_fpn.py b/one2345_elev_est/oee/models/loftr/backbone/resnet_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..985e5b3f273a51e51447a8025ca3aadbe46752eb --- /dev/null +++ b/one2345_elev_est/oee/models/loftr/backbone/resnet_fpn.py @@ -0,0 +1,199 @@ +import torch.nn as nn +import torch.nn.functional as F + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution without padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + def __init__(self, in_planes, planes, stride=1): + super().__init__() + self.conv1 = conv3x3(in_planes, planes, stride) + self.conv2 = conv3x3(planes, planes) + self.bn1 = nn.BatchNorm2d(planes) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + + if stride == 1: + self.downsample = None + else: + self.downsample = nn.Sequential( + conv1x1(in_planes, planes, stride=stride), + nn.BatchNorm2d(planes) + ) + + def forward(self, x): + y = x + y = self.relu(self.bn1(self.conv1(y))) + y = self.bn2(self.conv2(y)) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + +class ResNetFPN_8_2(nn.Module): + """ + ResNet+FPN, output resolution are 1/8 and 1/2. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + block = BasicBlock + initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + + # Class Variable + self.in_planes = initial_dim + + # Networks + self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(initial_dim) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 + self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 + self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 + + # 3. FPN upsample + self.layer3_outconv = conv1x1(block_dims[2], block_dims[2]) + self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) + self.layer2_outconv2 = nn.Sequential( + conv3x3(block_dims[2], block_dims[2]), + nn.BatchNorm2d(block_dims[2]), + nn.LeakyReLU(), + conv3x3(block_dims[2], block_dims[1]), + ) + self.layer1_outconv = conv1x1(block_dims[0], block_dims[1]) + self.layer1_outconv2 = nn.Sequential( + conv3x3(block_dims[1], block_dims[1]), + nn.BatchNorm2d(block_dims[1]), + nn.LeakyReLU(), + conv3x3(block_dims[1], block_dims[0]), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, dim, stride=1): + layer1 = block(self.in_planes, dim, stride=stride) + layer2 = block(dim, dim, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # ResNet Backbone + x0 = self.relu(self.bn1(self.conv1(x))) + x1 = self.layer1(x0) # 1/2 + x2 = self.layer2(x1) # 1/4 + x3 = self.layer3(x2) # 1/8 + + # FPN + x3_out = self.layer3_outconv(x3) + + x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) + x2_out = self.layer2_outconv(x2) + x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + + x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True) + x1_out = self.layer1_outconv(x1) + x1_out = self.layer1_outconv2(x1_out+x2_out_2x) + + return [x3_out, x1_out] + + +class ResNetFPN_16_4(nn.Module): + """ + ResNet+FPN, output resolution are 1/16 and 1/4. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + block = BasicBlock + initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + + # Class Variable + self.in_planes = initial_dim + + # Networks + self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(initial_dim) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 + self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 + self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 + self.layer4 = self._make_layer(block, block_dims[3], stride=2) # 1/16 + + # 3. FPN upsample + self.layer4_outconv = conv1x1(block_dims[3], block_dims[3]) + self.layer3_outconv = conv1x1(block_dims[2], block_dims[3]) + self.layer3_outconv2 = nn.Sequential( + conv3x3(block_dims[3], block_dims[3]), + nn.BatchNorm2d(block_dims[3]), + nn.LeakyReLU(), + conv3x3(block_dims[3], block_dims[2]), + ) + + self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) + self.layer2_outconv2 = nn.Sequential( + conv3x3(block_dims[2], block_dims[2]), + nn.BatchNorm2d(block_dims[2]), + nn.LeakyReLU(), + conv3x3(block_dims[2], block_dims[1]), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, dim, stride=1): + layer1 = block(self.in_planes, dim, stride=stride) + layer2 = block(dim, dim, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # ResNet Backbone + x0 = self.relu(self.bn1(self.conv1(x))) + x1 = self.layer1(x0) # 1/2 + x2 = self.layer2(x1) # 1/4 + x3 = self.layer3(x2) # 1/8 + x4 = self.layer4(x3) # 1/16 + + # FPN + x4_out = self.layer4_outconv(x4) + + x4_out_2x = F.interpolate(x4_out, scale_factor=2., mode='bilinear', align_corners=True) + x3_out = self.layer3_outconv(x3) + x3_out = self.layer3_outconv2(x3_out+x4_out_2x) + + x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) + x2_out = self.layer2_outconv(x2) + x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + + return [x4_out, x2_out] diff --git a/one2345_elev_est/oee/models/loftr/loftr.py b/one2345_elev_est/oee/models/loftr/loftr.py new file mode 100644 index 0000000000000000000000000000000000000000..79c491ee47a4d67cb8b3fe493397349e0867accd --- /dev/null +++ b/one2345_elev_est/oee/models/loftr/loftr.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn +from einops.einops import rearrange + +from .backbone import build_backbone +from .utils.position_encoding import PositionEncodingSine +from .loftr_module import LocalFeatureTransformer, FinePreprocess +from .utils.coarse_matching import CoarseMatching +from .utils.fine_matching import FineMatching + + +class LoFTR(nn.Module): + def __init__(self, config): + super().__init__() + # Misc + self.config = config + + # Modules + self.backbone = build_backbone(config) + self.pos_encoding = PositionEncodingSine( + config['coarse']['d_model'], + temp_bug_fix=config['coarse']['temp_bug_fix']) + self.loftr_coarse = LocalFeatureTransformer(config['coarse']) + self.coarse_matching = CoarseMatching(config['match_coarse']) + self.fine_preprocess = FinePreprocess(config) + self.loftr_fine = LocalFeatureTransformer(config["fine"]) + self.fine_matching = FineMatching() + + def forward(self, data): + """ + Update: + data (dict): { + 'image0': (torch.Tensor): (N, 1, H, W) + 'image1': (torch.Tensor): (N, 1, H, W) + 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position + 'mask1'(optional) : (torch.Tensor): (N, H, W) + } + """ + # 1. Local Feature CNN + data.update({ + 'bs': data['image0'].size(0), + 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:] + }) + + if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence + feats_c, feats_f = self.backbone(torch.cat([data['image0'], data['image1']], dim=0)) + (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(data['bs']), feats_f.split(data['bs']) + else: # handle different input shapes + (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(data['image0']), self.backbone(data['image1']) + + data.update({ + 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:], + 'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:] + }) + + # 2. coarse-level loftr module + # add featmap with positional encoding, then flatten it to sequence [N, HW, C] + feat_c0 = rearrange(self.pos_encoding(feat_c0), 'n c h w -> n (h w) c') + feat_c1 = rearrange(self.pos_encoding(feat_c1), 'n c h w -> n (h w) c') + + mask_c0 = mask_c1 = None # mask is useful in training + if 'mask0' in data: + mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2) + feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1) + + # 3. match coarse-level + self.coarse_matching(feat_c0, feat_c1, data, mask_c0=mask_c0, mask_c1=mask_c1) + + # 4. fine-level refinement + feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data) + if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted + feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold) + + # 5. match fine-level + self.fine_matching(feat_f0_unfold, feat_f1_unfold, data) + + def load_state_dict(self, state_dict, *args, **kwargs): + for k in list(state_dict.keys()): + if k.startswith('matcher.'): + state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k) + return super().load_state_dict(state_dict, *args, **kwargs) diff --git a/one2345_elev_est/oee/models/loftr/loftr_module/__init__.py b/one2345_elev_est/oee/models/loftr/loftr_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ca51db4f50a0c4f3dcd795e74b83e633ab2e990a --- /dev/null +++ b/one2345_elev_est/oee/models/loftr/loftr_module/__init__.py @@ -0,0 +1,2 @@ +from .transformer import LocalFeatureTransformer +from .fine_preprocess import FinePreprocess diff --git a/one2345_elev_est/oee/models/loftr/loftr_module/fine_preprocess.py b/one2345_elev_est/oee/models/loftr/loftr_module/fine_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..5bb8eefd362240a9901a335f0e6e07770ff04567 --- /dev/null +++ b/one2345_elev_est/oee/models/loftr/loftr_module/fine_preprocess.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops.einops import rearrange, repeat + + +class FinePreprocess(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.cat_c_feat = config['fine_concat_coarse_feat'] + self.W = self.config['fine_window_size'] + + d_model_c = self.config['coarse']['d_model'] + d_model_f = self.config['fine']['d_model'] + self.d_model_f = d_model_f + if self.cat_c_feat: + self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True) + self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True) + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu") + + def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data): + W = self.W + stride = data['hw0_f'][0] // data['hw0_c'][0] + + data.update({'W': W}) + if data['b_ids'].shape[0] == 0: + feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) + feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) + return feat0, feat1 + + # 1. unfold(crop) all local windows + feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2) + feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2) + feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2) + feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2) + + # 2. select only the predicted matches + feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf] + feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']] + + # option: use coarse-level loftr feature as context: concat and linear + if self.cat_c_feat: + feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']], + feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c] + feat_cf_win = self.merge_feat(torch.cat([ + torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf] + repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf] + ], -1)) + feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0) + + return feat_f0_unfold, feat_f1_unfold diff --git a/one2345_elev_est/oee/models/loftr/loftr_module/linear_attention.py b/one2345_elev_est/oee/models/loftr/loftr_module/linear_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..b73c5a6a6a722a44c0b68f70cb77c0988b8a5fb3 --- /dev/null +++ b/one2345_elev_est/oee/models/loftr/loftr_module/linear_attention.py @@ -0,0 +1,81 @@ +""" +Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" +Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py +""" + +import torch +from torch.nn import Module, Dropout + + +def elu_feature_map(x): + return torch.nn.functional.elu(x) + 1 + + +class LinearAttention(Module): + def __init__(self, eps=1e-6): + super().__init__() + self.feature_map = elu_feature_map + self.eps = eps + + def forward(self, queries, keys, values, q_mask=None, kv_mask=None): + """ Multi-Head linear attention proposed in "Transformers are RNNs" + Args: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + Q = self.feature_map(queries) + K = self.feature_map(keys) + + # set padded position to zero + if q_mask is not None: + Q = Q * q_mask[:, :, None, None] + if kv_mask is not None: + K = K * kv_mask[:, :, None, None] + values = values * kv_mask[:, :, None, None] + + v_length = values.size(1) + values = values / v_length # prevent fp16 overflow + KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V + Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) + queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length + + return queried_values.contiguous() + + +class FullAttention(Module): + def __init__(self, use_dropout=False, attention_dropout=0.1): + super().__init__() + self.use_dropout = use_dropout + self.dropout = Dropout(attention_dropout) + + def forward(self, queries, keys, values, q_mask=None, kv_mask=None): + """ Multi-head scaled dot-product attention, a.k.a full attention. + Args: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + + # Compute the unnormalized attention and apply the masks + QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) + if kv_mask is not None: + QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf')) + + # Compute the attention and the weighted average + softmax_temp = 1. / queries.size(3)**.5 # sqrt(D) + A = torch.softmax(softmax_temp * QK, dim=2) + if self.use_dropout: + A = self.dropout(A) + + queried_values = torch.einsum("nlsh,nshd->nlhd", A, values) + + return queried_values.contiguous() diff --git a/one2345_elev_est/oee/models/loftr/loftr_module/transformer.py b/one2345_elev_est/oee/models/loftr/loftr_module/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d79390ca08953bbef44e98149e662a681a16e42e --- /dev/null +++ b/one2345_elev_est/oee/models/loftr/loftr_module/transformer.py @@ -0,0 +1,101 @@ +import copy +import torch +import torch.nn as nn +from .linear_attention import LinearAttention, FullAttention + + +class LoFTREncoderLayer(nn.Module): + def __init__(self, + d_model, + nhead, + attention='linear'): + super(LoFTREncoderLayer, self).__init__() + + self.dim = d_model // nhead + self.nhead = nhead + + # multi-head attention + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + self.attention = LinearAttention() if attention == 'linear' else FullAttention() + self.merge = nn.Linear(d_model, d_model, bias=False) + + # feed-forward network + self.mlp = nn.Sequential( + nn.Linear(d_model*2, d_model*2, bias=False), + nn.ReLU(True), + nn.Linear(d_model*2, d_model, bias=False), + ) + + # norm and dropout + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x, source, x_mask=None, source_mask=None): + """ + Args: + x (torch.Tensor): [N, L, C] + source (torch.Tensor): [N, S, C] + x_mask (torch.Tensor): [N, L] (optional) + source_mask (torch.Tensor): [N, S] (optional) + """ + bs = x.size(0) + query, key, value = x, source, source + + # multi-head attention + query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)] + key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)] + value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) + message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)] + message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C] + message = self.norm1(message) + + # feed-forward network + message = self.mlp(torch.cat([x, message], dim=2)) + message = self.norm2(message) + + return x + message + + +class LocalFeatureTransformer(nn.Module): + """A Local Feature Transformer (LoFTR) module.""" + + def __init__(self, config): + super(LocalFeatureTransformer, self).__init__() + + self.config = config + self.d_model = config['d_model'] + self.nhead = config['nhead'] + self.layer_names = config['layer_names'] + encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'], config['attention']) + self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]) + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feat0, feat1, mask0=None, mask1=None): + """ + Args: + feat0 (torch.Tensor): [N, L, C] + feat1 (torch.Tensor): [N, S, C] + mask0 (torch.Tensor): [N, L] (optional) + mask1 (torch.Tensor): [N, S] (optional) + """ + + assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal" + + for layer, name in zip(self.layers, self.layer_names): + if name == 'self': + feat0 = layer(feat0, feat0, mask0, mask0) + feat1 = layer(feat1, feat1, mask1, mask1) + elif name == 'cross': + feat0 = layer(feat0, feat1, mask0, mask1) + feat1 = layer(feat1, feat0, mask1, mask0) + else: + raise KeyError + + return feat0, feat1 diff --git a/one2345_elev_est/oee/models/loftr/utils/coarse_matching.py b/one2345_elev_est/oee/models/loftr/utils/coarse_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..a97263339462dec3af9705d33d6ee634e2f46914 --- /dev/null +++ b/one2345_elev_est/oee/models/loftr/utils/coarse_matching.py @@ -0,0 +1,261 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops.einops import rearrange + +INF = 1e9 + +def mask_border(m, b: int, v): + """ Mask borders with value + Args: + m (torch.Tensor): [N, H0, W0, H1, W1] + b (int) + v (m.dtype) + """ + if b <= 0: + return + + m[:, :b] = v + m[:, :, :b] = v + m[:, :, :, :b] = v + m[:, :, :, :, :b] = v + m[:, -b:] = v + m[:, :, -b:] = v + m[:, :, :, -b:] = v + m[:, :, :, :, -b:] = v + + +def mask_border_with_padding(m, bd, v, p_m0, p_m1): + if bd <= 0: + return + + m[:, :bd] = v + m[:, :, :bd] = v + m[:, :, :, :bd] = v + m[:, :, :, :, :bd] = v + + h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int() + h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int() + for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)): + m[b_idx, h0 - bd:] = v + m[b_idx, :, w0 - bd:] = v + m[b_idx, :, :, h1 - bd:] = v + m[b_idx, :, :, :, w1 - bd:] = v + + +def compute_max_candidates(p_m0, p_m1): + """Compute the max candidates of all pairs within a batch + + Args: + p_m0, p_m1 (torch.Tensor): padded masks + """ + h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0] + h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0] + max_cand = torch.sum( + torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0]) + return max_cand + + +class CoarseMatching(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + # general config + self.thr = config['thr'] + self.border_rm = config['border_rm'] + # -- # for trainig fine-level LoFTR + self.train_coarse_percent = config['train_coarse_percent'] + self.train_pad_num_gt_min = config['train_pad_num_gt_min'] + + # we provide 2 options for differentiable matching + self.match_type = config['match_type'] + if self.match_type == 'dual_softmax': + self.temperature = config['dsmax_temperature'] + elif self.match_type == 'sinkhorn': + try: + from .superglue import log_optimal_transport + except ImportError: + raise ImportError("download superglue.py first!") + self.log_optimal_transport = log_optimal_transport + self.bin_score = nn.Parameter( + torch.tensor(config['skh_init_bin_score'], requires_grad=True)) + self.skh_iters = config['skh_iters'] + self.skh_prefilter = config['skh_prefilter'] + else: + raise NotImplementedError() + + def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None): + """ + Args: + feat0 (torch.Tensor): [N, L, C] + feat1 (torch.Tensor): [N, S, C] + data (dict) + mask_c0 (torch.Tensor): [N, L] (optional) + mask_c1 (torch.Tensor): [N, S] (optional) + Update: + data (dict): { + 'b_ids' (torch.Tensor): [M'], + 'i_ids' (torch.Tensor): [M'], + 'j_ids' (torch.Tensor): [M'], + 'gt_mask' (torch.Tensor): [M'], + 'mkpts0_c' (torch.Tensor): [M, 2], + 'mkpts1_c' (torch.Tensor): [M, 2], + 'mconf' (torch.Tensor): [M]} + NOTE: M' != M during training. + """ + N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2) + + # normalize + feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5, + [feat_c0, feat_c1]) + + if self.match_type == 'dual_softmax': + sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, + feat_c1) / self.temperature + if mask_c0 is not None: + sim_matrix.masked_fill_( + ~(mask_c0[..., None] * mask_c1[:, None]).bool(), + -INF) + conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2) + + elif self.match_type == 'sinkhorn': + # sinkhorn, dustbin included + sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1) + if mask_c0 is not None: + sim_matrix[:, :L, :S].masked_fill_( + ~(mask_c0[..., None] * mask_c1[:, None]).bool(), + -INF) + + # build uniform prior & use sinkhorn + log_assign_matrix = self.log_optimal_transport( + sim_matrix, self.bin_score, self.skh_iters) + assign_matrix = log_assign_matrix.exp() + conf_matrix = assign_matrix[:, :-1, :-1] + + # filter prediction with dustbin score (only in evaluation mode) + if not self.training and self.skh_prefilter: + filter0 = (assign_matrix.max(dim=2)[1] == S)[:, :-1] # [N, L] + filter1 = (assign_matrix.max(dim=1)[1] == L)[:, :-1] # [N, S] + conf_matrix[filter0[..., None].repeat(1, 1, S)] = 0 + conf_matrix[filter1[:, None].repeat(1, L, 1)] = 0 + + if self.config['sparse_spvs']: + data.update({'conf_matrix_with_bin': assign_matrix.clone()}) + + data.update({'conf_matrix': conf_matrix}) + + # predict coarse matches from conf_matrix + data.update(**self.get_coarse_match(conf_matrix, data)) + + @torch.no_grad() + def get_coarse_match(self, conf_matrix, data): + """ + Args: + conf_matrix (torch.Tensor): [N, L, S] + data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c'] + Returns: + coarse_matches (dict): { + 'b_ids' (torch.Tensor): [M'], + 'i_ids' (torch.Tensor): [M'], + 'j_ids' (torch.Tensor): [M'], + 'gt_mask' (torch.Tensor): [M'], + 'm_bids' (torch.Tensor): [M], + 'mkpts0_c' (torch.Tensor): [M, 2], + 'mkpts1_c' (torch.Tensor): [M, 2], + 'mconf' (torch.Tensor): [M]} + """ + axes_lengths = { + 'h0c': data['hw0_c'][0], + 'w0c': data['hw0_c'][1], + 'h1c': data['hw1_c'][0], + 'w1c': data['hw1_c'][1] + } + _device = conf_matrix.device + # 1. confidence thresholding + mask = conf_matrix > self.thr + mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c', + **axes_lengths) + if 'mask0' not in data: + mask_border(mask, self.border_rm, False) + else: + mask_border_with_padding(mask, self.border_rm, False, + data['mask0'], data['mask1']) + mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)', + **axes_lengths) + + # 2. mutual nearest + mask = mask \ + * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \ + * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0]) + + # 3. find all valid coarse matches + # this only works when at most one `True` in each row + mask_v, all_j_ids = mask.max(dim=2) + b_ids, i_ids = torch.where(mask_v) + j_ids = all_j_ids[b_ids, i_ids] + mconf = conf_matrix[b_ids, i_ids, j_ids] + + # 4. Random sampling of training samples for fine-level LoFTR + # (optional) pad samples with gt coarse-level matches + if self.training: + # NOTE: + # The sampling is performed across all pairs in a batch without manually balancing + # #samples for fine-level increases w.r.t. batch_size + if 'mask0' not in data: + num_candidates_max = mask.size(0) * max( + mask.size(1), mask.size(2)) + else: + num_candidates_max = compute_max_candidates( + data['mask0'], data['mask1']) + num_matches_train = int(num_candidates_max * + self.train_coarse_percent) + num_matches_pred = len(b_ids) + assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches" + + # pred_indices is to select from prediction + if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min: + pred_indices = torch.arange(num_matches_pred, device=_device) + else: + pred_indices = torch.randint( + num_matches_pred, + (num_matches_train - self.train_pad_num_gt_min, ), + device=_device) + + # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200) + gt_pad_indices = torch.randint( + len(data['spv_b_ids']), + (max(num_matches_train - num_matches_pred, + self.train_pad_num_gt_min), ), + device=_device) + mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero + + b_ids, i_ids, j_ids, mconf = map( + lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]], + dim=0), + *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']], + [j_ids, data['spv_j_ids']], [mconf, mconf_gt])) + + # These matches select patches that feed into fine-level network + coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids} + + # 4. Update with matches in original image resolution + scale = data['hw0_i'][0] / data['hw0_c'][0] + scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale + scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale + mkpts0_c = torch.stack( + [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]], + dim=1) * scale0 + mkpts1_c = torch.stack( + [j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]], + dim=1) * scale1 + + # These matches is the current prediction (for visualization) + coarse_matches.update({ + 'gt_mask': mconf == 0, + 'm_bids': b_ids[mconf != 0], # mconf == 0 => gt matches + 'mkpts0_c': mkpts0_c[mconf != 0], + 'mkpts1_c': mkpts1_c[mconf != 0], + 'mconf': mconf[mconf != 0] + }) + + return coarse_matches diff --git a/one2345_elev_est/oee/models/loftr/utils/cvpr_ds_config.py b/one2345_elev_est/oee/models/loftr/utils/cvpr_ds_config.py new file mode 100644 index 0000000000000000000000000000000000000000..1c9ce70154d3a1b961d3b4f08897415720f451f8 --- /dev/null +++ b/one2345_elev_est/oee/models/loftr/utils/cvpr_ds_config.py @@ -0,0 +1,50 @@ +from yacs.config import CfgNode as CN + + +def lower_config(yacs_cfg): + if not isinstance(yacs_cfg, CN): + return yacs_cfg + return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()} + + +_CN = CN() +_CN.BACKBONE_TYPE = 'ResNetFPN' +_CN.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)] +_CN.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd +_CN.FINE_CONCAT_COARSE_FEAT = True + +# 1. LoFTR-backbone (local feature CNN) config +_CN.RESNETFPN = CN() +_CN.RESNETFPN.INITIAL_DIM = 128 +_CN.RESNETFPN.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3 + +# 2. LoFTR-coarse module config +_CN.COARSE = CN() +_CN.COARSE.D_MODEL = 256 +_CN.COARSE.D_FFN = 256 +_CN.COARSE.NHEAD = 8 +_CN.COARSE.LAYER_NAMES = ['self', 'cross'] * 4 +_CN.COARSE.ATTENTION = 'linear' # options: ['linear', 'full'] +_CN.COARSE.TEMP_BUG_FIX = False + +# 3. Coarse-Matching config +_CN.MATCH_COARSE = CN() +_CN.MATCH_COARSE.THR = 0.2 +_CN.MATCH_COARSE.BORDER_RM = 2 +_CN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn'] +_CN.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1 +_CN.MATCH_COARSE.SKH_ITERS = 3 +_CN.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0 +_CN.MATCH_COARSE.SKH_PREFILTER = True +_CN.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.4 # training tricks: save GPU memory +_CN.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock + +# 4. LoFTR-fine module config +_CN.FINE = CN() +_CN.FINE.D_MODEL = 128 +_CN.FINE.D_FFN = 128 +_CN.FINE.NHEAD = 8 +_CN.FINE.LAYER_NAMES = ['self', 'cross'] * 1 +_CN.FINE.ATTENTION = 'linear' + +default_cfg = lower_config(_CN) diff --git a/one2345_elev_est/oee/models/loftr/utils/fine_matching.py b/one2345_elev_est/oee/models/loftr/utils/fine_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..6e77aded52e1eb5c01e22c2738104f3b09d6922a --- /dev/null +++ b/one2345_elev_est/oee/models/loftr/utils/fine_matching.py @@ -0,0 +1,74 @@ +import math +import torch +import torch.nn as nn + +from kornia.geometry.subpix import dsnt +from kornia.utils.grid import create_meshgrid + + +class FineMatching(nn.Module): + """FineMatching with s2d paradigm""" + + def __init__(self): + super().__init__() + + def forward(self, feat_f0, feat_f1, data): + """ + Args: + feat0 (torch.Tensor): [M, WW, C] + feat1 (torch.Tensor): [M, WW, C] + data (dict) + Update: + data (dict):{ + 'expec_f' (torch.Tensor): [M, 3], + 'mkpts0_f' (torch.Tensor): [M, 2], + 'mkpts1_f' (torch.Tensor): [M, 2]} + """ + M, WW, C = feat_f0.shape + W = int(math.sqrt(WW)) + scale = data['hw0_i'][0] / data['hw0_f'][0] + self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale + + # corner case: if no coarse matches found + if M == 0: + assert self.training == False, "M is always >0, when training, see coarse_matching.py" + # logger.warning('No matches found in coarse-level.') + data.update({ + 'expec_f': torch.empty(0, 3, device=feat_f0.device), + 'mkpts0_f': data['mkpts0_c'], + 'mkpts1_f': data['mkpts1_c'], + }) + return + + feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2, :] + sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1) + softmax_temp = 1. / C**.5 + heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W) + + # compute coordinates from heatmap + coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2] + grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2] + + # compute std over + var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2] + std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability + + # for fine-level supervision + data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)}) + + # compute absolute kpt coords + self.get_fine_match(coords_normalized, data) + + @torch.no_grad() + def get_fine_match(self, coords_normed, data): + W, WW, C, scale = self.W, self.WW, self.C, self.scale + + # mkpts0_f and mkpts1_f + mkpts0_f = data['mkpts0_c'] + scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale + mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])] + + data.update({ + "mkpts0_f": mkpts0_f, + "mkpts1_f": mkpts1_f + }) diff --git a/one2345_elev_est/oee/models/loftr/utils/geometry.py b/one2345_elev_est/oee/models/loftr/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..f95cdb65b48324c4f4ceb20231b1bed992b41116 --- /dev/null +++ b/one2345_elev_est/oee/models/loftr/utils/geometry.py @@ -0,0 +1,54 @@ +import torch + + +@torch.no_grad() +def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1): + """ Warp kpts0 from I0 to I1 with depth, K and Rt + Also check covisibility and depth consistency. + Depth is consistent if relative error < 0.2 (hard-coded). + + Args: + kpts0 (torch.Tensor): [N, L, 2] - , + depth0 (torch.Tensor): [N, H, W], + depth1 (torch.Tensor): [N, H, W], + T_0to1 (torch.Tensor): [N, 3, 4], + K0 (torch.Tensor): [N, 3, 3], + K1 (torch.Tensor): [N, 3, 3], + Returns: + calculable_mask (torch.Tensor): [N, L] + warped_keypoints0 (torch.Tensor): [N, L, 2] + """ + kpts0_long = kpts0.round().long() + + # Sample depth, get calculable_mask on depth != 0 + kpts0_depth = torch.stack( + [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0 + ) # (N, L) + nonzero_mask = kpts0_depth != 0 + + # Unproject + kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3) + kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) + + # Rigid Transform + w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) + w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] + + # Project + w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) + w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth + + # Covisible Check + h, w = depth1.shape[1:3] + covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \ + (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1) + w_kpts0_long = w_kpts0.long() + w_kpts0_long[~covisible_mask, :] = 0 + + w_kpts0_depth = torch.stack( + [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0 + ) # (N, L) + consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2 + valid_mask = nonzero_mask * covisible_mask * consistent_mask + + return valid_mask, w_kpts0 diff --git a/one2345_elev_est/oee/models/loftr/utils/position_encoding.py b/one2345_elev_est/oee/models/loftr/utils/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..732d28c814ef93bf48d338ba7554f6dcfc3b880e --- /dev/null +++ b/one2345_elev_est/oee/models/loftr/utils/position_encoding.py @@ -0,0 +1,42 @@ +import math +import torch +from torch import nn + + +class PositionEncodingSine(nn.Module): + """ + This is a sinusoidal position encoding that generalized to 2-dimensional images + """ + + def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True): + """ + Args: + max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels + temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41), + the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact + on the final performance. For now, we keep both impls for backward compatability. + We will remove the buggy impl after re-training all variants of our released models. + """ + super().__init__() + + pe = torch.zeros((d_model, *max_shape)) + y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0) + x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0) + if temp_bug_fix: + div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2))) + else: # a buggy implementation (for backward compatability only) + div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2)) + div_term = div_term[:, None, None] # [C//4, 1, 1] + pe[0::4, :, :] = torch.sin(x_position * div_term) + pe[1::4, :, :] = torch.cos(x_position * div_term) + pe[2::4, :, :] = torch.sin(y_position * div_term) + pe[3::4, :, :] = torch.cos(y_position * div_term) + + self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W] + + def forward(self, x): + """ + Args: + x: [N, C, H, W] + """ + return x + self.pe[:, :, :x.size(2), :x.size(3)] diff --git a/one2345_elev_est/oee/models/loftr/utils/supervision.py b/one2345_elev_est/oee/models/loftr/utils/supervision.py new file mode 100644 index 0000000000000000000000000000000000000000..8ce6e79ec72b45fcb6b187e33bda93a47b168acd --- /dev/null +++ b/one2345_elev_est/oee/models/loftr/utils/supervision.py @@ -0,0 +1,151 @@ +from math import log +from loguru import logger + +import torch +from einops import repeat +from kornia.utils import create_meshgrid + +from .geometry import warp_kpts + +############## ↓ Coarse-Level supervision ↓ ############## + + +@torch.no_grad() +def mask_pts_at_padded_regions(grid_pt, mask): + """For megadepth dataset, zero-padding exists in images""" + mask = repeat(mask, 'n h w -> n (h w) c', c=2) + grid_pt[~mask.bool()] = 0 + return grid_pt + + +@torch.no_grad() +def spvs_coarse(data, config): + """ + Update: + data (dict): { + "conf_matrix_gt": [N, hw0, hw1], + 'spv_b_ids': [M] + 'spv_i_ids': [M] + 'spv_j_ids': [M] + 'spv_w_pt0_i': [N, hw0, 2], in original image resolution + 'spv_pt1_i': [N, hw1, 2], in original image resolution + } + + NOTE: + - for scannet dataset, there're 3 kinds of resolution {i, c, f} + - for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f} + """ + # 1. misc + device = data['image0'].device + N, _, H0, W0 = data['image0'].shape + _, _, H1, W1 = data['image1'].shape + scale = config['LOFTR']['RESOLUTION'][0] + scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale + scale1 = scale * data['scale1'][:, None] if 'scale0' in data else scale + h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1]) + + # 2. warp grids + # create kpts in meshgrid and resize them to image resolution + grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2] + grid_pt0_i = scale0 * grid_pt0_c + grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1) + grid_pt1_i = scale1 * grid_pt1_c + + # mask padded region to (0, 0), so no need to manually mask conf_matrix_gt + if 'mask0' in data: + grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data['mask0']) + grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data['mask1']) + + # warp kpts bi-directionally and resize them to coarse-level resolution + # (no depth consistency check, since it leads to worse results experimentally) + # (unhandled edge case: points with 0-depth will be warped to the left-up corner) + _, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1']) + _, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0']) + w_pt0_c = w_pt0_i / scale1 + w_pt1_c = w_pt1_i / scale0 + + # 3. check if mutual nearest neighbor + w_pt0_c_round = w_pt0_c[:, :, :].round().long() + nearest_index1 = w_pt0_c_round[..., 0] + w_pt0_c_round[..., 1] * w1 + w_pt1_c_round = w_pt1_c[:, :, :].round().long() + nearest_index0 = w_pt1_c_round[..., 0] + w_pt1_c_round[..., 1] * w0 + + # corner case: out of boundary + def out_bound_mask(pt, w, h): + return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h) + nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0 + nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0 + + loop_back = torch.stack([nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0) + correct_0to1 = loop_back == torch.arange(h0*w0, device=device)[None].repeat(N, 1) + correct_0to1[:, 0] = False # ignore the top-left corner + + # 4. construct a gt conf_matrix + conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device) + b_ids, i_ids = torch.where(correct_0to1 != 0) + j_ids = nearest_index1[b_ids, i_ids] + + conf_matrix_gt[b_ids, i_ids, j_ids] = 1 + data.update({'conf_matrix_gt': conf_matrix_gt}) + + # 5. save coarse matches(gt) for training fine level + if len(b_ids) == 0: + logger.warning(f"No groundtruth coarse match found for: {data['pair_names']}") + # this won't affect fine-level loss calculation + b_ids = torch.tensor([0], device=device) + i_ids = torch.tensor([0], device=device) + j_ids = torch.tensor([0], device=device) + + data.update({ + 'spv_b_ids': b_ids, + 'spv_i_ids': i_ids, + 'spv_j_ids': j_ids + }) + + # 6. save intermediate results (for fast fine-level computation) + data.update({ + 'spv_w_pt0_i': w_pt0_i, + 'spv_pt1_i': grid_pt1_i + }) + + +def compute_supervision_coarse(data, config): + assert len(set(data['dataset_name'])) == 1, "Do not support mixed datasets training!" + data_source = data['dataset_name'][0] + if data_source.lower() in ['scannet', 'megadepth']: + spvs_coarse(data, config) + else: + raise ValueError(f'Unknown data source: {data_source}') + + +############## ↓ Fine-Level supervision ↓ ############## + +@torch.no_grad() +def spvs_fine(data, config): + """ + Update: + data (dict):{ + "expec_f_gt": [M, 2]} + """ + # 1. misc + # w_pt0_i, pt1_i = data.pop('spv_w_pt0_i'), data.pop('spv_pt1_i') + w_pt0_i, pt1_i = data['spv_w_pt0_i'], data['spv_pt1_i'] + scale = config['LOFTR']['RESOLUTION'][1] + radius = config['LOFTR']['FINE_WINDOW_SIZE'] // 2 + + # 2. get coarse prediction + b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids'] + + # 3. compute gt + scale = scale * data['scale1'][b_ids] if 'scale0' in data else scale + # `expec_f_gt` might exceed the window, i.e. abs(*) > 1, which would be filtered later + expec_f_gt = (w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius # [M, 2] + data.update({"expec_f_gt": expec_f_gt}) + + +def compute_supervision_fine(data, config): + data_source = data['dataset_name'][0] + if data_source.lower() in ['scannet', 'megadepth']: + spvs_fine(data, config) + else: + raise NotImplementedError diff --git a/one2345_elev_est/oee/utils/elev_est_api.py b/one2345_elev_est/oee/utils/elev_est_api.py new file mode 100644 index 0000000000000000000000000000000000000000..a82345c02eae79e19e450c7d4583467da153f501 --- /dev/null +++ b/one2345_elev_est/oee/utils/elev_est_api.py @@ -0,0 +1,206 @@ +import matplotlib.pyplot as plt +import warnings + +import numpy as np +import cv2 +import os +import os.path as osp +import imageio +from copy import deepcopy + +import loguru +import torch +from oee.models.loftr import LoFTR, default_cfg +import matplotlib.cm as cm + +from oee.utils import plt_utils +from oee.utils.plotting import make_matching_figure +from oee.utils.utils3d import rect_to_img, canonical_to_camera, calc_pose + + +class ElevEstHelper: + _feature_matcher = None + + @classmethod + def get_feature_matcher(cls): + if cls._feature_matcher is None: + loguru.logger.info("Loading feature matcher...") + _default_cfg = deepcopy(default_cfg) + _default_cfg['coarse']['temp_bug_fix'] = True # set to False when using the old ckpt + matcher = LoFTR(config=_default_cfg) + ckpt_path = "weights/indoor_ds_new.ckpt" + if not osp.exists(ckpt_path): + loguru.logger.info("Downloading feature matcher...") + os.makedirs("weights", exist_ok=True) + import gdown + gdown.cached_download(url="https://drive.google.com/uc?id=19s3QvcCWQ6g-N1PrYlDCg-2mOJZ3kkgS", + path=ckpt_path) + matcher.load_state_dict(torch.load(ckpt_path)['state_dict']) + matcher = matcher.eval().cuda() + cls._feature_matcher = matcher + return cls._feature_matcher + + +def mask_out_bkgd(img_path, dbg=False): + img = imageio.imread_v2(img_path) + if img.shape[-1] == 4: + fg_mask = img[:, :, :3] + else: + loguru.logger.info("Image has no alpha channel, using thresholding to mask out background") + fg_mask = ~(img > 245).all(axis=-1) + if dbg: + plt.imshow(plt_utils.vis_mask(img, fg_mask.astype(np.uint8), color=[0, 255, 0])) + plt.show() + return fg_mask + + +def get_feature_matching(img_paths, dbg=False): + assert len(img_paths) == 4 + matcher = ElevEstHelper.get_feature_matcher() + feature_matching = {} + masks = [] + for i in range(4): + mask = mask_out_bkgd(img_paths[i], dbg=dbg) + masks.append(mask) + for i in range(0, 4): + for j in range(i + 1, 4): + img0_pth = img_paths[i] + img1_pth = img_paths[j] + mask0 = masks[i] + mask1 = masks[j] + img0_raw = cv2.imread(img0_pth, cv2.IMREAD_GRAYSCALE) + img1_raw = cv2.imread(img1_pth, cv2.IMREAD_GRAYSCALE) + original_shape = img0_raw.shape + img0_raw_resized = cv2.resize(img0_raw, (480, 480)) + img1_raw_resized = cv2.resize(img1_raw, (480, 480)) + + img0 = torch.from_numpy(img0_raw_resized)[None][None].cuda() / 255. + img1 = torch.from_numpy(img1_raw_resized)[None][None].cuda() / 255. + batch = {'image0': img0, 'image1': img1} + + # Inference with LoFTR and get prediction + with torch.no_grad(): + matcher(batch) + mkpts0 = batch['mkpts0_f'].cpu().numpy() + mkpts1 = batch['mkpts1_f'].cpu().numpy() + mconf = batch['mconf'].cpu().numpy() + mkpts0[:, 0] = mkpts0[:, 0] * original_shape[1] / 480 + mkpts0[:, 1] = mkpts0[:, 1] * original_shape[0] / 480 + mkpts1[:, 0] = mkpts1[:, 0] * original_shape[1] / 480 + mkpts1[:, 1] = mkpts1[:, 1] * original_shape[0] / 480 + keep0 = mask0[mkpts0[:, 1].astype(int), mkpts1[:, 0].astype(int)] + keep1 = mask1[mkpts1[:, 1].astype(int), mkpts1[:, 0].astype(int)] + keep = np.logical_and(keep0, keep1) + mkpts0 = mkpts0[keep] + mkpts1 = mkpts1[keep] + mconf = mconf[keep] + if dbg: + # Draw visualization + color = cm.jet(mconf) + text = [ + 'LoFTR', + 'Matches: {}'.format(len(mkpts0)), + ] + fig = make_matching_figure(img0_raw, img1_raw, mkpts0, mkpts1, color, text=text) + fig.show() + feature_matching[f"{i}_{j}"] = np.concatenate([mkpts0, mkpts1, mconf[:, None]], axis=1) + + return feature_matching + + +def gen_pose_hypothesis(center_elevation): + elevations = np.radians( + [center_elevation, center_elevation - 10, center_elevation + 10, center_elevation, center_elevation]) # 45~120 + azimuths = np.radians([30, 30, 30, 20, 40]) + input_poses = calc_pose(elevations, azimuths, len(azimuths)) + input_poses = input_poses[1:] + input_poses[..., 1] *= -1 + input_poses[..., 2] *= -1 + return input_poses + + +def ba_error_general(K, matches, poses): + projmat0 = K @ poses[0].inverse()[:3, :4] + projmat1 = K @ poses[1].inverse()[:3, :4] + match_01 = matches[0] + pts0 = match_01[:, :2] + pts1 = match_01[:, 2:4] + Xref = cv2.triangulatePoints(projmat0.cpu().numpy(), projmat1.cpu().numpy(), + pts0.cpu().numpy().T, pts1.cpu().numpy().T) + Xref = Xref[:3] / Xref[3:] + Xref = Xref.T + Xref = torch.from_numpy(Xref).cuda().float() + reproj_error = 0 + for match, cp in zip(matches[1:], poses[2:]): + dist = (torch.norm(match_01[:, :2][:, None, :] - match[:, :2][None, :, :], dim=-1)) + if dist.numel() > 0: + # print("dist.shape", dist.shape) + m0to2_index = dist.argmin(1) + keep = dist[torch.arange(match_01.shape[0]), m0to2_index] < 1 + if keep.sum() > 0: + xref_in2 = rect_to_img(K, canonical_to_camera(Xref, cp.inverse())) + reproj_error2 = torch.norm(match[m0to2_index][keep][:, 2:4] - xref_in2[keep], dim=-1) + conf02 = match[m0to2_index][keep][:, -1] + reproj_error += (reproj_error2 * conf02).sum() / (conf02.sum()) + + return reproj_error + + +def find_optim_elev(elevs, nimgs, matches, K, dbg=False): + errs = [] + for elev in elevs: + err = 0 + cam_poses = gen_pose_hypothesis(elev) + for start in range(nimgs - 1): + batch_matches, batch_poses = [], [] + for i in range(start, nimgs + start): + ci = i % nimgs + batch_poses.append(cam_poses[ci]) + for j in range(nimgs - 1): + key = f"{start}_{(start + j + 1) % nimgs}" + match = matches[key] + batch_matches.append(match) + err += ba_error_general(K, batch_matches, batch_poses) + errs.append(err) + errs = torch.tensor(errs) + if dbg: + plt.plot(elevs, errs) + plt.show() + optim_elev = elevs[torch.argmin(errs)].item() + return optim_elev + + +def get_elev_est(feature_matching, min_elev=30, max_elev=150, K=None, dbg=False): + flag = True + matches = {} + for i in range(4): + for j in range(i + 1, 4): + match_ij = feature_matching[f"{i}_{j}"] + if len(match_ij) == 0: + flag = False + match_ji = np.concatenate([match_ij[:, 2:4], match_ij[:, 0:2], match_ij[:, 4:5]], axis=1) + matches[f"{i}_{j}"] = torch.from_numpy(match_ij).float().cuda() + matches[f"{j}_{i}"] = torch.from_numpy(match_ji).float().cuda() + if not flag: + loguru.logger.info("0 matches, could not estimate elevation") + return None + interval = 10 + elevs = np.arange(min_elev, max_elev, interval) + optim_elev1 = find_optim_elev(elevs, 4, matches, K) + + elevs = np.arange(optim_elev1 - 10, optim_elev1 + 10, 1) + optim_elev2 = find_optim_elev(elevs, 4, matches, K) + + return optim_elev2 + + +def elev_est_api(img_paths, min_elev=30, max_elev=150, K=None, dbg=False): + feature_matching = get_feature_matching(img_paths, dbg=dbg) + if K is None: + loguru.logger.warning("K is not provided, using default K") + K = np.array([[280.0, 0, 128.0], + [0, 280.0, 128.0], + [0, 0, 1]]) + K = torch.from_numpy(K).cuda().float() + elev = get_elev_est(feature_matching, min_elev, max_elev, K, dbg=dbg) + return elev diff --git a/one2345_elev_est/oee/utils/plotting.py b/one2345_elev_est/oee/utils/plotting.py new file mode 100644 index 0000000000000000000000000000000000000000..9e7ac1de4b1fb6d0cbeda2f61eca81c68a9ba423 --- /dev/null +++ b/one2345_elev_est/oee/utils/plotting.py @@ -0,0 +1,154 @@ +import bisect +import numpy as np +import matplotlib.pyplot as plt +import matplotlib + + +def _compute_conf_thresh(data): + dataset_name = data['dataset_name'][0].lower() + if dataset_name == 'scannet': + thr = 5e-4 + elif dataset_name == 'megadepth': + thr = 1e-4 + else: + raise ValueError(f'Unknown dataset: {dataset_name}') + return thr + + +# --- VISUALIZATION --- # + +def make_matching_figure( + img0, img1, mkpts0, mkpts1, color, + kpts0=None, kpts1=None, text=[], dpi=75, path=None): + # draw image pair + assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}' + fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi) + axes[0].imshow(img0, cmap='gray') + axes[1].imshow(img1, cmap='gray') + for i in range(2): # clear all frames + axes[i].get_yaxis().set_ticks([]) + axes[i].get_xaxis().set_ticks([]) + for spine in axes[i].spines.values(): + spine.set_visible(False) + plt.tight_layout(pad=1) + + if kpts0 is not None: + assert kpts1 is not None + axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2) + axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2) + + # draw matches + if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0: + fig.canvas.draw() + transFigure = fig.transFigure.inverted() + fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0)) + fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1)) + fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]), + (fkpts0[i, 1], fkpts1[i, 1]), + transform=fig.transFigure, c=color[i], linewidth=1) + for i in range(len(mkpts0))] + + axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4) + axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4) + + # put txts + txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w' + fig.text( + 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes, + fontsize=15, va='top', ha='left', color=txt_color) + + # save or return figure + if path: + plt.savefig(str(path), bbox_inches='tight', pad_inches=0) + plt.close() + else: + return fig + + +def _make_evaluation_figure(data, b_id, alpha='dynamic'): + b_mask = data['m_bids'] == b_id + conf_thr = _compute_conf_thresh(data) + + img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32) + img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32) + kpts0 = data['mkpts0_f'][b_mask].cpu().numpy() + kpts1 = data['mkpts1_f'][b_mask].cpu().numpy() + + # for megadepth, we visualize matches on the resized image + if 'scale0' in data: + kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy()[[1, 0]] + kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy()[[1, 0]] + + epi_errs = data['epi_errs'][b_mask].cpu().numpy() + correct_mask = epi_errs < conf_thr + precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0 + n_correct = np.sum(correct_mask) + n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu()) + recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches) + # recall might be larger than 1, since the calculation of conf_matrix_gt + # uses groundtruth depths and camera poses, but epipolar distance is used here. + + # matching info + if alpha == 'dynamic': + alpha = dynamic_alpha(len(correct_mask)) + color = error_colormap(epi_errs, conf_thr, alpha=alpha) + + text = [ + f'#Matches {len(kpts0)}', + f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}', + f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}' + ] + + # make the figure + figure = make_matching_figure(img0, img1, kpts0, kpts1, + color, text=text) + return figure + +def _make_confidence_figure(data, b_id): + # TODO: Implement confidence figure + raise NotImplementedError() + + +def make_matching_figures(data, config, mode='evaluation'): + """ Make matching figures for a batch. + + Args: + data (Dict): a batch updated by PL_LoFTR. + config (Dict): matcher config + Returns: + figures (Dict[str, List[plt.figure]] + """ + assert mode in ['evaluation', 'confidence'] # 'confidence' + figures = {mode: []} + for b_id in range(data['image0'].size(0)): + if mode == 'evaluation': + fig = _make_evaluation_figure( + data, b_id, + alpha=config.TRAINER.PLOT_MATCHES_ALPHA) + elif mode == 'confidence': + fig = _make_confidence_figure(data, b_id) + else: + raise ValueError(f'Unknown plot mode: {mode}') + figures[mode].append(fig) + return figures + + +def dynamic_alpha(n_matches, + milestones=[0, 300, 1000, 2000], + alphas=[1.0, 0.8, 0.4, 0.2]): + if n_matches == 0: + return 1.0 + ranges = list(zip(alphas, alphas[1:] + [None])) + loc = bisect.bisect_right(milestones, n_matches) - 1 + _range = ranges[loc] + if _range[1] is None: + return _range[0] + return _range[1] + (milestones[loc + 1] - n_matches) / ( + milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1]) + + +def error_colormap(err, thr, alpha=1.0): + assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}" + x = 1 - np.clip(err / (thr * 2), 0, 1) + return np.clip( + np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1) diff --git a/one2345_elev_est/oee/utils/plt_utils.py b/one2345_elev_est/oee/utils/plt_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..92353edab179de9f702633a01e123e94403bd83f --- /dev/null +++ b/one2345_elev_est/oee/utils/plt_utils.py @@ -0,0 +1,318 @@ +import os.path as osp +import os +import matplotlib.pyplot as plt +import torch +import cv2 +import math + +import numpy as np +import tqdm +from cv2 import findContours +from dl_ext.primitive import safe_zip +from dl_ext.timer import EvalTime + + +def plot_confidence(confidence): + n = len(confidence) + plt.plot(np.arange(n), confidence) + plt.show() + + +def image_grid( + images, + rows=None, + cols=None, + fill: bool = True, + show_axes: bool = False, + rgb=None, + show=True, + label=None, + **kwargs +): + """ + A util function for plotting a grid of images. + Args: + images: (N, H, W, 4) array of RGBA images + rows: number of rows in the grid + cols: number of columns in the grid + fill: boolean indicating if the space between images should be filled + show_axes: boolean indicating if the axes of the plots should be visible + rgb: boolean, If True, only RGB channels are plotted. + If False, only the alpha channel is plotted. + Returns: + None + """ + evaltime = EvalTime(disable=True) + evaltime('') + if isinstance(images, torch.Tensor): + images = images.detach().cpu() + if len(images[0].shape) == 2: + rgb = False + if images[0].shape[-1] == 2: + # flow + images = [flow_to_image(im) for im in images] + if (rows is None) != (cols is None): + raise ValueError("Specify either both rows and cols or neither.") + + if rows is None: + rows = int(len(images) ** 0.5) + cols = math.ceil(len(images) / rows) + + gridspec_kw = {"wspace": 0.0, "hspace": 0.0} if fill else {} + if len(images) < 50: + figsize = (10, 10) + else: + figsize = (15, 15) + evaltime('0.5') + plt.figure(figsize=figsize) + # fig, axarr = plt.subplots(rows, cols, gridspec_kw=gridspec_kw, figsize=figsize) + if label: + # fig.suptitle(label, fontsize=30) + plt.suptitle(label, fontsize=30) + # bleed = 0 + # fig.subplots_adjust(left=bleed, bottom=bleed, right=(1 - bleed), top=(1 - bleed)) + evaltime('subplots') + + # for i, (ax, im) in enumerate(tqdm.tqdm(zip(axarr.ravel(), images), leave=True, total=len(images))): + for i in range(len(images)): + # evaltime(f'{i} begin') + plt.subplot(rows, cols, i + 1) + if rgb: + # only render RGB channels + plt.imshow(images[i][..., :3], **kwargs) + # ax.imshow(im[..., :3], **kwargs) + else: + # only render Alpha channel + plt.imshow(images[i], **kwargs) + # ax.imshow(im, **kwargs) + if not show_axes: + plt.axis('off') + # ax.set_axis_off() + # ax.set_title(f'{i}') + plt.title(f'{i}') + # evaltime(f'{i} end') + evaltime('2') + if show: + plt.show() + # return fig + + +def depth_grid( + depths, + rows=None, + cols=None, + fill: bool = True, + show_axes: bool = False, +): + """ + A util function for plotting a grid of images. + Args: + images: (N, H, W, 4) array of RGBA images + rows: number of rows in the grid + cols: number of columns in the grid + fill: boolean indicating if the space between images should be filled + show_axes: boolean indicating if the axes of the plots should be visible + rgb: boolean, If True, only RGB channels are plotted. + If False, only the alpha channel is plotted. + Returns: + None + """ + if (rows is None) != (cols is None): + raise ValueError("Specify either both rows and cols or neither.") + + if rows is None: + rows = len(depths) + cols = 1 + + gridspec_kw = {"wspace": 0.0, "hspace": 0.0} if fill else {} + fig, axarr = plt.subplots(rows, cols, gridspec_kw=gridspec_kw, figsize=(15, 9)) + bleed = 0 + fig.subplots_adjust(left=bleed, bottom=bleed, right=(1 - bleed), top=(1 - bleed)) + + for ax, im in zip(axarr.ravel(), depths): + ax.imshow(im) + if not show_axes: + ax.set_axis_off() + plt.show() + + +def hover_masks_on_imgs(images, masks): + masks = np.array(masks) + new_imgs = [] + tids = list(range(1, masks.max() + 1)) + colors = colormap(rgb=True, lighten=True) + for im, mask in tqdm.tqdm(safe_zip(images, masks), total=len(images)): + for tid in tids: + im = vis_mask( + im, + (mask == tid).astype(np.uint8), + color=colors[tid], + alpha=0.5, + border_alpha=0.5, + border_color=[255, 255, 255], + border_thick=3) + new_imgs.append(im) + return new_imgs + + +def vis_mask(img, + mask, + color=[255, 255, 255], + alpha=0.4, + show_border=True, + border_alpha=0.5, + border_thick=1, + border_color=None): + """Visualizes a single binary mask.""" + if isinstance(mask, torch.Tensor): + from anypose.utils.pn_utils import to_array + mask = to_array(mask > 0).astype(np.uint8) + img = img.astype(np.float32) + idx = np.nonzero(mask) + + img[idx[0], idx[1], :] *= 1.0 - alpha + img[idx[0], idx[1], :] += [alpha * x for x in color] + + if show_border: + contours, _ = findContours( + mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + # contours = [c for c in contours if c.shape[0] > 10] + if border_color is None: + border_color = color + if not isinstance(border_color, list): + border_color = border_color.tolist() + if border_alpha < 1: + with_border = img.copy() + cv2.drawContours(with_border, contours, -1, border_color, + border_thick, cv2.LINE_AA) + img = (1 - border_alpha) * img + border_alpha * with_border + else: + cv2.drawContours(img, contours, -1, border_color, border_thick, + cv2.LINE_AA) + + return img.astype(np.uint8) + + +def colormap(rgb=False, lighten=True): + """Copied from Detectron codebase.""" + color_list = np.array( + [ + 0.000, 0.447, 0.741, + 0.850, 0.325, 0.098, + 0.929, 0.694, 0.125, + 0.494, 0.184, 0.556, + 0.466, 0.674, 0.188, + 0.301, 0.745, 0.933, + 0.635, 0.078, 0.184, + 0.300, 0.300, 0.300, + 0.600, 0.600, 0.600, + 1.000, 0.000, 0.000, + 1.000, 0.500, 0.000, + 0.749, 0.749, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 1.000, + 0.667, 0.000, 1.000, + 0.333, 0.333, 0.000, + 0.333, 0.667, 0.000, + 0.333, 1.000, 0.000, + 0.667, 0.333, 0.000, + 0.667, 0.667, 0.000, + 0.667, 1.000, 0.000, + 1.000, 0.333, 0.000, + 1.000, 0.667, 0.000, + 1.000, 1.000, 0.000, + 0.000, 0.333, 0.500, + 0.000, 0.667, 0.500, + 0.000, 1.000, 0.500, + 0.333, 0.000, 0.500, + 0.333, 0.333, 0.500, + 0.333, 0.667, 0.500, + 0.333, 1.000, 0.500, + 0.667, 0.000, 0.500, + 0.667, 0.333, 0.500, + 0.667, 0.667, 0.500, + 0.667, 1.000, 0.500, + 1.000, 0.000, 0.500, + 1.000, 0.333, 0.500, + 1.000, 0.667, 0.500, + 1.000, 1.000, 0.500, + 0.000, 0.333, 1.000, + 0.000, 0.667, 1.000, + 0.000, 1.000, 1.000, + 0.333, 0.000, 1.000, + 0.333, 0.333, 1.000, + 0.333, 0.667, 1.000, + 0.333, 1.000, 1.000, + 0.667, 0.000, 1.000, + 0.667, 0.333, 1.000, + 0.667, 0.667, 1.000, + 0.667, 1.000, 1.000, + 1.000, 0.000, 1.000, + 1.000, 0.333, 1.000, + 1.000, 0.667, 1.000, + 0.167, 0.000, 0.000, + 0.333, 0.000, 0.000, + 0.500, 0.000, 0.000, + 0.667, 0.000, 0.000, + 0.833, 0.000, 0.000, + 1.000, 0.000, 0.000, + 0.000, 0.167, 0.000, + 0.000, 0.333, 0.000, + 0.000, 0.500, 0.000, + 0.000, 0.667, 0.000, + 0.000, 0.833, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 0.167, + 0.000, 0.000, 0.333, + 0.000, 0.000, 0.500, + 0.000, 0.000, 0.667, + 0.000, 0.000, 0.833, + 0.000, 0.000, 1.000, + 0.000, 0.000, 0.000, + 0.143, 0.143, 0.143, + 0.286, 0.286, 0.286, + 0.429, 0.429, 0.429, + 0.571, 0.571, 0.571, + 0.714, 0.714, 0.714, + 0.857, 0.857, 0.857, + 1.000, 1.000, 1.000 + ] + ).astype(np.float32) + color_list = color_list.reshape((-1, 3)) + if not rgb: + color_list = color_list[:, ::-1] + + if lighten: + # Make all the colors a little lighter / whiter. This is copied + # from the detectron visualization code (search for 'w_ratio'). + w_ratio = 0.4 + color_list = (color_list * (1 - w_ratio) + w_ratio) + return color_list * 255 + + +def vis_layer_mask(masks, save_path=None): + masks = torch.as_tensor(masks) + tids = masks.unique().tolist() + tids.remove(0) + for tid in tqdm.tqdm(tids): + show = save_path is None + image_grid(masks == tid, label=f'{tid}', show=show) + if save_path: + os.makedirs(osp.dirname(save_path), exist_ok=True) + plt.savefig(save_path % tid) + plt.close('all') + + +def show(x, **kwargs): + if isinstance(x, torch.Tensor): + x = x.detach().cpu() + plt.imshow(x, **kwargs) + plt.show() + + +def vis_title(rgb, text, shift_y=30): + tmp = rgb.copy() + shift_x = rgb.shape[1] // 2 + cv2.putText(tmp, text, + (shift_x, shift_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), thickness=2, lineType=cv2.LINE_AA) + return tmp diff --git a/one2345_elev_est/oee/utils/utils3d.py b/one2345_elev_est/oee/utils/utils3d.py new file mode 100644 index 0000000000000000000000000000000000000000..9cc92fbde4143a4ed5187c989e3f98a896e7caab --- /dev/null +++ b/one2345_elev_est/oee/utils/utils3d.py @@ -0,0 +1,62 @@ +import numpy as np +import torch + + +def cart_to_hom(pts): + """ + :param pts: (N, 3 or 2) + :return pts_hom: (N, 4 or 3) + """ + if isinstance(pts, np.ndarray): + pts_hom = np.concatenate((pts, np.ones([*pts.shape[:-1], 1], dtype=np.float32)), -1) + else: + ones = torch.ones([*pts.shape[:-1], 1], dtype=torch.float32, device=pts.device) + pts_hom = torch.cat((pts, ones), dim=-1) + return pts_hom + + +def hom_to_cart(pts): + return pts[..., :-1] / pts[..., -1:] + + +def canonical_to_camera(pts, pose): + pts = cart_to_hom(pts) + pts = pts @ pose.transpose(-1, -2) + pts = hom_to_cart(pts) + return pts + + +def rect_to_img(K, pts_rect): + from dl_ext.vision_ext.datasets.kitti.structures import Calibration + pts_2d_hom = pts_rect @ K.t() + pts_img = Calibration.hom_to_cart(pts_2d_hom) + return pts_img + + +def calc_pose(phis, thetas, size, radius=1.2): + import torch + def normalize(vectors): + return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10) + + device = torch.device('cuda') + thetas = torch.FloatTensor(thetas).to(device) + phis = torch.FloatTensor(phis).to(device) + + centers = torch.stack([ + radius * torch.sin(thetas) * torch.sin(phis), + -radius * torch.cos(thetas) * torch.sin(phis), + radius * torch.cos(phis), + ], dim=-1) # [B, 3] + + # lookat + forward_vector = normalize(centers).squeeze(0) + up_vector = torch.FloatTensor([0, 0, 1]).to(device).unsqueeze(0).repeat(size, 1) + right_vector = normalize(torch.cross(up_vector, forward_vector, dim=-1)) + if right_vector.pow(2).sum() < 0.01: + right_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(size, 1) + up_vector = normalize(torch.cross(forward_vector, right_vector, dim=-1)) + + poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1) + poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1) + poses[:, :3, 3] = centers + return poses diff --git a/one2345_elev_est/one2345_elev_est.egg-info/PKG-INFO b/one2345_elev_est/one2345_elev_est.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..3487d5a2d125c78739dd3655fe057b9769a7c54a --- /dev/null +++ b/one2345_elev_est/one2345_elev_est.egg-info/PKG-INFO @@ -0,0 +1,4 @@ +Metadata-Version: 2.1 +Name: one2345-elev-est +Version: 0.1 +Author: chenlinghao diff --git a/one2345_elev_est/one2345_elev_est.egg-info/SOURCES.txt b/one2345_elev_est/one2345_elev_est.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..281c2ecbcdefaac7a27c6ef2a61b7958a3ca102a --- /dev/null +++ b/one2345_elev_est/one2345_elev_est.egg-info/SOURCES.txt @@ -0,0 +1,5 @@ +setup.py +one2345_elev_est.egg-info/PKG-INFO +one2345_elev_est.egg-info/SOURCES.txt +one2345_elev_est.egg-info/dependency_links.txt +one2345_elev_est.egg-info/top_level.txt \ No newline at end of file diff --git a/one2345_elev_est/one2345_elev_est.egg-info/dependency_links.txt b/one2345_elev_est/one2345_elev_est.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/one2345_elev_est/one2345_elev_est.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/one2345_elev_est/one2345_elev_est.egg-info/top_level.txt b/one2345_elev_est/one2345_elev_est.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/one2345_elev_est/one2345_elev_est.egg-info/top_level.txt @@ -0,0 +1 @@ + diff --git a/one2345_elev_est/requirements.txt b/one2345_elev_est/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..d6459497e3f3aad592f73f06d0207167cb2b4ec5 --- /dev/null +++ b/one2345_elev_est/requirements.txt @@ -0,0 +1,42 @@ +dl_ext +easydict +glumpy +gym +h5py +imageio +loguru +matplotlib +mplib +multipledispatch +# numpy +open3d +packaging +# pandas +Pillow +pycocotools +# pyk4a +motion-planning +# pyrealsense2 +pyrender +# pytorch3d +PyYAML +scikit_image +scikit_learn +scipy +screeninfo +# seaborn +setuptools +# skimage +tensorboardX +termcolor +# torch +# torchvision +tqdm +transforms3d +trimesh +yacs +zarr +sapien +pyglet==1.5.27 +wis3d +git+https://github.com/NVlabs/nvdiffrast.git diff --git a/one2345_elev_est/setup.py b/one2345_elev_est/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..370ef83c85e5700ee440957117ef9382304f8321 --- /dev/null +++ b/one2345_elev_est/setup.py @@ -0,0 +1,9 @@ +from setuptools import find_packages +from setuptools import setup + +setup( + name="one2345_elev_est", + version="0.1", + author="chenlinghao", + packages=find_packages(exclude=("configs", "tests",)), +) diff --git a/one2345_elev_est/tools/estimate_wild_imgs.py b/one2345_elev_est/tools/estimate_wild_imgs.py new file mode 100644 index 0000000000000000000000000000000000000000..47103d46e48d3758a2cc237f659dd6fe5fc183c7 --- /dev/null +++ b/one2345_elev_est/tools/estimate_wild_imgs.py @@ -0,0 +1,35 @@ +import tqdm +import imageio +import json +import os.path as osp +import os + +from oee.utils import plt_utils +from oee.utils.elev_est_api import elev_est_api + + +def visualize(img_paths, elev): + imgs = [imageio.imread_v2(img_path) for img_path in img_paths] + plt_utils.image_grid(imgs, 2, 2, label=f"elev={elev}") + + +def estimate_elev(root_dir): + # root_dir = "/home/linghao/Datasets/objaverse-processed/zero12345_img/wild" + # dataset = "supp_fail" + # root_dir = "/home/chao/chao/OpenComplete/zero123/zero123/gradio_tmp/" + # obj_names = sorted(os.listdir(root_dir)) + # results = {} + # for obj_name in tqdm.tqdm(obj_names): + img_dir = osp.join(root_dir, "stage2_8") + img_paths = [] + for i in range(4): + img_paths.append(f"{img_dir}/0_{i}.png") + elev = elev_est_api(img_paths) + # visualize(img_paths, elev) + # results[obj_name] = elev + # json.dump(results, open(osp.join(root_dir, f"../{dataset}_elev.json"), "w"), indent=4) + return elev + + +# if __name__ == '__main__': +# main() diff --git a/one2345_elev_est/tools/example.py b/one2345_elev_est/tools/example.py new file mode 100644 index 0000000000000000000000000000000000000000..065f31e3e64f21494b04ca2aed87e665ddc6d23d --- /dev/null +++ b/one2345_elev_est/tools/example.py @@ -0,0 +1,38 @@ +import imageio +import numpy as np + +from oee.utils import plt_utils +from oee.utils.elev_est_api import elev_est_api +import argparse + + +def visualize(img_paths, elev): + imgs = [imageio.imread_v2(img_path) for img_path in img_paths] + plt_utils.image_grid(imgs, 2, 2, label=f"elev={elev}") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--img_paths", type=str, nargs=4, help="image paths", + default=["assets/example_data/0_0.png", + "assets/example_data/0_1.png", + "assets/example_data/0_2.png", + "assets/example_data/0_3.png"]) + parser.add_argument("--min_elev", type=float, default=30, help="min elevation") + parser.add_argument("--max_elev", type=float, default=150, help="max elevation") + parser.add_argument("--dbg", default=False, action="store_true", help="debug mode") + parser.add_argument("--K_path", type=str, default=None, help="path to K") + args = parser.parse_args() + + if args.K_path is not None: + K = np.loadtxt(args.K_path) + else: + K = None + + elev = elev_est_api(args.img_paths, args.min_elev, args.max_elev, K, args.dbg) + + visualize(args.img_paths, elev) + + +if __name__ == '__main__': + main() diff --git a/one2345_elev_est/tools/weights/indoor_ds_new.ckpt b/one2345_elev_est/tools/weights/indoor_ds_new.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..ef68cc903b08e710d46a51c2aeb97407e047f3a5 --- /dev/null +++ b/one2345_elev_est/tools/weights/indoor_ds_new.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:be9ff88b323ec27889114719f668ae41aff7034b56a4c4acbd46b8b180b87ed3 +size 46355053 diff --git a/taming-transformers b/taming-transformers new file mode 160000 index 0000000000000000000000000000000000000000..3ba01b241669f5ade541ce990f7650a3b8f65318 --- /dev/null +++ b/taming-transformers @@ -0,0 +1 @@ +Subproject commit 3ba01b241669f5ade541ce990f7650a3b8f65318