Spaces:
Sleeping
Sleeping
# Import and class names setup | |
import gradio as gr | |
import os | |
import torch | |
from model import create_effnetb2_model | |
from timeit import default_timer as timer | |
from typing import Tuple, Dict | |
# Setup class names | |
with open('class_names.txt', 'r') as f: | |
class_names= [food_name.strip() for food_name in f.readlines()] | |
# Model and transforms preparation | |
effnetb2_model, effnetb2_transform= create_effnetb2_model() | |
# Load state dict | |
effnetb2_model.load_state_dict(torch.load( | |
f= 'effnetb2_food101_model.pth', | |
map_location= torch.device('cpu') | |
) | |
) | |
# Predict function | |
def predict(img)-> Tuple[Dict, float]: | |
# start a timer | |
start_time= timer() | |
#transform the input image for use with effnet b2 | |
transform_image= effnetb2_transform(img).unsqueeze(0) | |
#put model into eval mode, make pred | |
effnetb2_model.eval() | |
with torch.inference_mode(): | |
pred_logits= effnetb2_model(transform_image) | |
pred_prob= torch.softmax(pred_logits, dim=1) | |
# create a pred label and pred prob dict | |
pred_label_and_prob= {class_names[i]: float(pred_prob[0][i]) for i in range(len(class_names))} | |
# calc pred time | |
stop_time= timer() | |
pred_time= round(stop_time - start_time, 4) | |
# return pred dict and pred time | |
return pred_label_and_prob, pred_time | |
# create example list | |
example_list= [['example/'+example] for example in os.listdir('example')] | |
# create gradio app | |
title= 'FoodVision Large ππ₯©π£ ' | |
description= 'An EfficientnetB2 feature extractor Computer vision model to classify 101 classes of food from the food 101 image dataset' | |
article= 'Created at [09. PyTorch Model Deployment](https://www.learnpytorch.io/09_pytorch_model_deployment/).' | |
# Create the gradio demo | |
demo= gr.Interface(fn= predict, | |
inputs=gr.Image(type='pil'), | |
outputs= [gr.Label(num_top_classes=5, label= 'predictions'), | |
gr.Number(label= 'Prediction time (S)')], | |
examples= example_list, | |
title= title, | |
description= description, | |
article= article | |
) | |
# Launch the demo | |
demo.launch() | |