Fangrui Liu commited on
Commit
19bd5a9
·
1 Parent(s): 45180a0

update chat

Browse files
Files changed (5) hide show
  1. README.md +1 -1
  2. app.py +1 -291
  3. chat.py +204 -0
  4. helper.py +506 -0
  5. requirements.txt +2 -1
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: pink
5
  colorTo: purple
6
  sdk: streamlit
7
  sdk_version: 1.20.0
8
- app_file: app.py
9
  pinned: true
10
  license: mit
11
  ---
 
5
  colorTo: purple
6
  sdk: streamlit
7
  sdk_version: 1.20.0
8
+ app_file: chat.py
9
  pinned: true
10
  license: mit
11
  ---
app.py CHANGED
@@ -14,308 +14,18 @@ from langchain.prompts import PromptTemplate, ChatPromptTemplate, \
14
  from langchain.prompts.prompt import PromptTemplate
15
  from langchain.chat_models import ChatOpenAI
16
  from langchain import OpenAI
17
- from langchain.chains.query_constructor.base import AttributeInfo, VirtualColumnName
18
- from langchain.retrievers.self_query.base import SelfQueryRetriever
19
- from langchain.retrievers.self_query.myscale import MyScaleTranslator
20
- from langchain.embeddings import HuggingFaceInstructEmbeddings, SentenceTransformerEmbeddings
21
- from langchain.vectorstores import MyScaleSettings
22
- from chains.arxiv_chains import MyScaleWithoutMetadataJson
23
  import re
24
  import pandas as pd
25
  from os import environ
26
  import streamlit as st
27
  import datetime
28
- environ['TOKENIZERS_PARALLELISM'] = 'true'
29
  environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
30
 
31
-
32
  st.set_page_config(page_title="ChatData")
33
 
34
  st.header("ChatData")
35
 
36
- # query_model_name = "gpt-3.5-turbo-instruct"
37
- query_model_name = "text-davinci-003"
38
- chat_model_name = "gpt-3.5-turbo-16k"
39
-
40
-
41
- def hint_arxiv():
42
- st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
43
- "For example: \n\n"
44
- "*If you want to search papers with complex filters*:\n\n"
45
- "- What is a Bayesian network? Please use articles published later than Feb 2018 and with more than 2 categories and whose title like `computer` and must have `cs.CV` in its category.\n\n"
46
- "*If you want to ask questions based on papers in database*:\n\n"
47
- "- What is PageRank?\n"
48
- "- Did Geoffrey Hinton wrote paper about Capsule Neural Networks?\n"
49
- "- Introduce some applications of GANs published around 2019.\n"
50
- "- 请根据 2019 年左右的文章介绍一下 GAN 的应用都有哪些\n"
51
- "- Veuillez présenter les applications du GAN sur la base des articles autour de 2019 ?\n"
52
- "- Is it possible to synthesize room temperature super conductive material?")
53
-
54
-
55
- def hint_sql_arxiv():
56
- st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡')
57
- st.markdown('''```sql
58
- CREATE TABLE default.ChatArXiv (
59
- `abstract` String,
60
- `id` String,
61
- `vector` Array(Float32),
62
- `metadata` Object('JSON'),
63
- `pubdate` DateTime,
64
- `title` String,
65
- `categories` Array(String),
66
- `authors` Array(String),
67
- `comment` String,
68
- `primary_category` String,
69
- VECTOR INDEX vec_idx vector TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'),
70
- CONSTRAINT vec_len CHECK length(vector) = 768)
71
- ENGINE = ReplacingMergeTree ORDER BY id
72
- ```''')
73
-
74
-
75
- def hint_wiki():
76
- st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
77
- "For example: \n\n"
78
- "- Which company did Elon Musk found?\n"
79
- "- What is Iron Gwazi?\n"
80
- "- What is a Ring in mathematics?\n"
81
- "- 苹果的发源地是那里?\n")
82
-
83
-
84
- def hint_sql_wiki():
85
- st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡')
86
- st.markdown('''```sql
87
- CREATE TABLE wiki.Wikipedia (
88
- `id` String,
89
- `title` String,
90
- `text` String,
91
- `url` String,
92
- `wiki_id` UInt64,
93
- `views` Float32,
94
- `paragraph_id` UInt64,
95
- `langs` UInt32,
96
- `emb` Array(Float32),
97
- VECTOR INDEX vec_idx emb TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'),
98
- CONSTRAINT emb_len CHECK length(emb) = 768)
99
- ENGINE = ReplacingMergeTree ORDER BY id
100
- ```''')
101
-
102
-
103
- sel_map = {
104
- 'Wikipedia': {
105
- "database": "wiki",
106
- "table": "Wikipedia",
107
- "hint": hint_wiki,
108
- "hint_sql": hint_sql_wiki,
109
- "doc_prompt": PromptTemplate(
110
- input_variables=["page_content", "url", "title", "ref_id", "views"],
111
- template="Title for Doc #{ref_id}: {title}\n\tviews: {views}\n\tcontent: {page_content}\nSOURCE: {url}"),
112
- "metadata_cols": [
113
- AttributeInfo(
114
- name="title",
115
- description="title of the wikipedia page",
116
- type="string",
117
- ),
118
- AttributeInfo(
119
- name="text",
120
- description="paragraph from this wiki page",
121
- type="string",
122
- ),
123
- AttributeInfo(
124
- name="views",
125
- description="number of views",
126
- type="float"
127
- ),
128
- ],
129
- "must_have_cols": ['id', 'title', 'url', 'text', 'views'],
130
- "vector_col": "emb",
131
- "text_col": "text",
132
- "metadata_col": "metadata",
133
- "emb_model": lambda: SentenceTransformerEmbeddings(
134
- model_name='sentence-transformers/paraphrase-multilingual-mpnet-base-v2',)
135
- },
136
- 'ArXiv Papers': {
137
- "database": "default",
138
- "table": "ChatArXiv",
139
- "hint": hint_arxiv,
140
- "hint_sql": hint_sql_arxiv,
141
- "doc_prompt": PromptTemplate(
142
- input_variables=["page_content", "id", "title", "ref_id",
143
- "authors", "pubdate", "categories"],
144
- template="Title for Doc #{ref_id}: {title}\n\tAbstract: {page_content}\n\tAuthors: {authors}\n\tDate of Publication: {pubdate}\n\tCategories: {categories}\nSOURCE: {id}"),
145
- "metadata_cols": [
146
- AttributeInfo(
147
- name=VirtualColumnName(name="pubdate"),
148
- description="The year the paper is published",
149
- type="timestamp",
150
- ),
151
- AttributeInfo(
152
- name="authors",
153
- description="List of author names",
154
- type="list[string]",
155
- ),
156
- AttributeInfo(
157
- name="title",
158
- description="Title of the paper",
159
- type="string",
160
- ),
161
- AttributeInfo(
162
- name="categories",
163
- description="arxiv categories to this paper",
164
- type="list[string]"
165
- ),
166
- AttributeInfo(
167
- name="length(categories)",
168
- description="length of arxiv categories to this paper",
169
- type="int"
170
- ),
171
- ],
172
- "must_have_cols": ['title', 'id', 'categories', 'abstract', 'authors', 'pubdate'],
173
- "vector_col": "vector",
174
- "text_col": "abstract",
175
- "metadata_col": "metadata",
176
- "emb_model": lambda: HuggingFaceInstructEmbeddings(
177
- model_name='hkunlp/instructor-xl',
178
- embed_instruction="Represent the question for retrieving supporting scientific papers: ")
179
- }
180
- }
181
-
182
-
183
- def try_eval(x):
184
- try:
185
- return eval(x, {'datetime': datetime})
186
- except:
187
- return x
188
-
189
-
190
- def display(dataframe, columns_=None, index=None):
191
- if len(dataframe) > 0:
192
- if index:
193
- dataframe.set_index(index)
194
- if columns_:
195
- st.dataframe(dataframe[columns_])
196
- else:
197
- st.dataframe(dataframe)
198
- else:
199
- st.write("Sorry 😵 we didn't find any articles related to your query.\n\nMaybe the LLM is too naughty that does not follow our instruction... \n\nPlease try again and use verbs that may match the datatype.", unsafe_allow_html=True)
200
-
201
-
202
- def build_embedding_model(_sel):
203
- with st.spinner("Loading Model..."):
204
- embeddings = sel_map[_sel]["emb_model"]()
205
- return embeddings
206
-
207
-
208
- def build_retriever(_sel):
209
- with st.spinner(f"Connecting DB for {_sel}..."):
210
- myscale_connection = {
211
- "host": st.secrets['MYSCALE_HOST'],
212
- "port": st.secrets['MYSCALE_PORT'],
213
- "username": st.secrets['MYSCALE_USER'],
214
- "password": st.secrets['MYSCALE_PASSWORD'],
215
- }
216
-
217
- config = MyScaleSettings(**myscale_connection,
218
- database=sel_map[_sel]["database"],
219
- table=sel_map[_sel]["table"],
220
- column_map={
221
- "id": "id",
222
- "text": sel_map[_sel]["text_col"],
223
- "vector": sel_map[_sel]["vector_col"],
224
- "metadata": sel_map[_sel]["metadata_col"]
225
- })
226
- doc_search = MyScaleWithoutMetadataJson(st.session_state[f"emb_model_{_sel}"], config,
227
- must_have_cols=sel_map[_sel]['must_have_cols'])
228
-
229
- with st.spinner(f"Building Self Query Retriever for {_sel}..."):
230
- metadata_field_info = sel_map[_sel]["metadata_cols"]
231
- retriever = SelfQueryRetriever.from_llm(
232
- OpenAI(model_name=query_model_name, openai_api_key=st.secrets['OPENAI_API_KEY'], temperature=0),
233
- doc_search, "Scientific papers indexes with abstracts. All in English.", metadata_field_info,
234
- use_original_query=False, structured_query_translator=MyScaleTranslator())
235
-
236
- COMBINE_PROMPT = ChatPromptTemplate.from_strings(
237
- string_messages=[(SystemMessagePromptTemplate, combine_prompt_template),
238
- (HumanMessagePromptTemplate, '{question}')])
239
- OPENAI_API_KEY = st.secrets['OPENAI_API_KEY']
240
-
241
- with st.spinner(f'Building QA Chain with Self-query for {_sel}...'):
242
- chain = ArXivQAwithSourcesChain(
243
- retriever=retriever,
244
- combine_documents_chain=ArXivStuffDocumentChain(
245
- llm_chain=LLMChain(
246
- prompt=COMBINE_PROMPT,
247
- llm=ChatOpenAI(model_name=chat_model_name,
248
- openai_api_key=OPENAI_API_KEY, temperature=0.6),
249
- ),
250
- document_prompt=sel_map[_sel]["doc_prompt"],
251
- document_variable_name="summaries",
252
-
253
- ),
254
- return_source_documents=True,
255
- max_tokens_limit=12000,
256
- )
257
-
258
- with st.spinner(f'Building Vector SQL Database Retriever for {_sel}...'):
259
- MYSCALE_USER = st.secrets['MYSCALE_USER']
260
- MYSCALE_PASSWORD = st.secrets['MYSCALE_PASSWORD']
261
- MYSCALE_HOST = st.secrets['MYSCALE_HOST']
262
- MYSCALE_PORT = st.secrets['MYSCALE_PORT']
263
- engine = create_engine(
264
- f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}/{sel_map[_sel]["database"]}?protocol=https')
265
- metadata = MetaData(bind=engine)
266
- PROMPT = PromptTemplate(
267
- input_variables=["input", "table_info", "top_k"],
268
- template=_myscale_prompt,
269
- )
270
- output_parser = VectorSQLRetrieveCustomOutputParser.from_embeddings(
271
- model=st.session_state[f'emb_model_{_sel}'], must_have_columns=sel_map[_sel]["must_have_cols"])
272
- sql_query_chain = VectorSQLDatabaseChain.from_llm(
273
- llm=OpenAI(model_name=query_model_name, openai_api_key=OPENAI_API_KEY, temperature=0),
274
- prompt=PROMPT,
275
- top_k=10,
276
- return_direct=True,
277
- db=SQLDatabase(engine, None, metadata, max_string_length=1024),
278
- sql_cmd_parser=output_parser,
279
- native_format=True
280
- )
281
- sql_retriever = VectorSQLDatabaseChainRetriever(
282
- sql_db_chain=sql_query_chain, page_content_key=sel_map[_sel]["text_col"])
283
-
284
- with st.spinner(f'Building QA Chain with Vector SQL for {_sel}...'):
285
- sql_chain = ArXivQAwithSourcesChain(
286
- retriever=sql_retriever,
287
- combine_documents_chain=ArXivStuffDocumentChain(
288
- llm_chain=LLMChain(
289
- prompt=COMBINE_PROMPT,
290
- llm=ChatOpenAI(model_name=chat_model_name,
291
- openai_api_key=OPENAI_API_KEY, temperature=0.6),
292
- ),
293
- document_prompt=sel_map[_sel]["doc_prompt"],
294
- document_variable_name="summaries",
295
-
296
- ),
297
- return_source_documents=True,
298
- max_tokens_limit=12000,
299
- )
300
-
301
- return {
302
- "metadata_columns": [{'name': m.name.name if type(m.name) is VirtualColumnName else m.name, 'desc': m.description, 'type': m.type} for m in metadata_field_info],
303
- "retriever": retriever,
304
- "chain": chain,
305
- "sql_retriever": sql_retriever,
306
- "sql_chain": sql_chain
307
- }
308
-
309
-
310
- @st.cache_resource
311
- def build_all():
312
- sel_map_obj = {}
313
- for k in sel_map:
314
- st.session_state[f'emb_model_{k}'] = build_embedding_model(k)
315
- sel_map_obj[k] = build_retriever(k)
316
- return sel_map_obj
317
-
318
-
319
  if 'retriever' not in st.session_state:
320
  st.session_state["sel_map_obj"] = build_all()
321
 
 
14
  from langchain.prompts.prompt import PromptTemplate
15
  from langchain.chat_models import ChatOpenAI
16
  from langchain import OpenAI
 
 
 
 
 
 
17
  import re
18
  import pandas as pd
19
  from os import environ
20
  import streamlit as st
21
  import datetime
22
+ from helper import build_all, sel_map, display
23
  environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
24
 
 
25
  st.set_page_config(page_title="ChatData")
26
 
27
  st.header("ChatData")
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  if 'retriever' not in st.session_state:
30
  st.session_state["sel_map_obj"] = build_all()
31
 
chat.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ import pandas as pd
4
+ from os import environ
5
+ import datetime
6
+ import streamlit as st
7
+ from langchain.schema import Document
8
+
9
+ from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
10
+ ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \
11
+ ChatDataSQLAskCallBackHandler
12
+
13
+ from langchain.schema import BaseMessage, HumanMessage, AIMessage, FunctionMessage, SystemMessage
14
+ from auth0_component import login_button
15
+
16
+
17
+ from helper import build_tools, build_agents, build_all, sel_map, display
18
+
19
+ environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
20
+
21
+ st.set_page_config(page_title="ChatData", page_icon="https://myscale.com/favicon.ico")
22
+ st.header("ChatData")
23
+
24
+
25
+ if 'retriever' not in st.session_state:
26
+ st.session_state["sel_map_obj"] = build_all()
27
+ st.session_state["tools"] = build_tools()
28
+
29
+ def on_chat_submit():
30
+ ret = st.session_state.agents[st.session_state.sel][st.session_state.ret_type]({"input": st.session_state.chat_input})
31
+ print(ret)
32
+
33
+ def clear_history():
34
+ st.session_state.agents[st.session_state.sel][st.session_state.ret_type].memory.clear()
35
+
36
+ AUTH0_CLIENT_ID = st.secrets['AUTH0_CLIENT_ID']
37
+ AUTH0_DOMAIN = st.secrets['AUTH0_DOMAIN']
38
+
39
+ def login():
40
+ if "user_name" in st.session_state or ("jump_query_ask" in st.session_state and st.session_state.jump_query_ask):
41
+ return True
42
+ st.subheader("🤗 Welcom to [MyScale](https://myscale.com)'s [ChatData](https://github.com/myscale/ChatData)! 🤗 ")
43
+ st.write("You can now chat with ArXiv and Wikipedia! You can also try to build your RAG system with those knowledge base via [our public read-only credentials!](https://github.com/myscale/ChatData#data-schema) 🌟\n")
44
+ st.write("Built purely with streamlit 👑 , LangChain 🦜🔗 and love for AI!")
45
+ st.write("Follow us on [Twitter](https://x.com/myscaledb) and [Discord](https://discord.gg/D2qpkqc4Jq)!")
46
+ st.warning("To use chat, please jump to [https://myscale-chatdata.hf.space](https://myscale-chatdata.hf.space)")
47
+ st.info("We used [Auth0](https://auth0.com) as our identity provider. "
48
+ "We will **NOT** collect any of your conversation in any form for any purpose.")
49
+ st.divider()
50
+ col1, col2 = st.columns(2, gap='large')
51
+ with col1.container():
52
+ st.write("Try out MyScale's Self-query and Vector SQL retrievers!")
53
+ st.write("In this demo, you will be able to see how those retrievers "
54
+ "**digest** -> **translate** -> **retrieve** -> **answer** to your question!")
55
+ st.write("It is a step-by-step tour to understand RAG pipeline.")
56
+ st.session_state["jump_query_ask"] = st.button("Query / Ask")
57
+ with col2.container():
58
+ st.write("Now with the power of LangChain's Conversantional Agents, we are able to build "
59
+ "conversational chatbot with RAG! The agent will decide when and what to retrieve "
60
+ "based on your question!")
61
+ st.write("All those conversation history management and retrievers are provided within one MyScale instance!")
62
+ st.write("Log in to Chat with RAG!")
63
+ login_button(AUTH0_CLIENT_ID, AUTH0_DOMAIN, "auth0")
64
+ if st.session_state.auth0 is not None:
65
+ st.session_state.user_info = dict(st.session_state.auth0)
66
+ if 'email' in st.session_state.user_info:
67
+ email = st.session_state.user_info["email"]
68
+ else:
69
+ email = f"{st.session_state.user_info['nickname']}@{st.session_state.user_info['sub']}"
70
+ st.session_state["user_name"] = email
71
+ del st.session_state.auth0
72
+ st.experimental_rerun()
73
+ if st.session_state.jump_query_ask:
74
+ st.experimental_rerun()
75
+
76
+ def back_to_main():
77
+ if "user_info" in st.session_state:
78
+ del st.session_state.user_info
79
+ if "user_name" in st.session_state:
80
+ del st.session_state.user_name
81
+ if "jump_query_ask" in st.session_state:
82
+ del st.session_state.jump_query_ask
83
+
84
+ if login():
85
+ if "user_name" in st.session_state:
86
+ st.session_state["agents"] = build_agents(st.session_state.user_name)
87
+ with st.sidebar:
88
+ st.radio("Retriever Type", ["Self-querying retriever", "Vector SQL"], key="ret_type")
89
+ st.selectbox("Knowledge Base", ["ArXiv Papers", "Wikipedia", "ArXiv + Wikipedia"], key="sel")
90
+ st.button("Clear Chat History", on_click=clear_history)
91
+ st.button("Logout", on_click=back_to_main)
92
+ for msg in st.session_state.agents[st.session_state.sel][st.session_state.ret_type].memory.chat_memory.messages:
93
+ speaker = "user" if isinstance(msg, HumanMessage) else "assistant"
94
+ if isinstance(msg, FunctionMessage):
95
+ with st.chat_message("Knowledge Base", avatar="📖"):
96
+ print(type(msg.content))
97
+ st.write(f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*")
98
+ st.write("Retrieved from knowledge base:")
99
+ st.dataframe(pd.DataFrame.from_records(map(dict, eval(msg.content))))
100
+ else:
101
+ if len(msg.content) > 0:
102
+ with st.chat_message(speaker):
103
+ print(type(msg), msg.dict())
104
+ st.write(f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*")
105
+ st.write(f"{msg.content}")
106
+ st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")
107
+ elif "jump_query_ask" in st.session_state and st.session_state.jump_query_ask:
108
+
109
+ sel = st.selectbox('Choose the knowledge base you want to ask with:',
110
+ options=['ArXiv Papers', 'Wikipedia'])
111
+ sel_map[sel]['hint']()
112
+ tab_sql, tab_self_query = st.tabs(['Vector SQL', 'Self-Query Retrievers'])
113
+ with tab_sql:
114
+ sel_map[sel]['hint_sql']()
115
+ st.text_input("Ask a question:", key='query_sql')
116
+ cols = st.columns([1, 1, 1, 4])
117
+ cols[0].button("Query", key='search_sql')
118
+ cols[1].button("Ask", key='ask_sql')
119
+ cols[2].button("Back", key='back_sql', on_click=back_to_main)
120
+ plc_hldr = st.empty()
121
+ if st.session_state.search_sql:
122
+ plc_hldr = st.empty()
123
+ print(st.session_state.query_sql)
124
+ with plc_hldr.expander('Query Log', expanded=True):
125
+ callback = ChatDataSQLSearchCallBackHandler()
126
+ try:
127
+ docs = st.session_state.sel_map_obj[sel]["sql_retriever"].get_relevant_documents(
128
+ st.session_state.query_sql, callbacks=[callback])
129
+ callback.progress_bar.progress(value=1.0, text="Done!")
130
+ docs = pd.DataFrame(
131
+ [{**d.metadata, 'abstract': d.page_content} for d in docs])
132
+ display(docs)
133
+ except Exception as e:
134
+ st.write('Oops 😵 Something bad happened...')
135
+ raise e
136
+
137
+ if st.session_state.ask_sql:
138
+ plc_hldr = st.empty()
139
+ print(st.session_state.query_sql)
140
+ with plc_hldr.expander('Chat Log', expanded=True):
141
+ callback = ChatDataSQLAskCallBackHandler()
142
+ try:
143
+ ret = st.session_state.sel_map_obj[sel]["sql_chain"](
144
+ st.session_state.query_sql, callbacks=[callback])
145
+ callback.progress_bar.progress(value=1.0, text="Done!")
146
+ st.markdown(
147
+ f"### Answer from LLM\n{ret['answer']}\n### References")
148
+ docs = ret['sources']
149
+ docs = pd.DataFrame(
150
+ [{**d.metadata, 'abstract': d.page_content} for d in docs])
151
+ display(
152
+ docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
153
+ except Exception as e:
154
+ st.write('Oops 😵 Something bad happened...')
155
+ raise e
156
+
157
+
158
+ with tab_self_query:
159
+ st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡')
160
+ st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"])
161
+ st.text_input("Ask a question:", key='query_self')
162
+ cols = st.columns([1, 1, 1, 4])
163
+ cols[0].button("Query", key='search_self')
164
+ cols[1].button("Ask", key='ask_self')
165
+ cols[2].button("Back", key='back_self', on_click=back_to_main)
166
+ plc_hldr = st.empty()
167
+ if st.session_state.search_self:
168
+ plc_hldr = st.empty()
169
+ print(st.session_state.query_self)
170
+ with plc_hldr.expander('Query Log', expanded=True):
171
+ call_back = None
172
+ callback = ChatDataSelfSearchCallBackHandler()
173
+ try:
174
+ docs = st.session_state.sel_map_obj[sel]["retriever"].get_relevant_documents(
175
+ st.session_state.query_self, callbacks=[callback])
176
+ print(docs)
177
+ callback.progress_bar.progress(value=1.0, text="Done!")
178
+ docs = pd.DataFrame(
179
+ [{**d.metadata, 'abstract': d.page_content} for d in docs])
180
+ display(docs, sel_map[sel]["must_have_cols"])
181
+ except Exception as e:
182
+ st.write('Oops 😵 Something bad happened...')
183
+ raise e
184
+
185
+ if st.session_state.ask_self:
186
+ plc_hldr = st.empty()
187
+ print(st.session_state.query_self)
188
+ with plc_hldr.expander('Chat Log', expanded=True):
189
+ call_back = None
190
+ callback = ChatDataSelfAskCallBackHandler()
191
+ try:
192
+ ret = st.session_state.sel_map_obj[sel]["chain"](
193
+ st.session_state.query_self, callbacks=[callback])
194
+ callback.progress_bar.progress(value=1.0, text="Done!")
195
+ st.markdown(
196
+ f"### Answer from LLM\n{ret['answer']}\n### References")
197
+ docs = ret['sources']
198
+ docs = pd.DataFrame(
199
+ [{**d.metadata, 'abstract': d.page_content} for d in docs])
200
+ display(
201
+ docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
202
+ except Exception as e:
203
+ st.write('Oops 😵 Something bad happened...')
204
+ raise e
helper.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+ import time
4
+ import hashlib
5
+ from typing import Dict, Any
6
+ import re
7
+ import pandas as pd
8
+ from os import environ
9
+ import streamlit as st
10
+ import datetime
11
+
12
+ from sqlalchemy import Column, Text, create_engine, MetaData
13
+ from langchain.agents import AgentExecutor
14
+ try:
15
+ from sqlalchemy.orm import declarative_base
16
+ except ImportError:
17
+ from sqlalchemy.ext.declarative import declarative_base
18
+ from sqlalchemy.orm import sessionmaker
19
+ from clickhouse_sqlalchemy import (
20
+ Table, make_session, get_declarative_base, types, engines
21
+ )
22
+ from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
23
+ from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
24
+ from langchain.utilities.sql_database import SQLDatabase
25
+ from langchain.chains import LLMChain
26
+ from sqlalchemy import create_engine, MetaData
27
+ from langchain.prompts import PromptTemplate, ChatPromptTemplate, \
28
+ SystemMessagePromptTemplate, HumanMessagePromptTemplate
29
+ from langchain.prompts.prompt import PromptTemplate
30
+ from langchain.chat_models import ChatOpenAI
31
+ from langchain.schema import BaseRetriever
32
+ from langchain import OpenAI
33
+ from langchain.chains.query_constructor.base import AttributeInfo, VirtualColumnName
34
+ from langchain.retrievers.self_query.base import SelfQueryRetriever
35
+ from langchain.retrievers.self_query.myscale import MyScaleTranslator
36
+ from langchain.embeddings import HuggingFaceInstructEmbeddings, SentenceTransformerEmbeddings
37
+ from langchain.vectorstores import MyScaleSettings
38
+ from chains.arxiv_chains import MyScaleWithoutMetadataJson
39
+ from langchain.schema import Document
40
+ from langchain.prompts.prompt import PromptTemplate
41
+ from langchain.prompts.chat import MessagesPlaceholder
42
+ from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory
43
+ from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
44
+ from langchain.schema import BaseMessage, HumanMessage, AIMessage, FunctionMessage, SystemMessage
45
+ from langchain.memory import SQLChatMessageHistory
46
+ from langchain.memory.chat_message_histories.sql import \
47
+ BaseMessageConverter, DefaultMessageConverter
48
+ from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
49
+ from langchain.agents.agent_toolkits import create_retriever_tool
50
+ from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt
51
+ from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
52
+ from chains.arxiv_chains import VectorSQLRetrieveCustomOutputParser
53
+ environ['TOKENIZERS_PARALLELISM'] = 'true'
54
+ environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
55
+
56
+ # query_model_name = "gpt-3.5-turbo-instruct"
57
+ query_model_name = "text-davinci-003"
58
+ chat_model_name = "gpt-3.5-turbo-16k"
59
+
60
+
61
+ OPENAI_API_KEY = st.secrets['OPENAI_API_KEY']
62
+ OPENAI_API_BASE = st.secrets['OPENAI_API_BASE']
63
+ MYSCALE_USER = st.secrets['MYSCALE_USER']
64
+ MYSCALE_PASSWORD = st.secrets['MYSCALE_PASSWORD']
65
+ MYSCALE_HOST = st.secrets['MYSCALE_HOST']
66
+ MYSCALE_PORT = st.secrets['MYSCALE_PORT']
67
+
68
+ COMBINE_PROMPT = ChatPromptTemplate.from_strings(
69
+ string_messages=[(SystemMessagePromptTemplate, combine_prompt_template),
70
+ (HumanMessagePromptTemplate, '{question}')])
71
+
72
+ def hint_arxiv():
73
+ st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
74
+ "For example: \n\n"
75
+ "*If you want to search papers with complex filters*:\n\n"
76
+ "- What is a Bayesian network? Please use articles published later than Feb 2018 and with more than 2 categories and whose title like `computer` and must have `cs.CV` in its category.\n\n"
77
+ "*If you want to ask questions based on papers in database*:\n\n"
78
+ "- What is PageRank?\n"
79
+ "- Did Geoffrey Hinton wrote paper about Capsule Neural Networks?\n"
80
+ "- Introduce some applications of GANs published around 2019.\n"
81
+ "- 请根据 2019 年左右的文章介绍一下 GAN 的应用都有哪些\n"
82
+ "- Veuillez présenter les applications du GAN sur la base des articles autour de 2019 ?\n"
83
+ "- Is it possible to synthesize room temperature super conductive material?")
84
+
85
+
86
+ def hint_sql_arxiv():
87
+ st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡')
88
+ st.markdown('''```sql
89
+ CREATE TABLE default.ChatArXiv (
90
+ `abstract` String,
91
+ `id` String,
92
+ `vector` Array(Float32),
93
+ `metadata` Object('JSON'),
94
+ `pubdate` DateTime,
95
+ `title` String,
96
+ `categories` Array(String),
97
+ `authors` Array(String),
98
+ `comment` String,
99
+ `primary_category` String,
100
+ VECTOR INDEX vec_idx vector TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'),
101
+ CONSTRAINT vec_len CHECK length(vector) = 768)
102
+ ENGINE = ReplacingMergeTree ORDER BY id
103
+ ```''')
104
+
105
+
106
+ def hint_wiki():
107
+ st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
108
+ "For example: \n\n"
109
+ "- Which company did Elon Musk found?\n"
110
+ "- What is Iron Gwazi?\n"
111
+ "- What is a Ring in mathematics?\n"
112
+ "- 苹果的发源地是那里?\n")
113
+
114
+
115
+ def hint_sql_wiki():
116
+ st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡')
117
+ st.markdown('''```sql
118
+ CREATE TABLE wiki.Wikipedia (
119
+ `id` String,
120
+ `title` String,
121
+ `text` String,
122
+ `url` String,
123
+ `wiki_id` UInt64,
124
+ `views` Float32,
125
+ `paragraph_id` UInt64,
126
+ `langs` UInt32,
127
+ `emb` Array(Float32),
128
+ VECTOR INDEX vec_idx emb TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'),
129
+ CONSTRAINT emb_len CHECK length(emb) = 768)
130
+ ENGINE = ReplacingMergeTree ORDER BY id
131
+ ```''')
132
+
133
+
134
+ sel_map = {
135
+ 'Wikipedia': {
136
+ "database": "wiki",
137
+ "table": "Wikipedia",
138
+ "hint": hint_wiki,
139
+ "hint_sql": hint_sql_wiki,
140
+ "doc_prompt": PromptTemplate(
141
+ input_variables=["page_content", "url", "title", "ref_id", "views"],
142
+ template="Title for Doc #{ref_id}: {title}\n\tviews: {views}\n\tcontent: {page_content}\nSOURCE: {url}"),
143
+ "metadata_cols": [
144
+ AttributeInfo(
145
+ name="title",
146
+ description="title of the wikipedia page",
147
+ type="string",
148
+ ),
149
+ AttributeInfo(
150
+ name="text",
151
+ description="paragraph from this wiki page",
152
+ type="string",
153
+ ),
154
+ AttributeInfo(
155
+ name="views",
156
+ description="number of views",
157
+ type="float"
158
+ ),
159
+ ],
160
+ "must_have_cols": ['id', 'title', 'url', 'text', 'views'],
161
+ "vector_col": "emb",
162
+ "text_col": "text",
163
+ "metadata_col": "metadata",
164
+ "emb_model": lambda: SentenceTransformerEmbeddings(
165
+ model_name='sentence-transformers/paraphrase-multilingual-mpnet-base-v2',),
166
+ "tool_desc": ("search_among_wikipedia", "Searches among Wikipedia and returns related wiki pages"),
167
+ },
168
+ 'ArXiv Papers': {
169
+ "database": "default",
170
+ "table": "ChatArXiv",
171
+ "hint": hint_arxiv,
172
+ "hint_sql": hint_sql_arxiv,
173
+ "doc_prompt": PromptTemplate(
174
+ input_variables=["page_content", "id", "title", "ref_id",
175
+ "authors", "pubdate", "categories"],
176
+ template="Title for Doc #{ref_id}: {title}\n\tAbstract: {page_content}\n\tAuthors: {authors}\n\tDate of Publication: {pubdate}\n\tCategories: {categories}\nSOURCE: {id}"),
177
+ "metadata_cols": [
178
+ AttributeInfo(
179
+ name=VirtualColumnName(name="pubdate"),
180
+ description="The year the paper is published",
181
+ type="timestamp",
182
+ ),
183
+ AttributeInfo(
184
+ name="authors",
185
+ description="List of author names",
186
+ type="list[string]",
187
+ ),
188
+ AttributeInfo(
189
+ name="title",
190
+ description="Title of the paper",
191
+ type="string",
192
+ ),
193
+ AttributeInfo(
194
+ name="categories",
195
+ description="arxiv categories to this paper",
196
+ type="list[string]"
197
+ ),
198
+ AttributeInfo(
199
+ name="length(categories)",
200
+ description="length of arxiv categories to this paper",
201
+ type="int"
202
+ ),
203
+ ],
204
+ "must_have_cols": ['title', 'id', 'categories', 'abstract', 'authors', 'pubdate'],
205
+ "vector_col": "vector",
206
+ "text_col": "abstract",
207
+ "metadata_col": "metadata",
208
+ "emb_model": lambda: HuggingFaceInstructEmbeddings(
209
+ model_name='hkunlp/instructor-xl',
210
+ embed_instruction="Represent the question for retrieving supporting scientific papers: "),
211
+ "tool_desc": ("search_among_scientific_papers", "Searches among scientific papers from ArXiv and returns research papers"),
212
+ }
213
+ }
214
+
215
+ def build_embedding_model(_sel):
216
+ """Build embedding model
217
+ """
218
+ with st.spinner("Loading Model..."):
219
+ embeddings = sel_map[_sel]["emb_model"]()
220
+ return embeddings
221
+
222
+
223
+ def build_chains_retrievers(_sel: str) -> Dict[str, Any]:
224
+ """build chains and retrievers
225
+
226
+ :param _sel: selected knowledge base
227
+ :type _sel: str
228
+ :return: _description_
229
+ :rtype: Dict[str, Any]
230
+ """
231
+ metadata_field_info = sel_map[_sel]["metadata_cols"]
232
+ retriever = build_self_query(_sel)
233
+ chain = build_qa_chain(_sel, retriever, name="Self Query Retriever")
234
+ sql_retriever = build_vector_sql(_sel)
235
+ sql_chain = build_qa_chain(_sel, sql_retriever, name="Vector SQL")
236
+
237
+ return {
238
+ "metadata_columns": [{'name': m.name.name if type(m.name) is VirtualColumnName else m.name, 'desc': m.description, 'type': m.type} for m in metadata_field_info],
239
+ "retriever": retriever,
240
+ "chain": chain,
241
+ "sql_retriever": sql_retriever,
242
+ "sql_chain": sql_chain
243
+ }
244
+
245
+ def build_self_query(_sel: str) -> SelfQueryRetriever:
246
+ """Build self querying retriever
247
+
248
+ :param _sel: selected knowledge base
249
+ :type _sel: str
250
+ :return: retriever used by chains
251
+ :rtype: SelfQueryRetriever
252
+ """
253
+ with st.spinner(f"Connecting DB for {_sel}..."):
254
+ myscale_connection = {
255
+ "host": MYSCALE_HOST,
256
+ "port": MYSCALE_PORT,
257
+ "username": MYSCALE_USER,
258
+ "password": MYSCALE_PASSWORD,
259
+ }
260
+ config = MyScaleSettings(**myscale_connection,
261
+ database=sel_map[_sel]["database"],
262
+ table=sel_map[_sel]["table"],
263
+ column_map={
264
+ "id": "id",
265
+ "text": sel_map[_sel]["text_col"],
266
+ "vector": sel_map[_sel]["vector_col"],
267
+ "metadata": sel_map[_sel]["metadata_col"]
268
+ })
269
+ doc_search = MyScaleWithoutMetadataJson(st.session_state[f"emb_model_{_sel}"], config,
270
+ must_have_cols=sel_map[_sel]['must_have_cols'])
271
+
272
+ with st.spinner(f"Building Self Query Retriever for {_sel}..."):
273
+ metadata_field_info = sel_map[_sel]["metadata_cols"]
274
+ retriever = SelfQueryRetriever.from_llm(
275
+ OpenAI(model_name=query_model_name, openai_api_key=OPENAI_API_KEY, temperature=0),
276
+ doc_search, "Scientific papers indexes with abstracts. All in English.", metadata_field_info,
277
+ use_original_query=False, structured_query_translator=MyScaleTranslator())
278
+ return retriever
279
+
280
+ def build_vector_sql(_sel: str)->VectorSQLDatabaseChainRetriever:
281
+ """Build Vector SQL Database Retriever
282
+
283
+ :param _sel: selected knowledge base
284
+ :type _sel: str
285
+ :return: retriever used by chains
286
+ :rtype: VectorSQLDatabaseChainRetriever
287
+ """
288
+ with st.spinner(f'Building Vector SQL Database Retriever for {_sel}...'):
289
+ engine = create_engine(
290
+ f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}/{sel_map[_sel]["database"]}?protocol=https')
291
+ metadata = MetaData(bind=engine)
292
+ PROMPT = PromptTemplate(
293
+ input_variables=["input", "table_info", "top_k"],
294
+ template=_myscale_prompt,
295
+ )
296
+ output_parser = VectorSQLRetrieveCustomOutputParser.from_embeddings(
297
+ model=st.session_state[f'emb_model_{_sel}'], must_have_columns=sel_map[_sel]["must_have_cols"])
298
+ sql_query_chain = VectorSQLDatabaseChain.from_llm(
299
+ llm=OpenAI(model_name=query_model_name, openai_api_key=OPENAI_API_KEY, temperature=0),
300
+ prompt=PROMPT,
301
+ top_k=10,
302
+ return_direct=True,
303
+ db=SQLDatabase(engine, None, metadata, max_string_length=1024),
304
+ sql_cmd_parser=output_parser,
305
+ native_format=True
306
+ )
307
+ sql_retriever = VectorSQLDatabaseChainRetriever(
308
+ sql_db_chain=sql_query_chain, page_content_key=sel_map[_sel]["text_col"])
309
+ return sql_retriever
310
+
311
+ def build_qa_chain(_sel: str, retriever: BaseRetriever, name: str="Self-query") -> ArXivQAwithSourcesChain:
312
+ """_summary_
313
+
314
+ :param _sel: selected knowledge base
315
+ :type _sel: str
316
+ :param retriever: retriever used by chains
317
+ :type retriever: BaseRetriever
318
+ :param name: display name, defaults to "Self-query"
319
+ :type name: str, optional
320
+ :return: QA chain interacts with user
321
+ :rtype: ArXivQAwithSourcesChain
322
+ """
323
+ with st.spinner(f'Building QA Chain with {name} for {_sel}...'):
324
+ chain = ArXivQAwithSourcesChain(
325
+ retriever=retriever,
326
+ combine_documents_chain=ArXivStuffDocumentChain(
327
+ llm_chain=LLMChain(
328
+ prompt=COMBINE_PROMPT,
329
+ llm=ChatOpenAI(model_name=chat_model_name,
330
+ openai_api_key=OPENAI_API_KEY, temperature=0.6),
331
+ ),
332
+ document_prompt=sel_map[_sel]["doc_prompt"],
333
+ document_variable_name="summaries",
334
+
335
+ ),
336
+ return_source_documents=True,
337
+ max_tokens_limit=12000,
338
+ )
339
+ return chain
340
+
341
+ @st.cache_resource
342
+ def build_all() -> Dict[str, Any]:
343
+ """build all resources
344
+
345
+ :return: sel_map_obj
346
+ :rtype: Dict[str, Any]
347
+ """
348
+ sel_map_obj = {}
349
+ for k in sel_map:
350
+ st.session_state[f'emb_model_{k}'] = build_embedding_model(k)
351
+ sel_map_obj[k] = build_chains_retrievers(k)
352
+ return sel_map_obj
353
+
354
+ def create_message_model(table_name, DynamicBase): # type: ignore
355
+ """
356
+ Create a message model for a given table name.
357
+
358
+ Args:
359
+ table_name: The name of the table to use.
360
+ DynamicBase: The base class to use for the model.
361
+
362
+ Returns:
363
+ The model class.
364
+
365
+ """
366
+
367
+ # Model decleared inside a function to have a dynamic table name
368
+ class Message(DynamicBase):
369
+ __tablename__ = table_name
370
+ id = Column(types.Float64)
371
+ session_id = Column(Text)
372
+ msg_id = Column(Text, primary_key=True)
373
+ type = Column(Text)
374
+ addtionals = Column(Text)
375
+ message = Column(Text)
376
+ __table_args__ = (
377
+ engines.ReplacingMergeTree(
378
+ partition_by='session_id',
379
+ order_by=('id', 'msg_id')),
380
+ {'comment': 'Store Chat History'}
381
+ )
382
+
383
+ return Message
384
+
385
+ class DefaultClickhouseMessageConverter(DefaultMessageConverter):
386
+ """The default message converter for SQLChatMessageHistory."""
387
+
388
+ def __init__(self, table_name: str):
389
+ self.model_class = create_message_model(table_name, declarative_base())
390
+
391
+ def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:
392
+ tstamp = time.time()
393
+ msg_id = hashlib.sha256(f"{session_id}_{message}_{tstamp}".encode('utf-8')).hexdigest()
394
+ return self.model_class(
395
+ id=tstamp,
396
+ msg_id=msg_id,
397
+ session_id=session_id,
398
+ type=message.type,
399
+ addtionals=json.dumps(message.additional_kwargs),
400
+ message=json.dumps({
401
+ "type": message.type,
402
+ "additional_kwargs": {"timestamp": tstamp},
403
+ "data": message.dict()})
404
+ )
405
+ def from_sql_model(self, sql_message: Any) -> BaseMessage:
406
+ msg_dump = json.loads(sql_message.message)
407
+ msg = messages_from_dict([msg_dump])[0]
408
+ msg.additional_kwargs = msg_dump["additional_kwargs"]
409
+ return msg
410
+
411
+ def get_sql_model_class(self) -> Any:
412
+ return self.model_class
413
+
414
+
415
+ def create_agent_executor(name, session_id, llm, tools, **kwargs):
416
+ name = name.replace(" ", "_")
417
+ conn_str = f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}'
418
+ chat_memory = SQLChatMessageHistory(
419
+ session_id,
420
+ connection_string=f'{conn_str}/chat?protocol=https',
421
+ custom_message_converter=DefaultClickhouseMessageConverter(name))
422
+ memory = AgentTokenBufferMemory(llm=llm, chat_memory=chat_memory)
423
+
424
+ _system_message = SystemMessage(
425
+ content=(
426
+ "Do your best to answer the questions. "
427
+ "Feel free to use any tools available to look up "
428
+ "relevant information. Please keep all details in query "
429
+ "when calling search functions."
430
+ )
431
+ )
432
+ prompt = OpenAIFunctionsAgent.create_prompt(
433
+ system_message=_system_message,
434
+ extra_prompt_messages=[MessagesPlaceholder(variable_name="history")],
435
+ )
436
+ agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
437
+ return AgentExecutor(
438
+ agent=agent,
439
+ tools=tools,
440
+ memory=memory,
441
+ verbose=True,
442
+ return_intermediate_steps=True,
443
+ **kwargs
444
+ )
445
+
446
+ @st.cache_resource
447
+ def build_tools():
448
+ """build all resources
449
+
450
+ :return: sel_map_obj
451
+ :rtype: Dict[str, Any]
452
+ """
453
+ sel_map_obj = {}
454
+ for k in sel_map:
455
+ if f'emb_model_{k}' not in st.session_state:
456
+ st.session_state[f'emb_model_{k}'] = build_embedding_model(k)
457
+ if "sel_map_obj" not in st.session_state:
458
+ st.session_state["sel_map_obj"] = {}
459
+ if k not in st.session_state.sel_map_obj:
460
+ st.session_state["sel_map_obj"][k] = {}
461
+ if "langchain_retriever" not in st.session_state.sel_map_obj[k] or "vecsql_retriever" not in st.session_state.sel_map_obj[k]:
462
+ st.session_state.sel_map_obj[k].update(build_chains_retrievers(k))
463
+ sel_map_obj[k] = {
464
+ "langchain_retriever_tool": create_retriever_tool(st.session_state.sel_map_obj[k]["retriever"], *sel_map[k]["tool_desc"],),
465
+ "vecsql_retriever_tool": create_retriever_tool(st.session_state.sel_map_obj[k]["sql_retriever"], *sel_map[k]["tool_desc"],),
466
+ }
467
+ return sel_map_obj
468
+
469
+ @st.cache_resource(max_entries=1)
470
+ def build_agents(username):
471
+ chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=0.6, openai_api_base=OPENAI_API_BASE, openai_api_key=OPENAI_API_KEY)
472
+ agents = {}
473
+ cnt = 0
474
+ p = st.progress(0.0, "Building agents with different knowledge base...")
475
+ for k in [*sel_map.keys(), 'ArXiv + Wikipedia']:
476
+ for m, n in [("langchain_retriever_tool", "Self-querying retriever"), ("vecsql_retriever_tool", "Vector SQL")]:
477
+ if k == 'ArXiv + Wikipedia':
478
+ tools = [st.session_state.tools[k][m] for k in sel_map.keys()]
479
+ elif k == 'Null':
480
+ tools = []
481
+ else:
482
+ tools = [st.session_state.tools[k][m]]
483
+ if k not in agents:
484
+ agents[k] = {}
485
+ agents[k][n] = create_agent_executor(
486
+ "chat_memory",
487
+ username,
488
+ chat_llm,
489
+ tools=tools,
490
+ )
491
+ cnt += 1/6
492
+ p.progress(cnt, f"Building with Knowledge Base {k} via Retriever {n}...")
493
+ p.empty()
494
+ return agents
495
+
496
+
497
+ def display(dataframe, columns_=None, index=None):
498
+ if len(dataframe) > 0:
499
+ if index:
500
+ dataframe.set_index(index)
501
+ if columns_:
502
+ st.dataframe(dataframe[columns_])
503
+ else:
504
+ st.dataframe(dataframe)
505
+ else:
506
+ st.write("Sorry 😵 we didn't find any articles related to your query.\n\nMaybe the LLM is too naughty that does not follow our instruction... \n\nPlease try again and use verbs that may match the datatype.", unsafe_allow_html=True)
requirements.txt CHANGED
@@ -3,7 +3,8 @@ langchain-experimental @ git+https://github.com/myscale/langchain.git@preview#eg
3
  InstructorEmbedding
4
  pandas
5
  sentence_transformers
6
- streamlit==1.20
 
7
  altair==4.2.2
8
  clickhouse-connect
9
  openai
 
3
  InstructorEmbedding
4
  pandas
5
  sentence_transformers
6
+ streamlit==1.25
7
+ streamlit-auth0-component
8
  altair==4.2.2
9
  clickhouse-connect
10
  openai