disco / models /clusterkit.py
menghanxia's picture
fixed the randomness issue
2a10b61
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
import numpy as np
import torch
from tqdm import tqdm
import math, random
#from sklearn.cluster import KMeans, kmeans_plusplus, MeanShift, estimate_bandwidth
def tensor_kmeans_sklearn(data_vecs, n_clusters=7, metric='euclidean', need_layer_masks=False, max_iters=20):
N,C,H,W = data_vecs.shape
assert N == 1, 'only support singe image tensor'
## (1,C,H,W) -> (HW,C)
data_vecs = data_vecs.permute(0,2,3,1).view(-1,C)
## convert tensor to array
data_vecs_np = data_vecs.squeeze().detach().to("cpu").numpy()
km = KMeans(n_clusters=n_clusters, init='k-means++', n_init=10, max_iter=300)
pred = km.fit_predict(data_vecs_np)
cluster_ids_x = torch.from_numpy(km.labels_).to(data_vecs.device)
id_maps = cluster_ids_x.reshape(1,1,H,W).long()
if need_layer_masks:
one_hot_labels = F.one_hot(id_maps.squeeze(1), num_classes=n_clusters).float()
cluster_mask = one_hot_labels.permute(0,3,1,2)
return cluster_mask
return id_maps
def tensor_kmeans_pytorch(data_vecs, n_clusters=7, metric='euclidean', need_layer_masks=False, max_iters=20):
N,C,H,W = data_vecs.shape
assert N == 1, 'only support singe image tensor'
## (1,C,H,W) -> (HW,C)
data_vecs = data_vecs.permute(0,2,3,1).view(-1,C)
## cosine | euclidean
#cluster_ids_x, cluster_centers = kmeans(X=data_vecs, num_clusters=n_clusters, distance=metric, device=data_vecs.device)
cluster_ids_x, cluster_centers = kmeans(X=data_vecs, num_clusters=n_clusters, distance=metric,\
tqdm_flag=False, iter_limit=max_iters, device=data_vecs.device)
id_maps = cluster_ids_x.reshape(1,1,H,W)
if need_layer_masks:
one_hot_labels = F.one_hot(id_maps.squeeze(1), num_classes=n_clusters).float()
cluster_mask = one_hot_labels.permute(0,3,1,2)
return cluster_mask
return id_maps
def batch_kmeans_pytorch(data_vecs, n_clusters=7, metric='euclidean', use_sklearn_kmeans=False):
N,C,H,W = data_vecs.shape
sample_list = []
for idx in range(N):
if use_sklearn_kmeans:
cluster_mask = tensor_kmeans_sklearn(data_vecs[idx:idx+1,:,:,:], n_clusters, metric, True)
else:
cluster_mask = tensor_kmeans_pytorch(data_vecs[idx:idx+1,:,:,:], n_clusters, metric, True)
sample_list.append(cluster_mask)
return torch.cat(sample_list, dim=0)
def get_centroid_candidates(data_vecs, n_clusters=7, metric='euclidean', max_iters=20):
N,C,H,W = data_vecs.shape
data_vecs = data_vecs.permute(0,2,3,1).view(-1,C)
cluster_ids_x, cluster_centers = kmeans(X=data_vecs, num_clusters=n_clusters, distance=metric,\
tqdm_flag=False, iter_limit=max_iters, device=data_vecs.device)
return cluster_centers
def find_distinctive_elements(data_tensor, n_clusters=7, topk=3, metric='euclidean'):
N,C,H,W = data_tensor.shape
centroid_list = []
for idx in range(N):
cluster_centers = get_centroid_candidates(data_tensor[idx:idx+1,:,:,:], n_clusters, metric)
centroid_list.append(cluster_centers)
batch_centroids = torch.stack(centroid_list, dim=0)
data_vecs = data_tensor.flatten(2)
## distance matrix: (N,K,HW) = (N,K,C) x (N,C,HW)
AtB = torch.matmul(batch_centroids, data_vecs)
AtA = torch.matmul(batch_centroids, batch_centroids.permute(0,2,1))
BtB = torch.matmul(data_vecs.permute(0,2,1), data_vecs)
diag_A = torch.diagonal(AtA, dim1=-2, dim2=-1)
diag_B = torch.diagonal(BtB, dim1=-2, dim2=-1)
A2 = diag_A.unsqueeze(2).repeat(1,1,H*W)
B2 = diag_B.unsqueeze(1).repeat(1,n_clusters,1)
distance_map = A2 - 2*AtB + B2
values, indices = distance_map.topk(topk, dim=2, largest=False, sorted=True)
cluster_mask = torch.where(distance_map <= values[:,:,topk-1:], torch.ones_like(distance_map), torch.zeros_like(distance_map))
cluster_mask = cluster_mask.view(N,n_clusters,H,W)
return cluster_mask
##---------------------------------------------------------------------------------
'''
resource from github: https://github.com/subhadarship/kmeans_pytorch
'''
##---------------------------------------------------------------------------------
def initialize(X, num_clusters):
"""
initialize cluster centers
:param X: (torch.tensor) matrix
:param num_clusters: (int) number of clusters
:return: (np.array) initial state
"""
np.random.seed(1)
num_samples = len(X)
indices = np.random.choice(num_samples, num_clusters, replace=False)
initial_state = X[indices]
return initial_state
def kmeans(
X,
num_clusters,
distance='euclidean',
cluster_centers=[],
tol=1e-4,
tqdm_flag=True,
iter_limit=0,
device=torch.device('cpu'),
gamma_for_soft_dtw=0.001
):
"""
perform kmeans
:param X: (torch.tensor) matrix
:param num_clusters: (int) number of clusters
:param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean']
:param tol: (float) threshold [default: 0.0001]
:param device: (torch.device) device [default: cpu]
:param tqdm_flag: Allows to turn logs on and off
:param iter_limit: hard limit for max number of iterations
:param gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0
:return: (torch.tensor, torch.tensor) cluster ids, cluster centers
"""
if tqdm_flag:
print(f'running k-means on {device}..')
if distance == 'euclidean':
pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag)
elif distance == 'cosine':
pairwise_distance_function = partial(pairwise_cosine, device=device)
else:
raise NotImplementedError
# convert to float
X = X.float()
# transfer to device
X = X.to(device)
# initialize
if type(cluster_centers) == list: # ToDo: make this less annoyingly weird
initial_state = initialize(X, num_clusters)
else:
if tqdm_flag:
print('resuming')
# find data point closest to the initial cluster center
initial_state = cluster_centers
dis = pairwise_distance_function(X, initial_state)
choice_points = torch.argmin(dis, dim=0)
initial_state = X[choice_points]
initial_state = initial_state.to(device)
iteration = 0
if tqdm_flag:
tqdm_meter = tqdm(desc='[running kmeans]')
while True:
dis = pairwise_distance_function(X, initial_state)
choice_cluster = torch.argmin(dis, dim=1)
initial_state_pre = initial_state.clone()
for index in range(num_clusters):
selected = torch.nonzero(choice_cluster == index).squeeze().to(device)
selected = torch.index_select(X, 0, selected)
# https://github.com/subhadarship/kmeans_pytorch/issues/16
if selected.shape[0] == 0:
selected = X[torch.randint(len(X), (1,))]
initial_state[index] = selected.mean(dim=0)
center_shift = torch.sum(
torch.sqrt(
torch.sum((initial_state - initial_state_pre) ** 2, dim=1)
))
# increment iteration
iteration = iteration + 1
# update tqdm meter
if tqdm_flag:
tqdm_meter.set_postfix(
iteration=f'{iteration}',
center_shift=f'{center_shift ** 2:0.6f}',
tol=f'{tol:0.6f}'
)
tqdm_meter.update()
if center_shift ** 2 < tol:
break
if iter_limit != 0 and iteration >= iter_limit:
#print('hello, there!')
break
return choice_cluster.to(device), initial_state.to(device)
def kmeans_predict(
X,
cluster_centers,
distance='euclidean',
device=torch.device('cpu'),
gamma_for_soft_dtw=0.001,
tqdm_flag=True
):
"""
predict using cluster centers
:param X: (torch.tensor) matrix
:param cluster_centers: (torch.tensor) cluster centers
:param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean']
:param device: (torch.device) device [default: 'cpu']
:param gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0
:return: (torch.tensor) cluster ids
"""
if tqdm_flag:
print(f'predicting on {device}..')
if distance == 'euclidean':
pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag)
elif distance == 'cosine':
pairwise_distance_function = partial(pairwise_cosine, device=device)
elif distance == 'soft_dtw':
sdtw = SoftDTW(use_cuda=device.type == 'cuda', gamma=gamma_for_soft_dtw)
pairwise_distance_function = partial(pairwise_soft_dtw, sdtw=sdtw, device=device)
else:
raise NotImplementedError
# convert to float
X = X.float()
# transfer to device
X = X.to(device)
dis = pairwise_distance_function(X, cluster_centers)
choice_cluster = torch.argmin(dis, dim=1)
return choice_cluster.cpu()
def pairwise_distance(data1, data2, device=torch.device('cpu'), tqdm_flag=True):
if tqdm_flag:
print(f'device is :{device}')
# transfer to device
data1, data2 = data1.to(device), data2.to(device)
# N*1*M
A = data1.unsqueeze(dim=1)
# 1*N*M
B = data2.unsqueeze(dim=0)
dis = (A - B) ** 2.0
# return N*N matrix for pairwise distance
dis = dis.sum(dim=-1).squeeze()
return dis
def pairwise_cosine(data1, data2, device=torch.device('cpu')):
# transfer to device
data1, data2 = data1.to(device), data2.to(device)
# N*1*M
A = data1.unsqueeze(dim=1)
# 1*N*M
B = data2.unsqueeze(dim=0)
# normalize the points | [0.3, 0.4] -> [0.3/sqrt(0.09 + 0.16), 0.4/sqrt(0.09 + 0.16)] = [0.3/0.5, 0.4/0.5]
A_normalized = A / A.norm(dim=-1, keepdim=True)
B_normalized = B / B.norm(dim=-1, keepdim=True)
cosine = A_normalized * B_normalized
# return N*N matrix for pairwise distance
cosine_dis = 1 - cosine.sum(dim=-1).squeeze()
return cosine_dis