File size: 5,384 Bytes
104fd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#!/usr/bin/env python3
"""Step-by-step debugging of the grouped GEMM computation."""

import pathlib
import sys
from typing import Optional

import torch


def detect_variant(root: pathlib.Path) -> str:
    build_dir = root / "build"
    variant: Optional[str] = None

    if (root / "kernels" / "utils.py").exists():
        try:
            sys.path.insert(0, str(root))
            from kernels.utils import build_variant as _build_variant  # type: ignore

            variant = _build_variant()
        except Exception:
            variant = None
        finally:
            sys.path.pop(0)

    if variant is None:
        candidates = sorted(build_dir.glob("torch*-rocm64-*") or build_dir.glob("torch*-cu*"))
        if candidates:
            variant = candidates[0].name

    if variant is None:
        raise SystemExit("Could not determine build variant; run build.py first.")

    return variant


def manual_gmm_computation(a, b, batch_sizes, trans_b=False):
    """Manual step-by-step computation like the C++ code does."""
    print("=== Manual GMM computation ===")

    # Convert to CPU for batch sizes
    batch_sizes_cpu = batch_sizes.cpu()
    counts_ptr = batch_sizes_cpu.numpy()
    num_experts = len(counts_ptr)

    # Calculate prefix sums like the C++ code
    prefix = []
    running = 0
    for i in range(num_experts):
        running += counts_ptr[i]
        prefix.append(running)

    tokens = prefix[-1] if prefix else 0
    print(f"num_experts: {num_experts}, tokens: {tokens}")
    print(f"a.shape: {a.shape}, b.shape: {b.shape}")
    print(f"batch_sizes: {counts_ptr}")

    # Create output tensor
    if not trans_b:  # default case
        hidden_out = a.size(1)  # 128
        hidden_in = b.size(2)   # 128
        out = torch.empty((tokens, hidden_in), dtype=a.dtype, device=a.device)
        print(f"Output shape: {out.shape} (tokens={tokens}, hidden_in={hidden_in})")

        b_contig = b.contiguous()

        start = 0
        for expert in range(num_experts):
            end = prefix[expert]
            rows = end - start
            print(f"\nExpert {expert}: start={start}, end={end}, rows={rows}")

            if rows == 0:
                start = end
                continue

            # Get slices like C++ code
            a_slice = a.narrow(0, start, rows)  # [rows, hidden_out]
            b_slice = b_contig.select(0, expert)  # [hidden_out, hidden_in]
            out_slice = out.narrow(0, start, rows)  # [rows, hidden_in]

            print(f"  a_slice.shape: {a_slice.shape}")
            print(f"  b_slice.shape: {b_slice.shape}")
            print(f"  a_slice range: [{a_slice.min().item():.8f}, {a_slice.max().item():.8f}]")
            print(f"  b_slice range: [{b_slice.min().item():.8f}, {b_slice.max().item():.8f}]")

            # Convert to FP32 like C++ code
            a_f32 = a_slice.to(torch.float32)
            b_f32 = b_slice.to(torch.float32)

            # Do the matmul
            prod = torch.matmul(a_f32, b_f32)
            print(f"  prod.shape: {prod.shape}")
            print(f"  prod range: [{prod.min().item():.8f}, {prod.max().item():.8f}]")

            # Convert back and copy
            prod_bf16 = prod.to(a.dtype)
            out_slice.copy_(prod_bf16)

            start = end

        return out
    else:
        raise NotImplementedError("trans_b case not implemented")


def main() -> None:
    repo_root = pathlib.Path(__file__).resolve().parent.parent  # Go up from _dev/ to repo root
    variant = detect_variant(repo_root)
    staged_dir = repo_root / "build" / variant

    if str(staged_dir) not in sys.path:
        sys.path.insert(0, str(staged_dir))
    if str(repo_root) not in sys.path:
        sys.path.insert(0, str(repo_root))

    import megablocks  # type: ignore
    from tests.test_gg import gmm, randn  # type: ignore

    print(f"Using staged variant: {variant}")

    torch.manual_seed(0)

    z = m = n = k = 128
    trans_b = False

    a = randn(z, m, k).view(-1, k)
    b = randn(z, k, n) if not trans_b else randn(z, n, k)
    batch_sizes = torch.tensor([m] * z, device="cpu")

    print(f"=== Input Information ===")
    print(f"a.shape: {a.shape}, dtype: {a.dtype}")
    print(f"b.shape: {b.shape}, dtype: {b.dtype}")
    print(f"batch_sizes: {batch_sizes}")
    print(f"Input a range: [{a.min().item():.8f}, {a.max().item():.8f}]")
    print(f"Input b range: [{b.min().item():.8f}, {b.max().item():.8f}]")

    # Manual computation
    manual_out = manual_gmm_computation(a.clone(), b.clone(), batch_sizes, trans_b)
    print(f"\nManual output range: [{manual_out.min().item():.8f}, {manual_out.max().item():.8f}]")

    # Reference computation
    a_ref = a.detach().clone()
    b_ref = b.detach().clone()
    ref = gmm(a_ref, b_ref, batch_sizes.cpu(), trans_b)
    print(f"Reference output range: [{ref.min().item():.8f}, {ref.max().item():.8f}]")

    # Megablocks computation
    out = megablocks.gg_ops.gmm(a.clone(), b.clone(), batch_sizes, trans_b)
    print(f"Megablocks output range: [{out.min().item():.8f}, {out.max().item():.8f}]")

    # Compare
    manual_vs_ref = (manual_out - ref).abs().max().item()
    megablocks_vs_ref = (out - ref).abs().max().item()
    print(f"\nManual vs Reference max diff: {manual_vs_ref:.8e}")
    print(f"Megablocks vs Reference max diff: {megablocks_vs_ref:.8e}")


if __name__ == "__main__":
    main()