sw32-seo commited on
Commit
2e90bae
1 Parent(s): 0db4000

Wrong log prob example behavior is fixed

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +2 -3
  2. cnf_torch.py +4 -4
  3. main.py +1 -1
  4. results_circles/cnf-viz-00000.jpg +0 -0
  5. results_circles/cnf-viz-00250.jpg +0 -0
  6. results_circles/cnf-viz-00500.jpg +0 -0
  7. results_circles/cnf-viz-00750.jpg +0 -0
  8. results_circles/cnf-viz-01000.jpg +0 -0
  9. results_circles/cnf-viz-01250.jpg +0 -0
  10. results_circles/cnf-viz-01500.jpg +0 -0
  11. results_circles/cnf-viz-01750.jpg +0 -0
  12. results_circles/cnf-viz-02000.jpg +0 -0
  13. results_circles/cnf-viz-02250.jpg +0 -0
  14. results_circles/cnf-viz-02500.jpg +0 -0
  15. results_circles/cnf-viz-02750.jpg +0 -0
  16. results_circles/cnf-viz-03000.jpg +0 -0
  17. results_circles/cnf-viz-03250.jpg +0 -0
  18. results_circles/cnf-viz-03500.jpg +0 -0
  19. results_circles/cnf-viz-03750.jpg +0 -0
  20. results_circles/cnf-viz-04000.jpg +0 -0
  21. results_circles/cnf-viz-04250.jpg +0 -0
  22. results_circles/cnf-viz-04500.jpg +0 -0
  23. results_circles/cnf-viz-04750.jpg +0 -0
  24. results_circles/cnf-viz-05000.jpg +0 -0
  25. results_circles/cnf-viz-05250.jpg +0 -0
  26. results_circles/cnf-viz-05500.jpg +0 -0
  27. results_circles/cnf-viz-05750.jpg +0 -0
  28. results_circles/cnf-viz-06000.jpg +0 -0
  29. results_circles/cnf-viz-06250.jpg +0 -0
  30. results_circles/cnf-viz-06500.jpg +0 -0
  31. results_circles/cnf-viz-06750.jpg +0 -0
  32. results_circles/cnf-viz-07000.jpg +0 -0
  33. results_circles/cnf-viz-07250.jpg +0 -0
  34. results_circles/cnf-viz-07500.jpg +0 -0
  35. results_circles/cnf-viz-07750.jpg +0 -0
  36. results_circles/cnf-viz-08000.jpg +0 -0
  37. results_circles/cnf-viz-08250.jpg +0 -0
  38. results_circles/cnf-viz-08500.jpg +0 -0
  39. results_circles/cnf-viz-08750.jpg +0 -0
  40. results_circles/cnf-viz-09000.jpg +0 -0
  41. results_circles/cnf-viz-09250.jpg +0 -0
  42. results_circles/cnf-viz-09500.jpg +0 -0
  43. results_circles/cnf-viz-09750.jpg +0 -0
  44. results_circles/cnf-viz-10000.jpg +0 -0
  45. results_circles/cnf-viz.gif +0 -0
  46. results_moons/cnf-viz-00000.jpg +0 -0
  47. results_moons/cnf-viz-00250.jpg +0 -0
  48. results_moons/cnf-viz-00500.jpg +0 -0
  49. results_moons/cnf-viz-00750.jpg +0 -0
  50. results_moons/cnf-viz-01000.jpg +0 -0
README.md CHANGED
@@ -47,6 +47,5 @@ python main.py --model=cnf --sample_dataset=circles
47
 
48
  # Sample Results
49
 
50
- ![cnf-viz](https://user-images.githubusercontent.com/72425253/126116823-a014f13a-1171-4309-898f-0b6aedd84649.gif)
51
- ![cnf-viz](https://user-images.githubusercontent.com/72425253/126117205-fa68c16b-fba1-48a0-a965-3ac6cb5e201c.gif)
52
-
 
47
 
48
  # Sample Results
49
 
50
+ ![cnf-viz](https://user-images.githubusercontent.com/72425253/126124351-44e00438-055e-4b1c-90ee-758a545dd602.gif)
51
+ ![cnf-viz](https://user-images.githubusercontent.com/72425253/126124648-dcb3f8f4-396a-447c-96cf-f9304377fa48.gif)
 
cnf_torch.py CHANGED
@@ -15,7 +15,7 @@ import torch.optim as optim
15
  parser = argparse.ArgumentParser()
16
  parser.add_argument('--adjoint', action='store_true')
17
  parser.add_argument('--viz', action='store_true', default=True)
18
- parser.add_argument('--niters', type=int, default=10000)
19
  parser.add_argument('--lr', type=float, default=1e-3)
20
  parser.add_argument('--num_samples', type=int, default=512)
21
  parser.add_argument('--width', type=int, default=64)
@@ -148,9 +148,9 @@ if __name__ == '__main__':
148
 
149
  # model
150
  func = CNF(in_out_dim=2, hidden_dim=args.hidden_dim, width=args.width).to(device)
151
- for param in func.parameters():
152
- nn.init.constant_(param, 0.1)
153
- func(torch.tensor(0.).to(device), (torch.tensor([[0., 1.], [2., 3.], [4., 5.]]).to(device), torch.zeros((2, 1)).to(device)))
154
  optimizer = optim.Adam(func.parameters(), lr=args.lr)
155
  p_z0 = torch.distributions.MultivariateNormal(
156
  loc=torch.tensor([0.0, 0.0]).to(device),
 
15
  parser = argparse.ArgumentParser()
16
  parser.add_argument('--adjoint', action='store_true')
17
  parser.add_argument('--viz', action='store_true', default=True)
18
+ parser.add_argument('--niters', type=int, default=1)
19
  parser.add_argument('--lr', type=float, default=1e-3)
20
  parser.add_argument('--num_samples', type=int, default=512)
21
  parser.add_argument('--width', type=int, default=64)
 
148
 
149
  # model
150
  func = CNF(in_out_dim=2, hidden_dim=args.hidden_dim, width=args.width).to(device)
151
+ # for param in func.parameters():
152
+ # nn.init.constant_(param, 0.1)
153
+ # func(torch.tensor(0.).to(device), (torch.tensor([[0., 1.], [2., 3.], [4., 5.]]).to(device), torch.zeros((2, 1)).to(device)))
154
  optimizer = optim.Adam(func.parameters(), lr=args.lr)
155
  p_z0 = torch.distributions.MultivariateNormal(
156
  loc=torch.tensor([0.0, 0.0]).to(device),
main.py CHANGED
@@ -22,7 +22,7 @@ if __name__ == '__main__':
22
  parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
23
  parser.add_argument("--n_epoch", type=int, default=10, help="Total number of epoch")
24
  parser.add_argument("--batch_size", type=int, default=32, help="Number of images in batch")
25
- parser.add_argument("--sample_dataset", type=str, choices=['circels', 'moons'], default="circles",
26
  help="Sample dataset")
27
  parser.add_argument("--viz", action='store_true')
28
 
 
22
  parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
23
  parser.add_argument("--n_epoch", type=int, default=10, help="Total number of epoch")
24
  parser.add_argument("--batch_size", type=int, default=32, help="Number of images in batch")
25
+ parser.add_argument("--sample_dataset", type=str, choices=['circles', 'moons'], default="circles",
26
  help="Sample dataset")
27
  parser.add_argument("--viz", action='store_true')
28
 
results_circles/cnf-viz-00000.jpg CHANGED
results_circles/cnf-viz-00250.jpg CHANGED
results_circles/cnf-viz-00500.jpg CHANGED
results_circles/cnf-viz-00750.jpg CHANGED
results_circles/cnf-viz-01000.jpg CHANGED
results_circles/cnf-viz-01250.jpg CHANGED
results_circles/cnf-viz-01500.jpg CHANGED
results_circles/cnf-viz-01750.jpg CHANGED
results_circles/cnf-viz-02000.jpg CHANGED
results_circles/cnf-viz-02250.jpg CHANGED
results_circles/cnf-viz-02500.jpg CHANGED
results_circles/cnf-viz-02750.jpg CHANGED
results_circles/cnf-viz-03000.jpg CHANGED
results_circles/cnf-viz-03250.jpg CHANGED
results_circles/cnf-viz-03500.jpg CHANGED
results_circles/cnf-viz-03750.jpg CHANGED
results_circles/cnf-viz-04000.jpg CHANGED
results_circles/cnf-viz-04250.jpg CHANGED
results_circles/cnf-viz-04500.jpg CHANGED
results_circles/cnf-viz-04750.jpg CHANGED
results_circles/cnf-viz-05000.jpg CHANGED
results_circles/cnf-viz-05250.jpg CHANGED
results_circles/cnf-viz-05500.jpg CHANGED
results_circles/cnf-viz-05750.jpg CHANGED
results_circles/cnf-viz-06000.jpg CHANGED
results_circles/cnf-viz-06250.jpg CHANGED
results_circles/cnf-viz-06500.jpg CHANGED
results_circles/cnf-viz-06750.jpg CHANGED
results_circles/cnf-viz-07000.jpg CHANGED
results_circles/cnf-viz-07250.jpg CHANGED
results_circles/cnf-viz-07500.jpg CHANGED
results_circles/cnf-viz-07750.jpg CHANGED
results_circles/cnf-viz-08000.jpg CHANGED
results_circles/cnf-viz-08250.jpg CHANGED
results_circles/cnf-viz-08500.jpg CHANGED
results_circles/cnf-viz-08750.jpg CHANGED
results_circles/cnf-viz-09000.jpg CHANGED
results_circles/cnf-viz-09250.jpg CHANGED
results_circles/cnf-viz-09500.jpg CHANGED
results_circles/cnf-viz-09750.jpg CHANGED
results_circles/cnf-viz-10000.jpg CHANGED
results_circles/cnf-viz.gif CHANGED
results_moons/cnf-viz-00000.jpg CHANGED
results_moons/cnf-viz-00250.jpg CHANGED
results_moons/cnf-viz-00500.jpg CHANGED
results_moons/cnf-viz-00750.jpg CHANGED
results_moons/cnf-viz-01000.jpg CHANGED