yu3ufff commited on
Commit
162792e
1 Parent(s): 6fa87dd

Upload qbigbird.py

Browse files
Files changed (1) hide show
  1. qbigbird.py +89 -0
qbigbird.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+ import ssl
3
+
4
+ import nltk
5
+ from nltk.tokenize import sent_tokenize
6
+ import torch
7
+ from transformers import pipeline
8
+ import wikipedia as wiki
9
+
10
+ from utils import (
11
+ clean_last_sent,
12
+ add_proper_tail,
13
+ get_filtered_words,
14
+ get_nnp_query,
15
+ get_nn_query,
16
+ get_wiki_text,
17
+ get_text_chunks,
18
+ filter_answers,
19
+ )
20
+
21
+ # necessary downloads (+ workaround for some download problems)
22
+ try:
23
+ _create_unverified_https_context = ssl._create_unverified_context
24
+ except AttributeError:
25
+ pass
26
+ else:
27
+ ssl._create_default_https_context = _create_unverified_https_context
28
+
29
+ nltk.download('punkt')
30
+ nltk.download('averaged_perceptron_tagger')
31
+
32
+
33
+ class QBigBird:
34
+
35
+ def __init__(
36
+ self,
37
+ model='valhalla/electra-base-discriminator-finetuned_squadv1',
38
+ max_context_length=512,
39
+ top_n=5,
40
+ buzz_threshold=0.5
41
+ ):
42
+ device = 0 if torch.cuda.is_available() else -1
43
+ self.qa = pipeline('question-answering', model=model, device=device)
44
+ self.max_context_length = max_context_length
45
+ self.top_n = top_n
46
+ self.buzz_threshold = buzz_threshold
47
+
48
+ def guess_and_buzz(self, question):
49
+ # get last sentence of question, clean and improve it
50
+ text = sent_tokenize(question)[-1]
51
+ text = clean_last_sent(text)
52
+ text = add_proper_tail(text)
53
+
54
+ # get the words in the question excluding stop words
55
+ filtered_words = get_filtered_words(question)
56
+
57
+ # get a Wikipedia query using the proper nouns in the question
58
+ query = get_nnp_query(question)
59
+ query_words = query.split()
60
+ # if not enough proper nouns, return wrong guess with False
61
+ if len(query_words) < 2:
62
+ return 'not enough pns', False
63
+
64
+ wikitext = get_wiki_text(query)
65
+ answer_set = set()
66
+ text_chunks = get_text_chunks(wikitext, self.max_context_length)
67
+ for chunk in text_chunks:
68
+ if any(word in chunk for word in query_words):
69
+ result = self.qa({'question': text, 'context': chunk})
70
+ answer = result['answer']
71
+ score = result['score']
72
+ answer_set.add((answer, score))
73
+
74
+ answer_set = filter_answers(answer_set, question)
75
+ if len(answer_set) == 0:
76
+ return ' ', False
77
+
78
+ answers_scores = list(answer_set)
79
+ top_answers_scores = sorted(answers_scores, key=lambda tup: tup[1], reverse=True)[:self.top_n]
80
+
81
+ answer_freq = Counter(answer for answer, score in top_answers_scores)
82
+ freq_top_answers_scores = sorted(top_answers_scores, key=lambda tup: (answer_freq[tup[0]], tup[1]), reverse=True)
83
+ freq_top_answer = freq_top_answers_scores[0][0]
84
+ # get the exact Wikipedia title
85
+ freq_top_answer = wiki.search(freq_top_answer)[0]
86
+
87
+ buzz = freq_top_answers_scores[0][1] >= self.buzz_threshold
88
+
89
+ return freq_top_answer, buzz