Spaces:
Runtime error
Runtime error
Daniel Steinigen
commited on
Commit
•
a50f42c
1
Parent(s):
980b30f
add demonstrator
Browse files- README.md +1 -1
- app.py +209 -0
- classification.json +83 -0
- model_inference.py +380 -0
- requirements.txt +11 -0
- util/__init__.py +0 -0
- util/configuration.py +9 -0
- util/ontology.png +0 -0
- util/process_data.py +55 -0
- util/tokenizer.py +24 -0
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: NLP Legal Texts
|
3 |
-
emoji:
|
4 |
colorFrom: red
|
5 |
colorTo: gray
|
6 |
sdk: streamlit
|
|
|
1 |
---
|
2 |
title: NLP Legal Texts
|
3 |
+
emoji: ⚖
|
4 |
colorFrom: red
|
5 |
colorTo: gray
|
6 |
sdk: streamlit
|
app.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import streamlit as st
|
3 |
+
from spacy import displacy
|
4 |
+
from PIL import Image
|
5 |
+
import json
|
6 |
+
import requests
|
7 |
+
from pyvis.network import Network
|
8 |
+
import streamlit.components.v1 as components
|
9 |
+
|
10 |
+
from util.process_data import Entity, EntityType, Relation, Sample, SampleList
|
11 |
+
from util.tokenizer import Tokenizer
|
12 |
+
from model_inference import TransformersInference
|
13 |
+
from util.configuration import InferenceConfiguration
|
14 |
+
|
15 |
+
inference_config = InferenceConfiguration()
|
16 |
+
tokenizer = Tokenizer(inference_config.spacy_model)
|
17 |
+
|
18 |
+
SAMPLE_66 = "EStG § 66 Höhe des Kindergeldes, Zahlungszeitraum (1) Das Kindergeld beträgt monatlich für das erste und zweite Kind jeweils 219 Euro, für das dritte Kind 225 Euro und für das vierte und jedes weitere Kind jeweils 250 Euro."
|
19 |
+
SAMPLE_9 = "EStG § 9 Werbungskosten ... Zur Abgeltung dieser Aufwendungen ist für jeden Arbeitstag, an dem der Arbeitnehmer die erste Tätigkeitsstätte aufsucht eine Entfernungspauschale für jeden vollen Kilometer der Entfernung zwischen Wohnung und erster Tätigkeitsstätte von 0,30 Euro anzusetzen, höchstens jedoch 4 500 Euro im Kalenderjahr; ein höherer Betrag als 4 500 Euro ist anzusetzen, soweit der Arbeitnehmer einen eigenen oder ihm zur Nutzung überlassenen Kraftwagen benutzt."
|
20 |
+
|
21 |
+
|
22 |
+
############################################################
|
23 |
+
## Constants
|
24 |
+
############################################################
|
25 |
+
max_width_str = f"max-width: 60%;"
|
26 |
+
paragraph = None
|
27 |
+
style = "<style>mark.entity { display: inline-block }</style>"
|
28 |
+
graph_options = '''
|
29 |
+
var options = {
|
30 |
+
"edges": {
|
31 |
+
"arrows": {
|
32 |
+
"to": {
|
33 |
+
"enabled": true,
|
34 |
+
"scaleFactor": 1.2
|
35 |
+
}
|
36 |
+
}
|
37 |
+
}
|
38 |
+
}
|
39 |
+
'''
|
40 |
+
|
41 |
+
legend_content = {
|
42 |
+
"text": "StatedKeyFigure StatedExpression Unit Range Factor Condition DeclarativeKeyFigure DeclarativeExpression",
|
43 |
+
"ents": [
|
44 |
+
{"start": 0, "end": 15, "label": "K"},
|
45 |
+
{"start": 16, "end": 32, "label": "E"},
|
46 |
+
{"start": 33, "end": 37, "label": "U"},
|
47 |
+
{"start": 38, "end": 43, "label": "R"},
|
48 |
+
{"start": 44, "end": 50, "label": "F"},
|
49 |
+
{"start": 51, "end": 60, "label": "C"},
|
50 |
+
{"start": 61, "end": 81, "label": "DK"},
|
51 |
+
{"start": 82, "end": 103, "label": "DE"},
|
52 |
+
]}
|
53 |
+
legend_options = {
|
54 |
+
"ents": ["K","U","E","R","F","C","DK","DE"],
|
55 |
+
"colors": {'K': '#46d000',"U": "#e861ef", "E": "#538cff", "R": "#ffbe00", "F": "#0fd5dc", "C":"#ff484b", "DK":"#46d000", "DE":"#538cff"}
|
56 |
+
}
|
57 |
+
legend_mapping = {"StatedKeyFigure": "K","Unit": "U","StatedExpression": "E","Range": "R","Factor": "F","Condition": "C","DeclarativeKeyFigure": "DK","DeclarativeExpression": "DE"}
|
58 |
+
edge_colors = {'hasKeyFigure': '#46d000',"hasUnit": "#e861ef", "hasExpression": "#538cff", "hasRange": "#ffbe00", "hasFactor": "#0fd5dc", "hasCondition":"#ff484b", "join":"#aaa", "Typ":"#aaa", "hasParagraph": "#FF8B15"}
|
59 |
+
|
60 |
+
|
61 |
+
############################################################
|
62 |
+
## Function definitions
|
63 |
+
############################################################
|
64 |
+
|
65 |
+
def get_html(html: str, legend=False):
|
66 |
+
"""Convert HTML so it can be rendered."""
|
67 |
+
WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1.5rem">{}</div>"""
|
68 |
+
if legend: WRAPPER = """<div style="overflow-x: auto; padding: 1rem">{}</div>"""
|
69 |
+
# Newlines seem to mess with the rendering
|
70 |
+
html = html.replace("\n", " ")
|
71 |
+
return WRAPPER.format(html)
|
72 |
+
|
73 |
+
|
74 |
+
def get_displacy_ent_obj(paragraph, bedingungen=False, send_request=False):
|
75 |
+
entities = []
|
76 |
+
for entity in paragraph['entities']:
|
77 |
+
label = entity["entity"] if not send_request else entity["ent_type"]["label"]
|
78 |
+
if (bedingungen and label == "Condition") or (not bedingungen and label != "Condition") :
|
79 |
+
entities.append({
|
80 |
+
'start': entity['start'],
|
81 |
+
'end': entity["end"],
|
82 |
+
'label': legend_mapping[label]
|
83 |
+
})
|
84 |
+
return [{'text': paragraph['text'], 'ents': entities}]
|
85 |
+
|
86 |
+
|
87 |
+
def request_extractor(text_data):
|
88 |
+
try:
|
89 |
+
data = SampleList(
|
90 |
+
samples=[
|
91 |
+
Sample(
|
92 |
+
idx=0,
|
93 |
+
text=str(text_data),
|
94 |
+
entities=[],
|
95 |
+
relations=[]
|
96 |
+
)
|
97 |
+
]
|
98 |
+
)
|
99 |
+
tokenizer.run(data)
|
100 |
+
|
101 |
+
model_inference = TransformersInference(inference_config)
|
102 |
+
model_inference.run_inference(data)
|
103 |
+
return data.dict()["samples"][0]
|
104 |
+
except Exception as e:
|
105 |
+
result = e
|
106 |
+
return {"text":"error","entities":[], "relations":[]}
|
107 |
+
|
108 |
+
|
109 |
+
def generate_graph(nodes, edges, send_request=False):
|
110 |
+
net = Network(height="450px", width="100%")#, bgcolor="#222222", font_color="white", select_menu=True, filter_menu=True)
|
111 |
+
for node in nodes:
|
112 |
+
if "id" in node:
|
113 |
+
label = node["entity"] if not send_request else node["ent_type"]["label"]
|
114 |
+
node_color = legend_options["colors"][legend_mapping[label]]
|
115 |
+
node_label = node["text"] if len(node["text"]) < 30 else (node["text"][:27]+" ...")
|
116 |
+
if label in ["Kennzahl", "Kennzahlumschreibung"]:
|
117 |
+
net.add_node(node["id"], label=node_label, title=node["text"], mass=2, shape="ellipse", color=node_color, physics=False)
|
118 |
+
else:
|
119 |
+
net.add_node(node["id"], label=node_label, title=node["text"], mass=1, shape="ellipse", color=node_color)
|
120 |
+
for edge in edges:
|
121 |
+
label = edge["relation"] if not send_request else edge["rel_type"]["label"]
|
122 |
+
net.add_edge(edge["head"], edge["tail"], width=1, title=label, arrowStrikethrough=False, color=edge_colors[label])
|
123 |
+
# net.force_atlas_2based() # barnes_hut() force_atlas_2based() hrepulsion() repulsion()
|
124 |
+
net.toggle_physics(True)
|
125 |
+
net.set_edge_smooth("dynamic") # dynamic, continuous, discrete, diagonalCross, straightCross, horizontal, vertical, curvedCW, curvedCCW, cubicBezier
|
126 |
+
net.set_options(graph_options)
|
127 |
+
html_graph = net.generate_html()
|
128 |
+
return html_graph
|
129 |
+
|
130 |
+
############################################################
|
131 |
+
## Page configuration
|
132 |
+
############################################################
|
133 |
+
st.set_page_config(
|
134 |
+
page_title="NLP Gesetzestexte",
|
135 |
+
menu_items={
|
136 |
+
'Get Help': None,
|
137 |
+
'Report a bug': None,
|
138 |
+
'About': "## Demonstrator NLP"
|
139 |
+
}
|
140 |
+
# layout="wide")
|
141 |
+
)
|
142 |
+
|
143 |
+
st.markdown(
|
144 |
+
f"""
|
145 |
+
<style>
|
146 |
+
.appview-container .main .block-container{{
|
147 |
+
{max_width_str}
|
148 |
+
}}
|
149 |
+
</style>
|
150 |
+
""",
|
151 |
+
unsafe_allow_html=True,
|
152 |
+
)
|
153 |
+
|
154 |
+
# radio button formatting in line
|
155 |
+
st.write('<style>div.row-widget.stRadio > div{flex-direction:row;justify-content: left;} </style>', unsafe_allow_html=True)
|
156 |
+
|
157 |
+
############################################################
|
158 |
+
## Page formating
|
159 |
+
############################################################
|
160 |
+
col3, col4 = st.columns([2.4,1.6])
|
161 |
+
st.write('\n')
|
162 |
+
st.write('\n')
|
163 |
+
|
164 |
+
|
165 |
+
with col3:
|
166 |
+
st.subheader("Extraction of Key Figures")
|
167 |
+
st.write("Demonstrator Application for Paper 'Semantic Extraction of Key Figures and Their Properties From Tax Legal Texts using Neural Models'")
|
168 |
+
with col4:
|
169 |
+
st.caption("Semantic Model")
|
170 |
+
image = Image.open('util/ontology.png')
|
171 |
+
st.image(image, width=350)
|
172 |
+
|
173 |
+
|
174 |
+
text_option = st.radio("Select Example", ["Insert your paragraph", "EStG § 66 Kindergeld", "EStG § 9 Werbungskosten"])
|
175 |
+
st.write('\n')
|
176 |
+
if text_option == "EStG § 66 Kindergeld":
|
177 |
+
text_area_input = st.text_area("Given paragraph", SAMPLE_66, height=200)
|
178 |
+
elif text_option == "EStG § 9 Werbungskosten":
|
179 |
+
text_area_input = st.text_area("Given paragraph", SAMPLE_9, height=200)
|
180 |
+
else:
|
181 |
+
text_area_input = st.text_area("Given paragraph", "", height=200)
|
182 |
+
|
183 |
+
if st.button("Start Extraction") and text_area_input != "":
|
184 |
+
with st.spinner('Executing Extraction ...'):
|
185 |
+
paragraph = request_extractor(text_area_input)
|
186 |
+
if paragraph["text"] == "error":
|
187 |
+
st.error("Error while executing extraction.")
|
188 |
+
else:
|
189 |
+
legend = displacy.render([legend_content], style="ent", options=legend_options, manual=True)
|
190 |
+
st.write(f"{style}{get_html(legend, True)}", unsafe_allow_html=True)
|
191 |
+
|
192 |
+
st.caption("Entities:")
|
193 |
+
extracted_data = get_displacy_ent_obj(paragraph, False, True)
|
194 |
+
html = displacy.render(extracted_data, style="ent", options=legend_options, manual=True)
|
195 |
+
st.write(f"{style}{get_html(html)}", unsafe_allow_html=True)
|
196 |
+
|
197 |
+
st.write('\n')
|
198 |
+
st.caption("Conditions:")
|
199 |
+
extracted_data = get_displacy_ent_obj(paragraph, True, True)
|
200 |
+
html = displacy.render(extracted_data, style="ent", options=legend_options, manual=True)
|
201 |
+
st.write(f"{style}{get_html(html)}", unsafe_allow_html=True)
|
202 |
+
|
203 |
+
st.write('\n')
|
204 |
+
st.caption("\n\nRelations:")
|
205 |
+
html_graph_req = generate_graph(paragraph["entities"], paragraph["relations"], send_request=True)
|
206 |
+
components.html(html_graph_req, height=500)
|
207 |
+
st.write('\n')
|
208 |
+
with st.expander("Show JSON"):
|
209 |
+
st.json(paragraph)
|
classification.json
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"entity_types": [
|
3 |
+
{
|
4 |
+
"idx": 0,
|
5 |
+
"label": "O"
|
6 |
+
},
|
7 |
+
{
|
8 |
+
"idx": 1,
|
9 |
+
"label": "StatedKeyFigure"
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"idx": 2,
|
13 |
+
"label": "Condition"
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"idx": 3,
|
17 |
+
"label": "StatedExpression"
|
18 |
+
},
|
19 |
+
{
|
20 |
+
"idx": 4,
|
21 |
+
"label": "Unit"
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"idx": 5,
|
25 |
+
"label": "Range"
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"idx": 6,
|
29 |
+
"label": "DeclarativeKeyFigure"
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"idx": 7,
|
33 |
+
"label": "Factor"
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"idx": 8,
|
37 |
+
"label": "DeclarativeExpression"
|
38 |
+
}
|
39 |
+
],
|
40 |
+
"relation_types": [
|
41 |
+
{
|
42 |
+
"idx": 9,
|
43 |
+
"label": "hasCondition"
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"idx": 10,
|
47 |
+
"label": "hasExpression"
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"idx": 11,
|
51 |
+
"label": "hasUnit"
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"idx": 12,
|
55 |
+
"label": "join"
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"idx": 13,
|
59 |
+
"label": "hasRange"
|
60 |
+
},
|
61 |
+
{
|
62 |
+
"idx": 14,
|
63 |
+
"label": "hasFactor"
|
64 |
+
}
|
65 |
+
],
|
66 |
+
"id_of_non_entity": 0,
|
67 |
+
"groups": [
|
68 |
+
[
|
69 |
+
0,
|
70 |
+
2
|
71 |
+
],
|
72 |
+
[
|
73 |
+
0,
|
74 |
+
1,
|
75 |
+
3,
|
76 |
+
4,
|
77 |
+
5,
|
78 |
+
6,
|
79 |
+
7,
|
80 |
+
8
|
81 |
+
]
|
82 |
+
]
|
83 |
+
}
|
model_inference.py
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
from typing import List, Any
|
4 |
+
import copy
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from transformers import AutoTokenizer, AutoModelForTokenClassification, Trainer
|
9 |
+
|
10 |
+
from util.process_data import Sample, Entity, EntityType, EntityTypeSet, SampleList, Token, Relation
|
11 |
+
from util.configuration import InferenceConfiguration
|
12 |
+
|
13 |
+
valid_relations = { # head : [tail, ...]
|
14 |
+
"StatedKeyFigure": ["StatedKeyFigure", "Condition", "StatedExpression", "DeclarativeExpression"],
|
15 |
+
"DeclarativeKeyFigure": ["DeclarativeKeyFigure", "Condition", "StatedExpression", "DeclarativeExpression"],
|
16 |
+
"StatedExpression": ["Unit", "Factor", "Range", "Condition"],
|
17 |
+
"DeclarativeExpression": ["DeclarativeExpression", "Unit", "Factor", "Range", "Condition"],
|
18 |
+
"Condition": ["Condition", "StatedExpression", "DeclarativeExpression"],
|
19 |
+
"Range": ["Range"]
|
20 |
+
}
|
21 |
+
|
22 |
+
class TokenClassificationDataset(Dataset):
|
23 |
+
""" Pytorch Dataset """
|
24 |
+
|
25 |
+
def __init__(self, encodings, labels):
|
26 |
+
self.encodings = encodings
|
27 |
+
self.labels = labels
|
28 |
+
|
29 |
+
def __getitem__(self, idx):
|
30 |
+
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
|
31 |
+
item['labels'] = torch.tensor(self.labels[idx])
|
32 |
+
return item
|
33 |
+
|
34 |
+
def __len__(self):
|
35 |
+
return len(self.labels)
|
36 |
+
|
37 |
+
|
38 |
+
class TransformersInference():
|
39 |
+
|
40 |
+
def __init__(self, config: InferenceConfiguration):
|
41 |
+
super().__init__()
|
42 |
+
self.__logger = logging.getLogger(self.__class__.__name__)
|
43 |
+
self.__logger.info(f"Load Configuration: {config.dict()}")
|
44 |
+
|
45 |
+
with open(f"classification.json", mode='r', encoding="utf-8") as f:
|
46 |
+
self.__entity_type_set = EntityTypeSet.parse_obj(json.load(f))
|
47 |
+
self.__entity_type_label_to_id_mapping = {x.label: x.idx for x in self.__entity_type_set.all_types()}
|
48 |
+
self.__entity_type_id_to_label_mapping = {x.idx: x.label for x in self.__entity_type_set.all_types()}
|
49 |
+
|
50 |
+
self.__logger.info("Load Model: " + config.model_path_keyfigure)
|
51 |
+
self.__tokenizer = AutoTokenizer.from_pretrained(config.transformer_model,
|
52 |
+
padding="max_length", max_length=512, truncation=True)
|
53 |
+
|
54 |
+
self.__model = AutoModelForTokenClassification.from_pretrained(config.model_path_keyfigure, num_labels=(
|
55 |
+
len(self.__entity_type_set)))
|
56 |
+
|
57 |
+
self.__trainer = Trainer(model=self.__model)
|
58 |
+
self.__merge_entities = config.merge_entities
|
59 |
+
self.__split_len = config.split_len
|
60 |
+
self.__extract_relations = config.extract_relations
|
61 |
+
|
62 |
+
# add special tokens
|
63 |
+
entity_groups = self.__entity_type_set.groups
|
64 |
+
num_entity_groups = len(entity_groups)
|
65 |
+
|
66 |
+
lst_special_tokens = ["[REL]", "[SUB]", "[/SUB]", "[OBJ]", "[/OBJ]"]
|
67 |
+
for grp_idx, grp in enumerate(entity_groups):
|
68 |
+
lst_special_tokens.append(f"[GRP-{grp_idx:02d}]")
|
69 |
+
lst_special_tokens.extend([f"[ENT-{ent:02d}]" for ent in grp if ent != self.__entity_type_set.id_of_non_entity])
|
70 |
+
lst_special_tokens.extend([f"[/ENT-{ent:02d}]" for ent in grp if ent != self.__entity_type_set.id_of_non_entity])
|
71 |
+
|
72 |
+
lst_special_tokens = sorted(list(set(lst_special_tokens)))
|
73 |
+
special_tokens_dict = {'additional_special_tokens': lst_special_tokens }
|
74 |
+
num_added_toks = self.__tokenizer.add_special_tokens(special_tokens_dict)
|
75 |
+
self.__logger.info(f"Added {num_added_toks} new special tokens. All special tokens: '{self.__tokenizer.all_special_tokens}'")
|
76 |
+
|
77 |
+
self.__logger.info("Initialization completed.")
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
def run_inference(self, sample_list: SampleList):
|
82 |
+
group_predictions = []
|
83 |
+
group_entity_ids = []
|
84 |
+
self.__logger.info("Predict Entities ...")
|
85 |
+
for grp_idx, grp in enumerate(self.__entity_type_set.groups):
|
86 |
+
token_lists = [[x.text for x in sample.tokens] for sample in sample_list.samples]
|
87 |
+
predictions = self.__get_predictions(token_lists, f"[GRP-{grp_idx:02d}]")
|
88 |
+
group_entity_ids_ = []
|
89 |
+
for sample, prediction_per_tokens in zip(sample_list.samples, predictions):
|
90 |
+
group_entity_ids_.append(self.generate_response_entities(sample, prediction_per_tokens, grp_idx))
|
91 |
+
group_predictions.append(predictions)
|
92 |
+
group_entity_ids.append(group_entity_ids_)
|
93 |
+
|
94 |
+
if self.__extract_relations:
|
95 |
+
self.__logger.info("Predict Relations ...")
|
96 |
+
self.__do_extract_relations(sample_list, group_predictions, group_entity_ids)
|
97 |
+
|
98 |
+
|
99 |
+
def __do_extract_relations(self, sample_list, group_predictions, group_entity_ids):
|
100 |
+
id_of_non_entity = self.__entity_type_set.id_of_non_entity
|
101 |
+
|
102 |
+
for sample_idx, sample in enumerate(sample_list.samples):
|
103 |
+
masked_tokens = []
|
104 |
+
masked_tokens_align = []
|
105 |
+
# create SUB-Mask for every entity that can be a head
|
106 |
+
head_entities = [entity_ for entity_ in sample.entities if entity_.ent_type.label in list(valid_relations.keys())]
|
107 |
+
for entity_ in head_entities:
|
108 |
+
ent_masked_tokens = []
|
109 |
+
ent_masked_tokens_align = []
|
110 |
+
last_preds = [id_of_non_entity for group in group_predictions]
|
111 |
+
last_ent_ids = [-1 for group in group_entity_ids]
|
112 |
+
for token_idx, token in enumerate(sample.tokens):
|
113 |
+
for group, ent_ids, last_pred, last_ent_id in zip(group_predictions, group_entity_ids, last_preds, last_ent_ids):
|
114 |
+
pred = group[sample_idx][token_idx]
|
115 |
+
ent_id = ent_ids[sample_idx][token_idx]
|
116 |
+
if last_pred != pred and last_pred != id_of_non_entity:
|
117 |
+
mask = "[/SUB]" if last_ent_id == entity_.id else "[/OBJ]"
|
118 |
+
ent_masked_tokens.extend([f"[/ENT-{last_pred:02d}]", mask])
|
119 |
+
ent_masked_tokens_align.extend([str(last_ent_id), str(last_ent_id)])
|
120 |
+
|
121 |
+
for group, ent_ids, last_pred, last_ent_id in zip(group_predictions, group_entity_ids, last_preds, last_ent_ids):
|
122 |
+
pred = group[sample_idx][token_idx]
|
123 |
+
ent_id = ent_ids[sample_idx][token_idx]
|
124 |
+
if last_pred != pred and pred != id_of_non_entity:
|
125 |
+
mask = "[SUB]" if ent_id == entity_.id else "[OBJ]"
|
126 |
+
ent_masked_tokens.extend([mask, f"[ENT-{pred:02d}]"])
|
127 |
+
ent_masked_tokens_align.extend([str(ent_id), str(ent_id)])
|
128 |
+
|
129 |
+
ent_masked_tokens.append(token.text)
|
130 |
+
ent_masked_tokens_align.append(token.text)
|
131 |
+
for idx, group in enumerate(group_predictions):
|
132 |
+
last_preds[idx] = group[sample_idx][token_idx]
|
133 |
+
for idx, group in enumerate(group_entity_ids):
|
134 |
+
last_ent_ids[idx] = group[sample_idx][token_idx]
|
135 |
+
|
136 |
+
for group, ent_ids, last_pred, last_ent_id in zip(group_predictions, group_entity_ids, last_preds, last_ent_ids):
|
137 |
+
pred = group[sample_idx][token_idx]
|
138 |
+
ent_id = ent_ids[sample_idx][token_idx]
|
139 |
+
if last_pred != id_of_non_entity:
|
140 |
+
mask = "[/SUB]" if last_ent_id == entity_.id else "[/OBJ]"
|
141 |
+
ent_masked_tokens.extend([f"[/ENT-{last_pred:02d}]", mask])
|
142 |
+
ent_masked_tokens_align.extend([str(last_ent_id), str(last_ent_id)])
|
143 |
+
|
144 |
+
masked_tokens.append(ent_masked_tokens)
|
145 |
+
masked_tokens_align.append(ent_masked_tokens_align)
|
146 |
+
|
147 |
+
rel_predictions = self.__get_predictions(masked_tokens, "[REL]")
|
148 |
+
self.generate_response_relations(sample, head_entities, masked_tokens_align, rel_predictions)
|
149 |
+
|
150 |
+
|
151 |
+
def generate_response_entities(self, sample: Sample, predictions_per_tokens: List[int], grp_idx: int):
|
152 |
+
entities = []
|
153 |
+
entity_ids = []
|
154 |
+
id_of_non_entity = self.__entity_type_set.id_of_non_entity
|
155 |
+
idx = grp_idx * 1000
|
156 |
+
for token, prediction in zip(sample.tokens, predictions_per_tokens):
|
157 |
+
if id_of_non_entity == prediction:
|
158 |
+
entity_ids.append(-1)
|
159 |
+
continue
|
160 |
+
idx += 1
|
161 |
+
entities.append(self.__build_entity(idx, prediction, token))
|
162 |
+
entity_ids.append(idx)
|
163 |
+
|
164 |
+
if self.__merge_entities:
|
165 |
+
entities = self.__do_merge_entities(copy.deepcopy(entities))
|
166 |
+
prev_pred = id_of_non_entity
|
167 |
+
for idx, pred in enumerate(predictions_per_tokens):
|
168 |
+
if prev_pred == pred and idx > 0:
|
169 |
+
entity_ids[idx] = entity_ids[idx-1]
|
170 |
+
prev_pred = pred
|
171 |
+
|
172 |
+
sample.entities += entities
|
173 |
+
|
174 |
+
tags = sample.tags if len(sample.tags) > 0 else [self.__entity_type_set.id_of_non_entity] * len(sample.tokens)
|
175 |
+
for tag_id, tok in enumerate(sample.tokens):
|
176 |
+
for ent in entities:
|
177 |
+
if tok.start >= ent.start and tok.start < ent.end:
|
178 |
+
tags[tag_id] = ent.ent_type.idx
|
179 |
+
logging.info(tags)
|
180 |
+
sample.tags = tags
|
181 |
+
|
182 |
+
return entity_ids
|
183 |
+
|
184 |
+
|
185 |
+
def generate_response_relations(self, sample: Sample, head_entities: List[Entity], masked_tokens_align: List[List[str]], rel_predictions: List[List[int]]):
|
186 |
+
relations = []
|
187 |
+
id_of_non_entity = self.__entity_type_set.id_of_non_entity
|
188 |
+
idx = 0
|
189 |
+
for entity_, align_per_ent, prediction_per_ent in zip(head_entities, masked_tokens_align, rel_predictions):
|
190 |
+
for token, prediction in zip(align_per_ent, prediction_per_ent):
|
191 |
+
if id_of_non_entity == prediction:
|
192 |
+
continue
|
193 |
+
try:
|
194 |
+
tail = int(token)
|
195 |
+
except:
|
196 |
+
continue
|
197 |
+
if not self.__validate_relation(sample.entities, entity_.id, tail, prediction):
|
198 |
+
continue
|
199 |
+
idx += 1
|
200 |
+
relations.append(self.__build_relation(idx, entity_.id, tail, prediction))
|
201 |
+
|
202 |
+
sample.relations = relations
|
203 |
+
|
204 |
+
|
205 |
+
def __validate_relation(self, entities: List[Entity], head: int, tail: int, prediction: int):
|
206 |
+
if head == tail: return False
|
207 |
+
head_ents = [ent.ent_type.label for ent in entities if ent.id==head]
|
208 |
+
tail_ents = [ent.ent_type.label for ent in entities if ent.id==tail]
|
209 |
+
|
210 |
+
if len(head_ents) > 0:
|
211 |
+
head_ent = head_ents[0]
|
212 |
+
else:
|
213 |
+
return False
|
214 |
+
|
215 |
+
if len(tail_ents) > 0:
|
216 |
+
tail_ent = tail_ents[0]
|
217 |
+
else:
|
218 |
+
return False
|
219 |
+
|
220 |
+
return tail_ent in valid_relations[head_ent]
|
221 |
+
|
222 |
+
|
223 |
+
def __build_entity(self, idx: int, prediction: int, token: Token) -> Entity:
|
224 |
+
return Entity(
|
225 |
+
id=idx,
|
226 |
+
text=token.text,
|
227 |
+
start=token.start,
|
228 |
+
end=token.end,
|
229 |
+
ent_type=EntityType(
|
230 |
+
idx=prediction,
|
231 |
+
label=self.__entity_type_id_to_label_mapping[prediction]
|
232 |
+
)
|
233 |
+
)
|
234 |
+
|
235 |
+
def __build_relation(self, idx: int, head: int, tail: int, prediction: int) -> Relation:
|
236 |
+
return Relation(
|
237 |
+
id=idx,
|
238 |
+
head=head,
|
239 |
+
tail=tail,
|
240 |
+
rel_type=EntityType(
|
241 |
+
idx=prediction,
|
242 |
+
label=self.__entity_type_id_to_label_mapping[prediction]
|
243 |
+
)
|
244 |
+
)
|
245 |
+
|
246 |
+
def __do_merge_entities(self, input_ents_):
|
247 |
+
out_ents = list()
|
248 |
+
current_ent = None
|
249 |
+
|
250 |
+
for ent in input_ents_:
|
251 |
+
if current_ent is None:
|
252 |
+
current_ent = ent
|
253 |
+
else:
|
254 |
+
idx_diff = ent.start - current_ent.end
|
255 |
+
if ent.ent_type.idx == current_ent.ent_type.idx and idx_diff <= 1:
|
256 |
+
current_ent.end = ent.end
|
257 |
+
current_ent.text += (" " if idx_diff == 1 else "") + ent.text
|
258 |
+
else:
|
259 |
+
out_ents.append(current_ent)
|
260 |
+
current_ent = ent
|
261 |
+
|
262 |
+
if current_ent is not None:
|
263 |
+
out_ents.append(current_ent)
|
264 |
+
|
265 |
+
return out_ents
|
266 |
+
|
267 |
+
|
268 |
+
def __get_predictions(self, token_lists: List[List[str]], trigger: str) -> List[List[int]]:
|
269 |
+
""" Get predictions of Transformer Sequence Labeling model """
|
270 |
+
if self.__split_len > 0:
|
271 |
+
token_lists_split = self.__do_split_sentences(token_lists, self.__split_len)
|
272 |
+
predictions = []
|
273 |
+
for sample_token_lists in token_lists_split:
|
274 |
+
sample_token_lists_trigger = [[trigger]+sample for sample in sample_token_lists]
|
275 |
+
val_encodings = self.__tokenizer(sample_token_lists_trigger, is_split_into_words=True, padding='max_length', truncation=True) # return_tensors="pt"
|
276 |
+
val_labels = []
|
277 |
+
for i in range(len(sample_token_lists_trigger)):
|
278 |
+
word_ids = val_encodings.word_ids(batch_index=i)
|
279 |
+
label_ids = [0 for _ in word_ids]
|
280 |
+
val_labels.append(label_ids)
|
281 |
+
|
282 |
+
val_dataset = TokenClassificationDataset(val_encodings, val_labels)
|
283 |
+
|
284 |
+
predictions_raw, _, _ = self.__trainer.predict(val_dataset)
|
285 |
+
|
286 |
+
predictions_align = self.__align_predictions(predictions_raw, val_encodings)
|
287 |
+
confidence = [[max(token) for token in sample] for sample in predictions_align]
|
288 |
+
predictions_sample = [[token.index(max(token)) for token in sample][1:] for sample in predictions_align]
|
289 |
+
predictions_part = []
|
290 |
+
for tok, pred in zip(sample_token_lists_trigger, predictions_sample):
|
291 |
+
if trigger == "[REL]" and "[SUB]" not in tok:
|
292 |
+
predictions_part += [self.__entity_type_set.id_of_non_entity] * len(pred)
|
293 |
+
else:
|
294 |
+
predictions_part += pred
|
295 |
+
predictions.append(predictions_part)
|
296 |
+
# predictions.append([j for i in predictions_sample for j in i]))
|
297 |
+
else:
|
298 |
+
token_lists_trigger = [[trigger]+sample for sample in token_lists]
|
299 |
+
val_encodings = self.__tokenizer(token_lists_trigger, is_split_into_words=True, padding='max_length', truncation=True) # return_tensors="pt"
|
300 |
+
val_labels = []
|
301 |
+
for i in range(len(token_lists_trigger)):
|
302 |
+
word_ids = val_encodings.word_ids(batch_index=i)
|
303 |
+
label_ids = [0 for _ in word_ids]
|
304 |
+
val_labels.append(label_ids)
|
305 |
+
|
306 |
+
val_dataset = TokenClassificationDataset(val_encodings, val_labels)
|
307 |
+
|
308 |
+
predictions_raw, _, _ = self.__trainer.predict(val_dataset)
|
309 |
+
|
310 |
+
predictions_align = self.__align_predictions(predictions_raw, val_encodings)
|
311 |
+
confidence = [[max(token) for token in sample] for sample in predictions_align]
|
312 |
+
predictions = [[token.index(max(token)) for token in sample][1:] for sample in predictions_align]
|
313 |
+
|
314 |
+
return predictions
|
315 |
+
|
316 |
+
def __do_split_sentences(self, tokens_: List[List[str]], split_len_ = 200) -> List[List[List[str]]]:
|
317 |
+
# split token lists into shorter lists
|
318 |
+
res_tokens = []
|
319 |
+
|
320 |
+
for tok_lst in tokens_:
|
321 |
+
res_tokens_sample = []
|
322 |
+
length = len(tok_lst)
|
323 |
+
if length > split_len_:
|
324 |
+
num_lists = length // split_len_ + (1 if (length % split_len_) > 0 else 0)
|
325 |
+
new_length = int(length / num_lists) + 1
|
326 |
+
self.__logger.info(f"Splitting a list of {length} elements into {num_lists} lists of length {new_length}..")
|
327 |
+
start_idx = 0
|
328 |
+
for i in range(num_lists):
|
329 |
+
end_idx = min(start_idx + new_length, length)
|
330 |
+
if "\n" in tok_lst[start_idx]: tok_lst[start_idx] = "."
|
331 |
+
if "\n" in tok_lst[end_idx-1]: tok_lst[end_idx-1] = "."
|
332 |
+
res_tokens_sample.append(tok_lst[start_idx:end_idx])
|
333 |
+
start_idx = end_idx
|
334 |
+
|
335 |
+
res_tokens.append(res_tokens_sample)
|
336 |
+
else:
|
337 |
+
res_tokens.append([tok_lst])
|
338 |
+
|
339 |
+
return res_tokens
|
340 |
+
|
341 |
+
|
342 |
+
def __align_predictions(self, predictions, tokenized_inputs, sum_all_tokens=False) -> List[List[List[float]]]:
|
343 |
+
""" Align predicted labels from Transformer Tokenizer """
|
344 |
+
confidence = []
|
345 |
+
id_of_non_entity = self.__entity_type_set.id_of_non_entity
|
346 |
+
for i, tagset in enumerate(predictions):
|
347 |
+
|
348 |
+
word_ids = tokenized_inputs.word_ids(batch_index=i)
|
349 |
+
|
350 |
+
previous_word_idx = None
|
351 |
+
token_confidence = []
|
352 |
+
for k, word_idx in enumerate(word_ids):
|
353 |
+
try:
|
354 |
+
tok_conf = [value for value in tagset[k]]
|
355 |
+
except TypeError:
|
356 |
+
# use the object itself it if's not iterable
|
357 |
+
tok_conf = tagset[k]
|
358 |
+
|
359 |
+
if word_idx is not None:
|
360 |
+
# add nonentity tokens if there is a gap in word ids (usually caused by a newline token)
|
361 |
+
if previous_word_idx is not None:
|
362 |
+
diff = word_idx - previous_word_idx
|
363 |
+
for i in range(diff - 1):
|
364 |
+
tmp = [0 for _ in tok_conf]
|
365 |
+
tmp[id_of_non_entity] = 1.0
|
366 |
+
token_confidence.append(tmp)
|
367 |
+
|
368 |
+
# add confidence value if this is the first token of the word
|
369 |
+
if word_idx != previous_word_idx:
|
370 |
+
token_confidence.append(tok_conf)
|
371 |
+
else:
|
372 |
+
# if sum_all_tokens=True the confidence for all tokens of one word will be summarized
|
373 |
+
if sum_all_tokens:
|
374 |
+
token_confidence[-1] = [a + b for a, b in zip(token_confidence[-1], tok_conf)]
|
375 |
+
|
376 |
+
previous_word_idx = word_idx
|
377 |
+
|
378 |
+
confidence.append(token_confidence)
|
379 |
+
|
380 |
+
return confidence
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python 3.8.11
|
2 |
+
numpy~=1.23.5
|
3 |
+
PyYAML~=5.4.1
|
4 |
+
pydantic==1.8.2
|
5 |
+
tqdm~=4.56.2
|
6 |
+
scikit-learn~=0.24.2
|
7 |
+
spacy==3.2.0
|
8 |
+
# MarkupSafe==2.0.1
|
9 |
+
torch==1.6.0
|
10 |
+
transformers[sentencepiece]==4.26.1
|
11 |
+
pyvis==0.3.2
|
util/__init__.py
ADDED
File without changes
|
util/configuration.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
|
3 |
+
class InferenceConfiguration(BaseModel):
|
4 |
+
model_path_keyfigure: str = "danielsteinigen/KeyFiTax"
|
5 |
+
spacy_model: str = "de_core_news_sm"
|
6 |
+
transformer_model: str = "xlm-roberta-large"
|
7 |
+
merge_entities: bool = True
|
8 |
+
split_len: int = 200
|
9 |
+
extract_relations: bool = True
|
util/ontology.png
ADDED
util/process_data.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, List
|
2 |
+
|
3 |
+
from pydantic import BaseModel, Extra
|
4 |
+
|
5 |
+
class EntityType(BaseModel):
|
6 |
+
idx: int
|
7 |
+
label: str
|
8 |
+
|
9 |
+
|
10 |
+
class EntityTypeSet(BaseModel):
|
11 |
+
entity_types: List[EntityType]
|
12 |
+
relation_types: List[EntityType]
|
13 |
+
id_of_non_entity: int
|
14 |
+
groups: List[List[int]]
|
15 |
+
|
16 |
+
def __len__(self):
|
17 |
+
return len(self.entity_types) + len(self.relation_types)
|
18 |
+
|
19 |
+
def all_types(self):
|
20 |
+
return [*self.entity_types, *self.relation_types]
|
21 |
+
|
22 |
+
|
23 |
+
class Token(BaseModel):
|
24 |
+
text: str
|
25 |
+
start: int
|
26 |
+
end: int
|
27 |
+
|
28 |
+
|
29 |
+
class Entity(BaseModel):
|
30 |
+
id: int
|
31 |
+
text: str
|
32 |
+
start: int
|
33 |
+
end: int
|
34 |
+
ent_type: EntityType
|
35 |
+
confidence: Optional[float]
|
36 |
+
|
37 |
+
|
38 |
+
class Relation(BaseModel):
|
39 |
+
id: int
|
40 |
+
head: int
|
41 |
+
tail: int
|
42 |
+
rel_type: EntityType
|
43 |
+
|
44 |
+
|
45 |
+
class Sample(BaseModel):
|
46 |
+
idx: int
|
47 |
+
text: str
|
48 |
+
entities: List[Entity] = []
|
49 |
+
relations: List[Relation] = []
|
50 |
+
tokens: List[Token] = []
|
51 |
+
tags: List[int] = []
|
52 |
+
|
53 |
+
|
54 |
+
class SampleList(BaseModel):
|
55 |
+
samples: List[Sample]
|
util/tokenizer.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import spacy
|
4 |
+
|
5 |
+
from util.process_data import Token, Sample, SampleList
|
6 |
+
|
7 |
+
class Tokenizer():
|
8 |
+
|
9 |
+
def __init__(self, spacy_model: str):
|
10 |
+
self.__spacy_model = spacy.load(spacy_model)
|
11 |
+
|
12 |
+
def run(self, sample_list: SampleList):
|
13 |
+
self.__tokenize(sample_list.samples, self.__spacy_model)
|
14 |
+
|
15 |
+
def __tokenize(self, samples: List[Sample], spacy_model):
|
16 |
+
doc_pipe = spacy_model.pipe([sample.text.replace('\xa0', ' ') for sample in samples])
|
17 |
+
for sample, doc in zip(samples, doc_pipe):
|
18 |
+
sample.tokens = [Token(
|
19 |
+
text=x.text,
|
20 |
+
start=x.idx,
|
21 |
+
end=x.idx + len(x.text)
|
22 |
+
) for x in doc]
|
23 |
+
while '\n' in sample.tokens[-1].text or ' ' in sample.tokens[-1].text:
|
24 |
+
sample.tokens = sample.tokens[:-1]
|