Spaces:
Runtime error
Runtime error
taka-yamakoshi
commited on
Commit
·
2ce61c2
1
Parent(s):
0b33931
reduce num runs
Browse files
app.py
CHANGED
@@ -260,7 +260,13 @@ if __name__=='__main__':
|
|
260 |
|
261 |
multihead = True
|
262 |
effect_array = []
|
263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
effect_list = []
|
265 |
for layer_id in range(num_layers):
|
266 |
interventions = [create_interventions(token_id,['lay','qry','key','val'],num_heads,multihead) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
|
@@ -274,15 +280,18 @@ if __name__=='__main__':
|
|
274 |
effects = np.transpose(np.array(effect_array),(1,0,2))
|
275 |
|
276 |
cols = st.columns(len(masked_ids_option_1['sent_1'])-2)
|
|
|
277 |
for col_id,col in enumerate(cols):
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
|
|
|
|
|
260 |
|
261 |
multihead = True
|
262 |
effect_array = []
|
263 |
+
assert np.all(np.array(pron_locs['sent_1'])==np.array(pron_locs['sent_2']))
|
264 |
+
assert np.all(np.array(option_1_locs['sent_1'])==np.array(option_1_locs['sent_2']))
|
265 |
+
assert np.all(np.array(option_2_locs['sent_1'])==np.array(option_2_locs['sent_2']))
|
266 |
+
token_id_list = pron_locs['sent_1'] + option_1_locs['sent_1'] + option_2_locs['sent_1']
|
267 |
+
st.write(token_id_list)
|
268 |
+
for token_id in token_id_list:
|
269 |
+
token_id += 1
|
270 |
effect_list = []
|
271 |
for layer_id in range(num_layers):
|
272 |
interventions = [create_interventions(token_id,['lay','qry','key','val'],num_heads,multihead) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
|
|
|
280 |
effects = np.transpose(np.array(effect_array),(1,0,2))
|
281 |
|
282 |
cols = st.columns(len(masked_ids_option_1['sent_1'])-2)
|
283 |
+
token_id = 0
|
284 |
for col_id,col in enumerate(cols):
|
285 |
+
if col_id in token_id_list:
|
286 |
+
with col:
|
287 |
+
fig,ax = plt.subplots()
|
288 |
+
ax.set_box_aspect(effects.shape[0])
|
289 |
+
ax.imshow(effects[:,token_id:token_id+1,0],cmap=sns.color_palette("light:r", as_cmap=True),
|
290 |
+
vmin=effects[:,:,0].min(),vmax=effects[:,:,0].max())
|
291 |
+
ax.set_xticks([])
|
292 |
+
ax.set_xticklabels([])
|
293 |
+
ax.set_yticks([])
|
294 |
+
ax.set_yticklabels([])
|
295 |
+
ax.set_title(tokenizer.decode([masked_ids_option_1['sent_1'][col_id+1]]))
|
296 |
+
st.pyplot(fig)
|
297 |
+
token_id += 1
|