|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for optax.""" |
|
|
|
from absl.testing import absltest |
|
from absl.testing import parameterized |
|
from big_vision import optax as bv_optax |
|
import chex |
|
import jax |
|
import jax.numpy as jnp |
|
import ml_collections |
|
import numpy as np |
|
import optax |
|
|
|
|
|
class OptaxTest(parameterized.TestCase): |
|
|
|
def test_get_count(self): |
|
params = jax.tree.map(jnp.array, {"a": 1.}) |
|
tx = optax.masked( |
|
optax.scale_by_schedule(lambda step: step), |
|
{"a": True}, |
|
) |
|
opt_state = tx.init(params) |
|
self.assertEqual(bv_optax.get_count(opt_state), 0) |
|
_, opt_state = tx.update(params, opt_state) |
|
self.assertEqual(bv_optax.get_count(opt_state), 1) |
|
|
|
def test_split_frozen(self): |
|
params = jax.tree.map(jnp.array, { |
|
"Dense_0": {"kernel": 1., "bias": 2.}, |
|
}) |
|
sched1 = dict(decay_type="cosine") |
|
sched2 = dict(decay_type="linear") |
|
schedule = [ |
|
(".*/kernel", sched1), |
|
(".*/bias", sched2), |
|
] |
|
masks, scheds = bv_optax._make_mask_trees(params, schedule, log="schedule") |
|
frozen_mask, masks, scheds = bv_optax._split_frozen(masks, scheds) |
|
chex.assert_trees_all_equal( |
|
frozen_mask, |
|
{"Dense_0": {"kernel": False, "bias": False}}, |
|
) |
|
chex.assert_trees_all_equal( |
|
masks, |
|
( |
|
{"Dense_0": {"kernel": True, "bias": False}}, |
|
{"Dense_0": {"kernel": False, "bias": True}}, |
|
), |
|
) |
|
self.assertEqual(scheds, (sched1, sched2)) |
|
|
|
schedule = [ |
|
(".*/bias", None), |
|
("Dense_0/.*", sched1), |
|
(".*", None), |
|
] |
|
masks, scheds = bv_optax._make_mask_trees(params, schedule, log="schedule") |
|
frozen_mask, masks, scheds = bv_optax._split_frozen(masks, scheds) |
|
chex.assert_trees_all_equal( |
|
frozen_mask, |
|
{"Dense_0": {"kernel": False, "bias": True}}, |
|
) |
|
chex.assert_trees_all_equal( |
|
masks, |
|
({"Dense_0": {"kernel": True, "bias": False}},), |
|
) |
|
self.assertEqual(scheds, (sched1,)) |
|
|
|
schedule = [ |
|
(".*/kernel", None), |
|
] |
|
masks, scheds = bv_optax._make_mask_trees(params, schedule, log="schedule") |
|
with self.assertRaisesRegex(AssertionError, "All params must be covered"): |
|
_ = bv_optax._split_frozen(masks, scheds) |
|
|
|
def test_replace_frozen(self): |
|
params = jax.tree.map(jnp.array, { |
|
"Dense_0": {"kernel": 1., "bias": 2.}, |
|
}) |
|
schedule = [ |
|
(".*/kernel", {}), |
|
(".*", None), |
|
] |
|
chex.assert_trees_all_equal( |
|
bv_optax.replace_frozen(schedule, params, 0.), |
|
{"Dense_0": {"kernel": 1., "bias": 0.}}, |
|
) |
|
|
|
def test_make_simple(self): |
|
params = jax.tree.map(jnp.array, { |
|
"Dense_0": {"kernel": 1., "bias": 2.}, |
|
}) |
|
|
|
config = ml_collections.ConfigDict() |
|
config.lr = 0.01 |
|
config.schedule = dict(decay_type="linear") |
|
config.optax_name = "scale" |
|
config.optax = ml_collections.ConfigDict() |
|
g_scale = 0.5 |
|
config.optax.step_size = g_scale |
|
|
|
total_steps = 10 |
|
sched_kw = dict(global_batch_size=1, total_steps=total_steps) |
|
tx, (schedule_fn,) = bv_optax.make(config, params, sched_kw=sched_kw) |
|
opt_state = tx.init(params) |
|
grads = jax.tree.map(jnp.ones_like, params) |
|
for step in range(total_steps): |
|
updates, opt_state = tx.update(grads, opt_state) |
|
self.assertEqual(bv_optax.get_count(opt_state), step + 1) |
|
sched = schedule_fn(step) |
|
np.testing.assert_almost_equal( |
|
sched, 1.0 / total_steps * (total_steps - step)) |
|
make_tx = lambda sched: lambda g: -sched * config.lr * g_scale * g |
|
chex.assert_trees_all_close(updates, jax.tree.map(make_tx(sched), grads)) |
|
|
|
def test_make_wd(self): |
|
params = jax.tree.map(jnp.array, { |
|
"Dense_0": {"kernel": 1., "bias": 2., "other": 3.}, |
|
}) |
|
wds = jax.tree.map(jnp.array, { |
|
"Dense_0": {"kernel": 2e-3, "bias": 5e-4, "other": 0.}, |
|
}) |
|
|
|
config = ml_collections.ConfigDict() |
|
config.lr = 0.01 |
|
config.wd = 1e-3 |
|
config.wd_mults = [ |
|
(".*/kernel", 2.0), |
|
(".*/bias", 0.5), |
|
] |
|
config.schedule = dict(decay_type="linear") |
|
config.optax_name = "scale" |
|
config.optax = ml_collections.ConfigDict() |
|
g_scale = 0.5 |
|
config.optax.step_size = g_scale |
|
|
|
total_steps = 10 |
|
sched_kw = dict(global_batch_size=1, total_steps=total_steps) |
|
tx, (sched_fn,) = bv_optax.make(config, params, sched_kw=sched_kw) |
|
opt_state = tx.init(params) |
|
grads = jax.tree.map(jnp.ones_like, params) |
|
for step in range(total_steps): |
|
updates, opt_state = tx.update(grads, opt_state, params) |
|
self.assertEqual(bv_optax.get_count(opt_state), step + 1) |
|
sched = sched_fn(step) |
|
np.testing.assert_almost_equal( |
|
sched, 1.0 / total_steps * (total_steps - step)) |
|
|
|
def make_tx(sched): |
|
def inner(p, g, wd): |
|
return -sched * (config.lr * g_scale * g + p * wd) |
|
return inner |
|
|
|
chex.assert_trees_all_close( |
|
updates, jax.tree.map(make_tx(sched), params, grads, wds)) |
|
|
|
def test_make_clip_norm(self): |
|
params = jax.tree.map(jnp.array, { |
|
"Dense_0": {"kernel": 1., "bias": 2., "other": 3.}, |
|
}) |
|
|
|
config = ml_collections.ConfigDict() |
|
config.lr = 0.01 |
|
config.schedule = dict(decay_type="linear") |
|
config.optax_name = "scale" |
|
config.grad_clip_norm = 1.0 |
|
config.optax = ml_collections.ConfigDict() |
|
g_scale = 0.5 |
|
config.optax.step_size = g_scale |
|
|
|
total_steps = 10 |
|
sched_kw = dict(global_batch_size=1, total_steps=total_steps) |
|
tx, (sched_fn,) = bv_optax.make(config, params, sched_kw=sched_kw) |
|
opt_state = tx.init(params) |
|
|
|
grads = jax.tree.map(jnp.ones_like, params) |
|
gflat = jax.tree.leaves(grads) |
|
l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in gflat])) |
|
grad_clip_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g) |
|
grads_scaled = jax.tree.map(lambda p: grad_clip_factor * p, grads) |
|
|
|
for step in range(total_steps): |
|
updates, opt_state = tx.update(grads, opt_state) |
|
self.assertEqual(bv_optax.get_count(opt_state), step + 1) |
|
sched = sched_fn(step) |
|
np.testing.assert_almost_equal( |
|
sched, 1.0 / total_steps * (total_steps - step)) |
|
make_tx = lambda sched: lambda g: -sched * config.lr * g_scale * g |
|
chex.assert_trees_all_close(updates, |
|
jax.tree.map(make_tx(sched), grads_scaled)) |
|
|
|
def test_make_multi(self): |
|
params = jax.tree.map( |
|
jnp.array, { |
|
"Dense_0": {"kernel": 1.0, "bias": 2.0, "other": 3.0}, |
|
"Dense_1": {"kernel": 4.0, "bias": 5.0, "other": 6.0}, |
|
"Dense_2": {"kernel": 7.0, "bias": 8.0, "other": 9.0}, |
|
"Dense_3": {"kernel": 10., "bias": 11., "other": 12.}, |
|
}) |
|
|
|
|
|
lrb = 0.01 |
|
lr1 = 2.0 |
|
lr2 = 0.5 |
|
lr_mults = { |
|
"Dense_0": {"kernel": lr1, "bias": lr1, "other": lr1}, |
|
"Dense_1": {"kernel": lr2, "bias": lr2, "other": lr2}, |
|
"Dense_2": {"kernel": 1.0, "bias": 1.0, "other": 1.0}, |
|
"Dense_3": {"kernel": 1.0, "bias": 1.0, "other": 1.0}, |
|
} |
|
wdb = 1e-3 |
|
wd1 = 10.0 |
|
wd2 = 0.1 |
|
wds = jax.tree.map( |
|
jnp.array, { |
|
"Dense_0": {"kernel": wd1 * wdb, "bias": wd2 * wdb, "other": 0.}, |
|
"Dense_1": {"kernel": wd1 * wdb, "bias": wd2 * wdb, "other": 0.}, |
|
"Dense_2": {"kernel": wd1 * wdb, "bias": wd2 * wdb, "other": 0.}, |
|
"Dense_3": {"kernel": 0.0 * wdb, "bias": 0.0 * wdb, "other": 0.}, |
|
}) |
|
|
|
config = ml_collections.ConfigDict() |
|
config.lr = lrb |
|
config.lr_mults = [ |
|
("Dense_0/.*", lr1), |
|
("Dense_1/.*", lr2), |
|
] |
|
config.wd = wdb |
|
config.wd_mults = [ |
|
(".*/kernel", wd1), |
|
(".*/bias", wd2), |
|
] |
|
mult1 = 1.0 |
|
mult2 = 0.1 |
|
config.schedule = [ |
|
("Dense_0/.*", dict(decay_type="linear", mult=mult1, linear_end=mult1)), |
|
("Dense_[12]/.*", dict(decay_type="linear", mult=mult2)), |
|
(".*", None), |
|
] |
|
config.optax_name = "scale" |
|
config.grad_clip_norm = 1.0 |
|
config.optax = ml_collections.ConfigDict() |
|
g_scale = 0.5 |
|
config.optax.step_size = g_scale |
|
|
|
total_steps = 10 |
|
sched_kw = dict(global_batch_size=1, total_steps=total_steps) |
|
tx, (sched_fn1, |
|
sched_fn2) = bv_optax.make(config, params, sched_kw=sched_kw) |
|
opt_state = tx.init(params) |
|
|
|
|
|
frozen_fn = lambda _: jnp.array(0.) |
|
sched_fns = { |
|
"Dense_0": {"kernel": sched_fn1, "bias": sched_fn1, "other": sched_fn1}, |
|
"Dense_1": {"kernel": sched_fn2, "bias": sched_fn2, "other": sched_fn2}, |
|
"Dense_2": {"kernel": sched_fn2, "bias": sched_fn2, "other": sched_fn2}, |
|
"Dense_3": {"kernel": frozen_fn, "bias": frozen_fn, "other": frozen_fn}, |
|
} |
|
|
|
grads = jax.tree.map(jnp.ones_like, params) |
|
gflat, _ = jax.tree.flatten( |
|
|
|
jax.tree.map(lambda g, sched_fn: {frozen_fn: 0}.get(sched_fn, g), |
|
grads, sched_fns)) |
|
l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in gflat])) |
|
grad_clip_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g) |
|
grads_scaled = jax.tree.map(lambda p: grad_clip_factor * p, grads) |
|
|
|
def make_tx(step): |
|
def get_update(p, g, wd, sched_fn, lr_mult): |
|
return -sched_fn(step) * (lrb * lr_mult * g_scale * g + p * wd) |
|
return get_update |
|
|
|
for step in range(total_steps): |
|
updates, opt_state = tx.update(grads, opt_state, params) |
|
self.assertEqual(bv_optax.get_count(opt_state), step + 1) |
|
sched1, sched2 = sched_fn1(step), sched_fn2(step) |
|
np.testing.assert_almost_equal(sched1, mult1) |
|
np.testing.assert_almost_equal(sched2, |
|
mult2 * (total_steps - step) / total_steps) |
|
chex.assert_trees_all_close( |
|
updates, |
|
jax.tree.map( |
|
make_tx(step), params, grads_scaled, wds, sched_fns, lr_mults)) |
|
|
|
def test_frozen_no_state(self): |
|
params = {"small": jnp.zeros([1]), "large": jnp.zeros([1000])} |
|
config = ml_collections.ConfigDict() |
|
config.lr = 0.01 |
|
config.schedule = [ |
|
("small", dict(decay_type="cosine")), |
|
("large", None), |
|
] |
|
config.optax_name = "scale_by_adam" |
|
|
|
sched_kw = dict(global_batch_size=1, total_steps=1) |
|
tx, _ = bv_optax.make(config, params, sched_kw=sched_kw) |
|
|
|
opt_state = tx.init(params) |
|
adam_state = bv_optax.find_states(opt_state, optax.ScaleByAdamState) |
|
nbytes = sum( |
|
jax.tree.flatten(jax.tree.map(lambda x: x.nbytes, adam_state))[0]) |
|
self.assertLess(nbytes, 1_000) |
|
|
|
def test_adafactor(self): |
|
params = {"Dense_0": {"kernel": jnp.zeros([1024, 1024])}} |
|
|
|
config = ml_collections.ConfigDict() |
|
config.optax_name = "big_vision.scale_by_adafactor" |
|
config.lr = 0.01 |
|
config.schedule = dict(decay_type="linear") |
|
sched_kw = dict(global_batch_size=1, total_steps=1) |
|
|
|
tx, _ = bv_optax.make(config, params, sched_kw=sched_kw) |
|
|
|
opt_state = tx.init(params) |
|
adafactor_state = bv_optax.find_states(opt_state, optax.FactoredState) |
|
n_state_params = sum( |
|
jax.tree.flatten( |
|
jax.tree.map(lambda x: np.prod( |
|
x.shape if hasattr(x, "shape") else 0), adafactor_state))[0]) |
|
self.assertEqual(n_state_params, 2 * 1024 + 2) |
|
|
|
|
|
if __name__ == "__main__": |
|
absltest.main() |
|
|