Azza / test.py
Jingxiang Mo
Added deliverable 3
bde6562
raw
history blame
4.74 kB
import os
import gradio as gr
import numpy as np
import wikipediaapi as wk
from transformers import (
TokenClassificationPipeline,
AutoModelForTokenClassification,
AutoTokenizer,
)
import torch
from transformers.pipelines import AggregationStrategy
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer
# =====[ DEFINE PIPELINE ]===== #
class KeyphraseExtractionPipeline(TokenClassificationPipeline):
def __init__(self, model, *args, **kwargs):
super().__init__(
model=AutoModelForTokenClassification.from_pretrained(model),
tokenizer=AutoTokenizer.from_pretrained(model),
*args,
**kwargs
)
def postprocess(self, model_outputs):
results = super().postprocess(
model_outputs=model_outputs,
aggregation_strategy=AggregationStrategy.SIMPLE,
)
return np.unique([result.get("word").strip() for result in results])
# =====[ LOAD PIPELINE ]===== #
keyPhraseExtractionModel = "ml6team/keyphrase-extraction-kbir-inspec"
extractor = KeyphraseExtractionPipeline(model=keyPhraseExtractionModel)
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
#TODO: add further preprocessing
def keyphrases_extraction(text: str) -> str:
keyphrases = extractor(text)
return keyphrases
def wikipedia_search(input: str) -> str:
input = input.replace("\n", " ")
keyphrases = keyphrases_extraction(input)
wiki = wk.Wikipedia('en')
try :
#TODO: add better extraction and search
keyphrase_index = 0
page = wiki.page(keyphrases[keyphrase_index])
while not ('.' in page.summary) or not page.exists():
keyphrase_index += 1
if keyphrase_index == len(keyphrases):
raise Exception
page = wiki.page(keyphrases[keyphrase_index])
return page.summary
except:
return "I cannot answer this question"
def answer_question(question):
context = wikipedia_search(question)
if context == "I cannot answer this question":
return context
# ======== Tokenize ========
# Apply the tokenizer to the input text, treating them as a text-pair.
input_ids = tokenizer.encode(question, context)
# Report how long the input sequence is. if longer than 512 tokens, make it shorter
while(len(input_ids) > 512):
input_ids.pop()
print('Query has {:,} tokens.\n'.format(len(input_ids)))
# ======== Set Segment IDs ========
# Search the input_ids for the first instance of the `[SEP]` token.
sep_index = input_ids.index(tokenizer.sep_token_id)
# The number of segment A tokens includes the [SEP] token istelf.
num_seg_a = sep_index + 1
# The remainder are segment B.
num_seg_b = len(input_ids) - num_seg_a
# Construct the list of 0s and 1s.
segment_ids = [0]*num_seg_a + [1]*num_seg_b
# There should be a segment_id for every input token.
assert len(segment_ids) == len(input_ids)
# ======== Evaluate ========
# Run our example through the model.
outputs = model(torch.tensor([input_ids]), # The tokens representing our input text.
token_type_ids=torch.tensor([segment_ids]), # The segment IDs to differentiate question from answer_text
return_dict=True)
start_scores = outputs.start_logits
end_scores = outputs.end_logits
# ======== Reconstruct Answer ========
# Find the tokens with the highest `start` and `end` scores.
answer_start = torch.argmax(start_scores)
answer_end = torch.argmax(end_scores)
# Get the string versions of the input tokens.
tokens = tokenizer.convert_ids_to_tokens(input_ids)
# Start with the first token.
answer = tokens[answer_start]
# Select the remaining answer tokens and join them with whitespace.
for i in range(answer_start + 1, answer_end + 1):
# If it's a subword token, then recombine it with the previous token.
if tokens[i][0:2] == '##':
answer += tokens[i][2:]
# Otherwise, add a space then the token.
else:
answer += ' ' + tokens[i]
return 'Answer: "' + answer + '"'
# =====[ DEFINE INTERFACE ]===== #'
title = "Azza Chatbot"
examples = [
["Where is the Eiffel Tower?"],
["What is the population of France?"]
]
demo = gr.Interface(
title = title,
fn=answer_question,
inputs = "text",
outputs = "text",
examples=examples,
)
if __name__ == "__main__":
demo.launch(share=True)