Initial commit
Browse files- LICENSE +21 -0
- README.md +42 -0
- cnn.py +219 -0
- cnn_ode.py +256 -0
- jax_cnn_ode.py +0 -0
- main.py +23 -0
- mlp.py +103 -0
- ode.py +256 -0
- opts.py +0 -0
- train.py +130 -0
- train_cnf.py +274 -0
- train_ode.py +271 -0
- train_resnet.py +195 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2021 Seung-woo Eric Seo
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Neural ODE with Flax
|
2 |
+
This is the result of project ["Reproduce Neural ODE and SDE"][projectlink] in [HuggingFace Flax/JAX community week][comweeklink].
|
3 |
+
|
4 |
+
<code>main.py</code> will execute training of ResNet or OdeNet for MNIST dataset.
|
5 |
+
|
6 |
+
[projectlink]: https://discuss.huggingface.co/t/reproduce-neural-ode-and-neural-sde/7590
|
7 |
+
|
8 |
+
[comweeklink]: https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects#projects
|
9 |
+
|
10 |
+
## Dependency
|
11 |
+
|
12 |
+
### JAX and Flax
|
13 |
+
|
14 |
+
For JAX installation, please follow [here][jaxinstalllink].
|
15 |
+
|
16 |
+
or simply, type
|
17 |
+
```bash
|
18 |
+
pip install jax jaxlib
|
19 |
+
```
|
20 |
+
|
21 |
+
For Flax installation,
|
22 |
+
```bash
|
23 |
+
pip install flax
|
24 |
+
```
|
25 |
+
|
26 |
+
[jaxinstalllink]: https://github.com/google/jax#installation
|
27 |
+
|
28 |
+
|
29 |
+
Tensorflow-datasets will download MNIST dataset to environment.
|
30 |
+
|
31 |
+
## How to run training
|
32 |
+
|
33 |
+
For (small) ResNet training,
|
34 |
+
```bash
|
35 |
+
python main.py --model=resnet --lr=1e-4 --n_epoch=20 --batch_size=64
|
36 |
+
```
|
37 |
+
|
38 |
+
For Neural ODE training,
|
39 |
+
```bash
|
40 |
+
python main.py --model=odenet --lr=1e-4 --n_epoch=20 --batch_size=64
|
41 |
+
```
|
42 |
+
|
cnn.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import jax
|
2 |
+
from typing import Any, Callable, Sequence, Optional
|
3 |
+
from jax import lax, random, vmap, numpy as jnp
|
4 |
+
from jax.experimental.ode import odeint
|
5 |
+
import flax
|
6 |
+
from flax.training import train_state
|
7 |
+
from flax.core import freeze, unfreeze
|
8 |
+
from flax import linen as nn
|
9 |
+
from flax import serialization
|
10 |
+
import optax
|
11 |
+
import tensorflow_datasets as tfds
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
|
15 |
+
# Define model
|
16 |
+
class CNN(nn.Module):
|
17 |
+
"""A simple CNN model."""
|
18 |
+
|
19 |
+
@nn.compact
|
20 |
+
def __call__(self, inputs):
|
21 |
+
x = inputs
|
22 |
+
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
|
23 |
+
x = nn.relu(x)
|
24 |
+
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
|
25 |
+
|
26 |
+
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
|
27 |
+
x = nn.relu(x)
|
28 |
+
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
|
29 |
+
x = x.reshape((x.shape[0], -1)) # flatten
|
30 |
+
|
31 |
+
x = nn.Dense(features=256)(x)
|
32 |
+
x = nn.relu(x)
|
33 |
+
x = nn.Dense(features=10)(x)
|
34 |
+
x = nn.log_softmax(x)
|
35 |
+
return x
|
36 |
+
|
37 |
+
|
38 |
+
# Define Residual Block
|
39 |
+
class ResBlock(nn.Module):
|
40 |
+
"""Single Resblock w/o downsample"""
|
41 |
+
|
42 |
+
@nn.compact
|
43 |
+
def __call__(self, inputs):
|
44 |
+
x = inputs
|
45 |
+
f_x = nn.relu(nn.GroupNorm(64)(x))
|
46 |
+
f_x = nn.Conv(features=64, kernel_size=(3, 3))(f_x)
|
47 |
+
f_x = nn.relu(nn.GroupNorm(64)(f_x))
|
48 |
+
f_x = nn.Conv(features=64, kernel_size=(3, 3))(f_x)
|
49 |
+
x = f_x + x
|
50 |
+
return x
|
51 |
+
|
52 |
+
class ResDownBlock(nn.Module):
|
53 |
+
"""Single ResBlock w/ downsample"""
|
54 |
+
|
55 |
+
@nn.compact
|
56 |
+
def __call__(self, inputs):
|
57 |
+
x = inputs
|
58 |
+
f_x = nn.relu(nn.GroupNorm(64)(x))
|
59 |
+
x = nn.Conv(features=64, kernel_size=(1, 1), strides=(2, 2))(x)
|
60 |
+
f_x = nn.Conv(features=64, kernel_size=(3, 3), strides=(2, 2))(f_x)
|
61 |
+
f_x = nn.relu(nn.GroupNorm(64)(f_x))
|
62 |
+
f_x = nn.Conv(features=64, kernel_size=(3, 3))(f_x)
|
63 |
+
x = f_x + x
|
64 |
+
return x
|
65 |
+
|
66 |
+
|
67 |
+
# Define Model for Mnist example in Neural ODE
|
68 |
+
class SmallResNet(nn.Module):
|
69 |
+
res_down1: Callable = ResDownBlock()
|
70 |
+
res_down2: Callable = ResDownBlock()
|
71 |
+
resblock1: Callable = ResBlock()
|
72 |
+
resblock2: Callable = ResBlock()
|
73 |
+
resblock3: Callable = ResBlock()
|
74 |
+
resblock4: Callable = ResBlock()
|
75 |
+
resblock5: Callable = ResBlock()
|
76 |
+
resblock6: Callable = ResBlock()
|
77 |
+
|
78 |
+
@nn.compact
|
79 |
+
def __call__(self, inputs):
|
80 |
+
x = inputs
|
81 |
+
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
|
82 |
+
x = self.res_down1(x)
|
83 |
+
x = self.res_down2(x)
|
84 |
+
|
85 |
+
x = self.resblock1(x)
|
86 |
+
x = self.resblock2(x)
|
87 |
+
x = self.resblock3(x)
|
88 |
+
x = self.resblock4(x)
|
89 |
+
x = self.resblock5(x)
|
90 |
+
x = self.resblock6(x)
|
91 |
+
|
92 |
+
x = nn.GroupNorm(64)(x)
|
93 |
+
x = nn.relu(x)
|
94 |
+
x = nn.avg_pool(x, (1, 1))
|
95 |
+
|
96 |
+
x = x.reshape((x.shape[0], -1)) # flatten
|
97 |
+
|
98 |
+
x = nn.Dense(features=10)(x)
|
99 |
+
x = nn.log_softmax(x)
|
100 |
+
|
101 |
+
return x
|
102 |
+
|
103 |
+
|
104 |
+
# Define loss
|
105 |
+
def cross_entropy_loss(*, logits, labels):
|
106 |
+
one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
|
107 |
+
return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))
|
108 |
+
|
109 |
+
|
110 |
+
# Metric computation
|
111 |
+
def compute_metrics(*, logits, labels):
|
112 |
+
loss = cross_entropy_loss(logits=logits, labels=labels)
|
113 |
+
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
|
114 |
+
metrics = {
|
115 |
+
'loss': loss,
|
116 |
+
'accuracy': accuracy,
|
117 |
+
}
|
118 |
+
return metrics
|
119 |
+
|
120 |
+
|
121 |
+
def get_datasets():
|
122 |
+
"""Load MNIST train and test datasets into memory."""
|
123 |
+
ds_builder = tfds.builder('mnist')
|
124 |
+
ds_builder.download_and_prepare()
|
125 |
+
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
|
126 |
+
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
|
127 |
+
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
|
128 |
+
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
|
129 |
+
return train_ds, test_ds
|
130 |
+
|
131 |
+
|
132 |
+
def create_train_state(rng, learning_rate):
|
133 |
+
"""Creates initial 'TrainState'."""
|
134 |
+
cnn = SmallResNet()
|
135 |
+
params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
|
136 |
+
tx = optax.adam(learning_rate)
|
137 |
+
return train_state.TrainState.create(
|
138 |
+
apply_fn=cnn.apply, params=params, tx=tx
|
139 |
+
)
|
140 |
+
|
141 |
+
|
142 |
+
# Training step
|
143 |
+
@jax.jit
|
144 |
+
def train_step(state, batch):
|
145 |
+
"""Train for a single step."""
|
146 |
+
def loss_fn(params):
|
147 |
+
logits = SmallResNet().apply({'params': params}, batch['image'])
|
148 |
+
loss = cross_entropy_loss(logits=logits, labels=batch['label'])
|
149 |
+
return loss, logits
|
150 |
+
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
151 |
+
(_, logits), grads = grad_fn(state.params)
|
152 |
+
state = state.apply_gradients(grads=grads)
|
153 |
+
metrics = compute_metrics(logits=logits, labels=batch['label'])
|
154 |
+
return state, metrics
|
155 |
+
|
156 |
+
|
157 |
+
# Evaluation step
|
158 |
+
@jax.jit
|
159 |
+
def eval_step(params, batch):
|
160 |
+
logits = SmallResNet().apply({'params': params}, batch['image'])
|
161 |
+
return compute_metrics(logits=logits, labels=batch['label'])
|
162 |
+
|
163 |
+
|
164 |
+
# Train function
|
165 |
+
def train_epoch(state, train_ds, batch_size, epoch, rng):
|
166 |
+
"""Train for a single epoch"""
|
167 |
+
train_ds_size = len(train_ds['image'])
|
168 |
+
steps_per_epoch = train_ds_size // batch_size
|
169 |
+
|
170 |
+
perms = jax.random.permutation(rng, len(train_ds['image']))
|
171 |
+
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
|
172 |
+
perms = perms.reshape((steps_per_epoch, batch_size))
|
173 |
+
batch_metrics = []
|
174 |
+
for perm in perms:
|
175 |
+
batch = {k: v[perm, ...] for k, v in train_ds.items()}
|
176 |
+
state, metrics = train_step(state, batch)
|
177 |
+
batch_metrics.append(metrics)
|
178 |
+
|
179 |
+
# compute mean of metrics across each batch in epoch.
|
180 |
+
batch_metrics_np = jax.device_get(batch_metrics)
|
181 |
+
epoch_metrics_np = {
|
182 |
+
k: np.mean([metrics[k] for metrics in batch_metrics_np])
|
183 |
+
for k in batch_metrics_np[0]
|
184 |
+
}
|
185 |
+
print('train epoch: %d, loss: %.4f, accuracy: %.2f' % (
|
186 |
+
epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100
|
187 |
+
))
|
188 |
+
|
189 |
+
return state
|
190 |
+
|
191 |
+
|
192 |
+
# Eval function
|
193 |
+
def eval_model(params, test_ds):
|
194 |
+
metrics = eval_step(params, test_ds)
|
195 |
+
metrics = jax.device_get(metrics)
|
196 |
+
summary = jax.tree_map(lambda x: x.item(), metrics)
|
197 |
+
return summary['loss'], summary['accuracy']
|
198 |
+
|
199 |
+
|
200 |
+
if __name__ == '__main__':
|
201 |
+
train_ds, test_ds = get_datasets()
|
202 |
+
rng = jax.random.PRNGKey(0)
|
203 |
+
rng, init_rng = jax.random.split(rng)
|
204 |
+
|
205 |
+
learning_rate = 0.0001
|
206 |
+
|
207 |
+
state = create_train_state(init_rng, learning_rate)
|
208 |
+
del init_rng # Must not be used anymore.
|
209 |
+
|
210 |
+
num_epochs = 40
|
211 |
+
batch_size = 128
|
212 |
+
|
213 |
+
for epoch in range(1, num_epochs + 1):
|
214 |
+
rng, input_rng = jax.random.split(rng)
|
215 |
+
state = train_epoch(state, train_ds, batch_size, epoch, input_rng)
|
216 |
+
test_loss, test_accuracy = eval_model(state.params, test_ds)
|
217 |
+
print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (
|
218 |
+
epoch, test_loss, test_accuracy * 100
|
219 |
+
))
|
cnn_ode.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import jax
|
3 |
+
from typing import Any, Callable, Sequence, Optional, NewType
|
4 |
+
from jax import lax, random, vmap, numpy as jnp
|
5 |
+
from jax.experimental.ode import odeint
|
6 |
+
import flax
|
7 |
+
from flax.training import train_state
|
8 |
+
from flax import traverse_util
|
9 |
+
from flax.core import freeze, unfreeze
|
10 |
+
from flax import linen as nn
|
11 |
+
from flax import serialization
|
12 |
+
import optax
|
13 |
+
import tensorflow_datasets as tfds
|
14 |
+
import numpy as np
|
15 |
+
from tqdm import tqdm
|
16 |
+
import os
|
17 |
+
|
18 |
+
|
19 |
+
# TODO Add system argument for dim_out, ksize, tol, learning_rate, num_epoch and batch_size
|
20 |
+
|
21 |
+
# Define Residual Block
|
22 |
+
class ResDownBlock(nn.Module):
|
23 |
+
"""Single ResBlock w/ downsample"""
|
24 |
+
dim_out: Any = 64
|
25 |
+
|
26 |
+
@nn.compact
|
27 |
+
def __call__(self, inputs):
|
28 |
+
x = inputs
|
29 |
+
f_x = nn.relu(nn.GroupNorm(self.dim_out)(x))
|
30 |
+
x = nn.Conv(features=self.dim_out, kernel_size=(1, 1), strides=(2, 2))(x)
|
31 |
+
f_x = nn.Conv(features=self.dim_out, kernel_size=(3, 3), strides=(2, 2))(f_x)
|
32 |
+
f_x = nn.relu(nn.GroupNorm(self.dim_out)(f_x))
|
33 |
+
f_x = nn.Conv(features=self.dim_out, kernel_size=(3, 3))(f_x)
|
34 |
+
x = f_x + x
|
35 |
+
return x
|
36 |
+
|
37 |
+
|
38 |
+
class ConcatConv2D(nn.Module):
|
39 |
+
"""Concat dynamics to hidden layer"""
|
40 |
+
dim_out: Any = 64
|
41 |
+
ksize: Any = 3
|
42 |
+
|
43 |
+
@nn.compact
|
44 |
+
def __call__(self, x, t):
|
45 |
+
tt = jnp.ones_like(x[..., :1]) * t
|
46 |
+
ttx = jnp.concatenate([tt, x], -1)
|
47 |
+
return nn.Conv(features=self.dim_out, kernel_size=self.ksize)(ttx)
|
48 |
+
|
49 |
+
|
50 |
+
# Define Model for Mnist example in Neural ODE
|
51 |
+
class ODEfunc(nn.Module):
|
52 |
+
"""ODE function which replace ResNet"""
|
53 |
+
dim_out: Any = 64
|
54 |
+
ksize: Any = 3
|
55 |
+
|
56 |
+
@nn.compact
|
57 |
+
def __call__(self, inputs, t):
|
58 |
+
# TODO Count number of function estimation
|
59 |
+
# nfe_counter = NFEcounter()
|
60 |
+
# nfe_counter()
|
61 |
+
|
62 |
+
x = inputs
|
63 |
+
out = nn.GroupNorm(self.dim_out)(x)
|
64 |
+
out = nn.relu(out)
|
65 |
+
out = ConcatConv2D(self.dim_out, self.ksize)(out, t)
|
66 |
+
out = nn.GroupNorm(self.dim_out)(out)
|
67 |
+
out = nn.relu(out)
|
68 |
+
out = ConcatConv2D(self.dim_out, self.ksize)(out, t)
|
69 |
+
out = nn.GroupNorm(self.dim_out)(out)
|
70 |
+
|
71 |
+
return out
|
72 |
+
|
73 |
+
|
74 |
+
class NFEcounter(nn.Module):
|
75 |
+
|
76 |
+
@nn.compact
|
77 |
+
def __call__(self):
|
78 |
+
is_initialized = self.has_variable('nfe', 'nfe')
|
79 |
+
nfe = self.variable('nfe', 'nfe', jnp.array, [0])
|
80 |
+
if is_initialized:
|
81 |
+
nfe.value += 1
|
82 |
+
|
83 |
+
|
84 |
+
class ODEBlock(nn.Module):
|
85 |
+
"""ODE block which contains odeint"""
|
86 |
+
tol = 1.
|
87 |
+
|
88 |
+
@nn.compact
|
89 |
+
def __call__(self, x, params):
|
90 |
+
ode_func = ODEfunc()
|
91 |
+
init_state, final_state = odeint(partial(ode_func.apply, {'params': params}),
|
92 |
+
x, jnp.array([0., 1.]),
|
93 |
+
rtol=self.tol, atol=self.tol)
|
94 |
+
return final_state
|
95 |
+
|
96 |
+
|
97 |
+
class ODEBlockVmap(nn.Module):
|
98 |
+
"""Apply vmap to ODEBlock"""
|
99 |
+
|
100 |
+
@nn.compact
|
101 |
+
def __call__(self, x, params):
|
102 |
+
vmap_odeblock = nn.vmap(ODEBlock,
|
103 |
+
variable_axes={'params': 0, 'nfe': None},
|
104 |
+
split_rngs={'params': True, 'nfe': False},
|
105 |
+
in_axes=(0, None))
|
106 |
+
return vmap_odeblock(name='odeblock')(x, params)
|
107 |
+
|
108 |
+
|
109 |
+
class FullODENet(nn.Module):
|
110 |
+
"""Full ODE net which contains two downsampling layers, ODE block and linear classifier."""
|
111 |
+
dim_out: Any = 64
|
112 |
+
ksize: Any = 3
|
113 |
+
|
114 |
+
@nn.compact
|
115 |
+
def __call__(self, inputs):
|
116 |
+
x = inputs
|
117 |
+
x = nn.Conv(features=self.dim_out, kernel_size=(self.ksize, self.ksize))(x)
|
118 |
+
x = ResDownBlock()(x)
|
119 |
+
x = ResDownBlock()(x)
|
120 |
+
|
121 |
+
ode_func = ODEfunc()
|
122 |
+
init_fn = lambda rng, x: ode_func.init(random.split(rng)[-1], x, 0.)['params']
|
123 |
+
ode_func_params = self.param('ode_func', init_fn, jnp.ones_like(x[0]))
|
124 |
+
x = ODEBlockVmap()(x, ode_func_params)
|
125 |
+
|
126 |
+
x = nn.GroupNorm(self.dim_out)(x)
|
127 |
+
x = nn.relu(x)
|
128 |
+
x = nn.avg_pool(x, (1, 1))
|
129 |
+
|
130 |
+
x = x.reshape((x.shape[0], -1)) # flatten
|
131 |
+
|
132 |
+
x = nn.Dense(features=10)(x)
|
133 |
+
x = nn.log_softmax(x)
|
134 |
+
|
135 |
+
return x
|
136 |
+
|
137 |
+
|
138 |
+
# Define loss
|
139 |
+
@jax.jit
|
140 |
+
def cross_entropy_loss(logits, labels):
|
141 |
+
one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
|
142 |
+
return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))
|
143 |
+
|
144 |
+
|
145 |
+
# Metric computation
|
146 |
+
@jax.jit
|
147 |
+
def compute_metrics(logits, labels):
|
148 |
+
loss = cross_entropy_loss(logits=logits, labels=labels)
|
149 |
+
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
|
150 |
+
metrics = {
|
151 |
+
'loss': loss,
|
152 |
+
'accuracy': accuracy,
|
153 |
+
}
|
154 |
+
return metrics
|
155 |
+
|
156 |
+
|
157 |
+
def get_datasets():
|
158 |
+
"""Load MNIST train and test datasets into memory."""
|
159 |
+
ds_builder = tfds.builder('mnist')
|
160 |
+
ds_builder.download_and_prepare()
|
161 |
+
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
|
162 |
+
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
|
163 |
+
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
|
164 |
+
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
|
165 |
+
return train_ds, test_ds
|
166 |
+
|
167 |
+
|
168 |
+
def create_train_state(rng, learning_rate):
|
169 |
+
"""Creates initial 'TrainState'."""
|
170 |
+
cnn = FullODENet()
|
171 |
+
params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
|
172 |
+
tx = optax.adam(learning_rate)
|
173 |
+
return train_state.TrainState.create(
|
174 |
+
apply_fn=cnn.apply, params=params, tx=tx
|
175 |
+
)
|
176 |
+
|
177 |
+
|
178 |
+
# Training step
|
179 |
+
@jax.jit
|
180 |
+
def train_step(state, batch):
|
181 |
+
"""Train for a single step."""
|
182 |
+
def loss_fn(params):
|
183 |
+
logits = FullODENet().apply({'params': params}, batch['image'])
|
184 |
+
loss = cross_entropy_loss(logits=logits, labels=batch['label'])
|
185 |
+
return loss, logits
|
186 |
+
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
187 |
+
(_, logits), grads = grad_fn(state.params)
|
188 |
+
state = state.apply_gradients(grads=grads)
|
189 |
+
metrics = compute_metrics(logits=logits, labels=batch['label'])
|
190 |
+
return state, metrics
|
191 |
+
|
192 |
+
|
193 |
+
# Evaluation step
|
194 |
+
@jax.jit
|
195 |
+
def eval_step(params, batch):
|
196 |
+
logits = FullODENet().apply({'params': params}, batch['image'])
|
197 |
+
return compute_metrics(logits=logits, labels=batch['label'])
|
198 |
+
|
199 |
+
|
200 |
+
# Train function
|
201 |
+
def train_epoch(state, train_ds, batch_size, epoch, rng):
|
202 |
+
"""Train for a single epoch"""
|
203 |
+
train_ds_size = len(train_ds['image'])
|
204 |
+
steps_per_epoch = train_ds_size // batch_size
|
205 |
+
|
206 |
+
perms = jax.random.permutation(rng, len(train_ds['image']))
|
207 |
+
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
|
208 |
+
perms = perms.reshape((steps_per_epoch, batch_size))
|
209 |
+
batch_metrics = []
|
210 |
+
for perm in tqdm(perms):
|
211 |
+
batch = {k: v[perm, ...] for k, v in train_ds.items()}
|
212 |
+
state, metrics = train_step(state, batch)
|
213 |
+
batch_metrics.append(metrics)
|
214 |
+
|
215 |
+
# compute mean of metrics across each batch in epoch.
|
216 |
+
batch_metrics_np = jax.device_get(batch_metrics)
|
217 |
+
epoch_metrics_np = {
|
218 |
+
k: np.mean([metrics[k] for metrics in batch_metrics_np])
|
219 |
+
for k in batch_metrics_np[0]
|
220 |
+
}
|
221 |
+
print('train epoch: %d, loss: %.4f, accuracy: %.2f' % (
|
222 |
+
epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100
|
223 |
+
))
|
224 |
+
|
225 |
+
return state
|
226 |
+
|
227 |
+
|
228 |
+
# Eval function
|
229 |
+
def eval_model(params, test_ds):
|
230 |
+
metrics = eval_step(params, test_ds)
|
231 |
+
metrics = jax.device_get(metrics)
|
232 |
+
summary = jax.tree_map(lambda x: x.item(), metrics)
|
233 |
+
return summary['loss'], summary['accuracy']
|
234 |
+
|
235 |
+
|
236 |
+
if __name__ == '__main__':
|
237 |
+
train_ds, test_ds = get_datasets()
|
238 |
+
rng = jax.random.PRNGKey(0)
|
239 |
+
rng, init_rng = jax.random.split(rng)
|
240 |
+
|
241 |
+
# Build learning rate decay as Neural ODE paper
|
242 |
+
learning_rate = 0.0001
|
243 |
+
|
244 |
+
state = create_train_state(init_rng, learning_rate)
|
245 |
+
del init_rng # Must not be used anymore.
|
246 |
+
|
247 |
+
num_epochs = 20
|
248 |
+
batch_size = 128
|
249 |
+
|
250 |
+
for epoch in tqdm(range(1, num_epochs + 1)):
|
251 |
+
rng, input_rng = jax.random.split(rng)
|
252 |
+
state = train_epoch(state, train_ds, batch_size, epoch, input_rng)
|
253 |
+
test_loss, test_accuracy = eval_model(state.params, test_ds)
|
254 |
+
print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (
|
255 |
+
epoch, test_loss, test_accuracy * 100
|
256 |
+
))
|
jax_cnn_ode.py
ADDED
File without changes
|
main.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import train_ode
|
3 |
+
import train_resnet
|
4 |
+
|
5 |
+
|
6 |
+
def main(args):
|
7 |
+
if args.model == 'odenet':
|
8 |
+
train_ode.train_and_evaluate(args.lr, args.n_epoch, args.batch_size, args.tol)
|
9 |
+
else:
|
10 |
+
train_resnet.train_and_evaluate(args.lr, args.n_epoch, args.batch_size)
|
11 |
+
|
12 |
+
|
13 |
+
if __name__ == '__main__':
|
14 |
+
parser = argparse.ArgumentParser(description='main.py')
|
15 |
+
parser.add_argument("--model", type=str, choices=['odenet', 'resnet'], default="odenet", help="Type of model")
|
16 |
+
parser.add_argument("--tol", type=float, default=1e-1,
|
17 |
+
help="Error tolerance for ODE solver. This only works with odenet")
|
18 |
+
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
|
19 |
+
parser.add_argument("--n_epoch", type=int, default=10, help="Total number of epoch")
|
20 |
+
parser.add_argument("--batch_size", type=int, default=32, help="Number of images in batch")
|
21 |
+
|
22 |
+
args = parser.parse_args()
|
23 |
+
main(args)
|
mlp.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import jax
|
2 |
+
from typing import Any, Callable, Sequence, Optional
|
3 |
+
from jax import lax, random, numpy as jnp
|
4 |
+
import flax
|
5 |
+
from flax.training import train_state
|
6 |
+
from flax.core import freeze, unfreeze
|
7 |
+
from flax import linen as nn
|
8 |
+
from flax import serialization
|
9 |
+
import optax
|
10 |
+
|
11 |
+
|
12 |
+
class ExplicitMLP(nn.Module):
|
13 |
+
features: Sequence[int]
|
14 |
+
|
15 |
+
def setup(self):
|
16 |
+
self.layers = [nn.Dense(feat) for feat in self.features]
|
17 |
+
|
18 |
+
def __call__(self, inputs):
|
19 |
+
x = inputs
|
20 |
+
for i, lyr in enumerate(self.layers):
|
21 |
+
x = lyr(x)
|
22 |
+
if i != len(self.layers) - 1:
|
23 |
+
x = nn.relu(x)
|
24 |
+
return x
|
25 |
+
|
26 |
+
|
27 |
+
class SimpleMLP(nn.Module):
|
28 |
+
features: Sequence[int]
|
29 |
+
|
30 |
+
@nn.compact
|
31 |
+
def __call__(self, inputs):
|
32 |
+
x = inputs
|
33 |
+
for i, feat in enumerate(self.features):
|
34 |
+
x = nn.Dense(feat)(x)
|
35 |
+
if i != len(self.features - 1):
|
36 |
+
x = nn.relu(x)
|
37 |
+
return x
|
38 |
+
|
39 |
+
|
40 |
+
if __name__ == '__main__':
|
41 |
+
key1, key2 = random.split(random.PRNGKey(0), 2)
|
42 |
+
|
43 |
+
# Set problem dimensions
|
44 |
+
nsamples = 20
|
45 |
+
xdim = 10
|
46 |
+
ydim = 5
|
47 |
+
|
48 |
+
# Generate true W and b
|
49 |
+
W = random.normal(key1, (xdim, ydim))
|
50 |
+
b = random.normal(key2, (ydim,))
|
51 |
+
true_params = freeze({'params': {'bias': b, 'kernel': W}})
|
52 |
+
|
53 |
+
# Generate samples with additional noise
|
54 |
+
ksample, knoise = random.split(key1)
|
55 |
+
x_samples = random.normal(ksample, (nsamples, xdim))
|
56 |
+
y_samples = jnp.dot(x_samples, W) + b
|
57 |
+
y_samples += 0.1 * random.normal(knoise, (nsamples, ydim)) # Adding noise
|
58 |
+
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)
|
59 |
+
|
60 |
+
key_init, subkey = random.split(ksample, 2)
|
61 |
+
model = ExplicitMLP(features=[5])
|
62 |
+
params = model.init(subkey, x_samples)
|
63 |
+
|
64 |
+
def make_mse_func(x_batched, y_batched):
|
65 |
+
def mse(params):
|
66 |
+
# Define the squared loss for a single pair (x,y)
|
67 |
+
def squared_error(x, y):
|
68 |
+
pred = model.apply(params, x)
|
69 |
+
return jnp.inner(y - pred, y - pred) / 2.0
|
70 |
+
|
71 |
+
# We vectorize the previous to compute the average of the loss on all samples.
|
72 |
+
return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)
|
73 |
+
|
74 |
+
return jax.jit(mse) # And finally we jit the result.
|
75 |
+
|
76 |
+
# Get the sampled loss
|
77 |
+
loss = make_mse_func(x_samples, y_samples)
|
78 |
+
|
79 |
+
lr = 0.3
|
80 |
+
tx = optax.sgd(learning_rate=lr)
|
81 |
+
opt_state = tx.init(params)
|
82 |
+
loss_grad_fn = jax.value_and_grad(loss)
|
83 |
+
|
84 |
+
for i in range(101):
|
85 |
+
loss_val, grads = loss_grad_fn(params)
|
86 |
+
updates, opt_state = tx.update(grads, opt_state)
|
87 |
+
params = optax.apply_updates(params, updates)
|
88 |
+
|
89 |
+
if i % 10 == 0:
|
90 |
+
print('Loss step {}: '.format(i), loss_val)
|
91 |
+
|
92 |
+
# Serializing the result
|
93 |
+
bytes_output = serialization.to_bytes(params)
|
94 |
+
dict_output = serialization.to_state_dict(params)
|
95 |
+
print('Dict output')
|
96 |
+
print(dict_output)
|
97 |
+
print('Bytes output')
|
98 |
+
print(bytes_output)
|
99 |
+
|
100 |
+
# Restore the parameter from the saved one
|
101 |
+
saved_params = serialization.from_bytes(params, bytes_output)
|
102 |
+
print(loss(saved_params))
|
103 |
+
print(loss(params))
|
ode.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2018 Google LLC
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""JAX-based Dormand-Prince ODE integration with adaptive stepsize.
|
16 |
+
Integrate systems of ordinary differential equations (ODEs) using the JAX
|
17 |
+
autograd/diff library and the Dormand-Prince method for adaptive integration
|
18 |
+
stepsize calculation. Provides improved integration accuracy over fixed
|
19 |
+
stepsize integration methods.
|
20 |
+
For details of the mixed 4th/5th order Runge-Kutta integration method, see
|
21 |
+
https://doi.org/10.1090/S0025-5718-1986-0815836-3
|
22 |
+
Adjoint algorithm based on Appendix C of https://arxiv.org/pdf/1806.07366.pdf
|
23 |
+
"""
|
24 |
+
|
25 |
+
|
26 |
+
from functools import partial
|
27 |
+
import operator as op
|
28 |
+
|
29 |
+
import jax
|
30 |
+
import jax.numpy as jnp
|
31 |
+
from jax import core
|
32 |
+
from jax import custom_derivatives
|
33 |
+
from jax import lax
|
34 |
+
from jax._src.util import safe_map, safe_zip
|
35 |
+
from jax.flatten_util import ravel_pytree
|
36 |
+
from jax.tree_util import tree_map
|
37 |
+
from jax import linear_util as lu
|
38 |
+
|
39 |
+
map = safe_map
|
40 |
+
zip = safe_zip
|
41 |
+
|
42 |
+
|
43 |
+
def ravel_first_arg(f, unravel):
|
44 |
+
return ravel_first_arg_(lu.wrap_init(f), unravel).call_wrapped
|
45 |
+
|
46 |
+
@lu.transformation
|
47 |
+
def ravel_first_arg_(unravel, y_flat, *args):
|
48 |
+
y = unravel(y_flat)
|
49 |
+
ans = yield (y,) + args, {}
|
50 |
+
ans_flat, _ = ravel_pytree(ans)
|
51 |
+
yield ans_flat
|
52 |
+
|
53 |
+
def interp_fit_dopri(y0, y1, k, dt):
|
54 |
+
# Fit a polynomial to the results of a Runge-Kutta step.
|
55 |
+
dps_c_mid = jnp.array([
|
56 |
+
6025192743 / 30085553152 / 2, 0, 51252292925 / 65400821598 / 2,
|
57 |
+
-2691868925 / 45128329728 / 2, 187940372067 / 1594534317056 / 2,
|
58 |
+
-1776094331 / 19743644256 / 2, 11237099 / 235043384 / 2])
|
59 |
+
y_mid = y0 + dt * jnp.dot(dps_c_mid, k)
|
60 |
+
return jnp.asarray(fit_4th_order_polynomial(y0, y1, y_mid, k[0], k[-1], dt))
|
61 |
+
|
62 |
+
def fit_4th_order_polynomial(y0, y1, y_mid, dy0, dy1, dt):
|
63 |
+
a = -2.*dt*dy0 + 2.*dt*dy1 - 8.*y0 - 8.*y1 + 16.*y_mid
|
64 |
+
b = 5.*dt*dy0 - 3.*dt*dy1 + 18.*y0 + 14.*y1 - 32.*y_mid
|
65 |
+
c = -4.*dt*dy0 + dt*dy1 - 11.*y0 - 5.*y1 + 16.*y_mid
|
66 |
+
d = dt * dy0
|
67 |
+
e = y0
|
68 |
+
return a, b, c, d, e
|
69 |
+
|
70 |
+
def initial_step_size(fun, t0, y0, order, rtol, atol, f0):
|
71 |
+
# Algorithm from:
|
72 |
+
# E. Hairer, S. P. Norsett G. Wanner,
|
73 |
+
# Solving Ordinary Differential Equations I: Nonstiff Problems, Sec. II.4.
|
74 |
+
scale = atol + jnp.abs(y0) * rtol
|
75 |
+
d0 = jnp.linalg.norm(y0 / scale)
|
76 |
+
d1 = jnp.linalg.norm(f0 / scale)
|
77 |
+
|
78 |
+
h0 = jnp.where((d0 < 1e-5) | (d1 < 1e-5), 1e-6, 0.01 * d0 / d1)
|
79 |
+
|
80 |
+
y1 = y0 + h0 * f0
|
81 |
+
|
82 |
+
f1 = fun(y1, t0 + h0)
|
83 |
+
d2 = jnp.linalg.norm((f1 - f0) / scale) / h0
|
84 |
+
|
85 |
+
h1 = jnp.where((d1 <= 1e-15) & (d2 <= 1e-15),
|
86 |
+
jnp.maximum(1e-6, h0 * 1e-3),
|
87 |
+
(0.01 / jnp.max(d1 + d2)) ** (1. / (order + 1.)))
|
88 |
+
|
89 |
+
return jnp.minimum(100. * h0, h1)
|
90 |
+
|
91 |
+
def runge_kutta_step(func, y0, f0, t0, dt):
|
92 |
+
# Dopri5 Butcher tableaux
|
93 |
+
alpha = jnp.array([1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1., 0])
|
94 |
+
beta = jnp.array([
|
95 |
+
[1 / 5, 0, 0, 0, 0, 0, 0],
|
96 |
+
[3 / 40, 9 / 40, 0, 0, 0, 0, 0],
|
97 |
+
[44 / 45, -56 / 15, 32 / 9, 0, 0, 0, 0],
|
98 |
+
[19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729, 0, 0, 0],
|
99 |
+
[9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656, 0, 0],
|
100 |
+
[35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0]
|
101 |
+
])
|
102 |
+
c_sol = jnp.array([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0])
|
103 |
+
c_error = jnp.array([35 / 384 - 1951 / 21600, 0, 500 / 1113 - 22642 / 50085,
|
104 |
+
125 / 192 - 451 / 720, -2187 / 6784 - -12231 / 42400,
|
105 |
+
11 / 84 - 649 / 6300, -1. / 60.])
|
106 |
+
|
107 |
+
def body_fun(i, k):
|
108 |
+
ti = t0 + dt * alpha[i-1]
|
109 |
+
yi = y0 + dt * jnp.dot(beta[i-1, :], k)
|
110 |
+
ft = func(yi, ti)
|
111 |
+
return k.at[i, :].set(ft)
|
112 |
+
|
113 |
+
k = jnp.zeros((7, f0.shape[0]), f0.dtype).at[0, :].set(f0)
|
114 |
+
k = lax.fori_loop(1, 7, body_fun, k)
|
115 |
+
|
116 |
+
y1 = dt * jnp.dot(c_sol, k) + y0
|
117 |
+
y1_error = dt * jnp.dot(c_error, k)
|
118 |
+
f1 = k[-1]
|
119 |
+
return y1, f1, y1_error, k
|
120 |
+
|
121 |
+
def abs2(x):
|
122 |
+
if jnp.iscomplexobj(x):
|
123 |
+
return x.real ** 2 + x.imag ** 2
|
124 |
+
else:
|
125 |
+
return x ** 2
|
126 |
+
|
127 |
+
def error_ratio(error_estimate, rtol, atol, y0, y1):
|
128 |
+
err_tol = atol + rtol * jnp.maximum(jnp.abs(y0), jnp.abs(y1))
|
129 |
+
err_ratio = error_estimate / err_tol
|
130 |
+
return jnp.mean(abs2(err_ratio))
|
131 |
+
|
132 |
+
def optimal_step_size(last_step, mean_error_ratio, safety=0.9, ifactor=10.0,
|
133 |
+
dfactor=0.2, order=5.0):
|
134 |
+
"""Compute optimal Runge-Kutta stepsize."""
|
135 |
+
mean_error_ratio = jnp.max(mean_error_ratio)
|
136 |
+
dfactor = jnp.where(mean_error_ratio < 1, 1.0, dfactor)
|
137 |
+
|
138 |
+
err_ratio = jnp.sqrt(mean_error_ratio)
|
139 |
+
factor = jnp.maximum(1.0 / ifactor,
|
140 |
+
jnp.minimum(err_ratio**(1.0 / order) / safety, 1.0 / dfactor))
|
141 |
+
return jnp.where(mean_error_ratio == 0, last_step * ifactor, last_step / factor)
|
142 |
+
|
143 |
+
def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=jnp.inf):
|
144 |
+
"""Adaptive stepsize (Dormand-Prince) Runge-Kutta odeint implementation.
|
145 |
+
Args:
|
146 |
+
func: function to evaluate the time derivative of the solution `y` at time
|
147 |
+
`t` as `func(y, t, *args)`, producing the same shape/structure as `y0`.
|
148 |
+
y0: array or pytree of arrays representing the initial value for the state.
|
149 |
+
t: array of float times for evaluation, like `jnp.linspace(0., 10., 101)`,
|
150 |
+
in which the values must be strictly increasing.
|
151 |
+
*args: tuple of additional arguments for `func`, which must be arrays
|
152 |
+
scalars, or (nested) standard Python containers (tuples, lists, dicts,
|
153 |
+
namedtuples, i.e. pytrees) of those types.
|
154 |
+
rtol: float, relative local error tolerance for solver (optional).
|
155 |
+
atol: float, absolute local error tolerance for solver (optional).
|
156 |
+
mxstep: int, maximum number of steps to take for each timepoint (optional).
|
157 |
+
Returns:
|
158 |
+
Values of the solution `y` (i.e. integrated system values) at each time
|
159 |
+
point in `t`, represented as an array (or pytree of arrays) with the same
|
160 |
+
shape/structure as `y0` except with a new leading axis of length `len(t)`.
|
161 |
+
"""
|
162 |
+
def _check_arg(arg):
|
163 |
+
if not isinstance(arg, core.Tracer) and not core.valid_jaxtype(arg):
|
164 |
+
msg = ("The contents of odeint *args must be arrays or scalars, but got "
|
165 |
+
"\n{}.")
|
166 |
+
raise TypeError(msg.format(arg))
|
167 |
+
|
168 |
+
converted, consts = custom_derivatives.closure_convert(func, y0, t[0], *args)
|
169 |
+
return _odeint_wrapper(converted, rtol, atol, mxstep, y0, t, *args, *consts)
|
170 |
+
|
171 |
+
@partial(jax.jit, static_argnums=(0, 1, 2, 3))
|
172 |
+
def _odeint_wrapper(func, rtol, atol, mxstep, y0, ts, *args):
|
173 |
+
y0, unravel = ravel_pytree(y0)
|
174 |
+
func = ravel_first_arg(func, unravel)
|
175 |
+
out = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
|
176 |
+
return jax.vmap(unravel)(out)
|
177 |
+
|
178 |
+
@partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2, 3))
|
179 |
+
def _odeint(func, rtol, atol, mxstep, y0, ts, *args):
|
180 |
+
func_ = lambda y, t: func(y, t, *args)
|
181 |
+
|
182 |
+
def scan_fun(carry, target_t):
|
183 |
+
|
184 |
+
def cond_fun(state):
|
185 |
+
i, _, _, t, dt, _, _ = state
|
186 |
+
return (t < target_t) & (i < mxstep) & (dt > 0)
|
187 |
+
|
188 |
+
def body_fun(state):
|
189 |
+
i, y, f, t, dt, last_t, interp_coeff = state
|
190 |
+
next_y, next_f, next_y_error, k = runge_kutta_step(func_, y, f, t, dt)
|
191 |
+
next_t = t + dt
|
192 |
+
error_ratios = error_ratio(next_y_error, rtol, atol, y, next_y)
|
193 |
+
new_interp_coeff = interp_fit_dopri(y, next_y, k, dt)
|
194 |
+
dt = optimal_step_size(dt, error_ratios)
|
195 |
+
|
196 |
+
new = [i + 1, next_y, next_f, next_t, dt, t, new_interp_coeff]
|
197 |
+
old = [i + 1, y, f, t, dt, last_t, interp_coeff]
|
198 |
+
return map(partial(jnp.where, jnp.all(error_ratios <= 1.)), new, old)
|
199 |
+
|
200 |
+
_, *carry = lax.while_loop(cond_fun, body_fun, [0] + carry)
|
201 |
+
_, _, t, _, last_t, interp_coeff = carry
|
202 |
+
relative_output_time = (target_t - last_t) / (t - last_t)
|
203 |
+
y_target = jnp.polyval(interp_coeff, relative_output_time)
|
204 |
+
return carry, y_target
|
205 |
+
|
206 |
+
# ODEfunc with NFE counter will give auxilarly output for nfe.
|
207 |
+
# Below code is modified to skip that output.
|
208 |
+
f0 = func_(y0, ts[0])
|
209 |
+
dt = initial_step_size(func_, ts[0], y0, 4, rtol, atol, f0)
|
210 |
+
interp_coeff = jnp.array([y0] * 5)
|
211 |
+
init_carry = [y0, f0, ts[0], dt, ts[0], interp_coeff]
|
212 |
+
_, ys = lax.scan(scan_fun, init_carry, ts[1:])
|
213 |
+
return jnp.concatenate((y0[None], ys))
|
214 |
+
|
215 |
+
def _odeint_fwd(func, rtol, atol, mxstep, y0, ts, *args):
|
216 |
+
ys = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
|
217 |
+
return ys, (ys, ts, args)
|
218 |
+
|
219 |
+
def _odeint_rev(func, rtol, atol, mxstep, res, g):
|
220 |
+
ys, ts, args = res
|
221 |
+
|
222 |
+
def aug_dynamics(augmented_state, t, *args):
|
223 |
+
"""Original system augmented with vjp_y, vjp_t and vjp_args."""
|
224 |
+
y, y_bar, *_ = augmented_state
|
225 |
+
# `t` here is negatice time, so we need to negate again to get back to
|
226 |
+
# normal time. See the `odeint` invocation in `scan_fun` below.
|
227 |
+
y_dot, vjpfun = jax.vjp(func, y, -t, *args)
|
228 |
+
return (-y_dot, *vjpfun(y_bar))
|
229 |
+
|
230 |
+
y_bar = g[-1]
|
231 |
+
ts_bar = []
|
232 |
+
t0_bar = 0.
|
233 |
+
|
234 |
+
def scan_fun(carry, i):
|
235 |
+
y_bar, t0_bar, args_bar = carry
|
236 |
+
# Compute effect of moving measurement time
|
237 |
+
# `t_bar` should not be complex as it represents time
|
238 |
+
t_bar = jnp.dot(func(ys[i], ts[i], *args), g[i]).real
|
239 |
+
t0_bar = t0_bar - t_bar
|
240 |
+
# Run augmented system backwards to previous observation
|
241 |
+
_, y_bar, t0_bar, args_bar = odeint(
|
242 |
+
aug_dynamics, (ys[i], y_bar, t0_bar, args_bar),
|
243 |
+
jnp.array([-ts[i], -ts[i - 1]]),
|
244 |
+
*args, rtol=rtol, atol=atol, mxstep=mxstep)
|
245 |
+
y_bar, t0_bar, args_bar = tree_map(op.itemgetter(1), (y_bar, t0_bar, args_bar))
|
246 |
+
# Add gradient from current output
|
247 |
+
y_bar = y_bar + g[i - 1]
|
248 |
+
return (y_bar, t0_bar, args_bar), t_bar
|
249 |
+
|
250 |
+
init_carry = (g[-1], 0., tree_map(jnp.zeros_like, args))
|
251 |
+
(y_bar, t0_bar, args_bar), rev_ts_bar = lax.scan(
|
252 |
+
scan_fun, init_carry, jnp.arange(len(ts) - 1, 0, -1))
|
253 |
+
ts_bar = jnp.concatenate([jnp.array([t0_bar]), rev_ts_bar[::-1]])
|
254 |
+
return (y_bar, ts_bar, *args_bar)
|
255 |
+
|
256 |
+
_odeint.defvjp(_odeint_fwd, _odeint_rev)
|
opts.py
ADDED
File without changes
|
train.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import jax
|
3 |
+
from typing import Any, Callable, Sequence, Optional, NewType
|
4 |
+
from jax import lax, random, vmap, numpy as jnp
|
5 |
+
from jax.experimental.ode import odeint
|
6 |
+
import flax
|
7 |
+
from flax.training import train_state
|
8 |
+
from flax import traverse_util
|
9 |
+
from flax.core import freeze, unfreeze
|
10 |
+
from flax import linen as nn
|
11 |
+
from flax import serialization
|
12 |
+
import optax
|
13 |
+
import tensorflow_datasets as tfds
|
14 |
+
import numpy as np
|
15 |
+
from tqdm import tqdm
|
16 |
+
import os
|
17 |
+
|
18 |
+
|
19 |
+
# Define loss
|
20 |
+
@jax.jit
|
21 |
+
def cross_entropy_loss(logits, labels):
|
22 |
+
one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
|
23 |
+
return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))
|
24 |
+
|
25 |
+
|
26 |
+
# Metric computation
|
27 |
+
@jax.jit
|
28 |
+
def compute_metrics(logits, labels):
|
29 |
+
loss = cross_entropy_loss(logits=logits, labels=labels)
|
30 |
+
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
|
31 |
+
metrics = {
|
32 |
+
'loss': loss,
|
33 |
+
'accuracy': accuracy,
|
34 |
+
}
|
35 |
+
return metrics
|
36 |
+
|
37 |
+
|
38 |
+
def get_datasets():
|
39 |
+
"""Load MNIST train and test datasets into memory."""
|
40 |
+
ds_builder = tfds.builder('mnist')
|
41 |
+
ds_builder.download_and_prepare()
|
42 |
+
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
|
43 |
+
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
|
44 |
+
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
|
45 |
+
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
|
46 |
+
return train_ds, test_ds
|
47 |
+
|
48 |
+
|
49 |
+
def create_train_state(model, rng, learning_rate):
|
50 |
+
"""Creates initial 'TrainState'."""
|
51 |
+
params = model.init(rng, jnp.ones([1, 28, 28, 1]))['params']
|
52 |
+
tx = optax.adam(learning_rate)
|
53 |
+
return train_state.TrainState.create(
|
54 |
+
apply_fn=model.apply, params=params, tx=tx
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
# Training step
|
59 |
+
@jax.jit
|
60 |
+
def train_step(state, batch):
|
61 |
+
"""Train for a single step."""
|
62 |
+
def loss_fn(params):
|
63 |
+
logits = apply({'params': params}, batch['image'])
|
64 |
+
loss = cross_entropy_loss(logits=logits, labels=batch['label'])
|
65 |
+
return loss, logits
|
66 |
+
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
67 |
+
(_, logits), grads = grad_fn(state.params)
|
68 |
+
state = state.apply_gradients(grads=grads)
|
69 |
+
metrics = compute_metrics(logits=logits, labels=batch['label'])
|
70 |
+
return state, metrics
|
71 |
+
|
72 |
+
|
73 |
+
# Evaluation step
|
74 |
+
@jax.jit
|
75 |
+
def eval_step(params, batch):
|
76 |
+
logits = apply({'params': params}, batch['image'])
|
77 |
+
return compute_metrics(logits=logits, labels=batch['label'])
|
78 |
+
|
79 |
+
|
80 |
+
# Train function
|
81 |
+
def train_epoch(model, state, train_ds, batch_size, epoch, rng):
|
82 |
+
"""Train for a single epoch"""
|
83 |
+
train_ds_size = len(train_ds['image'])
|
84 |
+
steps_per_epoch = train_ds_size // batch_size
|
85 |
+
|
86 |
+
perms = jax.random.permutation(rng, len(train_ds['image']))
|
87 |
+
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
|
88 |
+
perms = perms.reshape((steps_per_epoch, batch_size))
|
89 |
+
batch_metrics = []
|
90 |
+
for perm in tqdm(perms):
|
91 |
+
batch = {k: v[perm, ...] for k, v in train_ds.items()}
|
92 |
+
state, metrics = train_step(model, state, batch)
|
93 |
+
batch_metrics.append(metrics)
|
94 |
+
|
95 |
+
# compute mean of metrics across each batch in epoch.
|
96 |
+
batch_metrics_np = jax.device_get(batch_metrics)
|
97 |
+
epoch_metrics_np = {
|
98 |
+
k: np.mean([metrics[k] for metrics in batch_metrics_np])
|
99 |
+
for k in batch_metrics_np[0]
|
100 |
+
}
|
101 |
+
print('train epoch: %d, loss: %.4f, accuracy: %.2f' % (
|
102 |
+
epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100
|
103 |
+
))
|
104 |
+
|
105 |
+
return state
|
106 |
+
|
107 |
+
|
108 |
+
# Eval function
|
109 |
+
def eval_model(params, test_ds):
|
110 |
+
metrics = eval_step(params, test_ds)
|
111 |
+
metrics = jax.device_get(metrics)
|
112 |
+
summary = jax.tree_map(lambda x: x.item(), metrics)
|
113 |
+
return summary['loss'], summary['accuracy']
|
114 |
+
|
115 |
+
|
116 |
+
def train_and_evaluate(learning_rate, n_epoch, batch_size):
|
117 |
+
train_ds, test_ds = get_datasets()
|
118 |
+
rng = jax.random.PRNGKey(0)
|
119 |
+
rng, init_rng = jax.random.split(rng)
|
120 |
+
|
121 |
+
state = create_train_state(init_rng, learning_rate)
|
122 |
+
del init_rng # Must not be used anymore.
|
123 |
+
|
124 |
+
for epoch in tqdm(range(1, n_epoch + 1)):
|
125 |
+
rng, input_rng = jax.random.split(rng)
|
126 |
+
state = train_epoch(state, train_ds, batch_size, epoch, input_rng)
|
127 |
+
test_loss, test_accuracy = eval_model(state.params, test_ds)
|
128 |
+
print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (
|
129 |
+
epoch, test_loss, test_accuracy * 100
|
130 |
+
))
|
train_cnf.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import os
|
4 |
+
import glob
|
5 |
+
from PIL import Image
|
6 |
+
from functools import partial
|
7 |
+
import jax
|
8 |
+
from typing import Any, Callable, Sequence, Optional, NewType
|
9 |
+
from jax import lax, random, vmap, scipy, numpy as jnp
|
10 |
+
# from jax.experimental.ode import odeint
|
11 |
+
from models.ode import odeint
|
12 |
+
import flax
|
13 |
+
from flax.training import train_state
|
14 |
+
from flax import traverse_util
|
15 |
+
from flax.core import freeze, unfreeze
|
16 |
+
from flax import linen as nn
|
17 |
+
from flax import serialization
|
18 |
+
import optax
|
19 |
+
from sklearn.datasets import make_circles
|
20 |
+
from tqdm import tqdm
|
21 |
+
|
22 |
+
|
23 |
+
# os.environ['TF_FORCE_UNIFIED_MEMORY'] = '1'
|
24 |
+
# os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
|
25 |
+
# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
|
26 |
+
|
27 |
+
|
28 |
+
class HyperNetwork(nn.Module):
|
29 |
+
"""Hyper-network allowing f(z(t), t) to change with time.
|
30 |
+
|
31 |
+
Adapted from the Pytorch implementation at:
|
32 |
+
https://github.com/rtqichen/torchdiffeq/blob/master/examples/cnf.py
|
33 |
+
"""
|
34 |
+
in_out_dim: Any = 2
|
35 |
+
hidden_dim: Any = 32
|
36 |
+
width: Any = 64
|
37 |
+
|
38 |
+
@nn.compact
|
39 |
+
def __call__(self, t):
|
40 |
+
# predict params
|
41 |
+
blocksize = self.width * self.in_out_dim
|
42 |
+
params = lax.expand_dims(t, (0, 1))
|
43 |
+
params = nn.Dense(self.hidden_dim)(params)
|
44 |
+
params = nn.tanh(params)
|
45 |
+
params = nn.Dense(self.hidden_dim)(params)
|
46 |
+
params = nn.tanh(params)
|
47 |
+
params = nn.Dense(3 * blocksize + self.width)(params)
|
48 |
+
|
49 |
+
# restructure
|
50 |
+
params = lax.reshape(params, (3 * blocksize + self.width,))
|
51 |
+
W = lax.reshape(params[:blocksize], (self.width, self.in_out_dim, 1))
|
52 |
+
|
53 |
+
U = lax.reshape(params[blocksize:2 * blocksize], (self.width, 1, self.in_out_dim))
|
54 |
+
|
55 |
+
G = lax.reshape(params[2 * blocksize:3 * blocksize], (self.width, 1, self.in_out_dim))
|
56 |
+
U = U * nn.sigmoid(G)
|
57 |
+
|
58 |
+
B = lax.expand_dims(params[3 * blocksize:], (1, 2))
|
59 |
+
return W, B, U
|
60 |
+
|
61 |
+
|
62 |
+
class CNF(nn.Module):
|
63 |
+
"""Adapted from the Pytorch implementation at:
|
64 |
+
https://github.com/rtqichen/torchdiffeq/blob/master/examples/cnf.py
|
65 |
+
"""
|
66 |
+
in_out_dim: Any = 2
|
67 |
+
hidden_dim: Any = 32
|
68 |
+
width: Any = 64
|
69 |
+
|
70 |
+
@nn.compact
|
71 |
+
def __call__(self, t, states):
|
72 |
+
z, logp_z = states[..., :2], states[..., 2:]
|
73 |
+
W, B, U = HyperNetwork(self.in_out_dim, self.hidden_dim, self.width)(t)
|
74 |
+
|
75 |
+
# TODO Below should be converted using vmap
|
76 |
+
def dzdt(z):
|
77 |
+
Z = lax.expand_dims(z, (0,))
|
78 |
+
Z = jnp.repeat(Z, self.width, 0)
|
79 |
+
h = nn.tanh(jnp.matmul(Z, W) + B)
|
80 |
+
return jnp.matmul(h, U).mean(0)
|
81 |
+
|
82 |
+
dz_dt = dzdt(z)
|
83 |
+
sum_dzdt = lambda z: jnp.sum(dzdt(z), 1)
|
84 |
+
df_dz = jax.jacrev(sum_dzdt)(z)
|
85 |
+
dlogp_z_dt = -1.0 * jnp.trace(df_dz, 0, 1, 2)
|
86 |
+
|
87 |
+
return lax.concatenate((dz_dt, lax.expand_dims(dlogp_z_dt, (1,))), 1)
|
88 |
+
|
89 |
+
|
90 |
+
class Neg_CNF(nn.Module):
|
91 |
+
"""Negative CNF for jax's odeint."""
|
92 |
+
in_out_dim: Any = 2
|
93 |
+
hidden_dim: Any = 32
|
94 |
+
width: Any = 64
|
95 |
+
|
96 |
+
@nn.compact
|
97 |
+
def __call__(self, t, states):
|
98 |
+
outputs = CNF(self.in_out_dim, self.hidden_dim, self.width)(-1.0 * t, states)
|
99 |
+
|
100 |
+
return -1.0 * outputs
|
101 |
+
|
102 |
+
|
103 |
+
def get_batch(num_samples):
|
104 |
+
"""Adapted from the Pytorch implementation at:
|
105 |
+
https://github.com/rtqichen/torchdiffeq/blob/master/examples/cnf.py
|
106 |
+
"""
|
107 |
+
points, _ = make_circles(n_samples=num_samples, noise=0.06, factor=0.5)
|
108 |
+
x = jnp.array(points, dtype=jnp.float32)
|
109 |
+
logp_diff_t1 = jnp.zeros((num_samples, 1), dtype=jnp.float32)
|
110 |
+
|
111 |
+
return lax.concatenate((x, logp_diff_t1), 1)
|
112 |
+
|
113 |
+
|
114 |
+
def create_train_state(rng, learning_rate, in_out_dim, hidden_dim, width):
|
115 |
+
"""Creates initial 'TrainState'."""
|
116 |
+
inputs = get_batch(10)
|
117 |
+
neg_cnf = CNF(in_out_dim, hidden_dim, width)
|
118 |
+
params = neg_cnf.init(rng, jnp.array(10.), inputs)['params']
|
119 |
+
# set_params(params)
|
120 |
+
tx = optax.adam(learning_rate)
|
121 |
+
return train_state.TrainState.create(
|
122 |
+
apply_fn=neg_cnf.apply, params=params, tx=tx
|
123 |
+
)
|
124 |
+
|
125 |
+
|
126 |
+
def set_params(params):
|
127 |
+
# Convert all value of Params to certain constant
|
128 |
+
params = unfreeze(params)
|
129 |
+
# Get flattened-key: value list.
|
130 |
+
flat_params = {'/'.join(k): v for k, v in traverse_util.flatten_dict(params).items()}
|
131 |
+
unflat_params = traverse_util.unflatten_dict({tuple(k.split('/')): 0.2 * jnp.ones_like(v) for k, v in flat_params.items()})
|
132 |
+
new_params = freeze(unflat_params)
|
133 |
+
test_x = jnp.array([[0., 1.], [2., 3.]])
|
134 |
+
test_log_p = jnp.zeros(2, 1)
|
135 |
+
test_inputs = lax.concatenate((test_x, test_log_p), 1)
|
136 |
+
Neg_CNF().apply({'params': new_params}, jnp.array(0.), test_inputs)
|
137 |
+
|
138 |
+
|
139 |
+
# @partial(jax.jit, static_argnums=(2, 3, 4, 5, 6))
|
140 |
+
def train_step(state, batch, in_out_dim, hidden_dim, width, t0, t1):
|
141 |
+
p_z0 = lambda x: scipy.stats.multivariate_normal.logpdf(x,
|
142 |
+
mean=jnp.array([0., 0.]),
|
143 |
+
cov=jnp.array([[0.1, 0.], [0., 0.1]]))
|
144 |
+
def loss_fn(params):
|
145 |
+
func = lambda states, t: Neg_CNF(in_out_dim, hidden_dim, width).apply({'params': params}, t, states)
|
146 |
+
outputs = odeint(
|
147 |
+
func,
|
148 |
+
batch,
|
149 |
+
-1.0 * jnp.array([t1, t0]),
|
150 |
+
atol=1e-5,
|
151 |
+
rtol=1e-5
|
152 |
+
)
|
153 |
+
z_t, logp_diff_t = outputs[..., :2], outputs[..., 2:]
|
154 |
+
z_t0, logp_diff_t0 = z_t[-1], logp_diff_t[-1]
|
155 |
+
logp_x = p_z0(z_t0) - lax.squeeze(logp_diff_t0, dimensions=(1,))
|
156 |
+
loss = -logp_x.mean(0)
|
157 |
+
return loss
|
158 |
+
grad_fn = jax.value_and_grad(loss_fn)
|
159 |
+
loss, grads = grad_fn(state.params)
|
160 |
+
state = state.apply_gradients(grads=grads)
|
161 |
+
|
162 |
+
return state, loss
|
163 |
+
|
164 |
+
|
165 |
+
def train(learning_rate, n_iters, batch_size, in_out_dim, hidden_dim, width, t0, t1, visual):
|
166 |
+
"""Train the model."""
|
167 |
+
rng = jax.random.PRNGKey(0)
|
168 |
+
state = create_train_state(rng, learning_rate, in_out_dim, hidden_dim, width)
|
169 |
+
|
170 |
+
for itr in range(1, n_iters+1):
|
171 |
+
batch = get_batch(batch_size)
|
172 |
+
state, loss = train_step(state, batch, in_out_dim, hidden_dim, width, t0, t1)
|
173 |
+
print("iter: %d, loss: %.2f" % (itr, loss))
|
174 |
+
|
175 |
+
if visual is True:
|
176 |
+
# Convert Params of Neg_CNF to CNF
|
177 |
+
neg_params = state.params
|
178 |
+
neg_params = unfreeze(neg_params)
|
179 |
+
# Get flattened-key: value list.
|
180 |
+
neg_flat_params = {'/'.join(k): v for k, v in traverse_util.flatten_dict(neg_params).items()}
|
181 |
+
pos_flat_params = {key[6:]: jnp.array(np.array(neg_flat_params[key])) for key in list(neg_flat_params.keys())}
|
182 |
+
pos_unflat_params = traverse_util.unflatten_dict({tuple(k.split('/')): v for k, v in pos_flat_params.items()})
|
183 |
+
pos_params = freeze(pos_unflat_params)
|
184 |
+
output = viz(neg_params, pos_params, in_out_dim, hidden_dim, width, t0, t1)
|
185 |
+
z_t_samples, z_t_density, logp_diff_t, viz_timesteps, target_sample, z_t1 = output
|
186 |
+
create_plots(z_t_samples, z_t_density, logp_diff_t, t0, t1, viz_timesteps, target_sample, z_t1)
|
187 |
+
|
188 |
+
|
189 |
+
def solve_dynamics(dynamics_fn, initial_state, t):
|
190 |
+
def f(initial_state, t):
|
191 |
+
return odeint(dynamics_fn, initial_state, t, atol=1e-5, rtol=1e-5)
|
192 |
+
return f(initial_state, t)
|
193 |
+
|
194 |
+
|
195 |
+
def viz(neg_params, pos_params, in_out_dim, hidden_dim, width, t0, t1):
|
196 |
+
"""Adapted from PyTorch """
|
197 |
+
viz_samples = 5000
|
198 |
+
viz_timesteps = 2
|
199 |
+
target_sample, _ = get_batch(viz_samples)
|
200 |
+
|
201 |
+
if not os.path.exists('results/'):
|
202 |
+
os.makedirs('results/')
|
203 |
+
|
204 |
+
z_t0 = jnp.array(np.random.multivariate_normal(mean=np.array([0., 0.]),
|
205 |
+
cov=np.array([[0.1, 0.], [0., 0.1]]),
|
206 |
+
size=viz_samples))
|
207 |
+
logp_diff_t0 = jnp.zeros((viz_samples, 1), dtype=jnp.float32)
|
208 |
+
|
209 |
+
func_pos = lambda states, t: CNF(in_out_dim, hidden_dim, width).apply({'params': pos_params}, t, states)
|
210 |
+
z_t_samples, _ = solve_dynamics(func_pos, (z_t0, logp_diff_t0), jnp.linspace(t0, t1, viz_timesteps))
|
211 |
+
|
212 |
+
# Generate evolution of density
|
213 |
+
x = jnp.linspace(-1.5, 1.5, 100)
|
214 |
+
y = jnp.linspace(-1.5, 1.5, 100)
|
215 |
+
points = np.vstack(jnp.meshgrid(x, y)).reshape([2, -1]).T
|
216 |
+
|
217 |
+
z_t1 = jnp.array(points, dtype=jnp.float32)
|
218 |
+
logp_diff_t1 = jnp.zeros((z_t1.shape[0], 1), dtype=jnp.float32)
|
219 |
+
func_neg = lambda states, t: Neg_CNF(in_out_dim, hidden_dim, width).apply({'params': neg_params}, -t, states)
|
220 |
+
z_t_density, logp_diff_t = solve_dynamics(func_neg, (z_t1, logp_diff_t1), -jnp.linspace(t1, t0, viz_timesteps))
|
221 |
+
|
222 |
+
return z_t_samples, z_t_density, logp_diff_t, viz_timesteps, target_sample, z_t1
|
223 |
+
|
224 |
+
|
225 |
+
def create_plots(z_t_samples, z_t_density, logp_diff_t, t0, t1, viz_timesteps, target_sample, z_t1):
|
226 |
+
# Create plots for each timestep
|
227 |
+
for (t, z_sample, z_density, logp_diff) in zip(
|
228 |
+
tqdm(np.linspace(t0, t1, viz_timesteps)),
|
229 |
+
z_t_samples, z_t_density, logp_diff_t
|
230 |
+
):
|
231 |
+
fig = plt.figure(figsize=(12, 4), dpi=200)
|
232 |
+
plt.tight_layout()
|
233 |
+
plt.axis('off')
|
234 |
+
plt.margins(0, 0)
|
235 |
+
fig.suptitle(f'{t:.2f}s')
|
236 |
+
|
237 |
+
ax1 = fig.add_subplot(1, 3, 1)
|
238 |
+
ax1.set_title('Target')
|
239 |
+
ax1.get_xaxis().set_ticks([])
|
240 |
+
ax1.get_yaxis().set_ticks([])
|
241 |
+
ax2 = fig.add_subplot(1, 3, 2)
|
242 |
+
ax2.set_title('Samples')
|
243 |
+
ax2.get_xaxis().set_ticks([])
|
244 |
+
ax2.get_yaxis().set_ticks([])
|
245 |
+
ax3 = fig.add_subplot(1, 3, 3)
|
246 |
+
ax3.set_title('Log Probability')
|
247 |
+
ax3.get_xaxis().set_ticks([])
|
248 |
+
ax3.get_yaxis().set_ticks([])
|
249 |
+
|
250 |
+
ax1.hist2d(*jnp.transpose(target_sample), bins=300, density=True,
|
251 |
+
range=[[-1.5, 1.5], [-1.5, 1.5]])
|
252 |
+
|
253 |
+
ax2.hist2d(*jnp.transpose(z_sample), bins=300, density=True,
|
254 |
+
range=[[-1.5, 1.5], [-1.5, 1.5]])
|
255 |
+
p_z0 = lambda x: scipy.stats.multivariate_normal.logpdf(x,
|
256 |
+
mean=jnp.array([0., 0.]),
|
257 |
+
cov=jnp.array([[0.1, 0.], [0., 0.1]]))
|
258 |
+
logp = p_z0(z_density) - lax.reshape(logp_diff, (z_density.shape[0]))
|
259 |
+
ax3.tricontourf(*jnp.transpose(z_t1),
|
260 |
+
jnp.exp(logp), 200)
|
261 |
+
|
262 |
+
plt.savefig(os.path.join('results/', f"cnf-viz-{int(t * 1000):05d}.jpg"),
|
263 |
+
pad_inches=0.2, bbox_inches='tight')
|
264 |
+
plt.close()
|
265 |
+
|
266 |
+
img, *imgs = [Image.open(f) for f in sorted(glob.glob(os.path.join('results/', f"cnf-viz-*.jpg")))]
|
267 |
+
img.save(fp=os.path.join('results/', "cnf-viz.gif"), format='GIF', append_images=imgs,
|
268 |
+
save_all=True, duration=250, loop=0)
|
269 |
+
|
270 |
+
print('Saved visualization animation at {}'.format(os.path.join('results/', "cnf-viz.gif")))
|
271 |
+
|
272 |
+
|
273 |
+
if __name__ == '__main__':
|
274 |
+
train(0.001, 100, 512, 2, 32, 64, 0., 10., True)
|
train_ode.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import jax
|
3 |
+
from typing import Any, Callable, Sequence, Optional, NewType
|
4 |
+
from jax import lax, random, vmap, numpy as jnp
|
5 |
+
from jax.experimental.ode import odeint
|
6 |
+
from jax.experimental import host_callback
|
7 |
+
import flax
|
8 |
+
from flax.training import train_state
|
9 |
+
from flax import traverse_util
|
10 |
+
from flax.core import freeze, unfreeze
|
11 |
+
from flax import linen as nn
|
12 |
+
from flax import serialization
|
13 |
+
import optax
|
14 |
+
import tensorflow_datasets as tfds
|
15 |
+
import numpy as np
|
16 |
+
from tqdm import tqdm
|
17 |
+
import os
|
18 |
+
|
19 |
+
|
20 |
+
# Define Residual Block
|
21 |
+
class ResDownBlock(nn.Module):
|
22 |
+
"""Single ResBlock w/ downsample"""
|
23 |
+
dim_out: Any = 64
|
24 |
+
|
25 |
+
@nn.compact
|
26 |
+
def __call__(self, inputs):
|
27 |
+
x = inputs
|
28 |
+
f_x = nn.relu(nn.GroupNorm(self.dim_out)(x))
|
29 |
+
x = nn.Conv(features=self.dim_out, kernel_size=(1, 1), strides=(2, 2))(x)
|
30 |
+
f_x = nn.Conv(features=self.dim_out, kernel_size=(3, 3), strides=(2, 2))(f_x)
|
31 |
+
f_x = nn.relu(nn.GroupNorm(self.dim_out)(f_x))
|
32 |
+
f_x = nn.Conv(features=self.dim_out, kernel_size=(3, 3))(f_x)
|
33 |
+
x = f_x + x
|
34 |
+
return x
|
35 |
+
|
36 |
+
|
37 |
+
class ConcatConv2D(nn.Module):
|
38 |
+
"""Concat dynamics to hidden layer"""
|
39 |
+
dim_out: Any = 64
|
40 |
+
ksize: Any = 3
|
41 |
+
|
42 |
+
@nn.compact
|
43 |
+
def __call__(self, inputs, t):
|
44 |
+
x = inputs
|
45 |
+
tt = jnp.ones_like(x[..., :1]) * t
|
46 |
+
ttx = jnp.concatenate([tt, x], -1)
|
47 |
+
return nn.Conv(features=self.dim_out, kernel_size=(self.ksize, self.ksize))(ttx)
|
48 |
+
|
49 |
+
|
50 |
+
# Define Neural ODE for mnist example.
|
51 |
+
class ODEfunc(nn.Module):
|
52 |
+
"""ODE function which replace ResNet"""
|
53 |
+
dim_out: Any = 64
|
54 |
+
ksize: Any = 3
|
55 |
+
|
56 |
+
@nn.compact
|
57 |
+
def __call__(self, inputs, t):
|
58 |
+
# TODO Count number of function estimation
|
59 |
+
host_callback.call(nfecounter.count, 1)
|
60 |
+
|
61 |
+
x = inputs
|
62 |
+
out = nn.GroupNorm(self.dim_out)(x)
|
63 |
+
out = nn.relu(out)
|
64 |
+
out = ConcatConv2D(self.dim_out, self.ksize)(out, t)
|
65 |
+
out = nn.GroupNorm(self.dim_out)(out)
|
66 |
+
out = nn.relu(out)
|
67 |
+
out = ConcatConv2D(self.dim_out, self.ksize)(out, t)
|
68 |
+
out = nn.GroupNorm(self.dim_out)(out)
|
69 |
+
|
70 |
+
return out
|
71 |
+
|
72 |
+
|
73 |
+
class NFEcounter:
|
74 |
+
def __init__(self, init_nfe):
|
75 |
+
self.nfe = init_nfe
|
76 |
+
|
77 |
+
def count(self, increase):
|
78 |
+
self.nfe += increase
|
79 |
+
|
80 |
+
def set(self, target):
|
81 |
+
self.nfe = target
|
82 |
+
|
83 |
+
# Define NFE counter
|
84 |
+
nfecounter = NFEcounter(0)
|
85 |
+
|
86 |
+
|
87 |
+
class ODEBlock(nn.Module):
|
88 |
+
"""ODE block which contains odeint"""
|
89 |
+
tol: Any = 1.
|
90 |
+
|
91 |
+
@nn.compact
|
92 |
+
def __call__(self, inputs, params):
|
93 |
+
ode_func = ODEfunc()
|
94 |
+
ode_func_apply = lambda x, t: ode_func.apply(variables={'params': params}, inputs=x, t=t)
|
95 |
+
init_state, final_state = odeint(ode_func_apply,
|
96 |
+
inputs, jnp.array([0., 1.]),
|
97 |
+
rtol=self.tol, atol=self.tol)
|
98 |
+
return final_state
|
99 |
+
|
100 |
+
|
101 |
+
class ODEBlockVmap(nn.Module):
|
102 |
+
"""Apply vmap to ODEBlock"""
|
103 |
+
tol: Any = 1.
|
104 |
+
|
105 |
+
@nn.compact
|
106 |
+
def __call__(self, inputs, params):
|
107 |
+
x = inputs
|
108 |
+
vmap_odeblock = nn.vmap(ODEBlock,
|
109 |
+
variable_axes={'params': 0},
|
110 |
+
split_rngs={'params': True},
|
111 |
+
in_axes=(0, None))
|
112 |
+
|
113 |
+
return vmap_odeblock(tol=self.tol, name='odeblock')(x, params)
|
114 |
+
|
115 |
+
|
116 |
+
class FullODENet(nn.Module):
|
117 |
+
"""Full ODE net which contains two downsampling layers, ODE block and linear classifier."""
|
118 |
+
dim_out: Any = 64
|
119 |
+
ksize: Any = 3
|
120 |
+
tol: Any = 1.
|
121 |
+
|
122 |
+
@nn.compact
|
123 |
+
def __call__(self, inputs):
|
124 |
+
x = inputs
|
125 |
+
x = nn.Conv(features=self.dim_out, kernel_size=(self.ksize, self.ksize))(x)
|
126 |
+
x = ResDownBlock()(x)
|
127 |
+
x = ResDownBlock()(x)
|
128 |
+
|
129 |
+
ode_func = ODEfunc()
|
130 |
+
init_fn = lambda rng, x: ode_func.init(random.split(rng)[-1], x, 0.)['params']
|
131 |
+
ode_func_params = self.param('ode_func', init_fn, jnp.ones_like(x[0]))
|
132 |
+
x = ODEBlockVmap(tol=self.tol)(x, ode_func_params)
|
133 |
+
|
134 |
+
x = nn.GroupNorm(self.dim_out)(x)
|
135 |
+
x = nn.relu(x)
|
136 |
+
x = nn.avg_pool(x, (1, 1))
|
137 |
+
|
138 |
+
x = x.reshape((x.shape[0], -1)) # flatten
|
139 |
+
|
140 |
+
x = nn.Dense(features=10)(x)
|
141 |
+
x = nn.log_softmax(x)
|
142 |
+
|
143 |
+
return x
|
144 |
+
|
145 |
+
|
146 |
+
# Define loss
|
147 |
+
@jax.jit
|
148 |
+
def cross_entropy_loss(logits, labels):
|
149 |
+
one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
|
150 |
+
return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))
|
151 |
+
|
152 |
+
|
153 |
+
# Metric computation
|
154 |
+
@jax.jit
|
155 |
+
def compute_metrics(logits, labels, nfe_forward, nfe_backward):
|
156 |
+
loss = cross_entropy_loss(logits=logits, labels=labels)
|
157 |
+
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
|
158 |
+
metrics = {
|
159 |
+
'loss': loss,
|
160 |
+
'accuracy': accuracy,
|
161 |
+
'nfe_forward': nfe_forward,
|
162 |
+
'nfe_backward': nfe_backward
|
163 |
+
}
|
164 |
+
return metrics
|
165 |
+
|
166 |
+
|
167 |
+
def get_datasets():
|
168 |
+
"""Load MNIST train and test datasets into memory."""
|
169 |
+
ds_builder = tfds.builder('mnist')
|
170 |
+
ds_builder.download_and_prepare()
|
171 |
+
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
|
172 |
+
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
|
173 |
+
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
|
174 |
+
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
|
175 |
+
return train_ds, test_ds
|
176 |
+
|
177 |
+
|
178 |
+
def create_train_state(rng, learning_rate, tol):
|
179 |
+
"""Creates initial 'TrainState'."""
|
180 |
+
odenet = FullODENet(tol=tol)
|
181 |
+
params = odenet.init(rng, jnp.ones([1, 28, 28, 1]))['params']
|
182 |
+
tx = optax.adam(learning_rate)
|
183 |
+
return train_state.TrainState.create(
|
184 |
+
apply_fn=odenet.apply, params=params, tx=tx
|
185 |
+
)
|
186 |
+
|
187 |
+
|
188 |
+
# Training step
|
189 |
+
@partial(jax.jit, static_argnums=(2,))
|
190 |
+
def train_step(state, batch, tol):
|
191 |
+
"""Train for a single step."""
|
192 |
+
def loss_fn(params):
|
193 |
+
logits = FullODENet(tol=tol).apply({'params': params}, batch['image'])
|
194 |
+
loss = cross_entropy_loss(logits=logits, labels=batch['label'])
|
195 |
+
return loss, logits
|
196 |
+
grad_fn = jax.grad(loss_fn, has_aux=True)
|
197 |
+
host_callback.call(nfecounter.set, 0)
|
198 |
+
(_, logits) = loss_fn(state.params)
|
199 |
+
nfe_forward = nfecounter.nfe
|
200 |
+
host_callback.call(nfecounter.set, 0)
|
201 |
+
grads, _ = grad_fn(state.params)
|
202 |
+
nfe_backward = nfecounter.nfe
|
203 |
+
state = state.apply_gradients(grads=grads)
|
204 |
+
metrics = compute_metrics(logits=logits, labels=batch['label'],
|
205 |
+
nfe_forward=nfe_forward, nfe_backward=nfe_backward)
|
206 |
+
return state, metrics
|
207 |
+
|
208 |
+
|
209 |
+
# Evaluation step
|
210 |
+
@partial(jax.jit, static_argnums=(2,))
|
211 |
+
def eval_step(params, batch, tol):
|
212 |
+
logits = FullODENet(tol=tol).apply({'params': params}, batch['image'])
|
213 |
+
return compute_metrics(logits=logits, labels=batch['label'], nfe_forward=0, nfe_backward=0)
|
214 |
+
|
215 |
+
|
216 |
+
# Train function
|
217 |
+
def train_epoch(state, train_ds, batch_size, epoch, rng, tol):
|
218 |
+
"""Train for a single epoch"""
|
219 |
+
train_ds_size = len(train_ds['image'])
|
220 |
+
steps_per_epoch = train_ds_size // batch_size
|
221 |
+
|
222 |
+
perms = jax.random.permutation(rng, len(train_ds['image']))
|
223 |
+
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
|
224 |
+
perms = perms.reshape((steps_per_epoch, batch_size))
|
225 |
+
batch_metrics = []
|
226 |
+
for perm in tqdm(perms):
|
227 |
+
batch = {k: v[perm, ...] for k, v in train_ds.items()}
|
228 |
+
state, metrics = train_step(state, batch, tol)
|
229 |
+
batch_metrics.append(metrics)
|
230 |
+
|
231 |
+
# compute mean of metrics across each batch in epoch.
|
232 |
+
batch_metrics_np = jax.device_get(batch_metrics)
|
233 |
+
epoch_metrics_np = {
|
234 |
+
k: np.mean([metrics[k] for metrics in batch_metrics_np])
|
235 |
+
for k in batch_metrics_np[0]
|
236 |
+
}
|
237 |
+
print('train epoch: %d, loss: %.4f, accuracy: %.2f, nfe_forward: %d, nfe_backward: %d' % (
|
238 |
+
epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100,
|
239 |
+
epoch_metrics_np['nfe_forward'], epoch_metrics_np['nfe_backward']
|
240 |
+
))
|
241 |
+
|
242 |
+
return state
|
243 |
+
|
244 |
+
|
245 |
+
# Eval function
|
246 |
+
def eval_model(params, test_ds, tol):
|
247 |
+
metrics = eval_step(params, test_ds, tol)
|
248 |
+
metrics = jax.device_get(metrics)
|
249 |
+
summary = jax.tree_map(lambda x: x.item(), metrics)
|
250 |
+
return summary['loss'], summary['accuracy']
|
251 |
+
|
252 |
+
|
253 |
+
def train_and_evaluate(learning_rate, n_epoch, batch_size, tol):
|
254 |
+
train_ds, test_ds = get_datasets()
|
255 |
+
rng = jax.random.PRNGKey(0)
|
256 |
+
rng, init_rng = jax.random.split(rng)
|
257 |
+
|
258 |
+
state = create_train_state(init_rng, learning_rate, tol)
|
259 |
+
del init_rng # Must not be used anymore.
|
260 |
+
|
261 |
+
for epoch in tqdm(range(1, n_epoch + 1)):
|
262 |
+
rng, input_rng = jax.random.split(rng)
|
263 |
+
state = train_epoch(state, train_ds, batch_size, epoch, input_rng, tol)
|
264 |
+
test_loss, test_accuracy = eval_model(state.params, test_ds, tol)
|
265 |
+
print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (
|
266 |
+
epoch, test_loss, test_accuracy * 100
|
267 |
+
))
|
268 |
+
|
269 |
+
|
270 |
+
if __name__ == '__main__':
|
271 |
+
train_and_evaluate(0.0001, 5, 128, 1.)
|
train_resnet.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import jax
|
3 |
+
from typing import Any, Callable, Sequence, Optional, NewType
|
4 |
+
from jax import lax, random, vmap, numpy as jnp
|
5 |
+
from jax.experimental.ode import odeint
|
6 |
+
import flax
|
7 |
+
from flax.training import train_state
|
8 |
+
from flax import traverse_util
|
9 |
+
from flax.core import freeze, unfreeze
|
10 |
+
from flax import linen as nn
|
11 |
+
from flax import serialization
|
12 |
+
import optax
|
13 |
+
import tensorflow_datasets as tfds
|
14 |
+
import numpy as np
|
15 |
+
from tqdm import tqdm
|
16 |
+
import os
|
17 |
+
|
18 |
+
|
19 |
+
# Define residual blocks
|
20 |
+
class ResDownBlock(nn.Module):
|
21 |
+
"""Single ResBlock w/ downsample"""
|
22 |
+
dim_out: Any = 64
|
23 |
+
|
24 |
+
@nn.compact
|
25 |
+
def __call__(self, inputs):
|
26 |
+
x = inputs
|
27 |
+
f_x = nn.relu(nn.GroupNorm(self.dim_out)(x))
|
28 |
+
x = nn.Conv(features=self.dim_out, kernel_size=(1, 1), strides=(2, 2))(x)
|
29 |
+
f_x = nn.Conv(features=self.dim_out, kernel_size=(3, 3), strides=(2, 2))(f_x)
|
30 |
+
f_x = nn.relu(nn.GroupNorm(self.dim_out)(f_x))
|
31 |
+
f_x = nn.Conv(features=self.dim_out, kernel_size=(3, 3))(f_x)
|
32 |
+
x = f_x + x
|
33 |
+
return x
|
34 |
+
|
35 |
+
|
36 |
+
class ResBlock(nn.Module):
|
37 |
+
"""Single Resblock w/o downsample"""
|
38 |
+
dim_out: Any = 64
|
39 |
+
ksize: Any = 3
|
40 |
+
|
41 |
+
@nn.compact
|
42 |
+
def __call__(self, inputs):
|
43 |
+
x = inputs
|
44 |
+
f_x = nn.relu(nn.GroupNorm(self.dim_out)(x))
|
45 |
+
f_x = nn.Conv(features=self.dim_out, kernel_size=(self.ksize, self.ksize))(f_x)
|
46 |
+
f_x = nn.relu(nn.GroupNorm(self.dim_out)(f_x))
|
47 |
+
f_x = nn.Conv(features=self.dim_out, kernel_size=(self.ksize, self.ksize))(f_x)
|
48 |
+
x = f_x + x
|
49 |
+
return x
|
50 |
+
|
51 |
+
|
52 |
+
# Define small ResNet for Mnist example
|
53 |
+
class SmallResNet(nn.Module):
|
54 |
+
dim_out: Any = 64
|
55 |
+
ksize: Any = 3
|
56 |
+
|
57 |
+
@nn.compact
|
58 |
+
def __call__(self, inputs):
|
59 |
+
x = inputs
|
60 |
+
x = nn.Conv(features=self.dim_out, kernel_size=(self.ksize, self.ksize))(x)
|
61 |
+
x = ResDownBlock()(x)
|
62 |
+
x = ResDownBlock()(x)
|
63 |
+
|
64 |
+
x = ResBlock()(x)
|
65 |
+
x = ResBlock()(x)
|
66 |
+
x = ResBlock()(x)
|
67 |
+
x = ResBlock()(x)
|
68 |
+
x = ResBlock()(x)
|
69 |
+
x = ResBlock()(x)
|
70 |
+
|
71 |
+
x = nn.GroupNorm(self.dim_out)(x)
|
72 |
+
x = nn.relu(x)
|
73 |
+
x = nn.avg_pool(x, (1, 1))
|
74 |
+
|
75 |
+
x = x.reshape((x.shape[0], -1)) # flatten
|
76 |
+
|
77 |
+
x = nn.Dense(features=10)(x)
|
78 |
+
x = nn.log_softmax(x)
|
79 |
+
|
80 |
+
return x
|
81 |
+
|
82 |
+
|
83 |
+
# Define loss
|
84 |
+
@jax.jit
|
85 |
+
def cross_entropy_loss(logits, labels):
|
86 |
+
one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
|
87 |
+
return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))
|
88 |
+
|
89 |
+
|
90 |
+
# Metric computation
|
91 |
+
@jax.jit
|
92 |
+
def compute_metrics(logits, labels):
|
93 |
+
loss = cross_entropy_loss(logits=logits, labels=labels)
|
94 |
+
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
|
95 |
+
metrics = {
|
96 |
+
'loss': loss,
|
97 |
+
'accuracy': accuracy,
|
98 |
+
}
|
99 |
+
return metrics
|
100 |
+
|
101 |
+
|
102 |
+
def get_datasets():
|
103 |
+
"""Load MNIST train and test datasets into memory."""
|
104 |
+
ds_builder = tfds.builder('mnist')
|
105 |
+
ds_builder.download_and_prepare()
|
106 |
+
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
|
107 |
+
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
|
108 |
+
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
|
109 |
+
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
|
110 |
+
return train_ds, test_ds
|
111 |
+
|
112 |
+
|
113 |
+
def create_train_state(rng, learning_rate):
|
114 |
+
"""Creates initial 'TrainState'."""
|
115 |
+
resnet = SmallResNet()
|
116 |
+
params = resnet.init(rng, jnp.ones([1, 28, 28, 1]))['params']
|
117 |
+
tx = optax.adam(learning_rate)
|
118 |
+
return train_state.TrainState.create(
|
119 |
+
apply_fn=resnet.apply, params=params, tx=tx
|
120 |
+
)
|
121 |
+
|
122 |
+
|
123 |
+
# Training step
|
124 |
+
@jax.jit
|
125 |
+
def train_step(state, batch):
|
126 |
+
"""Train for a single step."""
|
127 |
+
def loss_fn(params):
|
128 |
+
logits = SmallResNet().apply({'params': params}, batch['image'])
|
129 |
+
loss = cross_entropy_loss(logits=logits, labels=batch['label'])
|
130 |
+
return loss, logits
|
131 |
+
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
132 |
+
(_, logits), grads = grad_fn(state.params)
|
133 |
+
state = state.apply_gradients(grads=grads)
|
134 |
+
metrics = compute_metrics(logits=logits, labels=batch['label'])
|
135 |
+
return state, metrics
|
136 |
+
|
137 |
+
|
138 |
+
# Evaluation step
|
139 |
+
@jax.jit
|
140 |
+
def eval_step(params, batch):
|
141 |
+
logits = SmallResNet().apply({'params': params}, batch['image'])
|
142 |
+
return compute_metrics(logits=logits, labels=batch['label'])
|
143 |
+
|
144 |
+
|
145 |
+
# Train function
|
146 |
+
def train_epoch(state, train_ds, batch_size, epoch, rng):
|
147 |
+
"""Train for a single epoch"""
|
148 |
+
train_ds_size = len(train_ds['image'])
|
149 |
+
steps_per_epoch = train_ds_size // batch_size
|
150 |
+
|
151 |
+
perms = jax.random.permutation(rng, len(train_ds['image']))
|
152 |
+
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
|
153 |
+
perms = perms.reshape((steps_per_epoch, batch_size))
|
154 |
+
batch_metrics = []
|
155 |
+
for perm in tqdm(perms):
|
156 |
+
batch = {k: v[perm, ...] for k, v in train_ds.items()}
|
157 |
+
state, metrics = train_step(state, batch)
|
158 |
+
batch_metrics.append(metrics)
|
159 |
+
|
160 |
+
# compute mean of metrics across each batch in epoch.
|
161 |
+
batch_metrics_np = jax.device_get(batch_metrics)
|
162 |
+
epoch_metrics_np = {
|
163 |
+
k: np.mean([metrics[k] for metrics in batch_metrics_np])
|
164 |
+
for k in batch_metrics_np[0]
|
165 |
+
}
|
166 |
+
print('train epoch: %d, loss: %.4f, accuracy: %.2f' % (
|
167 |
+
epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100
|
168 |
+
))
|
169 |
+
|
170 |
+
return state
|
171 |
+
|
172 |
+
|
173 |
+
# Eval function
|
174 |
+
def eval_model(params, test_ds):
|
175 |
+
metrics = eval_step(params, test_ds)
|
176 |
+
metrics = jax.device_get(metrics)
|
177 |
+
summary = jax.tree_map(lambda x: x.item(), metrics)
|
178 |
+
return summary['loss'], summary['accuracy']
|
179 |
+
|
180 |
+
|
181 |
+
def train_and_evaluate(learning_rate, n_epoch, batch_size):
|
182 |
+
train_ds, test_ds = get_datasets()
|
183 |
+
rng = jax.random.PRNGKey(0)
|
184 |
+
rng, init_rng = jax.random.split(rng)
|
185 |
+
|
186 |
+
state = create_train_state(init_rng, learning_rate)
|
187 |
+
del init_rng # Must not be used anymore.
|
188 |
+
|
189 |
+
for epoch in tqdm(range(1, n_epoch + 1)):
|
190 |
+
rng, input_rng = jax.random.split(rng)
|
191 |
+
state = train_epoch(state, train_ds, batch_size, epoch, input_rng)
|
192 |
+
test_loss, test_accuracy = eval_model(state.params, test_ds)
|
193 |
+
print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (
|
194 |
+
epoch, test_loss, test_accuracy * 100
|
195 |
+
))
|