File size: 4,537 Bytes
2eafbc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from time import perf_counter
from typing import Any

from ultralytics import YOLO

from inference.core.cache import cache
from inference.core.entities.requests.yolo_world import YOLOWorldInferenceRequest
from inference.core.entities.responses.inference import (
    InferenceResponseImage,
    ObjectDetectionInferenceResponse,
    ObjectDetectionPrediction,
)
from inference.core.models.defaults import DEFAULT_CONFIDENCE
from inference.core.models.roboflow import RoboflowCoreModel
from inference.core.utils.hash import get_string_list_hash
from inference.core.utils.image_utils import load_image_rgb


class YOLOWorld(RoboflowCoreModel):
    """GroundingDINO class for zero-shot object detection.

    Attributes:
        model: The GroundingDINO model.
    """

    def __init__(self, *args, model_id="yolo_world/l", **kwargs):
        """Initializes the YOLO-World model.

        Args:
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.
        """

        super().__init__(*args, model_id=model_id, **kwargs)

        self.model = YOLO(self.cache_file("yolo-world.pt"))
        self.class_names = None

    def preproc_image(self, image: Any):
        """Preprocesses an image.

        Args:
            image (InferenceRequestImage): The image to preprocess.

        Returns:
            np.array: The preprocessed image.
        """
        np_image = load_image_rgb(image)
        return np_image[:, :, ::-1]

    def infer_from_request(
        self,
        request: YOLOWorldInferenceRequest,
    ) -> ObjectDetectionInferenceResponse:
        """
        Perform inference based on the details provided in the request, and return the associated responses.
        """
        result = self.infer(**request.dict())
        return result

    def infer(
        self,
        image: Any = None,
        text: list = None,
        confidence: float = DEFAULT_CONFIDENCE,
        **kwargs,
    ):
        """
        Run inference on a provided image.

        Args:
            request (CVInferenceRequest): The inference request.
            class_filter (Optional[List[str]]): A list of class names to filter, if provided.

        Returns:
            GroundingDINOInferenceRequest: The inference response.
        """
        t1 = perf_counter()
        image = self.preproc_image(image)
        img_dims = image.shape

        if text is not None and text != self.class_names:
            self.set_classes(text)
        if self.class_names is None:
            raise ValueError(
                "Class names not set and not provided in the request. Must set class names before inference or provide them via the argument `text`."
            )
        results = self.model.predict(
            image,
            conf=confidence,
            verbose=False,
        )[0]

        t2 = perf_counter() - t1

        predictions = []
        for i, box in enumerate(results.boxes):
            x, y, w, h = box.xywh.tolist()[0]
            class_id = int(box.cls)
            predictions.append(
                ObjectDetectionPrediction(
                    **{
                        "x": x,
                        "y": y,
                        "width": w,
                        "height": h,
                        "confidence": float(box.conf),
                        "class": self.class_names[class_id],
                        "class_id": class_id,
                    }
                )
            )

        responses = ObjectDetectionInferenceResponse(
            predictions=predictions,
            image=InferenceResponseImage(width=img_dims[1], height=img_dims[0]),
            time=t2,
        )
        return responses

    def set_classes(self, text: list):
        """Set the class names for the model.

        Args:
            text (list): The class names.
        """
        text_hash = get_string_list_hash(text)
        cached_embeddings = cache.get_numpy(text_hash)
        if cached_embeddings is not None:
            self.model.model.txt_feats = cached_embeddings
            self.model.model.model[-1].nc = len(text)
        else:
            self.model.set_classes(text)
            cache.set_numpy(text_hash, self.model.model.txt_feats, expire=300)
        self.class_names = text

    def get_infer_bucket_file_list(self) -> list:
        """Get the list of required files for inference.

        Returns:
            list: A list of required files for inference, e.g., ["model.pt"].
        """
        return ["yolo-world.pt"]