import os import functools import PIL from PIL.Image import Image import numpy as np from typing import List, Union import supervision as sv import torch import torchvision from huggingface_hub import hf_hub_download from sam_extension.pipeline import Pipeline from groundingdino.util.inference import Model GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" GROUNDING_DINO_CHECKPOINT_PATH = "groundingdino_swint_ogc.pth" SAM_REPO_ID = 'YouLiXiya/YL-SAM' LOCAL_DIR = "weights/groundingdino" hf_sam_download = functools.partial(hf_hub_download, repo_id=SAM_REPO_ID, local_dir=LOCAL_DIR, local_dir_use_symlinks=True) class GroundingDinoPipeline(Pipeline): def __init__(self, grounding_dino_config_path, grounfing_dino_ckpt_path, grounding_dino_model, device, *args, **kwargs): super(GroundingDinoPipeline, self).__init__(*args, **kwargs) self.grounding_dino_config_path = grounding_dino_config_path self.grounfing_dino_ckpt_path = grounfing_dino_ckpt_path self.grounding_dino_model = grounding_dino_model self.device = device @classmethod def from_pretrained(cls, grounding_dino_config_path, grounfing_dino_ckpt_path,device='cuda', *args, **kwargs): if not os.path.exists(grounfing_dino_ckpt_path): hf_sam_download(filename=os.path.basename(grounfing_dino_ckpt_path)) grounding_dino_model = Model(model_config_path=grounding_dino_config_path, model_checkpoint_path=grounfing_dino_ckpt_path, device=device) return cls(grounding_dino_config_path, grounfing_dino_ckpt_path, grounding_dino_model, device, *args, **kwargs) def visualize_results(self, img: Union[Image, np.ndarray], class_list: [List], box_threshold: float=0.25, text_threshold: float=0.25, nms_threshold: float=0.8, pil: bool=True): detections = self.forward(img, class_list, box_threshold, text_threshold) box_annotator = sv.BoxAnnotator() nms_idx = torchvision.ops.nms( torch.from_numpy(detections.xyxy), torch.from_numpy(detections.confidence), nms_threshold ).numpy().tolist() detections.xyxy = detections.xyxy[nms_idx] detections.confidence = detections.confidence[nms_idx] detections.class_id = detections.class_id[nms_idx] labels = [ f"{class_list[class_id]} {confidence:0.2f}" for _, _, confidence, class_id, _ in detections] annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections, labels=labels) if pil: return PIL.Image.fromarray(annotated_frame[:, :, ::-1]), detections else: return annotated_frame, detections @torch.no_grad() def forward(self, img: Union[Image, np.ndarray], class_list: [List], box_threshold: float=0.25, text_threshold: float=0.25 )->sv.Detections: if isinstance(img, Image): img = np.uint8(img)[:, :, ::-1] detections = self.grounding_dino_model.predict_with_classes( image=img, classes=class_list, box_threshold=box_threshold, text_threshold=text_threshold ) return detections