MiniDPVO / mini_dpvo /api /inference.py
pablovela5620's picture
chore: Update dependencies and remove unused files
a8c8616
import numpy as np
import os
import torch
from pathlib import Path
from multiprocessing import Process, Queue
from yacs.config import CfgNode
from mini_dpvo.utils import Timer
from mini_dpvo.dpvo import DPVO
from mini_dpvo.stream import image_stream, video_stream
import rerun as rr
from jaxtyping import UInt8, Float64, Float32
from scipy.spatial.transform import Rotation
from dataclasses import dataclass
from timeit import default_timer as timer
import cv2
import mmcv
from tqdm import tqdm
from mini_dust3r.api import OptimizedResult, inferece_dust3r
from mini_dust3r.model import AsymmetricCroCo3DStereo
@dataclass
class DPVOPrediction:
final_poses: Float32[torch.Tensor, "num_keyframes 7"] # noqa: F722
tstamps: Float64[torch.Tensor, "num_keyframes"] # noqa: F821
final_points: Float32[torch.Tensor, "buffer_size*num_patches 3"] # noqa: F722
final_colors: UInt8[torch.Tensor, "buffer_size num_patches 3"] # noqa: F722
def log_trajectory(
parent_log_path: Path,
poses: Float32[torch.Tensor, "buffer_size 7"], # noqa: F722
points: Float32[torch.Tensor, "buffer_size*num_patches 3"], # noqa: F722
colors: UInt8[torch.Tensor, "buffer_size num_patches 3"], # noqa: F722
intri_np: Float64[np.ndarray, "4"],
bgr_hw3: UInt8[np.ndarray, "h w 3"], # noqa: F722
path_list: list,
jpg_quality: int = 90,
):
cam_log_path = f"{parent_log_path}/camera"
rgb_hw3 = mmcv.bgr2rgb(bgr_hw3)
rr.log(
f"{cam_log_path}/pinhole/image",
rr.Image(rgb_hw3).compress(jpeg_quality=jpg_quality),
)
rr.log(
f"{cam_log_path}/pinhole",
rr.Pinhole(
height=bgr_hw3.shape[0],
width=bgr_hw3.shape[1],
focal_length=[intri_np[0], intri_np[1]],
principal_point=[intri_np[2], intri_np[3]],
),
)
poses_mask = ~(poses[:, :6] == 0).all(dim=1)
points_mask = ~(points == 0).all(dim=1)
nonzero_poses = poses[poses_mask]
nonzero_points = points[points_mask]
last_index = nonzero_poses.shape[0] - 1
# get last non-zero pose, and the index of the last non-zero pose
quat_pose = nonzero_poses[last_index].numpy(force=True)
trans_quat: Float32[np.ndarray, "3"] = quat_pose[:3]
rotation_quat = Rotation.from_quat(quat_pose[3:])
cam_R_world: Float64[np.ndarray, "3 3"] = rotation_quat.as_matrix()
cam_T_world = np.eye(4)
cam_T_world[:3, :3] = cam_R_world
cam_T_world[0:3, 3] = trans_quat
world_T_cam = np.linalg.inv(cam_T_world)
path_list.append(world_T_cam[:3, 3].copy().tolist())
rr.log(
f"{cam_log_path}",
rr.Transform3D(
translation=world_T_cam[:3, 3],
mat3x3=world_T_cam[:3, :3],
from_parent=False,
),
)
# log path using linestrip
rr.log(
f"{parent_log_path}/path",
rr.LineStrips3D(
strips=[
path_list,
],
colors=[255, 0, 0],
),
)
# outlier removal
trajectory_center = np.median(nonzero_poses[:, :3].numpy(force=True), axis=0)
def radii(a):
return np.linalg.norm(a - trajectory_center, axis=1)
points_np = nonzero_points.view(-1, 3).numpy(force=True)
colors_np = colors.view(-1, 3)[points_mask].numpy(force=True)
inlier_mask = (
radii(points_np) < radii(nonzero_poses[:, :3].numpy(force=True)).max() * 5
)
points_filtered = points_np[inlier_mask]
colors_filtered = colors_np[inlier_mask]
# log all points and colors at the same time
rr.log(
f"{parent_log_path}/pointcloud",
rr.Points3D(
positions=points_filtered,
colors=colors_filtered,
),
)
return path_list
def log_final(
parent_log_path: Path,
final_poses: Float32[torch.Tensor, "num_keyframes 7"],
tstamps: Float64[torch.Tensor, "num_keyframes"], # noqa: F821
final_points: Float32[torch.Tensor, "buffer_size*num_patches 3"],
final_colors: UInt8[torch.Tensor, "buffer_size num_patches 3"],
):
for idx, (pose_quat, tstamp) in enumerate(zip(final_poses, tstamps)):
cam_log_path = f"{parent_log_path}/camera_{idx}"
trans_quat = pose_quat[:3]
R_33 = Rotation.from_quat(pose_quat[3:]).as_matrix()
rr.log(
f"{cam_log_path}",
rr.Transform3D(translation=trans_quat, mat3x3=R_33, from_parent=False),
)
def create_reader(
imagedir: str, calib: str | None, stride: int, skip: int, queue: Queue
) -> Process:
if os.path.isdir(imagedir):
reader = Process(
target=image_stream, args=(queue, imagedir, calib, stride, skip)
)
else:
reader = Process(
target=video_stream, args=(queue, imagedir, calib, stride, skip)
)
return reader
def calculate_num_frames(video_or_image_dir: str, stride: int, skip: int) -> int:
# Determine the total number of frames
total_frames = 0
if os.path.isdir(video_or_image_dir):
total_frames = len(
[
name
for name in os.listdir(video_or_image_dir)
if os.path.isfile(os.path.join(video_or_image_dir, name))
]
)
else:
cap = cv2.VideoCapture(video_or_image_dir)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
total_frames = (total_frames - skip) // stride
return total_frames
def calib_from_dust3r(
bgr_hw3: UInt8[np.ndarray, "height width 3"],
model: AsymmetricCroCo3DStereo,
device: str,
) -> Float64[np.ndarray, "3 3"]:
"""
Calculates the calibration matrix from mini-dust3r.
Args:
bgr_hw3: The input image in BGR format with shape (height, width, 3).
model: The Dust3D-R model used for inference.
device: The device to run the inference on.
Returns:
The calibration matrix with shape (3, 3).
Raises:
None.
"""
tmp_path = Path("/tmp/dpvo/tmp.png")
# save image
mmcv.imwrite(bgr_hw3, str(tmp_path))
optimized_results: OptimizedResult = inferece_dust3r(
image_dir_or_list=tmp_path.parent,
model=model,
device=device,
batch_size=1,
)
# DELETE tmp file
tmp_path.unlink()
# get predicted intrinsics in original image size
downscaled_h, downscaled_w, _ = optimized_results.rgb_hw3_list[0].shape
orig_h, orig_w, _ = bgr_hw3.shape
# Scaling factors
scaling_factor_x = orig_w / downscaled_w
scaling_factor_y = orig_h / downscaled_h
# Scale the intrinsic matrix to the original image size
K_33_original = optimized_results.K_b33[0].copy()
K_33_original[0, 0] *= scaling_factor_x # fx
K_33_original[1, 1] *= scaling_factor_y # fy
K_33_original[0, 2] *= scaling_factor_x # cx
K_33_original[1, 2] *= scaling_factor_y # cy
return K_33_original
@torch.no_grad()
def inference_dpvo(
cfg: CfgNode,
network_path: str,
imagedir: str,
calib: str,
stride: int = 1,
skip: int = 0,
timeit: bool = False,
) -> tuple[DPVOPrediction, float]:
slam = None
queue = Queue(maxsize=8)
reader: Process = create_reader(imagedir, calib, stride, skip, queue)
reader.start()
parent_log_path = Path("world")
rr.log(f"{parent_log_path}", rr.ViewCoordinates.RDF, timeless=True)
start = timer()
total_frames = calculate_num_frames(imagedir, stride, skip)
# estimate camera intrinsics if not provided
if calib is None:
dust3r_device = (
"mps"
if torch.backends.mps.is_available()
else "cuda"
if torch.cuda.is_available()
else "cpu"
)
dust3r_model = AsymmetricCroCo3DStereo.from_pretrained(
"naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt"
).to(dust3r_device)
_, bgr_hw3, _ = queue.get()
K_33_pred = calib_from_dust3r(bgr_hw3, dust3r_model, dust3r_device)
intri_np_dust3r = np.array(
[K_33_pred[0, 0], K_33_pred[1, 1], K_33_pred[0, 2], K_33_pred[1, 2]]
)
# path list for visualizing the trajectory
path_list = []
with tqdm(total=total_frames, desc="Processing Frames") as pbar:
while True:
t: int
bgr_hw3: UInt8[np.ndarray, "h w 3"]
intri_np: Float64[np.ndarray, "4"]
(t, bgr_hw3, intri_np_calib) = queue.get()
intri_np = intri_np_calib if calib is not None else intri_np_dust3r
# queue will have a (-1, image, intrinsics) tuple when the reader is done
if t < 0:
break
rr.set_time_sequence(timeline="timestep", sequence=t)
bgr_3hw: UInt8[torch.Tensor, "h w 3"] = (
torch.from_numpy(bgr_hw3).permute(2, 0, 1).cuda()
)
intri_torch: Float64[torch.Tensor, "4"] = torch.from_numpy(intri_np).cuda()
if slam is None:
slam = DPVO(cfg, network_path, ht=bgr_3hw.shape[1], wd=bgr_3hw.shape[2])
with Timer("SLAM", enabled=timeit):
slam(t, bgr_3hw, intri_torch)
if slam.is_initialized:
poses: Float32[torch.Tensor, "buffer_size 7"] = slam.poses_
points: Float32[torch.Tensor, "buffer_size*num_patches 3"] = (
slam.points_
)
colors: UInt8[torch.Tensor, "buffer_size num_patches 3"] = slam.colors_
path_list = log_trajectory(
parent_log_path=parent_log_path,
poses=poses,
points=points,
colors=colors,
intri_np=intri_np,
bgr_hw3=bgr_hw3,
path_list=path_list,
)
pbar.update(1)
for _ in range(12):
slam.update()
total_time: float = timer() - start
print(f"Total time: {total_time:.2f}s")
reader.join()
final_poses: Float32[torch.Tensor, "num_keyframes 7"]
tstamps: Float64[torch.Tensor, "num_keyframes"] # noqa: F821
final_poses, tstamps = slam.terminate()
final_points: Float32[torch.Tensor, "buffer_size*num_patches 3"] = slam.points_
final_colors: UInt8[torch.Tensor, "buffer_size num_patches 3"] = slam.colors_
dpvo_pred = DPVOPrediction(
final_poses=final_poses,
tstamps=tstamps,
final_points=final_points,
final_colors=final_colors,
)
return dpvo_pred, total_time