tave-st commited on
Commit
ce71a2d
1 Parent(s): f6249e2

Add range in heatmap

Browse files
Files changed (3) hide show
  1. pages/clustering.py +53 -30
  2. recommender.py +1 -1
  3. recommender_system.py +0 -1
pages/clustering.py CHANGED
@@ -7,7 +7,7 @@ import altair as alt
7
  from sklearn.mixture import GaussianMixture
8
  import plotly.express as px
9
  import itertools
10
- from typing import Dict, List
11
 
12
 
13
  SIDEBAR_DESCRIPTION = """
@@ -15,10 +15,10 @@ SIDEBAR_DESCRIPTION = """
15
 
16
  To cluster a client, we adopt the RFM metrics. They stand for:
17
 
18
- - R = recency, that is the number of days since the last purchase
19
  in the store
20
  - F = frequency, that is the number of times a customer has ordered something
21
- - M = monetary value, that is how much a customer has spent buying
22
  from your business.
23
 
24
  Given these 3 metrics, we can cluster the customers and find a suitable
@@ -28,8 +28,8 @@ we're using right now has about 5000 distinct customers, we identify
28
 
29
  ## How we compute the clusters
30
 
31
- We resort to a GaussianMixture algorithm. We can think of GaussianMixture
32
- as generalized k-means clustering that incorporates information about
33
  the covariance structure of the data as well as the centers of the clusters.
34
  """.lstrip()
35
 
@@ -46,7 +46,7 @@ There 3 available clusters for this metric:
46
  """.lstrip()
47
 
48
  RECENCY_CLUSTERS_EXPLAIN = """
49
- The **recency** refers to how recently a customer has bought;
50
 
51
  There 3 available clusters for this metric:
52
 
@@ -58,7 +58,7 @@ There 3 available clusters for this metric:
58
  """.lstrip()
59
 
60
  MONETARY_CLUSTERS_EXPLAIN = """
61
- The **revenue** refers to how much a customer has spent buying
62
  from your business.
63
 
64
  There 3 available clusters for this metric:
@@ -115,7 +115,7 @@ def cluster_clients(df: pd.DataFrame):
115
 
116
 
117
  def _order_cluster(cluster_model: GaussianMixture, clusters, order="ascending"):
118
- """Orders the cluster by order."""
119
  centroids = cluster_model.means_.sum(axis=1)
120
 
121
  if order.lower() == "descending":
@@ -191,7 +191,10 @@ def explain_cluster(cluster_info):
191
  " and values"
192
  )
193
  for cluster, info in cluster_info.items():
194
- st.write(EXPLANATION_DICT[cluster].format(*info))
 
 
 
195
 
196
 
197
  def categorize_user(recency_cluster, frequency_cluster, monetary_cluster):
@@ -231,7 +234,9 @@ def categorize_user(recency_cluster, frequency_cluster, monetary_cluster):
231
  st.write(f"The customer can be described as: **{description}**")
232
 
233
 
234
- def plot_rfm_distribution(df_rfm: pd.DataFrame, cluster_info: Dict[str, List[int]]):
 
 
235
  """Plots 3 histograms for the RFM metrics."""
236
 
237
  for x, to_reverse in zip(("Revenue", "Frequency", "Recency"), (False, False, True)):
@@ -241,20 +246,21 @@ def plot_rfm_distribution(df_rfm: pd.DataFrame, cluster_info: Dict[str, List[int
241
  log_y=True,
242
  title=f"{x} metric",
243
  )
244
- # Get the max value in the cluster info. The cluster info is a list of min - max
245
- # values per cluster.
246
- values = cluster_info[f"{x}_cluster"]
 
247
  print(values)
248
  # Add vertical bar on each cluster end. But skip the last cluster.
249
- loop_range = list(enumerate(range(1, len(values)-1, 2)))
250
  if to_reverse:
251
- # @todo: remove hardcoded values
252
- loop_range = zip((2, 1), range(len(values)-1, 1, -2))
253
- for n_cluster, i in loop_range:
254
  print(x)
255
- print(values[i])
256
  fig.add_vline(
257
- x=values[i],
258
  annotation_text=f"End of cluster {n_cluster+1}",
259
  line_dash="dot",
260
  annotation=dict(textangle=90, font_color="red"),
@@ -267,13 +273,20 @@ def plot_rfm_distribution(df_rfm: pd.DataFrame, cluster_info: Dict[str, List[int
267
  st.plotly_chart(fig)
268
 
269
 
270
- def display_dataframe_heatmap(df_rfm: pd.DataFrame):
271
  """Displays an heatmap of how many clients lay in the clusters.
272
 
273
  This method uses some black magic coming from the dataframe
274
  styling guide.
275
  """
276
 
 
 
 
 
 
 
 
277
  # Create a dataframe with the count of clients for each group
278
  # of cluster.
279
 
@@ -291,6 +304,13 @@ def display_dataframe_heatmap(df_rfm: pd.DataFrame):
291
  ["Revenue_cluster", "Frequency_cluster", "Recency_cluster"]
292
  )
293
 
 
 
 
 
 
 
 
294
  # Use the count column as values, then index with the clusters.
295
  count = count.pivot(
296
  index=["Revenue_cluster", "Frequency_cluster"],
@@ -301,15 +321,15 @@ def display_dataframe_heatmap(df_rfm: pd.DataFrame):
301
  # Style manipulation
302
  cell_hover = {
303
  "selector": "td",
304
- "props": "font-size:1.5em",
305
  }
306
  index_names = {
307
  "selector": ".index_name",
308
- "props": "font-style: italic; color: Black; font-weight:normal;font-size:1.5em;",
309
  }
310
  headers = {
311
  "selector": "th:not(.index_name)",
312
- "props": "background-color: White; color: black; font-size:1.5em",
313
  }
314
 
315
  # Finally, display
@@ -336,7 +356,7 @@ def main():
336
  "# Dataset "
337
  "\nThis is the processed dataset with information about the clients, such as"
338
  " the RFM values and the clusters they belong to."
339
- )
340
  st.dataframe(df_rfm.style.format(formatter={"Revenue": "{:.2f}"}))
341
 
342
  cluster_info_dict = defaultdict(list)
@@ -351,15 +371,14 @@ def main():
351
  )
352
  min_cluster = cluster_info["min"].astype(int)
353
  max_cluster = cluster_info["max"].astype(int)
354
- min_max_interlieved = list(itertools.chain(*zip(min_cluster, max_cluster)))
355
- cluster_info_dict[cluster].extend(min_max_interlieved)
356
  st.dataframe(cluster_info)
357
 
358
  st.markdown("## RFM metric distribution")
359
 
360
  plot_rfm_distribution(df_rfm, cluster_info_dict)
361
 
362
- display_dataframe_heatmap(df_rfm)
363
 
364
  st.markdown("## Interactive exploration")
365
 
@@ -369,9 +388,13 @@ def main():
369
  )
370
 
371
  client_to_select = (
372
- df_rfm.groupby(["Recency_cluster", "Frequency_cluster", "Revenue_cluster"])["CustomerID"].first().values
373
- if filter_by_cluster
374
- else df["CustomerID"].unique()
 
 
 
 
375
  )
376
 
377
  # Let the user select the user to investigate
 
7
  from sklearn.mixture import GaussianMixture
8
  import plotly.express as px
9
  import itertools
10
+ from typing import Dict, List, Tuple
11
 
12
 
13
  SIDEBAR_DESCRIPTION = """
 
15
 
16
  To cluster a client, we adopt the RFM metrics. They stand for:
17
 
18
+ - R = recency, that is the number of days since the last purchase
19
  in the store
20
  - F = frequency, that is the number of times a customer has ordered something
21
+ - M = monetary value, that is how much a customer has spent buying
22
  from your business.
23
 
24
  Given these 3 metrics, we can cluster the customers and find a suitable
 
28
 
29
  ## How we compute the clusters
30
 
31
+ We resort to a GaussianMixture algorithm. We can think of GaussianMixture
32
+ as generalized k-means clustering that incorporates information about
33
  the covariance structure of the data as well as the centers of the clusters.
34
  """.lstrip()
35
 
 
46
  """.lstrip()
47
 
48
  RECENCY_CLUSTERS_EXPLAIN = """
49
+ The **recency** refers to how recently a customer has bought;
50
 
51
  There 3 available clusters for this metric:
52
 
 
58
  """.lstrip()
59
 
60
  MONETARY_CLUSTERS_EXPLAIN = """
61
+ The **revenue** refers to how much a customer has spent buying
62
  from your business.
63
 
64
  There 3 available clusters for this metric:
 
115
 
116
 
117
  def _order_cluster(cluster_model: GaussianMixture, clusters, order="ascending"):
118
+ """Orders the cluster by `order`."""
119
  centroids = cluster_model.means_.sum(axis=1)
120
 
121
  if order.lower() == "descending":
 
191
  " and values"
192
  )
193
  for cluster, info in cluster_info.items():
194
+ # Transform the (mins, maxs) tuple into
195
+ # [min_1, max_1, min_2, max_2, ...] list.
196
+ min_max_interleaved = list(itertools.chain(*zip(info[0], info[1])))
197
+ st.write(EXPLANATION_DICT[cluster].format(*min_max_interleaved))
198
 
199
 
200
  def categorize_user(recency_cluster, frequency_cluster, monetary_cluster):
 
234
  st.write(f"The customer can be described as: **{description}**")
235
 
236
 
237
+ def plot_rfm_distribution(
238
+ df_rfm: pd.DataFrame, cluster_info: Dict[str, Tuple[List[int], List[int]]]
239
+ ):
240
  """Plots 3 histograms for the RFM metrics."""
241
 
242
  for x, to_reverse in zip(("Revenue", "Frequency", "Recency"), (False, False, True)):
 
246
  log_y=True,
247
  title=f"{x} metric",
248
  )
249
+ # Get the max value in the cluster info. The cluster_info_dict is a
250
+ # tuple with first element the min values of the cluster, and second
251
+ # element the max values of the cluster.
252
+ values = cluster_info[f"{x}_cluster"][1] # get max values
253
  print(values)
254
  # Add vertical bar on each cluster end. But skip the last cluster.
255
+ loop_range = range(len(values) - 1)
256
  if to_reverse:
257
+ # Skip the last element
258
+ loop_range = range(len(values) - 1, 0, -1)
259
+ for n_cluster in loop_range:
260
  print(x)
261
+ print(values[n_cluster])
262
  fig.add_vline(
263
+ x=values[n_cluster],
264
  annotation_text=f"End of cluster {n_cluster+1}",
265
  line_dash="dot",
266
  annotation=dict(textangle=90, font_color="red"),
 
273
  st.plotly_chart(fig)
274
 
275
 
276
+ def display_dataframe_heatmap(df_rfm: pd.DataFrame, cluster_info_dict):
277
  """Displays an heatmap of how many clients lay in the clusters.
278
 
279
  This method uses some black magic coming from the dataframe
280
  styling guide.
281
  """
282
 
283
+ def style_with_limits(x, column, cluster_limit_dict):
284
+ """Simple function to transform the cluster number into
285
+ a cluster + range string."""
286
+ min_v = cluster_limit_dict[column][0][x - 1]
287
+ max_v = cluster_limit_dict[column][1][x - 1]
288
+ return f"{x}: [{int(min_v)}, {int(max_v)}]"
289
+
290
  # Create a dataframe with the count of clients for each group
291
  # of cluster.
292
 
 
304
  ["Revenue_cluster", "Frequency_cluster", "Recency_cluster"]
305
  )
306
 
307
+ # Add limits to the cells. In this way, we can better display
308
+ # the heatmap.
309
+ for cluster in ["Revenue_cluster", "Frequency_cluster", "Recency_cluster"]:
310
+ count[cluster] = count[cluster].apply(
311
+ lambda x: style_with_limits(x, cluster, cluster_info_dict)
312
+ )
313
+
314
  # Use the count column as values, then index with the clusters.
315
  count = count.pivot(
316
  index=["Revenue_cluster", "Frequency_cluster"],
 
321
  # Style manipulation
322
  cell_hover = {
323
  "selector": "td",
324
+ "props": "font-size:1.2em",
325
  }
326
  index_names = {
327
  "selector": ".index_name",
328
+ "props": "font-style: italic; color: Black; font-weight:normal;font-size:1.2em;",
329
  }
330
  headers = {
331
  "selector": "th:not(.index_name)",
332
+ "props": "background-color: White; color: black; font-size:1.2em",
333
  }
334
 
335
  # Finally, display
 
356
  "# Dataset "
357
  "\nThis is the processed dataset with information about the clients, such as"
358
  " the RFM values and the clusters they belong to."
359
+ )
360
  st.dataframe(df_rfm.style.format(formatter={"Revenue": "{:.2f}"}))
361
 
362
  cluster_info_dict = defaultdict(list)
 
371
  )
372
  min_cluster = cluster_info["min"].astype(int)
373
  max_cluster = cluster_info["max"].astype(int)
374
+ cluster_info_dict[cluster] = (min_cluster, max_cluster)
 
375
  st.dataframe(cluster_info)
376
 
377
  st.markdown("## RFM metric distribution")
378
 
379
  plot_rfm_distribution(df_rfm, cluster_info_dict)
380
 
381
+ display_dataframe_heatmap(df_rfm, cluster_info_dict)
382
 
383
  st.markdown("## Interactive exploration")
384
 
 
388
  )
389
 
390
  client_to_select = (
391
+ df_rfm.groupby(["Recency_cluster", "Frequency_cluster", "Revenue_cluster"])[
392
+ "CustomerID"
393
+ ]
394
+ .first()
395
+ .values
396
+ if filter_by_cluster
397
+ else df["CustomerID"].unique()
398
  )
399
 
400
  # Let the user select the user to investigate
recommender.py CHANGED
@@ -82,7 +82,7 @@ class Recommender:
82
  def recommend_products(
83
  self,
84
  user_id,
85
- items_to_recommend = 5,
86
  ):
87
  """Finds the recommended items for the user.
88
 
 
82
  def recommend_products(
83
  self,
84
  user_id,
85
+ items_to_recommend=5,
86
  ):
87
  """Finds the recommended items for the user.
88
 
recommender_system.py CHANGED
@@ -242,7 +242,6 @@ def display_recommendation_plots(
242
  items_other_description = _extract_description(df, bought_by_similar_users)
243
  suggestion_description = _extract_description(df, suggestions)
244
 
245
-
246
  # Plot the scatterplot
247
 
248
  fig = go.Figure()
 
242
  items_other_description = _extract_description(df, bought_by_similar_users)
243
  suggestion_description = _extract_description(df, suggestions)
244
 
 
245
  # Plot the scatterplot
246
 
247
  fig = go.Figure()