| | |
| |
|
| | import argparse |
| |
|
| | import mlx.core as mx |
| | from time_utils import time_fn |
| |
|
| | B = 8 |
| | T = 1024 |
| | D = 512 |
| |
|
| |
|
| | def time_batch_matmul(): |
| | mx.random.seed(3) |
| | a = mx.random.uniform(shape=(B, T, D)) |
| | b = mx.random.uniform(shape=(D, D)) |
| | c = mx.random.uniform(shape=(B, T, D)) |
| | mx.eval(a, b, c) |
| |
|
| | time_fn(mx.matmul, a, b) |
| |
|
| | def batch_vjp_first(): |
| | return mx.vjp(mx.matmul, [a, b], [c])[1][0] |
| |
|
| | time_fn(batch_vjp_first) |
| |
|
| | def batch_vjp_second(): |
| | return mx.vjp(mx.matmul, [a, b], [c])[1][1] |
| |
|
| | time_fn(batch_vjp_second) |
| |
|
| |
|
| | def time_unbatch_matmul(): |
| | mx.random.seed(3) |
| | a = mx.random.uniform(shape=(B * T, D)) |
| | b = mx.random.uniform(shape=(D, D)) |
| | c = mx.random.uniform(shape=(B * T, D)) |
| | mx.eval(a, b, c) |
| | time_fn(mx.matmul, a, b) |
| |
|
| | def unbatch_vjp_first(): |
| | return mx.matmul(c, mx.transpose(b)) |
| |
|
| | time_fn(unbatch_vjp_first) |
| |
|
| | def unbatch_vjp_second(): |
| | return mx.matmul(mx.transpose(a), c) |
| |
|
| | time_fn(unbatch_vjp_second) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser("MLX benchmarks.") |
| | parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.") |
| | args = parser.parse_args() |
| | if args.gpu: |
| | mx.set_default_device(mx.gpu) |
| | else: |
| | mx.set_default_device(mx.cpu) |
| |
|
| | time_batch_matmul() |
| | time_unbatch_matmul() |
| |
|