taka-yamakoshi commited on
Commit
2ce61c2
·
1 Parent(s): 0b33931

reduce num runs

Browse files
Files changed (1) hide show
  1. app.py +21 -12
app.py CHANGED
@@ -260,7 +260,13 @@ if __name__=='__main__':
260
 
261
  multihead = True
262
  effect_array = []
263
- for token_id in range(1,len(masked_ids_option_1['sent_1'])-1):
 
 
 
 
 
 
264
  effect_list = []
265
  for layer_id in range(num_layers):
266
  interventions = [create_interventions(token_id,['lay','qry','key','val'],num_heads,multihead) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
@@ -274,15 +280,18 @@ if __name__=='__main__':
274
  effects = np.transpose(np.array(effect_array),(1,0,2))
275
 
276
  cols = st.columns(len(masked_ids_option_1['sent_1'])-2)
 
277
  for col_id,col in enumerate(cols):
278
- with col:
279
- fig,ax = plt.subplots()
280
- ax.set_box_aspect(effects.shape[0])
281
- ax.imshow(effects[:,col_id:col_id+1,0],cmap=sns.color_palette("light:r", as_cmap=True),
282
- vmin=effects[:,:,0].min(),vmax=effects[:,:,0].max())
283
- ax.set_xticks([])
284
- ax.set_xticklabels([])
285
- ax.set_yticks([])
286
- ax.set_yticklabels([])
287
- ax.set_title(tokenizer.decode([masked_ids_option_1['sent_1'][col_id+1]]))
288
- st.pyplot(fig)
 
 
 
260
 
261
  multihead = True
262
  effect_array = []
263
+ assert np.all(np.array(pron_locs['sent_1'])==np.array(pron_locs['sent_2']))
264
+ assert np.all(np.array(option_1_locs['sent_1'])==np.array(option_1_locs['sent_2']))
265
+ assert np.all(np.array(option_2_locs['sent_1'])==np.array(option_2_locs['sent_2']))
266
+ token_id_list = pron_locs['sent_1'] + option_1_locs['sent_1'] + option_2_locs['sent_1']
267
+ st.write(token_id_list)
268
+ for token_id in token_id_list:
269
+ token_id += 1
270
  effect_list = []
271
  for layer_id in range(num_layers):
272
  interventions = [create_interventions(token_id,['lay','qry','key','val'],num_heads,multihead) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
 
280
  effects = np.transpose(np.array(effect_array),(1,0,2))
281
 
282
  cols = st.columns(len(masked_ids_option_1['sent_1'])-2)
283
+ token_id = 0
284
  for col_id,col in enumerate(cols):
285
+ if col_id in token_id_list:
286
+ with col:
287
+ fig,ax = plt.subplots()
288
+ ax.set_box_aspect(effects.shape[0])
289
+ ax.imshow(effects[:,token_id:token_id+1,0],cmap=sns.color_palette("light:r", as_cmap=True),
290
+ vmin=effects[:,:,0].min(),vmax=effects[:,:,0].max())
291
+ ax.set_xticks([])
292
+ ax.set_xticklabels([])
293
+ ax.set_yticks([])
294
+ ax.set_yticklabels([])
295
+ ax.set_title(tokenizer.decode([masked_ids_option_1['sent_1'][col_id+1]]))
296
+ st.pyplot(fig)
297
+ token_id += 1