levanti_he_ar / translate.py
Guy Mor-Lan
add models
46f657a
raw
history blame
3.65 kB
import torch
from transformers import MarianMTModel, AutoTokenizer
import ctranslate2
from colorize import align_words
import logging
# Create a logger
logger = logging.getLogger()
logger.setLevel(logging.INFO) # Set to debug to capture all levels of logs
file_handler = logging.FileHandler('app.log', mode='a') # 'a' mode appends to the file
file_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
model_to_ar = MarianMTModel.from_pretrained("./he_ar/", output_attentions=True)
model_from_ar = MarianMTModel.from_pretrained("./ar_he/", output_attentions=True)
model_to_ar_ct2 = ctranslate2.Translator("./he_ar_ct2/")
model_from_ar_ct2 = ctranslate2.Translator("./ar_he_ct2/")
tokenizer_to_ar = AutoTokenizer.from_pretrained("./he_ar/")
tokenizer_from_ar = AutoTokenizer.from_pretrained("./ar_he/")
print("Done loading models")
dialect_map = {
"Palestinian": "P",
"Syrian": "S",
"Lebanese": "L",
"Egyptian": "E",
"פלסטיני": "P",
"סורי": "S",
"לבנוני": "L",
"מצרי": "E"
}
def translate(text, ct_model, hf_model, tokenizer, to_arabic=True,
threshold=None, layer=2, head=6):
logger.info(f"Translating: {text}")
inp_tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(text))
out_tokens = ct_model.translate_batch([inp_tokens])[0].hypotheses[0]
out_string = tokenizer.convert_tokens_to_string(out_tokens)
encoder_input_ids = torch.tensor(tokenizer.convert_tokens_to_ids(inp_tokens)).unsqueeze(0)
decoder_input_ids = torch.tensor(tokenizer.convert_tokens_to_ids(["<pad>"] + out_tokens +
['</s>'])).unsqueeze(0)
colorization_output = hf_model(input_ids=encoder_input_ids,
decoder_input_ids=decoder_input_ids)
if not threshold:
if len(inp_tokens) < 10:
threshold = 0.05
elif len(inp_tokens) < 20:
threshold = 0.10
else:
threshold = 0.05
srchtml, tgthtml = align_words(colorization_output,
tokenizer,
encoder_input_ids,
decoder_input_ids,
threshold,
skip_first_src=to_arabic,
skip_second_src=False,
layer=layer,
head=head)
html = f"<div style='direction: rtl'>{srchtml}<br><br>{tgthtml}</div>"
arabic = out_string if is_arabic(out_string) else text
return html, arabic
#%%
def is_arabic(text):
# return True if text has more than 50% arabic characters, False otherwise
text = text.replace(" ", "")
arabic_chars = 0
for c in text:
if "\u0600" <= c <= "\u06FF":
arabic_chars += 1
return arabic_chars / len(text) > 0.5
def run_translate(text, dialect=None):
if not text:
return ""
if is_arabic(text):
return translate(text, model_from_ar_ct2, model_from_ar, tokenizer_from_ar,
to_arabic=False, threshold=None, layer=2, head=1)
else:
if dialect in dialect_map:
dialect = dialect_map[dialect]
text = f"{dialect} {text}" if dialect else text
return translate(text, model_to_ar_ct2, model_to_ar, tokenizer_to_ar,
to_arabic=True, threshold=None, layer=2, head=6)