File size: 2,336 Bytes
6554f2c
 
 
a28a899
6554f2c
 
 
 
8754e0c
7b0fb0e
 
 
 
a28a899
 
 
6554f2c
c320b57
 
 
 
 
6554f2c
a28a899
6554f2c
 
 
a28a899
 
6554f2c
 
 
a28a899
 
 
 
 
6554f2c
 
a28a899
6554f2c
a28a899
 
 
6554f2c
a28a899
6554f2c
a28a899
6554f2c
7b0fb0e
6554f2c
a28a899
 
 
 
 
 
 
 
 
 
 
8754e0c
 
 
c320b57
8754e0c
e668d73
 
8754e0c
 
 
 
 
 
7b0fb0e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from pandas.io.formats.format import return_docstring
import streamlit as st
import pandas as pd
from transformers import AutoTokenizer, AutoModelForMaskedLM
from transformers import pipeline
import os
import json


models = []
predicted_tokens = []
predicted_sentence = []


@st.cache(show_spinner=False, persist=True)
def load_model(masked_text, model_name):

    from_flax = False
    if model_name == "flax-community/roberta-hindi":
        from_flax = True

    model = AutoModelForMaskedLM.from_pretrained(model_name, from_flax=from_flax)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    nlp = pipeline("fill-mask", model=model, tokenizer=tokenizer)

    result_sentence = nlp(masked_text)

    return result_sentence


def app():
    st.markdown(
        "<h1 style='text-align: center; color: green;'>RoBERTa Hindi</h1>",
        unsafe_allow_html=True,
    )
    st.markdown(
        "This demo uses pretrained RoBERTa variants for Mask Language Modelling (MLM)"
    )

    target_text_path = "./mlm_custom/mlm_targeted_text.csv"
    target_text_df = pd.read_csv(target_text_path)

    texts = target_text_df["text"]

    st.markdown("""## Select any of the following text : """)
    masked_text = st.selectbox("", texts)

    st.write("You selected:", masked_text)

    selected_models = st.multiselect(
        "Choose models",
        [
            "flax-community/roberta-hindi",
            "mrm8488/HindiBERTa",
            "ai4bharat/indic-bert",
            "neuralspace-reverie/indic-transformers-hi-bert",
            "surajp/RoBERTa-hindi-guj-san",
        ],
        ["flax-community/roberta-hindi"],
    )
    if st.button("Fill the Mask!"):
        with st.spinner("Filling the Mask..."):

            for i in range(len(selected_models)):
                filled_sentence = load_model(masked_text, selected_models[i])
                # st.write(filled_sentence)
                models.append(selected_models[i])
                predicted_tokens.append(filled_sentence[0]["token_str"])
                predicted_sentence.append(filled_sentence[0]["sequence"])

            results_df = pd.DataFrame()
            results_df["Model Name"] = models
            results_df["Predicted Word"] = predicted_tokens
            results_df["Sentence"] = predicted_sentence
            st.dataframe(results_df)