SaffalPoosh's picture
Upload folder using huggingface_hub
e5765b1
raw
history blame
No virus
1.39 kB
from pathlib import Path
from typing import List, Dict, Any, Tuple
import albumentations as albu
import numpy as np
import torch
from iglovikov_helper_functions.utils.image_utils import load_rgb, load_grayscale
from pytorch_toolbelt.utils.torch_utils import tensor_from_rgb_image
from torch.utils.data import Dataset
class SegmentationDataset(Dataset):
def __init__(
self,
samples: List[Tuple[Path, Path]],
transform: albu.Compose,
length: int = None,
) -> None:
self.samples = samples
self.transform = transform
if length is None:
self.length = len(self.samples)
else:
self.length = length
def __len__(self) -> int:
return self.length
def __getitem__(self, idx: int) -> Dict[str, Any]:
idx = idx % len(self.samples)
image_path, mask_path = self.samples[idx]
image = load_rgb(image_path, lib="cv2")
mask = load_grayscale(mask_path)
# apply augmentations
sample = self.transform(image=image, mask=mask)
image, mask = sample["image"], sample["mask"]
mask = (mask > 0).astype(np.uint8)
mask = torch.from_numpy(mask)
return {
"image_id": image_path.stem,
"features": tensor_from_rgb_image(image),
"masks": torch.unsqueeze(mask, 0).float(),
}