File size: 2,148 Bytes
1b2a9b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import time
import numpy as np
import torch
import swapae.util as util

np.set_printoptions(precision=4, suppress=True, edgeitems=10)


class PCA:
    def __init__(self, X, ndim=128, var_fraction=0.99, l2_normalized=True, first_direction=None):

        self.l2_normalized = l2_normalized
        if l2_normalized:
            X = X[:, :-1]

        assert len(X.shape) == 2
        torch.cuda.synchronize()
        start_time = time.time()

        self.mean = torch.mean(X, dim=0, keepdim=True)
        #self.mean = 0
        #self.std = torch.std(X, dim=0, keepdim=True) + 1e-6
        self.std = 1
        #print("std is ", self.std[:, :10].cpu().numpy())
        #X_orig = X
        X = (X - self.mean) / self.std

        U, S, V = torch.svd(X)
        S = S[:ndim]
        V = V[:, :ndim]
        self.proj = V
        scale = torch.mm(X, self.proj).std(dim=0)
        torch.cuda.synchronize()
        print("PCA time taken on vectors of size %s : %f" % (str(X.size()), time.time() - start_time))
        print("largest std of each PC: ", scale[:10].cpu().numpy())
        print("smallest std of each PC: ", scale[-10:].cpu().numpy())
        self.sinvals = S
        print("largest sinvals: ", self.sinvals[:10].cpu().numpy())
        self.inv_proj = V.transpose(0, 1)
        self.N = X.size(0)

    def project(self, x):
        if self.l2_normalized:
            last_dim = x[:, -1:]
            x = x[:, :-1]
        #x = (x - self.mean) / self.std
        z = torch.mm(x, self.proj)
        if self.l2_normalized:
            return torch.cat([z, last_dim], dim=1)
        else:
            return z

    def scale(self):
        return self.sinvals / np.sqrt(self.N)

    def pc(self, idx):
        # return self.inv_proj[idx:idx + 1] * (self.std * np.sqrt(self.inv_proj.size(1)))
        return self.inv_proj[idx:idx + 1]

    def inverse(self, z):
        if self.l2_normalized:
            last_dim = z[:, -1:]
            z = z[:, :-1]
        x = torch.mm(z, self.inv_proj)
        #x = x * self.std + self.mean
        if self.l2_normalized:
            return torch.cat([x, last_dim], dim=1)
        else:
            return x