Spaces:
Running
on
Zero
Running
on
Zero
import cv2 | |
import torch | |
from torch.utils.data import Dataset | |
from torchvision.transforms import Compose | |
from dataset.transform import Resize, NormalizeImage, PrepareForNet, Crop | |
class PBRDataset(Dataset): | |
def __init__(self, filelist_path, mode, size=(512, 512)): | |
self.mode = mode | |
self.size = size | |
# Read filelist using @@ as delimiter | |
self.filelist = [] | |
with open(filelist_path, 'r') as f: | |
for line in f: | |
line = line.strip() | |
# Split on @@ delimiter | |
if '@@' in line: # Use @@ as delimiter between paths | |
parts = line.split('@@') | |
if len(parts) == 2: | |
self.filelist.append((parts[0].strip(), parts[1].strip())) | |
print(f"Loaded {len(self.filelist)} image pairs") | |
net_w, net_h = size | |
self.transform = Compose([ | |
Resize( | |
width=net_w, | |
height=net_h, | |
resize_target=True if mode == 'train' else False, | |
keep_aspect_ratio=True, | |
ensure_multiple_of=12, | |
resize_method='lower_bound', | |
image_interpolation_method=cv2.INTER_CUBIC, | |
), | |
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
PrepareForNet(), | |
] + ([Crop(size[0])] if self.mode == 'train' else [])) | |
def __getitem__(self, item): | |
try: | |
img_path, disp_path = self.filelist[item] | |
image = cv2.imread(img_path) | |
if image is None: | |
print(f"Failed to load image: {img_path}") | |
return self.__getitem__((item + 1) % len(self.filelist)) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0 | |
depth = cv2.imread(disp_path, cv2.IMREAD_GRAYSCALE) | |
if depth is None: | |
print(f"Failed to load depth: {disp_path}") | |
return self.__getitem__((item + 1) % len(self.filelist)) | |
depth = depth.astype('float32') / 255.0 | |
sample = self.transform({'image': image, 'depth': depth}) | |
sample['image'] = torch.from_numpy(sample['image']) | |
sample['depth'] = torch.from_numpy(sample['depth']) | |
sample['valid_mask'] = torch.ones_like(sample['depth'], dtype=torch.bool) | |
sample['image_path'] = img_path | |
return sample | |
except Exception as e: | |
print(f"Error loading {item}: {str(e)}") | |
return self.__getitem__((item + 1) % len(self.filelist)) | |
def __len__(self): | |
return len(self.filelist) | |