|
import cv2 |
|
import numpy as np |
|
import supervision as sv |
|
import argparse |
|
import torch |
|
import torchvision |
|
|
|
from groundingdino.util.inference import Model |
|
from segment_anything import SamPredictor |
|
from MobileSAM.setup_mobile_sam import setup_model |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--MOBILE_SAM_CHECKPOINT_PATH", type=str, default="./EfficientSAM/mobile_sam.pt", help="model" |
|
) |
|
parser.add_argument( |
|
"--SOURCE_IMAGE_PATH", type=str, default="./assets/demo2.jpg", help="path to image file" |
|
) |
|
parser.add_argument( |
|
"--CAPTION", type=str, default="The running dog", help="text prompt for GroundingDINO" |
|
) |
|
parser.add_argument( |
|
"--OUT_FILE_BOX", type=str, default="groundingdino_annotated_image.jpg", help="the output filename" |
|
) |
|
parser.add_argument( |
|
"--OUT_FILE_SEG", type=str, default="grounded_mobile_sam_annotated_image.jpg", help="the output filename" |
|
) |
|
parser.add_argument( |
|
"--OUT_FILE_BIN_MASK", type=str, default="grounded_mobile_sam_bin_mask.jpg", help="the output filename" |
|
) |
|
parser.add_argument("--BOX_THRESHOLD", type=float, default=0.25, help="") |
|
parser.add_argument("--TEXT_THRESHOLD", type=float, default=0.25, help="") |
|
parser.add_argument("--NMS_THRESHOLD", type=float, default=0.8, help="") |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
parser.add_argument( |
|
"--DEVICE", type=str, default=device, help="cuda:[0,1,2,3,4] or cpu" |
|
) |
|
return parser.parse_args() |
|
|
|
def main(args): |
|
DEVICE = args.DEVICE |
|
|
|
|
|
GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" |
|
GROUNDING_DINO_CHECKPOINT_PATH = "./groundingdino_swint_ogc.pth" |
|
|
|
|
|
grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH) |
|
|
|
|
|
MOBILE_SAM_CHECKPOINT_PATH = args.MOBILE_SAM_CHECKPOINT_PATH |
|
checkpoint = torch.load(MOBILE_SAM_CHECKPOINT_PATH) |
|
mobile_sam = setup_model() |
|
mobile_sam.load_state_dict(checkpoint, strict=True) |
|
mobile_sam.to(device=DEVICE) |
|
|
|
sam_predictor = SamPredictor(mobile_sam) |
|
|
|
|
|
|
|
SOURCE_IMAGE_PATH = args.SOURCE_IMAGE_PATH |
|
CLASSES = [args.CAPTION] |
|
BOX_THRESHOLD = args.BOX_THRESHOLD |
|
TEXT_THRESHOLD = args.TEXT_THRESHOLD |
|
NMS_THRESHOLD = args.NMS_THRESHOLD |
|
|
|
|
|
|
|
image = cv2.imread(SOURCE_IMAGE_PATH) |
|
|
|
|
|
detections = grounding_dino_model.predict_with_classes( |
|
image=image, |
|
classes=CLASSES, |
|
box_threshold=BOX_THRESHOLD, |
|
text_threshold=TEXT_THRESHOLD |
|
) |
|
|
|
|
|
box_annotator = sv.BoxAnnotator() |
|
labels = [ |
|
f"{CLASSES[class_id]} {confidence:0.2f}" |
|
for _, _, confidence, class_id, _ |
|
in detections] |
|
annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels) |
|
|
|
|
|
cv2.imwrite(args.OUT_FILE_BOX, annotated_frame) |
|
|
|
|
|
|
|
print(f"Before NMS: {len(detections.xyxy)} boxes") |
|
nms_idx = torchvision.ops.nms( |
|
torch.from_numpy(detections.xyxy), |
|
torch.from_numpy(detections.confidence), |
|
NMS_THRESHOLD |
|
).numpy().tolist() |
|
|
|
detections.xyxy = detections.xyxy[nms_idx] |
|
detections.confidence = detections.confidence[nms_idx] |
|
detections.class_id = detections.class_id[nms_idx] |
|
|
|
print(f"After NMS: {len(detections.xyxy)} boxes") |
|
|
|
|
|
def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray: |
|
sam_predictor.set_image(image) |
|
result_masks = [] |
|
for box in xyxy: |
|
masks, scores, logits = sam_predictor.predict( |
|
box=box, |
|
multimask_output=True |
|
) |
|
index = np.argmax(scores) |
|
result_masks.append(masks[index]) |
|
return np.array(result_masks) |
|
|
|
|
|
|
|
detections.mask = segment( |
|
sam_predictor=sam_predictor, |
|
image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB), |
|
xyxy=detections.xyxy |
|
) |
|
|
|
binary_mask = detections.mask[0].astype(np.uint8)*255 |
|
cv2.imwrite(args.OUT_FILE_BIN_MASK, binary_mask) |
|
|
|
|
|
box_annotator = sv.BoxAnnotator() |
|
mask_annotator = sv.MaskAnnotator() |
|
labels = [ |
|
f"{CLASSES[class_id]} {confidence:0.2f}" |
|
for _, _, confidence, class_id, _ |
|
in detections] |
|
annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections) |
|
annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels) |
|
|
|
cv2.imwrite(args.OUT_FILE_SEG, annotated_image) |
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
main(args) |
|
|