ANASDAVOODTK commited on
Commit
734a77e
1 Parent(s): 8978600

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +529 -187
app.py CHANGED
@@ -1,189 +1,531 @@
1
- import numpy as np
2
  import os
3
- import cv2
4
- from PIL import Image
5
- from io import BytesIO
6
- import streamlit as st
 
 
7
  import openai
8
- import PyPDF2
9
- import base64
10
- import pypdfium2 as pdfium
11
- import docx
12
- from docx import Document
13
- import fitz
14
- import pytesseract
15
-
16
- COMPLETIONS_MODEL = "gpt-4"
17
- openai.api_key = os.environ['openapi']
18
- COMPLETIONS_API_PARAMS = {
19
- "temperature": 0.0,
20
- "max_tokens": 1000,
21
- "model": COMPLETIONS_MODEL,
22
- }
23
-
24
- @st.cache(allow_output_mutation=True)
25
- def run_on_chunks(data):
26
- response = []
27
- chunk = data_chunk(data , chunk_size = 2500)
28
- num = 0
29
- text = st.empty()
30
-
31
- for i in chunk:
32
- num = num + 1
33
- text.write(f"{num}th API request sent out of {len(chunk)}")
34
- response.append(GPT_4_API(i))
35
- text.empty()
36
-
37
- return response
38
-
39
- def data_chunk(lst , chunk_size):
40
- return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
41
-
42
- def check_file_format(filename):
43
- return filename.rsplit('.', 1)[1].lower()
44
-
45
- def pdf_to_images(pdf_file):
46
- images = []
47
- with fitz.open(pdf_file) as doc:
48
- for page in doc:
49
- pix = page.get_pixmap(alpha=False)
50
- img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
51
- images.append(img)
52
- return images
53
-
54
- def OCR(pdf_file):
55
- pdf_reader = PyPDF2.PdfReader(pdf_file)
56
- pdf_writer = PyPDF2.PdfWriter()
57
- for page_num in range(len(pdf_reader.pages)):
58
- page = pdf_reader.pages[page_num]
59
- page.scale_by(2)
60
- pdf_writer.add_page(page)
61
-
62
- with open('enlarged.pdf', 'wb') as f:
63
- pdf_writer.write(f)
64
-
65
- images = pdf_to_images('enlarged.pdf')
66
- text = ''
67
- for image in images:
68
- size = (image.width * 2, image.height * 2)
69
- image = image.resize(size, Image.ANTIALIAS)
70
- text += pytesseract.image_to_string(image)
71
-
72
- pdf_file.close()
73
- return text
74
-
75
- def txt_extraction(file_path):
76
- file_contents = file_path.read().decode("utf-8")
77
- return file_contents
78
-
79
- def docx_extraction(path):
80
- doc = docx.Document(path)
81
- full_text = []
82
- for para in doc.paragraphs:
83
- full_text.append(para.text)
84
- return '\n'.join(full_text)
85
-
86
-
87
- def download_docx(text):
88
- document = Document()
89
- document.add_paragraph(text)
90
- output = BytesIO()
91
- document.save(output)
92
- output.seek(0)
93
- st.download_button(
94
- label="Download as .docx",
95
- data=output,
96
- file_name="document.docx",
97
- mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
98
- )
99
-
100
- def GPT_4_API(data):
101
- header = """ create 12 question and answeres from given paragraph it is imporant to not use numbers to point out questions and answers, Answers should strictly be exact lines from this paragraph"."\n\nContext:\n"""
102
- QA = header + "".join(str(list(data)))
103
- response = openai.ChatCompletion.create(messages = [{"role": "user", "content": f"{QA}"},],**COMPLETIONS_API_PARAMS)
104
- return response["choices"][0]["message"]["content"]
105
-
106
- def my_text_editor(text , default_text, key, height=800):
107
- string = ""
108
- for i in default_text:
109
- string = string + i
110
- textarea = text.text_area(key, height=height, value=string)
111
- return textarea , text
112
-
113
- def get_base64_of_bin_file(bin_file):
114
- with open(bin_file, 'rb') as f:
115
- data = f.read()
116
- return base64.b64encode(data).decode()
117
-
118
- def set_png_as_page_bg(png_file):
119
-
120
- bin_str = get_base64_of_bin_file(png_file)
121
- page_bg_img = '''
122
- <style>
123
- .stApp {
124
- background-image: url("data:image/png;base64,%s");
125
- background-size: cover;
126
- }
127
- </style>
128
- ''' % bin_str
129
- st.markdown(page_bg_img, unsafe_allow_html=True)
130
- return
131
-
132
- def Extract_pdf_content(pdf_name):
133
-
134
- page_text = ""
135
- pdf_reader = PyPDF2.PdfReader(pdf_name)
136
- num_pages = len(pdf_reader.pages)
137
-
138
- for page in range(num_pages):
139
- pdf_page = pdf_reader.pages[page]
140
- page_text = page_text + pdf_page.extract_text()
141
-
142
- return page_text
143
-
144
- def process(uploaded_file):
145
-
146
- data = Extract_pdf_content(uploaded_file)
147
- return data
148
-
149
- if __name__=="__main__":
150
-
151
- pytesseract.pytesseract.tesseract_cmd = r'C:\\Program Files\\Tesseract-OCR\\tesseract.exe'
152
- PAGE_CONFIG = {"page_title":"StColab.io","page_icon":":smiley:","layout":"centered"}
153
- st.set_page_config(**PAGE_CONFIG)
154
- main_bg = 'bkgnd1.jpg'
155
- set_png_as_page_bg(main_bg)
156
-
157
- st.title("Advanced Text processing Tool")
158
- uploaded_file = st.file_uploader("Upload a Files here", type = ["pdf","docx","txt"])
159
-
160
- if uploaded_file is not None:
161
-
162
- if check_file_format(uploaded_file.name) == "pdf":
163
- data = process(uploaded_file)
164
-
165
- text = st.empty()
166
- if data == '':
167
- text.write("applying OCR")
168
- data = OCR(uploaded_file)
169
- text.empty()
170
-
171
- elif check_file_format(uploaded_file.name) == "docx":
172
- data = docx_extraction(uploaded_file)
173
-
174
- else:
175
- data = txt_extraction(uploaded_file)
176
-
177
-
178
- if st.button("re-generate set of questions and answers"):
179
- text = st.empty()
180
- st.caching.clear_cache()
181
- response = run_on_chunks(data)
182
- textdata , text = my_text_editor(text ,response,"text-editor-1", height=650)
183
- download_docx(textdata)
184
-
185
- else:
186
- text = st.empty()
187
- response = run_on_chunks(data)
188
- textdata , text = my_text_editor(text ,response,"text-editor-1", height=650)
189
- download_docx(textdata)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Importing the libraries
2
  import os
3
+ import math
4
+ import requests
5
+ import bs4
6
+ from dotenv import load_dotenv
7
+ import nltk
8
+ import numpy as np
9
  import openai
10
+ import streamlit as st
11
+ from streamlit_chat import message as show_message
12
+ import textract
13
+ import tiktoken
14
+ import uuid
15
+ import validators
16
+
17
+
18
+ # Helper variables
19
+ load_dotenv()
20
+ openai.api_key = os.environ['openapi'] # Load OpenAI API key from .env file
21
+
22
+ llm_model = "gpt-3.5-turbo" # https://platform.openai.com/docs/guides/chat/introduction
23
+ llm_context_window = (
24
+ 4097 # https://platform.openai.com/docs/guides/chat/managing-tokens
25
+ )
26
+ embed_context_window, embed_model = (
27
+ 8191,
28
+ "text-embedding-ada-002",
29
+ ) # https://platform.openai.com/docs/guides/embeddings/second-generation-models
30
+ nltk.download(
31
+ "punkt"
32
+ ) # Download the nltk punkt tokenizer for splitting text into sentences
33
+ tokenizer = tiktoken.get_encoding(
34
+ "cl100k_base"
35
+ ) # Load the cl100k_base tokenizer which is designed to work with the ada-002 model (engine)
36
+
37
+ download_chunk_size = 128 # TODO: Find optimal chunk size for downloading files
38
+ split_chunk_tokens = 300 # TODO: Find optimal chunk size for splitting text
39
+ num_citations = 5 # TODO: Find optimal number of citations to give context to the LLM
40
+
41
+ # Streamlit settings
42
+ user_avatar_style = "fun-emoji" # https://www.dicebear.com/styles
43
+ assistant_avatar_style = "bottts-neutral"
44
+
45
+
46
+ # Helper functions
47
+ def get_num_tokens(text): # Count the number of tokens in a string
48
+ return len(
49
+ tokenizer.encode(text, disallowed_special=())
50
+ ) # disallowed_special=() removes the special tokens)
51
+
52
+
53
+ # TODO:
54
+ # Currently, any sentence that is longer than the max number of tokens will be its own chunk
55
+ # This is not ideal, since this doesn't ensure that the chunks are of a maximum size
56
+ # Find a way to split the sentence into chunks of a maximum size
57
+ def split_into_many(text): # Split text into chunks of a maximum number of tokens
58
+ sentences = nltk.tokenize.sent_tokenize(text) # Split the text into sentences
59
+ total_tokens = [
60
+ get_num_tokens(sentence) for sentence in sentences
61
+ ] # Get the number of tokens for each sentence
62
+
63
+ chunks = []
64
+ tokens_so_far = 0
65
+ chunk = []
66
+ for sentence, num_tokens in zip(sentences, total_tokens):
67
+ if not tokens_so_far: # If this is the first sentence in the chunk
68
+ if (
69
+ num_tokens > split_chunk_tokens
70
+ ): # If the sentence is longer than the max number of tokens, add it as its own chunk
71
+ chunk.append(sentence)
72
+ chunks.append(" ".join(chunk))
73
+ chunk = []
74
+ else: # If this is not the first sentence in the chunk
75
+ if (
76
+ tokens_so_far + num_tokens > split_chunk_tokens
77
+ ): # If the sentence would make the chunk longer than the max number of tokens, add the chunk to the list of chunks
78
+ chunks.append(" ".join(chunk))
79
+ chunk = []
80
+ tokens_so_far = 0
81
+
82
+ # Otherwise, add the sentence to the chunk and add the number of tokens to the total
83
+ chunk.append(sentence)
84
+ tokens_so_far += num_tokens + 1
85
+
86
+ # In case the file is smaller than the max number of tokens, add the last chunk
87
+ if not chunks:
88
+ chunks.append(" ".join(chunk))
89
+ return chunks
90
+
91
+
92
+ def embed(prompt): # Embed the prompt
93
+ embeds = []
94
+ if type(prompt) == str:
95
+ if (
96
+ get_num_tokens(prompt) > embed_context_window
97
+ ): # If token_length of prompt > context_window
98
+ prompt = split_into_many(prompt) # Split prompt into multiple chunks
99
+ else: # If token_length of prompt <= context_window
100
+ embeds = openai.Embedding.create(input=prompt, model=embed_model)[
101
+ "data"
102
+ ] # Embed prompt
103
+ if not embeds: # If the prompt was split into/is set of chunks
104
+ max_num_chunks = (
105
+ embed_context_window // split_chunk_tokens
106
+ ) # Number of chunks that can fit in the context window
107
+ for i in range(
108
+ 0, math.ceil(len(prompt) / max_num_chunks)
109
+ ): # For each batch of chunks
110
+ embeds.extend(
111
+ openai.Embedding.create(
112
+ input=prompt[i * max_num_chunks : (i + 1) * max_num_chunks],
113
+ model=embed_model,
114
+ )["data"]
115
+ ) # Embed the batch of chunks
116
+ return embeds # Return the list of embeddings
117
+
118
+
119
+ def embed_file(filename): # Create embeddings for a file
120
+ source_type = "file" # To help distinguish between local/URL files and URLs
121
+ file_source = "" # Source of the file
122
+ file_chunks = [] # List of file chunks (from the file)
123
+ file_vectors = [] # List of lists of file embeddings (from each chunk)
124
+
125
+ try:
126
+ extracted_text = (
127
+ textract.process(filename)
128
+ .decode("utf-8") # Extracted text is in bytes, convert to string
129
+ .encode("ascii", "ignore") # Remove non-ascii characters
130
+ .decode() # Convert back to string
131
+ )
132
+ if not extracted_text: # If the file is empty
133
+ raise Exception
134
+ os.remove(
135
+ filename
136
+ ) # Remove the file from the server since it is no longer needed
137
+ file_source = filename
138
+ file_chunks = split_into_many(extracted_text) # Split the text into chunks
139
+ file_vectors = [x["embedding"] for x in embed(file_chunks)] # Embed the chunks
140
+ except Exception: # If the file cannot be extracted, return empty values
141
+ if os.path.exists(filename): # If the file still exists
142
+ os.remove(
143
+ filename
144
+ ) # Remove the file from the server since it is no longer needed
145
+ source_type = ""
146
+ file_source = ""
147
+ file_chunks = []
148
+ file_vectors = []
149
+
150
+ return source_type, file_source, file_chunks, file_vectors
151
+
152
+
153
+ def embed_url(url): # Create embeddings for a url
154
+ source_type = "url" # To help distinguish between local/URL files and URLs
155
+ url_source = "" # Source of the url
156
+ url_chunks = [] # List of url chunks (for the url)
157
+ url_vectors = [] # List of list of url embeddings (for each chunk)
158
+ filename = "" # Filename of the url if it is a file
159
+
160
+ try:
161
+ if validators.url(url, public=True): # Verify url is a valid and public
162
+ response = requests.get(url) # Get the url info
163
+ header = response.headers["Content-Type"] # Get the header of the url
164
+ is_application = (
165
+ header.split("/")[0] == "application"
166
+ ) # Check if the url is a file
167
+
168
+ if is_application: # If url is a file, call embed_file on the file
169
+ filetype = header.split("/")[1] # Get the filetype
170
+ url_parts = url.split("/") # Get the parts of the url
171
+ filename = str(
172
+ "./"
173
+ + " ".join(
174
+ url_parts[:-1] + [url_parts[-1].split(".")[0]]
175
+ ) # Replace / with whitespace in the filename to avoid issues with the file path and remove the file extension since it may not match the actual filetype
176
+ + "."
177
+ + filetype
178
+ ) # Create the filename
179
+ with requests.get(
180
+ url, stream=True
181
+ ) as stream_response: # Download the file
182
+ stream_response.raise_for_status()
183
+ with open(filename, "wb") as file:
184
+ for chunk in stream_response.iter_content(
185
+ chunk_size=download_chunk_size
186
+ ):
187
+ file.write(chunk)
188
+ return embed_file(filename) # Embed the file
189
+ else: # If url is a webpage, use BeautifulSoup to extract the text
190
+ soup = bs4.BeautifulSoup(response.text) # Create a BeautifulSoup object
191
+ extracted_text = (
192
+ soup.get_text() # Extract the text from the webpage
193
+ .encode("ascii", "ignore") # Remove non-ascii characters
194
+ .decode() # Convert back to string
195
+ )
196
+ if not extracted_text: # If the webpage is empty
197
+ raise Exception
198
+ url_source = url
199
+ url_chunks = split_into_many(
200
+ extracted_text
201
+ ) # Split the text into chunks
202
+ url_vectors = [
203
+ x["embedding"] for x in embed(url_chunks[-1])
204
+ ] # Embed the chunks
205
+ else: # If url is not valid or public
206
+ raise Exception
207
+ except Exception: # If the url cannot be extracted, return empty values
208
+ source_type = ""
209
+ url_source = ""
210
+ url_chunks = []
211
+ url_vectors = []
212
+
213
+ return source_type, url_source, url_chunks, url_vectors
214
+
215
+
216
+ def get_most_relevant(
217
+ prompt_embedding, sources_embeddings
218
+ ): # Get which sources/chunks are most relevant to the prompt
219
+ sources_indices = [] # List of indices of the most relevant sources
220
+ sources_cosine_sims = [] # List of cosine similarities of the most relevant sources
221
+
222
+ for (
223
+ source_embeddings
224
+ ) in (
225
+ sources_embeddings
226
+ ): # source_embeddings contains all the embeddings of each chunk in a source
227
+ cosine_sims = np.array(
228
+ (source_embeddings @ prompt_embedding)
229
+ / (
230
+ np.linalg.norm(source_embeddings, axis=1)
231
+ * np.linalg.norm(prompt_embedding)
232
+ )
233
+ ) # Calculate the cosine similarity between the prompt and each chunk's vector
234
+ # Get the indices of the most relevant chunks: https://stackoverflow.com/questions/6910641/how-do-i-get-indices-of-n-maximum-values-in-a-numpy-array
235
+ num_chunks = min(
236
+ num_citations, len(cosine_sims)
237
+ ) # In case there are less chunks than num_citations
238
+ indices = np.argpartition(cosine_sims, -num_chunks)[
239
+ -num_chunks:
240
+ ] # Get the indices of the most relevant chunks
241
+ indices = indices[np.argsort(cosine_sims[indices])] # Sort the indices
242
+ cosine_sims = cosine_sims[
243
+ indices
244
+ ] # Get the cosine similarities of the most relevant chunks
245
+ sources_indices.append(indices) # Add the indices to sources_indices
246
+ sources_cosine_sims.append(
247
+ cosine_sims
248
+ ) # Add the cosine similarities to sources_cosine_sims
249
+
250
+ # Use sources_indices and sources_cosine_sims to get the most relevant sources/chunks
251
+ indexes = []
252
+ max_cosine_sims = []
253
+ for source_idx in range(len(sources_indices)): # For each source
254
+ for chunk_idx in range(len(sources_indices[source_idx])): # For each chunk
255
+ sources_chunk_idx = sources_indices[source_idx][
256
+ chunk_idx
257
+ ] # Get the index of the chunk
258
+ similarity = sources_cosine_sims[source_idx][
259
+ chunk_idx
260
+ ] # Get the cosine similarity of the chunk
261
+ if len(max_cosine_sims) < num_citations: # If max_values is not full
262
+ indexes.append(
263
+ [source_idx, sources_chunk_idx]
264
+ ) # Add the source/chunk index pair to indexes
265
+ max_cosine_sims.append(
266
+ similarity
267
+ ) # Add the cosine similarity to max_values
268
+ elif len(max_cosine_sims) == num_citations and similarity > min(
269
+ max_cosine_sims
270
+ ): # If max_values is full and the current cosine similarity is greater than the minimum cosine similarity in max_values
271
+ indexes.append(
272
+ [source_idx, sources_chunk_idx]
273
+ ) # Add the source/chunk index pair to indexes
274
+ max_cosine_sims.append(
275
+ similarity
276
+ ) # Add the cosine similarity to max_values
277
+ min_idx = max_cosine_sims.index(
278
+ min(max_cosine_sims)
279
+ ) # Get the index of the minimum cosine similarity in max_values
280
+ indexes.pop(
281
+ min_idx
282
+ ) # Remove the source/chunk index pair at the minimum cosine similarity index in indexes
283
+ max_cosine_sims.pop(
284
+ min_idx
285
+ ) # Remove the minimum cosine similarity in max_values
286
+ else: # If max_values is full and the current cosine similarity is less than the minimum cosine similarity in max_values
287
+ pass
288
+ return indexes
289
+
290
+
291
+ def process_source(
292
+ source, source_type
293
+ ): # Process the source name to be used in a message, since URL files are processed differently
294
+ return (
295
+ source if source_type == "file" else source.replace(" ", "/")
296
+ ) # In case this is a URL, reverse what was done in embed_url
297
+
298
+
299
+ # TODO: Find better way to create/store messages instead of everytime a new question is asked
300
+ def ask(): # Ask a question
301
+ messages = [
302
+ {
303
+ "role": "system",
304
+ "content": str(
305
+ "You are a helpful chatbot that answers questions a user may have about a topic. "
306
+ + "Sometimes, the user may give you external data from which you can use as needed. "
307
+ + "They will give it to you in the following way:\n"
308
+ + "Source 1: the source's name\n"
309
+ + "Text 1: the relevant text from the source\n"
310
+ + "Source 2: the source's name\n"
311
+ + "Text 2: the relevant text from the source\n"
312
+ + "...\n"
313
+ + "You can use this data to answer the user's questions or to ask the user questions. "
314
+ + "Take note that if you plan to reference a source, ALWAYS do so using the source's name.\n"
315
+ ),
316
+ },
317
+ {"role": "user", "content": st.session_state["questions"][0]},
318
+ ] # Add the system's introduction message and the user's first question to messages
319
+ show_message(
320
+ st.session_state["questions"][0],
321
+ is_user=True,
322
+ key=str(uuid.uuid4()),
323
+ avatar_style=user_avatar_style,
324
+ ) # Display user's first question
325
+
326
+ if (
327
+ len(st.session_state["questions"]) > 1 and st.session_state["answers"]
328
+ ): # If this is not the first question
329
+ for interaction, message in enumerate(
330
+ [
331
+ message
332
+ for pair in zip(
333
+ st.session_state["answers"], st.session_state["questions"][1:]
334
+ )
335
+ for message in pair
336
+ ] # Get the messages from the previous conversation in the order of [answer, question, answer, question, ...]: https://stackoverflow.com/questions/7946798/interleave-multiple-lists-of-the-same-length-in-python
337
+ ):
338
+ if interaction % 2 == 0: # If the message is an answer
339
+ messages.append(
340
+ {"role": "assistant", "content": message}
341
+ ) # Add the answer to messages
342
+ show_message(
343
+ message,
344
+ key=str(uuid.uuid4()),
345
+ avatar_style=assistant_avatar_style,
346
+ ) # Display the answer
347
+ else: # If the message is a question
348
+ messages.append(
349
+ {"role": "user", "content": message}
350
+ ) # Add the question to messages
351
+ show_message(
352
+ message,
353
+ is_user=True,
354
+ key=str(uuid.uuid4()),
355
+ avatar_style=user_avatar_style,
356
+ ) # Display the question
357
+
358
+ if (
359
+ st.session_state["sources_types"]
360
+ and st.session_state["sources"]
361
+ and st.session_state["chunks"]
362
+ and st.session_state["vectors"]
363
+ ): # If there are sources that were uploaded
364
+ prompt_embedding = np.array(
365
+ embed(st.session_state["questions"][-1])[0]["embedding"]
366
+ ) # Embed the last question
367
+ indexes = get_most_relevant(
368
+ prompt_embedding, st.session_state["vectors"]
369
+ ) # Get the most relevant chunks
370
+ if indexes: # If there are relevant chunks
371
+ messages[-1]["content"] += str(
372
+ "Here are some sources that may be helpful:\n"
373
+ ) # Add the sources to the last message
374
+ for idx, ind in enumerate(indexes):
375
+ source_idx, chunk_idx = ind[0], ind[1] # Get the source and chunk index
376
+ messages[-1]["content"] += str(
377
+ "Source "
378
+ + str(idx + 1)
379
+ + ": "
380
+ + process_source(
381
+ st.session_state["sources"][source_idx],
382
+ st.session_state["sources_types"][source_idx],
383
+ )
384
+ + "\n"
385
+ + "Text "
386
+ + str(idx + 1)
387
+ + ": "
388
+ + st.session_state["chunks"][source_idx][chunk_idx] # Get the chunk
389
+ + "\n"
390
+ )
391
+
392
+ while (
393
+ get_num_tokens("\n".join([message["content"] for message in messages]))
394
+ > llm_context_window
395
+ ): # If the context window is too large
396
+ if (
397
+ len(messages) == 2
398
+ ): # If there is only the introduction message and the user's most recent question
399
+ max_tokens_left = llm_context_window - get_num_tokens(
400
+ messages[0]["content"]
401
+ ) # Get the maximum number of tokens that can be present in the question
402
+ messages[1]["content"] = messages[1]["content"][
403
+ :max_tokens_left
404
+ ] # Truncate the question, from https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them 4 chars ~= 1 token, but it isn't certain that this is the case, so we will just truncate the question to max_tokens_left characters to be safe
405
+ else: # If there are more than 2 messages
406
+ messages.pop(1) # Remove the oldest question
407
+ messages.pop(2) # Remove the oldest answer
408
+
409
+ answer = openai.ChatCompletion.create(model=llm_model, messages=messages)[
410
+ "choices"
411
+ ][0]["message"][
412
+ "content"
413
+ ] # Get the answer from the chatbot
414
+ st.session_state["answers"].append(answer) # Add the answer to answers
415
+ show_message(
416
+ st.session_state["answers"][-1],
417
+ key=str(uuid.uuid4()),
418
+ avatar_style=assistant_avatar_style,
419
+ ) # Display the answer
420
+
421
+
422
+ # Main function, defines layout of the app
423
+ def main():
424
+ # Initialize session state variables
425
+ if "questions" not in st.session_state:
426
+ st.session_state["questions"] = []
427
+ if "answers" not in st.session_state:
428
+ st.session_state["answers"] = []
429
+ if "sources_types" not in st.session_state:
430
+ st.session_state["sources_types"] = []
431
+ if "sources" not in st.session_state:
432
+ st.session_state["sources"] = []
433
+ if "chunks" not in st.session_state:
434
+ st.session_state["chunks"] = []
435
+ if "vectors" not in st.session_state:
436
+ st.session_state["vectors"] = []
437
+
438
+ st.title("CacheChat :money_with_wings:") # Title
439
+ st.markdown(
440
+ "Check out the repo [here](https://github.com/andrewhinh/CacheChat) and notes on using the app [here](https://github.com/andrewhinh/CacheChat#notes)."
441
+ ) # Link to repo
442
+
443
+ uploaded_files = st.file_uploader(
444
+ "Choose file(s):", accept_multiple_files=True, key="files"
445
+ ) # File upload section
446
+ if uploaded_files: # If (a) file(s) is/are uploaded, create embeddings
447
+ with st.spinner("Processing..."): # Show loading spinner
448
+ for uploaded_file in uploaded_files:
449
+ if not (
450
+ uploaded_file.name in st.session_state["sources"]
451
+ ): # If the file has not been uploaded, process it
452
+ with open(uploaded_file.name, "wb") as file: # Save file to disk
453
+ file.write(uploaded_file.getbuffer())
454
+ source_type, file_source, file_chunks, file_vectors = embed_file(
455
+ uploaded_file.name
456
+ ) # Embed file
457
+ if (
458
+ not source_type
459
+ and not file_source
460
+ and not file_chunks
461
+ and not file_vectors
462
+ ): # If the file is invalid
463
+ st.error("Invalid file(s). Please try again.")
464
+ else: # If the file is valid
465
+ st.session_state["sources_types"].append(source_type)
466
+ st.session_state["sources"].append(file_source)
467
+ st.session_state["chunks"].append(file_chunks)
468
+ st.session_state["vectors"].append(file_vectors)
469
+
470
+ with st.form(key="url", clear_on_submit=True): # form for question input
471
+ uploaded_url = st.text_input(
472
+ "Enter a URL:",
473
+ placeholder="https://www.africau.edu/images/default/sample.pdf",
474
+ ) # URL input text box
475
+ upload_url_button = st.form_submit_button(label="Add URL") # Add URL button
476
+ if upload_url_button and uploaded_url: # If a URL is entered, create embeddings
477
+ with st.spinner("Processing..."): # Show loading spinner
478
+ if not (
479
+ uploaded_url in st.session_state["sources"] # Non-file URL in sources
480
+ or "./" + uploaded_url.replace("/", " ") # File URL in sources
481
+ in st.session_state["sources"]
482
+ ): # If the URL has not been uploaded, process it
483
+ source_type, url_source, url_chunks, url_vectors = embed_url(
484
+ uploaded_url
485
+ ) # Embed URL
486
+ if (
487
+ not source_type
488
+ and not url_source
489
+ and not url_chunks
490
+ and not url_vectors
491
+ ): # If the URL is invalid
492
+ st.error("Invalid URL. Please try again.")
493
+ else: # If the URL is valid
494
+ st.session_state["sources_types"].append(source_type)
495
+ st.session_state["sources"].append(url_source)
496
+ st.session_state["chunks"].append(url_chunks)
497
+ st.session_state["vectors"].append(url_vectors)
498
+
499
+ st.divider() # Create a divider between the uploads and the chat
500
+
501
+ input_container = (
502
+ st.container()
503
+ ) # container for inputs/uploads, https://docs.streamlit.io/library/api-reference/layout/st.container
504
+ response_container = (
505
+ st.container()
506
+ ) # container for chat history, https://docs.streamlit.io/library/api-reference/layout/st.container
507
+
508
+ with input_container:
509
+ with st.form(key="question", clear_on_submit=True): # form for question input
510
+ uploaded_question = st.text_input(
511
+ "Enter your input:",
512
+ placeholder="e.g: Summarize the research paper in 3 sentences.",
513
+ key="input",
514
+ ) # question text box
515
+ uploaded_question_button = st.form_submit_button(
516
+ label="Send"
517
+ ) # send button
518
+
519
+ with response_container:
520
+ if (
521
+ uploaded_question_button and uploaded_question
522
+ ): # if send button is pressed and text box is not empty
523
+ with st.spinner("Thinking..."): # show loading spinner
524
+ st.session_state["questions"].append(
525
+ uploaded_question
526
+ ) # add question to questions
527
+ ask() # ask question to chatbot
528
+
529
+
530
+ if __name__ == "__main__":
531
+ main()