Food_Vision_Big / app.py
v0idcr0w's picture
Upload 6 files
bfd7be3
raw
history blame contribute delete
No virus
2.06 kB
# Imports
import torch
import gradio as gr
from typing import Tuple, Dict
from timeit import default_timer as timer
import model
import os
# Class names
with open("class_names.txt", "r") as f:
class_names = [food.strip() for food in f.readlines()]
# Create instance of model
effnetb2, effnetb2_transforms = model.create_effnetb2_model(num_classes=len(class_names))
# Load Weights
effnetb2.load_state_dict(state_dict=torch.load("effnetb2_food101_20pct.pth",
map_location=torch.device("cpu") # hard-coded load to cpu
))
# Predict function
def predict(img) -> Tuple[Dict, float]:
# Start timer
start = timer()
# Transform input image for use
img = effnetb2_transforms(img).unsqueeze(dim=0)
# Put model in eval mode
effnetb2.eval()
with torch.inference_mode():
logits = effnetb2(img)
pred_probs = torch.softmax(logits, dim=1).squeeze()
prediction = logits.argmax(dim=1).item()
prediction_label = class_names[prediction]
end = timer()
pred_dict = {class_names[i]: pred_probs[i].item() for i in range(len(class_names))}
delta_time = round(end-start, 4)
return pred_dict, delta_time
# Gradio app
# Create example list from within this file
example_list = [ ["examples/" + example] for example in os.listdir("examples")]
title = "FoodVision Big"
description = "EfficientNetB2 feature extractor CV model to classify images of 101 types of food from the Food101 dataset."
article = "Created for PyTorch ZTM course"
demo = gr.Interface(fn=predict, # maps inputs to outputs
inputs=gr.Image(type="pil"),
outputs=[gr.Label(num_top_classes=5, label="Predictions"), # for the prediction dictionary
gr.Number(label="Prediction time (s)")
],
examples=example_list,
title=title,
description=description,
article=article,
)
demo.launch()