Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torchvision.models import vit_b_16, ViT_B_16_Weights | |
| class ViTEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| d_model=512, | |
| freeze_backbone=True, | |
| pretrained=True | |
| ): | |
| super().__init__() | |
| if pretrained: | |
| weights = ViT_B_16_Weights.IMAGENET1K_V1 | |
| else: | |
| weights = None | |
| self.vit = vit_b_16(weights=weights) | |
| # Remove classifier | |
| self.vit.heads = nn.Identity() | |
| self.hidden_dim = self.vit.hidden_dim | |
| # Projection | |
| self.proj = nn.Linear(self.hidden_dim, d_model) | |
| # Freeze | |
| if freeze_backbone: | |
| for p in self.vit.parameters(): | |
| p.requires_grad = False | |
| def unfreeze_backbone(self, unfreeze=True): | |
| for p in self.vit.parameters(): | |
| p.requires_grad = unfreeze | |
| print(f"ViT Backbone {'Unfrozen' if unfreeze else 'Frozen'}") | |
| def forward(self, images): | |
| """ | |
| images: (B, 3, 224, 224) | |
| return: (B, 196, d_model) | |
| """ | |
| # 1. Patch Embedding | |
| x = self.vit.conv_proj(images) | |
| # (B, hidden, 14, 14) | |
| x = x.flatten(2).transpose(1, 2) | |
| # (B, 196, hidden) | |
| # 2. Add Positional Embedding (Slicing to skip CLS token pos at index 0) | |
| # We use the parameter DIRECTLY from the model so gradients flow correctly | |
| # and device placement is handled automatically. | |
| # self.vit.encoder.pos_embedding is (1, 197, 768) | |
| x = x + self.vit.encoder.pos_embedding[:, 1:] | |
| # 3. Transformer Layers | |
| # We must not ignore the transformer layers! | |
| # Otherwise we are just using a simple Conv+Linear projection. | |
| x = self.vit.encoder.dropout(x) | |
| x = self.vit.encoder.layers(x) | |
| x = self.vit.encoder.ln(x) | |
| # 4. Project | |
| x = self.proj(x) | |
| # (B, 196, d_model) | |
| return x | |