Vincentqyw
fix: roma
8b973ee
raw
history blame
4.95 kB
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from random import shuffle, seed
from .gl3d.io import read_list, _parse_img, _parse_depth, _parse_kpts
from .utils.common import Notify
from .utils.photaug import photaug
class GL3DDataset(Dataset):
def __init__(self, dataset_dir, config, data_split, is_training):
self.dataset_dir = dataset_dir
self.config = config
self.is_training = is_training
self.data_split = data_split
(
self.match_set_list,
self.global_img_list,
self.global_depth_list,
) = self.prepare_match_sets()
pass
def __len__(self):
return len(self.match_set_list)
def __getitem__(self, idx):
match_set_path = self.match_set_list[idx]
decoded = np.fromfile(match_set_path, dtype=np.float32)
idx0, idx1 = int(decoded[0]), int(decoded[1])
inlier_num = int(decoded[2])
ori_img_size0 = np.reshape(decoded[3:5], (2,))
ori_img_size1 = np.reshape(decoded[5:7], (2,))
K0 = np.reshape(decoded[7:16], (3, 3))
K1 = np.reshape(decoded[16:25], (3, 3))
rel_pose = np.reshape(decoded[34:46], (3, 4))
# parse images.
img0 = _parse_img(self.global_img_list, idx0, self.config)
img1 = _parse_img(self.global_img_list, idx1, self.config)
# parse depths
depth0 = _parse_depth(self.global_depth_list, idx0, self.config)
depth1 = _parse_depth(self.global_depth_list, idx1, self.config)
# photometric augmentation
img0 = photaug(img0)
img1 = photaug(img1)
return {
"img0": img0 / 255.0,
"img1": img1 / 255.0,
"depth0": depth0,
"depth1": depth1,
"ori_img_size0": ori_img_size0,
"ori_img_size1": ori_img_size1,
"K0": K0,
"K1": K1,
"rel_pose": rel_pose,
"inlier_num": inlier_num,
}
def points_to_2D(self, pnts, H, W):
labels = np.zeros((H, W))
pnts = pnts.astype(int)
labels[pnts[:, 1], pnts[:, 0]] = 1
return labels
def prepare_match_sets(self, q_diff_thld=3, rot_diff_thld=60):
"""Get match sets.
Args:
is_training: Use training imageset or testing imageset.
data_split: Data split name.
Returns:
match_set_list: List of match sets path.
global_img_list: List of global image path.
global_context_feat_list:
"""
# get necessary lists.
gl3d_list_folder = os.path.join(self.dataset_dir, "list", self.data_split)
global_info = read_list(
os.path.join(gl3d_list_folder, "image_index_offset.txt")
)
global_img_list = [
os.path.join(self.dataset_dir, i)
for i in read_list(os.path.join(gl3d_list_folder, "image_list.txt"))
]
global_depth_list = [
os.path.join(self.dataset_dir, i)
for i in read_list(os.path.join(gl3d_list_folder, "depth_list.txt"))
]
imageset_list_name = (
"imageset_train.txt" if self.is_training else "imageset_test.txt"
)
match_set_list = self.get_match_set_list(
os.path.join(gl3d_list_folder, imageset_list_name),
q_diff_thld,
rot_diff_thld,
)
return match_set_list, global_img_list, global_depth_list
def get_match_set_list(self, imageset_list_path, q_diff_thld, rot_diff_thld):
"""Get the path list of match sets.
Args:
imageset_list_path: Path to imageset list.
q_diff_thld: Threshold of image pair sampling regarding camera orientation.
Returns:
match_set_list: List of match set path.
"""
imageset_list = [
os.path.join(self.dataset_dir, "data", i)
for i in read_list(imageset_list_path)
]
print(Notify.INFO, "Use # imageset", len(imageset_list), Notify.ENDC)
match_set_list = []
# discard image pairs whose image simiarity is beyond the threshold.
for i in imageset_list:
match_set_folder = os.path.join(i, "match_sets")
if os.path.exists(match_set_folder):
match_set_files = os.listdir(match_set_folder)
for val in match_set_files:
name, ext = os.path.splitext(val)
if ext == ".match_set":
splits = name.split("_")
q_diff = int(splits[2])
rot_diff = int(splits[3])
if q_diff >= q_diff_thld and rot_diff <= rot_diff_thld:
match_set_list.append(os.path.join(match_set_folder, val))
print(Notify.INFO, "Get # match sets", len(match_set_list), Notify.ENDC)
return match_set_list