suvash commited on
Commit
3d3df40
1 Parent(s): 31c5d2a

somewhat cleanup app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -4,18 +4,21 @@ from fastai.vision.all import *
4
  MODELS_PATH = Path('./models')
5
  EXAMPLES_PATH = Path('./examples')
6
 
7
- # Required function used by fastai learner (at training setup)
 
 
 
8
  def label_func(filepath):
9
  return filepath.parent.name
10
 
11
- learn = load_learner(MODELS_PATH/'food-101-resnet50.pkl')
 
12
 
13
- labels = learn.dls.vocab
14
-
15
- def predict(img):
16
  img = PILImage.create(img)
17
- pred,pred_idx,probs = learn.predict(img)
18
- return {labels[i]: float(probs[i]) for i in range(len(labels))}
 
19
 
20
  with open('gradio_article.md') as f:
21
  article = f.read()
@@ -31,7 +34,7 @@ interface_options = {
31
  "allow_flagging": "never",
32
  }
33
 
34
- demo = gradio.Interface(fn=predict,
35
  inputs=gradio.inputs.Image(shape=(512, 512)),
36
  outputs=gradio.outputs.Label(num_top_classes=5),
37
  **interface_options)
 
4
  MODELS_PATH = Path('./models')
5
  EXAMPLES_PATH = Path('./examples')
6
 
7
+ # Required function expected by fastai learn object
8
+ # it wasn't exported as a part of the pickle
9
+ # as it was defined externally to the learner object
10
+ # during the training time dataloaders setup
11
  def label_func(filepath):
12
  return filepath.parent.name
13
 
14
+ LEARN = load_learner(MODELS_PATH/'food-101-resnet50.pkl')
15
+ LABELS = LEARN.dls.vocab
16
 
17
+ def gradio_predict(img):
 
 
18
  img = PILImage.create(img)
19
+ _pred, _pred_idx, probs = LEARN.predict(img)
20
+ labels_probs = {LABELS[i]: float(probs[i]) for i, _ in enumerate(LABELS)}
21
+ return labels_probs
22
 
23
  with open('gradio_article.md') as f:
24
  article = f.read()
 
34
  "allow_flagging": "never",
35
  }
36
 
37
+ demo = gradio.Interface(fn=gradio_predict,
38
  inputs=gradio.inputs.Image(shape=(512, 512)),
39
  outputs=gradio.outputs.Label(num_top_classes=5),
40
  **interface_options)