|
import jax |
|
from typing import Any, Callable, Sequence, Optional |
|
from jax import lax, random, vmap, numpy as jnp |
|
from jax.experimental.ode import odeint |
|
import flax |
|
from flax.training import train_state |
|
from flax.core import freeze, unfreeze |
|
from flax import linen as nn |
|
from flax import serialization |
|
import optax |
|
import tensorflow_datasets as tfds |
|
import numpy as np |
|
|
|
|
|
|
|
class CNN(nn.Module): |
|
"""A simple CNN model.""" |
|
|
|
@nn.compact |
|
def __call__(self, inputs): |
|
x = inputs |
|
x = nn.Conv(features=32, kernel_size=(3, 3))(x) |
|
x = nn.relu(x) |
|
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) |
|
|
|
x = nn.Conv(features=64, kernel_size=(3, 3))(x) |
|
x = nn.relu(x) |
|
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) |
|
x = x.reshape((x.shape[0], -1)) |
|
|
|
x = nn.Dense(features=256)(x) |
|
x = nn.relu(x) |
|
x = nn.Dense(features=10)(x) |
|
x = nn.log_softmax(x) |
|
return x |
|
|
|
|
|
|
|
class ResBlock(nn.Module): |
|
"""Single Resblock w/o downsample""" |
|
|
|
@nn.compact |
|
def __call__(self, inputs): |
|
x = inputs |
|
f_x = nn.relu(nn.GroupNorm(64)(x)) |
|
f_x = nn.Conv(features=64, kernel_size=(3, 3))(f_x) |
|
f_x = nn.relu(nn.GroupNorm(64)(f_x)) |
|
f_x = nn.Conv(features=64, kernel_size=(3, 3))(f_x) |
|
x = f_x + x |
|
return x |
|
|
|
class ResDownBlock(nn.Module): |
|
"""Single ResBlock w/ downsample""" |
|
|
|
@nn.compact |
|
def __call__(self, inputs): |
|
x = inputs |
|
f_x = nn.relu(nn.GroupNorm(64)(x)) |
|
x = nn.Conv(features=64, kernel_size=(1, 1), strides=(2, 2))(x) |
|
f_x = nn.Conv(features=64, kernel_size=(3, 3), strides=(2, 2))(f_x) |
|
f_x = nn.relu(nn.GroupNorm(64)(f_x)) |
|
f_x = nn.Conv(features=64, kernel_size=(3, 3))(f_x) |
|
x = f_x + x |
|
return x |
|
|
|
|
|
|
|
class SmallResNet(nn.Module): |
|
res_down1: Callable = ResDownBlock() |
|
res_down2: Callable = ResDownBlock() |
|
resblock1: Callable = ResBlock() |
|
resblock2: Callable = ResBlock() |
|
resblock3: Callable = ResBlock() |
|
resblock4: Callable = ResBlock() |
|
resblock5: Callable = ResBlock() |
|
resblock6: Callable = ResBlock() |
|
|
|
@nn.compact |
|
def __call__(self, inputs): |
|
x = inputs |
|
x = nn.Conv(features=64, kernel_size=(3, 3))(x) |
|
x = self.res_down1(x) |
|
x = self.res_down2(x) |
|
|
|
x = self.resblock1(x) |
|
x = self.resblock2(x) |
|
x = self.resblock3(x) |
|
x = self.resblock4(x) |
|
x = self.resblock5(x) |
|
x = self.resblock6(x) |
|
|
|
x = nn.GroupNorm(64)(x) |
|
x = nn.relu(x) |
|
x = nn.avg_pool(x, (1, 1)) |
|
|
|
x = x.reshape((x.shape[0], -1)) |
|
|
|
x = nn.Dense(features=10)(x) |
|
x = nn.log_softmax(x) |
|
|
|
return x |
|
|
|
|
|
|
|
def cross_entropy_loss(*, logits, labels): |
|
one_hot_labels = jax.nn.one_hot(labels, num_classes=10) |
|
return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1)) |
|
|
|
|
|
|
|
def compute_metrics(*, logits, labels): |
|
loss = cross_entropy_loss(logits=logits, labels=labels) |
|
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) |
|
metrics = { |
|
'loss': loss, |
|
'accuracy': accuracy, |
|
} |
|
return metrics |
|
|
|
|
|
def get_datasets(): |
|
"""Load MNIST train and test datasets into memory.""" |
|
ds_builder = tfds.builder('mnist') |
|
ds_builder.download_and_prepare() |
|
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1)) |
|
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1)) |
|
train_ds['image'] = jnp.float32(train_ds['image']) / 255. |
|
test_ds['image'] = jnp.float32(test_ds['image']) / 255. |
|
return train_ds, test_ds |
|
|
|
|
|
def create_train_state(rng, learning_rate): |
|
"""Creates initial 'TrainState'.""" |
|
cnn = SmallResNet() |
|
params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params'] |
|
tx = optax.adam(learning_rate) |
|
return train_state.TrainState.create( |
|
apply_fn=cnn.apply, params=params, tx=tx |
|
) |
|
|
|
|
|
|
|
@jax.jit |
|
def train_step(state, batch): |
|
"""Train for a single step.""" |
|
def loss_fn(params): |
|
logits = SmallResNet().apply({'params': params}, batch['image']) |
|
loss = cross_entropy_loss(logits=logits, labels=batch['label']) |
|
return loss, logits |
|
grad_fn = jax.value_and_grad(loss_fn, has_aux=True) |
|
(_, logits), grads = grad_fn(state.params) |
|
state = state.apply_gradients(grads=grads) |
|
metrics = compute_metrics(logits=logits, labels=batch['label']) |
|
return state, metrics |
|
|
|
|
|
|
|
@jax.jit |
|
def eval_step(params, batch): |
|
logits = SmallResNet().apply({'params': params}, batch['image']) |
|
return compute_metrics(logits=logits, labels=batch['label']) |
|
|
|
|
|
|
|
def train_epoch(state, train_ds, batch_size, epoch, rng): |
|
"""Train for a single epoch""" |
|
train_ds_size = len(train_ds['image']) |
|
steps_per_epoch = train_ds_size // batch_size |
|
|
|
perms = jax.random.permutation(rng, len(train_ds['image'])) |
|
perms = perms[:steps_per_epoch * batch_size] |
|
perms = perms.reshape((steps_per_epoch, batch_size)) |
|
batch_metrics = [] |
|
for perm in perms: |
|
batch = {k: v[perm, ...] for k, v in train_ds.items()} |
|
state, metrics = train_step(state, batch) |
|
batch_metrics.append(metrics) |
|
|
|
|
|
batch_metrics_np = jax.device_get(batch_metrics) |
|
epoch_metrics_np = { |
|
k: np.mean([metrics[k] for metrics in batch_metrics_np]) |
|
for k in batch_metrics_np[0] |
|
} |
|
print('train epoch: %d, loss: %.4f, accuracy: %.2f' % ( |
|
epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100 |
|
)) |
|
|
|
return state |
|
|
|
|
|
|
|
def eval_model(params, test_ds): |
|
metrics = eval_step(params, test_ds) |
|
metrics = jax.device_get(metrics) |
|
summary = jax.tree_map(lambda x: x.item(), metrics) |
|
return summary['loss'], summary['accuracy'] |
|
|
|
|
|
if __name__ == '__main__': |
|
train_ds, test_ds = get_datasets() |
|
rng = jax.random.PRNGKey(0) |
|
rng, init_rng = jax.random.split(rng) |
|
|
|
learning_rate = 0.0001 |
|
|
|
state = create_train_state(init_rng, learning_rate) |
|
del init_rng |
|
|
|
num_epochs = 40 |
|
batch_size = 128 |
|
|
|
for epoch in range(1, num_epochs + 1): |
|
rng, input_rng = jax.random.split(rng) |
|
state = train_epoch(state, train_ds, batch_size, epoch, input_rng) |
|
test_loss, test_accuracy = eval_model(state.params, test_ds) |
|
print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % ( |
|
epoch, test_loss, test_accuracy * 100 |
|
)) |
|
|