taka-yamakoshi commited on
Commit
65b8143
1 Parent(s): bdd1d60
Files changed (1) hide show
  1. app.py +17 -7
app.py CHANGED
@@ -116,11 +116,14 @@ def show_instruction(sent,fontsize=20):
116
  suffix = '</span></p>'
117
  return st.markdown(prefix + sent + suffix, unsafe_allow_html = True)
118
 
119
- def create_interventions(token_id,interv_types,num_heads):
120
  interventions = {}
121
  for rep in ['lay','qry','key','val']:
122
  if rep in interv_types:
123
- interventions[rep] = [(head_id,token_id,[head_id,head_id+num_heads]) for head_id in range(num_heads)]
 
 
 
124
  else:
125
  interventions[rep] = []
126
  return interventions
@@ -251,10 +254,17 @@ if __name__=='__main__':
251
  interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
252
  probs_original = run_intervention(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
253
  df = pd.DataFrame(data=[[probs_original[0,0][0],probs_original[1,0][0]],
254
- [probs_original[0,1][0],probs_original[1,1][0]]],columns=['Option 1','Option 2'],index=['Sentence 1','Sentence 2'])
 
 
255
  st.dataframe(df.style.highlight_max(axis=1))
256
 
257
- if st.session_state['page_status'] == 'finish_debug':
258
- for layer_id in range(num_layers):
259
- interventions = [create_interventions(16,['lay','qry','key','val'],num_heads) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
260
- probs = run_intervention(interventions,num_heads,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
 
 
 
 
 
 
116
  suffix = '</span></p>'
117
  return st.markdown(prefix + sent + suffix, unsafe_allow_html = True)
118
 
119
+ def create_interventions(token_id,interv_types,num_heads,multihead=False):
120
  interventions = {}
121
  for rep in ['lay','qry','key','val']:
122
  if rep in interv_types:
123
+ if multihead:
124
+ interventions[rep] = [(head_id,token_id,[0,1]) for head_id in range(num_heads)]
125
+ else:
126
+ interventions[rep] = [(head_id,token_id,[head_id,head_id+num_heads]) for head_id in range(num_heads)]
127
  else:
128
  interventions[rep] = []
129
  return interventions
 
254
  interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
255
  probs_original = run_intervention(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
256
  df = pd.DataFrame(data=[[probs_original[0,0][0],probs_original[1,0][0]],
257
+ [probs_original[0,1][0],probs_original[1,1][0]]],
258
+ columns=[tokenizer.decode(option_1_tokens),tokenizer.decode(option_2_tokens)],
259
+ index=['Sentence 1','Sentence 2'])
260
  st.dataframe(df.style.highlight_max(axis=1))
261
 
262
+ multihead = True
263
+ for layer_id in range(num_layers)[:1]:
264
+ interventions = [create_interventions(16,['lay','qry','key','val'],num_heads,multihead) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
265
+ if multihead:
266
+ probs = run_intervention(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
267
+ else
268
+ probs = run_intervention(interventions,num_heads,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
269
+
270
+ st.write(probs)