# Importing the libraries import os import math import requests import bs4 from dotenv import load_dotenv import nltk import numpy as np import openai import streamlit as st from streamlit_chat import message as show_message import textract import tiktoken import uuid import validators # Helper variables load_dotenv() openai.api_key = os.environ['openapi'] # Load OpenAI API key from .env file llm_model = "gpt-3.5-turbo" # https://platform.openai.com/docs/guides/chat/introduction llm_context_window = ( 4097 # https://platform.openai.com/docs/guides/chat/managing-tokens ) embed_context_window, embed_model = ( 8191, "text-embedding-ada-002", ) # https://platform.openai.com/docs/guides/embeddings/second-generation-models nltk.download( "punkt" ) # Download the nltk punkt tokenizer for splitting text into sentences tokenizer = tiktoken.get_encoding( "cl100k_base" ) # Load the cl100k_base tokenizer which is designed to work with the ada-002 model (engine) download_chunk_size = 128 # TODO: Find optimal chunk size for downloading files split_chunk_tokens = 300 # TODO: Find optimal chunk size for splitting text num_citations = 5 # TODO: Find optimal number of citations to give context to the LLM # Streamlit settings user_avatar_style = "fun-emoji" # https://www.dicebear.com/styles assistant_avatar_style = "bottts-neutral" # Helper functions def get_num_tokens(text): # Count the number of tokens in a string return len( tokenizer.encode(text, disallowed_special=()) ) # disallowed_special=() removes the special tokens) # TODO: # Currently, any sentence that is longer than the max number of tokens will be its own chunk # This is not ideal, since this doesn't ensure that the chunks are of a maximum size # Find a way to split the sentence into chunks of a maximum size def split_into_many(text): # Split text into chunks of a maximum number of tokens sentences = nltk.tokenize.sent_tokenize(text) # Split the text into sentences total_tokens = [ get_num_tokens(sentence) for sentence in sentences ] # Get the number of tokens for each sentence chunks = [] tokens_so_far = 0 chunk = [] for sentence, num_tokens in zip(sentences, total_tokens): if not tokens_so_far: # If this is the first sentence in the chunk if ( num_tokens > split_chunk_tokens ): # If the sentence is longer than the max number of tokens, add it as its own chunk chunk.append(sentence) chunks.append(" ".join(chunk)) chunk = [] else: # If this is not the first sentence in the chunk if ( tokens_so_far + num_tokens > split_chunk_tokens ): # If the sentence would make the chunk longer than the max number of tokens, add the chunk to the list of chunks chunks.append(" ".join(chunk)) chunk = [] tokens_so_far = 0 # Otherwise, add the sentence to the chunk and add the number of tokens to the total chunk.append(sentence) tokens_so_far += num_tokens + 1 # In case the file is smaller than the max number of tokens, add the last chunk if not chunks: chunks.append(" ".join(chunk)) return chunks def embed(prompt): # Embed the prompt embeds = [] if type(prompt) == str: if ( get_num_tokens(prompt) > embed_context_window ): # If token_length of prompt > context_window prompt = split_into_many(prompt) # Split prompt into multiple chunks else: # If token_length of prompt <= context_window embeds = openai.Embedding.create(input=prompt, model=embed_model)[ "data" ] # Embed prompt if not embeds: # If the prompt was split into/is set of chunks max_num_chunks = ( embed_context_window // split_chunk_tokens ) # Number of chunks that can fit in the context window for i in range( 0, math.ceil(len(prompt) / max_num_chunks) ): # For each batch of chunks embeds.extend( openai.Embedding.create( input=prompt[i * max_num_chunks : (i + 1) * max_num_chunks], model=embed_model, )["data"] ) # Embed the batch of chunks return embeds # Return the list of embeddings def embed_file(filename): # Create embeddings for a file source_type = "file" # To help distinguish between local/URL files and URLs file_source = "" # Source of the file file_chunks = [] # List of file chunks (from the file) file_vectors = [] # List of lists of file embeddings (from each chunk) try: extracted_text = ( textract.process(filename) .decode("utf-8") # Extracted text is in bytes, convert to string .encode("ascii", "ignore") # Remove non-ascii characters .decode() # Convert back to string ) if not extracted_text: # If the file is empty raise Exception os.remove( filename ) # Remove the file from the server since it is no longer needed file_source = filename file_chunks = split_into_many(extracted_text) # Split the text into chunks file_vectors = [x["embedding"] for x in embed(file_chunks)] # Embed the chunks except Exception: # If the file cannot be extracted, return empty values if os.path.exists(filename): # If the file still exists os.remove( filename ) # Remove the file from the server since it is no longer needed source_type = "" file_source = "" file_chunks = [] file_vectors = [] return source_type, file_source, file_chunks, file_vectors def embed_url(url): # Create embeddings for a url source_type = "url" # To help distinguish between local/URL files and URLs url_source = "" # Source of the url url_chunks = [] # List of url chunks (for the url) url_vectors = [] # List of list of url embeddings (for each chunk) filename = "" # Filename of the url if it is a file try: if validators.url(url, public=True): # Verify url is a valid and public response = requests.get(url) # Get the url info header = response.headers["Content-Type"] # Get the header of the url is_application = ( header.split("/")[0] == "application" ) # Check if the url is a file if is_application: # If url is a file, call embed_file on the file filetype = header.split("/")[1] # Get the filetype url_parts = url.split("/") # Get the parts of the url filename = str( "./" + " ".join( url_parts[:-1] + [url_parts[-1].split(".")[0]] ) # Replace / with whitespace in the filename to avoid issues with the file path and remove the file extension since it may not match the actual filetype + "." + filetype ) # Create the filename with requests.get( url, stream=True ) as stream_response: # Download the file stream_response.raise_for_status() with open(filename, "wb") as file: for chunk in stream_response.iter_content( chunk_size=download_chunk_size ): file.write(chunk) return embed_file(filename) # Embed the file else: # If url is a webpage, use BeautifulSoup to extract the text soup = bs4.BeautifulSoup(response.text) # Create a BeautifulSoup object extracted_text = ( soup.get_text() # Extract the text from the webpage .encode("ascii", "ignore") # Remove non-ascii characters .decode() # Convert back to string ) if not extracted_text: # If the webpage is empty raise Exception url_source = url url_chunks = split_into_many( extracted_text ) # Split the text into chunks url_vectors = [ x["embedding"] for x in embed(url_chunks[-1]) ] # Embed the chunks else: # If url is not valid or public raise Exception except Exception: # If the url cannot be extracted, return empty values source_type = "" url_source = "" url_chunks = [] url_vectors = [] return source_type, url_source, url_chunks, url_vectors def get_most_relevant( prompt_embedding, sources_embeddings ): # Get which sources/chunks are most relevant to the prompt sources_indices = [] # List of indices of the most relevant sources sources_cosine_sims = [] # List of cosine similarities of the most relevant sources for ( source_embeddings ) in ( sources_embeddings ): # source_embeddings contains all the embeddings of each chunk in a source cosine_sims = np.array( (source_embeddings @ prompt_embedding) / ( np.linalg.norm(source_embeddings, axis=1) * np.linalg.norm(prompt_embedding) ) ) # Calculate the cosine similarity between the prompt and each chunk's vector # Get the indices of the most relevant chunks: https://stackoverflow.com/questions/6910641/how-do-i-get-indices-of-n-maximum-values-in-a-numpy-array num_chunks = min( num_citations, len(cosine_sims) ) # In case there are less chunks than num_citations indices = np.argpartition(cosine_sims, -num_chunks)[ -num_chunks: ] # Get the indices of the most relevant chunks indices = indices[np.argsort(cosine_sims[indices])] # Sort the indices cosine_sims = cosine_sims[ indices ] # Get the cosine similarities of the most relevant chunks sources_indices.append(indices) # Add the indices to sources_indices sources_cosine_sims.append( cosine_sims ) # Add the cosine similarities to sources_cosine_sims # Use sources_indices and sources_cosine_sims to get the most relevant sources/chunks indexes = [] max_cosine_sims = [] for source_idx in range(len(sources_indices)): # For each source for chunk_idx in range(len(sources_indices[source_idx])): # For each chunk sources_chunk_idx = sources_indices[source_idx][ chunk_idx ] # Get the index of the chunk similarity = sources_cosine_sims[source_idx][ chunk_idx ] # Get the cosine similarity of the chunk if len(max_cosine_sims) < num_citations: # If max_values is not full indexes.append( [source_idx, sources_chunk_idx] ) # Add the source/chunk index pair to indexes max_cosine_sims.append( similarity ) # Add the cosine similarity to max_values elif len(max_cosine_sims) == num_citations and similarity > min( max_cosine_sims ): # If max_values is full and the current cosine similarity is greater than the minimum cosine similarity in max_values indexes.append( [source_idx, sources_chunk_idx] ) # Add the source/chunk index pair to indexes max_cosine_sims.append( similarity ) # Add the cosine similarity to max_values min_idx = max_cosine_sims.index( min(max_cosine_sims) ) # Get the index of the minimum cosine similarity in max_values indexes.pop( min_idx ) # Remove the source/chunk index pair at the minimum cosine similarity index in indexes max_cosine_sims.pop( min_idx ) # Remove the minimum cosine similarity in max_values else: # If max_values is full and the current cosine similarity is less than the minimum cosine similarity in max_values pass return indexes def process_source( source, source_type ): # Process the source name to be used in a message, since URL files are processed differently return ( source if source_type == "file" else source.replace(" ", "/") ) # In case this is a URL, reverse what was done in embed_url # TODO: Find better way to create/store messages instead of everytime a new question is asked def ask(): # Ask a question messages = [ { "role": "system", "content": str( "You are a helpful chatbot that answers questions a user may have about a topic. " + "Sometimes, the user may give you external data from which you can use as needed. " + "They will give it to you in the following way:\n" + "Source 1: the source's name\n" + "Text 1: the relevant text from the source\n" + "Source 2: the source's name\n" + "Text 2: the relevant text from the source\n" + "...\n" + "You can use this data to answer the user's questions or to ask the user questions. " + "Take note that if you plan to reference a source, ALWAYS do so using the source's name.\n" ), }, {"role": "user", "content": st.session_state["questions"][0]}, ] # Add the system's introduction message and the user's first question to messages show_message( st.session_state["questions"][0], is_user=True, key=str(uuid.uuid4()), avatar_style=user_avatar_style, ) # Display user's first question if ( len(st.session_state["questions"]) > 1 and st.session_state["answers"] ): # If this is not the first question for interaction, message in enumerate( [ message for pair in zip( st.session_state["answers"], st.session_state["questions"][1:] ) for message in pair ] # Get the messages from the previous conversation in the order of [answer, question, answer, question, ...]: https://stackoverflow.com/questions/7946798/interleave-multiple-lists-of-the-same-length-in-python ): if interaction % 2 == 0: # If the message is an answer messages.append( {"role": "assistant", "content": message} ) # Add the answer to messages show_message( message, key=str(uuid.uuid4()), avatar_style=assistant_avatar_style, ) # Display the answer else: # If the message is a question messages.append( {"role": "user", "content": message} ) # Add the question to messages show_message( message, is_user=True, key=str(uuid.uuid4()), avatar_style=user_avatar_style, ) # Display the question if ( st.session_state["sources_types"] and st.session_state["sources"] and st.session_state["chunks"] and st.session_state["vectors"] ): # If there are sources that were uploaded prompt_embedding = np.array( embed(st.session_state["questions"][-1])[0]["embedding"] ) # Embed the last question indexes = get_most_relevant( prompt_embedding, st.session_state["vectors"] ) # Get the most relevant chunks if indexes: # If there are relevant chunks messages[-1]["content"] += str( "Here are some sources that may be helpful:\n" ) # Add the sources to the last message for idx, ind in enumerate(indexes): source_idx, chunk_idx = ind[0], ind[1] # Get the source and chunk index messages[-1]["content"] += str( "Source " + str(idx + 1) + ": " + process_source( st.session_state["sources"][source_idx], st.session_state["sources_types"][source_idx], ) + "\n" + "Text " + str(idx + 1) + ": " + st.session_state["chunks"][source_idx][chunk_idx] # Get the chunk + "\n" ) while ( get_num_tokens("\n".join([message["content"] for message in messages])) > llm_context_window ): # If the context window is too large if ( len(messages) == 2 ): # If there is only the introduction message and the user's most recent question max_tokens_left = llm_context_window - get_num_tokens( messages[0]["content"] ) # Get the maximum number of tokens that can be present in the question messages[1]["content"] = messages[1]["content"][ :max_tokens_left ] # Truncate the question, from https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them 4 chars ~= 1 token, but it isn't certain that this is the case, so we will just truncate the question to max_tokens_left characters to be safe else: # If there are more than 2 messages messages.pop(1) # Remove the oldest question messages.pop(2) # Remove the oldest answer answer = openai.ChatCompletion.create(model=llm_model, messages=messages)[ "choices" ][0]["message"][ "content" ] # Get the answer from the chatbot st.session_state["answers"].append(answer) # Add the answer to answers show_message( st.session_state["answers"][-1], key=str(uuid.uuid4()), avatar_style=assistant_avatar_style, ) # Display the answer # Main function, defines layout of the app def main(): # Initialize session state variables if "questions" not in st.session_state: st.session_state["questions"] = [] if "answers" not in st.session_state: st.session_state["answers"] = [] if "sources_types" not in st.session_state: st.session_state["sources_types"] = [] if "sources" not in st.session_state: st.session_state["sources"] = [] if "chunks" not in st.session_state: st.session_state["chunks"] = [] if "vectors" not in st.session_state: st.session_state["vectors"] = [] st.title("CacheChat :money_with_wings:") # Title st.markdown( "Check out the repo [here](https://github.com/andrewhinh/CacheChat) and notes on using the app [here](https://github.com/andrewhinh/CacheChat#notes)." ) # Link to repo uploaded_files = st.file_uploader( "Choose file(s):", accept_multiple_files=True, key="files" ) # File upload section if uploaded_files: # If (a) file(s) is/are uploaded, create embeddings with st.spinner("Processing..."): # Show loading spinner for uploaded_file in uploaded_files: if not ( uploaded_file.name in st.session_state["sources"] ): # If the file has not been uploaded, process it with open(uploaded_file.name, "wb") as file: # Save file to disk file.write(uploaded_file.getbuffer()) source_type, file_source, file_chunks, file_vectors = embed_file( uploaded_file.name ) # Embed file if ( not source_type and not file_source and not file_chunks and not file_vectors ): # If the file is invalid st.error("Invalid file(s). Please try again.") else: # If the file is valid st.session_state["sources_types"].append(source_type) st.session_state["sources"].append(file_source) st.session_state["chunks"].append(file_chunks) st.session_state["vectors"].append(file_vectors) with st.form(key="url", clear_on_submit=True): # form for question input uploaded_url = st.text_input( "Enter a URL:", placeholder="https://www.africau.edu/images/default/sample.pdf", ) # URL input text box upload_url_button = st.form_submit_button(label="Add URL") # Add URL button if upload_url_button and uploaded_url: # If a URL is entered, create embeddings with st.spinner("Processing..."): # Show loading spinner if not ( uploaded_url in st.session_state["sources"] # Non-file URL in sources or "./" + uploaded_url.replace("/", " ") # File URL in sources in st.session_state["sources"] ): # If the URL has not been uploaded, process it source_type, url_source, url_chunks, url_vectors = embed_url( uploaded_url ) # Embed URL if ( not source_type and not url_source and not url_chunks and not url_vectors ): # If the URL is invalid st.error("Invalid URL. Please try again.") else: # If the URL is valid st.session_state["sources_types"].append(source_type) st.session_state["sources"].append(url_source) st.session_state["chunks"].append(url_chunks) st.session_state["vectors"].append(url_vectors) st.divider() # Create a divider between the uploads and the chat input_container = ( st.container() ) # container for inputs/uploads, https://docs.streamlit.io/library/api-reference/layout/st.container response_container = ( st.container() ) # container for chat history, https://docs.streamlit.io/library/api-reference/layout/st.container with input_container: with st.form(key="question", clear_on_submit=True): # form for question input uploaded_question = st.text_input( "Enter your input:", placeholder="e.g: Summarize the research paper in 3 sentences.", key="input", ) # question text box uploaded_question_button = st.form_submit_button( label="Send" ) # send button with response_container: if ( uploaded_question_button and uploaded_question ): # if send button is pressed and text box is not empty with st.spinner("Thinking..."): # show loading spinner st.session_state["questions"].append( uploaded_question ) # add question to questions ask() # ask question to chatbot if __name__ == "__main__": main()