Fangrui Liu commited on
Commit
fab8405
Β·
1 Parent(s): c5b06e8

update chat memory schema

Browse files
Files changed (5) hide show
  1. app.py +112 -111
  2. callbacks/arxiv_callbacks.py +3 -2
  3. chat.py +28 -181
  4. helper.py +5 -2
  5. login.py +53 -0
app.py CHANGED
@@ -1,125 +1,126 @@
1
- from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt
2
- from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
3
- ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \
4
- ChatDataSQLAskCallBackHandler
5
- from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
6
- from chains.arxiv_chains import VectorSQLRetrieveCustomOutputParser
7
- from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
8
- from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
9
- from langchain.utilities.sql_database import SQLDatabase
10
- from langchain.chains import LLMChain
11
- from sqlalchemy import create_engine, MetaData
12
- from langchain.prompts import PromptTemplate, ChatPromptTemplate, \
13
- SystemMessagePromptTemplate, HumanMessagePromptTemplate
14
- from langchain.prompts.prompt import PromptTemplate
15
- from langchain.chat_models import ChatOpenAI
16
- from langchain import OpenAI
17
- import re
18
  import pandas as pd
19
  from os import environ
20
  import streamlit as st
21
- import datetime
22
- from helper import build_all, sel_map, display
23
- environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
24
 
25
- st.set_page_config(page_title="ChatData")
 
 
 
 
 
 
 
 
 
 
26
 
 
27
  st.header("ChatData")
28
 
29
  if 'retriever' not in st.session_state:
30
  st.session_state["sel_map_obj"] = build_all()
 
31
 
32
- sel = st.selectbox('Choose the knowledge base you want to ask with:',
33
- options=['ArXiv Papers', 'Wikipedia'])
34
- sel_map[sel]['hint']()
35
- tab_sql, tab_self_query = st.tabs(['Vector SQL', 'Self-Query Retrievers'])
36
- with tab_sql:
37
- sel_map[sel]['hint_sql']()
38
- st.text_input("Ask a question:", key='query_sql')
39
- cols = st.columns([1, 1, 7])
40
- cols[0].button("Query", key='search_sql')
41
- cols[1].button("Ask", key='ask_sql')
42
- plc_hldr = st.empty()
43
- if st.session_state.search_sql:
44
- plc_hldr = st.empty()
45
- print(st.session_state.query_sql)
46
- with plc_hldr.expander('Query Log', expanded=True):
47
- callback = ChatDataSQLSearchCallBackHandler()
48
- try:
49
- docs = st.session_state.sel_map_obj[sel]["sql_retriever"].get_relevant_documents(
50
- st.session_state.query_sql, callbacks=[callback])
51
- callback.progress_bar.progress(value=1.0, text="Done!")
52
- docs = pd.DataFrame(
53
- [{**d.metadata, 'abstract': d.page_content} for d in docs])
54
- display(docs)
55
- except Exception as e:
56
- st.write('Oops 😡 Something bad happened...')
57
- raise e
 
 
 
 
 
 
58
 
59
- if st.session_state.ask_sql:
60
- plc_hldr = st.empty()
61
- print(st.session_state.query_sql)
62
- with plc_hldr.expander('Chat Log', expanded=True):
63
- callback = ChatDataSQLAskCallBackHandler()
64
- try:
65
- ret = st.session_state.sel_map_obj[sel]["sql_chain"](
66
- st.session_state.query_sql, callbacks=[callback])
67
- callback.progress_bar.progress(value=1.0, text="Done!")
68
- st.markdown(
69
- f"### Answer from LLM\n{ret['answer']}\n### References")
70
- docs = ret['sources']
71
- docs = pd.DataFrame(
72
- [{**d.metadata, 'abstract': d.page_content} for d in docs])
73
- display(
74
- docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
75
- except Exception as e:
76
- st.write('Oops 😡 Something bad happened...')
77
- raise e
78
 
79
 
80
- with tab_self_query:
81
- st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='πŸ’‘')
82
- st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"])
83
- st.text_input("Ask a question:", key='query_self')
84
- cols = st.columns([1, 1, 7])
85
- cols[0].button("Query", key='search_self')
86
- cols[1].button("Ask", key='ask_self')
87
- plc_hldr = st.empty()
88
- if st.session_state.search_self:
89
- plc_hldr = st.empty()
90
- print(st.session_state.query_self)
91
- with plc_hldr.expander('Query Log', expanded=True):
92
- call_back = None
93
- callback = ChatDataSelfSearchCallBackHandler()
94
- try:
95
- docs = st.session_state.sel_map_obj[sel]["retriever"].get_relevant_documents(
96
- st.session_state.query_self, callbacks=[callback])
97
- print(docs)
98
- callback.progress_bar.progress(value=1.0, text="Done!")
99
- docs = pd.DataFrame(
100
- [{**d.metadata, 'abstract': d.page_content} for d in docs])
101
- display(docs, sel_map[sel]["must_have_cols"])
102
- except Exception as e:
103
- st.write('Oops 😡 Something bad happened...')
104
- raise e
 
105
 
106
- if st.session_state.ask_self:
107
- plc_hldr = st.empty()
108
- print(st.session_state.query_self)
109
- with plc_hldr.expander('Chat Log', expanded=True):
110
- call_back = None
111
- callback = ChatDataSelfAskCallBackHandler()
112
- try:
113
- ret = st.session_state.sel_map_obj[sel]["chain"](
114
- st.session_state.query_self, callbacks=[callback])
115
- callback.progress_bar.progress(value=1.0, text="Done!")
116
- st.markdown(
117
- f"### Answer from LLM\n{ret['answer']}\n### References")
118
- docs = ret['sources']
119
- docs = pd.DataFrame(
120
- [{**d.metadata, 'abstract': d.page_content} for d in docs])
121
- display(
122
- docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
123
- except Exception as e:
124
- st.write('Oops 😡 Something bad happened...')
125
- raise e
 
1
+ import json
2
+ import time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import pandas as pd
4
  from os import environ
5
  import streamlit as st
 
 
 
6
 
7
+ from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
8
+ ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \
9
+ ChatDataSQLAskCallBackHandler
10
+
11
+ from chat import chat_page
12
+ from login import login, back_to_main
13
+
14
+
15
+ from helper import build_tools, build_agents, build_all, sel_map, display
16
+
17
+ environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
18
 
19
+ st.set_page_config(page_title="ChatData", page_icon="https://myscale.com/favicon.ico")
20
  st.header("ChatData")
21
 
22
  if 'retriever' not in st.session_state:
23
  st.session_state["sel_map_obj"] = build_all()
24
+ st.session_state["tools"] = build_tools()
25
 
26
+ if login():
27
+ if "user_name" in st.session_state:
28
+ chat_page()
29
+ elif "jump_query_ask" in st.session_state and st.session_state.jump_query_ask:
30
+
31
+ sel = st.selectbox('Choose the knowledge base you want to ask with:',
32
+ options=['ArXiv Papers', 'Wikipedia'])
33
+ sel_map[sel]['hint']()
34
+ tab_sql, tab_self_query = st.tabs(['Vector SQL', 'Self-Query Retrievers'])
35
+ with tab_sql:
36
+ sel_map[sel]['hint_sql']()
37
+ st.text_input("Ask a question:", key='query_sql')
38
+ cols = st.columns([1, 1, 1, 4])
39
+ cols[0].button("Query", key='search_sql')
40
+ cols[1].button("Ask", key='ask_sql')
41
+ cols[2].button("Back", key='back_sql', on_click=back_to_main)
42
+ plc_hldr = st.empty()
43
+ if st.session_state.search_sql:
44
+ plc_hldr = st.empty()
45
+ print(st.session_state.query_sql)
46
+ with plc_hldr.expander('Query Log', expanded=True):
47
+ callback = ChatDataSQLSearchCallBackHandler()
48
+ try:
49
+ docs = st.session_state.sel_map_obj[sel]["sql_retriever"].get_relevant_documents(
50
+ st.session_state.query_sql, callbacks=[callback])
51
+ callback.progress_bar.progress(value=1.0, text="Done!")
52
+ docs = pd.DataFrame(
53
+ [{**d.metadata, 'abstract': d.page_content} for d in docs])
54
+ display(docs)
55
+ except Exception as e:
56
+ st.write('Oops 😡 Something bad happened...')
57
+ raise e
58
 
59
+ if st.session_state.ask_sql:
60
+ plc_hldr = st.empty()
61
+ print(st.session_state.query_sql)
62
+ with plc_hldr.expander('Chat Log', expanded=True):
63
+ callback = ChatDataSQLAskCallBackHandler()
64
+ try:
65
+ ret = st.session_state.sel_map_obj[sel]["sql_chain"](
66
+ st.session_state.query_sql, callbacks=[callback])
67
+ callback.progress_bar.progress(value=1.0, text="Done!")
68
+ st.markdown(
69
+ f"### Answer from LLM\n{ret['answer']}\n### References")
70
+ docs = ret['sources']
71
+ docs = pd.DataFrame(
72
+ [{**d.metadata, 'abstract': d.page_content} for d in docs])
73
+ display(
74
+ docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
75
+ except Exception as e:
76
+ st.write('Oops 😡 Something bad happened...')
77
+ raise e
78
 
79
 
80
+ with tab_self_query:
81
+ st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='πŸ’‘')
82
+ st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"])
83
+ st.text_input("Ask a question:", key='query_self')
84
+ cols = st.columns([1, 1, 1, 4])
85
+ cols[0].button("Query", key='search_self')
86
+ cols[1].button("Ask", key='ask_self')
87
+ cols[2].button("Back", key='back_self', on_click=back_to_main)
88
+ plc_hldr = st.empty()
89
+ if st.session_state.search_self:
90
+ plc_hldr = st.empty()
91
+ print(st.session_state.query_self)
92
+ with plc_hldr.expander('Query Log', expanded=True):
93
+ call_back = None
94
+ callback = ChatDataSelfSearchCallBackHandler()
95
+ try:
96
+ docs = st.session_state.sel_map_obj[sel]["retriever"].get_relevant_documents(
97
+ st.session_state.query_self, callbacks=[callback])
98
+ print(docs)
99
+ callback.progress_bar.progress(value=1.0, text="Done!")
100
+ docs = pd.DataFrame(
101
+ [{**d.metadata, 'abstract': d.page_content} for d in docs])
102
+ display(docs, sel_map[sel]["must_have_cols"])
103
+ except Exception as e:
104
+ st.write('Oops 😡 Something bad happened...')
105
+ raise e
106
 
107
+ if st.session_state.ask_self:
108
+ plc_hldr = st.empty()
109
+ print(st.session_state.query_self)
110
+ with plc_hldr.expander('Chat Log', expanded=True):
111
+ call_back = None
112
+ callback = ChatDataSelfAskCallBackHandler()
113
+ try:
114
+ ret = st.session_state.sel_map_obj[sel]["chain"](
115
+ st.session_state.query_self, callbacks=[callback])
116
+ callback.progress_bar.progress(value=1.0, text="Done!")
117
+ st.markdown(
118
+ f"### Answer from LLM\n{ret['answer']}\n### References")
119
+ docs = ret['sources']
120
+ docs = pd.DataFrame(
121
+ [{**d.metadata, 'abstract': d.page_content} for d in docs])
122
+ display(
123
+ docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
124
+ except Exception as e:
125
+ st.write('Oops 😡 Something bad happened...')
126
+ raise e
callbacks/arxiv_callbacks.py CHANGED
@@ -17,8 +17,9 @@ class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
17
 
18
  def on_chain_end(self, outputs, **kwargs) -> None:
19
  self.progress_bar.progress(value=0.6, text='Searching in DB...')
20
- st.markdown('### Generated Filter')
21
- st.write(outputs['text'], unsafe_allow_html=True)
 
22
 
23
  def on_chain_start(self, serialized, inputs, **kwargs) -> None:
24
  pass
 
17
 
18
  def on_chain_end(self, outputs, **kwargs) -> None:
19
  self.progress_bar.progress(value=0.6, text='Searching in DB...')
20
+ if 'repr' in outputs:
21
+ st.markdown('### Generated Filter')
22
+ st.markdown(f"```python\n{outputs['repr']}\n```", unsafe_allow_html=True)
23
 
24
  def on_chain_start(self, serialized, inputs, **kwargs) -> None:
25
  pass
chat.py CHANGED
@@ -1,31 +1,14 @@
1
- import json
2
- import time
3
  import pandas as pd
4
  from os import environ
5
  import datetime
6
  import streamlit as st
7
- from langchain.schema import Document
8
 
9
- from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
10
- ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \
11
- ChatDataSQLAskCallBackHandler
12
-
13
- from langchain.schema import BaseMessage, HumanMessage, AIMessage, FunctionMessage, SystemMessage
14
- from auth0_component import login_button
15
-
16
-
17
- from helper import build_tools, build_agents, build_all, sel_map, display
18
 
19
  environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
20
 
21
- st.set_page_config(page_title="ChatData", page_icon="https://myscale.com/favicon.ico")
22
- st.header("ChatData")
23
-
24
-
25
- if 'retriever' not in st.session_state:
26
- st.session_state["sel_map_obj"] = build_all()
27
- st.session_state["tools"] = build_tools()
28
-
29
  def on_chat_submit():
30
  ret = st.session_state.agents[st.session_state.sel][st.session_state.ret_type]({"input": st.session_state.chat_input})
31
  print(ret)
@@ -33,45 +16,7 @@ def on_chat_submit():
33
  def clear_history():
34
  st.session_state.agents[st.session_state.sel][st.session_state.ret_type].memory.clear()
35
 
36
- AUTH0_CLIENT_ID = st.secrets['AUTH0_CLIENT_ID']
37
- AUTH0_DOMAIN = st.secrets['AUTH0_DOMAIN']
38
 
39
- def login():
40
- if "user_name" in st.session_state or ("jump_query_ask" in st.session_state and st.session_state.jump_query_ask):
41
- return True
42
- st.subheader("πŸ€— Welcom to [MyScale](https://myscale.com)'s [ChatData](https://github.com/myscale/ChatData)! πŸ€— ")
43
- st.write("You can now chat with ArXiv and Wikipedia! 🌟\n")
44
- st.write("Built purely with streamlit πŸ‘‘ , LangChain πŸ¦œπŸ”— and love ❀️ for AI!")
45
- st.write("Follow us on [Twitter](https://x.com/myscaledb) and [Discord](https://discord.gg/D2qpkqc4Jq)!")
46
- st.write("For more details, please refer to [our repository on GitHub](https://github.com/myscale/ChatData)!")
47
- st.divider()
48
- col1, col2 = st.columns(2, gap='large')
49
- with col1.container():
50
- st.write("Try out MyScale's Self-query and Vector SQL retrievers!")
51
- st.write("In this demo, you will be able to see how those retrievers "
52
- "**digest** -> **translate** -> **retrieve** -> **answer** to your question!")
53
- st.session_state["jump_query_ask"] = st.button("Query / Ask")
54
- with col2.container():
55
- # st.warning("To use chat, please jump to [https://myscale-chatdata.hf.space](https://myscale-chatdata.hf.space)")
56
- st.write("Now with the power of LangChain's Conversantional Agents, we are able to build "
57
- "an RAG-enabled chatbot within one MyScale instance! ")
58
- st.write("Log in to Chat with RAG!")
59
- login_button(AUTH0_CLIENT_ID, AUTH0_DOMAIN, "auth0")
60
- st.divider()
61
- st.write("- [Privacy Policy](https://myscale.com/privacy/)\n"
62
- "- [Terms of Sevice](https://myscale.com/terms/)")
63
- if st.session_state.auth0 is not None:
64
- st.session_state.user_info = dict(st.session_state.auth0)
65
- if 'email' in st.session_state.user_info:
66
- email = st.session_state.user_info["email"]
67
- else:
68
- email = f"{st.session_state.user_info['nickname']}@{st.session_state.user_info['sub']}"
69
- st.session_state["user_name"] = email
70
- del st.session_state.auth0
71
- st.experimental_rerun()
72
- if st.session_state.jump_query_ask:
73
- st.experimental_rerun()
74
-
75
  def back_to_main():
76
  if "user_info" in st.session_state:
77
  del st.session_state.user_info
@@ -80,127 +25,29 @@ def back_to_main():
80
  if "jump_query_ask" in st.session_state:
81
  del st.session_state.jump_query_ask
82
 
83
- if login():
84
- if "user_name" in st.session_state:
85
- st.session_state["agents"] = build_agents(st.session_state.user_name)
86
- with st.sidebar:
87
- st.radio("Retriever Type", ["Self-querying retriever", "Vector SQL"], key="ret_type")
88
- st.selectbox("Knowledge Base", ["ArXiv Papers", "Wikipedia", "ArXiv + Wikipedia"], key="sel")
89
- st.button("Clear Chat History", on_click=clear_history)
90
- st.button("Logout", on_click=back_to_main)
91
- for msg in st.session_state.agents[st.session_state.sel][st.session_state.ret_type].memory.chat_memory.messages:
92
- speaker = "user" if isinstance(msg, HumanMessage) else "assistant"
93
- if isinstance(msg, FunctionMessage):
94
- with st.chat_message("Knowledge Base", avatar="πŸ“–"):
95
- print(type(msg.content))
 
 
 
 
 
 
 
 
 
96
  st.write(f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*")
97
- st.write("Retrieved from knowledge base:")
98
- try:
99
- st.dataframe(pd.DataFrame.from_records(map(dict, eval(msg.content))))
100
- except:
101
- st.write(msg.content)
102
- else:
103
- if len(msg.content) > 0:
104
- with st.chat_message(speaker):
105
- print(type(msg), msg.dict())
106
- st.write(f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*")
107
- st.write(f"{msg.content}")
108
- st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")
109
- elif "jump_query_ask" in st.session_state and st.session_state.jump_query_ask:
110
-
111
- sel = st.selectbox('Choose the knowledge base you want to ask with:',
112
- options=['ArXiv Papers', 'Wikipedia'])
113
- sel_map[sel]['hint']()
114
- tab_sql, tab_self_query = st.tabs(['Vector SQL', 'Self-Query Retrievers'])
115
- with tab_sql:
116
- sel_map[sel]['hint_sql']()
117
- st.text_input("Ask a question:", key='query_sql')
118
- cols = st.columns([1, 1, 1, 4])
119
- cols[0].button("Query", key='search_sql')
120
- cols[1].button("Ask", key='ask_sql')
121
- cols[2].button("Back", key='back_sql', on_click=back_to_main)
122
- plc_hldr = st.empty()
123
- if st.session_state.search_sql:
124
- plc_hldr = st.empty()
125
- print(st.session_state.query_sql)
126
- with plc_hldr.expander('Query Log', expanded=True):
127
- callback = ChatDataSQLSearchCallBackHandler()
128
- try:
129
- docs = st.session_state.sel_map_obj[sel]["sql_retriever"].get_relevant_documents(
130
- st.session_state.query_sql, callbacks=[callback])
131
- callback.progress_bar.progress(value=1.0, text="Done!")
132
- docs = pd.DataFrame(
133
- [{**d.metadata, 'abstract': d.page_content} for d in docs])
134
- display(docs)
135
- except Exception as e:
136
- st.write('Oops 😡 Something bad happened...')
137
- raise e
138
-
139
- if st.session_state.ask_sql:
140
- plc_hldr = st.empty()
141
- print(st.session_state.query_sql)
142
- with plc_hldr.expander('Chat Log', expanded=True):
143
- callback = ChatDataSQLAskCallBackHandler()
144
- try:
145
- ret = st.session_state.sel_map_obj[sel]["sql_chain"](
146
- st.session_state.query_sql, callbacks=[callback])
147
- callback.progress_bar.progress(value=1.0, text="Done!")
148
- st.markdown(
149
- f"### Answer from LLM\n{ret['answer']}\n### References")
150
- docs = ret['sources']
151
- docs = pd.DataFrame(
152
- [{**d.metadata, 'abstract': d.page_content} for d in docs])
153
- display(
154
- docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
155
- except Exception as e:
156
- st.write('Oops 😡 Something bad happened...')
157
- raise e
158
-
159
-
160
- with tab_self_query:
161
- st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='πŸ’‘')
162
- st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"])
163
- st.text_input("Ask a question:", key='query_self')
164
- cols = st.columns([1, 1, 1, 4])
165
- cols[0].button("Query", key='search_self')
166
- cols[1].button("Ask", key='ask_self')
167
- cols[2].button("Back", key='back_self', on_click=back_to_main)
168
- plc_hldr = st.empty()
169
- if st.session_state.search_self:
170
- plc_hldr = st.empty()
171
- print(st.session_state.query_self)
172
- with plc_hldr.expander('Query Log', expanded=True):
173
- call_back = None
174
- callback = ChatDataSelfSearchCallBackHandler()
175
- try:
176
- docs = st.session_state.sel_map_obj[sel]["retriever"].get_relevant_documents(
177
- st.session_state.query_self, callbacks=[callback])
178
- print(docs)
179
- callback.progress_bar.progress(value=1.0, text="Done!")
180
- docs = pd.DataFrame(
181
- [{**d.metadata, 'abstract': d.page_content} for d in docs])
182
- display(docs, sel_map[sel]["must_have_cols"])
183
- except Exception as e:
184
- st.write('Oops 😡 Something bad happened...')
185
- raise e
186
-
187
- if st.session_state.ask_self:
188
- plc_hldr = st.empty()
189
- print(st.session_state.query_self)
190
- with plc_hldr.expander('Chat Log', expanded=True):
191
- call_back = None
192
- callback = ChatDataSelfAskCallBackHandler()
193
- try:
194
- ret = st.session_state.sel_map_obj[sel]["chain"](
195
- st.session_state.query_self, callbacks=[callback])
196
- callback.progress_bar.progress(value=1.0, text="Done!")
197
- st.markdown(
198
- f"### Answer from LLM\n{ret['answer']}\n### References")
199
- docs = ret['sources']
200
- docs = pd.DataFrame(
201
- [{**d.metadata, 'abstract': d.page_content} for d in docs])
202
- display(
203
- docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
204
- except Exception as e:
205
- st.write('Oops 😡 Something bad happened...')
206
- raise e
 
 
 
1
  import pandas as pd
2
  from os import environ
3
  import datetime
4
  import streamlit as st
5
+ from langchain.schema import HumanMessage, FunctionMessage
6
 
7
+ from helper import build_agents
8
+ from login import back_to_main
 
 
 
 
 
 
 
9
 
10
  environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
11
 
 
 
 
 
 
 
 
 
12
  def on_chat_submit():
13
  ret = st.session_state.agents[st.session_state.sel][st.session_state.ret_type]({"input": st.session_state.chat_input})
14
  print(ret)
 
16
  def clear_history():
17
  st.session_state.agents[st.session_state.sel][st.session_state.ret_type].memory.clear()
18
 
 
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def back_to_main():
21
  if "user_info" in st.session_state:
22
  del st.session_state.user_info
 
25
  if "jump_query_ask" in st.session_state:
26
  del st.session_state.jump_query_ask
27
 
28
+ def chat_page():
29
+ st.session_state["agents"] = build_agents(f"{st.session_state.user_name}?default")
30
+ with st.sidebar:
31
+ st.radio("Retriever Type", ["Self-querying retriever", "Vector SQL"], key="ret_type")
32
+ st.selectbox("Knowledge Base", ["ArXiv Papers", "Wikipedia", "ArXiv + Wikipedia"], key="sel")
33
+ st.button("Clear Chat History", on_click=clear_history)
34
+ st.button("Logout", on_click=back_to_main)
35
+ for msg in st.session_state.agents[st.session_state.sel][st.session_state.ret_type].memory.chat_memory.messages:
36
+ speaker = "user" if isinstance(msg, HumanMessage) else "assistant"
37
+ if isinstance(msg, FunctionMessage):
38
+ with st.chat_message("Knowledge Base", avatar="πŸ“–"):
39
+ print(type(msg.content))
40
+ st.write(f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*")
41
+ st.write("Retrieved from knowledge base:")
42
+ try:
43
+ st.dataframe(pd.DataFrame.from_records(map(dict, eval(msg.content))))
44
+ except:
45
+ st.write(msg.content)
46
+ else:
47
+ if len(msg.content) > 0:
48
+ with st.chat_message(speaker):
49
+ print(type(msg), msg.dict())
50
  st.write(f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*")
51
+ st.write(f"{msg.content}")
52
+ st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")
53
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
helper.py CHANGED
@@ -369,6 +369,7 @@ def create_message_model(table_name, DynamicBase): # type: ignore
369
  __tablename__ = table_name
370
  id = Column(types.Float64)
371
  session_id = Column(Text)
 
372
  msg_id = Column(Text, primary_key=True)
373
  type = Column(Text)
374
  addtionals = Column(Text)
@@ -391,9 +392,11 @@ class DefaultClickhouseMessageConverter(DefaultMessageConverter):
391
  def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:
392
  tstamp = time.time()
393
  msg_id = hashlib.sha256(f"{session_id}_{message}_{tstamp}".encode('utf-8')).hexdigest()
 
394
  return self.model_class(
395
  id=tstamp,
396
  msg_id=msg_id,
 
397
  session_id=session_id,
398
  type=message.type,
399
  addtionals=json.dumps(message.additional_kwargs),
@@ -467,7 +470,7 @@ def build_tools():
467
  return sel_map_obj
468
 
469
  @st.cache_resource(max_entries=1)
470
- def build_agents(username):
471
  chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=0.6, openai_api_base=OPENAI_API_BASE, openai_api_key=OPENAI_API_KEY)
472
  agents = {}
473
  cnt = 0
@@ -484,7 +487,7 @@ def build_agents(username):
484
  agents[k] = {}
485
  agents[k][n] = create_agent_executor(
486
  "chat_memory",
487
- username,
488
  chat_llm,
489
  tools=tools,
490
  )
 
369
  __tablename__ = table_name
370
  id = Column(types.Float64)
371
  session_id = Column(Text)
372
+ user_id = Column(Text)
373
  msg_id = Column(Text, primary_key=True)
374
  type = Column(Text)
375
  addtionals = Column(Text)
 
392
  def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:
393
  tstamp = time.time()
394
  msg_id = hashlib.sha256(f"{session_id}_{message}_{tstamp}".encode('utf-8')).hexdigest()
395
+ user_id, _ = session_id.split("?")
396
  return self.model_class(
397
  id=tstamp,
398
  msg_id=msg_id,
399
+ user_id=user_id,
400
  session_id=session_id,
401
  type=message.type,
402
  addtionals=json.dumps(message.additional_kwargs),
 
470
  return sel_map_obj
471
 
472
  @st.cache_resource(max_entries=1)
473
+ def build_agents(session_id):
474
  chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=0.6, openai_api_base=OPENAI_API_BASE, openai_api_key=OPENAI_API_KEY)
475
  agents = {}
476
  cnt = 0
 
487
  agents[k] = {}
488
  agents[k][n] = create_agent_executor(
489
  "chat_memory",
490
+ session_id,
491
  chat_llm,
492
  tools=tools,
493
  )
login.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ import pandas as pd
4
+ from os import environ
5
+ import streamlit as st
6
+ from auth0_component import login_button
7
+
8
+ AUTH0_CLIENT_ID = st.secrets['AUTH0_CLIENT_ID']
9
+ AUTH0_DOMAIN = st.secrets['AUTH0_DOMAIN']
10
+
11
+ def login():
12
+ if "user_name" in st.session_state or ("jump_query_ask" in st.session_state and st.session_state.jump_query_ask):
13
+ return True
14
+ st.subheader("πŸ€— Welcom to [MyScale](https://myscale.com)'s [ChatData](https://github.com/myscale/ChatData)! πŸ€— ")
15
+ st.write("You can now chat with ArXiv and Wikipedia! 🌟\n")
16
+ st.write("Built purely with streamlit πŸ‘‘ , LangChain πŸ¦œπŸ”— and love ❀️ for AI!")
17
+ st.write("Follow us on [Twitter](https://x.com/myscaledb) and [Discord](https://discord.gg/D2qpkqc4Jq)!")
18
+ st.write("For more details, please refer to [our repository on GitHub](https://github.com/myscale/ChatData)!")
19
+ st.divider()
20
+ col1, col2 = st.columns(2, gap='large')
21
+ with col1.container():
22
+ st.write("Try out MyScale's Self-query and Vector SQL retrievers!")
23
+ st.write("In this demo, you will be able to see how those retrievers "
24
+ "**digest** -> **translate** -> **retrieve** -> **answer** to your question!")
25
+ st.session_state["jump_query_ask"] = st.button("Query / Ask")
26
+ with col2.container():
27
+ # st.warning("To use chat, please jump to [https://myscale-chatdata.hf.space](https://myscale-chatdata.hf.space)")
28
+ st.write("Now with the power of LangChain's Conversantional Agents, we are able to build "
29
+ "an RAG-enabled chatbot within one MyScale instance! ")
30
+ st.write("Log in to Chat with RAG!")
31
+ login_button(AUTH0_CLIENT_ID, AUTH0_DOMAIN, "auth0")
32
+ st.divider()
33
+ st.write("- [Privacy Policy](https://myscale.com/privacy/)\n"
34
+ "- [Terms of Sevice](https://myscale.com/terms/)")
35
+ if st.session_state.auth0 is not None:
36
+ st.session_state.user_info = dict(st.session_state.auth0)
37
+ if 'email' in st.session_state.user_info:
38
+ email = st.session_state.user_info["email"]
39
+ else:
40
+ email = f"{st.session_state.user_info['nickname']}@{st.session_state.user_info['sub']}"
41
+ st.session_state["user_name"] = email
42
+ del st.session_state.auth0
43
+ st.experimental_rerun()
44
+ if st.session_state.jump_query_ask:
45
+ st.experimental_rerun()
46
+
47
+ def back_to_main():
48
+ if "user_info" in st.session_state:
49
+ del st.session_state.user_info
50
+ if "user_name" in st.session_state:
51
+ del st.session_state.user_name
52
+ if "jump_query_ask" in st.session_state:
53
+ del st.session_state.jump_query_ask