amrithhun commited on
Commit
b0dbd35
1 Parent(s): cb0b4d5

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +289 -0
app.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import PyPDF2
3
+ import random
4
+ import itertools
5
+ import streamlit as st
6
+ from io import StringIO
7
+ from langchain.vectorstores import FAISS
8
+ from langchain.chains import RetrievalQA
9
+ from langchain.chat_models import ChatOpenAI
10
+ from langchain.retrievers import SVMRetriever
11
+ from langchain.chains import QAGenerationChain
12
+ from langchain.embeddings.openai import OpenAIEmbeddings
13
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
14
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
15
+ from langchain.callbacks.base import CallbackManager
16
+ from langchain.embeddings import HuggingFaceEmbeddings
17
+
18
+
19
+ st.set_page_config(page_title="PDF Analyzer",page_icon=':shark:')
20
+
21
+ @st.cache_data
22
+ def load_docs(files):
23
+ st.info("`Reading doc ...`")
24
+ all_text = ""
25
+ for file_path in files:
26
+ file_extension = os.path.splitext(file_path.name)[1]
27
+ if file_extension == ".pdf":
28
+ pdf_reader = PyPDF2.PdfReader(file_path)
29
+ text = ""
30
+ for page in pdf_reader.pages:
31
+ text += page.extract_text()
32
+ all_text += text
33
+ elif file_extension == ".txt":
34
+ stringio = StringIO(file_path.getvalue().decode("utf-8"))
35
+ text = stringio.read()
36
+ all_text += text
37
+ else:
38
+ st.warning('Please provide txt or pdf.', icon="⚠️")
39
+ return all_text
40
+
41
+
42
+
43
+
44
+ @st.cache_resource
45
+ def create_retriever(_embeddings, splits, retriever_type):
46
+ if retriever_type == "SIMILARITY SEARCH":
47
+ try:
48
+ vectorstore = FAISS.from_texts(splits, _embeddings)
49
+ except (IndexError, ValueError) as e:
50
+ st.error(f"Error creating vectorstore: {e}")
51
+ return
52
+ retriever = vectorstore.as_retriever(k=5)
53
+ elif retriever_type == "SUPPORT VECTOR MACHINES":
54
+ retriever = SVMRetriever.from_texts(splits, _embeddings)
55
+
56
+ return retriever
57
+
58
+ @st.cache_resource
59
+ def split_texts(text, chunk_size, overlap, split_method):
60
+
61
+ # Split texts
62
+ # IN: text, chunk size, overlap, split_method
63
+ # OUT: list of str splits
64
+
65
+ st.info("`Splitting doc ...`")
66
+
67
+ split_method = "RecursiveTextSplitter"
68
+ text_splitter = RecursiveCharacterTextSplitter(
69
+ chunk_size=chunk_size, chunk_overlap=overlap)
70
+
71
+ splits = text_splitter.split_text(text)
72
+ if not splits:
73
+ st.error("Failed to split document")
74
+ st.stop()
75
+
76
+ return splits
77
+
78
+ @st.cache_data
79
+ def generate_eval(text, N, chunk):
80
+
81
+ # Generate N questions from context of chunk chars
82
+ # IN: text, N questions, chunk size to draw question from in the doc
83
+ # OUT: eval set as JSON list
84
+
85
+ st.info("`Generating sample questions ...`")
86
+ n = len(text)
87
+ starting_indices = [random.randint(0, n-chunk) for _ in range(N)]
88
+ sub_sequences = [text[i:i+chunk] for i in starting_indices]
89
+ chain = QAGenerationChain.from_llm(ChatOpenAI(temperature=0))
90
+ eval_set = []
91
+ for i, b in enumerate(sub_sequences):
92
+ try:
93
+ qa = chain.run(b)
94
+ eval_set.append(qa)
95
+ st.write("Creating Question:",i+1)
96
+ except:
97
+ st.warning('Error generating question %s.' % str(i+1), icon="⚠️")
98
+ eval_set_full = list(itertools.chain.from_iterable(eval_set))
99
+ return eval_set_full
100
+
101
+
102
+ # ...
103
+
104
+ def main():
105
+
106
+ foot = f"""
107
+ <div style="
108
+ position: fixed;
109
+ bottom: 0;
110
+ left: 30%;
111
+ right: 0;
112
+ width: 50%;
113
+ padding: 0px 0px;
114
+ text-align: center;
115
+ ">
116
+ <p>Made by <a href='https://twitter.com/mehmet_ba7'>Mehmet Balioglu</a></p>
117
+ </div>
118
+ """
119
+
120
+ st.markdown(foot, unsafe_allow_html=True)
121
+
122
+ # Add custom CSS
123
+ st.markdown(
124
+ """
125
+ <style>
126
+
127
+ #MainMenu {visibility: hidden;
128
+ # }
129
+ footer {visibility: hidden;
130
+ }
131
+ .css-card {
132
+ border-radius: 0px;
133
+ padding: 30px 10px 10px 10px;
134
+ background-color: #f8f9fa;
135
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
136
+ margin-bottom: 10px;
137
+ font-family: "IBM Plex Sans", sans-serif;
138
+ }
139
+
140
+ .card-tag {
141
+ border-radius: 0px;
142
+ padding: 1px 5px 1px 5px;
143
+ margin-bottom: 10px;
144
+ position: absolute;
145
+ left: 0px;
146
+ top: 0px;
147
+ font-size: 0.6rem;
148
+ font-family: "IBM Plex Sans", sans-serif;
149
+ color: white;
150
+ background-color: green;
151
+ }
152
+
153
+ .css-zt5igj {left:0;
154
+ }
155
+
156
+ span.css-10trblm {margin-left:0;
157
+ }
158
+
159
+ div.css-1kyxreq {margin-top: -40px;
160
+ }
161
+
162
+
163
+
164
+
165
+
166
+
167
+ </style>
168
+ """,
169
+ unsafe_allow_html=True,
170
+ )
171
+ st.sidebar.image("img/logo1.png")
172
+
173
+
174
+
175
+
176
+ st.write(
177
+ f"""
178
+ <div style="display: flex; align-items: center; margin-left: 0;">
179
+ <h1 style="display: inline-block;">PDF Analyzer</h1>
180
+ <sup style="margin-left:5px;font-size:small; color: green;">beta</sup>
181
+ </div>
182
+ """,
183
+ unsafe_allow_html=True,
184
+ )
185
+
186
+
187
+
188
+
189
+
190
+
191
+ st.sidebar.title("Menu")
192
+
193
+ embedding_option = st.sidebar.radio(
194
+ "Choose Embeddings", ["OpenAI Embeddings", "HuggingFace Embeddings(slower)"])
195
+
196
+
197
+ retriever_type = st.sidebar.selectbox(
198
+ "Choose Retriever", ["SIMILARITY SEARCH", "SUPPORT VECTOR MACHINES"])
199
+
200
+ # Use RecursiveCharacterTextSplitter as the default and only text splitter
201
+ splitter_type = "RecursiveCharacterTextSplitter"
202
+
203
+ if 'openai_api_key' not in st.session_state:
204
+ openai_api_key = st.text_input(
205
+ 'Please enter your OpenAI API key or [get one here](https://platform.openai.com/account/api-keys)', value="", placeholder="Enter the OpenAI API key which begins with sk-")
206
+ if openai_api_key:
207
+ st.session_state.openai_api_key = openai_api_key
208
+ os.environ["OPENAI_API_KEY"] = openai_api_key
209
+ else:
210
+ #warning_text = 'Please enter your OpenAI API key. Get yours from here: [link](https://platform.openai.com/account/api-keys)'
211
+ #warning_html = f'<span>{warning_text}</span>'
212
+ #st.markdown(warning_html, unsafe_allow_html=True)
213
+ return
214
+ else:
215
+ os.environ["OPENAI_API_KEY"] = st.session_state.openai_api_key
216
+
217
+ uploaded_files = st.file_uploader("Upload a PDF or TXT Document", type=[
218
+ "pdf", "txt"], accept_multiple_files=True)
219
+
220
+ if uploaded_files:
221
+ # Check if last_uploaded_files is not in session_state or if uploaded_files are different from last_uploaded_files
222
+ if 'last_uploaded_files' not in st.session_state or st.session_state.last_uploaded_files != uploaded_files:
223
+ st.session_state.last_uploaded_files = uploaded_files
224
+ if 'eval_set' in st.session_state:
225
+ del st.session_state['eval_set']
226
+
227
+ # Load and process the uploaded PDF or TXT files.
228
+ loaded_text = load_docs(uploaded_files)
229
+ st.write("Documents uploaded and processed.")
230
+
231
+ # Split the document into chunks
232
+ splits = split_texts(loaded_text, chunk_size=1000,
233
+ overlap=0, split_method=splitter_type)
234
+
235
+ # Display the number of text chunks
236
+ num_chunks = len(splits)
237
+ st.write(f"Number of text chunks: {num_chunks}")
238
+
239
+ # Embed using OpenAI embeddings
240
+ # Embed using OpenAI embeddings or HuggingFace embeddings
241
+ if embedding_option == "OpenAI Embeddings":
242
+ embeddings = OpenAIEmbeddings()
243
+ elif embedding_option == "HuggingFace Embeddings(slower)":
244
+ # Replace "bert-base-uncased" with the desired HuggingFace model
245
+ embeddings = HuggingFaceEmbeddings()
246
+
247
+ retriever = create_retriever(embeddings, splits, retriever_type)
248
+
249
+
250
+ # Initialize the RetrievalQA chain with streaming output
251
+ callback_handler = StreamingStdOutCallbackHandler()
252
+ callback_manager = CallbackManager([callback_handler])
253
+
254
+ chat_openai = ChatOpenAI(
255
+ streaming=True, callback_manager=callback_manager, verbose=True, temperature=0)
256
+ qa = RetrievalQA.from_chain_type(llm=chat_openai, retriever=retriever, chain_type="stuff", verbose=True)
257
+
258
+ # Check if there are no generated question-answer pairs in the session state
259
+ if 'eval_set' not in st.session_state:
260
+ # Use the generate_eval function to generate question-answer pairs
261
+ num_eval_questions = 10 # Number of question-answer pairs to generate
262
+ st.session_state.eval_set = generate_eval(
263
+ loaded_text, num_eval_questions, 3000)
264
+
265
+ # Display the question-answer pairs in the sidebar with smaller text
266
+ for i, qa_pair in enumerate(st.session_state.eval_set):
267
+ st.sidebar.markdown(
268
+ f"""
269
+ <div class="css-card">
270
+ <span class="card-tag">Question {i + 1}</span>
271
+ <p style="font-size: 12px;">{qa_pair['question']}</p>
272
+ <p style="font-size: 12px;">{qa_pair['answer']}</p>
273
+ </div>
274
+ """,
275
+ unsafe_allow_html=True,
276
+ )
277
+ # <h4 style="font-size: 14px;">Question {i + 1}:</h4>
278
+ # <h4 style="font-size: 14px;">Answer {i + 1}:</h4>
279
+ st.write("Ready to answer questions.")
280
+
281
+ # Question and answering
282
+ user_question = st.text_input("Enter your question:")
283
+ if user_question:
284
+ answer = qa.run(user_question)
285
+ st.write("Answer:", answer)
286
+
287
+
288
+ if __name__ == "__main__":
289
+ main()