Spaces:
Running
Running
# Data loading based on https://github.com/NVIDIA/flownet2-pytorch | |
import numpy as np | |
import torch | |
import torch.utils.data as data | |
import os | |
import json | |
import random | |
from glob import glob | |
import os.path as osp | |
from core.utils import frame_utils | |
from core.utils.augmentor import FlowAugmentor, SparseFlowAugmentor | |
from core.utils.utils import merge_flows, fill_invalid | |
from functools import reduce | |
from queue import Queue | |
def js_read(filename: str): | |
with open(filename) as f_in: | |
return json.load(f_in) | |
# Placeholder class, may contain: | |
# - images | |
# - flows | |
# - flow_masks | |
class SceneData: | |
def __init__(self): | |
self.images = {} | |
self.flows = {} | |
self.flow_masks = {} | |
self.flow_mults = {} | |
self.name = "" | |
pass | |
class DataBlob: | |
def __init__(self): | |
pass | |
class FlowDataset(data.Dataset): | |
def __init__(self, aug_params=None, sparse=False, scene_params={}): | |
self.augmentor = None | |
self.sparse = sparse | |
self.dataset = "unknown" | |
self.subsample_groundtruth = False | |
if aug_params is not None: | |
if sparse: | |
self.augmentor = SparseFlowAugmentor(**aug_params) | |
else: | |
self.augmentor = FlowAugmentor(**aug_params) | |
self.is_test = False | |
self.init_seed = False | |
self.scenes = [] | |
self.scene_params = scene_params | |
def _flow_bfs(flows, s, t, max_depth): | |
q = Queue() | |
q.put(s) | |
prev = {s: None} | |
depth = {s: 0} | |
ok = False | |
while not ok and not q.empty(): | |
v = q.get() | |
if depth[v] > max_depth or v not in flows: | |
continue | |
for u in flows[v]: | |
if u not in prev: | |
prev[u] = v | |
depth[u] = depth[v] + 1 | |
q.put(u) | |
if u == t: | |
ok = True | |
break | |
if not ok: | |
return None | |
res = [] | |
while t != s: | |
res.append([prev[t], t]) | |
t = prev[t] | |
res.reverse() | |
return res | |
# Default parameters are for the two frame FW flow case | |
def process_scenes( | |
self, | |
frames=[(0, "left"), (1, "left")], | |
flows=[((0, "left"), (1, "left"))], # (source image, target image) | |
sequence_bounds=(0, 1), # range(min_image + bound[0], max_image + bound[1]) | |
invalid_images="clip", # skip or clip | |
invalid_flows="merge", | |
): # skip or merge | |
self.data = [] | |
for scene in self.scenes: # `scene` is a SceneData class? | |
if not scene.images: | |
continue | |
min_image = min(scene.images.keys())[0] | |
max_image = max(scene.images.keys())[0] | |
if sequence_bounds is None: | |
i_list = sorted(set((x[0] for x in scene.images.keys()))) | |
else: | |
i_list = list( | |
range( | |
min_image + sequence_bounds[0], max_image + sequence_bounds[1] | |
) | |
) | |
for i in i_list: | |
valid_data = True | |
image_list = [] | |
for di, cam in frames: | |
image = (i + di, cam) | |
if image not in scene.images: | |
if invalid_images == "skip": | |
valid_data = False | |
elif invalid_images == "clip": | |
image = (max(min_image, min(max_image, image[0])), image[1]) | |
else: | |
raise Exception("Invalid image mode") | |
if image not in scene.images: | |
valid_data = False | |
if not valid_data: | |
break | |
else: | |
image_list.append(scene.images[image]) | |
( | |
flow_list, | |
flow_mask_list, | |
flow_mults_list, | |
) = [], [], [] | |
for img_pair in flows: | |
if img_pair is None: | |
flow_list.append(None) | |
flow_mask_list.append(None) | |
flow_mults_list.append(None) | |
continue | |
img1_, img2_ = img_pair | |
img1 = (img1_[0] + i, img1_[1]) | |
img2 = (img2_[0] + i, img2_[1]) | |
img_pairs = [] | |
if (img1 not in scene.flows) or (img2 not in scene.flows[img1]): | |
if invalid_flows == "skip": | |
valid_data = False | |
elif invalid_flows == "merge": | |
img_pairs = FlowDataset._flow_bfs( | |
scene.flows, | |
img1, | |
img2, | |
max_depth=abs(img1[0] - img2[0]) + 10, | |
) | |
if img_pairs is None: | |
valid_data = False | |
else: | |
raise Exception("Invalid flow mode") | |
else: | |
img_pairs = [(img1, img2)] | |
if not valid_data: | |
break | |
else: | |
flow_list.append( | |
[scene.flows[img1_][img2_] for img1_, img2_ in img_pairs] | |
) | |
try: | |
flow_mask_list.append( | |
[ | |
scene.flow_masks[img1_][img2_] | |
for img1_, img2_ in img_pairs | |
] | |
) | |
except: | |
flow_mask_list.append(None) | |
try: | |
flow_mults_list.append( | |
[ | |
scene.flow_mults[img1_][img2_] | |
for img1_, img2_ in img_pairs | |
] | |
) | |
except: | |
flow_mults_list.append(None) | |
if not valid_data: | |
continue | |
new_data = DataBlob() | |
new_data.images = image_list | |
new_data.flows = flow_list | |
new_data.flow_masks = flow_mask_list | |
new_data.flow_mults = flow_mults_list | |
new_data.extra_info = (scene.name, i, len(self.data), frames, flows) | |
self.data.append(new_data) | |
def __getitem__(self, index): | |
while True: | |
try: | |
return self.fetch(index) | |
except Exception as e: | |
index = random.randint(0, len(self) - 1) | |
raise e | |
def fetch(self, index): | |
if self.is_test: | |
imgs = [] | |
for image in self.data[index].images: | |
img = frame_utils.read_gen(image) | |
img = np.array(img).astype(np.uint8)[..., :3] | |
img = torch.from_numpy(img).permute(2, 0, 1).float() | |
imgs.append(img) | |
return torch.stack(imgs), self.data[index].extra_info | |
if not self.init_seed: | |
worker_info = torch.utils.data.get_worker_info() | |
if worker_info is not None: | |
torch.manual_seed(worker_info.id) | |
np.random.seed(worker_info.id) | |
random.seed(worker_info.id) | |
self.init_seed = True | |
index = index % len(self.data) | |
imgs = [] | |
for image in self.data[index].images: | |
img = frame_utils.read_gen(image) | |
imgs.append(img) | |
imgs = [np.array(img).astype(np.uint8) for img in imgs] | |
flows, valids = [], [] | |
for flow_ind in range(len(self.data[index].flows)): | |
if self.data[index].flows[flow_ind] is None: | |
if len(imgs) > 0: | |
n, m = imgs[0].shape[:2] | |
else: | |
n, m = 1 | |
flows.append(np.zeros((n, m, 2)).astype(np.float32)) | |
valids.append(np.zeros((n, m)).astype(np.float32)) | |
continue | |
cur_flows, cur_valids = [], [] | |
if self.sparse: | |
if self.dataset == "TartanAir": | |
for f, fm in zip( | |
self.data[index].flows[flow_ind], | |
self.data[index].flow_masks[flow_ind], | |
): | |
flow = np.load(f) | |
valid = np.load(fm) | |
valid = 1 - valid / 100 | |
cur_flows.append(flow) | |
cur_valids.append(valid) | |
else: | |
for f in self.data[index].flows[flow_ind]: | |
flow, valid = frame_utils.readFlowKITTI(f) | |
cur_flows.append(flow) | |
cur_valids.append(valid) | |
else: | |
if self.dataset == "Infinigen": | |
# Inifinigen flow is stored as a 3D numpy array, [Flow, Depth] | |
for f in self.data[index].flows[flow_ind]: | |
flow = np.load(f) | |
flow = flow[..., :2] | |
cur_flows.append(flow) | |
elif self.data[index].flow_mults[flow_ind] is not None: | |
for f, f_mult in zip( | |
self.data[index].flows[flow_ind], | |
self.data[index].flow_mults[flow_ind], | |
): | |
flow = frame_utils.read_gen(f) | |
flow *= np.array(f_mult).astype(flow.dtype) | |
cur_flows.append(flow) | |
else: | |
for f in self.data[index].flows[flow_ind]: | |
flow = frame_utils.read_gen(f) | |
cur_flows.append(flow) | |
cur_flows = [np.array(flow).astype(np.float32) for flow in cur_flows] | |
if not self.sparse: | |
cur_valids = [ | |
( | |
(np.abs(flow[:, :, 0]) < 1000) & (np.abs(flow[:, :, 1]) < 1000) | |
).astype(np.float32) | |
for flow in cur_flows | |
] | |
if self.subsample_groundtruth: | |
# use only every second value in both spatial directions ==> flow will have same dimensions as images | |
# used for spring dataset | |
cur_flows = [flow[::2, ::2] for flow in cur_flows] | |
cur_valids = [valid[::2, ::2] for valid in cur_valids] | |
# TODO: merge flows | |
if len(cur_flows) == 1: | |
flows.append(cur_flows[0]) | |
valids.append(cur_valids[0]) | |
else: | |
cur_flow = cur_flows[0] | |
cur_valid = cur_valids[0] | |
for i in range(1, len(cur_flows)): | |
cur_flow, cur_valid = merge_flows( | |
cur_flow, cur_valid, cur_flows[i], cur_valids[i] | |
) | |
cur_flow = fill_invalid(cur_flow, cur_valid) | |
flows.append(cur_flow) | |
valids.append(cur_valid) | |
# grayscale images | |
if len(imgs[0].shape) == 2: | |
imgs = [np.tile(img[..., None], (1, 1, 3)) for img in imgs] | |
else: | |
imgs = [img[..., :3] for img in imgs] | |
if self.augmentor is not None: | |
if self.sparse: | |
imgs, flows, valids = self.augmentor(imgs, flows, valids) | |
else: | |
imgs, flows, valids = self.augmentor(imgs, flows, valids) | |
imgs = [torch.from_numpy(img).permute(2, 0, 1).float() for img in imgs] | |
new_flows = [] | |
for flow in flows: | |
flow = torch.from_numpy(flow).permute(2, 0, 1).float() | |
flow[torch.isnan(flow)] = 0 | |
flow[flow.abs() > 1e9] = 0 | |
new_flows.append(flow) | |
flows = new_flows | |
if len(valids): | |
valids = [torch.from_numpy(valid) for valid in valids] | |
if not self.sparse: | |
valids = [ | |
(valids[i] >= 0.5) | |
& (flows[i][0].abs() < 1000) | |
& (flows[i][1].abs() < 1000) | |
for i in range(len(flows)) | |
] | |
else: # should never execute? | |
valids = [(flow[0].abs() < 1000) & (flow[1].abs() < 1000) for flow in flows] | |
return torch.stack(imgs), torch.stack(flows), torch.stack(valids).float() | |
def add_datasets(self, other): | |
self.data = self.data + other.data | |
del other | |
return self | |
def __rmul__(self, v): | |
self.data = v * self.data | |
return self | |
def __len__(self): | |
return len(self.data) | |
class MpiSintel(FlowDataset): | |
def __init__( | |
self, | |
aug_params=None, | |
split="train", | |
root="datasets/Sintel", | |
dstype="clean", | |
scene_params={}, | |
): | |
super(MpiSintel, self).__init__( | |
aug_params=aug_params, scene_params=scene_params | |
) | |
assert split in ["train", "val", "submission"] | |
assert dstype in ["clean", "final"] | |
self.dataset = "MpiSintel" | |
seq_root = { | |
"train": "training", | |
"val": "training", | |
"submission": "test", | |
}[split] | |
flow_root = osp.join(root, seq_root, "flow") | |
image_root = osp.join(root, seq_root, dstype) | |
if split == "submission": | |
self.is_test = True | |
scene_split = js_read(os.path.join("config", "splits", "sintel.json")) | |
for scene in sorted(os.listdir(image_root)): | |
if scene_split[scene] != split: | |
continue | |
image_list = sorted(glob(osp.join(image_root, scene, "*.png"))) | |
flow_list = sorted(glob(osp.join(flow_root, scene, "*.flo"))) | |
current_scene = SceneData() | |
current_scene.name = scene | |
for i in range(len(image_list)): | |
current_scene.images[(i, "left")] = image_list[i] | |
if split != "submission": | |
for i in range(len(flow_list)): | |
current_scene.flows[(i, "left")] = {(i + 1, "left"): flow_list[i]} | |
self.scenes.append(current_scene) | |
self.process_scenes(**scene_params) | |
class FlyingChairs(FlowDataset): | |
def __init__( | |
self, | |
aug_params=None, | |
split="train", | |
root="datasets/FlyingChairs/FlyingChairs_release/data", | |
scene_params={}, | |
): | |
super(FlyingChairs, self).__init__( | |
aug_params=aug_params, scene_params=scene_params | |
) | |
self.dataset = "FlyingChairs" | |
images = sorted(glob(osp.join(root, "*.ppm"))) | |
flows = sorted(glob(osp.join(root, "*.flo"))) | |
assert len(images) // 2 == len(flows) | |
split_list = np.loadtxt("chairs_split.txt", dtype=np.int32) | |
for i in range(len(flows)): | |
xid = split_list[i] | |
if (split == "training" and xid == 1) or ( | |
split == "validation" and xid == 2 | |
): | |
current_scene = SceneData() | |
current_scene.name = str(i) | |
current_scene.images[(0, "left")] = images[2 * i] | |
current_scene.images[(1, "left")] = images[2 * i + 1] | |
current_scene.flows[(0, "left")] = {(1, "left"): flows[i]} | |
self.scenes.append(current_scene) | |
self.process_scenes(**scene_params) | |
class FlyingThings3D(FlowDataset): | |
def __init__( | |
self, | |
aug_params=None, | |
root="datasets/FlyingThings3D", | |
dstype="frames_cleanpass", | |
scene_params={}, | |
): | |
super(FlyingThings3D, self).__init__( | |
aug_params=aug_params, scene_params=scene_params | |
) | |
self.dataset = "FlyingThings3D" | |
image_dirs, flow_dirs, disp_dirs = {}, {}, {} | |
for cam in ["left", "right"]: | |
image_dirs[cam] = sorted(glob(osp.join(root, dstype, "TRAIN/*/*"))) | |
image_dirs[cam] = sorted([osp.join(f, cam) for f in image_dirs[cam]]) | |
flow_dirs[cam] = {} | |
for direction in ["into_future", "into_past"]: | |
flow_dirs[cam][direction] = sorted( | |
glob(osp.join(root, "optical_flow/TRAIN/*/*")) | |
) | |
flow_dirs[cam][direction] = sorted( | |
[osp.join(f, direction, cam) for f in flow_dirs[cam][direction]] | |
) | |
disp_dirs[cam] = sorted(glob(osp.join(root, "disparity/TRAIN/*/*"))) | |
disp_dirs[cam] = sorted([osp.join(f, cam) for f in disp_dirs[cam]]) | |
for ( | |
idir_l, | |
idir_r, | |
fdir_fw_l, | |
fdir_fw_r, | |
fdir_bw_l, | |
fdir_bw_r, | |
ddir_l, | |
ddir_r, | |
) in zip( | |
image_dirs["left"], | |
image_dirs["right"], | |
flow_dirs["left"]["into_future"], | |
flow_dirs["right"]["into_future"], | |
flow_dirs["left"]["into_past"], | |
flow_dirs["right"]["into_past"], | |
disp_dirs["left"], | |
disp_dirs["right"], | |
): | |
images_l = sorted(glob(osp.join(idir_l, "*.png"))) | |
fw_flows_l = sorted(glob(osp.join(fdir_fw_l, "*.pfm"))) | |
bw_flows_l = sorted(glob(osp.join(fdir_bw_l, "*.pfm"))) | |
disp_l = sorted(glob(osp.join(ddir_l, "*.pfm"))) | |
images_r = sorted(glob(osp.join(idir_r, "*.png"))) | |
fw_flows_r = sorted(glob(osp.join(fdir_fw_r, "*.pfm"))) | |
bw_flows_r = sorted(glob(osp.join(fdir_bw_r, "*.pfm"))) | |
disp_r = sorted(glob(osp.join(ddir_r, "*.pfm"))) | |
current_scene = SceneData() | |
for i in range(len(images_l)): | |
current_scene.images[(i, "left")] = images_l[i] | |
current_scene.flows[(i, "left")] = { | |
(i - 1, "left"): bw_flows_l[i], | |
(i + 1, "left"): fw_flows_l[i], | |
} | |
current_scene.images[(i, "right")] = images_r[i] | |
current_scene.flows[(i, "right")] = { | |
(i - 1, "right"): bw_flows_r[i], | |
(i + 1, "right"): fw_flows_r[i], | |
} | |
current_scene.flows[(i, "left")][(i, "right")] = disp_l[i] | |
current_scene.flows[(i, "right")][(i, "left")] = disp_r[i] | |
for k in current_scene.flows: | |
current_scene.flow_mults[k] = {} | |
for k2 in current_scene.flows[k]: | |
current_scene.flow_mults[k][k2] = 1 | |
for i in range(len(images_l)): | |
current_scene.flow_mults[(i, "right")][(i, "left")] = -1 | |
self.scenes.append(current_scene) | |
self.process_scenes(**scene_params) | |
class KITTI(FlowDataset): | |
def __init__( | |
self, aug_params=None, split="train", root="datasets/KITTI", scene_params={} | |
): | |
super(KITTI, self).__init__( | |
aug_params=aug_params, scene_params=scene_params, sparse=True | |
) | |
assert split in ["train", "val", "test", "submission"] | |
self.dataset = "KITTI" | |
if split == "submission": | |
self.is_test = True | |
seq_split = { | |
"train": "training", | |
"val": "training", | |
"test": "training", | |
"submission": "testing", | |
}[split] | |
root = osp.join(root, seq_split) | |
images0 = sorted(glob(osp.join(root, "image_2/*_09.png"))) | |
images1 = sorted(glob(osp.join(root, "image_2/*_10.png"))) | |
images2 = sorted(glob(osp.join(root, "image_2/*_11.png"))) | |
scene_split = js_read(os.path.join("config", "splits", "kitti.json")) | |
image_list = [] | |
flow_list = [] | |
for img0, img1, img2 in zip(images0, images1, images2): | |
if split != "submission" and scene_split[img1[-13:-7]] != split: | |
continue | |
image_list += [[img0, img1, img2]] | |
if split != "submission": | |
flow_list = [ | |
f | |
for f in sorted(glob(osp.join(root, "flow_occ/*_10.png"))) | |
if scene_split[f[-13:-7]] == split | |
] | |
for i in range(len(image_list)): | |
current_scene = SceneData() | |
current_scene.images[(0, "left")] = image_list[i][0] | |
current_scene.images[(1, "left")] = image_list[i][1] | |
current_scene.images[(2, "left")] = image_list[i][2] | |
if split != "submission": | |
current_scene.flows[(1, "left")] = {(2, "left"): flow_list[i]} | |
self.scenes.append(current_scene) | |
self.process_scenes(**scene_params) | |
class HD1K(FlowDataset): | |
def __init__(self, aug_params=None, root="datasets/HD1K", scene_params={}): | |
super(HD1K, self).__init__( | |
aug_params=aug_params, scene_params=scene_params, sparse=True | |
) | |
self.dataset = "HD1K" | |
seq_ix = 0 | |
while 1: | |
flows = sorted( | |
glob(os.path.join(root, "hd1k_flow_gt", "flow_occ/%06d_*.png" % seq_ix)) | |
) | |
images = sorted( | |
glob(os.path.join(root, "hd1k_input", "image_2/%06d_*.png" % seq_ix)) | |
) | |
if len(flows) == 0: | |
break | |
current_scene = SceneData() | |
for i in range(len(images)): | |
current_scene.images[(i, "left")] = images[i] | |
for i in range(len(flows)): | |
current_scene.flows[(i, "left")] = {(i + 1, "left"): flows[i]} | |
self.scenes.append(current_scene) | |
seq_ix += 1 | |
self.process_scenes(**scene_params) | |
class SpringFlowDataset(FlowDataset): | |
""" | |
Dataset class for Spring optical flow dataset. | |
For train, this dataset returns image1, image2, flow and a data tuple (framenum, scene name, left/right cam, FW/BW direction). | |
For test, this dataset returns image1, image2 and a data tuple (framenum, scene name, left/right cam, FW/BW direction). | |
root: root directory of the spring dataset (should contain test/train directories) | |
split: train/test split | |
subsample_groundtruth: If true, return ground truth such that it has the same dimensions as the images (1920x1080px); if false return full 4K resolution | |
""" | |
def __init__( | |
self, | |
aug_params=None, | |
root="datasets/Spring", | |
split="train", | |
subsample_groundtruth=True, | |
scene_params={}, | |
): | |
super(SpringFlowDataset, self).__init__( | |
aug_params=aug_params, scene_params=scene_params | |
) | |
assert split in ["train", "val", "test", "submission"] | |
self.dataset = "Spring" | |
seq_root = { | |
"train": "train", | |
"val": "train", | |
"test": "train", | |
"submission": "test", | |
}[split] | |
seq_root = os.path.join(root, seq_root) | |
if not os.path.exists(seq_root): | |
raise ValueError(f"Spring directory does not exist: {seq_root}") | |
self.subsample_groundtruth = subsample_groundtruth | |
self.split = split | |
self.seq_root = seq_root | |
self.data_list = [] | |
if split == "submission": | |
self.is_test = True | |
scene_split = js_read(os.path.join("config", "splits", "spring.json")) | |
for scene in sorted(os.listdir(seq_root)): | |
if scene_split[scene] != split: | |
continue | |
current_scene = SceneData() | |
current_scene.name = scene | |
for cam in ["left", "right"]: | |
images = sorted( | |
glob(os.path.join(seq_root, scene, f"frame_{cam}", "*.png")) | |
) | |
for i in range(len(images)): | |
current_scene.images[(i, cam)] = images[i] | |
current_scene.flows[(i, cam)] = {} | |
current_scene.flow_mults[(i, cam)] = {} | |
if split != "submission": | |
for direction in ["FW"]: | |
flows = sorted( | |
glob( | |
os.path.join( | |
seq_root, scene, f"flow_{direction}_{cam}", "*.flo5" | |
) | |
) | |
) | |
for i in range(len(flows)): | |
current_scene.flows[(i, cam)][(i + 1, cam)] = flows[i] | |
current_scene.flow_mults[(i, cam)][(i + 1, cam)] = 1 | |
for direction in ["BW"]: | |
flows = sorted( | |
glob( | |
os.path.join( | |
seq_root, scene, f"flow_{direction}_{cam}", "*.flo5" | |
) | |
) | |
) | |
for i in range(len(flows)): | |
current_scene.flows[(i + 1, cam)][(i, cam)] = flows[i] | |
current_scene.flow_mults[(i + 1, cam)][(i, cam)] = 1 | |
if cam == "left": | |
othercam = "right" | |
else: | |
othercam = "left" | |
disps = sorted( | |
glob(os.path.join(seq_root, scene, f"disp1_{cam}", "*.dsp5")) | |
) | |
for i in range(len(disps)): | |
current_scene.flows[(i, cam)][(i, othercam)] = disps[i] | |
current_scene.flow_mults[(i, cam)][(i, othercam)] = ( | |
1 if cam == "left" else -1 | |
) | |
self.scenes.append(current_scene) | |
self.process_scenes(**scene_params) | |
class Infinigen(FlowDataset): | |
def __init__(self, aug_params=None, root="datasets/Infinigen", scene_params={}): | |
super(Infinigen, self).__init__( | |
aug_params=aug_params, scene_params=scene_params | |
) | |
self.root = root | |
scenes = glob(osp.join(self.root, "*/")) | |
self.dataset = "Infinigen" | |
for scene in sorted(scenes): | |
if not osp.isdir(osp.join(scene, "frames")): | |
continue | |
current_scene = SceneData() | |
images_left = sorted(glob(osp.join(scene, "frames/Image/camera_0/*.png"))) | |
images_right = sorted(glob(osp.join(scene, "frames/Image/camera_1/*.png"))) | |
for idx in range(len(images_left)): | |
current_scene.images[(idx, "left")] = images_left[idx] | |
current_scene.images[(idx, "right")] = images_right[idx] | |
for idx in range(len(images_left) - 1): | |
# name = Image + "_{ID}" | |
ID_left = images_left[idx].split("/")[-1][6:-4] | |
ID_right = images_right[idx].split("/")[-1][6:-4] | |
flow_path_left = osp.join( | |
scene, "frames/Flow3D/camera_0", f"Flow3D_{ID_left}.npy" | |
) | |
flow_path_right = osp.join( | |
scene, "frames/Flow3D/camera_1", f"Flow3D_{ID_right}.npy" | |
) | |
current_scene.flows[(idx, "left")] = {(idx + 1, "left"): flow_path_left} | |
current_scene.flows[(idx, "right")] = { | |
(idx + 1, "right"): flow_path_right | |
} | |
self.scenes.append(current_scene) | |
self.process_scenes(**scene_params) | |
class TartanAir(FlowDataset): | |
# scale depths to balance rot & trans | |
DEPTH_SCALE = 5.0 | |
def __init__(self, aug_params=None, root="datasets/TartanAir", scene_params={}): | |
super(TartanAir, self).__init__( | |
aug_params=aug_params, scene_params=scene_params, sparse=True | |
) | |
self.dataset = "TartanAir" | |
self.root = root | |
self._build_dataset() | |
self.process_scenes(**scene_params) | |
def _build_dataset(self): | |
scenes = glob(osp.join(self.root, "*/*/*")) | |
for scene in sorted(scenes): | |
current_scene = SceneData() | |
current_scene.name = scene | |
images = sorted(glob(osp.join(scene, "image_left/*.png"))) | |
for idx in range(len(images)): | |
current_scene.images[(idx, "left")] = images[idx] | |
for idx in range(len(images) - 1): | |
frame0 = str(idx).zfill(6) | |
frame1 = str(idx + 1).zfill(6) | |
current_scene.flows[(idx, "left")] = { | |
(idx + 1, "left"): osp.join( | |
scene, "flow", f"{frame0}_{frame1}_flow.npy" | |
) | |
} | |
current_scene.flow_masks[(idx, "left")] = { | |
(idx + 1, "left"): osp.join( | |
scene, "flow", f"{frame0}_{frame1}_mask.npy" | |
) | |
} | |
self.scenes.append(current_scene) | |
# TODO: Adapt for new base class | |
class MegaScene(data.Dataset): | |
def __init__(self, root_dir, npz_path, min_overlap_score=0.4, **kwargs): | |
super().__init__() | |
self.root_dir = root_dir | |
self.scene_info = np.load(npz_path, allow_pickle=True) | |
self.pair_infos = self.scene_info["pair_infos"].copy() | |
del self.scene_info["pair_infos"] | |
self.pair_infos = [ | |
pair_info | |
for pair_info in self.pair_infos | |
if pair_info[1] > min_overlap_score | |
] | |
def __len__(self): | |
return len(self.pair_infos) | |
def __getitem__(self, idx): | |
(idx0, idx1), overlap_score, central_matches = self.pair_infos[idx] | |
img_name0 = osp.join(self.root_dir, self.scene_info["image_paths"][idx0]) | |
img_name1 = osp.join(self.root_dir, self.scene_info["image_paths"][idx1]) | |
depth_name0 = osp.join(self.root_dir, self.scene_info["depth_paths"][idx0]) | |
depth_name1 = osp.join(self.root_dir, self.scene_info["depth_paths"][idx1]) | |
# read intrinsics of original size | |
K_0 = self.scene_info["intrinsics"][idx0].copy().reshape(3, 3) | |
K_1 = self.scene_info["intrinsics"][idx1].copy().reshape(3, 3) | |
# read and compute relative poses | |
T0 = self.scene_info["poses"][idx0] | |
T1 = self.scene_info["poses"][idx1] | |
data = { | |
"image0": img_name0, | |
"image1": img_name1, | |
"depth0": depth_name0, | |
"depth1": depth_name1, | |
"T0": T0, | |
"T1": T1, # (4, 4) | |
"K0": K_0, # (3, 3) | |
"K1": K_1, | |
} | |
return data | |
# TODO: Adapt for new base class | |
class MegaDepth(FlowDataset): | |
def __init__(self, aug_params=None, root="datasets/megadepth"): | |
super(MegaDepth, self).__init__(aug_params, sparse=True) | |
self.n_frames = 2 | |
self.dataset = "MegaDepth" | |
self.root = root | |
self._build_dataset() | |
def _build_dataset(self): | |
dataset_path = osp.join(self.root, "train") | |
index_folder = osp.join(self.root, "index/scene_info_0.1_0.7") | |
index_path_list = glob(index_folder + "/*.npz") | |
dataset_list = [] | |
for index_path in index_path_list: | |
my_dataset = MegaScene(dataset_path, index_path, min_overlap_score=0.4) | |
dataset_list.append(my_dataset) | |
self.megascene = torch.utils.data.ConcatDataset(dataset_list) | |
for i in range(len(self.megascene)): | |
data = self.megascene[i] | |
self.image_list.append([data["image0"], data["image1"]]) | |
self.extra_info.append( | |
[ | |
data["depth0"], | |
data["depth1"], | |
data["T0"], | |
data["T1"], | |
data["K0"], | |
data["K1"], | |
] | |
) | |
# TODO: Adapt for new base class | |
class Middlebury(FlowDataset): | |
def __init__(self, aug_params=None, root="datasets/middlebury"): | |
super(Middlebury, self).__init__(aug_params) | |
img_root = os.path.join(root, "images") | |
flow_root = os.path.join(root, "flow") | |
flows = [] | |
imgs = [] | |
info = [] | |
for scene in sorted(os.listdir(flow_root)): | |
img0 = os.path.join(img_root, scene, "frame10.png") | |
img1 = os.path.join(img_root, scene, "frame11.png") | |
flow = os.path.join(flow_root, scene, "flow10.flo") | |
imgs += [(img0, img1)] | |
flows += [flow] | |
info += [scene] | |
self.image_list = imgs | |
self.flow_list = flows | |
self.extra_info = info | |
def three_frame_wrapper2(dataset_class, dataset_args, add_reversed=True, invalid_images="skip"): | |
datasets = [] | |
for cam in ["left", "right"]: | |
datasets.append( | |
dataset_class( | |
**dataset_args, | |
scene_params={ | |
"frames": [(-1, cam), (0, cam), (1, cam)], | |
"flows": [None, ((0, cam), (1, cam))], | |
"invalid_images": invalid_images, | |
}, | |
) | |
) | |
if add_reversed: | |
datasets.append( | |
dataset_class( | |
**dataset_args, | |
scene_params={ | |
"frames": [(1, cam), (0, cam), (-1, cam)], | |
"flows": [((0, cam), (1, cam)), None], | |
"invalid_images": invalid_images, | |
}, | |
) | |
) | |
return reduce(lambda x, y: x.add_datasets(y), datasets) | |
def three_frame_wrapper_val(dataset_class, dataset_args): | |
datasets = [] | |
for cam in ["left", "right"]: | |
datasets.append( | |
dataset_class( | |
**dataset_args, | |
scene_params={ | |
"frames": [(-1, cam), (0, cam), (1, cam)], | |
"flows": [((0, cam), (1, cam))], | |
}, | |
) | |
) | |
datasets.append( | |
dataset_class( | |
**dataset_args, | |
scene_params={ | |
"frames": [(1, cam), (0, cam), (-1, cam)], | |
"flows": [((0, cam), (-1, cam))], | |
}, | |
) | |
) | |
return reduce(lambda x, y: x.add_datasets(y), datasets) | |
def three_frame_wrapper_spring_submission(dataset_class, dataset_args): | |
datasets = [] | |
for cam in ["left", "right"]: | |
datasets.append( | |
dataset_class( | |
**dataset_args, | |
scene_params={ | |
"frames": [(-1, cam), (0, cam), (1, cam)], | |
"flows": [], | |
"sequence_bounds": (0, 0), | |
}, | |
) | |
) | |
datasets.append( | |
dataset_class( | |
**dataset_args, | |
scene_params={ | |
"frames": [(1, cam), (0, cam), (-1, cam)], | |
"flows": [], | |
"sequence_bounds": (1, 1), | |
}, | |
) | |
) | |
return reduce(lambda x, y: x.add_datasets(y), datasets) | |
def three_frame_wrapper_sintel_submission(dataset_class, dataset_args): | |
datasets = [] | |
for cam in ["left", "right"]: | |
datasets.append( | |
dataset_class( | |
**dataset_args, | |
scene_params={ | |
"frames": [(-1, cam), (0, cam), (1, cam)], | |
"flows": [], | |
"sequence_bounds": (0, 0), | |
}, | |
) | |
) | |
return reduce(lambda x, y: x.add_datasets(y), datasets) | |
def three_frame_wrapper_kitti_submission(dataset_class, dataset_args): | |
datasets = [] | |
for cam in ["left", "right"]: | |
datasets.append( | |
dataset_class( | |
**dataset_args, | |
scene_params={ | |
"frames": [(-1, cam), (0, cam), (1, cam)], | |
"flows": [], | |
"sequence_bounds": (0, 0), | |
"invalid_images": "skip" | |
}, | |
) | |
) | |
return reduce(lambda x, y: x.add_datasets(y), datasets) | |
def fetch_dataloader(args): | |
"""Create the data loader for the corresponding training set""" | |
if args.dataset == "things": | |
aug_params = { | |
"crop_size": args.image_size, | |
"pre_scale": args.scale, | |
"min_scale": -0.4, | |
"max_scale": +0.8, | |
"do_flip": True, | |
"do_rotate": False, | |
} | |
clean_dataset = three_frame_wrapper2( | |
FlyingThings3D, {"aug_params": aug_params, "dstype": "frames_cleanpass"} | |
) | |
final_dataset = three_frame_wrapper2( | |
FlyingThings3D, {"aug_params": aug_params, "dstype": "frames_finalpass"} | |
) | |
train_dataset = clean_dataset + final_dataset | |
elif args.dataset == "kitti": | |
aug_params = { | |
"crop_size": args.image_size, | |
"pre_scale": args.scale, | |
"min_scale": -0.2, | |
"max_scale": +0.4, | |
"do_flip": False, | |
} | |
train_dataset = three_frame_wrapper2( | |
KITTI, {"aug_params": aug_params, "split": "train"} | |
) | |
elif args.dataset == "kitti-full": | |
kitti = reduce( | |
lambda x, y: x.add_datasets(y), | |
[ | |
three_frame_wrapper2( | |
KITTI, | |
{ | |
"aug_params": { | |
"crop_size": args.image_size, | |
"pre_scale": args.scale, | |
"min_scale": -0.2, | |
"max_scale": +0.4, | |
"do_flip": True, | |
}, | |
"split": current_split, | |
}, | |
) | |
for current_split in ["train", "val", "test"] | |
], | |
) | |
train_dataset = kitti | |
elif args.dataset == "spring": | |
aug_params = { | |
"crop_size": args.image_size, | |
"pre_scale": args.scale, | |
"min_scale": 0.0, | |
"max_scale": +0.2, | |
"do_flip": True, | |
"do_rotate": False, | |
} | |
train_dataset = three_frame_wrapper2( | |
SpringFlowDataset, {"aug_params": aug_params, "subsample_groundtruth": True} | |
) | |
elif args.dataset == "spring-full": | |
aug_params = { | |
"crop_size": args.image_size, | |
"pre_scale": args.scale, | |
"min_scale": 0.0, | |
"max_scale": +0.2, | |
"do_flip": True, | |
"do_rotate": False, | |
} | |
train_dataset = reduce( | |
lambda x, y: x.add_datasets(y), | |
[ | |
three_frame_wrapper2( | |
SpringFlowDataset, | |
{ | |
"split": cur_sp, | |
"aug_params": aug_params, | |
"subsample_groundtruth": True, | |
}, | |
) | |
for cur_sp in ["train", "val", "test"] | |
], | |
) | |
elif args.dataset == "TartanAir": | |
aug_params = { | |
"crop_size": args.image_size, | |
"pre_scale": args.scale, | |
"min_scale": -0.2, | |
"max_scale": +0.4, | |
"do_flip": True, | |
"do_rotate": False, | |
} | |
train_dataset = three_frame_wrapper2(TartanAir, {"aug_params": aug_params}) | |
elif args.dataset == "TSKH": | |
aug_params = { | |
"crop_size": args.image_size, | |
"pre_scale": args.scale, | |
"min_scale": -0.2, | |
"max_scale": +0.6, | |
"do_flip": True, | |
"do_rotate": False, | |
} | |
things_clean = three_frame_wrapper2( | |
FlyingThings3D, {"aug_params": aug_params, "dstype": "frames_cleanpass"} | |
) | |
things_final = three_frame_wrapper2( | |
FlyingThings3D, {"aug_params": aug_params, "dstype": "frames_finalpass"} | |
) | |
things = things_clean + things_final | |
sintel_clean = three_frame_wrapper2( | |
MpiSintel, {"aug_params": aug_params, "split": "train", "dstype": "clean"} | |
) | |
sintel_final = three_frame_wrapper2( | |
MpiSintel, {"aug_params": aug_params, "split": "train", "dstype": "final"} | |
) | |
kitti = three_frame_wrapper2( | |
KITTI, | |
{ | |
"aug_params": { | |
"crop_size": args.image_size, | |
"pre_scale": args.scale, | |
"min_scale": -0.3, | |
"max_scale": +0.5, | |
"do_flip": True, | |
}, | |
"split": "train", | |
}, | |
) | |
hd1k = three_frame_wrapper2( | |
HD1K, | |
{ | |
"aug_params": { | |
"crop_size": args.image_size, | |
"pre_scale": args.scale, | |
"min_scale": -0.5, | |
"max_scale": +0.2, | |
"do_flip": True, | |
} | |
}, | |
) | |
# After (left + right) * (frames_cleanpass + frames_finalpass) | |
train_dataset = ( | |
100 * sintel_clean + 100 * sintel_final + 400 * kitti + 120 * hd1k + things | |
) | |
elif args.dataset == "TSKH-full": | |
aug_params = { | |
"crop_size": args.image_size, | |
"pre_scale": args.scale, | |
"min_scale": -0.2, | |
"max_scale": +0.6, | |
"do_flip": True, | |
"do_rotate": False, | |
} | |
things_clean = three_frame_wrapper2( | |
FlyingThings3D, {"aug_params": aug_params, "dstype": "frames_cleanpass"} | |
) | |
things_final = three_frame_wrapper2( | |
FlyingThings3D, {"aug_params": aug_params, "dstype": "frames_finalpass"} | |
) | |
things = things_clean + things_final | |
sintel_clean = reduce( | |
lambda x, y: x.add_datasets(y), | |
[ | |
three_frame_wrapper2( | |
MpiSintel, | |
{ | |
"aug_params": aug_params, | |
"split": current_split, | |
"dstype": "clean", | |
}, | |
) | |
for current_split in ["train", "val"] | |
], | |
) | |
sintel_final = reduce( | |
lambda x, y: x.add_datasets(y), | |
[ | |
three_frame_wrapper2( | |
MpiSintel, | |
{ | |
"aug_params": aug_params, | |
"split": current_split, | |
"dstype": "final", | |
}, | |
) | |
for current_split in ["train", "val"] | |
], | |
) | |
kitti = reduce( | |
lambda x, y: x.add_datasets(y), | |
[ | |
three_frame_wrapper2( | |
KITTI, | |
{ | |
"aug_params": { | |
"crop_size": args.image_size, | |
"pre_scale": args.scale, | |
"min_scale": -0.3, | |
"max_scale": +0.5, | |
"do_flip": True, | |
}, | |
"split": current_split, | |
}, | |
) | |
for current_split in ["train", "val", "test"] | |
], | |
) | |
hd1k = three_frame_wrapper2( | |
HD1K, | |
{ | |
"aug_params": { | |
"crop_size": args.image_size, | |
"pre_scale": args.scale, | |
"min_scale": -0.5, | |
"max_scale": +0.2, | |
"do_flip": True, | |
} | |
}, | |
) | |
# After (left + right) * (frames_cleanpass + frames_finalpass) | |
train_dataset = ( | |
80 * sintel_clean + 80 * sintel_final + 320 * kitti + 120 * hd1k + things | |
) | |
elif args.dataset == "sintel-full": | |
aug_params = { | |
"crop_size": args.image_size, | |
"pre_scale": args.scale, | |
"min_scale": -0.2, | |
"max_scale": +0.6, | |
"do_flip": True, | |
"do_rotate": False, | |
} | |
sintel_clean = reduce( | |
lambda x, y: x.add_datasets(y), | |
[ | |
three_frame_wrapper2( | |
MpiSintel, | |
{ | |
"aug_params": aug_params, | |
"split": current_split, | |
"dstype": "clean", | |
}, | |
invalid_images="clip", | |
) | |
for current_split in ["train", "val"] | |
], | |
) | |
sintel_final = reduce( | |
lambda x, y: x.add_datasets(y), | |
[ | |
three_frame_wrapper2( | |
MpiSintel, | |
{ | |
"aug_params": aug_params, | |
"split": current_split, | |
"dstype": "final", | |
}, | |
invalid_images="clip", | |
) | |
for current_split in ["train", "val"] | |
], | |
) | |
train_dataset = sintel_clean + sintel_final | |
else: | |
raise ValueError(f"Invalid dataset name {args.dataset}") | |
train_loader = data.DataLoader( | |
train_dataset, | |
batch_size=args.batch_size, | |
pin_memory=False, | |
shuffle=True, | |
num_workers=args.num_workers, | |
drop_last=True, | |
) | |
print("Training with %d image pairs" % len(train_dataset)) | |
return train_loader | |