Spaces:
Sleeping
Sleeping
import streamlit as st | |
import hopsworks | |
import joblib | |
from openai import OpenAI | |
from functions.llm_chain import ( | |
load_model, | |
get_llm_chain, | |
generate_response, | |
generate_response_openai, | |
) | |
import warnings | |
warnings.filterwarnings('ignore') | |
st.title("π€οΈ AirQuality AI assistant π¬") | |
def connect_to_hopsworks(): | |
# Initialize Hopsworks feature store connection | |
project = hopsworks.login() | |
fs = project.get_feature_store() | |
# Retrieve the model registry | |
mr = project.get_model_registry() | |
# Retrieve the 'air_quality_fv' feature view | |
feature_view = fs.get_feature_view( | |
name="air_quality_fv", | |
version=1, | |
) | |
# Initialize batch scoring | |
feature_view.init_batch_scoring(1) | |
# Retrieve the 'air_quality_xgboost_model' from the model registry | |
retrieved_model = mr.get_model( | |
name="air_quality_xgboost_model", | |
version=1, | |
) | |
# Download the saved model artifacts to a local directory | |
saved_model_dir = retrieved_model.download() | |
# Load the XGBoost regressor model and label encoder from the saved model directory | |
model_air_quality = joblib.load(saved_model_dir + "/xgboost_regressor.pkl") | |
encoder = joblib.load(saved_model_dir + "/label_encoder.pkl") | |
return feature_view, model_air_quality, encoder | |
def retrieve_llm_chain(): | |
# Load the LLM and its corresponding tokenizer. | |
model_llm, tokenizer = load_model() | |
# Create and configure a language model chain. | |
llm_chain = get_llm_chain( | |
model_llm, | |
tokenizer, | |
) | |
return model_llm, tokenizer, llm_chain | |
# Retrieve the feature view, air quality model and encoder for the city_name column | |
feature_view, model_air_quality, encoder = connect_to_hopsworks() | |
# Initialize or clear chat messages based on response source change | |
if "response_source" not in st.session_state or "messages" not in st.session_state: | |
st.session_state.messages = [] | |
st.session_state.response_source = "" | |
# User choice for model selection in the sidebar with OpenAI API as the default | |
new_response_source = st.sidebar.radio( | |
"Choose the response generation method:", | |
('Hermes LLM', 'OpenAI API'), | |
index=1 # Sets "OpenAI API" as the default selection | |
) | |
# If the user switches the response generation method, clear the chat | |
if new_response_source != st.session_state.response_source: | |
st.session_state.messages = [] # Clear previous chat messages | |
st.session_state.response_source = new_response_source # Update response source in session state | |
# Display a message indicating chat was cleared (optional) | |
st.experimental_rerun() # Rerun the app to reflect changes immediately | |
if new_response_source == 'OpenAI API': | |
openai_api_key = st.sidebar.text_input("Enter your OpenAI API key:", type="password") | |
if openai_api_key: | |
client = OpenAI( | |
api_key=openai_api_key | |
) | |
st.sidebar.success("API key saved successfully β ") | |
elif new_response_source == 'Hermes LLM': | |
# Conditionally load the LLM, tokenizer, and llm_chain if Local Model is selected | |
model_llm, tokenizer, llm_chain = retrieve_llm_chain() | |
# Display chat messages from history on app rerun | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# React to user input | |
if user_query := st.chat_input("How can I help you?"): | |
# Display user message in chat message container | |
st.chat_message("user").markdown(user_query) | |
# Add user message to chat history | |
st.session_state.messages.append({"role": "user", "content": user_query}) | |
st.write('βοΈ Generating Response...') | |
if new_response_source == 'Hermes LLM': | |
# Generate a response to the user query | |
response = generate_response( | |
user_query, | |
feature_view, | |
model_air_quality, | |
encoder, | |
model_llm, | |
tokenizer, | |
llm_chain, | |
verbose=False, | |
) | |
elif new_response_source == 'OpenAI API' and openai_api_key: | |
response = generate_response_openai( | |
user_query, | |
feature_view, | |
model_air_quality, | |
encoder, | |
client, | |
verbose=False, | |
) | |
else: | |
response = "Please select a response generation method and provide necessary details." | |
# Display assistant response in chat message container | |
with st.chat_message("assistant"): | |
st.markdown(response) | |
# Add assistant response to chat history | |
st.session_state.messages.append({"role": "assistant", "content": response}) |