waleedgeorgy's picture
Update app.py
b283dd7 verified
raw
history blame contribute delete
No virus
1.87 kB
import torch
import torchvision
import random
import gradio as gr
import os
from torch import nn
from typing import Tuple, Dict
from timeit import default_timer as timer
from model import create_effnetb2
# Hardcoding the class names
class_names = ['pizza', 'steak', 'sushi']
# Creating the EffnetB2 model and transforms
effnetb2, effnetb2_transforms = create_effnetb2(num_class = len(class_names),
seed = 42)
effnetb2.load_state_dict(torch.load(f = 'effnetb2.pth',
map_location = torch.device('cpu')))
# Defining the example images list
example_list = [["examples/" + example] for example in os.listdir('examples')]
# Defining the predict function
def predict(img) -> Tuple[Dict, float]:
start_time = timer()
transformed_img = effnetb2_transforms(img).unsqueeze(0)
effnetb2.eval()
with torch.inference_mode():
y_logits = effnetb2(transformed_img)
y_preds = torch.softmax(y_logits, dim = 1)
y_label = torch.argmax(y_preds, dim = 1)
pred_labels_probs = {class_names[i]: float(y_preds[0][i]) for i in range(len(class_names))}
end_time = timer()
pred_time = round(end_time - start_time, 4)
return pred_labels_probs, pred_time
# Creating the gradio app
title = 'FoodVision Mini'
description = 'An EfficientNetB2 feature extractor CV model to classify images of food as pizza, steak or sushi.'
demo = gr.Interface(fn=predict,
inputs=gr.Image(type='pil'),
outputs=[gr.Label(num_top_classes = len(class_names),
label = 'Prediction Probabilities'),
gr.Number(label = 'Prediction Time (s)')],
examples = example_list,
title = title,
description = description)
demo.launch(share = True)