Synesthesia / scripts /bootstrap_venvs.sh
Ashiedu's picture
Sync unified workbench
0490201 verified
#!/bin/bash
# scripts/bootstrap_venvs.sh
# Dual-venv bootstrap for Synesthesia ROCm 7.2.1
# Creates .venv-jax and .venv-torch with pinned wheels.
set -e
# --- Configuration ---
FORCE=0
if [[ "$1" == "--force" ]]; then
FORCE=1
fi
# Function to log messages
log() {
echo -e "\033[1;34m[BOOTSTRAP]\033[0m $1"
}
# Function to log errors
error() {
echo -e "\033[1;31m[ERROR]\033[0m $1" >&2
}
# --- 1. ROCm 7.2.1 Presence Check ---
# We ensure the hardware environment is correct before proceeding.
if [[ "$SKIP_ROCM_CHECK" != "1" ]]; then
log "Checking for ROCm 7.2.1..."
if [[ ! -d "/opt/rocm" ]]; then
error "ROCm not found at /opt/rocm. Please install ROCm 7.2.1."
exit 1
fi
if ! command -v rocm-smi &> /dev/null; then
error "rocm-smi not found. Ensure ROCm binaries are in your PATH."
exit 1
fi
log "ROCm 7.2.1 check passed."
else
log "Skipping ROCm check (SKIP_ROCM_CHECK=1)."
fi
# --- 2. Venv Management Functions ---
# Function to verify JAX GPU visibility
verify_jax() {
local venv_path=$1
log "Verifying GPU visibility in $venv_path..."
if "$venv_path/bin/python" -c "import jax; devices = jax.devices(); print(f'Devices: {devices}'); exit(0) if any(d.platform == 'gpu' or d.platform == 'rocm' for d in devices) else exit(1)" 2>/dev/null; then
return 0
else
return 1
fi
}
# Function to verify Torch GPU visibility
verify_torch() {
local venv_path=$1
log "Verifying GPU visibility in $venv_path..."
if "$venv_path/bin/python" -c "import torch; print(f'CUDA Available: {torch.cuda.is_available()}'); exit(0) if torch.cuda.is_available() else exit(1)" 2>/dev/null; then
return 0
else
return 1
fi
}
# --- 3. Bootstrap .venv-jax ---
# This venv is used for JAX-based inference and IREE/ONNX model exports.
VENV_JAX=".venv-jax"
if [[ $FORCE -eq 1 ]]; then
log "Force flag detected. Deleting $VENV_JAX..."
rm -rf "$VENV_JAX"
fi
if [[ -d "$VENV_JAX" ]] && { [[ "$SKIP_ROCM_CHECK" == "1" ]] || verify_jax "$VENV_JAX"; }; then
log "$VENV_JAX already exists and passes verification. Skipping."
else
log "Creating $VENV_JAX..."
rm -rf "$VENV_JAX"
python3 -m venv "$VENV_JAX"
"$VENV_JAX/bin/pip" install --upgrade pip setuptools wheel
# Install JAX with ROCm 7.2 wheels from the specified GitHub release
log "Installing JAX-ROCm 7.2 wheels from GitHub..."
"$VENV_JAX/bin/pip" install
jax==0.4.30
"jaxlib[rocm]==0.4.19+rocm7.2"
--find-links https://storage.googleapis.com/jax-releases/jax_rocm_releases.html
--extra-index-url https://pypi.org/simple
# Install TensorFlow-ROCm and conversion tools
log "Installing TensorFlow-ROCm and conversion tools..."
"$VENV_JAX/bin/pip" install
--find-links https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/
tensorflow-rocm==2.16.1
tf2jax==0.3.8
tf2onnx
iree-compiler
iree-runtime
optax
flax
huggingface_hub
if [[ "$SKIP_ROCM_CHECK" != "1" ]] && ! verify_jax "$VENV_JAX"; then
error ".venv-jax verification FAILED: GPU not visible."
exit 1
fi
log ".venv-jax PASS"
fi
# --- 4. Bootstrap .venv-torch ---
# This venv is used for PyTorch-based training and fine-tuning (Gemma 3, etc.)
VENV_TORCH=".venv-torch"
if [[ $FORCE -eq 1 ]]; then
log "Force flag detected. Deleting $VENV_TORCH..."
rm -rf "$VENV_TORCH"
fi
if [[ -d "$VENV_TORCH" ]] && { [[ "$SKIP_ROCM_CHECK" == "1" ]] || verify_torch "$VENV_TORCH"; }; then
log "$VENV_TORCH already exists and passes verification. Skipping."
else
log "Creating $VENV_TORCH..."
rm -rf "$VENV_TORCH"
python3 -m venv "$VENV_TORCH"
"$VENV_TORCH/bin/pip" install --upgrade pip setuptools wheel
# Install PyTorch with ROCm 7.2 (latest for performance and security)
log "Installing PyTorch-ROCm 7.2..."
"$VENV_TORCH/bin/pip" install
--index-url https://download.pytorch.org/whl/rocm7.2
torch
torchvision
torchaudio
# Install HuggingFace stack and utilities
log "Installing HuggingFace stack and utilities..."
"$VENV_TORCH/bin/pip" install
transformers
trl
peft
"bitsandbytes>=0.43"
accelerate
onnx
onnxruntime-rocm
tensorboard
trackio
huggingface_hub
streamlit
rich
python-dotenv
if [[ "$SKIP_ROCM_CHECK" != "1" ]] && ! verify_torch "$VENV_TORCH"; then
error ".venv-torch verification FAILED: GPU not visible."
exit 1
fi
log ".venv-torch PASS"
fi
log "Dual-venv bootstrap complete. Both venvs verified."
exit 0