from typing import List from src.interface import ModelInterface from src.data.classification_result import ClassificationResult from transformers import ViTFeatureExtractor, ViTForImageClassification, ViTImageProcessor import torch class GoogleVit(ModelInterface): def __init__(self): print('init... google vit model') # Load ViT model and feature extractor self.feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224') self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') self.processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') def classify_image(self, image) -> List[ClassificationResult]: # Preprocess the image inputs = self.processor(images=image, return_tensors="pt") # Perform inference outputs = self.model(**inputs) logits = outputs.logits.detach().numpy() # Convert logits to probabilities using softmax (using PyTorch) probabilities = torch.nn.functional.softmax(torch.from_numpy(logits), dim=-1).numpy() # Get the top 5 predictions top_5 = torch.argsort(torch.from_numpy(probabilities), axis=-1, descending=True)[0][:5].numpy() # Create ClassificationResult objects with confidence information results = [ ClassificationResult( class_name=self.model.config.id2label[top_5[i]], confidence=float(probabilities[0][top_5[i]]) ) for i in range(5) ] return results