sketch_rec_Mini / app.py
d22cs051's picture
setting default model to mobilenet_v2
d6bb5f9
raw
history blame
3.76 kB
### 1. Imports and class names setup ###
import gradio as gr
import os
import torch
from torchvision import transforms
from models import get_mobilenet_v2_model, get_resnet_18_model, get_vgg_16_model
from timeit import default_timer as timer
from typing import Tuple, Dict
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Setup class names
class_names = ["car","dragon","hourse","pegasus","ship","t-rex","tree"]
### 2. Model and transforms preparation ###
# Create EffNetB2 model
img_transforms = transforms.Compose(
[
transforms.Resize(size=(224, 224)),
transforms.ToTensor(),
]
)
model_name_to_fn = {
"mobilenet_v2": get_mobilenet_v2_model,
"resnet_18": get_resnet_18_model,
"vgg_16": get_vgg_16_model,
}
model_name_to_path = {
"mobilenet_v2": "mobilenet_v2.pth",
"resnet_18": "resnet_18.pth",
"vgg_16": "vgg_16.pt",
}
### 3. Predict function ###
# Create predict function
def predict(img, model_name: str = "mobilenet_v2",) -> Tuple[Dict, float]:
"""
Desc: Transforms and performs a prediction on img and returns prediction and time taken.
Args:
model_name (str): Name of the model to use for prediction.
img (PIL.Image): Image to perform prediction on.
Returns:
Tuple[Dict, float]: Tuple containing a dictionary of prediction labels and probabilities and the time taken to perform the prediction.
"""
# Start the timer
start_time = timer()
# Get the model function based on the model name
model_fn = model_name_to_fn[model_name]
model_path = model_name_to_path[model_name]
# Create the model and load its weights
model = model_fn().to(device)
model.load_state_dict(
torch.load(f"./models/{model_name}.pth", map_location=torch.device(device=device))
)
# Put model into evaluation mode and turn on inference mode
model.eval()
with torch.inference_mode():
# Transform the target image and add a batch dimension
img = img_transforms(img).unsqueeze(0).to(device)
# Pass the transformed image through the model and turn the prediction logits into prediction probabilities
pred_probs = torch.softmax(model(img), dim=1)
# Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
pred_labels_and_probs = {
class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))
}
# Calculate the prediction time
pred_time = round(timer() - start_time, 5)
# Return the prediction dictionary and prediction time
return pred_labels_and_probs, pred_time
### 4. Gradio app ###
# Create title, description and article strings
title = "SketchRec Mini ✍🏻"
description = "An Mutimodel Sketch Recognition App 🎨"
article = ""
# Create examples list from "examples/" directory
example_list = [["examples/" + example] for example in os.listdir("examples")]
# Create the Gradio demo
model_selection_dropdown = gr.components.Dropdown(
choices=list(model_name_to_fn.keys()), label="Select a model",
value="mobilenet_v2"
)
demo = gr.Interface(
fn=predict, # mapping function from input to output
inputs=[gr.Image(type="pil"),model_selection_dropdown], # what are the inputs?
outputs=[
gr.Label(num_top_classes=7, label="Predictions"), # what are the outputs?
gr.Number(label="Prediction time (s)"),
], # our fn has two outputs, therefore we have two outputs
# Create examples list from "examples/" directory
examples=example_list,
title=title,
description=description,
article=article,
)
# Launch the demo!
demo.launch(
# debug=True,
)