Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- added_tokens.json +11 -0
- config.json +3 -0
- configuration_gector.py +91 -0
- grammar_error_correction_pipeline.py +251 -0
- modelling_gector.py +182 -0
- pytorch_model.bin +3 -0
- special_tokens_map.json +13 -0
- spiece.model +3 -0
- tokenizer.json +0 -0
- tokenizer_config.json +95 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
config.json filter=lfs diff=lfs merge=lfs -text
|
added_tokens.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"</s>": 2,
|
3 |
+
"<cls>": 3,
|
4 |
+
"<eod>": 7,
|
5 |
+
"<eop>": 8,
|
6 |
+
"<mask>": 6,
|
7 |
+
"<pad>": 5,
|
8 |
+
"<s>": 1,
|
9 |
+
"<sep>": 4,
|
10 |
+
"<unk>": 0
|
11 |
+
}
|
config.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:97e03cd3dd250c9819297c1c1f6099ad8d59b374c344f07c633a79c68bce182f
|
3 |
+
size 11109513
|
configuration_gector.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
|
4 |
+
from typing import OrderedDict, Mapping, Union
|
5 |
+
|
6 |
+
from transformers import PretrainedConfig, AutoConfig
|
7 |
+
from transformers.onnx import OnnxConfig
|
8 |
+
|
9 |
+
|
10 |
+
class GectorConfig(PretrainedConfig):
|
11 |
+
model_type = "gector"
|
12 |
+
|
13 |
+
# To add config values from base model config
|
14 |
+
def __subclassconfig__(self, base_config: AutoConfig):
|
15 |
+
if base_config:
|
16 |
+
self.__dict__.update(base_config.__dict__)
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
model_id: str = None,
|
21 |
+
id2label: dict = None,
|
22 |
+
label2id: dict = None,
|
23 |
+
detect_id2label: dict = None,
|
24 |
+
detect_label2id: dict = None,
|
25 |
+
classifier_dropout: float = 0,
|
26 |
+
label_pad_token: str = "<PAD>",
|
27 |
+
label_unknown_token: str = "<UNK>",
|
28 |
+
detect_pad_token_id: int = 3,
|
29 |
+
correct_pad_token_id: int = 5001,
|
30 |
+
num_detect_tags: int = 4,
|
31 |
+
num_correct_tags: int = 5002,
|
32 |
+
max_length: int = 128,
|
33 |
+
label_smoothing: float = 0.0,
|
34 |
+
special_tokens_fix: bool = False,
|
35 |
+
delete_confidence: float = 0.0,
|
36 |
+
additional_confidence: float = 0.2,
|
37 |
+
base_config: AutoConfig = None,
|
38 |
+
verb_form_vocab: dict = None,
|
39 |
+
**kwargs,
|
40 |
+
):
|
41 |
+
super().__init__(**kwargs)
|
42 |
+
self.__subclassconfig__(base_config)
|
43 |
+
|
44 |
+
self.model_id = model_id
|
45 |
+
self.label2id = label2id
|
46 |
+
self.id2label = id2label
|
47 |
+
self.detect_label2id = detect_label2id
|
48 |
+
self.detect_id2label = detect_id2label
|
49 |
+
self.detect_pad_token_id = detect_pad_token_id
|
50 |
+
self.correct_pad_token_id = correct_pad_token_id
|
51 |
+
self.num_detect_tags = num_detect_tags
|
52 |
+
self.num_correct_tags = num_correct_tags
|
53 |
+
self.classifier_dropout = classifier_dropout
|
54 |
+
self.max_length = max_length
|
55 |
+
self.label_smoothing = label_smoothing
|
56 |
+
self.special_tokens_fix = special_tokens_fix
|
57 |
+
self.delete_confidence = delete_confidence
|
58 |
+
self.additional_confidence = additional_confidence
|
59 |
+
self.verb_form_vocab = verb_form_vocab
|
60 |
+
|
61 |
+
# def save_pretrained(
|
62 |
+
# self,
|
63 |
+
# save_directory: Union[str, os.PathLike],
|
64 |
+
# push_to_hub: bool = False,
|
65 |
+
# **kwargs,
|
66 |
+
# ):
|
67 |
+
# if os.path.isfile(save_directory):
|
68 |
+
# raise AssertionError(
|
69 |
+
# f"Provided path ({save_directory}) should be a directory, not a file"
|
70 |
+
# )
|
71 |
+
|
72 |
+
# os.makedirs(save_directory, exist_ok=True)
|
73 |
+
|
74 |
+
# if self.verb_form_vocab:
|
75 |
+
# verb_form_vocab_file = os.path.join(save_directory, "verb_form_vocab.json")
|
76 |
+
# with open(verb_form_vocab_file, "w", encoding="utf-8") as writer:
|
77 |
+
# writer.write(json.dumps(self.verb_form_vocab, indent=2, sort_keys=True) + "\n")
|
78 |
+
|
79 |
+
# super().save_pretrained(save_directory, push_to_hub, **kwargs)
|
80 |
+
|
81 |
+
|
82 |
+
class GectorOnnxConfig(OnnxConfig):
|
83 |
+
@property
|
84 |
+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
85 |
+
dynamic_axis = {0: "batch", 1: "sequence"}
|
86 |
+
return OrderedDict(
|
87 |
+
[
|
88 |
+
("input_ids", dynamic_axis),
|
89 |
+
("attention_mask", dynamic_axis),
|
90 |
+
]
|
91 |
+
)
|
grammar_error_correction_pipeline.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from transformers import Pipeline, TensorType
|
5 |
+
|
6 |
+
|
7 |
+
class GectorBase(object):
|
8 |
+
DELIMINTER = " "
|
9 |
+
START_TOKEN = "$START"
|
10 |
+
PAD = "<PAD>"
|
11 |
+
UNK = "<UNK>"
|
12 |
+
REPLACEMENTS = {
|
13 |
+
"''": '"',
|
14 |
+
"--": "—",
|
15 |
+
"`": "'",
|
16 |
+
"'ve": "' ve",
|
17 |
+
}
|
18 |
+
|
19 |
+
def decode_verb_form(self, original):
|
20 |
+
return self.model.config.verb_form_vocab["decode"].get(original)
|
21 |
+
|
22 |
+
def get_target_sent_by_edits(self, source_tokens, edits):
|
23 |
+
target_tokens = source_tokens[:]
|
24 |
+
shift_idx = 0
|
25 |
+
for edit in edits:
|
26 |
+
start, end, label, _ = edit
|
27 |
+
target_pos = start + shift_idx
|
28 |
+
source_token = (
|
29 |
+
target_tokens[target_pos]
|
30 |
+
if len(target_tokens) > target_pos >= 0
|
31 |
+
else ""
|
32 |
+
)
|
33 |
+
if label == "":
|
34 |
+
del target_tokens[target_pos]
|
35 |
+
shift_idx -= 1
|
36 |
+
elif start == end:
|
37 |
+
word = label.replace("$APPEND_", "")
|
38 |
+
target_tokens[target_pos:target_pos] = [word]
|
39 |
+
shift_idx += 1
|
40 |
+
elif label.startswith("$TRANSFORM_"):
|
41 |
+
word = self.apply_reverse_transformation(source_token, label)
|
42 |
+
if word is None:
|
43 |
+
word = source_token
|
44 |
+
target_tokens[target_pos] = word
|
45 |
+
elif start == end - 1:
|
46 |
+
word = label.replace("$REPLACE_", "")
|
47 |
+
target_tokens[target_pos] = word
|
48 |
+
elif label.startswith("$MERGE_"):
|
49 |
+
target_tokens[target_pos + 1 : target_pos + 1] = [label]
|
50 |
+
shift_idx += 1
|
51 |
+
|
52 |
+
return self.replace_merge_transforms(target_tokens)
|
53 |
+
|
54 |
+
def replace_merge_transforms(self, tokens):
|
55 |
+
if all(not x.startswith("$MERGE_") for x in tokens):
|
56 |
+
return tokens
|
57 |
+
|
58 |
+
target_line = " ".join(tokens)
|
59 |
+
target_line = target_line.replace(" $MERGE_HYPHEN ", "-")
|
60 |
+
target_line = target_line.replace(" $MERGE_SPACE ", "")
|
61 |
+
return target_line.split()
|
62 |
+
|
63 |
+
def convert_using_case(self, token, smart_action):
|
64 |
+
if not smart_action.startswith("$TRANSFORM_CASE_"):
|
65 |
+
return token
|
66 |
+
if smart_action.endswith("LOWER"):
|
67 |
+
return token.lower()
|
68 |
+
elif smart_action.endswith("UPPER"):
|
69 |
+
return token.upper()
|
70 |
+
elif smart_action.endswith("CAPITAL"):
|
71 |
+
return token.capitalize()
|
72 |
+
elif smart_action.endswith("CAPITAL_1"):
|
73 |
+
return token[0] + token[1:].capitalize()
|
74 |
+
elif smart_action.endswith("UPPER_-1"):
|
75 |
+
return token[:-1].upper() + token[-1]
|
76 |
+
else:
|
77 |
+
return token
|
78 |
+
|
79 |
+
def convert_using_verb(self, token, smart_action):
|
80 |
+
key_word = "$TRANSFORM_VERB_"
|
81 |
+
if not smart_action.startswith(key_word):
|
82 |
+
raise Exception(f"Unknown action type {smart_action}")
|
83 |
+
encoding_part = f"{token}_{smart_action[len(key_word):]}"
|
84 |
+
decoded_target_word = self.decode_verb_form(encoding_part)
|
85 |
+
return decoded_target_word
|
86 |
+
|
87 |
+
def convert_using_split(self, token, smart_action):
|
88 |
+
key_word = "$TRANSFORM_SPLIT"
|
89 |
+
if not smart_action.startswith(key_word):
|
90 |
+
raise Exception(f"Unknown action type {smart_action}")
|
91 |
+
target_words = token.split("-")
|
92 |
+
return " ".join(target_words)
|
93 |
+
|
94 |
+
def convert_using_plural(self, token, smart_action):
|
95 |
+
if smart_action.endswith("PLURAL"):
|
96 |
+
return token + "s"
|
97 |
+
elif smart_action.endswith("SINGULAR"):
|
98 |
+
return token[:-1]
|
99 |
+
else:
|
100 |
+
raise Exception(f"Unknown action type {smart_action}")
|
101 |
+
|
102 |
+
def apply_reverse_transformation(self, source_token, transform):
|
103 |
+
if transform.startswith("$TRANSFORM"):
|
104 |
+
# deal with equal
|
105 |
+
if transform == "$KEEP":
|
106 |
+
return source_token
|
107 |
+
# deal with case
|
108 |
+
if transform.startswith("$TRANSFORM_CASE"):
|
109 |
+
return self.convert_using_case(source_token, transform)
|
110 |
+
# deal with verb
|
111 |
+
if transform.startswith("$TRANSFORM_VERB"):
|
112 |
+
return self.convert_using_verb(source_token, transform)
|
113 |
+
# deal with split
|
114 |
+
if transform.startswith("$TRANSFORM_SPLIT"):
|
115 |
+
return self.convert_using_split(source_token, transform)
|
116 |
+
# deal with single/plural
|
117 |
+
if transform.startswith("$TRANSFORM_AGREEMENT"):
|
118 |
+
return self.convert_using_plural(source_token, transform)
|
119 |
+
# raise exception if not find correct type
|
120 |
+
raise Exception(f"Unknown action type {transform}")
|
121 |
+
else:
|
122 |
+
return source_token
|
123 |
+
|
124 |
+
def get_token_action(self, token, index, prob, sugg_token, min_error_probability):
|
125 |
+
"""Get lost of suggested actions for token."""
|
126 |
+
# cases when we don't need to do anything
|
127 |
+
if prob < min_error_probability or sugg_token in [self.UNK, self.PAD, "$KEEP"]:
|
128 |
+
return None
|
129 |
+
|
130 |
+
if (
|
131 |
+
sugg_token.startswith("$REPLACE_")
|
132 |
+
or sugg_token.startswith("$TRANSFORM_")
|
133 |
+
or sugg_token == "$DELETE"
|
134 |
+
):
|
135 |
+
start_pos = index
|
136 |
+
end_pos = index + 1
|
137 |
+
elif sugg_token.startswith("$APPEND_") or sugg_token.startswith("$MERGE_"):
|
138 |
+
start_pos = index + 1
|
139 |
+
end_pos = index + 1
|
140 |
+
|
141 |
+
if sugg_token == "$DELETE":
|
142 |
+
sugg_token_clear = ""
|
143 |
+
elif sugg_token.startswith("$TRANSFORM_") or sugg_token.startswith("$MERGE_"):
|
144 |
+
sugg_token_clear = sugg_token[:]
|
145 |
+
else:
|
146 |
+
sugg_token_clear = sugg_token[sugg_token.index("_") + 1 :]
|
147 |
+
|
148 |
+
return start_pos - 1, end_pos - 1, sugg_token_clear, prob
|
149 |
+
|
150 |
+
|
151 |
+
class GrammarErrorCorrectionPipeline(Pipeline, GectorBase):
|
152 |
+
def _sanitize_parameters(self, **kwargs):
|
153 |
+
preprocess_kwargs = {
|
154 |
+
"max_len": int(kwargs.get("max_len", 50)),
|
155 |
+
"lowercase_tokens": bool(kwargs.get("lowercase_tokens", False)),
|
156 |
+
}
|
157 |
+
forward_kwargs = {
|
158 |
+
"iterations": int(kwargs.get("iterations", 1)),
|
159 |
+
"max_len": int(kwargs.get("max_len", 50)),
|
160 |
+
"min_len": int(kwargs.get("min_len", 3)),
|
161 |
+
"min_error_probability": float(kwargs.get("min_error_probability", 0.0)),
|
162 |
+
}
|
163 |
+
postprocess_kwargs = {}
|
164 |
+
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
|
165 |
+
|
166 |
+
def add_word_offsets(self, tokenized_input):
|
167 |
+
word_ids = tokenized_input.word_ids()
|
168 |
+
offsets = [i for i, x in enumerate(word_ids) if i == 0 or x != word_ids[i - 1]]
|
169 |
+
if self.framework == TensorType.PYTORCH:
|
170 |
+
import torch
|
171 |
+
|
172 |
+
offsets = torch.tensor([offsets], dtype=torch.long)
|
173 |
+
mask = torch.ones_like(offsets)
|
174 |
+
tokenized_input["word_offsets"] = offsets
|
175 |
+
tokenized_input["word_mask"] = mask
|
176 |
+
return tokenized_input
|
177 |
+
|
178 |
+
def preprocess(self, model_input, **kwargs):
|
179 |
+
tokens = [self.START_TOKEN] + model_input.split(self.DELIMINTER)
|
180 |
+
tokenized_input = self.tokenizer(
|
181 |
+
tokens,
|
182 |
+
max_length=kwargs.get("max_len"),
|
183 |
+
add_special_tokens=False,
|
184 |
+
truncation=True,
|
185 |
+
is_split_into_words=True,
|
186 |
+
return_token_type_ids=True,
|
187 |
+
return_tensors=self.framework,
|
188 |
+
)
|
189 |
+
tokenized_input["oriignal_tokens"] = tokens[1:]
|
190 |
+
tokenized_input = self.add_word_offsets(tokenized_input)
|
191 |
+
return tokenized_input
|
192 |
+
|
193 |
+
def _forward_iterative(self, batch, **forward_kwargs):
|
194 |
+
oriignal_tokens = batch.pop("oriignal_tokens")
|
195 |
+
model_outputs = self.model(**batch)
|
196 |
+
|
197 |
+
error_probs = model_outputs.max_error_probabilities.numpy()
|
198 |
+
class_probabilities_correct = model_outputs.class_probabilities_correct.numpy()
|
199 |
+
all_probabilities = np.amax(class_probabilities_correct, axis=-1)
|
200 |
+
all_idxs = np.argmax(class_probabilities_correct, axis=-1)
|
201 |
+
|
202 |
+
all_results = []
|
203 |
+
noop_index = self.model.config.detect_label2id.get("$CORRECT")
|
204 |
+
for tokens, probabilities, idxs, error_prob in zip(
|
205 |
+
oriignal_tokens, all_probabilities, all_idxs, error_probs
|
206 |
+
):
|
207 |
+
length = min(len(tokens), forward_kwargs.get("max_len"))
|
208 |
+
edits = []
|
209 |
+
|
210 |
+
# skip whole sentences if there no errors
|
211 |
+
if max(idxs) == 0:
|
212 |
+
all_results.append(tokens)
|
213 |
+
continue
|
214 |
+
|
215 |
+
# skip whole sentence if probability of correctness is not high
|
216 |
+
if error_prob < forward_kwargs.get("min_error_probability"):
|
217 |
+
all_results.append(tokens)
|
218 |
+
continue
|
219 |
+
for i in range(length + 1):
|
220 |
+
# because of START token
|
221 |
+
if i == 0:
|
222 |
+
token = self.START_TOKEN
|
223 |
+
else:
|
224 |
+
token = tokens[i - 1]
|
225 |
+
# skip if there is no error
|
226 |
+
if idxs[i] == noop_index:
|
227 |
+
continue
|
228 |
+
|
229 |
+
sugg_token = self.model.config.id2label[str(idxs[i])]
|
230 |
+
action = self.get_token_action(
|
231 |
+
token,
|
232 |
+
i,
|
233 |
+
probabilities[i],
|
234 |
+
sugg_token,
|
235 |
+
forward_kwargs.get("min_error_probability"),
|
236 |
+
)
|
237 |
+
if not action:
|
238 |
+
continue
|
239 |
+
|
240 |
+
edits.append(action)
|
241 |
+
all_results.append(self.get_target_sent_by_edits(tokens, edits))
|
242 |
+
return all_results
|
243 |
+
|
244 |
+
def _forward(self, model_inputs, **forward_kwargs):
|
245 |
+
outputs = []
|
246 |
+
for iter in range(forward_kwargs.get("iterations")):
|
247 |
+
outputs = self._forward_iterative(model_inputs, **forward_kwargs)
|
248 |
+
return {"output": outputs}
|
249 |
+
|
250 |
+
def postprocess(self, model_outputs):
|
251 |
+
return model_outputs
|
modelling_gector.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from torch import nn
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from typing import Optional, Tuple, Union
|
7 |
+
|
8 |
+
from transformers import PreTrainedModel, AutoModel, AutoConfig
|
9 |
+
from transformers.modeling_outputs import TokenClassifierOutput
|
10 |
+
|
11 |
+
from .configuration_gector import GectorConfig
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
GECTOR_PRETRAINED_BASE_MODEL_ARCHIVE_LIST = [
|
16 |
+
"bert-base-cased",
|
17 |
+
"bert-large-cased",
|
18 |
+
"roberta-base",
|
19 |
+
"roberta-large",
|
20 |
+
"xlnet-base-cased",
|
21 |
+
"xlnet-large-cased",
|
22 |
+
"deberta-base-cased",
|
23 |
+
"deberta-large-cased",
|
24 |
+
]
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class GectorTokenClassifierOutput(TokenClassifierOutput):
|
29 |
+
loss: Optional[torch.FloatTensor] = None
|
30 |
+
logits_detect: torch.FloatTensor = None
|
31 |
+
class_probabilities_detect: torch.FloatTensor = None
|
32 |
+
logits_correct: torch.FloatTensor = None
|
33 |
+
class_probabilities_correct: torch.FloatTensor = None
|
34 |
+
max_error_probabilities: torch.FloatTensor = None
|
35 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
36 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
37 |
+
|
38 |
+
|
39 |
+
class GectorModel(PreTrainedModel):
|
40 |
+
config_class = GectorConfig
|
41 |
+
|
42 |
+
def __init__(self, config):
|
43 |
+
super().__init__(config)
|
44 |
+
special_tokens_fix = config.special_tokens_fix
|
45 |
+
|
46 |
+
config = AutoConfig.from_pretrained(config.model_id)
|
47 |
+
self.encoder_model = AutoModel.from_config(config)
|
48 |
+
|
49 |
+
if special_tokens_fix:
|
50 |
+
self.encoder_model.resize_token_embeddings(config.vocab_size + 1)
|
51 |
+
|
52 |
+
def forward(self, *args, **kwargs):
|
53 |
+
return self.encoder_model.forward(*args, **kwargs)
|
54 |
+
|
55 |
+
|
56 |
+
class GectorForTokenClassification(PreTrainedModel):
|
57 |
+
config_class = GectorConfig
|
58 |
+
|
59 |
+
def __init__(self, config):
|
60 |
+
super().__init__(config)
|
61 |
+
self.num_detect_tags = config.num_detect_tags
|
62 |
+
self.num_correct_tags = config.num_correct_tags
|
63 |
+
|
64 |
+
self.text_field_embedder = GectorModel(config)
|
65 |
+
self.embedding_size = self.text_field_embedder.encoder_model.config.hidden_size
|
66 |
+
|
67 |
+
self.dropout = nn.Dropout(config.classifier_dropout)
|
68 |
+
|
69 |
+
self.detect_proj_layer = nn.Linear(self.embedding_size, self.num_detect_tags)
|
70 |
+
self.correct_proj_layer = nn.Linear(self.embedding_size, self.num_correct_tags)
|
71 |
+
|
72 |
+
self.delete_confidence = config.delete_confidence
|
73 |
+
self.additional_confidence = config.additional_confidence
|
74 |
+
self.incorrect_index = config.detect_label2id.get("$INCORRECT")
|
75 |
+
|
76 |
+
# Initialize weights and apply final processing
|
77 |
+
self.post_init()
|
78 |
+
|
79 |
+
def forward(
|
80 |
+
self,
|
81 |
+
input_ids: Optional[torch.LongTensor] = None,
|
82 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
83 |
+
word_offsets: Optional[torch.LongTensor] = None,
|
84 |
+
word_mask: Optional[torch.LongTensor] = None,
|
85 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
86 |
+
position_ids: Optional[torch.LongTensor] = None,
|
87 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
88 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
89 |
+
labels: Optional[torch.LongTensor] = None,
|
90 |
+
output_attentions: Optional[bool] = None,
|
91 |
+
output_hidden_states: Optional[bool] = None,
|
92 |
+
return_dict: Optional[bool] = None,
|
93 |
+
) -> Union[Tuple[torch.Tensor], GectorTokenClassifierOutput]:
|
94 |
+
r"""
|
95 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
96 |
+
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
97 |
+
"""
|
98 |
+
return_dict = (
|
99 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
100 |
+
)
|
101 |
+
|
102 |
+
outputs = self.text_field_embedder(
|
103 |
+
input_ids,
|
104 |
+
attention_mask=attention_mask,
|
105 |
+
token_type_ids=token_type_ids,
|
106 |
+
position_ids=position_ids,
|
107 |
+
head_mask=head_mask,
|
108 |
+
inputs_embeds=inputs_embeds,
|
109 |
+
output_attentions=output_attentions,
|
110 |
+
output_hidden_states=output_hidden_states,
|
111 |
+
return_dict=return_dict,
|
112 |
+
)
|
113 |
+
sequence_output = outputs[0]
|
114 |
+
# If offsets are provided, the returned tensor will contain only the wordpiece
|
115 |
+
# embeddings at those positions, and (in particular) will contain one embedding
|
116 |
+
# per token. If offsets are not provided, the entire tensor of wordpiece embeddings
|
117 |
+
# will be returned.
|
118 |
+
if word_offsets is not None:
|
119 |
+
indices = word_offsets.unsqueeze(-1).expand(
|
120 |
+
-1, -1, sequence_output.size(-1)
|
121 |
+
)
|
122 |
+
sequence_output = torch.gather(sequence_output, 1, indices)
|
123 |
+
batch_size, sequence_length = sequence_output.size()[0:2]
|
124 |
+
|
125 |
+
logits_detect = self.detect_proj_layer(sequence_output)
|
126 |
+
logits_correct = self.correct_proj_layer(self.dropout(sequence_output))
|
127 |
+
|
128 |
+
class_probabilities_correct = nn.functional.softmax(
|
129 |
+
logits_correct, dim=-1
|
130 |
+
).view([batch_size, sequence_length, self.num_correct_tags])
|
131 |
+
class_probabilities_detect = nn.functional.softmax(logits_detect, dim=-1).view(
|
132 |
+
[batch_size, sequence_length, self.num_detect_tags]
|
133 |
+
)
|
134 |
+
max_error_probabilities = torch.max(
|
135 |
+
class_probabilities_detect[:, :, self.incorrect_index] * word_mask,
|
136 |
+
dim=-1,
|
137 |
+
)[0]
|
138 |
+
probability_change = [self.additional_confidence, self.delete_confidence] + [
|
139 |
+
0
|
140 |
+
] * (self.num_correct_tags - 2)
|
141 |
+
class_probabilities_correct += (
|
142 |
+
torch.FloatTensor(probability_change)
|
143 |
+
.repeat((batch_size, sequence_length, 1))
|
144 |
+
.to(self.device)
|
145 |
+
)
|
146 |
+
|
147 |
+
loss = None
|
148 |
+
if labels is not None:
|
149 |
+
detect_labels, correct_labels = torch.tensor_split(labels, 2, dim=-1)
|
150 |
+
# -100 is the default ignore_idx of CrossEntropyLoss
|
151 |
+
detect_labels[detect_labels == self.config.detect_pad_token_id] = -100
|
152 |
+
correct_labels[correct_labels == self.config.correct_pad_token_id] = -100
|
153 |
+
|
154 |
+
detect_loss_fct = nn.CrossEntropyLoss()
|
155 |
+
loss_detect = detect_loss_fct(
|
156 |
+
logits_detect.view(-1, self.config.num_detect_tags),
|
157 |
+
detect_labels.view(-1),
|
158 |
+
)
|
159 |
+
|
160 |
+
correct_loss_fct = nn.CrossEntropyLoss(
|
161 |
+
label_smoothing=self.config.label_smoothing
|
162 |
+
)
|
163 |
+
loss_correct = correct_loss_fct(
|
164 |
+
logits_correct.view(-1, self.config.num_correct_tags),
|
165 |
+
correct_labels.view(-1),
|
166 |
+
)
|
167 |
+
loss = loss_detect + loss_correct
|
168 |
+
|
169 |
+
if not return_dict:
|
170 |
+
output = (logits_detect, logits_correct) + outputs[2:]
|
171 |
+
return ((loss,) + output) if loss is not None else output
|
172 |
+
|
173 |
+
return GectorTokenClassifierOutput(
|
174 |
+
loss=loss,
|
175 |
+
logits_detect=logits_detect,
|
176 |
+
class_probabilities_detect=class_probabilities_detect,
|
177 |
+
logits_correct=logits_correct,
|
178 |
+
class_probabilities_correct=class_probabilities_correct,
|
179 |
+
max_error_probabilities=max_error_probabilities,
|
180 |
+
hidden_states=outputs.hidden_states,
|
181 |
+
attentions=outputs.attentions,
|
182 |
+
)
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f23f857799504b7347fad3548ad1c28ecb409f921d99a39ce8fec2ce7c3b98b7
|
3 |
+
size 482343698
|
special_tokens_map.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"additional_special_tokens": [
|
3 |
+
"<eop>",
|
4 |
+
"<eod>"
|
5 |
+
],
|
6 |
+
"bos_token": "<s>",
|
7 |
+
"cls_token": "<cls>",
|
8 |
+
"eos_token": "</s>",
|
9 |
+
"mask_token": "<mask>",
|
10 |
+
"pad_token": "<pad>",
|
11 |
+
"sep_token": "<sep>",
|
12 |
+
"unk_token": "<unk>"
|
13 |
+
}
|
spiece.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1f8c1c0bc2854d1af911a8550288c1258af5ba50277f3a5c829b98eb86fc5646
|
3 |
+
size 798011
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {
|
3 |
+
"0": {
|
4 |
+
"content": "<unk>",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false,
|
9 |
+
"special": true
|
10 |
+
},
|
11 |
+
"1": {
|
12 |
+
"content": "<s>",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false,
|
17 |
+
"special": true
|
18 |
+
},
|
19 |
+
"2": {
|
20 |
+
"content": "</s>",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false,
|
25 |
+
"special": true
|
26 |
+
},
|
27 |
+
"3": {
|
28 |
+
"content": "<cls>",
|
29 |
+
"lstrip": false,
|
30 |
+
"normalized": false,
|
31 |
+
"rstrip": false,
|
32 |
+
"single_word": false,
|
33 |
+
"special": true
|
34 |
+
},
|
35 |
+
"4": {
|
36 |
+
"content": "<sep>",
|
37 |
+
"lstrip": false,
|
38 |
+
"normalized": false,
|
39 |
+
"rstrip": false,
|
40 |
+
"single_word": false,
|
41 |
+
"special": true
|
42 |
+
},
|
43 |
+
"5": {
|
44 |
+
"content": "<pad>",
|
45 |
+
"lstrip": false,
|
46 |
+
"normalized": false,
|
47 |
+
"rstrip": false,
|
48 |
+
"single_word": false,
|
49 |
+
"special": true
|
50 |
+
},
|
51 |
+
"6": {
|
52 |
+
"content": "<mask>",
|
53 |
+
"lstrip": true,
|
54 |
+
"normalized": false,
|
55 |
+
"rstrip": false,
|
56 |
+
"single_word": false,
|
57 |
+
"special": true
|
58 |
+
},
|
59 |
+
"7": {
|
60 |
+
"content": "<eod>",
|
61 |
+
"lstrip": false,
|
62 |
+
"normalized": false,
|
63 |
+
"rstrip": false,
|
64 |
+
"single_word": false,
|
65 |
+
"special": true
|
66 |
+
},
|
67 |
+
"8": {
|
68 |
+
"content": "<eop>",
|
69 |
+
"lstrip": false,
|
70 |
+
"normalized": false,
|
71 |
+
"rstrip": false,
|
72 |
+
"single_word": false,
|
73 |
+
"special": true
|
74 |
+
}
|
75 |
+
},
|
76 |
+
"additional_special_tokens": [
|
77 |
+
"<eop>",
|
78 |
+
"<eod>"
|
79 |
+
],
|
80 |
+
"bos_token": "<s>",
|
81 |
+
"clean_up_tokenization_spaces": true,
|
82 |
+
"cls_token": "<cls>",
|
83 |
+
"do_basic_tokenize": false,
|
84 |
+
"do_lower_case": false,
|
85 |
+
"eos_token": "</s>",
|
86 |
+
"keep_accents": false,
|
87 |
+
"mask_token": "<mask>",
|
88 |
+
"model_max_length": 1000000000000000019884624838656,
|
89 |
+
"pad_token": "<pad>",
|
90 |
+
"padding_side": "right",
|
91 |
+
"remove_space": true,
|
92 |
+
"sep_token": "<sep>",
|
93 |
+
"tokenizer_class": "XLNetTokenizer",
|
94 |
+
"unk_token": "<unk>"
|
95 |
+
}
|