OMG / inference /core /managers /decorators /fixed_size_cache.py
Fucius's picture
Upload 422 files
df6c67d verified
from collections import deque
from typing import List, Optional
from inference.core.entities.requests.inference import InferenceRequest
from inference.core.entities.responses.inference import InferenceResponse
from inference.core.managers.base import Model, ModelManager
from inference.core.managers.decorators.base import ModelManagerDecorator
from inference.core.managers.entities import ModelDescription
class WithFixedSizeCache(ModelManagerDecorator):
def __init__(self, model_manager: ModelManager, max_size: int = 8):
"""Cache decorator, models will be evicted based on the last utilization (`.infer` call). Internally, a [double-ended queue](https://docs.python.org/3/library/collections.html#collections.deque) is used to keep track of model utilization.
Args:
model_manager (ModelManager): Instance of a ModelManager.
max_size (int, optional): Max number of models at the same time. Defaults to 8.
"""
super().__init__(model_manager)
self.max_size = max_size
self._key_queue = deque(self.model_manager.keys())
def add_model(
self, model_id: str, api_key: str, model_id_alias: Optional[str] = None
):
"""Adds a model to the manager and evicts the least recently used if the cache is full.
Args:
model_id (str): The identifier of the model.
model (Model): The model instance.
"""
queue_id = self._resolve_queue_id(
model_id=model_id, model_id_alias=model_id_alias
)
if model_id in self:
self._key_queue.remove(queue_id)
self._key_queue.append(queue_id)
return
should_pop = len(self) == self.max_size
if should_pop:
to_remove_model_id = self._key_queue.popleft()
self.remove(to_remove_model_id)
self._key_queue.append(queue_id)
try:
return super().add_model(model_id, api_key, model_id_alias=model_id_alias)
except Exception as error:
self._key_queue.remove(model_id)
raise error
def clear(self) -> None:
"""Removes all models from the manager."""
for model_id in list(self.keys()):
self.remove(model_id)
def remove(self, model_id: str) -> Model:
try:
self._key_queue.remove(model_id)
except ValueError:
pass
return super().remove(model_id)
async def infer_from_request(
self, model_id: str, request: InferenceRequest, **kwargs
) -> InferenceResponse:
"""Processes a complete inference request and updates the cache.
Args:
model_id (str): The identifier of the model.
request (InferenceRequest): The request to process.
Returns:
InferenceResponse: The response from the inference.
"""
self._key_queue.remove(model_id)
self._key_queue.append(model_id)
return await super().infer_from_request(model_id, request, **kwargs)
def infer_only(self, model_id: str, request, img_in, img_dims, batch_size=None):
"""Performs only the inference part of a request and updates the cache.
Args:
model_id (str): The identifier of the model.
request: The request to process.
img_in: Input image.
img_dims: Image dimensions.
batch_size (int, optional): Batch size.
Returns:
Response from the inference-only operation.
"""
self._key_queue.remove(model_id)
self._key_queue.append(model_id)
return super().infer_only(model_id, request, img_in, img_dims, batch_size)
def preprocess(self, model_id: str, request):
"""Processes the preprocessing part of a request and updates the cache.
Args:
model_id (str): The identifier of the model.
request (InferenceRequest): The request to preprocess.
"""
self._key_queue.remove(model_id)
self._key_queue.append(model_id)
return super().preprocess(model_id, request)
def describe_models(self) -> List[ModelDescription]:
return self.model_manager.describe_models()
def _resolve_queue_id(
self, model_id: str, model_id_alias: Optional[str] = None
) -> str:
return model_id if model_id_alias is None else model_id_alias