File size: 2,976 Bytes
162792e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9cdc79
162792e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from collections import Counter
import ssl

import nltk
from nltk.tokenize import sent_tokenize
import torch
from transformers import pipeline
import wikipedia as wiki

from utils import (
    clean_last_sent,
    add_proper_tail,
    get_filtered_words,
    get_nnp_query,
    get_nn_query,
    get_wiki_text,
    get_text_chunks,
    filter_answers,
)

# necessary downloads (+ workaround for some download problems)
try:
    _create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
    pass
else:
    ssl._create_default_https_context = _create_unverified_https_context

nltk.download('punkt')
nltk.download('stopwords')
nltk.download('averaged_perceptron_tagger')


class QBigBird:

    def __init__(
            self,
            model='valhalla/electra-base-discriminator-finetuned_squadv1',
            max_context_length=512,
            top_n=5,
            buzz_threshold=0.5
    ):
        device = 0 if torch.cuda.is_available() else -1
        self.qa = pipeline('question-answering', model=model, device=device)
        self.max_context_length = max_context_length
        self.top_n = top_n
        self.buzz_threshold = buzz_threshold

    def guess_and_buzz(self, question):
        # get last sentence of question, clean and improve it
        text = sent_tokenize(question)[-1]
        text = clean_last_sent(text)
        text = add_proper_tail(text)

        # get the words in the question excluding stop words
        filtered_words = get_filtered_words(question)

        # get a Wikipedia query using the proper nouns in the question
        query = get_nnp_query(question)
        query_words = query.split()
        # if not enough proper nouns, return wrong guess with False
        if len(query_words) < 2:
            return 'not enough pns', False

        wikitext = get_wiki_text(query)
        answer_set = set()
        text_chunks = get_text_chunks(wikitext, self.max_context_length)
        for chunk in text_chunks:
            if any(word in chunk for word in query_words):
                result = self.qa({'question': text, 'context': chunk})
                answer = result['answer']
                score = result['score']
                answer_set.add((answer, score))

        answer_set = filter_answers(answer_set, question)
        if len(answer_set) == 0:
            return ' ', False

        answers_scores = list(answer_set)
        top_answers_scores = sorted(answers_scores, key=lambda tup: tup[1], reverse=True)[:self.top_n]

        answer_freq = Counter(answer for answer, score in top_answers_scores)
        freq_top_answers_scores = sorted(top_answers_scores, key=lambda tup: (answer_freq[tup[0]], tup[1]), reverse=True)
        freq_top_answer = freq_top_answers_scores[0][0]
        # get the exact Wikipedia title
        freq_top_answer = wiki.search(freq_top_answer)[0]

        buzz = freq_top_answers_scores[0][1] >= self.buzz_threshold

        return freq_top_answer, buzz