Synesthesia / scripts /quick_env_check.py
Ashiedu's picture
Sync unified workbench
0490201 verified
#!/usr/bin/env python3
"""Quick environment verification for Synesthesia ROCm stack.
Checks:
1. ROCm SMI visibility
2. env.py imports and exports correct vars
3. PyTorch GPU detection
4. HF_TOKEN is set
5. JAX ROCm visibility (optional)
Returns exit code 0 only if all critical checks pass.
"""
import os
import subprocess
import sys
from pathlib import Path
def check_rocm_smi():
"""Check if ROCm SMI detects the GPU."""
try:
result = subprocess.run(
["rocm-smi", "--showproductname"],
capture_output=True,
text=True,
timeout=10,
)
if result.returncode == 0 and "6700" in result.stdout.lower():
return True, result.stdout.strip().split("\n")[0]
elif result.returncode == 0:
return True, f"ROCm detected (GPU: {result.stdout.strip().split(chr(10))[0]})"
else:
return False, f"rocm-smi failed: {result.stderr.strip()}"
except FileNotFoundError:
return False, "rocm-smi not found (ROCm not installed?)"
except subprocess.TimeoutExpired:
return False, "rocm-smi timed out"
except Exception as e:
return False, str(e)
def check_env_py():
"""Check if env.py imports and exports correct variables."""
try:
# Add project root to path
project_root = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(project_root))
from ML_Pipeline.shared import env
e = env.get_env_dict()
checks = {
"HSA_OVERRIDE_GFX_VERSION": e.get("HSA_OVERRIDE_GFX_VERSION"),
"HSA_ENABLE_SDMA": e.get("HSA_ENABLE_SDMA"),
"JAX_PLATFORMS": e.get("JAX_PLATFORMS"),
}
missing = [k for k, v in checks.items() if not v]
if missing:
return False, f"Missing env vars: {missing}"
return True, f"All env vars set (HF_TOKEN: {'set' if e.get('HF_TOKEN') else 'NOT SET'})"
except ImportError as e:
return False, f"Import error: {e}"
except Exception as e:
return False, f"Error: {e}"
def check_torch_gpu():
"""Check if PyTorch can see ROCm GPU."""
try:
import torch
# Check for ROCm/HIP
has_hip = hasattr(torch.version, "hip") and torch.version.hip is not None
# Check CUDA (PyTorch uses cuda.is_available() for ROCm too)
has_gpu = torch.cuda.is_available() if hasattr(torch, "cuda") else False
if has_hip and has_gpu:
device_name = torch.cuda.get_device_name(0) if has_gpu else "Unknown"
return True, f"ROCm GPU detected: {device_name}"
elif has_hip:
return True, "ROCm available but no GPU visible"
else:
return False, "PyTorch ROCm not available"
except ImportError:
return False, "PyTorch not installed"
except Exception as e:
return False, f"Error: {e}"
def check_hf_token():
"""Check if HF_TOKEN is set."""
token = os.environ.get("HF_TOKEN")
if token and len(token) > 10:
return True, f"HF_TOKEN set ({len(token)} chars)"
elif token:
return False, "HF_TOKEN set but too short (may be invalid)"
else:
# Check in env.py
try:
project_root = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(project_root))
from ML_Pipeline.shared.env import HF_TOKEN
if HF_TOKEN:
return True, "HF_TOKEN loaded from env.py"
except:
pass
return False, "HF_TOKEN not set"
def check_jax_rocm():
"""Check if JAX can see ROCm devices (optional)."""
try:
import jax
devices = jax.devices()
if len(devices) > 0:
device_types = set(type(d).__name__ for d in devices)
return True, f"JAX devices: {len(devices)} ({', '.join(device_types)})"
else:
return False, "JAX installed but no devices found"
except ImportError:
return None, "JAX not installed (optional)"
except Exception as e:
return False, f"Error: {e}"
def main():
"""Run all checks and report results."""
print("=" * 60)
print("Synesthesia Environment Verification")
print("=" * 60)
print()
checks = [
("ROCm SMI", check_rocm_smi, True),
("env.py", check_env_py, True),
("PyTorch GPU", check_torch_gpu, True),
("HF_TOKEN", check_hf_token, True),
("JAX ROCm", check_jax_rocm, False),
]
results = []
critical_failed = False
for name, check_fn, is_critical in checks:
try:
passed, message = check_fn()
except Exception as e:
passed = False
message = f"Check failed: {e}"
status = "✓ PASS" if passed else ("⚠ WARN" if not is_critical else "✗ FAIL")
if not passed and is_critical:
critical_failed = True
results.append((name, status, message, is_critical))
print(f"[{status}] {name}: {message}")
print()
print("=" * 60)
if critical_failed:
print("RESULT: CRITICAL CHECKS FAILED")
print("Fix the above errors before proceeding.")
return 1
else:
print("RESULT: ALL CRITICAL CHECKS PASSED")
print("Environment is ready for runtime module implementation.")
return 0
if __name__ == "__main__":
sys.exit(main())