#!/usr/bin/env python3 """Step-by-step debugging of the grouped GEMM computation.""" import pathlib import sys from typing import Optional import torch def detect_variant(root: pathlib.Path) -> str: build_dir = root / "build" variant: Optional[str] = None if (root / "kernels" / "utils.py").exists(): try: sys.path.insert(0, str(root)) from kernels.utils import build_variant as _build_variant # type: ignore variant = _build_variant() except Exception: variant = None finally: sys.path.pop(0) if variant is None: candidates = sorted(build_dir.glob("torch*-rocm64-*") or build_dir.glob("torch*-cu*")) if candidates: variant = candidates[0].name if variant is None: raise SystemExit("Could not determine build variant; run build.py first.") return variant def manual_gmm_computation(a, b, batch_sizes, trans_b=False): """Manual step-by-step computation like the C++ code does.""" print("=== Manual GMM computation ===") # Convert to CPU for batch sizes batch_sizes_cpu = batch_sizes.cpu() counts_ptr = batch_sizes_cpu.numpy() num_experts = len(counts_ptr) # Calculate prefix sums like the C++ code prefix = [] running = 0 for i in range(num_experts): running += counts_ptr[i] prefix.append(running) tokens = prefix[-1] if prefix else 0 print(f"num_experts: {num_experts}, tokens: {tokens}") print(f"a.shape: {a.shape}, b.shape: {b.shape}") print(f"batch_sizes: {counts_ptr}") # Create output tensor if not trans_b: # default case hidden_out = a.size(1) # 128 hidden_in = b.size(2) # 128 out = torch.empty((tokens, hidden_in), dtype=a.dtype, device=a.device) print(f"Output shape: {out.shape} (tokens={tokens}, hidden_in={hidden_in})") b_contig = b.contiguous() start = 0 for expert in range(num_experts): end = prefix[expert] rows = end - start print(f"\nExpert {expert}: start={start}, end={end}, rows={rows}") if rows == 0: start = end continue # Get slices like C++ code a_slice = a.narrow(0, start, rows) # [rows, hidden_out] b_slice = b_contig.select(0, expert) # [hidden_out, hidden_in] out_slice = out.narrow(0, start, rows) # [rows, hidden_in] print(f" a_slice.shape: {a_slice.shape}") print(f" b_slice.shape: {b_slice.shape}") print(f" a_slice range: [{a_slice.min().item():.8f}, {a_slice.max().item():.8f}]") print(f" b_slice range: [{b_slice.min().item():.8f}, {b_slice.max().item():.8f}]") # Convert to FP32 like C++ code a_f32 = a_slice.to(torch.float32) b_f32 = b_slice.to(torch.float32) # Do the matmul prod = torch.matmul(a_f32, b_f32) print(f" prod.shape: {prod.shape}") print(f" prod range: [{prod.min().item():.8f}, {prod.max().item():.8f}]") # Convert back and copy prod_bf16 = prod.to(a.dtype) out_slice.copy_(prod_bf16) start = end return out else: raise NotImplementedError("trans_b case not implemented") def main() -> None: repo_root = pathlib.Path(__file__).resolve().parent.parent # Go up from _dev/ to repo root variant = detect_variant(repo_root) staged_dir = repo_root / "build" / variant if str(staged_dir) not in sys.path: sys.path.insert(0, str(staged_dir)) if str(repo_root) not in sys.path: sys.path.insert(0, str(repo_root)) import megablocks # type: ignore from tests.test_gg import gmm, randn # type: ignore print(f"Using staged variant: {variant}") torch.manual_seed(0) z = m = n = k = 128 trans_b = False a = randn(z, m, k).view(-1, k) b = randn(z, k, n) if not trans_b else randn(z, n, k) batch_sizes = torch.tensor([m] * z, device="cpu") print(f"=== Input Information ===") print(f"a.shape: {a.shape}, dtype: {a.dtype}") print(f"b.shape: {b.shape}, dtype: {b.dtype}") print(f"batch_sizes: {batch_sizes}") print(f"Input a range: [{a.min().item():.8f}, {a.max().item():.8f}]") print(f"Input b range: [{b.min().item():.8f}, {b.max().item():.8f}]") # Manual computation manual_out = manual_gmm_computation(a.clone(), b.clone(), batch_sizes, trans_b) print(f"\nManual output range: [{manual_out.min().item():.8f}, {manual_out.max().item():.8f}]") # Reference computation a_ref = a.detach().clone() b_ref = b.detach().clone() ref = gmm(a_ref, b_ref, batch_sizes.cpu(), trans_b) print(f"Reference output range: [{ref.min().item():.8f}, {ref.max().item():.8f}]") # Megablocks computation out = megablocks.gg_ops.gmm(a.clone(), b.clone(), batch_sizes, trans_b) print(f"Megablocks output range: [{out.min().item():.8f}, {out.max().item():.8f}]") # Compare manual_vs_ref = (manual_out - ref).abs().max().item() megablocks_vs_ref = (out - ref).abs().max().item() print(f"\nManual vs Reference max diff: {manual_vs_ref:.8e}") print(f"Megablocks vs Reference max diff: {megablocks_vs_ref:.8e}") if __name__ == "__main__": main()