bertin / app.py
mariagrandury's picture
Fix bug in app
fa5172e
raw
history blame
3.84 kB
import random
from mtranslate import translate
import streamlit as st
from transformers import AutoTokenizer, AutoModelForMaskedLM, pipeline
LOGO = "https://raw.githubusercontent.com/nlp-en-es/assets/main/logo.png"
MODELS = {
"RoBERTa Base": {
"url": "bertin-project/bertin-roberta-base-spanish"
},
"RoBERTa Base Gaussian": {
"url": "bertin-project/bertin-base-gaussian"
},
"RoBERTa Base Random": {
"url": "bertin-project/bertin-base-random"
},
"RoBERTa Base Stepwise": {
"url": "bertin-project/bertin-base-stepwise"
},
"RoBERTa Base Gaussian Experiment": {
"url": "bertin-project/bertin-base-gaussian-exp-512seqlen"
},
"RoBERTa Base Random Experiment": {
"url": "bertin-project/bertin-base-random-exp-512seqlen"
}
}
PROMPT_LIST = [
"Fui a la librería a comprar un <mask>.",
"¡Qué buen <mask> hace hoy!",
"Hoy empiezan las vacaciones, vamos a la <mask>.",
"Mi color favorito es el <mask>.",
"Voy a <mask>, estoy muy cansada.",
"Mañana vienen mis amigos de <mask>.",
"¿Te apetece venir a <mask> conmigo?",
"En verano hace mucho <mask>.",
"En el bosque había <mask>."
]
@st.cache(show_spinner=False, persist=True)
def load_model(masked_text, model_url):
model = AutoModelForMaskedLM.from_pretrained(model_url)
tokenizer = AutoTokenizer.from_pretrained(model_url)
nlp = pipeline("fill-mask", model=model, tokenizer=tokenizer)
result = nlp(masked_text)
return result
# Page
st.set_page_config(page_title="BERTIN Demo", page_icon=LOGO)
st.title("BERTIN")
#Sidebar
st.sidebar.image(LOGO)
# Body
st.markdown(
"""
BERTIN is a series of BERT-based models for Spanish.
The models are trained with Flax and using TPUs sponsored by Google since this is part of the
[Flax/Jax Community Week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104)
organised by HuggingFace.
"""
)
model_name = st.selectbox("Model", list(MODELS.keys()))
model_url = MODELS[model_name]["url"]
prompt = st.selectbox("Prompt", ["Random", "Custom"])
if prompt == "Custom":
prompt_box = "Enter your masked text here..."
else:
prompt_box = random.choice(PROMPT_LIST)
text = st.text_area("Enter text", prompt_box)
if st.button("Fill the mask"):
with st.spinner(text="Getting results..."):
st.subheader("Result")
result = load_model(text, model_url)
if "error" in result:
if type(result["error"]) is str:
st.write(f'{result["error"]}.', end=" ")
if "estimated_time" in result:
st.write(
f'Please try again in about {result["estimated_time"]:.0f} seconds.'
)
else:
if type(result["error"]) is list:
for error in result["error"]:
st.write(f"{error}")
else:
result_sequence, result_token = result[0]["sequence"], result[0]["token_str"]
st.write(result_sequence)
st.text("English translation")
st.write(translate(result_sequence, "en", "es"))
st.markdown(
"""
### Team members
- Javier de la Rosa ([versae](https://huggingface.co/versae))
- Eduardo González ([edugp](https://huggingface.co/edugp))
- Paulo Villegas ([paulo](https://huggingface.co/paulo))
- Pablo González de Prado ([Pablogps](https://huggingface.co/Pablogps))
- Manu Romero ([mrm8488](https://huggingface.co/mrm8488))
- María Grandury ([mariagrandury](https://huggingface.co/mariagrandury))
### More information
You can find more information about these models
[here](https://huggingface.co/bertin-project/bertin-roberta-base-spanish).
"""
)