Spaces:
Runtime error
Runtime error
# | |
# Pyserini: Reproducible IR research with sparse and dense representations | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
import argparse | |
import jsonlines | |
import spacy | |
import sys | |
from REL.mention_detection import MentionDetectionBase | |
from REL.utils import process_results, split_in_words | |
from REL.entity_disambiguation import EntityDisambiguation | |
from REL.ner import Span | |
from wikimapper import WikiMapper | |
from typing import Dict, List, Tuple | |
from tqdm import tqdm | |
# Spacy Mention Detection class which overrides the NERBase class in the REL entity linking process | |
class NERSpacyMD(MentionDetectionBase): | |
def __init__(self, base_url:str, wiki_version:str, spacy_model:str): | |
super().__init__(base_url, wiki_version) | |
# we only want to link entities of specific types | |
self.ner_labels = ['PERSON', 'NORP', 'FAC', 'ORG', 'GPE', 'LOC', 'PRODUCT', 'EVENT', 'WORK_OF_ART', | |
'LAW', 'LANGUAGE', 'DATE', 'TIME', 'MONEY', 'QUANTITY'] | |
self.spacy_model = spacy_model | |
spacy.prefer_gpu() | |
self.tagger = spacy.load(spacy_model) | |
# mandatory function which overrides NERBase.predict() | |
def predict(self, doc: spacy.tokens.Doc) -> List[Span]: | |
spans = [] | |
for ent in doc.ents: | |
if ent.label_ in self.ner_labels: | |
spans.append(Span(ent.text, ent.start_char, ent.end_char, 0, ent.label_)) | |
return spans | |
""" | |
Responsible for finding mentions given a set of documents in a batch-wise manner. More specifically, | |
it returns the mention, its left/right context and a set of candidates. | |
:return: Dictionary with mentions per document. | |
""" | |
def find_mentions(self, dataset: Dict[str, str]) -> Tuple[Dict[str, List[Dict]], int]: | |
results = {} | |
total_ment = 0 | |
for i, doc in tqdm(enumerate(dataset), desc='Finding mentions', total=len(dataset)): | |
result_doc = [] | |
doc_text = dataset[doc] | |
spacy_doc = self.tagger(doc_text) | |
spans = self.predict(spacy_doc) | |
for entity in spans: | |
text, start_pos, end_pos, conf, tag = ( | |
entity.text, | |
entity.start_pos, | |
entity.end_pos, | |
entity.score, | |
entity.tag, | |
) | |
m = self.preprocess_mention(text) | |
cands = self.get_candidates(m) | |
if len(cands) == 0: | |
continue | |
total_ment += 1 | |
# Re-create ngram as 'text' is at times changed by Flair (e.g. double spaces are removed). | |
ngram = doc_text[start_pos:end_pos] | |
left_ctxt = " ".join(split_in_words(doc_text[:start_pos])[-100:]) | |
right_ctxt = " ".join(split_in_words(doc_text[end_pos:])[:100]) | |
res = { | |
"mention": m, | |
"context": (left_ctxt, right_ctxt), | |
"candidates": cands, | |
"gold": ["NONE"], | |
"pos": start_pos, | |
"sent_idx": 0, | |
"ngram": ngram, | |
"end_pos": end_pos, | |
"sentence": doc_text, | |
"conf_md": conf, | |
"tag": tag, | |
} | |
result_doc.append(res) | |
results[doc] = result_doc | |
return results, total_ment | |
# run REL entity linking on processed doc | |
def rel_entity_linking(docs: Dict[str,str], spacy_model:str, rel_base_url:str, rel_wiki_version:str, rel_ed_model_path:str) -> Dict[str, List[Tuple]]: | |
mention_detection = NERSpacyMD(rel_base_url, rel_wiki_version, spacy_model) | |
mentions_dataset, _ = mention_detection.find_mentions(docs) | |
config = { | |
'mode': 'eval', | |
'model_path': rel_ed_model_path, | |
} | |
ed_model = EntityDisambiguation(rel_base_url, rel_wiki_version, config) | |
predictions, _ = ed_model.predict(mentions_dataset) | |
linked_entities = process_results(mentions_dataset, predictions, docs) | |
return linked_entities | |
# read input pyserini json docs into a dictionary | |
def read_docs(input_path: str) -> Dict[str, str]: | |
docs = {} | |
with jsonlines.open(input_path) as reader: | |
for obj in tqdm(reader, desc='Reading docs'): | |
docs[obj['id']] = obj['contents'] | |
return docs | |
# enrich REL entity linking results with entities' wikidata ids, and write final results as json objects | |
# rel_linked_entities: Tuples of entities are composed by start_pos:int, mention_length:int, ent_text:str, ent_wikipedia_id:str, conf_score:float, ner_score:int, ent_type:str | |
def enrich_el_results(rel_linked_entities: Dict[str, List[Tuple]], docs: Dict[str, str], wikimapper_index:str) -> List[Dict]: | |
wikimapper = WikiMapper(wikimapper_index) | |
linked_entities_json = [] | |
for docid, doc_text in tqdm(docs.items(), desc='Enriching EL results', total=len(rel_linked_entities)): | |
if docid not in rel_linked_entities: | |
linked_entities_json.append({'id': docid, 'contents': doc_text, 'entities': []}) | |
else: | |
linked_entities_info = [] | |
ents = rel_linked_entities[docid] | |
for start_pos, mention_length, ent_text, ent_wikipedia_id, conf_score, ner_score, ent_type in ents: | |
# find entities' wikidata ids using their REL results (i.e. linked wikipedia ids) | |
ent_wikipedia_id = ent_wikipedia_id.replace('&', '&') | |
ent_wikidata_id = wikimapper.title_to_id(ent_wikipedia_id) | |
# write results as json objects | |
linked_entities_info.append({'start_pos': start_pos, 'end_pos': start_pos + mention_length, 'ent_text': ent_text, | |
'wikipedia_id': ent_wikipedia_id, 'wikidata_id': ent_wikidata_id, | |
'ent_type': ent_type}) | |
linked_entities_json.append({'id': docid, 'contents': doc_text, 'entities': linked_entities_info}) | |
return linked_entities_json | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-p', '--input_path', type=str, help='path to input texts') | |
parser.add_argument('-u', '--rel_base_url', type=str, help='directory containing all required REL data folders') | |
parser.add_argument('-m', '--rel_ed_model_path', type=str, help='path to the REL entity disambiguation model') | |
parser.add_argument('-v', '--rel_wiki_version', type=str, help='wikipedia corpus version used for REL') | |
parser.add_argument('-w', '--wikimapper_index', type=str, help='precomputed index used by Wikimapper') | |
parser.add_argument('-s', '--spacy_model', type=str, help='spacy model type') | |
parser.add_argument('-o', '--output_path', type=str, help='path to output json file') | |
args = parser.parse_args() | |
docs = read_docs(args.input_path) | |
rel_linked_entities = rel_entity_linking(docs, args.spacy_model, args.rel_base_url, args.rel_wiki_version, | |
args.rel_ed_model_path) | |
linked_entities_json = enrich_el_results(rel_linked_entities, docs, args.wikimapper_index) | |
with jsonlines.open(args.output_path, mode='w') as writer: | |
writer.write_all(linked_entities_json) | |
if __name__ == '__main__': | |
main() | |
sys.exit(0) | |