dragonSwing commited on
Commit
dd9b3ed
1 Parent(s): 1270e99

Initialize commit

Browse files
app.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gec_model import GecBERTModel
3
+
4
+ title = "Vietnamese Capitalization and Punctuation recovering"
5
+ description = "Auto capitalize and punctuations for plain, lowercase Vietnamese text."
6
+ article = "Coming soon"
7
+
8
+ examples = [
9
+ ["viBERT", "theo đó thủ tướng dự kiến tiếp bộ trưởng nông nghiệp mỹ tom wilsack bộ trưởng thương mại mỹ gina raimondo bộ trưởng tài chính janet yellen gặp gỡ thượng nghị sĩ patrick leahy và một số nghị sĩ mỹ khác"],
10
+ ["viBERT", "những gói cước năm g mobifone sẽ mang đến cho bạn những trải nghiệm mới lạ trên cả tuyệt vời so với mạng bốn g thì tốc độ truy cập mạng năm g mobifone được nhận định là siêu đỉnh với mức truy cập nhanh gấp 10 lần"],
11
+ ]
12
+
13
+ model_dict = {
14
+ "viBERT": GecBERTModel(
15
+ vocab_path="vocabulary",
16
+ model_paths='dragonSwing/vibert-capu',
17
+ max_len=64,
18
+ min_len=3,
19
+ iterations=3,
20
+ min_error_probability=0.2,
21
+ lowercase_tokens=False,
22
+ log=False,
23
+ confidence=0,
24
+ split_chunk=True,
25
+ chunk_size=48,
26
+ overlap_size=16,
27
+ min_words_cut=8,
28
+ ),
29
+ }
30
+
31
+ def fn(model_choice, input):
32
+ if model_choice in model_dict:
33
+ return model_dict[model_choice](input)[0]
34
+ else:
35
+ return "Unsupported model"
36
+
37
+ gr.Interface(fn, [gr.inputs.Dropdown(["viBERT"]), gr.inputs.Textbox(lines=5)], "text", examples=examples, title=title, description=description, article=article).launch()
configuration_seq2labels.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class Seq2LabelsConfig(PretrainedConfig):
5
+ r"""
6
+ This is the configuration class to store the configuration of a [`Seq2LabelsModel`]. It is used to
7
+ instantiate a Seq2Labels model according to the specified arguments, defining the model architecture. Instantiating a
8
+ configuration with the defaults will yield a similar configuration to that of the Seq2Labels architecture.
9
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
10
+ documentation from [`PretrainedConfig`] for more information.
11
+ Args:
12
+ vocab_size (`int`, *optional*, defaults to 30522):
13
+ Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
14
+ `inputs_ids` passed when calling [`BertModel`] or [`TFBertModel`].
15
+ pretrained_name_or_path (`str`, *optional*, defaults to `bert-base-cased`):
16
+ Pretrained BERT-like model path
17
+ load_pretrained (`bool`, *optional*, defaults to `False`):
18
+ Whether to load pretrained model from `pretrained_name_or_path`
19
+ use_cache (`bool`, *optional*, defaults to `True`):
20
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
21
+ relevant if `config.is_decoder=True`.
22
+ predictor_dropout (`float`, *optional*):
23
+ The dropout ratio for the classification head.
24
+ special_tokens_fix (`bool`, *optional*, defaults to `False`):
25
+ Whether to add additional tokens to the BERT's embedding layer.
26
+ Examples:
27
+ ```python
28
+ >>> from transformers import BertModel, BertConfig
29
+ >>> # Initializing a Seq2Labels style configuration
30
+ >>> configuration = Seq2LabelsConfig()
31
+ >>> # Initializing a model from the bert-base-uncased style configuration
32
+ >>> model = Seq2LabelsModel(configuration)
33
+ >>> # Accessing the model configuration
34
+ >>> configuration = model.config
35
+ ```"""
36
+ model_type = "bert"
37
+
38
+ def __init__(
39
+ self,
40
+ pretrained_name_or_path="bert-base-cased",
41
+ vocab_size=15,
42
+ num_detect_classes=4,
43
+ load_pretrained=False,
44
+ initializer_range=0.02,
45
+ pad_token_id=0,
46
+ use_cache=True,
47
+ predictor_dropout=0.0,
48
+ special_tokens_fix=False,
49
+ label_smoothing=0.0,
50
+ **kwargs
51
+ ):
52
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
53
+
54
+ self.vocab_size = vocab_size
55
+ self.num_detect_classes = num_detect_classes
56
+ self.pretrained_name_or_path = pretrained_name_or_path
57
+ self.load_pretrained = load_pretrained
58
+ self.initializer_range = initializer_range
59
+ self.use_cache = use_cache
60
+ self.predictor_dropout = predictor_dropout
61
+ self.special_tokens_fix = special_tokens_fix
62
+ self.label_smoothing = label_smoothing
gec_model.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Wrapper of Seq2Labels model. Fixes errors based on model predictions"""
2
+ from collections import defaultdict
3
+ from difflib import SequenceMatcher
4
+ import logging
5
+ import re
6
+ from time import time
7
+ from typing import List, Union
8
+ import warnings
9
+
10
+ import torch
11
+ from transformers import AutoTokenizer
12
+ from modeling_seq2labels import Seq2LabelsModel
13
+ from vocabulary import Vocabulary
14
+ from utils import PAD, UNK, START_TOKEN, get_target_sent_by_edits
15
+
16
+ logging.getLogger("werkzeug").setLevel(logging.ERROR)
17
+ logger = logging.getLogger(__file__)
18
+
19
+
20
+ class GecBERTModel(torch.nn.Module):
21
+ def __init__(
22
+ self,
23
+ vocab_path=None,
24
+ model_paths=None,
25
+ weights=None,
26
+ device=None,
27
+ max_len=64,
28
+ min_len=3,
29
+ lowercase_tokens=False,
30
+ log=False,
31
+ iterations=3,
32
+ min_error_probability=0.0,
33
+ confidence=0,
34
+ resolve_cycles=False,
35
+ split_chunk=False,
36
+ chunk_size=48,
37
+ overlap_size=12,
38
+ min_words_cut=6,
39
+ punc_dict={':', ".", ",", "?"},
40
+ ):
41
+ r"""
42
+ Args:
43
+ vocab_path (`str`):
44
+ Path to vocabulary directory.
45
+ model_paths (`List[str]`):
46
+ List of model paths.
47
+ weights (`int`, *Optional*, defaults to None):
48
+ Weights of each model. Only relevant if `is_ensemble is True`.
49
+ device (`int`, *Optional*, defaults to None):
50
+ Device to load model. If not set, device will be automatically choose.
51
+ max_len (`int`, defaults to 64):
52
+ Max sentence length to be processed (all longer will be truncated).
53
+ min_len (`int`, defaults to 3):
54
+ Min sentence length to be processed (all shorted will be returned w/o changes).
55
+ lowercase_tokens (`bool`, defaults to False):
56
+ Whether to lowercase tokens.
57
+ log (`bool`, defaults to False):
58
+ Whether to enable logging.
59
+ iterations (`int`, defaults to 3):
60
+ Max iterations to run during inference.
61
+ special_tokens_fix (`bool`, defaults to True):
62
+ Whether to fix problem with [CLS], [SEP] tokens tokenization.
63
+ min_error_probability (`float`, defaults to `0.0`):
64
+ Minimum probability for each action to apply.
65
+ confidence (`float`, defaults to `0.0`):
66
+ How many probability to add to $KEEP token.
67
+ split_chunk (`bool`, defaults to False):
68
+ Whether to split long sentences to multiple segments of `chunk_size`.
69
+ !Warning: if `chunk_size > max_len`, each segment will be truncate to `max_len`.
70
+ chunk_size (`int`, defaults to 48):
71
+ Length of each segment (in words). Only relevant if `split_chunk is True`.
72
+ overlap_size (`int`, defaults to 12):
73
+ Overlap size (in words) between two consecutive segments. Only relevant if `split_chunk is True`.
74
+ min_words_cut (`int`, defaults to 6):
75
+ Minimun number of words to be cut while merging two consecutive segments.
76
+ Only relevant if `split_chunk is True`.
77
+ punc_dict (List[str], defaults to `{':', ".", ",", "?"}`):
78
+ List of punctuations.
79
+ """
80
+ super().__init__()
81
+ if isinstance(model_paths, str):
82
+ model_paths = [model_paths]
83
+ self.model_weights = list(map(float, weights)) if weights else [1] * len(model_paths)
84
+ self.device = (
85
+ torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device)
86
+ )
87
+ self.max_len = max_len
88
+ self.min_len = min_len
89
+ self.lowercase_tokens = lowercase_tokens
90
+ self.min_error_probability = min_error_probability
91
+ self.vocab = Vocabulary.from_files(vocab_path)
92
+ self.log = log
93
+ self.iterations = iterations
94
+ self.confidence = confidence
95
+ self.resolve_cycles = resolve_cycles
96
+
97
+ assert (
98
+ chunk_size > 0 and chunk_size // 2 >= overlap_size
99
+ ), "Chunk merging required overlap size must be smaller than half of chunk size"
100
+ self.split_chunk = split_chunk
101
+ self.chunk_size = chunk_size
102
+ self.overlap_size = overlap_size
103
+ self.min_words_cut = min_words_cut
104
+ self.stride = chunk_size - overlap_size
105
+ self.punc_dict = punc_dict
106
+ self.punc_str = '[' + ''.join([f'\{x}' for x in punc_dict]) + ']'
107
+ # set training parameters and operations
108
+
109
+ self.indexers = []
110
+ self.models = []
111
+ for model_path in model_paths:
112
+ model = Seq2LabelsModel.from_pretrained(model_path)
113
+ config = model.config
114
+ model_name = config.pretrained_name_or_path
115
+ special_tokens_fix = config.special_tokens_fix
116
+ self.indexers.append(self._get_indexer(model_name, special_tokens_fix))
117
+ model.eval().to(self.device)
118
+ self.models.append(model)
119
+
120
+ def _get_indexer(self, weights_name, special_tokens_fix):
121
+ tokenizer = AutoTokenizer.from_pretrained(
122
+ weights_name, do_basic_tokenize=False, do_lower_case=self.lowercase_tokens, model_max_length=1024
123
+ )
124
+ # to adjust all tokenizers
125
+ if hasattr(tokenizer, 'encoder'):
126
+ tokenizer.vocab = tokenizer.encoder
127
+ if hasattr(tokenizer, 'sp_model'):
128
+ tokenizer.vocab = defaultdict(lambda: 1)
129
+ for i in range(tokenizer.sp_model.get_piece_size()):
130
+ tokenizer.vocab[tokenizer.sp_model.id_to_piece(i)] = i
131
+
132
+ if special_tokens_fix:
133
+ tokenizer.add_tokens([START_TOKEN])
134
+ tokenizer.vocab[START_TOKEN] = len(tokenizer) - 1
135
+ return tokenizer
136
+
137
+ def forward(self, text: Union[str, List[str], List[List[str]]], is_split_into_words=False):
138
+ # Input type checking for clearer error
139
+ def _is_valid_text_input(t):
140
+ if isinstance(t, str):
141
+ # Strings are fine
142
+ return True
143
+ elif isinstance(t, (list, tuple)):
144
+ # List are fine as long as they are...
145
+ if len(t) == 0:
146
+ # ... empty
147
+ return True
148
+ elif isinstance(t[0], str):
149
+ # ... list of strings
150
+ return True
151
+ elif isinstance(t[0], (list, tuple)):
152
+ # ... list with an empty list or with a list of strings
153
+ return len(t[0]) == 0 or isinstance(t[0][0], str)
154
+ else:
155
+ return False
156
+ else:
157
+ return False
158
+
159
+ if not _is_valid_text_input(text):
160
+ raise ValueError(
161
+ "text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) "
162
+ "or `List[List[str]]` (batch of pretokenized examples)."
163
+ )
164
+
165
+ if is_split_into_words:
166
+ is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
167
+ else:
168
+ is_batched = isinstance(text, (list, tuple))
169
+ if is_batched:
170
+ text = [x.split() for x in text]
171
+ else:
172
+ text = text.split()
173
+
174
+ if not is_batched:
175
+ text = [text]
176
+
177
+ return self.handle_batch(text)
178
+
179
+ def split_chunks(self, batch):
180
+ # return batch pairs of indices
181
+ result = []
182
+ indices = []
183
+ for tokens in batch:
184
+ start = len(result)
185
+ num_token = len(tokens)
186
+ if num_token <= self.chunk_size:
187
+ result.append(tokens)
188
+ elif num_token > self.chunk_size and num_token < (self.chunk_size * 2 - self.overlap_size):
189
+ split_idx = (num_token + self.overlap_size + 1) // 2
190
+ result.append(tokens[:split_idx])
191
+ result.append(tokens[split_idx - self.overlap_size :])
192
+ else:
193
+ for i in range(0, num_token - self.overlap_size, self.stride):
194
+ result.append(tokens[i : i + self.chunk_size])
195
+
196
+ indices.append((start, len(result)))
197
+
198
+ return result, indices
199
+
200
+ def check_alnum(self, s):
201
+ if len(s) < 2:
202
+ return False
203
+ return not (s.isalpha() or s.isdigit())
204
+
205
+ def apply_chunk_merging(self, tokens, next_tokens):
206
+ # Return next tokens if current tokens list is empty
207
+ if not tokens:
208
+ return next_tokens
209
+
210
+ source_token_idx = []
211
+ target_token_idx = []
212
+ source_tokens = []
213
+ target_tokens = []
214
+ num_keep = self.overlap_size - self.min_words_cut
215
+ i = 0
216
+ while len(source_token_idx) < self.overlap_size and -i < len(tokens):
217
+ i -= 1
218
+ if tokens[i] not in self.punc_dict:
219
+ source_token_idx.insert(0, i)
220
+ source_tokens.insert(0, tokens[i].lower())
221
+
222
+ i = 0
223
+ while len(target_token_idx) < self.overlap_size and i < len(next_tokens):
224
+ if next_tokens[i] not in self.punc_dict:
225
+ target_token_idx.append(i)
226
+ target_tokens.append(next_tokens[i].lower())
227
+ i += 1
228
+
229
+ matcher = SequenceMatcher(None, source_tokens, target_tokens)
230
+ diffs = list(matcher.get_opcodes())
231
+
232
+ for diff in diffs:
233
+ tag, i1, i2, j1, j2 = diff
234
+ if tag == "equal":
235
+ if i1 >= num_keep:
236
+ tail_idx = source_token_idx[i1]
237
+ head_idx = target_token_idx[j1]
238
+ break
239
+ elif i2 > num_keep:
240
+ tail_idx = source_token_idx[num_keep]
241
+ head_idx = target_token_idx[j2 - i2 + num_keep]
242
+ break
243
+ elif tag == "delete" and i1 == 0:
244
+ num_keep += i2 // 2
245
+
246
+ tokens = tokens[:tail_idx] + next_tokens[head_idx:]
247
+ return tokens
248
+
249
+ def merge_chunks(self, batch):
250
+ result = []
251
+ if len(batch) == 1 or self.overlap_size == 0:
252
+ for sub_tokens in batch:
253
+ result.extend(sub_tokens)
254
+ else:
255
+ for _, sub_tokens in enumerate(batch):
256
+ try:
257
+ result = self.apply_chunk_merging(result, sub_tokens)
258
+ except Exception as e:
259
+ print(e)
260
+
261
+ result = " ".join(result)
262
+ return result
263
+
264
+ def predict(self, batches):
265
+ t11 = time()
266
+ predictions = []
267
+ for batch, model in zip(batches, self.models):
268
+ batch = batch.to(self.device)
269
+ with torch.no_grad():
270
+ prediction = model.forward(**batch)
271
+ predictions.append(prediction)
272
+
273
+ preds, idx, error_probs = self._convert(predictions)
274
+ t55 = time()
275
+ if self.log:
276
+ print(f"Inference time {t55 - t11}")
277
+ return preds, idx, error_probs
278
+
279
+ def get_token_action(self, token, index, prob, sugg_token):
280
+ """Get lost of suggested actions for token."""
281
+ # cases when we don't need to do anything
282
+ if prob < self.min_error_probability or sugg_token in [UNK, PAD, '$KEEP']:
283
+ return None
284
+
285
+ if sugg_token.startswith('$REPLACE_') or sugg_token.startswith('$TRANSFORM_') or sugg_token == '$DELETE':
286
+ start_pos = index
287
+ end_pos = index + 1
288
+ elif sugg_token.startswith("$APPEND_") or sugg_token.startswith("$MERGE_"):
289
+ start_pos = index + 1
290
+ end_pos = index + 1
291
+
292
+ if sugg_token == "$DELETE":
293
+ sugg_token_clear = ""
294
+ elif sugg_token.startswith('$TRANSFORM_') or sugg_token.startswith("$MERGE_"):
295
+ sugg_token_clear = sugg_token[:]
296
+ else:
297
+ sugg_token_clear = sugg_token[sugg_token.index('_') + 1 :]
298
+
299
+ return start_pos - 1, end_pos - 1, sugg_token_clear, prob
300
+
301
+ def preprocess(self, token_batch):
302
+ seq_lens = [len(sequence) for sequence in token_batch if sequence]
303
+ if not seq_lens:
304
+ return []
305
+ max_len = min(max(seq_lens), self.max_len)
306
+ batches = []
307
+ for indexer in self.indexers:
308
+ token_batch = [[START_TOKEN] + sequence[:max_len] for sequence in token_batch]
309
+ batch = indexer(
310
+ token_batch,
311
+ return_tensors="pt",
312
+ padding=True,
313
+ is_split_into_words=True,
314
+ truncation=True,
315
+ add_special_tokens=False,
316
+ )
317
+ offset_batch = []
318
+ for i in range(len(token_batch)):
319
+ word_ids = batch.word_ids(batch_index=i)
320
+ offsets = [0]
321
+ for i in range(1, len(word_ids)):
322
+ if word_ids[i] != word_ids[i - 1]:
323
+ offsets.append(i)
324
+ offset_batch.append(torch.LongTensor(offsets))
325
+
326
+ batch["input_offsets"] = torch.nn.utils.rnn.pad_sequence(
327
+ offset_batch, batch_first=True, padding_value=0
328
+ ).to(torch.long)
329
+
330
+ batches.append(batch)
331
+
332
+ return batches
333
+
334
+ def _convert(self, data):
335
+ all_class_probs = torch.zeros_like(data[0]['logits'])
336
+ error_probs = torch.zeros_like(data[0]['max_error_probability'])
337
+ for output, weight in zip(data, self.model_weights):
338
+ class_probabilities_labels = torch.softmax(output['logits'], dim=-1)
339
+ all_class_probs += weight * class_probabilities_labels / sum(self.model_weights)
340
+ error_probs += weight * output['max_error_probability'] / sum(self.model_weights)
341
+
342
+ max_vals = torch.max(all_class_probs, dim=-1)
343
+ probs = max_vals[0].tolist()
344
+ idx = max_vals[1].tolist()
345
+ return probs, idx, error_probs.tolist()
346
+
347
+ def update_final_batch(self, final_batch, pred_ids, pred_batch, prev_preds_dict):
348
+ new_pred_ids = []
349
+ total_updated = 0
350
+ for i, orig_id in enumerate(pred_ids):
351
+ orig = final_batch[orig_id]
352
+ pred = pred_batch[i]
353
+ prev_preds = prev_preds_dict[orig_id]
354
+ if orig != pred and pred not in prev_preds:
355
+ final_batch[orig_id] = pred
356
+ new_pred_ids.append(orig_id)
357
+ prev_preds_dict[orig_id].append(pred)
358
+ total_updated += 1
359
+ elif orig != pred and pred in prev_preds:
360
+ # update final batch, but stop iterations
361
+ final_batch[orig_id] = pred
362
+ total_updated += 1
363
+ else:
364
+ continue
365
+ return final_batch, new_pred_ids, total_updated
366
+
367
+ def postprocess_batch(self, batch, all_probabilities, all_idxs, error_probs):
368
+ all_results = []
369
+ noop_index = self.vocab.get_token_index("$KEEP", "labels")
370
+ for tokens, probabilities, idxs, error_prob in zip(batch, all_probabilities, all_idxs, error_probs):
371
+ length = min(len(tokens), self.max_len)
372
+ edits = []
373
+
374
+ # skip whole sentences if there no errors
375
+ if max(idxs) == 0:
376
+ all_results.append(tokens)
377
+ continue
378
+
379
+ # skip whole sentence if probability of correctness is not high
380
+ if error_prob < self.min_error_probability:
381
+ all_results.append(tokens)
382
+ continue
383
+
384
+ for i in range(length + 1):
385
+ # because of START token
386
+ if i == 0:
387
+ token = START_TOKEN
388
+ else:
389
+ token = tokens[i - 1]
390
+ # skip if there is no error
391
+ if idxs[i] == noop_index:
392
+ continue
393
+
394
+ sugg_token = self.vocab.get_token_from_index(idxs[i], namespace='labels')
395
+ action = self.get_token_action(token, i, probabilities[i], sugg_token)
396
+ if not action:
397
+ continue
398
+
399
+ edits.append(action)
400
+ all_results.append(get_target_sent_by_edits(tokens, edits))
401
+ return all_results
402
+
403
+ def handle_batch(self, full_batch, merge_punc=True):
404
+ """
405
+ Handle batch of requests.
406
+ """
407
+ if self.split_chunk:
408
+ full_batch, indices = self.split_chunks(full_batch)
409
+ else:
410
+ indices = None
411
+ final_batch = full_batch[:]
412
+ batch_size = len(full_batch)
413
+ prev_preds_dict = {i: [final_batch[i]] for i in range(len(final_batch))}
414
+ short_ids = [i for i in range(len(full_batch)) if len(full_batch[i]) < self.min_len]
415
+ pred_ids = [i for i in range(len(full_batch)) if i not in short_ids]
416
+ total_updates = 0
417
+
418
+ for n_iter in range(self.iterations):
419
+ orig_batch = [final_batch[i] for i in pred_ids]
420
+
421
+ sequences = self.preprocess(orig_batch)
422
+
423
+ if not sequences:
424
+ break
425
+ probabilities, idxs, error_probs = self.predict(sequences)
426
+
427
+ pred_batch = self.postprocess_batch(orig_batch, probabilities, idxs, error_probs)
428
+ if self.log:
429
+ print(f"Iteration {n_iter + 1}. Predicted {round(100*len(pred_ids)/batch_size, 1)}% of sentences.")
430
+
431
+ final_batch, pred_ids, cnt = self.update_final_batch(final_batch, pred_ids, pred_batch, prev_preds_dict)
432
+ total_updates += cnt
433
+
434
+ if not pred_ids:
435
+ break
436
+ if self.split_chunk:
437
+ final_batch = [self.merge_chunks(final_batch[start:end]) for (start, end) in indices]
438
+ else:
439
+ final_batch = [" ".join(x) for x in final_batch]
440
+ if merge_punc:
441
+ final_batch = [re.sub(r'\s+(%s)' % self.punc_str, r'\1', x) for x in final_batch]
442
+
443
+ return final_batch
modeling_seq2labels.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple, Union
2
+ from torch import nn
3
+ from torch.nn import CrossEntropyLoss
4
+ from transformers import AutoConfig, AutoModel, BertPreTrainedModel
5
+ from transformers.modeling_outputs import ModelOutput
6
+
7
+ import torch
8
+
9
+
10
+ def get_range_vector(size: int, device: int) -> torch.Tensor:
11
+ """
12
+ Returns a range vector with the desired size, starting at 0. The CUDA implementation
13
+ is meant to avoid copy data from CPU to GPU.
14
+ """
15
+ return torch.arange(0, size, dtype=torch.long, device=device)
16
+
17
+
18
+ class Seq2LabelsOutput(ModelOutput):
19
+ loss: Optional[torch.FloatTensor] = None
20
+ logits: torch.FloatTensor = None
21
+ detect_logits: torch.FloatTensor = None
22
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
23
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
24
+ max_error_probability: Optional[torch.FloatTensor] = None
25
+
26
+
27
+ class Seq2LabelsModel(BertPreTrainedModel):
28
+
29
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
30
+
31
+ def __init__(self, config):
32
+ super().__init__(config)
33
+ self.num_labels = config.num_labels
34
+ self.num_detect_classes = config.num_detect_classes
35
+ self.label_smoothing = config.label_smoothing
36
+
37
+ if config.load_pretrained:
38
+ self.bert = AutoModel.from_pretrained(config.pretrained_name_or_path)
39
+ bert_config = self.bert.config
40
+ else:
41
+ bert_config = AutoConfig.from_pretrained(config.pretrained_name_or_path)
42
+ self.bert = AutoModel.from_config(bert_config)
43
+
44
+ if config.special_tokens_fix:
45
+ try:
46
+ vocab_size = self.bert.embeddings.word_embeddings.num_embeddings
47
+ except AttributeError:
48
+ # reserve more space
49
+ vocab_size = self.bert.word_embedding.num_embeddings + 5
50
+ self.bert.resize_token_embeddings(vocab_size + 1)
51
+
52
+ predictor_dropout = config.predictor_dropout if config.predictor_dropout is not None else 0.0
53
+ self.dropout = nn.Dropout(predictor_dropout)
54
+ self.classifier = nn.Linear(bert_config.hidden_size, config.vocab_size)
55
+ self.detector = nn.Linear(bert_config.hidden_size, config.num_detect_classes)
56
+
57
+ # Initialize weights and apply final processing
58
+ self.post_init()
59
+
60
+ def forward(
61
+ self,
62
+ input_ids: Optional[torch.Tensor] = None,
63
+ input_offsets: Optional[torch.Tensor] = None,
64
+ attention_mask: Optional[torch.Tensor] = None,
65
+ token_type_ids: Optional[torch.Tensor] = None,
66
+ position_ids: Optional[torch.Tensor] = None,
67
+ head_mask: Optional[torch.Tensor] = None,
68
+ inputs_embeds: Optional[torch.Tensor] = None,
69
+ labels: Optional[torch.Tensor] = None,
70
+ d_tags: Optional[torch.Tensor] = None,
71
+ output_attentions: Optional[bool] = None,
72
+ output_hidden_states: Optional[bool] = None,
73
+ return_dict: Optional[bool] = None,
74
+ ) -> Union[Tuple[torch.Tensor], Seq2LabelsOutput]:
75
+ r"""
76
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
77
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
78
+ """
79
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
80
+
81
+ outputs = self.bert(
82
+ input_ids,
83
+ attention_mask=attention_mask,
84
+ token_type_ids=token_type_ids,
85
+ position_ids=position_ids,
86
+ head_mask=head_mask,
87
+ inputs_embeds=inputs_embeds,
88
+ output_attentions=output_attentions,
89
+ output_hidden_states=output_hidden_states,
90
+ return_dict=return_dict,
91
+ )
92
+
93
+ sequence_output = outputs[0]
94
+
95
+ if input_offsets is not None:
96
+ # offsets is (batch_size, d1, ..., dn, orig_sequence_length)
97
+ range_vector = get_range_vector(input_offsets.size(0), device=sequence_output.device).unsqueeze(1)
98
+ # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
99
+ sequence_output = sequence_output[range_vector, input_offsets]
100
+
101
+ logits = self.classifier(self.dropout(sequence_output))
102
+ logits_d = self.detector(sequence_output)
103
+
104
+ loss = None
105
+ if labels is not None and d_tags is not None:
106
+ loss_labels_fct = CrossEntropyLoss(label_smoothing=self.label_smoothing)
107
+ loss_d_fct = CrossEntropyLoss()
108
+ loss_labels = loss_labels_fct(logits.view(-1, self.num_labels), labels.view(-1))
109
+ loss_d = loss_d_fct(logits_d.view(-1, self.num_detect_classes), d_tags.view(-1))
110
+ loss = loss_labels + loss_d
111
+
112
+ if not return_dict:
113
+ output = (logits, logits_d) + outputs[2:]
114
+ return ((loss,) + output) if loss is not None else output
115
+
116
+ return Seq2LabelsOutput(
117
+ loss=loss,
118
+ logits=logits,
119
+ detect_logits=logits_d,
120
+ hidden_states=outputs.hidden_states,
121
+ attentions=outputs.attentions,
122
+ max_error_probability=torch.ones(logits.size(0)),
123
+ )
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ argparse
2
+ numpy
3
+ python-Levenshtein==0.12.2
4
+ scikit-learn
5
+ torch>=1.9.1
6
+ transformers==4.17.0
7
+ sentencepiece
utils.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import re
4
+
5
+
6
+ VOCAB_DIR = Path(__file__).resolve().parent
7
+ PAD = "@@PADDING@@"
8
+ UNK = "@@UNKNOWN@@"
9
+ START_TOKEN = "$START"
10
+ SEQ_DELIMETERS = {"tokens": " ", "labels": "SEPL|||SEPR", "operations": "SEPL__SEPR"}
11
+
12
+
13
+ def get_verb_form_dicts():
14
+ path_to_dict = os.path.join(VOCAB_DIR, "verb-form-vocab.txt")
15
+ encode, decode = {}, {}
16
+ with open(path_to_dict, encoding="utf-8") as f:
17
+ for line in f:
18
+ words, tags = line.split(":")
19
+ word1, word2 = words.split("_")
20
+ tag1, tag2 = tags.split("_")
21
+ decode_key = f"{word1}_{tag1}_{tag2.strip()}"
22
+ if decode_key not in decode:
23
+ encode[words] = tags
24
+ decode[decode_key] = word2
25
+ return encode, decode
26
+
27
+
28
+ ENCODE_VERB_DICT, DECODE_VERB_DICT = get_verb_form_dicts()
29
+
30
+
31
+ def get_target_sent_by_edits(source_tokens, edits):
32
+ target_tokens = source_tokens[:]
33
+ shift_idx = 0
34
+ for edit in edits:
35
+ start, end, label, _ = edit
36
+ target_pos = start + shift_idx
37
+ if start < 0:
38
+ continue
39
+ elif len(target_tokens) > target_pos:
40
+ source_token = target_tokens[target_pos]
41
+ else:
42
+ source_token = ""
43
+ if label == "":
44
+ del target_tokens[target_pos]
45
+ shift_idx -= 1
46
+ elif start == end:
47
+ word = label.replace("$APPEND_", "")
48
+ # Avoid appending same token twice
49
+ if (target_pos < len(target_tokens) and target_tokens[target_pos] == word) or (
50
+ target_pos > 0 and target_tokens[target_pos - 1] == word
51
+ ):
52
+ continue
53
+ target_tokens[target_pos:target_pos] = [word]
54
+ shift_idx += 1
55
+ elif label.startswith("$TRANSFORM_"):
56
+ word = apply_reverse_transformation(source_token, label)
57
+ if word is None:
58
+ word = source_token
59
+ target_tokens[target_pos] = word
60
+ elif start == end - 1:
61
+ word = label.replace("$REPLACE_", "")
62
+ target_tokens[target_pos] = word
63
+ elif label.startswith("$MERGE_"):
64
+ target_tokens[target_pos + 1 : target_pos + 1] = [label]
65
+ shift_idx += 1
66
+
67
+ return replace_merge_transforms(target_tokens)
68
+
69
+
70
+ def replace_merge_transforms(tokens):
71
+ if all(not x.startswith("$MERGE_") for x in tokens):
72
+ return tokens
73
+ if tokens[0].startswith("$MERGE_"):
74
+ tokens = tokens[1:]
75
+ if tokens[-1].startswith("$MERGE_"):
76
+ tokens = tokens[:-1]
77
+
78
+ target_line = " ".join(tokens)
79
+ target_line = target_line.replace(" $MERGE_HYPHEN ", "-")
80
+ target_line = target_line.replace(" $MERGE_SPACE ", "")
81
+ target_line = re.sub(r'([\.\,\?\:]\s+)+', r'\1', target_line)
82
+ return target_line.split()
83
+
84
+
85
+ def convert_using_case(token, smart_action):
86
+ if not smart_action.startswith("$TRANSFORM_CASE_"):
87
+ return token
88
+ if smart_action.endswith("LOWER"):
89
+ return token.lower()
90
+ elif smart_action.endswith("UPPER"):
91
+ return token.upper()
92
+ elif smart_action.endswith("CAPITAL"):
93
+ return token.capitalize()
94
+ elif smart_action.endswith("CAPITAL_1"):
95
+ return token[0] + token[1:].capitalize()
96
+ elif smart_action.endswith("UPPER_-1"):
97
+ return token[:-1].upper() + token[-1]
98
+ else:
99
+ return token
100
+
101
+
102
+ def convert_using_verb(token, smart_action):
103
+ key_word = "$TRANSFORM_VERB_"
104
+ if not smart_action.startswith(key_word):
105
+ raise Exception(f"Unknown action type {smart_action}")
106
+ encoding_part = f"{token}_{smart_action[len(key_word):]}"
107
+ decoded_target_word = decode_verb_form(encoding_part)
108
+ return decoded_target_word
109
+
110
+
111
+ def convert_using_split(token, smart_action):
112
+ key_word = "$TRANSFORM_SPLIT"
113
+ if not smart_action.startswith(key_word):
114
+ raise Exception(f"Unknown action type {smart_action}")
115
+ target_words = token.split("-")
116
+ return " ".join(target_words)
117
+
118
+
119
+ def convert_using_plural(token, smart_action):
120
+ if smart_action.endswith("PLURAL"):
121
+ return token + "s"
122
+ elif smart_action.endswith("SINGULAR"):
123
+ return token[:-1]
124
+ else:
125
+ raise Exception(f"Unknown action type {smart_action}")
126
+
127
+
128
+ def apply_reverse_transformation(source_token, transform):
129
+ if transform.startswith("$TRANSFORM"):
130
+ # deal with equal
131
+ if transform == "$KEEP":
132
+ return source_token
133
+ # deal with case
134
+ if transform.startswith("$TRANSFORM_CASE"):
135
+ return convert_using_case(source_token, transform)
136
+ # deal with verb
137
+ if transform.startswith("$TRANSFORM_VERB"):
138
+ return convert_using_verb(source_token, transform)
139
+ # deal with split
140
+ if transform.startswith("$TRANSFORM_SPLIT"):
141
+ return convert_using_split(source_token, transform)
142
+ # deal with single/plural
143
+ if transform.startswith("$TRANSFORM_AGREEMENT"):
144
+ return convert_using_plural(source_token, transform)
145
+ # raise exception if not find correct type
146
+ raise Exception(f"Unknown action type {transform}")
147
+ else:
148
+ return source_token
149
+
150
+
151
+ # def read_parallel_lines(fn1, fn2):
152
+ # lines1 = read_lines(fn1, skip_strip=True)
153
+ # lines2 = read_lines(fn2, skip_strip=True)
154
+ # assert len(lines1) == len(lines2)
155
+ # out_lines1, out_lines2 = [], []
156
+ # for line1, line2 in zip(lines1, lines2):
157
+ # if not line1.strip() or not line2.strip():
158
+ # continue
159
+ # else:
160
+ # out_lines1.append(line1)
161
+ # out_lines2.append(line2)
162
+ # return out_lines1, out_lines2
163
+
164
+
165
+ def read_parallel_lines(fn1, fn2):
166
+ with open(fn1, encoding='utf-8') as f1, open(fn2, encoding='utf-8') as f2:
167
+ for line1, line2 in zip(f1, f2):
168
+ line1 = line1.strip()
169
+ line2 = line2.strip()
170
+
171
+ yield line1, line2
172
+
173
+
174
+ def read_lines(fn, skip_strip=False):
175
+ if not os.path.exists(fn):
176
+ return []
177
+ with open(fn, 'r', encoding='utf-8') as f:
178
+ lines = f.readlines()
179
+ return [s.strip() for s in lines if s.strip() or skip_strip]
180
+
181
+
182
+ def write_lines(fn, lines, mode='w'):
183
+ if mode == 'w' and os.path.exists(fn):
184
+ os.remove(fn)
185
+ with open(fn, encoding='utf-8', mode=mode) as f:
186
+ f.writelines(['%s\n' % s for s in lines])
187
+
188
+
189
+ def decode_verb_form(original):
190
+ return DECODE_VERB_DICT.get(original)
191
+
192
+
193
+ def encode_verb_form(original_word, corrected_word):
194
+ decoding_request = original_word + "_" + corrected_word
195
+ decoding_response = ENCODE_VERB_DICT.get(decoding_request, "").strip()
196
+ if original_word and decoding_response:
197
+ answer = decoding_response
198
+ else:
199
+ answer = None
200
+ return answer
201
+
202
+
203
+ def get_weights_name(transformer_name, lowercase):
204
+ if transformer_name == 'bert' and lowercase:
205
+ return 'bert-base-uncased'
206
+ if transformer_name == 'bert' and not lowercase:
207
+ return 'bert-base-cased'
208
+ if transformer_name == 'bert-large' and not lowercase:
209
+ return 'bert-large-cased'
210
+ if transformer_name == 'distilbert':
211
+ if not lowercase:
212
+ print('Warning! This model was trained only on uncased sentences.')
213
+ return 'distilbert-base-uncased'
214
+ if transformer_name == 'albert':
215
+ if not lowercase:
216
+ print('Warning! This model was trained only on uncased sentences.')
217
+ return 'albert-base-v1'
218
+ if lowercase:
219
+ print('Warning! This model was trained only on cased sentences.')
220
+ if transformer_name == 'roberta':
221
+ return 'roberta-base'
222
+ if transformer_name == 'roberta-large':
223
+ return 'roberta-large'
224
+ if transformer_name == 'gpt2':
225
+ return 'gpt2'
226
+ if transformer_name == 'transformerxl':
227
+ return 'transfo-xl-wt103'
228
+ if transformer_name == 'xlnet':
229
+ return 'xlnet-base-cased'
230
+ if transformer_name == 'xlnet-large':
231
+ return 'xlnet-large-cased'
232
+
233
+ return transformer_name
verb-form-vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
vocabulary.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import codecs
2
+ from collections import defaultdict
3
+ import logging
4
+ import os
5
+ import re
6
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union, TYPE_CHECKING
7
+ from filelock import FileLock
8
+
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ DEFAULT_NON_PADDED_NAMESPACES = ("*tags", "*labels")
13
+ DEFAULT_PADDING_TOKEN = "@@PADDING@@"
14
+ DEFAULT_OOV_TOKEN = "@@UNKNOWN@@"
15
+ NAMESPACE_PADDING_FILE = "non_padded_namespaces.txt"
16
+ _NEW_LINE_REGEX = re.compile(r"\n|\r\n")
17
+
18
+
19
+ def namespace_match(pattern: str, namespace: str):
20
+ """
21
+ Matches a namespace pattern against a namespace string. For example, `*tags` matches
22
+ `passage_tags` and `question_tags` and `tokens` matches `tokens` but not
23
+ `stemmed_tokens`.
24
+ """
25
+ if pattern[0] == "*" and namespace.endswith(pattern[1:]):
26
+ return True
27
+ elif pattern == namespace:
28
+ return True
29
+ return False
30
+
31
+
32
+ class _NamespaceDependentDefaultDict(defaultdict):
33
+ """
34
+ This is a [defaultdict]
35
+ (https://docs.python.org/2/library/collections.html#collections.defaultdict) where the
36
+ default value is dependent on the key that is passed.
37
+ We use "namespaces" in the :class:`Vocabulary` object to keep track of several different
38
+ mappings from strings to integers, so that we have a consistent API for mapping words, tags,
39
+ labels, characters, or whatever else you want, into integers. The issue is that some of those
40
+ namespaces (words and characters) should have integers reserved for padding and
41
+ out-of-vocabulary tokens, while others (labels and tags) shouldn't. This class allows you to
42
+ specify filters on the namespace (the key used in the `defaultdict`), and use different
43
+ default values depending on whether the namespace passes the filter.
44
+ To do filtering, we take a set of `non_padded_namespaces`. This is a set of strings
45
+ that are either matched exactly against the keys, or treated as suffixes, if the
46
+ string starts with `*`. In other words, if `*tags` is in `non_padded_namespaces` then
47
+ `passage_tags`, `question_tags`, etc. (anything that ends with `tags`) will have the
48
+ `non_padded` default value.
49
+ # Parameters
50
+ non_padded_namespaces : `Iterable[str]`
51
+ A set / list / tuple of strings describing which namespaces are not padded. If a namespace
52
+ (key) is missing from this dictionary, we will use :func:`namespace_match` to see whether
53
+ the namespace should be padded. If the given namespace matches any of the strings in this
54
+ list, we will use `non_padded_function` to initialize the value for that namespace, and
55
+ we will use `padded_function` otherwise.
56
+ padded_function : `Callable[[], Any]`
57
+ A zero-argument function to call to initialize a value for a namespace that `should` be
58
+ padded.
59
+ non_padded_function : `Callable[[], Any]`
60
+ A zero-argument function to call to initialize a value for a namespace that should `not` be
61
+ padded.
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ non_padded_namespaces: Iterable[str],
67
+ padded_function: Callable[[], Any],
68
+ non_padded_function: Callable[[], Any],
69
+ ) -> None:
70
+ self._non_padded_namespaces = set(non_padded_namespaces)
71
+ self._padded_function = padded_function
72
+ self._non_padded_function = non_padded_function
73
+ super().__init__()
74
+
75
+ def add_non_padded_namespaces(self, non_padded_namespaces: Set[str]):
76
+ # add non_padded_namespaces which weren't already present
77
+ self._non_padded_namespaces.update(non_padded_namespaces)
78
+
79
+
80
+ class _TokenToIndexDefaultDict(_NamespaceDependentDefaultDict):
81
+ def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None:
82
+ super().__init__(non_padded_namespaces, lambda: {padding_token: 0, oov_token: 1}, lambda: {})
83
+
84
+
85
+ class _IndexToTokenDefaultDict(_NamespaceDependentDefaultDict):
86
+ def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None:
87
+ super().__init__(non_padded_namespaces, lambda: {0: padding_token, 1: oov_token}, lambda: {})
88
+
89
+
90
+ class Vocabulary:
91
+ def __init__(
92
+ self,
93
+ counter: Dict[str, Dict[str, int]] = None,
94
+ min_count: Dict[str, int] = None,
95
+ max_vocab_size: Union[int, Dict[str, int]] = None,
96
+ non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES,
97
+ pretrained_files: Optional[Dict[str, str]] = None,
98
+ only_include_pretrained_words: bool = False,
99
+ tokens_to_add: Dict[str, List[str]] = None,
100
+ min_pretrained_embeddings: Dict[str, int] = None,
101
+ padding_token: Optional[str] = DEFAULT_PADDING_TOKEN,
102
+ oov_token: Optional[str] = DEFAULT_OOV_TOKEN,
103
+ ) -> None:
104
+ self._padding_token = padding_token if padding_token is not None else DEFAULT_PADDING_TOKEN
105
+ self._oov_token = oov_token if oov_token is not None else DEFAULT_OOV_TOKEN
106
+
107
+ self._non_padded_namespaces = set(non_padded_namespaces)
108
+
109
+ self._token_to_index = _TokenToIndexDefaultDict(
110
+ self._non_padded_namespaces, self._padding_token, self._oov_token
111
+ )
112
+ self._index_to_token = _IndexToTokenDefaultDict(
113
+ self._non_padded_namespaces, self._padding_token, self._oov_token
114
+ )
115
+
116
+ @classmethod
117
+ def from_files(
118
+ cls,
119
+ directory: Union[str, os.PathLike],
120
+ padding_token: Optional[str] = DEFAULT_PADDING_TOKEN,
121
+ oov_token: Optional[str] = DEFAULT_OOV_TOKEN,
122
+ ) -> "Vocabulary":
123
+ """
124
+ Loads a `Vocabulary` that was serialized either using `save_to_files` or inside
125
+ a model archive file.
126
+ # Parameters
127
+ directory : `str`
128
+ The directory or archive file containing the serialized vocabulary.
129
+ """
130
+ logger.info("Loading token dictionary from %s.", directory)
131
+ padding_token = padding_token if padding_token is not None else DEFAULT_PADDING_TOKEN
132
+ oov_token = oov_token if oov_token is not None else DEFAULT_OOV_TOKEN
133
+
134
+ if not os.path.isdir(directory):
135
+ raise ValueError(f"{directory} not exist")
136
+
137
+ # We use a lock file to avoid race conditions where multiple processes
138
+ # might be reading/writing from/to the same vocab files at once.
139
+ with FileLock(os.path.join(directory, ".lock")):
140
+ with codecs.open(os.path.join(directory, NAMESPACE_PADDING_FILE), "r", "utf-8") as namespace_file:
141
+ non_padded_namespaces = [namespace_str.strip() for namespace_str in namespace_file]
142
+
143
+ vocab = cls(
144
+ non_padded_namespaces=non_padded_namespaces,
145
+ padding_token=padding_token,
146
+ oov_token=oov_token,
147
+ )
148
+
149
+ # Check every file in the directory.
150
+ for namespace_filename in os.listdir(directory):
151
+ if namespace_filename == NAMESPACE_PADDING_FILE:
152
+ continue
153
+ if namespace_filename.startswith("."):
154
+ continue
155
+ namespace = namespace_filename.replace(".txt", "")
156
+ if any(namespace_match(pattern, namespace) for pattern in non_padded_namespaces):
157
+ is_padded = False
158
+ else:
159
+ is_padded = True
160
+ filename = os.path.join(directory, namespace_filename)
161
+ vocab.set_from_file(filename, is_padded, namespace=namespace, oov_token=oov_token)
162
+
163
+ return vocab
164
+
165
+ @classmethod
166
+ def empty(cls) -> "Vocabulary":
167
+ """
168
+ This method returns a bare vocabulary instantiated with `cls()` (so, `Vocabulary()` if you
169
+ haven't made a subclass of this object). The only reason to call `Vocabulary.empty()`
170
+ instead of `Vocabulary()` is if you are instantiating this object from a config file. We
171
+ register this constructor with the key "empty", so if you know that you don't need to
172
+ compute a vocabulary (either because you're loading a pre-trained model from an archive
173
+ file, you're using a pre-trained transformer that has its own vocabulary, or something
174
+ else), you can use this to avoid having the default vocabulary construction code iterate
175
+ through the data.
176
+ """
177
+ return cls()
178
+
179
+ def set_from_file(
180
+ self,
181
+ filename: str,
182
+ is_padded: bool = True,
183
+ oov_token: str = DEFAULT_OOV_TOKEN,
184
+ namespace: str = "tokens",
185
+ ):
186
+ """
187
+ If you already have a vocabulary file for a trained model somewhere, and you really want to
188
+ use that vocabulary file instead of just setting the vocabulary from a dataset, for
189
+ whatever reason, you can do that with this method. You must specify the namespace to use,
190
+ and we assume that you want to use padding and OOV tokens for this.
191
+ # Parameters
192
+ filename : `str`
193
+ The file containing the vocabulary to load. It should be formatted as one token per
194
+ line, with nothing else in the line. The index we assign to the token is the line
195
+ number in the file (1-indexed if `is_padded`, 0-indexed otherwise). Note that this
196
+ file should contain the OOV token string!
197
+ is_padded : `bool`, optional (default=`True`)
198
+ Is this vocabulary padded? For token / word / character vocabularies, this should be
199
+ `True`; while for tag or label vocabularies, this should typically be `False`. If
200
+ `True`, we add a padding token with index 0, and we enforce that the `oov_token` is
201
+ present in the file.
202
+ oov_token : `str`, optional (default=`DEFAULT_OOV_TOKEN`)
203
+ What token does this vocabulary use to represent out-of-vocabulary characters? This
204
+ must show up as a line in the vocabulary file. When we find it, we replace
205
+ `oov_token` with `self._oov_token`, because we only use one OOV token across
206
+ namespaces.
207
+ namespace : `str`, optional (default=`"tokens"`)
208
+ What namespace should we overwrite with this vocab file?
209
+ """
210
+ if is_padded:
211
+ self._token_to_index[namespace] = {self._padding_token: 0}
212
+ self._index_to_token[namespace] = {0: self._padding_token}
213
+ else:
214
+ self._token_to_index[namespace] = {}
215
+ self._index_to_token[namespace] = {}
216
+ with codecs.open(filename, "r", "utf-8") as input_file:
217
+ lines = _NEW_LINE_REGEX.split(input_file.read())
218
+ # Be flexible about having final newline or not
219
+ if lines and lines[-1] == "":
220
+ lines = lines[:-1]
221
+ for i, line in enumerate(lines):
222
+ index = i + 1 if is_padded else i
223
+ token = line.replace("@@NEWLINE@@", "\n")
224
+ if token == oov_token:
225
+ token = self._oov_token
226
+ self._token_to_index[namespace][token] = index
227
+ self._index_to_token[namespace][index] = token
228
+ if is_padded:
229
+ assert self._oov_token in self._token_to_index[namespace], "OOV token not found!"
230
+
231
+ def add_token_to_namespace(self, token: str, namespace: str = "tokens") -> int:
232
+ """
233
+ Adds `token` to the index, if it is not already present. Either way, we return the index of
234
+ the token.
235
+ """
236
+ if not isinstance(token, str):
237
+ raise ValueError(
238
+ "Vocabulary tokens must be strings, or saving and loading will break."
239
+ " Got %s (with type %s)" % (repr(token), type(token))
240
+ )
241
+ if token not in self._token_to_index[namespace]:
242
+ index = len(self._token_to_index[namespace])
243
+ self._token_to_index[namespace][token] = index
244
+ self._index_to_token[namespace][index] = token
245
+ return index
246
+ else:
247
+ return self._token_to_index[namespace][token]
248
+
249
+ def add_tokens_to_namespace(self, tokens: List[str], namespace: str = "tokens") -> List[int]:
250
+ """
251
+ Adds `tokens` to the index, if they are not already present. Either way, we return the
252
+ indices of the tokens in the order that they were given.
253
+ """
254
+ return [self.add_token_to_namespace(token, namespace) for token in tokens]
255
+
256
+ def get_token_index(self, token: str, namespace: str = "tokens") -> int:
257
+ try:
258
+ return self._token_to_index[namespace][token]
259
+ except KeyError:
260
+ try:
261
+ return self._token_to_index[namespace][self._oov_token]
262
+ except KeyError:
263
+ logger.error("Namespace: %s", namespace)
264
+ logger.error("Token: %s", token)
265
+ raise KeyError(
266
+ f"'{token}' not found in vocab namespace '{namespace}', and namespace "
267
+ f"does not contain the default OOV token ('{self._oov_token}')"
268
+ )
269
+
270
+ def get_token_from_index(self, index: int, namespace: str = "tokens") -> str:
271
+ return self._index_to_token[namespace][index]
272
+
273
+ def get_vocab_size(self, namespace: str = "tokens") -> int:
274
+ return len(self._token_to_index[namespace])
275
+
276
+ def get_namespaces(self) -> Set[str]:
277
+ return set(self._index_to_token.keys())
vocabulary/.lock ADDED
File without changes
vocabulary/d_tags.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ CORRECT
2
+ INCORRECT
3
+ @@UNKNOWN@@
4
+ @@PADDING@@
vocabulary/labels.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $KEEP
2
+ $TRANSFORM_CASE_CAPITAL
3
+ $APPEND_,
4
+ $APPEND_.
5
+ $TRANSFORM_VERB_VB_VBN
6
+ $TRANSFORM_CASE_UPPER
7
+ $APPEND_:
8
+ $APPEND_?
9
+ $TRANSFORM_VERB_VB_VBC
10
+ $TRANSFORM_CASE_LOWER
11
+ $TRANSFORM_CASE_CAPITAL_1
12
+ $TRANSFORM_CASE_UPPER_-1
13
+ $MERGE_SPACE
14
+ @@UNKNOWN@@
15
+ @@PADDING@@
vocabulary/non_padded_namespaces.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *tags
2
+ *labels