File size: 4,760 Bytes
cd8951a
 
2ca0b2e
cd8951a
 
 
 
 
 
 
 
 
2ca0b2e
cd8951a
2ca0b2e
cd8951a
 
 
 
 
 
 
 
2ca0b2e
cd8951a
 
 
 
 
2ca0b2e
cd8951a
 
 
 
 
 
 
 
2ca0b2e
cd8951a
 
2ca0b2e
cd8951a
 
 
2ca0b2e
cd8951a
2ca0b2e
 
cd8951a
 
2ca0b2e
cd8951a
 
 
 
 
 
 
2ca0b2e
cd8951a
 
2ca0b2e
 
cd8951a
 
2ca0b2e
cd8951a
 
 
 
2ca0b2e
cd8951a
 
 
 
 
2ca0b2e
 
cd8951a
 
 
 
 
 
 
2ca0b2e
 
cd8951a
 
 
 
 
 
 
2ca0b2e
cd8951a
 
 
2ca0b2e
 
cd8951a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import streamlit as st
import hopsworks
import joblib
from openai import OpenAI
from functions.llm_chain import (
    load_model, 
    get_llm_chain, 
    generate_response, 
    generate_response_openai,
)
import warnings
warnings.filterwarnings('ignore')

st.title("🌤️ AirQuality AI assistant 💬")

@st.cache_resource()
def connect_to_hopsworks():
    # Initialize Hopsworks feature store connection
    project = hopsworks.login()
    fs = project.get_feature_store()
    
    # Retrieve the model registry
    mr = project.get_model_registry()

    # Retrieve the 'air_quality_fv' feature view
    feature_view = fs.get_feature_view(
        name="air_quality_fv", 
        version=1,
    )

    # Initialize batch scoring
    feature_view.init_batch_scoring(1)
    
    # Retrieve the 'air_quality_xgboost_model' from the model registry
    retrieved_model = mr.get_model(
        name="air_quality_xgboost_model",
        version=1,
    )

    # Download the saved model artifacts to a local directory
    saved_model_dir = retrieved_model.download()

    # Load the XGBoost regressor model and label encoder from the saved model directory
    model_air_quality = joblib.load(saved_model_dir + "/xgboost_regressor.pkl")
    encoder = joblib.load(saved_model_dir + "/label_encoder.pkl")

    return feature_view, model_air_quality, encoder


@st.cache_resource()
def retrieve_llm_chain():

    # Load the LLM and its corresponding tokenizer.
    model_llm, tokenizer = load_model()
    
    # Create and configure a language model chain.
    llm_chain = get_llm_chain(
        model_llm, 
        tokenizer,
    )
    
    return model_llm, tokenizer, llm_chain


# Retrieve the feature view, air quality model and encoder for the city_name column
feature_view, model_air_quality, encoder = connect_to_hopsworks()

# Initialize or clear chat messages based on response source change
if "response_source" not in st.session_state or "messages" not in st.session_state:
    st.session_state.messages = []
    st.session_state.response_source = ""

# User choice for model selection in the sidebar with OpenAI API as the default
new_response_source = st.sidebar.radio(
    "Choose the response generation method:",
    ('Hermes LLM', 'OpenAI API'),
    index=1  # Sets "OpenAI API" as the default selection
)

# If the user switches the response generation method, clear the chat
if new_response_source != st.session_state.response_source:
    st.session_state.messages = []  # Clear previous chat messages
    st.session_state.response_source = new_response_source  # Update response source in session state

    # Display a message indicating chat was cleared (optional)
    st.experimental_rerun()  # Rerun the app to reflect changes immediately
    
    
if new_response_source == 'OpenAI API':
    openai_api_key = st.sidebar.text_input("Enter your OpenAI API key:", type="password")
    if openai_api_key:
        client = OpenAI(
            api_key=openai_api_key
        )
        st.sidebar.success("API key saved successfully ✅")
        
elif new_response_source == 'Hermes LLM':
    # Conditionally load the LLM, tokenizer, and llm_chain if Local Model is selected
    model_llm, tokenizer, llm_chain = retrieve_llm_chain()

    
# Display chat messages from history on app rerun
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# React to user input
if user_query := st.chat_input("How can I help you?"):
    # Display user message in chat message container
    st.chat_message("user").markdown(user_query)
    # Add user message to chat history
    st.session_state.messages.append({"role": "user", "content": user_query})

    st.write('⚙️ Generating Response...')

    if new_response_source == 'Hermes LLM':
        # Generate a response to the user query
        response = generate_response(
            user_query,
            feature_view,
            model_air_quality,
            encoder,
            model_llm,
            tokenizer,
            llm_chain,
            verbose=False,
        )
        
    elif new_response_source == 'OpenAI API' and openai_api_key:
        response = generate_response_openai(   
            user_query,
            feature_view,
            model_air_quality,
            encoder,
            client,
            verbose=False,
        )
        
    else:
        response = "Please select a response generation method and provide necessary details."

    # Display assistant response in chat message container
    with st.chat_message("assistant"):
        st.markdown(response)
    # Add assistant response to chat history
    st.session_state.messages.append({"role": "assistant", "content": response})