quiz-bowl-qa / qbigbird.py
yu3ufff's picture
Add nltk stopwords download
c9cdc79
from collections import Counter
import ssl
import nltk
from nltk.tokenize import sent_tokenize
import torch
from transformers import pipeline
import wikipedia as wiki
from utils import (
clean_last_sent,
add_proper_tail,
get_filtered_words,
get_nnp_query,
get_nn_query,
get_wiki_text,
get_text_chunks,
filter_answers,
)
# necessary downloads (+ workaround for some download problems)
try:
_create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
pass
else:
ssl._create_default_https_context = _create_unverified_https_context
nltk.download('punkt')
nltk.download('stopwords')
nltk.download('averaged_perceptron_tagger')
class QBigBird:
def __init__(
self,
model='valhalla/electra-base-discriminator-finetuned_squadv1',
max_context_length=512,
top_n=5,
buzz_threshold=0.5
):
device = 0 if torch.cuda.is_available() else -1
self.qa = pipeline('question-answering', model=model, device=device)
self.max_context_length = max_context_length
self.top_n = top_n
self.buzz_threshold = buzz_threshold
def guess_and_buzz(self, question):
# get last sentence of question, clean and improve it
text = sent_tokenize(question)[-1]
text = clean_last_sent(text)
text = add_proper_tail(text)
# get the words in the question excluding stop words
filtered_words = get_filtered_words(question)
# get a Wikipedia query using the proper nouns in the question
query = get_nnp_query(question)
query_words = query.split()
# if not enough proper nouns, return wrong guess with False
if len(query_words) < 2:
return 'not enough pns', False
wikitext = get_wiki_text(query)
answer_set = set()
text_chunks = get_text_chunks(wikitext, self.max_context_length)
for chunk in text_chunks:
if any(word in chunk for word in query_words):
result = self.qa({'question': text, 'context': chunk})
answer = result['answer']
score = result['score']
answer_set.add((answer, score))
answer_set = filter_answers(answer_set, question)
if len(answer_set) == 0:
return ' ', False
answers_scores = list(answer_set)
top_answers_scores = sorted(answers_scores, key=lambda tup: tup[1], reverse=True)[:self.top_n]
answer_freq = Counter(answer for answer, score in top_answers_scores)
freq_top_answers_scores = sorted(top_answers_scores, key=lambda tup: (answer_freq[tup[0]], tup[1]), reverse=True)
freq_top_answer = freq_top_answers_scores[0][0]
# get the exact Wikipedia title
freq_top_answer = wiki.search(freq_top_answer)[0]
buzz = freq_top_answers_scores[0][1] >= self.buzz_threshold
return freq_top_answer, buzz