EarthLoc2 / models /apl_model_dinov2.py
Pawel Piwowarski
init commit
0a82b18
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import logging
from types import SimpleNamespace as Namespace
# Assuming these are in your project structure
from models.salad import SALAD
from models.mixvpr import MixVPR
class DINOv2FeatureExtractor(nn.Module):
def __init__(
self,
image_size=518, # Default for DINOv2 models
model_type="vit_base_patch14_reg4_dinov2.lvd142m",
num_of_layers_to_unfreeze=1,
desc_dim=768, # vit-base has 768-dim embeddings
aggregator_type="No",
):
super().__init__()
# Initialize backbone with registers
self.backbone = timm.create_model(
model_type, pretrained=True, num_classes=0, img_size=image_size
)
# Store configuration parameters
self.model_type = model_type
self.num_channels = self.backbone.embed_dim
self.desc_dim = desc_dim
self.image_size = image_size
self.num_of_layers_to_unfreeze = num_of_layers_to_unfreeze
self.aggregator_type = aggregator_type
self.aggregator = None
if aggregator_type == "SALAD":
if "vit_small" in model_type:
self.aggregator = SALAD(
num_channels=self.num_channels,
num_clusters=24,
cluster_dim=64,
token_dim=512,
dropout=0.3,
)
# Output: 512 + (24 * 64) = 2,048 dims
self.desc_dim = 512 + (24 * 64)
elif "vit_base" in model_type:
self.aggregator = SALAD(
num_channels=self.num_channels,
num_clusters=32,
cluster_dim=64,
token_dim=1024,
dropout=0.3,
)
# Output: 1024 + (32 * 64) = 3,072 dims
self.desc_dim = 1024 + (32 * 64)
elif "vit_large" in model_type:
self.aggregator = SALAD(
num_channels=self.num_channels,
num_clusters=48,
cluster_dim=64,
token_dim=1024,
dropout=0.3,
)
# Output: 1024 + (48 * 64) = 4,096 dims
self.desc_dim = 1024 + (48 * 64)
elif aggregator_type == "MixVPR":
patch_dim = image_size // 14
if "vit_small" in model_type:
out_dim = 2048
elif "vit_base" in model_type:
out_dim = 3072
elif "vit_large" in model_type:
out_dim = 4096
else:
# Default or error
out_dim = 4096
self.aggregator = MixVPR(
in_channels=self.num_channels,
in_h=patch_dim,
in_w=patch_dim,
out_channels=out_dim,
)
self.desc_dim = out_dim
# This should be called regardless of the aggregator type.
self._freeze_parameters()
def _freeze_parameters(self):
"""
Freeze all parameters except the last N transformer blocks and norm layer.
"""
# First freeze everything
for param in self.backbone.parameters():
param.requires_grad = False
# Unfreeze the last N blocks
if self.num_of_layers_to_unfreeze > 0:
for block in self.backbone.blocks[
-self.num_of_layers_to_unfreeze :
]:
for param in block.parameters():
param.requires_grad = True
# Unfreeze norm layer
for param in self.backbone.norm.parameters():
param.requires_grad = True
# Count trainable parameters for backbone
def count_trainable_params(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
logging.info(
f"Number of trainable parameters backbone: {count_trainable_params(self.backbone):,}"
)
# Count aggregator parameters if it exists
if self.aggregator is not None:
aggregator_params = count_trainable_params(self.aggregator)
logging.info(
f"Number of trainable parameters aggregator: {aggregator_params:,}"
)
logging.info(
f"Total trainable parameters: {count_trainable_params(self.backbone) + aggregator_params:,}"
)
def forward(self, x):
B, _, H, W = x.shape
x = self.backbone.forward_features(x)
# Consistent handling for register vs. non-register models
if self.aggregator_type in ["SALAD", "MixVPR"]:
# DINOv2 with registers has 4 register tokens + 1 CLS token
# Standard ViT has 1 CLS token
start_index = 5 if "reg" in self.model_type else 1
patch_tokens = x[:, start_index:]
# Reshape to (B, C, H, W) for aggregators
patch_tokens_map = patch_tokens.reshape(
(B, H // 14, W // 14, self.num_channels)
).permute(0, 3, 1, 2)
if self.aggregator_type == "SALAD":
cls_token = x[:, 0]
return self.aggregator((patch_tokens_map, cls_token))
elif self.aggregator_type == "MixVPR":
return self.aggregator(patch_tokens_map)
# Default behavior: extract features from CLS pooling
features = self.backbone.forward_head(x, pre_logits=True)
# L2 normalization
return F.normalize(features, p=2, dim=-1)