taka-yamakoshi
commited on
Commit
•
28525ba
1
Parent(s):
6b8bbf9
change name of func
Browse files
app.py
CHANGED
@@ -142,7 +142,7 @@ def mask_out(input_ids,pron_locs,option_locs,mask_id):
|
|
142 |
# note annotations are shifted by 1 because special tokens were omitted
|
143 |
return input_ids[:pron_locs[0]+1] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+2:]
|
144 |
|
145 |
-
def
|
146 |
probs = []
|
147 |
for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]]):
|
148 |
input_ids = torch.tensor([
|
@@ -247,9 +247,9 @@ if __name__=='__main__':
|
|
247 |
option_2_tokens = option_2_tokens_1
|
248 |
|
249 |
interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
|
250 |
-
probs_original =
|
251 |
st.write(probs_original)
|
252 |
|
253 |
for layer_id in range(num_layers):
|
254 |
interventions = [create_interventions(16,['lay','qry','key','val'],num_heads) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
|
255 |
-
probs =
|
|
|
142 |
# note annotations are shifted by 1 because special tokens were omitted
|
143 |
return input_ids[:pron_locs[0]+1] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+2:]
|
144 |
|
145 |
+
def run_intervention(interventions,batch_size,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs):
|
146 |
probs = []
|
147 |
for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]]):
|
148 |
input_ids = torch.tensor([
|
|
|
247 |
option_2_tokens = option_2_tokens_1
|
248 |
|
249 |
interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
|
250 |
+
probs_original = run_intervention(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
|
251 |
st.write(probs_original)
|
252 |
|
253 |
for layer_id in range(num_layers):
|
254 |
interventions = [create_interventions(16,['lay','qry','key','val'],num_heads) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
|
255 |
+
probs = run_intervention(interventions,num_heads,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
|