Tobias Cornille commited on
Commit
a884db1
1 Parent(s): f8636a4

Make more robust

Browse files
Files changed (1) hide show
  1. app.py +79 -61
app.py CHANGED
@@ -107,7 +107,7 @@ def dino_detection(
107
  visualization = Image.fromarray(annotated_frame)
108
  return boxes, category_ids, visualization
109
  else:
110
- return boxes, category_ids
111
 
112
 
113
  def sam_masks_from_dino_boxes(predictor, image_array, boxes, device):
@@ -153,13 +153,16 @@ def clipseg_segmentation(
153
  ).to(device)
154
  with torch.no_grad():
155
  outputs = model(**inputs)
 
 
 
156
  # resize the outputs
157
- logits = nn.functional.interpolate(
158
- outputs.logits.unsqueeze(1),
159
  size=(image.size[1], image.size[0]),
160
  mode="bilinear",
161
  )
162
- preds = torch.sigmoid(logits.squeeze())
163
  semantic_inds = preds_to_semantic_inds(preds, background_threshold)
164
  return preds, semantic_inds
165
 
@@ -192,7 +195,7 @@ def clip_and_shrink_preds(semantic_inds, preds, shrink_kernel_size, num_categori
192
  torch.sum(bool_masks[i].int()).item() for i in range(1, bool_masks.size(0))
193
  ]
194
  max_size = max(sizes)
195
- relative_sizes = [size / max_size for size in sizes]
196
 
197
  # use bool masks to clip preds
198
  clipped_preds = torch.zeros_like(preds)
@@ -237,7 +240,7 @@ def upsample_pred(pred, image_source):
237
  else:
238
  target_height = int(upsampled_tensor.shape[2] * aspect_ratio)
239
  upsampled_tensor = upsampled_tensor[:, :, :target_height, :]
240
- return upsampled_tensor.squeeze()
241
 
242
 
243
  def sam_mask_from_points(predictor, image_array, points):
@@ -335,67 +338,82 @@ def generate_panoptic_mask(
335
  image = image.convert("RGB")
336
  image_array = np.asarray(image)
337
 
338
- # detect boxes for "thing" categories using Grounding DINO
339
- thing_boxes, thing_category_ids = dino_detection(
340
- dino_model,
341
- image,
342
- image_array,
343
- thing_category_names,
344
- category_name_to_id,
345
- dino_box_threshold,
346
- dino_text_threshold,
347
- device,
348
- )
349
  # compute SAM image embedding
350
  sam_predictor.set_image(image_array)
351
- # get segmentation masks for the thing boxes
352
- thing_masks = sam_masks_from_dino_boxes(
353
- sam_predictor, image_array, thing_boxes, device
354
- )
355
- # get rough segmentation masks for "stuff" categories using CLIPSeg
356
- clipseg_preds, clipseg_semantic_inds = clipseg_segmentation(
357
- clipseg_processor,
358
- clipseg_model,
359
- image,
360
- stuff_category_names,
361
- segmentation_background_threshold,
362
- device,
363
- )
364
- # remove things from stuff masks
365
- combined_things_mask = torch.any(thing_masks, dim=0)
366
- clipseg_semantic_inds_without_things = clipseg_semantic_inds.clone()
367
- clipseg_semantic_inds_without_things[combined_things_mask[0]] = 0
368
- # clip CLIPSeg preds based on non-overlapping semantic segmentation inds (+ optionally shrink the mask of each category)
369
- # also returns the relative size of each category
370
- clipsed_clipped_preds, relative_sizes = clip_and_shrink_preds(
371
- clipseg_semantic_inds_without_things,
372
- clipseg_preds,
373
- shrink_kernel_size,
374
- len(stuff_category_names) + 1,
375
- )
376
- # get finer segmentation masks for the "stuff" categories using SAM
377
- sam_preds = torch.zeros_like(clipsed_clipped_preds)
378
- for i in range(clipsed_clipped_preds.shape[0]):
379
- clipseg_pred = clipsed_clipped_preds[i]
380
- # for each "stuff" category, sample points in the rough segmentation mask
381
- num_samples = int(relative_sizes[i] * num_samples_factor)
382
- if num_samples == 0:
383
- continue
384
- points = sample_points_based_on_preds(clipseg_pred.cpu().numpy(), num_samples)
385
- if len(points) == 0:
386
- continue
387
- # use SAM to get mask for points
388
- pred = sam_mask_from_points(sam_predictor, image_array, points)
389
- sam_preds[i] = pred
390
- sam_semantic_inds = preds_to_semantic_inds(
391
- sam_preds, segmentation_background_threshold
392
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  # combine the thing inds and the stuff inds into panoptic inds
394
- panoptic_inds = sam_semantic_inds.clone()
 
 
 
 
395
  ind = len(stuff_category_names) + 1
396
  for thing_mask in thing_masks:
397
  # overlay thing mask on panoptic inds
398
- panoptic_inds[thing_mask.squeeze()] = ind
399
  ind += 1
400
 
401
  segmentation_bitmap, annotations = inds_to_segments_format(
 
107
  visualization = Image.fromarray(annotated_frame)
108
  return boxes, category_ids, visualization
109
  else:
110
+ return boxes, category_ids, phrases
111
 
112
 
113
  def sam_masks_from_dino_boxes(predictor, image_array, boxes, device):
 
153
  ).to(device)
154
  with torch.no_grad():
155
  outputs = model(**inputs)
156
+ logits = outputs.logits
157
+ if len(logits.shape) == 2:
158
+ logits = logits.unsqueeze(0)
159
  # resize the outputs
160
+ upscaled_logits = nn.functional.interpolate(
161
+ logits.unsqueeze(1),
162
  size=(image.size[1], image.size[0]),
163
  mode="bilinear",
164
  )
165
+ preds = torch.sigmoid(upscaled_logits.squeeze(dim=1))
166
  semantic_inds = preds_to_semantic_inds(preds, background_threshold)
167
  return preds, semantic_inds
168
 
 
195
  torch.sum(bool_masks[i].int()).item() for i in range(1, bool_masks.size(0))
196
  ]
197
  max_size = max(sizes)
198
+ relative_sizes = [size / max_size for size in sizes] if max_size > 0 else sizes
199
 
200
  # use bool masks to clip preds
201
  clipped_preds = torch.zeros_like(preds)
 
240
  else:
241
  target_height = int(upsampled_tensor.shape[2] * aspect_ratio)
242
  upsampled_tensor = upsampled_tensor[:, :, :target_height, :]
243
+ return upsampled_tensor.squeeze(dim=1)
244
 
245
 
246
  def sam_mask_from_points(predictor, image_array, points):
 
338
  image = image.convert("RGB")
339
  image_array = np.asarray(image)
340
 
 
 
 
 
 
 
 
 
 
 
 
341
  # compute SAM image embedding
342
  sam_predictor.set_image(image_array)
343
+
344
+ # detect boxes for "thing" categories using Grounding DINO
345
+ thing_category_ids = []
346
+ thing_masks = []
347
+ thing_boxes = []
348
+ if len(thing_category_names) > 0:
349
+ thing_boxes, thing_category_ids, _ = dino_detection(
350
+ dino_model,
351
+ image,
352
+ image_array,
353
+ thing_category_names,
354
+ category_name_to_id,
355
+ dino_box_threshold,
356
+ dino_text_threshold,
357
+ device,
358
+ )
359
+ if len(thing_boxes) > 0:
360
+ # get segmentation masks for the thing boxes
361
+ thing_masks = sam_masks_from_dino_boxes(
362
+ sam_predictor, image_array, thing_boxes, device
363
+ )
364
+ if len(stuff_category_names) > 0:
365
+ # get rough segmentation masks for "stuff" categories using CLIPSeg
366
+ clipseg_preds, clipseg_semantic_inds = clipseg_segmentation(
367
+ clipseg_processor,
368
+ clipseg_model,
369
+ image,
370
+ stuff_category_names,
371
+ segmentation_background_threshold,
372
+ device,
373
+ )
374
+ # remove things from stuff masks
375
+ clipseg_semantic_inds_without_things = clipseg_semantic_inds.clone()
376
+ if len(thing_boxes) > 0:
377
+ combined_things_mask = torch.any(thing_masks, dim=0)
378
+ clipseg_semantic_inds_without_things[combined_things_mask[0]] = 0
379
+ # clip CLIPSeg preds based on non-overlapping semantic segmentation inds (+ optionally shrink the mask of each category)
380
+ # also returns the relative size of each category
381
+ clipsed_clipped_preds, relative_sizes = clip_and_shrink_preds(
382
+ clipseg_semantic_inds_without_things,
383
+ clipseg_preds,
384
+ shrink_kernel_size,
385
+ len(stuff_category_names) + 1,
386
+ )
387
+ # get finer segmentation masks for the "stuff" categories using SAM
388
+ sam_preds = torch.zeros_like(clipsed_clipped_preds)
389
+ for i in range(clipsed_clipped_preds.shape[0]):
390
+ clipseg_pred = clipsed_clipped_preds[i]
391
+ # for each "stuff" category, sample points in the rough segmentation mask
392
+ num_samples = int(relative_sizes[i] * num_samples_factor)
393
+ if num_samples == 0:
394
+ continue
395
+ points = sample_points_based_on_preds(
396
+ clipseg_pred.cpu().numpy(), num_samples
397
+ )
398
+ if len(points) == 0:
399
+ continue
400
+ # use SAM to get mask for points
401
+ pred = sam_mask_from_points(sam_predictor, image_array, points)
402
+ sam_preds[i] = pred
403
+ sam_semantic_inds = preds_to_semantic_inds(
404
+ sam_preds, segmentation_background_threshold
405
+ )
406
+
407
  # combine the thing inds and the stuff inds into panoptic inds
408
+ panoptic_inds = (
409
+ sam_semantic_inds.clone()
410
+ if len(stuff_category_names) > 0
411
+ else torch.zeros(image_array.shape[0], image_array.shape[1], dtype=torch.long)
412
+ )
413
  ind = len(stuff_category_names) + 1
414
  for thing_mask in thing_masks:
415
  # overlay thing mask on panoptic inds
416
+ panoptic_inds[thing_mask.squeeze(dim=0)] = ind
417
  ind += 1
418
 
419
  segmentation_bitmap, annotations = inds_to_segments_format(