Spaces:
Sleeping
Sleeping
### 1. Imports | |
import gradio as gr | |
import os | |
import torch | |
from PIL import Image | |
from model import create_model_alexnet | |
from timeit import default_timer as timer | |
from typing import Tuple, Dict | |
### 2. Model and transforms preparation ### | |
# Create model_alexnet | |
model_alexnet, transforms = create_model_alexnet( num_classes=2) | |
# Load saved weights | |
model_alexnet.load_state_dict(torch.load(f="cat_dog_classifier.pth", map_location=torch.device("cpu"))) # load to CPU | |
### 3. Predict function ### | |
# Create predict function | |
def predict(img): | |
# Start the timer | |
start_time = timer() | |
model_alexnet.eval() | |
# Reading the image and size transformation | |
features = Image.open(img) | |
img = transforms(features).unsqueeze(0) | |
with torch.inference_mode(): | |
output = model_alexnet(img) | |
_, predicted = torch.max(output, 1) | |
# Create a prediction label and prediction probability dictionary for each prediction class | |
# This is the required format for Gradio's output parameter | |
pred_labels = 'Cat' if predicted.item() ==1 else 'Dog' | |
# Calculate the prediction time | |
pred_time = round(timer() - start_time, 5) | |
# Return the prediction dictionary and prediction time | |
return pred_labels, pred_time | |
### 4. Gradio app ### | |
import gradio as gr | |
# Create title, description and article strings | |
title = "Classification Demo" | |
description = "Cat/Dog classification - Transfer Learning " | |
# Create the Gradio demo | |
demo = gr.Interface(fn=predict, # mapping function from input to output | |
inputs=gr.Image(type='filepath'), # what are the inputs? | |
outputs=[gr.Label(label="Predictions"), # what are the outputs? | |
gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs | |
#examples=example_list, | |
title=title, | |
description=description,) | |
# Launch the demo! | |
demo.launch() | |