Spaces:
Sleeping
Sleeping
| from timeit import default_timer as timer | |
| import gradio as gr | |
| import torch | |
| from model import create_effnetb2_model | |
| from classes import class_names | |
| # Setup DataLoaders | |
| effnetb2, effnetb2_transforms = create_effnetb2_model(num_classes=525, | |
| seed=42) | |
| # Load the Best Model (effnetb2 with 10 epochs) | |
| def load_best_effnetb2_model(model_path): | |
| # Create an instance of the EfficientNet-B2 model using your function | |
| model = effnetb2 | |
| # Load the saved best model state_dict() into the model | |
| model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
| # Set the model in evaluation mode | |
| model.eval() | |
| return model | |
| # Specify the path to the saved model state dictionary | |
| best_model_path = "best_effnetb2_10_epochs.pth" | |
| # Load the model | |
| best_model = load_best_effnetb2_model(best_model_path) | |
| # Put EffNetB2 on CPU | |
| best_model.to("cpu") | |
| # Check the device | |
| next(iter(best_model.parameters())) | |
| # Prediction Function for the Gradio App | |
| def predict(input_image, model=best_model, transform=effnetb2_transforms, class_names=class_names): | |
| # Create a dictionary to store prediction information | |
| pred_dict = {} | |
| # Start the prediction timer | |
| start_time = timer() | |
| # Preprocess the input image | |
| input_image = transform(input_image).unsqueeze(0) # Add batch dimension | |
| # Prepare the model for inference | |
| model.eval() | |
| # Get prediction probability, prediction label, and prediction class | |
| with torch.inference_mode(): | |
| pred_logit = model(input_image) # perform inference on the input image | |
| pred_prob = torch.softmax(pred_logit, dim=1) # convert logits into prediction probabilities | |
| pred_label = torch.argmax(pred_prob, dim=1) # convert prediction probabilities into prediction label | |
| pred_class = class_names[pred_label.cpu()] # get the predicted class (on CPU) | |
| # Store prediction information in the dictionary | |
| pred_dict["pred_prob"] = round(pred_prob.max().cpu().item(), 4) | |
| pred_dict["pred_class"] = pred_class | |
| # End the timer and calculate the time for prediction | |
| end_time = timer() | |
| pred_dict["time_for_pred"] = round(end_time - start_time, 4) | |
| return pred_class, pred_dict["pred_prob"], pred_dict["time_for_pred"] | |
| ## Deploy Gradio App | |
| # Define sample images as PIL.Image objects | |
| sample_image1 = "sample_images/JABIRU.jpg" | |
| sample_image2 = "sample_images/APAPANE.jpg" | |
| sample_image3 = "sample_images/LITTLE_AUK.jpg" | |
| # Create title, description and article strings | |
| title = "Birds Prediction ποΈπ¦ π¦" | |
| description = "An EfficientNetB2 <b>Computer Vision</b> Model to Classify Birds. \ | |
| <br>The model achieved <b>93% accuracy</b> on the validation set. \ | |
| <br>To upload your own photo you can check first in the classes.py file \ | |
| the different <b>525 classes</b> trained with this model." | |
| article = "Created by Benito Martin" | |
| # Create the Gradio demo | |
| demo = gr.Interface(fn=predict, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload Your Image") | |
| ], | |
| outputs=[ | |
| gr.Label(num_top_classes=525, label="Predicted Bird"), | |
| gr.Number(label="Prediction Probability"), | |
| gr.Number(label="Prediction Time (s)") | |
| ], | |
| examples=[sample_image1, sample_image2, sample_image3], | |
| title=title, | |
| description=description, | |
| article=article) | |
| # Launch the demo! | |
| demo.launch(debug=True, share=True) | |