from typing import List from urllib.request import urlopen from PIL import Image from .data.model_data import ModelData from .models.mobilenet_v3 import MobilenetV3 from .models.clip_vit import ClipVit from .models.google_vit import GoogleVit from .models.resnet_50 import Resnet50 from .data.classification_result import ClassificationResult class ClassificationModel: """ Base class for all classification models. """ def __init__(self): self.load_model() def get_model_names(self): return [model.name for model in self.models] def get_model_data(self, model_name): for model in self.models: if model.name == model_name: return model raise Exception(f'Model {model_name} not found') def load_model(self): self.models = [ ModelData('clip-vit-base-patch32', model_class=ClipVit()), ModelData('mobilenet_v3', model_class=MobilenetV3()), ModelData('google-vit-base-patch16-224', model_class=GoogleVit()), ModelData('microsoft/resnet-50', model_class=Resnet50()) ] def classify(self, model_name, image) -> List[ClassificationResult]: #print type of image print('>> image type -->',type(image)) #convert image to pil img = self.image_to_pil(image) model = self.get_model_data(model_name) return model.model_class.classify_image(img) def image_to_pil(self, image): #if image is starts with https (means url), then download it if image.startswith('https'): return Image.open(urlopen(image)) return Image.open(image)