|
from transformers import Pipeline |
|
import numpy as np |
|
import torch |
|
from nltk.chunk import conlltags2tree |
|
from nltk import pos_tag |
|
from nltk.tree import Tree |
|
import string |
|
import torch.nn.functional as F |
|
import re |
|
|
|
|
|
def tokenize(text): |
|
|
|
for punctuation in string.punctuation: |
|
text = text.replace(punctuation, " " + punctuation + " ") |
|
return text.split() |
|
|
|
|
|
def find_entity_indices(article, entity): |
|
""" |
|
Find all occurrences of an entity in the article and return their indices. |
|
|
|
:param article: The complete article text. |
|
:param entity: The entity to search for. |
|
:return: A list of tuples (lArticleOffset, rArticleOffset) for each occurrence. |
|
""" |
|
|
|
|
|
|
|
|
|
entity_indices = [] |
|
for match in re.finditer(re.escape(entity), article): |
|
start_idx = match.start() |
|
end_idx = match.end() |
|
entity_indices.append((start_idx, end_idx)) |
|
|
|
return entity_indices |
|
|
|
|
|
def get_entities(tokens, tags, confidences, text): |
|
|
|
tags = [tag.replace("S-", "B-").replace("E-", "I-") for tag in tags] |
|
pos_tags = [pos for token, pos in pos_tag(tokens)] |
|
|
|
conlltags = [(token, pos, tg) for token, pos, tg in zip(tokens, pos_tags, tags)] |
|
ne_tree = conlltags2tree(conlltags) |
|
|
|
entities = [] |
|
idx: int = 0 |
|
|
|
for subtree in ne_tree: |
|
|
|
if isinstance(subtree, Tree): |
|
original_label = subtree.label() |
|
original_string = " ".join([token for token, pos in subtree.leaves()]) |
|
|
|
for indices in find_entity_indices(text, original_string): |
|
entity_start_position = indices[0] |
|
entity_end_position = indices[1] |
|
entities.append( |
|
{ |
|
"entity": original_label, |
|
"score": np.average(confidences[idx : idx + len(subtree)]), |
|
"index": idx, |
|
"word": original_string, |
|
"start": entity_start_position, |
|
"end": entity_end_position, |
|
} |
|
) |
|
assert ( |
|
text[entity_start_position:entity_end_position] == original_string |
|
) |
|
idx += len(subtree) |
|
|
|
|
|
|
|
else: |
|
token, pos = subtree |
|
|
|
|
|
idx += 1 |
|
|
|
return entities |
|
|
|
|
|
def realign( |
|
text_sentence, out_label_preds, softmax_scores, tokenizer, reverted_label_map |
|
): |
|
preds_list, words_list, confidence_list = [], [], [] |
|
word_ids = tokenizer(text_sentence, is_split_into_words=True).word_ids() |
|
for idx, word in enumerate(text_sentence): |
|
beginning_index = word_ids.index(idx) |
|
try: |
|
preds_list.append(reverted_label_map[out_label_preds[beginning_index]]) |
|
confidence_list.append(max(softmax_scores[beginning_index])) |
|
except Exception as ex: |
|
preds_list.append("O") |
|
confidence_list.append(0.0) |
|
words_list.append(word) |
|
|
|
return words_list, preds_list, confidence_list |
|
|
|
|
|
class MultitaskTokenClassificationPipeline(Pipeline): |
|
def __init__(self, model, tokenizer, label_map, **kwargs): |
|
super().__init__(model=model, tokenizer=tokenizer, **kwargs) |
|
self.label_map = label_map |
|
self.id2label = { |
|
task: {id_: label for label, id_ in labels.items()} |
|
for task, labels in label_map.items() |
|
} |
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
|
|
return kwargs, {}, {} |
|
|
|
def preprocess(self, text, **kwargs): |
|
tokenized_inputs = self.tokenizer( |
|
text, padding="max_length", truncation=True, max_length=512 |
|
) |
|
|
|
text_sentence = tokenize(text) |
|
return tokenized_inputs, text_sentence, text |
|
|
|
def _forward(self, inputs): |
|
inputs, text_sentence, text = inputs |
|
input_ids = torch.tensor([inputs["input_ids"]], dtype=torch.long).to( |
|
self.model.device |
|
) |
|
attention_mask = torch.tensor([inputs["attention_mask"]], dtype=torch.long).to( |
|
self.model.device |
|
) |
|
with torch.no_grad(): |
|
outputs = self.model(input_ids, attention_mask) |
|
return outputs, text_sentence, text |
|
|
|
def postprocess(self, outputs, **kwargs): |
|
""" |
|
Postprocess the outputs of the model |
|
:param outputs: |
|
:param kwargs: |
|
:return: |
|
""" |
|
tokens_result, text_sentence, text = outputs |
|
|
|
predictions = {} |
|
confidence_scores = {} |
|
for task, logits in tokens_result.logits.items(): |
|
predictions[task] = torch.argmax(logits, dim=-1).tolist() |
|
confidence_scores[task] = F.softmax(logits, dim=-1).tolist() |
|
|
|
decoded_predictions = {} |
|
for task, preds in predictions.items(): |
|
decoded_predictions[task] = [ |
|
[self.id2label[task][label] for label in seq] for seq in preds |
|
] |
|
entities = {} |
|
for task, preds in predictions.items(): |
|
words_list, preds_list, confidence_list = realign( |
|
text_sentence, |
|
preds[0], |
|
confidence_scores[task][0], |
|
self.tokenizer, |
|
self.id2label[task], |
|
) |
|
|
|
entities[task] = get_entities(words_list, preds_list, confidence_list, text) |
|
|
|
return entities |
|
|