File size: 1,443 Bytes
79c5088
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
from torchvision import transforms

class DINOv2Processor:
    def __init__(self, model_name="dinov2_vitb14", device="cpu", image_size=518):
        self.model_name = model_name
        self.device = device
        self.image_size = image_size
        self.model = self._load_model()

    def _load_model(self):
        model = torch.hub.load('facebookresearch/dinov2', self.model_name)
        model.eval()
        model.to(self.device)
        return model

    def _preprocess_image(self, image):
        preprocess = transforms.Compose([
            transforms.Resize(self.image_size, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(self.image_size),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            ),
        ])
        return preprocess(image)

    def compute_similarity(self, pil_image1, pil_image2):
        img1_t = self._preprocess_image(pil_image1).unsqueeze(0).to(self.device)
        img2_t = self._preprocess_image(pil_image2).unsqueeze(0).to(self.device)
        with torch.no_grad():
            feat1 = self.model(img1_t)
            feat2 = self.model(img2_t)
        feat1 = feat1 / feat1.norm(dim=1, keepdim=True)
        feat2 = feat2 / feat2.norm(dim=1, keepdim=True)
        similarity = (feat1 * feat2).sum(dim=1)
        return similarity.item()