MarioBarbeque's picture
fix typo
2946731 verified
import os
import threading
import streamlit as st
from itertools import tee
from chain import ChainBuilder
# remove these secrets from the container
# VS_INDEX_NAME = os.environ.get("VS_INDEX_NAME")
raise ValueError("DATABRICKS_HOST environment variable must be set")
raise ValueError("DATABRICKS_TOKEN environment variable must be set")
MAX_CHAT_TURNS = 10 # limit this for preliminary testing
MSG_MAX_TURNS_EXCEEDED = f"Sorry! The Vanderbilt AI assistant playground is limited to {MAX_CHAT_TURNS} turns in a single history. Click the 'Clear Chat' button or refresh the page to start a new conversation."
# MSG_CLIPPED_AT_MAX_OUT_TOKENS = "Reached maximum output tokens for DBRX Playground"
"How is a data lake used at Vanderbilt University Medical Center?",
"In a table, what are some of the greatest hurdles to healthcare in the United States?",
"What does EDW stand for in the context of Vanderbilt University Medical Center?",
"Code a sql statement that can query a database named 'VUMC'.",
"Write a short story about a country concert in Nashville, Tennessee.",
"Tell me about maximum out-of-pocket costs in healthcare.",
TITLE = "Vanderbilt AI Assistant"
DESCRIPTION= """Welcome to the first generation Vanderbilt AI assistant! \n
**WARNING**: Unfortunately this space is currently deprecated. The serving endpoint used to serve the pay-per-token Databricks DBRX language model has been rate-limited by
staff for security reasons to accept no queries. I am in the process of reworking this augmented model to hit an available endpoint for community use. Nonetheless, if you are interested
in seeing this model's functionality in a 24 hour time window, send me an email at ``, or the email below, and I will temporarily activate the serving endpoint
for you to query the model. \n
**Overview and Usage**: This AI assistant is built atop the Databricks DBRX large language model
and is augmented with additional organization-specific knowledge. Particularly, it has been preliminarily augmented with knowledge of Vanderbilt University Medical Center
terms like **EDW**, **HCERA**, **NRHA** and **thousands more**. (Ask the assistant if you don't know what any of these terms mean!) On the left is a sidebar of **Examples**;
click any of these examples to issue the corresponding query to the AI.
**Feedback**: Feedback is welcomed, encouraged, and invaluable! To give feedback in regards to one of the model's responses, click the **Give Feedback on Last Response** button just below
the user input bar. This allows you to provide either positive or negative feedback in regards to the model's most recent response. A **Feedback Form** will appear above the model's title.
Please be sure to select either πŸ‘ or πŸ‘Ž before adding additional notes about your choice. Be as brief or as detailed as you like! Note that you are making a difference; this
feedback allows us to later improve this model for your usage through a training technique known as reinforcement learning through human feedback. \n
**Disclaimer**: The model has **no access to PHI**. \n
Please provide any additional, larger feedback, ideas, or issues to the email: ****. Happy chatting!"""
GENERAL_ERROR_MSG = "An error occurred. Please refresh the page to start a new conversation."
# # To prevent streaming too fast, chunk the output into TOKEN_CHUNK_SIZE chunks
TOKEN_CHUNK_SIZE = 1 # test this number
# if TOKEN_CHUNK_SIZE_ENV is not None:
QUEUE_SIZE = 20 # maximize this value for adding enough places in the global queue?
# if QUEUE_SIZE_ENV is not None:
# @st.cache_resource
# def get_global_semaphore():
# return threading.BoundedSemaphore(QUEUE_SIZE)
# global_semaphore = get_global_semaphore()
# st.image("sunrise.jpg", caption="Sunrise by the mountains") # TODO add a Vanderbilt related picture to the head of our Space!
# use this to format later
with open("./style.css") as css:
st.markdown( f'<style>{}</style>' , unsafe_allow_html= True)
if "messages" not in st.session_state:
st.session_state["messages"] = []
if "feedback" not in st.session_state:
st.session_state["feedback"] = [None]
def clear_chat_history():
st.session_state["messages"] = []
st.button('Clear Chat', on_click=clear_chat_history)
# build our chain outside the working body so that its only instantiated once - simply pass it the chat history for chat completion
chain = ChainBuilder().build_chain()
def last_role_is_user():
return len(st.session_state["messages"]) > 0 and st.session_state["messages"][-1]["role"] == "user"
def text_stream(stream):
for chunk in stream:
if chunk["content"] is not None:
yield chunk["content"]
def get_stream_warning_error(stream):
error = None
warning = None
for chunk in stream:
if chunk["error"] is not None:
error = chunk["error"]
if chunk["warning"] is not None:
warning = chunk["warning"]
return warning, error
# @retry(wait=wait_random_exponential(min=0.5, max=2), stop=stop_after_attempt(3))
def chain_call(history):
input = {'messages': [{"role": m["role"], "content": m["content"]} for m in history]}
chat_completion =
return chat_completion
def write_response():
stream = chat_completion(st.session_state["messages"])
content_stream, error_stream = tee(stream)
response = st.write_stream(text_stream(content_stream))
stream_warning, stream_error = get_stream_warning_error(error_stream)
if stream_warning is not None:
if stream_error is not None:
# if there was an error, a list will be returned instead of a string:
if isinstance(response, list):
response = None
return response, stream_warning, stream_error
def chat_completion(messages):
if (len(messages)-1)//2 >= MAX_CHAT_TURNS:
yield {"content": None, "error": MSG_MAX_TURNS_EXCEEDED, "warning": None}
chat_completion = None
error = None
# *** TODO add code for implementing a global queue with a bounded semaphore?
# wait to be in queue
# with global_semaphore:
# try:
# chat_completion = chat_api_call(history_dbrx_format)
# except Exception as e:
# error = e
# chat_completion = chain_call(history_dbrx_format)
chat_completion = chain_call(messages)
if error is not None:
yield {"content": None, "error": GENERAL_ERROR_MSG, "warning": None}
max_token_warning = None
partial_message = ""
chunk_counter = 0
for chunk in chat_completion:
if chunk is not None:
chunk_counter += 1
partial_message += chunk
if chunk_counter % TOKEN_CHUNK_SIZE == 0:
chunk_counter = 0
yield {"content": partial_message, "error": None, "warning": None}
partial_message = ""
# if chunk.choices[0].finish_reason == "length":
# max_token_warning = MSG_CLIPPED_AT_MAX_OUT_TOKENS
yield {"content": partial_message, "error": None, "warning": max_token_warning}
# if assistant is the last message, we need to prompt the user
# if user is the last message, we need to retry the assistant.
def handle_user_input(user_input):
with history:
response, stream_warning, stream_error = [None, None, None]
if last_role_is_user():
# retry the assistant if the user tries to send a new message
with st.chat_message("assistant", avatar=MODEL_AVATAR_URL):
response, stream_warning, stream_error = write_response()
st.session_state["messages"].append({"role": "user", "content": user_input, "warning": None, "error": None})
with st.chat_message("user", avatar="πŸ§‘β€πŸ’»"):
stream = chat_completion(st.session_state["messages"])
with st.chat_message("assistant", avatar=MODEL_AVATAR_URL):
response, stream_warning, stream_error = write_response()
st.session_state["messages"].append({"role": "assistant", "content": response, "warning": stream_warning, "error": stream_error})
def feedback():
with st.form("feedback_form"):
st.title("Feedback Form")
st.markdown("Please select either πŸ‘ or πŸ‘Ž before providing a reason for your review of the most recent response. Dont forget to click submit!")
rating =
feedback = st.text_input("Please detail your feedback: ")
# implement a method for writing these responses to storage!
submitted = st.form_submit_button("Submit Feedback")
main = st.container()
with main:
if st.session_state["feedback"][-1] is not None: # TODO clean this up in a fn?
st.markdown("Thank you! Feedback received! Type a new message to continue your conversation.")
history = st.container(height=400)
with history:
for message in st.session_state["messages"]:
avatar = "πŸ§‘β€πŸ’»"
if message["role"] == "assistant":
with st.chat_message(message["role"], avatar=avatar):
if message["content"] is not None:
if message["error"] is not None:
if message["warning"] is not None:
if prompt := st.chat_input("Type a message!", max_chars=5000):
st.markdown("\n") #add some space for iphone users
gave_feedback = st.button('Give Feedback on Last Response', on_click=feedback)
if gave_feedback: # TODO clean up the conditions here with a function
with st.sidebar:
with st.container():
for prompt in EXAMPLE_PROMPTS:
st.button(prompt, args=(prompt,), on_click=handle_user_input)