Fucius's picture
Upload 422 files
df6c67d verified
raw
history blame
11.6 kB
import time
from typing import Dict, List, Optional, Tuple
import numpy as np
from fastapi.encoders import jsonable_encoder
from inference.core.cache import cache
from inference.core.cache.serializers import to_cachable_inference_item
from inference.core.devices.utils import GLOBAL_INFERENCE_SERVER_ID
from inference.core.entities.requests.inference import InferenceRequest
from inference.core.entities.responses.inference import InferenceResponse
from inference.core.env import (
DISABLE_INFERENCE_CACHE,
METRICS_ENABLED,
METRICS_INTERVAL,
ROBOFLOW_SERVER_UUID,
)
from inference.core.exceptions import InferenceModelNotFound
from inference.core.logger import logger
from inference.core.managers.entities import ModelDescription
from inference.core.managers.pingback import PingbackInfo
from inference.core.models.base import Model, PreprocessReturnMetadata
from inference.core.registries.base import ModelRegistry
class ModelManager:
"""Model managers keep track of a dictionary of Model objects and is responsible for passing requests to the right model using the infer method."""
def __init__(self, model_registry: ModelRegistry, models: Optional[dict] = None):
self.model_registry = model_registry
self._models: Dict[str, Model] = models if models is not None else {}
def init_pingback(self):
"""Initializes pingback mechanism."""
self.num_errors = 0 # in the device
self.uuid = ROBOFLOW_SERVER_UUID
if METRICS_ENABLED:
self.pingback = PingbackInfo(self)
self.pingback.start()
def add_model(
self, model_id: str, api_key: str, model_id_alias: Optional[str] = None
) -> None:
"""Adds a new model to the manager.
Args:
model_id (str): The identifier of the model.
model (Model): The model instance.
"""
logger.debug(
f"ModelManager - Adding model with model_id={model_id}, model_id_alias={model_id_alias}"
)
if model_id in self._models:
logger.debug(
f"ModelManager - model with model_id={model_id} is already loaded."
)
return
logger.debug("ModelManager - model initialisation...")
model = self.model_registry.get_model(
model_id if model_id_alias is None else model_id_alias, api_key
)(
model_id=model_id,
api_key=api_key,
)
logger.debug("ModelManager - model successfully loaded.")
self._models[model_id if model_id_alias is None else model_id_alias] = model
def check_for_model(self, model_id: str) -> None:
"""Checks whether the model with the given ID is in the manager.
Args:
model_id (str): The identifier of the model.
Raises:
InferenceModelNotFound: If the model is not found in the manager.
"""
if model_id not in self:
raise InferenceModelNotFound(f"Model with id {model_id} not loaded.")
async def infer_from_request(
self, model_id: str, request: InferenceRequest, **kwargs
) -> InferenceResponse:
"""Runs inference on the specified model with the given request.
Args:
model_id (str): The identifier of the model.
request (InferenceRequest): The request to process.
Returns:
InferenceResponse: The response from the inference.
"""
logger.debug(
f"ModelManager - inference from request started for model_id={model_id}."
)
try:
rtn_val = await self.model_infer(
model_id=model_id, request=request, **kwargs
)
logger.debug(
f"ModelManager - inference from request finished for model_id={model_id}."
)
finish_time = time.time()
if not DISABLE_INFERENCE_CACHE:
logger.debug(
f"ModelManager - caching inference request started for model_id={model_id}"
)
cache.zadd(
f"models",
value=f"{GLOBAL_INFERENCE_SERVER_ID}:{request.api_key}:{model_id}",
score=finish_time,
expire=METRICS_INTERVAL * 2,
)
if (
hasattr(request, "image")
and hasattr(request.image, "type")
and request.image.type == "numpy"
):
request.image.value = str(request.image.value)
cache.zadd(
f"inference:{GLOBAL_INFERENCE_SERVER_ID}:{model_id}",
value=to_cachable_inference_item(request, rtn_val),
score=finish_time,
expire=METRICS_INTERVAL * 2,
)
logger.debug(
f"ModelManager - caching inference request finished for model_id={model_id}"
)
return rtn_val
except Exception as e:
finish_time = time.time()
if not DISABLE_INFERENCE_CACHE:
cache.zadd(
f"models",
value=f"{GLOBAL_INFERENCE_SERVER_ID}:{request.api_key}:{model_id}",
score=finish_time,
expire=METRICS_INTERVAL * 2,
)
cache.zadd(
f"error:{GLOBAL_INFERENCE_SERVER_ID}:{model_id}",
value={
"request": jsonable_encoder(
request.dict(exclude={"image", "subject", "prompt"})
),
"error": str(e),
},
score=finish_time,
expire=METRICS_INTERVAL * 2,
)
raise
async def model_infer(self, model_id: str, request: InferenceRequest, **kwargs):
self.check_for_model(model_id)
return self._models[model_id].infer_from_request(request)
def make_response(
self, model_id: str, predictions: List[List[float]], *args, **kwargs
) -> InferenceResponse:
"""Creates a response object from the model's predictions.
Args:
model_id (str): The identifier of the model.
predictions (List[List[float]]): The model's predictions.
Returns:
InferenceResponse: The created response object.
"""
self.check_for_model(model_id)
return self._models[model_id].make_response(predictions, *args, **kwargs)
def postprocess(
self,
model_id: str,
predictions: Tuple[np.ndarray, ...],
preprocess_return_metadata: PreprocessReturnMetadata,
*args,
**kwargs,
) -> List[List[float]]:
"""Processes the model's predictions after inference.
Args:
model_id (str): The identifier of the model.
predictions (np.ndarray): The model's predictions.
Returns:
List[List[float]]: The post-processed predictions.
"""
self.check_for_model(model_id)
return self._models[model_id].postprocess(
predictions, preprocess_return_metadata, *args, **kwargs
)
def predict(self, model_id: str, *args, **kwargs) -> Tuple[np.ndarray, ...]:
"""Runs prediction on the specified model.
Args:
model_id (str): The identifier of the model.
Returns:
np.ndarray: The predictions from the model.
"""
self.check_for_model(model_id)
self._models[model_id].metrics["num_inferences"] += 1
tic = time.perf_counter()
res = self._models[model_id].predict(*args, **kwargs)
toc = time.perf_counter()
self._models[model_id].metrics["avg_inference_time"] += toc - tic
return res
def preprocess(
self, model_id: str, request: InferenceRequest
) -> Tuple[np.ndarray, PreprocessReturnMetadata]:
"""Preprocesses the request before inference.
Args:
model_id (str): The identifier of the model.
request (InferenceRequest): The request to preprocess.
Returns:
Tuple[np.ndarray, List[Tuple[int, int]]]: The preprocessed data.
"""
self.check_for_model(model_id)
return self._models[model_id].preprocess(**request.dict())
def get_class_names(self, model_id):
"""Retrieves the class names for a given model.
Args:
model_id (str): The identifier of the model.
Returns:
List[str]: The class names of the model.
"""
self.check_for_model(model_id)
return self._models[model_id].class_names
def get_task_type(self, model_id: str, api_key: str = None) -> str:
"""Retrieves the task type for a given model.
Args:
model_id (str): The identifier of the model.
Returns:
str: The task type of the model.
"""
self.check_for_model(model_id)
return self._models[model_id].task_type
def remove(self, model_id: str) -> None:
"""Removes a model from the manager.
Args:
model_id (str): The identifier of the model.
"""
try:
self.check_for_model(model_id)
self._models[model_id].clear_cache()
del self._models[model_id]
except InferenceModelNotFound:
logger.warning(
f"Attempted to remove model with id {model_id}, but it is not loaded. Skipping..."
)
def clear(self) -> None:
"""Removes all models from the manager."""
for model_id in list(self.keys()):
self.remove(model_id)
def __contains__(self, model_id: str) -> bool:
"""Checks if the model is contained in the manager.
Args:
model_id (str): The identifier of the model.
Returns:
bool: Whether the model is in the manager.
"""
return model_id in self._models
def __getitem__(self, key: str) -> Model:
"""Retrieve a model from the manager by key.
Args:
key (str): The identifier of the model.
Returns:
Model: The model corresponding to the key.
"""
self.check_for_model(model_id=key)
return self._models[key]
def __len__(self) -> int:
"""Retrieve the number of models in the manager.
Returns:
int: The number of models in the manager.
"""
return len(self._models)
def keys(self):
"""Retrieve the keys (model identifiers) from the manager.
Returns:
List[str]: The keys of the models in the manager.
"""
return self._models.keys()
def models(self) -> Dict[str, Model]:
"""Retrieve the models dictionary from the manager.
Returns:
Dict[str, Model]: The keys of the models in the manager.
"""
return self._models
def describe_models(self) -> List[ModelDescription]:
return [
ModelDescription(
model_id=model_id,
task_type=model.task_type,
batch_size=getattr(model, "batch_size", None),
input_width=getattr(model, "img_size_w", None),
input_height=getattr(model, "img_size_h", None),
)
for model_id, model in self._models.items()
]