NetsPresso_QA / scripts /entity_linking.py
geonmin-kim's picture
Upload folder using huggingface_hub
d6585f5
raw history blame
No virus
7.82 kB
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import jsonlines
import spacy
import sys
from REL.mention_detection import MentionDetectionBase
from REL.utils import process_results, split_in_words
from REL.entity_disambiguation import EntityDisambiguation
from REL.ner import Span
from wikimapper import WikiMapper
from typing import Dict, List, Tuple
from tqdm import tqdm
# Spacy Mention Detection class which overrides the NERBase class in the REL entity linking process
class NERSpacyMD(MentionDetectionBase):
def __init__(self, base_url:str, wiki_version:str, spacy_model:str):
super().__init__(base_url, wiki_version)
# we only want to link entities of specific types
self.ner_labels = ['PERSON', 'NORP', 'FAC', 'ORG', 'GPE', 'LOC', 'PRODUCT', 'EVENT', 'WORK_OF_ART',
'LAW', 'LANGUAGE', 'DATE', 'TIME', 'MONEY', 'QUANTITY']
self.spacy_model = spacy_model
spacy.prefer_gpu()
self.tagger = spacy.load(spacy_model)
# mandatory function which overrides NERBase.predict()
def predict(self, doc: spacy.tokens.Doc) -> List[Span]:
spans = []
for ent in doc.ents:
if ent.label_ in self.ner_labels:
spans.append(Span(ent.text, ent.start_char, ent.end_char, 0, ent.label_))
return spans
"""
Responsible for finding mentions given a set of documents in a batch-wise manner. More specifically,
it returns the mention, its left/right context and a set of candidates.
:return: Dictionary with mentions per document.
"""
def find_mentions(self, dataset: Dict[str, str]) -> Tuple[Dict[str, List[Dict]], int]:
results = {}
total_ment = 0
for i, doc in tqdm(enumerate(dataset), desc='Finding mentions', total=len(dataset)):
result_doc = []
doc_text = dataset[doc]
spacy_doc = self.tagger(doc_text)
spans = self.predict(spacy_doc)
for entity in spans:
text, start_pos, end_pos, conf, tag = (
entity.text,
entity.start_pos,
entity.end_pos,
entity.score,
entity.tag,
)
m = self.preprocess_mention(text)
cands = self.get_candidates(m)
if len(cands) == 0:
continue
total_ment += 1
# Re-create ngram as 'text' is at times changed by Flair (e.g. double spaces are removed).
ngram = doc_text[start_pos:end_pos]
left_ctxt = " ".join(split_in_words(doc_text[:start_pos])[-100:])
right_ctxt = " ".join(split_in_words(doc_text[end_pos:])[:100])
res = {
"mention": m,
"context": (left_ctxt, right_ctxt),
"candidates": cands,
"gold": ["NONE"],
"pos": start_pos,
"sent_idx": 0,
"ngram": ngram,
"end_pos": end_pos,
"sentence": doc_text,
"conf_md": conf,
"tag": tag,
}
result_doc.append(res)
results[doc] = result_doc
return results, total_ment
# run REL entity linking on processed doc
def rel_entity_linking(docs: Dict[str,str], spacy_model:str, rel_base_url:str, rel_wiki_version:str, rel_ed_model_path:str) -> Dict[str, List[Tuple]]:
mention_detection = NERSpacyMD(rel_base_url, rel_wiki_version, spacy_model)
mentions_dataset, _ = mention_detection.find_mentions(docs)
config = {
'mode': 'eval',
'model_path': rel_ed_model_path,
}
ed_model = EntityDisambiguation(rel_base_url, rel_wiki_version, config)
predictions, _ = ed_model.predict(mentions_dataset)
linked_entities = process_results(mentions_dataset, predictions, docs)
return linked_entities
# read input pyserini json docs into a dictionary
def read_docs(input_path: str) -> Dict[str, str]:
docs = {}
with jsonlines.open(input_path) as reader:
for obj in tqdm(reader, desc='Reading docs'):
docs[obj['id']] = obj['contents']
return docs
# enrich REL entity linking results with entities' wikidata ids, and write final results as json objects
# rel_linked_entities: Tuples of entities are composed by start_pos:int, mention_length:int, ent_text:str, ent_wikipedia_id:str, conf_score:float, ner_score:int, ent_type:str
def enrich_el_results(rel_linked_entities: Dict[str, List[Tuple]], docs: Dict[str, str], wikimapper_index:str) -> List[Dict]:
wikimapper = WikiMapper(wikimapper_index)
linked_entities_json = []
for docid, doc_text in tqdm(docs.items(), desc='Enriching EL results', total=len(rel_linked_entities)):
if docid not in rel_linked_entities:
linked_entities_json.append({'id': docid, 'contents': doc_text, 'entities': []})
else:
linked_entities_info = []
ents = rel_linked_entities[docid]
for start_pos, mention_length, ent_text, ent_wikipedia_id, conf_score, ner_score, ent_type in ents:
# find entities' wikidata ids using their REL results (i.e. linked wikipedia ids)
ent_wikipedia_id = ent_wikipedia_id.replace('&', '&')
ent_wikidata_id = wikimapper.title_to_id(ent_wikipedia_id)
# write results as json objects
linked_entities_info.append({'start_pos': start_pos, 'end_pos': start_pos + mention_length, 'ent_text': ent_text,
'wikipedia_id': ent_wikipedia_id, 'wikidata_id': ent_wikidata_id,
'ent_type': ent_type})
linked_entities_json.append({'id': docid, 'contents': doc_text, 'entities': linked_entities_info})
return linked_entities_json
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--input_path', type=str, help='path to input texts')
parser.add_argument('-u', '--rel_base_url', type=str, help='directory containing all required REL data folders')
parser.add_argument('-m', '--rel_ed_model_path', type=str, help='path to the REL entity disambiguation model')
parser.add_argument('-v', '--rel_wiki_version', type=str, help='wikipedia corpus version used for REL')
parser.add_argument('-w', '--wikimapper_index', type=str, help='precomputed index used by Wikimapper')
parser.add_argument('-s', '--spacy_model', type=str, help='spacy model type')
parser.add_argument('-o', '--output_path', type=str, help='path to output json file')
args = parser.parse_args()
docs = read_docs(args.input_path)
rel_linked_entities = rel_entity_linking(docs, args.spacy_model, args.rel_base_url, args.rel_wiki_version,
args.rel_ed_model_path)
linked_entities_json = enrich_el_results(rel_linked_entities, docs, args.wikimapper_index)
with jsonlines.open(args.output_path, mode='w') as writer:
writer.write_all(linked_entities_json)
if __name__ == '__main__':
main()
sys.exit(0)