ground-zero / src /engine /whisper_base.py
jefffffff9
Initial commit: Sahel-Agri Voice AI
76db545
"""
Loads the Whisper backbone model and processor once.
All other modules receive references to this shared instance.
"""
from __future__ import annotations
import logging
from pathlib import Path
import torch
import yaml
from transformers import WhisperForConditionalGeneration, WhisperProcessor
logger = logging.getLogger(__name__)
class WhisperBackbone:
"""Singleton-style loader for the Whisper base model and processor."""
def __init__(self, config_path: str = "configs/base_config.yaml") -> None:
config_path = Path(config_path)
with open(config_path) as f:
cfg = yaml.safe_load(f)
self._model_id: str = cfg["model"]["id"]
self._model: WhisperForConditionalGeneration | None = None
self._processor: WhisperProcessor | None = None
self._device: str = "cpu"
def load(self, device: str = "cuda", hf_token: str | None = None) -> None:
"""Load model and processor into memory. Call once at startup."""
self._device = device if torch.cuda.is_available() and device == "cuda" else "cpu"
logger.info("Loading %s on %s", self._model_id, self._device)
self._processor = WhisperProcessor.from_pretrained(
self._model_id,
token=hf_token,
)
dtype = torch.float16 if self._device == "cuda" else torch.float32
self._model = WhisperForConditionalGeneration.from_pretrained(
self._model_id,
torch_dtype=dtype,
token=hf_token,
).to(self._device)
self._model.eval()
logger.info("Model loaded successfully (dtype=%s, device=%s)", dtype, self._device)
@property
def model(self) -> WhisperForConditionalGeneration:
if self._model is None:
raise RuntimeError("Call WhisperBackbone.load() before accessing the model.")
return self._model
@property
def processor(self) -> WhisperProcessor:
if self._processor is None:
raise RuntimeError("Call WhisperBackbone.load() before accessing the processor.")
return self._processor
@property
def device(self) -> str:
return self._device
@property
def model_id(self) -> str:
return self._model_id
def free(self) -> None:
"""Release GPU memory."""
del self._model
del self._processor
self._model = None
self._processor = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("Backbone freed from memory.")