taka-yamakoshi commited on
Commit
9baefdc
1 Parent(s): 5a8ba99

show table

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -233,9 +233,9 @@ if __name__=='__main__':
233
  pron_locs[f'sent_{sent_id}'],
234
  option_2_locs[f'sent_{sent_id}'],mask_id)
235
 
236
- st.write(option_1_locs)
237
- st.write(option_2_locs)
238
- st.write(pron_locs)
239
  for token_ids in [masked_ids_option_1['sent_1'],masked_ids_option_1['sent_2'],masked_ids_option_2['sent_1'],masked_ids_option_2['sent_2']]:
240
  st.write(' '.join([tokenizer.decode([token]) for token in token_ids]))
241
 
@@ -246,10 +246,14 @@ if __name__=='__main__':
246
  assert np.all(option_1_tokens_1==option_1_tokens_2) and np.all(option_2_tokens_1==option_2_tokens_2)
247
  option_1_tokens = option_1_tokens_1
248
  option_2_tokens = option_2_tokens_1
 
 
249
 
250
  interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
251
  probs_original = run_intervention(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
252
- st.write(probs_original)
 
 
253
 
254
  if st.session_state['page_status'] == 'finish_debug':
255
  for layer_id in range(num_layers):
 
233
  pron_locs[f'sent_{sent_id}'],
234
  option_2_locs[f'sent_{sent_id}'],mask_id)
235
 
236
+ #st.write(option_1_locs)
237
+ #st.write(option_2_locs)
238
+ #st.write(pron_locs)
239
  for token_ids in [masked_ids_option_1['sent_1'],masked_ids_option_1['sent_2'],masked_ids_option_2['sent_1'],masked_ids_option_2['sent_2']]:
240
  st.write(' '.join([tokenizer.decode([token]) for token in token_ids]))
241
 
 
246
  assert np.all(option_1_tokens_1==option_1_tokens_2) and np.all(option_2_tokens_1==option_2_tokens_2)
247
  option_1_tokens = option_1_tokens_1
248
  option_2_tokens = option_2_tokens_1
249
+ st.write(option_1_tokens)
250
+ st.write(option_2_tokens)
251
 
252
  interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
253
  probs_original = run_intervention(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
254
+ df = pd.DataFrame(data=[[probs_original[0,0][0],probs_original[1,0][0]],
255
+ [probs_original[0,1][0],probs_original[1,1][0]]],columns=['Option 1','Option 2'],index=['Sentence 1','Sentence 2'])
256
+ st.dataframe(df.style.highlight_max(axis=1))
257
 
258
  if st.session_state['page_status'] == 'finish_debug':
259
  for layer_id in range(num_layers):