nickmuchi commited on
Commit
5432f01
1 Parent(s): 220e378

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -28
app.py CHANGED
@@ -3,6 +3,7 @@ import gradio as gr
3
  import matplotlib.pyplot as plt
4
  import requests, validators
5
  import torch
 
6
  from PIL import Image
7
  from transformers import AutoFeatureExtractor, DetrForObjectDetection, YolosForObjectDetection
8
 
@@ -51,7 +52,7 @@ def visualize_prediction(pil_img, output_dict, threshold=0.7, id2label=None):
51
  plt.axis("off")
52
  return fig2img(plt.gcf())
53
 
54
- def detect_objects_from_url(model_name,url,threshold):
55
 
56
  #Extract model and feature extractor
57
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
@@ -66,29 +67,8 @@ def detect_objects_from_url(model_name,url,threshold):
66
 
67
  if url and validators.url(url):
68
  image = Image.open(requests.get(url, stream=True).raw)
69
-
70
- #Make prediction
71
- processed_outputs = make_prediction(image, feature_extractor, model)
72
-
73
- #Visualize prediction
74
- viz_img = visualize_prediction(image, processed_outputs, threshold, model.config.id2label)
75
-
76
- return viz_img
77
-
78
- def detect_objects_from_upload(model_name,image_upload,threshold):
79
-
80
- #Extract model and feature extractor
81
- feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
82
-
83
- if 'detr' in model_name:
84
 
85
- model = DetrForObjectDetection.from_pretrained(model_name)
86
-
87
- elif 'yolos' in model_name:
88
-
89
- model = YolosForObjectDetection.from_pretrained(model_name)
90
-
91
- if image_upload:
92
  image = image_upload
93
 
94
  #Make prediction
@@ -97,13 +77,13 @@ def detect_objects_from_upload(model_name,image_upload,threshold):
97
  #Visualize prediction
98
  viz_img = visualize_prediction(image, processed_outputs, threshold, model.config.id2label)
99
 
100
- return viz_img
101
-
102
  #examples=[['facebook/detr-resnet-50','https://media-cldnry.s-nbcnews.com/image/upload/t_fit-1500w,f_auto,q_auto:best/newscms/2020_14/3290756/200331-wall-street-ew-#343p.jpg',,0.7]
103
 
104
 
105
 
106
- title = 'Object Detection App with DETR and YOLOS'
107
 
108
  description = """
109
  Links to HuggingFace Models:
@@ -131,6 +111,12 @@ with demo:
131
  with gr.Row():
132
  url_input = gr.Textbox(lines=1,label='Enter valid image URL here..')
133
  img_output_from_url = gr.Image(shape=(450,450))
 
 
 
 
 
 
134
 
135
  url_but = gr.Button('Detect')
136
 
@@ -139,11 +125,17 @@ with demo:
139
  img_input = gr.Image(type='pil')
140
  img_output_from_upload= gr.Image(shape=(450,450))
141
 
 
 
 
 
 
 
142
  img_but = gr.Button('Detect')
143
 
144
 
145
- url_but.click(detect_objects_from_url,inputs=[options,url_input,slider_input],outputs=img_output_from_url,queue=True)
146
- img_but.click(detect_objects_from_upload,inputs=[options,img_input,slider_input],outputs=img_output_from_upload,queue=True)
147
 
148
 
149
  demo.launch(enable_queue=True)
 
3
  import matplotlib.pyplot as plt
4
  import requests, validators
5
  import torch
6
+ import pathlib
7
  from PIL import Image
8
  from transformers import AutoFeatureExtractor, DetrForObjectDetection, YolosForObjectDetection
9
 
 
52
  plt.axis("off")
53
  return fig2img(plt.gcf())
54
 
55
+ def detect_objects(model_name,url_input,image_input,threshold):
56
 
57
  #Extract model and feature extractor
58
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
 
67
 
68
  if url and validators.url(url):
69
  image = Image.open(requests.get(url, stream=True).raw)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ elif image_upload:
 
 
 
 
 
 
72
  image = image_upload
73
 
74
  #Make prediction
 
77
  #Visualize prediction
78
  viz_img = visualize_prediction(image, processed_outputs, threshold, model.config.id2label)
79
 
80
+ return viz_img
81
+
82
  #examples=[['facebook/detr-resnet-50','https://media-cldnry.s-nbcnews.com/image/upload/t_fit-1500w,f_auto,q_auto:best/newscms/2020_14/3290756/200331-wall-street-ew-#343p.jpg',,0.7]
83
 
84
 
85
 
86
+ title = """<h1 id="title">Object Detection App with DETR and YOLOS</h1>"""
87
 
88
  description = """
89
  Links to HuggingFace Models:
 
111
  with gr.Row():
112
  url_input = gr.Textbox(lines=1,label='Enter valid image URL here..')
113
  img_output_from_url = gr.Image(shape=(450,450))
114
+
115
+ with gr.Row():
116
+ urls = ["https://media-cldnry.s-nbcnews.com/image/upload/t_fit-1500w,f_auto,q_auto:best/newscms/2020_14/3290756/200331-wall-street-ew-#343p.jpg"]
117
+ example_url = gr.Dataset(components=[url_input],
118
+ samples=[[url.as_posix()]
119
+ for url in urls])
120
 
121
  url_but = gr.Button('Detect')
122
 
 
125
  img_input = gr.Image(type='pil')
126
  img_output_from_upload= gr.Image(shape=(450,450))
127
 
128
+ with gr.Row():
129
+ paths = sorted(pathlib.Path('images').rglob('*.JPG')
130
+ example_images = gr.Dataset(components=[img_input],
131
+ samples=[[path.as_posix()]
132
+ for path in paths])
133
+
134
  img_but = gr.Button('Detect')
135
 
136
 
137
+ url_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=img_output_from_url,queue=True)
138
+ img_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=img_output_from_upload,queue=True)
139
 
140
 
141
  demo.launch(enable_queue=True)