File size: 7,843 Bytes
91126af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
import sys
import subprocess
import logging
import numpy as np
from distutils.util import strtobool
# from ace_trainer import TrainerACE
_logger = logging.getLogger(__name__)
import argparse
from pathlib import Path
from types import SimpleNamespace
# from dataset import CamLocDataset
import torch
import random
from ace_visualizer import ACEVisualizer
# from ace_network import Regressor
from torch.utils.data import DataLoader
import os
from ace_util import load_npz_file
import time
import re
import numpy as np
# import dsacstar
from collections import namedtuple
# import dataset_io
import pickle
import glob
def _strtobool(x):
return bool(strtobool(x))
def get_seed_id(seed_idx):
return f"iteration0_seed{seed_idx}"
def get_render_path(out_dir):
return out_dir / "renderings"
def get_register_opt(
rgb_files=None,
hypotheses=64,
hypotheses_max_tries=1000000,
threshold=10.0,
inlieralpha=100.0,
maxpixelerror=100.0,
render_visualization=False,
render_target_path='renderings',
render_flipped_portrait=False,
render_pose_conf_threshold=5000,
render_map_depth_filter=10,
render_camera_z_offset=4,
base_seed=1305,
confidence_threshold=1000.0,
max_estimates=-1,
render_marker_size=0.03,
result_npz=None,
results_folder="result_folder_old_test_raw"
):
if rgb_files is None:
raise ValueError("rgb_files is required")
if result_npz is None:
raise ValueError("result_npz is required")
opt = SimpleNamespace(
rgb_files=rgb_files,
hypotheses=hypotheses,
hypotheses_max_tries=hypotheses_max_tries,
threshold=threshold,
inlieralpha=inlieralpha,
maxpixelerror=maxpixelerror,
render_visualization=render_visualization,
render_target_path=Path(render_target_path),
render_flipped_portrait=render_flipped_portrait,
render_pose_conf_threshold=render_pose_conf_threshold,
render_map_depth_filter=render_map_depth_filter,
render_camera_z_offset=render_camera_z_offset,
base_seed=base_seed,
confidence_threshold=confidence_threshold,
max_estimates=max_estimates,
render_marker_size=render_marker_size,
result_npz=result_npz,
results_folder=Path(results_folder)
)
return opt
def regitser_visulization(opt):
TestEstimate = namedtuple("TestEstimate", [
"pose_est",
"pose_gt",
"focal_length",
"confidence",
"image_file"
])
#set random seeds
torch.manual_seed(opt.base_seed)
np.random.seed(opt.base_seed)
random.seed(opt.base_seed)
avg_batch_time = 0
num_batches = 0
all_files = glob.glob(opt.rgb_files)
target_path = opt.render_target_path
os.makedirs(target_path, exist_ok=True)
ace_visualizer = ACEVisualizer(target_path,
opt.render_flipped_portrait,
opt.render_map_depth_filter,
reloc_vis_conf_threshold=opt.render_pose_conf_threshold,
confidence_threshold=opt.confidence_threshold,
marker_size=opt.render_marker_size,
result_npz=opt.result_npz,
pan_start_angle=opt.pan_start_angle,
pan_radius_scale=opt.pan_radius_scale,
)
if 'state_dict' not in vars(opt).keys():
frame_idx = None
ace_visualizer.setup_reloc_visualisation(
frame_count=len(all_files),
camera_z_offset=opt.render_camera_z_offset,
frame_idx=frame_idx,
only_frustum=opt.only_frustum,
)
else:
frame_idx = opt.state_dict['frame_idx']
ace_visualizer.setup_reloc_visualisation(
frame_count=len(all_files),
camera_z_offset=opt.render_camera_z_offset,
frame_idx=frame_idx,
only_frustum=opt.only_frustum,
state_dict=opt.state_dict,
)
estimates_list = []
npz_data = load_npz_file(opt.result_npz)
pts3d_all = npz_data['pts3d']
cam_poses = npz_data['cam_poses']
cam_intrinsics = npz_data['intrinsic']
with torch.no_grad():
# for image_B1HW, _, _, _, intrinsics_B33, _, _, filenames, indices in testset_loader:
for filenames in [all_files]:
batch_start_time = time.time()
for frame_path in filenames:
img_file = frame_path
name = img_file.split('/')[-1]
match = re.search(r'_(\d+)\.png', name)
if match:
img_idx = int(match.group(1))
print(f'current image file {img_file}')
else:
print("No number found")
ours_pts3d = pts3d_all[img_idx].copy()
ours_K = cam_intrinsics[img_idx].copy()
ours_pose = cam_poses[img_idx].copy()
focal_length = ours_K[0, 0]
ppX = ours_K[0, 2]
ppY = ours_K[1, 2]
out_pose = torch.from_numpy(ours_pose.copy()).float()
scene_coordinates_3HW = torch.from_numpy(ours_pts3d.transpose(2, 0, 1)).float()
# Compute the pose via RANSAC.
# inlier_count = dsacstar.forward_rgb(
# scene_coordinates_3HW.unsqueeze(0),
# out_pose,
# opt.hypotheses,
# opt.threshold,
# focal_length,
# ppX,
# ppY,
# opt.inlieralpha,
# opt.maxpixelerror,
# 1,
# opt.base_seed,
# opt.hypotheses_max_tries
# )
estimates_list.append(TestEstimate(
pose_est=ours_pose,
pose_gt=None,
focal_length=focal_length,
confidence=10000,
image_file=frame_path
))
avg_batch_time += time.time() - batch_start_time
num_batches += 1
if 0 < opt.max_estimates <= len(estimates_list):
_logger.info(f"Stopping at {len(estimates_list)} estimates.")
break
# Process estimates and write them to file.
for estimate in estimates_list:
pose_est = estimate.pose_est
# _logger.info(f"Frame: {estimate.image_file}, Confidence: {estimate.confidence}")
for _ in range(10):
ace_visualizer.render_reloc_frame(
query_file=estimate.image_file,
est_pose=pose_est,
confidence=estimate.confidence,)
out_pose = pose_est.copy()
if opt.only_frustum:
ace_visualizer.trajectory_buffer.clear_frustums()
ace_visualizer.reset_position_markers(marker_color=ace_visualizer.progress_color_map[1] * 255)
_, vis_error, mean_value, _, _ = ace_visualizer.get_mean_repreoject_error()
vis_error[:] = mean_value
ace_visualizer.render_growing_map()
# Compute average time.
avg_time = avg_batch_time / num_batches
_logger.info(f"Avg. processing time: {avg_time * 1000:4.1f}ms")
state_dict = {}
state_dict['frame_idx'] = ace_visualizer.frame_idx
state_dict['camera_buffer'] = ace_visualizer.scene_camera.get_camera_buffer()
state_dict['pan_cameras'] = ace_visualizer.pan_cams
state_dict['map_xyz'] = ace_visualizer.pts3d.reshape(-1, 3)
state_dict['map_clr'] = ((ace_visualizer.image_gt.transpose(0, 2, 3, 1).reshape(-1, 3) + 1.0) / 2.0 * 255.0).astype('float64')
return state_dict
|