emanuelaboros commited on
Commit
4fd1faf
1 Parent(s): 8d73145

Initial commit of the trained NER model with code

Browse files
Files changed (2) hide show
  1. config.json +7 -0
  2. generic_ner.py +173 -0
config.json CHANGED
@@ -5,6 +5,13 @@
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
7
  "classifier_dropout": null,
 
 
 
 
 
 
 
8
  "hidden_act": "gelu",
9
  "hidden_dropout_prob": 0.1,
10
  "hidden_size": 512,
 
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
7
  "classifier_dropout": null,
8
+ "custom_pipelines": {
9
+ "generic-ner": {
10
+ "impl": "generic_ner.MultitaskTokenClassificationPipeline",
11
+ "pt": "models.ExtendedMultitaskModelForTokenClassification",
12
+ "tf": []
13
+ }
14
+ },
15
  "hidden_act": "gelu",
16
  "hidden_dropout_prob": 0.1,
17
  "hidden_size": 512,
generic_ner.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Pipeline
2
+ import numpy as np
3
+ import torch
4
+ from nltk.chunk import conlltags2tree
5
+ from nltk import pos_tag
6
+ from nltk.tree import Tree
7
+ import string
8
+ import torch.nn.functional as F
9
+ import re
10
+ from models import ExtendedMultitaskModelForTokenClassification
11
+
12
+ # Register the custom pipeline
13
+ from transformers import pipeline
14
+
15
+
16
+ def tokenize(text):
17
+ # print(text)
18
+ for punctuation in string.punctuation:
19
+ text = text.replace(punctuation, " " + punctuation + " ")
20
+ return text.split()
21
+
22
+
23
+ def find_entity_indices(article, entity):
24
+ """
25
+ Find all occurrences of an entity in the article and return their indices.
26
+
27
+ :param article: The complete article text.
28
+ :param entity: The entity to search for.
29
+ :return: A list of tuples (lArticleOffset, rArticleOffset) for each occurrence.
30
+ """
31
+
32
+ # normalized_target = normalize_text(entity)
33
+ # normalized_document = normalize_text(article)
34
+
35
+ entity_indices = []
36
+ for match in re.finditer(re.escape(entity), article):
37
+ start_idx = match.start()
38
+ end_idx = match.end()
39
+ entity_indices.append((start_idx, end_idx))
40
+
41
+ return entity_indices
42
+
43
+
44
+ def get_entities(tokens, tags, confidences, text):
45
+
46
+ tags = [tag.replace("S-", "B-").replace("E-", "I-") for tag in tags]
47
+ pos_tags = [pos for token, pos in pos_tag(tokens)]
48
+
49
+ conlltags = [(token, pos, tg) for token, pos, tg in zip(tokens, pos_tags, tags)]
50
+ ne_tree = conlltags2tree(conlltags)
51
+
52
+ entities = []
53
+ idx: int = 0
54
+
55
+ for subtree in ne_tree:
56
+ # skipping 'O' tags
57
+ if isinstance(subtree, Tree):
58
+ original_label = subtree.label()
59
+ original_string = " ".join([token for token, pos in subtree.leaves()])
60
+
61
+ for indices in find_entity_indices(text, original_string):
62
+ entity_start_position = indices[0]
63
+ entity_end_position = indices[1]
64
+ entities.append(
65
+ {
66
+ "entity": original_label,
67
+ "score": np.average(confidences[idx : idx + len(subtree)]),
68
+ "index": idx,
69
+ "word": original_string,
70
+ "start": entity_start_position,
71
+ "end": entity_end_position,
72
+ }
73
+ )
74
+ assert (
75
+ text[entity_start_position:entity_end_position] == original_string
76
+ )
77
+ idx += len(subtree)
78
+
79
+ # Update the current character position
80
+ # We add the length of the original string + 1 (for the space)
81
+ else:
82
+ token, pos = subtree
83
+ # If it's not a named entity, we still need to update the character
84
+ # position
85
+ idx += 1
86
+
87
+ return entities
88
+
89
+
90
+ def realign(
91
+ text_sentence, out_label_preds, softmax_scores, tokenizer, reverted_label_map
92
+ ):
93
+ preds_list, words_list, confidence_list = [], [], []
94
+ word_ids = tokenizer(text_sentence, is_split_into_words=True).word_ids()
95
+ for idx, word in enumerate(text_sentence):
96
+ beginning_index = word_ids.index(idx)
97
+ try:
98
+ preds_list.append(reverted_label_map[out_label_preds[beginning_index]])
99
+ confidence_list.append(max(softmax_scores[beginning_index]))
100
+ except Exception as ex: # the sentence was longer then max_length
101
+ preds_list.append("O")
102
+ confidence_list.append(0.0)
103
+ words_list.append(word)
104
+
105
+ return words_list, preds_list, confidence_list
106
+
107
+
108
+ class MultitaskTokenClassificationPipeline(Pipeline):
109
+ def __init__(self, model, tokenizer, label_map, **kwargs):
110
+ super().__init__(model=model, tokenizer=tokenizer, **kwargs)
111
+ self.label_map = label_map
112
+ self.id2label = {
113
+ task: {id_: label for label, id_ in labels.items()}
114
+ for task, labels in label_map.items()
115
+ }
116
+
117
+ def _sanitize_parameters(self, **kwargs):
118
+ # Add any additional parameter handling if necessary
119
+ return kwargs, {}, {}
120
+
121
+ def preprocess(self, text, **kwargs):
122
+ tokenized_inputs = self.tokenizer(
123
+ text, padding="max_length", truncation=True, max_length=512
124
+ )
125
+
126
+ text_sentence = tokenize(text)
127
+ return tokenized_inputs, text_sentence, text
128
+
129
+ def _forward(self, inputs):
130
+ inputs, text_sentence, text = inputs
131
+ input_ids = torch.tensor([inputs["input_ids"]], dtype=torch.long).to(
132
+ self.model.device
133
+ )
134
+ attention_mask = torch.tensor([inputs["attention_mask"]], dtype=torch.long).to(
135
+ self.model.device
136
+ )
137
+ with torch.no_grad():
138
+ outputs = self.model(input_ids, attention_mask)
139
+ return outputs, text_sentence, text
140
+
141
+ def postprocess(self, outputs, **kwargs):
142
+ """
143
+ Postprocess the outputs of the model
144
+ :param outputs:
145
+ :param kwargs:
146
+ :return:
147
+ """
148
+ tokens_result, text_sentence, text = outputs
149
+
150
+ predictions = {}
151
+ confidence_scores = {}
152
+ for task, logits in tokens_result.logits.items():
153
+ predictions[task] = torch.argmax(logits, dim=-1).tolist()
154
+ confidence_scores[task] = F.softmax(logits, dim=-1).tolist()
155
+
156
+ decoded_predictions = {}
157
+ for task, preds in predictions.items():
158
+ decoded_predictions[task] = [
159
+ [self.id2label[task][label] for label in seq] for seq in preds
160
+ ]
161
+ entities = {}
162
+ for task, preds in predictions.items():
163
+ words_list, preds_list, confidence_list = realign(
164
+ text_sentence,
165
+ preds[0],
166
+ confidence_scores[task][0],
167
+ self.tokenizer,
168
+ self.id2label[task],
169
+ )
170
+
171
+ entities[task] = get_entities(words_list, preds_list, confidence_list, text)
172
+
173
+ return entities