File size: 4,953 Bytes
404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from random import shuffle, seed
from .gl3d.io import read_list, _parse_img, _parse_depth, _parse_kpts
from .utils.common import Notify
from .utils.photaug import photaug
class GL3DDataset(Dataset):
def __init__(self, dataset_dir, config, data_split, is_training):
self.dataset_dir = dataset_dir
self.config = config
self.is_training = is_training
self.data_split = data_split
(
self.match_set_list,
self.global_img_list,
self.global_depth_list,
) = self.prepare_match_sets()
pass
def __len__(self):
return len(self.match_set_list)
def __getitem__(self, idx):
match_set_path = self.match_set_list[idx]
decoded = np.fromfile(match_set_path, dtype=np.float32)
idx0, idx1 = int(decoded[0]), int(decoded[1])
inlier_num = int(decoded[2])
ori_img_size0 = np.reshape(decoded[3:5], (2,))
ori_img_size1 = np.reshape(decoded[5:7], (2,))
K0 = np.reshape(decoded[7:16], (3, 3))
K1 = np.reshape(decoded[16:25], (3, 3))
rel_pose = np.reshape(decoded[34:46], (3, 4))
# parse images.
img0 = _parse_img(self.global_img_list, idx0, self.config)
img1 = _parse_img(self.global_img_list, idx1, self.config)
# parse depths
depth0 = _parse_depth(self.global_depth_list, idx0, self.config)
depth1 = _parse_depth(self.global_depth_list, idx1, self.config)
# photometric augmentation
img0 = photaug(img0)
img1 = photaug(img1)
return {
"img0": img0 / 255.0,
"img1": img1 / 255.0,
"depth0": depth0,
"depth1": depth1,
"ori_img_size0": ori_img_size0,
"ori_img_size1": ori_img_size1,
"K0": K0,
"K1": K1,
"rel_pose": rel_pose,
"inlier_num": inlier_num,
}
def points_to_2D(self, pnts, H, W):
labels = np.zeros((H, W))
pnts = pnts.astype(int)
labels[pnts[:, 1], pnts[:, 0]] = 1
return labels
def prepare_match_sets(self, q_diff_thld=3, rot_diff_thld=60):
"""Get match sets.
Args:
is_training: Use training imageset or testing imageset.
data_split: Data split name.
Returns:
match_set_list: List of match sets path.
global_img_list: List of global image path.
global_context_feat_list:
"""
# get necessary lists.
gl3d_list_folder = os.path.join(self.dataset_dir, "list", self.data_split)
global_info = read_list(
os.path.join(gl3d_list_folder, "image_index_offset.txt")
)
global_img_list = [
os.path.join(self.dataset_dir, i)
for i in read_list(os.path.join(gl3d_list_folder, "image_list.txt"))
]
global_depth_list = [
os.path.join(self.dataset_dir, i)
for i in read_list(os.path.join(gl3d_list_folder, "depth_list.txt"))
]
imageset_list_name = (
"imageset_train.txt" if self.is_training else "imageset_test.txt"
)
match_set_list = self.get_match_set_list(
os.path.join(gl3d_list_folder, imageset_list_name),
q_diff_thld,
rot_diff_thld,
)
return match_set_list, global_img_list, global_depth_list
def get_match_set_list(self, imageset_list_path, q_diff_thld, rot_diff_thld):
"""Get the path list of match sets.
Args:
imageset_list_path: Path to imageset list.
q_diff_thld: Threshold of image pair sampling regarding camera orientation.
Returns:
match_set_list: List of match set path.
"""
imageset_list = [
os.path.join(self.dataset_dir, "data", i)
for i in read_list(imageset_list_path)
]
print(Notify.INFO, "Use # imageset", len(imageset_list), Notify.ENDC)
match_set_list = []
# discard image pairs whose image simiarity is beyond the threshold.
for i in imageset_list:
match_set_folder = os.path.join(i, "match_sets")
if os.path.exists(match_set_folder):
match_set_files = os.listdir(match_set_folder)
for val in match_set_files:
name, ext = os.path.splitext(val)
if ext == ".match_set":
splits = name.split("_")
q_diff = int(splits[2])
rot_diff = int(splits[3])
if q_diff >= q_diff_thld and rot_diff <= rot_diff_thld:
match_set_list.append(os.path.join(match_set_folder, val))
print(Notify.INFO, "Get # match sets", len(match_set_list), Notify.ENDC)
return match_set_list
|