from typing import List import torch import timm from src.interface import ModelInterface from src.data.classification_result import ClassificationResult from PIL import Image import urllib.request class MobilenetV3(ModelInterface): def __init__(self): print('init... mobilenet v3 model') self.model = timm.create_model('mobilenetv3_large_100', pretrained=True).eval() # Download and read class labels url, filename = ("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt", "imagenet_classes.txt") urllib.request.urlretrieve(url, filename) with open(filename, "r") as f: self.class_labels = [s.strip() for s in f.readlines()] def classify_image(self, image) -> List[ClassificationResult]: # Get model-specific transforms (normalization, resize) data_config = timm.data.resolve_model_data_config(self.model) transforms = timm.data.create_transform(**data_config, is_training=False) input_tensor = transforms(image).unsqueeze(0) # Perform inference with torch.no_grad(): output = self.model(input_tensor) # Get the top 5 predictions probabilities, top5_class_indices = torch.topk(output.softmax(dim=1), k=5) # Create ClassificationResult objects with confidence information results = [ ClassificationResult( class_name=self.class_labels[top5_class_indices[0][i].item()], confidence=probabilities[0][i].item() ) for i in range(5) ] return results