vinhnx90 commited on
Commit
18a32c9
β€’
1 Parent(s): 48881b2
Files changed (3) hide show
  1. app.py +102 -74
  2. stream_handler.py +0 -37
  3. token_stream_handler.py +13 -0
app.py CHANGED
@@ -2,48 +2,52 @@ import os
2
  import tempfile
3
 
4
  import streamlit as st
 
5
 
6
  from langchain.callbacks.base import BaseCallbackHandler
7
  from langchain.chains import ConversationalRetrievalChain
8
  from langchain.chat_models import ChatOpenAI
 
9
  from langchain.embeddings import HuggingFaceEmbeddings
10
  from langchain.memory import ConversationBufferMemory
11
  from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
12
  from langchain.text_splitter import RecursiveCharacterTextSplitter
13
- from langchain_community.document_loaders import Docx2txtLoader, PyPDFLoader, TextLoader
14
  from langchain_community.vectorstores import DocArrayInMemorySearch
 
15
 
16
- from chat_profile import ChatProfileRoleEnum
17
- from stream_handler import PrintRetrievalHandler, StreamHandler
 
 
 
18
 
19
- # Configuration
20
  LLM_MODEL_NAME = "gpt-3.5-turbo"
21
  EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
22
 
23
-
24
- # Set up Streamlit app
25
- def setup_streamlit_app():
26
- st.set_page_config(
27
- page_title=":books: InkChatGPT: Chat with Documents",
28
- page_icon="πŸ“š",
29
- initial_sidebar_state="collapsed",
30
- menu_items={
31
- "Get Help": "https://x.com/vinhnx",
32
- "Report a bug": "https://github.com/vinhnx/InkChatGPT/issues",
33
- "About": "InkChatGPT is a Streamlit application that allows users to upload PDF documents and engage in a conversational Q&A with a language model (LLM) based on the content of those documents.",
34
- },
35
- )
36
-
37
- st.image("./assets/icon.jpg", width=100)
38
- st.header(":gray[:books: InkChatGPT]", divider="blue")
39
- st.write("**Chat** with Documents")
40
-
41
 
42
  # Setup memory for contextual conversation
43
  msgs = StreamlitChatMessageHistory()
44
 
45
 
46
- # Load and process documents
47
  def configure_retriever(uploaded_files):
48
  # Read documents
49
  docs = []
@@ -66,6 +70,7 @@ def configure_retriever(uploaded_files):
66
  st.write("This document format is not supported!")
67
  return None
68
 
 
69
  docs.extend(loader.load())
70
 
71
  # Split documents
@@ -84,68 +89,91 @@ def configure_retriever(uploaded_files):
84
  return retriever
85
 
86
 
87
- # Main function
88
- def main():
89
- setup_streamlit_app()
 
 
 
 
 
 
 
 
 
90
 
91
- with st.sidebar.expander("Documents"):
92
- st.subheader("Files")
93
- uploaded_files = st.file_uploader(
94
- label="Select files",
95
- type=["pdf", "txt", "docx"],
96
- accept_multiple_files=True,
97
- )
98
 
99
- with st.sidebar.expander("Setup"):
100
- st.subheader("API Key")
101
- openai_api_key = st.text_input("OpenAI API Key", type="password")
102
 
103
- is_empty_chat_messages = len(msgs.messages) == 0
104
- if is_empty_chat_messages or st.button("Clear message history"):
105
- msgs.clear()
106
- msgs.add_ai_message("How can I help you?")
107
 
108
- if not openai_api_key:
109
- st.info("Please add your OpenAI API key in the sidebar to continue.")
110
- st.stop()
111
 
112
- if uploaded_files:
113
- retriever = configure_retriever(uploaded_files)
114
 
115
- memory = ConversationBufferMemory(
116
- memory_key="chat_history", chat_memory=msgs, return_messages=True
117
- )
118
 
119
- # Setup LLM and QA chain
120
- llm = ChatOpenAI(
121
- model_name=LLM_MODEL_NAME,
122
- openai_api_key=openai_api_key,
123
- temperature=0,
124
- streaming=True,
125
- )
 
 
 
126
 
127
- chain = ConversationalRetrievalChain.from_llm(
128
- llm, retriever=retriever, memory=memory, verbose=False
129
- )
 
130
 
131
- avatars = {
132
- ChatProfileRoleEnum.Human: "user",
133
- ChatProfileRoleEnum.AI: "assistant",
134
- }
135
 
136
- for msg in msgs.messages:
137
- st.chat_message(avatars[msg.type]).write(msg.content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- if user_query := st.chat_input(placeholder="Ask me anything!"):
140
- st.chat_message("user").write(user_query)
 
 
141
 
142
- with st.chat_message("assistant"):
143
- retrieval_handler = PrintRetrievalHandler(st.empty())
144
- stream_handler = StreamHandler(st.empty())
145
- response = chain.run(
146
- user_query, callbacks=[retrieval_handler, stream_handler]
147
- )
148
 
 
 
149
 
150
- if __name__ == "__main__":
151
- main()
 
 
 
 
 
2
  import tempfile
3
 
4
  import streamlit as st
5
+ from chat_profile import ChatProfileRoleEnum
6
 
7
  from langchain.callbacks.base import BaseCallbackHandler
8
  from langchain.chains import ConversationalRetrievalChain
9
  from langchain.chat_models import ChatOpenAI
10
+ from langchain_community.document_loaders import Docx2txtLoader, PyPDFLoader, TextLoader
11
  from langchain.embeddings import HuggingFaceEmbeddings
12
  from langchain.memory import ConversationBufferMemory
13
  from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
14
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
15
  from langchain_community.vectorstores import DocArrayInMemorySearch
16
+ from streamlit_extras.add_vertical_space import add_vertical_space
17
 
18
+ # TODO: refactor
19
+ # TODO: extract class
20
+ # TODO: modularize
21
+ # TODO: hide side bar
22
+ # TODO: make the page attactive
23
 
24
+ # configs
25
  LLM_MODEL_NAME = "gpt-3.5-turbo"
26
  EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
27
 
28
+ st.set_page_config(
29
+ page_title=":books: InkChatGPT: Chat with Documents",
30
+ page_icon="πŸ“š",
31
+ initial_sidebar_state="collapsed",
32
+ menu_items={
33
+ "Get Help": "https://x.com/vinhnx",
34
+ "Report a bug": "https://github.com/vinhnx/InkChatGPT/issues",
35
+ "About": "InkChatGPT is a Streamlit application that allows users to upload PDF documents and engage in a conversational Q&A with a language model (LLM) based on the content of those documents.",
36
+ },
37
+ )
38
+
39
+ st.image("./assets/icon.jpg", width=100)
40
+ st.header(
41
+ ":gray[:books: InkChatGPT]",
42
+ divider="blue",
43
+ )
44
+ st.write("**Chat** with Documents")
 
45
 
46
  # Setup memory for contextual conversation
47
  msgs = StreamlitChatMessageHistory()
48
 
49
 
50
+ @st.cache_resource(ttl="1h")
51
  def configure_retriever(uploaded_files):
52
  # Read documents
53
  docs = []
 
70
  st.write("This document format is not supported!")
71
  return None
72
 
73
+ # loader = PyPDFLoader(temp_filepath)
74
  docs.extend(loader.load())
75
 
76
  # Split documents
 
89
  return retriever
90
 
91
 
92
+ class StreamHandler(BaseCallbackHandler):
93
+ def __init__(
94
+ self, container: st.delta_generator.DeltaGenerator, initial_text: str = ""
95
+ ):
96
+ self.container = container
97
+ self.text = initial_text
98
+ self.run_id_ignore_token = None
99
+
100
+ def on_llm_start(self, serialized: dict, prompts: list, **kwargs):
101
+ # Workaround to prevent showing the rephrased question as output
102
+ if prompts[0].startswith("Human"):
103
+ self.run_id_ignore_token = kwargs.get("run_id")
104
 
105
+ def on_llm_new_token(self, token: str, **kwargs) -> None:
106
+ if self.run_id_ignore_token == kwargs.get("run_id", False):
107
+ return
108
+ self.text += token
109
+ self.container.markdown(self.text)
 
 
110
 
 
 
 
111
 
112
+ class PrintRetrievalHandler(BaseCallbackHandler):
113
+ def __init__(self, container):
114
+ self.status = container.status("**Thinking...**")
115
+ self.container = container
116
 
117
+ def on_retriever_start(self, serialized: dict, query: str, **kwargs):
118
+ self.status.write(f"**Checking document for query:** `{query}`. Please wait...")
 
119
 
120
+ def on_retriever_end(self, documents, **kwargs):
121
+ self.container.empty()
122
 
 
 
 
123
 
124
+ with st.sidebar.expander("Documents"):
125
+ st.subheader("Files")
126
+ uploaded_files = st.file_uploader(
127
+ label="Select files", type=["pdf", "txt", "docx"], accept_multiple_files=True
128
+ )
129
+
130
+
131
+ with st.sidebar.expander("Setup"):
132
+ st.subheader("API Key")
133
+ openai_api_key = st.text_input("OpenAI API Key", type="password")
134
 
135
+ is_empty_chat_messages = len(msgs.messages) == 0
136
+ if is_empty_chat_messages or st.button("Clear message history"):
137
+ msgs.clear()
138
+ msgs.add_ai_message("How can I help you?")
139
 
140
+ if not openai_api_key:
141
+ st.info("Please add your OpenAI API key in the sidebar to continue.")
142
+ st.stop()
 
143
 
144
+ if uploaded_files:
145
+ retriever = configure_retriever(uploaded_files)
146
+
147
+ memory = ConversationBufferMemory(
148
+ memory_key="chat_history", chat_memory=msgs, return_messages=True
149
+ )
150
+
151
+ # Setup LLM and QA chain
152
+ llm = ChatOpenAI(
153
+ model_name=LLM_MODEL_NAME,
154
+ openai_api_key=openai_api_key,
155
+ temperature=0,
156
+ streaming=True,
157
+ )
158
+
159
+ chain = ConversationalRetrievalChain.from_llm(
160
+ llm, retriever=retriever, memory=memory, verbose=False
161
+ )
162
 
163
+ avatars = {
164
+ ChatProfileRoleEnum.Human: "user",
165
+ ChatProfileRoleEnum.AI: "assistant",
166
+ }
167
 
168
+ for msg in msgs.messages:
169
+ st.chat_message(avatars[msg.type]).write(msg.content)
 
 
 
 
170
 
171
+ if user_query := st.chat_input(placeholder="Ask me anything!"):
172
+ st.chat_message("user").write(user_query)
173
 
174
+ with st.chat_message("assistant"):
175
+ retrieval_handler = PrintRetrievalHandler(st.empty())
176
+ stream_handler = StreamHandler(st.empty())
177
+ response = chain.run(
178
+ user_query, callbacks=[retrieval_handler, stream_handler]
179
+ )
stream_handler.py DELETED
@@ -1,37 +0,0 @@
1
- import os
2
- import streamlit as st
3
-
4
- from langchain.callbacks.base import BaseCallbackHandler
5
-
6
-
7
- # Callback handlers
8
- class StreamHandler(BaseCallbackHandler):
9
- def __init__(
10
- self, container: st.delta_generator.DeltaGenerator, initial_text: str = ""
11
- ):
12
- self.container = container
13
- self.text = initial_text
14
- self.run_id_ignore_token = None
15
-
16
- def on_llm_start(self, serialized: dict, prompts: list, **kwargs):
17
- # Workaround to prevent showing the rephrased question as output
18
- if prompts[0].startswith("Human"):
19
- self.run_id_ignore_token = kwargs.get("run_id")
20
-
21
- def on_llm_new_token(self, token: str, **kwargs) -> None:
22
- if self.run_id_ignore_token == kwargs.get("run_id", False):
23
- return
24
- self.text += token
25
- self.container.markdown(self.text)
26
-
27
-
28
- class PrintRetrievalHandler(BaseCallbackHandler):
29
- def __init__(self, container):
30
- self.status = container.status("**Thinking...**")
31
- self.container = container
32
-
33
- def on_retriever_start(self, serialized: dict, query: str, **kwargs):
34
- self.status.write(f"**Checking document for query:** `{query}`. Please wait...")
35
-
36
- def on_retriever_end(self, documents, **kwargs):
37
- self.container.empty()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
token_stream_handler.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from langchain.callbacks.base import BaseCallbackHandler
4
+
5
+
6
+ class StreamHandler(BaseCallbackHandler):
7
+ def __init__(self, container, initial_text=""):
8
+ self.container = container
9
+ self.text = initial_text
10
+
11
+ def on_llm_new_token(self, token: str, **kwargs) -> None:
12
+ self.text += token
13
+ self.container.markdown(self.text)