Commit
·
a32f88a
1
Parent(s):
c1e53ae
Update test_matmul.py to support UT of XPU
Browse files- tests/conftest.py +13 -2
- tests/test_matmul.py +13 -12
- torch-ext/triton_kernels/swiglu.py +1 -1
tests/conftest.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import pytest
|
2 |
-
|
3 |
|
4 |
def pytest_addoption(parser):
|
5 |
parser.addoption("--device", action="store", default="cuda")
|
@@ -12,8 +12,19 @@ def device(request):
|
|
12 |
|
13 |
@pytest.fixture
|
14 |
def fresh_knobs(monkeypatch):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
from triton._internal_testing import _fresh_knobs_impl
|
16 |
-
|
|
|
|
|
|
|
|
|
17 |
try:
|
18 |
yield fresh_function()
|
19 |
finally:
|
|
|
1 |
import pytest
|
2 |
+
import triton
|
3 |
|
4 |
def pytest_addoption(parser):
|
5 |
parser.addoption("--device", action="store", default="cuda")
|
|
|
12 |
|
13 |
@pytest.fixture
|
14 |
def fresh_knobs(monkeypatch):
|
15 |
+
try:
|
16 |
+
_ver_str = getattr(triton, "__version__", "0.0.0").split("+")[0]
|
17 |
+
_parts = _ver_str.split(".")
|
18 |
+
_ver_tuple = tuple(int(p) for p in (_parts + ["0", "0", "0"])[:3])
|
19 |
+
except Exception:
|
20 |
+
_ver_tuple = (0, 0, 0)
|
21 |
+
|
22 |
from triton._internal_testing import _fresh_knobs_impl
|
23 |
+
if _ver_tuple > (3, 4, 0):
|
24 |
+
fresh_function, reset_function = _fresh_knobs_impl()
|
25 |
+
else:
|
26 |
+
fresh_function, reset_function = _fresh_knobs_impl(monkeypatch)
|
27 |
+
|
28 |
try:
|
29 |
yield fresh_function()
|
30 |
finally:
|
tests/test_matmul.py
CHANGED
@@ -20,7 +20,7 @@ from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_m
|
|
20 |
# testing utilities
|
21 |
from triton_kernels.testing import assert_close, compute_actual_scale
|
22 |
# target-specific utilities
|
23 |
-
from triton_kernels.target_info import is_hip, is_hip_cdna3, is_cuda, is_hip_cdna4
|
24 |
|
25 |
# ---------------
|
26 |
# initialize data
|
@@ -70,7 +70,7 @@ def init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_sh
|
|
70 |
if mode == 'batched' or (not has_y_gammas) or (has_y_gammas and (gindx is not None) and act_dtype.itemsize >= 2):
|
71 |
gs0 = None
|
72 |
gs1 = None
|
73 |
-
if "float8" in str(weight_dtype) and torch.cuda.get_device_capability()[0] < 10:
|
74 |
w = w.transpose(-1, -2).contiguous().transpose(-1, -2)
|
75 |
return x, w, bias, gs0, gs1
|
76 |
|
@@ -291,14 +291,15 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
|
|
291 |
if hbm_swizzling:
|
292 |
if is_hip():
|
293 |
pytest.skip("NYI. HBM swizzling just implemented for CUDA.")
|
294 |
-
if
|
295 |
-
|
296 |
-
|
297 |
-
if
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
|
|
302 |
|
303 |
# launch metadata for batched / mx types may not work yet.
|
304 |
test_launch_metadata = (mode == "ragged") and ("mx" not in weight_dtype_str)
|
@@ -306,7 +307,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
|
|
306 |
torch.manual_seed(0)
|
307 |
|
308 |
block_k = None
|
309 |
-
if is_persistent and weight_dtype_str.startswith("mx") and torch.cuda.get_device_capability()[0] < 10:
|
310 |
# Override block_k for testing correctness. The default is temporarily 128 for
|
311 |
# performance reasons which doesn't work with persistent matmul.
|
312 |
# TODO: revisit when Triton is better for H100 + MXFP4
|
@@ -462,7 +463,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
|
|
462 |
|
463 |
round_y = lambda y: (y / y_scale).to(act_dtype).to(torch.float32) * y_scale if sep_scatter else y
|
464 |
ref_y = matmul_ogs_torch(x_ref, w_ref, bias_ref, #
|
465 |
-
rdata, gindx, sindx, round_x=round_x, round_y=round_y, gammas=gs1_ref)
|
466 |
scale = lambda val, scal: val if scal is None else val / scal
|
467 |
if n_expt_shards > 1:
|
468 |
if do_scatter:
|
|
|
20 |
# testing utilities
|
21 |
from triton_kernels.testing import assert_close, compute_actual_scale
|
22 |
# target-specific utilities
|
23 |
+
from triton_kernels.target_info import is_hip, is_xpu, is_hip_cdna3, is_cuda, is_hip_cdna4
|
24 |
|
25 |
# ---------------
|
26 |
# initialize data
|
|
|
70 |
if mode == 'batched' or (not has_y_gammas) or (has_y_gammas and (gindx is not None) and act_dtype.itemsize >= 2):
|
71 |
gs0 = None
|
72 |
gs1 = None
|
73 |
+
if is_cuda() and "float8" in str(weight_dtype) and torch.cuda.get_device_capability()[0] < 10:
|
74 |
w = w.transpose(-1, -2).contiguous().transpose(-1, -2)
|
75 |
return x, w, bias, gs0, gs1
|
76 |
|
|
|
291 |
if hbm_swizzling:
|
292 |
if is_hip():
|
293 |
pytest.skip("NYI. HBM swizzling just implemented for CUDA.")
|
294 |
+
if is_cuda():
|
295 |
+
if torch.cuda.get_device_capability()[0] < 9:
|
296 |
+
pytest.skip("NYI. Ampere swizzling.")
|
297 |
+
if torch.cuda.get_device_capability()[0] < 10:
|
298 |
+
if "mxfloat4" not in weight_dtype_str:
|
299 |
+
pytest.skip("NYI. Hopper swizzling just implemented for mxfp4.")
|
300 |
+
if k % 64 != 0 or n % 64 != 0:
|
301 |
+
# Automatic padding not implemented for Hopper swizzle
|
302 |
+
pytest.skip("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles).")
|
303 |
|
304 |
# launch metadata for batched / mx types may not work yet.
|
305 |
test_launch_metadata = (mode == "ragged") and ("mx" not in weight_dtype_str)
|
|
|
307 |
torch.manual_seed(0)
|
308 |
|
309 |
block_k = None
|
310 |
+
if is_cuda() and is_persistent and weight_dtype_str.startswith("mx") and torch.cuda.get_device_capability()[0] < 10:
|
311 |
# Override block_k for testing correctness. The default is temporarily 128 for
|
312 |
# performance reasons which doesn't work with persistent matmul.
|
313 |
# TODO: revisit when Triton is better for H100 + MXFP4
|
|
|
463 |
|
464 |
round_y = lambda y: (y / y_scale).to(act_dtype).to(torch.float32) * y_scale if sep_scatter else y
|
465 |
ref_y = matmul_ogs_torch(x_ref, w_ref, bias_ref, #
|
466 |
+
rdata, gindx, sindx, round_x=round_x, round_y=round_y, gammas=gs1_ref, device=device)
|
467 |
scale = lambda val, scal: val if scal is None else val / scal
|
468 |
if n_expt_shards > 1:
|
469 |
if do_scatter:
|
torch-ext/triton_kernels/swiglu.py
CHANGED
@@ -35,7 +35,7 @@ class SwiGLU(torch.autograd.Function):
|
|
35 |
# optimization hyperparameters
|
36 |
BLOCK_M, BLOCK_N = 32 // a.itemsize, 128
|
37 |
num_warps = 4
|
38 |
-
kwargs = {'maxnreg': 64} if not target_info.is_hip() else {}
|
39 |
# launch semi-persistent kernel
|
40 |
N_BLOCKS = triton.cdiv(N // 2, BLOCK_N)
|
41 |
num_sms = target_info.num_sms()
|
|
|
35 |
# optimization hyperparameters
|
36 |
BLOCK_M, BLOCK_N = 32 // a.itemsize, 128
|
37 |
num_warps = 4
|
38 |
+
kwargs = {'maxnreg': 64} if not target_info.is_hip() and not target_info.is_xpu() else {}
|
39 |
# launch semi-persistent kernel
|
40 |
N_BLOCKS = triton.cdiv(N // 2, BLOCK_N)
|
41 |
num_sms = target_info.num_sms()
|