Spaces:
Sleeping
Sleeping
File size: 4,760 Bytes
cd8951a 2ca0b2e cd8951a 2ca0b2e cd8951a 2ca0b2e cd8951a 2ca0b2e cd8951a 2ca0b2e cd8951a 2ca0b2e cd8951a 2ca0b2e cd8951a 2ca0b2e cd8951a 2ca0b2e cd8951a 2ca0b2e cd8951a 2ca0b2e cd8951a 2ca0b2e cd8951a 2ca0b2e cd8951a 2ca0b2e cd8951a 2ca0b2e cd8951a 2ca0b2e cd8951a 2ca0b2e cd8951a 2ca0b2e cd8951a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import streamlit as st
import hopsworks
import joblib
from openai import OpenAI
from functions.llm_chain import (
load_model,
get_llm_chain,
generate_response,
generate_response_openai,
)
import warnings
warnings.filterwarnings('ignore')
st.title("🌤️ AirQuality AI assistant 💬")
@st.cache_resource()
def connect_to_hopsworks():
# Initialize Hopsworks feature store connection
project = hopsworks.login()
fs = project.get_feature_store()
# Retrieve the model registry
mr = project.get_model_registry()
# Retrieve the 'air_quality_fv' feature view
feature_view = fs.get_feature_view(
name="air_quality_fv",
version=1,
)
# Initialize batch scoring
feature_view.init_batch_scoring(1)
# Retrieve the 'air_quality_xgboost_model' from the model registry
retrieved_model = mr.get_model(
name="air_quality_xgboost_model",
version=1,
)
# Download the saved model artifacts to a local directory
saved_model_dir = retrieved_model.download()
# Load the XGBoost regressor model and label encoder from the saved model directory
model_air_quality = joblib.load(saved_model_dir + "/xgboost_regressor.pkl")
encoder = joblib.load(saved_model_dir + "/label_encoder.pkl")
return feature_view, model_air_quality, encoder
@st.cache_resource()
def retrieve_llm_chain():
# Load the LLM and its corresponding tokenizer.
model_llm, tokenizer = load_model()
# Create and configure a language model chain.
llm_chain = get_llm_chain(
model_llm,
tokenizer,
)
return model_llm, tokenizer, llm_chain
# Retrieve the feature view, air quality model and encoder for the city_name column
feature_view, model_air_quality, encoder = connect_to_hopsworks()
# Initialize or clear chat messages based on response source change
if "response_source" not in st.session_state or "messages" not in st.session_state:
st.session_state.messages = []
st.session_state.response_source = ""
# User choice for model selection in the sidebar with OpenAI API as the default
new_response_source = st.sidebar.radio(
"Choose the response generation method:",
('Hermes LLM', 'OpenAI API'),
index=1 # Sets "OpenAI API" as the default selection
)
# If the user switches the response generation method, clear the chat
if new_response_source != st.session_state.response_source:
st.session_state.messages = [] # Clear previous chat messages
st.session_state.response_source = new_response_source # Update response source in session state
# Display a message indicating chat was cleared (optional)
st.experimental_rerun() # Rerun the app to reflect changes immediately
if new_response_source == 'OpenAI API':
openai_api_key = st.sidebar.text_input("Enter your OpenAI API key:", type="password")
if openai_api_key:
client = OpenAI(
api_key=openai_api_key
)
st.sidebar.success("API key saved successfully ✅")
elif new_response_source == 'Hermes LLM':
# Conditionally load the LLM, tokenizer, and llm_chain if Local Model is selected
model_llm, tokenizer, llm_chain = retrieve_llm_chain()
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# React to user input
if user_query := st.chat_input("How can I help you?"):
# Display user message in chat message container
st.chat_message("user").markdown(user_query)
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": user_query})
st.write('⚙️ Generating Response...')
if new_response_source == 'Hermes LLM':
# Generate a response to the user query
response = generate_response(
user_query,
feature_view,
model_air_quality,
encoder,
model_llm,
tokenizer,
llm_chain,
verbose=False,
)
elif new_response_source == 'OpenAI API' and openai_api_key:
response = generate_response_openai(
user_query,
feature_view,
model_air_quality,
encoder,
client,
verbose=False,
)
else:
response = "Please select a response generation method and provide necessary details."
# Display assistant response in chat message container
with st.chat_message("assistant"):
st.markdown(response)
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": response}) |