Shaltiel commited on
Commit
c073d26
1 Parent(s): 81c680b

Upload BertForJointParsing.py

Browse files
Files changed (1) hide show
  1. BertForJointParsing.py +19 -18
BertForJointParsing.py CHANGED
@@ -81,6 +81,7 @@ class BertForJointParsing(BertPreTrainedModel):
81
 
82
  def set_output_embeddings(self, new_embeddings):
83
  if self.lex is not None:
 
84
  self.cls.predictions.decoder = new_embeddings
85
 
86
  def forward(
@@ -207,7 +208,7 @@ class BertForJointParsing(BertPreTrainedModel):
207
  inputs = {k:v.to(self.device) for k,v in inputs.items()}
208
  output = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_syntax_mst)
209
 
210
- final_output = [dict(text=sentence, tokens=[dict(token=t) for t in combine_token_wordpieces(ids, tokenizer)]) for sentence, ids in zip(sentences, inputs['input_ids'])]
211
  # Syntax logits: each sentence gets a dict(tree: List[dict(word,dep_head,dep_head_idx,dep_func)], root_idx: int)
212
  if output.syntax_logits is not None:
213
  for sent_idx,parsed in enumerate(syntax_parse_logits(inputs, sentences, tokenizer, output.syntax_logits)):
@@ -231,10 +232,10 @@ class BertForJointParsing(BertPreTrainedModel):
231
 
232
  # NER logits each sentence gets a list(tuple(word, ner))
233
  if output.ner_logits is not None:
234
- for sent_idx,parsed in enumerate(ner_parse_logits(inputs, sentences, tokenizer, output.ner_logits, self.config.id2label, offset_mapping)):
235
  if per_token_ner:
236
  merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
237
- final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(parsed)
238
 
239
  if output_style in ['ud', 'iahlt_ud']:
240
  final_output = convert_output_to_ud(final_output, style='htb' if output_style == 'ud' else 'iahlt')
@@ -245,36 +246,39 @@ class BertForJointParsing(BertPreTrainedModel):
245
 
246
 
247
 
248
- def aggregate_ner_tokens(predictions):
249
  entities = []
250
  prev = None
251
- for word, pred, start, end in predictions:
252
  # O does nothing
253
  if pred == 'O': prev = None
254
  # B- || I-entity != prev (different entity or none)
255
  elif pred.startswith('B-') or pred[2:] != prev:
256
  prev = pred[2:]
257
- entities.append([[word], prev, start, end])
258
  else:
259
  entities[-1][0].append(word)
260
- entities[-1][3] = end
 
261
 
262
- return [dict(phrase=' '.join(words), label=label, start=start, end=end) for words, label, start, end in entities]
263
 
264
  def merge_token_list(src, update, key):
265
  for token_src, token_update in zip(src, update):
266
  token_src[key] = token_update
267
 
268
- def combine_token_wordpieces(input_ids: torch.Tensor, tokenizer: BertTokenizerFast):
 
269
  ret = []
270
- for token in tokenizer.convert_ids_to_tokens(input_ids):
271
  if token in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]: continue
272
  if token.startswith('##'):
273
- ret[-1] += token[2:]
274
- else: ret.append(token)
 
275
  return ret
276
 
277
- def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor, id2label: Dict[int, str], offset_mapping):
278
  input_ids = inputs['input_ids']
279
 
280
  predictions = torch.argmax(logits, dim=-1)
@@ -289,16 +293,13 @@ def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], toke
289
 
290
  token = tokenizer._convert_id_to_token(token_id)
291
 
292
- # get the offsets for this token
293
- start_pos, end_pos = offset_mapping[batch_idx, tok_idx]
294
  # wordpieces should just be appended to the previous word
295
  # we modify the last token in ret
296
  # by discarding the original end position and replacing it with the new token's end position
297
  if token.startswith('##'):
298
- ret[-1] = (ret[-1][0] + token[2:], ret[-1][1], ret[-1][2], end_pos.item())
299
  continue
300
- # for each token, we append a tuple containing: token, label, start position, end position
301
- ret.append((token, id2label[predictions[batch_idx, tok_idx].item()], start_pos.item(), end_pos.item()))
302
 
303
  return batch_ret
304
 
 
81
 
82
  def set_output_embeddings(self, new_embeddings):
83
  if self.lex is not None:
84
+
85
  self.cls.predictions.decoder = new_embeddings
86
 
87
  def forward(
 
208
  inputs = {k:v.to(self.device) for k,v in inputs.items()}
209
  output = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_syntax_mst)
210
 
211
+ final_output = [dict(text=sentence, tokens=combine_token_wordpieces(ids, offsets, tokenizer)) for sentence, ids, offsets in zip(sentences, inputs['input_ids'], offset_mapping)]
212
  # Syntax logits: each sentence gets a dict(tree: List[dict(word,dep_head,dep_head_idx,dep_func)], root_idx: int)
213
  if output.syntax_logits is not None:
214
  for sent_idx,parsed in enumerate(syntax_parse_logits(inputs, sentences, tokenizer, output.syntax_logits)):
 
232
 
233
  # NER logits each sentence gets a list(tuple(word, ner))
234
  if output.ner_logits is not None:
235
+ for sent_idx,parsed in enumerate(ner_parse_logits(inputs, sentences, tokenizer, output.ner_logits, self.config.id2label)):
236
  if per_token_ner:
237
  merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
238
+ final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(final_output[sent_idx], parsed)
239
 
240
  if output_style in ['ud', 'iahlt_ud']:
241
  final_output = convert_output_to_ud(final_output, style='htb' if output_style == 'ud' else 'iahlt')
 
246
 
247
 
248
 
249
+ def aggregate_ner_tokens(final_output, parsed):
250
  entities = []
251
  prev = None
252
+ for token_idx, (d, (word, pred)) in enumerate(zip(final_output['tokens'], parsed)):
253
  # O does nothing
254
  if pred == 'O': prev = None
255
  # B- || I-entity != prev (different entity or none)
256
  elif pred.startswith('B-') or pred[2:] != prev:
257
  prev = pred[2:]
258
+ entities.append([[word], dict(label=prev, start=d['offsets']['start'], end=d['offsets']['end'], token_start=token_idx, token_end=token_idx)])
259
  else:
260
  entities[-1][0].append(word)
261
+ entities[-1][1]['end'] = d['offsets']['end']
262
+ entities[-1][1]['token_end'] = token_idx
263
 
264
+ return [dict(phrase=' '.join(words), **d) for words, d in entities]
265
 
266
  def merge_token_list(src, update, key):
267
  for token_src, token_update in zip(src, update):
268
  token_src[key] = token_update
269
 
270
+ def combine_token_wordpieces(input_ids: torch.Tensor, offset_mapping: torch.Tensor, tokenizer: BertTokenizerFast):
271
+ offset_mapping = offset_mapping.tolist()
272
  ret = []
273
+ for token, offsets in zip(tokenizer.convert_ids_to_tokens(input_ids), offset_mapping):
274
  if token in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]: continue
275
  if token.startswith('##'):
276
+ ret[-1]['token'] += token[2:]
277
+ ret[-1]['offsets']['end'] = offsets[1]
278
+ else: ret.append(dict(token=token, offsets=dict(start=offsets[0], end=offsets[1])))
279
  return ret
280
 
281
+ def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor, id2label: Dict[int, str]):
282
  input_ids = inputs['input_ids']
283
 
284
  predictions = torch.argmax(logits, dim=-1)
 
293
 
294
  token = tokenizer._convert_id_to_token(token_id)
295
 
 
 
296
  # wordpieces should just be appended to the previous word
297
  # we modify the last token in ret
298
  # by discarding the original end position and replacing it with the new token's end position
299
  if token.startswith('##'):
 
300
  continue
301
+ # for each token, we append a tuple containing: token, label, start position, end position
302
+ ret.append((token, id2label[predictions[batch_idx, tok_idx].item()]))
303
 
304
  return batch_ret
305