|
import os |
|
import time |
|
|
|
import streamlit as st |
|
|
|
from backend.constants.streamlit_keys import DATA_INITIALIZE_NOT_STATED, DATA_INITIALIZE_COMPLETED, \ |
|
DATA_INITIALIZE_STARTED |
|
from backend.constants.variables import DATA_INITIALIZE_STATUS, JUMP_QUERY_ASK, CHAINS_RETRIEVERS_MAPPING, \ |
|
TABLE_EMBEDDINGS_MAPPING, RETRIEVER_TOOLS, USER_NAME, GLOBAL_CONFIG, update_global_config |
|
from backend.construct.build_all import build_chains_and_retrievers, load_embedding_models, update_retriever_tools |
|
from backend.types.global_config import GlobalConfig |
|
from logger import logger |
|
from ui.chat_page import chat_page |
|
from ui.home import render_home |
|
from ui.retrievers import render_retrievers |
|
|
|
|
|
|
|
|
|
def prepare_environment(): |
|
os.environ['TOKENIZERS_PARALLELISM'] = 'true' |
|
os.environ["LANGCHAIN_TRACING_V2"] = "false" |
|
|
|
os.environ["OPENAI_API_BASE"] = st.secrets['OPENAI_API_BASE'] |
|
os.environ["OPENAI_API_KEY"] = st.secrets['OPENAI_API_KEY'] |
|
os.environ["AUTH0_CLIENT_ID"] = st.secrets['AUTH0_CLIENT_ID'] |
|
os.environ["AUTH0_DOMAIN"] = st.secrets['AUTH0_DOMAIN'] |
|
|
|
update_global_config(GlobalConfig( |
|
openai_api_base=st.secrets['OPENAI_API_BASE'], |
|
openai_api_key=st.secrets['OPENAI_API_KEY'], |
|
auth0_client_id=st.secrets['AUTH0_CLIENT_ID'], |
|
auth0_domain=st.secrets['AUTH0_DOMAIN'], |
|
myscale_user=st.secrets['MYSCALE_USER'], |
|
myscale_password=st.secrets['MYSCALE_PASSWORD'], |
|
myscale_host=st.secrets['MYSCALE_HOST'], |
|
myscale_port=st.secrets['MYSCALE_PORT'], |
|
query_model="gpt-3.5-turbo-0125", |
|
chat_model="gpt-3.5-turbo-0125", |
|
untrusted_api=st.secrets['UNSTRUCTURED_API'], |
|
myscale_enable_https=st.secrets.get('MYSCALE_ENABLE_HTTPS', True), |
|
)) |
|
|
|
|
|
|
|
def initialize_session_state(): |
|
if DATA_INITIALIZE_STATUS not in st.session_state: |
|
st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_NOT_STATED |
|
logger.info(f"Initialize session state key: {DATA_INITIALIZE_STATUS}") |
|
if JUMP_QUERY_ASK not in st.session_state: |
|
st.session_state[JUMP_QUERY_ASK] = False |
|
logger.info(f"Initialize session state key: {JUMP_QUERY_ASK}") |
|
|
|
|
|
def initialize_chat_data(): |
|
if st.session_state[DATA_INITIALIZE_STATUS] != DATA_INITIALIZE_COMPLETED: |
|
start_time = time.time() |
|
st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_STARTED |
|
st.session_state[TABLE_EMBEDDINGS_MAPPING] = load_embedding_models() |
|
st.session_state[CHAINS_RETRIEVERS_MAPPING] = build_chains_and_retrievers() |
|
st.session_state[RETRIEVER_TOOLS] = update_retriever_tools() |
|
|
|
st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_COMPLETED |
|
end_time = time.time() |
|
logger.info(f"ChatData initialized finished in {round(end_time - start_time, 3)} seconds, " |
|
f"session state keys: {list(st.session_state.keys())}") |
|
|
|
|
|
st.set_page_config( |
|
page_title="ChatData", |
|
page_icon="https://myscale.com/favicon.ico", |
|
initial_sidebar_state="expanded", |
|
layout="wide", |
|
) |
|
|
|
prepare_environment() |
|
initialize_session_state() |
|
initialize_chat_data() |
|
|
|
if USER_NAME in st.session_state: |
|
chat_page() |
|
else: |
|
if st.session_state[JUMP_QUERY_ASK]: |
|
render_retrievers() |
|
else: |
|
render_home() |
|
|