pranavSIT's picture
added pali inference
74e8f2f
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for optax."""
from absl.testing import absltest
from absl.testing import parameterized
from big_vision import optax as bv_optax
import chex
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import optax
class OptaxTest(parameterized.TestCase):
def test_get_count(self):
params = jax.tree.map(jnp.array, {"a": 1.})
tx = optax.masked(
optax.scale_by_schedule(lambda step: step),
{"a": True},
)
opt_state = tx.init(params)
self.assertEqual(bv_optax.get_count(opt_state), 0)
_, opt_state = tx.update(params, opt_state)
self.assertEqual(bv_optax.get_count(opt_state), 1)
def test_split_frozen(self):
params = jax.tree.map(jnp.array, {
"Dense_0": {"kernel": 1., "bias": 2.},
}) # pyformat: disable
sched1 = dict(decay_type="cosine")
sched2 = dict(decay_type="linear")
schedule = [
(".*/kernel", sched1),
(".*/bias", sched2),
]
masks, scheds = bv_optax._make_mask_trees(params, schedule, log="schedule")
frozen_mask, masks, scheds = bv_optax._split_frozen(masks, scheds)
chex.assert_trees_all_equal(
frozen_mask,
{"Dense_0": {"kernel": False, "bias": False}},
) # pyformat: disable
chex.assert_trees_all_equal(
masks,
(
{"Dense_0": {"kernel": True, "bias": False}},
{"Dense_0": {"kernel": False, "bias": True}},
),
) # pyformat: disable
self.assertEqual(scheds, (sched1, sched2))
# freeze some
schedule = [
(".*/bias", None),
("Dense_0/.*", sched1),
(".*", None),
]
masks, scheds = bv_optax._make_mask_trees(params, schedule, log="schedule")
frozen_mask, masks, scheds = bv_optax._split_frozen(masks, scheds)
chex.assert_trees_all_equal(
frozen_mask,
{"Dense_0": {"kernel": False, "bias": True}},
) # pyformat: disable
chex.assert_trees_all_equal(
masks,
({"Dense_0": {"kernel": True, "bias": False}},),
) # pyformat: disable
self.assertEqual(scheds, (sched1,))
# does not cover all params - fails
schedule = [
(".*/kernel", None),
]
masks, scheds = bv_optax._make_mask_trees(params, schedule, log="schedule")
with self.assertRaisesRegex(AssertionError, "All params must be covered"):
_ = bv_optax._split_frozen(masks, scheds)
def test_replace_frozen(self):
params = jax.tree.map(jnp.array, {
"Dense_0": {"kernel": 1., "bias": 2.},
}) # pyformat: disable
schedule = [
(".*/kernel", {}),
(".*", None),
]
chex.assert_trees_all_equal(
bv_optax.replace_frozen(schedule, params, 0.),
{"Dense_0": {"kernel": 1., "bias": 0.}},
) # pyformat: disable
def test_make_simple(self):
params = jax.tree.map(jnp.array, {
"Dense_0": {"kernel": 1., "bias": 2.},
}) # pyformat: disable
config = ml_collections.ConfigDict()
config.lr = 0.01
config.schedule = dict(decay_type="linear")
config.optax_name = "scale"
config.optax = ml_collections.ConfigDict()
g_scale = 0.5
config.optax.step_size = g_scale
total_steps = 10
sched_kw = dict(global_batch_size=1, total_steps=total_steps)
tx, (schedule_fn,) = bv_optax.make(config, params, sched_kw=sched_kw)
opt_state = tx.init(params)
grads = jax.tree.map(jnp.ones_like, params)
for step in range(total_steps):
updates, opt_state = tx.update(grads, opt_state)
self.assertEqual(bv_optax.get_count(opt_state), step + 1)
sched = schedule_fn(step)
np.testing.assert_almost_equal(
sched, 1.0 / total_steps * (total_steps - step))
make_tx = lambda sched: lambda g: -sched * config.lr * g_scale * g
chex.assert_trees_all_close(updates, jax.tree.map(make_tx(sched), grads))
def test_make_wd(self):
params = jax.tree.map(jnp.array, {
"Dense_0": {"kernel": 1., "bias": 2., "other": 3.},
}) # pyformat: disable
wds = jax.tree.map(jnp.array, {
"Dense_0": {"kernel": 2e-3, "bias": 5e-4, "other": 0.},
}) # pyformat: disable
config = ml_collections.ConfigDict()
config.lr = 0.01
config.wd = 1e-3
config.wd_mults = [
(".*/kernel", 2.0),
(".*/bias", 0.5),
]
config.schedule = dict(decay_type="linear")
config.optax_name = "scale"
config.optax = ml_collections.ConfigDict()
g_scale = 0.5
config.optax.step_size = g_scale
total_steps = 10
sched_kw = dict(global_batch_size=1, total_steps=total_steps)
tx, (sched_fn,) = bv_optax.make(config, params, sched_kw=sched_kw)
opt_state = tx.init(params)
grads = jax.tree.map(jnp.ones_like, params)
for step in range(total_steps):
updates, opt_state = tx.update(grads, opt_state, params)
self.assertEqual(bv_optax.get_count(opt_state), step + 1)
sched = sched_fn(step)
np.testing.assert_almost_equal(
sched, 1.0 / total_steps * (total_steps - step))
def make_tx(sched):
def inner(p, g, wd):
return -sched * (config.lr * g_scale * g + p * wd)
return inner
chex.assert_trees_all_close(
updates, jax.tree.map(make_tx(sched), params, grads, wds))
def test_make_clip_norm(self):
params = jax.tree.map(jnp.array, {
"Dense_0": {"kernel": 1., "bias": 2., "other": 3.},
}) # pyformat: disable
config = ml_collections.ConfigDict()
config.lr = 0.01
config.schedule = dict(decay_type="linear")
config.optax_name = "scale"
config.grad_clip_norm = 1.0
config.optax = ml_collections.ConfigDict()
g_scale = 0.5
config.optax.step_size = g_scale
total_steps = 10
sched_kw = dict(global_batch_size=1, total_steps=total_steps)
tx, (sched_fn,) = bv_optax.make(config, params, sched_kw=sched_kw)
opt_state = tx.init(params)
grads = jax.tree.map(jnp.ones_like, params)
gflat = jax.tree.leaves(grads)
l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in gflat]))
grad_clip_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g)
grads_scaled = jax.tree.map(lambda p: grad_clip_factor * p, grads)
for step in range(total_steps):
updates, opt_state = tx.update(grads, opt_state)
self.assertEqual(bv_optax.get_count(opt_state), step + 1)
sched = sched_fn(step)
np.testing.assert_almost_equal(
sched, 1.0 / total_steps * (total_steps - step))
make_tx = lambda sched: lambda g: -sched * config.lr * g_scale * g
chex.assert_trees_all_close(updates,
jax.tree.map(make_tx(sched), grads_scaled))
def test_make_multi(self):
params = jax.tree.map(
jnp.array, {
"Dense_0": {"kernel": 1.0, "bias": 2.0, "other": 3.0},
"Dense_1": {"kernel": 4.0, "bias": 5.0, "other": 6.0},
"Dense_2": {"kernel": 7.0, "bias": 8.0, "other": 9.0},
"Dense_3": {"kernel": 10., "bias": 11., "other": 12.},
}) # pyformat: disable
# Manually specify lr + wd for computing expected values.
lrb = 0.01
lr1 = 2.0
lr2 = 0.5
lr_mults = {
"Dense_0": {"kernel": lr1, "bias": lr1, "other": lr1},
"Dense_1": {"kernel": lr2, "bias": lr2, "other": lr2},
"Dense_2": {"kernel": 1.0, "bias": 1.0, "other": 1.0},
"Dense_3": {"kernel": 1.0, "bias": 1.0, "other": 1.0},
} # pyformat: disable
wdb = 1e-3
wd1 = 10.0
wd2 = 0.1
wds = jax.tree.map(
jnp.array, {
"Dense_0": {"kernel": wd1 * wdb, "bias": wd2 * wdb, "other": 0.},
"Dense_1": {"kernel": wd1 * wdb, "bias": wd2 * wdb, "other": 0.},
"Dense_2": {"kernel": wd1 * wdb, "bias": wd2 * wdb, "other": 0.},
"Dense_3": {"kernel": 0.0 * wdb, "bias": 0.0 * wdb, "other": 0.},
}) # pyformat: disable
config = ml_collections.ConfigDict()
config.lr = lrb
config.lr_mults = [
("Dense_0/.*", lr1),
("Dense_1/.*", lr2),
]
config.wd = wdb
config.wd_mults = [
(".*/kernel", wd1),
(".*/bias", wd2),
]
mult1 = 1.0
mult2 = 0.1
config.schedule = [
("Dense_0/.*", dict(decay_type="linear", mult=mult1, linear_end=mult1)),
("Dense_[12]/.*", dict(decay_type="linear", mult=mult2)),
(".*", None),
]
config.optax_name = "scale"
config.grad_clip_norm = 1.0
config.optax = ml_collections.ConfigDict()
g_scale = 0.5
config.optax.step_size = g_scale
total_steps = 10
sched_kw = dict(global_batch_size=1, total_steps=total_steps)
tx, (sched_fn1,
sched_fn2) = bv_optax.make(config, params, sched_kw=sched_kw)
opt_state = tx.init(params)
# Manually specify schedules for computing expected values.
frozen_fn = lambda _: jnp.array(0.)
sched_fns = {
"Dense_0": {"kernel": sched_fn1, "bias": sched_fn1, "other": sched_fn1},
"Dense_1": {"kernel": sched_fn2, "bias": sched_fn2, "other": sched_fn2},
"Dense_2": {"kernel": sched_fn2, "bias": sched_fn2, "other": sched_fn2},
"Dense_3": {"kernel": frozen_fn, "bias": frozen_fn, "other": frozen_fn},
} # pyformat: disable
grads = jax.tree.map(jnp.ones_like, params)
gflat, _ = jax.tree.flatten(
# Don't count frozen params towards gradient norm.
jax.tree.map(lambda g, sched_fn: {frozen_fn: 0}.get(sched_fn, g),
grads, sched_fns))
l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in gflat]))
grad_clip_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g)
grads_scaled = jax.tree.map(lambda p: grad_clip_factor * p, grads)
def make_tx(step):
def get_update(p, g, wd, sched_fn, lr_mult):
return -sched_fn(step) * (lrb * lr_mult * g_scale * g + p * wd)
return get_update
for step in range(total_steps):
updates, opt_state = tx.update(grads, opt_state, params)
self.assertEqual(bv_optax.get_count(opt_state), step + 1)
sched1, sched2 = sched_fn1(step), sched_fn2(step)
np.testing.assert_almost_equal(sched1, mult1)
np.testing.assert_almost_equal(sched2,
mult2 * (total_steps - step) / total_steps)
chex.assert_trees_all_close(
updates,
jax.tree.map(
make_tx(step), params, grads_scaled, wds, sched_fns, lr_mults))
def test_frozen_no_state(self):
params = {"small": jnp.zeros([1]), "large": jnp.zeros([1000])}
config = ml_collections.ConfigDict()
config.lr = 0.01
config.schedule = [
("small", dict(decay_type="cosine")),
("large", None),
]
config.optax_name = "scale_by_adam"
sched_kw = dict(global_batch_size=1, total_steps=1)
tx, _ = bv_optax.make(config, params, sched_kw=sched_kw)
opt_state = tx.init(params)
adam_state = bv_optax.find_states(opt_state, optax.ScaleByAdamState)
nbytes = sum(
jax.tree.flatten(jax.tree.map(lambda x: x.nbytes, adam_state))[0])
self.assertLess(nbytes, 1_000)
def test_adafactor(self):
params = {"Dense_0": {"kernel": jnp.zeros([1024, 1024])}}
config = ml_collections.ConfigDict()
config.optax_name = "big_vision.scale_by_adafactor"
config.lr = 0.01
config.schedule = dict(decay_type="linear")
sched_kw = dict(global_batch_size=1, total_steps=1)
tx, _ = bv_optax.make(config, params, sched_kw=sched_kw)
opt_state = tx.init(params)
adafactor_state = bv_optax.find_states(opt_state, optax.FactoredState)
n_state_params = sum(
jax.tree.flatten(
jax.tree.map(lambda x: np.prod(
x.shape if hasattr(x, "shape") else 0), adafactor_state))[0])
self.assertEqual(n_state_params, 2 * 1024 + 2)
if __name__ == "__main__":
absltest.main()