Spaces:
Runtime error
Runtime error
| # evaluate.py | |
| # Purpose: small evaluation and visualization for clusters | |
| import numpy as np | |
| import pandas as pd | |
| from sklearn.metrics import silhouette_score | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| def silhouette(embs, labels): | |
| mask = labels >= 0 | |
| if mask.sum() <= 1: | |
| return None | |
| score = silhouette_score(embs[mask], labels[mask]) | |
| return score | |
| def cluster_stats(df_original, labels): | |
| df = df_original.copy() | |
| df['cluster'] = labels | |
| stats = df.groupby('cluster').agg({'customer_id':'count', 'annual_income':'median', 'spend_score':'median'}) | |
| return stats | |
| if __name__ == '__main__': | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--features', default='data/features.parquet') | |
| parser.add_argument('--emb', default='data/embeddings.npy') | |
| parser.add_argument('--labels', default='data/cluster_labels.npy') | |
| args = parser.parse_args() | |
| df = pd.read_parquet(args.features) | |
| embs = np.load(args.emb) | |
| labels = np.load(args.labels) | |
| s = silhouette(embs, labels) | |
| print('Silhouette score (ignoring noise labels -1):', s) | |
| try: | |
| stats = cluster_stats(df, labels) | |
| print(stats) | |
| except Exception: | |
| print('Could not compute descriptive stats (missing columns).') |