Ahsen Khaliq commited on
Commit
6b9c074
1 Parent(s): 2a86230

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -2
app.py CHANGED
@@ -16,6 +16,14 @@ import gradio as gr
16
  from PIL import Image
17
  from torchvision import transforms
18
 
 
 
 
 
 
 
 
 
19
 
20
  def load_image(image_path):
21
  return Image.open(image_path).convert("RGB")
@@ -48,6 +56,7 @@ def visualize_and_predict(model, resolution, image_path):
48
  _, preds = model(image).topk(5)
49
  # convert preds to a Python list and remove the batch dimension
50
  preds = preds.tolist()[0]
 
51
  return preds
52
 
53
  os.system("wget https://github.com/pytorch/hub/raw/master/images/dog.jpg -O dog.jpg")
@@ -56,8 +65,7 @@ os.system("wget https://github.com/pytorch/hub/raw/master/images/dog.jpg -O dog.
56
 
57
  def inference(img):
58
  preds = visualize_and_predict(model, resolution, img)
59
-
60
- return preds
61
 
62
  inputs = gr.inputs.Image(type='filepath')
63
  outputs = gr.outputs.Textbox(label="Output")
16
  from PIL import Image
17
  from torchvision import transforms
18
 
19
+ import json
20
+
21
+
22
+ with open("in_cls_idx.json", "r") as f:
23
+ imagenet_id_to_name = {int(cls_id): name for cls_id, (label, name) in json.load(f).items()}
24
+
25
+
26
+
27
 
28
  def load_image(image_path):
29
  return Image.open(image_path).convert("RGB")
56
  _, preds = model(image).topk(5)
57
  # convert preds to a Python list and remove the batch dimension
58
  preds = preds.tolist()[0]
59
+
60
  return preds
61
 
62
  os.system("wget https://github.com/pytorch/hub/raw/master/images/dog.jpg -O dog.jpg")
65
 
66
  def inference(img):
67
  preds = visualize_and_predict(model, resolution, img)
68
+ return [imagenet_id_to_name[cls_id] for cls_id in preds]
 
69
 
70
  inputs = gr.inputs.Image(type='filepath')
71
  outputs = gr.outputs.Textbox(label="Output")