Annikaijak commited on
Commit
cd8951a
1 Parent(s): 3cf3b3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -174
app.py CHANGED
@@ -1,197 +1,146 @@
1
- import json
2
- import time
3
- import pickle
4
  import joblib
 
 
 
 
 
 
 
 
 
5
 
6
- import hopsworks
7
- import streamlit as st
8
- from geopy import distance
9
 
10
- import plotly.express as px
11
- import folium
12
- from streamlit_folium import st_folium
 
 
 
 
 
13
 
14
- from functions import *
 
 
 
 
15
 
 
 
 
 
 
 
 
 
16
 
 
 
17
 
18
- def print_fancy_header(text, font_size=22, color="#ff5f27"):
19
- res = f'<span style="color:{color}; font-size: {font_size}px;">{text}</span>'
20
- st.markdown(res, unsafe_allow_html=True)
21
 
22
- @st.cache_data()
23
- def get_batch_data_from_fs(td_version, date_threshold):
24
- st.write(f"Retrieving the Batch data since {date_threshold}")
25
- feature_view.init_batch_scoring(training_dataset_version=td_version)
26
 
27
- batch_data = feature_view.get_batch_data(start_time=date_threshold)
28
- return batch_data
29
 
 
 
30
 
31
- @st.cache_data()
32
- def download_model(name="air_quality_xgboost_model", version=1):
33
- mr = project.get_model_registry()
34
- retrieved_model = mr.get_model(
35
- name="air_quality_xgboost_model",
36
- version=1
 
37
  )
38
- saved_model_dir = retrieved_model.download()
39
- return saved_model_dir
40
-
41
-
42
-
43
- def plot_pm2_5(df):
44
- # create figure with plotly express
45
- fig = px.line(df, x='date', y='pm2_5', color='city_name')
46
-
47
- # customize line colors and styles
48
- fig.update_traces(mode='lines+markers')
49
- fig.update_layout({
50
- 'plot_bgcolor': 'rgba(0, 0, 0, 0)',
51
- 'paper_bgcolor': 'rgba(0, 0, 0, 0)',
52
- 'legend_title': 'City',
53
- 'legend_font': {'size': 12},
54
- 'legend_bgcolor': 'rgba(0, 0, 0, 0)',
55
- 'xaxis': {'title': 'Date'},
56
- 'yaxis': {'title': 'PM2.5'},
57
- 'shapes': [{
58
- 'type': 'line',
59
- 'x0': datetime.datetime.now().strftime('%Y-%m-%d'),
60
- 'y0': 0,
61
- 'x1': datetime.datetime.now().strftime('%Y-%m-%d'),
62
- 'y1': df['pm2_5'].max(),
63
- 'line': {'color': 'red', 'width': 2, 'dash': 'dashdot'}
64
- }]
65
- })
66
-
67
- # show plot
68
- st.plotly_chart(fig, use_container_width=True)
69
-
70
-
71
- with open('target_cities.json') as json_file:
72
- target_cities = json.load(json_file)
73
-
74
-
75
- #########################
76
- st.title('🌫 Air Quality Prediction 🌦')
77
-
78
- st.write(3 * "-")
79
- print_fancy_header('\n📡 Connecting to Hopsworks Feature Store...')
80
-
81
- st.write("Logging... ")
82
- # (Attention! If the app has stopped at this step,
83
- # please enter your Hopsworks API Key in the commmand prompt.)
84
- project = hopsworks.login()
85
- fs = project.get_feature_store()
86
- st.write("✅ Logged in successfully!")
87
-
88
- st.write("Getting the Feature View...")
89
- feature_view = fs.get_feature_view(
90
- name = 'air_quality_fv',
91
- version = 1
92
- )
93
- st.write("✅ Success!")
94
-
95
- # I am going to load data for of last 60 days (for feature engineering)
96
- today = datetime.date.today()
97
- date_threshold = today - datetime.timedelta(days=60)
98
 
99
- st.write(3 * "-")
100
- print_fancy_header('\n☁️ Retriving batch data from Feature Store...')
101
- batch_data = get_batch_data_from_fs(td_version=1,
102
- date_threshold=date_threshold)
103
 
104
- st.write("Batch data:")
105
- st.write(batch_data.sample(5))
106
 
 
 
 
 
107
 
108
- saved_model_dir = download_model(
109
- name="air_quality_xgboost_model",
110
- version=1
 
 
111
  )
112
 
113
- pipeline = joblib.load(saved_model_dir + "/xgboost_pipeline.pkl")
114
- st.write("\n")
115
- st.write("✅ Model was downloaded and cached.")
116
-
117
- st.write(3 * '-')
118
- st.write("\n")
119
- print_fancy_header(text="🖍 Select the cities using the form below. \
120
- Click the 'Submit' button at the bottom of the form to continue.",
121
- font_size=22)
122
- dict_for_streamlit = {}
123
- for continent in target_cities:
124
- for city_name, coords in target_cities[continent].items():
125
- dict_for_streamlit[city_name] = coords
126
- selected_cities_full_list = []
127
-
128
- with st.form(key="user_inputs"):
129
- print_fancy_header(text='\n🗺 Here you can choose cities from the drop-down menu',
130
- font_size=20, color="#00FFFF")
131
-
132
- cities_multiselect = st.multiselect(label='',
133
- options=dict_for_streamlit.keys())
134
- selected_cities_full_list.extend(cities_multiselect)
135
- st.write("_" * 3)
136
- print_fancy_header(text="\n📌 To add a city using the interactive map, click somewhere \
137
- (for the coordinates to appear)",
138
- font_size=20, color="#00FFFF")
139
 
140
- my_map = folium.Map(location=[42.57, -44.092], zoom_start=2)
141
- # Add markers for each city
142
- for city_name, coords in dict_for_streamlit.items():
143
- folium.CircleMarker(
144
- location=coords
145
- ).add_to(my_map)
146
-
147
- my_map.add_child(folium.LatLngPopup())
148
- res_map = st_folium(my_map, width=640, height=480)
149
 
150
- try:
151
- new_lat, new_long = res_map["last_clicked"]["lat"], res_map["last_clicked"]["lng"]
152
-
153
- # Calculate the distance between the clicked location and each city
154
- distances = {city: distance.distance(coord, (new_lat, new_long)).km for city, coord in dict_for_streamlit.items()}
155
-
156
- # Find the city with the minimum distance and print its name
157
- nearest_city = min(distances, key=distances.get)
158
- print_fancy_header(text=f"You have selected {nearest_city} using map", font_size=18, color="#52fa23")
159
 
160
- selected_cities_full_list.append(nearest_city)
161
- st.write(label_encoder.transform([nearest_city])[0])
 
162
 
163
- except Exception as err:
164
- print(err)
165
- pass
166
 
167
- submit_button = st.form_submit_button(label='Submit')
168
-
169
- if submit_button:
170
- st.write('Selected cities:', selected_cities_full_list)
171
-
172
- st.write(3*'-')
173
-
174
- dataset = batch_data
175
-
176
- dataset = dataset.sort_values(by=["city_name", "date"])
177
-
178
- st.write("\n")
179
- print_fancy_header(text='\n🧠 Predicting PM2.5 for selected cities...',
180
- font_size=18, color="#FDF4F5")
181
- st.write("")
182
- preds = pd.DataFrame(columns=dataset.columns)
183
- for city_name in selected_cities_full_list:
184
- st.write(f"\t * {city_name}...")
185
- features = dataset.loc[dataset['city_name'] == city_name]
186
- print(features.head())
187
- features['pm2_5'] = pipeline.predict(features)
188
- preds = pd.concat([preds, features])
189
-
190
- st.write("")
191
- print_fancy_header(text="📈Results 📉",
192
- font_size=22)
193
- plot_pm2_5(preds[preds['city_name'].isin(selected_cities_full_list)])
194
-
195
- st.write(3 * "-")
196
- st.subheader('\n🎉 📈 🤝 App Finished Successfully 🤝 📈 🎉')
197
- st.button("Re-run")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import hopsworks
 
3
  import joblib
4
+ from openai import OpenAI
5
+ from functions.llm_chain import (
6
+ load_model,
7
+ get_llm_chain,
8
+ generate_response,
9
+ generate_response_openai,
10
+ )
11
+ import warnings
12
+ warnings.filterwarnings('ignore')
13
 
14
+ st.title("🌤️ AirQuality AI assistant 💬")
 
 
15
 
16
+ @st.cache_resource()
17
+ def connect_to_hopsworks():
18
+ # Initialize Hopsworks feature store connection
19
+ project = hopsworks.login()
20
+ fs = project.get_feature_store()
21
+
22
+ # Retrieve the model registry
23
+ mr = project.get_model_registry()
24
 
25
+ # Retrieve the 'air_quality_fv' feature view
26
+ feature_view = fs.get_feature_view(
27
+ name="air_quality_fv",
28
+ version=1,
29
+ )
30
 
31
+ # Initialize batch scoring
32
+ feature_view.init_batch_scoring(1)
33
+
34
+ # Retrieve the 'air_quality_xgboost_model' from the model registry
35
+ retrieved_model = mr.get_model(
36
+ name="air_quality_xgboost_model",
37
+ version=1,
38
+ )
39
 
40
+ # Download the saved model artifacts to a local directory
41
+ saved_model_dir = retrieved_model.download()
42
 
43
+ # Load the XGBoost regressor model and label encoder from the saved model directory
44
+ model_air_quality = joblib.load(saved_model_dir + "/xgboost_regressor.pkl")
45
+ encoder = joblib.load(saved_model_dir + "/label_encoder.pkl")
46
 
47
+ return feature_view, model_air_quality, encoder
 
 
 
48
 
 
 
49
 
50
+ @st.cache_resource()
51
+ def retrieve_llm_chain():
52
 
53
+ # Load the LLM and its corresponding tokenizer.
54
+ model_llm, tokenizer = load_model()
55
+
56
+ # Create and configure a language model chain.
57
+ llm_chain = get_llm_chain(
58
+ model_llm,
59
+ tokenizer,
60
  )
61
+
62
+ return model_llm, tokenizer, llm_chain
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
 
 
 
 
64
 
65
+ # Retrieve the feature view, air quality model and encoder for the city_name column
66
+ feature_view, model_air_quality, encoder = connect_to_hopsworks()
67
 
68
+ # Initialize or clear chat messages based on response source change
69
+ if "response_source" not in st.session_state or "messages" not in st.session_state:
70
+ st.session_state.messages = []
71
+ st.session_state.response_source = ""
72
 
73
+ # User choice for model selection in the sidebar with OpenAI API as the default
74
+ new_response_source = st.sidebar.radio(
75
+ "Choose the response generation method:",
76
+ ('Hermes LLM', 'OpenAI API'),
77
+ index=1 # Sets "OpenAI API" as the default selection
78
  )
79
 
80
+ # If the user switches the response generation method, clear the chat
81
+ if new_response_source != st.session_state.response_source:
82
+ st.session_state.messages = [] # Clear previous chat messages
83
+ st.session_state.response_source = new_response_source # Update response source in session state
84
+
85
+ # Display a message indicating chat was cleared (optional)
86
+ st.experimental_rerun() # Rerun the app to reflect changes immediately
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
 
 
 
 
 
 
 
 
 
88
 
89
+ if new_response_source == 'OpenAI API':
90
+ openai_api_key = st.sidebar.text_input("Enter your OpenAI API key:", type="password")
91
+ if openai_api_key:
92
+ client = OpenAI(
93
+ api_key=openai_api_key
94
+ )
95
+ st.sidebar.success("API key saved successfully ✅")
 
 
96
 
97
+ elif new_response_source == 'Hermes LLM':
98
+ # Conditionally load the LLM, tokenizer, and llm_chain if Local Model is selected
99
+ model_llm, tokenizer, llm_chain = retrieve_llm_chain()
100
 
 
 
 
101
 
102
+ # Display chat messages from history on app rerun
103
+ for message in st.session_state.messages:
104
+ with st.chat_message(message["role"]):
105
+ st.markdown(message["content"])
106
+
107
+ # React to user input
108
+ if user_query := st.chat_input("How can I help you?"):
109
+ # Display user message in chat message container
110
+ st.chat_message("user").markdown(user_query)
111
+ # Add user message to chat history
112
+ st.session_state.messages.append({"role": "user", "content": user_query})
113
+
114
+ st.write('⚙️ Generating Response...')
115
+
116
+ if new_response_source == 'Hermes LLM':
117
+ # Generate a response to the user query
118
+ response = generate_response(
119
+ user_query,
120
+ feature_view,
121
+ model_air_quality,
122
+ encoder,
123
+ model_llm,
124
+ tokenizer,
125
+ llm_chain,
126
+ verbose=False,
127
+ )
128
+
129
+ elif new_response_source == 'OpenAI API' and openai_api_key:
130
+ response = generate_response_openai(
131
+ user_query,
132
+ feature_view,
133
+ model_air_quality,
134
+ encoder,
135
+ client,
136
+ verbose=False,
137
+ )
138
+
139
+ else:
140
+ response = "Please select a response generation method and provide necessary details."
141
+
142
+ # Display assistant response in chat message container
143
+ with st.chat_message("assistant"):
144
+ st.markdown(response)
145
+ # Add assistant response to chat history
146
+ st.session_state.messages.append({"role": "assistant", "content": response})