File size: 2,590 Bytes
6554f2c
 
 
a28a899
6554f2c
 
 
613d2ec
 
 
 
6554f2c
8754e0c
a28a899
 
6554f2c
613d2ec
6554f2c
613d2ec
 
c1b0837
613d2ec
04d97c8
6554f2c
613d2ec
 
a28a899
6554f2c
 
 
a28a899
 
 
 
613d2ec
6554f2c
 
aeb3532
 
c7070e4
aeb3532
a28a899
6554f2c
a28a899
 
 
613d2ec
 
 
 
 
 
 
c7070e4
613d2ec
 
 
 
 
 
 
 
 
 
 
 
 
10dfd97
 
 
613d2ec
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from pandas.io.formats.format import return_docstring
import streamlit as st
import pandas as pd
from transformers import AutoTokenizer, AutoModelForMaskedLM
from transformers import pipeline
import os
import json
import random

with open("config.json") as f:
    cfg = json.loads(f.read())


@st.cache(show_spinner=False, persist=True)
def load_model(masked_text, model_name):

    model = AutoModelForMaskedLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    nlp = pipeline("fill-mask", model=model, tokenizer=tokenizer)

    MASK_TOKEN = tokenizer.mask_token

    masked_text = masked_text.replace("<mask>", MASK_TOKEN)
    result_sentence = nlp(masked_text)

    return result_sentence[0]["sequence"], result_sentence[0]["token_str"]


def app():
    st.markdown(
        "<h1 style='text-align: center; color: green;'>RoBERTa Hindi</h1>",
        unsafe_allow_html=True,
    )
    st.markdown(
        "This demo uses multiple hindi transformer models for Masked Language Modelling (MLM)."
    )

    models_list = list(cfg["models"].keys())

    models = st.multiselect("Choose models", models_list, models_list)

    target_text_path = "./mlm_custom/mlm_targeted_text.csv"
    target_text_df = pd.read_csv(target_text_path)

    texts = target_text_df["text"]

    st.sidebar.title("Hindi MLM")

    results_df = pd.DataFrame(columns=["Model Name", "Filled Token", "Filled Text"])

    model_names = []
    filled_masked_texts = []
    filled_tokens = []
    pick_random = st.checkbox("Pick any random text")
    if pick_random:
        random_text = texts[random.randint(0, texts.shape[0] - 1)]
        masked_text = st.text_area("Please type a masked sentence to fill", random_text)
    else:
        select_text = st.sidebar.selectbox("Select any of the following text", texts)
        masked_text = st.text_area("Please type a masked sentence to fill", select_text)

    # pd.set_option('max_colwidth',30)
    if st.button("Fill the Mask!"):
        with st.spinner("Filling the Mask..."):

            for selected_model in models:

                filled_sentence, filled_token = load_model(
                    masked_text, cfg["models"][selected_model]
                )
                model_names.append(selected_model)
                filled_tokens.append(filled_token)
                filled_masked_texts.append(filled_sentence)

            results_df["Model Name"] = model_names
            results_df["Filled Token"] = filled_tokens
            results_df["Filled Text"] = filled_masked_texts

            st.table(results_df)