Spaces:
Sleeping
Sleeping
File size: 5,518 Bytes
899c526 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import numpy as np
import torch
import torch.utils.data as data
import torch.nn.functional as F
import csv
import os
import cv2
import math
import random
import json
import pickle
import os.path as osp
from .augmentation import RGBDAugmentor
from .rgbd_utils import *
class RGBDDataset(data.Dataset):
def __init__(self, name, datapath, n_frames=4, crop_size=[480,640], fmin=10.0, fmax=75.0, aug=True, sample=True):
""" Base class for RGBD dataset """
self.aug = None
self.root = datapath
self.name = name
self.aug = aug
self.sample = sample
self.n_frames = n_frames
self.fmin = fmin # exclude very easy examples
self.fmax = fmax # exclude very hard examples
if self.aug:
self.aug = RGBDAugmentor(crop_size=crop_size)
# building dataset is expensive, cache so only needs to be performed once
cur_path = osp.dirname(osp.abspath(__file__))
if not os.path.isdir(osp.join(cur_path, 'cache')):
os.mkdir(osp.join(cur_path, 'cache'))
self.scene_info = \
pickle.load(open('datasets/TartanAir.pickle', 'rb'))[0]
self._build_dataset_index()
def _build_dataset_index(self):
self.dataset_index = []
for scene in self.scene_info:
if not self.__class__.is_test_scene(scene):
graph = self.scene_info[scene]['graph']
for i in graph:
if i < len(graph) - 65:
self.dataset_index.append((scene, i))
else:
print("Reserving {} for validation".format(scene))
@staticmethod
def image_read(image_file):
return cv2.imread(image_file)
@staticmethod
def depth_read(depth_file):
return np.load(depth_file)
def build_frame_graph(self, poses, depths, intrinsics, f=16, max_flow=256):
""" compute optical flow distance between all pairs of frames """
def read_disp(fn):
depth = self.__class__.depth_read(fn)[f//2::f, f//2::f]
depth[depth < 0.01] = np.mean(depth)
return 1.0 / depth
poses = np.array(poses)
intrinsics = np.array(intrinsics) / f
disps = np.stack(list(map(read_disp, depths)), 0)
d = f * compute_distance_matrix_flow(poses, disps, intrinsics)
graph = {}
for i in range(d.shape[0]):
j, = np.where(d[i] < max_flow)
graph[i] = (j, d[i,j])
return graph
def __getitem__(self, index):
""" return training video """
index = index % len(self.dataset_index)
scene_id, ix = self.dataset_index[index]
frame_graph = self.scene_info[scene_id]['graph']
images_list = self.scene_info[scene_id]['images']
depths_list = self.scene_info[scene_id]['depths']
poses_list = self.scene_info[scene_id]['poses']
intrinsics_list = self.scene_info[scene_id]['intrinsics']
# stride = np.random.choice([1,2,3])
d = np.random.uniform(self.fmin, self.fmax)
s = 1
inds = [ ix ]
while len(inds) < self.n_frames:
# get other frames within flow threshold
if self.sample:
k = (frame_graph[ix][1] > self.fmin) & (frame_graph[ix][1] < self.fmax)
frames = frame_graph[ix][0][k]
# prefer frames forward in time
if np.count_nonzero(frames[frames > ix]):
ix = np.random.choice(frames[frames > ix])
elif ix + 1 < len(images_list):
ix = ix + 1
elif np.count_nonzero(frames):
ix = np.random.choice(frames)
else:
i = frame_graph[ix][0].copy()
g = frame_graph[ix][1].copy()
g[g > d] = -1
if s > 0:
g[i <= ix] = -1
else:
g[i >= ix] = -1
if len(g) > 0 and np.max(g) > 0:
ix = i[np.argmax(g)]
else:
if ix + s >= len(images_list) or ix + s < 0:
s *= -1
ix = ix + s
inds += [ ix ]
images, depths, poses, intrinsics = [], [], [], []
for i in inds:
images.append(self.__class__.image_read(images_list[i]))
depths.append(self.__class__.depth_read(depths_list[i]))
poses.append(poses_list[i])
intrinsics.append(intrinsics_list[i])
images = np.stack(images).astype(np.float32)
depths = np.stack(depths).astype(np.float32)
poses = np.stack(poses).astype(np.float32)
intrinsics = np.stack(intrinsics).astype(np.float32)
images = torch.from_numpy(images).float()
images = images.permute(0, 3, 1, 2)
disps = torch.from_numpy(1.0 / depths)
poses = torch.from_numpy(poses)
intrinsics = torch.from_numpy(intrinsics)
if self.aug:
images, poses, disps, intrinsics = \
self.aug(images, poses, disps, intrinsics)
# normalize depth
s = .7 * torch.quantile(disps, .98)
disps = disps / s
poses[...,:3] *= s
return images, poses, disps, intrinsics
def __len__(self):
return len(self.dataset_index)
def __imul__(self, x):
self.dataset_index *= x
return self
|