taka-yamakoshi commited on
Commit
6f269e2
1 Parent(s): 20d2531
Files changed (1) hide show
  1. app.py +2 -0
app.py CHANGED
@@ -142,6 +142,7 @@ def mask_out(input_ids,pron_locs,option_locs,mask_id):
142
  # note annotations are shifted by 1 because special tokens were omitted
143
  return input_ids[:pron_locs[0]+1] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+2:]
144
 
 
145
  def run_intervention(interventions,batch_size,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs):
146
  probs = []
147
  for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]):
@@ -158,6 +159,7 @@ def run_intervention(interventions,batch_size,model,masked_ids_option_1,masked_i
158
  probs = np.array(probs)
159
  assert probs.shape[0]==2 and probs.shape[1]==2 and probs.shape[2]==batch_size
160
  return probs
 
161
 
162
  if __name__=='__main__':
163
  wide_setup()
 
142
  # note annotations are shifted by 1 because special tokens were omitted
143
  return input_ids[:pron_locs[0]+1] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+2:]
144
 
145
+ '''
146
  def run_intervention(interventions,batch_size,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs):
147
  probs = []
148
  for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]):
 
159
  probs = np.array(probs)
160
  assert probs.shape[0]==2 and probs.shape[1]==2 and probs.shape[2]==batch_size
161
  return probs
162
+ '''
163
 
164
  if __name__=='__main__':
165
  wide_setup()