Spaces:
Runtime error
Runtime error
import torch | |
import numpy as np | |
def pairwise_distances(x, y): | |
#Input: x is a Nxd matrix | |
# y is an optional Mxd matirx | |
#Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:] | |
# if y is not given then use 'y=x'. | |
#i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2 | |
x_norm = (x ** 2).sum(1).view(-1, 1) | |
y_t = torch.transpose(y, 0, 1) | |
y_norm = (y ** 2).sum(1).view(1, -1) | |
dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t) | |
return torch.clamp(dist, 0.0, np.inf) | |
def meanshift_cluster(pts, bandwidth, weights = None, meanshift_step = 15, step_size = 0.3): | |
""" | |
meanshift written in pytorch | |
:param pts: input points | |
:param weights: weight per point during clustering | |
:return: clustered points | |
""" | |
pts_steps = [] | |
for i in range(meanshift_step): | |
Y = pairwise_distances(pts, pts) | |
K = torch.nn.functional.relu(bandwidth ** 2 - Y) | |
if weights is not None: | |
K = K * weights | |
P = torch.nn.functional.normalize(K, p=1, dim=0, eps=1e-10) | |
P = P.transpose(0, 1) | |
pts = step_size * (torch.matmul(P, pts) - pts) + pts | |
pts_steps.append(pts) | |
return pts_steps | |
def distance(a,b): | |
return torch.sqrt(((a-b)**2).sum()) | |
def meanshift_assign(points, bandwidth): | |
cluster_ids = [] | |
cluster_idx = 0 | |
cluster_centers = [] | |
for i, point in enumerate(points): | |
if(len(cluster_ids) == 0): | |
cluster_ids.append(cluster_idx) | |
cluster_centers.append(point) | |
cluster_idx += 1 | |
else: | |
# assign to nearest cluster | |
#for j,center in enumerate(cluster_centers): | |
# dist = distance(point, center) | |
# if(dist < bandwidth): | |
# cluster_ids.append(j) | |
cdist = torch.cdist(point.unsqueeze(0), torch.stack(cluster_centers), p = 2) | |
nearest_idx = torch.argmin(cdist, dim = 1) | |
if cdist[0, nearest_idx] < bandwidth: | |
cluster_ids.append(nearest_idx) | |
else: | |
cluster_ids.append(cluster_idx) | |
cluster_centers.append(point) | |
cluster_idx += 1 | |
return cluster_ids, cluster_centers | |