activation / benchmarks /common /bench_framework.py
TaehyunKim
Fix fused add rms norm (#4)
a1e5ca8 unverified
import collections
import math
import re
from typing import Any, Dict, Sequence
import torch
import triton
from .diff_engine import DiffCase
def make_fwd_key(batch_size, seq_len, dim):
return f"forward : ({batch_size}, {seq_len}, {dim})"
def make_bwd_key(batch_size, seq_len, dim):
return f"backward : ({batch_size}, {seq_len}, {dim})"
def parse_config_string(config_str):
match = re.match(r"(\w+)\s*:\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)",
config_str)
if not match:
raise ValueError(f"Invalid config string: {config_str}")
_, bs, sl, d = match.groups()
return int(bs), int(sl), int(d)
def make_fwd_benchmark_for_case(
*,
case: DiffCase,
configs: Sequence[tuple[int, int, int]],
plot_name: str,
ylabel: str = "us",
line_vals=("naive", "cuda", "speedup"),
line_names: Dict[str, str] | None = None,
dtype=torch.bfloat16,
eps: float = 1e-6,
time_unit_scale: float = 1000,
):
timings_ms = collections.defaultdict(dict)
line_vals = list(line_vals)
line_names = line_names or {v: v.title() for v in line_vals}
x_vals = [list(_) for _ in configs]
@triton.testing.perf_report(
triton.testing.Benchmark(x_names=["dim", "batch_size", "seq_len"],
x_vals=x_vals,
line_arg="provider",
line_vals=line_vals,
line_names=[line_names[v] for v in line_vals],
ylabel=ylabel,
plot_name=plot_name,
args={}))
def bench(dim, batch_size, seq_len, provider):
key = make_fwd_key(dim, batch_size, seq_len)
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
if provider == "speedup":
return timings_ms["naive"][key] / timings_ms["cuda"][key]
obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I)
run = lambda: case.forward(obj, I)
ms = triton.testing.do_bench(run)
timings_ms[provider][key] = ms
return time_unit_scale * ms
return bench
def make_fwd_benchmark_plot_for_case(
*,
case: DiffCase,
configs: Sequence[tuple[int, int, int]],
plot_name: str,
ylabel: str = "Relative Speedup",
line_vals=("naive", "cuda"),
line_names: Dict[str, str] | None = None,
dtype=torch.bfloat16,
eps: float = 1e-6,
):
timings_ms = collections.defaultdict(dict)
spdup_ratio = list()
line_vals = list(line_vals)
line_names = line_names or {v: v.title() for v in line_vals}
x_vals = [make_fwd_key(*_) for _ in configs]
x_vals.append("Geometric Mean")
@triton.testing.perf_report(
triton.testing.Benchmark(x_names=["config"],
x_vals=x_vals,
line_arg="provider",
line_vals=line_vals,
line_names=[line_names[v] for v in line_vals],
ylabel=ylabel,
plot_name=plot_name,
args={}))
def bench(config, provider):
if config == "Geometric Mean":
if provider == "cuda":
return round(math.prod(spdup_ratio)**(1 / len(spdup_ratio)), 2)
else:
return 1.00
batch_size, seq_len, dim = parse_config_string(config)
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I)
run = lambda: case.forward(obj, I)
ms = triton.testing.do_bench(run)
timings_ms[provider][config] = ms
if provider == "cuda":
ratio = timings_ms["naive"][config] / timings_ms["cuda"][config]
spdup_ratio.append(ratio)
return round(ratio, 2)
else:
return 1.00
return bench
def make_bwd_benchmark_for_case(
*,
case: DiffCase,
configs: Sequence[tuple[int, int, int]],
plot_name: str,
ylabel: str = "us",
line_vals=("naive", "cuda", "speedup"),
line_names: Dict[str, str] | None = None,
dtype=torch.bfloat16,
eps: float = 1e-6,
time_unit_scale: float = 1000,
):
timings_ms = collections.defaultdict(dict)
line_vals = list(line_vals)
line_names = line_names or {v: v.title() for v in line_vals}
x_vals = [list(_) for _ in configs]
@triton.testing.perf_report(
triton.testing.Benchmark(x_names=["dim", "batch_size", "seq_len"],
x_vals=x_vals,
line_arg="provider",
line_vals=line_vals,
line_names=[line_names[v] for v in line_vals],
ylabel=ylabel,
plot_name=plot_name,
args={}))
def bench(dim, batch_size, seq_len, provider):
key = make_bwd_key(dim, batch_size, seq_len)
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
if provider == "speedup":
return timings_ms["naive"][key] / timings_ms["cuda"][key]
obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I)
y = case.forward(obj, I)
gin = list(case.grad_inputs(I)) + list(obj.parameters())
if isinstance(y, torch.Tensor):
g = [torch.randn_like(y)]
else:
g = [torch.randn_like(r) for r in y]
run = lambda: torch.autograd.grad(y,
gin,
g,
retain_graph=True,
create_graph=False,
allow_unused=False)
ms = triton.testing.do_bench(run)
timings_ms[provider][key] = ms
return time_unit_scale * ms
return bench
def make_bwd_benchmark_plot_for_case(
*,
case: DiffCase,
configs: Sequence[tuple[int, int, int]],
plot_name: str,
ylabel: str = "Relative Speedup",
line_vals=("naive", "cuda"),
line_names: Dict[str, str] | None = None,
dtype=torch.bfloat16,
eps: float = 1e-6,
):
timings_ms = collections.defaultdict(dict)
spdup_ratio = list()
line_vals = list(line_vals)
line_names = line_names or {v: v.title() for v in line_vals}
x_vals = [make_bwd_key(*_) for _ in configs]
x_vals.append("Geometric Mean")
@triton.testing.perf_report(
triton.testing.Benchmark(x_names=["config"],
x_vals=x_vals,
line_arg="provider",
line_vals=line_vals,
line_names=[line_names[v] for v in line_vals],
ylabel=ylabel,
plot_name=plot_name,
args={}))
def bench(config, provider):
if config == "Geometric Mean":
if provider == "cuda":
return round(math.prod(spdup_ratio)**(1 / len(spdup_ratio)), 2)
else:
return 1.00
batch_size, seq_len, dim = parse_config_string(config)
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I)
y = case.forward(obj, I)
gin = list(case.grad_inputs(I)) + list(obj.parameters())
if isinstance(y, torch.Tensor):
g = [torch.randn_like(y)]
else:
g = [torch.randn_like(r) for r in y]
run = lambda: torch.autograd.grad(y,
gin,
g,
retain_graph=True,
create_graph=False,
allow_unused=False)
ms = triton.testing.do_bench(run)
timings_ms[provider][config] = ms
if provider == "cuda":
ratio = timings_ms["naive"][config] / timings_ms["cuda"][config]
spdup_ratio.append(ratio)
return round(ratio, 2)
else:
return 1.00
return bench