|
import os |
|
import numpy as np |
|
import torch |
|
from torch.utils.data import Dataset |
|
from random import shuffle, seed |
|
|
|
from .gl3d.io import read_list, _parse_img, _parse_depth, _parse_kpts |
|
from .utils.common import Notify |
|
from .utils.photaug import photaug |
|
|
|
|
|
class GL3DDataset(Dataset): |
|
def __init__(self, dataset_dir, config, data_split, is_training): |
|
self.dataset_dir = dataset_dir |
|
self.config = config |
|
self.is_training = is_training |
|
self.data_split = data_split |
|
|
|
( |
|
self.match_set_list, |
|
self.global_img_list, |
|
self.global_depth_list, |
|
) = self.prepare_match_sets() |
|
|
|
pass |
|
|
|
def __len__(self): |
|
return len(self.match_set_list) |
|
|
|
def __getitem__(self, idx): |
|
match_set_path = self.match_set_list[idx] |
|
decoded = np.fromfile(match_set_path, dtype=np.float32) |
|
|
|
idx0, idx1 = int(decoded[0]), int(decoded[1]) |
|
inlier_num = int(decoded[2]) |
|
ori_img_size0 = np.reshape(decoded[3:5], (2,)) |
|
ori_img_size1 = np.reshape(decoded[5:7], (2,)) |
|
K0 = np.reshape(decoded[7:16], (3, 3)) |
|
K1 = np.reshape(decoded[16:25], (3, 3)) |
|
rel_pose = np.reshape(decoded[34:46], (3, 4)) |
|
|
|
|
|
img0 = _parse_img(self.global_img_list, idx0, self.config) |
|
img1 = _parse_img(self.global_img_list, idx1, self.config) |
|
|
|
depth0 = _parse_depth(self.global_depth_list, idx0, self.config) |
|
depth1 = _parse_depth(self.global_depth_list, idx1, self.config) |
|
|
|
|
|
img0 = photaug(img0) |
|
img1 = photaug(img1) |
|
|
|
return { |
|
"img0": img0 / 255.0, |
|
"img1": img1 / 255.0, |
|
"depth0": depth0, |
|
"depth1": depth1, |
|
"ori_img_size0": ori_img_size0, |
|
"ori_img_size1": ori_img_size1, |
|
"K0": K0, |
|
"K1": K1, |
|
"rel_pose": rel_pose, |
|
"inlier_num": inlier_num, |
|
} |
|
|
|
def points_to_2D(self, pnts, H, W): |
|
labels = np.zeros((H, W)) |
|
pnts = pnts.astype(int) |
|
labels[pnts[:, 1], pnts[:, 0]] = 1 |
|
return labels |
|
|
|
def prepare_match_sets(self, q_diff_thld=3, rot_diff_thld=60): |
|
"""Get match sets. |
|
Args: |
|
is_training: Use training imageset or testing imageset. |
|
data_split: Data split name. |
|
Returns: |
|
match_set_list: List of match sets path. |
|
global_img_list: List of global image path. |
|
global_context_feat_list: |
|
""" |
|
|
|
gl3d_list_folder = os.path.join(self.dataset_dir, "list", self.data_split) |
|
global_info = read_list( |
|
os.path.join(gl3d_list_folder, "image_index_offset.txt") |
|
) |
|
global_img_list = [ |
|
os.path.join(self.dataset_dir, i) |
|
for i in read_list(os.path.join(gl3d_list_folder, "image_list.txt")) |
|
] |
|
global_depth_list = [ |
|
os.path.join(self.dataset_dir, i) |
|
for i in read_list(os.path.join(gl3d_list_folder, "depth_list.txt")) |
|
] |
|
|
|
imageset_list_name = ( |
|
"imageset_train.txt" if self.is_training else "imageset_test.txt" |
|
) |
|
match_set_list = self.get_match_set_list( |
|
os.path.join(gl3d_list_folder, imageset_list_name), |
|
q_diff_thld, |
|
rot_diff_thld, |
|
) |
|
return match_set_list, global_img_list, global_depth_list |
|
|
|
def get_match_set_list(self, imageset_list_path, q_diff_thld, rot_diff_thld): |
|
"""Get the path list of match sets. |
|
Args: |
|
imageset_list_path: Path to imageset list. |
|
q_diff_thld: Threshold of image pair sampling regarding camera orientation. |
|
Returns: |
|
match_set_list: List of match set path. |
|
""" |
|
imageset_list = [ |
|
os.path.join(self.dataset_dir, "data", i) |
|
for i in read_list(imageset_list_path) |
|
] |
|
print(Notify.INFO, "Use # imageset", len(imageset_list), Notify.ENDC) |
|
match_set_list = [] |
|
|
|
for i in imageset_list: |
|
match_set_folder = os.path.join(i, "match_sets") |
|
if os.path.exists(match_set_folder): |
|
match_set_files = os.listdir(match_set_folder) |
|
for val in match_set_files: |
|
name, ext = os.path.splitext(val) |
|
if ext == ".match_set": |
|
splits = name.split("_") |
|
q_diff = int(splits[2]) |
|
rot_diff = int(splits[3]) |
|
if q_diff >= q_diff_thld and rot_diff <= rot_diff_thld: |
|
match_set_list.append(os.path.join(match_set_folder, val)) |
|
|
|
print(Notify.INFO, "Get # match sets", len(match_set_list), Notify.ENDC) |
|
return match_set_list |
|
|