# -*- coding: utf-8 -*- import os import re import sys import typing as tp import torch import pysbd from transformers import NllbTokenizer, AutoModelForSeq2SeqLM import unicodedata import time #hy_segmenter = pysbd.Segmenter(language="hy", clean=False) not needed MODEL_NAME = "AriNubar/nllb-200-distilled-600m-en-hyw" LANGUAGES = { "Արեւմտահայերէն | Western Armenian": "hyw_Armn", "Անգլերէն | English": "eng_Latn", } HF_TOKEN = os.environ.get("HF_TOKEN") def get_non_printing_char_replacer(replace_by: str = " "): non_printable_map = { ord(c): replace_by for c in (chr(i) for i in range(sys.maxunicode + 1)) # same as \p{C} in perl # see https://www.unicode.org/reports/tr44/#General_Category_Values if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"} } def replace_non_printing_char(line) -> str: return line.translate(non_printable_map) return replace_non_printing_char def clean_text(text: str, lang) -> str: HYW_CHARS_TO_NORMALIZE = { "«": '"', "»": '"', "“": '"', "”": '"', "’": "'", "‘": "'", "–": "-", "—": "-", "ՙ": "'", "՚": "'", } DOUBLE_CHARS_TO_NORMALIZE = { "Կ՛": "Կ'", "կ՛": "կ'", "Չ՛": "Չ'", "չ՛": "չ'", "Մ՛": "Մ'", "մ՛": "մ'", } replace_nonprint = get_non_printing_char_replacer() text = replace_nonprint(text) # print(text) text = text.replace("\t", " ").replace("\n", " ").replace("\r", " ").replace(r"[^\x00-\x7F]+", " ").replace(r"\s+", " ") text = text.strip() if lang == "hyw_Armn": text = text.translate(str.maketrans(HYW_CHARS_TO_NORMALIZE)) for k, v in DOUBLE_CHARS_TO_NORMALIZE.items(): text = text.replace(k, v) return text def sentenize_with_fillers(text, splitter, fix_double_space=True, ignore_errors=False): if fix_double_space: text = re.sub(r"\s+", " ", text) text = text.strip() sentences = splitter.segment(text) fillers = [] i = 0 for sent in sentences: start_idx = text.find(sent, i) if ignore_errors and start_idx == -1: start_idx = i + 1 assert start_idx != -1, f"Sent not found after index {i} in text: {text}" fillers.append(text[i:start_idx]) i = start_idx + len(sent) fillers.append(text[i:]) return sentences, fillers def init_tokenizer(tokenizer, new_lang='hyw_Armn'): """ Add a new language token to the tokenizer vocabulary (this should be done each time after its initialization) """ old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder) tokenizer.lang_code_to_id[new_lang] = old_len-1 tokenizer.id_to_lang_code[old_len-1] = new_lang # always move "mask" to the last position tokenizer.fairseq_tokens_to_ids[""] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id) tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()} if new_lang not in tokenizer._additional_special_tokens: tokenizer._additional_special_tokens.append(new_lang) # clear the added token encoder; otherwise a new token may end up there by mistake tokenizer.added_tokens_encoder = {} tokenizer.added_tokens_decoder = {} class Translator: def __init__(self) -> None: self.model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, token=HF_TOKEN) if torch.cuda.is_available(): self.model = self.model.cuda() self.tokenizer = NllbTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN) init_tokenizer(self.tokenizer) self.hyw_splitter = pysbd.Segmenter(language="hy", clean=True) self.eng_splitter = pysbd.Segmenter(language="en", clean=True) self.languages = LANGUAGES def translate_single( self, text, src_lang, tgt_lang, max_length="auto", num_beams=4, n_out=None, **kwargs, ): self.tokenizer.src_lang = src_lang encoded = self.tokenizer( text, return_tensors="pt", truncation=True, max_length=256 ) if max_length == "auto": max_length = int(32 + 2.0 * encoded.input_ids.shape[1]) generated_tokens = self.model.generate( **encoded.to(self.model.device), forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang], max_length=max_length, num_beams=num_beams, num_return_sequences=n_out or 1, **kwargs, ) out = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) if isinstance(text, str) and n_out is None: return out[0] return out def translate(self, text: str, src_lang: str, tgt_lang: str, max_length=256, num_beams=4, by_sentence=True, clean=True, **kwargs): if by_sentence: if src_lang == "eng_Latn": sents = self.eng_splitter.segment(text) elif src_lang == "hyw_Armn": sents = self.hyw_splitter.segment(text) if clean: sents = [clean_text(sent, src_lang) for sent in sents] if len(sents) > 1: results = self.translate_batch(sents, src_lang, tgt_lang, num_beams=num_beams, max_length=max_length, **kwargs) else: results = self.translate_single(sents, src_lang, tgt_lang, max_length=max_length, num_beams=num_beams, **kwargs) return " ".join(results) def translate_batch(self, texts, src_lang, tgt_lang, num_beams=4, max_length=256, **kwargs): self.tokenizer.src_lang = src_lang if torch.cuda.is_available(): inputs = self.tokenizer(texts, return_tensors="pt", max_length=max_length, padding=True, truncation=True).input_ids.to("cuda") translated_tokens = self.model.generate(inputs, num_beams=num_beams, forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang]) else: inputs = self.tokenizer(texts, return_tensors="pt", max_length=max_length, padding=True, truncation=True) translated_tokens = self.model.generate(**inputs, num_beams=num_beams, forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang]) return self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) if __name__ == "__main__": print("Initializing translator...") translator = Translator() print("Translator initialized.") start_time = time.time() print(translator.translate("Hello world!", "eng_Latn", "hyw_Armn")) print("Time elapsed: ", time.time() - start_time) start_time = time.time() print(translator.translate("I am the greatest translator! Do not fuck with me!", "eng_Latn", "hyw_Armn")) print("Time elapsed: ", time.time() - start_time)