OnabajoMonsurat's picture
Change typings-typing libary
dda10fc
# Import and class names setup
import gradio as gr
import os
import torch
from model import create_effnetb2_model
from timeit import default_timer as timer
from typing import Tuple, Dict
# Setup class names
class_names= ['pizza', 'steak', 'sushi']
# Model and transforms preparation
effnetb2_model, effnetb2_transform= create_effnetb2_model()
# Load state dict
effnetb2_model.load_state_dict(torch.load(
f= 'effnet_b2_model.pth',
map_location= torch.device('cpu')
)
)
# Predict function
def predict(img)-> Tuple[Dict, float]:
# start a timer
start_time= timer()
#transform the input image for use with effnet b2
transform_image= effnetb2_transform(img).unsqueeze(0)
#put model into eval mode, make pred
effnetb2_model.eval()
with torch.inference_mode():
pred_logits= effnetb2_model(transform_image)
pred_prob= torch.softmax(pred_logits, dim=1)
# create a pred label and pred prob dict
pred_label_and_prob= {class_names[i]: float(pred_prob[0][i]) for i in range(len(class_names))}
# calc pred time
stop_time= timer()
pred_time= round(stop_time - start_time, 4)
# return pred dict and pred time
return pred_label_and_prob, pred_time
# create example list
example_list= [['example/'+example] for example in os.listdir('example')]
# create gradio app
title= 'FoodVision Mini πŸ•πŸ₯©πŸ£ '
description= 'An EfficientnetB2 feature extractor Computer vision model to classify image as pizza, steak or sushi'
article= 'Created at [09. PyTorch Model Deployment](https://www.learnpytorch.io/09_pytorch_model_deployment/).'
# Create the gradio demo
demo= gr.Interface(fn= predict,
inputs=gr.Image(type='pil'),
outputs= [gr.Label(num_top_classes=3, label= 'predictions'),
gr.Number(label= 'Prediction time (S)')],
examples= example_list,
title= title,
description= description,
article= article
)
# Launch the demo
#demo.launch()
demo.launch(debug=False)