Spaces:
Running
Running
# Copyright (C) 2022-present Naver Corporation. All rights reserved. | |
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
# -------------------------------------------------------- | |
# Dataset structure for stereo | |
# -------------------------------------------------------- | |
import sys, os | |
import os.path as osp | |
import pickle | |
import numpy as np | |
from PIL import Image | |
import json | |
import h5py | |
from glob import glob | |
import cv2 | |
import torch | |
from torch.utils import data | |
from .augmentor import StereoAugmentor | |
dataset_to_root = { | |
'CREStereo': './data/stereoflow//crenet_stereo_trainset/stereo_trainset/crestereo/', | |
'SceneFlow': './data/stereoflow//SceneFlow/', | |
'ETH3DLowRes': './data/stereoflow/eth3d_lowres/', | |
'Booster': './data/stereoflow/booster_gt/', | |
'Middlebury2021': './data/stereoflow/middlebury/2021/data/', | |
'Middlebury2014': './data/stereoflow/middlebury/2014/', | |
'Middlebury2006': './data/stereoflow/middlebury/2006/', | |
'Middlebury2005': './data/stereoflow/middlebury/2005/train/', | |
'MiddleburyEval3': './data/stereoflow/middlebury/MiddEval3/', | |
'Spring': './data/stereoflow/spring/', | |
'Kitti15': './data/stereoflow/kitti-stereo-2015/', | |
'Kitti12': './data/stereoflow/kitti-stereo-2012/', | |
} | |
cache_dir = "./data/stereoflow/datasets_stereo_cache/" | |
in1k_mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1) | |
in1k_std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1) | |
def img_to_tensor(img): | |
img = torch.from_numpy(img).permute(2, 0, 1).float() / 255. | |
img = (img-in1k_mean)/in1k_std | |
return img | |
def disp_to_tensor(disp): | |
return torch.from_numpy(disp)[None,:,:] | |
class StereoDataset(data.Dataset): | |
def __init__(self, split, augmentor=False, crop_size=None, totensor=True): | |
self.split = split | |
if not augmentor: assert crop_size is None | |
if crop_size: assert augmentor | |
self.crop_size = crop_size | |
self.augmentor_str = augmentor | |
self.augmentor = StereoAugmentor(crop_size) if augmentor else None | |
self.totensor = totensor | |
self.rmul = 1 # keep track of rmul | |
self.has_constant_resolution = True # whether the dataset has constant resolution or not (=> don't use batch_size>1 at test time) | |
self._prepare_data() | |
self._load_or_build_cache() | |
def prepare_data(self): | |
""" | |
to be defined for each dataset | |
""" | |
raise NotImplementedError | |
def __len__(self): | |
return len(self.pairnames) | |
def __getitem__(self, index): | |
pairname = self.pairnames[index] | |
# get filenames | |
Limgname = self.pairname_to_Limgname(pairname) | |
Rimgname = self.pairname_to_Rimgname(pairname) | |
Ldispname = self.pairname_to_Ldispname(pairname) if self.pairname_to_Ldispname is not None else None | |
# load images and disparities | |
Limg = _read_img(Limgname) | |
Rimg = _read_img(Rimgname) | |
disp = self.load_disparity(Ldispname) if Ldispname is not None else None | |
# sanity check | |
if disp is not None: assert np.all(disp>0) or self.name=="Spring", (self.name, pairname, Ldispname) | |
# apply augmentations | |
if self.augmentor is not None: | |
Limg, Rimg, disp = self.augmentor(Limg, Rimg, disp, self.name) | |
if self.totensor: | |
Limg = img_to_tensor(Limg) | |
Rimg = img_to_tensor(Rimg) | |
if disp is None: | |
disp = torch.tensor([]) # to allow dataloader batching with default collate_gn | |
else: | |
disp = disp_to_tensor(disp) | |
return Limg, Rimg, disp, str(pairname) | |
def __rmul__(self, v): | |
self.rmul *= v | |
self.pairnames = v * self.pairnames | |
return self | |
def __str__(self): | |
return f'{self.__class__.__name__}_{self.split}' | |
def __repr__(self): | |
s = f'{self.__class__.__name__}(split={self.split}, augmentor={self.augmentor_str}, crop_size={str(self.crop_size)}, totensor={self.totensor})' | |
if self.rmul==1: | |
s+=f'\n\tnum pairs: {len(self.pairnames)}' | |
else: | |
s+=f'\n\tnum pairs: {len(self.pairnames)} ({len(self.pairnames)//self.rmul}x{self.rmul})' | |
return s | |
def _set_root(self): | |
self.root = dataset_to_root[self.name] | |
assert os.path.isdir(self.root), f"could not find root directory for dataset {self.name}: {self.root}" | |
def _load_or_build_cache(self): | |
cache_file = osp.join(cache_dir, self.name+'.pkl') | |
if osp.isfile(cache_file): | |
with open(cache_file, 'rb') as fid: | |
self.pairnames = pickle.load(fid)[self.split] | |
else: | |
tosave = self._build_cache() | |
os.makedirs(cache_dir, exist_ok=True) | |
with open(cache_file, 'wb') as fid: | |
pickle.dump(tosave, fid) | |
self.pairnames = tosave[self.split] | |
class CREStereoDataset(StereoDataset): | |
def _prepare_data(self): | |
self.name = 'CREStereo' | |
self._set_root() | |
assert self.split in ['train'] | |
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname+'_left.jpg') | |
self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname+'_right.jpg') | |
self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname+'_left.disp.png') | |
self.pairname_to_str = lambda pairname: pairname | |
self.load_disparity = _read_crestereo_disp | |
def _build_cache(self): | |
allpairs = [s+'/'+f[:-len('_left.jpg')] for s in sorted(os.listdir(self.root)) for f in sorted(os.listdir(self.root+'/'+s)) if f.endswith('_left.jpg')] | |
assert len(allpairs)==200000, "incorrect parsing of pairs in CreStereo" | |
tosave = {'train': allpairs} | |
return tosave | |
class SceneFlowDataset(StereoDataset): | |
def _prepare_data(self): | |
self.name = "SceneFlow" | |
self._set_root() | |
assert self.split in ['train_finalpass','train_cleanpass','train_allpass','test_finalpass','test_cleanpass','test_allpass','test1of100_cleanpass','test1of100_finalpass'] | |
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) | |
self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname).replace('/left/','/right/') | |
self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname).replace('/frames_finalpass/','/disparity/').replace('/frames_cleanpass/','/disparity/')[:-4]+'.pfm' | |
self.pairname_to_str = lambda pairname: pairname[:-4] | |
self.load_disparity = _read_sceneflow_disp | |
def _build_cache(self): | |
trainpairs = [] | |
# driving | |
pairs = sorted(glob(self.root+'Driving/frames_finalpass/*/*/*/left/*.png')) | |
pairs = list(map(lambda x: x[len(self.root):], pairs)) | |
assert len(pairs) == 4400, "incorrect parsing of pairs in SceneFlow" | |
trainpairs += pairs | |
# monkaa | |
pairs = sorted(glob(self.root+'Monkaa/frames_finalpass/*/left/*.png')) | |
pairs = list(map(lambda x: x[len(self.root):], pairs)) | |
assert len(pairs) == 8664, "incorrect parsing of pairs in SceneFlow" | |
trainpairs += pairs | |
# flyingthings | |
pairs = sorted(glob(self.root+'FlyingThings/frames_finalpass/TRAIN/*/*/left/*.png')) | |
pairs = list(map(lambda x: x[len(self.root):], pairs)) | |
assert len(pairs) == 22390, "incorrect parsing of pairs in SceneFlow" | |
trainpairs += pairs | |
assert len(trainpairs) == 35454, "incorrect parsing of pairs in SceneFlow" | |
testpairs = sorted(glob(self.root+'FlyingThings/frames_finalpass/TEST/*/*/left/*.png')) | |
testpairs = list(map(lambda x: x[len(self.root):], testpairs)) | |
assert len(testpairs) == 4370, "incorrect parsing of pairs in SceneFlow" | |
test1of100pairs = testpairs[::100] | |
assert len(test1of100pairs) == 44, "incorrect parsing of pairs in SceneFlow" | |
# all | |
tosave = {'train_finalpass': trainpairs, | |
'train_cleanpass': list(map(lambda x: x.replace('frames_finalpass','frames_cleanpass'), trainpairs)), | |
'test_finalpass': testpairs, | |
'test_cleanpass': list(map(lambda x: x.replace('frames_finalpass','frames_cleanpass'), testpairs)), | |
'test1of100_finalpass': test1of100pairs, | |
'test1of100_cleanpass': list(map(lambda x: x.replace('frames_finalpass','frames_cleanpass'), test1of100pairs)), | |
} | |
tosave['train_allpass'] = tosave['train_finalpass']+tosave['train_cleanpass'] | |
tosave['test_allpass'] = tosave['test_finalpass']+tosave['test_cleanpass'] | |
return tosave | |
class Md21Dataset(StereoDataset): | |
def _prepare_data(self): | |
self.name = "Middlebury2021" | |
self._set_root() | |
assert self.split in ['train','subtrain','subval'] | |
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) | |
self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname.replace('/im0','/im1')) | |
self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname.split('/')[0], 'disp0.pfm') | |
self.pairname_to_str = lambda pairname: pairname[:-4] | |
self.load_disparity = _read_middlebury_disp | |
def _build_cache(self): | |
seqs = sorted(os.listdir(self.root)) | |
trainpairs = [] | |
for s in seqs: | |
#trainpairs += [s+'/im0.png'] # we should remove it, it is included as such in other lightings | |
trainpairs += [s+'/ambient/'+b+'/'+a for b in sorted(os.listdir(osp.join(self.root,s,'ambient'))) for a in sorted(os.listdir(osp.join(self.root,s,'ambient',b))) if a.startswith('im0')] | |
assert len(trainpairs)==355 | |
subtrainpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in seqs[:-2])] | |
subvalpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in seqs[-2:])] | |
assert len(subtrainpairs)==335 and len(subvalpairs)==20, "incorrect parsing of pairs in Middlebury 2021" | |
tosave = {'train': trainpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} | |
return tosave | |
class Md14Dataset(StereoDataset): | |
def _prepare_data(self): | |
self.name = "Middlebury2014" | |
self._set_root() | |
assert self.split in ['train','subtrain','subval'] | |
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, osp.dirname(pairname), 'im0.png') | |
self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname) | |
self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, osp.dirname(pairname), 'disp0.pfm') | |
self.pairname_to_str = lambda pairname: pairname[:-4] | |
self.load_disparity = _read_middlebury_disp | |
self.has_constant_resolution = False | |
def _build_cache(self): | |
seqs = sorted(os.listdir(self.root)) | |
trainpairs = [] | |
for s in seqs: | |
trainpairs += [s+'/im1.png',s+'/im1E.png',s+'/im1L.png'] | |
assert len(trainpairs)==138 | |
valseqs = ['Umbrella-imperfect','Vintage-perfect'] | |
assert all(s in seqs for s in valseqs) | |
subtrainpairs = [p for p in trainpairs if not any(p.startswith(s+'/') for s in valseqs)] | |
subvalpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in valseqs)] | |
assert len(subtrainpairs)==132 and len(subvalpairs)==6, "incorrect parsing of pairs in Middlebury 2014" | |
tosave = {'train': trainpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} | |
return tosave | |
class Md06Dataset(StereoDataset): | |
def _prepare_data(self): | |
self.name = "Middlebury2006" | |
self._set_root() | |
assert self.split in ['train','subtrain','subval'] | |
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) | |
self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, osp.dirname(pairname), 'view5.png') | |
self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname.split('/')[0], 'disp1.png') | |
self.load_disparity = _read_middlebury20052006_disp | |
self.has_constant_resolution = False | |
def _build_cache(self): | |
seqs = sorted(os.listdir(self.root)) | |
trainpairs = [] | |
for s in seqs: | |
for i in ['Illum1','Illum2','Illum3']: | |
for e in ['Exp0','Exp1','Exp2']: | |
trainpairs.append(osp.join(s,i,e,'view1.png')) | |
assert len(trainpairs)==189 | |
valseqs = ['Rocks1','Wood2'] | |
assert all(s in seqs for s in valseqs) | |
subtrainpairs = [p for p in trainpairs if not any(p.startswith(s+'/') for s in valseqs)] | |
subvalpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in valseqs)] | |
assert len(subtrainpairs)==171 and len(subvalpairs)==18, "incorrect parsing of pairs in Middlebury 2006" | |
tosave = {'train': trainpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} | |
return tosave | |
class Md05Dataset(StereoDataset): | |
def _prepare_data(self): | |
self.name = "Middlebury2005" | |
self._set_root() | |
assert self.split in ['train','subtrain','subval'] | |
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) | |
self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, osp.dirname(pairname), 'view5.png') | |
self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname.split('/')[0], 'disp1.png') | |
self.pairname_to_str = lambda pairname: pairname[:-4] | |
self.load_disparity = _read_middlebury20052006_disp | |
def _build_cache(self): | |
seqs = sorted(os.listdir(self.root)) | |
trainpairs = [] | |
for s in seqs: | |
for i in ['Illum1','Illum2','Illum3']: | |
for e in ['Exp0','Exp1','Exp2']: | |
trainpairs.append(osp.join(s,i,e,'view1.png')) | |
assert len(trainpairs)==54, "incorrect parsing of pairs in Middlebury 2005" | |
valseqs = ['Reindeer'] | |
assert all(s in seqs for s in valseqs) | |
subtrainpairs = [p for p in trainpairs if not any(p.startswith(s+'/') for s in valseqs)] | |
subvalpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in valseqs)] | |
assert len(subtrainpairs)==45 and len(subvalpairs)==9, "incorrect parsing of pairs in Middlebury 2005" | |
tosave = {'train': trainpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} | |
return tosave | |
class MdEval3Dataset(StereoDataset): | |
def _prepare_data(self): | |
self.name = "MiddleburyEval3" | |
self._set_root() | |
assert self.split in [s+'_'+r for s in ['train','subtrain','subval','test','all'] for r in ['full','half','quarter']] | |
if self.split.endswith('_full'): | |
self.root = self.root.replace('/MiddEval3','/MiddEval3_F') | |
elif self.split.endswith('_half'): | |
self.root = self.root.replace('/MiddEval3','/MiddEval3_H') | |
else: | |
assert self.split.endswith('_quarter') | |
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname, 'im0.png') | |
self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname, 'im1.png') | |
self.pairname_to_Ldispname = lambda pairname: None if pairname.startswith('test') else osp.join(self.root, pairname, 'disp0GT.pfm') | |
self.pairname_to_str = lambda pairname: pairname | |
self.load_disparity = _read_middlebury_disp | |
# for submission only | |
self.submission_methodname = "CroCo-Stereo" | |
self.submission_sresolution = 'F' if self.split.endswith('_full') else ('H' if self.split.endswith('_half') else 'Q') | |
def _build_cache(self): | |
trainpairs = ['train/'+s for s in sorted(os.listdir(self.root+'train/'))] | |
testpairs = ['test/'+s for s in sorted(os.listdir(self.root+'test/'))] | |
subvalpairs = trainpairs[-1:] | |
subtrainpairs = trainpairs[:-1] | |
allpairs = trainpairs+testpairs | |
assert len(trainpairs)==15 and len(testpairs)==15 and len(subvalpairs)==1 and len(subtrainpairs)==14 and len(allpairs)==30, "incorrect parsing of pairs in Middlebury Eval v3" | |
tosave = {} | |
for r in ['full','half','quarter']: | |
tosave.update(**{'train_'+r: trainpairs, 'subtrain_'+r: subtrainpairs, 'subval_'+r: subvalpairs, 'test_'+r: testpairs, 'all_'+r: allpairs}) | |
return tosave | |
def submission_save_pairname(self, pairname, prediction, outdir, time): | |
assert prediction.ndim==2 | |
assert prediction.dtype==np.float32 | |
outfile = os.path.join(outdir, pairname.split('/')[0].replace('train','training')+self.submission_sresolution, pairname.split('/')[1], 'disp0'+self.submission_methodname+'.pfm') | |
os.makedirs( os.path.dirname(outfile), exist_ok=True) | |
writePFM(outfile, prediction) | |
timefile = os.path.join( os.path.dirname(outfile), "time"+self.submission_methodname+'.txt') | |
with open(timefile, 'w') as fid: | |
fid.write(str(time)) | |
def finalize_submission(self, outdir): | |
cmd = f'cd {outdir}/; zip -r "{self.submission_methodname}.zip" .' | |
print(cmd) | |
os.system(cmd) | |
print(f'Done. Submission file at {outdir}/{self.submission_methodname}.zip') | |
class ETH3DLowResDataset(StereoDataset): | |
def _prepare_data(self): | |
self.name = "ETH3DLowRes" | |
self._set_root() | |
assert self.split in ['train','test','subtrain','subval','all'] | |
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname, 'im0.png') | |
self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname, 'im1.png') | |
self.pairname_to_Ldispname = None if self.split=='test' else lambda pairname: None if pairname.startswith('test/') else osp.join(self.root, pairname.replace('train/','train_gt/'), 'disp0GT.pfm') | |
self.pairname_to_str = lambda pairname: pairname | |
self.load_disparity = _read_eth3d_disp | |
self.has_constant_resolution = False | |
def _build_cache(self): | |
trainpairs = ['train/' + s for s in sorted(os.listdir(self.root+'train/'))] | |
testpairs = ['test/' + s for s in sorted(os.listdir(self.root+'test/'))] | |
assert len(trainpairs) == 27 and len(testpairs) == 20, "incorrect parsing of pairs in ETH3D Low Res" | |
subvalpairs = ['train/delivery_area_3s','train/electro_3l','train/playground_3l'] | |
assert all(p in trainpairs for p in subvalpairs) | |
subtrainpairs = [p for p in trainpairs if not p in subvalpairs] | |
assert len(subvalpairs)==3 and len(subtrainpairs)==24, "incorrect parsing of pairs in ETH3D Low Res" | |
tosave = {'train': trainpairs, 'test': testpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs, 'all': trainpairs+testpairs} | |
return tosave | |
def submission_save_pairname(self, pairname, prediction, outdir, time): | |
assert prediction.ndim==2 | |
assert prediction.dtype==np.float32 | |
outfile = os.path.join(outdir, 'low_res_two_view', pairname.split('/')[1]+'.pfm') | |
os.makedirs( os.path.dirname(outfile), exist_ok=True) | |
writePFM(outfile, prediction) | |
timefile = outfile[:-4]+'.txt' | |
with open(timefile, 'w') as fid: | |
fid.write('runtime '+str(time)) | |
def finalize_submission(self, outdir): | |
cmd = f'cd {outdir}/; zip -r "eth3d_low_res_two_view_results.zip" low_res_two_view' | |
print(cmd) | |
os.system(cmd) | |
print(f'Done. Submission file at {outdir}/eth3d_low_res_two_view_results.zip') | |
class BoosterDataset(StereoDataset): | |
def _prepare_data(self): | |
self.name = "Booster" | |
self._set_root() | |
assert self.split in ['train_balanced','test_balanced','subtrain_balanced','subval_balanced'] # we use only the balanced version | |
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) | |
self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname).replace('/camera_00/','/camera_02/') | |
self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, osp.dirname(pairname), '../disp_00.npy') # same images with different colors, same gt per sequence | |
self.pairname_to_str = lambda pairname: pairname[:-4].replace('/camera_00/','/') | |
self.load_disparity = _read_booster_disp | |
def _build_cache(self): | |
trainseqs = sorted(os.listdir(self.root+'train/balanced')) | |
trainpairs = ['train/balanced/'+s+'/camera_00/'+imname for s in trainseqs for imname in sorted(os.listdir(self.root+'train/balanced/'+s+'/camera_00/'))] | |
testpairs = ['test/balanced/'+s+'/camera_00/'+imname for s in sorted(os.listdir(self.root+'test/balanced')) for imname in sorted(os.listdir(self.root+'test/balanced/'+s+'/camera_00/'))] | |
assert len(trainpairs) == 228 and len(testpairs) == 191 | |
subtrainpairs = [p for p in trainpairs if any(s in p for s in trainseqs[:-2])] | |
subvalpairs = [p for p in trainpairs if any(s in p for s in trainseqs[-2:])] | |
# warning: if we do validation split, we should split scenes!!! | |
tosave = {'train_balanced': trainpairs, 'test_balanced': testpairs, 'subtrain_balanced': subtrainpairs, 'subval_balanced': subvalpairs,} | |
return tosave | |
class SpringDataset(StereoDataset): | |
def _prepare_data(self): | |
self.name = "Spring" | |
self._set_root() | |
assert self.split in ['train', 'test', 'subtrain', 'subval'] | |
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname+'.png') | |
self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname+'.png').replace('frame_right','<frame_right>').replace('frame_left','frame_right').replace('<frame_right>','frame_left') | |
self.pairname_to_Ldispname = lambda pairname: None if pairname.startswith('test') else osp.join(self.root, pairname+'.dsp5').replace('frame_left','disp1_left').replace('frame_right','disp1_right') | |
self.pairname_to_str = lambda pairname: pairname | |
self.load_disparity = _read_hdf5_disp | |
def _build_cache(self): | |
trainseqs = sorted(os.listdir( osp.join(self.root,'train'))) | |
trainpairs = [osp.join('train',s,'frame_left',f[:-4]) for s in trainseqs for f in sorted(os.listdir(osp.join(self.root,'train',s,'frame_left')))] | |
testseqs = sorted(os.listdir( osp.join(self.root,'test'))) | |
testpairs = [osp.join('test',s,'frame_left',f[:-4]) for s in testseqs for f in sorted(os.listdir(osp.join(self.root,'test',s,'frame_left')))] | |
testpairs += [p.replace('frame_left','frame_right') for p in testpairs] | |
"""maxnorm = {'0001': 32.88, '0002': 228.5, '0004': 298.2, '0005': 142.5, '0006': 113.6, '0007': 27.3, '0008': 554.5, '0009': 155.6, '0010': 126.1, '0011': 87.6, '0012': 303.2, '0013': 24.14, '0014': 82.56, '0015': 98.44, '0016': 156.9, '0017': 28.17, '0018': 21.03, '0020': 178.0, '0021': 58.06, '0022': 354.2, '0023': 8.79, '0024': 97.06, '0025': 55.16, '0026': 91.9, '0027': 156.6, '0030': 200.4, '0032': 58.66, '0033': 373.5, '0036': 149.4, '0037': 5.625, '0038': 37.0, '0039': 12.2, '0041': 453.5, '0043': 457.0, '0044': 379.5, '0045': 161.8, '0047': 105.44} # => let'use 0041""" | |
subtrainpairs = [p for p in trainpairs if p.split('/')[1]!='0041'] | |
subvalpairs = [p for p in trainpairs if p.split('/')[1]=='0041'] | |
assert len(trainpairs)==5000 and len(testpairs)==2000 and len(subtrainpairs)==4904 and len(subvalpairs)==96, "incorrect parsing of pairs in Spring" | |
tosave = {'train': trainpairs, 'test': testpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} | |
return tosave | |
def submission_save_pairname(self, pairname, prediction, outdir, time): | |
assert prediction.ndim==2 | |
assert prediction.dtype==np.float32 | |
outfile = os.path.join(outdir, pairname+'.dsp5').replace('frame_left','disp1_left').replace('frame_right','disp1_right') | |
os.makedirs( os.path.dirname(outfile), exist_ok=True) | |
writeDsp5File(prediction, outfile) | |
def finalize_submission(self, outdir): | |
assert self.split=='test' | |
exe = "{self.root}/disp1_subsampling" | |
if os.path.isfile(exe): | |
cmd = f'cd "{outdir}/test"; {exe} .' | |
print(cmd) | |
os.system(cmd) | |
else: | |
print('Could not find disp1_subsampling executable for submission.') | |
print('Please download it and run:') | |
print(f'cd "{outdir}/test"; <disp1_subsampling_exe> .') | |
class Kitti12Dataset(StereoDataset): | |
def _prepare_data(self): | |
self.name = "Kitti12" | |
self._set_root() | |
assert self.split in ['train','test'] | |
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname+'_10.png') | |
self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname.replace('/colored_0/','/colored_1/')+'_10.png') | |
self.pairname_to_Ldispname = None if self.split=='test' else lambda pairname: osp.join(self.root, pairname.replace('/colored_0/','/disp_occ/')+'_10.png') | |
self.pairname_to_str = lambda pairname: pairname.replace('/colored_0/','/') | |
self.load_disparity = _read_kitti_disp | |
def _build_cache(self): | |
trainseqs = ["training/colored_0/%06d"%(i) for i in range(194)] | |
testseqs = ["testing/colored_0/%06d"%(i) for i in range(195)] | |
assert len(trainseqs)==194 and len(testseqs)==195, "incorrect parsing of pairs in Kitti12" | |
tosave = {'train': trainseqs, 'test': testseqs} | |
return tosave | |
def submission_save_pairname(self, pairname, prediction, outdir, time): | |
assert prediction.ndim==2 | |
assert prediction.dtype==np.float32 | |
outfile = os.path.join(outdir, pairname.split('/')[-1]+'_10.png') | |
os.makedirs( os.path.dirname(outfile), exist_ok=True) | |
img = (prediction * 256).astype('uint16') | |
Image.fromarray(img).save(outfile) | |
def finalize_submission(self, outdir): | |
assert self.split=='test' | |
cmd = f'cd {outdir}/; zip -r "kitti12_results.zip" .' | |
print(cmd) | |
os.system(cmd) | |
print(f'Done. Submission file at {outdir}/kitti12_results.zip') | |
class Kitti15Dataset(StereoDataset): | |
def _prepare_data(self): | |
self.name = "Kitti15" | |
self._set_root() | |
assert self.split in ['train','subtrain','subval','test'] | |
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname+'_10.png') | |
self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname.replace('/image_2/','/image_3/')+'_10.png') | |
self.pairname_to_Ldispname = None if self.split=='test' else lambda pairname: osp.join(self.root, pairname.replace('/image_2/','/disp_occ_0/')+'_10.png') | |
self.pairname_to_str = lambda pairname: pairname.replace('/image_2/','/') | |
self.load_disparity = _read_kitti_disp | |
def _build_cache(self): | |
trainseqs = ["training/image_2/%06d"%(i) for i in range(200)] | |
subtrainseqs = trainseqs[:-5] | |
subvalseqs = trainseqs[-5:] | |
testseqs = ["testing/image_2/%06d"%(i) for i in range(200)] | |
assert len(trainseqs)==200 and len(subtrainseqs)==195 and len(subvalseqs)==5 and len(testseqs)==200, "incorrect parsing of pairs in Kitti15" | |
tosave = {'train': trainseqs, 'subtrain': subtrainseqs, 'subval': subvalseqs, 'test': testseqs} | |
return tosave | |
def submission_save_pairname(self, pairname, prediction, outdir, time): | |
assert prediction.ndim==2 | |
assert prediction.dtype==np.float32 | |
outfile = os.path.join(outdir, 'disp_0', pairname.split('/')[-1]+'_10.png') | |
os.makedirs( os.path.dirname(outfile), exist_ok=True) | |
img = (prediction * 256).astype('uint16') | |
Image.fromarray(img).save(outfile) | |
def finalize_submission(self, outdir): | |
assert self.split=='test' | |
cmd = f'cd {outdir}/; zip -r "kitti15_results.zip" disp_0' | |
print(cmd) | |
os.system(cmd) | |
print(f'Done. Submission file at {outdir}/kitti15_results.zip') | |
### auxiliary functions | |
def _read_img(filename): | |
# convert to RGB for scene flow finalpass data | |
img = np.asarray(Image.open(filename).convert('RGB')) | |
return img | |
def _read_booster_disp(filename): | |
disp = np.load(filename) | |
disp[disp==0.0] = np.inf | |
return disp | |
def _read_png_disp(filename, coef=1.0): | |
disp = np.asarray(Image.open(filename)) | |
disp = disp.astype(np.float32) / coef | |
disp[disp==0.0] = np.inf | |
return disp | |
def _read_pfm_disp(filename): | |
disp = np.ascontiguousarray(_read_pfm(filename)[0]) | |
disp[disp<=0] = np.inf # eg /nfs/data/ffs-3d/datasets/middlebury/2014/Shopvac-imperfect/disp0.pfm | |
return disp | |
def _read_npy_disp(filename): | |
return np.load(filename) | |
def _read_crestereo_disp(filename): return _read_png_disp(filename, coef=32.0) | |
def _read_middlebury20052006_disp(filename): return _read_png_disp(filename, coef=1.0) | |
def _read_kitti_disp(filename): return _read_png_disp(filename, coef=256.0) | |
_read_sceneflow_disp = _read_pfm_disp | |
_read_eth3d_disp = _read_pfm_disp | |
_read_middlebury_disp = _read_pfm_disp | |
_read_carla_disp = _read_pfm_disp | |
_read_tartanair_disp = _read_npy_disp | |
def _read_hdf5_disp(filename): | |
disp = np.asarray(h5py.File(filename)['disparity']) | |
disp[np.isnan(disp)] = np.inf # make invalid values as +inf | |
#disp[disp==0.0] = np.inf # make invalid values as +inf | |
return disp.astype(np.float32) | |
import re | |
def _read_pfm(file): | |
file = open(file, 'rb') | |
color = None | |
width = None | |
height = None | |
scale = None | |
endian = None | |
header = file.readline().rstrip() | |
if header.decode("ascii") == 'PF': | |
color = True | |
elif header.decode("ascii") == 'Pf': | |
color = False | |
else: | |
raise Exception('Not a PFM file.') | |
dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii")) | |
if dim_match: | |
width, height = list(map(int, dim_match.groups())) | |
else: | |
raise Exception('Malformed PFM header.') | |
scale = float(file.readline().decode("ascii").rstrip()) | |
if scale < 0: # little-endian | |
endian = '<' | |
scale = -scale | |
else: | |
endian = '>' # big-endian | |
data = np.fromfile(file, endian + 'f') | |
shape = (height, width, 3) if color else (height, width) | |
data = np.reshape(data, shape) | |
data = np.flipud(data) | |
return data, scale | |
def writePFM(file, image, scale=1): | |
file = open(file, 'wb') | |
color = None | |
if image.dtype.name != 'float32': | |
raise Exception('Image dtype must be float32.') | |
image = np.flipud(image) | |
if len(image.shape) == 3 and image.shape[2] == 3: # color image | |
color = True | |
elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale | |
color = False | |
else: | |
raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') | |
file.write('PF\n' if color else 'Pf\n'.encode()) | |
file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0])) | |
endian = image.dtype.byteorder | |
if endian == '<' or endian == '=' and sys.byteorder == 'little': | |
scale = -scale | |
file.write('%f\n'.encode() % scale) | |
image.tofile(file) | |
def writeDsp5File(disp, filename): | |
with h5py.File(filename, "w") as f: | |
f.create_dataset("disparity", data=disp, compression="gzip", compression_opts=5) | |
# disp visualization | |
def vis_disparity(disp, m=None, M=None): | |
if m is None: m = disp.min() | |
if M is None: M = disp.max() | |
disp_vis = (disp - m) / (M-m) * 255.0 | |
disp_vis = disp_vis.astype("uint8") | |
disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) | |
return disp_vis | |
# dataset getter | |
def get_train_dataset_stereo(dataset_str, augmentor=True, crop_size=None): | |
dataset_str = dataset_str.replace('(','Dataset(') | |
if augmentor: | |
dataset_str = dataset_str.replace(')',', augmentor=True)') | |
if crop_size is not None: | |
dataset_str = dataset_str.replace(')',', crop_size={:s})'.format(str(crop_size))) | |
return eval(dataset_str) | |
def get_test_datasets_stereo(dataset_str): | |
dataset_str = dataset_str.replace('(','Dataset(') | |
return [eval(s) for s in dataset_str.split('+')] |