| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Tests for utils.""" |
| |
|
| | from functools import partial |
| | import os |
| |
|
| | from absl.testing import parameterized |
| | from big_vision import utils |
| | import chex |
| | import flax |
| | import jax |
| | from jax.experimental.array_serialization import serialization as array_serial |
| | import jax.numpy as jnp |
| | import numpy as np |
| | import tensorflow as tf |
| |
|
| | from tensorflow.io import gfile |
| |
|
| |
|
| | NDEV = 4 |
| |
|
| |
|
| | def setUpModule(): |
| | chex.set_n_cpu_devices(NDEV) |
| |
|
| |
|
| | class PadShardUnpadTest(chex.TestCase, tf.test.TestCase): |
| | BATCH_SIZES = [NDEV, NDEV + 1, NDEV - 1, 5 * NDEV, 5 * NDEV + 1, 5 * NDEV - 1] |
| | DTYPES = [np.float32, np.uint8, jax.numpy.bfloat16, np.int32] |
| |
|
| | def tearDown(self): |
| | chex.clear_trace_counter() |
| | super().tearDown() |
| |
|
| | @parameterized.product(dtype=DTYPES, bs=BATCH_SIZES) |
| | def test_basics(self, dtype, bs): |
| | |
| | @partial(utils.pad_shard_unpad, static_argnums=()) |
| | def add(a, b): |
| | return a + b |
| |
|
| | x = jnp.arange(bs, dtype=dtype) |
| | y = add(x, 10 * x) |
| | chex.assert_type(y.dtype, x.dtype) |
| | np.testing.assert_allclose(np.float64(y), np.float64(x + 10*x)) |
| |
|
| | @parameterized.parameters(DTYPES) |
| | def test_min_device_batch_avoids_recompile(self, dtype): |
| | @partial(utils.pad_shard_unpad, static_argnums=()) |
| | @jax.jit |
| | @chex.assert_max_traces(n=1) |
| | def add(a, b): |
| | return a + b |
| |
|
| | chex.clear_trace_counter() |
| |
|
| | for bs in self.BATCH_SIZES: |
| | x = jnp.arange(bs, dtype=dtype) |
| | y = add(x, 10 * x, min_device_batch=9) |
| | chex.assert_type(y.dtype, x.dtype) |
| | np.testing.assert_allclose(np.float64(y), np.float64(x + 10*x)) |
| |
|
| | @parameterized.product(dtype=DTYPES, bs=BATCH_SIZES) |
| | def test_static_argnum(self, dtype, bs): |
| | @partial(utils.pad_shard_unpad, static_argnums=(1,)) |
| | def add(a, b): |
| | return a + b |
| |
|
| | x = jnp.arange(bs, dtype=dtype) |
| | y = add(x, dtype(10)) |
| | chex.assert_type(y.dtype, x.dtype) |
| | np.testing.assert_allclose(np.float64(y), np.float64(x + 10)) |
| |
|
| | @parameterized.product(dtype=DTYPES, bs=BATCH_SIZES) |
| | def test_static_argnames(self, dtype, bs): |
| | |
| | |
| | @partial(utils.pad_shard_unpad, static_argnames=('b',)) |
| | def add(params, a, *, b): |
| | return params * a + b |
| |
|
| | x = jnp.arange(bs, dtype=dtype) |
| | y = add(dtype(5), x, b=dtype(10)) |
| | chex.assert_type(y.dtype, x.dtype) |
| | np.testing.assert_allclose(np.float64(y), np.float64(5 * x + 10)) |
| |
|
| |
|
| | class TreeTest(tf.test.TestCase): |
| |
|
| | def setUp(self): |
| | super().setUp() |
| |
|
| | self.d1 = {'w1': 1, 'w2': 2, 'w34': (3, 4)} |
| | self.d1_flat = [1, 2] |
| | self.d1_flat_jax = jax.tree.flatten(self.d1)[0] |
| | self.d1_named_flat = [('w1', 1), ('w2', 2), ('w34/0', 3), ('w34/1', 4)] |
| | self.d1_named_flat_jax = [('w1', 1), ('w2', 2), ('w34/0', 3), ('w34/1', 4)] |
| |
|
| | self.d2 = {'conv1': {'kernel': 0, 'bias': 1}, |
| | 'conv2': {'kernel': 2, 'bias': 3}} |
| | self.d2_flat = [1, 0, 3, 2] |
| | self.d2_flat_jax = jax.tree.flatten(self.d2)[0] |
| | self.d2_named_flat = [('conv1/bias', 1), ('conv1/kernel', 0), |
| | ('conv2/bias', 3), ('conv2/kernel', 2)] |
| | self.d2_named_flat_jax = [('conv1/bias', 1), ('conv1/kernel', 0), |
| | ('conv2/bias', 3), ('conv2/kernel', 2)] |
| | self.d2_named_flat_inner = [ |
| | ('conv1/bias', 1), ('conv1/kernel', 0), ('conv1', self.d2['conv1']), |
| | ('conv2/bias', 3), ('conv2/kernel', 2), ('conv2', self.d2['conv2']), |
| | ('', self.d2), |
| | ] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | @flax.struct.dataclass |
| | class FlaxStruct(): |
| | v3: float |
| | v2: int |
| | v1: str |
| | self.d3 = {'a': 0, 'flax': FlaxStruct(2.0, 1, 's')} |
| | self.d3_flat = [0, 1, 2.0, 's'] |
| | self.d3_flat_jax = jax.tree.flatten(self.d3)[0] |
| | self.d3_named_flat = [ |
| | ('a', 0), ('flax/v1', 's'), ('flax/v2', 1), ('flax/v3', 2.0)] |
| | self.d3_named_flat_jax = [ |
| | ('a', 0), ('flax/v3', 2.0), ('flax/v2', 1), ('flax/v1', 's')] |
| |
|
| | def test_traverse_with_names(self): |
| | names_and_vals = list(utils._traverse_with_names(self.d1)) |
| | self.assertEqual(names_and_vals, self.d1_named_flat) |
| |
|
| | names_and_vals = list(utils._traverse_with_names(self.d2)) |
| | self.assertEqual(names_and_vals, self.d2_named_flat) |
| |
|
| | names_and_vals = list(utils._traverse_with_names( |
| | self.d2, with_inner_nodes=True)) |
| | self.assertEqual(names_and_vals, self.d2_named_flat_inner) |
| |
|
| | names_and_vals = list(utils._traverse_with_names(self.d3)) |
| | self.assertEqual(names_and_vals, self.d3_named_flat) |
| |
|
| | def test_tree_flatten_with_names(self): |
| | names_and_vals = utils.tree_flatten_with_names(self.d1)[0] |
| | self.assertEqual(names_and_vals, self.d1_named_flat_jax) |
| | self.assertEqual([x for _, x in names_and_vals], self.d1_flat_jax) |
| |
|
| | names_and_vals = utils.tree_flatten_with_names(self.d2)[0] |
| | self.assertEqual(names_and_vals, self.d2_named_flat_jax) |
| | self.assertEqual([x for _, x in names_and_vals], self.d2_flat_jax) |
| |
|
| | names_and_vals = utils.tree_flatten_with_names(self.d3)[0] |
| | self.assertEqual(names_and_vals, self.d3_named_flat_jax) |
| | self.assertEqual([x for _, x in names_and_vals], self.d3_flat_jax) |
| |
|
| | def test_tree_map_with_names(self): |
| | d1 = utils.tree_map_with_names( |
| | lambda name, x: -x if 'w2' in name else x, self.d1) |
| | self.assertEqual(d1, {'w1': 1, 'w2': -2, 'w34': (3, 4)}) |
| |
|
| | d1 = utils.tree_map_with_names( |
| | lambda name, x1, x2: x1 + x2 if 'w2' in name else x1, self.d1, self.d1) |
| | self.assertEqual(d1, {'w1': 1, 'w2': 4, 'w34': (3, 4)}) |
| |
|
| | def test_recover_tree(self): |
| | keys = ['a/b', 'a/c/x', 'a/c/y', 'd'] |
| | values = [0, 1, 2, 3] |
| | self.assertEqual(utils.recover_tree(keys, values), |
| | {'a': {'b': 0, 'c': {'x': 1, 'y': 2}}, 'd': 3}) |
| |
|
| | def test_make_mask_trees(self): |
| | F, T = False, True |
| | tree = {'a': {'b': 0, 'x': 1}, 'b': {'x': 2, 'y': 3}} |
| | msk1 = {'a': {'b': F, 'x': T}, 'b': {'x': T, 'y': F}} |
| | msk2 = {'a': {'b': F, 'x': F}, 'b': {'x': F, 'y': T}} |
| | |
| | |
| | self.assertEqual( |
| | utils.make_mask_trees(tree, ('.*/x', 'b/.*')), [msk1, msk2]) |
| |
|
| | def test_tree_get(self): |
| | tree = {'a': {'b': 0, 'x': 1}, 'b': {'x': 2, 'y': 3}} |
| | self.assertEqual(utils.tree_get(tree, 'a/b'), 0) |
| | self.assertEqual(utils.tree_get(tree, 'a/x'), 1) |
| | self.assertEqual(utils.tree_get(tree, 'b/x'), 2) |
| | self.assertEqual(utils.tree_get(tree, 'b/y'), 3) |
| | self.assertEqual(utils.tree_get(tree, 'a'), tree['a']) |
| | self.assertEqual(utils.tree_get(tree, 'b'), tree['b']) |
| |
|
| | def test_tree_replace(self): |
| | tree = {'a': {'b': 2, 'c': 3}, 'c': 4} |
| | replacements = { |
| | 'a/b': 'a/b/x', |
| | '.*c': 'C', |
| | 'C': 'D', |
| | '.*/c': None, |
| | } |
| | tree2 = utils.tree_replace(tree, replacements) |
| | self.assertEqual(tree2, {'D': 4, 'a': {'b': {'x': 2}}}) |
| |
|
| | def test_tree_compare(self): |
| | tree1_only, tree2_only, dtype_shape_mismatch = utils.tree_compare( |
| | {'a': {'b': jnp.array(2), 'c': jnp.array(3)}}, |
| | {'a': {'B': jnp.array(2), 'c': jnp.array(3.)}}, |
| | ) |
| | self.assertEqual(tree1_only, {'a/b'}) |
| | self.assertEqual(tree2_only, {'a/B'}) |
| | self.assertEqual( |
| | dtype_shape_mismatch, |
| | {'a/c': [(jnp.dtype('int32'), ()), (jnp.dtype('float32'), ())]}) |
| |
|
| |
|
| | class StepConversionTest(parameterized.TestCase, tf.test.TestCase): |
| |
|
| | @parameterized.named_parameters( |
| | ('nice_steps', 1000, None, None, dict(foo_steps=3), 3), |
| | ('nice_epochs', 1000, 100, None, dict(foo_epochs=3), 30), |
| | ('nice_examples', None, 100, None, dict(foo_examples=300), 3), |
| | ('nice_percent', None, None, 10, dict(foo_percent=0.30), 3), |
| | ('ignore_neg', 1000, 100, 10, dict(foo_steps=-1, foo_epochs=-1, |
| | foo_examples=-1, foo_percent=0.30), 3), |
| | ('zero_steps', None, None, 10, dict(foo_percent=0.0), 0), |
| | ('offbyone_steps', 1001, None, None, dict(foo_steps=3), 3), |
| | ('offbyone_epochs', 1001, 100, None, dict(foo_epochs=3), 30), |
| | ('offbyone_examples', None, 101, None, dict(foo_examples=300), 3), |
| | ('offbyone_percent', None, None, 11, dict(foo_percent=0.30), 3), |
| | ) |
| | def test_steps(self, data_size, batch_size, total, cfg, expected): |
| | |
| | step = utils.steps('foo', cfg, data_size=data_size, batch_size=batch_size, |
| | total_steps=total) |
| | self.assertEqual(step, expected) |
| |
|
| | |
| | with self.assertRaises(ValueError): |
| | step = utils.steps('bar', cfg, data_size=data_size, batch_size=batch_size, |
| | total_steps=total) |
| | step = utils.steps('bar', cfg, data_size=data_size, batch_size=batch_size, |
| | total_steps=total, default=1234) |
| | self.assertEqual(step, 1234) |
| |
|
| |
|
| | class CreateLearningRateScheduleTest(parameterized.TestCase, tf.test.TestCase): |
| |
|
| | @parameterized.named_parameters( |
| | ('linear', 'linear', {}, 13, .5), |
| | ('polynomial', 'polynomial', {'end': .1, 'power': 2}, 13, .325), |
| | ('cosine', 'cosine', {}, 13, .5), |
| | ('rsqrt', 'rsqrt', {'timescale': 1}, 13, 0.3333333), |
| | ('stair_5', 'stair', {'steps': [10], 'mults': [.5]}, 5, 1.), |
| | ('stair_10', 'stair', {'steps': [10], 'mults': [.5]}, 10, .5), |
| | ('warmup_before', 'rsqrt', {'timescale': 1}, 3, .6), |
| | ('cooldown_after', 'rsqrt', {'timescale': 1}, 20, .05), |
| | ) |
| | def test_schedule(self, decay_type, extra_kwargs, step, expected_lr): |
| | lr_fn = utils.create_learning_rate_schedule( |
| | total_steps=21, |
| | batch_size=512, |
| | base=.5, |
| | decay_type=decay_type, |
| | scale_with_batchsize=True, |
| | warmup_steps=5, |
| | cooldown_steps=5, |
| | **extra_kwargs) |
| | lr = lr_fn(step) |
| | self.assertAlmostEqual(lr, expected_lr) |
| |
|
| |
|
| | class CheckpointTest(tf.test.TestCase): |
| |
|
| | def setup(self): |
| | gacm = array_serial.GlobalAsyncCheckpointManager() |
| |
|
| | save_path = os.path.join(self.create_tempdir('workdir'), 'checkpoint.bv') |
| | x = utils.put_cpu(np.array([1, 2, 3, 4])) |
| | y = utils.put_cpu(np.array([5, 6, 7, 8])) |
| | ckpt = {'x': x, 'y': {'z': y}} |
| |
|
| | sharding = jax.sharding.SingleDeviceSharding( |
| | jax.local_devices(backend='cpu')[0] |
| | ) |
| | shardings = jax.tree.map(lambda _: sharding, ckpt) |
| |
|
| | return gacm, save_path, ckpt, shardings |
| |
|
| | def test_save_and_load(self): |
| | gacm, save_path, ckpt, shardings = self.setup() |
| | step = 100 |
| | utils.save_checkpoint_ts(gacm, ckpt, save_path, step, keep=True) |
| | gacm.wait_until_finished() |
| | ckpt_loaded = utils.load_checkpoint_ts(save_path, |
| | tree=ckpt, shardings=shardings) |
| | chex.assert_trees_all_equal(ckpt_loaded, ckpt) |
| |
|
| | save_path_step = f'{save_path}-{step:09d}' |
| | ckpt_loaded_step = utils.tsload(save_path_step, shardings=shardings) |
| | chex.assert_trees_all_equal(ckpt_loaded_step, ckpt) |
| |
|
| | def test_save_and_partial_load(self): |
| | gacm, save_path, ckpt, shardings = self.setup() |
| | utils.save_checkpoint_ts(gacm, ckpt, save_path, step=100) |
| | gacm.wait_until_finished() |
| | _ = shardings.pop('x'), ckpt.pop('x') |
| | ckpt_loaded = utils.load_checkpoint_ts(save_path, |
| | tree=ckpt, shardings=shardings) |
| | chex.assert_trees_all_equal(ckpt_loaded, ckpt) |
| |
|
| | def test_save_and_cpu_load(self): |
| | gacm, save_path, ckpt, _ = self.setup() |
| | utils.save_checkpoint_ts(gacm, ckpt, save_path, step=100) |
| | gacm.wait_until_finished() |
| | ckpt_loaded = utils.load_checkpoint_ts(save_path) |
| | chex.assert_trees_all_equal(ckpt_loaded, ckpt) |
| |
|
| | def test_save_and_partial_cpu_load(self): |
| | gacm, save_path, ckpt, _ = self.setup() |
| | utils.save_checkpoint_ts(gacm, ckpt, save_path, step=100) |
| | gacm.wait_until_finished() |
| | ckpt.pop('y') |
| | ckpt_loaded = utils.load_checkpoint_ts(save_path, regex='x.*') |
| | chex.assert_trees_all_equal(ckpt_loaded, ckpt) |
| |
|
| | def test_keep_deletes(self): |
| | def x(tree, factor): |
| | return jax.tree.map(lambda a: a * factor, tree) |
| |
|
| | gacm, save_path, ckpt, _ = self.setup() |
| | utils.save_checkpoint_ts(gacm, ckpt, save_path, step=100, keep=False) |
| | utils.save_checkpoint_ts(gacm, x(ckpt, 2), save_path, step=200, keep=True) |
| | utils.save_checkpoint_ts(gacm, x(ckpt, 3), save_path, step=300, keep=False) |
| | gacm.wait_until_finished() |
| | ckpt_loaded_200 = utils.tsload(f'{save_path}-{200:09d}') |
| | chex.assert_trees_all_equal(ckpt_loaded_200, x(ckpt, 2)) |
| | ckpt_loaded_300 = utils.tsload(f'{save_path}-{300:09d}-tmp') |
| | chex.assert_trees_all_equal(ckpt_loaded_300, x(ckpt, 3)) |
| | ckpt_loaded_last = utils.load_checkpoint_ts(save_path) |
| | chex.assert_trees_all_equal(ckpt_loaded_last, x(ckpt, 3)) |
| | with self.assertRaises(Exception): |
| | _ = utils.tsload(f'{save_path}-{100:09d}') |
| | |
| | self.assertFalse(gfile.exists(f'{save_path}-{100:09d}-tmp')) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | tf.test.main() |
| |
|