Spaces:
Running
Running
Fangrui Liu
commited on
Commit
Β·
c6f6149
1
Parent(s):
04f0bde
add knowledge base management
Browse files- app.py +1 -1
- callbacks/arxiv_callbacks.py +64 -40
- chat.py +173 -42
- helper.py β lib/helper.py +7 -4
- lib/json_conv.py +21 -0
- lib/private_kb.py +95 -21
- lib/sessions.py +0 -1
app.py
CHANGED
@@ -10,7 +10,7 @@ from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
|
|
10 |
|
11 |
from chat import chat_page
|
12 |
from login import login, back_to_main
|
13 |
-
from helper import build_tools,
|
14 |
|
15 |
|
16 |
|
|
|
10 |
|
11 |
from chat import chat_page
|
12 |
from login import login, back_to_main
|
13 |
+
from lib.helper import build_tools, build_all, sel_map, display
|
14 |
|
15 |
|
16 |
|
callbacks/arxiv_callbacks.py
CHANGED
@@ -3,70 +3,79 @@ import json
|
|
3 |
import textwrap
|
4 |
from typing import Dict, Any, List
|
5 |
from sql_formatter.core import format_sql
|
6 |
-
from langchain.callbacks.streamlit.streamlit_callback_handler import
|
|
|
|
|
|
|
7 |
from langchain.schema.output import LLMResult
|
8 |
from streamlit.delta_generator import DeltaGenerator
|
9 |
|
|
|
10 |
class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
|
11 |
def __init__(self) -> None:
|
12 |
self.progress_bar = st.progress(value=0.0, text="Working...")
|
13 |
self.tokens_stream = ""
|
14 |
-
|
15 |
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
|
16 |
pass
|
17 |
-
|
18 |
def on_text(self, text: str, **kwargs) -> None:
|
19 |
self.progress_bar.progress(value=0.2, text="Asking LLM...")
|
20 |
-
|
21 |
def on_chain_end(self, outputs, **kwargs) -> None:
|
22 |
-
self.progress_bar.progress(value=0.6, text=
|
23 |
-
if
|
24 |
-
st.markdown(
|
25 |
st.markdown(f"```python\n{outputs['repr']}\n```", unsafe_allow_html=True)
|
26 |
-
|
27 |
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
|
28 |
pass
|
29 |
|
|
|
30 |
class ChatDataSelfAskCallBackHandler(StreamlitCallbackHandler):
|
31 |
def __init__(self) -> None:
|
32 |
-
self.progress_bar = st.progress(value=0.0, text=
|
33 |
self.status_bar = st.empty()
|
34 |
self.prog_value = 0.0
|
35 |
self.prog_map = {
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
}
|
40 |
|
41 |
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
|
42 |
pass
|
43 |
-
|
44 |
def on_text(self, text: str, **kwargs) -> None:
|
45 |
pass
|
46 |
-
|
47 |
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
|
48 |
-
cid =
|
49 |
-
if cid !=
|
50 |
-
self.progress_bar.progress(
|
|
|
|
|
51 |
self.prog_value = self.prog_map[cid]
|
52 |
else:
|
53 |
self.prog_value += 0.1
|
54 |
-
self.progress_bar.progress(
|
|
|
|
|
55 |
|
56 |
def on_chain_end(self, outputs, **kwargs) -> None:
|
57 |
pass
|
58 |
-
|
59 |
|
60 |
class ChatDataSQLSearchCallBackHandler(StreamlitCallbackHandler):
|
61 |
def __init__(self) -> None:
|
62 |
-
self.progress_bar = st.progress(value=0.0, text=
|
63 |
self.status_bar = st.empty()
|
64 |
self.prog_value = 0
|
65 |
self.prog_interval = 0.2
|
66 |
|
67 |
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
|
68 |
pass
|
69 |
-
|
70 |
def on_llm_end(
|
71 |
self,
|
72 |
response: LLMResult,
|
@@ -74,41 +83,56 @@ class ChatDataSQLSearchCallBackHandler(StreamlitCallbackHandler):
|
|
74 |
**kwargs,
|
75 |
):
|
76 |
text = response.generations[0][0].text
|
77 |
-
if text.replace(
|
78 |
-
st.write(
|
79 |
-
st.markdown(f
|
80 |
print(f"Vector SQL: {text}")
|
81 |
self.prog_value += self.prog_interval
|
82 |
self.progress_bar.progress(value=self.prog_value, text="Searching in DB...")
|
83 |
-
|
84 |
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
|
85 |
-
cid =
|
86 |
self.prog_value += self.prog_interval
|
87 |
-
self.progress_bar.progress(
|
88 |
-
|
|
|
|
|
89 |
def on_chain_end(self, outputs, **kwargs) -> None:
|
90 |
pass
|
91 |
-
|
|
|
92 |
class ChatDataSQLAskCallBackHandler(ChatDataSQLSearchCallBackHandler):
|
93 |
def __init__(self) -> None:
|
94 |
-
self.progress_bar = st.progress(value=0.0, text=
|
95 |
self.status_bar = st.empty()
|
96 |
self.prog_value = 0
|
97 |
self.prog_interval = 0.1
|
98 |
-
|
99 |
-
|
100 |
class LLMThoughtWithKB(LLMThought):
|
101 |
-
def on_tool_end(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
try:
|
103 |
-
self._container.markdown(
|
104 |
-
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
except Exception as e:
|
107 |
super().on_tool_end(output, color, observation_prefix, llm_prefix, **kwargs)
|
108 |
-
|
109 |
-
|
110 |
class ChatDataAgentCallBackHandler(StreamlitCallbackHandler):
|
111 |
-
|
112 |
def on_llm_start(
|
113 |
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
114 |
) -> None:
|
@@ -120,4 +144,4 @@ class ChatDataAgentCallBackHandler(StreamlitCallbackHandler):
|
|
120 |
labeler=self._thought_labeler,
|
121 |
)
|
122 |
|
123 |
-
self._current_thought.on_llm_start(serialized, prompts)
|
|
|
3 |
import textwrap
|
4 |
from typing import Dict, Any, List
|
5 |
from sql_formatter.core import format_sql
|
6 |
+
from langchain.callbacks.streamlit.streamlit_callback_handler import (
|
7 |
+
LLMThought,
|
8 |
+
StreamlitCallbackHandler,
|
9 |
+
)
|
10 |
from langchain.schema.output import LLMResult
|
11 |
from streamlit.delta_generator import DeltaGenerator
|
12 |
|
13 |
+
|
14 |
class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
|
15 |
def __init__(self) -> None:
|
16 |
self.progress_bar = st.progress(value=0.0, text="Working...")
|
17 |
self.tokens_stream = ""
|
18 |
+
|
19 |
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
|
20 |
pass
|
21 |
+
|
22 |
def on_text(self, text: str, **kwargs) -> None:
|
23 |
self.progress_bar.progress(value=0.2, text="Asking LLM...")
|
24 |
+
|
25 |
def on_chain_end(self, outputs, **kwargs) -> None:
|
26 |
+
self.progress_bar.progress(value=0.6, text="Searching in DB...")
|
27 |
+
if "repr" in outputs:
|
28 |
+
st.markdown("### Generated Filter")
|
29 |
st.markdown(f"```python\n{outputs['repr']}\n```", unsafe_allow_html=True)
|
30 |
+
|
31 |
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
|
32 |
pass
|
33 |
|
34 |
+
|
35 |
class ChatDataSelfAskCallBackHandler(StreamlitCallbackHandler):
|
36 |
def __init__(self) -> None:
|
37 |
+
self.progress_bar = st.progress(value=0.0, text="Searching DB...")
|
38 |
self.status_bar = st.empty()
|
39 |
self.prog_value = 0.0
|
40 |
self.prog_map = {
|
41 |
+
"langchain.chains.qa_with_sources.retrieval.RetrievalQAWithSourcesChain": 0.2,
|
42 |
+
"langchain.chains.combine_documents.map_reduce.MapReduceDocumentsChain": 0.4,
|
43 |
+
"langchain.chains.combine_documents.stuff.StuffDocumentsChain": 0.8,
|
44 |
}
|
45 |
|
46 |
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
|
47 |
pass
|
48 |
+
|
49 |
def on_text(self, text: str, **kwargs) -> None:
|
50 |
pass
|
51 |
+
|
52 |
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
|
53 |
+
cid = ".".join(serialized["id"])
|
54 |
+
if cid != "langchain.chains.llm.LLMChain":
|
55 |
+
self.progress_bar.progress(
|
56 |
+
value=self.prog_map[cid], text=f"Running Chain `{cid}`..."
|
57 |
+
)
|
58 |
self.prog_value = self.prog_map[cid]
|
59 |
else:
|
60 |
self.prog_value += 0.1
|
61 |
+
self.progress_bar.progress(
|
62 |
+
value=self.prog_value, text=f"Running Chain `{cid}`..."
|
63 |
+
)
|
64 |
|
65 |
def on_chain_end(self, outputs, **kwargs) -> None:
|
66 |
pass
|
67 |
+
|
68 |
|
69 |
class ChatDataSQLSearchCallBackHandler(StreamlitCallbackHandler):
|
70 |
def __init__(self) -> None:
|
71 |
+
self.progress_bar = st.progress(value=0.0, text="Writing SQL...")
|
72 |
self.status_bar = st.empty()
|
73 |
self.prog_value = 0
|
74 |
self.prog_interval = 0.2
|
75 |
|
76 |
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
|
77 |
pass
|
78 |
+
|
79 |
def on_llm_end(
|
80 |
self,
|
81 |
response: LLMResult,
|
|
|
83 |
**kwargs,
|
84 |
):
|
85 |
text = response.generations[0][0].text
|
86 |
+
if text.replace(" ", "").upper().startswith("SELECT"):
|
87 |
+
st.write("We generated Vector SQL for you:")
|
88 |
+
st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""")
|
89 |
print(f"Vector SQL: {text}")
|
90 |
self.prog_value += self.prog_interval
|
91 |
self.progress_bar.progress(value=self.prog_value, text="Searching in DB...")
|
92 |
+
|
93 |
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
|
94 |
+
cid = ".".join(serialized["id"])
|
95 |
self.prog_value += self.prog_interval
|
96 |
+
self.progress_bar.progress(
|
97 |
+
value=self.prog_value, text=f"Running Chain `{cid}`..."
|
98 |
+
)
|
99 |
+
|
100 |
def on_chain_end(self, outputs, **kwargs) -> None:
|
101 |
pass
|
102 |
+
|
103 |
+
|
104 |
class ChatDataSQLAskCallBackHandler(ChatDataSQLSearchCallBackHandler):
|
105 |
def __init__(self) -> None:
|
106 |
+
self.progress_bar = st.progress(value=0.0, text="Writing SQL...")
|
107 |
self.status_bar = st.empty()
|
108 |
self.prog_value = 0
|
109 |
self.prog_interval = 0.1
|
110 |
+
|
111 |
+
|
112 |
class LLMThoughtWithKB(LLMThought):
|
113 |
+
def on_tool_end(
|
114 |
+
self,
|
115 |
+
output: str,
|
116 |
+
color=None,
|
117 |
+
observation_prefix=None,
|
118 |
+
llm_prefix=None,
|
119 |
+
**kwargs: Any,
|
120 |
+
) -> None:
|
121 |
try:
|
122 |
+
self._container.markdown(
|
123 |
+
"\n\n".join(
|
124 |
+
["### Retrieved Documents:"]
|
125 |
+
+ [
|
126 |
+
f"**{i+1}**: {textwrap.shorten(r['page_content'], width=80)}"
|
127 |
+
for i, r in enumerate(json.loads(output))
|
128 |
+
]
|
129 |
+
)
|
130 |
+
)
|
131 |
except Exception as e:
|
132 |
super().on_tool_end(output, color, observation_prefix, llm_prefix, **kwargs)
|
133 |
+
|
134 |
+
|
135 |
class ChatDataAgentCallBackHandler(StreamlitCallbackHandler):
|
|
|
136 |
def on_llm_start(
|
137 |
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
138 |
) -> None:
|
|
|
144 |
labeler=self._thought_labeler,
|
145 |
)
|
146 |
|
147 |
+
self._current_thought.on_llm_start(serialized, prompts)
|
chat.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import pandas as pd
|
2 |
from os import environ
|
3 |
from time import sleep
|
@@ -7,9 +8,12 @@ from lib.sessions import SessionManager
|
|
7 |
from lib.private_kb import PrivateKnowledgeBase
|
8 |
from langchain.schema import HumanMessage, FunctionMessage
|
9 |
from callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler
|
10 |
-
from langchain.callbacks.streamlit.streamlit_callback_handler import
|
|
|
|
|
|
|
11 |
|
12 |
-
from helper import (
|
13 |
build_agents,
|
14 |
MYSCALE_HOST,
|
15 |
MYSCALE_PASSWORD,
|
@@ -30,12 +34,16 @@ TOOL_NAMES = {
|
|
30 |
|
31 |
def on_chat_submit():
|
32 |
with st.session_state.next_round.container():
|
33 |
-
with st.chat_message(
|
34 |
st.write(st.session_state.chat_input)
|
35 |
-
with st.chat_message(
|
36 |
container = st.container()
|
37 |
-
st_callback = ChatDataAgentCallBackHandler(
|
38 |
-
|
|
|
|
|
|
|
|
|
39 |
print(ret)
|
40 |
|
41 |
|
@@ -105,7 +113,10 @@ def refresh_sessions():
|
|
105 |
st.session_state[
|
106 |
"current_sessions"
|
107 |
] = st.session_state.session_manager.list_sessions(st.session_state.user_name)
|
108 |
-
if
|
|
|
|
|
|
|
109 |
st.session_state.session_manager.add_session(
|
110 |
st.session_state.user_name,
|
111 |
f"{st.session_state.user_name}?default",
|
@@ -114,14 +125,64 @@ def refresh_sessions():
|
|
114 |
st.session_state[
|
115 |
"current_sessions"
|
116 |
] = st.session_state.session_manager.list_sessions(st.session_state.user_name)
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
try:
|
119 |
-
dfl_indx = [x["session_id"] for x in st.session_state.current_sessions].index(
|
|
|
|
|
|
|
|
|
120 |
except ValueError:
|
121 |
dfl_indx = 0
|
122 |
st.session_state.sel_sess = st.session_state.current_sessions[dfl_indx]
|
123 |
|
124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
def refresh_agent():
|
126 |
with st.spinner("Initializing session..."):
|
127 |
print(
|
@@ -138,22 +199,29 @@ def refresh_agent():
|
|
138 |
else st.session_state.sel_sess["system_prompt"],
|
139 |
)
|
140 |
|
|
|
141 |
def add_file():
|
142 |
-
if
|
|
|
|
|
|
|
143 |
st.session_state.tool_status.error("Please upload files!", icon="β οΈ")
|
144 |
sleep(2)
|
145 |
return
|
146 |
try:
|
147 |
st.session_state.tool_status.info("Uploading...")
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
151 |
except ValueError as e:
|
152 |
st.session_state.tool_status.error("Failed to upload! " + str(e))
|
153 |
sleep(2)
|
154 |
-
|
|
|
155 |
def clear_files():
|
156 |
st.session_state.private_kb.clear(st.session_state.user_name)
|
|
|
157 |
|
158 |
|
159 |
def chat_page():
|
@@ -168,7 +236,7 @@ def chat_page():
|
|
168 |
port=MYSCALE_PORT,
|
169 |
username=MYSCALE_USER,
|
170 |
password=MYSCALE_PASSWORD,
|
171 |
-
embedding=st.session_state.embeddings[
|
172 |
parser_api_key=UNSTRUCTURED_API,
|
173 |
)
|
174 |
if "session_manager" not in st.session_state:
|
@@ -177,12 +245,21 @@ def chat_page():
|
|
177 |
with st.expander("Session Management"):
|
178 |
if "current_sessions" not in st.session_state:
|
179 |
refresh_sessions()
|
180 |
-
st.info(
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
st.data_editor(
|
187 |
st.session_state.current_sessions,
|
188 |
num_rows="dynamic",
|
@@ -191,12 +268,18 @@ def chat_page():
|
|
191 |
)
|
192 |
st.button("Submit Change!", on_click=on_session_change_submit)
|
193 |
with st.expander("Session Selection", expanded=True):
|
194 |
-
st.info(
|
195 |
-
|
|
|
|
|
196 |
try:
|
197 |
dfl_indx = [
|
198 |
x["session_id"] for x in st.session_state.current_sessions
|
199 |
-
].index(
|
|
|
|
|
|
|
|
|
200 |
except Exception as e:
|
201 |
print("*** ", str(e))
|
202 |
dfl_indx = 0
|
@@ -210,39 +293,84 @@ def chat_page():
|
|
210 |
)
|
211 |
print(st.session_state.sel_sess)
|
212 |
with st.expander("Tool Settings", expanded=True):
|
213 |
-
st.info(
|
214 |
-
|
|
|
|
|
215 |
st.session_state["tool_status"] = st.empty()
|
216 |
-
tab_kb, tab_file
|
|
|
|
|
|
|
|
|
|
|
217 |
with tab_kb:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
st.multiselect(
|
219 |
"Select a Knowledge Base Tool",
|
220 |
-
st.session_state.tools.keys()
|
|
|
|
|
221 |
default=["Wikipedia + Self Querying"],
|
222 |
key="selected_tools",
|
223 |
on_change=refresh_agent,
|
224 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
with tab_file:
|
226 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
st.markdown("### Uploaded Files")
|
228 |
-
st.dataframe(
|
|
|
|
|
|
|
229 |
col_1, col_2 = st.columns(2)
|
230 |
with col_1:
|
231 |
st.button("Add Files", on_click=add_file)
|
232 |
with col_2:
|
233 |
-
st.button("Clear Files", on_click=clear_files)
|
234 |
-
|
235 |
-
# st.text_input("Give this knowledge base a description:")
|
236 |
-
# col_3, col_4 = st.columns(2)
|
237 |
-
# with col_3:
|
238 |
-
# st.button("Build Your KB!")
|
239 |
-
# with col_4:
|
240 |
-
# st.button("Delete Your KB")
|
241 |
-
|
242 |
-
|
243 |
st.button("Clear Chat History", on_click=clear_history)
|
244 |
st.button("Logout", on_click=back_to_main)
|
245 |
-
if
|
246 |
refresh_agent()
|
247 |
print("!!! ", st.session_state.agent.memory.chat_memory.session_id)
|
248 |
for msg in st.session_state.agent.memory.chat_memory.messages:
|
@@ -255,7 +383,10 @@ def chat_page():
|
|
255 |
st.write("Retrieved from knowledge base:")
|
256 |
try:
|
257 |
st.dataframe(
|
258 |
-
pd.DataFrame.from_records(
|
|
|
|
|
|
|
259 |
)
|
260 |
except:
|
261 |
st.write(msg.content)
|
|
|
1 |
+
import json
|
2 |
import pandas as pd
|
3 |
from os import environ
|
4 |
from time import sleep
|
|
|
8 |
from lib.private_kb import PrivateKnowledgeBase
|
9 |
from langchain.schema import HumanMessage, FunctionMessage
|
10 |
from callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler
|
11 |
+
from langchain.callbacks.streamlit.streamlit_callback_handler import (
|
12 |
+
StreamlitCallbackHandler,
|
13 |
+
)
|
14 |
+
from lib.json_conv import CustomJSONDecoder
|
15 |
|
16 |
+
from lib.helper import (
|
17 |
build_agents,
|
18 |
MYSCALE_HOST,
|
19 |
MYSCALE_PASSWORD,
|
|
|
34 |
|
35 |
def on_chat_submit():
|
36 |
with st.session_state.next_round.container():
|
37 |
+
with st.chat_message("user"):
|
38 |
st.write(st.session_state.chat_input)
|
39 |
+
with st.chat_message("assistant"):
|
40 |
container = st.container()
|
41 |
+
st_callback = ChatDataAgentCallBackHandler(
|
42 |
+
container, collapse_completed_thoughts=False
|
43 |
+
)
|
44 |
+
ret = st.session_state.agent(
|
45 |
+
{"input": st.session_state.chat_input}, callbacks=[st_callback]
|
46 |
+
)
|
47 |
print(ret)
|
48 |
|
49 |
|
|
|
113 |
st.session_state[
|
114 |
"current_sessions"
|
115 |
] = st.session_state.session_manager.list_sessions(st.session_state.user_name)
|
116 |
+
if (
|
117 |
+
type(st.session_state.current_sessions) is not dict
|
118 |
+
and len(st.session_state.current_sessions) <= 0
|
119 |
+
):
|
120 |
st.session_state.session_manager.add_session(
|
121 |
st.session_state.user_name,
|
122 |
f"{st.session_state.user_name}?default",
|
|
|
125 |
st.session_state[
|
126 |
"current_sessions"
|
127 |
] = st.session_state.session_manager.list_sessions(st.session_state.user_name)
|
128 |
+
st.session_state["user_files"] = st.session_state.private_kb.list_files(
|
129 |
+
st.session_state.user_name
|
130 |
+
)
|
131 |
+
st.session_state["user_tools"] = st.session_state.private_kb.list_tools(
|
132 |
+
st.session_state.user_name
|
133 |
+
)
|
134 |
+
st.session_state["tools_with_users"] = {
|
135 |
+
**st.session_state.tools,
|
136 |
+
**st.session_state.private_kb.as_tools(st.session_state.user_name),
|
137 |
+
}
|
138 |
try:
|
139 |
+
dfl_indx = [x["session_id"] for x in st.session_state.current_sessions].index(
|
140 |
+
"default"
|
141 |
+
if "" not in st.session_state
|
142 |
+
else st.session_state.sel_session["session_id"]
|
143 |
+
)
|
144 |
except ValueError:
|
145 |
dfl_indx = 0
|
146 |
st.session_state.sel_sess = st.session_state.current_sessions[dfl_indx]
|
147 |
|
148 |
|
149 |
+
def build_kb_as_tool():
|
150 |
+
if (
|
151 |
+
"b_tool_name" in st.session_state
|
152 |
+
and "b_tool_desc" in st.session_state
|
153 |
+
and "b_tool_files" in st.session_state
|
154 |
+
and len(st.session_state.b_tool_name) > 0
|
155 |
+
and len(st.session_state.b_tool_desc) > 0
|
156 |
+
and len(st.session_state.b_tool_files) > 0
|
157 |
+
):
|
158 |
+
st.session_state.private_kb.create_tool(
|
159 |
+
st.session_state.user_name,
|
160 |
+
st.session_state.b_tool_name,
|
161 |
+
st.session_state.b_tool_desc,
|
162 |
+
[f["file_name"] for f in st.session_state.b_tool_files],
|
163 |
+
)
|
164 |
+
refresh_sessions()
|
165 |
+
else:
|
166 |
+
st.session_state.tool_status.error(
|
167 |
+
"You should fill all fields to build up a tool!"
|
168 |
+
)
|
169 |
+
sleep(2)
|
170 |
+
|
171 |
+
|
172 |
+
def remove_kb():
|
173 |
+
if "r_tool_names" in st.session_state and len(st.session_state.r_tool_names) > 0:
|
174 |
+
st.session_state.private_kb.remove_tools(
|
175 |
+
st.session_state.user_name,
|
176 |
+
[f["tool_name"] for f in st.session_state.r_tool_names],
|
177 |
+
)
|
178 |
+
refresh_sessions()
|
179 |
+
else:
|
180 |
+
st.session_state.tool_status.error(
|
181 |
+
"You should specify at least one tool to delete!"
|
182 |
+
)
|
183 |
+
sleep(2)
|
184 |
+
|
185 |
+
|
186 |
def refresh_agent():
|
187 |
with st.spinner("Initializing session..."):
|
188 |
print(
|
|
|
199 |
else st.session_state.sel_sess["system_prompt"],
|
200 |
)
|
201 |
|
202 |
+
|
203 |
def add_file():
|
204 |
+
if (
|
205 |
+
"uploaded_files" not in st.session_state
|
206 |
+
or len(st.session_state.uploaded_files) == 0
|
207 |
+
):
|
208 |
st.session_state.tool_status.error("Please upload files!", icon="β οΈ")
|
209 |
sleep(2)
|
210 |
return
|
211 |
try:
|
212 |
st.session_state.tool_status.info("Uploading...")
|
213 |
+
st.session_state.private_kb.add_by_file(
|
214 |
+
st.session_state.user_name, st.session_state.uploaded_files
|
215 |
+
)
|
216 |
+
refresh_sessions()
|
217 |
except ValueError as e:
|
218 |
st.session_state.tool_status.error("Failed to upload! " + str(e))
|
219 |
sleep(2)
|
220 |
+
|
221 |
+
|
222 |
def clear_files():
|
223 |
st.session_state.private_kb.clear(st.session_state.user_name)
|
224 |
+
refresh_sessions()
|
225 |
|
226 |
|
227 |
def chat_page():
|
|
|
236 |
port=MYSCALE_PORT,
|
237 |
username=MYSCALE_USER,
|
238 |
password=MYSCALE_PASSWORD,
|
239 |
+
embedding=st.session_state.embeddings["Wikipedia"],
|
240 |
parser_api_key=UNSTRUCTURED_API,
|
241 |
)
|
242 |
if "session_manager" not in st.session_state:
|
|
|
245 |
with st.expander("Session Management"):
|
246 |
if "current_sessions" not in st.session_state:
|
247 |
refresh_sessions()
|
248 |
+
st.info(
|
249 |
+
"Here you can set up your session! \n\nYou can **change your prompt** here!",
|
250 |
+
icon="π€",
|
251 |
+
)
|
252 |
+
st.info(
|
253 |
+
(
|
254 |
+
"**Add columns by clicking the empty row**.\n"
|
255 |
+
"And **delete columns by selecting rows with a press on `DEL` Key**"
|
256 |
+
),
|
257 |
+
icon="π‘",
|
258 |
+
)
|
259 |
+
st.info(
|
260 |
+
"Don't forget to **click `Submit Change` to save your change**!",
|
261 |
+
icon="π",
|
262 |
+
)
|
263 |
st.data_editor(
|
264 |
st.session_state.current_sessions,
|
265 |
num_rows="dynamic",
|
|
|
268 |
)
|
269 |
st.button("Submit Change!", on_click=on_session_change_submit)
|
270 |
with st.expander("Session Selection", expanded=True):
|
271 |
+
st.info(
|
272 |
+
"If no session is attach to your account, then we will add a default session to you!",
|
273 |
+
icon="β€οΈ",
|
274 |
+
)
|
275 |
try:
|
276 |
dfl_indx = [
|
277 |
x["session_id"] for x in st.session_state.current_sessions
|
278 |
+
].index(
|
279 |
+
"default"
|
280 |
+
if "" not in st.session_state
|
281 |
+
else st.session_state.sel_session["session_id"]
|
282 |
+
)
|
283 |
except Exception as e:
|
284 |
print("*** ", str(e))
|
285 |
dfl_indx = 0
|
|
|
293 |
)
|
294 |
print(st.session_state.sel_sess)
|
295 |
with st.expander("Tool Settings", expanded=True):
|
296 |
+
st.info(
|
297 |
+
"We provides you several knowledge base tools for you. We are building more tools!",
|
298 |
+
icon="π§",
|
299 |
+
)
|
300 |
st.session_state["tool_status"] = st.empty()
|
301 |
+
tab_kb, tab_file = st.tabs(
|
302 |
+
[
|
303 |
+
"Knowledge Bases",
|
304 |
+
"File Upload",
|
305 |
+
]
|
306 |
+
)
|
307 |
with tab_kb:
|
308 |
+
st.markdown("#### Build You Own Knowledge")
|
309 |
+
st.multiselect(
|
310 |
+
"Select Files to Build up",
|
311 |
+
st.session_state.user_files,
|
312 |
+
placeholder="You should upload files first",
|
313 |
+
key="b_tool_files",
|
314 |
+
format_func=lambda x: x["file_name"],
|
315 |
+
)
|
316 |
+
st.text_input("Tool Name", "get_relevant_documents", key="b_tool_name")
|
317 |
+
st.text_input(
|
318 |
+
"Tool Description",
|
319 |
+
"Searches among user's private files and returns related documents",
|
320 |
+
key="b_tool_desc",
|
321 |
+
)
|
322 |
+
st.button("Build!", on_click=build_kb_as_tool)
|
323 |
+
st.markdown("### Knowledge Base Selection")
|
324 |
+
if (
|
325 |
+
"user_tools" in st.session_state
|
326 |
+
and len(st.session_state.user_tools) > 0
|
327 |
+
):
|
328 |
+
st.markdown("***User Created Knowledge Bases***")
|
329 |
+
st.dataframe(st.session_state.user_tools)
|
330 |
st.multiselect(
|
331 |
"Select a Knowledge Base Tool",
|
332 |
+
st.session_state.tools.keys()
|
333 |
+
if "tools_with_users" not in st.session_state
|
334 |
+
else st.session_state.tools_with_users,
|
335 |
default=["Wikipedia + Self Querying"],
|
336 |
key="selected_tools",
|
337 |
on_change=refresh_agent,
|
338 |
)
|
339 |
+
st.markdown("### Delete Knowledge Base")
|
340 |
+
st.multiselect(
|
341 |
+
"Choose Knowledge Base to Remove",
|
342 |
+
st.session_state.user_tools,
|
343 |
+
format_func=lambda x: x["tool_name"],
|
344 |
+
key="r_tool_names",
|
345 |
+
)
|
346 |
+
st.button("Delete", on_click=remove_kb)
|
347 |
with tab_file:
|
348 |
+
st.info(
|
349 |
+
(
|
350 |
+
"We adopted [Unstructured API](https://unstructured.io/api-key) "
|
351 |
+
"here and we only store the processed texts from your documents. "
|
352 |
+
"For privacy concerns, please refer to "
|
353 |
+
"[our policy issue](https://myscale.com/privacy/)."
|
354 |
+
),
|
355 |
+
icon="π",
|
356 |
+
)
|
357 |
+
st.file_uploader(
|
358 |
+
"Upload files", key="uploaded_files", accept_multiple_files=True
|
359 |
+
)
|
360 |
st.markdown("### Uploaded Files")
|
361 |
+
st.dataframe(
|
362 |
+
st.session_state.private_kb.list_files(st.session_state.user_name),
|
363 |
+
use_container_width=True,
|
364 |
+
)
|
365 |
col_1, col_2 = st.columns(2)
|
366 |
with col_1:
|
367 |
st.button("Add Files", on_click=add_file)
|
368 |
with col_2:
|
369 |
+
st.button("Clear Files and All Tools", on_click=clear_files)
|
370 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
371 |
st.button("Clear Chat History", on_click=clear_history)
|
372 |
st.button("Logout", on_click=back_to_main)
|
373 |
+
if "agent" not in st.session_state:
|
374 |
refresh_agent()
|
375 |
print("!!! ", st.session_state.agent.memory.chat_memory.session_id)
|
376 |
for msg in st.session_state.agent.memory.chat_memory.messages:
|
|
|
383 |
st.write("Retrieved from knowledge base:")
|
384 |
try:
|
385 |
st.dataframe(
|
386 |
+
pd.DataFrame.from_records(
|
387 |
+
json.loads(msg.content, cls=CustomJSONDecoder)
|
388 |
+
),
|
389 |
+
use_container_width=True,
|
390 |
)
|
391 |
except:
|
392 |
st.write(msg.content)
|
helper.py β lib/helper.py
RENAMED
@@ -49,10 +49,12 @@ from langchain.memory import SQLChatMessageHistory
|
|
49 |
from langchain.memory.chat_message_histories.sql import \
|
50 |
BaseMessageConverter, DefaultMessageConverter
|
51 |
from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
|
52 |
-
from langchain.agents.agent_toolkits import create_retriever_tool
|
53 |
from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt
|
54 |
from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
|
55 |
from chains.arxiv_chains import VectorSQLRetrieveCustomOutputParser
|
|
|
|
|
56 |
environ['TOKENIZERS_PARALLELISM'] = 'true'
|
57 |
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
|
58 |
|
@@ -495,7 +497,7 @@ def create_retriever_tool(
|
|
495 |
def wrap(func):
|
496 |
def wrapped_retrieve(*args, **kwargs):
|
497 |
docs: List[Document] = func(*args, **kwargs)
|
498 |
-
return json.dumps([d.dict() for d in docs])
|
499 |
return wrapped_retrieve
|
500 |
|
501 |
return Tool(
|
@@ -533,12 +535,13 @@ def build_agents(session_id, tool_names, chat_model_name=chat_model_name, temper
|
|
533 |
chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=temperature,
|
534 |
openai_api_base=OPENAI_API_BASE, openai_api_key=OPENAI_API_KEY, streaming=True,
|
535 |
)
|
536 |
-
tools =
|
|
|
537 |
agent = create_agent_executor(
|
538 |
"chat_memory",
|
539 |
session_id,
|
540 |
chat_llm,
|
541 |
-
tools=
|
542 |
system_prompt=system_prompt
|
543 |
)
|
544 |
return agent
|
|
|
49 |
from langchain.memory.chat_message_histories.sql import \
|
50 |
BaseMessageConverter, DefaultMessageConverter
|
51 |
from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
|
52 |
+
# from langchain.agents.agent_toolkits import create_retriever_tool
|
53 |
from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt
|
54 |
from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
|
55 |
from chains.arxiv_chains import VectorSQLRetrieveCustomOutputParser
|
56 |
+
from .json_conv import CustomJSONEncoder
|
57 |
+
|
58 |
environ['TOKENIZERS_PARALLELISM'] = 'true'
|
59 |
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
|
60 |
|
|
|
497 |
def wrap(func):
|
498 |
def wrapped_retrieve(*args, **kwargs):
|
499 |
docs: List[Document] = func(*args, **kwargs)
|
500 |
+
return json.dumps([d.dict() for d in docs], cls=CustomJSONEncoder)
|
501 |
return wrapped_retrieve
|
502 |
|
503 |
return Tool(
|
|
|
535 |
chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=temperature,
|
536 |
openai_api_base=OPENAI_API_BASE, openai_api_key=OPENAI_API_KEY, streaming=True,
|
537 |
)
|
538 |
+
tools = st.session_state.tools if "tools_with_users" not in st.session_state else st.session_state.tools_with_users
|
539 |
+
sel_tools = [tools[k] for k in tool_names]
|
540 |
agent = create_agent_executor(
|
541 |
"chat_memory",
|
542 |
session_id,
|
543 |
chat_llm,
|
544 |
+
tools=sel_tools,
|
545 |
system_prompt=system_prompt
|
546 |
)
|
547 |
return agent
|
lib/json_conv.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import datetime
|
3 |
+
|
4 |
+
class CustomJSONEncoder(json.JSONEncoder):
|
5 |
+
def default(self, obj):
|
6 |
+
if isinstance(obj, datetime.datetime):
|
7 |
+
return datetime.datetime.isoformat(obj)
|
8 |
+
return json.JSONEncoder.default(self, obj)
|
9 |
+
|
10 |
+
class CustomJSONDecoder(json.JSONDecoder):
|
11 |
+
def __init__(self, *args, **kwargs):
|
12 |
+
json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
|
13 |
+
|
14 |
+
def object_hook(self, source):
|
15 |
+
for k, v in source.items():
|
16 |
+
if isinstance(v, str):
|
17 |
+
try:
|
18 |
+
source[k] = datetime.datetime.fromisoformat(str(v))
|
19 |
+
except:
|
20 |
+
pass
|
21 |
+
return source
|
lib/private_kb.py
CHANGED
@@ -1,18 +1,19 @@
|
|
1 |
import pandas as pd
|
2 |
import hashlib
|
3 |
import requests
|
4 |
-
from typing import List
|
5 |
from datetime import datetime
|
6 |
from langchain.schema.embeddings import Embeddings
|
7 |
from streamlit.runtime.uploaded_file_manager import UploadedFile
|
8 |
from clickhouse_connect import get_client
|
9 |
from multiprocessing.pool import ThreadPool
|
10 |
from langchain.vectorstores.myscale import MyScaleWithoutJSON, MyScaleSettings
|
|
|
11 |
|
12 |
parser_url = "https://api.unstructured.io/general/v0/general"
|
13 |
|
14 |
|
15 |
-
def parse_files(api_key, user_id, files: List[UploadedFile]
|
16 |
def parse_file(file: UploadedFile):
|
17 |
headers = {
|
18 |
"accept": "application/json",
|
@@ -31,9 +32,10 @@ def parse_files(api_key, user_id, files: List[UploadedFile], collection="default
|
|
31 |
{
|
32 |
"text": t["text"],
|
33 |
"file_name": t["metadata"]["filename"],
|
34 |
-
"entity_id": hashlib.sha256(
|
|
|
|
|
35 |
"user_id": user_id,
|
36 |
-
"collection_id": collection,
|
37 |
"created_by": datetime.now(),
|
38 |
}
|
39 |
for t in json_response
|
@@ -43,7 +45,7 @@ def parse_files(api_key, user_id, files: List[UploadedFile], collection="default
|
|
43 |
|
44 |
with ThreadPool(8) as p:
|
45 |
rows = []
|
46 |
-
for r in
|
47 |
rows.extend(r)
|
48 |
return rows
|
49 |
|
@@ -68,21 +70,33 @@ class PrivateKnowledgeBase:
|
|
68 |
parser_api_key,
|
69 |
db="chat",
|
70 |
kb_table="private_kb",
|
|
|
71 |
) -> None:
|
72 |
super().__init__()
|
73 |
-
|
74 |
CREATE TABLE IF NOT EXISTS {db}.{kb_table}(
|
75 |
entity_id String,
|
76 |
file_name String,
|
77 |
text String,
|
78 |
user_id String,
|
79 |
-
collection_id String,
|
80 |
created_by DateTime,
|
81 |
vector Array(Float32),
|
82 |
CONSTRAINT cons_vec_len CHECK length(vector) = 768,
|
83 |
VECTOR INDEX vidx vector TYPE MSTG('metric_type=Cosine')
|
84 |
) ENGINE = ReplacingMergeTree ORDER BY entity_id
|
85 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
config = MyScaleSettings(
|
87 |
host=host,
|
88 |
port=port,
|
@@ -98,41 +112,101 @@ class PrivateKnowledgeBase:
|
|
98 |
password=config.password,
|
99 |
)
|
100 |
client.command("SET allow_experimental_object_type=1")
|
101 |
-
client.command(
|
|
|
102 |
self.parser_api_key = parser_api_key
|
103 |
self.vstore = MyScaleWithoutJSON(
|
104 |
embedding=embedding,
|
105 |
config=config,
|
106 |
-
must_have_cols=["file_name", "text", "
|
107 |
)
|
108 |
-
self.retriever = self.vstore.as_retriever()
|
109 |
|
110 |
-
def list_files(self, user_id):
|
111 |
query = f"""
|
112 |
-
SELECT DISTINCT file_name
|
113 |
-
|
|
|
|
|
114 |
"""
|
115 |
return [r for r in self.vstore.client.query(query).named_results()]
|
116 |
|
117 |
def add_by_file(
|
118 |
-
self, user_id, files: List[UploadedFile],
|
119 |
):
|
120 |
-
data = parse_files(self.parser_api_key, user_id, files
|
121 |
data = extract_embedding(self.vstore.embeddings, data)
|
122 |
self.vstore.client.insert_df(
|
123 |
-
self.
|
124 |
pd.DataFrame(data),
|
125 |
database=self.vstore.config.database,
|
126 |
)
|
127 |
|
128 |
def clear(self, user_id):
|
129 |
self.vstore.client.command(
|
130 |
-
f"DELETE FROM {self.vstore.config.database}.{self.
|
131 |
f"WHERE user_id='{user_id}'"
|
132 |
)
|
|
|
|
|
|
|
133 |
|
134 |
-
def
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
-
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import pandas as pd
|
2 |
import hashlib
|
3 |
import requests
|
4 |
+
from typing import List, Optional
|
5 |
from datetime import datetime
|
6 |
from langchain.schema.embeddings import Embeddings
|
7 |
from streamlit.runtime.uploaded_file_manager import UploadedFile
|
8 |
from clickhouse_connect import get_client
|
9 |
from multiprocessing.pool import ThreadPool
|
10 |
from langchain.vectorstores.myscale import MyScaleWithoutJSON, MyScaleSettings
|
11 |
+
from .helper import create_retriever_tool
|
12 |
|
13 |
parser_url = "https://api.unstructured.io/general/v0/general"
|
14 |
|
15 |
|
16 |
+
def parse_files(api_key, user_id, files: List[UploadedFile]):
|
17 |
def parse_file(file: UploadedFile):
|
18 |
headers = {
|
19 |
"accept": "application/json",
|
|
|
32 |
{
|
33 |
"text": t["text"],
|
34 |
"file_name": t["metadata"]["filename"],
|
35 |
+
"entity_id": hashlib.sha256(
|
36 |
+
(file_hash + t["text"]).encode()
|
37 |
+
).hexdigest(),
|
38 |
"user_id": user_id,
|
|
|
39 |
"created_by": datetime.now(),
|
40 |
}
|
41 |
for t in json_response
|
|
|
45 |
|
46 |
with ThreadPool(8) as p:
|
47 |
rows = []
|
48 |
+
for r in p.imap_unordered(parse_file, files):
|
49 |
rows.extend(r)
|
50 |
return rows
|
51 |
|
|
|
70 |
parser_api_key,
|
71 |
db="chat",
|
72 |
kb_table="private_kb",
|
73 |
+
tool_table="private_tool",
|
74 |
) -> None:
|
75 |
super().__init__()
|
76 |
+
kb_schema_ = f"""
|
77 |
CREATE TABLE IF NOT EXISTS {db}.{kb_table}(
|
78 |
entity_id String,
|
79 |
file_name String,
|
80 |
text String,
|
81 |
user_id String,
|
|
|
82 |
created_by DateTime,
|
83 |
vector Array(Float32),
|
84 |
CONSTRAINT cons_vec_len CHECK length(vector) = 768,
|
85 |
VECTOR INDEX vidx vector TYPE MSTG('metric_type=Cosine')
|
86 |
) ENGINE = ReplacingMergeTree ORDER BY entity_id
|
87 |
"""
|
88 |
+
tool_schema_ = f"""
|
89 |
+
CREATE TABLE IF NOT EXISTS {db}.{tool_table}(
|
90 |
+
tool_id String,
|
91 |
+
tool_name String,
|
92 |
+
file_names Array(String),
|
93 |
+
user_id String,
|
94 |
+
created_by DateTime,
|
95 |
+
tool_description String
|
96 |
+
) ENGINE = ReplacingMergeTree ORDER BY tool_id
|
97 |
+
"""
|
98 |
+
self.kb_table = kb_table
|
99 |
+
self.tool_table = tool_table
|
100 |
config = MyScaleSettings(
|
101 |
host=host,
|
102 |
port=port,
|
|
|
112 |
password=config.password,
|
113 |
)
|
114 |
client.command("SET allow_experimental_object_type=1")
|
115 |
+
client.command(kb_schema_)
|
116 |
+
client.command(tool_schema_)
|
117 |
self.parser_api_key = parser_api_key
|
118 |
self.vstore = MyScaleWithoutJSON(
|
119 |
embedding=embedding,
|
120 |
config=config,
|
121 |
+
must_have_cols=["file_name", "text", "created_by"],
|
122 |
)
|
|
|
123 |
|
124 |
+
def list_files(self, user_id, tool_name=None):
|
125 |
query = f"""
|
126 |
+
SELECT DISTINCT file_name, COUNT(entity_id) AS num_paragraph,
|
127 |
+
arrayMax(arrayMap(x->length(x), groupArray(text))) AS max_chars
|
128 |
+
FROM {self.vstore.config.database}.{self.kb_table}
|
129 |
+
WHERE user_id = '{user_id}' GROUP BY file_name
|
130 |
"""
|
131 |
return [r for r in self.vstore.client.query(query).named_results()]
|
132 |
|
133 |
def add_by_file(
|
134 |
+
self, user_id, files: List[UploadedFile], **kwargs
|
135 |
):
|
136 |
+
data = parse_files(self.parser_api_key, user_id, files)
|
137 |
data = extract_embedding(self.vstore.embeddings, data)
|
138 |
self.vstore.client.insert_df(
|
139 |
+
self.kb_table,
|
140 |
pd.DataFrame(data),
|
141 |
database=self.vstore.config.database,
|
142 |
)
|
143 |
|
144 |
def clear(self, user_id):
|
145 |
self.vstore.client.command(
|
146 |
+
f"DELETE FROM {self.vstore.config.database}.{self.kb_table} "
|
147 |
f"WHERE user_id='{user_id}'"
|
148 |
)
|
149 |
+
query = f"""DELETE FROM {self.vstore.config.database}.{self.tool_table}
|
150 |
+
WHERE user_id = '{user_id}'"""
|
151 |
+
self.vstore.client.command(query)
|
152 |
|
153 |
+
def create_tool(
|
154 |
+
self, user_id, tool_name, tool_description, files: Optional[List[str]] = None
|
155 |
+
):
|
156 |
+
self.vstore.client.insert_df(
|
157 |
+
self.tool_table,
|
158 |
+
pd.DataFrame(
|
159 |
+
[
|
160 |
+
{
|
161 |
+
"tool_id": hashlib.sha256(
|
162 |
+
(user_id + tool_name).encode("utf-8")
|
163 |
+
).hexdigest(),
|
164 |
+
"tool_name": tool_name,
|
165 |
+
"file_names": files,
|
166 |
+
"user_id": user_id,
|
167 |
+
"created_by": datetime.now(),
|
168 |
+
"tool_description": tool_description,
|
169 |
+
}
|
170 |
+
]
|
171 |
+
),
|
172 |
+
database=self.vstore.config.database,
|
173 |
+
)
|
174 |
|
175 |
+
def list_tools(self, user_id, tool_name=None):
|
176 |
+
extended_where = f"AND tool_name = '{tool_name}'" if tool_name else ""
|
177 |
+
query = f"""
|
178 |
+
SELECT tool_name, tool_description, length(file_names)
|
179 |
+
FROM {self.vstore.config.database}.{self.tool_table}
|
180 |
+
WHERE user_id = '{user_id}' {extended_where}
|
181 |
+
"""
|
182 |
+
return [r for r in self.vstore.client.query(query).named_results()]
|
183 |
+
|
184 |
+
def remove_tools(self, user_id, tool_names):
|
185 |
+
tool_names = ",".join([f"'{t}'" for t in tool_names])
|
186 |
+
query = f"""DELETE FROM {self.vstore.config.database}.{self.tool_table}
|
187 |
+
WHERE user_id = '{user_id}' AND tool_name IN [{tool_names}]"""
|
188 |
+
self.vstore.client.command(query)
|
189 |
+
|
190 |
+
def as_tools(self, user_id, tool_name=None):
|
191 |
+
tools = self.list_tools(user_id=user_id, tool_name=tool_name)
|
192 |
+
retrievers = {
|
193 |
+
t["tool_name"]: create_retriever_tool(
|
194 |
+
self.vstore.as_retriever(
|
195 |
+
search_kwargs={
|
196 |
+
"where_str": (
|
197 |
+
f"user_id='{user_id}' "
|
198 |
+
f"""AND file_name IN (
|
199 |
+
SELECT arrayJoin(file_names) FROM (
|
200 |
+
SELECT file_names
|
201 |
+
FROM {self.vstore.config.database}.{self.tool_table}
|
202 |
+
WHERE user_id = '{user_id}' AND tool_name = '{t['tool_name']}')
|
203 |
+
)"""
|
204 |
+
)
|
205 |
+
},
|
206 |
+
),
|
207 |
+
name=t["tool_name"],
|
208 |
+
description=t["tool_description"],
|
209 |
+
)
|
210 |
+
for t in tools
|
211 |
+
}
|
212 |
+
return retrievers
|
lib/sessions.py
CHANGED
@@ -8,7 +8,6 @@ from datetime import datetime
|
|
8 |
from sqlalchemy import Column, Text, orm, create_engine
|
9 |
from clickhouse_sqlalchemy import types, engines
|
10 |
from .schemas import create_message_model, create_session_table
|
11 |
-
from .private_kb import PrivateKnowledgeBase
|
12 |
|
13 |
def get_sessions(engine, model_class, user_id):
|
14 |
with orm.sessionmaker(engine)() as session:
|
|
|
8 |
from sqlalchemy import Column, Text, orm, create_engine
|
9 |
from clickhouse_sqlalchemy import types, engines
|
10 |
from .schemas import create_message_model, create_session_table
|
|
|
11 |
|
12 |
def get_sessions(engine, model_class, user_id):
|
13 |
with orm.sessionmaker(engine)() as session:
|