koichi12 commited on
Commit
ec4d14c
·
verified ·
1 Parent(s): d5967d1

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/ray/train/v2/__pycache__/__init__.cpython-311.pyc +0 -0
  2. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/__init__.py +0 -0
  3. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/__pycache__/__init__.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/__pycache__/exceptions.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/__pycache__/util.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__init__.py +14 -0
  7. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/accelerators.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/backend_setup.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/datasets.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/metrics.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/user_callback.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/working_dir_setup.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/accelerators.py +151 -0
  14. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/backend_setup.py +27 -0
  15. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/datasets.py +76 -0
  16. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/metrics.py +250 -0
  17. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/user_callback.py +50 -0
  18. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/working_dir_setup.py +24 -0
  19. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/constants.py +84 -0
  20. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/exceptions.py +170 -0
  21. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__init__.py +0 -0
  22. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__pycache__/callback.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__pycache__/context.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__pycache__/controller.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__pycache__/storage.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/callback.py +140 -0
  27. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__init__.py +0 -0
  28. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__pycache__/__init__.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__pycache__/checkpoint_manager.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__pycache__/report_handler.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__pycache__/sync_actor.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/checkpoint_manager.py +271 -0
  33. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/report_handler.py +111 -0
  34. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/sync_actor.py +190 -0
  35. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/context.py +281 -0
  36. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/controller.py +377 -0
  37. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/__pycache__/__init__.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/__pycache__/default.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/__pycache__/factory.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/__pycache__/failure_policy.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/default.py +44 -0
  42. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/factory.py +13 -0
  43. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/__init__.py +19 -0
  44. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/__pycache__/factory.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/__pycache__/fixed.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/__pycache__/scaling_policy.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/factory.py +13 -0
  48. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/fixed.py +22 -0
  49. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/scaling_policy.py +51 -0
  50. .venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/storage.py +551 -0
.venv/lib/python3.11/site-packages/ray/train/v2/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (185 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (195 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/__pycache__/exceptions.cpython-311.pyc ADDED
Binary file (10.1 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/__pycache__/util.cpython-311.pyc ADDED
Binary file (8.22 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .accelerators import AcceleratorSetupCallback
2
+ from .backend_setup import BackendSetupCallback
3
+ from .datasets import DatasetsSetupCallback
4
+ from .working_dir_setup import WorkingDirectorySetupCallback
5
+
6
+ __all__ = [
7
+ "AcceleratorSetupCallback",
8
+ "BackendSetupCallback",
9
+ "DatasetsSetupCallback",
10
+ "WorkingDirectorySetupCallback",
11
+ ]
12
+
13
+
14
+ # DO NOT ADD ANYTHING AFTER THIS LINE.
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/accelerators.cpython-311.pyc ADDED
Binary file (7.69 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/backend_setup.cpython-311.pyc ADDED
Binary file (2.26 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/datasets.cpython-311.pyc ADDED
Binary file (4.62 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/metrics.cpython-311.pyc ADDED
Binary file (14.1 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/user_callback.cpython-311.pyc ADDED
Binary file (2.52 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/__pycache__/working_dir_setup.cpython-311.pyc ADDED
Binary file (1.93 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/accelerators.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from collections import defaultdict
4
+ from typing import List
5
+
6
+ import ray._private.ray_constants as ray_constants
7
+ from ray._private.ray_constants import env_bool
8
+ from ray.train import BackendConfig
9
+ from ray.train.constants import ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV
10
+ from ray.train.v2._internal.execution.callback import WorkerGroupCallback
11
+ from ray.train.v2._internal.execution.worker_group import ActorMetadata, WorkerGroup
12
+ from ray.train.v2._internal.util import ray_get_safe
13
+ from ray.train.v2.api.config import ScalingConfig
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class AcceleratorSetupCallback(WorkerGroupCallback):
19
+ """Perform accelerator setup for workers.
20
+
21
+ For example, this callback can be used to share CUDA_VISIBLE_DEVICES
22
+ among workers on the same node.
23
+ """
24
+
25
+ def __init__(self, backend_config: BackendConfig, scaling_config: ScalingConfig):
26
+ self._backend = backend_config.backend_cls()
27
+ self._scaling_config = scaling_config
28
+
29
+ def after_worker_group_start(self, worker_group: WorkerGroup):
30
+ self._maybe_share_cuda_visible_devices(worker_group)
31
+ # TODO: Add support for sharing other accelerator resources.
32
+
33
+ def _maybe_share_cuda_visible_devices(self, worker_group: WorkerGroup):
34
+ share_cuda_visible_devices_enabled = env_bool(
35
+ ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV,
36
+ self._backend.share_cuda_visible_devices,
37
+ )
38
+
39
+ if (
40
+ self._scaling_config._resources_per_worker_not_none.get("GPU", 0) > 0
41
+ and share_cuda_visible_devices_enabled
42
+ ):
43
+ _share_cuda_visible_devices(worker_group)
44
+
45
+
46
+ def _share_cuda_visible_devices(worker_group: WorkerGroup):
47
+ """Sets CUDA_VISIBLE_DEVICES on all workers.
48
+ For each worker, CUDA_VISIBLE_DEVICES will be set to the GPU IDs
49
+ visible to all workers on that worker's node.
50
+ This allows GPU workers on the same node to communicate with one
51
+ another.
52
+
53
+ Example:
54
+ Setup:
55
+ - Node1:
56
+ - Worker1: {0, 1}
57
+ - Worker2: {2, 3}
58
+ - Node2:
59
+ - Worker3: {0, 1}
60
+ CUDA_VISIBLE_DEVICES:
61
+ - Worker1: "0,1,2,3"
62
+ - Worker2: "0,1,2,3"
63
+ - Worker2: "0,1"
64
+ """
65
+ _share_accelerator_ids(
66
+ worker_group, ray_constants.GPU, ray_constants.CUDA_VISIBLE_DEVICES_ENV_VAR
67
+ )
68
+
69
+
70
+ def _share_accelerator_ids(
71
+ worker_group: WorkerGroup, accelerator_name: str, env_var: str
72
+ ):
73
+ """Sets the given env_var on all workers.
74
+ For each worker, the cores/devices are visible to all the
75
+ workers on that worker's node. This allows workers on the
76
+ same node to communicate with one another.
77
+
78
+ Example:
79
+ Setup:
80
+ - Node1:
81
+ - Worker1: {0, 1}
82
+ - Worker2: {2, 3}
83
+ - Node2:
84
+ - Worker3: {0, 1}
85
+ NEURON_RT_VISIBLE_CORES/TPU_VISIBLE_CHIPS/...:
86
+ - Worker1: "0,1,2,3"
87
+ - Worker2: "0,1,2,3"
88
+ - Worker2: "0,1"
89
+
90
+ Args:
91
+ accelerator_name: The name of the accelerator.
92
+ env_var: The name of the environment variable to set.
93
+ """
94
+ if not worker_group.has_started():
95
+ raise RuntimeError(
96
+ "WorkerGroup must be started before sharing accelerator IDs."
97
+ )
98
+
99
+ worker_metadatas = [worker.metadata for worker in worker_group.get_workers()]
100
+ visible_accelerator_ids_per_worker = _get_visible_accelerator_ids_per_worker(
101
+ worker_metadatas=worker_metadatas, accelerator_name=accelerator_name
102
+ )
103
+
104
+ def set_accelerator_ids(accelerator_ids):
105
+ os.environ[env_var] = accelerator_ids
106
+
107
+ futures = []
108
+ for rank, visible_accelerator_ids in enumerate(visible_accelerator_ids_per_worker):
109
+ futures.append(
110
+ worker_group.execute_single_async(
111
+ rank, set_accelerator_ids, accelerator_ids=visible_accelerator_ids
112
+ )
113
+ )
114
+ ray_get_safe(futures)
115
+
116
+
117
+ def _get_visible_accelerator_ids_per_worker(
118
+ worker_metadatas: List[ActorMetadata], accelerator_name: str
119
+ ) -> List[str]:
120
+ """Returns a list of comma-separated accelerator IDs visible to each worker.
121
+
122
+ All workers on a node should have the same set of visible accelerators,
123
+ which is the union of accelerator ids of the workers.
124
+
125
+ Returns:
126
+ visible_accelerator_ids_per_worker: A list of comma-separated accelerator ID
127
+ strings. This list is the same length as the number of workers.
128
+
129
+ """
130
+ for metadata in worker_metadatas:
131
+ if accelerator_name not in metadata.accelerator_ids:
132
+ raise ValueError(
133
+ f"Accelerator '{accelerator_name}' is not available on all workers. "
134
+ f"Got these available accelerators instead: {metadata.accelerator_ids}"
135
+ )
136
+
137
+ node_id_to_accelerator_ids = defaultdict(set)
138
+
139
+ for metadata in worker_metadatas:
140
+ node_id_to_accelerator_ids[metadata.node_id].update(
141
+ metadata.accelerator_ids[accelerator_name]
142
+ )
143
+
144
+ visible_accelerator_ids_per_worker = []
145
+ for worker_id in range(len(worker_metadatas)):
146
+ node_id = worker_metadatas[worker_id].node_id
147
+ accelerator_ids = sorted(node_id_to_accelerator_ids[node_id])
148
+ all_resource_ids = ",".join([str(id) for id in accelerator_ids])
149
+ visible_accelerator_ids_per_worker.append(all_resource_ids)
150
+
151
+ return visible_accelerator_ids_per_worker
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/backend_setup.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from ray.exceptions import RayActorError
4
+ from ray.train.backend import BackendConfig
5
+ from ray.train.v2._internal.execution.callback import WorkerGroupCallback
6
+ from ray.train.v2._internal.execution.worker_group import WorkerGroup
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class BackendSetupCallback(WorkerGroupCallback):
12
+ def __init__(self, backend_config: BackendConfig):
13
+ self._backend_config = backend_config
14
+ self._backend = backend_config.backend_cls()
15
+
16
+ def after_worker_group_start(self, worker_group: WorkerGroup):
17
+ self._backend.on_start(worker_group, self._backend_config)
18
+ self._backend.on_training_start(worker_group, self._backend_config)
19
+
20
+ def before_worker_group_shutdown(self, worker_group: WorkerGroup):
21
+ try:
22
+ self._backend.on_shutdown(worker_group, self._backend_config)
23
+ except RayActorError:
24
+ logger.warning(
25
+ "Graceful shutdown of backend failed. This is "
26
+ "expected if one of the workers has crashed."
27
+ )
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/datasets.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Any, Callable, Dict, List, Union
3
+
4
+ import ray.train
5
+ from ray.data import Dataset
6
+ from ray.data.context import DataContext
7
+ from ray.train.v2._internal.execution.callback import WorkerGroupCallback
8
+ from ray.train.v2._internal.execution.worker_group.worker_group import WorkerGroup
9
+
10
+ # A type representing either a ray.data.Dataset or a function that returns a
11
+ # ray.data.Dataset and accepts no arguments.
12
+ GenDataset = Union[Dataset, Callable[[], Dataset]]
13
+
14
+
15
+ class DatasetsSetupCallback(WorkerGroupCallback):
16
+ """The callback to setup Ray Datasets for the worker group."""
17
+
18
+ def __init__(
19
+ self,
20
+ datasets: Dict[str, GenDataset],
21
+ data_config: ray.train.DataConfig,
22
+ scaling_config: ray.train.ScalingConfig,
23
+ ):
24
+ self._datasets = datasets
25
+ self._data_config = data_config
26
+ self._scaling_config = scaling_config
27
+
28
+ # Capture the current DataContext to propagate it to
29
+ # the Train workers later.
30
+ # The propagation works in the following way:
31
+ # 1. This callback is created when user create the Trainer.
32
+ # 2. Then this callback will be passed to the Controller actor.
33
+ # 3. Lastly, when the worker group is initialized, the Controller
34
+ # will call the `after_worker_group_start` callback to propagate
35
+ # the DataContext to Train workers.
36
+ self._data_context = copy.deepcopy(DataContext.get_current())
37
+
38
+ def get_train_total_resources(
39
+ self, scaling_config: ray.train.ScalingConfig
40
+ ) -> Dict[str, float]:
41
+ """Return the resources reserved for training, so that Data can exclude
42
+ these resources logically from its available pool."""
43
+ return scaling_config.total_resources
44
+
45
+ def before_init_train_context(
46
+ self, worker_group: "WorkerGroup"
47
+ ) -> Dict[str, List[Any]]:
48
+ # Configure dataset shards
49
+ datasets = {k: v() if callable(v) else v for k, v in self._datasets.items()}
50
+ node_ids = [worker.metadata.node_id for worker in worker_group.get_workers()]
51
+
52
+ # Notify the DataConfig about the total resources reserved for training.
53
+ total_train_resources = self.get_train_total_resources(self._scaling_config)
54
+ self._data_config.set_train_total_resources(
55
+ total_train_resources.get("CPU", 0), total_train_resources.get("GPU", 0)
56
+ )
57
+
58
+ dataset_shards = self._data_config.configure(
59
+ datasets,
60
+ world_size=len(worker_group),
61
+ worker_handles=None,
62
+ worker_node_ids=node_ids,
63
+ )
64
+ assert len(dataset_shards) == len(worker_group)
65
+
66
+ return {"dataset_shards": dataset_shards}
67
+
68
+ def after_worker_group_start(self, worker_group: "WorkerGroup"):
69
+ # Propagate DataContext
70
+ def _propagate_data_context(ctx: DataContext):
71
+ DataContext._set_current(ctx)
72
+
73
+ worker_group.execute(
74
+ _propagate_data_context,
75
+ self._data_context,
76
+ )
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/metrics.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import time
3
+ from contextlib import contextmanager
4
+ from dataclasses import asdict, dataclass, field, fields
5
+ from typing import Dict, Optional
6
+
7
+ from ray.train.v2._internal.execution.callback import (
8
+ ControllerCallback,
9
+ TrainContextCallback,
10
+ WorkerCallback,
11
+ WorkerGroupCallback,
12
+ )
13
+ from ray.train.v2._internal.execution.context import TrainRunContext, get_train_context
14
+ from ray.train.v2._internal.util import time_monotonic
15
+ from ray.util.metrics import Gauge
16
+
17
+ # Prometheus Tag keys for the worker and controller metrics.
18
+ RUN_NAME_TAG_KEY = "ray_train_run_name"
19
+ WORKER_WORLD_RANK_TAG_KEY = "ray_train_worker_world_rank"
20
+
21
+
22
+ @dataclass
23
+ class ControllerMetrics:
24
+ """A list of Train controller metrics.
25
+
26
+ Metric metadata attributes:
27
+ - description (required): A human-readable description of the metric, also used as
28
+ the chart description on the Ray Train dashboard.
29
+ """
30
+
31
+ train_worker_group_start_total_time_s: float = field(
32
+ default=0.0,
33
+ metadata={
34
+ "description": (
35
+ "Cumulative time in seconds to start worker groups in the Train job."
36
+ ),
37
+ },
38
+ )
39
+
40
+ train_worker_group_shutdown_total_time_s: float = field(
41
+ default=0.0,
42
+ metadata={
43
+ "description": (
44
+ "Cumulative time in seconds to shutdown worker groups in the Train job."
45
+ ),
46
+ },
47
+ )
48
+
49
+
50
+ @dataclass
51
+ class WorkerMetrics:
52
+ """A list of Train worker metrics.
53
+
54
+ Metric metadata attributes:
55
+ - description (required): A human-readable description of the metric, also used as
56
+ the chart description on the Ray Train dashboard.
57
+ """
58
+
59
+ train_report_total_blocked_time_s: float = field(
60
+ default=0.0,
61
+ metadata={
62
+ "description": (
63
+ "Cumulative time in seconds to report a checkpoint to the storage."
64
+ ),
65
+ },
66
+ )
67
+
68
+
69
+ class ControllerMetricsCallback(ControllerCallback, WorkerGroupCallback):
70
+ # Interval for pushing metrics to Prometheus.
71
+ LOCAL_METRICS_PUSH_INTERVAL_S: float = 5.0
72
+ CONTROLLER_TAG_KEYS = (RUN_NAME_TAG_KEY,)
73
+
74
+ def __init__(self, train_run_context: TrainRunContext):
75
+ """
76
+ This callback is initialized on the driver process and then passed to the
77
+ controller. This callback collects metrics from the controller actor as well
78
+ as the metrics related to the worker groups.
79
+ """
80
+ self._run_name = train_run_context.get_run_config().name
81
+ self._thread: Optional[threading.Thread] = None
82
+ self._thread_stop_event: Optional[threading.Event] = None
83
+ self._metrics: Optional[ControllerMetrics] = None
84
+ self._metrics_lock: Optional[threading.Lock] = None
85
+ self._controller_tag: Dict[str, str] = {}
86
+ self._metrics_gauges: Dict[str, Gauge] = {}
87
+
88
+ def _create_prometheus_controller_metrics(self) -> Dict[str, Gauge]:
89
+ """Create Prometheus worker metrics for the ControllerMetrics dataclass."""
90
+ metrics = {}
91
+ for _field in fields(ControllerMetrics):
92
+ metric_description = _field.metadata.get("description")
93
+ metrics[_field.name] = Gauge(
94
+ _field.name,
95
+ description=metric_description,
96
+ tag_keys=self.CONTROLLER_TAG_KEYS,
97
+ )
98
+ return metrics
99
+
100
+ def after_controller_start(self):
101
+ """
102
+ Creating a thread to periodically push local metrics to the gauges
103
+ after the train controller starts.
104
+ """
105
+ self._controller_tag = {
106
+ RUN_NAME_TAG_KEY: self._run_name,
107
+ }
108
+ self._thread_stop_event = threading.Event()
109
+ self._metrics_lock = threading.Lock()
110
+ self._metrics = ControllerMetrics()
111
+ self._metrics_gauges = self._create_prometheus_controller_metrics()
112
+
113
+ def push_local_metrics():
114
+ while not self._thread_stop_event.is_set():
115
+ with self._metrics_lock:
116
+ metrics_dict = asdict(self._metrics)
117
+ for metric_name, metric_value in metrics_dict.items():
118
+ self._metrics_gauges[metric_name].set(
119
+ metric_value, self._controller_tag
120
+ )
121
+ time.sleep(ControllerMetricsCallback.LOCAL_METRICS_PUSH_INTERVAL_S)
122
+
123
+ assert not self._thread
124
+ self._thread = threading.Thread(target=push_local_metrics, daemon=True)
125
+ self._thread.start()
126
+
127
+ def before_controller_shutdown(self):
128
+ """
129
+ Stop the thread that pushes local metrics to the gauges before the
130
+ controller shuts down.
131
+ """
132
+ # Stop the thread that pushes local metrics to the metrics gauges.
133
+ assert not self._thread_stop_event.is_set()
134
+ self._thread_stop_event.set()
135
+ # Reset the metrics to their default values.
136
+ for _field in fields(self._metrics):
137
+ self._metrics_gauges[_field.name].set(_field.default, self._controller_tag)
138
+
139
+ @contextmanager
140
+ def on_worker_group_start(self):
141
+ """
142
+ Context manager to measure the time taken to start a worker group.
143
+ """
144
+ start_time_s = time_monotonic()
145
+ yield
146
+ elapsed_time_s = time_monotonic() - start_time_s
147
+ with self._metrics_lock:
148
+ self._metrics.train_worker_group_start_total_time_s += elapsed_time_s
149
+
150
+ @contextmanager
151
+ def on_worker_group_shutdown(self):
152
+ """
153
+ Context manager to measure the time taken to start a worker group.
154
+ """
155
+ start_time_s = time_monotonic()
156
+ yield
157
+ elapsed_time_s = time_monotonic() - start_time_s
158
+ with self._metrics_lock:
159
+ self._metrics.train_worker_group_shutdown_total_time_s += elapsed_time_s
160
+
161
+
162
+ class WorkerMetricsCallback(WorkerCallback, TrainContextCallback):
163
+ # Interval for pushing metrics to Prometheus.
164
+ LOCAL_METRICS_PUSH_INTERVAL_S: float = 5.0
165
+ WORKER_TAG_KEYS = (RUN_NAME_TAG_KEY, WORKER_WORLD_RANK_TAG_KEY)
166
+
167
+ def __init__(self, train_run_context: TrainRunContext):
168
+ """
169
+ This callback is initialized on the driver process and then passed to the
170
+ workers. When adding more class attributes, make sure the attributes are
171
+ serializable picklable.
172
+
173
+ TODO: Making Callbacks factory methods that when they are initialized on the
174
+ driver process, we do not need to worry about pickling the callback instances.
175
+ """
176
+ self._run_name = train_run_context.get_run_config().name
177
+ self._thread: Optional[threading.Thread] = None
178
+ self._thread_stop_event: Optional[threading.Event] = None
179
+ self._metrics_lock: Optional[threading.Lock] = None
180
+ self._metrics: Optional[WorkerMetrics] = None
181
+ self._worker_tag: Dict[str, str] = {}
182
+ self._metrics_gauges: Dict[str, Gauge] = {}
183
+
184
+ def _create_prometheus_worker_metrics(self) -> Dict[str, Gauge]:
185
+ """Create Prometheus worker metrics for the TrainMetrics dataclass."""
186
+ metrics = {}
187
+ for _field in fields(self._metrics):
188
+ metric_description = _field.metadata.get("description")
189
+ metrics[_field.name] = Gauge(
190
+ _field.name,
191
+ description=metric_description,
192
+ tag_keys=self.WORKER_TAG_KEYS,
193
+ )
194
+ return metrics
195
+
196
+ def after_init_train_context(self):
197
+ """
198
+ Creating a thread to periodically push local metrics to the gauges
199
+ after the train context is initialized.
200
+
201
+ Note:
202
+ This method should be called after the train context is initialized on
203
+ each of the worker. The thread should not be created in the `__init__`
204
+ method which is called on the train driver process.
205
+ """
206
+ self._worker_tag = {
207
+ RUN_NAME_TAG_KEY: self._run_name,
208
+ WORKER_WORLD_RANK_TAG_KEY: str(get_train_context().get_world_rank()),
209
+ }
210
+ self._thread_stop_event = threading.Event()
211
+ self._metrics_lock = threading.Lock()
212
+ self._metrics = WorkerMetrics()
213
+ self._metrics_gauges = self._create_prometheus_worker_metrics()
214
+
215
+ def push_local_metrics():
216
+ while not self._thread_stop_event.is_set():
217
+ with self._metrics_lock:
218
+ metrics_dict = asdict(self._metrics)
219
+ for metric_name, metric_value in metrics_dict.items():
220
+ self._metrics_gauges[metric_name].set(
221
+ metric_value, self._worker_tag
222
+ )
223
+ time.sleep(WorkerMetricsCallback.LOCAL_METRICS_PUSH_INTERVAL_S)
224
+
225
+ assert not self._thread
226
+ self._thread = threading.Thread(target=push_local_metrics, daemon=True)
227
+ self._thread.start()
228
+
229
+ def before_worker_shutdown(self):
230
+ """
231
+ Stop the thread that pushes local metrics to the metrics gauges before
232
+ the worker group shuts down.
233
+ """
234
+ # Stop the thread that pushes local metrics to the gauges.
235
+ assert not self._thread_stop_event.is_set()
236
+ self._thread_stop_event.set()
237
+ # Reset the metrics to their default values.
238
+ for _field in fields(self._metrics):
239
+ self._metrics_gauges[_field.name].set(_field.default, self._worker_tag)
240
+
241
+ @contextmanager
242
+ def on_report(self):
243
+ """
244
+ Context manager to measure the time taken to report a checkpoint to the storage.
245
+ """
246
+ start_time_s = time_monotonic()
247
+ yield
248
+ elapsed_time_s = time_monotonic() - start_time_s
249
+ with self._metrics_lock:
250
+ self._metrics.train_report_total_blocked_time_s += elapsed_time_s
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/user_callback.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ from ray.train import Checkpoint
4
+ from ray.train.v2._internal.execution.callback import (
5
+ ReportCallback,
6
+ WorkerGroupCallback,
7
+ )
8
+ from ray.train.v2._internal.execution.context import TrainRunContext
9
+ from ray.train.v2._internal.execution.worker_group import WorkerGroupStatus
10
+ from ray.train.v2.api.callback import UserCallback
11
+
12
+
13
+ class UserCallbackHandler(WorkerGroupCallback, ReportCallback):
14
+ """Responsible for calling methods of subscribers implementing
15
+ the `UserCallback` interface.
16
+ """
17
+
18
+ def __init__(
19
+ self, user_callbacks: List[UserCallback], train_run_context: TrainRunContext
20
+ ):
21
+ self._user_callbacks = user_callbacks
22
+ self._train_run_context = train_run_context
23
+
24
+ # --------------------------
25
+ # ReportCallback
26
+ # --------------------------
27
+
28
+ def after_report(
29
+ self, metrics: List[Dict[str, Any]], checkpoint: Optional[Checkpoint]
30
+ ):
31
+ for user_callback in self._user_callbacks:
32
+ user_callback.after_report(
33
+ run_context=self._train_run_context,
34
+ metrics=metrics,
35
+ checkpoint=checkpoint,
36
+ )
37
+
38
+ # --------------------------
39
+ # WorkerGroupCallback
40
+ # --------------------------
41
+
42
+ def after_worker_group_poll_status(self, worker_group_status: WorkerGroupStatus):
43
+ if not worker_group_status.errors:
44
+ return
45
+
46
+ for user_callback in self._user_callbacks:
47
+ user_callback.after_exception(
48
+ run_context=self._train_run_context,
49
+ worker_exceptions=worker_group_status.errors,
50
+ )
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/callbacks/working_dir_setup.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ from ray.train.v2._internal.execution.callback import WorkerGroupCallback
5
+ from ray.train.v2._internal.execution.context import get_train_context
6
+ from ray.train.v2._internal.execution.worker_group import WorkerGroup
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class WorkingDirectorySetupCallback(WorkerGroupCallback):
12
+ def after_worker_group_start(self, worker_group: WorkerGroup):
13
+ def chdir_to_working_dir() -> None:
14
+ """Create the local working directory for the experiment."""
15
+ local_working_directory = (
16
+ get_train_context().get_storage().local_working_directory
17
+ )
18
+ os.makedirs(local_working_directory, exist_ok=True)
19
+ logger.debug(
20
+ f"Changing the working directory to: {local_working_directory}"
21
+ )
22
+ os.chdir(local_working_directory)
23
+
24
+ worker_group.execute(chdir_to_working_dir)
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/constants.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict
3
+
4
+ from ray._private.ray_constants import env_bool, env_set_by_user
5
+
6
+ # Unsupported configs can use this value to detect if the user has set it.
7
+ _UNSUPPORTED = "UNSUPPORTED"
8
+ _DEPRECATED = "DEPRECATED"
9
+
10
+ # The name of the file that is used to validate the storage.
11
+ VALIDATE_STORAGE_MARKER_FILENAME = ".validate_storage_marker"
12
+ # The name of the file that is used to store the checkpoint manager snapshot.
13
+ CHECKPOINT_MANAGER_SNAPSHOT_FILENAME = "checkpoint_manager_snapshot.json"
14
+
15
+
16
+ # =====================
17
+ # Environment Variables
18
+ # =====================
19
+
20
+ # Polling interval for the Train controller.
21
+ # This determines how many seconds the controller will wait between
22
+ # polling the worker group for its status.
23
+ HEALTH_CHECK_INTERVAL_S_ENV_VAR = "RAY_TRAIN_HEALTH_CHECK_INTERVAL_S"
24
+ DEFAULT_HEALTH_CHECK_INTERVAL_S: float = 2.0
25
+
26
+ # The time in seconds a worker health check must be hanging for
27
+ # before the controller marks the worker as dead and handles the failure.
28
+ WORKER_HEALTH_CHECK_TIMEOUT_S_ENV_VAR = "RAY_TRAIN_WORKER_HEALTH_CHECK_TIMEOUT_S"
29
+ DEFAULT_WORKER_HEALTH_CHECK_TIMEOUT_S: float = 10 * 60
30
+
31
+ # Timeout in seconds for the worker group to start.
32
+ WORKER_GROUP_START_TIMEOUT_S_ENV_VAR = "RAY_TRAIN_WORKER_GROUP_START_TIMEOUT_S"
33
+ DEFAULT_WORKER_GROUP_START_TIMEOUT_S: float = 30.0
34
+
35
+ # Timeout in seconds for `ray.train.report` to block on synchronization barriers,
36
+ # after which a timeout error will be raised.
37
+ REPORT_BARRIER_TIMEOUT_S_ENV_VAR = "RAY_TRAIN_REPORT_BARRIER_TIMEOUT_S"
38
+ DEFAULT_REPORT_BARRIER_TIMEOUT_S: float = 60 * 30
39
+ # Time in seconds for `ray.train.report` to log a warning if it is waiting for sync
40
+ # actor notification of releasing.
41
+ REPORT_BARRIER_WARN_INTERVAL_S_ENV_VAR = "RAY_TRAIN_REPORT_BARRIER_WARN_INTERVAL_S"
42
+ DEFAULT_REPORT_BARRIER_WARN_INTERVAL_S: float = 60
43
+
44
+ # The environment variable to enable the Ray Train Metrics.
45
+ METRICS_ENABLED_ENV_VAR = "RAY_TRAIN_METRICS_ENABLED"
46
+
47
+ # Environment variable to enable the print function patching.
48
+ ENABLE_PRINT_PATCH_ENV_VAR = "RAY_TRAIN_ENABLE_PRINT_PATCH"
49
+ DEFAULT_ENABLE_PRINT_PATCH = "1"
50
+
51
+ # Whether or not to run the controller as an actor.
52
+ RUN_CONTROLLER_AS_ACTOR_ENV_VAR = "RAY_TRAIN_RUN_CONTROLLER_AS_ACTOR"
53
+ DEFAULT_RUN_CONTROLLER_AS_ACTOR = "1"
54
+
55
+ # V2 feature flag.
56
+ V2_ENABLED_ENV_VAR = "RAY_TRAIN_V2_ENABLED"
57
+
58
+
59
+ def is_v2_enabled() -> bool:
60
+ return env_bool(V2_ENABLED_ENV_VAR, False)
61
+
62
+
63
+ ENV_VARS_TO_PROPAGATE = {
64
+ V2_ENABLED_ENV_VAR,
65
+ HEALTH_CHECK_INTERVAL_S_ENV_VAR,
66
+ WORKER_HEALTH_CHECK_TIMEOUT_S_ENV_VAR,
67
+ WORKER_GROUP_START_TIMEOUT_S_ENV_VAR,
68
+ ENABLE_PRINT_PATCH_ENV_VAR,
69
+ }
70
+
71
+
72
+ def get_env_vars_to_propagate() -> Dict[str, str]:
73
+ """Returns a dictionary of environment variables that should be propagated
74
+ from the driver to the controller, and then from the controller
75
+ to each training worker.
76
+
77
+ This way, users only need to set environment variables in one place
78
+ when launching the script instead of needing to manually set a runtime environment.
79
+ """
80
+ env_vars = {}
81
+ for env_var in ENV_VARS_TO_PROPAGATE:
82
+ if env_set_by_user(env_var):
83
+ env_vars[env_var] = os.environ[env_var]
84
+ return env_vars
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/exceptions.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List, Optional
3
+
4
+ from ray.train.v2._internal.constants import (
5
+ DEFAULT_WORKER_GROUP_START_TIMEOUT_S,
6
+ DEFAULT_WORKER_HEALTH_CHECK_TIMEOUT_S,
7
+ REPORT_BARRIER_TIMEOUT_S_ENV_VAR,
8
+ WORKER_GROUP_START_TIMEOUT_S_ENV_VAR,
9
+ WORKER_HEALTH_CHECK_TIMEOUT_S_ENV_VAR,
10
+ )
11
+
12
+
13
+ # TODO: Distinguish between user and system exceptions.
14
+ class RayTrainError(Exception):
15
+ """Base class for all Ray Train exceptions."""
16
+
17
+
18
+ class WorkerHealthCheckTimeoutError(RayTrainError):
19
+ """Exception raised when a worker health check hangs for long enough."""
20
+
21
+ def __init__(self, message):
22
+ timeout = os.getenv(
23
+ WORKER_HEALTH_CHECK_TIMEOUT_S_ENV_VAR, DEFAULT_WORKER_HEALTH_CHECK_TIMEOUT_S
24
+ )
25
+ message += (
26
+ f"\nSet the {WORKER_HEALTH_CHECK_TIMEOUT_S_ENV_VAR} "
27
+ "environment variable to increase the timeout "
28
+ f"(current value: {timeout} seconds)."
29
+ )
30
+ super().__init__(message)
31
+
32
+
33
+ class WorkerHealthCheckFailedError(RayTrainError):
34
+ """Exception raised when a worker health check fails."""
35
+
36
+ def __init__(self, message, failure: Exception):
37
+ super().__init__(message)
38
+ self._message = message
39
+ self.health_check_failure = failure
40
+
41
+ def __reduce__(self):
42
+ return (self.__class__, (self._message, self.health_check_failure))
43
+
44
+
45
+ class TrainingFailedError(RayTrainError):
46
+ """Exception raised when training fails."""
47
+
48
+ def __init__(self, worker_failures: Dict[int, Exception]):
49
+ super().__init__(
50
+ "Training failed due to worker errors. "
51
+ "Please inspect the error logs above, "
52
+ "or access the latest worker failures in this "
53
+ "exception's `worker_failures` attribute."
54
+ )
55
+ self.worker_failures = worker_failures
56
+
57
+ def __reduce__(self):
58
+ return (self.__class__, (self.worker_failures,))
59
+
60
+
61
+ class WorkerGroupStartupTimeoutError(RayTrainError):
62
+ """Exception raised when the worker group startup times out.
63
+
64
+ Example scenario: 4 GPUs are detected in the cluster, but when the worker
65
+ are actually scheduled, one of the nodes goes down and only 3 GPUs are
66
+ available. One of the worker tasks may be stuck pending, until a timeout is reached.
67
+ """
68
+
69
+ def __init__(self, num_workers: int):
70
+ timeout = float(
71
+ os.environ.get(
72
+ WORKER_GROUP_START_TIMEOUT_S_ENV_VAR,
73
+ DEFAULT_WORKER_GROUP_START_TIMEOUT_S,
74
+ )
75
+ )
76
+ self.num_workers = num_workers
77
+ super().__init__(
78
+ f"The worker group startup timed out after {timeout} seconds waiting "
79
+ f"for {num_workers} workers. "
80
+ "Potential causes include: "
81
+ "(1) temporary insufficient cluster resources while waiting for "
82
+ "autoscaling (ignore this warning in this case), "
83
+ "(2) infeasible resource request where the provided `ScalingConfig` "
84
+ "cannot be satisfied), "
85
+ "and (3) transient network issues. "
86
+ f"Set the {WORKER_GROUP_START_TIMEOUT_S_ENV_VAR} "
87
+ "environment variable to increase the timeout."
88
+ )
89
+
90
+ def __reduce__(self):
91
+ return (self.__class__, (self.num_workers,))
92
+
93
+
94
+ class WorkerGroupStartupFailedError(RayTrainError):
95
+ """Exception raised when the worker group fails to start.
96
+
97
+ Example scenario: A worker is scheduled onto a node that dies while
98
+ the worker actor is initializing.
99
+ """
100
+
101
+
102
+ class CheckpointManagerInitializationError(RayTrainError):
103
+ """Exception raised when the checkpoint manager fails to initialize from a snapshot.
104
+
105
+ Example scenarios:
106
+ 1. The checkpoint manager snapshot version is old and
107
+ incompatible with the current version of Ray Train.
108
+ 2. The checkpoint manager snapshot JSON file is corrupted.
109
+ 3. The checkpoint manager snapshot references checkpoints that cannot be found
110
+ in the run storage path.
111
+ """
112
+
113
+
114
+ class CollectiveTimeoutError(RayTrainError):
115
+ """Exception raised when an internal Ray Train collective operation of
116
+ the worker group times out.
117
+ """
118
+
119
+
120
+ class BroadcastCollectiveTimeoutError(CollectiveTimeoutError):
121
+ """Exception raised when the broadcast operation times out.
122
+
123
+ There are two main timeout examples:
124
+ 1. If not all workers call `ray.train.report`, the entire worker group will
125
+ hang until the timeout before raising. This prevents indefinite worker
126
+ group hangs.
127
+ 2. If a worker is slow in the training loop and fails to reach the broadcast
128
+ time, the collective will time out.
129
+ """
130
+
131
+ def __init__(
132
+ self, time_elapsed: Optional[float], missing_ranks: List[int], timeout_s: float
133
+ ):
134
+ self._time_elapsed = time_elapsed
135
+ self._missing_ranks = missing_ranks
136
+ self._timeout_s = timeout_s
137
+
138
+ message = (
139
+ f"The broadcast operation timed out after {time_elapsed:.2f} seconds. "
140
+ "Please make sure all worker ranks call `ray.train.report`. \n"
141
+ f"The following ranks have not called it: {missing_ranks}\n"
142
+ f"You can set this timeout with the {REPORT_BARRIER_TIMEOUT_S_ENV_VAR} "
143
+ f"environment variable (current value: {timeout_s:.2f} s)."
144
+ )
145
+ super().__init__(message)
146
+
147
+ def __reduce__(self):
148
+ return (
149
+ self.__class__,
150
+ (self._time_elapsed, self._missing_ranks, self._timeout_s),
151
+ )
152
+
153
+
154
+ class UserExceptionWithTraceback(RayTrainError):
155
+ """This class wraps a user code exception raised on the worker
156
+ with its original traceback string, for logging and debugging purposes.
157
+
158
+ This is needed because the original exception traceback is not serialized
159
+ with the exception when it is *returned* back to the main process.
160
+ """
161
+
162
+ def __init__(self, exc: BaseException, traceback_str: str):
163
+ self._base_exc = exc
164
+ self._traceback_str = traceback_str
165
+
166
+ def __reduce__(self):
167
+ return (self.__class__, (self._base_exc, self._traceback_str))
168
+
169
+ def __str__(self):
170
+ return self._traceback_str
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__pycache__/callback.cpython-311.pyc ADDED
Binary file (7.91 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__pycache__/context.cpython-311.pyc ADDED
Binary file (13.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__pycache__/controller.cpython-311.pyc ADDED
Binary file (18.3 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/__pycache__/storage.cpython-311.pyc ADDED
Binary file (28.2 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/callback.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
3
+
4
+ from ray.train.v2.api.callback import RayTrainCallback
5
+ from ray.util.annotations import DeveloperAPI
6
+
7
+ if TYPE_CHECKING:
8
+ from ray.train import Checkpoint
9
+ from ray.train.v2._internal.execution.controller import TrainControllerState
10
+ from ray.train.v2._internal.execution.failure_handling import FailureDecision
11
+ from ray.train.v2._internal.execution.scaling_policy import ScalingDecision
12
+ from ray.train.v2._internal.execution.worker_group import (
13
+ WorkerGroup,
14
+ WorkerGroupStatus,
15
+ )
16
+
17
+
18
+ @DeveloperAPI
19
+ class WorkerGroupCallback(RayTrainCallback):
20
+ def before_init_train_context(
21
+ self, worker_group: "WorkerGroup"
22
+ ) -> Dict[str, List[Any]]:
23
+ """Called before initializing the TrainContext for the worker_group.
24
+
25
+ Return:
26
+ A dictionary of additional arguments for TrainContext.
27
+ The key is the argument name and the value is a list of argument values
28
+ to pass to the TrainContext constructor of each worker in the worker group.
29
+ """
30
+ return {}
31
+
32
+ @contextmanager
33
+ def on_worker_group_start(self):
34
+ yield
35
+
36
+ def after_worker_group_start(self, worker_group: "WorkerGroup"):
37
+ """Called after the worker group actors are initialized.
38
+ All workers should be ready to execute tasks."""
39
+ pass
40
+
41
+ def after_worker_group_training_start(self, worker_group: "WorkerGroup"):
42
+ pass
43
+
44
+ @contextmanager
45
+ def on_worker_group_shutdown(self):
46
+ yield
47
+
48
+ def before_worker_group_shutdown(self, worker_group: "WorkerGroup"):
49
+ """Called before the worker group is shut down.
50
+ Workers may be dead at this point due to actor failures, so this method
51
+ should catch and handle exceptions if attempting to execute tasks."""
52
+ pass
53
+
54
+ def after_worker_group_poll_status(self, worker_group_status: "WorkerGroupStatus"):
55
+ pass
56
+
57
+
58
+ @DeveloperAPI
59
+ class ControllerCallback(RayTrainCallback):
60
+ def after_controller_start(self):
61
+ """Called immediately after `TrainController.run` is called,
62
+ before the control loop starts executing."""
63
+ pass
64
+
65
+ def before_controller_shutdown(self):
66
+ """Called before `TrainController.run` exits,
67
+ after the control loop has exited."""
68
+ pass
69
+
70
+ def after_controller_state_update(
71
+ self,
72
+ previous_state: "TrainControllerState",
73
+ current_state: "TrainControllerState",
74
+ ):
75
+ """Called whenever the controller state is updated."""
76
+ pass
77
+
78
+ def before_controller_execute_failure_decision(
79
+ self,
80
+ failure_decision: "FailureDecision",
81
+ worker_group_status: "WorkerGroupStatus",
82
+ ):
83
+ """Called before the controller executes a failure decision."""
84
+ pass
85
+
86
+ def before_controller_execute_scaling_decision(
87
+ self,
88
+ scaling_decision: "ScalingDecision",
89
+ worker_group_status: "WorkerGroupStatus",
90
+ ):
91
+ """Called before the controller executes a scaling decision."""
92
+ pass
93
+
94
+
95
+ @DeveloperAPI
96
+ class ReportCallback(RayTrainCallback):
97
+ def after_report(
98
+ self, metrics: List[Dict[str, Any]], checkpoint: Optional["Checkpoint"]
99
+ ):
100
+ """Called after all workers have reported a training result.
101
+
102
+ Note that this differs from `after_worker_group_poll_status`,
103
+ which may only contain a subset of workers that have reported.
104
+ For example, if only rank 0 is performing checkpointing, then
105
+ rank 0 would report a training result the slowest.
106
+ """
107
+ pass
108
+
109
+
110
+ @DeveloperAPI
111
+ class WorkerCallback(RayTrainCallback):
112
+ """
113
+ Callbacks that are hooked to the worker event.
114
+
115
+ These callbacks are created on the train driver process and then
116
+ copied and passed to all the workers.
117
+ The execution of these callbacks happens on each of the workers,
118
+ not on the train driver process.
119
+ """
120
+
121
+ def after_init_train_context(self):
122
+ pass
123
+
124
+ def before_worker_shutdown(self):
125
+ pass
126
+
127
+
128
+ @DeveloperAPI
129
+ class TrainContextCallback(RayTrainCallback):
130
+ """
131
+ Callbacks that are hooked to the train context event.
132
+
133
+ These callbacks are created on the train driver process and then
134
+ copied and passed to all the workers.
135
+ The execution of these callbacks happens on the train context of the workers.
136
+ """
137
+
138
+ @contextmanager
139
+ def on_report(self):
140
+ yield
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (216 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__pycache__/checkpoint_manager.cpython-311.pyc ADDED
Binary file (12.8 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__pycache__/report_handler.cpython-311.pyc ADDED
Binary file (5.79 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/__pycache__/sync_actor.cpython-311.pyc ADDED
Binary file (10.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/checkpoint_manager.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ from ray.air.config import CheckpointConfig
5
+ from ray.train._checkpoint import Checkpoint
6
+ from ray.train._internal.checkpoint_manager import (
7
+ _CheckpointManager,
8
+ _insert_into_sorted_list,
9
+ )
10
+ from ray.train._internal.session import _TrainingResult
11
+ from ray.train.v2._internal.exceptions import CheckpointManagerInitializationError
12
+ from ray.train.v2._internal.execution.callback import ReportCallback
13
+ from ray.train.v2._internal.execution.context import StorageContext
14
+ from ray.train.v2._internal.execution.storage import _delete_fs_path, _exists_at_fs_path
15
+
16
+ try:
17
+ from pydantic import BaseModel
18
+ from pydantic_core import from_json
19
+ except (ImportError, ModuleNotFoundError) as exc:
20
+ raise ImportError(
21
+ "`ray.train.v2` requires the pydantic package, which is missing. "
22
+ "Run the following command to fix this: `pip install pydantic`"
23
+ ) from exc
24
+
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class _TrainingResultState(BaseModel):
30
+ # Increment version if the schema changes
31
+ version: int = 0
32
+ checkpoint_dir_name: str
33
+ metrics: dict
34
+
35
+
36
+ class _CheckpointManagerState(BaseModel):
37
+ # Increment version if the schema changes
38
+ version: int = 0
39
+ checkpoint_results: List[_TrainingResultState]
40
+ latest_checkpoint_result: Optional[_TrainingResultState]
41
+
42
+
43
+ def _get_training_result_from_state(
44
+ state: _TrainingResultState,
45
+ storage_context: StorageContext,
46
+ ) -> _TrainingResult:
47
+ """Get a TrainingResult object from a Pydantic state object."""
48
+ return _TrainingResult(
49
+ checkpoint=Checkpoint(
50
+ path=storage_context.build_checkpoint_path_from_name(
51
+ state.checkpoint_dir_name
52
+ ),
53
+ filesystem=storage_context.storage_filesystem,
54
+ ),
55
+ metrics=state.metrics,
56
+ )
57
+
58
+
59
+ def _get_state_from_training_result(
60
+ training_result: _TrainingResult,
61
+ storage_context: StorageContext,
62
+ ) -> _TrainingResultState:
63
+ """Get a Pydantic state object from a TrainingResult object."""
64
+ return _TrainingResultState(
65
+ checkpoint_dir_name=storage_context.extract_checkpoint_dir_name_from_path(
66
+ training_result.checkpoint.path
67
+ ),
68
+ metrics=training_result.metrics,
69
+ )
70
+
71
+
72
+ class CheckpointManager(_CheckpointManager, ReportCallback):
73
+ def __init__(
74
+ self,
75
+ checkpoint_config: CheckpointConfig,
76
+ storage_context: StorageContext,
77
+ ):
78
+ self._storage_context = storage_context
79
+ self._checkpoint_config = checkpoint_config
80
+ super().__init__(checkpoint_config)
81
+ # If the snapshot is found, the checkpoint manager will restore its state.
82
+ self._maybe_load_state_from_storage()
83
+
84
+ def register_checkpoint(self, checkpoint_result: _TrainingResult):
85
+ """Register new checkpoint and add to bookkeeping.
86
+
87
+ This method will register a new checkpoint and add it to the internal
88
+ bookkeeping logic. This means the checkpoint manager will decide if
89
+ this checkpoint should be kept, and if older or worse performing
90
+ checkpoints should be deleted.
91
+
92
+ Args:
93
+ checkpoint: Tracked checkpoint object to add to bookkeeping.
94
+ """
95
+ self._latest_checkpoint_result = checkpoint_result
96
+
97
+ if self._checkpoint_config.checkpoint_score_attribute is not None:
98
+ # If we're ordering by a score, insert the checkpoint
99
+ # so that the list remains sorted.
100
+ _insert_into_sorted_list(
101
+ self._checkpoint_results,
102
+ checkpoint_result,
103
+ key=self._get_checkpoint_score,
104
+ )
105
+ else:
106
+ # If no metric is provided, just append (ordering by time of registration).
107
+ self._checkpoint_results.append(checkpoint_result)
108
+
109
+ results_to_delete = {}
110
+ if self._checkpoint_config.num_to_keep is not None:
111
+ # Delete the bottom (N - K) checkpoints
112
+ worst_results = set(
113
+ self._checkpoint_results[: -self._checkpoint_config.num_to_keep]
114
+ )
115
+ # Except for the latest checkpoint.
116
+ results_to_delete = worst_results - {self._latest_checkpoint_result}
117
+
118
+ # Update internal state before actually deleting them.
119
+ self._checkpoint_results = [
120
+ checkpoint_result
121
+ for checkpoint_result in self._checkpoint_results
122
+ if checkpoint_result not in results_to_delete
123
+ ]
124
+
125
+ # Save the checkpoint manager state to storage.
126
+ # Note: We save the state before deleting the old checkpoints.
127
+ # If deletion happens first and the process crashes, our snapshot
128
+ # may point to some stale checkpoints that are already deleted.
129
+ # TODO: Make this writing operation non-blocking.
130
+ self._write_state_to_storage()
131
+
132
+ # Delete the old checkpoints.
133
+ for checkpoint_result in results_to_delete:
134
+ checkpoint = checkpoint_result.checkpoint
135
+ logger.debug("Deleting checkpoint: ", checkpoint)
136
+ _delete_fs_path(fs=checkpoint.filesystem, fs_path=checkpoint.path)
137
+
138
+ # --------------------------
139
+ # CheckpointManager state
140
+ # --------------------------
141
+
142
+ def _save_state(self) -> str:
143
+ """Save the checkpoint manager state to a JSON str."""
144
+
145
+ checkpoint_results = [
146
+ _get_state_from_training_result(checkpoint_result, self._storage_context)
147
+ for checkpoint_result in self._checkpoint_results
148
+ ]
149
+
150
+ latest_checkpoint_result = (
151
+ _get_state_from_training_result(
152
+ self._latest_checkpoint_result, self._storage_context
153
+ )
154
+ if self._latest_checkpoint_result is not None
155
+ else None
156
+ )
157
+
158
+ manager_snapshot = _CheckpointManagerState(
159
+ checkpoint_results=checkpoint_results,
160
+ latest_checkpoint_result=latest_checkpoint_result,
161
+ )
162
+ return manager_snapshot.model_dump_json()
163
+
164
+ def _load_state(self, json_state: str):
165
+ """Load the checkpoint manager state from a JSON str."""
166
+ try:
167
+ manager_snapshot = _CheckpointManagerState.model_validate(
168
+ from_json(json_state)
169
+ )
170
+ except Exception as e:
171
+ raise CheckpointManagerInitializationError(repr(e)) from e
172
+ self._assert_checkpoints_exist()
173
+
174
+ self._checkpoint_results = [
175
+ _get_training_result_from_state(
176
+ training_result_state, self._storage_context
177
+ )
178
+ for training_result_state in manager_snapshot.checkpoint_results
179
+ ]
180
+
181
+ self._latest_checkpoint_result = (
182
+ _get_training_result_from_state(
183
+ manager_snapshot.latest_checkpoint_result, self._storage_context
184
+ )
185
+ if manager_snapshot.latest_checkpoint_result is not None
186
+ else None
187
+ )
188
+
189
+ def _maybe_load_state_from_storage(self):
190
+ """Load the checkpoint manager state from storage.
191
+ If no snapshot is found, start with a clean state.
192
+ """
193
+ if not _exists_at_fs_path(
194
+ fs=self._storage_context.storage_filesystem,
195
+ fs_path=self._storage_context.checkpoint_manager_snapshot_path,
196
+ ):
197
+ logger.debug(
198
+ "No checkpoint manager snapshot found. "
199
+ "No checkpoint will be available via `ray.train.get_checkpoint`, "
200
+ "so training will start from scratch."
201
+ )
202
+ return
203
+ with self._storage_context.storage_filesystem.open_input_stream(
204
+ self._storage_context.checkpoint_manager_snapshot_path
205
+ ) as f:
206
+ logger.info(
207
+ "A run snapshot was found in storage folder at: "
208
+ f"'{self._storage_context.experiment_fs_path}'\n"
209
+ "This snapshot contains a list of checkpoints reported via "
210
+ "`ray.train.report` and will be loaded. "
211
+ "This allows the latest checkpoint found in the snapshot to be "
212
+ "accessible within your training function via "
213
+ "`ray.train.get_checkpoint`.\n"
214
+ "If you meant to start a brand new training job without any "
215
+ "information about previous checkpoints found in this directory, "
216
+ "please configure a new, unique `RunConfig(name)` or delete the "
217
+ f"existing folder at '{self._storage_context.experiment_fs_path}'."
218
+ )
219
+ json_state = f.read().decode("utf-8")
220
+ self._load_state(json_state)
221
+
222
+ def _write_state_to_storage(self):
223
+ """Write the checkpoint manager state to storage."""
224
+ checkpoint_manager_snapshot = self._save_state()
225
+ with self._storage_context.storage_filesystem.open_output_stream(
226
+ self._storage_context.checkpoint_manager_snapshot_path
227
+ ) as f:
228
+ f.write(checkpoint_manager_snapshot.encode("utf-8"))
229
+
230
+ def _assert_checkpoints_exist(self):
231
+ """Validate the checkpoint manager state.
232
+
233
+ This method will validate the checkpoint manager state by checking if
234
+ the checkpoints specified in manager snapshot is compatible with the
235
+ checkpoint folders of the experiment storage filesystem.
236
+
237
+ Raises:
238
+ CheckpointManagerInitializationError: If the checkpoint manager snapshot
239
+ is not consistent with the stored checkpoints.
240
+ """
241
+ for checkpoint_result in self._checkpoint_results:
242
+ checkpoint = checkpoint_result.checkpoint
243
+ assert checkpoint is not None
244
+ if not _exists_at_fs_path(
245
+ fs=checkpoint.filesystem, fs_path=checkpoint.path
246
+ ):
247
+ raise CheckpointManagerInitializationError(
248
+ message=(
249
+ "The run snapshot contains a reference to a checkpoint "
250
+ f"that does not exist anymore ({checkpoint}). You are "
251
+ "running in a corrupted run directory `experiment_fs_path`."
252
+ "Please configure a new, unique `RunConfig(name)` "
253
+ "or delete the existing folder at "
254
+ f"`{self._storage_context.experiment_fs_path}`."
255
+ )
256
+ )
257
+
258
+ # --------------------------
259
+ # ReportCallback
260
+ # --------------------------
261
+
262
+ def after_report(
263
+ self, metrics: List[Dict[str, Any]], checkpoint: Optional[Checkpoint]
264
+ ):
265
+ if not checkpoint:
266
+ return
267
+
268
+ rank_0_metrics = metrics[0]
269
+ self.register_checkpoint(
270
+ _TrainingResult(checkpoint=checkpoint, metrics=rank_0_metrics)
271
+ )
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/report_handler.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ from typing import TYPE_CHECKING, Deque, List, Optional
3
+
4
+ from ray.train.v2._internal.execution.callback import (
5
+ ReportCallback,
6
+ WorkerGroupCallback,
7
+ )
8
+ from ray.train.v2._internal.execution.worker_group import WorkerGroup, WorkerGroupStatus
9
+
10
+ if TYPE_CHECKING:
11
+ from ray.train._internal.session import _TrainingResult
12
+
13
+
14
+ class ReportCallbackHandler(WorkerGroupCallback):
15
+ """Consolidate training results from multiple workers and call
16
+ subscribers implementing the `ReportCallback` interface sequentially.
17
+ """
18
+
19
+ def __init__(self, report_callbacks: List[ReportCallback]):
20
+ # Number of workers in the current worker group. It is initialized
21
+ # to be None. It is set to the number of workers when it receives the
22
+ # worker group status for the first time.
23
+ # When a worker group shutdown, self._num_workers is set to None,
24
+ # waiting to be updated when a new worker group status is received again.
25
+ self._num_workers: Optional[int] = None
26
+ # A list of queues holding training results from workers.
27
+ self._training_result_queues: Optional[List[Deque[_TrainingResult]]] = None
28
+
29
+ self._report_callbacks = report_callbacks
30
+
31
+ # --------------------------
32
+ # WorkerGroupCallback
33
+ # --------------------------
34
+
35
+ def after_worker_group_poll_status(
36
+ self, worker_group_status: WorkerGroupStatus
37
+ ) -> None:
38
+ """Handle training results as they roll in from worker status polls.
39
+
40
+ Wait for all workers to report training results to collect
41
+ a consolidated training result.
42
+ """
43
+ # Step 1: If self._num_workers is None, we need to initialize the number
44
+ # of workers and training_results_queues from the worker group status. This
45
+ # happens when the handler receives the worker group status for the first time.
46
+ assert (
47
+ self._num_workers and self._training_result_queues
48
+ ), "Need to call initialize state with `after_worker_group_start` first."
49
+
50
+ assert self._num_workers == worker_group_status.num_workers, (
51
+ f"The number of workers in the worker group has changed unexpectedly. "
52
+ f"Expected: {self._num_workers}, got: {worker_group_status.num_workers}"
53
+ )
54
+
55
+ # Step 2: Update training_results_queues with poll_results.
56
+ for i in range(self._num_workers):
57
+ training_result = worker_group_status.worker_statuses[i].training_result
58
+ if training_result:
59
+ self._training_result_queues[i].append(training_result)
60
+
61
+ # Directly return if any of the worker result queues are empty.
62
+ if not all(self._training_result_queues):
63
+ return
64
+
65
+ training_results = [q.popleft() for q in self._training_result_queues]
66
+
67
+ # Step 3: Consolidate a list of checkpoints to single checkpoint.
68
+ # Use the first checkpoint as the consolidated checkpoint.
69
+ checkpoint_results = [
70
+ tr for tr in training_results if tr.checkpoint is not None
71
+ ]
72
+
73
+ consolidated_checkpoint = None
74
+ if checkpoint_results:
75
+ # Double check the storage path of the checkpoints in the training results.
76
+ unique_checkpoint_paths = {tr.checkpoint.path for tr in checkpoint_results}
77
+ if len(unique_checkpoint_paths) > 1:
78
+ # TODO: Support for inconsistent checkpoints path from workers
79
+ # instead of hard raising error. Maybe drop this iteration of
80
+ # training results and continue with the next iteration.
81
+ raise RuntimeError(
82
+ "The storage path of the checkpoints in the training results "
83
+ "is not the same. This means the checkpoints are not consistent."
84
+ "Got a mix of the following checkpoint paths: "
85
+ f"{unique_checkpoint_paths}\n"
86
+ "This is unexpected -- please file a Github issue."
87
+ )
88
+ consolidated_checkpoint = checkpoint_results[0].checkpoint
89
+
90
+ # Step 4: Invoke all dependent `ReportCallback`s.
91
+ metrics_per_worker = [
92
+ training_result.metrics for training_result in training_results
93
+ ]
94
+ for callback in self._report_callbacks:
95
+ callback.after_report(
96
+ metrics=metrics_per_worker,
97
+ checkpoint=consolidated_checkpoint,
98
+ )
99
+
100
+ def after_worker_group_start(self, worker_group: WorkerGroup) -> None:
101
+ """Handle worker group start. Initialize internal states."""
102
+ self._num_workers = len(worker_group)
103
+ self._training_result_queues = [deque() for _ in range(self._num_workers)]
104
+
105
+ def before_worker_group_shutdown(self, worker_group: WorkerGroup) -> None:
106
+ """Handle worker group shutdown. Clear internal states.
107
+
108
+ None of the partial reported results are valid at this point.
109
+ """
110
+ self._num_workers = None
111
+ self._training_result_queues = None
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/checkpoint/sync_actor.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ from contextlib import contextmanager
4
+ from typing import List, Optional, TypeVar
5
+
6
+ import ray
7
+ from ray.train.v2._internal.constants import (
8
+ DEFAULT_REPORT_BARRIER_TIMEOUT_S,
9
+ DEFAULT_REPORT_BARRIER_WARN_INTERVAL_S,
10
+ REPORT_BARRIER_WARN_INTERVAL_S_ENV_VAR,
11
+ )
12
+ from ray.train.v2._internal.exceptions import BroadcastCollectiveTimeoutError
13
+
14
+ T = TypeVar("T", bound=Optional[object])
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ BROADCAST_PERIODIC_WARNING = """
19
+ `ray.train.report` has not been called by all {world_size} workers in the group.
20
+
21
+ The workers have been waiting for {max_time_elapsed_s:.2f} s for the following ranks
22
+ to join the `report` call: {missing_ranks}.
23
+
24
+ Please ensure that all workers call `ray.train.report` regardless of whether
25
+ they participate in checkpointing or not (e.g., pass `checkpoint=None` for ranks
26
+ that do not save a checkpoint). Also ensure that workers are not hanging on
27
+ other operations, causing them to miss this synchronization barrier.
28
+
29
+ You can set the {warn_interval_env_var} environment variable to change the frequency
30
+ of this warning (current value: {warn_interval_s} s).
31
+ """
32
+
33
+
34
+ @ray.remote(num_cpus=0) # type: ignore
35
+ class SynchronizationActor:
36
+ """A Ray actor that synchronizes the workers in a distributed training job.
37
+
38
+ This actor forms a synchronization barrier on a group of processes.
39
+ Every time a worker calls the broadcast_from_rank_zero method,
40
+ the counter is incremented. When the counter equals to the world size,
41
+ the actor notifies all the workers to continue.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ timeout_s: float = DEFAULT_REPORT_BARRIER_TIMEOUT_S,
47
+ warn_interval_s: float = DEFAULT_REPORT_BARRIER_WARN_INTERVAL_S,
48
+ ):
49
+ self._counter: int = 0
50
+ self._world_size: int = 0
51
+ self._condition = asyncio.Condition()
52
+ self._reduced_data = None
53
+ # The time when workers from different ranks
54
+ # enters the synchronization barrier.
55
+ self._sync_start_times: List[Optional[float]] = []
56
+ # The timeout in seconds for the synchronization barrier.
57
+ self._timeout_s: float = timeout_s
58
+ # The interval in seconds to log a warning when waiting for the barrier.
59
+ self._warn_interval_s: float = warn_interval_s
60
+
61
+ def get_counter(self):
62
+ """Returns the current value of the counter."""
63
+ return self._counter
64
+
65
+ def get_world_size(self):
66
+ """Returns the current value of the world_size."""
67
+ return self._world_size
68
+
69
+ def get_reduced_data(self):
70
+ """Returns the current value of the reduced_data."""
71
+ return self._reduced_data
72
+
73
+ def _clear_states(self):
74
+ """Clears the states of the actor. When the last worker has
75
+ called the _clear_states method, the actor clears its states
76
+ """
77
+ self._counter -= 1
78
+ if self._counter == 0:
79
+ self._reduced_data = None
80
+ self._world_size = 0
81
+
82
+ def _setup_or_validate_collective_op(self, world_size: int):
83
+ """The setup method for the synchronization actor if it is not setup yet.
84
+ It initializes the world size and the start times for the
85
+ synchronization barrier.
86
+ """
87
+ if self._world_size == 0:
88
+ self._world_size = world_size
89
+ self._sync_start_times = [None] * world_size
90
+ elif world_size != self._world_size:
91
+ raise ValueError(
92
+ f"Expected all callers to provide the same world size. \
93
+ Got {world_size} and expected {self._world_size}."
94
+ )
95
+
96
+ @contextmanager
97
+ def _broadcast_collective_context_manager(
98
+ self, world_rank: int, world_size: int, data: T
99
+ ):
100
+ """A context manager that ensures the synchronization barrier is lifted
101
+ after the block of code is executed.
102
+ """
103
+ try:
104
+ self._setup_or_validate_collective_op(world_size)
105
+ if world_rank == 0:
106
+ self._reduced_data = data
107
+ if self._counter < self._world_size:
108
+ self._counter += 1
109
+ yield
110
+ finally:
111
+ self._clear_states()
112
+
113
+ def _get_time_elapsed(self) -> Optional[float]:
114
+ """Return the time elapsed since the first worker entered the barrier.
115
+ If no workers have entered the barrier, returns None.
116
+ """
117
+ start_times = [t for t in self._sync_start_times if t is not None]
118
+ if not start_times:
119
+ return None
120
+
121
+ return asyncio.get_event_loop().time() - min(start_times)
122
+
123
+ def _get_missing_ranks(self) -> List[int]:
124
+ """Returns the ranks that have not entered the synchronization barrier."""
125
+ return [i for i, t in enumerate(self._sync_start_times) if t is None]
126
+
127
+ async def _wait_with_logging(self, condition, world_rank: int):
128
+ """Waits for the condition to be notified, logging an warning every
129
+ `log_interval` seconds, and raises a timeout error if `timeout` is reached.
130
+ """
131
+ current_time = asyncio.get_event_loop().time()
132
+ self._sync_start_times[world_rank] = current_time
133
+ while True:
134
+ try:
135
+ await asyncio.wait_for(condition.wait(), timeout=self._warn_interval_s)
136
+ return
137
+ # asyncio.wait_for() raises `asyncio.TimeoutError` for asyncio<=3.10
138
+ # and raises `TimeoutError` for asyncio>=3.11
139
+ # https://docs.python.org/3/library/asyncio-task.html#asyncio.wait_for
140
+ # TODO: (hpguo) Make only one worker log the warning message.
141
+ except (asyncio.TimeoutError, TimeoutError):
142
+ logger.warning(
143
+ BROADCAST_PERIODIC_WARNING.format(
144
+ world_size=self._world_size,
145
+ max_time_elapsed_s=self._get_time_elapsed(),
146
+ missing_ranks=self._get_missing_ranks(),
147
+ warn_interval_env_var=REPORT_BARRIER_WARN_INTERVAL_S_ENV_VAR,
148
+ warn_interval_s=self._warn_interval_s,
149
+ )
150
+ )
151
+
152
+ async def broadcast_from_rank_zero(
153
+ self, world_rank: int, world_size: int, data: T
154
+ ) -> T:
155
+ """Broadcasts a data from the worker with rank 0 to all other workers.
156
+
157
+ This method is a coroutine that blocks until all workers have called this
158
+ method with the their data. The data from the worker with rank 0 will
159
+ be returned.
160
+ """
161
+ # Ensures that all global states manipulation is done within the async context
162
+ # manager which makes the condition variable awaiting and the counter
163
+ # incrementing an atomic operation.
164
+ async with self._condition:
165
+ with self._broadcast_collective_context_manager(
166
+ world_rank, world_size, data
167
+ ):
168
+ # If the counter is equal to the world size, it means the last worker
169
+ # has called the broadcast_from_rank_zero method. The actor notifies
170
+ # all the workers to continue.
171
+ if self._counter == self._world_size:
172
+ self._condition.notify_all()
173
+ return self._reduced_data
174
+ # If the counter is less than the world size, the actor waits for the
175
+ # other workers to call the broadcast_from_rank_zero method.
176
+ try:
177
+ await asyncio.wait_for(
178
+ self._wait_with_logging(self._condition, world_rank),
179
+ timeout=self._timeout_s,
180
+ )
181
+ return self._reduced_data
182
+ except (asyncio.TimeoutError, TimeoutError) as e:
183
+ raise BroadcastCollectiveTimeoutError(
184
+ time_elapsed=self._get_time_elapsed(),
185
+ missing_ranks=self._get_missing_ranks(),
186
+ timeout_s=self._timeout_s,
187
+ ) from e
188
+
189
+ # TODO: Implement a general consensus_from_votes method that takes a callable
190
+ # reduce_fn and a list of votes from each worker. The method returns the consensus
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/context.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import threading
3
+ from dataclasses import dataclass
4
+ from queue import Queue
5
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
6
+
7
+ import ray
8
+ from ray.data.iterator import DataIterator
9
+ from ray.train import Checkpoint
10
+ from ray.train._internal import session
11
+ from ray.train._internal.session import _TrainingResult
12
+ from ray.train.v2._internal.execution.checkpoint.sync_actor import SynchronizationActor
13
+ from ray.train.v2._internal.execution.storage import StorageContext
14
+ from ray.train.v2._internal.util import _copy_doc, invoke_context_managers
15
+ from ray.train.v2.api.config import RunConfig
16
+
17
+ if TYPE_CHECKING:
18
+ from ray.train.v2._internal.execution.callback import TrainContextCallback
19
+ from ray.train.v2._internal.execution.worker_group.thread_runner import ThreadRunner
20
+
21
+
22
+ logger = logging.getLogger(__file__)
23
+
24
+
25
+ @dataclass
26
+ class TrainRunContext:
27
+ """Holds the metadata and context for the current training run."""
28
+
29
+ # TODO: Make this dataclass immutable after refactoring the train context.
30
+
31
+ # The run configuration for the current training run.
32
+ run_config: RunConfig
33
+
34
+ # TODO: Add more fields that are shared across all workers and controllers.
35
+ # For example, StorageContext, ScalingConfig, etc.
36
+
37
+ def get_run_config(self) -> RunConfig:
38
+ """Returns the run config of the current training run."""
39
+ return self.run_config
40
+
41
+
42
+ @dataclass(frozen=True)
43
+ class DistributedContext:
44
+ world_rank: int
45
+ world_size: int
46
+ local_rank: int
47
+ local_world_size: int
48
+ node_rank: int
49
+
50
+
51
+ @dataclass(frozen=True)
52
+ class ExecutionContext:
53
+ """Holds the execution context for the current worker process.
54
+
55
+ Every worker process has a single execution context accessed via the
56
+ `TrainContext`, which includes the training thread that is actually
57
+ running the user code.
58
+ """
59
+
60
+ # A shared synchronization actor that helps broadcast data across ranks.
61
+ synchronization_actor: SynchronizationActor
62
+
63
+ # A queue that receives training results from the user training code.
64
+ # `ray.train.report` in user code populates this queue.
65
+ result_queue: Queue
66
+
67
+ # The thread launcher that runs the user training loop.
68
+ training_thread_runner: "ThreadRunner"
69
+
70
+ # The callbacks that are run in the worker train context.
71
+ train_context_callbacks: List["TrainContextCallback"]
72
+
73
+
74
+ @dataclass
75
+ class TrainContext(TrainRunContext):
76
+ distributed_context: DistributedContext
77
+ execution_context: ExecutionContext
78
+ storage_context: StorageContext
79
+ dataset_shards: Dict[str, DataIterator]
80
+ checkpoint: Optional[Checkpoint] = None
81
+
82
+ @_copy_doc(session.get_metadata)
83
+ def get_metadata(self) -> Dict[str, Any]:
84
+ raise NotImplementedError
85
+
86
+ @_copy_doc(session.get_experiment_name)
87
+ def get_experiment_name(self) -> str:
88
+ # TODO: Resolve run_config.name if it is None
89
+ return self.run_config.name
90
+
91
+ @_copy_doc(session.get_trial_name)
92
+ def get_trial_name(self) -> str:
93
+ raise NotImplementedError
94
+
95
+ @_copy_doc(session.get_trial_id)
96
+ def get_trial_id(self) -> str:
97
+ raise NotImplementedError
98
+
99
+ @_copy_doc(session.get_trial_resources)
100
+ def get_trial_resources(self):
101
+ raise NotImplementedError
102
+
103
+ @_copy_doc(session.get_trial_dir)
104
+ def get_trial_dir(self) -> str:
105
+ raise NotImplementedError
106
+
107
+ @_copy_doc(session.get_world_size)
108
+ def get_world_size(self) -> int:
109
+ return self.distributed_context.world_size
110
+
111
+ @_copy_doc(session.get_world_rank)
112
+ def get_world_rank(self) -> int:
113
+ return self.distributed_context.world_rank
114
+
115
+ @_copy_doc(session.get_local_rank)
116
+ def get_local_rank(self) -> int:
117
+ return self.distributed_context.local_rank
118
+
119
+ @_copy_doc(session.get_local_world_size)
120
+ def get_local_world_size(self) -> int:
121
+ return self.distributed_context.local_world_size
122
+
123
+ @_copy_doc(session.get_node_rank)
124
+ def get_node_rank(self) -> int:
125
+ return self.distributed_context.node_rank
126
+
127
+ @_copy_doc(session.get_storage)
128
+ def get_storage(self):
129
+ return self.storage_context
130
+
131
+ def get_result_queue(self):
132
+ return self.execution_context.result_queue
133
+
134
+ def get_synchronization_actor(self):
135
+ return self.execution_context.synchronization_actor
136
+
137
+ def get_checkpoint(self):
138
+ return self.checkpoint
139
+
140
+ def get_dataset_shard(self, dataset_name: str) -> DataIterator:
141
+ """Returns the :class:`ray.data.DataIterator` shard for this worker.
142
+
143
+ Call :meth:`~ray.data.DataIterator.iter_torch_batches` or
144
+ :meth:`~ray.data.DataIterator.to_tf` on this shard to convert it to the
145
+ appropriate framework-specific data type.
146
+
147
+ Args:
148
+ dataset_name: Name of the dataset shard.
149
+ Returns:
150
+ The ``DataIterator`` shard with the given name for this worker.
151
+ Raises:
152
+ KeyError: If the dataset shard with the given name is not found.
153
+ """
154
+ try:
155
+ return self.dataset_shards[dataset_name]
156
+ except KeyError:
157
+ raise KeyError(
158
+ f"Dataset {dataset_name} not found. Available datasets: "
159
+ f"{list(self.dataset_shards.keys())}."
160
+ )
161
+
162
+ def get_context_callbacks(self) -> List["TrainContextCallback"]:
163
+ return self.execution_context.train_context_callbacks
164
+
165
+ def _sync_checkpoint_dir_name_across_ranks(
166
+ self, checkpoint_dir_name: Optional[str] = None
167
+ ) -> str:
168
+ """Sync the checkpoint dir name across ranks.
169
+
170
+ Args:
171
+ checkpoint_dir_name: The checkpoint dir name to sync.
172
+
173
+ Returns:
174
+ The synced checkpoint dir name.
175
+ """
176
+ # If checkpoint_dir_name is not set, use default checkpoint_dir_name
177
+ # created by the storage context.
178
+ checkpoint_dir_name = (
179
+ checkpoint_dir_name
180
+ or self.storage_context.make_default_checkpoint_dir_name()
181
+ )
182
+ # Get a consensus across ranks on the remote storage path, so distributed
183
+ # checkpoints will be stored to the same place.
184
+ sync_actor = self.get_synchronization_actor()
185
+ return ray.get(
186
+ sync_actor.broadcast_from_rank_zero.remote(
187
+ world_rank=self.distributed_context.world_rank,
188
+ world_size=self.distributed_context.world_size,
189
+ data=checkpoint_dir_name,
190
+ )
191
+ )
192
+
193
+ def _save_checkpoint(
194
+ self,
195
+ checkpoint_dir_name: str,
196
+ metrics: Dict[str, Any],
197
+ checkpoint: Optional[Checkpoint] = None,
198
+ ) -> _TrainingResult:
199
+ """Save the checkpoint to remote storage.
200
+
201
+ Returns:
202
+ The training result object containing the persisted checkpoint.
203
+ """
204
+
205
+ if not checkpoint:
206
+ return _TrainingResult(checkpoint=None, metrics=metrics)
207
+
208
+ # Persist the checkpoint to the remote storage path.
209
+ persisted_checkpoint = self.storage_context.persist_current_checkpoint(
210
+ checkpoint, checkpoint_dir_name
211
+ )
212
+ # Update latest checkpoint as the persisted checkpoint.
213
+ self.checkpoint = persisted_checkpoint
214
+
215
+ return _TrainingResult(checkpoint=persisted_checkpoint, metrics=metrics)
216
+
217
+ def report(
218
+ self,
219
+ metrics: Dict[str, Any],
220
+ checkpoint: Optional[Checkpoint] = None,
221
+ checkpoint_dir_name: Optional[str] = None,
222
+ ):
223
+ """
224
+ Upload checkpoint to remote storage and put a training
225
+ result on the result queue of this worker process.
226
+
227
+ Args:
228
+ metrics: The metrics to report.
229
+ checkpoint: The checkpoint to report.
230
+ checkpoint_dir_name: The name of the checkpoint dir
231
+ in this iteration. Note: If not set, the checkpoint will
232
+ be stored in the default storage path. If set, make sure
233
+ this value is unique for each iteration.
234
+
235
+ TODO: the report function should be implemented in the worker instead
236
+ of in the train context. The train context should only keep the train
237
+ related information and not the worker related actions. This refactor
238
+ would also require the `TrainContextCallback` to be updated as well.
239
+ """
240
+
241
+ with invoke_context_managers(
242
+ [
243
+ callback.on_report
244
+ for callback in self.execution_context.train_context_callbacks
245
+ ]
246
+ ):
247
+ # Step 1: sync the checkpoint dir name across ranks.
248
+ checkpoint_dir_name = self._sync_checkpoint_dir_name_across_ranks(
249
+ checkpoint_dir_name
250
+ )
251
+ # Step 2: save the checkpoint to remote storage.
252
+ training_result = self._save_checkpoint(
253
+ checkpoint_dir_name, metrics, checkpoint
254
+ )
255
+ # Step 3: Report the training result to the result queue.
256
+ # The queue size is set to 1 to avoid accumulating unprocessed results.
257
+ # If the queue is full, the put operation blocks until a result is consumed.
258
+
259
+ # TODO (hpguo): Add a metrics to track the blocking time waiting for the
260
+ # training result to be consumed by the controller.
261
+ self.get_result_queue().put(training_result)
262
+
263
+
264
+ # The global variable holding the current TrainContext
265
+ _train_context: Optional[TrainContext] = None
266
+
267
+ # Thread lock to protect the global TrainContext
268
+ _context_lock = threading.Lock()
269
+
270
+
271
+ def get_train_context() -> TrainContext:
272
+ with _context_lock:
273
+ if _train_context is None:
274
+ raise RuntimeError("TrainContext has not been initialized.")
275
+ return _train_context
276
+
277
+
278
+ def set_train_context(context) -> None:
279
+ global _train_context
280
+ with _context_lock:
281
+ _train_context = context
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/controller.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import time
4
+ from enum import Enum
5
+ from pathlib import Path
6
+ from typing import Any, Callable, Dict, List, Optional
7
+
8
+ from ray._private.auto_init_hook import wrap_auto_init
9
+ from ray.train import Checkpoint
10
+ from ray.train.v2._internal.constants import (
11
+ DEFAULT_HEALTH_CHECK_INTERVAL_S,
12
+ HEALTH_CHECK_INTERVAL_S_ENV_VAR,
13
+ )
14
+ from ray.train.v2._internal.exceptions import (
15
+ TrainingFailedError,
16
+ WorkerGroupStartupFailedError,
17
+ WorkerGroupStartupTimeoutError,
18
+ )
19
+ from ray.train.v2._internal.execution.callback import (
20
+ ControllerCallback,
21
+ ReportCallback,
22
+ TrainContextCallback,
23
+ WorkerCallback,
24
+ WorkerGroupCallback,
25
+ )
26
+ from ray.train.v2._internal.execution.checkpoint.checkpoint_manager import (
27
+ CheckpointManager,
28
+ )
29
+ from ray.train.v2._internal.execution.checkpoint.report_handler import (
30
+ ReportCallbackHandler,
31
+ )
32
+ from ray.train.v2._internal.execution.context import TrainRunContext
33
+ from ray.train.v2._internal.execution.failure_handling import (
34
+ FailureDecision,
35
+ FailurePolicy,
36
+ )
37
+ from ray.train.v2._internal.execution.scaling_policy import (
38
+ ResizeDecision,
39
+ ScalingDecision,
40
+ ScalingPolicy,
41
+ )
42
+ from ray.train.v2._internal.execution.storage import StorageContext, get_fs_and_path
43
+ from ray.train.v2._internal.execution.worker_group import WorkerGroup, WorkerGroupStatus
44
+ from ray.train.v2._internal.logging.logging import configure_controller_logger
45
+ from ray.train.v2._internal.util import time_monotonic
46
+ from ray.train.v2.api.result import Result
47
+ from ray.train.v2.api.callback import RayTrainCallback
48
+
49
+ logger = logging.getLogger(__name__)
50
+
51
+
52
+ class TrainControllerState(Enum):
53
+ """The possible states that the training controller can be in
54
+ while running the main execution control loop.
55
+
56
+ States:
57
+ RUNNING: The training controller is actively running training tasks.
58
+ RECOVERING: The training controller is in the process of recovering
59
+ from an error.
60
+ INITIALIZING: The train controller is starting up.
61
+ This is always the initial state of the controller.
62
+ ERRORED: A terminal state indicating that training has encountered
63
+ an error and cannot continue.
64
+ FINISHED: A terminal state indicating that training has completed.
65
+ """
66
+
67
+ RUNNING = "RUNNING"
68
+ INITIALIZING = "INITIALIZING"
69
+ RECOVERING = "RECOVERING"
70
+ ERRORED = "ERRORED"
71
+ FINISHED = "FINISHED"
72
+
73
+
74
+ class TrainController:
75
+ """Manages the execution of a distributed training job.
76
+
77
+ Responsibilities include:
78
+ * Triggering the training function to run on the worker group.
79
+ * Monitoring the status of the worker group.
80
+ * Handling scaling decisions by restarting the worker group.
81
+ * Handling failure decisions by restarting the worker group or terminating training.
82
+ * Running callback logic on different hooks in the control loop.
83
+ """
84
+
85
+ worker_group_cls = WorkerGroup
86
+
87
+ def __init__(
88
+ self,
89
+ train_fn: Callable[[Dict[str, Any]], None],
90
+ train_run_context: TrainRunContext,
91
+ scaling_policy: ScalingPolicy,
92
+ failure_policy: FailurePolicy,
93
+ callbacks: Optional[List[RayTrainCallback]] = None,
94
+ # TODO: [Deprecation]
95
+ resume_from_checkpoint: Optional[Checkpoint] = None,
96
+ ):
97
+ self._train_run_context = train_run_context
98
+ configure_controller_logger(self._train_run_context)
99
+ self._train_fn = train_fn
100
+ self._scaling_policy = scaling_policy
101
+ self._failure_policy = failure_policy
102
+ self._run_config = self._train_run_context.run_config
103
+ self._callbacks = callbacks or []
104
+ self._resume_from_checkpoint = resume_from_checkpoint
105
+ self._storage_context = StorageContext(
106
+ storage_path=self._run_config.storage_path,
107
+ experiment_dir_name=self._run_config.name,
108
+ storage_filesystem=self._run_config.storage_filesystem,
109
+ )
110
+
111
+ self._checkpoint_manager = CheckpointManager(
112
+ checkpoint_config=self._run_config.checkpoint_config,
113
+ storage_context=self._storage_context,
114
+ )
115
+ report_handler = ReportCallbackHandler(
116
+ report_callbacks=(
117
+ [self._checkpoint_manager]
118
+ + [c for c in self._callbacks if isinstance(c, ReportCallback)]
119
+ )
120
+ )
121
+
122
+ # Group callbacks by the hooks they're subscribed to.
123
+ self._controller_callbacks = [self._scaling_policy] + [
124
+ c for c in self._callbacks if isinstance(c, ControllerCallback)
125
+ ]
126
+ # Group callbacks that will be propagated to the worker group,
127
+ # train worker and the train context.
128
+ worker_group_callbacks_to_propagate = [report_handler] + [
129
+ c
130
+ for c in self._callbacks
131
+ if isinstance(
132
+ c, (WorkerGroupCallback, WorkerCallback, TrainContextCallback)
133
+ )
134
+ ]
135
+
136
+ self._worker_group = self.worker_group_cls(
137
+ train_run_context=self._train_run_context,
138
+ callbacks=worker_group_callbacks_to_propagate,
139
+ )
140
+ self._state = TrainControllerState.INITIALIZING
141
+
142
+ self._latest_poll_time = float("-inf")
143
+ self._health_check_interval_s = float(
144
+ os.getenv(HEALTH_CHECK_INTERVAL_S_ENV_VAR, DEFAULT_HEALTH_CHECK_INTERVAL_S)
145
+ )
146
+ self._training_failed_error: Optional[TrainingFailedError] = None
147
+
148
+ def _execute_scaling_decision(
149
+ self, decision: ScalingDecision, worker_group_status: WorkerGroupStatus
150
+ ):
151
+ """Executes scaling decisions."""
152
+ for callback in self._controller_callbacks:
153
+ callback.before_controller_execute_scaling_decision(
154
+ decision, worker_group_status
155
+ )
156
+
157
+ if isinstance(decision, ResizeDecision):
158
+ self._restart_worker_group(
159
+ num_workers=decision.num_workers,
160
+ resources_per_worker=decision.resources_per_worker,
161
+ )
162
+
163
+ def _execute_failure_decision(
164
+ self, failure_decision: FailureDecision, worker_group_status: WorkerGroupStatus
165
+ ):
166
+ """Executes failure handling decisions (ex: restart, terminate)."""
167
+ assert worker_group_status.errors
168
+
169
+ for callback in self._controller_callbacks:
170
+ callback.before_controller_execute_failure_decision(
171
+ failure_decision, worker_group_status
172
+ )
173
+
174
+ if failure_decision == FailureDecision.NOOP:
175
+ assert self._state == TrainControllerState.RUNNING
176
+ return
177
+
178
+ errors_str = "\n".join(
179
+ [
180
+ f"[Rank {worker_rank}]\n{error}"
181
+ for worker_rank, error in worker_group_status.errors.items()
182
+ ]
183
+ )
184
+
185
+ if failure_decision == FailureDecision.RESTART:
186
+ logger.error(
187
+ "Restarting training worker group after encountering "
188
+ f"failures on {len(worker_group_status.errors)} worker(s):\n"
189
+ f"{errors_str}"
190
+ )
191
+ # Shutdown the worker group so that we don't keep polling errored tasks.
192
+ self._worker_group.shutdown()
193
+ self._set_state(TrainControllerState.RECOVERING)
194
+ elif failure_decision == FailureDecision.RAISE:
195
+ logger.error(
196
+ "Terminating training worker group after encountering "
197
+ f"failure(s) on {len(worker_group_status.errors)} worker(s):\n"
198
+ f"{errors_str}"
199
+ )
200
+ self._set_state(TrainControllerState.ERRORED)
201
+ self._training_failed_error = TrainingFailedError(
202
+ worker_failures=worker_group_status.errors
203
+ )
204
+ else:
205
+ raise ValueError(f"Unexpected failure decision: {failure_decision}")
206
+
207
+ def _poll_workers(self) -> WorkerGroupStatus:
208
+ # Ensure that the time between polls is at least HEALTH_CHECK_INTERVAL_S.
209
+ time_since_last_poll = time_monotonic() - self._latest_poll_time
210
+ if time_since_last_poll < self._health_check_interval_s:
211
+ remaining_time = max(
212
+ self._health_check_interval_s - time_since_last_poll, 0
213
+ )
214
+ time.sleep(remaining_time)
215
+
216
+ status = self._worker_group.poll_status(timeout=self._health_check_interval_s)
217
+ self._latest_poll_time = time_monotonic()
218
+ return status
219
+
220
+ def _restart_worker_group(self, num_workers: int, resources_per_worker: dict):
221
+ """Restart the worker group and launch the train function."""
222
+ self._worker_group.shutdown()
223
+
224
+ # If there's a latest checkpoint that's been committed,
225
+ # use it to restore the worker group.
226
+ latest_checkpoint_result = self._checkpoint_manager.latest_checkpoint_result
227
+ latest_checkpoint = (
228
+ latest_checkpoint_result.checkpoint if latest_checkpoint_result else None
229
+ )
230
+ placement_strategy = self._scaling_policy.scaling_config.placement_strategy
231
+
232
+ # Start the worker group with the latest checkpoint if there is one.
233
+ # Otherwise, start the worker group with the checkpoint set by controller.
234
+ # Finally, if there is no checkpoint, start the worker group with None.
235
+ try:
236
+ self._worker_group.start(
237
+ train_fn=self._train_fn,
238
+ num_workers=num_workers,
239
+ resources_per_worker=resources_per_worker,
240
+ placement_strategy=placement_strategy,
241
+ checkpoint=latest_checkpoint or self._resume_from_checkpoint,
242
+ )
243
+ except (WorkerGroupStartupTimeoutError, WorkerGroupStartupFailedError) as e:
244
+ logger.error(
245
+ "Retrying the launch of the training worker group. "
246
+ f"The previous launch attempt encountered the following failure:\n{e}"
247
+ )
248
+
249
+ # TODO: Should this logic go through the failure policy?
250
+ # The current logic will always try recovering unconditionally
251
+ # on startup errors without a retry limit.
252
+ self._set_state(TrainControllerState.RECOVERING)
253
+ return
254
+
255
+ # TODO: Consider starting the worker group asynchronously.
256
+ self._set_state(TrainControllerState.RUNNING)
257
+
258
+ def _start(self):
259
+ for callback in self._controller_callbacks:
260
+ callback.after_controller_start()
261
+
262
+ def _shutdown(self):
263
+ self._worker_group.shutdown()
264
+
265
+ for callback in self._controller_callbacks:
266
+ callback.before_controller_shutdown()
267
+
268
+ def get_worker_group(self) -> WorkerGroup:
269
+ return self._worker_group
270
+
271
+ def get_state(self) -> TrainControllerState:
272
+ return self._state
273
+
274
+ def _set_state(self, state: TrainControllerState):
275
+ previous_state = self._state
276
+ self._state = state
277
+
278
+ for callback in self._controller_callbacks:
279
+ callback.after_controller_state_update(previous_state, state)
280
+
281
+ def _run_control_loop_iteration(self):
282
+ """Run a single iteration of the control loop.
283
+
284
+ Steps:
285
+ 1. Poll the worker group for status.
286
+ 2. If the worker group is initializing or recovering from an error,
287
+ make a scaling decision and execute it.
288
+ 3. If the worker group has finished, set the controller state to FINISHED.
289
+ 4. If the worker group has errors, make a failure decision and execute it.
290
+ 5. Otherwise, the worker group is running healthily.
291
+ Query the scaling policy for a scaling decision and execute it.
292
+ """
293
+ assert self.get_state() in (
294
+ TrainControllerState.RUNNING,
295
+ TrainControllerState.RECOVERING,
296
+ TrainControllerState.INITIALIZING,
297
+ ), self.get_state()
298
+
299
+ worker_group_status = self._poll_workers()
300
+
301
+ if worker_group_status.finished and not worker_group_status.errors:
302
+ self._set_state(TrainControllerState.FINISHED)
303
+ return
304
+
305
+ if self.get_state() in (
306
+ TrainControllerState.INITIALIZING,
307
+ TrainControllerState.RECOVERING,
308
+ ):
309
+ scaling_decision = (
310
+ self._scaling_policy.make_decision_for_non_running_worker_group(
311
+ worker_group_status
312
+ )
313
+ )
314
+ self._execute_scaling_decision(scaling_decision, worker_group_status)
315
+ elif self.get_state() == TrainControllerState.RUNNING:
316
+ if worker_group_status.errors:
317
+ failure_decision = self._failure_policy.make_decision(
318
+ worker_group_status
319
+ )
320
+ self._execute_failure_decision(failure_decision, worker_group_status)
321
+ else:
322
+ scaling_decision = (
323
+ self._scaling_policy.make_decision_for_running_worker_group(
324
+ worker_group_status
325
+ )
326
+ )
327
+ self._execute_scaling_decision(scaling_decision, worker_group_status)
328
+
329
+ @wrap_auto_init
330
+ def run(self):
331
+ """Run the main control loop. Exits when training is finished or errored."""
332
+ self._start()
333
+
334
+ while self.get_state() not in (
335
+ TrainControllerState.ERRORED,
336
+ TrainControllerState.FINISHED,
337
+ ):
338
+ self._run_control_loop_iteration()
339
+
340
+ self._shutdown()
341
+
342
+ def get_result(self) -> Result:
343
+ """Get the final training result from the TrainController."""
344
+
345
+ controller_state = self.get_state()
346
+ if controller_state not in (
347
+ TrainControllerState.FINISHED,
348
+ TrainControllerState.ERRORED,
349
+ ):
350
+ raise ValueError(
351
+ f"Cannot get result when controller is in state {controller_state}"
352
+ )
353
+
354
+ latest_checkpoint_result = self._checkpoint_manager.latest_checkpoint_result
355
+ latest_metrics = (
356
+ latest_checkpoint_result.metrics if latest_checkpoint_result else None
357
+ )
358
+ latest_checkpoint = (
359
+ latest_checkpoint_result.checkpoint if latest_checkpoint_result else None
360
+ )
361
+ best_checkpoints = [
362
+ (r.checkpoint, r.metrics)
363
+ for r in self._checkpoint_manager.best_checkpoint_results
364
+ ]
365
+ storage_filesystem, storage_fs_path = get_fs_and_path(
366
+ self._run_config.storage_path, self._run_config.storage_filesystem
367
+ )
368
+ experiment_fs_path = Path(storage_fs_path, self._run_config.name).as_posix()
369
+
370
+ return Result(
371
+ metrics=latest_metrics,
372
+ checkpoint=latest_checkpoint,
373
+ error=self._training_failed_error,
374
+ path=experiment_fs_path,
375
+ best_checkpoints=best_checkpoints,
376
+ _storage_filesystem=storage_filesystem,
377
+ )
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (511 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/__pycache__/default.cpython-311.pyc ADDED
Binary file (2.67 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/__pycache__/factory.cpython-311.pyc ADDED
Binary file (797 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/__pycache__/failure_policy.cpython-311.pyc ADDED
Binary file (1.7 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/default.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from ray.train import FailureConfig
4
+ from ray.train.v2._internal.execution.failure_handling import (
5
+ FailureDecision,
6
+ FailurePolicy,
7
+ )
8
+ from ray.train.v2._internal.execution.worker_group import WorkerGroupStatus
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class DefaultFailurePolicy(FailurePolicy):
14
+ def __init__(self, failure_config: FailureConfig):
15
+ super().__init__(failure_config)
16
+ self._total_failures = 0
17
+
18
+ def make_decision(self, worker_group_status: WorkerGroupStatus) -> FailureDecision:
19
+ if not worker_group_status.errors:
20
+ return FailureDecision.NOOP
21
+
22
+ self._total_failures += 1
23
+
24
+ if self.failure_config.max_failures == -1:
25
+ logger.info(
26
+ "Deciding to RESTART, since infinite retry is enabled. "
27
+ f"Encountered {self._total_failures} failures so far."
28
+ )
29
+ return FailureDecision.RESTART
30
+
31
+ if self._total_failures > self.failure_config.max_failures:
32
+ logger.info(
33
+ "Deciding to TERMINATE, since the total failure count "
34
+ f"({self._total_failures}) exceeded the maximum allowed failures: "
35
+ f"FailureConfig(max_failures={self.failure_config.max_failures})."
36
+ )
37
+ return FailureDecision.RAISE
38
+
39
+ logger.info(
40
+ "Deciding to RESTART, since the total "
41
+ f"failure count ({self._total_failures}) <= "
42
+ f"FailureConfig(max_failures={self.failure_config.max_failures})."
43
+ )
44
+ return FailureDecision.RESTART
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/failure_handling/factory.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.train import FailureConfig
2
+ from ray.train.v2._internal.execution.failure_handling import (
3
+ DefaultFailurePolicy,
4
+ FailurePolicy,
5
+ )
6
+
7
+
8
+ def create_failure_policy(failure_config: FailureConfig) -> FailurePolicy:
9
+ """Create a failure policy from the given failure config.
10
+
11
+ Defaults to the `DefaultFailurePolicy` implementation.
12
+ """
13
+ return DefaultFailurePolicy(failure_config=failure_config)
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # isort: off
2
+ from .scaling_policy import ScalingDecision, ScalingPolicy, NoopDecision, ResizeDecision
3
+ from .fixed import FixedScalingPolicy
4
+ from .factory import create_scaling_policy
5
+
6
+ # isort: on
7
+
8
+
9
+ __all__ = [
10
+ "ScalingPolicy",
11
+ "FixedScalingPolicy",
12
+ "ScalingDecision",
13
+ "NoopDecision",
14
+ "ResizeDecision",
15
+ "create_scaling_policy",
16
+ ]
17
+
18
+
19
+ # DO NOT ADD ANYTHING AFTER THIS LINE.
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/__pycache__/factory.cpython-311.pyc ADDED
Binary file (805 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/__pycache__/fixed.cpython-311.pyc ADDED
Binary file (1.54 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/__pycache__/scaling_policy.cpython-311.pyc ADDED
Binary file (3.09 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/factory.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.train.v2._internal.execution.scaling_policy import (
2
+ FixedScalingPolicy,
3
+ ScalingPolicy,
4
+ )
5
+ from ray.train.v2.api.config import ScalingConfig
6
+
7
+
8
+ def create_scaling_policy(scaling_config: ScalingConfig) -> ScalingPolicy:
9
+ """Create a scaling policy from the given scaling config.
10
+
11
+ Defaults to the `FixedScalingPolicy` implementation.
12
+ """
13
+ return FixedScalingPolicy(scaling_config=scaling_config)
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/fixed.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.train.v2._internal.execution.scaling_policy import (
2
+ NoopDecision,
3
+ ResizeDecision,
4
+ ScalingDecision,
5
+ ScalingPolicy,
6
+ )
7
+ from ray.train.v2._internal.execution.worker_group import WorkerGroupStatus
8
+
9
+
10
+ class FixedScalingPolicy(ScalingPolicy):
11
+ def make_decision_for_non_running_worker_group(
12
+ self, worker_group_status: WorkerGroupStatus
13
+ ) -> ScalingDecision:
14
+ return ResizeDecision(
15
+ num_workers=self.scaling_config.num_workers,
16
+ resources_per_worker=self.scaling_config._resources_per_worker_not_none,
17
+ )
18
+
19
+ def make_decision_for_running_worker_group(
20
+ self, worker_group_status: WorkerGroupStatus
21
+ ) -> ScalingDecision:
22
+ return NoopDecision()
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/scaling_policy/scaling_policy.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from dataclasses import dataclass
3
+ from typing import Dict
4
+
5
+ from ray.train.v2._internal.execution.callback import ControllerCallback
6
+ from ray.train.v2._internal.execution.worker_group import WorkerGroupStatus
7
+ from ray.train.v2.api.config import ScalingConfig
8
+
9
+
10
+ @dataclass
11
+ class ScalingDecision:
12
+ pass
13
+
14
+
15
+ @dataclass
16
+ class NoopDecision(ScalingDecision):
17
+ pass
18
+
19
+
20
+ @dataclass
21
+ class ResizeDecision(ScalingDecision):
22
+ num_workers: int
23
+ resources_per_worker: Dict[str, float]
24
+
25
+
26
+ class ScalingPolicy(abc.ABC, ControllerCallback):
27
+ """A policy that determines when and how to scale a worker group.
28
+
29
+ This can be used to implement elasticity and fault tolerance.
30
+
31
+ Recovery decisions are made when workers are in an inactive or unhealthy state.
32
+ Upscale decisions are optional and are made when workers are healthy.
33
+ """
34
+
35
+ def __init__(self, scaling_config: ScalingConfig):
36
+ self.scaling_config = scaling_config
37
+
38
+ @abc.abstractmethod
39
+ def make_decision_for_non_running_worker_group(
40
+ self, worker_group_status: WorkerGroupStatus
41
+ ) -> ScalingDecision:
42
+ """Makes a scaling decision when the worker group is initializing
43
+ or recovering from an error."""
44
+ raise NotImplementedError
45
+
46
+ @abc.abstractmethod
47
+ def make_decision_for_running_worker_group(
48
+ self, worker_group_status: WorkerGroupStatus
49
+ ) -> ScalingDecision:
50
+ """Makes a scaling decision when monitoring healthy, running workers."""
51
+ raise NotImplementedError
.venv/lib/python3.11/site-packages/ray/train/v2/_internal/execution/storage.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Try import ray[train] core requirements (defined in setup.py)
2
+ # isort: off
3
+ try:
4
+ import fsspec # noqa
5
+ from fsspec.implementations.local import LocalFileSystem
6
+
7
+ except (ImportError, ModuleNotFoundError) as e:
8
+ raise RuntimeError(
9
+ "fsspec is a required dependency of Ray Train and Ray Tune. "
10
+ "Please install with: `pip install fsspec`"
11
+ ) from e
12
+
13
+ try:
14
+ import pyarrow
15
+ import pyarrow.fs
16
+
17
+ except (ImportError, ModuleNotFoundError) as e:
18
+ raise RuntimeError(
19
+ "pyarrow is a required dependency of Ray Train and Ray Tune. "
20
+ "Please install with: `pip install pyarrow`"
21
+ ) from e
22
+ # isort: on
23
+
24
+ import fnmatch
25
+ import logging
26
+ import os
27
+ import shutil
28
+ from pathlib import Path
29
+ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Type, Union
30
+
31
+ from ray.air._internal.filelock import TempFileLock
32
+ from ray.train.constants import _get_ray_train_session_dir
33
+ from ray.train.v2._internal.constants import (
34
+ CHECKPOINT_MANAGER_SNAPSHOT_FILENAME,
35
+ VALIDATE_STORAGE_MARKER_FILENAME,
36
+ )
37
+ from ray.train.v2._internal.util import date_str
38
+ from ray.util.annotations import DeveloperAPI
39
+
40
+ if TYPE_CHECKING:
41
+ from ray.train import Checkpoint
42
+
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+
47
+ class _ExcludingLocalFilesystem(LocalFileSystem):
48
+ """LocalFileSystem wrapper to exclude files according to patterns.
49
+
50
+ Args:
51
+ root_path: Root path to strip when matching with the exclude pattern.
52
+ Ex: root_path="/tmp/a/b/c", exclude=["*a*"], will exclude
53
+ /tmp/a/b/c/_a_.txt but not ALL of /tmp/a/*.
54
+ exclude: List of patterns that are applied to files returned by
55
+ ``self.find()``. If a file path matches this pattern, it will
56
+ be excluded.
57
+
58
+ """
59
+
60
+ def __init__(self, root_path: Path, exclude: List[str], **kwargs):
61
+ super().__init__(**kwargs)
62
+ self._exclude = exclude
63
+ self._root_path = root_path
64
+
65
+ @property
66
+ def fsid(self):
67
+ return "_excluding_local"
68
+
69
+ def _should_exclude(self, path: str) -> bool:
70
+ """Return True if `path` (relative to `root_path`) matches any of the
71
+ `self._exclude` patterns."""
72
+ path = Path(path)
73
+ relative_path = path.relative_to(self._root_path).as_posix()
74
+ match_candidates = [relative_path]
75
+ if path.is_dir():
76
+ # Everything is in posix path format ('/')
77
+ match_candidates.append(relative_path + "/")
78
+
79
+ for excl in self._exclude:
80
+ if any(fnmatch.fnmatch(candidate, excl) for candidate in match_candidates):
81
+ return True
82
+ return False
83
+
84
+ def find(self, path, maxdepth=None, withdirs=False, detail=False, **kwargs):
85
+ """Call parent find() and exclude from result."""
86
+ paths = super().find(
87
+ path, maxdepth=maxdepth, withdirs=withdirs, detail=detail, **kwargs
88
+ )
89
+ if detail:
90
+ return {
91
+ path: out
92
+ for path, out in paths.items()
93
+ if not self._should_exclude(path)
94
+ }
95
+ else:
96
+ return [path for path in paths if not self._should_exclude(path)]
97
+
98
+
99
+ def _pyarrow_fs_copy_files(
100
+ source, destination, source_filesystem=None, destination_filesystem=None, **kwargs
101
+ ):
102
+ if isinstance(destination_filesystem, pyarrow.fs.S3FileSystem):
103
+ # Workaround multi-threading issue with pyarrow. Note that use_threads=True
104
+ # is safe for download, just not for uploads, see:
105
+ # https://github.com/apache/arrow/issues/32372
106
+ kwargs.setdefault("use_threads", False)
107
+
108
+ # Use a large chunk size to speed up large checkpoint transfers.
109
+ kwargs.setdefault("chunk_size", 64 * 1024 * 1024)
110
+
111
+ return pyarrow.fs.copy_files(
112
+ source,
113
+ destination,
114
+ source_filesystem=source_filesystem,
115
+ destination_filesystem=destination_filesystem,
116
+ **kwargs,
117
+ )
118
+
119
+
120
+ # TODO(justinvyu): Add unit tests for all these utils.
121
+
122
+
123
+ def _delete_fs_path(fs: pyarrow.fs.FileSystem, fs_path: str):
124
+ is_dir = _is_directory(fs, fs_path)
125
+
126
+ try:
127
+ if is_dir:
128
+ fs.delete_dir(fs_path)
129
+ else:
130
+ fs.delete_file(fs_path)
131
+ except Exception:
132
+ logger.exception(f"Caught exception when deleting path at ({fs}, {fs_path}):")
133
+
134
+
135
+ def _download_from_fs_path(
136
+ fs: pyarrow.fs.FileSystem,
137
+ fs_path: str,
138
+ local_path: str,
139
+ filelock: bool = True,
140
+ ):
141
+ """Downloads a directory or file from (fs, fs_path) to a local path.
142
+
143
+ If fs_path points to a directory:
144
+ - The full directory contents are downloaded directly into `local_path`,
145
+ rather than to a subdirectory of `local_path`.
146
+
147
+ If fs_path points to a file:
148
+ - The file is downloaded to `local_path`, which is expected to be a file path.
149
+
150
+ If the download fails, the `local_path` contents are
151
+ cleaned up before raising, if the directory did not previously exist.
152
+
153
+ NOTE: This method creates `local_path`'s parent directories if they do not
154
+ already exist. If the download fails, this does NOT clean up all the parent
155
+ directories that were created.
156
+
157
+ Args:
158
+ fs: The filesystem to download from.
159
+ fs_path: The filesystem path (either a directory or a file) to download.
160
+ local_path: The local path to download to.
161
+ filelock: Whether to require a file lock before downloading, useful for
162
+ multiple downloads to the same directory that may be happening in parallel.
163
+
164
+ Raises:
165
+ FileNotFoundError: if (fs, fs_path) doesn't exist.
166
+ """
167
+
168
+ _local_path = Path(local_path).resolve()
169
+ exists_before = _local_path.exists()
170
+ if _is_directory(fs=fs, fs_path=fs_path):
171
+ _local_path.mkdir(parents=True, exist_ok=True)
172
+ else:
173
+ _local_path.parent.mkdir(parents=True, exist_ok=True)
174
+
175
+ try:
176
+ if filelock:
177
+ with TempFileLock(f"{os.path.normpath(local_path)}.lock"):
178
+ _pyarrow_fs_copy_files(fs_path, local_path, source_filesystem=fs)
179
+ else:
180
+ _pyarrow_fs_copy_files(fs_path, local_path, source_filesystem=fs)
181
+ except Exception as e:
182
+ # Clean up the directory if downloading was unsuccessful
183
+ if not exists_before:
184
+ shutil.rmtree(local_path, ignore_errors=True)
185
+ raise e
186
+
187
+
188
+ def _upload_to_fs_path(
189
+ local_path: str,
190
+ fs: pyarrow.fs.FileSystem,
191
+ fs_path: str,
192
+ exclude: Optional[List[str]] = None,
193
+ ) -> None:
194
+ """Uploads a local directory or file to (fs, fs_path).
195
+
196
+ NOTE: This will create all necessary parent directories at the destination.
197
+
198
+ Args:
199
+ local_path: The local path to upload.
200
+ fs: The filesystem to upload to.
201
+ fs_path: The filesystem path where the dir/file will be uploaded to.
202
+ exclude: A list of filename matches to exclude from upload. This includes
203
+ all files under subdirectories as well.
204
+ This pattern will match with the relative paths of all files under
205
+ `local_path`.
206
+ Ex: ["*.png"] to exclude all .png images.
207
+ """
208
+
209
+ if not exclude:
210
+ # TODO(justinvyu): uploading a single file doesn't work
211
+ # (since we always create a directory at fs_path)
212
+ _create_directory(fs=fs, fs_path=fs_path)
213
+ _pyarrow_fs_copy_files(local_path, fs_path, destination_filesystem=fs)
214
+ return
215
+
216
+ _upload_to_uri_with_exclude_fsspec(
217
+ local_path=local_path, fs=fs, fs_path=fs_path, exclude=exclude
218
+ )
219
+
220
+
221
+ def _upload_to_uri_with_exclude_fsspec(
222
+ local_path: str, fs: "pyarrow.fs", fs_path: str, exclude: Optional[List[str]]
223
+ ) -> None:
224
+ local_fs = _ExcludingLocalFilesystem(root_path=local_path, exclude=exclude)
225
+ handler = pyarrow.fs.FSSpecHandler(local_fs)
226
+ source_fs = pyarrow.fs.PyFileSystem(handler)
227
+
228
+ _create_directory(fs=fs, fs_path=fs_path)
229
+ _pyarrow_fs_copy_files(
230
+ local_path, fs_path, source_filesystem=source_fs, destination_filesystem=fs
231
+ )
232
+
233
+
234
+ def _list_at_fs_path(
235
+ fs: pyarrow.fs.FileSystem,
236
+ fs_path: str,
237
+ file_filter: Callable[[pyarrow.fs.FileInfo], bool] = lambda x: True,
238
+ ) -> List[str]:
239
+ """Returns the list of filenames at (fs, fs_path), similar to os.listdir.
240
+
241
+ If the path doesn't exist, returns an empty list.
242
+ """
243
+ selector = pyarrow.fs.FileSelector(fs_path, allow_not_found=True, recursive=False)
244
+ return [
245
+ os.path.relpath(file_info.path.lstrip("/"), start=fs_path.lstrip("/"))
246
+ for file_info in fs.get_file_info(selector)
247
+ if file_filter(file_info)
248
+ ]
249
+
250
+
251
+ def _exists_at_fs_path(fs: pyarrow.fs.FileSystem, fs_path: str) -> bool:
252
+ """Returns True if (fs, fs_path) exists."""
253
+
254
+ valid = fs.get_file_info(fs_path)
255
+ return valid.type != pyarrow.fs.FileType.NotFound
256
+
257
+
258
+ def _is_directory(fs: pyarrow.fs.FileSystem, fs_path: str) -> bool:
259
+ """Checks if (fs, fs_path) is a directory or a file.
260
+
261
+ Raises:
262
+ FileNotFoundError: if (fs, fs_path) doesn't exist.
263
+ """
264
+
265
+ file_info = fs.get_file_info(fs_path)
266
+ if file_info.type == pyarrow.fs.FileType.NotFound:
267
+ raise FileNotFoundError(f"Path not found: ({fs}, {fs_path})")
268
+
269
+ return not file_info.is_file
270
+
271
+
272
+ def _create_directory(fs: pyarrow.fs.FileSystem, fs_path: str) -> None:
273
+ """Create directory at (fs, fs_path).
274
+
275
+ Some external filesystems require directories to already exist, or at least
276
+ the `netloc` to be created (e.g. PyArrows ``mock://`` filesystem).
277
+
278
+ Generally this should be done before and outside of Ray applications. This
279
+ utility is thus primarily used in testing, e.g. of ``mock://` URIs.
280
+ """
281
+ try:
282
+ fs.create_dir(fs_path)
283
+ except Exception:
284
+ logger.exception(
285
+ f"Caught exception when creating directory at ({fs}, {fs_path}):"
286
+ )
287
+
288
+
289
+ def get_fs_and_path(
290
+ storage_path: Union[str, os.PathLike],
291
+ storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
292
+ ) -> Tuple[pyarrow.fs.FileSystem, str]:
293
+ """Returns the fs and path from a storage path and an optional custom fs.
294
+
295
+ Args:
296
+ storage_path: A storage path or URI. (ex: s3://bucket/path or /tmp/ray_results)
297
+ storage_filesystem: A custom filesystem to use. If not provided,
298
+ this will be auto-resolved by pyarrow. If provided, the storage_path
299
+ is assumed to be prefix-stripped already, and must be a valid path
300
+ on the filesystem.
301
+ """
302
+ storage_path = str(storage_path)
303
+
304
+ if storage_filesystem:
305
+ return storage_filesystem, storage_path
306
+
307
+ return pyarrow.fs.FileSystem.from_uri(storage_path)
308
+
309
+
310
+ @DeveloperAPI
311
+ class StorageContext:
312
+ """Shared context that holds the source of truth for all paths and
313
+ storage utilities, passed along from the driver to workers.
314
+
315
+ This object defines a few types of paths:
316
+ 1. *_fs_path: A path on the `storage_filesystem`. This is a regular path
317
+ which has been prefix-stripped by pyarrow.fs.FileSystem.from_uri and
318
+ can be joined with `Path(...).as_posix()`.
319
+ 2. *_driver_staging_path: The temporary staging directory on the local filesystem
320
+ where driver artifacts are saved to before persisting them to storage.
321
+ 3. trial_working_directory: The local filesystem path that the remote
322
+ actors' working directories are moved to by default.
323
+ This is separated from the driver staging path so that driver syncing
324
+ does not implicitly upload the trial working directory, for trials on the
325
+ driver node.
326
+
327
+ Example with storage_path="mock:///bucket/path?param=1":
328
+
329
+ >>> import ray
330
+ >>> from ray.train._internal.storage import StorageContext
331
+ >>> import os
332
+ >>> _ = ray.init()
333
+ >>> storage = StorageContext(
334
+ ... storage_path="mock://netloc/bucket/path?param=1",
335
+ ... experiment_dir_name="exp_name",
336
+ ... )
337
+ >>> storage.storage_filesystem # Auto-resolved # doctest: +ELLIPSIS
338
+ <pyarrow._fs._MockFileSystem object...
339
+ >>> storage.experiment_fs_path
340
+ 'bucket/path/exp_name'
341
+ >>> storage.experiment_driver_staging_path # doctest: +ELLIPSIS
342
+ '/tmp/ray/session_.../artifacts/.../exp_name/driver_artifacts'
343
+ >>> storage.trial_dir_name = "trial_dir"
344
+ >>> storage.trial_fs_path
345
+ 'bucket/path/exp_name/trial_dir'
346
+ >>> storage.trial_driver_staging_path # doctest: +ELLIPSIS
347
+ '/tmp/ray/session_.../artifacts/.../exp_name/driver_artifacts/trial_dir'
348
+ >>> storage.trial_working_directory # doctest: +ELLIPSIS
349
+ '/tmp/ray/session_.../artifacts/.../exp_name/working_dirs/trial_dir'
350
+ >>> ray.shutdown()
351
+
352
+ Example with storage_path="/tmp/ray_results":
353
+
354
+ >>> from ray.train._internal.storage import StorageContext
355
+ >>> storage = StorageContext(
356
+ ... storage_path="/tmp/ray_results",
357
+ ... experiment_dir_name="exp_name",
358
+ ... )
359
+ >>> storage.storage_fs_path
360
+ '/tmp/ray_results'
361
+ >>> storage.experiment_fs_path
362
+ '/tmp/ray_results/exp_name'
363
+ >>> storage.storage_filesystem # Auto-resolved # doctest: +ELLIPSIS
364
+ <pyarrow._fs.LocalFileSystem object...
365
+
366
+ Internal Usage Examples:
367
+ - To copy files to the trial directory on the storage filesystem:
368
+
369
+ pyarrow.fs.copy_files(
370
+ local_dir,
371
+ Path(storage.trial_fs_path, "subdir").as_posix(),
372
+ destination_filesystem=storage.filesystem
373
+ )
374
+
375
+ .. warning::
376
+ This is an experimental developer API and is subject to change
377
+ without notice between versions.
378
+ """
379
+
380
+ def __init__(
381
+ self,
382
+ storage_path: Union[str, os.PathLike],
383
+ experiment_dir_name: str,
384
+ storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
385
+ ):
386
+ self.custom_fs_provided = storage_filesystem is not None
387
+
388
+ # Invariant: (`storage_filesystem`, `storage_path`) is the location where
389
+ # *all* results can be accessed.
390
+ self.experiment_dir_name = experiment_dir_name
391
+
392
+ self.storage_filesystem, self.storage_fs_path = get_fs_and_path(
393
+ storage_path, storage_filesystem
394
+ )
395
+ self.storage_fs_path = Path(self.storage_fs_path).as_posix()
396
+
397
+ self._create_validation_file()
398
+ self._check_validation_file()
399
+
400
+ def __str__(self):
401
+ return (
402
+ "StorageContext<\n"
403
+ f" storage_filesystem='{self.storage_filesystem.type_name}',\n"
404
+ f" storage_fs_path='{self.storage_fs_path}',\n"
405
+ f" experiment_dir_name='{self.experiment_dir_name}',\n"
406
+ ">"
407
+ )
408
+
409
+ def _create_validation_file(self):
410
+ """On the creation of a storage context, create a validation file at the
411
+ storage path to verify that the storage path can be written to.
412
+ This validation file is also used to check whether the storage path is
413
+ accessible by all nodes in the cluster."""
414
+ valid_file = Path(
415
+ self.experiment_fs_path, VALIDATE_STORAGE_MARKER_FILENAME
416
+ ).as_posix()
417
+ self.storage_filesystem.create_dir(self.experiment_fs_path)
418
+ with self.storage_filesystem.open_output_stream(valid_file):
419
+ pass
420
+
421
+ def _check_validation_file(self):
422
+ """Checks that the validation file exists at the storage path."""
423
+ valid_file = Path(
424
+ self.experiment_fs_path, VALIDATE_STORAGE_MARKER_FILENAME
425
+ ).as_posix()
426
+ if not _exists_at_fs_path(fs=self.storage_filesystem, fs_path=valid_file):
427
+ raise RuntimeError(
428
+ f"Unable to set up cluster storage with the following settings:\n{self}"
429
+ "\nCheck that all nodes in the cluster have read/write access "
430
+ "to the configured storage path. `RunConfig(storage_path)` should be "
431
+ "set to a cloud storage URI or a shared filesystem path accessible "
432
+ "by all nodes in your cluster ('s3://bucket' or '/mnt/nfs'). "
433
+ "A local path on the head node is not accessible by worker nodes. "
434
+ "See: https://docs.ray.io/en/latest/train/user-guides/persistent-storage.html" # noqa: E501
435
+ )
436
+
437
+ def persist_current_checkpoint(
438
+ self, checkpoint: "Checkpoint", checkpoint_dir_name: str
439
+ ) -> "Checkpoint":
440
+ """Persists a given checkpoint to the current checkpoint path on the filesystem.
441
+
442
+ This method copies the checkpoint files to the storage location.
443
+ It's up to the user to delete the original checkpoint files if desired.
444
+
445
+ For example, the original directory is typically a local temp directory.
446
+
447
+ Args:
448
+ checkpoint: The checkpoint to persist to
449
+ (fs, experiment_fs_path / checkpoint_dir_name).
450
+
451
+ Returns:
452
+ Checkpoint: A Checkpoint pointing to the persisted checkpoint location.
453
+ """
454
+ # TODO(justinvyu): Fix this cyclical import.
455
+ from ray.train import Checkpoint
456
+
457
+ checkpoint_fs_path = self.build_checkpoint_path_from_name(checkpoint_dir_name)
458
+
459
+ logger.debug(
460
+ "Copying checkpoint files to storage path:\n"
461
+ "({source_fs}, {source}) -> ({dest_fs}, {destination})".format(
462
+ source=checkpoint.path,
463
+ destination=checkpoint_fs_path,
464
+ source_fs=checkpoint.filesystem,
465
+ dest_fs=self.storage_filesystem,
466
+ )
467
+ )
468
+
469
+ # Raise an error if the storage path is not accessible when
470
+ # attempting to upload a checkpoint from a remote worker.
471
+ # Ex: If storage_path is a local path, then a validation marker
472
+ # will only exist on the head node but not the worker nodes.
473
+ self._check_validation_file()
474
+
475
+ self.storage_filesystem.create_dir(checkpoint_fs_path)
476
+ _pyarrow_fs_copy_files(
477
+ source=checkpoint.path,
478
+ destination=checkpoint_fs_path,
479
+ source_filesystem=checkpoint.filesystem,
480
+ destination_filesystem=self.storage_filesystem,
481
+ )
482
+
483
+ persisted_checkpoint = Checkpoint(
484
+ filesystem=self.storage_filesystem,
485
+ path=checkpoint_fs_path,
486
+ )
487
+ logger.info(f"Checkpoint successfully created at: {persisted_checkpoint}")
488
+ return persisted_checkpoint
489
+
490
+ @property
491
+ def experiment_fs_path(self) -> str:
492
+ """The path on the `storage_filesystem` to the experiment directory.
493
+
494
+ NOTE: This does not have a URI prefix anymore, since it has been stripped
495
+ by pyarrow.fs.FileSystem.from_uri already. The URI scheme information is
496
+ kept in `storage_filesystem` instead.
497
+ """
498
+ return Path(self.storage_fs_path, self.experiment_dir_name).as_posix()
499
+
500
+ @property
501
+ def local_working_directory(self) -> str:
502
+ """Every ray train worker will set this directory as its working directory."""
503
+ if self.experiment_dir_name is None:
504
+ raise RuntimeError(
505
+ "Cannot access `local_working_directory` without "
506
+ "setting `experiment_dir_name`"
507
+ )
508
+ return Path(_get_ray_train_session_dir(), self.experiment_dir_name).as_posix()
509
+
510
+ @property
511
+ def checkpoint_manager_snapshot_path(self) -> str:
512
+ """The path to the checkpoint manager snapshot file."""
513
+ return Path(
514
+ self.experiment_fs_path, CHECKPOINT_MANAGER_SNAPSHOT_FILENAME
515
+ ).as_posix()
516
+
517
+ @staticmethod
518
+ def get_experiment_dir_name(run_obj: Union[str, Callable, Type]) -> str:
519
+ from ray.tune.experiment import Experiment
520
+
521
+ run_identifier = Experiment.get_trainable_name(run_obj)
522
+
523
+ if bool(int(os.environ.get("TUNE_DISABLE_DATED_SUBDIR", 0))):
524
+ dir_name = run_identifier
525
+ else:
526
+ dir_name = "{}_{}".format(run_identifier, date_str())
527
+ return dir_name
528
+
529
+ @staticmethod
530
+ def make_default_checkpoint_dir_name():
531
+ """Get the name of the checkpoint directory by timestamp."""
532
+ return f"checkpoint_{date_str(include_ms=True)}"
533
+
534
+ def extract_checkpoint_dir_name_from_path(self, checkpoint_path: str) -> str:
535
+ """Get the checkpoint name from the checkpoint path.
536
+ The parent directory of the checkpoint path should be the experiment directory.
537
+ """
538
+ # TODO: Use Pathlib to extract the name when supports at least Python 3.9
539
+ experiment_fs_path = self.experiment_fs_path + "/"
540
+ if not checkpoint_path.startswith(experiment_fs_path):
541
+ raise ValueError(
542
+ f"Checkpoint path {checkpoint_path} is not under the experiment "
543
+ f"directory {self.experiment_fs_path}."
544
+ )
545
+ return checkpoint_path[len(experiment_fs_path) :]
546
+
547
+ def build_checkpoint_path_from_name(self, checkpoint_name: str) -> str:
548
+ """Get the checkpoint path from the checkpoint name.
549
+ The parent directory of the checkpoint path should be the experiment directory.
550
+ """
551
+ return Path(self.experiment_fs_path, checkpoint_name).as_posix()