YouLiXiya's picture
Upload 22 files
7dbe662
raw
history blame
No virus
5.1 kB
import PIL
from PIL.Image import Image
from typing import Union
from sklearn.decomposition import PCA
import torch
from torch import nn
from torchvision import transforms as tfs
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
DINO_MODEL_HUB = 'facebookresearch/dino:main'
DINO_MODEL_TYPE = ['dino_vits16',
'dino_vits8',
'dino_vitb16',
'dino_vitb8',
'dino_xcit_small_12_p16',
'dino_xcit_small_12_p8',
'dino_xcit_medium_24_p16',
'dino_xcit_medium_24_p8',
'dino_resnet50']
DINOV2_MODEL_HUB = 'facebookresearch/dinov2:main'
DINOV2_MODEL_TYPE = ['dinov2_vits14',
'dinov2_vitb14',
'dinov2_vitl14',
'dinov2_vitg14']
class DINO(nn.Module):
def __init__(self, model_type, device='cuda', img_size=224, pca_dim=None):
super(DINO, self).__init__()
assert model_type in DINO_MODEL_TYPE, 'Given DINO model type must in DINO_MODEL_TYPE!'
self.model = torch.hub.load(DINO_MODEL_HUB, model_type).to(device)
self.device = device
for param in self.model.parameters():
param.requires_grad = False
self.model.eval()
self.img_size = img_size
self.pca_dim = pca_dim
self.pca = self.set_pca(pca_dim) if pca_dim else None
def set_pca(self, dim=64):
return PCA(n_components=dim)
@torch.no_grad()
def extract_features(
self, img: Union[Image, torch.Tensor], transform=True, size=None
):
if transform and isinstance(img, Image):
img = self.transform(img, self.img_size).unsqueeze(0) # Nx3xHxW
with torch.no_grad():
out = self.model.get_intermediate_layers(img.to(self.device), n=1)[0]
out = out[:, 1:, :] # we discard the [CLS] token
h, w = int(img.shape[2] / self.model.patch_embed.patch_size), int(
img.shape[3] / self.model.patch_embed.patch_size
)
dim = out.shape[-1]
out = out.reshape(-1, h, w, dim)
dtype = out.dtype
if size is not None:
out = torch.nn.functional.interpolate(out.permute(0, 3, 1, 2), size=size, mode='bilinear').permute(0, 2, 3, 1)
if self.pca:
B, H, W, C = out.shape
out = out.view(-1, C).cpu().numpy()
out = self.pca.fit_transform(out)
out = torch.tensor(out.reshape(B, H, W, self.pca_dim), dtype=dtype).to(self.device)
return out
def forward(self, img: Union[Image, torch.Tensor], transform=True, size=None):
return self.extract_features(img, transform, size)
@staticmethod
def transform(img, image_size):
transforms = tfs.Compose(
[tfs.Resize((image_size, image_size)), tfs.ToTensor(), tfs.Normalize(MEAN, STD)]
)
img = transforms(img)
return img
class DINOV2(nn.Module):
def __init__(self, model_type, device='cuda', img_size=224, pca_dim=None):
super(DINOV2, self).__init__()
assert model_type in DINOV2_MODEL_TYPE, 'Given DINO model type must in DINO_MODEL_TYPE!'
self.model = torch.hub.load(DINOV2_MODEL_HUB, model_type).to(device)
self.device = device
for param in self.model.parameters():
param.requires_grad = False
self.model.eval()
self.img_size = img_size
self.pca_dim = pca_dim
self.pca = self.set_pca(pca_dim) if pca_dim else None
def set_pca(self, dim=64):
return PCA(n_components=dim)
@torch.no_grad()
def extract_features(
self, img: Union[Image, torch.Tensor], transform=True, size=None
):
if transform and isinstance(img, Image):
img = self.transform(img, self.img_size).unsqueeze(0) # Nx3xHxW
with torch.no_grad():
out = self.model.forward_features(img.to(self.device))['x_norm_patchtokens']
h, w = int(img.shape[2] / self.model.patch_size), int(
img.shape[3] / self.model.patch_size
)
dim = out.shape[-1]
out = out.reshape(-1, h, w, dim)
dtype = out.dtype
if size is not None:
out = torch.nn.functional.interpolate(out.permute(0, 3, 1, 2), size=size, mode='bilinear').permute(0, 2, 3, 1)
if self.pca:
B, H, W, C = out.shape
out = out.view(-1, C).cpu().numpy()
out = self.pca.fit_transform(out)
out = torch.tensor(out.reshape(B, H, W, self.pca_dim), dtype=dtype).to(self.device)
return out
def forward(self, img: Union[Image, torch.Tensor], transform=True, size=None):
return self.extract_features(img, transform, size)
@staticmethod
def transform(img, image_size):
transforms = tfs.Compose(
[tfs.Resize((image_size, image_size)), tfs.ToTensor(), tfs.Normalize(MEAN, STD)]
)
img = transforms(img)
return img