taka-yamakoshi commited on
Commit
ce466e4
1 Parent(s): d1e605d

check masking

Browse files
Files changed (1) hide show
  1. app.py +55 -11
app.py CHANGED
@@ -111,16 +111,39 @@ def show_annotated_sentence(sent,option_locs=[],mask_locs=[]):
111
  suffix = '</span></p>'
112
  return st.markdown(prefix + disp + suffix, unsafe_allow_html = True)
113
 
114
- def show_instruction(sent):
115
- disp_style = '"font-family:san serif; color:Black; font-size: 20px"'
116
  prefix = f'<p style={disp_style}><span style="font-weight:bold">'
117
  suffix = '</span></p>'
118
  return st.markdown(prefix + sent + suffix, unsafe_allow_html = True)
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  if __name__=='__main__':
121
  wide_setup()
122
  load_css('style.css')
123
  tokenizer,model = load_model()
 
124
  mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
125
 
126
  main_area = st.empty()
@@ -171,16 +194,37 @@ if __name__=='__main__':
171
  option_locs=st.session_state['option_locs_2'],
172
  mask_locs=st.session_state['mask_locs_2'])
173
 
174
- input_ids_1 = tokenizer(sent_1).input_ids
175
- input_ids_2 = tokenizer(sent_2).input_ids
176
- input_ids = torch.tensor([input_ids_1,input_ids_2])
177
-
178
- outputs = SkeletonAlbertForMaskedLM(model,input_ids,
179
- interventions = {0:{'lay':[(head_id,16,[0,1]) for head_id in range(64)],
180
- 'qry':[(head_id,16,[0,1]) for head_id in range(64)],
181
- 'key':[(head_id,16,[0,1]) for head_id in range(64)],
182
- 'val':[(head_id,16,[0,1]) for head_id in range(64)]}})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  logprobs = F.log_softmax(outputs['logits'], dim = -1)
 
 
184
  preds_0 = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0][1:-1]]
185
  preds_1 = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[1][1:-1]]
186
  st.write([tokenizer.decode([token]) for token in preds_0])
 
111
  suffix = '</span></p>'
112
  return st.markdown(prefix + disp + suffix, unsafe_allow_html = True)
113
 
114
+ def show_instruction(sent,fontsize=20):
115
+ disp_style = f'"font-family:san serif; color:Black; font-size: {fontsize}px"'
116
  prefix = f'<p style={disp_style}><span style="font-weight:bold">'
117
  suffix = '</span></p>'
118
  return st.markdown(prefix + sent + suffix, unsafe_allow_html = True)
119
 
120
+ def create_interventions(token_id,interv_type,num_layers,num_heads):
121
+ interventions = {}
122
+ for layer_id in range(num_layers):
123
+ interventions[layer_id] = {}
124
+ if interv_type == 'all':
125
+ for rep in ['lay','qry','key','val']:
126
+ interventions[layer_id][rep] = [(head_id,token_id,[head_id,head_id+num_heads]) for head_id in range(num_heads)]
127
+ else:
128
+ interventions[layer_id][interv_type] = [(head_id,token_id,[head_id,head_id+num_heads]) for head_id in range(num_heads)]
129
+ return interventions
130
+
131
+ def separate_options(option_locs):
132
+ assert np.sum(np.diff(option_locs)>1)==1
133
+ sep = list(np.diff(option_locs)>1).index(1)+1
134
+ option_1_locs, option_2_locs = option_locs[:sep], option_locs[sep:]
135
+ assert np.all(np.diff(option_1_locs)==1) and np.all(np.diff(option_2_loc)==1)
136
+ return option_1_locs, option_2_locs
137
+
138
+ def mask_out(input_ids,pron_locs,option_locs,mask_id):
139
+ assert np.all(np.diff(pron_locs)==1)
140
+ return input_ids[:pron_locs[0]] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+1:]
141
+
142
  if __name__=='__main__':
143
  wide_setup()
144
  load_css('style.css')
145
  tokenizer,model = load_model()
146
+ num_layers, num_heads = 12, 64
147
  mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
148
 
149
  main_area = st.empty()
 
194
  option_locs=st.session_state['option_locs_2'],
195
  mask_locs=st.session_state['mask_locs_2'])
196
 
197
+ option_1_locs, option_2_locs = {}, {}
198
+ pron_id = {}
199
+ input_ids_dict = {}
200
+ masked_ids_option_1 = {}
201
+ masked_ids_option_2 = {}
202
+ for sent_id in range(2):
203
+ option_1_locs[f'sent_{sent_id+1}'], option_2_locs[f'sent_{sent_id+1}'] = separate_options(st.session_state[f'option_locs_{sent_id}'])
204
+ pron_locs[f'sent_{sent_id+1}'] = st.session_state[f'mask_locs_{sent_id+1}']
205
+ input_ids_dict[f'sent_{sent_id+1}'] = tokenizer(st.session_state[f'sent_{sent_id+1}']).input_ids
206
+
207
+ masked_ids_option_1[f'sent_{sent_id+1}'] = mask_out(input_ids_dict[f'sent_{sent_id+1}'],
208
+ pron_locs[f'sent_{sent_id+1}'],
209
+ option_1_locs[f'sent_{sent_id+1}'],mask_id)
210
+ masked_ids_option_2[f'sent_{sent_id+1}'] = mask_out(input_ids_dict[f'sent_{sent_id+1}'],
211
+ pron_locs[f'sent_{sent_id+1}'],
212
+ option_2_locs[f'sent_{sent_id+1}'],mask_id)
213
+
214
+ for token_ids in [masked_ids_option_1['sent_1'],masked_ids_option_1['sent_2'],masked_ids_option_2['sent_1'],masked_ids_option_2['sent_2']]:
215
+ st.write(' '.join([tokenizer.decode([token]) for toke in token_ids]))
216
+
217
+ if st.session_state['page_status'] == 'finish_debug':
218
+ try:
219
+ assert len(input_ids_1) == len(input_ids_2)
220
+ except AssertionError:
221
+ show_instruction('Please make sure the number of tokens match between Sentence 1 and Sentence 2', fontsize=12)
222
+ input_ids = torch.tensor([*[input_ids_1 for _ in range(num_heads)],*[input_ids_2 for _ in range(num_heads)]])
223
+ interventions = create_interventions(16,'all',num_layers=num_layers,num_heads=num_heads)
224
+ outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions=interventions)
225
  logprobs = F.log_softmax(outputs['logits'], dim = -1)
226
+
227
+
228
  preds_0 = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0][1:-1]]
229
  preds_1 = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[1][1:-1]]
230
  st.write([tokenizer.decode([token]) for token in preds_0])