taka-yamakoshi
commited on
Commit
•
a267a6b
1
Parent(s):
9239cfa
test
Browse files
app.py
CHANGED
@@ -142,6 +142,23 @@ def mask_out(input_ids,pron_locs,option_locs,mask_id):
|
|
142 |
# note annotations are shifted by 1 because special tokens were omitted
|
143 |
return input_ids[:pron_locs[0]+1] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+2:]
|
144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
if __name__=='__main__':
|
146 |
wide_setup()
|
147 |
load_css('style.css')
|
@@ -220,7 +237,6 @@ if __name__=='__main__':
|
|
220 |
for token_ids in [masked_ids_option_1['sent_1'],masked_ids_option_1['sent_2'],masked_ids_option_2['sent_1'],masked_ids_option_2['sent_2']]:
|
221 |
st.write(' '.join([tokenizer.decode([token]) for token in token_ids]))
|
222 |
|
223 |
-
if st.session_state['page_status'] == 'finish_debug':
|
224 |
option_1_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_1_locs['sent_1'])+1]
|
225 |
option_1_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_1_locs['sent_2'])+1]
|
226 |
option_2_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_2_locs['sent_1'])+1]
|
@@ -229,21 +245,12 @@ if __name__=='__main__':
|
|
229 |
option_1_tokens = option_1_tokens_1
|
230 |
option_2_tokens = option_2_tokens_1
|
231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
for layer_id in range(num_layers):
|
233 |
interventions = [create_interventions(16,['lay','qry','key','val'],num_heads) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
|
234 |
-
|
235 |
-
input_ids = torch.tensor([
|
236 |
-
*[masked_ids['sent_1'] for _ in range(num_heads)],
|
237 |
-
*[masked_ids['sent_2'] for _ in range(num_heads)]
|
238 |
-
])
|
239 |
-
outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions=interventions)
|
240 |
-
logprobs = F.log_softmax(outputs['logits'], dim = -1)
|
241 |
-
logprobs_1, logprobs_2 = logprobs[:num_heads], logprobs[num_heads:]
|
242 |
-
evals_1 = [logprobs_1[:,pron_locs['sent_1'][0]+1+i,token] for i,token in enumerate(option_tokens)]
|
243 |
-
evals_2 = [logprobs_2[:,pron_locs['sent_2'][0]+1+i,token] for i,token in enumerate(option_tokens)]
|
244 |
-
|
245 |
-
|
246 |
-
preds_0 = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0][1:-1]]
|
247 |
-
preds_1 = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[1][1:-1]]
|
248 |
-
st.write([tokenizer.decode([token]) for token in preds_0])
|
249 |
-
st.write([tokenizer.decode([token]) for token in preds_1])
|
|
|
142 |
# note annotations are shifted by 1 because special tokens were omitted
|
143 |
return input_ids[:pron_locs[0]+1] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+2:]
|
144 |
|
145 |
+
def run(interventions,batch_size,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs):
|
146 |
+
probs = []
|
147 |
+
for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]]):
|
148 |
+
input_ids = torch.tensor([
|
149 |
+
*[masked_ids['sent_1'] for _ in range(batch_size)],
|
150 |
+
*[masked_ids['sent_2'] for _ in range(batch_size)]
|
151 |
+
])
|
152 |
+
outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions=interventions)
|
153 |
+
logprobs = F.log_softmax(outputs['logits'], dim = -1)
|
154 |
+
logprobs_1, logprobs_2 = logprobs[:batch_size], logprobs[batch_size:]
|
155 |
+
evals_1 = [logprobs_1[:,pron_locs['sent_1'][0]+1+i,token].numpy() for i,token in enumerate(option_tokens)]
|
156 |
+
evals_2 = [logprobs_2[:,pron_locs['sent_2'][0]+1+i,token].numpy() for i,token in enumerate(option_tokens)]
|
157 |
+
probs.append([np.exp(np.mean(evals_1,axis=0)),np.exp(np.mean(evals_2,axis=0))])
|
158 |
+
probs = np.array(probs)
|
159 |
+
assert probs.shape[0]==2 and probs.shape[1]==2 and probs.shape[2]==batch_size
|
160 |
+
return probs
|
161 |
+
|
162 |
if __name__=='__main__':
|
163 |
wide_setup()
|
164 |
load_css('style.css')
|
|
|
237 |
for token_ids in [masked_ids_option_1['sent_1'],masked_ids_option_1['sent_2'],masked_ids_option_2['sent_1'],masked_ids_option_2['sent_2']]:
|
238 |
st.write(' '.join([tokenizer.decode([token]) for token in token_ids]))
|
239 |
|
|
|
240 |
option_1_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_1_locs['sent_1'])+1]
|
241 |
option_1_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_1_locs['sent_2'])+1]
|
242 |
option_2_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_2_locs['sent_1'])+1]
|
|
|
245 |
option_1_tokens = option_1_tokens_1
|
246 |
option_2_tokens = option_2_tokens_1
|
247 |
|
248 |
+
interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
|
249 |
+
probs_original = run(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
|
250 |
+
st.write(probs_original)
|
251 |
+
print(probs_original)
|
252 |
+
|
253 |
+
if st.session_state['page_status'] == 'finish_debug':
|
254 |
for layer_id in range(num_layers):
|
255 |
interventions = [create_interventions(16,['lay','qry','key','val'],num_heads) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
|
256 |
+
probs = run(interventions,num_heads,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|