|
from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt |
|
from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \ |
|
ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \ |
|
ChatDataSQLAskCallBackHandler |
|
from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain |
|
from chains.arxiv_chains import VectorSQLRetrieveCustomOutputParser |
|
from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain |
|
from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever |
|
from langchain.utilities.sql_database import SQLDatabase |
|
from langchain.chains import LLMChain |
|
from sqlalchemy import create_engine, MetaData |
|
from langchain.prompts import PromptTemplate, ChatPromptTemplate, \ |
|
SystemMessagePromptTemplate, HumanMessagePromptTemplate |
|
from langchain.prompts.prompt import PromptTemplate |
|
from langchain.chat_models import ChatOpenAI |
|
from langchain import OpenAI |
|
import re |
|
import pandas as pd |
|
from os import environ |
|
import streamlit as st |
|
import datetime |
|
from helper import build_all, sel_map, display |
|
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE'] |
|
|
|
st.set_page_config(page_title="ChatData") |
|
|
|
st.header("ChatData") |
|
|
|
if 'retriever' not in st.session_state: |
|
st.session_state["sel_map_obj"] = build_all() |
|
|
|
sel = st.selectbox('Choose the knowledge base you want to ask with:', |
|
options=['ArXiv Papers', 'Wikipedia']) |
|
sel_map[sel]['hint']() |
|
tab_sql, tab_self_query = st.tabs(['Vector SQL', 'Self-Query Retrievers']) |
|
with tab_sql: |
|
sel_map[sel]['hint_sql']() |
|
st.text_input("Ask a question:", key='query_sql') |
|
cols = st.columns([1, 1, 7]) |
|
cols[0].button("Query", key='search_sql') |
|
cols[1].button("Ask", key='ask_sql') |
|
plc_hldr = st.empty() |
|
if st.session_state.search_sql: |
|
plc_hldr = st.empty() |
|
print(st.session_state.query_sql) |
|
with plc_hldr.expander('Query Log', expanded=True): |
|
callback = ChatDataSQLSearchCallBackHandler() |
|
try: |
|
docs = st.session_state.sel_map_obj[sel]["sql_retriever"].get_relevant_documents( |
|
st.session_state.query_sql, callbacks=[callback]) |
|
callback.progress_bar.progress(value=1.0, text="Done!") |
|
docs = pd.DataFrame( |
|
[{**d.metadata, 'abstract': d.page_content} for d in docs]) |
|
display(docs) |
|
except Exception as e: |
|
st.write('Oops π΅ Something bad happened...') |
|
raise e |
|
|
|
if st.session_state.ask_sql: |
|
plc_hldr = st.empty() |
|
print(st.session_state.query_sql) |
|
with plc_hldr.expander('Chat Log', expanded=True): |
|
callback = ChatDataSQLAskCallBackHandler() |
|
try: |
|
ret = st.session_state.sel_map_obj[sel]["sql_chain"]( |
|
st.session_state.query_sql, callbacks=[callback]) |
|
callback.progress_bar.progress(value=1.0, text="Done!") |
|
st.markdown( |
|
f"### Answer from LLM\n{ret['answer']}\n### References") |
|
docs = ret['sources'] |
|
docs = pd.DataFrame( |
|
[{**d.metadata, 'abstract': d.page_content} for d in docs]) |
|
display( |
|
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id') |
|
except Exception as e: |
|
st.write('Oops π΅ Something bad happened...') |
|
raise e |
|
|
|
|
|
with tab_self_query: |
|
st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='π‘') |
|
st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"]) |
|
st.text_input("Ask a question:", key='query_self') |
|
cols = st.columns([1, 1, 7]) |
|
cols[0].button("Query", key='search_self') |
|
cols[1].button("Ask", key='ask_self') |
|
plc_hldr = st.empty() |
|
if st.session_state.search_self: |
|
plc_hldr = st.empty() |
|
print(st.session_state.query_self) |
|
with plc_hldr.expander('Query Log', expanded=True): |
|
call_back = None |
|
callback = ChatDataSelfSearchCallBackHandler() |
|
try: |
|
docs = st.session_state.sel_map_obj[sel]["retriever"].get_relevant_documents( |
|
st.session_state.query_self, callbacks=[callback]) |
|
print(docs) |
|
callback.progress_bar.progress(value=1.0, text="Done!") |
|
docs = pd.DataFrame( |
|
[{**d.metadata, 'abstract': d.page_content} for d in docs]) |
|
display(docs, sel_map[sel]["must_have_cols"]) |
|
except Exception as e: |
|
st.write('Oops π΅ Something bad happened...') |
|
raise e |
|
|
|
if st.session_state.ask_self: |
|
plc_hldr = st.empty() |
|
print(st.session_state.query_self) |
|
with plc_hldr.expander('Chat Log', expanded=True): |
|
call_back = None |
|
callback = ChatDataSelfAskCallBackHandler() |
|
try: |
|
ret = st.session_state.sel_map_obj[sel]["chain"]( |
|
st.session_state.query_self, callbacks=[callback]) |
|
callback.progress_bar.progress(value=1.0, text="Done!") |
|
st.markdown( |
|
f"### Answer from LLM\n{ret['answer']}\n### References") |
|
docs = ret['sources'] |
|
docs = pd.DataFrame( |
|
[{**d.metadata, 'abstract': d.page_content} for d in docs]) |
|
display( |
|
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id') |
|
except Exception as e: |
|
st.write('Oops π΅ Something bad happened...') |
|
raise e |
|
|