obichimav's picture
Upload 42 files
8e5d8c7 verified
import os
import cv2
import numpy as np
from torch.utils.data import Dataset as BaseDataset
class SegmentationDataset(BaseDataset):
"""Dataset class for semantic segmentation task."""
def __init__(self, data_dir, classes=['background', 'object'],
augmentation=None, preprocessing=None):
self.image_dir = os.path.join(data_dir, 'Images')
self.mask_dir = os.path.join(data_dir, 'Masks')
for dir_path in [self.image_dir, self.mask_dir]:
if not os.path.exists(dir_path):
raise FileNotFoundError(f"Directory not found: {dir_path}")
self.filenames = self._get_filenames()
self.image_paths = [os.path.join(self.image_dir, fname) for fname in self.filenames]
self.mask_paths = self._get_mask_paths()
self.target_classes = [cls for cls in classes if cls.lower() != 'background']
self.class_values = [i for i, cls in enumerate(classes) if cls.lower() != 'background']
self.augmentation = augmentation
self.preprocessing = preprocessing
def __getitem__(self, index):
image = self._load_image(self.image_paths[index])
mask = self._load_mask(self.mask_paths[index])
if self.augmentation:
processed = self.augmentation(image=image, mask=mask)
image, mask = processed['image'], processed['mask']
if self.preprocessing:
processed = self.preprocessing(image=image, mask=mask)
image, mask = processed['image'], processed['mask']
return image, mask
def __len__(self):
return len(self.filenames)
def _get_filenames(self):
"""Returns sorted list of filenames, excluding directories."""
files = sorted(os.listdir(self.image_dir))
return [f for f in files if not os.path.isdir(os.path.join(self.image_dir, f))]
def _get_mask_paths(self):
"""Generates corresponding mask paths for each image."""
mask_paths = []
for image_file in self.filenames:
name, _ = os.path.splitext(image_file)
mask_paths.append(os.path.join(self.mask_dir, f"{name}.png"))
return mask_paths
def _load_image(self, path):
"""Loads and converts image to RGB."""
image = cv2.imread(path)
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
def _load_mask(self, path):
"""Loads and processes segmentation mask."""
mask = cv2.imread(path, 0)
masks = [(mask == value) for value in self.class_values]
mask = np.stack(masks, axis=-1).astype('float')
return mask
class InferenceDataset(BaseDataset):
"""Dataset class for inference without ground truth masks."""
def __init__(self, data_dir, classes=['background', 'object'],
augmentation=None, preprocessing=None):
self.filenames = sorted([
f for f in os.listdir(data_dir)
if not os.path.isdir(os.path.join(data_dir, f))
])
self.image_paths = [os.path.join(data_dir, fname) for fname in self.filenames]
self.target_classes = [cls for cls in classes if cls.lower() != 'background']
self.class_values = [i for i, cls in enumerate(classes) if cls.lower() != 'background']
self.augmentation = augmentation
self.preprocessing = preprocessing
def __getitem__(self, index):
image = cv2.imread(self.image_paths[index])
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
original_height, original_width = image.shape[:2]
if self.augmentation:
image = self.augmentation(image=image)['image']
if self.preprocessing:
image = self.preprocessing(image=image)['image']
return image, original_height, original_width
def __len__(self):
return len(self.filenames)
class StreamingDataset(BaseDataset):
"""Dataset class optimized for video frame processing."""
def __init__(self, data_dir, classes=['background', 'object'],
augmentation=None, preprocessing=None):
self.filenames = self._get_frame_filenames(data_dir)
self.image_paths = [os.path.join(data_dir, fname) for fname in self.filenames]
self.target_classes = [cls for cls in classes if cls.lower() != 'background']
self.class_values = [i for i, cls in enumerate(classes) if cls.lower() != 'background']
self.augmentation = augmentation
self.preprocessing = preprocessing
def _get_frame_filenames(self, directory):
"""Returns sorted list of frame filenames."""
files = sorted(os.listdir(directory))
return [f for f in files if (('frame' in f or 'Image' in f) and
f.lower().endswith('jpg') and
not os.path.isdir(os.path.join(directory, f)))]
def __getitem__(self, index):
return InferenceDataset.__getitem__(self, index)
def __len__(self):
return len(self.filenames)