Air_quality / app.py
Annikaijak's picture
Update app.py
cd8951a verified
raw
history blame contribute delete
No virus
4.76 kB
import streamlit as st
import hopsworks
import joblib
from openai import OpenAI
from functions.llm_chain import (
load_model,
get_llm_chain,
generate_response,
generate_response_openai,
)
import warnings
warnings.filterwarnings('ignore')
st.title("🌀️ AirQuality AI assistant πŸ’¬")
@st.cache_resource()
def connect_to_hopsworks():
# Initialize Hopsworks feature store connection
project = hopsworks.login()
fs = project.get_feature_store()
# Retrieve the model registry
mr = project.get_model_registry()
# Retrieve the 'air_quality_fv' feature view
feature_view = fs.get_feature_view(
name="air_quality_fv",
version=1,
)
# Initialize batch scoring
feature_view.init_batch_scoring(1)
# Retrieve the 'air_quality_xgboost_model' from the model registry
retrieved_model = mr.get_model(
name="air_quality_xgboost_model",
version=1,
)
# Download the saved model artifacts to a local directory
saved_model_dir = retrieved_model.download()
# Load the XGBoost regressor model and label encoder from the saved model directory
model_air_quality = joblib.load(saved_model_dir + "/xgboost_regressor.pkl")
encoder = joblib.load(saved_model_dir + "/label_encoder.pkl")
return feature_view, model_air_quality, encoder
@st.cache_resource()
def retrieve_llm_chain():
# Load the LLM and its corresponding tokenizer.
model_llm, tokenizer = load_model()
# Create and configure a language model chain.
llm_chain = get_llm_chain(
model_llm,
tokenizer,
)
return model_llm, tokenizer, llm_chain
# Retrieve the feature view, air quality model and encoder for the city_name column
feature_view, model_air_quality, encoder = connect_to_hopsworks()
# Initialize or clear chat messages based on response source change
if "response_source" not in st.session_state or "messages" not in st.session_state:
st.session_state.messages = []
st.session_state.response_source = ""
# User choice for model selection in the sidebar with OpenAI API as the default
new_response_source = st.sidebar.radio(
"Choose the response generation method:",
('Hermes LLM', 'OpenAI API'),
index=1 # Sets "OpenAI API" as the default selection
)
# If the user switches the response generation method, clear the chat
if new_response_source != st.session_state.response_source:
st.session_state.messages = [] # Clear previous chat messages
st.session_state.response_source = new_response_source # Update response source in session state
# Display a message indicating chat was cleared (optional)
st.experimental_rerun() # Rerun the app to reflect changes immediately
if new_response_source == 'OpenAI API':
openai_api_key = st.sidebar.text_input("Enter your OpenAI API key:", type="password")
if openai_api_key:
client = OpenAI(
api_key=openai_api_key
)
st.sidebar.success("API key saved successfully βœ…")
elif new_response_source == 'Hermes LLM':
# Conditionally load the LLM, tokenizer, and llm_chain if Local Model is selected
model_llm, tokenizer, llm_chain = retrieve_llm_chain()
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# React to user input
if user_query := st.chat_input("How can I help you?"):
# Display user message in chat message container
st.chat_message("user").markdown(user_query)
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": user_query})
st.write('βš™οΈ Generating Response...')
if new_response_source == 'Hermes LLM':
# Generate a response to the user query
response = generate_response(
user_query,
feature_view,
model_air_quality,
encoder,
model_llm,
tokenizer,
llm_chain,
verbose=False,
)
elif new_response_source == 'OpenAI API' and openai_api_key:
response = generate_response_openai(
user_query,
feature_view,
model_air_quality,
encoder,
client,
verbose=False,
)
else:
response = "Please select a response generation method and provide necessary details."
# Display assistant response in chat message container
with st.chat_message("assistant"):
st.markdown(response)
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": response})