import PIL from PIL.Image import Image from typing import Union from sklearn.decomposition import PCA import torch from torch import nn from torchvision import transforms as tfs MEAN = [0.485, 0.456, 0.406] STD = [0.229, 0.224, 0.225] DINO_MODEL_HUB = 'facebookresearch/dino:main' DINO_MODEL_TYPE = ['dino_vits16', 'dino_vits8', 'dino_vitb16', 'dino_vitb8', 'dino_xcit_small_12_p16', 'dino_xcit_small_12_p8', 'dino_xcit_medium_24_p16', 'dino_xcit_medium_24_p8', 'dino_resnet50'] DINOV2_MODEL_HUB = 'facebookresearch/dinov2:main' DINOV2_MODEL_TYPE = ['dinov2_vits14', 'dinov2_vitb14', 'dinov2_vitl14', 'dinov2_vitg14'] class DINO(nn.Module): def __init__(self, model_type, device='cuda', img_size=224, pca_dim=None): super(DINO, self).__init__() assert model_type in DINO_MODEL_TYPE, 'Given DINO model type must in DINO_MODEL_TYPE!' self.model = torch.hub.load(DINO_MODEL_HUB, model_type).to(device) self.device = device for param in self.model.parameters(): param.requires_grad = False self.model.eval() self.img_size = img_size self.pca_dim = pca_dim self.pca = self.set_pca(pca_dim) if pca_dim else None def set_pca(self, dim=64): return PCA(n_components=dim) @torch.no_grad() def extract_features( self, img: Union[Image, torch.Tensor], transform=True, size=None ): if transform and isinstance(img, Image): img = self.transform(img, self.img_size).unsqueeze(0) # Nx3xHxW with torch.no_grad(): out = self.model.get_intermediate_layers(img.to(self.device), n=1)[0] out = out[:, 1:, :] # we discard the [CLS] token h, w = int(img.shape[2] / self.model.patch_embed.patch_size), int( img.shape[3] / self.model.patch_embed.patch_size ) dim = out.shape[-1] out = out.reshape(-1, h, w, dim) dtype = out.dtype if size is not None: out = torch.nn.functional.interpolate(out.permute(0, 3, 1, 2), size=size, mode='bilinear').permute(0, 2, 3, 1) if self.pca: B, H, W, C = out.shape out = out.view(-1, C).cpu().numpy() out = self.pca.fit_transform(out) out = torch.tensor(out.reshape(B, H, W, self.pca_dim), dtype=dtype).to(self.device) return out def forward(self, img: Union[Image, torch.Tensor], transform=True, size=None): return self.extract_features(img, transform, size) @staticmethod def transform(img, image_size): transforms = tfs.Compose( [tfs.Resize((image_size, image_size)), tfs.ToTensor(), tfs.Normalize(MEAN, STD)] ) img = transforms(img) return img class DINOV2(nn.Module): def __init__(self, model_type, device='cuda', img_size=224, pca_dim=None): super(DINOV2, self).__init__() assert model_type in DINOV2_MODEL_TYPE, 'Given DINO model type must in DINO_MODEL_TYPE!' self.model = torch.hub.load(DINOV2_MODEL_HUB, model_type).to(device) self.device = device for param in self.model.parameters(): param.requires_grad = False self.model.eval() self.img_size = img_size self.pca_dim = pca_dim self.pca = self.set_pca(pca_dim) if pca_dim else None def set_pca(self, dim=64): return PCA(n_components=dim) @torch.no_grad() def extract_features( self, img: Union[Image, torch.Tensor], transform=True, size=None ): if transform and isinstance(img, Image): img = self.transform(img, self.img_size).unsqueeze(0) # Nx3xHxW with torch.no_grad(): out = self.model.forward_features(img.to(self.device))['x_norm_patchtokens'] h, w = int(img.shape[2] / self.model.patch_size), int( img.shape[3] / self.model.patch_size ) dim = out.shape[-1] out = out.reshape(-1, h, w, dim) dtype = out.dtype if size is not None: out = torch.nn.functional.interpolate(out.permute(0, 3, 1, 2), size=size, mode='bilinear').permute(0, 2, 3, 1) if self.pca: B, H, W, C = out.shape out = out.view(-1, C).cpu().numpy() out = self.pca.fit_transform(out) out = torch.tensor(out.reshape(B, H, W, self.pca_dim), dtype=dtype).to(self.device) return out def forward(self, img: Union[Image, torch.Tensor], transform=True, size=None): return self.extract_features(img, transform, size) @staticmethod def transform(img, image_size): transforms = tfs.Compose( [tfs.Resize((image_size, image_size)), tfs.ToTensor(), tfs.Normalize(MEAN, STD)] ) img = transforms(img) return img