NightRaven109's picture
Upload 21 files
adfc685 verified
raw
history blame
2.79 kB
import cv2
import torch
from torch.utils.data import Dataset
from torchvision.transforms import Compose
from dataset.transform import Resize, NormalizeImage, PrepareForNet, Crop
class PBRDataset(Dataset):
def __init__(self, filelist_path, mode, size=(512, 512)):
self.mode = mode
self.size = size
# Read filelist using @@ as delimiter
self.filelist = []
with open(filelist_path, 'r') as f:
for line in f:
line = line.strip()
# Split on @@ delimiter
if '@@' in line: # Use @@ as delimiter between paths
parts = line.split('@@')
if len(parts) == 2:
self.filelist.append((parts[0].strip(), parts[1].strip()))
print(f"Loaded {len(self.filelist)} image pairs")
net_w, net_h = size
self.transform = Compose([
Resize(
width=net_w,
height=net_h,
resize_target=True if mode == 'train' else False,
keep_aspect_ratio=True,
ensure_multiple_of=12,
resize_method='lower_bound',
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
PrepareForNet(),
] + ([Crop(size[0])] if self.mode == 'train' else []))
def __getitem__(self, item):
try:
img_path, disp_path = self.filelist[item]
image = cv2.imread(img_path)
if image is None:
print(f"Failed to load image: {img_path}")
return self.__getitem__((item + 1) % len(self.filelist))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
depth = cv2.imread(disp_path, cv2.IMREAD_GRAYSCALE)
if depth is None:
print(f"Failed to load depth: {disp_path}")
return self.__getitem__((item + 1) % len(self.filelist))
depth = depth.astype('float32') / 255.0
sample = self.transform({'image': image, 'depth': depth})
sample['image'] = torch.from_numpy(sample['image'])
sample['depth'] = torch.from_numpy(sample['depth'])
sample['valid_mask'] = torch.ones_like(sample['depth'], dtype=torch.bool)
sample['image_path'] = img_path
return sample
except Exception as e:
print(f"Error loading {item}: {str(e)}")
return self.__getitem__((item + 1) % len(self.filelist))
def __len__(self):
return len(self.filelist)