File size: 1,832 Bytes
cb1d8c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import gradio as gr

from model import create_effnetb2_model
from pathlib import Path
from timeit import default_timer as timer
from typing import Tuple, Dict

with open('class_names.txt', 'r') as f:
    CLASS_NAMES = [l.strip() for l in f.readlines()]
NUM_CLASSES = len(CLASS_NAMES)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL, MODEL_TRANSFORMS = create_effnetb2_model(NUM_CLASSES, load_st_dict=True)


def predict(img) -> Tuple[Dict, float]:
    start = timer()
    X = MODEL_TRANSFORMS(img).unsqueeze(dim=0).to(DEVICE)
    MODEL.eval()
    with torch.inference_mode():
        y_logits = MODEL(X)
        y_prob = torch.softmax(y_logits, dim=1)
    # Float casting due to Gradio's assumption that numpy objects
    # should be iterated. Running `tolist()` prior to this
    # dictionary comprehension can also fix this behavior.
    y_prob = y_prob.squeeze(dim=0).cpu().numpy()
    pred = {c: float(prob) for c, prob in zip(CLASS_NAMES, y_prob)}
    end = timer()
    return pred, end - start


title = "FoodVision Big πŸ–β“"
description = (
    "An [EfficientNetB2](https://pytorch.org/vision/stable/models/generated/torchvision.models.efficientnet_b2.html) "
    "feature extraction that can classify images of 101 classes of food images."
)
article = (
    "Created by [DiabeticOwl](https://huggingface.co/DiabeticOwl). "
    "Uses the [Food101 dataset](https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/)."
)
demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(shape=(288, 288), type='pil'),
    outputs=[gr.Label(num_top_classes=5, label='Predictions'),
             gr.Number(label='Prediction run time (s)')],
    examples=list(Path('examples').glob('*.jpg')),
    title=title,
    description=description,
    article=article
)

if __name__ == '__main__':
    demo.launch()