Vincentqyw
update: limit keypoints number
e15a186
raw history blame
No virus
5.33 kB
from pathlib import Path
import numpy as np
import torch
import PIL.Image
from tqdm import tqdm
import pycolmap
from ...utils.read_write_model import write_model, read_model
def scene_coordinates(p2D, R_w2c, t_w2c, depth, camera):
assert len(depth) == len(p2D)
ret = pycolmap.image_to_world(p2D, camera._asdict())
p2D_norm = np.asarray(ret["world_points"])
p2D_h = np.concatenate([p2D_norm, np.ones_like(p2D_norm[:, :1])], 1)
p3D_c = p2D_h * depth[:, None]
p3D_w = (p3D_c - t_w2c) @ R_w2c
return p3D_w
def interpolate_depth(depth, kp):
h, w = depth.shape
kp = kp / np.array([[w - 1, h - 1]]) * 2 - 1
assert np.all(kp > -1) and np.all(kp < 1)
depth = torch.from_numpy(depth)[None, None]
kp = torch.from_numpy(kp)[None, None]
grid_sample = torch.nn.functional.grid_sample
# To maximize the number of points that have depth:
# do bilinear interpolation first and then nearest for the remaining points
interp_lin = grid_sample(depth, kp, align_corners=True, mode="bilinear")[
0, :, 0
]
interp_nn = torch.nn.functional.grid_sample(
depth, kp, align_corners=True, mode="nearest"
)[0, :, 0]
interp = torch.where(torch.isnan(interp_lin), interp_nn, interp_lin)
valid = ~torch.any(torch.isnan(interp), 0)
interp_depth = interp.T.numpy().flatten()
valid = valid.numpy()
return interp_depth, valid
def image_path_to_rendered_depth_path(image_name):
parts = image_name.split("/")
name = "_".join(["".join(parts[0].split("-")), parts[1]])
name = name.replace("color", "pose")
name = name.replace("png", "depth.tiff")
return name
def project_to_image(p3D, R, t, camera, eps: float = 1e-4, pad: int = 1):
p3D = (p3D @ R.T) + t
visible = p3D[:, -1] >= eps # keep points in front of the camera
p2D_norm = p3D[:, :-1] / p3D[:, -1:].clip(min=eps)
ret = pycolmap.world_to_image(p2D_norm, camera._asdict())
p2D = np.asarray(ret["image_points"])
size = np.array([camera.width - pad - 1, camera.height - pad - 1])
valid = np.all((p2D >= pad) & (p2D <= size), -1)
valid &= visible
return p2D[valid], valid
def correct_sfm_with_gt_depth(sfm_path, depth_folder_path, output_path):
cameras, images, points3D = read_model(sfm_path)
for imgid, img in tqdm(images.items()):
image_name = img.name
depth_name = image_path_to_rendered_depth_path(image_name)
depth = PIL.Image.open(Path(depth_folder_path) / depth_name)
depth = np.array(depth).astype("float64")
depth = depth / 1000.0 # mm to meter
depth[(depth == 0.0) | (depth > 1000.0)] = np.nan
R_w2c, t_w2c = img.qvec2rotmat(), img.tvec
camera = cameras[img.camera_id]
p3D_ids = img.point3D_ids
p3Ds = np.stack([points3D[i].xyz for i in p3D_ids[p3D_ids != -1]], 0)
p2Ds, valids_projected = project_to_image(p3Ds, R_w2c, t_w2c, camera)
invalid_p3D_ids = p3D_ids[p3D_ids != -1][~valids_projected]
interp_depth, valids_backprojected = interpolate_depth(depth, p2Ds)
scs = scene_coordinates(
p2Ds[valids_backprojected],
R_w2c,
t_w2c,
interp_depth[valids_backprojected],
camera,
)
invalid_p3D_ids = np.append(
invalid_p3D_ids,
p3D_ids[p3D_ids != -1][valids_projected][~valids_backprojected],
)
for p3did in invalid_p3D_ids:
if p3did == -1:
continue
else:
obs_imgids = points3D[p3did].image_ids
invalid_imgids = list(np.where(obs_imgids == img.id)[0])
points3D[p3did] = points3D[p3did]._replace(
image_ids=np.delete(obs_imgids, invalid_imgids),
point2D_idxs=np.delete(
points3D[p3did].point2D_idxs, invalid_imgids
),
)
new_p3D_ids = p3D_ids.copy()
sub_p3D_ids = new_p3D_ids[new_p3D_ids != -1]
valids = np.ones(np.count_nonzero(new_p3D_ids != -1), dtype=bool)
valids[~valids_projected] = False
valids[valids_projected] = valids_backprojected
sub_p3D_ids[~valids] = -1
new_p3D_ids[new_p3D_ids != -1] = sub_p3D_ids
img = img._replace(point3D_ids=new_p3D_ids)
assert len(img.point3D_ids[img.point3D_ids != -1]) == len(
scs
), f"{len(scs)}, {len(img.point3D_ids[img.point3D_ids != -1])}"
for i, p3did in enumerate(img.point3D_ids[img.point3D_ids != -1]):
points3D[p3did] = points3D[p3did]._replace(xyz=scs[i])
images[imgid] = img
output_path.mkdir(parents=True, exist_ok=True)
write_model(cameras, images, points3D, output_path)
if __name__ == "__main__":
dataset = Path("datasets/7scenes")
outputs = Path("outputs/7Scenes")
SCENES = [
"chess",
"fire",
"heads",
"office",
"pumpkin",
"redkitchen",
"stairs",
]
for scene in SCENES:
sfm_path = outputs / scene / "sfm_superpoint+superglue"
depth_path = dataset / f"depth/7scenes_{scene}/train/depth"
output_path = outputs / scene / "sfm_superpoint+superglue+depth"
correct_sfm_with_gt_depth(sfm_path, depth_path, output_path)