Annikaijak commited on
Commit
2ca0b2e
β€’
1 Parent(s): a1aa76f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -0
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ import pickle
4
+ import joblib
5
+
6
+ import hopsworks
7
+ import streamlit as st
8
+ from geopy import distance
9
+
10
+ import plotly.express as px
11
+ import folium
12
+ from streamlit_folium import st_folium
13
+
14
+ from functions import *
15
+
16
+
17
+
18
+ def print_fancy_header(text, font_size=22, color="#ff5f27"):
19
+ res = f'<span style="color:{color}; font-size: {font_size}px;">{text}</span>'
20
+ st.markdown(res, unsafe_allow_html=True)
21
+
22
+ @st.cache_data()
23
+ def get_batch_data_from_fs(td_version, date_threshold):
24
+ st.write(f"Retrieving the Batch data since {date_threshold}")
25
+ feature_view.init_batch_scoring(training_dataset_version=td_version)
26
+
27
+ batch_data = feature_view.get_batch_data(start_time=date_threshold)
28
+ return batch_data
29
+
30
+
31
+ @st.cache_data()
32
+ def download_model(name="air_quality_xgboost_model", version=1):
33
+ mr = project.get_model_registry()
34
+ retrieved_model = mr.get_model(
35
+ name="air_quality_xgboost_model",
36
+ version=1
37
+ )
38
+ saved_model_dir = retrieved_model.download()
39
+ return saved_model_dir
40
+
41
+
42
+
43
+ def plot_pm2_5(df):
44
+ # create figure with plotly express
45
+ fig = px.line(df, x='date', y='pm2_5', color='city_name')
46
+
47
+ # customize line colors and styles
48
+ fig.update_traces(mode='lines+markers')
49
+ fig.update_layout({
50
+ 'plot_bgcolor': 'rgba(0, 0, 0, 0)',
51
+ 'paper_bgcolor': 'rgba(0, 0, 0, 0)',
52
+ 'legend_title': 'City',
53
+ 'legend_font': {'size': 12},
54
+ 'legend_bgcolor': 'rgba(0, 0, 0, 0)',
55
+ 'xaxis': {'title': 'Date'},
56
+ 'yaxis': {'title': 'PM2.5'},
57
+ 'shapes': [{
58
+ 'type': 'line',
59
+ 'x0': datetime.datetime.now().strftime('%Y-%m-%d'),
60
+ 'y0': 0,
61
+ 'x1': datetime.datetime.now().strftime('%Y-%m-%d'),
62
+ 'y1': df['pm2_5'].max(),
63
+ 'line': {'color': 'red', 'width': 2, 'dash': 'dashdot'}
64
+ }]
65
+ })
66
+
67
+ # show plot
68
+ st.plotly_chart(fig, use_container_width=True)
69
+
70
+
71
+ with open('target_cities.json') as json_file:
72
+ target_cities = json.load(json_file)
73
+
74
+
75
+ #########################
76
+ st.title('🌫 Air Quality Prediction 🌦')
77
+
78
+ st.write(3 * "-")
79
+ print_fancy_header('\nπŸ“‘ Connecting to Hopsworks Feature Store...')
80
+
81
+ st.write("Logging... ")
82
+ # (Attention! If the app has stopped at this step,
83
+ # please enter your Hopsworks API Key in the commmand prompt.)
84
+ project = hopsworks.login()
85
+ fs = project.get_feature_store()
86
+ st.write("βœ… Logged in successfully!")
87
+
88
+ st.write("Getting the Feature View...")
89
+ feature_view = fs.get_feature_view(
90
+ name = 'air_quality_fv',
91
+ version = 1
92
+ )
93
+ st.write("βœ… Success!")
94
+
95
+ # I am going to load data for of last 60 days (for feature engineering)
96
+ today = datetime.date.today()
97
+ date_threshold = today - datetime.timedelta(days=60)
98
+
99
+ st.write(3 * "-")
100
+ print_fancy_header('\n☁️ Retriving batch data from Feature Store...')
101
+ batch_data = get_batch_data_from_fs(td_version=1,
102
+ date_threshold=date_threshold)
103
+
104
+ st.write("Batch data:")
105
+ st.write(batch_data.sample(5))
106
+
107
+
108
+ saved_model_dir = download_model(
109
+ name="air_quality_xgboost_model",
110
+ version=1
111
+ )
112
+
113
+ pipeline = joblib.load(saved_model_dir + "/xgboost_pipeline.pkl")
114
+ st.write("\n")
115
+ st.write("βœ… Model was downloaded and cached.")
116
+
117
+ st.write(3 * '-')
118
+ st.write("\n")
119
+ print_fancy_header(text="πŸ– Select the cities using the form below. \
120
+ Click the 'Submit' button at the bottom of the form to continue.",
121
+ font_size=22)
122
+ dict_for_streamlit = {}
123
+ for continent in target_cities:
124
+ for city_name, coords in target_cities[continent].items():
125
+ dict_for_streamlit[city_name] = coords
126
+ selected_cities_full_list = []
127
+
128
+ with st.form(key="user_inputs"):
129
+ print_fancy_header(text='\nπŸ—Ί Here you can choose cities from the drop-down menu',
130
+ font_size=20, color="#00FFFF")
131
+
132
+ cities_multiselect = st.multiselect(label='',
133
+ options=dict_for_streamlit.keys())
134
+ selected_cities_full_list.extend(cities_multiselect)
135
+ st.write("_" * 3)
136
+ print_fancy_header(text="\nπŸ“Œ To add a city using the interactive map, click somewhere \
137
+ (for the coordinates to appear)",
138
+ font_size=20, color="#00FFFF")
139
+
140
+ my_map = folium.Map(location=[42.57, -44.092], zoom_start=2)
141
+ # Add markers for each city
142
+ for city_name, coords in dict_for_streamlit.items():
143
+ folium.CircleMarker(
144
+ location=coords
145
+ ).add_to(my_map)
146
+
147
+ my_map.add_child(folium.LatLngPopup())
148
+ res_map = st_folium(my_map, width=640, height=480)
149
+
150
+ try:
151
+ new_lat, new_long = res_map["last_clicked"]["lat"], res_map["last_clicked"]["lng"]
152
+
153
+ # Calculate the distance between the clicked location and each city
154
+ distances = {city: distance.distance(coord, (new_lat, new_long)).km for city, coord in dict_for_streamlit.items()}
155
+
156
+ # Find the city with the minimum distance and print its name
157
+ nearest_city = min(distances, key=distances.get)
158
+ print_fancy_header(text=f"You have selected {nearest_city} using map", font_size=18, color="#52fa23")
159
+
160
+ selected_cities_full_list.append(nearest_city)
161
+ st.write(label_encoder.transform([nearest_city])[0])
162
+
163
+ except Exception as err:
164
+ print(err)
165
+ pass
166
+
167
+ submit_button = st.form_submit_button(label='Submit')
168
+
169
+ if submit_button:
170
+ st.write('Selected cities:', selected_cities_full_list)
171
+
172
+ st.write(3*'-')
173
+
174
+ dataset = batch_data
175
+
176
+ dataset = dataset.sort_values(by=["city_name", "date"])
177
+
178
+ st.write("\n")
179
+ print_fancy_header(text='\n🧠 Predicting PM2.5 for selected cities...',
180
+ font_size=18, color="#FDF4F5")
181
+ st.write("")
182
+ preds = pd.DataFrame(columns=dataset.columns)
183
+ for city_name in selected_cities_full_list:
184
+ st.write(f"\t * {city_name}...")
185
+ features = dataset.loc[dataset['city_name'] == city_name]
186
+ print(features.head())
187
+ features['pm2_5'] = pipeline.predict(features)
188
+ preds = pd.concat([preds, features])
189
+
190
+ st.write("")
191
+ print_fancy_header(text="πŸ“ˆResults πŸ“‰",
192
+ font_size=22)
193
+ plot_pm2_5(preds[preds['city_name'].isin(selected_cities_full_list)])
194
+
195
+ st.write(3 * "-")
196
+ st.subheader('\nπŸŽ‰ πŸ“ˆ 🀝 App Finished Successfully 🀝 πŸ“ˆ πŸŽ‰')
197
+ st.button("Re-run")