Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| """GPU connection tests for Colab, HF Spaces, and local backends. | |
| Tests device detection, mixed precision, model placement, forward pass, | |
| and backward pass on all available GPU targets. | |
| Usage: | |
| python3 scripts/gpu_connection_test.py # auto-detect | |
| python3 scripts/gpu_connection_test.py --target cuda # force CUDA | |
| python3 scripts/gpu_connection_test.py --target mps # force MPS | |
| python3 scripts/gpu_connection_test.py --target cpu # force CPU | |
| python3 scripts/gpu_connection_test.py --full # include training loop test | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import sys | |
| import time | |
| from dataclasses import dataclass | |
| from typing import List, Tuple | |
| class TestResult: | |
| name: str | |
| passed: bool | |
| detail: str | |
| elapsed_ms: float = 0.0 | |
| def _run_test(name: str, fn) -> TestResult: | |
| t0 = time.time() | |
| try: | |
| detail = fn() | |
| elapsed = (time.time() - t0) * 1000 | |
| return TestResult(name, True, detail, elapsed) | |
| except Exception as e: | |
| elapsed = (time.time() - t0) * 1000 | |
| return TestResult(name, False, str(e), elapsed) | |
| def test_torch_import() -> str: | |
| import torch | |
| return f"torch {torch.__version__}" | |
| def test_cuda_available() -> str: | |
| import torch | |
| if not torch.cuda.is_available(): | |
| return "CUDA not available (expected on MPS/CPU)" | |
| name = torch.cuda.get_device_name(0) | |
| cap = torch.cuda.get_device_capability() | |
| mem = torch.cuda.get_device_properties(0).total_mem / 1e9 | |
| return f"{name}, compute {cap[0]}.{cap[1]}, {mem:.1f}GB" | |
| def test_mps_available() -> str: | |
| import torch | |
| if not hasattr(torch.backends, 'mps') or not torch.backends.mps.is_available(): | |
| return "MPS not available (expected on Linux/Colab)" | |
| return "MPS available" | |
| def test_accelerate_import() -> str: | |
| from accelerate import Accelerator | |
| acc = Accelerator() | |
| return f"device={acc.device}, mp={acc.mixed_precision}" | |
| def test_device_resolution() -> str: | |
| import torch | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): | |
| return "mps" | |
| return "cpu" | |
| def test_mixed_precision_support() -> str: | |
| import torch | |
| if not torch.cuda.is_available(): | |
| return "skipped (no CUDA)" | |
| cap = torch.cuda.get_device_capability() | |
| if cap[0] >= 8: | |
| # Test bf16 | |
| x = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16) | |
| y = x @ x.T | |
| return f"bf16 supported (compute {cap[0]}.{cap[1]})" | |
| # Test fp16 | |
| x = torch.randn(4, 4, device="cuda", dtype=torch.float16) | |
| y = x @ x.T | |
| return f"fp16 supported (compute {cap[0]}.{cap[1]})" | |
| def test_model_placement(target: str) -> str: | |
| import torch | |
| from training.core.kan_jepa_generator import create_kan_jepa_model | |
| from training.core.bidirectional_generator import SimpleVocab | |
| device = torch.device(target) | |
| model = create_kan_jepa_model(100, "small") | |
| model = model.to(device) | |
| n_params = sum(p.numel() for p in model.parameters()) | |
| actual_dev = next(model.parameters()).device | |
| return f"{n_params:,} params on {actual_dev}" | |
| def test_forward_pass(target: str) -> str: | |
| import torch | |
| from training.core.kan_jepa_generator import create_kan_jepa_model | |
| device = torch.device(target) | |
| model = create_kan_jepa_model(100, "small").to(device) | |
| model.eval() | |
| src = torch.randint(1, 50, (2, 10), device=device) | |
| tgt = torch.randint(1, 50, (2, 8), device=device) | |
| with torch.no_grad(): | |
| logits, info = model(src, tgt) | |
| return f"logits={list(logits.shape)}, jepa_loss={info['jepa_loss'].item():.4f}" | |
| def test_backward_pass(target: str) -> str: | |
| import torch | |
| from training.core.kan_jepa_generator import create_kan_jepa_model | |
| device = torch.device(target) | |
| model = create_kan_jepa_model(100, "small").to(device) | |
| model.train() | |
| src = torch.randint(1, 50, (4, 12), device=device) | |
| tgt = torch.randint(1, 50, (4, 10), device=device) | |
| logits, info = model(src, tgt[:, :-1]) | |
| loss = logits.mean() + info["jepa_loss"] | |
| loss.backward() | |
| grad_norms = [] | |
| for p in model.parameters(): | |
| if p.grad is not None: | |
| grad_norms.append(p.grad.norm().item()) | |
| return f"loss={loss.item():.4f}, grad_params={len(grad_norms)}, max_grad={max(grad_norms):.4f}" | |
| def test_mixed_precision_forward(target: str) -> str: | |
| import torch | |
| if target != "cuda": | |
| return "skipped (CUDA only)" | |
| from training.core.kan_jepa_generator import create_kan_jepa_model | |
| device = torch.device("cuda") | |
| model = create_kan_jepa_model(100, "small").to(device) | |
| model.train() | |
| cap = torch.cuda.get_device_capability() | |
| dtype = torch.bfloat16 if cap[0] >= 8 else torch.float16 | |
| src = torch.randint(1, 50, (4, 12), device=device) | |
| tgt = torch.randint(1, 50, (4, 10), device=device) | |
| with torch.autocast(device_type="cuda", dtype=dtype): | |
| logits, info = model(src, tgt[:, :-1]) | |
| loss = logits.mean() + info["jepa_loss"] | |
| loss.backward() | |
| return f"autocast {dtype} OK, loss={loss.item():.4f}" | |
| def test_accelerate_training(target: str) -> str: | |
| try: | |
| from training.core.accelerate_trainer import AccelerateTrainer, AccelerateConfig | |
| except ImportError: | |
| return "accelerate_trainer not available" | |
| from training.core.kan_jepa_generator import create_kan_jepa_model | |
| from training.core.bidirectional_generator import SimpleVocab | |
| pairs = [ | |
| ("Find all nodes", "MATCH (n) RETURN n"), | |
| ("Count people", "MATCH (p:Person) RETURN count(p)"), | |
| ("Find movies", "MATCH (m:Movie) RETURN m"), | |
| ("Who knows who", "MATCH (a)-[:KNOWS]->(b) RETURN a, b"), | |
| ] | |
| vocab = SimpleVocab.build_from_corpus( | |
| [t for p in pairs for t in p], max_size=100) | |
| model = create_kan_jepa_model(len(vocab), "small") | |
| cfg = AccelerateConfig( | |
| epochs=3, batch_size=2, gradient_accumulation_steps=1, | |
| mixed_precision="no", log_every=1, eval_samples=0) | |
| trainer = AccelerateTrainer(model, vocab, pairs, cfg) | |
| result = trainer.train(verbose=False) | |
| return f"3 epochs OK, loss={result['final_loss']:.4f}, {result['training_time_s']:.1f}s on {result['device']}" | |
| def test_colab_detection() -> str: | |
| """Detect if running inside Google Colab.""" | |
| try: | |
| import google.colab # noqa: F401 | |
| return "running in Colab" | |
| except ImportError: | |
| return "not in Colab (local environment)" | |
| def test_hf_space_detection() -> str: | |
| """Detect if running inside HuggingFace Spaces.""" | |
| import os | |
| if os.environ.get("SPACE_ID"): | |
| return f"HF Space: {os.environ['SPACE_ID']}" | |
| return "not in HF Spaces" | |
| def test_modular_max() -> str: | |
| """Test Modular MAX / Mojo availability.""" | |
| try: | |
| import max as _max | |
| ver = getattr(_max, "__version__", "unknown") | |
| return f"MAX {ver} available" | |
| except ImportError: | |
| return "MAX not installed (pip install modular)" | |
| def test_mlx_available() -> str: | |
| """Test Apple MLX framework.""" | |
| try: | |
| import mlx.core as mx | |
| ver = mx.__version__ if hasattr(mx, "__version__") else "unknown" | |
| # Quick compute test | |
| a = mx.ones((4, 4)) | |
| b = mx.ones((4, 4)) | |
| c = a @ b | |
| mx.eval(c) | |
| return f"MLX {ver}, matmul OK, unified memory" | |
| except ImportError: | |
| return "MLX not installed (pip install mlx)" | |
| except Exception as e: | |
| return f"MLX import OK but compute failed: {e}" | |
| def test_snowflake_available() -> str: | |
| """Test Snowflake ML SDK availability.""" | |
| import os | |
| if os.environ.get("SNOWFLAKE_ACCOUNT"): | |
| return f"SPCS environment: {os.environ['SNOWFLAKE_ACCOUNT']}" | |
| try: | |
| import snowflake.ml # noqa: F401 | |
| return "snowflake-ml-python installed (set SNOWFLAKE_ACCOUNT to connect)" | |
| except ImportError: | |
| return "snowflake-ml not installed (pip install snowflake-ml-python)" | |
| def test_unified_backend() -> str: | |
| """Test unified backend detection across all 7 backends.""" | |
| import sys, os | |
| sys.path.insert(0, os.getcwd()) | |
| from training.core.unified_backend import detect_backend, probe_all_backends | |
| backend = detect_backend() | |
| all_infos = probe_all_backends() | |
| available = [i.name for i in all_infos if i.available] | |
| unavailable = [i.name for i in all_infos if not i.available] | |
| return (f"selected={backend.name}, " | |
| f"available=[{', '.join(available)}], " | |
| f"not_found=[{', '.join(unavailable)}]") | |
| def test_memory_estimate() -> str: | |
| """Estimate GPU memory for full training.""" | |
| import torch | |
| from training.core.kan_jepa_generator import create_kan_jepa_model | |
| model = create_kan_jepa_model(2000, "text2cypher") | |
| n_params = sum(p.numel() for p in model.parameters()) | |
| # fp32: 4 bytes/param. With Adam: ~3x (params + grads + 2 momentum) | |
| mem_fp32 = n_params * 4 * 3 / 1e6 # MB | |
| mem_fp16 = n_params * 2 * 3 / 1e6 | |
| return f"{n_params:,} params, est. {mem_fp32:.0f}MB fp32 / {mem_fp16:.0f}MB fp16" | |
| def run_all(target: str, full: bool = False) -> List[TestResult]: | |
| results = [] | |
| # Environment detection | |
| results.append(_run_test("torch import", test_torch_import)) | |
| results.append(_run_test("CUDA available", test_cuda_available)) | |
| results.append(_run_test("MPS available", test_mps_available)) | |
| results.append(_run_test("Colab detection", test_colab_detection)) | |
| results.append(_run_test("HF Space detection", test_hf_space_detection)) | |
| results.append(_run_test("device resolution", test_device_resolution)) | |
| results.append(_run_test("accelerate import", test_accelerate_import)) | |
| results.append(_run_test("mixed precision support", test_mixed_precision_support)) | |
| results.append(_run_test("Modular MAX / Mojo", test_modular_max)) | |
| results.append(_run_test("Apple MLX", test_mlx_available)) | |
| results.append(_run_test("Snowflake SPCS", test_snowflake_available)) | |
| results.append(_run_test("unified backend (7 adapters)", test_unified_backend)) | |
| results.append(_run_test("memory estimate", test_memory_estimate)) | |
| # Model tests on target device | |
| results.append(_run_test(f"model placement [{target}]", | |
| lambda: test_model_placement(target))) | |
| results.append(_run_test(f"forward pass [{target}]", | |
| lambda: test_forward_pass(target))) | |
| results.append(_run_test(f"backward pass [{target}]", | |
| lambda: test_backward_pass(target))) | |
| results.append(_run_test(f"mixed precision fwd [{target}]", | |
| lambda: test_mixed_precision_forward(target))) | |
| if full: | |
| results.append(_run_test(f"accelerate training [{target}]", | |
| lambda: test_accelerate_training(target))) | |
| return results | |
| def main(): | |
| parser = argparse.ArgumentParser(description="GPU connection tests") | |
| parser.add_argument("--target", choices=["auto", "cuda", "mps", "cpu"], | |
| default="auto", help="Target device") | |
| parser.add_argument("--full", action="store_true", | |
| help="Include training loop test") | |
| args = parser.parse_args() | |
| # Resolve target | |
| if args.target == "auto": | |
| import torch | |
| if torch.cuda.is_available(): | |
| target = "cuda" | |
| elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): | |
| target = "mps" | |
| else: | |
| target = "cpu" | |
| else: | |
| target = args.target | |
| print(f"=== GPU Connection Tests (target: {target}) ===\n") | |
| results = run_all(target, args.full) | |
| # Print results | |
| passed = sum(1 for r in results if r.passed) | |
| total = len(results) | |
| for r in results: | |
| status = "PASS" if r.passed else "FAIL" | |
| print(f" [{status}] {r.name}: {r.detail} ({r.elapsed_ms:.0f}ms)") | |
| print(f"\n{passed}/{total} tests passed") | |
| if passed < total: | |
| failed = [r for r in results if not r.passed] | |
| print("\nFailed tests:") | |
| for r in failed: | |
| print(f" - {r.name}: {r.detail}") | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| main() | |