import pandas as pd from os import environ import streamlit as st from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \ ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \ ChatDataSQLAskCallBackHandler from chat import chat_page from login import login, back_to_main from lib.helper import build_tools, build_all, sel_map, display environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE'] st.set_page_config(page_title="ChatData", page_icon="https://myscale.com/favicon.ico") st.markdown( f""" """, unsafe_allow_html=True, ) st.header("ChatData") if 'sel_map_obj' not in st.session_state or 'embeddings' not in st.session_state: st.session_state["sel_map_obj"], st.session_state["embeddings"] = build_all() st.session_state["tools"] = build_tools() if login(): if "user_name" in st.session_state: chat_page() elif "jump_query_ask" in st.session_state and st.session_state.jump_query_ask: sel = st.selectbox('Choose the knowledge base you want to ask with:', options=['ArXiv Papers', 'Wikipedia']) sel_map[sel]['hint']() tab_sql, tab_self_query = st.tabs( ['Vector SQL', 'Self-Query Retrievers']) with tab_sql: sel_map[sel]['hint_sql']() st.text_input("Ask a question:", key='query_sql') cols = st.columns([1, 1, 1, 4]) cols[0].button("Query", key='search_sql') cols[1].button("Ask", key='ask_sql') cols[2].button("Back", key='back_sql', on_click=back_to_main) plc_hldr = st.empty() if st.session_state.search_sql: plc_hldr = st.empty() print(st.session_state.query_sql) with plc_hldr.expander('Query Log', expanded=True): callback = ChatDataSQLSearchCallBackHandler() try: docs = st.session_state.sel_map_obj[sel]["sql_retriever"].get_relevant_documents( st.session_state.query_sql, callbacks=[callback]) callback.progress_bar.progress(value=1.0, text="Done!") docs = pd.DataFrame( [{**d.metadata, 'abstract': d.page_content} for d in docs]) display(docs) except Exception as e: st.write('Oops 😵 Something bad happened...') raise e if st.session_state.ask_sql: plc_hldr = st.empty() print(st.session_state.query_sql) with plc_hldr.expander('Chat Log', expanded=True): callback = ChatDataSQLAskCallBackHandler() try: ret = st.session_state.sel_map_obj[sel]["sql_chain"]( st.session_state.query_sql, callbacks=[callback]) callback.progress_bar.progress(value=1.0, text="Done!") st.markdown( f"### Answer from LLM\n{ret['answer']}\n### References") docs = ret['sources'] docs = pd.DataFrame( [{**d.metadata, 'abstract': d.page_content} for d in docs]) display( docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id') except Exception as e: st.write('Oops 😵 Something bad happened...') raise e with tab_self_query: st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡') st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"]) st.text_input("Ask a question:", key='query_self') cols = st.columns([1, 1, 1, 4]) cols[0].button("Query", key='search_self') cols[1].button("Ask", key='ask_self') cols[2].button("Back", key='back_self', on_click=back_to_main) plc_hldr = st.empty() if st.session_state.search_self: plc_hldr = st.empty() print(st.session_state.query_self) with plc_hldr.expander('Query Log', expanded=True): call_back = None callback = ChatDataSelfSearchCallBackHandler() try: docs = st.session_state.sel_map_obj[sel]["retriever"].get_relevant_documents( st.session_state.query_self, callbacks=[callback]) print(docs) callback.progress_bar.progress(value=1.0, text="Done!") docs = pd.DataFrame( [{**d.metadata, 'abstract': d.page_content} for d in docs]) display(docs, sel_map[sel]["must_have_cols"]) except Exception as e: st.write('Oops 😵 Something bad happened...') raise e if st.session_state.ask_self: plc_hldr = st.empty() print(st.session_state.query_self) with plc_hldr.expander('Chat Log', expanded=True): call_back = None callback = ChatDataSelfAskCallBackHandler() try: ret = st.session_state.sel_map_obj[sel]["chain"]( st.session_state.query_self, callbacks=[callback]) callback.progress_bar.progress(value=1.0, text="Done!") st.markdown( f"### Answer from LLM\n{ret['answer']}\n### References") docs = ret['sources'] docs = pd.DataFrame( [{**d.metadata, 'abstract': d.page_content} for d in docs]) display( docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id') except Exception as e: st.write('Oops 😵 Something bad happened...') raise e