Commit
·
2d8a802
1
Parent(s):
1e407f0
Add ROCm build debugging utilities
Browse files- .gitignore +2 -0
- _dev/debug-build-1-env.sh +79 -0
- _dev/debug-build-2-hipcc.sh +154 -0
- _dev/debug-build-3-torch-ext.sh +187 -0
- _dev/debug-build-4-megablocks.sh +257 -0
- _dev/debug-build-all.sh +91 -0
- _dev/debug_build.py +68 -0
- build-fixed.sh +43 -0
- build-strace.sh +54 -0
- build.sh +39 -9
- run-tests.sh +138 -0
.gitignore
CHANGED
|
@@ -5,3 +5,5 @@ megablocks-moe/.bak
|
|
| 5 |
.pytest_cache
|
| 6 |
.readme_example.py.swp
|
| 7 |
.torch_extensions/
|
|
|
|
|
|
|
|
|
| 5 |
.pytest_cache
|
| 6 |
.readme_example.py.swp
|
| 7 |
.torch_extensions/
|
| 8 |
+
.torch_extensions_debug/
|
| 9 |
+
strace.log
|
_dev/debug-build-1-env.sh
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
# Debug script 1: Basic ROCm environment and tool availability check
|
| 4 |
+
|
| 5 |
+
set -euo pipefail
|
| 6 |
+
|
| 7 |
+
echo "=== ROCm Environment Debug Script 1 ==="
|
| 8 |
+
echo "Testing basic ROCm/HIP environment setup"
|
| 9 |
+
echo
|
| 10 |
+
|
| 11 |
+
# Set ROCm environment variables
|
| 12 |
+
export ROCM_PATH="${ROCM_PATH:-/opt/rocm-7.0.1}"
|
| 13 |
+
export ROCM_HOME="${ROCM_HOME:-$ROCM_PATH}"
|
| 14 |
+
export HIP_PATH="${HIP_PATH:-$ROCM_PATH}"
|
| 15 |
+
export HIP_HOME="${HIP_HOME:-$ROCM_PATH}"
|
| 16 |
+
export PATH="$ROCM_HOME/bin:$PATH"
|
| 17 |
+
export TORCH_HIP_ARCH_LIST="${TORCH_HIP_ARCH_LIST:-gfx942}"
|
| 18 |
+
export HSA_OVERRIDE_GFX_VERSION="${HSA_OVERRIDE_GFX_VERSION:-gfx942}"
|
| 19 |
+
|
| 20 |
+
echo "Environment Variables:"
|
| 21 |
+
echo "ROCM_PATH=$ROCM_PATH"
|
| 22 |
+
echo "ROCM_HOME=$ROCM_HOME"
|
| 23 |
+
echo "HIP_PATH=$HIP_PATH"
|
| 24 |
+
echo "HIP_HOME=$HIP_HOME"
|
| 25 |
+
echo "TORCH_HIP_ARCH_LIST=$TORCH_HIP_ARCH_LIST"
|
| 26 |
+
echo "HSA_OVERRIDE_GFX_VERSION=$HSA_OVERRIDE_GFX_VERSION"
|
| 27 |
+
echo "PATH (ROCm portion): $(echo $PATH | tr ':' '\n' | grep rocm || echo 'No ROCm in PATH')"
|
| 28 |
+
echo
|
| 29 |
+
|
| 30 |
+
echo "=== Directory Checks ==="
|
| 31 |
+
echo "ROCm installation directory exists: $(test -d "$ROCM_PATH" && echo 'YES' || echo 'NO')"
|
| 32 |
+
echo "ROCm bin directory exists: $(test -d "$ROCM_PATH/bin" && echo 'YES' || echo 'NO')"
|
| 33 |
+
echo "ROCm include directory exists: $(test -d "$ROCM_PATH/include" && echo 'YES' || echo 'NO')"
|
| 34 |
+
echo "ROCm lib directory exists: $(test -d "$ROCM_PATH/lib" && echo 'YES' || echo 'NO')"
|
| 35 |
+
echo
|
| 36 |
+
|
| 37 |
+
echo "=== Tool Availability ==="
|
| 38 |
+
echo "hipcc available: $(which hipcc >/dev/null 2>&1 && echo 'YES' || echo 'NO')"
|
| 39 |
+
echo "hip-clang available: $(which hip-clang >/dev/null 2>&1 && echo 'YES' || echo 'NO')"
|
| 40 |
+
echo "rocm-smi available: $(which rocm-smi >/dev/null 2>&1 && echo 'YES' || echo 'NO')"
|
| 41 |
+
echo "hipconfig available: $(which hipconfig >/dev/null 2>&1 && echo 'YES' || echo 'NO')"
|
| 42 |
+
echo
|
| 43 |
+
|
| 44 |
+
echo "=== Tool Versions ==="
|
| 45 |
+
if which hipcc >/dev/null 2>&1; then
|
| 46 |
+
echo "hipcc version:"
|
| 47 |
+
hipcc --version || echo "Failed to get hipcc version"
|
| 48 |
+
echo
|
| 49 |
+
fi
|
| 50 |
+
|
| 51 |
+
if which hipconfig >/dev/null 2>&1; then
|
| 52 |
+
echo "HIP config:"
|
| 53 |
+
hipconfig --full || echo "Failed to get hipconfig"
|
| 54 |
+
echo
|
| 55 |
+
fi
|
| 56 |
+
|
| 57 |
+
if which rocm-smi >/dev/null 2>&1; then
|
| 58 |
+
echo "ROCm SMI:"
|
| 59 |
+
rocm-smi --showproductname || echo "Failed to get ROCm SMI info"
|
| 60 |
+
echo
|
| 61 |
+
fi
|
| 62 |
+
|
| 63 |
+
echo "=== Python Environment ==="
|
| 64 |
+
python3 --version || echo "Python3 not available"
|
| 65 |
+
python3 -c "import torch; print(f'PyTorch version: {torch.__version__}')" || echo "PyTorch not available"
|
| 66 |
+
python3 -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')" || echo "Failed to check CUDA availability"
|
| 67 |
+
python3 -c "import torch; print(f'HIP available: {hasattr(torch.version, \"hip\") and torch.version.hip is not None}')" || echo "Failed to check HIP availability"
|
| 68 |
+
|
| 69 |
+
echo
|
| 70 |
+
echo "=== Basic HIP Device Check ==="
|
| 71 |
+
if which hipinfo >/dev/null 2>&1; then
|
| 72 |
+
echo "HIP devices:"
|
| 73 |
+
hipinfo || echo "hipinfo failed"
|
| 74 |
+
else
|
| 75 |
+
echo "hipinfo not available"
|
| 76 |
+
fi
|
| 77 |
+
|
| 78 |
+
echo
|
| 79 |
+
echo "=== Debug Script 1 Complete ==="
|
_dev/debug-build-2-hipcc.sh
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
# Debug script 2: HIP compiler basic compilation test
|
| 4 |
+
|
| 5 |
+
set -euo pipefail
|
| 6 |
+
|
| 7 |
+
echo "=== HIP Compiler Debug Script 2 ==="
|
| 8 |
+
echo "Testing basic HIP compilation capabilities"
|
| 9 |
+
echo
|
| 10 |
+
|
| 11 |
+
# Set ROCm environment variables
|
| 12 |
+
export ROCM_PATH="${ROCM_PATH:-/opt/rocm-7.0.1}"
|
| 13 |
+
export ROCM_HOME="${ROCM_HOME:-$ROCM_PATH}"
|
| 14 |
+
export HIP_PATH="${HIP_PATH:-$ROCM_PATH}"
|
| 15 |
+
export HIP_HOME="${HIP_HOME:-$ROCM_PATH}"
|
| 16 |
+
export PATH="$ROCM_HOME/bin:$PATH"
|
| 17 |
+
export TORCH_HIP_ARCH_LIST="${TORCH_HIP_ARCH_LIST:-gfx942}"
|
| 18 |
+
export HSA_OVERRIDE_GFX_VERSION="${HSA_OVERRIDE_GFX_VERSION:-gfx942}"
|
| 19 |
+
|
| 20 |
+
# Create a simple test directory
|
| 21 |
+
mkdir -p /tmp/hip_test
|
| 22 |
+
cd /tmp/hip_test
|
| 23 |
+
|
| 24 |
+
echo "=== Creating Simple HIP Test Program ==="
|
| 25 |
+
|
| 26 |
+
# Create a minimal HIP program
|
| 27 |
+
cat > test_simple.hip << 'EOF'
|
| 28 |
+
#include <hip/hip_runtime.h>
|
| 29 |
+
#include <iostream>
|
| 30 |
+
|
| 31 |
+
__global__ void simple_kernel(float* data, int n) {
|
| 32 |
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 33 |
+
if (idx < n) {
|
| 34 |
+
data[idx] = idx * 2.0f;
|
| 35 |
+
}
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
int main() {
|
| 39 |
+
const int n = 1024;
|
| 40 |
+
const size_t size = n * sizeof(float);
|
| 41 |
+
|
| 42 |
+
float* h_data = new float[n];
|
| 43 |
+
float* d_data;
|
| 44 |
+
|
| 45 |
+
// Allocate device memory
|
| 46 |
+
hipError_t err = hipMalloc(&d_data, size);
|
| 47 |
+
if (err != hipSuccess) {
|
| 48 |
+
std::cout << "Failed to allocate device memory: " << hipGetErrorString(err) << std::endl;
|
| 49 |
+
return 1;
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
// Launch kernel
|
| 53 |
+
dim3 block(256);
|
| 54 |
+
dim3 grid((n + block.x - 1) / block.x);
|
| 55 |
+
hipLaunchKernelGGL(simple_kernel, grid, block, 0, 0, d_data, n);
|
| 56 |
+
|
| 57 |
+
// Copy back
|
| 58 |
+
err = hipMemcpy(h_data, d_data, size, hipMemcpyDeviceToHost);
|
| 59 |
+
if (err != hipSuccess) {
|
| 60 |
+
std::cout << "Failed to copy from device: " << hipGetErrorString(err) << std::endl;
|
| 61 |
+
return 1;
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
// Check results
|
| 65 |
+
bool success = true;
|
| 66 |
+
for (int i = 0; i < 10; i++) {
|
| 67 |
+
if (h_data[i] != i * 2.0f) {
|
| 68 |
+
success = false;
|
| 69 |
+
break;
|
| 70 |
+
}
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
std::cout << "Simple HIP test: " << (success ? "PASSED" : "FAILED") << std::endl;
|
| 74 |
+
|
| 75 |
+
hipFree(d_data);
|
| 76 |
+
delete[] h_data;
|
| 77 |
+
|
| 78 |
+
return success ? 0 : 1;
|
| 79 |
+
}
|
| 80 |
+
EOF
|
| 81 |
+
|
| 82 |
+
echo "=== Testing HIP Compilation ==="
|
| 83 |
+
echo "Compiling simple HIP program..."
|
| 84 |
+
|
| 85 |
+
if hipcc -o test_simple test_simple.hip --amdgpu-target=gfx942 -v; then
|
| 86 |
+
echo "✓ HIP compilation successful"
|
| 87 |
+
echo
|
| 88 |
+
echo "=== Testing HIP Execution ==="
|
| 89 |
+
if ./test_simple; then
|
| 90 |
+
echo "✓ HIP execution successful"
|
| 91 |
+
else
|
| 92 |
+
echo "✗ HIP execution failed"
|
| 93 |
+
fi
|
| 94 |
+
else
|
| 95 |
+
echo "✗ HIP compilation failed"
|
| 96 |
+
fi
|
| 97 |
+
|
| 98 |
+
echo
|
| 99 |
+
echo "=== Testing HIP with C++ Extensions Flags ==="
|
| 100 |
+
|
| 101 |
+
# Test with flags similar to what torch uses
|
| 102 |
+
echo "Compiling with torch-like flags..."
|
| 103 |
+
if hipcc -o test_simple_torch test_simple.hip \
|
| 104 |
+
--amdgpu-target=gfx942 \
|
| 105 |
+
-O3 \
|
| 106 |
+
-std=c++17 \
|
| 107 |
+
-fPIC \
|
| 108 |
+
-shared \
|
| 109 |
+
-v; then
|
| 110 |
+
echo "✓ HIP compilation with torch flags successful"
|
| 111 |
+
else
|
| 112 |
+
echo "✗ HIP compilation with torch flags failed"
|
| 113 |
+
fi
|
| 114 |
+
|
| 115 |
+
echo
|
| 116 |
+
echo "=== Testing hipblaslt Library ==="
|
| 117 |
+
echo "Checking if hipblaslt is available for linking..."
|
| 118 |
+
|
| 119 |
+
cat > test_hipblaslt.cpp << 'EOF'
|
| 120 |
+
#include <iostream>
|
| 121 |
+
|
| 122 |
+
// Just test if we can link against hipblaslt
|
| 123 |
+
int main() {
|
| 124 |
+
std::cout << "hipblaslt linkage test" << std::endl;
|
| 125 |
+
return 0;
|
| 126 |
+
}
|
| 127 |
+
EOF
|
| 128 |
+
|
| 129 |
+
if hipcc -o test_hipblaslt test_hipblaslt.cpp -lhipblaslt -v; then
|
| 130 |
+
echo "✓ hipblaslt linkage successful"
|
| 131 |
+
if ./test_hipblaslt; then
|
| 132 |
+
echo "✓ hipblaslt execution successful"
|
| 133 |
+
else
|
| 134 |
+
echo "✗ hipblaslt execution failed"
|
| 135 |
+
fi
|
| 136 |
+
else
|
| 137 |
+
echo "✗ hipblaslt linkage failed"
|
| 138 |
+
fi
|
| 139 |
+
|
| 140 |
+
echo
|
| 141 |
+
echo "=== Library Search Paths ==="
|
| 142 |
+
echo "ROCm lib directory contents:"
|
| 143 |
+
ls -la "$ROCM_PATH/lib" | head -20 || echo "Could not list ROCm lib directory"
|
| 144 |
+
|
| 145 |
+
echo
|
| 146 |
+
echo "hipblaslt library files:"
|
| 147 |
+
find "$ROCM_PATH" -name "*hipblaslt*" 2>/dev/null || echo "No hipblaslt files found"
|
| 148 |
+
|
| 149 |
+
echo
|
| 150 |
+
echo "=== Debug Script 2 Complete ==="
|
| 151 |
+
|
| 152 |
+
# Cleanup
|
| 153 |
+
cd /
|
| 154 |
+
rm -rf /tmp/hip_test
|
_dev/debug-build-3-torch-ext.sh
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
# Debug script 3: PyTorch C++ extension compilation test
|
| 4 |
+
|
| 5 |
+
set -euo pipefail
|
| 6 |
+
|
| 7 |
+
echo "=== PyTorch C++ Extension Debug Script 3 ==="
|
| 8 |
+
echo "Testing PyTorch C++ extension compilation with HIP"
|
| 9 |
+
echo
|
| 10 |
+
|
| 11 |
+
# Set ROCm environment variables
|
| 12 |
+
export ROCM_PATH="${ROCM_PATH:-/opt/rocm-7.0.1}"
|
| 13 |
+
export ROCM_HOME="${ROCM_HOME:-$ROCM_PATH}"
|
| 14 |
+
export HIP_PATH="${HIP_PATH:-$ROCM_PATH}"
|
| 15 |
+
export HIP_HOME="${HIP_HOME:-$ROCM_PATH}"
|
| 16 |
+
export PATH="$ROCM_HOME/bin:$PATH"
|
| 17 |
+
export TORCH_HIP_ARCH_LIST="${TORCH_HIP_ARCH_LIST:-gfx942}"
|
| 18 |
+
export HSA_OVERRIDE_GFX_VERSION="${HSA_OVERRIDE_GFX_VERSION:-gfx942}"
|
| 19 |
+
export TORCH_EXTENSIONS_DIR="${TORCH_EXTENSIONS_DIR:-$PWD/.torch_extensions_debug}"
|
| 20 |
+
|
| 21 |
+
# Create a test directory
|
| 22 |
+
mkdir -p /tmp/torch_ext_test
|
| 23 |
+
cd /tmp/torch_ext_test
|
| 24 |
+
|
| 25 |
+
echo "=== Creating Simple PyTorch Extension ==="
|
| 26 |
+
|
| 27 |
+
# Create a minimal CUDA/HIP kernel similar to megablocks
|
| 28 |
+
cat > simple_kernel.cu << 'EOF'
|
| 29 |
+
#include <torch/extension.h>
|
| 30 |
+
#include <vector>
|
| 31 |
+
|
| 32 |
+
#ifdef __HIP_PLATFORM_AMD__
|
| 33 |
+
#include <hip/hip_runtime.h>
|
| 34 |
+
#define CUDA_LAUNCH_KERNEL(kernel, grid, block, smem, stream, ...) \
|
| 35 |
+
hipLaunchKernelGGL(kernel, grid, block, smem, stream, __VA_ARGS__)
|
| 36 |
+
#else
|
| 37 |
+
#include <cuda_runtime.h>
|
| 38 |
+
#define CUDA_LAUNCH_KERNEL(kernel, grid, block, smem, stream, ...) \
|
| 39 |
+
kernel<<<grid, block, smem, stream>>>(__VA_ARGS__)
|
| 40 |
+
#endif
|
| 41 |
+
|
| 42 |
+
__global__ void add_kernel(const float* a, const float* b, float* c, int n) {
|
| 43 |
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 44 |
+
if (idx < n) {
|
| 45 |
+
c[idx] = a[idx] + b[idx];
|
| 46 |
+
}
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
torch::Tensor add_tensors_cuda(torch::Tensor a, torch::Tensor b) {
|
| 50 |
+
auto c = torch::zeros_like(a);
|
| 51 |
+
int n = a.numel();
|
| 52 |
+
|
| 53 |
+
const int block_size = 256;
|
| 54 |
+
const int grid_size = (n + block_size - 1) / block_size;
|
| 55 |
+
|
| 56 |
+
CUDA_LAUNCH_KERNEL(
|
| 57 |
+
add_kernel,
|
| 58 |
+
dim3(grid_size),
|
| 59 |
+
dim3(block_size),
|
| 60 |
+
0,
|
| 61 |
+
0,
|
| 62 |
+
a.data_ptr<float>(),
|
| 63 |
+
b.data_ptr<float>(),
|
| 64 |
+
c.data_ptr<float>(),
|
| 65 |
+
n
|
| 66 |
+
);
|
| 67 |
+
|
| 68 |
+
return c;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 72 |
+
m.def("add_tensors", &add_tensors_cuda, "Add two tensors (CUDA/HIP)");
|
| 73 |
+
}
|
| 74 |
+
EOF
|
| 75 |
+
|
| 76 |
+
# Create Python test script
|
| 77 |
+
cat > test_extension.py << 'EOF'
|
| 78 |
+
import os
|
| 79 |
+
import sys
|
| 80 |
+
import torch
|
| 81 |
+
from torch.utils.cpp_extension import load
|
| 82 |
+
|
| 83 |
+
print("=== PyTorch Extension Load Test ===")
|
| 84 |
+
print(f"PyTorch version: {torch.__version__}")
|
| 85 |
+
print(f"CUDA available: {torch.cuda.is_available()}")
|
| 86 |
+
print(f"Device count: {torch.cuda.device_count()}")
|
| 87 |
+
|
| 88 |
+
if hasattr(torch.version, 'hip') and torch.version.hip:
|
| 89 |
+
print(f"HIP version: {torch.version.hip}")
|
| 90 |
+
|
| 91 |
+
print("\n=== Loading Extension ===")
|
| 92 |
+
print("This may take a while and will show compilation output...")
|
| 93 |
+
print("If this hangs, it indicates the same issue as build.py")
|
| 94 |
+
|
| 95 |
+
try:
|
| 96 |
+
# Mimic the same load call as build.py
|
| 97 |
+
simple_ext = load(
|
| 98 |
+
name="simple_test_ext",
|
| 99 |
+
sources=["simple_kernel.cu"],
|
| 100 |
+
extra_cflags=["-O3", "-std=c++17"],
|
| 101 |
+
extra_cuda_cflags=["-O3"], # torch switches this to hipcc on ROCm
|
| 102 |
+
verbose=True,
|
| 103 |
+
is_python_module=False
|
| 104 |
+
)
|
| 105 |
+
print("✓ Extension compilation successful!")
|
| 106 |
+
|
| 107 |
+
# Test the extension
|
| 108 |
+
print("\n=== Testing Extension ===")
|
| 109 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 110 |
+
a = torch.randn(1000, device=device)
|
| 111 |
+
b = torch.randn(1000, device=device)
|
| 112 |
+
|
| 113 |
+
if device == 'cuda':
|
| 114 |
+
result = simple_ext.add_tensors(a, b)
|
| 115 |
+
expected = a + b
|
| 116 |
+
if torch.allclose(result, expected):
|
| 117 |
+
print("✓ Extension execution successful!")
|
| 118 |
+
else:
|
| 119 |
+
print("✗ Extension execution failed - results don't match")
|
| 120 |
+
else:
|
| 121 |
+
print("⚠ No CUDA device, skipping execution test")
|
| 122 |
+
|
| 123 |
+
except Exception as e:
|
| 124 |
+
print(f"✗ Extension compilation/loading failed: {e}")
|
| 125 |
+
import traceback
|
| 126 |
+
traceback.print_exc()
|
| 127 |
+
EOF
|
| 128 |
+
|
| 129 |
+
echo "=== Running PyTorch Extension Test ==="
|
| 130 |
+
echo "This test mimics the same compilation process as build.py"
|
| 131 |
+
echo "If this hangs, it shows the same issue as the main build"
|
| 132 |
+
echo
|
| 133 |
+
|
| 134 |
+
# Set a timeout to prevent infinite hang
|
| 135 |
+
timeout 300 python3 test_extension.py || {
|
| 136 |
+
exit_code=$?
|
| 137 |
+
if [ $exit_code -eq 124 ]; then
|
| 138 |
+
echo "✗ Extension compilation timed out after 5 minutes (same as build.py hang)"
|
| 139 |
+
else
|
| 140 |
+
echo "✗ Extension compilation failed with exit code $exit_code"
|
| 141 |
+
fi
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
echo
|
| 145 |
+
echo "=== Testing with Minimal Sources ==="
|
| 146 |
+
|
| 147 |
+
# Create an even simpler version
|
| 148 |
+
cat > minimal_kernel.cu << 'EOF'
|
| 149 |
+
#include <torch/extension.h>
|
| 150 |
+
|
| 151 |
+
torch::Tensor dummy_function(torch::Tensor input) {
|
| 152 |
+
return input.clone();
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 156 |
+
m.def("dummy", &dummy_function, "Dummy function");
|
| 157 |
+
}
|
| 158 |
+
EOF
|
| 159 |
+
|
| 160 |
+
cat > test_minimal.py << 'EOF'
|
| 161 |
+
import torch
|
| 162 |
+
from torch.utils.cpp_extension import load
|
| 163 |
+
|
| 164 |
+
print("=== Minimal Extension Test ===")
|
| 165 |
+
|
| 166 |
+
try:
|
| 167 |
+
minimal_ext = load(
|
| 168 |
+
name="minimal_test_ext",
|
| 169 |
+
sources=["minimal_kernel.cu"],
|
| 170 |
+
extra_cflags=["-O3"],
|
| 171 |
+
verbose=True,
|
| 172 |
+
with_cuda=False # Skip CUDA/HIP compilation
|
| 173 |
+
)
|
| 174 |
+
print("✓ Minimal extension (CPU only) successful!")
|
| 175 |
+
except Exception as e:
|
| 176 |
+
print(f"✗ Even minimal extension failed: {e}")
|
| 177 |
+
EOF
|
| 178 |
+
|
| 179 |
+
echo "Testing minimal CPU-only extension..."
|
| 180 |
+
timeout 120 python3 test_minimal.py || echo "Minimal extension also failed/timed out"
|
| 181 |
+
|
| 182 |
+
echo
|
| 183 |
+
echo "=== Debug Script 3 Complete ==="
|
| 184 |
+
|
| 185 |
+
# Cleanup
|
| 186 |
+
cd /
|
| 187 |
+
rm -rf /tmp/torch_ext_test
|
_dev/debug-build-4-megablocks.sh
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
# Debug script 4: MegaBlocks-specific build debugging
|
| 4 |
+
|
| 5 |
+
set -euo pipefail
|
| 6 |
+
|
| 7 |
+
echo "=== MegaBlocks Build Debug Script 4 ==="
|
| 8 |
+
echo "Testing MegaBlocks-specific compilation components"
|
| 9 |
+
echo
|
| 10 |
+
|
| 11 |
+
# Set ROCm environment variables
|
| 12 |
+
export ROCM_PATH="${ROCM_PATH:-/opt/rocm-7.0.1}"
|
| 13 |
+
export ROCM_HOME="${ROCM_HOME:-$ROCM_PATH}"
|
| 14 |
+
export HIP_PATH="${HIP_PATH:-$ROCM_PATH}"
|
| 15 |
+
export HIP_HOME="${HIP_HOME:-$ROCM_PATH}"
|
| 16 |
+
export PATH="$ROCM_HOME/bin:$PATH"
|
| 17 |
+
export TORCH_HIP_ARCH_LIST="${TORCH_HIP_ARCH_LIST:-gfx942}"
|
| 18 |
+
export HSA_OVERRIDE_GFX_VERSION="${HSA_OVERRIDE_GFX_VERSION:-gfx942}"
|
| 19 |
+
export TORCH_EXTENSIONS_DIR="${TORCH_EXTENSIONS_DIR:-$PWD/.torch_extensions_debug}"
|
| 20 |
+
|
| 21 |
+
echo "Working directory: $(pwd)"
|
| 22 |
+
echo
|
| 23 |
+
|
| 24 |
+
echo "=== Checking MegaBlocks Source Files ==="
|
| 25 |
+
echo "Verifying all source files exist:"
|
| 26 |
+
|
| 27 |
+
sources=(
|
| 28 |
+
"torch-ext/torch_binding.cpp"
|
| 29 |
+
"csrc/new_cumsum.cu"
|
| 30 |
+
"csrc/new_histogram.cu"
|
| 31 |
+
"csrc/new_indices.cu"
|
| 32 |
+
"csrc/new_replicate.cu"
|
| 33 |
+
"csrc/new_sort.cu"
|
| 34 |
+
"csrc/grouped_gemm/grouped_gemm.cu"
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
all_exist=true
|
| 38 |
+
for src in "${sources[@]}"; do
|
| 39 |
+
if [ -f "$src" ]; then
|
| 40 |
+
echo "✓ $src exists ($(wc -l < "$src") lines)"
|
| 41 |
+
else
|
| 42 |
+
echo "✗ $src missing"
|
| 43 |
+
all_exist=false
|
| 44 |
+
fi
|
| 45 |
+
done
|
| 46 |
+
|
| 47 |
+
if [ "$all_exist" = false ]; then
|
| 48 |
+
echo "Cannot proceed - missing source files"
|
| 49 |
+
exit 1
|
| 50 |
+
fi
|
| 51 |
+
|
| 52 |
+
echo
|
| 53 |
+
echo "=== Checking Include Directories ==="
|
| 54 |
+
if [ -d "csrc" ]; then
|
| 55 |
+
echo "✓ csrc include directory exists"
|
| 56 |
+
echo "Headers in csrc/:"
|
| 57 |
+
find csrc -name "*.h" -o -name "*.hpp" | head -10
|
| 58 |
+
else
|
| 59 |
+
echo "✗ csrc include directory missing"
|
| 60 |
+
fi
|
| 61 |
+
|
| 62 |
+
echo
|
| 63 |
+
echo "=== Testing Individual Source Compilation ==="
|
| 64 |
+
|
| 65 |
+
# Test compiling each .cu file individually
|
| 66 |
+
for src in csrc/*.cu; do
|
| 67 |
+
if [ -f "$src" ]; then
|
| 68 |
+
echo "Testing compilation of $(basename "$src")..."
|
| 69 |
+
if timeout 60 hipcc -c "$src" -o "/tmp/$(basename "$src" .cu).o" \
|
| 70 |
+
--amdgpu-target=gfx942 \
|
| 71 |
+
-I./csrc \
|
| 72 |
+
-I"$(python3 -c 'import torch; print(torch.utils.cpp_extension.include_paths()[0])')" \
|
| 73 |
+
-std=c++17 \
|
| 74 |
+
-O3 \
|
| 75 |
+
-fPIC; then
|
| 76 |
+
echo "✓ $(basename "$src") compiled successfully"
|
| 77 |
+
else
|
| 78 |
+
echo "✗ $(basename "$src") compilation failed"
|
| 79 |
+
fi
|
| 80 |
+
fi
|
| 81 |
+
done
|
| 82 |
+
|
| 83 |
+
echo
|
| 84 |
+
echo "=== Testing grouped_gemm.cu Specifically ==="
|
| 85 |
+
echo "This is often the most complex kernel..."
|
| 86 |
+
|
| 87 |
+
if timeout 120 hipcc -c csrc/grouped_gemm/grouped_gemm.cu -o /tmp/grouped_gemm.o \
|
| 88 |
+
--amdgpu-target=gfx942 \
|
| 89 |
+
-I./csrc \
|
| 90 |
+
-I"$(python3 -c 'import torch; print(torch.utils.cpp_extension.include_paths()[0])')" \
|
| 91 |
+
-std=c++17 \
|
| 92 |
+
-O3 \
|
| 93 |
+
-fPIC \
|
| 94 |
+
-lhipblaslt \
|
| 95 |
+
-v; then
|
| 96 |
+
echo "✓ grouped_gemm.cu compiled successfully"
|
| 97 |
+
else
|
| 98 |
+
echo "✗ grouped_gemm.cu compilation failed"
|
| 99 |
+
fi
|
| 100 |
+
|
| 101 |
+
echo
|
| 102 |
+
echo "=== Testing torch_binding.cpp ==="
|
| 103 |
+
if timeout 60 hipcc -c torch-ext/torch_binding.cpp -o /tmp/torch_binding.o \
|
| 104 |
+
-I./csrc \
|
| 105 |
+
-I"$(python3 -c 'import torch; print(torch.utils.cpp_extension.include_paths()[0])')" \
|
| 106 |
+
-std=c++17 \
|
| 107 |
+
-O3 \
|
| 108 |
+
-fPIC; then
|
| 109 |
+
echo "✓ torch_binding.cpp compiled successfully"
|
| 110 |
+
else
|
| 111 |
+
echo "✗ torch_binding.cpp compilation failed"
|
| 112 |
+
fi
|
| 113 |
+
|
| 114 |
+
echo
|
| 115 |
+
echo "=== Testing Incremental PyTorch Extension Build ==="
|
| 116 |
+
|
| 117 |
+
cat > debug_build.py << 'EOF'
|
| 118 |
+
import os
|
| 119 |
+
import pathlib
|
| 120 |
+
import sys
|
| 121 |
+
import signal
|
| 122 |
+
import time
|
| 123 |
+
from torch.utils.cpp_extension import load
|
| 124 |
+
|
| 125 |
+
def timeout_handler(signum, frame):
|
| 126 |
+
print("Build timed out - this indicates a hanging issue")
|
| 127 |
+
sys.exit(1)
|
| 128 |
+
|
| 129 |
+
# Set up timeout
|
| 130 |
+
signal.signal(signal.SIGALRM, timeout_handler)
|
| 131 |
+
signal.alarm(180) # 3 minute timeout
|
| 132 |
+
|
| 133 |
+
repo = pathlib.Path(".").resolve()
|
| 134 |
+
os.environ.setdefault("TORCH_EXTENSIONS_DIR", str(repo / ".torch_extensions_debug"))
|
| 135 |
+
|
| 136 |
+
print("=== Testing with Single Source File ===")
|
| 137 |
+
try:
|
| 138 |
+
print("Building with just new_cumsum.cu...")
|
| 139 |
+
mod = load(
|
| 140 |
+
name="_megablocks_debug_single",
|
| 141 |
+
sources=["csrc/new_cumsum.cu"],
|
| 142 |
+
extra_include_paths=["csrc"],
|
| 143 |
+
extra_cflags=["-O3", "-std=c++17"],
|
| 144 |
+
extra_cuda_cflags=["-O3"],
|
| 145 |
+
verbose=True,
|
| 146 |
+
is_python_module=False,
|
| 147 |
+
)
|
| 148 |
+
print("✓ Single source build successful")
|
| 149 |
+
except Exception as e:
|
| 150 |
+
print(f"✗ Single source build failed: {e}")
|
| 151 |
+
|
| 152 |
+
print("\n=== Testing with Two Source Files ===")
|
| 153 |
+
try:
|
| 154 |
+
print("Building with new_cumsum.cu and new_histogram.cu...")
|
| 155 |
+
mod = load(
|
| 156 |
+
name="_megablocks_debug_double",
|
| 157 |
+
sources=["csrc/new_cumsum.cu", "csrc/new_histogram.cu"],
|
| 158 |
+
extra_include_paths=["csrc"],
|
| 159 |
+
extra_cflags=["-O3", "-std=c++17"],
|
| 160 |
+
extra_cuda_cflags=["-O3"],
|
| 161 |
+
verbose=True,
|
| 162 |
+
is_python_module=False,
|
| 163 |
+
)
|
| 164 |
+
print("✓ Double source build successful")
|
| 165 |
+
except Exception as e:
|
| 166 |
+
print(f"✗ Double source build failed: {e}")
|
| 167 |
+
|
| 168 |
+
print("\n=== Testing with grouped_gemm.cu Only ===")
|
| 169 |
+
try:
|
| 170 |
+
print("Building with just grouped_gemm.cu (most complex)...")
|
| 171 |
+
mod = load(
|
| 172 |
+
name="_megablocks_debug_gemm",
|
| 173 |
+
sources=["csrc/grouped_gemm/grouped_gemm.cu"],
|
| 174 |
+
extra_include_paths=["csrc"],
|
| 175 |
+
extra_cflags=["-O3", "-std=c++17"],
|
| 176 |
+
extra_cuda_cflags=["-O3"],
|
| 177 |
+
extra_ldflags=["-lhipblaslt"],
|
| 178 |
+
verbose=True,
|
| 179 |
+
is_python_module=False,
|
| 180 |
+
)
|
| 181 |
+
print("✓ grouped_gemm build successful")
|
| 182 |
+
except Exception as e:
|
| 183 |
+
print(f"✗ grouped_gemm build failed: {e}")
|
| 184 |
+
|
| 185 |
+
signal.alarm(0) # Cancel timeout
|
| 186 |
+
EOF
|
| 187 |
+
|
| 188 |
+
echo "Running incremental build test..."
|
| 189 |
+
python3 debug_build.py
|
| 190 |
+
|
| 191 |
+
echo
|
| 192 |
+
echo "=== Testing Full Build with Timeout ==="
|
| 193 |
+
|
| 194 |
+
cat > debug_full_build.py << 'EOF'
|
| 195 |
+
import os
|
| 196 |
+
import pathlib
|
| 197 |
+
import sys
|
| 198 |
+
import signal
|
| 199 |
+
from torch.utils.cpp_extension import load
|
| 200 |
+
|
| 201 |
+
def timeout_handler(signum, frame):
|
| 202 |
+
print("Full build timed out - this confirms the hanging issue")
|
| 203 |
+
sys.exit(124) # timeout exit code
|
| 204 |
+
|
| 205 |
+
# Set up 5 minute timeout
|
| 206 |
+
signal.signal(signal.SIGALRM, timeout_handler)
|
| 207 |
+
signal.alarm(300)
|
| 208 |
+
|
| 209 |
+
repo = pathlib.Path(".").resolve()
|
| 210 |
+
os.environ.setdefault("TORCH_EXTENSIONS_DIR", str(repo / ".torch_extensions_debug"))
|
| 211 |
+
|
| 212 |
+
sources = [
|
| 213 |
+
"torch-ext/torch_binding.cpp",
|
| 214 |
+
"csrc/new_cumsum.cu",
|
| 215 |
+
"csrc/new_histogram.cu",
|
| 216 |
+
"csrc/new_indices.cu",
|
| 217 |
+
"csrc/new_replicate.cu",
|
| 218 |
+
"csrc/new_sort.cu",
|
| 219 |
+
"csrc/grouped_gemm/grouped_gemm.cu",
|
| 220 |
+
]
|
| 221 |
+
|
| 222 |
+
print("=== Attempting Full MegaBlocks Build ===")
|
| 223 |
+
print("This mimics the exact build.py process...")
|
| 224 |
+
print("Sources:", sources)
|
| 225 |
+
|
| 226 |
+
try:
|
| 227 |
+
mod = load(
|
| 228 |
+
name="_megablocks_debug_full",
|
| 229 |
+
sources=sources,
|
| 230 |
+
extra_include_paths=["csrc"],
|
| 231 |
+
extra_cflags=["-O3", "-std=c++17"],
|
| 232 |
+
extra_cuda_cflags=["-O3"],
|
| 233 |
+
extra_ldflags=["-lhipblaslt"],
|
| 234 |
+
verbose=True,
|
| 235 |
+
is_python_module=False,
|
| 236 |
+
)
|
| 237 |
+
print("✓ Full build successful!")
|
| 238 |
+
print("Built:", mod)
|
| 239 |
+
|
| 240 |
+
except Exception as e:
|
| 241 |
+
print(f"✗ Full build failed: {e}")
|
| 242 |
+
import traceback
|
| 243 |
+
traceback.print_exc()
|
| 244 |
+
|
| 245 |
+
signal.alarm(0)
|
| 246 |
+
EOF
|
| 247 |
+
|
| 248 |
+
echo "Running full build test (with timeout)..."
|
| 249 |
+
python3 debug_full_build.py
|
| 250 |
+
|
| 251 |
+
echo
|
| 252 |
+
echo "=== Cleanup ==="
|
| 253 |
+
rm -f /tmp/*.o
|
| 254 |
+
rm -f debug_build.py debug_full_build.py
|
| 255 |
+
|
| 256 |
+
echo
|
| 257 |
+
echo "=== Debug Script 4 Complete ==="
|
_dev/debug-build-all.sh
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
# Debug script: Run all debug tests in sequence
|
| 4 |
+
|
| 5 |
+
set -euo pipefail
|
| 6 |
+
|
| 7 |
+
echo "=== MegaBlocks Build Debugging Suite ==="
|
| 8 |
+
echo "Running progressive debug tests to identify build hang issue"
|
| 9 |
+
echo "ROCm installation: /opt/rocm-7.0.1"
|
| 10 |
+
echo
|
| 11 |
+
|
| 12 |
+
# Make all scripts executable
|
| 13 |
+
chmod +x debug-build-1-env.sh debug-build-2-hipcc.sh debug-build-3-torch-ext.sh debug-build-4-megablocks.sh
|
| 14 |
+
|
| 15 |
+
scripts=(
|
| 16 |
+
"debug-build-1-env.sh"
|
| 17 |
+
"debug-build-2-hipcc.sh"
|
| 18 |
+
"debug-build-3-torch-ext.sh"
|
| 19 |
+
"debug-build-4-megablocks.sh"
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
results=()
|
| 23 |
+
start_time=$(date +%s)
|
| 24 |
+
|
| 25 |
+
for script in "${scripts[@]}"; do
|
| 26 |
+
echo
|
| 27 |
+
echo "========================================"
|
| 28 |
+
echo "Running $script"
|
| 29 |
+
echo "========================================"
|
| 30 |
+
|
| 31 |
+
script_start=$(date +%s)
|
| 32 |
+
|
| 33 |
+
if ./"$script"; then
|
| 34 |
+
script_end=$(date +%s)
|
| 35 |
+
duration=$((script_end - script_start))
|
| 36 |
+
echo "✓ $script completed successfully in ${duration}s"
|
| 37 |
+
results+=("✓ $script: SUCCESS (${duration}s)")
|
| 38 |
+
else
|
| 39 |
+
script_end=$(date +%s)
|
| 40 |
+
duration=$((script_end - script_start))
|
| 41 |
+
echo "✗ $script failed in ${duration}s"
|
| 42 |
+
results+=("✗ $script: FAILED (${duration}s)")
|
| 43 |
+
fi
|
| 44 |
+
|
| 45 |
+
echo "----------------------------------------"
|
| 46 |
+
done
|
| 47 |
+
|
| 48 |
+
end_time=$(date +%s)
|
| 49 |
+
total_duration=$((end_time - start_time))
|
| 50 |
+
|
| 51 |
+
echo
|
| 52 |
+
echo "========================================"
|
| 53 |
+
echo "SUMMARY REPORT"
|
| 54 |
+
echo "========================================"
|
| 55 |
+
echo "Total runtime: ${total_duration}s"
|
| 56 |
+
echo
|
| 57 |
+
|
| 58 |
+
for result in "${results[@]}"; do
|
| 59 |
+
echo "$result"
|
| 60 |
+
done
|
| 61 |
+
|
| 62 |
+
echo
|
| 63 |
+
echo "=== Analysis ==="
|
| 64 |
+
echo "1. If debug-1-env.sh fails: ROCm installation/environment issue"
|
| 65 |
+
echo "2. If debug-2-hipcc.sh fails: HIP compiler issue"
|
| 66 |
+
echo "3. If debug-3-torch-ext.sh hangs: PyTorch extension compilation issue"
|
| 67 |
+
echo "4. If debug-4-megablocks.sh hangs: MegaBlocks-specific compilation issue"
|
| 68 |
+
echo
|
| 69 |
+
echo "=== Next Steps Based on Results ==="
|
| 70 |
+
echo "- If all pass: The issue may be intermittent or environment-specific"
|
| 71 |
+
echo "- If script 3 or 4 hangs: Run with strace to see where it hangs:"
|
| 72 |
+
echo " strace -f -e trace=process,signal python3 build.py"
|
| 73 |
+
echo "- Check compilation log files in .torch_extensions for more details"
|
| 74 |
+
echo "- Consider using PYTORCH_JIT_LOG_LEVEL=1 for more verbose output"
|
| 75 |
+
|
| 76 |
+
echo
|
| 77 |
+
echo "=== Additional Debugging Commands ==="
|
| 78 |
+
echo "# Check for stuck processes:"
|
| 79 |
+
echo "ps aux | grep -E '(hipcc|hip-clang|python)'"
|
| 80 |
+
echo
|
| 81 |
+
echo "# Monitor system resources during build:"
|
| 82 |
+
echo "htop"
|
| 83 |
+
echo
|
| 84 |
+
echo "# Check for device issues:"
|
| 85 |
+
echo "dmesg | tail -20"
|
| 86 |
+
echo
|
| 87 |
+
echo "# Force clean rebuild:"
|
| 88 |
+
echo "rm -rf .torch_extensions* && ./build.sh"
|
| 89 |
+
|
| 90 |
+
echo
|
| 91 |
+
echo "Debug suite complete."
|
_dev/debug_build.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pathlib
|
| 3 |
+
import sys
|
| 4 |
+
import signal
|
| 5 |
+
import time
|
| 6 |
+
from torch.utils.cpp_extension import load
|
| 7 |
+
|
| 8 |
+
def timeout_handler(signum, frame):
|
| 9 |
+
print("Build timed out - this indicates a hanging issue")
|
| 10 |
+
sys.exit(1)
|
| 11 |
+
|
| 12 |
+
# Set up timeout
|
| 13 |
+
signal.signal(signal.SIGALRM, timeout_handler)
|
| 14 |
+
signal.alarm(180) # 3 minute timeout
|
| 15 |
+
|
| 16 |
+
repo = pathlib.Path(".").resolve()
|
| 17 |
+
os.environ.setdefault("TORCH_EXTENSIONS_DIR", str(repo / ".torch_extensions_debug"))
|
| 18 |
+
|
| 19 |
+
print("=== Testing with Single Source File ===")
|
| 20 |
+
try:
|
| 21 |
+
print("Building with just new_cumsum.cu...")
|
| 22 |
+
mod = load(
|
| 23 |
+
name="_megablocks_debug_single",
|
| 24 |
+
sources=["csrc/new_cumsum.cu"],
|
| 25 |
+
extra_include_paths=["csrc"],
|
| 26 |
+
extra_cflags=["-O3", "-std=c++17"],
|
| 27 |
+
extra_cuda_cflags=["-O3"],
|
| 28 |
+
verbose=True,
|
| 29 |
+
is_python_module=False,
|
| 30 |
+
)
|
| 31 |
+
print("✓ Single source build successful")
|
| 32 |
+
except Exception as e:
|
| 33 |
+
print(f"✗ Single source build failed: {e}")
|
| 34 |
+
|
| 35 |
+
print("\n=== Testing with Two Source Files ===")
|
| 36 |
+
try:
|
| 37 |
+
print("Building with new_cumsum.cu and new_histogram.cu...")
|
| 38 |
+
mod = load(
|
| 39 |
+
name="_megablocks_debug_double",
|
| 40 |
+
sources=["csrc/new_cumsum.cu", "csrc/new_histogram.cu"],
|
| 41 |
+
extra_include_paths=["csrc"],
|
| 42 |
+
extra_cflags=["-O3", "-std=c++17"],
|
| 43 |
+
extra_cuda_cflags=["-O3"],
|
| 44 |
+
verbose=True,
|
| 45 |
+
is_python_module=False,
|
| 46 |
+
)
|
| 47 |
+
print("✓ Double source build successful")
|
| 48 |
+
except Exception as e:
|
| 49 |
+
print(f"✗ Double source build failed: {e}")
|
| 50 |
+
|
| 51 |
+
print("\n=== Testing with grouped_gemm.cu Only ===")
|
| 52 |
+
try:
|
| 53 |
+
print("Building with just grouped_gemm.cu (most complex)...")
|
| 54 |
+
mod = load(
|
| 55 |
+
name="_megablocks_debug_gemm",
|
| 56 |
+
sources=["csrc/grouped_gemm/grouped_gemm.cu"],
|
| 57 |
+
extra_include_paths=["csrc"],
|
| 58 |
+
extra_cflags=["-O3", "-std=c++17"],
|
| 59 |
+
extra_cuda_cflags=["-O3"],
|
| 60 |
+
extra_ldflags=["-lhipblaslt"],
|
| 61 |
+
verbose=True,
|
| 62 |
+
is_python_module=False,
|
| 63 |
+
)
|
| 64 |
+
print("✓ grouped_gemm build successful")
|
| 65 |
+
except Exception as e:
|
| 66 |
+
print(f"✗ grouped_gemm build failed: {e}")
|
| 67 |
+
|
| 68 |
+
signal.alarm(0) # Cancel timeout
|
build-fixed.sh
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
set -euo pipefail
|
| 4 |
+
|
| 5 |
+
# Fixed build script with proper ROCm/HIP environment
|
| 6 |
+
|
| 7 |
+
echo "=== Fixed Build Script ==="
|
| 8 |
+
echo "Configuring environment for ROCm 7.0.1 with proper exports"
|
| 9 |
+
|
| 10 |
+
# Default to the ROCm 7.0.1 install unless the caller overrides it.
|
| 11 |
+
export ROCM_PATH="${ROCM_PATH:-/opt/rocm-7.0.1}"
|
| 12 |
+
export ROCM_HOME="${ROCM_HOME:-$ROCM_PATH}"
|
| 13 |
+
export HIP_PATH="${HIP_PATH:-$ROCM_PATH}"
|
| 14 |
+
export HIP_HOME="${HIP_HOME:-$ROCM_PATH}"
|
| 15 |
+
|
| 16 |
+
export PATH="$ROCM_HOME/bin:$PATH"
|
| 17 |
+
|
| 18 |
+
# Fix architecture specifications - use gfx942 consistently
|
| 19 |
+
export TORCH_HIP_ARCH_LIST="gfx942"
|
| 20 |
+
export PYTORCH_ROCM_ARCH="gfx942"
|
| 21 |
+
|
| 22 |
+
# Remove HSA_OVERRIDE_GFX_VERSION - not needed since MI300X is already gfx942
|
| 23 |
+
unset HSA_OVERRIDE_GFX_VERSION
|
| 24 |
+
|
| 25 |
+
# Force single-threaded compilation to avoid ninja hanging
|
| 26 |
+
export MAX_JOBS=1
|
| 27 |
+
|
| 28 |
+
# Enable PyTorch JIT logging for debugging
|
| 29 |
+
export PYTORCH_JIT_LOG_LEVEL=1
|
| 30 |
+
|
| 31 |
+
export TORCH_EXTENSIONS_DIR="${TORCH_EXTENSIONS_DIR:-$PWD/.torch_extensions}"
|
| 32 |
+
|
| 33 |
+
echo "Environment configured:"
|
| 34 |
+
echo "ROCM_PATH=$ROCM_PATH"
|
| 35 |
+
echo "TORCH_HIP_ARCH_LIST=$TORCH_HIP_ARCH_LIST"
|
| 36 |
+
echo "PYTORCH_ROCM_ARCH=$PYTORCH_ROCM_ARCH"
|
| 37 |
+
echo "MAX_JOBS=$MAX_JOBS"
|
| 38 |
+
echo "PYTORCH_JIT_LOG_LEVEL=$PYTORCH_JIT_LOG_LEVEL"
|
| 39 |
+
echo "HSA_OVERRIDE_GFX_VERSION=${HSA_OVERRIDE_GFX_VERSION:-unset}"
|
| 40 |
+
echo
|
| 41 |
+
|
| 42 |
+
echo "Starting build..."
|
| 43 |
+
python -u build.py
|
build-strace.sh
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
set -euo pipefail
|
| 4 |
+
|
| 5 |
+
# Build script with strace to debug hanging
|
| 6 |
+
|
| 7 |
+
echo "=== Build with strace debugging ==="
|
| 8 |
+
echo "This will trace system calls to identify where the build hangs"
|
| 9 |
+
|
| 10 |
+
# Same environment as build-fixed.sh but without MAX_JOBS limit
|
| 11 |
+
export ROCM_PATH="${ROCM_PATH:-/opt/rocm-7.0.1}"
|
| 12 |
+
export ROCM_HOME="${ROCM_HOME:-$ROCM_PATH}"
|
| 13 |
+
export HIP_PATH="${HIP_PATH:-$ROCM_PATH}"
|
| 14 |
+
export HIP_HOME="${HIP_HOME:-$ROCM_PATH}"
|
| 15 |
+
|
| 16 |
+
export PATH="$ROCM_HOME/bin:$PATH"
|
| 17 |
+
|
| 18 |
+
# Fix architecture specifications
|
| 19 |
+
export TORCH_HIP_ARCH_LIST="gfx942"
|
| 20 |
+
export PYTORCH_ROCM_ARCH="gfx942"
|
| 21 |
+
|
| 22 |
+
# Remove HSA_OVERRIDE_GFX_VERSION
|
| 23 |
+
unset HSA_OVERRIDE_GFX_VERSION
|
| 24 |
+
|
| 25 |
+
# Remove MAX_JOBS limit to see parallel compilation hang
|
| 26 |
+
unset MAX_JOBS
|
| 27 |
+
|
| 28 |
+
# Enable PyTorch JIT logging
|
| 29 |
+
export PYTORCH_JIT_LOG_LEVEL=1
|
| 30 |
+
|
| 31 |
+
export TORCH_EXTENSIONS_DIR="${TORCH_EXTENSIONS_DIR:-$PWD/.torch_extensions}"
|
| 32 |
+
|
| 33 |
+
echo "Environment configured for strace:"
|
| 34 |
+
echo "ROCM_PATH=$ROCM_PATH"
|
| 35 |
+
echo "TORCH_HIP_ARCH_LIST=$TORCH_HIP_ARCH_LIST"
|
| 36 |
+
echo "PYTORCH_ROCM_ARCH=$PYTORCH_ROCM_ARCH"
|
| 37 |
+
echo "MAX_JOBS=${MAX_JOBS:-unset}"
|
| 38 |
+
echo "PYTORCH_JIT_LOG_LEVEL=$PYTORCH_JIT_LOG_LEVEL"
|
| 39 |
+
echo
|
| 40 |
+
|
| 41 |
+
echo "Starting build with strace..."
|
| 42 |
+
echo "Tracing process creation, signals, and file operations..."
|
| 43 |
+
echo "Output will be saved to strace.log"
|
| 44 |
+
|
| 45 |
+
# Use strace to trace the build process
|
| 46 |
+
# -f: follow child processes
|
| 47 |
+
# -e trace=process,signal: trace process creation and signals
|
| 48 |
+
# -e trace=file: trace file operations
|
| 49 |
+
# -o strace.log: save output to file
|
| 50 |
+
# -T: show time spent in each syscall
|
| 51 |
+
strace -f -e trace=process,signal,file -o strace.log -T python -u build.py
|
| 52 |
+
|
| 53 |
+
echo "Build completed or interrupted"
|
| 54 |
+
echo "Check strace.log for detailed system call trace"
|
build.sh
CHANGED
|
@@ -1,9 +1,39 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
set -euo pipefail
|
| 4 |
+
|
| 5 |
+
# 3-4min build with cleanup to prevent lock file hangs
|
| 6 |
+
|
| 7 |
+
echo "=== MegaBlocks Build Script ==="
|
| 8 |
+
echo "Cleaning up any previous build processes and lock files..."
|
| 9 |
+
|
| 10 |
+
# Kill any hanging build processes
|
| 11 |
+
echo "Killing any running build.py processes..."
|
| 12 |
+
pkill -f "python.*build\.py" 2>/dev/null || true
|
| 13 |
+
pkill -f "ninja" 2>/dev/null || true
|
| 14 |
+
pkill -f "hipcc" 2>/dev/null || true
|
| 15 |
+
|
| 16 |
+
# Wait a moment for processes to terminate
|
| 17 |
+
sleep 2
|
| 18 |
+
|
| 19 |
+
# Clean up lock files that cause infinite loops
|
| 20 |
+
echo "Removing stale lock files..."
|
| 21 |
+
if [ -d ".torch_extensions" ]; then
|
| 22 |
+
find .torch_extensions -name "lock" -delete 2>/dev/null || true
|
| 23 |
+
find .torch_extensions -name ".ninja_lock" -delete 2>/dev/null || true
|
| 24 |
+
fi
|
| 25 |
+
|
| 26 |
+
# Default to the ROCm 7.0.1 install unless the caller overrides it.
|
| 27 |
+
export ROCM_PATH="${ROCM_PATH:-/opt/rocm-7.0.1}"
|
| 28 |
+
export ROCM_HOME="${ROCM_HOME:-$ROCM_PATH}"
|
| 29 |
+
export HIP_PATH="${HIP_PATH:-$ROCM_PATH}"
|
| 30 |
+
export HIP_HOME="${HIP_HOME:-$ROCM_PATH}"
|
| 31 |
+
|
| 32 |
+
export PATH="$ROCM_HOME/bin:$PATH"
|
| 33 |
+
export LD_LIBRARY_PATH="$ROCM_HOME/lib:$ROCM_HOME/lib64:${LD_LIBRARY_PATH:-}"
|
| 34 |
+
export TORCH_HIP_ARCH_LIST="${TORCH_HIP_ARCH_LIST:-gfx942}"
|
| 35 |
+
export HSA_OVERRIDE_GFX_VERSION="${HSA_OVERRIDE_GFX_VERSION:-gfx942}"
|
| 36 |
+
export TORCH_EXTENSIONS_DIR="${TORCH_EXTENSIONS_DIR:-$PWD/.torch_extensions}"
|
| 37 |
+
|
| 38 |
+
echo "Environment configured. Starting build..."
|
| 39 |
+
python -u build.py
|
run-tests.sh
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
KERNEL_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
| 5 |
+
cd "$KERNEL_DIR"
|
| 6 |
+
|
| 7 |
+
export KERNEL_DIR
|
| 8 |
+
|
| 9 |
+
detect_variant() {
|
| 10 |
+
python - <<'PY'
|
| 11 |
+
import os
|
| 12 |
+
import pathlib
|
| 13 |
+
|
| 14 |
+
root = pathlib.Path(os.environ["KERNEL_DIR"])
|
| 15 |
+
build_dir = root / "build"
|
| 16 |
+
variant = None
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from kernels.utils import build_variant as _build_variant
|
| 20 |
+
except Exception:
|
| 21 |
+
_build_variant = None
|
| 22 |
+
|
| 23 |
+
if _build_variant is not None:
|
| 24 |
+
try:
|
| 25 |
+
variant = _build_variant()
|
| 26 |
+
except Exception:
|
| 27 |
+
variant = None
|
| 28 |
+
|
| 29 |
+
if variant is None:
|
| 30 |
+
candidates = sorted(build_dir.glob("torch*-rocm64-*") or build_dir.glob("torch*-cu*"))
|
| 31 |
+
if candidates:
|
| 32 |
+
variant = candidates[0].name
|
| 33 |
+
|
| 34 |
+
if variant is None:
|
| 35 |
+
raise SystemExit("Could not determine MegaBlocks build variant. Run build.py first.")
|
| 36 |
+
|
| 37 |
+
print(variant)
|
| 38 |
+
PY
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
VARIANT=$(detect_variant)
|
| 42 |
+
|
| 43 |
+
STAGED_DIR="$KERNEL_DIR/build/$VARIANT"
|
| 44 |
+
find_staged_lib() {
|
| 45 |
+
local base="$1"
|
| 46 |
+
local candidates=(
|
| 47 |
+
"$base/_megablocks_rocm.so"
|
| 48 |
+
"$base/megablocks/_megablocks_rocm.so"
|
| 49 |
+
)
|
| 50 |
+
for path in "${candidates[@]}"; do
|
| 51 |
+
if [[ -f "$path" ]]; then
|
| 52 |
+
echo "$path"
|
| 53 |
+
return 0
|
| 54 |
+
fi
|
| 55 |
+
done
|
| 56 |
+
return 1
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
STAGED_LIB=$(find_staged_lib "$STAGED_DIR") || true
|
| 60 |
+
|
| 61 |
+
if [[ -z "${STAGED_LIB:-}" ]]; then
|
| 62 |
+
echo "Staged ROCm extension not found under $STAGED_DIR; rebuilding kernels..."
|
| 63 |
+
python build.py
|
| 64 |
+
VARIANT=$(detect_variant)
|
| 65 |
+
STAGED_DIR="$KERNEL_DIR/build/$VARIANT"
|
| 66 |
+
STAGED_LIB=$(find_staged_lib "$STAGED_DIR") || true
|
| 67 |
+
if [[ -z "${STAGED_LIB:-}" ]]; then
|
| 68 |
+
echo "ERROR: build.py completed but no extension was found under $STAGED_DIR" >&2
|
| 69 |
+
exit 1
|
| 70 |
+
fi
|
| 71 |
+
fi
|
| 72 |
+
|
| 73 |
+
export PYTHONPATH="$STAGED_DIR:${PYTHONPATH:-}"
|
| 74 |
+
|
| 75 |
+
echo "Using MegaBlocks build variant: $VARIANT"
|
| 76 |
+
|
| 77 |
+
declare -i GPU_COUNT
|
| 78 |
+
GPU_COUNT=$(python - <<'PY'
|
| 79 |
+
import torch
|
| 80 |
+
print(torch.cuda.device_count() if torch.cuda.is_available() else 0)
|
| 81 |
+
PY
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
if (( GPU_COUNT == 0 )); then
|
| 85 |
+
echo "ERROR: No HIP/CUDA GPUs detected. Tests require at least one visible accelerator." >&2
|
| 86 |
+
exit 1
|
| 87 |
+
fi
|
| 88 |
+
|
| 89 |
+
echo "Detected $GPU_COUNT visible GPU(s)."
|
| 90 |
+
|
| 91 |
+
log() {
|
| 92 |
+
echo
|
| 93 |
+
echo "==> $1"
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
run_pytest() {
|
| 97 |
+
local label="$1"
|
| 98 |
+
shift
|
| 99 |
+
log "$label"
|
| 100 |
+
set -x
|
| 101 |
+
"$@"
|
| 102 |
+
{ set +x; } 2>/dev/null || true
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
SINGLE_GPU_ENV=(HIP_VISIBLE_DEVICES=0 CUDA_VISIBLE_DEVICES=0 WORLD_SIZE=1)
|
| 106 |
+
MULTI2_GPU_ENV=(HIP_VISIBLE_DEVICES=0,1 CUDA_VISIBLE_DEVICES=0,1 WORLD_SIZE=2)
|
| 107 |
+
MULTI8_GPU_ENV=(HIP_VISIBLE_DEVICES=$(seq -s, 0 7) CUDA_VISIBLE_DEVICES=$(seq -s, 0 7) WORLD_SIZE=8)
|
| 108 |
+
|
| 109 |
+
SINGLE_TESTS=(
|
| 110 |
+
"test_mb_moe.py"
|
| 111 |
+
"test_mb_moe_shared_expert.py"
|
| 112 |
+
"layer_test.py"
|
| 113 |
+
"test_gg.py"
|
| 114 |
+
"ops_test.py"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
for test in "${SINGLE_TESTS[@]}"; do
|
| 118 |
+
run_pytest "Single-GPU pytest ${test}" env "${SINGLE_GPU_ENV[@]}" python -m pytest "tests/${test}" -q
|
| 119 |
+
done
|
| 120 |
+
|
| 121 |
+
if (( GPU_COUNT >= 2 )); then
|
| 122 |
+
run_pytest "Distributed layer smoke (2 GPUs)" env "${MULTI2_GPU_ENV[@]}" python -m pytest "tests/parallel_layer_test.py::test_megablocks_moe_mlp_functionality" -q
|
| 123 |
+
else
|
| 124 |
+
log "Skipping 2-GPU distributed layer test (requires >=2 GPUs, detected ${GPU_COUNT})."
|
| 125 |
+
fi
|
| 126 |
+
|
| 127 |
+
run_pytest "Shared expert functionality (world_size=1)" env "${SINGLE_GPU_ENV[@]}" python -m pytest 'tests/test_mb_moe_shared_expert_multi.py::test_shared_expert_distributed_functionality[1]' -q
|
| 128 |
+
run_pytest "Shared expert weighted sum (world_size=1)" env "${SINGLE_GPU_ENV[@]}" python -m pytest 'tests/test_mb_moe_shared_expert_multi.py::test_shared_expert_distributed_weighted_sum[1]' -q
|
| 129 |
+
|
| 130 |
+
if (( GPU_COUNT >= 8 )); then
|
| 131 |
+
run_pytest "Shared expert functionality (world_size=8)" env "${MULTI8_GPU_ENV[@]}" python -m pytest 'tests/test_mb_moe_shared_expert_multi.py::test_shared_expert_distributed_functionality[8]' -q
|
| 132 |
+
run_pytest "Shared expert weighted sum (world_size=8)" env "${MULTI8_GPU_ENV[@]}" python -m pytest 'tests/test_mb_moe_shared_expert_multi.py::test_shared_expert_distributed_weighted_sum[8]' -q
|
| 133 |
+
else
|
| 134 |
+
log "Skipping 8-GPU shared expert tests (requires >=8 GPUs, detected ${GPU_COUNT})."
|
| 135 |
+
fi
|
| 136 |
+
|
| 137 |
+
echo
|
| 138 |
+
echo "All requested tests completed."
|