taka-yamakoshi commited on
Commit
c6dd7aa
1 Parent(s): a440ac3
Files changed (1) hide show
  1. app.py +56 -16
app.py CHANGED
@@ -16,20 +16,7 @@ from transformers import AlbertTokenizer, AlbertForMaskedLM
16
  #from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM
17
  from skeleton_modeling_albert import SkeletonAlbertForMaskedLM
18
 
19
- @st.cache(show_spinner=True,allow_output_mutation=True)
20
- def load_model():
21
- tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')
22
- #model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True)
23
- model = AlbertForMaskedLM.from_pretrained('albert-xxlarge-v2')
24
- return tokenizer,model
25
-
26
- def clear_data():
27
- for key in st.session_state:
28
- del st.session_state[key]
29
-
30
- if __name__=='__main__':
31
-
32
- # Config
33
  max_width = 1500
34
  padding_top = 0
35
  padding_right = 2
@@ -56,9 +43,61 @@ if __name__=='__main__':
56
  st.markdown(define_margins, unsafe_allow_html=True)
57
  st.markdown(hide_table_row_index, unsafe_allow_html=True)
58
 
59
- tokenizer,model = load_model()
60
- mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
 
62
  sent_1 = st.sidebar.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.',on_change=clear_data)
63
  sent_2 = st.sidebar.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.',on_change=clear_data)
64
  input_ids_1 = tokenizer(sent_1).input_ids
@@ -69,3 +108,4 @@ if __name__=='__main__':
69
  logprobs = F.log_softmax(outputs['logits'], dim = -1)
70
  preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0]]
71
  st.write([tokenizer.decode([token]) for token in preds])
 
 
16
  #from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM
17
  from skeleton_modeling_albert import SkeletonAlbertForMaskedLM
18
 
19
+ def wide_setup():
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  max_width = 1500
21
  padding_top = 0
22
  padding_right = 2
 
43
  st.markdown(define_margins, unsafe_allow_html=True)
44
  st.markdown(hide_table_row_index, unsafe_allow_html=True)
45
 
46
+ @st.cache(show_spinner=True,allow_output_mutation=True)
47
+ def load_model():
48
+ tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')
49
+ #model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True)
50
+ model = AlbertForMaskedLM.from_pretrained('albert-xxlarge-v2')
51
+ return tokenizer,model
52
+
53
+ def clear_data():
54
+ for key in st.session_state:
55
+ del st.session_state[key]
56
+
57
+ if __name__=='__main__':
58
+ wide_setup()
59
+
60
+ if 'page_status' not in st.session_state:
61
+ st.session_state['page_status'] = 'type_in'
62
+
63
+ if st.session_state['page_status']=='type_in':
64
+ tokenizer,model = load_model()
65
+ mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
66
+
67
+ st.write('1. Type in the sentences and click "Tokenize"')
68
+ sent_1 = st.sidebar.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.')
69
+ sent_2 = st.sidebar.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.')
70
+ if st.sidebar.button('Tokenize'):
71
+ st.session_state['page_status'] = 'tokenized'
72
+ st.session_state['sent_1'] = sent_1
73
+ st.session_state['sent_2'] = sent_2
74
+
75
+ if st.session_state['page_status']=='tokenized':
76
+ tokenizer,model = load_model()
77
+ mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
78
+ sent_1 = st.session_state['sent_1']
79
+ sent_2 = st.session_state['sent_2']
80
+ if 'masked_pos_1' not in st.session_state:
81
+ st.session_state['masked_pos_1'] = []
82
+ if 'masked_pos_2' not in st.session_state:
83
+ st.session_state['masked_pos_2'] = []
84
+
85
+ st.write('2. Select sites to mask out and click "Confirm"')
86
+ input_sent = tokenizer(sent_1).input_ids
87
+ decoded_sent = [tokenizer.decode([token]) for token in input_sent]
88
+ char_nums = [len(word)+2 for word in decoded_sent]
89
+ cols = st.columns(char_nums)
90
+ with cols[0]:
91
+ st.write(decoded_sent[0])
92
+ with cols[-1]:
93
+ st.write(decoded_sent[-1])
94
+ for word_id,(col,word) in enumerate(zip(cols[1:-1],decoded_sent[1:-1])):
95
+ with col:
96
+ if st.button(word,key=f'word_{word_id}'):
97
+ st.session_state['masked_pos_1'].append(word_id)
98
+ st.write(f'Masked words: {" ".join([decoded_sent[word_id+1] for word_id in st.session_state["masked_pos_1"])}')
99
 
100
+ '''
101
  sent_1 = st.sidebar.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.',on_change=clear_data)
102
  sent_2 = st.sidebar.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.',on_change=clear_data)
103
  input_ids_1 = tokenizer(sent_1).input_ids
 
108
  logprobs = F.log_softmax(outputs['logits'], dim = -1)
109
  preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0]]
110
  st.write([tokenizer.decode([token]) for token in preds])
111
+ '''