taka-yamakoshi commited on
Commit
9874228
1 Parent(s): 3cc4ad8
Files changed (1) hide show
  1. app.py +19 -11
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import numpy as np
2
  import pandas as pd
3
  import streamlit as st
4
- #import matplotlib.pyplot as plt
5
- #import seaborn as sns
6
 
7
  #import jax
8
  #import jax.numpy as jnp
@@ -169,7 +169,6 @@ if __name__=='__main__':
169
  load_css('style.css')
170
  tokenizer,model = load_model()
171
  num_layers, num_heads = model.config.num_hidden_layers, model.config.num_attention_heads
172
- st.write(num_layers,num_heads)
173
  mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
174
 
175
  main_area = st.empty()
@@ -260,11 +259,20 @@ if __name__=='__main__':
260
  st.dataframe(df.style.highlight_max(axis=1))
261
 
262
  multihead = True
263
- for layer_id in range(num_layers):
264
- interventions = [create_interventions(15,['lay','qry','key','val'],num_heads,multihead) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
265
- if multihead:
266
- probs = run_intervention(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
267
- else:
268
- probs = run_intervention(interventions,num_heads,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
269
- effect = ((probs_original-probs)[0,0] + (probs_original-probs)[1,1] + (probs-probs_original)[0,1] + (probs-probs_original)[1,0])/4
270
- st.write(effect)
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  import pandas as pd
3
  import streamlit as st
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
 
7
  #import jax
8
  #import jax.numpy as jnp
 
169
  load_css('style.css')
170
  tokenizer,model = load_model()
171
  num_layers, num_heads = model.config.num_hidden_layers, model.config.num_attention_heads
 
172
  mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
173
 
174
  main_area = st.empty()
 
259
  st.dataframe(df.style.highlight_max(axis=1))
260
 
261
  multihead = True
262
+ effect_array = []
263
+ for token_id in range(1,len(masked_ids_option_1['sent_1'])-1):
264
+ effect_list = []
265
+ for layer_id in range(num_layers):
266
+ interventions = [create_interventions(token_id,['lay','qry','key','val'],num_heads,multihead) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
267
+ if multihead:
268
+ probs = run_intervention(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
269
+ else:
270
+ probs = run_intervention(interventions,num_heads,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
271
+ effect = ((probs_original-probs)[0,0] + (probs_original-probs)[1,1] + (probs-probs_original)[0,1] + (probs-probs_original)[1,0])/4
272
+ effect_list.append(effect)
273
+ effect_array.append(effect_list)
274
+ effects = np.array(effect_array)
275
+
276
+ fig,ax = plt.subplots(1,1,figsize=(8,6))
277
+ ax.imshow(effects.T)
278
+ st.pyplot(fig)