File size: 5,095 Bytes
7dbe662
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import PIL
from PIL.Image import Image
from typing import Union

from sklearn.decomposition import PCA

import torch
from torch import nn
from torchvision import transforms as tfs


MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
DINO_MODEL_HUB = 'facebookresearch/dino:main'
DINO_MODEL_TYPE = ['dino_vits16',
                  'dino_vits8',
                  'dino_vitb16',
                  'dino_vitb8',
                  'dino_xcit_small_12_p16',
                  'dino_xcit_small_12_p8',
                  'dino_xcit_medium_24_p16',
                  'dino_xcit_medium_24_p8',
                  'dino_resnet50']

DINOV2_MODEL_HUB = 'facebookresearch/dinov2:main'
DINOV2_MODEL_TYPE = ['dinov2_vits14',
                     'dinov2_vitb14',
                     'dinov2_vitl14',
                     'dinov2_vitg14']

class DINO(nn.Module):
    def __init__(self, model_type, device='cuda', img_size=224, pca_dim=None):
        super(DINO, self).__init__()
        assert model_type in DINO_MODEL_TYPE, 'Given DINO model type must in DINO_MODEL_TYPE!'
        self.model = torch.hub.load(DINO_MODEL_HUB, model_type).to(device)
        self.device = device
        for param in self.model.parameters():
            param.requires_grad = False
        self.model.eval()
        self.img_size = img_size
        self.pca_dim = pca_dim
        self.pca = self.set_pca(pca_dim) if pca_dim else None
    def set_pca(self, dim=64):
        return PCA(n_components=dim)
    @torch.no_grad()
    def extract_features(
            self, img: Union[Image, torch.Tensor], transform=True, size=None
    ):
        if transform and isinstance(img, Image):
            img = self.transform(img, self.img_size).unsqueeze(0)  # Nx3xHxW
        with torch.no_grad():
            out = self.model.get_intermediate_layers(img.to(self.device), n=1)[0]
            out = out[:, 1:, :]  # we discard the [CLS] token
            h, w = int(img.shape[2] / self.model.patch_embed.patch_size), int(
                img.shape[3] / self.model.patch_embed.patch_size
            )
            dim = out.shape[-1]
            out = out.reshape(-1, h, w, dim)
            dtype = out.dtype
            if size is not None:
                out = torch.nn.functional.interpolate(out.permute(0, 3, 1, 2), size=size, mode='bilinear').permute(0, 2, 3, 1)
            if self.pca:
                B, H, W, C = out.shape
                out = out.view(-1, C).cpu().numpy()
                out = self.pca.fit_transform(out)
                out = torch.tensor(out.reshape(B, H, W, self.pca_dim), dtype=dtype).to(self.device)
        return out
    def forward(self, img: Union[Image, torch.Tensor], transform=True, size=None):
        return self.extract_features(img, transform, size)
    @staticmethod
    def transform(img, image_size):
        transforms = tfs.Compose(
            [tfs.Resize((image_size, image_size)), tfs.ToTensor(), tfs.Normalize(MEAN, STD)]
        )
        img = transforms(img)
        return img

class DINOV2(nn.Module):
    def __init__(self, model_type, device='cuda', img_size=224, pca_dim=None):
        super(DINOV2, self).__init__()
        assert model_type in DINOV2_MODEL_TYPE, 'Given DINO model type must in DINO_MODEL_TYPE!'
        self.model = torch.hub.load(DINOV2_MODEL_HUB, model_type).to(device)
        self.device = device
        for param in self.model.parameters():
            param.requires_grad = False
        self.model.eval()
        self.img_size = img_size
        self.pca_dim = pca_dim
        self.pca = self.set_pca(pca_dim) if pca_dim else None
    def set_pca(self, dim=64):
        return PCA(n_components=dim)
    @torch.no_grad()
    def extract_features(
            self, img: Union[Image, torch.Tensor], transform=True, size=None
    ):
        if transform and isinstance(img, Image):
            img = self.transform(img, self.img_size).unsqueeze(0)  # Nx3xHxW
        with torch.no_grad():
            out = self.model.forward_features(img.to(self.device))['x_norm_patchtokens']
            h, w = int(img.shape[2] / self.model.patch_size), int(
                img.shape[3] / self.model.patch_size
            )
            dim = out.shape[-1]
            out = out.reshape(-1, h, w, dim)
            dtype = out.dtype
            if size is not None:
                out = torch.nn.functional.interpolate(out.permute(0, 3, 1, 2), size=size, mode='bilinear').permute(0, 2, 3, 1)
            if self.pca:
                B, H, W, C = out.shape
                out = out.view(-1, C).cpu().numpy()
                out = self.pca.fit_transform(out)
                out = torch.tensor(out.reshape(B, H, W, self.pca_dim), dtype=dtype).to(self.device)
        return out
    def forward(self, img: Union[Image, torch.Tensor], transform=True, size=None):
        return self.extract_features(img, transform, size)
    @staticmethod
    def transform(img, image_size):
        transforms = tfs.Compose(
            [tfs.Resize((image_size, image_size)), tfs.ToTensor(), tfs.Normalize(MEAN, STD)]
        )
        img = transforms(img)
        return img