from pathlib import Path import numpy as np import torch import PIL.Image from tqdm import tqdm import pycolmap from ...utils.read_write_model import write_model, read_model def scene_coordinates(p2D, R_w2c, t_w2c, depth, camera): assert len(depth) == len(p2D) ret = pycolmap.image_to_world(p2D, camera._asdict()) p2D_norm = np.asarray(ret["world_points"]) p2D_h = np.concatenate([p2D_norm, np.ones_like(p2D_norm[:, :1])], 1) p3D_c = p2D_h * depth[:, None] p3D_w = (p3D_c - t_w2c) @ R_w2c return p3D_w def interpolate_depth(depth, kp): h, w = depth.shape kp = kp / np.array([[w - 1, h - 1]]) * 2 - 1 assert np.all(kp > -1) and np.all(kp < 1) depth = torch.from_numpy(depth)[None, None] kp = torch.from_numpy(kp)[None, None] grid_sample = torch.nn.functional.grid_sample # To maximize the number of points that have depth: # do bilinear interpolation first and then nearest for the remaining points interp_lin = grid_sample(depth, kp, align_corners=True, mode="bilinear")[ 0, :, 0 ] interp_nn = torch.nn.functional.grid_sample( depth, kp, align_corners=True, mode="nearest" )[0, :, 0] interp = torch.where(torch.isnan(interp_lin), interp_nn, interp_lin) valid = ~torch.any(torch.isnan(interp), 0) interp_depth = interp.T.numpy().flatten() valid = valid.numpy() return interp_depth, valid def image_path_to_rendered_depth_path(image_name): parts = image_name.split("/") name = "_".join(["".join(parts[0].split("-")), parts[1]]) name = name.replace("color", "pose") name = name.replace("png", "depth.tiff") return name def project_to_image(p3D, R, t, camera, eps: float = 1e-4, pad: int = 1): p3D = (p3D @ R.T) + t visible = p3D[:, -1] >= eps # keep points in front of the camera p2D_norm = p3D[:, :-1] / p3D[:, -1:].clip(min=eps) ret = pycolmap.world_to_image(p2D_norm, camera._asdict()) p2D = np.asarray(ret["image_points"]) size = np.array([camera.width - pad - 1, camera.height - pad - 1]) valid = np.all((p2D >= pad) & (p2D <= size), -1) valid &= visible return p2D[valid], valid def correct_sfm_with_gt_depth(sfm_path, depth_folder_path, output_path): cameras, images, points3D = read_model(sfm_path) for imgid, img in tqdm(images.items()): image_name = img.name depth_name = image_path_to_rendered_depth_path(image_name) depth = PIL.Image.open(Path(depth_folder_path) / depth_name) depth = np.array(depth).astype("float64") depth = depth / 1000.0 # mm to meter depth[(depth == 0.0) | (depth > 1000.0)] = np.nan R_w2c, t_w2c = img.qvec2rotmat(), img.tvec camera = cameras[img.camera_id] p3D_ids = img.point3D_ids p3Ds = np.stack([points3D[i].xyz for i in p3D_ids[p3D_ids != -1]], 0) p2Ds, valids_projected = project_to_image(p3Ds, R_w2c, t_w2c, camera) invalid_p3D_ids = p3D_ids[p3D_ids != -1][~valids_projected] interp_depth, valids_backprojected = interpolate_depth(depth, p2Ds) scs = scene_coordinates( p2Ds[valids_backprojected], R_w2c, t_w2c, interp_depth[valids_backprojected], camera, ) invalid_p3D_ids = np.append( invalid_p3D_ids, p3D_ids[p3D_ids != -1][valids_projected][~valids_backprojected], ) for p3did in invalid_p3D_ids: if p3did == -1: continue else: obs_imgids = points3D[p3did].image_ids invalid_imgids = list(np.where(obs_imgids == img.id)[0]) points3D[p3did] = points3D[p3did]._replace( image_ids=np.delete(obs_imgids, invalid_imgids), point2D_idxs=np.delete( points3D[p3did].point2D_idxs, invalid_imgids ), ) new_p3D_ids = p3D_ids.copy() sub_p3D_ids = new_p3D_ids[new_p3D_ids != -1] valids = np.ones(np.count_nonzero(new_p3D_ids != -1), dtype=bool) valids[~valids_projected] = False valids[valids_projected] = valids_backprojected sub_p3D_ids[~valids] = -1 new_p3D_ids[new_p3D_ids != -1] = sub_p3D_ids img = img._replace(point3D_ids=new_p3D_ids) assert len(img.point3D_ids[img.point3D_ids != -1]) == len( scs ), f"{len(scs)}, {len(img.point3D_ids[img.point3D_ids != -1])}" for i, p3did in enumerate(img.point3D_ids[img.point3D_ids != -1]): points3D[p3did] = points3D[p3did]._replace(xyz=scs[i]) images[imgid] = img output_path.mkdir(parents=True, exist_ok=True) write_model(cameras, images, points3D, output_path) if __name__ == "__main__": dataset = Path("datasets/7scenes") outputs = Path("outputs/7Scenes") SCENES = [ "chess", "fire", "heads", "office", "pumpkin", "redkitchen", "stairs", ] for scene in SCENES: sfm_path = outputs / scene / "sfm_superpoint+superglue" depth_path = dataset / f"depth/7scenes_{scene}/train/depth" output_path = outputs / scene / "sfm_superpoint+superglue+depth" correct_sfm_with_gt_depth(sfm_path, depth_path, output_path)