| from queue import Queue |
| from typing import TYPE_CHECKING, Optional, TypeVar |
|
|
| from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase |
|
|
| T = TypeVar("T") |
|
|
| if TYPE_CHECKING: |
| from invokeai.app.services.invoker import Invoker |
|
|
|
|
| class ObjectSerializerForwardCache(ObjectSerializerBase[T]): |
| """ |
| Provides a LRU cache for an instance of `ObjectSerializerBase`. |
| Saving an object to the cache always writes through to the underlying storage. |
| """ |
|
|
| def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: int = 20): |
| super().__init__() |
| self._underlying_storage = underlying_storage |
| self._cache: dict[str, T] = {} |
| self._cache_ids = Queue[str]() |
| self._max_cache_size = max_cache_size |
|
|
| def start(self, invoker: "Invoker") -> None: |
| self._invoker = invoker |
| start_op = getattr(self._underlying_storage, "start", None) |
| if callable(start_op): |
| start_op(invoker) |
|
|
| def stop(self, invoker: "Invoker") -> None: |
| self._invoker = invoker |
| stop_op = getattr(self._underlying_storage, "stop", None) |
| if callable(stop_op): |
| stop_op(invoker) |
|
|
| def load(self, name: str) -> T: |
| cache_item = self._get_cache(name) |
| if cache_item is not None: |
| return cache_item |
|
|
| obj = self._underlying_storage.load(name) |
| self._set_cache(name, obj) |
| return obj |
|
|
| def save(self, obj: T) -> str: |
| name = self._underlying_storage.save(obj) |
| self._set_cache(name, obj) |
| return name |
|
|
| def delete(self, name: str) -> None: |
| self._underlying_storage.delete(name) |
| if name in self._cache: |
| del self._cache[name] |
| self._on_deleted(name) |
|
|
| def _get_cache(self, name: str) -> Optional[T]: |
| return None if name not in self._cache else self._cache[name] |
|
|
| def _set_cache(self, name: str, data: T): |
| if name not in self._cache: |
| self._cache[name] = data |
| self._cache_ids.put(name) |
| if self._cache_ids.qsize() > self._max_cache_size: |
| self._cache.pop(self._cache_ids.get()) |
|
|