# --------------------------------------------------------------------------- # IMPORTS # --------------------------------------------------------------------------- import os import pickle import nltk import numpy as np import requests from nltk import edit_distance, pos_tag from nltk.tokenize import word_tokenize from wikipedia2vec import Wikipedia2Vec from src.stopwords import STOP_WORDS # --------------------------------------------------------------------------- # SETUP AND HELPER FUNCTIONS # --------------------------------------------------------------------------- nltk.download('averaged_perceptron_tagger') nltk.download('punkt') DATADIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data') with open(os.path.join(DATADIR, 'entity_anchors.bin'), 'rb') as f: prior_prob = pickle.load(f) with open(os.path.join(DATADIR, 'entity_prior.bin'), 'rb') as f: entity_prior = pickle.load(f) def get_edit_dist(x, y): return edit_distance(x, y) def get_entity_prior(entity): try: return entity_prior[entity.replace('_', ' ')] except: return 0 def get_prior_prob(entity, mention): try: entity = entity.replace('_', ' ') mention = mention.lower() return prior_prob[mention][entity] / sum(prior_prob[mention].values()) except: return 0 def get_max_prior_prob(mentions, candidates): max_prob = {i: max([get_prior_prob(i, j) for j in mentions]) for i in candidates} return max_prob def cosine_similarity(v1, v2): v1v2 = np.linalg.norm(v1) * np.linalg.norm(v2) if v1v2 == 0: return 0 else: return np.dot(v2, v1) / v1v2 def is_disamb_page(title): service_url = "https://en.wikipedia.org/w/api.php" params = { "action": "query", "prop": "pageprops", "ppprop" : "disambiguation", "redirects":'', "format": "json", "titles": title } results = requests.get(service_url, params=params).json() return 'disambiguation' in str(results) def wikidata_search(query, limit=3): service_url = 'https://www.wikidata.org/w/api.php' params1 = { "action": "wbsearchentities", "search": query, "language": "en", "limit": limit, "format": "json" } params2 = { "action": "wbgetentities", "language": "en", "props": "sitelinks", "sitefilter": "enwiki", "format": "json" } results = requests.get(service_url, params=params1).json() entities = [i['id'] for i in results['search']] params2['ids'] = '|'.join(entities) results = requests.get(service_url, params=params2).json() candidates = [] for i in entities: try: candidates.append(results['entities'][i]['sitelinks']['enwiki']['title'].replace(' ', '_')) except: pass return [i for i in candidates if is_disamb_page(i) == False] def wikipedia_search(query, limit=3): service_url = 'https://en.wikipedia.org/w/api.php' params = { 'action': 'opensearch', 'search': query, 'namespace': 0, 'limit': limit, 'redirects': 'resolve', } results = requests.get(service_url, params=params).json()[1] results = [i.replace(' ', '_') for i in results if 'disambiguation' not in i.lower()] return [i for i in results if is_disamb_page(i) == False] def google_search(query, limit=10): service_url = "https://www.googleapis.com/customsearch/v1/siterestrict" params = { 'q': query, 'num': limit, 'start': 0, 'key': os.environ.get('APIKEY'), 'cx': os.environ.get('CESCX') } res = requests.get(service_url, params=params) try: cands = [i['title'].replace(' - Wikipedia', '') for i in res.json()["items"]] return [i.replace(' ', '_') for i in cands if is_disamb_page(i) == False] except: return [] def get_entity_extract(entity_title, num_sentences=0): service_url = 'https://en.wikipedia.org/w/api.php' params = { 'action': 'query', 'titles': entity_title, 'prop': 'extracts', 'redirects': 1, 'format': 'json', 'explaintext': 1, 'exsectionformat': 'plain' } if num_sentences != 0: params['exsentences'] = num_sentences res = requests.get(service_url, params=params) try: res = res.json()['query']['pages'] res = res[list(res.keys())[0]] extract = res['extract'] if 'extract' in res.keys() else '' return extract except: return '' # --------------------------------------------------------------------------- # NED SYSTEMS # --------------------------------------------------------------------------- ### Base Model ### class Base: def __init__(self): self.emb = Wikipedia2Vec.load(os.path.join(DATADIR, 'wiki2vec_w10_100d.bin')) self.stop_words = STOP_WORDS self.tokenizer = word_tokenize self.nouns_only = True self.vector_size = self.emb.train_params['dim_size'] def get_nouns(self, tokens): nouns = [] for word, pos in pos_tag(tokens): if (pos == 'NN' or pos == 'NNP' or pos == 'NNS' or pos == 'NNPS'): nouns.extend(word.split(' ')) return list(set(nouns)) def filter(self, tokens): tokens = list(set(tokens)) tokens = [w for w in tokens if not(w.lower() in self.stop_words)] tokens = [w for w in tokens if w.isalnum()] return self.get_nouns(tokens) if self.nouns_only else tokens def encode_entity(self, entity): entity = entity.replace('_', ' ') if self.emb.get_entity(entity) is not None: return self.emb.get_entity_vector(entity) else: return self.encode_sentence(get_entity_extract(entity, num_sentences=10)) def encode_sentence(self, s): words = self.filter(self.tokenizer(s.lower())) emb, n = np.zeros(self.vector_size), 1 for w in words: try: emb += self.emb.get_word_vector(w) n += 1 except KeyError: pass return emb/n ### Advance Model ### class GBRT(Base): def __init__(self): super().__init__() with open(os.path.join(DATADIR, 'model.bin'), 'rb') as f: self.model = pickle.load(f) def encode_context_entities(self, context_entities): emb, n = np.zeros(self.vector_size), 1 for i in context_entities: emb += self.encode_entity(i) n += 1 return emb/n def link(self, mentions_cands, context): # Calculate max prior probability of all candidates. mentions = set([i for i, _ in mentions_cands]) candidates = set([i for _, j in mentions_cands for i in j]) max_prob = get_max_prior_prob(mentions, candidates) # Find unambiguous entities unamb_entities = [x for i, j in mentions_cands for x in j if get_prior_prob(x, i) > 0.95] context_ent_emb = self.encode_context_entities(unamb_entities) # Make predictions context_emb = self.encode_sentence(context) predictions = [] for mention, candidates in mentions_cands: # Generate feature values num_cands = len(candidates) X = [] for candidate in candidates: cand = candidate.replace('_', ' ').lower() ment = mention.lower() cand_emb = self.encode_entity(candidate) X.append([ candidate, get_prior_prob(candidate, mention), get_entity_prior(candidate), max_prob[candidate], num_cands, get_edit_dist(ment, cand), int(ment == cand), int(ment in cand), int(cand.startswith(cand) or cand.endswith(ment)), cosine_similarity(cand_emb, context_emb), cosine_similarity(cand_emb, context_ent_emb) ]) # Add rank X.sort(key=lambda x: x[-1] + x[-2], reverse=True) X = [j + [i + 1] for i, j in enumerate(X)] # Predict pred, conf = 'NIL', 0.2 for i in X: c = self.model.predict(np.array([i[1:]]))[0] if c > conf: pred = i[0] conf = c predictions.append([mention, pred, conf]) return predictions