Spaces:
Runtime error
Runtime error
taka-yamakoshi
commited on
Commit
·
0aa4961
1
Parent(s):
41743bf
aes
Browse files
app.py
CHANGED
@@ -288,10 +288,11 @@ if __name__=='__main__':
|
|
288 |
cols = st.columns(len(masked_ids_option_1['sent_1'])-2)
|
289 |
token_id = 0
|
290 |
for col_id,col in enumerate(cols):
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
|
|
295 |
ax.set_box_aspect(num_layers)
|
296 |
ax.imshow(effect_array[:,interv_id:interv_id+1,0],cmap=sns.color_palette("light:r", as_cmap=True),
|
297 |
vmin=effect_array[:,:,0].min(),vmax=effect_array[:,:,0].max())
|
@@ -299,5 +300,8 @@ if __name__=='__main__':
|
|
299 |
ax.set_xticklabels([])
|
300 |
ax.set_yticks([])
|
301 |
ax.set_yticklabels([])
|
302 |
-
ax.
|
|
|
|
|
|
|
303 |
st.pyplot(fig)
|
|
|
288 |
cols = st.columns(len(masked_ids_option_1['sent_1'])-2)
|
289 |
token_id = 0
|
290 |
for col_id,col in enumerate(cols):
|
291 |
+
with col:
|
292 |
+
st.write(tokenizer.decode([masked_ids_option_1['sent_1'][col_id+1]]))
|
293 |
+
if col_id in token_id_list:
|
294 |
+
interv_id = token_id_list.index(col_id)
|
295 |
+
fig,ax = plt.subplots()
|
296 |
ax.set_box_aspect(num_layers)
|
297 |
ax.imshow(effect_array[:,interv_id:interv_id+1,0],cmap=sns.color_palette("light:r", as_cmap=True),
|
298 |
vmin=effect_array[:,:,0].min(),vmax=effect_array[:,:,0].max())
|
|
|
300 |
ax.set_xticklabels([])
|
301 |
ax.set_yticks([])
|
302 |
ax.set_yticklabels([])
|
303 |
+
ax.spines['top'].set_visible(False)
|
304 |
+
ax.spines['bottom'].set_visible(False)
|
305 |
+
ax.spines['right'].set_visible(False)
|
306 |
+
ax.spines['left'].set_visible(False)
|
307 |
st.pyplot(fig)
|