import os import cv2 import numpy as np from torch.utils.data import Dataset as BaseDataset class SegmentationDataset(BaseDataset): """Dataset class for semantic segmentation task.""" def __init__(self, data_dir, classes=['background', 'object'], augmentation=None, preprocessing=None): self.image_dir = os.path.join(data_dir, 'Images') self.mask_dir = os.path.join(data_dir, 'Masks') for dir_path in [self.image_dir, self.mask_dir]: if not os.path.exists(dir_path): raise FileNotFoundError(f"Directory not found: {dir_path}") self.filenames = self._get_filenames() self.image_paths = [os.path.join(self.image_dir, fname) for fname in self.filenames] self.mask_paths = self._get_mask_paths() self.target_classes = [cls for cls in classes if cls.lower() != 'background'] self.class_values = [i for i, cls in enumerate(classes) if cls.lower() != 'background'] self.augmentation = augmentation self.preprocessing = preprocessing def __getitem__(self, index): image = self._load_image(self.image_paths[index]) mask = self._load_mask(self.mask_paths[index]) if self.augmentation: processed = self.augmentation(image=image, mask=mask) image, mask = processed['image'], processed['mask'] if self.preprocessing: processed = self.preprocessing(image=image, mask=mask) image, mask = processed['image'], processed['mask'] return image, mask def __len__(self): return len(self.filenames) def _get_filenames(self): """Returns sorted list of filenames, excluding directories.""" files = sorted(os.listdir(self.image_dir)) return [f for f in files if not os.path.isdir(os.path.join(self.image_dir, f))] def _get_mask_paths(self): """Generates corresponding mask paths for each image.""" mask_paths = [] for image_file in self.filenames: name, _ = os.path.splitext(image_file) mask_paths.append(os.path.join(self.mask_dir, f"{name}.png")) return mask_paths def _load_image(self, path): """Loads and converts image to RGB.""" image = cv2.imread(path) return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) def _load_mask(self, path): """Loads and processes segmentation mask.""" mask = cv2.imread(path, 0) masks = [(mask == value) for value in self.class_values] mask = np.stack(masks, axis=-1).astype('float') return mask class InferenceDataset(BaseDataset): """Dataset class for inference without ground truth masks.""" def __init__(self, data_dir, classes=['background', 'object'], augmentation=None, preprocessing=None): self.filenames = sorted([ f for f in os.listdir(data_dir) if not os.path.isdir(os.path.join(data_dir, f)) ]) self.image_paths = [os.path.join(data_dir, fname) for fname in self.filenames] self.target_classes = [cls for cls in classes if cls.lower() != 'background'] self.class_values = [i for i, cls in enumerate(classes) if cls.lower() != 'background'] self.augmentation = augmentation self.preprocessing = preprocessing def __getitem__(self, index): image = cv2.imread(self.image_paths[index]) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) original_height, original_width = image.shape[:2] if self.augmentation: image = self.augmentation(image=image)['image'] if self.preprocessing: image = self.preprocessing(image=image)['image'] return image, original_height, original_width def __len__(self): return len(self.filenames) class StreamingDataset(BaseDataset): """Dataset class optimized for video frame processing.""" def __init__(self, data_dir, classes=['background', 'object'], augmentation=None, preprocessing=None): self.filenames = self._get_frame_filenames(data_dir) self.image_paths = [os.path.join(data_dir, fname) for fname in self.filenames] self.target_classes = [cls for cls in classes if cls.lower() != 'background'] self.class_values = [i for i, cls in enumerate(classes) if cls.lower() != 'background'] self.augmentation = augmentation self.preprocessing = preprocessing def _get_frame_filenames(self, directory): """Returns sorted list of frame filenames.""" files = sorted(os.listdir(directory)) return [f for f in files if (('frame' in f or 'Image' in f) and f.lower().endswith('jpg') and not os.path.isdir(os.path.join(directory, f)))] def __getitem__(self, index): return InferenceDataset.__getitem__(self, index) def __len__(self): return len(self.filenames)