Spaces:
Running
Running
Fangrui Liu
commited on
Commit
·
19bd5a9
1
Parent(s):
45180a0
update chat
Browse files- README.md +1 -1
- app.py +1 -291
- chat.py +204 -0
- helper.py +506 -0
- requirements.txt +2 -1
README.md
CHANGED
@@ -5,7 +5,7 @@ colorFrom: pink
|
|
5 |
colorTo: purple
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.20.0
|
8 |
-
app_file:
|
9 |
pinned: true
|
10 |
license: mit
|
11 |
---
|
|
|
5 |
colorTo: purple
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.20.0
|
8 |
+
app_file: chat.py
|
9 |
pinned: true
|
10 |
license: mit
|
11 |
---
|
app.py
CHANGED
@@ -14,308 +14,18 @@ from langchain.prompts import PromptTemplate, ChatPromptTemplate, \
|
|
14 |
from langchain.prompts.prompt import PromptTemplate
|
15 |
from langchain.chat_models import ChatOpenAI
|
16 |
from langchain import OpenAI
|
17 |
-
from langchain.chains.query_constructor.base import AttributeInfo, VirtualColumnName
|
18 |
-
from langchain.retrievers.self_query.base import SelfQueryRetriever
|
19 |
-
from langchain.retrievers.self_query.myscale import MyScaleTranslator
|
20 |
-
from langchain.embeddings import HuggingFaceInstructEmbeddings, SentenceTransformerEmbeddings
|
21 |
-
from langchain.vectorstores import MyScaleSettings
|
22 |
-
from chains.arxiv_chains import MyScaleWithoutMetadataJson
|
23 |
import re
|
24 |
import pandas as pd
|
25 |
from os import environ
|
26 |
import streamlit as st
|
27 |
import datetime
|
28 |
-
|
29 |
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
|
30 |
|
31 |
-
|
32 |
st.set_page_config(page_title="ChatData")
|
33 |
|
34 |
st.header("ChatData")
|
35 |
|
36 |
-
# query_model_name = "gpt-3.5-turbo-instruct"
|
37 |
-
query_model_name = "text-davinci-003"
|
38 |
-
chat_model_name = "gpt-3.5-turbo-16k"
|
39 |
-
|
40 |
-
|
41 |
-
def hint_arxiv():
|
42 |
-
st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
|
43 |
-
"For example: \n\n"
|
44 |
-
"*If you want to search papers with complex filters*:\n\n"
|
45 |
-
"- What is a Bayesian network? Please use articles published later than Feb 2018 and with more than 2 categories and whose title like `computer` and must have `cs.CV` in its category.\n\n"
|
46 |
-
"*If you want to ask questions based on papers in database*:\n\n"
|
47 |
-
"- What is PageRank?\n"
|
48 |
-
"- Did Geoffrey Hinton wrote paper about Capsule Neural Networks?\n"
|
49 |
-
"- Introduce some applications of GANs published around 2019.\n"
|
50 |
-
"- 请根据 2019 年左右的文章介绍一下 GAN 的应用都有哪些\n"
|
51 |
-
"- Veuillez présenter les applications du GAN sur la base des articles autour de 2019 ?\n"
|
52 |
-
"- Is it possible to synthesize room temperature super conductive material?")
|
53 |
-
|
54 |
-
|
55 |
-
def hint_sql_arxiv():
|
56 |
-
st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡')
|
57 |
-
st.markdown('''```sql
|
58 |
-
CREATE TABLE default.ChatArXiv (
|
59 |
-
`abstract` String,
|
60 |
-
`id` String,
|
61 |
-
`vector` Array(Float32),
|
62 |
-
`metadata` Object('JSON'),
|
63 |
-
`pubdate` DateTime,
|
64 |
-
`title` String,
|
65 |
-
`categories` Array(String),
|
66 |
-
`authors` Array(String),
|
67 |
-
`comment` String,
|
68 |
-
`primary_category` String,
|
69 |
-
VECTOR INDEX vec_idx vector TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'),
|
70 |
-
CONSTRAINT vec_len CHECK length(vector) = 768)
|
71 |
-
ENGINE = ReplacingMergeTree ORDER BY id
|
72 |
-
```''')
|
73 |
-
|
74 |
-
|
75 |
-
def hint_wiki():
|
76 |
-
st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
|
77 |
-
"For example: \n\n"
|
78 |
-
"- Which company did Elon Musk found?\n"
|
79 |
-
"- What is Iron Gwazi?\n"
|
80 |
-
"- What is a Ring in mathematics?\n"
|
81 |
-
"- 苹果的发源地是那里?\n")
|
82 |
-
|
83 |
-
|
84 |
-
def hint_sql_wiki():
|
85 |
-
st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡')
|
86 |
-
st.markdown('''```sql
|
87 |
-
CREATE TABLE wiki.Wikipedia (
|
88 |
-
`id` String,
|
89 |
-
`title` String,
|
90 |
-
`text` String,
|
91 |
-
`url` String,
|
92 |
-
`wiki_id` UInt64,
|
93 |
-
`views` Float32,
|
94 |
-
`paragraph_id` UInt64,
|
95 |
-
`langs` UInt32,
|
96 |
-
`emb` Array(Float32),
|
97 |
-
VECTOR INDEX vec_idx emb TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'),
|
98 |
-
CONSTRAINT emb_len CHECK length(emb) = 768)
|
99 |
-
ENGINE = ReplacingMergeTree ORDER BY id
|
100 |
-
```''')
|
101 |
-
|
102 |
-
|
103 |
-
sel_map = {
|
104 |
-
'Wikipedia': {
|
105 |
-
"database": "wiki",
|
106 |
-
"table": "Wikipedia",
|
107 |
-
"hint": hint_wiki,
|
108 |
-
"hint_sql": hint_sql_wiki,
|
109 |
-
"doc_prompt": PromptTemplate(
|
110 |
-
input_variables=["page_content", "url", "title", "ref_id", "views"],
|
111 |
-
template="Title for Doc #{ref_id}: {title}\n\tviews: {views}\n\tcontent: {page_content}\nSOURCE: {url}"),
|
112 |
-
"metadata_cols": [
|
113 |
-
AttributeInfo(
|
114 |
-
name="title",
|
115 |
-
description="title of the wikipedia page",
|
116 |
-
type="string",
|
117 |
-
),
|
118 |
-
AttributeInfo(
|
119 |
-
name="text",
|
120 |
-
description="paragraph from this wiki page",
|
121 |
-
type="string",
|
122 |
-
),
|
123 |
-
AttributeInfo(
|
124 |
-
name="views",
|
125 |
-
description="number of views",
|
126 |
-
type="float"
|
127 |
-
),
|
128 |
-
],
|
129 |
-
"must_have_cols": ['id', 'title', 'url', 'text', 'views'],
|
130 |
-
"vector_col": "emb",
|
131 |
-
"text_col": "text",
|
132 |
-
"metadata_col": "metadata",
|
133 |
-
"emb_model": lambda: SentenceTransformerEmbeddings(
|
134 |
-
model_name='sentence-transformers/paraphrase-multilingual-mpnet-base-v2',)
|
135 |
-
},
|
136 |
-
'ArXiv Papers': {
|
137 |
-
"database": "default",
|
138 |
-
"table": "ChatArXiv",
|
139 |
-
"hint": hint_arxiv,
|
140 |
-
"hint_sql": hint_sql_arxiv,
|
141 |
-
"doc_prompt": PromptTemplate(
|
142 |
-
input_variables=["page_content", "id", "title", "ref_id",
|
143 |
-
"authors", "pubdate", "categories"],
|
144 |
-
template="Title for Doc #{ref_id}: {title}\n\tAbstract: {page_content}\n\tAuthors: {authors}\n\tDate of Publication: {pubdate}\n\tCategories: {categories}\nSOURCE: {id}"),
|
145 |
-
"metadata_cols": [
|
146 |
-
AttributeInfo(
|
147 |
-
name=VirtualColumnName(name="pubdate"),
|
148 |
-
description="The year the paper is published",
|
149 |
-
type="timestamp",
|
150 |
-
),
|
151 |
-
AttributeInfo(
|
152 |
-
name="authors",
|
153 |
-
description="List of author names",
|
154 |
-
type="list[string]",
|
155 |
-
),
|
156 |
-
AttributeInfo(
|
157 |
-
name="title",
|
158 |
-
description="Title of the paper",
|
159 |
-
type="string",
|
160 |
-
),
|
161 |
-
AttributeInfo(
|
162 |
-
name="categories",
|
163 |
-
description="arxiv categories to this paper",
|
164 |
-
type="list[string]"
|
165 |
-
),
|
166 |
-
AttributeInfo(
|
167 |
-
name="length(categories)",
|
168 |
-
description="length of arxiv categories to this paper",
|
169 |
-
type="int"
|
170 |
-
),
|
171 |
-
],
|
172 |
-
"must_have_cols": ['title', 'id', 'categories', 'abstract', 'authors', 'pubdate'],
|
173 |
-
"vector_col": "vector",
|
174 |
-
"text_col": "abstract",
|
175 |
-
"metadata_col": "metadata",
|
176 |
-
"emb_model": lambda: HuggingFaceInstructEmbeddings(
|
177 |
-
model_name='hkunlp/instructor-xl',
|
178 |
-
embed_instruction="Represent the question for retrieving supporting scientific papers: ")
|
179 |
-
}
|
180 |
-
}
|
181 |
-
|
182 |
-
|
183 |
-
def try_eval(x):
|
184 |
-
try:
|
185 |
-
return eval(x, {'datetime': datetime})
|
186 |
-
except:
|
187 |
-
return x
|
188 |
-
|
189 |
-
|
190 |
-
def display(dataframe, columns_=None, index=None):
|
191 |
-
if len(dataframe) > 0:
|
192 |
-
if index:
|
193 |
-
dataframe.set_index(index)
|
194 |
-
if columns_:
|
195 |
-
st.dataframe(dataframe[columns_])
|
196 |
-
else:
|
197 |
-
st.dataframe(dataframe)
|
198 |
-
else:
|
199 |
-
st.write("Sorry 😵 we didn't find any articles related to your query.\n\nMaybe the LLM is too naughty that does not follow our instruction... \n\nPlease try again and use verbs that may match the datatype.", unsafe_allow_html=True)
|
200 |
-
|
201 |
-
|
202 |
-
def build_embedding_model(_sel):
|
203 |
-
with st.spinner("Loading Model..."):
|
204 |
-
embeddings = sel_map[_sel]["emb_model"]()
|
205 |
-
return embeddings
|
206 |
-
|
207 |
-
|
208 |
-
def build_retriever(_sel):
|
209 |
-
with st.spinner(f"Connecting DB for {_sel}..."):
|
210 |
-
myscale_connection = {
|
211 |
-
"host": st.secrets['MYSCALE_HOST'],
|
212 |
-
"port": st.secrets['MYSCALE_PORT'],
|
213 |
-
"username": st.secrets['MYSCALE_USER'],
|
214 |
-
"password": st.secrets['MYSCALE_PASSWORD'],
|
215 |
-
}
|
216 |
-
|
217 |
-
config = MyScaleSettings(**myscale_connection,
|
218 |
-
database=sel_map[_sel]["database"],
|
219 |
-
table=sel_map[_sel]["table"],
|
220 |
-
column_map={
|
221 |
-
"id": "id",
|
222 |
-
"text": sel_map[_sel]["text_col"],
|
223 |
-
"vector": sel_map[_sel]["vector_col"],
|
224 |
-
"metadata": sel_map[_sel]["metadata_col"]
|
225 |
-
})
|
226 |
-
doc_search = MyScaleWithoutMetadataJson(st.session_state[f"emb_model_{_sel}"], config,
|
227 |
-
must_have_cols=sel_map[_sel]['must_have_cols'])
|
228 |
-
|
229 |
-
with st.spinner(f"Building Self Query Retriever for {_sel}..."):
|
230 |
-
metadata_field_info = sel_map[_sel]["metadata_cols"]
|
231 |
-
retriever = SelfQueryRetriever.from_llm(
|
232 |
-
OpenAI(model_name=query_model_name, openai_api_key=st.secrets['OPENAI_API_KEY'], temperature=0),
|
233 |
-
doc_search, "Scientific papers indexes with abstracts. All in English.", metadata_field_info,
|
234 |
-
use_original_query=False, structured_query_translator=MyScaleTranslator())
|
235 |
-
|
236 |
-
COMBINE_PROMPT = ChatPromptTemplate.from_strings(
|
237 |
-
string_messages=[(SystemMessagePromptTemplate, combine_prompt_template),
|
238 |
-
(HumanMessagePromptTemplate, '{question}')])
|
239 |
-
OPENAI_API_KEY = st.secrets['OPENAI_API_KEY']
|
240 |
-
|
241 |
-
with st.spinner(f'Building QA Chain with Self-query for {_sel}...'):
|
242 |
-
chain = ArXivQAwithSourcesChain(
|
243 |
-
retriever=retriever,
|
244 |
-
combine_documents_chain=ArXivStuffDocumentChain(
|
245 |
-
llm_chain=LLMChain(
|
246 |
-
prompt=COMBINE_PROMPT,
|
247 |
-
llm=ChatOpenAI(model_name=chat_model_name,
|
248 |
-
openai_api_key=OPENAI_API_KEY, temperature=0.6),
|
249 |
-
),
|
250 |
-
document_prompt=sel_map[_sel]["doc_prompt"],
|
251 |
-
document_variable_name="summaries",
|
252 |
-
|
253 |
-
),
|
254 |
-
return_source_documents=True,
|
255 |
-
max_tokens_limit=12000,
|
256 |
-
)
|
257 |
-
|
258 |
-
with st.spinner(f'Building Vector SQL Database Retriever for {_sel}...'):
|
259 |
-
MYSCALE_USER = st.secrets['MYSCALE_USER']
|
260 |
-
MYSCALE_PASSWORD = st.secrets['MYSCALE_PASSWORD']
|
261 |
-
MYSCALE_HOST = st.secrets['MYSCALE_HOST']
|
262 |
-
MYSCALE_PORT = st.secrets['MYSCALE_PORT']
|
263 |
-
engine = create_engine(
|
264 |
-
f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}/{sel_map[_sel]["database"]}?protocol=https')
|
265 |
-
metadata = MetaData(bind=engine)
|
266 |
-
PROMPT = PromptTemplate(
|
267 |
-
input_variables=["input", "table_info", "top_k"],
|
268 |
-
template=_myscale_prompt,
|
269 |
-
)
|
270 |
-
output_parser = VectorSQLRetrieveCustomOutputParser.from_embeddings(
|
271 |
-
model=st.session_state[f'emb_model_{_sel}'], must_have_columns=sel_map[_sel]["must_have_cols"])
|
272 |
-
sql_query_chain = VectorSQLDatabaseChain.from_llm(
|
273 |
-
llm=OpenAI(model_name=query_model_name, openai_api_key=OPENAI_API_KEY, temperature=0),
|
274 |
-
prompt=PROMPT,
|
275 |
-
top_k=10,
|
276 |
-
return_direct=True,
|
277 |
-
db=SQLDatabase(engine, None, metadata, max_string_length=1024),
|
278 |
-
sql_cmd_parser=output_parser,
|
279 |
-
native_format=True
|
280 |
-
)
|
281 |
-
sql_retriever = VectorSQLDatabaseChainRetriever(
|
282 |
-
sql_db_chain=sql_query_chain, page_content_key=sel_map[_sel]["text_col"])
|
283 |
-
|
284 |
-
with st.spinner(f'Building QA Chain with Vector SQL for {_sel}...'):
|
285 |
-
sql_chain = ArXivQAwithSourcesChain(
|
286 |
-
retriever=sql_retriever,
|
287 |
-
combine_documents_chain=ArXivStuffDocumentChain(
|
288 |
-
llm_chain=LLMChain(
|
289 |
-
prompt=COMBINE_PROMPT,
|
290 |
-
llm=ChatOpenAI(model_name=chat_model_name,
|
291 |
-
openai_api_key=OPENAI_API_KEY, temperature=0.6),
|
292 |
-
),
|
293 |
-
document_prompt=sel_map[_sel]["doc_prompt"],
|
294 |
-
document_variable_name="summaries",
|
295 |
-
|
296 |
-
),
|
297 |
-
return_source_documents=True,
|
298 |
-
max_tokens_limit=12000,
|
299 |
-
)
|
300 |
-
|
301 |
-
return {
|
302 |
-
"metadata_columns": [{'name': m.name.name if type(m.name) is VirtualColumnName else m.name, 'desc': m.description, 'type': m.type} for m in metadata_field_info],
|
303 |
-
"retriever": retriever,
|
304 |
-
"chain": chain,
|
305 |
-
"sql_retriever": sql_retriever,
|
306 |
-
"sql_chain": sql_chain
|
307 |
-
}
|
308 |
-
|
309 |
-
|
310 |
-
@st.cache_resource
|
311 |
-
def build_all():
|
312 |
-
sel_map_obj = {}
|
313 |
-
for k in sel_map:
|
314 |
-
st.session_state[f'emb_model_{k}'] = build_embedding_model(k)
|
315 |
-
sel_map_obj[k] = build_retriever(k)
|
316 |
-
return sel_map_obj
|
317 |
-
|
318 |
-
|
319 |
if 'retriever' not in st.session_state:
|
320 |
st.session_state["sel_map_obj"] = build_all()
|
321 |
|
|
|
14 |
from langchain.prompts.prompt import PromptTemplate
|
15 |
from langchain.chat_models import ChatOpenAI
|
16 |
from langchain import OpenAI
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
import re
|
18 |
import pandas as pd
|
19 |
from os import environ
|
20 |
import streamlit as st
|
21 |
import datetime
|
22 |
+
from helper import build_all, sel_map, display
|
23 |
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
|
24 |
|
|
|
25 |
st.set_page_config(page_title="ChatData")
|
26 |
|
27 |
st.header("ChatData")
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
if 'retriever' not in st.session_state:
|
30 |
st.session_state["sel_map_obj"] = build_all()
|
31 |
|
chat.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import time
|
3 |
+
import pandas as pd
|
4 |
+
from os import environ
|
5 |
+
import datetime
|
6 |
+
import streamlit as st
|
7 |
+
from langchain.schema import Document
|
8 |
+
|
9 |
+
from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
|
10 |
+
ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \
|
11 |
+
ChatDataSQLAskCallBackHandler
|
12 |
+
|
13 |
+
from langchain.schema import BaseMessage, HumanMessage, AIMessage, FunctionMessage, SystemMessage
|
14 |
+
from auth0_component import login_button
|
15 |
+
|
16 |
+
|
17 |
+
from helper import build_tools, build_agents, build_all, sel_map, display
|
18 |
+
|
19 |
+
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
|
20 |
+
|
21 |
+
st.set_page_config(page_title="ChatData", page_icon="https://myscale.com/favicon.ico")
|
22 |
+
st.header("ChatData")
|
23 |
+
|
24 |
+
|
25 |
+
if 'retriever' not in st.session_state:
|
26 |
+
st.session_state["sel_map_obj"] = build_all()
|
27 |
+
st.session_state["tools"] = build_tools()
|
28 |
+
|
29 |
+
def on_chat_submit():
|
30 |
+
ret = st.session_state.agents[st.session_state.sel][st.session_state.ret_type]({"input": st.session_state.chat_input})
|
31 |
+
print(ret)
|
32 |
+
|
33 |
+
def clear_history():
|
34 |
+
st.session_state.agents[st.session_state.sel][st.session_state.ret_type].memory.clear()
|
35 |
+
|
36 |
+
AUTH0_CLIENT_ID = st.secrets['AUTH0_CLIENT_ID']
|
37 |
+
AUTH0_DOMAIN = st.secrets['AUTH0_DOMAIN']
|
38 |
+
|
39 |
+
def login():
|
40 |
+
if "user_name" in st.session_state or ("jump_query_ask" in st.session_state and st.session_state.jump_query_ask):
|
41 |
+
return True
|
42 |
+
st.subheader("🤗 Welcom to [MyScale](https://myscale.com)'s [ChatData](https://github.com/myscale/ChatData)! 🤗 ")
|
43 |
+
st.write("You can now chat with ArXiv and Wikipedia! You can also try to build your RAG system with those knowledge base via [our public read-only credentials!](https://github.com/myscale/ChatData#data-schema) 🌟\n")
|
44 |
+
st.write("Built purely with streamlit 👑 , LangChain 🦜🔗 and love for AI!")
|
45 |
+
st.write("Follow us on [Twitter](https://x.com/myscaledb) and [Discord](https://discord.gg/D2qpkqc4Jq)!")
|
46 |
+
st.warning("To use chat, please jump to [https://myscale-chatdata.hf.space](https://myscale-chatdata.hf.space)")
|
47 |
+
st.info("We used [Auth0](https://auth0.com) as our identity provider. "
|
48 |
+
"We will **NOT** collect any of your conversation in any form for any purpose.")
|
49 |
+
st.divider()
|
50 |
+
col1, col2 = st.columns(2, gap='large')
|
51 |
+
with col1.container():
|
52 |
+
st.write("Try out MyScale's Self-query and Vector SQL retrievers!")
|
53 |
+
st.write("In this demo, you will be able to see how those retrievers "
|
54 |
+
"**digest** -> **translate** -> **retrieve** -> **answer** to your question!")
|
55 |
+
st.write("It is a step-by-step tour to understand RAG pipeline.")
|
56 |
+
st.session_state["jump_query_ask"] = st.button("Query / Ask")
|
57 |
+
with col2.container():
|
58 |
+
st.write("Now with the power of LangChain's Conversantional Agents, we are able to build "
|
59 |
+
"conversational chatbot with RAG! The agent will decide when and what to retrieve "
|
60 |
+
"based on your question!")
|
61 |
+
st.write("All those conversation history management and retrievers are provided within one MyScale instance!")
|
62 |
+
st.write("Log in to Chat with RAG!")
|
63 |
+
login_button(AUTH0_CLIENT_ID, AUTH0_DOMAIN, "auth0")
|
64 |
+
if st.session_state.auth0 is not None:
|
65 |
+
st.session_state.user_info = dict(st.session_state.auth0)
|
66 |
+
if 'email' in st.session_state.user_info:
|
67 |
+
email = st.session_state.user_info["email"]
|
68 |
+
else:
|
69 |
+
email = f"{st.session_state.user_info['nickname']}@{st.session_state.user_info['sub']}"
|
70 |
+
st.session_state["user_name"] = email
|
71 |
+
del st.session_state.auth0
|
72 |
+
st.experimental_rerun()
|
73 |
+
if st.session_state.jump_query_ask:
|
74 |
+
st.experimental_rerun()
|
75 |
+
|
76 |
+
def back_to_main():
|
77 |
+
if "user_info" in st.session_state:
|
78 |
+
del st.session_state.user_info
|
79 |
+
if "user_name" in st.session_state:
|
80 |
+
del st.session_state.user_name
|
81 |
+
if "jump_query_ask" in st.session_state:
|
82 |
+
del st.session_state.jump_query_ask
|
83 |
+
|
84 |
+
if login():
|
85 |
+
if "user_name" in st.session_state:
|
86 |
+
st.session_state["agents"] = build_agents(st.session_state.user_name)
|
87 |
+
with st.sidebar:
|
88 |
+
st.radio("Retriever Type", ["Self-querying retriever", "Vector SQL"], key="ret_type")
|
89 |
+
st.selectbox("Knowledge Base", ["ArXiv Papers", "Wikipedia", "ArXiv + Wikipedia"], key="sel")
|
90 |
+
st.button("Clear Chat History", on_click=clear_history)
|
91 |
+
st.button("Logout", on_click=back_to_main)
|
92 |
+
for msg in st.session_state.agents[st.session_state.sel][st.session_state.ret_type].memory.chat_memory.messages:
|
93 |
+
speaker = "user" if isinstance(msg, HumanMessage) else "assistant"
|
94 |
+
if isinstance(msg, FunctionMessage):
|
95 |
+
with st.chat_message("Knowledge Base", avatar="📖"):
|
96 |
+
print(type(msg.content))
|
97 |
+
st.write(f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*")
|
98 |
+
st.write("Retrieved from knowledge base:")
|
99 |
+
st.dataframe(pd.DataFrame.from_records(map(dict, eval(msg.content))))
|
100 |
+
else:
|
101 |
+
if len(msg.content) > 0:
|
102 |
+
with st.chat_message(speaker):
|
103 |
+
print(type(msg), msg.dict())
|
104 |
+
st.write(f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*")
|
105 |
+
st.write(f"{msg.content}")
|
106 |
+
st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")
|
107 |
+
elif "jump_query_ask" in st.session_state and st.session_state.jump_query_ask:
|
108 |
+
|
109 |
+
sel = st.selectbox('Choose the knowledge base you want to ask with:',
|
110 |
+
options=['ArXiv Papers', 'Wikipedia'])
|
111 |
+
sel_map[sel]['hint']()
|
112 |
+
tab_sql, tab_self_query = st.tabs(['Vector SQL', 'Self-Query Retrievers'])
|
113 |
+
with tab_sql:
|
114 |
+
sel_map[sel]['hint_sql']()
|
115 |
+
st.text_input("Ask a question:", key='query_sql')
|
116 |
+
cols = st.columns([1, 1, 1, 4])
|
117 |
+
cols[0].button("Query", key='search_sql')
|
118 |
+
cols[1].button("Ask", key='ask_sql')
|
119 |
+
cols[2].button("Back", key='back_sql', on_click=back_to_main)
|
120 |
+
plc_hldr = st.empty()
|
121 |
+
if st.session_state.search_sql:
|
122 |
+
plc_hldr = st.empty()
|
123 |
+
print(st.session_state.query_sql)
|
124 |
+
with plc_hldr.expander('Query Log', expanded=True):
|
125 |
+
callback = ChatDataSQLSearchCallBackHandler()
|
126 |
+
try:
|
127 |
+
docs = st.session_state.sel_map_obj[sel]["sql_retriever"].get_relevant_documents(
|
128 |
+
st.session_state.query_sql, callbacks=[callback])
|
129 |
+
callback.progress_bar.progress(value=1.0, text="Done!")
|
130 |
+
docs = pd.DataFrame(
|
131 |
+
[{**d.metadata, 'abstract': d.page_content} for d in docs])
|
132 |
+
display(docs)
|
133 |
+
except Exception as e:
|
134 |
+
st.write('Oops 😵 Something bad happened...')
|
135 |
+
raise e
|
136 |
+
|
137 |
+
if st.session_state.ask_sql:
|
138 |
+
plc_hldr = st.empty()
|
139 |
+
print(st.session_state.query_sql)
|
140 |
+
with plc_hldr.expander('Chat Log', expanded=True):
|
141 |
+
callback = ChatDataSQLAskCallBackHandler()
|
142 |
+
try:
|
143 |
+
ret = st.session_state.sel_map_obj[sel]["sql_chain"](
|
144 |
+
st.session_state.query_sql, callbacks=[callback])
|
145 |
+
callback.progress_bar.progress(value=1.0, text="Done!")
|
146 |
+
st.markdown(
|
147 |
+
f"### Answer from LLM\n{ret['answer']}\n### References")
|
148 |
+
docs = ret['sources']
|
149 |
+
docs = pd.DataFrame(
|
150 |
+
[{**d.metadata, 'abstract': d.page_content} for d in docs])
|
151 |
+
display(
|
152 |
+
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
|
153 |
+
except Exception as e:
|
154 |
+
st.write('Oops 😵 Something bad happened...')
|
155 |
+
raise e
|
156 |
+
|
157 |
+
|
158 |
+
with tab_self_query:
|
159 |
+
st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡')
|
160 |
+
st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"])
|
161 |
+
st.text_input("Ask a question:", key='query_self')
|
162 |
+
cols = st.columns([1, 1, 1, 4])
|
163 |
+
cols[0].button("Query", key='search_self')
|
164 |
+
cols[1].button("Ask", key='ask_self')
|
165 |
+
cols[2].button("Back", key='back_self', on_click=back_to_main)
|
166 |
+
plc_hldr = st.empty()
|
167 |
+
if st.session_state.search_self:
|
168 |
+
plc_hldr = st.empty()
|
169 |
+
print(st.session_state.query_self)
|
170 |
+
with plc_hldr.expander('Query Log', expanded=True):
|
171 |
+
call_back = None
|
172 |
+
callback = ChatDataSelfSearchCallBackHandler()
|
173 |
+
try:
|
174 |
+
docs = st.session_state.sel_map_obj[sel]["retriever"].get_relevant_documents(
|
175 |
+
st.session_state.query_self, callbacks=[callback])
|
176 |
+
print(docs)
|
177 |
+
callback.progress_bar.progress(value=1.0, text="Done!")
|
178 |
+
docs = pd.DataFrame(
|
179 |
+
[{**d.metadata, 'abstract': d.page_content} for d in docs])
|
180 |
+
display(docs, sel_map[sel]["must_have_cols"])
|
181 |
+
except Exception as e:
|
182 |
+
st.write('Oops 😵 Something bad happened...')
|
183 |
+
raise e
|
184 |
+
|
185 |
+
if st.session_state.ask_self:
|
186 |
+
plc_hldr = st.empty()
|
187 |
+
print(st.session_state.query_self)
|
188 |
+
with plc_hldr.expander('Chat Log', expanded=True):
|
189 |
+
call_back = None
|
190 |
+
callback = ChatDataSelfAskCallBackHandler()
|
191 |
+
try:
|
192 |
+
ret = st.session_state.sel_map_obj[sel]["chain"](
|
193 |
+
st.session_state.query_self, callbacks=[callback])
|
194 |
+
callback.progress_bar.progress(value=1.0, text="Done!")
|
195 |
+
st.markdown(
|
196 |
+
f"### Answer from LLM\n{ret['answer']}\n### References")
|
197 |
+
docs = ret['sources']
|
198 |
+
docs = pd.DataFrame(
|
199 |
+
[{**d.metadata, 'abstract': d.page_content} for d in docs])
|
200 |
+
display(
|
201 |
+
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
|
202 |
+
except Exception as e:
|
203 |
+
st.write('Oops 😵 Something bad happened...')
|
204 |
+
raise e
|
helper.py
ADDED
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
import hashlib
|
5 |
+
from typing import Dict, Any
|
6 |
+
import re
|
7 |
+
import pandas as pd
|
8 |
+
from os import environ
|
9 |
+
import streamlit as st
|
10 |
+
import datetime
|
11 |
+
|
12 |
+
from sqlalchemy import Column, Text, create_engine, MetaData
|
13 |
+
from langchain.agents import AgentExecutor
|
14 |
+
try:
|
15 |
+
from sqlalchemy.orm import declarative_base
|
16 |
+
except ImportError:
|
17 |
+
from sqlalchemy.ext.declarative import declarative_base
|
18 |
+
from sqlalchemy.orm import sessionmaker
|
19 |
+
from clickhouse_sqlalchemy import (
|
20 |
+
Table, make_session, get_declarative_base, types, engines
|
21 |
+
)
|
22 |
+
from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
|
23 |
+
from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
|
24 |
+
from langchain.utilities.sql_database import SQLDatabase
|
25 |
+
from langchain.chains import LLMChain
|
26 |
+
from sqlalchemy import create_engine, MetaData
|
27 |
+
from langchain.prompts import PromptTemplate, ChatPromptTemplate, \
|
28 |
+
SystemMessagePromptTemplate, HumanMessagePromptTemplate
|
29 |
+
from langchain.prompts.prompt import PromptTemplate
|
30 |
+
from langchain.chat_models import ChatOpenAI
|
31 |
+
from langchain.schema import BaseRetriever
|
32 |
+
from langchain import OpenAI
|
33 |
+
from langchain.chains.query_constructor.base import AttributeInfo, VirtualColumnName
|
34 |
+
from langchain.retrievers.self_query.base import SelfQueryRetriever
|
35 |
+
from langchain.retrievers.self_query.myscale import MyScaleTranslator
|
36 |
+
from langchain.embeddings import HuggingFaceInstructEmbeddings, SentenceTransformerEmbeddings
|
37 |
+
from langchain.vectorstores import MyScaleSettings
|
38 |
+
from chains.arxiv_chains import MyScaleWithoutMetadataJson
|
39 |
+
from langchain.schema import Document
|
40 |
+
from langchain.prompts.prompt import PromptTemplate
|
41 |
+
from langchain.prompts.chat import MessagesPlaceholder
|
42 |
+
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory
|
43 |
+
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
44 |
+
from langchain.schema import BaseMessage, HumanMessage, AIMessage, FunctionMessage, SystemMessage
|
45 |
+
from langchain.memory import SQLChatMessageHistory
|
46 |
+
from langchain.memory.chat_message_histories.sql import \
|
47 |
+
BaseMessageConverter, DefaultMessageConverter
|
48 |
+
from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
|
49 |
+
from langchain.agents.agent_toolkits import create_retriever_tool
|
50 |
+
from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt
|
51 |
+
from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
|
52 |
+
from chains.arxiv_chains import VectorSQLRetrieveCustomOutputParser
|
53 |
+
environ['TOKENIZERS_PARALLELISM'] = 'true'
|
54 |
+
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
|
55 |
+
|
56 |
+
# query_model_name = "gpt-3.5-turbo-instruct"
|
57 |
+
query_model_name = "text-davinci-003"
|
58 |
+
chat_model_name = "gpt-3.5-turbo-16k"
|
59 |
+
|
60 |
+
|
61 |
+
OPENAI_API_KEY = st.secrets['OPENAI_API_KEY']
|
62 |
+
OPENAI_API_BASE = st.secrets['OPENAI_API_BASE']
|
63 |
+
MYSCALE_USER = st.secrets['MYSCALE_USER']
|
64 |
+
MYSCALE_PASSWORD = st.secrets['MYSCALE_PASSWORD']
|
65 |
+
MYSCALE_HOST = st.secrets['MYSCALE_HOST']
|
66 |
+
MYSCALE_PORT = st.secrets['MYSCALE_PORT']
|
67 |
+
|
68 |
+
COMBINE_PROMPT = ChatPromptTemplate.from_strings(
|
69 |
+
string_messages=[(SystemMessagePromptTemplate, combine_prompt_template),
|
70 |
+
(HumanMessagePromptTemplate, '{question}')])
|
71 |
+
|
72 |
+
def hint_arxiv():
|
73 |
+
st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
|
74 |
+
"For example: \n\n"
|
75 |
+
"*If you want to search papers with complex filters*:\n\n"
|
76 |
+
"- What is a Bayesian network? Please use articles published later than Feb 2018 and with more than 2 categories and whose title like `computer` and must have `cs.CV` in its category.\n\n"
|
77 |
+
"*If you want to ask questions based on papers in database*:\n\n"
|
78 |
+
"- What is PageRank?\n"
|
79 |
+
"- Did Geoffrey Hinton wrote paper about Capsule Neural Networks?\n"
|
80 |
+
"- Introduce some applications of GANs published around 2019.\n"
|
81 |
+
"- 请根据 2019 年左右的文章介绍一下 GAN 的应用都有哪些\n"
|
82 |
+
"- Veuillez présenter les applications du GAN sur la base des articles autour de 2019 ?\n"
|
83 |
+
"- Is it possible to synthesize room temperature super conductive material?")
|
84 |
+
|
85 |
+
|
86 |
+
def hint_sql_arxiv():
|
87 |
+
st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡')
|
88 |
+
st.markdown('''```sql
|
89 |
+
CREATE TABLE default.ChatArXiv (
|
90 |
+
`abstract` String,
|
91 |
+
`id` String,
|
92 |
+
`vector` Array(Float32),
|
93 |
+
`metadata` Object('JSON'),
|
94 |
+
`pubdate` DateTime,
|
95 |
+
`title` String,
|
96 |
+
`categories` Array(String),
|
97 |
+
`authors` Array(String),
|
98 |
+
`comment` String,
|
99 |
+
`primary_category` String,
|
100 |
+
VECTOR INDEX vec_idx vector TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'),
|
101 |
+
CONSTRAINT vec_len CHECK length(vector) = 768)
|
102 |
+
ENGINE = ReplacingMergeTree ORDER BY id
|
103 |
+
```''')
|
104 |
+
|
105 |
+
|
106 |
+
def hint_wiki():
|
107 |
+
st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
|
108 |
+
"For example: \n\n"
|
109 |
+
"- Which company did Elon Musk found?\n"
|
110 |
+
"- What is Iron Gwazi?\n"
|
111 |
+
"- What is a Ring in mathematics?\n"
|
112 |
+
"- 苹果的发源地是那里?\n")
|
113 |
+
|
114 |
+
|
115 |
+
def hint_sql_wiki():
|
116 |
+
st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡')
|
117 |
+
st.markdown('''```sql
|
118 |
+
CREATE TABLE wiki.Wikipedia (
|
119 |
+
`id` String,
|
120 |
+
`title` String,
|
121 |
+
`text` String,
|
122 |
+
`url` String,
|
123 |
+
`wiki_id` UInt64,
|
124 |
+
`views` Float32,
|
125 |
+
`paragraph_id` UInt64,
|
126 |
+
`langs` UInt32,
|
127 |
+
`emb` Array(Float32),
|
128 |
+
VECTOR INDEX vec_idx emb TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'),
|
129 |
+
CONSTRAINT emb_len CHECK length(emb) = 768)
|
130 |
+
ENGINE = ReplacingMergeTree ORDER BY id
|
131 |
+
```''')
|
132 |
+
|
133 |
+
|
134 |
+
sel_map = {
|
135 |
+
'Wikipedia': {
|
136 |
+
"database": "wiki",
|
137 |
+
"table": "Wikipedia",
|
138 |
+
"hint": hint_wiki,
|
139 |
+
"hint_sql": hint_sql_wiki,
|
140 |
+
"doc_prompt": PromptTemplate(
|
141 |
+
input_variables=["page_content", "url", "title", "ref_id", "views"],
|
142 |
+
template="Title for Doc #{ref_id}: {title}\n\tviews: {views}\n\tcontent: {page_content}\nSOURCE: {url}"),
|
143 |
+
"metadata_cols": [
|
144 |
+
AttributeInfo(
|
145 |
+
name="title",
|
146 |
+
description="title of the wikipedia page",
|
147 |
+
type="string",
|
148 |
+
),
|
149 |
+
AttributeInfo(
|
150 |
+
name="text",
|
151 |
+
description="paragraph from this wiki page",
|
152 |
+
type="string",
|
153 |
+
),
|
154 |
+
AttributeInfo(
|
155 |
+
name="views",
|
156 |
+
description="number of views",
|
157 |
+
type="float"
|
158 |
+
),
|
159 |
+
],
|
160 |
+
"must_have_cols": ['id', 'title', 'url', 'text', 'views'],
|
161 |
+
"vector_col": "emb",
|
162 |
+
"text_col": "text",
|
163 |
+
"metadata_col": "metadata",
|
164 |
+
"emb_model": lambda: SentenceTransformerEmbeddings(
|
165 |
+
model_name='sentence-transformers/paraphrase-multilingual-mpnet-base-v2',),
|
166 |
+
"tool_desc": ("search_among_wikipedia", "Searches among Wikipedia and returns related wiki pages"),
|
167 |
+
},
|
168 |
+
'ArXiv Papers': {
|
169 |
+
"database": "default",
|
170 |
+
"table": "ChatArXiv",
|
171 |
+
"hint": hint_arxiv,
|
172 |
+
"hint_sql": hint_sql_arxiv,
|
173 |
+
"doc_prompt": PromptTemplate(
|
174 |
+
input_variables=["page_content", "id", "title", "ref_id",
|
175 |
+
"authors", "pubdate", "categories"],
|
176 |
+
template="Title for Doc #{ref_id}: {title}\n\tAbstract: {page_content}\n\tAuthors: {authors}\n\tDate of Publication: {pubdate}\n\tCategories: {categories}\nSOURCE: {id}"),
|
177 |
+
"metadata_cols": [
|
178 |
+
AttributeInfo(
|
179 |
+
name=VirtualColumnName(name="pubdate"),
|
180 |
+
description="The year the paper is published",
|
181 |
+
type="timestamp",
|
182 |
+
),
|
183 |
+
AttributeInfo(
|
184 |
+
name="authors",
|
185 |
+
description="List of author names",
|
186 |
+
type="list[string]",
|
187 |
+
),
|
188 |
+
AttributeInfo(
|
189 |
+
name="title",
|
190 |
+
description="Title of the paper",
|
191 |
+
type="string",
|
192 |
+
),
|
193 |
+
AttributeInfo(
|
194 |
+
name="categories",
|
195 |
+
description="arxiv categories to this paper",
|
196 |
+
type="list[string]"
|
197 |
+
),
|
198 |
+
AttributeInfo(
|
199 |
+
name="length(categories)",
|
200 |
+
description="length of arxiv categories to this paper",
|
201 |
+
type="int"
|
202 |
+
),
|
203 |
+
],
|
204 |
+
"must_have_cols": ['title', 'id', 'categories', 'abstract', 'authors', 'pubdate'],
|
205 |
+
"vector_col": "vector",
|
206 |
+
"text_col": "abstract",
|
207 |
+
"metadata_col": "metadata",
|
208 |
+
"emb_model": lambda: HuggingFaceInstructEmbeddings(
|
209 |
+
model_name='hkunlp/instructor-xl',
|
210 |
+
embed_instruction="Represent the question for retrieving supporting scientific papers: "),
|
211 |
+
"tool_desc": ("search_among_scientific_papers", "Searches among scientific papers from ArXiv and returns research papers"),
|
212 |
+
}
|
213 |
+
}
|
214 |
+
|
215 |
+
def build_embedding_model(_sel):
|
216 |
+
"""Build embedding model
|
217 |
+
"""
|
218 |
+
with st.spinner("Loading Model..."):
|
219 |
+
embeddings = sel_map[_sel]["emb_model"]()
|
220 |
+
return embeddings
|
221 |
+
|
222 |
+
|
223 |
+
def build_chains_retrievers(_sel: str) -> Dict[str, Any]:
|
224 |
+
"""build chains and retrievers
|
225 |
+
|
226 |
+
:param _sel: selected knowledge base
|
227 |
+
:type _sel: str
|
228 |
+
:return: _description_
|
229 |
+
:rtype: Dict[str, Any]
|
230 |
+
"""
|
231 |
+
metadata_field_info = sel_map[_sel]["metadata_cols"]
|
232 |
+
retriever = build_self_query(_sel)
|
233 |
+
chain = build_qa_chain(_sel, retriever, name="Self Query Retriever")
|
234 |
+
sql_retriever = build_vector_sql(_sel)
|
235 |
+
sql_chain = build_qa_chain(_sel, sql_retriever, name="Vector SQL")
|
236 |
+
|
237 |
+
return {
|
238 |
+
"metadata_columns": [{'name': m.name.name if type(m.name) is VirtualColumnName else m.name, 'desc': m.description, 'type': m.type} for m in metadata_field_info],
|
239 |
+
"retriever": retriever,
|
240 |
+
"chain": chain,
|
241 |
+
"sql_retriever": sql_retriever,
|
242 |
+
"sql_chain": sql_chain
|
243 |
+
}
|
244 |
+
|
245 |
+
def build_self_query(_sel: str) -> SelfQueryRetriever:
|
246 |
+
"""Build self querying retriever
|
247 |
+
|
248 |
+
:param _sel: selected knowledge base
|
249 |
+
:type _sel: str
|
250 |
+
:return: retriever used by chains
|
251 |
+
:rtype: SelfQueryRetriever
|
252 |
+
"""
|
253 |
+
with st.spinner(f"Connecting DB for {_sel}..."):
|
254 |
+
myscale_connection = {
|
255 |
+
"host": MYSCALE_HOST,
|
256 |
+
"port": MYSCALE_PORT,
|
257 |
+
"username": MYSCALE_USER,
|
258 |
+
"password": MYSCALE_PASSWORD,
|
259 |
+
}
|
260 |
+
config = MyScaleSettings(**myscale_connection,
|
261 |
+
database=sel_map[_sel]["database"],
|
262 |
+
table=sel_map[_sel]["table"],
|
263 |
+
column_map={
|
264 |
+
"id": "id",
|
265 |
+
"text": sel_map[_sel]["text_col"],
|
266 |
+
"vector": sel_map[_sel]["vector_col"],
|
267 |
+
"metadata": sel_map[_sel]["metadata_col"]
|
268 |
+
})
|
269 |
+
doc_search = MyScaleWithoutMetadataJson(st.session_state[f"emb_model_{_sel}"], config,
|
270 |
+
must_have_cols=sel_map[_sel]['must_have_cols'])
|
271 |
+
|
272 |
+
with st.spinner(f"Building Self Query Retriever for {_sel}..."):
|
273 |
+
metadata_field_info = sel_map[_sel]["metadata_cols"]
|
274 |
+
retriever = SelfQueryRetriever.from_llm(
|
275 |
+
OpenAI(model_name=query_model_name, openai_api_key=OPENAI_API_KEY, temperature=0),
|
276 |
+
doc_search, "Scientific papers indexes with abstracts. All in English.", metadata_field_info,
|
277 |
+
use_original_query=False, structured_query_translator=MyScaleTranslator())
|
278 |
+
return retriever
|
279 |
+
|
280 |
+
def build_vector_sql(_sel: str)->VectorSQLDatabaseChainRetriever:
|
281 |
+
"""Build Vector SQL Database Retriever
|
282 |
+
|
283 |
+
:param _sel: selected knowledge base
|
284 |
+
:type _sel: str
|
285 |
+
:return: retriever used by chains
|
286 |
+
:rtype: VectorSQLDatabaseChainRetriever
|
287 |
+
"""
|
288 |
+
with st.spinner(f'Building Vector SQL Database Retriever for {_sel}...'):
|
289 |
+
engine = create_engine(
|
290 |
+
f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}/{sel_map[_sel]["database"]}?protocol=https')
|
291 |
+
metadata = MetaData(bind=engine)
|
292 |
+
PROMPT = PromptTemplate(
|
293 |
+
input_variables=["input", "table_info", "top_k"],
|
294 |
+
template=_myscale_prompt,
|
295 |
+
)
|
296 |
+
output_parser = VectorSQLRetrieveCustomOutputParser.from_embeddings(
|
297 |
+
model=st.session_state[f'emb_model_{_sel}'], must_have_columns=sel_map[_sel]["must_have_cols"])
|
298 |
+
sql_query_chain = VectorSQLDatabaseChain.from_llm(
|
299 |
+
llm=OpenAI(model_name=query_model_name, openai_api_key=OPENAI_API_KEY, temperature=0),
|
300 |
+
prompt=PROMPT,
|
301 |
+
top_k=10,
|
302 |
+
return_direct=True,
|
303 |
+
db=SQLDatabase(engine, None, metadata, max_string_length=1024),
|
304 |
+
sql_cmd_parser=output_parser,
|
305 |
+
native_format=True
|
306 |
+
)
|
307 |
+
sql_retriever = VectorSQLDatabaseChainRetriever(
|
308 |
+
sql_db_chain=sql_query_chain, page_content_key=sel_map[_sel]["text_col"])
|
309 |
+
return sql_retriever
|
310 |
+
|
311 |
+
def build_qa_chain(_sel: str, retriever: BaseRetriever, name: str="Self-query") -> ArXivQAwithSourcesChain:
|
312 |
+
"""_summary_
|
313 |
+
|
314 |
+
:param _sel: selected knowledge base
|
315 |
+
:type _sel: str
|
316 |
+
:param retriever: retriever used by chains
|
317 |
+
:type retriever: BaseRetriever
|
318 |
+
:param name: display name, defaults to "Self-query"
|
319 |
+
:type name: str, optional
|
320 |
+
:return: QA chain interacts with user
|
321 |
+
:rtype: ArXivQAwithSourcesChain
|
322 |
+
"""
|
323 |
+
with st.spinner(f'Building QA Chain with {name} for {_sel}...'):
|
324 |
+
chain = ArXivQAwithSourcesChain(
|
325 |
+
retriever=retriever,
|
326 |
+
combine_documents_chain=ArXivStuffDocumentChain(
|
327 |
+
llm_chain=LLMChain(
|
328 |
+
prompt=COMBINE_PROMPT,
|
329 |
+
llm=ChatOpenAI(model_name=chat_model_name,
|
330 |
+
openai_api_key=OPENAI_API_KEY, temperature=0.6),
|
331 |
+
),
|
332 |
+
document_prompt=sel_map[_sel]["doc_prompt"],
|
333 |
+
document_variable_name="summaries",
|
334 |
+
|
335 |
+
),
|
336 |
+
return_source_documents=True,
|
337 |
+
max_tokens_limit=12000,
|
338 |
+
)
|
339 |
+
return chain
|
340 |
+
|
341 |
+
@st.cache_resource
|
342 |
+
def build_all() -> Dict[str, Any]:
|
343 |
+
"""build all resources
|
344 |
+
|
345 |
+
:return: sel_map_obj
|
346 |
+
:rtype: Dict[str, Any]
|
347 |
+
"""
|
348 |
+
sel_map_obj = {}
|
349 |
+
for k in sel_map:
|
350 |
+
st.session_state[f'emb_model_{k}'] = build_embedding_model(k)
|
351 |
+
sel_map_obj[k] = build_chains_retrievers(k)
|
352 |
+
return sel_map_obj
|
353 |
+
|
354 |
+
def create_message_model(table_name, DynamicBase): # type: ignore
|
355 |
+
"""
|
356 |
+
Create a message model for a given table name.
|
357 |
+
|
358 |
+
Args:
|
359 |
+
table_name: The name of the table to use.
|
360 |
+
DynamicBase: The base class to use for the model.
|
361 |
+
|
362 |
+
Returns:
|
363 |
+
The model class.
|
364 |
+
|
365 |
+
"""
|
366 |
+
|
367 |
+
# Model decleared inside a function to have a dynamic table name
|
368 |
+
class Message(DynamicBase):
|
369 |
+
__tablename__ = table_name
|
370 |
+
id = Column(types.Float64)
|
371 |
+
session_id = Column(Text)
|
372 |
+
msg_id = Column(Text, primary_key=True)
|
373 |
+
type = Column(Text)
|
374 |
+
addtionals = Column(Text)
|
375 |
+
message = Column(Text)
|
376 |
+
__table_args__ = (
|
377 |
+
engines.ReplacingMergeTree(
|
378 |
+
partition_by='session_id',
|
379 |
+
order_by=('id', 'msg_id')),
|
380 |
+
{'comment': 'Store Chat History'}
|
381 |
+
)
|
382 |
+
|
383 |
+
return Message
|
384 |
+
|
385 |
+
class DefaultClickhouseMessageConverter(DefaultMessageConverter):
|
386 |
+
"""The default message converter for SQLChatMessageHistory."""
|
387 |
+
|
388 |
+
def __init__(self, table_name: str):
|
389 |
+
self.model_class = create_message_model(table_name, declarative_base())
|
390 |
+
|
391 |
+
def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:
|
392 |
+
tstamp = time.time()
|
393 |
+
msg_id = hashlib.sha256(f"{session_id}_{message}_{tstamp}".encode('utf-8')).hexdigest()
|
394 |
+
return self.model_class(
|
395 |
+
id=tstamp,
|
396 |
+
msg_id=msg_id,
|
397 |
+
session_id=session_id,
|
398 |
+
type=message.type,
|
399 |
+
addtionals=json.dumps(message.additional_kwargs),
|
400 |
+
message=json.dumps({
|
401 |
+
"type": message.type,
|
402 |
+
"additional_kwargs": {"timestamp": tstamp},
|
403 |
+
"data": message.dict()})
|
404 |
+
)
|
405 |
+
def from_sql_model(self, sql_message: Any) -> BaseMessage:
|
406 |
+
msg_dump = json.loads(sql_message.message)
|
407 |
+
msg = messages_from_dict([msg_dump])[0]
|
408 |
+
msg.additional_kwargs = msg_dump["additional_kwargs"]
|
409 |
+
return msg
|
410 |
+
|
411 |
+
def get_sql_model_class(self) -> Any:
|
412 |
+
return self.model_class
|
413 |
+
|
414 |
+
|
415 |
+
def create_agent_executor(name, session_id, llm, tools, **kwargs):
|
416 |
+
name = name.replace(" ", "_")
|
417 |
+
conn_str = f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}'
|
418 |
+
chat_memory = SQLChatMessageHistory(
|
419 |
+
session_id,
|
420 |
+
connection_string=f'{conn_str}/chat?protocol=https',
|
421 |
+
custom_message_converter=DefaultClickhouseMessageConverter(name))
|
422 |
+
memory = AgentTokenBufferMemory(llm=llm, chat_memory=chat_memory)
|
423 |
+
|
424 |
+
_system_message = SystemMessage(
|
425 |
+
content=(
|
426 |
+
"Do your best to answer the questions. "
|
427 |
+
"Feel free to use any tools available to look up "
|
428 |
+
"relevant information. Please keep all details in query "
|
429 |
+
"when calling search functions."
|
430 |
+
)
|
431 |
+
)
|
432 |
+
prompt = OpenAIFunctionsAgent.create_prompt(
|
433 |
+
system_message=_system_message,
|
434 |
+
extra_prompt_messages=[MessagesPlaceholder(variable_name="history")],
|
435 |
+
)
|
436 |
+
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
|
437 |
+
return AgentExecutor(
|
438 |
+
agent=agent,
|
439 |
+
tools=tools,
|
440 |
+
memory=memory,
|
441 |
+
verbose=True,
|
442 |
+
return_intermediate_steps=True,
|
443 |
+
**kwargs
|
444 |
+
)
|
445 |
+
|
446 |
+
@st.cache_resource
|
447 |
+
def build_tools():
|
448 |
+
"""build all resources
|
449 |
+
|
450 |
+
:return: sel_map_obj
|
451 |
+
:rtype: Dict[str, Any]
|
452 |
+
"""
|
453 |
+
sel_map_obj = {}
|
454 |
+
for k in sel_map:
|
455 |
+
if f'emb_model_{k}' not in st.session_state:
|
456 |
+
st.session_state[f'emb_model_{k}'] = build_embedding_model(k)
|
457 |
+
if "sel_map_obj" not in st.session_state:
|
458 |
+
st.session_state["sel_map_obj"] = {}
|
459 |
+
if k not in st.session_state.sel_map_obj:
|
460 |
+
st.session_state["sel_map_obj"][k] = {}
|
461 |
+
if "langchain_retriever" not in st.session_state.sel_map_obj[k] or "vecsql_retriever" not in st.session_state.sel_map_obj[k]:
|
462 |
+
st.session_state.sel_map_obj[k].update(build_chains_retrievers(k))
|
463 |
+
sel_map_obj[k] = {
|
464 |
+
"langchain_retriever_tool": create_retriever_tool(st.session_state.sel_map_obj[k]["retriever"], *sel_map[k]["tool_desc"],),
|
465 |
+
"vecsql_retriever_tool": create_retriever_tool(st.session_state.sel_map_obj[k]["sql_retriever"], *sel_map[k]["tool_desc"],),
|
466 |
+
}
|
467 |
+
return sel_map_obj
|
468 |
+
|
469 |
+
@st.cache_resource(max_entries=1)
|
470 |
+
def build_agents(username):
|
471 |
+
chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=0.6, openai_api_base=OPENAI_API_BASE, openai_api_key=OPENAI_API_KEY)
|
472 |
+
agents = {}
|
473 |
+
cnt = 0
|
474 |
+
p = st.progress(0.0, "Building agents with different knowledge base...")
|
475 |
+
for k in [*sel_map.keys(), 'ArXiv + Wikipedia']:
|
476 |
+
for m, n in [("langchain_retriever_tool", "Self-querying retriever"), ("vecsql_retriever_tool", "Vector SQL")]:
|
477 |
+
if k == 'ArXiv + Wikipedia':
|
478 |
+
tools = [st.session_state.tools[k][m] for k in sel_map.keys()]
|
479 |
+
elif k == 'Null':
|
480 |
+
tools = []
|
481 |
+
else:
|
482 |
+
tools = [st.session_state.tools[k][m]]
|
483 |
+
if k not in agents:
|
484 |
+
agents[k] = {}
|
485 |
+
agents[k][n] = create_agent_executor(
|
486 |
+
"chat_memory",
|
487 |
+
username,
|
488 |
+
chat_llm,
|
489 |
+
tools=tools,
|
490 |
+
)
|
491 |
+
cnt += 1/6
|
492 |
+
p.progress(cnt, f"Building with Knowledge Base {k} via Retriever {n}...")
|
493 |
+
p.empty()
|
494 |
+
return agents
|
495 |
+
|
496 |
+
|
497 |
+
def display(dataframe, columns_=None, index=None):
|
498 |
+
if len(dataframe) > 0:
|
499 |
+
if index:
|
500 |
+
dataframe.set_index(index)
|
501 |
+
if columns_:
|
502 |
+
st.dataframe(dataframe[columns_])
|
503 |
+
else:
|
504 |
+
st.dataframe(dataframe)
|
505 |
+
else:
|
506 |
+
st.write("Sorry 😵 we didn't find any articles related to your query.\n\nMaybe the LLM is too naughty that does not follow our instruction... \n\nPlease try again and use verbs that may match the datatype.", unsafe_allow_html=True)
|
requirements.txt
CHANGED
@@ -3,7 +3,8 @@ langchain-experimental @ git+https://github.com/myscale/langchain.git@preview#eg
|
|
3 |
InstructorEmbedding
|
4 |
pandas
|
5 |
sentence_transformers
|
6 |
-
streamlit==1.
|
|
|
7 |
altair==4.2.2
|
8 |
clickhouse-connect
|
9 |
openai
|
|
|
3 |
InstructorEmbedding
|
4 |
pandas
|
5 |
sentence_transformers
|
6 |
+
streamlit==1.25
|
7 |
+
streamlit-auth0-component
|
8 |
altair==4.2.2
|
9 |
clickhouse-connect
|
10 |
openai
|