Spaces:
Runtime error
Runtime error
import time | |
from typing import Dict, List, Optional, Tuple | |
import numpy as np | |
from fastapi.encoders import jsonable_encoder | |
from inference.core.cache import cache | |
from inference.core.cache.serializers import to_cachable_inference_item | |
from inference.core.devices.utils import GLOBAL_INFERENCE_SERVER_ID | |
from inference.core.entities.requests.inference import InferenceRequest | |
from inference.core.entities.responses.inference import InferenceResponse | |
from inference.core.env import ( | |
DISABLE_INFERENCE_CACHE, | |
METRICS_ENABLED, | |
METRICS_INTERVAL, | |
ROBOFLOW_SERVER_UUID, | |
) | |
from inference.core.exceptions import InferenceModelNotFound | |
from inference.core.logger import logger | |
from inference.core.managers.entities import ModelDescription | |
from inference.core.managers.pingback import PingbackInfo | |
from inference.core.models.base import Model, PreprocessReturnMetadata | |
from inference.core.registries.base import ModelRegistry | |
class ModelManager: | |
"""Model managers keep track of a dictionary of Model objects and is responsible for passing requests to the right model using the infer method.""" | |
def __init__(self, model_registry: ModelRegistry, models: Optional[dict] = None): | |
self.model_registry = model_registry | |
self._models: Dict[str, Model] = models if models is not None else {} | |
def init_pingback(self): | |
"""Initializes pingback mechanism.""" | |
self.num_errors = 0 # in the device | |
self.uuid = ROBOFLOW_SERVER_UUID | |
if METRICS_ENABLED: | |
self.pingback = PingbackInfo(self) | |
self.pingback.start() | |
def add_model( | |
self, model_id: str, api_key: str, model_id_alias: Optional[str] = None | |
) -> None: | |
"""Adds a new model to the manager. | |
Args: | |
model_id (str): The identifier of the model. | |
model (Model): The model instance. | |
""" | |
logger.debug( | |
f"ModelManager - Adding model with model_id={model_id}, model_id_alias={model_id_alias}" | |
) | |
if model_id in self._models: | |
logger.debug( | |
f"ModelManager - model with model_id={model_id} is already loaded." | |
) | |
return | |
logger.debug("ModelManager - model initialisation...") | |
model = self.model_registry.get_model( | |
model_id if model_id_alias is None else model_id_alias, api_key | |
)( | |
model_id=model_id, | |
api_key=api_key, | |
) | |
logger.debug("ModelManager - model successfully loaded.") | |
self._models[model_id if model_id_alias is None else model_id_alias] = model | |
def check_for_model(self, model_id: str) -> None: | |
"""Checks whether the model with the given ID is in the manager. | |
Args: | |
model_id (str): The identifier of the model. | |
Raises: | |
InferenceModelNotFound: If the model is not found in the manager. | |
""" | |
if model_id not in self: | |
raise InferenceModelNotFound(f"Model with id {model_id} not loaded.") | |
async def infer_from_request( | |
self, model_id: str, request: InferenceRequest, **kwargs | |
) -> InferenceResponse: | |
"""Runs inference on the specified model with the given request. | |
Args: | |
model_id (str): The identifier of the model. | |
request (InferenceRequest): The request to process. | |
Returns: | |
InferenceResponse: The response from the inference. | |
""" | |
logger.debug( | |
f"ModelManager - inference from request started for model_id={model_id}." | |
) | |
try: | |
rtn_val = await self.model_infer( | |
model_id=model_id, request=request, **kwargs | |
) | |
logger.debug( | |
f"ModelManager - inference from request finished for model_id={model_id}." | |
) | |
finish_time = time.time() | |
if not DISABLE_INFERENCE_CACHE: | |
logger.debug( | |
f"ModelManager - caching inference request started for model_id={model_id}" | |
) | |
cache.zadd( | |
f"models", | |
value=f"{GLOBAL_INFERENCE_SERVER_ID}:{request.api_key}:{model_id}", | |
score=finish_time, | |
expire=METRICS_INTERVAL * 2, | |
) | |
if ( | |
hasattr(request, "image") | |
and hasattr(request.image, "type") | |
and request.image.type == "numpy" | |
): | |
request.image.value = str(request.image.value) | |
cache.zadd( | |
f"inference:{GLOBAL_INFERENCE_SERVER_ID}:{model_id}", | |
value=to_cachable_inference_item(request, rtn_val), | |
score=finish_time, | |
expire=METRICS_INTERVAL * 2, | |
) | |
logger.debug( | |
f"ModelManager - caching inference request finished for model_id={model_id}" | |
) | |
return rtn_val | |
except Exception as e: | |
finish_time = time.time() | |
if not DISABLE_INFERENCE_CACHE: | |
cache.zadd( | |
f"models", | |
value=f"{GLOBAL_INFERENCE_SERVER_ID}:{request.api_key}:{model_id}", | |
score=finish_time, | |
expire=METRICS_INTERVAL * 2, | |
) | |
cache.zadd( | |
f"error:{GLOBAL_INFERENCE_SERVER_ID}:{model_id}", | |
value={ | |
"request": jsonable_encoder( | |
request.dict(exclude={"image", "subject", "prompt"}) | |
), | |
"error": str(e), | |
}, | |
score=finish_time, | |
expire=METRICS_INTERVAL * 2, | |
) | |
raise | |
async def model_infer(self, model_id: str, request: InferenceRequest, **kwargs): | |
self.check_for_model(model_id) | |
return self._models[model_id].infer_from_request(request) | |
def make_response( | |
self, model_id: str, predictions: List[List[float]], *args, **kwargs | |
) -> InferenceResponse: | |
"""Creates a response object from the model's predictions. | |
Args: | |
model_id (str): The identifier of the model. | |
predictions (List[List[float]]): The model's predictions. | |
Returns: | |
InferenceResponse: The created response object. | |
""" | |
self.check_for_model(model_id) | |
return self._models[model_id].make_response(predictions, *args, **kwargs) | |
def postprocess( | |
self, | |
model_id: str, | |
predictions: Tuple[np.ndarray, ...], | |
preprocess_return_metadata: PreprocessReturnMetadata, | |
*args, | |
**kwargs, | |
) -> List[List[float]]: | |
"""Processes the model's predictions after inference. | |
Args: | |
model_id (str): The identifier of the model. | |
predictions (np.ndarray): The model's predictions. | |
Returns: | |
List[List[float]]: The post-processed predictions. | |
""" | |
self.check_for_model(model_id) | |
return self._models[model_id].postprocess( | |
predictions, preprocess_return_metadata, *args, **kwargs | |
) | |
def predict(self, model_id: str, *args, **kwargs) -> Tuple[np.ndarray, ...]: | |
"""Runs prediction on the specified model. | |
Args: | |
model_id (str): The identifier of the model. | |
Returns: | |
np.ndarray: The predictions from the model. | |
""" | |
self.check_for_model(model_id) | |
self._models[model_id].metrics["num_inferences"] += 1 | |
tic = time.perf_counter() | |
res = self._models[model_id].predict(*args, **kwargs) | |
toc = time.perf_counter() | |
self._models[model_id].metrics["avg_inference_time"] += toc - tic | |
return res | |
def preprocess( | |
self, model_id: str, request: InferenceRequest | |
) -> Tuple[np.ndarray, PreprocessReturnMetadata]: | |
"""Preprocesses the request before inference. | |
Args: | |
model_id (str): The identifier of the model. | |
request (InferenceRequest): The request to preprocess. | |
Returns: | |
Tuple[np.ndarray, List[Tuple[int, int]]]: The preprocessed data. | |
""" | |
self.check_for_model(model_id) | |
return self._models[model_id].preprocess(**request.dict()) | |
def get_class_names(self, model_id): | |
"""Retrieves the class names for a given model. | |
Args: | |
model_id (str): The identifier of the model. | |
Returns: | |
List[str]: The class names of the model. | |
""" | |
self.check_for_model(model_id) | |
return self._models[model_id].class_names | |
def get_task_type(self, model_id: str, api_key: str = None) -> str: | |
"""Retrieves the task type for a given model. | |
Args: | |
model_id (str): The identifier of the model. | |
Returns: | |
str: The task type of the model. | |
""" | |
self.check_for_model(model_id) | |
return self._models[model_id].task_type | |
def remove(self, model_id: str) -> None: | |
"""Removes a model from the manager. | |
Args: | |
model_id (str): The identifier of the model. | |
""" | |
try: | |
self.check_for_model(model_id) | |
self._models[model_id].clear_cache() | |
del self._models[model_id] | |
except InferenceModelNotFound: | |
logger.warning( | |
f"Attempted to remove model with id {model_id}, but it is not loaded. Skipping..." | |
) | |
def clear(self) -> None: | |
"""Removes all models from the manager.""" | |
for model_id in list(self.keys()): | |
self.remove(model_id) | |
def __contains__(self, model_id: str) -> bool: | |
"""Checks if the model is contained in the manager. | |
Args: | |
model_id (str): The identifier of the model. | |
Returns: | |
bool: Whether the model is in the manager. | |
""" | |
return model_id in self._models | |
def __getitem__(self, key: str) -> Model: | |
"""Retrieve a model from the manager by key. | |
Args: | |
key (str): The identifier of the model. | |
Returns: | |
Model: The model corresponding to the key. | |
""" | |
self.check_for_model(model_id=key) | |
return self._models[key] | |
def __len__(self) -> int: | |
"""Retrieve the number of models in the manager. | |
Returns: | |
int: The number of models in the manager. | |
""" | |
return len(self._models) | |
def keys(self): | |
"""Retrieve the keys (model identifiers) from the manager. | |
Returns: | |
List[str]: The keys of the models in the manager. | |
""" | |
return self._models.keys() | |
def models(self) -> Dict[str, Model]: | |
"""Retrieve the models dictionary from the manager. | |
Returns: | |
Dict[str, Model]: The keys of the models in the manager. | |
""" | |
return self._models | |
def describe_models(self) -> List[ModelDescription]: | |
return [ | |
ModelDescription( | |
model_id=model_id, | |
task_type=model.task_type, | |
batch_size=getattr(model, "batch_size", None), | |
input_width=getattr(model, "img_size_w", None), | |
input_height=getattr(model, "img_size_h", None), | |
) | |
for model_id, model in self._models.items() | |
] | |