File size: 5,634 Bytes
0a82b18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import logging
from types import SimpleNamespace as Namespace

# Assuming these are in your project structure
from models.salad import SALAD
from models.mixvpr import MixVPR




class DINOv2FeatureExtractor(nn.Module):
    def __init__(
        self,
        image_size=518,  # Default for DINOv2 models
        model_type="vit_base_patch14_reg4_dinov2.lvd142m",
        num_of_layers_to_unfreeze=1,
        desc_dim=768,  # vit-base has 768-dim embeddings
        aggregator_type="No",
    ):
        super().__init__()

        # Initialize backbone with registers
        self.backbone = timm.create_model(
            model_type, pretrained=True, num_classes=0, img_size=image_size
        )

        # Store configuration parameters
        self.model_type = model_type
        self.num_channels = self.backbone.embed_dim
        self.desc_dim = desc_dim
        self.image_size = image_size
        self.num_of_layers_to_unfreeze = num_of_layers_to_unfreeze
        self.aggregator_type = aggregator_type
        self.aggregator = None

        if aggregator_type == "SALAD":
            if "vit_small" in model_type:
                self.aggregator = SALAD(
                    num_channels=self.num_channels,
                    num_clusters=24,
                    cluster_dim=64,
                    token_dim=512,
                    dropout=0.3,
                )
                # Output: 512 + (24 * 64) = 2,048 dims
                self.desc_dim = 512 + (24 * 64)
            elif "vit_base" in model_type:
                self.aggregator = SALAD(
                    num_channels=self.num_channels,
                    num_clusters=32,
                    cluster_dim=64,
                    token_dim=1024,
                    dropout=0.3,
                )
                # Output: 1024 + (32 * 64) = 3,072 dims
                self.desc_dim = 1024 + (32 * 64)
            elif "vit_large" in model_type:
                self.aggregator = SALAD(
                    num_channels=self.num_channels,
                    num_clusters=48,
                    cluster_dim=64,
                    token_dim=1024,
                    dropout=0.3,
                )
                # Output: 1024 + (48 * 64) = 4,096 dims
                self.desc_dim = 1024 + (48 * 64)
        elif aggregator_type == "MixVPR":
            patch_dim = image_size // 14
            if "vit_small" in model_type:
                out_dim = 2048
            elif "vit_base" in model_type:
                out_dim = 3072
            elif "vit_large" in model_type:
                out_dim = 4096
            else:
                # Default or error
                out_dim = 4096

            self.aggregator = MixVPR(
                in_channels=self.num_channels,
                in_h=patch_dim,
                in_w=patch_dim,
                out_channels=out_dim,
            )
            self.desc_dim = out_dim


        # This should be called regardless of the aggregator type.
        self._freeze_parameters()

    def _freeze_parameters(self):
        """
        Freeze all parameters except the last N transformer blocks and norm layer.
        """
        # First freeze everything
        for param in self.backbone.parameters():
            param.requires_grad = False

        # Unfreeze the last N blocks
        if self.num_of_layers_to_unfreeze > 0:
            for block in self.backbone.blocks[
                -self.num_of_layers_to_unfreeze :
            ]:
                for param in block.parameters():
                    param.requires_grad = True

        # Unfreeze norm layer
        for param in self.backbone.norm.parameters():
            param.requires_grad = True

        # Count trainable parameters for backbone
        def count_trainable_params(model):
            return sum(p.numel() for p in model.parameters() if p.requires_grad)

        logging.info(
            f"Number of trainable parameters backbone: {count_trainable_params(self.backbone):,}"
        )

        # Count aggregator parameters if it exists
        if self.aggregator is not None:
            aggregator_params = count_trainable_params(self.aggregator)
            logging.info(
                f"Number of trainable parameters aggregator: {aggregator_params:,}"
            )
            logging.info(
                f"Total trainable parameters: {count_trainable_params(self.backbone) + aggregator_params:,}"
            )

    def forward(self, x):
        B, _, H, W = x.shape
        x = self.backbone.forward_features(x)

        # Consistent handling for register vs. non-register models
        if self.aggregator_type in ["SALAD", "MixVPR"]:
            # DINOv2 with registers has 4 register tokens + 1 CLS token
            # Standard ViT has 1 CLS token
            start_index = 5 if "reg" in self.model_type else 1
            patch_tokens = x[:, start_index:]

            # Reshape to (B, C, H, W) for aggregators
            patch_tokens_map = patch_tokens.reshape(
                (B, H // 14, W // 14, self.num_channels)
            ).permute(0, 3, 1, 2)

            if self.aggregator_type == "SALAD":
                cls_token = x[:, 0]
                return self.aggregator((patch_tokens_map, cls_token))
            elif self.aggregator_type == "MixVPR":
                return self.aggregator(patch_tokens_map)

        # Default behavior: extract features from CLS pooling
        features = self.backbone.forward_head(x, pre_logits=True)

        # L2 normalization
        return F.normalize(features, p=2, dim=-1)