import streamlit as st # from htmlTemplates import css, bot_template, user_template from dotenv import load_dotenv # from PyPDF2 import PdfReader import os import mysql.connector from langchain.text_splitter import CharacterTextSplitter from langchain_community.embeddings import HuggingFaceInstructEmbeddings from langchain_community.vectorstores import FAISS from langchain_community.llms import HuggingFaceHub from langchain_openai import ChatOpenAI from langchain_openai import OpenAIEmbeddings from langchain.memory import ConversationBufferMemory from langchain.chains import ConversationalRetrievalChain def get_pdf_text(slug): load_dotenv() text = "" try: conn = mysql.connector.connect( user=os.getenv("SQL_USER"), password=os.getenv("SQL_PWD"), host=os.getenv("SQL_HOST"), database="Birdseye_DB", ) cursor = conn.cursor() # Execute a query cursor.execute("SELECT ocr_text FROM birdseye_temp WHERE slug = %s", (slug,)) # Fetch the results rows = cursor.fetchall() for row in rows: if row[0]: text += row[0] except mysql.connector.Error as err: st.error(f"Error: {err}") finally: if conn.is_connected(): cursor.close() conn.close() return text def get_text_chunks(text): """ Splits the given text into chunks based on specified character settings. Parameters: - text (str): The text to be split into chunks. Returns: - list: A list of text chunks. """ text_splitter = CharacterTextSplitter( separator="\n", chunk_size=1000, chunk_overlap=200, length_function=len ) chunks = text_splitter.split_text(text) return chunks def get_vectorstore(text_chunks): """ Generates a vector store from a list of text chunks using specified embeddings. Parameters: - text_chunks (list of str): Text segments to convert into vector embeddings. Returns: - FAISS: A FAISS vector store containing the embeddings of the text chunks. """ embeddings = OpenAIEmbeddings() vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings) return vectorstore def get_conversation_chain(vectorstore): """ Initializes a conversational retrieval chain that uses a large language model for generating responses based on the provided vector store. Parameters: - vectorstore (FAISS): A vector store to be used for retrieving relevant content. Returns: - ConversationalRetrievalChain: An initialized conversational chain object. """ try: llm = ChatOpenAI(model_name="gpt-4o", temperature=0.5, top_p=0.5) memory = ConversationBufferMemory( memory_key="chat_history", return_messages=True ) conversation_chain = ConversationalRetrievalChain.from_llm( llm=llm, retriever=vectorstore.as_retriever(), memory=memory ) return conversation_chain except Exception as e: raise # Re-raise exception to handle it or log it properly elsewhere def handle_userinput(user_question): response = st.session_state.conversation( { "question": f"Based on the memory and the provided document, answer the following user question: {user_question}. If the question is unrelated to memory or the document, just mention that you cannot provide an answer." } ) st.session_state.chat_history = response["chat_history"] for i, message in reversed(list(enumerate(st.session_state.chat_history))): if i % 2 == 0: st.write( user_template.replace("{{MSG}}", message.content), unsafe_allow_html=True, ) else: st.write( bot_template.replace("{{MSG}}", message.content), unsafe_allow_html=True ) def get_user_chat_count(user_id): """ Retrieves the chat count for the user from the MySQL database. """ try: conn = mysql.connector.connect( user=os.getenv("SQL_USER"), password=os.getenv("SQL_PWD"), host=os.getenv("SQL_HOST"), database="Birdseye_DB", ) cursor = conn.cursor() cursor.execute("SELECT count FROM birdseye_chat WHERE user_id = %s", (user_id,)) result = cursor.fetchone() if result: return result[0] else: # Insert a new row for the user if not found cursor.execute( "INSERT INTO birdseye_chat (user_id, count) VALUES (%s, %s)", (user_id, 0), ) conn.commit() return 0 except mysql.connector.Error as err: st.error(f"Error: {err}") return None finally: if conn.is_connected(): cursor.close() conn.close() def increment_user_chat_count(user_id): """ Increments the chat count for the user in the MySQL database. """ try: conn = mysql.connector.connect( user=os.getenv("SQL_USER"), password=os.getenv("SQL_PWD"), host=os.getenv("SQL_HOST"), database="Birdseye_DB", ) cursor = conn.cursor() cursor.execute( "UPDATE birdseye_chat SET count = count + 1 WHERE user_id = %s ", (user_id,) ) conn.commit() except mysql.connector.Error as err: st.error(f"Error: {err}") finally: if conn.is_connected(): cursor.close() conn.close() def is_user_in_unlimited_chat_group(user_id): """ Checks if the user belongs to the 'Unlimited Chat' group. """ try: conn = mysql.connector.connect( user=os.getenv("SQL_USER"), password=os.getenv("SQL_PWD"), host=os.getenv("SQL_HOST"), database="Birdseye_DB", ) cursor = conn.cursor() cursor.execute( """ SELECT 1 FROM auth_user_groups JOIN auth_group ON auth_user_groups.group_id = auth_group.id WHERE auth_user_groups.user_id = %s AND auth_group.name = 'Unlimited Chat' """, (user_id,), ) return cursor.fetchone() is not None except mysql.connector.Error as err: st.error(f"Error: {err}") return False finally: if conn.is_connected(): cursor.close() conn.close() def chat(slug, user_id): """ Manages the chat interface in the Streamlit application, handling the conversation flow and displaying the chat history. Restricts chat based on user group and chat count. """ st.write( "**Please note:** Due to processing limitations, the chat may not fully comprehend the whole document." ) text_chunks = get_text_chunks(get_pdf_text(slug)) vectorstore = get_vectorstore(text_chunks) st.session_state.conversation = get_conversation_chain(vectorstore) # Check if the user can chat if not is_user_in_unlimited_chat_group(user_id): user_chat_count = get_user_chat_count(user_id) if user_chat_count is None or user_chat_count >= 20: st.write("You have reached your chat limit.") return if len(st.session_state.messages) == 1: message = st.session_state.messages[0] with st.chat_message(message["role"]): st.write(message["content"]) else: for message in st.session_state.messages: with st.chat_message(message["role"]): st.write(message["content"]) # User-provided prompt if prompt := st.chat_input(): # increment_user_chat_count(user_id) st.session_state.messages.append({"role": "user", "content": prompt}) st.session_state.prompts = prompt with st.chat_message("user"): st.write(prompt) if st.session_state.messages[-1]["role"] != "ai": with st.spinner("Generating response..."): response = st.session_state.conversation.invoke( {"question": st.session_state.prompts} ) with st.chat_message("ai"): message_content = response["chat_history"][-1].content st.session_state.messages.append({"role": "ai", "content": message_content}) st.write(message_content) if not is_user_in_unlimited_chat_group(user_id): increment_user_chat_count(user_id) # Increment count after response def init(): """ Initializes the session state variables used in the Streamlit application and loads environment variables. """ if "pdf" not in st.session_state: st.session_state["pdf"] = False if "conversation" not in st.session_state: st.session_state.conversation = None if "chat_history" not in st.session_state: st.session_state.chat_history = None if "messages" not in st.session_state.keys(): st.session_state.messages = [ { "role": "ai", "content": "What do you want to learn about the document? Ask me a question!", } ] def main(): init() query_params = st.query_params slug = query_params.get("slug") user_id = query_params.get("user_id") load_dotenv() st.title("Chat with GPT :books:") if slug and user_id: chat(slug, user_id) else: st.error("Please return to Birdseye and select a document.") if __name__ == "__main__": main()