sunshineatnoon
Add application file
1b2a9b1
import time
import numpy as np
import torch
import swapae.util as util
np.set_printoptions(precision=4, suppress=True, edgeitems=10)
class PCA:
def __init__(self, X, ndim=128, var_fraction=0.99, l2_normalized=True, first_direction=None):
self.l2_normalized = l2_normalized
if l2_normalized:
X = X[:, :-1]
assert len(X.shape) == 2
torch.cuda.synchronize()
start_time = time.time()
self.mean = torch.mean(X, dim=0, keepdim=True)
#self.mean = 0
#self.std = torch.std(X, dim=0, keepdim=True) + 1e-6
self.std = 1
#print("std is ", self.std[:, :10].cpu().numpy())
#X_orig = X
X = (X - self.mean) / self.std
U, S, V = torch.svd(X)
S = S[:ndim]
V = V[:, :ndim]
self.proj = V
scale = torch.mm(X, self.proj).std(dim=0)
torch.cuda.synchronize()
print("PCA time taken on vectors of size %s : %f" % (str(X.size()), time.time() - start_time))
print("largest std of each PC: ", scale[:10].cpu().numpy())
print("smallest std of each PC: ", scale[-10:].cpu().numpy())
self.sinvals = S
print("largest sinvals: ", self.sinvals[:10].cpu().numpy())
self.inv_proj = V.transpose(0, 1)
self.N = X.size(0)
def project(self, x):
if self.l2_normalized:
last_dim = x[:, -1:]
x = x[:, :-1]
#x = (x - self.mean) / self.std
z = torch.mm(x, self.proj)
if self.l2_normalized:
return torch.cat([z, last_dim], dim=1)
else:
return z
def scale(self):
return self.sinvals / np.sqrt(self.N)
def pc(self, idx):
# return self.inv_proj[idx:idx + 1] * (self.std * np.sqrt(self.inv_proj.size(1)))
return self.inv_proj[idx:idx + 1]
def inverse(self, z):
if self.l2_normalized:
last_dim = z[:, -1:]
z = z[:, :-1]
x = torch.mm(z, self.inv_proj)
#x = x * self.std + self.mean
if self.l2_normalized:
return torch.cat([x, last_dim], dim=1)
else:
return x