File size: 4,342 Bytes
e5e9b34 944775a e5e9b34 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import torch
from transformers import set_seed, pipeline
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
######### HELSINKI NLP ##################
def translate_helsinki_nlp(s:str, src_iso:str, dest_iso:str)-> str:
'''
Translate the text using HelsinkiNLP's Opus models for Mossi language.
Parameters
----------
s: str
The text
src_iso:
The ISO-3 code of the source language
dest_iso:
The ISO-3 code of the destination language
Returns
----------
translation:str
The translated text
'''
# Ensure replicability
set_seed(555)
# Inference
translator = pipeline("translation", model=f"Helsinki-NLP/opus-mt-{src_iso}-{dest_iso}")
translation = translator(s)[0]['translation_text']
return translation
######### MASAKHANE ##################
def translate_masakhane(s:str, src_iso:str, dest_iso:str)-> str:
'''
Translate the text using Masakhane's M2M models for Mossi language.
Parameters
----------
s: str
The text
src_iso:
The ISO-3 code of the source language
dest_iso:
The ISO-3 code of the destination language
Returns
----------
translation:str
The translated text
'''
# Ensure replicability
set_seed(555)
# Load model
model = M2M100ForConditionalGeneration.from_pretrained(f"masakhane/m2m100_418m_{src_iso}_{dest_iso}_news")
tokenizer = M2M100Tokenizer.from_pretrained(f"masakhane/m2m100_418m_{src_iso}_{dest_iso}_news")
# Inference
encoded = tokenizer(s, return_tensors="pt")
generated_tokens = model.generate(**encoded)
translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
return translation
######### META ##################
def translate_facebook(s:str, src_iso:str, dest_iso:str)-> str:
'''
Translate the text using Meta's NLLB model for Mossi language.
Parameters
----------
s: str
The text
src_iso:
The ISO-3 code of the source language
dest_iso:
The ISO-3 code of the destination language
Returns
----------
translation:str
The translated text
'''
# Ensure replicability
set_seed(555)
# Load model
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M",src_lang=f"{src_iso}_Latn") #use_auth_token=True,
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") #, use_auth_token=True)
# Inference
encoded = tokenizer(s, return_tensors="pt")
translated_tokens = model.generate(**encoded, forced_bos_token_id=tokenizer.lang_code_to_id[f"{dest_iso}_Latn"], max_length=30)
translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
return translation
######### ALL OF THE ABOVE ##################
def translate(s, src_iso, dest_iso):
'''
Translate the text using all available models (Meta, Masakhane, and Helsinki NLP where applicable).
Parameters
----------
s: str
The text
src_iso:
The ISO-3 code of the source language
dest_iso:
The ISO-3 code of the destination language
Returns
----------
translation:str
The translated text, concatenated over different models
'''
# Translate with Meta NLLB
translation= "Meta's NLLB translation is:\n\n" + translate_facebook(s, src_iso, dest_iso)
# Check if the ISO pair is supported by another model and if so, add to translation
iso_pair = f"{src_iso}-{dest_iso}"
if iso_pair in ["mos-eng", 'eng-mos', 'fra-mos']:
src_iso = src_iso.lower().replace("eng", "en").replace("fra", "fr")
dest_iso = dest_iso.replace("eng", "en").replace("fra", "fr")
translation+= f"\n\n\nHelsinkiNLP's Opus translation is:\n\n {translate_helsinki_nlp(s, src_iso, dest_iso)}"
if iso_pair in ["mos-fra", "fra-mos"]:
src_iso = src_iso.lower().replace("fra", "fr")
dest_iso = dest_iso.replace("fra", "fr")
translation+= "\n\n\nMasakhane's M2M translation is:\n\n" + translate_masakhane(s, src_iso, dest_iso)
return translation
|