|
--- |
|
language: |
|
- en |
|
- az |
|
pipeline_tag: translation |
|
tags: |
|
- translation |
|
--- |
|
This is a version of NLLB fine-tuned to translate sentences between eng and azj languages, |
|
using the corresponding subset of https://github.com/turkic-interlingua/til-mt/tree/master/til_corpus. |
|
|
|
Example inference code (with the correct NLLB preprocessing!): |
|
|
|
```Python |
|
from transformers import NllbTokenizer, AutoModelForSeq2SeqLM, AutoConfig |
|
# this code is adapted from the Stopes repo of the NLLB team |
|
# https://github.com/facebookresearch/stopes/blob/main/stopes/pipelines/monolingual/monolingual_line_processor.py#L214 |
|
|
|
import re |
|
import sys |
|
import typing as tp |
|
import unicodedata |
|
from sacremoses import MosesPunctNormalizer |
|
|
|
|
|
mpn = MosesPunctNormalizer(lang="en") |
|
mpn.substitutions = [ |
|
(re.compile(r), sub) for r, sub in mpn.substitutions |
|
] |
|
|
|
|
|
def get_non_printing_char_replacer(replace_by: str = " ") -> tp.Callable[[str], str]: |
|
non_printable_map = { |
|
ord(c): replace_by |
|
for c in (chr(i) for i in range(sys.maxunicode + 1)) |
|
# same as \p{C} in perl |
|
# see https://www.unicode.org/reports/tr44/#General_Category_Values |
|
if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"} |
|
} |
|
|
|
def replace_non_printing_char(line) -> str: |
|
return line.translate(non_printable_map) |
|
|
|
return replace_non_printing_char |
|
|
|
replace_nonprint = get_non_printing_char_replacer(" ") |
|
|
|
def preproc(text): |
|
clean = mpn.normalize(text) |
|
clean = replace_nonprint(clean) |
|
# replace ๐๐ฏ๐๐ซ๐ ๐ข๐ฐ๐ ๐ by Francesca |
|
clean = unicodedata.normalize("NFKC", clean) |
|
return clean |
|
|
|
# loading the model |
|
model_name = "slone/nllb-600M-azj-eng-v1" |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).cuda() |
|
tokenizer = NllbTokenizer.from_pretrained(model_name) |
|
|
|
def translate(text, src_lang='eng_Latn', tgt_lang='azj_Latn', a=32, b=3, max_input_length=1024, num_beams=4, **kwargs): |
|
tokenizer.src_lang = src_lang |
|
tokenizer.tgt_lang = tgt_lang |
|
if isinstance(text, str): |
|
text = [text] |
|
text = [preproc(t) for t in text] |
|
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=max_input_length) |
|
result = model.generate( |
|
**inputs.to(model.device), |
|
forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang), |
|
max_new_tokens=int(a + b * inputs.input_ids.shape[1]), |
|
num_beams=num_beams, |
|
**kwargs |
|
) |
|
return tokenizer.batch_decode(result, skip_special_tokens=True) |
|
|
|
# Example of translating a couple of texts: |
|
texts = translate(["To be, or not to be, that is the question.", "Hello, how are you?"], src_lang='eng_Latn', tgt_lang='azj_Latn') |
|
print(texts) |
|
# ['Olmaq vษ ya olmamaq sualdฤฑr.', 'Salam, necษ var?'] |
|
``` |
|
|
|
If you want to translate too many sentences, you will need to put them in small batches |
|
(batch size can be chosen as the largest that fits into your GPU memory). |
|
An efficient way would be to batch them by similar length, like below: |
|
|
|
```Python |
|
def batched_translate(texts, batch_size=16, **kwargs): |
|
"""Translate texts in batches of similar length""" |
|
idxs, texts2 = zip(*sorted(enumerate(texts), key=lambda p: len(p[1]), reverse=True)) |
|
results = [] |
|
for i in trange(0, len(texts2), batch_size): |
|
results.extend(translate(texts2[i: i+batch_size], **kwargs)) |
|
return [p for i, p in sorted(zip(idxs, results))] |
|
``` |
|
|
|
Please beware that for translating a longer text, you need to split it into sentences and process them individually. |