| | |
| | |
| | """ |
| | Manage a RAM cache of diffusion/transformer models for fast switching. |
| | They are moved between GPU VRAM and CPU RAM as necessary. If the cache |
| | grows larger than a preset maximum, then the least recently used |
| | model will be cleared and (re)loaded from disk when next needed. |
| | """ |
| |
|
| | from abc import ABC, abstractmethod |
| | from dataclasses import dataclass, field |
| | from logging import Logger |
| | from typing import Dict, Generic, Optional, TypeVar |
| |
|
| | import torch |
| |
|
| | from invokeai.backend.model_manager.config import AnyModel, SubModelType |
| |
|
| |
|
| | class ModelLockerBase(ABC): |
| | """Base class for the model locker used by the loader.""" |
| |
|
| | @abstractmethod |
| | def lock(self) -> AnyModel: |
| | """Lock the contained model and move it into VRAM.""" |
| | pass |
| |
|
| | @abstractmethod |
| | def unlock(self) -> None: |
| | """Unlock the contained model, and remove it from VRAM.""" |
| | pass |
| |
|
| | @abstractmethod |
| | def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]: |
| | """Return the state dict (if any) for the cached model.""" |
| | pass |
| |
|
| | @property |
| | @abstractmethod |
| | def model(self) -> AnyModel: |
| | """Return the model.""" |
| | pass |
| |
|
| |
|
| | T = TypeVar("T") |
| |
|
| |
|
| | @dataclass |
| | class CacheRecord(Generic[T]): |
| | """ |
| | Elements of the cache: |
| | |
| | key: Unique key for each model, same as used in the models database. |
| | model: Model in memory. |
| | state_dict: A read-only copy of the model's state dict in RAM. It will be |
| | used as a template for creating a copy in the VRAM. |
| | size: Size of the model |
| | loaded: True if the model's state dict is currently in VRAM |
| | |
| | Before a model is executed, the state_dict template is copied into VRAM, |
| | and then injected into the model. When the model is finished, the VRAM |
| | copy of the state dict is deleted, and the RAM version is reinjected |
| | into the model. |
| | |
| | The state_dict should be treated as a read-only attribute. Do not attempt |
| | to patch or otherwise modify it. Instead, patch the copy of the state_dict |
| | after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel` |
| | context manager call `model_on_device()`. |
| | """ |
| |
|
| | key: str |
| | model: T |
| | device: torch.device |
| | state_dict: Optional[Dict[str, torch.Tensor]] |
| | size: int |
| | loaded: bool = False |
| | _locks: int = 0 |
| |
|
| | def lock(self) -> None: |
| | """Lock this record.""" |
| | self._locks += 1 |
| |
|
| | def unlock(self) -> None: |
| | """Unlock this record.""" |
| | self._locks -= 1 |
| | assert self._locks >= 0 |
| |
|
| | @property |
| | def locked(self) -> bool: |
| | """Return true if record is locked.""" |
| | return self._locks > 0 |
| |
|
| |
|
| | @dataclass |
| | class CacheStats(object): |
| | """Collect statistics on cache performance.""" |
| |
|
| | hits: int = 0 |
| | misses: int = 0 |
| | high_watermark: int = 0 |
| | in_cache: int = 0 |
| | cleared: int = 0 |
| | cache_size: int = 0 |
| | loaded_model_sizes: Dict[str, int] = field(default_factory=dict) |
| |
|
| |
|
| | class ModelCacheBase(ABC, Generic[T]): |
| | """Virtual base class for RAM model cache.""" |
| |
|
| | @property |
| | @abstractmethod |
| | def storage_device(self) -> torch.device: |
| | """Return the storage device (e.g. "CPU" for RAM).""" |
| | pass |
| |
|
| | @property |
| | @abstractmethod |
| | def execution_device(self) -> torch.device: |
| | """Return the exection device (e.g. "cuda" for VRAM).""" |
| | pass |
| |
|
| | @property |
| | @abstractmethod |
| | def lazy_offloading(self) -> bool: |
| | """Return true if the cache is configured to lazily offload models in VRAM.""" |
| | pass |
| |
|
| | @property |
| | @abstractmethod |
| | def max_cache_size(self) -> float: |
| | """Return the maximum size the RAM cache can grow to.""" |
| | pass |
| |
|
| | @max_cache_size.setter |
| | @abstractmethod |
| | def max_cache_size(self, value: float) -> None: |
| | """Set the cap on vram cache size.""" |
| |
|
| | @property |
| | @abstractmethod |
| | def max_vram_cache_size(self) -> float: |
| | """Return the maximum size the VRAM cache can grow to.""" |
| | pass |
| |
|
| | @max_vram_cache_size.setter |
| | @abstractmethod |
| | def max_vram_cache_size(self, value: float) -> float: |
| | """Set the maximum size the VRAM cache can grow to.""" |
| | pass |
| |
|
| | @abstractmethod |
| | def offload_unlocked_models(self, size_required: int) -> None: |
| | """Offload from VRAM any models not actively in use.""" |
| | pass |
| |
|
| | @abstractmethod |
| | def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None: |
| | """Move model into the indicated device.""" |
| | pass |
| |
|
| | @property |
| | @abstractmethod |
| | def stats(self) -> Optional[CacheStats]: |
| | """Return collected CacheStats object.""" |
| | pass |
| |
|
| | @stats.setter |
| | @abstractmethod |
| | def stats(self, stats: CacheStats) -> None: |
| | """Set the CacheStats object for collectin cache statistics.""" |
| | pass |
| |
|
| | @property |
| | @abstractmethod |
| | def logger(self) -> Logger: |
| | """Return the logger used by the cache.""" |
| | pass |
| |
|
| | @abstractmethod |
| | def make_room(self, size: int) -> None: |
| | """Make enough room in the cache to accommodate a new model of indicated size.""" |
| | pass |
| |
|
| | @abstractmethod |
| | def put( |
| | self, |
| | key: str, |
| | model: T, |
| | submodel_type: Optional[SubModelType] = None, |
| | ) -> None: |
| | """Store model under key and optional submodel_type.""" |
| | pass |
| |
|
| | @abstractmethod |
| | def get( |
| | self, |
| | key: str, |
| | submodel_type: Optional[SubModelType] = None, |
| | stats_name: Optional[str] = None, |
| | ) -> ModelLockerBase: |
| | """ |
| | Retrieve model using key and optional submodel_type. |
| | |
| | :param key: Opaque model key |
| | :param submodel_type: Type of the submodel to fetch |
| | :param stats_name: A human-readable id for the model for the purposes of |
| | stats reporting. |
| | |
| | This may raise an IndexError if the model is not in the cache. |
| | """ |
| | pass |
| |
|
| | @abstractmethod |
| | def cache_size(self) -> int: |
| | """Get the total size of the models currently cached.""" |
| | pass |
| |
|
| | @abstractmethod |
| | def print_cuda_stats(self) -> None: |
| | """Log debugging information on CUDA usage.""" |
| | pass |
| |
|