codealchemist01 commited on
Commit
28b51fd
·
verified ·
1 Parent(s): 84c468a

Upload models/vit_branch.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/vit_branch.py +111 -0
models/vit_branch.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Vision Transformer Branch for Hybrid Food Classifier
3
+ Uses DeiT-Base as backbone with custom head
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers import DeiTModel, DeiTConfig
8
+ from typing import Tuple
9
+
10
+ class ViTBranch(nn.Module):
11
+ """Vision Transformer branch using DeiT-Base"""
12
+
13
+ def __init__(
14
+ self,
15
+ model_name: str = "facebook/deit-base-distilled-patch16-224",
16
+ pretrained: bool = True,
17
+ freeze_early_layers: bool = True,
18
+ dropout: float = 0.1,
19
+ feature_dim: int = 768
20
+ ):
21
+ super(ViTBranch, self).__init__()
22
+
23
+ self.feature_dim = feature_dim
24
+
25
+ # Load DeiT model
26
+ if pretrained:
27
+ self.vit = DeiTModel.from_pretrained(model_name)
28
+ else:
29
+ config = DeiTConfig.from_pretrained(model_name)
30
+ self.vit = DeiTModel(config)
31
+
32
+ # Get model dimensions
33
+ self.hidden_size = self.vit.config.hidden_size # 768 for base
34
+ self.num_patches = (224 // 16) ** 2 # 196 patches for 224x224 image
35
+
36
+ # Freeze early layers if specified
37
+ if freeze_early_layers:
38
+ self._freeze_early_layers()
39
+
40
+ # Feature projection to match CNN branch
41
+ self.feature_proj = nn.Sequential(
42
+ nn.Linear(self.hidden_size, feature_dim),
43
+ nn.LayerNorm(feature_dim),
44
+ nn.GELU(),
45
+ nn.Dropout(dropout)
46
+ )
47
+
48
+ # Spatial feature projection (for fusion with CNN spatial features)
49
+ self.spatial_proj = nn.Sequential(
50
+ nn.Linear(self.hidden_size, feature_dim),
51
+ nn.LayerNorm(feature_dim),
52
+ nn.GELU(),
53
+ nn.Dropout(dropout)
54
+ )
55
+
56
+ # Additional processing head
57
+ self.feature_head = nn.Sequential(
58
+ nn.Linear(feature_dim, feature_dim),
59
+ nn.LayerNorm(feature_dim),
60
+ nn.GELU(),
61
+ nn.Dropout(dropout)
62
+ )
63
+
64
+ def _freeze_early_layers(self):
65
+ """Freeze early layers of the ViT"""
66
+ # Freeze first 8 transformer layers (out of 12)
67
+ layers_to_freeze = 8
68
+ for i, layer in enumerate(self.vit.encoder.layer):
69
+ if i < layers_to_freeze:
70
+ for param in layer.parameters():
71
+ param.requires_grad = False
72
+
73
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
74
+ """
75
+ Forward pass
76
+
77
+ Args:
78
+ x: Input tensor [B, 3, H, W]
79
+
80
+ Returns:
81
+ spatial_features: Patch features [B, num_patches, feature_dim]
82
+ global_features: CLS token features [B, feature_dim]
83
+ """
84
+ # Get ViT outputs
85
+ outputs = self.vit(pixel_values=x)
86
+
87
+ # Extract features
88
+ last_hidden_states = outputs.last_hidden_state # [B, seq_len, hidden_size]
89
+
90
+ # CLS token (first token) for global features
91
+ cls_token = last_hidden_states[:, 0] # [B, hidden_size]
92
+
93
+ # Patch tokens for spatial features
94
+ patch_tokens = last_hidden_states[:, 1:] # [B, num_patches, hidden_size]
95
+
96
+ # Project features
97
+ global_features = self.feature_proj(cls_token) # [B, feature_dim]
98
+ spatial_features = self.spatial_proj(patch_tokens) # [B, num_patches, feature_dim]
99
+
100
+ # Additional processing
101
+ global_features = self.feature_head(global_features) # [B, feature_dim]
102
+
103
+ return spatial_features, global_features
104
+
105
+ def get_feature_dim(self) -> int:
106
+ """Get feature dimension"""
107
+ return self.feature_dim
108
+
109
+ def get_num_patches(self) -> int:
110
+ """Get number of patches"""
111
+ return self.num_patches