Tobias Cornille
commited on
Commit
•
a884db1
1
Parent(s):
f8636a4
Make more robust
Browse files
app.py
CHANGED
@@ -107,7 +107,7 @@ def dino_detection(
|
|
107 |
visualization = Image.fromarray(annotated_frame)
|
108 |
return boxes, category_ids, visualization
|
109 |
else:
|
110 |
-
return boxes, category_ids
|
111 |
|
112 |
|
113 |
def sam_masks_from_dino_boxes(predictor, image_array, boxes, device):
|
@@ -153,13 +153,16 @@ def clipseg_segmentation(
|
|
153 |
).to(device)
|
154 |
with torch.no_grad():
|
155 |
outputs = model(**inputs)
|
|
|
|
|
|
|
156 |
# resize the outputs
|
157 |
-
|
158 |
-
|
159 |
size=(image.size[1], image.size[0]),
|
160 |
mode="bilinear",
|
161 |
)
|
162 |
-
preds = torch.sigmoid(
|
163 |
semantic_inds = preds_to_semantic_inds(preds, background_threshold)
|
164 |
return preds, semantic_inds
|
165 |
|
@@ -192,7 +195,7 @@ def clip_and_shrink_preds(semantic_inds, preds, shrink_kernel_size, num_categori
|
|
192 |
torch.sum(bool_masks[i].int()).item() for i in range(1, bool_masks.size(0))
|
193 |
]
|
194 |
max_size = max(sizes)
|
195 |
-
relative_sizes = [size / max_size for size in sizes]
|
196 |
|
197 |
# use bool masks to clip preds
|
198 |
clipped_preds = torch.zeros_like(preds)
|
@@ -237,7 +240,7 @@ def upsample_pred(pred, image_source):
|
|
237 |
else:
|
238 |
target_height = int(upsampled_tensor.shape[2] * aspect_ratio)
|
239 |
upsampled_tensor = upsampled_tensor[:, :, :target_height, :]
|
240 |
-
return upsampled_tensor.squeeze()
|
241 |
|
242 |
|
243 |
def sam_mask_from_points(predictor, image_array, points):
|
@@ -335,67 +338,82 @@ def generate_panoptic_mask(
|
|
335 |
image = image.convert("RGB")
|
336 |
image_array = np.asarray(image)
|
337 |
|
338 |
-
# detect boxes for "thing" categories using Grounding DINO
|
339 |
-
thing_boxes, thing_category_ids = dino_detection(
|
340 |
-
dino_model,
|
341 |
-
image,
|
342 |
-
image_array,
|
343 |
-
thing_category_names,
|
344 |
-
category_name_to_id,
|
345 |
-
dino_box_threshold,
|
346 |
-
dino_text_threshold,
|
347 |
-
device,
|
348 |
-
)
|
349 |
# compute SAM image embedding
|
350 |
sam_predictor.set_image(image_array)
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
#
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
393 |
# combine the thing inds and the stuff inds into panoptic inds
|
394 |
-
panoptic_inds =
|
|
|
|
|
|
|
|
|
395 |
ind = len(stuff_category_names) + 1
|
396 |
for thing_mask in thing_masks:
|
397 |
# overlay thing mask on panoptic inds
|
398 |
-
panoptic_inds[thing_mask.squeeze()] = ind
|
399 |
ind += 1
|
400 |
|
401 |
segmentation_bitmap, annotations = inds_to_segments_format(
|
|
|
107 |
visualization = Image.fromarray(annotated_frame)
|
108 |
return boxes, category_ids, visualization
|
109 |
else:
|
110 |
+
return boxes, category_ids, phrases
|
111 |
|
112 |
|
113 |
def sam_masks_from_dino_boxes(predictor, image_array, boxes, device):
|
|
|
153 |
).to(device)
|
154 |
with torch.no_grad():
|
155 |
outputs = model(**inputs)
|
156 |
+
logits = outputs.logits
|
157 |
+
if len(logits.shape) == 2:
|
158 |
+
logits = logits.unsqueeze(0)
|
159 |
# resize the outputs
|
160 |
+
upscaled_logits = nn.functional.interpolate(
|
161 |
+
logits.unsqueeze(1),
|
162 |
size=(image.size[1], image.size[0]),
|
163 |
mode="bilinear",
|
164 |
)
|
165 |
+
preds = torch.sigmoid(upscaled_logits.squeeze(dim=1))
|
166 |
semantic_inds = preds_to_semantic_inds(preds, background_threshold)
|
167 |
return preds, semantic_inds
|
168 |
|
|
|
195 |
torch.sum(bool_masks[i].int()).item() for i in range(1, bool_masks.size(0))
|
196 |
]
|
197 |
max_size = max(sizes)
|
198 |
+
relative_sizes = [size / max_size for size in sizes] if max_size > 0 else sizes
|
199 |
|
200 |
# use bool masks to clip preds
|
201 |
clipped_preds = torch.zeros_like(preds)
|
|
|
240 |
else:
|
241 |
target_height = int(upsampled_tensor.shape[2] * aspect_ratio)
|
242 |
upsampled_tensor = upsampled_tensor[:, :, :target_height, :]
|
243 |
+
return upsampled_tensor.squeeze(dim=1)
|
244 |
|
245 |
|
246 |
def sam_mask_from_points(predictor, image_array, points):
|
|
|
338 |
image = image.convert("RGB")
|
339 |
image_array = np.asarray(image)
|
340 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
341 |
# compute SAM image embedding
|
342 |
sam_predictor.set_image(image_array)
|
343 |
+
|
344 |
+
# detect boxes for "thing" categories using Grounding DINO
|
345 |
+
thing_category_ids = []
|
346 |
+
thing_masks = []
|
347 |
+
thing_boxes = []
|
348 |
+
if len(thing_category_names) > 0:
|
349 |
+
thing_boxes, thing_category_ids, _ = dino_detection(
|
350 |
+
dino_model,
|
351 |
+
image,
|
352 |
+
image_array,
|
353 |
+
thing_category_names,
|
354 |
+
category_name_to_id,
|
355 |
+
dino_box_threshold,
|
356 |
+
dino_text_threshold,
|
357 |
+
device,
|
358 |
+
)
|
359 |
+
if len(thing_boxes) > 0:
|
360 |
+
# get segmentation masks for the thing boxes
|
361 |
+
thing_masks = sam_masks_from_dino_boxes(
|
362 |
+
sam_predictor, image_array, thing_boxes, device
|
363 |
+
)
|
364 |
+
if len(stuff_category_names) > 0:
|
365 |
+
# get rough segmentation masks for "stuff" categories using CLIPSeg
|
366 |
+
clipseg_preds, clipseg_semantic_inds = clipseg_segmentation(
|
367 |
+
clipseg_processor,
|
368 |
+
clipseg_model,
|
369 |
+
image,
|
370 |
+
stuff_category_names,
|
371 |
+
segmentation_background_threshold,
|
372 |
+
device,
|
373 |
+
)
|
374 |
+
# remove things from stuff masks
|
375 |
+
clipseg_semantic_inds_without_things = clipseg_semantic_inds.clone()
|
376 |
+
if len(thing_boxes) > 0:
|
377 |
+
combined_things_mask = torch.any(thing_masks, dim=0)
|
378 |
+
clipseg_semantic_inds_without_things[combined_things_mask[0]] = 0
|
379 |
+
# clip CLIPSeg preds based on non-overlapping semantic segmentation inds (+ optionally shrink the mask of each category)
|
380 |
+
# also returns the relative size of each category
|
381 |
+
clipsed_clipped_preds, relative_sizes = clip_and_shrink_preds(
|
382 |
+
clipseg_semantic_inds_without_things,
|
383 |
+
clipseg_preds,
|
384 |
+
shrink_kernel_size,
|
385 |
+
len(stuff_category_names) + 1,
|
386 |
+
)
|
387 |
+
# get finer segmentation masks for the "stuff" categories using SAM
|
388 |
+
sam_preds = torch.zeros_like(clipsed_clipped_preds)
|
389 |
+
for i in range(clipsed_clipped_preds.shape[0]):
|
390 |
+
clipseg_pred = clipsed_clipped_preds[i]
|
391 |
+
# for each "stuff" category, sample points in the rough segmentation mask
|
392 |
+
num_samples = int(relative_sizes[i] * num_samples_factor)
|
393 |
+
if num_samples == 0:
|
394 |
+
continue
|
395 |
+
points = sample_points_based_on_preds(
|
396 |
+
clipseg_pred.cpu().numpy(), num_samples
|
397 |
+
)
|
398 |
+
if len(points) == 0:
|
399 |
+
continue
|
400 |
+
# use SAM to get mask for points
|
401 |
+
pred = sam_mask_from_points(sam_predictor, image_array, points)
|
402 |
+
sam_preds[i] = pred
|
403 |
+
sam_semantic_inds = preds_to_semantic_inds(
|
404 |
+
sam_preds, segmentation_background_threshold
|
405 |
+
)
|
406 |
+
|
407 |
# combine the thing inds and the stuff inds into panoptic inds
|
408 |
+
panoptic_inds = (
|
409 |
+
sam_semantic_inds.clone()
|
410 |
+
if len(stuff_category_names) > 0
|
411 |
+
else torch.zeros(image_array.shape[0], image_array.shape[1], dtype=torch.long)
|
412 |
+
)
|
413 |
ind = len(stuff_category_names) + 1
|
414 |
for thing_mask in thing_masks:
|
415 |
# overlay thing mask on panoptic inds
|
416 |
+
panoptic_inds[thing_mask.squeeze(dim=0)] = ind
|
417 |
ind += 1
|
418 |
|
419 |
segmentation_bitmap, annotations = inds_to_segments_format(
|