RegBotBeta / pages /llama_custom_demo.py
hbui's picture
Update pages/llama_custom_demo.py
12de659 verified
raw
history blame
8.88 kB
import streamlit as st
import os
import pathlib
from typing import List
# local imports
from models.llms import load_llm, integrated_llms
from models.embeddings import hf_embed_model, openai_embed_model
from models.llamaCustom import LlamaCustom
from models.llamaCustomV2 import LlamaCustomV2
# from models.vector_database import pinecone_vector_store
from utils.chatbox import show_previous_messages, show_chat_input
from utils.util import validate_openai_api_key
# llama_index
from llama_index.core import (
SimpleDirectoryReader,
Document,
VectorStoreIndex,
StorageContext,
Settings,
load_index_from_storage,
)
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.base.llms.types import ChatMessage
# huggingface
from huggingface_hub import HfApi
SAVE_DIR = "uploaded_files"
VECTOR_STORE_DIR = "vectorStores"
HF_REPO_ID = "zhtet/RegBotBeta"
# global
# Settings.embed_model = hf_embed_model
Settings.embed_model = openai_embed_model
# huggingface api
hf_api = HfApi()
def init_session_state():
if "llama_messages" not in st.session_state:
st.session_state.llama_messages = [
{"role": "assistant", "content": "How can I help you today?"}
]
# TODO: create a chat history for each different document
if "llama_chat_history" not in st.session_state:
st.session_state.llama_chat_history = [
ChatMessage.from_str(role="assistant", content="How can I help you today?")
]
if "llama_custom" not in st.session_state:
st.session_state.llama_custom = None
if "openai_api_key" not in st.session_state:
st.session_state.openai_api_key = ""
if "replicate_api_token" not in st.session_state:
st.session_state.replicate_api_token = ""
if "hf_token" not in st.session_state:
st.session_state.hf_token = ""
# @st.cache_resource
def get_index(
filename: str,
) -> VectorStoreIndex:
"""This function loads the index from storage if it exists, otherwise it creates a new index from the document."""
try:
index_path = pathlib.Path(f"{VECTOR_STORE_DIR}/{filename.replace('.', '_')}")
if pathlib.Path.exists(index_path):
print("Loading index from storage ...")
storage_context = StorageContext.from_defaults(persist_dir=index_path)
index = load_index_from_storage(storage_context=storage_context)
else:
reader = SimpleDirectoryReader(input_files=[f"{SAVE_DIR}/{filename}"])
docs = reader.load_data(show_progress=True)
index = VectorStoreIndex.from_documents(
documents=docs,
show_progress=True,
)
index.storage_context.persist(
persist_dir=f"vectorStores/{filename.replace('.', '_')}"
)
except Exception as e:
print(f"Error: {e}")
raise e
return index
# def get_pinecone_index(filename: str) -> VectorStoreIndex:
# """Thie function loads the index from Pinecone if it exists, otherwise it creates a new index from the document."""
# reader = SimpleDirectoryReader(input_files=[f"{SAVE_DIR}/{filename}"])
# docs = reader.load_data(show_progress=True)
# storage_context = StorageContext.from_defaults(vector_store=pinecone_vector_store)
# index = VectorStoreIndex.from_documents(
# documents=docs, show_progress=True, storage_context=storage_context
# )
# return index
def get_chroma_index(filename: str) -> VectorStoreIndex:
"""This function loads the index from Chroma if it exists, otherwise it creates a new index from the document."""
pass
def check_api_key(model_name: str, source: str):
if source.startswith("openai"):
if not st.session_state.openai_api_key:
with st.expander("OpenAI API Key", expanded=True):
openai_api_key = st.text_input(
label="Enter your OpenAI API Key:",
type="password",
help="Get your key from https://platform.openai.com/account/api-keys",
value=st.session_state.openai_api_key,
)
if openai_api_key and st.spinner("Validating OpenAI API Key ..."):
result = validate_openai_api_key(openai_api_key)
if result["status"] == "success":
st.session_state.openai_api_key = openai_api_key
st.success(result["message"])
else:
st.error(result["message"])
st.info("You can still select a different model to proceed.")
st.stop()
elif source.startswith("replicate"):
if not st.session_state.replicate_api_token:
with st.expander("Replicate API Token", expanded=True):
replicate_api_token = st.text_input(
label="Enter your Replicate API Token:",
type="password",
help="Get your key from https://replicate.ai/account",
value=st.session_state.replicate_api_token,
)
# TODO: need to validate the token
if replicate_api_token:
st.session_state.replicate_api_token = replicate_api_token
# set the environment variable
os.environ["REPLICATE_API_TOKEN"] = replicate_api_token
elif source.startswith("huggingface"):
if not st.session_state.hf_token:
with st.expander("Hugging Face Token", expanded=True):
hf_token = st.text_input(
label="Enter your Hugging Face Token:",
type="password",
help="Get your key from https://huggingface.co/settings/token",
value=st.session_state.hf_token,
)
if hf_token:
st.session_state.hf_token = hf_token
# set the environment variable
os.environ["HF_TOKEN"] = hf_token
init_session_state()
st.set_page_config(page_title="Llama", page_icon="🦙")
st.header("California Drinking Water Regulation Chatbot - RegBot with LlamaIndex Demo")
tab1, tab2 = st.tabs(["Config", "Chat"])
with tab1:
selected_llm_name = st.selectbox(
label="Select a model:",
options=[f"{key} | {value}" for key, value in integrated_llms.items()],
)
model_name, source = selected_llm_name.split("|")
check_api_key(model_name=model_name.strip(), source=source.strip())
selected_file = st.selectbox(
label="Choose a file to chat with: ", options=os.listdir(SAVE_DIR)
)
if st.button("Clear all api keys"):
st.session_state.openai_api_key = ""
st.session_state.replicate_api_token = ""
st.session_state.hf_token = ""
st.success("All API keys cleared!")
st.rerun()
if st.button("Submit", key="submit", help="Submit the form"):
with st.status("Loading ...", expanded=True) as status:
try:
st.write("Loading Model ...")
llama_llm = load_llm(
model_name=model_name.strip(), source=source.strip()
)
if llama_llm is None:
raise ValueError("Model not found!")
Settings.llm = llama_llm
st.write("Processing Data ...")
index = get_index(selected_file)
# index = get_pinecone_index(selected_file)
st.write("Finishing Up ...")
llama_custom = LlamaCustom(model_name=selected_llm_name, index=index)
# llama_custom = LlamaCustomV2(model_name=selected_llm_name, index=index)
st.session_state.llama_custom = llama_custom
status.update(label="Ready to query!", state="complete", expanded=False)
except Exception as e:
status.update(label="Error!", state="error", expanded=False)
st.error(f"Error: {e}")
st.stop()
with tab2:
messages_container = st.container(height=300)
show_previous_messages(framework="llama", messages_container=messages_container)
show_chat_input(
disabled=False,
framework="llama",
model=st.session_state.llama_custom,
messages_container=messages_container,
)
def clear_history():
messages_container.empty()
st.session_state.llama_messages = [
{"role": "assistant", "content": "How can I help you today?"}
]
st.session_state.llama_chat_history = [
ChatMessage.from_str(role="assistant", content="How can I help you today?")
]
if st.button("Clear Chat History"):
clear_history()
st.rerun()