Shaltiel commited on
Commit
b003ea2
•
1 Parent(s): 7ee2fd1

Upload 2 files

Browse files
Files changed (2) hide show
  1. BertForPrefixMarking.py +220 -0
  2. config.json +3 -0
BertForPrefixMarking.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.utils import ModelOutput
2
+ import torch
3
+ from torch import nn
4
+ from typing import List, Tuple, Optional
5
+ from dataclasses import dataclass
6
+ from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
7
+
8
+ # define the classes, and the possible prefixes for each class
9
+ POSSIBLE_PREFIX_CLASSES = [ ['לכש', 'כש', 'מש', 'בש', 'לש'], ['מ'], ['ש'], ['ה'], ['ו'], ['כ'], ['ל'], ['ב'] ]
10
+ # map each individual prefix to it's class number
11
+ PREFIXES_TO_CLASS = {w:i for i,l in enumerate(POSSIBLE_PREFIX_CLASSES) for w in l}
12
+ # keep a list of all the prefixes, sorted by length, so that we can decompose
13
+ # a given prefixes and figure out the classes
14
+ ALL_PREFIX_ITEMS = list(sorted(PREFIXES_TO_CLASS.keys(), key=len, reverse=True))
15
+ TOTAL_POSSIBLE_PREFIX_CLASSES = len(POSSIBLE_PREFIX_CLASSES)
16
+
17
+ def get_prefixes_from_str(s, greedy=False):
18
+ # keep trimming prefixes from the string
19
+ while len(s) > 0 and s[0] in PREFIXES_TO_CLASS:
20
+ # find the longest string to trim
21
+ next_pre = next((pre for pre in ALL_PREFIX_ITEMS if s.startswith(pre)), None)
22
+ if next_pre is None:
23
+ return
24
+ yield next_pre
25
+ # if the chosen prefix is more than one letter, there is always an option that the
26
+ # prefix is actually just the first letter of the prefix - so offer that up as a valid prefix
27
+ # as well. We will still jump to the length of the longer one, since if the next two/three
28
+ # letters are a prefix, they have to be the longest one
29
+ if not greedy and len(next_pre) > 1:
30
+ yield next_pre[0]
31
+ s = s[len(next_pre):]
32
+
33
+ def get_prefix_classes_from_str(s, greedy=False):
34
+ for pre in get_prefixes_from_str(s, greedy):
35
+ yield PREFIXES_TO_CLASS[pre]
36
+
37
+ @dataclass
38
+ class PrefixesClassifiersOutput(ModelOutput):
39
+ logits: torch.FloatTensor = None
40
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
41
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
42
+
43
+ class BertForPrefixMarking(BertPreTrainedModel):
44
+
45
+ def __init__(self, config):
46
+ super().__init__(config)
47
+
48
+ self.bert = BertModel(config, add_pooling_layer=False)
49
+ self.dropout = nn.Dropout(0.1)
50
+
51
+ # an embedding table containing an embedding for each prefix class + 1 for NONE
52
+ # we will concatenate either the embedding/NONE for each class - and we want the concatenate
53
+ # size to be the hidden_size
54
+ prefix_class_embed = config.hidden_size // TOTAL_POSSIBLE_PREFIX_CLASSES
55
+ self.prefix_class_embeddings = nn.Embedding(TOTAL_POSSIBLE_PREFIX_CLASSES + 1, prefix_class_embed)
56
+
57
+ # one layer for transformation, apply an activation, then another N classifiers for each prefix class
58
+ self.transform = nn.Linear(config.hidden_size + prefix_class_embed * TOTAL_POSSIBLE_PREFIX_CLASSES, config.hidden_size)
59
+ self.activation = nn.Tanh()
60
+ self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, 2) for _ in range(TOTAL_POSSIBLE_PREFIX_CLASSES)])
61
+
62
+ # Initialize weights and apply final processing
63
+ self.post_init()
64
+
65
+ def forward(
66
+ self,
67
+ input_ids: Optional[torch.Tensor] = None,
68
+ attention_mask: Optional[torch.Tensor] = None,
69
+ token_type_ids: Optional[torch.Tensor] = None,
70
+ prefix_class_id_options: Optional[torch.Tensor] = None,
71
+ position_ids: Optional[torch.Tensor] = None,
72
+ head_mask: Optional[torch.Tensor] = None,
73
+ inputs_embeds: Optional[torch.Tensor] = None,
74
+ output_attentions: Optional[bool] = None,
75
+ output_hidden_states: Optional[bool] = None,
76
+ return_dict: Optional[bool] = None,
77
+ ):
78
+ r"""
79
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
80
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
81
+ """
82
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
83
+
84
+ bert_outputs = self.bert(
85
+ input_ids,
86
+ attention_mask=attention_mask,
87
+ token_type_ids=token_type_ids,
88
+ position_ids=position_ids,
89
+ head_mask=head_mask,
90
+ inputs_embeds=inputs_embeds,
91
+ output_attentions=output_attentions,
92
+ output_hidden_states=output_hidden_states,
93
+ return_dict=return_dict,
94
+ )
95
+
96
+ sequence_output = bert_outputs[0]
97
+ sequence_output = self.dropout(sequence_output)
98
+
99
+ # encode the prefix_class_id_options
100
+ # If input_ids is batch x seq_len
101
+ # Then sequence_output is batch x seq_len x hidden_dim
102
+ # So prefix_class_id_options is batch x seq_len x TOTAL_POSSIBLE_PREFIX_CLASSES
103
+ # Looking up the embeddings should give us batch x seq_len x TOTAL_POSSIBLE_PREFIX_CLASSES x hidden_dim / N
104
+ possible_class_embed = self.prefix_class_embeddings(prefix_class_id_options)
105
+ # then flatten the final dimension - now we have batch x seq_len x hidden_dim_2
106
+ possible_class_embed = possible_class_embed.reshape(possible_class_embed.shape[:-2] + (-1,))
107
+
108
+ # concatenate the new class embed into the sequence output before the transform
109
+ pre_transform_output = torch.cat((sequence_output, possible_class_embed), dim=-1) # batch x seq_len x (hidden_dim + hidden_dim_2)
110
+ pre_logits_output = self.activation(self.transform(pre_transform_output))# batch x seq_len x hidden_dim
111
+ # run each of the classifiers on the transformed output
112
+ logits = torch.cat([cls(pre_logits_output).unsqueeze(-2) for cls in self.classifiers], dim=-2)
113
+
114
+ if not return_dict:
115
+ return (logits,) + bert_outputs[2:]
116
+
117
+ return PrefixesClassifiersOutput(
118
+ logits=logits,
119
+ hidden_states=bert_outputs.hidden_states,
120
+ attentions=bert_outputs.attentions,
121
+ )
122
+
123
+ def predict(self, sentences: List[str], tokenizer: BertTokenizerFast, padding='longest'):
124
+ # step 1: encode the sentences through using the tokenizer, and get the input tensors + prefix id tensors
125
+ inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, sentences, padding)
126
+
127
+ # run through bert
128
+ logits = self.forward(**inputs, return_dict=True).logits
129
+
130
+ # extract the predictions by argmaxing the final dimension (batch x sequence x prefixes x prediction)
131
+ logit_preds = torch.argmax(logits, axis=3)
132
+
133
+ ret = []
134
+
135
+ for sent_idx,sent_ids in enumerate(inputs['input_ids']):
136
+ tokens = tokenizer.convert_ids_to_tokens(sent_ids)
137
+ ret.append([])
138
+ for tok_idx,token in enumerate(tokens):
139
+ # If we've reached the pad token, then we are at the end
140
+ if token == tokenizer.pad_token: continue
141
+ if token.startswith('##'): continue
142
+
143
+ # combine the next tokens in? only if it's a breakup
144
+ next_tok_idx = tok_idx + 1
145
+ while next_tok_idx < len(tokens) and tokens[next_tok_idx].startswith('##'):
146
+ token += tokens[next_tok_idx][2:]
147
+
148
+ prefix_len = get_predicted_prefix_len_from_logits(token, logit_preds[sent_idx, tok_idx])
149
+
150
+ if not prefix_len:
151
+ ret[-1].append([token])
152
+ else:
153
+ ret[-1].append([token[:prefix_len], token[prefix_len:]])
154
+
155
+ return ret
156
+
157
+
158
+
159
+ def encode_sentences_for_bert_for_prefix_marking(tokenizer: BertTokenizerFast, sentences: List[str], padding='longest'):
160
+ inputs = tokenizer(sentences, padding=padding, return_tensors='pt')
161
+
162
+ # create our prefix_id_options array which will be like the input ids shape but with an addtional
163
+ # dimension containing for each prefix whether it can be for that word
164
+ prefix_id_options = torch.full(inputs['input_ids'].shape + (TOTAL_POSSIBLE_PREFIX_CLASSES,), TOTAL_POSSIBLE_PREFIX_CLASSES, dtype=torch.long)
165
+
166
+ # go through each token, and fill in the vector accordingly
167
+ for sent_idx, sent_ids in enumerate(inputs['input_ids']):
168
+ tokens = tokenizer.convert_ids_to_tokens(sent_ids)
169
+ for tok_idx, token in enumerate(tokens):
170
+ # if the first letter isn't a valid prefix letter, nothing to talk about
171
+ if len(token) < 2 or not token[0] in PREFIXES_TO_CLASS: continue
172
+
173
+ # combine the next tokens in? only if it's a breakup
174
+ next_tok_idx = tok_idx + 1
175
+ while next_tok_idx < len(tokens) and tokens[next_tok_idx].startswith('##'):
176
+ token += tokens[next_tok_idx][2:]
177
+
178
+ # find all the possible prefixes - and mark them as 0 (and in the possible mark it as it's value for embed lookup)
179
+ for pre_class in get_prefix_classes_from_str(token):
180
+ prefix_id_options[sent_idx, tok_idx, pre_class] = pre_class
181
+
182
+ inputs['prefix_class_id_options'] = prefix_id_options
183
+ return inputs
184
+
185
+ def get_predicted_prefix_len_from_logits(token, token_logits):
186
+ # Go through each possible prefix, and check if the prefix is yes - and if
187
+ # so increase the counter of the matched length, otherwise break out. That will solve cases
188
+ # of predicting prefix combinations that don't exist on the word.
189
+ # For example, if we have the word ושכשהלכתי and the model predict ו & כש, then we will only
190
+ # take the vuv because in order to get the כש we need the ש as well.
191
+ # Two extra items:
192
+ # 1] Don't allow the same prefix multiple times
193
+ # 2] Always check that the word starts with that prefix - otherwise it's bad
194
+ # (except for the case of multi-letter prefix, where we force the next to be last)
195
+ cur_len, skip_next, last_check, seen_prefixes = 0, False, False, set()
196
+ for prefix in get_prefixes_from_str(token):
197
+ # Are we skipping this prefix? This will be the case where we matched כש, don't allow ש
198
+ if skip_next:
199
+ skip_next = False
200
+ continue
201
+ # check for duplicate prefixes, we don't allow two of the same prefix
202
+ # if it predicted two of the same, then we will break out
203
+ if prefix in seen_prefixes: break
204
+ seen_prefixes.add(prefix)
205
+
206
+ # check if we predicted this prefix
207
+ if token_logits[PREFIXES_TO_CLASS[prefix]].item():
208
+ cur_len += len(prefix)
209
+ if last_check: break
210
+ skip_next = len(prefix) > 1
211
+ # Otherwise, we predicted no. If we didn't, then this is the end of the prefix
212
+ # and time to break out. *Except* if it's a multi letter prefix, then we allow
213
+ # just the next letter - e.g., if כש doesn't match, then we allow כ, but then we know
214
+ # the word continues with a ש, and if it's not כש, then it's not כ-ש- (invalid)
215
+ elif len(prefix) > 1:
216
+ last_check = True
217
+ else:
218
+ break
219
+
220
+ return cur_len
config.json CHANGED
@@ -2,6 +2,9 @@
2
  "architectures": [
3
  "BertForMaskedLM"
4
  ],
 
 
 
5
  "attention_probs_dropout_prob": 0.1,
6
  "gradient_checkpointing": false,
7
  "hidden_act": "gelu",
 
2
  "architectures": [
3
  "BertForMaskedLM"
4
  ],
5
+ "auto_map": {
6
+ "AutoModel": "BertForPrefixMarking.BertForPrefixMarking"
7
+ },
8
  "attention_probs_dropout_prob": 0.1,
9
  "gradient_checkpointing": false,
10
  "hidden_act": "gelu",