Spaces:
Runtime error
Runtime error
Olivia Figueira
commited on
Commit
•
8167d9f
1
Parent(s):
37d028c
Fix runtime memory issues by caching and using session state
Browse files- critic/critic.py +128 -63
critic/critic.py
CHANGED
@@ -16,6 +16,11 @@ nltk.download('punkt')
|
|
16 |
sys.path.insert(0, '.')
|
17 |
from critic.perturbations import get_local_neighbors_char_level, get_local_neighbors_word_level
|
18 |
from utils.spacy_tokenizer import spacy_tokenize_gec
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
def get_gpt2_loss(model, tokenizer, input_ids, attention_mask, labels):
|
21 |
with torch.no_grad():
|
@@ -137,10 +142,117 @@ def gpt2_critic(sent, model, tokenizer, verbose=1, cuda=False, fp16=True, seed='
|
|
137 |
counter_example = [sents[best_idx], float(logps[best_idx])]
|
138 |
return is_good, float(logps[0]), counter_example, return_string
|
139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
def main():
|
142 |
-
|
143 |
-
st.subheader('Exploring Unsupervised Grammatical Error Correction with Transformer-Based Models')
|
144 |
sent = st.text_input('Enter a sentence:', value="")
|
145 |
|
146 |
### LMs we are trying:
|
@@ -150,92 +262,45 @@ def main():
|
|
150 |
|
151 |
with st.spinner('Running with GPT-2 LM...'):
|
152 |
## GPT-2 LM (original LM-critic)
|
153 |
-
|
154 |
-
nice_name = "GPT-2"
|
155 |
-
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
156 |
-
tokenizer.pad_token = tokenizer.eos_token
|
157 |
-
model = GPT2LMHeadModel.from_pretrained(model_name)
|
158 |
-
model.eval()
|
159 |
-
model.cpu()
|
160 |
-
is_good, score, counter_example, return_string_GPT2 = gpt2_critic(sent, model, tokenizer)
|
161 |
st.markdown("**Results with GPT-2 LM:**")
|
162 |
st.write('\n'.join(return_string_GPT2))
|
163 |
-
results[
|
164 |
|
165 |
with st.spinner('Running with OPT LM...'):
|
166 |
## OPT LM
|
167 |
-
|
168 |
-
nice_name = "OPT"
|
169 |
-
model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
|
170 |
-
tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
|
171 |
-
tokenizer.pad_token = tokenizer.eos_token
|
172 |
-
model.eval()
|
173 |
-
model.cpu()
|
174 |
-
is_good, score, counter_example, return_string_OPT = gpt2_critic(sent, model, tokenizer)
|
175 |
st.markdown("**Results with OPT LM:**")
|
176 |
st.write('\n'.join(return_string_OPT))
|
177 |
-
results[
|
178 |
|
179 |
with st.spinner('Running with GPT NEO LM...'):
|
180 |
## GPT NEO
|
181 |
-
|
182 |
-
nice_name = "GPT NEO"
|
183 |
-
model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
|
184 |
-
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
|
185 |
-
tokenizer.pad_token = tokenizer.eos_token
|
186 |
-
model.eval()
|
187 |
-
model.cpu()
|
188 |
-
is_good, score, counter_example, return_string_GPTNEO = gpt2_critic(sent, model, tokenizer)
|
189 |
st.markdown("**Results with GPT NEO LM:**")
|
190 |
st.write('\n'.join(return_string_GPTNEO))
|
191 |
-
results[
|
192 |
|
193 |
with st.spinner('Running with RoBERTa LM...'):
|
194 |
## RoBERTa
|
195 |
-
|
196 |
-
nice_name = "RoBERTa"
|
197 |
-
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
|
198 |
-
config = RobertaConfig.from_pretrained("roberta-base")
|
199 |
-
config.is_decoder = True
|
200 |
-
model = RobertaForCausalLM.from_pretrained("roberta-base", config=config)
|
201 |
-
tokenizer.pad_token = tokenizer.eos_token
|
202 |
-
model.eval()
|
203 |
-
model.cpu()
|
204 |
-
is_good, score, counter_example, return_string_RoBERTa = gpt2_critic(sent, model, tokenizer)
|
205 |
st.markdown("**Results with RoBERTa LM:**")
|
206 |
st.write('\n'.join(return_string_RoBERTa))
|
207 |
-
results[
|
208 |
|
209 |
with st.spinner('Running with BART LM...'):
|
210 |
-
##
|
211 |
-
|
212 |
-
nice_name = "BART"
|
213 |
-
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
|
214 |
-
model = BartForCausalLM.from_pretrained("facebook/bart-base", add_cross_attention=False)
|
215 |
-
assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
|
216 |
-
tokenizer.pad_token = tokenizer.eos_token
|
217 |
-
model.eval()
|
218 |
-
model.cpu()
|
219 |
-
is_good, score, counter_example, return_string_BART = gpt2_critic(sent, model, tokenizer)
|
220 |
st.markdown("**Results with BART LM:**")
|
221 |
st.write('\n'.join(return_string_BART))
|
222 |
-
results[
|
223 |
|
224 |
with st.spinner('Running with XLM RoBERTa LM...'):
|
225 |
## XLM RoBERTa
|
226 |
-
|
227 |
-
nice_name = 'XLM RoBERTa'
|
228 |
-
tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
|
229 |
-
config = XLMRobertaConfig.from_pretrained("xlm-roberta-base")
|
230 |
-
config.is_decoder = True
|
231 |
-
model = XLMRobertaForCausalLM.from_pretrained("xlm-roberta-base", config=config)
|
232 |
-
tokenizer.pad_token = tokenizer.eos_token
|
233 |
-
model.eval()
|
234 |
-
model.cpu()
|
235 |
-
is_good, score, counter_example, return_string_XLMRoBERTa = gpt2_critic(sent, model, tokenizer)
|
236 |
st.markdown("**Results with XLM RoBERTa LM:**")
|
237 |
st.write('\n'.join(return_string_XLMRoBERTa))
|
238 |
-
results[
|
239 |
|
240 |
df = pd.DataFrame.from_dict(results,
|
241 |
orient = 'index',
|
@@ -243,7 +308,7 @@ def main():
|
|
243 |
st.markdown("**Tabular summary of results:**")
|
244 |
st.table(df)
|
245 |
|
246 |
-
st.write("
|
247 |
|
248 |
if __name__ == '__main__':
|
249 |
main()
|
|
|
16 |
sys.path.insert(0, '.')
|
17 |
from critic.perturbations import get_local_neighbors_char_level, get_local_neighbors_word_level
|
18 |
from utils.spacy_tokenizer import spacy_tokenize_gec
|
19 |
+
import streamlit as st
|
20 |
+
|
21 |
+
st.subheader('Exploring Unsupervised Grammatical Error Correction with Transformer-Based Models')
|
22 |
+
st.write('This live demonstration is adapted from the paper [LM-Critic: Language Models for Unsupervised Grammatical Error Correction](https://aclanthology.org/2021.emnlp-main.611.pdf) (EMNLP 2021) by Michihiro Yasunaga, Jure Leskovec, Percy Liang.')
|
23 |
+
st.write('The below demo first loads several LMs that we use in the LM-Critic. You will be prompted to enter a sentence which will then be scored by each of the LM-Critics using different LMs.')
|
24 |
|
25 |
def get_gpt2_loss(model, tokenizer, input_ids, attention_mask, labels):
|
26 |
with torch.no_grad():
|
|
|
142 |
counter_example = [sents[best_idx], float(logps[best_idx])]
|
143 |
return is_good, float(logps[0]), counter_example, return_string
|
144 |
|
145 |
+
@st.cache(suppress_st_warning=True)
|
146 |
+
def init_lms():
|
147 |
+
placeholder_lm_name = st.empty()
|
148 |
+
my_bar = st.progress(10)
|
149 |
+
prog = 10
|
150 |
+
|
151 |
+
## GPT-2 LM (original LM-critic)
|
152 |
+
model_name_gpt2 = 'gpt2'
|
153 |
+
nice_name_gpt2 = "GPT-2"
|
154 |
+
placeholder_lm_name.text(f"Initializing {nice_name_gpt2}...")
|
155 |
+
tokenizer_gpt2 = GPT2Tokenizer.from_pretrained(model_name_gpt2)
|
156 |
+
tokenizer_gpt2.pad_token = tokenizer_gpt2.eos_token
|
157 |
+
model_gpt2 = GPT2LMHeadModel.from_pretrained(model_name_gpt2)
|
158 |
+
model_gpt2.eval()
|
159 |
+
model_gpt2.cpu()
|
160 |
+
st.session_state["model_gpt2"] = model_gpt2
|
161 |
+
st.session_state["tokenizer_gpt2"] = tokenizer_gpt2
|
162 |
+
st.session_state["nice_name_gpt2"] = nice_name_gpt2
|
163 |
+
|
164 |
+
prog += 10
|
165 |
+
my_bar.progress(prog)
|
166 |
+
|
167 |
+
## OPT LM
|
168 |
+
model_name_opt = "facebook/opt-350m"
|
169 |
+
nice_name_opt = "OPT"
|
170 |
+
placeholder_lm_name.text(f"Initializing {nice_name_opt}...")
|
171 |
+
model_opt = OPTForCausalLM.from_pretrained(model_name_opt)
|
172 |
+
tokenizer_opt = GPT2Tokenizer.from_pretrained(model_name_opt)
|
173 |
+
tokenizer_opt.pad_token = tokenizer_opt.eos_token
|
174 |
+
model_opt.eval()
|
175 |
+
model_opt.cpu()
|
176 |
+
st.session_state["model_opt"] = model_opt
|
177 |
+
st.session_state["tokenizer_opt"] = tokenizer_opt
|
178 |
+
st.session_state["nice_name_opt"] = nice_name_opt
|
179 |
+
|
180 |
+
prog += 10
|
181 |
+
my_bar.progress(prog)
|
182 |
+
|
183 |
+
## GPT NEO
|
184 |
+
model_name_gptneo = "EleutherAI/gpt-neo-1.3B"
|
185 |
+
nice_name_gptneo = "GPT NEO"
|
186 |
+
placeholder_lm_name.text(f"Initializing {nice_name_gptneo}...")
|
187 |
+
model_gptneo = GPTNeoForCausalLM.from_pretrained(model_name_gptneo)
|
188 |
+
tokenizer_gptneo = GPT2Tokenizer.from_pretrained(model_name_gptneo)
|
189 |
+
tokenizer_gptneo.pad_token = tokenizer_gptneo.eos_token
|
190 |
+
model_gptneo.eval()
|
191 |
+
model_gptneo.cpu()
|
192 |
+
st.session_state["model_gptneo"] = model_gptneo
|
193 |
+
st.session_state["tokenizer_gptneo"] = tokenizer_gptneo
|
194 |
+
st.session_state["nice_name_gptneo"] = nice_name_gptneo
|
195 |
+
|
196 |
+
prog += 10
|
197 |
+
my_bar.progress(prog)
|
198 |
+
|
199 |
+
## RoBERTa
|
200 |
+
model_name_roberta = "roberta-base"
|
201 |
+
nice_name_roberta = "RoBERTa"
|
202 |
+
placeholder_lm_name.text(f"Initializing {nice_name_roberta}...")
|
203 |
+
tokenizer_roberta = RobertaTokenizer.from_pretrained(model_name_roberta)
|
204 |
+
config_roberta = RobertaConfig.from_pretrained(model_name_roberta)
|
205 |
+
config_roberta.is_decoder = True
|
206 |
+
model_roberta = RobertaForCausalLM.from_pretrained(model_name_roberta, config=config_roberta)
|
207 |
+
tokenizer_roberta.pad_token = tokenizer_roberta.eos_token
|
208 |
+
model_roberta.eval()
|
209 |
+
model_roberta.cpu()
|
210 |
+
st.session_state["model_roberta"] = model_gptneo
|
211 |
+
st.session_state["tokenizer_roberta"] = tokenizer_roberta
|
212 |
+
st.session_state["nice_name_roberta"] = nice_name_roberta
|
213 |
+
|
214 |
+
prog += 10
|
215 |
+
my_bar.progress(prog)
|
216 |
+
|
217 |
+
## BART
|
218 |
+
model_name_bart = "facebook/bart-base"
|
219 |
+
nice_name_bart = "BART"
|
220 |
+
placeholder_lm_name.text(f"Initializing {nice_name_bart}...")
|
221 |
+
tokenizer_bart = BartTokenizer.from_pretrained(model_name_bart)
|
222 |
+
model_bart = BartForCausalLM.from_pretrained(model_name_bart, add_cross_attention=False)
|
223 |
+
assert model_bart.config.is_decoder, f"{model_bart.__class__} has to be configured as a decoder."
|
224 |
+
tokenizer_bart.pad_token = tokenizer_bart.eos_token
|
225 |
+
model_bart.eval()
|
226 |
+
model_bart.cpu()
|
227 |
+
st.session_state["model_bart"] = model_bart
|
228 |
+
st.session_state["tokenizer_bart"] = tokenizer_bart
|
229 |
+
st.session_state["nice_name_bart"] = nice_name_bart
|
230 |
+
|
231 |
+
prog += 10
|
232 |
+
my_bar.progress(prog)
|
233 |
+
|
234 |
+
## XLM RoBERTa
|
235 |
+
model_name_xlmroberta = 'xlm-roberta-base'
|
236 |
+
nice_name_xlmroberta = 'XLM RoBERTa'
|
237 |
+
placeholder_lm_name.text(f"Initializing {nice_name_xlmroberta}...")
|
238 |
+
tokenizer_xlmroberta = XLMRobertaTokenizer.from_pretrained(model_name_xlmroberta)
|
239 |
+
config_xlmroberta = XLMRobertaConfig.from_pretrained(model_name_xlmroberta)
|
240 |
+
config_xlmroberta.is_decoder = True
|
241 |
+
model_xlmroberta = XLMRobertaForCausalLM.from_pretrained(model_name_xlmroberta, config=config_xlmroberta)
|
242 |
+
tokenizer_xlmroberta.pad_token = tokenizer_xlmroberta.eos_token
|
243 |
+
model_xlmroberta.eval()
|
244 |
+
model_xlmroberta.cpu()
|
245 |
+
st.session_state["model_xlmroberta"] = model_xlmroberta
|
246 |
+
st.session_state["tokenizer_xlmroberta"] = tokenizer_xlmroberta
|
247 |
+
st.session_state["nice_name_xlmroberta"] = nice_name_xlmroberta
|
248 |
+
|
249 |
+
prog += 10
|
250 |
+
my_bar.progress(prog)
|
251 |
+
placeholder_lm_name.empty()
|
252 |
+
my_bar.empty()
|
253 |
|
254 |
def main():
|
255 |
+
init_lms()
|
|
|
256 |
sent = st.text_input('Enter a sentence:', value="")
|
257 |
|
258 |
### LMs we are trying:
|
|
|
262 |
|
263 |
with st.spinner('Running with GPT-2 LM...'):
|
264 |
## GPT-2 LM (original LM-critic)
|
265 |
+
is_good, score, counter_example, return_string_GPT2 = gpt2_critic(sent, st.session_state['model_gpt2'], st.session_state['tokenizer_gpt2'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
st.markdown("**Results with GPT-2 LM:**")
|
267 |
st.write('\n'.join(return_string_GPT2))
|
268 |
+
results[st.session_state['nice_name_gpt2']] = ["Good" if is_good else "Bad", str(round(score, 3)), "N/A" if not counter_example else str(counter_example[0]), "N/A" if not counter_example else str(round(counter_example[1], 3))]
|
269 |
|
270 |
with st.spinner('Running with OPT LM...'):
|
271 |
## OPT LM
|
272 |
+
is_good, score, counter_example, return_string_OPT = gpt2_critic(sent, st.session_state['model_opt'], st.session_state['tokenizer_opt'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
st.markdown("**Results with OPT LM:**")
|
274 |
st.write('\n'.join(return_string_OPT))
|
275 |
+
results[st.session_state['nice_name_opt']] = ["Good" if is_good else "Bad", str(round(score, 3)), "N/A" if not counter_example else str(counter_example[0]), "N/A" if not counter_example else str(round(counter_example[1], 3))]
|
276 |
|
277 |
with st.spinner('Running with GPT NEO LM...'):
|
278 |
## GPT NEO
|
279 |
+
is_good, score, counter_example, return_string_GPTNEO = gpt2_critic(sent, st.session_state['model_gptneo'], st.session_state['tokenizer_gptneo'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
st.markdown("**Results with GPT NEO LM:**")
|
281 |
st.write('\n'.join(return_string_GPTNEO))
|
282 |
+
results[st.session_state['nice_name_gptneo']] = ["Good" if is_good else "Bad", str(round(score, 3)), "N/A" if not counter_example else str(counter_example[0]), "N/A" if not counter_example else str(round(counter_example[1], 3))]
|
283 |
|
284 |
with st.spinner('Running with RoBERTa LM...'):
|
285 |
## RoBERTa
|
286 |
+
is_good, score, counter_example, return_string_RoBERTa = gpt2_critic(sent, st.session_state['model_roberta'], st.session_state['tokenizer_roberta'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
st.markdown("**Results with RoBERTa LM:**")
|
288 |
st.write('\n'.join(return_string_RoBERTa))
|
289 |
+
results[st.session_state['nice_name_roberta']] = ["Good" if is_good else "Bad", str(round(score, 3)), "N/A" if not counter_example else str(counter_example[0]), "N/A" if not counter_example else str(round(counter_example[1], 3))]
|
290 |
|
291 |
with st.spinner('Running with BART LM...'):
|
292 |
+
## BART
|
293 |
+
is_good, score, counter_example, return_string_BART = gpt2_critic(sent, st.session_state['model_bart'], st.session_state['tokenizer_bart'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
st.markdown("**Results with BART LM:**")
|
295 |
st.write('\n'.join(return_string_BART))
|
296 |
+
results[st.session_state['nice_name_bart']] = ["Good" if is_good else "Bad", str(round(score, 3)), "N/A" if not counter_example else str(counter_example[0]), "N/A" if not counter_example else str(round(counter_example[1], 3))]
|
297 |
|
298 |
with st.spinner('Running with XLM RoBERTa LM...'):
|
299 |
## XLM RoBERTa
|
300 |
+
is_good, score, counter_example, return_string_XLMRoBERTa = gpt2_critic(sent, st.session_state['model_xlmroberta'], st.session_state['tokenizer_xlmroberta'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
st.markdown("**Results with XLM RoBERTa LM:**")
|
302 |
st.write('\n'.join(return_string_XLMRoBERTa))
|
303 |
+
results[st.session_state['nice_name_xlmroberta']] = ["Good" if is_good else "Bad", str(round(score, 3)), "N/A" if not counter_example else str(counter_example[0]), "N/A" if not counter_example else str(round(counter_example[1], 3))]
|
304 |
|
305 |
df = pd.DataFrame.from_dict(results,
|
306 |
orient = 'index',
|
|
|
308 |
st.markdown("**Tabular summary of results:**")
|
309 |
st.table(df)
|
310 |
|
311 |
+
st.write("Input another sentence!")
|
312 |
|
313 |
if __name__ == '__main__':
|
314 |
main()
|