from typing import List from src.interface import ModelInterface from src.data.classification_result import ClassificationResult from transformers import AutoImageProcessor, ResNetForImageClassification import torch class Resnet50(ModelInterface): def __init__(self): print('init... clip vit model') self.processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50") self.model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50") def classify_image(self, image) -> List[ClassificationResult]: # Preprocess the image inputs = self.processor(images=image, return_tensors="pt") # Perform inference outputs = self.model(**inputs) logits = outputs.logits.detach().numpy() # Convert logits to probabilities using softmax (using PyTorch) probabilities = torch.nn.functional.softmax(torch.from_numpy(logits), dim=-1).numpy() # Get the top 5 predictions top_5 = torch.argsort(torch.from_numpy(probabilities), axis=-1, descending=True)[0][:5].numpy() # Create ClassificationResult objects with confidence information results = [ ClassificationResult( class_name=self.model.config.id2label[top_5[i]], confidence=float(probabilities[0][top_5[i]]) ) for i in range(5) ] return results