File size: 886 Bytes
4d4dd90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import torch.nn.functional as F

from ..base_model import BaseModel


class DinoV2(BaseModel):
    default_conf = {"weights": "dinov2_vits14", "allow_resize": False}
    required_data_keys = ["image"]

    def _init(self, conf):
        self.net = torch.hub.load("facebookresearch/dinov2", conf.weights)
        self.set_initialized()

    def _forward(self, data):
        img = data["image"]
        if self.conf.allow_resize:
            img = F.upsample(img, [int(x // 14 * 14) for x in img.shape[-2:]])
        desc, cls_token = self.net.get_intermediate_layers(
            img, n=1, return_class_token=True, reshape=True
        )[0]

        return {
            "features": desc,
            "global_descriptor": cls_token,
            "descriptors": desc.flatten(-2).transpose(-2, -1),
        }

    def loss(self, pred, data):
        raise NotImplementedError