import math,pdb |
import torch,pynvml |
from torch.nn.functional import normalize |
from time import time |
import numpy as np |
def _kpp(data: torch.Tensor, k: int, sample_size: int = -1): |
""" Picks k points in the data based on the kmeans++ method. |
Parameters |
---------- |
data : torch.Tensor |
Expect a rank 1 or 2 array. Rank 1 is assumed to describe 1-D |
data, rank 2 multidimensional data, in which case one |
row is one observation. |
k : int |
Number of samples to generate. |
sample_size : int |
sample data to avoid memory overflow during calculation |
Returns |
------- |
init : ndarray |
A 'k' by 'N' containing the initial centroids. |
References |
---------- |
.. [1] D. Arthur and S. Vassilvitskii, "k-means++: the advantages of |
careful seeding", Proceedings of the Eighteenth Annual ACM-SIAM Symposium |
on Discrete Algorithms, 2007. |
.. [2] scipy/cluster/vq.py: _kpp |
""" |
batch_size=data.shape[0] |
if batch_size>sample_size: |
data = data[torch.randint(0, batch_size,[sample_size], device=data.device)] |
dims = data.shape[1] if len(data.shape) > 1 else 1 |
init = torch.zeros((k, dims)).to(data.device) |
r = torch.distributions.uniform.Uniform(0, 1) |
for i in range(k): |
if i == 0: |
init[i, :] = data[torch.randint(data.shape[0], [1])] |
else: |
D2 = torch.cdist(init[:i, :][None, :], data[None, :], p=2)[0].amin(dim=0) |
probs = D2 / torch.sum(D2) |
cumprobs = torch.cumsum(probs, dim=0) |
init[i, :] = data[torch.searchsorted(cumprobs, r.sample([1]).to(data.device))] |
return init |
class KMeansGPU: |
''' |
Kmeans clustering algorithm implemented with PyTorch |
Parameters: |
n_clusters: int, |
Number of clusters |
max_iter: int, default: 100 |
Maximum number of iterations |
tol: float, default: 0.0001 |
Tolerance |
verbose: int, default: 0 |
Verbosity |
mode: {'euclidean', 'cosine'}, default: 'euclidean' |
Type of distance measure |
init_method: {'random', 'point', '++'} |
Type of initialization |
minibatch: {None, int}, default: None |
Batch size of MinibatchKmeans algorithm |
if None perform full KMeans algorithm |
Attributes: |
centroids: torch.Tensor, shape: [n_clusters, n_features] |
cluster centroids |
''' |
def __init__(self, n_clusters, max_iter=200, tol=1e-4, verbose=0, mode="euclidean",device=torch.device("cuda:0")): |
self.n_clusters = n_clusters |
self.max_iter = max_iter |
self.tol = tol |
self.verbose = verbose |
self.mode = mode |
self.device=device |
pynvml.nvmlInit() |
gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(device.index) |
info = pynvml.nvmlDeviceGetMemoryInfo(gpu_handle) |
self.minibatch=int(33e6/self.n_clusters*info.free/ 1024 / 1024 / 1024) |
print("free_mem/GB:",info.free/ 1024 / 1024 / 1024,"minibatch:",self.minibatch) |
@staticmethod |
def cos_sim(a, b): |
""" |
Compute cosine similarity of 2 sets of vectors |
Parameters: |
a: torch.Tensor, shape: [m, n_features] |
b: torch.Tensor, shape: [n, n_features] |
""" |
return normalize(a, dim=-1) @ normalize(b, dim=-1).transpose(-2, -1) |
@staticmethod |
def euc_sim(a, b): |
""" |
Compute euclidean similarity of 2 sets of vectors |
Parameters: |
a: torch.Tensor, shape: [m, n_features] |
b: torch.Tensor, shape: [n, n_features] |
""" |
return 2 * a @ b.transpose(-2, -1) -(a**2).sum(dim=1)[..., :, None] - (b**2).sum(dim=1)[..., None, :] |
def max_sim(self, a, b): |
""" |
Compute maximum similarity (or minimum distance) of each vector |
in a with all of the vectors in b |
Parameters: |
a: torch.Tensor, shape: [m, n_features] |
b: torch.Tensor, shape: [n, n_features] |
""" |
if self.mode == 'cosine': |
sim_func = self.cos_sim |
elif self.mode == 'euclidean': |
sim_func = self.euc_sim |
sim = sim_func(a, b) |
max_sim_v, max_sim_i = sim.max(dim=-1) |
return max_sim_v, max_sim_i |
def fit_predict(self, X): |
""" |
Combination of fit() and predict() methods. |
This is faster than calling fit() and predict() seperately. |
Parameters: |
X: torch.Tensor, shape: [n_samples, n_features] |
centroids: {torch.Tensor, None}, default: None |
if given, centroids will be initialized with given tensor |
if None, centroids will be randomly chosen from X |
Return: |
labels: torch.Tensor, shape: [n_samples] |
mini_=33kk/k*remain |
mini=min(mini_,fea_shape) |
offset=log2(k/1000)*1.5 |
kpp_all=min(mini_*10/offset,fea_shape) |
kpp_sample=min(mini_/12/offset,fea_shape) |
""" |
assert isinstance(X, torch.Tensor), "input must be torch.Tensor" |
assert X.dtype in [torch.half, torch.float, torch.double], "input must be floating point" |
assert X.ndim == 2, "input must be a 2d tensor with shape: [n_samples, n_features] " |
offset = np.power(1.5,np.log(self.n_clusters / 1000))/np.log(2) |
with torch.no_grad(): |
batch_size= X.shape[0] |
start_time = time() |
if (self.minibatch*10//offset< batch_size): |
x = X[torch.randint(0, batch_size,[int(self.minibatch*10/offset)])].to(self.device) |
else: |
x = X.to(self.device) |
self.centroids = _kpp(x, self.n_clusters, min(int(self.minibatch/12/offset),batch_size)) |
del x |
torch.cuda.empty_cache() |
num_points_in_clusters = torch.ones(self.n_clusters, device=self.device, dtype=X.dtype) |
closest = None |
if(self.minibatch>=batch_size//2 and self.minibatch<batch_size): |
X = X[torch.randint(0, batch_size,[self.minibatch])].to(self.device) |
elif(self.minibatch>=batch_size): |
X=X.to(self.device) |
for i in range(self.max_iter): |
iter_time = time() |
if self.minibatch<batch_size//2: |
x = X[torch.randint(0, batch_size, [self.minibatch])].to(self.device) |
else: |
x = X |
closest = self.max_sim(a=x, b=self.centroids)[1].to(torch.int16) |
matched_clusters, counts = closest.unique(return_counts=True) |
expanded_closest = closest[None].expand(self.n_clusters, -1) |
mask = (expanded_closest==torch.arange(self.n_clusters, device=self.device)[:, None]).to(X.dtype) |
c_grad = mask @ x / mask.sum(-1)[..., :, None] |
c_grad[c_grad!=c_grad] = 0 |
error = (c_grad - self.centroids).pow(2).sum() |
if self.minibatch is not None: |
lr = 1/num_points_in_clusters[:,None] * 0.9 + 0.1 |
else: |
lr = 1 |
matched_clusters=matched_clusters.long() |
num_points_in_clusters[matched_clusters] += counts |
self.centroids = self.centroids * (1-lr) + c_grad * lr |
if self.verbose >= 2: |
print('iter:', i, 'error:', error.item(), 'time spent:', round(time()-iter_time, 4)) |
if error <= self.tol: |
break |
if self.verbose >= 1: |
print(f'used {i+1} iterations ({round(time()-start_time, 4)}s) to cluster {batch_size} items into {self.n_clusters} clusters') |
return closest |