Spaces:
Runtime error
Runtime error
Olivia Figueira
commited on
Commit
•
3c050d3
1
Parent(s):
8167d9f
Fixed reloading issue with state init
Browse files- critic/critic.py +89 -83
critic/critic.py
CHANGED
@@ -142,109 +142,114 @@ def gpt2_critic(sent, model, tokenizer, verbose=1, cuda=False, fp16=True, seed='
|
|
142 |
counter_example = [sents[best_idx], float(logps[best_idx])]
|
143 |
return is_good, float(logps[0]), counter_example, return_string
|
144 |
|
145 |
-
@st.cache(suppress_st_warning=True)
|
146 |
def init_lms():
|
147 |
placeholder_lm_name = st.empty()
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
|
|
163 |
|
164 |
prog += 10
|
165 |
my_bar.progress(prog)
|
166 |
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
|
|
179 |
|
180 |
prog += 10
|
181 |
my_bar.progress(prog)
|
182 |
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
|
|
195 |
|
196 |
prog += 10
|
197 |
my_bar.progress(prog)
|
198 |
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
|
|
213 |
|
214 |
prog += 10
|
215 |
my_bar.progress(prog)
|
216 |
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
|
|
230 |
|
231 |
prog += 10
|
232 |
my_bar.progress(prog)
|
233 |
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
|
|
248 |
|
249 |
prog += 10
|
250 |
my_bar.progress(prog)
|
@@ -252,7 +257,8 @@ def init_lms():
|
|
252 |
my_bar.empty()
|
253 |
|
254 |
def main():
|
255 |
-
|
|
|
256 |
sent = st.text_input('Enter a sentence:', value="")
|
257 |
|
258 |
### LMs we are trying:
|
@@ -311,4 +317,4 @@ def main():
|
|
311 |
st.write("Input another sentence!")
|
312 |
|
313 |
if __name__ == '__main__':
|
314 |
-
main()
|
|
|
142 |
counter_example = [sents[best_idx], float(logps[best_idx])]
|
143 |
return is_good, float(logps[0]), counter_example, return_string
|
144 |
|
|
|
145 |
def init_lms():
|
146 |
placeholder_lm_name = st.empty()
|
147 |
+
prog = 0
|
148 |
+
my_bar = st.progress(prog)
|
149 |
+
|
150 |
+
if "nice_name_gpt2" not in st.session_state:
|
151 |
+
## GPT-2 LM (original LM-critic)
|
152 |
+
model_name_gpt2 = 'gpt2'
|
153 |
+
nice_name_gpt2 = "GPT-2"
|
154 |
+
placeholder_lm_name.text(f"Initializing {nice_name_gpt2}...")
|
155 |
+
tokenizer_gpt2 = GPT2Tokenizer.from_pretrained(model_name_gpt2)
|
156 |
+
tokenizer_gpt2.pad_token = tokenizer_gpt2.eos_token
|
157 |
+
model_gpt2 = GPT2LMHeadModel.from_pretrained(model_name_gpt2)
|
158 |
+
model_gpt2.eval()
|
159 |
+
model_gpt2.cpu()
|
160 |
+
st.session_state["model_gpt2"] = model_gpt2
|
161 |
+
st.session_state["tokenizer_gpt2"] = tokenizer_gpt2
|
162 |
+
st.session_state["nice_name_gpt2"] = nice_name_gpt2
|
163 |
|
164 |
prog += 10
|
165 |
my_bar.progress(prog)
|
166 |
|
167 |
+
if "nice_name_opt" not in st.session_state:
|
168 |
+
## OPT LM
|
169 |
+
model_name_opt = "facebook/opt-350m"
|
170 |
+
nice_name_opt = "OPT"
|
171 |
+
placeholder_lm_name.text(f"Initializing {nice_name_opt}...")
|
172 |
+
model_opt = OPTForCausalLM.from_pretrained(model_name_opt)
|
173 |
+
tokenizer_opt = GPT2Tokenizer.from_pretrained(model_name_opt)
|
174 |
+
tokenizer_opt.pad_token = tokenizer_opt.eos_token
|
175 |
+
model_opt.eval()
|
176 |
+
model_opt.cpu()
|
177 |
+
st.session_state["model_opt"] = model_opt
|
178 |
+
st.session_state["tokenizer_opt"] = tokenizer_opt
|
179 |
+
st.session_state["nice_name_opt"] = nice_name_opt
|
180 |
|
181 |
prog += 10
|
182 |
my_bar.progress(prog)
|
183 |
|
184 |
+
if "nice_name_gptneo" not in st.session_state:
|
185 |
+
## GPT NEO
|
186 |
+
model_name_gptneo = "EleutherAI/gpt-neo-1.3B"
|
187 |
+
nice_name_gptneo = "GPT NEO"
|
188 |
+
placeholder_lm_name.text(f"Initializing {nice_name_gptneo}...")
|
189 |
+
model_gptneo = GPTNeoForCausalLM.from_pretrained(model_name_gptneo)
|
190 |
+
tokenizer_gptneo = GPT2Tokenizer.from_pretrained(model_name_gptneo)
|
191 |
+
tokenizer_gptneo.pad_token = tokenizer_gptneo.eos_token
|
192 |
+
model_gptneo.eval()
|
193 |
+
model_gptneo.cpu()
|
194 |
+
st.session_state["model_gptneo"] = model_gptneo
|
195 |
+
st.session_state["tokenizer_gptneo"] = tokenizer_gptneo
|
196 |
+
st.session_state["nice_name_gptneo"] = nice_name_gptneo
|
197 |
|
198 |
prog += 10
|
199 |
my_bar.progress(prog)
|
200 |
|
201 |
+
if "nice_name_roberta" not in st.session_state:
|
202 |
+
## RoBERTa
|
203 |
+
model_name_roberta = "roberta-base"
|
204 |
+
nice_name_roberta = "RoBERTa"
|
205 |
+
placeholder_lm_name.text(f"Initializing {nice_name_roberta}...")
|
206 |
+
tokenizer_roberta = RobertaTokenizer.from_pretrained(model_name_roberta)
|
207 |
+
config_roberta = RobertaConfig.from_pretrained(model_name_roberta)
|
208 |
+
config_roberta.is_decoder = True
|
209 |
+
model_roberta = RobertaForCausalLM.from_pretrained(model_name_roberta, config=config_roberta)
|
210 |
+
tokenizer_roberta.pad_token = tokenizer_roberta.eos_token
|
211 |
+
model_roberta.eval()
|
212 |
+
model_roberta.cpu()
|
213 |
+
st.session_state["model_roberta"] = model_gptneo
|
214 |
+
st.session_state["tokenizer_roberta"] = tokenizer_roberta
|
215 |
+
st.session_state["nice_name_roberta"] = nice_name_roberta
|
216 |
|
217 |
prog += 10
|
218 |
my_bar.progress(prog)
|
219 |
|
220 |
+
if "nice_name_bart" not in st.session_state:
|
221 |
+
## BART
|
222 |
+
model_name_bart = "facebook/bart-base"
|
223 |
+
nice_name_bart = "BART"
|
224 |
+
placeholder_lm_name.text(f"Initializing {nice_name_bart}...")
|
225 |
+
tokenizer_bart = BartTokenizer.from_pretrained(model_name_bart)
|
226 |
+
model_bart = BartForCausalLM.from_pretrained(model_name_bart, add_cross_attention=False)
|
227 |
+
assert model_bart.config.is_decoder, f"{model_bart.__class__} has to be configured as a decoder."
|
228 |
+
tokenizer_bart.pad_token = tokenizer_bart.eos_token
|
229 |
+
model_bart.eval()
|
230 |
+
model_bart.cpu()
|
231 |
+
st.session_state["model_bart"] = model_bart
|
232 |
+
st.session_state["tokenizer_bart"] = tokenizer_bart
|
233 |
+
st.session_state["nice_name_bart"] = nice_name_bart
|
234 |
|
235 |
prog += 10
|
236 |
my_bar.progress(prog)
|
237 |
|
238 |
+
if "nice_name_xlmroberta" not in st.session_state:
|
239 |
+
## XLM RoBERTa
|
240 |
+
model_name_xlmroberta = 'xlm-roberta-base'
|
241 |
+
nice_name_xlmroberta = 'XLM RoBERTa'
|
242 |
+
placeholder_lm_name.text(f"Initializing {nice_name_xlmroberta}...")
|
243 |
+
tokenizer_xlmroberta = XLMRobertaTokenizer.from_pretrained(model_name_xlmroberta)
|
244 |
+
config_xlmroberta = XLMRobertaConfig.from_pretrained(model_name_xlmroberta)
|
245 |
+
config_xlmroberta.is_decoder = True
|
246 |
+
model_xlmroberta = XLMRobertaForCausalLM.from_pretrained(model_name_xlmroberta, config=config_xlmroberta)
|
247 |
+
tokenizer_xlmroberta.pad_token = tokenizer_xlmroberta.eos_token
|
248 |
+
model_xlmroberta.eval()
|
249 |
+
model_xlmroberta.cpu()
|
250 |
+
st.session_state["model_xlmroberta"] = model_xlmroberta
|
251 |
+
st.session_state["tokenizer_xlmroberta"] = tokenizer_xlmroberta
|
252 |
+
st.session_state["nice_name_xlmroberta"] = nice_name_xlmroberta
|
253 |
|
254 |
prog += 10
|
255 |
my_bar.progress(prog)
|
|
|
257 |
my_bar.empty()
|
258 |
|
259 |
def main():
|
260 |
+
if "GPT-2" not in st.session_state:
|
261 |
+
init_lms()
|
262 |
sent = st.text_input('Enter a sentence:', value="")
|
263 |
|
264 |
### LMs we are trying:
|
|
|
317 |
st.write("Input another sentence!")
|
318 |
|
319 |
if __name__ == '__main__':
|
320 |
+
main()
|