import logging from transformers import AutoTokenizer, pipeline, BitsAndBytesConfig logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ModelLoader: """ Loads a single translation model + tokenizer, with optional 8-bit quantization. """ def __init__(self, quantize: bool = True): self.quantize = quantize def load(self, model_name: str): # 1) Tokenizer logger.info(f"Loading tokenizer for {model_name}") tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) if not hasattr(tokenizer, "lang_code_to_id"): raise AttributeError(f"Tokenizer for {model_name} has no lang_code_to_id mapping") # 2) Pipeline try: bnb_cfg = BitsAndBytesConfig(load_in_8bit=self.quantize) pipe = pipeline( "translation", model=model_name, tokenizer=tokenizer, device_map="auto", quantization_config=bnb_cfg, ) logger.info(f"Loaded {model_name} with 8-bit quantization") except Exception as e: logger.warning(f"8-bit quantization failed ({e}), loading full-precision") pipe = pipeline( "translation", model=model_name, tokenizer=tokenizer, device_map="auto", ) logger.info(f"Loaded {model_name} in full precision") return tokenizer, pipe