Atharva commited on
Commit
f74445c
β€’
1 Parent(s): 3d3358f

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import streamlit as st
3
+ import streamlit.components.v1 as components
4
+ from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
5
+
6
+ from src import GBRT, wikipedia_search, google_search
7
+
8
+
9
+ TYPE = {
10
+ 'LOC': ' location',
11
+ 'PER': ' person',
12
+ 'ORG': ' organization',
13
+ 'MISC': ''
14
+ }
15
+
16
+ COLOR = {
17
+ 'LOC': '#40E0D0',
18
+ 'PER': '#6495ED',
19
+ 'ORG': '#CCCCFF',
20
+ 'MISC': '#FF7F50'
21
+ }
22
+
23
+ # ---------------------------------------------------------------------------
24
+ # Loading models
25
+ # ---------------------------------------------------------------------------
26
+
27
+
28
+ @st.cache(allow_output_mutation=True, show_spinner=True)
29
+ def load_models():
30
+ # NER
31
+ tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
32
+ bert_ner = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
33
+ tagger = pipeline("token-classification", model=bert_ner, tokenizer=tokenizer,
34
+ device=0, aggregation_strategy="average")
35
+ # NED
36
+ model = GBRT()
37
+ return model, tagger
38
+
39
+
40
+ # ---------------------------------------------------------------------------
41
+ # Page setup
42
+ # ---------------------------------------------------------------------------
43
+ st.set_page_config(layout="wide", page_title='Named Entity Disambiguation')
44
+ st.write("## Named Entity Disambiguation")
45
+ col1, col2 = st.columns(2)
46
+
47
+
48
+ # ---------------------------------------------------------------------------
49
+ # Candidate Generation
50
+ # ---------------------------------------------------------------------------
51
+ def get_candidates(mentions_tags):
52
+ candidates = []
53
+ cache = {}
54
+ for mention, tag in mentions_tags:
55
+ if (mention, tag) in cache.keys():
56
+ candidates.append((mention, cache[(mention, tag)]))
57
+ else:
58
+ res1 = google_search(mention + TYPE[tag])
59
+ res2 = wikipedia_search(mention, limit=10)
60
+ cands = list(set(res1 + res2))
61
+ cache[(mention, tag)] = cands
62
+ candidates.append((mention, cands))
63
+ return candidates
64
+
65
+
66
+ # ---------------------------------------------------------------------------
67
+ # Rendering Setup
68
+ # ---------------------------------------------------------------------------
69
+ def display_tag(text, typ, label):
70
+ if label != 'NIL':
71
+ label = "https://en.wikipedia.org/wiki/" + label
72
+ return f"""
73
+ <a style="margin: 0 5px; padding: 2px 4px; border-radius: 4px; text-decoration:none;
74
+ background-color:{COLOR[typ]}; color: white; cursor:pointer" href={label} target="_blank">
75
+ <span style="margin-right:3px">{text}</span>
76
+ <span style="border-style:1px white solid; padding: 2px;">{typ}</span>
77
+ </a>"""
78
+
79
+
80
+ # ---------------------------------------------------------------------------
81
+ # Full Pipeline
82
+ # ---------------------------------------------------------------------------
83
+ def main(text):
84
+ ner_results = tagger(text)
85
+ tagged, last_pos = '', 0
86
+
87
+ with st.spinner('Generating Candidates'):
88
+ mentions_cands = get_candidates([(res['word'], res['entity_group']) for res in ner_results])
89
+
90
+ with st.spinner('Disambiguating Mentions'):
91
+ preditions = model.link(mentions_cands, text)
92
+
93
+ with st.spinner('Rendering Results'):
94
+ for i, res in enumerate(ner_results):
95
+ tag = display_tag(res['word'], res['entity_group'], preditions[i][1])
96
+ tagged += text[last_pos:res['start']] + tag
97
+ last_pos = res['end']
98
+ tagged += text[last_pos:]
99
+
100
+ with col2:
101
+ st.write("### Disambiguated Text")
102
+ components.html(f'<p style="line-height: 1.8; margin-top:30px; font-family: sans-serif">{tagged}</p>',
103
+ scrolling=True, height=500)
104
+
105
+ df = pd.DataFrame(data=preditions, columns=['Mention', 'Prediction', 'Confidence'])
106
+ st.write("**Additional Information**")
107
+ st.dataframe(df)
108
+
109
+
110
+ if __name__ == '__main__':
111
+ model, tagger = load_models()
112
+ with col1:
113
+ st.write("### Input Text")
114
+ user_input = st.text_area('Press Ctrl + Enter to update results',
115
+ 'George Washington went to Washington.', height=350)
116
+ if user_input:
117
+ main(user_input)
118
+
119
+
data/entity_anchors.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7d51bcbca1f5f4486ef73cc26867ee723d7e62a1d707da24b9e2017657d15fe
3
+ size 1191130865
data/entity_prior.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:469c177f0e6236d1e3dad0d1efe907dcb4c8004acaf7451a17b18754b5cfbcd7
3
+ size 525736013
data/model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:675668d7f25be41a7b388061081cf6fe7c04f344bb866561718f57c3a2fbc6a5
3
+ size 21047161
data/wiki2vec_w10_100d.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da0a561df04532687acd4f018de60aa5bcdefa57a75bbfd871f7ed7b72f06b76
3
+ size 3858917918
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ nltk
2
+ numpy
3
+ pandas
4
+ wikipedia2vec
src/__init__.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------------------------------------------------------------------------
2
+ # IMPORTS
3
+ # ---------------------------------------------------------------------------
4
+ import os
5
+ import pickle
6
+
7
+ import nltk
8
+ import numpy as np
9
+ import requests
10
+ from nltk import edit_distance, pos_tag
11
+ from nltk.tokenize import word_tokenize
12
+ from wikipedia2vec import Wikipedia2Vec
13
+
14
+ from src.stopwords import STOP_WORDS
15
+
16
+ # ---------------------------------------------------------------------------
17
+ # SETUP AND HELPER FUNCTIONS
18
+ # ---------------------------------------------------------------------------
19
+ nltk.download('averaged_perceptron_tagger')
20
+
21
+ DATADIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data')
22
+ with open(os.path.join(DATADIR, 'entity_anchors.bin'), 'rb') as f:
23
+ prior_prob = pickle.load(f)
24
+ with open(os.path.join(DATADIR, 'entity_prior.bin'), 'rb') as f:
25
+ entity_prior = pickle.load(f)
26
+
27
+
28
+ def get_edit_dist(x, y):
29
+ return edit_distance(x, y)
30
+
31
+
32
+ def get_entity_prior(entity):
33
+ try:
34
+ return entity_prior[entity.replace('_', ' ')]
35
+ except:
36
+ return 0
37
+
38
+
39
+ def get_prior_prob(entity, mention):
40
+ try:
41
+ entity = entity.replace('_', ' ')
42
+ mention = mention.lower()
43
+ return prior_prob[mention][entity] / sum(prior_prob[mention].values())
44
+ except:
45
+ return 0
46
+
47
+
48
+ def get_max_prior_prob(mentions, candidates):
49
+ max_prob = {i: max([get_prior_prob(i, j) for j in mentions])
50
+ for i in candidates}
51
+ return max_prob
52
+
53
+
54
+ def cosine_similarity(v1, v2):
55
+ v1v2 = np.linalg.norm(v1) * np.linalg.norm(v2)
56
+ if v1v2 == 0:
57
+ return 0
58
+ else:
59
+ return np.dot(v2, v1) / v1v2
60
+
61
+
62
+ def wikipedia_search(query, limit=20):
63
+ service_url = 'https://en.wikipedia.org/w/api.php'
64
+ params = {
65
+ 'action': 'opensearch',
66
+ 'search': query,
67
+ 'namespace': 0,
68
+ 'limit': limit,
69
+ 'redirects': 'resolve',
70
+ }
71
+
72
+ results = requests.get(service_url, params=params).json()[1]
73
+ results = [i.replace(' ', '_')
74
+ for i in results if 'disambiguation' not in i.lower()]
75
+ return results
76
+
77
+
78
+ def google_search(query, limit=10):
79
+ service_url = "https://www.googleapis.com/customsearch/v1/siterestrict"
80
+ params = {
81
+ 'q': query,
82
+ 'num': limit,
83
+ 'start': 0,
84
+ 'key': os.environ.get('APIKEY'),
85
+ 'cx': os.environ.get('CESCX')
86
+ }
87
+ res = requests.get(service_url, params=params)
88
+ try:
89
+ cands = [i['title'].replace(' - Wikipedia', '') for i in res.json()["items"]]
90
+ return [i.replace(' ', '_') for i in cands]
91
+ except:
92
+ return []
93
+
94
+ # ---------------------------------------------------------------------------
95
+ # NED SYSTEMS
96
+ # ---------------------------------------------------------------------------
97
+
98
+ ### Base Model ###
99
+
100
+
101
+ class Base:
102
+ def __init__(self):
103
+ self.emb = Wikipedia2Vec.load(os.path.join(DATADIR, 'wiki2vec_w10_100d.bin'))
104
+ self.stop_words = STOP_WORDS
105
+ self.tokenizer = word_tokenize
106
+ self.nouns_only = True
107
+ self.vector_size = self.emb.train_params['dim_size']
108
+
109
+ def get_nouns(self, tokens):
110
+ nouns = []
111
+ for word, pos in pos_tag(tokens):
112
+ if (pos == 'NN' or pos == 'NNP' or pos == 'NNS' or pos == 'NNPS'):
113
+ nouns.extend(word.split(' '))
114
+ return list(set(nouns))
115
+
116
+ def filter(self, tokens):
117
+ tokens = list(set(tokens))
118
+ tokens = [w for w in tokens if not(w.lower() in self.stop_words)]
119
+ tokens = [w for w in tokens if w.isalnum()]
120
+ return self.get_nouns(tokens) if self.nouns_only else tokens
121
+
122
+ def encode_entity(self, entity):
123
+ entity = entity.replace('_', ' ')
124
+ if self.emb.get_entity(entity) is not None:
125
+ return self.emb.get_entity_vector(entity)
126
+ else:
127
+ return np.zeros(self.vector_size)
128
+
129
+ def encode_sentence(self, s):
130
+ words = self.filter(self.tokenizer(s.lower()))
131
+ emb, n = np.zeros(self.vector_size), 1
132
+ for w in words:
133
+ try:
134
+ emb += self.emb.get_word_vector(w)
135
+ n += 1
136
+ except KeyError:
137
+ pass
138
+
139
+ return emb/n
140
+
141
+
142
+ ### Advance Model ###
143
+ class GBRT(Base):
144
+ def __init__(self):
145
+ super().__init__()
146
+ with open(os.path.join(DATADIR, 'model.bin'), 'rb') as f:
147
+ self.model = pickle.load(f)
148
+
149
+ def encode_context_entities(self, context_entities):
150
+ emb, n = np.zeros(self.vector_size), 1
151
+ for i in context_entities:
152
+ emb += self.encode_entity(i)
153
+ n += 1
154
+ return emb/n
155
+
156
+ def link(self, mentions_cands, context):
157
+ n_features = self.model.n_features_in_
158
+
159
+ # Calculate max prior probability of all candidates.
160
+ mentions = set([i for i, _ in mentions_cands])
161
+ candidates = set([i for _, j in mentions_cands for i in j])
162
+ max_prob = get_max_prior_prob(mentions, candidates)
163
+
164
+ # Find unambiguous entities
165
+ unamb_entities = [x for i, j in mentions_cands for x in j if get_prior_prob(x, i) > 0.95]
166
+ context_ent_emb = self.encode_context_entities(unamb_entities)
167
+
168
+ # Make predictions
169
+ context_emb = self.encode_sentence(context)
170
+ predictions = []
171
+ for mention, candidates in mentions_cands:
172
+ # Generate feature values
173
+ num_cands = len(candidates)
174
+ X = []
175
+ for candidate in candidates:
176
+ cand = candidate.replace('_', ' ').lower()
177
+ ment = mention.lower()
178
+ cand_emb = self.encode_entity(candidate)
179
+
180
+ X.append([
181
+ candidate,
182
+ get_prior_prob(candidate, mention),
183
+ get_entity_prior(candidate),
184
+ max_prob[candidate],
185
+ num_cands,
186
+ get_edit_dist(ment, cand),
187
+ int(ment == cand),
188
+ int(ment in cand),
189
+ int(cand.startswith(cand) or cand.endswith(ment)),
190
+ cosine_similarity(cand_emb, context_emb),
191
+ cosine_similarity(cand_emb, context_ent_emb)
192
+ ])
193
+
194
+ # Add rank
195
+ X.sort(key=lambda x: x[-1] + x[-2], reverse=True)
196
+ X = [j + [i + 1] for i, j in enumerate(X)]
197
+
198
+ # Predict
199
+ pred, conf = 'NIL', 0
200
+ for i in X:
201
+ c = self.model.predict(np.array([i[1:]]))[0]
202
+ if c > conf:
203
+ pred = i[0]
204
+ conf = c
205
+ predictions.append([mention, pred, conf])
206
+
207
+ return predictions
src/stopwords.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Stop words
2
+ STOP_WORDS = set("""
3
+ a about above across after afterwards again against all almost alone along
4
+ already also although always am among amongst amount an and another any anyhow
5
+ anyone anything anyway anywhere are around as at
6
+
7
+ back be became because become becomes becoming been before beforehand behind
8
+ being below beside besides between beyond both bottom but by
9
+
10
+ call can cannot ca could
11
+
12
+ did do does doing done down due during
13
+
14
+ each eight either eleven else elsewhere empty enough even ever every
15
+ everyone everything everywhere except
16
+
17
+ few fifteen fifty first five for former formerly forty four from front full
18
+ further
19
+
20
+ get give go
21
+
22
+ had has have he hence her here hereafter hereby herein hereupon hers herself
23
+ him himself his how however hundred
24
+
25
+ i if in indeed into is it its itself
26
+
27
+ keep
28
+
29
+ last latter latterly least less
30
+
31
+ just
32
+
33
+ made make many may me meanwhile might mine more moreover most mostly move much
34
+ must my myself
35
+
36
+ name namely neither never nevertheless next nine no nobody none noone nor not
37
+ nothing now nowhere
38
+
39
+ of off often on once one only onto or other others otherwise our ours ourselves
40
+ out over own
41
+
42
+ part per perhaps please put
43
+
44
+ quite
45
+
46
+ rather re really regarding
47
+
48
+ same say see seem seemed seeming seems serious several she should show side
49
+ since six sixty so some somehow someone something sometime sometimes somewhere
50
+ still such
51
+
52
+ take ten than that the their them themselves then thence there thereafter
53
+ thereby therefore therein thereupon these they third this those though three
54
+ through throughout thru thus to together too top toward towards twelve twenty
55
+ two
56
+
57
+ under until up unless upon us used using
58
+
59
+ various very very via was we well were what whatever when whence whenever where
60
+ whereafter whereas whereby wherein whereupon wherever whether which while
61
+ whither who whoever whole whom whose why will with within without would
62
+
63
+ yet you your yours yourself yourselves
64
+ """.split())
65
+
66
+ contractions = ["n't", "'d", "'ll", "'m", "'re", "'s", "'ve"]
67
+ STOP_WORDS.update(contractions)
68
+
69
+ for apostrophe in ["β€˜", "’"]:
70
+ for stopword in contractions:
71
+ STOP_WORDS.add(stopword.replace("'", apostrophe))