ground-zero / src /engine /adapter_manager.py
jefffffff9
Initial commit: Sahel-Agri Voice AI
76db545
"""
LoRA adapter hot-swap manager.
Uses PEFT's multi-adapter API:
- model.load_adapter(path, adapter_name=lang) — first load (~2s per adapter)
- model.set_adapter(lang) — subsequent swap (~50ms)
This keeps a single backbone in VRAM and swaps only the ~50MB adapter weights,
vs reloading the full 1.5GB model per language.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from peft import PeftModel
if TYPE_CHECKING:
from transformers import WhisperForConditionalGeneration
logger = logging.getLogger(__name__)
class AdapterManager:
"""Manages registration and hot-swapping of LoRA language adapters."""
def __init__(self, base_model: "WhisperForConditionalGeneration", config: dict) -> None:
self._base_model = base_model
self._config = config
self._registry: dict[str, str] = {} # language_code -> adapter_path
self._peft_model: PeftModel | None = None
self._active: str | None = None
def register(self, language: str, adapter_path: str) -> None:
"""Register an adapter path. Does not load it yet."""
path = Path(adapter_path)
if not path.exists():
logger.warning(
"Adapter path '%s' for language '%s' does not exist. "
"Run training first, or check the path.",
adapter_path, language,
)
self._registry[language] = str(path)
logger.info("Registered adapter '%s' → %s", language, adapter_path)
def load_adapter(self, language: str) -> None:
"""
Load an adapter into the model for the first time.
Slow (~2s): reads adapter weights from disk.
Subsequent activate() calls reuse the already-loaded weights.
"""
if language not in self._registry:
raise KeyError(f"No adapter registered for language '{language}'. "
f"Available: {list(self._registry)}")
adapter_path = self._registry[language]
if self._peft_model is None:
# First adapter: wrap the base model with PeftModel
logger.info("Wrapping base model with first adapter '%s'...", language)
self._peft_model = PeftModel.from_pretrained(
self._base_model,
adapter_path,
adapter_name=language,
)
else:
# Subsequent adapters: load into the existing PeftModel
logger.info("Loading adapter '%s' into existing PeftModel...", language)
self._peft_model.load_adapter(adapter_path, adapter_name=language)
self._active = language
logger.info("Adapter '%s' loaded and active.", language)
def activate(self, language: str) -> None:
"""
Hot-swap to a previously loaded adapter (~50ms).
Call load_adapter() first if this adapter hasn't been loaded.
"""
if self._peft_model is None:
self.load_adapter(language)
return
loaded = set(self._peft_model.peft_config.keys())
if language not in loaded:
self.load_adapter(language)
return
self._peft_model.set_adapter(language)
self._active = language
logger.debug("Hot-swapped to adapter '%s'.", language)
def get_model(self) -> "WhisperForConditionalGeneration | PeftModel":
"""Return the PeftModel (or base model if no adapter loaded yet)."""
return self._peft_model if self._peft_model is not None else self._base_model
def get_active(self) -> str | None:
return self._active
def list_available(self) -> list[str]:
return list(self._registry.keys())
def list_loaded(self) -> list[str]:
if self._peft_model is None:
return []
return list(self._peft_model.peft_config.keys())