roberta-hindi / app.py
hassiahk's picture
Model changes and code formatting
666b7aa
import json
import random
import pandas as pd
import streamlit as st
from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline
with open("config.json") as f:
cfg = json.loads(f.read())
@st.cache(show_spinner=False, persist=True)
def load_model(masked_text, model_name):
model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
nlp = pipeline("fill-mask", model=model, tokenizer=tokenizer)
MASK_TOKEN = tokenizer.mask_token
masked_text = masked_text.replace("<mask>", MASK_TOKEN)
result_sentence = nlp(masked_text)
return result_sentence[0]["sequence"], result_sentence[0]["token_str"]
def main():
st.title("RoBERTa Hindi")
st.markdown(
"This demo uses the below pretrained BERT variants for Mask Language Modeling (MLM):\n"
"- [RoBERTa Hindi](https://huggingface.co/flax-community/roberta-hindi)\n"
"- [Indic Transformers Hindi](https://huggingface.co/neuralspace-reverie/indic-transformers-hi-bert)\n"
"- [HindiBERTa](https://huggingface.co/mrm8488/HindiBERTa)\n"
"- [RoBERTa Hindi Guj San](https://huggingface.co/surajp/RoBERTa-hindi-guj-san)"
)
models_list = list(cfg["models"].keys())
models = st.multiselect(
"Choose models",
models_list,
models_list[0],
)
target_text_path = "./mlm_custom/mlm_targeted_text.csv"
target_text_df = pd.read_csv(target_text_path)
texts = target_text_df["text"]
st.sidebar.title("Hindi MLM")
pick_random = st.sidebar.checkbox("Pick any random text")
results_df = pd.DataFrame(columns=["Model Name", "Filled Token", "Filled Text"])
model_names = []
filled_masked_texts = []
filled_tokens = []
if pick_random:
random_text = texts[random.randint(0, texts.shape[0] - 1)]
masked_text = st.text_area("Please type a masked sentence to fill", random_text)
else:
select_text = st.sidebar.selectbox("Select any of the following text", texts)
masked_text = st.text_area("Please type a masked sentence to fill", select_text)
# pd.set_option('max_colwidth',30)
if st.button("Fill the Mask!"):
with st.spinner("Filling the Mask..."):
for selected_model in models:
filled_sentence, filled_token = load_model(masked_text, cfg["models"][selected_model])
model_names.append(selected_model)
filled_tokens.append(filled_token)
filled_masked_texts.append(filled_sentence)
results_df["Model Name"] = model_names
results_df["Filled Token"] = filled_tokens
results_df["Filled Text"] = filled_masked_texts
st.table(results_df)
if __name__ == "__main__":
main()