andrewrreed HF staff commited on
Commit
67a58db
·
1 Parent(s): aee0221

add handler

Browse files
__pycache__/handler.cpython-310.pyc ADDED
Binary file (2.17 kB). View file
 
data/verb-form-vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
gector/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .modeling import GECToR
2
+ from .configuration import GECToRConfig
3
+ from .dataset import load_dataset, GECToRDataset
4
+ from .predict import predict, load_verb_dict
5
+ from .predict_verbose import predict_verbose
6
+ from .vocab import (
7
+ build_vocab,
8
+ load_vocab_from_config,
9
+ load_vocab_from_official
10
+ )
11
+ __all__ = [
12
+ 'GECToR',
13
+ 'GECToRConfig',
14
+ 'load_dataset',
15
+ 'GECToRDataset',
16
+ 'predict',
17
+ 'load_verb_dict',
18
+ 'predict_verbose',
19
+ 'build_vocab',
20
+ 'load_vocab_from_config',
21
+ 'load_vocab_from_official'
22
+ ]
gector/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (626 Bytes). View file
 
gector/__pycache__/configuration.cpython-310.pyc ADDED
Binary file (1.6 kB). View file
 
gector/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (5.04 kB). View file
 
gector/__pycache__/modeling.cpython-310.pyc ADDED
Binary file (5.49 kB). View file
 
gector/__pycache__/predict.cpython-310.pyc ADDED
Binary file (5.35 kB). View file
 
gector/__pycache__/predict_verbose.cpython-310.pyc ADDED
Binary file (1.93 kB). View file
 
gector/__pycache__/vocab.cpython-310.pyc ADDED
Binary file (2.23 kB). View file
 
gector/configuration.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from transformers import PretrainedConfig
4
+ class GECToRConfig(PretrainedConfig):
5
+ def __init__(
6
+ self,
7
+ model_id: str = 'bert-base-cased',
8
+ p_dropout: float=0,
9
+ label_pad_token: str='<PAD>',
10
+ label_oov_token: str='<OOV>',
11
+ d_pad_token: str='<PAD>',
12
+ keep_label: str='$KEEP',
13
+ correct_label: str='$CORRECT',
14
+ incorrect_label: str='$INCORRECT',
15
+ label_smoothing: float=0.0,
16
+ has_add_pooling_layer: bool=True,
17
+ initializer_range: float=0.02,
18
+ **kwards
19
+ ):
20
+ super().__init__(**kwards)
21
+ self.d_label2id = {
22
+ "$CORRECT": 0,
23
+ "$INCORRECT": 1,
24
+ "<PAD>": 2
25
+ }
26
+ self.d_id2label = {v: k for k, v in self.d_label2id.items()}
27
+ self.d_num_labels = len(self.d_label2id)
28
+ self.model_id = model_id
29
+ self.p_dropout = p_dropout
30
+ self.label_pad_token = label_pad_token
31
+ self.label_oov_token = label_oov_token
32
+ self.d_pad_token = d_pad_token
33
+ self.keep_label = keep_label
34
+ self.correct_label = correct_label
35
+ self.incorrect_label = incorrect_label
36
+ self.label_smoothing = label_smoothing
37
+ self.has_add_pooling_layer = has_add_pooling_layer
38
+ self.initializer_range = initializer_range
gector/dataset.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+ from collections import Counter
3
+ import torch
4
+ from tqdm import tqdm
5
+ import os
6
+ from transformers import PreTrainedTokenizer
7
+
8
+ class GECToRDataset:
9
+ def __init__(
10
+ self,
11
+ srcs: List[str],
12
+ d_labels: List[List[int]]=None,
13
+ labels: List[List[int]]=None,
14
+ word_masks: List[List[int]]=None,
15
+ tokenizer: PreTrainedTokenizer=None,
16
+ max_length:int=128
17
+ ):
18
+ self.tokenizer = tokenizer
19
+ self.srcs = srcs
20
+ self.d_labels = d_labels
21
+ self.labels = labels
22
+ self.word_masks = word_masks
23
+ self.max_length = max_length
24
+ self.label2id = None
25
+ self.d_label2id = None
26
+
27
+ def __len__(self):
28
+ return len(self.srcs)
29
+
30
+ def __getitem__(self, idx):
31
+ src = self.srcs[idx]
32
+ d_labels = self.d_labels[idx]
33
+ labels = self.labels[idx]
34
+ wmask = self.word_masks[idx]
35
+ encode = self.tokenizer(
36
+ src,
37
+ return_tensors='pt',
38
+ max_length=self.max_length,
39
+ padding='max_length',
40
+ truncation=True,
41
+ is_split_into_words=True
42
+ )
43
+ return {
44
+ 'input_ids': encode['input_ids'].squeeze(),
45
+ 'attention_mask': encode['attention_mask'].squeeze(),
46
+ 'd_labels': torch.tensor(d_labels).squeeze(),
47
+ 'labels': torch.tensor(labels).squeeze(),
48
+ 'word_masks': torch.tensor(wmask).squeeze()
49
+ }
50
+
51
+ def append_vocab(self, label2id, d_label2id):
52
+ self.label2id = label2id
53
+ self.d_label2id = d_label2id
54
+ for i in range(len(self.labels)):
55
+ self.labels[i] = [self.label2id.get(l, self.label2id['<OOV>']) for l in self.labels[i]]
56
+ self.d_labels[i] = [self.d_label2id[l] for l in self.d_labels[i]]
57
+
58
+ def get_labels_freq(self, exluded_labels: List[str] = []):
59
+ assert(self.labels is not None and self.d_labels is not None)
60
+ flatten_labels = [ll for l in self.labels for ll in l if ll not in exluded_labels]
61
+ flatten_d_labels = [ll for l in self.d_labels for ll in l if ll not in exluded_labels]
62
+ return Counter(flatten_labels), Counter(flatten_d_labels)
63
+
64
+ def align_labels_to_subwords(
65
+ srcs: List[str],
66
+ word_labels: List[List[str]],
67
+ tokenizer: PreTrainedTokenizer,
68
+ batch_size: int=100000,
69
+ max_length: int=128,
70
+ keep_label: str='$KEEP',
71
+ pad_token: str='<PAD>',
72
+ correct_label: str='$CORRECT',
73
+ incorrect_label: str='$INCORRECT'
74
+ ):
75
+ itr = list(range(0, len(srcs), batch_size))
76
+ subword_labels = []
77
+ subword_d_labels = []
78
+ word_masks = []
79
+ for i in tqdm(itr):
80
+ encode = tokenizer(
81
+ srcs[i:i+batch_size],
82
+ max_length=max_length,
83
+ return_tensors='pt',
84
+ padding='max_length',
85
+ truncation=True,
86
+ is_split_into_words=True
87
+ )
88
+ for i, wlabels in enumerate(word_labels[i:i+batch_size]):
89
+ d_labels = []
90
+ labels = []
91
+ wmask = []
92
+ word_ids = encode.word_ids(i)
93
+ previous_word_idx = None
94
+ for word_idx in word_ids:
95
+ if word_idx is None:
96
+ labels.append(pad_token)
97
+ d_labels.append(pad_token)
98
+ wmask.append(0)
99
+ elif word_idx != previous_word_idx:
100
+ l = wlabels[word_idx]
101
+ labels.append(l)
102
+ wmask.append(1)
103
+ if l != keep_label:
104
+ d_labels.append(incorrect_label)
105
+ else:
106
+ d_labels.append(correct_label)
107
+ else:
108
+ labels.append(pad_token)
109
+ d_labels.append(pad_token)
110
+ wmask.append(0)
111
+ previous_word_idx = word_idx
112
+ subword_d_labels.append(d_labels)
113
+ subword_labels.append(labels)
114
+ word_masks.append(wmask)
115
+ return subword_d_labels, subword_labels, word_masks
116
+
117
+ def load_gector_format(
118
+ input_file: str,
119
+ delimeter: str='SEPL|||SEPR',
120
+ additional_delimeter: str='SEPL__SEPR'
121
+ ):
122
+ srcs = []
123
+ word_level_labels = [] # the size will be (#sents, seq_length) if not get_interactive_tags,
124
+ # (#iteration, #sents, seq_length) if get_interactive_tags
125
+ with open(input_file) as f:
126
+ for line in f:
127
+ src = [x.split(delimeter)[0] for x in line.split()]
128
+ labels = [x.split(delimeter)[1] for x in line.split()]
129
+ # Use only first tags. E.g. $REPLACE_meSEPL__SEPR$APPEND_too → $REPLACE_me
130
+ labels = [l.split(additional_delimeter)[0] for l in labels]
131
+ srcs.append(src)
132
+ word_level_labels.append(labels)
133
+ return srcs, word_level_labels
134
+
135
+ def load_dataset(
136
+ input_file: str,
137
+ tokenizer: PreTrainedTokenizer,
138
+ delimeter: str='SEPL|||SEPR',
139
+ additional_delimeter: str='SEPL__SEPR',
140
+ batch_size: int=50000, # avoid too heavy computation in the tokenization
141
+ max_length: int=128
142
+ ):
143
+ srcs, word_level_labels = load_gector_format(
144
+ input_file,
145
+ delimeter=delimeter,
146
+ additional_delimeter=additional_delimeter
147
+ )
148
+ d_labels, labels, word_masks = align_labels_to_subwords(
149
+ srcs,
150
+ word_level_labels,
151
+ tokenizer=tokenizer,
152
+ batch_size=batch_size,
153
+ max_length=max_length
154
+ )
155
+ return GECToRDataset(
156
+ srcs=srcs,
157
+ d_labels=d_labels,
158
+ labels=labels,
159
+ word_masks=word_masks,
160
+ tokenizer=tokenizer,
161
+ max_length=max_length
162
+ )
163
+
164
+
gector/modeling.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel, AutoTokenizer, AutoConfig, PreTrainedModel
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torch.nn as nn
5
+ from torch.nn import CrossEntropyLoss
6
+ from dataclasses import dataclass
7
+ from .configuration import GECToRConfig
8
+ from typing import List, Union, Optional, Tuple
9
+ import os
10
+ import json
11
+ from huggingface_hub import snapshot_download, ModelCard
12
+
13
+ @dataclass
14
+ class GECToROutput:
15
+ loss: torch.Tensor = None
16
+ loss_d: torch.Tensor = None
17
+ loss_labels: torch.Tensor = None
18
+ logits_d: torch.Tensor = None
19
+ logits_labels: torch.Tensor = None
20
+ accuracy: torch.Tensor = None
21
+ accuracy_d: torch.Tensor = None
22
+
23
+ @dataclass
24
+ class GECToRPredictionOutput:
25
+ probability_labels: torch.Tensor = None
26
+ probability_d: torch.Tensor = None
27
+ pred_labels: List[List[str]] = None
28
+ pred_label_ids: torch.Tensor = None
29
+ max_error_probability: torch.Tensor = None
30
+
31
+ class GECToR(PreTrainedModel):
32
+ config_class = GECToRConfig
33
+ def __init__(
34
+ self,
35
+ config: GECToRConfig
36
+ ):
37
+ super().__init__(config)
38
+ self.config = config
39
+ self.tokenizer = AutoTokenizer.from_pretrained(
40
+ self.config.model_id
41
+ )
42
+ if self.config.has_add_pooling_layer:
43
+ self.bert = AutoModel.from_pretrained(
44
+ self.config.model_id,
45
+ add_pooling_layer=False
46
+ )
47
+ else:
48
+ self.bert = AutoModel.from_pretrained(
49
+ self.config.model_id
50
+ )
51
+ # +1 is for $START token
52
+ self.bert.resize_token_embeddings(self.bert.config.vocab_size + 1)
53
+ self.label_proj_layer = nn.Linear(
54
+ self.bert.config.hidden_size,
55
+ self.config.num_labels - 1
56
+ ) # -1 is for <PAD>
57
+ self.d_proj_layer = nn.Linear(
58
+ self.bert.config.hidden_size,
59
+ self.config.d_num_labels - 1
60
+ )
61
+ self.dropout = nn.Dropout(self.config.p_dropout)
62
+ self.loss_fn = CrossEntropyLoss(
63
+ label_smoothing=self.config.label_smoothing
64
+ )
65
+
66
+ self.post_init()
67
+ self.tune_bert(False)
68
+
69
+ def init_weight(self) -> None:
70
+ self._init_weights(self.label_proj_layer)
71
+ self._init_weights(self.d_proj_layer)
72
+
73
+ def _init_weights(self, module) -> None:
74
+ """Initialize the weights"""
75
+ if isinstance(module, nn.Linear):
76
+ # Slightly different from the TF version which uses truncated_normal for initialization
77
+ # cf https://github.com/pytorch/pytorch/pull/5617
78
+ module.weight.data.normal_(
79
+ mean=0.0,
80
+ std=self.config.initializer_range
81
+ )
82
+ if module.bias is not None:
83
+ module.bias.data.zero_()
84
+ return
85
+
86
+ def tune_bert(self, tune=True):
87
+ # If tune=False, only classifier layers will be tuned.
88
+ for param in self.bert.parameters():
89
+ param.requires_grad = tune
90
+ return
91
+
92
+ def forward(
93
+ self,
94
+ input_ids: Optional[torch.Tensor] = None,
95
+ attention_mask: Optional[torch.Tensor] = None,
96
+ token_type_ids: Optional[torch.Tensor] = None,
97
+ position_ids: Optional[torch.Tensor] = None,
98
+ inputs_embeds: Optional[torch.Tensor] = None,
99
+ labels: Optional[torch.Tensor] = None,
100
+ d_labels: Optional[torch.Tensor] = None,
101
+ output_attentions: Optional[bool] = None,
102
+ output_hidden_states: Optional[bool] = None,
103
+ return_dict: Optional[bool] = None,
104
+ word_masks: Optional[torch.Tensor] = None,
105
+ ) -> GECToROutput:
106
+ bert_logits = self.bert(
107
+ input_ids,
108
+ attention_mask=attention_mask,
109
+ token_type_ids=token_type_ids,
110
+ position_ids=position_ids,
111
+ inputs_embeds=inputs_embeds,
112
+ output_attentions=output_attentions,
113
+ output_hidden_states=output_hidden_states,
114
+ return_dict=return_dict,
115
+ ).last_hidden_state
116
+ logits_d = self.d_proj_layer(bert_logits)
117
+ logits_labels = self.label_proj_layer(self.dropout(bert_logits))
118
+ loss_d, loss_labels, loss = None, None, None
119
+ accuracy, accuracy_d = None, None
120
+ if d_labels is not None and labels is not None:
121
+ pad_id = self.config.label2id[self.config.label_pad_token]
122
+ # -100 is the default ignore_idx of CrossEntropyLoss
123
+ labels[labels == pad_id] = -100
124
+ d_labels[labels == -100] = -100
125
+ loss_d = self.loss_fn(
126
+ logits_d.view(-1, self.config.d_num_labels - 1), # -1 for <PAD>
127
+ d_labels.view(-1)
128
+ )
129
+ loss_labels = self.loss_fn(
130
+ logits_labels.view(-1, self.config.num_labels - 1),
131
+ labels.view(-1)
132
+ )
133
+ loss = loss_d + loss_labels
134
+
135
+ pred_labels = torch.argmax(logits_labels, dim=-1)
136
+ accuracy = torch.sum(
137
+ (labels == pred_labels) * word_masks
138
+ ) / torch.sum(word_masks)
139
+ pred_d = torch.argmax(logits_d, dim=-1)
140
+ accuracy_d = torch.sum(
141
+ (d_labels == pred_d) * word_masks
142
+ ) / torch.sum(word_masks)
143
+
144
+ return GECToROutput(
145
+ loss=loss,
146
+ loss_d=loss_d,
147
+ loss_labels=loss_labels,
148
+ logits_d=logits_d,
149
+ logits_labels=logits_labels,
150
+ accuracy=accuracy,
151
+ accuracy_d=accuracy_d
152
+ )
153
+
154
+ def predict(
155
+ self,
156
+ input_ids: torch.Tensor,
157
+ attention_mask: torch.Tensor,
158
+ word_masks: torch.Tensor,
159
+ keep_confidence: float=0,
160
+ min_error_prob: float=0
161
+ ):
162
+ with torch.no_grad():
163
+ outputs = self.forward(
164
+ input_ids,
165
+ attention_mask
166
+ )
167
+ probability_labels = F.softmax(outputs.logits_labels, dim=-1)
168
+ probability_d = F.softmax(outputs.logits_d, dim=-1)
169
+
170
+ # Get actual labels considering inference parameters.
171
+ keep_index = self.config.label2id[self.config.keep_label]
172
+ probability_labels[:, :, keep_index] += keep_confidence
173
+ incor_idx = self.config.d_label2id[self.config.incorrect_label]
174
+ probability_d = probability_d[:, :, incor_idx]
175
+ max_error_probability = torch.max(probability_d * word_masks, dim=-1)[0]
176
+ probability_labels[max_error_probability < min_error_prob, :, keep_index] \
177
+ = float('inf')
178
+ pred_label_ids = torch.argmax(probability_labels, dim=-1)
179
+
180
+ def convert_ids_to_labels(ids, id2label):
181
+ labels = []
182
+ for id in ids.tolist():
183
+ labels.append(id2label[id])
184
+ return labels
185
+
186
+ pred_labels = []
187
+ for ids in pred_label_ids:
188
+ labels = convert_ids_to_labels(
189
+ ids,
190
+ self.config.id2label
191
+ )
192
+ pred_labels.append(labels)
193
+
194
+ return GECToRPredictionOutput(
195
+ probability_labels=probability_labels,
196
+ probability_d=probability_d,
197
+ pred_labels=pred_labels,
198
+ pred_label_ids=pred_label_ids,
199
+ max_error_probability=max_error_probability
200
+ )
gector/predict.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from tqdm import tqdm
4
+ from .modeling import GECToR
5
+ from transformers import PreTrainedTokenizer
6
+ from typing import List
7
+
8
+ def load_verb_dict(verb_file: str):
9
+ path_to_dict = os.path.join(verb_file)
10
+ encode, decode = {}, {}
11
+ with open(path_to_dict, encoding="utf-8") as f:
12
+ for line in f:
13
+ words, tags = line.split(":")
14
+ word1, word2 = words.split("_")
15
+ tag1, tag2 = tags.split("_")
16
+ decode_key = f"{word1}_{tag1}_{tag2.strip()}"
17
+ if decode_key not in decode:
18
+ encode[words] = tags
19
+ decode[decode_key] = word2
20
+ return encode, decode
21
+
22
+ def edit_src_by_tags(
23
+ srcs: List[List[str]],
24
+ pred_labels: List[List[str]],
25
+ encode: dict,
26
+ decode: dict
27
+ ) -> List[str]:
28
+ edited_srcs = []
29
+ for tokens, labels in zip(srcs, pred_labels):
30
+ edited_tokens = []
31
+ for t, l, in zip(tokens, labels):
32
+ n_token = process_token(t, l, encode, decode)
33
+ if n_token == None:
34
+ n_token = t
35
+ edited_tokens += n_token.split(' ')
36
+ if len(tokens) > len(labels):
37
+ omitted_tokens = tokens[len(labels):]
38
+ edited_tokens += omitted_tokens
39
+ temp_str = ' '.join(edited_tokens) \
40
+ .replace(' $MERGE_HYPHEN ', '-') \
41
+ .replace(' $MERGE_SPACE ', '') \
42
+ .replace(' $DELETE', '') \
43
+ .replace('$DELETE ', '')
44
+ edited_srcs.append(temp_str.split(' '))
45
+ return edited_srcs
46
+
47
+ def process_token(
48
+ token: str,
49
+ label: str,
50
+ encode: dict,
51
+ decode: dict
52
+ ) -> str:
53
+ if '$APPEND_' in label:
54
+ return token + ' ' + label.replace('$APPEND_', '')
55
+ elif token == '$START':
56
+ # [unused1] token cannot be replaced with another token and cannot be deleted.
57
+ return token
58
+ elif label in ['<PAD>', '<OOV>', '$KEEP']:
59
+ return token
60
+ elif '$APPEND_' in label:
61
+ return token + ' ' + label.replace('$APPEND_', '')
62
+ elif '$TRANSFORM_' in label:
63
+ return g_transform_processer(token, label, encode, decode)
64
+ elif '$REPLACE_' in label:
65
+ return label.replace('$REPLACE_', '')
66
+ elif label == '$DELETE':
67
+ return label
68
+ elif '$MERGE_' in label:
69
+ return token + ' ' + label
70
+ else:
71
+ return token
72
+
73
+ def g_transform_processer(
74
+ token: str,
75
+ label: str,
76
+ encode: dict,
77
+ decode: dict
78
+ ) -> str:
79
+ # Case related
80
+ if label == '$TRANSFORM_CASE_LOWER':
81
+ return token.lower()
82
+ elif label == '$TRANSFORM_CASE_UPPER':
83
+ return token.upper()
84
+ elif label == '$TRANSFORM_CASE_CAPITAL':
85
+ return token.capitalize()
86
+ elif label == '$TRANSFORM_CASE_CAPITAL_1':
87
+ if len(token) <= 1:
88
+ return token
89
+ return token[0] + token[1:].capitalize()
90
+ elif label == '$TRANSFORM_AGREEMENT_PLURAL':
91
+ return token + 's'
92
+ elif label == '$TRANSFORM_AGREEMENT_SINGULAR':
93
+ return token[:-1]
94
+ elif label == '$TRANSFORM_SPLIT_HYPHEN':
95
+ return ' '.join(token.split('-'))
96
+ else:
97
+ encoding_part = f"{token}_{label[len('$TRANSFORM_VERB_'):]}"
98
+ decoded_target_word = decode.get(encoding_part)
99
+ return decoded_target_word
100
+
101
+ def get_word_masks_from_word_ids(
102
+ word_ids: List[List[int]],
103
+ n: int
104
+ ):
105
+ word_masks = []
106
+ for i in range(n):
107
+ previous_id = 0
108
+ mask = []
109
+ for _id in word_ids(i):
110
+ if _id is None:
111
+ mask.append(0)
112
+ elif previous_id != _id:
113
+ mask.append(1)
114
+ else:
115
+ mask.append(0)
116
+ previous_id = _id
117
+ word_masks.append(mask)
118
+ return word_masks
119
+
120
+ def _predict(
121
+ model: GECToR,
122
+ tokenizer: PreTrainedTokenizer,
123
+ srcs: List[str],
124
+ keep_confidence: float=0,
125
+ min_error_prob: float=0,
126
+ batch_size: int=128
127
+ ):
128
+ itr = list(range(0, len(srcs), batch_size))
129
+ pred_labels = []
130
+ no_corrections = []
131
+ for i in tqdm(itr):
132
+ batch = tokenizer(
133
+ srcs[i:i+batch_size],
134
+ return_tensors='pt',
135
+ max_length=model.config.max_length,
136
+ padding='max_length',
137
+ truncation=True,
138
+ is_split_into_words=True
139
+ )
140
+ batch['word_masks'] = torch.tensor(
141
+ get_word_masks_from_word_ids(
142
+ batch.word_ids,
143
+ batch['input_ids'].size(0)
144
+ )
145
+ )
146
+ word_ids = batch.word_ids
147
+ if torch.cuda.is_available():
148
+ batch = {k:v.cuda() for k,v in batch.items()}
149
+ outputs = model.predict(
150
+ batch['input_ids'],
151
+ batch['attention_mask'],
152
+ batch['word_masks'],
153
+ keep_confidence,
154
+ min_error_prob
155
+ )
156
+ # Align subword-level label to word-level label
157
+ for i in range(len(outputs.pred_labels)):
158
+ no_correct = True
159
+ labels = []
160
+ previous_word_idx = None
161
+ for j, idx in enumerate(word_ids(i)):
162
+ if idx is None:
163
+ continue
164
+ if idx != previous_word_idx:
165
+ labels.append(outputs.pred_labels[i][j])
166
+ if outputs.pred_label_ids[i][j] > 2:
167
+ no_correct = False
168
+ previous_word_idx = idx
169
+ # print(no_correct, labels)
170
+ pred_labels.append(labels)
171
+ no_corrections.append(no_correct)
172
+ # print(pred_labels)
173
+ return pred_labels, no_corrections
174
+
175
+ def predict(
176
+ model: GECToR,
177
+ tokenizer: PreTrainedTokenizer,
178
+ srcs: List[str],
179
+ encode: dict,
180
+ decode: dict,
181
+ keep_confidence: float=0,
182
+ min_error_prob: float=0,
183
+ batch_size: int=128,
184
+ n_iteration: int=5
185
+ ) -> List[str]:
186
+ srcs = [['$START'] + src.split(' ') for src in srcs]
187
+ final_edited_sents = ['-1'] * len(srcs)
188
+ to_be_processed = srcs
189
+ original_sent_idx = list(range(0, len(srcs)))
190
+ for itr in range(n_iteration):
191
+ print(f'Iteratoin {itr}. the number of to_be_processed: {len(to_be_processed)}')
192
+ pred_labels, no_corrections = _predict(
193
+ model,
194
+ tokenizer,
195
+ to_be_processed,
196
+ keep_confidence,
197
+ min_error_prob,
198
+ batch_size
199
+ )
200
+ current_srcs = []
201
+ current_pred_labels = []
202
+ current_orig_idx = []
203
+ for i, yes in enumerate(no_corrections):
204
+ if yes: # there's no corrections?
205
+ final_edited_sents[original_sent_idx[i]] = ' '.join(to_be_processed[i]).replace('$START ', '')
206
+ else:
207
+ current_srcs.append(to_be_processed[i])
208
+ current_pred_labels.append(pred_labels[i])
209
+ current_orig_idx.append(original_sent_idx[i])
210
+ if current_srcs == []:
211
+ # Correcting for all sentences is completed.
212
+ break
213
+ # if itr > 2:
214
+ # for l in current_pred_labels:
215
+ # print(l)
216
+ edited_srcs = edit_src_by_tags(
217
+ current_srcs,
218
+ current_pred_labels,
219
+ encode,
220
+ decode
221
+ )
222
+ to_be_processed = edited_srcs
223
+ original_sent_idx = current_orig_idx
224
+
225
+ # print(f'=== Iteration {itr} ===')
226
+ # print('\n'.join(final_edited_sents))
227
+ # print(to_be_processed)
228
+ # print(have_corrections)
229
+ for i in range(len(to_be_processed)):
230
+ final_edited_sents[original_sent_idx[i]] = ' '.join(to_be_processed[i]).replace('$START ', '')
231
+ assert('-1' not in final_edited_sents)
232
+ return final_edited_sents
gector/predict_verbose.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from tqdm import tqdm
4
+ from .modeling import GECToR
5
+ from transformers import PreTrainedTokenizer
6
+ from typing import List, Dict
7
+ from .predict import (
8
+ edit_src_by_tags,
9
+ _predict
10
+ )
11
+
12
+ def predict_verbose(
13
+ model: GECToR,
14
+ tokenizer: PreTrainedTokenizer,
15
+ srcs: List[str],
16
+ encode: dict,
17
+ decode: dict,
18
+ keep_confidence: float=0,
19
+ min_error_prob: float=0,
20
+ batch_size: int=128,
21
+ n_iteration: int=5
22
+ ) -> List[str]:
23
+ srcs = [['$START'] + src.split(' ') for src in srcs]
24
+ final_edited_sents = ['-1'] * len(srcs)
25
+ to_be_processed = srcs
26
+ original_sent_idx = list(range(0, len(srcs)))
27
+ iteration_log: List[List[Dict]] = [] # [send_id][iteration_id]['src' or 'tags']
28
+ iteration_log = []
29
+ # Initialize iteration logs.
30
+ for i, src in enumerate(srcs):
31
+ iteration_log.append([{
32
+ 'src': src,
33
+ 'tag': None
34
+ }])
35
+ for itr in range(n_iteration):
36
+ print(f'Iteratoin {itr}. the number of to_be_processed: {len(to_be_processed)}')
37
+ pred_labels, no_corrections = _predict(
38
+ model,
39
+ tokenizer,
40
+ to_be_processed,
41
+ keep_confidence,
42
+ min_error_prob,
43
+ batch_size
44
+ )
45
+ current_srcs = []
46
+ current_pred_labels = []
47
+ current_orig_idx = []
48
+ for i, yes in enumerate(no_corrections):
49
+ if yes: # there's no corrections?
50
+ final_edited_sents[original_sent_idx[i]] = ' '.join(to_be_processed[i]).replace('$START ', '')
51
+ else:
52
+ current_srcs.append(to_be_processed[i])
53
+ current_pred_labels.append(pred_labels[i])
54
+ current_orig_idx.append(original_sent_idx[i])
55
+ if current_srcs == []:
56
+ # Correcting for all sentences is completed.
57
+ break
58
+ edited_srcs = edit_src_by_tags(
59
+ current_srcs,
60
+ current_pred_labels,
61
+ encode,
62
+ decode
63
+ )
64
+ # Register the information during iteration.
65
+ # edited_src will be the src of the next iteration.
66
+ for i, orig_id in enumerate(current_orig_idx):
67
+ iteration_log[orig_id][itr]['tag'] = current_pred_labels[i]
68
+ iteration_log[orig_id].append({
69
+ 'src': edited_srcs[i],
70
+ 'tag': None
71
+ })
72
+
73
+ to_be_processed = edited_srcs
74
+ original_sent_idx = current_orig_idx
75
+
76
+ # print(f'=== Iteration {itr} ===')
77
+ # print('\n'.join(final_edited_sents))
78
+ # print(to_be_processed)
79
+ # print(have_corrections)
80
+ for i in range(len(to_be_processed)):
81
+ final_edited_sents[original_sent_idx[i]] = ' '.join(to_be_processed[i]).replace('$START ', '')
82
+ assert('-1' not in final_edited_sents)
83
+ return final_edited_sents, iteration_log
gector/vocab.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .configuration import GECToRConfig
2
+ from .dataset import GECToRDataset
3
+ import os
4
+
5
+ def build_vocab(
6
+ train_dataset: GECToRDataset,
7
+ n_max_labels: int=5000,
8
+ n_max_d_labels: int=2
9
+ ):
10
+ label2id = {'<OOV>':0, '$KEEP':1}
11
+ d_label2id = {'$CORRECT':0, '$INCORRECT':1, '<PAD>':2}
12
+ freq_labels, _ = train_dataset.get_labels_freq(
13
+ exluded_labels=['<PAD>'] + list(label2id.keys())
14
+ )
15
+
16
+ def get_high_freq(freq: dict, n_max: int):
17
+ descending_freq = sorted(
18
+ freq.items(), key=lambda x:x[1], reverse=True
19
+ )
20
+ high_freq = [x[0] for x in descending_freq][:n_max]
21
+ if len(high_freq) < n_max:
22
+ print(f'Warning: the size of the vocablary: {len(high_freq)} is less than n_max: {n_max}.')
23
+ return high_freq
24
+
25
+ high_freq_labels = get_high_freq(freq_labels, n_max_labels-2)
26
+ for i, x in enumerate(high_freq_labels):
27
+ label2id[x] = i + 2
28
+ label2id['<PAD>'] = len(label2id)
29
+ return label2id, d_label2id
30
+
31
+ def load_vocab_from_config(config_file: str):
32
+ config = GECToRConfig.from_pretrained(config_file, not_dir=True)
33
+ return config.label2id, config.d_label2id
34
+
35
+ def load_vocab_from_official(dir):
36
+ vocab_path = os.path.join(dir, 'labels.txt')
37
+ vocab = open(vocab_path).read().replace('@@PADDING@@', '').replace('@@UNKNOWN@@', '').rstrip().split('\n')
38
+ # vocab_d = open(dir + 'd_tags.txt').read().rstrip().replace('@@PADDING@@', '<PAD>').replace('@@UNKNOWN@@', '<OOV>').split('\n')
39
+ label2id = {'<OOV>':0, '$KEEP':1}
40
+ d_label2id = {'$CORRECT':0, '$INCORRECT':1, '<PAD>':2}
41
+ idx = len(label2id)
42
+ for v in vocab:
43
+ if v not in label2id:
44
+ label2id[v] = idx
45
+ idx += 1
46
+ label2id['<PAD>'] = idx
47
+ return label2id, d_label2id
48
+
handler.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import AutoTokenizer
3
+ from gector import GECToR, predict, load_verb_dict
4
+
5
+
6
+ class EndpointHandler:
7
+ def __init__(self, path=""):
8
+ self.model = GECToR.from_pretrained(path)
9
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
10
+ self.encode, self.decode = load_verb_dict("data/verb-form-vocab.txt")
11
+
12
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
13
+ """
14
+ Process the input data and return the predicted results.
15
+
16
+ Args:
17
+ data (Dict[str, Any]): The input data dictionary containing the following keys:
18
+ - "inputs" (List[str]): A list of input strings to be processed.
19
+ - "n_iterations" (int, optional): The number of iterations for prediction. Defaults to 5.
20
+ - "batch_size" (int, optional): The batch size for prediction. Defaults to 2.
21
+ - "keep_confidence" (float, optional): The confidence threshold for keeping predictions. Defaults to 0.0.
22
+ - "min_error_prob" (float, optional): The minimum error probability for keeping predictions. Defaults to 0.0.
23
+
24
+ Returns:
25
+ List[Dict[str, Any]]: A list of dictionaries containing the predicted results for each input string.
26
+ """
27
+ srcs = data["inputs"]
28
+
29
+ # Extract optional parameters from data, with defaults
30
+ n_iterations = data.get("n_iterations", 5)
31
+ batch_size = data.get("batch_size", 2)
32
+ keep_confidence = data.get("keep_confidence", 0.0)
33
+ min_error_prob = data.get("min_error_prob", 0.0)
34
+
35
+ return predict(
36
+ model=self.model,
37
+ tokenizer=self.tokenizer,
38
+ srcs=srcs,
39
+ encode=self.encode,
40
+ decode=self.decode,
41
+ keep_confidence=keep_confidence,
42
+ min_error_prob=min_error_prob,
43
+ n_iteration=n_iterations,
44
+ batch_size=batch_size,
45
+ )
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.27.0
2
+ certifi==2024.2.2
3
+ charset-normalizer==3.3.2
4
+ filelock==3.13.1
5
+ fsspec==2024.2.0
6
+ huggingface-hub==0.20.3
7
+ idna==3.6
8
+ Jinja2==3.1.3
9
+ Levenshtein==0.24.0
10
+ MarkupSafe==2.1.5
11
+ mpmath==1.3.0
12
+ networkx==3.2.1
13
+ numpy==1.26.4
14
+ packaging==23.2
15
+ psutil==5.9.8
16
+ PyYAML==6.0.1
17
+ rapidfuzz==3.6.1
18
+ regex==2023.12.25
19
+ requests==2.31.0
20
+ safetensors==0.4.2
21
+ sympy==1.12
22
+ tokenizers==0.15.1
23
+ torch==2.2.0
24
+ tqdm==4.66.2
25
+ transformers==4.37.2
26
+ typing_extensions==4.9.0
27
+ urllib3==2.2.0