import torch import torch.nn as nn from torchvision import transforms, models import gradio as gr transformer = models.ResNet18_Weights.IMAGENET1K_V1.transforms() class_names = ['anger', 'disgust', 'fear', 'happy', 'pain', 'sad'] classes_count = len(class_names) model = models.resnet18(weights='DEFAULT') model.fc = nn.Sequential( nn.Linear(512, classes_count) ) model.load_state_dict(torch.load('./model_param.pt', map_location=torch.device('cpu')), strict=False) def predict(img): img = transformer(img).unsqueeze(0) model.eval() with torch.inference_mode(): pred = torch.softmax(model(img), dim=1) pred_and_labels = {class_names[i] : pred[0][i].item() for i in range(len(pred[0])) } return pred_and_labels app = gr.Interface( predict, gr.Image(type='pil'), gr.Label(num_top_classes=classes_count) ) app.launch()