laureimeisan commited on
Commit
3928508
·
verified ·
1 Parent(s): 7824be1

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +753 -0
model.py ADDED
@@ -0,0 +1,753 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import ViTConfig, ViTModel, ViTForImageClassification
4
+ from transformers import AutoImageProcessor
5
+ from typing import Optional, Dict, Any, Union
6
+ import logging
7
+
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+ """
12
+ References from Hugging Face Transformers ViT documentation:
13
+ https://huggingface.co/docs/transformers/en/model_doc/vit
14
+ """
15
+
16
+ class ViTForFER(nn.Module):
17
+ """
18
+ Vision Transformer for Facial Expression Recognition
19
+ Fine-tuned on FER dataset with 7 emotion classes
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ model_name: str = "google/vit-base-patch16-224-in21k",
25
+ num_classes: int = 7,
26
+ dropout_rate: float = 0.1,
27
+ freeze_backbone: bool = False,
28
+ use_gradient_checkpointing: bool = False
29
+ ):
30
+ """
31
+ Args:
32
+ model_name: Pre-trained ViT model name from HuggingFace
33
+ num_classes: Number of emotion classes (7 for FER)
34
+ dropout_rate: Dropout rate for the classifier head
35
+ freeze_backbone: Whether to freeze the backbone during fine-tuning
36
+ use_gradient_checkpointing: Whether to use gradient checkpointing
37
+ """
38
+ super().__init__()
39
+
40
+ self.model_name = model_name
41
+ self.num_classes = num_classes
42
+ self.dropout_rate = dropout_rate
43
+ self.freeze_backbone = freeze_backbone
44
+
45
+ # Load pre-trained ViT configuration
46
+ self.config = ViTConfig.from_pretrained(model_name)
47
+ self.config.num_labels = num_classes
48
+ self.config.id2label = {
49
+ 0: "angry", 1: "disgust", 2: "fear", 3: "happy",
50
+ 4: "neutral", 5: "sad", 6: "surprised"
51
+ }
52
+ self.config.label2id = {v: k for k, v in self.config.id2label.items()}
53
+
54
+ # Initialise ViT model
55
+ self.vit = ViTForImageClassification.from_pretrained(
56
+ model_name,
57
+ config=self.config,
58
+ ignore_mismatched_sizes=True
59
+ )
60
+
61
+ # Enable gradient checkpointing if requested
62
+ if use_gradient_checkpointing:
63
+ self.vit.gradient_checkpointing_enable()
64
+
65
+ # Freeze backbone if requested
66
+ if freeze_backbone:
67
+ self._freeze_backbone()
68
+
69
+ # Replace classifier head with custom one
70
+ self._replace_classifier_head()
71
+
72
+ # Initialize image processor
73
+ self.image_processor = AutoImageProcessor.from_pretrained(model_name)
74
+
75
+ def _freeze_backbone(self):
76
+ """Freeze the ViT backbone parameters"""
77
+ for param in self.vit.vit.parameters():
78
+ param.requires_grad = False
79
+ logger.info("ViT backbone frozen")
80
+
81
+ def _replace_classifier_head(self):
82
+ """Replace the classifier head with a custom one"""
83
+ hidden_size = self.config.hidden_size
84
+
85
+ # Custom classifier head with dropout and layer normalization
86
+ self.vit.classifier = nn.Sequential(
87
+ nn.LayerNorm(hidden_size),
88
+ nn.Dropout(self.dropout_rate),
89
+ nn.Linear(hidden_size, hidden_size // 2),
90
+ nn.GELU(),
91
+ nn.Dropout(self.dropout_rate / 2),
92
+ nn.Linear(hidden_size // 2, self.num_classes)
93
+ )
94
+
95
+ # Initialize weights
96
+ self._init_classifier_weights()
97
+
98
+ def _init_classifier_weights(self):
99
+ """Initialize classifier weights"""
100
+ for module in self.vit.classifier.modules():
101
+ if isinstance(module, nn.Linear):
102
+ torch.nn.init.xavier_uniform_(module.weight)
103
+ if module.bias is not None:
104
+ torch.nn.init.zeros_(module.bias)
105
+
106
+ def forward(self, pixel_values: torch.Tensor, labels: Optional[torch.Tensor] = None):
107
+ """
108
+ Forward pass
109
+
110
+ Args:
111
+ pixel_values: Input images tensor
112
+ labels: Optional labels for computing loss
113
+
114
+ Returns:
115
+ Dictionary containing logits, loss (if labels provided), and other outputs
116
+ """
117
+ outputs = self.vit(pixel_values=pixel_values, labels=labels)
118
+ return outputs
119
+
120
+ def get_features(self, pixel_values: torch.Tensor):
121
+ """
122
+ Get features from ViT backbone (before classifier)
123
+
124
+ Args:
125
+ pixel_values: Input images tensor
126
+
127
+ Returns:
128
+ Features tensor from ViT backbone
129
+ """
130
+ with torch.no_grad():
131
+ outputs = self.vit.vit(pixel_values=pixel_values)
132
+ # Get the [CLS] token representation
133
+ features = outputs.last_hidden_state[:, 0, :] # Shape: (batch_size, hidden_size)
134
+ return features
135
+
136
+ def get_attention_weights(self, pixel_values: torch.Tensor, layer_idx: int = -1):
137
+ """
138
+ Get attention weights for visualisation
139
+
140
+ Args:
141
+ pixel_values: Input images tensor
142
+ layer_idx: Which layer's attention to return (-1 for last layer)
143
+
144
+ Returns:
145
+ Attention weights tensor
146
+ """
147
+ with torch.no_grad():
148
+ outputs = self.vit.vit(pixel_values=pixel_values, output_attentions=True)
149
+ attentions = outputs.attentions
150
+ return attentions[layer_idx]
151
+
152
+ def unfreeze_backbone(self):
153
+ """Unfreeze the backbone for full fine-tuning"""
154
+ for param in self.vit.vit.parameters():
155
+ param.requires_grad = True
156
+ self.freeze_backbone = False
157
+ logger.info("ViT backbone unfrozen")
158
+
159
+ def get_model_info(self) -> Dict[str, Any]:
160
+ """Get model information"""
161
+ total_params = sum(p.numel() for p in self.parameters())
162
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
163
+
164
+ return {
165
+ "model_name": self.model_name,
166
+ "num_classes": self.num_classes,
167
+ "total_parameters": total_params,
168
+ "trainable_parameters": trainable_params,
169
+ "freeze_backbone": self.freeze_backbone,
170
+ "dropout_rate": self.dropout_rate,
171
+ "image_size": self.config.image_size,
172
+ "patch_size": self.config.patch_size,
173
+ "hidden_size": self.config.hidden_size,
174
+ "num_attention_heads": self.config.num_attention_heads,
175
+ "num_hidden_layers": self.config.num_hidden_layers
176
+ }
177
+
178
+
179
+ class EarlyFusionViT(nn.Module):
180
+ """
181
+ Early Fusion ViT: Concatenates RGB and Thermal images at input level
182
+ """
183
+
184
+ def __init__(
185
+ self,
186
+ model_name: str = "google/vit-base-patch16-224-in21k",
187
+ num_classes: int = 7,
188
+ dropout_rate: float = 0.1,
189
+ freeze_backbone: bool = False,
190
+ use_gradient_checkpointing: bool = False,
191
+ fusion_type: str = "concat"
192
+ ):
193
+ """
194
+ Args:
195
+ model_name: Pre-trained ViT model name
196
+ num_classes: Number of emotion classes
197
+ dropout_rate: Dropout rate for classifier
198
+ freeze_backbone: Whether to freeze backbone
199
+ use_gradient_checkpointing: Whether to use gradient checkpointing
200
+ fusion_type: How to fuse RGB and Thermal ("concat" or "add")
201
+ """
202
+ super().__init__()
203
+
204
+ self.model_name = model_name
205
+ self.num_classes = num_classes
206
+ self.dropout_rate = dropout_rate
207
+ self.freeze_backbone = freeze_backbone
208
+ self.fusion_type = fusion_type
209
+
210
+ # Load pre-trained ViT configuration
211
+ self.config = ViTConfig.from_pretrained(model_name)
212
+ self.config.num_labels = num_classes
213
+ self.config.id2label = {
214
+ 0: "angry", 1: "disgust", 2: "fear", 3: "happy",
215
+ 4: "neutral", 5: "sad", 6: "surprised"
216
+ }
217
+ self.config.label2id = {v: k for k, v in self.config.id2label.items()}
218
+
219
+ # Modify input channels based on fusion type
220
+ if fusion_type == "concat":
221
+ # RGB (3) + Thermal (3) = 6 channels
222
+ self.input_channels = 6
223
+ # Modify config to handle 6-channel input
224
+ self.config.num_channels = 6
225
+ else: # add
226
+ # RGB and Thermal both have 3 channels, output is 3 channels
227
+ self.input_channels = 3
228
+
229
+ # Create ViT backbone
230
+ self.vit = ViTModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
231
+
232
+ # Modify the patch embedding layer for different input channels
233
+ if self.input_channels != 3:
234
+ self._modify_patch_embedding()
235
+
236
+ # Enable gradient checkpointing if requested
237
+ if use_gradient_checkpointing:
238
+ self.vit.gradient_checkpointing_enable()
239
+
240
+ # Freeze backbone if requested
241
+ if freeze_backbone:
242
+ self._freeze_backbone()
243
+
244
+ # Create classifier head
245
+ self._create_classifier_head()
246
+
247
+ def _modify_patch_embedding(self):
248
+ """Modify patch embedding layer for different input channels"""
249
+ original_conv = self.vit.embeddings.patch_embeddings.projection
250
+
251
+ # Check if the conv layer already has the right number of channels
252
+ if original_conv.in_channels == self.input_channels:
253
+ # Already has the right number of channels, initialize properly
254
+ if self.input_channels == 6:
255
+ with torch.no_grad():
256
+ # Initialize the 6-channel weights by duplicating the first 3 channels
257
+ # Get the original 3-channel weights from a fresh model
258
+ from transformers import ViTConfig, ViTModel
259
+ temp_config = ViTConfig.from_pretrained(self.model_name)
260
+ temp_vit = ViTModel.from_pretrained(self.model_name, config=temp_config)
261
+ original_3ch_weight = temp_vit.embeddings.patch_embeddings.projection.weight
262
+
263
+ # Copy RGB weights to first 3 channels and thermal channels
264
+ original_conv.weight[:, :3, :, :] = original_3ch_weight
265
+ original_conv.weight[:, 3:6, :, :] = original_3ch_weight
266
+ return
267
+
268
+ # Create new conv layer with different input channels
269
+ new_conv = nn.Conv2d(
270
+ self.input_channels,
271
+ original_conv.out_channels,
272
+ kernel_size=original_conv.kernel_size,
273
+ stride=original_conv.stride,
274
+ padding=original_conv.padding,
275
+ bias=original_conv.bias is not None
276
+ )
277
+
278
+ # Initialize weights
279
+ with torch.no_grad():
280
+ if self.input_channels == 6: # concat case
281
+ # Copy RGB weights to first 3 channels
282
+ new_conv.weight[:, :3, :, :] = original_conv.weight
283
+ # Copy RGB weights to thermal channels (channels 3-6)
284
+ new_conv.weight[:, 3:6, :, :] = original_conv.weight
285
+
286
+ if original_conv.bias is not None:
287
+ new_conv.bias.copy_(original_conv.bias)
288
+
289
+ # Replace the projection layer
290
+ self.vit.embeddings.patch_embeddings.projection = new_conv
291
+
292
+ def _freeze_backbone(self):
293
+ """Freeze the ViT backbone parameters"""
294
+ for param in self.vit.parameters():
295
+ param.requires_grad = False
296
+ logger.info("ViT backbone frozen")
297
+
298
+ def _create_classifier_head(self):
299
+ """Create classifier head"""
300
+ hidden_size = self.config.hidden_size
301
+
302
+ self.classifier = nn.Sequential(
303
+ nn.LayerNorm(hidden_size),
304
+ nn.Dropout(self.dropout_rate),
305
+ nn.Linear(hidden_size, hidden_size // 2),
306
+ nn.GELU(),
307
+ nn.Dropout(self.dropout_rate / 2),
308
+ nn.Linear(hidden_size // 2, self.num_classes)
309
+ )
310
+
311
+ # Initialize weights
312
+ self._init_classifier_weights()
313
+
314
+ def _init_classifier_weights(self):
315
+ """Initialize classifier weights"""
316
+ for module in self.classifier.modules():
317
+ if isinstance(module, nn.Linear):
318
+ torch.nn.init.xavier_uniform_(module.weight)
319
+ if module.bias is not None:
320
+ torch.nn.init.zeros_(module.bias)
321
+
322
+ def forward(self, rgb_images: torch.Tensor, thermal_images: torch.Tensor, labels: Optional[torch.Tensor] = None):
323
+ """
324
+ Forward pass for early fusion
325
+
326
+ Args:
327
+ rgb_images: RGB images tensor (B, 3, H, W)
328
+ thermal_images: Thermal images tensor (B, 3, H, W)
329
+ labels: Optional labels for computing loss
330
+
331
+ Returns:
332
+ Dictionary containing logits and loss (if labels provided)
333
+ """
334
+ # Fuse RGB and Thermal at input level
335
+ if self.fusion_type == "concat":
336
+ # Concatenate along channel dimension
337
+ fused_input = torch.cat([rgb_images, thermal_images], dim=1) # (B, 6, H, W)
338
+ else: # add
339
+ # Element-wise addition
340
+ fused_input = rgb_images + thermal_images # (B, 3, H, W)
341
+
342
+ # Forward through ViT backbone
343
+ outputs = self.vit(pixel_values=fused_input)
344
+
345
+ # Get [CLS] token representation
346
+ cls_output = outputs.last_hidden_state[:, 0, :] # (B, hidden_size)
347
+
348
+ # Forward through classifier
349
+ logits = self.classifier(cls_output)
350
+
351
+ # Compute loss if labels provided
352
+ loss = None
353
+ if labels is not None:
354
+ loss_fn = nn.CrossEntropyLoss()
355
+ loss = loss_fn(logits, labels)
356
+
357
+ return {
358
+ 'logits': logits,
359
+ 'loss': loss,
360
+ 'last_hidden_state': outputs.last_hidden_state
361
+ }
362
+
363
+ def unfreeze_backbone(self):
364
+ """Unfreeze the backbone for full fine-tuning"""
365
+ for param in self.vit.parameters():
366
+ param.requires_grad = True
367
+ self.freeze_backbone = False
368
+ logger.info("ViT backbone unfrozen")
369
+
370
+
371
+ class LateFusionViT(nn.Module):
372
+ """
373
+ Late Fusion ViT: Separate ViT encoders for RGB and Thermal, fuse at feature/prediction level
374
+ """
375
+
376
+ def __init__(
377
+ self,
378
+ model_name: str = "google/vit-base-patch16-224-in21k",
379
+ num_classes: int = 7,
380
+ dropout_rate: float = 0.1,
381
+ freeze_backbone: bool = False,
382
+ use_gradient_checkpointing: bool = False,
383
+ fusion_type: str = "concat",
384
+ fusion_layer: str = "feature"
385
+ ):
386
+ """
387
+ Args:
388
+ model_name: Pre-trained ViT model name
389
+ num_classes: Number of emotion classes
390
+ dropout_rate: Dropout rate for classifier
391
+ freeze_backbone: Whether to freeze backbone
392
+ use_gradient_checkpointing: Whether to use gradient checkpointing
393
+ fusion_type: How to fuse features ("concat", "add", "attention")
394
+ fusion_layer: Where to fuse ("feature" or "prediction")
395
+ """
396
+ super().__init__()
397
+
398
+ self.model_name = model_name
399
+ self.num_classes = num_classes
400
+ self.dropout_rate = dropout_rate
401
+ self.freeze_backbone = freeze_backbone
402
+ self.fusion_type = fusion_type
403
+ self.fusion_layer = fusion_layer
404
+
405
+ # Load pre-trained ViT configuration
406
+ self.config = ViTConfig.from_pretrained(model_name)
407
+ hidden_size = self.config.hidden_size
408
+
409
+ # Create separate ViT encoders for RGB and Thermal
410
+ self.rgb_vit = ViTModel.from_pretrained(model_name)
411
+ self.thermal_vit = ViTModel.from_pretrained(model_name)
412
+
413
+ # Enable gradient checkpointing if requested
414
+ if use_gradient_checkpointing:
415
+ self.rgb_vit.gradient_checkpointing_enable()
416
+ self.thermal_vit.gradient_checkpointing_enable()
417
+
418
+ # Freeze backbones if requested
419
+ if freeze_backbone:
420
+ self._freeze_backbone()
421
+
422
+ if fusion_layer == "feature":
423
+ # Fuse at feature level, then single classifier
424
+ if fusion_type == "concat":
425
+ fusion_input_size = hidden_size * 2
426
+ elif fusion_type == "attention":
427
+ # Use attention to fuse features
428
+ self.attention_fusion = nn.MultiheadAttention(hidden_size, num_heads=8, batch_first=True)
429
+ fusion_input_size = hidden_size
430
+ else: # add
431
+ fusion_input_size = hidden_size
432
+
433
+ # Single classifier after fusion
434
+ self.classifier = nn.Sequential(
435
+ nn.LayerNorm(fusion_input_size),
436
+ nn.Dropout(dropout_rate),
437
+ nn.Linear(fusion_input_size, hidden_size // 2),
438
+ nn.GELU(),
439
+ nn.Dropout(dropout_rate / 2),
440
+ nn.Linear(hidden_size // 2, num_classes)
441
+ )
442
+ else: # prediction level fusion
443
+ # Separate classifiers for each modality
444
+ self.rgb_classifier = nn.Sequential(
445
+ nn.LayerNorm(hidden_size),
446
+ nn.Dropout(dropout_rate),
447
+ nn.Linear(hidden_size, hidden_size // 2),
448
+ nn.GELU(),
449
+ nn.Dropout(dropout_rate / 2),
450
+ nn.Linear(hidden_size // 2, num_classes)
451
+ )
452
+
453
+ self.thermal_classifier = nn.Sequential(
454
+ nn.LayerNorm(hidden_size),
455
+ nn.Dropout(dropout_rate),
456
+ nn.Linear(hidden_size, hidden_size // 2),
457
+ nn.GELU(),
458
+ nn.Dropout(dropout_rate / 2),
459
+ nn.Linear(hidden_size // 2, num_classes)
460
+ )
461
+
462
+ if fusion_type == "attention":
463
+ # Attention-based prediction fusion
464
+ self.prediction_attention = nn.Linear(num_classes * 2, num_classes)
465
+
466
+ # Initialize weights
467
+ self._init_classifier_weights()
468
+
469
+ def _freeze_backbone(self):
470
+ """Freeze the ViT backbone parameters"""
471
+ for param in self.rgb_vit.parameters():
472
+ param.requires_grad = False
473
+ for param in self.thermal_vit.parameters():
474
+ param.requires_grad = False
475
+ logger.info("ViT backbones frozen")
476
+
477
+ def _init_classifier_weights(self):
478
+ """Initialize classifier weights"""
479
+ if hasattr(self, 'classifier'):
480
+ for module in self.classifier.modules():
481
+ if isinstance(module, nn.Linear):
482
+ torch.nn.init.xavier_uniform_(module.weight)
483
+ if module.bias is not None:
484
+ torch.nn.init.zeros_(module.bias)
485
+
486
+ if hasattr(self, 'rgb_classifier'):
487
+ for module in self.rgb_classifier.modules():
488
+ if isinstance(module, nn.Linear):
489
+ torch.nn.init.xavier_uniform_(module.weight)
490
+ if module.bias is not None:
491
+ torch.nn.init.zeros_(module.bias)
492
+
493
+ if hasattr(self, 'thermal_classifier'):
494
+ for module in self.thermal_classifier.modules():
495
+ if isinstance(module, nn.Linear):
496
+ torch.nn.init.xavier_uniform_(module.weight)
497
+ if module.bias is not None:
498
+ torch.nn.init.zeros_(module.bias)
499
+
500
+ def forward(self, rgb_images: torch.Tensor, thermal_images: torch.Tensor, labels: Optional[torch.Tensor] = None):
501
+ """
502
+ Forward pass for late fusion
503
+
504
+ Args:
505
+ rgb_images: RGB images tensor (B, 3, H, W)
506
+ thermal_images: Thermal images tensor (B, 3, H, W)
507
+ labels: Optional labels for computing loss
508
+
509
+ Returns:
510
+ Dictionary containing logits and loss (if labels provided)
511
+ """
512
+ # Forward through separate ViT encoders
513
+ rgb_outputs = self.rgb_vit(pixel_values=rgb_images)
514
+ thermal_outputs = self.thermal_vit(pixel_values=thermal_images)
515
+
516
+ # Get [CLS] token representations
517
+ rgb_features = rgb_outputs.last_hidden_state[:, 0, :] # (B, hidden_size)
518
+ thermal_features = thermal_outputs.last_hidden_state[:, 0, :] # (B, hidden_size)
519
+
520
+ if self.fusion_layer == "feature":
521
+ # Fuse at feature level
522
+ if self.fusion_type == "concat":
523
+ fused_features = torch.cat([rgb_features, thermal_features], dim=1)
524
+ elif self.fusion_type == "attention":
525
+ # Stack features for attention
526
+ stacked_features = torch.stack([rgb_features, thermal_features], dim=1) # (B, 2, hidden_size)
527
+ fused_features, _ = self.attention_fusion(stacked_features, stacked_features, stacked_features)
528
+ fused_features = fused_features.mean(dim=1) # Average the attended features
529
+ else: # add
530
+ fused_features = rgb_features + thermal_features
531
+
532
+ # Forward through single classifier
533
+ logits = self.classifier(fused_features)
534
+
535
+ else: # prediction level fusion
536
+ # Get predictions from separate classifiers
537
+ rgb_logits = self.rgb_classifier(rgb_features)
538
+ thermal_logits = self.thermal_classifier(thermal_features)
539
+
540
+ # Fuse predictions
541
+ if self.fusion_type == "concat":
542
+ logits = (rgb_logits + thermal_logits) / 2 # Simple average
543
+ elif self.fusion_type == "attention":
544
+ # Attention-based fusion of predictions
545
+ concat_logits = torch.cat([rgb_logits, thermal_logits], dim=1)
546
+ logits = self.prediction_attention(concat_logits)
547
+ else: # add
548
+ logits = rgb_logits + thermal_logits
549
+
550
+ # Compute loss if labels provided
551
+ loss = None
552
+ if labels is not None:
553
+ loss_fn = nn.CrossEntropyLoss()
554
+ loss = loss_fn(logits, labels)
555
+
556
+ return {
557
+ 'logits': logits,
558
+ 'loss': loss,
559
+ 'rgb_features': rgb_features,
560
+ 'thermal_features': thermal_features
561
+ }
562
+
563
+ def unfreeze_backbone(self):
564
+ """Unfreeze the backbones for full fine-tuning"""
565
+ for param in self.rgb_vit.parameters():
566
+ param.requires_grad = True
567
+ for param in self.thermal_vit.parameters():
568
+ param.requires_grad = True
569
+ self.freeze_backbone = False
570
+ logger.info("ViT backbones unfrozen")
571
+
572
+
573
+ def create_multimodal_vit_model(
574
+ mode: str = 'rgb',
575
+ fusion_strategy: str = 'early',
576
+ fusion_type: str = 'concat',
577
+ fusion_layer: str = 'feature',
578
+ model_name: str = "google/vit-base-patch16-224-in21k",
579
+ num_classes: int = 7,
580
+ dropout_rate: float = 0.1,
581
+ freeze_backbone: bool = False,
582
+ use_gradient_checkpointing: bool = False
583
+ ) -> Union[ViTForFER, EarlyFusionViT, LateFusionViT]:
584
+ """
585
+ Create a multimodal ViT model for FER
586
+
587
+ Args:
588
+ mode: 'rgb', 'thermal', or 'combined'
589
+ fusion_strategy: 'early' or 'late' (only for combined mode)
590
+ fusion_type: 'concat', 'add', or 'attention' (for fusion)
591
+ fusion_layer: 'feature' or 'prediction' (for late fusion)
592
+ model_name: Pre-trained ViT model name
593
+ num_classes: Number of emotion classes
594
+ dropout_rate: Dropout rate for classifier
595
+ freeze_backbone: Whether to freeze backbone
596
+ use_gradient_checkpointing: Whether to use gradient checkpointing
597
+
598
+ Returns:
599
+ Appropriate ViT model based on mode
600
+ """
601
+ if mode in ['rgb', 'thermal']:
602
+ # Single modality model
603
+ model = ViTForFER(
604
+ model_name=model_name,
605
+ num_classes=num_classes,
606
+ dropout_rate=dropout_rate,
607
+ freeze_backbone=freeze_backbone,
608
+ use_gradient_checkpointing=use_gradient_checkpointing
609
+ )
610
+ logger.info(f"Created single modality ViT model for {mode}")
611
+
612
+ elif mode == 'combined':
613
+ if fusion_strategy == 'early':
614
+ # Early fusion model
615
+ model = EarlyFusionViT(
616
+ model_name=model_name,
617
+ num_classes=num_classes,
618
+ dropout_rate=dropout_rate,
619
+ freeze_backbone=freeze_backbone,
620
+ use_gradient_checkpointing=use_gradient_checkpointing,
621
+ fusion_type=fusion_type
622
+ )
623
+ logger.info(f"Created early fusion ViT model with {fusion_type} fusion")
624
+
625
+ else: # late fusion
626
+ # Late fusion model
627
+ model = LateFusionViT(
628
+ model_name=model_name,
629
+ num_classes=num_classes,
630
+ dropout_rate=dropout_rate,
631
+ freeze_backbone=freeze_backbone,
632
+ use_gradient_checkpointing=use_gradient_checkpointing,
633
+ fusion_type=fusion_type,
634
+ fusion_layer=fusion_layer
635
+ )
636
+ logger.info(f"Created late fusion ViT model with {fusion_type} fusion at {fusion_layer} level")
637
+ else:
638
+ raise ValueError(f"Invalid mode: {mode}. Must be 'rgb', 'thermal', or 'combined'")
639
+
640
+ return model
641
+
642
+
643
+ def get_optimizer_and_scheduler(
644
+ model: Union[ViTForFER, EarlyFusionViT, LateFusionViT],
645
+ learning_rate: float = 5e-5,
646
+ weight_decay: float = 0.01,
647
+ warmup_steps: int = 1000,
648
+ num_training_steps: int = 10000,
649
+ optimizer_type: str = "adamw"
650
+ ):
651
+ """
652
+ Get optimizer and learning rate scheduler
653
+
654
+ Args:
655
+ model: ViT model
656
+ learning_rate: Learning rate
657
+ weight_decay: Weight decay
658
+ warmup_steps: Number of warmup steps
659
+ num_training_steps: Total training steps
660
+ optimizer_type: Type of optimizer ("adamw" or "sgd")
661
+
662
+ Returns:
663
+ optimizer, scheduler
664
+ """
665
+ # Different learning rates for different parts
666
+ backbone_params = []
667
+ classifier_params = []
668
+
669
+ for name, param in model.named_parameters():
670
+ if param.requires_grad:
671
+ if 'classifier' in name:
672
+ classifier_params.append(param)
673
+ else:
674
+ backbone_params.append(param)
675
+
676
+ # Set different learning rates
677
+ param_groups = [
678
+ {'params': backbone_params, 'lr': learning_rate * 0.1}, # Lower LR for backbone
679
+ {'params': classifier_params, 'lr': learning_rate} # Higher LR for classifier
680
+ ]
681
+
682
+ # Create optimizer
683
+ if optimizer_type.lower() == "adamw":
684
+ optimizer = torch.optim.AdamW(
685
+ param_groups,
686
+ lr=learning_rate,
687
+ weight_decay=weight_decay,
688
+ betas=(0.9, 0.999),
689
+ eps=1e-8
690
+ )
691
+ else: # SGD
692
+ optimizer = torch.optim.SGD(
693
+ param_groups,
694
+ lr=learning_rate,
695
+ momentum=0.9,
696
+ weight_decay=weight_decay,
697
+ nesterov=True
698
+ )
699
+
700
+ # Learning rate scheduler
701
+ from transformers import get_cosine_schedule_with_warmup
702
+
703
+ scheduler = get_cosine_schedule_with_warmup(
704
+ optimizer,
705
+ num_warmup_steps=warmup_steps,
706
+ num_training_steps=num_training_steps
707
+ )
708
+
709
+ return optimizer, scheduler
710
+
711
+
712
+ if __name__ == "__main__":
713
+ # Example usage
714
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
715
+ print(f"Using device: {device}")
716
+
717
+ # Test different model configurations
718
+ configs = [
719
+ {'mode': 'rgb'},
720
+ {'mode': 'thermal'},
721
+ {'mode': 'combined', 'fusion_strategy': 'early', 'fusion_type': 'concat'},
722
+ {'mode': 'combined', 'fusion_strategy': 'early', 'fusion_type': 'add'},
723
+ {'mode': 'combined', 'fusion_strategy': 'late', 'fusion_type': 'concat', 'fusion_layer': 'feature'},
724
+ {'mode': 'combined', 'fusion_strategy': 'late', 'fusion_type': 'attention', 'fusion_layer': 'prediction'},
725
+ ]
726
+
727
+ batch_size = 4
728
+ dummy_rgb = torch.randn(batch_size, 3, 224, 224).to(device)
729
+ dummy_thermal = torch.randn(batch_size, 3, 224, 224).to(device)
730
+ dummy_labels = torch.randint(0, 7, (batch_size,)).to(device)
731
+
732
+ for config in configs:
733
+ print(f"\n=== Testing {config} ===")
734
+
735
+ try:
736
+ model = create_multimodal_vit_model(**config)
737
+ model.to(device)
738
+
739
+ # Test forward pass
740
+ with torch.no_grad():
741
+ if config['mode'] == 'combined':
742
+ outputs = model(dummy_rgb, dummy_thermal, dummy_labels)
743
+ else:
744
+ if config['mode'] == 'rgb':
745
+ outputs = model(dummy_rgb, dummy_labels)
746
+ else: # thermal
747
+ outputs = model(dummy_thermal, dummy_labels)
748
+
749
+ print(f"Output logits shape: {outputs['logits'].shape}")
750
+ print(f"Loss: {outputs['loss']}")
751
+
752
+ except Exception as e:
753
+ print(f"Error: {e}")