G.Hemanth Sai
commited on
Commit
•
32d9382
1
Parent(s):
e1b10aa
qgen
Browse files- .gitattributes +1 -0
- .gitignore +11 -0
- README.md +83 -11
- app.py +222 -0
- models/s2v_reddit_2015_md.tar.gz +3 -0
- requirements.txt +11 -0
- src/Pipeline/QAhaystack.py +158 -0
- src/Pipeline/QuestGen.py +94 -0
- src/Pipeline/Reader.py +58 -0
- src/Pipeline/TextSummarization.py +50 -0
- src/PreviousVersionCode/QuestionGenerator.py +127 -0
- src/PreviousVersionCode/context.py +379 -0
.gitattributes
CHANGED
@@ -31,3 +31,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
31 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
32 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
33 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
31 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
32 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
33 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tar.gz filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#editor specific files
|
2 |
+
.vscode
|
3 |
+
.idea
|
4 |
+
|
5 |
+
#cache files
|
6 |
+
__pycache__
|
7 |
+
tempCodeRunnerFile.py
|
8 |
+
|
9 |
+
# models
|
10 |
+
models/s2v_old
|
11 |
+
models/._s2v_old
|
README.md
CHANGED
@@ -1,13 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
|
|
|
|
|
|
1 |
+
# Internship-IVIS-labs
|
2 |
+
|
3 |
+
- The *Intelligent Question Generator* app is an easy-to-use interface built in Streamlit which uses [KeyBERT](https://github.com/MaartenGr/KeyBERT), [Sense2vec](https://github.com/explosion/sense2vec), [T5](https://huggingface.co/ramsrigouthamg/t5_paraphraser)
|
4 |
+
- It uses a minimal keyword extraction technique that leverages multiple NLP embeddings and relies on [Transformers](https://huggingface.co/transformers/) 🤗 to create keywords/keyphrases that are most similar to a document.
|
5 |
+
- [sense2vec](https://github.com/explosion/sense2vec) (Trask et. al, 2015) is a nice twist on word2vec that lets you learn more interesting and detailed word vectors.
|
6 |
+
|
7 |
+
## Repository Breakdown
|
8 |
+
### src Directory
|
9 |
---
|
10 |
+
- `src/Pipeline/QAhaystack.py`: This file contains the code of question answering using [haystack](https://haystack.deepset.ai/overview/intro).
|
11 |
+
- `src/Pipeline/QuestGen.py`: This file contains the code of question generation.
|
12 |
+
- `src/Pipeline/Reader.py`: This file contains the code of reading the document.
|
13 |
+
- `src/Pipeline/TextSummariztion.py`: This file contains the code of text summarization.
|
14 |
+
- `src/PreviousVersionCode/context.py`: This file contains the finding the context of the paragraph.
|
15 |
+
- `src/PreviousVersionCode/QuestionGenerator.py`: This file contains the code of first attempt of question generation.
|
16 |
+
|
17 |
+
## Installation
|
18 |
+
```shell
|
19 |
+
$ git clone https://github.com/HemanthSai7/Internship-IVIS-labs.git
|
20 |
+
```
|
21 |
+
```shell
|
22 |
+
$ cd Internship-IVIS-labs
|
23 |
+
```
|
24 |
+
```python
|
25 |
+
pip install -r requirements.txt
|
26 |
+
```
|
27 |
+
- For the running the app for the first time locally, you need to uncomment the the lines in `src/Pipeline/QuestGen.py` to download the models to the models directory.
|
28 |
+
|
29 |
+
```python
|
30 |
+
streamlit run app.py
|
31 |
+
```
|
32 |
+
- Once the app is running, you can access it at http://localhost:8501
|
33 |
+
```shell
|
34 |
+
You can now view your Streamlit app in your browser.
|
35 |
+
|
36 |
+
Local URL: http://localhost:8501
|
37 |
+
Network URL: http://192.168.0.103:8501
|
38 |
+
```
|
39 |
+
|
40 |
+
## Tech Stack Used
|
41 |
+
![image](https://img.shields.io/badge/Sense2vec-EF546D?style=for-the-badge&logo=Explosion.ai&logoColor=white)
|
42 |
+
![image](https://img.shields.io/badge/Spacy-09A3D5?style=for-the-badge&logo=spaCy&logoColor=white)
|
43 |
+
![image](https://img.shields.io/badge/Haystack-03AF9D?style=for-the-badge&logo=Haystackh&logoColor=white)
|
44 |
+
![image](https://img.shields.io/badge/Python-3776AB?style=for-the-badge&logo=python&logoColor=white)
|
45 |
+
![image](https://img.shields.io/badge/PyTorch-D04139?style=for-the-badge&logo=pytorch&logoColor=white)
|
46 |
+
![image](https://img.shields.io/badge/Numpy-013243?style=for-the-badge&logo=numpy&logoColor=white)
|
47 |
+
![image](https://img.shields.io/badge/Pandas-130654?style=for-the-badge&logo=pandas&logoColor=white)
|
48 |
+
![image](https://img.shields.io/badge/matplotlib-b2feb0?style=for-the-badge&logo=matplotlib&logoColor=white)
|
49 |
+
![image](https://img.shields.io/badge/scikit_learn-F7931E?style=for-the-badge&logo=scikit-learn&logoColor=white)
|
50 |
+
![image](https://img.shields.io/badge/Streamlit-EA6566?style=for-the-badge&logo=streamlit&logoColor=white)
|
51 |
+
|
52 |
+
## Timeline
|
53 |
+
### Week 1-2:
|
54 |
+
#### Tasks
|
55 |
+
- [x] Understanding and brushing up the concepts of NLP.
|
56 |
+
- [x] Extracting images and text from a pdf file and storing it in a texty file.
|
57 |
+
- [x] Exploring various open source tools for generating questions from a given text.
|
58 |
+
- [x] Read papers related to the project (Bert,T5,RoBERTa etc).
|
59 |
+
- [x] Summarizing the extracted text using T5 base pre-trained model from the pdf file.
|
60 |
+
|
61 |
+
### Week 3-4:
|
62 |
+
#### Tasks
|
63 |
+
- [x] Understanding the concept of QA systems.
|
64 |
+
- [x] Created a basic script for generating questions from the text.
|
65 |
+
- [x] Created a basic script for finding the context of the paragraph.
|
66 |
+
|
67 |
+
### Week 5-6:
|
68 |
+
#### Tasks
|
69 |
+
|
70 |
+
- [x] Understanding how Transformers models work for NLP tasks Question answering and generation
|
71 |
+
- [x] Understanding how to use the Haystack library for QA systems.
|
72 |
+
- [x] Understanding how to use the Haystack library for Question generation.
|
73 |
+
- [x] PreProcessed the document for Haystack QA for better results .
|
74 |
+
|
75 |
+
### Week 7-8:
|
76 |
+
#### Tasks
|
77 |
+
- [x] Understanding how to generate questions intelligently.
|
78 |
+
- [x] Explored wordnet to find synonyms
|
79 |
+
- [x] Used BertWSD for disambiguating the sentence provided.
|
80 |
+
- [x] Used KeyBERT for finding the keywords in the document.
|
81 |
+
- [x] Used sense2vec for finding better words with high relatedness for the keywords generated.
|
82 |
|
83 |
+
### Week 9-10:
|
84 |
+
#### Tasks
|
85 |
+
- [x] Create a streamlit app to demonstrate the project.
|
app.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
from keybert import KeyBERT
|
4 |
+
|
5 |
+
import seaborn as sns
|
6 |
+
|
7 |
+
from src.Pipeline.TextSummarization import T5_Base
|
8 |
+
from src.Pipeline.QuestGen import sense2vec_get_words,get_question
|
9 |
+
|
10 |
+
|
11 |
+
st.title("❓ Intelligent Question Generator")
|
12 |
+
st.header("")
|
13 |
+
|
14 |
+
|
15 |
+
with st.expander("ℹ️ - About this app", expanded=True):
|
16 |
+
|
17 |
+
st.write(
|
18 |
+
"""
|
19 |
+
- The *Intelligent Question Generator* app is an easy-to-use interface built in Streamlit which uses [KeyBERT](https://github.com/MaartenGr/KeyBERT), [Sense2vec](https://github.com/explosion/sense2vec), [T5](https://huggingface.co/ramsrigouthamg/t5_paraphraser)
|
20 |
+
- It uses a minimal keyword extraction technique that leverages multiple NLP embeddings and relies on [Transformers](https://huggingface.co/transformers/) 🤗 to create keywords/keyphrases that are most similar to a document.
|
21 |
+
- [sense2vec](https://github.com/explosion/sense2vec) (Trask et. al, 2015) is a nice twist on word2vec that lets you learn more interesting and detailed word vectors.
|
22 |
+
"""
|
23 |
+
)
|
24 |
+
|
25 |
+
st.markdown("")
|
26 |
+
|
27 |
+
st.markdown("")
|
28 |
+
st.markdown("## 📌 Paste document ")
|
29 |
+
|
30 |
+
with st.form(key="my_form"):
|
31 |
+
ce, c1, ce, c2, c3 = st.columns([0.07, 2, 0.07, 5, 1])
|
32 |
+
with c1:
|
33 |
+
ModelType = st.radio(
|
34 |
+
"Choose your model",
|
35 |
+
["DistilBERT (Default)", "BERT", "RoBERTa", "ALBERT", "XLNet"],
|
36 |
+
help="At present, you can choose 1 model ie DistilBERT to embed your text. More to come!",
|
37 |
+
)
|
38 |
+
|
39 |
+
if ModelType == "Default (DistilBERT)":
|
40 |
+
# kw_model = KeyBERT(model=roberta)
|
41 |
+
|
42 |
+
@st.cache(allow_output_mutation=True)
|
43 |
+
def load_model(model):
|
44 |
+
return KeyBERT(model=model)
|
45 |
+
|
46 |
+
kw_model = load_model('roberta')
|
47 |
+
|
48 |
+
else:
|
49 |
+
@st.cache(allow_output_mutation=True)
|
50 |
+
def load_model(model):
|
51 |
+
return KeyBERT(model=model)
|
52 |
+
|
53 |
+
kw_model = load_model("distilbert-base-nli-mean-tokens")
|
54 |
+
|
55 |
+
top_N = st.slider(
|
56 |
+
"# of results",
|
57 |
+
min_value=1,
|
58 |
+
max_value=30,
|
59 |
+
value=10,
|
60 |
+
help="You can choose the number of keywords/keyphrases to display. Between 1 and 30, default number is 10.",
|
61 |
+
)
|
62 |
+
min_Ngrams = st.number_input(
|
63 |
+
"Minimum Ngram",
|
64 |
+
min_value=1,
|
65 |
+
max_value=4,
|
66 |
+
help="""The minimum value for the ngram range.
|
67 |
+
*Keyphrase_ngram_range* sets the length of the resulting keywords/keyphrases.To extract keyphrases, simply set *keyphrase_ngram_range* to (1, 2) or higher depending on the number of words you would like in the resulting keyphrases.""",
|
68 |
+
# help="Minimum value for the keyphrase_ngram_range. keyphrase_ngram_range sets the length of the resulting keywords/keyphrases. To extract keyphrases, simply set keyphrase_ngram_range to (1, # 2) or higher depending on the number of words you would like in the resulting keyphrases.",
|
69 |
+
)
|
70 |
+
|
71 |
+
max_Ngrams = st.number_input(
|
72 |
+
"Maximum Ngram",
|
73 |
+
value=1,
|
74 |
+
min_value=1,
|
75 |
+
max_value=4,
|
76 |
+
help="""The maximum value for the keyphrase_ngram_range.
|
77 |
+
*Keyphrase_ngram_range* sets the length of the resulting keywords/keyphrases.
|
78 |
+
To extract keyphrases, simply set *keyphrase_ngram_range* to (1, 2) or higher depending on the number of words you would like in the resulting keyphrases.""",
|
79 |
+
)
|
80 |
+
|
81 |
+
StopWordsCheckbox = st.checkbox(
|
82 |
+
"Remove stop words",
|
83 |
+
value=True,
|
84 |
+
help="Tick this box to remove stop words from the document (currently English only)",
|
85 |
+
)
|
86 |
+
|
87 |
+
use_MMR = st.checkbox(
|
88 |
+
"Use MMR",
|
89 |
+
value=True,
|
90 |
+
help="You can use Maximal Margin Relevance (MMR) to diversify the results. It creates keywords/keyphrases based on cosine similarity. Try high/low 'Diversity' settings below for interesting variations.",
|
91 |
+
)
|
92 |
+
|
93 |
+
Diversity = st.slider(
|
94 |
+
"Keyword diversity (MMR only)",
|
95 |
+
value=0.5,
|
96 |
+
min_value=0.0,
|
97 |
+
max_value=1.0,
|
98 |
+
step=0.1,
|
99 |
+
help="""The higher the setting, the more diverse the keywords.Note that the *Keyword diversity* slider only works if the *MMR* checkbox is ticked.""",
|
100 |
+
)
|
101 |
+
|
102 |
+
with c2:
|
103 |
+
doc = st.text_area(
|
104 |
+
"Paste your text below (max 500 words)",
|
105 |
+
height=510,
|
106 |
+
)
|
107 |
+
|
108 |
+
MAX_WORDS = 500
|
109 |
+
import re
|
110 |
+
res = len(re.findall(r"\w+", doc))
|
111 |
+
if res > MAX_WORDS:
|
112 |
+
st.warning(
|
113 |
+
"⚠️ Your text contains "
|
114 |
+
+ str(res)
|
115 |
+
+ " words."
|
116 |
+
+ " Only the first 500 words will be reviewed. Stay tuned as increased allowance is coming! 😊"
|
117 |
+
)
|
118 |
+
|
119 |
+
doc = doc[:MAX_WORDS]
|
120 |
+
# base=base=T5_Base("t5-base","cpu",2048)
|
121 |
+
# doc=base.summarize(doc)
|
122 |
+
|
123 |
+
submit_button = st.form_submit_button(label="✨ Get me the data!")
|
124 |
+
|
125 |
+
if use_MMR:
|
126 |
+
mmr = True
|
127 |
+
else:
|
128 |
+
mmr = False
|
129 |
+
|
130 |
+
if StopWordsCheckbox:
|
131 |
+
StopWords = "english"
|
132 |
+
else:
|
133 |
+
StopWords = None
|
134 |
+
|
135 |
+
if min_Ngrams > max_Ngrams:
|
136 |
+
st.warning("min_Ngrams can't be greater than max_Ngrams")
|
137 |
+
st.stop()
|
138 |
+
|
139 |
+
# Uses KeyBERT to extract the top keywords from a text
|
140 |
+
# Arguments: text (str)
|
141 |
+
# Returns: list of keywords (list)
|
142 |
+
keywords = kw_model.extract_keywords(
|
143 |
+
doc,
|
144 |
+
keyphrase_ngram_range=(min_Ngrams, max_Ngrams),
|
145 |
+
use_mmr=mmr,
|
146 |
+
stop_words=StopWords,
|
147 |
+
top_n=top_N,
|
148 |
+
diversity=Diversity,
|
149 |
+
)
|
150 |
+
# print(keywords)
|
151 |
+
|
152 |
+
st.markdown("## 🎈 Results ")
|
153 |
+
|
154 |
+
st.header("")
|
155 |
+
|
156 |
+
|
157 |
+
df = (
|
158 |
+
pd.DataFrame(keywords, columns=["Keyword/Keyphrase", "Relevancy"])
|
159 |
+
.sort_values(by="Relevancy", ascending=False)
|
160 |
+
.reset_index(drop=True)
|
161 |
+
)
|
162 |
+
|
163 |
+
df.index += 1
|
164 |
+
|
165 |
+
# Add styling
|
166 |
+
cmGreen = sns.light_palette("green", as_cmap=True)
|
167 |
+
cmRed = sns.light_palette("red", as_cmap=True)
|
168 |
+
df = df.style.background_gradient(
|
169 |
+
cmap=cmGreen,
|
170 |
+
subset=[
|
171 |
+
"Relevancy",
|
172 |
+
],
|
173 |
+
)
|
174 |
+
|
175 |
+
c1, c2, c3 = st.columns([1, 3, 1])
|
176 |
+
|
177 |
+
format_dictionary = {
|
178 |
+
"Relevancy": "{:.2%}",
|
179 |
+
}
|
180 |
+
|
181 |
+
df = df.format(format_dictionary)
|
182 |
+
|
183 |
+
with c2:
|
184 |
+
st.table(df)
|
185 |
+
|
186 |
+
with st.expander("Note about Quantitative Relevancy"):
|
187 |
+
st.markdown(
|
188 |
+
"""
|
189 |
+
- The relevancy score is a quantitative measure of how relevant the keyword/keyphrase is to the document. It is calculated using cosine similarity. The higher the score, the more relevant the keyword/keyphrase is to the document.
|
190 |
+
- So if you see a keyword/keyphrase with a high relevancy score, it means that it is a good keyword/keyphrase to use in question answering, generation ,summarization, and other NLP tasks.
|
191 |
+
"""
|
192 |
+
)
|
193 |
+
|
194 |
+
with st.form(key="ques_form"):
|
195 |
+
ice, ic1, ice, ic2 ,ic3= st.columns([0.07, 2, 0.07, 5,0.07])
|
196 |
+
with ic1:
|
197 |
+
TopN = st.slider(
|
198 |
+
"Top N sense2vec results",
|
199 |
+
value=20,
|
200 |
+
min_value=0,
|
201 |
+
max_value=50,
|
202 |
+
step=1,
|
203 |
+
help="""Get the n most similar terms.""",
|
204 |
+
)
|
205 |
+
|
206 |
+
with ic2:
|
207 |
+
input_keyword = st.text_input("Paste any keyword generated above")
|
208 |
+
keywrd_button = st.form_submit_button(label="✨ Get me the questions!")
|
209 |
+
|
210 |
+
if keywrd_button:
|
211 |
+
st.markdown("## 🎈 Questions ")
|
212 |
+
ext_keywrds=sense2vec_get_words(TopN,input_keyword)
|
213 |
+
if len(ext_keywrds)<1:
|
214 |
+
st.warning("Sorry questions couldn't be generated")
|
215 |
+
|
216 |
+
for answer in ext_keywrds:
|
217 |
+
sentence_for_T5=" ".join(doc.split())
|
218 |
+
ques=get_question(sentence_for_T5,answer)
|
219 |
+
ques=ques.replace("<pad>","").replace("</s>","").replace("<s>","")
|
220 |
+
st.markdown(f'> #### {ques} ')
|
221 |
+
|
222 |
+
|
models/s2v_reddit_2015_md.tar.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5afb7c665d7833e54b04dfaf181500acca0327b5509e5e1f8ccb3b5986f53713
|
3 |
+
size 600444501
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
numpy
|
3 |
+
pandas
|
4 |
+
seaborn
|
5 |
+
scikit-learn
|
6 |
+
PyPDF2
|
7 |
+
fitz
|
8 |
+
transformers
|
9 |
+
spacy
|
10 |
+
keybert
|
11 |
+
sense2vec
|
src/Pipeline/QAhaystack.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from haystack.document_stores import ElasticsearchDocumentStore
|
5 |
+
from haystack.utils import launch_es,print_answers
|
6 |
+
from haystack.nodes import FARMReader,TransformersReader,BM25Retriever
|
7 |
+
from haystack.pipelines import ExtractiveQAPipeline
|
8 |
+
from haystack.nodes import TextConverter,PDFToTextConverter,PreProcessor
|
9 |
+
from haystack.utils import convert_files_to_docs, fetch_archive_from_http
|
10 |
+
from Reader import PdfReader,ExtractedText
|
11 |
+
|
12 |
+
launch_es() # Launches an Elasticsearch instance on your local machine
|
13 |
+
|
14 |
+
# Install the latest release of Haystack in your own environment
|
15 |
+
#! pip install farm-haystack
|
16 |
+
|
17 |
+
"""Install the latest main of Haystack"""
|
18 |
+
# !pip install --upgrade pip
|
19 |
+
# !pip install git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab,ocr]
|
20 |
+
|
21 |
+
# # For Colab/linux based machines
|
22 |
+
# !wget --no-check-certificate https://dl.xpdfreader.com/xpdf-tools-linux-4.04.tar.gz
|
23 |
+
# !tar -xvf xpdf-tools-linux-4.04.tar.gz && sudo cp xpdf-tools-linux-4.04/bin64/pdftotext /usr/local/bin
|
24 |
+
|
25 |
+
# For Macos machines
|
26 |
+
# !wget --no-check-certificate https://dl.xpdfreader.com/xpdf-tools-mac-4.03.tar.gz
|
27 |
+
# !tar -xvf xpdf-tools-mac-4.03.tar.gz && sudo cp xpdf-tools-mac-4.03/bin64/pdftotext /usr/local/bin
|
28 |
+
|
29 |
+
"Run this script from the root of the project"
|
30 |
+
# # In Colab / No Docker environments: Start Elasticsearch from source
|
31 |
+
# ! wget https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-7.9.2-linux-x86_64.tar.gz -q
|
32 |
+
# ! tar -xzf elasticsearch-7.9.2-linux-x86_64.tar.gz
|
33 |
+
# ! chown -R daemon:daemon elasticsearch-7.9.2
|
34 |
+
|
35 |
+
# import os
|
36 |
+
# from subprocess import Popen, PIPE, STDOUT
|
37 |
+
|
38 |
+
# es_server = Popen(
|
39 |
+
# ["elasticsearch-7.9.2/bin/elasticsearch"], stdout=PIPE, stderr=STDOUT, preexec_fn=lambda: os.setuid(1) # as daemon
|
40 |
+
# )
|
41 |
+
# # wait until ES has started
|
42 |
+
# ! sleep 30
|
43 |
+
|
44 |
+
logging.basicConfig(format="%(levelname)s - %(name)s - %(message)s", level=logging.WARNING)
|
45 |
+
logging.getLogger("haystack").setLevel(logging.INFO)
|
46 |
+
|
47 |
+
class Connection:
|
48 |
+
def __init__(self,host="localhost",username="",password="",index="document"):
|
49 |
+
"""
|
50 |
+
host: Elasticsearch host. If no host is provided, the default host "localhost" is used.
|
51 |
+
|
52 |
+
port: Elasticsearch port. If no port is provided, the default port 9200 is used.
|
53 |
+
|
54 |
+
username: Elasticsearch username. If no username is provided, no username is used.
|
55 |
+
|
56 |
+
password: Elasticsearch password. If no password is provided, no password is used.
|
57 |
+
|
58 |
+
index: Elasticsearch index. If no index is provided, the default index "document" is used.
|
59 |
+
"""
|
60 |
+
self.host=host
|
61 |
+
self.username=username
|
62 |
+
self.password=password
|
63 |
+
self.index=index
|
64 |
+
|
65 |
+
def get_connection(self):
|
66 |
+
document_store=ElasticsearchDocumentStore(host=self.host,username=self.username,password=self.password,index=self.index)
|
67 |
+
return document_store
|
68 |
+
|
69 |
+
class QAHaystack:
|
70 |
+
def __init__(self, filename):
|
71 |
+
self.filename=filename
|
72 |
+
|
73 |
+
def preprocessing(self,data):
|
74 |
+
"""
|
75 |
+
This function is used to preprocess the data. Its a simple function which removes the special characters and converts the data to lower case.
|
76 |
+
"""
|
77 |
+
|
78 |
+
converter = TextConverter(remove_numeric_tables=True, valid_languages=["en"])
|
79 |
+
doc_txt = converter.convert(file_path=ExtractedText(self.filename,'data.txt').save(4,6), meta=None)[0]
|
80 |
+
|
81 |
+
converter = PDFToTextConverter(remove_numeric_tables=True, valid_languages=["en"])
|
82 |
+
doc_pdf = converter.convert(file_path="data/tutorial8/manibook.pdf", meta=None)[0]
|
83 |
+
|
84 |
+
preprocess_text=data.lower() # lowercase
|
85 |
+
preprocess_text = re.sub(r'\s+', ' ', preprocess_text) # remove extra spaces
|
86 |
+
return preprocess_text
|
87 |
+
|
88 |
+
def convert_to_document(self,data):
|
89 |
+
|
90 |
+
"""
|
91 |
+
Write the data to a text file. This is required since the haystack library requires the data to be in a text file so that it can then be converted to a document.
|
92 |
+
"""
|
93 |
+
data=self.preprocessing(data)
|
94 |
+
with open(self.filename,'w') as f:
|
95 |
+
f.write(data)
|
96 |
+
|
97 |
+
"""
|
98 |
+
Read the data from the text file.
|
99 |
+
"""
|
100 |
+
data=self.preprocessing(data)
|
101 |
+
with open(self.filename,'r') as f:
|
102 |
+
data=f.read()
|
103 |
+
data=data.split("\n")
|
104 |
+
|
105 |
+
"""
|
106 |
+
DocumentStores expect Documents in dictionary form, like that below. They are loaded using the DocumentStore.write_documents()
|
107 |
+
|
108 |
+
dicts=[
|
109 |
+
{
|
110 |
+
'content': DOCUMENT_TEXT_HERE,
|
111 |
+
'meta':{'name': DOCUMENT_NAME,...}
|
112 |
+
},...
|
113 |
+
]
|
114 |
+
|
115 |
+
(Optionally: you can also add more key-value-pairs here, that will be indexed as fields in Elasticsearch and can be accessed later for filtering or shown in the responses of the Pipeline)
|
116 |
+
"""
|
117 |
+
data_json=[{
|
118 |
+
'content':paragraph,
|
119 |
+
'meta':{
|
120 |
+
'name':self.filename
|
121 |
+
}
|
122 |
+
} for paragraph in data
|
123 |
+
]
|
124 |
+
|
125 |
+
document_store=Connection().get_connection()
|
126 |
+
document_store.write_documents(data_json)
|
127 |
+
return document_store
|
128 |
+
|
129 |
+
|
130 |
+
class Pipeline:
|
131 |
+
def __init__(self,filename,retriever=BM25Retriever,reader=FARMReader):
|
132 |
+
self.reader=reader
|
133 |
+
self.retriever=retriever
|
134 |
+
self.filename=filename
|
135 |
+
|
136 |
+
def get_prediction(self,data,query):
|
137 |
+
"""
|
138 |
+
Retrievers help narrowing down the scope for the Reader to smaller units of text where a given question could be answered. They use some simple but fast algorithm.
|
139 |
+
|
140 |
+
Here: We use Elasticsearch's default BM25 algorithm . I'll check out the other retrievers as well.
|
141 |
+
"""
|
142 |
+
retriever=self.retriever(document_store=QAHaystack(self.filename).convert_to_document(data))
|
143 |
+
|
144 |
+
"""
|
145 |
+
Readers scan the texts returned by retrievers in detail and extract k best answers. They are based on powerful, but slower deep learning models.Haystack currently supports Readers based on the frameworks FARM and Transformers.
|
146 |
+
"""
|
147 |
+
reader = self.reader(model_name_or_path="deepset/roberta-base-squad2", use_gpu=True)
|
148 |
+
|
149 |
+
"""
|
150 |
+
With a Haystack Pipeline we can stick together your building blocks to a search pipeline. Under the hood, Pipelines are Directed Acyclic Graphs (DAGs) that you can easily customize for our own use cases. To speed things up, Haystack also comes with a few predefined Pipelines. One of them is the ExtractiveQAPipeline that combines a retriever and a reader to answer our questions.
|
151 |
+
"""
|
152 |
+
pipe = ExtractiveQAPipeline(reader, retriever)
|
153 |
+
|
154 |
+
"""
|
155 |
+
This function is used to get the prediction from the pipeline.
|
156 |
+
"""
|
157 |
+
prediction = pipe.run(query=query, params={"Retriever":{"top_k":10}, "Reader":{"top_k":5}})
|
158 |
+
return prediction
|
src/Pipeline/QuestGen.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Download important files for the pipeline. Uncomment the following lines if you are running this script for the first time"""
|
2 |
+
# !wget https://github.com/explosion/sense2vec/releases/download/v1.0.0/s2v_reddit_2015_md.tar.gz
|
3 |
+
# !tar -xvf s2v_reddit_2015_md.tar.gz
|
4 |
+
# if tar file is already downloaded don't download it again
|
5 |
+
import os
|
6 |
+
import urllib.request
|
7 |
+
import tarfile
|
8 |
+
if not os.path.exists("models/s2v_reddit_2015_md.tar.gz"):
|
9 |
+
print ("Downloading Sense2Vec model")
|
10 |
+
urllib.request.urlretrieve(r"https://github.com/explosion/sense2vec/releases/download/v1.0.0/s2v_reddit_2015_md.tar.gz",filename=r"models/s2v_reddit_2015_md.tar.gz")
|
11 |
+
else:
|
12 |
+
print ("Sense2Vec model already downloaded")
|
13 |
+
|
14 |
+
reddit_s2v= "models/s2v_reddit_2015_md.tar.gz"
|
15 |
+
extract_s2v="models"
|
16 |
+
extract_s2v_folder=reddit_s2v.replace(".tar.gz","")
|
17 |
+
if not os.path.isdir(extract_s2v_folder):
|
18 |
+
with tarfile.open(reddit_s2v, 'r:gz') as tar:
|
19 |
+
tar.extractall(f"models/")
|
20 |
+
else:
|
21 |
+
print ("Already extracted")
|
22 |
+
|
23 |
+
"""Import required libraries"""
|
24 |
+
|
25 |
+
import warnings
|
26 |
+
warnings.filterwarnings('ignore')
|
27 |
+
|
28 |
+
from transformers import T5ForConditionalGeneration,T5Tokenizer
|
29 |
+
|
30 |
+
import streamlit as st
|
31 |
+
from sense2vec import Sense2Vec
|
32 |
+
|
33 |
+
@st.cache(allow_output_mutation=True)
|
34 |
+
def cache_models(paths2v,pathT5cond,pathT5):
|
35 |
+
s2v = Sense2Vec().from_disk(paths2v)
|
36 |
+
question_model = T5ForConditionalGeneration.from_pretrained(pathT5cond)
|
37 |
+
question_tokenizer = T5Tokenizer.from_pretrained(pathT5)
|
38 |
+
return (s2v,question_model,question_tokenizer)
|
39 |
+
s2v,question_model,question_tokenizer=cache_models("models/s2v_old",'ramsrigouthamg/t5_squad_v1','t5-base')
|
40 |
+
|
41 |
+
|
42 |
+
"""Filter out same sense words using sense2vec algorithm"""
|
43 |
+
|
44 |
+
def filter_same_sense_words(original,wordlist):
|
45 |
+
filtered_words=[]
|
46 |
+
base_sense =original.split('|')[1]
|
47 |
+
for eachword in wordlist:
|
48 |
+
if eachword[0].split('|')[1] == base_sense:
|
49 |
+
filtered_words.append(eachword[0].split('|')[0].replace("_", " ").title().strip())
|
50 |
+
return filtered_words
|
51 |
+
|
52 |
+
def sense2vec_get_words(topn,input_keyword):
|
53 |
+
word=input_keyword
|
54 |
+
output=[]
|
55 |
+
required_keywords=[]
|
56 |
+
output = []
|
57 |
+
try:
|
58 |
+
sense = s2v.get_best_sense(word)
|
59 |
+
most_similar = s2v.most_similar(sense, n=topn)
|
60 |
+
for i in range(len(most_similar)):
|
61 |
+
required_keywords.append(most_similar[i])
|
62 |
+
output = filter_same_sense_words(sense,required_keywords)
|
63 |
+
print (f"Similar:{output}")
|
64 |
+
except:
|
65 |
+
output =[]
|
66 |
+
|
67 |
+
return output
|
68 |
+
|
69 |
+
"""T5 Question generation"""
|
70 |
+
question_model = T5ForConditionalGeneration.from_pretrained('ramsrigouthamg/t5_squad_v1')
|
71 |
+
question_tokenizer = T5Tokenizer.from_pretrained('t5-base')
|
72 |
+
|
73 |
+
def get_question(sentence,answer):
|
74 |
+
text = f"context: {sentence} answer: {answer} </s>"
|
75 |
+
max_len = 256
|
76 |
+
encoding = question_tokenizer.encode_plus(text,max_length=max_len, pad_to_max_length=True, return_tensors="pt")
|
77 |
+
|
78 |
+
input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
|
79 |
+
|
80 |
+
outs = question_model.generate(input_ids=input_ids,
|
81 |
+
attention_mask=attention_mask,
|
82 |
+
early_stopping=True,
|
83 |
+
num_beams=5,
|
84 |
+
num_return_sequences=1,
|
85 |
+
no_repeat_ngram_size=2,
|
86 |
+
max_length=200)
|
87 |
+
|
88 |
+
|
89 |
+
dec = [question_tokenizer.decode(ids) for ids in outs]
|
90 |
+
|
91 |
+
|
92 |
+
Question = dec[0].replace("question:","")
|
93 |
+
Question= Question.strip()
|
94 |
+
return Question
|
src/Pipeline/Reader.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PyPDF2
|
2 |
+
import fitz
|
3 |
+
|
4 |
+
class PdfReader:
|
5 |
+
def __init__(self, filename):
|
6 |
+
self.filename = filename
|
7 |
+
|
8 |
+
def total_pages(self):
|
9 |
+
with open(self.filename, 'rb') as f:
|
10 |
+
pdf_reader = PyPDF2.PdfFileReader(f)
|
11 |
+
return pdf_reader.numPages
|
12 |
+
|
13 |
+
def read(self):
|
14 |
+
with open(self.filename, 'rb') as f:
|
15 |
+
pdf_reader = PyPDF2.PdfFileReader(f)
|
16 |
+
num_pages = pdf_reader.numPages
|
17 |
+
count = 0
|
18 |
+
text = ''
|
19 |
+
while count < num_pages:
|
20 |
+
text += pdf_reader.getPage(count).extractText()
|
21 |
+
count += 1
|
22 |
+
return text
|
23 |
+
|
24 |
+
def read_pages(self, start_page, end_page):
|
25 |
+
with open(self.filename, 'rb') as f:
|
26 |
+
pdf_reader = PyPDF2.PdfFileReader(f)
|
27 |
+
text = ''
|
28 |
+
for page in range(start_page, end_page):
|
29 |
+
text += pdf_reader.getPage(page).extractText()
|
30 |
+
return text
|
31 |
+
|
32 |
+
def extract_images(self):
|
33 |
+
doc = fitz.open(self.filename)
|
34 |
+
for page_index in range(len(doc)):
|
35 |
+
for img in doc.get_page_images(page_index):
|
36 |
+
xref = img[0]
|
37 |
+
pix = fitz.Pixmap(doc, xref)
|
38 |
+
if pix.n < 5: # GRAY or RGB
|
39 |
+
pix.save(f"{xref}.png")
|
40 |
+
else: # convert to RGB
|
41 |
+
pix1 = fitz.Pixmap(fitz.csRGB, pix)
|
42 |
+
pix1.save(f"{xref}.png")
|
43 |
+
pix1 = None
|
44 |
+
pix = None
|
45 |
+
|
46 |
+
class ExtractedText(PdfReader):
|
47 |
+
def __init__(self, filename, output_filename):
|
48 |
+
super().__init__(filename)
|
49 |
+
self.output_filename = output_filename
|
50 |
+
|
51 |
+
def save(self,start_page, end_page):
|
52 |
+
with open(self.filename,'rb') as f:
|
53 |
+
pdf_reader = PyPDF2.PdfFileReader(f)
|
54 |
+
text = ''
|
55 |
+
for page in range(start_page, end_page):
|
56 |
+
text += pdf_reader.getPage(page).extractText()
|
57 |
+
with open(self.output_filename, 'w',encoding='utf-8') as f:
|
58 |
+
f.write(text)
|
src/Pipeline/TextSummarization.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
from nltk.tokenize import sent_tokenize
|
6 |
+
|
7 |
+
class T5_Base:
|
8 |
+
def __init__(self,path,device,model_max_length):
|
9 |
+
self.model=T5ForConditionalGeneration.from_pretrained(path)
|
10 |
+
self.tokenizer=T5Tokenizer.from_pretrained(path,model_max_length=model_max_length)
|
11 |
+
self.device=torch.device(device)
|
12 |
+
|
13 |
+
def set_seed(seed):
|
14 |
+
random.seed(seed)
|
15 |
+
np.random.seed(seed)
|
16 |
+
torch.manual_seed(seed)
|
17 |
+
torch.cuda.manual_seed_all(seed)
|
18 |
+
|
19 |
+
def preprocess(self,data):
|
20 |
+
preprocess_text=data.strip().replace('\n',' ')
|
21 |
+
return preprocess_text
|
22 |
+
|
23 |
+
def post_process(self,data):
|
24 |
+
final=""
|
25 |
+
for sent in sent_tokenize(data):
|
26 |
+
sent=sent.capitalize()
|
27 |
+
final+=sent+" "+sent
|
28 |
+
return final
|
29 |
+
|
30 |
+
def getSummary(self,data):
|
31 |
+
data=self.preprocess(data)
|
32 |
+
t5_prepared_Data="summarize: "+data
|
33 |
+
tokenized_text=self.tokenizer.encode_plus(t5_prepared_Data,max_length=512,pad_to_max_length=False,truncation=True,return_tensors='pt').to(self.device)
|
34 |
+
input_ids,attention_mask=tokenized_text['input_ids'],tokenized_text['attention_mask']
|
35 |
+
summary_ids=self.model.generate(input_ids=input_ids,
|
36 |
+
attention_mask=attention_mask,
|
37 |
+
early_stopping=True,
|
38 |
+
num_beams=3,
|
39 |
+
num_return_sequences=1,
|
40 |
+
no_repeat_ngram_size=2,
|
41 |
+
min_length = 75,
|
42 |
+
max_length=300)
|
43 |
+
|
44 |
+
output=[self.tokenizer.decode(ids,skip_special_tokens=True) for ids in summary_ids]
|
45 |
+
summary=output[0]
|
46 |
+
summary=self.post_process(summary)
|
47 |
+
summary=summary.strip()
|
48 |
+
return summary
|
49 |
+
|
50 |
+
|
src/PreviousVersionCode/QuestionGenerator.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from TextSummarization import T5_Base
|
2 |
+
|
3 |
+
import spacy
|
4 |
+
import torch
|
5 |
+
from transformers import BertTokenizer, BertModel
|
6 |
+
from transformers import T5ForConditionalGeneration, T5Tokenizer, BertTokenizer, BertModel, AutoTokenizer
|
7 |
+
from sentence_transformers import SentenceTransformer
|
8 |
+
from sklearn.feature_extraction.text import CountVectorizer
|
9 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
10 |
+
|
11 |
+
"""
|
12 |
+
spacy.load() returns a language model object containing all components and data needed to process text. It is usually called nlp. Calling the nlp object on a string of text will return a processed Doc
|
13 |
+
"""
|
14 |
+
nlp = spacy.load("en_core_web_sm") #spacy's trained pipeline model
|
15 |
+
|
16 |
+
from warnings import filterwarnings as filt
|
17 |
+
filt('ignore')
|
18 |
+
|
19 |
+
class QuestionGenerator:
|
20 |
+
def __init__(self,path,device,model_max_length):
|
21 |
+
self.model=T5ForConditionalGeneration.from_pretrained(path)
|
22 |
+
self.tokenizer=AutoTokenizer.from_pretrained(path,model_max_length=model_max_length)
|
23 |
+
self.device=torch.device(device)
|
24 |
+
|
25 |
+
def preprocess(self,data):
|
26 |
+
preprocess_text=data.strip().replace('\n','')
|
27 |
+
return preprocess_text
|
28 |
+
|
29 |
+
def gen_question(self,data,answer):
|
30 |
+
data=self.preprocess(data)
|
31 |
+
t5_prepared_data=f'context: {data} answer: {answer}'
|
32 |
+
encoding=self.tokenizer.encode_plus(t5_prepared_data,max_length=512,pad_to_max_length=True,truncation=True,return_tensors='pt').to(self.device)
|
33 |
+
input_ids,attention_mask=encoding['input_ids'],encoding['attention_mask']
|
34 |
+
output=self.model.generate(input_ids,
|
35 |
+
attention_mask=attention_mask,
|
36 |
+
num_beams=4,
|
37 |
+
num_return_sequences=1,
|
38 |
+
no_repeat_ngram_size=2,
|
39 |
+
min_length=30,
|
40 |
+
max_length=512,
|
41 |
+
early_stopping=True)
|
42 |
+
|
43 |
+
dec=[self.tokenizer.decode(ids,skip_special_tokens=True) for ids in output]
|
44 |
+
Question=dec[0].replace("question:","").strip()
|
45 |
+
return Question
|
46 |
+
class KeywordGenerator:
|
47 |
+
def __init__(self,path,device):
|
48 |
+
self.bert_model=BertModel.from_pretrained(path)
|
49 |
+
self.bert_tokenizer=BertTokenizer.from_pretrained(path)
|
50 |
+
self.sentence_model=SentenceTransformer('distilbert-base-nli-mean-tokens')
|
51 |
+
self.device=torch.device(device)
|
52 |
+
|
53 |
+
def get_embedding(self):
|
54 |
+
"""
|
55 |
+
Token Embedding
|
56 |
+
txt = '[CLS] ' + doc + ' [SEP]' where CLS (used for classification task) is the token for the start of the sentence and SEP is the token for the end of the sentence and doc is the document to be encoded.
|
57 |
+
Ex: Sentence A : Paris is a beautiful city.
|
58 |
+
Sentence B : I love Paris.
|
59 |
+
tokens =[[cls] , Paris, is , a , beautiful , city ,[sep] , I , love , Paris ]
|
60 |
+
Before feeding the tokens to the Bert we convert the tokens into embeddings using an embedding layer called token embedding layer.
|
61 |
+
"""
|
62 |
+
tokens=self.bert_tokenizer.tokenize(txt)
|
63 |
+
token_idx = self.bert_tokenizer.convert_tokens_to_ids(tokens)
|
64 |
+
|
65 |
+
"""
|
66 |
+
Segment Embedding
|
67 |
+
Segment embedding is used to distinguish between the two gives sentences.The segment embedding layer returns only either of the two embedding EA(embedding of Sentence A) or EB(embedding of Sentence B) i.e if the input token belongs to sentence A then EA else EB for sentence B.
|
68 |
+
"""
|
69 |
+
segment_ids=[1]*len(token_idx) #This is the segment_ids for the document. [1]*len(token_idxs) is a list of 1s of length len(token_idxs).
|
70 |
+
|
71 |
+
torch_token = torch.tensor([token_idx])
|
72 |
+
torch_segment = torch.tensor([segment_ids])
|
73 |
+
return self.bert_model(torch_token,torch_segment)[-1].detach().numpy() #
|
74 |
+
|
75 |
+
def get_posTags(self,context):
|
76 |
+
"""This function returns the POS tags of the words in the context. Uses Spacy's POS tagger"""
|
77 |
+
doc=nlp(context)
|
78 |
+
doc_pos=[document.pos_ for document in doc]
|
79 |
+
return doc_pos,context.split()
|
80 |
+
|
81 |
+
def get_sentence(self,context):
|
82 |
+
"""This function returns the sentences in the context. Uses Spacy's sentence tokenizer"""
|
83 |
+
doc=nlp(context)
|
84 |
+
return list(doc.sents)
|
85 |
+
|
86 |
+
def get_vector(self,doc):
|
87 |
+
"""
|
88 |
+
Machines cannot understand characters and words. So when dealing with text data we need to represent it in numbers to be understood by the machine. Countvectorizer is a method to convert text to numerical data.
|
89 |
+
"""
|
90 |
+
stop_words="english" #This is the list of stop words that we want to remove from the text
|
91 |
+
n_gram_range=(1,1) # This is the n-gram range. (1,1)->(unigram,unigram), (1,2)->(unigram,bigram), (1,3)->(unigram,trigram), (2,2)->(bigram,bigram) etc.
|
92 |
+
df=CountVectorizer(stop_words=stop_words,ngram_range=n_gram_range).fit([doc])
|
93 |
+
return df.get_feature_names() #This returns the list of words in the text.
|
94 |
+
|
95 |
+
def get_key_words(self,context,module_type='t'):
|
96 |
+
"""
|
97 |
+
module_type: 't' for token, 's' for sentence, 'v' for vector
|
98 |
+
"""
|
99 |
+
keywords=[]
|
100 |
+
top_n=5
|
101 |
+
for txt in self.get_sentence(context):
|
102 |
+
keyword=self.get_vector(str(txt))
|
103 |
+
print(f'vectors: {keyword}')
|
104 |
+
if module_type=='t':
|
105 |
+
doc_embedding=self.get_embedding(str(txt))
|
106 |
+
keyword_embedding=self.get_embedding(' '.join(keyword))
|
107 |
+
else:
|
108 |
+
doc_embedding=self.sentence_model.encode([str(txt)])
|
109 |
+
keyword_embedding=self.sentence_model.encode(keyword)
|
110 |
+
|
111 |
+
distances=cosine_similarity(doc_embedding,keyword_embedding)
|
112 |
+
print(distances)
|
113 |
+
keywords+=[(keyword[index],str(txt)) for index in distances.argsort()[0][-top_n:]]
|
114 |
+
|
115 |
+
return keywords
|
116 |
+
|
117 |
+
txt = """Enter text"""
|
118 |
+
for ans, context in KeywordGenerator('bert-base-uncased','cpu').get_key_words(txt,'st'):
|
119 |
+
print(QuestionGenerator('ramsrigouthamg/t5_squad_v1','cpu',512).gen_question(context, ans))
|
120 |
+
print()
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
|
src/PreviousVersionCode/context.py
ADDED
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""context
|
3 |
+
|
4 |
+
Automatically generated by Colaboratory.
|
5 |
+
|
6 |
+
Original file is located at
|
7 |
+
https://colab.research.google.com/drive/1qLh1aASQj5HIENPZpHQltTuShZny_567
|
8 |
+
"""
|
9 |
+
|
10 |
+
# !pip install -q transformers
|
11 |
+
|
12 |
+
# Import important libraries
|
13 |
+
# Commented out IPython magic to ensure Python compatibility.
|
14 |
+
import os
|
15 |
+
import json
|
16 |
+
import wanb
|
17 |
+
from pprint import pprint
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from torch.utils.data import Dataset
|
21 |
+
from torch.utils.data import DataLoader
|
22 |
+
from transformers import AdamW
|
23 |
+
from tqdm.notebook import tqdm
|
24 |
+
from transformers import BertForQuestionAnswering,BertTokenizer,BertTokenizerFast
|
25 |
+
|
26 |
+
import numpy as np
|
27 |
+
import matplotlib.pyplot as plt
|
28 |
+
import seaborn as sns
|
29 |
+
import pandas as pd
|
30 |
+
# %matplotlib inline
|
31 |
+
|
32 |
+
#connecting to wandb
|
33 |
+
wandb.login()
|
34 |
+
|
35 |
+
#Sweep Configuration
|
36 |
+
PROJECT_NAME="context"
|
37 |
+
ENTITY=None
|
38 |
+
|
39 |
+
sweep_config={
|
40 |
+
'method':'random'
|
41 |
+
}
|
42 |
+
|
43 |
+
#set metric information --> we want to minimize the loss function.
|
44 |
+
metric = {
|
45 |
+
'name': 'Validation accuracy',
|
46 |
+
'goal': 'maximize'
|
47 |
+
}
|
48 |
+
sweep_config['metric'] = metric
|
49 |
+
|
50 |
+
#set all other hyperparameters
|
51 |
+
parameters_dict = {
|
52 |
+
'epochs':{
|
53 |
+
'values': [1]
|
54 |
+
},
|
55 |
+
'optimizer':{
|
56 |
+
'values': ['sgd','adam']
|
57 |
+
},
|
58 |
+
'momentum':{
|
59 |
+
'distribution': 'uniform',
|
60 |
+
'min': 0.5,
|
61 |
+
'max': 0.99
|
62 |
+
},
|
63 |
+
'batch_size':{
|
64 |
+
'distribution': 'q_log_uniform_values',
|
65 |
+
'q': 8,
|
66 |
+
'min': 16,
|
67 |
+
'max': 256
|
68 |
+
}
|
69 |
+
}
|
70 |
+
sweep_config['parameters'] = parameters_dict
|
71 |
+
|
72 |
+
#print the configuration of the sweep
|
73 |
+
pprint(sweep_config)
|
74 |
+
|
75 |
+
#initialize the sweep
|
76 |
+
sweep_id=wandb.sweep(sweep_config,project=PROJECT_NAME,entity=ENTITY)
|
77 |
+
|
78 |
+
# Mount the Google Drive to save the model
|
79 |
+
from google.colab import drive
|
80 |
+
drive.mount('/content/drive')
|
81 |
+
|
82 |
+
if not os.path.exists('/content/drive/MyDrive/BERT-SQuAD'):
|
83 |
+
os.mkdir('/content/drive/MyDrive/BERT-SQuAD')
|
84 |
+
|
85 |
+
# Download SQuAD 2.0 data
|
86 |
+
# !wget -nc https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json
|
87 |
+
# !wget -nc https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json
|
88 |
+
|
89 |
+
"""Load the training dataset and take a look at it"""
|
90 |
+
with open('train-v2.0.json','rb') as f:
|
91 |
+
squad=json.load(f)
|
92 |
+
|
93 |
+
# Each 'data' dict has two keys (title and paragraphs)
|
94 |
+
squad['data'][150]['paragraphs'][0]['context']
|
95 |
+
|
96 |
+
"""Load the dev dataset and take a look at it"""
|
97 |
+
def read_data(path):
|
98 |
+
|
99 |
+
with open(path,'rb') as f:
|
100 |
+
squad=json.load(f)
|
101 |
+
|
102 |
+
contexts=[]
|
103 |
+
questions=[]
|
104 |
+
answers=[]
|
105 |
+
for group in squad['data']:
|
106 |
+
for passage in group['paragraphs']:
|
107 |
+
context=passage['context']
|
108 |
+
for qna in passage['qas']:
|
109 |
+
question=qna['question']
|
110 |
+
for answer in qna['answers']:
|
111 |
+
contexts.append(context)
|
112 |
+
questions.append(question)
|
113 |
+
answers.append(answer)
|
114 |
+
return contexts,questions,answers
|
115 |
+
|
116 |
+
|
117 |
+
#Put the contexts, questions and answers for training and validation into the appropriate lists.
|
118 |
+
"""
|
119 |
+
The answers are dictionaries whith the answer text and an integer which indicates the start index of the answer in the context.
|
120 |
+
"""
|
121 |
+
train_contexts,train_questions,train_answers=read_data('train-v2.0.json')
|
122 |
+
valid_contexts,valid_questions,valid_answers=read_data('dev-v2.0.json')
|
123 |
+
# print(train_contexts[:10])
|
124 |
+
|
125 |
+
# Create a dictionary to map the words to their indices
|
126 |
+
def end_idx(answers,contexts):
|
127 |
+
for answers,context in zip(answers,contexts):
|
128 |
+
gold_text=answers['text']
|
129 |
+
start_idx=answers['answer_start']
|
130 |
+
end_idx=start_idx+len(gold_text)
|
131 |
+
|
132 |
+
# sometimes squad answers are off by a character or two so we fix this
|
133 |
+
if context[start_idx:end_idx] == gold_text:
|
134 |
+
answers['answer_end'] = end_idx
|
135 |
+
elif context[start_idx-1:end_idx-1] == gold_text:
|
136 |
+
answers['answer_start'] = start_idx - 1
|
137 |
+
answers['answer_end'] = end_idx - 1 # When the gold label is off by one character
|
138 |
+
elif context[start_idx-2:end_idx-2] == gold_text:
|
139 |
+
answers['answer_start'] = start_idx - 2
|
140 |
+
answers['answer_end'] = end_idx - 2 # When the gold label is off by two characters
|
141 |
+
|
142 |
+
|
143 |
+
""""Tokenization"""
|
144 |
+
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
145 |
+
train_encodings = tokenizer(train_contexts, train_questions, truncation=True, padding=True)
|
146 |
+
valid_encodings = tokenizer(valid_contexts, valid_questions, truncation=True, padding=True)
|
147 |
+
|
148 |
+
# print(train_encodings.keys()) ---> dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
|
149 |
+
|
150 |
+
# Positional encoding
|
151 |
+
def add_token_positions(encodings,answers):
|
152 |
+
start_positions=[]
|
153 |
+
end_positions=[]
|
154 |
+
for i in range(len(answers)):
|
155 |
+
start_positions.append(encodings.char_to_token(i,answers[i]['answer_start']))
|
156 |
+
end_positions.append(encodings.char_to_token(i,answers[i]['answer_end']))
|
157 |
+
|
158 |
+
# if start position is None, the answer passage has been truncated
|
159 |
+
if start_positions[-1] is None:
|
160 |
+
start_positions[-1] = tokenizer.model_max_length
|
161 |
+
if end_positions[-1] is None:
|
162 |
+
end_positions[-1] = tokenizer.model_max_length
|
163 |
+
|
164 |
+
encodings.update({'start_positions': start_positions, 'end_positions': end_positions})
|
165 |
+
|
166 |
+
|
167 |
+
"""Dataloader for the training dataset"""
|
168 |
+
class DatasetRetriever(Dataset):
|
169 |
+
def __init__(self,encodings):
|
170 |
+
self.encodings=encodings
|
171 |
+
|
172 |
+
def __getitem__(self,idx):
|
173 |
+
return {key:torch.tensor(val[idx]) for key,val in self.encodings.items()}
|
174 |
+
|
175 |
+
def __len__(self):
|
176 |
+
return len(self.encodings.input_ids)
|
177 |
+
|
178 |
+
#Split the dataset into train and validation
|
179 |
+
train_dataset=DatasetRetriever(train_encodings)
|
180 |
+
valid_dataset=DatasetRetriever(valid_encodings)
|
181 |
+
train_loader=DataLoader(train_dataset,batch_size=16,shuffle=True)
|
182 |
+
valid_loader=DataLoader(valid_dataset,batch_size=16)
|
183 |
+
model = BertForQuestionAnswering.from_pretrained("bert-base-uncased")
|
184 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
185 |
+
|
186 |
+
#Training and testing Loop
|
187 |
+
def pipeline():
|
188 |
+
epochs=1,
|
189 |
+
optimizer = torch.optim.AdamW(model.parameters(),lr=5e-5)
|
190 |
+
|
191 |
+
with wandb.init(config=None):
|
192 |
+
config=wandb.config
|
193 |
+
model.to(device)
|
194 |
+
|
195 |
+
#train the model
|
196 |
+
model.train()
|
197 |
+
for epoch in range(config.epochs):
|
198 |
+
loop = tqdm(train_loader, leave=True)
|
199 |
+
for batch in loop:
|
200 |
+
optimizer.zero_grad()
|
201 |
+
input_ids = batch['input_ids'].to(device)
|
202 |
+
attention_mask = batch['attention_mask'].to(device)
|
203 |
+
start_positions = batch['start_positions'].to(device)
|
204 |
+
end_positions = batch['end_positions'].to(device)
|
205 |
+
outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions)
|
206 |
+
loss = outputs[0]
|
207 |
+
loss.backward()
|
208 |
+
optimizer.step()
|
209 |
+
|
210 |
+
loop.set_description(f'Epoch {epoch+1}')
|
211 |
+
loop.set_postfix(loss=loss.item())
|
212 |
+
wandb.log({'Validation Loss':loss})
|
213 |
+
|
214 |
+
#set the model to evaluation phase
|
215 |
+
model.eval()
|
216 |
+
acc=[]
|
217 |
+
for batch in tqdm(valid_loader):
|
218 |
+
with torch.no_grad():
|
219 |
+
input_ids=batch['input_ids'].to(device)
|
220 |
+
attention_mask=batch['attention_mask'].to(device)
|
221 |
+
start_true=batch['start_positions'].to(device)
|
222 |
+
end_true=batch['end_positions'].to(device)
|
223 |
+
|
224 |
+
outputs=model(input_ids,attention_mask=attention_mask)
|
225 |
+
|
226 |
+
start_pred=torch.argmax(outputs['start_logits'],dim=1)
|
227 |
+
end_pred=torch.argmax(outputs['end_logits'],dim=1)
|
228 |
+
|
229 |
+
acc.append(((start_pred == start_true).sum()/len(start_pred)).item())
|
230 |
+
acc.append(((end_pred == end_true).sum()/len(end_pred)).item())
|
231 |
+
|
232 |
+
acc = sum(acc)/len(acc)
|
233 |
+
|
234 |
+
print("\n\nT/P\tanswer_start\tanswer_end\n")
|
235 |
+
for i in range(len(start_true)):
|
236 |
+
print(f"true\t{start_true[i]}\t{end_true[i]}\n"
|
237 |
+
f"pred\t{start_pred[i]}\t{end_pred[i]}\n")
|
238 |
+
wandb.log({'Validation accuracy': acc})
|
239 |
+
|
240 |
+
#Run the pipeline
|
241 |
+
wandb.agent(sweep_id, pipeline, count = 4)
|
242 |
+
|
243 |
+
|
244 |
+
"""Save the model so we dont have to train it again"""
|
245 |
+
model_path = '/content/drive/MyDrive/BERT-SQuAD'
|
246 |
+
model.save_pretrained(model_path)
|
247 |
+
tokenizer.save_pretrained(model_path)
|
248 |
+
|
249 |
+
"""Load the model"""
|
250 |
+
model_path = '/content/drive/MyDrive/BERT-SQuAD'
|
251 |
+
model = BertForQuestionAnswering.from_pretrained(model_path)
|
252 |
+
tokenizer = BertTokenizerFast.from_pretrained(model_path)
|
253 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
254 |
+
model = model.to(device)
|
255 |
+
|
256 |
+
|
257 |
+
|
258 |
+
#Get predictions
|
259 |
+
def get_prediction(context,answer):
|
260 |
+
inputs=tokenizer.encode_plus(question,context,return_tensors='pt').to(device)
|
261 |
+
outputs=model(**inputs)
|
262 |
+
answer_start=torch.argmax(outputs[0]) # start position of the answer
|
263 |
+
answer_end=torch.argmax(outputs[1])+1 # end position of the answer
|
264 |
+
answer = tokenizer.convert_tokens_to_string(tokenizer. ## convert the tokens to string
|
265 |
+
convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))
|
266 |
+
return answer
|
267 |
+
|
268 |
+
|
269 |
+
"""
|
270 |
+
Question testing
|
271 |
+
|
272 |
+
Official SQuAD evaluation script-->
|
273 |
+
https://colab.research.google.com/github/fastforwardlabs/ff14_blog/blob/master/_notebooks/2020-06-09-Evaluating_BERT_on_SQuAD.ipynb#scrollTo=MzPlHgWEBQ8D
|
274 |
+
"""
|
275 |
+
|
276 |
+
def normalize_text(s):
|
277 |
+
"""Removing articles and punctuation, and standardizing whitespace are all typical text processing steps."""
|
278 |
+
import string, re
|
279 |
+
def remove_articles(text):
|
280 |
+
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
|
281 |
+
return re.sub(regex, " ", text)
|
282 |
+
def white_space_fix(text):
|
283 |
+
return " ".join(text.split())
|
284 |
+
def remove_punc(text):
|
285 |
+
exclude = set(string.punctuation)
|
286 |
+
return "".join(ch for ch in text if ch not in exclude)
|
287 |
+
def lower(text):
|
288 |
+
return text.lower()
|
289 |
+
|
290 |
+
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
291 |
+
|
292 |
+
def exact_match(prediction, truth):
|
293 |
+
return bool(normalize_text(prediction) == normalize_text(truth))
|
294 |
+
|
295 |
+
def compute_f1(prediction, truth):
|
296 |
+
pred_tokens = normalize_text(prediction).split()
|
297 |
+
truth_tokens = normalize_text(truth).split()
|
298 |
+
|
299 |
+
# if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise
|
300 |
+
if len(pred_tokens) == 0 or len(truth_tokens) == 0:
|
301 |
+
return int(pred_tokens == truth_tokens)
|
302 |
+
|
303 |
+
common_tokens = set(pred_tokens) & set(truth_tokens)
|
304 |
+
|
305 |
+
# if there are no common tokens then f1 = 0
|
306 |
+
if len(common_tokens) == 0:
|
307 |
+
return 0
|
308 |
+
|
309 |
+
prec = len(common_tokens) / len(pred_tokens)
|
310 |
+
rec = len(common_tokens) / len(truth_tokens)
|
311 |
+
|
312 |
+
return round(2 * (prec * rec) / (prec + rec), 2)
|
313 |
+
|
314 |
+
def question_answer(context, question,answer):
|
315 |
+
prediction = get_prediction(context,question)
|
316 |
+
em_score = exact_match(prediction, answer)
|
317 |
+
f1_score = compute_f1(prediction, answer)
|
318 |
+
|
319 |
+
print(f'Question: {question}')
|
320 |
+
print(f'Prediction: {prediction}')
|
321 |
+
print(f'True Answer: {answer}')
|
322 |
+
print(f'Exact match: {em_score}')
|
323 |
+
print(f'F1 score: {f1_score}\n')
|
324 |
+
|
325 |
+
context = """Space exploration is a very exciting field of research. It is the
|
326 |
+
frontier of Physics and no doubt will change the understanding of science.
|
327 |
+
However, it does come at a cost. A normal space shuttle costs about 1.5 billion dollars to make.
|
328 |
+
The annual budget of NASA, which is a premier space exploring organization is about 17 billion.
|
329 |
+
So the question that some people ask is that whether it is worth it."""
|
330 |
+
|
331 |
+
|
332 |
+
questions =["What wil change the understanding of science?",
|
333 |
+
"What is the main idea in the paragraph?"]
|
334 |
+
|
335 |
+
answers = ["Space Exploration",
|
336 |
+
"The cost of space exploration is too high"]
|
337 |
+
|
338 |
+
"""
|
339 |
+
VISUALISATION IN PROGRESS
|
340 |
+
|
341 |
+
for question, answer in zip(questions, answers):
|
342 |
+
question_answer(context, question, answer)
|
343 |
+
|
344 |
+
#Visualize the start scores
|
345 |
+
plt.rcParams["figure.figsize"]=(20,10)
|
346 |
+
ax=sns.barplot(x=token_labels,y=start_scores)
|
347 |
+
ax.set_xticklabels(ax.get_xticklabels(),rotation=90,ha="center")
|
348 |
+
ax.grid(True)
|
349 |
+
plt.title("Start word scores")
|
350 |
+
plt.show()
|
351 |
+
|
352 |
+
#Visualize the end scores
|
353 |
+
plt.rcParams["figure.figsize"]=(20,10)
|
354 |
+
ax=sns.barplot(x=token_labels,y=end_scores)
|
355 |
+
ax.set_xticklabels(ax.get_xticklabels(),rotation=90,ha="center")
|
356 |
+
ax.grid(True)
|
357 |
+
plt.title("End word scores")
|
358 |
+
plt.show()
|
359 |
+
|
360 |
+
#Visualize both the scores
|
361 |
+
scores=[]
|
362 |
+
for (i,token_label) in enumerate(token_labels):
|
363 |
+
# Add the token's start score as one row.
|
364 |
+
scores.append({'token_label':token_label,
|
365 |
+
'score':start_scores[i],
|
366 |
+
'marker':'start'})
|
367 |
+
|
368 |
+
# Add the token's end score as another row.
|
369 |
+
scores.append({'token_label': token_label,
|
370 |
+
'score': end_scores[i],
|
371 |
+
'marker': 'end'})
|
372 |
+
|
373 |
+
df=pd.DataFrame(scores)
|
374 |
+
group_plot=sns.catplot(x="token_label",y="score",hue="marker",data=df,
|
375 |
+
kind="bar",height=6,aspect=4)
|
376 |
+
|
377 |
+
group_plot.set_xticklabels(ax.get_xticklabels(),rotation=90,ha="center")
|
378 |
+
group_plot.ax.grid(True)
|
379 |
+
"""
|