taka-yamakoshi commited on
Commit
50ce4f4
1 Parent(s): 8f32fbf

first pass

Browse files
Files changed (1) hide show
  1. app.py +65 -28
app.py CHANGED
@@ -58,6 +58,41 @@ def clear_data():
58
  for key in st.session_state:
59
  del st.session_state[key]
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def show_annotated_sentence(sent,option_locs=[],mask_locs=[]):
62
  disp_style = '"font-family:san serif; color:Black; font-size: 20px"'
63
  prefix = f'<p style={disp_style}><span style="font-weight:bold">'
@@ -90,43 +125,45 @@ if __name__=='__main__':
90
  sent_1 = st.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.')
91
  sent_2 = st.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.')
92
  if st.button('Tokenize'):
93
- st.session_state['page_status'] = 'tokenized'
94
  st.session_state['sent_1'] = sent_1
95
  st.session_state['sent_2'] = sent_2
96
  main_area.empty()
97
 
98
- if st.session_state['page_status']=='tokenized':
99
  with main_area.container():
100
  sent_1 = st.session_state['sent_1']
101
  sent_2 = st.session_state['sent_2']
102
- if 'masked_pos_1' not in st.session_state:
103
- st.session_state['masked_pos_1'] = []
104
- if 'masked_pos_2' not in st.session_state:
105
- st.session_state['masked_pos_2'] = []
106
 
107
  st.write('2. Select sites to mask out and click "Confirm"')
108
- input_sent = tokenizer(sent_1).input_ids
109
- decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]]
110
- char_nums = [len(word)+2 for word in decoded_sent]
111
- cols = st.columns(char_nums)
112
- for word_id,(col,word) in enumerate(zip(cols,decoded_sent)):
113
- with col:
114
- if st.button(word,key=f'word_{word_id}'):
115
- if word_id not in st.session_state['masked_pos_1']:
116
- st.session_state['masked_pos_1'].append(word_id)
117
- else:
118
- st.session_state['masked_pos_1'].remove(word_id)
119
- st.markdown(show_annotated_sentence(decoded_sent,mask_locs=st.session_state['masked_pos_1']), unsafe_allow_html = True)
120
 
 
 
 
 
 
 
121
 
122
  if st.session_state['page_status']=='analysis':
123
- sent_1 = st.sidebar.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.',on_change=clear_data)
124
- sent_2 = st.sidebar.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.',on_change=clear_data)
125
- input_ids_1 = tokenizer(sent_1).input_ids
126
- input_ids_2 = tokenizer(sent_2).input_ids
127
- input_ids = torch.tensor([input_ids_1,input_ids_2])
128
-
129
- outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions = {0:{'lay':[(8,1,[0,1])]}})
130
- logprobs = F.log_softmax(outputs['logits'], dim = -1)
131
- preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0]]
132
- st.write([tokenizer.decode([token]) for token in preds])
 
 
 
58
  for key in st.session_state:
59
  del st.session_state[key]
60
 
61
+ def annotate_mask(sent_id,sent):
62
+ st.write(f'Sentence {sent_id}')
63
+ input_sent = tokenizer(sent).input_ids
64
+ decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]]
65
+ char_nums = [len(word)+2 for word in decoded_sent]
66
+ cols = st.columns(char_nums)
67
+ mask_locs = []
68
+ for word_id,(col,word) in enumerate(zip(cols,decoded_sent)):
69
+ with col:
70
+ if st.button(word,key=f'word_{word_id}'):
71
+ if word_id not in mask_locs:
72
+ mask_locs.append(word_id)
73
+ else:
74
+ mask_locs.remove(word_id)
75
+ st.markdown(show_annotated_sentence(decoded_sent,mask_locs=mask_locs), unsafe_allow_html = True)
76
+ return mask_locs
77
+
78
+ def annotate_options(sent_id,sent):
79
+ st.write(f'Sentence {sent_id}')
80
+ input_sent = tokenizer(sent).input_ids
81
+ decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]]
82
+ char_nums = [len(word)+2 for word in decoded_sent]
83
+ cols = st.columns(char_nums)
84
+ option_locs = []
85
+ for word_id,(col,word) in enumerate(zip(cols,decoded_sent)):
86
+ with col:
87
+ if st.button(word,key=f'word_{word_id}'):
88
+ if word_id not in option_locs:
89
+ option_locs.append(word_id)
90
+ else:
91
+ option_locs.remove(word_id)
92
+ st.markdown(show_annotated_sentence(decoded_sent,option_locs=option_locs,
93
+ mask_locs=st.session_state[f'mask_locs_{sent_id}']), unsafe_allow_html = True)
94
+ return option_locs
95
+
96
  def show_annotated_sentence(sent,option_locs=[],mask_locs=[]):
97
  disp_style = '"font-family:san serif; color:Black; font-size: 20px"'
98
  prefix = f'<p style={disp_style}><span style="font-weight:bold">'
 
125
  sent_1 = st.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.')
126
  sent_2 = st.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.')
127
  if st.button('Tokenize'):
128
+ st.session_state['page_status'] = 'annotate_mask'
129
  st.session_state['sent_1'] = sent_1
130
  st.session_state['sent_2'] = sent_2
131
  main_area.empty()
132
 
133
+ if st.session_state['page_status']=='annotate_mask':
134
  with main_area.container():
135
  sent_1 = st.session_state['sent_1']
136
  sent_2 = st.session_state['sent_2']
 
 
 
 
137
 
138
  st.write('2. Select sites to mask out and click "Confirm"')
139
+ st.session_state[f'mask_locs_1'] = annotate_mask(1,sent_1)
140
+ st.session_state[f'mask_locs_2'] = annotate_mask(2,sent_2)
141
+ if st.button('Confirm'):
142
+ st.session_state['page_status'] = 'annotate_options'
143
+ main_area.empty()
144
+
145
+ if st.session_state['page_status'] == 'annotate_options':
146
+ with main_area.container():
147
+ sent_1 = st.session_state['sent_1']
148
+ sent_2 = st.session_state['sent_2']
 
 
149
 
150
+ st.write('2. Select options click "Confirm"')
151
+ st.session_state[f'option_locs_1'] = annotate_options(1,sent_1)
152
+ st.session_state[f'option_locs_2'] = annotate_options(2,sent_2)
153
+ if st.button('Confirm'):
154
+ st.session_state['page_status'] = 'analysis'
155
+ main_area.empty()
156
 
157
  if st.session_state['page_status']=='analysis':
158
+ with main_area.container():
159
+ sent_1 = st.session_state['sent_1']
160
+ sent_2 = st.session_state['sent_2']
161
+
162
+ input_ids_1 = tokenizer(sent_1).input_ids
163
+ input_ids_2 = tokenizer(sent_2).input_ids
164
+ input_ids = torch.tensor([input_ids_1,input_ids_2])
165
+
166
+ outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions = {0:{'lay':[(8,1,[0,1])]}})
167
+ logprobs = F.log_softmax(outputs['logits'], dim = -1)
168
+ preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0]]
169
+ st.write([tokenizer.decode([token]) for token in preds])