Spaces:
Sleeping
Sleeping
Commit
•
3861db2
1
Parent(s):
e6fd8db
Upload pages_clustering.py
Browse files- pages_clustering.py +415 -0
pages_clustering.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
import streamlit as st
|
3 |
+
from utils import load_and_preprocess_data
|
4 |
+
import pandas as pd
|
5 |
+
import numpy as np
|
6 |
+
import altair as alt
|
7 |
+
from sklearn.mixture import GaussianMixture
|
8 |
+
import plotly.express as px
|
9 |
+
import itertools
|
10 |
+
from typing import Dict, List, Tuple
|
11 |
+
|
12 |
+
|
13 |
+
SIDEBAR_DESCRIPTION = """
|
14 |
+
# Client clustering
|
15 |
+
|
16 |
+
To cluster a client, we adopt the RFM metrics. They stand for:
|
17 |
+
|
18 |
+
- R = recency, that is the number of days since the last purchase
|
19 |
+
in the store
|
20 |
+
- F = frequency, that is the number of times a customer has ordered something
|
21 |
+
- M = monetary value, that is how much a customer has spent buying
|
22 |
+
from your business.
|
23 |
+
|
24 |
+
Given these 3 metrics, we can cluster the customers and find a suitable
|
25 |
+
"definition" based on the clusters they belong to. Since the dataset
|
26 |
+
we're using right now has about 5000 distinct customers, we identify
|
27 |
+
3 clusters for each metric.
|
28 |
+
|
29 |
+
## How we compute the clusters
|
30 |
+
|
31 |
+
We resort to a GaussianMixture algorithm. We can think of GaussianMixture
|
32 |
+
as generalized k-means clustering that incorporates information about
|
33 |
+
the covariance structure of the data as well as the centers of the clusters.
|
34 |
+
""".lstrip()
|
35 |
+
|
36 |
+
FREQUENCY_CLUSTERS_EXPLAIN = """
|
37 |
+
The **frequency** denotes how frequently a customer has ordered.
|
38 |
+
|
39 |
+
There 3 available clusters for this metric:
|
40 |
+
|
41 |
+
- cluster 1: denotes a customer that purchases one or few times (range [{}, {}])
|
42 |
+
- cluster 2: these customer have a discrete amount of orders (range [{}, {}])
|
43 |
+
- cluster 3: these customer purchases lots of times (range [{}, {}])
|
44 |
+
|
45 |
+
-------
|
46 |
+
""".lstrip()
|
47 |
+
|
48 |
+
RECENCY_CLUSTERS_EXPLAIN = """
|
49 |
+
The **recency** refers to how recently a customer has bought;
|
50 |
+
|
51 |
+
There 3 available clusters for this metric:
|
52 |
+
|
53 |
+
- cluster 1: the last order of these client is long time ago (range [{}, {}])
|
54 |
+
- cluster 2: these are clients that purchases something not very recently (range [{}, {}])
|
55 |
+
- cluster 3: the last order of these client is a few days/weeks ago (range [{}, {}])
|
56 |
+
|
57 |
+
-------
|
58 |
+
""".lstrip()
|
59 |
+
|
60 |
+
MONETARY_CLUSTERS_EXPLAIN = """
|
61 |
+
The **revenue** refers to how much a customer has spent buying
|
62 |
+
from your business.
|
63 |
+
|
64 |
+
There 3 available clusters for this metric:
|
65 |
+
|
66 |
+
- cluster 1: these clients spent little money (range [{}, {}])
|
67 |
+
- cluster 2: these clients spent a considerable amount of money (range [{}, {}])
|
68 |
+
- cluster 3: these clients spent lots of money (range [{}, {}])
|
69 |
+
|
70 |
+
-------
|
71 |
+
""".lstrip()
|
72 |
+
|
73 |
+
EXPLANATION_DICT = {
|
74 |
+
"Frequency_cluster": FREQUENCY_CLUSTERS_EXPLAIN,
|
75 |
+
"Recency_cluster": RECENCY_CLUSTERS_EXPLAIN,
|
76 |
+
"Revenue_cluster": MONETARY_CLUSTERS_EXPLAIN,
|
77 |
+
}
|
78 |
+
|
79 |
+
|
80 |
+
def create_features(df: pd.DataFrame):
|
81 |
+
"""Creates a new dataframe with the RFM features for each client."""
|
82 |
+
# Compute frequency, the number of distinct time a user purchased.
|
83 |
+
client_features = df.groupby("CustomerID")["InvoiceDate"].nunique().reset_index()
|
84 |
+
client_features.columns = ["CustomerID", "Frequency"]
|
85 |
+
|
86 |
+
# Add monetary value, the total revenue for each single user.
|
87 |
+
client_takings = df.groupby("CustomerID")["Price"].sum()
|
88 |
+
client_features["Revenue"] = client_takings.values
|
89 |
+
|
90 |
+
# Add recency, i.e. the days since the last purchase in the store.
|
91 |
+
max_date = df.groupby("CustomerID")["InvoiceDate"].max().reset_index()
|
92 |
+
max_date.columns = ["CustomerID", "LastPurchaseDate"]
|
93 |
+
|
94 |
+
client_features["Recency"] = (
|
95 |
+
max_date["LastPurchaseDate"].max() - max_date["LastPurchaseDate"]
|
96 |
+
).dt.days
|
97 |
+
|
98 |
+
return client_features
|
99 |
+
|
100 |
+
|
101 |
+
@st.cache
|
102 |
+
def cluster_clients(df: pd.DataFrame):
|
103 |
+
"""Computes the RFM features and clusters for each user based on the RFM metrics."""
|
104 |
+
|
105 |
+
df_rfm = create_features(df)
|
106 |
+
|
107 |
+
for to_cluster, order in zip(
|
108 |
+
["Revenue", "Frequency", "Recency"], ["ascending", "ascending", "descending"]
|
109 |
+
):
|
110 |
+
kmeans = GaussianMixture(n_components=3, random_state=42)
|
111 |
+
labels = kmeans.fit_predict(df_rfm[[to_cluster]])
|
112 |
+
df_rfm[f"{to_cluster}_cluster"] = _order_cluster(kmeans, labels, order)
|
113 |
+
|
114 |
+
return df_rfm
|
115 |
+
|
116 |
+
|
117 |
+
def _order_cluster(cluster_model: GaussianMixture, clusters, order="ascending"):
|
118 |
+
"""Orders the cluster by `order`."""
|
119 |
+
centroids = cluster_model.means_.sum(axis=1)
|
120 |
+
|
121 |
+
if order.lower() == "descending":
|
122 |
+
centroids *= -1
|
123 |
+
|
124 |
+
ascending_order = np.argsort(centroids)
|
125 |
+
lookup_table = np.zeros_like(ascending_order)
|
126 |
+
# Cluster will start from 1
|
127 |
+
lookup_table[ascending_order] = np.arange(cluster_model.n_components) + 1
|
128 |
+
return lookup_table[clusters]
|
129 |
+
|
130 |
+
|
131 |
+
def show_purhcase_history(user: int, df: pd.DataFrame):
|
132 |
+
user_purchases = df.loc[df.CustomerID == user, ["Price", "InvoiceDate"]]
|
133 |
+
expenses = user_purchases.groupby(user_purchases.InvoiceDate).sum()
|
134 |
+
expenses.columns = ["Expenses"]
|
135 |
+
expenses = expenses.reset_index()
|
136 |
+
|
137 |
+
c = (
|
138 |
+
alt.Chart(expenses)
|
139 |
+
.mark_line(point=True)
|
140 |
+
.encode(
|
141 |
+
x=alt.X("InvoiceDate", timeUnit="yearmonthdate", title="Date"),
|
142 |
+
y="Expenses",
|
143 |
+
)
|
144 |
+
.properties(title="User expenses")
|
145 |
+
)
|
146 |
+
|
147 |
+
st.altair_chart(c)
|
148 |
+
|
149 |
+
|
150 |
+
def show_user_info(user: int, df_rfm: pd.DataFrame):
|
151 |
+
"""Prints some information about the user.
|
152 |
+
|
153 |
+
The main information are the total expenses, how
|
154 |
+
many times he purchases in the store, and the clusters
|
155 |
+
he belongs to.
|
156 |
+
"""
|
157 |
+
|
158 |
+
user_row = df_rfm[df_rfm["CustomerID"] == user]
|
159 |
+
if len(user_row) == 0:
|
160 |
+
st.write(f"No user with id {user}")
|
161 |
+
|
162 |
+
output = []
|
163 |
+
|
164 |
+
output.append(f"The user purchased **{user_row['Frequency'].squeeze()} times**.\n")
|
165 |
+
output.append(
|
166 |
+
f"She/he spent **{user_row['Revenue'].squeeze()} dollars** in total.\n"
|
167 |
+
)
|
168 |
+
output.append(
|
169 |
+
f"The last time she/he bought something was **{user_row['Recency'].squeeze()} days ago**.\n"
|
170 |
+
)
|
171 |
+
output.append(f"She/he belongs to the clusters: ")
|
172 |
+
for cluster in [column for column in user_row.columns if "_cluster" in column]:
|
173 |
+
output.append(f"- {cluster} = {user_row[cluster].squeeze()}")
|
174 |
+
|
175 |
+
st.write("\n".join(output))
|
176 |
+
|
177 |
+
return (
|
178 |
+
user_row["Recency_cluster"].squeeze(),
|
179 |
+
user_row["Frequency_cluster"].squeeze(),
|
180 |
+
user_row["Revenue_cluster"].squeeze(),
|
181 |
+
)
|
182 |
+
|
183 |
+
|
184 |
+
def explain_cluster(cluster_info):
|
185 |
+
"""Displays a popup menu explinging the meanining of the clusters."""
|
186 |
+
|
187 |
+
with st.expander("Show information about the clusters"):
|
188 |
+
st.write(
|
189 |
+
"**Note**: these values are valid for these dataset."
|
190 |
+
"Different dataset will have different number of clusters"
|
191 |
+
" and values"
|
192 |
+
)
|
193 |
+
for cluster, info in cluster_info.items():
|
194 |
+
# Transform the (mins, maxs) tuple into
|
195 |
+
# [min_1, max_1, min_2, max_2, ...] list.
|
196 |
+
min_max_interleaved = list(itertools.chain(*zip(info[0], info[1])))
|
197 |
+
st.write(EXPLANATION_DICT[cluster].format(*min_max_interleaved))
|
198 |
+
|
199 |
+
|
200 |
+
def categorize_user(recency_cluster, frequency_cluster, monetary_cluster):
|
201 |
+
"""Describe the user with few words based on the cluster he belongs to."""
|
202 |
+
|
203 |
+
score = f"{recency_cluster}{frequency_cluster}{monetary_cluster}"
|
204 |
+
|
205 |
+
# @fixme: find a better approeach. These elif chains don't scale at all.
|
206 |
+
|
207 |
+
description = ""
|
208 |
+
|
209 |
+
if score == "111":
|
210 |
+
description = "Tourist"
|
211 |
+
elif score.startswith("2"):
|
212 |
+
description = "Losing interest"
|
213 |
+
elif score == "133":
|
214 |
+
description = "Former lover"
|
215 |
+
elif score == "123":
|
216 |
+
description = "Former passionate client"
|
217 |
+
elif score == "113":
|
218 |
+
description = "Spent a lot, but never come back"
|
219 |
+
elif score.startswith("1"):
|
220 |
+
description = "About to dump"
|
221 |
+
elif score == "313":
|
222 |
+
description = "Potential lover"
|
223 |
+
elif score == "312":
|
224 |
+
description = "Interesting new client"
|
225 |
+
elif score == "311":
|
226 |
+
description = "New customer"
|
227 |
+
elif score == "333":
|
228 |
+
description = "Gold client"
|
229 |
+
elif score == "322":
|
230 |
+
description = "Lovers"
|
231 |
+
else:
|
232 |
+
description = "Average client"
|
233 |
+
|
234 |
+
st.write(f"The customer can be described as: **{description}**")
|
235 |
+
|
236 |
+
|
237 |
+
def plot_rfm_distribution(
|
238 |
+
df_rfm: pd.DataFrame, cluster_info: Dict[str, Tuple[List[int], List[int]]]
|
239 |
+
):
|
240 |
+
"""Plots 3 histograms for the RFM metrics."""
|
241 |
+
|
242 |
+
for x, to_reverse in zip(("Revenue", "Frequency", "Recency"), (False, False, True)):
|
243 |
+
fig = px.histogram(
|
244 |
+
df_rfm,
|
245 |
+
x=x,
|
246 |
+
log_y=True,
|
247 |
+
title=f"{x} metric",
|
248 |
+
)
|
249 |
+
# Get the max value in the cluster info. The cluster_info_dict is a
|
250 |
+
# tuple with first element the min values of the cluster, and second
|
251 |
+
# element the max values of the cluster.
|
252 |
+
values = cluster_info[f"{x}_cluster"][1] # get max values
|
253 |
+
print(values)
|
254 |
+
# Add vertical bar on each cluster end. But skip the last cluster.
|
255 |
+
loop_range = range(len(values) - 1)
|
256 |
+
if to_reverse:
|
257 |
+
# Skip the last element
|
258 |
+
loop_range = range(len(values) - 1, 0, -1)
|
259 |
+
for n_cluster in loop_range:
|
260 |
+
print(x)
|
261 |
+
print(values[n_cluster])
|
262 |
+
fig.add_vline(
|
263 |
+
x=values[n_cluster],
|
264 |
+
annotation_text=f"End of cluster {n_cluster+1}",
|
265 |
+
line_dash="dot",
|
266 |
+
annotation=dict(textangle=90, font_color="red"),
|
267 |
+
)
|
268 |
+
|
269 |
+
fig.update_layout(
|
270 |
+
yaxis_title="Count (log scale)",
|
271 |
+
)
|
272 |
+
|
273 |
+
st.plotly_chart(fig)
|
274 |
+
|
275 |
+
|
276 |
+
def display_dataframe_heatmap(df_rfm: pd.DataFrame, cluster_info_dict):
|
277 |
+
"""Displays an heatmap of how many clients lay in the clusters.
|
278 |
+
|
279 |
+
This method uses some black magic coming from the dataframe
|
280 |
+
styling guide.
|
281 |
+
"""
|
282 |
+
|
283 |
+
def style_with_limits(x, column, cluster_limit_dict):
|
284 |
+
"""Simple function to transform the cluster number into
|
285 |
+
a cluster + range string."""
|
286 |
+
min_v = cluster_limit_dict[column][0][x - 1]
|
287 |
+
max_v = cluster_limit_dict[column][1][x - 1]
|
288 |
+
return f"{x}: [{int(min_v)}, {int(max_v)}]"
|
289 |
+
|
290 |
+
# Create a dataframe with the count of clients for each group
|
291 |
+
# of cluster.
|
292 |
+
|
293 |
+
count = (
|
294 |
+
df_rfm.groupby(["Recency_cluster", "Frequency_cluster", "Revenue_cluster"])[
|
295 |
+
"CustomerID"
|
296 |
+
]
|
297 |
+
.count()
|
298 |
+
.reset_index()
|
299 |
+
)
|
300 |
+
count = count.rename(columns={"CustomerID": "Count"})
|
301 |
+
|
302 |
+
# Remove duplicates
|
303 |
+
count = count.drop_duplicates(
|
304 |
+
["Revenue_cluster", "Frequency_cluster", "Recency_cluster"]
|
305 |
+
)
|
306 |
+
|
307 |
+
# Add limits to the cells. In this way, we can better display
|
308 |
+
# the heatmap.
|
309 |
+
for cluster in ["Revenue_cluster", "Frequency_cluster", "Recency_cluster"]:
|
310 |
+
count[cluster] = count[cluster].apply(
|
311 |
+
lambda x: style_with_limits(x, cluster, cluster_info_dict)
|
312 |
+
)
|
313 |
+
|
314 |
+
# Use the count column as values, then index with the clusters.
|
315 |
+
count = count.pivot(
|
316 |
+
index=["Revenue_cluster", "Frequency_cluster"],
|
317 |
+
columns="Recency_cluster",
|
318 |
+
values="Count",
|
319 |
+
)
|
320 |
+
|
321 |
+
# Style manipulation
|
322 |
+
cell_hover = {
|
323 |
+
"selector": "td",
|
324 |
+
"props": "font-size:1.2em",
|
325 |
+
}
|
326 |
+
index_names = {
|
327 |
+
"selector": ".index_name",
|
328 |
+
"props": "font-style: italic; color: Black; font-weight:normal;font-size:1.2em;",
|
329 |
+
}
|
330 |
+
headers = {
|
331 |
+
"selector": "th:not(.index_name)",
|
332 |
+
"props": "background-color: White; color: black; font-size:1.2em",
|
333 |
+
}
|
334 |
+
|
335 |
+
# Finally, display
|
336 |
+
# We cannot directly print the dataframe since the streamlit
|
337 |
+
# functin remove the multiindex. Thus, we extract the html representation
|
338 |
+
# and then display it.
|
339 |
+
st.markdown("## Heatmap: how the client are distributed between clusters")
|
340 |
+
st.write(
|
341 |
+
count.style.format(thousands=" ", precision=0, na_rep="0")
|
342 |
+
.set_table_styles([cell_hover, index_names, headers])
|
343 |
+
.background_gradient(cmap="coolwarm")
|
344 |
+
.to_html(),
|
345 |
+
unsafe_allow_html=True,
|
346 |
+
)
|
347 |
+
|
348 |
+
|
349 |
+
def main():
|
350 |
+
st.sidebar.markdown(SIDEBAR_DESCRIPTION)
|
351 |
+
|
352 |
+
df, _, _ = load_and_preprocess_data()
|
353 |
+
df_rfm = cluster_clients(df)
|
354 |
+
|
355 |
+
st.markdown(
|
356 |
+
"# Dataset "
|
357 |
+
"\nThis is the processed dataset with information about the clients, such as"
|
358 |
+
" the RFM values and the clusters they belong to."
|
359 |
+
)
|
360 |
+
st.dataframe(df_rfm.style.format(formatter={"Revenue": "{:.2f}"}))
|
361 |
+
|
362 |
+
cluster_info_dict = defaultdict(list)
|
363 |
+
|
364 |
+
with st.expander("Show more details about the clusters"):
|
365 |
+
for cluster in [column for column in df_rfm.columns if "_cluster" in column]:
|
366 |
+
st.write(cluster)
|
367 |
+
cluster_info = (
|
368 |
+
df_rfm.groupby(cluster)[cluster.split("_")[0]]
|
369 |
+
.describe()
|
370 |
+
.reset_index(names="Cluster")
|
371 |
+
)
|
372 |
+
min_cluster = cluster_info["min"].astype(int)
|
373 |
+
max_cluster = cluster_info["max"].astype(int)
|
374 |
+
cluster_info_dict[cluster] = (min_cluster, max_cluster)
|
375 |
+
st.dataframe(cluster_info)
|
376 |
+
|
377 |
+
st.markdown("## RFM metric distribution")
|
378 |
+
|
379 |
+
plot_rfm_distribution(df_rfm, cluster_info_dict)
|
380 |
+
|
381 |
+
display_dataframe_heatmap(df_rfm, cluster_info_dict)
|
382 |
+
|
383 |
+
st.markdown("## Interactive exploration")
|
384 |
+
|
385 |
+
filter_by_cluster = st.checkbox(
|
386 |
+
"Filter client: only one client per cluster type",
|
387 |
+
value=True,
|
388 |
+
)
|
389 |
+
|
390 |
+
client_to_select = (
|
391 |
+
df_rfm.groupby(["Recency_cluster", "Frequency_cluster", "Revenue_cluster"])[
|
392 |
+
"CustomerID"
|
393 |
+
]
|
394 |
+
.first()
|
395 |
+
.values
|
396 |
+
if filter_by_cluster
|
397 |
+
else df["CustomerID"].unique()
|
398 |
+
)
|
399 |
+
|
400 |
+
# Let the user select the user to investigate
|
401 |
+
user = st.selectbox(
|
402 |
+
"Select a customer to show more information about him.",
|
403 |
+
client_to_select,
|
404 |
+
)
|
405 |
+
|
406 |
+
show_purhcase_history(user, df)
|
407 |
+
|
408 |
+
recency, frequency, revenue = show_user_info(user, df_rfm)
|
409 |
+
|
410 |
+
categorize_user(recency, frequency, revenue)
|
411 |
+
|
412 |
+
explain_cluster(cluster_info_dict)
|
413 |
+
|
414 |
+
|
415 |
+
main()
|