vinhnx90 commited on
Commit
db70198
β€’
1 Parent(s): 5440da0

Refactor app with better Document retrieval embedding and better chat streaming

Browse files
Files changed (2) hide show
  1. app.py +121 -146
  2. requirements.txt +4 -8
app.py CHANGED
@@ -1,155 +1,130 @@
1
  import os
2
- import streamlit as st
3
 
4
- from chat_profile import ChatProfileRoleEnum
 
 
 
 
 
 
 
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
- from langchain_community.document_loaders import Docx2txtLoader, PyPDFLoader, TextLoader
7
- from langchain_community.vectorstores.chroma import Chroma
8
- from langchain_openai import ChatOpenAI, OpenAIEmbeddings
9
- from langchain_community.chat_message_histories import StreamlitChatMessageHistory
10
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
11
- from langchain_core.runnables.history import RunnableWithMessageHistory
12
-
13
- __import__("pysqlite3")
14
- import sys
15
-
16
- sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
17
-
18
- # config page
19
- st.set_page_config(page_title="InkChatGPT", page_icon="πŸ“š")
20
-
21
- # Set up memory
22
- msgs = StreamlitChatMessageHistory(key="langchain_messages")
23
-
24
-
25
- def load_and_process_file(file_data):
26
- """
27
- Load and process the uploaded file.
28
- Returns a vector store containing the embedded chunks of the file.
29
- """
30
- file_name = os.path.join("./", file_data.name)
31
- with open(file_name, "wb") as f:
32
- f.write(file_data.getvalue())
33
-
34
- _, extension = os.path.splitext(file_name)
35
-
36
- # Load the file using the appropriate loader
37
- if extension == ".pdf":
38
- loader = PyPDFLoader(file_name)
39
- elif extension == ".docx":
40
- loader = Docx2txtLoader(file_name)
41
- elif extension == ".txt":
42
- loader = TextLoader(file_name)
43
- else:
44
- st.error("This document format is not supported!")
45
- return None
46
-
47
- documents = loader.load()
48
-
49
- text_splitter = RecursiveCharacterTextSplitter(
50
- chunk_size=1000,
51
- chunk_overlap=200,
52
  )
53
- chunks = text_splitter.split_documents(documents)
54
- embeddings = OpenAIEmbeddings(api_key=st.session_state.api_key)
55
- vector_store = Chroma.from_documents(chunks, embeddings)
56
- return vector_store
57
-
58
-
59
- def main():
60
- """
61
- The main function that runs the Streamlit app.
62
- """
63
 
64
- if not st.session_state.api_key:
65
- st.info("Please add your OpenAI API key to continue.")
66
 
67
- if len(msgs.messages) == 0:
68
- msgs.add_ai_message(
69
- """
70
- Hello, how can I help you?
71
 
72
- You can upload a document and chat with me to ask questions related to its content.
73
- """
74
- )
75
-
76
- # Render current messages from StreamlitChatMessageHistory
77
- for msg in msgs.messages:
78
- st.chat_message(msg.type).write(msg.content)
79
-
80
- # If user inputs a new prompt, generate and draw a new response
81
- if question := st.chat_input(
82
- placeholder="Chat with your document",
83
- disabled=(not st.session_state.api_key),
84
  ):
85
- st.chat_message(ChatProfileRoleEnum.Human).write(question)
86
- prompt = ChatPromptTemplate.from_messages(
87
- [
88
- ("system", "You are an AI chatbot having a conversation with a human."),
89
- MessagesPlaceholder(variable_name="history"),
90
- (ChatProfileRoleEnum.Human, f"{question}"),
91
- ]
92
- )
93
-
94
- llm = ChatOpenAI(
95
- api_key=st.session_state.api_key,
96
- temperature=0.0,
97
- model_name="gpt-3.5-turbo",
98
- )
99
-
100
- chain = prompt | llm
101
- chain_with_history = RunnableWithMessageHistory(
102
- chain,
103
- lambda session_id: msgs,
104
- input_messages_key="question",
105
- history_messages_key="history",
106
- )
107
-
108
- # Note: new messages are saved to history automatically by Langchain during run
109
- config = {"configurable": {"session_id": "any"}}
110
- response = chain_with_history.invoke({"question": question}, config)
111
- st.chat_message(ChatProfileRoleEnum.AI).write(response.content)
112
-
113
-
114
- def build_sidebar():
115
- with st.sidebar:
116
- st.subheader("πŸ“š InkChatGPT")
117
-
118
- openai_api_key = st.text_input(
119
- "OpenAI API Key",
120
- type="password",
121
- placeholder="Enter your OpenAI API key",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  )
123
- st.session_state.api_key = openai_api_key
124
-
125
- with st.form("my_form"):
126
-
127
- uploaded_file = st.file_uploader(
128
- "Select a file", type=["pdf", "docx", "txt"], key="file_uploader"
129
- )
130
-
131
- add_file = st.form_submit_button(
132
- "Process File",
133
- disabled=(not uploaded_file and not openai_api_key),
134
- )
135
- if (
136
- add_file
137
- and uploaded_file
138
- and st.session_state.api_key.startswith("sk-")
139
- ):
140
- with st.spinner("πŸ’­ Thinking..."):
141
- vector_store = load_and_process_file(uploaded_file)
142
-
143
- if vector_store:
144
- msgs.add_ai_message(
145
- f"""
146
- File: `{uploaded_file.name}`, processed successfully!
147
-
148
- Feel free to ask me any question.
149
- """
150
- )
151
-
152
-
153
- if __name__ == "__main__":
154
- build_sidebar()
155
- main()
 
1
  import os
2
+ import tempfile
3
 
4
+ import streamlit as st
5
+ from langchain.callbacks.base import BaseCallbackHandler
6
+ from langchain.chains import ConversationalRetrievalChain
7
+ from langchain.chat_models import ChatOpenAI
8
+ from langchain.document_loaders import PyPDFLoader
9
+ from langchain.embeddings import HuggingFaceEmbeddings
10
+ from langchain.memory import ConversationBufferMemory
11
+ from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
12
  from langchain.text_splitter import RecursiveCharacterTextSplitter
13
+ from langchain_community.vectorstores import DocArrayInMemorySearch
14
+
15
+ st.set_page_config(page_title="πŸ“š InkChatGPT: Chat with Documents", page_icon="πŸ“š")
16
+ st.subheader("πŸ“š InkChatGPT")
17
+ st.write("Chat with Documents")
18
+
19
+ @st.cache_resource(ttl="1h")
20
+ def configure_retriever(uploaded_files):
21
+ # Read documents
22
+ docs = []
23
+ temp_dir = tempfile.TemporaryDirectory()
24
+ for file in uploaded_files:
25
+ temp_filepath = os.path.join(temp_dir.name, file.name)
26
+ with open(temp_filepath, "wb") as f:
27
+ f.write(file.getvalue())
28
+ loader = PyPDFLoader(temp_filepath)
29
+ docs.extend(loader.load())
30
+
31
+ # Split documents
32
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200)
33
+ splits = text_splitter.split_documents(docs)
34
+
35
+ # Create embeddings and store in vectordb
36
+ embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
37
+ vectordb = DocArrayInMemorySearch.from_documents(splits, embeddings)
38
+
39
+ # Define retriever
40
+ retriever = vectordb.as_retriever(
41
+ search_type="mmr", search_kwargs={"k": 2, "fetch_k": 4}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  )
 
 
 
 
 
 
 
 
 
 
43
 
44
+ return retriever
 
45
 
 
 
 
 
46
 
47
+ class StreamHandler(BaseCallbackHandler):
48
+ def __init__(
49
+ self, container: st.delta_generator.DeltaGenerator, initial_text: str = ""
 
 
 
 
 
 
 
 
 
50
  ):
51
+ self.container = container
52
+ self.text = initial_text
53
+ self.run_id_ignore_token = None
54
+
55
+ def on_llm_start(self, serialized: dict, prompts: list, **kwargs):
56
+ # Workaround to prevent showing the rephrased question as output
57
+ if prompts[0].startswith("Human"):
58
+ self.run_id_ignore_token = kwargs.get("run_id")
59
+
60
+ def on_llm_new_token(self, token: str, **kwargs) -> None:
61
+ if self.run_id_ignore_token == kwargs.get("run_id", False):
62
+ return
63
+ self.text += token
64
+ self.container.markdown(self.text)
65
+
66
+
67
+ class PrintRetrievalHandler(BaseCallbackHandler):
68
+ def __init__(self, container):
69
+ self.status = container.status("**Context Retrieval**")
70
+
71
+ def on_retriever_start(self, serialized: dict, query: str, **kwargs):
72
+ self.status.write(f"**Question:** {query}")
73
+ self.status.update(label=f"**Context Retrieval:** {query}")
74
+
75
+ def on_retriever_end(self, documents, **kwargs):
76
+ for idx, doc in enumerate(documents):
77
+ source = os.path.basename(doc.metadata["source"])
78
+ self.status.write(f"**Document {idx} from {source}**")
79
+ self.status.markdown(doc.page_content)
80
+ self.status.update(state="complete")
81
+
82
+
83
+ openai_api_key = st.sidebar.text_input("OpenAI API Key", type="password")
84
+ if not openai_api_key:
85
+ st.info("Please add your OpenAI API key to continue.")
86
+ st.stop()
87
+
88
+ uploaded_files = st.sidebar.file_uploader(
89
+ label="Upload PDF files", type=["pdf"], accept_multiple_files=True
90
+ )
91
+ if not uploaded_files:
92
+ st.info("Please upload PDF documents to continue.")
93
+ st.stop()
94
+
95
+ retriever = configure_retriever(uploaded_files)
96
+
97
+ # Setup memory for contextual conversation
98
+ msgs = StreamlitChatMessageHistory()
99
+ memory = ConversationBufferMemory(
100
+ memory_key="chat_history", chat_memory=msgs, return_messages=True
101
+ )
102
+
103
+ # Setup LLM and QA chain
104
+ llm = ChatOpenAI(
105
+ model_name="gpt-3.5-turbo",
106
+ openai_api_key=openai_api_key,
107
+ temperature=0,
108
+ streaming=True,
109
+ )
110
+ qa_chain = ConversationalRetrievalChain.from_llm(
111
+ llm, retriever=retriever, memory=memory, verbose=True
112
+ )
113
+
114
+ if len(msgs.messages) == 0 or st.sidebar.button("Clear message history"):
115
+ msgs.clear()
116
+ msgs.add_ai_message("How can I help you?")
117
+
118
+ avatars = {"human": "user", "ai": "assistant"}
119
+ for msg in msgs.messages:
120
+ st.chat_message(avatars[msg.type]).write(msg.content)
121
+
122
+ if user_query := st.chat_input(placeholder="Ask me anything!"):
123
+ st.chat_message("user").write(user_query)
124
+
125
+ with st.chat_message("assistant"):
126
+ retrieval_handler = PrintRetrievalHandler(st.container())
127
+ stream_handler = StreamHandler(st.empty())
128
+ response = qa_chain.run(
129
+ user_query, callbacks=[retrieval_handler, stream_handler]
130
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,11 +1,7 @@
 
 
 
1
  langchain
2
- langchain_openai
3
  streamlit
4
  streamlit_chat
5
- chromadb
6
- openai
7
- tiktoken
8
- pypdf
9
- docx2txt
10
- watchdog
11
- pysqlite3-binary
 
1
+ openai
2
+ sentence-transformers
3
+ docarray
4
  langchain
 
5
  streamlit
6
  streamlit_chat
7
+ pypdf