LiheYoung's picture
Add Github repository content
2680cbd verified
raw
history blame
No virus
1.86 kB
import cv2
import torch
from torch.utils.data import Dataset
from torchvision.transforms import Compose
from dataset.transform import Resize, NormalizeImage, PrepareForNet
class KITTI(Dataset):
def __init__(self, filelist_path, mode, size=(518, 518)):
if mode != 'val':
raise NotImplementedError
self.mode = mode
self.size = size
with open(filelist_path, 'r') as f:
self.filelist = f.read().splitlines()
net_w, net_h = size
self.transform = Compose([
Resize(
width=net_w,
height=net_h,
resize_target=True if mode == 'train' else False,
keep_aspect_ratio=True,
ensure_multiple_of=14,
resize_method='lower_bound',
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
PrepareForNet(),
])
def __getitem__(self, item):
img_path = self.filelist[item].split(' ')[0]
depth_path = self.filelist[item].split(' ')[1]
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED).astype('float32')
sample = self.transform({'image': image, 'depth': depth})
sample['image'] = torch.from_numpy(sample['image'])
sample['depth'] = torch.from_numpy(sample['depth'])
sample['depth'] = sample['depth'] / 256.0 # convert in meters
sample['valid_mask'] = sample['depth'] > 0
sample['image_path'] = self.filelist[item].split(' ')[0]
return sample
def __len__(self):
return len(self.filelist)