import torch import gradio as gr from model import create_effnetb2_model from pathlib import Path from timeit import default_timer as timer from typing import Tuple, Dict with open('class_names.txt', 'r') as f: CLASS_NAMES = [l.strip() for l in f.readlines()] NUM_CLASSES = len(CLASS_NAMES) DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' MODEL, MODEL_TRANSFORMS = create_effnetb2_model(NUM_CLASSES, load_st_dict=True) def predict(img) -> Tuple[Dict, float]: start = timer() X = MODEL_TRANSFORMS(img).unsqueeze(dim=0).to(DEVICE) MODEL.eval() with torch.inference_mode(): y_logits = MODEL(X) y_prob = torch.softmax(y_logits, dim=1) # Prediction threshold for non understood images. if y_prob.max() < .15: raise gr.Error("This image might not be of food.") # Float casting due to Gradio's assumption that numpy objects # should be iterated. Running `tolist()` prior to this # dictionary comprehension can also fix this behavior. y_prob = y_prob.squeeze(dim=0).cpu().numpy() pred = {c: float(prob) for c, prob in zip(CLASS_NAMES, y_prob)} end = timer() return pred, end - start title = "FoodVision Big 🍖❓" description = ( "An [EfficientNetB2]( " "feature extraction that can classify images of 101 classes of food images." ) article = ( "Created by [DiabeticOwl]( " "Uses the [Food101 dataset](" ) demo = gr.Interface( fn=predict, # inputs=gr.Image(shape=(288, 288), type='pil'), inputs=gr.Image(type='pil'), outputs=[gr.Label(num_top_classes=5, label='Predictions'), gr.Number(label='Prediction run time (s)')], examples=list(Path('examples').glob('*.jpg')), title=title, description=description, article=article ) if __name__ == '__main__': demo.launch()