rhobit commited on
Commit
022cb46
·
1 Parent(s): 246b547

Optimize performance and tweak UI

Browse files
Files changed (1) hide show
  1. app.py +145 -121
app.py CHANGED
@@ -9,16 +9,15 @@ from sklearn.decomposition import PCA
9
  from sklearn.manifold import TSNE
10
  from sentence_transformers import SentenceTransformer
11
  from transformers import BertTokenizer,BertForMaskedLM
12
- import cv2
13
  import io
14
  import time
15
 
16
- @st.cache(show_spinner=False,allow_output_mutation=True)
17
  def load_sentence_model():
18
  sentence_model = SentenceTransformer('paraphrase-distilroberta-base-v1')
19
  return sentence_model
20
 
21
- @st.cache(show_spinner=False)
22
  def load_model(model_name):
23
  if model_name.startswith('bert'):
24
  tokenizer = BertTokenizer.from_pretrained(model_name)
@@ -30,7 +29,7 @@ def load_model(model_name):
30
  def load_data(sentence_num):
31
  df = pd.read_csv('tsne_out.csv')
32
  df = df.loc[lambda d: (d['sentence_num']==sentence_num)&(d['iter_num']<1000)]
33
- return df
34
 
35
  #@st.cache(show_spinner=False)
36
  def mask_prob(model,mask_id,sentences,position,temp=1):
@@ -67,7 +66,25 @@ def run_chains(tokenizer,model,mask_id,input_text,num_steps):
67
  sentence,_ = sample_words(probs,pos,sentence)
68
  return pd.DataFrame(data=data_list,columns=['step','sentence','next_sample_loc'])
69
 
70
- #@st.cache(suppress_st_warning=True,show_spinner=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def run_tsne(chain):
72
  st.sidebar.write('Running t-SNE...')
73
  st.sidebar.write('This takes ~1 min for 1000 steps with ~10 token sentences')
@@ -81,20 +98,92 @@ def run_tsne(chain):
81
  tsne = pd.concat([chain, pd.DataFrame(tsne_vals, columns = ['x_tsne', 'y_tsne'],index=chain.index)], axis = 1)
82
  return tsne
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  def clear_df():
85
  if 'df' in st.session_state:
86
  del st.session_state['df']
87
 
88
- def update_sent_id(increment_value=0):
89
- sent_id = st.session_state.sent_id
90
- sent_id += increment_value
91
- sent_id = min(len(st.session_state.df)-1,max(0,sent_id))
92
- st.session_state.sent_id = sent_id
93
-
94
- def initialize_sent_id():
95
- st.session_state.sent_id = st.session_state.sent_id_from_slider
96
 
97
  if __name__=='__main__':
 
98
  # Config
99
  max_width = 1500
100
  padding_top = 0
@@ -121,139 +210,74 @@ if __name__=='__main__':
121
  """
122
  st.markdown(define_margins, unsafe_allow_html=True)
123
  st.markdown(hide_table_row_index, unsafe_allow_html=True)
 
 
 
 
 
124
 
125
  # Title
126
  st.header("Demo: Probing BERT's priors with serial reproduction chains")
127
- # Descriptions
128
- with st.expander("Expand to read the descriptions"):
129
- st.text("Let's explore sentences in the serial reproduction chains generated by BERT!")
130
- st.text("First, please choose the samples from the two pre-generated chains,")
131
- st.text("or specify your own initial sentence, from which you can generate samples.")
132
- st.text("After selecting the chain, you can use the slider to choose the starting point")
133
- st.text("and then either click through steps or watch the autoplay.")
134
- st.text("Finally, you can check 'Show candidates', to see which words are proposed")
135
- st.text("when each word is masked out.")
136
  # Load BERT
137
  tokenizer,model = load_model('bert-base-uncased')
138
  mask_id = tokenizer.encode("[MASK]")[1:-1][0]
139
 
140
  # First step: load the dataframe containing sentences
141
- input_type = st.sidebar.radio(label='1. Choose the input type',on_change=clear_df,
142
- options=('Use one of the example sentences','Use your own initial sentence'))
143
  if input_type=='Use one of the example sentences':
144
  sentence = st.sidebar.selectbox("Select the inital sentence",
145
  ('--- Please select one from below ---',
146
- 'About 170 campers attend the camps each week.',
147
- 'She grew up with three brothers and ten sisters.'))
 
148
  if sentence!='--- Please select one from below ---':
149
  if sentence=='About 170 campers attend the camps each week.':
150
  sentence_num = 6
151
  elif sentence=='She grew up with three brothers and ten sisters.':
152
  sentence_num = 8
 
 
153
  st.session_state.df = load_data(sentence_num)
 
154
  else:
155
- sentence = st.sidebar.text_input('Type down your own sentence here.',on_change=clear_df)
156
- num_steps = st.sidebar.number_input(label='How many steps do you want to run?',value=1000)
157
  if st.sidebar.button('Run chains'):
158
- chain = run_chains(tokenizer,model,mask_id,sentence,num_steps=num_steps)
159
  st.session_state.df = run_tsne(chain)
160
  st.session_state.finished_sampling = True
161
- st.session_state.sent_id = 0
 
 
 
 
 
 
 
 
162
 
163
  if 'df' in st.session_state:
164
  df = st.session_state.df
165
- x_tsne, y_tsne = df.x_tsne, df.y_tsne
166
- xscale_unit = (max(x_tsne)-min(x_tsne))/10
167
- yscale_unit = (max(y_tsne)-min(y_tsne))/10
168
- xlims = [(min(x_tsne)//xscale_unit-1)*xscale_unit,(max(x_tsne)//xscale_unit+1)*xscale_unit]
169
- ylims = [(min(y_tsne)//yscale_unit-1)*yscale_unit,(max(y_tsne)//yscale_unit+1)*yscale_unit]
170
- color_list = sns.color_palette('flare',n_colors=int(len(df)*1.2))
171
 
172
- st.sidebar.slider(label='2. Select a position in the chain to start exploring',
173
- min_value=0,max_value=len(df)-1,value=0,key='sent_id_from_slider',on_change=initialize_sent_id)
174
- if 'sent_id' not in st.session_state:
175
- initialize_sent_id()
176
 
177
- explore_type = st.sidebar.radio('3. Choose the way to explore',options=['Click through steps','Autoplay'])
178
  if explore_type=='Autoplay':
179
- st.text('Please use the slider on the left to change the starting position.')
180
- cols = st.columns(8)
181
- with cols[0]:
182
- start_autoplay = st.button('Play',key='play')
183
- with cols[1]:
184
- stop_autoplay = st.button('Stop',key='stop')
185
- fig_place_holder = st.empty()
186
- if start_autoplay and not stop_autoplay:
187
- for sent_id in range(st.session_state.sent_id_from_slider,len(st.session_state.df),10):
188
- sentence = df.cleaned_sentence.to_list()[sent_id]
189
- fig = plt.figure(figsize=(5,5),dpi=200)
190
- ax = fig.add_subplot(1,1,1)
191
- ax.plot(x_tsne[:sent_id+1],y_tsne[:sent_id+1],linewidth=0.2,color='gray',zorder=1)
192
- ax.scatter(x_tsne[:sent_id+1],y_tsne[:sent_id+1],s=5,color=color_list[:sent_id+1],zorder=2)
193
- ax.scatter(x_tsne[sent_id:sent_id+1],y_tsne[sent_id:sent_id+1],s=50,marker='*',color='blue',zorder=3)
194
- ax.set_xlim(*xlims)
195
- ax.set_ylim(*ylims)
196
- ax.axis('off')
197
- plt.title(f'Step {sent_id}: {sentence}')
198
- cols = fig_place_holder.columns([1,2,1])
199
- with cols[1]:
200
- fig_place_holder.pyplot(fig)
201
- time.sleep(3)
202
- fig_place_holder.empty()
203
- else:
204
- if explore_type=='Click through steps':
205
- button_labels = ['+1','+10','+100','+500']
206
- cols = st.sidebar.columns([4,5,6,6])
207
- for col_id,col in enumerate(cols):
208
- with col:
209
- st.button(button_labels[col_id],key=button_labels[col_id],
210
- on_click=update_sent_id,kwargs=dict(increment_value=int(button_labels[col_id].replace('+',''))))
211
- button_labels = ['-1','-10','-100','-500']
212
- cols = st.sidebar.columns([4,5,6,6])
213
- for col_id,col in enumerate(cols):
214
- with col:
215
- st.button(button_labels[col_id],key=button_labels[col_id],
216
- on_click=update_sent_id,kwargs=dict(increment_value=int(button_labels[col_id].replace('+',''))))
217
-
218
- sent_id = st.session_state.sent_id
219
- sentence = df.cleaned_sentence.to_list()[sent_id]
220
- input_sent = tokenizer(sentence,return_tensors='pt')['input_ids']
221
- decoded_sent = [tokenizer.decode([token]) for token in input_sent[0]]
222
- char_nums = [len(word)+2 for word in decoded_sent]
223
- show_candidates = st.checkbox('Show candidates')
224
- disp_style = '"font-family:san serif; color:Black; font-size: 20px"'
225
- if explore_type=='Click through steps' and input_type=='Use your own initial sentence' and sent_id>0 and 'finished_sampling' in st.session_state:
226
- sampled_loc = df.next_sample_loc.to_list()[sent_id-1]
227
- disp_step = f'<p style={disp_style}>Step {st.session_state.sent_id}&colon;&nbsp;'
228
- disp_sent_before = f'{disp_step}<span style="font-weight:bold">'+' '.join(decoded_sent[1:sampled_loc])
229
- new_word = f'<span style="color:Red">{decoded_sent[sampled_loc]}</span>'
230
- disp_sent_after = ' '.join(decoded_sent[sampled_loc+1:-1])+'</span></p>'
231
- st.markdown(disp_sent_before+' '+new_word+' '+disp_sent_after,unsafe_allow_html=True)
232
- else:
233
- disp_step = f'<p style={disp_style}>Step {st.session_state.sent_id}&colon;&nbsp;'
234
- st.markdown(f'{disp_step}<span style="font-weight:bold">{sentence}</span></p>',unsafe_allow_html=True)
235
- if show_candidates:
236
- st.write('Click any word to see each candidate with its probability')
237
- cols = st.columns(char_nums)
238
- with cols[0]:
239
- st.write(decoded_sent[0])
240
- with cols[-1]:
241
- st.write(decoded_sent[-1])
242
- for word_id,(col,word) in enumerate(zip(cols[1:-1],decoded_sent[1:-1])):
243
- with col:
244
- if st.button(word,key=f'word_{word_id}'):
245
- probs = mask_prob(model,mask_id,input_sent,word_id+1)
246
- _,candidates_df = sample_words(probs, word_id+1, input_sent)
247
- st.table(candidates_df)
248
 
249
- fig = plt.figure(figsize=(5,5),dpi=200)
250
- ax = fig.add_subplot(1,1,1)
251
- ax.plot(x_tsne[:sent_id+1],y_tsne[:sent_id+1],linewidth=0.2,color='gray',zorder=1)
252
- ax.scatter(x_tsne[:sent_id+1],y_tsne[:sent_id+1],s=5,color=color_list[:sent_id+1],zorder=2)
253
- ax.scatter(x_tsne[sent_id:sent_id+1],y_tsne[sent_id:sent_id+1],s=50,marker='*',color='blue',zorder=3)
254
- ax.set_xlim(*xlims)
255
- ax.set_ylim(*ylims)
256
- ax.axis('off')
257
- cols = st.columns([1,2,1])
258
- with cols[1]:
259
- st.pyplot(fig)
 
9
  from sklearn.manifold import TSNE
10
  from sentence_transformers import SentenceTransformer
11
  from transformers import BertTokenizer,BertForMaskedLM
 
12
  import io
13
  import time
14
 
15
+ @st.cache(show_spinner=True,allow_output_mutation=True)
16
  def load_sentence_model():
17
  sentence_model = SentenceTransformer('paraphrase-distilroberta-base-v1')
18
  return sentence_model
19
 
20
+ @st.cache(show_spinner=True,allow_output_mutation=True)
21
  def load_model(model_name):
22
  if model_name.startswith('bert'):
23
  tokenizer = BertTokenizer.from_pretrained(model_name)
 
29
  def load_data(sentence_num):
30
  df = pd.read_csv('tsne_out.csv')
31
  df = df.loc[lambda d: (d['sentence_num']==sentence_num)&(d['iter_num']<1000)]
32
+ return df.reset_index()
33
 
34
  #@st.cache(show_spinner=False)
35
  def mask_prob(model,mask_id,sentences,position,temp=1):
 
66
  sentence,_ = sample_words(probs,pos,sentence)
67
  return pd.DataFrame(data=data_list,columns=['step','sentence','next_sample_loc'])
68
 
69
+ #@st.cache(show_spinner=True,allow_output_mutation=True)
70
+ def show_tsne_panel(df, step_id):
71
+ x_tsne, y_tsne = df.x_tsne, df.y_tsne
72
+ xscale_unit = (max(x_tsne)-min(x_tsne))/10
73
+ yscale_unit = (max(y_tsne)-min(y_tsne))/10
74
+ xlims = [(min(x_tsne)//xscale_unit-1)*xscale_unit,(max(x_tsne)//xscale_unit+1)*xscale_unit]
75
+ ylims = [(min(y_tsne)//yscale_unit-1)*yscale_unit,(max(y_tsne)//yscale_unit+1)*yscale_unit]
76
+ color_list = sns.color_palette('flare',n_colors=int(len(df)*1.2))
77
+
78
+ fig = plt.figure(figsize=(5,5),dpi=200)
79
+ ax = fig.add_subplot(1,1,1)
80
+ ax.plot(x_tsne[:step_id+1],y_tsne[:step_id+1],linewidth=0.2,color='gray',zorder=1)
81
+ ax.scatter(x_tsne[:step_id+1],y_tsne[:step_id+1],s=5,color=color_list[:step_id+1],zorder=2)
82
+ ax.scatter(x_tsne[step_id:step_id+1],y_tsne[step_id:step_id+1],s=50,marker='*',color='blue',zorder=3)
83
+ ax.set_xlim(*xlims)
84
+ ax.set_ylim(*ylims)
85
+ ax.axis('off')
86
+ return fig
87
+
88
  def run_tsne(chain):
89
  st.sidebar.write('Running t-SNE...')
90
  st.sidebar.write('This takes ~1 min for 1000 steps with ~10 token sentences')
 
98
  tsne = pd.concat([chain, pd.DataFrame(tsne_vals, columns = ['x_tsne', 'y_tsne'],index=chain.index)], axis = 1)
99
  return tsne
100
 
101
+ def autoplay() :
102
+ for step_id in range(st.session_state.step_id, len(st.session_state.df), 1):
103
+ x = st.empty()
104
+ with x.container():
105
+ st.markdown(show_changed_site(), unsafe_allow_html = True)
106
+ fig = show_tsne_panel(st.session_state.df, step_id)
107
+ st.session_state.prev_step_id = st.session_state.step_id
108
+ st.session_state.step_id = step_id
109
+ #plt.title(f'Step {step_id}')#: {show_changed_site()}')
110
+ cols = st.columns([1,2,1])
111
+ with cols[1]:
112
+ st.pyplot(fig)
113
+ time.sleep(.25)
114
+ x.empty()
115
+
116
+ def initialize_buttons() :
117
+ buttons = st.sidebar.empty()
118
+ button_ids = []
119
+ with buttons.container() :
120
+ row1_labels = ['+1','+10','+100','+500']
121
+ row1 = st.columns([4,5,6,6])
122
+ for col_id,col in enumerate(row1):
123
+ button_ids.append(col.button(row1_labels[col_id],key=row1_labels[col_id]))
124
+
125
+ row2_labels = ['-1','-10','-100','-500']
126
+ row2 = st.columns([4,5,6,6])
127
+ for col_id,col in enumerate(row2):
128
+ button_ids.append(col.button(row2_labels[col_id],key=row2_labels[col_id]))
129
+
130
+ show_candidates_checked = st.checkbox('Show candidates')
131
+
132
+ # Increment if any of them have been pressed
133
+ increments = np.array([1,10,100,500,-1,-10,-100,-500])
134
+ if any(button_ids) :
135
+ increment_value = increments[np.array(button_ids)][0]
136
+ st.session_state.prev_step_id = st.session_state.step_id
137
+ new_step_id = st.session_state.step_id + increment_value
138
+ st.session_state.step_id = min(len(st.session_state.df) - 1, max(0, new_step_id))
139
+ if show_candidates_checked:
140
+ st.write('Click any word to see each candidate with its probability')
141
+ show_candidates()
142
+
143
+ def show_candidates():
144
+ if 'curr_table' in st.session_state:
145
+ st.session_state.curr_table.empty()
146
+ step_id = st.session_state.step_id
147
+ sentence = df.cleaned_sentence.loc[step_id]
148
+ input_sent = tokenizer(sentence,return_tensors='pt')['input_ids']
149
+ decoded_sent = [tokenizer.decode([token]) for token in input_sent[0]]
150
+ char_nums = [len(word)+2 for word in decoded_sent]
151
+ cols = st.columns(char_nums)
152
+ with cols[0]:
153
+ st.write(decoded_sent[0])
154
+ with cols[-1]:
155
+ st.write(decoded_sent[-1])
156
+ for word_id,(col,word) in enumerate(zip(cols[1:-1],decoded_sent[1:-1])):
157
+ with col:
158
+ if st.button(word,key=f'word_{word_id}'):
159
+ probs = mask_prob(model,mask_id,input_sent,word_id+1)
160
+ _, candidates_df = sample_words(probs, word_id+1, input_sent)
161
+ st.session_state.curr_table = st.table(candidates_df)
162
+
163
+
164
+ def show_changed_site():
165
+ df = st.session_state.df
166
+ step_id = st.session_state.step_id
167
+ prev_step_id = st.session_state.prev_step_id
168
+ curr_sent = df.cleaned_sentence.loc[step_id].split(' ')
169
+ prev_sent = df.cleaned_sentence.loc[prev_step_id].split(' ')
170
+ locs = [df.next_sample_loc.to_list()[step_id-1]] if 'next_sample_loc' in df else (
171
+ [i for i in range(len(curr_sent)) if curr_sent[i] not in prev_sent]
172
+ )
173
+ disp_style = '"font-family:san serif; color:Black; font-size: 20px"'
174
+ prefix = f'<p style={disp_style}>Step {st.session_state.step_id}&colon;&nbsp; <span style="font-weight:bold">'
175
+ disp = ' '.join([f'<span style="color:Red">{word}</span>' if i in locs else f'{word}'
176
+ for (i, word) in enumerate(curr_sent)])
177
+ suffix = '</span></p>'
178
+ return prefix + disp + suffix
179
+
180
  def clear_df():
181
  if 'df' in st.session_state:
182
  del st.session_state['df']
183
 
 
 
 
 
 
 
 
 
184
 
185
  if __name__=='__main__':
186
+
187
  # Config
188
  max_width = 1500
189
  padding_top = 0
 
210
  """
211
  st.markdown(define_margins, unsafe_allow_html=True)
212
  st.markdown(hide_table_row_index, unsafe_allow_html=True)
213
+ input_type = st.sidebar.radio(
214
+ label='1. Choose the input type',
215
+ on_change=clear_df,
216
+ options=('Use one of the example sentences','Use your own initial sentence')
217
+ )
218
 
219
  # Title
220
  st.header("Demo: Probing BERT's priors with serial reproduction chains")
221
+
 
 
 
 
 
 
 
 
222
  # Load BERT
223
  tokenizer,model = load_model('bert-base-uncased')
224
  mask_id = tokenizer.encode("[MASK]")[1:-1][0]
225
 
226
  # First step: load the dataframe containing sentences
 
 
227
  if input_type=='Use one of the example sentences':
228
  sentence = st.sidebar.selectbox("Select the inital sentence",
229
  ('--- Please select one from below ---',
230
+ 'About 170 campers attend the camps each week.',
231
+ "Ali marpet's mother is joy rose.",
232
+ 'She grew up with three brothers and ten sisters.'))
233
  if sentence!='--- Please select one from below ---':
234
  if sentence=='About 170 campers attend the camps each week.':
235
  sentence_num = 6
236
  elif sentence=='She grew up with three brothers and ten sisters.':
237
  sentence_num = 8
238
+ elif sentence=="Ali marpet's mother is joy rose." :
239
+ sentence_num = 2
240
  st.session_state.df = load_data(sentence_num)
241
+ st.session_state.finished_sampling = True
242
  else:
243
+ sentence = st.sidebar.text_input('Type your own sentence here.',on_change=clear_df)
244
+ num_steps = st.sidebar.number_input(label='How many steps do you want to run?',value=500)
245
  if st.sidebar.button('Run chains'):
246
+ chain = run_chains(tokenizer, model, mask_id, sentence, num_steps=num_steps)
247
  st.session_state.df = run_tsne(chain)
248
  st.session_state.finished_sampling = True
249
+
250
+ st.empty().markdown("\
251
+ Let's explore sentences from BERT's prior! \
252
+ Use the menu to the left to select a pre-generated chain, \
253
+ or start a new chain using your own initial sentence.\
254
+ " if not 'df' in st.session_state else "\
255
+ Use the slider to select a step, or watch the autoplay.\
256
+ Click 'Show candidates' to see the top proposals when each word is masked out.\
257
+ ")
258
 
259
  if 'df' in st.session_state:
260
  df = st.session_state.df
261
+ if 'step_id' not in st.session_state:
262
+ st.session_state.prev_step_id = 0
263
+ st.session_state.step_id = 0
264
+
 
 
265
 
266
+ explore_type = st.sidebar.radio(
267
+ '2. Choose how to explore the chain',
268
+ options=['Click through steps','Autoplay']
269
+ )
270
 
 
271
  if explore_type=='Autoplay':
272
+ st.empty()
273
+ st.sidebar.empty()
274
+ autoplay()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
+ elif explore_type=='Click through steps':
277
+ initialize_buttons()
278
+ with st.container():
279
+ st.markdown(show_changed_site(), unsafe_allow_html = True)
280
+ fig = show_tsne_panel(df, st.session_state.step_id)
281
+ cols = st.columns([1,2,1])
282
+ with cols[1]:
283
+ st.pyplot(fig)