| | |
| |
|
| | import gc |
| | import inspect |
| | import io |
| | import math |
| | import unittest |
| | from functools import partial, wraps |
| | from io import StringIO |
| |
|
| | import mlx.core as mx |
| | import mlx_tests |
| |
|
| |
|
| | class TestCompile(mlx_tests.MLXTestCase): |
| | def test_simple_compile(self): |
| | def fun(x, y): |
| | return x + y |
| |
|
| | compiled_fn = mx.compile(fun) |
| | compiled_fn = mx.compile(fun) |
| | x = mx.array(1.0) |
| | y = mx.array(1.0) |
| | out = compiled_fn(x, y) |
| | self.assertEqual(out.item(), 2.0) |
| |
|
| | |
| | out = compiled_fn(x, y) |
| | self.assertEqual(out.item(), 2.0) |
| |
|
| | |
| | x = mx.array([1.0, 2.0]) |
| | out = compiled_fn(x, y) |
| | self.assertTrue(mx.array_equal(out, mx.array([2.0, 3.0]))) |
| |
|
| | y = mx.array([1.0, 2.0]) |
| | out = compiled_fn(x, y) |
| | self.assertTrue(mx.array_equal(out, mx.array([2.0, 4.0]))) |
| |
|
| | |
| | x = mx.array([1, 2], mx.int32) |
| | y = mx.array([1, 2], mx.int32) |
| | out = compiled_fn(x, y) |
| | self.assertEqual(out.dtype, mx.int32) |
| | self.assertTrue(mx.array_equal(out, mx.array([2, 4]))) |
| |
|
| | def test_compile_grad(self): |
| | def loss_fn(x): |
| | return mx.exp(x).sum() |
| |
|
| | grad_fn = mx.grad(loss_fn) |
| |
|
| | x = mx.array([0.5, -0.5, 1.2]) |
| | dfdx = grad_fn(x) |
| | compile_grad_fn = mx.compile(grad_fn) |
| | c_dfdx = grad_fn(x) |
| |
|
| | self.assertTrue(mx.allclose(c_dfdx, dfdx)) |
| |
|
| | |
| | c_dfdx = compile_grad_fn(x) |
| | self.assertTrue(mx.allclose(c_dfdx, dfdx)) |
| |
|
| | |
| | c_dfdx = mx.compile(grad_fn)(x) |
| | self.assertTrue(mx.allclose(c_dfdx, dfdx)) |
| |
|
| | |
| | def loss_fn(x): |
| | return mx.exp(x).sum(), mx.sin(x) |
| |
|
| | val_and_grad_fn = mx.value_and_grad(loss_fn) |
| | (loss, val), dfdx = val_and_grad_fn(x) |
| | (c_loss, c_val), c_dfdx = mx.compile(val_and_grad_fn)(x) |
| |
|
| | self.assertTrue(mx.allclose(c_dfdx, dfdx)) |
| | self.assertTrue(mx.allclose(c_loss, loss)) |
| | self.assertTrue(mx.allclose(c_val, val)) |
| |
|
| | def test_compile_inputs_with_primitives(self): |
| | x = mx.array([1, 2, 3]) |
| | y = mx.array([1, 2, 3]) |
| | for _ in range(5): |
| | x = x + y |
| | y = y + 1 |
| |
|
| | def fun(x, y): |
| | return x * y |
| |
|
| | out = fun(x, y) |
| |
|
| | x = mx.array([1, 2, 3]) |
| | y = mx.array([1, 2, 3]) |
| | for _ in range(5): |
| | x = x + y |
| | y = y + 1 |
| |
|
| | c_out = mx.compile(fun)(x, y) |
| | self.assertTrue(mx.array_equal(out, c_out)) |
| |
|
| | |
| | c_out = mx.compile(fun)(x, y) |
| | self.assertTrue(mx.array_equal(out, c_out)) |
| |
|
| | def test_compile_with_closure(self): |
| | x = mx.array(1) |
| |
|
| | def closure(y): |
| | return x + y |
| |
|
| | compiled = mx.compile(closure) |
| | out = compiled(mx.array(1)) |
| | self.assertEqual(out.item(), 2) |
| |
|
| | |
| | out = compiled(mx.array(1)) |
| | self.assertEqual(out.item(), 2) |
| |
|
| | |
| | x = mx.array([1, 2]) |
| | out = compiled(mx.array(1)) |
| |
|
| | |
| | self.assertEqual(out.item(), 2) |
| |
|
| | |
| | x = {"a": mx.array(1), "b": mx.array(2)} |
| |
|
| | def closure(y): |
| | return x["a"] + y + x["b"] |
| |
|
| | compiled = mx.compile(closure) |
| | out = compiled(mx.array(1)) |
| | self.assertEqual(out.item(), 4) |
| |
|
| | |
| | x["a"] = mx.array([4, 5]) |
| | out = compiled(mx.array(1)) |
| | self.assertEqual(out.item(), 4) |
| |
|
| | x["b"] = mx.array([-6, -8]) |
| | out = compiled(mx.array(1)) |
| | self.assertEqual(out.item(), 4) |
| |
|
| | |
| | x = mx.array(1) |
| | x = x + x |
| |
|
| | def closure(y): |
| | return x + y |
| |
|
| | compiled = mx.compile(closure) |
| | out = compiled(mx.array(2)) |
| | self.assertEqual(out.item(), 4) |
| |
|
| | |
| | out = compiled(mx.array(2)) |
| | self.assertEqual(out.item(), 4) |
| |
|
| | def test_function_creates_array(self): |
| | def fun(x): |
| | return x + mx.array(1) |
| |
|
| | cfun = mx.compile(fun) |
| | out = cfun(mx.array(3)) |
| | self.assertEqual(out.item(), 4) |
| |
|
| | |
| | out = cfun(mx.array(3)) |
| | self.assertEqual(out.item(), 4) |
| |
|
| | def test_enable_disable(self): |
| | def fun(x): |
| | y = x + 1 |
| | z = x + 1 |
| | return y + z |
| |
|
| | def count_prims(outputs): |
| | buf = io.StringIO() |
| | mx.export_to_dot(buf, outputs) |
| | buf.seek(0) |
| | return len([l for l in buf.read().split() if "label" in l]) |
| |
|
| | x = mx.array(1.0) |
| | cfun = mx.compile(fun) |
| | n_compiled = count_prims(cfun(x)) |
| |
|
| | |
| | mx.disable_compile() |
| | n_uncompiled = count_prims(cfun(x)) |
| | self.assertTrue(n_compiled < n_uncompiled) |
| |
|
| | |
| | mx.enable_compile() |
| | n_enable_compiled = count_prims(cfun(x)) |
| | self.assertEqual(n_compiled, n_enable_compiled) |
| |
|
| | def test_compile_two_input_grad(self): |
| | def loss(w, x): |
| | y = x * w |
| | return (y * mx.exp(y)).sum() |
| |
|
| | x = mx.array([1.0, 0.5, 2.0, -0.5]) |
| | w = mx.array([-1.0, 0.3, 1.0, -0.9]) |
| |
|
| | expected_grad = mx.grad(loss)(w, x) |
| | compiled_grad = mx.compile(mx.grad(loss))(w, x) |
| | self.assertTrue(mx.allclose(expected_grad, compiled_grad)) |
| |
|
| | def test_vmap_compiled(self): |
| | def simple_unary(x): |
| | return -mx.exp(x) |
| |
|
| | x = mx.array([[1.0, 2.0], [2.0, 3.0]]) |
| |
|
| | expected_out = mx.vmap(simple_unary)(x) |
| | out = mx.vmap(mx.compile(simple_unary))(x) |
| | self.assertTrue(mx.allclose(expected_out, out)) |
| |
|
| | def simple_binary(x, y): |
| | return mx.abs(mx.exp(x + y) + y) |
| |
|
| | x = mx.array([[1.0, -3.0], [0.5, -0.5]]) |
| | y = mx.array([[2.0, -1.0], [0.25, -0.25]]) |
| |
|
| | expected_out = mx.vmap(simple_binary)(x, y) |
| | out = mx.vmap(mx.compile(simple_binary))(x, y) |
| | self.assertTrue(mx.allclose(expected_out, out)) |
| |
|
| | expected_out = mx.vmap(simple_binary, in_axes=(0, 1))(x, y) |
| | out = mx.vmap(mx.compile(simple_binary), in_axes=(0, 1))(x, y) |
| | self.assertTrue(mx.allclose(expected_out, out)) |
| |
|
| | y = mx.array([0.25, -0.25]) |
| | expected_out = mx.vmap(simple_binary, in_axes=(0, None))(x, y) |
| | out = mx.vmap(mx.compile(simple_binary), in_axes=(0, None))(x, y) |
| | self.assertTrue(mx.allclose(expected_out, out)) |
| |
|
| | def simple_unary_outer(x): |
| | x = mx.abs(x) |
| |
|
| | @mx.compile |
| | def simple_unary_inner(z): |
| | return -mx.exp(x) |
| |
|
| | return simple_unary_inner(x) |
| |
|
| | expected_out = -mx.exp(mx.abs(x)) |
| | out = mx.vmap(simple_unary_outer)(x) |
| | self.assertTrue(mx.allclose(expected_out, out)) |
| |
|
| | def test_vjp_vjp_compiled(self): |
| | def simple_unary(x): |
| | return -mx.exp(x) |
| |
|
| | x = mx.array([[1.0, 2.0], [2.0, 3.0]]) |
| | y = mx.array([[1.0, 1.0], [1.0, 1.0]]) |
| |
|
| | expected_out, expected_vjp_out = mx.vjp(simple_unary, (x,), (y,)) |
| | out, vjp_out = mx.vjp(mx.compile(simple_unary), (x,), (y,)) |
| | self.assertTrue(mx.allclose(expected_vjp_out[0], vjp_out[0])) |
| | self.assertTrue(mx.allclose(expected_out[0], out[0])) |
| |
|
| | expected_out, expected_jvp_out = mx.jvp(simple_unary, (x,), (y,)) |
| | out, jvp_out = mx.jvp(mx.compile(simple_unary), (x,), (y,)) |
| | self.assertTrue(mx.allclose(expected_jvp_out[0], jvp_out[0])) |
| | self.assertTrue(mx.allclose(expected_out[0], out[0])) |
| |
|
| | def simple_binary(x, y): |
| | return mx.abs(mx.exp(x + y) + y) |
| |
|
| | x = mx.array([[1.0, -3.0], [0.5, -0.5]]) |
| | y = mx.array([[2.0, -1.0], [0.25, -0.25]]) |
| | cotans = mx.ones_like(x) |
| |
|
| | expected_out, expected_vjp_out = mx.vjp(simple_binary, (x, y), (cotans,)) |
| | out, vjp_out = mx.vjp(mx.compile(simple_binary), (x, y), (cotans,)) |
| | self.assertTrue(mx.allclose(expected_out[0], out[0])) |
| | self.assertTrue(mx.allclose(expected_vjp_out[0], vjp_out[0])) |
| | self.assertTrue(mx.allclose(expected_vjp_out[1], vjp_out[1])) |
| |
|
| | tans = (mx.ones_like(x), mx.ones_like(y)) |
| | expected_out, expected_jvp_out = mx.jvp(simple_binary, (x, y), tans) |
| | out, jvp_out = mx.jvp(mx.compile(simple_binary), (x, y), tans) |
| | self.assertTrue(mx.allclose(expected_jvp_out[0], jvp_out[0])) |
| | self.assertTrue(mx.allclose(expected_out[0], out[0])) |
| |
|
| | def test_transform_over_eval_compiled(self): |
| | def outer(x): |
| | y = mx.exp(mx.abs(x)) |
| | mx.eval(y) |
| | return y.sum() |
| |
|
| | x = mx.array([2.0, -1.0, 0.5]) |
| | dfdx = mx.grad(outer)(x) |
| |
|
| | @mx.compile |
| | def simple_unary(x): |
| | return mx.exp(mx.abs(x)) |
| |
|
| | def outer(x): |
| | y = simple_unary(x) |
| | mx.eval(y) |
| | return y.sum() |
| |
|
| | cdfdx = mx.grad(outer)(x) |
| | self.assertTrue(mx.allclose(dfdx, cdfdx)) |
| |
|
| | def test_compile_capture(self): |
| | |
| | state = {"y": mx.array(2)} |
| |
|
| | @partial(mx.compile, inputs=state) |
| | def test_state(x): |
| | x = x + state["y"] |
| | return x |
| |
|
| | test_state(mx.array(1)) |
| | |
| | self.assertEqual(state["y"], 2) |
| |
|
| | |
| | state["y"] = mx.array(3) |
| | out = test_state(mx.array(1)) |
| | self.assertEqual(out.item(), 4) |
| |
|
| | |
| | state = [mx.array(2)] |
| |
|
| | @partial(mx.compile, inputs=state) |
| | def test_state(x): |
| | x = x + state[0] |
| | return x |
| |
|
| | out = test_state(mx.array(1)) |
| | self.assertEqual(out.item(), 3) |
| | state[0] = mx.array(3) |
| | out = test_state(mx.array(1)) |
| | self.assertEqual(out.item(), 4) |
| |
|
| | |
| | state = ([mx.array(2)],) |
| |
|
| | @partial(mx.compile, inputs=state) |
| | def test_state(x): |
| | x = x + state[0][0] |
| | return x |
| |
|
| | out = test_state(mx.array(1)) |
| | self.assertEqual(out.item(), 3) |
| | state[0][0] = mx.array(3) |
| | out = test_state(mx.array(1)) |
| | self.assertEqual(out.item(), 4) |
| |
|
| | |
| | state = {} |
| |
|
| | @partial(mx.compile, outputs=state) |
| | def test_state(x): |
| | state["y"] = x + 3 |
| | return mx.abs(x) |
| |
|
| | test_state(mx.array(-1)) |
| | self.assertEqual(state["y"].item(), 2) |
| |
|
| | |
| | |
| | state = {} |
| |
|
| | @partial(mx.compile, inputs=state, outputs=state) |
| | def test_state(x): |
| | y = state.get("y", mx.array(0)) |
| | state["y"] = x + y |
| | return x + 2 * y |
| |
|
| | test_state(mx.array(1)) |
| | self.assertEqual(state["y"].item(), 1) |
| | test_state(mx.array(1)) |
| | self.assertEqual(state["y"].item(), 2) |
| |
|
| | def test_compile_rng(self): |
| | @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) |
| | def fun(): |
| | return mx.random.uniform(shape=(10, 10)) |
| |
|
| | self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2)) |
| |
|
| | def test_compile_kwargs(self): |
| | @mx.compile |
| | def fun(x, y, z): |
| | return x + y + z |
| |
|
| | x = mx.array(1) |
| | y = mx.array(2) |
| | z = mx.array(3) |
| | out = fun(x, y=y, z=z) |
| | self.assertEqual(out.item(), 6) |
| |
|
| | def test_shapeless_compile(self): |
| | y = 1 |
| |
|
| | @partial(mx.compile, shapeless=True) |
| | def fun(x): |
| | return x + y |
| |
|
| | x = mx.array([1, 2]) |
| | self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3]))) |
| |
|
| | |
| | |
| | y = 2 |
| | x = mx.array([1, 2, 3]) |
| | self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3, 4]))) |
| |
|
| | |
| | x = mx.array([1.0, 2.0, 3.0]) |
| | self.assertTrue(mx.array_equal(fun(x), mx.array([3.0, 4.0, 5.0]))) |
| |
|
| | |
| | x = mx.array([[1, 2, 3]]) |
| | self.assertTrue(mx.array_equal(fun(x), mx.array([[3, 4, 5]]))) |
| |
|
| | def test_shapeless_compile_with_broadcasts(self): |
| | x = mx.ones((2, 2)) |
| | y = mx.array([2, 2]) |
| |
|
| | def fun(x, y): |
| | return x * y |
| |
|
| | cfun = mx.compile(fun, shapeless=True) |
| | self.assertTrue(mx.array_equal(cfun(x, y), fun(x, y))) |
| | self.assertTrue(mx.array_equal(cfun(y, x), fun(y, x))) |
| | y = mx.array([[3]]) |
| | self.assertTrue(mx.array_equal(cfun(x, y), fun(x, y))) |
| | self.assertTrue(mx.array_equal(cfun(y, x), fun(y, x))) |
| |
|
| | def test_shapeless_compile_with_reduction(self): |
| | |
| | z = 1 |
| |
|
| | @partial(mx.compile, shapeless=True) |
| | def fun(x, y): |
| | return x + y.sum(0, keepdims=True) + z |
| |
|
| | x = mx.ones((2, 2), mx.int32) |
| | y = mx.ones((2, 2), mx.int32) |
| | self.assertTrue(mx.array_equal(fun(x, y), mx.full(shape=(2, 2), vals=4))) |
| | x = mx.ones((3, 3), mx.int32) |
| | y = mx.ones((3, 3), mx.int32) |
| | z = 2 |
| | self.assertTrue(mx.array_equal(fun(x, y), mx.full(shape=(3, 3), vals=5))) |
| |
|
| | x1 = mx.array([[1, 2], [3, 4], [5, 6]]) |
| | x2 = mx.array([[1, 2]]) |
| |
|
| | def fun(x): |
| | return x * x.sum(-1, keepdims=True) |
| |
|
| | cfun = mx.compile(fun, shapeless=True) |
| | mx.eval(cfun(x1)) |
| | self.assertTrue(mx.array_equal(fun(x2), cfun(x2))) |
| |
|
| | def fun(x): |
| | return x * x.sum(-1, keepdims=False) |
| |
|
| | cfun = mx.compile(fun, shapeless=True) |
| | self.assertTrue(mx.array_equal(fun(x2), cfun(x2))) |
| |
|
| | def test_shapeless_compile_unflatten(self): |
| | x = mx.zeros((1, 1, 4 * 32)) |
| |
|
| | def fun(x): |
| | return mx.unflatten(x, -1, (4, -1)) |
| |
|
| | self.assertEqual(mx.compile(fun, shapeless=True)(x).shape, (1, 1, 4, 32)) |
| |
|
| | def test_shapeless_compile_gather(self): |
| | x = mx.zeros((1, 1, 32)) |
| |
|
| | def fun(x): |
| | return x[:, -1, :] |
| |
|
| | self.assertEqual(mx.compile(fun, shapeless=True)(x).shape, (1, 32)) |
| |
|
| | def test_compile_with_constant(self): |
| | |
| | @partial(mx.compile) |
| | def fun(x, y): |
| | return x + y |
| |
|
| | z = fun(mx.array(1.0), 1.0) |
| | self.assertEqual(z.item(), 2.0) |
| |
|
| | z = fun(mx.array(1.0), 2.0) |
| | self.assertEqual(z.item(), 3.0) |
| |
|
| | z = fun(mx.array(1.0), y=1.0) |
| | self.assertEqual(z.item(), 2.0) |
| |
|
| | z = fun(mx.array(1.0), y=3.0) |
| | self.assertEqual(z.item(), 4.0) |
| |
|
| | |
| | @partial(mx.compile) |
| | def fun(x, y=(1, 2)): |
| | return x + y[0] + y[1] |
| |
|
| | z = fun(mx.array(1)) |
| | self.assertEqual(z.item(), 4) |
| |
|
| | z = fun(mx.array(1), (2, 2)) |
| | self.assertEqual(z.item(), 5) |
| |
|
| | z = fun(mx.array(1), (2, 1)) |
| | self.assertEqual(z.item(), 4) |
| |
|
| | |
| | @partial(mx.compile) |
| | def fun(x, y): |
| | if y: |
| | return x + 1 |
| | else: |
| | return x + 2 |
| |
|
| | z = fun(mx.array(1), True) |
| | self.assertEqual(z.item(), 2) |
| |
|
| | z = fun(mx.array(1), False) |
| | self.assertEqual(z.item(), 3) |
| |
|
| | |
| | @partial(mx.compile) |
| | def fun(x, y): |
| | if y == "one": |
| | return x + 1 |
| | else: |
| | return x + 2 |
| |
|
| | z = fun(mx.array(1), "one") |
| | self.assertEqual(z.item(), 2) |
| |
|
| | z = fun(mx.array(1), "two") |
| | self.assertEqual(z.item(), 3) |
| |
|
| | |
| | @partial(mx.compile) |
| | def fun(x, y): |
| | if y[0][0] == 1: |
| | return x + 1 |
| | else: |
| | return x + 2 |
| |
|
| | z = fun(mx.array(1), [[1]]) |
| | self.assertEqual(z.item(), 2) |
| |
|
| | z = fun(mx.array(1), [[0]]) |
| | self.assertEqual(z.item(), 3) |
| |
|
| | @partial(mx.compile) |
| | def fun(x, a, b): |
| | for ai in a: |
| | for bi in b: |
| | x = bi * x + ai |
| | return x |
| |
|
| | z = fun(mx.array(1), [1, 1], [2]) |
| | self.assertEqual(z.item(), 7) |
| |
|
| | z = fun(mx.array(1), [1], [1, 2]) |
| | self.assertEqual(z.item(), 5) |
| |
|
| | counter = [0] |
| |
|
| | @partial(mx.compile) |
| | def fun(x, y): |
| | counter[0] += 1 |
| | return x + y |
| |
|
| | z = fun(mx.array(1), 1) |
| | self.assertEqual(z.item(), 2) |
| |
|
| | z = fun(1, mx.array(1)) |
| | self.assertEqual(z.item(), 2) |
| |
|
| | self.assertEqual(counter[0], 2) |
| |
|
| | y = 1.0 |
| |
|
| | @mx.compile |
| | def fun(x, constant): |
| | return x + y |
| |
|
| | constant1 = "abc" |
| | out = fun(mx.array(0.0), constant1) |
| | self.assertEqual(out, mx.array(1.0)) |
| |
|
| | |
| | y = 2.0 |
| | constant2 = "abc".encode("utf-8").decode("utf-8") |
| | out = fun(mx.array(0.0), constant2) |
| | self.assertEqual(out, mx.array(1.0)) |
| |
|
| | |
| | constant2 = "xyz" |
| | out = fun(mx.array(0.0), constant2) |
| | self.assertEqual(out, mx.array(2.0)) |
| |
|
| | def test_compile_inf(self): |
| | @mx.compile |
| | def fun(x): |
| | return mx.isinf(x + 2) |
| |
|
| | out = fun(mx.array([0.0])) |
| | self.assertEqual(out.item(), False) |
| |
|
| | def test_unsupported_input_types(self): |
| | class MyClass: |
| | value = 1 |
| |
|
| | @mx.compile |
| | def fun(x, y): |
| | return x + y.value |
| |
|
| | with self.assertRaises(ValueError): |
| | out = fun(mx.array(0.0), MyClass()) |
| |
|
| | with self.assertRaises(ValueError): |
| | out = fun(mx.array(0.0), y=MyClass()) |
| |
|
| | def test_compile_create_list(self): |
| | @mx.compile |
| | def fun(): |
| | return [0.1 * mx.zeros((2,)), 0.1 * mx.zeros((2,))] |
| |
|
| | out = fun() |
| | mx.eval(out) |
| |
|
| | def test_compile_vjp(self): |
| | def fun(w): |
| | w1 = w + w |
| | w2 = w + w |
| | return w @ w1 + w2 @ w2 |
| |
|
| | def step(w): |
| | out, grad = mx.vjp(fun, (w,), (mx.array([[1.0, 1.0], [1.0, 1.0]]),)) |
| | return out[0], grad[0] |
| |
|
| | w = mx.zeros((2, 2)) |
| | mx.eval(w) |
| |
|
| | expected = step(w) |
| | out = mx.compile(step)(w) |
| | self.assertTrue(mx.allclose(expected[0], out[0])) |
| | self.assertTrue(mx.allclose(expected[1], out[1])) |
| |
|
| | def fun(w1, w2, x): |
| | x = x @ w1 |
| | y = x @ w2 |
| | x = x + y * y |
| | return (x * x).sum() |
| |
|
| | w1 = mx.zeros((4, 4)) |
| | w2 = mx.zeros((4, 4)) |
| | x = mx.zeros((4, 4)) |
| |
|
| | def step(w1, w2, x): |
| | loss, gradient = mx.value_and_grad(fun)(w1, w2, x) |
| | w1 = w1 + gradient |
| | return loss, w1 |
| |
|
| | mx.eval(x, w1, w2) |
| | expected = step(w1, w2, x) |
| | out = mx.compile(step)(w1, w2, x) |
| |
|
| | self.assertTrue(mx.allclose(expected[0], out[0])) |
| | self.assertTrue(mx.allclose(expected[1], out[1])) |
| |
|
| | def test_shapeless_mean(self): |
| | def mean(x): |
| | return mx.mean(x, keepdims=True) |
| |
|
| | cfun = mx.compile(mean) |
| | out = cfun(mx.ones((5, 5))) |
| | self.assertTrue(mx.allclose(out, mx.array(1.0))) |
| |
|
| | cmean = mx.compile(mean, shapeless=True) |
| |
|
| | x = mx.ones(2) |
| | out = cmean(x) |
| | self.assertTrue(mx.allclose(out, mean(x))) |
| |
|
| | x = mx.ones(4) |
| | out = cmean(x) |
| | self.assertTrue(mx.allclose(out, mean(x))) |
| |
|
| | x = mx.ones(7) |
| | out = cmean(x) |
| | self.assertTrue(mx.allclose(out, mean(x))) |
| |
|
| | def test_compile_broadcast_only(self): |
| | def fn(a): |
| | a = mx.broadcast_to(a, (1,)) |
| | return a + a |
| |
|
| | out = mx.compile(fn)(mx.array(2.0)) |
| | |
| | self.assertTrue(repr(out) is not None) |
| | self.assertTrue(mx.array_equal(out, mx.array([4.0]))) |
| |
|
| | def test_compile_with_long_name(self): |
| | def fn(a, b): |
| | for _ in range(10): |
| | a = a - 1.0 |
| | b = b - 1.0 |
| | return a + b |
| |
|
| | out = mx.compile(fn)(mx.array(10.0), mx.array(20.0)) |
| | self.assertEqual(out.item(), 10.0) |
| |
|
| | def test_compile_multi_output(self): |
| | def fn(x): |
| | ys = [x] |
| | for i in range(5): |
| | ys.append(ys[-1] + x) |
| | return ys, mx.sum(ys[-1]) |
| |
|
| | x = mx.ones(1, dtype=mx.int32) |
| | y1 = mx.compile(fn)(x)[1] |
| | y2 = fn(x)[1] |
| | self.assertEqual(y1.item(), y2.item()) |
| | self.assertEqual(y1.item(), 6) |
| |
|
| | def test_inf_constant(self): |
| | def fn(x): |
| | return mx.where(mx.isinf(x), 0, 1) |
| |
|
| | x = mx.array([0, float("inf"), 1], dtype=mx.bfloat16) |
| | self.assertTrue(mx.array_equal(mx.compile(fn)(x), fn(x))) |
| |
|
| | def test_max_into_equal(self): |
| | x = mx.random.uniform(shape=(1, 2, 2)) |
| | mx.eval(x) |
| |
|
| | def fn(): |
| | maxes = mx.max(x, axis=(1, 2), keepdims=True) |
| | return x == maxes |
| |
|
| | out = mx.compile(fn)() |
| | expected = fn() |
| | self.assertTrue(mx.array_equal(expected, out)) |
| |
|
| | def test_dtypes(self): |
| | x = mx.array([0, 1, 2, 3]) |
| | dtypes = [mx.bool_, mx.int8, mx.uint8, mx.int16, mx.uint16] |
| | for dtype in dtypes: |
| | x = x.astype(dtype) |
| | mx.eval(x) |
| |
|
| | def fn(x): |
| | return x * 1 + 0 |
| |
|
| | out = mx.compile(fn)(x) |
| | expected = fn(x) |
| | self.assertTrue(mx.array_equal(expected, out)) |
| |
|
| | def test_compile_without_captured_inputs(self): |
| | x = mx.array([1, 2, 3]) + 2 |
| |
|
| | def fn(a): |
| | y = x + 1 |
| | return a + y |
| |
|
| | with self.assertRaises(ValueError): |
| | y = mx.compile(fn)(x) |
| |
|
| | x = mx.array([1.0, 2.0]) + mx.array([1.0, 2.0]) |
| | y = None |
| |
|
| | def fn(x): |
| | nonlocal y |
| | if y is None: |
| | y = mx.array([1.0, 2.0]) |
| |
|
| | y = y + x |
| | return y |
| |
|
| | fn(x) |
| | with self.assertRaises(ValueError): |
| | y = mx.compile(fn)(x) |
| |
|
| | def test_compile_dynamic_dims(self): |
| | a = mx.random.uniform(shape=(2,) * 10) |
| | b = mx.random.uniform(shape=(2,) * 10) |
| | a = a.T |
| | mx.eval(a, b) |
| |
|
| | def fn(a, b): |
| | return mx.abs(a + b) |
| |
|
| | out = mx.compile(fn)(a, b) |
| | expected = fn(a, b) |
| | self.assertTrue(mx.allclose(out, expected)) |
| |
|
| | def test_compile_many_inputs(self): |
| | inputs = [mx.ones((2, 2, 2, 2)) for _ in range(20)] |
| | inputs[0] = inputs[0].T |
| |
|
| | @mx.compile |
| | def fun(*inputs): |
| | x = inputs[0] |
| | for y in inputs[1:10]: |
| | x = x + y |
| | a = inputs[10] |
| | for b in inputs[11:]: |
| | a = a + b |
| | return x + a |
| |
|
| | out = fun(*inputs) |
| | self.assertTrue(mx.allclose(out, mx.full((2, 2), 20))) |
| |
|
| | @mx.compile |
| | def fun(arrs): |
| | for _ in range(6): |
| | arrs = [x + y for x, y in zip(arrs[::2], arrs[1::2])] |
| | return arrs[0] |
| |
|
| | arrs = [mx.array([1.0, 2.0]) for _ in range(64)] |
| | out = fun(arrs) |
| | self.assertTrue(mx.allclose(out, mx.array([64.0, 128.0]))) |
| |
|
| | inputs = [mx.arange(16384).astype(mx.float16) for _ in range(8)] |
| |
|
| | def fun(inputs): |
| | a = inputs[0] + inputs[1] |
| | b = inputs[2] + inputs[3] |
| | c = inputs[4] + inputs[5] |
| | d = inputs[6] + inputs[7] |
| | return a * b * c * d |
| |
|
| | out = mx.compile(fun)(inputs) |
| | expected = fun(inputs) |
| | self.assertTrue(mx.allclose(out, expected)) |
| |
|
| | def test_compile_many_outputs(self): |
| |
|
| | @mx.compile |
| | def fun(arr): |
| | arrs = [arr] * 64 |
| | first_arrs = None |
| | for _ in range(6): |
| | arrs = [x + y for x, y in zip(arrs[::2], arrs[1::2])] |
| | if first_arrs is None: |
| | first_arrs = arrs |
| | return arrs[0], first_arrs |
| |
|
| | out = fun(mx.array([1.0, 2.0])) |
| | self.assertTrue(mx.allclose(out[0], mx.array([64.0, 128.0]))) |
| |
|
| | def test_shapeless_compile_matmul(self): |
| | a = mx.array([0.0, 1.0, 2.0]) |
| | b = mx.array([0.0, 1.0, 2.0]) |
| |
|
| | fun = mx.compile(lambda a, b: a @ b, shapeless=True) |
| | self.assertTrue(mx.allclose(fun(a, b), a @ b)) |
| |
|
| | def test_shapeless_compile_slice_update(self): |
| | def fun(x): |
| | x[2] = mx.array([3.0]) |
| | return x |
| |
|
| | cfun = mx.compile(fun, shapeless=True) |
| |
|
| | a = mx.array([0.0, 1.0, 2.0, 3.0]) |
| | self.assertTrue(mx.allclose(cfun(a), fun(a))) |
| |
|
| | a = mx.array([0.0, 1.0, 2.0, 3.0, 4.0]) |
| | self.assertTrue(mx.allclose(cfun(a), fun(a))) |
| |
|
| | def test_shapeless_compile_with_reshape(self): |
| | def fun(x): |
| | return x.reshape(x.shape[0] * x.shape[1], -1) |
| |
|
| | compiled_fun = mx.compile(fun, shapeless=True) |
| |
|
| | x = mx.zeros(shape=(2, 3, 4)) |
| | out = compiled_fun(x) |
| | self.assertEqual(out.shape, (6, 4)) |
| |
|
| | x = mx.zeros(shape=(2, 3, 8)) |
| | out = compiled_fun(x) |
| | self.assertEqual(out.shape, (6, 8)) |
| |
|
| | x = mx.zeros(shape=(5, 5, 5)) |
| |
|
| | with self.assertRaises(ValueError): |
| | compiled_fun(x) |
| |
|
| | def test_compile_shapeless_with_broadcast(self): |
| | a = mx.array(0.0) |
| | b = mx.ones((2, 2)) |
| |
|
| | def fun(a): |
| | return mx.broadcast_to(a, b.shape) |
| |
|
| | cfun = mx.compile(fun, shapeless=True) |
| | |
| | cfun(a) |
| |
|
| | |
| | with self.assertRaises(ValueError): |
| | cfun(mx.array(0.0).reshape(1, 1, 1)) |
| |
|
| | def fun(a, b): |
| | return mx.broadcast_arrays(a, b) |
| |
|
| | cfun = mx.compile(fun, shapeless=True) |
| | a, b = cfun(a, b) |
| | self.assertEqual(a.shape, (2, 2)) |
| | self.assertEqual(b.shape, (2, 2)) |
| |
|
| | |
| | a = mx.zeros((2, 1, 4, 2)) |
| | b = mx.zeros((3, 2, 5)) |
| |
|
| | def fun(a, b): |
| | return a @ b |
| |
|
| | cfun = mx.compile(fun, shapeless=True) |
| | out = cfun(a, b) |
| | self.assertEqual(out.shape, (2, 3, 4, 5)) |
| |
|
| | |
| | def fun(args): |
| | return sum(args).sum() |
| |
|
| | a = mx.array(0.0) |
| | b = mx.ones((2, 2)) |
| |
|
| | cfun = mx.compile(mx.grad(fun), shapeless=True) |
| | out = cfun((a, b)) |
| |
|
| | self.assertEqual(out[0].shape, ()) |
| | self.assertEqual(out[1].shape, (2, 2)) |
| |
|
| | out = cfun((b, a)) |
| |
|
| | self.assertEqual(out[0].shape, (2, 2)) |
| | self.assertEqual(out[1].shape, ()) |
| |
|
| | |
| | def fun(args): |
| | return (args[0] @ args[1]).sum() |
| |
|
| | a = mx.zeros((2, 1, 4, 2)) |
| | b = mx.zeros((3, 2, 5)) |
| |
|
| | cfun = mx.compile(mx.grad(fun), shapeless=True) |
| | out = cfun((a, b)) |
| |
|
| | self.assertEqual(out[0].shape, (2, 1, 4, 2)) |
| | self.assertEqual(out[1].shape, (3, 2, 5)) |
| |
|
| | a = mx.zeros((3, 1, 4, 2)) |
| | b = mx.zeros((2, 2, 5)) |
| |
|
| | out = cfun((a, b)) |
| |
|
| | self.assertEqual(out[0].shape, (3, 1, 4, 2)) |
| | self.assertEqual(out[1].shape, (2, 2, 5)) |
| |
|
| | def test_leaks(self): |
| | gc.collect() |
| | if mx.metal.is_available(): |
| | mem_pre = mx.get_active_memory() |
| | else: |
| | mem_pre = 0 |
| |
|
| | def outer(): |
| | d = {} |
| |
|
| | def f(x): |
| | return d["x"] |
| |
|
| | d["f"] = mx.compile(f) |
| | d["x"] = mx.array([0] * 1000) |
| |
|
| | for _ in range(5): |
| | outer() |
| | gc.collect() |
| |
|
| | if mx.metal.is_available(): |
| | mem_post = mx.get_active_memory() |
| | else: |
| | mem_post = 0 |
| |
|
| | self.assertEqual(mem_pre, mem_post) |
| |
|
| | def test_double_constant(self): |
| | with mx.stream(mx.cpu): |
| | x = mx.array(1.0, dtype=mx.float64) |
| |
|
| | def fun(x): |
| | return (x + math.pi) * 2.0 |
| |
|
| | y = fun(x).item() |
| | y_compiled = mx.compile(fun)(x).item() |
| | self.assertEqual(y, y_compiled) |
| |
|
| | def test_shared_broadcast(self): |
| | def fun(x, y, z): |
| | yy = mx.broadcast_to(y, z.shape) |
| | return (x + yy * z), yy.sum() |
| |
|
| | a = mx.random.normal((10, 10)) |
| | b = mx.array(0.1) |
| | c = mx.random.normal((10, 10)) |
| | mx.eval(a, b, c) |
| | fc = mx.compile(fun) |
| | d = fc(a, b, c) |
| |
|
| | s = StringIO() |
| | mx.export_to_dot(s, a=a, b=b, c=c, d1=d[0], d2=d[1]) |
| | s.seek(0) |
| | s = s.read() |
| |
|
| | self.assertTrue("CompiledBroadcastMultiplyAdd" in s) |
| | d_hat = fun(a, b, c) |
| | self.assertTrue(mx.allclose(d[0], d_hat[0])) |
| | self.assertTrue(mx.allclose(d[1], d_hat[1])) |
| |
|
| | def test_wrap_compiled(self): |
| | @mx.compile |
| | def inner(): |
| | pass |
| |
|
| | @wraps(inner) |
| | def wrapper(): |
| | pass |
| |
|
| | def test_compiled_preserves_attributes(self): |
| | def inner(x: mx.array, y: str): |
| | """ |
| | A useful function. |
| | """ |
| | pass |
| |
|
| | c_inner = mx.compile(inner) |
| | self.assertEqual(inner.__name__, c_inner.__name__) |
| | self.assertEqual(inner.__qualname__, c_inner.__qualname__) |
| | self.assertEqual(inner.__doc__, c_inner.__doc__) |
| | self.assertEqual(inspect.signature(inner), inspect.signature(c_inner)) |
| |
|
| | def test_compile_with_none(self): |
| | @mx.compile |
| | def fun(x, y): |
| | if y is None: |
| | return mx.abs(x - 2.0) |
| | else: |
| | return mx.abs(x + y) |
| |
|
| | out = fun(mx.array(1.0), None) |
| | self.assertEqual(out.item(), 1.0) |
| |
|
| | out = fun(mx.array(1.0), mx.array(2.0)) |
| | self.assertEqual(out.item(), 3.0) |
| |
|
| | def test_compile_changing_outputs(self): |
| | @mx.compile |
| | def fun(x, y): |
| | if y is None: |
| | return 2 * x |
| | elif ( |
| | isinstance(x, mx.array) |
| | and isinstance(y, mx.array) |
| | and x.dtype == y.dtype == mx.float32 |
| | ): |
| | return [x + y] |
| | elif y.dtype == mx.bool_: |
| | return {"a": x, "b": y * x} |
| | else: |
| | return None |
| |
|
| | a = fun(mx.array(1.0), mx.array(2.0)) |
| | self.assertTrue(isinstance(a, list)) |
| | self.assertEqual(a[0].item(), 3.0) |
| |
|
| | b = fun(mx.array(1.0), mx.array(True)) |
| | self.assertTrue(isinstance(b, dict)) |
| | self.assertEqual(b["a"].item(), 1.0) |
| | self.assertEqual(b["b"].item(), 1.0) |
| |
|
| | c = fun(mx.array(1.0), None) |
| | self.assertTrue(isinstance(c, mx.array)) |
| | self.assertEqual(c.item(), 2.0) |
| |
|
| | d = fun(False, mx.array(1.0)) |
| | self.assertTrue(d is None) |
| |
|
| | def test_compile_changing_outputs_with_state(self): |
| | state = [mx.array(1.0)] |
| |
|
| | @partial(mx.compile, inputs=state, outputs=state) |
| | def fun(y): |
| | x = state[0] |
| | if y.dtype == mx.float32: |
| | state[0] = 2 * y |
| | return [x, y, x + y] |
| | elif y.dtype == mx.int32: |
| | state[0] *= 2 |
| | return x + y |
| |
|
| | for i in range(10): |
| | fun(mx.array(1.0)) |
| | fun(mx.array(1)) |
| |
|
| | self.assertEqual(state[0].item(), 4) |
| |
|
| | def test_outputs_changing(self): |
| | @mx.compile |
| | def fun(x): |
| | x = mx.abs(mx.negative(x)) |
| | y = mx.abs(x) |
| | return x, y |
| |
|
| | @mx.compile |
| | def fun2(x): |
| | x = mx.abs(mx.negative(x)) |
| | y = mx.abs(x) |
| | return y |
| |
|
| | a, b = fun(mx.array(-1.0)) |
| | mx.eval(a, b) |
| |
|
| | a = fun2(mx.array(-1.0)) |
| | self.assertEqual(a.item(), 1.0) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | mlx_tests.MLXTestRunner() |
| |
|