File size: 2,068 Bytes
39dcf75
 
 
 
 
 
 
 
8b09e96
39dcf75
 
9c14f10
39dcf75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95a02b8
39dcf75
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

# Import and class names setup
import gradio as gr
import os
import torch

from model import create_effnetb2_model
from timeit import default_timer as timer
from typing import Tuple, Dict

# Setup class names
with open('class_names.txt', 'r') as f:
  class_names= [food_name.strip() for food_name in f.readlines()]


# Model and transforms preparation
effnetb2_model, effnetb2_transform= create_effnetb2_model()
# Load state dict
effnetb2_model.load_state_dict(torch.load(
    f= 'effnetb2_food101_model.pth',
    map_location= torch.device('cpu')
    )
)

# Predict function

def predict(img)-> Tuple[Dict, float]:
  # start a timer
  start_time= timer()

  #transform the input image for use with effnet b2
  transform_image= effnetb2_transform(img).unsqueeze(0)

  #put model into eval mode, make pred
  effnetb2_model.eval()
  with torch.inference_mode():
    pred_logits= effnetb2_model(transform_image)
    pred_prob= torch.softmax(pred_logits, dim=1)

  # create a pred label and pred prob dict
  pred_label_and_prob= {class_names[i]: float(pred_prob[0][i]) for i in range(len(class_names))}


  # calc pred time
  stop_time= timer()
  pred_time= round(stop_time - start_time, 4)


  # return pred dict and pred time
  return pred_label_and_prob, pred_time

# create example list
example_list= [['example/'+example] for example in os.listdir('example')]

# create gradio app
title= 'FoodVision Large 🍕🥩🍣 '
description= 'An EfficientnetB2 feature extractor Computer vision model to classify 101 classes of food from the food 101 image dataset'
article= 'Created at [To be uploaded].'

# Create the gradio demo
demo= gr.Interface(fn= predict,
                   inputs=gr.Image(type='pil'),
                   outputs= [gr.Label(num_top_classes=5, label= 'predictions'),
                             gr.Number(label= 'Prediction time (S)')],
                   examples= example_list,
                   title= title,
                   description= description,
                   article= article
                   )

# Launch the demo
demo.launch()