Spaces:
Sleeping
Sleeping
import os | |
import time | |
import torch | |
import argparse | |
import numpy as np | |
import open3d as o3d | |
import os.path as osp | |
from dust3r.losses import L21 | |
from spann3r.model import Spann3R | |
from dust3r.inference import inference | |
from dust3r.utils.geometry import geotrf | |
from dust3r.image_pairs import make_pairs | |
from spann3r.loss import Regr3D_t_ScaleShiftInv | |
from spann3r.datasets import * | |
from torch.utils.data import DataLoader | |
from spann3r.tools.eval_recon import accuracy, completion | |
def get_args_parser(): | |
parser = argparse.ArgumentParser('Spann3R evaluation', add_help=False) | |
parser.add_argument('--exp_path', type=str, help='Path to experiment folder', default='./checkpoints') | |
parser.add_argument('--exp_name', type=str, default='ckpt_best', help='Path to experiment folder') | |
parser.add_argument('--ckpt', type=str, default='spann3r.pth', help='ckpt name') | |
parser.add_argument('--scenegraph_type', type=str, default='complete', help='scenegraph type') | |
parser.add_argument('--offline', action='store_true', help='offline reconstruction') | |
parser.add_argument('--device', type=str, default='cuda:0', help='device') | |
parser.add_argument('--conf_thresh', type=float, default=0.0, help='confidence threshold') | |
return parser | |
def main(args): | |
workspace = args.exp_path | |
ckpt_path = osp.join(workspace, args.ckpt) | |
if not osp.exists(workspace): | |
raise FileNotFoundError(f"Workspace {workspace} not found") | |
exp_path = osp.join(workspace, args.exp_name) | |
os.makedirs(exp_path, exist_ok=True) | |
datasets_all = { | |
'7scenes': SevenScenes(split='test', ROOT="./data/7scenes", | |
resolution=224, num_seq=1, full_video=True, kf_every=20), | |
'NRGBD': NRGBD(split='test', ROOT="./data/neural_rgbd", | |
resolution=224, num_seq=1, full_video=True, kf_every=40), | |
'DTU': DTU(split='test', ROOT="./data/dtu_test", | |
resolution=224, num_seq=1, full_video=True, kf_every=5), | |
} | |
model = Spann3R(dus3r_name='./checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth', | |
use_feat=False).to(args.device) | |
model.load_state_dict(torch.load(ckpt_path)['model']) | |
model.eval() | |
criterion = Regr3D_t_ScaleShiftInv(L21, norm_mode=False, gt_scale=True) | |
with torch.no_grad(): | |
for name_data, dataset in datasets_all.items(): | |
save_path = osp.join(exp_path, name_data) | |
if args.offline: | |
save_path = osp.join(save_path + '_offline') | |
os.makedirs(save_path, exist_ok=True) | |
log_file = osp.join(save_path, 'logs.txt') | |
os.makedirs(save_path, exist_ok=True) | |
acc_all = 0 | |
acc_all_med = 0 | |
comp_all = 0 | |
comp_all_med = 0 | |
nc1_all = 0 | |
nc1_all_med = 0 | |
nc2_all = 0 | |
nc2_all_med = 0 | |
fps_all = [] | |
time_all = [] | |
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) | |
for i, batch in enumerate(dataloader): | |
for view in batch: | |
for name in 'img pts3d valid_mask camera_pose camera_intrinsics F_matrix corres'.split(): # pseudo_focal | |
if name not in view: | |
continue | |
view[name] = view[name].to(args.device, non_blocking=True) | |
print(f'Started reconstruction for {name_data} {i+1}/{len(dataloader)}') | |
if args.offline: | |
imgs_all = [] | |
for j, view in enumerate(batch): | |
img = view['img'] | |
shape1 = [img.size()[::-1]] | |
imgs_all.append( | |
dict( | |
img=img, | |
true_shape=torch.tensor(img.shape[2:]).unsqueeze(0), | |
idx=j, | |
instance=str(j) | |
) | |
) | |
start = time.time() | |
pairs = make_pairs(imgs_all, scene_graph=args.scenegraph_type, prefilter=None, symmetrize=True) | |
output = inference(pairs, model.dust3r, args.device, batch_size=2, verbose=True) | |
preds, preds_all, idx_used = model.offline_reconstruction(batch, output) | |
end = time.time() | |
ordered_batch = [batch[i] for i in idx_used] | |
else: | |
start = time.time() | |
preds, preds_all = model.forward(batch) | |
end = time.time() | |
ordered_batch = batch | |
fps = len(batch) / (end - start) | |
print(f'Finished reconstruction for {name_data} {i+1}/{len(dataloader)}, FPS: {fps:.2f}') | |
fps_all.append(fps) | |
time_all.append(end - start) | |
# Evaluation | |
print(f'Evaluation for {name_data} {i+1}/{len(dataloader)}') | |
gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = criterion.get_all_pts3d_t(ordered_batch, preds_all) | |
pred_scale, gt_scale, pred_shift_z, gt_shift_z = monitoring['pred_scale'], monitoring['gt_scale'], monitoring['pred_shift_z'], monitoring['gt_shift_z'] | |
in_camera1 = None | |
pts_all = [] | |
pts_gt_all = [] | |
images_all = [] | |
masks_all = [] | |
conf_all = [] | |
for j, view in enumerate(ordered_batch): | |
if in_camera1 is None: | |
in_camera1 = view['camera_pose'][0].cpu() | |
image = view['img'].permute(0, 2, 3, 1).cpu().numpy()[0] | |
mask = view['valid_mask'].cpu().numpy()[0] | |
# pts = preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'].detach().cpu().numpy()[0] | |
pts = pred_pts[0][j].cpu().numpy()[0] if j < len(pred_pts[0]) else pred_pts[1][-1].cpu().numpy()[0] | |
conf = preds[j]['conf'][0].cpu().data.numpy() | |
pts_gt = gt_pts[j].detach().cpu().numpy()[0] | |
#### Align predicted 3D points to the ground truth | |
pts[..., -1] += gt_shift_z.cpu().numpy().item() | |
pts = geotrf(in_camera1, pts) | |
pts_gt[..., -1] += gt_shift_z.cpu().numpy().item() | |
pts_gt = geotrf(in_camera1, pts_gt) | |
images_all.append((image[None, ...] + 1.0)/2.0) | |
pts_all.append(pts[None, ...]) | |
pts_gt_all.append(pts_gt[None, ...]) | |
masks_all.append(mask[None, ...]) | |
conf_all.append(conf[None, ...]) | |
images_all = np.concatenate(images_all, axis=0) | |
pts_all = np.concatenate(pts_all, axis=0) | |
pts_gt_all = np.concatenate(pts_gt_all, axis=0) | |
masks_all = np.concatenate(masks_all, axis=0) | |
conf_all = np.concatenate(conf_all, axis=0) | |
scene_id = view['label'][0].rsplit('/', 1)[0] | |
save_params = {} | |
save_params['images_all'] = images_all | |
save_params['pts_all'] = pts_all | |
save_params['pts_gt_all'] = pts_gt_all | |
save_params['masks_all'] = masks_all | |
save_params['conf_all'] = conf_all | |
np.save(os.path.join(save_path, f"{scene_id.replace('/', '_')}.npy"), save_params) | |
if 'DTU' in name_data: | |
threshold = 100 | |
else: | |
threshold = 0.1 | |
pts_all_masked = pts_all[masks_all > 0] | |
pts_gt_all_masked = pts_gt_all[masks_all > 0] | |
images_all_masked = images_all[masks_all > 0] | |
pcd = o3d.geometry.PointCloud() | |
pcd.points = o3d.utility.Vector3dVector(pts_all_masked.reshape(-1, 3)) | |
pcd.colors = o3d.utility.Vector3dVector(images_all_masked.reshape(-1, 3)) | |
o3d.io.write_point_cloud(os.path.join(save_path, f"{scene_id.replace('/', '_')}-mask.ply"), pcd) | |
pcd_gt = o3d.geometry.PointCloud() | |
pcd_gt.points = o3d.utility.Vector3dVector(pts_gt_all_masked.reshape(-1, 3)) | |
pcd_gt.colors = o3d.utility.Vector3dVector(images_all_masked.reshape(-1, 3) / 255.0) | |
o3d.io.write_point_cloud(os.path.join(save_path, f"{scene_id.replace('/', '_')}-gt.ply"), pcd_gt) | |
trans_init = np.eye(4) | |
reg_p2p = o3d.pipelines.registration.registration_icp( | |
pcd, pcd_gt, threshold, trans_init, | |
o3d.pipelines.registration.TransformationEstimationPointToPoint()) | |
transformation = reg_p2p.transformation | |
pcd = pcd.transform(transformation) | |
pcd.estimate_normals() | |
pcd_gt.estimate_normals() | |
gt_normal = np.asarray(pcd_gt.normals) | |
pred_normal = np.asarray(pcd.normals) | |
acc, acc_med, nc1, nc1_med = accuracy(pcd_gt.points, pcd.points, gt_normal, pred_normal) | |
comp, comp_med, nc2, nc2_med = completion(pcd_gt.points, pcd.points, gt_normal, pred_normal) | |
print(f"Idx: {scene_id}, Acc: {acc}, Comp: {comp}, NC1: {nc1}, NC2: {nc2} - Acc_med: {acc_med}, Compc_med: {comp_med}, NC1c_med: {nc1_med}, NC2c_med: {nc2_med}", file=open(log_file, "a")) | |
acc_all += acc | |
comp_all += comp | |
nc1_all += nc1 | |
nc2_all += nc2 | |
acc_all_med += acc_med | |
comp_all_med += comp_med | |
nc1_all_med += nc1_med | |
nc2_all_med += nc2_med | |
# release cuda memory | |
torch.cuda.empty_cache() | |
print(f"Finished evaluation for {name_data} {i+1}/{len(dataloader)}") | |
# Get depth from pcd and run TSDFusion | |
print(f"Dataset: {name_data}, Accuracy: {acc_all/len(dataloader)}, Completion: {comp_all/len(dataloader)}, NC1: {nc1_all/len(dataloader)}, NC2: {nc2_all/len(dataloader)} - Acc_med: {acc_all_med/len(dataloader)}, Comp_med: {comp_all_med/len(dataloader)}, NC1_med: {nc1_all_med/len(dataloader)}, NC2_med: {nc2_all_med/len(dataloader)}", file=open(log_file, "a")) | |
print(f"Average fps: {sum(fps) / len(fps)}, Average time: {sum(time_all) / len(time_all)}", file=open(log_file, "a")) | |
if __name__ == '__main__': | |
parser = get_args_parser() | |
args = parser.parse_args() | |
main(args) | |