# Copyright 2022 The T5X Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for t5x.partitioning.""" import collections from absl.testing import absltest from absl.testing import parameterized import flax.core from flax.linen import partitioning as nn_partitioning import jax import numpy as np from t5x import adafactor from t5x import optimizers from t5x import partitioning from t5x import test_utils as ptu from t5x import train_state jax.config.parse_flags_with_absl() mock = absltest.mock TpuDevice = ptu.TpuDevice TPUV3_32 = ptu.make_devices(4, 4, 1, 2, kind='TPU v3') AxisMetadata = nn_partitioning.AxisMetadata PartitionSpec = partitioning.PartitionSpec class PartitioningTest(absltest.TestCase): @mock.patch('jax.host_count') @mock.patch('jax.local_device_count') def test_bounds_from_last_device(self, local_device_count, host_count): last_device = mock.Mock(coords=(3, 3, 3), core_on_chip=1) tpu_bounds = partitioning.bounds_from_last_device(last_device) self.assertEqual(tpu_bounds, (4, 4, 4, 2)) last_device = mock.Mock(spec=[]) host_count.return_value = 1 local_device_count.return_value = 4 non_tpu_bounds = partitioning.bounds_from_last_device(last_device) self.assertEqual(non_tpu_bounds, (1, 4)) @mock.patch('jax.local_device_count') def test_get_coords(self, local_device_count): device = mock.Mock(coords=(1, 0, 1), core_on_chip=1) coords = partitioning.get_coords(device) self.assertEqual(coords, (1, 0, 1, 1)) device = mock.Mock(spec=['process_index', 'id']) device.process_index = 1 device.id = 9 local_device_count.return_value = 8 coords = partitioning.get_coords(device) self.assertEqual(coords, (1, 1)) @mock.patch('jax.local_devices') @mock.patch('jax.devices') @mock.patch('jax._src.lib.xla_bridge.process_index') def test_default_mesh(self, process_index_fn, devices_fn, local_devices_fn): devices_fn.return_value = TPUV3_32 local_devices_fn.return_value = [ d for d in TPUV3_32 if d.process_index == 0 ] process_index_fn.return_value = 0 global_mesh = partitioning.default_mesh(4) self.assertEqual(global_mesh.axis_names, ('data', 'model')) self.assertEqual(global_mesh.shape, collections.OrderedDict((('data', 8), ('model', 4)))) self.assertEqual(global_mesh.size, 32) for process_index in (0, 1, 2, 3): process_index_fn.return_value = process_index local_mesh = global_mesh.local_mesh self.assertEqual(local_mesh.axis_names, ('data', 'model')) self.assertEqual(local_mesh.shape, collections.OrderedDict((('data', 2), ('model', 4)))) self.assertEqual(local_mesh.size, 8) process_index_fn.return_value = 0 local_mesh = global_mesh.local_mesh lds = np.array([ [ TpuDevice(id=0, process_index=0, coords=(0, 0, 0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0, 0, 0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1, 0, 0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1, 0, 0), core_on_chip=1) ], [ TpuDevice(id=8, process_index=0, coords=(0, 1, 0), core_on_chip=0), TpuDevice(id=9, process_index=0, coords=(0, 1, 0), core_on_chip=1), TpuDevice(id=10, process_index=0, coords=(1, 1, 0), core_on_chip=0), TpuDevice(id=11, process_index=0, coords=(1, 1, 0), core_on_chip=1) ] ], dtype=object) np.testing.assert_array_equal(local_mesh.devices, lds) @mock.patch('jax.local_devices') @mock.patch('jax.devices') @mock.patch('jax._src.lib.xla_bridge.process_index') def test_local_chunker(self, process_index_fn, devices_fn, local_devices_fn): devices_fn.return_value = TPUV3_32 local_devices_fn.return_value = [ d for d in TPUV3_32 if d.process_index == 0 ] process_index_fn.return_value = 0 global_mesh = partitioning.default_mesh(4) local_chunker = partitioning.LocalChunker(global_mesh) self.assertEqual(local_chunker.num_chunks['data'], 4) self.assertEqual(local_chunker.num_chunks['model'], 1) # Derive the chunk order along the first 'data' dim for testing. host_ordering = [] for d in global_mesh.devices[:, 0]: if d.process_index not in host_ordering: host_ordering.append(d.process_index) process_index_to_data_pos = { process_index: idx for idx, process_index in enumerate(host_ordering) } for process_indexx in (0, 1, 2, 3): process_index_fn.return_value = process_indexx global_mesh = partitioning.default_mesh(4) local_chunker = partitioning.LocalChunker(global_mesh) # get expected chunk for 'data' axis. expected_chunk = process_index_to_data_pos[process_indexx] self.assertEqual(local_chunker.chunk_ids['data'], expected_chunk) self.assertEqual(local_chunker.chunk_ids['model'], 0) # Sharded along both axes. local_chunk_info = local_chunker.get_local_chunk_info((128, 16), ['data', 'model']) self.assertEqual(local_chunk_info.replica_id, 0) self.assertEqual(local_chunk_info.slice, (slice(32 * expected_chunk, 32 * (expected_chunk + 1)), slice(0, 16))) # Replicated across first axis. local_chunk_info = local_chunker.get_local_chunk_info((128, 16), [None, 'model']) self.assertEqual(local_chunk_info.replica_id, expected_chunk) self.assertEqual(local_chunk_info.slice, (slice(None), slice(0, 16))) class ModelBasedPartitionerTest(parameterized.TestCase): def get_axes_spec(self, partitioner, factored, momentum): opt_def = adafactor.Adafactor( learning_rate=0.1, factored=factored, min_dim_size_to_factor=8, beta1=0.1 if momentum else None, logical_factor_rules={ 'batch': adafactor.FactorDim.NONE, 'embed': adafactor.FactorDim.ROW, 'vocab': adafactor.FactorDim.COLUMN, 'mlp': adafactor.FactorDim.COLUMN, }) state = train_state.FlaxOptimTrainState.create( opt_def, flax.core.freeze({ 'params': { 'logits_dense': np.ones((16, 16), np.float32), 'mlp': { 'wo': { 'kernel': np.ones((32, 16), np.float32) } } }, 'params_axes': { 'logits_dense_axes': AxisMetadata(names=('vocab', 'embed')), 'mlp': { 'wo': { 'kernel_axes': AxisMetadata(names=('embed', 'mlp')) } } } })) return partitioner.get_mesh_axes(state).state_dict() def get_expected_axes_spec(self, spec_0, spec_1, kernel_spec=PartitionSpec(None, 'model')): return train_state.FlaxOptimTrainState( optimizers.Optimizer( # opt_def, adafactor.Adafactor(0.1), # opt_def not compared. state=optimizers.OptimizerState( step=None, param_states={ 'logits_dense': spec_0, 'mlp': { 'wo': { 'kernel': spec_1 } } }), target={ 'logits_dense': PartitionSpec('model', None), 'mlp': { 'wo': { 'kernel': kernel_spec } } })).state_dict() def test_get_mesh_axes(self): partitioner = partitioning.PjitPartitioner( num_partitions=1, logical_axis_rules=(('batch', 'data'), ('embed', None), ('vocab', 'model'), ('mlp', 'model'))) p0_spec = PartitionSpec('model', None) p1_spec = PartitionSpec(None, 'model') # Test quadrant of conditions: factored or not / momentum or not. axes_spec = self.get_axes_spec(partitioner, factored=True, momentum=False) expected_axes_spec = self.get_expected_axes_spec( adafactor._AdafactorParamState(m=None, v=None, v_col=None, v_row=None), adafactor._AdafactorParamState(m=None, v=None, v_col=None, v_row=None)) jax.tree_multimap(self.assertEqual, axes_spec, expected_axes_spec) axes_spec = self.get_axes_spec(partitioner, factored=True, momentum=True) expected_axes_spec = self.get_expected_axes_spec( adafactor._AdafactorParamState( m=p0_spec, v=None, v_col=None, v_row=None), adafactor._AdafactorParamState( m=p1_spec, v=None, v_col=None, v_row=None)) jax.tree_multimap(self.assertEqual, axes_spec, expected_axes_spec) axes_spec = self.get_axes_spec(partitioner, factored=False, momentum=True) expected_axes_spec = self.get_expected_axes_spec( adafactor._AdafactorParamState( m=p0_spec, v=p0_spec, v_col=None, v_row=None), adafactor._AdafactorParamState( m=p1_spec, v=p1_spec, v_col=None, v_row=None)) jax.tree_multimap(self.assertEqual, axes_spec, expected_axes_spec) axes_spec = self.get_axes_spec(partitioner, factored=False, momentum=False) expected_axes_spec = self.get_expected_axes_spec( adafactor._AdafactorParamState( m=None, v=p0_spec, v_col=None, v_row=None), adafactor._AdafactorParamState( m=None, v=p1_spec, v_col=None, v_row=None)) jax.tree_multimap(self.assertEqual, axes_spec, expected_axes_spec) @parameterized.product(activation_dims=(1, 2), param_dims=(1, 2)) def test_standard_logical_axis_rules(self, activation_dims, param_dims): default_rules = partitioning.standard_logical_axis_rules( activation_dims, param_dims, additional_rules=None) custom_rules = (('my-new-axis', 'data'), ('another-axis', None), ('another-one', 'model')) new_rules = partitioning.standard_logical_axis_rules( activation_dims, param_dims, additional_rules=custom_rules) self.assertEqual(new_rules[:len(default_rules)], default_rules) self.assertEqual(new_rules[len(default_rules):], list(custom_rules)) if __name__ == '__main__': absltest.main()