""" File: model_translation.py Description: Loading models for text translations Author: Didier Guillevic Date: 2024-03-16 """ import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration class Singleton(type): _instances = {} def __call__(cls, *args, **kwargs): if cls not in cls._instances: cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) return cls._instances[cls] class ModelM2M100(metaclass=Singleton): """Loads an instance of the M2M100 model (418M). """ def __init__(self): self._model_name = "facebook/m2m100_418M" self._tokenizer = M2M100Tokenizer.from_pretrained(self._model_name) self._model = M2M100ForConditionalGeneration.from_pretrained( self._model_name, device_map="auto", torch_dtype=torch.float16, low_cpu_mem_usage=True ) @property def model_name(self): return self._model_name @property def tokenizer(self): return self._tokenizer @property def model(self): return self._model class ModelMADLAD(metaclass=Singleton): """Loads an instance of the Google MADLAD model (3B). """ def __init__(self): self._model_name = "google/madlad400-3b-mt" self._tokenizer = AutoTokenizer.from_pretrained( self.model_name, use_fast=True ) self._model = AutoModelForSeq2SeqLM.from_pretrained( self._model_name, device_map="auto", torch_dtype=torch.float16, low_cpu_mem_usage=True ) @property def model_name(self): return self._model_name @property def tokenizer(self): return self._tokenizer @property def model(self): return self._model # Bi-lingual individual models src_langs = set(["ar", "en", "fa", "fr", "he", "ja", "zh"]) model_names = { "ar": "Helsinki-NLP/opus-mt-ar-en", "en": "Helsinki-NLP/opus-mt-en-fr", "fa": "Helsinki-NLP/opus-mt-tc-big-fa-itc", "fr": "Helsinki-NLP/opus-mt-fr-en", "he": "Helsinki-NLP/opus-mt-tc-big-he-en", "zh": "Helsinki-NLP/opus-mt-zh-en", } # Registry for all loaded bilingual models tokenizer_model_registry = {} device = 'cpu' def get_tokenizer_model_for_src_lang(src_lang: str) -> (AutoTokenizer, AutoModelForSeq2SeqLM): """ Return the (tokenizer, model) for a given source language. """ src_lang = src_lang.lower() # Already loaded? if src_lang in tokenizer_model_registry: return tokenizer_model_registry.get(src_lang) # Load tokenizer and model model_name = model_names.get(src_lang) if not model_name: raise Exception(f"No model defined for language: {src_lang}") # We will leave the models on the CPU (for now) tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) if model.config.torch_dtype != torch.float16: model = model.half() model.to(device) tokenizer_model_registry[src_lang] = (tokenizer, model) return (tokenizer, model) # Max number of words for given input text # - Usually 512 tokens (max position encodings, as well as max length) # - Let's set to some number of words somewhat lower than that threshold # - e.g. 200 words max_words_per_chunk = 200