Spaces:
Sleeping
Sleeping
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast | |
class Translator: | |
''' | |
Install Requirements - | |
pip install pickle5 transformers==4.12.2 sentencepiece | |
MBart Documentation | |
https://huggingface.co/transformers/model_doc/mbart.html | |
Get the supported lang codes | |
https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt | |
Class - Translator | |
Initializes MBart Seq2Seq Model and Tokenizer | |
Helper func to translate input language to desired target language | |
Supported Languages: English, Gujarati, Hindi, Bengali, Malayalam, Marathi, Tamil, Telugu | |
''' | |
def __init__(self): | |
self.model = MBartForConditionalGeneration.from_pretrained('facebook/mbart-large-50-many-to-many-mmt') | |
self.tokenizer = MBart50TokenizerFast.from_pretrained('facebook/mbart-large-50-many-to-many-mmt') | |
# self.model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50") | |
# self.tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50") | |
# , src_lang="en_XX", tgt_lang="ro_RO" | |
# https://dl-translate.readthedocs.io/en/latest/available_languages/ | |
self.supported_langs = ['en_XX', 'fr_XX', 'de_DE', 'ru_RU', 'ja_XX'] | |
def translate(self, input_text, src_lang, tgt_lang): | |
if src_lang not in self.supported_langs: | |
raise RuntimeError('Unsupported source language.') | |
if tgt_lang not in self.supported_langs: | |
raise RuntimeError('Unsupported target language.') | |
self.tokenizer.src_lang = src_lang | |
encoded_text = self.tokenizer(input_text, return_tensors='pt') | |
generated_tokens = self.model.generate(**encoded_text, forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang]) | |
output_text_arr = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) | |
if len(output_text_arr) > 0: | |
return output_text_arr[0] | |
else: | |
raise RuntimeError('Failed to generate output. Output Text Array is empty.') |