YangKai0616 commited on
Commit
a32f88a
·
1 Parent(s): c1e53ae

Update test_matmul.py to support UT of XPU

Browse files
tests/conftest.py CHANGED
@@ -1,5 +1,5 @@
1
  import pytest
2
-
3
 
4
  def pytest_addoption(parser):
5
  parser.addoption("--device", action="store", default="cuda")
@@ -12,8 +12,19 @@ def device(request):
12
 
13
  @pytest.fixture
14
  def fresh_knobs(monkeypatch):
 
 
 
 
 
 
 
15
  from triton._internal_testing import _fresh_knobs_impl
16
- fresh_function, reset_function = _fresh_knobs_impl(monkeypatch)
 
 
 
 
17
  try:
18
  yield fresh_function()
19
  finally:
 
1
  import pytest
2
+ import triton
3
 
4
  def pytest_addoption(parser):
5
  parser.addoption("--device", action="store", default="cuda")
 
12
 
13
  @pytest.fixture
14
  def fresh_knobs(monkeypatch):
15
+ try:
16
+ _ver_str = getattr(triton, "__version__", "0.0.0").split("+")[0]
17
+ _parts = _ver_str.split(".")
18
+ _ver_tuple = tuple(int(p) for p in (_parts + ["0", "0", "0"])[:3])
19
+ except Exception:
20
+ _ver_tuple = (0, 0, 0)
21
+
22
  from triton._internal_testing import _fresh_knobs_impl
23
+ if _ver_tuple > (3, 4, 0):
24
+ fresh_function, reset_function = _fresh_knobs_impl()
25
+ else:
26
+ fresh_function, reset_function = _fresh_knobs_impl(monkeypatch)
27
+
28
  try:
29
  yield fresh_function()
30
  finally:
tests/test_matmul.py CHANGED
@@ -20,7 +20,7 @@ from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_m
20
  # testing utilities
21
  from triton_kernels.testing import assert_close, compute_actual_scale
22
  # target-specific utilities
23
- from triton_kernels.target_info import is_hip, is_hip_cdna3, is_cuda, is_hip_cdna4
24
 
25
  # ---------------
26
  # initialize data
@@ -70,7 +70,7 @@ def init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_sh
70
  if mode == 'batched' or (not has_y_gammas) or (has_y_gammas and (gindx is not None) and act_dtype.itemsize >= 2):
71
  gs0 = None
72
  gs1 = None
73
- if "float8" in str(weight_dtype) and torch.cuda.get_device_capability()[0] < 10:
74
  w = w.transpose(-1, -2).contiguous().transpose(-1, -2)
75
  return x, w, bias, gs0, gs1
76
 
@@ -291,14 +291,15 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
291
  if hbm_swizzling:
292
  if is_hip():
293
  pytest.skip("NYI. HBM swizzling just implemented for CUDA.")
294
- if torch.cuda.get_device_capability()[0] < 9:
295
- pytest.skip("NYI. Ampere swizzling.")
296
- if torch.cuda.get_device_capability()[0] < 10:
297
- if "mxfloat4" not in weight_dtype_str:
298
- pytest.skip("NYI. Hopper swizzling just implemented for mxfp4.")
299
- if k % 64 != 0 or n % 64 != 0:
300
- # Automatic padding not implemented for Hopper swizzle
301
- pytest.skip("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles).")
 
302
 
303
  # launch metadata for batched / mx types may not work yet.
304
  test_launch_metadata = (mode == "ragged") and ("mx" not in weight_dtype_str)
@@ -306,7 +307,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
306
  torch.manual_seed(0)
307
 
308
  block_k = None
309
- if is_persistent and weight_dtype_str.startswith("mx") and torch.cuda.get_device_capability()[0] < 10:
310
  # Override block_k for testing correctness. The default is temporarily 128 for
311
  # performance reasons which doesn't work with persistent matmul.
312
  # TODO: revisit when Triton is better for H100 + MXFP4
@@ -462,7 +463,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
462
 
463
  round_y = lambda y: (y / y_scale).to(act_dtype).to(torch.float32) * y_scale if sep_scatter else y
464
  ref_y = matmul_ogs_torch(x_ref, w_ref, bias_ref, #
465
- rdata, gindx, sindx, round_x=round_x, round_y=round_y, gammas=gs1_ref)
466
  scale = lambda val, scal: val if scal is None else val / scal
467
  if n_expt_shards > 1:
468
  if do_scatter:
 
20
  # testing utilities
21
  from triton_kernels.testing import assert_close, compute_actual_scale
22
  # target-specific utilities
23
+ from triton_kernels.target_info import is_hip, is_xpu, is_hip_cdna3, is_cuda, is_hip_cdna4
24
 
25
  # ---------------
26
  # initialize data
 
70
  if mode == 'batched' or (not has_y_gammas) or (has_y_gammas and (gindx is not None) and act_dtype.itemsize >= 2):
71
  gs0 = None
72
  gs1 = None
73
+ if is_cuda() and "float8" in str(weight_dtype) and torch.cuda.get_device_capability()[0] < 10:
74
  w = w.transpose(-1, -2).contiguous().transpose(-1, -2)
75
  return x, w, bias, gs0, gs1
76
 
 
291
  if hbm_swizzling:
292
  if is_hip():
293
  pytest.skip("NYI. HBM swizzling just implemented for CUDA.")
294
+ if is_cuda():
295
+ if torch.cuda.get_device_capability()[0] < 9:
296
+ pytest.skip("NYI. Ampere swizzling.")
297
+ if torch.cuda.get_device_capability()[0] < 10:
298
+ if "mxfloat4" not in weight_dtype_str:
299
+ pytest.skip("NYI. Hopper swizzling just implemented for mxfp4.")
300
+ if k % 64 != 0 or n % 64 != 0:
301
+ # Automatic padding not implemented for Hopper swizzle
302
+ pytest.skip("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles).")
303
 
304
  # launch metadata for batched / mx types may not work yet.
305
  test_launch_metadata = (mode == "ragged") and ("mx" not in weight_dtype_str)
 
307
  torch.manual_seed(0)
308
 
309
  block_k = None
310
+ if is_cuda() and is_persistent and weight_dtype_str.startswith("mx") and torch.cuda.get_device_capability()[0] < 10:
311
  # Override block_k for testing correctness. The default is temporarily 128 for
312
  # performance reasons which doesn't work with persistent matmul.
313
  # TODO: revisit when Triton is better for H100 + MXFP4
 
463
 
464
  round_y = lambda y: (y / y_scale).to(act_dtype).to(torch.float32) * y_scale if sep_scatter else y
465
  ref_y = matmul_ogs_torch(x_ref, w_ref, bias_ref, #
466
+ rdata, gindx, sindx, round_x=round_x, round_y=round_y, gammas=gs1_ref, device=device)
467
  scale = lambda val, scal: val if scal is None else val / scal
468
  if n_expt_shards > 1:
469
  if do_scatter:
torch-ext/triton_kernels/swiglu.py CHANGED
@@ -35,7 +35,7 @@ class SwiGLU(torch.autograd.Function):
35
  # optimization hyperparameters
36
  BLOCK_M, BLOCK_N = 32 // a.itemsize, 128
37
  num_warps = 4
38
- kwargs = {'maxnreg': 64} if not target_info.is_hip() else {}
39
  # launch semi-persistent kernel
40
  N_BLOCKS = triton.cdiv(N // 2, BLOCK_N)
41
  num_sms = target_info.num_sms()
 
35
  # optimization hyperparameters
36
  BLOCK_M, BLOCK_N = 32 // a.itemsize, 128
37
  num_warps = 4
38
+ kwargs = {'maxnreg': 64} if not target_info.is_hip() and not target_info.is_xpu() else {}
39
  # launch semi-persistent kernel
40
  N_BLOCKS = triton.cdiv(N // 2, BLOCK_N)
41
  num_sms = target_info.num_sms()