youtube-music-transcribe / t5x /contrib /moe /partitioning_test.py
juancopi81's picture
Add t5x and mt3 models
b100e1c
raw
history blame
11.6 kB
# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for partitioning."""
from typing import Any
from absl.testing import absltest
from flax import core as flax_core
from flax import optim
from flax.linen import partitioning as flax_partitioning
import jax
import numpy as np
from t5x import train_state as train_state_lib
from t5x.contrib.moe import partitioning as moe_partitioning
from t5x.contrib.moe import training_utils
mock = absltest.mock
AxisMetadata = flax_partitioning.AxisMetadata
DataLayout = moe_partitioning.DataLayout
FlaxOptimTrainState = train_state_lib.FlaxOptimTrainState
InferenceState = train_state_lib.InferenceState
PartitionSpec = moe_partitioning.PartitionSpec
PRNGKey = Any
class LogicalAdam(optim.Adam):
"""Subclass of Adam optimizer with T5X logical axis partitioning support."""
def derive_logical_axes(self, optimizer_state, param_logical_axes):
"""Derives optimizer logical partitioning from model logical partitions."""
del param_logical_axes # Return fixed axes for test
optimizer_logical_axes = {
'state': {
'param_states': {
'logits_dense': {
'grad_ema': None,
'grad_sq_ema': None
},
'mlp': {
'wo': {
'kernel': {
'grad_ema': PartitionSpec('embed', 'mlp'),
'grad_sq_ema': None
}
}
}
},
'step': None
},
'target': {
'logits_dense': PartitionSpec('vocab', 'embed'),
'mlp': {
'wo': {
'kernel': PartitionSpec('embed', 'mlp'),
},
},
}
}
return optimizer_state.restore_state(optimizer_logical_axes)
def create_optimizer():
"""Creates simple Adam optimizer."""
target = {
'logits_dense': np.ones((16, 16), np.float32),
'mlp': {
'wo': {
'kernel': np.ones((32, 16), np.float32)
}
}
}
return LogicalAdam(learning_rate=1e-4).create(target)
class PartitioningTest(absltest.TestCase):
def test_default_data_layout(self):
# No expert replication required. Use default data layout.
partitioner = moe_partitioning.MoePjitPartitioner(
num_experts=8, num_partitions=1)
self.assertFalse(partitioner.two_data_axes)
self.assertEqual(
partitioner.get_data_layout(batch_size=32),
DataLayout(
batch_size=32,
shard_id=0,
num_shards=1,
is_first_host_in_replica_set=True))
def test_two_data_axis_layout_override(self):
partitioner = moe_partitioning.MoePjitPartitioner(
num_experts=8, num_partitions=1)
# Force override case to check layout is valid.
partitioner.two_data_axes = True
partitioner._data_axis = ('data', 'model')
self.assertEqual(
partitioner.get_data_layout(batch_size=8),
DataLayout(
batch_size=8,
shard_id=0,
num_shards=1,
is_first_host_in_replica_set=True))
def test_logical_axes_for_moe_partitioner_no_overrides(self):
partitioner = moe_partitioning.MoePjitPartitioner(
num_experts=8,
num_partitions=1,
state_filter_fn=training_utils.match_fn(r'no_state_matching'))
optimizer = create_optimizer()
train_state = FlaxOptimTrainState(
optimizer,
params_axes={
'logits_dense_axes': AxisMetadata(names=('vocab', 'embed')),
'mlp': {
'wo': {
'kernel_axes': AxisMetadata(names=('embed', 'mlp'))
}
}
})
logical_axes = partitioner.get_logical_axes(train_state)
# No updates to state. Should match what derive_logical_axes() returns.
jax.tree_map(self.assertIsNone, logical_axes.param_states['logits_dense'])
self.assertEqual(logical_axes.param_states['mlp']['wo']['kernel'].grad_ema,
PartitionSpec('embed', 'mlp'))
self.assertIsNone(
logical_axes.param_states['mlp']['wo']['kernel'].grad_sq_ema)
self.assertEqual(
logical_axes.params, {
'logits_dense': PartitionSpec('vocab', 'embed'),
'mlp': {
'wo': {
'kernel': PartitionSpec('embed', 'mlp')
}
}
})
def test_logical_axes_for_moe_partitioner_with_overrides(self):
partitioner = moe_partitioning.MoePjitPartitioner(
num_experts=8,
num_partitions=1,
state_filter_fn=training_utils.match_fn(r'.*mlp.*'))
optimizer = create_optimizer()
train_state = FlaxOptimTrainState(
optimizer,
params_axes={
'logits_dense_axes': AxisMetadata(names=('vocab', 'embed')),
'mlp': {
'wo': {
'kernel_axes': AxisMetadata(names=('embed', 'mlp'))
}
}
})
logical_axes = partitioner.get_logical_axes(train_state)
jax.tree_map(self.assertIsNone, logical_axes.param_states['logits_dense'])
# 'mlp' params should be prepended with 'expert' spec because
# state_filter_fn matches '.*mlp.*'.
self.assertEqual(logical_axes.param_states['mlp']['wo']['kernel'].grad_ema,
PartitionSpec('expert', 'embed', 'mlp'))
self.assertEqual(
logical_axes.param_states['mlp']['wo']['kernel'].grad_sq_ema,
PartitionSpec('expert',))
self.assertEqual(
logical_axes.params, {
'logits_dense': PartitionSpec('vocab', 'embed'),
'mlp': {
'wo': {
'kernel': PartitionSpec('embed', 'mlp')
}
}
})
def test_inference_state_logical_axes(self):
partitioner = moe_partitioning.MoePjitPartitioner(
num_experts=8, num_partitions=1)
model_variables = flax_core.freeze({
'params': {
'dense': {
'bias': np.zeros(4),
'kernel': np.zeros((2, 4))
}
},
'params_axes': {
'dense': {
'bias_axes': AxisMetadata(names=('embed',)),
'kernel_axes': AxisMetadata(names=('vocab', 'embed')),
}
},
})
train_state = InferenceState.create(model_variables)
logical_axes = partitioner.get_logical_axes(train_state)
# No expert axis overrides to InferenceState. Partition specs should match
# input axis metadata.
self.assertEqual(
logical_axes,
InferenceState(
step=None,
params=flax_core.FrozenDict({
'dense': {
'bias': PartitionSpec('embed',),
'kernel': PartitionSpec('vocab', 'embed'),
},
})))
@mock.patch('jax.device_count')
def test_overridden_logical_axis_rules(self, device_count: int):
device_count.return_value = 4
# Fewer experts than devices --> modified axis rules with two 'batch' axes.
self.assertEqual(
moe_partitioning.standard_logical_axis_rules(
num_experts=1,
num_partitions=1,
model_parallel_submesh=None,
additional_rules=[('additional', 'model'),
('expert_magic', 'data')]),
[
('batch', ('data', 'model')), # Shard batch over entire mesh
# No sharding of weights over model axis.
('vocab', None),
('embed', None),
('mlp', None),
('heads', None),
('kv', None),
('joined_kv', None),
('relpos_buckets', None),
('abspos_buckets', None),
('length', None),
('layers', None),
('stack', None),
('mlp_activations', None),
('expert', 'data'), # Shard experts over "first" data axis only
('expert_mlp', None),
('expert_group', None),
# Experts replicated along "second" data axis
('expert_replicas', 'model'),
('unmodeled', None),
('additional', None),
('expert_magic', 'data'),
])
def test_default_logical_axis(self):
# Model parallelism used --> default logical axis rules.
self.assertEqual(
moe_partitioning.standard_logical_axis_rules(
num_experts=1,
num_partitions=2,
model_parallel_submesh=None,
additional_rules=[('additional', 'model')]),
[
('batch', 'data'), # Shard batch over single data axis
# Default model annotations used.
('vocab', 'model'),
('embed', None),
('mlp', 'model'),
('heads', 'model'),
('kv', None),
('joined_kv', 'model'),
('relpos_buckets', None),
('abspos_buckets', None),
('length', None),
('layers', None),
('stack', None),
('mlp_activations', None),
('expert', 'data'), # Shard experts along data axis
('expert_mlp', 'model'),
('expert_group', None),
('expert_replicas', None),
('unmodeled', None),
('additional', 'model'),
])
def test_2d_parameter_sharding_unsupported(self):
with self.assertRaisesRegex(ValueError, 'is not supported for MoE.'):
moe_partitioning.standard_logical_axis_rules(
num_experts=4, num_partitions=1, parameter_partitioning_dims=2)
def test_data_partition_spec(self):
self.assertEqual(
moe_partitioning.data_partition_spec(two_data_axes=False),
PartitionSpec('data',))
self.assertEqual(
moe_partitioning.data_partition_spec(two_data_axes=True),
PartitionSpec(('data', 'model'),))
@mock.patch('jax.device_count')
def test_when_to_override_model_axis(self, device_count: int):
device_count.return_value = 4
# More experts than devices.
self.assertFalse(
moe_partitioning._override_model_axis(
num_experts=8, num_partitions=1, model_parallel_submesh=None))
# Fewer experts than devices.
self.assertTrue(
moe_partitioning._override_model_axis(
num_experts=1, num_partitions=1, model_parallel_submesh=None))
# Model parallelism used.
self.assertFalse(
moe_partitioning._override_model_axis(
num_experts=1, num_partitions=2, model_parallel_submesh=None))
def test_axis_resource_overrides(self):
input_resources = (PartitionSpec('data'), PartitionSpec('model'), None,
PartitionSpec('unrecognized'))
overridden_resources = moe_partitioning._override_partition_specs(
input_resources)
# "data" -> ("data", "model"). "model" -> None.
self.assertEqual(overridden_resources, (PartitionSpec(
('data', 'model'),), None, None, PartitionSpec('unrecognized',)))
if __name__ == '__main__':
absltest.main()