Spaces:
Running
Running
File size: 886 Bytes
4d4dd90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import torch
import torch.nn.functional as F
from ..base_model import BaseModel
class DinoV2(BaseModel):
default_conf = {"weights": "dinov2_vits14", "allow_resize": False}
required_data_keys = ["image"]
def _init(self, conf):
self.net = torch.hub.load("facebookresearch/dinov2", conf.weights)
self.set_initialized()
def _forward(self, data):
img = data["image"]
if self.conf.allow_resize:
img = F.upsample(img, [int(x // 14 * 14) for x in img.shape[-2:]])
desc, cls_token = self.net.get_intermediate_layers(
img, n=1, return_class_token=True, reshape=True
)[0]
return {
"features": desc,
"global_descriptor": cls_token,
"descriptors": desc.flatten(-2).transpose(-2, -1),
}
def loss(self, pred, data):
raise NotImplementedError
|