Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| def pairwise_distances(x, y): | |
| #Input: x is a Nxd matrix | |
| # y is an optional Mxd matirx | |
| #Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:] | |
| # if y is not given then use 'y=x'. | |
| #i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2 | |
| x_norm = (x ** 2).sum(1).view(-1, 1) | |
| y_t = torch.transpose(y, 0, 1) | |
| y_norm = (y ** 2).sum(1).view(1, -1) | |
| dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t) | |
| return torch.clamp(dist, 0.0, np.inf) | |
| def meanshift_cluster(pts, bandwidth, weights = None, meanshift_step = 15, step_size = 0.3): | |
| """ | |
| meanshift written in pytorch | |
| :param pts: input points | |
| :param weights: weight per point during clustering | |
| :return: clustered points | |
| """ | |
| pts_steps = [] | |
| for i in range(meanshift_step): | |
| Y = pairwise_distances(pts, pts) | |
| K = torch.nn.functional.relu(bandwidth ** 2 - Y) | |
| if weights is not None: | |
| K = K * weights | |
| P = torch.nn.functional.normalize(K, p=1, dim=0, eps=1e-10) | |
| P = P.transpose(0, 1) | |
| pts = step_size * (torch.matmul(P, pts) - pts) + pts | |
| pts_steps.append(pts) | |
| return pts_steps | |
| def distance(a,b): | |
| return torch.sqrt(((a-b)**2).sum()) | |
| def meanshift_assign(points, bandwidth): | |
| cluster_ids = [] | |
| cluster_idx = 0 | |
| cluster_centers = [] | |
| for i, point in enumerate(points): | |
| if(len(cluster_ids) == 0): | |
| cluster_ids.append(cluster_idx) | |
| cluster_centers.append(point) | |
| cluster_idx += 1 | |
| else: | |
| # assign to nearest cluster | |
| #for j,center in enumerate(cluster_centers): | |
| # dist = distance(point, center) | |
| # if(dist < bandwidth): | |
| # cluster_ids.append(j) | |
| cdist = torch.cdist(point.unsqueeze(0), torch.stack(cluster_centers), p = 2) | |
| nearest_idx = torch.argmin(cdist, dim = 1) | |
| if cdist[0, nearest_idx] < bandwidth: | |
| cluster_ids.append(nearest_idx) | |
| else: | |
| cluster_ids.append(cluster_idx) | |
| cluster_centers.append(point) | |
| cluster_idx += 1 | |
| return cluster_ids, cluster_centers | |