taka-yamakoshi commited on
Commit
0aa4961
1 Parent(s): 41743bf
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -288,10 +288,11 @@ if __name__=='__main__':
288
  cols = st.columns(len(masked_ids_option_1['sent_1'])-2)
289
  token_id = 0
290
  for col_id,col in enumerate(cols):
291
- if col_id in token_id_list:
292
- interv_id = token_id_list.index(col_id)
293
- with col:
294
- fig,ax = plt.subplots(figsize=(1,6))
 
295
  ax.set_box_aspect(num_layers)
296
  ax.imshow(effect_array[:,interv_id:interv_id+1,0],cmap=sns.color_palette("light:r", as_cmap=True),
297
  vmin=effect_array[:,:,0].min(),vmax=effect_array[:,:,0].max())
@@ -299,5 +300,8 @@ if __name__=='__main__':
299
  ax.set_xticklabels([])
300
  ax.set_yticks([])
301
  ax.set_yticklabels([])
302
- ax.set_title(tokenizer.decode([masked_ids_option_1['sent_1'][col_id+1]]))
 
 
 
303
  st.pyplot(fig)
 
288
  cols = st.columns(len(masked_ids_option_1['sent_1'])-2)
289
  token_id = 0
290
  for col_id,col in enumerate(cols):
291
+ with col:
292
+ st.write(tokenizer.decode([masked_ids_option_1['sent_1'][col_id+1]]))
293
+ if col_id in token_id_list:
294
+ interv_id = token_id_list.index(col_id)
295
+ fig,ax = plt.subplots()
296
  ax.set_box_aspect(num_layers)
297
  ax.imshow(effect_array[:,interv_id:interv_id+1,0],cmap=sns.color_palette("light:r", as_cmap=True),
298
  vmin=effect_array[:,:,0].min(),vmax=effect_array[:,:,0].max())
 
300
  ax.set_xticklabels([])
301
  ax.set_yticks([])
302
  ax.set_yticklabels([])
303
+ ax.spines['top'].set_visible(False)
304
+ ax.spines['bottom'].set_visible(False)
305
+ ax.spines['right'].set_visible(False)
306
+ ax.spines['left'].set_visible(False)
307
  st.pyplot(fig)