Spaces:
Runtime error
Runtime error
Upload qbigbird.py
Browse files- qbigbird.py +89 -0
qbigbird.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import Counter
|
2 |
+
import ssl
|
3 |
+
|
4 |
+
import nltk
|
5 |
+
from nltk.tokenize import sent_tokenize
|
6 |
+
import torch
|
7 |
+
from transformers import pipeline
|
8 |
+
import wikipedia as wiki
|
9 |
+
|
10 |
+
from utils import (
|
11 |
+
clean_last_sent,
|
12 |
+
add_proper_tail,
|
13 |
+
get_filtered_words,
|
14 |
+
get_nnp_query,
|
15 |
+
get_nn_query,
|
16 |
+
get_wiki_text,
|
17 |
+
get_text_chunks,
|
18 |
+
filter_answers,
|
19 |
+
)
|
20 |
+
|
21 |
+
# necessary downloads (+ workaround for some download problems)
|
22 |
+
try:
|
23 |
+
_create_unverified_https_context = ssl._create_unverified_context
|
24 |
+
except AttributeError:
|
25 |
+
pass
|
26 |
+
else:
|
27 |
+
ssl._create_default_https_context = _create_unverified_https_context
|
28 |
+
|
29 |
+
nltk.download('punkt')
|
30 |
+
nltk.download('averaged_perceptron_tagger')
|
31 |
+
|
32 |
+
|
33 |
+
class QBigBird:
|
34 |
+
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
model='valhalla/electra-base-discriminator-finetuned_squadv1',
|
38 |
+
max_context_length=512,
|
39 |
+
top_n=5,
|
40 |
+
buzz_threshold=0.5
|
41 |
+
):
|
42 |
+
device = 0 if torch.cuda.is_available() else -1
|
43 |
+
self.qa = pipeline('question-answering', model=model, device=device)
|
44 |
+
self.max_context_length = max_context_length
|
45 |
+
self.top_n = top_n
|
46 |
+
self.buzz_threshold = buzz_threshold
|
47 |
+
|
48 |
+
def guess_and_buzz(self, question):
|
49 |
+
# get last sentence of question, clean and improve it
|
50 |
+
text = sent_tokenize(question)[-1]
|
51 |
+
text = clean_last_sent(text)
|
52 |
+
text = add_proper_tail(text)
|
53 |
+
|
54 |
+
# get the words in the question excluding stop words
|
55 |
+
filtered_words = get_filtered_words(question)
|
56 |
+
|
57 |
+
# get a Wikipedia query using the proper nouns in the question
|
58 |
+
query = get_nnp_query(question)
|
59 |
+
query_words = query.split()
|
60 |
+
# if not enough proper nouns, return wrong guess with False
|
61 |
+
if len(query_words) < 2:
|
62 |
+
return 'not enough pns', False
|
63 |
+
|
64 |
+
wikitext = get_wiki_text(query)
|
65 |
+
answer_set = set()
|
66 |
+
text_chunks = get_text_chunks(wikitext, self.max_context_length)
|
67 |
+
for chunk in text_chunks:
|
68 |
+
if any(word in chunk for word in query_words):
|
69 |
+
result = self.qa({'question': text, 'context': chunk})
|
70 |
+
answer = result['answer']
|
71 |
+
score = result['score']
|
72 |
+
answer_set.add((answer, score))
|
73 |
+
|
74 |
+
answer_set = filter_answers(answer_set, question)
|
75 |
+
if len(answer_set) == 0:
|
76 |
+
return ' ', False
|
77 |
+
|
78 |
+
answers_scores = list(answer_set)
|
79 |
+
top_answers_scores = sorted(answers_scores, key=lambda tup: tup[1], reverse=True)[:self.top_n]
|
80 |
+
|
81 |
+
answer_freq = Counter(answer for answer, score in top_answers_scores)
|
82 |
+
freq_top_answers_scores = sorted(top_answers_scores, key=lambda tup: (answer_freq[tup[0]], tup[1]), reverse=True)
|
83 |
+
freq_top_answer = freq_top_answers_scores[0][0]
|
84 |
+
# get the exact Wikipedia title
|
85 |
+
freq_top_answer = wiki.search(freq_top_answer)[0]
|
86 |
+
|
87 |
+
buzz = freq_top_answers_scores[0][1] >= self.buzz_threshold
|
88 |
+
|
89 |
+
return freq_top_answer, buzz
|