Spaces:
Sleeping
Sleeping
import logging | |
from typing import Union, List | |
from langdetect import detect, LangDetectException | |
from models.model_loader import ModelLoader | |
from models.model_selector import ModelSelector | |
logger = logging.getLogger(__name__) | |
class ModelManager: | |
""" | |
Orchestrates model selection, loading, and auto-language detection. | |
Exposes: | |
- translate(text, src_lang=None, tgt_lang=None) | |
- get_info() | |
""" | |
def __init__( | |
self, | |
candidates: List[str] = None, | |
quantize: bool = True, | |
default_tgt: str = None, | |
): | |
self.selector = ModelSelector(candidates, quantize) | |
self.loader = ModelLoader(quantize) | |
self.tokenizer = None | |
self.pipeline = None | |
self.lang_codes = [] | |
self.default_tgt = default_tgt # e.g. "tur_Latn" | |
self._load_best_model() | |
def _load_best_model(self): | |
model_name = self.selector.select() | |
tok, pipe = self.loader.load(model_name) | |
self.tokenizer = tok | |
self.pipeline = pipe | |
self.lang_codes = list(tok.lang_code_to_id.keys()) | |
# Pick a Turkish code if not explicitly set | |
if not self.default_tgt: | |
tur = [c for c in self.lang_codes if c.lower().startswith("tr")] | |
if not tur: | |
raise ValueError(f"No Turkish code found in {model_name}") | |
self.default_tgt = tur[0] | |
logger.info(f"Default target language: {self.default_tgt}") | |
def translate( | |
self, | |
text: Union[str, List[str]], | |
src_lang: str = None, | |
tgt_lang: str = None, | |
): | |
tgt = tgt_lang or self.default_tgt | |
# Auto-detect source if missing | |
if not src_lang: | |
sample = text[0] if isinstance(text, list) else text | |
try: | |
iso = detect(sample).lower() | |
candidates = [c for c in self.lang_codes if c.lower().startswith(iso)] | |
if not candidates: | |
raise LangDetectException(f"No mapping for ISO '{iso}'") | |
exact = [c for c in candidates if c.lower() == iso] | |
src = exact[0] if exact else candidates[0] | |
logger.info(f"Detected src_lang={src}") | |
except Exception as e: | |
logger.warning(f"Auto-detect failed ({e}); defaulting to English") | |
eng = [c for c in self.lang_codes if c.lower().startswith("en")] | |
src = eng[0] if eng else self.lang_codes[0] | |
else: | |
src = src_lang | |
return self.pipeline(text, src_lang=src, tgt_lang=tgt) | |
def get_info(self): | |
""" | |
Returns a dict for your sidebar: | |
{ model_name, quantized, device, default_tgt } | |
""" | |
mdl = getattr(self.pipeline, "model", None) | |
return { | |
"model": getattr(mdl, "name_or_path", None), | |
"quantized": getattr(mdl, "is_loaded_in_8bit", False), | |
"device": str(getattr(mdl, "device", "auto")), | |
"default_tgt": self.default_tgt, | |
} | |