Vincentqyw
fix: roma
c74a070
raw history blame
No virus
5.34 kB
import os.path as osp
import numpy as np
import torch
from dkm.utils import *
from PIL import Image
from tqdm import tqdm
class ScanNetBenchmark:
def __init__(self, data_root="data/scannet") -> None:
self.data_root = data_root
def benchmark(self, model, model_name=None):
model.train(False)
with torch.no_grad():
data_root = self.data_root
tmp = np.load(osp.join(data_root, "test.npz"))
pairs, rel_pose = tmp["name"], tmp["rel_pose"]
tot_e_t, tot_e_R, tot_e_pose = [], [], []
pair_inds = np.random.choice(
range(len(pairs)), size=len(pairs), replace=False
)
for pairind in tqdm(pair_inds, smoothing=0.9):
scene = pairs[pairind]
scene_name = f"scene0{scene[0]}_00"
im1_path = osp.join(
self.data_root,
"scans_test",
scene_name,
"color",
f"{scene[2]}.jpg",
)
im1 = Image.open(im1_path)
im2_path = osp.join(
self.data_root,
"scans_test",
scene_name,
"color",
f"{scene[3]}.jpg",
)
im2 = Image.open(im2_path)
T_gt = rel_pose[pairind].reshape(3, 4)
R, t = T_gt[:3, :3], T_gt[:3, 3]
K = np.stack(
[
np.array([float(i) for i in r.split()])
for r in open(
osp.join(
self.data_root,
"scans_test",
scene_name,
"intrinsic",
"intrinsic_color.txt",
),
"r",
)
.read()
.split("\n")
if r
]
)
w1, h1 = im1.size
w2, h2 = im2.size
K1 = K.copy()
K2 = K.copy()
dense_matches, dense_certainty = model.match(im1_path, im2_path)
sparse_matches, sparse_certainty = model.sample(
dense_matches, dense_certainty, 5000
)
scale1 = 480 / min(w1, h1)
scale2 = 480 / min(w2, h2)
w1, h1 = scale1 * w1, scale1 * h1
w2, h2 = scale2 * w2, scale2 * h2
K1 = K1 * scale1
K2 = K2 * scale2
offset = 0.5
kpts1 = sparse_matches[:, :2]
kpts1 = np.stack(
(
w1 * (kpts1[:, 0] + 1) / 2 - offset,
h1 * (kpts1[:, 1] + 1) / 2 - offset,
),
axis=-1,
)
kpts2 = sparse_matches[:, 2:]
kpts2 = np.stack(
(
w2 * (kpts2[:, 0] + 1) / 2 - offset,
h2 * (kpts2[:, 1] + 1) / 2 - offset,
),
axis=-1,
)
for _ in range(5):
shuffling = np.random.permutation(np.arange(len(kpts1)))
kpts1 = kpts1[shuffling]
kpts2 = kpts2[shuffling]
try:
norm_threshold = 0.5 / (
np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2]))
)
R_est, t_est, mask = estimate_pose(
kpts1,
kpts2,
K1,
K2,
norm_threshold,
conf=0.99999,
)
T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
e_pose = max(e_t, e_R)
except Exception as e:
print(repr(e))
e_t, e_R = 90, 90
e_pose = max(e_t, e_R)
tot_e_t.append(e_t)
tot_e_R.append(e_R)
tot_e_pose.append(e_pose)
tot_e_t.append(e_t)
tot_e_R.append(e_R)
tot_e_pose.append(e_pose)
tot_e_pose = np.array(tot_e_pose)
thresholds = [5, 10, 20]
auc = pose_auc(tot_e_pose, thresholds)
acc_5 = (tot_e_pose < 5).mean()
acc_10 = (tot_e_pose < 10).mean()
acc_15 = (tot_e_pose < 15).mean()
acc_20 = (tot_e_pose < 20).mean()
map_5 = acc_5
map_10 = np.mean([acc_5, acc_10])
map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
return {
"auc_5": auc[0],
"auc_10": auc[1],
"auc_20": auc[2],
"map_5": map_5,
"map_10": map_10,
"map_20": map_20,
}