ngohuudang commited on
Commit
1b76ad1
1 Parent(s): 9bb5ff5

update file

Browse files
.gitattributes CHANGED
@@ -2,34 +2,26 @@
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
  *.wasm filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
5
  *.ftz filter=lfs diff=lfs merge=lfs -text
6
  *.gz filter=lfs diff=lfs merge=lfs -text
7
  *.h5 filter=lfs diff=lfs merge=lfs -text
8
  *.joblib filter=lfs diff=lfs merge=lfs -text
9
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
10
  *.model filter=lfs diff=lfs merge=lfs -text
11
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
12
  *.onnx filter=lfs diff=lfs merge=lfs -text
13
  *.ot filter=lfs diff=lfs merge=lfs -text
14
  *.parquet filter=lfs diff=lfs merge=lfs -text
15
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
16
  *.pt filter=lfs diff=lfs merge=lfs -text
17
  *.pth filter=lfs diff=lfs merge=lfs -text
18
  *.rar filter=lfs diff=lfs merge=lfs -text
 
19
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
20
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
21
  *.tflite filter=lfs diff=lfs merge=lfs -text
22
  *.tgz filter=lfs diff=lfs merge=lfs -text
23
  *.wasm filter=lfs diff=lfs merge=lfs -text
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
__pycache__/gec_model.cpython-310.pyc ADDED
Binary file (14.1 kB). View file
 
__pycache__/gec_model.cpython-311.pyc ADDED
Binary file (25.9 kB). View file
 
__pycache__/gec_model.cpython-39.pyc ADDED
Binary file (14.2 kB). View file
 
__pycache__/modeling_seq2labels.cpython-310.pyc ADDED
Binary file (3.97 kB). View file
 
__pycache__/modeling_seq2labels.cpython-311.pyc ADDED
Binary file (7.03 kB). View file
 
__pycache__/modeling_seq2labels.cpython-39.pyc ADDED
Binary file (4.06 kB). View file
 
__pycache__/utils.cpython-310.pyc ADDED
Binary file (6.13 kB). View file
 
__pycache__/utils_gec.cpython-310.pyc ADDED
Binary file (6.14 kB). View file
 
__pycache__/utils_gec.cpython-311.pyc ADDED
Binary file (11.8 kB). View file
 
__pycache__/utils_gec.cpython-39.pyc ADDED
Binary file (6.12 kB). View file
 
__pycache__/vocabulary.cpython-310.pyc ADDED
Binary file (12.9 kB). View file
 
__pycache__/vocabulary.cpython-311.pyc ADDED
Binary file (18.9 kB). View file
 
__pycache__/vocabulary.cpython-39.pyc ADDED
Binary file (13 kB). View file
 
config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Seq2LabelsModel"
4
+ ],
5
+ "initializer_range": 0.02,
6
+ "label_smoothing": 0.0,
7
+ "load_pretrained": false,
8
+ "model_type": "bert",
9
+ "num_detect_classes": 4,
10
+ "pad_token_id": 0,
11
+ "predictor_dropout": 0.0,
12
+ "pretrained_name_or_path": "xlm-roberta-capu/xlm-roberta-base",
13
+ "special_tokens_fix": true,
14
+ "torch_dtype": "float32",
15
+ "transformers_version": "4.18.0",
16
+ "use_cache": true,
17
+ "vocab_size": 15
18
+ }
configuration_seq2labels.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class Seq2LabelsConfig(PretrainedConfig):
5
+ r"""
6
+ This is the configuration class to store the configuration of a [`Seq2LabelsModel`]. It is used to
7
+ instantiate a Seq2Labels model according to the specified arguments, defining the model architecture. Instantiating a
8
+ configuration with the defaults will yield a similar configuration to that of the Seq2Labels architecture.
9
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
10
+ documentation from [`PretrainedConfig`] for more information.
11
+ Args:
12
+ vocab_size (`int`, *optional*, defaults to 30522):
13
+ Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
14
+ `inputs_ids` passed when calling [`BertModel`] or [`TFBertModel`].
15
+ pretrained_name_or_path (`str`, *optional*, defaults to `bert-base-cased`):
16
+ Pretrained BERT-like model path
17
+ load_pretrained (`bool`, *optional*, defaults to `False`):
18
+ Whether to load pretrained model from `pretrained_name_or_path`
19
+ use_cache (`bool`, *optional*, defaults to `True`):
20
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
21
+ relevant if `config.is_decoder=True`.
22
+ predictor_dropout (`float`, *optional*):
23
+ The dropout ratio for the classification head.
24
+ special_tokens_fix (`bool`, *optional*, defaults to `False`):
25
+ Whether to add additional tokens to the BERT's embedding layer.
26
+ Examples:
27
+ ```python
28
+ >>> from transformers import BertModel, BertConfig
29
+ >>> # Initializing a Seq2Labels style configuration
30
+ >>> configuration = Seq2LabelsConfig()
31
+ >>> # Initializing a model from the bert-base-uncased style configuration
32
+ >>> model = Seq2LabelsModel(configuration)
33
+ >>> # Accessing the model configuration
34
+ >>> configuration = model.config
35
+ ```"""
36
+ model_type = "bert"
37
+
38
+ def __init__(
39
+ self,
40
+ pretrained_name_or_path="bert-base-cased",
41
+ vocab_size=15,
42
+ num_detect_classes=4,
43
+ load_pretrained=False,
44
+ initializer_range=0.02,
45
+ pad_token_id=0,
46
+ use_cache=True,
47
+ predictor_dropout=0.0,
48
+ special_tokens_fix=False,
49
+ label_smoothing=0.0,
50
+ **kwargs
51
+ ):
52
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
53
+
54
+ self.vocab_size = vocab_size
55
+ self.num_detect_classes = num_detect_classes
56
+ self.pretrained_name_or_path = pretrained_name_or_path
57
+ self.load_pretrained = load_pretrained
58
+ self.initializer_range = initializer_range
59
+ self.use_cache = use_cache
60
+ self.predictor_dropout = predictor_dropout
61
+ self.special_tokens_fix = special_tokens_fix
62
+ self.label_smoothing = label_smoothing
gec_model.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Wrapper of Seq2Labels model. Fixes errors based on model predictions"""
2
+ from collections import defaultdict
3
+ from difflib import SequenceMatcher
4
+ import logging
5
+ import re
6
+ from time import time
7
+ from typing import List, Union
8
+ import warnings
9
+ import sys
10
+ import torch
11
+ from transformers import AutoTokenizer
12
+ from modeling_seq2labels import Seq2LabelsModel
13
+ from vocabulary import Vocabulary
14
+ from utils_gec import PAD, UNK, START_TOKEN, get_target_sent_by_edits
15
+ current_dir = sys.path[0].replace('\\','/')
16
+ logging.getLogger("werkzeug").setLevel(logging.ERROR)
17
+ logger = logging.getLogger(__file__)
18
+
19
+
20
+ class GecBERTModel(torch.nn.Module):
21
+ def __init__(
22
+ self,
23
+ vocab_path=None,
24
+ model_paths=None,
25
+ weights=None,
26
+ device=None,
27
+ max_len=64,
28
+ min_len=3,
29
+ lowercase_tokens=False,
30
+ log=False,
31
+ iterations=3,
32
+ min_error_probability=0.0,
33
+ confidence=0,
34
+ resolve_cycles=False,
35
+ split_chunk=False,
36
+ chunk_size=48,
37
+ overlap_size=12,
38
+ min_words_cut=6,
39
+ punc_dict={':', ".", ",", "?"},
40
+ ):
41
+ r"""
42
+ Args:
43
+ vocab_path (`str`):
44
+ Path to vocabulary directory.
45
+ model_paths (`List[str]`):
46
+ List of model paths.
47
+ weights (`int`, *Optional*, defaults to None):
48
+ Weights of each model. Only relevant if `is_ensemble is True`.
49
+ device (`int`, *Optional*, defaults to None):
50
+ Device to load model. If not set, device will be automatically choose.
51
+ max_len (`int`, defaults to 64):
52
+ Max sentence length to be processed (all longer will be truncated).
53
+ min_len (`int`, defaults to 3):
54
+ Min sentence length to be processed (all shorted will be returned w/o changes).
55
+ lowercase_tokens (`bool`, defaults to False):
56
+ Whether to lowercase tokens.
57
+ log (`bool`, defaults to False):
58
+ Whether to enable logging.
59
+ iterations (`int`, defaults to 3):
60
+ Max iterations to run during inference.
61
+ special_tokens_fix (`bool`, defaults to True):
62
+ Whether to fix problem with [CLS], [SEP] tokens tokenization.
63
+ min_error_probability (`float`, defaults to `0.0`):
64
+ Minimum probability for each action to apply.
65
+ confidence (`float`, defaults to `0.0`):
66
+ How many probability to add to $KEEP token.
67
+ split_chunk (`bool`, defaults to False):
68
+ Whether to split long sentences to multiple segments of `chunk_size`.
69
+ !Warning: if `chunk_size > max_len`, each segment will be truncate to `max_len`.
70
+ chunk_size (`int`, defaults to 48):
71
+ Length of each segment (in words). Only relevant if `split_chunk is True`.
72
+ overlap_size (`int`, defaults to 12):
73
+ Overlap size (in words) between two consecutive segments. Only relevant if `split_chunk is True`.
74
+ min_words_cut (`int`, defaults to 6):
75
+ Minimun number of words to be cut while merging two consecutive segments.
76
+ Only relevant if `split_chunk is True`.
77
+ punc_dict (List[str], defaults to `{':', ".", ",", "?"}`):
78
+ List of punctuations.
79
+ """
80
+ super().__init__()
81
+ if isinstance(model_paths, str):
82
+ model_paths = [model_paths]
83
+ self.model_weights = list(map(float, weights)) if weights else [1] * len(model_paths)
84
+ self.device = (
85
+ torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device)
86
+ )
87
+ # self.device = torch.device("cpu")
88
+ self.max_len = max_len
89
+ self.min_len = min_len
90
+ self.lowercase_tokens = lowercase_tokens
91
+ self.min_error_probability = min_error_probability
92
+ self.vocab = Vocabulary.from_files(vocab_path)
93
+ self.incorr_index = self.vocab.get_token_index("INCORRECT", "d_tags")
94
+ self.log = log
95
+ self.iterations = iterations
96
+ self.confidence = confidence
97
+ self.resolve_cycles = resolve_cycles
98
+
99
+ assert (
100
+ chunk_size > 0 and chunk_size // 2 >= overlap_size
101
+ ), "Chunk merging required overlap size must be smaller than half of chunk size"
102
+ self.split_chunk = split_chunk
103
+ self.chunk_size = chunk_size
104
+ self.overlap_size = overlap_size
105
+ self.min_words_cut = min_words_cut
106
+ self.stride = chunk_size - overlap_size
107
+ self.punc_dict = punc_dict
108
+ self.punc_str = '[' + ''.join([f'\{x}' for x in punc_dict]) + ']'
109
+ # set training parameters and operations
110
+
111
+ self.indexers = []
112
+ self.models = []
113
+ for model_path in model_paths:
114
+ model = Seq2LabelsModel.from_pretrained(model_path)
115
+ config = model.config
116
+ model_name = current_dir + "/" + config.pretrained_name_or_path
117
+ special_tokens_fix = config.special_tokens_fix
118
+ self.indexers.append(self._get_indexer(model_name, special_tokens_fix))
119
+ model.eval().to(self.device)
120
+ self.models.append(model)
121
+
122
+ def _get_indexer(self, weights_name, special_tokens_fix):
123
+ tokenizer = AutoTokenizer.from_pretrained(
124
+ weights_name, do_basic_tokenize=False,
125
+ do_lower_case=self.lowercase_tokens, model_max_length=1024
126
+ )
127
+ # to adjust all tokenizers
128
+ if hasattr(tokenizer, 'encoder'):
129
+ tokenizer.vocab = tokenizer.encoder
130
+ if hasattr(tokenizer, 'sp_model'):
131
+ tokenizer.vocab = defaultdict(lambda: 1)
132
+ for i in range(tokenizer.sp_model.get_piece_size()):
133
+ tokenizer.vocab[tokenizer.sp_model.id_to_piece(i)] = i
134
+
135
+ if special_tokens_fix:
136
+ tokenizer.add_tokens([START_TOKEN])
137
+ tokenizer.vocab[START_TOKEN] = len(tokenizer) - 1
138
+ return tokenizer
139
+
140
+ def forward(self, text: Union[str, List[str], List[List[str]]], is_split_into_words=False):
141
+ # Input type checking for clearer error
142
+ def _is_valid_text_input(t):
143
+ if isinstance(t, str):
144
+ # Strings are fine
145
+ return True
146
+ elif isinstance(t, (list, tuple)):
147
+ # List are fine as long as they are...
148
+ if len(t) == 0:
149
+ # ... empty
150
+ return True
151
+ elif isinstance(t[0], str):
152
+ # ... list of strings
153
+ return True
154
+ elif isinstance(t[0], (list, tuple)):
155
+ # ... list with an empty list or with a list of strings
156
+ return len(t[0]) == 0 or isinstance(t[0][0], str)
157
+ else:
158
+ return False
159
+ else:
160
+ return False
161
+
162
+ if not _is_valid_text_input(text):
163
+ raise ValueError(
164
+ "text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) "
165
+ "or `List[List[str]]` (batch of pretokenized examples)."
166
+ )
167
+
168
+ if is_split_into_words:
169
+ is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
170
+ else:
171
+ is_batched = isinstance(text, (list, tuple))
172
+ if is_batched:
173
+ text = [x.split() for x in text]
174
+ else:
175
+ text = text.split()
176
+
177
+ if not is_batched:
178
+ text = [text]
179
+
180
+ return self.handle_batch(text)
181
+
182
+ def split_chunks(self, batch):
183
+ # return batch pairs of indices
184
+ result = []
185
+ indices = []
186
+ for tokens in batch:
187
+ start = len(result)
188
+ num_token = len(tokens)
189
+ if num_token <= self.chunk_size:
190
+ result.append(tokens)
191
+ elif num_token > self.chunk_size and num_token < (self.chunk_size * 2 - self.overlap_size):
192
+ split_idx = (num_token + self.overlap_size + 1) // 2
193
+ result.append(tokens[:split_idx])
194
+ result.append(tokens[split_idx - self.overlap_size :])
195
+ else:
196
+ for i in range(0, num_token - self.overlap_size, self.stride):
197
+ result.append(tokens[i : i + self.chunk_size])
198
+
199
+ indices.append((start, len(result)))
200
+
201
+ return result, indices
202
+
203
+ def check_alnum(self, s):
204
+ if len(s) < 2:
205
+ return False
206
+ return not (s.isalpha() or s.isdigit())
207
+
208
+ def apply_chunk_merging(self, tokens, next_tokens):
209
+ # Return next tokens if current tokens list is empty
210
+ if not tokens:
211
+ return next_tokens
212
+
213
+ source_token_idx = []
214
+ target_token_idx = []
215
+ source_tokens = []
216
+ target_tokens = []
217
+ num_keep = self.overlap_size - self.min_words_cut
218
+ i = 0
219
+ while len(source_token_idx) < self.overlap_size and -i < len(tokens):
220
+ i -= 1
221
+ if tokens[i] not in self.punc_dict:
222
+ source_token_idx.insert(0, i)
223
+ source_tokens.insert(0, tokens[i].lower())
224
+
225
+ i = 0
226
+ while len(target_token_idx) < self.overlap_size and i < len(next_tokens):
227
+ if next_tokens[i] not in self.punc_dict:
228
+ target_token_idx.append(i)
229
+ target_tokens.append(next_tokens[i].lower())
230
+ i += 1
231
+
232
+ matcher = SequenceMatcher(None, source_tokens, target_tokens)
233
+ diffs = list(matcher.get_opcodes())
234
+
235
+ for diff in diffs:
236
+ tag, i1, i2, j1, j2 = diff
237
+ if tag == "equal":
238
+ if i1 >= num_keep:
239
+ tail_idx = source_token_idx[i1]
240
+ head_idx = target_token_idx[j1]
241
+ break
242
+ elif i2 > num_keep:
243
+ tail_idx = source_token_idx[num_keep]
244
+ head_idx = target_token_idx[j2 - i2 + num_keep]
245
+ break
246
+ elif tag == "delete" and i1 == 0:
247
+ num_keep += i2 // 2
248
+
249
+ tokens = tokens[:tail_idx] + next_tokens[head_idx:]
250
+ return tokens
251
+
252
+ def merge_chunks(self, batch):
253
+ result = []
254
+ if len(batch) == 1 or self.overlap_size == 0:
255
+ for sub_tokens in batch:
256
+ result.extend(sub_tokens)
257
+ else:
258
+ for _, sub_tokens in enumerate(batch):
259
+ try:
260
+ result = self.apply_chunk_merging(result, sub_tokens)
261
+ except Exception as e:
262
+ print(e)
263
+
264
+ result = " ".join(result)
265
+ return result
266
+
267
+ def predict(self, batches):
268
+ t11 = time()
269
+ predictions = []
270
+ for batch, model in zip(batches, self.models):
271
+ batch = batch.to(self.device)
272
+ with torch.no_grad():
273
+ prediction = model.forward(**batch)
274
+ predictions.append(prediction)
275
+
276
+ preds, idx, error_probs = self._convert(predictions)
277
+ t55 = time()
278
+ if self.log:
279
+ print(f"Inference time {t55 - t11}")
280
+ return preds, idx, error_probs
281
+
282
+ def get_token_action(self, token, index, prob, sugg_token):
283
+ """Get lost of suggested actions for token."""
284
+ # cases when we don't need to do anything
285
+ if prob < self.min_error_probability or sugg_token in [UNK, PAD, '$KEEP']:
286
+ return None
287
+
288
+ if sugg_token.startswith('$REPLACE_') or sugg_token.startswith('$TRANSFORM_') or sugg_token == '$DELETE':
289
+ start_pos = index
290
+ end_pos = index + 1
291
+ elif sugg_token.startswith("$APPEND_") or sugg_token.startswith("$MERGE_"):
292
+ start_pos = index + 1
293
+ end_pos = index + 1
294
+
295
+ if sugg_token == "$DELETE":
296
+ sugg_token_clear = ""
297
+ elif sugg_token.startswith('$TRANSFORM_') or sugg_token.startswith("$MERGE_"):
298
+ sugg_token_clear = sugg_token[:]
299
+ else:
300
+ sugg_token_clear = sugg_token[sugg_token.index('_') + 1 :]
301
+
302
+ return start_pos - 1, end_pos - 1, sugg_token_clear, prob
303
+
304
+ def preprocess(self, token_batch):
305
+ seq_lens = [len(sequence) for sequence in token_batch if sequence]
306
+ if not seq_lens:
307
+ return []
308
+ max_len = min(max(seq_lens), self.max_len)
309
+ batches = []
310
+ for indexer in self.indexers:
311
+ token_batch = [[START_TOKEN] + sequence[:max_len] for sequence in token_batch]
312
+ batch = indexer(
313
+ token_batch,
314
+ return_tensors="pt",
315
+ padding=True,
316
+ is_split_into_words=True,
317
+ truncation=True,
318
+ add_special_tokens=False,
319
+ )
320
+ offset_batch = []
321
+ for i in range(len(token_batch)):
322
+ word_ids = batch.word_ids(batch_index=i)
323
+ offsets = [0]
324
+ for i in range(1, len(word_ids)):
325
+ if word_ids[i] != word_ids[i - 1]:
326
+ offsets.append(i)
327
+ offset_batch.append(torch.LongTensor(offsets))
328
+
329
+ batch["input_offsets"] = torch.nn.utils.rnn.pad_sequence(
330
+ offset_batch, batch_first=True, padding_value=0
331
+ ).to(torch.long)
332
+
333
+ batches.append(batch)
334
+
335
+ return batches
336
+
337
+ def _convert(self, data):
338
+ all_class_probs = torch.zeros_like(data[0]['logits'])
339
+ error_probs = torch.zeros_like(data[0]['max_error_probability'])
340
+ for output, weight in zip(data, self.model_weights):
341
+ class_probabilities_labels = torch.softmax(output['logits'], dim=-1)
342
+ all_class_probs += weight * class_probabilities_labels / sum(self.model_weights)
343
+ class_probabilities_d = torch.softmax(output['detect_logits'], dim=-1)
344
+ error_probs_d = class_probabilities_d[:, :, self.incorr_index]
345
+ incorr_prob = torch.max(error_probs_d, dim=-1)[0]
346
+ error_probs += weight * incorr_prob / sum(self.model_weights)
347
+
348
+ max_vals = torch.max(all_class_probs, dim=-1)
349
+ probs = max_vals[0].tolist()
350
+ idx = max_vals[1].tolist()
351
+ return probs, idx, error_probs.tolist()
352
+
353
+ def update_final_batch(self, final_batch, pred_ids, pred_batch, prev_preds_dict):
354
+ new_pred_ids = []
355
+ total_updated = 0
356
+ for i, orig_id in enumerate(pred_ids):
357
+ orig = final_batch[orig_id]
358
+ pred = pred_batch[i]
359
+ prev_preds = prev_preds_dict[orig_id]
360
+ if orig != pred and pred not in prev_preds:
361
+ final_batch[orig_id] = pred
362
+ new_pred_ids.append(orig_id)
363
+ prev_preds_dict[orig_id].append(pred)
364
+ total_updated += 1
365
+ elif orig != pred and pred in prev_preds:
366
+ # update final batch, but stop iterations
367
+ final_batch[orig_id] = pred
368
+ total_updated += 1
369
+ else:
370
+ continue
371
+ return final_batch, new_pred_ids, total_updated
372
+
373
+ def postprocess_batch(self, batch, all_probabilities, all_idxs, error_probs):
374
+ all_results = []
375
+ noop_index = self.vocab.get_token_index("$KEEP", "labels")
376
+ for tokens, probabilities, idxs, error_prob in zip(batch, all_probabilities, all_idxs, error_probs):
377
+ length = min(len(tokens), self.max_len)
378
+ edits = []
379
+
380
+ # skip whole sentences if there no errors
381
+ if max(idxs) == 0:
382
+ all_results.append(tokens)
383
+ continue
384
+
385
+ # skip whole sentence if probability of correctness is not high
386
+ if error_prob < self.min_error_probability:
387
+ all_results.append(tokens)
388
+ continue
389
+
390
+ for i in range(length + 1):
391
+ # because of START token
392
+ if i == 0:
393
+ token = START_TOKEN
394
+ else:
395
+ token = tokens[i - 1]
396
+ # skip if there is no error
397
+ if idxs[i] == noop_index:
398
+ continue
399
+
400
+ sugg_token = self.vocab.get_token_from_index(idxs[i], namespace='labels')
401
+ action = self.get_token_action(token, i, probabilities[i], sugg_token)
402
+ if not action:
403
+ continue
404
+
405
+ edits.append(action)
406
+ all_results.append(get_target_sent_by_edits(tokens, edits))
407
+ return all_results
408
+
409
+ def handle_batch(self, full_batch, merge_punc=True):
410
+ """
411
+ Handle batch of requests.
412
+ """
413
+ if self.split_chunk:
414
+ full_batch, indices = self.split_chunks(full_batch)
415
+ else:
416
+ indices = None
417
+ final_batch = full_batch[:]
418
+ batch_size = len(full_batch)
419
+ prev_preds_dict = {i: [final_batch[i]] for i in range(len(final_batch))}
420
+ short_ids = [i for i in range(len(full_batch)) if len(full_batch[i]) < self.min_len]
421
+ pred_ids = [i for i in range(len(full_batch)) if i not in short_ids]
422
+ total_updates = 0
423
+
424
+ for n_iter in range(self.iterations):
425
+ orig_batch = [final_batch[i] for i in pred_ids]
426
+
427
+ sequences = self.preprocess(orig_batch)
428
+
429
+ if not sequences:
430
+ break
431
+ probabilities, idxs, error_probs = self.predict(sequences)
432
+
433
+ pred_batch = self.postprocess_batch(orig_batch, probabilities, idxs, error_probs)
434
+ if self.log:
435
+ print(f"Iteration {n_iter + 1}. Predicted {round(100*len(pred_ids)/batch_size, 1)}% of sentences.")
436
+
437
+ final_batch, pred_ids, cnt = self.update_final_batch(final_batch, pred_ids, pred_batch, prev_preds_dict)
438
+ total_updates += cnt
439
+
440
+ if not pred_ids:
441
+ break
442
+ if self.split_chunk:
443
+ final_batch = [self.merge_chunks(final_batch[start:end]) for (start, end) in indices]
444
+ else:
445
+ final_batch = [" ".join(x) for x in final_batch]
446
+ if merge_punc:
447
+ final_batch = [re.sub(r'\s+(%s)' % self.punc_str, r'\1', x) for x in final_batch]
448
+
449
+ return final_batch
modeling_seq2labels.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple, Union
2
+ from torch import nn
3
+ from torch.nn import CrossEntropyLoss
4
+ from transformers import AutoConfig, AutoModel, BertPreTrainedModel
5
+ from transformers.modeling_outputs import ModelOutput
6
+ import sys
7
+ import torch
8
+ current_dir = sys.path[0].replace('\\','/')
9
+
10
+ def get_range_vector(size: int, device: int) -> torch.Tensor:
11
+ """
12
+ Returns a range vector with the desired size, starting at 0. The CUDA implementation
13
+ is meant to avoid copy data from CPU to GPU.
14
+ """
15
+ return torch.arange(0, size, dtype=torch.long, device=device)
16
+
17
+
18
+ class Seq2LabelsOutput(ModelOutput):
19
+ loss: Optional[torch.FloatTensor] = None
20
+ logits: torch.FloatTensor = None
21
+ detect_logits: torch.FloatTensor = None
22
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
23
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
24
+ max_error_probability: Optional[torch.FloatTensor] = None
25
+
26
+
27
+ class Seq2LabelsModel(BertPreTrainedModel):
28
+
29
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
30
+
31
+ def __init__(self, config):
32
+ super().__init__(config)
33
+ self.num_labels = config.num_labels
34
+ self.num_detect_classes = config.num_detect_classes
35
+ self.label_smoothing = config.label_smoothing
36
+
37
+ if config.load_pretrained:
38
+ self.bert = AutoModel.from_pretrained(current_dir + "/" + config.pretrained_name_or_path)
39
+ bert_config = self.bert.config
40
+ else:
41
+ print(current_dir + "/" + config.pretrained_name_or_path)
42
+ bert_config = AutoConfig.from_pretrained(current_dir + "/" + config.pretrained_name_or_path)
43
+ self.bert = AutoModel.from_config(bert_config)
44
+
45
+ if config.special_tokens_fix:
46
+ try:
47
+ vocab_size = self.bert.embeddings.word_embeddings.num_embeddings
48
+ except AttributeError:
49
+ # reserve more space
50
+ vocab_size = self.bert.word_embedding.num_embeddings + 5
51
+ self.bert.resize_token_embeddings(vocab_size + 1)
52
+
53
+ predictor_dropout = config.predictor_dropout if config.predictor_dropout is not None else 0.0
54
+ self.dropout = nn.Dropout(predictor_dropout)
55
+ self.classifier = nn.Linear(bert_config.hidden_size, config.vocab_size)
56
+ self.detector = nn.Linear(bert_config.hidden_size, config.num_detect_classes)
57
+
58
+ # Initialize weights and apply final processing
59
+ self.post_init()
60
+
61
+ def forward(
62
+ self,
63
+ input_ids: Optional[torch.Tensor] = None,
64
+ input_offsets: Optional[torch.Tensor] = None,
65
+ attention_mask: Optional[torch.Tensor] = None,
66
+ token_type_ids: Optional[torch.Tensor] = None,
67
+ position_ids: Optional[torch.Tensor] = None,
68
+ head_mask: Optional[torch.Tensor] = None,
69
+ inputs_embeds: Optional[torch.Tensor] = None,
70
+ labels: Optional[torch.Tensor] = None,
71
+ d_tags: Optional[torch.Tensor] = None,
72
+ output_attentions: Optional[bool] = None,
73
+ output_hidden_states: Optional[bool] = None,
74
+ return_dict: Optional[bool] = None,
75
+ ) -> Union[Tuple[torch.Tensor], Seq2LabelsOutput]:
76
+ r"""
77
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
78
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
79
+ """
80
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
81
+
82
+ outputs = self.bert(
83
+ input_ids,
84
+ attention_mask=attention_mask,
85
+ token_type_ids=token_type_ids,
86
+ position_ids=position_ids,
87
+ head_mask=head_mask,
88
+ inputs_embeds=inputs_embeds,
89
+ output_attentions=output_attentions,
90
+ output_hidden_states=output_hidden_states,
91
+ return_dict=return_dict,
92
+ )
93
+
94
+ sequence_output = outputs[0]
95
+
96
+ if input_offsets is not None:
97
+ # offsets is (batch_size, d1, ..., dn, orig_sequence_length)
98
+ range_vector = get_range_vector(input_offsets.size(0), device=sequence_output.device).unsqueeze(1)
99
+ # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
100
+ sequence_output = sequence_output[range_vector, input_offsets]
101
+
102
+ logits = self.classifier(self.dropout(sequence_output))
103
+ logits_d = self.detector(sequence_output)
104
+
105
+ loss = None
106
+ if labels is not None and d_tags is not None:
107
+ loss_labels_fct = CrossEntropyLoss(label_smoothing=self.label_smoothing)
108
+ loss_d_fct = CrossEntropyLoss()
109
+ loss_labels = loss_labels_fct(logits.view(-1, self.num_labels), labels.view(-1))
110
+ loss_d = loss_d_fct(logits_d.view(-1, self.num_detect_classes), d_tags.view(-1))
111
+ loss = loss_labels + loss_d
112
+
113
+ if not return_dict:
114
+ output = (logits, logits_d) + outputs[2:]
115
+ return ((loss,) + output) if loss is not None else output
116
+
117
+ return Seq2LabelsOutput(
118
+ loss=loss,
119
+ logits=logits,
120
+ detect_logits=logits_d,
121
+ hidden_states=outputs.hidden_states,
122
+ attentions=outputs.attentions,
123
+ max_error_probability=torch.ones(logits.size(0), device=logits.device),
124
+ )
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6e2a5c2b1cbf16a9fd0b88c0dc8585f3832a60d10eea8140854f8d8f32c188d
3
+ size 1112304873
utils_gec.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import re
4
+
5
+
6
+ VOCAB_DIR = Path(__file__).resolve().parent
7
+ PAD = "@@PADDING@@"
8
+ UNK = "@@UNKNOWN@@"
9
+ START_TOKEN = "$START"
10
+ SEQ_DELIMETERS = {"tokens": " ", "labels": "SEPL|||SEPR", "operations": "SEPL__SEPR"}
11
+
12
+
13
+ def get_verb_form_dicts():
14
+ path_to_dict = os.path.join(VOCAB_DIR, "verb-form-vocab.txt")
15
+ encode, decode = {}, {}
16
+ with open(path_to_dict, encoding="utf-8") as f:
17
+ for line in f:
18
+ words, tags = line.split(":")
19
+ word1, word2 = words.split("_")
20
+ tag1, tag2 = tags.split("_")
21
+ decode_key = f"{word1}_{tag1}_{tag2.strip()}"
22
+ if decode_key not in decode:
23
+ encode[words] = tags
24
+ decode[decode_key] = word2
25
+ return encode, decode
26
+
27
+
28
+ ENCODE_VERB_DICT, DECODE_VERB_DICT = get_verb_form_dicts()
29
+
30
+
31
+ def get_target_sent_by_edits(source_tokens, edits):
32
+ target_tokens = source_tokens[:]
33
+ shift_idx = 0
34
+ for edit in edits:
35
+ start, end, label, _ = edit
36
+ target_pos = start + shift_idx
37
+ if start < 0:
38
+ continue
39
+ elif len(target_tokens) > target_pos:
40
+ source_token = target_tokens[target_pos]
41
+ else:
42
+ source_token = ""
43
+ if label == "":
44
+ del target_tokens[target_pos]
45
+ shift_idx -= 1
46
+ elif start == end:
47
+ word = label.replace("$APPEND_", "")
48
+ # Avoid appending same token twice
49
+ if (target_pos < len(target_tokens) and target_tokens[target_pos] == word) or (
50
+ target_pos > 0 and target_tokens[target_pos - 1] == word
51
+ ):
52
+ continue
53
+ target_tokens[target_pos:target_pos] = [word]
54
+ shift_idx += 1
55
+ elif label.startswith("$TRANSFORM_"):
56
+ word = apply_reverse_transformation(source_token, label)
57
+ if word is None:
58
+ word = source_token
59
+ target_tokens[target_pos] = word
60
+ elif start == end - 1:
61
+ word = label.replace("$REPLACE_", "")
62
+ target_tokens[target_pos] = word
63
+ elif label.startswith("$MERGE_"):
64
+ target_tokens[target_pos + 1 : target_pos + 1] = [label]
65
+ shift_idx += 1
66
+
67
+ return replace_merge_transforms(target_tokens)
68
+
69
+
70
+ def replace_merge_transforms(tokens):
71
+ if all(not x.startswith("$MERGE_") for x in tokens):
72
+ return tokens
73
+ if tokens[0].startswith("$MERGE_"):
74
+ tokens = tokens[1:]
75
+ if tokens[-1].startswith("$MERGE_"):
76
+ tokens = tokens[:-1]
77
+
78
+ target_line = " ".join(tokens)
79
+ target_line = target_line.replace(" $MERGE_HYPHEN ", "-")
80
+ target_line = target_line.replace(" $MERGE_SPACE ", "")
81
+ target_line = re.sub(r'([\.\,\?\:]\s+)+', r'\1', target_line)
82
+ return target_line.split()
83
+
84
+
85
+ def convert_using_case(token, smart_action):
86
+ if not smart_action.startswith("$TRANSFORM_CASE_"):
87
+ return token
88
+ if smart_action.endswith("LOWER"):
89
+ return token.lower()
90
+ elif smart_action.endswith("UPPER"):
91
+ return token.upper()
92
+ elif smart_action.endswith("CAPITAL"):
93
+ return token.capitalize()
94
+ elif smart_action.endswith("CAPITAL_1"):
95
+ return token[0] + token[1:].capitalize()
96
+ elif smart_action.endswith("UPPER_-1"):
97
+ return token[:-1].upper() + token[-1]
98
+ else:
99
+ return token
100
+
101
+
102
+ def convert_using_verb(token, smart_action):
103
+ key_word = "$TRANSFORM_VERB_"
104
+ if not smart_action.startswith(key_word):
105
+ raise Exception(f"Unknown action type {smart_action}")
106
+ encoding_part = f"{token}_{smart_action[len(key_word):]}"
107
+ decoded_target_word = decode_verb_form(encoding_part)
108
+ return decoded_target_word
109
+
110
+
111
+ def convert_using_split(token, smart_action):
112
+ key_word = "$TRANSFORM_SPLIT"
113
+ if not smart_action.startswith(key_word):
114
+ raise Exception(f"Unknown action type {smart_action}")
115
+ target_words = token.split("-")
116
+ return " ".join(target_words)
117
+
118
+
119
+ def convert_using_plural(token, smart_action):
120
+ if smart_action.endswith("PLURAL"):
121
+ return token + "s"
122
+ elif smart_action.endswith("SINGULAR"):
123
+ return token[:-1]
124
+ else:
125
+ raise Exception(f"Unknown action type {smart_action}")
126
+
127
+
128
+ def apply_reverse_transformation(source_token, transform):
129
+ if transform.startswith("$TRANSFORM"):
130
+ # deal with equal
131
+ if transform == "$KEEP":
132
+ return source_token
133
+ # deal with case
134
+ if transform.startswith("$TRANSFORM_CASE"):
135
+ return convert_using_case(source_token, transform)
136
+ # deal with verb
137
+ if transform.startswith("$TRANSFORM_VERB"):
138
+ return convert_using_verb(source_token, transform)
139
+ # deal with split
140
+ if transform.startswith("$TRANSFORM_SPLIT"):
141
+ return convert_using_split(source_token, transform)
142
+ # deal with single/plural
143
+ if transform.startswith("$TRANSFORM_AGREEMENT"):
144
+ return convert_using_plural(source_token, transform)
145
+ # raise exception if not find correct type
146
+ raise Exception(f"Unknown action type {transform}")
147
+ else:
148
+ return source_token
149
+
150
+
151
+ # def read_parallel_lines(fn1, fn2):
152
+ # lines1 = read_lines(fn1, skip_strip=True)
153
+ # lines2 = read_lines(fn2, skip_strip=True)
154
+ # assert len(lines1) == len(lines2)
155
+ # out_lines1, out_lines2 = [], []
156
+ # for line1, line2 in zip(lines1, lines2):
157
+ # if not line1.strip() or not line2.strip():
158
+ # continue
159
+ # else:
160
+ # out_lines1.append(line1)
161
+ # out_lines2.append(line2)
162
+ # return out_lines1, out_lines2
163
+
164
+
165
+ def read_parallel_lines(fn1, fn2):
166
+ with open(fn1, encoding='utf-8') as f1, open(fn2, encoding='utf-8') as f2:
167
+ for line1, line2 in zip(f1, f2):
168
+ line1 = line1.strip()
169
+ line2 = line2.strip()
170
+
171
+ yield line1, line2
172
+
173
+
174
+ def read_lines(fn, skip_strip=False):
175
+ if not os.path.exists(fn):
176
+ return []
177
+ with open(fn, 'r', encoding='utf-8') as f:
178
+ lines = f.readlines()
179
+ return [s.strip() for s in lines if s.strip() or skip_strip]
180
+
181
+
182
+ def write_lines(fn, lines, mode='w'):
183
+ if mode == 'w' and os.path.exists(fn):
184
+ os.remove(fn)
185
+ with open(fn, encoding='utf-8', mode=mode) as f:
186
+ f.writelines(['%s\n' % s for s in lines])
187
+
188
+
189
+ def decode_verb_form(original):
190
+ return DECODE_VERB_DICT.get(original)
191
+
192
+
193
+ def encode_verb_form(original_word, corrected_word):
194
+ decoding_request = original_word + "_" + corrected_word
195
+ decoding_response = ENCODE_VERB_DICT.get(decoding_request, "").strip()
196
+ if original_word and decoding_response:
197
+ answer = decoding_response
198
+ else:
199
+ answer = None
200
+ return answer
201
+
202
+
203
+ def get_weights_name(transformer_name, lowercase):
204
+ if transformer_name == 'bert' and lowercase:
205
+ return 'bert-base-uncased'
206
+ if transformer_name == 'bert' and not lowercase:
207
+ return 'bert-base-cased'
208
+ if transformer_name == 'bert-large' and not lowercase:
209
+ return 'bert-large-cased'
210
+ if transformer_name == 'distilbert':
211
+ if not lowercase:
212
+ print('Warning! This model was trained only on uncased sentences.')
213
+ return 'distilbert-base-uncased'
214
+ if transformer_name == 'albert':
215
+ if not lowercase:
216
+ print('Warning! This model was trained only on uncased sentences.')
217
+ return 'albert-base-v1'
218
+ if lowercase:
219
+ print('Warning! This model was trained only on cased sentences.')
220
+ if transformer_name == 'roberta':
221
+ return 'roberta-base'
222
+ if transformer_name == 'roberta-large':
223
+ return 'roberta-large'
224
+ if transformer_name == 'gpt2':
225
+ return 'gpt2'
226
+ if transformer_name == 'transformerxl':
227
+ return 'transfo-xl-wt103'
228
+ if transformer_name == 'xlnet':
229
+ return 'xlnet-base-cased'
230
+ if transformer_name == 'xlnet-large':
231
+ return 'xlnet-large-cased'
232
+
233
+ return transformer_name
verb-form-vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
vocabulary.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import codecs
2
+ from collections import defaultdict
3
+ import logging
4
+ import os
5
+ import re
6
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union, TYPE_CHECKING
7
+ from filelock import FileLock
8
+
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ DEFAULT_NON_PADDED_NAMESPACES = ("*tags", "*labels")
13
+ DEFAULT_PADDING_TOKEN = "@@PADDING@@"
14
+ DEFAULT_OOV_TOKEN = "@@UNKNOWN@@"
15
+ NAMESPACE_PADDING_FILE = "non_padded_namespaces.txt"
16
+ _NEW_LINE_REGEX = re.compile(r"\n|\r\n")
17
+
18
+
19
+ def namespace_match(pattern: str, namespace: str):
20
+ """
21
+ Matches a namespace pattern against a namespace string. For example, `*tags` matches
22
+ `passage_tags` and `question_tags` and `tokens` matches `tokens` but not
23
+ `stemmed_tokens`.
24
+ """
25
+ if pattern[0] == "*" and namespace.endswith(pattern[1:]):
26
+ return True
27
+ elif pattern == namespace:
28
+ return True
29
+ return False
30
+
31
+
32
+ class _NamespaceDependentDefaultDict(defaultdict):
33
+ """
34
+ This is a [defaultdict]
35
+ (https://docs.python.org/2/library/collections.html#collections.defaultdict) where the
36
+ default value is dependent on the key that is passed.
37
+ We use "namespaces" in the :class:`Vocabulary` object to keep track of several different
38
+ mappings from strings to integers, so that we have a consistent API for mapping words, tags,
39
+ labels, characters, or whatever else you want, into integers. The issue is that some of those
40
+ namespaces (words and characters) should have integers reserved for padding and
41
+ out-of-vocabulary tokens, while others (labels and tags) shouldn't. This class allows you to
42
+ specify filters on the namespace (the key used in the `defaultdict`), and use different
43
+ default values depending on whether the namespace passes the filter.
44
+ To do filtering, we take a set of `non_padded_namespaces`. This is a set of strings
45
+ that are either matched exactly against the keys, or treated as suffixes, if the
46
+ string starts with `*`. In other words, if `*tags` is in `non_padded_namespaces` then
47
+ `passage_tags`, `question_tags`, etc. (anything that ends with `tags`) will have the
48
+ `non_padded` default value.
49
+ # Parameters
50
+ non_padded_namespaces : `Iterable[str]`
51
+ A set / list / tuple of strings describing which namespaces are not padded. If a namespace
52
+ (key) is missing from this dictionary, we will use :func:`namespace_match` to see whether
53
+ the namespace should be padded. If the given namespace matches any of the strings in this
54
+ list, we will use `non_padded_function` to initialize the value for that namespace, and
55
+ we will use `padded_function` otherwise.
56
+ padded_function : `Callable[[], Any]`
57
+ A zero-argument function to call to initialize a value for a namespace that `should` be
58
+ padded.
59
+ non_padded_function : `Callable[[], Any]`
60
+ A zero-argument function to call to initialize a value for a namespace that should `not` be
61
+ padded.
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ non_padded_namespaces: Iterable[str],
67
+ padded_function: Callable[[], Any],
68
+ non_padded_function: Callable[[], Any],
69
+ ) -> None:
70
+ self._non_padded_namespaces = set(non_padded_namespaces)
71
+ self._padded_function = padded_function
72
+ self._non_padded_function = non_padded_function
73
+ super().__init__()
74
+
75
+ def add_non_padded_namespaces(self, non_padded_namespaces: Set[str]):
76
+ # add non_padded_namespaces which weren't already present
77
+ self._non_padded_namespaces.update(non_padded_namespaces)
78
+
79
+
80
+ class _TokenToIndexDefaultDict(_NamespaceDependentDefaultDict):
81
+ def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None:
82
+ super().__init__(non_padded_namespaces, lambda: {padding_token: 0, oov_token: 1}, lambda: {})
83
+
84
+
85
+ class _IndexToTokenDefaultDict(_NamespaceDependentDefaultDict):
86
+ def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None:
87
+ super().__init__(non_padded_namespaces, lambda: {0: padding_token, 1: oov_token}, lambda: {})
88
+
89
+
90
+ class Vocabulary:
91
+ def __init__(
92
+ self,
93
+ counter: Dict[str, Dict[str, int]] = None,
94
+ min_count: Dict[str, int] = None,
95
+ max_vocab_size: Union[int, Dict[str, int]] = None,
96
+ non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES,
97
+ pretrained_files: Optional[Dict[str, str]] = None,
98
+ only_include_pretrained_words: bool = False,
99
+ tokens_to_add: Dict[str, List[str]] = None,
100
+ min_pretrained_embeddings: Dict[str, int] = None,
101
+ padding_token: Optional[str] = DEFAULT_PADDING_TOKEN,
102
+ oov_token: Optional[str] = DEFAULT_OOV_TOKEN,
103
+ ) -> None:
104
+ self._padding_token = padding_token if padding_token is not None else DEFAULT_PADDING_TOKEN
105
+ self._oov_token = oov_token if oov_token is not None else DEFAULT_OOV_TOKEN
106
+
107
+ self._non_padded_namespaces = set(non_padded_namespaces)
108
+
109
+ self._token_to_index = _TokenToIndexDefaultDict(
110
+ self._non_padded_namespaces, self._padding_token, self._oov_token
111
+ )
112
+ self._index_to_token = _IndexToTokenDefaultDict(
113
+ self._non_padded_namespaces, self._padding_token, self._oov_token
114
+ )
115
+
116
+ @classmethod
117
+ def from_files(
118
+ cls,
119
+ directory: Union[str, os.PathLike],
120
+ padding_token: Optional[str] = DEFAULT_PADDING_TOKEN,
121
+ oov_token: Optional[str] = DEFAULT_OOV_TOKEN,
122
+ ) -> "Vocabulary":
123
+ """
124
+ Loads a `Vocabulary` that was serialized either using `save_to_files` or inside
125
+ a model archive file.
126
+ # Parameters
127
+ directory : `str`
128
+ The directory or archive file containing the serialized vocabulary.
129
+ """
130
+ logger.info("Loading token dictionary from %s.", directory)
131
+ padding_token = padding_token if padding_token is not None else DEFAULT_PADDING_TOKEN
132
+ oov_token = oov_token if oov_token is not None else DEFAULT_OOV_TOKEN
133
+
134
+ if not os.path.isdir(directory):
135
+ raise ValueError(f"{directory} not exist")
136
+
137
+ # We use a lock file to avoid race conditions where multiple processes
138
+ # might be reading/writing from/to the same vocab files at once.
139
+ with FileLock(os.path.join(directory, ".lock")):
140
+ with codecs.open(os.path.join(directory, NAMESPACE_PADDING_FILE), "r", "utf-8") as namespace_file:
141
+ non_padded_namespaces = [namespace_str.strip() for namespace_str in namespace_file]
142
+
143
+ vocab = cls(
144
+ non_padded_namespaces=non_padded_namespaces,
145
+ padding_token=padding_token,
146
+ oov_token=oov_token,
147
+ )
148
+
149
+ # Check every file in the directory.
150
+ for namespace_filename in os.listdir(directory):
151
+ if namespace_filename == NAMESPACE_PADDING_FILE:
152
+ continue
153
+ if namespace_filename.startswith("."):
154
+ continue
155
+ namespace = namespace_filename.replace(".txt", "")
156
+ if any(namespace_match(pattern, namespace) for pattern in non_padded_namespaces):
157
+ is_padded = False
158
+ else:
159
+ is_padded = True
160
+ filename = os.path.join(directory, namespace_filename)
161
+ vocab.set_from_file(filename, is_padded, namespace=namespace, oov_token=oov_token)
162
+
163
+ return vocab
164
+
165
+ @classmethod
166
+ def empty(cls) -> "Vocabulary":
167
+ """
168
+ This method returns a bare vocabulary instantiated with `cls()` (so, `Vocabulary()` if you
169
+ haven't made a subclass of this object). The only reason to call `Vocabulary.empty()`
170
+ instead of `Vocabulary()` is if you are instantiating this object from a config file. We
171
+ register this constructor with the key "empty", so if you know that you don't need to
172
+ compute a vocabulary (either because you're loading a pre-trained model from an archive
173
+ file, you're using a pre-trained transformer that has its own vocabulary, or something
174
+ else), you can use this to avoid having the default vocabulary construction code iterate
175
+ through the data.
176
+ """
177
+ return cls()
178
+
179
+ def set_from_file(
180
+ self,
181
+ filename: str,
182
+ is_padded: bool = True,
183
+ oov_token: str = DEFAULT_OOV_TOKEN,
184
+ namespace: str = "tokens",
185
+ ):
186
+ """
187
+ If you already have a vocabulary file for a trained model somewhere, and you really want to
188
+ use that vocabulary file instead of just setting the vocabulary from a dataset, for
189
+ whatever reason, you can do that with this method. You must specify the namespace to use,
190
+ and we assume that you want to use padding and OOV tokens for this.
191
+ # Parameters
192
+ filename : `str`
193
+ The file containing the vocabulary to load. It should be formatted as one token per
194
+ line, with nothing else in the line. The index we assign to the token is the line
195
+ number in the file (1-indexed if `is_padded`, 0-indexed otherwise). Note that this
196
+ file should contain the OOV token string!
197
+ is_padded : `bool`, optional (default=`True`)
198
+ Is this vocabulary padded? For token / word / character vocabularies, this should be
199
+ `True`; while for tag or label vocabularies, this should typically be `False`. If
200
+ `True`, we add a padding token with index 0, and we enforce that the `oov_token` is
201
+ present in the file.
202
+ oov_token : `str`, optional (default=`DEFAULT_OOV_TOKEN`)
203
+ What token does this vocabulary use to represent out-of-vocabulary characters? This
204
+ must show up as a line in the vocabulary file. When we find it, we replace
205
+ `oov_token` with `self._oov_token`, because we only use one OOV token across
206
+ namespaces.
207
+ namespace : `str`, optional (default=`"tokens"`)
208
+ What namespace should we overwrite with this vocab file?
209
+ """
210
+ if is_padded:
211
+ self._token_to_index[namespace] = {self._padding_token: 0}
212
+ self._index_to_token[namespace] = {0: self._padding_token}
213
+ else:
214
+ self._token_to_index[namespace] = {}
215
+ self._index_to_token[namespace] = {}
216
+ with codecs.open(filename, "r", "utf-8") as input_file:
217
+ lines = _NEW_LINE_REGEX.split(input_file.read())
218
+ # Be flexible about having final newline or not
219
+ if lines and lines[-1] == "":
220
+ lines = lines[:-1]
221
+ for i, line in enumerate(lines):
222
+ index = i + 1 if is_padded else i
223
+ token = line.replace("@@NEWLINE@@", "\n")
224
+ if token == oov_token:
225
+ token = self._oov_token
226
+ self._token_to_index[namespace][token] = index
227
+ self._index_to_token[namespace][index] = token
228
+ if is_padded:
229
+ assert self._oov_token in self._token_to_index[namespace], "OOV token not found!"
230
+
231
+ def add_token_to_namespace(self, token: str, namespace: str = "tokens") -> int:
232
+ """
233
+ Adds `token` to the index, if it is not already present. Either way, we return the index of
234
+ the token.
235
+ """
236
+ if not isinstance(token, str):
237
+ raise ValueError(
238
+ "Vocabulary tokens must be strings, or saving and loading will break."
239
+ " Got %s (with type %s)" % (repr(token), type(token))
240
+ )
241
+ if token not in self._token_to_index[namespace]:
242
+ index = len(self._token_to_index[namespace])
243
+ self._token_to_index[namespace][token] = index
244
+ self._index_to_token[namespace][index] = token
245
+ return index
246
+ else:
247
+ return self._token_to_index[namespace][token]
248
+
249
+ def add_tokens_to_namespace(self, tokens: List[str], namespace: str = "tokens") -> List[int]:
250
+ """
251
+ Adds `tokens` to the index, if they are not already present. Either way, we return the
252
+ indices of the tokens in the order that they were given.
253
+ """
254
+ return [self.add_token_to_namespace(token, namespace) for token in tokens]
255
+
256
+ def get_token_index(self, token: str, namespace: str = "tokens") -> int:
257
+ try:
258
+ return self._token_to_index[namespace][token]
259
+ except KeyError:
260
+ try:
261
+ return self._token_to_index[namespace][self._oov_token]
262
+ except KeyError:
263
+ logger.error("Namespace: %s", namespace)
264
+ logger.error("Token: %s", token)
265
+ raise KeyError(
266
+ f"'{token}' not found in vocab namespace '{namespace}', and namespace "
267
+ f"does not contain the default OOV token ('{self._oov_token}')"
268
+ )
269
+
270
+ def get_token_from_index(self, index: int, namespace: str = "tokens") -> str:
271
+ return self._index_to_token[namespace][index]
272
+
273
+ def get_vocab_size(self, namespace: str = "tokens") -> int:
274
+ return len(self._token_to_index[namespace])
275
+
276
+ def get_namespaces(self) -> Set[str]:
277
+ return set(self._index_to_token.keys())
vocabulary/d_tags.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ CORRECT
2
+ INCORRECT
3
+ @@UNKNOWN@@
4
+ @@PADDING@@
vocabulary/labels.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $KEEP
2
+ $TRANSFORM_CASE_CAPITAL
3
+ $APPEND_,
4
+ $APPEND_.
5
+ $TRANSFORM_VERB_VB_VBN
6
+ $TRANSFORM_CASE_UPPER
7
+ $APPEND_:
8
+ $APPEND_?
9
+ $TRANSFORM_VERB_VB_VBC
10
+ $TRANSFORM_CASE_LOWER
11
+ $TRANSFORM_CASE_CAPITAL_1
12
+ $TRANSFORM_CASE_UPPER_-1
13
+ $MERGE_SPACE
14
+ @@UNKNOWN@@
15
+ @@PADDING@@
vocabulary/non_padded_namespaces.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *tags
2
+ *labels
xlm-roberta-base/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "XLMRobertaForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "bos_token_id": 0,
7
+ "eos_token_id": 2,
8
+ "hidden_act": "gelu",
9
+ "hidden_dropout_prob": 0.1,
10
+ "hidden_size": 768,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 3072,
13
+ "layer_norm_eps": 1e-05,
14
+ "max_position_embeddings": 514,
15
+ "model_type": "xlm-roberta",
16
+ "num_attention_heads": 12,
17
+ "num_hidden_layers": 12,
18
+ "output_past": true,
19
+ "pad_token_id": 1,
20
+ "position_embedding_type": "absolute",
21
+ "transformers_version": "4.17.0.dev0",
22
+ "type_vocab_size": 1,
23
+ "use_cache": true,
24
+ "vocab_size": 250002
25
+ }
xlm-roberta-base/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff