sfmig commited on
Commit
562224f
1 Parent(s): d14d564

added confidence of segmentation from notebook detr

Browse files
Files changed (1) hide show
  1. app.py +38 -15
app.py CHANGED
@@ -3,6 +3,13 @@ Using as reference:
3
  - https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512
4
  - https://huggingface.co/spaces/chansung/segformer-tf-transformers/blob/main/app.py
5
  - https://huggingface.co/facebook/detr-resnet-50-panoptic
 
 
 
 
 
 
 
6
  """
7
 
8
  from transformers import DetrFeatureExtractor, DetrForSegmentation
@@ -168,26 +175,31 @@ def ade_palette():
168
  [102, 255, 0],
169
  [92, 0, 255],
170
  ]
 
171
 
172
- feature_extractor = DetrFeatureExtractor.from_pretrained('facebook/detr-resnet-50-panoptic')
173
- model = DetrForSegmentation.from_pretrained('facebook/detr-resnet-50-panoptic')
174
-
175
- # gradio components
176
- input = gr.inputs.Image()
177
- output = gr.outputs.Image()
178
-
179
- def predict_animal_mask(im):
180
  image = Image.fromarray(im) # im: numpy array 3d: 480, 640, 3: to PIL Image
181
  image = image.resize((200,200)) # PIL image # could I upsample output instead? better?
182
 
183
- inputs = feature_extractor(images=image, return_tensors="pt") #pt=Pytorch, tf=TensorFlow
184
- outputs = model(**inputs)
185
- logits = outputs.logits # torch.Size([1, 100, 251])
 
186
  bboxes = outputs.pred_boxes
187
- masks = outputs.pred_masks # torch.Size([1, 100, 200, 200])
 
 
 
 
 
 
 
 
 
188
 
189
  # postprocess the mask (numpy arrays)
190
- label_per_pixel = torch.argmax(masks.squeeze(),dim=0).detach().numpy()
191
  color_mask = np.zeros(image.size+(3,))
192
  for lbl, color in enumerate(ade_palette()):
193
  color_mask[label_per_pixel==lbl,:] = color
@@ -198,12 +210,23 @@ def predict_animal_mask(im):
198
 
199
  return pred_img
200
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  ####################################################
203
  # Create user interface and launch
204
  gr.Interface(predict_animal_mask,
205
- inputs = input,
206
- outputs = output,
207
  title = 'Animals* segmentation in images',
208
  description = "An animal* segmentation image webapp using DETR (End-to-End Object Detection) model with ResNet-50 backbone").launch()
209
 
 
3
  - https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512
4
  - https://huggingface.co/spaces/chansung/segformer-tf-transformers/blob/main/app.py
5
  - https://huggingface.co/facebook/detr-resnet-50-panoptic
6
+ # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
7
+
8
+ https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/DETR/DETR_panoptic_segmentation_minimal_example_(with_DetrFeatureExtractor).ipynb
9
+
10
+ Additions
11
+ - add shown labels as strings
12
+ - show only animal masks (ask an nlp model?)
13
  """
14
 
15
  from transformers import DetrFeatureExtractor, DetrForSegmentation
 
175
  [102, 255, 0],
176
  [92, 0, 255],
177
  ]
178
+
179
 
180
+ def predict_animal_mask(im,
181
+ flag_high_confidence):
 
 
 
 
 
 
182
  image = Image.fromarray(im) # im: numpy array 3d: 480, 640, 3: to PIL Image
183
  image = image.resize((200,200)) # PIL image # could I upsample output instead? better?
184
 
185
+ # encoding is a dict with pixel_values and pixel_mask
186
+ encoding = feature_extractor(images=image, return_tensors="pt") #pt=Pytorch, tf=TensorFlow
187
+ outputs = model(**encoding) # odict with keys: ['logits', 'pred_boxes', 'pred_masks', 'last_hidden_state', 'encoder_last_hidden_state']
188
+ logits = outputs.logits # torch.Size([1, 100, 251]); why 251?
189
  bboxes = outputs.pred_boxes
190
+ masks = outputs.pred_masks # torch.Size([1, 100, 200, 200]); for every pixel, score in each of the 100 classes? there is a mask per class
191
+
192
+ # keep only the masks with high confidence?--------------------------------
193
+ # compute the prob per mask (i.e., class), excluding the "no-object" class (the last one)
194
+ if flag_high_confidence:
195
+ prob_per_query = outputs.logits.softmax(-1)[..., :-1].max(-1)[0] # why logits last dim 251?
196
+ # threshold the confidence
197
+ keep = prob_per_query > 0.85
198
+ else:
199
+ keep = torch.ones(outputs.logits.shape[0:2], dtype=torch.bool)
200
 
201
  # postprocess the mask (numpy arrays)
202
+ label_per_pixel = torch.argmax(masks[keep].squeeze(),dim=0).detach().numpy() # from the masks per class, select the highest per pixel
203
  color_mask = np.zeros(image.size+(3,))
204
  for lbl, color in enumerate(ade_palette()):
205
  color_mask[label_per_pixel==lbl,:] = color
 
210
 
211
  return pred_img
212
 
213
+ #######################################
214
+ # get models from hugging face
215
+ feature_extractor = DetrFeatureExtractor.from_pretrained('facebook/detr-resnet-50-panoptic')
216
+ model = DetrForSegmentation.from_pretrained('facebook/detr-resnet-50-panoptic')
217
+
218
+ # gradio components -inputs
219
+ gr_image_input = gr.inputs.Image()
220
+ gr_checkbox_high_confidence = gr.inputs.Checkbox(False,
221
+ label='disply high confidence only?')
222
+ # gradio outputs
223
+ gr_image_output = gr.outputs.Image()
224
 
225
  ####################################################
226
  # Create user interface and launch
227
  gr.Interface(predict_animal_mask,
228
+ inputs = [gr_image_input,gr_checkbox_high_confidence],
229
+ outputs = gr_image_output,
230
  title = 'Animals* segmentation in images',
231
  description = "An animal* segmentation image webapp using DETR (End-to-End Object Detection) model with ResNet-50 backbone").launch()
232