taka-yamakoshi commited on
Commit
402ce08
1 Parent(s): 318a1a5
Files changed (1) hide show
  1. app.py +27 -32
app.py CHANGED
@@ -91,7 +91,6 @@ def annotate_options(sent_id,sent):
91
  st.session_state[f'option_locs_{sent_id}'].append(word_id)
92
  else:
93
  st.session_state[f'option_locs_{sent_id}'].remove(word_id)
94
- st.write([st.session_state[f'option_locs_{sent_id}']])
95
  st.markdown(show_annotated_sentence(decoded_sent,
96
  option_locs=st.session_state[f'option_locs_{sent_id}'],
97
  mask_locs=st.session_state[f'mask_locs_{sent_id}']), unsafe_allow_html = True)
@@ -123,41 +122,36 @@ if __name__=='__main__':
123
  st.session_state['page_status'] = 'type_in'
124
 
125
  if st.session_state['page_status']=='type_in':
126
- with main_area.container():
127
- st.write('1. Type in the sentences and click "Tokenize"')
128
- sent_1 = st.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.')
129
- sent_2 = st.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.')
130
- if st.button('Tokenize'):
131
- st.session_state['page_status'] = 'annotate_mask'
132
- st.session_state['sent_1'] = sent_1
133
- st.session_state['sent_2'] = sent_2
134
- main_area.empty()
135
 
136
  if st.session_state['page_status']=='annotate_mask':
137
- with main_area.container():
138
- sent_1 = st.session_state['sent_1']
139
- sent_2 = st.session_state['sent_2']
140
 
141
- st.write('2. Select sites to mask out and click "Confirm"')
142
- annotate_mask(1,sent_1)
143
- annotate_mask(2,sent_2)
144
- st.write(st.session_state['mask_locs_1'])
145
- st.write(st.session_state['mask_locs_2'])
146
- if st.button('Confirm',key='mask'):
147
- st.session_state['page_status'] = 'annotate_options'
148
- main_area.empty()
149
 
150
  if st.session_state['page_status'] == 'annotate_options':
151
- with main_area.container():
152
- sent_1 = st.session_state['sent_1']
153
- sent_2 = st.session_state['sent_2']
154
 
155
- st.write('2. Select options and click "Confirm"')
156
- annotate_options(1,sent_1)
157
- annotate_options(2,sent_2)
158
- if st.button('Confirm',key='option'):
159
- st.session_state['page_status'] = 'analysis'
160
- main_area.empty()
161
 
162
  if st.session_state['page_status']=='analysis':
163
  with main_area.container():
@@ -168,7 +162,8 @@ if __name__=='__main__':
168
  input_ids_2 = tokenizer(sent_2).input_ids
169
  input_ids = torch.tensor([input_ids_1,input_ids_2])
170
 
171
- outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions = {0:{'lay':[(8,1,[0,1])]}})
 
172
  logprobs = F.log_softmax(outputs['logits'], dim = -1)
173
- preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0]]
174
  st.write([tokenizer.decode([token]) for token in preds])
 
91
  st.session_state[f'option_locs_{sent_id}'].append(word_id)
92
  else:
93
  st.session_state[f'option_locs_{sent_id}'].remove(word_id)
 
94
  st.markdown(show_annotated_sentence(decoded_sent,
95
  option_locs=st.session_state[f'option_locs_{sent_id}'],
96
  mask_locs=st.session_state[f'mask_locs_{sent_id}']), unsafe_allow_html = True)
 
122
  st.session_state['page_status'] = 'type_in'
123
 
124
  if st.session_state['page_status']=='type_in':
125
+ st.write('1. Type in the sentences and click "Tokenize"')
126
+ sent_1 = st.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.')
127
+ sent_2 = st.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.')
128
+ if st.button('Tokenize'):
129
+ st.session_state['page_status'] = 'annotate_mask'
130
+ st.session_state['sent_1'] = sent_1
131
+ st.session_state['sent_2'] = sent_2
132
+ st.experimental_rerun()
 
133
 
134
  if st.session_state['page_status']=='annotate_mask':
135
+ sent_1 = st.session_state['sent_1']
136
+ sent_2 = st.session_state['sent_2']
 
137
 
138
+ st.write('2. Select sites to mask out and click "Confirm"')
139
+ annotate_mask(1,sent_1)
140
+ annotate_mask(2,sent_2)
141
+ if st.button('Confirm',key='mask'):
142
+ st.session_state['page_status'] = 'annotate_options'
143
+ st.experimental_rerun()
 
 
144
 
145
  if st.session_state['page_status'] == 'annotate_options':
146
+ sent_1 = st.session_state['sent_1']
147
+ sent_2 = st.session_state['sent_2']
 
148
 
149
+ st.write('2. Select options and click "Confirm"')
150
+ annotate_options(1,sent_1)
151
+ annotate_options(2,sent_2)
152
+ if st.button('Confirm',key='option'):
153
+ st.session_state['page_status'] = 'analysis'
154
+ st.experimental_rerun()
155
 
156
  if st.session_state['page_status']=='analysis':
157
  with main_area.container():
 
162
  input_ids_2 = tokenizer(sent_2).input_ids
163
  input_ids = torch.tensor([input_ids_1,input_ids_2])
164
 
165
+ outputs = SkeletonAlbertForMaskedLM(model,input_ids,
166
+ interventions = {0:{'lay':[(head_id,17,[0,1]) for head_id in range(64)]}})
167
  logprobs = F.log_softmax(outputs['logits'], dim = -1)
168
+ preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0][1:-1]]
169
  st.write([tokenizer.decode([token]) for token in preds])