ajitrajasekharan commited on
Commit
d1b63cc
1 Parent(s): 04d9f88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -43
app.py CHANGED
@@ -3,10 +3,6 @@ import streamlit as st
3
  import torch
4
  import string
5
 
6
- bert_tokenizer = None
7
- bert_model = None
8
- global top_k
9
- model_name = "ajitrajaskharan/biomedical"
10
 
11
 
12
  from transformers import BertTokenizer, BertForMaskedLM
@@ -75,14 +71,12 @@ def get_bert_prediction(input_text,top_k,model_name):
75
 
76
  def run_test(sent,top_k,model_name):
77
  start = None
78
- global bert_tokenizer
79
- global bert_model
80
- if (bert_tokenizer is None):
81
- bert_tokenizer, bert_model = load_bert_model(model_name)
82
  with st.spinner("Computing"):
83
  start = time.time()
84
  try:
85
- res = get_bert_prediction(sent,top_k,model_name)
86
  st.caption("Results in JSON")
87
  st.json(res)
88
 
@@ -93,47 +87,48 @@ def run_test(sent,top_k,model_name):
93
  st.text(f"prediction took {time.time() - start:.2f}s")
94
 
95
  def on_text_change():
96
- global top_k,model_name
97
  text = st.session_state.my_text
98
- run_test(text,top_k,model_name)
99
 
100
  def on_option_change():
101
- global top_k,model_name
102
  text = st.session_state.my_choice
103
- run_test(text,top_k,model_name)
104
 
105
  def on_results_count_change():
106
- global top_k
107
- top_k = int(st.session_state.my_slider)
108
- st.info("Results count changed " + str(top_k))
109
 
110
  def on_model_change1():
111
- global model_name
112
- global bert_tokenizer
113
- global bert_model
114
- model_name = st.session_state.my_model1
115
- st.info("Pre-selected model chosen: " + model_name)
116
- bert_tokenizer, bert_model = load_bert_model(model_name)
117
 
118
  def on_model_change2():
119
- global model_name
120
- global bert_tokenizer
121
- global bert_model
122
- model_name = st.session_state.my_model2
123
- st.info("Custom model chosen: " + model_name)
124
- bert_tokenizer, bert_model = load_bert_model(model_name)
125
 
126
  def init_selectbox():
127
  st.selectbox(
128
  'Choose any of these sentences or type any text below',
129
  ('', "[MASK] who lives in New York and works for XCorp suffers from Parkinson's", "Lou Gehrig who lives in [MASK] and works for XCorp suffers from Parkinson's","Lou Gehrig who lives in New York and works for [MASK] suffers from Parkinson's","Lou Gehrig who lives in New York and works for XCorp suffers from [MASK]","[MASK] who lives in New York and works for XCorp suffers from Lou Gehrig's", "Parkinson who lives in [MASK] and works for XCorp suffers from Lou Gehrig's","Parkinson who lives in New York and works for [MASK] suffers from Lou Gehrig's","Parkinson who lives in New York and works for XCorp suffers from [MASK]","Lou Gehrig","Parkinson","Lou Gehrigh's is a [MASK]","Parkinson is a [MASK]","New York is a [MASK]","New York","XCorp","XCorp is a [MASK]","acute lymphoblastic leukemia","acute lymphoblastic leukemia is a [MASK]"),on_change=on_option_change,key='my_choice')
130
 
131
-
 
 
 
 
 
 
 
 
132
 
133
  def main():
134
- global top_k
135
- global bert_tokenizer
136
- global bert_model
137
 
138
 
139
  st.markdown("<h3 style='text-align: center;'>Qualitative evaluation of any pretrained BERT model</h3>", unsafe_allow_html=True)
@@ -144,8 +139,8 @@ def main():
144
  st.write("This app can be used to examine both model prediction for a masked position as well as the neighborhood of CLS vector")
145
  st.write(" - To examine model prediction for a position, enter the token [MASK] or <mask>")
146
  st.write(" - To examine just the [CLS] vector, enter a word/phrase or sentence. Example: eGFR or EGFR or non small cell lung cancer")
147
- top_k = st.sidebar.slider("Select how many predictions do you need", 1 , 50, 20,key='my_slider',on_change=on_results_count_change) #some times it is possible to have less words
148
- print(top_k)
149
 
150
 
151
 
@@ -153,13 +148,13 @@ def main():
153
 
154
  # with st.spinner("Computing"):
155
  try:
156
- model_name = st.sidebar.selectbox(label='Select Model to Apply', options=['ajitrajasekharan/biomedical', 'bert-base-cased','bert-large-cased','microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext','allenai/scibert_scivocab_cased','dmis-lab/biobert-v1.1'], index=0, key = "my_model1",on_change=on_model_change1)
157
  init_selectbox()
158
  st.text_input("Enter text below", "",on_change=on_text_change,key='my_text')
159
- custom_model_name = st.text_input("Model not listed on left? Type the model name (fill-mask BERT models only)", "",key="my_model2",on_change=on_model_change2)
160
- if (len(custom_model_name) > 0):
161
- model_name = custom_model_name
162
- st.info("Custom model selected: " + model_name)
163
  # bert_tokenizer, bert_model = load_bert_model(model_name)
164
  #if len(input_text) > 0:
165
  # run_test(input_text,top_k,model_name)
@@ -167,10 +162,10 @@ def main():
167
  # if len(option) > 0:
168
  # run_test(option,top_k,model_name)
169
 
170
- st.info("Top k = " + str(top_k))
171
- st.info("Model name = " + model_name)
172
- if (bert_tokenizer is None):
173
- bert_tokenizer, bert_model = load_bert_model(model_name)
174
 
175
 
176
 
 
3
  import torch
4
  import string
5
 
 
 
 
 
6
 
7
 
8
  from transformers import BertTokenizer, BertForMaskedLM
 
71
 
72
  def run_test(sent,top_k,model_name):
73
  start = None
74
+ if (st.session_state['bert_tokenizer'] is None):
75
+ st.session_state['bert_tokenizer'], st.session_state['bert_model'] = load_bert_model(st.session_state['model_name'])
 
 
76
  with st.spinner("Computing"):
77
  start = time.time()
78
  try:
79
+ res = get_bert_prediction(sent,st.session_state['top_k'],st.session_state['model_name'])
80
  st.caption("Results in JSON")
81
  st.json(res)
82
 
 
87
  st.text(f"prediction took {time.time() - start:.2f}s")
88
 
89
  def on_text_change():
90
+
91
  text = st.session_state.my_text
92
+ run_test(text,st.session_state['top_k']),st.session_state['model_name'])
93
 
94
  def on_option_change():
95
+
96
  text = st.session_state.my_choice
97
+ run_test(text,st.session_state['top_k']),st.session_state['model_name'])
98
 
99
  def on_results_count_change():
100
+
101
+ st.session_state['top_k'] = int(st.session_state.my_slider)
102
+ st.info("Results count changed " + str(st.session_state['top_k']))
103
 
104
  def on_model_change1():
105
+ st.session_state['model_name'] = st.session_state.my_model1
106
+ st.info("Pre-selected model chosen: " + st.session_state['model_name'])
107
+ st.session_state['bert_tokenizer'], st.session_state['bert_model'] = load_bert_model(st.session_state['model_name'])
 
 
 
108
 
109
  def on_model_change2():
110
+ st.session_state['model_name'] = st.session_state.my_model2
111
+ st.info("Custom model chosen: " + st.session_state['model_name'])
112
+ st.session_state['bert_tokenizer'], st.session_state['bert_model'] = load_bert_model(st.session_state['model_name'])
 
 
 
113
 
114
  def init_selectbox():
115
  st.selectbox(
116
  'Choose any of these sentences or type any text below',
117
  ('', "[MASK] who lives in New York and works for XCorp suffers from Parkinson's", "Lou Gehrig who lives in [MASK] and works for XCorp suffers from Parkinson's","Lou Gehrig who lives in New York and works for [MASK] suffers from Parkinson's","Lou Gehrig who lives in New York and works for XCorp suffers from [MASK]","[MASK] who lives in New York and works for XCorp suffers from Lou Gehrig's", "Parkinson who lives in [MASK] and works for XCorp suffers from Lou Gehrig's","Parkinson who lives in New York and works for [MASK] suffers from Lou Gehrig's","Parkinson who lives in New York and works for XCorp suffers from [MASK]","Lou Gehrig","Parkinson","Lou Gehrigh's is a [MASK]","Parkinson is a [MASK]","New York is a [MASK]","New York","XCorp","XCorp is a [MASK]","acute lymphoblastic leukemia","acute lymphoblastic leukemia is a [MASK]"),on_change=on_option_change,key='my_choice')
118
 
119
+ def init_session_states():
120
+ if 'top_k' not in st.session_state:
121
+ st.session_state['top_k'] = 20
122
+ if 'bert_tokenizer' not in st.session_state:
123
+ st.session_state['bert_tokenizer'] = None
124
+ if 'bert_model' not in st.session_state:
125
+ st.session_state['bert_model'] = None
126
+ if 'model_name' not in st.session_state:
127
+ st.session_state['model_name'] = "ajitrajasekharan/biomedical"
128
 
129
  def main():
130
+ init_session_states()
131
+
 
132
 
133
 
134
  st.markdown("<h3 style='text-align: center;'>Qualitative evaluation of any pretrained BERT model</h3>", unsafe_allow_html=True)
 
139
  st.write("This app can be used to examine both model prediction for a masked position as well as the neighborhood of CLS vector")
140
  st.write(" - To examine model prediction for a position, enter the token [MASK] or <mask>")
141
  st.write(" - To examine just the [CLS] vector, enter a word/phrase or sentence. Example: eGFR or EGFR or non small cell lung cancer")
142
+ st.sidebar.slider("Select how many predictions do you need", 1 , 50, 20,key='my_slider',on_change=on_results_count_change) #some times it is possible to have less words
143
+
144
 
145
 
146
 
 
148
 
149
  # with st.spinner("Computing"):
150
  try:
151
+ st.sidebar.selectbox(label='Select Model to Apply', options=['ajitrajasekharan/biomedical', 'bert-base-cased','bert-large-cased','microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext','allenai/scibert_scivocab_cased','dmis-lab/biobert-v1.1'], index=0, key = "my_model1",on_change=on_model_change1)
152
  init_selectbox()
153
  st.text_input("Enter text below", "",on_change=on_text_change,key='my_text')
154
+ st.text_input("Model not listed on left? Type the model name (fill-mask BERT models only)", "",key="my_model2",on_change=on_model_change2)
155
+ #if (len(custom_model_name) > 0):
156
+ # model_name = custom_model_name
157
+ # st.info("Custom model selected: " + model_name)
158
  # bert_tokenizer, bert_model = load_bert_model(model_name)
159
  #if len(input_text) > 0:
160
  # run_test(input_text,top_k,model_name)
 
162
  # if len(option) > 0:
163
  # run_test(option,top_k,model_name)
164
 
165
+ st.info("Top k = " + str(st.session_state['top_k']))
166
+ st.info("Model name = " + st.session_state['model_name'])
167
+ if (st.session_state['bert_tokenizer'] is None):
168
+ st.session_state['bert_tokenizer'], st.session_state['bert_model'] = load_bert_model(st.session_state['model_name'])
169
 
170
 
171