Wrong log prob example behavior is fixed
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +2 -3
- cnf_torch.py +4 -4
- main.py +1 -1
- results_circles/cnf-viz-00000.jpg +0 -0
- results_circles/cnf-viz-00250.jpg +0 -0
- results_circles/cnf-viz-00500.jpg +0 -0
- results_circles/cnf-viz-00750.jpg +0 -0
- results_circles/cnf-viz-01000.jpg +0 -0
- results_circles/cnf-viz-01250.jpg +0 -0
- results_circles/cnf-viz-01500.jpg +0 -0
- results_circles/cnf-viz-01750.jpg +0 -0
- results_circles/cnf-viz-02000.jpg +0 -0
- results_circles/cnf-viz-02250.jpg +0 -0
- results_circles/cnf-viz-02500.jpg +0 -0
- results_circles/cnf-viz-02750.jpg +0 -0
- results_circles/cnf-viz-03000.jpg +0 -0
- results_circles/cnf-viz-03250.jpg +0 -0
- results_circles/cnf-viz-03500.jpg +0 -0
- results_circles/cnf-viz-03750.jpg +0 -0
- results_circles/cnf-viz-04000.jpg +0 -0
- results_circles/cnf-viz-04250.jpg +0 -0
- results_circles/cnf-viz-04500.jpg +0 -0
- results_circles/cnf-viz-04750.jpg +0 -0
- results_circles/cnf-viz-05000.jpg +0 -0
- results_circles/cnf-viz-05250.jpg +0 -0
- results_circles/cnf-viz-05500.jpg +0 -0
- results_circles/cnf-viz-05750.jpg +0 -0
- results_circles/cnf-viz-06000.jpg +0 -0
- results_circles/cnf-viz-06250.jpg +0 -0
- results_circles/cnf-viz-06500.jpg +0 -0
- results_circles/cnf-viz-06750.jpg +0 -0
- results_circles/cnf-viz-07000.jpg +0 -0
- results_circles/cnf-viz-07250.jpg +0 -0
- results_circles/cnf-viz-07500.jpg +0 -0
- results_circles/cnf-viz-07750.jpg +0 -0
- results_circles/cnf-viz-08000.jpg +0 -0
- results_circles/cnf-viz-08250.jpg +0 -0
- results_circles/cnf-viz-08500.jpg +0 -0
- results_circles/cnf-viz-08750.jpg +0 -0
- results_circles/cnf-viz-09000.jpg +0 -0
- results_circles/cnf-viz-09250.jpg +0 -0
- results_circles/cnf-viz-09500.jpg +0 -0
- results_circles/cnf-viz-09750.jpg +0 -0
- results_circles/cnf-viz-10000.jpg +0 -0
- results_circles/cnf-viz.gif +0 -0
- results_moons/cnf-viz-00000.jpg +0 -0
- results_moons/cnf-viz-00250.jpg +0 -0
- results_moons/cnf-viz-00500.jpg +0 -0
- results_moons/cnf-viz-00750.jpg +0 -0
- results_moons/cnf-viz-01000.jpg +0 -0
README.md
CHANGED
@@ -47,6 +47,5 @@ python main.py --model=cnf --sample_dataset=circles
|
|
47 |
|
48 |
# Sample Results
|
49 |
|
50 |
-
![cnf-viz](https://user-images.githubusercontent.com/72425253/
|
51 |
-
![cnf-viz](https://user-images.githubusercontent.com/72425253/
|
52 |
-
|
|
|
47 |
|
48 |
# Sample Results
|
49 |
|
50 |
+
![cnf-viz](https://user-images.githubusercontent.com/72425253/126124351-44e00438-055e-4b1c-90ee-758a545dd602.gif)
|
51 |
+
![cnf-viz](https://user-images.githubusercontent.com/72425253/126124648-dcb3f8f4-396a-447c-96cf-f9304377fa48.gif)
|
|
cnf_torch.py
CHANGED
@@ -15,7 +15,7 @@ import torch.optim as optim
|
|
15 |
parser = argparse.ArgumentParser()
|
16 |
parser.add_argument('--adjoint', action='store_true')
|
17 |
parser.add_argument('--viz', action='store_true', default=True)
|
18 |
-
parser.add_argument('--niters', type=int, default=
|
19 |
parser.add_argument('--lr', type=float, default=1e-3)
|
20 |
parser.add_argument('--num_samples', type=int, default=512)
|
21 |
parser.add_argument('--width', type=int, default=64)
|
@@ -148,9 +148,9 @@ if __name__ == '__main__':
|
|
148 |
|
149 |
# model
|
150 |
func = CNF(in_out_dim=2, hidden_dim=args.hidden_dim, width=args.width).to(device)
|
151 |
-
for param in func.parameters():
|
152 |
-
|
153 |
-
func(torch.tensor(0.).to(device), (torch.tensor([[0., 1.], [2., 3.], [4., 5.]]).to(device), torch.zeros((2, 1)).to(device)))
|
154 |
optimizer = optim.Adam(func.parameters(), lr=args.lr)
|
155 |
p_z0 = torch.distributions.MultivariateNormal(
|
156 |
loc=torch.tensor([0.0, 0.0]).to(device),
|
|
|
15 |
parser = argparse.ArgumentParser()
|
16 |
parser.add_argument('--adjoint', action='store_true')
|
17 |
parser.add_argument('--viz', action='store_true', default=True)
|
18 |
+
parser.add_argument('--niters', type=int, default=1)
|
19 |
parser.add_argument('--lr', type=float, default=1e-3)
|
20 |
parser.add_argument('--num_samples', type=int, default=512)
|
21 |
parser.add_argument('--width', type=int, default=64)
|
|
|
148 |
|
149 |
# model
|
150 |
func = CNF(in_out_dim=2, hidden_dim=args.hidden_dim, width=args.width).to(device)
|
151 |
+
# for param in func.parameters():
|
152 |
+
# nn.init.constant_(param, 0.1)
|
153 |
+
# func(torch.tensor(0.).to(device), (torch.tensor([[0., 1.], [2., 3.], [4., 5.]]).to(device), torch.zeros((2, 1)).to(device)))
|
154 |
optimizer = optim.Adam(func.parameters(), lr=args.lr)
|
155 |
p_z0 = torch.distributions.MultivariateNormal(
|
156 |
loc=torch.tensor([0.0, 0.0]).to(device),
|
main.py
CHANGED
@@ -22,7 +22,7 @@ if __name__ == '__main__':
|
|
22 |
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
|
23 |
parser.add_argument("--n_epoch", type=int, default=10, help="Total number of epoch")
|
24 |
parser.add_argument("--batch_size", type=int, default=32, help="Number of images in batch")
|
25 |
-
parser.add_argument("--sample_dataset", type=str, choices=['
|
26 |
help="Sample dataset")
|
27 |
parser.add_argument("--viz", action='store_true')
|
28 |
|
|
|
22 |
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
|
23 |
parser.add_argument("--n_epoch", type=int, default=10, help="Total number of epoch")
|
24 |
parser.add_argument("--batch_size", type=int, default=32, help="Number of images in batch")
|
25 |
+
parser.add_argument("--sample_dataset", type=str, choices=['circles', 'moons'], default="circles",
|
26 |
help="Sample dataset")
|
27 |
parser.add_argument("--viz", action='store_true')
|
28 |
|
results_circles/cnf-viz-00000.jpg
CHANGED
results_circles/cnf-viz-00250.jpg
CHANGED
results_circles/cnf-viz-00500.jpg
CHANGED
results_circles/cnf-viz-00750.jpg
CHANGED
results_circles/cnf-viz-01000.jpg
CHANGED
results_circles/cnf-viz-01250.jpg
CHANGED
results_circles/cnf-viz-01500.jpg
CHANGED
results_circles/cnf-viz-01750.jpg
CHANGED
results_circles/cnf-viz-02000.jpg
CHANGED
results_circles/cnf-viz-02250.jpg
CHANGED
results_circles/cnf-viz-02500.jpg
CHANGED
results_circles/cnf-viz-02750.jpg
CHANGED
results_circles/cnf-viz-03000.jpg
CHANGED
results_circles/cnf-viz-03250.jpg
CHANGED
results_circles/cnf-viz-03500.jpg
CHANGED
results_circles/cnf-viz-03750.jpg
CHANGED
results_circles/cnf-viz-04000.jpg
CHANGED
results_circles/cnf-viz-04250.jpg
CHANGED
results_circles/cnf-viz-04500.jpg
CHANGED
results_circles/cnf-viz-04750.jpg
CHANGED
results_circles/cnf-viz-05000.jpg
CHANGED
results_circles/cnf-viz-05250.jpg
CHANGED
results_circles/cnf-viz-05500.jpg
CHANGED
results_circles/cnf-viz-05750.jpg
CHANGED
results_circles/cnf-viz-06000.jpg
CHANGED
results_circles/cnf-viz-06250.jpg
CHANGED
results_circles/cnf-viz-06500.jpg
CHANGED
results_circles/cnf-viz-06750.jpg
CHANGED
results_circles/cnf-viz-07000.jpg
CHANGED
results_circles/cnf-viz-07250.jpg
CHANGED
results_circles/cnf-viz-07500.jpg
CHANGED
results_circles/cnf-viz-07750.jpg
CHANGED
results_circles/cnf-viz-08000.jpg
CHANGED
results_circles/cnf-viz-08250.jpg
CHANGED
results_circles/cnf-viz-08500.jpg
CHANGED
results_circles/cnf-viz-08750.jpg
CHANGED
results_circles/cnf-viz-09000.jpg
CHANGED
results_circles/cnf-viz-09250.jpg
CHANGED
results_circles/cnf-viz-09500.jpg
CHANGED
results_circles/cnf-viz-09750.jpg
CHANGED
results_circles/cnf-viz-10000.jpg
CHANGED
results_circles/cnf-viz.gif
CHANGED
results_moons/cnf-viz-00000.jpg
CHANGED
results_moons/cnf-viz-00250.jpg
CHANGED
results_moons/cnf-viz-00500.jpg
CHANGED
results_moons/cnf-viz-00750.jpg
CHANGED
results_moons/cnf-viz-01000.jpg
CHANGED