|
import streamlit as st |
|
from streamlit_feedback import streamlit_feedback |
|
|
|
import os |
|
import pandas as pd |
|
import base64 |
|
from io import BytesIO |
|
import sqlite3 |
|
import uuid |
|
import yaml |
|
|
|
import chromadb |
|
from llama_index.core import ( |
|
VectorStoreIndex, |
|
SimpleDirectoryReader, |
|
StorageContext, |
|
Document |
|
) |
|
from llama_index.vector_stores.chroma.base import ChromaVectorStore |
|
from llama_index.embeddings.huggingface.base import HuggingFaceEmbedding |
|
from llama_index.llms.openai import OpenAI |
|
from llama_index.core.memory import ChatMemoryBuffer |
|
from llama_index.core.tools import QueryEngineTool |
|
from llama_index.agent.openai import OpenAIAgent |
|
from llama_index.core import Settings |
|
|
|
from vision_api import get_transcribed_text |
|
from qna_prompting import get_qna_question_tool, evaluate_qna_answer_tool |
|
from prompt_engineering import ( |
|
system_content, |
|
textbook_content, |
|
winnie_the_pooh_prompt, |
|
introduction_line |
|
) |
|
|
|
import nest_asyncio |
|
nest_asyncio.apply() |
|
|
|
|
|
st.set_page_config(page_title="π»π Study Bear π―") |
|
openai_api = os.getenv("OPENAI_API_KEY") |
|
|
|
with open("./config/model_config_advanced.yml", "r") as file_reader: |
|
model_config = yaml.safe_load(file_reader) |
|
|
|
input_files = model_config["input_data"]["source"] |
|
embedding_model = model_config["embeddings"]["embedding_base_model"] |
|
fine_tuned_path = model_config["embeddings"]["fine_tuned_embedding_model"] |
|
persisted_vector_db = model_config["vector_store"]["persisted_path"] |
|
questionaire_db_path = model_config["questionaire_data"]["db_path"] |
|
|
|
data_df = pd.DataFrame( |
|
{ |
|
"Completion": [30, 40, 100, 10], |
|
} |
|
) |
|
data_df.index = ["Chapter 1", "Chapter 2", "Chapter 3", "Chapter 4"] |
|
|
|
bear_img_path = "./resource/disney-cuties-little-winnie-the-pooh-emoticon.png" |
|
piglet_img_path = "./resource/disney-cuties-piglet-emoticon.png" |
|
|
|
|
|
with st.sidebar: |
|
st.title("π―π Study Bear π»π") |
|
st.write("Just like Pooh needs honey, success requires hard work β no shortcuts allowed!") |
|
wtp_mode = st.toggle('Winnie-the-Pooh mode', value=False) |
|
if wtp_mode: |
|
system_content = system_content + winnie_the_pooh_prompt |
|
textbook_content = system_content + textbook_content |
|
|
|
if openai_api: |
|
pass |
|
elif "OPENAI_API_KEY" in st.secrets: |
|
st.success("API key already provided!", icon="β
") |
|
openai_api = st.secrets["OPENAI_API_KEY"] |
|
else: |
|
openai_api = st.text_input("Enter OpenAI API token:", type="password") |
|
if not (openai_api.startswith("sk-") and len(openai_api)==51): |
|
st.warning("Please enter your credentials!", icon="β οΈ") |
|
else: |
|
st.success("Proceed to entering your prompt message!", icon="π") |
|
|
|
|
|
os.environ["OPENAI_API_KEY"] = openai_api |
|
|
|
st.subheader("Models and parameters") |
|
selected_model = st.sidebar.selectbox(label="Choose an OpenAI model", |
|
options=["gpt-3.5-turbo-0125", "gpt-4-0125-preview"], |
|
index=1, |
|
key="selected_model") |
|
temperature = st.sidebar.slider("temperature", min_value=0.0, max_value=2.0, |
|
value=0.0, step=0.01) |
|
st.data_editor( |
|
data_df, |
|
column_config={ |
|
"Completion": st.column_config.ProgressColumn( |
|
"Completion %", |
|
help="Percentage of content covered", |
|
format="%.1f%%", |
|
min_value=0, |
|
max_value=100, |
|
), |
|
}, |
|
hide_index=False, |
|
) |
|
|
|
st.markdown("π Reach out to SakiMilo to learn how to create this app!") |
|
|
|
if "init" not in st.session_state.keys(): |
|
st.session_state.init = {"warm_started": "No"} |
|
st.session_state.feedback = False |
|
|
|
if "image_prompt" not in st.session_state.keys(): |
|
st.session_state.image_prompt = False |
|
|
|
|
|
if "messages" not in st.session_state.keys(): |
|
st.session_state.messages = [{"role": "assistant", |
|
"content": introduction_line, |
|
"type": "text"}] |
|
|
|
if "feedback_key" not in st.session_state: |
|
st.session_state.feedback_key = 0 |
|
|
|
if "release_file" not in st.session_state: |
|
st.session_state.release_file = "false" |
|
|
|
if "question_id" not in st.session_state: |
|
st.session_state.question_id = None |
|
|
|
if "qna_answer_int" not in st.session_state: |
|
st.session_state.qna_answer_int = None |
|
|
|
if "qna_answer_str" not in st.session_state: |
|
st.session_state.qna_answer_str = None |
|
|
|
if "reasons" not in st.session_state: |
|
st.session_state.reasons = None |
|
|
|
if "user_id" not in st.session_state: |
|
st.session_state.user_id = str(uuid.uuid4()) |
|
|
|
def clear_chat_history(): |
|
|
|
st.session_state.messages = [{"role": "assistant", |
|
"content": introduction_line, |
|
"type": "text"}] |
|
chat_engine = get_query_engine(input_files=input_files, |
|
llm_model=selected_model, |
|
temperature=temperature, |
|
embedding_model=embedding_model, |
|
fine_tuned_path=fine_tuned_path, |
|
system_content=system_content, |
|
persisted_vector_db=persisted_vector_db) |
|
chat_engine.reset() |
|
st.toast("yumyum, what was I saying again? π»π¬", icon="π―") |
|
|
|
def clear_question_history(user_id): |
|
|
|
con = sqlite3.connect(questionaire_db_path) |
|
cur = con.cursor() |
|
sql_string = f""" |
|
DELETE FROM answer_tbl |
|
WHERE user_id='{user_id}' |
|
""" |
|
res = cur.execute(sql_string) |
|
con.commit() |
|
con.close() |
|
st.toast("the tale of one thousand and one questions, reset! π§¨π§¨", icon="π") |
|
|
|
st.sidebar.button("Clear Chat History", on_click=clear_chat_history) |
|
st.sidebar.button("Clear Question History", |
|
on_click=clear_question_history, |
|
kwargs={"user_id": st.session_state.user_id}) |
|
if st.sidebar.button("I want to submit a feedback!"): |
|
st.session_state.feedback = True |
|
st.session_state.feedback_key += 1 |
|
|
|
@st.cache_resource |
|
def get_document_object(input_files): |
|
documents = SimpleDirectoryReader(input_files=input_files).load_data() |
|
document = Document(text="\n\n".join([doc.text for doc in documents])) |
|
return document |
|
|
|
@st.cache_resource |
|
def get_llm_object(selected_model, temperature): |
|
llm = OpenAI(model=selected_model, temperature=temperature) |
|
return llm |
|
|
|
@st.cache_resource |
|
def get_embedding_model(model_name, fine_tuned_path=None): |
|
if fine_tuned_path is None: |
|
print(f"loading from `{model_name}` from huggingface") |
|
embed_model = HuggingFaceEmbedding(model_name=model_name) |
|
else: |
|
print(f"loading from local `{fine_tuned_path}`") |
|
embed_model = fine_tuned_path |
|
return embed_model |
|
|
|
@st.cache_resource |
|
def get_query_engine(input_files, llm_model, temperature, |
|
embedding_model, fine_tuned_path, |
|
system_content, persisted_vector_db): |
|
|
|
llm = get_llm_object(llm_model, temperature) |
|
embedded_model = get_embedding_model( |
|
model_name=embedding_model, |
|
fine_tuned_path=fine_tuned_path |
|
) |
|
Settings.llm = llm |
|
Settings.chunk_size = 1024 |
|
Settings.embed_model = embedded_model |
|
|
|
if os.path.exists(persisted_vector_db): |
|
print("loading from vector database - chroma") |
|
db = chromadb.PersistentClient(path=persisted_vector_db) |
|
chroma_collection = db.get_or_create_collection("quickstart") |
|
vector_store = ChromaVectorStore(chroma_collection=chroma_collection) |
|
storage_context = StorageContext.from_defaults(vector_store=vector_store) |
|
|
|
index = VectorStoreIndex.from_vector_store( |
|
vector_store=vector_store, |
|
storage_context=storage_context |
|
) |
|
else: |
|
print("create new chroma vector database..") |
|
documents = SimpleDirectoryReader(input_files=input_files).load_data() |
|
|
|
db = chromadb.PersistentClient(path=persisted_vector_db) |
|
chroma_collection = db.get_or_create_collection("quickstart") |
|
vector_store = ChromaVectorStore(chroma_collection=chroma_collection) |
|
|
|
nodes = Settings.node_parser.get_nodes_from_documents(documents) |
|
storage_context = StorageContext.from_defaults(vector_store=vector_store) |
|
storage_context.docstore.add_documents(nodes) |
|
|
|
index = VectorStoreIndex(nodes, storage_context=storage_context) |
|
|
|
memory = ChatMemoryBuffer.from_defaults(token_limit=100_000) |
|
hi_content_engine = index.as_query_engine( |
|
memory=memory, |
|
system_prompt=system_content, |
|
similarity_top_k=10, |
|
verbose=True, |
|
streaming=True |
|
) |
|
hi_textbook_query_description = """ |
|
Use this tool to extract content from the query engine, |
|
which is built by ingesting textbook content from `Health Insurance 7th Edition`, |
|
that has 15 chapters in total. When user wants to learn more about a |
|
particular chapter, this tool will help to assist user to get better |
|
understanding of the content of the textbook. |
|
""" |
|
|
|
hi_query_tool = QueryEngineTool.from_defaults( |
|
query_engine=hi_content_engine, |
|
name="health_insurance_textbook_query_engine", |
|
description=hi_textbook_query_description |
|
) |
|
|
|
agent = OpenAIAgent.from_tools(tools=[ |
|
hi_query_tool, |
|
get_qna_question_tool, |
|
evaluate_qna_answer_tool |
|
], |
|
max_function_calls=1, |
|
llm=llm, |
|
verbose=True, |
|
system_prompt=textbook_content) |
|
print("loaded AI agent, let's begin the chat!") |
|
print("="*50) |
|
print("") |
|
|
|
return agent |
|
|
|
def generate_llm_response(prompt_input, tool_choice="auto"): |
|
chat_agent = get_query_engine(input_files=input_files, |
|
llm_model=selected_model, |
|
temperature=temperature, |
|
embedding_model=embedding_model, |
|
fine_tuned_path=fine_tuned_path, |
|
system_content=system_content, |
|
persisted_vector_db=persisted_vector_db) |
|
|
|
|
|
response = chat_agent.stream_chat(prompt_input, tool_choice=tool_choice) |
|
return response |
|
|
|
def handle_feedback(user_response): |
|
st.toast("βοΈ Feedback received!") |
|
st.session_state.feedback = False |
|
|
|
def handle_image_upload(): |
|
st.session_state.release_file = "true" |
|
|
|
|
|
if st.session_state.init["warm_started"] == "No": |
|
clear_chat_history() |
|
st.session_state.init["warm_started"] = "Yes" |
|
|
|
|
|
with st.sidebar: |
|
image_file = st.file_uploader("Upload your image here...", |
|
type=["png", "jpeg", "jpg"], |
|
on_change=handle_image_upload) |
|
|
|
if st.session_state.release_file == "true" and image_file: |
|
with st.spinner("Uploading..."): |
|
b64string = base64.b64encode(image_file.read()).decode('utf-8') |
|
message = { |
|
"role": "user", |
|
"content": b64string, |
|
"type": "image"} |
|
st.session_state.messages.append(message) |
|
|
|
transcribed_msg = get_transcribed_text(b64string) |
|
message = { |
|
"role": "admin", |
|
"content": transcribed_msg, |
|
"type": "text"} |
|
st.session_state.messages.append(message) |
|
st.session_state.release_file = "false" |
|
|
|
|
|
for message in st.session_state.messages: |
|
if message["role"] == "admin": |
|
continue |
|
elif message["role"] == "user": |
|
avatar = piglet_img_path |
|
elif message["role"] == "assistant": |
|
avatar = bear_img_path |
|
|
|
with st.chat_message(message["role"], avatar=avatar): |
|
if message["type"] == "text": |
|
st.write(message["content"]) |
|
elif message["type"] == "image": |
|
img_io = BytesIO(base64.b64decode(message["content"].encode("utf-8"))) |
|
st.image(img_io) |
|
|
|
|
|
if prompt := st.chat_input(disabled=not openai_api): |
|
st.session_state.messages.append({"role": "user", |
|
"content": prompt, |
|
"type": "text"}) |
|
with st.chat_message("user", avatar=piglet_img_path): |
|
st.write(prompt) |
|
|
|
|
|
if prompt is None and \ |
|
st.session_state.messages[-1]["role"] == "admin": |
|
st.session_state.image_prompt = True |
|
prompt = st.session_state.messages[-1]["content"] |
|
|
|
|
|
if st.session_state.messages[-1]["role"] != "assistant": |
|
with st.chat_message("assistant", avatar=bear_img_path): |
|
with st.spinner("π§Έπ€ Thinking... π»π"): |
|
if st.session_state.image_prompt: |
|
response = generate_llm_response( |
|
prompt, |
|
tool_choice="health_insurance_textbook_query_engine" |
|
) |
|
st.session_state.image_prompt = False |
|
else: |
|
response = generate_llm_response(prompt, tool_choice="auto") |
|
placeholder = st.empty() |
|
full_response = "" |
|
for token in response.response_gen: |
|
token = token.replace("\n", " \n") \ |
|
.replace("$", "\$") \ |
|
.replace("\[", "$$") |
|
full_response += token |
|
placeholder.markdown(full_response) |
|
placeholder.markdown(full_response) |
|
|
|
message = {"role": "assistant", |
|
"content": full_response, |
|
"type": "text"} |
|
st.session_state.messages.append(message) |
|
|
|
|
|
if st.session_state.feedback: |
|
result = streamlit_feedback( |
|
feedback_type="thumbs", |
|
optional_text_label="[Optional] Please provide an explanation", |
|
on_submit=handle_feedback, |
|
key=f"feedback_{st.session_state.feedback_key}" |
|
) |