File size: 5,189 Bytes
8e5d8c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import cv2
import numpy as np
from torch.utils.data import Dataset as BaseDataset

class SegmentationDataset(BaseDataset):
    """Dataset class for semantic segmentation task."""
    
    def __init__(self, data_dir, classes=['background', 'object'], 
                 augmentation=None, preprocessing=None):
        
        self.image_dir = os.path.join(data_dir, 'Images')
        self.mask_dir = os.path.join(data_dir, 'Masks')
        
        for dir_path in [self.image_dir, self.mask_dir]:
            if not os.path.exists(dir_path):
                raise FileNotFoundError(f"Directory not found: {dir_path}")
        
        self.filenames = self._get_filenames()
        self.image_paths = [os.path.join(self.image_dir, fname) for fname in self.filenames]
        self.mask_paths = self._get_mask_paths()
        
        self.target_classes = [cls for cls in classes if cls.lower() != 'background']
        self.class_values = [i for i, cls in enumerate(classes) if cls.lower() != 'background']
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, index):
        image = self._load_image(self.image_paths[index])
        mask = self._load_mask(self.mask_paths[index])
        
        if self.augmentation:
            processed = self.augmentation(image=image, mask=mask)
            image, mask = processed['image'], processed['mask']
        
        if self.preprocessing:
            processed = self.preprocessing(image=image, mask=mask)
            image, mask = processed['image'], processed['mask']
            
        return image, mask
    
    def __len__(self):
        return len(self.filenames)
    
    def _get_filenames(self):
        """Returns sorted list of filenames, excluding directories."""
        files = sorted(os.listdir(self.image_dir))
        return [f for f in files if not os.path.isdir(os.path.join(self.image_dir, f))]
    
    def _get_mask_paths(self):
        """Generates corresponding mask paths for each image."""
        mask_paths = []
        for image_file in self.filenames:
            name, _ = os.path.splitext(image_file)
            mask_paths.append(os.path.join(self.mask_dir, f"{name}.png"))
        return mask_paths
    
    def _load_image(self, path):
        """Loads and converts image to RGB."""
        image = cv2.imread(path)
        return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    def _load_mask(self, path):
        """Loads and processes segmentation mask."""
        mask = cv2.imread(path, 0)
        masks = [(mask == value) for value in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')
        return mask

class InferenceDataset(BaseDataset):
    """Dataset class for inference without ground truth masks."""
    
    def __init__(self, data_dir, classes=['background', 'object'], 
                 augmentation=None, preprocessing=None):
        self.filenames = sorted([
            f for f in os.listdir(data_dir) 
            if not os.path.isdir(os.path.join(data_dir, f))
        ])
        self.image_paths = [os.path.join(data_dir, fname) for fname in self.filenames]
        
        self.target_classes = [cls for cls in classes if cls.lower() != 'background']
        self.class_values = [i for i, cls in enumerate(classes) if cls.lower() != 'background']
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, index):
        image = cv2.imread(self.image_paths[index])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        original_height, original_width = image.shape[:2]
        
        if self.augmentation:
            image = self.augmentation(image=image)['image']
        
        if self.preprocessing:
            image = self.preprocessing(image=image)['image']
            
        return image, original_height, original_width
    
    def __len__(self):
        return len(self.filenames)

class StreamingDataset(BaseDataset):
    """Dataset class optimized for video frame processing."""
    
    def __init__(self, data_dir, classes=['background', 'object'], 
                 augmentation=None, preprocessing=None):
        self.filenames = self._get_frame_filenames(data_dir)
        self.image_paths = [os.path.join(data_dir, fname) for fname in self.filenames]
        
        self.target_classes = [cls for cls in classes if cls.lower() != 'background']
        self.class_values = [i for i, cls in enumerate(classes) if cls.lower() != 'background']
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def _get_frame_filenames(self, directory):
        """Returns sorted list of frame filenames."""
        files = sorted(os.listdir(directory))
        return [f for f in files if (('frame' in f or 'Image' in f) and 
                                   f.lower().endswith('jpg') and
                                   not os.path.isdir(os.path.join(directory, f)))]
    
    def __getitem__(self, index):
        return InferenceDataset.__getitem__(self, index)
    
    def __len__(self):
        return len(self.filenames)