File size: 6,398 Bytes
d868172
 
357be93
7d0539f
d868172
 
 
357be93
 
d868172
 
 
 
 
 
 
 
 
 
 
 
 
9f3ce07
 
 
 
 
 
 
 
d868172
9f3ce07
 
 
 
 
 
 
 
 
 
d868172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f3ce07
4697040
d868172
516cd0a
d868172
516cd0a
9f3ce07
 
d868172
516cd0a
d868172
 
 
 
 
 
 
 
 
 
 
 
 
80a6ddf
 
 
 
 
 
 
 
 
d868172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357be93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d868172
a945a9c
bd79886
 
d868172
69c8f9c
 
d868172
b1746af
0ed70a2
69c8f9c
0ed70a2
3df5bff
 
f437981
 
 
1ea9d7f
f437981
1ea9d7f
bffe103
0ed70a2
987f96d
0ed70a2
f9cb0bf
357be93
0ed70a2
d868172
 
 
 
 
 
 
 
 
 
 
357be93
0ed70a2
 
3cf0565
c9bfd8b
 
 
0ed70a2
 
 
 
 
 
 
 
c9bfd8b
0ed70a2
 
 
 
d868172
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
from transformers import Pipeline
import nltk
import requests
import torch

nltk.download("averaged_perceptron_tagger")
nltk.download("averaged_perceptron_tagger_eng")

NEL_MODEL = "nel-mgenre-multilingual"


def get_wikipedia_page_props(input_str: str):
    """
    Retrieves the QID for a given Wikipedia page name from the specified language Wikipedia.
    If the request fails, it falls back to using the OpenRefine Wikidata API.

    Args:
        input_str (str): The input string in the format "page_name >> language".

    Returns:
        str: The QID or "NIL" if the QID is not found.
    """
    # print(f"Input string: {input_str}")
    if ">>" not in input_str:
        page_name = input_str
        language = "en"
        print(
            f"<< was not found in {input_str} so we are checking with these values: Page name: {page_name}, Language: {language}"
        )
    else:
        # Preprocess the input string
        try:
            page_name, language = input_str.split(">>")
            page_name = page_name.strip()
            language = language.strip()
        except:
            page_name = input_str
            language = "en"
            print(
                f"<< was not found in {input_str} so we are checking with these values: Page name: {page_name}, Language: {language}"
            )
    wikipedia_url = f"https://{language}.wikipedia.org/w/api.php"
    wikipedia_params = {
        "action": "query",
        "prop": "pageprops",
        "format": "json",
        "titles": page_name,
    }

    qid = "NIL"
    try:
        # Attempt to fetch from Wikipedia API
        response = requests.get(wikipedia_url, params=wikipedia_params)
        response.raise_for_status()
        data = response.json()

        if "pages" in data["query"]:
            page_id = list(data["query"]["pages"].keys())[0]

            if "pageprops" in data["query"]["pages"][page_id]:
                page_props = data["query"]["pages"][page_id]["pageprops"]

                if "wikibase_item" in page_props:
                    # print(page_props["wikibase_item"], language)
                    return page_props["wikibase_item"], language
                else:
                    return qid, language
            else:
                return qid, language
        else:
            return qid, language
    except Exception as e:
        return qid, language


def get_wikipedia_title(qid, language="en"):
    url = f"https://www.wikidata.org/w/api.php"
    params = {
        "action": "wbgetentities",
        "format": "json",
        "ids": qid,
        "props": "sitelinks/urls",
        "sitefilter": f"{language}wiki",
    }

    response = requests.get(url, params=params)
    try:
        response.raise_for_status()  # Raise an HTTPError if the response was not 2xx
        data = response.json()
    except requests.exceptions.RequestException as e:
        print(f"HTTP error: {e}")
        return "NIL", "None"
    except ValueError as e:  # Catch JSON decode errors
        print(f"Invalid JSON response: {response.text}")
        return "NIL", "None"

    try:
        title = data["entities"][qid]["sitelinks"][f"{language}wiki"]["title"]
        url = data["entities"][qid]["sitelinks"][f"{language}wiki"]["url"]
        return title, url
    except KeyError:
        return "NIL", "None"


class NelPipeline(Pipeline):

    def _sanitize_parameters(self, **kwargs):
        preprocess_kwargs = {}
        if "text" in kwargs:
            preprocess_kwargs["text"] = kwargs["text"]

        return preprocess_kwargs, {}, {}

    def preprocess(self, text, **kwargs):
        # Extract the entity between [START] and [END]
        start_token = "[START]"
        end_token = "[END]"

        if start_token in text and end_token in text:
            start_idx = text.index(start_token) + len(start_token)
            end_idx = text.index(end_token)
            enclosed_entity = text[start_idx:end_idx].strip()
            lOffset = start_idx  # left offset (start of the entity)
            rOffset = end_idx  # right offset (end of the entity)
        else:
            enclosed_entity = None
            lOffset = None
            rOffset = None

        # Generate predictions using the model
        outputs = self.model.generate(
            **self.tokenizer([text], return_tensors="pt").to(self.device),
            num_beams=1,
            num_return_sequences=1,
            max_new_tokens=30,
            return_dict_in_generate=True,
            output_scores=True,
        )
        # Decode the predictions into readable text
        wikipedia_prediction = self.tokenizer.batch_decode(
            outputs.sequences, skip_special_tokens=True
        )[0]
        # Process the scores for each token

        transition_scores = self.model.compute_transition_scores(
            outputs.sequences, outputs.scores, normalize_logits=True
        )
        log_prob_sum = sum(transition_scores[0])

        # Calculate the probability for the entire sequence by exponentiating the sum of log probabilities
        sequence_confidence = torch.exp(log_prob_sum)
        percentage = sequence_confidence.cpu().numpy() * 100.0

        # print(wikipedia_prediction, enclosed_entity, lOffset, rOffset, percentage)

        # Return the predictions along with the extracted entity, lOffset, and rOffset
        return wikipedia_prediction, enclosed_entity, lOffset, rOffset, percentage

    def _forward(self, inputs):
        return inputs

    def postprocess(self, outputs, **kwargs):
        """
        Postprocess the outputs of the model
        :param outputs:
        :param kwargs:
        :return:
        """

        wikipedia_prediction, enclosed_entity, lOffset, rOffset, percentage = outputs
        qid, language = get_wikipedia_page_props(wikipedia_prediction)
        title, url = get_wikipedia_title(qid, language=language)

        percentage = round(percentage, 2)

        results = [
            {
                # "id": f"{lOffset}:{rOffset}:{enclosed_entity}:{NEL_MODEL}",
                "surface": enclosed_entity,
                "wkd_id": qid,
                "wkpedia_pagename": title,
                "wkpedia_url": url,
                "type": "UNK",
                "confidence_nel": percentage,
                "lOffset": lOffset,
                "rOffset": rOffset,
            }
        ]
        return results