Example scurve is added.
Browse files- README.md +2 -0
- results_scurve/cnf-viz-00000.jpg +0 -0
- results_scurve/cnf-viz-00250.jpg +0 -0
- results_scurve/cnf-viz-00500.jpg +0 -0
- results_scurve/cnf-viz-00750.jpg +0 -0
- results_scurve/cnf-viz-01000.jpg +0 -0
- results_scurve/cnf-viz-01250.jpg +0 -0
- results_scurve/cnf-viz-01500.jpg +0 -0
- results_scurve/cnf-viz-01750.jpg +0 -0
- results_scurve/cnf-viz-02000.jpg +0 -0
- results_scurve/cnf-viz-02250.jpg +0 -0
- results_scurve/cnf-viz-02500.jpg +0 -0
- results_scurve/cnf-viz-02750.jpg +0 -0
- results_scurve/cnf-viz-03000.jpg +0 -0
- results_scurve/cnf-viz-03250.jpg +0 -0
- results_scurve/cnf-viz-03500.jpg +0 -0
- results_scurve/cnf-viz-03750.jpg +0 -0
- results_scurve/cnf-viz-04000.jpg +0 -0
- results_scurve/cnf-viz-04250.jpg +0 -0
- results_scurve/cnf-viz-04500.jpg +0 -0
- results_scurve/cnf-viz-04750.jpg +0 -0
- results_scurve/cnf-viz-05000.jpg +0 -0
- results_scurve/cnf-viz-05250.jpg +0 -0
- results_scurve/cnf-viz-05500.jpg +0 -0
- results_scurve/cnf-viz-05750.jpg +0 -0
- results_scurve/cnf-viz-06000.jpg +0 -0
- results_scurve/cnf-viz-06250.jpg +0 -0
- results_scurve/cnf-viz-06500.jpg +0 -0
- results_scurve/cnf-viz-06750.jpg +0 -0
- results_scurve/cnf-viz-07000.jpg +0 -0
- results_scurve/cnf-viz-07250.jpg +0 -0
- results_scurve/cnf-viz-07500.jpg +0 -0
- results_scurve/cnf-viz-07750.jpg +0 -0
- results_scurve/cnf-viz-08000.jpg +0 -0
- results_scurve/cnf-viz-08250.jpg +0 -0
- results_scurve/cnf-viz-08500.jpg +0 -0
- results_scurve/cnf-viz-08750.jpg +0 -0
- results_scurve/cnf-viz-09000.jpg +0 -0
- results_scurve/cnf-viz-09250.jpg +0 -0
- results_scurve/cnf-viz-09500.jpg +0 -0
- results_scurve/cnf-viz-09750.jpg +0 -0
- results_scurve/cnf-viz-10000.jpg +0 -0
- results_scurve/cnf-viz.gif +0 -0
- train_cnf.py +16 -2
README.md
CHANGED
@@ -44,8 +44,10 @@ For Continuous Normalizing Flow,
|
|
44 |
```bash
|
45 |
python main.py --model=cnf --sample_dataset=circles
|
46 |
```
|
|
|
47 |
|
48 |
# Sample Results
|
49 |
|
50 |
![cnf-viz](https://user-images.githubusercontent.com/72425253/126124351-44e00438-055e-4b1c-90ee-758a545dd602.gif)
|
51 |
![cnf-viz](https://user-images.githubusercontent.com/72425253/126124648-dcb3f8f4-396a-447c-96cf-f9304377fa48.gif)
|
|
|
|
44 |
```bash
|
45 |
python main.py --model=cnf --sample_dataset=circles
|
46 |
```
|
47 |
+
Sample datasets can be chosen as circles, moons, or scurve.
|
48 |
|
49 |
# Sample Results
|
50 |
|
51 |
![cnf-viz](https://user-images.githubusercontent.com/72425253/126124351-44e00438-055e-4b1c-90ee-758a545dd602.gif)
|
52 |
![cnf-viz](https://user-images.githubusercontent.com/72425253/126124648-dcb3f8f4-396a-447c-96cf-f9304377fa48.gif)
|
53 |
+
![cnf-viz](https://user-images.githubusercontent.com/72425253/126127269-4c02ee6a-a9a3-4b9f-b380-f8669f58872b.gif)
|
results_scurve/cnf-viz-00000.jpg
ADDED
results_scurve/cnf-viz-00250.jpg
ADDED
results_scurve/cnf-viz-00500.jpg
ADDED
results_scurve/cnf-viz-00750.jpg
ADDED
results_scurve/cnf-viz-01000.jpg
ADDED
results_scurve/cnf-viz-01250.jpg
ADDED
results_scurve/cnf-viz-01500.jpg
ADDED
results_scurve/cnf-viz-01750.jpg
ADDED
results_scurve/cnf-viz-02000.jpg
ADDED
results_scurve/cnf-viz-02250.jpg
ADDED
results_scurve/cnf-viz-02500.jpg
ADDED
results_scurve/cnf-viz-02750.jpg
ADDED
results_scurve/cnf-viz-03000.jpg
ADDED
results_scurve/cnf-viz-03250.jpg
ADDED
results_scurve/cnf-viz-03500.jpg
ADDED
results_scurve/cnf-viz-03750.jpg
ADDED
results_scurve/cnf-viz-04000.jpg
ADDED
results_scurve/cnf-viz-04250.jpg
ADDED
results_scurve/cnf-viz-04500.jpg
ADDED
results_scurve/cnf-viz-04750.jpg
ADDED
results_scurve/cnf-viz-05000.jpg
ADDED
results_scurve/cnf-viz-05250.jpg
ADDED
results_scurve/cnf-viz-05500.jpg
ADDED
results_scurve/cnf-viz-05750.jpg
ADDED
results_scurve/cnf-viz-06000.jpg
ADDED
results_scurve/cnf-viz-06250.jpg
ADDED
results_scurve/cnf-viz-06500.jpg
ADDED
results_scurve/cnf-viz-06750.jpg
ADDED
results_scurve/cnf-viz-07000.jpg
ADDED
results_scurve/cnf-viz-07250.jpg
ADDED
results_scurve/cnf-viz-07500.jpg
ADDED
results_scurve/cnf-viz-07750.jpg
ADDED
results_scurve/cnf-viz-08000.jpg
ADDED
results_scurve/cnf-viz-08250.jpg
ADDED
results_scurve/cnf-viz-08500.jpg
ADDED
results_scurve/cnf-viz-08750.jpg
ADDED
results_scurve/cnf-viz-09000.jpg
ADDED
results_scurve/cnf-viz-09250.jpg
ADDED
results_scurve/cnf-viz-09500.jpg
ADDED
results_scurve/cnf-viz-09750.jpg
ADDED
results_scurve/cnf-viz-10000.jpg
ADDED
results_scurve/cnf-viz.gif
ADDED
train_cnf.py
CHANGED
@@ -16,7 +16,7 @@ from flax.core import freeze, unfreeze
|
|
16 |
from flax import linen as nn
|
17 |
from flax import serialization
|
18 |
import optax
|
19 |
-
from sklearn.datasets import make_circles, make_moons
|
20 |
from tqdm import tqdm
|
21 |
|
22 |
|
@@ -116,6 +116,16 @@ def get_batch_moons(num_samples):
|
|
116 |
return lax.concatenate((x, logp_diff_t1), 1)
|
117 |
|
118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
def multivariate_normal(z):
|
120 |
"""
|
121 |
Log probability of multivariate_normal.
|
@@ -187,6 +197,8 @@ def train(learning_rate, n_iters, batch_size, in_out_dim, hidden_dim, width, t0,
|
|
187 |
get_batch = lambda num_samples: get_batch_circles(num_samples)
|
188 |
elif dataset == "moons":
|
189 |
get_batch = lambda num_samples: get_batch_moons(num_samples)
|
|
|
|
|
190 |
|
191 |
for itr in range(1, n_iters+1):
|
192 |
batch = get_batch(batch_size)
|
@@ -222,6 +234,8 @@ def viz(neg_params, pos_params, in_out_dim, hidden_dim, width, t0, t1, dataset):
|
|
222 |
get_batch = lambda num_samples: get_batch_circles(num_samples)
|
223 |
elif dataset == "moons":
|
224 |
get_batch = lambda num_samples: get_batch_moons(num_samples)
|
|
|
|
|
225 |
target_sample = get_batch(viz_samples)[:, :2]
|
226 |
|
227 |
if not os.path.exists('results_%s/' % dataset):
|
@@ -299,4 +313,4 @@ def create_plots(z_t_samples, z_t_density, logp_diff_t, t0, t1, viz_timesteps, t
|
|
299 |
|
300 |
|
301 |
if __name__ == '__main__':
|
302 |
-
train(0.001,
|
|
|
16 |
from flax import linen as nn
|
17 |
from flax import serialization
|
18 |
import optax
|
19 |
+
from sklearn.datasets import make_circles, make_moons, make_s_curve
|
20 |
from tqdm import tqdm
|
21 |
|
22 |
|
|
|
116 |
return lax.concatenate((x, logp_diff_t1), 1)
|
117 |
|
118 |
|
119 |
+
def get_batch_scurve(num_samples):
|
120 |
+
points, _ = make_s_curve(n_samples=num_samples, noise=0.05, random_state=0)
|
121 |
+
x1 = jnp.array(points, dtype=jnp.float32)[:, :1]
|
122 |
+
x2 = jnp.array(points, dtype=jnp.float32)[:, 2:]
|
123 |
+
x = lax.concatenate((x1, x2), 1)
|
124 |
+
logp_diff_t1 = jnp.zeros((num_samples, 1), dtype=jnp.float32)
|
125 |
+
|
126 |
+
return lax.concatenate((x, logp_diff_t1), 1)
|
127 |
+
|
128 |
+
|
129 |
def multivariate_normal(z):
|
130 |
"""
|
131 |
Log probability of multivariate_normal.
|
|
|
197 |
get_batch = lambda num_samples: get_batch_circles(num_samples)
|
198 |
elif dataset == "moons":
|
199 |
get_batch = lambda num_samples: get_batch_moons(num_samples)
|
200 |
+
elif dataset == "scurve":
|
201 |
+
get_batch = lambda num_samples: get_batch_scurve(num_samples)
|
202 |
|
203 |
for itr in range(1, n_iters+1):
|
204 |
batch = get_batch(batch_size)
|
|
|
234 |
get_batch = lambda num_samples: get_batch_circles(num_samples)
|
235 |
elif dataset == "moons":
|
236 |
get_batch = lambda num_samples: get_batch_moons(num_samples)
|
237 |
+
elif dataset == "scurve":
|
238 |
+
get_batch = lambda num_samples: get_batch_scurve(num_samples)
|
239 |
target_sample = get_batch(viz_samples)[:, :2]
|
240 |
|
241 |
if not os.path.exists('results_%s/' % dataset):
|
|
|
313 |
|
314 |
|
315 |
if __name__ == '__main__':
|
316 |
+
train(0.001, 1000, 512, 2, 32, 64, 0., 10., True, 'scurve')
|