Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from functools import partial | |
import numpy as np | |
import torch | |
from tqdm import tqdm | |
import math, random | |
#from sklearn.cluster import KMeans, kmeans_plusplus, MeanShift, estimate_bandwidth | |
def tensor_kmeans_sklearn(data_vecs, n_clusters=7, metric='euclidean', need_layer_masks=False, max_iters=20): | |
N,C,H,W = data_vecs.shape | |
assert N == 1, 'only support singe image tensor' | |
## (1,C,H,W) -> (HW,C) | |
data_vecs = data_vecs.permute(0,2,3,1).view(-1,C) | |
## convert tensor to array | |
data_vecs_np = data_vecs.squeeze().detach().to("cpu").numpy() | |
km = KMeans(n_clusters=n_clusters, init='k-means++', n_init=10, max_iter=300) | |
pred = km.fit_predict(data_vecs_np) | |
cluster_ids_x = torch.from_numpy(km.labels_).to(data_vecs.device) | |
id_maps = cluster_ids_x.reshape(1,1,H,W).long() | |
if need_layer_masks: | |
one_hot_labels = F.one_hot(id_maps.squeeze(1), num_classes=n_clusters).float() | |
cluster_mask = one_hot_labels.permute(0,3,1,2) | |
return cluster_mask | |
return id_maps | |
def tensor_kmeans_pytorch(data_vecs, n_clusters=7, metric='euclidean', need_layer_masks=False, max_iters=20): | |
N,C,H,W = data_vecs.shape | |
assert N == 1, 'only support singe image tensor' | |
## (1,C,H,W) -> (HW,C) | |
data_vecs = data_vecs.permute(0,2,3,1).view(-1,C) | |
## cosine | euclidean | |
#cluster_ids_x, cluster_centers = kmeans(X=data_vecs, num_clusters=n_clusters, distance=metric, device=data_vecs.device) | |
cluster_ids_x, cluster_centers = kmeans(X=data_vecs, num_clusters=n_clusters, distance=metric,\ | |
tqdm_flag=False, iter_limit=max_iters, device=data_vecs.device) | |
id_maps = cluster_ids_x.reshape(1,1,H,W) | |
if need_layer_masks: | |
one_hot_labels = F.one_hot(id_maps.squeeze(1), num_classes=n_clusters).float() | |
cluster_mask = one_hot_labels.permute(0,3,1,2) | |
return cluster_mask | |
return id_maps | |
def batch_kmeans_pytorch(data_vecs, n_clusters=7, metric='euclidean', use_sklearn_kmeans=False): | |
N,C,H,W = data_vecs.shape | |
sample_list = [] | |
for idx in range(N): | |
if use_sklearn_kmeans: | |
cluster_mask = tensor_kmeans_sklearn(data_vecs[idx:idx+1,:,:,:], n_clusters, metric, True) | |
else: | |
cluster_mask = tensor_kmeans_pytorch(data_vecs[idx:idx+1,:,:,:], n_clusters, metric, True) | |
sample_list.append(cluster_mask) | |
return torch.cat(sample_list, dim=0) | |
def get_centroid_candidates(data_vecs, n_clusters=7, metric='euclidean', max_iters=20): | |
N,C,H,W = data_vecs.shape | |
data_vecs = data_vecs.permute(0,2,3,1).view(-1,C) | |
cluster_ids_x, cluster_centers = kmeans(X=data_vecs, num_clusters=n_clusters, distance=metric,\ | |
tqdm_flag=False, iter_limit=max_iters, device=data_vecs.device) | |
return cluster_centers | |
def find_distinctive_elements(data_tensor, n_clusters=7, topk=3, metric='euclidean'): | |
N,C,H,W = data_tensor.shape | |
centroid_list = [] | |
for idx in range(N): | |
cluster_centers = get_centroid_candidates(data_tensor[idx:idx+1,:,:,:], n_clusters, metric) | |
centroid_list.append(cluster_centers) | |
batch_centroids = torch.stack(centroid_list, dim=0) | |
data_vecs = data_tensor.flatten(2) | |
## distance matrix: (N,K,HW) = (N,K,C) x (N,C,HW) | |
AtB = torch.matmul(batch_centroids, data_vecs) | |
AtA = torch.matmul(batch_centroids, batch_centroids.permute(0,2,1)) | |
BtB = torch.matmul(data_vecs.permute(0,2,1), data_vecs) | |
diag_A = torch.diagonal(AtA, dim1=-2, dim2=-1) | |
diag_B = torch.diagonal(BtB, dim1=-2, dim2=-1) | |
A2 = diag_A.unsqueeze(2).repeat(1,1,H*W) | |
B2 = diag_B.unsqueeze(1).repeat(1,n_clusters,1) | |
distance_map = A2 - 2*AtB + B2 | |
values, indices = distance_map.topk(topk, dim=2, largest=False, sorted=True) | |
cluster_mask = torch.where(distance_map <= values[:,:,topk-1:], torch.ones_like(distance_map), torch.zeros_like(distance_map)) | |
cluster_mask = cluster_mask.view(N,n_clusters,H,W) | |
return cluster_mask | |
##--------------------------------------------------------------------------------- | |
''' | |
resource from github: https://github.com/subhadarship/kmeans_pytorch | |
''' | |
##--------------------------------------------------------------------------------- | |
def initialize(X, num_clusters): | |
""" | |
initialize cluster centers | |
:param X: (torch.tensor) matrix | |
:param num_clusters: (int) number of clusters | |
:return: (np.array) initial state | |
""" | |
np.random.seed(1) | |
num_samples = len(X) | |
indices = np.random.choice(num_samples, num_clusters, replace=False) | |
initial_state = X[indices] | |
return initial_state | |
def kmeans( | |
X, | |
num_clusters, | |
distance='euclidean', | |
cluster_centers=[], | |
tol=1e-4, | |
tqdm_flag=True, | |
iter_limit=0, | |
device=torch.device('cpu'), | |
gamma_for_soft_dtw=0.001 | |
): | |
""" | |
perform kmeans | |
:param X: (torch.tensor) matrix | |
:param num_clusters: (int) number of clusters | |
:param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean'] | |
:param tol: (float) threshold [default: 0.0001] | |
:param device: (torch.device) device [default: cpu] | |
:param tqdm_flag: Allows to turn logs on and off | |
:param iter_limit: hard limit for max number of iterations | |
:param gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0 | |
:return: (torch.tensor, torch.tensor) cluster ids, cluster centers | |
""" | |
if tqdm_flag: | |
print(f'running k-means on {device}..') | |
if distance == 'euclidean': | |
pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag) | |
elif distance == 'cosine': | |
pairwise_distance_function = partial(pairwise_cosine, device=device) | |
else: | |
raise NotImplementedError | |
# convert to float | |
X = X.float() | |
# transfer to device | |
X = X.to(device) | |
# initialize | |
if type(cluster_centers) == list: # ToDo: make this less annoyingly weird | |
initial_state = initialize(X, num_clusters) | |
else: | |
if tqdm_flag: | |
print('resuming') | |
# find data point closest to the initial cluster center | |
initial_state = cluster_centers | |
dis = pairwise_distance_function(X, initial_state) | |
choice_points = torch.argmin(dis, dim=0) | |
initial_state = X[choice_points] | |
initial_state = initial_state.to(device) | |
iteration = 0 | |
if tqdm_flag: | |
tqdm_meter = tqdm(desc='[running kmeans]') | |
while True: | |
dis = pairwise_distance_function(X, initial_state) | |
choice_cluster = torch.argmin(dis, dim=1) | |
initial_state_pre = initial_state.clone() | |
for index in range(num_clusters): | |
selected = torch.nonzero(choice_cluster == index).squeeze().to(device) | |
selected = torch.index_select(X, 0, selected) | |
# https://github.com/subhadarship/kmeans_pytorch/issues/16 | |
if selected.shape[0] == 0: | |
selected = X[torch.randint(len(X), (1,))] | |
initial_state[index] = selected.mean(dim=0) | |
center_shift = torch.sum( | |
torch.sqrt( | |
torch.sum((initial_state - initial_state_pre) ** 2, dim=1) | |
)) | |
# increment iteration | |
iteration = iteration + 1 | |
# update tqdm meter | |
if tqdm_flag: | |
tqdm_meter.set_postfix( | |
iteration=f'{iteration}', | |
center_shift=f'{center_shift ** 2:0.6f}', | |
tol=f'{tol:0.6f}' | |
) | |
tqdm_meter.update() | |
if center_shift ** 2 < tol: | |
break | |
if iter_limit != 0 and iteration >= iter_limit: | |
#print('hello, there!') | |
break | |
return choice_cluster.to(device), initial_state.to(device) | |
def kmeans_predict( | |
X, | |
cluster_centers, | |
distance='euclidean', | |
device=torch.device('cpu'), | |
gamma_for_soft_dtw=0.001, | |
tqdm_flag=True | |
): | |
""" | |
predict using cluster centers | |
:param X: (torch.tensor) matrix | |
:param cluster_centers: (torch.tensor) cluster centers | |
:param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean'] | |
:param device: (torch.device) device [default: 'cpu'] | |
:param gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0 | |
:return: (torch.tensor) cluster ids | |
""" | |
if tqdm_flag: | |
print(f'predicting on {device}..') | |
if distance == 'euclidean': | |
pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag) | |
elif distance == 'cosine': | |
pairwise_distance_function = partial(pairwise_cosine, device=device) | |
elif distance == 'soft_dtw': | |
sdtw = SoftDTW(use_cuda=device.type == 'cuda', gamma=gamma_for_soft_dtw) | |
pairwise_distance_function = partial(pairwise_soft_dtw, sdtw=sdtw, device=device) | |
else: | |
raise NotImplementedError | |
# convert to float | |
X = X.float() | |
# transfer to device | |
X = X.to(device) | |
dis = pairwise_distance_function(X, cluster_centers) | |
choice_cluster = torch.argmin(dis, dim=1) | |
return choice_cluster.cpu() | |
def pairwise_distance(data1, data2, device=torch.device('cpu'), tqdm_flag=True): | |
if tqdm_flag: | |
print(f'device is :{device}') | |
# transfer to device | |
data1, data2 = data1.to(device), data2.to(device) | |
# N*1*M | |
A = data1.unsqueeze(dim=1) | |
# 1*N*M | |
B = data2.unsqueeze(dim=0) | |
dis = (A - B) ** 2.0 | |
# return N*N matrix for pairwise distance | |
dis = dis.sum(dim=-1).squeeze() | |
return dis | |
def pairwise_cosine(data1, data2, device=torch.device('cpu')): | |
# transfer to device | |
data1, data2 = data1.to(device), data2.to(device) | |
# N*1*M | |
A = data1.unsqueeze(dim=1) | |
# 1*N*M | |
B = data2.unsqueeze(dim=0) | |
# normalize the points | [0.3, 0.4] -> [0.3/sqrt(0.09 + 0.16), 0.4/sqrt(0.09 + 0.16)] = [0.3/0.5, 0.4/0.5] | |
A_normalized = A / A.norm(dim=-1, keepdim=True) | |
B_normalized = B / B.norm(dim=-1, keepdim=True) | |
cosine = A_normalized * B_normalized | |
# return N*N matrix for pairwise distance | |
cosine_dis = 1 - cosine.sum(dim=-1).squeeze() | |
return cosine_dis |