Spaces:
Runtime error
Runtime error
Commit
•
a28a899
1
Parent(s):
ce82499
updated inference.py
Browse files- apps/inference.py +30 -22
apps/inference.py
CHANGED
@@ -1,50 +1,58 @@
|
|
1 |
from pandas.io.formats.format import return_docstring
|
2 |
import streamlit as st
|
3 |
import pandas as pd
|
4 |
-
from transformers import AutoTokenizer,AutoModelForMaskedLM
|
5 |
from transformers import pipeline
|
6 |
import os
|
7 |
import json
|
8 |
|
9 |
-
|
10 |
-
|
|
|
11 |
|
12 |
model = AutoModelForMaskedLM.from_pretrained(model_name, from_flax=True)
|
13 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
14 |
-
nlp = pipeline(
|
15 |
|
16 |
result_sentence = nlp(masked_text)
|
17 |
|
18 |
-
return result_sentence
|
|
|
19 |
|
20 |
def app():
|
21 |
-
st.markdown("<h1 style='text-align: center; color: green;'>RoBERTa Hindi</h1>", unsafe_allow_html=True)
|
22 |
st.markdown(
|
23 |
-
|
|
|
|
|
|
|
|
|
24 |
)
|
25 |
|
26 |
-
target_text_path =
|
27 |
target_text_df = pd.read_csv(target_text_path)
|
28 |
-
|
29 |
-
texts = target_text_df[
|
30 |
-
|
31 |
st.markdown("""## Select any of the following text : """)
|
32 |
-
masked_text = st.selectbox(
|
33 |
-
texts)
|
34 |
|
35 |
-
st.write(
|
36 |
|
37 |
models = st.multiselect(
|
38 |
"Choose models",
|
39 |
-
[
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
44 |
|
45 |
selected_model = models[0]
|
46 |
|
47 |
-
if st.button(
|
48 |
-
|
49 |
-
filled_sentence = load_model(masked_text,selected_model)
|
50 |
st.write(filled_sentence)
|
|
|
1 |
from pandas.io.formats.format import return_docstring
|
2 |
import streamlit as st
|
3 |
import pandas as pd
|
4 |
+
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
5 |
from transformers import pipeline
|
6 |
import os
|
7 |
import json
|
8 |
|
9 |
+
|
10 |
+
@st.cache(show_spinner=False, persist=True)
|
11 |
+
def load_model(masked_text, model_name):
|
12 |
|
13 |
model = AutoModelForMaskedLM.from_pretrained(model_name, from_flax=True)
|
14 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
15 |
+
nlp = pipeline("fill-mask", model=model, tokenizer=tokenizer)
|
16 |
|
17 |
result_sentence = nlp(masked_text)
|
18 |
|
19 |
+
return result_sentence
|
20 |
+
|
21 |
|
22 |
def app():
|
|
|
23 |
st.markdown(
|
24 |
+
"<h1 style='text-align: center; color: green;'>RoBERTa Hindi</h1>",
|
25 |
+
unsafe_allow_html=True,
|
26 |
+
)
|
27 |
+
st.markdown(
|
28 |
+
"This demo uses pretrained RoBERTa variants for Mask Language Modelling (MLM)"
|
29 |
)
|
30 |
|
31 |
+
target_text_path = "./mlm_custom/mlm_targeted_text.csv"
|
32 |
target_text_df = pd.read_csv(target_text_path)
|
33 |
+
|
34 |
+
texts = target_text_df["text"]
|
35 |
+
|
36 |
st.markdown("""## Select any of the following text : """)
|
37 |
+
masked_text = st.selectbox("", texts)
|
|
|
38 |
|
39 |
+
st.write("You selected:", masked_text)
|
40 |
|
41 |
models = st.multiselect(
|
42 |
"Choose models",
|
43 |
+
[
|
44 |
+
"flax-community/roberta-hindi",
|
45 |
+
"mrm8488/HindiBERTa",
|
46 |
+
"ai4bharat/indic-bert",
|
47 |
+
"neuralspace-reverie/indic-transformers-hi-bert",
|
48 |
+
"surajp/RoBERTa-hindi-guj-san",
|
49 |
+
],
|
50 |
+
["flax-community/roberta-hindi"],
|
51 |
+
)
|
52 |
|
53 |
selected_model = models[0]
|
54 |
|
55 |
+
if st.button("Fill the Mask!"):
|
56 |
+
with st.spinner("Filling the Mask..."):
|
57 |
+
filled_sentence = load_model(masked_text, selected_model)
|
58 |
st.write(filled_sentence)
|