Pennywise881 commited on
Commit
9f23e0b
1 Parent(s): 0b52499

uploaded code files

Browse files
Files changed (4) hide show
  1. Article.py +46 -0
  2. QueryProcessor.py +98 -0
  3. QuestionAnswer.py +129 -0
  4. app.py +84 -0
Article.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wikipediaapi
2
+
3
+ class Article:
4
+
5
+ def __init__(self, article_name):
6
+ self.article_data = {}
7
+ self.article = wikipediaapi.Wikipedia('en').page(article_name)
8
+
9
+ def article_exists(self):
10
+ try:
11
+ if self.article.exists():
12
+ return True
13
+ except:
14
+ return False
15
+
16
+ def get_sections_and_texts(self, sections):
17
+ if 'Summary' not in self.article_data:
18
+ self.article_data['Summary'] = ''
19
+ if self.article.summary:
20
+ self.article_data['Summary'] = self.article.summary.lower().split('\n')
21
+
22
+ for section in sections:
23
+ if section.text:
24
+ self.article_data[section.title] = section.text.lower().split('\n')
25
+ if len(section.sections) > 0:
26
+ self.get_sections_and_texts(section.sections)
27
+
28
+ def remove_empty_sections(self):
29
+ for _, docs in self.article_data.items():
30
+ for d in docs:
31
+ if len(d) <= 0:
32
+ docs.remove(d)
33
+
34
+
35
+ def get_article_data(self):
36
+ self.get_sections_and_texts(self.article.sections)
37
+ self.remove_empty_sections()
38
+
39
+ num_docs = sum(len(docs) for _, docs in self.article_data.items())
40
+ avg_doc_len = sum(len(doc.split()) for _, docs in self.article_data.items() for doc in docs) / num_docs
41
+
42
+ return {
43
+ 'article_data': self.article_data,
44
+ 'num_docs': num_docs,
45
+ 'avg_doc_len': avg_doc_len
46
+ }
QueryProcessor.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from nltk.corpus import stopwords
4
+ from nltk.tokenize import RegexpTokenizer
5
+
6
+ class QueryProcessor:
7
+
8
+ def __init__(self, question, section_texts, N, avg_doc_len):
9
+ self.section_texts = section_texts
10
+ self.N = N
11
+ self.avg_doc_len = avg_doc_len
12
+ # self.bm25_scores = {}
13
+
14
+ self.query_items = self.set_query(question)
15
+ self.section_document_idx = None
16
+
17
+ def set_query(self, question):
18
+ punct_regex = RegexpTokenizer(r'\w+')
19
+
20
+ return [q for q in punct_regex.tokenize(question.lower()) if q not in stopwords.words('english')]
21
+
22
+ def get_query(self):
23
+ return self.query_items
24
+
25
+ def bm25(self, word, paragraph, k=1.2, b=0.75):
26
+ # frequency of word (word) in doc (paragraph)
27
+ freq = paragraph.split().count(word)
28
+
29
+ # term frequency
30
+ tf = (freq * (k+1)) / (freq + k * (1 - b + b * len(paragraph.split()) / self.avg_doc_len))
31
+
32
+ # number of docs that contain the word
33
+ N_q = sum([1 for _, docs in self.section_texts.items() for doc in docs if word in doc.split()])
34
+
35
+ # inverse document frequency
36
+ idf = np.log(((self.N - N_q + 0.5) / (N_q + 0.5)) + 1)
37
+
38
+ return round(tf*idf, 4)
39
+
40
+ def get_bm25_scores(self):
41
+ bm25_scores = {}
42
+
43
+ for query in self.query_items:
44
+ bm25_scores[query] = {}
45
+ for section, docs in self.section_texts.items():
46
+ bm25_scores[query][section] = {}
47
+ for doc_index in range(len(docs)):
48
+ score = self.bm25(query, docs[doc_index])
49
+ if score > 0.0:
50
+ bm25_scores[query][section][doc_index] = score
51
+
52
+ if len(bm25_scores[query][section]) <= 0:
53
+ del bm25_scores[query][section]
54
+
55
+ return bm25_scores
56
+
57
+ def filter_bad_documents(self, bm25_scores):
58
+ section_document_idx = {}
59
+
60
+ for sec_docs in bm25_scores.values():
61
+ for sec, doc_scores in sec_docs.items():
62
+ if sec not in section_document_idx:
63
+ section_document_idx[sec] = []
64
+ for doc_idx, score in doc_scores.items():
65
+ if score > 0.5 and doc_idx not in section_document_idx[sec]:
66
+ section_document_idx[sec].append(doc_idx)
67
+
68
+ if len(section_document_idx[sec]) <= 0:
69
+ del section_document_idx[sec]
70
+
71
+ return section_document_idx
72
+
73
+
74
+ def get_context(self):
75
+ bm25_scores = self.get_bm25_scores()
76
+ self.section_document_idx = self.filter_bad_documents(bm25_scores)
77
+
78
+ # print(bm25_scores)
79
+ context = ' '.join([self.section_texts[section][d_id] for section, doc_ids in self.section_document_idx.items() for d_id in doc_ids])
80
+
81
+ # print(section_document_idx)
82
+
83
+ return context
84
+
85
+ def match_section_with_answer_text(self, text):
86
+ # print(text)
87
+ sections = []
88
+ for sec, doc_ids in self.section_document_idx.items():
89
+ for d_id in doc_ids:
90
+ if self.section_texts[sec][d_id].find(text) > -1:
91
+ sections.append(sec)
92
+
93
+ return sections
94
+
95
+
96
+
97
+
98
+
QuestionAnswer.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ # # from transformers import AutoTokenizer, AutoModelForQuestionAnswering
4
+
5
+
6
+ class QuestionAnswer:
7
+
8
+ def __init__(self, data, model, tokenizer, torch_device):
9
+
10
+ self.max_length = 384
11
+ self.doc_stride = 128
12
+
13
+ self.tokenizer = tokenizer
14
+ self.model = model
15
+ self.data = data
16
+ self.torch_device = torch_device
17
+
18
+ self.output = None
19
+ self.features = None
20
+ self.results = None
21
+
22
+ def get_output_from_model(self):
23
+ # data = {'question': question, 'context': context}
24
+
25
+ with torch.no_grad():
26
+ tokenized_data = self.tokenizer(
27
+ self.data['question'],
28
+ self.data['context'],
29
+ truncation='only_second',
30
+ max_length=self.max_length,
31
+ stride=self.doc_stride,
32
+ return_overflowing_tokens=True,
33
+ return_offsets_mapping=True,
34
+ padding='max_length',
35
+ return_tensors='pt'
36
+ ).to(self.torch_device)
37
+
38
+ output = self.model(tokenized_data['input_ids'], tokenized_data['attention_mask'])
39
+
40
+ return output
41
+
42
+ # print(output.keys())
43
+ # print(output['start_logits'].shape)
44
+ # print(output['end_logits'].shape)
45
+ # print(tokenized_data.keys())
46
+
47
+ def prepare_features(self, example):
48
+ tokenized_example = self.tokenizer(
49
+ example['question'],
50
+ example['context'],
51
+ truncation='only_second',
52
+ max_length=self.max_length,
53
+ stride=self.doc_stride,
54
+ return_overflowing_tokens=True,
55
+ return_offsets_mapping=True,
56
+ padding='max_length',
57
+ )
58
+
59
+ # sample_mapping = tokenized_example.pop("overflow_to_sample_mapping")
60
+
61
+ for i in range(len(tokenized_example['input_ids'])):
62
+ sequence_ids = tokenized_example.sequence_ids(i)
63
+ # print(sequence_ids)
64
+ context_index = 1
65
+
66
+ # sample_index = sample_mapping[i]
67
+
68
+ tokenized_example["offset_mapping"][i] = [
69
+ (o if sequence_ids[k] == context_index else None)
70
+ for k, o in enumerate(tokenized_example["offset_mapping"][i])
71
+ ]
72
+
73
+ return tokenized_example
74
+
75
+ def postprocess_qa_predictions(self, data, features, raw_predictions, top_n_answers=5, max_answer_length=30):
76
+ all_start_logits, all_end_logits = raw_predictions.start_logits, raw_predictions.end_logits
77
+
78
+ # print(all_start_logits)
79
+
80
+ results = []
81
+ context = data['context']
82
+
83
+ # print(len(features['input_ids']))
84
+ for i in range(len(features['input_ids'])):
85
+ start_logits = all_start_logits[i].cpu().numpy()
86
+ end_logits = all_end_logits[i].cpu().numpy()
87
+
88
+ # print(start_logits)
89
+
90
+ offset_mapping = features['offset_mapping'][i]
91
+
92
+ start_indices = np.argsort(start_logits)[-1: -top_n_answers - 1: -1].tolist()
93
+ end_indices = np.argsort(end_logits)[-1: -top_n_answers - 1: -1].tolist()
94
+
95
+ for start_index in start_indices:
96
+ for end_index in end_indices:
97
+ if (
98
+ start_index >= len(offset_mapping)
99
+ or end_index >= len(offset_mapping)
100
+ or offset_mapping[start_index] is None
101
+ or offset_mapping[end_index] is None
102
+ or end_index < start_index
103
+ or end_index - start_index + 1 > max_answer_length
104
+ ):
105
+ continue
106
+
107
+ start_char = offset_mapping[start_index][0]
108
+ end_char = offset_mapping[end_index][1]
109
+
110
+ # print(start_logits[start_index])
111
+ # print(end_logits[end_index])
112
+ score = start_logits[start_index] + end_logits[end_index]
113
+ results.append(
114
+ {
115
+ 'score': float('%.*g' % (3, score)),
116
+ 'text': context[start_char: end_char]
117
+ }
118
+ )
119
+
120
+ results = sorted(results, key=lambda x: x["score"], reverse=True)[:top_n_answers]
121
+ return results
122
+
123
+
124
+ def get_results(self):
125
+ self.output = self.get_output_from_model()
126
+ self.features = self.prepare_features(self.data)
127
+ self.results = self.postprocess_qa_predictions(self.data, self.features, self.output)
128
+
129
+ return self.results
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import wikipediaapi
3
+ from Article import Article
4
+ from QueryProcessor import QueryProcessor
5
+ from QuestionAnswer import QuestionAnswer
6
+
7
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering
8
+
9
+ model = AutoModelForQuestionAnswering.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2')
10
+ tokenizer = AutoTokenizer.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2')
11
+
12
+ st.write("""
13
+ # Wiki Q & A
14
+ """)
15
+
16
+ placeholder = st.empty()
17
+ wiki_wiki = wikipediaapi.Wikipedia('en')
18
+
19
+ if "found_article" not in st.session_state:
20
+ st.session_state.page = 0
21
+ st.session_state.found_article = False
22
+ st.session_state.article = ''
23
+ st.session_state.conversation = []
24
+ st.session_state.article_data = {}
25
+
26
+
27
+ def get_article():
28
+ article_name = placeholder.text_input('Enter the name of a Wikipedia article', '')
29
+
30
+ if article_name:
31
+ page = wiki_wiki.page(article_name)
32
+ if page.exists():
33
+ st.session_state.found_article = True
34
+ st.session_state.article = article_name
35
+
36
+ article = Article(article_name=article_name)
37
+ st.session_state.article_data = article.get_article_data()
38
+
39
+ ask_questions()
40
+ else:
41
+ st.write(f'Sorry, could not find Wikipedia article: {article}')
42
+
43
+ def ask_questions():
44
+ question = placeholder.text_input(f"Ask questions about {st.session_state.article}", '')
45
+ st.header("Questions and Answers:")
46
+
47
+ if question:
48
+ query_processor = QueryProcessor(
49
+ question=question,
50
+ section_texts=st.session_state.article_data['article_data'],
51
+ N=st.session_state.article_data['num_docs'],
52
+ avg_doc_len=st.session_state.article_data['avg_doc_len']
53
+ )
54
+
55
+ context = query_processor.get_context()
56
+
57
+ data = {
58
+ 'question': question,
59
+ 'context': context
60
+ }
61
+
62
+ qa = QuestionAnswer(data, model, tokenizer, 'cpu')
63
+ results = qa.get_results()
64
+
65
+ answer = ''
66
+ for r in results:
67
+ answer += r['text']+", "
68
+
69
+ answer = answer[:len(answer)-2]
70
+ st.session_state.conversation.append({'question' : question, 'answer': answer})
71
+ st.session_state.conversation.reverse()
72
+ # print(results)
73
+
74
+ if len(st.session_state.conversation) > 0:
75
+
76
+ for data in st.session_state.conversation:
77
+ st.text("Question: " + data['question'] + "\n" + "Answer: " + data['answer'] )
78
+
79
+
80
+ if st.session_state.found_article == False:
81
+ get_article()
82
+
83
+ else:
84
+ ask_questions()