Olivia Figueira commited on
Commit
e3c1abf
1 Parent(s): 3c050d3

Refactored LM inits and changed app ui

Browse files
Files changed (1) hide show
  1. critic/critic.py +110 -112
critic/critic.py CHANGED
@@ -20,7 +20,7 @@ import streamlit as st
20
 
21
  st.subheader('Exploring Unsupervised Grammatical Error Correction with Transformer-Based Models')
22
  st.write('This live demonstration is adapted from the paper [LM-Critic: Language Models for Unsupervised Grammatical Error Correction](https://aclanthology.org/2021.emnlp-main.611.pdf) (EMNLP 2021) by Michihiro Yasunaga, Jure Leskovec, Percy Liang.')
23
- st.write('The below demo first loads several LMs that we use in the LM-Critic. You will be prompted to enter a sentence which will then be scored by each of the LM-Critics using different LMs.')
24
 
25
  def get_gpt2_loss(model, tokenizer, input_ids, attention_mask, labels):
26
  with torch.no_grad():
@@ -142,132 +142,120 @@ def gpt2_critic(sent, model, tokenizer, verbose=1, cuda=False, fp16=True, seed='
142
  counter_example = [sents[best_idx], float(logps[best_idx])]
143
  return is_good, float(logps[0]), counter_example, return_string
144
 
145
- def init_lms():
 
146
  placeholder_lm_name = st.empty()
147
- prog = 0
148
- my_bar = st.progress(prog)
149
-
150
- if "nice_name_gpt2" not in st.session_state:
151
- ## GPT-2 LM (original LM-critic)
152
- model_name_gpt2 = 'gpt2'
153
- nice_name_gpt2 = "GPT-2"
154
- placeholder_lm_name.text(f"Initializing {nice_name_gpt2}...")
155
- tokenizer_gpt2 = GPT2Tokenizer.from_pretrained(model_name_gpt2)
156
- tokenizer_gpt2.pad_token = tokenizer_gpt2.eos_token
157
- model_gpt2 = GPT2LMHeadModel.from_pretrained(model_name_gpt2)
158
- model_gpt2.eval()
159
- model_gpt2.cpu()
160
- st.session_state["model_gpt2"] = model_gpt2
161
- st.session_state["tokenizer_gpt2"] = tokenizer_gpt2
162
- st.session_state["nice_name_gpt2"] = nice_name_gpt2
163
-
164
- prog += 10
165
- my_bar.progress(prog)
166
-
167
- if "nice_name_opt" not in st.session_state:
168
- ## OPT LM
169
- model_name_opt = "facebook/opt-350m"
170
- nice_name_opt = "OPT"
171
- placeholder_lm_name.text(f"Initializing {nice_name_opt}...")
172
- model_opt = OPTForCausalLM.from_pretrained(model_name_opt)
173
- tokenizer_opt = GPT2Tokenizer.from_pretrained(model_name_opt)
174
- tokenizer_opt.pad_token = tokenizer_opt.eos_token
175
- model_opt.eval()
176
- model_opt.cpu()
177
- st.session_state["model_opt"] = model_opt
178
- st.session_state["tokenizer_opt"] = tokenizer_opt
179
- st.session_state["nice_name_opt"] = nice_name_opt
180
-
181
- prog += 10
182
- my_bar.progress(prog)
183
-
184
- if "nice_name_gptneo" not in st.session_state:
185
- ## GPT NEO
186
- model_name_gptneo = "EleutherAI/gpt-neo-1.3B"
187
- nice_name_gptneo = "GPT NEO"
188
- placeholder_lm_name.text(f"Initializing {nice_name_gptneo}...")
189
- model_gptneo = GPTNeoForCausalLM.from_pretrained(model_name_gptneo)
190
- tokenizer_gptneo = GPT2Tokenizer.from_pretrained(model_name_gptneo)
191
- tokenizer_gptneo.pad_token = tokenizer_gptneo.eos_token
192
- model_gptneo.eval()
193
- model_gptneo.cpu()
194
- st.session_state["model_gptneo"] = model_gptneo
195
- st.session_state["tokenizer_gptneo"] = tokenizer_gptneo
196
- st.session_state["nice_name_gptneo"] = nice_name_gptneo
197
-
198
- prog += 10
199
- my_bar.progress(prog)
200
-
201
- if "nice_name_roberta" not in st.session_state:
202
- ## RoBERTa
203
- model_name_roberta = "roberta-base"
204
- nice_name_roberta = "RoBERTa"
205
- placeholder_lm_name.text(f"Initializing {nice_name_roberta}...")
206
- tokenizer_roberta = RobertaTokenizer.from_pretrained(model_name_roberta)
207
- config_roberta = RobertaConfig.from_pretrained(model_name_roberta)
208
- config_roberta.is_decoder = True
209
- model_roberta = RobertaForCausalLM.from_pretrained(model_name_roberta, config=config_roberta)
210
- tokenizer_roberta.pad_token = tokenizer_roberta.eos_token
211
- model_roberta.eval()
212
- model_roberta.cpu()
213
- st.session_state["model_roberta"] = model_gptneo
214
- st.session_state["tokenizer_roberta"] = tokenizer_roberta
215
- st.session_state["nice_name_roberta"] = nice_name_roberta
216
 
217
- prog += 10
218
- my_bar.progress(prog)
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
- if "nice_name_bart" not in st.session_state:
221
- ## BART
222
- model_name_bart = "facebook/bart-base"
223
- nice_name_bart = "BART"
224
- placeholder_lm_name.text(f"Initializing {nice_name_bart}...")
225
- tokenizer_bart = BartTokenizer.from_pretrained(model_name_bart)
226
- model_bart = BartForCausalLM.from_pretrained(model_name_bart, add_cross_attention=False)
227
- assert model_bart.config.is_decoder, f"{model_bart.__class__} has to be configured as a decoder."
228
- tokenizer_bart.pad_token = tokenizer_bart.eos_token
229
- model_bart.eval()
230
- model_bart.cpu()
231
- st.session_state["model_bart"] = model_bart
232
- st.session_state["tokenizer_bart"] = tokenizer_bart
233
- st.session_state["nice_name_bart"] = nice_name_bart
 
234
 
235
- prog += 10
236
- my_bar.progress(prog)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
- if "nice_name_xlmroberta" not in st.session_state:
239
- ## XLM RoBERTa
240
- model_name_xlmroberta = 'xlm-roberta-base'
241
- nice_name_xlmroberta = 'XLM RoBERTa'
242
- placeholder_lm_name.text(f"Initializing {nice_name_xlmroberta}...")
243
- tokenizer_xlmroberta = XLMRobertaTokenizer.from_pretrained(model_name_xlmroberta)
244
- config_xlmroberta = XLMRobertaConfig.from_pretrained(model_name_xlmroberta)
245
- config_xlmroberta.is_decoder = True
246
- model_xlmroberta = XLMRobertaForCausalLM.from_pretrained(model_name_xlmroberta, config=config_xlmroberta)
247
- tokenizer_xlmroberta.pad_token = tokenizer_xlmroberta.eos_token
248
- model_xlmroberta.eval()
249
- model_xlmroberta.cpu()
250
- st.session_state["model_xlmroberta"] = model_xlmroberta
251
- st.session_state["tokenizer_xlmroberta"] = tokenizer_xlmroberta
252
- st.session_state["nice_name_xlmroberta"] = nice_name_xlmroberta
 
253
 
254
- prog += 10
255
- my_bar.progress(prog)
 
 
 
 
 
 
 
 
 
 
 
256
  placeholder_lm_name.empty()
257
- my_bar.empty()
 
 
258
 
259
  def main():
260
- if "GPT-2" not in st.session_state:
261
- init_lms()
262
- sent = st.text_input('Enter a sentence:', value="")
263
 
264
- ### LMs we are trying:
265
- if sent != '':
266
  st.markdown(f"**Input Sentence**: {sent}")
267
  results = {}
268
 
269
  with st.spinner('Running with GPT-2 LM...'):
270
  ## GPT-2 LM (original LM-critic)
 
 
271
  is_good, score, counter_example, return_string_GPT2 = gpt2_critic(sent, st.session_state['model_gpt2'], st.session_state['tokenizer_gpt2'])
272
  st.markdown("**Results with GPT-2 LM:**")
273
  st.write('\n'.join(return_string_GPT2))
@@ -275,6 +263,8 @@ def main():
275
 
276
  with st.spinner('Running with OPT LM...'):
277
  ## OPT LM
 
 
278
  is_good, score, counter_example, return_string_OPT = gpt2_critic(sent, st.session_state['model_opt'], st.session_state['tokenizer_opt'])
279
  st.markdown("**Results with OPT LM:**")
280
  st.write('\n'.join(return_string_OPT))
@@ -282,6 +272,8 @@ def main():
282
 
283
  with st.spinner('Running with GPT NEO LM...'):
284
  ## GPT NEO
 
 
285
  is_good, score, counter_example, return_string_GPTNEO = gpt2_critic(sent, st.session_state['model_gptneo'], st.session_state['tokenizer_gptneo'])
286
  st.markdown("**Results with GPT NEO LM:**")
287
  st.write('\n'.join(return_string_GPTNEO))
@@ -289,6 +281,8 @@ def main():
289
 
290
  with st.spinner('Running with RoBERTa LM...'):
291
  ## RoBERTa
 
 
292
  is_good, score, counter_example, return_string_RoBERTa = gpt2_critic(sent, st.session_state['model_roberta'], st.session_state['tokenizer_roberta'])
293
  st.markdown("**Results with RoBERTa LM:**")
294
  st.write('\n'.join(return_string_RoBERTa))
@@ -296,6 +290,8 @@ def main():
296
 
297
  with st.spinner('Running with BART LM...'):
298
  ## BART
 
 
299
  is_good, score, counter_example, return_string_BART = gpt2_critic(sent, st.session_state['model_bart'], st.session_state['tokenizer_bart'])
300
  st.markdown("**Results with BART LM:**")
301
  st.write('\n'.join(return_string_BART))
@@ -303,6 +299,8 @@ def main():
303
 
304
  with st.spinner('Running with XLM RoBERTa LM...'):
305
  ## XLM RoBERTa
 
 
306
  is_good, score, counter_example, return_string_XLMRoBERTa = gpt2_critic(sent, st.session_state['model_xlmroberta'], st.session_state['tokenizer_xlmroberta'])
307
  st.markdown("**Results with XLM RoBERTa LM:**")
308
  st.write('\n'.join(return_string_XLMRoBERTa))
 
20
 
21
  st.subheader('Exploring Unsupervised Grammatical Error Correction with Transformer-Based Models')
22
  st.write('This live demonstration is adapted from the paper [LM-Critic: Language Models for Unsupervised Grammatical Error Correction](https://aclanthology.org/2021.emnlp-main.611.pdf) (EMNLP 2021) by Michihiro Yasunaga, Jure Leskovec, Percy Liang.')
23
+ st.write('Enter any sentence in the text box, press submit, and see the grammatical scoring and judgement results outputted by LM-Critic using different LMs dislpayed below. Upon running this for the first time, it will initialize each LM.')
24
 
25
  def get_gpt2_loss(model, tokenizer, input_ids, attention_mask, labels):
26
  with torch.no_grad():
 
142
  counter_example = [sents[best_idx], float(logps[best_idx])]
143
  return is_good, float(logps[0]), counter_example, return_string
144
 
145
+ def gpt2():
146
+ ## GPT-2 LM (original LM-critic)
147
  placeholder_lm_name = st.empty()
148
+ model_name_gpt2 = 'gpt2'
149
+ nice_name_gpt2 = "GPT-2"
150
+ placeholder_lm_name.text(f"Initializing {nice_name_gpt2}...")
151
+ tokenizer_gpt2 = GPT2Tokenizer.from_pretrained(model_name_gpt2)
152
+ tokenizer_gpt2.pad_token = tokenizer_gpt2.eos_token
153
+ model_gpt2 = GPT2LMHeadModel.from_pretrained(model_name_gpt2)
154
+ model_gpt2.eval()
155
+ model_gpt2.cpu()
156
+ placeholder_lm_name.empty()
157
+ st.session_state["model_gpt2"] = model_gpt2
158
+ st.session_state["tokenizer_gpt2"] = tokenizer_gpt2
159
+ st.session_state["nice_name_gpt2"] = nice_name_gpt2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
+ def opt():
162
+ ## OPT LM
163
+ placeholder_lm_name = st.empty()
164
+ model_name_opt = "facebook/opt-350m"
165
+ nice_name_opt = "OPT"
166
+ placeholder_lm_name.text(f"Initializing {nice_name_opt}...")
167
+ model_opt = OPTForCausalLM.from_pretrained(model_name_opt)
168
+ tokenizer_opt = GPT2Tokenizer.from_pretrained(model_name_opt)
169
+ tokenizer_opt.pad_token = tokenizer_opt.eos_token
170
+ model_opt.eval()
171
+ model_opt.cpu()
172
+ placeholder_lm_name.empty()
173
+ st.session_state["model_opt"] = model_opt
174
+ st.session_state["tokenizer_opt"] = tokenizer_opt
175
+ st.session_state["nice_name_opt"] = nice_name_opt
176
 
177
+ def gpt_neo():
178
+ ## GPT NEO
179
+ placeholder_lm_name = st.empty()
180
+ model_name_gptneo = "EleutherAI/gpt-neo-1.3B"
181
+ nice_name_gptneo = "GPT NEO"
182
+ placeholder_lm_name.text(f"Initializing {nice_name_gptneo}...")
183
+ model_gptneo = GPTNeoForCausalLM.from_pretrained(model_name_gptneo)
184
+ tokenizer_gptneo = GPT2Tokenizer.from_pretrained(model_name_gptneo)
185
+ tokenizer_gptneo.pad_token = tokenizer_gptneo.eos_token
186
+ model_gptneo.eval()
187
+ model_gptneo.cpu()
188
+ placeholder_lm_name.empty()
189
+ st.session_state["model_gptneo"] = model_gptneo
190
+ st.session_state["tokenizer_gptneo"] = tokenizer_gptneo
191
+ st.session_state["nice_name_gptneo"] = nice_name_gptneo
192
 
193
+ def roberta():
194
+ ## RoBERTa
195
+ placeholder_lm_name = st.empty()
196
+ model_name_roberta = "roberta-base"
197
+ nice_name_roberta = "RoBERTa"
198
+ placeholder_lm_name.text(f"Initializing {nice_name_roberta}...")
199
+ tokenizer_roberta = RobertaTokenizer.from_pretrained(model_name_roberta)
200
+ config_roberta = RobertaConfig.from_pretrained(model_name_roberta)
201
+ config_roberta.is_decoder = True
202
+ model_roberta = RobertaForCausalLM.from_pretrained(model_name_roberta, config=config_roberta)
203
+ tokenizer_roberta.pad_token = tokenizer_roberta.eos_token
204
+ model_roberta.eval()
205
+ model_roberta.cpu()
206
+ placeholder_lm_name.empty()
207
+ st.session_state["model_roberta"] = model_roberta
208
+ st.session_state["tokenizer_roberta"] = tokenizer_roberta
209
+ st.session_state["nice_name_roberta"] = nice_name_roberta
210
 
211
+ def bart():
212
+ ## BART
213
+ placeholder_lm_name = st.empty()
214
+ model_name_bart = "facebook/bart-base"
215
+ nice_name_bart = "BART"
216
+ placeholder_lm_name.text(f"Initializing {nice_name_bart}...")
217
+ tokenizer_bart = BartTokenizer.from_pretrained(model_name_bart)
218
+ model_bart = BartForCausalLM.from_pretrained(model_name_bart, add_cross_attention=False)
219
+ assert model_bart.config.is_decoder, f"{model_bart.__class__} has to be configured as a decoder."
220
+ tokenizer_bart.pad_token = tokenizer_bart.eos_token
221
+ model_bart.eval()
222
+ model_bart.cpu()
223
+ placeholder_lm_name.empty()
224
+ st.session_state["model_bart"] = model_bart
225
+ st.session_state["tokenizer_bart"] = tokenizer_bart
226
+ st.session_state["nice_name_bart"] = nice_name_bart
227
 
228
+ def xlm_roberta():
229
+ ## XLM RoBERTa
230
+ placeholder_lm_name = st.empty()
231
+ model_name_xlmroberta = 'xlm-roberta-base'
232
+ nice_name_xlmroberta = 'XLM RoBERTa'
233
+ placeholder_lm_name.text(f"Initializing {nice_name_xlmroberta}...")
234
+ tokenizer_xlmroberta = XLMRobertaTokenizer.from_pretrained(model_name_xlmroberta)
235
+ config_xlmroberta = XLMRobertaConfig.from_pretrained(model_name_xlmroberta)
236
+ config_xlmroberta.is_decoder = True
237
+ model_xlmroberta = XLMRobertaForCausalLM.from_pretrained(model_name_xlmroberta, config=config_xlmroberta)
238
+ tokenizer_xlmroberta.pad_token = tokenizer_xlmroberta.eos_token
239
+ model_xlmroberta.eval()
240
+ model_xlmroberta.cpu()
241
  placeholder_lm_name.empty()
242
+ st.session_state["model_xlmroberta"] = model_xlmroberta
243
+ st.session_state["tokenizer_xlmroberta"] = tokenizer_xlmroberta
244
+ st.session_state["nice_name_xlmroberta"] = nice_name_xlmroberta
245
 
246
  def main():
247
+ form = st.form(key='my_form')
248
+ sent = form.text_input(label='Enter a sentence:', value="")
249
+ submit = form.form_submit_button(label='Submit')
250
 
251
+ if submit and sent != '':
 
252
  st.markdown(f"**Input Sentence**: {sent}")
253
  results = {}
254
 
255
  with st.spinner('Running with GPT-2 LM...'):
256
  ## GPT-2 LM (original LM-critic)
257
+ if "nice_name_gpt2" not in st.session_state:
258
+ gpt2()
259
  is_good, score, counter_example, return_string_GPT2 = gpt2_critic(sent, st.session_state['model_gpt2'], st.session_state['tokenizer_gpt2'])
260
  st.markdown("**Results with GPT-2 LM:**")
261
  st.write('\n'.join(return_string_GPT2))
 
263
 
264
  with st.spinner('Running with OPT LM...'):
265
  ## OPT LM
266
+ if "nice_name_opt" not in st.session_state:
267
+ opt()
268
  is_good, score, counter_example, return_string_OPT = gpt2_critic(sent, st.session_state['model_opt'], st.session_state['tokenizer_opt'])
269
  st.markdown("**Results with OPT LM:**")
270
  st.write('\n'.join(return_string_OPT))
 
272
 
273
  with st.spinner('Running with GPT NEO LM...'):
274
  ## GPT NEO
275
+ if "nice_name_gptneo" not in st.session_state:
276
+ gpt_neo()
277
  is_good, score, counter_example, return_string_GPTNEO = gpt2_critic(sent, st.session_state['model_gptneo'], st.session_state['tokenizer_gptneo'])
278
  st.markdown("**Results with GPT NEO LM:**")
279
  st.write('\n'.join(return_string_GPTNEO))
 
281
 
282
  with st.spinner('Running with RoBERTa LM...'):
283
  ## RoBERTa
284
+ if "nice_name_roberta" not in st.session_state:
285
+ roberta()
286
  is_good, score, counter_example, return_string_RoBERTa = gpt2_critic(sent, st.session_state['model_roberta'], st.session_state['tokenizer_roberta'])
287
  st.markdown("**Results with RoBERTa LM:**")
288
  st.write('\n'.join(return_string_RoBERTa))
 
290
 
291
  with st.spinner('Running with BART LM...'):
292
  ## BART
293
+ if "nice_name_bart" not in st.session_state:
294
+ bart()
295
  is_good, score, counter_example, return_string_BART = gpt2_critic(sent, st.session_state['model_bart'], st.session_state['tokenizer_bart'])
296
  st.markdown("**Results with BART LM:**")
297
  st.write('\n'.join(return_string_BART))
 
299
 
300
  with st.spinner('Running with XLM RoBERTa LM...'):
301
  ## XLM RoBERTa
302
+ if "nice_name_xlmroberta" not in st.session_state:
303
+ xlm_roberta()
304
  is_good, score, counter_example, return_string_XLMRoBERTa = gpt2_critic(sent, st.session_state['model_xlmroberta'], st.session_state['tokenizer_xlmroberta'])
305
  st.markdown("**Results with XLM RoBERTa LM:**")
306
  st.write('\n'.join(return_string_XLMRoBERTa))