File size: 5,384 Bytes
104fd3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
#!/usr/bin/env python3
"""Step-by-step debugging of the grouped GEMM computation."""
import pathlib
import sys
from typing import Optional
import torch
def detect_variant(root: pathlib.Path) -> str:
build_dir = root / "build"
variant: Optional[str] = None
if (root / "kernels" / "utils.py").exists():
try:
sys.path.insert(0, str(root))
from kernels.utils import build_variant as _build_variant # type: ignore
variant = _build_variant()
except Exception:
variant = None
finally:
sys.path.pop(0)
if variant is None:
candidates = sorted(build_dir.glob("torch*-rocm64-*") or build_dir.glob("torch*-cu*"))
if candidates:
variant = candidates[0].name
if variant is None:
raise SystemExit("Could not determine build variant; run build.py first.")
return variant
def manual_gmm_computation(a, b, batch_sizes, trans_b=False):
"""Manual step-by-step computation like the C++ code does."""
print("=== Manual GMM computation ===")
# Convert to CPU for batch sizes
batch_sizes_cpu = batch_sizes.cpu()
counts_ptr = batch_sizes_cpu.numpy()
num_experts = len(counts_ptr)
# Calculate prefix sums like the C++ code
prefix = []
running = 0
for i in range(num_experts):
running += counts_ptr[i]
prefix.append(running)
tokens = prefix[-1] if prefix else 0
print(f"num_experts: {num_experts}, tokens: {tokens}")
print(f"a.shape: {a.shape}, b.shape: {b.shape}")
print(f"batch_sizes: {counts_ptr}")
# Create output tensor
if not trans_b: # default case
hidden_out = a.size(1) # 128
hidden_in = b.size(2) # 128
out = torch.empty((tokens, hidden_in), dtype=a.dtype, device=a.device)
print(f"Output shape: {out.shape} (tokens={tokens}, hidden_in={hidden_in})")
b_contig = b.contiguous()
start = 0
for expert in range(num_experts):
end = prefix[expert]
rows = end - start
print(f"\nExpert {expert}: start={start}, end={end}, rows={rows}")
if rows == 0:
start = end
continue
# Get slices like C++ code
a_slice = a.narrow(0, start, rows) # [rows, hidden_out]
b_slice = b_contig.select(0, expert) # [hidden_out, hidden_in]
out_slice = out.narrow(0, start, rows) # [rows, hidden_in]
print(f" a_slice.shape: {a_slice.shape}")
print(f" b_slice.shape: {b_slice.shape}")
print(f" a_slice range: [{a_slice.min().item():.8f}, {a_slice.max().item():.8f}]")
print(f" b_slice range: [{b_slice.min().item():.8f}, {b_slice.max().item():.8f}]")
# Convert to FP32 like C++ code
a_f32 = a_slice.to(torch.float32)
b_f32 = b_slice.to(torch.float32)
# Do the matmul
prod = torch.matmul(a_f32, b_f32)
print(f" prod.shape: {prod.shape}")
print(f" prod range: [{prod.min().item():.8f}, {prod.max().item():.8f}]")
# Convert back and copy
prod_bf16 = prod.to(a.dtype)
out_slice.copy_(prod_bf16)
start = end
return out
else:
raise NotImplementedError("trans_b case not implemented")
def main() -> None:
repo_root = pathlib.Path(__file__).resolve().parent.parent # Go up from _dev/ to repo root
variant = detect_variant(repo_root)
staged_dir = repo_root / "build" / variant
if str(staged_dir) not in sys.path:
sys.path.insert(0, str(staged_dir))
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
import megablocks # type: ignore
from tests.test_gg import gmm, randn # type: ignore
print(f"Using staged variant: {variant}")
torch.manual_seed(0)
z = m = n = k = 128
trans_b = False
a = randn(z, m, k).view(-1, k)
b = randn(z, k, n) if not trans_b else randn(z, n, k)
batch_sizes = torch.tensor([m] * z, device="cpu")
print(f"=== Input Information ===")
print(f"a.shape: {a.shape}, dtype: {a.dtype}")
print(f"b.shape: {b.shape}, dtype: {b.dtype}")
print(f"batch_sizes: {batch_sizes}")
print(f"Input a range: [{a.min().item():.8f}, {a.max().item():.8f}]")
print(f"Input b range: [{b.min().item():.8f}, {b.max().item():.8f}]")
# Manual computation
manual_out = manual_gmm_computation(a.clone(), b.clone(), batch_sizes, trans_b)
print(f"\nManual output range: [{manual_out.min().item():.8f}, {manual_out.max().item():.8f}]")
# Reference computation
a_ref = a.detach().clone()
b_ref = b.detach().clone()
ref = gmm(a_ref, b_ref, batch_sizes.cpu(), trans_b)
print(f"Reference output range: [{ref.min().item():.8f}, {ref.max().item():.8f}]")
# Megablocks computation
out = megablocks.gg_ops.gmm(a.clone(), b.clone(), batch_sizes, trans_b)
print(f"Megablocks output range: [{out.min().item():.8f}, {out.max().item():.8f}]")
# Compare
manual_vs_ref = (manual_out - ref).abs().max().item()
megablocks_vs_ref = (out - ref).abs().max().item()
print(f"\nManual vs Reference max diff: {manual_vs_ref:.8e}")
print(f"Megablocks vs Reference max diff: {megablocks_vs_ref:.8e}")
if __name__ == "__main__":
main() |