demo_language_moore / src /translation.py
khof312's picture
Remove use of authentication token.
944775a
raw
history blame
4.34 kB
import torch
from transformers import set_seed, pipeline
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
######### HELSINKI NLP ##################
def translate_helsinki_nlp(s:str, src_iso:str, dest_iso:str)-> str:
'''
Translate the text using HelsinkiNLP's Opus models for Mossi language.
Parameters
----------
s: str
The text
src_iso:
The ISO-3 code of the source language
dest_iso:
The ISO-3 code of the destination language
Returns
----------
translation:str
The translated text
'''
# Ensure replicability
set_seed(555)
# Inference
translator = pipeline("translation", model=f"Helsinki-NLP/opus-mt-{src_iso}-{dest_iso}")
translation = translator(s)[0]['translation_text']
return translation
######### MASAKHANE ##################
def translate_masakhane(s:str, src_iso:str, dest_iso:str)-> str:
'''
Translate the text using Masakhane's M2M models for Mossi language.
Parameters
----------
s: str
The text
src_iso:
The ISO-3 code of the source language
dest_iso:
The ISO-3 code of the destination language
Returns
----------
translation:str
The translated text
'''
# Ensure replicability
set_seed(555)
# Load model
model = M2M100ForConditionalGeneration.from_pretrained(f"masakhane/m2m100_418m_{src_iso}_{dest_iso}_news")
tokenizer = M2M100Tokenizer.from_pretrained(f"masakhane/m2m100_418m_{src_iso}_{dest_iso}_news")
# Inference
encoded = tokenizer(s, return_tensors="pt")
generated_tokens = model.generate(**encoded)
translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
return translation
######### META ##################
def translate_facebook(s:str, src_iso:str, dest_iso:str)-> str:
'''
Translate the text using Meta's NLLB model for Mossi language.
Parameters
----------
s: str
The text
src_iso:
The ISO-3 code of the source language
dest_iso:
The ISO-3 code of the destination language
Returns
----------
translation:str
The translated text
'''
# Ensure replicability
set_seed(555)
# Load model
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M",src_lang=f"{src_iso}_Latn") #use_auth_token=True,
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") #, use_auth_token=True)
# Inference
encoded = tokenizer(s, return_tensors="pt")
translated_tokens = model.generate(**encoded, forced_bos_token_id=tokenizer.lang_code_to_id[f"{dest_iso}_Latn"], max_length=30)
translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
return translation
######### ALL OF THE ABOVE ##################
def translate(s, src_iso, dest_iso):
'''
Translate the text using all available models (Meta, Masakhane, and Helsinki NLP where applicable).
Parameters
----------
s: str
The text
src_iso:
The ISO-3 code of the source language
dest_iso:
The ISO-3 code of the destination language
Returns
----------
translation:str
The translated text, concatenated over different models
'''
# Translate with Meta NLLB
translation= "Meta's NLLB translation is:\n\n" + translate_facebook(s, src_iso, dest_iso)
# Check if the ISO pair is supported by another model and if so, add to translation
iso_pair = f"{src_iso}-{dest_iso}"
if iso_pair in ["mos-eng", 'eng-mos', 'fra-mos']:
src_iso = src_iso.lower().replace("eng", "en").replace("fra", "fr")
dest_iso = dest_iso.replace("eng", "en").replace("fra", "fr")
translation+= f"\n\n\nHelsinkiNLP's Opus translation is:\n\n {translate_helsinki_nlp(s, src_iso, dest_iso)}"
if iso_pair in ["mos-fra", "fra-mos"]:
src_iso = src_iso.lower().replace("fra", "fr")
dest_iso = dest_iso.replace("fra", "fr")
translation+= "\n\n\nMasakhane's M2M translation is:\n\n" + translate_masakhane(s, src_iso, dest_iso)
return translation