Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/ray/train/v2/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/__pycache__/exceptions.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/__pycache__/util.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__init__.py +14 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/accelerators.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/backend_setup.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/datasets.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/metrics.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/user_callback.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/working_dir_setup.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/accelerators.py +151 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/backend_setup.py +27 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/datasets.py +76 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/metrics.py +250 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/user_callback.py +50 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/working_dir_setup.py +24 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/constants.py +84 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/exceptions.py +170 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__pycache__/callback.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__pycache__/context.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__pycache__/controller.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__pycache__/storage.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/callback.py +140 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__pycache__/checkpoint_manager.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__pycache__/report_handler.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__pycache__/sync_actor.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/checkpoint_manager.py +271 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/report_handler.py +111 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/sync_actor.py +190 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/context.py +281 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/controller.py +377 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/__pycache__/default.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/__pycache__/factory.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/__pycache__/failure_policy.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/default.py +44 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/factory.py +13 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/__init__.py +19 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/__pycache__/factory.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/__pycache__/fixed.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/__pycache__/scaling_policy.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/factory.py +13 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/fixed.py +22 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/scaling_policy.py +51 -0
- .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/storage.py +551 -0
.venv/lib/python3.11/site-packages/ray/train/v2/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (185 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (195 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/__pycache__/exceptions.cpython-311.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/__pycache__/util.cpython-311.pyc
ADDED
|
Binary file (8.22 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .accelerators import AcceleratorSetupCallback
|
| 2 |
+
from .backend_setup import BackendSetupCallback
|
| 3 |
+
from .datasets import DatasetsSetupCallback
|
| 4 |
+
from .working_dir_setup import WorkingDirectorySetupCallback
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"AcceleratorSetupCallback",
|
| 8 |
+
"BackendSetupCallback",
|
| 9 |
+
"DatasetsSetupCallback",
|
| 10 |
+
"WorkingDirectorySetupCallback",
|
| 11 |
+
]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# DO NOT ADD ANYTHING AFTER THIS LINE.
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/accelerators.cpython-311.pyc
ADDED
|
Binary file (7.69 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/backend_setup.cpython-311.pyc
ADDED
|
Binary file (2.26 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/datasets.cpython-311.pyc
ADDED
|
Binary file (4.62 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/metrics.cpython-311.pyc
ADDED
|
Binary file (14.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/user_callback.cpython-311.pyc
ADDED
|
Binary file (2.52 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/working_dir_setup.cpython-311.pyc
ADDED
|
Binary file (1.93 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/accelerators.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import ray._private.ray_constants as ray_constants
|
| 7 |
+
from ray._private.ray_constants import env_bool
|
| 8 |
+
from ray.train import BackendConfig
|
| 9 |
+
from ray.train.constants import ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV
|
| 10 |
+
from ray.train.v2._internal.execution.callback import WorkerGroupCallback
|
| 11 |
+
from ray.train.v2._internal.execution.worker_group import ActorMetadata, WorkerGroup
|
| 12 |
+
from ray.train.v2._internal.util import ray_get_safe
|
| 13 |
+
from ray.train.v2.api.config import ScalingConfig
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class AcceleratorSetupCallback(WorkerGroupCallback):
|
| 19 |
+
"""Perform accelerator setup for workers.
|
| 20 |
+
|
| 21 |
+
For example, this callback can be used to share CUDA_VISIBLE_DEVICES
|
| 22 |
+
among workers on the same node.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, backend_config: BackendConfig, scaling_config: ScalingConfig):
|
| 26 |
+
self._backend = backend_config.backend_cls()
|
| 27 |
+
self._scaling_config = scaling_config
|
| 28 |
+
|
| 29 |
+
def after_worker_group_start(self, worker_group: WorkerGroup):
|
| 30 |
+
self._maybe_share_cuda_visible_devices(worker_group)
|
| 31 |
+
# TODO: Add support for sharing other accelerator resources.
|
| 32 |
+
|
| 33 |
+
def _maybe_share_cuda_visible_devices(self, worker_group: WorkerGroup):
|
| 34 |
+
share_cuda_visible_devices_enabled = env_bool(
|
| 35 |
+
ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV,
|
| 36 |
+
self._backend.share_cuda_visible_devices,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
if (
|
| 40 |
+
self._scaling_config._resources_per_worker_not_none.get("GPU", 0) > 0
|
| 41 |
+
and share_cuda_visible_devices_enabled
|
| 42 |
+
):
|
| 43 |
+
_share_cuda_visible_devices(worker_group)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _share_cuda_visible_devices(worker_group: WorkerGroup):
|
| 47 |
+
"""Sets CUDA_VISIBLE_DEVICES on all workers.
|
| 48 |
+
For each worker, CUDA_VISIBLE_DEVICES will be set to the GPU IDs
|
| 49 |
+
visible to all workers on that worker's node.
|
| 50 |
+
This allows GPU workers on the same node to communicate with one
|
| 51 |
+
another.
|
| 52 |
+
|
| 53 |
+
Example:
|
| 54 |
+
Setup:
|
| 55 |
+
- Node1:
|
| 56 |
+
- Worker1: {0, 1}
|
| 57 |
+
- Worker2: {2, 3}
|
| 58 |
+
- Node2:
|
| 59 |
+
- Worker3: {0, 1}
|
| 60 |
+
CUDA_VISIBLE_DEVICES:
|
| 61 |
+
- Worker1: "0,1,2,3"
|
| 62 |
+
- Worker2: "0,1,2,3"
|
| 63 |
+
- Worker2: "0,1"
|
| 64 |
+
"""
|
| 65 |
+
_share_accelerator_ids(
|
| 66 |
+
worker_group, ray_constants.GPU, ray_constants.CUDA_VISIBLE_DEVICES_ENV_VAR
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _share_accelerator_ids(
|
| 71 |
+
worker_group: WorkerGroup, accelerator_name: str, env_var: str
|
| 72 |
+
):
|
| 73 |
+
"""Sets the given env_var on all workers.
|
| 74 |
+
For each worker, the cores/devices are visible to all the
|
| 75 |
+
workers on that worker's node. This allows workers on the
|
| 76 |
+
same node to communicate with one another.
|
| 77 |
+
|
| 78 |
+
Example:
|
| 79 |
+
Setup:
|
| 80 |
+
- Node1:
|
| 81 |
+
- Worker1: {0, 1}
|
| 82 |
+
- Worker2: {2, 3}
|
| 83 |
+
- Node2:
|
| 84 |
+
- Worker3: {0, 1}
|
| 85 |
+
NEURON_RT_VISIBLE_CORES/TPU_VISIBLE_CHIPS/...:
|
| 86 |
+
- Worker1: "0,1,2,3"
|
| 87 |
+
- Worker2: "0,1,2,3"
|
| 88 |
+
- Worker2: "0,1"
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
accelerator_name: The name of the accelerator.
|
| 92 |
+
env_var: The name of the environment variable to set.
|
| 93 |
+
"""
|
| 94 |
+
if not worker_group.has_started():
|
| 95 |
+
raise RuntimeError(
|
| 96 |
+
"WorkerGroup must be started before sharing accelerator IDs."
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
worker_metadatas = [worker.metadata for worker in worker_group.get_workers()]
|
| 100 |
+
visible_accelerator_ids_per_worker = _get_visible_accelerator_ids_per_worker(
|
| 101 |
+
worker_metadatas=worker_metadatas, accelerator_name=accelerator_name
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
def set_accelerator_ids(accelerator_ids):
|
| 105 |
+
os.environ[env_var] = accelerator_ids
|
| 106 |
+
|
| 107 |
+
futures = []
|
| 108 |
+
for rank, visible_accelerator_ids in enumerate(visible_accelerator_ids_per_worker):
|
| 109 |
+
futures.append(
|
| 110 |
+
worker_group.execute_single_async(
|
| 111 |
+
rank, set_accelerator_ids, accelerator_ids=visible_accelerator_ids
|
| 112 |
+
)
|
| 113 |
+
)
|
| 114 |
+
ray_get_safe(futures)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _get_visible_accelerator_ids_per_worker(
|
| 118 |
+
worker_metadatas: List[ActorMetadata], accelerator_name: str
|
| 119 |
+
) -> List[str]:
|
| 120 |
+
"""Returns a list of comma-separated accelerator IDs visible to each worker.
|
| 121 |
+
|
| 122 |
+
All workers on a node should have the same set of visible accelerators,
|
| 123 |
+
which is the union of accelerator ids of the workers.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
visible_accelerator_ids_per_worker: A list of comma-separated accelerator ID
|
| 127 |
+
strings. This list is the same length as the number of workers.
|
| 128 |
+
|
| 129 |
+
"""
|
| 130 |
+
for metadata in worker_metadatas:
|
| 131 |
+
if accelerator_name not in metadata.accelerator_ids:
|
| 132 |
+
raise ValueError(
|
| 133 |
+
f"Accelerator '{accelerator_name}' is not available on all workers. "
|
| 134 |
+
f"Got these available accelerators instead: {metadata.accelerator_ids}"
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
node_id_to_accelerator_ids = defaultdict(set)
|
| 138 |
+
|
| 139 |
+
for metadata in worker_metadatas:
|
| 140 |
+
node_id_to_accelerator_ids[metadata.node_id].update(
|
| 141 |
+
metadata.accelerator_ids[accelerator_name]
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
visible_accelerator_ids_per_worker = []
|
| 145 |
+
for worker_id in range(len(worker_metadatas)):
|
| 146 |
+
node_id = worker_metadatas[worker_id].node_id
|
| 147 |
+
accelerator_ids = sorted(node_id_to_accelerator_ids[node_id])
|
| 148 |
+
all_resource_ids = ",".join([str(id) for id in accelerator_ids])
|
| 149 |
+
visible_accelerator_ids_per_worker.append(all_resource_ids)
|
| 150 |
+
|
| 151 |
+
return visible_accelerator_ids_per_worker
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/backend_setup.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
from ray.exceptions import RayActorError
|
| 4 |
+
from ray.train.backend import BackendConfig
|
| 5 |
+
from ray.train.v2._internal.execution.callback import WorkerGroupCallback
|
| 6 |
+
from ray.train.v2._internal.execution.worker_group import WorkerGroup
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BackendSetupCallback(WorkerGroupCallback):
|
| 12 |
+
def __init__(self, backend_config: BackendConfig):
|
| 13 |
+
self._backend_config = backend_config
|
| 14 |
+
self._backend = backend_config.backend_cls()
|
| 15 |
+
|
| 16 |
+
def after_worker_group_start(self, worker_group: WorkerGroup):
|
| 17 |
+
self._backend.on_start(worker_group, self._backend_config)
|
| 18 |
+
self._backend.on_training_start(worker_group, self._backend_config)
|
| 19 |
+
|
| 20 |
+
def before_worker_group_shutdown(self, worker_group: WorkerGroup):
|
| 21 |
+
try:
|
| 22 |
+
self._backend.on_shutdown(worker_group, self._backend_config)
|
| 23 |
+
except RayActorError:
|
| 24 |
+
logger.warning(
|
| 25 |
+
"Graceful shutdown of backend failed. This is "
|
| 26 |
+
"expected if one of the workers has crashed."
|
| 27 |
+
)
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/datasets.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from typing import Any, Callable, Dict, List, Union
|
| 3 |
+
|
| 4 |
+
import ray.train
|
| 5 |
+
from ray.data import Dataset
|
| 6 |
+
from ray.data.context import DataContext
|
| 7 |
+
from ray.train.v2._internal.execution.callback import WorkerGroupCallback
|
| 8 |
+
from ray.train.v2._internal.execution.worker_group.worker_group import WorkerGroup
|
| 9 |
+
|
| 10 |
+
# A type representing either a ray.data.Dataset or a function that returns a
|
| 11 |
+
# ray.data.Dataset and accepts no arguments.
|
| 12 |
+
GenDataset = Union[Dataset, Callable[[], Dataset]]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class DatasetsSetupCallback(WorkerGroupCallback):
|
| 16 |
+
"""The callback to setup Ray Datasets for the worker group."""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
datasets: Dict[str, GenDataset],
|
| 21 |
+
data_config: ray.train.DataConfig,
|
| 22 |
+
scaling_config: ray.train.ScalingConfig,
|
| 23 |
+
):
|
| 24 |
+
self._datasets = datasets
|
| 25 |
+
self._data_config = data_config
|
| 26 |
+
self._scaling_config = scaling_config
|
| 27 |
+
|
| 28 |
+
# Capture the current DataContext to propagate it to
|
| 29 |
+
# the Train workers later.
|
| 30 |
+
# The propagation works in the following way:
|
| 31 |
+
# 1. This callback is created when user create the Trainer.
|
| 32 |
+
# 2. Then this callback will be passed to the Controller actor.
|
| 33 |
+
# 3. Lastly, when the worker group is initialized, the Controller
|
| 34 |
+
# will call the `after_worker_group_start` callback to propagate
|
| 35 |
+
# the DataContext to Train workers.
|
| 36 |
+
self._data_context = copy.deepcopy(DataContext.get_current())
|
| 37 |
+
|
| 38 |
+
def get_train_total_resources(
|
| 39 |
+
self, scaling_config: ray.train.ScalingConfig
|
| 40 |
+
) -> Dict[str, float]:
|
| 41 |
+
"""Return the resources reserved for training, so that Data can exclude
|
| 42 |
+
these resources logically from its available pool."""
|
| 43 |
+
return scaling_config.total_resources
|
| 44 |
+
|
| 45 |
+
def before_init_train_context(
|
| 46 |
+
self, worker_group: "WorkerGroup"
|
| 47 |
+
) -> Dict[str, List[Any]]:
|
| 48 |
+
# Configure dataset shards
|
| 49 |
+
datasets = {k: v() if callable(v) else v for k, v in self._datasets.items()}
|
| 50 |
+
node_ids = [worker.metadata.node_id for worker in worker_group.get_workers()]
|
| 51 |
+
|
| 52 |
+
# Notify the DataConfig about the total resources reserved for training.
|
| 53 |
+
total_train_resources = self.get_train_total_resources(self._scaling_config)
|
| 54 |
+
self._data_config.set_train_total_resources(
|
| 55 |
+
total_train_resources.get("CPU", 0), total_train_resources.get("GPU", 0)
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
dataset_shards = self._data_config.configure(
|
| 59 |
+
datasets,
|
| 60 |
+
world_size=len(worker_group),
|
| 61 |
+
worker_handles=None,
|
| 62 |
+
worker_node_ids=node_ids,
|
| 63 |
+
)
|
| 64 |
+
assert len(dataset_shards) == len(worker_group)
|
| 65 |
+
|
| 66 |
+
return {"dataset_shards": dataset_shards}
|
| 67 |
+
|
| 68 |
+
def after_worker_group_start(self, worker_group: "WorkerGroup"):
|
| 69 |
+
# Propagate DataContext
|
| 70 |
+
def _propagate_data_context(ctx: DataContext):
|
| 71 |
+
DataContext._set_current(ctx)
|
| 72 |
+
|
| 73 |
+
worker_group.execute(
|
| 74 |
+
_propagate_data_context,
|
| 75 |
+
self._data_context,
|
| 76 |
+
)
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/metrics.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import threading
|
| 2 |
+
import time
|
| 3 |
+
from contextlib import contextmanager
|
| 4 |
+
from dataclasses import asdict, dataclass, field, fields
|
| 5 |
+
from typing import Dict, Optional
|
| 6 |
+
|
| 7 |
+
from ray.train.v2._internal.execution.callback import (
|
| 8 |
+
ControllerCallback,
|
| 9 |
+
TrainContextCallback,
|
| 10 |
+
WorkerCallback,
|
| 11 |
+
WorkerGroupCallback,
|
| 12 |
+
)
|
| 13 |
+
from ray.train.v2._internal.execution.context import TrainRunContext, get_train_context
|
| 14 |
+
from ray.train.v2._internal.util import time_monotonic
|
| 15 |
+
from ray.util.metrics import Gauge
|
| 16 |
+
|
| 17 |
+
# Prometheus Tag keys for the worker and controller metrics.
|
| 18 |
+
RUN_NAME_TAG_KEY = "ray_train_run_name"
|
| 19 |
+
WORKER_WORLD_RANK_TAG_KEY = "ray_train_worker_world_rank"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class ControllerMetrics:
|
| 24 |
+
"""A list of Train controller metrics.
|
| 25 |
+
|
| 26 |
+
Metric metadata attributes:
|
| 27 |
+
- description (required): A human-readable description of the metric, also used as
|
| 28 |
+
the chart description on the Ray Train dashboard.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
train_worker_group_start_total_time_s: float = field(
|
| 32 |
+
default=0.0,
|
| 33 |
+
metadata={
|
| 34 |
+
"description": (
|
| 35 |
+
"Cumulative time in seconds to start worker groups in the Train job."
|
| 36 |
+
),
|
| 37 |
+
},
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
train_worker_group_shutdown_total_time_s: float = field(
|
| 41 |
+
default=0.0,
|
| 42 |
+
metadata={
|
| 43 |
+
"description": (
|
| 44 |
+
"Cumulative time in seconds to shutdown worker groups in the Train job."
|
| 45 |
+
),
|
| 46 |
+
},
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@dataclass
|
| 51 |
+
class WorkerMetrics:
|
| 52 |
+
"""A list of Train worker metrics.
|
| 53 |
+
|
| 54 |
+
Metric metadata attributes:
|
| 55 |
+
- description (required): A human-readable description of the metric, also used as
|
| 56 |
+
the chart description on the Ray Train dashboard.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
train_report_total_blocked_time_s: float = field(
|
| 60 |
+
default=0.0,
|
| 61 |
+
metadata={
|
| 62 |
+
"description": (
|
| 63 |
+
"Cumulative time in seconds to report a checkpoint to the storage."
|
| 64 |
+
),
|
| 65 |
+
},
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class ControllerMetricsCallback(ControllerCallback, WorkerGroupCallback):
|
| 70 |
+
# Interval for pushing metrics to Prometheus.
|
| 71 |
+
LOCAL_METRICS_PUSH_INTERVAL_S: float = 5.0
|
| 72 |
+
CONTROLLER_TAG_KEYS = (RUN_NAME_TAG_KEY,)
|
| 73 |
+
|
| 74 |
+
def __init__(self, train_run_context: TrainRunContext):
|
| 75 |
+
"""
|
| 76 |
+
This callback is initialized on the driver process and then passed to the
|
| 77 |
+
controller. This callback collects metrics from the controller actor as well
|
| 78 |
+
as the metrics related to the worker groups.
|
| 79 |
+
"""
|
| 80 |
+
self._run_name = train_run_context.get_run_config().name
|
| 81 |
+
self._thread: Optional[threading.Thread] = None
|
| 82 |
+
self._thread_stop_event: Optional[threading.Event] = None
|
| 83 |
+
self._metrics: Optional[ControllerMetrics] = None
|
| 84 |
+
self._metrics_lock: Optional[threading.Lock] = None
|
| 85 |
+
self._controller_tag: Dict[str, str] = {}
|
| 86 |
+
self._metrics_gauges: Dict[str, Gauge] = {}
|
| 87 |
+
|
| 88 |
+
def _create_prometheus_controller_metrics(self) -> Dict[str, Gauge]:
|
| 89 |
+
"""Create Prometheus worker metrics for the ControllerMetrics dataclass."""
|
| 90 |
+
metrics = {}
|
| 91 |
+
for _field in fields(ControllerMetrics):
|
| 92 |
+
metric_description = _field.metadata.get("description")
|
| 93 |
+
metrics[_field.name] = Gauge(
|
| 94 |
+
_field.name,
|
| 95 |
+
description=metric_description,
|
| 96 |
+
tag_keys=self.CONTROLLER_TAG_KEYS,
|
| 97 |
+
)
|
| 98 |
+
return metrics
|
| 99 |
+
|
| 100 |
+
def after_controller_start(self):
|
| 101 |
+
"""
|
| 102 |
+
Creating a thread to periodically push local metrics to the gauges
|
| 103 |
+
after the train controller starts.
|
| 104 |
+
"""
|
| 105 |
+
self._controller_tag = {
|
| 106 |
+
RUN_NAME_TAG_KEY: self._run_name,
|
| 107 |
+
}
|
| 108 |
+
self._thread_stop_event = threading.Event()
|
| 109 |
+
self._metrics_lock = threading.Lock()
|
| 110 |
+
self._metrics = ControllerMetrics()
|
| 111 |
+
self._metrics_gauges = self._create_prometheus_controller_metrics()
|
| 112 |
+
|
| 113 |
+
def push_local_metrics():
|
| 114 |
+
while not self._thread_stop_event.is_set():
|
| 115 |
+
with self._metrics_lock:
|
| 116 |
+
metrics_dict = asdict(self._metrics)
|
| 117 |
+
for metric_name, metric_value in metrics_dict.items():
|
| 118 |
+
self._metrics_gauges[metric_name].set(
|
| 119 |
+
metric_value, self._controller_tag
|
| 120 |
+
)
|
| 121 |
+
time.sleep(ControllerMetricsCallback.LOCAL_METRICS_PUSH_INTERVAL_S)
|
| 122 |
+
|
| 123 |
+
assert not self._thread
|
| 124 |
+
self._thread = threading.Thread(target=push_local_metrics, daemon=True)
|
| 125 |
+
self._thread.start()
|
| 126 |
+
|
| 127 |
+
def before_controller_shutdown(self):
|
| 128 |
+
"""
|
| 129 |
+
Stop the thread that pushes local metrics to the gauges before the
|
| 130 |
+
controller shuts down.
|
| 131 |
+
"""
|
| 132 |
+
# Stop the thread that pushes local metrics to the metrics gauges.
|
| 133 |
+
assert not self._thread_stop_event.is_set()
|
| 134 |
+
self._thread_stop_event.set()
|
| 135 |
+
# Reset the metrics to their default values.
|
| 136 |
+
for _field in fields(self._metrics):
|
| 137 |
+
self._metrics_gauges[_field.name].set(_field.default, self._controller_tag)
|
| 138 |
+
|
| 139 |
+
@contextmanager
|
| 140 |
+
def on_worker_group_start(self):
|
| 141 |
+
"""
|
| 142 |
+
Context manager to measure the time taken to start a worker group.
|
| 143 |
+
"""
|
| 144 |
+
start_time_s = time_monotonic()
|
| 145 |
+
yield
|
| 146 |
+
elapsed_time_s = time_monotonic() - start_time_s
|
| 147 |
+
with self._metrics_lock:
|
| 148 |
+
self._metrics.train_worker_group_start_total_time_s += elapsed_time_s
|
| 149 |
+
|
| 150 |
+
@contextmanager
|
| 151 |
+
def on_worker_group_shutdown(self):
|
| 152 |
+
"""
|
| 153 |
+
Context manager to measure the time taken to start a worker group.
|
| 154 |
+
"""
|
| 155 |
+
start_time_s = time_monotonic()
|
| 156 |
+
yield
|
| 157 |
+
elapsed_time_s = time_monotonic() - start_time_s
|
| 158 |
+
with self._metrics_lock:
|
| 159 |
+
self._metrics.train_worker_group_shutdown_total_time_s += elapsed_time_s
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class WorkerMetricsCallback(WorkerCallback, TrainContextCallback):
|
| 163 |
+
# Interval for pushing metrics to Prometheus.
|
| 164 |
+
LOCAL_METRICS_PUSH_INTERVAL_S: float = 5.0
|
| 165 |
+
WORKER_TAG_KEYS = (RUN_NAME_TAG_KEY, WORKER_WORLD_RANK_TAG_KEY)
|
| 166 |
+
|
| 167 |
+
def __init__(self, train_run_context: TrainRunContext):
|
| 168 |
+
"""
|
| 169 |
+
This callback is initialized on the driver process and then passed to the
|
| 170 |
+
workers. When adding more class attributes, make sure the attributes are
|
| 171 |
+
serializable picklable.
|
| 172 |
+
|
| 173 |
+
TODO: Making Callbacks factory methods that when they are initialized on the
|
| 174 |
+
driver process, we do not need to worry about pickling the callback instances.
|
| 175 |
+
"""
|
| 176 |
+
self._run_name = train_run_context.get_run_config().name
|
| 177 |
+
self._thread: Optional[threading.Thread] = None
|
| 178 |
+
self._thread_stop_event: Optional[threading.Event] = None
|
| 179 |
+
self._metrics_lock: Optional[threading.Lock] = None
|
| 180 |
+
self._metrics: Optional[WorkerMetrics] = None
|
| 181 |
+
self._worker_tag: Dict[str, str] = {}
|
| 182 |
+
self._metrics_gauges: Dict[str, Gauge] = {}
|
| 183 |
+
|
| 184 |
+
def _create_prometheus_worker_metrics(self) -> Dict[str, Gauge]:
|
| 185 |
+
"""Create Prometheus worker metrics for the TrainMetrics dataclass."""
|
| 186 |
+
metrics = {}
|
| 187 |
+
for _field in fields(self._metrics):
|
| 188 |
+
metric_description = _field.metadata.get("description")
|
| 189 |
+
metrics[_field.name] = Gauge(
|
| 190 |
+
_field.name,
|
| 191 |
+
description=metric_description,
|
| 192 |
+
tag_keys=self.WORKER_TAG_KEYS,
|
| 193 |
+
)
|
| 194 |
+
return metrics
|
| 195 |
+
|
| 196 |
+
def after_init_train_context(self):
|
| 197 |
+
"""
|
| 198 |
+
Creating a thread to periodically push local metrics to the gauges
|
| 199 |
+
after the train context is initialized.
|
| 200 |
+
|
| 201 |
+
Note:
|
| 202 |
+
This method should be called after the train context is initialized on
|
| 203 |
+
each of the worker. The thread should not be created in the `__init__`
|
| 204 |
+
method which is called on the train driver process.
|
| 205 |
+
"""
|
| 206 |
+
self._worker_tag = {
|
| 207 |
+
RUN_NAME_TAG_KEY: self._run_name,
|
| 208 |
+
WORKER_WORLD_RANK_TAG_KEY: str(get_train_context().get_world_rank()),
|
| 209 |
+
}
|
| 210 |
+
self._thread_stop_event = threading.Event()
|
| 211 |
+
self._metrics_lock = threading.Lock()
|
| 212 |
+
self._metrics = WorkerMetrics()
|
| 213 |
+
self._metrics_gauges = self._create_prometheus_worker_metrics()
|
| 214 |
+
|
| 215 |
+
def push_local_metrics():
|
| 216 |
+
while not self._thread_stop_event.is_set():
|
| 217 |
+
with self._metrics_lock:
|
| 218 |
+
metrics_dict = asdict(self._metrics)
|
| 219 |
+
for metric_name, metric_value in metrics_dict.items():
|
| 220 |
+
self._metrics_gauges[metric_name].set(
|
| 221 |
+
metric_value, self._worker_tag
|
| 222 |
+
)
|
| 223 |
+
time.sleep(WorkerMetricsCallback.LOCAL_METRICS_PUSH_INTERVAL_S)
|
| 224 |
+
|
| 225 |
+
assert not self._thread
|
| 226 |
+
self._thread = threading.Thread(target=push_local_metrics, daemon=True)
|
| 227 |
+
self._thread.start()
|
| 228 |
+
|
| 229 |
+
def before_worker_shutdown(self):
|
| 230 |
+
"""
|
| 231 |
+
Stop the thread that pushes local metrics to the metrics gauges before
|
| 232 |
+
the worker group shuts down.
|
| 233 |
+
"""
|
| 234 |
+
# Stop the thread that pushes local metrics to the gauges.
|
| 235 |
+
assert not self._thread_stop_event.is_set()
|
| 236 |
+
self._thread_stop_event.set()
|
| 237 |
+
# Reset the metrics to their default values.
|
| 238 |
+
for _field in fields(self._metrics):
|
| 239 |
+
self._metrics_gauges[_field.name].set(_field.default, self._worker_tag)
|
| 240 |
+
|
| 241 |
+
@contextmanager
|
| 242 |
+
def on_report(self):
|
| 243 |
+
"""
|
| 244 |
+
Context manager to measure the time taken to report a checkpoint to the storage.
|
| 245 |
+
"""
|
| 246 |
+
start_time_s = time_monotonic()
|
| 247 |
+
yield
|
| 248 |
+
elapsed_time_s = time_monotonic() - start_time_s
|
| 249 |
+
with self._metrics_lock:
|
| 250 |
+
self._metrics.train_report_total_blocked_time_s += elapsed_time_s
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/user_callback.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Optional
|
| 2 |
+
|
| 3 |
+
from ray.train import Checkpoint
|
| 4 |
+
from ray.train.v2._internal.execution.callback import (
|
| 5 |
+
ReportCallback,
|
| 6 |
+
WorkerGroupCallback,
|
| 7 |
+
)
|
| 8 |
+
from ray.train.v2._internal.execution.context import TrainRunContext
|
| 9 |
+
from ray.train.v2._internal.execution.worker_group import WorkerGroupStatus
|
| 10 |
+
from ray.train.v2.api.callback import UserCallback
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class UserCallbackHandler(WorkerGroupCallback, ReportCallback):
|
| 14 |
+
"""Responsible for calling methods of subscribers implementing
|
| 15 |
+
the `UserCallback` interface.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self, user_callbacks: List[UserCallback], train_run_context: TrainRunContext
|
| 20 |
+
):
|
| 21 |
+
self._user_callbacks = user_callbacks
|
| 22 |
+
self._train_run_context = train_run_context
|
| 23 |
+
|
| 24 |
+
# --------------------------
|
| 25 |
+
# ReportCallback
|
| 26 |
+
# --------------------------
|
| 27 |
+
|
| 28 |
+
def after_report(
|
| 29 |
+
self, metrics: List[Dict[str, Any]], checkpoint: Optional[Checkpoint]
|
| 30 |
+
):
|
| 31 |
+
for user_callback in self._user_callbacks:
|
| 32 |
+
user_callback.after_report(
|
| 33 |
+
run_context=self._train_run_context,
|
| 34 |
+
metrics=metrics,
|
| 35 |
+
checkpoint=checkpoint,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# --------------------------
|
| 39 |
+
# WorkerGroupCallback
|
| 40 |
+
# --------------------------
|
| 41 |
+
|
| 42 |
+
def after_worker_group_poll_status(self, worker_group_status: WorkerGroupStatus):
|
| 43 |
+
if not worker_group_status.errors:
|
| 44 |
+
return
|
| 45 |
+
|
| 46 |
+
for user_callback in self._user_callbacks:
|
| 47 |
+
user_callback.after_exception(
|
| 48 |
+
run_context=self._train_run_context,
|
| 49 |
+
worker_exceptions=worker_group_status.errors,
|
| 50 |
+
)
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/working_dir_setup.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from ray.train.v2._internal.execution.callback import WorkerGroupCallback
|
| 5 |
+
from ray.train.v2._internal.execution.context import get_train_context
|
| 6 |
+
from ray.train.v2._internal.execution.worker_group import WorkerGroup
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class WorkingDirectorySetupCallback(WorkerGroupCallback):
|
| 12 |
+
def after_worker_group_start(self, worker_group: WorkerGroup):
|
| 13 |
+
def chdir_to_working_dir() -> None:
|
| 14 |
+
"""Create the local working directory for the experiment."""
|
| 15 |
+
local_working_directory = (
|
| 16 |
+
get_train_context().get_storage().local_working_directory
|
| 17 |
+
)
|
| 18 |
+
os.makedirs(local_working_directory, exist_ok=True)
|
| 19 |
+
logger.debug(
|
| 20 |
+
f"Changing the working directory to: {local_working_directory}"
|
| 21 |
+
)
|
| 22 |
+
os.chdir(local_working_directory)
|
| 23 |
+
|
| 24 |
+
worker_group.execute(chdir_to_working_dir)
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/constants.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Dict
|
| 3 |
+
|
| 4 |
+
from ray._private.ray_constants import env_bool, env_set_by_user
|
| 5 |
+
|
| 6 |
+
# Unsupported configs can use this value to detect if the user has set it.
|
| 7 |
+
_UNSUPPORTED = "UNSUPPORTED"
|
| 8 |
+
_DEPRECATED = "DEPRECATED"
|
| 9 |
+
|
| 10 |
+
# The name of the file that is used to validate the storage.
|
| 11 |
+
VALIDATE_STORAGE_MARKER_FILENAME = ".validate_storage_marker"
|
| 12 |
+
# The name of the file that is used to store the checkpoint manager snapshot.
|
| 13 |
+
CHECKPOINT_MANAGER_SNAPSHOT_FILENAME = "checkpoint_manager_snapshot.json"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# =====================
|
| 17 |
+
# Environment Variables
|
| 18 |
+
# =====================
|
| 19 |
+
|
| 20 |
+
# Polling interval for the Train controller.
|
| 21 |
+
# This determines how many seconds the controller will wait between
|
| 22 |
+
# polling the worker group for its status.
|
| 23 |
+
HEALTH_CHECK_INTERVAL_S_ENV_VAR = "RAY_TRAIN_HEALTH_CHECK_INTERVAL_S"
|
| 24 |
+
DEFAULT_HEALTH_CHECK_INTERVAL_S: float = 2.0
|
| 25 |
+
|
| 26 |
+
# The time in seconds a worker health check must be hanging for
|
| 27 |
+
# before the controller marks the worker as dead and handles the failure.
|
| 28 |
+
WORKER_HEALTH_CHECK_TIMEOUT_S_ENV_VAR = "RAY_TRAIN_WORKER_HEALTH_CHECK_TIMEOUT_S"
|
| 29 |
+
DEFAULT_WORKER_HEALTH_CHECK_TIMEOUT_S: float = 10 * 60
|
| 30 |
+
|
| 31 |
+
# Timeout in seconds for the worker group to start.
|
| 32 |
+
WORKER_GROUP_START_TIMEOUT_S_ENV_VAR = "RAY_TRAIN_WORKER_GROUP_START_TIMEOUT_S"
|
| 33 |
+
DEFAULT_WORKER_GROUP_START_TIMEOUT_S: float = 30.0
|
| 34 |
+
|
| 35 |
+
# Timeout in seconds for `ray.train.report` to block on synchronization barriers,
|
| 36 |
+
# after which a timeout error will be raised.
|
| 37 |
+
REPORT_BARRIER_TIMEOUT_S_ENV_VAR = "RAY_TRAIN_REPORT_BARRIER_TIMEOUT_S"
|
| 38 |
+
DEFAULT_REPORT_BARRIER_TIMEOUT_S: float = 60 * 30
|
| 39 |
+
# Time in seconds for `ray.train.report` to log a warning if it is waiting for sync
|
| 40 |
+
# actor notification of releasing.
|
| 41 |
+
REPORT_BARRIER_WARN_INTERVAL_S_ENV_VAR = "RAY_TRAIN_REPORT_BARRIER_WARN_INTERVAL_S"
|
| 42 |
+
DEFAULT_REPORT_BARRIER_WARN_INTERVAL_S: float = 60
|
| 43 |
+
|
| 44 |
+
# The environment variable to enable the Ray Train Metrics.
|
| 45 |
+
METRICS_ENABLED_ENV_VAR = "RAY_TRAIN_METRICS_ENABLED"
|
| 46 |
+
|
| 47 |
+
# Environment variable to enable the print function patching.
|
| 48 |
+
ENABLE_PRINT_PATCH_ENV_VAR = "RAY_TRAIN_ENABLE_PRINT_PATCH"
|
| 49 |
+
DEFAULT_ENABLE_PRINT_PATCH = "1"
|
| 50 |
+
|
| 51 |
+
# Whether or not to run the controller as an actor.
|
| 52 |
+
RUN_CONTROLLER_AS_ACTOR_ENV_VAR = "RAY_TRAIN_RUN_CONTROLLER_AS_ACTOR"
|
| 53 |
+
DEFAULT_RUN_CONTROLLER_AS_ACTOR = "1"
|
| 54 |
+
|
| 55 |
+
# V2 feature flag.
|
| 56 |
+
V2_ENABLED_ENV_VAR = "RAY_TRAIN_V2_ENABLED"
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def is_v2_enabled() -> bool:
|
| 60 |
+
return env_bool(V2_ENABLED_ENV_VAR, False)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
ENV_VARS_TO_PROPAGATE = {
|
| 64 |
+
V2_ENABLED_ENV_VAR,
|
| 65 |
+
HEALTH_CHECK_INTERVAL_S_ENV_VAR,
|
| 66 |
+
WORKER_HEALTH_CHECK_TIMEOUT_S_ENV_VAR,
|
| 67 |
+
WORKER_GROUP_START_TIMEOUT_S_ENV_VAR,
|
| 68 |
+
ENABLE_PRINT_PATCH_ENV_VAR,
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def get_env_vars_to_propagate() -> Dict[str, str]:
|
| 73 |
+
"""Returns a dictionary of environment variables that should be propagated
|
| 74 |
+
from the driver to the controller, and then from the controller
|
| 75 |
+
to each training worker.
|
| 76 |
+
|
| 77 |
+
This way, users only need to set environment variables in one place
|
| 78 |
+
when launching the script instead of needing to manually set a runtime environment.
|
| 79 |
+
"""
|
| 80 |
+
env_vars = {}
|
| 81 |
+
for env_var in ENV_VARS_TO_PROPAGATE:
|
| 82 |
+
if env_set_by_user(env_var):
|
| 83 |
+
env_vars[env_var] = os.environ[env_var]
|
| 84 |
+
return env_vars
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/exceptions.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Dict, List, Optional
|
| 3 |
+
|
| 4 |
+
from ray.train.v2._internal.constants import (
|
| 5 |
+
DEFAULT_WORKER_GROUP_START_TIMEOUT_S,
|
| 6 |
+
DEFAULT_WORKER_HEALTH_CHECK_TIMEOUT_S,
|
| 7 |
+
REPORT_BARRIER_TIMEOUT_S_ENV_VAR,
|
| 8 |
+
WORKER_GROUP_START_TIMEOUT_S_ENV_VAR,
|
| 9 |
+
WORKER_HEALTH_CHECK_TIMEOUT_S_ENV_VAR,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# TODO: Distinguish between user and system exceptions.
|
| 14 |
+
class RayTrainError(Exception):
|
| 15 |
+
"""Base class for all Ray Train exceptions."""
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class WorkerHealthCheckTimeoutError(RayTrainError):
|
| 19 |
+
"""Exception raised when a worker health check hangs for long enough."""
|
| 20 |
+
|
| 21 |
+
def __init__(self, message):
|
| 22 |
+
timeout = os.getenv(
|
| 23 |
+
WORKER_HEALTH_CHECK_TIMEOUT_S_ENV_VAR, DEFAULT_WORKER_HEALTH_CHECK_TIMEOUT_S
|
| 24 |
+
)
|
| 25 |
+
message += (
|
| 26 |
+
f"\nSet the {WORKER_HEALTH_CHECK_TIMEOUT_S_ENV_VAR} "
|
| 27 |
+
"environment variable to increase the timeout "
|
| 28 |
+
f"(current value: {timeout} seconds)."
|
| 29 |
+
)
|
| 30 |
+
super().__init__(message)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class WorkerHealthCheckFailedError(RayTrainError):
|
| 34 |
+
"""Exception raised when a worker health check fails."""
|
| 35 |
+
|
| 36 |
+
def __init__(self, message, failure: Exception):
|
| 37 |
+
super().__init__(message)
|
| 38 |
+
self._message = message
|
| 39 |
+
self.health_check_failure = failure
|
| 40 |
+
|
| 41 |
+
def __reduce__(self):
|
| 42 |
+
return (self.__class__, (self._message, self.health_check_failure))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class TrainingFailedError(RayTrainError):
|
| 46 |
+
"""Exception raised when training fails."""
|
| 47 |
+
|
| 48 |
+
def __init__(self, worker_failures: Dict[int, Exception]):
|
| 49 |
+
super().__init__(
|
| 50 |
+
"Training failed due to worker errors. "
|
| 51 |
+
"Please inspect the error logs above, "
|
| 52 |
+
"or access the latest worker failures in this "
|
| 53 |
+
"exception's `worker_failures` attribute."
|
| 54 |
+
)
|
| 55 |
+
self.worker_failures = worker_failures
|
| 56 |
+
|
| 57 |
+
def __reduce__(self):
|
| 58 |
+
return (self.__class__, (self.worker_failures,))
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class WorkerGroupStartupTimeoutError(RayTrainError):
|
| 62 |
+
"""Exception raised when the worker group startup times out.
|
| 63 |
+
|
| 64 |
+
Example scenario: 4 GPUs are detected in the cluster, but when the worker
|
| 65 |
+
are actually scheduled, one of the nodes goes down and only 3 GPUs are
|
| 66 |
+
available. One of the worker tasks may be stuck pending, until a timeout is reached.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def __init__(self, num_workers: int):
|
| 70 |
+
timeout = float(
|
| 71 |
+
os.environ.get(
|
| 72 |
+
WORKER_GROUP_START_TIMEOUT_S_ENV_VAR,
|
| 73 |
+
DEFAULT_WORKER_GROUP_START_TIMEOUT_S,
|
| 74 |
+
)
|
| 75 |
+
)
|
| 76 |
+
self.num_workers = num_workers
|
| 77 |
+
super().__init__(
|
| 78 |
+
f"The worker group startup timed out after {timeout} seconds waiting "
|
| 79 |
+
f"for {num_workers} workers. "
|
| 80 |
+
"Potential causes include: "
|
| 81 |
+
"(1) temporary insufficient cluster resources while waiting for "
|
| 82 |
+
"autoscaling (ignore this warning in this case), "
|
| 83 |
+
"(2) infeasible resource request where the provided `ScalingConfig` "
|
| 84 |
+
"cannot be satisfied), "
|
| 85 |
+
"and (3) transient network issues. "
|
| 86 |
+
f"Set the {WORKER_GROUP_START_TIMEOUT_S_ENV_VAR} "
|
| 87 |
+
"environment variable to increase the timeout."
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
def __reduce__(self):
|
| 91 |
+
return (self.__class__, (self.num_workers,))
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class WorkerGroupStartupFailedError(RayTrainError):
|
| 95 |
+
"""Exception raised when the worker group fails to start.
|
| 96 |
+
|
| 97 |
+
Example scenario: A worker is scheduled onto a node that dies while
|
| 98 |
+
the worker actor is initializing.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class CheckpointManagerInitializationError(RayTrainError):
|
| 103 |
+
"""Exception raised when the checkpoint manager fails to initialize from a snapshot.
|
| 104 |
+
|
| 105 |
+
Example scenarios:
|
| 106 |
+
1. The checkpoint manager snapshot version is old and
|
| 107 |
+
incompatible with the current version of Ray Train.
|
| 108 |
+
2. The checkpoint manager snapshot JSON file is corrupted.
|
| 109 |
+
3. The checkpoint manager snapshot references checkpoints that cannot be found
|
| 110 |
+
in the run storage path.
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class CollectiveTimeoutError(RayTrainError):
|
| 115 |
+
"""Exception raised when an internal Ray Train collective operation of
|
| 116 |
+
the worker group times out.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class BroadcastCollectiveTimeoutError(CollectiveTimeoutError):
|
| 121 |
+
"""Exception raised when the broadcast operation times out.
|
| 122 |
+
|
| 123 |
+
There are two main timeout examples:
|
| 124 |
+
1. If not all workers call `ray.train.report`, the entire worker group will
|
| 125 |
+
hang until the timeout before raising. This prevents indefinite worker
|
| 126 |
+
group hangs.
|
| 127 |
+
2. If a worker is slow in the training loop and fails to reach the broadcast
|
| 128 |
+
time, the collective will time out.
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
def __init__(
|
| 132 |
+
self, time_elapsed: Optional[float], missing_ranks: List[int], timeout_s: float
|
| 133 |
+
):
|
| 134 |
+
self._time_elapsed = time_elapsed
|
| 135 |
+
self._missing_ranks = missing_ranks
|
| 136 |
+
self._timeout_s = timeout_s
|
| 137 |
+
|
| 138 |
+
message = (
|
| 139 |
+
f"The broadcast operation timed out after {time_elapsed:.2f} seconds. "
|
| 140 |
+
"Please make sure all worker ranks call `ray.train.report`. \n"
|
| 141 |
+
f"The following ranks have not called it: {missing_ranks}\n"
|
| 142 |
+
f"You can set this timeout with the {REPORT_BARRIER_TIMEOUT_S_ENV_VAR} "
|
| 143 |
+
f"environment variable (current value: {timeout_s:.2f} s)."
|
| 144 |
+
)
|
| 145 |
+
super().__init__(message)
|
| 146 |
+
|
| 147 |
+
def __reduce__(self):
|
| 148 |
+
return (
|
| 149 |
+
self.__class__,
|
| 150 |
+
(self._time_elapsed, self._missing_ranks, self._timeout_s),
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class UserExceptionWithTraceback(RayTrainError):
|
| 155 |
+
"""This class wraps a user code exception raised on the worker
|
| 156 |
+
with its original traceback string, for logging and debugging purposes.
|
| 157 |
+
|
| 158 |
+
This is needed because the original exception traceback is not serialized
|
| 159 |
+
with the exception when it is *returned* back to the main process.
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
def __init__(self, exc: BaseException, traceback_str: str):
|
| 163 |
+
self._base_exc = exc
|
| 164 |
+
self._traceback_str = traceback_str
|
| 165 |
+
|
| 166 |
+
def __reduce__(self):
|
| 167 |
+
return (self.__class__, (self._base_exc, self._traceback_str))
|
| 168 |
+
|
| 169 |
+
def __str__(self):
|
| 170 |
+
return self._traceback_str
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__pycache__/callback.cpython-311.pyc
ADDED
|
Binary file (7.91 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__pycache__/context.cpython-311.pyc
ADDED
|
Binary file (13.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__pycache__/controller.cpython-311.pyc
ADDED
|
Binary file (18.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__pycache__/storage.cpython-311.pyc
ADDED
|
Binary file (28.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/callback.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from contextlib import contextmanager
|
| 2 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
| 3 |
+
|
| 4 |
+
from ray.train.v2.api.callback import RayTrainCallback
|
| 5 |
+
from ray.util.annotations import DeveloperAPI
|
| 6 |
+
|
| 7 |
+
if TYPE_CHECKING:
|
| 8 |
+
from ray.train import Checkpoint
|
| 9 |
+
from ray.train.v2._internal.execution.controller import TrainControllerState
|
| 10 |
+
from ray.train.v2._internal.execution.failure_handling import FailureDecision
|
| 11 |
+
from ray.train.v2._internal.execution.scaling_policy import ScalingDecision
|
| 12 |
+
from ray.train.v2._internal.execution.worker_group import (
|
| 13 |
+
WorkerGroup,
|
| 14 |
+
WorkerGroupStatus,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@DeveloperAPI
|
| 19 |
+
class WorkerGroupCallback(RayTrainCallback):
|
| 20 |
+
def before_init_train_context(
|
| 21 |
+
self, worker_group: "WorkerGroup"
|
| 22 |
+
) -> Dict[str, List[Any]]:
|
| 23 |
+
"""Called before initializing the TrainContext for the worker_group.
|
| 24 |
+
|
| 25 |
+
Return:
|
| 26 |
+
A dictionary of additional arguments for TrainContext.
|
| 27 |
+
The key is the argument name and the value is a list of argument values
|
| 28 |
+
to pass to the TrainContext constructor of each worker in the worker group.
|
| 29 |
+
"""
|
| 30 |
+
return {}
|
| 31 |
+
|
| 32 |
+
@contextmanager
|
| 33 |
+
def on_worker_group_start(self):
|
| 34 |
+
yield
|
| 35 |
+
|
| 36 |
+
def after_worker_group_start(self, worker_group: "WorkerGroup"):
|
| 37 |
+
"""Called after the worker group actors are initialized.
|
| 38 |
+
All workers should be ready to execute tasks."""
|
| 39 |
+
pass
|
| 40 |
+
|
| 41 |
+
def after_worker_group_training_start(self, worker_group: "WorkerGroup"):
|
| 42 |
+
pass
|
| 43 |
+
|
| 44 |
+
@contextmanager
|
| 45 |
+
def on_worker_group_shutdown(self):
|
| 46 |
+
yield
|
| 47 |
+
|
| 48 |
+
def before_worker_group_shutdown(self, worker_group: "WorkerGroup"):
|
| 49 |
+
"""Called before the worker group is shut down.
|
| 50 |
+
Workers may be dead at this point due to actor failures, so this method
|
| 51 |
+
should catch and handle exceptions if attempting to execute tasks."""
|
| 52 |
+
pass
|
| 53 |
+
|
| 54 |
+
def after_worker_group_poll_status(self, worker_group_status: "WorkerGroupStatus"):
|
| 55 |
+
pass
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@DeveloperAPI
|
| 59 |
+
class ControllerCallback(RayTrainCallback):
|
| 60 |
+
def after_controller_start(self):
|
| 61 |
+
"""Called immediately after `TrainController.run` is called,
|
| 62 |
+
before the control loop starts executing."""
|
| 63 |
+
pass
|
| 64 |
+
|
| 65 |
+
def before_controller_shutdown(self):
|
| 66 |
+
"""Called before `TrainController.run` exits,
|
| 67 |
+
after the control loop has exited."""
|
| 68 |
+
pass
|
| 69 |
+
|
| 70 |
+
def after_controller_state_update(
|
| 71 |
+
self,
|
| 72 |
+
previous_state: "TrainControllerState",
|
| 73 |
+
current_state: "TrainControllerState",
|
| 74 |
+
):
|
| 75 |
+
"""Called whenever the controller state is updated."""
|
| 76 |
+
pass
|
| 77 |
+
|
| 78 |
+
def before_controller_execute_failure_decision(
|
| 79 |
+
self,
|
| 80 |
+
failure_decision: "FailureDecision",
|
| 81 |
+
worker_group_status: "WorkerGroupStatus",
|
| 82 |
+
):
|
| 83 |
+
"""Called before the controller executes a failure decision."""
|
| 84 |
+
pass
|
| 85 |
+
|
| 86 |
+
def before_controller_execute_scaling_decision(
|
| 87 |
+
self,
|
| 88 |
+
scaling_decision: "ScalingDecision",
|
| 89 |
+
worker_group_status: "WorkerGroupStatus",
|
| 90 |
+
):
|
| 91 |
+
"""Called before the controller executes a scaling decision."""
|
| 92 |
+
pass
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@DeveloperAPI
|
| 96 |
+
class ReportCallback(RayTrainCallback):
|
| 97 |
+
def after_report(
|
| 98 |
+
self, metrics: List[Dict[str, Any]], checkpoint: Optional["Checkpoint"]
|
| 99 |
+
):
|
| 100 |
+
"""Called after all workers have reported a training result.
|
| 101 |
+
|
| 102 |
+
Note that this differs from `after_worker_group_poll_status`,
|
| 103 |
+
which may only contain a subset of workers that have reported.
|
| 104 |
+
For example, if only rank 0 is performing checkpointing, then
|
| 105 |
+
rank 0 would report a training result the slowest.
|
| 106 |
+
"""
|
| 107 |
+
pass
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@DeveloperAPI
|
| 111 |
+
class WorkerCallback(RayTrainCallback):
|
| 112 |
+
"""
|
| 113 |
+
Callbacks that are hooked to the worker event.
|
| 114 |
+
|
| 115 |
+
These callbacks are created on the train driver process and then
|
| 116 |
+
copied and passed to all the workers.
|
| 117 |
+
The execution of these callbacks happens on each of the workers,
|
| 118 |
+
not on the train driver process.
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
def after_init_train_context(self):
|
| 122 |
+
pass
|
| 123 |
+
|
| 124 |
+
def before_worker_shutdown(self):
|
| 125 |
+
pass
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@DeveloperAPI
|
| 129 |
+
class TrainContextCallback(RayTrainCallback):
|
| 130 |
+
"""
|
| 131 |
+
Callbacks that are hooked to the train context event.
|
| 132 |
+
|
| 133 |
+
These callbacks are created on the train driver process and then
|
| 134 |
+
copied and passed to all the workers.
|
| 135 |
+
The execution of these callbacks happens on the train context of the workers.
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
@contextmanager
|
| 139 |
+
def on_report(self):
|
| 140 |
+
yield
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (216 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__pycache__/checkpoint_manager.cpython-311.pyc
ADDED
|
Binary file (12.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__pycache__/report_handler.cpython-311.pyc
ADDED
|
Binary file (5.79 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__pycache__/sync_actor.cpython-311.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/checkpoint_manager.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Any, Dict, List, Optional
|
| 3 |
+
|
| 4 |
+
from ray.air.config import CheckpointConfig
|
| 5 |
+
from ray.train._checkpoint import Checkpoint
|
| 6 |
+
from ray.train._internal.checkpoint_manager import (
|
| 7 |
+
_CheckpointManager,
|
| 8 |
+
_insert_into_sorted_list,
|
| 9 |
+
)
|
| 10 |
+
from ray.train._internal.session import _TrainingResult
|
| 11 |
+
from ray.train.v2._internal.exceptions import CheckpointManagerInitializationError
|
| 12 |
+
from ray.train.v2._internal.execution.callback import ReportCallback
|
| 13 |
+
from ray.train.v2._internal.execution.context import StorageContext
|
| 14 |
+
from ray.train.v2._internal.execution.storage import _delete_fs_path, _exists_at_fs_path
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
from pydantic import BaseModel
|
| 18 |
+
from pydantic_core import from_json
|
| 19 |
+
except (ImportError, ModuleNotFoundError) as exc:
|
| 20 |
+
raise ImportError(
|
| 21 |
+
"`ray.train.v2` requires the pydantic package, which is missing. "
|
| 22 |
+
"Run the following command to fix this: `pip install pydantic`"
|
| 23 |
+
) from exc
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class _TrainingResultState(BaseModel):
|
| 30 |
+
# Increment version if the schema changes
|
| 31 |
+
version: int = 0
|
| 32 |
+
checkpoint_dir_name: str
|
| 33 |
+
metrics: dict
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class _CheckpointManagerState(BaseModel):
|
| 37 |
+
# Increment version if the schema changes
|
| 38 |
+
version: int = 0
|
| 39 |
+
checkpoint_results: List[_TrainingResultState]
|
| 40 |
+
latest_checkpoint_result: Optional[_TrainingResultState]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _get_training_result_from_state(
|
| 44 |
+
state: _TrainingResultState,
|
| 45 |
+
storage_context: StorageContext,
|
| 46 |
+
) -> _TrainingResult:
|
| 47 |
+
"""Get a TrainingResult object from a Pydantic state object."""
|
| 48 |
+
return _TrainingResult(
|
| 49 |
+
checkpoint=Checkpoint(
|
| 50 |
+
path=storage_context.build_checkpoint_path_from_name(
|
| 51 |
+
state.checkpoint_dir_name
|
| 52 |
+
),
|
| 53 |
+
filesystem=storage_context.storage_filesystem,
|
| 54 |
+
),
|
| 55 |
+
metrics=state.metrics,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _get_state_from_training_result(
|
| 60 |
+
training_result: _TrainingResult,
|
| 61 |
+
storage_context: StorageContext,
|
| 62 |
+
) -> _TrainingResultState:
|
| 63 |
+
"""Get a Pydantic state object from a TrainingResult object."""
|
| 64 |
+
return _TrainingResultState(
|
| 65 |
+
checkpoint_dir_name=storage_context.extract_checkpoint_dir_name_from_path(
|
| 66 |
+
training_result.checkpoint.path
|
| 67 |
+
),
|
| 68 |
+
metrics=training_result.metrics,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class CheckpointManager(_CheckpointManager, ReportCallback):
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
checkpoint_config: CheckpointConfig,
|
| 76 |
+
storage_context: StorageContext,
|
| 77 |
+
):
|
| 78 |
+
self._storage_context = storage_context
|
| 79 |
+
self._checkpoint_config = checkpoint_config
|
| 80 |
+
super().__init__(checkpoint_config)
|
| 81 |
+
# If the snapshot is found, the checkpoint manager will restore its state.
|
| 82 |
+
self._maybe_load_state_from_storage()
|
| 83 |
+
|
| 84 |
+
def register_checkpoint(self, checkpoint_result: _TrainingResult):
|
| 85 |
+
"""Register new checkpoint and add to bookkeeping.
|
| 86 |
+
|
| 87 |
+
This method will register a new checkpoint and add it to the internal
|
| 88 |
+
bookkeeping logic. This means the checkpoint manager will decide if
|
| 89 |
+
this checkpoint should be kept, and if older or worse performing
|
| 90 |
+
checkpoints should be deleted.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
checkpoint: Tracked checkpoint object to add to bookkeeping.
|
| 94 |
+
"""
|
| 95 |
+
self._latest_checkpoint_result = checkpoint_result
|
| 96 |
+
|
| 97 |
+
if self._checkpoint_config.checkpoint_score_attribute is not None:
|
| 98 |
+
# If we're ordering by a score, insert the checkpoint
|
| 99 |
+
# so that the list remains sorted.
|
| 100 |
+
_insert_into_sorted_list(
|
| 101 |
+
self._checkpoint_results,
|
| 102 |
+
checkpoint_result,
|
| 103 |
+
key=self._get_checkpoint_score,
|
| 104 |
+
)
|
| 105 |
+
else:
|
| 106 |
+
# If no metric is provided, just append (ordering by time of registration).
|
| 107 |
+
self._checkpoint_results.append(checkpoint_result)
|
| 108 |
+
|
| 109 |
+
results_to_delete = {}
|
| 110 |
+
if self._checkpoint_config.num_to_keep is not None:
|
| 111 |
+
# Delete the bottom (N - K) checkpoints
|
| 112 |
+
worst_results = set(
|
| 113 |
+
self._checkpoint_results[: -self._checkpoint_config.num_to_keep]
|
| 114 |
+
)
|
| 115 |
+
# Except for the latest checkpoint.
|
| 116 |
+
results_to_delete = worst_results - {self._latest_checkpoint_result}
|
| 117 |
+
|
| 118 |
+
# Update internal state before actually deleting them.
|
| 119 |
+
self._checkpoint_results = [
|
| 120 |
+
checkpoint_result
|
| 121 |
+
for checkpoint_result in self._checkpoint_results
|
| 122 |
+
if checkpoint_result not in results_to_delete
|
| 123 |
+
]
|
| 124 |
+
|
| 125 |
+
# Save the checkpoint manager state to storage.
|
| 126 |
+
# Note: We save the state before deleting the old checkpoints.
|
| 127 |
+
# If deletion happens first and the process crashes, our snapshot
|
| 128 |
+
# may point to some stale checkpoints that are already deleted.
|
| 129 |
+
# TODO: Make this writing operation non-blocking.
|
| 130 |
+
self._write_state_to_storage()
|
| 131 |
+
|
| 132 |
+
# Delete the old checkpoints.
|
| 133 |
+
for checkpoint_result in results_to_delete:
|
| 134 |
+
checkpoint = checkpoint_result.checkpoint
|
| 135 |
+
logger.debug("Deleting checkpoint: ", checkpoint)
|
| 136 |
+
_delete_fs_path(fs=checkpoint.filesystem, fs_path=checkpoint.path)
|
| 137 |
+
|
| 138 |
+
# --------------------------
|
| 139 |
+
# CheckpointManager state
|
| 140 |
+
# --------------------------
|
| 141 |
+
|
| 142 |
+
def _save_state(self) -> str:
|
| 143 |
+
"""Save the checkpoint manager state to a JSON str."""
|
| 144 |
+
|
| 145 |
+
checkpoint_results = [
|
| 146 |
+
_get_state_from_training_result(checkpoint_result, self._storage_context)
|
| 147 |
+
for checkpoint_result in self._checkpoint_results
|
| 148 |
+
]
|
| 149 |
+
|
| 150 |
+
latest_checkpoint_result = (
|
| 151 |
+
_get_state_from_training_result(
|
| 152 |
+
self._latest_checkpoint_result, self._storage_context
|
| 153 |
+
)
|
| 154 |
+
if self._latest_checkpoint_result is not None
|
| 155 |
+
else None
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
manager_snapshot = _CheckpointManagerState(
|
| 159 |
+
checkpoint_results=checkpoint_results,
|
| 160 |
+
latest_checkpoint_result=latest_checkpoint_result,
|
| 161 |
+
)
|
| 162 |
+
return manager_snapshot.model_dump_json()
|
| 163 |
+
|
| 164 |
+
def _load_state(self, json_state: str):
|
| 165 |
+
"""Load the checkpoint manager state from a JSON str."""
|
| 166 |
+
try:
|
| 167 |
+
manager_snapshot = _CheckpointManagerState.model_validate(
|
| 168 |
+
from_json(json_state)
|
| 169 |
+
)
|
| 170 |
+
except Exception as e:
|
| 171 |
+
raise CheckpointManagerInitializationError(repr(e)) from e
|
| 172 |
+
self._assert_checkpoints_exist()
|
| 173 |
+
|
| 174 |
+
self._checkpoint_results = [
|
| 175 |
+
_get_training_result_from_state(
|
| 176 |
+
training_result_state, self._storage_context
|
| 177 |
+
)
|
| 178 |
+
for training_result_state in manager_snapshot.checkpoint_results
|
| 179 |
+
]
|
| 180 |
+
|
| 181 |
+
self._latest_checkpoint_result = (
|
| 182 |
+
_get_training_result_from_state(
|
| 183 |
+
manager_snapshot.latest_checkpoint_result, self._storage_context
|
| 184 |
+
)
|
| 185 |
+
if manager_snapshot.latest_checkpoint_result is not None
|
| 186 |
+
else None
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def _maybe_load_state_from_storage(self):
|
| 190 |
+
"""Load the checkpoint manager state from storage.
|
| 191 |
+
If no snapshot is found, start with a clean state.
|
| 192 |
+
"""
|
| 193 |
+
if not _exists_at_fs_path(
|
| 194 |
+
fs=self._storage_context.storage_filesystem,
|
| 195 |
+
fs_path=self._storage_context.checkpoint_manager_snapshot_path,
|
| 196 |
+
):
|
| 197 |
+
logger.debug(
|
| 198 |
+
"No checkpoint manager snapshot found. "
|
| 199 |
+
"No checkpoint will be available via `ray.train.get_checkpoint`, "
|
| 200 |
+
"so training will start from scratch."
|
| 201 |
+
)
|
| 202 |
+
return
|
| 203 |
+
with self._storage_context.storage_filesystem.open_input_stream(
|
| 204 |
+
self._storage_context.checkpoint_manager_snapshot_path
|
| 205 |
+
) as f:
|
| 206 |
+
logger.info(
|
| 207 |
+
"A run snapshot was found in storage folder at: "
|
| 208 |
+
f"'{self._storage_context.experiment_fs_path}'\n"
|
| 209 |
+
"This snapshot contains a list of checkpoints reported via "
|
| 210 |
+
"`ray.train.report` and will be loaded. "
|
| 211 |
+
"This allows the latest checkpoint found in the snapshot to be "
|
| 212 |
+
"accessible within your training function via "
|
| 213 |
+
"`ray.train.get_checkpoint`.\n"
|
| 214 |
+
"If you meant to start a brand new training job without any "
|
| 215 |
+
"information about previous checkpoints found in this directory, "
|
| 216 |
+
"please configure a new, unique `RunConfig(name)` or delete the "
|
| 217 |
+
f"existing folder at '{self._storage_context.experiment_fs_path}'."
|
| 218 |
+
)
|
| 219 |
+
json_state = f.read().decode("utf-8")
|
| 220 |
+
self._load_state(json_state)
|
| 221 |
+
|
| 222 |
+
def _write_state_to_storage(self):
|
| 223 |
+
"""Write the checkpoint manager state to storage."""
|
| 224 |
+
checkpoint_manager_snapshot = self._save_state()
|
| 225 |
+
with self._storage_context.storage_filesystem.open_output_stream(
|
| 226 |
+
self._storage_context.checkpoint_manager_snapshot_path
|
| 227 |
+
) as f:
|
| 228 |
+
f.write(checkpoint_manager_snapshot.encode("utf-8"))
|
| 229 |
+
|
| 230 |
+
def _assert_checkpoints_exist(self):
|
| 231 |
+
"""Validate the checkpoint manager state.
|
| 232 |
+
|
| 233 |
+
This method will validate the checkpoint manager state by checking if
|
| 234 |
+
the checkpoints specified in manager snapshot is compatible with the
|
| 235 |
+
checkpoint folders of the experiment storage filesystem.
|
| 236 |
+
|
| 237 |
+
Raises:
|
| 238 |
+
CheckpointManagerInitializationError: If the checkpoint manager snapshot
|
| 239 |
+
is not consistent with the stored checkpoints.
|
| 240 |
+
"""
|
| 241 |
+
for checkpoint_result in self._checkpoint_results:
|
| 242 |
+
checkpoint = checkpoint_result.checkpoint
|
| 243 |
+
assert checkpoint is not None
|
| 244 |
+
if not _exists_at_fs_path(
|
| 245 |
+
fs=checkpoint.filesystem, fs_path=checkpoint.path
|
| 246 |
+
):
|
| 247 |
+
raise CheckpointManagerInitializationError(
|
| 248 |
+
message=(
|
| 249 |
+
"The run snapshot contains a reference to a checkpoint "
|
| 250 |
+
f"that does not exist anymore ({checkpoint}). You are "
|
| 251 |
+
"running in a corrupted run directory `experiment_fs_path`."
|
| 252 |
+
"Please configure a new, unique `RunConfig(name)` "
|
| 253 |
+
"or delete the existing folder at "
|
| 254 |
+
f"`{self._storage_context.experiment_fs_path}`."
|
| 255 |
+
)
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# --------------------------
|
| 259 |
+
# ReportCallback
|
| 260 |
+
# --------------------------
|
| 261 |
+
|
| 262 |
+
def after_report(
|
| 263 |
+
self, metrics: List[Dict[str, Any]], checkpoint: Optional[Checkpoint]
|
| 264 |
+
):
|
| 265 |
+
if not checkpoint:
|
| 266 |
+
return
|
| 267 |
+
|
| 268 |
+
rank_0_metrics = metrics[0]
|
| 269 |
+
self.register_checkpoint(
|
| 270 |
+
_TrainingResult(checkpoint=checkpoint, metrics=rank_0_metrics)
|
| 271 |
+
)
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/report_handler.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import deque
|
| 2 |
+
from typing import TYPE_CHECKING, Deque, List, Optional
|
| 3 |
+
|
| 4 |
+
from ray.train.v2._internal.execution.callback import (
|
| 5 |
+
ReportCallback,
|
| 6 |
+
WorkerGroupCallback,
|
| 7 |
+
)
|
| 8 |
+
from ray.train.v2._internal.execution.worker_group import WorkerGroup, WorkerGroupStatus
|
| 9 |
+
|
| 10 |
+
if TYPE_CHECKING:
|
| 11 |
+
from ray.train._internal.session import _TrainingResult
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ReportCallbackHandler(WorkerGroupCallback):
|
| 15 |
+
"""Consolidate training results from multiple workers and call
|
| 16 |
+
subscribers implementing the `ReportCallback` interface sequentially.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, report_callbacks: List[ReportCallback]):
|
| 20 |
+
# Number of workers in the current worker group. It is initialized
|
| 21 |
+
# to be None. It is set to the number of workers when it receives the
|
| 22 |
+
# worker group status for the first time.
|
| 23 |
+
# When a worker group shutdown, self._num_workers is set to None,
|
| 24 |
+
# waiting to be updated when a new worker group status is received again.
|
| 25 |
+
self._num_workers: Optional[int] = None
|
| 26 |
+
# A list of queues holding training results from workers.
|
| 27 |
+
self._training_result_queues: Optional[List[Deque[_TrainingResult]]] = None
|
| 28 |
+
|
| 29 |
+
self._report_callbacks = report_callbacks
|
| 30 |
+
|
| 31 |
+
# --------------------------
|
| 32 |
+
# WorkerGroupCallback
|
| 33 |
+
# --------------------------
|
| 34 |
+
|
| 35 |
+
def after_worker_group_poll_status(
|
| 36 |
+
self, worker_group_status: WorkerGroupStatus
|
| 37 |
+
) -> None:
|
| 38 |
+
"""Handle training results as they roll in from worker status polls.
|
| 39 |
+
|
| 40 |
+
Wait for all workers to report training results to collect
|
| 41 |
+
a consolidated training result.
|
| 42 |
+
"""
|
| 43 |
+
# Step 1: If self._num_workers is None, we need to initialize the number
|
| 44 |
+
# of workers and training_results_queues from the worker group status. This
|
| 45 |
+
# happens when the handler receives the worker group status for the first time.
|
| 46 |
+
assert (
|
| 47 |
+
self._num_workers and self._training_result_queues
|
| 48 |
+
), "Need to call initialize state with `after_worker_group_start` first."
|
| 49 |
+
|
| 50 |
+
assert self._num_workers == worker_group_status.num_workers, (
|
| 51 |
+
f"The number of workers in the worker group has changed unexpectedly. "
|
| 52 |
+
f"Expected: {self._num_workers}, got: {worker_group_status.num_workers}"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# Step 2: Update training_results_queues with poll_results.
|
| 56 |
+
for i in range(self._num_workers):
|
| 57 |
+
training_result = worker_group_status.worker_statuses[i].training_result
|
| 58 |
+
if training_result:
|
| 59 |
+
self._training_result_queues[i].append(training_result)
|
| 60 |
+
|
| 61 |
+
# Directly return if any of the worker result queues are empty.
|
| 62 |
+
if not all(self._training_result_queues):
|
| 63 |
+
return
|
| 64 |
+
|
| 65 |
+
training_results = [q.popleft() for q in self._training_result_queues]
|
| 66 |
+
|
| 67 |
+
# Step 3: Consolidate a list of checkpoints to single checkpoint.
|
| 68 |
+
# Use the first checkpoint as the consolidated checkpoint.
|
| 69 |
+
checkpoint_results = [
|
| 70 |
+
tr for tr in training_results if tr.checkpoint is not None
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
consolidated_checkpoint = None
|
| 74 |
+
if checkpoint_results:
|
| 75 |
+
# Double check the storage path of the checkpoints in the training results.
|
| 76 |
+
unique_checkpoint_paths = {tr.checkpoint.path for tr in checkpoint_results}
|
| 77 |
+
if len(unique_checkpoint_paths) > 1:
|
| 78 |
+
# TODO: Support for inconsistent checkpoints path from workers
|
| 79 |
+
# instead of hard raising error. Maybe drop this iteration of
|
| 80 |
+
# training results and continue with the next iteration.
|
| 81 |
+
raise RuntimeError(
|
| 82 |
+
"The storage path of the checkpoints in the training results "
|
| 83 |
+
"is not the same. This means the checkpoints are not consistent."
|
| 84 |
+
"Got a mix of the following checkpoint paths: "
|
| 85 |
+
f"{unique_checkpoint_paths}\n"
|
| 86 |
+
"This is unexpected -- please file a Github issue."
|
| 87 |
+
)
|
| 88 |
+
consolidated_checkpoint = checkpoint_results[0].checkpoint
|
| 89 |
+
|
| 90 |
+
# Step 4: Invoke all dependent `ReportCallback`s.
|
| 91 |
+
metrics_per_worker = [
|
| 92 |
+
training_result.metrics for training_result in training_results
|
| 93 |
+
]
|
| 94 |
+
for callback in self._report_callbacks:
|
| 95 |
+
callback.after_report(
|
| 96 |
+
metrics=metrics_per_worker,
|
| 97 |
+
checkpoint=consolidated_checkpoint,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def after_worker_group_start(self, worker_group: WorkerGroup) -> None:
|
| 101 |
+
"""Handle worker group start. Initialize internal states."""
|
| 102 |
+
self._num_workers = len(worker_group)
|
| 103 |
+
self._training_result_queues = [deque() for _ in range(self._num_workers)]
|
| 104 |
+
|
| 105 |
+
def before_worker_group_shutdown(self, worker_group: WorkerGroup) -> None:
|
| 106 |
+
"""Handle worker group shutdown. Clear internal states.
|
| 107 |
+
|
| 108 |
+
None of the partial reported results are valid at this point.
|
| 109 |
+
"""
|
| 110 |
+
self._num_workers = None
|
| 111 |
+
self._training_result_queues = None
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/sync_actor.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import logging
|
| 3 |
+
from contextlib import contextmanager
|
| 4 |
+
from typing import List, Optional, TypeVar
|
| 5 |
+
|
| 6 |
+
import ray
|
| 7 |
+
from ray.train.v2._internal.constants import (
|
| 8 |
+
DEFAULT_REPORT_BARRIER_TIMEOUT_S,
|
| 9 |
+
DEFAULT_REPORT_BARRIER_WARN_INTERVAL_S,
|
| 10 |
+
REPORT_BARRIER_WARN_INTERVAL_S_ENV_VAR,
|
| 11 |
+
)
|
| 12 |
+
from ray.train.v2._internal.exceptions import BroadcastCollectiveTimeoutError
|
| 13 |
+
|
| 14 |
+
T = TypeVar("T", bound=Optional[object])
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
BROADCAST_PERIODIC_WARNING = """
|
| 19 |
+
`ray.train.report` has not been called by all {world_size} workers in the group.
|
| 20 |
+
|
| 21 |
+
The workers have been waiting for {max_time_elapsed_s:.2f} s for the following ranks
|
| 22 |
+
to join the `report` call: {missing_ranks}.
|
| 23 |
+
|
| 24 |
+
Please ensure that all workers call `ray.train.report` regardless of whether
|
| 25 |
+
they participate in checkpointing or not (e.g., pass `checkpoint=None` for ranks
|
| 26 |
+
that do not save a checkpoint). Also ensure that workers are not hanging on
|
| 27 |
+
other operations, causing them to miss this synchronization barrier.
|
| 28 |
+
|
| 29 |
+
You can set the {warn_interval_env_var} environment variable to change the frequency
|
| 30 |
+
of this warning (current value: {warn_interval_s} s).
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@ray.remote(num_cpus=0) # type: ignore
|
| 35 |
+
class SynchronizationActor:
|
| 36 |
+
"""A Ray actor that synchronizes the workers in a distributed training job.
|
| 37 |
+
|
| 38 |
+
This actor forms a synchronization barrier on a group of processes.
|
| 39 |
+
Every time a worker calls the broadcast_from_rank_zero method,
|
| 40 |
+
the counter is incremented. When the counter equals to the world size,
|
| 41 |
+
the actor notifies all the workers to continue.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
timeout_s: float = DEFAULT_REPORT_BARRIER_TIMEOUT_S,
|
| 47 |
+
warn_interval_s: float = DEFAULT_REPORT_BARRIER_WARN_INTERVAL_S,
|
| 48 |
+
):
|
| 49 |
+
self._counter: int = 0
|
| 50 |
+
self._world_size: int = 0
|
| 51 |
+
self._condition = asyncio.Condition()
|
| 52 |
+
self._reduced_data = None
|
| 53 |
+
# The time when workers from different ranks
|
| 54 |
+
# enters the synchronization barrier.
|
| 55 |
+
self._sync_start_times: List[Optional[float]] = []
|
| 56 |
+
# The timeout in seconds for the synchronization barrier.
|
| 57 |
+
self._timeout_s: float = timeout_s
|
| 58 |
+
# The interval in seconds to log a warning when waiting for the barrier.
|
| 59 |
+
self._warn_interval_s: float = warn_interval_s
|
| 60 |
+
|
| 61 |
+
def get_counter(self):
|
| 62 |
+
"""Returns the current value of the counter."""
|
| 63 |
+
return self._counter
|
| 64 |
+
|
| 65 |
+
def get_world_size(self):
|
| 66 |
+
"""Returns the current value of the world_size."""
|
| 67 |
+
return self._world_size
|
| 68 |
+
|
| 69 |
+
def get_reduced_data(self):
|
| 70 |
+
"""Returns the current value of the reduced_data."""
|
| 71 |
+
return self._reduced_data
|
| 72 |
+
|
| 73 |
+
def _clear_states(self):
|
| 74 |
+
"""Clears the states of the actor. When the last worker has
|
| 75 |
+
called the _clear_states method, the actor clears its states
|
| 76 |
+
"""
|
| 77 |
+
self._counter -= 1
|
| 78 |
+
if self._counter == 0:
|
| 79 |
+
self._reduced_data = None
|
| 80 |
+
self._world_size = 0
|
| 81 |
+
|
| 82 |
+
def _setup_or_validate_collective_op(self, world_size: int):
|
| 83 |
+
"""The setup method for the synchronization actor if it is not setup yet.
|
| 84 |
+
It initializes the world size and the start times for the
|
| 85 |
+
synchronization barrier.
|
| 86 |
+
"""
|
| 87 |
+
if self._world_size == 0:
|
| 88 |
+
self._world_size = world_size
|
| 89 |
+
self._sync_start_times = [None] * world_size
|
| 90 |
+
elif world_size != self._world_size:
|
| 91 |
+
raise ValueError(
|
| 92 |
+
f"Expected all callers to provide the same world size. \
|
| 93 |
+
Got {world_size} and expected {self._world_size}."
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
@contextmanager
|
| 97 |
+
def _broadcast_collective_context_manager(
|
| 98 |
+
self, world_rank: int, world_size: int, data: T
|
| 99 |
+
):
|
| 100 |
+
"""A context manager that ensures the synchronization barrier is lifted
|
| 101 |
+
after the block of code is executed.
|
| 102 |
+
"""
|
| 103 |
+
try:
|
| 104 |
+
self._setup_or_validate_collective_op(world_size)
|
| 105 |
+
if world_rank == 0:
|
| 106 |
+
self._reduced_data = data
|
| 107 |
+
if self._counter < self._world_size:
|
| 108 |
+
self._counter += 1
|
| 109 |
+
yield
|
| 110 |
+
finally:
|
| 111 |
+
self._clear_states()
|
| 112 |
+
|
| 113 |
+
def _get_time_elapsed(self) -> Optional[float]:
|
| 114 |
+
"""Return the time elapsed since the first worker entered the barrier.
|
| 115 |
+
If no workers have entered the barrier, returns None.
|
| 116 |
+
"""
|
| 117 |
+
start_times = [t for t in self._sync_start_times if t is not None]
|
| 118 |
+
if not start_times:
|
| 119 |
+
return None
|
| 120 |
+
|
| 121 |
+
return asyncio.get_event_loop().time() - min(start_times)
|
| 122 |
+
|
| 123 |
+
def _get_missing_ranks(self) -> List[int]:
|
| 124 |
+
"""Returns the ranks that have not entered the synchronization barrier."""
|
| 125 |
+
return [i for i, t in enumerate(self._sync_start_times) if t is None]
|
| 126 |
+
|
| 127 |
+
async def _wait_with_logging(self, condition, world_rank: int):
|
| 128 |
+
"""Waits for the condition to be notified, logging an warning every
|
| 129 |
+
`log_interval` seconds, and raises a timeout error if `timeout` is reached.
|
| 130 |
+
"""
|
| 131 |
+
current_time = asyncio.get_event_loop().time()
|
| 132 |
+
self._sync_start_times[world_rank] = current_time
|
| 133 |
+
while True:
|
| 134 |
+
try:
|
| 135 |
+
await asyncio.wait_for(condition.wait(), timeout=self._warn_interval_s)
|
| 136 |
+
return
|
| 137 |
+
# asyncio.wait_for() raises `asyncio.TimeoutError` for asyncio<=3.10
|
| 138 |
+
# and raises `TimeoutError` for asyncio>=3.11
|
| 139 |
+
# https://docs.python.org/3/library/asyncio-task.html#asyncio.wait_for
|
| 140 |
+
# TODO: (hpguo) Make only one worker log the warning message.
|
| 141 |
+
except (asyncio.TimeoutError, TimeoutError):
|
| 142 |
+
logger.warning(
|
| 143 |
+
BROADCAST_PERIODIC_WARNING.format(
|
| 144 |
+
world_size=self._world_size,
|
| 145 |
+
max_time_elapsed_s=self._get_time_elapsed(),
|
| 146 |
+
missing_ranks=self._get_missing_ranks(),
|
| 147 |
+
warn_interval_env_var=REPORT_BARRIER_WARN_INTERVAL_S_ENV_VAR,
|
| 148 |
+
warn_interval_s=self._warn_interval_s,
|
| 149 |
+
)
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
async def broadcast_from_rank_zero(
|
| 153 |
+
self, world_rank: int, world_size: int, data: T
|
| 154 |
+
) -> T:
|
| 155 |
+
"""Broadcasts a data from the worker with rank 0 to all other workers.
|
| 156 |
+
|
| 157 |
+
This method is a coroutine that blocks until all workers have called this
|
| 158 |
+
method with the their data. The data from the worker with rank 0 will
|
| 159 |
+
be returned.
|
| 160 |
+
"""
|
| 161 |
+
# Ensures that all global states manipulation is done within the async context
|
| 162 |
+
# manager which makes the condition variable awaiting and the counter
|
| 163 |
+
# incrementing an atomic operation.
|
| 164 |
+
async with self._condition:
|
| 165 |
+
with self._broadcast_collective_context_manager(
|
| 166 |
+
world_rank, world_size, data
|
| 167 |
+
):
|
| 168 |
+
# If the counter is equal to the world size, it means the last worker
|
| 169 |
+
# has called the broadcast_from_rank_zero method. The actor notifies
|
| 170 |
+
# all the workers to continue.
|
| 171 |
+
if self._counter == self._world_size:
|
| 172 |
+
self._condition.notify_all()
|
| 173 |
+
return self._reduced_data
|
| 174 |
+
# If the counter is less than the world size, the actor waits for the
|
| 175 |
+
# other workers to call the broadcast_from_rank_zero method.
|
| 176 |
+
try:
|
| 177 |
+
await asyncio.wait_for(
|
| 178 |
+
self._wait_with_logging(self._condition, world_rank),
|
| 179 |
+
timeout=self._timeout_s,
|
| 180 |
+
)
|
| 181 |
+
return self._reduced_data
|
| 182 |
+
except (asyncio.TimeoutError, TimeoutError) as e:
|
| 183 |
+
raise BroadcastCollectiveTimeoutError(
|
| 184 |
+
time_elapsed=self._get_time_elapsed(),
|
| 185 |
+
missing_ranks=self._get_missing_ranks(),
|
| 186 |
+
timeout_s=self._timeout_s,
|
| 187 |
+
) from e
|
| 188 |
+
|
| 189 |
+
# TODO: Implement a general consensus_from_votes method that takes a callable
|
| 190 |
+
# reduce_fn and a list of votes from each worker. The method returns the consensus
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/context.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import threading
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from queue import Queue
|
| 5 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
| 6 |
+
|
| 7 |
+
import ray
|
| 8 |
+
from ray.data.iterator import DataIterator
|
| 9 |
+
from ray.train import Checkpoint
|
| 10 |
+
from ray.train._internal import session
|
| 11 |
+
from ray.train._internal.session import _TrainingResult
|
| 12 |
+
from ray.train.v2._internal.execution.checkpoint.sync_actor import SynchronizationActor
|
| 13 |
+
from ray.train.v2._internal.execution.storage import StorageContext
|
| 14 |
+
from ray.train.v2._internal.util import _copy_doc, invoke_context_managers
|
| 15 |
+
from ray.train.v2.api.config import RunConfig
|
| 16 |
+
|
| 17 |
+
if TYPE_CHECKING:
|
| 18 |
+
from ray.train.v2._internal.execution.callback import TrainContextCallback
|
| 19 |
+
from ray.train.v2._internal.execution.worker_group.thread_runner import ThreadRunner
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__file__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class TrainRunContext:
|
| 27 |
+
"""Holds the metadata and context for the current training run."""
|
| 28 |
+
|
| 29 |
+
# TODO: Make this dataclass immutable after refactoring the train context.
|
| 30 |
+
|
| 31 |
+
# The run configuration for the current training run.
|
| 32 |
+
run_config: RunConfig
|
| 33 |
+
|
| 34 |
+
# TODO: Add more fields that are shared across all workers and controllers.
|
| 35 |
+
# For example, StorageContext, ScalingConfig, etc.
|
| 36 |
+
|
| 37 |
+
def get_run_config(self) -> RunConfig:
|
| 38 |
+
"""Returns the run config of the current training run."""
|
| 39 |
+
return self.run_config
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass(frozen=True)
|
| 43 |
+
class DistributedContext:
|
| 44 |
+
world_rank: int
|
| 45 |
+
world_size: int
|
| 46 |
+
local_rank: int
|
| 47 |
+
local_world_size: int
|
| 48 |
+
node_rank: int
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dataclass(frozen=True)
|
| 52 |
+
class ExecutionContext:
|
| 53 |
+
"""Holds the execution context for the current worker process.
|
| 54 |
+
|
| 55 |
+
Every worker process has a single execution context accessed via the
|
| 56 |
+
`TrainContext`, which includes the training thread that is actually
|
| 57 |
+
running the user code.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
# A shared synchronization actor that helps broadcast data across ranks.
|
| 61 |
+
synchronization_actor: SynchronizationActor
|
| 62 |
+
|
| 63 |
+
# A queue that receives training results from the user training code.
|
| 64 |
+
# `ray.train.report` in user code populates this queue.
|
| 65 |
+
result_queue: Queue
|
| 66 |
+
|
| 67 |
+
# The thread launcher that runs the user training loop.
|
| 68 |
+
training_thread_runner: "ThreadRunner"
|
| 69 |
+
|
| 70 |
+
# The callbacks that are run in the worker train context.
|
| 71 |
+
train_context_callbacks: List["TrainContextCallback"]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@dataclass
|
| 75 |
+
class TrainContext(TrainRunContext):
|
| 76 |
+
distributed_context: DistributedContext
|
| 77 |
+
execution_context: ExecutionContext
|
| 78 |
+
storage_context: StorageContext
|
| 79 |
+
dataset_shards: Dict[str, DataIterator]
|
| 80 |
+
checkpoint: Optional[Checkpoint] = None
|
| 81 |
+
|
| 82 |
+
@_copy_doc(session.get_metadata)
|
| 83 |
+
def get_metadata(self) -> Dict[str, Any]:
|
| 84 |
+
raise NotImplementedError
|
| 85 |
+
|
| 86 |
+
@_copy_doc(session.get_experiment_name)
|
| 87 |
+
def get_experiment_name(self) -> str:
|
| 88 |
+
# TODO: Resolve run_config.name if it is None
|
| 89 |
+
return self.run_config.name
|
| 90 |
+
|
| 91 |
+
@_copy_doc(session.get_trial_name)
|
| 92 |
+
def get_trial_name(self) -> str:
|
| 93 |
+
raise NotImplementedError
|
| 94 |
+
|
| 95 |
+
@_copy_doc(session.get_trial_id)
|
| 96 |
+
def get_trial_id(self) -> str:
|
| 97 |
+
raise NotImplementedError
|
| 98 |
+
|
| 99 |
+
@_copy_doc(session.get_trial_resources)
|
| 100 |
+
def get_trial_resources(self):
|
| 101 |
+
raise NotImplementedError
|
| 102 |
+
|
| 103 |
+
@_copy_doc(session.get_trial_dir)
|
| 104 |
+
def get_trial_dir(self) -> str:
|
| 105 |
+
raise NotImplementedError
|
| 106 |
+
|
| 107 |
+
@_copy_doc(session.get_world_size)
|
| 108 |
+
def get_world_size(self) -> int:
|
| 109 |
+
return self.distributed_context.world_size
|
| 110 |
+
|
| 111 |
+
@_copy_doc(session.get_world_rank)
|
| 112 |
+
def get_world_rank(self) -> int:
|
| 113 |
+
return self.distributed_context.world_rank
|
| 114 |
+
|
| 115 |
+
@_copy_doc(session.get_local_rank)
|
| 116 |
+
def get_local_rank(self) -> int:
|
| 117 |
+
return self.distributed_context.local_rank
|
| 118 |
+
|
| 119 |
+
@_copy_doc(session.get_local_world_size)
|
| 120 |
+
def get_local_world_size(self) -> int:
|
| 121 |
+
return self.distributed_context.local_world_size
|
| 122 |
+
|
| 123 |
+
@_copy_doc(session.get_node_rank)
|
| 124 |
+
def get_node_rank(self) -> int:
|
| 125 |
+
return self.distributed_context.node_rank
|
| 126 |
+
|
| 127 |
+
@_copy_doc(session.get_storage)
|
| 128 |
+
def get_storage(self):
|
| 129 |
+
return self.storage_context
|
| 130 |
+
|
| 131 |
+
def get_result_queue(self):
|
| 132 |
+
return self.execution_context.result_queue
|
| 133 |
+
|
| 134 |
+
def get_synchronization_actor(self):
|
| 135 |
+
return self.execution_context.synchronization_actor
|
| 136 |
+
|
| 137 |
+
def get_checkpoint(self):
|
| 138 |
+
return self.checkpoint
|
| 139 |
+
|
| 140 |
+
def get_dataset_shard(self, dataset_name: str) -> DataIterator:
|
| 141 |
+
"""Returns the :class:`ray.data.DataIterator` shard for this worker.
|
| 142 |
+
|
| 143 |
+
Call :meth:`~ray.data.DataIterator.iter_torch_batches` or
|
| 144 |
+
:meth:`~ray.data.DataIterator.to_tf` on this shard to convert it to the
|
| 145 |
+
appropriate framework-specific data type.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
dataset_name: Name of the dataset shard.
|
| 149 |
+
Returns:
|
| 150 |
+
The ``DataIterator`` shard with the given name for this worker.
|
| 151 |
+
Raises:
|
| 152 |
+
KeyError: If the dataset shard with the given name is not found.
|
| 153 |
+
"""
|
| 154 |
+
try:
|
| 155 |
+
return self.dataset_shards[dataset_name]
|
| 156 |
+
except KeyError:
|
| 157 |
+
raise KeyError(
|
| 158 |
+
f"Dataset {dataset_name} not found. Available datasets: "
|
| 159 |
+
f"{list(self.dataset_shards.keys())}."
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
def get_context_callbacks(self) -> List["TrainContextCallback"]:
|
| 163 |
+
return self.execution_context.train_context_callbacks
|
| 164 |
+
|
| 165 |
+
def _sync_checkpoint_dir_name_across_ranks(
|
| 166 |
+
self, checkpoint_dir_name: Optional[str] = None
|
| 167 |
+
) -> str:
|
| 168 |
+
"""Sync the checkpoint dir name across ranks.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
checkpoint_dir_name: The checkpoint dir name to sync.
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
The synced checkpoint dir name.
|
| 175 |
+
"""
|
| 176 |
+
# If checkpoint_dir_name is not set, use default checkpoint_dir_name
|
| 177 |
+
# created by the storage context.
|
| 178 |
+
checkpoint_dir_name = (
|
| 179 |
+
checkpoint_dir_name
|
| 180 |
+
or self.storage_context.make_default_checkpoint_dir_name()
|
| 181 |
+
)
|
| 182 |
+
# Get a consensus across ranks on the remote storage path, so distributed
|
| 183 |
+
# checkpoints will be stored to the same place.
|
| 184 |
+
sync_actor = self.get_synchronization_actor()
|
| 185 |
+
return ray.get(
|
| 186 |
+
sync_actor.broadcast_from_rank_zero.remote(
|
| 187 |
+
world_rank=self.distributed_context.world_rank,
|
| 188 |
+
world_size=self.distributed_context.world_size,
|
| 189 |
+
data=checkpoint_dir_name,
|
| 190 |
+
)
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
def _save_checkpoint(
|
| 194 |
+
self,
|
| 195 |
+
checkpoint_dir_name: str,
|
| 196 |
+
metrics: Dict[str, Any],
|
| 197 |
+
checkpoint: Optional[Checkpoint] = None,
|
| 198 |
+
) -> _TrainingResult:
|
| 199 |
+
"""Save the checkpoint to remote storage.
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
The training result object containing the persisted checkpoint.
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
if not checkpoint:
|
| 206 |
+
return _TrainingResult(checkpoint=None, metrics=metrics)
|
| 207 |
+
|
| 208 |
+
# Persist the checkpoint to the remote storage path.
|
| 209 |
+
persisted_checkpoint = self.storage_context.persist_current_checkpoint(
|
| 210 |
+
checkpoint, checkpoint_dir_name
|
| 211 |
+
)
|
| 212 |
+
# Update latest checkpoint as the persisted checkpoint.
|
| 213 |
+
self.checkpoint = persisted_checkpoint
|
| 214 |
+
|
| 215 |
+
return _TrainingResult(checkpoint=persisted_checkpoint, metrics=metrics)
|
| 216 |
+
|
| 217 |
+
def report(
|
| 218 |
+
self,
|
| 219 |
+
metrics: Dict[str, Any],
|
| 220 |
+
checkpoint: Optional[Checkpoint] = None,
|
| 221 |
+
checkpoint_dir_name: Optional[str] = None,
|
| 222 |
+
):
|
| 223 |
+
"""
|
| 224 |
+
Upload checkpoint to remote storage and put a training
|
| 225 |
+
result on the result queue of this worker process.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
metrics: The metrics to report.
|
| 229 |
+
checkpoint: The checkpoint to report.
|
| 230 |
+
checkpoint_dir_name: The name of the checkpoint dir
|
| 231 |
+
in this iteration. Note: If not set, the checkpoint will
|
| 232 |
+
be stored in the default storage path. If set, make sure
|
| 233 |
+
this value is unique for each iteration.
|
| 234 |
+
|
| 235 |
+
TODO: the report function should be implemented in the worker instead
|
| 236 |
+
of in the train context. The train context should only keep the train
|
| 237 |
+
related information and not the worker related actions. This refactor
|
| 238 |
+
would also require the `TrainContextCallback` to be updated as well.
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
with invoke_context_managers(
|
| 242 |
+
[
|
| 243 |
+
callback.on_report
|
| 244 |
+
for callback in self.execution_context.train_context_callbacks
|
| 245 |
+
]
|
| 246 |
+
):
|
| 247 |
+
# Step 1: sync the checkpoint dir name across ranks.
|
| 248 |
+
checkpoint_dir_name = self._sync_checkpoint_dir_name_across_ranks(
|
| 249 |
+
checkpoint_dir_name
|
| 250 |
+
)
|
| 251 |
+
# Step 2: save the checkpoint to remote storage.
|
| 252 |
+
training_result = self._save_checkpoint(
|
| 253 |
+
checkpoint_dir_name, metrics, checkpoint
|
| 254 |
+
)
|
| 255 |
+
# Step 3: Report the training result to the result queue.
|
| 256 |
+
# The queue size is set to 1 to avoid accumulating unprocessed results.
|
| 257 |
+
# If the queue is full, the put operation blocks until a result is consumed.
|
| 258 |
+
|
| 259 |
+
# TODO (hpguo): Add a metrics to track the blocking time waiting for the
|
| 260 |
+
# training result to be consumed by the controller.
|
| 261 |
+
self.get_result_queue().put(training_result)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
# The global variable holding the current TrainContext
|
| 265 |
+
_train_context: Optional[TrainContext] = None
|
| 266 |
+
|
| 267 |
+
# Thread lock to protect the global TrainContext
|
| 268 |
+
_context_lock = threading.Lock()
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def get_train_context() -> TrainContext:
|
| 272 |
+
with _context_lock:
|
| 273 |
+
if _train_context is None:
|
| 274 |
+
raise RuntimeError("TrainContext has not been initialized.")
|
| 275 |
+
return _train_context
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def set_train_context(context) -> None:
|
| 279 |
+
global _train_context
|
| 280 |
+
with _context_lock:
|
| 281 |
+
_train_context = context
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/controller.py
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
from enum import Enum
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any, Callable, Dict, List, Optional
|
| 7 |
+
|
| 8 |
+
from ray._private.auto_init_hook import wrap_auto_init
|
| 9 |
+
from ray.train import Checkpoint
|
| 10 |
+
from ray.train.v2._internal.constants import (
|
| 11 |
+
DEFAULT_HEALTH_CHECK_INTERVAL_S,
|
| 12 |
+
HEALTH_CHECK_INTERVAL_S_ENV_VAR,
|
| 13 |
+
)
|
| 14 |
+
from ray.train.v2._internal.exceptions import (
|
| 15 |
+
TrainingFailedError,
|
| 16 |
+
WorkerGroupStartupFailedError,
|
| 17 |
+
WorkerGroupStartupTimeoutError,
|
| 18 |
+
)
|
| 19 |
+
from ray.train.v2._internal.execution.callback import (
|
| 20 |
+
ControllerCallback,
|
| 21 |
+
ReportCallback,
|
| 22 |
+
TrainContextCallback,
|
| 23 |
+
WorkerCallback,
|
| 24 |
+
WorkerGroupCallback,
|
| 25 |
+
)
|
| 26 |
+
from ray.train.v2._internal.execution.checkpoint.checkpoint_manager import (
|
| 27 |
+
CheckpointManager,
|
| 28 |
+
)
|
| 29 |
+
from ray.train.v2._internal.execution.checkpoint.report_handler import (
|
| 30 |
+
ReportCallbackHandler,
|
| 31 |
+
)
|
| 32 |
+
from ray.train.v2._internal.execution.context import TrainRunContext
|
| 33 |
+
from ray.train.v2._internal.execution.failure_handling import (
|
| 34 |
+
FailureDecision,
|
| 35 |
+
FailurePolicy,
|
| 36 |
+
)
|
| 37 |
+
from ray.train.v2._internal.execution.scaling_policy import (
|
| 38 |
+
ResizeDecision,
|
| 39 |
+
ScalingDecision,
|
| 40 |
+
ScalingPolicy,
|
| 41 |
+
)
|
| 42 |
+
from ray.train.v2._internal.execution.storage import StorageContext, get_fs_and_path
|
| 43 |
+
from ray.train.v2._internal.execution.worker_group import WorkerGroup, WorkerGroupStatus
|
| 44 |
+
from ray.train.v2._internal.logging.logging import configure_controller_logger
|
| 45 |
+
from ray.train.v2._internal.util import time_monotonic
|
| 46 |
+
from ray.train.v2.api.result import Result
|
| 47 |
+
from ray.train.v2.api.callback import RayTrainCallback
|
| 48 |
+
|
| 49 |
+
logger = logging.getLogger(__name__)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class TrainControllerState(Enum):
|
| 53 |
+
"""The possible states that the training controller can be in
|
| 54 |
+
while running the main execution control loop.
|
| 55 |
+
|
| 56 |
+
States:
|
| 57 |
+
RUNNING: The training controller is actively running training tasks.
|
| 58 |
+
RECOVERING: The training controller is in the process of recovering
|
| 59 |
+
from an error.
|
| 60 |
+
INITIALIZING: The train controller is starting up.
|
| 61 |
+
This is always the initial state of the controller.
|
| 62 |
+
ERRORED: A terminal state indicating that training has encountered
|
| 63 |
+
an error and cannot continue.
|
| 64 |
+
FINISHED: A terminal state indicating that training has completed.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
RUNNING = "RUNNING"
|
| 68 |
+
INITIALIZING = "INITIALIZING"
|
| 69 |
+
RECOVERING = "RECOVERING"
|
| 70 |
+
ERRORED = "ERRORED"
|
| 71 |
+
FINISHED = "FINISHED"
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class TrainController:
|
| 75 |
+
"""Manages the execution of a distributed training job.
|
| 76 |
+
|
| 77 |
+
Responsibilities include:
|
| 78 |
+
* Triggering the training function to run on the worker group.
|
| 79 |
+
* Monitoring the status of the worker group.
|
| 80 |
+
* Handling scaling decisions by restarting the worker group.
|
| 81 |
+
* Handling failure decisions by restarting the worker group or terminating training.
|
| 82 |
+
* Running callback logic on different hooks in the control loop.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
worker_group_cls = WorkerGroup
|
| 86 |
+
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
train_fn: Callable[[Dict[str, Any]], None],
|
| 90 |
+
train_run_context: TrainRunContext,
|
| 91 |
+
scaling_policy: ScalingPolicy,
|
| 92 |
+
failure_policy: FailurePolicy,
|
| 93 |
+
callbacks: Optional[List[RayTrainCallback]] = None,
|
| 94 |
+
# TODO: [Deprecation]
|
| 95 |
+
resume_from_checkpoint: Optional[Checkpoint] = None,
|
| 96 |
+
):
|
| 97 |
+
self._train_run_context = train_run_context
|
| 98 |
+
configure_controller_logger(self._train_run_context)
|
| 99 |
+
self._train_fn = train_fn
|
| 100 |
+
self._scaling_policy = scaling_policy
|
| 101 |
+
self._failure_policy = failure_policy
|
| 102 |
+
self._run_config = self._train_run_context.run_config
|
| 103 |
+
self._callbacks = callbacks or []
|
| 104 |
+
self._resume_from_checkpoint = resume_from_checkpoint
|
| 105 |
+
self._storage_context = StorageContext(
|
| 106 |
+
storage_path=self._run_config.storage_path,
|
| 107 |
+
experiment_dir_name=self._run_config.name,
|
| 108 |
+
storage_filesystem=self._run_config.storage_filesystem,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
self._checkpoint_manager = CheckpointManager(
|
| 112 |
+
checkpoint_config=self._run_config.checkpoint_config,
|
| 113 |
+
storage_context=self._storage_context,
|
| 114 |
+
)
|
| 115 |
+
report_handler = ReportCallbackHandler(
|
| 116 |
+
report_callbacks=(
|
| 117 |
+
[self._checkpoint_manager]
|
| 118 |
+
+ [c for c in self._callbacks if isinstance(c, ReportCallback)]
|
| 119 |
+
)
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Group callbacks by the hooks they're subscribed to.
|
| 123 |
+
self._controller_callbacks = [self._scaling_policy] + [
|
| 124 |
+
c for c in self._callbacks if isinstance(c, ControllerCallback)
|
| 125 |
+
]
|
| 126 |
+
# Group callbacks that will be propagated to the worker group,
|
| 127 |
+
# train worker and the train context.
|
| 128 |
+
worker_group_callbacks_to_propagate = [report_handler] + [
|
| 129 |
+
c
|
| 130 |
+
for c in self._callbacks
|
| 131 |
+
if isinstance(
|
| 132 |
+
c, (WorkerGroupCallback, WorkerCallback, TrainContextCallback)
|
| 133 |
+
)
|
| 134 |
+
]
|
| 135 |
+
|
| 136 |
+
self._worker_group = self.worker_group_cls(
|
| 137 |
+
train_run_context=self._train_run_context,
|
| 138 |
+
callbacks=worker_group_callbacks_to_propagate,
|
| 139 |
+
)
|
| 140 |
+
self._state = TrainControllerState.INITIALIZING
|
| 141 |
+
|
| 142 |
+
self._latest_poll_time = float("-inf")
|
| 143 |
+
self._health_check_interval_s = float(
|
| 144 |
+
os.getenv(HEALTH_CHECK_INTERVAL_S_ENV_VAR, DEFAULT_HEALTH_CHECK_INTERVAL_S)
|
| 145 |
+
)
|
| 146 |
+
self._training_failed_error: Optional[TrainingFailedError] = None
|
| 147 |
+
|
| 148 |
+
def _execute_scaling_decision(
|
| 149 |
+
self, decision: ScalingDecision, worker_group_status: WorkerGroupStatus
|
| 150 |
+
):
|
| 151 |
+
"""Executes scaling decisions."""
|
| 152 |
+
for callback in self._controller_callbacks:
|
| 153 |
+
callback.before_controller_execute_scaling_decision(
|
| 154 |
+
decision, worker_group_status
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
if isinstance(decision, ResizeDecision):
|
| 158 |
+
self._restart_worker_group(
|
| 159 |
+
num_workers=decision.num_workers,
|
| 160 |
+
resources_per_worker=decision.resources_per_worker,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
def _execute_failure_decision(
|
| 164 |
+
self, failure_decision: FailureDecision, worker_group_status: WorkerGroupStatus
|
| 165 |
+
):
|
| 166 |
+
"""Executes failure handling decisions (ex: restart, terminate)."""
|
| 167 |
+
assert worker_group_status.errors
|
| 168 |
+
|
| 169 |
+
for callback in self._controller_callbacks:
|
| 170 |
+
callback.before_controller_execute_failure_decision(
|
| 171 |
+
failure_decision, worker_group_status
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
if failure_decision == FailureDecision.NOOP:
|
| 175 |
+
assert self._state == TrainControllerState.RUNNING
|
| 176 |
+
return
|
| 177 |
+
|
| 178 |
+
errors_str = "\n".join(
|
| 179 |
+
[
|
| 180 |
+
f"[Rank {worker_rank}]\n{error}"
|
| 181 |
+
for worker_rank, error in worker_group_status.errors.items()
|
| 182 |
+
]
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
if failure_decision == FailureDecision.RESTART:
|
| 186 |
+
logger.error(
|
| 187 |
+
"Restarting training worker group after encountering "
|
| 188 |
+
f"failures on {len(worker_group_status.errors)} worker(s):\n"
|
| 189 |
+
f"{errors_str}"
|
| 190 |
+
)
|
| 191 |
+
# Shutdown the worker group so that we don't keep polling errored tasks.
|
| 192 |
+
self._worker_group.shutdown()
|
| 193 |
+
self._set_state(TrainControllerState.RECOVERING)
|
| 194 |
+
elif failure_decision == FailureDecision.RAISE:
|
| 195 |
+
logger.error(
|
| 196 |
+
"Terminating training worker group after encountering "
|
| 197 |
+
f"failure(s) on {len(worker_group_status.errors)} worker(s):\n"
|
| 198 |
+
f"{errors_str}"
|
| 199 |
+
)
|
| 200 |
+
self._set_state(TrainControllerState.ERRORED)
|
| 201 |
+
self._training_failed_error = TrainingFailedError(
|
| 202 |
+
worker_failures=worker_group_status.errors
|
| 203 |
+
)
|
| 204 |
+
else:
|
| 205 |
+
raise ValueError(f"Unexpected failure decision: {failure_decision}")
|
| 206 |
+
|
| 207 |
+
def _poll_workers(self) -> WorkerGroupStatus:
|
| 208 |
+
# Ensure that the time between polls is at least HEALTH_CHECK_INTERVAL_S.
|
| 209 |
+
time_since_last_poll = time_monotonic() - self._latest_poll_time
|
| 210 |
+
if time_since_last_poll < self._health_check_interval_s:
|
| 211 |
+
remaining_time = max(
|
| 212 |
+
self._health_check_interval_s - time_since_last_poll, 0
|
| 213 |
+
)
|
| 214 |
+
time.sleep(remaining_time)
|
| 215 |
+
|
| 216 |
+
status = self._worker_group.poll_status(timeout=self._health_check_interval_s)
|
| 217 |
+
self._latest_poll_time = time_monotonic()
|
| 218 |
+
return status
|
| 219 |
+
|
| 220 |
+
def _restart_worker_group(self, num_workers: int, resources_per_worker: dict):
|
| 221 |
+
"""Restart the worker group and launch the train function."""
|
| 222 |
+
self._worker_group.shutdown()
|
| 223 |
+
|
| 224 |
+
# If there's a latest checkpoint that's been committed,
|
| 225 |
+
# use it to restore the worker group.
|
| 226 |
+
latest_checkpoint_result = self._checkpoint_manager.latest_checkpoint_result
|
| 227 |
+
latest_checkpoint = (
|
| 228 |
+
latest_checkpoint_result.checkpoint if latest_checkpoint_result else None
|
| 229 |
+
)
|
| 230 |
+
placement_strategy = self._scaling_policy.scaling_config.placement_strategy
|
| 231 |
+
|
| 232 |
+
# Start the worker group with the latest checkpoint if there is one.
|
| 233 |
+
# Otherwise, start the worker group with the checkpoint set by controller.
|
| 234 |
+
# Finally, if there is no checkpoint, start the worker group with None.
|
| 235 |
+
try:
|
| 236 |
+
self._worker_group.start(
|
| 237 |
+
train_fn=self._train_fn,
|
| 238 |
+
num_workers=num_workers,
|
| 239 |
+
resources_per_worker=resources_per_worker,
|
| 240 |
+
placement_strategy=placement_strategy,
|
| 241 |
+
checkpoint=latest_checkpoint or self._resume_from_checkpoint,
|
| 242 |
+
)
|
| 243 |
+
except (WorkerGroupStartupTimeoutError, WorkerGroupStartupFailedError) as e:
|
| 244 |
+
logger.error(
|
| 245 |
+
"Retrying the launch of the training worker group. "
|
| 246 |
+
f"The previous launch attempt encountered the following failure:\n{e}"
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
# TODO: Should this logic go through the failure policy?
|
| 250 |
+
# The current logic will always try recovering unconditionally
|
| 251 |
+
# on startup errors without a retry limit.
|
| 252 |
+
self._set_state(TrainControllerState.RECOVERING)
|
| 253 |
+
return
|
| 254 |
+
|
| 255 |
+
# TODO: Consider starting the worker group asynchronously.
|
| 256 |
+
self._set_state(TrainControllerState.RUNNING)
|
| 257 |
+
|
| 258 |
+
def _start(self):
|
| 259 |
+
for callback in self._controller_callbacks:
|
| 260 |
+
callback.after_controller_start()
|
| 261 |
+
|
| 262 |
+
def _shutdown(self):
|
| 263 |
+
self._worker_group.shutdown()
|
| 264 |
+
|
| 265 |
+
for callback in self._controller_callbacks:
|
| 266 |
+
callback.before_controller_shutdown()
|
| 267 |
+
|
| 268 |
+
def get_worker_group(self) -> WorkerGroup:
|
| 269 |
+
return self._worker_group
|
| 270 |
+
|
| 271 |
+
def get_state(self) -> TrainControllerState:
|
| 272 |
+
return self._state
|
| 273 |
+
|
| 274 |
+
def _set_state(self, state: TrainControllerState):
|
| 275 |
+
previous_state = self._state
|
| 276 |
+
self._state = state
|
| 277 |
+
|
| 278 |
+
for callback in self._controller_callbacks:
|
| 279 |
+
callback.after_controller_state_update(previous_state, state)
|
| 280 |
+
|
| 281 |
+
def _run_control_loop_iteration(self):
|
| 282 |
+
"""Run a single iteration of the control loop.
|
| 283 |
+
|
| 284 |
+
Steps:
|
| 285 |
+
1. Poll the worker group for status.
|
| 286 |
+
2. If the worker group is initializing or recovering from an error,
|
| 287 |
+
make a scaling decision and execute it.
|
| 288 |
+
3. If the worker group has finished, set the controller state to FINISHED.
|
| 289 |
+
4. If the worker group has errors, make a failure decision and execute it.
|
| 290 |
+
5. Otherwise, the worker group is running healthily.
|
| 291 |
+
Query the scaling policy for a scaling decision and execute it.
|
| 292 |
+
"""
|
| 293 |
+
assert self.get_state() in (
|
| 294 |
+
TrainControllerState.RUNNING,
|
| 295 |
+
TrainControllerState.RECOVERING,
|
| 296 |
+
TrainControllerState.INITIALIZING,
|
| 297 |
+
), self.get_state()
|
| 298 |
+
|
| 299 |
+
worker_group_status = self._poll_workers()
|
| 300 |
+
|
| 301 |
+
if worker_group_status.finished and not worker_group_status.errors:
|
| 302 |
+
self._set_state(TrainControllerState.FINISHED)
|
| 303 |
+
return
|
| 304 |
+
|
| 305 |
+
if self.get_state() in (
|
| 306 |
+
TrainControllerState.INITIALIZING,
|
| 307 |
+
TrainControllerState.RECOVERING,
|
| 308 |
+
):
|
| 309 |
+
scaling_decision = (
|
| 310 |
+
self._scaling_policy.make_decision_for_non_running_worker_group(
|
| 311 |
+
worker_group_status
|
| 312 |
+
)
|
| 313 |
+
)
|
| 314 |
+
self._execute_scaling_decision(scaling_decision, worker_group_status)
|
| 315 |
+
elif self.get_state() == TrainControllerState.RUNNING:
|
| 316 |
+
if worker_group_status.errors:
|
| 317 |
+
failure_decision = self._failure_policy.make_decision(
|
| 318 |
+
worker_group_status
|
| 319 |
+
)
|
| 320 |
+
self._execute_failure_decision(failure_decision, worker_group_status)
|
| 321 |
+
else:
|
| 322 |
+
scaling_decision = (
|
| 323 |
+
self._scaling_policy.make_decision_for_running_worker_group(
|
| 324 |
+
worker_group_status
|
| 325 |
+
)
|
| 326 |
+
)
|
| 327 |
+
self._execute_scaling_decision(scaling_decision, worker_group_status)
|
| 328 |
+
|
| 329 |
+
@wrap_auto_init
|
| 330 |
+
def run(self):
|
| 331 |
+
"""Run the main control loop. Exits when training is finished or errored."""
|
| 332 |
+
self._start()
|
| 333 |
+
|
| 334 |
+
while self.get_state() not in (
|
| 335 |
+
TrainControllerState.ERRORED,
|
| 336 |
+
TrainControllerState.FINISHED,
|
| 337 |
+
):
|
| 338 |
+
self._run_control_loop_iteration()
|
| 339 |
+
|
| 340 |
+
self._shutdown()
|
| 341 |
+
|
| 342 |
+
def get_result(self) -> Result:
|
| 343 |
+
"""Get the final training result from the TrainController."""
|
| 344 |
+
|
| 345 |
+
controller_state = self.get_state()
|
| 346 |
+
if controller_state not in (
|
| 347 |
+
TrainControllerState.FINISHED,
|
| 348 |
+
TrainControllerState.ERRORED,
|
| 349 |
+
):
|
| 350 |
+
raise ValueError(
|
| 351 |
+
f"Cannot get result when controller is in state {controller_state}"
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
latest_checkpoint_result = self._checkpoint_manager.latest_checkpoint_result
|
| 355 |
+
latest_metrics = (
|
| 356 |
+
latest_checkpoint_result.metrics if latest_checkpoint_result else None
|
| 357 |
+
)
|
| 358 |
+
latest_checkpoint = (
|
| 359 |
+
latest_checkpoint_result.checkpoint if latest_checkpoint_result else None
|
| 360 |
+
)
|
| 361 |
+
best_checkpoints = [
|
| 362 |
+
(r.checkpoint, r.metrics)
|
| 363 |
+
for r in self._checkpoint_manager.best_checkpoint_results
|
| 364 |
+
]
|
| 365 |
+
storage_filesystem, storage_fs_path = get_fs_and_path(
|
| 366 |
+
self._run_config.storage_path, self._run_config.storage_filesystem
|
| 367 |
+
)
|
| 368 |
+
experiment_fs_path = Path(storage_fs_path, self._run_config.name).as_posix()
|
| 369 |
+
|
| 370 |
+
return Result(
|
| 371 |
+
metrics=latest_metrics,
|
| 372 |
+
checkpoint=latest_checkpoint,
|
| 373 |
+
error=self._training_failed_error,
|
| 374 |
+
path=experiment_fs_path,
|
| 375 |
+
best_checkpoints=best_checkpoints,
|
| 376 |
+
_storage_filesystem=storage_filesystem,
|
| 377 |
+
)
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (511 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/__pycache__/default.cpython-311.pyc
ADDED
|
Binary file (2.67 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/__pycache__/factory.cpython-311.pyc
ADDED
|
Binary file (797 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/__pycache__/failure_policy.cpython-311.pyc
ADDED
|
Binary file (1.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/default.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
from ray.train import FailureConfig
|
| 4 |
+
from ray.train.v2._internal.execution.failure_handling import (
|
| 5 |
+
FailureDecision,
|
| 6 |
+
FailurePolicy,
|
| 7 |
+
)
|
| 8 |
+
from ray.train.v2._internal.execution.worker_group import WorkerGroupStatus
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DefaultFailurePolicy(FailurePolicy):
|
| 14 |
+
def __init__(self, failure_config: FailureConfig):
|
| 15 |
+
super().__init__(failure_config)
|
| 16 |
+
self._total_failures = 0
|
| 17 |
+
|
| 18 |
+
def make_decision(self, worker_group_status: WorkerGroupStatus) -> FailureDecision:
|
| 19 |
+
if not worker_group_status.errors:
|
| 20 |
+
return FailureDecision.NOOP
|
| 21 |
+
|
| 22 |
+
self._total_failures += 1
|
| 23 |
+
|
| 24 |
+
if self.failure_config.max_failures == -1:
|
| 25 |
+
logger.info(
|
| 26 |
+
"Deciding to RESTART, since infinite retry is enabled. "
|
| 27 |
+
f"Encountered {self._total_failures} failures so far."
|
| 28 |
+
)
|
| 29 |
+
return FailureDecision.RESTART
|
| 30 |
+
|
| 31 |
+
if self._total_failures > self.failure_config.max_failures:
|
| 32 |
+
logger.info(
|
| 33 |
+
"Deciding to TERMINATE, since the total failure count "
|
| 34 |
+
f"({self._total_failures}) exceeded the maximum allowed failures: "
|
| 35 |
+
f"FailureConfig(max_failures={self.failure_config.max_failures})."
|
| 36 |
+
)
|
| 37 |
+
return FailureDecision.RAISE
|
| 38 |
+
|
| 39 |
+
logger.info(
|
| 40 |
+
"Deciding to RESTART, since the total "
|
| 41 |
+
f"failure count ({self._total_failures}) <= "
|
| 42 |
+
f"FailureConfig(max_failures={self.failure_config.max_failures})."
|
| 43 |
+
)
|
| 44 |
+
return FailureDecision.RESTART
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/factory.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.train import FailureConfig
|
| 2 |
+
from ray.train.v2._internal.execution.failure_handling import (
|
| 3 |
+
DefaultFailurePolicy,
|
| 4 |
+
FailurePolicy,
|
| 5 |
+
)
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def create_failure_policy(failure_config: FailureConfig) -> FailurePolicy:
|
| 9 |
+
"""Create a failure policy from the given failure config.
|
| 10 |
+
|
| 11 |
+
Defaults to the `DefaultFailurePolicy` implementation.
|
| 12 |
+
"""
|
| 13 |
+
return DefaultFailurePolicy(failure_config=failure_config)
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# isort: off
|
| 2 |
+
from .scaling_policy import ScalingDecision, ScalingPolicy, NoopDecision, ResizeDecision
|
| 3 |
+
from .fixed import FixedScalingPolicy
|
| 4 |
+
from .factory import create_scaling_policy
|
| 5 |
+
|
| 6 |
+
# isort: on
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"ScalingPolicy",
|
| 11 |
+
"FixedScalingPolicy",
|
| 12 |
+
"ScalingDecision",
|
| 13 |
+
"NoopDecision",
|
| 14 |
+
"ResizeDecision",
|
| 15 |
+
"create_scaling_policy",
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# DO NOT ADD ANYTHING AFTER THIS LINE.
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/__pycache__/factory.cpython-311.pyc
ADDED
|
Binary file (805 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/__pycache__/fixed.cpython-311.pyc
ADDED
|
Binary file (1.54 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/__pycache__/scaling_policy.cpython-311.pyc
ADDED
|
Binary file (3.09 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/factory.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.train.v2._internal.execution.scaling_policy import (
|
| 2 |
+
FixedScalingPolicy,
|
| 3 |
+
ScalingPolicy,
|
| 4 |
+
)
|
| 5 |
+
from ray.train.v2.api.config import ScalingConfig
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def create_scaling_policy(scaling_config: ScalingConfig) -> ScalingPolicy:
|
| 9 |
+
"""Create a scaling policy from the given scaling config.
|
| 10 |
+
|
| 11 |
+
Defaults to the `FixedScalingPolicy` implementation.
|
| 12 |
+
"""
|
| 13 |
+
return FixedScalingPolicy(scaling_config=scaling_config)
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/fixed.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.train.v2._internal.execution.scaling_policy import (
|
| 2 |
+
NoopDecision,
|
| 3 |
+
ResizeDecision,
|
| 4 |
+
ScalingDecision,
|
| 5 |
+
ScalingPolicy,
|
| 6 |
+
)
|
| 7 |
+
from ray.train.v2._internal.execution.worker_group import WorkerGroupStatus
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class FixedScalingPolicy(ScalingPolicy):
|
| 11 |
+
def make_decision_for_non_running_worker_group(
|
| 12 |
+
self, worker_group_status: WorkerGroupStatus
|
| 13 |
+
) -> ScalingDecision:
|
| 14 |
+
return ResizeDecision(
|
| 15 |
+
num_workers=self.scaling_config.num_workers,
|
| 16 |
+
resources_per_worker=self.scaling_config._resources_per_worker_not_none,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
def make_decision_for_running_worker_group(
|
| 20 |
+
self, worker_group_status: WorkerGroupStatus
|
| 21 |
+
) -> ScalingDecision:
|
| 22 |
+
return NoopDecision()
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/scaling_policy.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Dict
|
| 4 |
+
|
| 5 |
+
from ray.train.v2._internal.execution.callback import ControllerCallback
|
| 6 |
+
from ray.train.v2._internal.execution.worker_group import WorkerGroupStatus
|
| 7 |
+
from ray.train.v2.api.config import ScalingConfig
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class ScalingDecision:
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class NoopDecision(ScalingDecision):
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class ResizeDecision(ScalingDecision):
|
| 22 |
+
num_workers: int
|
| 23 |
+
resources_per_worker: Dict[str, float]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ScalingPolicy(abc.ABC, ControllerCallback):
|
| 27 |
+
"""A policy that determines when and how to scale a worker group.
|
| 28 |
+
|
| 29 |
+
This can be used to implement elasticity and fault tolerance.
|
| 30 |
+
|
| 31 |
+
Recovery decisions are made when workers are in an inactive or unhealthy state.
|
| 32 |
+
Upscale decisions are optional and are made when workers are healthy.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, scaling_config: ScalingConfig):
|
| 36 |
+
self.scaling_config = scaling_config
|
| 37 |
+
|
| 38 |
+
@abc.abstractmethod
|
| 39 |
+
def make_decision_for_non_running_worker_group(
|
| 40 |
+
self, worker_group_status: WorkerGroupStatus
|
| 41 |
+
) -> ScalingDecision:
|
| 42 |
+
"""Makes a scaling decision when the worker group is initializing
|
| 43 |
+
or recovering from an error."""
|
| 44 |
+
raise NotImplementedError
|
| 45 |
+
|
| 46 |
+
@abc.abstractmethod
|
| 47 |
+
def make_decision_for_running_worker_group(
|
| 48 |
+
self, worker_group_status: WorkerGroupStatus
|
| 49 |
+
) -> ScalingDecision:
|
| 50 |
+
"""Makes a scaling decision when monitoring healthy, running workers."""
|
| 51 |
+
raise NotImplementedError
|
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/storage.py
ADDED
|
@@ -0,0 +1,551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Try import ray[train] core requirements (defined in setup.py)
|
| 2 |
+
# isort: off
|
| 3 |
+
try:
|
| 4 |
+
import fsspec # noqa
|
| 5 |
+
from fsspec.implementations.local import LocalFileSystem
|
| 6 |
+
|
| 7 |
+
except (ImportError, ModuleNotFoundError) as e:
|
| 8 |
+
raise RuntimeError(
|
| 9 |
+
"fsspec is a required dependency of Ray Train and Ray Tune. "
|
| 10 |
+
"Please install with: `pip install fsspec`"
|
| 11 |
+
) from e
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
import pyarrow
|
| 15 |
+
import pyarrow.fs
|
| 16 |
+
|
| 17 |
+
except (ImportError, ModuleNotFoundError) as e:
|
| 18 |
+
raise RuntimeError(
|
| 19 |
+
"pyarrow is a required dependency of Ray Train and Ray Tune. "
|
| 20 |
+
"Please install with: `pip install pyarrow`"
|
| 21 |
+
) from e
|
| 22 |
+
# isort: on
|
| 23 |
+
|
| 24 |
+
import fnmatch
|
| 25 |
+
import logging
|
| 26 |
+
import os
|
| 27 |
+
import shutil
|
| 28 |
+
from pathlib import Path
|
| 29 |
+
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Type, Union
|
| 30 |
+
|
| 31 |
+
from ray.air._internal.filelock import TempFileLock
|
| 32 |
+
from ray.train.constants import _get_ray_train_session_dir
|
| 33 |
+
from ray.train.v2._internal.constants import (
|
| 34 |
+
CHECKPOINT_MANAGER_SNAPSHOT_FILENAME,
|
| 35 |
+
VALIDATE_STORAGE_MARKER_FILENAME,
|
| 36 |
+
)
|
| 37 |
+
from ray.train.v2._internal.util import date_str
|
| 38 |
+
from ray.util.annotations import DeveloperAPI
|
| 39 |
+
|
| 40 |
+
if TYPE_CHECKING:
|
| 41 |
+
from ray.train import Checkpoint
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
logger = logging.getLogger(__name__)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class _ExcludingLocalFilesystem(LocalFileSystem):
|
| 48 |
+
"""LocalFileSystem wrapper to exclude files according to patterns.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
root_path: Root path to strip when matching with the exclude pattern.
|
| 52 |
+
Ex: root_path="/tmp/a/b/c", exclude=["*a*"], will exclude
|
| 53 |
+
/tmp/a/b/c/_a_.txt but not ALL of /tmp/a/*.
|
| 54 |
+
exclude: List of patterns that are applied to files returned by
|
| 55 |
+
``self.find()``. If a file path matches this pattern, it will
|
| 56 |
+
be excluded.
|
| 57 |
+
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(self, root_path: Path, exclude: List[str], **kwargs):
|
| 61 |
+
super().__init__(**kwargs)
|
| 62 |
+
self._exclude = exclude
|
| 63 |
+
self._root_path = root_path
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def fsid(self):
|
| 67 |
+
return "_excluding_local"
|
| 68 |
+
|
| 69 |
+
def _should_exclude(self, path: str) -> bool:
|
| 70 |
+
"""Return True if `path` (relative to `root_path`) matches any of the
|
| 71 |
+
`self._exclude` patterns."""
|
| 72 |
+
path = Path(path)
|
| 73 |
+
relative_path = path.relative_to(self._root_path).as_posix()
|
| 74 |
+
match_candidates = [relative_path]
|
| 75 |
+
if path.is_dir():
|
| 76 |
+
# Everything is in posix path format ('/')
|
| 77 |
+
match_candidates.append(relative_path + "/")
|
| 78 |
+
|
| 79 |
+
for excl in self._exclude:
|
| 80 |
+
if any(fnmatch.fnmatch(candidate, excl) for candidate in match_candidates):
|
| 81 |
+
return True
|
| 82 |
+
return False
|
| 83 |
+
|
| 84 |
+
def find(self, path, maxdepth=None, withdirs=False, detail=False, **kwargs):
|
| 85 |
+
"""Call parent find() and exclude from result."""
|
| 86 |
+
paths = super().find(
|
| 87 |
+
path, maxdepth=maxdepth, withdirs=withdirs, detail=detail, **kwargs
|
| 88 |
+
)
|
| 89 |
+
if detail:
|
| 90 |
+
return {
|
| 91 |
+
path: out
|
| 92 |
+
for path, out in paths.items()
|
| 93 |
+
if not self._should_exclude(path)
|
| 94 |
+
}
|
| 95 |
+
else:
|
| 96 |
+
return [path for path in paths if not self._should_exclude(path)]
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _pyarrow_fs_copy_files(
|
| 100 |
+
source, destination, source_filesystem=None, destination_filesystem=None, **kwargs
|
| 101 |
+
):
|
| 102 |
+
if isinstance(destination_filesystem, pyarrow.fs.S3FileSystem):
|
| 103 |
+
# Workaround multi-threading issue with pyarrow. Note that use_threads=True
|
| 104 |
+
# is safe for download, just not for uploads, see:
|
| 105 |
+
# https://github.com/apache/arrow/issues/32372
|
| 106 |
+
kwargs.setdefault("use_threads", False)
|
| 107 |
+
|
| 108 |
+
# Use a large chunk size to speed up large checkpoint transfers.
|
| 109 |
+
kwargs.setdefault("chunk_size", 64 * 1024 * 1024)
|
| 110 |
+
|
| 111 |
+
return pyarrow.fs.copy_files(
|
| 112 |
+
source,
|
| 113 |
+
destination,
|
| 114 |
+
source_filesystem=source_filesystem,
|
| 115 |
+
destination_filesystem=destination_filesystem,
|
| 116 |
+
**kwargs,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# TODO(justinvyu): Add unit tests for all these utils.
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def _delete_fs_path(fs: pyarrow.fs.FileSystem, fs_path: str):
|
| 124 |
+
is_dir = _is_directory(fs, fs_path)
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
if is_dir:
|
| 128 |
+
fs.delete_dir(fs_path)
|
| 129 |
+
else:
|
| 130 |
+
fs.delete_file(fs_path)
|
| 131 |
+
except Exception:
|
| 132 |
+
logger.exception(f"Caught exception when deleting path at ({fs}, {fs_path}):")
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _download_from_fs_path(
|
| 136 |
+
fs: pyarrow.fs.FileSystem,
|
| 137 |
+
fs_path: str,
|
| 138 |
+
local_path: str,
|
| 139 |
+
filelock: bool = True,
|
| 140 |
+
):
|
| 141 |
+
"""Downloads a directory or file from (fs, fs_path) to a local path.
|
| 142 |
+
|
| 143 |
+
If fs_path points to a directory:
|
| 144 |
+
- The full directory contents are downloaded directly into `local_path`,
|
| 145 |
+
rather than to a subdirectory of `local_path`.
|
| 146 |
+
|
| 147 |
+
If fs_path points to a file:
|
| 148 |
+
- The file is downloaded to `local_path`, which is expected to be a file path.
|
| 149 |
+
|
| 150 |
+
If the download fails, the `local_path` contents are
|
| 151 |
+
cleaned up before raising, if the directory did not previously exist.
|
| 152 |
+
|
| 153 |
+
NOTE: This method creates `local_path`'s parent directories if they do not
|
| 154 |
+
already exist. If the download fails, this does NOT clean up all the parent
|
| 155 |
+
directories that were created.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
fs: The filesystem to download from.
|
| 159 |
+
fs_path: The filesystem path (either a directory or a file) to download.
|
| 160 |
+
local_path: The local path to download to.
|
| 161 |
+
filelock: Whether to require a file lock before downloading, useful for
|
| 162 |
+
multiple downloads to the same directory that may be happening in parallel.
|
| 163 |
+
|
| 164 |
+
Raises:
|
| 165 |
+
FileNotFoundError: if (fs, fs_path) doesn't exist.
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
_local_path = Path(local_path).resolve()
|
| 169 |
+
exists_before = _local_path.exists()
|
| 170 |
+
if _is_directory(fs=fs, fs_path=fs_path):
|
| 171 |
+
_local_path.mkdir(parents=True, exist_ok=True)
|
| 172 |
+
else:
|
| 173 |
+
_local_path.parent.mkdir(parents=True, exist_ok=True)
|
| 174 |
+
|
| 175 |
+
try:
|
| 176 |
+
if filelock:
|
| 177 |
+
with TempFileLock(f"{os.path.normpath(local_path)}.lock"):
|
| 178 |
+
_pyarrow_fs_copy_files(fs_path, local_path, source_filesystem=fs)
|
| 179 |
+
else:
|
| 180 |
+
_pyarrow_fs_copy_files(fs_path, local_path, source_filesystem=fs)
|
| 181 |
+
except Exception as e:
|
| 182 |
+
# Clean up the directory if downloading was unsuccessful
|
| 183 |
+
if not exists_before:
|
| 184 |
+
shutil.rmtree(local_path, ignore_errors=True)
|
| 185 |
+
raise e
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def _upload_to_fs_path(
|
| 189 |
+
local_path: str,
|
| 190 |
+
fs: pyarrow.fs.FileSystem,
|
| 191 |
+
fs_path: str,
|
| 192 |
+
exclude: Optional[List[str]] = None,
|
| 193 |
+
) -> None:
|
| 194 |
+
"""Uploads a local directory or file to (fs, fs_path).
|
| 195 |
+
|
| 196 |
+
NOTE: This will create all necessary parent directories at the destination.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
local_path: The local path to upload.
|
| 200 |
+
fs: The filesystem to upload to.
|
| 201 |
+
fs_path: The filesystem path where the dir/file will be uploaded to.
|
| 202 |
+
exclude: A list of filename matches to exclude from upload. This includes
|
| 203 |
+
all files under subdirectories as well.
|
| 204 |
+
This pattern will match with the relative paths of all files under
|
| 205 |
+
`local_path`.
|
| 206 |
+
Ex: ["*.png"] to exclude all .png images.
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
if not exclude:
|
| 210 |
+
# TODO(justinvyu): uploading a single file doesn't work
|
| 211 |
+
# (since we always create a directory at fs_path)
|
| 212 |
+
_create_directory(fs=fs, fs_path=fs_path)
|
| 213 |
+
_pyarrow_fs_copy_files(local_path, fs_path, destination_filesystem=fs)
|
| 214 |
+
return
|
| 215 |
+
|
| 216 |
+
_upload_to_uri_with_exclude_fsspec(
|
| 217 |
+
local_path=local_path, fs=fs, fs_path=fs_path, exclude=exclude
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def _upload_to_uri_with_exclude_fsspec(
|
| 222 |
+
local_path: str, fs: "pyarrow.fs", fs_path: str, exclude: Optional[List[str]]
|
| 223 |
+
) -> None:
|
| 224 |
+
local_fs = _ExcludingLocalFilesystem(root_path=local_path, exclude=exclude)
|
| 225 |
+
handler = pyarrow.fs.FSSpecHandler(local_fs)
|
| 226 |
+
source_fs = pyarrow.fs.PyFileSystem(handler)
|
| 227 |
+
|
| 228 |
+
_create_directory(fs=fs, fs_path=fs_path)
|
| 229 |
+
_pyarrow_fs_copy_files(
|
| 230 |
+
local_path, fs_path, source_filesystem=source_fs, destination_filesystem=fs
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def _list_at_fs_path(
|
| 235 |
+
fs: pyarrow.fs.FileSystem,
|
| 236 |
+
fs_path: str,
|
| 237 |
+
file_filter: Callable[[pyarrow.fs.FileInfo], bool] = lambda x: True,
|
| 238 |
+
) -> List[str]:
|
| 239 |
+
"""Returns the list of filenames at (fs, fs_path), similar to os.listdir.
|
| 240 |
+
|
| 241 |
+
If the path doesn't exist, returns an empty list.
|
| 242 |
+
"""
|
| 243 |
+
selector = pyarrow.fs.FileSelector(fs_path, allow_not_found=True, recursive=False)
|
| 244 |
+
return [
|
| 245 |
+
os.path.relpath(file_info.path.lstrip("/"), start=fs_path.lstrip("/"))
|
| 246 |
+
for file_info in fs.get_file_info(selector)
|
| 247 |
+
if file_filter(file_info)
|
| 248 |
+
]
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def _exists_at_fs_path(fs: pyarrow.fs.FileSystem, fs_path: str) -> bool:
|
| 252 |
+
"""Returns True if (fs, fs_path) exists."""
|
| 253 |
+
|
| 254 |
+
valid = fs.get_file_info(fs_path)
|
| 255 |
+
return valid.type != pyarrow.fs.FileType.NotFound
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def _is_directory(fs: pyarrow.fs.FileSystem, fs_path: str) -> bool:
|
| 259 |
+
"""Checks if (fs, fs_path) is a directory or a file.
|
| 260 |
+
|
| 261 |
+
Raises:
|
| 262 |
+
FileNotFoundError: if (fs, fs_path) doesn't exist.
|
| 263 |
+
"""
|
| 264 |
+
|
| 265 |
+
file_info = fs.get_file_info(fs_path)
|
| 266 |
+
if file_info.type == pyarrow.fs.FileType.NotFound:
|
| 267 |
+
raise FileNotFoundError(f"Path not found: ({fs}, {fs_path})")
|
| 268 |
+
|
| 269 |
+
return not file_info.is_file
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def _create_directory(fs: pyarrow.fs.FileSystem, fs_path: str) -> None:
|
| 273 |
+
"""Create directory at (fs, fs_path).
|
| 274 |
+
|
| 275 |
+
Some external filesystems require directories to already exist, or at least
|
| 276 |
+
the `netloc` to be created (e.g. PyArrows ``mock://`` filesystem).
|
| 277 |
+
|
| 278 |
+
Generally this should be done before and outside of Ray applications. This
|
| 279 |
+
utility is thus primarily used in testing, e.g. of ``mock://` URIs.
|
| 280 |
+
"""
|
| 281 |
+
try:
|
| 282 |
+
fs.create_dir(fs_path)
|
| 283 |
+
except Exception:
|
| 284 |
+
logger.exception(
|
| 285 |
+
f"Caught exception when creating directory at ({fs}, {fs_path}):"
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def get_fs_and_path(
|
| 290 |
+
storage_path: Union[str, os.PathLike],
|
| 291 |
+
storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
|
| 292 |
+
) -> Tuple[pyarrow.fs.FileSystem, str]:
|
| 293 |
+
"""Returns the fs and path from a storage path and an optional custom fs.
|
| 294 |
+
|
| 295 |
+
Args:
|
| 296 |
+
storage_path: A storage path or URI. (ex: s3://bucket/path or /tmp/ray_results)
|
| 297 |
+
storage_filesystem: A custom filesystem to use. If not provided,
|
| 298 |
+
this will be auto-resolved by pyarrow. If provided, the storage_path
|
| 299 |
+
is assumed to be prefix-stripped already, and must be a valid path
|
| 300 |
+
on the filesystem.
|
| 301 |
+
"""
|
| 302 |
+
storage_path = str(storage_path)
|
| 303 |
+
|
| 304 |
+
if storage_filesystem:
|
| 305 |
+
return storage_filesystem, storage_path
|
| 306 |
+
|
| 307 |
+
return pyarrow.fs.FileSystem.from_uri(storage_path)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
@DeveloperAPI
|
| 311 |
+
class StorageContext:
|
| 312 |
+
"""Shared context that holds the source of truth for all paths and
|
| 313 |
+
storage utilities, passed along from the driver to workers.
|
| 314 |
+
|
| 315 |
+
This object defines a few types of paths:
|
| 316 |
+
1. *_fs_path: A path on the `storage_filesystem`. This is a regular path
|
| 317 |
+
which has been prefix-stripped by pyarrow.fs.FileSystem.from_uri and
|
| 318 |
+
can be joined with `Path(...).as_posix()`.
|
| 319 |
+
2. *_driver_staging_path: The temporary staging directory on the local filesystem
|
| 320 |
+
where driver artifacts are saved to before persisting them to storage.
|
| 321 |
+
3. trial_working_directory: The local filesystem path that the remote
|
| 322 |
+
actors' working directories are moved to by default.
|
| 323 |
+
This is separated from the driver staging path so that driver syncing
|
| 324 |
+
does not implicitly upload the trial working directory, for trials on the
|
| 325 |
+
driver node.
|
| 326 |
+
|
| 327 |
+
Example with storage_path="mock:///bucket/path?param=1":
|
| 328 |
+
|
| 329 |
+
>>> import ray
|
| 330 |
+
>>> from ray.train._internal.storage import StorageContext
|
| 331 |
+
>>> import os
|
| 332 |
+
>>> _ = ray.init()
|
| 333 |
+
>>> storage = StorageContext(
|
| 334 |
+
... storage_path="mock://netloc/bucket/path?param=1",
|
| 335 |
+
... experiment_dir_name="exp_name",
|
| 336 |
+
... )
|
| 337 |
+
>>> storage.storage_filesystem # Auto-resolved # doctest: +ELLIPSIS
|
| 338 |
+
<pyarrow._fs._MockFileSystem object...
|
| 339 |
+
>>> storage.experiment_fs_path
|
| 340 |
+
'bucket/path/exp_name'
|
| 341 |
+
>>> storage.experiment_driver_staging_path # doctest: +ELLIPSIS
|
| 342 |
+
'/tmp/ray/session_.../artifacts/.../exp_name/driver_artifacts'
|
| 343 |
+
>>> storage.trial_dir_name = "trial_dir"
|
| 344 |
+
>>> storage.trial_fs_path
|
| 345 |
+
'bucket/path/exp_name/trial_dir'
|
| 346 |
+
>>> storage.trial_driver_staging_path # doctest: +ELLIPSIS
|
| 347 |
+
'/tmp/ray/session_.../artifacts/.../exp_name/driver_artifacts/trial_dir'
|
| 348 |
+
>>> storage.trial_working_directory # doctest: +ELLIPSIS
|
| 349 |
+
'/tmp/ray/session_.../artifacts/.../exp_name/working_dirs/trial_dir'
|
| 350 |
+
>>> ray.shutdown()
|
| 351 |
+
|
| 352 |
+
Example with storage_path="/tmp/ray_results":
|
| 353 |
+
|
| 354 |
+
>>> from ray.train._internal.storage import StorageContext
|
| 355 |
+
>>> storage = StorageContext(
|
| 356 |
+
... storage_path="/tmp/ray_results",
|
| 357 |
+
... experiment_dir_name="exp_name",
|
| 358 |
+
... )
|
| 359 |
+
>>> storage.storage_fs_path
|
| 360 |
+
'/tmp/ray_results'
|
| 361 |
+
>>> storage.experiment_fs_path
|
| 362 |
+
'/tmp/ray_results/exp_name'
|
| 363 |
+
>>> storage.storage_filesystem # Auto-resolved # doctest: +ELLIPSIS
|
| 364 |
+
<pyarrow._fs.LocalFileSystem object...
|
| 365 |
+
|
| 366 |
+
Internal Usage Examples:
|
| 367 |
+
- To copy files to the trial directory on the storage filesystem:
|
| 368 |
+
|
| 369 |
+
pyarrow.fs.copy_files(
|
| 370 |
+
local_dir,
|
| 371 |
+
Path(storage.trial_fs_path, "subdir").as_posix(),
|
| 372 |
+
destination_filesystem=storage.filesystem
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
.. warning::
|
| 376 |
+
This is an experimental developer API and is subject to change
|
| 377 |
+
without notice between versions.
|
| 378 |
+
"""
|
| 379 |
+
|
| 380 |
+
def __init__(
|
| 381 |
+
self,
|
| 382 |
+
storage_path: Union[str, os.PathLike],
|
| 383 |
+
experiment_dir_name: str,
|
| 384 |
+
storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
|
| 385 |
+
):
|
| 386 |
+
self.custom_fs_provided = storage_filesystem is not None
|
| 387 |
+
|
| 388 |
+
# Invariant: (`storage_filesystem`, `storage_path`) is the location where
|
| 389 |
+
# *all* results can be accessed.
|
| 390 |
+
self.experiment_dir_name = experiment_dir_name
|
| 391 |
+
|
| 392 |
+
self.storage_filesystem, self.storage_fs_path = get_fs_and_path(
|
| 393 |
+
storage_path, storage_filesystem
|
| 394 |
+
)
|
| 395 |
+
self.storage_fs_path = Path(self.storage_fs_path).as_posix()
|
| 396 |
+
|
| 397 |
+
self._create_validation_file()
|
| 398 |
+
self._check_validation_file()
|
| 399 |
+
|
| 400 |
+
def __str__(self):
|
| 401 |
+
return (
|
| 402 |
+
"StorageContext<\n"
|
| 403 |
+
f" storage_filesystem='{self.storage_filesystem.type_name}',\n"
|
| 404 |
+
f" storage_fs_path='{self.storage_fs_path}',\n"
|
| 405 |
+
f" experiment_dir_name='{self.experiment_dir_name}',\n"
|
| 406 |
+
">"
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
def _create_validation_file(self):
|
| 410 |
+
"""On the creation of a storage context, create a validation file at the
|
| 411 |
+
storage path to verify that the storage path can be written to.
|
| 412 |
+
This validation file is also used to check whether the storage path is
|
| 413 |
+
accessible by all nodes in the cluster."""
|
| 414 |
+
valid_file = Path(
|
| 415 |
+
self.experiment_fs_path, VALIDATE_STORAGE_MARKER_FILENAME
|
| 416 |
+
).as_posix()
|
| 417 |
+
self.storage_filesystem.create_dir(self.experiment_fs_path)
|
| 418 |
+
with self.storage_filesystem.open_output_stream(valid_file):
|
| 419 |
+
pass
|
| 420 |
+
|
| 421 |
+
def _check_validation_file(self):
|
| 422 |
+
"""Checks that the validation file exists at the storage path."""
|
| 423 |
+
valid_file = Path(
|
| 424 |
+
self.experiment_fs_path, VALIDATE_STORAGE_MARKER_FILENAME
|
| 425 |
+
).as_posix()
|
| 426 |
+
if not _exists_at_fs_path(fs=self.storage_filesystem, fs_path=valid_file):
|
| 427 |
+
raise RuntimeError(
|
| 428 |
+
f"Unable to set up cluster storage with the following settings:\n{self}"
|
| 429 |
+
"\nCheck that all nodes in the cluster have read/write access "
|
| 430 |
+
"to the configured storage path. `RunConfig(storage_path)` should be "
|
| 431 |
+
"set to a cloud storage URI or a shared filesystem path accessible "
|
| 432 |
+
"by all nodes in your cluster ('s3://bucket' or '/mnt/nfs'). "
|
| 433 |
+
"A local path on the head node is not accessible by worker nodes. "
|
| 434 |
+
"See: https://docs.ray.io/en/latest/train/user-guides/persistent-storage.html" # noqa: E501
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
def persist_current_checkpoint(
|
| 438 |
+
self, checkpoint: "Checkpoint", checkpoint_dir_name: str
|
| 439 |
+
) -> "Checkpoint":
|
| 440 |
+
"""Persists a given checkpoint to the current checkpoint path on the filesystem.
|
| 441 |
+
|
| 442 |
+
This method copies the checkpoint files to the storage location.
|
| 443 |
+
It's up to the user to delete the original checkpoint files if desired.
|
| 444 |
+
|
| 445 |
+
For example, the original directory is typically a local temp directory.
|
| 446 |
+
|
| 447 |
+
Args:
|
| 448 |
+
checkpoint: The checkpoint to persist to
|
| 449 |
+
(fs, experiment_fs_path / checkpoint_dir_name).
|
| 450 |
+
|
| 451 |
+
Returns:
|
| 452 |
+
Checkpoint: A Checkpoint pointing to the persisted checkpoint location.
|
| 453 |
+
"""
|
| 454 |
+
# TODO(justinvyu): Fix this cyclical import.
|
| 455 |
+
from ray.train import Checkpoint
|
| 456 |
+
|
| 457 |
+
checkpoint_fs_path = self.build_checkpoint_path_from_name(checkpoint_dir_name)
|
| 458 |
+
|
| 459 |
+
logger.debug(
|
| 460 |
+
"Copying checkpoint files to storage path:\n"
|
| 461 |
+
"({source_fs}, {source}) -> ({dest_fs}, {destination})".format(
|
| 462 |
+
source=checkpoint.path,
|
| 463 |
+
destination=checkpoint_fs_path,
|
| 464 |
+
source_fs=checkpoint.filesystem,
|
| 465 |
+
dest_fs=self.storage_filesystem,
|
| 466 |
+
)
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
# Raise an error if the storage path is not accessible when
|
| 470 |
+
# attempting to upload a checkpoint from a remote worker.
|
| 471 |
+
# Ex: If storage_path is a local path, then a validation marker
|
| 472 |
+
# will only exist on the head node but not the worker nodes.
|
| 473 |
+
self._check_validation_file()
|
| 474 |
+
|
| 475 |
+
self.storage_filesystem.create_dir(checkpoint_fs_path)
|
| 476 |
+
_pyarrow_fs_copy_files(
|
| 477 |
+
source=checkpoint.path,
|
| 478 |
+
destination=checkpoint_fs_path,
|
| 479 |
+
source_filesystem=checkpoint.filesystem,
|
| 480 |
+
destination_filesystem=self.storage_filesystem,
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
persisted_checkpoint = Checkpoint(
|
| 484 |
+
filesystem=self.storage_filesystem,
|
| 485 |
+
path=checkpoint_fs_path,
|
| 486 |
+
)
|
| 487 |
+
logger.info(f"Checkpoint successfully created at: {persisted_checkpoint}")
|
| 488 |
+
return persisted_checkpoint
|
| 489 |
+
|
| 490 |
+
@property
|
| 491 |
+
def experiment_fs_path(self) -> str:
|
| 492 |
+
"""The path on the `storage_filesystem` to the experiment directory.
|
| 493 |
+
|
| 494 |
+
NOTE: This does not have a URI prefix anymore, since it has been stripped
|
| 495 |
+
by pyarrow.fs.FileSystem.from_uri already. The URI scheme information is
|
| 496 |
+
kept in `storage_filesystem` instead.
|
| 497 |
+
"""
|
| 498 |
+
return Path(self.storage_fs_path, self.experiment_dir_name).as_posix()
|
| 499 |
+
|
| 500 |
+
@property
|
| 501 |
+
def local_working_directory(self) -> str:
|
| 502 |
+
"""Every ray train worker will set this directory as its working directory."""
|
| 503 |
+
if self.experiment_dir_name is None:
|
| 504 |
+
raise RuntimeError(
|
| 505 |
+
"Cannot access `local_working_directory` without "
|
| 506 |
+
"setting `experiment_dir_name`"
|
| 507 |
+
)
|
| 508 |
+
return Path(_get_ray_train_session_dir(), self.experiment_dir_name).as_posix()
|
| 509 |
+
|
| 510 |
+
@property
|
| 511 |
+
def checkpoint_manager_snapshot_path(self) -> str:
|
| 512 |
+
"""The path to the checkpoint manager snapshot file."""
|
| 513 |
+
return Path(
|
| 514 |
+
self.experiment_fs_path, CHECKPOINT_MANAGER_SNAPSHOT_FILENAME
|
| 515 |
+
).as_posix()
|
| 516 |
+
|
| 517 |
+
@staticmethod
|
| 518 |
+
def get_experiment_dir_name(run_obj: Union[str, Callable, Type]) -> str:
|
| 519 |
+
from ray.tune.experiment import Experiment
|
| 520 |
+
|
| 521 |
+
run_identifier = Experiment.get_trainable_name(run_obj)
|
| 522 |
+
|
| 523 |
+
if bool(int(os.environ.get("TUNE_DISABLE_DATED_SUBDIR", 0))):
|
| 524 |
+
dir_name = run_identifier
|
| 525 |
+
else:
|
| 526 |
+
dir_name = "{}_{}".format(run_identifier, date_str())
|
| 527 |
+
return dir_name
|
| 528 |
+
|
| 529 |
+
@staticmethod
|
| 530 |
+
def make_default_checkpoint_dir_name():
|
| 531 |
+
"""Get the name of the checkpoint directory by timestamp."""
|
| 532 |
+
return f"checkpoint_{date_str(include_ms=True)}"
|
| 533 |
+
|
| 534 |
+
def extract_checkpoint_dir_name_from_path(self, checkpoint_path: str) -> str:
|
| 535 |
+
"""Get the checkpoint name from the checkpoint path.
|
| 536 |
+
The parent directory of the checkpoint path should be the experiment directory.
|
| 537 |
+
"""
|
| 538 |
+
# TODO: Use Pathlib to extract the name when supports at least Python 3.9
|
| 539 |
+
experiment_fs_path = self.experiment_fs_path + "/"
|
| 540 |
+
if not checkpoint_path.startswith(experiment_fs_path):
|
| 541 |
+
raise ValueError(
|
| 542 |
+
f"Checkpoint path {checkpoint_path} is not under the experiment "
|
| 543 |
+
f"directory {self.experiment_fs_path}."
|
| 544 |
+
)
|
| 545 |
+
return checkpoint_path[len(experiment_fs_path) :]
|
| 546 |
+
|
| 547 |
+
def build_checkpoint_path_from_name(self, checkpoint_name: str) -> str:
|
| 548 |
+
"""Get the checkpoint path from the checkpoint name.
|
| 549 |
+
The parent directory of the checkpoint path should be the experiment directory.
|
| 550 |
+
"""
|
| 551 |
+
return Path(self.experiment_fs_path, checkpoint_name).as_posix()
|