import logging import os from pathlib import Path from typing import Final, List, Mapping from urllib.parse import urlparse import cv2 from PIL import Image import numpy as np import requests import rerun as rr import torch import torchvision from cv2 import Mat from segment_anything import SamPredictor, sam_model_registry from segment_anything.modeling import Sam from tqdm import tqdm # Grounding DINO import GroundingDINO.groundingdino.datasets.transforms as T from GroundingDINO.groundingdino.models import build_model from GroundingDINO.groundingdino.util.slconfig import SLConfig from GroundingDINO.groundingdino.util.utils import ( clean_state_dict, get_phrases_from_posmap, ) from groundingdino.models import GroundingDINO CONFIG_PATH: Final = ( Path(os.path.dirname(__file__)) / "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" ) MODEL_DIR: Final = Path(os.path.dirname(__file__)) / "model" MODEL_URLS: Final = { "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", "grounding": "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth", } def download_with_progress(url: str, dest: Path) -> None: """Download file with tqdm progress bar.""" chunk_size = 1024 * 1024 resp = requests.get(url, stream=True) total_size = int(resp.headers.get("content-length", 0)) with open(dest, "wb") as dest_file: with tqdm( desc="Downloading model", total=total_size, unit="iB", unit_scale=True, unit_divisor=1024, ) as progress: for data in resp.iter_content(chunk_size): dest_file.write(data) progress.update(len(data)) def get_downloaded_model_path(model_name: str) -> Path: """Fetch the segment-anything model to a local cache directory.""" model_url = MODEL_URLS[model_name] model_location = MODEL_DIR / model_url.split("/")[-1] if not model_location.exists(): os.makedirs(MODEL_DIR, exist_ok=True) download_with_progress(model_url, model_location) return model_location def create_sam(model: str, device: str) -> Sam: """Load the segment-anything model, fetching the model-file as necessary.""" model_path = get_downloaded_model_path(model) logging.info("PyTorch version: {}".format(torch.__version__)) logging.info("Torchvision version: {}".format(torchvision.__version__)) logging.info("CUDA is available: {}".format(torch.cuda.is_available())) logging.info("Building sam from: {}".format(model_path)) sam = sam_model_registry[model](checkpoint=model_path) return sam.to(device=device) def run_segmentation( predictor: SamPredictor, image: Mat, detections, phrases: List[str], id_from_phrase: Mapping[str, int], ) -> None: """Run segmentation on a single image.""" if detections.shape[0] == 0: return logging.info("Finding masks") transformed_boxes = predictor.transform.apply_boxes_torch( detections, image.shape[:2] ) masks, _, _ = predictor.predict_torch( point_coords=None, point_labels=None, boxes=transformed_boxes.to(predictor.device), multimask_output=False, ) logging.info("Found {} masks".format(len(masks))) # Layer all of the masks that belong to a single phrase together segmentation_img = np.zeros((image.shape[0], image.shape[1])) for phrase, mask in zip(phrases, masks): segmentation_img[mask.squeeze()] = id_from_phrase[phrase] rr.log_segmentation_image("image/segmentation", segmentation_img) def is_url(path: str) -> bool: """Check if a path is a url or a local file.""" try: result = urlparse(path) return all([result.scheme, result.netloc]) except ValueError: return False def resize_img(img: Mat, max_dimension: int = 512) -> Mat: height, width = img.shape[:2] # Check if either dimension is larger than the maximum if max(height, width) > max_dimension: # Calculate the new dimensions while maintaining the aspect ratio if height > width: new_height = max_dimension new_width = int((new_height * width) / height) else: new_width = max_dimension new_height = int((new_width * height) / width) # Resize the image resized_image = cv2.resize( img, (new_width, new_height), interpolation=cv2.INTER_AREA ) return resized_image def image_to_tensor(image: Mat) -> torch.Tensor: """ Assumes a RGB OpenCV image, this is required for the DINO model """ image_pil = Image.fromarray(image) transform = T.Compose( [ T.RandomResize([800], max_size=1333), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) image_tensor, _ = transform(image_pil, None) # 3, h, w return image_tensor def load_image(image_uri: str) -> Mat: """Conditionally download an image from URL or load it from disk.""" logging.info("Loading: {}".format(image_uri)) if is_url(image_uri): response = requests.get(image_uri) response.raise_for_status() image_data = np.asarray(bytearray(response.content), dtype="uint8") image = cv2.imdecode(image_data, cv2.IMREAD_COLOR) else: image = cv2.imread(image_uri, cv2.IMREAD_COLOR) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) return image def load_grounding_model( model_config_path: Path, model_checkpoint_path: Path, device: str ) -> GroundingDINO: args = SLConfig.fromfile(model_config_path) args.device = device model = build_model(args) checkpoint = torch.load(model_checkpoint_path, map_location="cpu") _ = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) _ = model.eval() return model def get_grounding_output( model: GroundingDINO, image: torch.Tensor, caption: str, box_threshold: float, text_threshold: float, with_logits: bool = False, device: str = "cpu", ): caption = caption.lower() caption = caption.strip() if not caption.endswith("."): caption = caption + "." model = model.to(device) image = image.to(device) with torch.no_grad(): outputs = model(image[None], captions=[caption]) logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256) boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4) logits.shape[0] # filter output logits_filt = logits.clone() boxes_filt = boxes.clone() filt_mask = logits_filt.max(dim=1)[0] > box_threshold logits_filt = logits_filt[filt_mask] # num_filt, 256 boxes_filt = boxes_filt[filt_mask] # num_filt, 4 logits_filt.shape[0] # get phrase tokenlizer = model.tokenizer tokenized = tokenlizer(caption) # build pred pred_phrases = [] for logit, box in zip(logits_filt, boxes_filt): pred_phrase = get_phrases_from_posmap( logit > text_threshold, tokenized, tokenlizer ) if with_logits: pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})") else: pred_phrases.append(pred_phrase) return boxes_filt, pred_phrases