Spaces:
Runtime error
Runtime error
File size: 1,438 Bytes
6f1d0d4 a5cb888 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import streamlit as st
from langdetect import detect
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
@st.cache
def load_data():
supported_languages = [
'en_XX',
'et_EE',
]
return {k.split('_')[0]:k for k in supported_languages}
@st.cache(allow_output_mutation=True, suppress_st_warning=True)
def load_model():
model_name = "facebook/mbart-large-50-many-to-many-mmt"
model = MBartForConditionalGeneration.from_pretrained(model_name)
tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
return (model, tokenizer)
data = load_data()
def translate_to_english(model, tokenizer, text):
src_lang = detect(text)
if src_lang in data:
tokenizer.src_lang = src_lang
encoded_txt = tokenizer(text, return_tensors="pt")
generated_tokens = model.generate(
**encoded_txt,
forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"]
)
return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
else:
print(f"Language {src_lang} not found")
return
st.title("Auto Translate (To English)")
text = st.text_input(f"Write in any (1 of {len(data.keys())}) language")
st.text("What you wrote: ")
st.write(text)
st.text("English Translation: ")
if text:
model, tokenizer = load_model()
translated_text = translate_to_english(model, tokenizer, text)
st.write(translated_text[0] if translated_text else "No translation found") |