leonardlin commited on
Commit
2d8a802
·
1 Parent(s): 1e407f0

Add ROCm build debugging utilities

Browse files
.gitignore CHANGED
@@ -5,3 +5,5 @@ megablocks-moe/.bak
5
  .pytest_cache
6
  .readme_example.py.swp
7
  .torch_extensions/
 
 
 
5
  .pytest_cache
6
  .readme_example.py.swp
7
  .torch_extensions/
8
+ .torch_extensions_debug/
9
+ strace.log
_dev/debug-build-1-env.sh ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # Debug script 1: Basic ROCm environment and tool availability check
4
+
5
+ set -euo pipefail
6
+
7
+ echo "=== ROCm Environment Debug Script 1 ==="
8
+ echo "Testing basic ROCm/HIP environment setup"
9
+ echo
10
+
11
+ # Set ROCm environment variables
12
+ export ROCM_PATH="${ROCM_PATH:-/opt/rocm-7.0.1}"
13
+ export ROCM_HOME="${ROCM_HOME:-$ROCM_PATH}"
14
+ export HIP_PATH="${HIP_PATH:-$ROCM_PATH}"
15
+ export HIP_HOME="${HIP_HOME:-$ROCM_PATH}"
16
+ export PATH="$ROCM_HOME/bin:$PATH"
17
+ export TORCH_HIP_ARCH_LIST="${TORCH_HIP_ARCH_LIST:-gfx942}"
18
+ export HSA_OVERRIDE_GFX_VERSION="${HSA_OVERRIDE_GFX_VERSION:-gfx942}"
19
+
20
+ echo "Environment Variables:"
21
+ echo "ROCM_PATH=$ROCM_PATH"
22
+ echo "ROCM_HOME=$ROCM_HOME"
23
+ echo "HIP_PATH=$HIP_PATH"
24
+ echo "HIP_HOME=$HIP_HOME"
25
+ echo "TORCH_HIP_ARCH_LIST=$TORCH_HIP_ARCH_LIST"
26
+ echo "HSA_OVERRIDE_GFX_VERSION=$HSA_OVERRIDE_GFX_VERSION"
27
+ echo "PATH (ROCm portion): $(echo $PATH | tr ':' '\n' | grep rocm || echo 'No ROCm in PATH')"
28
+ echo
29
+
30
+ echo "=== Directory Checks ==="
31
+ echo "ROCm installation directory exists: $(test -d "$ROCM_PATH" && echo 'YES' || echo 'NO')"
32
+ echo "ROCm bin directory exists: $(test -d "$ROCM_PATH/bin" && echo 'YES' || echo 'NO')"
33
+ echo "ROCm include directory exists: $(test -d "$ROCM_PATH/include" && echo 'YES' || echo 'NO')"
34
+ echo "ROCm lib directory exists: $(test -d "$ROCM_PATH/lib" && echo 'YES' || echo 'NO')"
35
+ echo
36
+
37
+ echo "=== Tool Availability ==="
38
+ echo "hipcc available: $(which hipcc >/dev/null 2>&1 && echo 'YES' || echo 'NO')"
39
+ echo "hip-clang available: $(which hip-clang >/dev/null 2>&1 && echo 'YES' || echo 'NO')"
40
+ echo "rocm-smi available: $(which rocm-smi >/dev/null 2>&1 && echo 'YES' || echo 'NO')"
41
+ echo "hipconfig available: $(which hipconfig >/dev/null 2>&1 && echo 'YES' || echo 'NO')"
42
+ echo
43
+
44
+ echo "=== Tool Versions ==="
45
+ if which hipcc >/dev/null 2>&1; then
46
+ echo "hipcc version:"
47
+ hipcc --version || echo "Failed to get hipcc version"
48
+ echo
49
+ fi
50
+
51
+ if which hipconfig >/dev/null 2>&1; then
52
+ echo "HIP config:"
53
+ hipconfig --full || echo "Failed to get hipconfig"
54
+ echo
55
+ fi
56
+
57
+ if which rocm-smi >/dev/null 2>&1; then
58
+ echo "ROCm SMI:"
59
+ rocm-smi --showproductname || echo "Failed to get ROCm SMI info"
60
+ echo
61
+ fi
62
+
63
+ echo "=== Python Environment ==="
64
+ python3 --version || echo "Python3 not available"
65
+ python3 -c "import torch; print(f'PyTorch version: {torch.__version__}')" || echo "PyTorch not available"
66
+ python3 -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')" || echo "Failed to check CUDA availability"
67
+ python3 -c "import torch; print(f'HIP available: {hasattr(torch.version, \"hip\") and torch.version.hip is not None}')" || echo "Failed to check HIP availability"
68
+
69
+ echo
70
+ echo "=== Basic HIP Device Check ==="
71
+ if which hipinfo >/dev/null 2>&1; then
72
+ echo "HIP devices:"
73
+ hipinfo || echo "hipinfo failed"
74
+ else
75
+ echo "hipinfo not available"
76
+ fi
77
+
78
+ echo
79
+ echo "=== Debug Script 1 Complete ==="
_dev/debug-build-2-hipcc.sh ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # Debug script 2: HIP compiler basic compilation test
4
+
5
+ set -euo pipefail
6
+
7
+ echo "=== HIP Compiler Debug Script 2 ==="
8
+ echo "Testing basic HIP compilation capabilities"
9
+ echo
10
+
11
+ # Set ROCm environment variables
12
+ export ROCM_PATH="${ROCM_PATH:-/opt/rocm-7.0.1}"
13
+ export ROCM_HOME="${ROCM_HOME:-$ROCM_PATH}"
14
+ export HIP_PATH="${HIP_PATH:-$ROCM_PATH}"
15
+ export HIP_HOME="${HIP_HOME:-$ROCM_PATH}"
16
+ export PATH="$ROCM_HOME/bin:$PATH"
17
+ export TORCH_HIP_ARCH_LIST="${TORCH_HIP_ARCH_LIST:-gfx942}"
18
+ export HSA_OVERRIDE_GFX_VERSION="${HSA_OVERRIDE_GFX_VERSION:-gfx942}"
19
+
20
+ # Create a simple test directory
21
+ mkdir -p /tmp/hip_test
22
+ cd /tmp/hip_test
23
+
24
+ echo "=== Creating Simple HIP Test Program ==="
25
+
26
+ # Create a minimal HIP program
27
+ cat > test_simple.hip << 'EOF'
28
+ #include <hip/hip_runtime.h>
29
+ #include <iostream>
30
+
31
+ __global__ void simple_kernel(float* data, int n) {
32
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
33
+ if (idx < n) {
34
+ data[idx] = idx * 2.0f;
35
+ }
36
+ }
37
+
38
+ int main() {
39
+ const int n = 1024;
40
+ const size_t size = n * sizeof(float);
41
+
42
+ float* h_data = new float[n];
43
+ float* d_data;
44
+
45
+ // Allocate device memory
46
+ hipError_t err = hipMalloc(&d_data, size);
47
+ if (err != hipSuccess) {
48
+ std::cout << "Failed to allocate device memory: " << hipGetErrorString(err) << std::endl;
49
+ return 1;
50
+ }
51
+
52
+ // Launch kernel
53
+ dim3 block(256);
54
+ dim3 grid((n + block.x - 1) / block.x);
55
+ hipLaunchKernelGGL(simple_kernel, grid, block, 0, 0, d_data, n);
56
+
57
+ // Copy back
58
+ err = hipMemcpy(h_data, d_data, size, hipMemcpyDeviceToHost);
59
+ if (err != hipSuccess) {
60
+ std::cout << "Failed to copy from device: " << hipGetErrorString(err) << std::endl;
61
+ return 1;
62
+ }
63
+
64
+ // Check results
65
+ bool success = true;
66
+ for (int i = 0; i < 10; i++) {
67
+ if (h_data[i] != i * 2.0f) {
68
+ success = false;
69
+ break;
70
+ }
71
+ }
72
+
73
+ std::cout << "Simple HIP test: " << (success ? "PASSED" : "FAILED") << std::endl;
74
+
75
+ hipFree(d_data);
76
+ delete[] h_data;
77
+
78
+ return success ? 0 : 1;
79
+ }
80
+ EOF
81
+
82
+ echo "=== Testing HIP Compilation ==="
83
+ echo "Compiling simple HIP program..."
84
+
85
+ if hipcc -o test_simple test_simple.hip --amdgpu-target=gfx942 -v; then
86
+ echo "✓ HIP compilation successful"
87
+ echo
88
+ echo "=== Testing HIP Execution ==="
89
+ if ./test_simple; then
90
+ echo "✓ HIP execution successful"
91
+ else
92
+ echo "✗ HIP execution failed"
93
+ fi
94
+ else
95
+ echo "✗ HIP compilation failed"
96
+ fi
97
+
98
+ echo
99
+ echo "=== Testing HIP with C++ Extensions Flags ==="
100
+
101
+ # Test with flags similar to what torch uses
102
+ echo "Compiling with torch-like flags..."
103
+ if hipcc -o test_simple_torch test_simple.hip \
104
+ --amdgpu-target=gfx942 \
105
+ -O3 \
106
+ -std=c++17 \
107
+ -fPIC \
108
+ -shared \
109
+ -v; then
110
+ echo "✓ HIP compilation with torch flags successful"
111
+ else
112
+ echo "✗ HIP compilation with torch flags failed"
113
+ fi
114
+
115
+ echo
116
+ echo "=== Testing hipblaslt Library ==="
117
+ echo "Checking if hipblaslt is available for linking..."
118
+
119
+ cat > test_hipblaslt.cpp << 'EOF'
120
+ #include <iostream>
121
+
122
+ // Just test if we can link against hipblaslt
123
+ int main() {
124
+ std::cout << "hipblaslt linkage test" << std::endl;
125
+ return 0;
126
+ }
127
+ EOF
128
+
129
+ if hipcc -o test_hipblaslt test_hipblaslt.cpp -lhipblaslt -v; then
130
+ echo "✓ hipblaslt linkage successful"
131
+ if ./test_hipblaslt; then
132
+ echo "✓ hipblaslt execution successful"
133
+ else
134
+ echo "✗ hipblaslt execution failed"
135
+ fi
136
+ else
137
+ echo "✗ hipblaslt linkage failed"
138
+ fi
139
+
140
+ echo
141
+ echo "=== Library Search Paths ==="
142
+ echo "ROCm lib directory contents:"
143
+ ls -la "$ROCM_PATH/lib" | head -20 || echo "Could not list ROCm lib directory"
144
+
145
+ echo
146
+ echo "hipblaslt library files:"
147
+ find "$ROCM_PATH" -name "*hipblaslt*" 2>/dev/null || echo "No hipblaslt files found"
148
+
149
+ echo
150
+ echo "=== Debug Script 2 Complete ==="
151
+
152
+ # Cleanup
153
+ cd /
154
+ rm -rf /tmp/hip_test
_dev/debug-build-3-torch-ext.sh ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # Debug script 3: PyTorch C++ extension compilation test
4
+
5
+ set -euo pipefail
6
+
7
+ echo "=== PyTorch C++ Extension Debug Script 3 ==="
8
+ echo "Testing PyTorch C++ extension compilation with HIP"
9
+ echo
10
+
11
+ # Set ROCm environment variables
12
+ export ROCM_PATH="${ROCM_PATH:-/opt/rocm-7.0.1}"
13
+ export ROCM_HOME="${ROCM_HOME:-$ROCM_PATH}"
14
+ export HIP_PATH="${HIP_PATH:-$ROCM_PATH}"
15
+ export HIP_HOME="${HIP_HOME:-$ROCM_PATH}"
16
+ export PATH="$ROCM_HOME/bin:$PATH"
17
+ export TORCH_HIP_ARCH_LIST="${TORCH_HIP_ARCH_LIST:-gfx942}"
18
+ export HSA_OVERRIDE_GFX_VERSION="${HSA_OVERRIDE_GFX_VERSION:-gfx942}"
19
+ export TORCH_EXTENSIONS_DIR="${TORCH_EXTENSIONS_DIR:-$PWD/.torch_extensions_debug}"
20
+
21
+ # Create a test directory
22
+ mkdir -p /tmp/torch_ext_test
23
+ cd /tmp/torch_ext_test
24
+
25
+ echo "=== Creating Simple PyTorch Extension ==="
26
+
27
+ # Create a minimal CUDA/HIP kernel similar to megablocks
28
+ cat > simple_kernel.cu << 'EOF'
29
+ #include <torch/extension.h>
30
+ #include <vector>
31
+
32
+ #ifdef __HIP_PLATFORM_AMD__
33
+ #include <hip/hip_runtime.h>
34
+ #define CUDA_LAUNCH_KERNEL(kernel, grid, block, smem, stream, ...) \
35
+ hipLaunchKernelGGL(kernel, grid, block, smem, stream, __VA_ARGS__)
36
+ #else
37
+ #include <cuda_runtime.h>
38
+ #define CUDA_LAUNCH_KERNEL(kernel, grid, block, smem, stream, ...) \
39
+ kernel<<<grid, block, smem, stream>>>(__VA_ARGS__)
40
+ #endif
41
+
42
+ __global__ void add_kernel(const float* a, const float* b, float* c, int n) {
43
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
44
+ if (idx < n) {
45
+ c[idx] = a[idx] + b[idx];
46
+ }
47
+ }
48
+
49
+ torch::Tensor add_tensors_cuda(torch::Tensor a, torch::Tensor b) {
50
+ auto c = torch::zeros_like(a);
51
+ int n = a.numel();
52
+
53
+ const int block_size = 256;
54
+ const int grid_size = (n + block_size - 1) / block_size;
55
+
56
+ CUDA_LAUNCH_KERNEL(
57
+ add_kernel,
58
+ dim3(grid_size),
59
+ dim3(block_size),
60
+ 0,
61
+ 0,
62
+ a.data_ptr<float>(),
63
+ b.data_ptr<float>(),
64
+ c.data_ptr<float>(),
65
+ n
66
+ );
67
+
68
+ return c;
69
+ }
70
+
71
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
72
+ m.def("add_tensors", &add_tensors_cuda, "Add two tensors (CUDA/HIP)");
73
+ }
74
+ EOF
75
+
76
+ # Create Python test script
77
+ cat > test_extension.py << 'EOF'
78
+ import os
79
+ import sys
80
+ import torch
81
+ from torch.utils.cpp_extension import load
82
+
83
+ print("=== PyTorch Extension Load Test ===")
84
+ print(f"PyTorch version: {torch.__version__}")
85
+ print(f"CUDA available: {torch.cuda.is_available()}")
86
+ print(f"Device count: {torch.cuda.device_count()}")
87
+
88
+ if hasattr(torch.version, 'hip') and torch.version.hip:
89
+ print(f"HIP version: {torch.version.hip}")
90
+
91
+ print("\n=== Loading Extension ===")
92
+ print("This may take a while and will show compilation output...")
93
+ print("If this hangs, it indicates the same issue as build.py")
94
+
95
+ try:
96
+ # Mimic the same load call as build.py
97
+ simple_ext = load(
98
+ name="simple_test_ext",
99
+ sources=["simple_kernel.cu"],
100
+ extra_cflags=["-O3", "-std=c++17"],
101
+ extra_cuda_cflags=["-O3"], # torch switches this to hipcc on ROCm
102
+ verbose=True,
103
+ is_python_module=False
104
+ )
105
+ print("✓ Extension compilation successful!")
106
+
107
+ # Test the extension
108
+ print("\n=== Testing Extension ===")
109
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
110
+ a = torch.randn(1000, device=device)
111
+ b = torch.randn(1000, device=device)
112
+
113
+ if device == 'cuda':
114
+ result = simple_ext.add_tensors(a, b)
115
+ expected = a + b
116
+ if torch.allclose(result, expected):
117
+ print("✓ Extension execution successful!")
118
+ else:
119
+ print("✗ Extension execution failed - results don't match")
120
+ else:
121
+ print("⚠ No CUDA device, skipping execution test")
122
+
123
+ except Exception as e:
124
+ print(f"✗ Extension compilation/loading failed: {e}")
125
+ import traceback
126
+ traceback.print_exc()
127
+ EOF
128
+
129
+ echo "=== Running PyTorch Extension Test ==="
130
+ echo "This test mimics the same compilation process as build.py"
131
+ echo "If this hangs, it shows the same issue as the main build"
132
+ echo
133
+
134
+ # Set a timeout to prevent infinite hang
135
+ timeout 300 python3 test_extension.py || {
136
+ exit_code=$?
137
+ if [ $exit_code -eq 124 ]; then
138
+ echo "✗ Extension compilation timed out after 5 minutes (same as build.py hang)"
139
+ else
140
+ echo "✗ Extension compilation failed with exit code $exit_code"
141
+ fi
142
+ }
143
+
144
+ echo
145
+ echo "=== Testing with Minimal Sources ==="
146
+
147
+ # Create an even simpler version
148
+ cat > minimal_kernel.cu << 'EOF'
149
+ #include <torch/extension.h>
150
+
151
+ torch::Tensor dummy_function(torch::Tensor input) {
152
+ return input.clone();
153
+ }
154
+
155
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
156
+ m.def("dummy", &dummy_function, "Dummy function");
157
+ }
158
+ EOF
159
+
160
+ cat > test_minimal.py << 'EOF'
161
+ import torch
162
+ from torch.utils.cpp_extension import load
163
+
164
+ print("=== Minimal Extension Test ===")
165
+
166
+ try:
167
+ minimal_ext = load(
168
+ name="minimal_test_ext",
169
+ sources=["minimal_kernel.cu"],
170
+ extra_cflags=["-O3"],
171
+ verbose=True,
172
+ with_cuda=False # Skip CUDA/HIP compilation
173
+ )
174
+ print("✓ Minimal extension (CPU only) successful!")
175
+ except Exception as e:
176
+ print(f"✗ Even minimal extension failed: {e}")
177
+ EOF
178
+
179
+ echo "Testing minimal CPU-only extension..."
180
+ timeout 120 python3 test_minimal.py || echo "Minimal extension also failed/timed out"
181
+
182
+ echo
183
+ echo "=== Debug Script 3 Complete ==="
184
+
185
+ # Cleanup
186
+ cd /
187
+ rm -rf /tmp/torch_ext_test
_dev/debug-build-4-megablocks.sh ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # Debug script 4: MegaBlocks-specific build debugging
4
+
5
+ set -euo pipefail
6
+
7
+ echo "=== MegaBlocks Build Debug Script 4 ==="
8
+ echo "Testing MegaBlocks-specific compilation components"
9
+ echo
10
+
11
+ # Set ROCm environment variables
12
+ export ROCM_PATH="${ROCM_PATH:-/opt/rocm-7.0.1}"
13
+ export ROCM_HOME="${ROCM_HOME:-$ROCM_PATH}"
14
+ export HIP_PATH="${HIP_PATH:-$ROCM_PATH}"
15
+ export HIP_HOME="${HIP_HOME:-$ROCM_PATH}"
16
+ export PATH="$ROCM_HOME/bin:$PATH"
17
+ export TORCH_HIP_ARCH_LIST="${TORCH_HIP_ARCH_LIST:-gfx942}"
18
+ export HSA_OVERRIDE_GFX_VERSION="${HSA_OVERRIDE_GFX_VERSION:-gfx942}"
19
+ export TORCH_EXTENSIONS_DIR="${TORCH_EXTENSIONS_DIR:-$PWD/.torch_extensions_debug}"
20
+
21
+ echo "Working directory: $(pwd)"
22
+ echo
23
+
24
+ echo "=== Checking MegaBlocks Source Files ==="
25
+ echo "Verifying all source files exist:"
26
+
27
+ sources=(
28
+ "torch-ext/torch_binding.cpp"
29
+ "csrc/new_cumsum.cu"
30
+ "csrc/new_histogram.cu"
31
+ "csrc/new_indices.cu"
32
+ "csrc/new_replicate.cu"
33
+ "csrc/new_sort.cu"
34
+ "csrc/grouped_gemm/grouped_gemm.cu"
35
+ )
36
+
37
+ all_exist=true
38
+ for src in "${sources[@]}"; do
39
+ if [ -f "$src" ]; then
40
+ echo "✓ $src exists ($(wc -l < "$src") lines)"
41
+ else
42
+ echo "✗ $src missing"
43
+ all_exist=false
44
+ fi
45
+ done
46
+
47
+ if [ "$all_exist" = false ]; then
48
+ echo "Cannot proceed - missing source files"
49
+ exit 1
50
+ fi
51
+
52
+ echo
53
+ echo "=== Checking Include Directories ==="
54
+ if [ -d "csrc" ]; then
55
+ echo "✓ csrc include directory exists"
56
+ echo "Headers in csrc/:"
57
+ find csrc -name "*.h" -o -name "*.hpp" | head -10
58
+ else
59
+ echo "✗ csrc include directory missing"
60
+ fi
61
+
62
+ echo
63
+ echo "=== Testing Individual Source Compilation ==="
64
+
65
+ # Test compiling each .cu file individually
66
+ for src in csrc/*.cu; do
67
+ if [ -f "$src" ]; then
68
+ echo "Testing compilation of $(basename "$src")..."
69
+ if timeout 60 hipcc -c "$src" -o "/tmp/$(basename "$src" .cu).o" \
70
+ --amdgpu-target=gfx942 \
71
+ -I./csrc \
72
+ -I"$(python3 -c 'import torch; print(torch.utils.cpp_extension.include_paths()[0])')" \
73
+ -std=c++17 \
74
+ -O3 \
75
+ -fPIC; then
76
+ echo "✓ $(basename "$src") compiled successfully"
77
+ else
78
+ echo "✗ $(basename "$src") compilation failed"
79
+ fi
80
+ fi
81
+ done
82
+
83
+ echo
84
+ echo "=== Testing grouped_gemm.cu Specifically ==="
85
+ echo "This is often the most complex kernel..."
86
+
87
+ if timeout 120 hipcc -c csrc/grouped_gemm/grouped_gemm.cu -o /tmp/grouped_gemm.o \
88
+ --amdgpu-target=gfx942 \
89
+ -I./csrc \
90
+ -I"$(python3 -c 'import torch; print(torch.utils.cpp_extension.include_paths()[0])')" \
91
+ -std=c++17 \
92
+ -O3 \
93
+ -fPIC \
94
+ -lhipblaslt \
95
+ -v; then
96
+ echo "✓ grouped_gemm.cu compiled successfully"
97
+ else
98
+ echo "✗ grouped_gemm.cu compilation failed"
99
+ fi
100
+
101
+ echo
102
+ echo "=== Testing torch_binding.cpp ==="
103
+ if timeout 60 hipcc -c torch-ext/torch_binding.cpp -o /tmp/torch_binding.o \
104
+ -I./csrc \
105
+ -I"$(python3 -c 'import torch; print(torch.utils.cpp_extension.include_paths()[0])')" \
106
+ -std=c++17 \
107
+ -O3 \
108
+ -fPIC; then
109
+ echo "✓ torch_binding.cpp compiled successfully"
110
+ else
111
+ echo "✗ torch_binding.cpp compilation failed"
112
+ fi
113
+
114
+ echo
115
+ echo "=== Testing Incremental PyTorch Extension Build ==="
116
+
117
+ cat > debug_build.py << 'EOF'
118
+ import os
119
+ import pathlib
120
+ import sys
121
+ import signal
122
+ import time
123
+ from torch.utils.cpp_extension import load
124
+
125
+ def timeout_handler(signum, frame):
126
+ print("Build timed out - this indicates a hanging issue")
127
+ sys.exit(1)
128
+
129
+ # Set up timeout
130
+ signal.signal(signal.SIGALRM, timeout_handler)
131
+ signal.alarm(180) # 3 minute timeout
132
+
133
+ repo = pathlib.Path(".").resolve()
134
+ os.environ.setdefault("TORCH_EXTENSIONS_DIR", str(repo / ".torch_extensions_debug"))
135
+
136
+ print("=== Testing with Single Source File ===")
137
+ try:
138
+ print("Building with just new_cumsum.cu...")
139
+ mod = load(
140
+ name="_megablocks_debug_single",
141
+ sources=["csrc/new_cumsum.cu"],
142
+ extra_include_paths=["csrc"],
143
+ extra_cflags=["-O3", "-std=c++17"],
144
+ extra_cuda_cflags=["-O3"],
145
+ verbose=True,
146
+ is_python_module=False,
147
+ )
148
+ print("✓ Single source build successful")
149
+ except Exception as e:
150
+ print(f"✗ Single source build failed: {e}")
151
+
152
+ print("\n=== Testing with Two Source Files ===")
153
+ try:
154
+ print("Building with new_cumsum.cu and new_histogram.cu...")
155
+ mod = load(
156
+ name="_megablocks_debug_double",
157
+ sources=["csrc/new_cumsum.cu", "csrc/new_histogram.cu"],
158
+ extra_include_paths=["csrc"],
159
+ extra_cflags=["-O3", "-std=c++17"],
160
+ extra_cuda_cflags=["-O3"],
161
+ verbose=True,
162
+ is_python_module=False,
163
+ )
164
+ print("✓ Double source build successful")
165
+ except Exception as e:
166
+ print(f"✗ Double source build failed: {e}")
167
+
168
+ print("\n=== Testing with grouped_gemm.cu Only ===")
169
+ try:
170
+ print("Building with just grouped_gemm.cu (most complex)...")
171
+ mod = load(
172
+ name="_megablocks_debug_gemm",
173
+ sources=["csrc/grouped_gemm/grouped_gemm.cu"],
174
+ extra_include_paths=["csrc"],
175
+ extra_cflags=["-O3", "-std=c++17"],
176
+ extra_cuda_cflags=["-O3"],
177
+ extra_ldflags=["-lhipblaslt"],
178
+ verbose=True,
179
+ is_python_module=False,
180
+ )
181
+ print("✓ grouped_gemm build successful")
182
+ except Exception as e:
183
+ print(f"✗ grouped_gemm build failed: {e}")
184
+
185
+ signal.alarm(0) # Cancel timeout
186
+ EOF
187
+
188
+ echo "Running incremental build test..."
189
+ python3 debug_build.py
190
+
191
+ echo
192
+ echo "=== Testing Full Build with Timeout ==="
193
+
194
+ cat > debug_full_build.py << 'EOF'
195
+ import os
196
+ import pathlib
197
+ import sys
198
+ import signal
199
+ from torch.utils.cpp_extension import load
200
+
201
+ def timeout_handler(signum, frame):
202
+ print("Full build timed out - this confirms the hanging issue")
203
+ sys.exit(124) # timeout exit code
204
+
205
+ # Set up 5 minute timeout
206
+ signal.signal(signal.SIGALRM, timeout_handler)
207
+ signal.alarm(300)
208
+
209
+ repo = pathlib.Path(".").resolve()
210
+ os.environ.setdefault("TORCH_EXTENSIONS_DIR", str(repo / ".torch_extensions_debug"))
211
+
212
+ sources = [
213
+ "torch-ext/torch_binding.cpp",
214
+ "csrc/new_cumsum.cu",
215
+ "csrc/new_histogram.cu",
216
+ "csrc/new_indices.cu",
217
+ "csrc/new_replicate.cu",
218
+ "csrc/new_sort.cu",
219
+ "csrc/grouped_gemm/grouped_gemm.cu",
220
+ ]
221
+
222
+ print("=== Attempting Full MegaBlocks Build ===")
223
+ print("This mimics the exact build.py process...")
224
+ print("Sources:", sources)
225
+
226
+ try:
227
+ mod = load(
228
+ name="_megablocks_debug_full",
229
+ sources=sources,
230
+ extra_include_paths=["csrc"],
231
+ extra_cflags=["-O3", "-std=c++17"],
232
+ extra_cuda_cflags=["-O3"],
233
+ extra_ldflags=["-lhipblaslt"],
234
+ verbose=True,
235
+ is_python_module=False,
236
+ )
237
+ print("✓ Full build successful!")
238
+ print("Built:", mod)
239
+
240
+ except Exception as e:
241
+ print(f"✗ Full build failed: {e}")
242
+ import traceback
243
+ traceback.print_exc()
244
+
245
+ signal.alarm(0)
246
+ EOF
247
+
248
+ echo "Running full build test (with timeout)..."
249
+ python3 debug_full_build.py
250
+
251
+ echo
252
+ echo "=== Cleanup ==="
253
+ rm -f /tmp/*.o
254
+ rm -f debug_build.py debug_full_build.py
255
+
256
+ echo
257
+ echo "=== Debug Script 4 Complete ==="
_dev/debug-build-all.sh ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # Debug script: Run all debug tests in sequence
4
+
5
+ set -euo pipefail
6
+
7
+ echo "=== MegaBlocks Build Debugging Suite ==="
8
+ echo "Running progressive debug tests to identify build hang issue"
9
+ echo "ROCm installation: /opt/rocm-7.0.1"
10
+ echo
11
+
12
+ # Make all scripts executable
13
+ chmod +x debug-build-1-env.sh debug-build-2-hipcc.sh debug-build-3-torch-ext.sh debug-build-4-megablocks.sh
14
+
15
+ scripts=(
16
+ "debug-build-1-env.sh"
17
+ "debug-build-2-hipcc.sh"
18
+ "debug-build-3-torch-ext.sh"
19
+ "debug-build-4-megablocks.sh"
20
+ )
21
+
22
+ results=()
23
+ start_time=$(date +%s)
24
+
25
+ for script in "${scripts[@]}"; do
26
+ echo
27
+ echo "========================================"
28
+ echo "Running $script"
29
+ echo "========================================"
30
+
31
+ script_start=$(date +%s)
32
+
33
+ if ./"$script"; then
34
+ script_end=$(date +%s)
35
+ duration=$((script_end - script_start))
36
+ echo "✓ $script completed successfully in ${duration}s"
37
+ results+=("✓ $script: SUCCESS (${duration}s)")
38
+ else
39
+ script_end=$(date +%s)
40
+ duration=$((script_end - script_start))
41
+ echo "✗ $script failed in ${duration}s"
42
+ results+=("✗ $script: FAILED (${duration}s)")
43
+ fi
44
+
45
+ echo "----------------------------------------"
46
+ done
47
+
48
+ end_time=$(date +%s)
49
+ total_duration=$((end_time - start_time))
50
+
51
+ echo
52
+ echo "========================================"
53
+ echo "SUMMARY REPORT"
54
+ echo "========================================"
55
+ echo "Total runtime: ${total_duration}s"
56
+ echo
57
+
58
+ for result in "${results[@]}"; do
59
+ echo "$result"
60
+ done
61
+
62
+ echo
63
+ echo "=== Analysis ==="
64
+ echo "1. If debug-1-env.sh fails: ROCm installation/environment issue"
65
+ echo "2. If debug-2-hipcc.sh fails: HIP compiler issue"
66
+ echo "3. If debug-3-torch-ext.sh hangs: PyTorch extension compilation issue"
67
+ echo "4. If debug-4-megablocks.sh hangs: MegaBlocks-specific compilation issue"
68
+ echo
69
+ echo "=== Next Steps Based on Results ==="
70
+ echo "- If all pass: The issue may be intermittent or environment-specific"
71
+ echo "- If script 3 or 4 hangs: Run with strace to see where it hangs:"
72
+ echo " strace -f -e trace=process,signal python3 build.py"
73
+ echo "- Check compilation log files in .torch_extensions for more details"
74
+ echo "- Consider using PYTORCH_JIT_LOG_LEVEL=1 for more verbose output"
75
+
76
+ echo
77
+ echo "=== Additional Debugging Commands ==="
78
+ echo "# Check for stuck processes:"
79
+ echo "ps aux | grep -E '(hipcc|hip-clang|python)'"
80
+ echo
81
+ echo "# Monitor system resources during build:"
82
+ echo "htop"
83
+ echo
84
+ echo "# Check for device issues:"
85
+ echo "dmesg | tail -20"
86
+ echo
87
+ echo "# Force clean rebuild:"
88
+ echo "rm -rf .torch_extensions* && ./build.sh"
89
+
90
+ echo
91
+ echo "Debug suite complete."
_dev/debug_build.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pathlib
3
+ import sys
4
+ import signal
5
+ import time
6
+ from torch.utils.cpp_extension import load
7
+
8
+ def timeout_handler(signum, frame):
9
+ print("Build timed out - this indicates a hanging issue")
10
+ sys.exit(1)
11
+
12
+ # Set up timeout
13
+ signal.signal(signal.SIGALRM, timeout_handler)
14
+ signal.alarm(180) # 3 minute timeout
15
+
16
+ repo = pathlib.Path(".").resolve()
17
+ os.environ.setdefault("TORCH_EXTENSIONS_DIR", str(repo / ".torch_extensions_debug"))
18
+
19
+ print("=== Testing with Single Source File ===")
20
+ try:
21
+ print("Building with just new_cumsum.cu...")
22
+ mod = load(
23
+ name="_megablocks_debug_single",
24
+ sources=["csrc/new_cumsum.cu"],
25
+ extra_include_paths=["csrc"],
26
+ extra_cflags=["-O3", "-std=c++17"],
27
+ extra_cuda_cflags=["-O3"],
28
+ verbose=True,
29
+ is_python_module=False,
30
+ )
31
+ print("✓ Single source build successful")
32
+ except Exception as e:
33
+ print(f"✗ Single source build failed: {e}")
34
+
35
+ print("\n=== Testing with Two Source Files ===")
36
+ try:
37
+ print("Building with new_cumsum.cu and new_histogram.cu...")
38
+ mod = load(
39
+ name="_megablocks_debug_double",
40
+ sources=["csrc/new_cumsum.cu", "csrc/new_histogram.cu"],
41
+ extra_include_paths=["csrc"],
42
+ extra_cflags=["-O3", "-std=c++17"],
43
+ extra_cuda_cflags=["-O3"],
44
+ verbose=True,
45
+ is_python_module=False,
46
+ )
47
+ print("✓ Double source build successful")
48
+ except Exception as e:
49
+ print(f"✗ Double source build failed: {e}")
50
+
51
+ print("\n=== Testing with grouped_gemm.cu Only ===")
52
+ try:
53
+ print("Building with just grouped_gemm.cu (most complex)...")
54
+ mod = load(
55
+ name="_megablocks_debug_gemm",
56
+ sources=["csrc/grouped_gemm/grouped_gemm.cu"],
57
+ extra_include_paths=["csrc"],
58
+ extra_cflags=["-O3", "-std=c++17"],
59
+ extra_cuda_cflags=["-O3"],
60
+ extra_ldflags=["-lhipblaslt"],
61
+ verbose=True,
62
+ is_python_module=False,
63
+ )
64
+ print("✓ grouped_gemm build successful")
65
+ except Exception as e:
66
+ print(f"✗ grouped_gemm build failed: {e}")
67
+
68
+ signal.alarm(0) # Cancel timeout
build-fixed.sh ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ set -euo pipefail
4
+
5
+ # Fixed build script with proper ROCm/HIP environment
6
+
7
+ echo "=== Fixed Build Script ==="
8
+ echo "Configuring environment for ROCm 7.0.1 with proper exports"
9
+
10
+ # Default to the ROCm 7.0.1 install unless the caller overrides it.
11
+ export ROCM_PATH="${ROCM_PATH:-/opt/rocm-7.0.1}"
12
+ export ROCM_HOME="${ROCM_HOME:-$ROCM_PATH}"
13
+ export HIP_PATH="${HIP_PATH:-$ROCM_PATH}"
14
+ export HIP_HOME="${HIP_HOME:-$ROCM_PATH}"
15
+
16
+ export PATH="$ROCM_HOME/bin:$PATH"
17
+
18
+ # Fix architecture specifications - use gfx942 consistently
19
+ export TORCH_HIP_ARCH_LIST="gfx942"
20
+ export PYTORCH_ROCM_ARCH="gfx942"
21
+
22
+ # Remove HSA_OVERRIDE_GFX_VERSION - not needed since MI300X is already gfx942
23
+ unset HSA_OVERRIDE_GFX_VERSION
24
+
25
+ # Force single-threaded compilation to avoid ninja hanging
26
+ export MAX_JOBS=1
27
+
28
+ # Enable PyTorch JIT logging for debugging
29
+ export PYTORCH_JIT_LOG_LEVEL=1
30
+
31
+ export TORCH_EXTENSIONS_DIR="${TORCH_EXTENSIONS_DIR:-$PWD/.torch_extensions}"
32
+
33
+ echo "Environment configured:"
34
+ echo "ROCM_PATH=$ROCM_PATH"
35
+ echo "TORCH_HIP_ARCH_LIST=$TORCH_HIP_ARCH_LIST"
36
+ echo "PYTORCH_ROCM_ARCH=$PYTORCH_ROCM_ARCH"
37
+ echo "MAX_JOBS=$MAX_JOBS"
38
+ echo "PYTORCH_JIT_LOG_LEVEL=$PYTORCH_JIT_LOG_LEVEL"
39
+ echo "HSA_OVERRIDE_GFX_VERSION=${HSA_OVERRIDE_GFX_VERSION:-unset}"
40
+ echo
41
+
42
+ echo "Starting build..."
43
+ python -u build.py
build-strace.sh ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ set -euo pipefail
4
+
5
+ # Build script with strace to debug hanging
6
+
7
+ echo "=== Build with strace debugging ==="
8
+ echo "This will trace system calls to identify where the build hangs"
9
+
10
+ # Same environment as build-fixed.sh but without MAX_JOBS limit
11
+ export ROCM_PATH="${ROCM_PATH:-/opt/rocm-7.0.1}"
12
+ export ROCM_HOME="${ROCM_HOME:-$ROCM_PATH}"
13
+ export HIP_PATH="${HIP_PATH:-$ROCM_PATH}"
14
+ export HIP_HOME="${HIP_HOME:-$ROCM_PATH}"
15
+
16
+ export PATH="$ROCM_HOME/bin:$PATH"
17
+
18
+ # Fix architecture specifications
19
+ export TORCH_HIP_ARCH_LIST="gfx942"
20
+ export PYTORCH_ROCM_ARCH="gfx942"
21
+
22
+ # Remove HSA_OVERRIDE_GFX_VERSION
23
+ unset HSA_OVERRIDE_GFX_VERSION
24
+
25
+ # Remove MAX_JOBS limit to see parallel compilation hang
26
+ unset MAX_JOBS
27
+
28
+ # Enable PyTorch JIT logging
29
+ export PYTORCH_JIT_LOG_LEVEL=1
30
+
31
+ export TORCH_EXTENSIONS_DIR="${TORCH_EXTENSIONS_DIR:-$PWD/.torch_extensions}"
32
+
33
+ echo "Environment configured for strace:"
34
+ echo "ROCM_PATH=$ROCM_PATH"
35
+ echo "TORCH_HIP_ARCH_LIST=$TORCH_HIP_ARCH_LIST"
36
+ echo "PYTORCH_ROCM_ARCH=$PYTORCH_ROCM_ARCH"
37
+ echo "MAX_JOBS=${MAX_JOBS:-unset}"
38
+ echo "PYTORCH_JIT_LOG_LEVEL=$PYTORCH_JIT_LOG_LEVEL"
39
+ echo
40
+
41
+ echo "Starting build with strace..."
42
+ echo "Tracing process creation, signals, and file operations..."
43
+ echo "Output will be saved to strace.log"
44
+
45
+ # Use strace to trace the build process
46
+ # -f: follow child processes
47
+ # -e trace=process,signal: trace process creation and signals
48
+ # -e trace=file: trace file operations
49
+ # -o strace.log: save output to file
50
+ # -T: show time spent in each syscall
51
+ strace -f -e trace=process,signal,file -o strace.log -T python -u build.py
52
+
53
+ echo "Build completed or interrupted"
54
+ echo "Check strace.log for detailed system call trace"
build.sh CHANGED
@@ -1,9 +1,39 @@
1
- # export TORCH_EXTENSIONS_DIR=/root/shisa-v2/train/v2.1/megablocks.kernels-community/.torch_extensions; export ROCM_HOME=/opt/rocm-6.4.1; export HIP_HOME=$ROCM_HOME; export TORCH_HIP_ARCH_LIST=gfx942; export HSA_OVERRIDE_GFX_VERSION=gfx942; python megablocks.kernels-community/build.py
2
-
3
- # 3-4min build
4
- export ROCM_HOME=/opt/rocm-6.4.1
5
- export HIP_HOME=$ROCM_HOME
6
- export TORCH_HIP_ARCH_LIST=gfx942
7
- export HSA_OVERRIDE_GFX_VERSION=gfx942
8
- export TORCH_EXTENSIONS_DIR="$PWD/megablocks.kernels-community/.torch_extensions"
9
- python megablocks.kernels-community/build.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ set -euo pipefail
4
+
5
+ # 3-4min build with cleanup to prevent lock file hangs
6
+
7
+ echo "=== MegaBlocks Build Script ==="
8
+ echo "Cleaning up any previous build processes and lock files..."
9
+
10
+ # Kill any hanging build processes
11
+ echo "Killing any running build.py processes..."
12
+ pkill -f "python.*build\.py" 2>/dev/null || true
13
+ pkill -f "ninja" 2>/dev/null || true
14
+ pkill -f "hipcc" 2>/dev/null || true
15
+
16
+ # Wait a moment for processes to terminate
17
+ sleep 2
18
+
19
+ # Clean up lock files that cause infinite loops
20
+ echo "Removing stale lock files..."
21
+ if [ -d ".torch_extensions" ]; then
22
+ find .torch_extensions -name "lock" -delete 2>/dev/null || true
23
+ find .torch_extensions -name ".ninja_lock" -delete 2>/dev/null || true
24
+ fi
25
+
26
+ # Default to the ROCm 7.0.1 install unless the caller overrides it.
27
+ export ROCM_PATH="${ROCM_PATH:-/opt/rocm-7.0.1}"
28
+ export ROCM_HOME="${ROCM_HOME:-$ROCM_PATH}"
29
+ export HIP_PATH="${HIP_PATH:-$ROCM_PATH}"
30
+ export HIP_HOME="${HIP_HOME:-$ROCM_PATH}"
31
+
32
+ export PATH="$ROCM_HOME/bin:$PATH"
33
+ export LD_LIBRARY_PATH="$ROCM_HOME/lib:$ROCM_HOME/lib64:${LD_LIBRARY_PATH:-}"
34
+ export TORCH_HIP_ARCH_LIST="${TORCH_HIP_ARCH_LIST:-gfx942}"
35
+ export HSA_OVERRIDE_GFX_VERSION="${HSA_OVERRIDE_GFX_VERSION:-gfx942}"
36
+ export TORCH_EXTENSIONS_DIR="${TORCH_EXTENSIONS_DIR:-$PWD/.torch_extensions}"
37
+
38
+ echo "Environment configured. Starting build..."
39
+ python -u build.py
run-tests.sh ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ KERNEL_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
5
+ cd "$KERNEL_DIR"
6
+
7
+ export KERNEL_DIR
8
+
9
+ detect_variant() {
10
+ python - <<'PY'
11
+ import os
12
+ import pathlib
13
+
14
+ root = pathlib.Path(os.environ["KERNEL_DIR"])
15
+ build_dir = root / "build"
16
+ variant = None
17
+
18
+ try:
19
+ from kernels.utils import build_variant as _build_variant
20
+ except Exception:
21
+ _build_variant = None
22
+
23
+ if _build_variant is not None:
24
+ try:
25
+ variant = _build_variant()
26
+ except Exception:
27
+ variant = None
28
+
29
+ if variant is None:
30
+ candidates = sorted(build_dir.glob("torch*-rocm64-*") or build_dir.glob("torch*-cu*"))
31
+ if candidates:
32
+ variant = candidates[0].name
33
+
34
+ if variant is None:
35
+ raise SystemExit("Could not determine MegaBlocks build variant. Run build.py first.")
36
+
37
+ print(variant)
38
+ PY
39
+ }
40
+
41
+ VARIANT=$(detect_variant)
42
+
43
+ STAGED_DIR="$KERNEL_DIR/build/$VARIANT"
44
+ find_staged_lib() {
45
+ local base="$1"
46
+ local candidates=(
47
+ "$base/_megablocks_rocm.so"
48
+ "$base/megablocks/_megablocks_rocm.so"
49
+ )
50
+ for path in "${candidates[@]}"; do
51
+ if [[ -f "$path" ]]; then
52
+ echo "$path"
53
+ return 0
54
+ fi
55
+ done
56
+ return 1
57
+ }
58
+
59
+ STAGED_LIB=$(find_staged_lib "$STAGED_DIR") || true
60
+
61
+ if [[ -z "${STAGED_LIB:-}" ]]; then
62
+ echo "Staged ROCm extension not found under $STAGED_DIR; rebuilding kernels..."
63
+ python build.py
64
+ VARIANT=$(detect_variant)
65
+ STAGED_DIR="$KERNEL_DIR/build/$VARIANT"
66
+ STAGED_LIB=$(find_staged_lib "$STAGED_DIR") || true
67
+ if [[ -z "${STAGED_LIB:-}" ]]; then
68
+ echo "ERROR: build.py completed but no extension was found under $STAGED_DIR" >&2
69
+ exit 1
70
+ fi
71
+ fi
72
+
73
+ export PYTHONPATH="$STAGED_DIR:${PYTHONPATH:-}"
74
+
75
+ echo "Using MegaBlocks build variant: $VARIANT"
76
+
77
+ declare -i GPU_COUNT
78
+ GPU_COUNT=$(python - <<'PY'
79
+ import torch
80
+ print(torch.cuda.device_count() if torch.cuda.is_available() else 0)
81
+ PY
82
+ )
83
+
84
+ if (( GPU_COUNT == 0 )); then
85
+ echo "ERROR: No HIP/CUDA GPUs detected. Tests require at least one visible accelerator." >&2
86
+ exit 1
87
+ fi
88
+
89
+ echo "Detected $GPU_COUNT visible GPU(s)."
90
+
91
+ log() {
92
+ echo
93
+ echo "==> $1"
94
+ }
95
+
96
+ run_pytest() {
97
+ local label="$1"
98
+ shift
99
+ log "$label"
100
+ set -x
101
+ "$@"
102
+ { set +x; } 2>/dev/null || true
103
+ }
104
+
105
+ SINGLE_GPU_ENV=(HIP_VISIBLE_DEVICES=0 CUDA_VISIBLE_DEVICES=0 WORLD_SIZE=1)
106
+ MULTI2_GPU_ENV=(HIP_VISIBLE_DEVICES=0,1 CUDA_VISIBLE_DEVICES=0,1 WORLD_SIZE=2)
107
+ MULTI8_GPU_ENV=(HIP_VISIBLE_DEVICES=$(seq -s, 0 7) CUDA_VISIBLE_DEVICES=$(seq -s, 0 7) WORLD_SIZE=8)
108
+
109
+ SINGLE_TESTS=(
110
+ "test_mb_moe.py"
111
+ "test_mb_moe_shared_expert.py"
112
+ "layer_test.py"
113
+ "test_gg.py"
114
+ "ops_test.py"
115
+ )
116
+
117
+ for test in "${SINGLE_TESTS[@]}"; do
118
+ run_pytest "Single-GPU pytest ${test}" env "${SINGLE_GPU_ENV[@]}" python -m pytest "tests/${test}" -q
119
+ done
120
+
121
+ if (( GPU_COUNT >= 2 )); then
122
+ run_pytest "Distributed layer smoke (2 GPUs)" env "${MULTI2_GPU_ENV[@]}" python -m pytest "tests/parallel_layer_test.py::test_megablocks_moe_mlp_functionality" -q
123
+ else
124
+ log "Skipping 2-GPU distributed layer test (requires >=2 GPUs, detected ${GPU_COUNT})."
125
+ fi
126
+
127
+ run_pytest "Shared expert functionality (world_size=1)" env "${SINGLE_GPU_ENV[@]}" python -m pytest 'tests/test_mb_moe_shared_expert_multi.py::test_shared_expert_distributed_functionality[1]' -q
128
+ run_pytest "Shared expert weighted sum (world_size=1)" env "${SINGLE_GPU_ENV[@]}" python -m pytest 'tests/test_mb_moe_shared_expert_multi.py::test_shared_expert_distributed_weighted_sum[1]' -q
129
+
130
+ if (( GPU_COUNT >= 8 )); then
131
+ run_pytest "Shared expert functionality (world_size=8)" env "${MULTI8_GPU_ENV[@]}" python -m pytest 'tests/test_mb_moe_shared_expert_multi.py::test_shared_expert_distributed_functionality[8]' -q
132
+ run_pytest "Shared expert weighted sum (world_size=8)" env "${MULTI8_GPU_ENV[@]}" python -m pytest 'tests/test_mb_moe_shared_expert_multi.py::test_shared_expert_distributed_weighted_sum[8]' -q
133
+ else
134
+ log "Skipping 8-GPU shared expert tests (requires >=8 GPUs, detected ${GPU_COUNT})."
135
+ fi
136
+
137
+ echo
138
+ echo "All requested tests completed."