|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import timm |
|
import logging |
|
from types import SimpleNamespace as Namespace |
|
|
|
|
|
from models.salad import SALAD |
|
from models.mixvpr import MixVPR |
|
|
|
|
|
|
|
|
|
class DINOv2FeatureExtractor(nn.Module): |
|
def __init__( |
|
self, |
|
image_size=518, |
|
model_type="vit_base_patch14_reg4_dinov2.lvd142m", |
|
num_of_layers_to_unfreeze=1, |
|
desc_dim=768, |
|
aggregator_type="No", |
|
): |
|
super().__init__() |
|
|
|
|
|
self.backbone = timm.create_model( |
|
model_type, pretrained=True, num_classes=0, img_size=image_size |
|
) |
|
|
|
|
|
self.model_type = model_type |
|
self.num_channels = self.backbone.embed_dim |
|
self.desc_dim = desc_dim |
|
self.image_size = image_size |
|
self.num_of_layers_to_unfreeze = num_of_layers_to_unfreeze |
|
self.aggregator_type = aggregator_type |
|
self.aggregator = None |
|
|
|
if aggregator_type == "SALAD": |
|
if "vit_small" in model_type: |
|
self.aggregator = SALAD( |
|
num_channels=self.num_channels, |
|
num_clusters=24, |
|
cluster_dim=64, |
|
token_dim=512, |
|
dropout=0.3, |
|
) |
|
|
|
self.desc_dim = 512 + (24 * 64) |
|
elif "vit_base" in model_type: |
|
self.aggregator = SALAD( |
|
num_channels=self.num_channels, |
|
num_clusters=32, |
|
cluster_dim=64, |
|
token_dim=1024, |
|
dropout=0.3, |
|
) |
|
|
|
self.desc_dim = 1024 + (32 * 64) |
|
elif "vit_large" in model_type: |
|
self.aggregator = SALAD( |
|
num_channels=self.num_channels, |
|
num_clusters=48, |
|
cluster_dim=64, |
|
token_dim=1024, |
|
dropout=0.3, |
|
) |
|
|
|
self.desc_dim = 1024 + (48 * 64) |
|
elif aggregator_type == "MixVPR": |
|
patch_dim = image_size // 14 |
|
if "vit_small" in model_type: |
|
out_dim = 2048 |
|
elif "vit_base" in model_type: |
|
out_dim = 3072 |
|
elif "vit_large" in model_type: |
|
out_dim = 4096 |
|
else: |
|
|
|
out_dim = 4096 |
|
|
|
self.aggregator = MixVPR( |
|
in_channels=self.num_channels, |
|
in_h=patch_dim, |
|
in_w=patch_dim, |
|
out_channels=out_dim, |
|
) |
|
self.desc_dim = out_dim |
|
|
|
|
|
|
|
self._freeze_parameters() |
|
|
|
def _freeze_parameters(self): |
|
""" |
|
Freeze all parameters except the last N transformer blocks and norm layer. |
|
""" |
|
|
|
for param in self.backbone.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
if self.num_of_layers_to_unfreeze > 0: |
|
for block in self.backbone.blocks[ |
|
-self.num_of_layers_to_unfreeze : |
|
]: |
|
for param in block.parameters(): |
|
param.requires_grad = True |
|
|
|
|
|
for param in self.backbone.norm.parameters(): |
|
param.requires_grad = True |
|
|
|
|
|
def count_trainable_params(model): |
|
return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
logging.info( |
|
f"Number of trainable parameters backbone: {count_trainable_params(self.backbone):,}" |
|
) |
|
|
|
|
|
if self.aggregator is not None: |
|
aggregator_params = count_trainable_params(self.aggregator) |
|
logging.info( |
|
f"Number of trainable parameters aggregator: {aggregator_params:,}" |
|
) |
|
logging.info( |
|
f"Total trainable parameters: {count_trainable_params(self.backbone) + aggregator_params:,}" |
|
) |
|
|
|
def forward(self, x): |
|
B, _, H, W = x.shape |
|
x = self.backbone.forward_features(x) |
|
|
|
|
|
if self.aggregator_type in ["SALAD", "MixVPR"]: |
|
|
|
|
|
start_index = 5 if "reg" in self.model_type else 1 |
|
patch_tokens = x[:, start_index:] |
|
|
|
|
|
patch_tokens_map = patch_tokens.reshape( |
|
(B, H // 14, W // 14, self.num_channels) |
|
).permute(0, 3, 1, 2) |
|
|
|
if self.aggregator_type == "SALAD": |
|
cls_token = x[:, 0] |
|
return self.aggregator((patch_tokens_map, cls_token)) |
|
elif self.aggregator_type == "MixVPR": |
|
return self.aggregator(patch_tokens_map) |
|
|
|
|
|
features = self.backbone.forward_head(x, pre_logits=True) |
|
|
|
|
|
return F.normalize(features, p=2, dim=-1) |