#!/usr/bin/env python3 """ Example of using Rerun to log and visualize the out of grounded dino + segment-anything. See: [segment_anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). Can be used to test mask-generation on one or more images, as well as videos. Images can be local file-paths or remote urls. Videos must be local file-paths. Can use multiple prompts. """ import argparse import logging import rerun as rr import torch import cv2 from pathlib import Path from models import CONFIG_PATH, MODEL_URLS, get_downloaded_model_path from models import load_grounding_model, create_sam, load_image, image_to_tensor from models import get_grounding_output, run_segmentation, resize_img from segment_anything import SamPredictor from segment_anything.modeling import Sam from groundingdino.models import GroundingDINO def log_images_segmentation(args, model: GroundingDINO, predictor: Sam): for n, image_uri in enumerate(args.images): rr.set_time_sequence("image", n) image = load_image(image_uri) rr.log_image("image", image) detections, phrases, id_from_phrase = grounding_dino_detect( model, args.device, image, args.prompt ) predictor.set_image(image) run_segmentation(predictor, image, detections, phrases, id_from_phrase) def grounding_dino_detect(model, device, image, prompt): image_tensor = image_to_tensor(image) logging.info(f"Running GroundedDINO with DETECTION PROMPT {prompt}.") boxes_filt, box_phrases = get_grounding_output( model, image_tensor, prompt, 0.3, 0.25, device=device ) logging.info(f"Grounded output with prediction phrases: {box_phrases}") # denormalize boxes (from [0, 1] to image size) H, W, _ = image.shape for i in range(boxes_filt.size(0)): boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H]) boxes_filt[i][:2] -= boxes_filt[i][2:] / 2 boxes_filt[i][2:] += boxes_filt[i][:2] id_from_phrase = {phrase: i for i, phrase in enumerate(set(box_phrases), start=1)} box_ids = [id_from_phrase[phrase] for phrase in box_phrases] # One mask per box # Make sure we have an AnnotationInfo present for every class-id used in this image rr.log_annotation_context( "image", [ rr.AnnotationInfo(id=id, label=phrase) for phrase, id in id_from_phrase.items() ], timeless=False, ) rr.log_rects( "image/detections", rects=boxes_filt.numpy(), class_ids=box_ids, rect_format=rr.RectFormat.XYXY, ) return boxes_filt, box_phrases, id_from_phrase def log_video_segmentation(args, model: GroundingDINO, predictor: Sam): video_path = args.video_path assert video_path.exists() cap = cv2.VideoCapture(str(video_path)) idx = 0 while cap.isOpened(): ret, bgr = cap.read() if not ret: break rr.set_time_sequence("frame", idx) rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) rgb = resize_img(rgb, 512) rr.log_image("image", rgb) detections, phrases, id_from_phrase = grounding_dino_detect( model, args.device, rgb, args.prompt ) predictor.set_image(rgb) run_segmentation(predictor, rgb, detections, phrases, id_from_phrase) idx += 1 def main() -> None: parser = argparse.ArgumentParser( description="Run IDEA Research Grounded Dino + SAM example.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--model", action="store", default="vit_b", choices=MODEL_URLS.keys(), help="Which model to use." "(See: https://github.com/facebookresearch/segment-anything#model-checkpoints)", ) parser.add_argument( "--device", action="store", default="cpu", help="Which torch device to use, e.g. cpu or cuda. " "(See: https://pytorch.org/docs/stable/tensor_attributes.html#torch.device)", ) parser.add_argument( "--prompt", default="tires and windows", type=str, help="List of prompts to use for bounding box detection.", ) parser.add_argument( "images", metavar="N", type=str, nargs="*", help="A list of images to process." ) parser.add_argument( "--bbox-threshold", default=0.3, type=float, help="Threshold for a bounding box to be considered.", ) parser.add_argument( "--video-path", default=None, type=Path, help="Path to video to run segmentation on", ) rr.script_add_args(parser) args = parser.parse_args() rr.script_setup(args, "grounded_sam") logging.getLogger().addHandler(rr.LoggingHandler("logs")) logging.getLogger().setLevel(logging.INFO) # load model grounded_checkpoint = get_downloaded_model_path("grounding") model = load_grounding_model(CONFIG_PATH, grounded_checkpoint, device=args.device) sam = create_sam(args.model, args.device) predictor = SamPredictor(sam) if len(args.images) == 0 and args.video_path is None: logging.info("No image provided. Using default.") args.images = [ "https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg" ] if len(args.images) > 0: log_images_segmentation(args, model, predictor) elif args.video_path is not None: log_video_segmentation(args, model, predictor) rr.script_teardown(args) if __name__ == "__main__": main()