from collections import Counter import ssl import nltk from nltk.tokenize import sent_tokenize import torch from transformers import pipeline import wikipedia as wiki from utils import ( clean_last_sent, add_proper_tail, get_filtered_words, get_nnp_query, get_nn_query, get_wiki_text, get_text_chunks, filter_answers, ) # necessary downloads (+ workaround for some download problems) try: _create_unverified_https_context = ssl._create_unverified_context except AttributeError: pass else: ssl._create_default_https_context = _create_unverified_https_context nltk.download('punkt') nltk.download('stopwords') nltk.download('averaged_perceptron_tagger') class QBigBird: def __init__( self, model='valhalla/electra-base-discriminator-finetuned_squadv1', max_context_length=512, top_n=5, buzz_threshold=0.5 ): device = 0 if torch.cuda.is_available() else -1 self.qa = pipeline('question-answering', model=model, device=device) self.max_context_length = max_context_length self.top_n = top_n self.buzz_threshold = buzz_threshold def guess_and_buzz(self, question): # get last sentence of question, clean and improve it text = sent_tokenize(question)[-1] text = clean_last_sent(text) text = add_proper_tail(text) # get the words in the question excluding stop words filtered_words = get_filtered_words(question) # get a Wikipedia query using the proper nouns in the question query = get_nnp_query(question) query_words = query.split() # if not enough proper nouns, return wrong guess with False if len(query_words) < 2: return 'not enough pns', False wikitext = get_wiki_text(query) answer_set = set() text_chunks = get_text_chunks(wikitext, self.max_context_length) for chunk in text_chunks: if any(word in chunk for word in query_words): result = self.qa({'question': text, 'context': chunk}) answer = result['answer'] score = result['score'] answer_set.add((answer, score)) answer_set = filter_answers(answer_set, question) if len(answer_set) == 0: return ' ', False answers_scores = list(answer_set) top_answers_scores = sorted(answers_scores, key=lambda tup: tup[1], reverse=True)[:self.top_n] answer_freq = Counter(answer for answer, score in top_answers_scores) freq_top_answers_scores = sorted(top_answers_scores, key=lambda tup: (answer_freq[tup[0]], tup[1]), reverse=True) freq_top_answer = freq_top_answers_scores[0][0] # get the exact Wikipedia title freq_top_answer = wiki.search(freq_top_answer)[0] buzz = freq_top_answers_scores[0][1] >= self.buzz_threshold return freq_top_answer, buzz