Spaces:
Running
Running
import torch | |
import gradio as gr | |
from model import create_effnetb2_model | |
from pathlib import Path | |
from timeit import default_timer as timer | |
from typing import Tuple, Dict | |
with open('class_names.txt', 'r') as f: | |
CLASS_NAMES = [l.strip() for l in f.readlines()] | |
NUM_CLASSES = len(CLASS_NAMES) | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
MODEL, MODEL_TRANSFORMS = create_effnetb2_model(NUM_CLASSES, load_st_dict=True) | |
def predict(img) -> Tuple[Dict, float]: | |
start = timer() | |
X = MODEL_TRANSFORMS(img).unsqueeze(dim=0).to(DEVICE) | |
MODEL.eval() | |
with torch.inference_mode(): | |
y_logits = MODEL(X) | |
y_prob = torch.softmax(y_logits, dim=1) | |
# Prediction threshold for non understood images. | |
if y_prob.max() < .15: | |
raise gr.Error("This image might not be of food.") | |
# Float casting due to Gradio's assumption that numpy objects | |
# should be iterated. Running `tolist()` prior to this | |
# dictionary comprehension can also fix this behavior. | |
y_prob = y_prob.squeeze(dim=0).cpu().numpy() | |
pred = {c: float(prob) for c, prob in zip(CLASS_NAMES, y_prob)} | |
end = timer() | |
return pred, end - start | |
title = "FoodVision Big πβ" | |
description = ( | |
"An [EfficientNetB2](https://pytorch.org/vision/stable/models/generated/torchvision.models.efficientnet_b2.html) " | |
"feature extraction that can classify images of 101 classes of food images." | |
) | |
article = ( | |
"Created by [DiabeticOwl](https://huggingface.co/DiabeticOwl). " | |
"Uses the [Food101 dataset](https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/)." | |
) | |
demo = gr.Interface( | |
fn=predict, | |
# inputs=gr.Image(shape=(288, 288), type='pil'), | |
inputs=gr.Image(type='pil'), | |
outputs=[gr.Label(num_top_classes=5, label='Predictions'), | |
gr.Number(label='Prediction run time (s)')], | |
examples=list(Path('examples').glob('*.jpg')), | |
title=title, | |
description=description, | |
article=article | |
) | |
if __name__ == '__main__': | |
demo.launch() | |