Olivia Figueira commited on
Commit
3c050d3
1 Parent(s): 8167d9f

Fixed reloading issue with state init

Browse files
Files changed (1) hide show
  1. critic/critic.py +89 -83
critic/critic.py CHANGED
@@ -142,109 +142,114 @@ def gpt2_critic(sent, model, tokenizer, verbose=1, cuda=False, fp16=True, seed='
142
  counter_example = [sents[best_idx], float(logps[best_idx])]
143
  return is_good, float(logps[0]), counter_example, return_string
144
 
145
- @st.cache(suppress_st_warning=True)
146
  def init_lms():
147
  placeholder_lm_name = st.empty()
148
- my_bar = st.progress(10)
149
- prog = 10
150
-
151
- ## GPT-2 LM (original LM-critic)
152
- model_name_gpt2 = 'gpt2'
153
- nice_name_gpt2 = "GPT-2"
154
- placeholder_lm_name.text(f"Initializing {nice_name_gpt2}...")
155
- tokenizer_gpt2 = GPT2Tokenizer.from_pretrained(model_name_gpt2)
156
- tokenizer_gpt2.pad_token = tokenizer_gpt2.eos_token
157
- model_gpt2 = GPT2LMHeadModel.from_pretrained(model_name_gpt2)
158
- model_gpt2.eval()
159
- model_gpt2.cpu()
160
- st.session_state["model_gpt2"] = model_gpt2
161
- st.session_state["tokenizer_gpt2"] = tokenizer_gpt2
162
- st.session_state["nice_name_gpt2"] = nice_name_gpt2
 
163
 
164
  prog += 10
165
  my_bar.progress(prog)
166
 
167
- ## OPT LM
168
- model_name_opt = "facebook/opt-350m"
169
- nice_name_opt = "OPT"
170
- placeholder_lm_name.text(f"Initializing {nice_name_opt}...")
171
- model_opt = OPTForCausalLM.from_pretrained(model_name_opt)
172
- tokenizer_opt = GPT2Tokenizer.from_pretrained(model_name_opt)
173
- tokenizer_opt.pad_token = tokenizer_opt.eos_token
174
- model_opt.eval()
175
- model_opt.cpu()
176
- st.session_state["model_opt"] = model_opt
177
- st.session_state["tokenizer_opt"] = tokenizer_opt
178
- st.session_state["nice_name_opt"] = nice_name_opt
 
179
 
180
  prog += 10
181
  my_bar.progress(prog)
182
 
183
- ## GPT NEO
184
- model_name_gptneo = "EleutherAI/gpt-neo-1.3B"
185
- nice_name_gptneo = "GPT NEO"
186
- placeholder_lm_name.text(f"Initializing {nice_name_gptneo}...")
187
- model_gptneo = GPTNeoForCausalLM.from_pretrained(model_name_gptneo)
188
- tokenizer_gptneo = GPT2Tokenizer.from_pretrained(model_name_gptneo)
189
- tokenizer_gptneo.pad_token = tokenizer_gptneo.eos_token
190
- model_gptneo.eval()
191
- model_gptneo.cpu()
192
- st.session_state["model_gptneo"] = model_gptneo
193
- st.session_state["tokenizer_gptneo"] = tokenizer_gptneo
194
- st.session_state["nice_name_gptneo"] = nice_name_gptneo
 
195
 
196
  prog += 10
197
  my_bar.progress(prog)
198
 
199
- ## RoBERTa
200
- model_name_roberta = "roberta-base"
201
- nice_name_roberta = "RoBERTa"
202
- placeholder_lm_name.text(f"Initializing {nice_name_roberta}...")
203
- tokenizer_roberta = RobertaTokenizer.from_pretrained(model_name_roberta)
204
- config_roberta = RobertaConfig.from_pretrained(model_name_roberta)
205
- config_roberta.is_decoder = True
206
- model_roberta = RobertaForCausalLM.from_pretrained(model_name_roberta, config=config_roberta)
207
- tokenizer_roberta.pad_token = tokenizer_roberta.eos_token
208
- model_roberta.eval()
209
- model_roberta.cpu()
210
- st.session_state["model_roberta"] = model_gptneo
211
- st.session_state["tokenizer_roberta"] = tokenizer_roberta
212
- st.session_state["nice_name_roberta"] = nice_name_roberta
 
213
 
214
  prog += 10
215
  my_bar.progress(prog)
216
 
217
- ## BART
218
- model_name_bart = "facebook/bart-base"
219
- nice_name_bart = "BART"
220
- placeholder_lm_name.text(f"Initializing {nice_name_bart}...")
221
- tokenizer_bart = BartTokenizer.from_pretrained(model_name_bart)
222
- model_bart = BartForCausalLM.from_pretrained(model_name_bart, add_cross_attention=False)
223
- assert model_bart.config.is_decoder, f"{model_bart.__class__} has to be configured as a decoder."
224
- tokenizer_bart.pad_token = tokenizer_bart.eos_token
225
- model_bart.eval()
226
- model_bart.cpu()
227
- st.session_state["model_bart"] = model_bart
228
- st.session_state["tokenizer_bart"] = tokenizer_bart
229
- st.session_state["nice_name_bart"] = nice_name_bart
 
230
 
231
  prog += 10
232
  my_bar.progress(prog)
233
 
234
- ## XLM RoBERTa
235
- model_name_xlmroberta = 'xlm-roberta-base'
236
- nice_name_xlmroberta = 'XLM RoBERTa'
237
- placeholder_lm_name.text(f"Initializing {nice_name_xlmroberta}...")
238
- tokenizer_xlmroberta = XLMRobertaTokenizer.from_pretrained(model_name_xlmroberta)
239
- config_xlmroberta = XLMRobertaConfig.from_pretrained(model_name_xlmroberta)
240
- config_xlmroberta.is_decoder = True
241
- model_xlmroberta = XLMRobertaForCausalLM.from_pretrained(model_name_xlmroberta, config=config_xlmroberta)
242
- tokenizer_xlmroberta.pad_token = tokenizer_xlmroberta.eos_token
243
- model_xlmroberta.eval()
244
- model_xlmroberta.cpu()
245
- st.session_state["model_xlmroberta"] = model_xlmroberta
246
- st.session_state["tokenizer_xlmroberta"] = tokenizer_xlmroberta
247
- st.session_state["nice_name_xlmroberta"] = nice_name_xlmroberta
 
248
 
249
  prog += 10
250
  my_bar.progress(prog)
@@ -252,7 +257,8 @@ def init_lms():
252
  my_bar.empty()
253
 
254
  def main():
255
- init_lms()
 
256
  sent = st.text_input('Enter a sentence:', value="")
257
 
258
  ### LMs we are trying:
@@ -311,4 +317,4 @@ def main():
311
  st.write("Input another sentence!")
312
 
313
  if __name__ == '__main__':
314
- main()
 
142
  counter_example = [sents[best_idx], float(logps[best_idx])]
143
  return is_good, float(logps[0]), counter_example, return_string
144
 
 
145
  def init_lms():
146
  placeholder_lm_name = st.empty()
147
+ prog = 0
148
+ my_bar = st.progress(prog)
149
+
150
+ if "nice_name_gpt2" not in st.session_state:
151
+ ## GPT-2 LM (original LM-critic)
152
+ model_name_gpt2 = 'gpt2'
153
+ nice_name_gpt2 = "GPT-2"
154
+ placeholder_lm_name.text(f"Initializing {nice_name_gpt2}...")
155
+ tokenizer_gpt2 = GPT2Tokenizer.from_pretrained(model_name_gpt2)
156
+ tokenizer_gpt2.pad_token = tokenizer_gpt2.eos_token
157
+ model_gpt2 = GPT2LMHeadModel.from_pretrained(model_name_gpt2)
158
+ model_gpt2.eval()
159
+ model_gpt2.cpu()
160
+ st.session_state["model_gpt2"] = model_gpt2
161
+ st.session_state["tokenizer_gpt2"] = tokenizer_gpt2
162
+ st.session_state["nice_name_gpt2"] = nice_name_gpt2
163
 
164
  prog += 10
165
  my_bar.progress(prog)
166
 
167
+ if "nice_name_opt" not in st.session_state:
168
+ ## OPT LM
169
+ model_name_opt = "facebook/opt-350m"
170
+ nice_name_opt = "OPT"
171
+ placeholder_lm_name.text(f"Initializing {nice_name_opt}...")
172
+ model_opt = OPTForCausalLM.from_pretrained(model_name_opt)
173
+ tokenizer_opt = GPT2Tokenizer.from_pretrained(model_name_opt)
174
+ tokenizer_opt.pad_token = tokenizer_opt.eos_token
175
+ model_opt.eval()
176
+ model_opt.cpu()
177
+ st.session_state["model_opt"] = model_opt
178
+ st.session_state["tokenizer_opt"] = tokenizer_opt
179
+ st.session_state["nice_name_opt"] = nice_name_opt
180
 
181
  prog += 10
182
  my_bar.progress(prog)
183
 
184
+ if "nice_name_gptneo" not in st.session_state:
185
+ ## GPT NEO
186
+ model_name_gptneo = "EleutherAI/gpt-neo-1.3B"
187
+ nice_name_gptneo = "GPT NEO"
188
+ placeholder_lm_name.text(f"Initializing {nice_name_gptneo}...")
189
+ model_gptneo = GPTNeoForCausalLM.from_pretrained(model_name_gptneo)
190
+ tokenizer_gptneo = GPT2Tokenizer.from_pretrained(model_name_gptneo)
191
+ tokenizer_gptneo.pad_token = tokenizer_gptneo.eos_token
192
+ model_gptneo.eval()
193
+ model_gptneo.cpu()
194
+ st.session_state["model_gptneo"] = model_gptneo
195
+ st.session_state["tokenizer_gptneo"] = tokenizer_gptneo
196
+ st.session_state["nice_name_gptneo"] = nice_name_gptneo
197
 
198
  prog += 10
199
  my_bar.progress(prog)
200
 
201
+ if "nice_name_roberta" not in st.session_state:
202
+ ## RoBERTa
203
+ model_name_roberta = "roberta-base"
204
+ nice_name_roberta = "RoBERTa"
205
+ placeholder_lm_name.text(f"Initializing {nice_name_roberta}...")
206
+ tokenizer_roberta = RobertaTokenizer.from_pretrained(model_name_roberta)
207
+ config_roberta = RobertaConfig.from_pretrained(model_name_roberta)
208
+ config_roberta.is_decoder = True
209
+ model_roberta = RobertaForCausalLM.from_pretrained(model_name_roberta, config=config_roberta)
210
+ tokenizer_roberta.pad_token = tokenizer_roberta.eos_token
211
+ model_roberta.eval()
212
+ model_roberta.cpu()
213
+ st.session_state["model_roberta"] = model_gptneo
214
+ st.session_state["tokenizer_roberta"] = tokenizer_roberta
215
+ st.session_state["nice_name_roberta"] = nice_name_roberta
216
 
217
  prog += 10
218
  my_bar.progress(prog)
219
 
220
+ if "nice_name_bart" not in st.session_state:
221
+ ## BART
222
+ model_name_bart = "facebook/bart-base"
223
+ nice_name_bart = "BART"
224
+ placeholder_lm_name.text(f"Initializing {nice_name_bart}...")
225
+ tokenizer_bart = BartTokenizer.from_pretrained(model_name_bart)
226
+ model_bart = BartForCausalLM.from_pretrained(model_name_bart, add_cross_attention=False)
227
+ assert model_bart.config.is_decoder, f"{model_bart.__class__} has to be configured as a decoder."
228
+ tokenizer_bart.pad_token = tokenizer_bart.eos_token
229
+ model_bart.eval()
230
+ model_bart.cpu()
231
+ st.session_state["model_bart"] = model_bart
232
+ st.session_state["tokenizer_bart"] = tokenizer_bart
233
+ st.session_state["nice_name_bart"] = nice_name_bart
234
 
235
  prog += 10
236
  my_bar.progress(prog)
237
 
238
+ if "nice_name_xlmroberta" not in st.session_state:
239
+ ## XLM RoBERTa
240
+ model_name_xlmroberta = 'xlm-roberta-base'
241
+ nice_name_xlmroberta = 'XLM RoBERTa'
242
+ placeholder_lm_name.text(f"Initializing {nice_name_xlmroberta}...")
243
+ tokenizer_xlmroberta = XLMRobertaTokenizer.from_pretrained(model_name_xlmroberta)
244
+ config_xlmroberta = XLMRobertaConfig.from_pretrained(model_name_xlmroberta)
245
+ config_xlmroberta.is_decoder = True
246
+ model_xlmroberta = XLMRobertaForCausalLM.from_pretrained(model_name_xlmroberta, config=config_xlmroberta)
247
+ tokenizer_xlmroberta.pad_token = tokenizer_xlmroberta.eos_token
248
+ model_xlmroberta.eval()
249
+ model_xlmroberta.cpu()
250
+ st.session_state["model_xlmroberta"] = model_xlmroberta
251
+ st.session_state["tokenizer_xlmroberta"] = tokenizer_xlmroberta
252
+ st.session_state["nice_name_xlmroberta"] = nice_name_xlmroberta
253
 
254
  prog += 10
255
  my_bar.progress(prog)
 
257
  my_bar.empty()
258
 
259
  def main():
260
+ if "GPT-2" not in st.session_state:
261
+ init_lms()
262
  sent = st.text_input('Enter a sentence:', value="")
263
 
264
  ### LMs we are trying:
 
317
  st.write("Input another sentence!")
318
 
319
  if __name__ == '__main__':
320
+ main()