Chat / app.py
jeonghin's picture
Update app.py
666ff20 verified
raw
history blame contribute delete
No virus
9.63 kB
import streamlit as st
# from htmlTemplates import css, bot_template, user_template
from dotenv import load_dotenv
# from PyPDF2 import PdfReader
import os
import mysql.connector
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.embeddings import HuggingFaceInstructEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.llms import HuggingFaceHub
from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
def get_pdf_text(slug):
load_dotenv()
text = ""
try:
conn = mysql.connector.connect(
user=os.getenv("SQL_USER"),
password=os.getenv("SQL_PWD"),
host=os.getenv("SQL_HOST"),
database="Birdseye_DB",
)
cursor = conn.cursor()
# Execute a query
cursor.execute("SELECT ocr_text FROM birdseye_temp WHERE slug = %s", (slug,))
# Fetch the results
rows = cursor.fetchall()
for row in rows:
if row[0]:
text += row[0]
except mysql.connector.Error as err:
st.error(f"Error: {err}")
finally:
if conn.is_connected():
cursor.close()
conn.close()
return text
def get_text_chunks(text):
"""
Splits the given text into chunks based on specified character settings.
Parameters:
- text (str): The text to be split into chunks.
Returns:
- list: A list of text chunks.
"""
text_splitter = CharacterTextSplitter(
separator="\n", chunk_size=1000, chunk_overlap=200, length_function=len
)
chunks = text_splitter.split_text(text)
return chunks
def get_vectorstore(text_chunks):
"""
Generates a vector store from a list of text chunks using specified embeddings.
Parameters:
- text_chunks (list of str): Text segments to convert into vector embeddings.
Returns:
- FAISS: A FAISS vector store containing the embeddings of the text chunks.
"""
embeddings = OpenAIEmbeddings()
vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings)
return vectorstore
def get_conversation_chain(vectorstore):
"""
Initializes a conversational retrieval chain that uses a large language model
for generating responses based on the provided vector store.
Parameters:
- vectorstore (FAISS): A vector store to be used for retrieving relevant content.
Returns:
- ConversationalRetrievalChain: An initialized conversational chain object.
"""
try:
llm = ChatOpenAI(model_name="gpt-4o", temperature=0.5, top_p=0.5)
memory = ConversationBufferMemory(
memory_key="chat_history", return_messages=True
)
conversation_chain = ConversationalRetrievalChain.from_llm(
llm=llm, retriever=vectorstore.as_retriever(), memory=memory
)
return conversation_chain
except Exception as e:
raise # Re-raise exception to handle it or log it properly elsewhere
def handle_userinput(user_question):
response = st.session_state.conversation(
{
"question": f"Based on the memory and the provided document, answer the following user question: {user_question}. If the question is unrelated to memory or the document, just mention that you cannot provide an answer."
}
)
st.session_state.chat_history = response["chat_history"]
for i, message in reversed(list(enumerate(st.session_state.chat_history))):
if i % 2 == 0:
st.write(
user_template.replace("{{MSG}}", message.content),
unsafe_allow_html=True,
)
else:
st.write(
bot_template.replace("{{MSG}}", message.content), unsafe_allow_html=True
)
def get_user_chat_count(user_id):
"""
Retrieves the chat count for the user from the MySQL database.
"""
try:
conn = mysql.connector.connect(
user=os.getenv("SQL_USER"),
password=os.getenv("SQL_PWD"),
host=os.getenv("SQL_HOST"),
database="Birdseye_DB",
)
cursor = conn.cursor()
cursor.execute("SELECT count FROM birdseye_chat WHERE user_id = %s", (user_id,))
result = cursor.fetchone()
if result:
return result[0]
else:
# Insert a new row for the user if not found
cursor.execute(
"INSERT INTO birdseye_chat (user_id, count) VALUES (%s, %s)",
(user_id, 0),
)
conn.commit()
return 0
except mysql.connector.Error as err:
st.error(f"Error: {err}")
return None
finally:
if conn.is_connected():
cursor.close()
conn.close()
def increment_user_chat_count(user_id):
"""
Increments the chat count for the user in the MySQL database.
"""
try:
conn = mysql.connector.connect(
user=os.getenv("SQL_USER"),
password=os.getenv("SQL_PWD"),
host=os.getenv("SQL_HOST"),
database="Birdseye_DB",
)
cursor = conn.cursor()
cursor.execute(
"UPDATE birdseye_chat SET count = count + 1 WHERE user_id = %s ", (user_id,)
)
conn.commit()
except mysql.connector.Error as err:
st.error(f"Error: {err}")
finally:
if conn.is_connected():
cursor.close()
conn.close()
def is_user_in_unlimited_chat_group(user_id):
"""
Checks if the user belongs to the 'Unlimited Chat' group.
"""
try:
conn = mysql.connector.connect(
user=os.getenv("SQL_USER"),
password=os.getenv("SQL_PWD"),
host=os.getenv("SQL_HOST"),
database="Birdseye_DB",
)
cursor = conn.cursor()
cursor.execute(
"""
SELECT 1
FROM auth_user_groups
JOIN auth_group ON auth_user_groups.group_id = auth_group.id
WHERE auth_user_groups.user_id = %s AND auth_group.name = 'Unlimited Chat'
""",
(user_id,),
)
return cursor.fetchone() is not None
except mysql.connector.Error as err:
st.error(f"Error: {err}")
return False
finally:
if conn.is_connected():
cursor.close()
conn.close()
def chat(slug, user_id):
"""
Manages the chat interface in the Streamlit application, handling the conversation
flow and displaying the chat history.
Restricts chat based on user group and chat count.
"""
st.write(
"**Please note:** Due to processing limitations, the chat may not fully comprehend the whole document."
)
text_chunks = get_text_chunks(get_pdf_text(slug))
vectorstore = get_vectorstore(text_chunks)
st.session_state.conversation = get_conversation_chain(vectorstore)
# Check if the user can chat
if not is_user_in_unlimited_chat_group(user_id):
user_chat_count = get_user_chat_count(user_id)
if user_chat_count is None or user_chat_count >= 20:
st.write("You have reached your chat limit.")
return
if len(st.session_state.messages) == 1:
message = st.session_state.messages[0]
with st.chat_message(message["role"]):
st.write(message["content"])
else:
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.write(message["content"])
# User-provided prompt
if prompt := st.chat_input():
# increment_user_chat_count(user_id)
st.session_state.messages.append({"role": "user", "content": prompt})
st.session_state.prompts = prompt
with st.chat_message("user"):
st.write(prompt)
if st.session_state.messages[-1]["role"] != "ai":
with st.spinner("Generating response..."):
response = st.session_state.conversation.invoke(
{"question": st.session_state.prompts}
)
with st.chat_message("ai"):
message_content = response["chat_history"][-1].content
st.session_state.messages.append({"role": "ai", "content": message_content})
st.write(message_content)
if not is_user_in_unlimited_chat_group(user_id):
increment_user_chat_count(user_id) # Increment count after response
def init():
"""
Initializes the session state variables used in the Streamlit application and
loads environment variables.
"""
if "pdf" not in st.session_state:
st.session_state["pdf"] = False
if "conversation" not in st.session_state:
st.session_state.conversation = None
if "chat_history" not in st.session_state:
st.session_state.chat_history = None
if "messages" not in st.session_state.keys():
st.session_state.messages = [
{
"role": "ai",
"content": "What do you want to learn about the document? Ask me a question!",
}
]
def main():
init()
query_params = st.query_params
slug = query_params.get("slug")
user_id = query_params.get("user_id")
load_dotenv()
st.title("Chat with GPT :books:")
if slug and user_id:
chat(slug, user_id)
else:
st.error("Please return to Birdseye and select a document.")
if __name__ == "__main__":
main()