import streamlit as st import os import pathlib from typing import List # local imports from models.llms import load_llm, integrated_llms from models.embeddings import hf_embed_model, openai_embed_model from models.llamaCustom import LlamaCustom from models.llamaCustomV2 import LlamaCustomV2 # from models.vector_database import pinecone_vector_store from utils.chatbox import show_previous_messages, show_chat_input from utils.util import validate_openai_api_key # llama_index from llama_index.core import ( SimpleDirectoryReader, Document, VectorStoreIndex, StorageContext, Settings, load_index_from_storage, ) from llama_index.core.memory import ChatMemoryBuffer from llama_index.core.base.llms.types import ChatMessage # huggingface from huggingface_hub import HfApi SAVE_DIR = "uploaded_files" VECTOR_STORE_DIR = "vectorStores" HF_REPO_ID = "zhtet/RegBotBeta" # global # Settings.embed_model = hf_embed_model Settings.embed_model = openai_embed_model # huggingface api hf_api = HfApi() def init_session_state(): if "llama_messages" not in st.session_state: st.session_state.llama_messages = [ {"role": "assistant", "content": "How can I help you today?"} ] # TODO: create a chat history for each different document if "llama_chat_history" not in st.session_state: st.session_state.llama_chat_history = [ ChatMessage.from_str(role="assistant", content="How can I help you today?") ] if "llama_custom" not in st.session_state: st.session_state.llama_custom = None if "openai_api_key" not in st.session_state: st.session_state.openai_api_key = "" if "replicate_api_token" not in st.session_state: st.session_state.replicate_api_token = "" if "hf_token" not in st.session_state: st.session_state.hf_token = "" # @st.cache_resource def get_index( filename: str, ) -> VectorStoreIndex: """This function loads the index from storage if it exists, otherwise it creates a new index from the document.""" try: index_path = pathlib.Path(f"{VECTOR_STORE_DIR}/{filename.replace('.', '_')}") if pathlib.Path.exists(index_path): print("Loading index from storage ...") storage_context = StorageContext.from_defaults(persist_dir=index_path) index = load_index_from_storage(storage_context=storage_context) else: reader = SimpleDirectoryReader(input_files=[f"{SAVE_DIR}/{filename}"]) docs = reader.load_data(show_progress=True) index = VectorStoreIndex.from_documents( documents=docs, show_progress=True, ) index.storage_context.persist( persist_dir=f"vectorStores/{filename.replace('.', '_')}" ) except Exception as e: print(f"Error: {e}") raise e return index # def get_pinecone_index(filename: str) -> VectorStoreIndex: # """Thie function loads the index from Pinecone if it exists, otherwise it creates a new index from the document.""" # reader = SimpleDirectoryReader(input_files=[f"{SAVE_DIR}/{filename}"]) # docs = reader.load_data(show_progress=True) # storage_context = StorageContext.from_defaults(vector_store=pinecone_vector_store) # index = VectorStoreIndex.from_documents( # documents=docs, show_progress=True, storage_context=storage_context # ) # return index def get_chroma_index(filename: str) -> VectorStoreIndex: """This function loads the index from Chroma if it exists, otherwise it creates a new index from the document.""" pass def check_api_key(model_name: str, source: str): if source.startswith("openai"): if not st.session_state.openai_api_key: with st.expander("OpenAI API Key", expanded=True): openai_api_key = st.text_input( label="Enter your OpenAI API Key:", type="password", help="Get your key from https://platform.openai.com/account/api-keys", value=st.session_state.openai_api_key, ) if openai_api_key and st.spinner("Validating OpenAI API Key ..."): result = validate_openai_api_key(openai_api_key) if result["status"] == "success": st.session_state.openai_api_key = openai_api_key st.success(result["message"]) else: st.error(result["message"]) st.info("You can still select a different model to proceed.") st.stop() elif source.startswith("replicate"): if not st.session_state.replicate_api_token: with st.expander("Replicate API Token", expanded=True): replicate_api_token = st.text_input( label="Enter your Replicate API Token:", type="password", help="Get your key from https://replicate.ai/account", value=st.session_state.replicate_api_token, ) # TODO: need to validate the token if replicate_api_token: st.session_state.replicate_api_token = replicate_api_token # set the environment variable os.environ["REPLICATE_API_TOKEN"] = replicate_api_token elif source.startswith("huggingface"): if not st.session_state.hf_token: with st.expander("Hugging Face Token", expanded=True): hf_token = st.text_input( label="Enter your Hugging Face Token:", type="password", help="Get your key from https://huggingface.co/settings/token", value=st.session_state.hf_token, ) if hf_token: st.session_state.hf_token = hf_token # set the environment variable os.environ["HF_TOKEN"] = hf_token init_session_state() st.set_page_config(page_title="Llama", page_icon="🦙") st.header("California Drinking Water Regulation Chatbot - RegBot with LlamaIndex Demo") tab1, tab2 = st.tabs(["Config", "Chat"]) with tab1: selected_llm_name = st.selectbox( label="Select a model:", options=[f"{key} | {value}" for key, value in integrated_llms.items()], ) model_name, source = selected_llm_name.split("|") check_api_key(model_name=model_name.strip(), source=source.strip()) selected_file = st.selectbox( label="Choose a file to chat with: ", options=os.listdir(SAVE_DIR) ) if st.button("Clear all api keys"): st.session_state.openai_api_key = "" st.session_state.replicate_api_token = "" st.session_state.hf_token = "" st.success("All API keys cleared!") st.rerun() if st.button("Submit", key="submit", help="Submit the form"): with st.status("Loading ...", expanded=True) as status: try: st.write("Loading Model ...") llama_llm = load_llm( model_name=model_name.strip(), source=source.strip() ) if llama_llm is None: raise ValueError("Model not found!") Settings.llm = llama_llm st.write("Processing Data ...") index = get_index(selected_file) # index = get_pinecone_index(selected_file) st.write("Finishing Up ...") llama_custom = LlamaCustom(model_name=selected_llm_name, index=index) # llama_custom = LlamaCustomV2(model_name=selected_llm_name, index=index) st.session_state.llama_custom = llama_custom status.update(label="Ready to query!", state="complete", expanded=False) except Exception as e: status.update(label="Error!", state="error", expanded=False) st.error(f"Error: {e}") st.stop() with tab2: messages_container = st.container(height=300) show_previous_messages(framework="llama", messages_container=messages_container) show_chat_input( disabled=False, framework="llama", model=st.session_state.llama_custom, messages_container=messages_container, ) def clear_history(): messages_container.empty() st.session_state.llama_messages = [ {"role": "assistant", "content": "How can I help you today?"} ] st.session_state.llama_chat_history = [ ChatMessage.from_str(role="assistant", content="How can I help you today?") ] if st.button("Clear Chat History"): clear_history() st.rerun()