# File 8: scripts/cluster.py | |
import torch | |
import numpy as np | |
from sklearn.cluster import KMeans | |
from utils.data_loader import UserFeedbackDataset | |
def update_latent_clusters(): | |
dataset = UserFeedbackDataset() | |
latents = torch.cat([data[0] for data in dataset]).numpy() | |
# Cluster latent vectors | |
kmeans = KMeans(n_clusters=10) | |
clusters = kmeans.fit_predict(latents) | |
# Update VAE prior with cluster centers | |
cluster_centers = torch.tensor(kmeans.cluster_centers_, dtype=torch.float32) | |
# Save new prior distribution | |
torch.save(cluster_centers, 'storage/models/latent_prior.pt') |