FoodVision_Big / app.py
DiabeticOwl's picture
Updating deprecated Image input parameter.
c00f0b4 verified
import torch
import gradio as gr
from model import create_effnetb2_model
from pathlib import Path
from timeit import default_timer as timer
from typing import Tuple, Dict
with open('class_names.txt', 'r') as f:
CLASS_NAMES = [l.strip() for l in f.readlines()]
NUM_CLASSES = len(CLASS_NAMES)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL, MODEL_TRANSFORMS = create_effnetb2_model(NUM_CLASSES, load_st_dict=True)
def predict(img) -> Tuple[Dict, float]:
start = timer()
X = MODEL_TRANSFORMS(img).unsqueeze(dim=0).to(DEVICE)
MODEL.eval()
with torch.inference_mode():
y_logits = MODEL(X)
y_prob = torch.softmax(y_logits, dim=1)
# Prediction threshold for non understood images.
if y_prob.max() < .15:
raise gr.Error("This image might not be of food.")
# Float casting due to Gradio's assumption that numpy objects
# should be iterated. Running `tolist()` prior to this
# dictionary comprehension can also fix this behavior.
y_prob = y_prob.squeeze(dim=0).cpu().numpy()
pred = {c: float(prob) for c, prob in zip(CLASS_NAMES, y_prob)}
end = timer()
return pred, end - start
title = "FoodVision Big πŸ–β“"
description = (
"An [EfficientNetB2](https://pytorch.org/vision/stable/models/generated/torchvision.models.efficientnet_b2.html) "
"feature extraction that can classify images of 101 classes of food images."
)
article = (
"Created by [DiabeticOwl](https://huggingface.co/DiabeticOwl). "
"Uses the [Food101 dataset](https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/)."
)
demo = gr.Interface(
fn=predict,
# inputs=gr.Image(shape=(288, 288), type='pil'),
inputs=gr.Image(type='pil'),
outputs=[gr.Label(num_top_classes=5, label='Predictions'),
gr.Number(label='Prediction run time (s)')],
examples=list(Path('examples').glob('*.jpg')),
title=title,
description=description,
article=article
)
if __name__ == '__main__':
demo.launch()