File size: 3,286 Bytes
c6838d4
 
01c2e17
c6838d4
01c2e17
a38201d
c6838d4
 
 
 
 
 
 
a38201d
 
 
01c2e17
 
a38201d
 
 
 
 
01c2e17
 
a38201d
01c2e17
 
 
c6838d4
01c2e17
c6838d4
 
 
01c2e17
c6838d4
 
 
01c2e17
c6838d4
 
01c2e17
 
c6838d4
01c2e17
 
c6838d4
01c2e17
 
 
 
 
c6838d4
01c2e17
 
c6838d4
01c2e17
 
 
c6838d4
01c2e17
 
c6838d4
01c2e17
c6838d4
01c2e17
5fdb962
c6838d4
5fdb962
c6838d4
 
01c2e17
c6838d4
 
01c2e17
c6838d4
 
 
 
 
 
 
 
 
01c2e17
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
### 1. Imports and class names setup ###
import gradio as gr
import os
import torch
from pathlib import Path
from zipfile import ZipFile
from model import create_effnetb2_model
from timeit import default_timer as timer
from typing import Tuple, Dict

# Setup class names
class_names = ['pizza', 'steak', 'sushi']

### 2. Handle examples.zip ###
# Define the zip file and the target extraction folder
zip_file_path = Path("examples.zip")
extracted_folder_path = Path("examples")

# Extract .zip file if it exists and is not already extracted
if zip_file_path.exists() and not extracted_folder_path.exists():
    print(f"Extracting {zip_file_path} to {extracted_folder_path}...")
    with ZipFile(zip_file_path, "r") as zf:
        zf.extractall(extracted_folder_path)
    print(f"Extraction complete. Files extracted to {extracted_folder_path}.")
else:
    print(f"ZIP file not found or examples folder already exists.")

### 3. Model and transforms preparation ###
effnetb2, effnetb2_transforms = create_effnetb2_model(num_classes=3)

# Load saved weights
effnetb2.load_state_dict(
    torch.load(
        f="09_pretrained_effnetb2_feature_extractor_pizza_steak_sushi_20_percent.pth",
        map_location=torch.device("cpu")  # Load the model to the CPU
    )
)

### 4. Predict function ###

def predict(img) -> Tuple[Dict, float]:
    # Start a timer
    start_time = timer()

    # Transform the input image for use with EffNetB2
    img = effnetb2_transforms(img).unsqueeze(0)  # unsqueeze = add batch dimension on 0th index

    # Put model into eval mode, make prediction
    effnetb2.eval()
    with torch.inference_mode():
        # Pass transformed image through the model and turn the prediction logits into probabilities
        pred_probs = torch.softmax(effnetb2(img), dim=1)

    # Create a prediction label and prediction probability dictionary
    pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}

    # Calculate prediction time
    end_time = timer()
    pred_time = round(end_time - start_time, 4)

    # Return pred dict and pred time
    return pred_labels_and_probs, pred_time

### 5. Gradio app ###

# Create title, description, and article
title = "Food Extractor 🍕🥩🍣"
description = "An [EfficientNetB2 feature extractor](https://pytorch.org/vision/stable/models/generated/torchvision.models.efficientnet_b2.html#torchvision.models.efficientnet_b2) computer vision model to classify images as pizza, steak or sushi."
article = "Created by [Prof. Sajad Ahmad Rather, IIT Roorkee, PARIMAL LAB](https://github.com/SajadAHMAD1)."

# Create example list
example_list = [[str(filepath)] for filepath in extracted_folder_path.glob("*")]  # Get all files in the examples folder

# Create the Gradio demo
demo = gr.Interface(fn=predict,  # Maps inputs to outputs
                    inputs=gr.Image(type="pil"),
                    outputs=[gr.Label(num_top_classes=3, label="Predictions"),
                             gr.Number(label="Prediction time (s)")],
                    examples=example_list,
                    title=title,
                    description=description,
                    article=article)

# Launch the demo!
demo.launch(debug=False)  # Don't need share=True in Hugging Face Spaces