taka-yamakoshi
commited on
Commit
•
9baefdc
1
Parent(s):
5a8ba99
show table
Browse files
app.py
CHANGED
@@ -233,9 +233,9 @@ if __name__=='__main__':
|
|
233 |
pron_locs[f'sent_{sent_id}'],
|
234 |
option_2_locs[f'sent_{sent_id}'],mask_id)
|
235 |
|
236 |
-
st.write(option_1_locs)
|
237 |
-
st.write(option_2_locs)
|
238 |
-
st.write(pron_locs)
|
239 |
for token_ids in [masked_ids_option_1['sent_1'],masked_ids_option_1['sent_2'],masked_ids_option_2['sent_1'],masked_ids_option_2['sent_2']]:
|
240 |
st.write(' '.join([tokenizer.decode([token]) for token in token_ids]))
|
241 |
|
@@ -246,10 +246,14 @@ if __name__=='__main__':
|
|
246 |
assert np.all(option_1_tokens_1==option_1_tokens_2) and np.all(option_2_tokens_1==option_2_tokens_2)
|
247 |
option_1_tokens = option_1_tokens_1
|
248 |
option_2_tokens = option_2_tokens_1
|
|
|
|
|
249 |
|
250 |
interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
|
251 |
probs_original = run_intervention(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
|
252 |
-
|
|
|
|
|
253 |
|
254 |
if st.session_state['page_status'] == 'finish_debug':
|
255 |
for layer_id in range(num_layers):
|
|
|
233 |
pron_locs[f'sent_{sent_id}'],
|
234 |
option_2_locs[f'sent_{sent_id}'],mask_id)
|
235 |
|
236 |
+
#st.write(option_1_locs)
|
237 |
+
#st.write(option_2_locs)
|
238 |
+
#st.write(pron_locs)
|
239 |
for token_ids in [masked_ids_option_1['sent_1'],masked_ids_option_1['sent_2'],masked_ids_option_2['sent_1'],masked_ids_option_2['sent_2']]:
|
240 |
st.write(' '.join([tokenizer.decode([token]) for token in token_ids]))
|
241 |
|
|
|
246 |
assert np.all(option_1_tokens_1==option_1_tokens_2) and np.all(option_2_tokens_1==option_2_tokens_2)
|
247 |
option_1_tokens = option_1_tokens_1
|
248 |
option_2_tokens = option_2_tokens_1
|
249 |
+
st.write(option_1_tokens)
|
250 |
+
st.write(option_2_tokens)
|
251 |
|
252 |
interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
|
253 |
probs_original = run_intervention(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
|
254 |
+
df = pd.DataFrame(data=[[probs_original[0,0][0],probs_original[1,0][0]],
|
255 |
+
[probs_original[0,1][0],probs_original[1,1][0]]],columns=['Option 1','Option 2'],index=['Sentence 1','Sentence 2'])
|
256 |
+
st.dataframe(df.style.highlight_max(axis=1))
|
257 |
|
258 |
if st.session_state['page_status'] == 'finish_debug':
|
259 |
for layer_id in range(num_layers):
|