Spaces:
Sleeping
Sleeping
from timeit import default_timer as timer | |
import gradio as gr | |
import torch | |
from model import create_effnetb2_model | |
from classes import class_names | |
# Setup DataLoaders | |
effnetb2, effnetb2_transforms = create_effnetb2_model(num_classes=525, | |
seed=42) | |
# Load the Best Model (effnetb2 with 10 epochs) | |
def load_best_effnetb2_model(model_path): | |
# Create an instance of the EfficientNet-B2 model using your function | |
model = effnetb2 | |
# Load the saved best model state_dict() into the model | |
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
# Set the model in evaluation mode | |
model.eval() | |
return model | |
# Specify the path to the saved model state dictionary | |
best_model_path = "best_effnetb2_10_epochs.pth" | |
# Load the model | |
best_model = load_best_effnetb2_model(best_model_path) | |
# Put EffNetB2 on CPU | |
best_model.to("cpu") | |
# Check the device | |
next(iter(best_model.parameters())) | |
# Prediction Function for the Gradio App | |
def predict(input_image, model=best_model, transform=effnetb2_transforms, class_names=class_names): | |
# Create a dictionary to store prediction information | |
pred_dict = {} | |
# Start the prediction timer | |
start_time = timer() | |
# Preprocess the input image | |
input_image = transform(input_image).unsqueeze(0) # Add batch dimension | |
# Prepare the model for inference | |
model.eval() | |
# Get prediction probability, prediction label, and prediction class | |
with torch.inference_mode(): | |
pred_logit = model(input_image) # perform inference on the input image | |
pred_prob = torch.softmax(pred_logit, dim=1) # convert logits into prediction probabilities | |
pred_label = torch.argmax(pred_prob, dim=1) # convert prediction probabilities into prediction label | |
pred_class = class_names[pred_label.cpu()] # get the predicted class (on CPU) | |
# Store prediction information in the dictionary | |
pred_dict["pred_prob"] = round(pred_prob.max().cpu().item(), 4) | |
pred_dict["pred_class"] = pred_class | |
# End the timer and calculate the time for prediction | |
end_time = timer() | |
pred_dict["time_for_pred"] = round(end_time - start_time, 4) | |
return pred_class, pred_dict["pred_prob"], pred_dict["time_for_pred"] | |
## Deploy Gradio App | |
# Define sample images as PIL.Image objects | |
sample_image1 = "sample_images/JABIRU.jpg" | |
sample_image2 = "sample_images/APAPANE.jpg" | |
sample_image3 = "sample_images/LITTLE_AUK.jpg" | |
# Create title, description and article strings | |
title = "Birds Prediction ποΈπ¦ π¦" | |
description = "An EfficientNetB2 <b>Computer Vision</b> Model to Classify Birds. \ | |
<br>The model achieved <b>93% accuracy</b> on the validation set. \ | |
<br>To upload your own photo you can check first in the classes.py file \ | |
the different <b>525 classes</b> trained with this model." | |
article = "Created by Benito Martin" | |
# Create the Gradio demo | |
demo = gr.Interface(fn=predict, | |
inputs=[ | |
gr.Image(type="pil", label="Upload Your Image") | |
], | |
outputs=[ | |
gr.Label(num_top_classes=525, label="Predicted Bird"), | |
gr.Number(label="Prediction Probability"), | |
gr.Number(label="Prediction Time (s)") | |
], | |
examples=[sample_image1, sample_image2, sample_image3], | |
title=title, | |
description=description, | |
article=article) | |
# Launch the demo! | |
demo.launch(debug=True, share=True) | |