taka-yamakoshi commited on
Commit
ca1b654
1 Parent(s): 9839e32
Files changed (1) hide show
  1. app.py +54 -40
app.py CHANGED
@@ -130,14 +130,14 @@ def show_instruction(sent,fontsize=20):
130
  suffix = '</span></p>'
131
  return st.markdown(prefix + sent + suffix, unsafe_allow_html = True)
132
 
133
- def create_interventions(token_id,interv_types,num_heads,multihead=False):
134
  interventions = {}
135
  for rep in ['lay','qry','key','val']:
136
  if rep in interv_types:
137
  if multihead:
138
  interventions[rep] = [(head_id,token_id,[0,1]) for head_id in range(num_heads)]
139
  else:
140
- interventions[rep] = [(head_id,token_id,[head_id,head_id+num_heads]) for head_id in range(num_heads)]
141
  else:
142
  interventions[rep] = []
143
  return interventions
@@ -176,6 +176,27 @@ def run_intervention(interventions,batch_size,skeleton_model,model,masked_ids_op
176
  assert probs.shape[0]==2 and probs.shape[1]==2 and probs.shape[2]==batch_size
177
  return probs
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  if __name__=='__main__':
180
  wide_setup()
181
  load_css('style.css')
@@ -217,7 +238,7 @@ if __name__=='__main__':
217
  show_instruction('2. Select sites to mask out and click "Confirm"',fontsize=16)
218
  #show_instruction('------------------------------',fontsize=32)
219
  annotate_mask(1,sent_1)
220
- show_instruction('------------------------------',fontsize=32)
221
  annotate_mask(2,sent_2)
222
  if st.button('Confirm',key='confirm_mask'):
223
  st.session_state['page_status'] = 'annotate_options'
@@ -230,21 +251,34 @@ if __name__=='__main__':
230
  show_instruction('3. Select options and click "Confirm"',fontsize=16)
231
  #show_instruction('------------------------------',fontsize=32)
232
  annotate_options(1,sent_1)
233
- show_instruction('------------------------------',fontsize=32)
234
  annotate_options(2,sent_2)
235
  if st.button('Confirm',key='confirm_option'):
236
  st.session_state['page_status'] = 'analysis'
237
  st.experimental_rerun()
238
 
239
  if st.session_state['page_status']=='analysis':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  sent_1 = st.session_state['sent_1']
241
  sent_2 = st.session_state['sent_2']
242
- #show_annotated_sentence(st.session_state['decoded_sent_1'],
243
- # option_locs=st.session_state['option_locs_1'],
244
- # mask_locs=st.session_state['mask_locs_1'])
245
- #show_annotated_sentence(st.session_state['decoded_sent_2'],
246
- # option_locs=st.session_state['option_locs_2'],
247
- # mask_locs=st.session_state['mask_locs_2'])
248
 
249
  option_1_locs, option_2_locs = {}, {}
250
  pron_locs = {}
@@ -263,12 +297,6 @@ if __name__=='__main__':
263
  pron_locs[f'sent_{sent_id}'],
264
  option_2_locs[f'sent_{sent_id}'],mask_id)
265
 
266
- #st.write(option_1_locs)
267
- #st.write(option_2_locs)
268
- #st.write(pron_locs)
269
- #for token_ids in [masked_ids_option_1['sent_1'],masked_ids_option_1['sent_2'],masked_ids_option_2['sent_1'],masked_ids_option_2['sent_2']]:
270
- # st.write(' '.join([tokenizer.decode([token]) for token in token_ids]))
271
-
272
  option_1_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_1_locs['sent_1'])+1]
273
  option_1_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_1_locs['sent_2'])+1]
274
  option_2_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_2_locs['sent_1'])+1]
@@ -293,45 +321,31 @@ if __name__=='__main__':
293
  assert np.all(compare_1.astype(int)==compare_2.astype(int))
294
  context_locs = list(np.arange(len(masked_ids_option_1['sent_1']))[compare_1]-1) # match the indexing for annotation
295
 
296
- multihead = True
297
  assert np.all(np.array(pron_locs['sent_1'])==np.array(pron_locs['sent_2']))
298
  assert np.all(np.array(option_1_locs['sent_1'])==np.array(option_1_locs['sent_2']))
299
  assert np.all(np.array(option_2_locs['sent_1'])==np.array(option_2_locs['sent_2']))
300
  token_id_list = pron_locs['sent_1'] + option_1_locs['sent_1'] + option_2_locs['sent_1'] + context_locs
301
- #st.write(token_id_list)
302
 
303
  effect_array = []
304
  for token_id in token_id_list:
305
  token_id += 1
306
  effect_list = []
307
  for layer_id in range(num_layers):
308
- interventions = [create_interventions(token_id,['lay','qry','key','val'],num_heads,multihead) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
 
309
  if multihead:
310
  probs = run_intervention(interventions,1,skeleton_model,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
311
  else:
312
- probs = run_intervention(interventions,num_heads,skeleton_model,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
313
  effect = ((probs_original-probs)[0,0] + (probs_original-probs)[1,1] + (probs-probs_original)[0,1] + (probs-probs_original)[1,0])/4
314
  effect_list.append(effect)
315
  effect_array.append(effect_list)
316
  effect_array = np.transpose(np.array(effect_array),(1,0,2))
317
 
318
- cols = st.columns(len(masked_ids_option_1['sent_1'])-2)
319
- token_id = 0
320
- for col_id,col in enumerate(cols):
321
- with col:
322
- st.write(tokenizer.decode([masked_ids_option_1['sent_1'][col_id+1]]))
323
- if col_id in token_id_list:
324
- interv_id = token_id_list.index(col_id)
325
- fig,ax = plt.subplots()
326
- ax.set_box_aspect(num_layers)
327
- ax.imshow(effect_array[:,interv_id:interv_id+1,0],cmap=sns.color_palette("light:r", as_cmap=True),
328
- vmin=effect_array[:,:,0].min(),vmax=effect_array[:,:,0].max())
329
- ax.set_xticks([])
330
- ax.set_xticklabels([])
331
- ax.set_yticks([])
332
- ax.set_yticklabels([])
333
- ax.spines['top'].set_visible(False)
334
- ax.spines['bottom'].set_visible(False)
335
- ax.spines['right'].set_visible(False)
336
- ax.spines['left'].set_visible(False)
337
- st.pyplot(fig)
 
130
  suffix = '</span></p>'
131
  return st.markdown(prefix + sent + suffix, unsafe_allow_html = True)
132
 
133
+ def create_interventions(token_id,interv_types,num_heads,multihead=False,heads=[]):
134
  interventions = {}
135
  for rep in ['lay','qry','key','val']:
136
  if rep in interv_types:
137
  if multihead:
138
  interventions[rep] = [(head_id,token_id,[0,1]) for head_id in range(num_heads)]
139
  else:
140
+ interventions[rep] = [(head_id,token_id,[i,i+len(heads)]) for i,head_id in enumerate(heads)]
141
  else:
142
  interventions[rep] = []
143
  return interventions
 
176
  assert probs.shape[0]==2 and probs.shape[1]==2 and probs.shape[2]==batch_size
177
  return probs
178
 
179
+ def show_results(effect_array,masked_sent,token_id_list,num_layers):
180
+ cols = st.columns(len(masked_sent)-2)
181
+ for col_id,col in enumerate(cols):
182
+ with col:
183
+ st.write(tokenizer.decode([masked_sent[col_id+1]]))
184
+ if col_id in token_id_list:
185
+ interv_id = token_id_list.index(col_id)
186
+ fig,ax = plt.subplots()
187
+ ax.set_box_aspect(num_layers)
188
+ ax.imshow(effect_array[:,interv_id:interv_id+1],cmap=sns.color_palette("light:r", as_cmap=True),
189
+ vmin=effect_array.min(),vmax=effect_array.max())
190
+ ax.set_xticks([])
191
+ ax.set_xticklabels([])
192
+ ax.set_yticks([])
193
+ ax.set_yticklabels([])
194
+ ax.spines['top'].set_visible(False)
195
+ ax.spines['bottom'].set_visible(False)
196
+ ax.spines['right'].set_visible(False)
197
+ ax.spines['left'].set_visible(False)
198
+ st.pyplot(fig)
199
+
200
  if __name__=='__main__':
201
  wide_setup()
202
  load_css('style.css')
 
238
  show_instruction('2. Select sites to mask out and click "Confirm"',fontsize=16)
239
  #show_instruction('------------------------------',fontsize=32)
240
  annotate_mask(1,sent_1)
241
+ show_instruction('------------------------------',fontsize=24)
242
  annotate_mask(2,sent_2)
243
  if st.button('Confirm',key='confirm_mask'):
244
  st.session_state['page_status'] = 'annotate_options'
 
251
  show_instruction('3. Select options and click "Confirm"',fontsize=16)
252
  #show_instruction('------------------------------',fontsize=32)
253
  annotate_options(1,sent_1)
254
+ show_instruction('------------------------------',fontsize=24)
255
  annotate_options(2,sent_2)
256
  if st.button('Confirm',key='confirm_option'):
257
  st.session_state['page_status'] = 'analysis'
258
  st.experimental_rerun()
259
 
260
  if st.session_state['page_status']=='analysis':
261
+ interv_reps = st.multiselect('Select the types of representations to intervene.',['layer','query','key','value'])
262
+ rep_dict = {'layer':'lay','query':'qry','key':'key','value':'val'}
263
+ multihead = not st.checkbox('Perform individual head analysis (takes time)')
264
+ if not multihead:
265
+ heads = st.multiselect('Select heads to intervene.',list(np.arange(1,num_heads+1)))
266
+ else:
267
+ heads = []
268
+
269
+ if st.button('Run',key='run'):
270
+ st.session_state['reps'] = [rep_dict[rep] for rep in interv_reps]
271
+ st.session_state['multihead'] = multihead
272
+ st.session_state['heads'] = heads
273
+ st.session_state['page_status'] = 'results'
274
+ st.experimental_rerun()
275
+
276
+ if st.session_state['page_status']=='results':
277
  sent_1 = st.session_state['sent_1']
278
  sent_2 = st.session_state['sent_2']
279
+ multihead = st.session_state['multihead']
280
+ heads = st.session_state['heads']
281
+ reps = st.session_state['reps']
 
 
 
282
 
283
  option_1_locs, option_2_locs = {}, {}
284
  pron_locs = {}
 
297
  pron_locs[f'sent_{sent_id}'],
298
  option_2_locs[f'sent_{sent_id}'],mask_id)
299
 
 
 
 
 
 
 
300
  option_1_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_1_locs['sent_1'])+1]
301
  option_1_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_1_locs['sent_2'])+1]
302
  option_2_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_2_locs['sent_1'])+1]
 
321
  assert np.all(compare_1.astype(int)==compare_2.astype(int))
322
  context_locs = list(np.arange(len(masked_ids_option_1['sent_1']))[compare_1]-1) # match the indexing for annotation
323
 
 
324
  assert np.all(np.array(pron_locs['sent_1'])==np.array(pron_locs['sent_2']))
325
  assert np.all(np.array(option_1_locs['sent_1'])==np.array(option_1_locs['sent_2']))
326
  assert np.all(np.array(option_2_locs['sent_1'])==np.array(option_2_locs['sent_2']))
327
  token_id_list = pron_locs['sent_1'] + option_1_locs['sent_1'] + option_2_locs['sent_1'] + context_locs
 
328
 
329
  effect_array = []
330
  for token_id in token_id_list:
331
  token_id += 1
332
  effect_list = []
333
  for layer_id in range(num_layers):
334
+ interventions = [create_interventions(token_id,reps,num_heads,multihead,[head_id-1 for head_id in heads])
335
+ if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
336
  if multihead:
337
  probs = run_intervention(interventions,1,skeleton_model,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
338
  else:
339
+ probs = run_intervention(interventions,len(heads),skeleton_model,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
340
  effect = ((probs_original-probs)[0,0] + (probs_original-probs)[1,1] + (probs-probs_original)[0,1] + (probs-probs_original)[1,0])/4
341
  effect_list.append(effect)
342
  effect_array.append(effect_list)
343
  effect_array = np.transpose(np.array(effect_array),(1,0,2))
344
 
345
+ if multihead:
346
+ show_results(effect_array[:,:,0],masked_ids_option_1['sent_1'],token_id_list,num_layers)
347
+ else:
348
+ tabs = st.tabs(heads)
349
+ for i,tab in enumerate(tabs):
350
+ with tab:
351
+ show_results(effect_array[:,:,i],masked_ids_option_1['sent_1'],token_id_list,num_layers)