File size: 23,712 Bytes
734a77e
75d2364
734a77e
 
 
 
 
 
ee165bf
734a77e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
# Importing the libraries
import os
import math
import requests
import bs4
from dotenv import load_dotenv
import nltk
import numpy as np
import openai
import streamlit as st
from streamlit_chat import message as show_message
import textract
import tiktoken
import uuid
import validators


# Helper variables
load_dotenv()
openai.api_key = os.environ['openapi']  # Load OpenAI API key from .env file

llm_model = "gpt-3.5-turbo"  # https://platform.openai.com/docs/guides/chat/introduction
llm_context_window = (
    4097  # https://platform.openai.com/docs/guides/chat/managing-tokens
)
embed_context_window, embed_model = (
    8191,
    "text-embedding-ada-002",
)  # https://platform.openai.com/docs/guides/embeddings/second-generation-models
nltk.download(
    "punkt"
)  # Download the nltk punkt tokenizer for splitting text into sentences
tokenizer = tiktoken.get_encoding(
    "cl100k_base"
)  # Load the cl100k_base tokenizer which is designed to work with the ada-002 model (engine)

download_chunk_size = 128  # TODO: Find optimal chunk size for downloading files
split_chunk_tokens = 300  # TODO: Find optimal chunk size for splitting text
num_citations = 5  # TODO: Find optimal number of citations to give context to the LLM

# Streamlit settings
user_avatar_style = "fun-emoji"  # https://www.dicebear.com/styles
assistant_avatar_style = "bottts-neutral"


# Helper functions
def get_num_tokens(text):  # Count the number of tokens in a string
    return len(
        tokenizer.encode(text, disallowed_special=())
    )  # disallowed_special=() removes the special tokens)


#   TODO:
#   Currently, any sentence that is longer than the max number of tokens will be its own chunk
#   This is not ideal, since this doesn't ensure that the chunks are of a maximum size
#   Find a way to split the sentence into chunks of a maximum size
def split_into_many(text):  # Split text into chunks of a maximum number of tokens
    sentences = nltk.tokenize.sent_tokenize(text)  # Split the text into sentences
    total_tokens = [
        get_num_tokens(sentence) for sentence in sentences
    ]  # Get the number of tokens for each sentence

    chunks = []
    tokens_so_far = 0
    chunk = []
    for sentence, num_tokens in zip(sentences, total_tokens):
        if not tokens_so_far:  # If this is the first sentence in the chunk
            if (
                num_tokens > split_chunk_tokens
            ):  # If the sentence is longer than the max number of tokens, add it as its own chunk
                chunk.append(sentence)
                chunks.append(" ".join(chunk))
                chunk = []
        else:  # If this is not the first sentence in the chunk
            if (
                tokens_so_far + num_tokens > split_chunk_tokens
            ):  # If the sentence would make the chunk longer than the max number of tokens, add the chunk to the list of chunks
                chunks.append(" ".join(chunk))
                chunk = []
                tokens_so_far = 0

        # Otherwise, add the sentence to the chunk and add the number of tokens to the total
        chunk.append(sentence)
        tokens_so_far += num_tokens + 1

    # In case the file is smaller than the max number of tokens, add the last chunk
    if not chunks:
        chunks.append(" ".join(chunk))
    return chunks


def embed(prompt):  # Embed the prompt
    embeds = []
    if type(prompt) == str:
        if (
            get_num_tokens(prompt) > embed_context_window
        ):  # If token_length of prompt > context_window
            prompt = split_into_many(prompt)  # Split prompt into multiple chunks
        else:  # If token_length of prompt <= context_window
            embeds = openai.Embedding.create(input=prompt, model=embed_model)[
                "data"
            ]  # Embed prompt
    if not embeds:  # If the prompt was split into/is set of chunks
        max_num_chunks = (
            embed_context_window // split_chunk_tokens
        )  # Number of chunks that can fit in the context window
        for i in range(
            0, math.ceil(len(prompt) / max_num_chunks)
        ):  # For each batch of chunks
            embeds.extend(
                openai.Embedding.create(
                    input=prompt[i * max_num_chunks : (i + 1) * max_num_chunks],
                    model=embed_model,
                )["data"]
            )  # Embed the batch of chunks
    return embeds  # Return the list of embeddings


def embed_file(filename):  # Create embeddings for a file
    source_type = "file"  # To help distinguish between local/URL files and URLs
    file_source = ""  # Source of the file
    file_chunks = []  # List of file chunks (from the file)
    file_vectors = []  # List of lists of file embeddings (from each chunk)

    try:
        extracted_text = (
            textract.process(filename)
            .decode("utf-8")  # Extracted text is in bytes, convert to string
            .encode("ascii", "ignore")  # Remove non-ascii characters
            .decode()  # Convert back to string
        )
        if not extracted_text:  # If the file is empty
            raise Exception
        os.remove(
            filename
        )  # Remove the file from the server since it is no longer needed
        file_source = filename
        file_chunks = split_into_many(extracted_text)  # Split the text into chunks
        file_vectors = [x["embedding"] for x in embed(file_chunks)]  # Embed the chunks
    except Exception:  # If the file cannot be extracted, return empty values
        if os.path.exists(filename):  # If the file still exists
            os.remove(
                filename
            )  # Remove the file from the server since it is no longer needed
        source_type = ""
        file_source = ""
        file_chunks = []
        file_vectors = []

    return source_type, file_source, file_chunks, file_vectors


def embed_url(url):  # Create embeddings for a url
    source_type = "url"  # To help distinguish between local/URL files and URLs
    url_source = ""  # Source of the url
    url_chunks = []  # List of url chunks (for the url)
    url_vectors = []  # List of list of url embeddings (for each chunk)
    filename = ""  # Filename of the url if it is a file

    try:
        if validators.url(url, public=True):  # Verify url is a valid and public
            response = requests.get(url)  # Get the url info
            header = response.headers["Content-Type"]  # Get the header of the url
            is_application = (
                header.split("/")[0] == "application"
            )  # Check if the url is a file

            if is_application:  # If url is a file, call embed_file on the file
                filetype = header.split("/")[1]  # Get the filetype
                url_parts = url.split("/")  # Get the parts of the url
                filename = str(
                    "./"
                    + " ".join(
                        url_parts[:-1] + [url_parts[-1].split(".")[0]]
                    )  # Replace / with whitespace in the filename to avoid issues with the file path and remove the file extension since it may not match the actual filetype
                    + "."
                    + filetype
                )  # Create the filename
                with requests.get(
                    url, stream=True
                ) as stream_response:  # Download the file
                    stream_response.raise_for_status()
                    with open(filename, "wb") as file:
                        for chunk in stream_response.iter_content(
                            chunk_size=download_chunk_size
                        ):
                            file.write(chunk)
                return embed_file(filename)  # Embed the file
            else:  # If url is a webpage, use BeautifulSoup to extract the text
                soup = bs4.BeautifulSoup(response.text)  # Create a BeautifulSoup object
                extracted_text = (
                    soup.get_text()  # Extract the text from the webpage
                    .encode("ascii", "ignore")  # Remove non-ascii characters
                    .decode()  # Convert back to string
                )
                if not extracted_text:  # If the webpage is empty
                    raise Exception
                url_source = url
                url_chunks = split_into_many(
                    extracted_text
                )  # Split the text into chunks
                url_vectors = [
                    x["embedding"] for x in embed(url_chunks[-1])
                ]  # Embed the chunks
        else:  # If url is not valid or public
            raise Exception
    except Exception:  # If the url cannot be extracted, return empty values
        source_type = ""
        url_source = ""
        url_chunks = []
        url_vectors = []

    return source_type, url_source, url_chunks, url_vectors


def get_most_relevant(
    prompt_embedding, sources_embeddings
):  # Get which sources/chunks are most relevant to the prompt
    sources_indices = []  # List of indices of the most relevant sources
    sources_cosine_sims = []  # List of cosine similarities of the most relevant sources

    for (
        source_embeddings
    ) in (
        sources_embeddings
    ):  # source_embeddings contains all the embeddings of each chunk in a source
        cosine_sims = np.array(
            (source_embeddings @ prompt_embedding)
            / (
                np.linalg.norm(source_embeddings, axis=1)
                * np.linalg.norm(prompt_embedding)
            )
        )  # Calculate the cosine similarity between the prompt and each chunk's vector
        # Get the indices of the most relevant chunks: https://stackoverflow.com/questions/6910641/how-do-i-get-indices-of-n-maximum-values-in-a-numpy-array
        num_chunks = min(
            num_citations, len(cosine_sims)
        )  # In case there are less chunks than num_citations
        indices = np.argpartition(cosine_sims, -num_chunks)[
            -num_chunks:
        ]  # Get the indices of the most relevant chunks
        indices = indices[np.argsort(cosine_sims[indices])]  # Sort the indices
        cosine_sims = cosine_sims[
            indices
        ]  # Get the cosine similarities of the most relevant chunks
        sources_indices.append(indices)  # Add the indices to sources_indices
        sources_cosine_sims.append(
            cosine_sims
        )  # Add the cosine similarities to sources_cosine_sims

    # Use sources_indices and sources_cosine_sims to get the most relevant sources/chunks
    indexes = []
    max_cosine_sims = []
    for source_idx in range(len(sources_indices)):  # For each source
        for chunk_idx in range(len(sources_indices[source_idx])):  # For each chunk
            sources_chunk_idx = sources_indices[source_idx][
                chunk_idx
            ]  # Get the index of the chunk
            similarity = sources_cosine_sims[source_idx][
                chunk_idx
            ]  # Get the cosine similarity of the chunk
            if len(max_cosine_sims) < num_citations:  # If max_values is not full
                indexes.append(
                    [source_idx, sources_chunk_idx]
                )  # Add the source/chunk index pair to indexes
                max_cosine_sims.append(
                    similarity
                )  # Add the cosine similarity to max_values
            elif len(max_cosine_sims) == num_citations and similarity > min(
                max_cosine_sims
            ):  # If max_values is full and the current cosine similarity is greater than the minimum cosine similarity in max_values
                indexes.append(
                    [source_idx, sources_chunk_idx]
                )  # Add the source/chunk index pair to indexes
                max_cosine_sims.append(
                    similarity
                )  # Add the cosine similarity to max_values
                min_idx = max_cosine_sims.index(
                    min(max_cosine_sims)
                )  # Get the index of the minimum cosine similarity in max_values
                indexes.pop(
                    min_idx
                )  # Remove the source/chunk index pair at the minimum cosine similarity index in indexes
                max_cosine_sims.pop(
                    min_idx
                )  # Remove the minimum cosine similarity in max_values
            else:  # If max_values is full and the current cosine similarity is less than the minimum cosine similarity in max_values
                pass
    return indexes


def process_source(
    source, source_type
):  # Process the source name to be used in a message, since URL files are processed differently
    return (
        source if source_type == "file" else source.replace(" ", "/")
    )  # In case this is a URL, reverse what was done in embed_url


#   TODO: Find better way to create/store messages instead of everytime a new question is asked
def ask():  # Ask a question
    messages = [
        {
            "role": "system",
            "content": str(
                "You are a helpful chatbot that answers questions a user may have about a topic. "
                + "Sometimes, the user may give you external data from which you can use as needed. "
                + "They will give it to you in the following way:\n"
                + "Source 1: the source's name\n"
                + "Text 1: the relevant text from the source\n"
                + "Source 2: the source's name\n"
                + "Text 2: the relevant text from the source\n"
                + "...\n"
                + "You can use this data to answer the user's questions or to ask the user questions. "
                + "Take note that if you plan to reference a source, ALWAYS do so using the source's name.\n"
            ),
        },
        {"role": "user", "content": st.session_state["questions"][0]},
    ]  # Add the system's introduction message and the user's first question to messages
    show_message(
        st.session_state["questions"][0],
        is_user=True,
        key=str(uuid.uuid4()),
        avatar_style=user_avatar_style,
    )  # Display user's first question

    if (
        len(st.session_state["questions"]) > 1 and st.session_state["answers"]
    ):  # If this is not the first question
        for interaction, message in enumerate(
            [
                message
                for pair in zip(
                    st.session_state["answers"], st.session_state["questions"][1:]
                )
                for message in pair
            ]  # Get the messages from the previous conversation in the order of [answer, question, answer, question, ...]: https://stackoverflow.com/questions/7946798/interleave-multiple-lists-of-the-same-length-in-python
        ):
            if interaction % 2 == 0:  # If the message is an answer
                messages.append(
                    {"role": "assistant", "content": message}
                )  # Add the answer to messages
                show_message(
                    message,
                    key=str(uuid.uuid4()),
                    avatar_style=assistant_avatar_style,
                )  # Display the answer
            else:  # If the message is a question
                messages.append(
                    {"role": "user", "content": message}
                )  # Add the question to messages
                show_message(
                    message,
                    is_user=True,
                    key=str(uuid.uuid4()),
                    avatar_style=user_avatar_style,
                )  # Display the question

    if (
        st.session_state["sources_types"]
        and st.session_state["sources"]
        and st.session_state["chunks"]
        and st.session_state["vectors"]
    ):  # If there are sources that were uploaded
        prompt_embedding = np.array(
            embed(st.session_state["questions"][-1])[0]["embedding"]
        )  # Embed the last question
        indexes = get_most_relevant(
            prompt_embedding, st.session_state["vectors"]
        )  # Get the most relevant chunks
        if indexes:  # If there are relevant chunks
            messages[-1]["content"] += str(
                "Here are some sources that may be helpful:\n"
            )  # Add the sources to the last message
            for idx, ind in enumerate(indexes):
                source_idx, chunk_idx = ind[0], ind[1]  # Get the source and chunk index
                messages[-1]["content"] += str(
                    "Source "
                    + str(idx + 1)
                    + ": "
                    + process_source(
                        st.session_state["sources"][source_idx],
                        st.session_state["sources_types"][source_idx],
                    )
                    + "\n"
                    + "Text "
                    + str(idx + 1)
                    + ": "
                    + st.session_state["chunks"][source_idx][chunk_idx]  # Get the chunk
                    + "\n"
                )

    while (
        get_num_tokens("\n".join([message["content"] for message in messages]))
        > llm_context_window
    ):  # If the context window is too large
        if (
            len(messages) == 2
        ):  # If there is only the introduction message and the user's most recent question
            max_tokens_left = llm_context_window - get_num_tokens(
                messages[0]["content"]
            )  # Get the maximum number of tokens that can be present in the question
            messages[1]["content"] = messages[1]["content"][
                :max_tokens_left
            ]  # Truncate the question, from https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them 4 chars ~= 1 token, but it isn't certain that this is the case, so we will just truncate the question to max_tokens_left characters to be safe
        else:  # If there are more than 2 messages
            messages.pop(1)  # Remove the oldest question
            messages.pop(2)  # Remove the oldest answer

    answer = openai.ChatCompletion.create(model=llm_model, messages=messages)[
        "choices"
    ][0]["message"][
        "content"
    ]  # Get the answer from the chatbot
    st.session_state["answers"].append(answer)  # Add the answer to answers
    show_message(
        st.session_state["answers"][-1],
        key=str(uuid.uuid4()),
        avatar_style=assistant_avatar_style,
    )  # Display the answer


# Main function, defines layout of the app
def main():
    # Initialize session state variables
    if "questions" not in st.session_state:
        st.session_state["questions"] = []
    if "answers" not in st.session_state:
        st.session_state["answers"] = []
    if "sources_types" not in st.session_state:
        st.session_state["sources_types"] = []
    if "sources" not in st.session_state:
        st.session_state["sources"] = []
    if "chunks" not in st.session_state:
        st.session_state["chunks"] = []
    if "vectors" not in st.session_state:
        st.session_state["vectors"] = []

    st.title("CacheChat :money_with_wings:")  # Title
    st.markdown(
        "Check out the repo [here](https://github.com/andrewhinh/CacheChat) and notes on using the app [here](https://github.com/andrewhinh/CacheChat#notes)."
    )  # Link to repo

    uploaded_files = st.file_uploader(
        "Choose file(s):", accept_multiple_files=True, key="files"
    )  # File upload section
    if uploaded_files:  # If (a) file(s) is/are uploaded, create embeddings
        with st.spinner("Processing..."):  # Show loading spinner
            for uploaded_file in uploaded_files:
                if not (
                    uploaded_file.name in st.session_state["sources"]
                ):  # If the file has not been uploaded, process it
                    with open(uploaded_file.name, "wb") as file:  # Save file to disk
                        file.write(uploaded_file.getbuffer())
                    source_type, file_source, file_chunks, file_vectors = embed_file(
                        uploaded_file.name
                    )  # Embed file
                    if (
                        not source_type
                        and not file_source
                        and not file_chunks
                        and not file_vectors
                    ):  # If the file is invalid
                        st.error("Invalid file(s). Please try again.")
                    else:  # If the file is valid
                        st.session_state["sources_types"].append(source_type)
                        st.session_state["sources"].append(file_source)
                        st.session_state["chunks"].append(file_chunks)
                        st.session_state["vectors"].append(file_vectors)

    with st.form(key="url", clear_on_submit=True):  # form for question input
        uploaded_url = st.text_input(
            "Enter a URL:",
            placeholder="https://www.africau.edu/images/default/sample.pdf",
        )  # URL input text box
        upload_url_button = st.form_submit_button(label="Add URL")  # Add URL button
    if upload_url_button and uploaded_url:  # If a URL is entered, create embeddings
        with st.spinner("Processing..."):  # Show loading spinner
            if not (
                uploaded_url in st.session_state["sources"]  # Non-file URL in sources
                or "./" + uploaded_url.replace("/", " ")  # File URL in sources
                in st.session_state["sources"]
            ):  # If the URL has not been uploaded, process it
                source_type, url_source, url_chunks, url_vectors = embed_url(
                    uploaded_url
                )  # Embed URL
                if (
                    not source_type
                    and not url_source
                    and not url_chunks
                    and not url_vectors
                ):  # If the URL is invalid
                    st.error("Invalid URL. Please try again.")
                else:  # If the URL is valid
                    st.session_state["sources_types"].append(source_type)
                    st.session_state["sources"].append(url_source)
                    st.session_state["chunks"].append(url_chunks)
                    st.session_state["vectors"].append(url_vectors)

    st.divider()  # Create a divider between the uploads and the chat

    input_container = (
        st.container()
    )  # container for inputs/uploads, https://docs.streamlit.io/library/api-reference/layout/st.container
    response_container = (
        st.container()
    )  # container for chat history, https://docs.streamlit.io/library/api-reference/layout/st.container

    with input_container:
        with st.form(key="question", clear_on_submit=True):  # form for question input
            uploaded_question = st.text_input(
                "Enter your input:",
                placeholder="e.g: Summarize the research paper in 3 sentences.",
                key="input",
            )  # question text box
            uploaded_question_button = st.form_submit_button(
                label="Send"
            )  # send button

    with response_container:
        if (
            uploaded_question_button and uploaded_question
        ):  # if send button is pressed and text box is not empty
            with st.spinner("Thinking..."):  # show loading spinner
                st.session_state["questions"].append(
                    uploaded_question
                )  # add question to questions
                ask()  # ask question to chatbot


if __name__ == "__main__":
    main()