File size: 6,102 Bytes
c1811af
9831428
41cb046
a381bc0
80e614a
77e7345
9831428
 
 
 
064fc00
77e7345
9831428
 
064fc00
d4e5967
3d68848
77e7345
9522bb7
9831428
 
 
 
 
 
77e7345
9831428
 
 
 
 
 
 
 
9522bb7
d4e5967
 
 
9522bb7
41cb046
 
77e7345
 
 
 
 
 
 
d4e5967
 
a381bc0
d4e5967
 
 
 
 
 
77e7345
171f660
77e7345
 
 
80e614a
 
77e7345
80e614a
171f660
80e614a
 
 
22eefa0
80e614a
 
77e7345
80e614a
 
22eefa0
80e614a
 
4970856
 
41cb046
77e7345
d4e5967
 
 
 
 
 
 
 
 
 
41cb046
77e7345
 
 
41cb046
 
d4e5967
41cb046
77e7345
41cb046
288a5de
77eba15
77e7345
 
 
 
 
 
41cb046
288a5de
77e7345
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288a5de
 
77e7345
 
288a5de
 
77e7345
288a5de
41cb046
77e7345
 
 
 
 
 
 
 
41cb046
288a5de
 
 
 
 
 
 
 
 
77e7345
288a5de
77e7345
41cb046
77e7345
 
288a5de
 
77e7345
288a5de
 
 
d4e5967
 
 
 
 
 
 
 
41cb046
9522bb7
4071dd4
77e7345
4970856
77e7345
41cb046
77e7345
 
e537f35
4071dd4
77e7345
c1811af
4970856
77e7345
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import os
import gradio as gr
import numpy as np
import wikipediaapi as wk
import wikipedia
import openai
from transformers import (
    TokenClassificationPipeline,
    AutoModelForTokenClassification,
    AutoTokenizer,
    BertForQuestionAnswering,
    BertTokenizer,
)
from transformers.pipelines import AggregationStrategy
import torch
from dotenv import load_dotenv


# =====[ DEFINE PIPELINE ]===== #
class KeyphraseExtractionPipeline(TokenClassificationPipeline):
    def __init__(self, model, *args, **kwargs):
        super().__init__(
            model=AutoModelForTokenClassification.from_pretrained(model),
            tokenizer=AutoTokenizer.from_pretrained(model),
            *args,
            **kwargs,
        )

    def postprocess(self, model_outputs):
        results = super().postprocess(
            model_outputs=model_outputs,
            aggregation_strategy=AggregationStrategy.SIMPLE,
        )
        return np.unique([result.get("word").strip() for result in results])

load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY")

# =====[ LOAD PIPELINE ]===== #
keyPhraseExtractionModel = "ml6team/keyphrase-extraction-kbir-inspec"
extractor = KeyphraseExtractionPipeline(model=keyPhraseExtractionModel)
model = BertForQuestionAnswering.from_pretrained(
    "bert-large-uncased-whole-word-masking-finetuned-squad"
)
tokenizer = BertTokenizer.from_pretrained(
    "bert-large-uncased-whole-word-masking-finetuned-squad"
)

def wikipedia_search(input: str) -> str:
    """Perform a Wikipedia search using keyphrases.

    Args:
        input (str): The input text.

    Returns:
        str: The summary of the Wikipedia page.
    """

    keyphrases = extractor( input.replace("\n", " "))
    wiki = wk.Wikipedia("en")

    try:
        if len(keyphrases) == 0:
            return "Can you add more details to your question?"

        query_suggestion = wikipedia.suggest(keyphrases[0])
        if query_suggestion is not None:
            results = wikipedia.search(query_suggestion)
        else:
            results = wikipedia.search(keyphrases[0])

        index = 0
        page = wiki.page(results[index])
        while not ("." in page.summary) or not page.exists():
            index += 1
            if index == len(results):
                raise Exception
            page = wiki.page(results[index])
        return page.summary
    except:
        return "I cannot answer this question"


def answer_question(question: str) -> str:
    """Answer the question using the context from the Wikipedia search.

    Args:
        question (str): The input question.

    Returns:
        str: The answer to the question.
    """

    context = wikipedia_search(question)
    if (context == "I cannot answer this question") or (
        context == "Can you add more details to your question?"
    ):
        return context

    # Tokenize and split input
    input_ids = tokenizer.encode(question, context)
    question_ids = input_ids[: input_ids.index(tokenizer.sep_token_id) + 1]

    # Report how long the input sequence is. if longer than 512 tokens divide it multiple sequences
    length_of_group = 512 - len(question_ids)
    input_ids_without_question = input_ids[
        input_ids.index(tokenizer.sep_token_id) + 1 :
    ]
    print(
        f"Query has {len(input_ids)} tokens, divided in {len(input_ids_without_question)//length_of_group + 1}.\n"
    )

    input_ids_split = []
    for group in range(len(input_ids_without_question) // length_of_group + 1):
        input_ids_split.append(
            question_ids
            + input_ids_without_question[
                length_of_group * group : length_of_group * (group + 1) - 1
            ]
        )
    input_ids_split.append(
        question_ids
        + input_ids_without_question[
            length_of_group
            * (len(input_ids_without_question) // length_of_group + 1) : len(
                input_ids_without_question
            )
            - 1
        ]
    )
    scores = []
    for input in input_ids_split:
        # set Segment IDs
        # Search the input_ids for the first instance of the `[SEP]` token.
        sep_index = input.index(tokenizer.sep_token_id)
        num_seg_a = sep_index + 1
        segment_ids = [0] * num_seg_a + [1] * (len(input) - num_seg_a)
        assert len(segment_ids) == len(input)

        # evaulate the model
        outputs = model(
            torch.tensor([input]),  # The tokens representing our input text.
            token_type_ids=torch.tensor(
                [segment_ids]
            ),  # The segment IDs to differentiate question from answer_text
            return_dict=True,
        )

        start_scores = outputs.start_logits
        end_scores = outputs.end_logits

        max_start_score = torch.max(start_scores)
        max_end_score = torch.max(end_scores)

        print(max_start_score)
        print(max_end_score)

        #  reconstruct answer from the tokens
        tokens = tokenizer.convert_ids_to_tokens(input_ids)
        answer = tokens[torch.argmax(start_scores)]

        for i in range(torch.argmax(start_scores) + 1, torch.argmax(end_scores) + 1):
            if tokens[i][0:2] == "##":
                answer += tokens[i][2:]
            else:
                answer += " " + tokens[i]
        scores.append((max_start_score, max_end_score, answer))

    # Compare scores for answers found and each paragraph and pick the most relevant.
    answer = max(scores, key=lambda x: x[0] + x[1])[2]

    response = openai.Completion.create(
        model="text-davinci-003",
        prompt="Answer the question " + question + "using this answer: " + answer,
        max_tokens=3000,
    )
    return response.choices[0].text.replace("\n\n", " ")

# =====[ DEFINE INTERFACE ]===== #'
title = "Azza Knowledge Agent"
examples = [["Where is the Eiffel Tower?"], ["What is the population of France?"]]
demo = gr.Interface(
    title=title,
    fn=answer_question,
    inputs="text",
    outputs="text",
    examples=examples,
    allow_flagging="never",
)

if __name__ == "__main__":
    demo.launch()