roberta-testing / apps /inference.py
prateekagrawal's picture
Updated flax error in inference
c320b57
raw
history blame
No virus
2.34 kB
from pandas.io.formats.format import return_docstring
import streamlit as st
import pandas as pd
from transformers import AutoTokenizer, AutoModelForMaskedLM
from transformers import pipeline
import os
import json
models = []
predicted_tokens = []
predicted_sentence = []
@st.cache(show_spinner=False, persist=True)
def load_model(masked_text, model_name):
from_flax = False
if model_name == "flax-community/roberta-hindi":
from_flax = True
model = AutoModelForMaskedLM.from_pretrained(model_name, from_flax=from_flax)
tokenizer = AutoTokenizer.from_pretrained(model_name)
nlp = pipeline("fill-mask", model=model, tokenizer=tokenizer)
result_sentence = nlp(masked_text)
return result_sentence
def app():
st.markdown(
"<h1 style='text-align: center; color: green;'>RoBERTa Hindi</h1>",
unsafe_allow_html=True,
)
st.markdown(
"This demo uses pretrained RoBERTa variants for Mask Language Modelling (MLM)"
)
target_text_path = "./mlm_custom/mlm_targeted_text.csv"
target_text_df = pd.read_csv(target_text_path)
texts = target_text_df["text"]
st.markdown("""## Select any of the following text : """)
masked_text = st.selectbox("", texts)
st.write("You selected:", masked_text)
selected_models = st.multiselect(
"Choose models",
[
"flax-community/roberta-hindi",
"mrm8488/HindiBERTa",
"ai4bharat/indic-bert",
"neuralspace-reverie/indic-transformers-hi-bert",
"surajp/RoBERTa-hindi-guj-san",
],
["flax-community/roberta-hindi"],
)
if st.button("Fill the Mask!"):
with st.spinner("Filling the Mask..."):
for i in range(len(selected_models)):
filled_sentence = load_model(masked_text, selected_models[i])
# st.write(filled_sentence)
models.append(selected_models[i])
predicted_tokens.append(filled_sentence[0]["token_str"])
predicted_sentence.append(filled_sentence[0]["sequence"])
results_df = pd.DataFrame()
results_df["Model Name"] = models
results_df["Predicted Word"] = predicted_tokens
results_df["Sentence"] = predicted_sentence
st.dataframe(results_df)