Spaces:
Running
Running
import os | |
import functools | |
import PIL | |
from PIL.Image import Image | |
import numpy as np | |
from typing import List, Union | |
import supervision as sv | |
import torch | |
import torchvision | |
from huggingface_hub import hf_hub_download | |
from sam_extension.pipeline import Pipeline | |
from groundingdino.util.inference import Model | |
GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" | |
GROUNDING_DINO_CHECKPOINT_PATH = "groundingdino_swint_ogc.pth" | |
SAM_REPO_ID = 'YouLiXiya/YL-SAM' | |
LOCAL_DIR = "weights/groundingdino" | |
hf_sam_download = functools.partial(hf_hub_download, repo_id=SAM_REPO_ID, local_dir=LOCAL_DIR, local_dir_use_symlinks=True) | |
class GroundingDinoPipeline(Pipeline): | |
def __init__(self, | |
grounding_dino_config_path, | |
grounfing_dino_ckpt_path, | |
grounding_dino_model, | |
device, | |
*args, | |
**kwargs): | |
super(GroundingDinoPipeline, self).__init__(*args, **kwargs) | |
self.grounding_dino_config_path = grounding_dino_config_path | |
self.grounfing_dino_ckpt_path = grounfing_dino_ckpt_path | |
self.grounding_dino_model = grounding_dino_model | |
self.device = device | |
def from_pretrained(cls, grounding_dino_config_path, grounfing_dino_ckpt_path,device='cuda', *args, **kwargs): | |
if not os.path.exists(grounfing_dino_ckpt_path): | |
hf_sam_download(filename=os.path.basename(grounfing_dino_ckpt_path)) | |
grounding_dino_model = Model(model_config_path=grounding_dino_config_path, | |
model_checkpoint_path=grounfing_dino_ckpt_path, | |
device=device) | |
return cls(grounding_dino_config_path, | |
grounfing_dino_ckpt_path, | |
grounding_dino_model, | |
device, | |
*args, | |
**kwargs) | |
def visualize_results(self, | |
img: Union[Image, np.ndarray], | |
class_list: [List], | |
box_threshold: float=0.25, | |
text_threshold: float=0.25, | |
nms_threshold: float=0.8, | |
pil: bool=True): | |
detections = self.forward(img, class_list, box_threshold, text_threshold) | |
box_annotator = sv.BoxAnnotator() | |
nms_idx = torchvision.ops.nms( | |
torch.from_numpy(detections.xyxy), | |
torch.from_numpy(detections.confidence), | |
nms_threshold | |
).numpy().tolist() | |
detections.xyxy = detections.xyxy[nms_idx] | |
detections.confidence = detections.confidence[nms_idx] | |
detections.class_id = detections.class_id[nms_idx] | |
labels = [ | |
f"{class_list[class_id]} {confidence:0.2f}" | |
for _, _, confidence, class_id, _ | |
in detections] | |
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections, labels=labels) | |
if pil: | |
return PIL.Image.fromarray(annotated_frame[:, :, ::-1]), detections | |
else: | |
return annotated_frame, detections | |
def forward(self, | |
img: Union[Image, np.ndarray], | |
class_list: [List], | |
box_threshold: float=0.25, | |
text_threshold: float=0.25 | |
)->sv.Detections: | |
if isinstance(img, Image): | |
img = np.uint8(img)[:, :, ::-1] | |
detections = self.grounding_dino_model.predict_with_classes( | |
image=img, | |
classes=class_list, | |
box_threshold=box_threshold, | |
text_threshold=text_threshold | |
) | |
return detections | |