File size: 1,504 Bytes
ff2d1f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import logging
from transformers import AutoTokenizer, pipeline, BitsAndBytesConfig

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class ModelLoader:
    """
    Loads a single translation model + tokenizer, with optional 8-bit quantization.
    """
    def __init__(self, quantize: bool = True):
        self.quantize = quantize

    def load(self, model_name: str):
        # 1) Tokenizer
        logger.info(f"Loading tokenizer for {model_name}")
        tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
        if not hasattr(tokenizer, "lang_code_to_id"):
            raise AttributeError(f"Tokenizer for {model_name} has no lang_code_to_id mapping")
        # 2) Pipeline
        try:
            bnb_cfg = BitsAndBytesConfig(load_in_8bit=self.quantize)
            pipe = pipeline(
                "translation",
                model=model_name,
                tokenizer=tokenizer,
                device_map="auto",
                quantization_config=bnb_cfg,
            )
            logger.info(f"Loaded {model_name} with 8-bit quantization")
        except Exception as e:
            logger.warning(f"8-bit quantization failed ({e}), loading full-precision")
            pipe = pipeline(
                "translation",
                model=model_name,
                tokenizer=tokenizer,
                device_map="auto",
            )
            logger.info(f"Loaded {model_name} in full precision")
        return tokenizer, pipe