taka-yamakoshi commited on
Commit
a267a6b
1 Parent(s): 9239cfa
Files changed (1) hide show
  1. app.py +24 -17
app.py CHANGED
@@ -142,6 +142,23 @@ def mask_out(input_ids,pron_locs,option_locs,mask_id):
142
  # note annotations are shifted by 1 because special tokens were omitted
143
  return input_ids[:pron_locs[0]+1] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+2:]
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  if __name__=='__main__':
146
  wide_setup()
147
  load_css('style.css')
@@ -220,7 +237,6 @@ if __name__=='__main__':
220
  for token_ids in [masked_ids_option_1['sent_1'],masked_ids_option_1['sent_2'],masked_ids_option_2['sent_1'],masked_ids_option_2['sent_2']]:
221
  st.write(' '.join([tokenizer.decode([token]) for token in token_ids]))
222
 
223
- if st.session_state['page_status'] == 'finish_debug':
224
  option_1_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_1_locs['sent_1'])+1]
225
  option_1_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_1_locs['sent_2'])+1]
226
  option_2_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_2_locs['sent_1'])+1]
@@ -229,21 +245,12 @@ if __name__=='__main__':
229
  option_1_tokens = option_1_tokens_1
230
  option_2_tokens = option_2_tokens_1
231
 
 
 
 
 
 
 
232
  for layer_id in range(num_layers):
233
  interventions = [create_interventions(16,['lay','qry','key','val'],num_heads) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
234
- for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]):
235
- input_ids = torch.tensor([
236
- *[masked_ids['sent_1'] for _ in range(num_heads)],
237
- *[masked_ids['sent_2'] for _ in range(num_heads)]
238
- ])
239
- outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions=interventions)
240
- logprobs = F.log_softmax(outputs['logits'], dim = -1)
241
- logprobs_1, logprobs_2 = logprobs[:num_heads], logprobs[num_heads:]
242
- evals_1 = [logprobs_1[:,pron_locs['sent_1'][0]+1+i,token] for i,token in enumerate(option_tokens)]
243
- evals_2 = [logprobs_2[:,pron_locs['sent_2'][0]+1+i,token] for i,token in enumerate(option_tokens)]
244
-
245
-
246
- preds_0 = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0][1:-1]]
247
- preds_1 = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[1][1:-1]]
248
- st.write([tokenizer.decode([token]) for token in preds_0])
249
- st.write([tokenizer.decode([token]) for token in preds_1])
 
142
  # note annotations are shifted by 1 because special tokens were omitted
143
  return input_ids[:pron_locs[0]+1] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+2:]
144
 
145
+ def run(interventions,batch_size,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs):
146
+ probs = []
147
+ for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]]):
148
+ input_ids = torch.tensor([
149
+ *[masked_ids['sent_1'] for _ in range(batch_size)],
150
+ *[masked_ids['sent_2'] for _ in range(batch_size)]
151
+ ])
152
+ outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions=interventions)
153
+ logprobs = F.log_softmax(outputs['logits'], dim = -1)
154
+ logprobs_1, logprobs_2 = logprobs[:batch_size], logprobs[batch_size:]
155
+ evals_1 = [logprobs_1[:,pron_locs['sent_1'][0]+1+i,token].numpy() for i,token in enumerate(option_tokens)]
156
+ evals_2 = [logprobs_2[:,pron_locs['sent_2'][0]+1+i,token].numpy() for i,token in enumerate(option_tokens)]
157
+ probs.append([np.exp(np.mean(evals_1,axis=0)),np.exp(np.mean(evals_2,axis=0))])
158
+ probs = np.array(probs)
159
+ assert probs.shape[0]==2 and probs.shape[1]==2 and probs.shape[2]==batch_size
160
+ return probs
161
+
162
  if __name__=='__main__':
163
  wide_setup()
164
  load_css('style.css')
 
237
  for token_ids in [masked_ids_option_1['sent_1'],masked_ids_option_1['sent_2'],masked_ids_option_2['sent_1'],masked_ids_option_2['sent_2']]:
238
  st.write(' '.join([tokenizer.decode([token]) for token in token_ids]))
239
 
 
240
  option_1_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_1_locs['sent_1'])+1]
241
  option_1_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_1_locs['sent_2'])+1]
242
  option_2_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_2_locs['sent_1'])+1]
 
245
  option_1_tokens = option_1_tokens_1
246
  option_2_tokens = option_2_tokens_1
247
 
248
+ interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
249
+ probs_original = run(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
250
+ st.write(probs_original)
251
+ print(probs_original)
252
+
253
+ if st.session_state['page_status'] == 'finish_debug':
254
  for layer_id in range(num_layers):
255
  interventions = [create_interventions(16,['lay','qry','key','val'],num_heads) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
256
+ probs = run(interventions,num_heads,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)