Spaces:
Sleeping
Sleeping
add password
Browse files- main_page.py +82 -80
- pages/go_further.py +280 -277
- pages/image_classification.py +205 -204
- pages/object_detection.py +179 -178
- pages/recommendation_system.py +351 -352
- pages/sentiment_analysis.py +196 -197
- pages/supervised_unsupervised_page.py +631 -632
- pages/timeseries_analysis.py +202 -203
- pages/topic_modeling.py +147 -148
- utils.py +4 -5
main_page.py
CHANGED
@@ -1,11 +1,7 @@
|
|
1 |
-
import os
|
2 |
import streamlit as st
|
3 |
-
import pandas as pd
|
4 |
-
import numpy as np
|
5 |
-
|
6 |
from st_pages import Page, show_pages, Section, add_indentation
|
7 |
from PIL import Image
|
8 |
-
|
9 |
|
10 |
|
11 |
|
@@ -14,10 +10,16 @@ from PIL import Image
|
|
14 |
##################################################################################
|
15 |
|
16 |
st.set_page_config(layout="wide")
|
17 |
-
#add_indentation()
|
18 |
|
19 |
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
##################################################################################
|
23 |
# GOOGLE DRIVE CONNEXION #
|
@@ -36,107 +38,107 @@ st.set_page_config(layout="wide")
|
|
36 |
##################################################################################
|
37 |
|
38 |
|
39 |
-
st.image("images/AI.jpg")
|
40 |
-
st.markdown(" ")
|
41 |
|
42 |
-
col1, col2 = st.columns([0.65,0.35], gap="medium")
|
43 |
|
44 |
-
with col1:
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
#st.markdown("in collaboration with Hi! PARIS engineers: Laurène DAVID, Salma HOUIDI and Maeva N'GUESSAN")
|
50 |
|
51 |
-
# with col2:
|
52 |
-
#Hi! PARIS collaboration mention
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
|
63 |
|
64 |
-
st.markdown(" ")
|
65 |
-
st.divider()
|
66 |
|
67 |
|
68 |
-
# #Hi! PARIS collaboration mention
|
69 |
-
# st.markdown(" ")
|
70 |
-
# image_hiparis = Image.open('images/hi-paris.png')
|
71 |
-
# st.image(image_hiparis, width=150)
|
72 |
-
# url = "https://www.hi-paris.fr/"
|
73 |
-
# st.markdown("**The app was made in collaboration with [Hi! PARIS](%s)**" % url)
|
74 |
|
75 |
|
76 |
|
77 |
|
78 |
-
##################################################################################
|
79 |
-
# DASHBOARD/SIDEBAR #
|
80 |
-
##################################################################################
|
81 |
|
82 |
|
83 |
-
# AI use case pages
|
84 |
-
show_pages(
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
)
|
106 |
|
107 |
|
108 |
|
109 |
-
##################################################################################
|
110 |
-
# PAGE CONTENT #
|
111 |
-
##################################################################################
|
112 |
|
113 |
|
114 |
-
st.header("About the app")
|
115 |
|
116 |
|
117 |
-
st.info("""The goal of the **AI and Data Science Examples** is to give an introduction to Data Science by showcasing real-life applications.
|
118 |
-
|
119 |
|
120 |
-
st.markdown(" ")
|
121 |
|
122 |
-
st.markdown("""The app contains four sections:
|
123 |
- 1️⃣ **Machine Learning**: This first section covers use cases where structured data (data in a tabular format) is used to train an AI model.
|
124 |
-
|
125 |
- 2️⃣ **Natural Language Processing** (NLP): This second section showcases AI applications where large amounts of text data is analyzed using Deep Learning models.
|
126 |
-
|
127 |
- 3️⃣ **Computer Vision**: This third section covers a sub-field of AI called Computer Vision, which deals with image/video data.
|
128 |
-
|
129 |
- 🚀 **Go further**: In the final section, you will gain a deeper understanding of AI models and how they function.
|
130 |
-
|
131 |
-
|
132 |
|
133 |
-
st.image("images/ML_domains.png",
|
134 |
-
|
135 |
-
|
136 |
|
137 |
|
138 |
-
# st.markdown(" ")
|
139 |
-
# st.markdown(" ")
|
140 |
-
# st.markdown("## Want to learn more about AI ?")
|
141 |
-
# st.markdown("""**Hi! PARIS**, a multidisciplinary center on Data Analysis and AI founded by Institut Polytechnique de Paris and HEC Paris,
|
142 |
-
# hosts every year a **Data Science Bootcamp** for students of all levels.""")
|
|
|
|
|
1 |
import streamlit as st
|
|
|
|
|
|
|
2 |
from st_pages import Page, show_pages, Section, add_indentation
|
3 |
from PIL import Image
|
4 |
+
from utils import check_password
|
5 |
|
6 |
|
7 |
|
|
|
10 |
##################################################################################
|
11 |
|
12 |
st.set_page_config(layout="wide")
|
|
|
13 |
|
14 |
|
15 |
|
16 |
+
##################################################################################
|
17 |
+
# PASSWORD CHECK #
|
18 |
+
##################################################################################
|
19 |
+
|
20 |
+
if check_password():
|
21 |
+
|
22 |
+
|
23 |
|
24 |
##################################################################################
|
25 |
# GOOGLE DRIVE CONNEXION #
|
|
|
38 |
##################################################################################
|
39 |
|
40 |
|
41 |
+
st.image("images/AI.jpg")
|
42 |
+
st.markdown(" ")
|
43 |
|
44 |
+
col1, col2 = st.columns([0.65,0.35], gap="medium")
|
45 |
|
46 |
+
with col1:
|
47 |
+
st.title("AI and Data Science Examples")
|
48 |
+
st.subheader("HEC Paris, 2023-2024")
|
49 |
+
# st.markdown("""**Course provided by Shirish C. SRIVASTAVA** <br>
|
50 |
+
# **Hi! PARIS Engineering team**: Laurène DAVID, Salma HOUIDI and Maeva N'GUESSAN""", unsafe_allow_html=True)
|
51 |
+
#st.markdown("in collaboration with Hi! PARIS engineers: Laurène DAVID, Salma HOUIDI and Maeva N'GUESSAN")
|
52 |
|
53 |
+
# with col2:
|
54 |
+
#Hi! PARIS collaboration mention
|
55 |
+
# st.markdown(" ")
|
56 |
+
# st.markdown(" ")
|
57 |
+
#st.markdown(" ")
|
58 |
+
|
59 |
+
url = "https://www.hi-paris.fr/"
|
60 |
+
#st.markdown("This app was funded by the Hi! PARIS Center")
|
61 |
+
st.markdown("""###### **The app was made in collaboration with [Hi! PARIS](%s)** """ % url, unsafe_allow_html=True)
|
62 |
+
image_hiparis = Image.open('images/hi-paris.png')
|
63 |
+
st.image(image_hiparis, width=150)
|
64 |
|
65 |
|
66 |
+
st.markdown(" ")
|
67 |
+
st.divider()
|
68 |
|
69 |
|
70 |
+
# #Hi! PARIS collaboration mention
|
71 |
+
# st.markdown(" ")
|
72 |
+
# image_hiparis = Image.open('images/hi-paris.png')
|
73 |
+
# st.image(image_hiparis, width=150)
|
74 |
+
# url = "https://www.hi-paris.fr/"
|
75 |
+
# st.markdown("**The app was made in collaboration with [Hi! PARIS](%s)**" % url)
|
76 |
|
77 |
|
78 |
|
79 |
|
80 |
+
##################################################################################
|
81 |
+
# DASHBOARD/SIDEBAR #
|
82 |
+
##################################################################################
|
83 |
|
84 |
|
85 |
+
# AI use case pages
|
86 |
+
show_pages(
|
87 |
+
[
|
88 |
+
Page("main_page.py", "Home Page", "🏠"),
|
89 |
+
Section(name=" ", icon=""),
|
90 |
+
Section(name=" ", icon=""),
|
91 |
+
|
92 |
+
Section(name="Machine Learning", icon="1️⃣"),
|
93 |
+
Page("pages/supervised_unsupervised_page.py", "1| Supervised vs Unsupervised 🔍", ""),
|
94 |
+
Page("pages/timeseries_analysis.py", "2| Time Series Forecasting 📈", ""),
|
95 |
+
Page("pages/recommendation_system.py", "3| Recommendation systems 🛒", ""),
|
96 |
+
|
97 |
+
Section(name="Natural Language Processing", icon="2️⃣"),
|
98 |
+
Page("pages/topic_modeling.py", "1| Topic Modeling 📚", ""),
|
99 |
+
Page("pages/sentiment_analysis.py", "2| Sentiment Analysis 👍", ""),
|
100 |
+
|
101 |
+
Section(name="Computer Vision", icon="3️⃣"),
|
102 |
+
Page("pages/image_classification.py", "1| Image Classification 🖼️", ""),
|
103 |
+
Page("pages/object_detection.py", "2| Object Detection 📹", ""),
|
104 |
|
105 |
+
Page("pages/go_further.py", "🚀 Go further")
|
106 |
+
]
|
107 |
+
)
|
108 |
|
109 |
|
110 |
|
111 |
+
##################################################################################
|
112 |
+
# PAGE CONTENT #
|
113 |
+
##################################################################################
|
114 |
|
115 |
|
116 |
+
st.header("About the app")
|
117 |
|
118 |
|
119 |
+
st.info("""The goal of the **AI and Data Science Examples** is to give an introduction to Data Science by showcasing real-life applications.
|
120 |
+
The app includes use cases using traditional Machine Learning algorithms on structured data, as well as models that analyze unstructured data (text, images,...).""")
|
121 |
|
122 |
+
st.markdown(" ")
|
123 |
|
124 |
+
st.markdown("""The app contains four sections:
|
125 |
- 1️⃣ **Machine Learning**: This first section covers use cases where structured data (data in a tabular format) is used to train an AI model.
|
126 |
+
You will find pages on *Supervised/Unsupervised Learning*, *Time Series Forecasting* and AI powered *Recommendation Systems*.
|
127 |
- 2️⃣ **Natural Language Processing** (NLP): This second section showcases AI applications where large amounts of text data is analyzed using Deep Learning models.
|
128 |
+
Pages on *Topic Modeling* and *Sentiment Analysis*, which are different kinds of NLP models, can be found in this section.
|
129 |
- 3️⃣ **Computer Vision**: This third section covers a sub-field of AI called Computer Vision, which deals with image/video data.
|
130 |
+
The field of Computer Vision includes *Image classification* and *Object Detection*, which are both featured in this section.
|
131 |
- 🚀 **Go further**: In the final section, you will gain a deeper understanding of AI models and how they function.
|
132 |
+
The page features multiple models to try, as well as different datasets to train a model on.
|
133 |
+
""")
|
134 |
|
135 |
+
st.image("images/ML_domains.png",
|
136 |
+
caption="""This figure showcases a selection of sub-fields of AI, which includes
|
137 |
+
Machine Learning, NLP and Computer Vision.""")
|
138 |
|
139 |
|
140 |
+
# st.markdown(" ")
|
141 |
+
# st.markdown(" ")
|
142 |
+
# st.markdown("## Want to learn more about AI ?")
|
143 |
+
# st.markdown("""**Hi! PARIS**, a multidisciplinary center on Data Analysis and AI founded by Institut Polytechnique de Paris and HEC Paris,
|
144 |
+
# hosts every year a **Data Science Bootcamp** for students of all levels.""")
|
pages/go_further.py
CHANGED
@@ -9,7 +9,7 @@ import altair as alt
|
|
9 |
import plotly.express as px
|
10 |
|
11 |
from st_pages import add_indentation
|
12 |
-
from utils import load_data_csv
|
13 |
|
14 |
from sklearn.datasets import fetch_california_housing
|
15 |
from sklearn.compose import make_column_selector as selector
|
@@ -123,350 +123,353 @@ scores = np.diag(cm)
|
|
123 |
# START OF THE PAGE
|
124 |
##############################################################################################
|
125 |
|
126 |
-
|
127 |
-
|
128 |
-
st.
|
129 |
-
|
130 |
-
|
|
|
|
|
131 |
**Explainability** is also given for most models.
|
132 |
These results give an indication on which variable had the most impact on the model's final prediction. <br>
|
133 |
Note that each model has its own way of measuring explainability, which makes comparisions between model explainabilities difficult.
|
134 |
|
135 |
All of the classification models used in this page come from `scikit-learn`, which is a popular Data Science library in Python.
|
136 |
-
|
137 |
-
try:
|
138 |
-
|
139 |
-
except:
|
140 |
-
|
141 |
|
142 |
-
st.markdown(" ")
|
143 |
-
st.divider()
|
144 |
|
145 |
|
146 |
-
path_data = r'data/other_data'
|
147 |
|
148 |
-
st.markdown("# Classification ")
|
149 |
-
st.markdown("""**Reminder**: Classification models are AI models that are trained to predict a finite number of values/categories.
|
150 |
-
|
151 |
-
st.markdown(" ")
|
152 |
-
st.markdown(" ")
|
153 |
|
154 |
|
155 |
|
156 |
|
157 |
-
########################## SELECT A DATASET ###############################
|
158 |
|
159 |
-
st.markdown("### Select a dataset 📋")
|
160 |
-
st.markdown("""To perform the classification task, you can choose between three different datasets: **Titanic**, **Car evaluation**, **Wine quality** and **Diabetes prevention** <br>
|
161 |
-
|
162 |
-
|
163 |
|
164 |
-
st.warning("""**Note:** The performance of a Machine Learning model is sensitive to the data being used to train it.
|
165 |
-
|
166 |
|
167 |
-
select_data = st.selectbox("Choose an option", ["Titanic 🚢", "Car evaluation 🚙", "Wine quality 🍷", "Diabetes prevention 👩⚕️"]) #label_visibility="collapsed")
|
168 |
-
st.markdown(" ")
|
169 |
|
170 |
-
if select_data =="Wine quality 🍷":
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
|
178 |
-
|
179 |
-
|
180 |
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
|
190 |
-
|
191 |
-
if select_data == "Titanic 🚢":
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
|
203 |
-
|
204 |
-
|
205 |
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
- **Survived**: Survival (Died or Survived)
|
221 |
-
- **Pclass**: Ticket class of the passenger (1=First, 2=Second, 3=Third)
|
222 |
-
- **Gender**: Gender
|
223 |
-
- **Age**: Age in years
|
224 |
-
- **SibSp**: Number of siblings aboard the Titanic
|
225 |
-
- **Parch**: Number of parents/children aboard the Titanic
|
226 |
-
- **Fare**: Passenger fare
|
227 |
-
- **Embarked**: Port of Embarkation (C=Cherbourg, Q=Queenstown, S=Southampton)""")
|
228 |
-
|
229 |
-
if select_data == "Car evaluation 🚙":
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
- **Buying**: Buying price of the vehicule (Very high, high, medium, low)
|
253 |
-
- **Maintenance**: Price for maintenance (Very high, high, medium, low)
|
254 |
-
- **Doors**: Number of doors in the vehicule (2, 3, 4, 5 or more)
|
255 |
-
- **Persons**: Capacity in terms of persons to carry (2, 4, more)
|
256 |
-
- **Luggage boot**: Size of luggage boot
|
257 |
-
- **Safety**: Estimated safety of the car (low, medium, high)
|
258 |
-
- **Evaluation**: Evaluation level (unacceptable, acceptable)""")
|
259 |
-
|
260 |
|
261 |
-
if select_data == "Diabetes prevention 👩⚕️":
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
|
270 |
-
|
271 |
-
|
272 |
|
273 |
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
- **Pregnancies**: Number of pregnancies had
|
288 |
-
- **Glucose**: The level of glucose in the patient's blood
|
289 |
-
- **BloodPressure**: Blood pressure measurement
|
290 |
-
- **SkinThickness**: Thickness of the skin
|
291 |
-
- **Insulin**: Level of insulin in the blood
|
292 |
-
- **BMI**: Body mass index
|
293 |
-
- **DiabetesPedigreeFunction**: Likelihood of diabetes depending on the patient's age and diabetic family history
|
294 |
-
- **Age**: Age of the patient
|
295 |
-
- **Outcome**: Whether the patient has diabetes (Yes or No)""")
|
296 |
|
297 |
-
st.markdown(" ")
|
298 |
-
st.markdown(" ")
|
299 |
|
300 |
|
301 |
|
302 |
|
303 |
-
########################## SELECT A MODEL ###############################
|
304 |
|
305 |
-
st.markdown("### Select a model 📚")
|
306 |
-
st.markdown("""You can choose between three types of classification models: **K nearest neighbors (KNN)**, **Decision Trees** and **Random Forests**. <br>
|
307 |
-
|
308 |
-
|
309 |
|
310 |
-
st.warning("""**Note**: Different types of models exists for most Machine Learning tasks.
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
|
315 |
-
select_model = st.selectbox("**Choose an option**", ["K-nearest-neighbor 🏘️", "Decision Tree 🌳", "Random Forest 🏕️"])
|
316 |
-
st.markdown(" ")
|
317 |
|
318 |
|
319 |
-
if select_model == "K-nearest-neighbor 🏘️":
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
- When k is equal to 3 (the small dotted circle in the image below), the most common class is **Class B**. The red point will then be predicted as Classe B.
|
333 |
- When k is equal to 6 (the large dotted circle in the image below), the the most common class is **Class A**. The red point will then be predicted as Classe A.""",
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
see_code_box = st.checkbox("See the code", key='knn_code')
|
341 |
-
if see_code_box:
|
342 |
-
see_code(select_model)
|
343 |
-
|
344 |
-
|
345 |
-
if select_model == "Decision Tree 🌳":
|
346 |
-
st.info("""**About the model**: Decision trees are classification model that split the prediction task into a succession of decisions, each with only two possible outcomes.
|
347 |
-
These decisions can be visualized as a tree, with data points arriving from the top of the tree and landing at final "prediction regions".""")
|
348 |
-
|
349 |
-
select_param = 8
|
350 |
-
model_dict = {"model":select_model, "param":select_param}
|
351 |
-
|
352 |
-
learn_model = st.checkbox("Learn more about the model", key="tree")
|
353 |
-
if learn_model:
|
354 |
-
st.markdown("""The following image showcases a decision tree which predicts whether a **bank should give out a loan** to a client. <br>
|
355 |
-
The data used to train the model has each client's **age**, **salary** and **number of children**.""", unsafe_allow_html=True)
|
356 |
|
357 |
-
st.
|
358 |
-
|
|
|
359 |
|
360 |
-
|
361 |
-
|
362 |
-
|
|
|
363 |
|
364 |
-
|
365 |
-
|
366 |
-
see_code(select_model)
|
367 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
368 |
|
369 |
|
370 |
-
if select_model == "Random Forest 🏕️":
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
select_param = 8
|
376 |
-
model_dict = {"model":select_model, "param":select_param}
|
377 |
-
|
378 |
-
learn_model = st.checkbox("Learn more about the model", key="tree")
|
379 |
-
if learn_model:
|
380 |
-
st.markdown("""Random Forests classifiers combine the results of multiple trees by apply **majority voting**, which means selecting the class that was most often predicted by trees as the final prediction.
|
381 |
-
In the following image, the random forest model built four decision trees, who each have made their own class prediction. <br>"""
|
382 |
-
, unsafe_allow_html=True)
|
383 |
-
|
384 |
-
st.markdown("""Class C was predicted twice, whereas Class B et D where only predicted once. <br>
|
385 |
-
The final prediction of the random forest model is thus Class C.""", unsafe_allow_html=True)
|
386 |
|
387 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
388 |
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
|
393 |
|
394 |
|
395 |
-
st.markdown(" ")
|
396 |
-
st.markdown(" ")
|
397 |
|
398 |
-
########################## RUN THE MODEL ###############################
|
399 |
|
400 |
-
st.markdown("### Train the model ⚙️")
|
401 |
-
st.markdown("""Now, you can build the chosen classification model and use the selected dataset to train it. <br>
|
402 |
-
|
403 |
|
404 |
-
st.warning("""**Note**: Most machine learning models have an element of randomness in their predictions.
|
405 |
-
|
406 |
|
407 |
-
st.markdown(f"""You've selected the **{select_data}** dataset and the **{select_model}** model.""")
|
408 |
|
409 |
|
410 |
-
run_model = st.button("Run model", type="primary")
|
411 |
-
st.markdown(" ")
|
|
|
412 |
|
413 |
-
if run_model:
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
418 |
|
419 |
-
|
420 |
-
st.markdown("
|
421 |
st.markdown("""The values below represent the model's accuracy for each possible class.
|
422 |
-
|
|
|
|
|
|
|
|
|
423 |
if select_data == "Diabetes prevention 👩⚕️":
|
424 |
st.warning("""**Note**: The Diabetes dataset only contains information on 768 patients. 500 patients don't have diabetes and 268 do have the disease.
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
score_df = pd.DataFrame({"label":labels, "accuracy":np.round(score*100)})
|
429 |
-
fig = px.bar(score_df, x="label", y="accuracy", color="label", text_auto=True)
|
430 |
-
st.plotly_chart(fig, use_container_width=True)
|
431 |
-
|
432 |
-
st.warning("""**Note**: To improve the results of a model, practionners often conduct *hyperparameter tuning*.
|
433 |
-
It consists of trying different combination of the model's parameters to maximise the accuracy score.
|
434 |
-
Hyperparameter tuning wasn't conduct here in order to insure the app doesn't lag.""")
|
435 |
-
|
436 |
-
|
437 |
-
with tab2:
|
438 |
-
st.markdown("#### Explainability")
|
439 |
-
st.markdown("""Variables with a high explainability score had the most impact on the model's predictions.
|
440 |
-
Variables with a low explainability score had a much smaller impact.""")
|
441 |
-
|
442 |
-
df_feature_imp = pd.DataFrame({"variable":feature_names, "importance":feature_imp})
|
443 |
-
df_feature_imp = df_feature_imp.groupby("variable").mean().reset_index()
|
444 |
-
df_feature_imp["importance"] = df_feature_imp["importance"].round(2)
|
445 |
-
df_feature_imp.sort_values(by=["importance"], ascending=False, inplace=True)
|
446 |
-
|
447 |
-
fig = px.bar(df_feature_imp, x="importance", y="variable", color="importance")
|
448 |
st.plotly_chart(fig, use_container_width=True)
|
449 |
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
The lowest possible accuracy is 0 and the highest 100.""")
|
454 |
-
|
455 |
-
st.markdown("""The K-nearest-neighbor algorithm doesn't have a built-in solution to compute model explainability with `scikit-learn`.
|
456 |
-
You can use other python packages such as `SHAP` to compute explainability, which we didn't use here since they usually take a long time to output results.""")
|
457 |
-
|
458 |
-
if select_data == "Diabetes prevention 👩⚕️":
|
459 |
-
st.warning("""**Note**: The Diabetes dataset only contains information on 768 patients. 500 patients don't have diabetes and 268 do have the disease.
|
460 |
-
This small number of patient data explains why the model's performance isn't optimal.
|
461 |
-
Additional data collection as well as hyperparameter tuning can be conducted to improve results.""")
|
462 |
-
|
463 |
-
score_df = pd.DataFrame({"label":labels, "accuracy":np.round(score*100)})
|
464 |
-
fig = px.bar(score_df, x="label", y="accuracy", color="label", title="Accuracy results", text_auto=True)
|
465 |
-
st.plotly_chart(fig, use_container_width=True)
|
466 |
-
|
467 |
-
st.warning("""**Note**: To improve the results of a model, practionners often conduct *hyperparameter tuning*.
|
468 |
-
It consists of trying different combination of the model's parameters to maximise the accuracy score.
|
469 |
-
Hyperparameter tuning wasn't conduct here in order to insure the app doesn't lag.""")
|
470 |
|
471 |
|
472 |
|
|
|
9 |
import plotly.express as px
|
10 |
|
11 |
from st_pages import add_indentation
|
12 |
+
from utils import load_data_csv, check_password
|
13 |
|
14 |
from sklearn.datasets import fetch_california_housing
|
15 |
from sklearn.compose import make_column_selector as selector
|
|
|
123 |
# START OF THE PAGE
|
124 |
##############################################################################################
|
125 |
|
126 |
+
if check_password():
|
127 |
+
|
128 |
+
st.image("images/ML_header.jpg")
|
129 |
+
st.markdown("# Go further 🚀")
|
130 |
+
st.markdown("""This page allows you to test and compare results between different AI models, and gain a deeper understanding of how they make predictions. <br>
|
131 |
+
It includes three different types of **classification models** with Python code illustrations, as well as four datasets to choose from.
|
132 |
+
|
133 |
**Explainability** is also given for most models.
|
134 |
These results give an indication on which variable had the most impact on the model's final prediction. <br>
|
135 |
Note that each model has its own way of measuring explainability, which makes comparisions between model explainabilities difficult.
|
136 |
|
137 |
All of the classification models used in this page come from `scikit-learn`, which is a popular Data Science library in Python.
|
138 |
+
""", unsafe_allow_html=True)
|
139 |
+
try:
|
140 |
+
st.link_button("Go to the scikit-learn website", "https://scikit-learn.org/stable/index.html")
|
141 |
+
except:
|
142 |
+
st.markdown("You need internet connexion to access the link.")
|
143 |
|
144 |
+
st.markdown(" ")
|
145 |
+
st.divider()
|
146 |
|
147 |
|
148 |
+
path_data = r'data/other_data'
|
149 |
|
150 |
+
st.markdown("# Classification ")
|
151 |
+
st.markdown("""**Reminder**: Classification models are AI models that are trained to predict a finite number of values/categories.
|
152 |
+
Examples can be found in the *Supervised vs Unsupervised* page with the credit score classification and customer churn prediction use cases.""")
|
153 |
+
st.markdown(" ")
|
154 |
+
st.markdown(" ")
|
155 |
|
156 |
|
157 |
|
158 |
|
159 |
+
########################## SELECT A DATASET ###############################
|
160 |
|
161 |
+
st.markdown("### Select a dataset 📋")
|
162 |
+
st.markdown("""To perform the classification task, you can choose between three different datasets: **Titanic**, **Car evaluation**, **Wine quality** and **Diabetes prevention** <br>
|
163 |
+
Each dataset will be shown in its original format and will go through pre-processing steps to insure its quality and usability for the chosen model.
|
164 |
+
""", unsafe_allow_html=True)
|
165 |
|
166 |
+
st.warning("""**Note:** The performance of a Machine Learning model is sensitive to the data being used to train it.
|
167 |
+
Data cleaning and pre-processing are usually as important as training the AI model. These steps can include removing missing values, identifying outliers and transforming columns from text to numbers.""")
|
168 |
|
169 |
+
select_data = st.selectbox("Choose an option", ["Titanic 🚢", "Car evaluation 🚙", "Wine quality 🍷", "Diabetes prevention 👩⚕️"]) #label_visibility="collapsed")
|
170 |
+
st.markdown(" ")
|
171 |
|
172 |
+
if select_data =="Wine quality 🍷":
|
173 |
+
# Load data and clean it
|
174 |
+
data = load_data_csv(path_data, "winequality.csv")
|
175 |
+
data = data.loc[data["residual sugar"] < 40]
|
176 |
+
data = data.loc[data["free sulfur dioxide"] < 200]
|
177 |
+
data = data.loc[data["total sulfur dioxide"] < 400]
|
178 |
+
data.drop(columns=["free sulfur dioxide"], inplace=True)
|
179 |
|
180 |
+
X = data.drop(columns=["quality"])
|
181 |
+
y = data["quality"]
|
182 |
|
183 |
+
# Information on the data
|
184 |
+
st.info("""**About the data**: The goal of the wine quality dataset is to **predict the quality** of different wines using their formulation.
|
185 |
+
The target in this use case is the `quality` variable which has two possible values (Good and Mediocre).""")
|
186 |
|
187 |
+
# View data
|
188 |
+
view_data = st.checkbox("View the data", key="wine")
|
189 |
+
if view_data:
|
190 |
+
st.dataframe(data)
|
191 |
|
192 |
+
|
193 |
+
if select_data == "Titanic 🚢":
|
194 |
+
# Load data and clean it
|
195 |
+
data = load_data_csv(path_data, "titanic.csv")
|
196 |
+
data = data.drop(columns=["Name","Cabin","Ticket","PassengerId"]).dropna()
|
197 |
+
data["Survived"] = data["Survived"].map({0: "Died", 1:"Survived"})
|
198 |
+
data.rename({"Sex":"Gender"}, axis=1, inplace=True)
|
199 |
+
data["Age"] = data["Age"].astype(int)
|
200 |
+
data["Fare"] = data["Fare"].round(2)
|
201 |
+
|
202 |
+
cat_columns = data.select_dtypes(include="object").columns
|
203 |
+
data[cat_columns] = data[cat_columns].astype("category")
|
204 |
|
205 |
+
X = data.drop(columns=["Survived"])
|
206 |
+
y = data["Survived"]
|
207 |
|
208 |
+
# Information on the data
|
209 |
+
st.info("""**About the data**: The goal of the titanic dataset is to **predict whether a passenger on the ship survived**.
|
210 |
+
The target in this use case is the `Survived` variable which has two possible values (Died or Survived).
|
211 |
+
""")
|
212 |
|
213 |
+
# View data
|
214 |
+
view_data = st.checkbox("View the data", key="titanic")
|
215 |
+
if view_data:
|
216 |
+
st.dataframe(data)
|
217 |
+
|
218 |
+
# About the variables
|
219 |
+
about_var = st.checkbox("Information on the variables", key="titanic-var")
|
220 |
+
if about_var:
|
221 |
+
st.markdown("""
|
222 |
+
- **Survived**: Survival (Died or Survived)
|
223 |
+
- **Pclass**: Ticket class of the passenger (1=First, 2=Second, 3=Third)
|
224 |
+
- **Gender**: Gender
|
225 |
+
- **Age**: Age in years
|
226 |
+
- **SibSp**: Number of siblings aboard the Titanic
|
227 |
+
- **Parch**: Number of parents/children aboard the Titanic
|
228 |
+
- **Fare**: Passenger fare
|
229 |
+
- **Embarked**: Port of Embarkation (C=Cherbourg, Q=Queenstown, S=Southampton)""")
|
230 |
+
|
231 |
+
if select_data == "Car evaluation 🚙":
|
232 |
+
# Load data and clean it
|
233 |
+
data = load_data_csv(path_data, "car.csv")
|
234 |
+
data.rename({"Price":"Buying"}, axis=1, inplace=True)
|
235 |
+
cat_columns = data.select_dtypes(include="object").columns
|
236 |
+
data[cat_columns] = data[cat_columns].astype("category")
|
237 |
+
|
238 |
+
X = data.drop(columns="Evaluation")
|
239 |
+
y = data["Evaluation"]
|
240 |
+
|
241 |
+
# Information on the data
|
242 |
+
st.info("""**About the data**: The goal of the car evaluation dataset is to predict the evaluation made about a car before being sold.
|
243 |
+
The target in this use case is the `Evaluation` variable, which has two possible values (Not acceptable or acceptable)""")
|
244 |
+
|
245 |
+
# View data
|
246 |
+
view_data = st.checkbox("View the data", key="car")
|
247 |
+
if view_data:
|
248 |
+
st.dataframe(data)
|
249 |
+
|
250 |
+
# View data
|
251 |
+
about_var = st.checkbox("Information on the variables", key="car-var")
|
252 |
+
if about_var:
|
253 |
+
st.markdown("""
|
254 |
+
- **Buying**: Buying price of the vehicule (Very high, high, medium, low)
|
255 |
+
- **Maintenance**: Price for maintenance (Very high, high, medium, low)
|
256 |
+
- **Doors**: Number of doors in the vehicule (2, 3, 4, 5 or more)
|
257 |
+
- **Persons**: Capacity in terms of persons to carry (2, 4, more)
|
258 |
+
- **Luggage boot**: Size of luggage boot
|
259 |
+
- **Safety**: Estimated safety of the car (low, medium, high)
|
260 |
+
- **Evaluation**: Evaluation level (unacceptable, acceptable)""")
|
261 |
+
|
262 |
|
263 |
+
if select_data == "Diabetes prevention 👩⚕️":
|
264 |
+
# Load data and clean it
|
265 |
+
data = load_data_csv(path_data, "diabetes.csv")
|
266 |
+
data["Outcome"] = data["Outcome"].map({1:"Yes", 0:"No"})
|
267 |
+
#data.drop(columns=["DiabetesPedigreeFunction"], inplace=True)
|
268 |
+
# data.rename({"Price":"Buying"}, axis=1, inplace=True)
|
269 |
+
cat_columns = data.select_dtypes(include="object").columns
|
270 |
+
data[cat_columns] = data[cat_columns].astype("category")
|
271 |
|
272 |
+
X = data.drop(columns="Outcome")
|
273 |
+
y = data["Outcome"]
|
274 |
|
275 |
|
276 |
+
# Information on the data
|
277 |
+
st.info("""**About the data**: The goal of the diabetes dataset is to predict whether a patient has diabetes.
|
278 |
+
The target in this use case is the `Outcome` variable, which has two possible values (Yes or No)""")
|
279 |
|
280 |
+
# View data
|
281 |
+
view_data = st.checkbox("View the data", key="diabetes")
|
282 |
+
if view_data:
|
283 |
+
st.dataframe(data)
|
284 |
|
285 |
+
# View data
|
286 |
+
about_var = st.checkbox("Information on the variables", key="car-var")
|
287 |
+
if about_var:
|
288 |
+
st.markdown("""
|
289 |
+
- **Pregnancies**: Number of pregnancies had
|
290 |
+
- **Glucose**: The level of glucose in the patient's blood
|
291 |
+
- **BloodPressure**: Blood pressure measurement
|
292 |
+
- **SkinThickness**: Thickness of the skin
|
293 |
+
- **Insulin**: Level of insulin in the blood
|
294 |
+
- **BMI**: Body mass index
|
295 |
+
- **DiabetesPedigreeFunction**: Likelihood of diabetes depending on the patient's age and diabetic family history
|
296 |
+
- **Age**: Age of the patient
|
297 |
+
- **Outcome**: Whether the patient has diabetes (Yes or No)""")
|
298 |
|
299 |
+
st.markdown(" ")
|
300 |
+
st.markdown(" ")
|
301 |
|
302 |
|
303 |
|
304 |
|
305 |
+
########################## SELECT A MODEL ###############################
|
306 |
|
307 |
+
st.markdown("### Select a model 📚")
|
308 |
+
st.markdown("""You can choose between three types of classification models: **K nearest neighbors (KNN)**, **Decision Trees** and **Random Forests**. <br>
|
309 |
+
For each model, you will be given a short explanation as to how they function.
|
310 |
+
""", unsafe_allow_html=True)
|
311 |
|
312 |
+
st.warning("""**Note**: Different types of models exists for most Machine Learning tasks.
|
313 |
+
Models tend to vary in complexity and picking which one to train for a specific use case isn't always straightforward.
|
314 |
+
Complex model might output better results but take longer to make predictions.
|
315 |
+
The model selection step requires a good amount of testing by practitioners.""")
|
316 |
|
317 |
+
select_model = st.selectbox("**Choose an option**", ["K-nearest-neighbor 🏘️", "Decision Tree 🌳", "Random Forest 🏕️"])
|
318 |
+
st.markdown(" ")
|
319 |
|
320 |
|
321 |
+
if select_model == "K-nearest-neighbor 🏘️":
|
322 |
+
#st.markdown("#### Model: K-nearest-neighbor")
|
323 |
+
st.info("""**About the model**: K-nearest-neighbor (or KNN) is a type of classification model that uses neighboring points to classify new data.
|
324 |
+
When trying to predict a class to new data point, the algorithm will look at points in close proximity (or in its neighborhood) to make a decision.
|
325 |
+
The most common class in the points' neighborhood will then be chosen as the final prediction.""")
|
326 |
+
|
327 |
+
select_param = 6
|
328 |
+
model_dict = {"model":select_model, "param":select_param}
|
329 |
|
330 |
+
learn_model = st.checkbox("Learn more about the model", key="knn")
|
331 |
+
if learn_model:
|
332 |
+
st.markdown("""An important parameter in KNN algorithms is the number of points to choose as neighboors. <br>
|
333 |
+
The image below shows two cases where the number of neighboors (k) are equal to 3 and 6.
|
334 |
- When k is equal to 3 (the small dotted circle in the image below), the most common class is **Class B**. The red point will then be predicted as Classe B.
|
335 |
- When k is equal to 6 (the large dotted circle in the image below), the the most common class is **Class A**. The red point will then be predicted as Classe A.""",
|
336 |
+
unsafe_allow_html=True)
|
337 |
+
|
338 |
+
st.image("images/knn.png", width=600)
|
339 |
+
st.markdown("""K-nearest-neighbor algorithm are popular for their simplicity. <br>
|
340 |
+
This can be a drawback for use cases/dataset that require a more complex approach to make accurate predictions.""", unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
341 |
|
342 |
+
see_code_box = st.checkbox("See the code", key='knn_code')
|
343 |
+
if see_code_box:
|
344 |
+
see_code(select_model)
|
345 |
|
346 |
+
|
347 |
+
if select_model == "Decision Tree 🌳":
|
348 |
+
st.info("""**About the model**: Decision trees are classification model that split the prediction task into a succession of decisions, each with only two possible outcomes.
|
349 |
+
These decisions can be visualized as a tree, with data points arriving from the top of the tree and landing at final "prediction regions".""")
|
350 |
|
351 |
+
select_param = 8
|
352 |
+
model_dict = {"model":select_model, "param":select_param}
|
|
|
353 |
|
354 |
+
learn_model = st.checkbox("Learn more about the model", key="tree")
|
355 |
+
if learn_model:
|
356 |
+
st.markdown("""The following image showcases a decision tree which predicts whether a **bank should give out a loan** to a client. <br>
|
357 |
+
The data used to train the model has each client's **age**, **salary** and **number of children**.""", unsafe_allow_html=True)
|
358 |
+
|
359 |
+
st.markdown("""To predict whether a client gets a loan, the client's data goes through each 'leaf' in the tree (leaves are the blue box question in the image below) and **gets assigned the class of the final leaf it fell into** (either Get loan or Don't get loan).
|
360 |
+
For example, a client that is under 30 years old and has a lower salary than 2500$ will not be awarded a loan by the model.""", unsafe_allow_html=True)
|
361 |
+
|
362 |
+
st.image("images/decisiontree.png", width=800)
|
363 |
+
st.markdown("""Decision tree models are popular as they are easy to interpret. <br>
|
364 |
+
The higher the variable is on the tree, the more important it is in the decision process.""", unsafe_allow_html=True)
|
365 |
+
|
366 |
+
see_code_box = st.checkbox("See the code", key='tree_code')
|
367 |
+
if see_code_box:
|
368 |
+
see_code(select_model)
|
369 |
+
|
370 |
|
371 |
|
372 |
+
if select_model == "Random Forest 🏕️":
|
373 |
+
st.info("""**About the model:** Random Forest models generate multiple decision tree models to make predictions.
|
374 |
+
The main drawback of decision trees is that their predictions can be unstable, meaning that their output often changes.
|
375 |
+
Random Forest models combine the predictions of multiple decision trees to reduce this unstability and improve robustness.""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
|
377 |
+
select_param = 8
|
378 |
+
model_dict = {"model":select_model, "param":select_param}
|
379 |
+
|
380 |
+
learn_model = st.checkbox("Learn more about the model", key="tree")
|
381 |
+
if learn_model:
|
382 |
+
st.markdown("""Random Forests classifiers combine the results of multiple trees by apply **majority voting**, which means selecting the class that was most often predicted by trees as the final prediction.
|
383 |
+
In the following image, the random forest model built four decision trees, who each have made their own class prediction. <br>"""
|
384 |
+
, unsafe_allow_html=True)
|
385 |
+
|
386 |
+
st.markdown("""Class C was predicted twice, whereas Class B et D where only predicted once. <br>
|
387 |
+
The final prediction of the random forest model is thus Class C.""", unsafe_allow_html=True)
|
388 |
+
|
389 |
+
st.image("images/randomforest.png", width=800)
|
390 |
|
391 |
+
see_code_box = st.checkbox("See the code", key='forest_code')
|
392 |
+
if see_code_box:
|
393 |
+
see_code(select_model)
|
394 |
|
395 |
|
396 |
|
397 |
+
st.markdown(" ")
|
398 |
+
st.markdown(" ")
|
399 |
|
400 |
+
########################## RUN THE MODEL ###############################
|
401 |
|
402 |
+
st.markdown("### Train the model ⚙️")
|
403 |
+
st.markdown("""Now, you can build the chosen classification model and use the selected dataset to train it. <br>
|
404 |
+
You will get the model's accuracy in predicting each category, as well as the importance of each variable in the final predictions.""", unsafe_allow_html=True)
|
405 |
|
406 |
+
st.warning("""**Note**: Most machine learning models have an element of randomness in their predictions.
|
407 |
+
This explains why a model's accuracy might change even if you run it with the same dataset.""")
|
408 |
|
409 |
+
st.markdown(f"""You've selected the **{select_data}** dataset and the **{select_model}** model.""")
|
410 |
|
411 |
|
412 |
+
run_model = st.button("Run model", type="primary")
|
413 |
+
st.markdown(" ")
|
414 |
+
st.markdown(" ")
|
415 |
|
416 |
+
if run_model:
|
417 |
+
score, feature_imp, feature_names, labels = model_training(X, y, model_dict, _num_transformer=StandardScaler())
|
418 |
+
|
419 |
+
if select_model in ["Decision Tree 🌳", "Random Forest 🏕️"]: # show explainability for decision tree, random firest
|
420 |
+
tab1, tab2 = st.tabs(["Results", "Explainability"])
|
421 |
+
|
422 |
+
with tab1:
|
423 |
+
st.markdown("### Results")
|
424 |
+
st.markdown("""The values below represent the model's accuracy for each possible class.
|
425 |
+
The lowest possible accuracy is 0 and the highest 100.""")
|
426 |
+
if select_data == "Diabetes prevention 👩⚕️":
|
427 |
+
st.warning("""**Note**: The Diabetes dataset only contains information on 768 patients. 500 patients don't have diabetes and 268 do have the disease.
|
428 |
+
This small number of patient data explains why the model's performance isn't optimal.
|
429 |
+
Additional data collection as well as hyperparameter tuning can be conducted to improve results.""")
|
430 |
+
|
431 |
+
score_df = pd.DataFrame({"label":labels, "accuracy":np.round(score*100)})
|
432 |
+
fig = px.bar(score_df, x="label", y="accuracy", color="label", text_auto=True)
|
433 |
+
st.plotly_chart(fig, use_container_width=True)
|
434 |
+
|
435 |
+
st.warning("""**Note**: To improve the results of a model, practionners often conduct *hyperparameter tuning*.
|
436 |
+
It consists of trying different combination of the model's parameters to maximise the accuracy score.
|
437 |
+
Hyperparameter tuning wasn't conduct here in order to insure the app doesn't lag.""")
|
438 |
+
|
439 |
+
|
440 |
+
with tab2:
|
441 |
+
st.markdown("### Explainability")
|
442 |
+
st.markdown("""Variables with a high explainability score had the most impact on the model's predictions.
|
443 |
+
Variables with a low explainability score had a much smaller impact.""")
|
444 |
+
|
445 |
+
df_feature_imp = pd.DataFrame({"variable":feature_names, "importance":feature_imp})
|
446 |
+
df_feature_imp = df_feature_imp.groupby("variable").mean().reset_index()
|
447 |
+
df_feature_imp["importance"] = df_feature_imp["importance"].round(2)
|
448 |
+
df_feature_imp.sort_values(by=["importance"], ascending=False, inplace=True)
|
449 |
+
|
450 |
+
fig = px.bar(df_feature_imp, x="importance", y="variable", color="importance")
|
451 |
+
st.plotly_chart(fig, use_container_width=True)
|
452 |
|
453 |
+
else: # only show results for knn
|
454 |
+
st.markdown("### Results")
|
455 |
st.markdown("""The values below represent the model's accuracy for each possible class.
|
456 |
+
The lowest possible accuracy is 0 and the highest 100.""")
|
457 |
+
|
458 |
+
st.warning("""**Note**: The K-nearest-neighbor algorithm doesn't have a built-in solution to compute model explainability with `scikit-learn`.
|
459 |
+
You can use other python packages such as `SHAP` to compute explainability, which we didn't use here since they usually take a long time to output results.""")
|
460 |
+
|
461 |
if select_data == "Diabetes prevention 👩⚕️":
|
462 |
st.warning("""**Note**: The Diabetes dataset only contains information on 768 patients. 500 patients don't have diabetes and 268 do have the disease.
|
463 |
+
This small number of patient data explains why the model's performance isn't optimal.
|
464 |
+
Additional data collection as well as hyperparameter tuning can be conducted to improve results.""")
|
465 |
+
|
466 |
score_df = pd.DataFrame({"label":labels, "accuracy":np.round(score*100)})
|
467 |
+
fig = px.bar(score_df, x="label", y="accuracy", color="label", title="Accuracy results", text_auto=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
468 |
st.plotly_chart(fig, use_container_width=True)
|
469 |
|
470 |
+
st.warning("""**Note**: To improve the results of a model, practionners often conduct *hyperparameter tuning*.
|
471 |
+
It consists of trying different combination of the model's parameters to maximise the accuracy score.
|
472 |
+
Hyperparameter tuning wasn't conduct here in order to insure the app doesn't lag.""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
473 |
|
474 |
|
475 |
|
pages/image_classification.py
CHANGED
@@ -5,7 +5,7 @@ import os
|
|
5 |
import pandas as pd
|
6 |
import plotly.express as px
|
7 |
from PIL import Image
|
8 |
-
from utils import load_data_pickle
|
9 |
|
10 |
|
11 |
# import gradcam
|
@@ -72,259 +72,260 @@ gradcam_images_paths = ["images/meningioma_tumor.png", "images/no_tumor.png", "i
|
|
72 |
|
73 |
###################################### TITLE ####################################
|
74 |
|
75 |
-
|
76 |
|
77 |
-
st.markdown("
|
78 |
-
st.info("""**Image classification** is a process in Machine Learning and Computer Vision where an algorithm is trained to recognize and categorize images into predefined classes. It involves analyzing the visual content of an image and assigning it to a specific label based on its features.""")
|
79 |
-
#unsafe_allow_html=True)
|
80 |
-
st.markdown(" ")
|
81 |
-
st.markdown("""State-of-the-art image classification models use **neural networks** to predict whether an image belongs to a specific class.<br>
|
82 |
-
Each of the possible predicted classes are given a probability then the class with the highest value is assigned to the input image.""",
|
83 |
-
unsafe_allow_html=True)
|
84 |
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
st.
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
-
st.markdown(" ")
|
92 |
|
93 |
-
st.markdown("""Real-life applications of image classification includes:
|
94 |
- **Medical Imaging 👨⚕️**: Diagnose diseases and medical conditions from images such as X-rays, MRIs and CT scans to, for example, identify tumors and classify different types of cancers.
|
95 |
- **Autonomous Vehicules** 🏎️: Classify objects such as pedestrians, vehicles, traffic signs, lane markings, and obstacles, which is crucial for navigation and collision avoidance.
|
96 |
- **Satellite and Remote Sensing 🛰️**: Analyze satellite imagery to identify land use patterns, monitor vegetation health, assess environmental changes, and detect natural disasters such as wildfires and floods.
|
97 |
- **Quality Control 🛂**: Inspect products and identify defects to ensure compliance with quality standards during the manufacturying process.
|
98 |
-
|
99 |
|
100 |
-
# st.markdown("""Real-life applications of Brain Tumor includes:
|
101 |
-
# - **Research and development💰**: The technologies and methodologies developed for brain tumor classification can advance research in neuroscience, oncology, and the development of new diagnostic tools and treatments.
|
102 |
-
# - **Healthcare👨⚕️**: Data derived from the classification and analysis of brain tumors can inform public health decisions, healthcare policies, and resource allocation, emphasizing areas with higher incidences of certain types of tumors.
|
103 |
-
# - **Insurance Industry 🏬**: Predict future demand for products to optimize inventory levels, reduce holding costs, and improve supply chain efficiency.
|
104 |
-
# """)
|
105 |
|
106 |
|
107 |
-
|
108 |
|
109 |
|
110 |
-
# BEGINNING OF USE CASE
|
111 |
-
st.divider()
|
112 |
-
st.markdown("# Brain Tumor Classification 🧠")
|
113 |
|
114 |
-
st.info("""In this use case, a **brain tumor classification** model is leveraged to accurately identify the presence of tumors in MRI scans of the brain.
|
115 |
-
|
116 |
|
117 |
-
st.markdown(" ")
|
118 |
-
_, col, _ = st.columns([0.1,0.8,0.1])
|
119 |
-
with col:
|
120 |
-
|
121 |
|
122 |
-
st.markdown(" ")
|
123 |
-
st.markdown(" ")
|
124 |
|
125 |
-
### WHAT ARE BRAIN TUMORS ?
|
126 |
-
st.markdown(" ### What is a Brain Tumor ?")
|
127 |
-
st.markdown("""
|
128 |
-
A brain tumor occurs when **abnormal cells form within the brain**. Two main types of tumors exist: **cancerous (malignant) tumors** and **benign tumors**.
|
129 |
- **Cancerous tumors** are malignant tumors that have the ability to invade nearby tissues and spread to other parts of the body through a process called metastasis.
|
130 |
- **Benign tumors** can become quite large but will not invade nearby tissue or spread to other parts of the body. They can still cause serious health problems depending on their size, location and rate of growth.
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
|
135 |
-
st.markdown(" ")
|
136 |
-
st.markdown(" ")
|
137 |
-
st.markdown("### About the data 📋")
|
138 |
|
139 |
-
st.markdown("""You were provided with a large dataset which contains **anonymized patient MRI scans** categorized into three distinct classes: **pituitary tumor** (in most cases benign), **meningioma tumor** (cancerous) and **no tumor**.
|
140 |
-
This dataset will serve as the foundation for training our classification model, offering a comprehensive view of varied tumor presentations within the brain.""")
|
141 |
|
142 |
-
|
143 |
-
|
144 |
-
st.
|
145 |
|
146 |
-
|
147 |
-
|
148 |
-
# st.warning("You can view here a few examples of the MRI training data.")
|
149 |
-
# # image selection
|
150 |
-
# images = os.listdir(DATA_DIR)
|
151 |
-
# selected_image1 = st.selectbox("Choose an image to visualize 🔎 :", images, key="selectionbox_key_2")
|
152 |
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
# st.image(image, caption="Image selected", width=450)
|
157 |
|
158 |
-
# st.
|
159 |
-
#
|
|
|
|
|
|
|
|
|
160 |
|
161 |
-
|
162 |
-
|
|
|
|
|
163 |
|
|
|
|
|
164 |
|
|
|
|
|
165 |
|
166 |
-
st.markdown("### Train the algorithm ⚙️")
|
167 |
-
st.markdown("""**Training an AI model** means feeding it data that contains multiple examples/images each type of tumor to be detected.
|
168 |
-
By analyzing the provided MRI images, the model learns to discern the subtle differences between each classes, thereby enabling the precise identification of tumor types.""")
|
169 |
|
170 |
|
171 |
-
###
|
|
|
|
|
172 |
|
173 |
-
# Initialisation de l'état du modèle
|
174 |
-
if 'model_train' not in st.session_state:
|
175 |
-
st.session_state['model_train'] = False
|
176 |
|
177 |
-
|
178 |
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
with st.spinner('Training the model...'):
|
183 |
-
time.sleep(2)
|
184 |
-
st.success("The model has been trained.")
|
185 |
-
else:
|
186 |
-
# Afficher le statut
|
187 |
-
st.info("The model hasn't been trained yet.")
|
188 |
|
189 |
-
|
190 |
-
if st.session_state.model_train:
|
191 |
-
st.markdown(" ")
|
192 |
-
st.markdown(" ")
|
193 |
-
st.markdown("### See the results ☑️")
|
194 |
-
tab1, tab2 = st.tabs(["Performance", "Explainability"])
|
195 |
-
|
196 |
-
with tab1:
|
197 |
-
#st.subheader("Performance")
|
198 |
-
st.info("""**Evaluating a model's performance** helps provide a quantitative measurement of it's ability to make accurate predictions.
|
199 |
-
In this use case, the performance of the brain tumor classification model was measured by comparing the patient's true diagnosis with the class predicted by the trained model.""")
|
200 |
-
|
201 |
-
class_accuracy_path = "data/image_classification/class_accuracies.pkl"
|
202 |
-
|
203 |
-
# Charger les données depuis le fichier Pickle
|
204 |
-
try:
|
205 |
-
with open(class_accuracy_path, 'rb') as file:
|
206 |
-
class_accuracy = pickle.load(file)
|
207 |
-
except Exception as e:
|
208 |
-
st.error(f"Erreur lors du chargement du fichier : {e}")
|
209 |
-
class_accuracy = {}
|
210 |
-
|
211 |
-
if not isinstance(class_accuracy, dict):
|
212 |
-
st.error(f"Expected a dictionary, but got: {type(class_accuracy)}")
|
213 |
-
else:
|
214 |
-
# Conversion des données en DataFrame
|
215 |
-
df_accuracy = pd.DataFrame(list(class_accuracy.items()), columns=['Tumor Type', 'Accuracy'])
|
216 |
-
df_accuracy['Accuracy'] = ((df_accuracy['Accuracy'] * 100).round()).astype(int)
|
217 |
-
|
218 |
-
# Générer le graphique à barres avec Plotly
|
219 |
-
fig = px.bar(df_accuracy, x='Tumor Type', y='Accuracy',
|
220 |
-
text='Accuracy', color='Tumor Type',
|
221 |
-
title="Model Performance",
|
222 |
-
labels={'Accuracy': 'Accuracy (%)', 'Tumor Type': 'Tumor Type'})
|
223 |
-
|
224 |
-
fig.update_traces(texttemplate='%{text}%', textposition='outside')
|
225 |
-
|
226 |
-
# Afficher le graphique dans Streamlit
|
227 |
-
st.plotly_chart(fig, use_container_width=True)
|
228 |
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
|
|
|
|
|
|
|
|
|
|
233 |
|
|
|
|
|
234 |
st.markdown(" ")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
Pixels that are colored in 'red' had a larger impact on the model's output and thus its ability to distinguish different tumor types (or none).
|
250 |
-
|
251 |
-
""", unsafe_allow_html=True)
|
252 |
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
st.
|
259 |
-
|
260 |
-
# st.markdown("""
|
261 |
-
# <b>Interpretation</b>: <br>
|
262 |
|
263 |
-
|
264 |
-
# **Meningiomas** are tumors that originate from the meninges, the layers of tissue
|
265 |
-
# that envelop the brain and spinal cord. Although they are most often benign
|
266 |
-
# (noncancerous) and grow slowly, their location can cause significant issues by
|
267 |
-
# exerting pressure on the brain or spinal cord. Meningiomas can occur at various
|
268 |
-
# places around the brain and spinal cord and are more common in women than in men.
|
269 |
-
|
270 |
-
# ### Pituitary Tumors <br>
|
271 |
-
# **Pituitary** are growths that develop in the pituitary gland, a small gland located at the
|
272 |
-
# base of the brain, behind the nose, and between the ears. Despite their critical location,
|
273 |
-
# the majority of pituitary tumors are benign and grow slowly. This gland regulates many of the
|
274 |
-
# hormones that control various body functions, so even a small tumor can affect hormone production,
|
275 |
-
# leading to a variety of symptoms.""", unsafe_allow_html=True)
|
276 |
-
|
277 |
-
|
278 |
-
#################################################
|
279 |
|
280 |
-
st.markdown("
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
|
|
|
289 |
|
290 |
-
|
291 |
-
|
292 |
|
293 |
-
# Selection des images
|
294 |
-
images = os.listdir(DATA_DIR)
|
295 |
-
selected_image2 = st.selectbox("Choose an image", images, key="selectionbox_key_1")
|
296 |
|
297 |
-
#
|
298 |
-
|
299 |
-
image = Image.open(image_path)
|
300 |
-
st.markdown("#### You've selected the following image.")
|
301 |
-
st.image(image, caption="Image selected", width=300)
|
302 |
|
|
|
|
|
|
|
303 |
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
# # Prétraitement et prédiction
|
310 |
-
# image_preprocessed = preprocess(image)
|
311 |
-
# predicted_tensor, _ = predict(image_preprocessed, model)
|
312 |
|
313 |
-
# predicted_idx = predicted_tensor.item()
|
314 |
-
# predicted_category = categories[predicted_idx]
|
315 |
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
324 |
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
|
329 |
-
|
330 |
|
|
|
5 |
import pandas as pd
|
6 |
import plotly.express as px
|
7 |
from PIL import Image
|
8 |
+
from utils import load_data_pickle, check_password
|
9 |
|
10 |
|
11 |
# import gradcam
|
|
|
72 |
|
73 |
###################################### TITLE ####################################
|
74 |
|
75 |
+
if check_password():
|
76 |
|
77 |
+
st.markdown("# Image Classification 🖼️")
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
+
st.markdown("### What is Image classification ?")
|
80 |
+
st.info("""**Image classification** is a process in Machine Learning and Computer Vision where an algorithm is trained to recognize and categorize images into predefined classes. It involves analyzing the visual content of an image and assigning it to a specific label based on its features.""")
|
81 |
+
#unsafe_allow_html=True)
|
82 |
+
st.markdown(" ")
|
83 |
+
st.markdown("""State-of-the-art image classification models use **neural networks** to predict whether an image belongs to a specific class.<br>
|
84 |
+
Each of the possible predicted classes are given a probability then the class with the highest value is assigned to the input image.""",
|
85 |
+
unsafe_allow_html=True)
|
86 |
+
|
87 |
+
image_ts = Image.open('images/cnn_example.png')
|
88 |
+
_, col, _ = st.columns([0.2,0.8,0.2])
|
89 |
+
with col:
|
90 |
+
st.image(image_ts,
|
91 |
+
caption="An example of an image classification model, with the 'backbone model' as the neural network.")
|
92 |
|
93 |
+
st.markdown(" ")
|
94 |
|
95 |
+
st.markdown("""Real-life applications of image classification includes:
|
96 |
- **Medical Imaging 👨⚕️**: Diagnose diseases and medical conditions from images such as X-rays, MRIs and CT scans to, for example, identify tumors and classify different types of cancers.
|
97 |
- **Autonomous Vehicules** 🏎️: Classify objects such as pedestrians, vehicles, traffic signs, lane markings, and obstacles, which is crucial for navigation and collision avoidance.
|
98 |
- **Satellite and Remote Sensing 🛰️**: Analyze satellite imagery to identify land use patterns, monitor vegetation health, assess environmental changes, and detect natural disasters such as wildfires and floods.
|
99 |
- **Quality Control 🛂**: Inspect products and identify defects to ensure compliance with quality standards during the manufacturying process.
|
100 |
+
""")
|
101 |
|
102 |
+
# st.markdown("""Real-life applications of Brain Tumor includes:
|
103 |
+
# - **Research and development💰**: The technologies and methodologies developed for brain tumor classification can advance research in neuroscience, oncology, and the development of new diagnostic tools and treatments.
|
104 |
+
# - **Healthcare👨⚕️**: Data derived from the classification and analysis of brain tumors can inform public health decisions, healthcare policies, and resource allocation, emphasizing areas with higher incidences of certain types of tumors.
|
105 |
+
# - **Insurance Industry 🏬**: Predict future demand for products to optimize inventory levels, reduce holding costs, and improve supply chain efficiency.
|
106 |
+
# """)
|
107 |
|
108 |
|
109 |
+
###################################### USE CASE #######################################
|
110 |
|
111 |
|
112 |
+
# BEGINNING OF USE CASE
|
113 |
+
st.divider()
|
114 |
+
st.markdown("# Brain Tumor Classification 🧠")
|
115 |
|
116 |
+
st.info("""In this use case, a **brain tumor classification** model is leveraged to accurately identify the presence of tumors in MRI scans of the brain.
|
117 |
+
This application can be a great resource for healthcare professionals to facilite early detection and consequently improve treatment outcomes for patients.""")
|
118 |
|
119 |
+
st.markdown(" ")
|
120 |
+
_, col, _ = st.columns([0.1,0.8,0.1])
|
121 |
+
with col:
|
122 |
+
st.image("images/brain_tumor.jpg")
|
123 |
|
124 |
+
st.markdown(" ")
|
125 |
+
st.markdown(" ")
|
126 |
|
127 |
+
### WHAT ARE BRAIN TUMORS ?
|
128 |
+
st.markdown(" ### What is a Brain Tumor ?")
|
129 |
+
st.markdown("""A brain tumor occurs when **abnormal cells form within the brain**. Two main types of tumors exist: **cancerous (malignant) tumors** and **benign tumors**.
|
|
|
130 |
- **Cancerous tumors** are malignant tumors that have the ability to invade nearby tissues and spread to other parts of the body through a process called metastasis.
|
131 |
- **Benign tumors** can become quite large but will not invade nearby tissue or spread to other parts of the body. They can still cause serious health problems depending on their size, location and rate of growth.
|
132 |
+
""", unsafe_allow_html=True)
|
|
|
|
|
133 |
|
|
|
|
|
|
|
134 |
|
|
|
|
|
135 |
|
136 |
+
st.markdown(" ")
|
137 |
+
st.markdown(" ")
|
138 |
+
st.markdown("### About the data 📋")
|
139 |
|
140 |
+
st.markdown("""You were provided with a large dataset which contains **anonymized patient MRI scans** categorized into three distinct classes: **pituitary tumor** (in most cases benign), **meningioma tumor** (cancerous) and **no tumor**.
|
141 |
+
This dataset will serve as the foundation for training our classification model, offering a comprehensive view of varied tumor presentations within the brain.""")
|
|
|
|
|
|
|
|
|
142 |
|
143 |
+
_, col, _ = st.columns([0.15,0.7,0.15])
|
144 |
+
with col:
|
145 |
+
st.image("images/tumors_types_class.png")
|
|
|
146 |
|
147 |
+
# see_data = st.checkbox('**See the data**', key="image_class\seedata")
|
148 |
+
# if see_data:
|
149 |
+
# st.warning("You can view here a few examples of the MRI training data.")
|
150 |
+
# # image selection
|
151 |
+
# images = os.listdir(DATA_DIR)
|
152 |
+
# selected_image1 = st.selectbox("Choose an image to visualize 🔎 :", images, key="selectionbox_key_2")
|
153 |
|
154 |
+
# # show image
|
155 |
+
# image_path = os.path.join(DATA_DIR, selected_image1)
|
156 |
+
# image = Image.open(image_path)
|
157 |
+
# st.image(image, caption="Image selected", width=450)
|
158 |
|
159 |
+
# st.info("""**Note**: This dataset will serve as the foundation for training our classification model, offering a comprehensive view of varied tumor presentations within the brain.
|
160 |
+
# By analyzing these images, the model learns to discern the subtle differences between each class, thereby enabling the precise identification of tumor types.""")
|
161 |
|
162 |
+
st.markdown(" ")
|
163 |
+
st.markdown(" ")
|
164 |
|
|
|
|
|
|
|
165 |
|
166 |
|
167 |
+
st.markdown("### Train the algorithm ⚙️")
|
168 |
+
st.markdown("""**Training an AI model** means feeding it data that contains multiple examples/images each type of tumor to be detected.
|
169 |
+
By analyzing the provided MRI images, the model learns to discern the subtle differences between each classes, thereby enabling the precise identification of tumor types.""")
|
170 |
|
|
|
|
|
|
|
171 |
|
172 |
+
### CONDITION ##
|
173 |
|
174 |
+
# Initialisation de l'état du modèle
|
175 |
+
if 'model_train' not in st.session_state:
|
176 |
+
st.session_state['model_train'] = False
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
|
178 |
+
run_model = st.button("Train the model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
+
if run_model:
|
181 |
+
# Simuler l'entraînement du modèle
|
182 |
+
st.session_state.model_train = True
|
183 |
+
with st.spinner('Training the model...'):
|
184 |
+
time.sleep(2)
|
185 |
+
st.success("The model has been trained.")
|
186 |
+
else:
|
187 |
+
# Afficher le statut
|
188 |
+
st.info("The model hasn't been trained yet.")
|
189 |
|
190 |
+
# Afficher les résultats
|
191 |
+
if st.session_state.model_train:
|
192 |
st.markdown(" ")
|
193 |
+
st.markdown(" ")
|
194 |
+
st.markdown("### See the results ☑️")
|
195 |
+
tab1, tab2 = st.tabs(["Performance", "Explainability"])
|
196 |
+
|
197 |
+
with tab1:
|
198 |
+
#st.subheader("Performance")
|
199 |
+
st.info("""**Evaluating a model's performance** helps provide a quantitative measurement of it's ability to make accurate predictions.
|
200 |
+
In this use case, the performance of the brain tumor classification model was measured by comparing the patient's true diagnosis with the class predicted by the trained model.""")
|
201 |
+
|
202 |
+
class_accuracy_path = "data/image_classification/class_accuracies.pkl"
|
203 |
+
|
204 |
+
# Charger les données depuis le fichier Pickle
|
205 |
+
try:
|
206 |
+
with open(class_accuracy_path, 'rb') as file:
|
207 |
+
class_accuracy = pickle.load(file)
|
208 |
+
except Exception as e:
|
209 |
+
st.error(f"Erreur lors du chargement du fichier : {e}")
|
210 |
+
class_accuracy = {}
|
211 |
+
|
212 |
+
if not isinstance(class_accuracy, dict):
|
213 |
+
st.error(f"Expected a dictionary, but got: {type(class_accuracy)}")
|
214 |
+
else:
|
215 |
+
# Conversion des données en DataFrame
|
216 |
+
df_accuracy = pd.DataFrame(list(class_accuracy.items()), columns=['Tumor Type', 'Accuracy'])
|
217 |
+
df_accuracy['Accuracy'] = ((df_accuracy['Accuracy'] * 100).round()).astype(int)
|
218 |
+
|
219 |
+
# Générer le graphique à barres avec Plotly
|
220 |
+
fig = px.bar(df_accuracy, x='Tumor Type', y='Accuracy',
|
221 |
+
text='Accuracy', color='Tumor Type',
|
222 |
+
title="Model Performance",
|
223 |
+
labels={'Accuracy': 'Accuracy (%)', 'Tumor Type': 'Tumor Type'})
|
224 |
+
|
225 |
+
fig.update_traces(texttemplate='%{text}%', textposition='outside')
|
226 |
+
|
227 |
+
# Afficher le graphique dans Streamlit
|
228 |
+
st.plotly_chart(fig, use_container_width=True)
|
229 |
|
230 |
+
|
231 |
+
st.markdown("""<i>The model's accuracy was evaluated across two types of tumors (pituitary and meningioma) and no tumor type.</i>
|
232 |
+
<i>This evaluation is vital for determining if the model performs consistently across different tumor classifications, or if it encounters difficulties accurately distinguishing between two classes.""",
|
233 |
+
unsafe_allow_html=True)
|
234 |
+
|
235 |
+
st.markdown(" ")
|
236 |
+
|
237 |
+
st.markdown("""**Interpretation**: <br>
|
238 |
+
Our model demonstrates high accuracy in predicting cancerous type tumors (meningioma) as well as 'healthy' brain scans (no tumor) with a 98% accuracy for both.
|
239 |
+
It is observed that the model's performance is lower for pituitary type tumors, as it is around 81%.
|
240 |
+
This discrepancy may indicate that the model finds it more challenging to distinguish pituitary tumors from other tumor
|
241 |
+
types, possibly due to their unique characteristics or lower representation in the training data.
|
242 |
+
""", unsafe_allow_html=True)
|
|
|
|
|
|
|
243 |
|
244 |
+
with tab2:
|
245 |
+
#st.subheader("Model Explainability with Grad-CAM")
|
246 |
+
st.info("""**Explainability in AI** refers to the ability to **understand and interpret how AI systems make predictions** and how to quantify the impact of the provided data on its results.
|
247 |
+
In the case of image classification, explainability can be measured by analyzing which of the image's pixel had the most impact on the model's output.""")
|
248 |
+
st.markdown(" ")
|
249 |
+
st.markdown("""The following images show the output of image classification explainability applied on three images used during training. <br>
|
250 |
+
Pixels that are colored in 'red' had a larger impact on the model's output and thus its ability to distinguish different tumor types (or none).
|
|
|
|
|
251 |
|
252 |
+
""", unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
+
st.markdown(" ")
|
255 |
+
gradcam_images_paths = ["images/meningioma_tumor.png", "images/no_tumor.png", "images/pituitary.png"]
|
256 |
+
class_names = ["Meningioma Tumor", "No Tumor", "Pituitary Tumor"]
|
257 |
+
|
258 |
+
for path, class_name in zip(gradcam_images_paths, class_names):
|
259 |
+
st.image(path, caption=f"Explainability for {class_name}")
|
260 |
+
|
261 |
+
# st.markdown("""
|
262 |
+
# <b>Interpretation</b>: <br>
|
263 |
+
|
264 |
+
# ### Meningioma Tumors <br>
|
265 |
+
# **Meningiomas** are tumors that originate from the meninges, the layers of tissue
|
266 |
+
# that envelop the brain and spinal cord. Although they are most often benign
|
267 |
+
# (noncancerous) and grow slowly, their location can cause significant issues by
|
268 |
+
# exerting pressure on the brain or spinal cord. Meningiomas can occur at various
|
269 |
+
# places around the brain and spinal cord and are more common in women than in men.
|
270 |
+
|
271 |
+
# ### Pituitary Tumors <br>
|
272 |
+
# **Pituitary** are growths that develop in the pituitary gland, a small gland located at the
|
273 |
+
# base of the brain, behind the nose, and between the ears. Despite their critical location,
|
274 |
+
# the majority of pituitary tumors are benign and grow slowly. This gland regulates many of the
|
275 |
+
# hormones that control various body functions, so even a small tumor can affect hormone production,
|
276 |
+
# leading to a variety of symptoms.""", unsafe_allow_html=True)
|
277 |
+
|
278 |
+
|
279 |
+
#################################################
|
280 |
+
|
281 |
+
st.markdown(" ")
|
282 |
+
st.markdown(" ")
|
283 |
+
st.markdown("### Classify new MRI scans 🆕")
|
284 |
|
285 |
+
st.info("**Note**: The brain tumor classification model can classify new MRI images only if it has been previously trained.")
|
286 |
|
287 |
+
st.markdown("""Here, you are provided the MRI scans of nine new patients.
|
288 |
+
Select an image and press 'run the model' to classify the MRI as either a pituitary tumor, a meningioma tumor or no tumor.""")
|
289 |
|
|
|
|
|
|
|
290 |
|
291 |
+
# Définition des catégories de tumeurs
|
292 |
+
categories = ["pituitary tumor", "no tumor", "meningioma tumor"]
|
|
|
|
|
|
|
293 |
|
294 |
+
# Selection des images
|
295 |
+
images = os.listdir(DATA_DIR)
|
296 |
+
selected_image2 = st.selectbox("Choose an image", images, key="selectionbox_key_1")
|
297 |
|
298 |
+
# show image
|
299 |
+
image_path = os.path.join(DATA_DIR, selected_image2)
|
300 |
+
image = Image.open(image_path)
|
301 |
+
st.markdown("#### You've selected the following image.")
|
302 |
+
st.image(image, caption="Image selected", width=300)
|
|
|
|
|
|
|
303 |
|
|
|
|
|
304 |
|
305 |
+
if st.button('**Make predictions**', key='another_action_button'):
|
306 |
+
results_path = r"data/image_classification"
|
307 |
+
df_results = load_data_pickle(results_path, "results.pkl")
|
308 |
+
predicted_category = df_results.loc[df_results["image"]==selected_image2,"class"].to_numpy()
|
309 |
+
|
310 |
+
# # Prétraitement et prédiction
|
311 |
+
# image_preprocessed = preprocess(image)
|
312 |
+
# predicted_tensor, _ = predict(image_preprocessed, model)
|
313 |
+
|
314 |
+
# predicted_idx = predicted_tensor.item()
|
315 |
+
# predicted_category = categories[predicted_idx]
|
316 |
+
|
317 |
+
# Affichage de la prédiction avec la catégorie prédite
|
318 |
+
if predicted_category == "pituitary":
|
319 |
+
st.warning(f"**Results**: Pituitary tumor was detected. ")
|
320 |
+
elif predicted_category == "no tumor":
|
321 |
+
st.success(f"**Results**: No tumor was detected.")
|
322 |
+
elif predicted_category == "meningnoma":
|
323 |
+
st.error(f"**Results**: Meningioma was detected.")
|
324 |
+
|
325 |
|
326 |
+
# image_path = os.path.join(DATA_DIR, selected_image2)
|
327 |
+
# image = Image.open(image_path)
|
328 |
+
# st.image(image, caption="Image selected", width=450)
|
329 |
|
330 |
+
|
331 |
|
pages/object_detection.py
CHANGED
@@ -8,6 +8,7 @@ import plotly.express as px
|
|
8 |
import pickle
|
9 |
import random
|
10 |
|
|
|
11 |
from PIL import Image
|
12 |
from transformers import YolosFeatureExtractor, YolosForObjectDetection
|
13 |
from torchvision.transforms import ToTensor, ToPILImage
|
@@ -137,257 +138,257 @@ cats = ['shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jac
|
|
137 |
|
138 |
######################################################################################################################################
|
139 |
|
140 |
-
|
141 |
-
st.markdown("# Object Detection 📹")
|
142 |
|
143 |
-
st.markdown("### What is Object Detection ?")
|
144 |
-
|
145 |
-
#st.markdown("""Object detection involves **identifying** and **locating objects** within an image or video frame through bounding boxes. """)
|
146 |
-
st.info("""Object Detection is a computer vision task in which the goal is to **detect** and **locate objects** of interest in an image or video.
|
147 |
-
|
148 |
|
149 |
|
150 |
-
st.markdown("Here is an example of Object Detection for Traffic Analysis.")
|
151 |
-
#image_od = Image.open('images/od_2.png')
|
152 |
-
#st.image(image_od, width=600)
|
153 |
-
st.video(data='https://www.youtube.com/watch?v=PVCGDoTZHaI')
|
154 |
|
155 |
-
st.markdown(" ")
|
156 |
|
157 |
-
st.markdown("""Common applications of Object Detection include:
|
158 |
- **Autonomous Vehicles** :car: : Object detection is crucial for self-driving cars to track pedestrians, cyclists, other vehicles, and obstacles on the road.
|
159 |
- **Retail** 🏬 : Implementing smart shelves and checkout systems that use object detection to track inventory and monitor stock levels.
|
160 |
- **Healthcare** 👨⚕️: Detecting and tracking anomalies in medical images, such as tumors or abnormalities, for diagnostic purposes or prevention.
|
161 |
- **Manufacturing** 🏭: Quality control on production lines by detecting defects or irregularities in manufactured products. Ensuring workplace safety by monitoring the movement of workers and equipment.
|
162 |
-
""")
|
163 |
-
|
164 |
|
165 |
|
166 |
-
############################# USE CASE #############################
|
167 |
-
st.markdown(" ")
|
168 |
-
st.divider()
|
169 |
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
st.info("""**Object detection models** can very valuable for fashion retailers wishing to improve customer experience. They can provide, for example, **product recognition**, **visual search**
|
174 |
-
and even **virtual try-ons**.""")
|
175 |
|
176 |
-
st.markdown("
|
|
|
|
|
|
|
|
|
177 |
|
178 |
-
st.markdown("
|
179 |
-
st.markdown(" ")
|
180 |
|
181 |
-
|
182 |
-
|
183 |
-
# for img, col in zip(images_dior,columns_img):
|
184 |
-
# with col:
|
185 |
-
# st.image(img)
|
186 |
|
187 |
-
|
188 |
-
|
189 |
-
|
|
|
|
|
190 |
|
|
|
|
|
|
|
191 |
|
192 |
-
st.markdown(" ")
|
193 |
-
st.markdown(" ")
|
194 |
|
|
|
|
|
195 |
|
196 |
-
st.markdown("### About the model 📚")
|
197 |
-
st.markdown("""The object detection model was trained to **detect specific clothing items** on images. <br>
|
198 |
-
Below is a list of the <b>46</b> different types of clothing items the model can identify and locate.""", unsafe_allow_html=True)
|
199 |
|
200 |
-
|
|
|
|
|
201 |
|
202 |
-
|
203 |
-
annotated_text([cats_annotated])
|
204 |
|
205 |
-
|
206 |
-
|
207 |
-
# 'coat', 'dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove', 'watch',
|
208 |
-
# 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'bag, wallet', 'scarf', 'umbrella', 'hood', 'collar', 'lapel',
|
209 |
-
# 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow', 'flower', 'fringe', 'ribbon', 'rivet',
|
210 |
-
# 'ruffle', 'sequin', 'tassel'""", unsafe_allow_html=True)
|
211 |
|
212 |
-
st.markdown("
|
213 |
-
|
214 |
-
|
|
|
|
|
|
|
215 |
|
|
|
|
|
|
|
216 |
|
217 |
|
218 |
-
############## SELECT AN IMAGE ###############
|
219 |
|
220 |
-
|
221 |
-
st.markdown("""The images provided were taken from **Dior's 2020 Fall Women Fashion Show**""")
|
222 |
|
223 |
-
|
224 |
-
|
225 |
-
list_images = os.listdir(fashion_images_path)
|
226 |
-
image_name = st.selectbox("Choose an image", list_images)
|
227 |
-
image_ = os.path.join(fashion_images_path, image_name)
|
228 |
-
st.image(image_, width=300)
|
229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
-
# image_ = None
|
232 |
-
# select_image_box = st.radio(
|
233 |
-
# "**Select the image you wish to run the model on**",
|
234 |
-
# ["Choose an existing image", "Load your own image"],
|
235 |
-
# index=None,)# #label_visibility="collapsed")
|
236 |
|
237 |
-
#
|
238 |
-
#
|
239 |
-
#
|
240 |
-
#
|
241 |
-
|
242 |
-
# if image_ is not None:
|
243 |
-
# image_ = os.path.join(fashion_images_path,image_)
|
244 |
-
# st.markdown("You've selected the following image:")
|
245 |
-
# st.image(image_, width=300)
|
246 |
-
|
247 |
-
# elif select_image_box == "Load your own image":
|
248 |
-
# image_ = st.file_uploader("Load an image here",
|
249 |
-
# key="OD_dior", type=['jpg','jpeg','png'], label_visibility="collapsed")
|
250 |
-
|
251 |
-
# st.warning("""**Note**: The model tends to perform better with images of people/clothing items facing forward.
|
252 |
-
# Choose this type of image if you want optimal results.""")
|
253 |
-
# st.warning("""**Note:** The model was trained to detect clothing items on a single person.
|
254 |
-
# If your image contains more than one person, the model won't detect the items of the other persons.""")
|
255 |
|
256 |
-
#
|
257 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
258 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
|
260 |
-
|
261 |
-
st.
|
262 |
|
263 |
|
|
|
|
|
264 |
|
265 |
-
########## SELECT AN ELEMENT TO DETECT ##################
|
266 |
|
267 |
|
268 |
-
|
269 |
|
270 |
-
# st.markdown("#### Choose the elements you want to detect 👉")
|
271 |
|
272 |
-
|
273 |
-
# container = st.container()
|
274 |
-
# selected_options = None
|
275 |
-
# all = st.checkbox("Select all")
|
276 |
|
277 |
-
#
|
278 |
-
# selected_options = container.multiselect("**Select one or more items**", cats, cats)
|
279 |
-
# else:
|
280 |
-
# selected_options = container.multiselect("**Select one or more items**", cats)
|
281 |
|
282 |
-
#
|
283 |
-
|
284 |
-
|
|
|
285 |
|
|
|
|
|
|
|
|
|
286 |
|
287 |
-
#
|
288 |
-
|
|
|
289 |
|
290 |
|
|
|
|
|
291 |
|
292 |
-
############## SELECT A THRESHOLD ###############
|
293 |
|
294 |
-
st.markdown("### Define a threshold for predictions 🔎")
|
295 |
-
st.markdown("""In this section, you can select a threshold for the model's final predictions. <br>
|
296 |
-
Objects that are given a lower score than the chosen threshold will be ignored in the final results""", unsafe_allow_html=True)
|
297 |
|
298 |
-
|
299 |
-
Each object is given a class based on a probability score computed by the model. A high probability signals that the model is confident in its prediction.
|
300 |
-
On the contrary, a lower probability score signals a level of uncertainty.""")
|
301 |
|
302 |
-
st.markdown(" ")
|
303 |
-
|
|
|
304 |
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
caption="Examples of object detection with bounding boses and probability scores")
|
309 |
|
310 |
-
st.markdown(" ")
|
|
|
311 |
|
312 |
-
st.
|
|
|
|
|
|
|
313 |
|
314 |
-
|
315 |
-
# Elements that are identified with a lower probability than the given threshold will be ignored in the final results.""")
|
316 |
|
317 |
-
|
318 |
|
|
|
|
|
319 |
|
320 |
-
|
321 |
-
# st.error("""**Warning**: Selecting a low threshold (below 0.6) could lead the model to make errors and detect too many objects.""")
|
322 |
|
323 |
-
st.write("You've selected a threshold at", threshold)
|
324 |
-
st.markdown(" ")
|
325 |
|
|
|
|
|
326 |
|
|
|
|
|
327 |
|
328 |
-
pickle_file_path = r"data/dior_show/results"
|
329 |
|
330 |
|
331 |
-
|
332 |
|
333 |
-
run_model = st.button("**Run the model**", type="primary")
|
334 |
|
335 |
-
|
336 |
-
if image_ != None and selected_options != None and threshold!= None:
|
337 |
-
with st.spinner('Wait for it...'):
|
338 |
-
## SELECT IMAGE
|
339 |
-
#st.write(image_)
|
340 |
-
image = Image.open(image_)
|
341 |
-
image = fix_channels(ToTensor()(image))
|
342 |
|
343 |
-
|
344 |
-
FEATURE_EXTRACTOR_PATH = "hustvl/yolos-small"
|
345 |
-
MODEL_PATH = "valentinafeve/yolos-fashionpedia"
|
346 |
-
# feature_extractor, model = load_model(FEATURE_EXTRACTOR_PATH, MODEL_PATH)
|
347 |
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
# with open(os.path.join(pickle_file_path, f"{image_name}_results.pkl"), 'wb') as file:
|
356 |
-
# pickle.dump(outputs, file)
|
357 |
-
|
358 |
-
image_name = image_name[:5]
|
359 |
-
path_load_pickle = os.path.join(pickle_file_path, f"{image_name}_results.pkl")
|
360 |
-
with open(path_load_pickle, 'rb') as pickle_file:
|
361 |
-
outputs = pickle.load(pickle_file)
|
362 |
-
|
363 |
-
probas, keep = return_probas(outputs, threshold)
|
364 |
|
365 |
-
|
|
|
|
|
|
|
366 |
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
st.info("""**Note**: Some items might have been detected more than once on the image.
|
384 |
-
For these items, we've computed the average probability score across all detections.""")
|
385 |
-
visualize_probas(probas, threshold, colors_used)
|
386 |
|
387 |
-
|
388 |
|
389 |
-
|
390 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
391 |
|
392 |
|
393 |
|
|
|
8 |
import pickle
|
9 |
import random
|
10 |
|
11 |
+
from utils import check_password
|
12 |
from PIL import Image
|
13 |
from transformers import YolosFeatureExtractor, YolosForObjectDetection
|
14 |
from torchvision.transforms import ToTensor, ToPILImage
|
|
|
138 |
|
139 |
######################################################################################################################################
|
140 |
|
141 |
+
if check_password():
|
142 |
+
st.markdown("# Object Detection 📹")
|
143 |
|
144 |
+
st.markdown("### What is Object Detection ?")
|
145 |
+
|
146 |
+
#st.markdown("""Object detection involves **identifying** and **locating objects** within an image or video frame through bounding boxes. """)
|
147 |
+
st.info("""Object Detection is a computer vision task in which the goal is to **detect** and **locate objects** of interest in an image or video.
|
148 |
+
The task involves identifying the position and boundaries of objects (or **bounding boxes**) in an image, and classifying the objects into different categories.""")
|
149 |
|
150 |
|
151 |
+
st.markdown("Here is an example of Object Detection for Traffic Analysis.")
|
152 |
+
#image_od = Image.open('images/od_2.png')
|
153 |
+
#st.image(image_od, width=600)
|
154 |
+
st.video(data='https://www.youtube.com/watch?v=PVCGDoTZHaI')
|
155 |
|
156 |
+
st.markdown(" ")
|
157 |
|
158 |
+
st.markdown("""Common applications of Object Detection include:
|
159 |
- **Autonomous Vehicles** :car: : Object detection is crucial for self-driving cars to track pedestrians, cyclists, other vehicles, and obstacles on the road.
|
160 |
- **Retail** 🏬 : Implementing smart shelves and checkout systems that use object detection to track inventory and monitor stock levels.
|
161 |
- **Healthcare** 👨⚕️: Detecting and tracking anomalies in medical images, such as tumors or abnormalities, for diagnostic purposes or prevention.
|
162 |
- **Manufacturing** 🏭: Quality control on production lines by detecting defects or irregularities in manufactured products. Ensuring workplace safety by monitoring the movement of workers and equipment.
|
163 |
+
""")
|
|
|
164 |
|
165 |
|
|
|
|
|
|
|
166 |
|
167 |
+
############################# USE CASE #############################
|
168 |
+
st.markdown(" ")
|
169 |
+
st.divider()
|
|
|
|
|
170 |
|
171 |
+
st.markdown("# Fashion Object Detection 👗")
|
172 |
+
# st.info("""This use case showcases the application of **Object detection** to detect clothing items/features on images. <br>
|
173 |
+
# The images used were gathered from Dior's""")
|
174 |
+
st.info("""**Object detection models** can very valuable for fashion retailers wishing to improve customer experience. They can provide, for example, **product recognition**, **visual search**
|
175 |
+
and even **virtual try-ons**.""")
|
176 |
|
177 |
+
st.markdown("In this use case, we are going to show an object detection model that as able to identify and locate different articles of clothings on fashion show images.")
|
|
|
178 |
|
179 |
+
st.markdown(" ")
|
180 |
+
st.markdown(" ")
|
|
|
|
|
|
|
181 |
|
182 |
+
# images_dior = [os.path.join("data/dior_show/images",url) for url in os.listdir("data/dior_show/images") if url != "results"]
|
183 |
+
# columns_img = st.columns(4)
|
184 |
+
# for img, col in zip(images_dior,columns_img):
|
185 |
+
# with col:
|
186 |
+
# st.image(img)
|
187 |
|
188 |
+
_, col, _ = st.columns([0.1,0.8,0.1])
|
189 |
+
with col:
|
190 |
+
st.image("images/fashion_od2.png")
|
191 |
|
|
|
|
|
192 |
|
193 |
+
st.markdown(" ")
|
194 |
+
st.markdown(" ")
|
195 |
|
|
|
|
|
|
|
196 |
|
197 |
+
st.markdown("### About the model 📚")
|
198 |
+
st.markdown("""The object detection model was trained to **detect specific clothing items** on images. <br>
|
199 |
+
Below is a list of the <b>46</b> different types of clothing items the model can identify and locate.""", unsafe_allow_html=True)
|
200 |
|
201 |
+
colors = ["#8ef", "#faa", "#afa", "#fea", "#8ef","#afa"]*7 + ["#8ef", "#faa", "#afa", "#fea"]
|
|
|
202 |
|
203 |
+
cats_annotated = [(g,"","#afa") for g in cats]
|
204 |
+
annotated_text([cats_annotated])
|
|
|
|
|
|
|
|
|
205 |
|
206 |
+
# st.markdown("""**Here are the 'objects' the model is able to detect**: <br>
|
207 |
+
# 'shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 'skirt',
|
208 |
+
# 'coat', 'dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove', 'watch',
|
209 |
+
# 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'bag, wallet', 'scarf', 'umbrella', 'hood', 'collar', 'lapel',
|
210 |
+
# 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow', 'flower', 'fringe', 'ribbon', 'rivet',
|
211 |
+
# 'ruffle', 'sequin', 'tassel'""", unsafe_allow_html=True)
|
212 |
|
213 |
+
st.markdown("Credits for the model: https://huggingface.co/valentinafeve/yolos-fashionpedia")
|
214 |
+
st.markdown("")
|
215 |
+
st.markdown("")
|
216 |
|
217 |
|
|
|
218 |
|
219 |
+
############## SELECT AN IMAGE ###############
|
|
|
220 |
|
221 |
+
st.markdown("### Select an image 🖼️")
|
222 |
+
st.markdown("""The images provided were taken from **Dior's 2020 Fall Women Fashion Show**""")
|
|
|
|
|
|
|
|
|
223 |
|
224 |
+
image_ = None
|
225 |
+
fashion_images_path = r"data/dior_show/images"
|
226 |
+
list_images = os.listdir(fashion_images_path)
|
227 |
+
image_name = st.selectbox("Choose an image", list_images)
|
228 |
+
image_ = os.path.join(fashion_images_path, image_name)
|
229 |
+
st.image(image_, width=300)
|
230 |
|
|
|
|
|
|
|
|
|
|
|
231 |
|
232 |
+
# image_ = None
|
233 |
+
# select_image_box = st.radio(
|
234 |
+
# "**Select the image you wish to run the model on**",
|
235 |
+
# ["Choose an existing image", "Load your own image"],
|
236 |
+
# index=None,)# #label_visibility="collapsed")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
|
238 |
+
# if select_image_box == "Choose an existing image":
|
239 |
+
# fashion_images_path = r"data/dior_show/images"
|
240 |
+
# list_images = os.listdir(fashion_images_path)
|
241 |
+
# image_ = st.selectbox("", list_images, label_visibility="collapsed")
|
242 |
+
|
243 |
+
# if image_ is not None:
|
244 |
+
# image_ = os.path.join(fashion_images_path,image_)
|
245 |
+
# st.markdown("You've selected the following image:")
|
246 |
+
# st.image(image_, width=300)
|
247 |
|
248 |
+
# elif select_image_box == "Load your own image":
|
249 |
+
# image_ = st.file_uploader("Load an image here",
|
250 |
+
# key="OD_dior", type=['jpg','jpeg','png'], label_visibility="collapsed")
|
251 |
+
|
252 |
+
# st.warning("""**Note**: The model tends to perform better with images of people/clothing items facing forward.
|
253 |
+
# Choose this type of image if you want optimal results.""")
|
254 |
+
# st.warning("""**Note:** The model was trained to detect clothing items on a single person.
|
255 |
+
# If your image contains more than one person, the model won't detect the items of the other persons.""")
|
256 |
|
257 |
+
# if image_ is not None:
|
258 |
+
# st.image(Image.open(image_), width=300)
|
259 |
|
260 |
|
261 |
+
st.markdown(" ")
|
262 |
+
st.markdown(" ")
|
263 |
|
|
|
264 |
|
265 |
|
266 |
+
########## SELECT AN ELEMENT TO DETECT ##################
|
267 |
|
|
|
268 |
|
269 |
+
dict_cats = dict(zip(np.arange(len(cats)), cats))
|
|
|
|
|
|
|
270 |
|
271 |
+
# st.markdown("#### Choose the elements you want to detect 👉")
|
|
|
|
|
|
|
272 |
|
273 |
+
# # Select one or more elements to detect
|
274 |
+
# container = st.container()
|
275 |
+
# selected_options = None
|
276 |
+
# all = st.checkbox("Select all")
|
277 |
|
278 |
+
# if all:
|
279 |
+
# selected_options = container.multiselect("**Select one or more items**", cats, cats)
|
280 |
+
# else:
|
281 |
+
# selected_options = container.multiselect("**Select one or more items**", cats)
|
282 |
|
283 |
+
#cats = selected_options
|
284 |
+
selected_options = cats
|
285 |
+
dict_cats_final = {key:value for (key,value) in dict_cats.items() if value in selected_options}
|
286 |
|
287 |
|
288 |
+
# st.markdown(" ")
|
289 |
+
# st.markdown(" ")
|
290 |
|
|
|
291 |
|
|
|
|
|
|
|
292 |
|
293 |
+
############## SELECT A THRESHOLD ###############
|
|
|
|
|
294 |
|
295 |
+
st.markdown("### Define a threshold for predictions 🔎")
|
296 |
+
st.markdown("""In this section, you can select a threshold for the model's final predictions. <br>
|
297 |
+
Objects that are given a lower score than the chosen threshold will be ignored in the final results""", unsafe_allow_html=True)
|
298 |
|
299 |
+
st.info("""**Note**: Object detection models detect objects using bounding boxes as well as assign objects to specific classes.
|
300 |
+
Each object is given a class based on a probability score computed by the model. A high probability signals that the model is confident in its prediction.
|
301 |
+
On the contrary, a lower probability score signals a level of uncertainty.""")
|
|
|
302 |
|
303 |
+
st.markdown(" ")
|
304 |
+
#st.markdown("The images below are examples of probability scores given by object detection models for each element detected.")
|
305 |
|
306 |
+
_, col, _ = st.columns([0.2,0.6,0.2])
|
307 |
+
with col:
|
308 |
+
st.image("images/probability_od.png",
|
309 |
+
caption="Examples of object detection with bounding boses and probability scores")
|
310 |
|
311 |
+
st.markdown(" ")
|
|
|
312 |
|
313 |
+
st.markdown("**Select a threshold** ")
|
314 |
|
315 |
+
# st.warning("""**Note**: The threshold helps you decide how confident you want your model to be with its predictions.
|
316 |
+
# Elements that are identified with a lower probability than the given threshold will be ignored in the final results.""")
|
317 |
|
318 |
+
threshold = st.slider('**Select a threshold**', min_value=0.5, step=0.05, max_value=1.0, value=0.75, label_visibility="collapsed")
|
|
|
319 |
|
|
|
|
|
320 |
|
321 |
+
# if threshold < 0.6:
|
322 |
+
# st.error("""**Warning**: Selecting a low threshold (below 0.6) could lead the model to make errors and detect too many objects.""")
|
323 |
|
324 |
+
st.write("You've selected a threshold at", threshold)
|
325 |
+
st.markdown(" ")
|
326 |
|
|
|
327 |
|
328 |
|
329 |
+
pickle_file_path = r"data/dior_show/results"
|
330 |
|
|
|
331 |
|
332 |
+
############# RUN MODEL ################
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
|
334 |
+
run_model = st.button("**Run the model**", type="primary")
|
|
|
|
|
|
|
335 |
|
336 |
+
if run_model:
|
337 |
+
if image_ != None and selected_options != None and threshold!= None:
|
338 |
+
with st.spinner('Wait for it...'):
|
339 |
+
## SELECT IMAGE
|
340 |
+
#st.write(image_)
|
341 |
+
image = Image.open(image_)
|
342 |
+
image = fix_channels(ToTensor()(image))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
|
344 |
+
## LOAD OBJECT DETECTION MODEL
|
345 |
+
FEATURE_EXTRACTOR_PATH = "hustvl/yolos-small"
|
346 |
+
MODEL_PATH = "valentinafeve/yolos-fashionpedia"
|
347 |
+
# feature_extractor, model = load_model(FEATURE_EXTRACTOR_PATH, MODEL_PATH)
|
348 |
|
349 |
+
# # RUN MODEL ON IMAGE
|
350 |
+
# inputs = feature_extractor(images=image, return_tensors="pt")
|
351 |
+
# outputs = model(**inputs)
|
352 |
+
|
353 |
+
# Save results
|
354 |
+
# pickle_file_path = r"data/dior_show/results"
|
355 |
+
# image_name = image_.split('\\')[1][:5]
|
356 |
+
# with open(os.path.join(pickle_file_path, f"{image_name}_results.pkl"), 'wb') as file:
|
357 |
+
# pickle.dump(outputs, file)
|
358 |
+
|
359 |
+
image_name = image_name[:5]
|
360 |
+
path_load_pickle = os.path.join(pickle_file_path, f"{image_name}_results.pkl")
|
361 |
+
with open(path_load_pickle, 'rb') as pickle_file:
|
362 |
+
outputs = pickle.load(pickle_file)
|
363 |
+
|
364 |
+
probas, keep = return_probas(outputs, threshold)
|
|
|
|
|
|
|
365 |
|
366 |
+
st.markdown("#### See the results ☑️")
|
367 |
|
368 |
+
# PLOT BOUNDING BOX AND BARS/PROBA
|
369 |
+
col1, col2 = st.columns(2)
|
370 |
+
with col1:
|
371 |
+
st.markdown(" ")
|
372 |
+
st.markdown("##### 1. Bounding box results")
|
373 |
+
bboxes_scaled = rescale_bboxes(outputs.pred_boxes[0, keep].cpu(), image.size)
|
374 |
+
colors_used = plot_results(image, probas[keep], bboxes_scaled)
|
375 |
+
|
376 |
+
with col2:
|
377 |
+
#st.markdown("**Probability scores**")
|
378 |
+
if not any(keep.tolist()):
|
379 |
+
st.error("""No objects were detected on the image.
|
380 |
+
Decrease your threshold or choose differents items to detect.""")
|
381 |
+
else:
|
382 |
+
st.markdown(" ")
|
383 |
+
st.markdown("##### 2. Probability score of each object")
|
384 |
+
st.info("""**Note**: Some items might have been detected more than once on the image.
|
385 |
+
For these items, we've computed the average probability score across all detections.""")
|
386 |
+
visualize_probas(probas, threshold, colors_used)
|
387 |
+
|
388 |
+
|
389 |
+
|
390 |
+
else:
|
391 |
+
st.error("You must select an **image**, **elements to detect** and a **threshold** to run the model !")
|
392 |
|
393 |
|
394 |
|
pages/recommendation_system.py
CHANGED
@@ -10,7 +10,7 @@ import plotly.express as px
|
|
10 |
from sklearn.preprocessing import MinMaxScaler
|
11 |
from sklearn.metrics.pairwise import cosine_similarity
|
12 |
from annotated_text import annotated_text
|
13 |
-
from utils import
|
14 |
from st_pages import add_indentation
|
15 |
|
16 |
#add_indentation()
|
@@ -18,445 +18,444 @@ from st_pages import add_indentation
|
|
18 |
|
19 |
st.set_page_config(layout="wide")
|
20 |
|
21 |
-
|
22 |
-
st.markdown("# Recommendation systems 🛒")
|
|
|
23 |
|
24 |
-
st.
|
|
|
25 |
|
26 |
-
st.
|
27 |
-
They are very common in social media platforms such as TikTok, Youtube or Instagram or e-commerce websites as they help improve and personalize a consumer's experience.""")
|
28 |
-
|
29 |
-
st.markdown("""There are two main types of recommendation systems:
|
30 |
- **Content-based filtering**: Recommendations are made based on the user's own preferences
|
31 |
- **Collaborative filtering**: Recommendations are made based on the preferences and behavior of similar users""", unsafe_allow_html=True)
|
32 |
-
|
33 |
-
# st.markdown("""Here is an example of **Content-based filtering versus Collaborative filtering** for movie recommendations.""")
|
34 |
-
st.markdown(" ")
|
35 |
-
st.markdown(" ")
|
36 |
|
37 |
-
# _, col_img, _ = st.columns(spec=[0.2,0.6,0.2])
|
38 |
-
# with col_img:
|
39 |
-
# st.image("images/rs.png")
|
40 |
|
41 |
-
st.image("images/rs.png")
|
42 |
|
43 |
-
st.markdown(" ")
|
44 |
|
45 |
-
st.markdown("""Common applications of Recommendation systems include:
|
46 |
- **E-Commerce Platforms** 🛍️: Suggest products to users based on their browsing history, purchase patterns, and preferences.
|
47 |
- **Streaming Services** 📽️: Recommend movies, TV shows, or songs based on users' viewing/listening history and preferences.
|
48 |
- **Social Media Platforms** 📱: Suggest friends, groups, or content based on users' connections, interests, and engagement history.
|
49 |
- **Automotive and Navigation Systems** 🗺️: Suggest optimal routes based on real-time traffic conditions, historical data, and user preferences.
|
50 |
-
""")
|
51 |
|
52 |
-
st.markdown(" ")
|
53 |
|
54 |
-
select_usecase = st.selectbox("**Choose a use case**",
|
55 |
-
|
56 |
-
|
57 |
|
58 |
-
st.divider()
|
59 |
|
60 |
|
61 |
|
62 |
-
#####################################################################################################
|
63 |
-
# MOVIE RECOMMENDATION SYSTEM #
|
64 |
-
#####################################################################################################
|
65 |
|
66 |
-
# Recommendation function
|
67 |
-
def recommend(movie_name, nb):
|
68 |
-
|
69 |
-
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
|
75 |
-
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
|
84 |
-
poster = fetch_poster(temp_movie_id)
|
85 |
-
recommend_posters.append(poster)
|
86 |
-
|
87 |
-
# fetch poster
|
88 |
-
try:
|
89 |
poster = fetch_poster(temp_movie_id)
|
90 |
recommend_posters.append(poster)
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
return "https://image.tmdb.org/t/p/w500/" + data["poster_path"]
|
102 |
-
|
103 |
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
-
if select_usecase == "Movie recommendation system 📽️":
|
106 |
|
107 |
-
colors = ["#8ef", "#faa", "#afa", "#fea", "#8ef","#afa"]
|
108 |
-
#api_key = st.secrets["recommendation_system"]["key"]
|
109 |
-
api_key = os.environ["MOVIE_RECOM_API"]
|
110 |
|
111 |
-
|
112 |
-
path_data = r"data/movies"
|
113 |
-
path_models = r"pretrained_models/recommendation_system"
|
114 |
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
|
119 |
-
|
120 |
-
|
|
|
121 |
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
csr_data = pickle.load(file)
|
126 |
|
|
|
|
|
127 |
|
128 |
-
|
129 |
-
|
|
|
|
|
130 |
|
131 |
-
#st.info(""" """)
|
132 |
|
133 |
-
|
134 |
-
|
135 |
-
""", unsafe_allow_html=True)
|
136 |
-
st.markdown(" ")
|
137 |
|
138 |
-
|
139 |
-
# User selection
|
140 |
-
selected_movie = st.selectbox("**Select a movie**", movies["title"].values[:-3])
|
141 |
-
selected_nb_movies = st.selectbox("**Select a number of movies to recommend**", np.arange(2,7), index=3)
|
142 |
-
|
143 |
-
# Show user selection on the app
|
144 |
-
c1, c2 = st.columns([0.7,0.3], gap="medium")
|
145 |
-
with c1:
|
146 |
-
new_movies = movies.rename({"movie_id":"id"},axis=1).merge(vote, on="id", how="left")
|
147 |
-
description = new_movies.loc[new_movies["title"]==selected_movie,"description"].to_list()[0]
|
148 |
-
genre = new_movies.loc[new_movies["title"]==selected_movie,"genre"].to_list()[0]
|
149 |
-
vote_ = new_movies.loc[new_movies["title"]==selected_movie,"vote_average"].to_list()[0]
|
150 |
-
vote_count = new_movies.loc[new_movies["title"]==selected_movie,"vote_count"].to_list()[0]
|
151 |
-
|
152 |
-
list_genres = [(g.strip(),"",color) for color,g in zip(colors, genre.split(", "))]
|
153 |
-
|
154 |
-
st.header(selected_movie, divider="grey")
|
155 |
-
st.markdown(f"**Synopsis**: {description}")
|
156 |
-
annotated_text(["**Genre(s)**: ", list_genres])
|
157 |
-
st.markdown(f"**Rating**: {vote_}:star:")
|
158 |
-
st.markdown(f"**Votes**: {vote_count}")
|
159 |
|
160 |
-
st.
|
|
|
|
|
161 |
st.markdown(" ")
|
162 |
-
|
163 |
-
recommend_button = st.button("**Recommend movies**")
|
164 |
-
|
165 |
-
with c2:
|
166 |
-
try:
|
167 |
-
poster = fetch_poster(movies.loc[movies["title"]==selected_movie,"movie_id"].to_list()[0])
|
168 |
-
st.image(poster, width=300)
|
169 |
-
except:
|
170 |
-
pass
|
171 |
-
|
172 |
-
|
173 |
-
# Run model and show results
|
174 |
-
if recommend_button:
|
175 |
-
st.text("Here are few Recommendations..")
|
176 |
-
names,posters,movie_ids = recommend(selected_movie, selected_nb_movies)
|
177 |
-
tab1, tab2 = st.tabs(["View movies", "View genres"])
|
178 |
-
|
179 |
-
with tab1:
|
180 |
-
cols=st.columns(int(selected_nb_movies))
|
181 |
-
#cols=[col1,col2,col3,col4,col5]
|
182 |
-
for i in range(0,selected_nb_movies):
|
183 |
-
with cols[i]:
|
184 |
-
expander = st.expander("See movie details")
|
185 |
-
|
186 |
-
# if posters[i] == None:
|
187 |
-
# pass
|
188 |
-
# else:
|
189 |
-
# st.image(posters[i])
|
190 |
-
|
191 |
-
st.markdown(f"##### **{i+1}. {names[i]}**")
|
192 |
-
id = movie_ids[i]
|
193 |
|
194 |
-
genre = movies.loc[movies["movie_id"]==id,"genre"].to_list()[0]
|
195 |
-
list_genres = [(g.strip(),"",color) for color,g in zip(colors, genre.split(", "))]
|
196 |
-
|
197 |
-
synopsis = movies.loc[movies['movie_id']==id, "description"].to_list()[0]
|
198 |
-
st.markdown(synopsis)
|
199 |
-
|
200 |
-
vote_avg, vote_count = vote[vote["id"] == id].vote_average , vote[vote["id"] == id].vote_count
|
201 |
-
annotated_text(["**Genre(s)**: ", list_genres])
|
202 |
-
st.markdown(f"""**Rating**: {list(vote_avg.values)[0]}:star:""")
|
203 |
-
st.markdown(f"**Votes**: {list(vote_count.values)[0]}")
|
204 |
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
|
|
|
|
|
216 |
|
217 |
|
218 |
|
219 |
|
220 |
|
221 |
-
#####################################################################################################
|
222 |
-
# HOTEL RECOMMENDATION SYSTEM #
|
223 |
-
#####################################################################################################
|
224 |
-
|
225 |
|
226 |
-
# Load scaler with caching
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
if select_usecase == "Hotel recommendation system 🛎️":
|
231 |
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
return scaler
|
237 |
|
238 |
-
|
239 |
-
# Start with the full dataset
|
240 |
-
filtered_df = df.copy()
|
241 |
-
|
242 |
-
# Filter by Location if specified (either city or country)
|
243 |
-
if 'Location' in preferences and preferences['Location']:
|
244 |
-
filtered_df = filtered_df[(filtered_df['City'].str.contains(preferences['Location'], case=False, na=False)) |
|
245 |
-
(filtered_df['Country'].str.contains(preferences['Location'], case=False, na=False))]
|
246 |
-
|
247 |
-
# Filter by Number of beds if specified
|
248 |
-
if 'Number of beds' in preferences:
|
249 |
-
filtered_df = filtered_df[filtered_df['Number of bed'] == preferences['Number of beds']]
|
250 |
|
251 |
-
# Filter by Rating if specified
|
252 |
-
if 'Rating' in preferences:
|
253 |
-
min_rating, max_rating = preferences['Rating']
|
254 |
-
filtered_df = filtered_df[filtered_df['Rating'].between(min_rating, max_rating)]
|
255 |
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
- **Bangkok**: Immerse yourself in the hustle and bustle of Bangkok's streets, adorned with glittering temples and bustling markets. The Grand Palace and Khao San Road showcase the city's unique blend of tradition and modernity.
|
305 |
- **Chiang Mai**: Nestled in the misty mountains of Northern Thailand, Chiang Mai captivates with ancient temples, lush landscapes, and vibrant night markets. The Old City exudes a unique atmosphere, while the surrounding hills offer tranquility.
|
306 |
- **Phuket**: Thailand's largest island, Phuket, beckons beach lovers with its stunning white sands, vibrant nightlife, and water activities. It's a perfect blend of relaxation and excitement."""
|
307 |
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
Indulge in the countries rich tapestry of art, culture, and gastronomy.
|
313 |
-
From the romantic allure of Paris to the sun-kissed vineyards of Provence, every corner of this diverse country tells a unique story, promising an unforgettable journey for every traveler."""
|
314 |
-
|
315 |
- **Paris**: Dive into the city's iconic landmarks such as the Eiffel Tower, Louvre Museum, and Notre-Dame Cathedral grace the skyline.
|
316 |
- **Provence**: Visit the stunning Palais des Papes in Avignon, explore the colorful markets of Aix-en-Provence, and unwind in the serene beauty of the Luberon region.
|
317 |
- **Côte d'Azur**: This stunning stretch of the French coastline is a captivating blend of azure waters, picturesque landscapes and charming villages.
|
318 |
-
"""
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
Embark on an unforgettable journey where tradition and modernity coexist in harmony.
|
325 |
-
From the lively streets of Barcelona to the sun-soaked beaches of Andalusia, Spain offers a captivating blend of history, culture, and natural beauty.
|
326 |
-
"""
|
327 |
-
|
328 |
- **Barcelona**: Explore the iconic Sagrada Familia, stroll down the vibrant La Rambla, and soak in the Mediterranean ambiance at Barceloneta Beach.
|
329 |
- **Seville**: Visit the awe-inspiring Alcázar, marvel at the Giralda Tower, and wander through the enchanting alleys of the Santa Cruz neighborhood.
|
330 |
- **Granada**: Explore the Generalife Gardens, stroll through the Albayzín quarter with its narrow streets and white houses, and savor the views of the city from the Mirador de San Nicolás.
|
331 |
-
"""
|
332 |
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
From gleaming skyscrapers to vibrant neighborhoods, this cosmopolitan gem in Southeast Asia promises an immersive journey into a world where tradition meets cutting-edge technology."""
|
338 |
|
339 |
-
|
340 |
- **Marina Bay Sands**: Enjoy panoramic views from the SkyPark, take a dip in the infinity pool, and explore The Shoppes for luxury shopping and entertainment. At night, witness the mesmerizing light and water show at the Marina Bay Sands Skypark.
|
341 |
- **Gardens by the Bay**: Explore the Flower Dome and Cloud Forest conservatories, and stroll through the scenic OCBC Skyway for breathtaking views of the gardens and city.
|
342 |
- **Sentosa Island**: Escape to Sentosa Island, a resort destination offering a myriad of attractions. Relax on pristine beaches, visit Universal Studios Singapore for thrilling rides, and explore S.E.A. Aquarium for an underwater adventure.
|
343 |
|
344 |
-
"""
|
345 |
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
|
365 |
|
366 |
-
|
367 |
|
368 |
-
|
369 |
-
|
370 |
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
|
382 |
-
|
383 |
-
|
384 |
|
385 |
-
|
386 |
|
387 |
|
388 |
-
|
389 |
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
|
403 |
-
|
404 |
-
|
405 |
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
|
410 |
-
|
411 |
-
|
412 |
|
413 |
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
recommended_hotels, message = recommend_hotels_with_location_and_beds(df, preferences, max_recommendations)
|
430 |
-
|
431 |
-
# If no recommendations, reduce the maximum number of recommendations and try again
|
432 |
-
if recommended_hotels.empty:
|
433 |
-
max_recommendations -= 1
|
434 |
recommended_hotels, message = recommend_hotels_with_location_and_beds(df, preferences, max_recommendations)
|
|
|
|
|
435 |
if recommended_hotels.empty:
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
st.
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
|
|
|
|
|
|
|
10 |
from sklearn.preprocessing import MinMaxScaler
|
11 |
from sklearn.metrics.pairwise import cosine_similarity
|
12 |
from annotated_text import annotated_text
|
13 |
+
from utils import load_model_pickle, load_data_csv, check_password
|
14 |
from st_pages import add_indentation
|
15 |
|
16 |
#add_indentation()
|
|
|
18 |
|
19 |
st.set_page_config(layout="wide")
|
20 |
|
21 |
+
if check_password():
|
22 |
+
st.markdown("# Recommendation systems 🛒")
|
23 |
+
st.markdown("### What is a Recommendation System ?")
|
24 |
|
25 |
+
st.info("""**Recommendation systems** are algorithms built to **suggest** or **recommend** **products** to consumers.
|
26 |
+
They are very common in social media platforms such as TikTok, Youtube or Instagram or e-commerce websites as they help improve and personalize a consumer's experience.""")
|
27 |
|
28 |
+
st.markdown("""There are two main types of recommendation systems:
|
|
|
|
|
|
|
29 |
- **Content-based filtering**: Recommendations are made based on the user's own preferences
|
30 |
- **Collaborative filtering**: Recommendations are made based on the preferences and behavior of similar users""", unsafe_allow_html=True)
|
31 |
+
|
32 |
+
# st.markdown("""Here is an example of **Content-based filtering versus Collaborative filtering** for movie recommendations.""")
|
33 |
+
st.markdown(" ")
|
34 |
+
st.markdown(" ")
|
35 |
|
36 |
+
# _, col_img, _ = st.columns(spec=[0.2,0.6,0.2])
|
37 |
+
# with col_img:
|
38 |
+
# st.image("images/rs.png")
|
39 |
|
40 |
+
st.image("images/rs.png")
|
41 |
|
42 |
+
st.markdown(" ")
|
43 |
|
44 |
+
st.markdown("""Common applications of Recommendation systems include:
|
45 |
- **E-Commerce Platforms** 🛍️: Suggest products to users based on their browsing history, purchase patterns, and preferences.
|
46 |
- **Streaming Services** 📽️: Recommend movies, TV shows, or songs based on users' viewing/listening history and preferences.
|
47 |
- **Social Media Platforms** 📱: Suggest friends, groups, or content based on users' connections, interests, and engagement history.
|
48 |
- **Automotive and Navigation Systems** 🗺️: Suggest optimal routes based on real-time traffic conditions, historical data, and user preferences.
|
49 |
+
""")
|
50 |
|
51 |
+
st.markdown(" ")
|
52 |
|
53 |
+
select_usecase = st.selectbox("**Choose a use case**",
|
54 |
+
["Movie recommendation system 📽️",
|
55 |
+
"Hotel recommendation system 🛎️"])
|
56 |
|
57 |
+
st.divider()
|
58 |
|
59 |
|
60 |
|
61 |
+
#####################################################################################################
|
62 |
+
# MOVIE RECOMMENDATION SYSTEM #
|
63 |
+
#####################################################################################################
|
64 |
|
65 |
+
# Recommendation function
|
66 |
+
def recommend(movie_name, nb):
|
67 |
+
n_movies_to_recommend = nb
|
68 |
+
idx = movies[movies['title'] == movie_name].index[0]
|
69 |
|
70 |
+
distances, indices = model.kneighbors(csr_data[idx], n_neighbors=n_movies_to_recommend + 1)
|
71 |
+
idx = list(indices.squeeze())
|
72 |
+
df = np.take(movies, idx, axis=0)
|
73 |
|
74 |
+
movies_list = list(df.title[1:])
|
75 |
|
76 |
+
recommend_movies_names = []
|
77 |
+
recommend_posters = []
|
78 |
+
movie_ids = []
|
79 |
+
for i in movies_list:
|
80 |
+
temp_movie_id = (movies[movies.title ==i].movie_id).values[0]
|
81 |
+
movie_ids.append(temp_movie_id)
|
82 |
|
|
|
|
|
|
|
|
|
|
|
83 |
poster = fetch_poster(temp_movie_id)
|
84 |
recommend_posters.append(poster)
|
85 |
+
|
86 |
+
# fetch poster
|
87 |
+
try:
|
88 |
+
poster = fetch_poster(temp_movie_id)
|
89 |
+
recommend_posters.append(poster)
|
90 |
+
except:
|
91 |
+
recommend_posters.append(None)
|
92 |
+
|
93 |
+
recommend_movies_names.append(i)
|
94 |
+
return recommend_movies_names, recommend_posters, movie_ids
|
|
|
|
|
95 |
|
96 |
+
# Get poster
|
97 |
+
def fetch_poster(movie_id):
|
98 |
+
response = requests.get(f'https://api.themoviedb.org/3/movie/{movie_id}?api_key={api_key}')
|
99 |
+
data = response.json()
|
100 |
+
return "https://image.tmdb.org/t/p/w500/" + data["poster_path"]
|
101 |
|
|
|
102 |
|
|
|
|
|
|
|
103 |
|
104 |
+
if select_usecase == "Movie recommendation system 📽️":
|
|
|
|
|
105 |
|
106 |
+
colors = ["#8ef", "#faa", "#afa", "#fea", "#8ef","#afa"]
|
107 |
+
#api_key = st.secrets["recommendation_system"]["key"]
|
108 |
+
api_key = os.environ["MOVIE_RECOM_API"]
|
109 |
|
110 |
+
# Load data
|
111 |
+
path_data = r"data/movies"
|
112 |
+
path_models = r"pretrained_models/recommendation_system"
|
113 |
|
114 |
+
movies_dict = pickle.load(open(os.path.join(path_data,"movies_dict2.pkl"),"rb"))
|
115 |
+
movies = pd.DataFrame(movies_dict)
|
116 |
+
movies.drop_duplicates(inplace=True)
|
|
|
117 |
|
118 |
+
vote_info = pickle.load(open(os.path.join(path_data,"vote_info.pkl"),"rb"))
|
119 |
+
vote = pd.DataFrame(vote_info)
|
120 |
|
121 |
+
# Load model
|
122 |
+
model = load_model_pickle(path_models,"model.pkl")
|
123 |
+
with open(os.path.join(path_data,'csr_data_tf.pkl'), 'rb') as file:
|
124 |
+
csr_data = pickle.load(file)
|
125 |
|
|
|
126 |
|
127 |
+
# Description of the use case
|
128 |
+
st.markdown("""# Movie Recommendation System 📽️""")
|
|
|
|
|
129 |
|
130 |
+
#st.info(""" """)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
+
st.markdown("""This use case showcases the use of recommender systems for **movie recommendations** using **collaborative filtering**. <br>
|
133 |
+
The model recommends and ranks movies based on what users, who have also watched the chosen movie, have watched else on the platform. <br>
|
134 |
+
""", unsafe_allow_html=True)
|
135 |
st.markdown(" ")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
+
# User selection
|
139 |
+
selected_movie = st.selectbox("**Select a movie**", movies["title"].values[:-3])
|
140 |
+
selected_nb_movies = st.selectbox("**Select a number of movies to recommend**", np.arange(2,7), index=3)
|
141 |
+
|
142 |
+
# Show user selection on the app
|
143 |
+
c1, c2 = st.columns([0.7,0.3], gap="medium")
|
144 |
+
with c1:
|
145 |
+
new_movies = movies.rename({"movie_id":"id"},axis=1).merge(vote, on="id", how="left")
|
146 |
+
description = new_movies.loc[new_movies["title"]==selected_movie,"description"].to_list()[0]
|
147 |
+
genre = new_movies.loc[new_movies["title"]==selected_movie,"genre"].to_list()[0]
|
148 |
+
vote_ = new_movies.loc[new_movies["title"]==selected_movie,"vote_average"].to_list()[0]
|
149 |
+
vote_count = new_movies.loc[new_movies["title"]==selected_movie,"vote_count"].to_list()[0]
|
150 |
+
|
151 |
+
list_genres = [(g.strip(),"",color) for color,g in zip(colors, genre.split(", "))]
|
152 |
+
|
153 |
+
st.header(selected_movie, divider="grey")
|
154 |
+
st.markdown(f"**Synopsis**: {description}")
|
155 |
+
annotated_text(["**Genre(s)**: ", list_genres])
|
156 |
+
st.markdown(f"**Rating**: {vote_}:star:")
|
157 |
+
st.markdown(f"**Votes**: {vote_count}")
|
158 |
|
159 |
+
st.info(f"You've selected {selected_nb_movies} movies to recommend")
|
160 |
+
st.markdown(" ")
|
161 |
+
|
162 |
+
recommend_button = st.button("**Recommend movies**")
|
163 |
+
|
164 |
+
with c2:
|
165 |
+
try:
|
166 |
+
poster = fetch_poster(movies.loc[movies["title"]==selected_movie,"movie_id"].to_list()[0])
|
167 |
+
st.image(poster, width=300)
|
168 |
+
except:
|
169 |
+
pass
|
170 |
+
|
171 |
+
|
172 |
+
# Run model and show results
|
173 |
+
if recommend_button:
|
174 |
+
st.text("Here are few Recommendations..")
|
175 |
+
names,posters,movie_ids = recommend(selected_movie, selected_nb_movies)
|
176 |
+
tab1, tab2 = st.tabs(["View movies", "View genres"])
|
177 |
+
|
178 |
+
with tab1:
|
179 |
+
cols=st.columns(int(selected_nb_movies))
|
180 |
+
#cols=[col1,col2,col3,col4,col5]
|
181 |
+
for i in range(0,selected_nb_movies):
|
182 |
+
with cols[i]:
|
183 |
+
expander = st.expander("See movie details")
|
184 |
+
|
185 |
+
# if posters[i] == None:
|
186 |
+
# pass
|
187 |
+
# else:
|
188 |
+
# st.image(posters[i])
|
189 |
+
|
190 |
+
st.markdown(f"##### **{i+1}. {names[i]}**")
|
191 |
+
id = movie_ids[i]
|
192 |
+
|
193 |
+
genre = movies.loc[movies["movie_id"]==id,"genre"].to_list()[0]
|
194 |
+
list_genres = [(g.strip(),"",color) for color,g in zip(colors, genre.split(", "))]
|
195 |
+
|
196 |
+
synopsis = movies.loc[movies['movie_id']==id, "description"].to_list()[0]
|
197 |
+
st.markdown(synopsis)
|
198 |
+
|
199 |
+
vote_avg, vote_count = vote[vote["id"] == id].vote_average , vote[vote["id"] == id].vote_count
|
200 |
+
annotated_text(["**Genre(s)**: ", list_genres])
|
201 |
+
st.markdown(f"""**Rating**: {list(vote_avg.values)[0]}:star:""")
|
202 |
+
st.markdown(f"**Votes**: {list(vote_count.values)[0]}")
|
203 |
+
|
204 |
+
|
205 |
+
with tab2:
|
206 |
+
recommended_genres = movies.loc[movies["movie_id"].isin(movie_ids[:5]),"genre"].to_list()
|
207 |
+
list_recom_genres = [genre for list_genres in recommended_genres for genre in list_genres.split(", ")]
|
208 |
+
df_recom_genres = pd.Series(list_recom_genres).value_counts().to_frame().reset_index(names="genre")
|
209 |
+
df_recom_genres["proportion (%)"] = (100*df_recom_genres["count"]/df_recom_genres["count"].sum())
|
210 |
|
211 |
+
fig = px.bar(df_recom_genres, x='count', y='genre', color="genre", title='Most recommended genres', orientation="h")
|
212 |
+
st.plotly_chart(fig, use_container_width=True)
|
213 |
|
214 |
|
215 |
|
216 |
|
217 |
|
|
|
|
|
|
|
|
|
218 |
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
+
#####################################################################################################
|
221 |
+
# HOTEL RECOMMENDATION SYSTEM #
|
222 |
+
#####################################################################################################
|
223 |
+
|
|
|
224 |
|
225 |
+
# Load scaler with caching
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
|
|
|
|
|
|
|
|
|
227 |
|
228 |
+
|
229 |
+
if select_usecase == "Hotel recommendation system 🛎️":
|
230 |
+
|
231 |
+
@st.cache_data(ttl=3600)
|
232 |
+
def get_scaler(df):
|
233 |
+
scaler = MinMaxScaler()
|
234 |
+
scaler.fit(df[['Rating', 'Price']])
|
235 |
+
return scaler
|
236 |
+
|
237 |
+
def recommend_hotels_with_location_and_beds(df, preferences, max_recommendations=5):
|
238 |
+
# Start with the full dataset
|
239 |
+
filtered_df = df.copy()
|
240 |
+
|
241 |
+
# Filter by Location if specified (either city or country)
|
242 |
+
if 'Location' in preferences and preferences['Location']:
|
243 |
+
filtered_df = filtered_df[(filtered_df['City'].str.contains(preferences['Location'], case=False, na=False)) |
|
244 |
+
(filtered_df['Country'].str.contains(preferences['Location'], case=False, na=False))]
|
245 |
+
|
246 |
+
# Filter by Number of beds if specified
|
247 |
+
if 'Number of beds' in preferences:
|
248 |
+
filtered_df = filtered_df[filtered_df['Number of bed'] == preferences['Number of beds']]
|
249 |
+
|
250 |
+
# Filter by Rating if specified
|
251 |
+
if 'Rating' in preferences:
|
252 |
+
min_rating, max_rating = preferences['Rating']
|
253 |
+
filtered_df = filtered_df[filtered_df['Rating'].between(min_rating, max_rating)]
|
254 |
+
|
255 |
+
# Filter by Price range if specified
|
256 |
+
if 'Price' in preferences:
|
257 |
+
min_price, max_price = preferences['Price']
|
258 |
+
filtered_df = filtered_df[filtered_df['Price'].between(min_price, max_price)]
|
259 |
+
|
260 |
+
# Ensure there are still hotels after filtering
|
261 |
+
if filtered_df.empty:
|
262 |
+
# Send a notification if no hotels match the criteria
|
263 |
+
send_notification("No hotels were found matching the specified criteria.")
|
264 |
+
return pd.DataFrame(), "No hotels were found matching the specified criteria."
|
265 |
+
|
266 |
+
preferences["Rating"] = np.mean(np.array(preferences["Rating"]))
|
267 |
+
preferences["Price"] = np.mean(np.array(preferences["Price"]))
|
268 |
+
|
269 |
+
# Normalize the preferences vector (excluding location and number of beds for similarity calculation)
|
270 |
+
preferences_vector = np.array([[preferences.get('Rating', 0),
|
271 |
+
preferences.get('Price', 0)]])
|
272 |
+
preferences_vector_normalized = scaler.transform(preferences_vector)
|
273 |
+
|
274 |
+
# Calculate similarity scores for the filtered hotels
|
275 |
+
filtered_numerical_features = filtered_df[['Rating', 'Price']]
|
276 |
+
filtered_numerical_features_normalized = scaler.transform(filtered_numerical_features)
|
277 |
+
similarity_scores = cosine_similarity(preferences_vector_normalized, filtered_numerical_features_normalized)[0]
|
278 |
+
|
279 |
+
# Get the indices of top_n similar hotels
|
280 |
+
top_indices = similarity_scores.argsort()[-max_recommendations:][::-1]
|
281 |
+
recommended_indices = filtered_df.iloc[top_indices].index
|
282 |
+
|
283 |
+
# Return the recommended hotels with relevant details (including specified columns)
|
284 |
+
return df.loc[recommended_indices], None
|
285 |
+
|
286 |
+
|
287 |
+
def send_notification(message):
|
288 |
+
"""
|
289 |
+
Placeholder function to send a notification.
|
290 |
+
This function can be replaced with the actual notification mechanism (e.g., email, SMS).
|
291 |
+
"""
|
292 |
+
print("Notification:", message)
|
293 |
+
|
294 |
+
|
295 |
+
def country_info(country):
|
296 |
+
if country == "Thailand":
|
297 |
+
image = "images/thailand.jpeg"
|
298 |
+
emoji = "🏝️"
|
299 |
+
description = """**Description**:
|
300 |
+
Thailand seamlessly fuses ancient traditions with modern dynamism, creating an unparalleled tapestry for travelers.
|
301 |
+
Renowned for its warm hospitality, vibrant culture, and delectable cuisine, Thailand offers an unforgettable experience for every adventurer."""
|
302 |
+
top_places = """
|
303 |
- **Bangkok**: Immerse yourself in the hustle and bustle of Bangkok's streets, adorned with glittering temples and bustling markets. The Grand Palace and Khao San Road showcase the city's unique blend of tradition and modernity.
|
304 |
- **Chiang Mai**: Nestled in the misty mountains of Northern Thailand, Chiang Mai captivates with ancient temples, lush landscapes, and vibrant night markets. The Old City exudes a unique atmosphere, while the surrounding hills offer tranquility.
|
305 |
- **Phuket**: Thailand's largest island, Phuket, beckons beach lovers with its stunning white sands, vibrant nightlife, and water activities. It's a perfect blend of relaxation and excitement."""
|
306 |
|
307 |
+
if country == "France":
|
308 |
+
image = "images/france.jpeg"
|
309 |
+
emoji = "⚜️"
|
310 |
+
description ="""**Description**:
|
311 |
+
Indulge in the countries rich tapestry of art, culture, and gastronomy.
|
312 |
+
From the romantic allure of Paris to the sun-kissed vineyards of Provence, every corner of this diverse country tells a unique story, promising an unforgettable journey for every traveler."""
|
313 |
+
top_places = """
|
314 |
- **Paris**: Dive into the city's iconic landmarks such as the Eiffel Tower, Louvre Museum, and Notre-Dame Cathedral grace the skyline.
|
315 |
- **Provence**: Visit the stunning Palais des Papes in Avignon, explore the colorful markets of Aix-en-Provence, and unwind in the serene beauty of the Luberon region.
|
316 |
- **Côte d'Azur**: This stunning stretch of the French coastline is a captivating blend of azure waters, picturesque landscapes and charming villages.
|
317 |
+
"""
|
318 |
+
|
319 |
+
if country == "Spain":
|
320 |
+
image = "images/spain-banner.jpg"
|
321 |
+
emoji = "☀️"
|
322 |
+
description = """**Description**:
|
323 |
+
Embark on an unforgettable journey where tradition and modernity coexist in harmony.
|
324 |
+
From the lively streets of Barcelona to the sun-soaked beaches of Andalusia, Spain offers a captivating blend of history, culture, and natural beauty.
|
325 |
+
"""
|
326 |
+
top_places = """
|
327 |
- **Barcelona**: Explore the iconic Sagrada Familia, stroll down the vibrant La Rambla, and soak in the Mediterranean ambiance at Barceloneta Beach.
|
328 |
- **Seville**: Visit the awe-inspiring Alcázar, marvel at the Giralda Tower, and wander through the enchanting alleys of the Santa Cruz neighborhood.
|
329 |
- **Granada**: Explore the Generalife Gardens, stroll through the Albayzín quarter with its narrow streets and white houses, and savor the views of the city from the Mirador de San Nicolás.
|
330 |
+
"""
|
331 |
|
332 |
+
if country == "Singapore":
|
333 |
+
image = "images/singapore.jpg"
|
334 |
+
emoji = "🏙️"
|
335 |
+
description = """**Description**:
|
336 |
+
From gleaming skyscrapers to vibrant neighborhoods, this cosmopolitan gem in Southeast Asia promises an immersive journey into a world where tradition meets cutting-edge technology."""
|
337 |
|
338 |
+
top_places = """
|
339 |
- **Marina Bay Sands**: Enjoy panoramic views from the SkyPark, take a dip in the infinity pool, and explore The Shoppes for luxury shopping and entertainment. At night, witness the mesmerizing light and water show at the Marina Bay Sands Skypark.
|
340 |
- **Gardens by the Bay**: Explore the Flower Dome and Cloud Forest conservatories, and stroll through the scenic OCBC Skyway for breathtaking views of the gardens and city.
|
341 |
- **Sentosa Island**: Escape to Sentosa Island, a resort destination offering a myriad of attractions. Relax on pristine beaches, visit Universal Studios Singapore for thrilling rides, and explore S.E.A. Aquarium for an underwater adventure.
|
342 |
|
343 |
+
"""
|
344 |
|
345 |
+
###### STREAMLIT MARKDOWN ######
|
346 |
+
st.header(f"{country} {emoji}", divider="grey")
|
347 |
+
st.image(image)
|
348 |
+
st.markdown(description)
|
349 |
|
350 |
+
see_top_places = st.checkbox("**Top places to visit**", key={country})
|
351 |
+
if see_top_places:
|
352 |
+
st.markdown(top_places)
|
353 |
|
354 |
+
|
355 |
+
|
356 |
+
|
357 |
+
st.markdown("""# Hotel Recommendation System 🛎️""")
|
358 |
|
359 |
+
st.info("""This use case shows how you can create personalized hotel recommendations using a recommendation system with **content-based Filtering**.
|
360 |
+
Analyzing location, amenities, price, and reviews, the model suggests tailored hotel recommendation based on the user's preference.
|
361 |
+
""")
|
362 |
+
st.markdown(" ")
|
363 |
|
364 |
|
365 |
+
path_hotels_data = r"data/hotels"
|
366 |
|
367 |
+
# Load hotel data
|
368 |
+
df = load_data_csv(path_hotels_data,"booking_df.csv")
|
369 |
|
370 |
+
# clean data
|
371 |
+
df.drop_duplicates(inplace=True)
|
372 |
+
df["Country"] = df["Country"].apply(lambda x: "Spain" if x=="Espagne" else x)
|
373 |
+
list_cities = df["City"].value_counts().to_frame().reset_index()
|
374 |
+
list_cities = list_cities.loc[list_cities["count"]>=5,"City"].to_numpy()
|
375 |
+
df = df.loc[(df["City"].isin(list_cities)) & (df["Number of bed"]<=6)]
|
376 |
+
df["Price"] = df["Price"].astype(int)
|
377 |
+
df.loc[(df["Number of bed"]==0) & (df["Price"]<1000),"Number of bed"] = 1
|
378 |
+
df.loc[(df["Number of bed"]==0) & (df["Price"].between(1000,2000)),"Number of bed"] = 2
|
379 |
+
df.loc[(df["Number of bed"]==0) & (df["Price"]>2000),"Number of bed"] = 3
|
380 |
|
381 |
+
df["Rating"] = df["Rating"].apply(lambda x: np.nan if x==0 else x)
|
382 |
+
df["Rating"].fillna(np.round(df["Rating"].mean(), 1), inplace=True)
|
383 |
|
384 |
+
scaler = get_scaler(df)
|
385 |
|
386 |
|
387 |
+
col1, col2 = st.columns([0.3,0.7], gap="large")
|
388 |
|
389 |
+
with col1:
|
390 |
+
# Collect user preferences
|
391 |
+
st.markdown(" ")
|
392 |
+
st.markdown(" ")
|
393 |
+
st.markdown("")
|
394 |
+
#st.markdown("#### Filter preferences")
|
395 |
+
list_countries = df["Country"].unique()
|
396 |
+
location = st.selectbox("Select a Country",list_countries, index=0)
|
397 |
|
398 |
+
list_nb_beds = df["Number of bed"].unique()
|
399 |
+
num_beds = st.selectbox("Number of beds", list_nb_beds, index=0)
|
400 |
+
#if num_beds == "No information"
|
401 |
|
402 |
+
min_rating, max_rating = st.slider("Range of ratings", min_value=df["Rating"].min(), max_value=df["Rating"].max(), step=0.1, value=(5.0, df["Rating"].max()))
|
403 |
+
min_price, max_price = st.slider("Range of room prices", min_value=df["Price"].min(), max_value=df["Price"].max(), step=10, value=(df["Price"].min(), 10000))
|
404 |
|
405 |
+
# Convert price range sliders to integer values
|
406 |
+
min_price = int(min_price)
|
407 |
+
max_price = int(max_price)
|
408 |
|
409 |
+
with col2:
|
410 |
+
country_info(location)
|
411 |
|
412 |
|
413 |
+
preferences = {
|
414 |
+
'Location': location,
|
415 |
+
'Number of beds': num_beds,
|
416 |
+
'Rating': [min_rating, max_rating],
|
417 |
+
'Price': [min_price, max_price],
|
418 |
+
}
|
419 |
+
|
420 |
|
421 |
+
if st.button("Recommend Hotels"):
|
422 |
+
st.info("Hotels were recommended based on how similar they were to the users preferences.")
|
423 |
+
|
424 |
+
# Default number of recommendations to show
|
425 |
+
max_recommendations = 5
|
426 |
+
|
427 |
+
# Call the recommendation function
|
|
|
|
|
|
|
|
|
|
|
428 |
recommended_hotels, message = recommend_hotels_with_location_and_beds(df, preferences, max_recommendations)
|
429 |
+
|
430 |
+
# If no recommendations, reduce the maximum number of recommendations and try again
|
431 |
if recommended_hotels.empty:
|
432 |
+
max_recommendations -= 1
|
433 |
+
recommended_hotels, message = recommend_hotels_with_location_and_beds(df, preferences, max_recommendations)
|
434 |
+
if recommended_hotels.empty:
|
435 |
+
st.error(message)
|
436 |
+
# else:
|
437 |
+
# st.write(recommended_hotels)
|
438 |
+
else:
|
439 |
+
st.markdown(" ")
|
440 |
+
for i in range(len(recommended_hotels)):
|
441 |
+
#st.dataframe(recommended_hotels)
|
442 |
+
df_result = recommended_hotels.iloc[i,:]
|
443 |
+
col1_, col2_ = st.columns([0.4,0.6], gap="medium")
|
444 |
+
|
445 |
+
with col1_:
|
446 |
+
st.image("images/room.jpg",width=100)
|
447 |
+
st.markdown(f"### {i+1}: {df_result['Hotel Name']}")
|
448 |
+
st.markdown(f"""**{df_result['Room Type']}** <br>
|
449 |
+
with {df_result['Bed Type']}
|
450 |
+
""", unsafe_allow_html=True)
|
451 |
+
with col2_:
|
452 |
+
st.markdown(" ")
|
453 |
+
st.markdown(" ")
|
454 |
+
annotated_text("**Number of beds :** ",(f"{df_result['Number of bed']}","","#faa"))
|
455 |
+
#st.markdown(f"**Bed type**: {df_result['Bed Type']}")
|
456 |
+
annotated_text("**City:** ",(f"{df_result['City']}","","#afa"))
|
457 |
+
annotated_text("**Rating:** ",(f"{df_result['Rating']}","","#8ef"))
|
458 |
+
annotated_text("**Price:** ",(f"{df_result['Price']}$","","#fea"))
|
459 |
+
|
460 |
+
st.divider()
|
461 |
+
|
pages/sentiment_analysis.py
CHANGED
@@ -10,7 +10,7 @@ import plotly.express as px
|
|
10 |
from st_pages import add_indentation
|
11 |
|
12 |
from pysentimiento import create_analyzer
|
13 |
-
from utils import load_data_pickle
|
14 |
|
15 |
st.set_page_config(layout="wide")
|
16 |
#add_indentation()
|
@@ -33,227 +33,226 @@ def load_sa_model():
|
|
33 |
|
34 |
|
35 |
|
36 |
-
|
37 |
-
st.markdown("# Sentiment Analysis 👍")
|
|
|
38 |
|
39 |
-
st.
|
|
|
|
|
40 |
|
41 |
-
st.
|
42 |
-
Sentiment analysis is a **Natural Language Processing** (NLP) task that involves determining the sentiment or emotion expressed in a piece of text.
|
43 |
-
It has a wide range of use cases across various industries, as it helps organizations gain insights into the opinions, emotions, and attitudes expressed in text data.""")
|
44 |
|
45 |
-
|
|
|
|
|
46 |
|
47 |
-
|
48 |
-
with col:
|
49 |
-
st.image("images/sentiment_analysis.png") #, width=800)
|
50 |
-
|
51 |
-
st.markdown(" ")
|
52 |
-
|
53 |
-
st.markdown("""
|
54 |
-
Common applications of Natural Language Processing include:
|
55 |
-
- **Customer Feedback and Reviews** 💯: Assessing reviews on products or services to understand customer satisfaction and identify areas for improvement.
|
56 |
-
- **Market Research** 🔍: Analyzing survey responses or online forums to gauge public opinion on products, services, or emerging trends.
|
57 |
-
- **Financial Market Analysis** 📉: Monitoring financial news, reports, and social media to gauge investor sentiment and predict market trends.
|
58 |
-
- **Government and Public Policy** 📣: Analyzing public opinion on government policies, initiatives, and political decisions to gauge public sentiment and inform decision-making.
|
59 |
-
""")
|
60 |
-
|
61 |
-
st.divider()
|
62 |
-
|
63 |
-
#sa_pages = ["Starbucks Customer Reviews (Text)", "Tiktok's US Congressional Hearing (Audio)"]
|
64 |
-
#st.markdown("### Select a use case ")
|
65 |
-
#use_case = st.selectbox("", sa_pages, label_visibility="collapsed")
|
66 |
-
|
67 |
-
|
68 |
-
st.markdown("# Customer Review Analysis 📝")
|
69 |
-
st.info("""In this use case, **sentiment analysis** is used to predict the **polarity** (negative, neutral, positive) of customer reviews.
|
70 |
-
You can try the application by using the provided starbucks customer reviews, or by writing your own.""")
|
71 |
-
st.markdown(" ")
|
72 |
-
|
73 |
-
_, col, _ = st.columns([0.2,0.6,0.2])
|
74 |
-
with col:
|
75 |
-
st.image("images/reviews.png",use_column_width=True)
|
76 |
|
77 |
-
st.markdown("
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
|
|
79 |
|
80 |
-
#
|
81 |
-
|
82 |
-
|
83 |
-
reviews_df.reset_index(drop=True, inplace=True)
|
84 |
-
reviews_df["Date"] = reviews_df["Date"].dt.date
|
85 |
-
reviews_df["Year"] = reviews_df["Year"].astype(int)
|
86 |
|
87 |
|
|
|
|
|
|
|
|
|
88 |
|
89 |
-
st.
|
90 |
-
|
|
|
91 |
|
92 |
-
with tab1_:
|
93 |
-
# FILTER DATA
|
94 |
st.markdown(" ")
|
95 |
|
96 |
-
col1, col2 = st.columns([0.2, 0.8], gap="medium")
|
97 |
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
index=3, label_visibility="collapsed")
|
105 |
-
|
106 |
-
if select_image_box == "Filter by Date (Year)":
|
107 |
-
selected_date = st.multiselect("Date (Year)", reviews_df["Year"].unique(), default=reviews_df["Year"].unique()[0])
|
108 |
-
reviews_df = reviews_df.loc[reviews_df["Year"].isin(selected_date)]
|
109 |
-
|
110 |
-
if select_image_box == "Filter by State":
|
111 |
-
selected_state = st.multiselect("State", reviews_df["State"].unique(), default=reviews_df["State"].unique()[0])
|
112 |
-
reviews_df = reviews_df.loc[reviews_df["State"].isin(selected_state)]
|
113 |
-
|
114 |
-
if select_image_box == "Filter by Rating":
|
115 |
-
selected_rating = st.multiselect("Rating", sorted(list(reviews_df["Rating"].dropna().unique())),
|
116 |
-
default = sorted(list(reviews_df["Rating"].dropna().unique()))[0])
|
117 |
-
reviews_df = reviews_df.loc[reviews_df["Rating"].isin(selected_rating)]
|
118 |
-
|
119 |
-
if select_image_box == "No filters":
|
120 |
-
pass
|
121 |
-
|
122 |
-
#st.slider()
|
123 |
-
run_model1 = st.button("**Run the model**", type="primary", key="tab1")
|
124 |
-
st.info("The model has already been trained in this use case.")
|
125 |
-
|
126 |
-
with col2:
|
127 |
-
# VIEW DATA
|
128 |
-
st.markdown("""<b>View the reviews:</b> <br>
|
129 |
-
The dataset contains the location (State), date, rating, text and images (if provided) for each review.""",
|
130 |
-
unsafe_allow_html=True)
|
131 |
-
|
132 |
-
st.data_editor(
|
133 |
-
reviews_df.drop(columns=["Year"]),
|
134 |
-
column_config={"Image 1": st.column_config.ImageColumn("Image 1"),
|
135 |
-
"Image 2": st.column_config.ImageColumn("Image 2")},
|
136 |
-
hide_index=True)
|
137 |
-
|
138 |
|
139 |
-
######### SHOW RESULTS ########
|
140 |
-
if run_model1:
|
141 |
-
with st.spinner('Wait for it...'):
|
142 |
-
df_results = load_data_pickle(path_sa,"reviews_results.pkl")
|
143 |
-
df_results.reset_index(drop=True, inplace=True)
|
144 |
|
145 |
-
index_row = np.array(reviews_df.index)
|
146 |
-
df_results = df_results.iloc[index_row].reset_index(drop=True)
|
147 |
-
df_results["Review"] = reviews_df["Review"]
|
148 |
-
st.markdown(" ")
|
149 |
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
with tab1: # Overall results (tab_1)
|
154 |
-
# get results df
|
155 |
-
df_results_tab1 = df_results[["ID","Review","Rating","Negative","Neutral","Positive","Result"]]
|
156 |
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
perct_negative = df_warning.loc[df_warning["Result"]=="Negative","Percentage"].to_numpy()[0]
|
162 |
-
if perct_negative > 50:
|
163 |
-
st.error(f"**Negative reviews alert** ⚠️: The proportion of negative reviews is {perct_negative}% !")
|
164 |
-
|
165 |
-
# show dataframe results
|
166 |
-
st.data_editor(
|
167 |
-
df_results_tab1, #.loc[df_results_tab1["Customer ID"].isin(filter_customers)],
|
168 |
-
column_config={
|
169 |
-
"Negative": st.column_config.ProgressColumn(
|
170 |
-
"Negative 👎",
|
171 |
-
help="Negative score of the review",
|
172 |
-
format="%d%%",
|
173 |
-
min_value=0,
|
174 |
-
max_value=100),
|
175 |
-
"Neutral": st.column_config.ProgressColumn(
|
176 |
-
"Neutral ✋",
|
177 |
-
help="Neutral score of the review",
|
178 |
-
format="%d%%",
|
179 |
-
min_value=0,
|
180 |
-
max_value=100),
|
181 |
-
"Positive": st.column_config.ProgressColumn(
|
182 |
-
"Positive 👍",
|
183 |
-
help="Positive score of the review",
|
184 |
-
format="%d%%",
|
185 |
-
min_value=0,
|
186 |
-
max_value=100)},
|
187 |
-
hide_index=True,
|
188 |
-
)
|
189 |
-
|
190 |
-
with tab2: # Results by state (tab_1)
|
191 |
-
avg_state = df_results[["State","Negative","Neutral","Positive"]].groupby(["State"]).mean().round()
|
192 |
-
avg_state = avg_state.reset_index().melt(id_vars="State", var_name="Sentiment", value_name="Score (%)")
|
193 |
-
|
194 |
-
chart_state = alt.Chart(avg_state, title="Review polarity per state").mark_bar().encode(
|
195 |
-
x=alt.X('Sentiment', axis=alt.Axis(title=None, labels=False, ticks=False)),
|
196 |
-
y=alt.Y('Score (%)', axis=alt.Axis(grid=False)),
|
197 |
-
color='Sentiment',
|
198 |
-
column=alt.Column('State', header=alt.Header(title=None, labelOrient='bottom'))
|
199 |
-
).configure_view(
|
200 |
-
stroke='transparent'
|
201 |
-
).interactive()
|
202 |
|
203 |
-
|
204 |
-
st.altair_chart(chart_state)
|
205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
).interactive()
|
218 |
|
219 |
-
st.markdown(" ")
|
220 |
-
st.
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
|
242 |
-
|
243 |
-
with st.spinner('Wait for it...'):
|
244 |
-
#sentiment_analyzer = create_analyzer(task="sentiment", lang="en")
|
245 |
-
# Load model with cache
|
246 |
-
sentiment_analyzer = load_sa_model()
|
247 |
-
q = sentiment_analyzer.predict(txt_review)
|
248 |
-
|
249 |
-
df_review_user = pd.DataFrame({"Polarity":["Positive","Neutral","Negative"],
|
250 |
-
"Score":[q.probas['POS'], q.probas['NEU'], q.probas['NEG']]})
|
251 |
|
252 |
-
|
253 |
-
st.
|
254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
|
256 |
-
|
257 |
-
|
258 |
|
259 |
|
|
|
10 |
from st_pages import add_indentation
|
11 |
|
12 |
from pysentimiento import create_analyzer
|
13 |
+
from utils import load_data_pickle, check_password
|
14 |
|
15 |
st.set_page_config(layout="wide")
|
16 |
#add_indentation()
|
|
|
33 |
|
34 |
|
35 |
|
36 |
+
if check_password():
|
37 |
+
st.markdown("# Sentiment Analysis 👍")
|
38 |
+
st.markdown("### What is Sentiment Analysis ?")
|
39 |
|
40 |
+
st.info("""
|
41 |
+
Sentiment analysis is a **Natural Language Processing** (NLP) task that involves determining the sentiment or emotion expressed in a piece of text.
|
42 |
+
It has a wide range of use cases across various industries, as it helps organizations gain insights into the opinions, emotions, and attitudes expressed in text data.""")
|
43 |
|
44 |
+
st.markdown("Here is an example of Sentiment analysis used to analyze **Customer Satisfaction** for perfums.")
|
|
|
|
|
45 |
|
46 |
+
_, col, _ = st.columns([0.1,0.8,0.1])
|
47 |
+
with col:
|
48 |
+
st.image("images/sentiment_analysis.png") #, width=800)
|
49 |
|
50 |
+
st.markdown(" ")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
+
st.markdown("""
|
53 |
+
Common applications of Natural Language Processing include:
|
54 |
+
- **Customer Feedback and Reviews** 💯: Assessing reviews on products or services to understand customer satisfaction and identify areas for improvement.
|
55 |
+
- **Market Research** 🔍: Analyzing survey responses or online forums to gauge public opinion on products, services, or emerging trends.
|
56 |
+
- **Financial Market Analysis** 📉: Monitoring financial news, reports, and social media to gauge investor sentiment and predict market trends.
|
57 |
+
- **Government and Public Policy** 📣: Analyzing public opinion on government policies, initiatives, and political decisions to gauge public sentiment and inform decision-making.
|
58 |
+
""")
|
59 |
|
60 |
+
st.divider()
|
61 |
|
62 |
+
#sa_pages = ["Starbucks Customer Reviews (Text)", "Tiktok's US Congressional Hearing (Audio)"]
|
63 |
+
#st.markdown("### Select a use case ")
|
64 |
+
#use_case = st.selectbox("", sa_pages, label_visibility="collapsed")
|
|
|
|
|
|
|
65 |
|
66 |
|
67 |
+
st.markdown("# Customer Review Analysis 📝")
|
68 |
+
st.info("""In this use case, **sentiment analysis** is used to predict the **polarity** (negative, neutral, positive) of customer reviews.
|
69 |
+
You can try the application by using the provided starbucks customer reviews, or by writing your own.""")
|
70 |
+
st.markdown(" ")
|
71 |
|
72 |
+
_, col, _ = st.columns([0.2,0.6,0.2])
|
73 |
+
with col:
|
74 |
+
st.image("images/reviews.png",use_column_width=True)
|
75 |
|
|
|
|
|
76 |
st.markdown(" ")
|
77 |
|
|
|
78 |
|
79 |
+
# Load data
|
80 |
+
path_sa = "data/sa_data"
|
81 |
+
reviews_df = load_data_pickle(path_sa,"reviews_raw.pkl")
|
82 |
+
reviews_df.reset_index(drop=True, inplace=True)
|
83 |
+
reviews_df["Date"] = reviews_df["Date"].dt.date
|
84 |
+
reviews_df["Year"] = reviews_df["Year"].astype(int)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
|
|
|
|
|
|
|
|
|
|
86 |
|
|
|
|
|
|
|
|
|
87 |
|
88 |
+
st.markdown("#### Predict polarity 🤔")
|
89 |
+
tab1_, tab2_ = st.tabs(["Starbucks reviews", "Write a review"])
|
|
|
|
|
|
|
|
|
90 |
|
91 |
+
with tab1_:
|
92 |
+
# FILTER DATA
|
93 |
+
st.markdown(" ")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
+
col1, col2 = st.columns([0.2, 0.8], gap="medium")
|
|
|
96 |
|
97 |
+
with col1:
|
98 |
+
st.markdown("""<b>Filter reviews: </b> <br>
|
99 |
+
You can filter the dataset by Date, State or Rating""", unsafe_allow_html=True)
|
100 |
+
|
101 |
+
select_image_box = st.radio("",
|
102 |
+
["Filter by Date (Year)", "Filter by State", "Filter by Rating", "No filters"],
|
103 |
+
index=3, label_visibility="collapsed")
|
104 |
+
|
105 |
+
if select_image_box == "Filter by Date (Year)":
|
106 |
+
selected_date = st.multiselect("Date (Year)", reviews_df["Year"].unique(), default=reviews_df["Year"].unique()[0])
|
107 |
+
reviews_df = reviews_df.loc[reviews_df["Year"].isin(selected_date)]
|
108 |
+
|
109 |
+
if select_image_box == "Filter by State":
|
110 |
+
selected_state = st.multiselect("State", reviews_df["State"].unique(), default=reviews_df["State"].unique()[0])
|
111 |
+
reviews_df = reviews_df.loc[reviews_df["State"].isin(selected_state)]
|
112 |
+
|
113 |
+
if select_image_box == "Filter by Rating":
|
114 |
+
selected_rating = st.multiselect("Rating", sorted(list(reviews_df["Rating"].dropna().unique())),
|
115 |
+
default = sorted(list(reviews_df["Rating"].dropna().unique()))[0])
|
116 |
+
reviews_df = reviews_df.loc[reviews_df["Rating"].isin(selected_rating)]
|
117 |
+
|
118 |
+
if select_image_box == "No filters":
|
119 |
+
pass
|
120 |
+
|
121 |
+
#st.slider()
|
122 |
+
run_model1 = st.button("**Run the model**", type="primary", key="tab1")
|
123 |
+
st.info("The model has already been trained in this use case.")
|
124 |
+
|
125 |
+
with col2:
|
126 |
+
# VIEW DATA
|
127 |
+
st.markdown("""<b>View the reviews:</b> <br>
|
128 |
+
The dataset contains the location (State), date, rating, text and images (if provided) for each review.""",
|
129 |
+
unsafe_allow_html=True)
|
130 |
+
|
131 |
+
st.data_editor(
|
132 |
+
reviews_df.drop(columns=["Year"]),
|
133 |
+
column_config={"Image 1": st.column_config.ImageColumn("Image 1"),
|
134 |
+
"Image 2": st.column_config.ImageColumn("Image 2")},
|
135 |
+
hide_index=True)
|
136 |
+
|
137 |
|
138 |
+
######### SHOW RESULTS ########
|
139 |
+
if run_model1:
|
140 |
+
with st.spinner('Wait for it...'):
|
141 |
+
df_results = load_data_pickle(path_sa,"reviews_results.pkl")
|
142 |
+
df_results.reset_index(drop=True, inplace=True)
|
143 |
|
144 |
+
index_row = np.array(reviews_df.index)
|
145 |
+
df_results = df_results.iloc[index_row].reset_index(drop=True)
|
146 |
+
df_results["Review"] = reviews_df["Review"]
|
147 |
+
st.markdown(" ")
|
|
|
148 |
|
149 |
+
st.markdown("#### See the results ☑️")
|
150 |
+
tab1, tab2, tab3 = st.tabs(["All results", "Results per state", "Results per year"])
|
151 |
+
|
152 |
+
with tab1: # Overall results (tab_1)
|
153 |
+
# get results df
|
154 |
+
df_results_tab1 = df_results[["ID","Review","Rating","Negative","Neutral","Positive","Result"]]
|
155 |
+
|
156 |
+
# warning message
|
157 |
+
df_warning = df_results_tab1["Result"].value_counts().to_frame().reset_index()
|
158 |
+
df_warning["Percentage"] = (100*df_warning["count"]/df_warning["count"].sum()).round(2)
|
159 |
+
|
160 |
+
perct_negative = df_warning.loc[df_warning["Result"]=="Negative","Percentage"].to_numpy()[0]
|
161 |
+
if perct_negative > 50:
|
162 |
+
st.error(f"**Negative reviews alert** ⚠️: The proportion of negative reviews is {perct_negative}% !")
|
163 |
+
|
164 |
+
# show dataframe results
|
165 |
+
st.data_editor(
|
166 |
+
df_results_tab1, #.loc[df_results_tab1["Customer ID"].isin(filter_customers)],
|
167 |
+
column_config={
|
168 |
+
"Negative": st.column_config.ProgressColumn(
|
169 |
+
"Negative 👎",
|
170 |
+
help="Negative score of the review",
|
171 |
+
format="%d%%",
|
172 |
+
min_value=0,
|
173 |
+
max_value=100),
|
174 |
+
"Neutral": st.column_config.ProgressColumn(
|
175 |
+
"Neutral ✋",
|
176 |
+
help="Neutral score of the review",
|
177 |
+
format="%d%%",
|
178 |
+
min_value=0,
|
179 |
+
max_value=100),
|
180 |
+
"Positive": st.column_config.ProgressColumn(
|
181 |
+
"Positive 👍",
|
182 |
+
help="Positive score of the review",
|
183 |
+
format="%d%%",
|
184 |
+
min_value=0,
|
185 |
+
max_value=100)},
|
186 |
+
hide_index=True,
|
187 |
+
)
|
188 |
+
|
189 |
+
with tab2: # Results by state (tab_1)
|
190 |
+
avg_state = df_results[["State","Negative","Neutral","Positive"]].groupby(["State"]).mean().round()
|
191 |
+
avg_state = avg_state.reset_index().melt(id_vars="State", var_name="Sentiment", value_name="Score (%)")
|
192 |
+
|
193 |
+
chart_state = alt.Chart(avg_state, title="Review polarity per state").mark_bar().encode(
|
194 |
+
x=alt.X('Sentiment', axis=alt.Axis(title=None, labels=False, ticks=False)),
|
195 |
+
y=alt.Y('Score (%)', axis=alt.Axis(grid=False)),
|
196 |
+
color='Sentiment',
|
197 |
+
column=alt.Column('State', header=alt.Header(title=None, labelOrient='bottom'))
|
198 |
+
).configure_view(
|
199 |
+
stroke='transparent'
|
200 |
+
).interactive()
|
201 |
+
|
202 |
+
st.markdown(" ")
|
203 |
+
st.altair_chart(chart_state)
|
204 |
+
|
205 |
+
|
206 |
+
with tab3: # Results by year (tab_1)
|
207 |
+
avg_year = df_results[["Year","Negative","Neutral","Positive"]]
|
208 |
+
#avg_year["Year"] = avg_year["Year"].astype(str)
|
209 |
+
avg_year = avg_year.groupby(["Year"]).mean().round()
|
210 |
+
avg_year = avg_year.reset_index().melt(id_vars="Year", var_name="Sentiment", value_name="Score (%)")
|
211 |
+
|
212 |
+
chart_year = alt.Chart(avg_year, title="Evolution of review polarity").mark_area(opacity=0.5).encode(
|
213 |
+
x='Year',
|
214 |
+
y='Score (%)',
|
215 |
+
color='Sentiment',
|
216 |
+
).interactive()
|
217 |
+
|
218 |
+
st.markdown(" ")
|
219 |
+
st.altair_chart(chart_year, use_container_width=True)
|
220 |
+
|
221 |
+
# else:
|
222 |
+
# st.warning("You must select at least one review to run the model.")
|
223 |
+
|
224 |
+
|
225 |
+
#### WRITE YOUR OWN REVIEW #####""
|
226 |
+
with tab2_:
|
227 |
+
st.markdown("**Write your own review**")
|
228 |
+
|
229 |
+
txt_review = st.text_area(
|
230 |
+
"Write your review",
|
231 |
+
"I recently visited a local Starbucks, and unfortunately, my experience was far from satisfactory. "
|
232 |
+
"From the moment I stepped in, the atmosphere felt chaotic and disorganized. "
|
233 |
+
"The staff appeared overwhelmed, leading to a significant delay in receiving my order. "
|
234 |
+
"The quality of my drink further added to my disappointment. "
|
235 |
+
"The coffee tasted burnt, as if it had been sitting on the burner for far too long.",
|
236 |
+
label_visibility="collapsed"
|
237 |
+
)
|
238 |
|
239 |
+
run_model2 = st.button("**Run the model**", type="primary", key="tab2")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
|
241 |
+
if run_model2:
|
242 |
+
with st.spinner('Wait for it...'):
|
243 |
+
#sentiment_analyzer = create_analyzer(task="sentiment", lang="en")
|
244 |
+
# Load model with cache
|
245 |
+
sentiment_analyzer = load_sa_model()
|
246 |
+
q = sentiment_analyzer.predict(txt_review)
|
247 |
+
|
248 |
+
df_review_user = pd.DataFrame({"Polarity":["Positive","Neutral","Negative"],
|
249 |
+
"Score":[q.probas['POS'], q.probas['NEU'], q.probas['NEG']]})
|
250 |
+
|
251 |
+
st.markdown(" ")
|
252 |
+
st.info(f"""Your review was **{int(q.probas['POS']*100)}%** positive, **{int(q.probas['NEU']*100)}%** neutral
|
253 |
+
and **{int(q.probas['NEG']*100)}%** negative.""")
|
254 |
|
255 |
+
fig = px.bar(df_review_user, x='Score', y='Polarity', color="Polarity", title='Sentiment analysis results', orientation="h")
|
256 |
+
st.plotly_chart(fig, use_container_width=True)
|
257 |
|
258 |
|
pages/supervised_unsupervised_page.py
CHANGED
@@ -6,9 +6,8 @@ import numpy as np
|
|
6 |
import plotly.express as px
|
7 |
from PIL import Image
|
8 |
|
9 |
-
from utils import load_data_pickle, load_model_pickle
|
10 |
from st_pages import add_indentation
|
11 |
-
from annotated_text import annotated_text
|
12 |
|
13 |
#####################################################################################
|
14 |
# PAGE CONFIG
|
@@ -23,763 +22,763 @@ st.set_page_config(layout="wide")
|
|
23 |
# INTRO
|
24 |
#####################################################################################
|
25 |
|
26 |
-
|
27 |
-
st.markdown("# Supervised vs Unsupervised Learning 🔍")
|
28 |
|
29 |
-
st.info("""Data Science models are often split into two categories: **Supervised** and **Unsupervised Learning**.
|
30 |
-
|
31 |
-
|
32 |
|
33 |
-
st.markdown(" ")
|
34 |
-
#st.markdown("## What are the differences between both ?")
|
35 |
|
36 |
-
col1, col2 = st.columns(2, gap="large")
|
37 |
|
38 |
-
with col1:
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
- A model is first **trained** to make predictions using labeled data, which doesn't contain the desired output.
|
43 |
- The trained model can then be used to **predict values** for new data.
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
|
48 |
-
with col2:
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
- They can be useful for applications where the goal is to discover **unknown groupings** in the data.
|
53 |
- They are also used to identify unusual patterns or **outliers**.
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
|
59 |
-
st.markdown(" ")
|
60 |
|
61 |
-
learning_type = st.selectbox("**Select an AI task**",
|
62 |
-
|
63 |
-
|
64 |
|
65 |
|
66 |
|
67 |
|
68 |
|
69 |
-
#######################################################################################################################
|
70 |
-
# SUPERVISED LEARNING
|
71 |
-
#######################################################################################################################
|
72 |
|
73 |
|
74 |
-
if learning_type == "Supervised Learning":
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
st.markdown(" ")
|
80 |
-
|
81 |
-
# st.divider()
|
82 |
-
|
83 |
-
path_data_supervised = r"data/classification"
|
84 |
-
path_pretrained_supervised = r"pretrained_models/supervised_learning"
|
85 |
-
|
86 |
-
################################# CREDIT SCORE ######################################
|
87 |
-
|
88 |
-
if sl_usecase == "Credit score classification 💯":
|
89 |
-
|
90 |
-
path_credit = os.path.join(path_data_supervised,"credit_score")
|
91 |
-
|
92 |
-
## Description of the use case
|
93 |
-
st.divider()
|
94 |
-
st.markdown("# Credit score classification 💯")
|
95 |
-
st.info("""**Classification models** are supervised learning models whose goal is to categorize data into predefined categories.
|
96 |
-
As opposed to unsupervised learning models, these categories are known beforehand.
|
97 |
-
Other types of supervised learning models include Regression models, which learn how to predict numerical values, instead of a set number of categories.""")
|
98 |
|
99 |
-
st.markdown("In this use case, we will build a **credit score classification model** which predicts whether a client has a 'Bad', 'Standard', or 'Good' credit score.")
|
100 |
st.markdown(" ")
|
101 |
-
|
102 |
-
_, col, _ = st.columns([0.25,0.5,0.25])
|
103 |
-
with col:
|
104 |
-
st.image("images/credit_score.jpg")
|
105 |
-
|
106 |
-
## Learn about the data
|
107 |
-
st.markdown("#### About the data 📋")
|
108 |
-
st.markdown("""To train the credit classification model, you were provided a **labeled** database with 7600 clients and containing bank and credit-related client information. <br>
|
109 |
-
This dataset is 'labeled' since it contains information on what we are trying to predict, which is the **Credit_Score** variable.""",
|
110 |
-
unsafe_allow_html=True)
|
111 |
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
# View data
|
121 |
-
see_data = st.checkbox('**See the data**', key="credit_score\data")
|
122 |
-
if see_data:
|
123 |
-
st.warning("The data of only the first 30 clients are shown.")
|
124 |
-
st.dataframe(credit_train.head(30).reset_index(drop=True))
|
125 |
-
|
126 |
-
learn_data = st.checkbox('**Learn more about the data**', key="credit_score_var")
|
127 |
-
if learn_data:
|
128 |
-
st.markdown("""
|
129 |
-
- **Age**: The client's age
|
130 |
-
- **Occupation**: The client's occupation/job
|
131 |
-
- **Credit_Mix**: Score for the different type of credit accounts a client has (mortgages, loans, credit cards, ...)
|
132 |
-
- **Payment_of_Min_Amount**: Whether the client is making the minimum required payments on their credit accounts (Yes, No, NM:Not mentioned)
|
133 |
-
- **Annual_Income**: The client's annual income
|
134 |
-
- **Num_Bank_Accounts**: Number of bank accounts opened
|
135 |
-
- **Num_Credit_Card**: Number of credit cards owned
|
136 |
-
- **Interest_Rate**: The client's average interest rate
|
137 |
-
- **Num_of_Loan**: Number of loans of the client
|
138 |
-
- **Changed_Credit_Limit**: Whether a client changed his credit limit once or not (Yes, No) -
|
139 |
-
- **Outstanding Debt**: A client's outstanding debt
|
140 |
-
- **Credit_History_Age**: The length of a client's credit history (in months)
|
141 |
-
""")
|
142 |
-
|
143 |
-
st.markdown(" ")
|
144 |
-
st.markdown(" ")
|
145 |
-
|
146 |
-
## Train the algorithm
|
147 |
-
st.markdown("#### Train the algorithm ⚙️")
|
148 |
-
st.info("""**Training** an AI model means feeding it data that contains multiple examples of clients with their credit scores.
|
149 |
-
Using the labeled data provided, the model will **learn relationships** between a client's credit score and the other bank/credit-related variables provided.
|
150 |
-
Using these learned relationships, the model will then try to make **accurate predictions**.""")
|
151 |
|
152 |
-
|
153 |
-
|
154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
-
|
157 |
-
|
|
|
158 |
|
159 |
-
|
160 |
-
st.
|
161 |
-
|
162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
-
|
|
|
|
|
|
|
|
|
165 |
|
166 |
-
|
167 |
-
st.
|
168 |
-
|
169 |
-
st.markdown(" ")
|
170 |
-
st.markdown(" ")
|
171 |
-
time.sleep(2)
|
172 |
-
st.markdown("#### See the results ☑️")
|
173 |
-
tab1, tab2 = st.tabs(["Performance", "Explainability"])
|
174 |
-
|
175 |
-
######## MODEL PERFORMANCE
|
176 |
-
with tab1:
|
177 |
-
results_train = load_data_pickle(path_credit,"credit_score_cm_train")
|
178 |
-
results_train = results_train.to_numpy()
|
179 |
-
accuracy = np.round(results_train.diagonal()*100)
|
180 |
-
df_accuracy = pd.DataFrame({"Credit Score":["Good","Poor","Standard"],
|
181 |
-
"Accuracy":accuracy})
|
182 |
|
183 |
-
|
184 |
-
|
185 |
-
In this use case, the performance of the credit score model was measured by comparing clients' true credit scores with the scores predicted by the trained model.""")
|
186 |
|
187 |
-
|
188 |
-
|
|
|
|
|
189 |
|
190 |
-
|
191 |
-
<i>This is crucial as to understand whether the model is consistant in its performance, or whether it has trouble distinguishing between two kinds of credit score.</i>""",
|
192 |
-
unsafe_allow_html=True)
|
193 |
|
|
|
|
|
|
|
194 |
st.markdown(" ")
|
195 |
-
|
196 |
-
st.markdown("""**Interpretation**: <br>
|
197 |
-
Our model's is overall quite accurate in predicting all types of credit scores with an accuracy that is above 85% for each.
|
198 |
-
We do note that is slighly more accuracte in predicting a good credit score (92%) and less for a standard credit score (86%).
|
199 |
-
This can be due to the model having a harder time distinguishing between clients with a standard credit score and other more "extreme" credit scores (Good, Bad).
|
200 |
-
""", unsafe_allow_html=True)
|
201 |
-
|
202 |
-
##### MODEL EXPLAINABILITY
|
203 |
-
with tab2:
|
204 |
st.markdown(" ")
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
# Create feature importance dataframe
|
209 |
-
df_var_importance = pd.DataFrame({"variable":credit_test_pp.columns,
|
210 |
-
"score":credit_model.feature_importances_})
|
211 |
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
""", unsafe_allow_html=True)
|
232 |
-
|
233 |
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
st.markdown("#### Predict credit score 🆕")
|
240 |
-
st.info("You can only predict the credit score of new clients once the **model has been trained.**")
|
241 |
-
st.markdown(" ")
|
242 |
-
|
243 |
-
col1, col2 = st.columns([0.25,0.75], gap="medium")
|
244 |
-
|
245 |
-
credit_test = load_data_pickle(path_credit,"credit_score_test_raw.pkl")
|
246 |
-
credit_test.reset_index(drop=True, inplace=True)
|
247 |
-
credit_test.insert(0, "Client ID", [f"{i}" for i in range(credit_test.shape[0])])
|
248 |
-
credit_test = credit_test.loc[credit_test["Num_Bank_Accounts"]>0]
|
249 |
-
#credit_test.drop(columns=["Credit_Score"], inplace=True)
|
250 |
-
|
251 |
-
with col1:
|
252 |
-
st.markdown("""<b>Filter the data</b> <br>
|
253 |
-
You can select clients based on their *Age*, *Annual income* or *Oustanding Debt*.""",
|
254 |
-
unsafe_allow_html=True)
|
255 |
-
|
256 |
-
select_image_box = st.radio(" ",
|
257 |
-
["Filter by Age", "Filter by Income", "Filter by Outstanding Debt", "No filters"],
|
258 |
-
label_visibility="collapsed")
|
259 |
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
(7000, 100000), label_visibility="collapsed", key="income")
|
270 |
-
credit_test = credit_test.loc[credit_test["Annual_Income"].between(min_income, max_income)]
|
271 |
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
|
|
276 |
|
277 |
-
|
278 |
-
|
|
|
|
|
|
|
279 |
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
""
|
284 |
-
warning_threshold = st.slider('Select a value', min_value=20, max_value=100, step=10,
|
285 |
-
label_visibility="collapsed", key="warning")
|
286 |
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
|
298 |
-
|
299 |
-
|
300 |
-
X_test = credit_test_pp.iloc[credit_test.index,:]
|
301 |
-
predictions = credit_model.predict(X_test)
|
302 |
-
|
303 |
-
df_results_pred = credit_test.copy()
|
304 |
-
df_results_pred["Credit Score"] = predictions
|
305 |
-
df_mean_pred = df_results_pred["Credit Score"].value_counts().to_frame().reset_index()
|
306 |
-
df_mean_pred.columns = ["Credit Score", "Proportion"]
|
307 |
-
df_mean_pred["Proportion"] = (100*df_mean_pred["Proportion"]/df_results_pred.shape[0]).round()
|
308 |
|
309 |
-
perct_bad_score = df_mean_pred.loc[df_mean_pred["Credit Score"]=="Poor"]["Proportion"].to_numpy()
|
310 |
|
311 |
-
|
312 |
-
|
|
|
|
|
|
|
|
|
313 |
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
#title="Proportion of credit scores")
|
319 |
-
st.plotly_chart(fig, use_container_width=True)
|
320 |
-
|
321 |
-
with col2:
|
322 |
-
df_show_results = df_results_pred[["Credit Score","Client ID"] + [col for col in df_results_pred.columns if col not in ["Client ID","Credit Score"]]]
|
323 |
-
columns_float = df_show_results.select_dtypes(include="float").columns
|
324 |
-
df_show_results[columns_float] = df_show_results[columns_float].astype(int)
|
325 |
-
|
326 |
-
def highlight_score(val):
|
327 |
-
if val == "Good":
|
328 |
-
color = 'red'
|
329 |
-
if val == 'Standard':
|
330 |
-
color= "cornflowerblue"
|
331 |
-
if val == "Poor":
|
332 |
-
color = 'blue'
|
333 |
-
return f'color: {color}'
|
334 |
|
335 |
-
|
|
|
|
|
|
|
|
|
336 |
|
337 |
-
|
338 |
-
st.dataframe(df_show_results_color)
|
339 |
-
|
340 |
-
else:
|
341 |
-
st.error("You have to train the credit score model first.")
|
342 |
-
|
343 |
|
|
|
|
|
344 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
345 |
|
|
|
|
|
346 |
|
347 |
|
348 |
-
################################# CUSTOMER CHURN #####################################
|
349 |
-
|
350 |
-
elif sl_usecase == "Customer churn prediction ❌":
|
351 |
-
#st.warning("This page is under construction")
|
352 |
-
path_churn = r"data/classification/churn"
|
353 |
-
|
354 |
-
## Description of the use case
|
355 |
-
st.divider()
|
356 |
-
st.markdown("# Customer churn prediction ❌")
|
357 |
-
|
358 |
-
st.info("""**Classification models** are supervised learning models whose goal is to categorize data into predefined categories.
|
359 |
-
As opposed to unsupervised learning models, these categories are known beforehand.
|
360 |
-
Other types of supervised learning models include Regression models, which learn how to predict numerical values, instead of a set number of categories.""")
|
361 |
-
|
362 |
-
st.markdown("For this use case, we will build a **customer churn classification model** that can predict whether a person will stop being a customer using historical data.")
|
363 |
-
|
364 |
-
st.markdown(" ")
|
365 |
|
366 |
-
## Load data
|
367 |
-
churn_data = load_data_pickle(path_churn, "churn_train_raw.pkl")
|
368 |
-
|
369 |
-
_, col, _ = st.columns([0.1,0.8,0.1])
|
370 |
-
with col:
|
371 |
-
st.image("images/customer-churn.png", use_column_width=True)
|
372 |
-
|
373 |
-
st.markdown(" ")
|
374 |
|
375 |
-
## Learn about the data
|
376 |
-
st.markdown("#### About the data 📋")
|
377 |
-
st.markdown("""To train the customer churn model, you were provided a **labeled** database with around 7000 clients of a telecommunications company. <br>
|
378 |
-
The data contains information on which services the customer has signed for, account information as well as whether the customer churned or not (our label here).""",
|
379 |
-
unsafe_allow_html=True)
|
380 |
-
# st.markdown("This dataset is 'labeled' since it contains information on what we are trying to predict, which is the **Churn** variable.")
|
381 |
-
st.info("**Note**: The variables that had two possible values (Yes or No) where transformed into binary variables (0 or 1) with 0 being 'No' and 1 being 'Yes'.")
|
382 |
|
383 |
-
see_data = st.checkbox('**See the data**', key="churn-data")
|
384 |
|
385 |
-
|
386 |
-
st.warning("You can only view the first 30 customers in this section.")
|
387 |
-
churn_data = load_data_pickle(path_churn, "churn_train_raw.pkl")
|
388 |
-
st.dataframe(churn_data)
|
389 |
-
|
390 |
-
learn_data = st.checkbox('**Learn more about the data**', key="churn-var")
|
391 |
-
if learn_data:
|
392 |
-
st.markdown("""
|
393 |
-
- **SeniorCitizen**: Whether the customer is a senior citizen or not (1, 0)
|
394 |
-
- **Partner**: Whether the customer has a partner or not (Yes, No)
|
395 |
-
- **Dependents**: Whether the customer has dependents or not (Yes, No)
|
396 |
-
- **tenure**: Number of months the customer has stayed with the company
|
397 |
-
- **PhoneService**: Whether the customer has a phone service or not (Yes, No)
|
398 |
-
- **MultipleLines**: Whether the customer has multiple lines or not (Yes, No)
|
399 |
-
- **InternetService**: Customer’s internet service provider (DSL, Fiber optic, No)
|
400 |
-
- **OnlineSecurity**: Whether the customer has online security or not (Yes, No)
|
401 |
-
- **OnlineBackup**: Whether the customer has online backup or not (Yes, No)
|
402 |
-
- **DeviceProtection**: Whether the customer has device protection or not (Yes, No)
|
403 |
-
- **TechSupport**: Whether the customer has tech support or not (Yes, No)
|
404 |
-
- **StreamingTV**: Whether the customer has streaming TV or not (Yes, No)
|
405 |
-
- **StreamingMovies**: Whether the customer has streaming movies or not (Yes, No)
|
406 |
-
- **Contract**: The contract term of the customer (Month-to-month, One year, Two year)
|
407 |
-
- **PaperlessBilling**: Whether the customer has paperless billing or not (Yes, No)
|
408 |
-
- **PaymentMethod**: The customer’s payment method (Electronic check, Mailed check, Bank transfer (automatic), Credit card (automatic))
|
409 |
-
- **MonthlyCharges**: The amount charged to the customer monthly
|
410 |
-
- **TotalCharges**: The total amount charged to the customer
|
411 |
-
- <span style="color: red;"> **Churn** (the variable we want to predict): Whether the customer churned or not (Yes or No) </span>
|
412 |
-
""", unsafe_allow_html=True)
|
413 |
|
414 |
-
|
415 |
-
|
|
|
416 |
|
|
|
|
|
|
|
417 |
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
It helps practitioners understand the structure, patterns, and characteristics of the data they are working with.
|
422 |
-
For this use case, we will perform EDA by analyzing the **proportion of clients who have churned or not** based on the dataset's other variables.""")
|
423 |
-
|
424 |
-
st.info("**Note**: EDA is usually preformed before model training as it helps inform decisions made by the model throughout the modeling process.")
|
425 |
|
426 |
-
|
427 |
-
if see_EDA:
|
428 |
-
st.markdown(" ")
|
429 |
|
430 |
-
# Show EDA image
|
431 |
-
st.markdown("""Exploratory Data Analysis has been preformed between the predicted variable `Churn` with 15 other variables present in the dataset. <br>
|
432 |
-
Each graphs shows the proportion of churned and not churned customer based on the variable's possible values.""", unsafe_allow_html=True)
|
433 |
st.markdown(" ")
|
434 |
|
435 |
-
|
436 |
-
|
437 |
|
438 |
-
st.
|
|
|
|
|
439 |
|
440 |
-
|
441 |
-
st.markdown("""**Interpretation** <br>
|
442 |
-
For variables such as `Contract`, `PaperlessBilling`, `PaymentMethod` and `InternetService`, we can see a significant difference in the proportion of churned customers based on the variable's value.
|
443 |
-
In the *Contract* graph, clients with a 'Month-to-Month' tend to churn more often than those with a longer contract.
|
444 |
-
In the *InternetService* graph, clients with a 'Fiber optic' service are more likely to churn than those with DSL or no internet service. """, unsafe_allow_html=True)
|
445 |
-
|
446 |
-
st.info("""**Note**: Performing EDA can give us an indication as to which variables might be more significant in the customer churn model.
|
447 |
-
It can be a valuable tool to study the relationship between two variables but can sometimes be too simplistic. Some relationships might be top complex to be seen through EDA.""")
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
st.markdown(""" """)
|
452 |
-
st.markdown(""" """)
|
453 |
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
459 |
|
|
|
|
|
460 |
|
461 |
-
if 'model_train_churn' not in st.session_state:
|
462 |
-
st.session_state['model_train_churn'] = False
|
463 |
-
|
464 |
-
if st.session_state.model_train_churn:
|
465 |
-
st.write("The model has already been trained.")
|
466 |
-
else:
|
467 |
-
st.write("The model hasn't been trained yet")
|
468 |
|
469 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
470 |
|
|
|
|
|
|
|
471 |
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
st.markdown(" ")
|
|
|
|
|
|
|
|
|
476 |
st.markdown(" ")
|
477 |
-
time.sleep(2)
|
478 |
-
st.markdown("#### See the results ☑️")
|
479 |
-
tab1, tab2 = st.tabs(["Performance", "Explainability"])
|
480 |
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
|
|
488 |
|
489 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
490 |
|
491 |
-
st.markdown(" ")
|
492 |
-
st.info("""**Note**: Evaluating a model's performance helps provide a quantitative measure of the model's ability to make accurate decisions.
|
493 |
-
In this use case, the performance of the customer churn model was measured by comparing the clients' churn variables with the value predicted by the trained model.""")
|
494 |
|
495 |
-
|
496 |
-
|
497 |
-
#fig.update_traces(textposition='inside', textfont=dict(color='white'))
|
498 |
-
st.plotly_chart(fig, use_container_width=True)
|
499 |
|
500 |
-
|
501 |
-
|
502 |
-
|
|
|
503 |
|
504 |
-
|
505 |
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
##### MODEL EXPLAINABILITY
|
512 |
-
with tab2:
|
513 |
st.markdown(" ")
|
514 |
-
st.
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
df_var_importance = load_data_pickle(path_churn, "churn_feature_importance.pkl")
|
519 |
-
df_var_importance.rename({"importance":"score"}, axis=1, inplace=True)
|
520 |
-
df_var_importance.sort_values(by=["score"], inplace=True)
|
521 |
-
df_var_importance["score"] = df_var_importance["score"].round(3)
|
522 |
-
|
523 |
-
# Feature importance plot with plotly
|
524 |
-
fig = px.bar(df_var_importance, x='score', y='variable', color="score", orientation="h", title="Model explainability")
|
525 |
-
st.plotly_chart(fig, use_container_width=True)
|
526 |
-
|
527 |
-
st.markdown("""<b>Interpretation</b> <br>
|
528 |
-
The client's tenure, amount of Monthly and Total Charges, as well as the type of Contract had the most impact on the model's churn predictions.
|
529 |
-
On the other hand, whether the client is subscribed to a streaming platform, he is covered by device protection or he has or not phone service had a very contribution in the final predictions.
|
530 |
-
""", unsafe_allow_html=True)
|
531 |
|
532 |
-
|
533 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
534 |
|
535 |
-
|
536 |
-
|
|
|
|
|
537 |
|
538 |
-
|
|
|
|
|
539 |
|
540 |
-
|
541 |
-
|
542 |
-
churn_test = load_data_pickle(path_churn,"churn_test_raw.pkl")
|
543 |
-
churn_test.reset_index(drop=True, inplace=True)
|
544 |
-
churn_test.insert(0, "Client ID", [f"{i}" for i in range(churn_test.shape[0])])
|
545 |
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
550 |
|
551 |
-
|
552 |
-
|
553 |
-
label_visibility="collapsed")
|
554 |
|
555 |
-
|
556 |
-
st.markdown(" ")
|
557 |
-
min_tenure, max_tenure = st.slider('Select a range', churn_test["tenure"].astype(int).min(), churn_test["tenure"].astype(int).max(), (1,50),
|
558 |
-
key="tenure", label_visibility="collapsed")
|
559 |
-
churn_test = churn_test.loc[churn_test["tenure"].between(min_tenure,max_tenure)]
|
560 |
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
|
572 |
-
|
573 |
-
|
|
|
574 |
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
label_visibility="collapsed", key="warning")
|
581 |
-
|
582 |
-
st.markdown(" ")
|
583 |
-
st.write("The threshold is at", warning_threshold, "%")
|
584 |
-
|
585 |
|
586 |
-
|
587 |
-
|
588 |
-
|
|
|
|
|
589 |
|
|
|
|
|
|
|
|
|
590 |
|
591 |
-
|
592 |
-
|
593 |
-
st.markdown(" ")
|
594 |
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
|
602 |
-
|
603 |
-
|
604 |
-
predictions = ["No" if x==0 else "Yes" for x in predictions]
|
605 |
|
606 |
-
df_results_pred = churn_test.copy()
|
607 |
-
df_results_pred["Churn"] = predictions
|
608 |
-
df_mean_pred = df_results_pred["Churn"].value_counts().to_frame().reset_index()
|
609 |
-
df_mean_pred.columns = ["Churn", "Proportion"]
|
610 |
-
df_mean_pred["Proportion"] = (100*df_mean_pred["Proportion"]/df_results_pred.shape[0]).round()
|
611 |
|
612 |
-
|
|
|
|
|
613 |
|
614 |
-
if perct_churned >= warning_threshold:
|
615 |
-
st.error(f"The proportion of clients that have churned is above {warning_threshold}% (at {perct_churned[0]}%)⚠️")
|
616 |
|
617 |
-
|
|
|
|
|
618 |
|
619 |
-
|
620 |
-
|
621 |
-
st.markdown("**Proporition of predicted churn**")
|
622 |
-
fig = px.pie(df_mean_pred, values='Proportion', names='Churn', color="Churn",
|
623 |
-
color_discrete_map={'No':'royalblue', 'Yes':'red'})
|
624 |
-
st.plotly_chart(fig, use_container_width=True)
|
625 |
-
|
626 |
-
with col2:
|
627 |
-
df_show_results = df_results_pred[["Churn","Client ID"] + [col for col in df_results_pred.columns if col not in ["Client ID","Churn"]]]
|
628 |
-
columns_float = df_show_results.select_dtypes(include="float").columns
|
629 |
-
df_show_results[columns_float] = df_show_results[columns_float].astype(int)
|
630 |
-
|
631 |
-
def highlight_score(val):
|
632 |
-
if val == "No":
|
633 |
-
color = 'royalblue'
|
634 |
-
if val == 'Yes':
|
635 |
-
color= "red"
|
636 |
-
return f'color: {color}'
|
637 |
-
|
638 |
-
df_show_results_color = df_show_results.style.applymap(highlight_score, subset=['Churn'])
|
639 |
|
640 |
-
|
641 |
-
|
|
|
642 |
|
643 |
-
|
644 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
645 |
|
|
|
646 |
|
|
|
|
|
647 |
|
|
|
648 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
649 |
|
|
|
|
|
650 |
|
651 |
-
#######################################################################################################################
|
652 |
-
# UNSUPERVISED LEARNING
|
653 |
-
#######################################################################################################################
|
654 |
|
655 |
|
656 |
-
def markdown_general_info(df):
|
657 |
-
text = st.markdown(f"""
|
658 |
-
- **Age**: {int(np.round(df.Age))}
|
659 |
-
- **Yearly income**: {int(df.Income)} $
|
660 |
-
- **Number of kids**: {df.Kids}
|
661 |
-
- **Days of enrollment**: {int(np.round(df.Days_subscription))}
|
662 |
-
- **Web visits per month**: {df.WebVisitsMonth}
|
663 |
-
""")
|
664 |
-
return text
|
665 |
|
666 |
|
667 |
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
|
673 |
-
#################################### CUSTOMER SEGMENTATION ##################################
|
674 |
|
675 |
-
|
676 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
677 |
|
678 |
-
if usl_usecase == "Customer segmentation (clustering) 🧑🤝🧑":
|
679 |
|
680 |
-
# st.divider()
|
681 |
-
st.divider()
|
682 |
-
st.markdown("# Customer Segmentation (clustering) 🧑🤝🧑")
|
683 |
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
st.markdown("""You are giving a database that contains information on around **2000 customers** of a mass-market retailer.
|
698 |
-
The database's contains **personal information** (age, income, number of kids...), as well as information on what types of products were purchased by the client, how long has he been enrolled as a client and where these purchases were made. """, unsafe_allow_html=True)
|
699 |
-
|
700 |
-
see_data = st.checkbox('**See the data**', key="dataframe")
|
701 |
-
|
702 |
-
if see_data:
|
703 |
-
customer_data = load_data_pickle(path_clustering, "clean_marketing.pkl")
|
704 |
-
st.dataframe(customer_data.head(10))
|
705 |
-
|
706 |
-
learn_data = st.checkbox('**Learn more about the variables**', key="variable")
|
707 |
-
|
708 |
-
if learn_data:
|
709 |
-
st.markdown("""
|
710 |
-
- **Age**: Customer's age
|
711 |
-
- **Income**: Customer's yearly household income
|
712 |
-
- **Kids**: Number of children/teenagers in customer's household
|
713 |
-
- **Days_subscription**: Number of days since a customer's enrollment with the company
|
714 |
-
- **Recency**: Number of days since customer's last purchase
|
715 |
-
- **Wines**: Proportion of money spent on wine in last 2 years
|
716 |
-
- **Fruits**: Proportion of money spent on fruits in last 2 years
|
717 |
-
- **MeatProducts**: Proportion of money spent on meat in last 2 years
|
718 |
-
- **FishProducts**: Proportion of money spent on fish in last 2 years
|
719 |
-
- **SweetProducts**: Proportion of money spent sweets in last 2 years
|
720 |
-
- **DealsPurchases**: Proportion of purchases made with a discount
|
721 |
-
- **WebPurchases**: Proportion of purchases made through the company’s website
|
722 |
-
- **CatalogPurchases**: Proporition of purchases made using a catalogue
|
723 |
-
- **StorePurchases**: Proportion of purchases made directly in stores
|
724 |
-
- **WebVisitsMonth**: Proportion of visits to company’s website in the last month""")
|
725 |
st.divider()
|
|
|
726 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
727 |
|
728 |
-
st.markdown(" ")
|
729 |
-
st.markdown(" ")
|
730 |
|
731 |
-
|
|
|
732 |
|
733 |
-
|
734 |
-
In our case, a data points represents a customer that will be assigned to an unknown group.""")
|
735 |
-
|
736 |
-
# st.markdown("""
|
737 |
-
# - The clustering algorithm used in this use case allows a specific number of groups to be identified, which isn't the case for all clustering models.""")
|
738 |
|
739 |
-
|
740 |
-
|
741 |
-
|
742 |
-
|
743 |
-
|
744 |
-
|
745 |
-
|
746 |
-
|
747 |
-
|
748 |
-
|
749 |
-
|
750 |
-
|
751 |
-
|
752 |
-
|
753 |
-
|
754 |
-
|
755 |
-
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
|
767 |
-
|
768 |
-
|
769 |
-
|
770 |
-
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
|
782 |
-
|
783 |
-
|
784 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
785 |
|
|
|
6 |
import plotly.express as px
|
7 |
from PIL import Image
|
8 |
|
9 |
+
from utils import load_data_pickle, load_model_pickle, check_password
|
10 |
from st_pages import add_indentation
|
|
|
11 |
|
12 |
#####################################################################################
|
13 |
# PAGE CONFIG
|
|
|
22 |
# INTRO
|
23 |
#####################################################################################
|
24 |
|
25 |
+
if check_password():
|
26 |
+
st.markdown("# Supervised vs Unsupervised Learning 🔍")
|
27 |
|
28 |
+
st.info("""Data Science models are often split into two categories: **Supervised** and **Unsupervised Learning**.
|
29 |
+
The goal of this page is to present these two kinds of Data Science models, as well as give you multiple use cases to try them with.
|
30 |
+
Note that other kinds of AI models exist such as Reinforcement Learning or Federated Learning, which we won't cover in this app.""")
|
31 |
|
32 |
+
st.markdown(" ")
|
33 |
+
#st.markdown("## What are the differences between both ?")
|
34 |
|
35 |
+
col1, col2 = st.columns(2, gap="large")
|
36 |
|
37 |
+
with col1:
|
38 |
+
st.markdown("## Supervised Learning")
|
39 |
+
st.markdown("""Supervised learning models are trained by learning from **labeled data**. <br>
|
40 |
+
Labeled data provides to the model the desired output, which it will then use to learn relevant patterns and make predictions.
|
41 |
- A model is first **trained** to make predictions using labeled data, which doesn't contain the desired output.
|
42 |
- The trained model can then be used to **predict values** for new data.
|
43 |
+
""", unsafe_allow_html=True)
|
44 |
+
st.markdown(" ")
|
45 |
+
st.image("images/supervised_learner.png", caption="An example of supervised learning")
|
46 |
|
47 |
+
with col2:
|
48 |
+
st.markdown("## Unsupervised Learning")
|
49 |
+
st.markdown("""Unsupervised learning models learn the data's inherent structure without any explicit guidance on what to look for.
|
50 |
+
The algorithm will identify any naturally occurring patterns in the dataset using **unlabeled data**.
|
51 |
- They can be useful for applications where the goal is to discover **unknown groupings** in the data.
|
52 |
- They are also used to identify unusual patterns or **outliers**.
|
53 |
+
""", unsafe_allow_html=True)
|
54 |
+
st.markdown(" ")
|
55 |
+
st.image("images/unsupervised_learning.png", caption="An example of unsupervised Learning",
|
56 |
+
use_column_width=True)
|
57 |
|
58 |
+
st.markdown(" ")
|
59 |
|
60 |
+
learning_type = st.selectbox("**Select an AI task**",
|
61 |
+
["Supervised Learning",
|
62 |
+
"Unsupervised Learning"])
|
63 |
|
64 |
|
65 |
|
66 |
|
67 |
|
68 |
+
#######################################################################################################################
|
69 |
+
# SUPERVISED LEARNING
|
70 |
+
#######################################################################################################################
|
71 |
|
72 |
|
73 |
+
if learning_type == "Supervised Learning":
|
74 |
+
sl_usecase = st.selectbox("**Choose a use case**",
|
75 |
+
["Credit score classification 💯",
|
76 |
+
"Customer churn prediction ❌"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
|
|
78 |
st.markdown(" ")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
+
# st.divider()
|
81 |
+
|
82 |
+
path_data_supervised = r"data/classification"
|
83 |
+
path_pretrained_supervised = r"pretrained_models/supervised_learning"
|
84 |
+
|
85 |
+
################################# CREDIT SCORE ######################################
|
86 |
+
|
87 |
+
if sl_usecase == "Credit score classification 💯":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
+
path_credit = os.path.join(path_data_supervised,"credit_score")
|
90 |
+
|
91 |
+
## Description of the use case
|
92 |
+
st.divider()
|
93 |
+
st.markdown("# Credit score classification 💯")
|
94 |
+
st.info("""**Classification models** are supervised learning models whose goal is to categorize data into predefined categories.
|
95 |
+
As opposed to unsupervised learning models, these categories are known beforehand.
|
96 |
+
Other types of supervised learning models include Regression models, which learn how to predict numerical values, instead of a set number of categories.""")
|
97 |
+
|
98 |
+
st.markdown("In this use case, we will build a **credit score classification model** which predicts whether a client has a 'Bad', 'Standard', or 'Good' credit score.")
|
99 |
+
st.markdown(" ")
|
100 |
|
101 |
+
_, col, _ = st.columns([0.25,0.5,0.25])
|
102 |
+
with col:
|
103 |
+
st.image("images/credit_score.jpg")
|
104 |
|
105 |
+
## Learn about the data
|
106 |
+
st.markdown("#### About the data 📋")
|
107 |
+
st.markdown("""To train the credit classification model, you were provided a **labeled** database with 7600 clients and containing bank and credit-related client information. <br>
|
108 |
+
This dataset is 'labeled' since it contains information on what we are trying to predict, which is the **Credit_Score** variable.""",
|
109 |
+
unsafe_allow_html=True)
|
110 |
+
|
111 |
+
## Load data
|
112 |
+
credit_train = load_data_pickle(path_credit, "credit_score_train_raw.pkl")
|
113 |
+
credit_test_pp = load_data_pickle(path_credit, "credit_score_test_pp.pkl")
|
114 |
+
labels = ["Good","Poor","Standard"]
|
115 |
+
|
116 |
+
## Load model
|
117 |
+
credit_model = load_model_pickle(path_pretrained_supervised,"credit_score_model.pkl")
|
118 |
+
|
119 |
+
# View data
|
120 |
+
see_data = st.checkbox('**See the data**', key="credit_score\data")
|
121 |
+
if see_data:
|
122 |
+
st.warning("The data of only the first 30 clients are shown.")
|
123 |
+
st.dataframe(credit_train.head(30).reset_index(drop=True))
|
124 |
+
|
125 |
+
learn_data = st.checkbox('**Learn more about the data**', key="credit_score_var")
|
126 |
+
if learn_data:
|
127 |
+
st.markdown("""
|
128 |
+
- **Age**: The client's age
|
129 |
+
- **Occupation**: The client's occupation/job
|
130 |
+
- **Credit_Mix**: Score for the different type of credit accounts a client has (mortgages, loans, credit cards, ...)
|
131 |
+
- **Payment_of_Min_Amount**: Whether the client is making the minimum required payments on their credit accounts (Yes, No, NM:Not mentioned)
|
132 |
+
- **Annual_Income**: The client's annual income
|
133 |
+
- **Num_Bank_Accounts**: Number of bank accounts opened
|
134 |
+
- **Num_Credit_Card**: Number of credit cards owned
|
135 |
+
- **Interest_Rate**: The client's average interest rate
|
136 |
+
- **Num_of_Loan**: Number of loans of the client
|
137 |
+
- **Changed_Credit_Limit**: Whether a client changed his credit limit once or not (Yes, No) -
|
138 |
+
- **Outstanding Debt**: A client's outstanding debt
|
139 |
+
- **Credit_History_Age**: The length of a client's credit history (in months)
|
140 |
+
""")
|
141 |
+
|
142 |
+
st.markdown(" ")
|
143 |
+
st.markdown(" ")
|
144 |
|
145 |
+
## Train the algorithm
|
146 |
+
st.markdown("#### Train the algorithm ⚙️")
|
147 |
+
st.info("""**Training** an AI model means feeding it data that contains multiple examples of clients with their credit scores.
|
148 |
+
Using the labeled data provided, the model will **learn relationships** between a client's credit score and the other bank/credit-related variables provided.
|
149 |
+
Using these learned relationships, the model will then try to make **accurate predictions**.""")
|
150 |
|
151 |
+
# st.markdown("""Before feeding the model data for training, exploratory data analysis is often conducted to discover if patterns can discovered beforehand.""")
|
152 |
+
# st.image("images/models/credit_score/EDA_numeric_credit.png")
|
153 |
+
#st.markdown("In our case, the training data is the dataset containing the bank and credit information of our 7600 customers.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
+
if 'model_train' not in st.session_state:
|
156 |
+
st.session_state['model_train'] = False
|
|
|
157 |
|
158 |
+
if st.session_state.model_train:
|
159 |
+
st.write("The model has been trained.")
|
160 |
+
else:
|
161 |
+
st.write("The model hasn't been trained yet")
|
162 |
|
163 |
+
run_credit_model = st.button("**Train the model**")
|
|
|
|
|
164 |
|
165 |
+
if run_credit_model:
|
166 |
+
st.session_state.model_train = True
|
167 |
+
with st.spinner('Wait for it...'):
|
168 |
st.markdown(" ")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
st.markdown(" ")
|
170 |
+
time.sleep(2)
|
171 |
+
st.markdown("#### See the results ☑️")
|
172 |
+
tab1, tab2 = st.tabs(["Performance", "Explainability"])
|
|
|
|
|
|
|
173 |
|
174 |
+
######## MODEL PERFORMANCE
|
175 |
+
with tab1:
|
176 |
+
results_train = load_data_pickle(path_credit,"credit_score_cm_train")
|
177 |
+
results_train = results_train.to_numpy()
|
178 |
+
accuracy = np.round(results_train.diagonal()*100)
|
179 |
+
df_accuracy = pd.DataFrame({"Credit Score":["Good","Poor","Standard"],
|
180 |
+
"Accuracy":accuracy})
|
181 |
+
|
182 |
+
st.markdown(" ")
|
183 |
+
st.info("""**Evaluating a model's performance** helps provide a quantitative measure of the model's ability to make accurate decisions.
|
184 |
+
In this use case, the performance of the credit score model was measured by comparing clients' true credit scores with the scores predicted by the trained model.""")
|
185 |
+
|
186 |
+
fig = px.bar(df_accuracy, y='Accuracy', x='Credit Score', color="Credit Score", title="Model performance")
|
187 |
+
st.plotly_chart(fig, use_container_width=True)
|
188 |
+
|
189 |
+
st.markdown("""<i>The model's accuracy was measured for every type of credit score (Good, Standard, Poor).</i>
|
190 |
+
<i>This is crucial as to understand whether the model is consistant in its performance, or whether it has trouble distinguishing between two kinds of credit score.</i>""",
|
191 |
+
unsafe_allow_html=True)
|
192 |
+
|
193 |
+
st.markdown(" ")
|
194 |
+
|
195 |
+
st.markdown("""**Interpretation**: <br>
|
196 |
+
Our model's is overall quite accurate in predicting all types of credit scores with an accuracy that is above 85% for each.
|
197 |
+
We do note that is slighly more accuracte in predicting a good credit score (92%) and less for a standard credit score (86%).
|
198 |
+
This can be due to the model having a harder time distinguishing between clients with a standard credit score and other more "extreme" credit scores (Good, Bad).
|
199 |
+
""", unsafe_allow_html=True)
|
200 |
|
201 |
+
##### MODEL EXPLAINABILITY
|
202 |
+
with tab2:
|
203 |
+
st.markdown(" ")
|
204 |
+
st.info("""**Explainability** in AI refers to the ability to understand which variable used by a model during training had the most impact on the final predictions and how to quantify this impact.
|
205 |
+
Understanding the inner workings of a model helps build trust among users and stakeholders, as well as increase acceptance.""")
|
206 |
+
|
207 |
+
# Create feature importance dataframe
|
208 |
+
df_var_importance = pd.DataFrame({"variable":credit_test_pp.columns,
|
209 |
+
"score":credit_model.feature_importances_})
|
210 |
+
|
211 |
+
# Compute average score for categorical variables
|
212 |
+
for column in ["Occupation","Credit_Mix","Payment_of_Min_Amount"]:
|
213 |
+
col_remove = [col for col in credit_test_pp.columns if f"{column}_" in col]
|
214 |
+
avg_score = df_var_importance.loc[df_var_importance["variable"].isin(col_remove)]["score"].mean()
|
215 |
+
|
216 |
+
df_var_importance = df_var_importance.loc[~df_var_importance["variable"].isin(col_remove)]
|
217 |
+
new_row = pd.DataFrame([[column, avg_score]], columns=["variable","score"])
|
218 |
+
df_var_importance = pd.concat([df_var_importance, new_row], ignore_index=True)
|
219 |
+
|
220 |
+
df_var_importance.sort_values(by=["score"], inplace=True)
|
221 |
+
df_var_importance["score"] = df_var_importance["score"].round(3)
|
222 |
|
223 |
+
# Feature importance plot with plotly
|
224 |
+
fig = px.bar(df_var_importance, x='score', y='variable', color="score", orientation="h", title="Model explainability")
|
225 |
+
st.plotly_chart(fig, use_container_width=True)
|
|
|
|
|
226 |
|
227 |
+
st.markdown("""<b>Interpretation</b>: <br>
|
228 |
+
A client's outstanding debt, interest rate and delay from due date were the most crucial factors in explaining their final credit score. <br>
|
229 |
+
Whether a client is making their minimum required payments on their credit accounts (Payment_min_amount), their occupation and their number of loans had a very limited impact on their credit score.,
|
230 |
+
""", unsafe_allow_html=True)
|
231 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
|
233 |
+
st.markdown(" ")
|
234 |
+
st.markdown(" ")
|
235 |
+
|
236 |
+
|
237 |
+
## Make predictions
|
238 |
+
st.markdown("#### Predict credit score 🆕")
|
239 |
+
st.info("You can only predict the credit score of new clients once the **model has been trained.**")
|
240 |
+
st.markdown(" ")
|
241 |
+
|
242 |
+
col1, col2 = st.columns([0.25,0.75], gap="medium")
|
243 |
+
|
244 |
+
credit_test = load_data_pickle(path_credit,"credit_score_test_raw.pkl")
|
245 |
+
credit_test.reset_index(drop=True, inplace=True)
|
246 |
+
credit_test.insert(0, "Client ID", [f"{i}" for i in range(credit_test.shape[0])])
|
247 |
+
credit_test = credit_test.loc[credit_test["Num_Bank_Accounts"]>0]
|
248 |
+
#credit_test.drop(columns=["Credit_Score"], inplace=True)
|
249 |
+
|
250 |
+
with col1:
|
251 |
+
st.markdown("""<b>Filter the data</b> <br>
|
252 |
+
You can select clients based on their *Age*, *Annual income* or *Oustanding Debt*.""",
|
253 |
+
unsafe_allow_html=True)
|
254 |
|
255 |
+
select_image_box = st.radio(" ",
|
256 |
+
["Filter by Age", "Filter by Income", "Filter by Outstanding Debt", "No filters"],
|
257 |
+
label_visibility="collapsed")
|
|
|
|
|
258 |
|
259 |
+
if select_image_box == "Filter by Age":
|
260 |
+
st.markdown(" ")
|
261 |
+
min_age, max_age = st.slider('Select a range', credit_test["Age"].astype(int).min(), credit_test["Age"].astype(int).max(), (19,50),
|
262 |
+
key="age", label_visibility="collapsed")
|
263 |
+
credit_test = credit_test.loc[credit_test["Age"].between(min_age,max_age)]
|
264 |
|
265 |
+
if select_image_box == "Filter by Income":
|
266 |
+
st.markdown(" ")
|
267 |
+
min_income, max_income = st.slider('Select a range', credit_test["Annual_Income"].astype(int).min(), 180000,
|
268 |
+
(7000, 100000), label_visibility="collapsed", key="income")
|
269 |
+
credit_test = credit_test.loc[credit_test["Annual_Income"].between(min_income, max_income)]
|
270 |
|
271 |
+
if select_image_box == "Filter by Outstanding Debt":
|
272 |
+
min_debt, max_debt = st.slider('Select a range', credit_test["Outstanding_Debt"].astype(int).min(), credit_test["Outstanding_Debt"].astype(int).max(),
|
273 |
+
(0,2000), label_visibility="collapsed", key="debt")
|
274 |
+
credit_test = credit_test.loc[credit_test["Outstanding_Debt"].between(min_debt, max_debt)]
|
|
|
|
|
275 |
|
276 |
+
if select_image_box == "No filters":
|
277 |
+
pass
|
|
|
278 |
|
279 |
+
st.markdown(" ")
|
280 |
+
st.markdown("""<b>Select a threshold for the alert</b> <br>
|
281 |
+
A warning message will be displayed if the percentage of poor credit scores exceeds this threshold.
|
282 |
+
""", unsafe_allow_html=True)
|
283 |
+
warning_threshold = st.slider('Select a value', min_value=20, max_value=100, step=10,
|
284 |
+
label_visibility="collapsed", key="warning")
|
285 |
|
286 |
+
st.markdown(" ")
|
287 |
+
st.write("The threshold is at", warning_threshold, "%")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
|
|
|
289 |
|
290 |
+
with col2:
|
291 |
+
#st.markdown("**View the database**")
|
292 |
+
st.dataframe(credit_test)
|
293 |
+
|
294 |
+
make_predictions = st.button("**Make predictions**")
|
295 |
+
st.markdown(" ")
|
296 |
|
297 |
+
if make_predictions:
|
298 |
+
if st.session_state.model_train is True:
|
299 |
+
X_test = credit_test_pp.iloc[credit_test.index,:]
|
300 |
+
predictions = credit_model.predict(X_test)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
|
302 |
+
df_results_pred = credit_test.copy()
|
303 |
+
df_results_pred["Credit Score"] = predictions
|
304 |
+
df_mean_pred = df_results_pred["Credit Score"].value_counts().to_frame().reset_index()
|
305 |
+
df_mean_pred.columns = ["Credit Score", "Proportion"]
|
306 |
+
df_mean_pred["Proportion"] = (100*df_mean_pred["Proportion"]/df_results_pred.shape[0]).round()
|
307 |
|
308 |
+
perct_bad_score = df_mean_pred.loc[df_mean_pred["Credit Score"]=="Poor"]["Proportion"].to_numpy()
|
|
|
|
|
|
|
|
|
|
|
309 |
|
310 |
+
if perct_bad_score >= warning_threshold:
|
311 |
+
st.error(f"The proportion of clients with a bad credit score is above {warning_threshold}% (at {perct_bad_score[0]}%)⚠️")
|
312 |
|
313 |
+
col1, col2 = st.columns([0.4,0.6], gap="large")
|
314 |
+
with col1:
|
315 |
+
st.markdown("**Proporition of predicted credit scores**")
|
316 |
+
fig = px.pie(df_mean_pred, values='Proportion', names='Credit Score')
|
317 |
+
#title="Proportion of credit scores")
|
318 |
+
st.plotly_chart(fig, use_container_width=True)
|
319 |
+
|
320 |
+
with col2:
|
321 |
+
df_show_results = df_results_pred[["Credit Score","Client ID"] + [col for col in df_results_pred.columns if col not in ["Client ID","Credit Score"]]]
|
322 |
+
columns_float = df_show_results.select_dtypes(include="float").columns
|
323 |
+
df_show_results[columns_float] = df_show_results[columns_float].astype(int)
|
324 |
+
|
325 |
+
def highlight_score(val):
|
326 |
+
if val == "Good":
|
327 |
+
color = 'red'
|
328 |
+
if val == 'Standard':
|
329 |
+
color= "cornflowerblue"
|
330 |
+
if val == "Poor":
|
331 |
+
color = 'blue'
|
332 |
+
return f'color: {color}'
|
333 |
+
|
334 |
+
df_show_results_color = df_show_results.style.applymap(highlight_score, subset=['Credit Score'])
|
335 |
+
|
336 |
+
st.markdown("**Overall results**")
|
337 |
+
st.dataframe(df_show_results_color)
|
338 |
|
339 |
+
else:
|
340 |
+
st.error("You have to train the credit score model first.")
|
341 |
|
342 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
345 |
|
|
|
346 |
|
347 |
+
################################# CUSTOMER CHURN #####################################
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
348 |
|
349 |
+
elif sl_usecase == "Customer churn prediction ❌":
|
350 |
+
#st.warning("This page is under construction")
|
351 |
+
path_churn = r"data/classification/churn"
|
352 |
|
353 |
+
## Description of the use case
|
354 |
+
st.divider()
|
355 |
+
st.markdown("# Customer churn prediction ❌")
|
356 |
|
357 |
+
st.info("""**Classification models** are supervised learning models whose goal is to categorize data into predefined categories.
|
358 |
+
As opposed to unsupervised learning models, these categories are known beforehand.
|
359 |
+
Other types of supervised learning models include Regression models, which learn how to predict numerical values, instead of a set number of categories.""")
|
|
|
|
|
|
|
|
|
360 |
|
361 |
+
st.markdown("For this use case, we will build a **customer churn classification model** that can predict whether a person will stop being a customer using historical data.")
|
|
|
|
|
362 |
|
|
|
|
|
|
|
363 |
st.markdown(" ")
|
364 |
|
365 |
+
## Load data
|
366 |
+
churn_data = load_data_pickle(path_churn, "churn_train_raw.pkl")
|
367 |
|
368 |
+
_, col, _ = st.columns([0.1,0.8,0.1])
|
369 |
+
with col:
|
370 |
+
st.image("images/customer-churn.png", use_column_width=True)
|
371 |
|
372 |
+
st.markdown(" ")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
|
374 |
+
## Learn about the data
|
375 |
+
st.markdown("#### About the data 📋")
|
376 |
+
st.markdown("""To train the customer churn model, you were provided a **labeled** database with around 7000 clients of a telecommunications company. <br>
|
377 |
+
The data contains information on which services the customer has signed for, account information as well as whether the customer churned or not (our label here).""",
|
378 |
+
unsafe_allow_html=True)
|
379 |
+
# st.markdown("This dataset is 'labeled' since it contains information on what we are trying to predict, which is the **Churn** variable.")
|
380 |
+
st.info("**Note**: The variables that had two possible values (Yes or No) where transformed into binary variables (0 or 1) with 0 being 'No' and 1 being 'Yes'.")
|
381 |
+
|
382 |
+
see_data = st.checkbox('**See the data**', key="churn-data")
|
383 |
+
|
384 |
+
if see_data:
|
385 |
+
st.warning("You can only view the first 30 customers in this section.")
|
386 |
+
churn_data = load_data_pickle(path_churn, "churn_train_raw.pkl")
|
387 |
+
st.dataframe(churn_data)
|
388 |
+
|
389 |
+
learn_data = st.checkbox('**Learn more about the data**', key="churn-var")
|
390 |
+
if learn_data:
|
391 |
+
st.markdown("""
|
392 |
+
- **SeniorCitizen**: Whether the customer is a senior citizen or not (1, 0)
|
393 |
+
- **Partner**: Whether the customer has a partner or not (Yes, No)
|
394 |
+
- **Dependents**: Whether the customer has dependents or not (Yes, No)
|
395 |
+
- **tenure**: Number of months the customer has stayed with the company
|
396 |
+
- **PhoneService**: Whether the customer has a phone service or not (Yes, No)
|
397 |
+
- **MultipleLines**: Whether the customer has multiple lines or not (Yes, No)
|
398 |
+
- **InternetService**: Customer’s internet service provider (DSL, Fiber optic, No)
|
399 |
+
- **OnlineSecurity**: Whether the customer has online security or not (Yes, No)
|
400 |
+
- **OnlineBackup**: Whether the customer has online backup or not (Yes, No)
|
401 |
+
- **DeviceProtection**: Whether the customer has device protection or not (Yes, No)
|
402 |
+
- **TechSupport**: Whether the customer has tech support or not (Yes, No)
|
403 |
+
- **StreamingTV**: Whether the customer has streaming TV or not (Yes, No)
|
404 |
+
- **StreamingMovies**: Whether the customer has streaming movies or not (Yes, No)
|
405 |
+
- **Contract**: The contract term of the customer (Month-to-month, One year, Two year)
|
406 |
+
- **PaperlessBilling**: Whether the customer has paperless billing or not (Yes, No)
|
407 |
+
- **PaymentMethod**: The customer’s payment method (Electronic check, Mailed check, Bank transfer (automatic), Credit card (automatic))
|
408 |
+
- **MonthlyCharges**: The amount charged to the customer monthly
|
409 |
+
- **TotalCharges**: The total amount charged to the customer
|
410 |
+
- <span style="color: red;"> **Churn** (the variable we want to predict): Whether the customer churned or not (Yes or No) </span>
|
411 |
+
""", unsafe_allow_html=True)
|
412 |
|
413 |
+
st.markdown(" ")
|
414 |
+
st.markdown(" ")
|
415 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
416 |
|
417 |
+
## Exploratory data analysis
|
418 |
+
st.markdown("#### Exploratory Data Analysis 🔎")
|
419 |
+
st.markdown("""Exploratory Data Analysis (EDA) is a crucial step in the machine learning workflow.
|
420 |
+
It helps practitioners understand the structure, patterns, and characteristics of the data they are working with.
|
421 |
+
For this use case, we will perform EDA by analyzing the **proportion of clients who have churned or not** based on the dataset's other variables.""")
|
422 |
+
|
423 |
+
st.info("**Note**: EDA is usually preformed before model training as it helps inform decisions made by the model throughout the modeling process.")
|
424 |
|
425 |
+
see_EDA = st.checkbox('**View the analysis**', key="churn-EDA")
|
426 |
+
if see_EDA:
|
427 |
+
st.markdown(" ")
|
428 |
|
429 |
+
# Show EDA image
|
430 |
+
st.markdown("""Exploratory Data Analysis has been preformed between the predicted variable `Churn` with 15 other variables present in the dataset. <br>
|
431 |
+
Each graphs shows the proportion of churned and not churned customer based on the variable's possible values.""", unsafe_allow_html=True)
|
432 |
st.markdown(" ")
|
433 |
+
|
434 |
+
img_eda = os.path.join(path_churn, "EDA_churn.png")
|
435 |
+
st.image(img_eda)
|
436 |
+
|
437 |
st.markdown(" ")
|
|
|
|
|
|
|
438 |
|
439 |
+
# Intepretation
|
440 |
+
st.markdown("""**Interpretation** <br>
|
441 |
+
For variables such as `Contract`, `PaperlessBilling`, `PaymentMethod` and `InternetService`, we can see a significant difference in the proportion of churned customers based on the variable's value.
|
442 |
+
In the *Contract* graph, clients with a 'Month-to-Month' tend to churn more often than those with a longer contract.
|
443 |
+
In the *InternetService* graph, clients with a 'Fiber optic' service are more likely to churn than those with DSL or no internet service. """, unsafe_allow_html=True)
|
444 |
+
|
445 |
+
st.info("""**Note**: Performing EDA can give us an indication as to which variables might be more significant in the customer churn model.
|
446 |
+
It can be a valuable tool to study the relationship between two variables but can sometimes be too simplistic. Some relationships might be top complex to be seen through EDA.""")
|
447 |
|
448 |
+
|
449 |
+
|
450 |
+
st.markdown(""" """)
|
451 |
+
st.markdown(""" """)
|
452 |
+
|
453 |
+
## Train the algorithm
|
454 |
+
st.markdown("#### Train the algorithm ⚙️")
|
455 |
+
st.markdown("""**Training the model** means feeding it data that contains multiple examples of what you are trying to predict (here it is `Churn`).
|
456 |
+
This allows the model to **learn relationships** between the `Churn` variable and the additional variables provided for the analysis and make accuracte predictions.""")
|
457 |
+
st.info("**Note**: A model is always trained before it can used to make predictions on new 'unlabeled' data.")
|
458 |
|
|
|
|
|
|
|
459 |
|
460 |
+
if 'model_train_churn' not in st.session_state:
|
461 |
+
st.session_state['model_train_churn'] = False
|
|
|
|
|
462 |
|
463 |
+
if st.session_state.model_train_churn:
|
464 |
+
st.write("The model has already been trained.")
|
465 |
+
else:
|
466 |
+
st.write("The model hasn't been trained yet")
|
467 |
|
468 |
+
run_churn_model = st.button("**Train the model**")
|
469 |
|
470 |
+
|
471 |
+
if run_churn_model:
|
472 |
+
st.session_state.model_train_churn = True
|
473 |
+
with st.spinner('Wait for it...'):
|
|
|
|
|
|
|
474 |
st.markdown(" ")
|
475 |
+
st.markdown(" ")
|
476 |
+
time.sleep(2)
|
477 |
+
st.markdown("#### See the results ☑️")
|
478 |
+
tab1, tab2 = st.tabs(["Performance", "Explainability"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
479 |
|
480 |
+
######## MODEL PERFORMANCE
|
481 |
+
with tab1:
|
482 |
+
results_train = load_data_pickle(path_churn,"churn_cm_train.pkl")
|
483 |
+
results_train = results_train.to_numpy()
|
484 |
+
accuracy = np.round(results_train.diagonal()*100)
|
485 |
+
df_accuracy = pd.DataFrame({"Churn":["No","Yes"],
|
486 |
+
"Accuracy":accuracy})
|
487 |
+
|
488 |
+
df_accuracy["Accuracy"] = np.round(df_accuracy["Accuracy"]/100)
|
489 |
+
|
490 |
+
st.markdown(" ")
|
491 |
+
st.info("""**Note**: Evaluating a model's performance helps provide a quantitative measure of the model's ability to make accurate decisions.
|
492 |
+
In this use case, the performance of the customer churn model was measured by comparing the clients' churn variables with the value predicted by the trained model.""")
|
493 |
|
494 |
+
fig = px.bar(df_accuracy, y='Accuracy', x='Churn', color="Churn", title="Model performance", text_auto=True)
|
495 |
+
fig.update_traces(textfont_size=16)
|
496 |
+
#fig.update_traces(textposition='inside', textfont=dict(color='white'))
|
497 |
+
st.plotly_chart(fig, use_container_width=True)
|
498 |
|
499 |
+
# st.markdown("""<i>The model's accuracy was measured for both Churn and No Churn.</i>
|
500 |
+
# <i>This is crucial as to understand whether the model is consistant in its performance, or whether it has trouble distinguishing between two kinds of credit score.</i>""",
|
501 |
+
# unsafe_allow_html=True)
|
502 |
|
503 |
+
st.markdown(" ")
|
|
|
|
|
|
|
|
|
504 |
|
505 |
+
st.markdown("""**Interpretation**: <br>
|
506 |
+
The model has a 88% accuracy in predicting customer that haven't churned, and a 94% accurate in predicting customer who have churned. <br>
|
507 |
+
This means that the model's overall performance is good (at around 91%) but isn't equally as good for both predicted classes.
|
508 |
+
""", unsafe_allow_html=True)
|
509 |
+
|
510 |
+
##### MODEL EXPLAINABILITY
|
511 |
+
with tab2:
|
512 |
+
st.markdown(" ")
|
513 |
+
st.info("""**Note**: Explainability in AI refers to the ability to understand which variable used by a model during training had the most impact on the final predictions and how to quantify this impact.
|
514 |
+
Understanding the inner workings of a model helps build trust among users and stakeholders, as well as increase acceptance.""")
|
515 |
+
|
516 |
+
# Import feature importance dataframe
|
517 |
+
df_var_importance = load_data_pickle(path_churn, "churn_feature_importance.pkl")
|
518 |
+
df_var_importance.rename({"importance":"score"}, axis=1, inplace=True)
|
519 |
+
df_var_importance.sort_values(by=["score"], inplace=True)
|
520 |
+
df_var_importance["score"] = df_var_importance["score"].round(3)
|
521 |
+
|
522 |
+
# Feature importance plot with plotly
|
523 |
+
fig = px.bar(df_var_importance, x='score', y='variable', color="score", orientation="h", title="Model explainability")
|
524 |
+
st.plotly_chart(fig, use_container_width=True)
|
525 |
+
|
526 |
+
st.markdown("""<b>Interpretation</b> <br>
|
527 |
+
The client's tenure, amount of Monthly and Total Charges, as well as the type of Contract had the most impact on the model's churn predictions.
|
528 |
+
On the other hand, whether the client is subscribed to a streaming platform, he is covered by device protection or he has or not phone service had a very contribution in the final predictions.
|
529 |
+
""", unsafe_allow_html=True)
|
530 |
+
|
531 |
+
st.markdown(" ")
|
532 |
+
st.markdown(" ")
|
533 |
|
534 |
+
st.markdown("#### Predict customer churn 🆕")
|
535 |
+
st.markdown("Once you have trained the model, you can use it predict whether a client will churn or not on new data.")
|
|
|
536 |
|
537 |
+
st.markdown(" ")
|
|
|
|
|
|
|
|
|
538 |
|
539 |
+
col1, col2 = st.columns([0.25,0.75], gap="medium")
|
540 |
+
|
541 |
+
churn_test = load_data_pickle(path_churn,"churn_test_raw.pkl")
|
542 |
+
churn_test.reset_index(drop=True, inplace=True)
|
543 |
+
churn_test.insert(0, "Client ID", [f"{i}" for i in range(churn_test.shape[0])])
|
544 |
|
545 |
+
with col1:
|
546 |
+
st.markdown("""<b>Filter the data</b> <br>
|
547 |
+
You can select clients based on their *Tenure*, *Total Charges* or *Contract*.""",
|
548 |
+
unsafe_allow_html=True)
|
549 |
|
550 |
+
select_image_box = st.radio(" ",
|
551 |
+
["Filter by Tenure", "Filter by Total Charges", "Filter by Contract", "No filters"],
|
552 |
+
label_visibility="collapsed")
|
553 |
|
554 |
+
if select_image_box == "Filter by Tenure":
|
555 |
+
st.markdown(" ")
|
556 |
+
min_tenure, max_tenure = st.slider('Select a range', churn_test["tenure"].astype(int).min(), churn_test["tenure"].astype(int).max(), (1,50),
|
557 |
+
key="tenure", label_visibility="collapsed")
|
558 |
+
churn_test = churn_test.loc[churn_test["tenure"].between(min_tenure,max_tenure)]
|
|
|
|
|
|
|
|
|
|
|
559 |
|
560 |
+
if select_image_box == "Filter by Total Charges":
|
561 |
+
st.markdown(" ")
|
562 |
+
min_charges, max_charges = st.slider('Select a range', churn_test["TotalCharges"].astype(int).min(), churn_test["TotalCharges"].astype(int).max(), (50, 5000),
|
563 |
+
label_visibility="collapsed", key="totalcharges")
|
564 |
+
churn_test = churn_test.loc[churn_test["TotalCharges"].between(min_charges, max_charges)]
|
565 |
|
566 |
+
if select_image_box == "Filter by Contract":
|
567 |
+
contract = st.selectbox('Select a type of contract', churn_test["Contract"].unique(), index=0, label_visibility="collapsed", key="contract",
|
568 |
+
placeholder = "Choose one or more options")
|
569 |
+
churn_test = churn_test.loc[churn_test["Contract"]==contract]
|
570 |
|
571 |
+
if select_image_box == "No filters":
|
572 |
+
pass
|
|
|
573 |
|
574 |
+
st.markdown(" ")
|
575 |
+
st.markdown("""<b>Select a threshold for the alert</b> <br>
|
576 |
+
A warning message will be displayed if the percentage of churned customers exceeds this threshold.
|
577 |
+
""", unsafe_allow_html=True)
|
578 |
+
warning_threshold = st.slider('Select a value', min_value=20, max_value=100, step=10,
|
579 |
+
label_visibility="collapsed", key="warning")
|
580 |
|
581 |
+
st.markdown(" ")
|
582 |
+
st.write("The threshold is at", warning_threshold, "%")
|
|
|
583 |
|
|
|
|
|
|
|
|
|
|
|
584 |
|
585 |
+
with col2:
|
586 |
+
#st.markdown("**View the database**")
|
587 |
+
st.dataframe(churn_test)
|
588 |
|
|
|
|
|
589 |
|
590 |
+
# Button to make predictions
|
591 |
+
make_predictions = st.button("**Make predictions**")
|
592 |
+
st.markdown(" ")
|
593 |
|
594 |
+
if make_predictions:
|
595 |
+
if st.session_state.model_train_churn is True:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
596 |
|
597 |
+
## Load preprocessed test data and model
|
598 |
+
churn_test_pp = load_data_pickle(path_churn, "churn_test_pp.pkl")
|
599 |
+
churn_model = load_model_pickle(path_pretrained_supervised,"churn_model.pkl")
|
600 |
|
601 |
+
X_test = churn_test_pp.iloc[churn_test.index,:].to_numpy()
|
602 |
+
predictions = churn_model.predict(X_test)
|
603 |
+
predictions = ["No" if x==0 else "Yes" for x in predictions]
|
604 |
+
|
605 |
+
df_results_pred = churn_test.copy()
|
606 |
+
df_results_pred["Churn"] = predictions
|
607 |
+
df_mean_pred = df_results_pred["Churn"].value_counts().to_frame().reset_index()
|
608 |
+
df_mean_pred.columns = ["Churn", "Proportion"]
|
609 |
+
df_mean_pred["Proportion"] = (100*df_mean_pred["Proportion"]/df_results_pred.shape[0]).round()
|
610 |
|
611 |
+
perct_churned = df_mean_pred.loc[df_mean_pred["Churn"]=="Yes"]["Proportion"].to_numpy()
|
612 |
|
613 |
+
if perct_churned >= warning_threshold:
|
614 |
+
st.error(f"The proportion of clients that have churned is above {warning_threshold}% (at {perct_churned[0]}%)⚠️")
|
615 |
|
616 |
+
st.markdown(" ")
|
617 |
|
618 |
+
col1, col2 = st.columns([0.4,0.6], gap="large")
|
619 |
+
with col1:
|
620 |
+
st.markdown("**Proporition of predicted churn**")
|
621 |
+
fig = px.pie(df_mean_pred, values='Proportion', names='Churn', color="Churn",
|
622 |
+
color_discrete_map={'No':'royalblue', 'Yes':'red'})
|
623 |
+
st.plotly_chart(fig, use_container_width=True)
|
624 |
+
|
625 |
+
with col2:
|
626 |
+
df_show_results = df_results_pred[["Churn","Client ID"] + [col for col in df_results_pred.columns if col not in ["Client ID","Churn"]]]
|
627 |
+
columns_float = df_show_results.select_dtypes(include="float").columns
|
628 |
+
df_show_results[columns_float] = df_show_results[columns_float].astype(int)
|
629 |
+
|
630 |
+
def highlight_score(val):
|
631 |
+
if val == "No":
|
632 |
+
color = 'royalblue'
|
633 |
+
if val == 'Yes':
|
634 |
+
color= "red"
|
635 |
+
return f'color: {color}'
|
636 |
+
|
637 |
+
df_show_results_color = df_show_results.style.applymap(highlight_score, subset=['Churn'])
|
638 |
+
|
639 |
+
st.markdown("**Overall results**")
|
640 |
+
st.dataframe(df_show_results_color)
|
641 |
|
642 |
+
else:
|
643 |
+
st.error("You have to train the credit score model first.")
|
644 |
|
|
|
|
|
|
|
645 |
|
646 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
647 |
|
648 |
|
649 |
|
650 |
+
#######################################################################################################################
|
651 |
+
# UNSUPERVISED LEARNING
|
652 |
+
#######################################################################################################################
|
|
|
653 |
|
|
|
654 |
|
655 |
+
def markdown_general_info(df):
|
656 |
+
text = st.markdown(f"""
|
657 |
+
- **Age**: {int(np.round(df.Age))}
|
658 |
+
- **Yearly income**: {int(df.Income)} $
|
659 |
+
- **Number of kids**: {df.Kids}
|
660 |
+
- **Days of enrollment**: {int(np.round(df.Days_subscription))}
|
661 |
+
- **Web visits per month**: {df.WebVisitsMonth}
|
662 |
+
""")
|
663 |
+
return text
|
664 |
|
|
|
665 |
|
|
|
|
|
|
|
666 |
|
667 |
+
if learning_type == "Unsupervised Learning":
|
668 |
+
usl_usecase = st.selectbox("**Choose a use case**",
|
669 |
+
["Customer segmentation (clustering) 🧑🤝🧑"])
|
670 |
|
671 |
+
|
672 |
+
#################################### CUSTOMER SEGMENTATION ##################################
|
673 |
+
|
674 |
+
path_clustering = r"data/clustering"
|
675 |
+
path_clustering_results = r"data/clustering/results"
|
676 |
+
|
677 |
+
if usl_usecase == "Customer segmentation (clustering) 🧑🤝🧑":
|
678 |
+
|
679 |
+
# st.divider()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
680 |
st.divider()
|
681 |
+
st.markdown("# Customer Segmentation (clustering) 🧑🤝🧑")
|
682 |
|
683 |
+
st.markdown("""In this use case, we will use a clustering model, a type of Unsupervised Learning model, to perform **Customer Segmentation**. <br>
|
684 |
+
Our model will allow similar groups of clients to be identified within company's consumer database based on consumer habits and caracteristics.
|
685 |
+
""", unsafe_allow_html=True)
|
686 |
+
|
687 |
+
st.markdown(" ")
|
688 |
+
|
689 |
+
## Show image
|
690 |
+
_, col, _ = st.columns([0.2,0.5,0.3])
|
691 |
+
with col:
|
692 |
+
st.image("images/cs.webp")
|
693 |
+
|
694 |
+
## About the use case
|
695 |
+
st.markdown("#### About the use case 📋")
|
696 |
+
st.markdown("""You are giving a database that contains information on around **2000 customers** of a mass-market retailer.
|
697 |
+
The database's contains **personal information** (age, income, number of kids...), as well as information on what types of products were purchased by the client, how long has he been enrolled as a client and where these purchases were made. """, unsafe_allow_html=True)
|
698 |
+
|
699 |
+
see_data = st.checkbox('**See the data**', key="dataframe")
|
700 |
+
|
701 |
+
if see_data:
|
702 |
+
customer_data = load_data_pickle(path_clustering, "clean_marketing.pkl")
|
703 |
+
st.dataframe(customer_data.head(10))
|
704 |
+
|
705 |
+
learn_data = st.checkbox('**Learn more about the variables**', key="variable")
|
706 |
+
|
707 |
+
if learn_data:
|
708 |
+
st.markdown("""
|
709 |
+
- **Age**: Customer's age
|
710 |
+
- **Income**: Customer's yearly household income
|
711 |
+
- **Kids**: Number of children/teenagers in customer's household
|
712 |
+
- **Days_subscription**: Number of days since a customer's enrollment with the company
|
713 |
+
- **Recency**: Number of days since customer's last purchase
|
714 |
+
- **Wines**: Proportion of money spent on wine in last 2 years
|
715 |
+
- **Fruits**: Proportion of money spent on fruits in last 2 years
|
716 |
+
- **MeatProducts**: Proportion of money spent on meat in last 2 years
|
717 |
+
- **FishProducts**: Proportion of money spent on fish in last 2 years
|
718 |
+
- **SweetProducts**: Proportion of money spent sweets in last 2 years
|
719 |
+
- **DealsPurchases**: Proportion of purchases made with a discount
|
720 |
+
- **WebPurchases**: Proportion of purchases made through the company’s website
|
721 |
+
- **CatalogPurchases**: Proporition of purchases made using a catalogue
|
722 |
+
- **StorePurchases**: Proportion of purchases made directly in stores
|
723 |
+
- **WebVisitsMonth**: Proportion of visits to company’s website in the last month""")
|
724 |
+
st.divider()
|
725 |
|
|
|
|
|
726 |
|
727 |
+
st.markdown(" ")
|
728 |
+
st.markdown(" ")
|
729 |
|
730 |
+
st.markdown("#### Clustering algorithm ⚙️")
|
|
|
|
|
|
|
|
|
731 |
|
732 |
+
st.info("""**Clustering** is a type of unsupervised learning method that learns how to group similar data points together into "clusters", without needing supervision.
|
733 |
+
In our case, a data points represents a customer that will be assigned to an unknown group.""")
|
734 |
+
|
735 |
+
# st.markdown("""
|
736 |
+
# - The clustering algorithm used in this use case allows a specific number of groups to be identified, which isn't the case for all clustering models.""")
|
737 |
+
|
738 |
+
st.markdown(" ")
|
739 |
+
st.markdown("Here is an example of grouped data using a clustering model.")
|
740 |
+
st.image("images/clustering.webp")
|
741 |
+
|
742 |
+
st.warning("**Note**: The number of clusters chosen by the user can have a strong impact on the quality of the segmentation. Try to run the model multiple times with different number of clusters and see which number leads to groups with more distinct customer behaviors/preferences.")
|
743 |
+
|
744 |
+
nb_groups = st.selectbox("Choose a number of customer groups to identify", np.arange(2,6))
|
745 |
+
df_results = load_data_pickle(path_clustering_results, f"results_{nb_groups}_clusters.pkl")
|
746 |
+
|
747 |
+
st.markdown(" ")
|
748 |
+
run_model = st.button("**Run the model**")
|
749 |
+
#tab1, tab2 = st.tabs(["Results per product type", "Results per channel"])
|
750 |
+
#st.divider()
|
751 |
+
|
752 |
+
if run_model:
|
753 |
+
cols_group = st.columns(int(nb_groups))
|
754 |
+
for nb in range(nb_groups):
|
755 |
+
df_nb = df_results[nb]
|
756 |
+
|
757 |
+
col1, col2 = st.columns([0.3,0.7])
|
758 |
+
with col1:
|
759 |
+
st.image("images/group.png", width=200)
|
760 |
+
st.header(f"Group {nb+1}", divider="grey")
|
761 |
+
markdown_general_info(df_nb)
|
762 |
+
|
763 |
+
with col2:
|
764 |
+
tab1, tab2 = st.tabs(["Results per product type", "Results per channel"])
|
765 |
+
list_product_col = [col for col in list(df_nb.index) if "Products" in col]
|
766 |
+
df_products = df_nb.reset_index()
|
767 |
+
df_products = df_products.loc[df_products["variable"].isin(list_product_col)]
|
768 |
+
df_products.columns = ["variables", "values"]
|
769 |
+
|
770 |
+
with tab1:
|
771 |
+
fig = px.pie(df_products, values='values', names='variables',
|
772 |
+
title="Amount spent per product type (in %)")
|
773 |
+
st.plotly_chart(fig, width=300)
|
774 |
+
|
775 |
+
list_purchases_col = [col for col in list(df_nb.index) if "Purchases" in col]
|
776 |
+
df_products = df_nb.reset_index()
|
777 |
+
df_products = df_products.loc[df_products["variable"].isin(list_purchases_col)]
|
778 |
+
df_products.columns = ["variables", "values"]
|
779 |
+
|
780 |
+
with tab2:
|
781 |
+
fig = px.pie(df_products, values='values', names='variables',
|
782 |
+
title='Proportion of purchases made per channel (in %)')
|
783 |
+
st.plotly_chart(fig, width=300)
|
784 |
|
pages/timeseries_analysis.py
CHANGED
@@ -7,7 +7,7 @@ import numpy as np
|
|
7 |
from PIL import Image
|
8 |
from prophet import Prophet
|
9 |
from datetime import date
|
10 |
-
from utils import load_data_pickle
|
11 |
from sklearn.metrics import root_mean_squared_error
|
12 |
from st_pages import add_indentation
|
13 |
|
@@ -30,181 +30,180 @@ def forecast_prophet(train, test, col=None):
|
|
30 |
|
31 |
###################################### TITLE ####################################
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
st.markdown("# Time Series Forecasting 📈")
|
36 |
|
37 |
-
st.markdown("### What is Time Series Forecasting ?")
|
38 |
-
st.info("""Time series forecasting models are AI models built to make predictions about future values using historical data.
|
39 |
-
|
40 |
-
|
41 |
|
42 |
-
st.markdown(" ")
|
43 |
-
image_ts = Image.open('images/ts_patterns.png')
|
44 |
-
_, col, _ = st.columns([0.15,0.7,0.15])
|
45 |
-
with col:
|
46 |
-
|
47 |
|
48 |
-
st.markdown(" ")
|
49 |
|
50 |
-
st.markdown("""Real-life applications of time series forecasting include:
|
51 |
- **Finance 💰**: Predict stock prices based on historical data to assist investors and traders in making informed decisions.
|
52 |
- **Energy ⚡**: Forecast energy consumption patterns to optimize resource allocation, plan maintenance, and manage energy grids more efficiently.
|
53 |
- **Retail 🏬**: Predict future demand for products to optimize inventory levels, reduce holding costs, and improve supply chain efficiency.
|
54 |
- **Transportation and Traffic flow :car:**: Forecasting traffic patterns to optimize route planning, reduce congestion, and improve overall transportation efficiency.
|
55 |
- **Healthcare** 👨⚕️: Predicting the number of patient admissions to hospitals, helping healthcare providers allocate resources effectively and manage staffing levels.
|
56 |
- **Weather 🌦️**: Predicting weather conditions over time, which is crucial for planning various activities, agricultural decisions, and disaster preparedness.
|
57 |
-
""")
|
58 |
|
59 |
|
60 |
|
61 |
-
st.markdown(" ")
|
62 |
|
63 |
|
64 |
|
65 |
|
66 |
-
###################################### USE CASE #######################################
|
67 |
|
68 |
-
# LOAD DATASET
|
69 |
-
path_timeseries = r"data/household"
|
70 |
-
data_model = load_data_pickle(path_timeseries,"household_power_consumption_clean.pkl")
|
71 |
-
data_model.rename({"Date":"ds", "Global_active_power":"y"}, axis=1, inplace=True)
|
72 |
-
data_model.dropna(inplace=True)
|
73 |
-
data_model["ds"] = pd.to_datetime(data_model["ds"])
|
74 |
|
75 |
-
# BEGINNING OF USE CASE
|
76 |
-
st.divider()
|
77 |
-
st.markdown("# Power Consumption Forecasting ⚡")
|
78 |
|
79 |
-
#st.markdown(" ")
|
80 |
-
st.info("""In this use case, a time series forecasting model learns how to accuratly predict the **energy consumption** (or global active power in the dataset) of a household using historical data.
|
81 |
-
|
82 |
|
83 |
-
st.markdown(" ")
|
84 |
|
85 |
-
_, col, _ = st.columns([0.15,0.7,0.15])
|
86 |
-
with col:
|
87 |
-
|
88 |
|
89 |
-
st.markdown(" ")
|
90 |
-
st.markdown(" ")
|
91 |
|
92 |
-
st.markdown("### About the data 📋")
|
93 |
|
94 |
-
st.markdown("""You were provided data from the **daily energy consumption** of a household between January 2007 and November 2010 (46 months). <br>
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
|
99 |
-
st.markdown(" ")
|
100 |
|
101 |
-
st.info("""The data has been split into "historical data" and "to be predicted". Since forecasting models are **supervised**, we will use the household's energy data from January 2007 to December 2009 as historical data to train the model.
|
102 |
-
|
103 |
|
104 |
-
select_cutoff_date = date(2010, 1, 1)
|
105 |
-
select_cutoff_date = select_cutoff_date.strftime('%Y-%m-%d')
|
106 |
|
107 |
-
# SELECT TRAIN/TEST SET
|
108 |
-
train = data_model[data_model["ds"] <= select_cutoff_date]
|
109 |
-
test = data_model[data_model["ds"] > select_cutoff_date]
|
110 |
|
111 |
-
# PLOT TRAIN/TEST SET
|
112 |
-
train_plot = train.copy()
|
113 |
-
train_plot["split"] = ["historical data"]*len(train_plot)
|
114 |
|
115 |
-
test_plot = test.copy()
|
116 |
-
test_plot["split"] = ["to be predicted"]*len(test_plot)
|
117 |
-
data_clean_plot = pd.concat([train_plot, test_plot]) # plot dataset
|
118 |
|
119 |
-
st.markdown(" ")
|
120 |
-
tab1, tab2, tab3, tab4, tab5 = st.tabs(["Global active power", "Sub metering 1", "Sub metering 2", "Sub metering 3", "Global Intensity"])
|
121 |
|
122 |
-
with tab1:
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
-
with tab2:
|
135 |
-
ts_chart = alt.Chart(data_clean_plot.loc[data_clean_plot["split"]=="historical data"]).mark_line().encode(
|
136 |
-
x=alt.X('ds:T', axis=alt.Axis(format='%b %Y', tickCount=12), title="Date"),
|
137 |
-
y=alt.Y('Sub_metering_1:Q', title="Sub metering 1"),
|
138 |
-
color=alt.Color('split:N')) #, scale=custom_color_scale))
|
139 |
-
|
140 |
-
st.markdown("**View Sub-metering 1** (additional)")
|
141 |
-
st.altair_chart(ts_chart.interactive(), use_container_width=True)
|
142 |
-
st.success("**Sub-metering 1** is the total active power consumed by the kitchen in the house (in kilowatts).")
|
143 |
|
144 |
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
y=alt.Y('Sub_metering_2:Q', title="Sub metering 2"),
|
150 |
-
color=alt.Color('split:N')) #, scale=custom_color_scale))
|
151 |
|
152 |
-
st.markdown("**View Sub-metering 2** (additional)")
|
153 |
-
st.altair_chart(ts_chart.interactive(), use_container_width=True)
|
154 |
-
st.success("**Sub-metering 2** is the total active power consumed by the laundry room in the house (in kilowatts).")
|
155 |
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
y=alt.Y('Sub_metering_3:Q', title="Sub metering 3"),
|
161 |
-
color=alt.Color('split:N')) #scale=custom_color_scale))
|
162 |
|
163 |
-
st.markdown("**View Sub-metering 3** (additional)")
|
164 |
-
st.altair_chart(ts_chart.interactive(), use_container_width=True)
|
165 |
-
st.success("**Sub-metering 3** is the active power consumed by the electric water heater and air conditioner in the household (in kilowatts).")
|
166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
x=alt.X('ds:T', axis=alt.Axis(format='%b %Y', tickCount=12), title="Date"),
|
172 |
-
y=alt.Y('Global_intensity:Q', title="Global active power"),
|
173 |
-
color=alt.Color('split:N')) # scale=custom_color_scale))
|
174 |
|
175 |
-
st.markdown("**View Global intensity** (additional)")
|
176 |
-
st.altair_chart(ts_chart.interactive(), use_container_width=True)
|
177 |
-
st.success("**Global intensity** is the average current intensity delivered to the household (amps).")
|
178 |
|
179 |
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
-
st.markdown(" ")
|
182 |
-
st.markdown(" ")
|
183 |
-
st.markdown("### Forecast model 📈")
|
184 |
-
st.markdown("""The forecasting model used in this use case allows **additional data** to be used for training.
|
185 |
-
Try adding more data to the model as it can help improve its performance and accuracy.""")
|
186 |
|
187 |
|
|
|
|
|
|
|
|
|
188 |
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
|
|
|
|
|
|
193 |
|
194 |
-
|
195 |
-
st.
|
196 |
-
|
197 |
-
# if st.session_state.model_train:
|
198 |
-
# text = "The model has alerady been trained."
|
199 |
-
# else:
|
200 |
-
# st.write("The model hasn't been trained yet")
|
201 |
|
202 |
-
st.markdown("")
|
203 |
-
run_model = st.button("**Run the model**")
|
204 |
|
205 |
-
|
206 |
-
st.markdown(" ")
|
207 |
-
st.markdown(" ")
|
208 |
|
209 |
|
210 |
|
@@ -212,107 +211,107 @@ st.markdown(" ")
|
|
212 |
|
213 |
################################## SEE RESULTS ###############################
|
214 |
|
215 |
-
if "saved_model" not in st.session_state:
|
216 |
-
|
217 |
|
218 |
|
219 |
-
if run_model:
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
|
236 |
-
|
237 |
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
with col2:
|
243 |
-
# Create df for true vs predicted plot
|
244 |
-
df_results = pd.concat([test_plot.reset_index(drop=True), forecast.drop(columns=["ds"]).reset_index(drop=True)], axis=1)[["ds","y","yhat"]]
|
245 |
-
df_results = df_results.melt(id_vars="ds")
|
246 |
-
df_results["variable"] = df_results["variable"].map({"y":"true values", "yhat":"predicted values"})
|
247 |
-
df_results.columns = ["Date", "Variable", "Global Active Power"]
|
248 |
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
258 |
st.plotly_chart(fig, use_container_width=True)
|
|
|
|
|
259 |
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
fig.update_layout(title="Trend", xaxis_title="Date", yaxis_title="Trend")
|
266 |
-
st.plotly_chart(fig, use_container_width=True)
|
267 |
-
st.markdown("""**Interpretation** <br>
|
268 |
-
No trend in the household's energy consumption has been detected by the model.""", unsafe_allow_html=True)
|
269 |
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
forecast_weekly = forecast.copy()
|
274 |
-
forecast_weekly["dayweek"] = forecast_weekly["ds"].apply(lambda x: x.isoweekday()).map(days_week)
|
275 |
-
|
276 |
-
fig = px.area(forecast_weekly, x="dayweek", y="weekly", color_discrete_sequence=["purple"])
|
277 |
-
fig.update_layout(title="Weekly seasonality", xaxis_title="Date", yaxis_title="Weekly")
|
278 |
-
st.plotly_chart(fig, use_container_width=True)
|
279 |
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
|
297 |
-
|
298 |
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
|
304 |
|
305 |
-
################################## MAKE FUTURE PREDICTIONS ###############################
|
306 |
|
307 |
-
# st.markdown("#### Forecast new values ")
|
308 |
|
309 |
-
# st.info("**The model needs to be trained before it can predict new values.**")
|
310 |
|
311 |
-
# st.
|
312 |
|
313 |
-
# make_predictions = st.button("**Forecast new values**")
|
314 |
|
315 |
-
# if make_predictions is True:
|
316 |
-
# if st.session_state.saved_model is True:
|
317 |
|
318 |
|
|
|
7 |
from PIL import Image
|
8 |
from prophet import Prophet
|
9 |
from datetime import date
|
10 |
+
from utils import load_data_pickle, check_password
|
11 |
from sklearn.metrics import root_mean_squared_error
|
12 |
from st_pages import add_indentation
|
13 |
|
|
|
30 |
|
31 |
###################################### TITLE ####################################
|
32 |
|
33 |
+
if check_password():
|
34 |
+
st.markdown("# Time Series Forecasting 📈")
|
|
|
35 |
|
36 |
+
st.markdown("### What is Time Series Forecasting ?")
|
37 |
+
st.info("""Time series forecasting models are AI models built to make predictions about future values using historical data.
|
38 |
+
These types of models take into account temporal patterns, such as **trends** (long-term movements), **seasonality** (repeating patterns at fixed intervals), and **cyclic patterns** (repeating patterns not necessarily at fixed intervals)""")
|
39 |
+
#unsafe_allow_html=True)
|
40 |
|
41 |
+
st.markdown(" ")
|
42 |
+
image_ts = Image.open('images/ts_patterns.png')
|
43 |
+
_, col, _ = st.columns([0.15,0.7,0.15])
|
44 |
+
with col:
|
45 |
+
st.image(image_ts)
|
46 |
|
47 |
+
st.markdown(" ")
|
48 |
|
49 |
+
st.markdown("""Real-life applications of time series forecasting include:
|
50 |
- **Finance 💰**: Predict stock prices based on historical data to assist investors and traders in making informed decisions.
|
51 |
- **Energy ⚡**: Forecast energy consumption patterns to optimize resource allocation, plan maintenance, and manage energy grids more efficiently.
|
52 |
- **Retail 🏬**: Predict future demand for products to optimize inventory levels, reduce holding costs, and improve supply chain efficiency.
|
53 |
- **Transportation and Traffic flow :car:**: Forecasting traffic patterns to optimize route planning, reduce congestion, and improve overall transportation efficiency.
|
54 |
- **Healthcare** 👨⚕️: Predicting the number of patient admissions to hospitals, helping healthcare providers allocate resources effectively and manage staffing levels.
|
55 |
- **Weather 🌦️**: Predicting weather conditions over time, which is crucial for planning various activities, agricultural decisions, and disaster preparedness.
|
56 |
+
""")
|
57 |
|
58 |
|
59 |
|
60 |
+
st.markdown(" ")
|
61 |
|
62 |
|
63 |
|
64 |
|
65 |
+
###################################### USE CASE #######################################
|
66 |
|
67 |
+
# LOAD DATASET
|
68 |
+
path_timeseries = r"data/household"
|
69 |
+
data_model = load_data_pickle(path_timeseries,"household_power_consumption_clean.pkl")
|
70 |
+
data_model.rename({"Date":"ds", "Global_active_power":"y"}, axis=1, inplace=True)
|
71 |
+
data_model.dropna(inplace=True)
|
72 |
+
data_model["ds"] = pd.to_datetime(data_model["ds"])
|
73 |
|
74 |
+
# BEGINNING OF USE CASE
|
75 |
+
st.divider()
|
76 |
+
st.markdown("# Power Consumption Forecasting ⚡")
|
77 |
|
78 |
+
#st.markdown(" ")
|
79 |
+
st.info("""In this use case, a time series forecasting model learns how to accuratly predict the **energy consumption** (or global active power in the dataset) of a household using historical data.
|
80 |
+
A forecasting model can be a valuable tool for energy consumption analysis as it can help **optimize resource planning** and **avoid overloads** during peak demand periods.""")
|
81 |
|
82 |
+
st.markdown(" ")
|
83 |
|
84 |
+
_, col, _ = st.columns([0.15,0.7,0.15])
|
85 |
+
with col:
|
86 |
+
st.image("images/energy_consumption.jpg")
|
87 |
|
88 |
+
st.markdown(" ")
|
89 |
+
st.markdown(" ")
|
90 |
|
91 |
+
st.markdown("### About the data 📋")
|
92 |
|
93 |
+
st.markdown("""You were provided data from the **daily energy consumption** of a household between January 2007 and November 2010 (46 months). <br>
|
94 |
+
The goal is to forecast the **Global active power** being produced daily by the household.
|
95 |
+
Additional variables such as *Global Intensity* and three levels of *Sub-metering* are also available for the forecast.
|
96 |
+
""", unsafe_allow_html=True)
|
97 |
|
98 |
+
st.markdown(" ")
|
99 |
|
100 |
+
st.info("""The data has been split into "historical data" and "to be predicted". Since forecasting models are **supervised**, we will use the household's energy data from January 2007 to December 2009 as historical data to train the model.
|
101 |
+
We will then use the rest of the available data (starting January 2010) to test the performance of the model.""")
|
102 |
|
103 |
+
select_cutoff_date = date(2010, 1, 1)
|
104 |
+
select_cutoff_date = select_cutoff_date.strftime('%Y-%m-%d')
|
105 |
|
106 |
+
# SELECT TRAIN/TEST SET
|
107 |
+
train = data_model[data_model["ds"] <= select_cutoff_date]
|
108 |
+
test = data_model[data_model["ds"] > select_cutoff_date]
|
109 |
|
110 |
+
# PLOT TRAIN/TEST SET
|
111 |
+
train_plot = train.copy()
|
112 |
+
train_plot["split"] = ["historical data"]*len(train_plot)
|
113 |
|
114 |
+
test_plot = test.copy()
|
115 |
+
test_plot["split"] = ["to be predicted"]*len(test_plot)
|
116 |
+
data_clean_plot = pd.concat([train_plot, test_plot]) # plot dataset
|
117 |
|
118 |
+
st.markdown(" ")
|
119 |
+
tab1, tab2, tab3, tab4, tab5 = st.tabs(["Global active power", "Sub metering 1", "Sub metering 2", "Sub metering 3", "Global Intensity"])
|
120 |
|
121 |
+
with tab1:
|
122 |
+
ts_chart = alt.Chart(data_clean_plot).mark_line().encode(
|
123 |
+
x=alt.X('ds:T', axis=alt.Axis(format='%b %Y', tickCount=12), title="Date"),
|
124 |
+
y=alt.Y('y:Q', title="Global active power"),
|
125 |
+
color='split:N',
|
126 |
+
).interactive()
|
127 |
+
|
128 |
+
st.markdown("**View Global active power** (to be forecasted)")
|
129 |
+
st.altair_chart(ts_chart, use_container_width=True)
|
130 |
+
st.success("""**Global active power** refers to the total real power consumed by electrical devices in the house (in kilowatts).""")
|
131 |
+
|
132 |
+
|
133 |
+
with tab2:
|
134 |
+
ts_chart = alt.Chart(data_clean_plot.loc[data_clean_plot["split"]=="historical data"]).mark_line().encode(
|
135 |
+
x=alt.X('ds:T', axis=alt.Axis(format='%b %Y', tickCount=12), title="Date"),
|
136 |
+
y=alt.Y('Sub_metering_1:Q', title="Sub metering 1"),
|
137 |
+
color=alt.Color('split:N')) #, scale=custom_color_scale))
|
138 |
+
|
139 |
+
st.markdown("**View Sub-metering 1** (additional)")
|
140 |
+
st.altair_chart(ts_chart.interactive(), use_container_width=True)
|
141 |
+
st.success("**Sub-metering 1** is the total active power consumed by the kitchen in the house (in kilowatts).")
|
142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
|
145 |
+
with tab3:
|
146 |
+
ts_chart = alt.Chart(data_clean_plot.loc[data_clean_plot["split"]=="historical data"]).mark_line().encode(
|
147 |
+
x=alt.X('ds:T', axis=alt.Axis(format='%b %Y', tickCount=12), title="Date"),
|
148 |
+
y=alt.Y('Sub_metering_2:Q', title="Sub metering 2"),
|
149 |
+
color=alt.Color('split:N')) #, scale=custom_color_scale))
|
150 |
|
151 |
+
st.markdown("**View Sub-metering 2** (additional)")
|
152 |
+
st.altair_chart(ts_chart.interactive(), use_container_width=True)
|
153 |
+
st.success("**Sub-metering 2** is the total active power consumed by the laundry room in the house (in kilowatts).")
|
|
|
|
|
154 |
|
|
|
|
|
|
|
155 |
|
156 |
+
with tab4:
|
157 |
+
ts_chart = alt.Chart(data_clean_plot.loc[data_clean_plot["split"]=="historical data"]).mark_line().encode(
|
158 |
+
x=alt.X('ds:T', axis=alt.Axis(format='%b %Y', tickCount=12), title="Date"),
|
159 |
+
y=alt.Y('Sub_metering_3:Q', title="Sub metering 3"),
|
160 |
+
color=alt.Color('split:N')) #scale=custom_color_scale))
|
161 |
|
162 |
+
st.markdown("**View Sub-metering 3** (additional)")
|
163 |
+
st.altair_chart(ts_chart.interactive(), use_container_width=True)
|
164 |
+
st.success("**Sub-metering 3** is the active power consumed by the electric water heater and air conditioner in the household (in kilowatts).")
|
|
|
|
|
165 |
|
|
|
|
|
|
|
166 |
|
167 |
+
with tab5:
|
168 |
+
custom_color_scale = alt.Scale(range=['red', 'lightcoral'])
|
169 |
+
ts_chart = alt.Chart(data_clean_plot.loc[data_clean_plot["split"]=="historical data"]).mark_line().encode(
|
170 |
+
x=alt.X('ds:T', axis=alt.Axis(format='%b %Y', tickCount=12), title="Date"),
|
171 |
+
y=alt.Y('Global_intensity:Q', title="Global active power"),
|
172 |
+
color=alt.Color('split:N')) # scale=custom_color_scale))
|
173 |
|
174 |
+
st.markdown("**View Global intensity** (additional)")
|
175 |
+
st.altair_chart(ts_chart.interactive(), use_container_width=True)
|
176 |
+
st.success("**Global intensity** is the average current intensity delivered to the household (amps).")
|
|
|
|
|
|
|
177 |
|
|
|
|
|
|
|
178 |
|
179 |
|
180 |
+
st.markdown(" ")
|
181 |
+
st.markdown(" ")
|
182 |
+
st.markdown("### Forecast model 📈")
|
183 |
+
st.markdown("""The forecasting model used in this use case allows **additional data** to be used for training.
|
184 |
+
Try adding more data to the model as it can help improve its performance and accuracy.""")
|
185 |
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
|
188 |
+
# ADD VARIABLES TO ANALYSIS
|
189 |
+
add_var = ["Sub_metering_1", "Sub_metering_2", "Sub_metering_3","Global_intensity"]
|
190 |
+
st.markdown("")
|
191 |
+
select_add_var = st.multiselect("**Add variables to the model**", add_var)
|
192 |
|
193 |
+
if 'model_train' not in st.session_state:
|
194 |
+
st.session_state['model_train'] = False
|
195 |
+
|
196 |
+
# if st.session_state.model_train:
|
197 |
+
# text = "The model has alerady been trained."
|
198 |
+
# else:
|
199 |
+
# st.write("The model hasn't been trained yet")
|
200 |
|
201 |
+
st.markdown("")
|
202 |
+
run_model = st.button("**Run the model**")
|
|
|
|
|
|
|
|
|
|
|
203 |
|
|
|
|
|
204 |
|
205 |
+
st.markdown(" ")
|
206 |
+
st.markdown(" ")
|
|
|
207 |
|
208 |
|
209 |
|
|
|
211 |
|
212 |
################################## SEE RESULTS ###############################
|
213 |
|
214 |
+
if "saved_model" not in st.session_state:
|
215 |
+
st.session_state["saved_model"] = False
|
216 |
|
217 |
|
218 |
+
if run_model:
|
219 |
+
with st.spinner('Wait for it...'):
|
220 |
+
fbmodel, forecast = forecast_prophet(train, test, col=select_add_var)
|
221 |
+
st.session_state.model_train = True
|
222 |
+
st.session_state.saved_model = fbmodel
|
223 |
|
224 |
+
####################### SEE RESULTS ########################
|
225 |
+
st.markdown("#### See the results ☑️")
|
226 |
+
st.info("The model is able to forecast energy consumption as well as learn the predicted data's **trend**, **weekly** and **yearly seasonality**.")
|
227 |
|
228 |
+
tab1_result, tab2_result, tab3_result, tab4_result = st.tabs(["Performance", "Trend", "Weekly seasonality", "Yearly seasonality"])
|
229 |
+
with tab1_result:
|
230 |
+
# Compute model root mean squared error
|
231 |
+
y_true = test_plot["y"]
|
232 |
+
y_pred = forecast["yhat"]
|
233 |
+
error = str(np.round(root_mean_squared_error(y_true, y_pred, ),3))
|
234 |
|
235 |
+
col1, col2 = st.columns([0.1,0.9])
|
236 |
|
237 |
+
with col1:
|
238 |
+
st.markdown("")
|
239 |
+
st.metric(label="**Average error**", value=error)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
|
241 |
+
with col2:
|
242 |
+
# Create df for true vs predicted plot
|
243 |
+
df_results = pd.concat([test_plot.reset_index(drop=True), forecast.drop(columns=["ds"]).reset_index(drop=True)], axis=1)[["ds","y","yhat"]]
|
244 |
+
df_results = df_results.melt(id_vars="ds")
|
245 |
+
df_results["variable"] = df_results["variable"].map({"y":"true values", "yhat":"predicted values"})
|
246 |
+
df_results.columns = ["Date", "Variable", "Global Active Power"]
|
247 |
+
|
248 |
+
fig = px.line(df_results, x="Date", y="Global Active Power", color="Variable",
|
249 |
+
color_discrete_sequence=["lightblue", "black"], line_dash = 'Variable')
|
250 |
+
|
251 |
+
fig.update_layout(
|
252 |
+
title=f'True vs predicted power consumption',
|
253 |
+
width=1200,
|
254 |
+
height=600
|
255 |
+
)
|
256 |
+
|
257 |
+
st.plotly_chart(fig, use_container_width=True)
|
258 |
+
|
259 |
+
with tab2_result:
|
260 |
+
ymin = forecast["trend"].min()
|
261 |
+
ymax = forecast["trend"].max()
|
262 |
+
|
263 |
+
fig = px.area(forecast, x="ds", y="trend", color_discrete_sequence=["red"]) #range_y=[ymin, ymax])
|
264 |
+
fig.update_layout(title="Trend", xaxis_title="Date", yaxis_title="Trend")
|
265 |
st.plotly_chart(fig, use_container_width=True)
|
266 |
+
st.markdown("""**Interpretation** <br>
|
267 |
+
No trend in the household's energy consumption has been detected by the model.""", unsafe_allow_html=True)
|
268 |
|
269 |
+
with tab3_result:
|
270 |
+
#st.success("**Weekly seasonality** refers to a repeating pattern or variation that occurs on a weekly basis on the energy consumption data.")
|
271 |
+
days_week = dict(zip(np.arange(1,8),["Monday", "Tuedsay", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"]))
|
272 |
+
forecast_weekly = forecast.copy()
|
273 |
+
forecast_weekly["dayweek"] = forecast_weekly["ds"].apply(lambda x: x.isoweekday()).map(days_week)
|
|
|
|
|
|
|
|
|
274 |
|
275 |
+
fig = px.area(forecast_weekly, x="dayweek", y="weekly", color_discrete_sequence=["purple"])
|
276 |
+
fig.update_layout(title="Weekly seasonality", xaxis_title="Date", yaxis_title="Weekly")
|
277 |
+
st.plotly_chart(fig, use_container_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
278 |
|
279 |
+
st.markdown("""**Interpretation** <br>
|
280 |
+
The household consumes more electrical power during the week-end (Saturday and Sunday) then during the week.
|
281 |
+
""", unsafe_allow_html=True)
|
282 |
|
283 |
+
with tab4_result:
|
284 |
+
forecast_year = forecast[["ds","yearly"]].copy()
|
285 |
+
forecast_year["ds_year"] = forecast_year["ds"].apply(lambda x: x.strftime("%B %d"))
|
286 |
+
forecast_year["ds"] = forecast_year["ds"].apply(lambda x: x.strftime("%m-%d"))
|
287 |
+
forecast_year.sort_values(by=["ds"], inplace=True)
|
288 |
+
forecast_year = forecast_year.groupby(["ds","ds_year"]).mean().reset_index()
|
289 |
|
290 |
+
st.markdown("")
|
291 |
+
ts_chart = alt.Chart(forecast_year, title="Yearly seasonality").mark_area(opacity=0.5,line = {'color':'darkblue'}).encode(
|
292 |
+
x=alt.X('ds_year:T', axis=alt.Axis(format='%b', tickCount=12), title="Date"),
|
293 |
+
y=alt.Y('yearly:Q', title="Yearly seasonality"),
|
294 |
+
).interactive()
|
295 |
|
296 |
+
st.altair_chart(ts_chart, use_container_width=True)
|
297 |
|
298 |
+
st.markdown("""**Interpretation** <br>
|
299 |
+
The household consumes more energy during the winter (November to February) and less during the warmer months.
|
300 |
+
""", unsafe_allow_html=True)
|
301 |
+
|
302 |
|
303 |
|
304 |
+
################################## MAKE FUTURE PREDICTIONS ###############################
|
305 |
|
306 |
+
# st.markdown("#### Forecast new values ")
|
307 |
|
308 |
+
# st.info("**The model needs to be trained before it can predict new values.**")
|
309 |
|
310 |
+
# st.
|
311 |
|
312 |
+
# make_predictions = st.button("**Forecast new values**")
|
313 |
|
314 |
+
# if make_predictions is True:
|
315 |
+
# if st.session_state.saved_model is True:
|
316 |
|
317 |
|
pages/topic_modeling.py
CHANGED
@@ -6,193 +6,192 @@ import pandas as pd
|
|
6 |
import matplotlib.pyplot as plt
|
7 |
import plotly.express as px
|
8 |
|
9 |
-
from utils import
|
10 |
from st_pages import add_indentation
|
11 |
|
12 |
|
13 |
st.set_page_config(layout="wide")
|
14 |
|
15 |
-
|
16 |
-
st.title("Topic Modeling 📚")
|
|
|
17 |
|
|
|
|
|
|
|
|
|
18 |
|
19 |
-
st.markdown("
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
It is a useful tool for organizing and summarizing vast amounts of textual data as well as automate the discovery of hidden thematic structures in a corpus of text data, without any prior knowledge.
|
24 |
-
""")
|
25 |
-
|
26 |
-
st.markdown(" ")
|
27 |
-
_, col, _ = st.columns([0.25,0.4,0.35])
|
28 |
-
with col:
|
29 |
-
st.image("images/topic_modeling.gif", caption="An example of Topic Modeling", use_column_width=True)
|
30 |
|
31 |
|
32 |
-
st.markdown("""Common applications of Topic Modeling include:
|
33 |
- **Search Engine Optimization (SEO): 🔎** Determine the main topics/keywords present on a website to optimize content and improve search engine rankings.
|
34 |
- **Customer Support** ✍️: Analyze customer support tickets, emails, and chat transcripts to identify common questions and complaints.
|
35 |
- **Fraud Detection and Risk Management: 🏦** : Detect fraudulent activities, compliance violations, and operational risks by analyzing textual data such as transaction descriptions, audit reports and regulatory filings.
|
36 |
- **Market Research 🌎**: Gain competitive intelligence and make informed decisions regarding product development, marketing strategies, and market positioning by analyzing research reports and industry news.
|
37 |
-
""")
|
38 |
-
|
39 |
-
|
40 |
-
st.markdown(" ")
|
41 |
-
st.divider()
|
42 |
-
|
43 |
-
st.markdown("# Topic modeling on product descriptions 🛍️")
|
44 |
-
st.markdown("""In this use case, we will use a **topic model** to categorize around **20 000 e-commerce products** as well as identify
|
45 |
-
the main types of products solds.""")
|
46 |
|
47 |
-
_, col, _ = st.columns([0.2,0.6,0.2])
|
48 |
-
with col:
|
49 |
-
st.image("images/e-commerce.jpg")
|
50 |
|
51 |
-
st.markdown("
|
|
|
52 |
|
53 |
-
#
|
54 |
-
|
55 |
-
|
56 |
|
57 |
-
|
58 |
-
|
|
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
##### ABOUT THE USE CASE
|
63 |
-
st.markdown("#### About the data 📋")
|
64 |
-
st.markdown("""You were provided a dataset with around 20 000 products from a large e-commerce retailer. <br>
|
65 |
-
This dataset contains the products' title and description on the website.""", unsafe_allow_html=True)
|
66 |
-
st.info("""**Note**: Some of the descriptions featured below are shown in their 'raw' form, meaning they contain unprocessed html code and special characters.
|
67 |
-
These descriptions were first 'cleaned' (by removing unwanted characters) before being used in the model.""")
|
68 |
-
see_data = st.checkbox('**See the data**', key="credit_score_data") # Corrected the key to use an underscore
|
69 |
-
if see_data:
|
70 |
st.markdown(" ")
|
71 |
-
st.warning("This view only shows a subset of the 20 000 product description used.")
|
72 |
-
data = load_data_pickle(path_data,"data-tm-view.pkl")
|
73 |
-
data_show = data[["TITLE", "DESCRIPTION"]]
|
74 |
-
st.dataframe(data_show.reset_index(drop=True), use_container_width=True)
|
75 |
-
|
76 |
|
77 |
-
|
78 |
-
|
|
|
79 |
|
|
|
|
|
80 |
|
81 |
|
82 |
-
# RUN THE MODEL
|
83 |
-
st.markdown("#### About the model 📚")
|
84 |
-
st.markdown("""**Topic models** can be seen as unsupervised clustering models where text documents are grouped into topics/clusters based on their similarities.
|
85 |
-
We will use here a topic model to automatically categorize/group the retailer's products based on their description,
|
86 |
-
as well as understand what are the most common type of products being sold.""", unsafe_allow_html=True)
|
87 |
|
88 |
-
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
-
def show_results():
|
92 |
-
st.markdown("#### See the results ☑️")
|
93 |
-
tab1, tab2 = st.tabs(["Overall results", "Specific Topic Details", ])# "Search Similar Topics"])
|
94 |
-
st.markdown(" ")
|
95 |
|
96 |
-
|
97 |
-
|
98 |
-
st.header("Overall results")
|
99 |
-
st.markdown("""This tab showcases all of the **topics identified** within the product dataset, each topic's most significant words (**top words**), as well as the **proportion**
|
100 |
-
of products that were assigned to the specific topic.""")
|
101 |
|
102 |
-
summary_table = topic_info[['Title','Representation', 'Percentage']].copy()
|
103 |
-
summary_table['Top Words'] = summary_table['Representation'].apply(lambda x: x[:5]) #:5
|
104 |
-
summary_table = summary_table[["Title","Top Words","Percentage"]]
|
105 |
-
summary_table.rename({"Title":"Topic Title"}, axis=1, inplace=True)
|
106 |
-
|
107 |
-
st.data_editor(
|
108 |
-
summary_table, #.loc[df_results_tab1["Customer ID"].isin(filter_customers)],
|
109 |
-
column_config={
|
110 |
-
"Percentage": st.column_config.ProgressColumn(
|
111 |
-
"Proportion %",
|
112 |
-
help="Propotion of documents within each topic",
|
113 |
-
format="%.1f%%",
|
114 |
-
min_value=0,
|
115 |
-
max_value=100)},
|
116 |
-
use_container_width=True
|
117 |
-
)
|
118 |
-
|
119 |
-
st.info("""**Note**: The topic 'titles' were not provided by the model but instead were generated by feeding the topic's top words to an LLM.
|
120 |
-
Traditional topic models define topics using representative/top words but weren't built to generate a specific title to each topic.""")
|
121 |
-
|
122 |
-
# Tab 2: Specific Topic Details
|
123 |
-
with tab2:
|
124 |
-
|
125 |
-
# Load top words
|
126 |
-
with open(os.path.join(path_data,"topics_top_words.json"), "r") as json_file:
|
127 |
-
top_words_dict = json.load(json_file)
|
128 |
-
|
129 |
-
# Load similarity df and scores
|
130 |
-
similarity_df = load_data_pickle(path_data, "similarity_topic_df.pkl")
|
131 |
-
similarity_scores = load_numpy(path_data, "similarity_topic_scores.npy")
|
132 |
|
133 |
-
#st.markdown(" ")
|
134 |
-
st.header("Learn more about each topic")
|
135 |
-
st.markdown("""You can **select a specific topic** to get more information on its **top words**, as well as the
|
136 |
-
**other topics that are most similar to it**.""")
|
137 |
-
# st.info("""In this section, you can find more information on each of the topics identified by the model.
|
138 |
-
# This includes the topic's a full list of its top words, the importance of each of these words, as well as the top five topics that are most similar to it.""")
|
139 |
|
140 |
-
|
|
|
|
|
|
|
|
|
141 |
|
142 |
-
|
143 |
-
|
144 |
-
selected_topic = st.selectbox('**Select a Topic**', topics)
|
145 |
-
selected_topic_id = topic_info[topic_info['Title'] == selected_topic]["Topic"].to_numpy()[0] + 1
|
146 |
|
|
|
|
|
|
|
147 |
st.markdown(" ")
|
148 |
-
col1, col2 = st.columns(2)
|
149 |
|
150 |
-
#
|
151 |
-
with
|
152 |
-
|
153 |
-
|
154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
st.
|
175 |
-
st.
|
176 |
-
|
177 |
-
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
-
if 'button_clicked' not in st.session_state:
|
181 |
-
|
182 |
|
183 |
-
def run_model():
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
|
188 |
-
|
189 |
-
|
|
|
|
|
|
|
190 |
show_results()
|
191 |
-
|
192 |
-
|
193 |
-
show_results()
|
194 |
-
|
195 |
-
run_model()
|
196 |
|
197 |
|
198 |
|
|
|
6 |
import matplotlib.pyplot as plt
|
7 |
import plotly.express as px
|
8 |
|
9 |
+
from utils import load_data_pickle, load_numpy, check_password
|
10 |
from st_pages import add_indentation
|
11 |
|
12 |
|
13 |
st.set_page_config(layout="wide")
|
14 |
|
15 |
+
if check_password():
|
16 |
+
st.title("Topic Modeling 📚")
|
17 |
+
st.markdown("### What is Topic Modeling ?")
|
18 |
|
19 |
+
st.info("""
|
20 |
+
Topic modeling is a text-mining technique used to **identify topics within a collection of documents**.
|
21 |
+
It is a useful tool for organizing and summarizing vast amounts of textual data as well as automate the discovery of hidden thematic structures in a corpus of text data, without any prior knowledge.
|
22 |
+
""")
|
23 |
|
24 |
+
st.markdown(" ")
|
25 |
+
_, col, _ = st.columns([0.25,0.4,0.35])
|
26 |
+
with col:
|
27 |
+
st.image("images/topic_modeling.gif", caption="An example of Topic Modeling", use_column_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
|
30 |
+
st.markdown("""Common applications of Topic Modeling include:
|
31 |
- **Search Engine Optimization (SEO): 🔎** Determine the main topics/keywords present on a website to optimize content and improve search engine rankings.
|
32 |
- **Customer Support** ✍️: Analyze customer support tickets, emails, and chat transcripts to identify common questions and complaints.
|
33 |
- **Fraud Detection and Risk Management: 🏦** : Detect fraudulent activities, compliance violations, and operational risks by analyzing textual data such as transaction descriptions, audit reports and regulatory filings.
|
34 |
- **Market Research 🌎**: Gain competitive intelligence and make informed decisions regarding product development, marketing strategies, and market positioning by analyzing research reports and industry news.
|
35 |
+
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
|
|
|
|
|
|
37 |
|
38 |
+
st.markdown(" ")
|
39 |
+
st.divider()
|
40 |
|
41 |
+
st.markdown("# Topic modeling on product descriptions 🛍️")
|
42 |
+
st.markdown("""In this use case, we will use a **topic model** to categorize around **20 000 e-commerce products** as well as identify
|
43 |
+
the main types of products solds.""")
|
44 |
|
45 |
+
_, col, _ = st.columns([0.2,0.6,0.2])
|
46 |
+
with col:
|
47 |
+
st.image("images/e-commerce.jpg")
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
st.markdown(" ")
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
+
# Load data
|
52 |
+
path_data = "data/topic-modeling"
|
53 |
+
# data = load_data_csv(path_data,"data-topicmodeling.csv")
|
54 |
|
55 |
+
# Load the topic data
|
56 |
+
topic_info = load_data_pickle(path_data, 'topic_info.pkl')
|
57 |
|
58 |
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
+
##### ABOUT THE USE CASE
|
61 |
+
st.markdown("#### About the data 📋")
|
62 |
+
st.markdown("""You were provided a dataset with around 20 000 products from a large e-commerce retailer. <br>
|
63 |
+
This dataset contains the products' title and description on the website.""", unsafe_allow_html=True)
|
64 |
+
st.info("""**Note**: Some of the descriptions featured below are shown in their 'raw' form, meaning they contain unprocessed html code and special characters.
|
65 |
+
These descriptions were first 'cleaned' (by removing unwanted characters) before being used in the model.""")
|
66 |
+
see_data = st.checkbox('**See the data**', key="credit_score_data") # Corrected the key to use an underscore
|
67 |
+
if see_data:
|
68 |
+
st.markdown(" ")
|
69 |
+
st.warning("This view only shows a subset of the 20 000 product description used.")
|
70 |
+
data = load_data_pickle(path_data,"data-tm-view.pkl")
|
71 |
+
data_show = data[["TITLE", "DESCRIPTION"]]
|
72 |
+
st.dataframe(data_show.reset_index(drop=True), use_container_width=True)
|
73 |
|
|
|
|
|
|
|
|
|
74 |
|
75 |
+
st.markdown(" ")
|
76 |
+
st.markdown(" ")
|
|
|
|
|
|
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
+
# RUN THE MODEL
|
81 |
+
st.markdown("#### About the model 📚")
|
82 |
+
st.markdown("""**Topic models** can be seen as unsupervised clustering models where text documents are grouped into topics/clusters based on their similarities.
|
83 |
+
We will use here a topic model to automatically categorize/group the retailer's products based on their description,
|
84 |
+
as well as understand what are the most common type of products being sold.""", unsafe_allow_html=True)
|
85 |
|
86 |
+
st.info("""**Note**: In topic modeling, the final topics are represented by the model using 'top words'.
|
87 |
+
A topic's top words are chosen based on how much they appear in the topic's documents.""")
|
|
|
|
|
88 |
|
89 |
+
def show_results():
|
90 |
+
st.markdown("#### See the results ☑️")
|
91 |
+
tab1, tab2 = st.tabs(["Overall results", "Specific Topic Details", ])# "Search Similar Topics"])
|
92 |
st.markdown(" ")
|
|
|
93 |
|
94 |
+
# Tab 1: Summary Table
|
95 |
+
with tab1:
|
96 |
+
st.header("Overall results")
|
97 |
+
st.markdown("""This tab showcases all of the **topics identified** within the product dataset. <br>
|
98 |
+
Each topic's <b>most significant words</b> (top words), as well as the <b>proportion</b> of products that were assigned to it are given.""",
|
99 |
+
unsafe_allow_html=True)
|
100 |
+
|
101 |
+
summary_table = topic_info[['Title','Representation', 'Percentage']].copy()
|
102 |
+
summary_table['Top Words'] = summary_table['Representation'].apply(lambda x: x[:5]) #:5
|
103 |
+
summary_table = summary_table[["Title","Top Words","Percentage"]]
|
104 |
+
summary_table.rename({"Title":"Topic Title"}, axis=1, inplace=True)
|
105 |
|
106 |
+
st.data_editor(
|
107 |
+
summary_table, #.loc[df_results_tab1["Customer ID"].isin(filter_customers)],
|
108 |
+
column_config={
|
109 |
+
"Percentage": st.column_config.ProgressColumn(
|
110 |
+
"Proportion %",
|
111 |
+
help="Propotion of documents within each topic",
|
112 |
+
format="%.1f%%",
|
113 |
+
min_value=0,
|
114 |
+
max_value=100)},
|
115 |
+
use_container_width=True
|
116 |
+
)
|
117 |
+
|
118 |
+
st.info("""**Note**: The topic 'titles' were not provided by the model but instead were generated by feeding the topic's top words to an LLM.
|
119 |
+
Traditional topic models define topics using representative/top words but weren't built to generate a specific title to each topic.""")
|
120 |
+
|
121 |
+
# Tab 2: Specific Topic Details
|
122 |
+
with tab2:
|
123 |
+
|
124 |
+
# Load top words
|
125 |
+
with open(os.path.join(path_data,"topics_top_words.json"), "r") as json_file:
|
126 |
+
top_words_dict = json.load(json_file)
|
127 |
|
128 |
+
# Load similarity df and scores
|
129 |
+
similarity_df = load_data_pickle(path_data, "similarity_topic_df.pkl")
|
130 |
+
similarity_scores = load_numpy(path_data, "similarity_topic_scores.npy")
|
131 |
+
|
132 |
+
#st.markdown(" ")
|
133 |
+
st.header("Learn more about each topic")
|
134 |
+
st.markdown("""You can **select a specific topic** to get more information on its **top words**, as well as the
|
135 |
+
**other topics that are most similar to it**.""")
|
136 |
+
# st.info("""In this section, you can find more information on each of the topics identified by the model.
|
137 |
+
# This includes the topic's a full list of its top words, the importance of each of these words, as well as the top five topics that are most similar to it.""")
|
138 |
+
|
139 |
+
st.markdown(" ")
|
140 |
+
|
141 |
+
# Select topic
|
142 |
+
topics = topic_info["Title"].sort_values().to_list()
|
143 |
+
selected_topic = st.selectbox('**Select a Topic**', topics)
|
144 |
+
selected_topic_id = topic_info[topic_info['Title'] == selected_topic]["Topic"].to_numpy()[0] + 1
|
145 |
+
|
146 |
+
st.markdown(" ")
|
147 |
+
col1, col2 = st.columns(2)
|
148 |
+
|
149 |
+
# Top words
|
150 |
+
with col1:
|
151 |
+
top_words_df = pd.DataFrame(top_words_dict[selected_topic], columns=["Word", "Importance"])
|
152 |
+
top_words_df.sort_values(by=["Importance"], ascending=False, inplace=True)
|
153 |
+
top_words_df["Importance"] = top_words_df["Importance"].round(2)
|
154 |
+
|
155 |
+
fig = px.bar(top_words_df, x='Word', y='Importance', color="Importance", title="Top words", text_auto=True)
|
156 |
+
fig.update_layout(yaxis=dict(range=[0, 1]), xaxis_title="", showlegend=False)
|
157 |
+
st.plotly_chart(fig, use_container_width=True)
|
158 |
+
st.info("""**Note:** Each score was computed based on the words importance in the particular topic using
|
159 |
+
a popular metric in NLP called TF-IDF (Term Frequency-Inverse Document Frequency). """)
|
160 |
+
|
161 |
+
|
162 |
+
# Similar topics to the selected topic
|
163 |
+
with col2:
|
164 |
+
similarity_df = similarity_df.loc[similarity_df["Topic"]==selected_topic]
|
165 |
+
similarity_df["scores"] = 100*similarity_scores[selected_topic_id,:]
|
166 |
+
similarity_df.columns = ["Original Topic", "Rank", "Topic", "Similarity (%)"]
|
167 |
+
|
168 |
+
fig = px.bar(similarity_df, y='Similarity (%)', x='Topic', color="Topic", title="Five most similar topics", text_auto=True)
|
169 |
+
fig.update_layout(yaxis=dict(range=[0, 100]),
|
170 |
+
xaxis_title="",
|
171 |
+
showlegend=False)
|
172 |
+
|
173 |
+
st.plotly_chart(fig, use_container_width=True)
|
174 |
+
st.info("""**Note:** Topics with a high similarity score can be merged together as to reduce the number of topics, as
|
175 |
+
well as improve the topics' coherence.""")
|
176 |
+
|
177 |
+
return None
|
178 |
|
179 |
+
if 'button_clicked' not in st.session_state:
|
180 |
+
st.session_state['button_clicked'] = False
|
181 |
|
182 |
+
def run_model():
|
183 |
+
run_model = st.button("**Run the model**", type="primary")
|
184 |
+
st.markdown(" ")
|
185 |
+
st.markdown(" ")
|
186 |
|
187 |
+
if not st.session_state['button_clicked']:
|
188 |
+
if run_model:
|
189 |
+
show_results()
|
190 |
+
st.session_state['button_clicked'] = True
|
191 |
+
else:
|
192 |
show_results()
|
193 |
+
|
194 |
+
run_model()
|
|
|
|
|
|
|
195 |
|
196 |
|
197 |
|
utils.py
CHANGED
@@ -108,9 +108,12 @@ def load_model_huggingface(repo_id, token, task=None):
|
|
108 |
def check_password():
|
109 |
"""Returns `True` if the user had the correct password."""
|
110 |
|
|
|
|
|
|
|
111 |
def password_entered():
|
112 |
"""Checks whether a password entered by the user is correct."""
|
113 |
-
if "password" in st.session_state and st.session_state["password"] ==
|
114 |
st.session_state["password_correct"] = True
|
115 |
del st.session_state["password"] # don't store password
|
116 |
else:
|
@@ -134,10 +137,6 @@ def check_password():
|
|
134 |
return True
|
135 |
|
136 |
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
###################### OTHER ######################
|
142 |
|
143 |
def img_to_bytes(img_path):
|
|
|
108 |
def check_password():
|
109 |
"""Returns `True` if the user had the correct password."""
|
110 |
|
111 |
+
password_key = os.environ["PASSWORD"]
|
112 |
+
#password_key = st.secrets["password"]
|
113 |
+
|
114 |
def password_entered():
|
115 |
"""Checks whether a password entered by the user is correct."""
|
116 |
+
if "password" in st.session_state and st.session_state["password"] == password_key:
|
117 |
st.session_state["password_correct"] = True
|
118 |
del st.session_state["password"] # don't store password
|
119 |
else:
|
|
|
137 |
return True
|
138 |
|
139 |
|
|
|
|
|
|
|
|
|
140 |
###################### OTHER ######################
|
141 |
|
142 |
def img_to_bytes(img_path):
|