Spaces:
Runtime error
Runtime error
| # scripts/bootstrap_venvs.sh | |
| # Dual-venv bootstrap for Synesthesia ROCm 7.2.1 | |
| # Creates .venv-jax and .venv-torch with pinned wheels. | |
| set -e | |
| # --- Configuration --- | |
| FORCE=0 | |
| if [[ "$1" == "--force" ]]; then | |
| FORCE=1 | |
| fi | |
| # Function to log messages | |
| log() { | |
| echo -e "\033[1;34m[BOOTSTRAP]\033[0m $1" | |
| } | |
| # Function to log errors | |
| error() { | |
| echo -e "\033[1;31m[ERROR]\033[0m $1" >&2 | |
| } | |
| # --- 1. ROCm 7.2.1 Presence Check --- | |
| # We ensure the hardware environment is correct before proceeding. | |
| if [[ "$SKIP_ROCM_CHECK" != "1" ]]; then | |
| log "Checking for ROCm 7.2.1..." | |
| if [[ ! -d "/opt/rocm" ]]; then | |
| error "ROCm not found at /opt/rocm. Please install ROCm 7.2.1." | |
| exit 1 | |
| fi | |
| if ! command -v rocm-smi &> /dev/null; then | |
| error "rocm-smi not found. Ensure ROCm binaries are in your PATH." | |
| exit 1 | |
| fi | |
| log "ROCm 7.2.1 check passed." | |
| else | |
| log "Skipping ROCm check (SKIP_ROCM_CHECK=1)." | |
| fi | |
| # --- 2. Venv Management Functions --- | |
| # Function to verify JAX GPU visibility | |
| verify_jax() { | |
| local venv_path=$1 | |
| log "Verifying GPU visibility in $venv_path..." | |
| if "$venv_path/bin/python" -c "import jax; devices = jax.devices(); print(f'Devices: {devices}'); exit(0) if any(d.platform == 'gpu' or d.platform == 'rocm' for d in devices) else exit(1)" 2>/dev/null; then | |
| return 0 | |
| else | |
| return 1 | |
| fi | |
| } | |
| # Function to verify Torch GPU visibility | |
| verify_torch() { | |
| local venv_path=$1 | |
| log "Verifying GPU visibility in $venv_path..." | |
| if "$venv_path/bin/python" -c "import torch; print(f'CUDA Available: {torch.cuda.is_available()}'); exit(0) if torch.cuda.is_available() else exit(1)" 2>/dev/null; then | |
| return 0 | |
| else | |
| return 1 | |
| fi | |
| } | |
| # --- 3. Bootstrap .venv-jax --- | |
| # This venv is used for JAX-based inference and IREE/ONNX model exports. | |
| VENV_JAX=".venv-jax" | |
| if [[ $FORCE -eq 1 ]]; then | |
| log "Force flag detected. Deleting $VENV_JAX..." | |
| rm -rf "$VENV_JAX" | |
| fi | |
| if [[ -d "$VENV_JAX" ]] && { [[ "$SKIP_ROCM_CHECK" == "1" ]] || verify_jax "$VENV_JAX"; }; then | |
| log "$VENV_JAX already exists and passes verification. Skipping." | |
| else | |
| log "Creating $VENV_JAX..." | |
| rm -rf "$VENV_JAX" | |
| python3 -m venv "$VENV_JAX" | |
| "$VENV_JAX/bin/pip" install --upgrade pip setuptools wheel | |
| # Install JAX with ROCm 7.2 wheels from the specified GitHub release | |
| log "Installing JAX-ROCm 7.2 wheels from GitHub..." | |
| "$VENV_JAX/bin/pip" install | |
| jax==0.4.30 | |
| "jaxlib[rocm]==0.4.19+rocm7.2" | |
| --find-links https://storage.googleapis.com/jax-releases/jax_rocm_releases.html | |
| --extra-index-url https://pypi.org/simple | |
| # Install TensorFlow-ROCm and conversion tools | |
| log "Installing TensorFlow-ROCm and conversion tools..." | |
| "$VENV_JAX/bin/pip" install | |
| --find-links https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/ | |
| tensorflow-rocm==2.16.1 | |
| tf2jax==0.3.8 | |
| tf2onnx | |
| iree-compiler | |
| iree-runtime | |
| optax | |
| flax | |
| huggingface_hub | |
| if [[ "$SKIP_ROCM_CHECK" != "1" ]] && ! verify_jax "$VENV_JAX"; then | |
| error ".venv-jax verification FAILED: GPU not visible." | |
| exit 1 | |
| fi | |
| log ".venv-jax PASS" | |
| fi | |
| # --- 4. Bootstrap .venv-torch --- | |
| # This venv is used for PyTorch-based training and fine-tuning (Gemma 3, etc.) | |
| VENV_TORCH=".venv-torch" | |
| if [[ $FORCE -eq 1 ]]; then | |
| log "Force flag detected. Deleting $VENV_TORCH..." | |
| rm -rf "$VENV_TORCH" | |
| fi | |
| if [[ -d "$VENV_TORCH" ]] && { [[ "$SKIP_ROCM_CHECK" == "1" ]] || verify_torch "$VENV_TORCH"; }; then | |
| log "$VENV_TORCH already exists and passes verification. Skipping." | |
| else | |
| log "Creating $VENV_TORCH..." | |
| rm -rf "$VENV_TORCH" | |
| python3 -m venv "$VENV_TORCH" | |
| "$VENV_TORCH/bin/pip" install --upgrade pip setuptools wheel | |
| # Install PyTorch with ROCm 7.2 (latest for performance and security) | |
| log "Installing PyTorch-ROCm 7.2..." | |
| "$VENV_TORCH/bin/pip" install | |
| --index-url https://download.pytorch.org/whl/rocm7.2 | |
| torch | |
| torchvision | |
| torchaudio | |
| # Install HuggingFace stack and utilities | |
| log "Installing HuggingFace stack and utilities..." | |
| "$VENV_TORCH/bin/pip" install | |
| transformers | |
| trl | |
| peft | |
| "bitsandbytes>=0.43" | |
| accelerate | |
| onnx | |
| onnxruntime-rocm | |
| tensorboard | |
| trackio | |
| huggingface_hub | |
| streamlit | |
| rich | |
| python-dotenv | |
| if [[ "$SKIP_ROCM_CHECK" != "1" ]] && ! verify_torch "$VENV_TORCH"; then | |
| error ".venv-torch verification FAILED: GPU not visible." | |
| exit 1 | |
| fi | |
| log ".venv-torch PASS" | |
| fi | |
| log "Dual-venv bootstrap complete. Both venvs verified." | |
| exit 0 | |