nllb-extended-v2024-demo / translation.py
cointegrated's picture
evolve from the myv-rus demo
193923d
import re
import sys
import typing as tp
import unicodedata
import torch
from sacremoses import MosesPunctNormalizer
from sentence_splitter import SentenceSplitter
from transformers import AutoModelForSeq2SeqLM, NllbTokenizer
MODEL_URL = "slone/nllb-210-v1"
LANGUAGES = {
"Русский | Russian": "rus_Cyrl",
"English | Английский": "eng_Latn",
"Azərbaycan | Azerbaijani | Азербайджанский": "azj_Latn",
"Башҡорт | Bashkir | Башкирский": "bak_Cyrl",
"Буряад | Buryat | Бурятский": "bxr_Cyrl",
"Чӑваш | Chuvash | Чувашский": "chv_Cyrl",
"Хакас | Khakas | Хакасский": "kjh_Cyrl",
"Къарачай-малкъар | Karachay-Balkar | Карачаево-балкарский": "krc_Cyrl",
"Марий | Meadow Mari | Марийский": "mhr_Cyrl",
"Эрзянь | Erzya | Эрзянский": "myv_Cyrl",
"Татар | Tatar | Татарский": "tat_Cyrl",
"Тыва | Тувинский | Tuvan ": "tyv_Cyrl",
}
L1 = "rus_Cyrl"
L2 = "eng_Latn"
def get_non_printing_char_replacer(replace_by: str = " ") -> tp.Callable[[str], str]:
non_printable_map = {
ord(c): replace_by
for c in (chr(i) for i in range(sys.maxunicode + 1))
# same as \p{C} in perl
# see https://www.unicode.org/reports/tr44/#General_Category_Values
if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"}
}
def replace_non_printing_char(line) -> str:
return line.translate(non_printable_map)
return replace_non_printing_char
class TextPreprocessor:
"""
Mimic the text preprocessing made for the NLLB model.
This code is adapted from the Stopes repo of the NLLB team:
https://github.com/facebookresearch/stopes/blob/main/stopes/pipelines/monolingual/monolingual_line_processor.py#L214
"""
def __init__(self, lang="en"):
self.mpn = MosesPunctNormalizer(lang=lang)
self.mpn.substitutions = [
(re.compile(r), sub) for r, sub in self.mpn.substitutions
]
self.replace_nonprint = get_non_printing_char_replacer(" ")
def __call__(self, text: str) -> str:
clean = self.mpn.normalize(text)
clean = self.replace_nonprint(clean)
# replace 𝓕𝔯𝔞𝔫𝔠𝔢𝔰𝔠𝔞 by Francesca
clean = unicodedata.normalize("NFKC", clean)
return clean
def sentenize_with_fillers(text, splitter, fix_double_space=True, ignore_errors=False):
"""Apply a sentence splitter and return the sentences and all separators before and after them"""
if fix_double_space:
text = re.sub(" +", " ", text)
sentences = splitter.split(text)
fillers = []
i = 0
for sentence in sentences:
start_idx = text.find(sentence, i)
if ignore_errors and start_idx == -1:
# print(f"sent not found after {i}: `{sentence}`")
start_idx = i + 1
assert start_idx != -1, f"sent not found after {i}: `{sentence}`"
fillers.append(text[i:start_idx])
i = start_idx + len(sentence)
fillers.append(text[i:])
return sentences, fillers
class Translator:
def __init__(self):
self.model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_URL, low_cpu_mem_usage=True)
if torch.cuda.is_available():
self.model.cuda()
self.tokenizer = NllbTokenizer.from_pretrained(MODEL_URL)
self.splitter = SentenceSplitter("ru")
self.preprocessor = TextPreprocessor()
self.languages = LANGUAGES
def translate(
self,
text,
src_lang=L1,
tgt_lang=L2,
max_length="auto",
num_beams=4,
by_sentence=True,
preprocess=True,
**kwargs,
):
"""Translate a text sentence by sentence, preserving the fillers around the sentences."""
if by_sentence:
sents, fillers = sentenize_with_fillers(
text, splitter=self.splitter, ignore_errors=True
)
else:
sents = [text]
fillers = ["", ""]
if preprocess:
sents = [self.preprocessor(sent) for sent in sents]
results = []
for sent, sep in zip(sents, fillers):
results.append(sep)
results.append(
self.translate_single(
sent,
src_lang=src_lang,
tgt_lang=tgt_lang,
max_length=max_length,
num_beams=num_beams,
**kwargs,
)
)
results.append(fillers[-1])
return "".join(results)
def translate_single(
self,
text,
src_lang=L1,
tgt_lang=L2,
max_length="auto",
num_beams=4,
n_out=None,
**kwargs,
):
self.tokenizer.src_lang = src_lang
encoded = self.tokenizer(
text, return_tensors="pt", truncation=True, max_length=512
)
if max_length == "auto":
max_length = int(32 + 2.0 * encoded.input_ids.shape[1])
generated_tokens = self.model.generate(
**encoded.to(self.model.device),
forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang],
max_length=max_length,
num_beams=num_beams,
num_return_sequences=n_out or 1,
**kwargs,
)
out = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
if isinstance(text, str) and n_out is None:
return out[0]
return out
if __name__ == "__main__":
print("Initializing a translator to pre-download models...")
translator = Translator()
print("Initialization successful!")