Spaces:
Sleeping
Sleeping
import logging | |
from transformers import AutoTokenizer, pipeline, BitsAndBytesConfig | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class ModelLoader: | |
""" | |
Loads a single translation model + tokenizer, with optional 8-bit quantization. | |
""" | |
def __init__(self, quantize: bool = True): | |
self.quantize = quantize | |
def load(self, model_name: str): | |
# 1) Tokenizer | |
logger.info(f"Loading tokenizer for {model_name}") | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) | |
if not hasattr(tokenizer, "lang_code_to_id"): | |
raise AttributeError(f"Tokenizer for {model_name} has no lang_code_to_id mapping") | |
# 2) Pipeline | |
try: | |
bnb_cfg = BitsAndBytesConfig(load_in_8bit=self.quantize) | |
pipe = pipeline( | |
"translation", | |
model=model_name, | |
tokenizer=tokenizer, | |
device_map="auto", | |
quantization_config=bnb_cfg, | |
) | |
logger.info(f"Loaded {model_name} with 8-bit quantization") | |
except Exception as e: | |
logger.warning(f"8-bit quantization failed ({e}), loading full-precision") | |
pipe = pipeline( | |
"translation", | |
model=model_name, | |
tokenizer=tokenizer, | |
device_map="auto", | |
) | |
logger.info(f"Loaded {model_name} in full precision") | |
return tokenizer, pipe | |