Shaltiel commited on
Commit
deb5cae
โ€ข
1 Parent(s): 2963a45

Fixed bug with UNK tokens being discarded causing misalignment.

Browse files
BertForJointParsing.py CHANGED
@@ -199,7 +199,7 @@ class BertForJointParsing(BertPreTrainedModel):
199
 
200
  # predict the logits for the sentence
201
  if self.prefix is not None:
202
- inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, sentences, padding)
203
  else:
204
  inputs = tokenizer(sentences, padding=padding, truncation=truncation, return_offsets_mapping=True, return_tensors='pt')
205
 
@@ -218,7 +218,7 @@ class BertForJointParsing(BertPreTrainedModel):
218
 
219
  # Prefix logits: each sentence gets a list([prefix_segment, word_without_prefix]) - **WITH CLS & SEP**
220
  if output.prefix_logits is not None:
221
- for sent_idx,parsed in enumerate(prefix_parse_logits(input_ids, sentences, tokenizer, output.prefix_logits)):
222
  merge_token_list(final_output[sent_idx]['tokens'], map(tuple, parsed[1:-1]), 'seg')
223
 
224
  # Lex logits each sentence gets a list(tuple(word, lexeme))
@@ -272,6 +272,7 @@ def combine_token_wordpieces(input_ids: List[int], offset_mapping: torch.Tensor,
272
  offset_mapping = offset_mapping.tolist()
273
  ret = []
274
  special_toks = tokenizer.all_special_tokens
 
275
  for token, offsets in zip(tokenizer.convert_ids_to_tokens(input_ids), offset_mapping):
276
  if token in special_toks: continue
277
  if token.startswith('##'):
@@ -285,6 +286,7 @@ def ner_parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer
285
  batch_ret = []
286
 
287
  special_toks = tokenizer.all_special_tokens
 
288
  for batch_idx in range(len(sentences)):
289
 
290
  ret = []
@@ -311,6 +313,7 @@ def lex_parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer
311
  batch_ret = []
312
 
313
  special_toks = tokenizer.all_special_tokens
 
314
  for batch_idx in range(len(sentences)):
315
  intermediate_ret = []
316
  tokens = tokenizer.convert_ids_to_tokens(input_ids[batch_idx])
@@ -519,5 +522,4 @@ def ud_get_prefix_dep(pre, word, word_idx):
519
  if pre == 'ื”':
520
  func = 'det' if 'DET' in word['morph']['prefixes'] else 'mark'
521
 
522
- return (word['syntax']['dep_head_idx'] if does_follow_main else word_idx), func
523
-
 
199
 
200
  # predict the logits for the sentence
201
  if self.prefix is not None:
202
+ inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, self.config.prefix_cfg, sentences, padding)
203
  else:
204
  inputs = tokenizer(sentences, padding=padding, truncation=truncation, return_offsets_mapping=True, return_tensors='pt')
205
 
 
218
 
219
  # Prefix logits: each sentence gets a list([prefix_segment, word_without_prefix]) - **WITH CLS & SEP**
220
  if output.prefix_logits is not None:
221
+ for sent_idx,parsed in enumerate(prefix_parse_logits(input_ids, sentences, tokenizer, output.prefix_logits, self.config.prefix_cfg)):
222
  merge_token_list(final_output[sent_idx]['tokens'], map(tuple, parsed[1:-1]), 'seg')
223
 
224
  # Lex logits each sentence gets a list(tuple(word, lexeme))
 
272
  offset_mapping = offset_mapping.tolist()
273
  ret = []
274
  special_toks = tokenizer.all_special_tokens
275
+ special_toks.remove(tokenizer.unk_token)
276
  for token, offsets in zip(tokenizer.convert_ids_to_tokens(input_ids), offset_mapping):
277
  if token in special_toks: continue
278
  if token.startswith('##'):
 
286
  batch_ret = []
287
 
288
  special_toks = tokenizer.all_special_tokens
289
+ special_toks.remove(tokenizer.unk_token)
290
  for batch_idx in range(len(sentences)):
291
 
292
  ret = []
 
313
  batch_ret = []
314
 
315
  special_toks = tokenizer.all_special_tokens
316
+ special_toks.remove(tokenizer.unk_token)
317
  for batch_idx in range(len(sentences)):
318
  intermediate_ret = []
319
  tokens = tokenizer.convert_ids_to_tokens(input_ids[batch_idx])
 
522
  if pre == 'ื”':
523
  func = 'det' if 'DET' in word['morph']['prefixes'] else 'mark'
524
 
525
+ return (word['syntax']['dep_head_idx'] if does_follow_main else word_idx), func
 
BertForMorphTagging.py CHANGED
@@ -176,6 +176,7 @@ def parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer: Be
176
  # Where tokens is a list of dicts, where each dict is:
177
  # { pos: str, feats: dict, prefixes: List[str], suffix: str | bool, suffix_feats: dict | None}
178
  special_toks = tokenizer.all_special_tokens
 
179
  ret = []
180
  for sent_idx,sentence in enumerate(sentences):
181
  input_id_strs = tokenizer.convert_ids_to_tokens(input_ids[sent_idx])
 
176
  # Where tokens is a list of dicts, where each dict is:
177
  # { pos: str, feats: dict, prefixes: List[str], suffix: str | bool, suffix_feats: dict | None}
178
  special_toks = tokenizer.all_special_tokens
179
+ special_toks.remove(tokenizer.unk_token)
180
  ret = []
181
  for sent_idx,sentence in enumerate(sentences):
182
  input_id_strs = tokenizer.convert_ids_to_tokens(input_ids[sent_idx])
BertForPrefixMarking.py CHANGED
@@ -7,18 +7,31 @@ from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
7
 
8
  # define the classes, and the possible prefixes for each class
9
  POSSIBLE_PREFIX_CLASSES = [ ['ืœื›ืฉ', 'ื›ืฉ', 'ืžืฉ', 'ื‘ืฉ', 'ืœืฉ'], ['ืž'], ['ืฉ'], ['ื”'], ['ื•'], ['ื›'], ['ืœ'], ['ื‘'] ]
10
- # map each individual prefix to it's class number
11
- PREFIXES_TO_CLASS = {w:i for i,l in enumerate(POSSIBLE_PREFIX_CLASSES) for w in l}
12
- # keep a list of all the prefixes, sorted by length, so that we can decompose
13
- # a given prefixes and figure out the classes
14
- ALL_PREFIX_ITEMS = list(sorted(PREFIXES_TO_CLASS.keys(), key=len, reverse=True))
15
- TOTAL_POSSIBLE_PREFIX_CLASSES = len(POSSIBLE_PREFIX_CLASSES)
16
-
17
- def get_prefixes_from_str(s, greedy=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  # keep trimming prefixes from the string
19
- while len(s) > 0 and s[0] in PREFIXES_TO_CLASS:
20
  # find the longest string to trim
21
- next_pre = next((pre for pre in ALL_PREFIX_ITEMS if s.startswith(pre)), None)
22
  if next_pre is None:
23
  return
24
  yield next_pre
@@ -30,9 +43,9 @@ def get_prefixes_from_str(s, greedy=False):
30
  yield next_pre[0]
31
  s = s[len(next_pre):]
32
 
33
- def get_prefix_classes_from_str(s, greedy=False):
34
- for pre in get_prefixes_from_str(s, greedy):
35
- yield PREFIXES_TO_CLASS[pre]
36
 
37
  @dataclass
38
  class PrefixesClassifiersOutput(ModelOutput):
@@ -46,16 +59,21 @@ class BertPrefixMarkingHead(nn.Module):
46
  super().__init__()
47
  self.config = config
48
 
 
 
 
 
 
49
  # an embedding table containing an embedding for each prefix class + 1 for NONE
50
  # we will concatenate either the embedding/NONE for each class - and we want the concatenate
51
  # size to be the hidden_size
52
- prefix_class_embed = config.hidden_size // TOTAL_POSSIBLE_PREFIX_CLASSES
53
- self.prefix_class_embeddings = nn.Embedding(TOTAL_POSSIBLE_PREFIX_CLASSES + 1, prefix_class_embed)
54
 
55
  # one layer for transformation, apply an activation, then another N classifiers for each prefix class
56
- self.transform = nn.Linear(config.hidden_size + prefix_class_embed * TOTAL_POSSIBLE_PREFIX_CLASSES, config.hidden_size)
57
  self.activation = nn.Tanh()
58
- self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, 2) for _ in range(TOTAL_POSSIBLE_PREFIX_CLASSES)])
59
 
60
  def forward(
61
  self,
@@ -66,8 +84,8 @@ class BertPrefixMarkingHead(nn.Module):
66
  # encode the prefix_class_id_options
67
  # If input_ids is batch x seq_len
68
  # Then sequence_output is batch x seq_len x hidden_dim
69
- # So prefix_class_id_options is batch x seq_len x TOTAL_POSSIBLE_PREFIX_CLASSES
70
- # Looking up the embeddings should give us batch x seq_len x TOTAL_POSSIBLE_PREFIX_CLASSES x hidden_dim / N
71
  possible_class_embed = self.prefix_class_embeddings(prefix_class_id_options)
72
  # then flatten the final dimension - now we have batch x seq_len x hidden_dim_2
73
  possible_class_embed = possible_class_embed.reshape(possible_class_embed.shape[:-2] + (-1,))
@@ -148,15 +166,15 @@ class BertForPrefixMarking(BertPreTrainedModel):
148
 
149
  def predict(self, sentences: List[str], tokenizer: BertTokenizerFast, padding='longest'):
150
  # step 1: encode the sentences through using the tokenizer, and get the input tensors + prefix id tensors
151
- inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, sentences, padding)
152
  inputs.pop('offset_mapping')
153
  inputs = {k:v.to(self.device) for k,v in inputs.items()}
154
 
155
  # run through bert
156
  logits = self.forward(**inputs, return_dict=True).logits
157
- return parse_logits(inputs['input_ids'].tolist(), sentences, tokenizer, logits)
158
 
159
- def parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.FloatTensor):
160
  # extract the predictions by argmaxing the final dimension (batch x sequence x prefixes x prediction)
161
  logit_preds = torch.argmax(logits, axis=3).tolist()
162
 
@@ -176,7 +194,7 @@ def parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer: Be
176
  token += tokens[next_tok_idx][2:]
177
  next_tok_idx += 1
178
 
179
- prefix_len = get_predicted_prefix_len_from_logits(token, logit_preds[sent_idx][tok_idx])
180
 
181
  if not prefix_len:
182
  ret[-1].append([token])
@@ -184,18 +202,18 @@ def parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer: Be
184
  ret[-1].append([token[:prefix_len], token[prefix_len:]])
185
  return ret
186
 
187
- def encode_sentences_for_bert_for_prefix_marking(tokenizer: BertTokenizerFast, sentences: List[str], padding='longest', truncation=True):
188
  inputs = tokenizer(sentences, padding=padding, truncation=truncation, return_offsets_mapping=True, return_tensors='pt')
189
  # create our prefix_id_options array which will be like the input ids shape but with an addtional
190
  # dimension containing for each prefix whether it can be for that word
191
- prefix_id_options = torch.full(inputs['input_ids'].shape + (TOTAL_POSSIBLE_PREFIX_CLASSES,), TOTAL_POSSIBLE_PREFIX_CLASSES, dtype=torch.long)
192
 
193
  # go through each token, and fill in the vector accordingly
194
  for sent_idx, sent_ids in enumerate(inputs['input_ids']):
195
  tokens = tokenizer.convert_ids_to_tokens(sent_ids)
196
  for tok_idx, token in enumerate(tokens):
197
  # if the first letter isn't a valid prefix letter, nothing to talk about
198
- if len(token) < 2 or not token[0] in PREFIXES_TO_CLASS: continue
199
 
200
  # combine the next tokens in? only if it's a breakup
201
  next_tok_idx = tok_idx + 1
@@ -204,13 +222,13 @@ def encode_sentences_for_bert_for_prefix_marking(tokenizer: BertTokenizerFast, s
204
  next_tok_idx += 1
205
 
206
  # find all the possible prefixes - and mark them as 0 (and in the possible mark it as it's value for embed lookup)
207
- for pre_class in get_prefix_classes_from_str(token):
208
  prefix_id_options[sent_idx, tok_idx, pre_class] = pre_class
209
 
210
  inputs['prefix_class_id_options'] = prefix_id_options
211
  return inputs
212
 
213
- def get_predicted_prefix_len_from_logits(token, token_logits):
214
  # Go through each possible prefix, and check if the prefix is yes - and if
215
  # so increase the counter of the matched length, otherwise break out. That will solve cases
216
  # of predicting prefix combinations that don't exist on the word.
@@ -221,7 +239,7 @@ def get_predicted_prefix_len_from_logits(token, token_logits):
221
  # 2] Always check that the word starts with that prefix - otherwise it's bad
222
  # (except for the case of multi-letter prefix, where we force the next to be last)
223
  cur_len, skip_next, last_check, seen_prefixes = 0, False, False, set()
224
- for prefix in get_prefixes_from_str(token):
225
  # Are we skipping this prefix? This will be the case where we matched ื›ืฉ, don't allow ืฉ
226
  if skip_next:
227
  skip_next = False
@@ -232,7 +250,7 @@ def get_predicted_prefix_len_from_logits(token, token_logits):
232
  seen_prefixes.add(prefix)
233
 
234
  # check if we predicted this prefix
235
- if token_logits[PREFIXES_TO_CLASS[prefix]]:
236
  cur_len += len(prefix)
237
  if last_check: break
238
  skip_next = len(prefix) > 1
 
7
 
8
  # define the classes, and the possible prefixes for each class
9
  POSSIBLE_PREFIX_CLASSES = [ ['ืœื›ืฉ', 'ื›ืฉ', 'ืžืฉ', 'ื‘ืฉ', 'ืœืฉ'], ['ืž'], ['ืฉ'], ['ื”'], ['ื•'], ['ื›'], ['ืœ'], ['ื‘'] ]
10
+ POSSIBLE_RABBINIC_PREFIX_CLASSES = [ ['ืœื›ืฉ', 'ื›ืฉ', 'ืžืฉ', 'ื‘ืฉ', 'ืœืฉ', 'ืœื“', 'ื‘ื“', 'ืžื“', 'ื›ื“', 'ืœื›ื“'], ['ืž'], ['ืฉ', 'ื“'], ['ื”'], ['ื•'], ['ื›'], ['ืœ'], ['ื‘'], ['ื'], ['ืง'] ]
11
+
12
+ class PrefixConfig(dict):
13
+ def __init__(self, possible_classes, **kwargs): # added kwargs for previous version where all features were kept as dict values
14
+ super().__init__()
15
+ self.possible_classes = possible_classes
16
+ self.total_classes = len(possible_classes)
17
+ self.prefix_c2i = {w: i for i, l in enumerate(possible_classes) for w in l}
18
+ self.all_prefix_items = list(sorted(self.prefix_c2i.keys(), key=len, reverse=True))
19
+
20
+ @property
21
+ def possible_classes(self) -> List[List[str]]:
22
+ return self.get('possible_classes')
23
+
24
+ @possible_classes.setter
25
+ def possible_classes(self, value: List[List[str]]):
26
+ self['possible_classes'] = value
27
+
28
+ DEFAULT_PREFIX_CONFIG = PrefixConfig(POSSIBLE_PREFIX_CLASSES)
29
+
30
+ def get_prefixes_from_str(s, cfg: PrefixConfig, greedy=False):
31
  # keep trimming prefixes from the string
32
+ while len(s) > 0 and s[0] in cfg.prefix_c2i:
33
  # find the longest string to trim
34
+ next_pre = next((pre for pre in cfg.all_prefix_items if s.startswith(pre)), None)
35
  if next_pre is None:
36
  return
37
  yield next_pre
 
43
  yield next_pre[0]
44
  s = s[len(next_pre):]
45
 
46
+ def get_prefix_classes_from_str(s, cfg: PrefixConfig, greedy=False):
47
+ for pre in get_prefixes_from_str(s, cfg, greedy):
48
+ yield cfg.prefix_c2i[pre]
49
 
50
  @dataclass
51
  class PrefixesClassifiersOutput(ModelOutput):
 
59
  super().__init__()
60
  self.config = config
61
 
62
+ if not hasattr(config, 'prefix_cfg') or config.prefix_cfg is None:
63
+ setattr(config, 'prefix_cfg', DEFAULT_PREFIX_CONFIG)
64
+ if isinstance(config.prefix_cfg, dict):
65
+ config.prefix_cfg = PrefixConfig(config.prefix_cfg['possible_classes'])
66
+
67
  # an embedding table containing an embedding for each prefix class + 1 for NONE
68
  # we will concatenate either the embedding/NONE for each class - and we want the concatenate
69
  # size to be the hidden_size
70
+ prefix_class_embed = config.hidden_size // config.prefix_cfg.total_classes
71
+ self.prefix_class_embeddings = nn.Embedding(config.prefix_cfg.total_classes + 1, prefix_class_embed)
72
 
73
  # one layer for transformation, apply an activation, then another N classifiers for each prefix class
74
+ self.transform = nn.Linear(config.hidden_size + prefix_class_embed * config.prefix_cfg.total_classes, config.hidden_size)
75
  self.activation = nn.Tanh()
76
+ self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, 2) for _ in range(config.prefix_cfg.total_classes)])
77
 
78
  def forward(
79
  self,
 
84
  # encode the prefix_class_id_options
85
  # If input_ids is batch x seq_len
86
  # Then sequence_output is batch x seq_len x hidden_dim
87
+ # So prefix_class_id_options is batch x seq_len x total_classes
88
+ # Looking up the embeddings should give us batch x seq_len x total_classes x hidden_dim / N
89
  possible_class_embed = self.prefix_class_embeddings(prefix_class_id_options)
90
  # then flatten the final dimension - now we have batch x seq_len x hidden_dim_2
91
  possible_class_embed = possible_class_embed.reshape(possible_class_embed.shape[:-2] + (-1,))
 
166
 
167
  def predict(self, sentences: List[str], tokenizer: BertTokenizerFast, padding='longest'):
168
  # step 1: encode the sentences through using the tokenizer, and get the input tensors + prefix id tensors
169
+ inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, self.config.prefix_cfg, sentences, padding)
170
  inputs.pop('offset_mapping')
171
  inputs = {k:v.to(self.device) for k,v in inputs.items()}
172
 
173
  # run through bert
174
  logits = self.forward(**inputs, return_dict=True).logits
175
+ return parse_logits(inputs['input_ids'].tolist(), sentences, tokenizer, logits, self.config.prefix_cfg)
176
 
177
+ def parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.FloatTensor, config: PrefixConfig):
178
  # extract the predictions by argmaxing the final dimension (batch x sequence x prefixes x prediction)
179
  logit_preds = torch.argmax(logits, axis=3).tolist()
180
 
 
194
  token += tokens[next_tok_idx][2:]
195
  next_tok_idx += 1
196
 
197
+ prefix_len = get_predicted_prefix_len_from_logits(token, logit_preds[sent_idx][tok_idx], config)
198
 
199
  if not prefix_len:
200
  ret[-1].append([token])
 
202
  ret[-1].append([token[:prefix_len], token[prefix_len:]])
203
  return ret
204
 
205
+ def encode_sentences_for_bert_for_prefix_marking(tokenizer: BertTokenizerFast, config: PrefixConfig, sentences: List[str], padding='longest', truncation=True):
206
  inputs = tokenizer(sentences, padding=padding, truncation=truncation, return_offsets_mapping=True, return_tensors='pt')
207
  # create our prefix_id_options array which will be like the input ids shape but with an addtional
208
  # dimension containing for each prefix whether it can be for that word
209
+ prefix_id_options = torch.full(inputs['input_ids'].shape + (config.total_classes,), config.total_classes, dtype=torch.long)
210
 
211
  # go through each token, and fill in the vector accordingly
212
  for sent_idx, sent_ids in enumerate(inputs['input_ids']):
213
  tokens = tokenizer.convert_ids_to_tokens(sent_ids)
214
  for tok_idx, token in enumerate(tokens):
215
  # if the first letter isn't a valid prefix letter, nothing to talk about
216
+ if len(token) < 2 or not token[0] in config.prefix_c2i: continue
217
 
218
  # combine the next tokens in? only if it's a breakup
219
  next_tok_idx = tok_idx + 1
 
222
  next_tok_idx += 1
223
 
224
  # find all the possible prefixes - and mark them as 0 (and in the possible mark it as it's value for embed lookup)
225
+ for pre_class in get_prefix_classes_from_str(token, config):
226
  prefix_id_options[sent_idx, tok_idx, pre_class] = pre_class
227
 
228
  inputs['prefix_class_id_options'] = prefix_id_options
229
  return inputs
230
 
231
+ def get_predicted_prefix_len_from_logits(token, token_logits, config: PrefixConfig):
232
  # Go through each possible prefix, and check if the prefix is yes - and if
233
  # so increase the counter of the matched length, otherwise break out. That will solve cases
234
  # of predicting prefix combinations that don't exist on the word.
 
239
  # 2] Always check that the word starts with that prefix - otherwise it's bad
240
  # (except for the case of multi-letter prefix, where we force the next to be last)
241
  cur_len, skip_next, last_check, seen_prefixes = 0, False, False, set()
242
+ for prefix in get_prefixes_from_str(token, config):
243
  # Are we skipping this prefix? This will be the case where we matched ื›ืฉ, don't allow ืฉ
244
  if skip_next:
245
  skip_next = False
 
250
  seen_prefixes.add(prefix)
251
 
252
  # check if we predicted this prefix
253
+ if token_logits[config.prefix_c2i[prefix]]:
254
  cur_len += len(prefix)
255
  if last_check: break
256
  skip_next = len(prefix) > 1
BertForSyntaxParsing.py CHANGED
@@ -166,6 +166,7 @@ def parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer: Be
166
  outputs = []
167
 
168
  special_toks = tokenizer.all_special_tokens
 
169
  for i in range(len(sentences)):
170
  deps = logits.dependency_head_indices[i].tolist()
171
  funcs = logits.function_logits.argmax(-1)[i].tolist()
 
166
  outputs = []
167
 
168
  special_toks = tokenizer.all_special_tokens
169
+ special_toks.remove(tokenizer.unk_token)
170
  for i in range(len(sentences)):
171
  deps = logits.dependency_head_indices[i].tolist()
172
  funcs = logits.function_logits.argmax(-1)[i].tolist()