Spaces:
Runtime error
Runtime error
from typing import Tuple, List | |
import torch | |
from groundingdino.util.utils import get_phrases_from_posmap | |
def preprocess_caption(caption: str) -> str: | |
result = caption.lower().strip() | |
if result.endswith("."): | |
return result | |
return result + "." | |
def predict( | |
model, | |
image: torch.Tensor, | |
caption: str, | |
box_threshold: float, | |
text_threshold: float, | |
device: str = "cuda" | |
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]: | |
caption = preprocess_caption(caption=caption) | |
model = model.to(device) | |
image = image.to(device) | |
with torch.no_grad(): | |
outputs = model(image[None], captions=[caption]) | |
prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256) | |
prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4) | |
mask = prediction_logits.max(dim=1)[0] > box_threshold | |
logits = prediction_logits[mask] # logits.shape = (n, 256) | |
boxes = prediction_boxes[mask] # boxes.shape = (n, 4) | |
tokenizer = model.tokenizer | |
tokenized = tokenizer(caption) | |
print(tokenized) | |
phrases = [ | |
get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '') | |
for logit | |
in logits | |
] | |
return boxes, logits.max(dim=1)[0], phrases |