ktzsh commited on
Commit
010f214
1 Parent(s): 9f441d4

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ config.json filter=lfs diff=lfs merge=lfs -text
added_tokens.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</s>": 2,
3
+ "<cls>": 3,
4
+ "<eod>": 7,
5
+ "<eop>": 8,
6
+ "<mask>": 6,
7
+ "<pad>": 5,
8
+ "<s>": 1,
9
+ "<sep>": 4,
10
+ "<unk>": 0
11
+ }
config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97e03cd3dd250c9819297c1c1f6099ad8d59b374c344f07c633a79c68bce182f
3
+ size 11109513
configuration_gector.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ from typing import OrderedDict, Mapping, Union
5
+
6
+ from transformers import PretrainedConfig, AutoConfig
7
+ from transformers.onnx import OnnxConfig
8
+
9
+
10
+ class GectorConfig(PretrainedConfig):
11
+ model_type = "gector"
12
+
13
+ # To add config values from base model config
14
+ def __subclassconfig__(self, base_config: AutoConfig):
15
+ if base_config:
16
+ self.__dict__.update(base_config.__dict__)
17
+
18
+ def __init__(
19
+ self,
20
+ model_id: str = None,
21
+ id2label: dict = None,
22
+ label2id: dict = None,
23
+ detect_id2label: dict = None,
24
+ detect_label2id: dict = None,
25
+ classifier_dropout: float = 0,
26
+ label_pad_token: str = "<PAD>",
27
+ label_unknown_token: str = "<UNK>",
28
+ detect_pad_token_id: int = 3,
29
+ correct_pad_token_id: int = 5001,
30
+ num_detect_tags: int = 4,
31
+ num_correct_tags: int = 5002,
32
+ max_length: int = 128,
33
+ label_smoothing: float = 0.0,
34
+ special_tokens_fix: bool = False,
35
+ delete_confidence: float = 0.0,
36
+ additional_confidence: float = 0.2,
37
+ base_config: AutoConfig = None,
38
+ verb_form_vocab: dict = None,
39
+ **kwargs,
40
+ ):
41
+ super().__init__(**kwargs)
42
+ self.__subclassconfig__(base_config)
43
+
44
+ self.model_id = model_id
45
+ self.label2id = label2id
46
+ self.id2label = id2label
47
+ self.detect_label2id = detect_label2id
48
+ self.detect_id2label = detect_id2label
49
+ self.detect_pad_token_id = detect_pad_token_id
50
+ self.correct_pad_token_id = correct_pad_token_id
51
+ self.num_detect_tags = num_detect_tags
52
+ self.num_correct_tags = num_correct_tags
53
+ self.classifier_dropout = classifier_dropout
54
+ self.max_length = max_length
55
+ self.label_smoothing = label_smoothing
56
+ self.special_tokens_fix = special_tokens_fix
57
+ self.delete_confidence = delete_confidence
58
+ self.additional_confidence = additional_confidence
59
+ self.verb_form_vocab = verb_form_vocab
60
+
61
+ # def save_pretrained(
62
+ # self,
63
+ # save_directory: Union[str, os.PathLike],
64
+ # push_to_hub: bool = False,
65
+ # **kwargs,
66
+ # ):
67
+ # if os.path.isfile(save_directory):
68
+ # raise AssertionError(
69
+ # f"Provided path ({save_directory}) should be a directory, not a file"
70
+ # )
71
+
72
+ # os.makedirs(save_directory, exist_ok=True)
73
+
74
+ # if self.verb_form_vocab:
75
+ # verb_form_vocab_file = os.path.join(save_directory, "verb_form_vocab.json")
76
+ # with open(verb_form_vocab_file, "w", encoding="utf-8") as writer:
77
+ # writer.write(json.dumps(self.verb_form_vocab, indent=2, sort_keys=True) + "\n")
78
+
79
+ # super().save_pretrained(save_directory, push_to_hub, **kwargs)
80
+
81
+
82
+ class GectorOnnxConfig(OnnxConfig):
83
+ @property
84
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
85
+ dynamic_axis = {0: "batch", 1: "sequence"}
86
+ return OrderedDict(
87
+ [
88
+ ("input_ids", dynamic_axis),
89
+ ("attention_mask", dynamic_axis),
90
+ ]
91
+ )
grammar_error_correction_pipeline.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+
4
+ from transformers import Pipeline, TensorType
5
+
6
+
7
+ class GectorBase(object):
8
+ DELIMINTER = " "
9
+ START_TOKEN = "$START"
10
+ PAD = "<PAD>"
11
+ UNK = "<UNK>"
12
+ REPLACEMENTS = {
13
+ "''": '"',
14
+ "--": "—",
15
+ "`": "'",
16
+ "'ve": "' ve",
17
+ }
18
+
19
+ def decode_verb_form(self, original):
20
+ return self.model.config.verb_form_vocab["decode"].get(original)
21
+
22
+ def get_target_sent_by_edits(self, source_tokens, edits):
23
+ target_tokens = source_tokens[:]
24
+ shift_idx = 0
25
+ for edit in edits:
26
+ start, end, label, _ = edit
27
+ target_pos = start + shift_idx
28
+ source_token = (
29
+ target_tokens[target_pos]
30
+ if len(target_tokens) > target_pos >= 0
31
+ else ""
32
+ )
33
+ if label == "":
34
+ del target_tokens[target_pos]
35
+ shift_idx -= 1
36
+ elif start == end:
37
+ word = label.replace("$APPEND_", "")
38
+ target_tokens[target_pos:target_pos] = [word]
39
+ shift_idx += 1
40
+ elif label.startswith("$TRANSFORM_"):
41
+ word = self.apply_reverse_transformation(source_token, label)
42
+ if word is None:
43
+ word = source_token
44
+ target_tokens[target_pos] = word
45
+ elif start == end - 1:
46
+ word = label.replace("$REPLACE_", "")
47
+ target_tokens[target_pos] = word
48
+ elif label.startswith("$MERGE_"):
49
+ target_tokens[target_pos + 1 : target_pos + 1] = [label]
50
+ shift_idx += 1
51
+
52
+ return self.replace_merge_transforms(target_tokens)
53
+
54
+ def replace_merge_transforms(self, tokens):
55
+ if all(not x.startswith("$MERGE_") for x in tokens):
56
+ return tokens
57
+
58
+ target_line = " ".join(tokens)
59
+ target_line = target_line.replace(" $MERGE_HYPHEN ", "-")
60
+ target_line = target_line.replace(" $MERGE_SPACE ", "")
61
+ return target_line.split()
62
+
63
+ def convert_using_case(self, token, smart_action):
64
+ if not smart_action.startswith("$TRANSFORM_CASE_"):
65
+ return token
66
+ if smart_action.endswith("LOWER"):
67
+ return token.lower()
68
+ elif smart_action.endswith("UPPER"):
69
+ return token.upper()
70
+ elif smart_action.endswith("CAPITAL"):
71
+ return token.capitalize()
72
+ elif smart_action.endswith("CAPITAL_1"):
73
+ return token[0] + token[1:].capitalize()
74
+ elif smart_action.endswith("UPPER_-1"):
75
+ return token[:-1].upper() + token[-1]
76
+ else:
77
+ return token
78
+
79
+ def convert_using_verb(self, token, smart_action):
80
+ key_word = "$TRANSFORM_VERB_"
81
+ if not smart_action.startswith(key_word):
82
+ raise Exception(f"Unknown action type {smart_action}")
83
+ encoding_part = f"{token}_{smart_action[len(key_word):]}"
84
+ decoded_target_word = self.decode_verb_form(encoding_part)
85
+ return decoded_target_word
86
+
87
+ def convert_using_split(self, token, smart_action):
88
+ key_word = "$TRANSFORM_SPLIT"
89
+ if not smart_action.startswith(key_word):
90
+ raise Exception(f"Unknown action type {smart_action}")
91
+ target_words = token.split("-")
92
+ return " ".join(target_words)
93
+
94
+ def convert_using_plural(self, token, smart_action):
95
+ if smart_action.endswith("PLURAL"):
96
+ return token + "s"
97
+ elif smart_action.endswith("SINGULAR"):
98
+ return token[:-1]
99
+ else:
100
+ raise Exception(f"Unknown action type {smart_action}")
101
+
102
+ def apply_reverse_transformation(self, source_token, transform):
103
+ if transform.startswith("$TRANSFORM"):
104
+ # deal with equal
105
+ if transform == "$KEEP":
106
+ return source_token
107
+ # deal with case
108
+ if transform.startswith("$TRANSFORM_CASE"):
109
+ return self.convert_using_case(source_token, transform)
110
+ # deal with verb
111
+ if transform.startswith("$TRANSFORM_VERB"):
112
+ return self.convert_using_verb(source_token, transform)
113
+ # deal with split
114
+ if transform.startswith("$TRANSFORM_SPLIT"):
115
+ return self.convert_using_split(source_token, transform)
116
+ # deal with single/plural
117
+ if transform.startswith("$TRANSFORM_AGREEMENT"):
118
+ return self.convert_using_plural(source_token, transform)
119
+ # raise exception if not find correct type
120
+ raise Exception(f"Unknown action type {transform}")
121
+ else:
122
+ return source_token
123
+
124
+ def get_token_action(self, token, index, prob, sugg_token, min_error_probability):
125
+ """Get lost of suggested actions for token."""
126
+ # cases when we don't need to do anything
127
+ if prob < min_error_probability or sugg_token in [self.UNK, self.PAD, "$KEEP"]:
128
+ return None
129
+
130
+ if (
131
+ sugg_token.startswith("$REPLACE_")
132
+ or sugg_token.startswith("$TRANSFORM_")
133
+ or sugg_token == "$DELETE"
134
+ ):
135
+ start_pos = index
136
+ end_pos = index + 1
137
+ elif sugg_token.startswith("$APPEND_") or sugg_token.startswith("$MERGE_"):
138
+ start_pos = index + 1
139
+ end_pos = index + 1
140
+
141
+ if sugg_token == "$DELETE":
142
+ sugg_token_clear = ""
143
+ elif sugg_token.startswith("$TRANSFORM_") or sugg_token.startswith("$MERGE_"):
144
+ sugg_token_clear = sugg_token[:]
145
+ else:
146
+ sugg_token_clear = sugg_token[sugg_token.index("_") + 1 :]
147
+
148
+ return start_pos - 1, end_pos - 1, sugg_token_clear, prob
149
+
150
+
151
+ class GrammarErrorCorrectionPipeline(Pipeline, GectorBase):
152
+ def _sanitize_parameters(self, **kwargs):
153
+ preprocess_kwargs = {
154
+ "max_len": int(kwargs.get("max_len", 50)),
155
+ "lowercase_tokens": bool(kwargs.get("lowercase_tokens", False)),
156
+ }
157
+ forward_kwargs = {
158
+ "iterations": int(kwargs.get("iterations", 1)),
159
+ "max_len": int(kwargs.get("max_len", 50)),
160
+ "min_len": int(kwargs.get("min_len", 3)),
161
+ "min_error_probability": float(kwargs.get("min_error_probability", 0.0)),
162
+ }
163
+ postprocess_kwargs = {}
164
+ return preprocess_kwargs, forward_kwargs, postprocess_kwargs
165
+
166
+ def add_word_offsets(self, tokenized_input):
167
+ word_ids = tokenized_input.word_ids()
168
+ offsets = [i for i, x in enumerate(word_ids) if i == 0 or x != word_ids[i - 1]]
169
+ if self.framework == TensorType.PYTORCH:
170
+ import torch
171
+
172
+ offsets = torch.tensor([offsets], dtype=torch.long)
173
+ mask = torch.ones_like(offsets)
174
+ tokenized_input["word_offsets"] = offsets
175
+ tokenized_input["word_mask"] = mask
176
+ return tokenized_input
177
+
178
+ def preprocess(self, model_input, **kwargs):
179
+ tokens = [self.START_TOKEN] + model_input.split(self.DELIMINTER)
180
+ tokenized_input = self.tokenizer(
181
+ tokens,
182
+ max_length=kwargs.get("max_len"),
183
+ add_special_tokens=False,
184
+ truncation=True,
185
+ is_split_into_words=True,
186
+ return_token_type_ids=True,
187
+ return_tensors=self.framework,
188
+ )
189
+ tokenized_input["oriignal_tokens"] = tokens[1:]
190
+ tokenized_input = self.add_word_offsets(tokenized_input)
191
+ return tokenized_input
192
+
193
+ def _forward_iterative(self, batch, **forward_kwargs):
194
+ oriignal_tokens = batch.pop("oriignal_tokens")
195
+ model_outputs = self.model(**batch)
196
+
197
+ error_probs = model_outputs.max_error_probabilities.numpy()
198
+ class_probabilities_correct = model_outputs.class_probabilities_correct.numpy()
199
+ all_probabilities = np.amax(class_probabilities_correct, axis=-1)
200
+ all_idxs = np.argmax(class_probabilities_correct, axis=-1)
201
+
202
+ all_results = []
203
+ noop_index = self.model.config.detect_label2id.get("$CORRECT")
204
+ for tokens, probabilities, idxs, error_prob in zip(
205
+ oriignal_tokens, all_probabilities, all_idxs, error_probs
206
+ ):
207
+ length = min(len(tokens), forward_kwargs.get("max_len"))
208
+ edits = []
209
+
210
+ # skip whole sentences if there no errors
211
+ if max(idxs) == 0:
212
+ all_results.append(tokens)
213
+ continue
214
+
215
+ # skip whole sentence if probability of correctness is not high
216
+ if error_prob < forward_kwargs.get("min_error_probability"):
217
+ all_results.append(tokens)
218
+ continue
219
+ for i in range(length + 1):
220
+ # because of START token
221
+ if i == 0:
222
+ token = self.START_TOKEN
223
+ else:
224
+ token = tokens[i - 1]
225
+ # skip if there is no error
226
+ if idxs[i] == noop_index:
227
+ continue
228
+
229
+ sugg_token = self.model.config.id2label[str(idxs[i])]
230
+ action = self.get_token_action(
231
+ token,
232
+ i,
233
+ probabilities[i],
234
+ sugg_token,
235
+ forward_kwargs.get("min_error_probability"),
236
+ )
237
+ if not action:
238
+ continue
239
+
240
+ edits.append(action)
241
+ all_results.append(self.get_target_sent_by_edits(tokens, edits))
242
+ return all_results
243
+
244
+ def _forward(self, model_inputs, **forward_kwargs):
245
+ outputs = []
246
+ for iter in range(forward_kwargs.get("iterations")):
247
+ outputs = self._forward_iterative(model_inputs, **forward_kwargs)
248
+ return {"output": outputs}
249
+
250
+ def postprocess(self, model_outputs):
251
+ return model_outputs
modelling_gector.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import logging
3
+
4
+ from torch import nn
5
+ from dataclasses import dataclass
6
+ from typing import Optional, Tuple, Union
7
+
8
+ from transformers import PreTrainedModel, AutoModel, AutoConfig
9
+ from transformers.modeling_outputs import TokenClassifierOutput
10
+
11
+ from .configuration_gector import GectorConfig
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ GECTOR_PRETRAINED_BASE_MODEL_ARCHIVE_LIST = [
16
+ "bert-base-cased",
17
+ "bert-large-cased",
18
+ "roberta-base",
19
+ "roberta-large",
20
+ "xlnet-base-cased",
21
+ "xlnet-large-cased",
22
+ "deberta-base-cased",
23
+ "deberta-large-cased",
24
+ ]
25
+
26
+
27
+ @dataclass
28
+ class GectorTokenClassifierOutput(TokenClassifierOutput):
29
+ loss: Optional[torch.FloatTensor] = None
30
+ logits_detect: torch.FloatTensor = None
31
+ class_probabilities_detect: torch.FloatTensor = None
32
+ logits_correct: torch.FloatTensor = None
33
+ class_probabilities_correct: torch.FloatTensor = None
34
+ max_error_probabilities: torch.FloatTensor = None
35
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
36
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
37
+
38
+
39
+ class GectorModel(PreTrainedModel):
40
+ config_class = GectorConfig
41
+
42
+ def __init__(self, config):
43
+ super().__init__(config)
44
+ special_tokens_fix = config.special_tokens_fix
45
+
46
+ config = AutoConfig.from_pretrained(config.model_id)
47
+ self.encoder_model = AutoModel.from_config(config)
48
+
49
+ if special_tokens_fix:
50
+ self.encoder_model.resize_token_embeddings(config.vocab_size + 1)
51
+
52
+ def forward(self, *args, **kwargs):
53
+ return self.encoder_model.forward(*args, **kwargs)
54
+
55
+
56
+ class GectorForTokenClassification(PreTrainedModel):
57
+ config_class = GectorConfig
58
+
59
+ def __init__(self, config):
60
+ super().__init__(config)
61
+ self.num_detect_tags = config.num_detect_tags
62
+ self.num_correct_tags = config.num_correct_tags
63
+
64
+ self.text_field_embedder = GectorModel(config)
65
+ self.embedding_size = self.text_field_embedder.encoder_model.config.hidden_size
66
+
67
+ self.dropout = nn.Dropout(config.classifier_dropout)
68
+
69
+ self.detect_proj_layer = nn.Linear(self.embedding_size, self.num_detect_tags)
70
+ self.correct_proj_layer = nn.Linear(self.embedding_size, self.num_correct_tags)
71
+
72
+ self.delete_confidence = config.delete_confidence
73
+ self.additional_confidence = config.additional_confidence
74
+ self.incorrect_index = config.detect_label2id.get("$INCORRECT")
75
+
76
+ # Initialize weights and apply final processing
77
+ self.post_init()
78
+
79
+ def forward(
80
+ self,
81
+ input_ids: Optional[torch.LongTensor] = None,
82
+ attention_mask: Optional[torch.FloatTensor] = None,
83
+ word_offsets: Optional[torch.LongTensor] = None,
84
+ word_mask: Optional[torch.LongTensor] = None,
85
+ token_type_ids: Optional[torch.LongTensor] = None,
86
+ position_ids: Optional[torch.LongTensor] = None,
87
+ head_mask: Optional[torch.FloatTensor] = None,
88
+ inputs_embeds: Optional[torch.FloatTensor] = None,
89
+ labels: Optional[torch.LongTensor] = None,
90
+ output_attentions: Optional[bool] = None,
91
+ output_hidden_states: Optional[bool] = None,
92
+ return_dict: Optional[bool] = None,
93
+ ) -> Union[Tuple[torch.Tensor], GectorTokenClassifierOutput]:
94
+ r"""
95
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
96
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
97
+ """
98
+ return_dict = (
99
+ return_dict if return_dict is not None else self.config.use_return_dict
100
+ )
101
+
102
+ outputs = self.text_field_embedder(
103
+ input_ids,
104
+ attention_mask=attention_mask,
105
+ token_type_ids=token_type_ids,
106
+ position_ids=position_ids,
107
+ head_mask=head_mask,
108
+ inputs_embeds=inputs_embeds,
109
+ output_attentions=output_attentions,
110
+ output_hidden_states=output_hidden_states,
111
+ return_dict=return_dict,
112
+ )
113
+ sequence_output = outputs[0]
114
+ # If offsets are provided, the returned tensor will contain only the wordpiece
115
+ # embeddings at those positions, and (in particular) will contain one embedding
116
+ # per token. If offsets are not provided, the entire tensor of wordpiece embeddings
117
+ # will be returned.
118
+ if word_offsets is not None:
119
+ indices = word_offsets.unsqueeze(-1).expand(
120
+ -1, -1, sequence_output.size(-1)
121
+ )
122
+ sequence_output = torch.gather(sequence_output, 1, indices)
123
+ batch_size, sequence_length = sequence_output.size()[0:2]
124
+
125
+ logits_detect = self.detect_proj_layer(sequence_output)
126
+ logits_correct = self.correct_proj_layer(self.dropout(sequence_output))
127
+
128
+ class_probabilities_correct = nn.functional.softmax(
129
+ logits_correct, dim=-1
130
+ ).view([batch_size, sequence_length, self.num_correct_tags])
131
+ class_probabilities_detect = nn.functional.softmax(logits_detect, dim=-1).view(
132
+ [batch_size, sequence_length, self.num_detect_tags]
133
+ )
134
+ max_error_probabilities = torch.max(
135
+ class_probabilities_detect[:, :, self.incorrect_index] * word_mask,
136
+ dim=-1,
137
+ )[0]
138
+ probability_change = [self.additional_confidence, self.delete_confidence] + [
139
+ 0
140
+ ] * (self.num_correct_tags - 2)
141
+ class_probabilities_correct += (
142
+ torch.FloatTensor(probability_change)
143
+ .repeat((batch_size, sequence_length, 1))
144
+ .to(self.device)
145
+ )
146
+
147
+ loss = None
148
+ if labels is not None:
149
+ detect_labels, correct_labels = torch.tensor_split(labels, 2, dim=-1)
150
+ # -100 is the default ignore_idx of CrossEntropyLoss
151
+ detect_labels[detect_labels == self.config.detect_pad_token_id] = -100
152
+ correct_labels[correct_labels == self.config.correct_pad_token_id] = -100
153
+
154
+ detect_loss_fct = nn.CrossEntropyLoss()
155
+ loss_detect = detect_loss_fct(
156
+ logits_detect.view(-1, self.config.num_detect_tags),
157
+ detect_labels.view(-1),
158
+ )
159
+
160
+ correct_loss_fct = nn.CrossEntropyLoss(
161
+ label_smoothing=self.config.label_smoothing
162
+ )
163
+ loss_correct = correct_loss_fct(
164
+ logits_correct.view(-1, self.config.num_correct_tags),
165
+ correct_labels.view(-1),
166
+ )
167
+ loss = loss_detect + loss_correct
168
+
169
+ if not return_dict:
170
+ output = (logits_detect, logits_correct) + outputs[2:]
171
+ return ((loss,) + output) if loss is not None else output
172
+
173
+ return GectorTokenClassifierOutput(
174
+ loss=loss,
175
+ logits_detect=logits_detect,
176
+ class_probabilities_detect=class_probabilities_detect,
177
+ logits_correct=logits_correct,
178
+ class_probabilities_correct=class_probabilities_correct,
179
+ max_error_probabilities=max_error_probabilities,
180
+ hidden_states=outputs.hidden_states,
181
+ attentions=outputs.attentions,
182
+ )
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f23f857799504b7347fad3548ad1c28ecb409f921d99a39ce8fec2ce7c3b98b7
3
+ size 482343698
special_tokens_map.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<eop>",
4
+ "<eod>"
5
+ ],
6
+ "bos_token": "<s>",
7
+ "cls_token": "<cls>",
8
+ "eos_token": "</s>",
9
+ "mask_token": "<mask>",
10
+ "pad_token": "<pad>",
11
+ "sep_token": "<sep>",
12
+ "unk_token": "<unk>"
13
+ }
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f8c1c0bc2854d1af911a8550288c1258af5ba50277f3a5c829b98eb86fc5646
3
+ size 798011
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<unk>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<s>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "</s>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<cls>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "4": {
36
+ "content": "<sep>",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "5": {
44
+ "content": "<pad>",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ },
51
+ "6": {
52
+ "content": "<mask>",
53
+ "lstrip": true,
54
+ "normalized": false,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": true
58
+ },
59
+ "7": {
60
+ "content": "<eod>",
61
+ "lstrip": false,
62
+ "normalized": false,
63
+ "rstrip": false,
64
+ "single_word": false,
65
+ "special": true
66
+ },
67
+ "8": {
68
+ "content": "<eop>",
69
+ "lstrip": false,
70
+ "normalized": false,
71
+ "rstrip": false,
72
+ "single_word": false,
73
+ "special": true
74
+ }
75
+ },
76
+ "additional_special_tokens": [
77
+ "<eop>",
78
+ "<eod>"
79
+ ],
80
+ "bos_token": "<s>",
81
+ "clean_up_tokenization_spaces": true,
82
+ "cls_token": "<cls>",
83
+ "do_basic_tokenize": false,
84
+ "do_lower_case": false,
85
+ "eos_token": "</s>",
86
+ "keep_accents": false,
87
+ "mask_token": "<mask>",
88
+ "model_max_length": 1000000000000000019884624838656,
89
+ "pad_token": "<pad>",
90
+ "padding_side": "right",
91
+ "remove_space": true,
92
+ "sep_token": "<sep>",
93
+ "tokenizer_class": "XLNetTokenizer",
94
+ "unk_token": "<unk>"
95
+ }