sumit commited on
Commit
d8d5ce9
1 Parent(s): 81c680b

add index from tokenizer

Browse files
Files changed (1) hide show
  1. BertForJointParsing.py +47 -10
BertForJointParsing.py CHANGED
@@ -186,7 +186,7 @@ class BertForJointParsing(BertPreTrainedModel):
186
  morph_logits=morph_logits
187
  )
188
 
189
- def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, padding='longest', truncation=True, compute_syntax_mst=True, per_token_ner=False, output_style: Literal['json', 'ud', 'iahlt_ud'] = 'json'):
190
  is_single_sentence = isinstance(sentences, str)
191
  if is_single_sentence:
192
  sentences = [sentences]
@@ -234,32 +234,66 @@ class BertForJointParsing(BertPreTrainedModel):
234
  for sent_idx,parsed in enumerate(ner_parse_logits(inputs, sentences, tokenizer, output.ner_logits, self.config.id2label, offset_mapping)):
235
  if per_token_ner:
236
  merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
237
- final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(parsed)
238
-
239
  if output_style in ['ud', 'iahlt_ud']:
240
  final_output = convert_output_to_ud(final_output, style='htb' if output_style == 'ud' else 'iahlt')
241
 
242
  if is_single_sentence:
243
  final_output = final_output[0]
 
 
 
 
 
244
  return final_output
245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
 
248
  def aggregate_ner_tokens(predictions):
249
  entities = []
250
  prev = None
251
- for word, pred, start, end in predictions:
252
  # O does nothing
253
  if pred == 'O': prev = None
254
  # B- || I-entity != prev (different entity or none)
255
  elif pred.startswith('B-') or pred[2:] != prev:
256
  prev = pred[2:]
257
- entities.append([[word], prev, start, end])
258
  else:
259
  entities[-1][0].append(word)
260
  entities[-1][3] = end
 
261
 
262
- return [dict(phrase=' '.join(words), label=label, start=start, end=end) for words, label, start, end in entities]
263
 
264
  def merge_token_list(src, update, key):
265
  for token_src, token_update in zip(src, update):
@@ -276,7 +310,6 @@ def combine_token_wordpieces(input_ids: torch.Tensor, tokenizer: BertTokenizerFa
276
 
277
  def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor, id2label: Dict[int, str], offset_mapping):
278
  input_ids = inputs['input_ids']
279
-
280
  predictions = torch.argmax(logits, dim=-1)
281
  batch_ret = []
282
  for batch_idx in range(len(sentences)):
@@ -295,11 +328,15 @@ def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], toke
295
  # we modify the last token in ret
296
  # by discarding the original end position and replacing it with the new token's end position
297
  if token.startswith('##'):
298
- ret[-1] = (ret[-1][0] + token[2:], ret[-1][1], ret[-1][2], end_pos.item())
299
  continue
300
  # for each token, we append a tuple containing: token, label, start position, end position
301
- ret.append((token, id2label[predictions[batch_idx, tok_idx].item()], start_pos.item(), end_pos.item()))
302
-
 
 
 
 
303
  return batch_ret
304
 
305
  def lex_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor):
 
186
  morph_logits=morph_logits
187
  )
188
 
189
+ def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, padding='longest', truncation=True, compute_syntax_mst=True, per_token_ner=False, detailed_ner=False, output_style: Literal['json', 'ud', 'iahlt_ud'] = 'json'):
190
  is_single_sentence = isinstance(sentences, str)
191
  if is_single_sentence:
192
  sentences = [sentences]
 
234
  for sent_idx,parsed in enumerate(ner_parse_logits(inputs, sentences, tokenizer, output.ner_logits, self.config.id2label, offset_mapping)):
235
  if per_token_ner:
236
  merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
237
+ final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(parsed)
238
+
239
  if output_style in ['ud', 'iahlt_ud']:
240
  final_output = convert_output_to_ud(final_output, style='htb' if output_style == 'ud' else 'iahlt')
241
 
242
  if is_single_sentence:
243
  final_output = final_output[0]
244
+
245
+ words_index = parse_index(inputs['input_ids'], tokenizer)[0]
246
+ for idx, w in zip(words_index, final_output[0]['tokens']):
247
+ w['idx'] = idx
248
+
249
  return final_output
250
 
251
+ def parse_index(input_ids: torch.Tensor, tokenizer: BertTokenizerFast):
252
+ # Create input_indices for each input_id, handling word-pieces
253
+ input_indices = []
254
+ for batch_idx, ids in enumerate(input_ids):
255
+ sentence_indices = []
256
+ current_word_indices = []
257
+ for idx, id_value in enumerate(ids):
258
+ # Skip special tokens
259
+ if id_value in [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]:
260
+ continue
261
+
262
+ token_id = input_ids[batch_idx, idx]
263
+ token = tokenizer._convert_id_to_token(token_id)
264
+
265
+ # If the token is a continuation of a previous word (word-piece), append the index
266
+ if token.startswith('##'):
267
+ current_word_indices.append(idx)
268
+ else:
269
+ # If there's a current word, add it to sentence indices
270
+ if current_word_indices:
271
+ sentence_indices.append(current_word_indices)
272
+ current_word_indices = [idx]
273
+
274
+ # Add the last word to sentence indices if not empty
275
+ if current_word_indices:
276
+ sentence_indices.append(current_word_indices)
277
+ input_indices.append(sentence_indices)
278
+ return input_indices
279
 
280
 
281
  def aggregate_ner_tokens(predictions):
282
  entities = []
283
  prev = None
284
+ for word, pred, start, end, idx in predictions:
285
  # O does nothing
286
  if pred == 'O': prev = None
287
  # B- || I-entity != prev (different entity or none)
288
  elif pred.startswith('B-') or pred[2:] != prev:
289
  prev = pred[2:]
290
+ entities.append([[word], prev, start, end, idx])
291
  else:
292
  entities[-1][0].append(word)
293
  entities[-1][3] = end
294
+ entities[-1][4].extend(idx)
295
 
296
+ return [dict(idx=idx, phrase=' '.join(words), label=label, start=start, end=end) for words, label, start, end, idx in entities]
297
 
298
  def merge_token_list(src, update, key):
299
  for token_src, token_update in zip(src, update):
 
310
 
311
  def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor, id2label: Dict[int, str], offset_mapping):
312
  input_ids = inputs['input_ids']
 
313
  predictions = torch.argmax(logits, dim=-1)
314
  batch_ret = []
315
  for batch_idx in range(len(sentences)):
 
328
  # we modify the last token in ret
329
  # by discarding the original end position and replacing it with the new token's end position
330
  if token.startswith('##'):
331
+ ret[-1] = [ret[-1][0] + token[2:], ret[-1][1], ret[-1][2], end_pos.item()]
332
  continue
333
  # for each token, we append a tuple containing: token, label, start position, end position
334
+ ret.append([token, id2label[predictions[batch_idx, tok_idx].item()], start_pos.item(), end_pos.item()])
335
+
336
+ words_index = parse_index(inputs['input_ids'], tokenizer)[0]
337
+ for idx, w in zip(words_index, batch_ret[0]):
338
+ w.append(idx)
339
+
340
  return batch_ret
341
 
342
  def lex_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor):