Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	John Graham Reynolds
				
			
		try to change path to css file and add newer, non-experimental decorator for caching
		4abddf8
		
		| import os | |
| import threading | |
| import streamlit as st | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_databricks.vectorstores import DatabricksVectorSearch | |
| from itertools import tee | |
| DATABRICKS_HOST = os.environ.get("DATABRICKS_HOST") | |
| DATABRICKS_TOKEN = os.environ.get("DATABRICKS_TOKEN") | |
| VS_ENDPOINT_NAME = os.environ.get("VS_ENDPOINT_NAME") | |
| VS_INDEX_NAME = os.environ.get("VS_INDEX_NAME") | |
| if DATABRICKS_HOST is None: | |
| raise ValueError("DATABRICKS_HOST environment variable must be set") | |
| if DATABRICKS_TOKEN is None: | |
| raise ValueError("DATABRICKS_API_TOKEN environment variable must be set") | |
| MODEL_AVATAR_URL= "./VU.jpeg" | |
| # MSG_MAX_TURNS_EXCEEDED = f"Sorry! The Vanderbilt AI assistant playground is limited to {MAX_CHAT_TURNS} turns. Click the 'Clear Chat' button or refresh the page to start a new conversation." | |
| # MSG_CLIPPED_AT_MAX_OUT_TOKENS = "Reached maximum output tokens for DBRX Playground" | |
| EXAMPLE_PROMPTS = [ | |
| "Tell me about maximum out-of-pocket costs in healthcare." | |
| "Write a haiku about Nashville, Tennessee." | |
| "How is a data lake used at Vanderbilt University Medical Center?", | |
| "In a table, what are some of the greatest hurdles to healthcare in the United States?", | |
| "What does EDW stand for in the context of Vanderbilt University Medical Center?", | |
| "Code a sql statement that can query a database named 'VUMC'.", | |
| "Write a short story about a country concert in Nashville, Tennessee.", | |
| ] | |
| TITLE = "VUMC Chatbot" | |
| DESCRIPTION="""Welcome to the first generation Vanderbilt AI assistant! \n This AI assistant is built atop the Databricks DBRX large language model | |
| and is augmented with additional organization-specific knowledge. Specifically, it has been preliminarily augmented with knowledge of Vanderbilt University Medical Center | |
| terms like **Data Lake**, **EDW** (Enterprise Data Warehouse), **HCERA** (Health Care and Education Reconciliation Act), and **thousands more!** The model has **no access to PHI**. | |
| Try querying the model with any of the examples prompts below for a simple introduction to both Vanderbilt-specific and general knowledge queries. The purpose of this | |
| model is to allow VUMC employees access to an intelligent assistant that improves and expedites VUMC work. \n | |
| Feedback and ideas are very welcome! Please provide any feedback, ideas, or issues to the email: **john.graham.reynolds@vumc.org**. | |
| We hope to gradually improve this AI assistant to create a large-scale, all-inclusive tool to compliment the work of all VUMC staff.""" | |
| GENERAL_ERROR_MSG = "An error occurred. Please refresh the page to start a new conversation." | |
| # @st.cache_resource | |
| # def get_global_semaphore(): | |
| # return threading.BoundedSemaphore(QUEUE_SIZE) | |
| # global_semaphore = get_global_semaphore() | |
| st.set_page_config(layout="wide") | |
| # # To prevent streaming to fast, chunk the output into TOKEN_CHUNK_SIZE chunks | |
| TOKEN_CHUNK_SIZE = 1 | |
| # if TOKEN_CHUNK_SIZE_ENV is not None: | |
| # TOKEN_CHUNK_SIZE = int(TOKEN_CHUNK_SIZE_ENV) | |
| st.title(TITLE) | |
| # st.image("sunrise.jpg", caption="Sunrise by the mountains") # add a Vanderbilt related picture to the head of our Space! | |
| st.markdown(DESCRIPTION) | |
| st.markdown("\n") | |
| # use this to format later | |
| with open("./style.css") as css: | |
| st.markdown( f'<style>{css.read()}</style>' , unsafe_allow_html= True) | |
| if "messages" not in st.session_state: | |
| st.session_state["messages"] = [] | |
| def clear_chat_history(): | |
| st.session_state["messages"] = [] | |
| st.button('Clear Chat', on_click=clear_chat_history) | |
| def last_role_is_user(): | |
| return len(st.session_state["messages"]) > 0 and st.session_state["messages"][-1]["role"] == "user" | |
| def get_system_prompt(): | |
| return "" | |
| # ** working logic for querying glossary embeddings | |
| # Same embedding model we used to create embeddings of terms | |
| # make sure we cache this so that it doesnt redownload each time, hindering Space start time if sleeping | |
| # try adding this st caching decorator to ensure the embeddings class gets cached after downloading the entirety of the model | |
| # does this cache to the given folder though? It does appear to populate the folder as expected after being run | |
| # will this work here? | |
| def load_embedding_model(): | |
| embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-large-en", cache_folder="./langchain_cache/") | |
| return embeddings | |
| embeddings = load_embedding_model() | |
| # instantiate the vector store for similarity search in our chain | |
| # need to make this a function and decorate it with @st.experimental_memo as above? | |
| # We are only calling this initially when the Space starts. Can we expedite this process for users when opening up this Space? | |
| # @st.cache_data # TODO add this in | |
| vector_store = DatabricksVectorSearch( | |
| endpoint=VS_ENDPOINT_NAME, | |
| index_name=VS_INDEX_NAME, | |
| embedding=embeddings, | |
| text_column="name", | |
| columns=["name", "description"], | |
| ) | |
| def text_stream(stream): | |
| for chunk in stream: | |
| if chunk["content"] is not None: | |
| yield chunk["content"] | |
| def get_stream_warning_error(stream): | |
| error = None | |
| warning = None | |
| # for chunk in stream: | |
| # if chunk["error"] is not None: | |
| # error = chunk["error"] | |
| # if chunk["warning"] is not None: | |
| # warning = chunk["warning"] | |
| return warning, error | |
| # @retry(wait=wait_random_exponential(min=0.5, max=2), stop=stop_after_attempt(3)) | |
| def chat_api_call(history): | |
| # *** original code for instantiating the DBRX model through the OpenAI client *** skip this and introduce our chain eventually | |
| # extra_body = {} | |
| # if SAFETY_FILTER: | |
| # extra_body["enable_safety_filter"] = SAFETY_FILTER | |
| # chat_completion = client.chat.completions.create( | |
| # messages=[ | |
| # {"role": m["role"], "content": m["content"]} | |
| # for m in history | |
| # ], | |
| # model="databricks-dbrx-instruct", | |
| # stream=True, | |
| # max_tokens=MAX_TOKENS, | |
| # temperature=0.7, | |
| # extra_body= extra_body | |
| # ) | |
| # ** TODO update this next to take and do similarity search on user input! | |
| search_result = vector_store.similarity_search(query="Tell me about what a data lake is.", k=5) | |
| chat_completion = search_result # TODO update this after we implement our chain | |
| return chat_completion | |
| def write_response(): | |
| stream = chat_completion(st.session_state["messages"]) | |
| content_stream, error_stream = tee(stream) | |
| response = st.write_stream(text_stream(content_stream)) | |
| stream_warning, stream_error = get_stream_warning_error(error_stream) | |
| if stream_warning is not None: | |
| st.warning(stream_warning,icon="β οΈ") | |
| if stream_error is not None: | |
| st.error(stream_error,icon="π¨") | |
| # if there was an error, a list will be returned instead of a string: https://docs.streamlit.io/library/api-reference/write-magic/st.write_stream | |
| if isinstance(response, list): | |
| response = None | |
| return response, stream_warning, stream_error | |
| def chat_completion(messages): | |
| history_dbrx_format = [ | |
| {"role": "system", "content": get_system_prompt()} | |
| ] | |
| history_dbrx_format = history_dbrx_format + messages | |
| # if (len(history_dbrx_format)-1)//2 >= MAX_CHAT_TURNS: | |
| # yield {"content": None, "error": MSG_MAX_TURNS_EXCEEDED, "warning": None} | |
| # return | |
| chat_completion = None | |
| error = None | |
| # *** original code for querying DBRX through the OpenAI cleint for chat completion | |
| # wait to be in queue | |
| # with global_semaphore: | |
| # try: | |
| # chat_completion = chat_api_call(history_dbrx_format) | |
| # except Exception as e: | |
| # error = e | |
| chat_completion = chat_api_call(history_dbrx_format) | |
| if error is not None: | |
| yield {"content": None, "error": GENERAL_ERROR_MSG, "warning": None} | |
| print(error) | |
| return | |
| max_token_warning = None | |
| partial_message = "" | |
| chunk_counter = 0 | |
| for chunk in chat_completion: | |
| # if chunk.choices[0].delta.content is not None: | |
| if chunk.page_content is not None: | |
| chunk_counter += 1 | |
| # partial_message += chunk.choices[0].delta.content | |
| partial_message += f"* {chunk.page_content} [{chunk.metadata}]" | |
| if chunk_counter % TOKEN_CHUNK_SIZE == 0: | |
| chunk_counter = 0 | |
| yield {"content": partial_message, "error": None, "warning": None} | |
| partial_message = "" | |
| # if chunk.choices[0].finish_reason == "length": | |
| # max_token_warning = MSG_CLIPPED_AT_MAX_OUT_TOKENS | |
| yield {"content": partial_message, "error": None, "warning": max_token_warning} | |
| # if assistant is the last message, we need to prompt the user | |
| # if user is the last message, we need to retry the assistant. | |
| def handle_user_input(user_input): | |
| with history: | |
| response, stream_warning, stream_error = [None, None, None] | |
| if last_role_is_user(): | |
| # retry the assistant if the user tries to send a new message | |
| with st.chat_message("assistant", avatar=MODEL_AVATAR_URL): | |
| response, stream_warning, stream_error = write_response() | |
| else: | |
| st.session_state["messages"].append({"role": "user", "content": user_input, "warning": None, "error": None}) | |
| with st.chat_message("user"): | |
| st.markdown(user_input) | |
| stream = chat_completion(st.session_state["messages"]) | |
| with st.chat_message("assistant", avatar=MODEL_AVATAR_URL): | |
| response, stream_warning, stream_error = write_response() | |
| st.session_state["messages"].append({"role": "assistant", "content": response, "warning": stream_warning, "error": stream_error}) | |
| main = st.container() | |
| with main: | |
| history = st.container(height=400) | |
| with history: | |
| for message in st.session_state["messages"]: | |
| avatar = "π§βπ»" | |
| if message["role"] == "assistant": | |
| avatar = MODEL_AVATAR_URL | |
| with st.chat_message(message["role"],avatar=avatar): | |
| if message["content"] is not None: | |
| st.markdown(message["content"]) | |
| # if message["error"] is not None: | |
| # st.error(message["error"],icon="π¨") | |
| # if message["warning"] is not None: | |
| # st.warning(message["warning"],icon="β οΈ") | |
| if prompt := st.chat_input("Type a message!", max_chars=1000): | |
| handle_user_input(prompt) | |
| st.markdown("\n") #add some space for iphone users | |
| with st.sidebar: | |
| with st.container(): | |
| st.title("Examples") | |
| for prompt in EXAMPLE_PROMPTS: | |
| st.button(prompt, args=(prompt,), on_click=handle_user_input) | 
