File size: 7,818 Bytes
d6585f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import argparse
import jsonlines
import spacy
import sys
from REL.mention_detection import MentionDetectionBase
from REL.utils import process_results, split_in_words
from REL.entity_disambiguation import EntityDisambiguation
from REL.ner import Span
from wikimapper import WikiMapper
from typing import Dict, List, Tuple
from tqdm import tqdm

# Spacy Mention Detection class which overrides the NERBase class in the REL entity linking process
class NERSpacyMD(MentionDetectionBase):
    def __init__(self, base_url:str, wiki_version:str, spacy_model:str):
        super().__init__(base_url, wiki_version)
        # we only want to link entities of specific types
        self.ner_labels = ['PERSON', 'NORP', 'FAC', 'ORG', 'GPE', 'LOC', 'PRODUCT', 'EVENT', 'WORK_OF_ART',
                           'LAW', 'LANGUAGE', 'DATE', 'TIME', 'MONEY', 'QUANTITY']
        self.spacy_model = spacy_model
        spacy.prefer_gpu()
        self.tagger = spacy.load(spacy_model)

    # mandatory function which overrides NERBase.predict()
    def predict(self, doc: spacy.tokens.Doc) -> List[Span]:
        spans = []
        for ent in doc.ents:
            if ent.label_ in self.ner_labels:
                spans.append(Span(ent.text, ent.start_char, ent.end_char, 0, ent.label_))
        return spans

    """
    Responsible for finding mentions given a set of documents in a batch-wise manner. More specifically,
    it returns the mention, its left/right context and a set of candidates.
    :return: Dictionary with mentions per document.
    """

    def find_mentions(self, dataset: Dict[str, str]) -> Tuple[Dict[str, List[Dict]], int]:
        results = {}
        total_ment = 0
        for i, doc in tqdm(enumerate(dataset), desc='Finding mentions', total=len(dataset)):
            result_doc = []
            doc_text = dataset[doc]
            spacy_doc = self.tagger(doc_text)
            spans = self.predict(spacy_doc)
            for entity in spans:
                text, start_pos, end_pos, conf, tag = (
                    entity.text,
                    entity.start_pos,
                    entity.end_pos,
                    entity.score,
                    entity.tag,
                )
                m = self.preprocess_mention(text)
                cands = self.get_candidates(m)
                if len(cands) == 0:
                    continue
                total_ment += 1
                # Re-create ngram as 'text' is at times changed by Flair (e.g. double spaces are removed).
                ngram = doc_text[start_pos:end_pos]
                left_ctxt = " ".join(split_in_words(doc_text[:start_pos])[-100:])
                right_ctxt = " ".join(split_in_words(doc_text[end_pos:])[:100])
                res = {
                    "mention": m,
                    "context": (left_ctxt, right_ctxt),
                    "candidates": cands,
                    "gold": ["NONE"],
                    "pos": start_pos,
                    "sent_idx": 0,
                    "ngram": ngram,
                    "end_pos": end_pos,
                    "sentence": doc_text,
                    "conf_md": conf,
                    "tag": tag,
                }
                result_doc.append(res)
            results[doc] = result_doc
        return results, total_ment


# run REL entity linking on processed doc
def rel_entity_linking(docs: Dict[str,str], spacy_model:str, rel_base_url:str, rel_wiki_version:str, rel_ed_model_path:str) -> Dict[str, List[Tuple]]:
    mention_detection = NERSpacyMD(rel_base_url, rel_wiki_version, spacy_model)
    mentions_dataset, _ = mention_detection.find_mentions(docs)
    config = {
        'mode': 'eval',
        'model_path': rel_ed_model_path,
    }
    ed_model = EntityDisambiguation(rel_base_url, rel_wiki_version, config)
    predictions, _ = ed_model.predict(mentions_dataset)

    linked_entities = process_results(mentions_dataset, predictions, docs)
    return linked_entities


# read input pyserini json docs into a dictionary
def read_docs(input_path: str) -> Dict[str, str]:
    docs = {}
    with jsonlines.open(input_path) as reader:
        for obj in tqdm(reader, desc='Reading docs'):
            docs[obj['id']] = obj['contents']
    return docs


# enrich REL entity linking results with entities' wikidata ids, and write final results as json objects
# rel_linked_entities: Tuples of entities are composed by start_pos:int, mention_length:int, ent_text:str, ent_wikipedia_id:str, conf_score:float, ner_score:int, ent_type:str
def enrich_el_results(rel_linked_entities: Dict[str, List[Tuple]], docs: Dict[str, str], wikimapper_index:str) -> List[Dict]:
    wikimapper = WikiMapper(wikimapper_index)
    linked_entities_json = []
    for docid, doc_text in tqdm(docs.items(), desc='Enriching EL results', total=len(rel_linked_entities)):
        if docid not in rel_linked_entities:
            linked_entities_json.append({'id': docid, 'contents': doc_text, 'entities': []})
        else:
            linked_entities_info = []
            ents = rel_linked_entities[docid]
            for start_pos, mention_length, ent_text, ent_wikipedia_id, conf_score, ner_score, ent_type in ents:
                # find entities' wikidata ids using their REL results (i.e. linked wikipedia ids)
                ent_wikipedia_id = ent_wikipedia_id.replace('&', '&')
                ent_wikidata_id = wikimapper.title_to_id(ent_wikipedia_id)

                # write results as json objects
                linked_entities_info.append({'start_pos': start_pos, 'end_pos': start_pos + mention_length, 'ent_text': ent_text,
                                             'wikipedia_id': ent_wikipedia_id, 'wikidata_id': ent_wikidata_id,
                                             'ent_type': ent_type})
            linked_entities_json.append({'id': docid, 'contents': doc_text, 'entities': linked_entities_info})
    return linked_entities_json

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-p', '--input_path', type=str, help='path to input texts')
    parser.add_argument('-u', '--rel_base_url', type=str, help='directory containing all required REL data folders')
    parser.add_argument('-m', '--rel_ed_model_path', type=str, help='path to the REL entity disambiguation model')
    parser.add_argument('-v', '--rel_wiki_version', type=str, help='wikipedia corpus version used for REL')
    parser.add_argument('-w', '--wikimapper_index', type=str, help='precomputed index used by Wikimapper')
    parser.add_argument('-s', '--spacy_model', type=str, help='spacy model type')
    parser.add_argument('-o', '--output_path', type=str, help='path to output json file')
    args = parser.parse_args()

    docs = read_docs(args.input_path)
    rel_linked_entities = rel_entity_linking(docs, args.spacy_model, args.rel_base_url, args.rel_wiki_version,
                                             args.rel_ed_model_path)
    linked_entities_json = enrich_el_results(rel_linked_entities, docs, args.wikimapper_index)
    with jsonlines.open(args.output_path, mode='w') as writer:
        writer.write_all(linked_entities_json)


if __name__ == '__main__':
    main()
    sys.exit(0)