|
import jax |
|
from typing import Any, Callable, Sequence, Optional |
|
from jax import lax, random, numpy as jnp |
|
import flax |
|
from flax.training import train_state |
|
from flax.core import freeze, unfreeze |
|
from flax import linen as nn |
|
from flax import serialization |
|
import optax |
|
|
|
|
|
class ExplicitMLP(nn.Module): |
|
features: Sequence[int] |
|
|
|
def setup(self): |
|
self.layers = [nn.Dense(feat) for feat in self.features] |
|
|
|
def __call__(self, inputs): |
|
x = inputs |
|
for i, lyr in enumerate(self.layers): |
|
x = lyr(x) |
|
if i != len(self.layers) - 1: |
|
x = nn.relu(x) |
|
return x |
|
|
|
|
|
class SimpleMLP(nn.Module): |
|
features: Sequence[int] |
|
|
|
@nn.compact |
|
def __call__(self, inputs): |
|
x = inputs |
|
for i, feat in enumerate(self.features): |
|
x = nn.Dense(feat)(x) |
|
if i != len(self.features - 1): |
|
x = nn.relu(x) |
|
return x |
|
|
|
|
|
if __name__ == '__main__': |
|
key1, key2 = random.split(random.PRNGKey(0), 2) |
|
|
|
|
|
nsamples = 20 |
|
xdim = 10 |
|
ydim = 5 |
|
|
|
|
|
W = random.normal(key1, (xdim, ydim)) |
|
b = random.normal(key2, (ydim,)) |
|
true_params = freeze({'params': {'bias': b, 'kernel': W}}) |
|
|
|
|
|
ksample, knoise = random.split(key1) |
|
x_samples = random.normal(ksample, (nsamples, xdim)) |
|
y_samples = jnp.dot(x_samples, W) + b |
|
y_samples += 0.1 * random.normal(knoise, (nsamples, ydim)) |
|
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape) |
|
|
|
key_init, subkey = random.split(ksample, 2) |
|
model = ExplicitMLP(features=[5]) |
|
params = model.init(subkey, x_samples) |
|
|
|
def make_mse_func(x_batched, y_batched): |
|
def mse(params): |
|
|
|
def squared_error(x, y): |
|
pred = model.apply(params, x) |
|
return jnp.inner(y - pred, y - pred) / 2.0 |
|
|
|
|
|
return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0) |
|
|
|
return jax.jit(mse) |
|
|
|
|
|
loss = make_mse_func(x_samples, y_samples) |
|
|
|
lr = 0.3 |
|
tx = optax.sgd(learning_rate=lr) |
|
opt_state = tx.init(params) |
|
loss_grad_fn = jax.value_and_grad(loss) |
|
|
|
for i in range(101): |
|
loss_val, grads = loss_grad_fn(params) |
|
updates, opt_state = tx.update(grads, opt_state) |
|
params = optax.apply_updates(params, updates) |
|
|
|
if i % 10 == 0: |
|
print('Loss step {}: '.format(i), loss_val) |
|
|
|
|
|
bytes_output = serialization.to_bytes(params) |
|
dict_output = serialization.to_state_dict(params) |
|
print('Dict output') |
|
print(dict_output) |
|
print('Bytes output') |
|
print(bytes_output) |
|
|
|
|
|
saved_params = serialization.from_bytes(params, bytes_output) |
|
print(loss(saved_params)) |
|
print(loss(params)) |
|
|