egerber1 commited on
Commit
8b513d0
0 Parent(s):

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ data/*
2
+ .idea
3
+ *.log
4
+ .ipynb_checkpoints
5
+ data_spacy_entity_linker
README.md ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Spacy Entity Linker
2
+
3
+ ## Introduction
4
+ Spacy Entity Linker is a pipeline for spaCy that performs Linked Entity Extraction with Wikidata on
5
+ a given Document.
6
+ The Entity Linking System operates by matching potential candidates from each sentence
7
+ (subject, object, prepositional phrase, compounds, etc.) to aliases
8
+ from Wikidata. The package allows to easily find the category behind each entity (e.g. "banana" is type "food" OR "Microsoft" is type "company"). It can
9
+ is therefore useful for information extraction tasks and labeling tasks.
10
+
11
+ The package was written before a working Linked Entity Solution existed inside spaCy. In comparison to spaCy's linked entity system, it has the following examples
12
+ - no extensive training required (string-matching is done on a database)
13
+ - knowledge base can be dynamically updated without retraining
14
+ - entity categories can be easily resolved
15
+ - grouping entities by category
16
+
17
+ It also comes along with a number of disadvantages:
18
+ - it is slower than the spaCy implementation due to the use of a database for finding entities
19
+ - no context sensitivity due to the implementation of the "max-prior method" for entitiy disambiguation
20
+
21
+
22
+ ## Use
23
+ ```python
24
+ import spacy
25
+ from SpacyEntityLinker import EntityLinker
26
+
27
+ #Initialize Entity Linker
28
+ entityLinker = EntityLinker()
29
+
30
+ #initialize language model
31
+ nlp = spacy.load("en_core_web_sm")
32
+
33
+ #add pipeline
34
+ nlp.add_pipe(entityLinker, last=True, name="entityLinker")
35
+
36
+ doc = nlp("I watched the Pirates of the Carribean last silvester")
37
+
38
+
39
+ #returns all entities in the whole document
40
+ all_linked_entities=doc._.linkedEntities
41
+ #iterates over sentences and prints linked entities
42
+ for sent in doc.sents:
43
+ sent._.linkedEntities.pretty_print()
44
+
45
+ '''
46
+ https://www.wikidata.org/wiki/Q194318 194318 Pirates of the Caribbean Series of fantasy adventure films
47
+ https://www.wikidata.org/wiki/Q12525597 12525597 Silvester the day celebrated on 31 December (Roman Catholic Church) or 2 January (Eastern Orthodox Churches)
48
+
49
+ '''
50
+ ```
51
+
52
+ ## Example
53
+ In the following example we will use SpacyEntityLinker to extract all
54
+
55
+
56
+ ### Entity Linking Policy
57
+ Currently the only method for choosing an entity given different possible matches (e.g. Paris - city vs Paris - firstname) is max-prior. This method achieves around 70% accuracy on predicting
58
+ the correct entities behind link descriptions on wikipedia.
59
+
60
+ ## Note
61
+ The Entity Linker at the current state is still experimental and should not be used in production mode.
62
+
63
+ ## Performance
64
+ The current implementation supports only Sqlite. This is advantageous for development because
65
+ it does not requirement any special setup and configuration. However, for more performance critical usecases, a different
66
+ database with in-memory access (e.g. Redis) should be used. This may be implemented in the future.
67
+
68
+ ## Installation
69
+
70
+ To install the package run: <code>pip install spacy-entity-linker</code>
71
+
72
+ Afterwards, the knowledge base (Wikidata) must be downloaded. This can be done by calling
73
+
74
+ <code>python -m spacyEntityLinker download_knowledge_base</code>
75
+
76
+ This will download and extract a ~500mb file that contains a preprocessed version of Wikidata
77
+
78
+ ## TODO
79
+ - [ ] implement Entity Classifier based on sentence embeddings for improved accuracy
80
+ - [ ] implement get_picture_urls() on EntityElement
81
+ - [ ] retrieve statements for each EntityElement (inlinks + outlinks)
downloadKnowledgeBase.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ wget "https://wikidatafiles.nyc3.digitaloceanspaces.com/Hosting/Hosting/SpacyEntityLinker/datafiles.tar.gz" -O /tmp/knowledge_base.tar.gz
4
+ tar -xzf /tmp/knowledge_base.tar.gz --directory ./data_spacy_entity_linker
5
+ rm /tmp/knowledge_base.tar.gz
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ spacy>=2.1.9
setup.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright (c) 2014 SeatGeek
5
+
6
+ # This file is part of fuzzywuzzy.
7
+
8
+ from spacyEntityLinker import __version__
9
+ import os
10
+
11
+ try:
12
+ from setuptools import setup
13
+ except ImportError:
14
+ from distutils.core import setup
15
+
16
+
17
+ def open_file(fname):
18
+ return open(os.path.join(os.path.dirname(__file__), fname))
19
+
20
+
21
+ with open("README.md", "r") as fh:
22
+ long_description = fh.read()
23
+
24
+ setup(
25
+ name='spacy-entity-linker',
26
+ version=__version__,
27
+ author='Emanuel Gerber',
28
+ author_email='emanuel.j.gerber@gmail.com',
29
+ packages=['spacyEntityLinker'],
30
+ url='https://github.com/egerber/spacy-entity-linker',
31
+ license="MIT",
32
+ classifiers=["Environment :: Console",
33
+ "Intended Audience :: Developers",
34
+ "Intended Audience :: Science/Research",
35
+ "License :: OSI Approved :: MIT License",
36
+ "Operating System :: POSIX :: Linux",
37
+ "Programming Language :: Cython",
38
+ "Programming Language :: Python",
39
+ "Programming Language :: Python :: 2",
40
+ "Programming Language :: Python :: 2.7",
41
+ "Programming Language :: Python :: 3",
42
+ "Programming Language :: Python :: 3.4"
43
+ ],
44
+ description='Linked Entity Pipeline for spaCy',
45
+ long_description=long_description,
46
+ long_description_content_type="text/markdown",
47
+ zip_safe=True,
48
+ )
spacyEntityLinker/DatabaseConnection.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import os
3
+
4
+ MAX_DEPTH_CHAIN = 10
5
+ P_INSTANCE_OF = 31
6
+ P_SUBCLASS = 279
7
+
8
+ MAX_ITEMS_CACHE = 100000
9
+
10
+ conn = None
11
+ entity_cache = {}
12
+ chain_cache = {}
13
+
14
+ DB_DEFAULT_PATH = os.path.abspath('../data_spacy_entity_linker/wikidb_filtered.db')
15
+
16
+ wikidata_instance = None
17
+
18
+
19
+ def get_wikidata_instance():
20
+ global wikidata_instance
21
+
22
+ if wikidata_instance is None:
23
+ wikidata_instance = WikidataQueryController()
24
+
25
+ return wikidata_instance
26
+
27
+
28
+ class WikidataQueryController:
29
+
30
+ def __init__(self):
31
+ self.conn = None
32
+
33
+ self.cache = {
34
+ "entity": {},
35
+ "chain": {},
36
+ "name": {}
37
+ }
38
+
39
+ self.init_database_connection()
40
+
41
+ def _get_cached_value(self, cache_type, key):
42
+ return self.cache[cache_type][key]
43
+
44
+ def _is_cached(self, cache_type, key):
45
+ return key in self.cache[cache_type]
46
+
47
+ def _add_to_cache(self, cache_type, key, value):
48
+ if len(self.cache[cache_type]) < MAX_ITEMS_CACHE:
49
+ self.cache[cache_type][key] = value
50
+
51
+ def init_database_connection(self, path=DB_DEFAULT_PATH):
52
+ self.conn = sqlite3.connect(path)
53
+
54
+ def clear_cache(self):
55
+ self.cache["entity"].clear()
56
+ self.cache["chain"].clear()
57
+ self.cache["name"].clear()
58
+
59
+ def get_entities_from_alias(self, alias):
60
+ c = self.conn.cursor()
61
+ if self._is_cached("entity", alias):
62
+ return self._get_cached_value("entity", alias).copy()
63
+
64
+ query_alias = """SELECT j.item_id,j.en_label, j.en_description,j.views,j.inlinks,a.en_alias from aliases as a
65
+ LEFT JOIN joined as j ON a.item_id = j.item_id
66
+ WHERE a.en_alias_lowercase = ? and j.item_id NOT NULL"""
67
+
68
+ c.execute(query_alias, [alias.lower()])
69
+ fetched_rows = c.fetchall()
70
+
71
+ self._add_to_cache("entity", alias, fetched_rows)
72
+ return fetched_rows
73
+
74
+ def get_instances_of(self, item_id, properties=[P_INSTANCE_OF, P_SUBCLASS], count=1000):
75
+ query = "SELECT source_item_id from statements where target_item_id={} and edge_property_id IN ({}) LIMIT {}".format(
76
+ item_id, ",".join([str(prop) for prop in properties]), count)
77
+
78
+ c = self.conn.cursor()
79
+ c.execute(query)
80
+
81
+ res = c.fetchall()
82
+
83
+ return [e[0] for e in res]
84
+
85
+ def get_entity_name(self, item_id):
86
+ if self._is_cached("name", item_id):
87
+ return self._get_cached_value("name", item_id)
88
+
89
+ c = self.conn.cursor()
90
+ query = "SELECT en_label from joined WHERE item_id=?"
91
+ c.execute(query, [item_id])
92
+ res = c.fetchone()
93
+
94
+ if res and len(res):
95
+ if res[0] == None:
96
+ self._append_chain_elements("name", item_id, 'no label')
97
+ else:
98
+ self._append_chain_elements("name", item_id, res[0])
99
+ else:
100
+ self._append_chain_elements("name", item_id, '<none>')
101
+
102
+ return self._get_cached_value("name", item_id)
103
+
104
+ def get_entity(self, item_id):
105
+ c = self.conn.cursor()
106
+ query = "SELECT j.item_id,j.en_label,j.en_description,j.views,j.inlinks from joined as j " \
107
+ "WHERE j.item_id=={}".format(item_id)
108
+
109
+ res = c.execute(query)
110
+
111
+ return res.fetchone()
112
+
113
+ def get_children(self, item_id, limit=100):
114
+ c = self.conn.cursor()
115
+ query = "SELECT j.item_id,j.en_label,j.en_description,j.views,j.inlinks from joined as j " \
116
+ "JOIN statements as s on j.item_id=s.source_item_id " \
117
+ "WHERE s.target_item_id={} and s.edge_property_id IN (279,31) LIMIT {}".format(item_id, limit)
118
+
119
+ res = c.execute(query)
120
+
121
+ return res.fetchall()
122
+
123
+ def get_parents(self, item_id, limit=100):
124
+ c = self.conn.cursor()
125
+ query = "SELECT j.item_id,j.en_label,j.en_description,j.views,j.inlinks from joined as j " \
126
+ "JOIN statements as s on j.item_id=s.target_item_id " \
127
+ "WHERE s.source_item_id={} and s.edge_property_id IN (279,31) LIMIT {}".format(item_id, limit)
128
+
129
+ res = c.execute(query)
130
+
131
+ return res.fetchall()
132
+
133
+ def get_categories(self, item_id, max_depth=10):
134
+ chain = []
135
+ edges = []
136
+ self._append_chain_elements(item_id, 0, chain, edges, max_depth, [P_INSTANCE_OF, P_SUBCLASS])
137
+ return [el[0] for el in chain]
138
+
139
+ def get_chain(self, item_id, max_depth=10, property=P_INSTANCE_OF):
140
+ chain = []
141
+ edges = []
142
+ self._append_chain_elements(item_id, 0, chain, edges, max_depth, property)
143
+ return chain
144
+
145
+ def get_recursive_edges(self, item_id):
146
+ chain = []
147
+ edges = []
148
+ self._append_chain_elements(self, item_id, 0, chain, edges)
149
+ return edges
150
+
151
+ def _append_chain_elements(self, item_id, level=0, chain=[], edges=[], max_depth=10, property=P_INSTANCE_OF):
152
+ properties = property
153
+ if type(property) != list:
154
+ properties = [property]
155
+
156
+ if self._is_cached("chain", (item_id, max_depth)):
157
+ chain += self._get_cached_value("chain", (item_id, max_depth)).copy()
158
+ return
159
+
160
+ # prevent infinite recursion
161
+ if level >= max_depth:
162
+ return
163
+
164
+ c = self.conn.cursor()
165
+
166
+ query = "SELECT target_item_id,edge_property_id from statements where source_item_id={} and edge_property_id IN ({})".format(
167
+ item_id, ",".join([str(prop) for prop in properties]))
168
+
169
+ # set value for current item in order to prevent infinite recursion
170
+ self._add_to_cache("chain", (item_id, max_depth), [])
171
+
172
+ for target_item in c.execute(query):
173
+
174
+ chain_ids = [el[0] for el in chain]
175
+
176
+ if not (target_item[0] in chain_ids):
177
+ chain += [(target_item[0], level + 1)]
178
+ edges.append((item_id, target_item[0], target_item[1]))
179
+ self._append_chain_elements(target_item[0], level=level + 1, chain=chain, edges=edges,
180
+ max_depth=max_depth,
181
+ property=property)
182
+
183
+ self._add_to_cache("chain", (item_id, max_depth), chain)
184
+
185
+
186
+ if __name__ == '__main__':
187
+ queryInstance = WikidataQueryController()
188
+
189
+ queryInstance.init_database_connection()
190
+ print(queryInstance.get_categories(13191, max_depth=1))
191
+ print(queryInstance.get_categories(13191, max_depth=1))
spacyEntityLinker/EntityCandidates.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class EntityCandidates:
2
+
3
+ def __init__(self, entity_elements):
4
+ self.entity_elements = entity_elements
5
+
6
+ def __iter__(self):
7
+ for entity in self.entity_elements:
8
+ yield entity
9
+
10
+ def __len__(self):
11
+ return len(self.entity_elements)
12
+
13
+ def __getitem__(self, item):
14
+ return self.entity_elements[item]
15
+
16
+ def pretty_print(self):
17
+ for entity in self.entity_elements:
18
+ entity.pretty_print()
19
+
20
+ def __str__(self):
21
+ return str(["entity {}: {} (<{}>)".format(i, entity.get_label(), entity.get_description()) for i, entity in
22
+ enumerate(self.entity_elements)])
spacyEntityLinker/EntityClassifier.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import groupby
2
+ import numpy as np
3
+
4
+
5
+ class EntityClassifier:
6
+ def __init__(self):
7
+ pass
8
+
9
+ def _get_grouped_by_length(self, entities):
10
+ sorted_by_len = sorted(entities, key=lambda entity: len(entity.get_span()), reverse=True)
11
+
12
+ entities_by_length = {}
13
+ for length, group in groupby(sorted_by_len, lambda entity: len(entity.get_span())):
14
+ entities = list(group)
15
+ entities_by_length[length] = entities
16
+
17
+ return entities_by_length
18
+
19
+ def _filter_max_length(self, entities):
20
+ entities_by_length = self._get_grouped_by_length(entities)
21
+ max_length = max(list(entities_by_length.keys()))
22
+
23
+ return entities_by_length[max_length]
24
+
25
+ def _select_max_prior(self, entities):
26
+ priors = [entity.get_prior() for entity in entities]
27
+ return entities[np.argmax(priors)]
28
+
29
+ def _get_casing_difference(self, word1, original):
30
+ difference = 0
31
+ for w1, w2 in zip(word1, original):
32
+ if w1 != w2:
33
+ difference += 1
34
+
35
+ return difference
36
+
37
+ def _filter_most_similar(self, entities):
38
+ similarities = np.array(
39
+ [self._get_casing_difference(entity.get_span().text, entity.get_original_alias()) for entity in entities])
40
+
41
+ min_indices = np.where(similarities == similarities.min())[0].tolist()
42
+
43
+ return [entities[i] for i in min_indices]
44
+
45
+ def __call__(self, entities):
46
+ filtered_by_length = self._filter_max_length(entities)
47
+ filtered_by_casing = self._filter_most_similar(filtered_by_length)
48
+
49
+ return self._select_max_prior(filtered_by_casing)
spacyEntityLinker/EntityCollection.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter, defaultdict
2
+ from spacyEntityLinker.DatabaseConnection import get_wikidata_instance
3
+
4
+
5
+ class EntityCollection:
6
+
7
+ def __init__(self, entities=[]):
8
+ self.entities = entities
9
+
10
+ def __iter__(self):
11
+ for entity in self.entities:
12
+ yield entity
13
+
14
+ def __getitem__(self, item):
15
+ return self.entities[item]
16
+
17
+ def __len__(self):
18
+ return len(self.entities)
19
+
20
+ def append(self, entity):
21
+ self.entities.append(entity)
22
+
23
+ def get_categories(self, max_depth=1):
24
+ categories = []
25
+ for entity in self.entities:
26
+ categories += entity.get_categories(max_depth)
27
+
28
+ return categories
29
+
30
+ def print_categories(self, max_depth=1, limit=10):
31
+ wikidataInstance = get_wikidata_instance()
32
+
33
+ all_categories = []
34
+ category_to_entites = defaultdict(list)
35
+
36
+ for e in self.entities:
37
+ for category in e.get_categories(max_depth):
38
+ category_to_entites[category].append(e)
39
+ all_categories.append(category)
40
+
41
+ counter = Counter()
42
+ counter.update(all_categories)
43
+
44
+ for category, frequency in counter.most_common(limit):
45
+ print("{} ({}) : {}".format(wikidataInstance.get_entity_name(category), frequency,
46
+ ','.join([str(e) for e in category_to_entites[category]])))
47
+
48
+ def pretty_print(self):
49
+ for entity in self.entities:
50
+ entity.pretty_print()
51
+
52
+ def grouped_by_category(self, max_depth=1):
53
+ counter = Counter()
54
+ counter.update(self.get_categories(max_depth))
55
+
56
+ return counter
57
+
58
+ def get_distinct_categories(self, max_depth=1):
59
+ return list(set(self.get_categories(max_depth)))
60
+
61
+ def most_frequent_categories(self):
62
+ pass
63
+
64
+ def get_most_significant_categories(self, priors):
65
+ pass
spacyEntityLinker/EntityElement.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from spacyEntityLinker.DatabaseConnection import get_wikidata_instance
2
+
3
+
4
+ class EntityElement:
5
+ def __init__(self, row, span):
6
+ self.identifier = row[0]
7
+ self.prior = 0
8
+ self.original_alias = None
9
+ self.in_degree = None
10
+
11
+ if len(row) > 1:
12
+ self.label = row[1]
13
+ if len(row) > 2:
14
+ self.description = row[2]
15
+ if len(row) > 3 and row[3]:
16
+ self.prior = row[3]
17
+ if len(row) > 4 and row[4]:
18
+ self.in_degree = row[4]
19
+ if len(row) > 5 and row[5]:
20
+ self.original_alias = row[5]
21
+
22
+ self.span = span
23
+
24
+ self.chain = None
25
+ self.chain_ids = None
26
+
27
+ self.wikidata_instance = get_wikidata_instance()
28
+
29
+ def get_in_degree(self):
30
+ return self.in_degree
31
+
32
+ def get_original_alias(self):
33
+ return self.original_alias
34
+
35
+ def is_singleton(self):
36
+ return len(self.get_chain()) == 0
37
+
38
+ def get_span(self):
39
+ return self.span
40
+
41
+ def get_label(self):
42
+ return self.label
43
+
44
+ def get_id(self):
45
+ return self.identifier
46
+
47
+ def get_prior(self):
48
+ return self.prior
49
+
50
+ def get_chain(self):
51
+ if self.chain is None:
52
+ self.chain = self.wikidata_instance.get_chain(self.identifier, max_depth=10, property=31)
53
+ return self.chain
54
+
55
+ def is_category(self):
56
+ pass
57
+
58
+ def is_leaf(self):
59
+ pass
60
+
61
+ def get_categories(self, max_depth=10):
62
+ return self.wikidata_instance.get_categories(self.identifier, max_depth=max_depth)
63
+
64
+ def get_children(self, limit=10):
65
+ return [EntityElement(row, None) for row in self.wikidata_instance.get_children(self.get_id(), limit)]
66
+
67
+ def get_parents(self, limit=10):
68
+ return [EntityElement(row, None) for row in self.wikidata_instance.get_parents(self.get_id(), limit)]
69
+
70
+ def get_subclass_hierarchy(self):
71
+ chain = self.wikidata_instance.get_chain(self.identifier, max_depth=5, property=279)
72
+ return [self.wikidata_instance.get_entity_name(el[0]) for el in chain]
73
+
74
+ def get_instance_of_hierarchy(self):
75
+ chain = self.wikidata_instance.get_chain(self.identifier, max_depth=5, property=31)
76
+ return [self.wikidata_instance.get_entity_name(el[0]) for el in chain]
77
+
78
+ def get_chain_ids(self, max_depth=10):
79
+ if self.chain_ids is None:
80
+ self.chain_ids = set([el[0] for el in self.get_chain(max_depth=max_depth)])
81
+
82
+ return self.chain_ids
83
+
84
+ def get_description(self):
85
+ if self.description:
86
+ return self.description
87
+ else:
88
+ return ""
89
+
90
+ def is_intersecting(self, other_element):
91
+ return len(self.get_chain_ids().intersection(other_element.get_chain_ids())) > 0
92
+
93
+ def serialize(self):
94
+ return {
95
+ "id": self.get_id(),
96
+ "label": self.get_label(),
97
+ "span": self.get_span()
98
+ }
99
+
100
+ def pretty_print(self):
101
+ print(
102
+ "https://www.wikidata.org/wiki/Q{0:<10} {1:<10} {2:<30} {3:<100}".format(self.get_id(),
103
+ self.get_id(),
104
+ self.get_label(),
105
+ self.get_description()[:100]))
106
+
107
+ def pretty_string(self, description=False):
108
+ if description:
109
+ return ','.join([span.text for span in self.span]) + " => {} <{}>".format(self.get_label(),
110
+ self.get_description())
111
+ else:
112
+ return ','.join([span.text for span in self.span]) + " => {}".format(self.get_label())
113
+
114
+ def save(self, category):
115
+ for span in self.span:
116
+ span.sent._.linked_entities.append(
117
+ {"id": self.identifier, "range": [span.start, span.end + 1], "category": category})
118
+
119
+ def __str__(self):
120
+ label = self.get_label()
121
+ if label:
122
+ return label
123
+ else:
124
+ return ""
spacyEntityLinker/EntityLinker.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from spacyEntityLinker.EntityClassifier import EntityClassifier
2
+ from spacyEntityLinker.EntityCollection import EntityCollection
3
+ from spacyEntityLinker.TermCandidateExtractor import TermCandidateExtractor
4
+ from spacy.tokens import Doc, Span
5
+
6
+
7
+ class EntityLinker:
8
+
9
+ def __init__(self):
10
+ Doc.set_extension("linkedEntities", default=EntityCollection(), force=True)
11
+ Span.set_extension("linkedEntities", default=None, force=True)
12
+
13
+ def __call__(self, doc):
14
+ tce = TermCandidateExtractor(doc)
15
+ classifier = EntityClassifier()
16
+
17
+ for sent in doc.sents:
18
+ sent._.linkedEntities = EntityCollection([])
19
+
20
+ entities = []
21
+ for termCandidates in tce:
22
+ entityCandidates = termCandidates.get_entity_candidates()
23
+ if len(entityCandidates) > 0:
24
+ entity = classifier(entityCandidates)
25
+ entity.span.sent._.linkedEntities.append(entity)
26
+ entities.append(entity)
27
+
28
+ doc._.linkedEntities = EntityCollection(entities)
29
+
30
+ return doc
spacyEntityLinker/TermCandidate.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from spacyEntityLinker.EntityCandidates import EntityCandidates
2
+ from spacyEntityLinker.EntityElement import EntityElement
3
+ from spacyEntityLinker.DatabaseConnection import get_wikidata_instance
4
+
5
+
6
+ class TermCandidate:
7
+ def __init__(self, span):
8
+ self.variations = [span]
9
+
10
+ def pretty_print(self):
11
+ print("Term Candidates are [{}]".format(self))
12
+
13
+ def append(self, span):
14
+ self.variations.append(span)
15
+
16
+ def has_plural(self, variation):
17
+ return any([t.tag_ == "NNS" for t in variation])
18
+
19
+ def get_singular(self, variation):
20
+ return ' '.join([t.text if t.tag_ != "NNS" else t.lemma_ for t in variation])
21
+
22
+ def __str__(self):
23
+ return ', '.join([variation.text for variation in self.variations])
24
+
25
+ def get_entity_candidates(self):
26
+ wikidata_instance = get_wikidata_instance()
27
+ entities_by_variation = {}
28
+ for variation in self.variations:
29
+ entities_by_variation[variation] = wikidata_instance.get_entities_from_alias(variation.text)
30
+ if self.has_plural(variation):
31
+ entities_by_variation[variation] += wikidata_instance.get_entities_from_alias(
32
+ self.get_singular(variation))
33
+
34
+ entity_elements = []
35
+ for variation, entities in entities_by_variation.items():
36
+ entity_elements += [EntityElement(entity, variation) for entity in entities]
37
+
38
+ return EntityCandidates(entity_elements)
spacyEntityLinker/TermCandidateExtractor.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from spacyEntityLinker.TermCandidate import TermCandidate
2
+
3
+
4
+ class TermCandidateExtractor:
5
+ def __init__(self, doc):
6
+ self.doc = doc
7
+
8
+ def __iter__(self):
9
+ for sent in self.doc.sents:
10
+ for candidate in self._get_candidates_in_sent(sent, self.doc):
11
+ yield candidate
12
+
13
+ def _get_candidates_in_sent(self, sent, doc):
14
+ root = list(filter(lambda token: token.dep_ == "ROOT", sent))[0]
15
+
16
+ excluded_children = []
17
+ candidates = []
18
+
19
+ def get_candidates(node, doc):
20
+
21
+ if (node.pos_ in ["PROPN", "NOUN"]) and node.pos_ not in ["PRON"]:
22
+ term_candidates = TermCandidate(doc[node.i:node.i + 1])
23
+
24
+ for child in node.children:
25
+
26
+ start_index = min(node.i, child.i)
27
+ end_index = max(node.i, child.i)
28
+
29
+ if child.dep_ == "compound" or child.dep_ == "amod":
30
+ subtree_tokens = list(child.subtree)
31
+ if all([c.dep_ == "compound" for c in subtree_tokens]):
32
+ start_index = min([c.i for c in subtree_tokens])
33
+ term_candidates.append(doc[start_index:end_index + 1])
34
+
35
+ if not child.dep_ == "amod":
36
+ term_candidates.append(doc[start_index:start_index + 1])
37
+ excluded_children.append(child)
38
+
39
+ if child.dep_ == "prep" and child.text == "of":
40
+ end_index = max([c.i for c in child.subtree])
41
+ term_candidates.append(doc[start_index:end_index + 1])
42
+
43
+ candidates.append(term_candidates)
44
+
45
+ for child in node.children:
46
+ if child in excluded_children:
47
+ continue
48
+ get_candidates(child, doc)
49
+
50
+ get_candidates(root, doc)
51
+
52
+ return candidates
spacyEntityLinker/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .EntityLinker import EntityLinker
2
+
3
+ __version__ = '0.0.2'
4
+ __all__ = [EntityLinker]
spacyEntityLinker/__main__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ if __name__ == "__main__":
2
+ import sys
3
+ import urllib
4
+ import urllib.request
5
+ import tarfile
6
+ import os
7
+
8
+ if len(sys.argv) < 2:
9
+ print("No arguments given")
10
+ pass
11
+
12
+ command = sys.argv.pop(1)
13
+
14
+ if command == "download_knowledge_base":
15
+ FILE_URL = "https://wikidatafiles.nyc3.digitaloceanspaces.com/Hosting/Hosting/SpacyEntityLinker/datafiles.tar.gz"
16
+
17
+ OUTPUT_TAR_FILE = os.path.abspath(
18
+ os.path.dirname(__file__)) + '/../data_spacy_entity_linker/wikidb_filtered.tar.gz'
19
+ OUTPUT_DB_PATH = os.path.abspath(os.path.dirname(__file__)) + '/../data_spacy_entity_linker'
20
+ if not os.path.exists(OUTPUT_DB_PATH):
21
+ os.makedirs(OUTPUT_DB_PATH)
22
+ urllib.request.urlretrieve(FILE_URL, OUTPUT_TAR_FILE)
23
+
24
+ tar = tarfile.open(OUTPUT_TAR_FILE)
25
+ tar.extractall(OUTPUT_DB_PATH)
26
+ tar.close()
27
+
28
+ os.remove(OUTPUT_TAR_FILE)
tests/test_EntityLinker.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import spacy
3
+ from spacyEntityLinker.EntityLinker import EntityLinker
4
+
5
+
6
+ class TestEntityLinker(unittest.TestCase):
7
+
8
+ def __init__(self, arg, *args, **kwargs):
9
+ super(TestEntityLinker, self).__init__(arg, *args, **kwargs)
10
+ self.nlp = spacy.load('en_core_web_sm')
11
+
12
+ def test_initialization(self):
13
+ entityLinker = EntityLinker()
14
+
15
+ self.nlp.add_pipe(entityLinker, last=True, name="entityLinker")
16
+
17
+ doc = self.nlp("I watched the Pirates of the Caribbean last silvester. Then I saw a snake. It was great.")
18
+
19
+ doc._.linkedEntities.pretty_print()
20
+
21
+ for sent in doc.sents:
22
+ sent._.linkedEntities.pretty_print()
23
+
24
+ self.nlp.remove_pipe("entityLinker")
tests/test_TermCandidateExtractor.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import spacy
3
+ import spacyEntityLinker.TermCandidateExtractor
4
+
5
+
6
+ class TestCandidateExtractor(unittest.TestCase):
7
+
8
+ def __init__(self, arg, *args, **kwargs):
9
+ super(TestCandidateExtractor, self).__init__(arg, *args, **kwargs)