Spaces:
Runtime error
Runtime error
File size: 7,818 Bytes
d6585f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import jsonlines
import spacy
import sys
from REL.mention_detection import MentionDetectionBase
from REL.utils import process_results, split_in_words
from REL.entity_disambiguation import EntityDisambiguation
from REL.ner import Span
from wikimapper import WikiMapper
from typing import Dict, List, Tuple
from tqdm import tqdm
# Spacy Mention Detection class which overrides the NERBase class in the REL entity linking process
class NERSpacyMD(MentionDetectionBase):
def __init__(self, base_url:str, wiki_version:str, spacy_model:str):
super().__init__(base_url, wiki_version)
# we only want to link entities of specific types
self.ner_labels = ['PERSON', 'NORP', 'FAC', 'ORG', 'GPE', 'LOC', 'PRODUCT', 'EVENT', 'WORK_OF_ART',
'LAW', 'LANGUAGE', 'DATE', 'TIME', 'MONEY', 'QUANTITY']
self.spacy_model = spacy_model
spacy.prefer_gpu()
self.tagger = spacy.load(spacy_model)
# mandatory function which overrides NERBase.predict()
def predict(self, doc: spacy.tokens.Doc) -> List[Span]:
spans = []
for ent in doc.ents:
if ent.label_ in self.ner_labels:
spans.append(Span(ent.text, ent.start_char, ent.end_char, 0, ent.label_))
return spans
"""
Responsible for finding mentions given a set of documents in a batch-wise manner. More specifically,
it returns the mention, its left/right context and a set of candidates.
:return: Dictionary with mentions per document.
"""
def find_mentions(self, dataset: Dict[str, str]) -> Tuple[Dict[str, List[Dict]], int]:
results = {}
total_ment = 0
for i, doc in tqdm(enumerate(dataset), desc='Finding mentions', total=len(dataset)):
result_doc = []
doc_text = dataset[doc]
spacy_doc = self.tagger(doc_text)
spans = self.predict(spacy_doc)
for entity in spans:
text, start_pos, end_pos, conf, tag = (
entity.text,
entity.start_pos,
entity.end_pos,
entity.score,
entity.tag,
)
m = self.preprocess_mention(text)
cands = self.get_candidates(m)
if len(cands) == 0:
continue
total_ment += 1
# Re-create ngram as 'text' is at times changed by Flair (e.g. double spaces are removed).
ngram = doc_text[start_pos:end_pos]
left_ctxt = " ".join(split_in_words(doc_text[:start_pos])[-100:])
right_ctxt = " ".join(split_in_words(doc_text[end_pos:])[:100])
res = {
"mention": m,
"context": (left_ctxt, right_ctxt),
"candidates": cands,
"gold": ["NONE"],
"pos": start_pos,
"sent_idx": 0,
"ngram": ngram,
"end_pos": end_pos,
"sentence": doc_text,
"conf_md": conf,
"tag": tag,
}
result_doc.append(res)
results[doc] = result_doc
return results, total_ment
# run REL entity linking on processed doc
def rel_entity_linking(docs: Dict[str,str], spacy_model:str, rel_base_url:str, rel_wiki_version:str, rel_ed_model_path:str) -> Dict[str, List[Tuple]]:
mention_detection = NERSpacyMD(rel_base_url, rel_wiki_version, spacy_model)
mentions_dataset, _ = mention_detection.find_mentions(docs)
config = {
'mode': 'eval',
'model_path': rel_ed_model_path,
}
ed_model = EntityDisambiguation(rel_base_url, rel_wiki_version, config)
predictions, _ = ed_model.predict(mentions_dataset)
linked_entities = process_results(mentions_dataset, predictions, docs)
return linked_entities
# read input pyserini json docs into a dictionary
def read_docs(input_path: str) -> Dict[str, str]:
docs = {}
with jsonlines.open(input_path) as reader:
for obj in tqdm(reader, desc='Reading docs'):
docs[obj['id']] = obj['contents']
return docs
# enrich REL entity linking results with entities' wikidata ids, and write final results as json objects
# rel_linked_entities: Tuples of entities are composed by start_pos:int, mention_length:int, ent_text:str, ent_wikipedia_id:str, conf_score:float, ner_score:int, ent_type:str
def enrich_el_results(rel_linked_entities: Dict[str, List[Tuple]], docs: Dict[str, str], wikimapper_index:str) -> List[Dict]:
wikimapper = WikiMapper(wikimapper_index)
linked_entities_json = []
for docid, doc_text in tqdm(docs.items(), desc='Enriching EL results', total=len(rel_linked_entities)):
if docid not in rel_linked_entities:
linked_entities_json.append({'id': docid, 'contents': doc_text, 'entities': []})
else:
linked_entities_info = []
ents = rel_linked_entities[docid]
for start_pos, mention_length, ent_text, ent_wikipedia_id, conf_score, ner_score, ent_type in ents:
# find entities' wikidata ids using their REL results (i.e. linked wikipedia ids)
ent_wikipedia_id = ent_wikipedia_id.replace('&', '&')
ent_wikidata_id = wikimapper.title_to_id(ent_wikipedia_id)
# write results as json objects
linked_entities_info.append({'start_pos': start_pos, 'end_pos': start_pos + mention_length, 'ent_text': ent_text,
'wikipedia_id': ent_wikipedia_id, 'wikidata_id': ent_wikidata_id,
'ent_type': ent_type})
linked_entities_json.append({'id': docid, 'contents': doc_text, 'entities': linked_entities_info})
return linked_entities_json
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--input_path', type=str, help='path to input texts')
parser.add_argument('-u', '--rel_base_url', type=str, help='directory containing all required REL data folders')
parser.add_argument('-m', '--rel_ed_model_path', type=str, help='path to the REL entity disambiguation model')
parser.add_argument('-v', '--rel_wiki_version', type=str, help='wikipedia corpus version used for REL')
parser.add_argument('-w', '--wikimapper_index', type=str, help='precomputed index used by Wikimapper')
parser.add_argument('-s', '--spacy_model', type=str, help='spacy model type')
parser.add_argument('-o', '--output_path', type=str, help='path to output json file')
args = parser.parse_args()
docs = read_docs(args.input_path)
rel_linked_entities = rel_entity_linking(docs, args.spacy_model, args.rel_base_url, args.rel_wiki_version,
args.rel_ed_model_path)
linked_entities_json = enrich_el_results(rel_linked_entities, docs, args.wikimapper_index)
with jsonlines.open(args.output_path, mode='w') as writer:
writer.write_all(linked_entities_json)
if __name__ == '__main__':
main()
sys.exit(0)
|