File size: 3,702 Bytes
f8d6043
611f039
 
 
7e2639e
611f039
d455cf3
 
 
 
 
 
 
 
 
7785fab
611f039
d455cf3
d797dbc
611f039
d455cf3
 
611f039
d455cf3
 
 
 
 
 
 
 
fdcd960
 
d455cf3
 
14f48f5
611f039
 
d455cf3
7e2639e
d455cf3
 
 
 
 
611f039
d455cf3
 
 
 
 
 
 
 
fdcd960
d455cf3
 
 
 
611f039
d455cf3
 
 
 
 
 
 
 
 
611f039
d455cf3
 
f8d6043
c77a3a5
 
 
f8d6043
d455cf3
 
408c8ec
2bc133e
 
7e1d2bd
d455cf3
611f039
 
d455cf3
f8d6043
ff8b2d0
 
d455cf3
 
 
 
 
c77a3a5
611f039
 
 
 
 
b44f956
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from timeit import default_timer as timer
import gradio as gr
import torch
from model import create_effnetb2_model
from classes import class_names

# Setup DataLoaders
effnetb2, effnetb2_transforms = create_effnetb2_model(num_classes=525,
                                                      seed=42)

# Load the Best Model (effnetb2 with 10 epochs)

def load_best_effnetb2_model(model_path):

    # Create an instance of the EfficientNet-B2 model using your function
    model = effnetb2

    # Load the saved best model state_dict() into the model
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

    # Set the model in evaluation mode
    model.eval()

    return model

# Specify the path to the saved model state dictionary
best_model_path = "best_effnetb2_10_epochs.pth"

# Load the model
best_model = load_best_effnetb2_model(best_model_path)

# Put EffNetB2 on CPU
best_model.to("cpu")

# Check the device
next(iter(best_model.parameters()))


# Prediction Function for the Gradio App
def predict(input_image, model=best_model, transform=effnetb2_transforms, class_names=class_names):

    # Create a dictionary to store prediction information
    pred_dict = {}

    # Start the prediction timer
    start_time = timer()

    # Preprocess the input image
    input_image = transform(input_image).unsqueeze(0)  # Add batch dimension

    # Prepare the model for inference
    model.eval()

    # Get prediction probability, prediction label, and prediction class
    with torch.inference_mode():
        pred_logit = model(input_image)               # perform inference on the input image
        pred_prob = torch.softmax(pred_logit, dim=1)  # convert logits into prediction probabilities
        pred_label = torch.argmax(pred_prob, dim=1)   # convert prediction probabilities into prediction label
        pred_class = class_names[pred_label.cpu()]    # get the predicted class (on CPU)

        # Store prediction information in the dictionary
        pred_dict["pred_prob"] = round(pred_prob.max().cpu().item(), 4)
        pred_dict["pred_class"] = pred_class

        # End the timer and calculate the time for prediction
        end_time = timer()
        pred_dict["time_for_pred"] = round(end_time - start_time, 4)

    return pred_class, pred_dict["pred_prob"], pred_dict["time_for_pred"]

## Deploy Gradio App

# Define sample images as PIL.Image objects
sample_image1 = "sample_images/JABIRU.jpg"
sample_image2 = "sample_images/APAPANE.jpg"
sample_image3 = "sample_images/LITTLE_AUK.jpg"

# Create title, description and article strings
title = "Birds Prediction πŸ•ŠοΈπŸ¦…πŸ¦†"
description = "An EfficientNetB2 <b>Computer Vision</b> Model to Classify Birds. \
               <br>The model achieved <b>93% accuracy</b> on the validation set. \
               <br>To upload your own photo you can check first in the classes.py file \
               the different <b>525 classes</b> trained with this model."
article = "Created by Benito Martin"

# Create the Gradio demo
demo = gr.Interface(fn=predict,
                    inputs=[
                        gr.Image(type="pil", label="Upload Your Image")
                            ],
                    outputs=[
                        gr.Label(num_top_classes=525, label="Predicted Bird"),
                        gr.Number(label="Prediction Probability"),
                        gr.Number(label="Prediction Time (s)")
                    ],
                    examples=[sample_image1, sample_image2, sample_image3],
                    title=title,
                    description=description,
                    article=article)

# Launch the demo!
demo.launch(debug=True, share=True)