from time import perf_counter from typing import Any from ultralytics import YOLO from inference.core.cache import cache from inference.core.entities.requests.yolo_world import YOLOWorldInferenceRequest from inference.core.entities.responses.inference import ( InferenceResponseImage, ObjectDetectionInferenceResponse, ObjectDetectionPrediction, ) from inference.core.models.defaults import DEFAULT_CONFIDENCE from inference.core.models.roboflow import RoboflowCoreModel from inference.core.utils.hash import get_string_list_hash from inference.core.utils.image_utils import load_image_rgb class YOLOWorld(RoboflowCoreModel): """GroundingDINO class for zero-shot object detection. Attributes: model: The GroundingDINO model. """ def __init__(self, *args, model_id="yolo_world/l", **kwargs): """Initializes the YOLO-World model. Args: *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. """ super().__init__(*args, model_id=model_id, **kwargs) self.model = YOLO(self.cache_file("yolo-world.pt")) self.class_names = None def preproc_image(self, image: Any): """Preprocesses an image. Args: image (InferenceRequestImage): The image to preprocess. Returns: np.array: The preprocessed image. """ np_image = load_image_rgb(image) return np_image[:, :, ::-1] def infer_from_request( self, request: YOLOWorldInferenceRequest, ) -> ObjectDetectionInferenceResponse: """ Perform inference based on the details provided in the request, and return the associated responses. """ result = self.infer(**request.dict()) return result def infer( self, image: Any = None, text: list = None, confidence: float = DEFAULT_CONFIDENCE, **kwargs, ): """ Run inference on a provided image. Args: request (CVInferenceRequest): The inference request. class_filter (Optional[List[str]]): A list of class names to filter, if provided. Returns: GroundingDINOInferenceRequest: The inference response. """ t1 = perf_counter() image = self.preproc_image(image) img_dims = image.shape if text is not None and text != self.class_names: self.set_classes(text) if self.class_names is None: raise ValueError( "Class names not set and not provided in the request. Must set class names before inference or provide them via the argument `text`." ) results = self.model.predict( image, conf=confidence, verbose=False, )[0] t2 = perf_counter() - t1 predictions = [] for i, box in enumerate(results.boxes): x, y, w, h = box.xywh.tolist()[0] class_id = int(box.cls) predictions.append( ObjectDetectionPrediction( **{ "x": x, "y": y, "width": w, "height": h, "confidence": float(box.conf), "class": self.class_names[class_id], "class_id": class_id, } ) ) responses = ObjectDetectionInferenceResponse( predictions=predictions, image=InferenceResponseImage(width=img_dims[1], height=img_dims[0]), time=t2, ) return responses def set_classes(self, text: list): """Set the class names for the model. Args: text (list): The class names. """ text_hash = get_string_list_hash(text) cached_embeddings = cache.get_numpy(text_hash) if cached_embeddings is not None: self.model.model.txt_feats = cached_embeddings self.model.model.model[-1].nc = len(text) else: self.model.set_classes(text) cache.set_numpy(text_hash, self.model.model.txt_feats, expire=300) self.class_names = text def get_infer_bucket_file_list(self) -> list: """Get the list of required files for inference. Returns: list: A list of required files for inference, e.g., ["model.pt"]. """ return ["yolo-world.pt"]