# File 8: scripts/cluster.py import torch import numpy as np from sklearn.cluster import KMeans from utils.data_loader import UserFeedbackDataset def update_latent_clusters(): dataset = UserFeedbackDataset() latents = torch.cat([data[0] for data in dataset]).numpy() # Cluster latent vectors kmeans = KMeans(n_clusters=10) clusters = kmeans.fit_predict(latents) # Update VAE prior with cluster centers cluster_centers = torch.tensor(kmeans.cluster_centers_, dtype=torch.float32) # Save new prior distribution torch.save(cluster_centers, 'storage/models/latent_prior.pt')