Clement Vachet commited on
Commit
3a4ff58
1 Parent(s): 2ccf6ca

Use gradio blocks for user interface

Browse files
Files changed (1) hide show
  1. app.py +49 -22
app.py CHANGED
@@ -8,9 +8,10 @@ from PIL import Image
8
  from transformers import pipeline
9
  import matplotlib.pyplot as plt
10
  import io
 
11
 
12
- model_pipeline = pipeline(model="facebook/detr-resnet-50")
13
-
14
 
15
  COLORS = [
16
  [0.000, 0.447, 0.741],
@@ -22,6 +23,11 @@ COLORS = [
22
  ]
23
 
24
 
 
 
 
 
 
25
  def get_output_figure(pil_img, results, threshold):
26
  plt.figure(figsize=(16, 10))
27
  plt.imshow(pil_img)
@@ -45,7 +51,10 @@ def get_output_figure(pil_img, results, threshold):
45
 
46
 
47
  #@spaces.GPU
48
- def detect(image, threshold=0.9):
 
 
 
49
  results = model_pipeline(image)
50
  print(results)
51
 
@@ -59,22 +68,40 @@ def detect(image, threshold=0.9):
59
  return output_pil_img
60
 
61
 
62
- with gr.Blocks() as demo:
63
- gr.Markdown("# Object detection with DETR on COCO dataset")
64
- gr.Markdown(
65
- """
66
- This application uses a DETR (DEtection TRansformers) model to detect objects on images.
67
- This version was trained using the COCO dataset.
68
- You can load an image and see the predictions for the objects detected.
69
- """
70
- )
71
-
72
- gr.Interface(
73
- fn=detect,
74
- inputs=[gr.Image(label="Input image", type="pil"), \
75
- gr.Slider(0, 1.0, value=0.9, label='Threshold')],
76
- outputs=[gr.Image(label="Output prediction", type="pil")],
77
- examples=[['samples/savanna.jpg']],
78
- )
79
-
80
- demo.launch(show_error=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from transformers import pipeline
9
  import matplotlib.pyplot as plt
10
  import io
11
+ import os
12
 
13
+ list_models = ["facebook/detr-resnet-50"]
14
+ list_models_simple = [os.path.basename(model) for model in list_models]
15
 
16
  COLORS = [
17
  [0.000, 0.447, 0.741],
 
23
  ]
24
 
25
 
26
+ def load_pipeline(model):
27
+ model_pipeline = pipeline(model=model)
28
+ return model_pipeline
29
+
30
+
31
  def get_output_figure(pil_img, results, threshold):
32
  plt.figure(figsize=(16, 10))
33
  plt.imshow(pil_img)
 
51
 
52
 
53
  #@spaces.GPU
54
+ def detect(image, model_id, threshold=0.9):
55
+ print("model:", list_models[model_id])
56
+
57
+ model_pipeline = load_pipeline(list_models[model_id])
58
  results = model_pipeline(image)
59
  print(results)
60
 
 
68
  return output_pil_img
69
 
70
 
71
+ def demo():
72
+ with gr.Blocks(theme="base") as demo:
73
+ gr.Markdown("# Object detection on COCO dataset")
74
+ gr.Markdown(
75
+ """
76
+ This application uses transformer-based models to detect objects on images.
77
+ This version was trained using the COCO dataset.
78
+ You can load an image and see the predictions for the objects detected.
79
+ """
80
+ )
81
+
82
+ with gr.Row():
83
+ model_id = gr.Radio(list_models, \
84
+ label="Detection models", value=list_models[0], type="index", info="Choose your detection model")
85
+ with gr.Row():
86
+ threshold = gr.Slider(0, 1.0, value=0.9, label='Detection threshold', info="Choose your detection threshold")
87
+
88
+ with gr.Row():
89
+ input_image = gr.Image(label="Input image", type="pil")
90
+ output_image = gr.Image(label="Output image", type="pil")
91
+
92
+ with gr.Row():
93
+ submit_btn = gr.Button("Submit")
94
+ clear_button = gr.ClearButton()
95
+
96
+ gr.Examples(['samples/savanna.jpg'], inputs=input_image)
97
+
98
+ submit_btn.click(fn=detect, inputs=[input_image, model_id, threshold], outputs=[output_image])
99
+ clear_button.click(lambda: [None, None], \
100
+ inputs=None, \
101
+ outputs=[input_image, output_image], \
102
+ queue=False)
103
+
104
+ demo.queue().launch(debug=True)
105
+
106
+ if __name__ == "__main__":
107
+ demo()