sw32-seo commited on
Commit
5966102
1 Parent(s): 2e90bae

Example scurve is added.

Browse files
Files changed (44) hide show
  1. README.md +2 -0
  2. results_scurve/cnf-viz-00000.jpg +0 -0
  3. results_scurve/cnf-viz-00250.jpg +0 -0
  4. results_scurve/cnf-viz-00500.jpg +0 -0
  5. results_scurve/cnf-viz-00750.jpg +0 -0
  6. results_scurve/cnf-viz-01000.jpg +0 -0
  7. results_scurve/cnf-viz-01250.jpg +0 -0
  8. results_scurve/cnf-viz-01500.jpg +0 -0
  9. results_scurve/cnf-viz-01750.jpg +0 -0
  10. results_scurve/cnf-viz-02000.jpg +0 -0
  11. results_scurve/cnf-viz-02250.jpg +0 -0
  12. results_scurve/cnf-viz-02500.jpg +0 -0
  13. results_scurve/cnf-viz-02750.jpg +0 -0
  14. results_scurve/cnf-viz-03000.jpg +0 -0
  15. results_scurve/cnf-viz-03250.jpg +0 -0
  16. results_scurve/cnf-viz-03500.jpg +0 -0
  17. results_scurve/cnf-viz-03750.jpg +0 -0
  18. results_scurve/cnf-viz-04000.jpg +0 -0
  19. results_scurve/cnf-viz-04250.jpg +0 -0
  20. results_scurve/cnf-viz-04500.jpg +0 -0
  21. results_scurve/cnf-viz-04750.jpg +0 -0
  22. results_scurve/cnf-viz-05000.jpg +0 -0
  23. results_scurve/cnf-viz-05250.jpg +0 -0
  24. results_scurve/cnf-viz-05500.jpg +0 -0
  25. results_scurve/cnf-viz-05750.jpg +0 -0
  26. results_scurve/cnf-viz-06000.jpg +0 -0
  27. results_scurve/cnf-viz-06250.jpg +0 -0
  28. results_scurve/cnf-viz-06500.jpg +0 -0
  29. results_scurve/cnf-viz-06750.jpg +0 -0
  30. results_scurve/cnf-viz-07000.jpg +0 -0
  31. results_scurve/cnf-viz-07250.jpg +0 -0
  32. results_scurve/cnf-viz-07500.jpg +0 -0
  33. results_scurve/cnf-viz-07750.jpg +0 -0
  34. results_scurve/cnf-viz-08000.jpg +0 -0
  35. results_scurve/cnf-viz-08250.jpg +0 -0
  36. results_scurve/cnf-viz-08500.jpg +0 -0
  37. results_scurve/cnf-viz-08750.jpg +0 -0
  38. results_scurve/cnf-viz-09000.jpg +0 -0
  39. results_scurve/cnf-viz-09250.jpg +0 -0
  40. results_scurve/cnf-viz-09500.jpg +0 -0
  41. results_scurve/cnf-viz-09750.jpg +0 -0
  42. results_scurve/cnf-viz-10000.jpg +0 -0
  43. results_scurve/cnf-viz.gif +0 -0
  44. train_cnf.py +16 -2
README.md CHANGED
@@ -44,8 +44,10 @@ For Continuous Normalizing Flow,
44
  ```bash
45
  python main.py --model=cnf --sample_dataset=circles
46
  ```
 
47
 
48
  # Sample Results
49
 
50
  ![cnf-viz](https://user-images.githubusercontent.com/72425253/126124351-44e00438-055e-4b1c-90ee-758a545dd602.gif)
51
  ![cnf-viz](https://user-images.githubusercontent.com/72425253/126124648-dcb3f8f4-396a-447c-96cf-f9304377fa48.gif)
 
 
44
  ```bash
45
  python main.py --model=cnf --sample_dataset=circles
46
  ```
47
+ Sample datasets can be chosen as circles, moons, or scurve.
48
 
49
  # Sample Results
50
 
51
  ![cnf-viz](https://user-images.githubusercontent.com/72425253/126124351-44e00438-055e-4b1c-90ee-758a545dd602.gif)
52
  ![cnf-viz](https://user-images.githubusercontent.com/72425253/126124648-dcb3f8f4-396a-447c-96cf-f9304377fa48.gif)
53
+ ![cnf-viz](https://user-images.githubusercontent.com/72425253/126127269-4c02ee6a-a9a3-4b9f-b380-f8669f58872b.gif)
results_scurve/cnf-viz-00000.jpg ADDED
results_scurve/cnf-viz-00250.jpg ADDED
results_scurve/cnf-viz-00500.jpg ADDED
results_scurve/cnf-viz-00750.jpg ADDED
results_scurve/cnf-viz-01000.jpg ADDED
results_scurve/cnf-viz-01250.jpg ADDED
results_scurve/cnf-viz-01500.jpg ADDED
results_scurve/cnf-viz-01750.jpg ADDED
results_scurve/cnf-viz-02000.jpg ADDED
results_scurve/cnf-viz-02250.jpg ADDED
results_scurve/cnf-viz-02500.jpg ADDED
results_scurve/cnf-viz-02750.jpg ADDED
results_scurve/cnf-viz-03000.jpg ADDED
results_scurve/cnf-viz-03250.jpg ADDED
results_scurve/cnf-viz-03500.jpg ADDED
results_scurve/cnf-viz-03750.jpg ADDED
results_scurve/cnf-viz-04000.jpg ADDED
results_scurve/cnf-viz-04250.jpg ADDED
results_scurve/cnf-viz-04500.jpg ADDED
results_scurve/cnf-viz-04750.jpg ADDED
results_scurve/cnf-viz-05000.jpg ADDED
results_scurve/cnf-viz-05250.jpg ADDED
results_scurve/cnf-viz-05500.jpg ADDED
results_scurve/cnf-viz-05750.jpg ADDED
results_scurve/cnf-viz-06000.jpg ADDED
results_scurve/cnf-viz-06250.jpg ADDED
results_scurve/cnf-viz-06500.jpg ADDED
results_scurve/cnf-viz-06750.jpg ADDED
results_scurve/cnf-viz-07000.jpg ADDED
results_scurve/cnf-viz-07250.jpg ADDED
results_scurve/cnf-viz-07500.jpg ADDED
results_scurve/cnf-viz-07750.jpg ADDED
results_scurve/cnf-viz-08000.jpg ADDED
results_scurve/cnf-viz-08250.jpg ADDED
results_scurve/cnf-viz-08500.jpg ADDED
results_scurve/cnf-viz-08750.jpg ADDED
results_scurve/cnf-viz-09000.jpg ADDED
results_scurve/cnf-viz-09250.jpg ADDED
results_scurve/cnf-viz-09500.jpg ADDED
results_scurve/cnf-viz-09750.jpg ADDED
results_scurve/cnf-viz-10000.jpg ADDED
results_scurve/cnf-viz.gif ADDED
train_cnf.py CHANGED
@@ -16,7 +16,7 @@ from flax.core import freeze, unfreeze
16
  from flax import linen as nn
17
  from flax import serialization
18
  import optax
19
- from sklearn.datasets import make_circles, make_moons
20
  from tqdm import tqdm
21
 
22
 
@@ -116,6 +116,16 @@ def get_batch_moons(num_samples):
116
  return lax.concatenate((x, logp_diff_t1), 1)
117
 
118
 
 
 
 
 
 
 
 
 
 
 
119
  def multivariate_normal(z):
120
  """
121
  Log probability of multivariate_normal.
@@ -187,6 +197,8 @@ def train(learning_rate, n_iters, batch_size, in_out_dim, hidden_dim, width, t0,
187
  get_batch = lambda num_samples: get_batch_circles(num_samples)
188
  elif dataset == "moons":
189
  get_batch = lambda num_samples: get_batch_moons(num_samples)
 
 
190
 
191
  for itr in range(1, n_iters+1):
192
  batch = get_batch(batch_size)
@@ -222,6 +234,8 @@ def viz(neg_params, pos_params, in_out_dim, hidden_dim, width, t0, t1, dataset):
222
  get_batch = lambda num_samples: get_batch_circles(num_samples)
223
  elif dataset == "moons":
224
  get_batch = lambda num_samples: get_batch_moons(num_samples)
 
 
225
  target_sample = get_batch(viz_samples)[:, :2]
226
 
227
  if not os.path.exists('results_%s/' % dataset):
@@ -299,4 +313,4 @@ def create_plots(z_t_samples, z_t_density, logp_diff_t, t0, t1, viz_timesteps, t
299
 
300
 
301
  if __name__ == '__main__':
302
- train(0.001, 1, 512, 2, 32, 64, 0., 10., True, 'circles')
 
16
  from flax import linen as nn
17
  from flax import serialization
18
  import optax
19
+ from sklearn.datasets import make_circles, make_moons, make_s_curve
20
  from tqdm import tqdm
21
 
22
 
 
116
  return lax.concatenate((x, logp_diff_t1), 1)
117
 
118
 
119
+ def get_batch_scurve(num_samples):
120
+ points, _ = make_s_curve(n_samples=num_samples, noise=0.05, random_state=0)
121
+ x1 = jnp.array(points, dtype=jnp.float32)[:, :1]
122
+ x2 = jnp.array(points, dtype=jnp.float32)[:, 2:]
123
+ x = lax.concatenate((x1, x2), 1)
124
+ logp_diff_t1 = jnp.zeros((num_samples, 1), dtype=jnp.float32)
125
+
126
+ return lax.concatenate((x, logp_diff_t1), 1)
127
+
128
+
129
  def multivariate_normal(z):
130
  """
131
  Log probability of multivariate_normal.
 
197
  get_batch = lambda num_samples: get_batch_circles(num_samples)
198
  elif dataset == "moons":
199
  get_batch = lambda num_samples: get_batch_moons(num_samples)
200
+ elif dataset == "scurve":
201
+ get_batch = lambda num_samples: get_batch_scurve(num_samples)
202
 
203
  for itr in range(1, n_iters+1):
204
  batch = get_batch(batch_size)
 
234
  get_batch = lambda num_samples: get_batch_circles(num_samples)
235
  elif dataset == "moons":
236
  get_batch = lambda num_samples: get_batch_moons(num_samples)
237
+ elif dataset == "scurve":
238
+ get_batch = lambda num_samples: get_batch_scurve(num_samples)
239
  target_sample = get_batch(viz_samples)[:, :2]
240
 
241
  if not os.path.exists('results_%s/' % dataset):
 
313
 
314
 
315
  if __name__ == '__main__':
316
+ train(0.001, 1000, 512, 2, 32, 64, 0., 10., True, 'scurve')