import os from torchvision.io import read_image class roiLeishDataset: def __init__(self, file_img, file_mask, transforms=None): self.file_img = file_img self.file_mask = file_mask self.list_img = os.listdir(self.file_img) self.transforms = transforms def __len__(self): return len(self.list_img) def __getitem__(self, idx): img_path = os.path.join(self.file_img, self.list_img[idx]) image = read_image(img_path) mask_path = os.path.join(self.file_mask, self.list_img[idx]) mask = read_image(mask_path) if self.transforms: image = self.transforms(image) mask = self.transforms(mask) mask = mask[0].unsqueeze(0) return {'image':image/255.0, 'mask':mask/255.0}