co
Browse files- normflows.py +354 -0
- requirements.txt +63 -0
normflows.py
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau,OneCycleLR,CyclicLR
|
4 |
+
import pandas as pd
|
5 |
+
from sklearn.preprocessing import StandardScaler,MinMaxScaler
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from torch.distributions import MultivariateNormal, LogNormal,Normal, Chi2
|
8 |
+
from torch.distributions.distribution import Distribution
|
9 |
+
from sklearn.metrics import r2_score
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
|
13 |
+
# It's a distribution that is a kernel density estimate of a Gaussian distribution
|
14 |
+
class GaussianKDE(Distribution):
|
15 |
+
def __init__(self, X, bw):
|
16 |
+
"""
|
17 |
+
X : tensor (n, d)
|
18 |
+
`n` points with `d` dimensions to which KDE will be fit
|
19 |
+
bw : numeric
|
20 |
+
bandwidth for Gaussian kernel
|
21 |
+
"""
|
22 |
+
self.X = X
|
23 |
+
self.bw = bw
|
24 |
+
self.dims = X.shape[-1]
|
25 |
+
self.n = X.shape[0]
|
26 |
+
self.mvn = MultivariateNormal(loc=torch.zeros(self.dims),
|
27 |
+
scale_tril=torch.eye(self.dims))
|
28 |
+
|
29 |
+
def sample(self, num_samples):
|
30 |
+
"""
|
31 |
+
We are sampling from a normal distribution with mean equal to the data points in the dataset and
|
32 |
+
standard deviation equal to the bandwidth
|
33 |
+
|
34 |
+
:param num_samples: the number of samples to draw from the KDE
|
35 |
+
:return: A sample of size num_samples from the KDE.
|
36 |
+
"""
|
37 |
+
idxs = (np.random.uniform(0, 1, num_samples) * self.n).astype(int)
|
38 |
+
norm = Normal(loc=self.X[idxs], scale=self.bw)
|
39 |
+
return norm.sample()
|
40 |
+
|
41 |
+
def score_samples(self, Y, X=None):
|
42 |
+
"""Returns the kernel density estimates of each point in `Y`.
|
43 |
+
|
44 |
+
Parameters
|
45 |
+
----------
|
46 |
+
Y : tensor (m, d)
|
47 |
+
`m` points with `d` dimensions for which the probability density will
|
48 |
+
be calculated
|
49 |
+
X : tensor (n, d), optional
|
50 |
+
`n` points with `d` dimensions to which KDE will be fit. Provided to
|
51 |
+
allow batch calculations in `log_prob`. By default, `X` is None and
|
52 |
+
all points used to initialize KernelDensityEstimator are included.
|
53 |
+
|
54 |
+
|
55 |
+
Returns
|
56 |
+
-------
|
57 |
+
log_probs : tensor (m)
|
58 |
+
log probability densities for each of the queried points in `Y`
|
59 |
+
"""
|
60 |
+
if X == None:
|
61 |
+
X = self.X
|
62 |
+
log_probs = self.mvn.log_prob((X.unsqueeze(1) - Y)).sum(dim=0)
|
63 |
+
|
64 |
+
return log_probs
|
65 |
+
|
66 |
+
def log_prob(self, Y):
|
67 |
+
"""Returns the total log probability of one or more points, `Y`, using
|
68 |
+
a Multivariate Normal kernel fit to `X` and scaled using `bw`.
|
69 |
+
|
70 |
+
Parameters
|
71 |
+
----------
|
72 |
+
Y : tensor (m, d)
|
73 |
+
`m` points with `d` dimensions for which the probability density will
|
74 |
+
be calculated
|
75 |
+
|
76 |
+
Returns
|
77 |
+
-------
|
78 |
+
log_prob : numeric
|
79 |
+
total log probability density for the queried points, `Y`
|
80 |
+
"""
|
81 |
+
|
82 |
+
X_chunks = self.X
|
83 |
+
Y_chunks = Y
|
84 |
+
self.Y = Y
|
85 |
+
log_prob = 0
|
86 |
+
|
87 |
+
for x in X_chunks:
|
88 |
+
for y in Y_chunks:
|
89 |
+
|
90 |
+
log_prob += self.score_samples(y,x).sum(dim=0)
|
91 |
+
|
92 |
+
return log_prob
|
93 |
+
|
94 |
+
class Chi2KDE(Distribution):
|
95 |
+
def __init__(self, X, bw):
|
96 |
+
"""
|
97 |
+
X : tensor (n, d)
|
98 |
+
`n` points with `d` dimensions to which KDE will be fit
|
99 |
+
bw : numeric
|
100 |
+
bandwidth for Gaussian kernel
|
101 |
+
"""
|
102 |
+
self.X = X
|
103 |
+
self.bw = bw
|
104 |
+
self.dims = X.shape[-1]
|
105 |
+
self.n = X.shape[0]
|
106 |
+
self.mvn = Chi2(self.dims)
|
107 |
+
|
108 |
+
def sample(self, num_samples):
|
109 |
+
idxs = (np.random.uniform(0, 1, num_samples) * self.n).astype(int)
|
110 |
+
norm = LogNormal(loc=self.X[idxs], scale=self.bw)
|
111 |
+
return norm.sample()
|
112 |
+
|
113 |
+
def score_samples(self, Y, X=None):
|
114 |
+
"""Returns the kernel density estimates of each point in `Y`.
|
115 |
+
|
116 |
+
Parameters
|
117 |
+
----------
|
118 |
+
Y : tensor (m, d)
|
119 |
+
`m` points with `d` dimensions for which the probability density will
|
120 |
+
be calculated
|
121 |
+
X : tensor (n, d), optional
|
122 |
+
`n` points with `d` dimensions to which KDE will be fit. Provided to
|
123 |
+
allow batch calculations in `log_prob`. By default, `X` is None and
|
124 |
+
all points used to initialize KernelDensityEstimator are included.
|
125 |
+
|
126 |
+
|
127 |
+
Returns
|
128 |
+
-------
|
129 |
+
log_probs : tensor (m)
|
130 |
+
log probability densities for each of the queried points in `Y`
|
131 |
+
"""
|
132 |
+
if X == None:
|
133 |
+
X = self.X
|
134 |
+
log_probs = self.mvn.log_prob(abs(X.unsqueeze(1) - Y)).sum()
|
135 |
+
|
136 |
+
return log_probs
|
137 |
+
|
138 |
+
def log_prob(self, Y):
|
139 |
+
"""Returns the total log probability of one or more points, `Y`, using
|
140 |
+
a Multivariate Normal kernel fit to `X` and scaled using `bw`.
|
141 |
+
|
142 |
+
Parameters
|
143 |
+
----------
|
144 |
+
Y : tensor (m, d)
|
145 |
+
`m` points with `d` dimensions for which the probability density will
|
146 |
+
be calculated
|
147 |
+
|
148 |
+
Returns
|
149 |
+
-------
|
150 |
+
log_prob : numeric
|
151 |
+
total log probability density for the queried points, `Y`
|
152 |
+
"""
|
153 |
+
|
154 |
+
X_chunks = self.X
|
155 |
+
Y_chunks = Y
|
156 |
+
self.Y = Y
|
157 |
+
log_prob = 0
|
158 |
+
|
159 |
+
for x in X_chunks:
|
160 |
+
for y in Y_chunks:
|
161 |
+
|
162 |
+
log_prob += self.score_samples(y,x).sum(dim=0)
|
163 |
+
|
164 |
+
return log_prob
|
165 |
+
|
166 |
+
|
167 |
+
class PlanarFlow(nn.Module):
|
168 |
+
"""
|
169 |
+
A single planar flow, computes T(x) and log(det(jac_T)))
|
170 |
+
"""
|
171 |
+
def __init__(self, D):
|
172 |
+
super(PlanarFlow, self).__init__()
|
173 |
+
self.u = nn.Parameter(torch.Tensor(1, D), requires_grad=True)
|
174 |
+
self.w = nn.Parameter(torch.Tensor(1, D), requires_grad=True)
|
175 |
+
self.b = nn.Parameter(torch.Tensor(1), requires_grad=True)
|
176 |
+
self.h = torch.tanh
|
177 |
+
self.init_params()
|
178 |
+
|
179 |
+
def init_params(self):
|
180 |
+
self.w.data.uniform_(0.4, 1)
|
181 |
+
self.b.data.uniform_(0.4, 1)
|
182 |
+
self.u.data.uniform_(0.4, 1)
|
183 |
+
|
184 |
+
|
185 |
+
def forward(self, z):
|
186 |
+
linear_term = torch.mm(z, self.w.T) + self.b
|
187 |
+
return z + self.u * self.h(linear_term)
|
188 |
+
|
189 |
+
def h_prime(self, x):
|
190 |
+
"""
|
191 |
+
Derivative of tanh
|
192 |
+
"""
|
193 |
+
return (1 - self.h(x) ** 2)
|
194 |
+
|
195 |
+
def psi(self, z):
|
196 |
+
inner = torch.mm(z, self.w.T) + self.b
|
197 |
+
return self.h_prime(inner) * self.w
|
198 |
+
|
199 |
+
def log_det(self, z):
|
200 |
+
inner = 1 + torch.mm(self.psi(z), self.u.T)
|
201 |
+
return torch.log(torch.abs(inner))
|
202 |
+
|
203 |
+
|
204 |
+
# It's a normalizing flow that takes in a distribution and outputs a distribution.
|
205 |
+
class NormalizingFlow(nn.Module):
|
206 |
+
"""
|
207 |
+
A normalizng flow composed of a sequence of planar flows.
|
208 |
+
"""
|
209 |
+
def __init__(self, D, n_flows=2):
|
210 |
+
"""
|
211 |
+
The function takes in two arguments, D and n_flows. D is the dimension of the data, and n_flows
|
212 |
+
is the number of flows. The function then creates a list of PlanarFlow objects, where the number
|
213 |
+
of PlanarFlow objects is equal to n_flows
|
214 |
+
|
215 |
+
:param D: the dimensionality of the data
|
216 |
+
:param n_flows: number of flows to use, defaults to 2 (optional)
|
217 |
+
"""
|
218 |
+
super(NormalizingFlow, self).__init__()
|
219 |
+
self.flows = nn.ModuleList(
|
220 |
+
[PlanarFlow(D) for _ in range(n_flows)])
|
221 |
+
|
222 |
+
def sample(self, base_samples):
|
223 |
+
"""
|
224 |
+
Transform samples from a simple base distribution
|
225 |
+
by passing them through a sequence of Planar flows.
|
226 |
+
"""
|
227 |
+
samples = base_samples
|
228 |
+
for flow in self.flows:
|
229 |
+
samples = flow(samples)
|
230 |
+
return samples
|
231 |
+
|
232 |
+
def forward(self, x):
|
233 |
+
"""
|
234 |
+
Computes and returns the sum of log_det_jacobians
|
235 |
+
and the transformed samples T(x).
|
236 |
+
"""
|
237 |
+
sum_log_det = 0
|
238 |
+
transformed_sample = x
|
239 |
+
|
240 |
+
for i in range(len(self.flows)):
|
241 |
+
log_det_i = (self.flows[i].log_det(transformed_sample))
|
242 |
+
sum_log_det += log_det_i
|
243 |
+
transformed_sample = self.flows[i](transformed_sample)
|
244 |
+
|
245 |
+
return transformed_sample, sum_log_det
|
246 |
+
|
247 |
+
def random_normal_samples(n, dim=2):
|
248 |
+
return torch.zeros(n, dim).normal_(mean=0, std=1.5)
|
249 |
+
|
250 |
+
|
251 |
+
|
252 |
+
|
253 |
+
class nflow():
|
254 |
+
def __init__(self,dim=2,latent=16,batchsize:int=1,dataset=None):
|
255 |
+
"""
|
256 |
+
The function __init__ initializes the class NormalizingFlowModel with the parameters dim,
|
257 |
+
latent, batchsize, and datasetPath
|
258 |
+
|
259 |
+
:param dim: The dimension of the data, defaults to 2 (optional)
|
260 |
+
:param latent: The number of latent variables in the model, defaults to 16 (optional)
|
261 |
+
:param batchsize: The number of samples to generate at a time, defaults to 1
|
262 |
+
:type batchsize: int (optional)
|
263 |
+
:param datasetPath: The path to the dataset, defaults to data/dataset.csv
|
264 |
+
:type datasetPath: str (optional)
|
265 |
+
"""
|
266 |
+
self.dim = dim
|
267 |
+
self.batchsize = batchsize
|
268 |
+
self.model = NormalizingFlow(dim, latent)
|
269 |
+
self.dataset = dataset
|
270 |
+
|
271 |
+
def compile(self,optim:torch.optim=torch.optim.Adam,distribution:str='GaussianKDE',lr:float=0.00015,bw:float=0.1,wd=0.0015):
|
272 |
+
"""
|
273 |
+
It takes in a dataset, a model, and a distribution, and returns a compiled model
|
274 |
+
|
275 |
+
:param optim: the optimizer to use
|
276 |
+
:type optim: torch.optim
|
277 |
+
:param distribution: the type of distribution to use, defaults to GaussianKDE
|
278 |
+
:type distribution: str (optional)
|
279 |
+
:param lr: learning rate
|
280 |
+
:type lr: float
|
281 |
+
:param bw: bandwidth for the KDE
|
282 |
+
:type bw: float
|
283 |
+
"""
|
284 |
+
if wd:
|
285 |
+
self.opt = optim(
|
286 |
+
params=self.model.parameters(),
|
287 |
+
lr=lr,
|
288 |
+
weight_decay = wd
|
289 |
+
# momentum=0.9
|
290 |
+
# momentum=0.1
|
291 |
+
)
|
292 |
+
else:
|
293 |
+
self.opt = optim(
|
294 |
+
params=self.model.parameters(),
|
295 |
+
lr=lr,
|
296 |
+
# momentum=0.9
|
297 |
+
# momentum=0.1
|
298 |
+
)
|
299 |
+
self.scaler = StandardScaler()
|
300 |
+
self.scaler_mm = MinMaxScaler(feature_range=(0,1))
|
301 |
+
|
302 |
+
df = pd.read_csv(self.dataset)
|
303 |
+
df = df.iloc[:,1:]
|
304 |
+
|
305 |
+
|
306 |
+
if 'Chi2' in distribution:
|
307 |
+
self.scaled=self.scaler_mm.fit_transform(df)
|
308 |
+
else: self.scaled = self.scaler.fit_transform(df)
|
309 |
+
|
310 |
+
self.density = globals()[distribution](X=torch.tensor(self.scaled, dtype=torch.float32), bw=bw)
|
311 |
+
|
312 |
+
# self.dl = torch.utils.data.DataLoader(scaled,batchsize=self.batchsize)
|
313 |
+
self.scheduler = ReduceLROnPlateau(self.opt, patience=10000)
|
314 |
+
self.losses = []
|
315 |
+
|
316 |
+
def train(self,iters:int=1000):
|
317 |
+
"""
|
318 |
+
> We sample from a normal distribution, pass the samples through the model, and then calculate
|
319 |
+
the loss
|
320 |
+
|
321 |
+
:param iters: number of iterations to train for, defaults to 1000
|
322 |
+
:type iters: int (optional)
|
323 |
+
"""
|
324 |
+
for idx in range(iters):
|
325 |
+
if idx % 100 == 0:
|
326 |
+
print("Iteration {}".format(idx))
|
327 |
+
|
328 |
+
samples = torch.autograd.Variable(random_normal_samples(self.batchsize,self.dim))
|
329 |
+
|
330 |
+
z_k, sum_log_det = self.model(samples)
|
331 |
+
log_p_x = self.density.log_prob(z_k)
|
332 |
+
# Reverse KL since we can evaluate target density but can't sample
|
333 |
+
loss = (-sum_log_det - (log_p_x)).mean()
|
334 |
+
|
335 |
+
self.opt.zero_grad()
|
336 |
+
loss.backward()
|
337 |
+
self.opt.step()
|
338 |
+
self.scheduler.step(loss)
|
339 |
+
|
340 |
+
self.losses.append(loss.item())
|
341 |
+
|
342 |
+
if idx % 100 == 0:
|
343 |
+
print("Loss {}".format(loss.item()))
|
344 |
+
|
345 |
+
plt.plot(self.losses)
|
346 |
+
|
347 |
+
def performance(self):
|
348 |
+
"""
|
349 |
+
The function takes the model and the scaled data as inputs, samples from the model, and then
|
350 |
+
prints the r2 score of the samples and the scaled data.
|
351 |
+
"""
|
352 |
+
samples = ((self.model.sample(torch.tensor(self.scaled).float())).detach().numpy())
|
353 |
+
|
354 |
+
print('r2', r2_score(self.scaled,samples))
|
requirements.txt
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
altair==4.2.2
|
2 |
+
attrs==22.2.0
|
3 |
+
blinker==1.5
|
4 |
+
cachetools==5.3.0
|
5 |
+
certifi==2022.12.7
|
6 |
+
charset-normalizer==3.1.0
|
7 |
+
click==8.1.3
|
8 |
+
contourpy==1.0.7
|
9 |
+
cycler==0.11.0
|
10 |
+
decorator==5.1.1
|
11 |
+
entrypoints==0.4
|
12 |
+
filelock==3.10.0
|
13 |
+
fonttools==4.39.2
|
14 |
+
gitdb==4.0.10
|
15 |
+
GitPython==3.1.31
|
16 |
+
idna==3.4
|
17 |
+
importlib-metadata==6.0.0
|
18 |
+
Jinja2==3.1.2
|
19 |
+
joblib==1.2.0
|
20 |
+
jsonschema==4.17.3
|
21 |
+
kiwisolver==1.4.4
|
22 |
+
markdown-it-py==2.2.0
|
23 |
+
MarkupSafe==2.1.2
|
24 |
+
matplotlib==3.7.1
|
25 |
+
mdurl==0.1.2
|
26 |
+
mpmath==1.3.0
|
27 |
+
networkx==3.0
|
28 |
+
numpy==1.24.2
|
29 |
+
packaging==23.0
|
30 |
+
pandas==1.5.3
|
31 |
+
Pillow==9.4.0
|
32 |
+
protobuf==3.20.3
|
33 |
+
pyarrow==11.0.0
|
34 |
+
pydeck==0.8.0
|
35 |
+
Pygments==2.14.0
|
36 |
+
Pympler==1.0.1
|
37 |
+
pyparsing==3.0.9
|
38 |
+
pyrsistent==0.19.3
|
39 |
+
python-dateutil==2.8.2
|
40 |
+
pytz==2022.7.1
|
41 |
+
pytz-deprecation-shim==0.1.0.post0
|
42 |
+
requests==2.28.2
|
43 |
+
rich==13.3.2
|
44 |
+
scikit-learn==1.2.2
|
45 |
+
scipy==1.10.1
|
46 |
+
seaborn==0.12.2
|
47 |
+
semver==2.13.0
|
48 |
+
six==1.16.0
|
49 |
+
smmap==5.0.0
|
50 |
+
streamlit==1.20.0
|
51 |
+
sympy==1.11.1
|
52 |
+
threadpoolctl==3.1.0
|
53 |
+
toml==0.10.2
|
54 |
+
toolz==0.12.0
|
55 |
+
torch==2.0.0
|
56 |
+
tornado==6.2
|
57 |
+
typing_extensions==4.5.0
|
58 |
+
tzdata==2022.7
|
59 |
+
tzlocal==4.2
|
60 |
+
urllib3==1.26.15
|
61 |
+
validators==0.20.0
|
62 |
+
watchdog==2.3.1
|
63 |
+
zipp==3.15.0
|