from typing import Tuple, List import torch from groundingdino.util.utils import get_phrases_from_posmap def preprocess_caption(caption: str) -> str: result = caption.lower().strip() if result.endswith("."): return result return result + "." def predict( model, image: torch.Tensor, caption: str, box_threshold: float, text_threshold: float, device: str = "cuda" ) -> Tuple[torch.Tensor, torch.Tensor, List[str]]: caption = preprocess_caption(caption=caption) model = model.to(device) image = image.to(device) with torch.no_grad(): outputs = model(image[None], captions=[caption]) prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256) prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4) mask = prediction_logits.max(dim=1)[0] > box_threshold logits = prediction_logits[mask] # logits.shape = (n, 256) boxes = prediction_boxes[mask] # boxes.shape = (n, 4) tokenizer = model.tokenizer tokenized = tokenizer(caption) print(tokenized) phrases = [ get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '') for logit in logits ] return boxes, logits.max(dim=1)[0], phrases