Spaces:
Runtime error
Runtime error
File size: 5,106 Bytes
df6c67d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import os
import urllib.request
from time import perf_counter
from typing import Any
import torch
from groundingdino.util.inference import Model
from inference.core.entities.requests.groundingdino import GroundingDINOInferenceRequest
from inference.core.entities.requests.inference import InferenceRequestImage
from inference.core.entities.responses.inference import (
InferenceResponseImage,
ObjectDetectionInferenceResponse,
ObjectDetectionPrediction,
)
from inference.core.env import MODEL_CACHE_DIR
from inference.core.models.roboflow import RoboflowCoreModel
from inference.core.utils.image_utils import load_image_rgb, xyxy_to_xywh
class GroundingDINO(RoboflowCoreModel):
"""GroundingDINO class for zero-shot object detection.
Attributes:
model: The GroundingDINO model.
"""
def __init__(
self, *args, model_id="grounding_dino/groundingdino_swint_ogc", **kwargs
):
"""Initializes the GroundingDINO model.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
super().__init__(*args, model_id=model_id, **kwargs)
GROUDNING_DINO_CACHE_DIR = os.path.join(MODEL_CACHE_DIR, model_id)
GROUNDING_DINO_CONFIG_PATH = os.path.join(
GROUDNING_DINO_CACHE_DIR, "GroundingDINO_SwinT_OGC.py"
)
# GROUNDING_DINO_CHECKPOINT_PATH = os.path.join(
# GROUDNING_DINO_CACHE_DIR, "groundingdino_swint_ogc.pth"
# )
if not os.path.exists(GROUDNING_DINO_CACHE_DIR):
os.makedirs(GROUDNING_DINO_CACHE_DIR)
if not os.path.exists(GROUNDING_DINO_CONFIG_PATH):
url = "https://raw.githubusercontent.com/roboflow/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinT_OGC.py"
urllib.request.urlretrieve(url, GROUNDING_DINO_CONFIG_PATH)
# if not os.path.exists(GROUNDING_DINO_CHECKPOINT_PATH):
# url = "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth"
# urllib.request.urlretrieve(url, GROUNDING_DINO_CHECKPOINT_PATH)
self.model = Model(
model_config_path=GROUNDING_DINO_CONFIG_PATH,
model_checkpoint_path=os.path.join(
GROUDNING_DINO_CACHE_DIR, "groundingdino_swint_ogc.pth"
),
device="cuda" if torch.cuda.is_available() else "cpu",
)
def preproc_image(self, image: Any):
"""Preprocesses an image.
Args:
image (InferenceRequestImage): The image to preprocess.
Returns:
np.array: The preprocessed image.
"""
np_image = load_image_rgb(image)
return np_image
def infer_from_request(
self,
request: GroundingDINOInferenceRequest,
) -> ObjectDetectionInferenceResponse:
"""
Perform inference based on the details provided in the request, and return the associated responses.
"""
result = self.infer(**request.dict())
return result
def infer(
self, image: Any = None, text: list = None, class_filter: list = None, **kwargs
):
"""
Run inference on a provided image.
Args:
request (CVInferenceRequest): The inference request.
class_filter (Optional[List[str]]): A list of class names to filter, if provided.
Returns:
GroundingDINOInferenceRequest: The inference response.
"""
t1 = perf_counter()
image = self.preproc_image(image)
img_dims = image.shape
detections = self.model.predict_with_classes(
image=image,
classes=text,
box_threshold=0.5,
text_threshold=0.5,
)
self.class_names = text
xywh_bboxes = [xyxy_to_xywh(detection) for detection in detections.xyxy]
t2 = perf_counter() - t1
responses = ObjectDetectionInferenceResponse(
predictions=[
ObjectDetectionPrediction(
**{
"x": xywh_bboxes[i][0],
"y": xywh_bboxes[i][1],
"width": xywh_bboxes[i][2],
"height": xywh_bboxes[i][3],
"confidence": detections.confidence[i],
"class": self.class_names[int(detections.class_id[i])],
"class_id": int(detections.class_id[i]),
}
)
for i, pred in enumerate(detections.xyxy)
if not class_filter or self.class_names[int(pred[6])] in class_filter
],
image=InferenceResponseImage(width=img_dims[1], height=img_dims[0]),
time=t2,
)
return responses
def get_infer_bucket_file_list(self) -> list:
"""Get the list of required files for inference.
Returns:
list: A list of required files for inference, e.g., ["model.pt"].
"""
return ["groundingdino_swint_ogc.pth"]
|