taka-yamakoshi commited on
Commit
cd10873
1 Parent(s): 402ce08
Files changed (1) hide show
  1. app.py +11 -1
app.py CHANGED
@@ -62,6 +62,7 @@ def annotate_mask(sent_id,sent):
62
  st.write(f'Sentence {sent_id}')
63
  input_sent = tokenizer(sent).input_ids
64
  decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]]
 
65
  char_nums = [len(word)+2 for word in decoded_sent]
66
  cols = st.columns(char_nums)
67
  if f'mask_locs_{sent_id}' not in st.session_state:
@@ -157,13 +158,22 @@ if __name__=='__main__':
157
  with main_area.container():
158
  sent_1 = st.session_state['sent_1']
159
  sent_2 = st.session_state['sent_2']
 
 
 
 
 
 
160
 
161
  input_ids_1 = tokenizer(sent_1).input_ids
162
  input_ids_2 = tokenizer(sent_2).input_ids
163
  input_ids = torch.tensor([input_ids_1,input_ids_2])
164
 
165
  outputs = SkeletonAlbertForMaskedLM(model,input_ids,
166
- interventions = {0:{'lay':[(head_id,17,[0,1]) for head_id in range(64)]}})
 
 
 
167
  logprobs = F.log_softmax(outputs['logits'], dim = -1)
168
  preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0][1:-1]]
169
  st.write([tokenizer.decode([token]) for token in preds])
 
62
  st.write(f'Sentence {sent_id}')
63
  input_sent = tokenizer(sent).input_ids
64
  decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]]
65
+ st.session_state[f'decoded_sent_{sent_id}'] = decoded_sent
66
  char_nums = [len(word)+2 for word in decoded_sent]
67
  cols = st.columns(char_nums)
68
  if f'mask_locs_{sent_id}' not in st.session_state:
 
158
  with main_area.container():
159
  sent_1 = st.session_state['sent_1']
160
  sent_2 = st.session_state['sent_2']
161
+ show_annotated_sentence(st.session_state['deceded_sent_1'],
162
+ option_locs=st.session_state['option_locs_1'],
163
+ mask_locs=st.session_state['mask_locs_1'])
164
+ show_annotated_sentence(st.session_state['deceded_sent_2'],
165
+ option_locs=st.session_state['option_locs_2'],
166
+ mask_locs=st.session_state['mask_locs_2'])
167
 
168
  input_ids_1 = tokenizer(sent_1).input_ids
169
  input_ids_2 = tokenizer(sent_2).input_ids
170
  input_ids = torch.tensor([input_ids_1,input_ids_2])
171
 
172
  outputs = SkeletonAlbertForMaskedLM(model,input_ids,
173
+ interventions = {0:{'lay':[(head_id,17,[0,1]) for head_id in range(64)],
174
+ 'qry':[(head_id,17,[0,1]) for head_id in range(64)],
175
+ 'key':[(head_id,17,[0,1]) for head_id in range(64)],
176
+ 'val':[(head_id,17,[0,1]) for head_id in range(64)]}})
177
  logprobs = F.log_softmax(outputs['logits'], dim = -1)
178
  preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0][1:-1]]
179
  st.write([tokenizer.decode([token]) for token in preds])