emanuelaboros
commited on
Commit
•
4fd1faf
1
Parent(s):
8d73145
Initial commit of the trained NER model with code
Browse files- config.json +7 -0
- generic_ner.py +173 -0
config.json
CHANGED
@@ -5,6 +5,13 @@
|
|
5 |
],
|
6 |
"attention_probs_dropout_prob": 0.1,
|
7 |
"classifier_dropout": null,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
"hidden_act": "gelu",
|
9 |
"hidden_dropout_prob": 0.1,
|
10 |
"hidden_size": 512,
|
|
|
5 |
],
|
6 |
"attention_probs_dropout_prob": 0.1,
|
7 |
"classifier_dropout": null,
|
8 |
+
"custom_pipelines": {
|
9 |
+
"generic-ner": {
|
10 |
+
"impl": "generic_ner.MultitaskTokenClassificationPipeline",
|
11 |
+
"pt": "models.ExtendedMultitaskModelForTokenClassification",
|
12 |
+
"tf": []
|
13 |
+
}
|
14 |
+
},
|
15 |
"hidden_act": "gelu",
|
16 |
"hidden_dropout_prob": 0.1,
|
17 |
"hidden_size": 512,
|
generic_ner.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import Pipeline
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from nltk.chunk import conlltags2tree
|
5 |
+
from nltk import pos_tag
|
6 |
+
from nltk.tree import Tree
|
7 |
+
import string
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import re
|
10 |
+
from models import ExtendedMultitaskModelForTokenClassification
|
11 |
+
|
12 |
+
# Register the custom pipeline
|
13 |
+
from transformers import pipeline
|
14 |
+
|
15 |
+
|
16 |
+
def tokenize(text):
|
17 |
+
# print(text)
|
18 |
+
for punctuation in string.punctuation:
|
19 |
+
text = text.replace(punctuation, " " + punctuation + " ")
|
20 |
+
return text.split()
|
21 |
+
|
22 |
+
|
23 |
+
def find_entity_indices(article, entity):
|
24 |
+
"""
|
25 |
+
Find all occurrences of an entity in the article and return their indices.
|
26 |
+
|
27 |
+
:param article: The complete article text.
|
28 |
+
:param entity: The entity to search for.
|
29 |
+
:return: A list of tuples (lArticleOffset, rArticleOffset) for each occurrence.
|
30 |
+
"""
|
31 |
+
|
32 |
+
# normalized_target = normalize_text(entity)
|
33 |
+
# normalized_document = normalize_text(article)
|
34 |
+
|
35 |
+
entity_indices = []
|
36 |
+
for match in re.finditer(re.escape(entity), article):
|
37 |
+
start_idx = match.start()
|
38 |
+
end_idx = match.end()
|
39 |
+
entity_indices.append((start_idx, end_idx))
|
40 |
+
|
41 |
+
return entity_indices
|
42 |
+
|
43 |
+
|
44 |
+
def get_entities(tokens, tags, confidences, text):
|
45 |
+
|
46 |
+
tags = [tag.replace("S-", "B-").replace("E-", "I-") for tag in tags]
|
47 |
+
pos_tags = [pos for token, pos in pos_tag(tokens)]
|
48 |
+
|
49 |
+
conlltags = [(token, pos, tg) for token, pos, tg in zip(tokens, pos_tags, tags)]
|
50 |
+
ne_tree = conlltags2tree(conlltags)
|
51 |
+
|
52 |
+
entities = []
|
53 |
+
idx: int = 0
|
54 |
+
|
55 |
+
for subtree in ne_tree:
|
56 |
+
# skipping 'O' tags
|
57 |
+
if isinstance(subtree, Tree):
|
58 |
+
original_label = subtree.label()
|
59 |
+
original_string = " ".join([token for token, pos in subtree.leaves()])
|
60 |
+
|
61 |
+
for indices in find_entity_indices(text, original_string):
|
62 |
+
entity_start_position = indices[0]
|
63 |
+
entity_end_position = indices[1]
|
64 |
+
entities.append(
|
65 |
+
{
|
66 |
+
"entity": original_label,
|
67 |
+
"score": np.average(confidences[idx : idx + len(subtree)]),
|
68 |
+
"index": idx,
|
69 |
+
"word": original_string,
|
70 |
+
"start": entity_start_position,
|
71 |
+
"end": entity_end_position,
|
72 |
+
}
|
73 |
+
)
|
74 |
+
assert (
|
75 |
+
text[entity_start_position:entity_end_position] == original_string
|
76 |
+
)
|
77 |
+
idx += len(subtree)
|
78 |
+
|
79 |
+
# Update the current character position
|
80 |
+
# We add the length of the original string + 1 (for the space)
|
81 |
+
else:
|
82 |
+
token, pos = subtree
|
83 |
+
# If it's not a named entity, we still need to update the character
|
84 |
+
# position
|
85 |
+
idx += 1
|
86 |
+
|
87 |
+
return entities
|
88 |
+
|
89 |
+
|
90 |
+
def realign(
|
91 |
+
text_sentence, out_label_preds, softmax_scores, tokenizer, reverted_label_map
|
92 |
+
):
|
93 |
+
preds_list, words_list, confidence_list = [], [], []
|
94 |
+
word_ids = tokenizer(text_sentence, is_split_into_words=True).word_ids()
|
95 |
+
for idx, word in enumerate(text_sentence):
|
96 |
+
beginning_index = word_ids.index(idx)
|
97 |
+
try:
|
98 |
+
preds_list.append(reverted_label_map[out_label_preds[beginning_index]])
|
99 |
+
confidence_list.append(max(softmax_scores[beginning_index]))
|
100 |
+
except Exception as ex: # the sentence was longer then max_length
|
101 |
+
preds_list.append("O")
|
102 |
+
confidence_list.append(0.0)
|
103 |
+
words_list.append(word)
|
104 |
+
|
105 |
+
return words_list, preds_list, confidence_list
|
106 |
+
|
107 |
+
|
108 |
+
class MultitaskTokenClassificationPipeline(Pipeline):
|
109 |
+
def __init__(self, model, tokenizer, label_map, **kwargs):
|
110 |
+
super().__init__(model=model, tokenizer=tokenizer, **kwargs)
|
111 |
+
self.label_map = label_map
|
112 |
+
self.id2label = {
|
113 |
+
task: {id_: label for label, id_ in labels.items()}
|
114 |
+
for task, labels in label_map.items()
|
115 |
+
}
|
116 |
+
|
117 |
+
def _sanitize_parameters(self, **kwargs):
|
118 |
+
# Add any additional parameter handling if necessary
|
119 |
+
return kwargs, {}, {}
|
120 |
+
|
121 |
+
def preprocess(self, text, **kwargs):
|
122 |
+
tokenized_inputs = self.tokenizer(
|
123 |
+
text, padding="max_length", truncation=True, max_length=512
|
124 |
+
)
|
125 |
+
|
126 |
+
text_sentence = tokenize(text)
|
127 |
+
return tokenized_inputs, text_sentence, text
|
128 |
+
|
129 |
+
def _forward(self, inputs):
|
130 |
+
inputs, text_sentence, text = inputs
|
131 |
+
input_ids = torch.tensor([inputs["input_ids"]], dtype=torch.long).to(
|
132 |
+
self.model.device
|
133 |
+
)
|
134 |
+
attention_mask = torch.tensor([inputs["attention_mask"]], dtype=torch.long).to(
|
135 |
+
self.model.device
|
136 |
+
)
|
137 |
+
with torch.no_grad():
|
138 |
+
outputs = self.model(input_ids, attention_mask)
|
139 |
+
return outputs, text_sentence, text
|
140 |
+
|
141 |
+
def postprocess(self, outputs, **kwargs):
|
142 |
+
"""
|
143 |
+
Postprocess the outputs of the model
|
144 |
+
:param outputs:
|
145 |
+
:param kwargs:
|
146 |
+
:return:
|
147 |
+
"""
|
148 |
+
tokens_result, text_sentence, text = outputs
|
149 |
+
|
150 |
+
predictions = {}
|
151 |
+
confidence_scores = {}
|
152 |
+
for task, logits in tokens_result.logits.items():
|
153 |
+
predictions[task] = torch.argmax(logits, dim=-1).tolist()
|
154 |
+
confidence_scores[task] = F.softmax(logits, dim=-1).tolist()
|
155 |
+
|
156 |
+
decoded_predictions = {}
|
157 |
+
for task, preds in predictions.items():
|
158 |
+
decoded_predictions[task] = [
|
159 |
+
[self.id2label[task][label] for label in seq] for seq in preds
|
160 |
+
]
|
161 |
+
entities = {}
|
162 |
+
for task, preds in predictions.items():
|
163 |
+
words_list, preds_list, confidence_list = realign(
|
164 |
+
text_sentence,
|
165 |
+
preds[0],
|
166 |
+
confidence_scores[task][0],
|
167 |
+
self.tokenizer,
|
168 |
+
self.id2label[task],
|
169 |
+
)
|
170 |
+
|
171 |
+
entities[task] = get_entities(words_list, preds_list, confidence_list, text)
|
172 |
+
|
173 |
+
return entities
|