Spaces:
Runtime error
Runtime error
import streamlit as st | |
from langdetect import detect | |
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast | |
def load_data(): | |
supported_languages = [ | |
'en_XX', | |
'et_EE', | |
] | |
return {k.split('_')[0]:k for k in supported_languages} | |
def load_model(): | |
model_name = "facebook/mbart-large-50-many-to-many-mmt" | |
model = MBartForConditionalGeneration.from_pretrained(model_name) | |
tokenizer = MBart50TokenizerFast.from_pretrained(model_name) | |
return (model, tokenizer) | |
data = load_data() | |
def translate_to_english(model, tokenizer, text): | |
src_lang = detect(text) | |
if src_lang in data: | |
tokenizer.src_lang = src_lang | |
encoded_txt = tokenizer(text, return_tensors="pt") | |
generated_tokens = model.generate( | |
**encoded_txt, | |
forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"] | |
) | |
return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) | |
else: | |
print(f"Language {src_lang} not found") | |
return | |
st.title("Auto Translate (To English)") | |
text = st.text_input(f"Write in any (1 of {len(data.keys())}) language") | |
st.text("What you wrote: ") | |
st.write(text) | |
st.text("English Translation: ") | |
if text: | |
model, tokenizer = load_model() | |
translated_text = translate_to_english(model, tokenizer, text) | |
st.write(translated_text[0] if translated_text else "No translation found") |