import time from typing import Dict, List, Optional, Tuple import numpy as np from fastapi.encoders import jsonable_encoder from inference.core.cache import cache from inference.core.cache.serializers import to_cachable_inference_item from inference.core.devices.utils import GLOBAL_INFERENCE_SERVER_ID from inference.core.entities.requests.inference import InferenceRequest from inference.core.entities.responses.inference import InferenceResponse from inference.core.env import ( DISABLE_INFERENCE_CACHE, METRICS_ENABLED, METRICS_INTERVAL, ROBOFLOW_SERVER_UUID, ) from inference.core.exceptions import InferenceModelNotFound from inference.core.logger import logger from inference.core.managers.entities import ModelDescription from inference.core.managers.pingback import PingbackInfo from inference.core.models.base import Model, PreprocessReturnMetadata from inference.core.registries.base import ModelRegistry class ModelManager: """Model managers keep track of a dictionary of Model objects and is responsible for passing requests to the right model using the infer method.""" def __init__(self, model_registry: ModelRegistry, models: Optional[dict] = None): self.model_registry = model_registry self._models: Dict[str, Model] = models if models is not None else {} def init_pingback(self): """Initializes pingback mechanism.""" self.num_errors = 0 # in the device self.uuid = ROBOFLOW_SERVER_UUID if METRICS_ENABLED: self.pingback = PingbackInfo(self) self.pingback.start() def add_model( self, model_id: str, api_key: str, model_id_alias: Optional[str] = None ) -> None: """Adds a new model to the manager. Args: model_id (str): The identifier of the model. model (Model): The model instance. """ logger.debug( f"ModelManager - Adding model with model_id={model_id}, model_id_alias={model_id_alias}" ) if model_id in self._models: logger.debug( f"ModelManager - model with model_id={model_id} is already loaded." ) return logger.debug("ModelManager - model initialisation...") model = self.model_registry.get_model( model_id if model_id_alias is None else model_id_alias, api_key )( model_id=model_id, api_key=api_key, ) logger.debug("ModelManager - model successfully loaded.") self._models[model_id if model_id_alias is None else model_id_alias] = model def check_for_model(self, model_id: str) -> None: """Checks whether the model with the given ID is in the manager. Args: model_id (str): The identifier of the model. Raises: InferenceModelNotFound: If the model is not found in the manager. """ if model_id not in self: raise InferenceModelNotFound(f"Model with id {model_id} not loaded.") async def infer_from_request( self, model_id: str, request: InferenceRequest, **kwargs ) -> InferenceResponse: """Runs inference on the specified model with the given request. Args: model_id (str): The identifier of the model. request (InferenceRequest): The request to process. Returns: InferenceResponse: The response from the inference. """ logger.debug( f"ModelManager - inference from request started for model_id={model_id}." ) try: rtn_val = await self.model_infer( model_id=model_id, request=request, **kwargs ) logger.debug( f"ModelManager - inference from request finished for model_id={model_id}." ) finish_time = time.time() if not DISABLE_INFERENCE_CACHE: logger.debug( f"ModelManager - caching inference request started for model_id={model_id}" ) cache.zadd( f"models", value=f"{GLOBAL_INFERENCE_SERVER_ID}:{request.api_key}:{model_id}", score=finish_time, expire=METRICS_INTERVAL * 2, ) if ( hasattr(request, "image") and hasattr(request.image, "type") and request.image.type == "numpy" ): request.image.value = str(request.image.value) cache.zadd( f"inference:{GLOBAL_INFERENCE_SERVER_ID}:{model_id}", value=to_cachable_inference_item(request, rtn_val), score=finish_time, expire=METRICS_INTERVAL * 2, ) logger.debug( f"ModelManager - caching inference request finished for model_id={model_id}" ) return rtn_val except Exception as e: finish_time = time.time() if not DISABLE_INFERENCE_CACHE: cache.zadd( f"models", value=f"{GLOBAL_INFERENCE_SERVER_ID}:{request.api_key}:{model_id}", score=finish_time, expire=METRICS_INTERVAL * 2, ) cache.zadd( f"error:{GLOBAL_INFERENCE_SERVER_ID}:{model_id}", value={ "request": jsonable_encoder( request.dict(exclude={"image", "subject", "prompt"}) ), "error": str(e), }, score=finish_time, expire=METRICS_INTERVAL * 2, ) raise async def model_infer(self, model_id: str, request: InferenceRequest, **kwargs): self.check_for_model(model_id) return self._models[model_id].infer_from_request(request) def make_response( self, model_id: str, predictions: List[List[float]], *args, **kwargs ) -> InferenceResponse: """Creates a response object from the model's predictions. Args: model_id (str): The identifier of the model. predictions (List[List[float]]): The model's predictions. Returns: InferenceResponse: The created response object. """ self.check_for_model(model_id) return self._models[model_id].make_response(predictions, *args, **kwargs) def postprocess( self, model_id: str, predictions: Tuple[np.ndarray, ...], preprocess_return_metadata: PreprocessReturnMetadata, *args, **kwargs, ) -> List[List[float]]: """Processes the model's predictions after inference. Args: model_id (str): The identifier of the model. predictions (np.ndarray): The model's predictions. Returns: List[List[float]]: The post-processed predictions. """ self.check_for_model(model_id) return self._models[model_id].postprocess( predictions, preprocess_return_metadata, *args, **kwargs ) def predict(self, model_id: str, *args, **kwargs) -> Tuple[np.ndarray, ...]: """Runs prediction on the specified model. Args: model_id (str): The identifier of the model. Returns: np.ndarray: The predictions from the model. """ self.check_for_model(model_id) self._models[model_id].metrics["num_inferences"] += 1 tic = time.perf_counter() res = self._models[model_id].predict(*args, **kwargs) toc = time.perf_counter() self._models[model_id].metrics["avg_inference_time"] += toc - tic return res def preprocess( self, model_id: str, request: InferenceRequest ) -> Tuple[np.ndarray, PreprocessReturnMetadata]: """Preprocesses the request before inference. Args: model_id (str): The identifier of the model. request (InferenceRequest): The request to preprocess. Returns: Tuple[np.ndarray, List[Tuple[int, int]]]: The preprocessed data. """ self.check_for_model(model_id) return self._models[model_id].preprocess(**request.dict()) def get_class_names(self, model_id): """Retrieves the class names for a given model. Args: model_id (str): The identifier of the model. Returns: List[str]: The class names of the model. """ self.check_for_model(model_id) return self._models[model_id].class_names def get_task_type(self, model_id: str, api_key: str = None) -> str: """Retrieves the task type for a given model. Args: model_id (str): The identifier of the model. Returns: str: The task type of the model. """ self.check_for_model(model_id) return self._models[model_id].task_type def remove(self, model_id: str) -> None: """Removes a model from the manager. Args: model_id (str): The identifier of the model. """ try: self.check_for_model(model_id) self._models[model_id].clear_cache() del self._models[model_id] except InferenceModelNotFound: logger.warning( f"Attempted to remove model with id {model_id}, but it is not loaded. Skipping..." ) def clear(self) -> None: """Removes all models from the manager.""" for model_id in list(self.keys()): self.remove(model_id) def __contains__(self, model_id: str) -> bool: """Checks if the model is contained in the manager. Args: model_id (str): The identifier of the model. Returns: bool: Whether the model is in the manager. """ return model_id in self._models def __getitem__(self, key: str) -> Model: """Retrieve a model from the manager by key. Args: key (str): The identifier of the model. Returns: Model: The model corresponding to the key. """ self.check_for_model(model_id=key) return self._models[key] def __len__(self) -> int: """Retrieve the number of models in the manager. Returns: int: The number of models in the manager. """ return len(self._models) def keys(self): """Retrieve the keys (model identifiers) from the manager. Returns: List[str]: The keys of the models in the manager. """ return self._models.keys() def models(self) -> Dict[str, Model]: """Retrieve the models dictionary from the manager. Returns: Dict[str, Model]: The keys of the models in the manager. """ return self._models def describe_models(self) -> List[ModelDescription]: return [ ModelDescription( model_id=model_id, task_type=model.task_type, batch_size=getattr(model, "batch_size", None), input_width=getattr(model, "img_size_w", None), input_height=getattr(model, "img_size_h", None), ) for model_id, model in self._models.items() ]