File size: 2,296 Bytes
874b2d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from collections.abc import Sequence
from matplotlib.figure import Figure


def plot_cluster_counts(labels: Sequence[int]) -> Figure:
    """
    Generate a bar chart showing the number of samples in each cluster.

    Args:
        labels: Sequence of integer cluster labels.
    Returns:
        Matplotlib Figure with cluster size distribution.
    """
    # Count and sort cluster sizes
    counts = pd.Series(labels).value_counts().sort_index()

    # Create bar chart
    fig, ax = plt.subplots(figsize=(8, 5))
    ax.bar(counts.index.astype(str), counts.values, edgecolor="black")
    ax.set_title("Cluster Size Distribution", fontsize=14, fontweight="bold")
    ax.set_xlabel("Cluster Label", fontsize=12)
    ax.set_ylabel("Number of Samples", fontsize=12)
    ax.grid(axis="y", linestyle="--", alpha=0.6)
    plt.tight_layout()
    return fig


def visualize_clusters(
    X: np.ndarray,
    labels: Sequence[int],
    centers: np.ndarray
) -> Figure:
    """
    Scatter plot of clustered data with centroids.

    Args:
        X: 2D array of shape (n_samples, 2).
        labels: Cluster labels for each sample.
        centers: 2D array of cluster centroids.
    Returns:
        Matplotlib Figure with clusters and centroids plotted.
    """
    unique_labels = np.unique(labels)
    n_clusters = unique_labels.size

    # Choose a colormap
    cmap = plt.get_cmap('tab10')

    fig, ax = plt.subplots(figsize=(8, 6))
    for idx, cluster in enumerate(unique_labels):
        mask = labels == cluster
        ax.scatter(
            X[mask, 0], X[mask, 1],
            s=50,
            label=f"Cluster {cluster}",
            color=cmap(idx),
            edgecolor='k',
            alpha=0.7
        )

    # Plot centroids
    ax.scatter(
        centers[:, 0], centers[:, 1],
        s=200,
        marker='X',
        c='black',
        label='Centroids',
        linewidths=2
    )

    ax.set_title("Cluster Visualization", fontsize=14, fontweight="bold")
    ax.set_xlabel('Annual Income ($K)', fontsize=14)
    ax.set_xlabel('Spending Score', fontsize=14)
    ax.legend(title="Clusters", fontsize=10, title_fontsize=12)
    ax.grid(True, linestyle="--", alpha=0.6)
    plt.tight_layout()
    return fig