sfmig commited on
Commit
912be9c
1 Parent(s): 562224f

added slider for confidence

Browse files
Files changed (1) hide show
  1. app.py +12 -13
app.py CHANGED
@@ -7,6 +7,8 @@ Using as reference:
7
 
8
  https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/DETR/DETR_panoptic_segmentation_minimal_example_(with_DetrFeatureExtractor).ipynb
9
 
 
 
10
  Additions
11
  - add shown labels as strings
12
  - show only animal masks (ask an nlp model?)
@@ -178,7 +180,7 @@ def ade_palette():
178
 
179
 
180
  def predict_animal_mask(im,
181
- flag_high_confidence):
182
  image = Image.fromarray(im) # im: numpy array 3d: 480, 640, 3: to PIL Image
183
  image = image.resize((200,200)) # PIL image # could I upsample output instead? better?
184
 
@@ -191,12 +193,9 @@ def predict_animal_mask(im,
191
 
192
  # keep only the masks with high confidence?--------------------------------
193
  # compute the prob per mask (i.e., class), excluding the "no-object" class (the last one)
194
- if flag_high_confidence:
195
- prob_per_query = outputs.logits.softmax(-1)[..., :-1].max(-1)[0] # why logits last dim 251?
196
- # threshold the confidence
197
- keep = prob_per_query > 0.85
198
- else:
199
- keep = torch.ones(outputs.logits.shape[0:2], dtype=torch.bool)
200
 
201
  # postprocess the mask (numpy arrays)
202
  label_per_pixel = torch.argmax(masks[keep].squeeze(),dim=0).detach().numpy() # from the masks per class, select the highest per pixel
@@ -217,18 +216,18 @@ model = DetrForSegmentation.from_pretrained('facebook/detr-resnet-50-panoptic')
217
 
218
  # gradio components -inputs
219
  gr_image_input = gr.inputs.Image()
220
- gr_checkbox_high_confidence = gr.inputs.Checkbox(False,
221
- label='disply high confidence only?')
222
  # gradio outputs
223
  gr_image_output = gr.outputs.Image()
224
 
225
  ####################################################
226
  # Create user interface and launch
227
  gr.Interface(predict_animal_mask,
228
- inputs = [gr_image_input,gr_checkbox_high_confidence],
229
- outputs = gr_image_output,
230
- title = 'Animals* segmentation in images',
231
- description = "An animal* segmentation image webapp using DETR (End-to-End Object Detection) model with ResNet-50 backbone").launch()
232
 
233
 
234
  ####################################
 
7
 
8
  https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/DETR/DETR_panoptic_segmentation_minimal_example_(with_DetrFeatureExtractor).ipynb
9
 
10
+ https://arxiv.org/abs/2005.12872
11
+
12
  Additions
13
  - add shown labels as strings
14
  - show only animal masks (ask an nlp model?)
 
180
 
181
 
182
  def predict_animal_mask(im,
183
+ gr_slider_confidence):
184
  image = Image.fromarray(im) # im: numpy array 3d: 480, 640, 3: to PIL Image
185
  image = image.resize((200,200)) # PIL image # could I upsample output instead? better?
186
 
 
193
 
194
  # keep only the masks with high confidence?--------------------------------
195
  # compute the prob per mask (i.e., class), excluding the "no-object" class (the last one)
196
+ prob_per_query = outputs.logits.softmax(-1)[..., :-1].max(-1)[0] # why logits last dim 251?
197
+ # threshold the confidence
198
+ keep = prob_per_query > gr_slider_confidence/100.0
 
 
 
199
 
200
  # postprocess the mask (numpy arrays)
201
  label_per_pixel = torch.argmax(masks[keep].squeeze(),dim=0).detach().numpy() # from the masks per class, select the highest per pixel
 
216
 
217
  # gradio components -inputs
218
  gr_image_input = gr.inputs.Image()
219
+ gr_slider_confidence = gr.inputs.Slider(0,100,5,85,
220
+ label='Set confidence threshold for masks')
221
  # gradio outputs
222
  gr_image_output = gr.outputs.Image()
223
 
224
  ####################################################
225
  # Create user interface and launch
226
  gr.Interface(predict_animal_mask,
227
+ inputs = [gr_image_input,gr_slider_confidence],
228
+ outputs = gr_image_output,
229
+ title = 'Image segmentation with varying confidence',
230
+ description = "An image segmentation webapp using DETR (End-to-End Object Detection) model with ResNet-50 backbone").launch()
231
 
232
 
233
  ####################################