shygtfdsaxz / app.py
aditi2222's picture
Update app.py
a5cb888
import streamlit as st
from langdetect import detect
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
@st.cache
def load_data():
supported_languages = [
'en_XX',
'et_EE',
]
return {k.split('_')[0]:k for k in supported_languages}
@st.cache(allow_output_mutation=True, suppress_st_warning=True)
def load_model():
model_name = "facebook/mbart-large-50-many-to-many-mmt"
model = MBartForConditionalGeneration.from_pretrained(model_name)
tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
return (model, tokenizer)
data = load_data()
def translate_to_english(model, tokenizer, text):
src_lang = detect(text)
if src_lang in data:
tokenizer.src_lang = src_lang
encoded_txt = tokenizer(text, return_tensors="pt")
generated_tokens = model.generate(
**encoded_txt,
forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"]
)
return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
else:
print(f"Language {src_lang} not found")
return
st.title("Auto Translate (To English)")
text = st.text_input(f"Write in any (1 of {len(data.keys())}) language")
st.text("What you wrote: ")
st.write(text)
st.text("English Translation: ")
if text:
model, tokenizer = load_model()
translated_text = translate_to_english(model, tokenizer, text)
st.write(translated_text[0] if translated_text else "No translation found")