File size: 622 Bytes
b516268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# File 8: scripts/cluster.py
import torch
import numpy as np
from sklearn.cluster import KMeans
from utils.data_loader import UserFeedbackDataset

def update_latent_clusters():
    dataset = UserFeedbackDataset()
    latents = torch.cat([data[0] for data in dataset]).numpy()
    
    # Cluster latent vectors
    kmeans = KMeans(n_clusters=10)
    clusters = kmeans.fit_predict(latents)
    
    # Update VAE prior with cluster centers
    cluster_centers = torch.tensor(kmeans.cluster_centers_, dtype=torch.float32)
    
    # Save new prior distribution
    torch.save(cluster_centers, 'storage/models/latent_prior.pt')