prateekagrawal commited on
Commit
613d2ec
1 Parent(s): 0911d2b

Updated inference.py

Browse files
Files changed (1) hide show
  1. apps/inference.py +61 -7
apps/inference.py CHANGED
@@ -5,21 +5,25 @@ from transformers import AutoTokenizer, AutoModelForMaskedLM
5
  from transformers import pipeline
6
  import os
7
  import json
 
 
 
 
8
 
9
 
10
  @st.cache(show_spinner=False, persist=True)
11
  def load_model(masked_text, model_name):
12
 
13
- from_flax = False
14
- if model_name == "flax-community/roberta-hindi":
15
- from_flax = True
16
- model = AutoModelForMaskedLM.from_pretrained(model_name, from_flax=from_flax)
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
18
  MASK_TOKEN = tokenizer.mask_token
 
19
  masked_text = masked_text.replace("<mask>", MASK_TOKEN)
20
- nlp = pipeline("fill-mask", model=model, tokenizer=tokenizer)
21
  result_sentence = nlp(masked_text)
22
- return result_sentence
 
23
 
24
 
25
  def app():
@@ -28,7 +32,7 @@ def app():
28
  unsafe_allow_html=True,
29
  )
30
  st.markdown(
31
- "This demo uses pretrained RoBERTa variants for Mask Language Modelling (MLM)"
32
  )
33
 
34
  target_text_path = "./mlm_custom/mlm_targeted_text.csv"
@@ -51,6 +55,56 @@ def app():
51
  ],
52
  ["flax-community/roberta-hindi"],
53
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  if st.button("Fill the Mask!"):
55
  with st.spinner("Filling the Mask..."):
56
  models = []
5
  from transformers import pipeline
6
  import os
7
  import json
8
+ import random
9
+
10
+ with open("config.json") as f:
11
+ cfg = json.loads(f.read())
12
 
13
 
14
  @st.cache(show_spinner=False, persist=True)
15
  def load_model(masked_text, model_name):
16
 
17
+ model = AutoModelForMaskedLM.from_pretrained(model_name)
 
 
 
18
  tokenizer = AutoTokenizer.from_pretrained(model_name)
19
+ nlp = pipeline("fill-mask", model=model, tokenizer=tokenizer)
20
+
21
  MASK_TOKEN = tokenizer.mask_token
22
+
23
  masked_text = masked_text.replace("<mask>", MASK_TOKEN)
 
24
  result_sentence = nlp(masked_text)
25
+
26
+ return result_sentence[0]["sequence"], result_sentence[0]["token_str"]
27
 
28
 
29
  def app():
32
  unsafe_allow_html=True,
33
  )
34
  st.markdown(
35
+ "This demo uses multiple hindi transformer models for Masked Language Modelling (MLM)."
36
  )
37
 
38
  target_text_path = "./mlm_custom/mlm_targeted_text.csv"
55
  ],
56
  ["flax-community/roberta-hindi"],
57
  )
58
+
59
+ models_list = list(cfg["models"].keys())
60
+
61
+ models = st.multiselect(
62
+ "Choose models",
63
+ models_list,
64
+ models_list[0],
65
+ )
66
+
67
+ target_text_path = "./mlm_custom/mlm_targeted_text.csv"
68
+ target_text_df = pd.read_csv(target_text_path)
69
+
70
+ texts = target_text_df["text"]
71
+
72
+ st.sidebar.title("Hindi MLM")
73
+
74
+ pick_random = st.sidebar.checkbox("Pick any random text")
75
+
76
+ results_df = pd.DataFrame(columns=["Model Name", "Filled Token", "Filled Text"])
77
+
78
+ model_names = []
79
+ filled_masked_texts = []
80
+ filled_tokens = []
81
+
82
+ if pick_random:
83
+ random_text = texts[random.randint(0, texts.shape[0] - 1)]
84
+ masked_text = st.text_area("Please type a masked sentence to fill", random_text)
85
+ else:
86
+ select_text = st.sidebar.selectbox("Select any of the following text", texts)
87
+ masked_text = st.text_area("Please type a masked sentence to fill", select_text)
88
+
89
+ # pd.set_option('max_colwidth',30)
90
+ if st.button("Fill the Mask!"):
91
+ with st.spinner("Filling the Mask..."):
92
+
93
+ for selected_model in models:
94
+
95
+ filled_sentence, filled_token = load_model(masked_text, cfg["models"][selected_model])
96
+ model_names.append(selected_model)
97
+ filled_tokens.append(filled_token)
98
+ filled_masked_texts.append(filled_sentence)
99
+
100
+ results_df["Model Name"] = model_names
101
+ results_df["Filled Token"] = filled_tokens
102
+ results_df["Filled Text"] = filled_masked_texts
103
+
104
+ st.table(results_df)
105
+
106
+
107
+
108
  if st.button("Fill the Mask!"):
109
  with st.spinner("Filling the Mask..."):
110
  models = []