| |
| import mlx.core as mx |
| import mlx.nn as nn |
| import mlx_tests |
| from mlx.nn.layers.distributed import shard_inplace, shard_linear |
| from mlx.nn.utils import average_gradients |
|
|
|
|
| class TestNCCLDistributed(mlx_tests.MLXTestCase): |
| @classmethod |
| def setUpClass(cls): |
| world = mx.distributed.init(strict=True, backend="nccl") |
| rank = world.rank() |
| mx.set_default_device(mx.Device(mx.gpu, rank % 8)) |
|
|
| def test_all_reduce(self): |
| world = mx.distributed.init() |
| dtypes = [ |
| (mx.int8, 0), |
| (mx.uint8, 0), |
| (mx.int32, 0), |
| (mx.uint32, 0), |
| (mx.float32, 1e-6), |
| (mx.float16, 5e-3), |
| (mx.bfloat16, 1e-1), |
| ] |
| sizes = [ |
| (7,), |
| (10,), |
| (1024,), |
| (1024, 1024), |
| ] |
| key = mx.random.key(0) |
|
|
| for dt, rtol in dtypes: |
| for sh in sizes: |
| x = ( |
| mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10 |
| ).astype(dt) |
|
|
| |
| y = mx.distributed.all_sum(x[world.rank()]) |
| z = x.sum(0) |
| maxrelerror = (y - z).abs() |
| if rtol > 0: |
| maxrelerror /= z.abs() |
| maxrelerror = maxrelerror.max() |
| self.assertLessEqual(maxrelerror, rtol) |
|
|
| def test_average_gradients(self): |
| original_all_sum = mx.distributed.all_sum |
| n_calls = 0 |
| xtype = None |
|
|
| def new_all_sum(x, **kwargs): |
| nonlocal n_calls |
| nonlocal xtype |
|
|
| n_calls += 1 |
| if xtype is not None: |
| self.assertEqual(xtype, x.dtype) |
|
|
| return original_all_sum(x, **kwargs) |
|
|
| mx.distributed.all_sum = new_all_sum |
| try: |
| grads = [mx.ones(10) for i in range(10)] |
| new_grads = average_gradients(grads) |
| mx.eval(new_grads) |
| self.assertEqual(len(new_grads), 10) |
| self.assertTrue(all(mx.all(g == 1) for g in new_grads)) |
| self.assertEqual(n_calls, 1) |
|
|
| n_calls = 0 |
| new_grads = average_gradients(grads, all_reduce_size=4 * 50) |
| mx.eval(new_grads) |
| self.assertEqual(len(new_grads), 10) |
| self.assertTrue(all(mx.all(g == 1) for g in new_grads)) |
| self.assertEqual(n_calls, 2) |
|
|
| n_calls = 0 |
| new_grads = average_gradients(grads, all_reduce_size=0) |
| mx.eval(new_grads) |
| self.assertEqual(len(new_grads), 10) |
| self.assertTrue(all(mx.all(g == 1) for g in new_grads)) |
| self.assertEqual(n_calls, 10) |
|
|
| n_calls = 0 |
| xtype = mx.float16 |
| new_grads = average_gradients( |
| grads, |
| all_reduce_size=2 * 50, |
| communication_type=mx.float16, |
| ) |
| mx.eval(new_grads) |
| self.assertEqual(len(new_grads), 10) |
| self.assertTrue(all(g.dtype == mx.float32 for g in new_grads)) |
| self.assertTrue(all(mx.all(g == 1) for g in new_grads)) |
| self.assertEqual(n_calls, 2) |
|
|
| finally: |
| mx.distributed.all_sum = original_all_sum |
|
|
| def test_donation(self): |
| x = mx.random.normal((1024,)) |
| mx.eval(x) |
| mx.synchronize() |
|
|
| mx.reset_peak_memory() |
| scale = mx.array(2.0) |
| y = mx.distributed.all_sum(x) |
| mx.eval(y) |
| mx.synchronize() |
| all_sum_only = mx.get_peak_memory() |
| y = mx.distributed.all_sum(x) * scale |
| mx.eval(y) |
| mx.synchronize() |
| all_sum_with_binary = mx.get_peak_memory() |
|
|
| self.assertEqual(all_sum_only, all_sum_with_binary) |
|
|
| def test_shard_linear(self): |
| |
| mx.random.seed(0xF0F0F0F0) |
|
|
| |
| world = mx.distributed.init() |
| part = ( |
| slice(None), |
| slice( |
| world.rank() * 1024 // world.size(), |
| (world.rank() + 1) * 1024 // world.size(), |
| ), |
| ) |
| x = mx.random.normal((4, 1024)) |
|
|
| |
| lin = nn.Linear(1024, 1024, bias=True) |
| slin1 = shard_linear(lin, "all-to-sharded") |
| slin2 = shard_linear(lin, "sharded-to-all") |
| y = lin(x) |
| y1 = slin1(x) |
| y2 = slin2(x[part]) |
| self.assertTrue(mx.allclose(y, y2, atol=1e-4, rtol=1e-4)) |
| self.assertTrue(mx.allclose(y[part], y1, atol=1e-4, rtol=1e-4)) |
|
|
| |
| def dummy_loss(model, x, y): |
| return (model(x) * y).sum() |
|
|
| mod = nn.Sequential( |
| nn.Linear(128, 128), |
| nn.Linear(128, 128), |
| nn.Linear(128, 128), |
| nn.Linear(128, 128), |
| ) |
| smod = nn.Sequential( |
| shard_linear(mod.layers[0], "all-to-sharded"), |
| shard_linear(mod.layers[1], "sharded-to-all"), |
| shard_linear(mod.layers[2], "all-to-sharded"), |
| shard_linear(mod.layers[3], "sharded-to-all"), |
| ) |
|
|
| grad1 = nn.value_and_grad(mod, dummy_loss) |
| grad2 = nn.value_and_grad(smod, dummy_loss) |
|
|
| x = mx.random.normal((4, 128)) |
| y = mx.random.normal((4, 128)) |
|
|
| l1, g1 = grad1(mod, x, y) |
| l2, g2 = grad2(smod, x, y) |
| mx.eval(l1, g1, l2, g2) |
|
|
| part = slice( |
| world.rank() * 128 // world.size(), (world.rank() + 1) * 128 // world.size() |
| ) |
|
|
| self.assertTrue(mx.allclose(l1, l2)) |
| self.assertTrue( |
| mx.allclose( |
| g1["layers"][0]["weight"][part], |
| g2["layers"][0]["weight"], |
| atol=1e-6, |
| rtol=1e-4, |
| ) |
| ) |
| self.assertTrue( |
| mx.allclose( |
| g1["layers"][2]["weight"][part], |
| g2["layers"][2]["weight"], |
| atol=1e-6, |
| rtol=1e-4, |
| ) |
| ) |
| self.assertTrue( |
| mx.allclose( |
| g1["layers"][1]["weight"][:, part], |
| g2["layers"][1]["weight"], |
| atol=1e-6, |
| rtol=1e-4, |
| ) |
| ) |
| self.assertTrue( |
| mx.allclose( |
| g1["layers"][3]["weight"][:, part], |
| g2["layers"][3]["weight"], |
| atol=1e-6, |
| rtol=1e-4, |
| ) |
| ) |
| self.assertTrue( |
| mx.allclose( |
| g1["layers"][0]["bias"][part], |
| g2["layers"][0]["bias"], |
| atol=1e-6, |
| rtol=1e-4, |
| ) |
| ) |
| self.assertTrue( |
| mx.allclose( |
| g1["layers"][2]["bias"][part], |
| g2["layers"][2]["bias"], |
| atol=1e-6, |
| rtol=1e-4, |
| ) |
| ) |
| self.assertTrue( |
| mx.allclose( |
| g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4 |
| ) |
| ) |
| self.assertTrue( |
| mx.allclose( |
| g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4 |
| ) |
| ) |
|
|
| def test_shard_predicate(self): |
| mx.random.seed(0xF0F0F0F0) |
|
|
| class MyConv(nn.Module): |
| def __init__(self, *args, **kwargs): |
| super().__init__() |
| self.aggregate = kwargs.pop("aggregate", False) |
| self.conv = nn.Conv2d(*args, **kwargs) |
|
|
| def __call__(self, x): |
| x = self.conv(x) |
| if self.aggregate: |
| x = mx.distributed.all_sum(x) |
| return x |
|
|
| def sharding(path, weight): |
| parts = path.split(".") |
| even = int(parts[1]) % 2 == 0 |
| if even: |
| return 0 |
| else: |
| return -1 if parts[-1] != "bias" else None |
|
|
| mod = nn.Sequential( |
| MyConv(3, 128, kernel_size=3), |
| MyConv(128, 128, kernel_size=3), |
| MyConv(128, 128, kernel_size=3), |
| MyConv(128, 3, kernel_size=3), |
| ) |
| smod = nn.Sequential( |
| MyConv(3, 128, kernel_size=3), |
| MyConv(128, 128, kernel_size=3, aggregate=True), |
| MyConv(128, 128, kernel_size=3), |
| MyConv(128, 3, kernel_size=3, aggregate=True), |
| ) |
| smod.update(mod.parameters()) |
| shard_inplace(smod, sharding) |
|
|
| x = mx.random.normal((4, 16, 16, 3)) |
| y1 = mod(x) |
| y2 = smod(x) |
| self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4)) |
|
|
|
|
| if __name__ == "__main__": |
| mlx_tests.MLXTestRunner() |
|
|