translationbot_g2 / translate.py
Lekha Vulavala
Added translation controls and code (#1)
1f05fa7 verified
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
class Translator:
'''
Install Requirements -
pip install pickle5 transformers==4.12.2 sentencepiece
MBart Documentation
https://huggingface.co/transformers/model_doc/mbart.html
Get the supported lang codes
https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt
Class - Translator
Initializes MBart Seq2Seq Model and Tokenizer
Helper func to translate input language to desired target language
Supported Languages: English, Gujarati, Hindi, Bengali, Malayalam, Marathi, Tamil, Telugu
'''
def __init__(self):
self.model = MBartForConditionalGeneration.from_pretrained('facebook/mbart-large-50-many-to-many-mmt')
self.tokenizer = MBart50TokenizerFast.from_pretrained('facebook/mbart-large-50-many-to-many-mmt')
# self.model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50")
# self.tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50")
# , src_lang="en_XX", tgt_lang="ro_RO"
# https://dl-translate.readthedocs.io/en/latest/available_languages/
self.supported_langs = ['en_XX', 'fr_XX', 'de_DE', 'ru_RU', 'ja_XX']
def translate(self, input_text, src_lang, tgt_lang):
if src_lang not in self.supported_langs:
raise RuntimeError('Unsupported source language.')
if tgt_lang not in self.supported_langs:
raise RuntimeError('Unsupported target language.')
self.tokenizer.src_lang = src_lang
encoded_text = self.tokenizer(input_text, return_tensors='pt')
generated_tokens = self.model.generate(**encoded_text, forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang])
output_text_arr = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
if len(output_text_arr) > 0:
return output_text_arr[0]
else:
raise RuntimeError('Failed to generate output. Output Text Array is empty.')