Fangrui Liu commited on
Commit
c6f6149
Β·
1 Parent(s): 04f0bde

add knowledge base management

Browse files
app.py CHANGED
@@ -10,7 +10,7 @@ from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
10
 
11
  from chat import chat_page
12
  from login import login, back_to_main
13
- from helper import build_tools, build_agents, build_all, sel_map, display
14
 
15
 
16
 
 
10
 
11
  from chat import chat_page
12
  from login import login, back_to_main
13
+ from lib.helper import build_tools, build_all, sel_map, display
14
 
15
 
16
 
callbacks/arxiv_callbacks.py CHANGED
@@ -3,70 +3,79 @@ import json
3
  import textwrap
4
  from typing import Dict, Any, List
5
  from sql_formatter.core import format_sql
6
- from langchain.callbacks.streamlit.streamlit_callback_handler import LLMThought, StreamlitCallbackHandler
 
 
 
7
  from langchain.schema.output import LLMResult
8
  from streamlit.delta_generator import DeltaGenerator
9
 
 
10
  class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
11
  def __init__(self) -> None:
12
  self.progress_bar = st.progress(value=0.0, text="Working...")
13
  self.tokens_stream = ""
14
-
15
  def on_llm_start(self, serialized, prompts, **kwargs) -> None:
16
  pass
17
-
18
  def on_text(self, text: str, **kwargs) -> None:
19
  self.progress_bar.progress(value=0.2, text="Asking LLM...")
20
-
21
  def on_chain_end(self, outputs, **kwargs) -> None:
22
- self.progress_bar.progress(value=0.6, text='Searching in DB...')
23
- if 'repr' in outputs:
24
- st.markdown('### Generated Filter')
25
  st.markdown(f"```python\n{outputs['repr']}\n```", unsafe_allow_html=True)
26
-
27
  def on_chain_start(self, serialized, inputs, **kwargs) -> None:
28
  pass
29
 
 
30
  class ChatDataSelfAskCallBackHandler(StreamlitCallbackHandler):
31
  def __init__(self) -> None:
32
- self.progress_bar = st.progress(value=0.0, text='Searching DB...')
33
  self.status_bar = st.empty()
34
  self.prog_value = 0.0
35
  self.prog_map = {
36
- 'langchain.chains.qa_with_sources.retrieval.RetrievalQAWithSourcesChain': 0.2,
37
- 'langchain.chains.combine_documents.map_reduce.MapReduceDocumentsChain': 0.4,
38
- 'langchain.chains.combine_documents.stuff.StuffDocumentsChain': 0.8
39
  }
40
 
41
  def on_llm_start(self, serialized, prompts, **kwargs) -> None:
42
  pass
43
-
44
  def on_text(self, text: str, **kwargs) -> None:
45
  pass
46
-
47
  def on_chain_start(self, serialized, inputs, **kwargs) -> None:
48
- cid = '.'.join(serialized['id'])
49
- if cid != 'langchain.chains.llm.LLMChain':
50
- self.progress_bar.progress(value=self.prog_map[cid], text=f'Running Chain `{cid}`...')
 
 
51
  self.prog_value = self.prog_map[cid]
52
  else:
53
  self.prog_value += 0.1
54
- self.progress_bar.progress(value=self.prog_value, text=f'Running Chain `{cid}`...')
 
 
55
 
56
  def on_chain_end(self, outputs, **kwargs) -> None:
57
  pass
58
-
59
 
60
  class ChatDataSQLSearchCallBackHandler(StreamlitCallbackHandler):
61
  def __init__(self) -> None:
62
- self.progress_bar = st.progress(value=0.0, text='Writing SQL...')
63
  self.status_bar = st.empty()
64
  self.prog_value = 0
65
  self.prog_interval = 0.2
66
 
67
  def on_llm_start(self, serialized, prompts, **kwargs) -> None:
68
  pass
69
-
70
  def on_llm_end(
71
  self,
72
  response: LLMResult,
@@ -74,41 +83,56 @@ class ChatDataSQLSearchCallBackHandler(StreamlitCallbackHandler):
74
  **kwargs,
75
  ):
76
  text = response.generations[0][0].text
77
- if text.replace(' ', '').upper().startswith('SELECT'):
78
- st.write('We generated Vector SQL for you:')
79
- st.markdown(f'''```sql\n{format_sql(text, max_len=80)}\n```''')
80
  print(f"Vector SQL: {text}")
81
  self.prog_value += self.prog_interval
82
  self.progress_bar.progress(value=self.prog_value, text="Searching in DB...")
83
-
84
  def on_chain_start(self, serialized, inputs, **kwargs) -> None:
85
- cid = '.'.join(serialized['id'])
86
  self.prog_value += self.prog_interval
87
- self.progress_bar.progress(value=self.prog_value, text=f'Running Chain `{cid}`...')
88
-
 
 
89
  def on_chain_end(self, outputs, **kwargs) -> None:
90
  pass
91
-
 
92
  class ChatDataSQLAskCallBackHandler(ChatDataSQLSearchCallBackHandler):
93
  def __init__(self) -> None:
94
- self.progress_bar = st.progress(value=0.0, text='Writing SQL...')
95
  self.status_bar = st.empty()
96
  self.prog_value = 0
97
  self.prog_interval = 0.1
98
-
99
-
100
  class LLMThoughtWithKB(LLMThought):
101
- def on_tool_end(self, output: str, color: str | None = None, observation_prefix: str | None = None, llm_prefix: str | None = None, **kwargs: Any) -> None:
 
 
 
 
 
 
 
102
  try:
103
- self._container.markdown("\n\n".join(["### Retrieved Documents:"] + \
104
- [f"**{i+1}**: {textwrap.shorten(r['page_content'], width=80)}"
105
- for i, r in enumerate(json.loads(output))]))
 
 
 
 
 
 
106
  except Exception as e:
107
  super().on_tool_end(output, color, observation_prefix, llm_prefix, **kwargs)
108
-
109
-
110
  class ChatDataAgentCallBackHandler(StreamlitCallbackHandler):
111
-
112
  def on_llm_start(
113
  self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
114
  ) -> None:
@@ -120,4 +144,4 @@ class ChatDataAgentCallBackHandler(StreamlitCallbackHandler):
120
  labeler=self._thought_labeler,
121
  )
122
 
123
- self._current_thought.on_llm_start(serialized, prompts)
 
3
  import textwrap
4
  from typing import Dict, Any, List
5
  from sql_formatter.core import format_sql
6
+ from langchain.callbacks.streamlit.streamlit_callback_handler import (
7
+ LLMThought,
8
+ StreamlitCallbackHandler,
9
+ )
10
  from langchain.schema.output import LLMResult
11
  from streamlit.delta_generator import DeltaGenerator
12
 
13
+
14
  class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
15
  def __init__(self) -> None:
16
  self.progress_bar = st.progress(value=0.0, text="Working...")
17
  self.tokens_stream = ""
18
+
19
  def on_llm_start(self, serialized, prompts, **kwargs) -> None:
20
  pass
21
+
22
  def on_text(self, text: str, **kwargs) -> None:
23
  self.progress_bar.progress(value=0.2, text="Asking LLM...")
24
+
25
  def on_chain_end(self, outputs, **kwargs) -> None:
26
+ self.progress_bar.progress(value=0.6, text="Searching in DB...")
27
+ if "repr" in outputs:
28
+ st.markdown("### Generated Filter")
29
  st.markdown(f"```python\n{outputs['repr']}\n```", unsafe_allow_html=True)
30
+
31
  def on_chain_start(self, serialized, inputs, **kwargs) -> None:
32
  pass
33
 
34
+
35
  class ChatDataSelfAskCallBackHandler(StreamlitCallbackHandler):
36
  def __init__(self) -> None:
37
+ self.progress_bar = st.progress(value=0.0, text="Searching DB...")
38
  self.status_bar = st.empty()
39
  self.prog_value = 0.0
40
  self.prog_map = {
41
+ "langchain.chains.qa_with_sources.retrieval.RetrievalQAWithSourcesChain": 0.2,
42
+ "langchain.chains.combine_documents.map_reduce.MapReduceDocumentsChain": 0.4,
43
+ "langchain.chains.combine_documents.stuff.StuffDocumentsChain": 0.8,
44
  }
45
 
46
  def on_llm_start(self, serialized, prompts, **kwargs) -> None:
47
  pass
48
+
49
  def on_text(self, text: str, **kwargs) -> None:
50
  pass
51
+
52
  def on_chain_start(self, serialized, inputs, **kwargs) -> None:
53
+ cid = ".".join(serialized["id"])
54
+ if cid != "langchain.chains.llm.LLMChain":
55
+ self.progress_bar.progress(
56
+ value=self.prog_map[cid], text=f"Running Chain `{cid}`..."
57
+ )
58
  self.prog_value = self.prog_map[cid]
59
  else:
60
  self.prog_value += 0.1
61
+ self.progress_bar.progress(
62
+ value=self.prog_value, text=f"Running Chain `{cid}`..."
63
+ )
64
 
65
  def on_chain_end(self, outputs, **kwargs) -> None:
66
  pass
67
+
68
 
69
  class ChatDataSQLSearchCallBackHandler(StreamlitCallbackHandler):
70
  def __init__(self) -> None:
71
+ self.progress_bar = st.progress(value=0.0, text="Writing SQL...")
72
  self.status_bar = st.empty()
73
  self.prog_value = 0
74
  self.prog_interval = 0.2
75
 
76
  def on_llm_start(self, serialized, prompts, **kwargs) -> None:
77
  pass
78
+
79
  def on_llm_end(
80
  self,
81
  response: LLMResult,
 
83
  **kwargs,
84
  ):
85
  text = response.generations[0][0].text
86
+ if text.replace(" ", "").upper().startswith("SELECT"):
87
+ st.write("We generated Vector SQL for you:")
88
+ st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""")
89
  print(f"Vector SQL: {text}")
90
  self.prog_value += self.prog_interval
91
  self.progress_bar.progress(value=self.prog_value, text="Searching in DB...")
92
+
93
  def on_chain_start(self, serialized, inputs, **kwargs) -> None:
94
+ cid = ".".join(serialized["id"])
95
  self.prog_value += self.prog_interval
96
+ self.progress_bar.progress(
97
+ value=self.prog_value, text=f"Running Chain `{cid}`..."
98
+ )
99
+
100
  def on_chain_end(self, outputs, **kwargs) -> None:
101
  pass
102
+
103
+
104
  class ChatDataSQLAskCallBackHandler(ChatDataSQLSearchCallBackHandler):
105
  def __init__(self) -> None:
106
+ self.progress_bar = st.progress(value=0.0, text="Writing SQL...")
107
  self.status_bar = st.empty()
108
  self.prog_value = 0
109
  self.prog_interval = 0.1
110
+
111
+
112
  class LLMThoughtWithKB(LLMThought):
113
+ def on_tool_end(
114
+ self,
115
+ output: str,
116
+ color=None,
117
+ observation_prefix=None,
118
+ llm_prefix=None,
119
+ **kwargs: Any,
120
+ ) -> None:
121
  try:
122
+ self._container.markdown(
123
+ "\n\n".join(
124
+ ["### Retrieved Documents:"]
125
+ + [
126
+ f"**{i+1}**: {textwrap.shorten(r['page_content'], width=80)}"
127
+ for i, r in enumerate(json.loads(output))
128
+ ]
129
+ )
130
+ )
131
  except Exception as e:
132
  super().on_tool_end(output, color, observation_prefix, llm_prefix, **kwargs)
133
+
134
+
135
  class ChatDataAgentCallBackHandler(StreamlitCallbackHandler):
 
136
  def on_llm_start(
137
  self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
138
  ) -> None:
 
144
  labeler=self._thought_labeler,
145
  )
146
 
147
+ self._current_thought.on_llm_start(serialized, prompts)
chat.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import pandas as pd
2
  from os import environ
3
  from time import sleep
@@ -7,9 +8,12 @@ from lib.sessions import SessionManager
7
  from lib.private_kb import PrivateKnowledgeBase
8
  from langchain.schema import HumanMessage, FunctionMessage
9
  from callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler
10
- from langchain.callbacks.streamlit.streamlit_callback_handler import StreamlitCallbackHandler
 
 
 
11
 
12
- from helper import (
13
  build_agents,
14
  MYSCALE_HOST,
15
  MYSCALE_PASSWORD,
@@ -30,12 +34,16 @@ TOOL_NAMES = {
30
 
31
  def on_chat_submit():
32
  with st.session_state.next_round.container():
33
- with st.chat_message('user'):
34
  st.write(st.session_state.chat_input)
35
- with st.chat_message('assistant'):
36
  container = st.container()
37
- st_callback = ChatDataAgentCallBackHandler(container, collapse_completed_thoughts=False)
38
- ret = st.session_state.agent({"input": st.session_state.chat_input}, callbacks=[st_callback])
 
 
 
 
39
  print(ret)
40
 
41
 
@@ -105,7 +113,10 @@ def refresh_sessions():
105
  st.session_state[
106
  "current_sessions"
107
  ] = st.session_state.session_manager.list_sessions(st.session_state.user_name)
108
- if type(st.session_state.current_sessions) is not dict and len(st.session_state.current_sessions) <= 0:
 
 
 
109
  st.session_state.session_manager.add_session(
110
  st.session_state.user_name,
111
  f"{st.session_state.user_name}?default",
@@ -114,14 +125,64 @@ def refresh_sessions():
114
  st.session_state[
115
  "current_sessions"
116
  ] = st.session_state.session_manager.list_sessions(st.session_state.user_name)
117
-
 
 
 
 
 
 
 
 
 
118
  try:
119
- dfl_indx = [x["session_id"] for x in st.session_state.current_sessions].index("default" if "" not in st.session_state else st.session_state.sel_session["session_id"])
 
 
 
 
120
  except ValueError:
121
  dfl_indx = 0
122
  st.session_state.sel_sess = st.session_state.current_sessions[dfl_indx]
123
 
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  def refresh_agent():
126
  with st.spinner("Initializing session..."):
127
  print(
@@ -138,22 +199,29 @@ def refresh_agent():
138
  else st.session_state.sel_sess["system_prompt"],
139
  )
140
 
 
141
  def add_file():
142
- if 'uploaded_files' not in st.session_state or len(st.session_state.uploaded_files) == 0:
 
 
 
143
  st.session_state.tool_status.error("Please upload files!", icon="⚠️")
144
  sleep(2)
145
  return
146
  try:
147
  st.session_state.tool_status.info("Uploading...")
148
- print([(f.name, f.type) for f in st.session_state.uploaded_files])
149
- st.session_state.private_kb.add_by_file(st.session_state.user_name,
150
- st.session_state.uploaded_files)
 
151
  except ValueError as e:
152
  st.session_state.tool_status.error("Failed to upload! " + str(e))
153
  sleep(2)
154
-
 
155
  def clear_files():
156
  st.session_state.private_kb.clear(st.session_state.user_name)
 
157
 
158
 
159
  def chat_page():
@@ -168,7 +236,7 @@ def chat_page():
168
  port=MYSCALE_PORT,
169
  username=MYSCALE_USER,
170
  password=MYSCALE_PASSWORD,
171
- embedding=st.session_state.embeddings['Wikipedia'],
172
  parser_api_key=UNSTRUCTURED_API,
173
  )
174
  if "session_manager" not in st.session_state:
@@ -177,12 +245,21 @@ def chat_page():
177
  with st.expander("Session Management"):
178
  if "current_sessions" not in st.session_state:
179
  refresh_sessions()
180
- st.info("Here you can set up your session! \n\nYou can **change your prompt** here!",
181
- icon="πŸ€–")
182
- st.info(("**Add columns by clicking the empty row**.\n"
183
- "And **delete columns by selecting rows with a press on `DEL` Key**"),
184
- icon="πŸ’‘")
185
- st.info("Don't forget to **click `Submit Change` to save your change**!", icon="πŸ“’")
 
 
 
 
 
 
 
 
 
186
  st.data_editor(
187
  st.session_state.current_sessions,
188
  num_rows="dynamic",
@@ -191,12 +268,18 @@ def chat_page():
191
  )
192
  st.button("Submit Change!", on_click=on_session_change_submit)
193
  with st.expander("Session Selection", expanded=True):
194
- st.info("Here you can select your session!", icon="πŸ€–")
195
- st.info("If no session is attach to your account, then we will add a default session to you!", icon="❀️")
 
 
196
  try:
197
  dfl_indx = [
198
  x["session_id"] for x in st.session_state.current_sessions
199
- ].index("default" if "" not in st.session_state else st.session_state.sel_session["session_id"])
 
 
 
 
200
  except Exception as e:
201
  print("*** ", str(e))
202
  dfl_indx = 0
@@ -210,39 +293,84 @@ def chat_page():
210
  )
211
  print(st.session_state.sel_sess)
212
  with st.expander("Tool Settings", expanded=True):
213
- st.info("Here you can select your tools.", icon="πŸ”§")
214
- st.info("We provides you several knowledge base tools for you. We are building more tools!", icon="πŸ‘·β€β™‚οΈ")
 
 
215
  st.session_state["tool_status"] = st.empty()
216
- tab_kb, tab_file, tab_build = st.tabs(["Knowledge Bases", "File Upload", "KB Builder"])
 
 
 
 
 
217
  with tab_kb:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  st.multiselect(
219
  "Select a Knowledge Base Tool",
220
- st.session_state.tools.keys(),
 
 
221
  default=["Wikipedia + Self Querying"],
222
  key="selected_tools",
223
  on_change=refresh_agent,
224
  )
 
 
 
 
 
 
 
 
225
  with tab_file:
226
- st.file_uploader("Upload files", key="uploaded_files", accept_multiple_files=True)
 
 
 
 
 
 
 
 
 
 
 
227
  st.markdown("### Uploaded Files")
228
- st.dataframe(st.session_state.private_kb.list_files(st.session_state.user_name))
 
 
 
229
  col_1, col_2 = st.columns(2)
230
  with col_1:
231
  st.button("Add Files", on_click=add_file)
232
  with col_2:
233
- st.button("Clear Files", on_click=clear_files)
234
- # with tab_build:
235
- # st.text_input("Give this knowledge base a description:")
236
- # col_3, col_4 = st.columns(2)
237
- # with col_3:
238
- # st.button("Build Your KB!")
239
- # with col_4:
240
- # st.button("Delete Your KB")
241
-
242
-
243
  st.button("Clear Chat History", on_click=clear_history)
244
  st.button("Logout", on_click=back_to_main)
245
- if 'agent' not in st.session_state:
246
  refresh_agent()
247
  print("!!! ", st.session_state.agent.memory.chat_memory.session_id)
248
  for msg in st.session_state.agent.memory.chat_memory.messages:
@@ -255,7 +383,10 @@ def chat_page():
255
  st.write("Retrieved from knowledge base:")
256
  try:
257
  st.dataframe(
258
- pd.DataFrame.from_records(map(dict, eval(msg.content)))
 
 
 
259
  )
260
  except:
261
  st.write(msg.content)
 
1
+ import json
2
  import pandas as pd
3
  from os import environ
4
  from time import sleep
 
8
  from lib.private_kb import PrivateKnowledgeBase
9
  from langchain.schema import HumanMessage, FunctionMessage
10
  from callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler
11
+ from langchain.callbacks.streamlit.streamlit_callback_handler import (
12
+ StreamlitCallbackHandler,
13
+ )
14
+ from lib.json_conv import CustomJSONDecoder
15
 
16
+ from lib.helper import (
17
  build_agents,
18
  MYSCALE_HOST,
19
  MYSCALE_PASSWORD,
 
34
 
35
  def on_chat_submit():
36
  with st.session_state.next_round.container():
37
+ with st.chat_message("user"):
38
  st.write(st.session_state.chat_input)
39
+ with st.chat_message("assistant"):
40
  container = st.container()
41
+ st_callback = ChatDataAgentCallBackHandler(
42
+ container, collapse_completed_thoughts=False
43
+ )
44
+ ret = st.session_state.agent(
45
+ {"input": st.session_state.chat_input}, callbacks=[st_callback]
46
+ )
47
  print(ret)
48
 
49
 
 
113
  st.session_state[
114
  "current_sessions"
115
  ] = st.session_state.session_manager.list_sessions(st.session_state.user_name)
116
+ if (
117
+ type(st.session_state.current_sessions) is not dict
118
+ and len(st.session_state.current_sessions) <= 0
119
+ ):
120
  st.session_state.session_manager.add_session(
121
  st.session_state.user_name,
122
  f"{st.session_state.user_name}?default",
 
125
  st.session_state[
126
  "current_sessions"
127
  ] = st.session_state.session_manager.list_sessions(st.session_state.user_name)
128
+ st.session_state["user_files"] = st.session_state.private_kb.list_files(
129
+ st.session_state.user_name
130
+ )
131
+ st.session_state["user_tools"] = st.session_state.private_kb.list_tools(
132
+ st.session_state.user_name
133
+ )
134
+ st.session_state["tools_with_users"] = {
135
+ **st.session_state.tools,
136
+ **st.session_state.private_kb.as_tools(st.session_state.user_name),
137
+ }
138
  try:
139
+ dfl_indx = [x["session_id"] for x in st.session_state.current_sessions].index(
140
+ "default"
141
+ if "" not in st.session_state
142
+ else st.session_state.sel_session["session_id"]
143
+ )
144
  except ValueError:
145
  dfl_indx = 0
146
  st.session_state.sel_sess = st.session_state.current_sessions[dfl_indx]
147
 
148
 
149
+ def build_kb_as_tool():
150
+ if (
151
+ "b_tool_name" in st.session_state
152
+ and "b_tool_desc" in st.session_state
153
+ and "b_tool_files" in st.session_state
154
+ and len(st.session_state.b_tool_name) > 0
155
+ and len(st.session_state.b_tool_desc) > 0
156
+ and len(st.session_state.b_tool_files) > 0
157
+ ):
158
+ st.session_state.private_kb.create_tool(
159
+ st.session_state.user_name,
160
+ st.session_state.b_tool_name,
161
+ st.session_state.b_tool_desc,
162
+ [f["file_name"] for f in st.session_state.b_tool_files],
163
+ )
164
+ refresh_sessions()
165
+ else:
166
+ st.session_state.tool_status.error(
167
+ "You should fill all fields to build up a tool!"
168
+ )
169
+ sleep(2)
170
+
171
+
172
+ def remove_kb():
173
+ if "r_tool_names" in st.session_state and len(st.session_state.r_tool_names) > 0:
174
+ st.session_state.private_kb.remove_tools(
175
+ st.session_state.user_name,
176
+ [f["tool_name"] for f in st.session_state.r_tool_names],
177
+ )
178
+ refresh_sessions()
179
+ else:
180
+ st.session_state.tool_status.error(
181
+ "You should specify at least one tool to delete!"
182
+ )
183
+ sleep(2)
184
+
185
+
186
  def refresh_agent():
187
  with st.spinner("Initializing session..."):
188
  print(
 
199
  else st.session_state.sel_sess["system_prompt"],
200
  )
201
 
202
+
203
  def add_file():
204
+ if (
205
+ "uploaded_files" not in st.session_state
206
+ or len(st.session_state.uploaded_files) == 0
207
+ ):
208
  st.session_state.tool_status.error("Please upload files!", icon="⚠️")
209
  sleep(2)
210
  return
211
  try:
212
  st.session_state.tool_status.info("Uploading...")
213
+ st.session_state.private_kb.add_by_file(
214
+ st.session_state.user_name, st.session_state.uploaded_files
215
+ )
216
+ refresh_sessions()
217
  except ValueError as e:
218
  st.session_state.tool_status.error("Failed to upload! " + str(e))
219
  sleep(2)
220
+
221
+
222
  def clear_files():
223
  st.session_state.private_kb.clear(st.session_state.user_name)
224
+ refresh_sessions()
225
 
226
 
227
  def chat_page():
 
236
  port=MYSCALE_PORT,
237
  username=MYSCALE_USER,
238
  password=MYSCALE_PASSWORD,
239
+ embedding=st.session_state.embeddings["Wikipedia"],
240
  parser_api_key=UNSTRUCTURED_API,
241
  )
242
  if "session_manager" not in st.session_state:
 
245
  with st.expander("Session Management"):
246
  if "current_sessions" not in st.session_state:
247
  refresh_sessions()
248
+ st.info(
249
+ "Here you can set up your session! \n\nYou can **change your prompt** here!",
250
+ icon="πŸ€–",
251
+ )
252
+ st.info(
253
+ (
254
+ "**Add columns by clicking the empty row**.\n"
255
+ "And **delete columns by selecting rows with a press on `DEL` Key**"
256
+ ),
257
+ icon="πŸ’‘",
258
+ )
259
+ st.info(
260
+ "Don't forget to **click `Submit Change` to save your change**!",
261
+ icon="πŸ“’",
262
+ )
263
  st.data_editor(
264
  st.session_state.current_sessions,
265
  num_rows="dynamic",
 
268
  )
269
  st.button("Submit Change!", on_click=on_session_change_submit)
270
  with st.expander("Session Selection", expanded=True):
271
+ st.info(
272
+ "If no session is attach to your account, then we will add a default session to you!",
273
+ icon="❀️",
274
+ )
275
  try:
276
  dfl_indx = [
277
  x["session_id"] for x in st.session_state.current_sessions
278
+ ].index(
279
+ "default"
280
+ if "" not in st.session_state
281
+ else st.session_state.sel_session["session_id"]
282
+ )
283
  except Exception as e:
284
  print("*** ", str(e))
285
  dfl_indx = 0
 
293
  )
294
  print(st.session_state.sel_sess)
295
  with st.expander("Tool Settings", expanded=True):
296
+ st.info(
297
+ "We provides you several knowledge base tools for you. We are building more tools!",
298
+ icon="πŸ”§",
299
+ )
300
  st.session_state["tool_status"] = st.empty()
301
+ tab_kb, tab_file = st.tabs(
302
+ [
303
+ "Knowledge Bases",
304
+ "File Upload",
305
+ ]
306
+ )
307
  with tab_kb:
308
+ st.markdown("#### Build You Own Knowledge")
309
+ st.multiselect(
310
+ "Select Files to Build up",
311
+ st.session_state.user_files,
312
+ placeholder="You should upload files first",
313
+ key="b_tool_files",
314
+ format_func=lambda x: x["file_name"],
315
+ )
316
+ st.text_input("Tool Name", "get_relevant_documents", key="b_tool_name")
317
+ st.text_input(
318
+ "Tool Description",
319
+ "Searches among user's private files and returns related documents",
320
+ key="b_tool_desc",
321
+ )
322
+ st.button("Build!", on_click=build_kb_as_tool)
323
+ st.markdown("### Knowledge Base Selection")
324
+ if (
325
+ "user_tools" in st.session_state
326
+ and len(st.session_state.user_tools) > 0
327
+ ):
328
+ st.markdown("***User Created Knowledge Bases***")
329
+ st.dataframe(st.session_state.user_tools)
330
  st.multiselect(
331
  "Select a Knowledge Base Tool",
332
+ st.session_state.tools.keys()
333
+ if "tools_with_users" not in st.session_state
334
+ else st.session_state.tools_with_users,
335
  default=["Wikipedia + Self Querying"],
336
  key="selected_tools",
337
  on_change=refresh_agent,
338
  )
339
+ st.markdown("### Delete Knowledge Base")
340
+ st.multiselect(
341
+ "Choose Knowledge Base to Remove",
342
+ st.session_state.user_tools,
343
+ format_func=lambda x: x["tool_name"],
344
+ key="r_tool_names",
345
+ )
346
+ st.button("Delete", on_click=remove_kb)
347
  with tab_file:
348
+ st.info(
349
+ (
350
+ "We adopted [Unstructured API](https://unstructured.io/api-key) "
351
+ "here and we only store the processed texts from your documents. "
352
+ "For privacy concerns, please refer to "
353
+ "[our policy issue](https://myscale.com/privacy/)."
354
+ ),
355
+ icon="πŸ“ƒ",
356
+ )
357
+ st.file_uploader(
358
+ "Upload files", key="uploaded_files", accept_multiple_files=True
359
+ )
360
  st.markdown("### Uploaded Files")
361
+ st.dataframe(
362
+ st.session_state.private_kb.list_files(st.session_state.user_name),
363
+ use_container_width=True,
364
+ )
365
  col_1, col_2 = st.columns(2)
366
  with col_1:
367
  st.button("Add Files", on_click=add_file)
368
  with col_2:
369
+ st.button("Clear Files and All Tools", on_click=clear_files)
370
+
 
 
 
 
 
 
 
 
371
  st.button("Clear Chat History", on_click=clear_history)
372
  st.button("Logout", on_click=back_to_main)
373
+ if "agent" not in st.session_state:
374
  refresh_agent()
375
  print("!!! ", st.session_state.agent.memory.chat_memory.session_id)
376
  for msg in st.session_state.agent.memory.chat_memory.messages:
 
383
  st.write("Retrieved from knowledge base:")
384
  try:
385
  st.dataframe(
386
+ pd.DataFrame.from_records(
387
+ json.loads(msg.content, cls=CustomJSONDecoder)
388
+ ),
389
+ use_container_width=True,
390
  )
391
  except:
392
  st.write(msg.content)
helper.py β†’ lib/helper.py RENAMED
@@ -49,10 +49,12 @@ from langchain.memory import SQLChatMessageHistory
49
  from langchain.memory.chat_message_histories.sql import \
50
  BaseMessageConverter, DefaultMessageConverter
51
  from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
52
- from langchain.agents.agent_toolkits import create_retriever_tool
53
  from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt
54
  from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
55
  from chains.arxiv_chains import VectorSQLRetrieveCustomOutputParser
 
 
56
  environ['TOKENIZERS_PARALLELISM'] = 'true'
57
  environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
58
 
@@ -495,7 +497,7 @@ def create_retriever_tool(
495
  def wrap(func):
496
  def wrapped_retrieve(*args, **kwargs):
497
  docs: List[Document] = func(*args, **kwargs)
498
- return json.dumps([d.dict() for d in docs])
499
  return wrapped_retrieve
500
 
501
  return Tool(
@@ -533,12 +535,13 @@ def build_agents(session_id, tool_names, chat_model_name=chat_model_name, temper
533
  chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=temperature,
534
  openai_api_base=OPENAI_API_BASE, openai_api_key=OPENAI_API_KEY, streaming=True,
535
  )
536
- tools = [st.session_state.tools[k] for k in tool_names]
 
537
  agent = create_agent_executor(
538
  "chat_memory",
539
  session_id,
540
  chat_llm,
541
- tools=tools,
542
  system_prompt=system_prompt
543
  )
544
  return agent
 
49
  from langchain.memory.chat_message_histories.sql import \
50
  BaseMessageConverter, DefaultMessageConverter
51
  from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
52
+ # from langchain.agents.agent_toolkits import create_retriever_tool
53
  from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt
54
  from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
55
  from chains.arxiv_chains import VectorSQLRetrieveCustomOutputParser
56
+ from .json_conv import CustomJSONEncoder
57
+
58
  environ['TOKENIZERS_PARALLELISM'] = 'true'
59
  environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
60
 
 
497
  def wrap(func):
498
  def wrapped_retrieve(*args, **kwargs):
499
  docs: List[Document] = func(*args, **kwargs)
500
+ return json.dumps([d.dict() for d in docs], cls=CustomJSONEncoder)
501
  return wrapped_retrieve
502
 
503
  return Tool(
 
535
  chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=temperature,
536
  openai_api_base=OPENAI_API_BASE, openai_api_key=OPENAI_API_KEY, streaming=True,
537
  )
538
+ tools = st.session_state.tools if "tools_with_users" not in st.session_state else st.session_state.tools_with_users
539
+ sel_tools = [tools[k] for k in tool_names]
540
  agent = create_agent_executor(
541
  "chat_memory",
542
  session_id,
543
  chat_llm,
544
+ tools=sel_tools,
545
  system_prompt=system_prompt
546
  )
547
  return agent
lib/json_conv.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import datetime
3
+
4
+ class CustomJSONEncoder(json.JSONEncoder):
5
+ def default(self, obj):
6
+ if isinstance(obj, datetime.datetime):
7
+ return datetime.datetime.isoformat(obj)
8
+ return json.JSONEncoder.default(self, obj)
9
+
10
+ class CustomJSONDecoder(json.JSONDecoder):
11
+ def __init__(self, *args, **kwargs):
12
+ json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
13
+
14
+ def object_hook(self, source):
15
+ for k, v in source.items():
16
+ if isinstance(v, str):
17
+ try:
18
+ source[k] = datetime.datetime.fromisoformat(str(v))
19
+ except:
20
+ pass
21
+ return source
lib/private_kb.py CHANGED
@@ -1,18 +1,19 @@
1
  import pandas as pd
2
  import hashlib
3
  import requests
4
- from typing import List
5
  from datetime import datetime
6
  from langchain.schema.embeddings import Embeddings
7
  from streamlit.runtime.uploaded_file_manager import UploadedFile
8
  from clickhouse_connect import get_client
9
  from multiprocessing.pool import ThreadPool
10
  from langchain.vectorstores.myscale import MyScaleWithoutJSON, MyScaleSettings
 
11
 
12
  parser_url = "https://api.unstructured.io/general/v0/general"
13
 
14
 
15
- def parse_files(api_key, user_id, files: List[UploadedFile], collection="default"):
16
  def parse_file(file: UploadedFile):
17
  headers = {
18
  "accept": "application/json",
@@ -31,9 +32,10 @@ def parse_files(api_key, user_id, files: List[UploadedFile], collection="default
31
  {
32
  "text": t["text"],
33
  "file_name": t["metadata"]["filename"],
34
- "entity_id": hashlib.sha256((file_hash + t["text"]).encode()).hexdigest(),
 
 
35
  "user_id": user_id,
36
- "collection_id": collection,
37
  "created_by": datetime.now(),
38
  }
39
  for t in json_response
@@ -43,7 +45,7 @@ def parse_files(api_key, user_id, files: List[UploadedFile], collection="default
43
 
44
  with ThreadPool(8) as p:
45
  rows = []
46
- for r in map(parse_file, files):
47
  rows.extend(r)
48
  return rows
49
 
@@ -68,21 +70,33 @@ class PrivateKnowledgeBase:
68
  parser_api_key,
69
  db="chat",
70
  kb_table="private_kb",
 
71
  ) -> None:
72
  super().__init__()
73
- schema_ = f"""
74
  CREATE TABLE IF NOT EXISTS {db}.{kb_table}(
75
  entity_id String,
76
  file_name String,
77
  text String,
78
  user_id String,
79
- collection_id String,
80
  created_by DateTime,
81
  vector Array(Float32),
82
  CONSTRAINT cons_vec_len CHECK length(vector) = 768,
83
  VECTOR INDEX vidx vector TYPE MSTG('metric_type=Cosine')
84
  ) ENGINE = ReplacingMergeTree ORDER BY entity_id
85
  """
 
 
 
 
 
 
 
 
 
 
 
 
86
  config = MyScaleSettings(
87
  host=host,
88
  port=port,
@@ -98,41 +112,101 @@ class PrivateKnowledgeBase:
98
  password=config.password,
99
  )
100
  client.command("SET allow_experimental_object_type=1")
101
- client.command(schema_)
 
102
  self.parser_api_key = parser_api_key
103
  self.vstore = MyScaleWithoutJSON(
104
  embedding=embedding,
105
  config=config,
106
- must_have_cols=["file_name", "text", "create_by"],
107
  )
108
- self.retriever = self.vstore.as_retriever()
109
 
110
- def list_files(self, user_id):
111
  query = f"""
112
- SELECT DISTINCT file_name FROM {self.vstore.config.database}.{self.vstore.config.table}
113
- WHERE user_id = '{user_id}'
 
 
114
  """
115
  return [r for r in self.vstore.client.query(query).named_results()]
116
 
117
  def add_by_file(
118
- self, user_id, files: List[UploadedFile], collection="default", **kwargs
119
  ):
120
- data = parse_files(self.parser_api_key, user_id, files, collection=collection)
121
  data = extract_embedding(self.vstore.embeddings, data)
122
  self.vstore.client.insert_df(
123
- self.vstore.config.table,
124
  pd.DataFrame(data),
125
  database=self.vstore.config.database,
126
  )
127
 
128
  def clear(self, user_id):
129
  self.vstore.client.command(
130
- f"DELETE FROM {self.vstore.config.database}.{self.vstore.config.table} "
131
  f"WHERE user_id='{user_id}'"
132
  )
 
 
 
133
 
134
- def _get_relevant_documents(self, query, *args, **kwargs):
135
- return self.retriever._get_relevant_documents(query, *args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
- async def _aget_relevant_documents(self, *args, **kwargs):
138
- return self.retriever._aget_relevant_documents(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pandas as pd
2
  import hashlib
3
  import requests
4
+ from typing import List, Optional
5
  from datetime import datetime
6
  from langchain.schema.embeddings import Embeddings
7
  from streamlit.runtime.uploaded_file_manager import UploadedFile
8
  from clickhouse_connect import get_client
9
  from multiprocessing.pool import ThreadPool
10
  from langchain.vectorstores.myscale import MyScaleWithoutJSON, MyScaleSettings
11
+ from .helper import create_retriever_tool
12
 
13
  parser_url = "https://api.unstructured.io/general/v0/general"
14
 
15
 
16
+ def parse_files(api_key, user_id, files: List[UploadedFile]):
17
  def parse_file(file: UploadedFile):
18
  headers = {
19
  "accept": "application/json",
 
32
  {
33
  "text": t["text"],
34
  "file_name": t["metadata"]["filename"],
35
+ "entity_id": hashlib.sha256(
36
+ (file_hash + t["text"]).encode()
37
+ ).hexdigest(),
38
  "user_id": user_id,
 
39
  "created_by": datetime.now(),
40
  }
41
  for t in json_response
 
45
 
46
  with ThreadPool(8) as p:
47
  rows = []
48
+ for r in p.imap_unordered(parse_file, files):
49
  rows.extend(r)
50
  return rows
51
 
 
70
  parser_api_key,
71
  db="chat",
72
  kb_table="private_kb",
73
+ tool_table="private_tool",
74
  ) -> None:
75
  super().__init__()
76
+ kb_schema_ = f"""
77
  CREATE TABLE IF NOT EXISTS {db}.{kb_table}(
78
  entity_id String,
79
  file_name String,
80
  text String,
81
  user_id String,
 
82
  created_by DateTime,
83
  vector Array(Float32),
84
  CONSTRAINT cons_vec_len CHECK length(vector) = 768,
85
  VECTOR INDEX vidx vector TYPE MSTG('metric_type=Cosine')
86
  ) ENGINE = ReplacingMergeTree ORDER BY entity_id
87
  """
88
+ tool_schema_ = f"""
89
+ CREATE TABLE IF NOT EXISTS {db}.{tool_table}(
90
+ tool_id String,
91
+ tool_name String,
92
+ file_names Array(String),
93
+ user_id String,
94
+ created_by DateTime,
95
+ tool_description String
96
+ ) ENGINE = ReplacingMergeTree ORDER BY tool_id
97
+ """
98
+ self.kb_table = kb_table
99
+ self.tool_table = tool_table
100
  config = MyScaleSettings(
101
  host=host,
102
  port=port,
 
112
  password=config.password,
113
  )
114
  client.command("SET allow_experimental_object_type=1")
115
+ client.command(kb_schema_)
116
+ client.command(tool_schema_)
117
  self.parser_api_key = parser_api_key
118
  self.vstore = MyScaleWithoutJSON(
119
  embedding=embedding,
120
  config=config,
121
+ must_have_cols=["file_name", "text", "created_by"],
122
  )
 
123
 
124
+ def list_files(self, user_id, tool_name=None):
125
  query = f"""
126
+ SELECT DISTINCT file_name, COUNT(entity_id) AS num_paragraph,
127
+ arrayMax(arrayMap(x->length(x), groupArray(text))) AS max_chars
128
+ FROM {self.vstore.config.database}.{self.kb_table}
129
+ WHERE user_id = '{user_id}' GROUP BY file_name
130
  """
131
  return [r for r in self.vstore.client.query(query).named_results()]
132
 
133
  def add_by_file(
134
+ self, user_id, files: List[UploadedFile], **kwargs
135
  ):
136
+ data = parse_files(self.parser_api_key, user_id, files)
137
  data = extract_embedding(self.vstore.embeddings, data)
138
  self.vstore.client.insert_df(
139
+ self.kb_table,
140
  pd.DataFrame(data),
141
  database=self.vstore.config.database,
142
  )
143
 
144
  def clear(self, user_id):
145
  self.vstore.client.command(
146
+ f"DELETE FROM {self.vstore.config.database}.{self.kb_table} "
147
  f"WHERE user_id='{user_id}'"
148
  )
149
+ query = f"""DELETE FROM {self.vstore.config.database}.{self.tool_table}
150
+ WHERE user_id = '{user_id}'"""
151
+ self.vstore.client.command(query)
152
 
153
+ def create_tool(
154
+ self, user_id, tool_name, tool_description, files: Optional[List[str]] = None
155
+ ):
156
+ self.vstore.client.insert_df(
157
+ self.tool_table,
158
+ pd.DataFrame(
159
+ [
160
+ {
161
+ "tool_id": hashlib.sha256(
162
+ (user_id + tool_name).encode("utf-8")
163
+ ).hexdigest(),
164
+ "tool_name": tool_name,
165
+ "file_names": files,
166
+ "user_id": user_id,
167
+ "created_by": datetime.now(),
168
+ "tool_description": tool_description,
169
+ }
170
+ ]
171
+ ),
172
+ database=self.vstore.config.database,
173
+ )
174
 
175
+ def list_tools(self, user_id, tool_name=None):
176
+ extended_where = f"AND tool_name = '{tool_name}'" if tool_name else ""
177
+ query = f"""
178
+ SELECT tool_name, tool_description, length(file_names)
179
+ FROM {self.vstore.config.database}.{self.tool_table}
180
+ WHERE user_id = '{user_id}' {extended_where}
181
+ """
182
+ return [r for r in self.vstore.client.query(query).named_results()]
183
+
184
+ def remove_tools(self, user_id, tool_names):
185
+ tool_names = ",".join([f"'{t}'" for t in tool_names])
186
+ query = f"""DELETE FROM {self.vstore.config.database}.{self.tool_table}
187
+ WHERE user_id = '{user_id}' AND tool_name IN [{tool_names}]"""
188
+ self.vstore.client.command(query)
189
+
190
+ def as_tools(self, user_id, tool_name=None):
191
+ tools = self.list_tools(user_id=user_id, tool_name=tool_name)
192
+ retrievers = {
193
+ t["tool_name"]: create_retriever_tool(
194
+ self.vstore.as_retriever(
195
+ search_kwargs={
196
+ "where_str": (
197
+ f"user_id='{user_id}' "
198
+ f"""AND file_name IN (
199
+ SELECT arrayJoin(file_names) FROM (
200
+ SELECT file_names
201
+ FROM {self.vstore.config.database}.{self.tool_table}
202
+ WHERE user_id = '{user_id}' AND tool_name = '{t['tool_name']}')
203
+ )"""
204
+ )
205
+ },
206
+ ),
207
+ name=t["tool_name"],
208
+ description=t["tool_description"],
209
+ )
210
+ for t in tools
211
+ }
212
+ return retrievers
lib/sessions.py CHANGED
@@ -8,7 +8,6 @@ from datetime import datetime
8
  from sqlalchemy import Column, Text, orm, create_engine
9
  from clickhouse_sqlalchemy import types, engines
10
  from .schemas import create_message_model, create_session_table
11
- from .private_kb import PrivateKnowledgeBase
12
 
13
  def get_sessions(engine, model_class, user_id):
14
  with orm.sessionmaker(engine)() as session:
 
8
  from sqlalchemy import Column, Text, orm, create_engine
9
  from clickhouse_sqlalchemy import types, engines
10
  from .schemas import create_message_model, create_session_table
 
11
 
12
  def get_sessions(engine, model_class, user_id):
13
  with orm.sessionmaker(engine)() as session: