birds_pytorch / app.py
bmartinc80's picture
updated launch share = true
b44f956 verified
raw
history blame contribute delete
No virus
3.7 kB
from timeit import default_timer as timer
import gradio as gr
import torch
from model import create_effnetb2_model
from classes import class_names
# Setup DataLoaders
effnetb2, effnetb2_transforms = create_effnetb2_model(num_classes=525,
seed=42)
# Load the Best Model (effnetb2 with 10 epochs)
def load_best_effnetb2_model(model_path):
# Create an instance of the EfficientNet-B2 model using your function
model = effnetb2
# Load the saved best model state_dict() into the model
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
# Set the model in evaluation mode
model.eval()
return model
# Specify the path to the saved model state dictionary
best_model_path = "best_effnetb2_10_epochs.pth"
# Load the model
best_model = load_best_effnetb2_model(best_model_path)
# Put EffNetB2 on CPU
best_model.to("cpu")
# Check the device
next(iter(best_model.parameters()))
# Prediction Function for the Gradio App
def predict(input_image, model=best_model, transform=effnetb2_transforms, class_names=class_names):
# Create a dictionary to store prediction information
pred_dict = {}
# Start the prediction timer
start_time = timer()
# Preprocess the input image
input_image = transform(input_image).unsqueeze(0) # Add batch dimension
# Prepare the model for inference
model.eval()
# Get prediction probability, prediction label, and prediction class
with torch.inference_mode():
pred_logit = model(input_image) # perform inference on the input image
pred_prob = torch.softmax(pred_logit, dim=1) # convert logits into prediction probabilities
pred_label = torch.argmax(pred_prob, dim=1) # convert prediction probabilities into prediction label
pred_class = class_names[pred_label.cpu()] # get the predicted class (on CPU)
# Store prediction information in the dictionary
pred_dict["pred_prob"] = round(pred_prob.max().cpu().item(), 4)
pred_dict["pred_class"] = pred_class
# End the timer and calculate the time for prediction
end_time = timer()
pred_dict["time_for_pred"] = round(end_time - start_time, 4)
return pred_class, pred_dict["pred_prob"], pred_dict["time_for_pred"]
## Deploy Gradio App
# Define sample images as PIL.Image objects
sample_image1 = "sample_images/JABIRU.jpg"
sample_image2 = "sample_images/APAPANE.jpg"
sample_image3 = "sample_images/LITTLE_AUK.jpg"
# Create title, description and article strings
title = "Birds Prediction πŸ•ŠοΈπŸ¦…πŸ¦†"
description = "An EfficientNetB2 <b>Computer Vision</b> Model to Classify Birds. \
<br>The model achieved <b>93% accuracy</b> on the validation set. \
<br>To upload your own photo you can check first in the classes.py file \
the different <b>525 classes</b> trained with this model."
article = "Created by Benito Martin"
# Create the Gradio demo
demo = gr.Interface(fn=predict,
inputs=[
gr.Image(type="pil", label="Upload Your Image")
],
outputs=[
gr.Label(num_top_classes=525, label="Predicted Bird"),
gr.Number(label="Prediction Probability"),
gr.Number(label="Prediction Time (s)")
],
examples=[sample_image1, sample_image2, sample_image3],
title=title,
description=description,
article=article)
# Launch the demo!
demo.launch(debug=True, share=True)