Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,543 Bytes
9e15541 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
# from pykeops.torch import LazyTensor
from typing import Tuple
import matplotlib.pyplot as plt
import torch
from torch import nn, Tensor
class VisualizationModule(nn.Module):
def __init__(self, in_channels, reduce_images=3):
super().__init__()
self.batch_rgb_mean = torch.zeros(in_channels)
self.batch_rgb_comp = torch.eye(in_channels, 3)
self.reduce_images = reduce_images
self.fitted_pca = False
self.n_kmeans_clusters = 8
self.kmeans_cluster_centers = torch.zeros(self.n_kmeans_clusters, in_channels)
self.cmap_kmeans = plt.get_cmap("tab10")
def fit_pca(self, batch_features, refit):
if batch_features.dim() > 2:
raise ValueError(f"Wrong dims for PCA: {batch_features.shape}")
if not self.fitted_pca or refit:
# filter nan values
batch_features = batch_features[~torch.isnan(batch_features).any(dim=1)]
self._pca_fast(batch_features, num_components=3*self.reduce_images)
self.fitted_pca = True
def transform_pca(self, features, norm, from_dim):
features = features - self.batch_rgb_mean
if norm:
features = features / torch.linalg.norm(features, dim=-1, keepdim=True)
return features @ self.batch_rgb_comp[..., from_dim:from_dim+3]
def _pca_fast(self, data: Tensor, num_components: int = 3) -> Tuple[Tensor, Tensor]:
"""Function implements PCA using PyTorch fast low-rank approximation.
Args:
data (Tensor): Data matrix of the shape [N, C] or [B, N, C].
num_components (int): Number of principal components to be used.
Returns:
data_pca (Tensor): Transformed low-dimensional data of the shape [N, num_components] or [B, N, num_components].
pca_components (Tensor): Principal components of the shape [num_components, C] or [B, num_components, C].
"""
# Normalize data
data_mean = data.mean(dim=-2, keepdim=True)
data_normalize = (data - data_mean) / (data.std(dim=-2, keepdim=True) + 1e-08)
# Perform fast low-rank PCA
u, _, v = torch.pca_lowrank(data_normalize, q=max(num_components, 6), niter=2, center=True)
v = v.transpose(-1, -2)
# Perform SVD flip
u, v = self._svd_flip(u, v) # type: Tensor, Tensor
# Transpose PCA components to match scikit-learn
if data_normalize.ndim == 2:
pca_components = v[:num_components]
else:
pca_components = v[:, :num_components]
self.batch_rgb_mean = data_mean
self.batch_rgb_comp = pca_components.transpose(-1, -2)
def _svd_flip(self, u: Tensor, v: Tensor) -> Tuple[Tensor, Tensor]:
"""Perform SVD flip to solve sign issue of SVD.
Args:
u (Tensor): u matrix of the shape [N, C] or [B, N, C].
v (Tensor): v matrix of the shape [C, C] or [B, C, C].
Returns:
u (Tensor): Fixed u matrix of the shape [N, C] or [B, N, C].
v (Tensor): Fixed v matrix of the shape [C, C] or [B, C, C].
"""
max_abs: Tensor = torch.abs(u).argmax(dim=-2)
indexes: Tensor = torch.arange(u.shape[-1], device=u.device)
if u.ndim == 2:
signs: Tensor = torch.sign(u[max_abs, indexes])
u = u * signs
v = v * signs.unsqueeze(dim=-1)
else:
# Maybe fix looping the future...
signs = torch.stack(
[torch.sign(u[batch_index, max_abs[batch_index], indexes]) for batch_index in range(u.shape[0])], dim=0
)
u = u * signs.unsqueeze(dim=1)
v = v * signs.unsqueeze(dim=-1)
return u, v
def old_fit_transform_kmeans_batch(self, batch_features, subsample_size=20000):
feats_map_flattened = batch_features.flatten(1, -2)
from torch_kmeans import KMeans, CosineSimilarity
kmeans_engine = KMeans(n_clusters=self.n_kmeans_clusters, distance=CosineSimilarity)
n = feats_map_flattened.size(1)
if subsample_size is not None and subsample_size < n:
indices = torch.randperm(n)[:subsample_size]
feats_map_subsampled = feats_map_flattened[:, indices]
kmeans_engine.fit(feats_map_subsampled)
else:
kmeans_engine.fit(feats_map_flattened)
labels = kmeans_engine.predict(feats_map_flattened)
labels = labels.reshape(batch_features.shape[:-1]).float().cpu().numpy()
label_map = self.cmap_kmeans(labels / (self.n_kmeans_clusters - 1))[..., :3]
label_map = torch.Tensor(label_map).squeeze(-2)
return label_map
def fit_transform_kmeans_batch(self, batch_features):
feats_map_flattened = batch_features.flatten(0, -2)
with torch.no_grad():
cl, c = self._KMeans_cosine(feats_map_flattened.float(), K=self.n_kmeans_clusters)
self.kmeans_cluster_centers = c
labels = cl.reshape(batch_features.shape[:-1]).float().cpu().numpy()
label_map = self.cmap_kmeans(labels / (self.n_kmeans_clusters - 1))[..., :3]
label_map = torch.Tensor(label_map).squeeze(-2)
return label_map
def _KMeans_cosine(self, x, K=19, Niter=100):
"""Implements Lloyd's algorithm for the Cosine similarity metric."""
N, D = x.shape # Number of samples, dimension of the ambient space
c = x[:K, :].clone() # Simplistic initialization for the centroids
# Normalize the centroids for the cosine similarity:
c[:] = torch.nn.functional.normalize(c, dim=1, p=2)
x_i = LazyTensor(x.view(N, 1, D)) # (N, 1, D) samples
c_j = LazyTensor(c.view(1, K, D)) # (1, K, D) centroids
# K-means loop:
# - x is the (N, D) point cloud,
# - cl is the (N,) vector of class labels
# - c is the (K, D) cloud of cluster centroids
for i in range(Niter):
# E step: assign points to the closest cluster -------------------------
S_ij = x_i | c_j # (N, K) symbolic Gram matrix of dot products
cl = S_ij.argmax(dim=1).long().view(-1) # Points -> Nearest cluster
# M step: update the centroids to the normalized cluster average: ------
# Compute the sum of points per cluster:
c.zero_()
c.scatter_add_(0, cl[:, None].repeat(1, D), x)
# Normalize the centroids, in place:
c[:] = torch.nn.functional.normalize(c, dim=1, p=2)
return cl, c |