Spaces:
Running
on
Zero
Running
on
Zero
from collections import deque | |
from typing import List, Optional | |
from inference.core.entities.requests.inference import InferenceRequest | |
from inference.core.entities.responses.inference import InferenceResponse | |
from inference.core.managers.base import Model, ModelManager | |
from inference.core.managers.decorators.base import ModelManagerDecorator | |
from inference.core.managers.entities import ModelDescription | |
class WithFixedSizeCache(ModelManagerDecorator): | |
def __init__(self, model_manager: ModelManager, max_size: int = 8): | |
"""Cache decorator, models will be evicted based on the last utilization (`.infer` call). Internally, a [double-ended queue](https://docs.python.org/3/library/collections.html#collections.deque) is used to keep track of model utilization. | |
Args: | |
model_manager (ModelManager): Instance of a ModelManager. | |
max_size (int, optional): Max number of models at the same time. Defaults to 8. | |
""" | |
super().__init__(model_manager) | |
self.max_size = max_size | |
self._key_queue = deque(self.model_manager.keys()) | |
def add_model( | |
self, model_id: str, api_key: str, model_id_alias: Optional[str] = None | |
): | |
"""Adds a model to the manager and evicts the least recently used if the cache is full. | |
Args: | |
model_id (str): The identifier of the model. | |
model (Model): The model instance. | |
""" | |
queue_id = self._resolve_queue_id( | |
model_id=model_id, model_id_alias=model_id_alias | |
) | |
if model_id in self: | |
self._key_queue.remove(queue_id) | |
self._key_queue.append(queue_id) | |
return | |
should_pop = len(self) == self.max_size | |
if should_pop: | |
to_remove_model_id = self._key_queue.popleft() | |
self.remove(to_remove_model_id) | |
self._key_queue.append(queue_id) | |
try: | |
return super().add_model(model_id, api_key, model_id_alias=model_id_alias) | |
except Exception as error: | |
self._key_queue.remove(model_id) | |
raise error | |
def clear(self) -> None: | |
"""Removes all models from the manager.""" | |
for model_id in list(self.keys()): | |
self.remove(model_id) | |
def remove(self, model_id: str) -> Model: | |
try: | |
self._key_queue.remove(model_id) | |
except ValueError: | |
pass | |
return super().remove(model_id) | |
async def infer_from_request( | |
self, model_id: str, request: InferenceRequest, **kwargs | |
) -> InferenceResponse: | |
"""Processes a complete inference request and updates the cache. | |
Args: | |
model_id (str): The identifier of the model. | |
request (InferenceRequest): The request to process. | |
Returns: | |
InferenceResponse: The response from the inference. | |
""" | |
self._key_queue.remove(model_id) | |
self._key_queue.append(model_id) | |
return await super().infer_from_request(model_id, request, **kwargs) | |
def infer_only(self, model_id: str, request, img_in, img_dims, batch_size=None): | |
"""Performs only the inference part of a request and updates the cache. | |
Args: | |
model_id (str): The identifier of the model. | |
request: The request to process. | |
img_in: Input image. | |
img_dims: Image dimensions. | |
batch_size (int, optional): Batch size. | |
Returns: | |
Response from the inference-only operation. | |
""" | |
self._key_queue.remove(model_id) | |
self._key_queue.append(model_id) | |
return super().infer_only(model_id, request, img_in, img_dims, batch_size) | |
def preprocess(self, model_id: str, request): | |
"""Processes the preprocessing part of a request and updates the cache. | |
Args: | |
model_id (str): The identifier of the model. | |
request (InferenceRequest): The request to preprocess. | |
""" | |
self._key_queue.remove(model_id) | |
self._key_queue.append(model_id) | |
return super().preprocess(model_id, request) | |
def describe_models(self) -> List[ModelDescription]: | |
return self.model_manager.describe_models() | |
def _resolve_queue_id( | |
self, model_id: str, model_id_alias: Optional[str] = None | |
) -> str: | |
return model_id if model_id_alias is None else model_id_alias | |