sw32-seo commited on
Commit
5775f48
1 Parent(s): f559f97

Initial commit

Browse files
Files changed (13) hide show
  1. LICENSE +21 -0
  2. README.md +42 -0
  3. cnn.py +219 -0
  4. cnn_ode.py +256 -0
  5. jax_cnn_ode.py +0 -0
  6. main.py +23 -0
  7. mlp.py +103 -0
  8. ode.py +256 -0
  9. opts.py +0 -0
  10. train.py +130 -0
  11. train_cnf.py +274 -0
  12. train_ode.py +271 -0
  13. train_resnet.py +195 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 Seung-woo Eric Seo
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Neural ODE with Flax
2
+ This is the result of project ["Reproduce Neural ODE and SDE"][projectlink] in [HuggingFace Flax/JAX community week][comweeklink].
3
+
4
+ <code>main.py</code> will execute training of ResNet or OdeNet for MNIST dataset.
5
+
6
+ [projectlink]: https://discuss.huggingface.co/t/reproduce-neural-ode-and-neural-sde/7590
7
+
8
+ [comweeklink]: https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects#projects
9
+
10
+ ## Dependency
11
+
12
+ ### JAX and Flax
13
+
14
+ For JAX installation, please follow [here][jaxinstalllink].
15
+
16
+ or simply, type
17
+ ```bash
18
+ pip install jax jaxlib
19
+ ```
20
+
21
+ For Flax installation,
22
+ ```bash
23
+ pip install flax
24
+ ```
25
+
26
+ [jaxinstalllink]: https://github.com/google/jax#installation
27
+
28
+
29
+ Tensorflow-datasets will download MNIST dataset to environment.
30
+
31
+ ## How to run training
32
+
33
+ For (small) ResNet training,
34
+ ```bash
35
+ python main.py --model=resnet --lr=1e-4 --n_epoch=20 --batch_size=64
36
+ ```
37
+
38
+ For Neural ODE training,
39
+ ```bash
40
+ python main.py --model=odenet --lr=1e-4 --n_epoch=20 --batch_size=64
41
+ ```
42
+
cnn.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ from typing import Any, Callable, Sequence, Optional
3
+ from jax import lax, random, vmap, numpy as jnp
4
+ from jax.experimental.ode import odeint
5
+ import flax
6
+ from flax.training import train_state
7
+ from flax.core import freeze, unfreeze
8
+ from flax import linen as nn
9
+ from flax import serialization
10
+ import optax
11
+ import tensorflow_datasets as tfds
12
+ import numpy as np
13
+
14
+
15
+ # Define model
16
+ class CNN(nn.Module):
17
+ """A simple CNN model."""
18
+
19
+ @nn.compact
20
+ def __call__(self, inputs):
21
+ x = inputs
22
+ x = nn.Conv(features=32, kernel_size=(3, 3))(x)
23
+ x = nn.relu(x)
24
+ x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
25
+
26
+ x = nn.Conv(features=64, kernel_size=(3, 3))(x)
27
+ x = nn.relu(x)
28
+ x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
29
+ x = x.reshape((x.shape[0], -1)) # flatten
30
+
31
+ x = nn.Dense(features=256)(x)
32
+ x = nn.relu(x)
33
+ x = nn.Dense(features=10)(x)
34
+ x = nn.log_softmax(x)
35
+ return x
36
+
37
+
38
+ # Define Residual Block
39
+ class ResBlock(nn.Module):
40
+ """Single Resblock w/o downsample"""
41
+
42
+ @nn.compact
43
+ def __call__(self, inputs):
44
+ x = inputs
45
+ f_x = nn.relu(nn.GroupNorm(64)(x))
46
+ f_x = nn.Conv(features=64, kernel_size=(3, 3))(f_x)
47
+ f_x = nn.relu(nn.GroupNorm(64)(f_x))
48
+ f_x = nn.Conv(features=64, kernel_size=(3, 3))(f_x)
49
+ x = f_x + x
50
+ return x
51
+
52
+ class ResDownBlock(nn.Module):
53
+ """Single ResBlock w/ downsample"""
54
+
55
+ @nn.compact
56
+ def __call__(self, inputs):
57
+ x = inputs
58
+ f_x = nn.relu(nn.GroupNorm(64)(x))
59
+ x = nn.Conv(features=64, kernel_size=(1, 1), strides=(2, 2))(x)
60
+ f_x = nn.Conv(features=64, kernel_size=(3, 3), strides=(2, 2))(f_x)
61
+ f_x = nn.relu(nn.GroupNorm(64)(f_x))
62
+ f_x = nn.Conv(features=64, kernel_size=(3, 3))(f_x)
63
+ x = f_x + x
64
+ return x
65
+
66
+
67
+ # Define Model for Mnist example in Neural ODE
68
+ class SmallResNet(nn.Module):
69
+ res_down1: Callable = ResDownBlock()
70
+ res_down2: Callable = ResDownBlock()
71
+ resblock1: Callable = ResBlock()
72
+ resblock2: Callable = ResBlock()
73
+ resblock3: Callable = ResBlock()
74
+ resblock4: Callable = ResBlock()
75
+ resblock5: Callable = ResBlock()
76
+ resblock6: Callable = ResBlock()
77
+
78
+ @nn.compact
79
+ def __call__(self, inputs):
80
+ x = inputs
81
+ x = nn.Conv(features=64, kernel_size=(3, 3))(x)
82
+ x = self.res_down1(x)
83
+ x = self.res_down2(x)
84
+
85
+ x = self.resblock1(x)
86
+ x = self.resblock2(x)
87
+ x = self.resblock3(x)
88
+ x = self.resblock4(x)
89
+ x = self.resblock5(x)
90
+ x = self.resblock6(x)
91
+
92
+ x = nn.GroupNorm(64)(x)
93
+ x = nn.relu(x)
94
+ x = nn.avg_pool(x, (1, 1))
95
+
96
+ x = x.reshape((x.shape[0], -1)) # flatten
97
+
98
+ x = nn.Dense(features=10)(x)
99
+ x = nn.log_softmax(x)
100
+
101
+ return x
102
+
103
+
104
+ # Define loss
105
+ def cross_entropy_loss(*, logits, labels):
106
+ one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
107
+ return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))
108
+
109
+
110
+ # Metric computation
111
+ def compute_metrics(*, logits, labels):
112
+ loss = cross_entropy_loss(logits=logits, labels=labels)
113
+ accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
114
+ metrics = {
115
+ 'loss': loss,
116
+ 'accuracy': accuracy,
117
+ }
118
+ return metrics
119
+
120
+
121
+ def get_datasets():
122
+ """Load MNIST train and test datasets into memory."""
123
+ ds_builder = tfds.builder('mnist')
124
+ ds_builder.download_and_prepare()
125
+ train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
126
+ test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
127
+ train_ds['image'] = jnp.float32(train_ds['image']) / 255.
128
+ test_ds['image'] = jnp.float32(test_ds['image']) / 255.
129
+ return train_ds, test_ds
130
+
131
+
132
+ def create_train_state(rng, learning_rate):
133
+ """Creates initial 'TrainState'."""
134
+ cnn = SmallResNet()
135
+ params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
136
+ tx = optax.adam(learning_rate)
137
+ return train_state.TrainState.create(
138
+ apply_fn=cnn.apply, params=params, tx=tx
139
+ )
140
+
141
+
142
+ # Training step
143
+ @jax.jit
144
+ def train_step(state, batch):
145
+ """Train for a single step."""
146
+ def loss_fn(params):
147
+ logits = SmallResNet().apply({'params': params}, batch['image'])
148
+ loss = cross_entropy_loss(logits=logits, labels=batch['label'])
149
+ return loss, logits
150
+ grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
151
+ (_, logits), grads = grad_fn(state.params)
152
+ state = state.apply_gradients(grads=grads)
153
+ metrics = compute_metrics(logits=logits, labels=batch['label'])
154
+ return state, metrics
155
+
156
+
157
+ # Evaluation step
158
+ @jax.jit
159
+ def eval_step(params, batch):
160
+ logits = SmallResNet().apply({'params': params}, batch['image'])
161
+ return compute_metrics(logits=logits, labels=batch['label'])
162
+
163
+
164
+ # Train function
165
+ def train_epoch(state, train_ds, batch_size, epoch, rng):
166
+ """Train for a single epoch"""
167
+ train_ds_size = len(train_ds['image'])
168
+ steps_per_epoch = train_ds_size // batch_size
169
+
170
+ perms = jax.random.permutation(rng, len(train_ds['image']))
171
+ perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
172
+ perms = perms.reshape((steps_per_epoch, batch_size))
173
+ batch_metrics = []
174
+ for perm in perms:
175
+ batch = {k: v[perm, ...] for k, v in train_ds.items()}
176
+ state, metrics = train_step(state, batch)
177
+ batch_metrics.append(metrics)
178
+
179
+ # compute mean of metrics across each batch in epoch.
180
+ batch_metrics_np = jax.device_get(batch_metrics)
181
+ epoch_metrics_np = {
182
+ k: np.mean([metrics[k] for metrics in batch_metrics_np])
183
+ for k in batch_metrics_np[0]
184
+ }
185
+ print('train epoch: %d, loss: %.4f, accuracy: %.2f' % (
186
+ epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100
187
+ ))
188
+
189
+ return state
190
+
191
+
192
+ # Eval function
193
+ def eval_model(params, test_ds):
194
+ metrics = eval_step(params, test_ds)
195
+ metrics = jax.device_get(metrics)
196
+ summary = jax.tree_map(lambda x: x.item(), metrics)
197
+ return summary['loss'], summary['accuracy']
198
+
199
+
200
+ if __name__ == '__main__':
201
+ train_ds, test_ds = get_datasets()
202
+ rng = jax.random.PRNGKey(0)
203
+ rng, init_rng = jax.random.split(rng)
204
+
205
+ learning_rate = 0.0001
206
+
207
+ state = create_train_state(init_rng, learning_rate)
208
+ del init_rng # Must not be used anymore.
209
+
210
+ num_epochs = 40
211
+ batch_size = 128
212
+
213
+ for epoch in range(1, num_epochs + 1):
214
+ rng, input_rng = jax.random.split(rng)
215
+ state = train_epoch(state, train_ds, batch_size, epoch, input_rng)
216
+ test_loss, test_accuracy = eval_model(state.params, test_ds)
217
+ print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (
218
+ epoch, test_loss, test_accuracy * 100
219
+ ))
cnn_ode.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import jax
3
+ from typing import Any, Callable, Sequence, Optional, NewType
4
+ from jax import lax, random, vmap, numpy as jnp
5
+ from jax.experimental.ode import odeint
6
+ import flax
7
+ from flax.training import train_state
8
+ from flax import traverse_util
9
+ from flax.core import freeze, unfreeze
10
+ from flax import linen as nn
11
+ from flax import serialization
12
+ import optax
13
+ import tensorflow_datasets as tfds
14
+ import numpy as np
15
+ from tqdm import tqdm
16
+ import os
17
+
18
+
19
+ # TODO Add system argument for dim_out, ksize, tol, learning_rate, num_epoch and batch_size
20
+
21
+ # Define Residual Block
22
+ class ResDownBlock(nn.Module):
23
+ """Single ResBlock w/ downsample"""
24
+ dim_out: Any = 64
25
+
26
+ @nn.compact
27
+ def __call__(self, inputs):
28
+ x = inputs
29
+ f_x = nn.relu(nn.GroupNorm(self.dim_out)(x))
30
+ x = nn.Conv(features=self.dim_out, kernel_size=(1, 1), strides=(2, 2))(x)
31
+ f_x = nn.Conv(features=self.dim_out, kernel_size=(3, 3), strides=(2, 2))(f_x)
32
+ f_x = nn.relu(nn.GroupNorm(self.dim_out)(f_x))
33
+ f_x = nn.Conv(features=self.dim_out, kernel_size=(3, 3))(f_x)
34
+ x = f_x + x
35
+ return x
36
+
37
+
38
+ class ConcatConv2D(nn.Module):
39
+ """Concat dynamics to hidden layer"""
40
+ dim_out: Any = 64
41
+ ksize: Any = 3
42
+
43
+ @nn.compact
44
+ def __call__(self, x, t):
45
+ tt = jnp.ones_like(x[..., :1]) * t
46
+ ttx = jnp.concatenate([tt, x], -1)
47
+ return nn.Conv(features=self.dim_out, kernel_size=self.ksize)(ttx)
48
+
49
+
50
+ # Define Model for Mnist example in Neural ODE
51
+ class ODEfunc(nn.Module):
52
+ """ODE function which replace ResNet"""
53
+ dim_out: Any = 64
54
+ ksize: Any = 3
55
+
56
+ @nn.compact
57
+ def __call__(self, inputs, t):
58
+ # TODO Count number of function estimation
59
+ # nfe_counter = NFEcounter()
60
+ # nfe_counter()
61
+
62
+ x = inputs
63
+ out = nn.GroupNorm(self.dim_out)(x)
64
+ out = nn.relu(out)
65
+ out = ConcatConv2D(self.dim_out, self.ksize)(out, t)
66
+ out = nn.GroupNorm(self.dim_out)(out)
67
+ out = nn.relu(out)
68
+ out = ConcatConv2D(self.dim_out, self.ksize)(out, t)
69
+ out = nn.GroupNorm(self.dim_out)(out)
70
+
71
+ return out
72
+
73
+
74
+ class NFEcounter(nn.Module):
75
+
76
+ @nn.compact
77
+ def __call__(self):
78
+ is_initialized = self.has_variable('nfe', 'nfe')
79
+ nfe = self.variable('nfe', 'nfe', jnp.array, [0])
80
+ if is_initialized:
81
+ nfe.value += 1
82
+
83
+
84
+ class ODEBlock(nn.Module):
85
+ """ODE block which contains odeint"""
86
+ tol = 1.
87
+
88
+ @nn.compact
89
+ def __call__(self, x, params):
90
+ ode_func = ODEfunc()
91
+ init_state, final_state = odeint(partial(ode_func.apply, {'params': params}),
92
+ x, jnp.array([0., 1.]),
93
+ rtol=self.tol, atol=self.tol)
94
+ return final_state
95
+
96
+
97
+ class ODEBlockVmap(nn.Module):
98
+ """Apply vmap to ODEBlock"""
99
+
100
+ @nn.compact
101
+ def __call__(self, x, params):
102
+ vmap_odeblock = nn.vmap(ODEBlock,
103
+ variable_axes={'params': 0, 'nfe': None},
104
+ split_rngs={'params': True, 'nfe': False},
105
+ in_axes=(0, None))
106
+ return vmap_odeblock(name='odeblock')(x, params)
107
+
108
+
109
+ class FullODENet(nn.Module):
110
+ """Full ODE net which contains two downsampling layers, ODE block and linear classifier."""
111
+ dim_out: Any = 64
112
+ ksize: Any = 3
113
+
114
+ @nn.compact
115
+ def __call__(self, inputs):
116
+ x = inputs
117
+ x = nn.Conv(features=self.dim_out, kernel_size=(self.ksize, self.ksize))(x)
118
+ x = ResDownBlock()(x)
119
+ x = ResDownBlock()(x)
120
+
121
+ ode_func = ODEfunc()
122
+ init_fn = lambda rng, x: ode_func.init(random.split(rng)[-1], x, 0.)['params']
123
+ ode_func_params = self.param('ode_func', init_fn, jnp.ones_like(x[0]))
124
+ x = ODEBlockVmap()(x, ode_func_params)
125
+
126
+ x = nn.GroupNorm(self.dim_out)(x)
127
+ x = nn.relu(x)
128
+ x = nn.avg_pool(x, (1, 1))
129
+
130
+ x = x.reshape((x.shape[0], -1)) # flatten
131
+
132
+ x = nn.Dense(features=10)(x)
133
+ x = nn.log_softmax(x)
134
+
135
+ return x
136
+
137
+
138
+ # Define loss
139
+ @jax.jit
140
+ def cross_entropy_loss(logits, labels):
141
+ one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
142
+ return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))
143
+
144
+
145
+ # Metric computation
146
+ @jax.jit
147
+ def compute_metrics(logits, labels):
148
+ loss = cross_entropy_loss(logits=logits, labels=labels)
149
+ accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
150
+ metrics = {
151
+ 'loss': loss,
152
+ 'accuracy': accuracy,
153
+ }
154
+ return metrics
155
+
156
+
157
+ def get_datasets():
158
+ """Load MNIST train and test datasets into memory."""
159
+ ds_builder = tfds.builder('mnist')
160
+ ds_builder.download_and_prepare()
161
+ train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
162
+ test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
163
+ train_ds['image'] = jnp.float32(train_ds['image']) / 255.
164
+ test_ds['image'] = jnp.float32(test_ds['image']) / 255.
165
+ return train_ds, test_ds
166
+
167
+
168
+ def create_train_state(rng, learning_rate):
169
+ """Creates initial 'TrainState'."""
170
+ cnn = FullODENet()
171
+ params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
172
+ tx = optax.adam(learning_rate)
173
+ return train_state.TrainState.create(
174
+ apply_fn=cnn.apply, params=params, tx=tx
175
+ )
176
+
177
+
178
+ # Training step
179
+ @jax.jit
180
+ def train_step(state, batch):
181
+ """Train for a single step."""
182
+ def loss_fn(params):
183
+ logits = FullODENet().apply({'params': params}, batch['image'])
184
+ loss = cross_entropy_loss(logits=logits, labels=batch['label'])
185
+ return loss, logits
186
+ grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
187
+ (_, logits), grads = grad_fn(state.params)
188
+ state = state.apply_gradients(grads=grads)
189
+ metrics = compute_metrics(logits=logits, labels=batch['label'])
190
+ return state, metrics
191
+
192
+
193
+ # Evaluation step
194
+ @jax.jit
195
+ def eval_step(params, batch):
196
+ logits = FullODENet().apply({'params': params}, batch['image'])
197
+ return compute_metrics(logits=logits, labels=batch['label'])
198
+
199
+
200
+ # Train function
201
+ def train_epoch(state, train_ds, batch_size, epoch, rng):
202
+ """Train for a single epoch"""
203
+ train_ds_size = len(train_ds['image'])
204
+ steps_per_epoch = train_ds_size // batch_size
205
+
206
+ perms = jax.random.permutation(rng, len(train_ds['image']))
207
+ perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
208
+ perms = perms.reshape((steps_per_epoch, batch_size))
209
+ batch_metrics = []
210
+ for perm in tqdm(perms):
211
+ batch = {k: v[perm, ...] for k, v in train_ds.items()}
212
+ state, metrics = train_step(state, batch)
213
+ batch_metrics.append(metrics)
214
+
215
+ # compute mean of metrics across each batch in epoch.
216
+ batch_metrics_np = jax.device_get(batch_metrics)
217
+ epoch_metrics_np = {
218
+ k: np.mean([metrics[k] for metrics in batch_metrics_np])
219
+ for k in batch_metrics_np[0]
220
+ }
221
+ print('train epoch: %d, loss: %.4f, accuracy: %.2f' % (
222
+ epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100
223
+ ))
224
+
225
+ return state
226
+
227
+
228
+ # Eval function
229
+ def eval_model(params, test_ds):
230
+ metrics = eval_step(params, test_ds)
231
+ metrics = jax.device_get(metrics)
232
+ summary = jax.tree_map(lambda x: x.item(), metrics)
233
+ return summary['loss'], summary['accuracy']
234
+
235
+
236
+ if __name__ == '__main__':
237
+ train_ds, test_ds = get_datasets()
238
+ rng = jax.random.PRNGKey(0)
239
+ rng, init_rng = jax.random.split(rng)
240
+
241
+ # Build learning rate decay as Neural ODE paper
242
+ learning_rate = 0.0001
243
+
244
+ state = create_train_state(init_rng, learning_rate)
245
+ del init_rng # Must not be used anymore.
246
+
247
+ num_epochs = 20
248
+ batch_size = 128
249
+
250
+ for epoch in tqdm(range(1, num_epochs + 1)):
251
+ rng, input_rng = jax.random.split(rng)
252
+ state = train_epoch(state, train_ds, batch_size, epoch, input_rng)
253
+ test_loss, test_accuracy = eval_model(state.params, test_ds)
254
+ print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (
255
+ epoch, test_loss, test_accuracy * 100
256
+ ))
jax_cnn_ode.py ADDED
File without changes
main.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import train_ode
3
+ import train_resnet
4
+
5
+
6
+ def main(args):
7
+ if args.model == 'odenet':
8
+ train_ode.train_and_evaluate(args.lr, args.n_epoch, args.batch_size, args.tol)
9
+ else:
10
+ train_resnet.train_and_evaluate(args.lr, args.n_epoch, args.batch_size)
11
+
12
+
13
+ if __name__ == '__main__':
14
+ parser = argparse.ArgumentParser(description='main.py')
15
+ parser.add_argument("--model", type=str, choices=['odenet', 'resnet'], default="odenet", help="Type of model")
16
+ parser.add_argument("--tol", type=float, default=1e-1,
17
+ help="Error tolerance for ODE solver. This only works with odenet")
18
+ parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
19
+ parser.add_argument("--n_epoch", type=int, default=10, help="Total number of epoch")
20
+ parser.add_argument("--batch_size", type=int, default=32, help="Number of images in batch")
21
+
22
+ args = parser.parse_args()
23
+ main(args)
mlp.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ from typing import Any, Callable, Sequence, Optional
3
+ from jax import lax, random, numpy as jnp
4
+ import flax
5
+ from flax.training import train_state
6
+ from flax.core import freeze, unfreeze
7
+ from flax import linen as nn
8
+ from flax import serialization
9
+ import optax
10
+
11
+
12
+ class ExplicitMLP(nn.Module):
13
+ features: Sequence[int]
14
+
15
+ def setup(self):
16
+ self.layers = [nn.Dense(feat) for feat in self.features]
17
+
18
+ def __call__(self, inputs):
19
+ x = inputs
20
+ for i, lyr in enumerate(self.layers):
21
+ x = lyr(x)
22
+ if i != len(self.layers) - 1:
23
+ x = nn.relu(x)
24
+ return x
25
+
26
+
27
+ class SimpleMLP(nn.Module):
28
+ features: Sequence[int]
29
+
30
+ @nn.compact
31
+ def __call__(self, inputs):
32
+ x = inputs
33
+ for i, feat in enumerate(self.features):
34
+ x = nn.Dense(feat)(x)
35
+ if i != len(self.features - 1):
36
+ x = nn.relu(x)
37
+ return x
38
+
39
+
40
+ if __name__ == '__main__':
41
+ key1, key2 = random.split(random.PRNGKey(0), 2)
42
+
43
+ # Set problem dimensions
44
+ nsamples = 20
45
+ xdim = 10
46
+ ydim = 5
47
+
48
+ # Generate true W and b
49
+ W = random.normal(key1, (xdim, ydim))
50
+ b = random.normal(key2, (ydim,))
51
+ true_params = freeze({'params': {'bias': b, 'kernel': W}})
52
+
53
+ # Generate samples with additional noise
54
+ ksample, knoise = random.split(key1)
55
+ x_samples = random.normal(ksample, (nsamples, xdim))
56
+ y_samples = jnp.dot(x_samples, W) + b
57
+ y_samples += 0.1 * random.normal(knoise, (nsamples, ydim)) # Adding noise
58
+ print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)
59
+
60
+ key_init, subkey = random.split(ksample, 2)
61
+ model = ExplicitMLP(features=[5])
62
+ params = model.init(subkey, x_samples)
63
+
64
+ def make_mse_func(x_batched, y_batched):
65
+ def mse(params):
66
+ # Define the squared loss for a single pair (x,y)
67
+ def squared_error(x, y):
68
+ pred = model.apply(params, x)
69
+ return jnp.inner(y - pred, y - pred) / 2.0
70
+
71
+ # We vectorize the previous to compute the average of the loss on all samples.
72
+ return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)
73
+
74
+ return jax.jit(mse) # And finally we jit the result.
75
+
76
+ # Get the sampled loss
77
+ loss = make_mse_func(x_samples, y_samples)
78
+
79
+ lr = 0.3
80
+ tx = optax.sgd(learning_rate=lr)
81
+ opt_state = tx.init(params)
82
+ loss_grad_fn = jax.value_and_grad(loss)
83
+
84
+ for i in range(101):
85
+ loss_val, grads = loss_grad_fn(params)
86
+ updates, opt_state = tx.update(grads, opt_state)
87
+ params = optax.apply_updates(params, updates)
88
+
89
+ if i % 10 == 0:
90
+ print('Loss step {}: '.format(i), loss_val)
91
+
92
+ # Serializing the result
93
+ bytes_output = serialization.to_bytes(params)
94
+ dict_output = serialization.to_state_dict(params)
95
+ print('Dict output')
96
+ print(dict_output)
97
+ print('Bytes output')
98
+ print(bytes_output)
99
+
100
+ # Restore the parameter from the saved one
101
+ saved_params = serialization.from_bytes(params, bytes_output)
102
+ print(loss(saved_params))
103
+ print(loss(params))
ode.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2018 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """JAX-based Dormand-Prince ODE integration with adaptive stepsize.
16
+ Integrate systems of ordinary differential equations (ODEs) using the JAX
17
+ autograd/diff library and the Dormand-Prince method for adaptive integration
18
+ stepsize calculation. Provides improved integration accuracy over fixed
19
+ stepsize integration methods.
20
+ For details of the mixed 4th/5th order Runge-Kutta integration method, see
21
+ https://doi.org/10.1090/S0025-5718-1986-0815836-3
22
+ Adjoint algorithm based on Appendix C of https://arxiv.org/pdf/1806.07366.pdf
23
+ """
24
+
25
+
26
+ from functools import partial
27
+ import operator as op
28
+
29
+ import jax
30
+ import jax.numpy as jnp
31
+ from jax import core
32
+ from jax import custom_derivatives
33
+ from jax import lax
34
+ from jax._src.util import safe_map, safe_zip
35
+ from jax.flatten_util import ravel_pytree
36
+ from jax.tree_util import tree_map
37
+ from jax import linear_util as lu
38
+
39
+ map = safe_map
40
+ zip = safe_zip
41
+
42
+
43
+ def ravel_first_arg(f, unravel):
44
+ return ravel_first_arg_(lu.wrap_init(f), unravel).call_wrapped
45
+
46
+ @lu.transformation
47
+ def ravel_first_arg_(unravel, y_flat, *args):
48
+ y = unravel(y_flat)
49
+ ans = yield (y,) + args, {}
50
+ ans_flat, _ = ravel_pytree(ans)
51
+ yield ans_flat
52
+
53
+ def interp_fit_dopri(y0, y1, k, dt):
54
+ # Fit a polynomial to the results of a Runge-Kutta step.
55
+ dps_c_mid = jnp.array([
56
+ 6025192743 / 30085553152 / 2, 0, 51252292925 / 65400821598 / 2,
57
+ -2691868925 / 45128329728 / 2, 187940372067 / 1594534317056 / 2,
58
+ -1776094331 / 19743644256 / 2, 11237099 / 235043384 / 2])
59
+ y_mid = y0 + dt * jnp.dot(dps_c_mid, k)
60
+ return jnp.asarray(fit_4th_order_polynomial(y0, y1, y_mid, k[0], k[-1], dt))
61
+
62
+ def fit_4th_order_polynomial(y0, y1, y_mid, dy0, dy1, dt):
63
+ a = -2.*dt*dy0 + 2.*dt*dy1 - 8.*y0 - 8.*y1 + 16.*y_mid
64
+ b = 5.*dt*dy0 - 3.*dt*dy1 + 18.*y0 + 14.*y1 - 32.*y_mid
65
+ c = -4.*dt*dy0 + dt*dy1 - 11.*y0 - 5.*y1 + 16.*y_mid
66
+ d = dt * dy0
67
+ e = y0
68
+ return a, b, c, d, e
69
+
70
+ def initial_step_size(fun, t0, y0, order, rtol, atol, f0):
71
+ # Algorithm from:
72
+ # E. Hairer, S. P. Norsett G. Wanner,
73
+ # Solving Ordinary Differential Equations I: Nonstiff Problems, Sec. II.4.
74
+ scale = atol + jnp.abs(y0) * rtol
75
+ d0 = jnp.linalg.norm(y0 / scale)
76
+ d1 = jnp.linalg.norm(f0 / scale)
77
+
78
+ h0 = jnp.where((d0 < 1e-5) | (d1 < 1e-5), 1e-6, 0.01 * d0 / d1)
79
+
80
+ y1 = y0 + h0 * f0
81
+
82
+ f1 = fun(y1, t0 + h0)
83
+ d2 = jnp.linalg.norm((f1 - f0) / scale) / h0
84
+
85
+ h1 = jnp.where((d1 <= 1e-15) & (d2 <= 1e-15),
86
+ jnp.maximum(1e-6, h0 * 1e-3),
87
+ (0.01 / jnp.max(d1 + d2)) ** (1. / (order + 1.)))
88
+
89
+ return jnp.minimum(100. * h0, h1)
90
+
91
+ def runge_kutta_step(func, y0, f0, t0, dt):
92
+ # Dopri5 Butcher tableaux
93
+ alpha = jnp.array([1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1., 0])
94
+ beta = jnp.array([
95
+ [1 / 5, 0, 0, 0, 0, 0, 0],
96
+ [3 / 40, 9 / 40, 0, 0, 0, 0, 0],
97
+ [44 / 45, -56 / 15, 32 / 9, 0, 0, 0, 0],
98
+ [19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729, 0, 0, 0],
99
+ [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656, 0, 0],
100
+ [35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0]
101
+ ])
102
+ c_sol = jnp.array([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0])
103
+ c_error = jnp.array([35 / 384 - 1951 / 21600, 0, 500 / 1113 - 22642 / 50085,
104
+ 125 / 192 - 451 / 720, -2187 / 6784 - -12231 / 42400,
105
+ 11 / 84 - 649 / 6300, -1. / 60.])
106
+
107
+ def body_fun(i, k):
108
+ ti = t0 + dt * alpha[i-1]
109
+ yi = y0 + dt * jnp.dot(beta[i-1, :], k)
110
+ ft = func(yi, ti)
111
+ return k.at[i, :].set(ft)
112
+
113
+ k = jnp.zeros((7, f0.shape[0]), f0.dtype).at[0, :].set(f0)
114
+ k = lax.fori_loop(1, 7, body_fun, k)
115
+
116
+ y1 = dt * jnp.dot(c_sol, k) + y0
117
+ y1_error = dt * jnp.dot(c_error, k)
118
+ f1 = k[-1]
119
+ return y1, f1, y1_error, k
120
+
121
+ def abs2(x):
122
+ if jnp.iscomplexobj(x):
123
+ return x.real ** 2 + x.imag ** 2
124
+ else:
125
+ return x ** 2
126
+
127
+ def error_ratio(error_estimate, rtol, atol, y0, y1):
128
+ err_tol = atol + rtol * jnp.maximum(jnp.abs(y0), jnp.abs(y1))
129
+ err_ratio = error_estimate / err_tol
130
+ return jnp.mean(abs2(err_ratio))
131
+
132
+ def optimal_step_size(last_step, mean_error_ratio, safety=0.9, ifactor=10.0,
133
+ dfactor=0.2, order=5.0):
134
+ """Compute optimal Runge-Kutta stepsize."""
135
+ mean_error_ratio = jnp.max(mean_error_ratio)
136
+ dfactor = jnp.where(mean_error_ratio < 1, 1.0, dfactor)
137
+
138
+ err_ratio = jnp.sqrt(mean_error_ratio)
139
+ factor = jnp.maximum(1.0 / ifactor,
140
+ jnp.minimum(err_ratio**(1.0 / order) / safety, 1.0 / dfactor))
141
+ return jnp.where(mean_error_ratio == 0, last_step * ifactor, last_step / factor)
142
+
143
+ def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=jnp.inf):
144
+ """Adaptive stepsize (Dormand-Prince) Runge-Kutta odeint implementation.
145
+ Args:
146
+ func: function to evaluate the time derivative of the solution `y` at time
147
+ `t` as `func(y, t, *args)`, producing the same shape/structure as `y0`.
148
+ y0: array or pytree of arrays representing the initial value for the state.
149
+ t: array of float times for evaluation, like `jnp.linspace(0., 10., 101)`,
150
+ in which the values must be strictly increasing.
151
+ *args: tuple of additional arguments for `func`, which must be arrays
152
+ scalars, or (nested) standard Python containers (tuples, lists, dicts,
153
+ namedtuples, i.e. pytrees) of those types.
154
+ rtol: float, relative local error tolerance for solver (optional).
155
+ atol: float, absolute local error tolerance for solver (optional).
156
+ mxstep: int, maximum number of steps to take for each timepoint (optional).
157
+ Returns:
158
+ Values of the solution `y` (i.e. integrated system values) at each time
159
+ point in `t`, represented as an array (or pytree of arrays) with the same
160
+ shape/structure as `y0` except with a new leading axis of length `len(t)`.
161
+ """
162
+ def _check_arg(arg):
163
+ if not isinstance(arg, core.Tracer) and not core.valid_jaxtype(arg):
164
+ msg = ("The contents of odeint *args must be arrays or scalars, but got "
165
+ "\n{}.")
166
+ raise TypeError(msg.format(arg))
167
+
168
+ converted, consts = custom_derivatives.closure_convert(func, y0, t[0], *args)
169
+ return _odeint_wrapper(converted, rtol, atol, mxstep, y0, t, *args, *consts)
170
+
171
+ @partial(jax.jit, static_argnums=(0, 1, 2, 3))
172
+ def _odeint_wrapper(func, rtol, atol, mxstep, y0, ts, *args):
173
+ y0, unravel = ravel_pytree(y0)
174
+ func = ravel_first_arg(func, unravel)
175
+ out = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
176
+ return jax.vmap(unravel)(out)
177
+
178
+ @partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2, 3))
179
+ def _odeint(func, rtol, atol, mxstep, y0, ts, *args):
180
+ func_ = lambda y, t: func(y, t, *args)
181
+
182
+ def scan_fun(carry, target_t):
183
+
184
+ def cond_fun(state):
185
+ i, _, _, t, dt, _, _ = state
186
+ return (t < target_t) & (i < mxstep) & (dt > 0)
187
+
188
+ def body_fun(state):
189
+ i, y, f, t, dt, last_t, interp_coeff = state
190
+ next_y, next_f, next_y_error, k = runge_kutta_step(func_, y, f, t, dt)
191
+ next_t = t + dt
192
+ error_ratios = error_ratio(next_y_error, rtol, atol, y, next_y)
193
+ new_interp_coeff = interp_fit_dopri(y, next_y, k, dt)
194
+ dt = optimal_step_size(dt, error_ratios)
195
+
196
+ new = [i + 1, next_y, next_f, next_t, dt, t, new_interp_coeff]
197
+ old = [i + 1, y, f, t, dt, last_t, interp_coeff]
198
+ return map(partial(jnp.where, jnp.all(error_ratios <= 1.)), new, old)
199
+
200
+ _, *carry = lax.while_loop(cond_fun, body_fun, [0] + carry)
201
+ _, _, t, _, last_t, interp_coeff = carry
202
+ relative_output_time = (target_t - last_t) / (t - last_t)
203
+ y_target = jnp.polyval(interp_coeff, relative_output_time)
204
+ return carry, y_target
205
+
206
+ # ODEfunc with NFE counter will give auxilarly output for nfe.
207
+ # Below code is modified to skip that output.
208
+ f0 = func_(y0, ts[0])
209
+ dt = initial_step_size(func_, ts[0], y0, 4, rtol, atol, f0)
210
+ interp_coeff = jnp.array([y0] * 5)
211
+ init_carry = [y0, f0, ts[0], dt, ts[0], interp_coeff]
212
+ _, ys = lax.scan(scan_fun, init_carry, ts[1:])
213
+ return jnp.concatenate((y0[None], ys))
214
+
215
+ def _odeint_fwd(func, rtol, atol, mxstep, y0, ts, *args):
216
+ ys = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
217
+ return ys, (ys, ts, args)
218
+
219
+ def _odeint_rev(func, rtol, atol, mxstep, res, g):
220
+ ys, ts, args = res
221
+
222
+ def aug_dynamics(augmented_state, t, *args):
223
+ """Original system augmented with vjp_y, vjp_t and vjp_args."""
224
+ y, y_bar, *_ = augmented_state
225
+ # `t` here is negatice time, so we need to negate again to get back to
226
+ # normal time. See the `odeint` invocation in `scan_fun` below.
227
+ y_dot, vjpfun = jax.vjp(func, y, -t, *args)
228
+ return (-y_dot, *vjpfun(y_bar))
229
+
230
+ y_bar = g[-1]
231
+ ts_bar = []
232
+ t0_bar = 0.
233
+
234
+ def scan_fun(carry, i):
235
+ y_bar, t0_bar, args_bar = carry
236
+ # Compute effect of moving measurement time
237
+ # `t_bar` should not be complex as it represents time
238
+ t_bar = jnp.dot(func(ys[i], ts[i], *args), g[i]).real
239
+ t0_bar = t0_bar - t_bar
240
+ # Run augmented system backwards to previous observation
241
+ _, y_bar, t0_bar, args_bar = odeint(
242
+ aug_dynamics, (ys[i], y_bar, t0_bar, args_bar),
243
+ jnp.array([-ts[i], -ts[i - 1]]),
244
+ *args, rtol=rtol, atol=atol, mxstep=mxstep)
245
+ y_bar, t0_bar, args_bar = tree_map(op.itemgetter(1), (y_bar, t0_bar, args_bar))
246
+ # Add gradient from current output
247
+ y_bar = y_bar + g[i - 1]
248
+ return (y_bar, t0_bar, args_bar), t_bar
249
+
250
+ init_carry = (g[-1], 0., tree_map(jnp.zeros_like, args))
251
+ (y_bar, t0_bar, args_bar), rev_ts_bar = lax.scan(
252
+ scan_fun, init_carry, jnp.arange(len(ts) - 1, 0, -1))
253
+ ts_bar = jnp.concatenate([jnp.array([t0_bar]), rev_ts_bar[::-1]])
254
+ return (y_bar, ts_bar, *args_bar)
255
+
256
+ _odeint.defvjp(_odeint_fwd, _odeint_rev)
opts.py ADDED
File without changes
train.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import jax
3
+ from typing import Any, Callable, Sequence, Optional, NewType
4
+ from jax import lax, random, vmap, numpy as jnp
5
+ from jax.experimental.ode import odeint
6
+ import flax
7
+ from flax.training import train_state
8
+ from flax import traverse_util
9
+ from flax.core import freeze, unfreeze
10
+ from flax import linen as nn
11
+ from flax import serialization
12
+ import optax
13
+ import tensorflow_datasets as tfds
14
+ import numpy as np
15
+ from tqdm import tqdm
16
+ import os
17
+
18
+
19
+ # Define loss
20
+ @jax.jit
21
+ def cross_entropy_loss(logits, labels):
22
+ one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
23
+ return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))
24
+
25
+
26
+ # Metric computation
27
+ @jax.jit
28
+ def compute_metrics(logits, labels):
29
+ loss = cross_entropy_loss(logits=logits, labels=labels)
30
+ accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
31
+ metrics = {
32
+ 'loss': loss,
33
+ 'accuracy': accuracy,
34
+ }
35
+ return metrics
36
+
37
+
38
+ def get_datasets():
39
+ """Load MNIST train and test datasets into memory."""
40
+ ds_builder = tfds.builder('mnist')
41
+ ds_builder.download_and_prepare()
42
+ train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
43
+ test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
44
+ train_ds['image'] = jnp.float32(train_ds['image']) / 255.
45
+ test_ds['image'] = jnp.float32(test_ds['image']) / 255.
46
+ return train_ds, test_ds
47
+
48
+
49
+ def create_train_state(model, rng, learning_rate):
50
+ """Creates initial 'TrainState'."""
51
+ params = model.init(rng, jnp.ones([1, 28, 28, 1]))['params']
52
+ tx = optax.adam(learning_rate)
53
+ return train_state.TrainState.create(
54
+ apply_fn=model.apply, params=params, tx=tx
55
+ )
56
+
57
+
58
+ # Training step
59
+ @jax.jit
60
+ def train_step(state, batch):
61
+ """Train for a single step."""
62
+ def loss_fn(params):
63
+ logits = apply({'params': params}, batch['image'])
64
+ loss = cross_entropy_loss(logits=logits, labels=batch['label'])
65
+ return loss, logits
66
+ grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
67
+ (_, logits), grads = grad_fn(state.params)
68
+ state = state.apply_gradients(grads=grads)
69
+ metrics = compute_metrics(logits=logits, labels=batch['label'])
70
+ return state, metrics
71
+
72
+
73
+ # Evaluation step
74
+ @jax.jit
75
+ def eval_step(params, batch):
76
+ logits = apply({'params': params}, batch['image'])
77
+ return compute_metrics(logits=logits, labels=batch['label'])
78
+
79
+
80
+ # Train function
81
+ def train_epoch(model, state, train_ds, batch_size, epoch, rng):
82
+ """Train for a single epoch"""
83
+ train_ds_size = len(train_ds['image'])
84
+ steps_per_epoch = train_ds_size // batch_size
85
+
86
+ perms = jax.random.permutation(rng, len(train_ds['image']))
87
+ perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
88
+ perms = perms.reshape((steps_per_epoch, batch_size))
89
+ batch_metrics = []
90
+ for perm in tqdm(perms):
91
+ batch = {k: v[perm, ...] for k, v in train_ds.items()}
92
+ state, metrics = train_step(model, state, batch)
93
+ batch_metrics.append(metrics)
94
+
95
+ # compute mean of metrics across each batch in epoch.
96
+ batch_metrics_np = jax.device_get(batch_metrics)
97
+ epoch_metrics_np = {
98
+ k: np.mean([metrics[k] for metrics in batch_metrics_np])
99
+ for k in batch_metrics_np[0]
100
+ }
101
+ print('train epoch: %d, loss: %.4f, accuracy: %.2f' % (
102
+ epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100
103
+ ))
104
+
105
+ return state
106
+
107
+
108
+ # Eval function
109
+ def eval_model(params, test_ds):
110
+ metrics = eval_step(params, test_ds)
111
+ metrics = jax.device_get(metrics)
112
+ summary = jax.tree_map(lambda x: x.item(), metrics)
113
+ return summary['loss'], summary['accuracy']
114
+
115
+
116
+ def train_and_evaluate(learning_rate, n_epoch, batch_size):
117
+ train_ds, test_ds = get_datasets()
118
+ rng = jax.random.PRNGKey(0)
119
+ rng, init_rng = jax.random.split(rng)
120
+
121
+ state = create_train_state(init_rng, learning_rate)
122
+ del init_rng # Must not be used anymore.
123
+
124
+ for epoch in tqdm(range(1, n_epoch + 1)):
125
+ rng, input_rng = jax.random.split(rng)
126
+ state = train_epoch(state, train_ds, batch_size, epoch, input_rng)
127
+ test_loss, test_accuracy = eval_model(state.params, test_ds)
128
+ print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (
129
+ epoch, test_loss, test_accuracy * 100
130
+ ))
train_cnf.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import os
4
+ import glob
5
+ from PIL import Image
6
+ from functools import partial
7
+ import jax
8
+ from typing import Any, Callable, Sequence, Optional, NewType
9
+ from jax import lax, random, vmap, scipy, numpy as jnp
10
+ # from jax.experimental.ode import odeint
11
+ from models.ode import odeint
12
+ import flax
13
+ from flax.training import train_state
14
+ from flax import traverse_util
15
+ from flax.core import freeze, unfreeze
16
+ from flax import linen as nn
17
+ from flax import serialization
18
+ import optax
19
+ from sklearn.datasets import make_circles
20
+ from tqdm import tqdm
21
+
22
+
23
+ # os.environ['TF_FORCE_UNIFIED_MEMORY'] = '1'
24
+ # os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
25
+ # os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
26
+
27
+
28
+ class HyperNetwork(nn.Module):
29
+ """Hyper-network allowing f(z(t), t) to change with time.
30
+
31
+ Adapted from the Pytorch implementation at:
32
+ https://github.com/rtqichen/torchdiffeq/blob/master/examples/cnf.py
33
+ """
34
+ in_out_dim: Any = 2
35
+ hidden_dim: Any = 32
36
+ width: Any = 64
37
+
38
+ @nn.compact
39
+ def __call__(self, t):
40
+ # predict params
41
+ blocksize = self.width * self.in_out_dim
42
+ params = lax.expand_dims(t, (0, 1))
43
+ params = nn.Dense(self.hidden_dim)(params)
44
+ params = nn.tanh(params)
45
+ params = nn.Dense(self.hidden_dim)(params)
46
+ params = nn.tanh(params)
47
+ params = nn.Dense(3 * blocksize + self.width)(params)
48
+
49
+ # restructure
50
+ params = lax.reshape(params, (3 * blocksize + self.width,))
51
+ W = lax.reshape(params[:blocksize], (self.width, self.in_out_dim, 1))
52
+
53
+ U = lax.reshape(params[blocksize:2 * blocksize], (self.width, 1, self.in_out_dim))
54
+
55
+ G = lax.reshape(params[2 * blocksize:3 * blocksize], (self.width, 1, self.in_out_dim))
56
+ U = U * nn.sigmoid(G)
57
+
58
+ B = lax.expand_dims(params[3 * blocksize:], (1, 2))
59
+ return W, B, U
60
+
61
+
62
+ class CNF(nn.Module):
63
+ """Adapted from the Pytorch implementation at:
64
+ https://github.com/rtqichen/torchdiffeq/blob/master/examples/cnf.py
65
+ """
66
+ in_out_dim: Any = 2
67
+ hidden_dim: Any = 32
68
+ width: Any = 64
69
+
70
+ @nn.compact
71
+ def __call__(self, t, states):
72
+ z, logp_z = states[..., :2], states[..., 2:]
73
+ W, B, U = HyperNetwork(self.in_out_dim, self.hidden_dim, self.width)(t)
74
+
75
+ # TODO Below should be converted using vmap
76
+ def dzdt(z):
77
+ Z = lax.expand_dims(z, (0,))
78
+ Z = jnp.repeat(Z, self.width, 0)
79
+ h = nn.tanh(jnp.matmul(Z, W) + B)
80
+ return jnp.matmul(h, U).mean(0)
81
+
82
+ dz_dt = dzdt(z)
83
+ sum_dzdt = lambda z: jnp.sum(dzdt(z), 1)
84
+ df_dz = jax.jacrev(sum_dzdt)(z)
85
+ dlogp_z_dt = -1.0 * jnp.trace(df_dz, 0, 1, 2)
86
+
87
+ return lax.concatenate((dz_dt, lax.expand_dims(dlogp_z_dt, (1,))), 1)
88
+
89
+
90
+ class Neg_CNF(nn.Module):
91
+ """Negative CNF for jax's odeint."""
92
+ in_out_dim: Any = 2
93
+ hidden_dim: Any = 32
94
+ width: Any = 64
95
+
96
+ @nn.compact
97
+ def __call__(self, t, states):
98
+ outputs = CNF(self.in_out_dim, self.hidden_dim, self.width)(-1.0 * t, states)
99
+
100
+ return -1.0 * outputs
101
+
102
+
103
+ def get_batch(num_samples):
104
+ """Adapted from the Pytorch implementation at:
105
+ https://github.com/rtqichen/torchdiffeq/blob/master/examples/cnf.py
106
+ """
107
+ points, _ = make_circles(n_samples=num_samples, noise=0.06, factor=0.5)
108
+ x = jnp.array(points, dtype=jnp.float32)
109
+ logp_diff_t1 = jnp.zeros((num_samples, 1), dtype=jnp.float32)
110
+
111
+ return lax.concatenate((x, logp_diff_t1), 1)
112
+
113
+
114
+ def create_train_state(rng, learning_rate, in_out_dim, hidden_dim, width):
115
+ """Creates initial 'TrainState'."""
116
+ inputs = get_batch(10)
117
+ neg_cnf = CNF(in_out_dim, hidden_dim, width)
118
+ params = neg_cnf.init(rng, jnp.array(10.), inputs)['params']
119
+ # set_params(params)
120
+ tx = optax.adam(learning_rate)
121
+ return train_state.TrainState.create(
122
+ apply_fn=neg_cnf.apply, params=params, tx=tx
123
+ )
124
+
125
+
126
+ def set_params(params):
127
+ # Convert all value of Params to certain constant
128
+ params = unfreeze(params)
129
+ # Get flattened-key: value list.
130
+ flat_params = {'/'.join(k): v for k, v in traverse_util.flatten_dict(params).items()}
131
+ unflat_params = traverse_util.unflatten_dict({tuple(k.split('/')): 0.2 * jnp.ones_like(v) for k, v in flat_params.items()})
132
+ new_params = freeze(unflat_params)
133
+ test_x = jnp.array([[0., 1.], [2., 3.]])
134
+ test_log_p = jnp.zeros(2, 1)
135
+ test_inputs = lax.concatenate((test_x, test_log_p), 1)
136
+ Neg_CNF().apply({'params': new_params}, jnp.array(0.), test_inputs)
137
+
138
+
139
+ # @partial(jax.jit, static_argnums=(2, 3, 4, 5, 6))
140
+ def train_step(state, batch, in_out_dim, hidden_dim, width, t0, t1):
141
+ p_z0 = lambda x: scipy.stats.multivariate_normal.logpdf(x,
142
+ mean=jnp.array([0., 0.]),
143
+ cov=jnp.array([[0.1, 0.], [0., 0.1]]))
144
+ def loss_fn(params):
145
+ func = lambda states, t: Neg_CNF(in_out_dim, hidden_dim, width).apply({'params': params}, t, states)
146
+ outputs = odeint(
147
+ func,
148
+ batch,
149
+ -1.0 * jnp.array([t1, t0]),
150
+ atol=1e-5,
151
+ rtol=1e-5
152
+ )
153
+ z_t, logp_diff_t = outputs[..., :2], outputs[..., 2:]
154
+ z_t0, logp_diff_t0 = z_t[-1], logp_diff_t[-1]
155
+ logp_x = p_z0(z_t0) - lax.squeeze(logp_diff_t0, dimensions=(1,))
156
+ loss = -logp_x.mean(0)
157
+ return loss
158
+ grad_fn = jax.value_and_grad(loss_fn)
159
+ loss, grads = grad_fn(state.params)
160
+ state = state.apply_gradients(grads=grads)
161
+
162
+ return state, loss
163
+
164
+
165
+ def train(learning_rate, n_iters, batch_size, in_out_dim, hidden_dim, width, t0, t1, visual):
166
+ """Train the model."""
167
+ rng = jax.random.PRNGKey(0)
168
+ state = create_train_state(rng, learning_rate, in_out_dim, hidden_dim, width)
169
+
170
+ for itr in range(1, n_iters+1):
171
+ batch = get_batch(batch_size)
172
+ state, loss = train_step(state, batch, in_out_dim, hidden_dim, width, t0, t1)
173
+ print("iter: %d, loss: %.2f" % (itr, loss))
174
+
175
+ if visual is True:
176
+ # Convert Params of Neg_CNF to CNF
177
+ neg_params = state.params
178
+ neg_params = unfreeze(neg_params)
179
+ # Get flattened-key: value list.
180
+ neg_flat_params = {'/'.join(k): v for k, v in traverse_util.flatten_dict(neg_params).items()}
181
+ pos_flat_params = {key[6:]: jnp.array(np.array(neg_flat_params[key])) for key in list(neg_flat_params.keys())}
182
+ pos_unflat_params = traverse_util.unflatten_dict({tuple(k.split('/')): v for k, v in pos_flat_params.items()})
183
+ pos_params = freeze(pos_unflat_params)
184
+ output = viz(neg_params, pos_params, in_out_dim, hidden_dim, width, t0, t1)
185
+ z_t_samples, z_t_density, logp_diff_t, viz_timesteps, target_sample, z_t1 = output
186
+ create_plots(z_t_samples, z_t_density, logp_diff_t, t0, t1, viz_timesteps, target_sample, z_t1)
187
+
188
+
189
+ def solve_dynamics(dynamics_fn, initial_state, t):
190
+ def f(initial_state, t):
191
+ return odeint(dynamics_fn, initial_state, t, atol=1e-5, rtol=1e-5)
192
+ return f(initial_state, t)
193
+
194
+
195
+ def viz(neg_params, pos_params, in_out_dim, hidden_dim, width, t0, t1):
196
+ """Adapted from PyTorch """
197
+ viz_samples = 5000
198
+ viz_timesteps = 2
199
+ target_sample, _ = get_batch(viz_samples)
200
+
201
+ if not os.path.exists('results/'):
202
+ os.makedirs('results/')
203
+
204
+ z_t0 = jnp.array(np.random.multivariate_normal(mean=np.array([0., 0.]),
205
+ cov=np.array([[0.1, 0.], [0., 0.1]]),
206
+ size=viz_samples))
207
+ logp_diff_t0 = jnp.zeros((viz_samples, 1), dtype=jnp.float32)
208
+
209
+ func_pos = lambda states, t: CNF(in_out_dim, hidden_dim, width).apply({'params': pos_params}, t, states)
210
+ z_t_samples, _ = solve_dynamics(func_pos, (z_t0, logp_diff_t0), jnp.linspace(t0, t1, viz_timesteps))
211
+
212
+ # Generate evolution of density
213
+ x = jnp.linspace(-1.5, 1.5, 100)
214
+ y = jnp.linspace(-1.5, 1.5, 100)
215
+ points = np.vstack(jnp.meshgrid(x, y)).reshape([2, -1]).T
216
+
217
+ z_t1 = jnp.array(points, dtype=jnp.float32)
218
+ logp_diff_t1 = jnp.zeros((z_t1.shape[0], 1), dtype=jnp.float32)
219
+ func_neg = lambda states, t: Neg_CNF(in_out_dim, hidden_dim, width).apply({'params': neg_params}, -t, states)
220
+ z_t_density, logp_diff_t = solve_dynamics(func_neg, (z_t1, logp_diff_t1), -jnp.linspace(t1, t0, viz_timesteps))
221
+
222
+ return z_t_samples, z_t_density, logp_diff_t, viz_timesteps, target_sample, z_t1
223
+
224
+
225
+ def create_plots(z_t_samples, z_t_density, logp_diff_t, t0, t1, viz_timesteps, target_sample, z_t1):
226
+ # Create plots for each timestep
227
+ for (t, z_sample, z_density, logp_diff) in zip(
228
+ tqdm(np.linspace(t0, t1, viz_timesteps)),
229
+ z_t_samples, z_t_density, logp_diff_t
230
+ ):
231
+ fig = plt.figure(figsize=(12, 4), dpi=200)
232
+ plt.tight_layout()
233
+ plt.axis('off')
234
+ plt.margins(0, 0)
235
+ fig.suptitle(f'{t:.2f}s')
236
+
237
+ ax1 = fig.add_subplot(1, 3, 1)
238
+ ax1.set_title('Target')
239
+ ax1.get_xaxis().set_ticks([])
240
+ ax1.get_yaxis().set_ticks([])
241
+ ax2 = fig.add_subplot(1, 3, 2)
242
+ ax2.set_title('Samples')
243
+ ax2.get_xaxis().set_ticks([])
244
+ ax2.get_yaxis().set_ticks([])
245
+ ax3 = fig.add_subplot(1, 3, 3)
246
+ ax3.set_title('Log Probability')
247
+ ax3.get_xaxis().set_ticks([])
248
+ ax3.get_yaxis().set_ticks([])
249
+
250
+ ax1.hist2d(*jnp.transpose(target_sample), bins=300, density=True,
251
+ range=[[-1.5, 1.5], [-1.5, 1.5]])
252
+
253
+ ax2.hist2d(*jnp.transpose(z_sample), bins=300, density=True,
254
+ range=[[-1.5, 1.5], [-1.5, 1.5]])
255
+ p_z0 = lambda x: scipy.stats.multivariate_normal.logpdf(x,
256
+ mean=jnp.array([0., 0.]),
257
+ cov=jnp.array([[0.1, 0.], [0., 0.1]]))
258
+ logp = p_z0(z_density) - lax.reshape(logp_diff, (z_density.shape[0]))
259
+ ax3.tricontourf(*jnp.transpose(z_t1),
260
+ jnp.exp(logp), 200)
261
+
262
+ plt.savefig(os.path.join('results/', f"cnf-viz-{int(t * 1000):05d}.jpg"),
263
+ pad_inches=0.2, bbox_inches='tight')
264
+ plt.close()
265
+
266
+ img, *imgs = [Image.open(f) for f in sorted(glob.glob(os.path.join('results/', f"cnf-viz-*.jpg")))]
267
+ img.save(fp=os.path.join('results/', "cnf-viz.gif"), format='GIF', append_images=imgs,
268
+ save_all=True, duration=250, loop=0)
269
+
270
+ print('Saved visualization animation at {}'.format(os.path.join('results/', "cnf-viz.gif")))
271
+
272
+
273
+ if __name__ == '__main__':
274
+ train(0.001, 100, 512, 2, 32, 64, 0., 10., True)
train_ode.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import jax
3
+ from typing import Any, Callable, Sequence, Optional, NewType
4
+ from jax import lax, random, vmap, numpy as jnp
5
+ from jax.experimental.ode import odeint
6
+ from jax.experimental import host_callback
7
+ import flax
8
+ from flax.training import train_state
9
+ from flax import traverse_util
10
+ from flax.core import freeze, unfreeze
11
+ from flax import linen as nn
12
+ from flax import serialization
13
+ import optax
14
+ import tensorflow_datasets as tfds
15
+ import numpy as np
16
+ from tqdm import tqdm
17
+ import os
18
+
19
+
20
+ # Define Residual Block
21
+ class ResDownBlock(nn.Module):
22
+ """Single ResBlock w/ downsample"""
23
+ dim_out: Any = 64
24
+
25
+ @nn.compact
26
+ def __call__(self, inputs):
27
+ x = inputs
28
+ f_x = nn.relu(nn.GroupNorm(self.dim_out)(x))
29
+ x = nn.Conv(features=self.dim_out, kernel_size=(1, 1), strides=(2, 2))(x)
30
+ f_x = nn.Conv(features=self.dim_out, kernel_size=(3, 3), strides=(2, 2))(f_x)
31
+ f_x = nn.relu(nn.GroupNorm(self.dim_out)(f_x))
32
+ f_x = nn.Conv(features=self.dim_out, kernel_size=(3, 3))(f_x)
33
+ x = f_x + x
34
+ return x
35
+
36
+
37
+ class ConcatConv2D(nn.Module):
38
+ """Concat dynamics to hidden layer"""
39
+ dim_out: Any = 64
40
+ ksize: Any = 3
41
+
42
+ @nn.compact
43
+ def __call__(self, inputs, t):
44
+ x = inputs
45
+ tt = jnp.ones_like(x[..., :1]) * t
46
+ ttx = jnp.concatenate([tt, x], -1)
47
+ return nn.Conv(features=self.dim_out, kernel_size=(self.ksize, self.ksize))(ttx)
48
+
49
+
50
+ # Define Neural ODE for mnist example.
51
+ class ODEfunc(nn.Module):
52
+ """ODE function which replace ResNet"""
53
+ dim_out: Any = 64
54
+ ksize: Any = 3
55
+
56
+ @nn.compact
57
+ def __call__(self, inputs, t):
58
+ # TODO Count number of function estimation
59
+ host_callback.call(nfecounter.count, 1)
60
+
61
+ x = inputs
62
+ out = nn.GroupNorm(self.dim_out)(x)
63
+ out = nn.relu(out)
64
+ out = ConcatConv2D(self.dim_out, self.ksize)(out, t)
65
+ out = nn.GroupNorm(self.dim_out)(out)
66
+ out = nn.relu(out)
67
+ out = ConcatConv2D(self.dim_out, self.ksize)(out, t)
68
+ out = nn.GroupNorm(self.dim_out)(out)
69
+
70
+ return out
71
+
72
+
73
+ class NFEcounter:
74
+ def __init__(self, init_nfe):
75
+ self.nfe = init_nfe
76
+
77
+ def count(self, increase):
78
+ self.nfe += increase
79
+
80
+ def set(self, target):
81
+ self.nfe = target
82
+
83
+ # Define NFE counter
84
+ nfecounter = NFEcounter(0)
85
+
86
+
87
+ class ODEBlock(nn.Module):
88
+ """ODE block which contains odeint"""
89
+ tol: Any = 1.
90
+
91
+ @nn.compact
92
+ def __call__(self, inputs, params):
93
+ ode_func = ODEfunc()
94
+ ode_func_apply = lambda x, t: ode_func.apply(variables={'params': params}, inputs=x, t=t)
95
+ init_state, final_state = odeint(ode_func_apply,
96
+ inputs, jnp.array([0., 1.]),
97
+ rtol=self.tol, atol=self.tol)
98
+ return final_state
99
+
100
+
101
+ class ODEBlockVmap(nn.Module):
102
+ """Apply vmap to ODEBlock"""
103
+ tol: Any = 1.
104
+
105
+ @nn.compact
106
+ def __call__(self, inputs, params):
107
+ x = inputs
108
+ vmap_odeblock = nn.vmap(ODEBlock,
109
+ variable_axes={'params': 0},
110
+ split_rngs={'params': True},
111
+ in_axes=(0, None))
112
+
113
+ return vmap_odeblock(tol=self.tol, name='odeblock')(x, params)
114
+
115
+
116
+ class FullODENet(nn.Module):
117
+ """Full ODE net which contains two downsampling layers, ODE block and linear classifier."""
118
+ dim_out: Any = 64
119
+ ksize: Any = 3
120
+ tol: Any = 1.
121
+
122
+ @nn.compact
123
+ def __call__(self, inputs):
124
+ x = inputs
125
+ x = nn.Conv(features=self.dim_out, kernel_size=(self.ksize, self.ksize))(x)
126
+ x = ResDownBlock()(x)
127
+ x = ResDownBlock()(x)
128
+
129
+ ode_func = ODEfunc()
130
+ init_fn = lambda rng, x: ode_func.init(random.split(rng)[-1], x, 0.)['params']
131
+ ode_func_params = self.param('ode_func', init_fn, jnp.ones_like(x[0]))
132
+ x = ODEBlockVmap(tol=self.tol)(x, ode_func_params)
133
+
134
+ x = nn.GroupNorm(self.dim_out)(x)
135
+ x = nn.relu(x)
136
+ x = nn.avg_pool(x, (1, 1))
137
+
138
+ x = x.reshape((x.shape[0], -1)) # flatten
139
+
140
+ x = nn.Dense(features=10)(x)
141
+ x = nn.log_softmax(x)
142
+
143
+ return x
144
+
145
+
146
+ # Define loss
147
+ @jax.jit
148
+ def cross_entropy_loss(logits, labels):
149
+ one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
150
+ return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))
151
+
152
+
153
+ # Metric computation
154
+ @jax.jit
155
+ def compute_metrics(logits, labels, nfe_forward, nfe_backward):
156
+ loss = cross_entropy_loss(logits=logits, labels=labels)
157
+ accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
158
+ metrics = {
159
+ 'loss': loss,
160
+ 'accuracy': accuracy,
161
+ 'nfe_forward': nfe_forward,
162
+ 'nfe_backward': nfe_backward
163
+ }
164
+ return metrics
165
+
166
+
167
+ def get_datasets():
168
+ """Load MNIST train and test datasets into memory."""
169
+ ds_builder = tfds.builder('mnist')
170
+ ds_builder.download_and_prepare()
171
+ train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
172
+ test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
173
+ train_ds['image'] = jnp.float32(train_ds['image']) / 255.
174
+ test_ds['image'] = jnp.float32(test_ds['image']) / 255.
175
+ return train_ds, test_ds
176
+
177
+
178
+ def create_train_state(rng, learning_rate, tol):
179
+ """Creates initial 'TrainState'."""
180
+ odenet = FullODENet(tol=tol)
181
+ params = odenet.init(rng, jnp.ones([1, 28, 28, 1]))['params']
182
+ tx = optax.adam(learning_rate)
183
+ return train_state.TrainState.create(
184
+ apply_fn=odenet.apply, params=params, tx=tx
185
+ )
186
+
187
+
188
+ # Training step
189
+ @partial(jax.jit, static_argnums=(2,))
190
+ def train_step(state, batch, tol):
191
+ """Train for a single step."""
192
+ def loss_fn(params):
193
+ logits = FullODENet(tol=tol).apply({'params': params}, batch['image'])
194
+ loss = cross_entropy_loss(logits=logits, labels=batch['label'])
195
+ return loss, logits
196
+ grad_fn = jax.grad(loss_fn, has_aux=True)
197
+ host_callback.call(nfecounter.set, 0)
198
+ (_, logits) = loss_fn(state.params)
199
+ nfe_forward = nfecounter.nfe
200
+ host_callback.call(nfecounter.set, 0)
201
+ grads, _ = grad_fn(state.params)
202
+ nfe_backward = nfecounter.nfe
203
+ state = state.apply_gradients(grads=grads)
204
+ metrics = compute_metrics(logits=logits, labels=batch['label'],
205
+ nfe_forward=nfe_forward, nfe_backward=nfe_backward)
206
+ return state, metrics
207
+
208
+
209
+ # Evaluation step
210
+ @partial(jax.jit, static_argnums=(2,))
211
+ def eval_step(params, batch, tol):
212
+ logits = FullODENet(tol=tol).apply({'params': params}, batch['image'])
213
+ return compute_metrics(logits=logits, labels=batch['label'], nfe_forward=0, nfe_backward=0)
214
+
215
+
216
+ # Train function
217
+ def train_epoch(state, train_ds, batch_size, epoch, rng, tol):
218
+ """Train for a single epoch"""
219
+ train_ds_size = len(train_ds['image'])
220
+ steps_per_epoch = train_ds_size // batch_size
221
+
222
+ perms = jax.random.permutation(rng, len(train_ds['image']))
223
+ perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
224
+ perms = perms.reshape((steps_per_epoch, batch_size))
225
+ batch_metrics = []
226
+ for perm in tqdm(perms):
227
+ batch = {k: v[perm, ...] for k, v in train_ds.items()}
228
+ state, metrics = train_step(state, batch, tol)
229
+ batch_metrics.append(metrics)
230
+
231
+ # compute mean of metrics across each batch in epoch.
232
+ batch_metrics_np = jax.device_get(batch_metrics)
233
+ epoch_metrics_np = {
234
+ k: np.mean([metrics[k] for metrics in batch_metrics_np])
235
+ for k in batch_metrics_np[0]
236
+ }
237
+ print('train epoch: %d, loss: %.4f, accuracy: %.2f, nfe_forward: %d, nfe_backward: %d' % (
238
+ epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100,
239
+ epoch_metrics_np['nfe_forward'], epoch_metrics_np['nfe_backward']
240
+ ))
241
+
242
+ return state
243
+
244
+
245
+ # Eval function
246
+ def eval_model(params, test_ds, tol):
247
+ metrics = eval_step(params, test_ds, tol)
248
+ metrics = jax.device_get(metrics)
249
+ summary = jax.tree_map(lambda x: x.item(), metrics)
250
+ return summary['loss'], summary['accuracy']
251
+
252
+
253
+ def train_and_evaluate(learning_rate, n_epoch, batch_size, tol):
254
+ train_ds, test_ds = get_datasets()
255
+ rng = jax.random.PRNGKey(0)
256
+ rng, init_rng = jax.random.split(rng)
257
+
258
+ state = create_train_state(init_rng, learning_rate, tol)
259
+ del init_rng # Must not be used anymore.
260
+
261
+ for epoch in tqdm(range(1, n_epoch + 1)):
262
+ rng, input_rng = jax.random.split(rng)
263
+ state = train_epoch(state, train_ds, batch_size, epoch, input_rng, tol)
264
+ test_loss, test_accuracy = eval_model(state.params, test_ds, tol)
265
+ print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (
266
+ epoch, test_loss, test_accuracy * 100
267
+ ))
268
+
269
+
270
+ if __name__ == '__main__':
271
+ train_and_evaluate(0.0001, 5, 128, 1.)
train_resnet.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import jax
3
+ from typing import Any, Callable, Sequence, Optional, NewType
4
+ from jax import lax, random, vmap, numpy as jnp
5
+ from jax.experimental.ode import odeint
6
+ import flax
7
+ from flax.training import train_state
8
+ from flax import traverse_util
9
+ from flax.core import freeze, unfreeze
10
+ from flax import linen as nn
11
+ from flax import serialization
12
+ import optax
13
+ import tensorflow_datasets as tfds
14
+ import numpy as np
15
+ from tqdm import tqdm
16
+ import os
17
+
18
+
19
+ # Define residual blocks
20
+ class ResDownBlock(nn.Module):
21
+ """Single ResBlock w/ downsample"""
22
+ dim_out: Any = 64
23
+
24
+ @nn.compact
25
+ def __call__(self, inputs):
26
+ x = inputs
27
+ f_x = nn.relu(nn.GroupNorm(self.dim_out)(x))
28
+ x = nn.Conv(features=self.dim_out, kernel_size=(1, 1), strides=(2, 2))(x)
29
+ f_x = nn.Conv(features=self.dim_out, kernel_size=(3, 3), strides=(2, 2))(f_x)
30
+ f_x = nn.relu(nn.GroupNorm(self.dim_out)(f_x))
31
+ f_x = nn.Conv(features=self.dim_out, kernel_size=(3, 3))(f_x)
32
+ x = f_x + x
33
+ return x
34
+
35
+
36
+ class ResBlock(nn.Module):
37
+ """Single Resblock w/o downsample"""
38
+ dim_out: Any = 64
39
+ ksize: Any = 3
40
+
41
+ @nn.compact
42
+ def __call__(self, inputs):
43
+ x = inputs
44
+ f_x = nn.relu(nn.GroupNorm(self.dim_out)(x))
45
+ f_x = nn.Conv(features=self.dim_out, kernel_size=(self.ksize, self.ksize))(f_x)
46
+ f_x = nn.relu(nn.GroupNorm(self.dim_out)(f_x))
47
+ f_x = nn.Conv(features=self.dim_out, kernel_size=(self.ksize, self.ksize))(f_x)
48
+ x = f_x + x
49
+ return x
50
+
51
+
52
+ # Define small ResNet for Mnist example
53
+ class SmallResNet(nn.Module):
54
+ dim_out: Any = 64
55
+ ksize: Any = 3
56
+
57
+ @nn.compact
58
+ def __call__(self, inputs):
59
+ x = inputs
60
+ x = nn.Conv(features=self.dim_out, kernel_size=(self.ksize, self.ksize))(x)
61
+ x = ResDownBlock()(x)
62
+ x = ResDownBlock()(x)
63
+
64
+ x = ResBlock()(x)
65
+ x = ResBlock()(x)
66
+ x = ResBlock()(x)
67
+ x = ResBlock()(x)
68
+ x = ResBlock()(x)
69
+ x = ResBlock()(x)
70
+
71
+ x = nn.GroupNorm(self.dim_out)(x)
72
+ x = nn.relu(x)
73
+ x = nn.avg_pool(x, (1, 1))
74
+
75
+ x = x.reshape((x.shape[0], -1)) # flatten
76
+
77
+ x = nn.Dense(features=10)(x)
78
+ x = nn.log_softmax(x)
79
+
80
+ return x
81
+
82
+
83
+ # Define loss
84
+ @jax.jit
85
+ def cross_entropy_loss(logits, labels):
86
+ one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
87
+ return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))
88
+
89
+
90
+ # Metric computation
91
+ @jax.jit
92
+ def compute_metrics(logits, labels):
93
+ loss = cross_entropy_loss(logits=logits, labels=labels)
94
+ accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
95
+ metrics = {
96
+ 'loss': loss,
97
+ 'accuracy': accuracy,
98
+ }
99
+ return metrics
100
+
101
+
102
+ def get_datasets():
103
+ """Load MNIST train and test datasets into memory."""
104
+ ds_builder = tfds.builder('mnist')
105
+ ds_builder.download_and_prepare()
106
+ train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
107
+ test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
108
+ train_ds['image'] = jnp.float32(train_ds['image']) / 255.
109
+ test_ds['image'] = jnp.float32(test_ds['image']) / 255.
110
+ return train_ds, test_ds
111
+
112
+
113
+ def create_train_state(rng, learning_rate):
114
+ """Creates initial 'TrainState'."""
115
+ resnet = SmallResNet()
116
+ params = resnet.init(rng, jnp.ones([1, 28, 28, 1]))['params']
117
+ tx = optax.adam(learning_rate)
118
+ return train_state.TrainState.create(
119
+ apply_fn=resnet.apply, params=params, tx=tx
120
+ )
121
+
122
+
123
+ # Training step
124
+ @jax.jit
125
+ def train_step(state, batch):
126
+ """Train for a single step."""
127
+ def loss_fn(params):
128
+ logits = SmallResNet().apply({'params': params}, batch['image'])
129
+ loss = cross_entropy_loss(logits=logits, labels=batch['label'])
130
+ return loss, logits
131
+ grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
132
+ (_, logits), grads = grad_fn(state.params)
133
+ state = state.apply_gradients(grads=grads)
134
+ metrics = compute_metrics(logits=logits, labels=batch['label'])
135
+ return state, metrics
136
+
137
+
138
+ # Evaluation step
139
+ @jax.jit
140
+ def eval_step(params, batch):
141
+ logits = SmallResNet().apply({'params': params}, batch['image'])
142
+ return compute_metrics(logits=logits, labels=batch['label'])
143
+
144
+
145
+ # Train function
146
+ def train_epoch(state, train_ds, batch_size, epoch, rng):
147
+ """Train for a single epoch"""
148
+ train_ds_size = len(train_ds['image'])
149
+ steps_per_epoch = train_ds_size // batch_size
150
+
151
+ perms = jax.random.permutation(rng, len(train_ds['image']))
152
+ perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
153
+ perms = perms.reshape((steps_per_epoch, batch_size))
154
+ batch_metrics = []
155
+ for perm in tqdm(perms):
156
+ batch = {k: v[perm, ...] for k, v in train_ds.items()}
157
+ state, metrics = train_step(state, batch)
158
+ batch_metrics.append(metrics)
159
+
160
+ # compute mean of metrics across each batch in epoch.
161
+ batch_metrics_np = jax.device_get(batch_metrics)
162
+ epoch_metrics_np = {
163
+ k: np.mean([metrics[k] for metrics in batch_metrics_np])
164
+ for k in batch_metrics_np[0]
165
+ }
166
+ print('train epoch: %d, loss: %.4f, accuracy: %.2f' % (
167
+ epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100
168
+ ))
169
+
170
+ return state
171
+
172
+
173
+ # Eval function
174
+ def eval_model(params, test_ds):
175
+ metrics = eval_step(params, test_ds)
176
+ metrics = jax.device_get(metrics)
177
+ summary = jax.tree_map(lambda x: x.item(), metrics)
178
+ return summary['loss'], summary['accuracy']
179
+
180
+
181
+ def train_and_evaluate(learning_rate, n_epoch, batch_size):
182
+ train_ds, test_ds = get_datasets()
183
+ rng = jax.random.PRNGKey(0)
184
+ rng, init_rng = jax.random.split(rng)
185
+
186
+ state = create_train_state(init_rng, learning_rate)
187
+ del init_rng # Must not be used anymore.
188
+
189
+ for epoch in tqdm(range(1, n_epoch + 1)):
190
+ rng, input_rng = jax.random.split(rng)
191
+ state = train_epoch(state, train_ds, batch_size, epoch, input_rng)
192
+ test_loss, test_accuracy = eval_model(state.params, test_ds)
193
+ print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (
194
+ epoch, test_loss, test_accuracy * 100
195
+ ))