laureimeisan commited on
Commit
01c27ca
·
verified ·
1 Parent(s): 51d003b

Upload dataset.py

Browse files
Files changed (1) hide show
  1. dataset.py +465 -0
dataset.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import numpy as np
7
+ from typing import Dict, List, Tuple, Optional
8
+ import re
9
+ from pathlib import Path
10
+
11
+ class MultiModalFERDataset(Dataset):
12
+ """
13
+ Multi-modal Facial Expression Recognition Dataset for RGB and Thermal images
14
+ Supports 7 emotion classes: angry, disgust, fear, happy, neutral, sad, surprised
15
+ Supports 3 modes: 'rgb', 'thermal', 'combined'
16
+ """
17
+
18
+ def __init__(self,
19
+ data_dir: str,
20
+ mode: str = 'rgb',
21
+ split_ratio: float = 0.8,
22
+ split_type: str = 'train',
23
+ transform_rgb=None,
24
+ transform_thermal=None,
25
+ use_augmented: bool = False):
26
+ """
27
+ Args:
28
+ data_dir: Path to the Data directory containing RGB/Thermal/RgbAug/ThermalAug folders
29
+ mode: 'rgb', 'thermal', or 'combined' for fusion approaches
30
+ split_ratio: Ratio for train/test split (0.8 means 80% train, 20% test)
31
+ split_type: 'train' or 'test'
32
+ transform_rgb: Transform for RGB images
33
+ transform_thermal: Transform for thermal images
34
+ use_augmented: Whether to include augmented data
35
+ """
36
+ self.data_dir = data_dir
37
+ self.mode = mode.lower()
38
+ self.split_ratio = split_ratio
39
+ self.split_type = split_type
40
+ self.transform_rgb = transform_rgb
41
+ self.transform_thermal = transform_thermal
42
+ self.use_augmented = use_augmented
43
+
44
+ # Validate mode
45
+ if self.mode not in ['rgb', 'thermal', 'combined']:
46
+ raise ValueError("Mode must be 'rgb', 'thermal', or 'combined'")
47
+
48
+ # Define emotion classes (mapping from filename format)
49
+ self.emotion_classes = ['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprised']
50
+ self.class_to_idx = {emotion: idx for idx, emotion in enumerate(self.emotion_classes)}
51
+ self.idx_to_class = {idx: emotion for emotion, idx in self.class_to_idx.items()}
52
+
53
+ # Load image paths and labels
54
+ self.rgb_paths = []
55
+ self.thermal_paths = []
56
+ self.labels = []
57
+ self._load_data()
58
+
59
+ def _load_data(self):
60
+ """Load all image paths and corresponding labels from filename-based structure"""
61
+ # Define directories to search
62
+ rgb_dirs = [os.path.join(self.data_dir, 'RGB')]
63
+ thermal_dirs = [os.path.join(self.data_dir, 'Thermal')]
64
+
65
+ if self.use_augmented:
66
+ rgb_dirs.append(os.path.join(self.data_dir, 'RgbAug'))
67
+ thermal_dirs.append(os.path.join(self.data_dir, 'ThermalAug'))
68
+
69
+ # Collect all RGB and Thermal files
70
+ rgb_files = []
71
+ thermal_files = []
72
+
73
+ # Load RGB files
74
+ for rgb_dir in rgb_dirs:
75
+ if os.path.exists(rgb_dir):
76
+ for filename in os.listdir(rgb_dir):
77
+ if filename.startswith('R_') and filename.lower().endswith(('.jpg', '.jpeg', '.bmp', '.png')):
78
+ rgb_files.append((os.path.join(rgb_dir, filename), filename))
79
+
80
+ # Load Thermal files
81
+ for thermal_dir in thermal_dirs:
82
+ if os.path.exists(thermal_dir):
83
+ for filename in os.listdir(thermal_dir):
84
+ if filename.startswith('T_') and filename.lower().endswith(('.jpg', '.jpeg', '.bmp', '.png')):
85
+ thermal_files.append((os.path.join(thermal_dir, filename), filename))
86
+
87
+ # Create data based on mode
88
+ if self.mode == 'combined':
89
+ self._create_paired_data(rgb_files, thermal_files)
90
+ elif self.mode == 'rgb':
91
+ self._create_single_modal_data(rgb_files, 'rgb')
92
+ elif self.mode == 'thermal':
93
+ self._create_single_modal_data(thermal_files, 'thermal')
94
+
95
+ # Apply train/test split
96
+ self._apply_split()
97
+
98
+ def _parse_filename(self, filename: str) -> Tuple[str, str]:
99
+ """Parse filename to extract emotion class and unique ID
100
+ Format: [R|T]_Classname_ID_Source.ext
101
+ Returns: (emotion_class, unique_id)
102
+ """
103
+ # Remove extension and split by underscore
104
+ basename = os.path.splitext(filename)[0]
105
+ parts = basename.split('_')
106
+
107
+ if len(parts) >= 4:
108
+ modality = parts[0] # R or T
109
+ emotion = parts[1].lower()
110
+ unique_id = parts[2]
111
+ source = parts[3]
112
+
113
+ # Map 'surprised' to 'surprised' (handle naming inconsistency)
114
+ if emotion == 'surprised':
115
+ emotion = 'surprised'
116
+
117
+ return emotion, f"{unique_id}_{source}"
118
+ else:
119
+ raise ValueError(f"Invalid filename format: {filename}")
120
+
121
+ def _create_paired_data(self, rgb_files: List, thermal_files: List):
122
+ """Create paired RGB-Thermal data for combined mode"""
123
+ # Create mapping from unique_id to file paths (only one per unique_id to avoid duplication)
124
+ rgb_map = {}
125
+ thermal_map = {}
126
+
127
+ for rgb_path, rgb_filename in rgb_files:
128
+ try:
129
+ emotion, unique_id = self._parse_filename(rgb_filename)
130
+ if emotion in self.class_to_idx:
131
+ # Only keep the first file per unique_id to avoid duplication
132
+ if unique_id not in rgb_map:
133
+ rgb_map[unique_id] = (rgb_path, emotion)
134
+ except:
135
+ continue
136
+
137
+ for thermal_path, thermal_filename in thermal_files:
138
+ try:
139
+ emotion, unique_id = self._parse_filename(thermal_filename)
140
+ if emotion in self.class_to_idx:
141
+ # Only keep the first file per unique_id to avoid duplication
142
+ if unique_id not in thermal_map:
143
+ thermal_map[unique_id] = (thermal_path, emotion)
144
+ except:
145
+ continue
146
+
147
+ # Find common unique_ids that have both RGB and Thermal
148
+ common_ids = set(rgb_map.keys()) & set(thermal_map.keys())
149
+
150
+ for unique_id in common_ids:
151
+ rgb_path, rgb_emotion = rgb_map[unique_id]
152
+ thermal_path, thermal_emotion = thermal_map[unique_id]
153
+
154
+ # Ensure emotions match
155
+ if rgb_emotion == thermal_emotion:
156
+ self.rgb_paths.append(rgb_path)
157
+ self.thermal_paths.append(thermal_path)
158
+ self.labels.append(self.class_to_idx[rgb_emotion])
159
+
160
+ def _create_single_modal_data(self, files: List, modality: str):
161
+ """Create single modal data for RGB-only or Thermal-only mode"""
162
+ for file_path, filename in files:
163
+ try:
164
+ emotion, unique_id = self._parse_filename(filename)
165
+ if emotion in self.class_to_idx:
166
+ if modality == 'rgb':
167
+ self.rgb_paths.append(file_path)
168
+ self.thermal_paths.append(None)
169
+ else: # thermal
170
+ self.rgb_paths.append(None)
171
+ self.thermal_paths.append(file_path)
172
+ self.labels.append(self.class_to_idx[emotion])
173
+ except:
174
+ continue
175
+
176
+ def _apply_split(self):
177
+ """Apply train/test split based on split_ratio"""
178
+ total_samples = len(self.labels)
179
+ train_size = int(total_samples * self.split_ratio)
180
+
181
+ # Create indices and shuffle
182
+ indices = np.random.RandomState(42).permutation(total_samples)
183
+
184
+ if self.split_type == 'train':
185
+ selected_indices = indices[:train_size]
186
+ else: # test
187
+ selected_indices = indices[train_size:]
188
+
189
+ # Filter data based on selected indices
190
+ self.rgb_paths = [self.rgb_paths[i] for i in selected_indices]
191
+ self.thermal_paths = [self.thermal_paths[i] for i in selected_indices]
192
+ self.labels = [self.labels[i] for i in selected_indices]
193
+
194
+ def __len__(self):
195
+ return len(self.labels)
196
+
197
+ def __getitem__(self, idx):
198
+ label = self.labels[idx]
199
+
200
+ if self.mode == 'rgb':
201
+ # RGB only mode
202
+ rgb_image = Image.open(self.rgb_paths[idx]).convert('RGB')
203
+ if self.transform_rgb:
204
+ rgb_image = self.transform_rgb(rgb_image)
205
+ return rgb_image, label
206
+
207
+ elif self.mode == 'thermal':
208
+ # Thermal only mode
209
+ thermal_image = Image.open(self.thermal_paths[idx])
210
+ # Convert thermal to grayscale then to 3-channel for ViT compatibility
211
+ if thermal_image.mode != 'L':
212
+ thermal_image = thermal_image.convert('L')
213
+ thermal_image = thermal_image.convert('RGB') # Convert to 3-channel
214
+
215
+ if self.transform_thermal:
216
+ thermal_image = self.transform_thermal(thermal_image)
217
+ return thermal_image, label
218
+
219
+ elif self.mode == 'combined':
220
+ # Combined mode - return both RGB and Thermal
221
+ rgb_image = Image.open(self.rgb_paths[idx]).convert('RGB')
222
+ thermal_image = Image.open(self.thermal_paths[idx])
223
+
224
+ # Convert thermal to grayscale then to 3-channel
225
+ if thermal_image.mode != 'L':
226
+ thermal_image = thermal_image.convert('L')
227
+ thermal_image = thermal_image.convert('RGB')
228
+
229
+ if self.transform_rgb:
230
+ rgb_image = self.transform_rgb(rgb_image)
231
+ if self.transform_thermal:
232
+ thermal_image = self.transform_thermal(thermal_image)
233
+
234
+ return {'rgb': rgb_image, 'thermal': thermal_image}, label
235
+
236
+ def get_class_distribution(self) -> Dict[str, int]:
237
+ """Get the distribution of classes in the dataset"""
238
+ distribution = {}
239
+ for emotion in self.emotion_classes:
240
+ count = self.labels.count(self.class_to_idx[emotion])
241
+ distribution[emotion] = count
242
+ return distribution
243
+
244
+ def get_class_weights(self) -> torch.Tensor:
245
+ """Calculate class weights for imbalanced dataset"""
246
+ class_counts = np.bincount(self.labels)
247
+ total_samples = len(self.labels)
248
+ class_weights = total_samples / (len(self.emotion_classes) * class_counts)
249
+ return torch.FloatTensor(class_weights)
250
+
251
+
252
+ def get_rgb_transforms(image_size: int = 224, is_training: bool = True):
253
+ """
254
+ Get RGB data transforms for training and validation
255
+
256
+ Args:
257
+ image_size: Target image size for ViT
258
+ is_training: Whether this is for training (applies augmentation)
259
+ """
260
+ if is_training:
261
+ transform = transforms.Compose([
262
+ transforms.Resize((image_size, image_size)),
263
+ transforms.RandomHorizontalFlip(p=0.5),
264
+ transforms.RandomRotation(degrees=15),
265
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
266
+ transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
267
+ transforms.ToTensor(),
268
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
269
+ ])
270
+ else:
271
+ transform = transforms.Compose([
272
+ transforms.Resize((image_size, image_size)),
273
+ transforms.ToTensor(),
274
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
275
+ ])
276
+
277
+ return transform
278
+
279
+
280
+ def get_thermal_transforms(image_size: int = 224, is_training: bool = True):
281
+ """
282
+ Get Thermal data transforms for training and validation
283
+ Note: Thermal images are treated as grayscale converted to 3-channel
284
+
285
+ Args:
286
+ image_size: Target image size for ViT
287
+ is_training: Whether this is for training (applies augmentation)
288
+ """
289
+ if is_training:
290
+ transform = transforms.Compose([
291
+ transforms.Resize((image_size, image_size)),
292
+ transforms.RandomHorizontalFlip(p=0.5),
293
+ transforms.RandomRotation(degrees=15),
294
+ # More conservative augmentation for thermal images
295
+ transforms.RandomAffine(degrees=0, translate=(0.05, 0.05), scale=(0.95, 1.05)),
296
+ transforms.ToTensor(),
297
+ # Use ImageNet normalization for consistency with pretrained ViT
298
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
299
+ ])
300
+ else:
301
+ transform = transforms.Compose([
302
+ transforms.Resize((image_size, image_size)),
303
+ transforms.ToTensor(),
304
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
305
+ ])
306
+
307
+ return transform
308
+
309
+
310
+ def create_multimodal_data_loaders(
311
+ data_dir: str,
312
+ mode: str = 'rgb',
313
+ batch_size: int = 32,
314
+ image_size: int = 224,
315
+ num_workers: int = 4,
316
+ val_split: float = 0.2,
317
+ use_augmented: bool = False
318
+ ) -> Tuple[DataLoader, DataLoader]:
319
+ """
320
+ Create train and test data loaders for multimodal FER
321
+
322
+ Args:
323
+ data_dir: Path to the Data directory containing RGB/Thermal folders
324
+ mode: 'rgb', 'thermal', or 'combined'
325
+ batch_size: Batch size for training
326
+ image_size: Target image size for ViT
327
+ num_workers: Number of workers for data loading
328
+ val_split: Fraction of data to use for validation (applied to train split)
329
+ use_augmented: Whether to include augmented data
330
+
331
+ Returns:
332
+ train_loader, test_loader
333
+ """
334
+ # Create transforms
335
+ rgb_train_transform = get_rgb_transforms(image_size, is_training=True)
336
+ rgb_test_transform = get_rgb_transforms(image_size, is_training=False)
337
+ thermal_train_transform = get_thermal_transforms(image_size, is_training=True)
338
+ thermal_test_transform = get_thermal_transforms(image_size, is_training=False)
339
+
340
+ # Calculate split ratio for train vs test
341
+ train_split_ratio = 1.0 - val_split
342
+
343
+ # Create datasets
344
+ train_dataset = MultiModalFERDataset(
345
+ data_dir=data_dir,
346
+ mode=mode,
347
+ split_ratio=train_split_ratio,
348
+ split_type='train',
349
+ transform_rgb=rgb_train_transform,
350
+ transform_thermal=thermal_train_transform,
351
+ use_augmented=use_augmented
352
+ )
353
+
354
+ test_dataset = MultiModalFERDataset(
355
+ data_dir=data_dir,
356
+ mode=mode,
357
+ split_ratio=train_split_ratio,
358
+ split_type='test',
359
+ transform_rgb=rgb_test_transform,
360
+ transform_thermal=thermal_test_transform,
361
+ use_augmented=use_augmented
362
+ )
363
+
364
+ # Create data loaders
365
+ train_loader = DataLoader(
366
+ train_dataset,
367
+ batch_size=batch_size,
368
+ shuffle=True,
369
+ num_workers=num_workers,
370
+ pin_memory=True,
371
+ drop_last=True
372
+ )
373
+
374
+ test_loader = DataLoader(
375
+ test_dataset,
376
+ batch_size=batch_size,
377
+ shuffle=False,
378
+ num_workers=num_workers,
379
+ pin_memory=True
380
+ )
381
+
382
+ return train_loader, test_loader
383
+
384
+
385
+ def analyze_multimodal_dataset(data_dir: str, use_augmented: bool = False):
386
+ """Analyze the multimodal dataset and print statistics"""
387
+ print("=== Multimodal Dataset Analysis ===")
388
+
389
+ # Analyze different modes and splits
390
+ for mode in ['rgb', 'thermal', 'combined']:
391
+ print(f"\n=== {mode.upper()} Mode ===")
392
+
393
+ for split in ['train', 'test']:
394
+ print(f"\n{split.upper()} Split:")
395
+
396
+ try:
397
+ dataset = MultiModalFERDataset(
398
+ data_dir=data_dir,
399
+ mode=mode,
400
+ split_ratio=0.8,
401
+ split_type=split,
402
+ use_augmented=use_augmented
403
+ )
404
+
405
+ print(f"Total samples: {len(dataset)}")
406
+
407
+ # Class distribution
408
+ distribution = dataset.get_class_distribution()
409
+ print("Class distribution:")
410
+ for emotion, count in distribution.items():
411
+ if len(dataset) > 0:
412
+ percentage = (count / len(dataset)) * 100
413
+ print(f" {emotion}: {count} ({percentage:.1f}%)")
414
+
415
+ # Class weights
416
+ if len(dataset) > 0:
417
+ weights = dataset.get_class_weights()
418
+ print("Class weights:")
419
+ for i, (emotion, weight) in enumerate(zip(dataset.emotion_classes, weights)):
420
+ print(f" {emotion}: {weight:.3f}")
421
+
422
+ except Exception as e:
423
+ print(f"Error loading {mode} {split} dataset: {e}")
424
+
425
+
426
+ if __name__ == "__main__":
427
+ # Example usage
428
+ data_dir = "../vit/vit/data/vit/Data"
429
+
430
+ # Analyze dataset
431
+ analyze_multimodal_dataset(data_dir, use_augmented=True)
432
+
433
+ # Test different modes
434
+ for mode in ['rgb', 'thermal', 'combined']:
435
+ print(f"\n=== Testing {mode.upper()} Mode ===")
436
+
437
+ try:
438
+ # Create data loaders
439
+ train_loader, test_loader = create_multimodal_data_loaders(
440
+ data_dir, mode=mode, batch_size=8, image_size=224, use_augmented=True
441
+ )
442
+
443
+ print(f"Data loaders created:")
444
+ print(f"Train batches: {len(train_loader)}")
445
+ print(f"Test batches: {len(test_loader)}")
446
+
447
+ # Test loading a batch
448
+ if len(train_loader) > 0:
449
+ batch = next(iter(train_loader))
450
+ if mode == 'combined':
451
+ data, labels = batch
452
+ rgb_images = data['rgb']
453
+ thermal_images = data['thermal']
454
+ print(f"RGB batch shape: {rgb_images.shape}")
455
+ print(f"Thermal batch shape: {thermal_images.shape}")
456
+ print(f"Labels shape: {labels.shape}")
457
+ else:
458
+ images, labels = batch
459
+ print(f"Batch shape: {images.shape}")
460
+ print(f"Labels shape: {labels.shape}")
461
+ else:
462
+ print("No data available for this mode")
463
+
464
+ except Exception as e:
465
+ print(f"Error testing {mode} mode: {e}")