nickmuchi commited on
Commit
5f842d6
1 Parent(s): c4bed8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -5
app.py CHANGED
@@ -51,7 +51,7 @@ def visualize_prediction(pil_img, output_dict, threshold=0.7, id2label=None):
51
  plt.axis("off")
52
  return fig2img(plt.gcf())
53
 
54
- def detect_objects(model_name,url,image_upload,threshold):
55
 
56
  #Extract model and feature extractor
57
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
@@ -68,8 +68,29 @@ def detect_objects(model_name,url,image_upload,threshold):
68
 
69
  if validators.url(url):
70
  image = Image.open(requests.get(url, stream=True).raw)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- elif image_upload:
73
  image = image_upload
74
 
75
  #Make prediction
@@ -78,7 +99,7 @@ def detect_objects(model_name,url,image_upload,threshold):
78
  #Visualize prediction
79
  viz_img = visualize_prediction(image, processed_outputs, threshold, model.config.id2label)
80
 
81
- return viz_img
82
 
83
  #examples=[['facebook/detr-resnet-50','https://media-cldnry.s-nbcnews.com/image/upload/t_fit-1500w,f_auto,q_auto:best/newscms/2020_14/3290756/200331-wall-street-ew-#343p.jpg',,0.7]
84
 
@@ -123,8 +144,8 @@ with demo:
123
  img_but = gr.Button('Detect')
124
 
125
 
126
- url_but.click(detect_objects,inputs=[options,url_input,None,slider_input],outputs=img_output_from_url,queue=True)
127
- img_but.click(detect_objects,inputs=[options,None,img_input,slider_input],outputs=img_output_from_upload,queue=True)
128
 
129
 
130
  demo.launch(enable_queue=True)
 
51
  plt.axis("off")
52
  return fig2img(plt.gcf())
53
 
54
+ def detect_objects_from_url(model_name,url,threshold):
55
 
56
  #Extract model and feature extractor
57
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
 
68
 
69
  if validators.url(url):
70
  image = Image.open(requests.get(url, stream=True).raw)
71
+
72
+ #Make prediction
73
+ processed_outputs = make_prediction(image, feature_extractor, model)
74
+
75
+ #Visualize prediction
76
+ viz_img = visualize_prediction(image, processed_outputs, threshold, model.config.id2label)
77
+
78
+ return viz_img
79
+
80
+ def detect_objects_from_upload(model_name,image_upload,threshold):
81
+
82
+ #Extract model and feature extractor
83
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
84
+
85
+ if 'detr' in model_name:
86
+
87
+ model = DetrForObjectDetection.from_pretrained(model_name)
88
+
89
+ elif 'yolos' in model_name:
90
+
91
+ model = YolosForObjectDetection.from_pretrained(model_name)
92
 
93
+ if image_upload:
94
  image = image_upload
95
 
96
  #Make prediction
 
99
  #Visualize prediction
100
  viz_img = visualize_prediction(image, processed_outputs, threshold, model.config.id2label)
101
 
102
+ return viz_img
103
 
104
  #examples=[['facebook/detr-resnet-50','https://media-cldnry.s-nbcnews.com/image/upload/t_fit-1500w,f_auto,q_auto:best/newscms/2020_14/3290756/200331-wall-street-ew-#343p.jpg',,0.7]
105
 
 
144
  img_but = gr.Button('Detect')
145
 
146
 
147
+ url_but.click(detect_objects_from_url,inputs=[options,url_input,slider_input],outputs=img_output_from_url,queue=True)
148
+ img_but.click(detect_objects,inputs=[options,img_input,slider_input],outputs=img_output_from_upload,queue=True)
149
 
150
 
151
  demo.launch(enable_queue=True)