Olivia Figueira commited on
Commit
8167d9f
1 Parent(s): 37d028c

Fix runtime memory issues by caching and using session state

Browse files
Files changed (1) hide show
  1. critic/critic.py +128 -63
critic/critic.py CHANGED
@@ -16,6 +16,11 @@ nltk.download('punkt')
16
  sys.path.insert(0, '.')
17
  from critic.perturbations import get_local_neighbors_char_level, get_local_neighbors_word_level
18
  from utils.spacy_tokenizer import spacy_tokenize_gec
 
 
 
 
 
19
 
20
  def get_gpt2_loss(model, tokenizer, input_ids, attention_mask, labels):
21
  with torch.no_grad():
@@ -137,10 +142,117 @@ def gpt2_critic(sent, model, tokenizer, verbose=1, cuda=False, fp16=True, seed='
137
  counter_example = [sents[best_idx], float(logps[best_idx])]
138
  return is_good, float(logps[0]), counter_example, return_string
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  def main():
142
- import streamlit as st
143
- st.subheader('Exploring Unsupervised Grammatical Error Correction with Transformer-Based Models')
144
  sent = st.text_input('Enter a sentence:', value="")
145
 
146
  ### LMs we are trying:
@@ -150,92 +262,45 @@ def main():
150
 
151
  with st.spinner('Running with GPT-2 LM...'):
152
  ## GPT-2 LM (original LM-critic)
153
- model_name = 'gpt2'
154
- nice_name = "GPT-2"
155
- tokenizer = GPT2Tokenizer.from_pretrained(model_name)
156
- tokenizer.pad_token = tokenizer.eos_token
157
- model = GPT2LMHeadModel.from_pretrained(model_name)
158
- model.eval()
159
- model.cpu()
160
- is_good, score, counter_example, return_string_GPT2 = gpt2_critic(sent, model, tokenizer)
161
  st.markdown("**Results with GPT-2 LM:**")
162
  st.write('\n'.join(return_string_GPT2))
163
- results[nice_name] = ["Good" if is_good else "Bad", str(round(score, 3)), "N/A" if not counter_example else str(counter_example[0]), "N/A" if not counter_example else str(round(counter_example[1], 3))]
164
 
165
  with st.spinner('Running with OPT LM...'):
166
  ## OPT LM
167
- model_name = "facebook/opt-350m"
168
- nice_name = "OPT"
169
- model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
170
- tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
171
- tokenizer.pad_token = tokenizer.eos_token
172
- model.eval()
173
- model.cpu()
174
- is_good, score, counter_example, return_string_OPT = gpt2_critic(sent, model, tokenizer)
175
  st.markdown("**Results with OPT LM:**")
176
  st.write('\n'.join(return_string_OPT))
177
- results[nice_name] = ["Good" if is_good else "Bad", str(round(score, 3)), "N/A" if not counter_example else str(counter_example[0]), "N/A" if not counter_example else str(round(counter_example[1], 3))]
178
 
179
  with st.spinner('Running with GPT NEO LM...'):
180
  ## GPT NEO
181
- model_name = "EleutherAI/gpt-neo-1.3B"
182
- nice_name = "GPT NEO"
183
- model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
184
- tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
185
- tokenizer.pad_token = tokenizer.eos_token
186
- model.eval()
187
- model.cpu()
188
- is_good, score, counter_example, return_string_GPTNEO = gpt2_critic(sent, model, tokenizer)
189
  st.markdown("**Results with GPT NEO LM:**")
190
  st.write('\n'.join(return_string_GPTNEO))
191
- results[nice_name] = ["Good" if is_good else "Bad", str(round(score, 3)), "N/A" if not counter_example else str(counter_example[0]), "N/A" if not counter_example else str(round(counter_example[1], 3))]
192
 
193
  with st.spinner('Running with RoBERTa LM...'):
194
  ## RoBERTa
195
- model_name = "roberta-base"
196
- nice_name = "RoBERTa"
197
- tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
198
- config = RobertaConfig.from_pretrained("roberta-base")
199
- config.is_decoder = True
200
- model = RobertaForCausalLM.from_pretrained("roberta-base", config=config)
201
- tokenizer.pad_token = tokenizer.eos_token
202
- model.eval()
203
- model.cpu()
204
- is_good, score, counter_example, return_string_RoBERTa = gpt2_critic(sent, model, tokenizer)
205
  st.markdown("**Results with RoBERTa LM:**")
206
  st.write('\n'.join(return_string_RoBERTa))
207
- results[nice_name] = ["Good" if is_good else "Bad", str(round(score, 3)), "N/A" if not counter_example else str(counter_example[0]), "N/A" if not counter_example else str(round(counter_example[1], 3))]
208
 
209
  with st.spinner('Running with BART LM...'):
210
- ## RoBERTa
211
- model_name = "facebook/bart-base"
212
- nice_name = "BART"
213
- tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
214
- model = BartForCausalLM.from_pretrained("facebook/bart-base", add_cross_attention=False)
215
- assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
216
- tokenizer.pad_token = tokenizer.eos_token
217
- model.eval()
218
- model.cpu()
219
- is_good, score, counter_example, return_string_BART = gpt2_critic(sent, model, tokenizer)
220
  st.markdown("**Results with BART LM:**")
221
  st.write('\n'.join(return_string_BART))
222
- results[nice_name] = ["Good" if is_good else "Bad", str(round(score, 3)), "N/A" if not counter_example else str(counter_example[0]), "N/A" if not counter_example else str(round(counter_example[1], 3))]
223
 
224
  with st.spinner('Running with XLM RoBERTa LM...'):
225
  ## XLM RoBERTa
226
- model_name = 'xlm-roberta-base'
227
- nice_name = 'XLM RoBERTa'
228
- tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
229
- config = XLMRobertaConfig.from_pretrained("xlm-roberta-base")
230
- config.is_decoder = True
231
- model = XLMRobertaForCausalLM.from_pretrained("xlm-roberta-base", config=config)
232
- tokenizer.pad_token = tokenizer.eos_token
233
- model.eval()
234
- model.cpu()
235
- is_good, score, counter_example, return_string_XLMRoBERTa = gpt2_critic(sent, model, tokenizer)
236
  st.markdown("**Results with XLM RoBERTa LM:**")
237
  st.write('\n'.join(return_string_XLMRoBERTa))
238
- results[nice_name] = ["Good" if is_good else "Bad", str(round(score, 3)), "N/A" if not counter_example else str(counter_example[0]), "N/A" if not counter_example else str(round(counter_example[1], 3))]
239
 
240
  df = pd.DataFrame.from_dict(results,
241
  orient = 'index',
@@ -243,7 +308,7 @@ def main():
243
  st.markdown("**Tabular summary of results:**")
244
  st.table(df)
245
 
246
- st.write("Done.")
247
 
248
  if __name__ == '__main__':
249
  main()
 
16
  sys.path.insert(0, '.')
17
  from critic.perturbations import get_local_neighbors_char_level, get_local_neighbors_word_level
18
  from utils.spacy_tokenizer import spacy_tokenize_gec
19
+ import streamlit as st
20
+
21
+ st.subheader('Exploring Unsupervised Grammatical Error Correction with Transformer-Based Models')
22
+ st.write('This live demonstration is adapted from the paper [LM-Critic: Language Models for Unsupervised Grammatical Error Correction](https://aclanthology.org/2021.emnlp-main.611.pdf) (EMNLP 2021) by Michihiro Yasunaga, Jure Leskovec, Percy Liang.')
23
+ st.write('The below demo first loads several LMs that we use in the LM-Critic. You will be prompted to enter a sentence which will then be scored by each of the LM-Critics using different LMs.')
24
 
25
  def get_gpt2_loss(model, tokenizer, input_ids, attention_mask, labels):
26
  with torch.no_grad():
 
142
  counter_example = [sents[best_idx], float(logps[best_idx])]
143
  return is_good, float(logps[0]), counter_example, return_string
144
 
145
+ @st.cache(suppress_st_warning=True)
146
+ def init_lms():
147
+ placeholder_lm_name = st.empty()
148
+ my_bar = st.progress(10)
149
+ prog = 10
150
+
151
+ ## GPT-2 LM (original LM-critic)
152
+ model_name_gpt2 = 'gpt2'
153
+ nice_name_gpt2 = "GPT-2"
154
+ placeholder_lm_name.text(f"Initializing {nice_name_gpt2}...")
155
+ tokenizer_gpt2 = GPT2Tokenizer.from_pretrained(model_name_gpt2)
156
+ tokenizer_gpt2.pad_token = tokenizer_gpt2.eos_token
157
+ model_gpt2 = GPT2LMHeadModel.from_pretrained(model_name_gpt2)
158
+ model_gpt2.eval()
159
+ model_gpt2.cpu()
160
+ st.session_state["model_gpt2"] = model_gpt2
161
+ st.session_state["tokenizer_gpt2"] = tokenizer_gpt2
162
+ st.session_state["nice_name_gpt2"] = nice_name_gpt2
163
+
164
+ prog += 10
165
+ my_bar.progress(prog)
166
+
167
+ ## OPT LM
168
+ model_name_opt = "facebook/opt-350m"
169
+ nice_name_opt = "OPT"
170
+ placeholder_lm_name.text(f"Initializing {nice_name_opt}...")
171
+ model_opt = OPTForCausalLM.from_pretrained(model_name_opt)
172
+ tokenizer_opt = GPT2Tokenizer.from_pretrained(model_name_opt)
173
+ tokenizer_opt.pad_token = tokenizer_opt.eos_token
174
+ model_opt.eval()
175
+ model_opt.cpu()
176
+ st.session_state["model_opt"] = model_opt
177
+ st.session_state["tokenizer_opt"] = tokenizer_opt
178
+ st.session_state["nice_name_opt"] = nice_name_opt
179
+
180
+ prog += 10
181
+ my_bar.progress(prog)
182
+
183
+ ## GPT NEO
184
+ model_name_gptneo = "EleutherAI/gpt-neo-1.3B"
185
+ nice_name_gptneo = "GPT NEO"
186
+ placeholder_lm_name.text(f"Initializing {nice_name_gptneo}...")
187
+ model_gptneo = GPTNeoForCausalLM.from_pretrained(model_name_gptneo)
188
+ tokenizer_gptneo = GPT2Tokenizer.from_pretrained(model_name_gptneo)
189
+ tokenizer_gptneo.pad_token = tokenizer_gptneo.eos_token
190
+ model_gptneo.eval()
191
+ model_gptneo.cpu()
192
+ st.session_state["model_gptneo"] = model_gptneo
193
+ st.session_state["tokenizer_gptneo"] = tokenizer_gptneo
194
+ st.session_state["nice_name_gptneo"] = nice_name_gptneo
195
+
196
+ prog += 10
197
+ my_bar.progress(prog)
198
+
199
+ ## RoBERTa
200
+ model_name_roberta = "roberta-base"
201
+ nice_name_roberta = "RoBERTa"
202
+ placeholder_lm_name.text(f"Initializing {nice_name_roberta}...")
203
+ tokenizer_roberta = RobertaTokenizer.from_pretrained(model_name_roberta)
204
+ config_roberta = RobertaConfig.from_pretrained(model_name_roberta)
205
+ config_roberta.is_decoder = True
206
+ model_roberta = RobertaForCausalLM.from_pretrained(model_name_roberta, config=config_roberta)
207
+ tokenizer_roberta.pad_token = tokenizer_roberta.eos_token
208
+ model_roberta.eval()
209
+ model_roberta.cpu()
210
+ st.session_state["model_roberta"] = model_gptneo
211
+ st.session_state["tokenizer_roberta"] = tokenizer_roberta
212
+ st.session_state["nice_name_roberta"] = nice_name_roberta
213
+
214
+ prog += 10
215
+ my_bar.progress(prog)
216
+
217
+ ## BART
218
+ model_name_bart = "facebook/bart-base"
219
+ nice_name_bart = "BART"
220
+ placeholder_lm_name.text(f"Initializing {nice_name_bart}...")
221
+ tokenizer_bart = BartTokenizer.from_pretrained(model_name_bart)
222
+ model_bart = BartForCausalLM.from_pretrained(model_name_bart, add_cross_attention=False)
223
+ assert model_bart.config.is_decoder, f"{model_bart.__class__} has to be configured as a decoder."
224
+ tokenizer_bart.pad_token = tokenizer_bart.eos_token
225
+ model_bart.eval()
226
+ model_bart.cpu()
227
+ st.session_state["model_bart"] = model_bart
228
+ st.session_state["tokenizer_bart"] = tokenizer_bart
229
+ st.session_state["nice_name_bart"] = nice_name_bart
230
+
231
+ prog += 10
232
+ my_bar.progress(prog)
233
+
234
+ ## XLM RoBERTa
235
+ model_name_xlmroberta = 'xlm-roberta-base'
236
+ nice_name_xlmroberta = 'XLM RoBERTa'
237
+ placeholder_lm_name.text(f"Initializing {nice_name_xlmroberta}...")
238
+ tokenizer_xlmroberta = XLMRobertaTokenizer.from_pretrained(model_name_xlmroberta)
239
+ config_xlmroberta = XLMRobertaConfig.from_pretrained(model_name_xlmroberta)
240
+ config_xlmroberta.is_decoder = True
241
+ model_xlmroberta = XLMRobertaForCausalLM.from_pretrained(model_name_xlmroberta, config=config_xlmroberta)
242
+ tokenizer_xlmroberta.pad_token = tokenizer_xlmroberta.eos_token
243
+ model_xlmroberta.eval()
244
+ model_xlmroberta.cpu()
245
+ st.session_state["model_xlmroberta"] = model_xlmroberta
246
+ st.session_state["tokenizer_xlmroberta"] = tokenizer_xlmroberta
247
+ st.session_state["nice_name_xlmroberta"] = nice_name_xlmroberta
248
+
249
+ prog += 10
250
+ my_bar.progress(prog)
251
+ placeholder_lm_name.empty()
252
+ my_bar.empty()
253
 
254
  def main():
255
+ init_lms()
 
256
  sent = st.text_input('Enter a sentence:', value="")
257
 
258
  ### LMs we are trying:
 
262
 
263
  with st.spinner('Running with GPT-2 LM...'):
264
  ## GPT-2 LM (original LM-critic)
265
+ is_good, score, counter_example, return_string_GPT2 = gpt2_critic(sent, st.session_state['model_gpt2'], st.session_state['tokenizer_gpt2'])
 
 
 
 
 
 
 
266
  st.markdown("**Results with GPT-2 LM:**")
267
  st.write('\n'.join(return_string_GPT2))
268
+ results[st.session_state['nice_name_gpt2']] = ["Good" if is_good else "Bad", str(round(score, 3)), "N/A" if not counter_example else str(counter_example[0]), "N/A" if not counter_example else str(round(counter_example[1], 3))]
269
 
270
  with st.spinner('Running with OPT LM...'):
271
  ## OPT LM
272
+ is_good, score, counter_example, return_string_OPT = gpt2_critic(sent, st.session_state['model_opt'], st.session_state['tokenizer_opt'])
 
 
 
 
 
 
 
273
  st.markdown("**Results with OPT LM:**")
274
  st.write('\n'.join(return_string_OPT))
275
+ results[st.session_state['nice_name_opt']] = ["Good" if is_good else "Bad", str(round(score, 3)), "N/A" if not counter_example else str(counter_example[0]), "N/A" if not counter_example else str(round(counter_example[1], 3))]
276
 
277
  with st.spinner('Running with GPT NEO LM...'):
278
  ## GPT NEO
279
+ is_good, score, counter_example, return_string_GPTNEO = gpt2_critic(sent, st.session_state['model_gptneo'], st.session_state['tokenizer_gptneo'])
 
 
 
 
 
 
 
280
  st.markdown("**Results with GPT NEO LM:**")
281
  st.write('\n'.join(return_string_GPTNEO))
282
+ results[st.session_state['nice_name_gptneo']] = ["Good" if is_good else "Bad", str(round(score, 3)), "N/A" if not counter_example else str(counter_example[0]), "N/A" if not counter_example else str(round(counter_example[1], 3))]
283
 
284
  with st.spinner('Running with RoBERTa LM...'):
285
  ## RoBERTa
286
+ is_good, score, counter_example, return_string_RoBERTa = gpt2_critic(sent, st.session_state['model_roberta'], st.session_state['tokenizer_roberta'])
 
 
 
 
 
 
 
 
 
287
  st.markdown("**Results with RoBERTa LM:**")
288
  st.write('\n'.join(return_string_RoBERTa))
289
+ results[st.session_state['nice_name_roberta']] = ["Good" if is_good else "Bad", str(round(score, 3)), "N/A" if not counter_example else str(counter_example[0]), "N/A" if not counter_example else str(round(counter_example[1], 3))]
290
 
291
  with st.spinner('Running with BART LM...'):
292
+ ## BART
293
+ is_good, score, counter_example, return_string_BART = gpt2_critic(sent, st.session_state['model_bart'], st.session_state['tokenizer_bart'])
 
 
 
 
 
 
 
 
294
  st.markdown("**Results with BART LM:**")
295
  st.write('\n'.join(return_string_BART))
296
+ results[st.session_state['nice_name_bart']] = ["Good" if is_good else "Bad", str(round(score, 3)), "N/A" if not counter_example else str(counter_example[0]), "N/A" if not counter_example else str(round(counter_example[1], 3))]
297
 
298
  with st.spinner('Running with XLM RoBERTa LM...'):
299
  ## XLM RoBERTa
300
+ is_good, score, counter_example, return_string_XLMRoBERTa = gpt2_critic(sent, st.session_state['model_xlmroberta'], st.session_state['tokenizer_xlmroberta'])
 
 
 
 
 
 
 
 
 
301
  st.markdown("**Results with XLM RoBERTa LM:**")
302
  st.write('\n'.join(return_string_XLMRoBERTa))
303
+ results[st.session_state['nice_name_xlmroberta']] = ["Good" if is_good else "Bad", str(round(score, 3)), "N/A" if not counter_example else str(counter_example[0]), "N/A" if not counter_example else str(round(counter_example[1], 3))]
304
 
305
  df = pd.DataFrame.from_dict(results,
306
  orient = 'index',
 
308
  st.markdown("**Tabular summary of results:**")
309
  st.table(df)
310
 
311
+ st.write("Input another sentence!")
312
 
313
  if __name__ == '__main__':
314
  main()