Vincentqyw
update: limit keypoints number
e15a186
raw
history blame
5.33 kB
from pathlib import Path
import numpy as np
import torch
import PIL.Image
from tqdm import tqdm
import pycolmap
from ...utils.read_write_model import write_model, read_model
def scene_coordinates(p2D, R_w2c, t_w2c, depth, camera):
assert len(depth) == len(p2D)
ret = pycolmap.image_to_world(p2D, camera._asdict())
p2D_norm = np.asarray(ret["world_points"])
p2D_h = np.concatenate([p2D_norm, np.ones_like(p2D_norm[:, :1])], 1)
p3D_c = p2D_h * depth[:, None]
p3D_w = (p3D_c - t_w2c) @ R_w2c
return p3D_w
def interpolate_depth(depth, kp):
h, w = depth.shape
kp = kp / np.array([[w - 1, h - 1]]) * 2 - 1
assert np.all(kp > -1) and np.all(kp < 1)
depth = torch.from_numpy(depth)[None, None]
kp = torch.from_numpy(kp)[None, None]
grid_sample = torch.nn.functional.grid_sample
# To maximize the number of points that have depth:
# do bilinear interpolation first and then nearest for the remaining points
interp_lin = grid_sample(depth, kp, align_corners=True, mode="bilinear")[
0, :, 0
]
interp_nn = torch.nn.functional.grid_sample(
depth, kp, align_corners=True, mode="nearest"
)[0, :, 0]
interp = torch.where(torch.isnan(interp_lin), interp_nn, interp_lin)
valid = ~torch.any(torch.isnan(interp), 0)
interp_depth = interp.T.numpy().flatten()
valid = valid.numpy()
return interp_depth, valid
def image_path_to_rendered_depth_path(image_name):
parts = image_name.split("/")
name = "_".join(["".join(parts[0].split("-")), parts[1]])
name = name.replace("color", "pose")
name = name.replace("png", "depth.tiff")
return name
def project_to_image(p3D, R, t, camera, eps: float = 1e-4, pad: int = 1):
p3D = (p3D @ R.T) + t
visible = p3D[:, -1] >= eps # keep points in front of the camera
p2D_norm = p3D[:, :-1] / p3D[:, -1:].clip(min=eps)
ret = pycolmap.world_to_image(p2D_norm, camera._asdict())
p2D = np.asarray(ret["image_points"])
size = np.array([camera.width - pad - 1, camera.height - pad - 1])
valid = np.all((p2D >= pad) & (p2D <= size), -1)
valid &= visible
return p2D[valid], valid
def correct_sfm_with_gt_depth(sfm_path, depth_folder_path, output_path):
cameras, images, points3D = read_model(sfm_path)
for imgid, img in tqdm(images.items()):
image_name = img.name
depth_name = image_path_to_rendered_depth_path(image_name)
depth = PIL.Image.open(Path(depth_folder_path) / depth_name)
depth = np.array(depth).astype("float64")
depth = depth / 1000.0 # mm to meter
depth[(depth == 0.0) | (depth > 1000.0)] = np.nan
R_w2c, t_w2c = img.qvec2rotmat(), img.tvec
camera = cameras[img.camera_id]
p3D_ids = img.point3D_ids
p3Ds = np.stack([points3D[i].xyz for i in p3D_ids[p3D_ids != -1]], 0)
p2Ds, valids_projected = project_to_image(p3Ds, R_w2c, t_w2c, camera)
invalid_p3D_ids = p3D_ids[p3D_ids != -1][~valids_projected]
interp_depth, valids_backprojected = interpolate_depth(depth, p2Ds)
scs = scene_coordinates(
p2Ds[valids_backprojected],
R_w2c,
t_w2c,
interp_depth[valids_backprojected],
camera,
)
invalid_p3D_ids = np.append(
invalid_p3D_ids,
p3D_ids[p3D_ids != -1][valids_projected][~valids_backprojected],
)
for p3did in invalid_p3D_ids:
if p3did == -1:
continue
else:
obs_imgids = points3D[p3did].image_ids
invalid_imgids = list(np.where(obs_imgids == img.id)[0])
points3D[p3did] = points3D[p3did]._replace(
image_ids=np.delete(obs_imgids, invalid_imgids),
point2D_idxs=np.delete(
points3D[p3did].point2D_idxs, invalid_imgids
),
)
new_p3D_ids = p3D_ids.copy()
sub_p3D_ids = new_p3D_ids[new_p3D_ids != -1]
valids = np.ones(np.count_nonzero(new_p3D_ids != -1), dtype=bool)
valids[~valids_projected] = False
valids[valids_projected] = valids_backprojected
sub_p3D_ids[~valids] = -1
new_p3D_ids[new_p3D_ids != -1] = sub_p3D_ids
img = img._replace(point3D_ids=new_p3D_ids)
assert len(img.point3D_ids[img.point3D_ids != -1]) == len(
scs
), f"{len(scs)}, {len(img.point3D_ids[img.point3D_ids != -1])}"
for i, p3did in enumerate(img.point3D_ids[img.point3D_ids != -1]):
points3D[p3did] = points3D[p3did]._replace(xyz=scs[i])
images[imgid] = img
output_path.mkdir(parents=True, exist_ok=True)
write_model(cameras, images, points3D, output_path)
if __name__ == "__main__":
dataset = Path("datasets/7scenes")
outputs = Path("outputs/7Scenes")
SCENES = [
"chess",
"fire",
"heads",
"office",
"pumpkin",
"redkitchen",
"stairs",
]
for scene in SCENES:
sfm_path = outputs / scene / "sfm_superpoint+superglue"
depth_path = dataset / f"depth/7scenes_{scene}/train/depth"
output_path = outputs / scene / "sfm_superpoint+superglue+depth"
correct_sfm_with_gt_depth(sfm_path, depth_path, output_path)