Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """Quick environment verification for Synesthesia ROCm stack. | |
| Checks: | |
| 1. ROCm SMI visibility | |
| 2. env.py imports and exports correct vars | |
| 3. PyTorch GPU detection | |
| 4. HF_TOKEN is set | |
| 5. JAX ROCm visibility (optional) | |
| Returns exit code 0 only if all critical checks pass. | |
| """ | |
| import os | |
| import subprocess | |
| import sys | |
| from pathlib import Path | |
| def check_rocm_smi(): | |
| """Check if ROCm SMI detects the GPU.""" | |
| try: | |
| result = subprocess.run( | |
| ["rocm-smi", "--showproductname"], | |
| capture_output=True, | |
| text=True, | |
| timeout=10, | |
| ) | |
| if result.returncode == 0 and "6700" in result.stdout.lower(): | |
| return True, result.stdout.strip().split("\n")[0] | |
| elif result.returncode == 0: | |
| return True, f"ROCm detected (GPU: {result.stdout.strip().split(chr(10))[0]})" | |
| else: | |
| return False, f"rocm-smi failed: {result.stderr.strip()}" | |
| except FileNotFoundError: | |
| return False, "rocm-smi not found (ROCm not installed?)" | |
| except subprocess.TimeoutExpired: | |
| return False, "rocm-smi timed out" | |
| except Exception as e: | |
| return False, str(e) | |
| def check_env_py(): | |
| """Check if env.py imports and exports correct variables.""" | |
| try: | |
| # Add project root to path | |
| project_root = Path(__file__).resolve().parent.parent | |
| sys.path.insert(0, str(project_root)) | |
| from ML_Pipeline.shared import env | |
| e = env.get_env_dict() | |
| checks = { | |
| "HSA_OVERRIDE_GFX_VERSION": e.get("HSA_OVERRIDE_GFX_VERSION"), | |
| "HSA_ENABLE_SDMA": e.get("HSA_ENABLE_SDMA"), | |
| "JAX_PLATFORMS": e.get("JAX_PLATFORMS"), | |
| } | |
| missing = [k for k, v in checks.items() if not v] | |
| if missing: | |
| return False, f"Missing env vars: {missing}" | |
| return True, f"All env vars set (HF_TOKEN: {'set' if e.get('HF_TOKEN') else 'NOT SET'})" | |
| except ImportError as e: | |
| return False, f"Import error: {e}" | |
| except Exception as e: | |
| return False, f"Error: {e}" | |
| def check_torch_gpu(): | |
| """Check if PyTorch can see ROCm GPU.""" | |
| try: | |
| import torch | |
| # Check for ROCm/HIP | |
| has_hip = hasattr(torch.version, "hip") and torch.version.hip is not None | |
| # Check CUDA (PyTorch uses cuda.is_available() for ROCm too) | |
| has_gpu = torch.cuda.is_available() if hasattr(torch, "cuda") else False | |
| if has_hip and has_gpu: | |
| device_name = torch.cuda.get_device_name(0) if has_gpu else "Unknown" | |
| return True, f"ROCm GPU detected: {device_name}" | |
| elif has_hip: | |
| return True, "ROCm available but no GPU visible" | |
| else: | |
| return False, "PyTorch ROCm not available" | |
| except ImportError: | |
| return False, "PyTorch not installed" | |
| except Exception as e: | |
| return False, f"Error: {e}" | |
| def check_hf_token(): | |
| """Check if HF_TOKEN is set.""" | |
| token = os.environ.get("HF_TOKEN") | |
| if token and len(token) > 10: | |
| return True, f"HF_TOKEN set ({len(token)} chars)" | |
| elif token: | |
| return False, "HF_TOKEN set but too short (may be invalid)" | |
| else: | |
| # Check in env.py | |
| try: | |
| project_root = Path(__file__).resolve().parent.parent | |
| sys.path.insert(0, str(project_root)) | |
| from ML_Pipeline.shared.env import HF_TOKEN | |
| if HF_TOKEN: | |
| return True, "HF_TOKEN loaded from env.py" | |
| except: | |
| pass | |
| return False, "HF_TOKEN not set" | |
| def check_jax_rocm(): | |
| """Check if JAX can see ROCm devices (optional).""" | |
| try: | |
| import jax | |
| devices = jax.devices() | |
| if len(devices) > 0: | |
| device_types = set(type(d).__name__ for d in devices) | |
| return True, f"JAX devices: {len(devices)} ({', '.join(device_types)})" | |
| else: | |
| return False, "JAX installed but no devices found" | |
| except ImportError: | |
| return None, "JAX not installed (optional)" | |
| except Exception as e: | |
| return False, f"Error: {e}" | |
| def main(): | |
| """Run all checks and report results.""" | |
| print("=" * 60) | |
| print("Synesthesia Environment Verification") | |
| print("=" * 60) | |
| print() | |
| checks = [ | |
| ("ROCm SMI", check_rocm_smi, True), | |
| ("env.py", check_env_py, True), | |
| ("PyTorch GPU", check_torch_gpu, True), | |
| ("HF_TOKEN", check_hf_token, True), | |
| ("JAX ROCm", check_jax_rocm, False), | |
| ] | |
| results = [] | |
| critical_failed = False | |
| for name, check_fn, is_critical in checks: | |
| try: | |
| passed, message = check_fn() | |
| except Exception as e: | |
| passed = False | |
| message = f"Check failed: {e}" | |
| status = "✓ PASS" if passed else ("⚠ WARN" if not is_critical else "✗ FAIL") | |
| if not passed and is_critical: | |
| critical_failed = True | |
| results.append((name, status, message, is_critical)) | |
| print(f"[{status}] {name}: {message}") | |
| print() | |
| print("=" * 60) | |
| if critical_failed: | |
| print("RESULT: CRITICAL CHECKS FAILED") | |
| print("Fix the above errors before proceeding.") | |
| return 1 | |
| else: | |
| print("RESULT: ALL CRITICAL CHECKS PASSED") | |
| print("Environment is ready for runtime module implementation.") | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |