Mobile-SAM / sam_extension /pipeline /groundingdino.py
YouLiXiya's picture
Upload 22 files
7dbe662
raw
history blame
No virus
3.77 kB
import os
import functools
import PIL
from PIL.Image import Image
import numpy as np
from typing import List, Union
import supervision as sv
import torch
import torchvision
from huggingface_hub import hf_hub_download
from sam_extension.pipeline import Pipeline
from groundingdino.util.inference import Model
GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT_PATH = "groundingdino_swint_ogc.pth"
SAM_REPO_ID = 'YouLiXiya/YL-SAM'
LOCAL_DIR = "weights/groundingdino"
hf_sam_download = functools.partial(hf_hub_download, repo_id=SAM_REPO_ID, local_dir=LOCAL_DIR, local_dir_use_symlinks=True)
class GroundingDinoPipeline(Pipeline):
def __init__(self,
grounding_dino_config_path,
grounfing_dino_ckpt_path,
grounding_dino_model,
device,
*args,
**kwargs):
super(GroundingDinoPipeline, self).__init__(*args, **kwargs)
self.grounding_dino_config_path = grounding_dino_config_path
self.grounfing_dino_ckpt_path = grounfing_dino_ckpt_path
self.grounding_dino_model = grounding_dino_model
self.device = device
@classmethod
def from_pretrained(cls, grounding_dino_config_path, grounfing_dino_ckpt_path,device='cuda', *args, **kwargs):
if not os.path.exists(grounfing_dino_ckpt_path):
hf_sam_download(filename=os.path.basename(grounfing_dino_ckpt_path))
grounding_dino_model = Model(model_config_path=grounding_dino_config_path,
model_checkpoint_path=grounfing_dino_ckpt_path,
device=device)
return cls(grounding_dino_config_path,
grounfing_dino_ckpt_path,
grounding_dino_model,
device,
*args,
**kwargs)
def visualize_results(self,
img: Union[Image, np.ndarray],
class_list: [List],
box_threshold: float=0.25,
text_threshold: float=0.25,
nms_threshold: float=0.8,
pil: bool=True):
detections = self.forward(img, class_list, box_threshold, text_threshold)
box_annotator = sv.BoxAnnotator()
nms_idx = torchvision.ops.nms(
torch.from_numpy(detections.xyxy),
torch.from_numpy(detections.confidence),
nms_threshold
).numpy().tolist()
detections.xyxy = detections.xyxy[nms_idx]
detections.confidence = detections.confidence[nms_idx]
detections.class_id = detections.class_id[nms_idx]
labels = [
f"{class_list[class_id]} {confidence:0.2f}"
for _, _, confidence, class_id, _
in detections]
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections, labels=labels)
if pil:
return PIL.Image.fromarray(annotated_frame[:, :, ::-1]), detections
else:
return annotated_frame, detections
@torch.no_grad()
def forward(self,
img: Union[Image, np.ndarray],
class_list: [List],
box_threshold: float=0.25,
text_threshold: float=0.25
)->sv.Detections:
if isinstance(img, Image):
img = np.uint8(img)[:, :, ::-1]
detections = self.grounding_dino_model.predict_with_classes(
image=img,
classes=class_list,
box_threshold=box_threshold,
text_threshold=text_threshold
)
return detections