import torch import torch.nn as nn from torchvision import transforms, models from huggingface_hub import hf_hub_download from PIL import Image device = torch.device("cuda" if torch.cuda.is_available() else "cpu") num_classes_school = 26 num_classes_type = 10 model_path = hf_hub_download( repo_id="Irina1402/mobilnetv3-painting-classification", filename="model.pth" ) model = models.mobilenet_v3_large(weights=models.MobileNet_V3_Large_Weights.DEFAULT) num_features = model.classifier[0].in_features model.classifier = nn.Sequential( nn.Linear(num_features, 512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, num_classes_school + num_classes_type) ) model.load_state_dict(torch.load(model_path, map_location=device)) model = model.to(device) model.eval() transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) school_labels = [ "American", "Austrian", "Belgian", "Bohemian", "Catalan", "Danish", "Dutch", "English", "Finnish", "Flemish", "French", "German", "Greek", "Hungarian", "Irish", "Italian", "Netherlandish", "Norwegian", "Other", "Polish", "Portuguese", "Russian", "Scottish", "Spanish", "Swedish", "Swiss" ] type_labels = [ "genre", "historical", "interior", "landscape", "mythological", "other", "portrait", "religious", "still-life", "study" ] def classify_image(image: Image.Image): """Classify the uploaded image and return type and school predictions.""" input_tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): output = model(input_tensor) school_output = output[:, :num_classes_school] type_output = output[:, num_classes_school:] school_prediction = torch.argmax(school_output).item() type_prediction = torch.argmax(type_output).item() return { "school": school_labels[school_prediction], "type": type_labels[type_prediction] }