Spaces:
Runtime error
Runtime error
import time | |
import numpy as np | |
import torch | |
import swapae.util as util | |
np.set_printoptions(precision=4, suppress=True, edgeitems=10) | |
class PCA: | |
def __init__(self, X, ndim=128, var_fraction=0.99, l2_normalized=True, first_direction=None): | |
self.l2_normalized = l2_normalized | |
if l2_normalized: | |
X = X[:, :-1] | |
assert len(X.shape) == 2 | |
torch.cuda.synchronize() | |
start_time = time.time() | |
self.mean = torch.mean(X, dim=0, keepdim=True) | |
#self.mean = 0 | |
#self.std = torch.std(X, dim=0, keepdim=True) + 1e-6 | |
self.std = 1 | |
#print("std is ", self.std[:, :10].cpu().numpy()) | |
#X_orig = X | |
X = (X - self.mean) / self.std | |
U, S, V = torch.svd(X) | |
S = S[:ndim] | |
V = V[:, :ndim] | |
self.proj = V | |
scale = torch.mm(X, self.proj).std(dim=0) | |
torch.cuda.synchronize() | |
print("PCA time taken on vectors of size %s : %f" % (str(X.size()), time.time() - start_time)) | |
print("largest std of each PC: ", scale[:10].cpu().numpy()) | |
print("smallest std of each PC: ", scale[-10:].cpu().numpy()) | |
self.sinvals = S | |
print("largest sinvals: ", self.sinvals[:10].cpu().numpy()) | |
self.inv_proj = V.transpose(0, 1) | |
self.N = X.size(0) | |
def project(self, x): | |
if self.l2_normalized: | |
last_dim = x[:, -1:] | |
x = x[:, :-1] | |
#x = (x - self.mean) / self.std | |
z = torch.mm(x, self.proj) | |
if self.l2_normalized: | |
return torch.cat([z, last_dim], dim=1) | |
else: | |
return z | |
def scale(self): | |
return self.sinvals / np.sqrt(self.N) | |
def pc(self, idx): | |
# return self.inv_proj[idx:idx + 1] * (self.std * np.sqrt(self.inv_proj.size(1))) | |
return self.inv_proj[idx:idx + 1] | |
def inverse(self, z): | |
if self.l2_normalized: | |
last_dim = z[:, -1:] | |
z = z[:, :-1] | |
x = torch.mm(z, self.inv_proj) | |
#x = x * self.std + self.mean | |
if self.l2_normalized: | |
return torch.cat([x, last_dim], dim=1) | |
else: | |
return x | |