Spaces:
Sleeping
Sleeping
import os | |
import time | |
import torch | |
import argparse | |
import numpy as np | |
import open3d as o3d | |
import os.path as osp | |
from dust3r.losses import L21 | |
from spann3r.model import Spann3R | |
from dust3r.inference import inference | |
from dust3r.utils.geometry import geotrf | |
from dust3r.image_pairs import make_pairs | |
from spann3r.loss import Regr3D_t_ScaleShiftInv | |
from spann3r.datasets import * | |
from torch.utils.data import DataLoader | |
from spann3r.tools.eval_recon import accuracy, completion | |
from spann3r.tools.vis import render_frames, find_render_cam, vis_pred_and_imgs | |
from pose_utils import solve_cemara | |
from backend_utils import improved_multiway_registration, pts2normal, point2mesh, combine_and_clean_point_clouds | |
def get_args_parser(): | |
parser = argparse.ArgumentParser('Spann3R demo', add_help=False) | |
parser.add_argument('--save_path', type=str, default='./output/demo/', help='Path to experiment folder') | |
parser.add_argument('--demo_path', type=str, default='./examples/s00567', help='Path to experiment folder') | |
parser.add_argument('--ckpt_path', type=str, default='./checkpoints/spann3r.pth', help='ckpt path') | |
parser.add_argument('--scenegraph_type', type=str, default='complete', help='scenegraph type') | |
parser.add_argument('--offline', action='store_true', help='offline reconstruction') | |
parser.add_argument('--device', type=str, default='cuda:0', help='device') | |
parser.add_argument('--conf_thresh', type=float, default=1e-3, help='confidence threshold') | |
parser.add_argument('--kf_every', type=int, default=10, help='map every kf_every frames') | |
parser.add_argument('--vis', action='store_true', help='visualize') | |
parser.add_argument('--voxel_size', type=float, default=0.004, help='voxel size for multiway registration') | |
return parser | |
import tempfile | |
import subprocess | |
def extract_frames(video_path: str, duration: float = 20.0, fps: float = 3.0) -> str: | |
temp_dir = tempfile.mkdtemp() | |
output_path = os.path.join(temp_dir, "%03d.jpg") | |
filter_complex = f"select='if(lt(t,{duration}),1,0)',fps={fps}" | |
command = [ | |
"ffmpeg", | |
"-i", video_path, | |
"-vf", filter_complex, | |
"-vsync", "0", | |
output_path | |
] | |
subprocess.run(command, check=True) | |
return temp_dir | |
def main(args): | |
workspace = args.save_path | |
os.makedirs(workspace, exist_ok=True) | |
##### Load model | |
model = Spann3R(dus3r_name='./checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth', | |
use_feat=False).to(args.device) | |
model.load_state_dict(torch.load(args.ckpt_path)['model']) | |
model.eval() | |
if args.demo_path.endswith('.mp4') or args.demo_path.endswith('.avi') or args.demo_path.endswith('.webm'): | |
args.demo_path = extract_frames(args.demo_path) | |
args.kf_every = 1 | |
##### Load dataset | |
dataset = Demo(ROOT=args.demo_path, resolution=224, full_video=True, kf_every=args.kf_every) | |
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) | |
batch = dataloader.__iter__().__next__() | |
##### Inference | |
for view in batch: | |
view['img'] = view['img'].to(args.device, non_blocking=True) | |
demo_name = args.demo_path.split("/")[-1] | |
print(f'Started reconstruction for {demo_name}') | |
if args.offline: | |
imgs_all = [] | |
for j, view in enumerate(batch): | |
img = view['img'] | |
imgs_all.append( | |
dict( | |
img=img, | |
true_shape=torch.tensor(img.shape[2:]).unsqueeze(0), | |
idx=j, | |
instance=str(j) | |
) | |
) | |
start = time.time() | |
pairs = make_pairs(imgs_all, scene_graph=args.scenegraph_type, prefilter=None, symmetrize=True) | |
output = inference(pairs, model.dust3r, args.device, batch_size=2, verbose=True) | |
preds, preds_all, idx_used = model.offline_reconstruction(batch, output) | |
end = time.time() | |
ordered_batch = [batch[i] for i in idx_used] | |
else: | |
start = time.time() | |
preds, preds_all = model.forward(batch) | |
end = time.time() | |
ordered_batch = batch | |
fps = len(batch) / (end - start) | |
print(f'Finished reconstruction for {demo_name}, FPS: {fps:.2f}') | |
##### Save results | |
save_demo_path = osp.join(workspace, demo_name) | |
os.makedirs(save_demo_path, exist_ok=True) | |
pts_all = [] | |
pts_normal_all = [] | |
pts_gt_all = [] | |
images_all = [] | |
masks_all = [] | |
conf_sig_all = [] | |
cameras_all = [] | |
last_focal = None | |
for j, view in enumerate(ordered_batch): | |
image = view['img'].permute(0, 2, 3, 1).cpu().numpy()[0] | |
mask = view['valid_mask'].cpu().numpy()[0] | |
pts = preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'].detach().cpu().numpy()[0] | |
pts_normal = pts2normal(preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'][0]).cpu().numpy() | |
conf = preds[j]['conf'][0].cpu().data.numpy() | |
conf_sig = (conf - 1) / conf | |
pts_gt = view['pts3d'].cpu().numpy()[0] | |
camera, last_focal, depth_map = solve_cemara(torch.tensor(pts), torch.tensor(conf_sig) > args.conf_thresh, | |
args.device, focal=last_focal) | |
pts_scale = depth_map / last_focal | |
images_all.append((image[None, ...] + 1.0)/2.0) | |
pts_all.append(pts[None, ...]) | |
pts_normal_all.append(pts_normal[None, ...]) | |
pts_gt_all.append(pts_gt[None, ...]) | |
pts_scale_all.append(pts_scale[None, ...]) | |
masks_all.append(mask[None, ...]) | |
conf_sig_all.append(conf_sig[None, ...]) | |
images_all = np.concatenate(images_all, axis=0) | |
pts_all = np.concatenate(pts_all, axis=0) | |
pts_normal_all = np.concatenate(pts_normal_all, axis=0) | |
pts_gt_all = np.concatenate(pts_gt_all, axis=0) | |
masks_all = np.concatenate(masks_all, axis=0) | |
conf_sig_all = np.concatenate(conf_sig_all, axis=0) | |
# Create point clouds for multiway registration | |
pcds = [] | |
for j in range(len(pts_all)): | |
pcd = o3d.geometry.PointCloud() | |
mask = conf_sig_all[j] > args.conf_thresh | |
pcd.points = o3d.utility.Vector3dVector(pts_all[j][mask]) | |
pcd.colors = o3d.utility.Vector3dVector(images_all[j][mask]) | |
pcd.normals = o3d.utility.Vector3dVector(pts_normal_all[j][mask]) | |
pcds.append(pcd) | |
print("Performing global registration...") | |
pcd_combined, _, _ = improved_multiway_registration(pcds, voxel_size=0.001) | |
# pcd_combined = combine_and_clean_point_clouds(transformed_pcds, voxel_size=args.voxel_size * 0.1) | |
mesh_recon = point2mesh(pcd_combined) | |
if __name__ == '__main__': | |
parser = get_args_parser() | |
args = parser.parse_args() | |
main(args) |