import numpy as np from sklearn.cluster import KMeans from sklearn.decomposition import PCA from sklearn.preprocessing import StandardScaler import matplotlib.pyplot as plt import seaborn as sns from config import NUM_CLUSTERS, OUTPUT_DIR import os class ClusterAnalyzer: def __init__(self, n_clusters=NUM_CLUSTERS): self.n_clusters = n_clusters self.scaler = StandardScaler() self.kmeans = None self.pca = None def fit_predict(self, features): """Fit KMeans and return cluster labels""" # Standardize features features_scaled = self.scaler.fit_transform(features) # Adaptive PCA - use min(n_samples, n_features, 50) components n_components = min(features_scaled.shape[0] - 1, features_scaled.shape[1], 50) if n_components < 1: n_components = 1 print(f"Using {n_components} PCA components (adapted to data size)") self.pca = PCA(n_components=n_components) features_reduced = self.pca.fit_transform(features_scaled) # Adjust number of clusters if needed n_clusters = min(self.n_clusters, len(features_reduced)) if n_clusters < 1: n_clusters = 1 print(f"Using {n_clusters} clusters") self.kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10) labels = self.kmeans.fit_predict(features_reduced) return labels, features_reduced def get_cluster_centers(self): """Return cluster centers""" if self.kmeans is not None: return self.kmeans.cluster_centers_ return None def visualize_clusters(self, features, labels, image_paths, save_path=None): """Visualize clusters using PCA""" # Further reduce to 2D for visualization (if possible) if features.shape[0] > 2 and features.shape[1] > 2: pca_2d = PCA(n_components=min(2, features.shape[0] - 1, features.shape[1])) features_2d = pca_2d.fit_transform(features) else: # If we can't do PCA, use first 2 features features_2d = features[:, :2] if features.shape[1] >= 2 else np.hstack([features, np.zeros((features.shape[0], 2 - features.shape[1]))]) # Create plot plt.figure(figsize=(12, 8)) # Handle case where we have only one cluster unique_labels = np.unique(labels) if len(unique_labels) > 1: scatter = plt.scatter(features_2d[:, 0], features_2d[:, 1], c=labels, cmap='tab10', alpha=0.7, s=100) plt.colorbar(scatter) else: plt.scatter(features_2d[:, 0], features_2d[:, 1], c='blue', alpha=0.7, s=100) plt.title(f'All samples in single cluster (Cluster {labels[0]})') plt.title('Drill Core Sample Clusters (PCA Visualization)', fontsize=16) plt.xlabel('Feature Dimension 1') plt.ylabel('Feature Dimension 2') # Annotate some points for i in range(min(15, len(features_2d))): if i < len(image_paths): filename = os.path.basename(image_paths[i])[:15] + "..." plt.annotate(filename, (features_2d[i, 0], features_2d[i, 1]), xytext=(5, 5), textcoords='offset points', fontsize=8, alpha=0.7) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"Cluster visualization saved to {save_path}") plt.show() def create_cluster_map(self, image_paths, labels): """Create mapping from cluster ID to image paths""" cluster_map = {} for path, label in zip(image_paths, labels): if label not in cluster_map: cluster_map[label] = [] cluster_map[label].append(path) return cluster_map def analyze_cluster_characteristics(self, features, labels, image_paths): """Analyze characteristics of each cluster""" cluster_stats = {} # Get features for each cluster for cluster_id in np.unique(labels): mask = labels == cluster_id cluster_features = features[mask] # Calculate statistics mean_features = np.mean(cluster_features, axis=0) std_features = np.std(cluster_features, axis=0) # Get image paths for this cluster cluster_images = [path for i, path in enumerate(image_paths) if labels[i] == cluster_id] cluster_stats[cluster_id] = { 'count': len(cluster_images), 'mean_features': mean_features, 'std_features': std_features, 'sample_images': cluster_images[:5] # First 5 samples } return cluster_stats def analyze_clusters(self, features, image_paths): """Complete clustering analysis""" print(f"Performing clustering analysis on {len(image_paths)} samples...") print(f"Feature dimensions: {features.shape}") # Perform clustering labels, features_reduced = self.fit_predict(features) # Create cluster map cluster_map = self.create_cluster_map(image_paths, labels) # Analyze cluster characteristics cluster_stats = self.analyze_cluster_characteristics(features, labels, image_paths) # Visualize if we have enough samples if len(image_paths) > 2: viz_path = os.path.join(OUTPUT_DIR, "clusters.png") self.visualize_clusters(features, labels, image_paths, viz_path) # Print cluster information print("\n" + "="*60) print("CLUSTER ANALYSIS RESULTS") print("="*60) for cluster_id, stats in cluster_stats.items(): print(f"\nCluster {cluster_id}:") print(f" Samples: {stats['count']} images") print(f" Sample files:") for path in stats['sample_images']: print(f" - {os.path.basename(path)}") return labels, cluster_map, cluster_stats if __name__ == "__main__": pass