import cv2 import torch from torch.utils.data import Dataset from torchvision.transforms import Compose from dataset.transform import Resize, NormalizeImage, PrepareForNet, Crop class VKITTI2(Dataset): def __init__(self, filelist_path, mode, size=(518, 518)): self.mode = mode self.size = size with open(filelist_path, 'r') as f: self.filelist = f.read().splitlines() net_w, net_h = size self.transform = Compose([ Resize( width=net_w, height=net_h, resize_target=True if mode == 'train' else False, keep_aspect_ratio=True, ensure_multiple_of=14, resize_method='lower_bound', image_interpolation_method=cv2.INTER_CUBIC, ), NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), PrepareForNet(), ] + ([Crop(size[0])] if self.mode == 'train' else [])) def __getitem__(self, item): img_path = self.filelist[item].split(' ')[0] depth_path = self.filelist[item].split(' ')[1] image = cv2.imread(img_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0 depth = cv2.imread(depth_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) / 100.0 # cm to m sample = self.transform({'image': image, 'depth': depth}) sample['image'] = torch.from_numpy(sample['image']) sample['depth'] = torch.from_numpy(sample['depth']) sample['valid_mask'] = (sample['depth'] <= 80) sample['image_path'] = self.filelist[item].split(' ')[0] return sample def __len__(self): return len(self.filelist)