|
import collections |
|
import math |
|
import re |
|
from typing import Any, Dict, Sequence |
|
|
|
import torch |
|
import triton |
|
|
|
from .diff_engine import DiffCase |
|
|
|
|
|
def make_fwd_key(batch_size, seq_len, dim): |
|
return f"forward : ({batch_size}, {seq_len}, {dim})" |
|
|
|
|
|
def make_bwd_key(batch_size, seq_len, dim): |
|
return f"backward : ({batch_size}, {seq_len}, {dim})" |
|
|
|
|
|
def parse_config_string(config_str): |
|
match = re.match(r"(\w+)\s*:\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", |
|
config_str) |
|
if not match: |
|
raise ValueError(f"Invalid config string: {config_str}") |
|
_, bs, sl, d = match.groups() |
|
return int(bs), int(sl), int(d) |
|
|
|
|
|
def make_fwd_benchmark_for_case( |
|
*, |
|
case: DiffCase, |
|
configs: Sequence[tuple[int, int, int]], |
|
plot_name: str, |
|
ylabel: str = "us", |
|
line_vals=("naive", "cuda", "speedup"), |
|
line_names: Dict[str, str] | None = None, |
|
dtype=torch.bfloat16, |
|
eps: float = 1e-6, |
|
time_unit_scale: float = 1000, |
|
): |
|
timings_ms = collections.defaultdict(dict) |
|
line_vals = list(line_vals) |
|
line_names = line_names or {v: v.title() for v in line_vals} |
|
x_vals = [list(_) for _ in configs] |
|
|
|
@triton.testing.perf_report( |
|
triton.testing.Benchmark(x_names=["dim", "batch_size", "seq_len"], |
|
x_vals=x_vals, |
|
line_arg="provider", |
|
line_vals=line_vals, |
|
line_names=[line_names[v] for v in line_vals], |
|
ylabel=ylabel, |
|
plot_name=plot_name, |
|
args={})) |
|
def bench(dim, batch_size, seq_len, provider): |
|
key = make_fwd_key(dim, batch_size, seq_len) |
|
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps) |
|
if provider == "speedup": |
|
return timings_ms["naive"][key] / timings_ms["cuda"][key] |
|
obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I) |
|
run = lambda: case.forward(obj, I) |
|
ms = triton.testing.do_bench(run) |
|
timings_ms[provider][key] = ms |
|
return time_unit_scale * ms |
|
|
|
return bench |
|
|
|
|
|
def make_fwd_benchmark_plot_for_case( |
|
*, |
|
case: DiffCase, |
|
configs: Sequence[tuple[int, int, int]], |
|
plot_name: str, |
|
ylabel: str = "Relative Speedup", |
|
line_vals=("naive", "cuda"), |
|
line_names: Dict[str, str] | None = None, |
|
dtype=torch.bfloat16, |
|
eps: float = 1e-6, |
|
): |
|
timings_ms = collections.defaultdict(dict) |
|
spdup_ratio = list() |
|
line_vals = list(line_vals) |
|
line_names = line_names or {v: v.title() for v in line_vals} |
|
x_vals = [make_fwd_key(*_) for _ in configs] |
|
x_vals.append("Geometric Mean") |
|
|
|
@triton.testing.perf_report( |
|
triton.testing.Benchmark(x_names=["config"], |
|
x_vals=x_vals, |
|
line_arg="provider", |
|
line_vals=line_vals, |
|
line_names=[line_names[v] for v in line_vals], |
|
ylabel=ylabel, |
|
plot_name=plot_name, |
|
args={})) |
|
def bench(config, provider): |
|
if config == "Geometric Mean": |
|
if provider == "cuda": |
|
return round(math.prod(spdup_ratio)**(1 / len(spdup_ratio)), 2) |
|
else: |
|
return 1.00 |
|
batch_size, seq_len, dim = parse_config_string(config) |
|
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps) |
|
obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I) |
|
run = lambda: case.forward(obj, I) |
|
ms = triton.testing.do_bench(run) |
|
timings_ms[provider][config] = ms |
|
if provider == "cuda": |
|
ratio = timings_ms["naive"][config] / timings_ms["cuda"][config] |
|
spdup_ratio.append(ratio) |
|
return round(ratio, 2) |
|
else: |
|
return 1.00 |
|
|
|
return bench |
|
|
|
|
|
def make_bwd_benchmark_for_case( |
|
*, |
|
case: DiffCase, |
|
configs: Sequence[tuple[int, int, int]], |
|
plot_name: str, |
|
ylabel: str = "us", |
|
line_vals=("naive", "cuda", "speedup"), |
|
line_names: Dict[str, str] | None = None, |
|
dtype=torch.bfloat16, |
|
eps: float = 1e-6, |
|
time_unit_scale: float = 1000, |
|
): |
|
timings_ms = collections.defaultdict(dict) |
|
line_vals = list(line_vals) |
|
line_names = line_names or {v: v.title() for v in line_vals} |
|
x_vals = [list(_) for _ in configs] |
|
|
|
@triton.testing.perf_report( |
|
triton.testing.Benchmark(x_names=["dim", "batch_size", "seq_len"], |
|
x_vals=x_vals, |
|
line_arg="provider", |
|
line_vals=line_vals, |
|
line_names=[line_names[v] for v in line_vals], |
|
ylabel=ylabel, |
|
plot_name=plot_name, |
|
args={})) |
|
def bench(dim, batch_size, seq_len, provider): |
|
key = make_bwd_key(dim, batch_size, seq_len) |
|
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps) |
|
if provider == "speedup": |
|
return timings_ms["naive"][key] / timings_ms["cuda"][key] |
|
obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I) |
|
y = case.forward(obj, I) |
|
gin = list(case.grad_inputs(I)) + list(obj.parameters()) |
|
if isinstance(y, torch.Tensor): |
|
g = [torch.randn_like(y)] |
|
else: |
|
g = [torch.randn_like(r) for r in y] |
|
run = lambda: torch.autograd.grad(y, |
|
gin, |
|
g, |
|
retain_graph=True, |
|
create_graph=False, |
|
allow_unused=False) |
|
ms = triton.testing.do_bench(run) |
|
timings_ms[provider][key] = ms |
|
return time_unit_scale * ms |
|
|
|
return bench |
|
|
|
|
|
def make_bwd_benchmark_plot_for_case( |
|
*, |
|
case: DiffCase, |
|
configs: Sequence[tuple[int, int, int]], |
|
plot_name: str, |
|
ylabel: str = "Relative Speedup", |
|
line_vals=("naive", "cuda"), |
|
line_names: Dict[str, str] | None = None, |
|
dtype=torch.bfloat16, |
|
eps: float = 1e-6, |
|
): |
|
timings_ms = collections.defaultdict(dict) |
|
spdup_ratio = list() |
|
line_vals = list(line_vals) |
|
line_names = line_names or {v: v.title() for v in line_vals} |
|
x_vals = [make_bwd_key(*_) for _ in configs] |
|
x_vals.append("Geometric Mean") |
|
|
|
@triton.testing.perf_report( |
|
triton.testing.Benchmark(x_names=["config"], |
|
x_vals=x_vals, |
|
line_arg="provider", |
|
line_vals=line_vals, |
|
line_names=[line_names[v] for v in line_vals], |
|
ylabel=ylabel, |
|
plot_name=plot_name, |
|
args={})) |
|
def bench(config, provider): |
|
if config == "Geometric Mean": |
|
if provider == "cuda": |
|
return round(math.prod(spdup_ratio)**(1 / len(spdup_ratio)), 2) |
|
else: |
|
return 1.00 |
|
batch_size, seq_len, dim = parse_config_string(config) |
|
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps) |
|
obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I) |
|
y = case.forward(obj, I) |
|
gin = list(case.grad_inputs(I)) + list(obj.parameters()) |
|
if isinstance(y, torch.Tensor): |
|
g = [torch.randn_like(y)] |
|
else: |
|
g = [torch.randn_like(r) for r in y] |
|
run = lambda: torch.autograd.grad(y, |
|
gin, |
|
g, |
|
retain_graph=True, |
|
create_graph=False, |
|
allow_unused=False) |
|
ms = triton.testing.do_bench(run) |
|
timings_ms[provider][config] = ms |
|
if provider == "cuda": |
|
ratio = timings_ms["naive"][config] / timings_ms["cuda"][config] |
|
spdup_ratio.append(ratio) |
|
return round(ratio, 2) |
|
else: |
|
return 1.00 |
|
|
|
return bench |
|
|