projectai / cluster.py
Matthew Frazer
Create cluster.py
b516268 verified
raw
history blame contribute delete
622 Bytes
# File 8: scripts/cluster.py
import torch
import numpy as np
from sklearn.cluster import KMeans
from utils.data_loader import UserFeedbackDataset
def update_latent_clusters():
dataset = UserFeedbackDataset()
latents = torch.cat([data[0] for data in dataset]).numpy()
# Cluster latent vectors
kmeans = KMeans(n_clusters=10)
clusters = kmeans.fit_predict(latents)
# Update VAE prior with cluster centers
cluster_centers = torch.tensor(kmeans.cluster_centers_, dtype=torch.float32)
# Save new prior distribution
torch.save(cluster_centers, 'storage/models/latent_prior.pt')