Spaces:
Running
Running
import PIL | |
from PIL.Image import Image | |
from typing import Union | |
from sklearn.decomposition import PCA | |
import torch | |
from torch import nn | |
from torchvision import transforms as tfs | |
MEAN = [0.485, 0.456, 0.406] | |
STD = [0.229, 0.224, 0.225] | |
DINO_MODEL_HUB = 'facebookresearch/dino:main' | |
DINO_MODEL_TYPE = ['dino_vits16', | |
'dino_vits8', | |
'dino_vitb16', | |
'dino_vitb8', | |
'dino_xcit_small_12_p16', | |
'dino_xcit_small_12_p8', | |
'dino_xcit_medium_24_p16', | |
'dino_xcit_medium_24_p8', | |
'dino_resnet50'] | |
DINOV2_MODEL_HUB = 'facebookresearch/dinov2:main' | |
DINOV2_MODEL_TYPE = ['dinov2_vits14', | |
'dinov2_vitb14', | |
'dinov2_vitl14', | |
'dinov2_vitg14'] | |
class DINO(nn.Module): | |
def __init__(self, model_type, device='cuda', img_size=224, pca_dim=None): | |
super(DINO, self).__init__() | |
assert model_type in DINO_MODEL_TYPE, 'Given DINO model type must in DINO_MODEL_TYPE!' | |
self.model = torch.hub.load(DINO_MODEL_HUB, model_type).to(device) | |
self.device = device | |
for param in self.model.parameters(): | |
param.requires_grad = False | |
self.model.eval() | |
self.img_size = img_size | |
self.pca_dim = pca_dim | |
self.pca = self.set_pca(pca_dim) if pca_dim else None | |
def set_pca(self, dim=64): | |
return PCA(n_components=dim) | |
def extract_features( | |
self, img: Union[Image, torch.Tensor], transform=True, size=None | |
): | |
if transform and isinstance(img, Image): | |
img = self.transform(img, self.img_size).unsqueeze(0) # Nx3xHxW | |
with torch.no_grad(): | |
out = self.model.get_intermediate_layers(img.to(self.device), n=1)[0] | |
out = out[:, 1:, :] # we discard the [CLS] token | |
h, w = int(img.shape[2] / self.model.patch_embed.patch_size), int( | |
img.shape[3] / self.model.patch_embed.patch_size | |
) | |
dim = out.shape[-1] | |
out = out.reshape(-1, h, w, dim) | |
dtype = out.dtype | |
if size is not None: | |
out = torch.nn.functional.interpolate(out.permute(0, 3, 1, 2), size=size, mode='bilinear').permute(0, 2, 3, 1) | |
if self.pca: | |
B, H, W, C = out.shape | |
out = out.view(-1, C).cpu().numpy() | |
out = self.pca.fit_transform(out) | |
out = torch.tensor(out.reshape(B, H, W, self.pca_dim), dtype=dtype).to(self.device) | |
return out | |
def forward(self, img: Union[Image, torch.Tensor], transform=True, size=None): | |
return self.extract_features(img, transform, size) | |
def transform(img, image_size): | |
transforms = tfs.Compose( | |
[tfs.Resize((image_size, image_size)), tfs.ToTensor(), tfs.Normalize(MEAN, STD)] | |
) | |
img = transforms(img) | |
return img | |
class DINOV2(nn.Module): | |
def __init__(self, model_type, device='cuda', img_size=224, pca_dim=None): | |
super(DINOV2, self).__init__() | |
assert model_type in DINOV2_MODEL_TYPE, 'Given DINO model type must in DINO_MODEL_TYPE!' | |
self.model = torch.hub.load(DINOV2_MODEL_HUB, model_type).to(device) | |
self.device = device | |
for param in self.model.parameters(): | |
param.requires_grad = False | |
self.model.eval() | |
self.img_size = img_size | |
self.pca_dim = pca_dim | |
self.pca = self.set_pca(pca_dim) if pca_dim else None | |
def set_pca(self, dim=64): | |
return PCA(n_components=dim) | |
def extract_features( | |
self, img: Union[Image, torch.Tensor], transform=True, size=None | |
): | |
if transform and isinstance(img, Image): | |
img = self.transform(img, self.img_size).unsqueeze(0) # Nx3xHxW | |
with torch.no_grad(): | |
out = self.model.forward_features(img.to(self.device))['x_norm_patchtokens'] | |
h, w = int(img.shape[2] / self.model.patch_size), int( | |
img.shape[3] / self.model.patch_size | |
) | |
dim = out.shape[-1] | |
out = out.reshape(-1, h, w, dim) | |
dtype = out.dtype | |
if size is not None: | |
out = torch.nn.functional.interpolate(out.permute(0, 3, 1, 2), size=size, mode='bilinear').permute(0, 2, 3, 1) | |
if self.pca: | |
B, H, W, C = out.shape | |
out = out.view(-1, C).cpu().numpy() | |
out = self.pca.fit_transform(out) | |
out = torch.tensor(out.reshape(B, H, W, self.pca_dim), dtype=dtype).to(self.device) | |
return out | |
def forward(self, img: Union[Image, torch.Tensor], transform=True, size=None): | |
return self.extract_features(img, transform, size) | |
def transform(img, image_size): | |
transforms = tfs.Compose( | |
[tfs.Resize((image_size, image_size)), tfs.ToTensor(), tfs.Normalize(MEAN, STD)] | |
) | |
img = transforms(img) | |
return img | |