activation / tests /conftest.py
iamwyldecat's picture
feat(rms-norm): Impl fused RMSNorm
f3b99fb
import logging
import numpy as np
import plotly.graph_objects as go
import pytest
from .kernels.test_poly_norm_perf import PERF_RESULTS, PerfResult
logger = logging.getLogger(__name__)
DO_PLOT = False
def plot(perf_results: list[PerfResult]):
x_labels = [f"{r.type}, {r.shape}, {r.dtype}" for r in perf_results]
kernel_speedup = [r.speedup for r in perf_results]
torch_speedup = [1 for _ in perf_results]
geo_mean = float(np.exp(np.mean(np.log(kernel_speedup))))
x_labels.append("Geometric Mean")
kernel_speedup.append(geo_mean)
torch_speedup.append(1.0)
fig = go.Figure()
bar_width = 0.2
fig.add_trace(
go.Bar(
x=x_labels,
y=kernel_speedup,
name="Activation",
marker_color="rgb(100, 100, 100)",
text=[f"x{v:.2f}" for v in kernel_speedup],
textfont=dict(size=14),
textposition="outside",
# width=[bar_width] * len(x_labels),
)
)
fig.add_trace(
go.Bar(
x=x_labels,
y=torch_speedup,
name="Torch",
marker_color="rgb(30, 30, 30)",
text=[f"x{v:.2f}" for v in torch_speedup],
textfont=dict(size=14),
textposition="outside",
# width=[bar_width] * len(x_labels),
)
)
fig.update_layout(
title=dict(
text="<b>Speedup over torch (higher is better) (MI250, torch 2.7, ROCm 6.3)</b>",
font=dict(size=24),
),
legend=dict(
x=0.01,
y=0.99,
xanchor="left",
yanchor="top",
bgcolor="rgba(0,0,0,0)",
bordercolor="black",
borderwidth=1,
),
font=dict(size=16),
yaxis_title="Speedup (torch / activation)",
barmode="group",
bargroupgap=0,
bargap=0.2,
xaxis_tickangle=-45,
template="plotly_white",
yaxis_type="log",
shapes=[
dict(
type="rect",
xref="x",
yref="paper", # y축 전체 범위 (0~1)
x0=-0.5,
x1=len(x_labels) - 0.5,
y0=0,
y1=1,
line=dict(
color="black",
width=1.5,
),
fillcolor="rgba(0,0,0,0)", # 투명 배경
layer="above", # bar 아래에 그리기
)
],
)
output_file = "perf_result.html"
fig.write_html(output_file)
logger.info(f"Plotting performance results to {output_file}")
def pytest_addoption(parser):
parser.addoption(
"--run-perf", action="store_true", default=False, help="Run perf tests"
)
parser.addoption(
"--do-plot", action="store_true", default=False, help="Plot performance results"
)
@pytest.fixture
def do_plot(request):
return request.config.getoption("--do-plot")
def pytest_configure(config):
global DO_PLOT
DO_PLOT = config.getoption("--do-plot")
run_perf = config.getoption("--run-perf")
if DO_PLOT and not run_perf:
raise ValueError(
"Cannot plot performance results without running performance tests. "
"Please use --run-perf option."
)
config.addinivalue_line("markers", "perf: mark test as performance-related")
def pytest_collection_modifyitems(config, items):
run_perf = config.getoption("--run-perf")
skip_perf = pytest.mark.skip(reason="need --run-perf option to run")
skip_normal = pytest.mark.skip(
reason="normal tests skipped when --run-perf is used"
)
for item in items:
if "perf" in item.keywords and not run_perf:
item.add_marker(skip_perf)
elif "perf" not in item.keywords and run_perf:
item.add_marker(skip_normal)
def pytest_sessionfinish(session, exitstatus) -> None:
if DO_PLOT:
plot(PERF_RESULTS)
else:
logger.info(PERF_RESULTS)