vinhnx90 commited on
Commit
375bd04
β€’
1 Parent(s): 568c54e

Refactor to extract token stream handler

Browse files
Files changed (2) hide show
  1. app.py +29 -38
  2. token_stream_handler.py +13 -0
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import streamlit as st
3
- from langchain.callbacks.base import BaseCallbackHandler
 
4
  from langchain.chains import ConversationalRetrievalChain
5
  from langchain.schema import ChatMessage
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -10,21 +11,6 @@ from langchain_openai import ChatOpenAI, OpenAIEmbeddings
10
 
11
  st.set_page_config(page_title="InkChatGPT", page_icon="πŸ“š")
12
 
13
- __import__("pysqlite3")
14
- import sys
15
-
16
- sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
17
-
18
-
19
- class StreamHandler(BaseCallbackHandler):
20
- def __init__(self, container, initial_text=""):
21
- self.container = container
22
- self.text = initial_text
23
-
24
- def on_llm_new_token(self, token: str, **kwargs) -> None:
25
- self.text += token
26
- self.container.markdown(self.text)
27
-
28
 
29
  def load_and_process_file(file_data):
30
  """
@@ -156,38 +142,43 @@ def clear_history():
156
  del st.session_state["history"]
157
 
158
 
 
 
 
 
 
 
 
 
 
 
 
159
  def build_sidebar():
160
  with st.sidebar:
161
  st.title("πŸ“š InkChatGPT")
162
 
163
- openai_api_key = st.text_input(
164
- "OpenAI API Key", type="password", placeholder="Enter your OpenAI API key"
165
- )
166
- st.session_state.api_key = openai_api_key
 
 
167
 
168
- if not openai_api_key:
169
- st.info("Please add your OpenAI API key to continue.")
 
170
 
171
- uploaded_file = st.file_uploader(
172
- "Select a file", type=["pdf", "docx", "txt"], key="file_uploader"
173
- )
174
 
175
- if uploaded_file and openai_api_key.startswith("sk-"):
176
- add_file = st.button(
177
  "Process File",
178
- on_click=clear_history,
179
- key="process_button",
 
180
  )
181
 
182
- if uploaded_file and add_file:
183
- with st.spinner("πŸ’­ Thinking..."):
184
- vector_store = load_and_process_file(uploaded_file)
185
-
186
- if vector_store:
187
- crc = initialize_chat_model(vector_store)
188
- st.session_state.crc = crc
189
- st.success("File processed successfully!")
190
-
191
 
192
  if __name__ == "__main__":
193
  build_sidebar()
 
1
  import os
2
  import streamlit as st
3
+
4
+ from token_stream_handler import StreamHandler
5
  from langchain.chains import ConversationalRetrievalChain
6
  from langchain.schema import ChatMessage
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
11
 
12
  st.set_page_config(page_title="InkChatGPT", page_icon="πŸ“š")
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def load_and_process_file(file_data):
16
  """
 
142
  del st.session_state["history"]
143
 
144
 
145
+ def process_data(uploaded_file, openai_api_key):
146
+ if uploaded_file and openai_api_key.startswith("sk-"):
147
+ with st.spinner("πŸ’­ Thinking..."):
148
+ vector_store = load_and_process_file(uploaded_file)
149
+
150
+ if vector_store:
151
+ crc = initialize_chat_model(vector_store)
152
+ st.session_state.crc = crc
153
+ st.success(f"File: `{uploaded_file.name}`, processed successfully!")
154
+
155
+
156
  def build_sidebar():
157
  with st.sidebar:
158
  st.title("πŸ“š InkChatGPT")
159
 
160
+ with st.form(key="input_form"):
161
+ openai_api_key = st.text_input(
162
+ "OpenAI API Key",
163
+ type="password",
164
+ placeholder="Enter your OpenAI API key",
165
+ )
166
 
167
+ st.session_state.api_key = openai_api_key
168
+ if not openai_api_key:
169
+ st.info("Please add your OpenAI API key to continue.")
170
 
171
+ uploaded_file = st.file_uploader(
172
+ "Select a file", type=["pdf", "docx", "txt"], key="file_uploader"
173
+ )
174
 
175
+ st.form_submit_button(
 
176
  "Process File",
177
+ on_click=process_data(
178
+ uploaded_file=uploaded_file, openai_api_key=openai_api_key
179
+ ),
180
  )
181
 
 
 
 
 
 
 
 
 
 
182
 
183
  if __name__ == "__main__":
184
  build_sidebar()
token_stream_handler.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from langchain.callbacks.base import BaseCallbackHandler
4
+
5
+
6
+ class StreamHandler(BaseCallbackHandler):
7
+ def __init__(self, container, initial_text=""):
8
+ self.container = container
9
+ self.text = initial_text
10
+
11
+ def on_llm_new_token(self, token: str, **kwargs) -> None:
12
+ self.text += token
13
+ self.container.markdown(self.text)