Spaces:
Runtime error
Runtime error
import sys | |
import modules.config | |
import numpy as np | |
import torch | |
from extras.GroundingDINO.util.inference import default_groundingdino | |
from extras.sam.predictor import SamPredictor | |
from rembg import remove, new_session | |
from segment_anything import sam_model_registry | |
from segment_anything.utils.amg import remove_small_regions | |
class SAMOptions: | |
def __init__(self, | |
# GroundingDINO | |
dino_prompt: str = '', | |
dino_box_threshold=0.3, | |
dino_text_threshold=0.25, | |
dino_erode_or_dilate=0, | |
dino_debug=False, | |
# SAM | |
max_detections=2, | |
model_type='vit_b' | |
): | |
self.dino_prompt = dino_prompt | |
self.dino_box_threshold = dino_box_threshold | |
self.dino_text_threshold = dino_text_threshold | |
self.dino_erode_or_dilate = dino_erode_or_dilate | |
self.dino_debug = dino_debug | |
self.max_detections = max_detections | |
self.model_type = model_type | |
def optimize_masks(masks: torch.Tensor) -> torch.Tensor: | |
""" | |
removes small disconnected regions and holes | |
""" | |
fine_masks = [] | |
for mask in masks.to('cpu').numpy(): # masks: [num_masks, 1, h, w] | |
fine_masks.append(remove_small_regions(mask[0], 400, mode="holes")[0]) | |
masks = np.stack(fine_masks, axis=0)[:, np.newaxis] | |
return torch.from_numpy(masks) | |
def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=None, | |
sam_options: SAMOptions | None = SAMOptions) -> tuple[np.ndarray | None, int | None, int | None, int | None]: | |
dino_detection_count = 0 | |
sam_detection_count = 0 | |
sam_detection_on_mask_count = 0 | |
if image is None: | |
return None, dino_detection_count, sam_detection_count, sam_detection_on_mask_count | |
if extras is None: | |
extras = {} | |
if 'image' in image: | |
image = image['image'] | |
if mask_model != 'sam' or sam_options is None: | |
result = remove( | |
image, | |
session=new_session(mask_model, **extras), | |
only_mask=True, | |
**extras | |
) | |
return result, dino_detection_count, sam_detection_count, sam_detection_on_mask_count | |
detections, boxes, logits, phrases = default_groundingdino( | |
image=image, | |
caption=sam_options.dino_prompt, | |
box_threshold=sam_options.dino_box_threshold, | |
text_threshold=sam_options.dino_text_threshold | |
) | |
H, W = image.shape[0], image.shape[1] | |
boxes = boxes * torch.Tensor([W, H, W, H]) | |
boxes[:, :2] = boxes[:, :2] - boxes[:, 2:] / 2 | |
boxes[:, 2:] = boxes[:, 2:] + boxes[:, :2] | |
sam_checkpoint = modules.config.download_sam_model(sam_options.model_type) | |
sam = sam_model_registry[sam_options.model_type](checkpoint=sam_checkpoint) | |
sam_predictor = SamPredictor(sam) | |
final_mask_tensor = torch.zeros((image.shape[0], image.shape[1])) | |
dino_detection_count = boxes.size(0) | |
if dino_detection_count > 0: | |
sam_predictor.set_image(image) | |
if sam_options.dino_erode_or_dilate != 0: | |
for index in range(boxes.size(0)): | |
assert boxes.size(1) == 4 | |
boxes[index][0] -= sam_options.dino_erode_or_dilate | |
boxes[index][1] -= sam_options.dino_erode_or_dilate | |
boxes[index][2] += sam_options.dino_erode_or_dilate | |
boxes[index][3] += sam_options.dino_erode_or_dilate | |
if sam_options.dino_debug: | |
from PIL import ImageDraw, Image | |
debug_dino_image = Image.new("RGB", (image.shape[1], image.shape[0]), color="black") | |
draw = ImageDraw.Draw(debug_dino_image) | |
for box in boxes.numpy(): | |
draw.rectangle(box.tolist(), fill="white") | |
return np.array(debug_dino_image), dino_detection_count, sam_detection_count, sam_detection_on_mask_count | |
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes, image.shape[:2]) | |
masks, _, _ = sam_predictor.predict_torch( | |
point_coords=None, | |
point_labels=None, | |
boxes=transformed_boxes, | |
multimask_output=False, | |
) | |
masks = optimize_masks(masks) | |
sam_detection_count = len(masks) | |
if sam_options.max_detections == 0: | |
sam_options.max_detections = sys.maxsize | |
sam_objects = min(len(logits), sam_options.max_detections) | |
for obj_ind in range(sam_objects): | |
mask_tensor = masks[obj_ind][0] | |
final_mask_tensor += mask_tensor | |
sam_detection_on_mask_count += 1 | |
final_mask_tensor = (final_mask_tensor > 0).to('cpu').numpy() | |
mask_image = np.dstack((final_mask_tensor, final_mask_tensor, final_mask_tensor)) * 255 | |
mask_image = np.array(mask_image, dtype=np.uint8) | |
return mask_image, dino_detection_count, sam_detection_count, sam_detection_on_mask_count | |