# # Pyserini: Reproducible IR research with sparse and dense representations # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import argparse import jsonlines import spacy import sys from REL.mention_detection import MentionDetectionBase from REL.utils import process_results, split_in_words from REL.entity_disambiguation import EntityDisambiguation from REL.ner import Span from wikimapper import WikiMapper from typing import Dict, List, Tuple from tqdm import tqdm # Spacy Mention Detection class which overrides the NERBase class in the REL entity linking process class NERSpacyMD(MentionDetectionBase): def __init__(self, base_url:str, wiki_version:str, spacy_model:str): super().__init__(base_url, wiki_version) # we only want to link entities of specific types self.ner_labels = ['PERSON', 'NORP', 'FAC', 'ORG', 'GPE', 'LOC', 'PRODUCT', 'EVENT', 'WORK_OF_ART', 'LAW', 'LANGUAGE', 'DATE', 'TIME', 'MONEY', 'QUANTITY'] self.spacy_model = spacy_model spacy.prefer_gpu() self.tagger = spacy.load(spacy_model) # mandatory function which overrides NERBase.predict() def predict(self, doc: spacy.tokens.Doc) -> List[Span]: spans = [] for ent in doc.ents: if ent.label_ in self.ner_labels: spans.append(Span(ent.text, ent.start_char, ent.end_char, 0, ent.label_)) return spans """ Responsible for finding mentions given a set of documents in a batch-wise manner. More specifically, it returns the mention, its left/right context and a set of candidates. :return: Dictionary with mentions per document. """ def find_mentions(self, dataset: Dict[str, str]) -> Tuple[Dict[str, List[Dict]], int]: results = {} total_ment = 0 for i, doc in tqdm(enumerate(dataset), desc='Finding mentions', total=len(dataset)): result_doc = [] doc_text = dataset[doc] spacy_doc = self.tagger(doc_text) spans = self.predict(spacy_doc) for entity in spans: text, start_pos, end_pos, conf, tag = ( entity.text, entity.start_pos, entity.end_pos, entity.score, entity.tag, ) m = self.preprocess_mention(text) cands = self.get_candidates(m) if len(cands) == 0: continue total_ment += 1 # Re-create ngram as 'text' is at times changed by Flair (e.g. double spaces are removed). ngram = doc_text[start_pos:end_pos] left_ctxt = " ".join(split_in_words(doc_text[:start_pos])[-100:]) right_ctxt = " ".join(split_in_words(doc_text[end_pos:])[:100]) res = { "mention": m, "context": (left_ctxt, right_ctxt), "candidates": cands, "gold": ["NONE"], "pos": start_pos, "sent_idx": 0, "ngram": ngram, "end_pos": end_pos, "sentence": doc_text, "conf_md": conf, "tag": tag, } result_doc.append(res) results[doc] = result_doc return results, total_ment # run REL entity linking on processed doc def rel_entity_linking(docs: Dict[str,str], spacy_model:str, rel_base_url:str, rel_wiki_version:str, rel_ed_model_path:str) -> Dict[str, List[Tuple]]: mention_detection = NERSpacyMD(rel_base_url, rel_wiki_version, spacy_model) mentions_dataset, _ = mention_detection.find_mentions(docs) config = { 'mode': 'eval', 'model_path': rel_ed_model_path, } ed_model = EntityDisambiguation(rel_base_url, rel_wiki_version, config) predictions, _ = ed_model.predict(mentions_dataset) linked_entities = process_results(mentions_dataset, predictions, docs) return linked_entities # read input pyserini json docs into a dictionary def read_docs(input_path: str) -> Dict[str, str]: docs = {} with jsonlines.open(input_path) as reader: for obj in tqdm(reader, desc='Reading docs'): docs[obj['id']] = obj['contents'] return docs # enrich REL entity linking results with entities' wikidata ids, and write final results as json objects # rel_linked_entities: Tuples of entities are composed by start_pos:int, mention_length:int, ent_text:str, ent_wikipedia_id:str, conf_score:float, ner_score:int, ent_type:str def enrich_el_results(rel_linked_entities: Dict[str, List[Tuple]], docs: Dict[str, str], wikimapper_index:str) -> List[Dict]: wikimapper = WikiMapper(wikimapper_index) linked_entities_json = [] for docid, doc_text in tqdm(docs.items(), desc='Enriching EL results', total=len(rel_linked_entities)): if docid not in rel_linked_entities: linked_entities_json.append({'id': docid, 'contents': doc_text, 'entities': []}) else: linked_entities_info = [] ents = rel_linked_entities[docid] for start_pos, mention_length, ent_text, ent_wikipedia_id, conf_score, ner_score, ent_type in ents: # find entities' wikidata ids using their REL results (i.e. linked wikipedia ids) ent_wikipedia_id = ent_wikipedia_id.replace('&', '&') ent_wikidata_id = wikimapper.title_to_id(ent_wikipedia_id) # write results as json objects linked_entities_info.append({'start_pos': start_pos, 'end_pos': start_pos + mention_length, 'ent_text': ent_text, 'wikipedia_id': ent_wikipedia_id, 'wikidata_id': ent_wikidata_id, 'ent_type': ent_type}) linked_entities_json.append({'id': docid, 'contents': doc_text, 'entities': linked_entities_info}) return linked_entities_json def main(): parser = argparse.ArgumentParser() parser.add_argument('-p', '--input_path', type=str, help='path to input texts') parser.add_argument('-u', '--rel_base_url', type=str, help='directory containing all required REL data folders') parser.add_argument('-m', '--rel_ed_model_path', type=str, help='path to the REL entity disambiguation model') parser.add_argument('-v', '--rel_wiki_version', type=str, help='wikipedia corpus version used for REL') parser.add_argument('-w', '--wikimapper_index', type=str, help='precomputed index used by Wikimapper') parser.add_argument('-s', '--spacy_model', type=str, help='spacy model type') parser.add_argument('-o', '--output_path', type=str, help='path to output json file') args = parser.parse_args() docs = read_docs(args.input_path) rel_linked_entities = rel_entity_linking(docs, args.spacy_model, args.rel_base_url, args.rel_wiki_version, args.rel_ed_model_path) linked_entities_json = enrich_el_results(rel_linked_entities, docs, args.wikimapper_index) with jsonlines.open(args.output_path, mode='w') as writer: writer.write_all(linked_entities_json) if __name__ == '__main__': main() sys.exit(0)