File size: 5,634 Bytes
0a82b18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import logging
from types import SimpleNamespace as Namespace
# Assuming these are in your project structure
from models.salad import SALAD
from models.mixvpr import MixVPR
class DINOv2FeatureExtractor(nn.Module):
def __init__(
self,
image_size=518, # Default for DINOv2 models
model_type="vit_base_patch14_reg4_dinov2.lvd142m",
num_of_layers_to_unfreeze=1,
desc_dim=768, # vit-base has 768-dim embeddings
aggregator_type="No",
):
super().__init__()
# Initialize backbone with registers
self.backbone = timm.create_model(
model_type, pretrained=True, num_classes=0, img_size=image_size
)
# Store configuration parameters
self.model_type = model_type
self.num_channels = self.backbone.embed_dim
self.desc_dim = desc_dim
self.image_size = image_size
self.num_of_layers_to_unfreeze = num_of_layers_to_unfreeze
self.aggregator_type = aggregator_type
self.aggregator = None
if aggregator_type == "SALAD":
if "vit_small" in model_type:
self.aggregator = SALAD(
num_channels=self.num_channels,
num_clusters=24,
cluster_dim=64,
token_dim=512,
dropout=0.3,
)
# Output: 512 + (24 * 64) = 2,048 dims
self.desc_dim = 512 + (24 * 64)
elif "vit_base" in model_type:
self.aggregator = SALAD(
num_channels=self.num_channels,
num_clusters=32,
cluster_dim=64,
token_dim=1024,
dropout=0.3,
)
# Output: 1024 + (32 * 64) = 3,072 dims
self.desc_dim = 1024 + (32 * 64)
elif "vit_large" in model_type:
self.aggregator = SALAD(
num_channels=self.num_channels,
num_clusters=48,
cluster_dim=64,
token_dim=1024,
dropout=0.3,
)
# Output: 1024 + (48 * 64) = 4,096 dims
self.desc_dim = 1024 + (48 * 64)
elif aggregator_type == "MixVPR":
patch_dim = image_size // 14
if "vit_small" in model_type:
out_dim = 2048
elif "vit_base" in model_type:
out_dim = 3072
elif "vit_large" in model_type:
out_dim = 4096
else:
# Default or error
out_dim = 4096
self.aggregator = MixVPR(
in_channels=self.num_channels,
in_h=patch_dim,
in_w=patch_dim,
out_channels=out_dim,
)
self.desc_dim = out_dim
# This should be called regardless of the aggregator type.
self._freeze_parameters()
def _freeze_parameters(self):
"""
Freeze all parameters except the last N transformer blocks and norm layer.
"""
# First freeze everything
for param in self.backbone.parameters():
param.requires_grad = False
# Unfreeze the last N blocks
if self.num_of_layers_to_unfreeze > 0:
for block in self.backbone.blocks[
-self.num_of_layers_to_unfreeze :
]:
for param in block.parameters():
param.requires_grad = True
# Unfreeze norm layer
for param in self.backbone.norm.parameters():
param.requires_grad = True
# Count trainable parameters for backbone
def count_trainable_params(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
logging.info(
f"Number of trainable parameters backbone: {count_trainable_params(self.backbone):,}"
)
# Count aggregator parameters if it exists
if self.aggregator is not None:
aggregator_params = count_trainable_params(self.aggregator)
logging.info(
f"Number of trainable parameters aggregator: {aggregator_params:,}"
)
logging.info(
f"Total trainable parameters: {count_trainable_params(self.backbone) + aggregator_params:,}"
)
def forward(self, x):
B, _, H, W = x.shape
x = self.backbone.forward_features(x)
# Consistent handling for register vs. non-register models
if self.aggregator_type in ["SALAD", "MixVPR"]:
# DINOv2 with registers has 4 register tokens + 1 CLS token
# Standard ViT has 1 CLS token
start_index = 5 if "reg" in self.model_type else 1
patch_tokens = x[:, start_index:]
# Reshape to (B, C, H, W) for aggregators
patch_tokens_map = patch_tokens.reshape(
(B, H // 14, W // 14, self.num_channels)
).permute(0, 3, 1, 2)
if self.aggregator_type == "SALAD":
cls_token = x[:, 0]
return self.aggregator((patch_tokens_map, cls_token))
elif self.aggregator_type == "MixVPR":
return self.aggregator(patch_tokens_map)
# Default behavior: extract features from CLS pooling
features = self.backbone.forward_head(x, pre_logits=True)
# L2 normalization
return F.normalize(features, p=2, dim=-1) |