msivanes commited on
Commit
c0ed310
1 Parent(s): 74c80aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -4
app.py CHANGED
@@ -2,14 +2,18 @@ from fastai.vision.core import PILImageBW, TensorImageBW
2
  from datasets import ClassLabel
3
  import gradio as gr
4
  from fastai.learner import load_learner
 
 
5
 
6
  def get_image_attr(x): return x['image']
7
  def get_target_attr(x): return x['target']
 
8
 
9
  def img2tensor(im: Image.Image):
10
  return TensorImageBW(array(im)).unsqueeze(0)
11
 
12
  classLabel = ClassLabel(names=['T - shirt / top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'], id=None)
 
13
 
14
  def add_target(x:dict):
15
  x['target'] = classLabel.int2str(x['label'])
@@ -20,13 +24,19 @@ learn = load_learner('export.pkl', cpu=True)
20
  def classify(inp):
21
  img = PILImageBW.create(inp)
22
  item = dict(image=img)
23
- pred, _, _ = learn.predict(item)
24
- return classLabel.int2str(int(pred))
 
 
 
 
25
 
26
  iface = gr.Interface(
27
  fn=classify,
28
- inputs=gr.inputs.Image(),
29
- outputs="text",
30
  title="Fashion Mnist Classifier",
31
  description="fastai deployment in Gradio.",
 
 
32
  ).launch()
 
2
  from datasets import ClassLabel
3
  import gradio as gr
4
  from fastai.learner import load_learner
5
+ from PIL import Image
6
+ from numpy import array
7
 
8
  def get_image_attr(x): return x['image']
9
  def get_target_attr(x): return x['target']
10
+ def get_label_attr(x): return x['label']
11
 
12
  def img2tensor(im: Image.Image):
13
  return TensorImageBW(array(im)).unsqueeze(0)
14
 
15
  classLabel = ClassLabel(names=['T - shirt / top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'], id=None)
16
+ labels = classLabel.names
17
 
18
  def add_target(x:dict):
19
  x['target'] = classLabel.int2str(x['label'])
 
24
  def classify(inp):
25
  img = PILImageBW.create(inp)
26
  item = dict(image=img)
27
+ pred, _, prob = learn.predict(item)
28
+ return {label: float(prob[i]) for i, label in enumerate(labels)}
29
+ # return classLabel.int2str(int(pred))
30
+
31
+ examples = ['shoes.jpg', 't-shirt.jpg']
32
+ interpretation='default'
33
 
34
  iface = gr.Interface(
35
  fn=classify,
36
+ inputs=gr.inputs.Image(image_mode='L'),
37
+ outputs=gr.outputs.Label(num_top_classes=3),
38
  title="Fashion Mnist Classifier",
39
  description="fastai deployment in Gradio.",
40
+ examples=examples,
41
+ interpretation=interpretation,
42
  ).launch()