Daniel Steinigen commited on
Commit
a50f42c
1 Parent(s): 980b30f

add demonstrator

Browse files
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: NLP Legal Texts
3
- emoji: 🐠
4
  colorFrom: red
5
  colorTo: gray
6
  sdk: streamlit
 
1
  ---
2
  title: NLP Legal Texts
3
+ emoji: ⚖‍
4
  colorFrom: red
5
  colorTo: gray
6
  sdk: streamlit
app.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from spacy import displacy
4
+ from PIL import Image
5
+ import json
6
+ import requests
7
+ from pyvis.network import Network
8
+ import streamlit.components.v1 as components
9
+
10
+ from util.process_data import Entity, EntityType, Relation, Sample, SampleList
11
+ from util.tokenizer import Tokenizer
12
+ from model_inference import TransformersInference
13
+ from util.configuration import InferenceConfiguration
14
+
15
+ inference_config = InferenceConfiguration()
16
+ tokenizer = Tokenizer(inference_config.spacy_model)
17
+
18
+ SAMPLE_66 = "EStG § 66 Höhe des Kindergeldes, Zahlungszeitraum (1) Das Kindergeld beträgt monatlich für das erste und zweite Kind jeweils 219 Euro, für das dritte Kind 225 Euro und für das vierte und jedes weitere Kind jeweils 250 Euro."
19
+ SAMPLE_9 = "EStG § 9 Werbungskosten ... Zur Abgeltung dieser Aufwendungen ist für jeden Arbeitstag, an dem der Arbeitnehmer die erste Tätigkeitsstätte aufsucht eine Entfernungspauschale für jeden vollen Kilometer der Entfernung zwischen Wohnung und erster Tätigkeitsstätte von 0,30 Euro anzusetzen, höchstens jedoch 4 500 Euro im Kalenderjahr; ein höherer Betrag als 4 500 Euro ist anzusetzen, soweit der Arbeitnehmer einen eigenen oder ihm zur Nutzung überlassenen Kraftwagen benutzt."
20
+
21
+
22
+ ############################################################
23
+ ## Constants
24
+ ############################################################
25
+ max_width_str = f"max-width: 60%;"
26
+ paragraph = None
27
+ style = "<style>mark.entity { display: inline-block }</style>"
28
+ graph_options = '''
29
+ var options = {
30
+ "edges": {
31
+ "arrows": {
32
+ "to": {
33
+ "enabled": true,
34
+ "scaleFactor": 1.2
35
+ }
36
+ }
37
+ }
38
+ }
39
+ '''
40
+
41
+ legend_content = {
42
+ "text": "StatedKeyFigure StatedExpression Unit Range Factor Condition DeclarativeKeyFigure DeclarativeExpression",
43
+ "ents": [
44
+ {"start": 0, "end": 15, "label": "K"},
45
+ {"start": 16, "end": 32, "label": "E"},
46
+ {"start": 33, "end": 37, "label": "U"},
47
+ {"start": 38, "end": 43, "label": "R"},
48
+ {"start": 44, "end": 50, "label": "F"},
49
+ {"start": 51, "end": 60, "label": "C"},
50
+ {"start": 61, "end": 81, "label": "DK"},
51
+ {"start": 82, "end": 103, "label": "DE"},
52
+ ]}
53
+ legend_options = {
54
+ "ents": ["K","U","E","R","F","C","DK","DE"],
55
+ "colors": {'K': '#46d000',"U": "#e861ef", "E": "#538cff", "R": "#ffbe00", "F": "#0fd5dc", "C":"#ff484b", "DK":"#46d000", "DE":"#538cff"}
56
+ }
57
+ legend_mapping = {"StatedKeyFigure": "K","Unit": "U","StatedExpression": "E","Range": "R","Factor": "F","Condition": "C","DeclarativeKeyFigure": "DK","DeclarativeExpression": "DE"}
58
+ edge_colors = {'hasKeyFigure': '#46d000',"hasUnit": "#e861ef", "hasExpression": "#538cff", "hasRange": "#ffbe00", "hasFactor": "#0fd5dc", "hasCondition":"#ff484b", "join":"#aaa", "Typ":"#aaa", "hasParagraph": "#FF8B15"}
59
+
60
+
61
+ ############################################################
62
+ ## Function definitions
63
+ ############################################################
64
+
65
+ def get_html(html: str, legend=False):
66
+ """Convert HTML so it can be rendered."""
67
+ WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1.5rem">{}</div>"""
68
+ if legend: WRAPPER = """<div style="overflow-x: auto; padding: 1rem">{}</div>"""
69
+ # Newlines seem to mess with the rendering
70
+ html = html.replace("\n", " ")
71
+ return WRAPPER.format(html)
72
+
73
+
74
+ def get_displacy_ent_obj(paragraph, bedingungen=False, send_request=False):
75
+ entities = []
76
+ for entity in paragraph['entities']:
77
+ label = entity["entity"] if not send_request else entity["ent_type"]["label"]
78
+ if (bedingungen and label == "Condition") or (not bedingungen and label != "Condition") :
79
+ entities.append({
80
+ 'start': entity['start'],
81
+ 'end': entity["end"],
82
+ 'label': legend_mapping[label]
83
+ })
84
+ return [{'text': paragraph['text'], 'ents': entities}]
85
+
86
+
87
+ def request_extractor(text_data):
88
+ try:
89
+ data = SampleList(
90
+ samples=[
91
+ Sample(
92
+ idx=0,
93
+ text=str(text_data),
94
+ entities=[],
95
+ relations=[]
96
+ )
97
+ ]
98
+ )
99
+ tokenizer.run(data)
100
+
101
+ model_inference = TransformersInference(inference_config)
102
+ model_inference.run_inference(data)
103
+ return data.dict()["samples"][0]
104
+ except Exception as e:
105
+ result = e
106
+ return {"text":"error","entities":[], "relations":[]}
107
+
108
+
109
+ def generate_graph(nodes, edges, send_request=False):
110
+ net = Network(height="450px", width="100%")#, bgcolor="#222222", font_color="white", select_menu=True, filter_menu=True)
111
+ for node in nodes:
112
+ if "id" in node:
113
+ label = node["entity"] if not send_request else node["ent_type"]["label"]
114
+ node_color = legend_options["colors"][legend_mapping[label]]
115
+ node_label = node["text"] if len(node["text"]) < 30 else (node["text"][:27]+" ...")
116
+ if label in ["Kennzahl", "Kennzahlumschreibung"]:
117
+ net.add_node(node["id"], label=node_label, title=node["text"], mass=2, shape="ellipse", color=node_color, physics=False)
118
+ else:
119
+ net.add_node(node["id"], label=node_label, title=node["text"], mass=1, shape="ellipse", color=node_color)
120
+ for edge in edges:
121
+ label = edge["relation"] if not send_request else edge["rel_type"]["label"]
122
+ net.add_edge(edge["head"], edge["tail"], width=1, title=label, arrowStrikethrough=False, color=edge_colors[label])
123
+ # net.force_atlas_2based() # barnes_hut() force_atlas_2based() hrepulsion() repulsion()
124
+ net.toggle_physics(True)
125
+ net.set_edge_smooth("dynamic") # dynamic, continuous, discrete, diagonalCross, straightCross, horizontal, vertical, curvedCW, curvedCCW, cubicBezier
126
+ net.set_options(graph_options)
127
+ html_graph = net.generate_html()
128
+ return html_graph
129
+
130
+ ############################################################
131
+ ## Page configuration
132
+ ############################################################
133
+ st.set_page_config(
134
+ page_title="NLP Gesetzestexte",
135
+ menu_items={
136
+ 'Get Help': None,
137
+ 'Report a bug': None,
138
+ 'About': "## Demonstrator NLP"
139
+ }
140
+ # layout="wide")
141
+ )
142
+
143
+ st.markdown(
144
+ f"""
145
+ <style>
146
+ .appview-container .main .block-container{{
147
+ {max_width_str}
148
+ }}
149
+ </style>
150
+ """,
151
+ unsafe_allow_html=True,
152
+ )
153
+
154
+ # radio button formatting in line
155
+ st.write('<style>div.row-widget.stRadio > div{flex-direction:row;justify-content: left;} </style>', unsafe_allow_html=True)
156
+
157
+ ############################################################
158
+ ## Page formating
159
+ ############################################################
160
+ col3, col4 = st.columns([2.4,1.6])
161
+ st.write('\n')
162
+ st.write('\n')
163
+
164
+
165
+ with col3:
166
+ st.subheader("Extraction of Key Figures")
167
+ st.write("Demonstrator Application for Paper 'Semantic Extraction of Key Figures and Their Properties From Tax Legal Texts using Neural Models'")
168
+ with col4:
169
+ st.caption("Semantic Model")
170
+ image = Image.open('util/ontology.png')
171
+ st.image(image, width=350)
172
+
173
+
174
+ text_option = st.radio("Select Example", ["Insert your paragraph", "EStG § 66 Kindergeld", "EStG § 9 Werbungskosten"])
175
+ st.write('\n')
176
+ if text_option == "EStG § 66 Kindergeld":
177
+ text_area_input = st.text_area("Given paragraph", SAMPLE_66, height=200)
178
+ elif text_option == "EStG § 9 Werbungskosten":
179
+ text_area_input = st.text_area("Given paragraph", SAMPLE_9, height=200)
180
+ else:
181
+ text_area_input = st.text_area("Given paragraph", "", height=200)
182
+
183
+ if st.button("Start Extraction") and text_area_input != "":
184
+ with st.spinner('Executing Extraction ...'):
185
+ paragraph = request_extractor(text_area_input)
186
+ if paragraph["text"] == "error":
187
+ st.error("Error while executing extraction.")
188
+ else:
189
+ legend = displacy.render([legend_content], style="ent", options=legend_options, manual=True)
190
+ st.write(f"{style}{get_html(legend, True)}", unsafe_allow_html=True)
191
+
192
+ st.caption("Entities:")
193
+ extracted_data = get_displacy_ent_obj(paragraph, False, True)
194
+ html = displacy.render(extracted_data, style="ent", options=legend_options, manual=True)
195
+ st.write(f"{style}{get_html(html)}", unsafe_allow_html=True)
196
+
197
+ st.write('\n')
198
+ st.caption("Conditions:")
199
+ extracted_data = get_displacy_ent_obj(paragraph, True, True)
200
+ html = displacy.render(extracted_data, style="ent", options=legend_options, manual=True)
201
+ st.write(f"{style}{get_html(html)}", unsafe_allow_html=True)
202
+
203
+ st.write('\n')
204
+ st.caption("\n\nRelations:")
205
+ html_graph_req = generate_graph(paragraph["entities"], paragraph["relations"], send_request=True)
206
+ components.html(html_graph_req, height=500)
207
+ st.write('\n')
208
+ with st.expander("Show JSON"):
209
+ st.json(paragraph)
classification.json ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "entity_types": [
3
+ {
4
+ "idx": 0,
5
+ "label": "O"
6
+ },
7
+ {
8
+ "idx": 1,
9
+ "label": "StatedKeyFigure"
10
+ },
11
+ {
12
+ "idx": 2,
13
+ "label": "Condition"
14
+ },
15
+ {
16
+ "idx": 3,
17
+ "label": "StatedExpression"
18
+ },
19
+ {
20
+ "idx": 4,
21
+ "label": "Unit"
22
+ },
23
+ {
24
+ "idx": 5,
25
+ "label": "Range"
26
+ },
27
+ {
28
+ "idx": 6,
29
+ "label": "DeclarativeKeyFigure"
30
+ },
31
+ {
32
+ "idx": 7,
33
+ "label": "Factor"
34
+ },
35
+ {
36
+ "idx": 8,
37
+ "label": "DeclarativeExpression"
38
+ }
39
+ ],
40
+ "relation_types": [
41
+ {
42
+ "idx": 9,
43
+ "label": "hasCondition"
44
+ },
45
+ {
46
+ "idx": 10,
47
+ "label": "hasExpression"
48
+ },
49
+ {
50
+ "idx": 11,
51
+ "label": "hasUnit"
52
+ },
53
+ {
54
+ "idx": 12,
55
+ "label": "join"
56
+ },
57
+ {
58
+ "idx": 13,
59
+ "label": "hasRange"
60
+ },
61
+ {
62
+ "idx": 14,
63
+ "label": "hasFactor"
64
+ }
65
+ ],
66
+ "id_of_non_entity": 0,
67
+ "groups": [
68
+ [
69
+ 0,
70
+ 2
71
+ ],
72
+ [
73
+ 0,
74
+ 1,
75
+ 3,
76
+ 4,
77
+ 5,
78
+ 6,
79
+ 7,
80
+ 8
81
+ ]
82
+ ]
83
+ }
model_inference.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ from typing import List, Any
4
+ import copy
5
+
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ from transformers import AutoTokenizer, AutoModelForTokenClassification, Trainer
9
+
10
+ from util.process_data import Sample, Entity, EntityType, EntityTypeSet, SampleList, Token, Relation
11
+ from util.configuration import InferenceConfiguration
12
+
13
+ valid_relations = { # head : [tail, ...]
14
+ "StatedKeyFigure": ["StatedKeyFigure", "Condition", "StatedExpression", "DeclarativeExpression"],
15
+ "DeclarativeKeyFigure": ["DeclarativeKeyFigure", "Condition", "StatedExpression", "DeclarativeExpression"],
16
+ "StatedExpression": ["Unit", "Factor", "Range", "Condition"],
17
+ "DeclarativeExpression": ["DeclarativeExpression", "Unit", "Factor", "Range", "Condition"],
18
+ "Condition": ["Condition", "StatedExpression", "DeclarativeExpression"],
19
+ "Range": ["Range"]
20
+ }
21
+
22
+ class TokenClassificationDataset(Dataset):
23
+ """ Pytorch Dataset """
24
+
25
+ def __init__(self, encodings, labels):
26
+ self.encodings = encodings
27
+ self.labels = labels
28
+
29
+ def __getitem__(self, idx):
30
+ item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
31
+ item['labels'] = torch.tensor(self.labels[idx])
32
+ return item
33
+
34
+ def __len__(self):
35
+ return len(self.labels)
36
+
37
+
38
+ class TransformersInference():
39
+
40
+ def __init__(self, config: InferenceConfiguration):
41
+ super().__init__()
42
+ self.__logger = logging.getLogger(self.__class__.__name__)
43
+ self.__logger.info(f"Load Configuration: {config.dict()}")
44
+
45
+ with open(f"classification.json", mode='r', encoding="utf-8") as f:
46
+ self.__entity_type_set = EntityTypeSet.parse_obj(json.load(f))
47
+ self.__entity_type_label_to_id_mapping = {x.label: x.idx for x in self.__entity_type_set.all_types()}
48
+ self.__entity_type_id_to_label_mapping = {x.idx: x.label for x in self.__entity_type_set.all_types()}
49
+
50
+ self.__logger.info("Load Model: " + config.model_path_keyfigure)
51
+ self.__tokenizer = AutoTokenizer.from_pretrained(config.transformer_model,
52
+ padding="max_length", max_length=512, truncation=True)
53
+
54
+ self.__model = AutoModelForTokenClassification.from_pretrained(config.model_path_keyfigure, num_labels=(
55
+ len(self.__entity_type_set)))
56
+
57
+ self.__trainer = Trainer(model=self.__model)
58
+ self.__merge_entities = config.merge_entities
59
+ self.__split_len = config.split_len
60
+ self.__extract_relations = config.extract_relations
61
+
62
+ # add special tokens
63
+ entity_groups = self.__entity_type_set.groups
64
+ num_entity_groups = len(entity_groups)
65
+
66
+ lst_special_tokens = ["[REL]", "[SUB]", "[/SUB]", "[OBJ]", "[/OBJ]"]
67
+ for grp_idx, grp in enumerate(entity_groups):
68
+ lst_special_tokens.append(f"[GRP-{grp_idx:02d}]")
69
+ lst_special_tokens.extend([f"[ENT-{ent:02d}]" for ent in grp if ent != self.__entity_type_set.id_of_non_entity])
70
+ lst_special_tokens.extend([f"[/ENT-{ent:02d}]" for ent in grp if ent != self.__entity_type_set.id_of_non_entity])
71
+
72
+ lst_special_tokens = sorted(list(set(lst_special_tokens)))
73
+ special_tokens_dict = {'additional_special_tokens': lst_special_tokens }
74
+ num_added_toks = self.__tokenizer.add_special_tokens(special_tokens_dict)
75
+ self.__logger.info(f"Added {num_added_toks} new special tokens. All special tokens: '{self.__tokenizer.all_special_tokens}'")
76
+
77
+ self.__logger.info("Initialization completed.")
78
+
79
+
80
+
81
+ def run_inference(self, sample_list: SampleList):
82
+ group_predictions = []
83
+ group_entity_ids = []
84
+ self.__logger.info("Predict Entities ...")
85
+ for grp_idx, grp in enumerate(self.__entity_type_set.groups):
86
+ token_lists = [[x.text for x in sample.tokens] for sample in sample_list.samples]
87
+ predictions = self.__get_predictions(token_lists, f"[GRP-{grp_idx:02d}]")
88
+ group_entity_ids_ = []
89
+ for sample, prediction_per_tokens in zip(sample_list.samples, predictions):
90
+ group_entity_ids_.append(self.generate_response_entities(sample, prediction_per_tokens, grp_idx))
91
+ group_predictions.append(predictions)
92
+ group_entity_ids.append(group_entity_ids_)
93
+
94
+ if self.__extract_relations:
95
+ self.__logger.info("Predict Relations ...")
96
+ self.__do_extract_relations(sample_list, group_predictions, group_entity_ids)
97
+
98
+
99
+ def __do_extract_relations(self, sample_list, group_predictions, group_entity_ids):
100
+ id_of_non_entity = self.__entity_type_set.id_of_non_entity
101
+
102
+ for sample_idx, sample in enumerate(sample_list.samples):
103
+ masked_tokens = []
104
+ masked_tokens_align = []
105
+ # create SUB-Mask for every entity that can be a head
106
+ head_entities = [entity_ for entity_ in sample.entities if entity_.ent_type.label in list(valid_relations.keys())]
107
+ for entity_ in head_entities:
108
+ ent_masked_tokens = []
109
+ ent_masked_tokens_align = []
110
+ last_preds = [id_of_non_entity for group in group_predictions]
111
+ last_ent_ids = [-1 for group in group_entity_ids]
112
+ for token_idx, token in enumerate(sample.tokens):
113
+ for group, ent_ids, last_pred, last_ent_id in zip(group_predictions, group_entity_ids, last_preds, last_ent_ids):
114
+ pred = group[sample_idx][token_idx]
115
+ ent_id = ent_ids[sample_idx][token_idx]
116
+ if last_pred != pred and last_pred != id_of_non_entity:
117
+ mask = "[/SUB]" if last_ent_id == entity_.id else "[/OBJ]"
118
+ ent_masked_tokens.extend([f"[/ENT-{last_pred:02d}]", mask])
119
+ ent_masked_tokens_align.extend([str(last_ent_id), str(last_ent_id)])
120
+
121
+ for group, ent_ids, last_pred, last_ent_id in zip(group_predictions, group_entity_ids, last_preds, last_ent_ids):
122
+ pred = group[sample_idx][token_idx]
123
+ ent_id = ent_ids[sample_idx][token_idx]
124
+ if last_pred != pred and pred != id_of_non_entity:
125
+ mask = "[SUB]" if ent_id == entity_.id else "[OBJ]"
126
+ ent_masked_tokens.extend([mask, f"[ENT-{pred:02d}]"])
127
+ ent_masked_tokens_align.extend([str(ent_id), str(ent_id)])
128
+
129
+ ent_masked_tokens.append(token.text)
130
+ ent_masked_tokens_align.append(token.text)
131
+ for idx, group in enumerate(group_predictions):
132
+ last_preds[idx] = group[sample_idx][token_idx]
133
+ for idx, group in enumerate(group_entity_ids):
134
+ last_ent_ids[idx] = group[sample_idx][token_idx]
135
+
136
+ for group, ent_ids, last_pred, last_ent_id in zip(group_predictions, group_entity_ids, last_preds, last_ent_ids):
137
+ pred = group[sample_idx][token_idx]
138
+ ent_id = ent_ids[sample_idx][token_idx]
139
+ if last_pred != id_of_non_entity:
140
+ mask = "[/SUB]" if last_ent_id == entity_.id else "[/OBJ]"
141
+ ent_masked_tokens.extend([f"[/ENT-{last_pred:02d}]", mask])
142
+ ent_masked_tokens_align.extend([str(last_ent_id), str(last_ent_id)])
143
+
144
+ masked_tokens.append(ent_masked_tokens)
145
+ masked_tokens_align.append(ent_masked_tokens_align)
146
+
147
+ rel_predictions = self.__get_predictions(masked_tokens, "[REL]")
148
+ self.generate_response_relations(sample, head_entities, masked_tokens_align, rel_predictions)
149
+
150
+
151
+ def generate_response_entities(self, sample: Sample, predictions_per_tokens: List[int], grp_idx: int):
152
+ entities = []
153
+ entity_ids = []
154
+ id_of_non_entity = self.__entity_type_set.id_of_non_entity
155
+ idx = grp_idx * 1000
156
+ for token, prediction in zip(sample.tokens, predictions_per_tokens):
157
+ if id_of_non_entity == prediction:
158
+ entity_ids.append(-1)
159
+ continue
160
+ idx += 1
161
+ entities.append(self.__build_entity(idx, prediction, token))
162
+ entity_ids.append(idx)
163
+
164
+ if self.__merge_entities:
165
+ entities = self.__do_merge_entities(copy.deepcopy(entities))
166
+ prev_pred = id_of_non_entity
167
+ for idx, pred in enumerate(predictions_per_tokens):
168
+ if prev_pred == pred and idx > 0:
169
+ entity_ids[idx] = entity_ids[idx-1]
170
+ prev_pred = pred
171
+
172
+ sample.entities += entities
173
+
174
+ tags = sample.tags if len(sample.tags) > 0 else [self.__entity_type_set.id_of_non_entity] * len(sample.tokens)
175
+ for tag_id, tok in enumerate(sample.tokens):
176
+ for ent in entities:
177
+ if tok.start >= ent.start and tok.start < ent.end:
178
+ tags[tag_id] = ent.ent_type.idx
179
+ logging.info(tags)
180
+ sample.tags = tags
181
+
182
+ return entity_ids
183
+
184
+
185
+ def generate_response_relations(self, sample: Sample, head_entities: List[Entity], masked_tokens_align: List[List[str]], rel_predictions: List[List[int]]):
186
+ relations = []
187
+ id_of_non_entity = self.__entity_type_set.id_of_non_entity
188
+ idx = 0
189
+ for entity_, align_per_ent, prediction_per_ent in zip(head_entities, masked_tokens_align, rel_predictions):
190
+ for token, prediction in zip(align_per_ent, prediction_per_ent):
191
+ if id_of_non_entity == prediction:
192
+ continue
193
+ try:
194
+ tail = int(token)
195
+ except:
196
+ continue
197
+ if not self.__validate_relation(sample.entities, entity_.id, tail, prediction):
198
+ continue
199
+ idx += 1
200
+ relations.append(self.__build_relation(idx, entity_.id, tail, prediction))
201
+
202
+ sample.relations = relations
203
+
204
+
205
+ def __validate_relation(self, entities: List[Entity], head: int, tail: int, prediction: int):
206
+ if head == tail: return False
207
+ head_ents = [ent.ent_type.label for ent in entities if ent.id==head]
208
+ tail_ents = [ent.ent_type.label for ent in entities if ent.id==tail]
209
+
210
+ if len(head_ents) > 0:
211
+ head_ent = head_ents[0]
212
+ else:
213
+ return False
214
+
215
+ if len(tail_ents) > 0:
216
+ tail_ent = tail_ents[0]
217
+ else:
218
+ return False
219
+
220
+ return tail_ent in valid_relations[head_ent]
221
+
222
+
223
+ def __build_entity(self, idx: int, prediction: int, token: Token) -> Entity:
224
+ return Entity(
225
+ id=idx,
226
+ text=token.text,
227
+ start=token.start,
228
+ end=token.end,
229
+ ent_type=EntityType(
230
+ idx=prediction,
231
+ label=self.__entity_type_id_to_label_mapping[prediction]
232
+ )
233
+ )
234
+
235
+ def __build_relation(self, idx: int, head: int, tail: int, prediction: int) -> Relation:
236
+ return Relation(
237
+ id=idx,
238
+ head=head,
239
+ tail=tail,
240
+ rel_type=EntityType(
241
+ idx=prediction,
242
+ label=self.__entity_type_id_to_label_mapping[prediction]
243
+ )
244
+ )
245
+
246
+ def __do_merge_entities(self, input_ents_):
247
+ out_ents = list()
248
+ current_ent = None
249
+
250
+ for ent in input_ents_:
251
+ if current_ent is None:
252
+ current_ent = ent
253
+ else:
254
+ idx_diff = ent.start - current_ent.end
255
+ if ent.ent_type.idx == current_ent.ent_type.idx and idx_diff <= 1:
256
+ current_ent.end = ent.end
257
+ current_ent.text += (" " if idx_diff == 1 else "") + ent.text
258
+ else:
259
+ out_ents.append(current_ent)
260
+ current_ent = ent
261
+
262
+ if current_ent is not None:
263
+ out_ents.append(current_ent)
264
+
265
+ return out_ents
266
+
267
+
268
+ def __get_predictions(self, token_lists: List[List[str]], trigger: str) -> List[List[int]]:
269
+ """ Get predictions of Transformer Sequence Labeling model """
270
+ if self.__split_len > 0:
271
+ token_lists_split = self.__do_split_sentences(token_lists, self.__split_len)
272
+ predictions = []
273
+ for sample_token_lists in token_lists_split:
274
+ sample_token_lists_trigger = [[trigger]+sample for sample in sample_token_lists]
275
+ val_encodings = self.__tokenizer(sample_token_lists_trigger, is_split_into_words=True, padding='max_length', truncation=True) # return_tensors="pt"
276
+ val_labels = []
277
+ for i in range(len(sample_token_lists_trigger)):
278
+ word_ids = val_encodings.word_ids(batch_index=i)
279
+ label_ids = [0 for _ in word_ids]
280
+ val_labels.append(label_ids)
281
+
282
+ val_dataset = TokenClassificationDataset(val_encodings, val_labels)
283
+
284
+ predictions_raw, _, _ = self.__trainer.predict(val_dataset)
285
+
286
+ predictions_align = self.__align_predictions(predictions_raw, val_encodings)
287
+ confidence = [[max(token) for token in sample] for sample in predictions_align]
288
+ predictions_sample = [[token.index(max(token)) for token in sample][1:] for sample in predictions_align]
289
+ predictions_part = []
290
+ for tok, pred in zip(sample_token_lists_trigger, predictions_sample):
291
+ if trigger == "[REL]" and "[SUB]" not in tok:
292
+ predictions_part += [self.__entity_type_set.id_of_non_entity] * len(pred)
293
+ else:
294
+ predictions_part += pred
295
+ predictions.append(predictions_part)
296
+ # predictions.append([j for i in predictions_sample for j in i]))
297
+ else:
298
+ token_lists_trigger = [[trigger]+sample for sample in token_lists]
299
+ val_encodings = self.__tokenizer(token_lists_trigger, is_split_into_words=True, padding='max_length', truncation=True) # return_tensors="pt"
300
+ val_labels = []
301
+ for i in range(len(token_lists_trigger)):
302
+ word_ids = val_encodings.word_ids(batch_index=i)
303
+ label_ids = [0 for _ in word_ids]
304
+ val_labels.append(label_ids)
305
+
306
+ val_dataset = TokenClassificationDataset(val_encodings, val_labels)
307
+
308
+ predictions_raw, _, _ = self.__trainer.predict(val_dataset)
309
+
310
+ predictions_align = self.__align_predictions(predictions_raw, val_encodings)
311
+ confidence = [[max(token) for token in sample] for sample in predictions_align]
312
+ predictions = [[token.index(max(token)) for token in sample][1:] for sample in predictions_align]
313
+
314
+ return predictions
315
+
316
+ def __do_split_sentences(self, tokens_: List[List[str]], split_len_ = 200) -> List[List[List[str]]]:
317
+ # split token lists into shorter lists
318
+ res_tokens = []
319
+
320
+ for tok_lst in tokens_:
321
+ res_tokens_sample = []
322
+ length = len(tok_lst)
323
+ if length > split_len_:
324
+ num_lists = length // split_len_ + (1 if (length % split_len_) > 0 else 0)
325
+ new_length = int(length / num_lists) + 1
326
+ self.__logger.info(f"Splitting a list of {length} elements into {num_lists} lists of length {new_length}..")
327
+ start_idx = 0
328
+ for i in range(num_lists):
329
+ end_idx = min(start_idx + new_length, length)
330
+ if "\n" in tok_lst[start_idx]: tok_lst[start_idx] = "."
331
+ if "\n" in tok_lst[end_idx-1]: tok_lst[end_idx-1] = "."
332
+ res_tokens_sample.append(tok_lst[start_idx:end_idx])
333
+ start_idx = end_idx
334
+
335
+ res_tokens.append(res_tokens_sample)
336
+ else:
337
+ res_tokens.append([tok_lst])
338
+
339
+ return res_tokens
340
+
341
+
342
+ def __align_predictions(self, predictions, tokenized_inputs, sum_all_tokens=False) -> List[List[List[float]]]:
343
+ """ Align predicted labels from Transformer Tokenizer """
344
+ confidence = []
345
+ id_of_non_entity = self.__entity_type_set.id_of_non_entity
346
+ for i, tagset in enumerate(predictions):
347
+
348
+ word_ids = tokenized_inputs.word_ids(batch_index=i)
349
+
350
+ previous_word_idx = None
351
+ token_confidence = []
352
+ for k, word_idx in enumerate(word_ids):
353
+ try:
354
+ tok_conf = [value for value in tagset[k]]
355
+ except TypeError:
356
+ # use the object itself it if's not iterable
357
+ tok_conf = tagset[k]
358
+
359
+ if word_idx is not None:
360
+ # add nonentity tokens if there is a gap in word ids (usually caused by a newline token)
361
+ if previous_word_idx is not None:
362
+ diff = word_idx - previous_word_idx
363
+ for i in range(diff - 1):
364
+ tmp = [0 for _ in tok_conf]
365
+ tmp[id_of_non_entity] = 1.0
366
+ token_confidence.append(tmp)
367
+
368
+ # add confidence value if this is the first token of the word
369
+ if word_idx != previous_word_idx:
370
+ token_confidence.append(tok_conf)
371
+ else:
372
+ # if sum_all_tokens=True the confidence for all tokens of one word will be summarized
373
+ if sum_all_tokens:
374
+ token_confidence[-1] = [a + b for a, b in zip(token_confidence[-1], tok_conf)]
375
+
376
+ previous_word_idx = word_idx
377
+
378
+ confidence.append(token_confidence)
379
+
380
+ return confidence
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python 3.8.11
2
+ numpy~=1.23.5
3
+ PyYAML~=5.4.1
4
+ pydantic==1.8.2
5
+ tqdm~=4.56.2
6
+ scikit-learn~=0.24.2
7
+ spacy==3.2.0
8
+ # MarkupSafe==2.0.1
9
+ torch==1.6.0
10
+ transformers[sentencepiece]==4.26.1
11
+ pyvis==0.3.2
util/__init__.py ADDED
File without changes
util/configuration.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+
3
+ class InferenceConfiguration(BaseModel):
4
+ model_path_keyfigure: str = "danielsteinigen/KeyFiTax"
5
+ spacy_model: str = "de_core_news_sm"
6
+ transformer_model: str = "xlm-roberta-large"
7
+ merge_entities: bool = True
8
+ split_len: int = 200
9
+ extract_relations: bool = True
util/ontology.png ADDED
util/process_data.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+
3
+ from pydantic import BaseModel, Extra
4
+
5
+ class EntityType(BaseModel):
6
+ idx: int
7
+ label: str
8
+
9
+
10
+ class EntityTypeSet(BaseModel):
11
+ entity_types: List[EntityType]
12
+ relation_types: List[EntityType]
13
+ id_of_non_entity: int
14
+ groups: List[List[int]]
15
+
16
+ def __len__(self):
17
+ return len(self.entity_types) + len(self.relation_types)
18
+
19
+ def all_types(self):
20
+ return [*self.entity_types, *self.relation_types]
21
+
22
+
23
+ class Token(BaseModel):
24
+ text: str
25
+ start: int
26
+ end: int
27
+
28
+
29
+ class Entity(BaseModel):
30
+ id: int
31
+ text: str
32
+ start: int
33
+ end: int
34
+ ent_type: EntityType
35
+ confidence: Optional[float]
36
+
37
+
38
+ class Relation(BaseModel):
39
+ id: int
40
+ head: int
41
+ tail: int
42
+ rel_type: EntityType
43
+
44
+
45
+ class Sample(BaseModel):
46
+ idx: int
47
+ text: str
48
+ entities: List[Entity] = []
49
+ relations: List[Relation] = []
50
+ tokens: List[Token] = []
51
+ tags: List[int] = []
52
+
53
+
54
+ class SampleList(BaseModel):
55
+ samples: List[Sample]
util/tokenizer.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import spacy
4
+
5
+ from util.process_data import Token, Sample, SampleList
6
+
7
+ class Tokenizer():
8
+
9
+ def __init__(self, spacy_model: str):
10
+ self.__spacy_model = spacy.load(spacy_model)
11
+
12
+ def run(self, sample_list: SampleList):
13
+ self.__tokenize(sample_list.samples, self.__spacy_model)
14
+
15
+ def __tokenize(self, samples: List[Sample], spacy_model):
16
+ doc_pipe = spacy_model.pipe([sample.text.replace('\xa0', ' ') for sample in samples])
17
+ for sample, doc in zip(samples, doc_pipe):
18
+ sample.tokens = [Token(
19
+ text=x.text,
20
+ start=x.idx,
21
+ end=x.idx + len(x.text)
22
+ ) for x in doc]
23
+ while '\n' in sample.tokens[-1].text or ' ' in sample.tokens[-1].text:
24
+ sample.tokens = sample.tokens[:-1]