|
import os |
|
import cv2 |
|
import numpy as np |
|
from torch.utils.data import Dataset as BaseDataset |
|
|
|
class SegmentationDataset(BaseDataset): |
|
"""Dataset class for semantic segmentation task.""" |
|
|
|
def __init__(self, data_dir, classes=['background', 'object'], |
|
augmentation=None, preprocessing=None): |
|
|
|
self.image_dir = os.path.join(data_dir, 'Images') |
|
self.mask_dir = os.path.join(data_dir, 'Masks') |
|
|
|
for dir_path in [self.image_dir, self.mask_dir]: |
|
if not os.path.exists(dir_path): |
|
raise FileNotFoundError(f"Directory not found: {dir_path}") |
|
|
|
self.filenames = self._get_filenames() |
|
self.image_paths = [os.path.join(self.image_dir, fname) for fname in self.filenames] |
|
self.mask_paths = self._get_mask_paths() |
|
|
|
self.target_classes = [cls for cls in classes if cls.lower() != 'background'] |
|
self.class_values = [i for i, cls in enumerate(classes) if cls.lower() != 'background'] |
|
|
|
self.augmentation = augmentation |
|
self.preprocessing = preprocessing |
|
|
|
def __getitem__(self, index): |
|
image = self._load_image(self.image_paths[index]) |
|
mask = self._load_mask(self.mask_paths[index]) |
|
|
|
if self.augmentation: |
|
processed = self.augmentation(image=image, mask=mask) |
|
image, mask = processed['image'], processed['mask'] |
|
|
|
if self.preprocessing: |
|
processed = self.preprocessing(image=image, mask=mask) |
|
image, mask = processed['image'], processed['mask'] |
|
|
|
return image, mask |
|
|
|
def __len__(self): |
|
return len(self.filenames) |
|
|
|
def _get_filenames(self): |
|
"""Returns sorted list of filenames, excluding directories.""" |
|
files = sorted(os.listdir(self.image_dir)) |
|
return [f for f in files if not os.path.isdir(os.path.join(self.image_dir, f))] |
|
|
|
def _get_mask_paths(self): |
|
"""Generates corresponding mask paths for each image.""" |
|
mask_paths = [] |
|
for image_file in self.filenames: |
|
name, _ = os.path.splitext(image_file) |
|
mask_paths.append(os.path.join(self.mask_dir, f"{name}.png")) |
|
return mask_paths |
|
|
|
def _load_image(self, path): |
|
"""Loads and converts image to RGB.""" |
|
image = cv2.imread(path) |
|
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
def _load_mask(self, path): |
|
"""Loads and processes segmentation mask.""" |
|
mask = cv2.imread(path, 0) |
|
masks = [(mask == value) for value in self.class_values] |
|
mask = np.stack(masks, axis=-1).astype('float') |
|
return mask |
|
|
|
class InferenceDataset(BaseDataset): |
|
"""Dataset class for inference without ground truth masks.""" |
|
|
|
def __init__(self, data_dir, classes=['background', 'object'], |
|
augmentation=None, preprocessing=None): |
|
self.filenames = sorted([ |
|
f for f in os.listdir(data_dir) |
|
if not os.path.isdir(os.path.join(data_dir, f)) |
|
]) |
|
self.image_paths = [os.path.join(data_dir, fname) for fname in self.filenames] |
|
|
|
self.target_classes = [cls for cls in classes if cls.lower() != 'background'] |
|
self.class_values = [i for i, cls in enumerate(classes) if cls.lower() != 'background'] |
|
|
|
self.augmentation = augmentation |
|
self.preprocessing = preprocessing |
|
|
|
def __getitem__(self, index): |
|
image = cv2.imread(self.image_paths[index]) |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
original_height, original_width = image.shape[:2] |
|
|
|
if self.augmentation: |
|
image = self.augmentation(image=image)['image'] |
|
|
|
if self.preprocessing: |
|
image = self.preprocessing(image=image)['image'] |
|
|
|
return image, original_height, original_width |
|
|
|
def __len__(self): |
|
return len(self.filenames) |
|
|
|
class StreamingDataset(BaseDataset): |
|
"""Dataset class optimized for video frame processing.""" |
|
|
|
def __init__(self, data_dir, classes=['background', 'object'], |
|
augmentation=None, preprocessing=None): |
|
self.filenames = self._get_frame_filenames(data_dir) |
|
self.image_paths = [os.path.join(data_dir, fname) for fname in self.filenames] |
|
|
|
self.target_classes = [cls for cls in classes if cls.lower() != 'background'] |
|
self.class_values = [i for i, cls in enumerate(classes) if cls.lower() != 'background'] |
|
|
|
self.augmentation = augmentation |
|
self.preprocessing = preprocessing |
|
|
|
def _get_frame_filenames(self, directory): |
|
"""Returns sorted list of frame filenames.""" |
|
files = sorted(os.listdir(directory)) |
|
return [f for f in files if (('frame' in f or 'Image' in f) and |
|
f.lower().endswith('jpg') and |
|
not os.path.isdir(os.path.join(directory, f)))] |
|
|
|
def __getitem__(self, index): |
|
return InferenceDataset.__getitem__(self, index) |
|
|
|
def __len__(self): |
|
return len(self.filenames) |