Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,792 Bytes
adfc685 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import cv2
import torch
from torch.utils.data import Dataset
from torchvision.transforms import Compose
from dataset.transform import Resize, NormalizeImage, PrepareForNet, Crop
class PBRDataset(Dataset):
def __init__(self, filelist_path, mode, size=(512, 512)):
self.mode = mode
self.size = size
# Read filelist using @@ as delimiter
self.filelist = []
with open(filelist_path, 'r') as f:
for line in f:
line = line.strip()
# Split on @@ delimiter
if '@@' in line: # Use @@ as delimiter between paths
parts = line.split('@@')
if len(parts) == 2:
self.filelist.append((parts[0].strip(), parts[1].strip()))
print(f"Loaded {len(self.filelist)} image pairs")
net_w, net_h = size
self.transform = Compose([
Resize(
width=net_w,
height=net_h,
resize_target=True if mode == 'train' else False,
keep_aspect_ratio=True,
ensure_multiple_of=12,
resize_method='lower_bound',
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
PrepareForNet(),
] + ([Crop(size[0])] if self.mode == 'train' else []))
def __getitem__(self, item):
try:
img_path, disp_path = self.filelist[item]
image = cv2.imread(img_path)
if image is None:
print(f"Failed to load image: {img_path}")
return self.__getitem__((item + 1) % len(self.filelist))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
depth = cv2.imread(disp_path, cv2.IMREAD_GRAYSCALE)
if depth is None:
print(f"Failed to load depth: {disp_path}")
return self.__getitem__((item + 1) % len(self.filelist))
depth = depth.astype('float32') / 255.0
sample = self.transform({'image': image, 'depth': depth})
sample['image'] = torch.from_numpy(sample['image'])
sample['depth'] = torch.from_numpy(sample['depth'])
sample['valid_mask'] = torch.ones_like(sample['depth'], dtype=torch.bool)
sample['image_path'] = img_path
return sample
except Exception as e:
print(f"Error loading {item}: {str(e)}")
return self.__getitem__((item + 1) % len(self.filelist))
def __len__(self):
return len(self.filelist)
|