prjt / app.py
ANASDAVOODTK's picture
Update app.py
734a77e
# Importing the libraries
import os
import math
import requests
import bs4
from dotenv import load_dotenv
import nltk
import numpy as np
import openai
import streamlit as st
from streamlit_chat import message as show_message
import textract
import tiktoken
import uuid
import validators
# Helper variables
load_dotenv()
openai.api_key = os.environ['openapi'] # Load OpenAI API key from .env file
llm_model = "gpt-3.5-turbo" # https://platform.openai.com/docs/guides/chat/introduction
llm_context_window = (
4097 # https://platform.openai.com/docs/guides/chat/managing-tokens
)
embed_context_window, embed_model = (
8191,
"text-embedding-ada-002",
) # https://platform.openai.com/docs/guides/embeddings/second-generation-models
nltk.download(
"punkt"
) # Download the nltk punkt tokenizer for splitting text into sentences
tokenizer = tiktoken.get_encoding(
"cl100k_base"
) # Load the cl100k_base tokenizer which is designed to work with the ada-002 model (engine)
download_chunk_size = 128 # TODO: Find optimal chunk size for downloading files
split_chunk_tokens = 300 # TODO: Find optimal chunk size for splitting text
num_citations = 5 # TODO: Find optimal number of citations to give context to the LLM
# Streamlit settings
user_avatar_style = "fun-emoji" # https://www.dicebear.com/styles
assistant_avatar_style = "bottts-neutral"
# Helper functions
def get_num_tokens(text): # Count the number of tokens in a string
return len(
tokenizer.encode(text, disallowed_special=())
) # disallowed_special=() removes the special tokens)
# TODO:
# Currently, any sentence that is longer than the max number of tokens will be its own chunk
# This is not ideal, since this doesn't ensure that the chunks are of a maximum size
# Find a way to split the sentence into chunks of a maximum size
def split_into_many(text): # Split text into chunks of a maximum number of tokens
sentences = nltk.tokenize.sent_tokenize(text) # Split the text into sentences
total_tokens = [
get_num_tokens(sentence) for sentence in sentences
] # Get the number of tokens for each sentence
chunks = []
tokens_so_far = 0
chunk = []
for sentence, num_tokens in zip(sentences, total_tokens):
if not tokens_so_far: # If this is the first sentence in the chunk
if (
num_tokens > split_chunk_tokens
): # If the sentence is longer than the max number of tokens, add it as its own chunk
chunk.append(sentence)
chunks.append(" ".join(chunk))
chunk = []
else: # If this is not the first sentence in the chunk
if (
tokens_so_far + num_tokens > split_chunk_tokens
): # If the sentence would make the chunk longer than the max number of tokens, add the chunk to the list of chunks
chunks.append(" ".join(chunk))
chunk = []
tokens_so_far = 0
# Otherwise, add the sentence to the chunk and add the number of tokens to the total
chunk.append(sentence)
tokens_so_far += num_tokens + 1
# In case the file is smaller than the max number of tokens, add the last chunk
if not chunks:
chunks.append(" ".join(chunk))
return chunks
def embed(prompt): # Embed the prompt
embeds = []
if type(prompt) == str:
if (
get_num_tokens(prompt) > embed_context_window
): # If token_length of prompt > context_window
prompt = split_into_many(prompt) # Split prompt into multiple chunks
else: # If token_length of prompt <= context_window
embeds = openai.Embedding.create(input=prompt, model=embed_model)[
"data"
] # Embed prompt
if not embeds: # If the prompt was split into/is set of chunks
max_num_chunks = (
embed_context_window // split_chunk_tokens
) # Number of chunks that can fit in the context window
for i in range(
0, math.ceil(len(prompt) / max_num_chunks)
): # For each batch of chunks
embeds.extend(
openai.Embedding.create(
input=prompt[i * max_num_chunks : (i + 1) * max_num_chunks],
model=embed_model,
)["data"]
) # Embed the batch of chunks
return embeds # Return the list of embeddings
def embed_file(filename): # Create embeddings for a file
source_type = "file" # To help distinguish between local/URL files and URLs
file_source = "" # Source of the file
file_chunks = [] # List of file chunks (from the file)
file_vectors = [] # List of lists of file embeddings (from each chunk)
try:
extracted_text = (
textract.process(filename)
.decode("utf-8") # Extracted text is in bytes, convert to string
.encode("ascii", "ignore") # Remove non-ascii characters
.decode() # Convert back to string
)
if not extracted_text: # If the file is empty
raise Exception
os.remove(
filename
) # Remove the file from the server since it is no longer needed
file_source = filename
file_chunks = split_into_many(extracted_text) # Split the text into chunks
file_vectors = [x["embedding"] for x in embed(file_chunks)] # Embed the chunks
except Exception: # If the file cannot be extracted, return empty values
if os.path.exists(filename): # If the file still exists
os.remove(
filename
) # Remove the file from the server since it is no longer needed
source_type = ""
file_source = ""
file_chunks = []
file_vectors = []
return source_type, file_source, file_chunks, file_vectors
def embed_url(url): # Create embeddings for a url
source_type = "url" # To help distinguish between local/URL files and URLs
url_source = "" # Source of the url
url_chunks = [] # List of url chunks (for the url)
url_vectors = [] # List of list of url embeddings (for each chunk)
filename = "" # Filename of the url if it is a file
try:
if validators.url(url, public=True): # Verify url is a valid and public
response = requests.get(url) # Get the url info
header = response.headers["Content-Type"] # Get the header of the url
is_application = (
header.split("/")[0] == "application"
) # Check if the url is a file
if is_application: # If url is a file, call embed_file on the file
filetype = header.split("/")[1] # Get the filetype
url_parts = url.split("/") # Get the parts of the url
filename = str(
"./"
+ " ".join(
url_parts[:-1] + [url_parts[-1].split(".")[0]]
) # Replace / with whitespace in the filename to avoid issues with the file path and remove the file extension since it may not match the actual filetype
+ "."
+ filetype
) # Create the filename
with requests.get(
url, stream=True
) as stream_response: # Download the file
stream_response.raise_for_status()
with open(filename, "wb") as file:
for chunk in stream_response.iter_content(
chunk_size=download_chunk_size
):
file.write(chunk)
return embed_file(filename) # Embed the file
else: # If url is a webpage, use BeautifulSoup to extract the text
soup = bs4.BeautifulSoup(response.text) # Create a BeautifulSoup object
extracted_text = (
soup.get_text() # Extract the text from the webpage
.encode("ascii", "ignore") # Remove non-ascii characters
.decode() # Convert back to string
)
if not extracted_text: # If the webpage is empty
raise Exception
url_source = url
url_chunks = split_into_many(
extracted_text
) # Split the text into chunks
url_vectors = [
x["embedding"] for x in embed(url_chunks[-1])
] # Embed the chunks
else: # If url is not valid or public
raise Exception
except Exception: # If the url cannot be extracted, return empty values
source_type = ""
url_source = ""
url_chunks = []
url_vectors = []
return source_type, url_source, url_chunks, url_vectors
def get_most_relevant(
prompt_embedding, sources_embeddings
): # Get which sources/chunks are most relevant to the prompt
sources_indices = [] # List of indices of the most relevant sources
sources_cosine_sims = [] # List of cosine similarities of the most relevant sources
for (
source_embeddings
) in (
sources_embeddings
): # source_embeddings contains all the embeddings of each chunk in a source
cosine_sims = np.array(
(source_embeddings @ prompt_embedding)
/ (
np.linalg.norm(source_embeddings, axis=1)
* np.linalg.norm(prompt_embedding)
)
) # Calculate the cosine similarity between the prompt and each chunk's vector
# Get the indices of the most relevant chunks: https://stackoverflow.com/questions/6910641/how-do-i-get-indices-of-n-maximum-values-in-a-numpy-array
num_chunks = min(
num_citations, len(cosine_sims)
) # In case there are less chunks than num_citations
indices = np.argpartition(cosine_sims, -num_chunks)[
-num_chunks:
] # Get the indices of the most relevant chunks
indices = indices[np.argsort(cosine_sims[indices])] # Sort the indices
cosine_sims = cosine_sims[
indices
] # Get the cosine similarities of the most relevant chunks
sources_indices.append(indices) # Add the indices to sources_indices
sources_cosine_sims.append(
cosine_sims
) # Add the cosine similarities to sources_cosine_sims
# Use sources_indices and sources_cosine_sims to get the most relevant sources/chunks
indexes = []
max_cosine_sims = []
for source_idx in range(len(sources_indices)): # For each source
for chunk_idx in range(len(sources_indices[source_idx])): # For each chunk
sources_chunk_idx = sources_indices[source_idx][
chunk_idx
] # Get the index of the chunk
similarity = sources_cosine_sims[source_idx][
chunk_idx
] # Get the cosine similarity of the chunk
if len(max_cosine_sims) < num_citations: # If max_values is not full
indexes.append(
[source_idx, sources_chunk_idx]
) # Add the source/chunk index pair to indexes
max_cosine_sims.append(
similarity
) # Add the cosine similarity to max_values
elif len(max_cosine_sims) == num_citations and similarity > min(
max_cosine_sims
): # If max_values is full and the current cosine similarity is greater than the minimum cosine similarity in max_values
indexes.append(
[source_idx, sources_chunk_idx]
) # Add the source/chunk index pair to indexes
max_cosine_sims.append(
similarity
) # Add the cosine similarity to max_values
min_idx = max_cosine_sims.index(
min(max_cosine_sims)
) # Get the index of the minimum cosine similarity in max_values
indexes.pop(
min_idx
) # Remove the source/chunk index pair at the minimum cosine similarity index in indexes
max_cosine_sims.pop(
min_idx
) # Remove the minimum cosine similarity in max_values
else: # If max_values is full and the current cosine similarity is less than the minimum cosine similarity in max_values
pass
return indexes
def process_source(
source, source_type
): # Process the source name to be used in a message, since URL files are processed differently
return (
source if source_type == "file" else source.replace(" ", "/")
) # In case this is a URL, reverse what was done in embed_url
# TODO: Find better way to create/store messages instead of everytime a new question is asked
def ask(): # Ask a question
messages = [
{
"role": "system",
"content": str(
"You are a helpful chatbot that answers questions a user may have about a topic. "
+ "Sometimes, the user may give you external data from which you can use as needed. "
+ "They will give it to you in the following way:\n"
+ "Source 1: the source's name\n"
+ "Text 1: the relevant text from the source\n"
+ "Source 2: the source's name\n"
+ "Text 2: the relevant text from the source\n"
+ "...\n"
+ "You can use this data to answer the user's questions or to ask the user questions. "
+ "Take note that if you plan to reference a source, ALWAYS do so using the source's name.\n"
),
},
{"role": "user", "content": st.session_state["questions"][0]},
] # Add the system's introduction message and the user's first question to messages
show_message(
st.session_state["questions"][0],
is_user=True,
key=str(uuid.uuid4()),
avatar_style=user_avatar_style,
) # Display user's first question
if (
len(st.session_state["questions"]) > 1 and st.session_state["answers"]
): # If this is not the first question
for interaction, message in enumerate(
[
message
for pair in zip(
st.session_state["answers"], st.session_state["questions"][1:]
)
for message in pair
] # Get the messages from the previous conversation in the order of [answer, question, answer, question, ...]: https://stackoverflow.com/questions/7946798/interleave-multiple-lists-of-the-same-length-in-python
):
if interaction % 2 == 0: # If the message is an answer
messages.append(
{"role": "assistant", "content": message}
) # Add the answer to messages
show_message(
message,
key=str(uuid.uuid4()),
avatar_style=assistant_avatar_style,
) # Display the answer
else: # If the message is a question
messages.append(
{"role": "user", "content": message}
) # Add the question to messages
show_message(
message,
is_user=True,
key=str(uuid.uuid4()),
avatar_style=user_avatar_style,
) # Display the question
if (
st.session_state["sources_types"]
and st.session_state["sources"]
and st.session_state["chunks"]
and st.session_state["vectors"]
): # If there are sources that were uploaded
prompt_embedding = np.array(
embed(st.session_state["questions"][-1])[0]["embedding"]
) # Embed the last question
indexes = get_most_relevant(
prompt_embedding, st.session_state["vectors"]
) # Get the most relevant chunks
if indexes: # If there are relevant chunks
messages[-1]["content"] += str(
"Here are some sources that may be helpful:\n"
) # Add the sources to the last message
for idx, ind in enumerate(indexes):
source_idx, chunk_idx = ind[0], ind[1] # Get the source and chunk index
messages[-1]["content"] += str(
"Source "
+ str(idx + 1)
+ ": "
+ process_source(
st.session_state["sources"][source_idx],
st.session_state["sources_types"][source_idx],
)
+ "\n"
+ "Text "
+ str(idx + 1)
+ ": "
+ st.session_state["chunks"][source_idx][chunk_idx] # Get the chunk
+ "\n"
)
while (
get_num_tokens("\n".join([message["content"] for message in messages]))
> llm_context_window
): # If the context window is too large
if (
len(messages) == 2
): # If there is only the introduction message and the user's most recent question
max_tokens_left = llm_context_window - get_num_tokens(
messages[0]["content"]
) # Get the maximum number of tokens that can be present in the question
messages[1]["content"] = messages[1]["content"][
:max_tokens_left
] # Truncate the question, from https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them 4 chars ~= 1 token, but it isn't certain that this is the case, so we will just truncate the question to max_tokens_left characters to be safe
else: # If there are more than 2 messages
messages.pop(1) # Remove the oldest question
messages.pop(2) # Remove the oldest answer
answer = openai.ChatCompletion.create(model=llm_model, messages=messages)[
"choices"
][0]["message"][
"content"
] # Get the answer from the chatbot
st.session_state["answers"].append(answer) # Add the answer to answers
show_message(
st.session_state["answers"][-1],
key=str(uuid.uuid4()),
avatar_style=assistant_avatar_style,
) # Display the answer
# Main function, defines layout of the app
def main():
# Initialize session state variables
if "questions" not in st.session_state:
st.session_state["questions"] = []
if "answers" not in st.session_state:
st.session_state["answers"] = []
if "sources_types" not in st.session_state:
st.session_state["sources_types"] = []
if "sources" not in st.session_state:
st.session_state["sources"] = []
if "chunks" not in st.session_state:
st.session_state["chunks"] = []
if "vectors" not in st.session_state:
st.session_state["vectors"] = []
st.title("CacheChat :money_with_wings:") # Title
st.markdown(
"Check out the repo [here](https://github.com/andrewhinh/CacheChat) and notes on using the app [here](https://github.com/andrewhinh/CacheChat#notes)."
) # Link to repo
uploaded_files = st.file_uploader(
"Choose file(s):", accept_multiple_files=True, key="files"
) # File upload section
if uploaded_files: # If (a) file(s) is/are uploaded, create embeddings
with st.spinner("Processing..."): # Show loading spinner
for uploaded_file in uploaded_files:
if not (
uploaded_file.name in st.session_state["sources"]
): # If the file has not been uploaded, process it
with open(uploaded_file.name, "wb") as file: # Save file to disk
file.write(uploaded_file.getbuffer())
source_type, file_source, file_chunks, file_vectors = embed_file(
uploaded_file.name
) # Embed file
if (
not source_type
and not file_source
and not file_chunks
and not file_vectors
): # If the file is invalid
st.error("Invalid file(s). Please try again.")
else: # If the file is valid
st.session_state["sources_types"].append(source_type)
st.session_state["sources"].append(file_source)
st.session_state["chunks"].append(file_chunks)
st.session_state["vectors"].append(file_vectors)
with st.form(key="url", clear_on_submit=True): # form for question input
uploaded_url = st.text_input(
"Enter a URL:",
placeholder="https://www.africau.edu/images/default/sample.pdf",
) # URL input text box
upload_url_button = st.form_submit_button(label="Add URL") # Add URL button
if upload_url_button and uploaded_url: # If a URL is entered, create embeddings
with st.spinner("Processing..."): # Show loading spinner
if not (
uploaded_url in st.session_state["sources"] # Non-file URL in sources
or "./" + uploaded_url.replace("/", " ") # File URL in sources
in st.session_state["sources"]
): # If the URL has not been uploaded, process it
source_type, url_source, url_chunks, url_vectors = embed_url(
uploaded_url
) # Embed URL
if (
not source_type
and not url_source
and not url_chunks
and not url_vectors
): # If the URL is invalid
st.error("Invalid URL. Please try again.")
else: # If the URL is valid
st.session_state["sources_types"].append(source_type)
st.session_state["sources"].append(url_source)
st.session_state["chunks"].append(url_chunks)
st.session_state["vectors"].append(url_vectors)
st.divider() # Create a divider between the uploads and the chat
input_container = (
st.container()
) # container for inputs/uploads, https://docs.streamlit.io/library/api-reference/layout/st.container
response_container = (
st.container()
) # container for chat history, https://docs.streamlit.io/library/api-reference/layout/st.container
with input_container:
with st.form(key="question", clear_on_submit=True): # form for question input
uploaded_question = st.text_input(
"Enter your input:",
placeholder="e.g: Summarize the research paper in 3 sentences.",
key="input",
) # question text box
uploaded_question_button = st.form_submit_button(
label="Send"
) # send button
with response_container:
if (
uploaded_question_button and uploaded_question
): # if send button is pressed and text box is not empty
with st.spinner("Thinking..."): # show loading spinner
st.session_state["questions"].append(
uploaded_question
) # add question to questions
ask() # ask question to chatbot
if __name__ == "__main__":
main()