import torch from model import create_resnet import numpy as np import gradio as gr import os from timeit import default_timer as timer from typing import Tuple, Dict model = create_resnet() model.load_state_dict(torch.load(f="ResNet18_epoch-14.pth", map_location=torch.device("cpu"))) from torchvision import datasets, transforms transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def predict(img): start_time = timer() transformed_image = transform(img) transformed_image = transformed_image.unsqueeze(0) model.eval() with torch.no_grad(): output = model(transformed_image) predicted_label = int(torch.sigmoid(output).item()) end_time = timer() pred_time = round(end_time - start_time, 4) output = "Good" if predicted_label == 1 else "Bad" return output, pred_time # Gradio Interface title = "🍋 Lemon Quality Classifier 🍋" description = "A [ResNet18](https://pytorch.org/vision/main/models/generated/torchvision.models.resnet18.html) computer vision model to classify lemons as good or bad in quality." article = "Created for practice and learning." example_list = [["examples/" + example] for example in os.listdir("examples")] demo = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=[gr.Label(num_top_classes=1, label="Prediction"), gr.Number(label="Prediction time (s)")], examples=example_list, title=title, description=description, article=article) demo.launch()