Spaces:
Running
Running
import torch | |
from transformers import set_seed, pipeline | |
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
import time | |
######### HELSINKI NLP ################## | |
def translate_helsinki_nlp(s:str, src_iso:str, dest_iso:str)-> str: | |
''' | |
Translate the text using HelsinkiNLP's Opus models for Mossi language. | |
Parameters | |
---------- | |
s: str | |
The text | |
src_iso: | |
The ISO-3 code of the source language | |
dest_iso: | |
The ISO-3 code of the destination language | |
Returns | |
---------- | |
translation:str | |
The translated text | |
''' | |
# Ensure replicability | |
set_seed(555) | |
# Inference | |
translator = pipeline("translation", model=f"Helsinki-NLP/opus-mt-{src_iso}-{dest_iso}") | |
translation = translator(s)[0]['translation_text'] | |
return translation | |
######### MASAKHANE ################## | |
def translate_masakhane(s:str, src_iso:str, dest_iso:str)-> str: | |
''' | |
Translate the text using Masakhane's M2M models for Mossi language. | |
Parameters | |
---------- | |
s: str | |
The text | |
src_iso: | |
The ISO-3 code of the source language | |
dest_iso: | |
The ISO-3 code of the destination language | |
Returns | |
---------- | |
translation:str | |
The translated text | |
''' | |
# Ensure replicability | |
set_seed(555) | |
# Load model | |
model = M2M100ForConditionalGeneration.from_pretrained(f"masakhane/m2m100_418m_{src_iso}_{dest_iso}_news") | |
tokenizer = M2M100Tokenizer.from_pretrained(f"masakhane/m2m100_418m_{src_iso}_{dest_iso}_news") | |
# Inference | |
encoded = tokenizer(s, return_tensors="pt") | |
generated_tokens = model.generate(**encoded) | |
translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
return translation | |
######### META ################## | |
def translate_facebook(s:str, src_iso:str, dest_iso:str)-> str: | |
''' | |
Translate the text using Meta's NLLB model for Mossi language. | |
Parameters | |
---------- | |
s: str | |
The text | |
src_iso: | |
The ISO-3 code of the source language | |
dest_iso: | |
The ISO-3 code of the destination language | |
Returns | |
---------- | |
translation:str | |
The translated text | |
''' | |
# Ensure replicability | |
set_seed(555) | |
# Load model | |
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M",src_lang=f"{src_iso}_Latn") #use_auth_token=True, | |
model = AutoModelForSeq2SeqLM.from_pretrained("anyantudre/NLLB-finetuned-fr-to-mos-V3") #, use_auth_token=True) | |
# Inference | |
encoded = tokenizer(s, return_tensors="pt") | |
translated_tokens = model.generate(**encoded, forced_bos_token_id=tokenizer.convert_tokens_to_ids(f"{dest_iso}_Latn"), max_length=120) | |
translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] | |
return translation | |
######### ALL OF THE ABOVE ################## | |
def translate(s, src_iso, dest_iso): | |
''' | |
Translate the text using all available models (Meta, Masakhane, and Helsinki NLP where applicable). | |
Parameters | |
---------- | |
s: str | |
The text | |
src_iso: | |
The ISO-3 code of the source language | |
dest_iso: | |
The ISO-3 code of the destination language | |
Returns | |
---------- | |
translation:str | |
The translated text, concatenated over different models | |
''' | |
# Ensure replicability | |
start_time = time.time() | |
# Translate with Meta NLLB | |
translation= "Meta's NLLB translation is:\n\n" + translate_facebook(s, src_iso, dest_iso) | |
# Check if the ISO pair is supported by another model and if so, add to translation | |
iso_pair = f"{src_iso}-{dest_iso}" | |
if iso_pair in ["mos-eng", 'eng-mos', 'fra-mos']: | |
src_iso = src_iso.lower().replace("eng", "en").replace("fra", "fr") | |
dest_iso = dest_iso.replace("eng", "en").replace("fra", "fr") | |
translation+= f"\n\n\nHelsinkiNLP's Opus translation is:\n\n {translate_helsinki_nlp(s, src_iso, dest_iso)}" | |
if iso_pair in ["mos-fra", "fra-mos"]: | |
src_iso = src_iso.lower().replace("fra", "fr") | |
dest_iso = dest_iso.replace("fra", "fr") | |
translation+= "\n\n\nMasakhane's M2M translation is:\n\n" + translate_masakhane(s, src_iso, dest_iso) | |
print("Time elapsed: ", int(time.time() - start_time), " seconds") | |
return translation | |