Shaltiel commited on
Commit
297d804
1 Parent(s): 33aea67

add index from tokenizer

Browse files
Files changed (1) hide show
  1. BertForJointParsing.py +10 -47
BertForJointParsing.py CHANGED
@@ -186,7 +186,7 @@ class BertForJointParsing(BertPreTrainedModel):
186
  morph_logits=morph_logits
187
  )
188
 
189
- def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, padding='longest', truncation=True, compute_syntax_mst=True, per_token_ner=False, detailed_ner=False, output_style: Literal['json', 'ud', 'iahlt_ud'] = 'json'):
190
  is_single_sentence = isinstance(sentences, str)
191
  if is_single_sentence:
192
  sentences = [sentences]
@@ -234,66 +234,32 @@ class BertForJointParsing(BertPreTrainedModel):
234
  for sent_idx,parsed in enumerate(ner_parse_logits(inputs, sentences, tokenizer, output.ner_logits, self.config.id2label, offset_mapping)):
235
  if per_token_ner:
236
  merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
237
- final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(parsed)
238
-
239
  if output_style in ['ud', 'iahlt_ud']:
240
  final_output = convert_output_to_ud(final_output, style='htb' if output_style == 'ud' else 'iahlt')
241
 
242
  if is_single_sentence:
243
  final_output = final_output[0]
244
-
245
- words_index = parse_index(inputs['input_ids'], tokenizer)[0]
246
- for idx, w in zip(words_index, final_output[0]['tokens']):
247
- w['idx'] = idx
248
-
249
  return final_output
250
 
251
- def parse_index(input_ids: torch.Tensor, tokenizer: BertTokenizerFast):
252
- # Create input_indices for each input_id, handling word-pieces
253
- input_indices = []
254
- for batch_idx, ids in enumerate(input_ids):
255
- sentence_indices = []
256
- current_word_indices = []
257
- for idx, id_value in enumerate(ids):
258
- # Skip special tokens
259
- if id_value in [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]:
260
- continue
261
-
262
- token_id = input_ids[batch_idx, idx]
263
- token = tokenizer._convert_id_to_token(token_id)
264
-
265
- # If the token is a continuation of a previous word (word-piece), append the index
266
- if token.startswith('##'):
267
- current_word_indices.append(idx)
268
- else:
269
- # If there's a current word, add it to sentence indices
270
- if current_word_indices:
271
- sentence_indices.append(current_word_indices)
272
- current_word_indices = [idx]
273
-
274
- # Add the last word to sentence indices if not empty
275
- if current_word_indices:
276
- sentence_indices.append(current_word_indices)
277
- input_indices.append(sentence_indices)
278
- return input_indices
279
 
280
 
281
  def aggregate_ner_tokens(predictions):
282
  entities = []
283
  prev = None
284
- for word, pred, start, end, idx in predictions:
285
  # O does nothing
286
  if pred == 'O': prev = None
287
  # B- || I-entity != prev (different entity or none)
288
  elif pred.startswith('B-') or pred[2:] != prev:
289
  prev = pred[2:]
290
- entities.append([[word], prev, start, end, idx])
291
  else:
292
  entities[-1][0].append(word)
293
  entities[-1][3] = end
294
- entities[-1][4].extend(idx)
295
 
296
- return [dict(idx=idx, phrase=' '.join(words), label=label, start=start, end=end) for words, label, start, end, idx in entities]
297
 
298
  def merge_token_list(src, update, key):
299
  for token_src, token_update in zip(src, update):
@@ -310,6 +276,7 @@ def combine_token_wordpieces(input_ids: torch.Tensor, tokenizer: BertTokenizerFa
310
 
311
  def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor, id2label: Dict[int, str], offset_mapping):
312
  input_ids = inputs['input_ids']
 
313
  predictions = torch.argmax(logits, dim=-1)
314
  batch_ret = []
315
  for batch_idx in range(len(sentences)):
@@ -328,15 +295,11 @@ def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], toke
328
  # we modify the last token in ret
329
  # by discarding the original end position and replacing it with the new token's end position
330
  if token.startswith('##'):
331
- ret[-1] = [ret[-1][0] + token[2:], ret[-1][1], ret[-1][2], end_pos.item()]
332
  continue
333
  # for each token, we append a tuple containing: token, label, start position, end position
334
- ret.append([token, id2label[predictions[batch_idx, tok_idx].item()], start_pos.item(), end_pos.item()])
335
-
336
- words_index = parse_index(inputs['input_ids'], tokenizer)[0]
337
- for idx, w in zip(words_index, batch_ret[0]):
338
- w.append(idx)
339
-
340
  return batch_ret
341
 
342
  def lex_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor):
 
186
  morph_logits=morph_logits
187
  )
188
 
189
+ def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, padding='longest', truncation=True, compute_syntax_mst=True, per_token_ner=False, output_style: Literal['json', 'ud', 'iahlt_ud'] = 'json'):
190
  is_single_sentence = isinstance(sentences, str)
191
  if is_single_sentence:
192
  sentences = [sentences]
 
234
  for sent_idx,parsed in enumerate(ner_parse_logits(inputs, sentences, tokenizer, output.ner_logits, self.config.id2label, offset_mapping)):
235
  if per_token_ner:
236
  merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
237
+ final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(parsed)
238
+
239
  if output_style in ['ud', 'iahlt_ud']:
240
  final_output = convert_output_to_ud(final_output, style='htb' if output_style == 'ud' else 'iahlt')
241
 
242
  if is_single_sentence:
243
  final_output = final_output[0]
 
 
 
 
 
244
  return final_output
245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
 
248
  def aggregate_ner_tokens(predictions):
249
  entities = []
250
  prev = None
251
+ for word, pred, start, end in predictions:
252
  # O does nothing
253
  if pred == 'O': prev = None
254
  # B- || I-entity != prev (different entity or none)
255
  elif pred.startswith('B-') or pred[2:] != prev:
256
  prev = pred[2:]
257
+ entities.append([[word], prev, start, end])
258
  else:
259
  entities[-1][0].append(word)
260
  entities[-1][3] = end
 
261
 
262
+ return [dict(phrase=' '.join(words), label=label, start=start, end=end) for words, label, start, end in entities]
263
 
264
  def merge_token_list(src, update, key):
265
  for token_src, token_update in zip(src, update):
 
276
 
277
  def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor, id2label: Dict[int, str], offset_mapping):
278
  input_ids = inputs['input_ids']
279
+
280
  predictions = torch.argmax(logits, dim=-1)
281
  batch_ret = []
282
  for batch_idx in range(len(sentences)):
 
295
  # we modify the last token in ret
296
  # by discarding the original end position and replacing it with the new token's end position
297
  if token.startswith('##'):
298
+ ret[-1] = (ret[-1][0] + token[2:], ret[-1][1], ret[-1][2], end_pos.item())
299
  continue
300
  # for each token, we append a tuple containing: token, label, start position, end position
301
+ ret.append((token, id2label[predictions[batch_idx, tok_idx].item()], start_pos.item(), end_pos.item()))
302
+
 
 
 
 
303
  return batch_ret
304
 
305
  def lex_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor):