|
|
|
|
|
|
|
|
import argparse |
|
|
import math |
|
|
import os |
|
|
import subprocess |
|
|
import time |
|
|
|
|
|
import mlx.core as mx |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"]) |
|
|
device_name = device_name.decode("utf-8").strip("\n") |
|
|
|
|
|
N_warmup = 8 |
|
|
N_iter_bench = 80 |
|
|
N_iter_func = 5 |
|
|
|
|
|
|
|
|
def bench(f, a, b): |
|
|
for i in range(N_warmup): |
|
|
f(a, b) |
|
|
torch.mps.synchronize() |
|
|
|
|
|
s = time.perf_counter_ns() |
|
|
for i in range(N_iter_bench): |
|
|
f(a, b) |
|
|
e = time.perf_counter_ns() |
|
|
return (e - s) * 1e-9 |
|
|
|
|
|
|
|
|
def gemm_nn_mlx(a, b): |
|
|
ys = [] |
|
|
for i in range(N_iter_func): |
|
|
y = a @ b |
|
|
ys.append(y) |
|
|
mx.eval(ys) |
|
|
return ys |
|
|
|
|
|
|
|
|
def gemm_nt_mlx(a, b): |
|
|
ys = [] |
|
|
for i in range(N_iter_func): |
|
|
y = a @ b.transpose((0, 2, 1)) |
|
|
ys.append(y) |
|
|
mx.eval(ys) |
|
|
return ys |
|
|
|
|
|
|
|
|
def gemm_tn_mlx(a, b): |
|
|
ys = [] |
|
|
for i in range(N_iter_func): |
|
|
y = a.transpose((0, 2, 1)) @ b |
|
|
ys.append(y) |
|
|
mx.eval(ys) |
|
|
return ys |
|
|
|
|
|
|
|
|
def gemm_tt_mlx(a, b): |
|
|
ys = [] |
|
|
for i in range(N_iter_func): |
|
|
y = a.transpose((0, 2, 1)) @ b.transpose((0, 2, 1)) |
|
|
ys.append(y) |
|
|
mx.eval(ys) |
|
|
return ys |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def gemm_nn_torch(a, b): |
|
|
ys = [] |
|
|
for i in range(N_iter_func): |
|
|
y = a @ b |
|
|
ys.append(y) |
|
|
torch.mps.synchronize() |
|
|
return ys |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def gemm_nt_torch(a, b): |
|
|
ys = [] |
|
|
for i in range(N_iter_func): |
|
|
y = a @ b.transpose(-1, -2) |
|
|
ys.append(y) |
|
|
torch.mps.synchronize() |
|
|
return ys |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def gemm_tn_torch(a, b): |
|
|
ys = [] |
|
|
for i in range(N_iter_func): |
|
|
y = a.transpose(-1, -2) @ b |
|
|
ys.append(y) |
|
|
torch.mps.synchronize() |
|
|
return ys |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def gemm_tt_torch(a, b): |
|
|
ys = [] |
|
|
for i in range(N_iter_func): |
|
|
y = a.transpose(-1, -2) @ b.transpose(-1, -2) |
|
|
ys.append(y) |
|
|
torch.mps.synchronize() |
|
|
return ys |
|
|
|
|
|
|
|
|
def bench_shape(B, M, N, K, np_dtype, transpose="nn"): |
|
|
shape_a = (B, M, K) if transpose[0] == "n" else (B, K, M) |
|
|
shape_b = (B, K, N) if transpose[1] == "n" else (B, N, K) |
|
|
|
|
|
a_np = np.random.normal(0.0, 1.0 / math.sqrt(M + K), shape_a).astype(np_dtype) |
|
|
b_np = np.random.normal(0.0, 1.0 / math.sqrt(N + K), shape_b).astype(np_dtype) |
|
|
|
|
|
a_mx = mx.array(a_np) |
|
|
b_mx = mx.array(b_np) |
|
|
|
|
|
a_pt = torch.from_numpy(a_np).to("mps") |
|
|
b_pt = torch.from_numpy(b_np).to("mps") |
|
|
|
|
|
torch.mps.synchronize() |
|
|
|
|
|
f_mx = { |
|
|
"nn": gemm_nn_mlx, |
|
|
"nt": gemm_nt_mlx, |
|
|
"tn": gemm_tn_mlx, |
|
|
"tt": gemm_tt_mlx, |
|
|
}[transpose] |
|
|
|
|
|
f_pt = { |
|
|
"nn": gemm_nn_torch, |
|
|
"nt": gemm_nt_torch, |
|
|
"tn": gemm_tn_torch, |
|
|
"tt": gemm_tt_torch, |
|
|
}[transpose] |
|
|
|
|
|
time_torch = bench(f_pt, a_pt, b_pt) |
|
|
time_mlx = bench(f_mx, a_mx, b_mx) |
|
|
|
|
|
t_a = (0, 1, 2) if transpose[0] == "n" else (0, 2, 1) |
|
|
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1) |
|
|
|
|
|
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b) |
|
|
c_npy = a_np.transpose(t_a).astype(np_dtype) @ b_np.transpose(t_b).astype(np_dtype) |
|
|
|
|
|
atol = 1e-5 if np_dtype == np.float32 else 1e-4 |
|
|
|
|
|
if not np.allclose(c_mlx, c_npy.astype(np_dtype), atol=atol): |
|
|
print( |
|
|
f"Failed at {(B, M, N, K)} [transpose = {transpose}] with max(|a - b|) = {np.max(np.abs(c_npy - c_mlx))}" |
|
|
) |
|
|
|
|
|
return time_mlx, time_torch |
|
|
|
|
|
|
|
|
def get_gflop_count(B, M, N, K): |
|
|
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Run gemm benchmarks") |
|
|
|
|
|
dtypes = ("float32", "float16", "complex64") |
|
|
transposes = ("nn", "nt", "tn") |
|
|
shapes = ( |
|
|
(16, 234, 768, 3072), |
|
|
(1, 64, 64, 25344), |
|
|
(16, 1024, 1024, 1024), |
|
|
(1, 1024, 1024, 2048), |
|
|
(4, 1024, 1024, 4096), |
|
|
(4, 1024, 4096, 1024), |
|
|
(1, 4096, 4096, 4096), |
|
|
) |
|
|
|
|
|
for dtype in dtypes: |
|
|
for transpose in transposes: |
|
|
for B, M, N, K in shapes: |
|
|
np_dtype = getattr(np, dtype) |
|
|
time_mlx, time_torch = bench_shape(B, M, N, K, np_dtype, transpose) |
|
|
|
|
|
gflop_count = get_gflop_count(B, M, N, K) |
|
|
gflops_mx = gflop_count / (time_mlx) |
|
|
gflops_pt = gflop_count / (time_torch) |
|
|
diff = gflops_mx / gflops_pt - 1.0 |
|
|
|
|
|
print( |
|
|
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%" |
|
|
) |
|
|
if gflops_pt >= 2.0 * gflops_mx: |
|
|
print("ATTENTION ^^^^^^^") |
|
|
|