OMG / inference /models /grounding_dino /grounding_dino.py
Fucius's picture
Upload 422 files
df6c67d verified
raw
history blame
5.11 kB
import os
import urllib.request
from time import perf_counter
from typing import Any
import torch
from groundingdino.util.inference import Model
from inference.core.entities.requests.groundingdino import GroundingDINOInferenceRequest
from inference.core.entities.requests.inference import InferenceRequestImage
from inference.core.entities.responses.inference import (
InferenceResponseImage,
ObjectDetectionInferenceResponse,
ObjectDetectionPrediction,
)
from inference.core.env import MODEL_CACHE_DIR
from inference.core.models.roboflow import RoboflowCoreModel
from inference.core.utils.image_utils import load_image_rgb, xyxy_to_xywh
class GroundingDINO(RoboflowCoreModel):
"""GroundingDINO class for zero-shot object detection.
Attributes:
model: The GroundingDINO model.
"""
def __init__(
self, *args, model_id="grounding_dino/groundingdino_swint_ogc", **kwargs
):
"""Initializes the GroundingDINO model.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
super().__init__(*args, model_id=model_id, **kwargs)
GROUDNING_DINO_CACHE_DIR = os.path.join(MODEL_CACHE_DIR, model_id)
GROUNDING_DINO_CONFIG_PATH = os.path.join(
GROUDNING_DINO_CACHE_DIR, "GroundingDINO_SwinT_OGC.py"
)
# GROUNDING_DINO_CHECKPOINT_PATH = os.path.join(
# GROUDNING_DINO_CACHE_DIR, "groundingdino_swint_ogc.pth"
# )
if not os.path.exists(GROUDNING_DINO_CACHE_DIR):
os.makedirs(GROUDNING_DINO_CACHE_DIR)
if not os.path.exists(GROUNDING_DINO_CONFIG_PATH):
url = "https://raw.githubusercontent.com/roboflow/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinT_OGC.py"
urllib.request.urlretrieve(url, GROUNDING_DINO_CONFIG_PATH)
# if not os.path.exists(GROUNDING_DINO_CHECKPOINT_PATH):
# url = "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth"
# urllib.request.urlretrieve(url, GROUNDING_DINO_CHECKPOINT_PATH)
self.model = Model(
model_config_path=GROUNDING_DINO_CONFIG_PATH,
model_checkpoint_path=os.path.join(
GROUDNING_DINO_CACHE_DIR, "groundingdino_swint_ogc.pth"
),
device="cuda" if torch.cuda.is_available() else "cpu",
)
def preproc_image(self, image: Any):
"""Preprocesses an image.
Args:
image (InferenceRequestImage): The image to preprocess.
Returns:
np.array: The preprocessed image.
"""
np_image = load_image_rgb(image)
return np_image
def infer_from_request(
self,
request: GroundingDINOInferenceRequest,
) -> ObjectDetectionInferenceResponse:
"""
Perform inference based on the details provided in the request, and return the associated responses.
"""
result = self.infer(**request.dict())
return result
def infer(
self, image: Any = None, text: list = None, class_filter: list = None, **kwargs
):
"""
Run inference on a provided image.
Args:
request (CVInferenceRequest): The inference request.
class_filter (Optional[List[str]]): A list of class names to filter, if provided.
Returns:
GroundingDINOInferenceRequest: The inference response.
"""
t1 = perf_counter()
image = self.preproc_image(image)
img_dims = image.shape
detections = self.model.predict_with_classes(
image=image,
classes=text,
box_threshold=0.5,
text_threshold=0.5,
)
self.class_names = text
xywh_bboxes = [xyxy_to_xywh(detection) for detection in detections.xyxy]
t2 = perf_counter() - t1
responses = ObjectDetectionInferenceResponse(
predictions=[
ObjectDetectionPrediction(
**{
"x": xywh_bboxes[i][0],
"y": xywh_bboxes[i][1],
"width": xywh_bboxes[i][2],
"height": xywh_bboxes[i][3],
"confidence": detections.confidence[i],
"class": self.class_names[int(detections.class_id[i])],
"class_id": int(detections.class_id[i]),
}
)
for i, pred in enumerate(detections.xyxy)
if not class_filter or self.class_names[int(pred[6])] in class_filter
],
image=InferenceResponseImage(width=img_dims[1], height=img_dims[0]),
time=t2,
)
return responses
def get_infer_bucket_file_list(self) -> list:
"""Get the list of required files for inference.
Returns:
list: A list of required files for inference, e.g., ["model.pt"].
"""
return ["groundingdino_swint_ogc.pth"]