Spaces:
Sleeping
Sleeping
import streamlit as st | |
# from htmlTemplates import css, bot_template, user_template | |
from dotenv import load_dotenv | |
# from PyPDF2 import PdfReader | |
import os | |
import mysql.connector | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain_community.embeddings import HuggingFaceInstructEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.llms import HuggingFaceHub | |
from langchain_openai import ChatOpenAI | |
from langchain_openai import OpenAIEmbeddings | |
from langchain.memory import ConversationBufferMemory | |
from langchain.chains import ConversationalRetrievalChain | |
def get_pdf_text(slug): | |
load_dotenv() | |
text = "" | |
try: | |
conn = mysql.connector.connect( | |
user=os.getenv("SQL_USER"), | |
password=os.getenv("SQL_PWD"), | |
host=os.getenv("SQL_HOST"), | |
database="Birdseye_DB", | |
) | |
cursor = conn.cursor() | |
# Execute a query | |
cursor.execute("SELECT ocr_text FROM birdseye_temp WHERE slug = %s", (slug,)) | |
# Fetch the results | |
rows = cursor.fetchall() | |
for row in rows: | |
if row[0]: | |
text += row[0] | |
except mysql.connector.Error as err: | |
st.error(f"Error: {err}") | |
finally: | |
if conn.is_connected(): | |
cursor.close() | |
conn.close() | |
return text | |
def get_text_chunks(text): | |
""" | |
Splits the given text into chunks based on specified character settings. | |
Parameters: | |
- text (str): The text to be split into chunks. | |
Returns: | |
- list: A list of text chunks. | |
""" | |
text_splitter = CharacterTextSplitter( | |
separator="\n", chunk_size=1000, chunk_overlap=200, length_function=len | |
) | |
chunks = text_splitter.split_text(text) | |
return chunks | |
def get_vectorstore(text_chunks): | |
""" | |
Generates a vector store from a list of text chunks using specified embeddings. | |
Parameters: | |
- text_chunks (list of str): Text segments to convert into vector embeddings. | |
Returns: | |
- FAISS: A FAISS vector store containing the embeddings of the text chunks. | |
""" | |
embeddings = OpenAIEmbeddings() | |
vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings) | |
return vectorstore | |
def get_conversation_chain(vectorstore): | |
""" | |
Initializes a conversational retrieval chain that uses a large language model | |
for generating responses based on the provided vector store. | |
Parameters: | |
- vectorstore (FAISS): A vector store to be used for retrieving relevant content. | |
Returns: | |
- ConversationalRetrievalChain: An initialized conversational chain object. | |
""" | |
try: | |
llm = ChatOpenAI(model_name="gpt-4o", temperature=0.5, top_p=0.5) | |
memory = ConversationBufferMemory( | |
memory_key="chat_history", return_messages=True | |
) | |
conversation_chain = ConversationalRetrievalChain.from_llm( | |
llm=llm, retriever=vectorstore.as_retriever(), memory=memory | |
) | |
return conversation_chain | |
except Exception as e: | |
raise # Re-raise exception to handle it or log it properly elsewhere | |
def handle_userinput(user_question): | |
response = st.session_state.conversation( | |
{ | |
"question": f"Based on the memory and the provided document, answer the following user question: {user_question}. If the question is unrelated to memory or the document, just mention that you cannot provide an answer." | |
} | |
) | |
st.session_state.chat_history = response["chat_history"] | |
for i, message in reversed(list(enumerate(st.session_state.chat_history))): | |
if i % 2 == 0: | |
st.write( | |
user_template.replace("{{MSG}}", message.content), | |
unsafe_allow_html=True, | |
) | |
else: | |
st.write( | |
bot_template.replace("{{MSG}}", message.content), unsafe_allow_html=True | |
) | |
def get_user_chat_count(user_id): | |
""" | |
Retrieves the chat count for the user from the MySQL database. | |
""" | |
try: | |
conn = mysql.connector.connect( | |
user=os.getenv("SQL_USER"), | |
password=os.getenv("SQL_PWD"), | |
host=os.getenv("SQL_HOST"), | |
database="Birdseye_DB", | |
) | |
cursor = conn.cursor() | |
cursor.execute("SELECT count FROM birdseye_chat WHERE user_id = %s", (user_id,)) | |
result = cursor.fetchone() | |
if result: | |
return result[0] | |
else: | |
# Insert a new row for the user if not found | |
cursor.execute( | |
"INSERT INTO birdseye_chat (user_id, count) VALUES (%s, %s)", | |
(user_id, 0), | |
) | |
conn.commit() | |
return 0 | |
except mysql.connector.Error as err: | |
st.error(f"Error: {err}") | |
return None | |
finally: | |
if conn.is_connected(): | |
cursor.close() | |
conn.close() | |
def increment_user_chat_count(user_id): | |
""" | |
Increments the chat count for the user in the MySQL database. | |
""" | |
try: | |
conn = mysql.connector.connect( | |
user=os.getenv("SQL_USER"), | |
password=os.getenv("SQL_PWD"), | |
host=os.getenv("SQL_HOST"), | |
database="Birdseye_DB", | |
) | |
cursor = conn.cursor() | |
cursor.execute( | |
"UPDATE birdseye_chat SET count = count + 1 WHERE user_id = %s ", (user_id,) | |
) | |
conn.commit() | |
except mysql.connector.Error as err: | |
st.error(f"Error: {err}") | |
finally: | |
if conn.is_connected(): | |
cursor.close() | |
conn.close() | |
def is_user_in_unlimited_chat_group(user_id): | |
""" | |
Checks if the user belongs to the 'Unlimited Chat' group. | |
""" | |
try: | |
conn = mysql.connector.connect( | |
user=os.getenv("SQL_USER"), | |
password=os.getenv("SQL_PWD"), | |
host=os.getenv("SQL_HOST"), | |
database="Birdseye_DB", | |
) | |
cursor = conn.cursor() | |
cursor.execute( | |
""" | |
SELECT 1 | |
FROM auth_user_groups | |
JOIN auth_group ON auth_user_groups.group_id = auth_group.id | |
WHERE auth_user_groups.user_id = %s AND auth_group.name = 'Unlimited Chat' | |
""", | |
(user_id,), | |
) | |
return cursor.fetchone() is not None | |
except mysql.connector.Error as err: | |
st.error(f"Error: {err}") | |
return False | |
finally: | |
if conn.is_connected(): | |
cursor.close() | |
conn.close() | |
def chat(slug, user_id): | |
""" | |
Manages the chat interface in the Streamlit application, handling the conversation | |
flow and displaying the chat history. | |
Restricts chat based on user group and chat count. | |
""" | |
st.write( | |
"**Please note:** Due to processing limitations, the chat may not fully comprehend the whole document." | |
) | |
text_chunks = get_text_chunks(get_pdf_text(slug)) | |
vectorstore = get_vectorstore(text_chunks) | |
st.session_state.conversation = get_conversation_chain(vectorstore) | |
# Check if the user can chat | |
if not is_user_in_unlimited_chat_group(user_id): | |
user_chat_count = get_user_chat_count(user_id) | |
if user_chat_count is None or user_chat_count >= 20: | |
st.write("You have reached your chat limit.") | |
return | |
if len(st.session_state.messages) == 1: | |
message = st.session_state.messages[0] | |
with st.chat_message(message["role"]): | |
st.write(message["content"]) | |
else: | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.write(message["content"]) | |
# User-provided prompt | |
if prompt := st.chat_input(): | |
# increment_user_chat_count(user_id) | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
st.session_state.prompts = prompt | |
with st.chat_message("user"): | |
st.write(prompt) | |
if st.session_state.messages[-1]["role"] != "ai": | |
with st.spinner("Generating response..."): | |
response = st.session_state.conversation.invoke( | |
{"question": st.session_state.prompts} | |
) | |
with st.chat_message("ai"): | |
message_content = response["chat_history"][-1].content | |
st.session_state.messages.append({"role": "ai", "content": message_content}) | |
st.write(message_content) | |
if not is_user_in_unlimited_chat_group(user_id): | |
increment_user_chat_count(user_id) # Increment count after response | |
def init(): | |
""" | |
Initializes the session state variables used in the Streamlit application and | |
loads environment variables. | |
""" | |
if "pdf" not in st.session_state: | |
st.session_state["pdf"] = False | |
if "conversation" not in st.session_state: | |
st.session_state.conversation = None | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = None | |
if "messages" not in st.session_state.keys(): | |
st.session_state.messages = [ | |
{ | |
"role": "ai", | |
"content": "What do you want to learn about the document? Ask me a question!", | |
} | |
] | |
def main(): | |
init() | |
query_params = st.query_params | |
slug = query_params.get("slug") | |
user_id = query_params.get("user_id") | |
load_dotenv() | |
st.title("Chat with GPT :books:") | |
if slug and user_id: | |
chat(slug, user_id) | |
else: | |
st.error("Please return to Birdseye and select a document.") | |
if __name__ == "__main__": | |
main() | |