import cv2 import torch from torch.utils.data import Dataset from torchvision.transforms import Compose from dataset.transform import Resize, NormalizeImage, PrepareForNet, Crop class PBRDataset(Dataset): def __init__(self, filelist_path, mode, size=(512, 512)): self.mode = mode self.size = size # Read filelist using @@ as delimiter self.filelist = [] with open(filelist_path, 'r') as f: for line in f: line = line.strip() # Split on @@ delimiter if '@@' in line: # Use @@ as delimiter between paths parts = line.split('@@') if len(parts) == 2: self.filelist.append((parts[0].strip(), parts[1].strip())) print(f"Loaded {len(self.filelist)} image pairs") net_w, net_h = size self.transform = Compose([ Resize( width=net_w, height=net_h, resize_target=True if mode == 'train' else False, keep_aspect_ratio=True, ensure_multiple_of=12, resize_method='lower_bound', image_interpolation_method=cv2.INTER_CUBIC, ), NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), PrepareForNet(), ] + ([Crop(size[0])] if self.mode == 'train' else [])) def __getitem__(self, item): try: img_path, disp_path = self.filelist[item] image = cv2.imread(img_path) if image is None: print(f"Failed to load image: {img_path}") return self.__getitem__((item + 1) % len(self.filelist)) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0 depth = cv2.imread(disp_path, cv2.IMREAD_GRAYSCALE) if depth is None: print(f"Failed to load depth: {disp_path}") return self.__getitem__((item + 1) % len(self.filelist)) depth = depth.astype('float32') / 255.0 sample = self.transform({'image': image, 'depth': depth}) sample['image'] = torch.from_numpy(sample['image']) sample['depth'] = torch.from_numpy(sample['depth']) sample['valid_mask'] = torch.ones_like(sample['depth'], dtype=torch.bool) sample['image_path'] = img_path return sample except Exception as e: print(f"Error loading {item}: {str(e)}") return self.__getitem__((item + 1) % len(self.filelist)) def __len__(self): return len(self.filelist)