Olivia Figueira commited on
Commit
b6e5241
1 Parent(s): 090a2ed

Upload code with streamlit addition

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 Michihiro Yasunaga
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: pink
5
  colorTo: gray
6
  sdk: streamlit
7
  sdk_version: 1.2.0
8
- app_file: app.py
9
  pinned: false
10
  license: afl-3.0
11
  ---
 
5
  colorTo: gray
6
  sdk: streamlit
7
  sdk_version: 1.2.0
8
+ app_file: critic/critic.py
9
  pinned: false
10
  license: afl-3.0
11
  ---
critic/PIE/__pycache__/word_level_perturb.cpython-37.pyc ADDED
Binary file (9.64 kB). View file
 
critic/PIE/__pycache__/word_level_perturb.cpython-38.pyc ADDED
Binary file (7.8 kB). View file
 
critic/PIE/word_level_perturb.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Word-level perturbation generator.
3
+
4
+ Originally by https://github.com/awasthiabhijeet/PIE/tree/master/errorify
5
+ """
6
+ import os
7
+ import math
8
+ import pickle
9
+ import random
10
+ import editdistance
11
+ from numpy.random import choice as npchoice
12
+ from collections import defaultdict
13
+
14
+
15
+ try:
16
+ dir_path = os.path.dirname(os.path.realpath(__file__))
17
+ except:
18
+ dir_path = '.'
19
+
20
+ VERBS = pickle.load(open(f'{dir_path}/verbs.p', 'rb'))
21
+ COMMON_INSERTS = set(pickle.load(open(f'{dir_path}/common_inserts.p', 'rb'))) #common inserts *to fix a sent*
22
+ COMMON_DELETES = pickle.load(open(f'{dir_path}/common_deletes.p','rb')) #common deletes *to fix a sent*
23
+ _COMMON_REPLACES = pickle.load(open(f'{dir_path}/common_replaces.p', 'rb')) #common replacements *to errorify a sent*
24
+
25
+
26
+
27
+ COMMON_REPLACES = {}
28
+ for src in _COMMON_REPLACES:
29
+ for tgt in _COMMON_REPLACES[src]:
30
+ if (src=="'re" and tgt=="are") or (tgt=="'re" and src=="are"):
31
+ continue
32
+ ED = editdistance.eval(tgt, src)
33
+ if ED > 2:
34
+ continue
35
+ longer = max(len(src), len(tgt))
36
+ if float(ED)/longer >= 0.5:
37
+ continue
38
+ if tgt not in COMMON_REPLACES:
39
+ COMMON_REPLACES[tgt] = {}
40
+ COMMON_REPLACES[tgt][src] = _COMMON_REPLACES[src][tgt]
41
+
42
+
43
+ VERBS_refine = defaultdict(list)
44
+ for src in VERBS:
45
+ for tgt in VERBS[src]:
46
+ ED = editdistance.eval(tgt, src)
47
+ if ED > 2:
48
+ continue
49
+ longer = max(len(src), len(tgt))
50
+ if float(ED)/longer >= 0.5:
51
+ continue
52
+ VERBS_refine[src].append(tgt)
53
+
54
+
55
+
56
+ class WordLevelPerturber_all:
57
+ def __init__(self, sentence: str):
58
+ self.original_sentence = sentence.rstrip()
59
+ self.sentence = self.original_sentence
60
+ self.tokenized = None
61
+ self.tokenize()
62
+
63
+ def tokenize(self):
64
+ self.tokenized = self.sentence.split()
65
+
66
+ def orig(self):
67
+ return self.original_sentence
68
+
69
+ def _insert(self):
70
+ """Insert a commonly deleted word."""
71
+ if len(self.tokenized) > 0:
72
+ insertable = list(range(len(self.tokenized)))
73
+ index = random.choice(insertable)
74
+ plist = list(COMMON_DELETES.values())
75
+ plistsum = sum(plist)
76
+ plist = [x / plistsum for x in plist]
77
+ # Choose a word
78
+ ins_word = npchoice(list(COMMON_DELETES.keys()), p=plist)
79
+ self.tokenized.insert(index,ins_word)
80
+ return ' '.join(self.tokenized)
81
+
82
+ def _mod_verb(self, redir=True):
83
+ if len(self.tokenized) > 0:
84
+ verbs = [i for i, w in enumerate(self.tokenized) if w in VERBS]
85
+ if not verbs:
86
+ if redir:
87
+ return self._replace(redir=False)
88
+ return self.sentence
89
+ index = random.choice(verbs)
90
+ word = self.tokenized[index]
91
+ if not VERBS[word]:
92
+ return self.sentence
93
+ repl = random.choice(VERBS[word])
94
+ self.tokenized[index] = repl
95
+ return ' '.join(self.tokenized)
96
+
97
+ def _delete(self):
98
+ """Delete a commonly inserted word."""
99
+ if len(self.tokenized) > 1:
100
+ toks_len = len(self.tokenized)
101
+ toks = self.tokenized
102
+ deletable = [i for i, w in enumerate(toks) if w in COMMON_INSERTS]
103
+ if not deletable:
104
+ return self.sentence
105
+ index = random.choice(deletable)
106
+ del self.tokenized[index]
107
+ return ' '.join(self.tokenized)
108
+
109
+ def _replace(self, redir=True):
110
+ if len(self.tokenized) > 0:
111
+ deletable = [i for i, w in enumerate(self.tokenized) if (w in COMMON_REPLACES)]
112
+ if not deletable:
113
+ if redir:
114
+ return self._mod_verb(redir=False)
115
+ return self.sentence
116
+ index = random.choice(deletable)
117
+ word = self.tokenized[index]
118
+ if not COMMON_REPLACES[word]:
119
+ return self.sentence
120
+ # Normalize probabilities
121
+ plist = list(COMMON_REPLACES[word].values())
122
+ plistsum = sum(plist)
123
+ plist = [x / plistsum for x in plist]
124
+ # Choose a word
125
+ repl = npchoice(list(COMMON_REPLACES[word].keys()), p=plist)
126
+ self.tokenized[index] = repl
127
+ return ' '.join(self.tokenized)
128
+
129
+ def perturb(self):
130
+ count = 1
131
+ orig_sent = self.sentence
132
+ for x in range(count):
133
+ perturb_probs = [.30,.30,.30,.10]
134
+ perturb_fun = npchoice([self._insert, self._mod_verb, self._replace, self._delete],p=perturb_probs)
135
+ self.sentence = perturb_fun()
136
+ self.tokenize()
137
+ res_sentence = self.sentence
138
+ self.sentence = self.original_sentence
139
+ self.tokenize()
140
+ return res_sentence
141
+
142
+
143
+ class WordLevelPerturber_refine:
144
+ def __init__(self, sentence: str):
145
+ self.original_sentence = sentence.rstrip()
146
+ self.sentence = self.original_sentence
147
+ self.tokenized = None
148
+ self.tokenize()
149
+
150
+ def tokenize(self):
151
+ self.tokenized = self.sentence.split()
152
+
153
+ def orig(self):
154
+ return self.original_sentence
155
+
156
+ def _insert(self):
157
+ """Insert a commonly deleted word."""
158
+ if len(self.tokenized) > 0:
159
+ insertable = list(range(len(self.tokenized)))
160
+ index = random.choice(insertable)
161
+ plist = list(COMMON_DELETES.values())
162
+ plistsum = sum(plist)
163
+ plist = [x / plistsum for x in plist]
164
+ # Choose a word
165
+ ins_word = npchoice(list(COMMON_DELETES.keys()), p=plist)
166
+ self.tokenized.insert(index,ins_word)
167
+ return ' '.join(self.tokenized)
168
+
169
+ def _mod_verb(self, redir=True):
170
+ if len(self.tokenized) > 0:
171
+ verbs = [i for i, w in enumerate(self.tokenized) if w in VERBS_refine]
172
+ if not verbs:
173
+ if redir:
174
+ return self._replace(redir=False)
175
+ return self.sentence
176
+ index = random.choice(verbs)
177
+ word = self.tokenized[index]
178
+ if not VERBS_refine[word]:
179
+ return self.sentence
180
+ repl = random.choice(VERBS_refine[word])
181
+ self.tokenized[index] = repl
182
+
183
+ return ' '.join(self.tokenized)
184
+
185
+ def _delete(self):
186
+ """Delete a commonly inserted word."""
187
+ if len(self.tokenized) > 1:
188
+ toks_len = len(self.tokenized)
189
+ toks = self.tokenized
190
+ deletable = [i for i, w in enumerate(toks) if (w in COMMON_INSERTS) and (i>0 and toks[i-1].lower() == toks[i].lower())]
191
+ if not deletable:
192
+ return self.sentence
193
+ index = random.choice(deletable)
194
+ del self.tokenized[index]
195
+ return ' '.join(self.tokenized)
196
+
197
+ def _replace(self, redir=True):
198
+ def _keep(i,w):
199
+ if w.lower() in {"not", "n't"}:
200
+ return True
201
+ return False
202
+
203
+ if len(self.tokenized) > 0:
204
+ deletable = [i for i, w in enumerate(self.tokenized) if (w in COMMON_REPLACES) and (not _keep(i,w))]
205
+ if not deletable:
206
+ if redir:
207
+ return self._mod_verb(redir=False)
208
+ return self.sentence
209
+
210
+ index = random.choice(deletable)
211
+ word = self.tokenized[index]
212
+ if not COMMON_REPLACES[word]:
213
+ return self.sentence
214
+
215
+ # Normalize probabilities
216
+ plist = list(COMMON_REPLACES[word].values())
217
+ plistsum = sum(plist)
218
+ plist = [x / plistsum for x in plist]
219
+
220
+ # Choose a word
221
+ repl = npchoice(list(COMMON_REPLACES[word].keys()), p=plist)
222
+ self.tokenized[index] = repl
223
+
224
+ return ' '.join(self.tokenized)
225
+
226
+ def perturb(self):
227
+ count = 1
228
+ orig_sent = self.sentence
229
+ for x in range(count):
230
+ perturb_probs = [.30,.30,.30,.10]
231
+ perturb_fun = npchoice([self._insert, self._mod_verb, self._replace, self._delete],p=perturb_probs)
232
+ self.sentence = perturb_fun()
233
+ self.tokenize()
234
+ res_sentence = self.sentence
235
+ self.sentence = self.original_sentence
236
+ self.tokenize()
237
+ return res_sentence
critic/__init__.py ADDED
File without changes
critic/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (146 Bytes). View file
 
critic/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (142 Bytes). View file
 
critic/__pycache__/critic.cpython-37.pyc ADDED
Binary file (5.46 kB). View file
 
critic/__pycache__/critic.cpython-38.pyc ADDED
Binary file (4.26 kB). View file
 
critic/__pycache__/edit_dist_utils.cpython-38.pyc ADDED
Binary file (4.62 kB). View file
 
critic/__pycache__/perturbations.cpython-38.pyc ADDED
Binary file (4.55 kB). View file
 
critic/common_typo.json ADDED
The diff for this file is too large to render. See raw diff
 
critic/critic.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ import random
4
+ import hashlib
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ from transformers import GPT2Tokenizer, GPT2Model, GPT2LMHeadModel
8
+ import nltk
9
+ nltk.download('punkt')
10
+
11
+ sys.path.insert(0, '.')
12
+ from critic.perturbations import get_local_neighbors_char_level, get_local_neighbors_word_level
13
+ from utils.spacy_tokenizer import spacy_tokenize_gec
14
+
15
+ model_name = 'gpt2'
16
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
17
+ tokenizer.pad_token = tokenizer.eos_token
18
+ model = GPT2LMHeadModel.from_pretrained(model_name)
19
+ model.eval()
20
+ #model.cuda()
21
+ model.cpu()
22
+ print (f'Loaded {model_name}')
23
+
24
+
25
+ def get_gpt2_loss(input_ids, attention_mask, labels):
26
+ with torch.no_grad():
27
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
28
+ lm_logits = outputs[1] #[bsize, seqlen, vocab]
29
+ if labels is not None:
30
+ shift_logits = lm_logits[..., :-1, :].contiguous()
31
+ shift_labels = labels[..., 1:].contiguous()
32
+ shift_mask = attention_mask[..., 1:].contiguous()
33
+ loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
34
+ bsize, seqlen = input_ids.size()
35
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(bsize, seqlen-1)
36
+ loss = (loss * shift_mask).sum(dim=1) #[bsize, ]
37
+ return loss
38
+
39
+
40
+ MAX_LENGTH = 66
41
+
42
+ def run_gpt2(sents, cuda=False, model_name=None):
43
+ assert isinstance(sents, list)
44
+ _sents = [tokenizer.bos_token + s for s in sents]
45
+ inputs = tokenizer(_sents, return_tensors="pt", padding=True)
46
+ if inputs['input_ids'].size(1) > MAX_LENGTH:
47
+ return None
48
+ if cuda:
49
+ inputs = {k: v.cuda() for k, v in inputs.items()}
50
+ loss = get_gpt2_loss(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], labels=inputs['input_ids'])
51
+ logps = - loss.detach().cpu()
52
+ return logps
53
+
54
+
55
+ def gpt2_critic_char_level_only(sent, verbose=1, cuda=False, fp16=True, seed='auto', n_samples=100):
56
+ return_string = []
57
+ if seed == 'auto':
58
+ seed = int(hashlib.md5(sent.encode()).hexdigest(), 16) % (2**32) #Seed must be between 0 and 2**32 - 1
59
+ if verbose > 1:
60
+ print ('seed', seed)
61
+ np.random.seed(seed); random.seed(seed)
62
+ is_good = True
63
+ for _ in range(1):
64
+ sent_perturbations = get_local_neighbors_char_level(sent, max_n_samples=n_samples)
65
+ if verbose > 1:
66
+ print ("#sent_perturbations (char-level)", len(sent_perturbations))
67
+ return_string.append(f"#sent_perturbations (char-level){len(sent_perturbations)}\n")
68
+ sents = [sent] + list(sent_perturbations)
69
+ if fp16:
70
+ with torch.cuda.amp.autocast():
71
+ logps = run_gpt2(sents, cuda)
72
+ else:
73
+ logps = run_gpt2(sents, cuda)
74
+ if logps is None:
75
+ if verbose:
76
+ print ('Invalid input. Maybe the sentence is too long.')
77
+ return_string.append('Invalid input. Maybe the sentence is too long.\n')
78
+ return None
79
+ best_idx = int(logps.argmax())
80
+ if best_idx != 0:
81
+ is_good = False
82
+ break
83
+ if verbose:
84
+ if is_good:
85
+ print ('Good! Your sentence log(p) = {:.3f}'.format(float(logps[0])))
86
+ return_string.append('Good! Your sentence log(p) = {:.3f}\n'.format(float(logps[0])))
87
+ else:
88
+ print ('Bad! Your sentence log(p) = {:.3f}'.format(float(logps[0])))
89
+ return_string.append('Bad! Your sentence log(p) = {:.3f}\n'.format(float(logps[0])))
90
+ print ('Neighbor sentence with highest log(p): {} (= {:.3f})'.format(sents[best_idx], float(logps[best_idx])))
91
+ return_string.append('Neighbor sentence with highest log(p): {} (= {:.3f})\n'.format(sents[best_idx], float(logps[best_idx])))
92
+ counter_example = None
93
+ if not is_good:
94
+ counter_example = [sents[best_idx], float(logps[best_idx])]
95
+ return is_good, float(logps[0]), counter_example
96
+
97
+
98
+ def gpt2_critic(sent, verbose=1, cuda=False, fp16=True, seed='auto', n_samples=100, word_level_mode='refine'):
99
+ return_string = []
100
+ if seed == 'auto':
101
+ seed = int(hashlib.md5(sent.encode()).hexdigest(), 16) % (2**32) #Seed must be between 0 and 2**32 - 1
102
+ if verbose > 1:
103
+ print ('seed', seed)
104
+ return_string.append(f'seed{seed}\n')
105
+ np.random.seed(seed); random.seed(seed)
106
+ sent_toked = spacy_tokenize_gec(sent)
107
+ is_good = True
108
+ for _ in range(1):
109
+ sent_perturbations_w, orig_sent = get_local_neighbors_word_level(sent_toked, max_n_samples=n_samples//2, mode=word_level_mode)
110
+ sent_perturbations_c = get_local_neighbors_char_level(orig_sent, max_n_samples=n_samples//2)
111
+ if verbose > 1:
112
+ print ("#sent_perturbations (char-level)", len(sent_perturbations_c))
113
+ return_string.append("#sent_perturbations (char-level)\n", len(sent_perturbations_c))
114
+ print ("#sent_perturbations (word-level)", len(sent_perturbations_w))
115
+ return_string.append("#sent_perturbations (word-level)\n", len(sent_perturbations_w))
116
+ sents = [orig_sent] + list(sent_perturbations_c.union(sent_perturbations_w))
117
+ if fp16:
118
+ with torch.cuda.amp.autocast():
119
+ logps = run_gpt2(sents, cuda)
120
+ else:
121
+ logps = run_gpt2(sents, cuda)
122
+ if logps is None:
123
+ if verbose:
124
+ print ('Invalid input. Maybe the sentence is too long.')
125
+ return_string.append('Invalid input. Maybe the sentence is too long.\n')
126
+ return None
127
+ best_idx = int(logps.argmax())
128
+ if best_idx != 0:
129
+ is_good = False
130
+ break
131
+ if verbose:
132
+ if is_good:
133
+ print ('Good! Your sentence log(p) = {:.3f}'.format(float(logps[0])))
134
+ return_string.append('Good! Your sentence log(p) = {:.3f}\n'.format(float(logps[0])))
135
+ else:
136
+ print ('Bad! Your sentence log(p) = {:.3f}'.format(float(logps[0])))
137
+ return_string.append('Bad! Your sentence log(p) = {:.3f}\n'.format(float(logps[0])))
138
+ print ('Neighbor sentence with highest log(p): {} (= {:.3f})'.format(sents[best_idx], float(logps[best_idx])))
139
+ return_string.append('Neighbor sentence with highest log(p): {} (= {:.3f})\n'.format(sents[best_idx], float(logps[best_idx])))
140
+ counter_example = None
141
+ if not is_good:
142
+ counter_example = [sents[best_idx], float(logps[best_idx])]
143
+ return is_good, float(logps[0]), counter_example, return_string
144
+
145
+
146
+ def main():
147
+ import streamlit as st
148
+ st.subheader('Exploring Unsupervised Grammatical Error Correction with Transformer-Based Models')
149
+ sent = st.text_input('Enter a sentence:', value="")
150
+ if sent != '':
151
+ st.markdown(f"**Sentence**: {sent}")
152
+ _,_,_,return_string = gpt2_critic(sent)
153
+ st.markdown("**Results:**")
154
+ st.write('\n'.join(return_string))
155
+
156
+ if __name__ == '__main__':
157
+ main()
critic/edit_dist_utils.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Edit distance utils...
3
+
4
+ Originally by https://worksheets.codalab.org/worksheets/0x8fc01c7fc2b742fdb29c05669f0ad7d2
5
+ """
6
+ from collections import defaultdict
7
+ import numpy as np
8
+ import random
9
+ import string
10
+ from itertools import permutations
11
+
12
+ def process_filetype(filetype):
13
+ insert = (filetype // 1000) % 2 == 1
14
+ delete = (filetype // 100) % 2 == 1
15
+ substitute = (filetype // 10) % 2 == 1
16
+ swap = filetype % 2 == 1
17
+ return insert, delete, substitute, swap
18
+
19
+ def get_all_edit_dist_one(word, filetype = 1111, sub_restrict = None):
20
+ """
21
+ Allowable edit_dist_one perturbations:
22
+ 1. Insert any lowercase characer at any position other than the start
23
+ 2. Delete any character other than the first one
24
+ 3. Substitute any lowercase character for any other lowercase letter other than the start
25
+ 4. Swap adjacent characters
26
+ We also include the original word. Filetype determines which of the allowable perturbations to use.
27
+ """
28
+ insert, delete, substitute, swap = process_filetype(filetype)
29
+ #last_mod_pos is last thing you could insert before
30
+ last_mod_pos = len(word) #- 1
31
+ ed1 = set()
32
+ if len(word) <= 2 or word[:1].isupper() or word[:1].isnumeric():
33
+ return ed1
34
+ for pos in range(1, last_mod_pos + 1): #can add letters at the end
35
+ if delete and pos < last_mod_pos:
36
+ deletion = word[:pos] + word[pos + 1:]
37
+ ed1.add(deletion)
38
+ if swap and pos < last_mod_pos - 1:
39
+ #swapping thing at pos with thing at pos + 1
40
+ swaped = word[:pos] + word[pos + 1] + word[pos] + word[pos + 2:]
41
+ ed1.add(swaped)
42
+ for letter in string.ascii_lowercase: #+"'-": #no need to add '-, as we want to corrupt good to bad
43
+ if insert:
44
+ #Insert right after pos - 1
45
+ insertion = word[:pos] + letter + word[pos:]
46
+ ed1.add(insertion)
47
+ can_substitute = sub_restrict is None or letter in sub_restrict[word[pos]]
48
+ if substitute and pos < last_mod_pos and can_substitute:
49
+ substitution = word[:pos] + letter + word[pos + 1:]
50
+ ed1.add(substitution)
51
+ #Include original word
52
+ # ed1.add(word)
53
+ return ed1
54
+
55
+ def get_all_internal_permutations(word):
56
+ if len(word) > 10:
57
+ return set([word])
58
+ first_char = word[0]
59
+ last_char = word[-1]
60
+ internal_chars = word[1:-1]
61
+ internal_permutations = set()
62
+ for int_perm in permutations(internal_chars):
63
+ int_perm_str = ''.join(int_perm)
64
+ perm = '{}{}{}'.format(first_char, int_perm_str, last_char)
65
+ internal_permutations.add(perm)
66
+ return internal_permutations
67
+
68
+ def sample_random_internal_permutations(word, n_perts = 5):
69
+ #We try swapping everything with the second character...
70
+ if len(word) < 4:
71
+ return set([word])
72
+ #iterate through positions between second and last
73
+ perturbations = set()
74
+ start = word[0]
75
+ end = word[-1]
76
+ middle = word[1:-1]
77
+ for _ in range(n_perts):
78
+ middle_list = list(middle)
79
+ random.shuffle(middle_list)
80
+ mixed_up_middle = ''.join(middle_list)
81
+ perturbations.add('{}{}{}'.format(start, mixed_up_middle, end))
82
+ return perturbations
83
+
84
+ def get_sorted_word(word):
85
+ if len(word) < 3:
86
+ sorted_word = word
87
+ else:
88
+ sorted_word = '{}{}{}'.format(word[0], ''.join(sorted(word[1:-1])), word[-1])
89
+ return sorted_word
90
+
91
+ def get_sorted_word_set(word):
92
+ if len(word) < 3:
93
+ sorted_word = word
94
+ else:
95
+ sorted_word = '{}{}{}'.format(word[0], ''.join(sorted(word[1:-1])), word[-1])
96
+ return set([sorted_word])
97
+
98
+
99
+ #Used to create agglomerative clusters.
100
+ def preprocess_ed1_neighbors(vocab, sub_restrict = None, filetype = 1111):
101
+ vocab = set([word.lower() for word in vocab])
102
+ typo2words = defaultdict(set)
103
+ for word in vocab:
104
+ ed1_typos = get_all_edit_dist_one(word, filetype = filetype, sub_restrict = sub_restrict)
105
+ for typo in ed1_typos:
106
+ typo2words[typo].add(word)
107
+
108
+ word2neighbors = defaultdict(set)
109
+ for typo in typo2words:
110
+ for word in typo2words[typo]:
111
+ word2neighbors[word] = word2neighbors[word].union(typo2words[typo])
112
+ return word2neighbors
113
+
114
+ #Used to create agglomerative clusters.
115
+ def ed1_neighbors_mat(vocab, sub_restrict = None, filetype = 1111):
116
+ vocab = [word.lower() for word in vocab]
117
+ word2idx = dict([(word, i) for i, word in enumerate(vocab)])
118
+ word2neighbors = preprocess_ed1_neighbors(vocab, sub_restrict = sub_restrict, filetype = filetype)
119
+ edges = set()
120
+ for word in word2neighbors:
121
+ for neighbor in word2neighbors[word]:
122
+ edge = [word, neighbor]
123
+ edge.sort()
124
+ edge = tuple(edge)
125
+ edges.add(edge)
126
+ edge_mat = np.zeros((len(vocab), len(vocab)), dtype = int)
127
+ for edge in edges:
128
+ vtx1, vtx2 = edge
129
+ idx1, idx2 = word2idx[vtx1], word2idx[vtx2]
130
+ edge_mat[idx1][idx2] = 1
131
+ edge_mat[idx2][idx1] = 1
132
+ return edge_mat
133
+
134
+
135
+
136
+ if __name__ == '__main__':
137
+ while True:
138
+ word = input("Enter a word: ")
139
+ print("Total number of possible perturbations: {}".format(len(get_all_edit_dist_one(word))))
critic/perturbations.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Originally by https://worksheets.codalab.org/worksheets/0x8fc01c7fc2b742fdb29c05669f0ad7d2
3
+ """
4
+ import json
5
+ import os, sys
6
+ import re
7
+ import random
8
+ import numpy as np
9
+ from random import sample
10
+ from tqdm import tqdm
11
+ from collections import Counter
12
+
13
+ from critic.edit_dist_utils import get_all_edit_dist_one, sample_random_internal_permutations
14
+
15
+
16
+ try:
17
+ dir_path = os.path.dirname(os.path.realpath(__file__))
18
+ except:
19
+ dir_path = '.'
20
+ common_typo = json.load(open(f"{dir_path}/common_typo.json"))
21
+
22
+ random.seed(1234)
23
+ np.random.seed(1234)
24
+
25
+
26
+ class RandomPerturbationAttack(object):
27
+ def __init__(self, attack_type = 'ed1'):
28
+ self.cache = {} #{word: {0: set(), 1: set(),.. }, ..} #0=swap, 1=substitute, 2=delete, 3=insert
29
+ self.n_types = 5
30
+ self.attack_type = attack_type
31
+ #
32
+ def sample_perturbations(self, word, n_samples, types):
33
+ if types is None:
34
+ type_list = list(range(4)) * (n_samples//4) + list(np.random.choice(self.n_types, n_samples % self.n_types, replace=False))
35
+ else:
36
+ type_list = [sample(types,1)[0] for _ in range(n_samples)]
37
+ type_count = Counter(type_list)
38
+ perturbations = set()
39
+ for type in type_count:
40
+ if type not in self.cache[word]:
41
+ continue
42
+ if len(self.cache[word][type]) >= type_count[type]:
43
+ perturbations.update(set(sample(self.cache[word][type], type_count[type])))
44
+ else:
45
+ perturbations.update(self.cache[word][type])
46
+ return perturbations
47
+ #
48
+ def get_perturbations(self, word, n_samples, types=None):
49
+ if word not in self.cache:
50
+ self.cache[word] = {}
51
+ if word[0].islower():
52
+ for type in range(4):
53
+ self.cache[word][type] = get_all_edit_dist_one(word, 10**type)
54
+ if word in common_typo:
55
+ self.cache[word][4] = set(common_typo[word])
56
+ elif word[0].isupper():
57
+ if word in common_typo:
58
+ self.cache[word][4] = set(common_typo[word])
59
+ if self.attack_type == 'ed1':
60
+ perturbations = self.sample_perturbations(word, n_samples, types)
61
+ else:
62
+ raise NotImplementedError("Attack type: {} not implemented yet".format(self.attack_type))
63
+ return perturbations
64
+ #
65
+ def name(self):
66
+ return 'RandomPerturbationAttack'
67
+
68
+
69
+ word_attack = RandomPerturbationAttack()
70
+
71
+
72
+ def _tokenize(sent):
73
+ toks = []
74
+ word_idxs = []
75
+ for idx, match in enumerate(re.finditer(r'([a-zA-Z]+)|([0-9]+)|.', sent)):
76
+ tok = match.group(0)
77
+ toks.append(tok)
78
+ if len(tok) > 2 and tok.isalpha() and (tok[0].islower()):
79
+ word_idxs.append(idx)
80
+ return toks, word_idxs
81
+
82
+ def _detokenize(toks):
83
+ return ''.join(toks)
84
+
85
+ def get_local_neighbors_char_level(sent, max_n_samples=500):
86
+ words, word_idxs = _tokenize(sent)
87
+ n_samples = min(len(word_idxs) *20, max_n_samples)
88
+ sent_perturbations = set()
89
+ if len(word_idxs) == 0:
90
+ return sent_perturbations
91
+ for _ in range(500):
92
+ word_idx = sample(word_idxs, 1)[0]
93
+ words_cp = words[:]
94
+ word_perturbations = list(word_attack.get_perturbations(words_cp[word_idx], n_samples=1))
95
+ if len(word_perturbations) > 0:
96
+ words_cp[word_idx] = word_perturbations[0]
97
+ sent_perturbed = _detokenize(words_cp)
98
+ if sent_perturbed != sent:
99
+ sent_perturbations.add(sent_perturbed)
100
+ if len(sent_perturbations) == n_samples:
101
+ break
102
+ #Adding common typos such as 's'
103
+ for word_idx in word_idxs:
104
+ words_cp = words[:]
105
+ word = words_cp[word_idx]
106
+ if len(word) > 2 and word[0].islower():
107
+ words_cp[word_idx] = word +'s'
108
+ sent_perturbed = _detokenize(words_cp)
109
+ if sent_perturbed != sent:
110
+ sent_perturbations.add(sent_perturbed)
111
+ words_cp[word_idx] = word[:-1]
112
+ sent_perturbed = _detokenize(words_cp)
113
+ if sent_perturbed != sent:
114
+ sent_perturbations.add(sent_perturbed)
115
+ if len(sent_perturbations) > max_n_samples:
116
+ sent_perturbations = list(sent_perturbations)
117
+ np.random.shuffle(sent_perturbations)
118
+ sent_perturbations = set(sent_perturbations[:max_n_samples])
119
+ return sent_perturbations
120
+
121
+
122
+
123
+ from critic.PIE.word_level_perturb import WordLevelPerturber_all, WordLevelPerturber_refine
124
+ from utils.text_utils import detokenize_sent
125
+
126
+ def get_local_neighbors_word_level(sent_toked, max_n_samples=500, mode='refine'):
127
+ """ sent_toked is tokenized by spacy """
128
+ n_samples = min(len(sent_toked) *20, max_n_samples)
129
+ orig_sent = ' '.join(sent_toked)
130
+ orig_sent_detok = detokenize_sent(orig_sent)
131
+ if mode == 'refine':
132
+ ptb = WordLevelPerturber_refine(orig_sent)
133
+ else:
134
+ ptb = WordLevelPerturber_all(orig_sent)
135
+ sent_perturbations = set()
136
+ for _ in range(500):
137
+ sent_perturbed = ptb.perturb()
138
+ if sent_perturbed != orig_sent:
139
+ sent_perturbed_detok = detokenize_sent(sent_perturbed)
140
+ sent_perturbations.add(sent_perturbed_detok)
141
+ if len(sent_perturbations) == n_samples:
142
+ break
143
+ assert len(sent_perturbations) <= max_n_samples
144
+ return sent_perturbations, orig_sent_detok
eval_critic/eval_critic.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import sys
4
+ import json
5
+ import numpy as np
6
+ import editdistance
7
+ from tqdm import tqdm
8
+ from collections import Counter
9
+
10
+ sys.path.insert(0, '.')
11
+ from utils.text_utils import detokenize_sent
12
+ from critic.critic import run_gpt2, gpt2_critic
13
+
14
+ def load_data():
15
+ data_path = 'eval_critic/eval_data.jsonl'
16
+ good_sents, bad_sents = [], []
17
+ for line in open(data_path):
18
+ obj = json.loads(line)
19
+ good_sents.append(obj['good'])
20
+ bad_sents.append(obj['bad'])
21
+ return good_sents, bad_sents
22
+
23
+ good_sents, bad_sents = load_data()
24
+
25
+
26
+
27
+ def get_logps(sents):
28
+ final = []
29
+ for start in tqdm(range(0, len(sents), 100)):
30
+ sents_sub = sents[start: start+100]
31
+ sents_sub_detok = [detokenize_sent(sent) for sent in sents_sub]
32
+ logps = run_gpt2(sents_sub_detok)
33
+ assert logps is not None
34
+ for i in range(len(sents_sub)):
35
+ final.append({'sent': sents_sub[i], 'sent_detok': sents_sub_detok[i], 'logp': float(logps[i])})
36
+ return final
37
+
38
+ def evaluate_logp():
39
+ """
40
+ Check whether log p(bad_sent) < log p(good_sent)
41
+ """
42
+ good_logps = get_logps(good_sents)
43
+ bad_logps = get_logps(bad_sents)
44
+ accs = []
45
+ for good, bad in zip(good_logps, bad_logps):
46
+ accs.append(int(bad['logp'] < good['logp']))
47
+ avg_acc = float(sum(accs))/len(accs)
48
+ print (f'log p(bad) < log p(good)? {sum(accs)} / {len(accs)} = {avg_acc:.3f}')
49
+ return good_logps, bad_logps
50
+
51
+ good_logps, bad_logps = evaluate_logp()
52
+ # log p(bad) < log p(good)? 555 / 586 = 0.947
53
+
54
+
55
+ def compute_metrics(good_accs, bad_accs):
56
+ goodP = float(sum(good_accs))/(len(bad_accs)-sum(bad_accs)+sum(good_accs))
57
+ goodR = float(sum(good_accs))/len(good_accs)
58
+ goodF05 = (1+0.5**2) * float(goodP * goodR)/((0.5**2 * goodP) + goodR)
59
+ badP = float(sum(bad_accs))/(len(good_accs)-sum(good_accs)+sum(bad_accs))
60
+ badR = float(sum(bad_accs))/len(bad_accs)
61
+ badF05 = (1+0.5**2) * float(badP * badR)/((0.5**2 * badP) + badR)
62
+ print (f' Good precision = {sum(good_accs)} / {(len(bad_accs)-sum(bad_accs)+sum(good_accs))} = {goodP:.3f}')
63
+ print (f' Good recall = {sum(good_accs)} / {len(good_accs)} = {goodR:.3f}')
64
+ print (f' Good F0.5 = {goodF05:.3f}')
65
+ print (f' Bad precision = {sum(bad_accs)} / {(len(good_accs)-sum(good_accs)+sum(bad_accs))} = {badP:.3f}')
66
+ print (f' Bad recall = {sum(bad_accs)} / {len(bad_accs)} = {badR:.3f}')
67
+ print (f' Bad F0.5 = {badF05:.3f}')
68
+ return {'goodP': goodP, 'goodR': goodR, 'goodF05': goodF05, 'badP': badP, 'badR': badR, 'badF05': badF05}
69
+
70
+ def evaluate_baseline_critic():
71
+ threshold = np.mean([elm['logp'] for elm in good_logps + bad_logps])
72
+ good_accs, bad_accs = [], []
73
+ for obj in good_logps:
74
+ pred = int(obj['logp'] > threshold)
75
+ good_accs.append(pred==1)
76
+ for obj in bad_logps:
77
+ pred = int(obj['logp'] > threshold)
78
+ bad_accs.append(pred==0)
79
+ print ('\nBaseline critic:')
80
+ stats = compute_metrics(good_accs, bad_accs)
81
+ json.dump(stats, open('baseline_critic.stats.json', 'w'), indent=2)
82
+
83
+ evaluate_baseline_critic()
84
+ # Baseline critic:
85
+ # Good precision = 365 / 668 = 0.546
86
+ # Good recall = 365 / 586 = 0.623
87
+ # Good F0.5 = 0.560
88
+ # Bad precision = 283 / 504 = 0.562
89
+ # Bad recall = 283 / 586 = 0.483
90
+ # Bad F0.5 = 0.544
91
+
92
+
93
+ def evaluate_LM_Critic():
94
+ good_accs, bad_accs = [], []
95
+ for obj in tqdm(good_logps):
96
+ res = gpt2_critic(obj['sent_detok'], verbose=0, seed=1, n_samples=100, word_level_mode='refine')
97
+ pred = int(res[0])
98
+ good_accs.append(pred==1)
99
+ for obj in tqdm(bad_logps):
100
+ res = gpt2_critic(obj['sent_detok'], verbose=0, seed=1, n_samples=100, word_level_mode='refine')
101
+ pred = int(res[0])
102
+ bad_accs.append(pred==0)
103
+ print ('\nLM-Critic:')
104
+ stats = compute_metrics(good_accs, bad_accs)
105
+ json.dump(stats, open('lm_critic.stats.json', 'w'), indent=2)
106
+
107
+ evaluate_LM_Critic()
108
+ # LM-Critic: (there is variance due to the randomness of sampling, some variation in GPT2 return score)
109
+ # Good precision = 446 / 654 = 0.682
110
+ # Good recall = 446 / 586 = 0.761
111
+ # Good F0.5 = 0.696
112
+ # Bad precision = 378 / 518 = 0.730
113
+ # Bad recall = 378 / 586 = 0.645
114
+ # Bad F0.5 = 0.711
eval_critic/eval_data.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
gec/download_data.sh ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ conda activate errant200
2
+
3
+
4
+ ######################## Set up benckmarks ########################
5
+ mkdir -p benchmarks
6
+ cd benchmarks
7
+
8
+ #Prepare CoNLL2014
9
+ wget https://www.comp.nus.edu.sg/~nlp/conll14st/conll14st-test-data.tar.gz
10
+ tar -xf conll14st-test-data.tar.gz
11
+ python3 scripts/get_orig_from_m2.py conll14st-test-data/noalt/official-2014.combined.m2 \
12
+ -out conll14st-test-data/noalt/official-2014.combined.orig.txt
13
+
14
+
15
+ #Prepare BEA2019
16
+ wget https://www.cl.cam.ac.uk/research/nl/bea2019st/data/wi+locness_v2.1.bea19.tar.gz
17
+ tar -xf wi+locness_v2.1.bea19.tar.gz
18
+ mv wi+locness wi+locness_v2.1.bea19
19
+ python3 scripts/get_orig_from_m2.py wi+locness_v2.1.bea19/m2/ABCN.dev.gold.bea19.m2 \
20
+ -out wi+locness_v2.1.bea19/m2/ABCN.dev.bea19.orig.txt
21
+
22
+
23
+ #Prepare GMEG-wiki and -yahoo
24
+ git clone https://github.com/grammarly/GMEG.git
25
+ root=GMEG/data/test/wiki
26
+ errant_parallel -orig $root/source \
27
+ -cor $root/ref0 $root/ref1 $root/ref2 $root/ref3 \
28
+ -out $root/ref.m2
29
+
30
+ root=GMEG/data/test/yahoo
31
+ errant_parallel -orig $root/source \
32
+ -cor $root/ref0 $root/ref1 $root/ref2 $root/ref3 \
33
+ -out $root/ref.m2
34
+
35
+
36
+ #Download M2 scorer
37
+ git clone https://github.com/nusnlp/m2scorer.git
38
+
39
+
40
+ ######################## Download training data ########################
41
+ cd ../
42
+ wget https://nlp.stanford.edu/projects/myasu/LM-Critic/data.zip
43
+ unzip data.zip
gec/scripts/get_corr_from_m2.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ # Apply the edits of a single annotator to generate the corrected sentences.
4
+ def main(args):
5
+ m2 = open(args.m2_file).read().strip().split("\n\n")
6
+ out = open(args.out, "w")
7
+ # Do not apply edits with these error types
8
+ skip = {"noop", "UNK", "Um"}
9
+
10
+ for sent in m2:
11
+ sent = sent.split("\n")
12
+ cor_sent = sent[0].split()[1:] # Ignore "S "
13
+ edits = sent[1:]
14
+ offset = 0
15
+ for edit in edits:
16
+ edit = edit.split("|||")
17
+ if edit[1] in skip: continue # Ignore certain edits
18
+ coder = int(edit[-1])
19
+ if coder != args.id: continue # Ignore other coders
20
+ span = edit[0].split()[1:] # Ignore "A "
21
+ start = int(span[0])
22
+ end = int(span[1])
23
+ cor = edit[2].split()
24
+ cor_sent[start+offset:end+offset] = cor
25
+ offset = offset-(end-start)+len(cor)
26
+ out.write(" ".join(cor_sent)+"\n")
27
+
28
+ if __name__ == "__main__":
29
+ # Define and parse program input
30
+ parser = argparse.ArgumentParser()
31
+ parser.add_argument("m2_file", help="The path to an input m2 file.")
32
+ parser.add_argument("-out", help="A path to where we save the output corrected text file.", required=True)
33
+ parser.add_argument("-id", help="The id of the target annotator in the m2 file.", type=int, default=0)
34
+ args = parser.parse_args()
35
+ main(args)
gec/scripts/get_orig_from_m2.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ # Apply the edits of a single annotator to generate the corrected sentences.
4
+ def main(args):
5
+ m2 = open(args.m2_file).read().strip().split("\n\n")
6
+ out = open(args.out, "w")
7
+ # Do not apply edits with these error types
8
+ skip = {"noop", "UNK", "Um"}
9
+
10
+ for sent in m2:
11
+ sent = sent.split("\n")
12
+ orig_sent = sent[0].split()[1:] # Ignore "S "
13
+ out.write(" ".join(orig_sent)+"\n")
14
+
15
+ if __name__ == "__main__":
16
+ # Define and parse program input
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("m2_file", help="The path to an input m2 file.")
19
+ parser.add_argument("-out", help="A path to where we save the output corrected text file.", required=True)
20
+ parser.add_argument("-id", help="The id of the target annotator in the m2 file.", type=int, default=0)
21
+ args = parser.parse_args()
22
+ main(args)
gec/scripts/parse_errant_output.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import sys
4
+ import json
5
+
6
+ for i, line in enumerate(sys.stdin):
7
+ if i == 3:
8
+ nums = line.split()
9
+ P, R, F = nums[3:6]
10
+ json.dump({'precision': float(P), 'recall': float(R), 'F0.5': float(F)}, open('stats.json', 'w'), indent=2)
11
+ break
gec/scripts/parse_m2_output.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import sys
4
+ import json
5
+
6
+ scores = []
7
+ for i, line in enumerate(sys.stdin):
8
+ score = line.split(':')[1].strip()
9
+ scores.append(float(score))
10
+
11
+ json.dump({'precision': scores[0], 'recall': scores[1], 'F0.5': scores[2]}, open('stats.json', 'w'), indent=2)
gec/src/run-round0.sh ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exit 0;
2
+ ################################################################################
3
+ # run the following commands one by one in the `gec/` directory of the repo
4
+ ################################################################################
5
+ export CUDA_VISIBLE_DEVICES=0
6
+ conda activate lm-critic
7
+
8
+ ############### Train the fixer ###############
9
+ dt=`date '+%Y%m%d_%H%M%S'`
10
+ outdir=data/round0__synthetic/model-fixer__${dt}
11
+ mkdir -p $outdir
12
+ python3.8 -u src/run_seq2seq.py \
13
+ --model_name_or_path facebook/bart-base --task summarization --text_column bad_detoked --summary_column good_detoked \
14
+ --do_train --num_train_epochs 1 --train_file data/round0__synthetic/synthetic_paired_data_9M.json \
15
+ --preprocessing_num_workers 20 --overwrite_output_dir --output_dir $outdir --predict_with_generate --fp16 \
16
+ --per_device_train_batch_size 64 --gradient_accumulation_steps 8 --max_source_length 64 --max_target_length 64 \
17
+ --logging_first_step --logging_steps 20 --save_steps 2000 \
18
+ |& tee $outdir/log.txt
19
+
20
+
21
+
22
+ ############### Run the fixer on benchmarks ###############
23
+ model_path=data/round0__synthetic/model-fixer
24
+
25
+ #BEA2019
26
+ python src/run_fixer.py -m $model_path -i benchmarks/wi+locness_v2.1.bea19/m2/ABCN.dev.bea19.orig.txt -o $model_path/predictions/bea19dev.out.txt --bea19
27
+ #CoNLL2014
28
+ python src/run_fixer.py -m $model_path -i benchmarks/conll14st-test-data/noalt/official-2014.combined.orig.txt -o $model_path/predictions/conll14.out.txt
29
+ #GMEG-wiki
30
+ python src/run_fixer.py -m $model_path -i benchmarks/GMEG/data/test/wiki/source -o $model_path/predictions/gmeg.wiki.out.txt
31
+ #GMEG-yahoo
32
+ python src/run_fixer.py -m $model_path -i benchmarks/GMEG/data/test/yahoo/source -o $model_path/predictions/gmeg.yahoo.out.txt
33
+
34
+
35
+
36
+ ############### Evaluate the fixer outputs ###############
37
+ #CoNLL2014
38
+ python2 benchmarks/m2scorer/scripts/m2scorer.py $model_path/predictions/conll14.out.txt \
39
+ benchmarks/conll14st-test-data/noalt/official-2014.combined.m2 | tee $model_path/predictions/conll14.eval.txt
40
+ # Precision : 0.5922
41
+ # Recall : 0.2920
42
+ # F_0.5 : 0.4912
43
+
44
+
45
+ #BEA2019 and GMEG uses errant scorer, which needs its own environment
46
+ conda deactivate
47
+ conda activate errant200
48
+
49
+ #BEA2019
50
+ errant_parallel -orig benchmarks/wi+locness_v2.1.bea19/m2/ABCN.dev.bea19.orig.txt \
51
+ -cor $model_path/predictions/bea19dev.out.txt \
52
+ -out $model_path/predictions/bea19dev.outm2.txt && \
53
+ errant_compare -hyp $model_path/predictions/bea19dev.outm2.txt -ref benchmarks/wi+locness_v2.1.bea19/m2/ABCN.dev.gold.bea19.m2 | tee $model_path/predictions/bea19dev.eval.txt
54
+ # =========== Span-Based Correction ============
55
+ # TP FP FN Prec Rec F0.5
56
+ # 1337 1686 6124 0.4423 0.1792 0.3419
57
+ # ==============================================
58
+
59
+ #GEMG-wiki
60
+ errant_parallel -orig benchmarks/GMEG/data/test/wiki/source \
61
+ -cor $model_path/predictions/gmeg.wiki.out.txt \
62
+ -out $model_path/predictions/gmeg.wiki.outm2.txt && \
63
+ errant_compare -hyp $model_path/predictions/gmeg.wiki.outm2.txt -ref benchmarks/GMEG/data/test/wiki/ref.m2 | tee $model_path/predictions/gmeg.wiki.eval.txt
64
+ # =========== Span-Based Correction ============
65
+ # TP FP FN Prec Rec F0.5
66
+ # 352 323 973 0.5215 0.2657 0.4373
67
+ # ==============================================
68
+
69
+ #GEMG-yahoo
70
+ errant_parallel -orig benchmarks/GMEG/data/test/yahoo/source \
71
+ -cor $model_path/predictions/gmeg.yahoo.out.txt \
72
+ -out $model_path/predictions/gmeg.yahoo.outm2.txt && \
73
+ errant_compare -hyp $model_path/predictions/gmeg.yahoo.outm2.txt -ref benchmarks/GMEG/data/test/yahoo/ref.m2 | tee $model_path/predictions/gmeg.yahoo.eval.txt
74
+ # =========== Span-Based Correction ============
75
+ # TP FP FN Prec Rec F0.5
76
+ # 241 301 411 0.4446 0.3696 0.4273
77
+ # ==============================================
gec/src/run-round1.sh ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exit 0;
2
+ ################################################################################
3
+ # run the following commands one by one in the `gec/` directory of the repo
4
+ ################################################################################
5
+ export CUDA_VISIBLE_DEVICES=0
6
+ conda activate lm-critic
7
+
8
+ ############### Train the fixer ###############
9
+ dt=`date '+%Y%m%d_%H%M%S'`
10
+ outdir=data/round1__BIFI/model-fixer__${dt}
11
+ mkdir -p $outdir
12
+ python3.8 -u src/run_seq2seq.py \
13
+ --model_name_or_path facebook/bart-base --task summarization --text_column bad_detoked --summary_column good_detoked \
14
+ --do_train --num_train_epochs 1 --train_file data/round1__BIFI/BIFI_paired_data_9M.json \
15
+ --preprocessing_num_workers 20 --overwrite_output_dir --output_dir $outdir --predict_with_generate --fp16 \
16
+ --per_device_train_batch_size 64 --gradient_accumulation_steps 8 --max_source_length 64 --max_target_length 64 \
17
+ --logging_first_step --logging_steps 20 --save_steps 2000 \
18
+ |& tee $outdir/log.txt
19
+
20
+
21
+
22
+ ############### Run the fixer on benchmarks ###############
23
+ model_path=data/round1__BIFI/model-fixer
24
+ #BEA2019
25
+ python src/run_fixer.py -m $model_path -i benchmarks/wi+locness_v2.1.bea19/m2/ABCN.dev.bea19.orig.txt -o $model_path/predictions/bea19dev.out.txt --bea19
26
+ #CoNLL2014
27
+ python src/run_fixer.py -m $model_path -i benchmarks/conll14st-test-data/noalt/official-2014.combined.orig.txt -o $model_path/predictions/conll14.out.txt
28
+ #GMEG-wiki
29
+ python src/run_fixer.py -m $model_path -i benchmarks/GMEG/data/test/wiki/source -o $model_path/predictions/gmeg.wiki.out.txt
30
+ #GMEG-yahoo
31
+ python src/run_fixer.py -m $model_path -i benchmarks/GMEG/data/test/yahoo/source -o $model_path/predictions/gmeg.yahoo.out.txt
32
+
33
+
34
+
35
+ ############### Evaluate the fixer outputs ###############
36
+ #CoNLL2014
37
+ python2 benchmarks/m2scorer/scripts/m2scorer.py $model_path/predictions/conll14.out.txt \
38
+ benchmarks/conll14st-test-data/noalt/official-2014.combined.m2 | tee $model_path/predictions/conll14.eval.txt
39
+ # Precision : 0.6444
40
+ # Recall : 0.3569
41
+ # F_0.5 : 0.5550
42
+
43
+ #BEA2019 and GMEG uses errant scorer, which needs its own environment
44
+ conda deactivate
45
+ conda activate errant200
46
+
47
+ #BEA2019
48
+ errant_parallel -orig benchmarks/wi+locness_v2.1.bea19/m2/ABCN.dev.bea19.orig.txt \
49
+ -cor $model_path/predictions/bea19dev.out.txt \
50
+ -out $model_path/predictions/bea19dev.outm2.txt && \
51
+ errant_compare -hyp $model_path/predictions/bea19dev.outm2.txt -ref benchmarks/wi+locness_v2.1.bea19/m2/ABCN.dev.gold.bea19.m2 | tee $model_path/predictions/bea19dev.eval.txt
52
+ # =========== Span-Based Correction ============
53
+ # TP FP FN Prec Rec F0.5
54
+ # 1848 1733 5613 0.5161 0.2477 0.4241
55
+ # ==============================================
56
+
57
+ #GEMG-wiki
58
+ errant_parallel -orig benchmarks/GMEG/data/test/wiki/source \
59
+ -cor $model_path/predictions/gmeg.wiki.out.txt \
60
+ -out $model_path/predictions/gmeg.wiki.outm2.txt && \
61
+ errant_compare -hyp $model_path/predictions/gmeg.wiki.outm2.txt -ref benchmarks/GMEG/data/test/wiki/ref.m2 | tee $model_path/predictions/gmeg.wiki.eval.txt
62
+ # =========== Span-Based Correction ============
63
+ # TP FP FN Prec Rec F0.5
64
+ # 468 339 925 0.5799 0.336 0.5064
65
+ # ==============================================
66
+
67
+ #GEMG-yahoo
68
+ errant_parallel -orig benchmarks/GMEG/data/test/yahoo/source \
69
+ -cor $model_path/predictions/gmeg.yahoo.out.txt \
70
+ -out $model_path/predictions/gmeg.yahoo.outm2.txt && \
71
+ errant_compare -hyp $model_path/predictions/gmeg.yahoo.outm2.txt -ref benchmarks/GMEG/data/test/yahoo/ref.m2 | tee $model_path/predictions/gmeg.yahoo.eval.txt
72
+ # =========== Span-Based Correction ============
73
+ # TP FP FN Prec Rec F0.5
74
+ # 382 329 428 0.5373 0.4716 0.5227
75
+ # ==============================================
gec/src/run_fixer.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import torch
5
+ import argparse
6
+ from tqdm import tqdm
7
+ from transformers import BartForConditionalGeneration, BartTokenizer
8
+
9
+ sys.path.insert(0, '..')
10
+ from utils.text_utils import detokenize_sent
11
+ from utils.spacy_tokenizer import spacy_tokenize_gec, spacy_tokenize_bea19
12
+
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument('-m', '--model_path')
15
+ parser.add_argument('-i', '--input_path')
16
+ parser.add_argument('-o', '--output_path')
17
+ parser.add_argument('--bea19', action='store_true')
18
+ args = parser.parse_args()
19
+
20
+
21
+ tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
22
+ model = BartForConditionalGeneration.from_pretrained(args.model_path, force_bos_token_to_be_generated=True)
23
+ model.eval()
24
+ model.cuda()
25
+
26
+
27
+ def run_model(sents):
28
+ num_ret_seqs = 10
29
+ inp_max_len = 66
30
+ batch = [tokenizer(s, return_tensors='pt', padding='max_length', max_length=inp_max_len) for s in sents]
31
+ oidx2bidx = {} #original index to final batch index
32
+ final_batch = []
33
+ for oidx, elm in enumerate(batch):
34
+ if elm['input_ids'].size(1) <= inp_max_len:
35
+ oidx2bidx[oidx] = len(final_batch)
36
+ final_batch.append(elm)
37
+ batch = {key: torch.cat([elm[key] for elm in final_batch], dim=0) for key in final_batch[0]}
38
+ with torch.no_grad():
39
+ generated_ids = model.generate(batch['input_ids'].cuda(),
40
+ attention_mask=batch['attention_mask'].cuda(),
41
+ num_beams=10, num_return_sequences=num_ret_seqs, max_length=65)
42
+ _out = tokenizer.batch_decode(generated_ids.detach().cpu(), skip_special_tokens=True)
43
+ outs = []
44
+ for i in range(0, len(_out), num_ret_seqs):
45
+ outs.append(_out[i:i+num_ret_seqs])
46
+ final_outs = [[sents[oidx]] if oidx not in oidx2bidx else outs[oidx2bidx[oidx]] for oidx in range(len(sents))]
47
+ return final_outs
48
+
49
+
50
+ def run_for_wiki_yahoo_conll():
51
+ sents = [detokenize_sent(l.strip()) for l in open(args.input_path)]
52
+ b_size = 40
53
+ outs = []
54
+ for j in tqdm(range(0, len(sents), b_size)):
55
+ sents_batch = sents[j:j+b_size]
56
+ outs_batch = run_model(sents_batch)
57
+ for sent, preds in zip(sents_batch, outs_batch):
58
+ preds_detoked = [detokenize_sent(pred) for pred in preds]
59
+ preds = [' '.join(spacy_tokenize_gec(pred)) for pred in preds_detoked]
60
+ outs.append({'src': sent, 'preds': preds})
61
+ os.system('mkdir -p {}'.format(os.path.dirname(args.output_path)))
62
+ with open(args.output_path, 'w') as outf:
63
+ for out in outs:
64
+ print (out['preds'][0], file=outf)
65
+
66
+
67
+ def run_for_bea19():
68
+ sents = [detokenize_sent(l.strip()) for l in open(args.input_path)]
69
+ b_size = 40
70
+ outs = []
71
+ for j in tqdm(range(0, len(sents), b_size)):
72
+ sents_batch = sents[j:j+b_size]
73
+ outs_batch = run_model(sents_batch)
74
+ for sent, preds in zip(sents_batch, outs_batch):
75
+ preds_detoked = [detokenize_sent(pred) for pred in preds]
76
+ preds = [' '.join(spacy_tokenize_bea19(pred)) for pred in preds_detoked]
77
+ outs.append({'src': sent, 'preds': preds})
78
+ os.system('mkdir -p {}'.format(os.path.dirname(args.output_path)))
79
+ with open(args.output_path, 'w') as outf:
80
+ for out in outs:
81
+ print (out['preds'][0], file=outf)
82
+
83
+
84
+ if args.bea19:
85
+ run_for_bea19()
86
+ else:
87
+ run_for_wiki_yahoo_conll()
gec/src/run_seq2seq.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Fine-tuning the library models for sequence to sequence.
17
+ """
18
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
19
+
20
+ import logging
21
+ import os
22
+ import re
23
+ import sys
24
+ from dataclasses import dataclass, field
25
+ from typing import Optional
26
+
27
+ import numpy as np
28
+ from datasets import load_dataset, load_metric
29
+
30
+ import transformers
31
+ from transformers import (
32
+ AutoConfig,
33
+ AutoModelForSeq2SeqLM,
34
+ AutoTokenizer,
35
+ DataCollatorForSeq2Seq,
36
+ HfArgumentParser,
37
+ MBartTokenizer,
38
+ Seq2SeqTrainer,
39
+ Seq2SeqTrainingArguments,
40
+ default_data_collator,
41
+ set_seed,
42
+ )
43
+ from transformers.trainer_utils import get_last_checkpoint, is_main_process
44
+
45
+
46
+ logger = logging.getLogger(__name__)
47
+
48
+
49
+ @dataclass
50
+ class ModelArguments:
51
+ """
52
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
53
+ """
54
+
55
+ model_name_or_path: str = field(
56
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
57
+ )
58
+ config_name: Optional[str] = field(
59
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
60
+ )
61
+ tokenizer_name: Optional[str] = field(
62
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
63
+ )
64
+ cache_dir: Optional[str] = field(
65
+ default=None,
66
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
67
+ )
68
+ use_fast_tokenizer: bool = field(
69
+ default=True,
70
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
71
+ )
72
+ model_revision: str = field(
73
+ default="main",
74
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
75
+ )
76
+ use_auth_token: bool = field(
77
+ default=False,
78
+ metadata={
79
+ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
80
+ "with private models)."
81
+ },
82
+ )
83
+
84
+
85
+ @dataclass
86
+ class DataTrainingArguments:
87
+ """
88
+ Arguments pertaining to what data we are going to input our model for training and eval.
89
+ """
90
+
91
+ task: str = field(
92
+ default="summarization",
93
+ metadata={
94
+ "help": "The name of the task, should be summarization (or summarization_{dataset} for evaluating "
95
+ "pegasus) or translation (or translation_{xx}_to_{yy})."
96
+ },
97
+ )
98
+ dataset_name: Optional[str] = field(
99
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
100
+ )
101
+ dataset_config_name: Optional[str] = field(
102
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
103
+ )
104
+ text_column: Optional[str] = field(
105
+ default=None,
106
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
107
+ )
108
+ summary_column: Optional[str] = field(
109
+ default=None,
110
+ metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
111
+ )
112
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
113
+ validation_file: Optional[str] = field(
114
+ default=None,
115
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
116
+ )
117
+ overwrite_cache: bool = field(
118
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
119
+ )
120
+ preprocessing_num_workers: Optional[int] = field(
121
+ default=None,
122
+ metadata={"help": "The number of processes to use for the preprocessing."},
123
+ )
124
+ max_source_length: Optional[int] = field(
125
+ default=1024,
126
+ metadata={
127
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
128
+ "than this will be truncated, sequences shorter will be padded."
129
+ },
130
+ )
131
+ max_target_length: Optional[int] = field(
132
+ default=128,
133
+ metadata={
134
+ "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
135
+ "than this will be truncated, sequences shorter will be padded."
136
+ },
137
+ )
138
+ val_max_target_length: Optional[int] = field(
139
+ default=None,
140
+ metadata={
141
+ "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
142
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
143
+ "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
144
+ "during ``evaluate`` and ``predict``."
145
+ },
146
+ )
147
+ pad_to_max_length: bool = field(
148
+ default=False,
149
+ metadata={
150
+ "help": "Whether to pad all samples to model maximum sentence length. "
151
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
152
+ "efficient on GPU but very bad for TPU."
153
+ },
154
+ )
155
+ max_train_samples: Optional[int] = field(
156
+ default=None,
157
+ metadata={
158
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
159
+ "value if set."
160
+ },
161
+ )
162
+ max_val_samples: Optional[int] = field(
163
+ default=None,
164
+ metadata={
165
+ "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
166
+ "value if set."
167
+ },
168
+ )
169
+ source_lang: Optional[str] = field(default=None, metadata={"help": "Source language id for translation."})
170
+ target_lang: Optional[str] = field(default=None, metadata={"help": "Target language id for translation."})
171
+ eval_beams: Optional[int] = field(default=None, metadata={"help": "Number of beams to use for evaluation."})
172
+ ignore_pad_token_for_loss: bool = field(
173
+ default=True,
174
+ metadata={
175
+ "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
176
+ },
177
+ )
178
+ source_prefix: Optional[str] = field(
179
+ default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
180
+ )
181
+
182
+ def __post_init__(self):
183
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
184
+ raise ValueError("Need either a dataset name or a training/validation file.")
185
+ else:
186
+ if self.train_file is not None:
187
+ extension = self.train_file.split(".")[-1]
188
+ assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
189
+ if self.validation_file is not None:
190
+ extension = self.validation_file.split(".")[-1]
191
+ assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
192
+ if not self.task.startswith("summarization") and not self.task.startswith("translation"):
193
+ raise ValueError(
194
+ "`task` should be summarization, summarization_{dataset}, translation or translation_{xx}_to_{yy}."
195
+ )
196
+ if self.val_max_target_length is None:
197
+ self.val_max_target_length = self.max_target_length
198
+
199
+
200
+ summarization_name_mapping = {
201
+ "amazon_reviews_multi": ("review_body", "review_title"),
202
+ "big_patent": ("description", "abstract"),
203
+ "cnn_dailymail": ("article", "highlights"),
204
+ "orange_sum": ("text", "summary"),
205
+ "pn_summary": ("article", "summary"),
206
+ "psc": ("extract_text", "summary_text"),
207
+ "samsum": ("dialogue", "summary"),
208
+ "thaisum": ("body", "summary"),
209
+ "xglue": ("news_body", "news_title"),
210
+ "xsum": ("document", "summary"),
211
+ "wiki_summary": ("article", "highlights"),
212
+ }
213
+
214
+
215
+ def main():
216
+ # See all possible arguments in src/transformers/training_args.py
217
+ # or by passing the --help flag to this script.
218
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
219
+
220
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
221
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
222
+ # If we pass only one argument to the script and it's the path to a json file,
223
+ # let's parse it to get our arguments.
224
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
225
+ else:
226
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
227
+
228
+ # Detecting last checkpoint.
229
+ last_checkpoint = None
230
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
231
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
232
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
233
+ raise ValueError(
234
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
235
+ "Use --overwrite_output_dir to overcome."
236
+ )
237
+ elif last_checkpoint is not None:
238
+ logger.info(
239
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
240
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
241
+ )
242
+
243
+ # Setup logging
244
+ logging.basicConfig(
245
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
246
+ datefmt="%m/%d/%Y %H:%M:%S",
247
+ handlers=[logging.StreamHandler(sys.stdout)],
248
+ )
249
+ logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
250
+
251
+ # Log on each process the small summary:
252
+ logger.warning(
253
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
254
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
255
+ )
256
+ # Set the verbosity to info of the Transformers logger (on main process only):
257
+ if is_main_process(training_args.local_rank):
258
+ transformers.utils.logging.set_verbosity_info()
259
+ logger.info("Training/evaluation parameters %s", training_args)
260
+
261
+ # Set seed before initializing model.
262
+ set_seed(training_args.seed)
263
+
264
+ # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
265
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
266
+ # (the dataset will be downloaded automatically from the datasets Hub).
267
+ #
268
+ # For CSV/JSON files in the summarization task, this script will use the first column for the full texts and the
269
+ # second column for the summaries (unless you specify column names for this with the `text_column` and
270
+ # `summary_column` arguments).
271
+ # For translation, only JSON files are supported, with one field named "translation" containing two keys for the
272
+ # source and target languages (unless you adapt what follows).
273
+ #
274
+ # In distributed training, the load_dataset function guarantee that only one local process can concurrently
275
+ # download the dataset.
276
+ if data_args.dataset_name is not None:
277
+ # Downloading and loading a dataset from the hub.
278
+ datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name)
279
+ else:
280
+ data_files = {}
281
+ if data_args.train_file is not None:
282
+ data_files["train"] = data_args.train_file
283
+ extension = data_args.train_file.split(".")[-1]
284
+ if data_args.validation_file is not None:
285
+ data_files["validation"] = data_args.validation_file
286
+ extension = data_args.validation_file.split(".")[-1]
287
+ datasets = load_dataset(extension, data_files=data_files)
288
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
289
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
290
+
291
+ # Load pretrained model and tokenizer
292
+ #
293
+ # Distributed training:
294
+ # The .from_pretrained methods guarantee that only one local process can concurrently
295
+ # download model & vocab.
296
+ config = AutoConfig.from_pretrained(
297
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
298
+ cache_dir=model_args.cache_dir,
299
+ revision=model_args.model_revision,
300
+ use_auth_token=True if model_args.use_auth_token else None,
301
+ )
302
+ tokenizer = AutoTokenizer.from_pretrained(
303
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
304
+ cache_dir=model_args.cache_dir,
305
+ use_fast=model_args.use_fast_tokenizer,
306
+ revision=model_args.model_revision,
307
+ use_auth_token=True if model_args.use_auth_token else None,
308
+ )
309
+ model = AutoModelForSeq2SeqLM.from_pretrained(
310
+ model_args.model_name_or_path,
311
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
312
+ config=config,
313
+ cache_dir=model_args.cache_dir,
314
+ revision=model_args.model_revision,
315
+ use_auth_token=True if model_args.use_auth_token else None,
316
+ )
317
+
318
+ # Set decoder_start_token_id
319
+ if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer):
320
+ model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
321
+ if model.config.decoder_start_token_id is None:
322
+ raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
323
+
324
+ # Get the default prefix if None is passed.
325
+ if data_args.source_prefix is None:
326
+ task_specific_params = model.config.task_specific_params
327
+ if task_specific_params is not None:
328
+ prefix = task_specific_params.get("prefix", "")
329
+ else:
330
+ prefix = ""
331
+ else:
332
+ prefix = data_args.source_prefix
333
+
334
+ # Preprocessing the datasets.
335
+ # We need to tokenize inputs and targets.
336
+ if training_args.do_train:
337
+ column_names = datasets["train"].column_names
338
+ else:
339
+ column_names = datasets["validation"].column_names
340
+
341
+ # For translation we set the codes of our source and target languages (only useful for mBART, the others will
342
+ # ignore those attributes).
343
+ if data_args.task.startswith("translation"):
344
+ if data_args.source_lang is not None:
345
+ tokenizer.src_lang = data_args.source_lang
346
+ if data_args.target_lang is not None:
347
+ tokenizer.tgt_lang = data_args.target_lang
348
+
349
+ # To serialize preprocess_function below, each of those four variables needs to be defined (even if we won't use
350
+ # them all).
351
+ source_lang, target_lang, text_column, summary_column = None, None, None, None
352
+
353
+ if data_args.task.startswith("summarization"):
354
+ # Get the column names for input/target.
355
+ dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
356
+ if data_args.text_column is None:
357
+ text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
358
+ else:
359
+ text_column = data_args.text_column
360
+ if data_args.summary_column is None:
361
+ summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
362
+ else:
363
+ summary_column = data_args.summary_column
364
+ else:
365
+ # Get the language codes for input/target.
366
+ lang_search = re.match("translation_([a-z]+)_to_([a-z]+)", data_args.task)
367
+ if data_args.source_lang is not None:
368
+ source_lang = data_args.source_lang.split("_")[0]
369
+ else:
370
+ assert (
371
+ lang_search is not None
372
+ ), "Provide a source language via --source_lang or rename your task 'translation_xx_to_yy'."
373
+ source_lang = lang_search.groups()[0]
374
+
375
+ if data_args.target_lang is not None:
376
+ target_lang = data_args.target_lang.split("_")[0]
377
+ else:
378
+ assert (
379
+ lang_search is not None
380
+ ), "Provide a target language via --target_lang or rename your task 'translation_xx_to_yy'."
381
+ target_lang = lang_search.groups()[1]
382
+
383
+ # Temporarily set max_target_length for training.
384
+ max_target_length = data_args.max_target_length
385
+ padding = "max_length" if data_args.pad_to_max_length else False
386
+
387
+ def preprocess_function(examples):
388
+ if data_args.task.startswith("translation"):
389
+ inputs = [ex[source_lang] for ex in examples["translation"]]
390
+ targets = [ex[target_lang] for ex in examples["translation"]]
391
+ else:
392
+ inputs = examples[text_column]
393
+ targets = examples[summary_column]
394
+ inputs = [prefix + inp for inp in inputs]
395
+ model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
396
+
397
+ # Setup the tokenizer for targets
398
+ with tokenizer.as_target_tokenizer():
399
+ labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
400
+
401
+ # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
402
+ # padding in the loss.
403
+ if padding == "max_length" and data_args.ignore_pad_token_for_loss:
404
+ labels["input_ids"] = [
405
+ [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
406
+ ]
407
+
408
+ model_inputs["labels"] = labels["input_ids"]
409
+ return model_inputs
410
+
411
+ if training_args.do_train:
412
+ train_dataset = datasets["train"]
413
+ if data_args.max_train_samples is not None:
414
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
415
+ train_dataset = train_dataset.map(
416
+ preprocess_function,
417
+ batched=True,
418
+ num_proc=data_args.preprocessing_num_workers,
419
+ remove_columns=column_names,
420
+ load_from_cache_file=not data_args.overwrite_cache,
421
+ )
422
+
423
+ if training_args.do_eval:
424
+ max_target_length = data_args.val_max_target_length
425
+ eval_dataset = datasets["validation"]
426
+ if data_args.max_val_samples is not None:
427
+ eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
428
+ eval_dataset = eval_dataset.map(
429
+ preprocess_function,
430
+ batched=True,
431
+ num_proc=data_args.preprocessing_num_workers,
432
+ remove_columns=column_names,
433
+ load_from_cache_file=not data_args.overwrite_cache,
434
+ )
435
+
436
+ # Data collator
437
+ label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
438
+ if data_args.pad_to_max_length:
439
+ data_collator = default_data_collator
440
+ else:
441
+ data_collator = DataCollatorForSeq2Seq(
442
+ tokenizer,
443
+ label_pad_token_id=label_pad_token_id,
444
+ pad_to_multiple_of=8 if training_args.fp16 else None,
445
+ )
446
+
447
+ # Metric
448
+ metric_name = "rouge" if data_args.task.startswith("summarization") else "sacrebleu"
449
+ metric = load_metric(metric_name)
450
+
451
+ def compute_metrics(eval_preds):
452
+ preds, labels = eval_preds
453
+ if isinstance(preds, tuple):
454
+ preds = preds[0]
455
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
456
+ if data_args.ignore_pad_token_for_loss:
457
+ # Replace -100 in the labels as we can't decode them.
458
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
459
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
460
+
461
+ # Some simple post-processing
462
+ decoded_preds = [pred.strip() for pred in decoded_preds]
463
+ decoded_labels = [label.strip() for label in decoded_labels]
464
+ if metric_name == "sacrebleu":
465
+ decoded_labels = [[label] for label in decoded_labels]
466
+
467
+ result = metric.compute(predictions=decoded_preds, references=decoded_labels)
468
+
469
+ # Extract a few results from ROUGE
470
+ if metric_name == "rouge":
471
+ result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
472
+ else:
473
+ result = {"bleu": result["score"]}
474
+
475
+ prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
476
+ result["gen_len"] = np.mean(prediction_lens)
477
+
478
+ return result
479
+
480
+ # Initialize our Trainer
481
+ trainer = Seq2SeqTrainer(
482
+ model=model,
483
+ args=training_args,
484
+ train_dataset=train_dataset if training_args.do_train else None,
485
+ eval_dataset=eval_dataset if training_args.do_eval else None,
486
+ tokenizer=tokenizer,
487
+ data_collator=data_collator,
488
+ compute_metrics=compute_metrics if training_args.predict_with_generate else None,
489
+ )
490
+
491
+ # Training
492
+ if training_args.do_train:
493
+ if last_checkpoint is not None:
494
+ checkpoint = last_checkpoint
495
+ elif os.path.isdir(model_args.model_name_or_path):
496
+ checkpoint = model_args.model_name_or_path
497
+ else:
498
+ checkpoint = None
499
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
500
+ trainer.save_model() # Saves the tokenizer too for easy upload
501
+
502
+ output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
503
+ if trainer.is_world_process_zero():
504
+ with open(output_train_file, "w") as writer:
505
+ logger.info("***** Train results *****")
506
+ for key, value in sorted(train_result.metrics.items()):
507
+ logger.info(f" {key} = {value}")
508
+ writer.write(f"{key} = {value}\n")
509
+
510
+ # Need to save the state, since Trainer.save_model saves only the tokenizer with the model
511
+ trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
512
+
513
+ # Evaluation
514
+ results = {}
515
+ if training_args.do_eval:
516
+ logger.info("*** Evaluate ***")
517
+
518
+ results = trainer.evaluate()
519
+
520
+ output_eval_file = os.path.join(training_args.output_dir, "eval_results_seq2seq.txt")
521
+ if trainer.is_world_process_zero():
522
+ with open(output_eval_file, "w") as writer:
523
+ logger.info("***** Eval results *****")
524
+ for key, value in sorted(results.items()):
525
+ logger.info(f" {key} = {value}")
526
+ writer.write(f"{key} = {value}\n")
527
+
528
+ return results
529
+
530
+
531
+ def _mp_fn(index):
532
+ # For xla_spawn (TPUs)
533
+ main()
534
+
535
+
536
+ if __name__ == "__main__":
537
+ main()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ datasets==1.3.0
2
+ editdistance==0.6.0
3
+ nltk==3.7
4
+ numpy==1.22.3
5
+ spacy==3.0.5
6
+ streamlit==1.9.0
7
+ torch==1.11.0
8
+ tqdm==4.49.0
9
+ transformers==4.3.3
utils/__pycache__/spacy_tokenizer.cpython-38.pyc ADDED
Binary file (2.16 kB). View file
 
utils/__pycache__/text_utils.cpython-37.pyc ADDED
Binary file (1.81 kB). View file
 
utils/__pycache__/text_utils.cpython-38.pyc ADDED
Binary file (1.82 kB). View file
 
utils/spacy_tokenizer.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spacy
2
+ from spacy.tokenizer import Tokenizer
3
+ from spacy.lang.char_classes import ALPHA, ALPHA_LOWER, ALPHA_UPPER, CONCAT_QUOTES, LIST_ELLIPSES, LIST_ICONS, HYPHENS
4
+ from spacy.util import compile_infix_regex
5
+ from spacy.lang.en import English
6
+ nlp = English()
7
+
8
+ def get_tokenizer_gec(nlp):
9
+ infixes = (
10
+ LIST_ELLIPSES
11
+ + LIST_ICONS
12
+ + [
13
+ r"(?<=[0-9])[+\-\*^](?=[0-9-])",
14
+ r"(?<=[{al}{q}])\.(?=[{au}{q}])".format(
15
+ al=ALPHA_LOWER, au=ALPHA_UPPER, q=CONCAT_QUOTES
16
+ ),
17
+ r"(?<=[{a}]),(?=[{a}])".format(a=ALPHA),
18
+ #r"(?<=[{a}])(?:{h})(?=[{a}])".format(a=ALPHA, h=HYPHENS),
19
+ r"(?<=[{a}0-9])[:<>=/](?=[{a}])".format(a=ALPHA),
20
+ ]
21
+ )
22
+ infix_re = compile_infix_regex(infixes)
23
+ return Tokenizer(nlp.vocab, prefix_search=nlp.tokenizer.prefix_search,
24
+ suffix_search=nlp.tokenizer.suffix_search,
25
+ infix_finditer=infix_re.finditer,
26
+ token_match=nlp.tokenizer.token_match,
27
+ rules=nlp.Defaults.tokenizer_exceptions)
28
+
29
+
30
+ def get_tokenizer_bea19(nlp):
31
+ infixes = (
32
+ LIST_ELLIPSES
33
+ + LIST_ICONS
34
+ + [
35
+ r"(?<=[0-9])[+\-\*^](?=[0-9-])",
36
+ r"(?<=[{al}{q}])\.(?=[{au}{q}])".format(
37
+ al=ALPHA_LOWER, au=ALPHA_UPPER, q=CONCAT_QUOTES
38
+ ),
39
+ r"(?<=[{a}]),(?=[{a}])".format(a=ALPHA),
40
+ r"(?<=[{a}])(?:{h})(?=[{a}])".format(a=ALPHA, h=HYPHENS),
41
+ r"(?<=[{a}0-9])[:<>=/](?=[{a}])".format(a=ALPHA),
42
+ ]
43
+ )
44
+ infix_re = compile_infix_regex(infixes)
45
+ return Tokenizer(nlp.vocab, prefix_search=nlp.tokenizer.prefix_search,
46
+ suffix_search=nlp.tokenizer.suffix_search,
47
+ infix_finditer=infix_re.finditer,
48
+ token_match=nlp.tokenizer.token_match,
49
+ rules=nlp.Defaults.tokenizer_exceptions)
50
+
51
+
52
+ tokenizer_gec = get_tokenizer_gec(nlp)
53
+ tokenizer_bea19 = get_tokenizer_bea19(nlp)
54
+
55
+
56
+ def spacy_tokenize_gec(text):
57
+ nlp.tokenizer = tokenizer_gec
58
+ return [str(w) for w in nlp(text)]
59
+
60
+ def spacy_tokenize_bea19(text):
61
+ nlp.tokenizer = tokenizer_bea19
62
+ return [str(w) for w in nlp(text)]
utils/text_utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from nltk import sent_tokenize, word_tokenize
3
+ from nltk.tokenize.treebank import TreebankWordDetokenizer
4
+ detokenizer = TreebankWordDetokenizer()
5
+
6
+ def handle_dounble_quote(sent):
7
+ cur_str = ''
8
+ exp_left = True
9
+ ignore_space = False
10
+ for char in sent:
11
+ if char == '"':
12
+ if exp_left: #this is a left "
13
+ cur_str = cur_str.rstrip() + ' "'
14
+ exp_left = (not exp_left)
15
+ ignore_space = True
16
+ else: #this is a right "
17
+ cur_str = cur_str.rstrip() + '" '
18
+ exp_left = (not exp_left)
19
+ ignore_space = False
20
+ else:
21
+ if ignore_space: #expecting right
22
+ if char == ' ':
23
+ continue
24
+ else:
25
+ cur_str = cur_str + char
26
+ ignore_space = False
27
+ else:
28
+ cur_str = cur_str + char
29
+ cur_str = cur_str.strip()
30
+ cur_str = re.sub(r'[ ]+', ' ', cur_str)
31
+ return cur_str
32
+
33
+ def postprocess_space(sent):
34
+ sent = re.sub(r'[ ]+\.', '.', sent)
35
+ sent = re.sub(r'[ ]+,', ',', sent)
36
+ sent = re.sub(r'[ ]+!', '!', sent)
37
+ sent = re.sub(r'[ ]+\?', '?', sent)
38
+ sent = re.sub(r'\([ ]+', '(', sent)
39
+ sent = re.sub(r'[ ]+\)', ')', sent)
40
+ sent = re.sub(r' \'s( |\.|,|!|\?)', r"'s\1", sent)
41
+ sent = re.sub(r'n \'t( |\.|,|!|\?)', r"n't\1", sent)
42
+ return sent
43
+
44
+ def detokenize_sent(sent):
45
+ #Clean raw sent
46
+ sent = re.sub(r'\' s ', '\'s ', sent)
47
+ toks = sent.split()
48
+ if len([1 for t in toks if t=="'"]) % 2 == 0:
49
+ toks = ['"' if t=="'" else t for t in toks]
50
+ sent = ' '.join(toks)
51
+ #
52
+ sents = sent_tokenize(sent)
53
+ final_sents = []
54
+ for _sent in sents:
55
+ _sent = detokenizer.detokenize(_sent.split())
56
+ res = handle_dounble_quote(_sent)
57
+ if res == -1:
58
+ print ('unbalanced double quote')
59
+ print (_sent)
60
+ else:
61
+ _sent = res
62
+ final_sents.append(_sent)
63
+ sent = ' '.join(final_sents)
64
+ sent = postprocess_space(sent)
65
+ return sent