| """Launch N RunPod A40 pods, deploy the codebase, kick off training. |
| |
| Usage: |
| python scripts/runpod_launch.py --models A B C F --gpu A40 \ |
| --image runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04 |
| |
| For each model letter: |
| 1. create pod |
| 2. wait for SSH |
| 3. rsync repo + .env via scp |
| 4. run pod_bootstrap.sh on the pod (in tmux/nohup) |
| 5. record pod id + run name in runs/launch_manifest.json |
| |
| Polling/log retrieval is left to scripts/runpod_status.py. |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import shutil |
| import subprocess |
| import sys |
| import tempfile |
| import time |
| from pathlib import Path |
|
|
| from dotenv import load_dotenv |
|
|
| load_dotenv() |
| RUNPOD_API_KEY = os.environ["RUNPOD_API_KEY"] |
|
|
| GPU_IDS = { |
| "A40": "NVIDIA A40", |
| "A6000": "NVIDIA RTX A6000", |
| "A100": "NVIDIA A100-SXM4-80GB", |
| "H100": "NVIDIA H100 80GB HBM3", |
| } |
|
|
| DEFAULT_IMAGE = "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04" |
|
|
|
|
| def runpodctl(args: list[str], capture: bool = True) -> str: |
| env = {**os.environ, "RUNPOD_API_KEY": RUNPOD_API_KEY} |
| res = subprocess.run( |
| ["runpodctl", *args], env=env, capture_output=capture, text=True |
| ) |
| if res.returncode != 0: |
| raise RuntimeError(f"runpodctl {' '.join(args)} failed: {res.stderr}\n{res.stdout}") |
| return res.stdout |
|
|
|
|
| def create_pod(name: str, gpu_id: str, image: str, container_disk: int = 50, |
| volume_gb: int = 100) -> dict: |
| out = runpodctl([ |
| "pod", "create", |
| "--name", name, |
| "--gpu-id", gpu_id, |
| "--gpu-count", "1", |
| "--image", image, |
| "--cloud-type", "COMMUNITY", |
| "--container-disk-in-gb", str(container_disk), |
| "--volume-in-gb", str(volume_gb), |
| "--volume-mount-path", "/workspace", |
| "--ports", "22/tcp", |
| "--ssh", |
| ]) |
| pod = json.loads(out) |
| return pod |
|
|
|
|
| def wait_for_ssh(pod_id: str, timeout: int = 600) -> tuple[str, int]: |
| start = time.time() |
| last_err = "" |
| while time.time() - start < timeout: |
| try: |
| info = json.loads(runpodctl(["ssh", "info", pod_id])) |
| host = info.get("publicIp") or info.get("ip") |
| port = info.get("port") or info.get("sshPort") |
| if host and port: |
| return host, int(port) |
| except Exception as e: |
| last_err = str(e) |
| time.sleep(15) |
| raise TimeoutError(f"SSH not ready for {pod_id}: {last_err}") |
|
|
|
|
| def ssh(host: str, port: int, cmd: str, user: str = "root", timeout: int = 60) -> str: |
| res = subprocess.run([ |
| "ssh", "-o", "StrictHostKeyChecking=no", |
| "-o", "UserKnownHostsFile=/dev/null", |
| "-o", "ConnectTimeout=15", |
| "-p", str(port), |
| f"{user}@{host}", cmd, |
| ], capture_output=True, text=True, timeout=timeout) |
| if res.returncode != 0: |
| raise RuntimeError(f"ssh {host}:{port} {cmd!r} failed: {res.stderr}") |
| return res.stdout |
|
|
|
|
| def scp(host: str, port: int, local_path: Path, remote_path: str, user: str = "root") -> None: |
| cmd = ["scp", "-o", "StrictHostKeyChecking=no", |
| "-o", "UserKnownHostsFile=/dev/null", |
| "-P", str(port)] |
| if local_path.is_dir(): |
| cmd.append("-r") |
| cmd.extend([str(local_path), f"{user}@{host}:{remote_path}"]) |
| res = subprocess.run(cmd, capture_output=True, text=True, timeout=900) |
| if res.returncode != 0: |
| raise RuntimeError(f"scp {local_path} -> {host}:{remote_path} failed: {res.stderr}") |
|
|
|
|
| def deploy_and_launch(host: str, port: int, model: str, run_name: str, repo_root: Path) -> None: |
| |
| with tempfile.TemporaryDirectory() as td: |
| tar = Path(td) / "physiojepa.tar.gz" |
| excludes = [".venv", ".git", "__pycache__", "runs", "cache", "docs/figures", |
| "docs/paperes"] |
| excl_args = [] |
| for e in excludes: |
| excl_args.extend(["--exclude", e]) |
| subprocess.run( |
| ["tar", "-czf", str(tar), *excl_args, "-C", str(repo_root.parent), |
| repo_root.name], |
| check=True, |
| ) |
| scp(host, port, tar, "/workspace/physiojepa.tar.gz") |
| |
| env_file = repo_root / ".env" |
| scp(host, port, env_file, "/workspace/.env") |
| ssh(host, port, "set -e; cd /workspace && rm -rf physiojepa && " |
| "tar -xzf physiojepa.tar.gz && rm physiojepa.tar.gz") |
| |
| bootstrap = ( |
| f"set -e; mkdir -p /workspace/runs; " |
| f"cd /workspace/physiojepa && chmod +x scripts/pod_bootstrap.sh && " |
| f"nohup bash scripts/pod_bootstrap.sh {model} {run_name} " |
| f"> /workspace/runs/{run_name}.bootstrap.log 2>&1 &" |
| f" disown; echo started; sleep 1" |
| ) |
| ssh(host, port, bootstrap) |
|
|
|
|
| def main() -> None: |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--models", nargs="+", default=["A", "B", "C", "F"]) |
| ap.add_argument("--gpu", default="A40", choices=list(GPU_IDS.keys())) |
| ap.add_argument("--image", default=DEFAULT_IMAGE) |
| ap.add_argument("--repo_root", default=str(Path(__file__).resolve().parents[1])) |
| ap.add_argument("--manifest", default="runs/launch_manifest.json") |
| args = ap.parse_args() |
|
|
| repo_root = Path(args.repo_root) |
| Path(args.manifest).parent.mkdir(parents=True, exist_ok=True) |
| gpu_id = GPU_IDS[args.gpu] |
| manifest = [] |
|
|
| for model in args.models: |
| run_name = f"e2_{model}_a40" |
| pod_name = f"pj-{model.lower()}-{int(time.time()) % 100000:05d}" |
| print(f"[launch] creating pod {pod_name} (model={model}, gpu={args.gpu})") |
| pod = create_pod(pod_name, gpu_id, args.image) |
| pod_id = pod.get("id") or pod.get("podId") |
| print(f"[launch] pod_id={pod_id}, waiting for SSH...") |
| try: |
| host, port = wait_for_ssh(pod_id) |
| except TimeoutError as e: |
| print(f"[launch] WARN: {e}; deleting pod and continuing") |
| try: |
| runpodctl(["pod", "delete", pod_id]) |
| except Exception: |
| pass |
| continue |
| print(f"[launch] SSH up @ {host}:{port}, deploying code") |
| deploy_and_launch(host, port, model, run_name, repo_root) |
| manifest.append({"pod_id": pod_id, "pod_name": pod_name, "host": host, |
| "port": port, "model": model, "run_name": run_name, |
| "started_at": time.time()}) |
| Path(args.manifest).write_text(json.dumps(manifest, indent=2)) |
| print(f"[launch] {model} kicked off; manifest -> {args.manifest}") |
|
|
| print(f"[launch] all done. manifest:\n{Path(args.manifest).read_text()}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|