import torch import torch.nn as nn import torchvision import torchvision.transforms as T from PIL import Image class DinoVisionTransformerClassifier(nn.Module): def __init__(self, num_classes): super(DinoVisionTransformerClassifier, self).__init__() self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Workaround to bypass HTTP Error 403 rate limit exceeded torch.hub._validate_not_a_forked_repo=lambda a,b,c: True self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_lc") self.model.linear_head = nn.Sequential( nn.Linear(3840, 512, bias=True), nn.ReLU(), nn.Linear(512, 256, bias=True), nn.ReLU(), nn.Linear(256, num_classes, bias=True) ) self.model.to(self.device) self.transform_image = T.Compose([ T.Resize((224, 224)), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) self.model_name = "dinov2" def load_image_from_filepath(self, img: str) -> torch.Tensor: """ Load an image as filepath and return a tensor that can be used as an input to model. """ img = Image.open(img).convert('RGB') transformed_img = self.transform_image(img)[:3].unsqueeze(0).to(self.device) return transformed_img def load_image_from_pillowimage(self, img: Image.Image) -> torch.Tensor: """ Load an image as Pillow Image and return a tensor that can be used as an input to model. """ transformed_img = self.transform_image(img)[:3].unsqueeze(0).to(self.device) return transformed_img def forward(self, x): if isinstance(x, str): x = self.load_image_from_filepath(x) if isinstance(x, Image.Image): x = self.load_image_from_pillowimage(x) return self.model(x)