Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/ray/train/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/__pycache__/backend.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/__pycache__/base_trainer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/__pycache__/constants.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/__pycache__/context.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/__pycache__/data_parallel_trainer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/__pycache__/error.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/__pycache__/predictor.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/__pycache__/session.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/accelerator.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/backend_executor.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/checkpoint_manager.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/data_config.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/dl_predictor.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/framework_checkpoint.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/session.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/storage.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/syncer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/worker_group.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/accelerator.py +5 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/backend_executor.py +830 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/checkpoint_manager.py +185 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/data_config.py +139 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/dl_predictor.py +103 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/framework_checkpoint.py +45 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/session.py +1163 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/state/__init__.py +14 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/state/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/state/__pycache__/schema.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/state/__pycache__/state_actor.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/state/__pycache__/state_manager.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/state/schema.py +158 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/state/state_actor.py +62 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/state/state_manager.py +126 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/storage.py +725 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/syncer.py +490 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/utils.py +239 -0
- .venv/lib/python3.11/site-packages/ray/train/_internal/worker_group.py +426 -0
- .venv/lib/python3.11/site-packages/ray/train/horovod/__init__.py +22 -0
- .venv/lib/python3.11/site-packages/ray/train/horovod/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/horovod/__pycache__/config.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/horovod/__pycache__/horovod_trainer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/horovod/config.py +159 -0
- .venv/lib/python3.11/site-packages/ray/train/horovod/horovod_trainer.py +202 -0
- .venv/lib/python3.11/site-packages/ray/train/lightning/__init__.py +39 -0
- .venv/lib/python3.11/site-packages/ray/train/lightning/__pycache__/__init__.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/ray/train/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (2.69 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/__pycache__/backend.cpython-311.pyc
ADDED
|
Binary file (3.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/__pycache__/base_trainer.cpython-311.pyc
ADDED
|
Binary file (37.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/__pycache__/constants.cpython-311.pyc
ADDED
|
Binary file (3.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/__pycache__/context.cpython-311.pyc
ADDED
|
Binary file (7.62 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/__pycache__/data_parallel_trainer.cpython-311.pyc
ADDED
|
Binary file (26.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/__pycache__/error.cpython-311.pyc
ADDED
|
Binary file (671 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/__pycache__/predictor.cpython-311.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/__pycache__/session.cpython-311.pyc
ADDED
|
Binary file (181 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (896 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/_internal/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (192 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/accelerator.cpython-311.pyc
ADDED
|
Binary file (551 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/backend_executor.cpython-311.pyc
ADDED
|
Binary file (36.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/checkpoint_manager.cpython-311.pyc
ADDED
|
Binary file (8.59 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/data_config.cpython-311.pyc
ADDED
|
Binary file (6.97 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/dl_predictor.cpython-311.pyc
ADDED
|
Binary file (5.64 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/framework_checkpoint.cpython-311.pyc
ADDED
|
Binary file (2.52 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/session.cpython-311.pyc
ADDED
|
Binary file (46.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/storage.cpython-311.pyc
ADDED
|
Binary file (36.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/syncer.cpython-311.pyc
ADDED
|
Binary file (23.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/_internal/__pycache__/worker_group.cpython-311.pyc
ADDED
|
Binary file (21.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/_internal/accelerator.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Accelerator(abc.ABC):
|
| 5 |
+
"""A utility that contains methods to accelerate training."""
|
.venv/lib/python3.11/site-packages/ray/train/_internal/backend_executor.py
ADDED
|
@@ -0,0 +1,830 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar
|
| 7 |
+
|
| 8 |
+
import ray
|
| 9 |
+
import ray._private.ray_constants as ray_constants
|
| 10 |
+
from ray._private.ray_constants import env_integer
|
| 11 |
+
from ray.data import Dataset
|
| 12 |
+
from ray.exceptions import RayActorError
|
| 13 |
+
from ray.train import Checkpoint, DataConfig
|
| 14 |
+
from ray.train._internal.session import (
|
| 15 |
+
TrialInfo,
|
| 16 |
+
_TrainingResult,
|
| 17 |
+
get_session,
|
| 18 |
+
init_session,
|
| 19 |
+
shutdown_session,
|
| 20 |
+
)
|
| 21 |
+
from ray.train._internal.storage import StorageContext
|
| 22 |
+
from ray.train._internal.utils import check_for_failure
|
| 23 |
+
from ray.train._internal.worker_group import WorkerGroup
|
| 24 |
+
from ray.train.backend import BackendConfig
|
| 25 |
+
from ray.train.constants import (
|
| 26 |
+
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV,
|
| 27 |
+
ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV,
|
| 28 |
+
ENABLE_SHARE_NEURON_CORES_ACCELERATOR_ENV,
|
| 29 |
+
ENABLE_SHARE_NPU_RT_VISIBLE_DEVICES_ENV,
|
| 30 |
+
ENABLE_SHARE_ROCR_VISIBLE_DEVICES_ENV,
|
| 31 |
+
RAY_TRAIN_ENABLE_STATE_TRACKING,
|
| 32 |
+
TRAIN_ENABLE_WORKER_SPREAD_ENV,
|
| 33 |
+
TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV,
|
| 34 |
+
)
|
| 35 |
+
from ray.util.placement_group import get_current_placement_group, remove_placement_group
|
| 36 |
+
|
| 37 |
+
T = TypeVar("T")
|
| 38 |
+
|
| 39 |
+
logger = logging.getLogger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class TrainBackendError(Exception):
|
| 43 |
+
"""Errors with BackendExecutor that should not be exposed to user."""
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class TrainingWorkerError(Exception):
|
| 47 |
+
"""Raised if a worker fails during training."""
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@dataclass
|
| 51 |
+
class ResourceConfig:
|
| 52 |
+
"""
|
| 53 |
+
Resource configuration for resource_ids to share between workers.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
resource_name: The name of the resource to configure
|
| 57 |
+
(Example: "neuron_cores" or "gpu").
|
| 58 |
+
resource_enable_sharing_env_var: The environment variable to
|
| 59 |
+
check if the resource should be shared.
|
| 60 |
+
share_resource_ids_env_var: The environment variable to configure for
|
| 61 |
+
sharing the resources with other workers.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
resource_name: str
|
| 65 |
+
resource_enable_sharing_env_var: str
|
| 66 |
+
share_resource_ids_env_var: str
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class BackendExecutor:
|
| 70 |
+
"""Main execution class for training backends.
|
| 71 |
+
|
| 72 |
+
This class holds a worker group and is responsible for executing the
|
| 73 |
+
training function on the workers, and collecting intermediate results
|
| 74 |
+
from ``session.report()``.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
backend_config: The configurations for this
|
| 78 |
+
specific backend.
|
| 79 |
+
num_workers: Number of workers to use for training.
|
| 80 |
+
resources_per_worker (Optional[Dict[str, float]]):
|
| 81 |
+
Dictionary specifying the resources that will be
|
| 82 |
+
requested for each worker. Defaults to {"CPU": 1}.
|
| 83 |
+
max_retries: Number of retries when Ray actors fail.
|
| 84 |
+
Defaults to 3. Set to -1 for unlimited retries.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
backend_config: BackendConfig,
|
| 90 |
+
# TODO(xwjiang): Legacy Ray Train trainer clean up!
|
| 91 |
+
trial_info: Optional[TrialInfo] = None,
|
| 92 |
+
num_workers: int = 1,
|
| 93 |
+
resources_per_worker: Optional[Dict[str, float]] = None,
|
| 94 |
+
max_retries: int = 3,
|
| 95 |
+
):
|
| 96 |
+
if resources_per_worker is None:
|
| 97 |
+
self._resources_per_worker = {"CPU": 1}
|
| 98 |
+
else:
|
| 99 |
+
self._resources_per_worker = resources_per_worker.copy()
|
| 100 |
+
|
| 101 |
+
self._backend_config = backend_config
|
| 102 |
+
self._backend = backend_config.backend_cls()
|
| 103 |
+
self._num_workers = num_workers
|
| 104 |
+
self._max_failures = max_retries
|
| 105 |
+
if self._max_failures < 0:
|
| 106 |
+
self._max_failures = float("inf")
|
| 107 |
+
self._num_failures = 0
|
| 108 |
+
self._last_failure = None
|
| 109 |
+
self._initialization_hook = None
|
| 110 |
+
self._placement_group = None
|
| 111 |
+
|
| 112 |
+
self._trial_info = trial_info
|
| 113 |
+
|
| 114 |
+
self.worker_group = InactiveWorkerGroup()
|
| 115 |
+
self.dataset_shards = None
|
| 116 |
+
|
| 117 |
+
self._resource_configs = [
|
| 118 |
+
ResourceConfig(
|
| 119 |
+
ray_constants.NEURON_CORES,
|
| 120 |
+
ENABLE_SHARE_NEURON_CORES_ACCELERATOR_ENV,
|
| 121 |
+
ray_constants.NEURON_RT_VISIBLE_CORES_ENV_VAR,
|
| 122 |
+
),
|
| 123 |
+
ResourceConfig(
|
| 124 |
+
ray_constants.NPU,
|
| 125 |
+
ENABLE_SHARE_NPU_RT_VISIBLE_DEVICES_ENV,
|
| 126 |
+
ray_constants.NPU_RT_VISIBLE_DEVICES_ENV_VAR,
|
| 127 |
+
),
|
| 128 |
+
# For AMD GPUs, they are using ROCR_VISIBLE_DEVICES env var.
|
| 129 |
+
ResourceConfig(
|
| 130 |
+
ray_constants.GPU,
|
| 131 |
+
ENABLE_SHARE_ROCR_VISIBLE_DEVICES_ENV,
|
| 132 |
+
ray_constants.ROCR_VISIBLE_DEVICES_ENV_VAR,
|
| 133 |
+
),
|
| 134 |
+
]
|
| 135 |
+
|
| 136 |
+
# Record the initialization time of BackendExecutor, which is
|
| 137 |
+
# after trainer.fit() and before worker_group executes the training function.
|
| 138 |
+
self._start_time_ms = int(time.time() * 1000)
|
| 139 |
+
|
| 140 |
+
self.state_tracking_enabled = env_integer(RAY_TRAIN_ENABLE_STATE_TRACKING, 0)
|
| 141 |
+
|
| 142 |
+
def start(
|
| 143 |
+
self,
|
| 144 |
+
initialization_hook: Optional[Callable[[], None]] = None,
|
| 145 |
+
train_cls: Optional[Type] = None,
|
| 146 |
+
train_cls_args: Optional[Tuple] = None,
|
| 147 |
+
train_cls_kwargs: Optional[Dict] = None,
|
| 148 |
+
):
|
| 149 |
+
"""Starts the worker group."""
|
| 150 |
+
self._create_placement_group()
|
| 151 |
+
placement_group = self._placement_group or "default"
|
| 152 |
+
self.worker_group = WorkerGroup(
|
| 153 |
+
num_workers=self._num_workers,
|
| 154 |
+
resources_per_worker=self._resources_per_worker,
|
| 155 |
+
actor_cls=train_cls,
|
| 156 |
+
actor_cls_args=train_cls_args,
|
| 157 |
+
actor_cls_kwargs=train_cls_kwargs,
|
| 158 |
+
placement_group=placement_group,
|
| 159 |
+
)
|
| 160 |
+
# Hack to avoid OOMs.
|
| 161 |
+
# This is just a temporary solution for Train loading entire checkpoints
|
| 162 |
+
# into memory by ensuring that the rank 0 worker is on the same node as
|
| 163 |
+
# trainable, thus allowing for lazy checkpoint transfer to be used.
|
| 164 |
+
# See https://github.com/ray-project/ray/issues/33073
|
| 165 |
+
# for more context.
|
| 166 |
+
# TODO remove passing in trial_driver_ip.
|
| 167 |
+
|
| 168 |
+
trial_driver_node_id = (
|
| 169 |
+
self._trial_info.driver_node_id if self._trial_info else None
|
| 170 |
+
)
|
| 171 |
+
self.worker_group.sort_workers_by_node_id_and_gpu_id(trial_driver_node_id)
|
| 172 |
+
|
| 173 |
+
try:
|
| 174 |
+
if initialization_hook:
|
| 175 |
+
self._initialization_hook = initialization_hook
|
| 176 |
+
self.worker_group.execute(initialization_hook)
|
| 177 |
+
|
| 178 |
+
# Always propagate the driver's DataContext to each worker in the group.
|
| 179 |
+
from ray.data import DataContext
|
| 180 |
+
|
| 181 |
+
def _set_driver_dataset_context(ctx: DataContext):
|
| 182 |
+
DataContext._set_current(ctx)
|
| 183 |
+
|
| 184 |
+
self.worker_group.execute(
|
| 185 |
+
_set_driver_dataset_context,
|
| 186 |
+
DataContext.get_current(),
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
share_cuda_visible_devices_enabled = bool(
|
| 190 |
+
env_integer(
|
| 191 |
+
ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV,
|
| 192 |
+
self._backend.share_cuda_visible_devices,
|
| 193 |
+
)
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
if (
|
| 197 |
+
self._resources_per_worker.get("GPU", 0) > 0
|
| 198 |
+
and share_cuda_visible_devices_enabled
|
| 199 |
+
):
|
| 200 |
+
self._share_cuda_visible_devices()
|
| 201 |
+
for resource_config in self._resource_configs:
|
| 202 |
+
if self._is_share_resources_enabled(
|
| 203 |
+
resource_config.resource_name,
|
| 204 |
+
resource_config.resource_enable_sharing_env_var,
|
| 205 |
+
):
|
| 206 |
+
self._share_resource_ids(
|
| 207 |
+
resource_config.resource_name,
|
| 208 |
+
resource_config.share_resource_ids_env_var,
|
| 209 |
+
)
|
| 210 |
+
self._backend.on_start(self.worker_group, self._backend_config)
|
| 211 |
+
except RayActorError as exc:
|
| 212 |
+
logger.exception(str(exc))
|
| 213 |
+
logger.warning(
|
| 214 |
+
"Failure occurred during startup. Restarting all workers and "
|
| 215 |
+
"attempting to startup again."
|
| 216 |
+
)
|
| 217 |
+
self._increment_failures()
|
| 218 |
+
self._restart()
|
| 219 |
+
|
| 220 |
+
if self.state_tracking_enabled:
|
| 221 |
+
from ray.train._internal.state import TrainRunStateManager
|
| 222 |
+
from ray.train._internal.state.state_actor import get_state_actor
|
| 223 |
+
|
| 224 |
+
self.state_manager = TrainRunStateManager(state_actor=get_state_actor())
|
| 225 |
+
|
| 226 |
+
def _create_placement_group(self):
|
| 227 |
+
"""Creates a placement group if it does not exist.
|
| 228 |
+
|
| 229 |
+
If a placement group is already detected (Tune) this will be a no-op.
|
| 230 |
+
|
| 231 |
+
By default the placement group will be created with PACK strategy.
|
| 232 |
+
This is optimized for colocating GPUs on a minimal number of nodes.
|
| 233 |
+
This behavior can be overridden to use the SPREAD strategy by defining
|
| 234 |
+
``TRAIN_ENABLE_WORKER_SPREAD_ENV``
|
| 235 |
+
|
| 236 |
+
If a placement group is created it will be stored as
|
| 237 |
+
self._placement_group.
|
| 238 |
+
"""
|
| 239 |
+
current_placement_group = get_current_placement_group()
|
| 240 |
+
worker = ray._private.worker.global_worker
|
| 241 |
+
should_capture_child_tasks_in_placement_group = (
|
| 242 |
+
worker.should_capture_child_tasks_in_placement_group
|
| 243 |
+
)
|
| 244 |
+
should_create_placement_group = (
|
| 245 |
+
current_placement_group is None
|
| 246 |
+
or not should_capture_child_tasks_in_placement_group
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
if should_create_placement_group:
|
| 250 |
+
bundles = [
|
| 251 |
+
self._resources_per_worker.copy() for _ in range(self._num_workers)
|
| 252 |
+
]
|
| 253 |
+
|
| 254 |
+
use_spread = bool(env_integer(TRAIN_ENABLE_WORKER_SPREAD_ENV, 0))
|
| 255 |
+
strategy = "SPREAD" if use_spread else "PACK"
|
| 256 |
+
|
| 257 |
+
placement_group = ray.util.placement_group(bundles, strategy=strategy)
|
| 258 |
+
logger.debug("Waiting for placement group to start.")
|
| 259 |
+
timeout = env_integer(TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV, 100)
|
| 260 |
+
ready, _ = ray.wait([placement_group.ready()], timeout=timeout)
|
| 261 |
+
if ready:
|
| 262 |
+
logger.debug("Placement group has started.")
|
| 263 |
+
else:
|
| 264 |
+
raise TimeoutError(
|
| 265 |
+
"Placement group creation timed out. Make sure your "
|
| 266 |
+
"cluster either has enough resources or use an "
|
| 267 |
+
"autoscaling cluster. If you are running on a cluster, "
|
| 268 |
+
"make sure you specify an address in `ray.init()`, for example, "
|
| 269 |
+
'`ray.init("auto")`. You can also increase the timeout by setting '
|
| 270 |
+
"the TRAIN_PLACEMENT_GROUP_TIMEOUT_S environment variable. "
|
| 271 |
+
"Current resources available: {}, resources requested by the "
|
| 272 |
+
"placement group: {}".format(
|
| 273 |
+
ray.available_resources(), placement_group.bundle_specs
|
| 274 |
+
)
|
| 275 |
+
)
|
| 276 |
+
self._placement_group = placement_group
|
| 277 |
+
|
| 278 |
+
def _share_cuda_visible_devices(self):
|
| 279 |
+
"""Sets CUDA_VISIBLE_DEVICES on all workers.
|
| 280 |
+
|
| 281 |
+
For each worker, CUDA_VISIBLE_DEVICES will be set to the GPU IDs
|
| 282 |
+
visible to all workers on that worker's node.
|
| 283 |
+
|
| 284 |
+
This allows GPU workers on the same node to communicate with one
|
| 285 |
+
another.
|
| 286 |
+
|
| 287 |
+
Example:
|
| 288 |
+
|
| 289 |
+
Setup:
|
| 290 |
+
- Node1:
|
| 291 |
+
- Worker1: {0, 1}
|
| 292 |
+
- Worker2: {2, 3}
|
| 293 |
+
- Node2:
|
| 294 |
+
- Worker3: {0, 1}
|
| 295 |
+
|
| 296 |
+
CUDA_VISIBLE_DEVICES:
|
| 297 |
+
- Worker1: "0,1,2,3"
|
| 298 |
+
- Worker2: "0,1,2,3"
|
| 299 |
+
- Worker3: "0,1"
|
| 300 |
+
|
| 301 |
+
"""
|
| 302 |
+
self._share_resource_ids(
|
| 303 |
+
ray_constants.GPU, ray_constants.CUDA_VISIBLE_DEVICES_ENV_VAR
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
def _share_resource_ids(self, resource: str, env_var: str):
|
| 307 |
+
"""Sets the given env_var on all workers.
|
| 308 |
+
|
| 309 |
+
For each worker, the cores/devices are visible to all the
|
| 310 |
+
workers on that worker's node.This allows workers on the
|
| 311 |
+
same node to communicate with one another.
|
| 312 |
+
|
| 313 |
+
Example:
|
| 314 |
+
|
| 315 |
+
Setup:
|
| 316 |
+
- Node1:
|
| 317 |
+
- Worker1: {0, 1}
|
| 318 |
+
- Worker2: {2, 3}
|
| 319 |
+
- Node2:
|
| 320 |
+
- Worker3: {0, 1}
|
| 321 |
+
|
| 322 |
+
NEURON_RT_VISIBLE_CORES/TPU_VISIBLE_CHIPS/...:
|
| 323 |
+
- Worker1: "0,1,2,3"
|
| 324 |
+
- Worker2: "0,1,2,3"
|
| 325 |
+
- Worker2: "0,1"
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
resource: The name of the resource/accelerator.
|
| 329 |
+
env_var: The name of the environment variable to set.
|
| 330 |
+
"""
|
| 331 |
+
node_ids_and_resource_ids = [
|
| 332 |
+
(
|
| 333 |
+
w.metadata.node_id,
|
| 334 |
+
w.metadata.resource_ids[resource],
|
| 335 |
+
)
|
| 336 |
+
for w in self.worker_group.workers
|
| 337 |
+
]
|
| 338 |
+
node_id_to_worker_id = defaultdict(set)
|
| 339 |
+
node_id_to_resource_ids = defaultdict(set)
|
| 340 |
+
|
| 341 |
+
for worker_id, (node_id, resource_ids) in enumerate(node_ids_and_resource_ids):
|
| 342 |
+
node_id_to_worker_id[node_id].add(worker_id)
|
| 343 |
+
node_id_to_resource_ids[node_id].update(resource_ids)
|
| 344 |
+
|
| 345 |
+
futures = []
|
| 346 |
+
for node_id, resource_ids in node_id_to_resource_ids.items():
|
| 347 |
+
resource_ids = sorted(resource_ids)
|
| 348 |
+
all_resource_ids = ",".join(resource_ids)
|
| 349 |
+
|
| 350 |
+
def set_resource_ids():
|
| 351 |
+
os.environ[env_var] = all_resource_ids
|
| 352 |
+
|
| 353 |
+
for worker_id in node_id_to_worker_id[node_id]:
|
| 354 |
+
futures.append(
|
| 355 |
+
self.worker_group.execute_single_async(worker_id, set_resource_ids)
|
| 356 |
+
)
|
| 357 |
+
ray.get(futures)
|
| 358 |
+
|
| 359 |
+
def _is_share_resources_enabled(self, resource_name: str, enable_sharing_env: str):
|
| 360 |
+
"""Whether to share resource IDs on all workers
|
| 361 |
+
based on enable_sharing_env.
|
| 362 |
+
|
| 363 |
+
This will return true if resources are requested and greater than 0.
|
| 364 |
+
Also, user can disable by configuring the `enable_sharing_env` to "0".
|
| 365 |
+
|
| 366 |
+
Args:
|
| 367 |
+
resource_name: The name of the resource/accelerator.
|
| 368 |
+
enable_sharing_env: The name of the environment variable
|
| 369 |
+
to check.
|
| 370 |
+
"""
|
| 371 |
+
has_resource_requested = self._resources_per_worker.get(resource_name, 0) > 0
|
| 372 |
+
return has_resource_requested and ray_constants.env_bool(
|
| 373 |
+
enable_sharing_env, True
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
def _create_rank_world_size_mappings(self) -> List[Dict]:
|
| 377 |
+
"""Create rank and world size mappings for workers.
|
| 378 |
+
There are three maps returned:
|
| 379 |
+
- local_rank_map, which maps from worker world_rank to local_rank.
|
| 380 |
+
- local_world_size_map, which maps from world_rank to local_world_size
|
| 381 |
+
- node_rank_map, which maps from world rank to node rank
|
| 382 |
+
|
| 383 |
+
Example:
|
| 384 |
+
Worker 0: node 0
|
| 385 |
+
Worker 1: node 0
|
| 386 |
+
Worker 2: node 1
|
| 387 |
+
Worker 3: node 0
|
| 388 |
+
Worker 4: node 1
|
| 389 |
+
|
| 390 |
+
Workers 0, 1, 3 are on node 0.
|
| 391 |
+
Workers 2, 4 are on node 1.
|
| 392 |
+
|
| 393 |
+
Expected local_rank_map:
|
| 394 |
+
{
|
| 395 |
+
0 -> 0,
|
| 396 |
+
1 -> 1,
|
| 397 |
+
2 -> 0,
|
| 398 |
+
3 -> 2,
|
| 399 |
+
4 -> 1
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
Expected local_world_size_map:
|
| 403 |
+
{
|
| 404 |
+
0 -> 3,
|
| 405 |
+
1 -> 3,
|
| 406 |
+
2 -> 2,
|
| 407 |
+
3 -> 3,
|
| 408 |
+
4 -> 2
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
Expected node_rank_map:
|
| 412 |
+
{
|
| 413 |
+
0 -> 0,
|
| 414 |
+
1 -> 0,
|
| 415 |
+
2 -> 1,
|
| 416 |
+
3 -> 0,
|
| 417 |
+
4 -> 1
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
"""
|
| 421 |
+
local_rank_map = {} # map from world rank to local rank
|
| 422 |
+
local_world_size_map = {} # map from world rank to local world size
|
| 423 |
+
node_rank_map = {} # map from world rank to node rank
|
| 424 |
+
node_ids = {} # map from node id to node index
|
| 425 |
+
node_cnt = 0 # count the number of nodes
|
| 426 |
+
|
| 427 |
+
node_id_dict = defaultdict(
|
| 428 |
+
int
|
| 429 |
+
) # map from node id to the number of workers on it.
|
| 430 |
+
for world_rank in range(len(self.worker_group)):
|
| 431 |
+
worker = self.worker_group.workers[world_rank]
|
| 432 |
+
node_id = worker.metadata.node_id
|
| 433 |
+
local_rank_map[world_rank] = node_id_dict[node_id]
|
| 434 |
+
node_id_dict[node_id] += 1
|
| 435 |
+
|
| 436 |
+
if node_id not in node_ids:
|
| 437 |
+
node_ids[node_id] = node_cnt
|
| 438 |
+
node_cnt += 1
|
| 439 |
+
node_rank_map[world_rank] = node_ids[node_id]
|
| 440 |
+
|
| 441 |
+
for world_rank in range(len(self.worker_group)):
|
| 442 |
+
worker = self.worker_group.workers[world_rank]
|
| 443 |
+
node_id = worker.metadata.node_id
|
| 444 |
+
local_world_size_map[world_rank] = node_id_dict[node_id]
|
| 445 |
+
|
| 446 |
+
workers_info = "\n".join(
|
| 447 |
+
[
|
| 448 |
+
f"- (node_id={w.metadata.node_id}, ip={w.metadata.node_ip}, "
|
| 449 |
+
f"pid={w.metadata.pid}) world_rank={i}, "
|
| 450 |
+
f"local_rank={local_rank_map[i]}, node_rank={node_rank_map[i]}"
|
| 451 |
+
for i, w in enumerate(self.worker_group.workers)
|
| 452 |
+
]
|
| 453 |
+
)
|
| 454 |
+
logger.info(f"Started distributed worker processes: \n{workers_info}")
|
| 455 |
+
|
| 456 |
+
return local_rank_map, local_world_size_map, node_rank_map
|
| 457 |
+
|
| 458 |
+
def start_training(
|
| 459 |
+
self,
|
| 460 |
+
train_func: Callable[[], T],
|
| 461 |
+
datasets: Dict[str, Dataset],
|
| 462 |
+
metadata: Dict[str, Any],
|
| 463 |
+
data_config: DataConfig,
|
| 464 |
+
storage: StorageContext,
|
| 465 |
+
checkpoint: Optional[Checkpoint] = None,
|
| 466 |
+
) -> None:
|
| 467 |
+
"""Executes a training function on all workers in a separate thread.
|
| 468 |
+
|
| 469 |
+
``finish_training`` should be called after this.
|
| 470 |
+
|
| 471 |
+
Args:
|
| 472 |
+
train_func: The training function to run on each worker.
|
| 473 |
+
datasets: The base datasets.
|
| 474 |
+
data_config: The config object for creating dataset shards for workers.
|
| 475 |
+
checkpoint: The checkpoint data that
|
| 476 |
+
should be loaded onto each worker and accessed by the
|
| 477 |
+
training function via ``session.get_checkpoint()``. If this
|
| 478 |
+
is ``None`` then no checkpoint will be loaded.
|
| 479 |
+
"""
|
| 480 |
+
use_detailed_autofilled_metrics = env_integer(
|
| 481 |
+
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, 0
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
# First initialize the session.
|
| 485 |
+
def initialize_session(
|
| 486 |
+
train_func,
|
| 487 |
+
world_rank,
|
| 488 |
+
local_rank,
|
| 489 |
+
node_rank,
|
| 490 |
+
local_world_size,
|
| 491 |
+
world_size,
|
| 492 |
+
trial_info,
|
| 493 |
+
checkpoint,
|
| 494 |
+
dataset_shard,
|
| 495 |
+
metadata,
|
| 496 |
+
storage,
|
| 497 |
+
):
|
| 498 |
+
try:
|
| 499 |
+
init_session(
|
| 500 |
+
training_func=train_func,
|
| 501 |
+
world_rank=world_rank,
|
| 502 |
+
local_rank=local_rank,
|
| 503 |
+
node_rank=node_rank,
|
| 504 |
+
local_world_size=local_world_size,
|
| 505 |
+
world_size=world_size,
|
| 506 |
+
trial_info=trial_info,
|
| 507 |
+
dataset_shard=dataset_shard,
|
| 508 |
+
metadata=metadata,
|
| 509 |
+
checkpoint=checkpoint,
|
| 510 |
+
detailed_autofilled_metrics=use_detailed_autofilled_metrics,
|
| 511 |
+
storage=storage,
|
| 512 |
+
)
|
| 513 |
+
except ValueError:
|
| 514 |
+
raise TrainBackendError(
|
| 515 |
+
"Attempting to start training but a "
|
| 516 |
+
"previous training run is still ongoing. "
|
| 517 |
+
"You must call `finish_training` before "
|
| 518 |
+
"calling `start_training` again."
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
if self.dataset_shards is None:
|
| 522 |
+
actors = [worker.actor for worker in self.worker_group.workers]
|
| 523 |
+
node_ids = [worker.metadata.node_id for worker in self.worker_group.workers]
|
| 524 |
+
self.dataset_shards = data_config.configure(
|
| 525 |
+
datasets,
|
| 526 |
+
world_size=len(self.worker_group),
|
| 527 |
+
worker_handles=actors,
|
| 528 |
+
worker_node_ids=node_ids,
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
(
|
| 532 |
+
local_rank_map,
|
| 533 |
+
local_world_size_map,
|
| 534 |
+
node_rank_map,
|
| 535 |
+
) = self._create_rank_world_size_mappings()
|
| 536 |
+
|
| 537 |
+
futures = []
|
| 538 |
+
for index in range(len(self.worker_group)):
|
| 539 |
+
futures.append(
|
| 540 |
+
self.worker_group.execute_single_async(
|
| 541 |
+
index,
|
| 542 |
+
initialize_session,
|
| 543 |
+
world_rank=index,
|
| 544 |
+
local_rank=local_rank_map[index],
|
| 545 |
+
node_rank=node_rank_map[index],
|
| 546 |
+
local_world_size=local_world_size_map[index],
|
| 547 |
+
world_size=len(self.worker_group),
|
| 548 |
+
trial_info=self._trial_info,
|
| 549 |
+
train_func=train_func,
|
| 550 |
+
dataset_shard=self.dataset_shards[index],
|
| 551 |
+
metadata=metadata,
|
| 552 |
+
checkpoint=checkpoint,
|
| 553 |
+
storage=storage,
|
| 554 |
+
)
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
self._backend.on_training_start(self.worker_group, self._backend_config)
|
| 558 |
+
|
| 559 |
+
self.get_with_failure_handling(futures)
|
| 560 |
+
|
| 561 |
+
# Register Train Run before training starts
|
| 562 |
+
if self.state_tracking_enabled:
|
| 563 |
+
from ray.train._internal.state.schema import RunStatusEnum
|
| 564 |
+
|
| 565 |
+
core_context = ray.runtime_context.get_runtime_context()
|
| 566 |
+
|
| 567 |
+
self.state_manager.register_train_run(
|
| 568 |
+
run_id=self._trial_info.run_id,
|
| 569 |
+
run_name=self._trial_info.experiment_name,
|
| 570 |
+
job_id=core_context.get_job_id(),
|
| 571 |
+
controller_actor_id=core_context.get_actor_id(),
|
| 572 |
+
datasets=datasets,
|
| 573 |
+
worker_group=self.worker_group,
|
| 574 |
+
start_time_ms=self._start_time_ms,
|
| 575 |
+
run_status=RunStatusEnum.RUNNING,
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
# Run the training function asynchronously in its own thread.
|
| 579 |
+
def train_async():
|
| 580 |
+
session = get_session()
|
| 581 |
+
session.start()
|
| 582 |
+
|
| 583 |
+
self.worker_group.execute_async(train_async)
|
| 584 |
+
|
| 585 |
+
def get_next_results(self) -> Optional[List[_TrainingResult]]:
|
| 586 |
+
"""Fetches the next ``_TrainingResult`` from each worker.
|
| 587 |
+
|
| 588 |
+
Each ``_TrainingResult`` is expected to correspond to the same step from
|
| 589 |
+
each worker (e.g. the same call to ``train.report()``).
|
| 590 |
+
|
| 591 |
+
Returns:
|
| 592 |
+
A list of ``_TrainingResult``s or ``None`` if there are no more results
|
| 593 |
+
since the training function has exited on all workers.
|
| 594 |
+
"""
|
| 595 |
+
|
| 596 |
+
def get_next():
|
| 597 |
+
session = _get_session("get_next_results")
|
| 598 |
+
try:
|
| 599 |
+
result = session.get_next()
|
| 600 |
+
except RuntimeError:
|
| 601 |
+
# Training thread has not been started yet.
|
| 602 |
+
raise TrainBackendError(
|
| 603 |
+
"`get_next_results` has been called "
|
| 604 |
+
"before `start_training`. Please call "
|
| 605 |
+
"`start_training` before "
|
| 606 |
+
"`get_next_results`."
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
return result
|
| 610 |
+
|
| 611 |
+
# Get next result from each worker.
|
| 612 |
+
futures = self.worker_group.execute_async(get_next)
|
| 613 |
+
results = self.get_with_failure_handling(futures)
|
| 614 |
+
|
| 615 |
+
# Check if any worker returned None.
|
| 616 |
+
if any(r is None for r in results):
|
| 617 |
+
# Either all workers have results or none of them do.
|
| 618 |
+
if not all(r is None for r in results):
|
| 619 |
+
raise RuntimeError(
|
| 620 |
+
"Some workers returned results while "
|
| 621 |
+
"others didn't. Make sure that "
|
| 622 |
+
"`session.report()` are called the "
|
| 623 |
+
"same number of times on all workers."
|
| 624 |
+
)
|
| 625 |
+
else:
|
| 626 |
+
# Return None if all results are None.
|
| 627 |
+
return None
|
| 628 |
+
|
| 629 |
+
return results
|
| 630 |
+
|
| 631 |
+
def pause_reporting(self):
|
| 632 |
+
"""Disable workers from enqueuing results from ``session.report()``.
|
| 633 |
+
|
| 634 |
+
Note: Already reported results may still be enqueued at this point,
|
| 635 |
+
and should be handled appropriately.
|
| 636 |
+
"""
|
| 637 |
+
|
| 638 |
+
def pause_session_reporting():
|
| 639 |
+
session = _get_session("pause_reporting")
|
| 640 |
+
return session.pause_reporting()
|
| 641 |
+
|
| 642 |
+
futures = self.worker_group.execute_async(pause_session_reporting)
|
| 643 |
+
self.get_with_failure_handling(futures)
|
| 644 |
+
|
| 645 |
+
def finish_training(self):
|
| 646 |
+
"""Finish training and return final results. Propagate any exceptions.
|
| 647 |
+
|
| 648 |
+
Blocks until training is finished on all workers.
|
| 649 |
+
|
| 650 |
+
Assumes `start_training` has already been called.
|
| 651 |
+
|
| 652 |
+
Returns:
|
| 653 |
+
A list of return values from calling ``train_func`` on each worker.
|
| 654 |
+
Each item corresponds to the return value from a single worker.
|
| 655 |
+
"""
|
| 656 |
+
|
| 657 |
+
def end_training():
|
| 658 |
+
session = _get_session("finish_training")
|
| 659 |
+
try:
|
| 660 |
+
# session.finish raises any Exceptions from training.
|
| 661 |
+
output = session.finish()
|
| 662 |
+
finally:
|
| 663 |
+
# Shutdown session even if session.finish() raises an
|
| 664 |
+
# Exception.
|
| 665 |
+
shutdown_session()
|
| 666 |
+
|
| 667 |
+
return output
|
| 668 |
+
|
| 669 |
+
futures = self.worker_group.execute_async(end_training)
|
| 670 |
+
results = self.get_with_failure_handling(futures)
|
| 671 |
+
return results
|
| 672 |
+
|
| 673 |
+
def report_final_run_status(
|
| 674 |
+
self,
|
| 675 |
+
errored: bool = False,
|
| 676 |
+
failed_rank: Optional[int] = None,
|
| 677 |
+
stack_trace: Optional[str] = None,
|
| 678 |
+
):
|
| 679 |
+
"""Report the final train run status, error, and end time to TrainStateActor."""
|
| 680 |
+
if self.state_tracking_enabled:
|
| 681 |
+
from ray.train._internal.state.schema import (
|
| 682 |
+
MAX_ERROR_STACK_TRACE_LENGTH,
|
| 683 |
+
RunStatusEnum,
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
if errored:
|
| 687 |
+
run_status = RunStatusEnum.ERRORED
|
| 688 |
+
status_detail = ""
|
| 689 |
+
if failed_rank is not None:
|
| 690 |
+
status_detail += f"Rank {failed_rank} worker raised an error. \n"
|
| 691 |
+
if stack_trace is not None:
|
| 692 |
+
# Keep only the last part of the stack trace if it's too long.
|
| 693 |
+
status_detail += stack_trace[-MAX_ERROR_STACK_TRACE_LENGTH:]
|
| 694 |
+
else:
|
| 695 |
+
run_status = RunStatusEnum.FINISHED
|
| 696 |
+
status_detail = ""
|
| 697 |
+
|
| 698 |
+
self.state_manager.end_train_run(
|
| 699 |
+
run_id=self._trial_info.run_id,
|
| 700 |
+
run_status=run_status,
|
| 701 |
+
status_detail=status_detail,
|
| 702 |
+
end_time_ms=int(time.time() * 1000),
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
def get_with_failure_handling(self, remote_values):
|
| 706 |
+
"""Gets the remote values while handling for worker failures.
|
| 707 |
+
|
| 708 |
+
This method should be called instead of ``ray.get()`` directly in
|
| 709 |
+
order to handle worker failures.
|
| 710 |
+
|
| 711 |
+
If a worker failure is identified, backend specific failure handling
|
| 712 |
+
is executed and a ``TrainingWorkerError`` is raised.
|
| 713 |
+
|
| 714 |
+
Args:
|
| 715 |
+
remote_values: List of object refs representing functions
|
| 716 |
+
that may fail in the middle of execution. For example, running
|
| 717 |
+
a Train training loop in multiple parallel actor calls.
|
| 718 |
+
Returns:
|
| 719 |
+
The resolved objects represented by the passed in ObjectRefs.
|
| 720 |
+
"""
|
| 721 |
+
success, exception = check_for_failure(remote_values)
|
| 722 |
+
if success:
|
| 723 |
+
return ray.get(remote_values)
|
| 724 |
+
else:
|
| 725 |
+
self._last_failure = exception
|
| 726 |
+
self._increment_failures()
|
| 727 |
+
logger.warning(
|
| 728 |
+
"Failure identified during training. Restarting all workers and "
|
| 729 |
+
"continuing training from latest checkpoint."
|
| 730 |
+
)
|
| 731 |
+
self._restart()
|
| 732 |
+
raise TrainingWorkerError
|
| 733 |
+
|
| 734 |
+
def shutdown(self, graceful_termination: bool = True):
|
| 735 |
+
"""Shuts down the workers in the worker group.
|
| 736 |
+
|
| 737 |
+
Args:
|
| 738 |
+
graceful_termination: If set to True, attempt to clean up the backend
|
| 739 |
+
before terminating the Ray actors.
|
| 740 |
+
|
| 741 |
+
"""
|
| 742 |
+
if graceful_termination:
|
| 743 |
+
try:
|
| 744 |
+
self._backend.on_shutdown(self.worker_group, self._backend_config)
|
| 745 |
+
except RayActorError:
|
| 746 |
+
logger.warning(
|
| 747 |
+
"Graceful shutdown of backend failed. This is "
|
| 748 |
+
"expected if one of the workers has crashed."
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
if graceful_termination:
|
| 752 |
+
self.worker_group.shutdown()
|
| 753 |
+
else:
|
| 754 |
+
self.worker_group.shutdown(patience_s=0)
|
| 755 |
+
self.worker_group = InactiveWorkerGroup()
|
| 756 |
+
|
| 757 |
+
if self._placement_group:
|
| 758 |
+
remove_placement_group(self._placement_group)
|
| 759 |
+
self._placement_group = None
|
| 760 |
+
|
| 761 |
+
self.dataset_shards = None
|
| 762 |
+
|
| 763 |
+
def is_started(self):
|
| 764 |
+
return not isinstance(self.worker_group, InactiveWorkerGroup)
|
| 765 |
+
|
| 766 |
+
def _restart(self):
|
| 767 |
+
self.worker_group.shutdown()
|
| 768 |
+
if self._initialization_hook is not None:
|
| 769 |
+
initialization_hook = self._initialization_hook
|
| 770 |
+
else:
|
| 771 |
+
initialization_hook = None
|
| 772 |
+
if self._placement_group:
|
| 773 |
+
remove_placement_group(self._placement_group)
|
| 774 |
+
self._placement_group = None
|
| 775 |
+
self.start(initialization_hook=initialization_hook)
|
| 776 |
+
|
| 777 |
+
def _increment_failures(self):
|
| 778 |
+
self._num_failures += 1
|
| 779 |
+
if self._num_failures >= self._max_failures:
|
| 780 |
+
failure = self._last_failure
|
| 781 |
+
self._last_failure = None
|
| 782 |
+
if self._max_failures > 0:
|
| 783 |
+
exc = RuntimeError(
|
| 784 |
+
"Training has failed after " f"{self._num_failures} " "attempts."
|
| 785 |
+
)
|
| 786 |
+
raise exc.with_traceback(None) from failure
|
| 787 |
+
else:
|
| 788 |
+
raise failure
|
| 789 |
+
|
| 790 |
+
def get_worker_group(self):
|
| 791 |
+
return self.worker_group
|
| 792 |
+
|
| 793 |
+
def _get_num_failures(self):
|
| 794 |
+
return self._num_failures
|
| 795 |
+
|
| 796 |
+
|
| 797 |
+
class InactiveWorkerGroupError(Exception):
|
| 798 |
+
"""Raised when underlying worker group is inactive."""
|
| 799 |
+
|
| 800 |
+
|
| 801 |
+
class InactiveWorkerGroup:
|
| 802 |
+
# TODO: fix inheritence. perhaps create WorkerGroupInterface.
|
| 803 |
+
|
| 804 |
+
# Need to define getstate and setstate so that getattr does not screwup
|
| 805 |
+
# pickling. See https://stackoverflow.com/a/50888571/11249691
|
| 806 |
+
def __getstate__(self):
|
| 807 |
+
return vars(self)
|
| 808 |
+
|
| 809 |
+
def __setstate__(self, state):
|
| 810 |
+
vars(self).update(state)
|
| 811 |
+
|
| 812 |
+
def __getattr__(self, name):
|
| 813 |
+
raise InactiveWorkerGroupError()
|
| 814 |
+
|
| 815 |
+
def __len__(self):
|
| 816 |
+
raise InactiveWorkerGroupError()
|
| 817 |
+
|
| 818 |
+
|
| 819 |
+
def _get_session(method_name: str):
|
| 820 |
+
# Get the session for this worker.
|
| 821 |
+
session = get_session()
|
| 822 |
+
if not session:
|
| 823 |
+
# Session is not initialized yet.
|
| 824 |
+
raise TrainBackendError(
|
| 825 |
+
f"`{method_name}` has been called "
|
| 826 |
+
"before `start_training`. Please call "
|
| 827 |
+
"`start_training` before "
|
| 828 |
+
f"`{method_name}`."
|
| 829 |
+
)
|
| 830 |
+
return session
|
.venv/lib/python3.11/site-packages/ray/train/_internal/checkpoint_manager.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import numbers
|
| 3 |
+
from typing import Any, Callable, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
from ray._private.dict import flatten_dict
|
| 6 |
+
from ray.air._internal.util import is_nan
|
| 7 |
+
from ray.air.config import MAX
|
| 8 |
+
from ray.train import CheckpointConfig
|
| 9 |
+
from ray.train._internal.session import _TrainingResult
|
| 10 |
+
from ray.train._internal.storage import _delete_fs_path
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _insert_into_sorted_list(list: List[Any], item: Any, key: Callable[[Any], Any]):
|
| 16 |
+
"""Insert an item into a sorted list with a custom key function.
|
| 17 |
+
|
| 18 |
+
Examples:
|
| 19 |
+
|
| 20 |
+
>>> list = []
|
| 21 |
+
>>> _insert_into_sorted_list(list, {"a": 1, "b": 0}, lambda x: x["a"])
|
| 22 |
+
>>> list
|
| 23 |
+
[{'a': 1, 'b': 0}]
|
| 24 |
+
>>> _insert_into_sorted_list(list, {"a": 3, "b": 1}, lambda x: x["a"])
|
| 25 |
+
>>> list
|
| 26 |
+
[{'a': 1, 'b': 0}, {'a': 3, 'b': 1}]
|
| 27 |
+
>>> _insert_into_sorted_list(list, {"a": 4, "b": 2}, lambda x: x["a"])
|
| 28 |
+
>>> list
|
| 29 |
+
[{'a': 1, 'b': 0}, {'a': 3, 'b': 1}, {'a': 4, 'b': 2}]
|
| 30 |
+
>>> _insert_into_sorted_list(list, {"a": 1, "b": 3}, lambda x: x["a"])
|
| 31 |
+
>>> list
|
| 32 |
+
[{'a': 1, 'b': 0}, {'a': 1, 'b': 3}, {'a': 3, 'b': 1}, {'a': 4, 'b': 2}]
|
| 33 |
+
"""
|
| 34 |
+
i = 0
|
| 35 |
+
while i < len(list):
|
| 36 |
+
# Insert to the right of all duplicates.
|
| 37 |
+
if key(list[i]) > key(item):
|
| 38 |
+
break
|
| 39 |
+
i += 1
|
| 40 |
+
list.insert(i, item)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class _CheckpointManager:
|
| 44 |
+
"""Checkpoint manager that handles checkpoint book-keeping for a trial.
|
| 45 |
+
|
| 46 |
+
The main purpose of this abstraction is to keep the top K checkpoints based on
|
| 47 |
+
recency/a user-provided metric.
|
| 48 |
+
|
| 49 |
+
NOTE: This class interacts with `_TrainingResult` objects, which are
|
| 50 |
+
(checkpoint, metrics) pairs. This is to order checkpoints by metrics.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
checkpoint_config: Defines how many and which checkpoints to keep.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(self, checkpoint_config: Optional[CheckpointConfig]):
|
| 57 |
+
self._checkpoint_config = checkpoint_config or CheckpointConfig()
|
| 58 |
+
|
| 59 |
+
# List of checkpoints ordered by ascending score.
|
| 60 |
+
self._checkpoint_results: List[_TrainingResult] = []
|
| 61 |
+
|
| 62 |
+
# The latest registered checkpoint.
|
| 63 |
+
# This should never be immediately deleted upon registration,
|
| 64 |
+
# even if it's not in the top K checkpoints, based on score.
|
| 65 |
+
self._latest_checkpoint_result: Optional[_TrainingResult] = None
|
| 66 |
+
|
| 67 |
+
if (
|
| 68 |
+
self._checkpoint_config.num_to_keep is not None
|
| 69 |
+
and self._checkpoint_config.num_to_keep <= 0
|
| 70 |
+
):
|
| 71 |
+
raise ValueError(
|
| 72 |
+
f"`num_to_keep` must >= 1, got: "
|
| 73 |
+
f"{self._checkpoint_config.num_to_keep}"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
@property
|
| 77 |
+
def checkpoint_config(self):
|
| 78 |
+
return self._checkpoint_config
|
| 79 |
+
|
| 80 |
+
def register_checkpoint(self, checkpoint_result: _TrainingResult):
|
| 81 |
+
"""Register new checkpoint and add to bookkeeping.
|
| 82 |
+
|
| 83 |
+
This method will register a new checkpoint and add it to the internal
|
| 84 |
+
bookkeeping logic. This means the checkpoint manager will decide if
|
| 85 |
+
this checkpoint should be kept, and if older or worse performing
|
| 86 |
+
checkpoints should be deleted.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
checkpoint: Tracked checkpoint object to add to bookkeeping.
|
| 90 |
+
"""
|
| 91 |
+
self._latest_checkpoint_result = checkpoint_result
|
| 92 |
+
|
| 93 |
+
if self._checkpoint_config.checkpoint_score_attribute is not None:
|
| 94 |
+
# If we're ordering by a score, insert the checkpoint
|
| 95 |
+
# so that the list remains sorted.
|
| 96 |
+
_insert_into_sorted_list(
|
| 97 |
+
self._checkpoint_results,
|
| 98 |
+
checkpoint_result,
|
| 99 |
+
key=self._get_checkpoint_score,
|
| 100 |
+
)
|
| 101 |
+
else:
|
| 102 |
+
# If no metric is provided, just append (ordering by time of registration).
|
| 103 |
+
self._checkpoint_results.append(checkpoint_result)
|
| 104 |
+
|
| 105 |
+
if self._checkpoint_config.num_to_keep is not None:
|
| 106 |
+
# Delete the bottom (N - K) checkpoints
|
| 107 |
+
worst_results = set(
|
| 108 |
+
self._checkpoint_results[: -self._checkpoint_config.num_to_keep]
|
| 109 |
+
)
|
| 110 |
+
# Except for the latest checkpoint.
|
| 111 |
+
results_to_delete = worst_results - {self._latest_checkpoint_result}
|
| 112 |
+
|
| 113 |
+
# Update internal state before actually deleting them.
|
| 114 |
+
self._checkpoint_results = [
|
| 115 |
+
checkpoint_result
|
| 116 |
+
for checkpoint_result in self._checkpoint_results
|
| 117 |
+
if checkpoint_result not in results_to_delete
|
| 118 |
+
]
|
| 119 |
+
|
| 120 |
+
for checkpoint_result in results_to_delete:
|
| 121 |
+
checkpoint = checkpoint_result.checkpoint
|
| 122 |
+
logger.debug("Deleting checkpoint: ", checkpoint)
|
| 123 |
+
_delete_fs_path(fs=checkpoint.filesystem, fs_path=checkpoint.path)
|
| 124 |
+
|
| 125 |
+
def _get_checkpoint_score(
|
| 126 |
+
self, checkpoint: _TrainingResult
|
| 127 |
+
) -> Tuple[bool, numbers.Number]:
|
| 128 |
+
"""Get the score for a checkpoint, according to checkpoint config.
|
| 129 |
+
|
| 130 |
+
If `mode="min"`, the metric is negated so that the lowest score is
|
| 131 |
+
treated as the best.
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
Tuple: A tuple of (not_is_nan: bool, score: numbers.Number).
|
| 135 |
+
This score orders: nan values < float("-inf") < valid numeric metrics
|
| 136 |
+
"""
|
| 137 |
+
checkpoint_score_attribute = self._checkpoint_config.checkpoint_score_attribute
|
| 138 |
+
if checkpoint_score_attribute:
|
| 139 |
+
flat_metrics = flatten_dict(checkpoint.metrics)
|
| 140 |
+
try:
|
| 141 |
+
checkpoint_result = flat_metrics[checkpoint_score_attribute]
|
| 142 |
+
except KeyError:
|
| 143 |
+
valid_keys = list(flat_metrics.keys())
|
| 144 |
+
logger.error(
|
| 145 |
+
f"Result dict has no key: {checkpoint_score_attribute}. "
|
| 146 |
+
f"checkpoint_score_attr must be set to a key in the "
|
| 147 |
+
f"result dict. Valid keys are: {valid_keys}"
|
| 148 |
+
)
|
| 149 |
+
checkpoint_result = float("-inf")
|
| 150 |
+
else:
|
| 151 |
+
checkpoint_result = float("-inf")
|
| 152 |
+
|
| 153 |
+
checkpoint_score_order = self._checkpoint_config.checkpoint_score_order
|
| 154 |
+
order_factor = 1.0 if checkpoint_score_order == MAX else -1.0
|
| 155 |
+
|
| 156 |
+
checkpoint_score = order_factor * checkpoint_result
|
| 157 |
+
|
| 158 |
+
if not isinstance(checkpoint_score, numbers.Number):
|
| 159 |
+
raise ValueError(
|
| 160 |
+
f"Unable to persist checkpoint for "
|
| 161 |
+
f"checkpoint_score_attribute: "
|
| 162 |
+
f"{checkpoint_score_attribute} with value "
|
| 163 |
+
f"{checkpoint_score}. "
|
| 164 |
+
f"This attribute must be numerical."
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
return (
|
| 168 |
+
(not is_nan(checkpoint_score), checkpoint_score)
|
| 169 |
+
if not is_nan(checkpoint_score)
|
| 170 |
+
else (False, float("-inf"))
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
@property
|
| 174 |
+
def best_checkpoint_result(self) -> Optional[_TrainingResult]:
|
| 175 |
+
return self._checkpoint_results[-1] if self._checkpoint_results else None
|
| 176 |
+
|
| 177 |
+
@property
|
| 178 |
+
def latest_checkpoint_result(self) -> Optional[_TrainingResult]:
|
| 179 |
+
return self._latest_checkpoint_result
|
| 180 |
+
|
| 181 |
+
@property
|
| 182 |
+
def best_checkpoint_results(self) -> List[_TrainingResult]:
|
| 183 |
+
if self._checkpoint_config.num_to_keep is None:
|
| 184 |
+
return self._checkpoint_results
|
| 185 |
+
return self._checkpoint_results[-self._checkpoint_config.num_to_keep :]
|
.venv/lib/python3.11/site-packages/ray/train/_internal/data_config.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from typing import Dict, List, Literal, Optional, Union
|
| 3 |
+
|
| 4 |
+
import ray
|
| 5 |
+
from ray.actor import ActorHandle
|
| 6 |
+
from ray.data import DataIterator, Dataset, ExecutionOptions, NodeIdStr
|
| 7 |
+
from ray.data._internal.execution.interfaces.execution_options import ExecutionResources
|
| 8 |
+
from ray.util.annotations import DeveloperAPI, PublicAPI
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@PublicAPI(stability="stable")
|
| 12 |
+
class DataConfig:
|
| 13 |
+
"""Class responsible for configuring Train dataset preprocessing.
|
| 14 |
+
|
| 15 |
+
For advanced use cases, this class can be subclassed and the `configure()` method
|
| 16 |
+
overriden for custom data preprocessing.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
datasets_to_split: Union[Literal["all"], List[str]] = "all",
|
| 22 |
+
execution_options: Optional[ExecutionOptions] = None,
|
| 23 |
+
):
|
| 24 |
+
"""Construct a DataConfig.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
datasets_to_split: Specifies which datasets should be split among workers.
|
| 28 |
+
Can be set to "all" or a list of dataset names. Defaults to "all",
|
| 29 |
+
i.e. split all datasets.
|
| 30 |
+
execution_options: The execution options to pass to Ray Data. By default,
|
| 31 |
+
the options will be optimized for data ingest. When overriding this,
|
| 32 |
+
base your options off of `DataConfig.default_ingest_options()`.
|
| 33 |
+
"""
|
| 34 |
+
if isinstance(datasets_to_split, list) or datasets_to_split == "all":
|
| 35 |
+
self._datasets_to_split = datasets_to_split
|
| 36 |
+
else:
|
| 37 |
+
raise TypeError(
|
| 38 |
+
"`datasets_to_split` should be a 'all' or a list of strings of "
|
| 39 |
+
"dataset names. Received "
|
| 40 |
+
f"{type(datasets_to_split).__name__} with value {datasets_to_split}."
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
self._execution_options: ExecutionOptions = (
|
| 44 |
+
execution_options or DataConfig.default_ingest_options()
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
self._num_train_cpus = 0.0
|
| 48 |
+
self._num_train_gpus = 0.0
|
| 49 |
+
|
| 50 |
+
def set_train_total_resources(self, num_train_cpus: float, num_train_gpus: float):
|
| 51 |
+
"""Set the total number of CPUs and GPUs used by training.
|
| 52 |
+
|
| 53 |
+
If CPU or GPU resource limits are not set, they will be set to the
|
| 54 |
+
total cluster resources minus the resources used by training.
|
| 55 |
+
"""
|
| 56 |
+
# TODO: We may also include other resources besides CPU and GPU.
|
| 57 |
+
self._num_train_cpus = num_train_cpus
|
| 58 |
+
self._num_train_gpus = num_train_gpus
|
| 59 |
+
|
| 60 |
+
@DeveloperAPI
|
| 61 |
+
def configure(
|
| 62 |
+
self,
|
| 63 |
+
datasets: Dict[str, Dataset],
|
| 64 |
+
world_size: int,
|
| 65 |
+
worker_handles: Optional[List[ActorHandle]],
|
| 66 |
+
worker_node_ids: Optional[List[NodeIdStr]],
|
| 67 |
+
**kwargs,
|
| 68 |
+
) -> List[Dict[str, DataIterator]]:
|
| 69 |
+
"""Configure how Train datasets should be assigned to workers.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
datasets: The datasets dict passed to Train by the user.
|
| 73 |
+
world_size: The number of Train workers in total.
|
| 74 |
+
worker_handles: The actor handles of the Train workers.
|
| 75 |
+
worker_node_ids: The node ids of the Train workers.
|
| 76 |
+
kwargs: Forwards compatibility placeholder.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
A list of dataset splits for each worker. The size of the list must be
|
| 80 |
+
equal to `world_size`. Each element of the list contains the assigned
|
| 81 |
+
`DataIterator` instances by name for the worker.
|
| 82 |
+
"""
|
| 83 |
+
output = [{} for _ in range(world_size)]
|
| 84 |
+
|
| 85 |
+
if self._datasets_to_split == "all":
|
| 86 |
+
datasets_to_split = set(datasets.keys())
|
| 87 |
+
else:
|
| 88 |
+
datasets_to_split = set(self._datasets_to_split)
|
| 89 |
+
|
| 90 |
+
locality_hints = (
|
| 91 |
+
worker_node_ids if self._execution_options.locality_with_output else None
|
| 92 |
+
)
|
| 93 |
+
for name, ds in datasets.items():
|
| 94 |
+
execution_options = copy.deepcopy(self._execution_options)
|
| 95 |
+
|
| 96 |
+
if execution_options.is_resource_limits_default():
|
| 97 |
+
# If "resource_limits" is not overriden by the user,
|
| 98 |
+
# add training-reserved resources to Data's exclude_resources.
|
| 99 |
+
execution_options.exclude_resources = (
|
| 100 |
+
execution_options.exclude_resources.add(
|
| 101 |
+
ExecutionResources(
|
| 102 |
+
cpu=self._num_train_cpus, gpu=self._num_train_gpus
|
| 103 |
+
)
|
| 104 |
+
)
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
ds = ds.copy(ds)
|
| 108 |
+
ds.context.execution_options = execution_options
|
| 109 |
+
|
| 110 |
+
if name in datasets_to_split:
|
| 111 |
+
for i, split in enumerate(
|
| 112 |
+
ds.streaming_split(
|
| 113 |
+
world_size, equal=True, locality_hints=locality_hints
|
| 114 |
+
)
|
| 115 |
+
):
|
| 116 |
+
output[i][name] = split
|
| 117 |
+
else:
|
| 118 |
+
for i in range(world_size):
|
| 119 |
+
output[i][name] = ds.iterator()
|
| 120 |
+
|
| 121 |
+
return output
|
| 122 |
+
|
| 123 |
+
@staticmethod
|
| 124 |
+
def default_ingest_options() -> ExecutionOptions:
|
| 125 |
+
"""The default Ray Data options used for data ingest.
|
| 126 |
+
|
| 127 |
+
By default, configurations are carried over from what is already set
|
| 128 |
+
in DataContext.
|
| 129 |
+
"""
|
| 130 |
+
ctx = ray.data.DataContext.get_current()
|
| 131 |
+
return ExecutionOptions(
|
| 132 |
+
# TODO(hchen): Re-enable `locality_with_output` by default after fixing
|
| 133 |
+
# https://github.com/ray-project/ray/issues/40607
|
| 134 |
+
locality_with_output=ctx.execution_options.locality_with_output,
|
| 135 |
+
resource_limits=ctx.execution_options.resource_limits,
|
| 136 |
+
exclude_resources=ctx.execution_options.exclude_resources,
|
| 137 |
+
preserve_order=ctx.execution_options.preserve_order,
|
| 138 |
+
verbose_progress=ctx.execution_options.verbose_progress,
|
| 139 |
+
)
|
.venv/lib/python3.11/site-packages/ray/train/_internal/dl_predictor.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from typing import Dict, Optional, TypeVar, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
|
| 7 |
+
from ray.air.util.data_batch_conversion import (
|
| 8 |
+
BatchFormat,
|
| 9 |
+
_convert_batch_type_to_pandas,
|
| 10 |
+
_convert_pandas_to_batch_type,
|
| 11 |
+
)
|
| 12 |
+
from ray.train.predictor import Predictor
|
| 13 |
+
from ray.util.annotations import DeveloperAPI
|
| 14 |
+
|
| 15 |
+
TensorType = TypeVar("TensorType")
|
| 16 |
+
TensorDtype = TypeVar("TensorDtype")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class DLPredictor(Predictor):
|
| 20 |
+
@abc.abstractmethod
|
| 21 |
+
def _arrays_to_tensors(
|
| 22 |
+
self,
|
| 23 |
+
numpy_arrays: Union[np.ndarray, Dict[str, np.ndarray]],
|
| 24 |
+
dtype: Optional[Union[TensorDtype, Dict[str, TensorDtype]]],
|
| 25 |
+
) -> Union[TensorType, Dict[str, TensorType]]:
|
| 26 |
+
"""Converts a NumPy ndarray batch to the tensor type for the DL framework.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
numpy_array: The numpy array to convert to a tensor.
|
| 30 |
+
dtype: The tensor dtype to use when creating the DL tensor.
|
| 31 |
+
ndarray: A (dict of) NumPy ndarray(s) that we wish to convert to a (dict of)
|
| 32 |
+
tensor(s).
|
| 33 |
+
dtype: A (dict of) tensor dtype(s) to use when creating the DL tensor; if
|
| 34 |
+
None, the dtype will be inferred from the NumPy ndarray data.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
A deep learning framework specific tensor.
|
| 38 |
+
"""
|
| 39 |
+
raise NotImplementedError
|
| 40 |
+
|
| 41 |
+
@abc.abstractmethod
|
| 42 |
+
def _tensor_to_array(self, tensor: TensorType) -> np.ndarray:
|
| 43 |
+
"""Converts tensor framework specific tensor to a numpy array.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
tensor: A framework specific tensor.
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
A numpy array representing the input tensor.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
raise NotImplementedError
|
| 53 |
+
|
| 54 |
+
@abc.abstractmethod
|
| 55 |
+
@DeveloperAPI
|
| 56 |
+
def call_model(
|
| 57 |
+
self, inputs: Union[TensorType, Dict[str, TensorType]]
|
| 58 |
+
) -> Union[TensorType, Dict[str, TensorType]]:
|
| 59 |
+
"""Inputs the tensor to the model for this Predictor and returns the result.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
inputs: The tensor to input to the model.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
A tensor or dictionary of tensors containing the model output.
|
| 66 |
+
"""
|
| 67 |
+
raise NotImplementedError
|
| 68 |
+
|
| 69 |
+
@classmethod
|
| 70 |
+
@DeveloperAPI
|
| 71 |
+
def preferred_batch_format(cls) -> BatchFormat:
|
| 72 |
+
return BatchFormat.NUMPY
|
| 73 |
+
|
| 74 |
+
def _predict_pandas(
|
| 75 |
+
self,
|
| 76 |
+
data: pd.DataFrame,
|
| 77 |
+
dtype: Optional[Union[TensorDtype, Dict[str, TensorDtype]]],
|
| 78 |
+
) -> pd.DataFrame:
|
| 79 |
+
numpy_input = _convert_pandas_to_batch_type(
|
| 80 |
+
data,
|
| 81 |
+
BatchFormat.NUMPY,
|
| 82 |
+
self._cast_tensor_columns,
|
| 83 |
+
)
|
| 84 |
+
numpy_output = self._predict_numpy(numpy_input, dtype)
|
| 85 |
+
return _convert_batch_type_to_pandas(numpy_output)
|
| 86 |
+
|
| 87 |
+
def _predict_numpy(
|
| 88 |
+
self,
|
| 89 |
+
data: Union[np.ndarray, Dict[str, np.ndarray]],
|
| 90 |
+
dtype: Optional[Union[TensorDtype, Dict[str, TensorDtype]]],
|
| 91 |
+
) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
| 92 |
+
# Single column selection return numpy array so preprocessors can be
|
| 93 |
+
# reused in both training and prediction
|
| 94 |
+
if isinstance(data, dict) and len(data) == 1:
|
| 95 |
+
data = next(iter(data.values()))
|
| 96 |
+
model_input = self._arrays_to_tensors(data, dtype)
|
| 97 |
+
model_output = self.call_model(model_input)
|
| 98 |
+
# TODO (jiaodong): Investigate perf implication of this.
|
| 99 |
+
# Move DL Tensor to CPU and convert to numpy.
|
| 100 |
+
if isinstance(model_output, dict):
|
| 101 |
+
return {k: self._tensor_to_array(v) for k, v in model_output.items()}
|
| 102 |
+
else:
|
| 103 |
+
return {"predictions": self._tensor_to_array(model_output)}
|
.venv/lib/python3.11/site-packages/ray/train/_internal/framework_checkpoint.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import ray.cloudpickle as ray_pickle
|
| 4 |
+
from ray._private.utils import binary_to_hex, hex_to_binary
|
| 5 |
+
from ray.data.preprocessor import Preprocessor
|
| 6 |
+
from ray.train._checkpoint import Checkpoint
|
| 7 |
+
|
| 8 |
+
PREPROCESSOR_KEY = "preprocessor_pkl"
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class FrameworkCheckpoint(Checkpoint):
|
| 12 |
+
"""A checkpoint to preserve the functionality of legacy
|
| 13 |
+
framework-specific checkpoints.
|
| 14 |
+
|
| 15 |
+
Example:
|
| 16 |
+
|
| 17 |
+
>>> import tempfile
|
| 18 |
+
>>> checkpoint = FrameworkCheckpoint(tempfile.mkdtemp())
|
| 19 |
+
>>> checkpoint.get_preprocessor() is None
|
| 20 |
+
True
|
| 21 |
+
>>> preprocessor = Preprocessor()
|
| 22 |
+
>>> preprocessor._attr = 1234
|
| 23 |
+
>>> checkpoint.set_preprocessor(preprocessor)
|
| 24 |
+
>>> checkpoint.get_preprocessor()._attr
|
| 25 |
+
1234
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def get_preprocessor(self) -> Optional[Preprocessor]:
|
| 29 |
+
"""Return the preprocessor stored in the checkpoint.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
The preprocessor stored in the checkpoint, or ``None`` if no
|
| 33 |
+
preprocessor was stored.
|
| 34 |
+
"""
|
| 35 |
+
metadata = self.get_metadata()
|
| 36 |
+
preprocessor_bytes = metadata.get(PREPROCESSOR_KEY)
|
| 37 |
+
if preprocessor_bytes is None:
|
| 38 |
+
return None
|
| 39 |
+
return ray_pickle.loads(hex_to_binary(preprocessor_bytes))
|
| 40 |
+
|
| 41 |
+
def set_preprocessor(self, preprocessor: Preprocessor):
|
| 42 |
+
"""Store a preprocessor with the checkpoint."""
|
| 43 |
+
self.update_metadata(
|
| 44 |
+
{PREPROCESSOR_KEY: binary_to_hex(ray_pickle.dumps(preprocessor))}
|
| 45 |
+
)
|
.venv/lib/python3.11/site-packages/ray/train/_internal/session.py
ADDED
|
@@ -0,0 +1,1163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import platform
|
| 5 |
+
import queue
|
| 6 |
+
import sys
|
| 7 |
+
import threading
|
| 8 |
+
import time
|
| 9 |
+
import warnings
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Set, Type
|
| 13 |
+
|
| 14 |
+
import ray
|
| 15 |
+
from ray.air._internal.util import RunnerThread, StartTraceback
|
| 16 |
+
from ray.air.constants import (
|
| 17 |
+
_ERROR_FETCH_TIMEOUT,
|
| 18 |
+
_RESULT_FETCH_TIMEOUT,
|
| 19 |
+
SESSION_MISUSE_LOG_ONCE_KEY,
|
| 20 |
+
TIME_THIS_ITER_S,
|
| 21 |
+
TIMESTAMP,
|
| 22 |
+
)
|
| 23 |
+
from ray.data import Dataset
|
| 24 |
+
from ray.train import Checkpoint
|
| 25 |
+
from ray.train._internal.accelerator import Accelerator
|
| 26 |
+
from ray.train._internal.storage import StorageContext
|
| 27 |
+
from ray.train.constants import (
|
| 28 |
+
CHECKPOINT_DIR_NAME,
|
| 29 |
+
DETAILED_AUTOFILLED_KEYS,
|
| 30 |
+
RAY_CHDIR_TO_TRIAL_DIR,
|
| 31 |
+
TIME_TOTAL_S,
|
| 32 |
+
WORKER_HOSTNAME,
|
| 33 |
+
WORKER_NODE_IP,
|
| 34 |
+
WORKER_PID,
|
| 35 |
+
_v2_migration_warnings_enabled,
|
| 36 |
+
)
|
| 37 |
+
from ray.train.error import SessionMisuseError
|
| 38 |
+
from ray.train.utils import _log_deprecation_warning
|
| 39 |
+
from ray.util.annotations import DeveloperAPI, PublicAPI
|
| 40 |
+
from ray.util.debug import log_once
|
| 41 |
+
from ray.util.placement_group import _valid_resource_shape
|
| 42 |
+
from ray.util.scheduling_strategies import (
|
| 43 |
+
PlacementGroupSchedulingStrategy,
|
| 44 |
+
SchedulingStrategyT,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
if TYPE_CHECKING:
|
| 48 |
+
from ray.data import DataIterator
|
| 49 |
+
from ray.tune.execution.placement_groups import PlacementGroupFactory
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
logger = logging.getLogger(__name__)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@dataclass
|
| 56 |
+
class TrialInfo:
|
| 57 |
+
"""The trial information to propagate to TrainSession."""
|
| 58 |
+
|
| 59 |
+
name: str
|
| 60 |
+
id: str
|
| 61 |
+
resources: Dict[str, float]
|
| 62 |
+
logdir: str
|
| 63 |
+
driver_ip: str
|
| 64 |
+
driver_node_id: str
|
| 65 |
+
experiment_name: Optional[str] = None
|
| 66 |
+
run_id: Optional[str] = None
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class _FutureTrainingResult:
|
| 70 |
+
"""A future that will be resolved to a `_TrainingResult`.
|
| 71 |
+
|
| 72 |
+
This is needed for specific schedulers such as PBT that schedule saves.
|
| 73 |
+
|
| 74 |
+
This wrapper should be removed after refactoring PBT to not schedule saves anymore.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(self, future: ray.ObjectRef):
|
| 78 |
+
self.future = future
|
| 79 |
+
|
| 80 |
+
def resolve(self, block: bool = True) -> Optional["_TrainingResult"]:
|
| 81 |
+
"""Resolve into ``_TrainingResult``.
|
| 82 |
+
|
| 83 |
+
This will return None for function trainables if no checkpoint has been
|
| 84 |
+
saved before.
|
| 85 |
+
"""
|
| 86 |
+
if block:
|
| 87 |
+
timeout = None
|
| 88 |
+
else:
|
| 89 |
+
timeout = 1e-9
|
| 90 |
+
try:
|
| 91 |
+
return ray.get(self.future, timeout=timeout)
|
| 92 |
+
except TimeoutError:
|
| 93 |
+
# Not ready, yet
|
| 94 |
+
pass
|
| 95 |
+
except Exception as exc:
|
| 96 |
+
logger.error(f"Error resolving result: {exc}")
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class _TrainingResult:
|
| 100 |
+
"""A (checkpoint, metrics) result reported by the user."""
|
| 101 |
+
|
| 102 |
+
def __init__(self, checkpoint: Optional[Checkpoint], metrics: Dict[str, Any]):
|
| 103 |
+
self.checkpoint = checkpoint
|
| 104 |
+
self.metrics = metrics
|
| 105 |
+
|
| 106 |
+
def __repr__(self) -> str:
|
| 107 |
+
return f"TrainingResult(checkpoint={self.checkpoint}, metrics={self.metrics})"
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# TODO(xwjiang): This needs a better name.
|
| 111 |
+
@DeveloperAPI
|
| 112 |
+
class _TrainSession:
|
| 113 |
+
"""Holds information for training on each worker."""
|
| 114 |
+
|
| 115 |
+
def __init__(
|
| 116 |
+
self,
|
| 117 |
+
training_func: Callable,
|
| 118 |
+
world_rank: Optional[int],
|
| 119 |
+
local_rank: Optional[int],
|
| 120 |
+
node_rank: Optional[int],
|
| 121 |
+
local_world_size: Optional[int],
|
| 122 |
+
world_size: Optional[int],
|
| 123 |
+
trial_info: Optional[TrialInfo] = None,
|
| 124 |
+
dataset_shard: Optional[Dict[str, Dataset]] = None,
|
| 125 |
+
metadata: Dict[str, Any] = None,
|
| 126 |
+
checkpoint: Optional[Checkpoint] = None,
|
| 127 |
+
detailed_autofilled_metrics: bool = False,
|
| 128 |
+
storage: Optional[StorageContext] = None,
|
| 129 |
+
synchronous_result_reporting: bool = False,
|
| 130 |
+
):
|
| 131 |
+
# `synchronous_result_reporting` refers to whether or not the
|
| 132 |
+
# training function is immediately unblocked to continue running
|
| 133 |
+
# after the main thread receives its result.
|
| 134 |
+
# Ex 1: For 2 Ray Train workers with synchronous_result_reporting=True,
|
| 135 |
+
# the worker that produces a result first will immediately will continue
|
| 136 |
+
# onto the next iteration.
|
| 137 |
+
# Ex 2: For a Tune function Trainable with `synchronous_result_reporting=False`,
|
| 138 |
+
# training will only continue with an explicit call to `session.get_next`.
|
| 139 |
+
# Synchronous reporting in example 2 is needed for Tune schedulers to
|
| 140 |
+
# be able to stop the execution of the training function at will,
|
| 141 |
+
# for advanced pausing schedulers (PBT, BOHB) and actor reuse.
|
| 142 |
+
self.synchronous_result_reporting = synchronous_result_reporting
|
| 143 |
+
|
| 144 |
+
# Ray Train worker properties
|
| 145 |
+
# Note: These are set to None for Tune function Trainables.
|
| 146 |
+
self.dataset_shard = dataset_shard
|
| 147 |
+
self.metadata = metadata
|
| 148 |
+
|
| 149 |
+
self.world_rank = world_rank
|
| 150 |
+
self.local_rank = local_rank
|
| 151 |
+
self.node_rank = node_rank
|
| 152 |
+
self.local_world_size = local_world_size
|
| 153 |
+
self.world_size = world_size
|
| 154 |
+
|
| 155 |
+
assert storage
|
| 156 |
+
logger.debug(f"StorageContext on SESSION (rank={world_rank}):\n{storage}")
|
| 157 |
+
|
| 158 |
+
# NOTE: `reset` will initialize many properties needed to start running the
|
| 159 |
+
# training_func as a thread.
|
| 160 |
+
self.reset(
|
| 161 |
+
training_func=training_func,
|
| 162 |
+
trial_info=trial_info,
|
| 163 |
+
storage=storage,
|
| 164 |
+
loaded_checkpoint=checkpoint,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Autofilled metrics attributes.
|
| 168 |
+
self.detailed_autofilled_metrics = detailed_autofilled_metrics
|
| 169 |
+
self.last_report_time = time.time()
|
| 170 |
+
self.iteration = 0
|
| 171 |
+
self.time_total = 0.0
|
| 172 |
+
self.local_ip = self.get_current_ip()
|
| 173 |
+
|
| 174 |
+
self.accelerator = None
|
| 175 |
+
self._state = {}
|
| 176 |
+
|
| 177 |
+
def get_state(self, key: str) -> Any:
|
| 178 |
+
return self._state.get(key)
|
| 179 |
+
|
| 180 |
+
def set_state(self, key: str, value: Any):
|
| 181 |
+
self._state[key] = value
|
| 182 |
+
|
| 183 |
+
def get_current_ip(self):
|
| 184 |
+
self.local_ip = ray.util.get_node_ip_address()
|
| 185 |
+
return self.local_ip
|
| 186 |
+
|
| 187 |
+
def start(self):
|
| 188 |
+
"""Starts the training thread."""
|
| 189 |
+
self.training_started = True
|
| 190 |
+
self.training_thread.start()
|
| 191 |
+
|
| 192 |
+
def reset(
|
| 193 |
+
self,
|
| 194 |
+
training_func: Callable,
|
| 195 |
+
trial_info: TrialInfo,
|
| 196 |
+
storage: StorageContext,
|
| 197 |
+
loaded_checkpoint=None,
|
| 198 |
+
):
|
| 199 |
+
# This lock is used to control the execution of the training thread.
|
| 200 |
+
self.continue_lock = threading.Semaphore(0)
|
| 201 |
+
|
| 202 |
+
# This event is used to signal the training thread to stop.
|
| 203 |
+
self.stop_event = threading.Event()
|
| 204 |
+
|
| 205 |
+
# Queue for sending results across threads.
|
| 206 |
+
self.result_queue = queue.Queue(1)
|
| 207 |
+
|
| 208 |
+
# Queue for raising exceptions from runner thread to main thread.
|
| 209 |
+
# The error queue has a max size of one to prevent stacking error and force
|
| 210 |
+
# error reporting to block until finished.
|
| 211 |
+
self.error_queue = queue.Queue(1)
|
| 212 |
+
|
| 213 |
+
# The Thread object that is running the training function.
|
| 214 |
+
self.training_thread = RunnerThread(
|
| 215 |
+
target=training_func, daemon=True, error_queue=self.error_queue
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# Possibly override with new state
|
| 219 |
+
self.trial_info = trial_info
|
| 220 |
+
self.storage = storage
|
| 221 |
+
self.loaded_checkpoint = loaded_checkpoint
|
| 222 |
+
|
| 223 |
+
# Reset state
|
| 224 |
+
self._state = {}
|
| 225 |
+
self.ignore_report = False
|
| 226 |
+
self.training_started = False
|
| 227 |
+
self._first_report = True
|
| 228 |
+
|
| 229 |
+
# Change the working directory to a special trial folder.
|
| 230 |
+
# This is to ensure that all Ray Train workers have a common working directory.
|
| 231 |
+
os.makedirs(storage.trial_working_directory, exist_ok=True)
|
| 232 |
+
if bool(int(os.environ.get(RAY_CHDIR_TO_TRIAL_DIR, "1"))):
|
| 233 |
+
logger.debug(
|
| 234 |
+
f"Changing the working directory to: {storage.trial_working_directory}"
|
| 235 |
+
)
|
| 236 |
+
os.chdir(storage.trial_working_directory)
|
| 237 |
+
|
| 238 |
+
def pause_reporting(self):
|
| 239 |
+
"""Ignore all future ``session.report()`` calls."""
|
| 240 |
+
self.ignore_report = True
|
| 241 |
+
|
| 242 |
+
def finish(self, timeout: Optional[float] = None) -> Optional[Any]:
|
| 243 |
+
"""Finishes the training thread.
|
| 244 |
+
|
| 245 |
+
Raises any Exception from training.
|
| 246 |
+
"""
|
| 247 |
+
# Set the stop event for the training thread to gracefully exit.
|
| 248 |
+
self.stop_event.set()
|
| 249 |
+
|
| 250 |
+
# Release the lock so that training thread can process this event.
|
| 251 |
+
self.continue_lock.release()
|
| 252 |
+
|
| 253 |
+
# Force a final (blocking) sync of artifacts in the trial path to storage.
|
| 254 |
+
self.storage.persist_artifacts(force=True)
|
| 255 |
+
|
| 256 |
+
# Wait for training to finish.
|
| 257 |
+
# This will raise any errors that occur during training, including SystemError
|
| 258 |
+
# This returns the result of the training function.
|
| 259 |
+
output = None
|
| 260 |
+
if self.training_started:
|
| 261 |
+
output = self.training_thread.join(timeout=timeout)
|
| 262 |
+
|
| 263 |
+
return output
|
| 264 |
+
|
| 265 |
+
def get_next(self) -> Optional[_TrainingResult]:
|
| 266 |
+
"""Gets the next ``_TrainingResult`` from the result queue.
|
| 267 |
+
|
| 268 |
+
If the result queue is empty, then this function returns ``None``.
|
| 269 |
+
"""
|
| 270 |
+
if not self.training_started:
|
| 271 |
+
raise RuntimeError("Please call start before calling get_next.")
|
| 272 |
+
|
| 273 |
+
if self.synchronous_result_reporting:
|
| 274 |
+
# There's no need to release the lock on the first report
|
| 275 |
+
# since `start` already started the training thread.
|
| 276 |
+
if not self._first_report:
|
| 277 |
+
# Release the lock to trigger training to continue,
|
| 278 |
+
# until the next call to report.
|
| 279 |
+
self.continue_lock.release()
|
| 280 |
+
self._first_report = False
|
| 281 |
+
|
| 282 |
+
result = None
|
| 283 |
+
# While training is still ongoing, attempt to get the result.
|
| 284 |
+
while result is None and self.training_thread.is_alive():
|
| 285 |
+
try:
|
| 286 |
+
result = self.result_queue.get(
|
| 287 |
+
block=True, timeout=_RESULT_FETCH_TIMEOUT
|
| 288 |
+
)
|
| 289 |
+
except queue.Empty:
|
| 290 |
+
pass
|
| 291 |
+
|
| 292 |
+
# If no result was found, then the runner must no longer be alive.
|
| 293 |
+
if result is None:
|
| 294 |
+
# Try one last time to fetch results in case results were
|
| 295 |
+
# reported in between the time of the last check and the
|
| 296 |
+
# termination of the thread runner.
|
| 297 |
+
try:
|
| 298 |
+
result = self.result_queue.get(
|
| 299 |
+
block=False, timeout=_RESULT_FETCH_TIMEOUT
|
| 300 |
+
)
|
| 301 |
+
except queue.Empty:
|
| 302 |
+
pass
|
| 303 |
+
|
| 304 |
+
# check if error occurred inside the thread runner.
|
| 305 |
+
if result is None:
|
| 306 |
+
# only raise an error from the runner if all results are consumed
|
| 307 |
+
self._report_thread_runner_error(block=True)
|
| 308 |
+
else:
|
| 309 |
+
if not self.error_queue.empty():
|
| 310 |
+
logger.debug(
|
| 311 |
+
(
|
| 312 |
+
"Runner error waiting to be raised in main thread. "
|
| 313 |
+
"Logging all available results first."
|
| 314 |
+
)
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
if not self.synchronous_result_reporting:
|
| 318 |
+
# At this point, the training thread has reached
|
| 319 |
+
# the `train.report` and is blocked there.
|
| 320 |
+
# If performing asynchronous result reporting,
|
| 321 |
+
# release the lock to allow each worker to keep training
|
| 322 |
+
# immediately after the coordinator fetches their result.
|
| 323 |
+
self.continue_lock.release()
|
| 324 |
+
|
| 325 |
+
# Return None if there are no more results to fetch.
|
| 326 |
+
return result
|
| 327 |
+
|
| 328 |
+
def _auto_fill_metrics(self, result: dict) -> dict:
|
| 329 |
+
"""Add autofilled metrics and update attributes."""
|
| 330 |
+
current_time = time.time()
|
| 331 |
+
current_datetime = datetime.now()
|
| 332 |
+
if TIME_THIS_ITER_S in result:
|
| 333 |
+
time_this_iter = result[TIME_THIS_ITER_S]
|
| 334 |
+
else:
|
| 335 |
+
time_this_iter = current_time - self.last_report_time
|
| 336 |
+
self.iteration += 1
|
| 337 |
+
self.time_total += time_this_iter
|
| 338 |
+
self.last_report_time = current_time
|
| 339 |
+
|
| 340 |
+
auto_filled_metrics = {
|
| 341 |
+
TIMESTAMP: int(time.mktime(current_datetime.timetuple())),
|
| 342 |
+
TIME_TOTAL_S: self.time_total,
|
| 343 |
+
WORKER_PID: os.getpid(),
|
| 344 |
+
WORKER_HOSTNAME: platform.node(),
|
| 345 |
+
WORKER_NODE_IP: self.local_ip,
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
if not self.detailed_autofilled_metrics:
|
| 349 |
+
auto_filled_metrics = {
|
| 350 |
+
k: v
|
| 351 |
+
for k, v in auto_filled_metrics.items()
|
| 352 |
+
if k not in DETAILED_AUTOFILLED_KEYS
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
result = result.copy()
|
| 356 |
+
result.update(auto_filled_metrics)
|
| 357 |
+
return result
|
| 358 |
+
|
| 359 |
+
def _auto_fill_checkpoint_metrics(self, result: dict) -> dict:
|
| 360 |
+
"""Add autofilled metrics and update attributes."""
|
| 361 |
+
current_datetime = datetime.now()
|
| 362 |
+
|
| 363 |
+
auto_filled_metrics = {
|
| 364 |
+
TIMESTAMP: int(time.mktime(current_datetime.timetuple()))
|
| 365 |
+
}
|
| 366 |
+
result = result.copy()
|
| 367 |
+
result.update(auto_filled_metrics)
|
| 368 |
+
return result
|
| 369 |
+
|
| 370 |
+
def _report_thread_runner_error(self, block=False):
|
| 371 |
+
try:
|
| 372 |
+
e = self.error_queue.get(block=block, timeout=_ERROR_FETCH_TIMEOUT)
|
| 373 |
+
raise StartTraceback from e
|
| 374 |
+
except queue.Empty:
|
| 375 |
+
pass
|
| 376 |
+
|
| 377 |
+
def _report_training_result(self, training_result: _TrainingResult) -> None:
|
| 378 |
+
"""Place a training result on the result queue for the main thread to process,
|
| 379 |
+
then block until the main thread signals that training should continue.
|
| 380 |
+
|
| 381 |
+
NOTE: This is used internally to report results from Train to Tune
|
| 382 |
+
without persisting checkpoints to storage 2 times.
|
| 383 |
+
`report` is the public API that directly persists to storage, which
|
| 384 |
+
should only be called by user code.
|
| 385 |
+
"""
|
| 386 |
+
if training_result.checkpoint:
|
| 387 |
+
# NOTE: This populates `train.get_checkpoint`
|
| 388 |
+
self.loaded_checkpoint = training_result.checkpoint
|
| 389 |
+
|
| 390 |
+
# Add result to a thread-safe queue.
|
| 391 |
+
self.result_queue.put(training_result, block=True)
|
| 392 |
+
|
| 393 |
+
# Acquire lock to stop the training thread until main thread
|
| 394 |
+
# triggers resume.
|
| 395 |
+
self.continue_lock.acquire()
|
| 396 |
+
|
| 397 |
+
# If the trial should be terminated, exit gracefully.
|
| 398 |
+
# NOTE: This is only really useful if `synchronous_result_reporting=True`.
|
| 399 |
+
# Otherwise, the lock is immediately released on reporting, and this
|
| 400 |
+
# check is skipped before the main thread decides to set the stop event.
|
| 401 |
+
if self.stop_event.is_set():
|
| 402 |
+
self.stop_event.clear()
|
| 403 |
+
sys.exit(0)
|
| 404 |
+
|
| 405 |
+
def report(self, metrics: Dict, checkpoint: Optional[Checkpoint] = None) -> None:
|
| 406 |
+
# Special case: early fail for Torch tensors
|
| 407 |
+
if "torch" in sys.modules:
|
| 408 |
+
from ray.air._internal.torch_utils import contains_tensor
|
| 409 |
+
|
| 410 |
+
if contains_tensor(metrics):
|
| 411 |
+
raise ValueError(
|
| 412 |
+
"Passing objects containg Torch tensors as metrics "
|
| 413 |
+
"is not supported as it will throw an exception on "
|
| 414 |
+
"deserialization. You can either convert the tensors "
|
| 415 |
+
"to Python objects or report a `train.Checkpoint` "
|
| 416 |
+
"with `ray.train.report` to store your Torch objects."
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
if self.ignore_report:
|
| 420 |
+
return
|
| 421 |
+
|
| 422 |
+
metrics = self._auto_fill_metrics(metrics)
|
| 423 |
+
|
| 424 |
+
persisted_checkpoint = None
|
| 425 |
+
if checkpoint:
|
| 426 |
+
self.storage._update_checkpoint_index(metrics)
|
| 427 |
+
|
| 428 |
+
# Persist the reported checkpoint files to storage.
|
| 429 |
+
persisted_checkpoint = self.storage.persist_current_checkpoint(checkpoint)
|
| 430 |
+
|
| 431 |
+
metrics[CHECKPOINT_DIR_NAME] = self.storage.checkpoint_dir_name
|
| 432 |
+
else:
|
| 433 |
+
metrics[CHECKPOINT_DIR_NAME] = None
|
| 434 |
+
|
| 435 |
+
# Persist trial artifacts to storage.
|
| 436 |
+
force_artifact_sync = (
|
| 437 |
+
persisted_checkpoint
|
| 438 |
+
and self.storage.sync_config.sync_artifacts_on_checkpoint
|
| 439 |
+
)
|
| 440 |
+
self.storage.persist_artifacts(force=force_artifact_sync)
|
| 441 |
+
|
| 442 |
+
# Set additional user metadata from the Trainer.
|
| 443 |
+
if persisted_checkpoint and self.metadata:
|
| 444 |
+
user_metadata = persisted_checkpoint.get_metadata()
|
| 445 |
+
for k, v in self.metadata.items():
|
| 446 |
+
# Update keys not already set by the user. This gives user-set keys
|
| 447 |
+
# precedence over keys set at the Trainer level.
|
| 448 |
+
if k not in user_metadata:
|
| 449 |
+
user_metadata[k] = v
|
| 450 |
+
persisted_checkpoint.set_metadata(user_metadata)
|
| 451 |
+
|
| 452 |
+
result = _TrainingResult(checkpoint=persisted_checkpoint, metrics=metrics)
|
| 453 |
+
|
| 454 |
+
self._report_training_result(result)
|
| 455 |
+
|
| 456 |
+
@property
|
| 457 |
+
def experiment_name(self) -> str:
|
| 458 |
+
return self.trial_info.experiment_name
|
| 459 |
+
|
| 460 |
+
@property
|
| 461 |
+
def trial_name(self) -> str:
|
| 462 |
+
return self.trial_info.name
|
| 463 |
+
|
| 464 |
+
@property
|
| 465 |
+
def trial_id(self) -> str:
|
| 466 |
+
return self.trial_info.id
|
| 467 |
+
|
| 468 |
+
@property
|
| 469 |
+
def run_id(self) -> str:
|
| 470 |
+
return self.trial_info.run_id
|
| 471 |
+
|
| 472 |
+
@property
|
| 473 |
+
def trial_resources(self) -> "PlacementGroupFactory":
|
| 474 |
+
return self.trial_info.resources
|
| 475 |
+
|
| 476 |
+
@property
|
| 477 |
+
def trial_dir(self) -> str:
|
| 478 |
+
return self.trial_info.logdir
|
| 479 |
+
|
| 480 |
+
def get_dataset_shard(
|
| 481 |
+
self,
|
| 482 |
+
dataset_name: Optional[str] = None,
|
| 483 |
+
) -> Optional["DataIterator"]:
|
| 484 |
+
shard = self.dataset_shard
|
| 485 |
+
if shard is None:
|
| 486 |
+
warnings.warn(
|
| 487 |
+
"No dataset passed in. Returning None. Make sure to "
|
| 488 |
+
"pass in a Dataset to Trainer.run to use this "
|
| 489 |
+
"function."
|
| 490 |
+
)
|
| 491 |
+
elif isinstance(shard, dict):
|
| 492 |
+
if not dataset_name:
|
| 493 |
+
raise RuntimeError(
|
| 494 |
+
"Multiple datasets were passed into ``Trainer``, "
|
| 495 |
+
"but no ``dataset_name`` is passed into "
|
| 496 |
+
"``get_dataset_shard``. Please specify which "
|
| 497 |
+
"dataset shard to retrieve."
|
| 498 |
+
)
|
| 499 |
+
return shard.get(dataset_name)
|
| 500 |
+
return shard
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
# Cache of resource dicts that have been checked by the launch hook already.
|
| 504 |
+
_checked_resources: Set[frozenset] = set()
|
| 505 |
+
|
| 506 |
+
# Global _TrainSession object initialized by Ray Tune function trainables
|
| 507 |
+
# and Ray Train V1 workers.
|
| 508 |
+
_session: Optional[_TrainSession] = None
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
def _tune_task_and_actor_launch_hook(
|
| 512 |
+
fn, resources: Dict[str, float], strategy: Optional[SchedulingStrategyT]
|
| 513 |
+
):
|
| 514 |
+
"""Launch hook to catch nested tasks that can't fit in the placement group.
|
| 515 |
+
|
| 516 |
+
This gives users a nice warning in case they launch a nested task in a Tune trial
|
| 517 |
+
without reserving resources in the trial placement group to fit it.
|
| 518 |
+
"""
|
| 519 |
+
|
| 520 |
+
# Already checked, skip for performance reasons.
|
| 521 |
+
key = frozenset({(k, v) for k, v in resources.items() if v > 0})
|
| 522 |
+
if not key or key in _checked_resources:
|
| 523 |
+
return
|
| 524 |
+
|
| 525 |
+
# No need to check if placement group is None.
|
| 526 |
+
if (
|
| 527 |
+
not isinstance(strategy, PlacementGroupSchedulingStrategy)
|
| 528 |
+
or strategy.placement_group is None
|
| 529 |
+
):
|
| 530 |
+
return
|
| 531 |
+
|
| 532 |
+
# Check if the resource request is targeting the current placement group.
|
| 533 |
+
cur_pg = ray.util.get_current_placement_group()
|
| 534 |
+
if not cur_pg or strategy.placement_group.id != cur_pg.id:
|
| 535 |
+
return
|
| 536 |
+
|
| 537 |
+
_checked_resources.add(key)
|
| 538 |
+
|
| 539 |
+
# Check if the request can be fulfilled by the current placement group.
|
| 540 |
+
pgf = get_trial_resources()
|
| 541 |
+
|
| 542 |
+
if pgf.head_bundle_is_empty:
|
| 543 |
+
available_bundles = cur_pg.bundle_specs[0:]
|
| 544 |
+
else:
|
| 545 |
+
available_bundles = cur_pg.bundle_specs[1:]
|
| 546 |
+
|
| 547 |
+
# Check if the request can be fulfilled by the current placement group.
|
| 548 |
+
if _valid_resource_shape(resources, available_bundles):
|
| 549 |
+
return
|
| 550 |
+
|
| 551 |
+
if fn.class_name:
|
| 552 |
+
submitted = "actor"
|
| 553 |
+
name = fn.module_name + "." + fn.class_name + "." + fn.function_name
|
| 554 |
+
else:
|
| 555 |
+
submitted = "task"
|
| 556 |
+
name = fn.module_name + "." + fn.function_name
|
| 557 |
+
|
| 558 |
+
# Normalize the resource spec so it looks the same as the placement group bundle.
|
| 559 |
+
main_resources = cur_pg.bundle_specs[0]
|
| 560 |
+
resources = {k: float(v) for k, v in resources.items() if v > 0}
|
| 561 |
+
|
| 562 |
+
raise RuntimeError(
|
| 563 |
+
f"No trial resources are available for launching the {submitted} `{name}`. "
|
| 564 |
+
"To resolve this, specify the Tune option:\n\n"
|
| 565 |
+
"> resources_per_trial=tune.PlacementGroupFactory(\n"
|
| 566 |
+
f"> [{main_resources}] + [{resources}] * N\n"
|
| 567 |
+
"> )\n\n"
|
| 568 |
+
f"Where `N` is the number of slots to reserve for trial {submitted}s. "
|
| 569 |
+
"If you are using a Ray training library, there might be a utility function "
|
| 570 |
+
"to set this automatically for you. For more information, refer to "
|
| 571 |
+
"https://docs.ray.io/en/latest/tune/tutorials/tune-resources.html"
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
def init_session(*args, **kwargs) -> None:
|
| 576 |
+
global _session
|
| 577 |
+
if _session:
|
| 578 |
+
raise ValueError(
|
| 579 |
+
"A Train session is already in use. Do not call "
|
| 580 |
+
"`init_session()` manually."
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
# Setup hooks for generating placement group resource deadlock warnings.
|
| 584 |
+
from ray import actor, remote_function
|
| 585 |
+
|
| 586 |
+
if "TUNE_DISABLE_RESOURCE_CHECKS" not in os.environ:
|
| 587 |
+
actor._actor_launch_hook = _tune_task_and_actor_launch_hook
|
| 588 |
+
remote_function._task_launch_hook = _tune_task_and_actor_launch_hook
|
| 589 |
+
|
| 590 |
+
_session = _TrainSession(*args, **kwargs)
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
def get_session() -> Optional[_TrainSession]:
|
| 594 |
+
return _session
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
def shutdown_session():
|
| 598 |
+
"""Shuts down the initialized session."""
|
| 599 |
+
global _session
|
| 600 |
+
_session = None
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
def _raise_accelerator_session_misuse():
|
| 604 |
+
"""Raises a SessionMisuseError because a utility function was used improperly."""
|
| 605 |
+
raise SessionMisuseError(
|
| 606 |
+
"prepare/accelerate utility functions should be called inside a training "
|
| 607 |
+
"function executed by `Trainer.run`"
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
def get_accelerator(default_accelerator_cls: Type[Accelerator]) -> Accelerator:
|
| 612 |
+
"""The accelerator for this training session.
|
| 613 |
+
|
| 614 |
+
If an accelerator has not been set, then this method will construct an
|
| 615 |
+
accelerator using the provided accelerator class.
|
| 616 |
+
|
| 617 |
+
Raises:
|
| 618 |
+
SessionMisuseError: if the session is uninitialized.
|
| 619 |
+
"""
|
| 620 |
+
session = get_session()
|
| 621 |
+
if session is None:
|
| 622 |
+
_raise_accelerator_session_misuse()
|
| 623 |
+
if session.accelerator is None:
|
| 624 |
+
session.accelerator = default_accelerator_cls()
|
| 625 |
+
return session.accelerator
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
def set_accelerator(accelerator: Accelerator) -> None:
|
| 629 |
+
"""Sets the accelerator for this training session.
|
| 630 |
+
|
| 631 |
+
Args:
|
| 632 |
+
accelerator: The accelerator to use for training.
|
| 633 |
+
|
| 634 |
+
Raises:
|
| 635 |
+
SessionMisuseError: if the session is unitialized.
|
| 636 |
+
RuntimeError: if the accelerator has already been set.
|
| 637 |
+
"""
|
| 638 |
+
session = get_session()
|
| 639 |
+
if session is None:
|
| 640 |
+
_raise_accelerator_session_misuse()
|
| 641 |
+
if session.accelerator is not None:
|
| 642 |
+
raise RuntimeError("Cannot change accelerator once set.")
|
| 643 |
+
session.accelerator = accelerator
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
def _warn_session_misuse(default_value: Any = None):
|
| 647 |
+
"""Warns if fn is being used outside of session and returns ``default_value``."""
|
| 648 |
+
|
| 649 |
+
def inner(fn: Callable):
|
| 650 |
+
fn_name = fn.__name__
|
| 651 |
+
|
| 652 |
+
@functools.wraps(fn)
|
| 653 |
+
def wrapper(*args, **kwargs):
|
| 654 |
+
session = get_session()
|
| 655 |
+
if not session:
|
| 656 |
+
if log_once(f"{SESSION_MISUSE_LOG_ONCE_KEY}-{fn_name}"):
|
| 657 |
+
warnings.warn(
|
| 658 |
+
f"`{fn_name}` is meant to only be "
|
| 659 |
+
"called inside a function that is executed by a Tuner"
|
| 660 |
+
f" or Trainer. Returning `{default_value}`."
|
| 661 |
+
)
|
| 662 |
+
return default_value
|
| 663 |
+
return fn(*args, **kwargs)
|
| 664 |
+
|
| 665 |
+
return wrapper
|
| 666 |
+
|
| 667 |
+
return inner
|
| 668 |
+
|
| 669 |
+
|
| 670 |
+
@PublicAPI(stability="stable")
|
| 671 |
+
@_warn_session_misuse()
|
| 672 |
+
def report(metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None:
|
| 673 |
+
"""Report metrics and optionally save a checkpoint.
|
| 674 |
+
|
| 675 |
+
If a checkpoint is provided, it will be
|
| 676 |
+
:ref:`persisted to storage <persistent-storage-guide>`.
|
| 677 |
+
|
| 678 |
+
If this is called in multiple distributed training workers:
|
| 679 |
+
|
| 680 |
+
- Only the metrics reported by the rank 0 worker will be tracked by Ray Train.
|
| 681 |
+
See :ref:`the metrics logging guide <train-monitoring-and-logging>`.
|
| 682 |
+
- A checkpoint will be registered as long as one or more workers reports
|
| 683 |
+
checkpoint that is not None.
|
| 684 |
+
See the :ref:`checkpointing guide <train-dl-saving-checkpoints>`.
|
| 685 |
+
- Checkpoints from multiple workers will be merged into one directory
|
| 686 |
+
in persistent storage.
|
| 687 |
+
See :ref:`the distributed checkpointing guide <train-distributed-checkpointing>`.
|
| 688 |
+
|
| 689 |
+
.. note::
|
| 690 |
+
|
| 691 |
+
Each invocation of this method will automatically increment the underlying
|
| 692 |
+
``training_iteration`` number. The physical meaning of this "iteration" is
|
| 693 |
+
defined by user depending on how often they call ``report``.
|
| 694 |
+
It does not necessarily map to one epoch.
|
| 695 |
+
|
| 696 |
+
.. warning::
|
| 697 |
+
|
| 698 |
+
All workers must call `ray.train.report` the same number of times
|
| 699 |
+
so that Ray Train can properly synchronize the training state across
|
| 700 |
+
workers. Otherwise, your training will hang.
|
| 701 |
+
|
| 702 |
+
.. warning::
|
| 703 |
+
|
| 704 |
+
This method does NOT act as a barrier for distributed training workers.
|
| 705 |
+
Workers will upload their checkpoint, then continue training immediately.
|
| 706 |
+
If you need to synchronize workers, you can use a framework-native barrier
|
| 707 |
+
such as `torch.distributed.barrier()`.
|
| 708 |
+
|
| 709 |
+
Example:
|
| 710 |
+
|
| 711 |
+
.. testcode::
|
| 712 |
+
|
| 713 |
+
import tempfile
|
| 714 |
+
|
| 715 |
+
from ray import train
|
| 716 |
+
from ray.train import Checkpoint
|
| 717 |
+
from ray.train.torch import TorchTrainer
|
| 718 |
+
|
| 719 |
+
|
| 720 |
+
def train_func(config):
|
| 721 |
+
start_epoch = 0
|
| 722 |
+
checkpoint = train.get_checkpoint()
|
| 723 |
+
if checkpoint:
|
| 724 |
+
with checkpoint.as_directory() as checkpoint_dir:
|
| 725 |
+
# Load back training state
|
| 726 |
+
...
|
| 727 |
+
|
| 728 |
+
for epoch in range(start_epoch, config.get("num_epochs", 10)):
|
| 729 |
+
# Do training...
|
| 730 |
+
|
| 731 |
+
metrics = {"loss": ...}
|
| 732 |
+
|
| 733 |
+
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
|
| 734 |
+
# Save the checkpoint...
|
| 735 |
+
# torch.save(...)
|
| 736 |
+
|
| 737 |
+
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
|
| 738 |
+
|
| 739 |
+
# Example: Only the rank 0 worker uploads the checkpoint.
|
| 740 |
+
if ray.train.get_context().get_world_rank() == 0:
|
| 741 |
+
train.report(metrics, checkpoint=checkpoint)
|
| 742 |
+
else:
|
| 743 |
+
train.report(metrics, checkpoint=None)
|
| 744 |
+
|
| 745 |
+
trainer = TorchTrainer(
|
| 746 |
+
train_func, scaling_config=train.ScalingConfig(num_workers=2)
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
Args:
|
| 750 |
+
metrics: The metrics you want to report.
|
| 751 |
+
checkpoint: The optional checkpoint you want to report.
|
| 752 |
+
"""
|
| 753 |
+
# If we are running in a Tune function, switch to `ray.tune.report`.
|
| 754 |
+
from ray.tune.trainable.trainable_fn_utils import _in_tune_session
|
| 755 |
+
|
| 756 |
+
if _in_tune_session():
|
| 757 |
+
import ray.tune
|
| 758 |
+
|
| 759 |
+
if _v2_migration_warnings_enabled():
|
| 760 |
+
_log_deprecation_warning(
|
| 761 |
+
"`ray.train.report` should be switched to "
|
| 762 |
+
"`ray.tune.report` when running in a function "
|
| 763 |
+
"passed to Ray Tune. This will be an error in the future."
|
| 764 |
+
)
|
| 765 |
+
return ray.tune.report(metrics, checkpoint=checkpoint)
|
| 766 |
+
|
| 767 |
+
get_session().report(metrics, checkpoint=checkpoint)
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
@PublicAPI(stability="stable")
|
| 771 |
+
@_warn_session_misuse()
|
| 772 |
+
def get_checkpoint() -> Optional[Checkpoint]:
|
| 773 |
+
"""Access the latest reported checkpoint to resume from if one exists.
|
| 774 |
+
|
| 775 |
+
Example:
|
| 776 |
+
|
| 777 |
+
.. testcode::
|
| 778 |
+
|
| 779 |
+
import tempfile
|
| 780 |
+
|
| 781 |
+
from ray import train
|
| 782 |
+
from ray.train import Checkpoint
|
| 783 |
+
from ray.train.torch import TorchTrainer
|
| 784 |
+
|
| 785 |
+
|
| 786 |
+
def train_func(config):
|
| 787 |
+
start_epoch = 0
|
| 788 |
+
checkpoint = train.get_checkpoint()
|
| 789 |
+
if checkpoint:
|
| 790 |
+
with checkpoint.as_directory() as checkpoint_dir:
|
| 791 |
+
# Load back training state
|
| 792 |
+
...
|
| 793 |
+
|
| 794 |
+
for epoch in range(start_epoch, config.get("num_epochs", 10)):
|
| 795 |
+
# Do training...
|
| 796 |
+
|
| 797 |
+
metrics = {"loss": ...}
|
| 798 |
+
|
| 799 |
+
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
|
| 800 |
+
# Save the checkpoint...
|
| 801 |
+
|
| 802 |
+
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
|
| 803 |
+
train.report(metrics, checkpoint=checkpoint)
|
| 804 |
+
|
| 805 |
+
trainer = TorchTrainer(
|
| 806 |
+
train_func, scaling_config=train.ScalingConfig(num_workers=2)
|
| 807 |
+
)
|
| 808 |
+
|
| 809 |
+
Returns:
|
| 810 |
+
Checkpoint object if the session is currently being resumed.
|
| 811 |
+
Otherwise, return None.
|
| 812 |
+
"""
|
| 813 |
+
# If we are running in a Tune function, switch to `ray.tune.get_checkpoint`.
|
| 814 |
+
from ray.tune.trainable.trainable_fn_utils import _in_tune_session
|
| 815 |
+
|
| 816 |
+
if _in_tune_session():
|
| 817 |
+
import ray.tune
|
| 818 |
+
|
| 819 |
+
if _v2_migration_warnings_enabled():
|
| 820 |
+
_log_deprecation_warning(
|
| 821 |
+
"`ray.train.get_checkpoint` should be switched to "
|
| 822 |
+
"`ray.tune.get_checkpoint` when running in a function "
|
| 823 |
+
"passed to Ray Tune. This will be an error in the future."
|
| 824 |
+
)
|
| 825 |
+
return ray.tune.get_checkpoint()
|
| 826 |
+
|
| 827 |
+
return get_session().loaded_checkpoint
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
@PublicAPI(stability="beta")
|
| 831 |
+
@_warn_session_misuse()
|
| 832 |
+
def get_metadata() -> Dict[str, Any]:
|
| 833 |
+
"""User metadata dict passed to the Trainer constructor."""
|
| 834 |
+
return get_session().metadata
|
| 835 |
+
|
| 836 |
+
|
| 837 |
+
@PublicAPI(stability="beta")
|
| 838 |
+
@_warn_session_misuse()
|
| 839 |
+
def get_experiment_name() -> str:
|
| 840 |
+
"""Experiment name for the corresponding trial."""
|
| 841 |
+
return get_session().experiment_name
|
| 842 |
+
|
| 843 |
+
|
| 844 |
+
@PublicAPI(stability="beta")
|
| 845 |
+
@_warn_session_misuse()
|
| 846 |
+
def get_trial_name() -> str:
|
| 847 |
+
"""Trial name for the corresponding trial."""
|
| 848 |
+
return get_session().trial_name
|
| 849 |
+
|
| 850 |
+
|
| 851 |
+
@PublicAPI(stability="beta")
|
| 852 |
+
@_warn_session_misuse()
|
| 853 |
+
def get_trial_id() -> str:
|
| 854 |
+
"""Trial id for the corresponding trial."""
|
| 855 |
+
return get_session().trial_id
|
| 856 |
+
|
| 857 |
+
|
| 858 |
+
@PublicAPI(stability="alpha")
|
| 859 |
+
@_warn_session_misuse()
|
| 860 |
+
def get_run_id() -> str:
|
| 861 |
+
"""Unique Train Run id for the corresponding trial."""
|
| 862 |
+
return get_session().run_id
|
| 863 |
+
|
| 864 |
+
|
| 865 |
+
@PublicAPI(stability="beta")
|
| 866 |
+
@_warn_session_misuse()
|
| 867 |
+
def get_trial_resources() -> "PlacementGroupFactory":
|
| 868 |
+
"""Trial resources for the corresponding trial."""
|
| 869 |
+
return get_session().trial_resources
|
| 870 |
+
|
| 871 |
+
|
| 872 |
+
@PublicAPI(stability="beta")
|
| 873 |
+
@_warn_session_misuse()
|
| 874 |
+
def get_trial_dir() -> str:
|
| 875 |
+
"""Log directory corresponding to the trial directory for a Tune session.
|
| 876 |
+
If calling from a Train session, this will give the trial directory of its parent
|
| 877 |
+
Tune session.
|
| 878 |
+
|
| 879 |
+
.. testcode::
|
| 880 |
+
|
| 881 |
+
from ray import train, tune
|
| 882 |
+
|
| 883 |
+
def train_func(config):
|
| 884 |
+
print(train.get_context().get_trial_dir())
|
| 885 |
+
|
| 886 |
+
tuner = tune.Tuner(train_func)
|
| 887 |
+
tuner.fit()
|
| 888 |
+
|
| 889 |
+
.. testoutput::
|
| 890 |
+
:options: +MOCK
|
| 891 |
+
|
| 892 |
+
/Users/root/ray_results/train_func_2023-07-19_15-01-37/train_func_d620c_00000_0_2023-07-19_15-01-40
|
| 893 |
+
"""
|
| 894 |
+
return get_session().trial_dir
|
| 895 |
+
|
| 896 |
+
|
| 897 |
+
@PublicAPI(stability="beta")
|
| 898 |
+
@_warn_session_misuse(default_value=1)
|
| 899 |
+
def get_world_size() -> int:
|
| 900 |
+
"""Get the current world size (i.e. total number of workers) for this run.
|
| 901 |
+
|
| 902 |
+
.. testcode::
|
| 903 |
+
|
| 904 |
+
import ray
|
| 905 |
+
from ray import train
|
| 906 |
+
from ray.train import ScalingConfig
|
| 907 |
+
from ray.train.tensorflow import TensorflowTrainer
|
| 908 |
+
|
| 909 |
+
NUM_WORKERS = 2
|
| 910 |
+
|
| 911 |
+
def train_loop_per_worker(config):
|
| 912 |
+
assert train.get_context().get_world_size() == NUM_WORKERS
|
| 913 |
+
|
| 914 |
+
train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
|
| 915 |
+
trainer = TensorflowTrainer(
|
| 916 |
+
train_loop_per_worker,
|
| 917 |
+
scaling_config=ScalingConfig(num_workers=NUM_WORKERS),
|
| 918 |
+
datasets={"train": train_dataset}
|
| 919 |
+
)
|
| 920 |
+
trainer.fit()
|
| 921 |
+
|
| 922 |
+
.. testoutput::
|
| 923 |
+
:hide:
|
| 924 |
+
|
| 925 |
+
...
|
| 926 |
+
"""
|
| 927 |
+
session = get_session()
|
| 928 |
+
if not hasattr(session, "world_size"):
|
| 929 |
+
raise RuntimeError(
|
| 930 |
+
"`get_world_size` can only be called for TrainSession! "
|
| 931 |
+
"Make sure you only use that in `train_loop_per_worker` function"
|
| 932 |
+
"that is passed into `DataParallelTrainer`."
|
| 933 |
+
)
|
| 934 |
+
return session.world_size
|
| 935 |
+
|
| 936 |
+
|
| 937 |
+
@PublicAPI(stability="beta")
|
| 938 |
+
@_warn_session_misuse(default_value=0)
|
| 939 |
+
def get_world_rank() -> int:
|
| 940 |
+
"""Get the world rank of this worker.
|
| 941 |
+
|
| 942 |
+
.. testcode::
|
| 943 |
+
|
| 944 |
+
import ray
|
| 945 |
+
from ray import train
|
| 946 |
+
from ray.train import ScalingConfig
|
| 947 |
+
from ray.train.tensorflow import TensorflowTrainer
|
| 948 |
+
|
| 949 |
+
def train_loop_per_worker(config):
|
| 950 |
+
if train.get_context().get_world_rank() == 0:
|
| 951 |
+
print("Worker 0")
|
| 952 |
+
|
| 953 |
+
train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
|
| 954 |
+
trainer = TensorflowTrainer(
|
| 955 |
+
train_loop_per_worker,
|
| 956 |
+
scaling_config=ScalingConfig(num_workers=2),
|
| 957 |
+
datasets={"train": train_dataset}
|
| 958 |
+
)
|
| 959 |
+
trainer.fit()
|
| 960 |
+
|
| 961 |
+
.. testoutput::
|
| 962 |
+
:hide:
|
| 963 |
+
|
| 964 |
+
...
|
| 965 |
+
"""
|
| 966 |
+
session = get_session()
|
| 967 |
+
if not hasattr(session, "world_rank"):
|
| 968 |
+
raise RuntimeError(
|
| 969 |
+
"`get_world_rank` can only be called for TrainSession! "
|
| 970 |
+
"Make sure you only use that in `train_loop_per_worker` function"
|
| 971 |
+
"that is passed into `DataParallelTrainer`."
|
| 972 |
+
)
|
| 973 |
+
return session.world_rank
|
| 974 |
+
|
| 975 |
+
|
| 976 |
+
@PublicAPI(stability="beta")
|
| 977 |
+
@_warn_session_misuse(default_value=0)
|
| 978 |
+
def get_local_rank() -> int:
|
| 979 |
+
"""Get the local rank of this worker (rank of the worker on its node).
|
| 980 |
+
|
| 981 |
+
.. testcode::
|
| 982 |
+
|
| 983 |
+
import torch
|
| 984 |
+
|
| 985 |
+
import ray
|
| 986 |
+
from ray import train
|
| 987 |
+
from ray.train import ScalingConfig
|
| 988 |
+
from ray.train.torch import TorchTrainer
|
| 989 |
+
|
| 990 |
+
def train_loop_per_worker(config):
|
| 991 |
+
if torch.cuda.is_available():
|
| 992 |
+
torch.cuda.set_device(train.get_context().get_local_rank())
|
| 993 |
+
...
|
| 994 |
+
|
| 995 |
+
train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
|
| 996 |
+
trainer = TorchTrainer(
|
| 997 |
+
train_loop_per_worker,
|
| 998 |
+
scaling_config=ScalingConfig(num_workers=2, use_gpu=True),
|
| 999 |
+
datasets={"train": train_dataset}
|
| 1000 |
+
)
|
| 1001 |
+
trainer.fit()
|
| 1002 |
+
|
| 1003 |
+
.. testoutput::
|
| 1004 |
+
:hide:
|
| 1005 |
+
|
| 1006 |
+
...
|
| 1007 |
+
"""
|
| 1008 |
+
session = get_session()
|
| 1009 |
+
if not hasattr(session, "local_rank"):
|
| 1010 |
+
raise RuntimeError(
|
| 1011 |
+
"`get_local_rank` can only be called for TrainSession! "
|
| 1012 |
+
"Make sure you only use that in `train_loop_per_worker` function"
|
| 1013 |
+
"that is passed into `DataParallelTrainer`."
|
| 1014 |
+
)
|
| 1015 |
+
return session.local_rank
|
| 1016 |
+
|
| 1017 |
+
|
| 1018 |
+
@PublicAPI(stability="beta")
|
| 1019 |
+
@_warn_session_misuse(default_value=0)
|
| 1020 |
+
def get_local_world_size() -> int:
|
| 1021 |
+
"""Get the local world size of this node (i.e. number of workers on this node).
|
| 1022 |
+
|
| 1023 |
+
Example:
|
| 1024 |
+
|
| 1025 |
+
.. testcode::
|
| 1026 |
+
|
| 1027 |
+
import ray
|
| 1028 |
+
from ray import train
|
| 1029 |
+
from ray.train import ScalingConfig
|
| 1030 |
+
from ray.train.torch import TorchTrainer
|
| 1031 |
+
|
| 1032 |
+
def train_loop_per_worker():
|
| 1033 |
+
print(train.get_context().get_local_world_size())
|
| 1034 |
+
|
| 1035 |
+
train_dataset = ray.data.from_items(
|
| 1036 |
+
[{"x": x, "y": x + 1} for x in range(32)])
|
| 1037 |
+
trainer = TorchTrainer(train_loop_per_worker,
|
| 1038 |
+
scaling_config=ScalingConfig(num_workers=1),
|
| 1039 |
+
datasets={"train": train_dataset})
|
| 1040 |
+
trainer.fit()
|
| 1041 |
+
|
| 1042 |
+
.. testoutput::
|
| 1043 |
+
:hide:
|
| 1044 |
+
|
| 1045 |
+
...
|
| 1046 |
+
"""
|
| 1047 |
+
session = get_session()
|
| 1048 |
+
if not hasattr(session, "local_world_size"):
|
| 1049 |
+
raise RuntimeError(
|
| 1050 |
+
"`get_local_world_size` can only be called for TrainSession! "
|
| 1051 |
+
"Make sure you only use that in `train_loop_per_worker` function"
|
| 1052 |
+
"that is passed into `DataParallelTrainer`."
|
| 1053 |
+
)
|
| 1054 |
+
return session.local_world_size
|
| 1055 |
+
|
| 1056 |
+
|
| 1057 |
+
@PublicAPI(stability="beta")
|
| 1058 |
+
@_warn_session_misuse(default_value=0)
|
| 1059 |
+
def get_node_rank() -> int:
|
| 1060 |
+
"""Get the rank of this node.
|
| 1061 |
+
|
| 1062 |
+
Example:
|
| 1063 |
+
|
| 1064 |
+
.. testcode::
|
| 1065 |
+
|
| 1066 |
+
import ray
|
| 1067 |
+
from ray import train
|
| 1068 |
+
from ray.train import ScalingConfig
|
| 1069 |
+
from ray.train.torch import TorchTrainer
|
| 1070 |
+
|
| 1071 |
+
def train_loop_per_worker():
|
| 1072 |
+
print(train.get_context().get_node_rank())
|
| 1073 |
+
|
| 1074 |
+
train_dataset = ray.data.from_items(
|
| 1075 |
+
[{"x": x, "y": x + 1} for x in range(32)])
|
| 1076 |
+
trainer = TorchTrainer(train_loop_per_worker,
|
| 1077 |
+
scaling_config=ScalingConfig(num_workers=1),
|
| 1078 |
+
datasets={"train": train_dataset})
|
| 1079 |
+
trainer.fit()
|
| 1080 |
+
|
| 1081 |
+
.. testoutput::
|
| 1082 |
+
:hide:
|
| 1083 |
+
|
| 1084 |
+
...
|
| 1085 |
+
"""
|
| 1086 |
+
session = get_session()
|
| 1087 |
+
if not hasattr(session, "node_rank"):
|
| 1088 |
+
raise RuntimeError(
|
| 1089 |
+
"`get_node_rank` can only be called for TrainSession! "
|
| 1090 |
+
"Make sure you only use that in `train_loop_per_worker` function"
|
| 1091 |
+
"that is passed into `DataParallelTrainer`."
|
| 1092 |
+
)
|
| 1093 |
+
return session.node_rank
|
| 1094 |
+
|
| 1095 |
+
|
| 1096 |
+
@PublicAPI(stability="stable")
|
| 1097 |
+
@_warn_session_misuse()
|
| 1098 |
+
def get_dataset_shard(
|
| 1099 |
+
dataset_name: Optional[str] = None,
|
| 1100 |
+
) -> Optional["DataIterator"]:
|
| 1101 |
+
"""Returns the :class:`ray.data.DataIterator` shard for this worker.
|
| 1102 |
+
|
| 1103 |
+
Call :meth:`~ray.data.DataIterator.iter_torch_batches` or
|
| 1104 |
+
:meth:`~ray.data.DataIterator.to_tf` on this shard to convert it to the
|
| 1105 |
+
appropriate framework-specific data type.
|
| 1106 |
+
|
| 1107 |
+
.. testcode::
|
| 1108 |
+
|
| 1109 |
+
import ray
|
| 1110 |
+
from ray import train
|
| 1111 |
+
from ray.train import ScalingConfig
|
| 1112 |
+
from ray.train.torch import TorchTrainer
|
| 1113 |
+
|
| 1114 |
+
def train_loop_per_worker(config):
|
| 1115 |
+
...
|
| 1116 |
+
for epoch in range(2):
|
| 1117 |
+
# Trainer will automatically handle sharding.
|
| 1118 |
+
data_shard = train.get_dataset_shard("train")
|
| 1119 |
+
for batch in data_shard.iter_torch_batches():
|
| 1120 |
+
...
|
| 1121 |
+
|
| 1122 |
+
train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
|
| 1123 |
+
trainer = TorchTrainer(
|
| 1124 |
+
train_loop_per_worker,
|
| 1125 |
+
scaling_config=ScalingConfig(num_workers=2),
|
| 1126 |
+
datasets={"train": train_dataset}
|
| 1127 |
+
)
|
| 1128 |
+
trainer.fit()
|
| 1129 |
+
|
| 1130 |
+
.. testoutput::
|
| 1131 |
+
:hide:
|
| 1132 |
+
|
| 1133 |
+
...
|
| 1134 |
+
|
| 1135 |
+
Args:
|
| 1136 |
+
dataset_name: If a Dictionary of Datasets was passed to ``Trainer``, then
|
| 1137 |
+
specifies which dataset shard to return.
|
| 1138 |
+
|
| 1139 |
+
Returns:
|
| 1140 |
+
The ``DataIterator`` shard to use for this worker.
|
| 1141 |
+
If no dataset is passed into Trainer, then return None.
|
| 1142 |
+
"""
|
| 1143 |
+
session = get_session()
|
| 1144 |
+
if not hasattr(session, "get_dataset_shard"):
|
| 1145 |
+
raise RuntimeError(
|
| 1146 |
+
"`get_dataset_shard` can only be called for TrainSession! "
|
| 1147 |
+
"Make sure you only use that in `train_loop_per_worker` function"
|
| 1148 |
+
"that is passed into `DataParallelTrainer`."
|
| 1149 |
+
)
|
| 1150 |
+
return session.get_dataset_shard(dataset_name)
|
| 1151 |
+
|
| 1152 |
+
|
| 1153 |
+
@DeveloperAPI
|
| 1154 |
+
@_warn_session_misuse()
|
| 1155 |
+
def get_storage() -> StorageContext:
|
| 1156 |
+
"""Returns the :class:`~ray.train._internal.storage.StorageContext` storage
|
| 1157 |
+
context which gives advanced access to the filesystem and paths
|
| 1158 |
+
configured through `RunConfig`.
|
| 1159 |
+
|
| 1160 |
+
NOTE: This is a developer API, and the `StorageContext` interface may change
|
| 1161 |
+
without notice between minor versions.
|
| 1162 |
+
"""
|
| 1163 |
+
return get_session().storage
|
.venv/lib/python3.11/site-packages/ray/train/_internal/state/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.train._internal.state.state_manager import TrainRunStateManager
|
| 2 |
+
|
| 3 |
+
try:
|
| 4 |
+
import pydantic # noqa: F401
|
| 5 |
+
except ImportError:
|
| 6 |
+
raise ModuleNotFoundError(
|
| 7 |
+
"pydantic isn't installed."
|
| 8 |
+
"To install pydantic, please run 'pip install pydantic'"
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"TrainRunStateManager",
|
| 14 |
+
]
|
.venv/lib/python3.11/site-packages/ray/train/_internal/state/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (581 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/_internal/state/__pycache__/schema.cpython-311.pyc
ADDED
|
Binary file (8.55 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/_internal/state/__pycache__/state_actor.cpython-311.pyc
ADDED
|
Binary file (3.44 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/_internal/state/__pycache__/state_manager.cpython-311.pyc
ADDED
|
Binary file (6.71 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/_internal/state/schema.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
|
| 4 |
+
from ray._private.pydantic_compat import BaseModel, Field
|
| 5 |
+
from ray.dashboard.modules.job.pydantic_models import JobDetails
|
| 6 |
+
from ray.util.annotations import DeveloperAPI
|
| 7 |
+
|
| 8 |
+
MAX_ERROR_STACK_TRACE_LENGTH = 50000
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@DeveloperAPI
|
| 12 |
+
class RunStatusEnum(str, Enum):
|
| 13 |
+
"""Enumeration for the status of a train run."""
|
| 14 |
+
|
| 15 |
+
# (Deprecated) Replaced by RUNNING.
|
| 16 |
+
# The train run has started
|
| 17 |
+
STARTED = "STARTED"
|
| 18 |
+
# The train run is running
|
| 19 |
+
RUNNING = "RUNNING"
|
| 20 |
+
# The train run was terminated as expected
|
| 21 |
+
FINISHED = "FINISHED"
|
| 22 |
+
# The train run was terminated early due to errors in the training function
|
| 23 |
+
ERRORED = "ERRORED"
|
| 24 |
+
# The train run was terminated early due to system errors or controller errors
|
| 25 |
+
ABORTED = "ABORTED"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@DeveloperAPI
|
| 29 |
+
class ActorStatusEnum(str, Enum):
|
| 30 |
+
DEAD = "DEAD"
|
| 31 |
+
ALIVE = "ALIVE"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@DeveloperAPI
|
| 35 |
+
class TrainWorkerInfo(BaseModel):
|
| 36 |
+
"""Metadata of a Ray Train worker."""
|
| 37 |
+
|
| 38 |
+
actor_id: str = Field(description="Actor ID of the worker.")
|
| 39 |
+
world_rank: int = Field(description="World rank of the worker.")
|
| 40 |
+
local_rank: int = Field(description="Local rank of the worker.")
|
| 41 |
+
node_rank: int = Field(description="Node rank of the worker.")
|
| 42 |
+
node_id: str = Field(description="ID of the node that the worker is running on.")
|
| 43 |
+
node_ip: str = Field(
|
| 44 |
+
description="IP address of the node that the worker is running on."
|
| 45 |
+
)
|
| 46 |
+
pid: int = Field(description="Process ID of the worker.")
|
| 47 |
+
gpu_ids: List[int] = Field(
|
| 48 |
+
description="A list of GPU ids allocated to that worker."
|
| 49 |
+
)
|
| 50 |
+
status: Optional[ActorStatusEnum] = Field(
|
| 51 |
+
description="The status of the train worker actor. It can be ALIVE or DEAD."
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@DeveloperAPI
|
| 56 |
+
class MemoryInfo(BaseModel):
|
| 57 |
+
rss: int
|
| 58 |
+
vms: int
|
| 59 |
+
pfaults: Optional[int]
|
| 60 |
+
pageins: Optional[int]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@DeveloperAPI
|
| 64 |
+
class ProcessStats(BaseModel):
|
| 65 |
+
cpuPercent: float
|
| 66 |
+
# total memory, free memory, memory used ratio
|
| 67 |
+
mem: Optional[List[int]]
|
| 68 |
+
memoryInfo: MemoryInfo
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class ProcessGPUUsage(BaseModel):
|
| 72 |
+
# This gpu usage stats from a process
|
| 73 |
+
pid: int
|
| 74 |
+
gpuMemoryUsage: int
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@DeveloperAPI
|
| 78 |
+
class GPUStats(BaseModel):
|
| 79 |
+
uuid: str
|
| 80 |
+
index: int
|
| 81 |
+
name: str
|
| 82 |
+
utilizationGpu: Optional[float]
|
| 83 |
+
memoryUsed: float
|
| 84 |
+
memoryTotal: float
|
| 85 |
+
processInfo: ProcessGPUUsage
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@DeveloperAPI
|
| 89 |
+
class TrainWorkerInfoWithDetails(TrainWorkerInfo):
|
| 90 |
+
"""Metadata of a Ray Train worker."""
|
| 91 |
+
|
| 92 |
+
processStats: Optional[ProcessStats] = Field(
|
| 93 |
+
None, description="Process stats of the worker."
|
| 94 |
+
)
|
| 95 |
+
gpus: List[GPUStats] = Field(
|
| 96 |
+
default_factory=list,
|
| 97 |
+
description=(
|
| 98 |
+
"GPU stats of the worker. "
|
| 99 |
+
"Only returns GPUs that are attached to the worker process."
|
| 100 |
+
),
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@DeveloperAPI
|
| 105 |
+
class TrainDatasetInfo(BaseModel):
|
| 106 |
+
name: str = Field(
|
| 107 |
+
description="The key of the dataset dict specified in Ray Train Trainer."
|
| 108 |
+
)
|
| 109 |
+
dataset_uuid: str = Field(description="The uuid of the dataset.")
|
| 110 |
+
dataset_name: Optional[str] = Field(description="The name of the dataset.")
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@DeveloperAPI
|
| 114 |
+
class TrainRunInfo(BaseModel):
|
| 115 |
+
"""Metadata for a Ray Train run and information about its workers."""
|
| 116 |
+
|
| 117 |
+
name: str = Field(description="The name of the Train run.")
|
| 118 |
+
id: str = Field(description="The unique identifier for each Train run.")
|
| 119 |
+
job_id: str = Field(description="The Ray Job ID.")
|
| 120 |
+
controller_actor_id: str = Field(description="Actor Id of the Train controller.")
|
| 121 |
+
workers: List[TrainWorkerInfo] = Field(
|
| 122 |
+
description="A List of Train workers sorted by global ranks."
|
| 123 |
+
)
|
| 124 |
+
datasets: List[TrainDatasetInfo] = Field(
|
| 125 |
+
description="A List of dataset info for this Train run."
|
| 126 |
+
)
|
| 127 |
+
run_status: RunStatusEnum = Field(
|
| 128 |
+
description="The current status of the train run. It can be one of the "
|
| 129 |
+
"following: RUNNING, FINISHED, ERRORED, or ABORTED."
|
| 130 |
+
)
|
| 131 |
+
status_detail: str = Field(
|
| 132 |
+
description="Detailed information about the current run status, "
|
| 133 |
+
"such as error messages."
|
| 134 |
+
)
|
| 135 |
+
start_time_ms: int = Field(
|
| 136 |
+
description="The UNIX timestamp of the start time of this Train run."
|
| 137 |
+
)
|
| 138 |
+
end_time_ms: Optional[int] = Field(
|
| 139 |
+
description="The UNIX timestamp of the end time of this Train run. "
|
| 140 |
+
"If null, the Train run has not ended yet."
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@DeveloperAPI
|
| 145 |
+
class TrainRunInfoWithDetails(TrainRunInfo):
|
| 146 |
+
"""Metadata for a Ray Train run and information about its workers."""
|
| 147 |
+
|
| 148 |
+
workers: List[TrainWorkerInfoWithDetails] = Field(
|
| 149 |
+
description="A List of Train workers sorted by global ranks."
|
| 150 |
+
)
|
| 151 |
+
job_details: Optional[JobDetails] = Field(
|
| 152 |
+
None, description="Details of the job that started this Train run."
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
@DeveloperAPI
|
| 157 |
+
class TrainRunsResponse(BaseModel):
|
| 158 |
+
train_runs: List[TrainRunInfoWithDetails]
|
.venv/lib/python3.11/site-packages/ray/train/_internal/state/state_actor.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import threading
|
| 3 |
+
from typing import Dict, Optional
|
| 4 |
+
|
| 5 |
+
import ray
|
| 6 |
+
from ray.actor import ActorHandle
|
| 7 |
+
from ray.train._internal.state.schema import TrainRunInfo
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@ray.remote(num_cpus=0)
|
| 13 |
+
class TrainStateActor:
|
| 14 |
+
def __init__(self):
|
| 15 |
+
self._run_infos: Dict[str, TrainRunInfo] = {}
|
| 16 |
+
|
| 17 |
+
def register_train_run(self, run_info: TrainRunInfo) -> None:
|
| 18 |
+
# Register a new train run.
|
| 19 |
+
self._run_infos[run_info.id] = run_info
|
| 20 |
+
|
| 21 |
+
def get_train_run(self, run_id: str) -> Optional[TrainRunInfo]:
|
| 22 |
+
# Retrieve a registered run with its id
|
| 23 |
+
return self._run_infos.get(run_id, None)
|
| 24 |
+
|
| 25 |
+
def get_all_train_runs(self) -> Dict[str, TrainRunInfo]:
|
| 26 |
+
# Retrieve all registered train runs
|
| 27 |
+
return self._run_infos
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
TRAIN_STATE_ACTOR_NAME = "train_state_actor"
|
| 31 |
+
TRAIN_STATE_ACTOR_NAMESPACE = "_train_state_actor"
|
| 32 |
+
|
| 33 |
+
_state_actor_lock: threading.RLock = threading.RLock()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_or_create_state_actor() -> ActorHandle:
|
| 37 |
+
"""Get or create a `TrainStateActor` on the head node."""
|
| 38 |
+
with _state_actor_lock:
|
| 39 |
+
state_actor = TrainStateActor.options(
|
| 40 |
+
name=TRAIN_STATE_ACTOR_NAME,
|
| 41 |
+
namespace=TRAIN_STATE_ACTOR_NAMESPACE,
|
| 42 |
+
get_if_exists=True,
|
| 43 |
+
lifetime="detached",
|
| 44 |
+
resources={"node:__internal_head__": 0.001},
|
| 45 |
+
# Escape from the parent's placement group
|
| 46 |
+
scheduling_strategy="DEFAULT",
|
| 47 |
+
).remote()
|
| 48 |
+
|
| 49 |
+
# Ensure the state actor is ready
|
| 50 |
+
ray.get(state_actor.__ray_ready__.remote())
|
| 51 |
+
return state_actor
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def get_state_actor() -> Optional[ActorHandle]:
|
| 55 |
+
"""Get the `TrainStateActor` if exists, otherwise return None."""
|
| 56 |
+
try:
|
| 57 |
+
return ray.get_actor(
|
| 58 |
+
name=TRAIN_STATE_ACTOR_NAME,
|
| 59 |
+
namespace=TRAIN_STATE_ACTOR_NAMESPACE,
|
| 60 |
+
)
|
| 61 |
+
except ValueError:
|
| 62 |
+
return None
|
.venv/lib/python3.11/site-packages/ray/train/_internal/state/state_manager.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from typing import Any, Dict
|
| 5 |
+
|
| 6 |
+
import ray
|
| 7 |
+
from ray.data import Dataset
|
| 8 |
+
from ray.train._internal.state.schema import (
|
| 9 |
+
RunStatusEnum,
|
| 10 |
+
TrainDatasetInfo,
|
| 11 |
+
TrainRunInfo,
|
| 12 |
+
TrainWorkerInfo,
|
| 13 |
+
)
|
| 14 |
+
from ray.train._internal.utils import check_for_failure
|
| 15 |
+
from ray.train._internal.worker_group import WorkerGroup
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TrainRunStateManager:
|
| 21 |
+
"""A class that aggregates and reports train run info to TrainStateActor.
|
| 22 |
+
|
| 23 |
+
This manager class is created on the train controller layer for each run.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, state_actor) -> None:
|
| 27 |
+
self.state_actor = state_actor
|
| 28 |
+
self.train_run_info_dict = defaultdict(dict)
|
| 29 |
+
|
| 30 |
+
def register_train_run(
|
| 31 |
+
self,
|
| 32 |
+
run_id: str,
|
| 33 |
+
job_id: str,
|
| 34 |
+
run_name: str,
|
| 35 |
+
run_status: str,
|
| 36 |
+
controller_actor_id: str,
|
| 37 |
+
datasets: Dict[str, Dataset],
|
| 38 |
+
worker_group: WorkerGroup,
|
| 39 |
+
start_time_ms: float,
|
| 40 |
+
status_detail: str = "",
|
| 41 |
+
) -> None:
|
| 42 |
+
"""Collect Train Run Info and report to StateActor."""
|
| 43 |
+
|
| 44 |
+
if not self.state_actor:
|
| 45 |
+
logger.warning(
|
| 46 |
+
"Unable to register train run since `TrainStateActor` is not started."
|
| 47 |
+
)
|
| 48 |
+
return
|
| 49 |
+
|
| 50 |
+
def collect_train_worker_info():
|
| 51 |
+
train_context = ray.train.get_context()
|
| 52 |
+
core_context = ray.runtime_context.get_runtime_context()
|
| 53 |
+
|
| 54 |
+
return TrainWorkerInfo(
|
| 55 |
+
world_rank=train_context.get_world_rank(),
|
| 56 |
+
local_rank=train_context.get_local_rank(),
|
| 57 |
+
node_rank=train_context.get_node_rank(),
|
| 58 |
+
actor_id=core_context.get_actor_id(),
|
| 59 |
+
node_id=core_context.get_node_id(),
|
| 60 |
+
node_ip=ray.util.get_node_ip_address(),
|
| 61 |
+
gpu_ids=ray.get_gpu_ids(),
|
| 62 |
+
pid=os.getpid(),
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
futures = [
|
| 66 |
+
worker_group.execute_single_async(index, collect_train_worker_info)
|
| 67 |
+
for index in range(len(worker_group))
|
| 68 |
+
]
|
| 69 |
+
success, exception = check_for_failure(futures)
|
| 70 |
+
|
| 71 |
+
if not success:
|
| 72 |
+
logger.error(
|
| 73 |
+
"Failed to collect run information from the Ray Train "
|
| 74 |
+
f"workers:\n{exception}"
|
| 75 |
+
)
|
| 76 |
+
return
|
| 77 |
+
|
| 78 |
+
worker_info_list = ray.get(futures)
|
| 79 |
+
worker_info_list = sorted(worker_info_list, key=lambda info: info.world_rank)
|
| 80 |
+
|
| 81 |
+
dataset_info_list = [
|
| 82 |
+
TrainDatasetInfo(
|
| 83 |
+
name=ds_name,
|
| 84 |
+
dataset_name=ds._plan._dataset_name,
|
| 85 |
+
dataset_uuid=ds._plan._dataset_uuid,
|
| 86 |
+
)
|
| 87 |
+
for ds_name, ds in datasets.items()
|
| 88 |
+
]
|
| 89 |
+
|
| 90 |
+
updates = dict(
|
| 91 |
+
id=run_id,
|
| 92 |
+
job_id=job_id,
|
| 93 |
+
name=run_name,
|
| 94 |
+
controller_actor_id=controller_actor_id,
|
| 95 |
+
workers=worker_info_list,
|
| 96 |
+
datasets=dataset_info_list,
|
| 97 |
+
start_time_ms=start_time_ms,
|
| 98 |
+
run_status=run_status,
|
| 99 |
+
status_detail=status_detail,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Clear the cached info to avoid registering the same run twice
|
| 103 |
+
self.train_run_info_dict[run_id] = {}
|
| 104 |
+
self._update_train_run_info(run_id, updates)
|
| 105 |
+
|
| 106 |
+
def end_train_run(
|
| 107 |
+
self,
|
| 108 |
+
run_id: str,
|
| 109 |
+
run_status: RunStatusEnum,
|
| 110 |
+
status_detail: str,
|
| 111 |
+
end_time_ms: int,
|
| 112 |
+
):
|
| 113 |
+
"""Update the train run status when the training is finished."""
|
| 114 |
+
updates = dict(
|
| 115 |
+
run_status=run_status,
|
| 116 |
+
status_detail=status_detail,
|
| 117 |
+
end_time_ms=end_time_ms,
|
| 118 |
+
)
|
| 119 |
+
self._update_train_run_info(run_id, updates)
|
| 120 |
+
|
| 121 |
+
def _update_train_run_info(self, run_id: str, updates: Dict[str, Any]) -> None:
|
| 122 |
+
"""Update specific fields of a registered TrainRunInfo instance."""
|
| 123 |
+
if run_id in self.train_run_info_dict:
|
| 124 |
+
self.train_run_info_dict[run_id].update(updates)
|
| 125 |
+
train_run_info = TrainRunInfo(**self.train_run_info_dict[run_id])
|
| 126 |
+
ray.get(self.state_actor.register_train_run.remote(train_run_info))
|
.venv/lib/python3.11/site-packages/ray/train/_internal/storage.py
ADDED
|
@@ -0,0 +1,725 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Try import ray[train] core requirements (defined in setup.py)
|
| 2 |
+
# isort: off
|
| 3 |
+
try:
|
| 4 |
+
import fsspec # noqa
|
| 5 |
+
from fsspec.implementations.local import LocalFileSystem
|
| 6 |
+
|
| 7 |
+
except (ImportError, ModuleNotFoundError) as e:
|
| 8 |
+
raise RuntimeError(
|
| 9 |
+
"fsspec is a required dependency of Ray Train and Ray Tune. "
|
| 10 |
+
"Please install with: `pip install fsspec`"
|
| 11 |
+
) from e
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
import pyarrow
|
| 15 |
+
import pyarrow.fs
|
| 16 |
+
|
| 17 |
+
except (ImportError, ModuleNotFoundError) as e:
|
| 18 |
+
raise RuntimeError(
|
| 19 |
+
"pyarrow is a required dependency of Ray Train and Ray Tune. "
|
| 20 |
+
"Please install with: `pip install pyarrow`"
|
| 21 |
+
) from e
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
# check if Arrow has S3 support
|
| 25 |
+
from pyarrow.fs import S3FileSystem
|
| 26 |
+
except ImportError:
|
| 27 |
+
S3FileSystem = None
|
| 28 |
+
# isort: on
|
| 29 |
+
|
| 30 |
+
import fnmatch
|
| 31 |
+
import logging
|
| 32 |
+
import os
|
| 33 |
+
import shutil
|
| 34 |
+
from pathlib import Path
|
| 35 |
+
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Type, Union
|
| 36 |
+
|
| 37 |
+
from ray.air._internal.filelock import TempFileLock
|
| 38 |
+
from ray.train._internal.syncer import SyncConfig, Syncer, _BackgroundSyncer
|
| 39 |
+
from ray.train.constants import _get_ray_train_session_dir
|
| 40 |
+
from ray.util.annotations import DeveloperAPI
|
| 41 |
+
|
| 42 |
+
if TYPE_CHECKING:
|
| 43 |
+
from ray.train._checkpoint import Checkpoint
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
logger = logging.getLogger(__name__)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
_VALIDATE_STORAGE_MARKER_FILENAME = ".validate_storage_marker"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class _ExcludingLocalFilesystem(LocalFileSystem):
|
| 53 |
+
"""LocalFileSystem wrapper to exclude files according to patterns.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
root_path: Root path to strip when matching with the exclude pattern.
|
| 57 |
+
Ex: root_path="/tmp/a/b/c", exclude=["*a*"], will exclude
|
| 58 |
+
/tmp/a/b/c/_a_.txt but not ALL of /tmp/a/*.
|
| 59 |
+
exclude: List of patterns that are applied to files returned by
|
| 60 |
+
``self.find()``. If a file path matches this pattern, it will
|
| 61 |
+
be excluded.
|
| 62 |
+
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def __init__(self, root_path: Path, exclude: List[str], **kwargs):
|
| 66 |
+
super().__init__(**kwargs)
|
| 67 |
+
self._exclude = exclude
|
| 68 |
+
self._root_path = root_path
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def fsid(self):
|
| 72 |
+
return "_excluding_local"
|
| 73 |
+
|
| 74 |
+
def _should_exclude(self, path: str) -> bool:
|
| 75 |
+
"""Return True if `path` (relative to `root_path`) matches any of the
|
| 76 |
+
`self._exclude` patterns."""
|
| 77 |
+
path = Path(path)
|
| 78 |
+
relative_path = path.relative_to(self._root_path).as_posix()
|
| 79 |
+
match_candidates = [relative_path]
|
| 80 |
+
if path.is_dir():
|
| 81 |
+
# Everything is in posix path format ('/')
|
| 82 |
+
match_candidates.append(relative_path + "/")
|
| 83 |
+
|
| 84 |
+
for excl in self._exclude:
|
| 85 |
+
if any(fnmatch.fnmatch(candidate, excl) for candidate in match_candidates):
|
| 86 |
+
return True
|
| 87 |
+
return False
|
| 88 |
+
|
| 89 |
+
def find(self, path, maxdepth=None, withdirs=False, detail=False, **kwargs):
|
| 90 |
+
"""Call parent find() and exclude from result."""
|
| 91 |
+
paths = super().find(
|
| 92 |
+
path, maxdepth=maxdepth, withdirs=withdirs, detail=detail, **kwargs
|
| 93 |
+
)
|
| 94 |
+
if detail:
|
| 95 |
+
return {
|
| 96 |
+
path: out
|
| 97 |
+
for path, out in paths.items()
|
| 98 |
+
if not self._should_exclude(path)
|
| 99 |
+
}
|
| 100 |
+
else:
|
| 101 |
+
return [path for path in paths if not self._should_exclude(path)]
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _pyarrow_fs_copy_files(
|
| 105 |
+
source, destination, source_filesystem=None, destination_filesystem=None, **kwargs
|
| 106 |
+
):
|
| 107 |
+
if S3FileSystem and isinstance(destination_filesystem, pyarrow.fs.S3FileSystem):
|
| 108 |
+
# Workaround multi-threading issue with pyarrow. Note that use_threads=True
|
| 109 |
+
# is safe for download, just not for uploads, see:
|
| 110 |
+
# https://github.com/apache/arrow/issues/32372
|
| 111 |
+
kwargs.setdefault("use_threads", False)
|
| 112 |
+
|
| 113 |
+
# Use a large chunk size to speed up large checkpoint transfers.
|
| 114 |
+
kwargs.setdefault("chunk_size", 64 * 1024 * 1024)
|
| 115 |
+
|
| 116 |
+
return pyarrow.fs.copy_files(
|
| 117 |
+
source,
|
| 118 |
+
destination,
|
| 119 |
+
source_filesystem=source_filesystem,
|
| 120 |
+
destination_filesystem=destination_filesystem,
|
| 121 |
+
**kwargs,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# TODO(justinvyu): Add unit tests for all these utils.
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _delete_fs_path(fs: pyarrow.fs.FileSystem, fs_path: str):
|
| 129 |
+
is_dir = _is_directory(fs, fs_path)
|
| 130 |
+
|
| 131 |
+
try:
|
| 132 |
+
if is_dir:
|
| 133 |
+
fs.delete_dir(fs_path)
|
| 134 |
+
else:
|
| 135 |
+
fs.delete_file(fs_path)
|
| 136 |
+
except Exception:
|
| 137 |
+
logger.exception(f"Caught exception when deleting path at ({fs}, {fs_path}):")
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def _download_from_fs_path(
|
| 141 |
+
fs: pyarrow.fs.FileSystem,
|
| 142 |
+
fs_path: str,
|
| 143 |
+
local_path: str,
|
| 144 |
+
filelock: bool = True,
|
| 145 |
+
):
|
| 146 |
+
"""Downloads a directory or file from (fs, fs_path) to a local path.
|
| 147 |
+
|
| 148 |
+
If fs_path points to a directory:
|
| 149 |
+
- The full directory contents are downloaded directly into `local_path`,
|
| 150 |
+
rather than to a subdirectory of `local_path`.
|
| 151 |
+
|
| 152 |
+
If fs_path points to a file:
|
| 153 |
+
- The file is downloaded to `local_path`, which is expected to be a file path.
|
| 154 |
+
|
| 155 |
+
If the download fails, the `local_path` contents are
|
| 156 |
+
cleaned up before raising, if the directory did not previously exist.
|
| 157 |
+
|
| 158 |
+
NOTE: This method creates `local_path`'s parent directories if they do not
|
| 159 |
+
already exist. If the download fails, this does NOT clean up all the parent
|
| 160 |
+
directories that were created.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
fs: The filesystem to download from.
|
| 164 |
+
fs_path: The filesystem path (either a directory or a file) to download.
|
| 165 |
+
local_path: The local path to download to.
|
| 166 |
+
filelock: Whether to require a file lock before downloading, useful for
|
| 167 |
+
multiple downloads to the same directory that may be happening in parallel.
|
| 168 |
+
|
| 169 |
+
Raises:
|
| 170 |
+
FileNotFoundError: if (fs, fs_path) doesn't exist.
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
_local_path = Path(local_path).resolve()
|
| 174 |
+
exists_before = _local_path.exists()
|
| 175 |
+
if _is_directory(fs=fs, fs_path=fs_path):
|
| 176 |
+
_local_path.mkdir(parents=True, exist_ok=True)
|
| 177 |
+
else:
|
| 178 |
+
_local_path.parent.mkdir(parents=True, exist_ok=True)
|
| 179 |
+
|
| 180 |
+
try:
|
| 181 |
+
if filelock:
|
| 182 |
+
with TempFileLock(f"{os.path.normpath(local_path)}.lock"):
|
| 183 |
+
_pyarrow_fs_copy_files(fs_path, local_path, source_filesystem=fs)
|
| 184 |
+
else:
|
| 185 |
+
_pyarrow_fs_copy_files(fs_path, local_path, source_filesystem=fs)
|
| 186 |
+
except Exception as e:
|
| 187 |
+
# Clean up the directory if downloading was unsuccessful
|
| 188 |
+
if not exists_before:
|
| 189 |
+
shutil.rmtree(local_path, ignore_errors=True)
|
| 190 |
+
raise e
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def _upload_to_fs_path(
|
| 194 |
+
local_path: str,
|
| 195 |
+
fs: pyarrow.fs.FileSystem,
|
| 196 |
+
fs_path: str,
|
| 197 |
+
exclude: Optional[List[str]] = None,
|
| 198 |
+
) -> None:
|
| 199 |
+
"""Uploads a local directory or file to (fs, fs_path).
|
| 200 |
+
|
| 201 |
+
NOTE: This will create all necessary parent directories at the destination.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
local_path: The local path to upload.
|
| 205 |
+
fs: The filesystem to upload to.
|
| 206 |
+
fs_path: The filesystem path where the dir/file will be uploaded to.
|
| 207 |
+
exclude: A list of filename matches to exclude from upload. This includes
|
| 208 |
+
all files under subdirectories as well.
|
| 209 |
+
This pattern will match with the relative paths of all files under
|
| 210 |
+
`local_path`.
|
| 211 |
+
Ex: ["*.png"] to exclude all .png images.
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
if not exclude:
|
| 215 |
+
# TODO(justinvyu): uploading a single file doesn't work
|
| 216 |
+
# (since we always create a directory at fs_path)
|
| 217 |
+
_create_directory(fs=fs, fs_path=fs_path)
|
| 218 |
+
_pyarrow_fs_copy_files(local_path, fs_path, destination_filesystem=fs)
|
| 219 |
+
return
|
| 220 |
+
|
| 221 |
+
_upload_to_uri_with_exclude_fsspec(
|
| 222 |
+
local_path=local_path, fs=fs, fs_path=fs_path, exclude=exclude
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def _upload_to_uri_with_exclude_fsspec(
|
| 227 |
+
local_path: str, fs: "pyarrow.fs", fs_path: str, exclude: Optional[List[str]]
|
| 228 |
+
) -> None:
|
| 229 |
+
local_fs = _ExcludingLocalFilesystem(root_path=local_path, exclude=exclude)
|
| 230 |
+
handler = pyarrow.fs.FSSpecHandler(local_fs)
|
| 231 |
+
source_fs = pyarrow.fs.PyFileSystem(handler)
|
| 232 |
+
|
| 233 |
+
_create_directory(fs=fs, fs_path=fs_path)
|
| 234 |
+
_pyarrow_fs_copy_files(
|
| 235 |
+
local_path, fs_path, source_filesystem=source_fs, destination_filesystem=fs
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def _list_at_fs_path(
|
| 240 |
+
fs: pyarrow.fs.FileSystem,
|
| 241 |
+
fs_path: str,
|
| 242 |
+
file_filter: Optional[Callable[[pyarrow.fs.FileInfo], bool]] = None,
|
| 243 |
+
) -> List[str]:
|
| 244 |
+
"""Returns the list of filenames at (fs, fs_path), similar to os.listdir.
|
| 245 |
+
|
| 246 |
+
If the path doesn't exist, returns an empty list.
|
| 247 |
+
"""
|
| 248 |
+
if file_filter is None:
|
| 249 |
+
file_filter = lambda x: True # noqa: E731
|
| 250 |
+
|
| 251 |
+
selector = pyarrow.fs.FileSelector(fs_path, allow_not_found=True, recursive=False)
|
| 252 |
+
return [
|
| 253 |
+
os.path.relpath(file_info.path.lstrip("/"), start=fs_path.lstrip("/"))
|
| 254 |
+
for file_info in fs.get_file_info(selector)
|
| 255 |
+
if file_filter(file_info)
|
| 256 |
+
]
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def _exists_at_fs_path(fs: pyarrow.fs.FileSystem, fs_path: str) -> bool:
|
| 260 |
+
"""Returns True if (fs, fs_path) exists."""
|
| 261 |
+
|
| 262 |
+
valid = fs.get_file_info(fs_path)
|
| 263 |
+
return valid.type != pyarrow.fs.FileType.NotFound
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def _is_directory(fs: pyarrow.fs.FileSystem, fs_path: str) -> bool:
|
| 267 |
+
"""Checks if (fs, fs_path) is a directory or a file.
|
| 268 |
+
|
| 269 |
+
Raises:
|
| 270 |
+
FileNotFoundError: if (fs, fs_path) doesn't exist.
|
| 271 |
+
"""
|
| 272 |
+
|
| 273 |
+
file_info = fs.get_file_info(fs_path)
|
| 274 |
+
if file_info.type == pyarrow.fs.FileType.NotFound:
|
| 275 |
+
raise FileNotFoundError(f"Path not found: ({fs}, {fs_path})")
|
| 276 |
+
|
| 277 |
+
return not file_info.is_file
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def _create_directory(fs: pyarrow.fs.FileSystem, fs_path: str) -> None:
|
| 281 |
+
"""Create directory at (fs, fs_path).
|
| 282 |
+
|
| 283 |
+
Some external filesystems require directories to already exist, or at least
|
| 284 |
+
the `netloc` to be created (e.g. PyArrows ``mock://`` filesystem).
|
| 285 |
+
|
| 286 |
+
Generally this should be done before and outside of Ray applications. This
|
| 287 |
+
utility is thus primarily used in testing, e.g. of ``mock://` URIs.
|
| 288 |
+
"""
|
| 289 |
+
try:
|
| 290 |
+
fs.create_dir(fs_path)
|
| 291 |
+
except Exception:
|
| 292 |
+
logger.exception(
|
| 293 |
+
f"Caught exception when creating directory at ({fs}, {fs_path}):"
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def get_fs_and_path(
|
| 298 |
+
storage_path: Union[str, os.PathLike],
|
| 299 |
+
storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
|
| 300 |
+
) -> Tuple[pyarrow.fs.FileSystem, str]:
|
| 301 |
+
"""Returns the fs and path from a storage path and an optional custom fs.
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
storage_path: A storage path or URI. (ex: s3://bucket/path or /tmp/ray_results)
|
| 305 |
+
storage_filesystem: A custom filesystem to use. If not provided,
|
| 306 |
+
this will be auto-resolved by pyarrow. If provided, the storage_path
|
| 307 |
+
is assumed to be prefix-stripped already, and must be a valid path
|
| 308 |
+
on the filesystem.
|
| 309 |
+
"""
|
| 310 |
+
storage_path = str(storage_path)
|
| 311 |
+
|
| 312 |
+
if storage_filesystem:
|
| 313 |
+
return storage_filesystem, storage_path
|
| 314 |
+
|
| 315 |
+
return pyarrow.fs.FileSystem.from_uri(storage_path)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class _FilesystemSyncer(_BackgroundSyncer):
|
| 319 |
+
"""Syncer between local filesystem and a `storage_filesystem`."""
|
| 320 |
+
|
| 321 |
+
def __init__(self, storage_filesystem: Optional["pyarrow.fs.FileSystem"], **kwargs):
|
| 322 |
+
self.storage_filesystem = storage_filesystem
|
| 323 |
+
super().__init__(**kwargs)
|
| 324 |
+
|
| 325 |
+
def _sync_up_command(
|
| 326 |
+
self, local_path: str, uri: str, exclude: Optional[List] = None
|
| 327 |
+
) -> Tuple[Callable, Dict]:
|
| 328 |
+
# TODO(justinvyu): Defer this cleanup up as part of the
|
| 329 |
+
# external-facing Syncer deprecation.
|
| 330 |
+
fs_path = uri
|
| 331 |
+
return (
|
| 332 |
+
_upload_to_fs_path,
|
| 333 |
+
dict(
|
| 334 |
+
local_path=local_path,
|
| 335 |
+
fs=self.storage_filesystem,
|
| 336 |
+
fs_path=fs_path,
|
| 337 |
+
exclude=exclude,
|
| 338 |
+
),
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
def _sync_down_command(self, uri: str, local_path: str) -> Tuple[Callable, Dict]:
|
| 342 |
+
fs_path = uri
|
| 343 |
+
return (
|
| 344 |
+
_download_from_fs_path,
|
| 345 |
+
dict(
|
| 346 |
+
fs=self.storage_filesystem,
|
| 347 |
+
fs_path=fs_path,
|
| 348 |
+
local_path=local_path,
|
| 349 |
+
),
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
def _delete_command(self, uri: str) -> Tuple[Callable, Dict]:
|
| 353 |
+
fs_path = uri
|
| 354 |
+
return _delete_fs_path, dict(fs=self.storage_filesystem, fs_path=fs_path)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
@DeveloperAPI
|
| 358 |
+
class StorageContext:
|
| 359 |
+
"""Shared context that holds the source of truth for all paths and
|
| 360 |
+
storage utilities, passed along from the driver to workers.
|
| 361 |
+
|
| 362 |
+
This object defines a few types of paths:
|
| 363 |
+
1. *_fs_path: A path on the `storage_filesystem`. This is a regular path
|
| 364 |
+
which has been prefix-stripped by pyarrow.fs.FileSystem.from_uri and
|
| 365 |
+
can be joined with `Path(...).as_posix()`.
|
| 366 |
+
2. *_driver_staging_path: The temporary staging directory on the local filesystem
|
| 367 |
+
where driver artifacts are saved to before persisting them to storage.
|
| 368 |
+
3. trial_working_directory: The local filesystem path that the remote
|
| 369 |
+
actors' working directories are moved to by default.
|
| 370 |
+
This is separated from the driver staging path so that driver syncing
|
| 371 |
+
does not implicitly upload the trial working directory, for trials on the
|
| 372 |
+
driver node.
|
| 373 |
+
|
| 374 |
+
Example with storage_path="mock:///bucket/path?param=1":
|
| 375 |
+
|
| 376 |
+
>>> import ray
|
| 377 |
+
>>> from ray.train._internal.storage import StorageContext
|
| 378 |
+
>>> import os
|
| 379 |
+
>>> _ = ray.init()
|
| 380 |
+
>>> storage = StorageContext(
|
| 381 |
+
... storage_path="mock://netloc/bucket/path?param=1",
|
| 382 |
+
... experiment_dir_name="exp_name",
|
| 383 |
+
... )
|
| 384 |
+
>>> storage.storage_filesystem # Auto-resolved # doctest: +ELLIPSIS
|
| 385 |
+
<pyarrow._fs._MockFileSystem object...
|
| 386 |
+
>>> storage.experiment_fs_path
|
| 387 |
+
'bucket/path/exp_name'
|
| 388 |
+
>>> storage.experiment_driver_staging_path # doctest: +ELLIPSIS
|
| 389 |
+
'/tmp/ray/session_.../artifacts/.../exp_name/driver_artifacts'
|
| 390 |
+
>>> storage.trial_dir_name = "trial_dir"
|
| 391 |
+
>>> storage.trial_fs_path
|
| 392 |
+
'bucket/path/exp_name/trial_dir'
|
| 393 |
+
>>> storage.trial_driver_staging_path # doctest: +ELLIPSIS
|
| 394 |
+
'/tmp/ray/session_.../artifacts/.../exp_name/driver_artifacts/trial_dir'
|
| 395 |
+
>>> storage.trial_working_directory # doctest: +ELLIPSIS
|
| 396 |
+
'/tmp/ray/session_.../artifacts/.../exp_name/working_dirs/trial_dir'
|
| 397 |
+
>>> storage.current_checkpoint_index = 1
|
| 398 |
+
>>> storage.checkpoint_fs_path
|
| 399 |
+
'bucket/path/exp_name/trial_dir/checkpoint_000001'
|
| 400 |
+
>>> ray.shutdown()
|
| 401 |
+
|
| 402 |
+
Example with storage_path="/tmp/ray_results":
|
| 403 |
+
|
| 404 |
+
>>> from ray.train._internal.storage import StorageContext
|
| 405 |
+
>>> storage = StorageContext(
|
| 406 |
+
... storage_path="/tmp/ray_results",
|
| 407 |
+
... experiment_dir_name="exp_name",
|
| 408 |
+
... )
|
| 409 |
+
>>> storage.storage_fs_path
|
| 410 |
+
'/tmp/ray_results'
|
| 411 |
+
>>> storage.experiment_fs_path
|
| 412 |
+
'/tmp/ray_results/exp_name'
|
| 413 |
+
>>> storage.storage_filesystem # Auto-resolved # doctest: +ELLIPSIS
|
| 414 |
+
<pyarrow._fs.LocalFileSystem object...
|
| 415 |
+
|
| 416 |
+
Internal Usage Examples:
|
| 417 |
+
- To copy files to the trial directory on the storage filesystem:
|
| 418 |
+
|
| 419 |
+
pyarrow.fs.copy_files(
|
| 420 |
+
local_dir,
|
| 421 |
+
Path(storage.trial_fs_path, "subdir").as_posix(),
|
| 422 |
+
destination_filesystem=storage.filesystem
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
.. warning::
|
| 426 |
+
This is an experimental developer API and is subject to change
|
| 427 |
+
without notice between versions.
|
| 428 |
+
"""
|
| 429 |
+
|
| 430 |
+
def __init__(
|
| 431 |
+
self,
|
| 432 |
+
storage_path: Union[str, os.PathLike],
|
| 433 |
+
experiment_dir_name: str,
|
| 434 |
+
sync_config: Optional[SyncConfig] = None,
|
| 435 |
+
storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
|
| 436 |
+
trial_dir_name: Optional[str] = None,
|
| 437 |
+
current_checkpoint_index: int = -1,
|
| 438 |
+
):
|
| 439 |
+
from ray.tune.utils import date_str
|
| 440 |
+
|
| 441 |
+
self.custom_fs_provided = storage_filesystem is not None
|
| 442 |
+
|
| 443 |
+
# Invariant: (`storage_filesystem`, `storage_path`) is the location where
|
| 444 |
+
# *all* results can be accessed.
|
| 445 |
+
self.experiment_dir_name = experiment_dir_name
|
| 446 |
+
self.trial_dir_name = trial_dir_name
|
| 447 |
+
self.current_checkpoint_index = current_checkpoint_index
|
| 448 |
+
self.sync_config = sync_config or SyncConfig()
|
| 449 |
+
|
| 450 |
+
self.storage_filesystem, self.storage_fs_path = get_fs_and_path(
|
| 451 |
+
storage_path, storage_filesystem
|
| 452 |
+
)
|
| 453 |
+
self.storage_fs_path = Path(self.storage_fs_path).as_posix()
|
| 454 |
+
|
| 455 |
+
self.syncer: Syncer = _FilesystemSyncer(
|
| 456 |
+
storage_filesystem=self.storage_filesystem,
|
| 457 |
+
sync_period=self.sync_config.sync_period,
|
| 458 |
+
sync_timeout=self.sync_config.sync_timeout,
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
self._create_validation_file()
|
| 462 |
+
self._check_validation_file()
|
| 463 |
+
|
| 464 |
+
# Timestamp is used to create a unique session directory for the current
|
| 465 |
+
# training job. This is used to avoid conflicts when multiple training jobs
|
| 466 |
+
# run with the same name in the same cluster.
|
| 467 |
+
# This is set ONCE at the creation of the storage context, on the driver.
|
| 468 |
+
self._timestamp = date_str()
|
| 469 |
+
|
| 470 |
+
def __str__(self):
|
| 471 |
+
return (
|
| 472 |
+
"StorageContext<\n"
|
| 473 |
+
f" storage_filesystem='{self.storage_filesystem.type_name}',\n"
|
| 474 |
+
f" storage_fs_path='{self.storage_fs_path}',\n"
|
| 475 |
+
f" experiment_dir_name='{self.experiment_dir_name}',\n"
|
| 476 |
+
f" trial_dir_name='{self.trial_dir_name}',\n"
|
| 477 |
+
f" current_checkpoint_index={self.current_checkpoint_index},\n"
|
| 478 |
+
">"
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
def _create_validation_file(self):
|
| 482 |
+
"""On the creation of a storage context, create a validation file at the
|
| 483 |
+
storage path to verify that the storage path can be written to.
|
| 484 |
+
This validation file is also used to check whether the storage path is
|
| 485 |
+
accessible by all nodes in the cluster."""
|
| 486 |
+
valid_file = Path(
|
| 487 |
+
self.experiment_fs_path, _VALIDATE_STORAGE_MARKER_FILENAME
|
| 488 |
+
).as_posix()
|
| 489 |
+
self.storage_filesystem.create_dir(self.experiment_fs_path)
|
| 490 |
+
with self.storage_filesystem.open_output_stream(valid_file):
|
| 491 |
+
pass
|
| 492 |
+
|
| 493 |
+
def _check_validation_file(self):
|
| 494 |
+
"""Checks that the validation file exists at the storage path."""
|
| 495 |
+
valid_file = Path(
|
| 496 |
+
self.experiment_fs_path, _VALIDATE_STORAGE_MARKER_FILENAME
|
| 497 |
+
).as_posix()
|
| 498 |
+
if not _exists_at_fs_path(fs=self.storage_filesystem, fs_path=valid_file):
|
| 499 |
+
raise RuntimeError(
|
| 500 |
+
f"Unable to set up cluster storage with the following settings:\n{self}"
|
| 501 |
+
"\nCheck that all nodes in the cluster have read/write access "
|
| 502 |
+
"to the configured storage path. `RunConfig(storage_path)` should be "
|
| 503 |
+
"set to a cloud storage URI or a shared filesystem path accessible "
|
| 504 |
+
"by all nodes in your cluster ('s3://bucket' or '/mnt/nfs'). "
|
| 505 |
+
"A local path on the head node is not accessible by worker nodes. "
|
| 506 |
+
"See: https://docs.ray.io/en/latest/train/user-guides/persistent-storage.html" # noqa: E501
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
def _update_checkpoint_index(self, metrics: Dict):
|
| 510 |
+
# Per default, increase by 1. This can be overwritten to customize checkpoint
|
| 511 |
+
# directories.
|
| 512 |
+
self.current_checkpoint_index += 1
|
| 513 |
+
|
| 514 |
+
def persist_current_checkpoint(self, checkpoint: "Checkpoint") -> "Checkpoint":
|
| 515 |
+
"""Persists a given checkpoint to the current checkpoint path on the filesystem.
|
| 516 |
+
|
| 517 |
+
"Current" is defined by the `current_checkpoint_index` attribute of the
|
| 518 |
+
storage context.
|
| 519 |
+
|
| 520 |
+
This method copies the checkpoint files to the storage location.
|
| 521 |
+
It's up to the user to delete the original checkpoint files if desired.
|
| 522 |
+
|
| 523 |
+
For example, the original directory is typically a local temp directory.
|
| 524 |
+
|
| 525 |
+
Args:
|
| 526 |
+
checkpoint: The checkpoint to persist to (fs, checkpoint_fs_path).
|
| 527 |
+
|
| 528 |
+
Returns:
|
| 529 |
+
Checkpoint: A Checkpoint pointing to the persisted checkpoint location.
|
| 530 |
+
"""
|
| 531 |
+
# TODO(justinvyu): Fix this cyclical import.
|
| 532 |
+
from ray.train._checkpoint import Checkpoint
|
| 533 |
+
|
| 534 |
+
logger.debug(
|
| 535 |
+
"Copying checkpoint files to storage path:\n"
|
| 536 |
+
"({source_fs}, {source}) -> ({dest_fs}, {destination})".format(
|
| 537 |
+
source=checkpoint.path,
|
| 538 |
+
destination=self.checkpoint_fs_path,
|
| 539 |
+
source_fs=checkpoint.filesystem,
|
| 540 |
+
dest_fs=self.storage_filesystem,
|
| 541 |
+
)
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
# Raise an error if the storage path is not accessible when
|
| 545 |
+
# attempting to upload a checkpoint from a remote worker.
|
| 546 |
+
# Ex: If storage_path is a local path, then a validation marker
|
| 547 |
+
# will only exist on the head node but not the worker nodes.
|
| 548 |
+
self._check_validation_file()
|
| 549 |
+
|
| 550 |
+
self.storage_filesystem.create_dir(self.checkpoint_fs_path)
|
| 551 |
+
_pyarrow_fs_copy_files(
|
| 552 |
+
source=checkpoint.path,
|
| 553 |
+
destination=self.checkpoint_fs_path,
|
| 554 |
+
source_filesystem=checkpoint.filesystem,
|
| 555 |
+
destination_filesystem=self.storage_filesystem,
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
persisted_checkpoint = Checkpoint(
|
| 559 |
+
filesystem=self.storage_filesystem,
|
| 560 |
+
path=self.checkpoint_fs_path,
|
| 561 |
+
)
|
| 562 |
+
logger.info(f"Checkpoint successfully created at: {persisted_checkpoint}")
|
| 563 |
+
return persisted_checkpoint
|
| 564 |
+
|
| 565 |
+
def persist_artifacts(self, force: bool = False) -> None:
|
| 566 |
+
"""Persists all artifacts within `trial_local_dir` to storage.
|
| 567 |
+
|
| 568 |
+
This method possibly launches a background task to sync the trial dir,
|
| 569 |
+
depending on the `sync_period` + `sync_artifacts_on_checkpoint`
|
| 570 |
+
settings of `SyncConfig`.
|
| 571 |
+
|
| 572 |
+
`(local_fs, trial_working_dir) -> (storage_filesystem, trial_fs_path)`
|
| 573 |
+
|
| 574 |
+
Args:
|
| 575 |
+
force: If True, wait for a previous sync to finish, launch a new one,
|
| 576 |
+
and wait for that one to finish. By the end of a `force=True` call, the
|
| 577 |
+
latest version of the trial artifacts will be persisted.
|
| 578 |
+
"""
|
| 579 |
+
if not self.sync_config.sync_artifacts:
|
| 580 |
+
return
|
| 581 |
+
|
| 582 |
+
# Skip if there are no artifacts to sync
|
| 583 |
+
is_empty = not any(os.scandir(self.trial_working_directory))
|
| 584 |
+
if is_empty:
|
| 585 |
+
return
|
| 586 |
+
|
| 587 |
+
if force:
|
| 588 |
+
self.syncer.wait()
|
| 589 |
+
self.syncer.sync_up(
|
| 590 |
+
local_dir=self.trial_working_directory, remote_dir=self.trial_fs_path
|
| 591 |
+
)
|
| 592 |
+
self.syncer.wait()
|
| 593 |
+
else:
|
| 594 |
+
self.syncer.sync_up_if_needed(
|
| 595 |
+
local_dir=self.trial_working_directory, remote_dir=self.trial_fs_path
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
@property
|
| 599 |
+
def experiment_fs_path(self) -> str:
|
| 600 |
+
"""The path on the `storage_filesystem` to the experiment directory.
|
| 601 |
+
|
| 602 |
+
NOTE: This does not have a URI prefix anymore, since it has been stripped
|
| 603 |
+
by pyarrow.fs.FileSystem.from_uri already. The URI scheme information is
|
| 604 |
+
kept in `storage_filesystem` instead.
|
| 605 |
+
"""
|
| 606 |
+
return Path(self.storage_fs_path, self.experiment_dir_name).as_posix()
|
| 607 |
+
|
| 608 |
+
def _get_session_path(self) -> str:
|
| 609 |
+
"""The Ray Train/Tune session local directory used to stage files
|
| 610 |
+
before persisting to the storage filesystem."""
|
| 611 |
+
return Path(
|
| 612 |
+
_get_ray_train_session_dir(), self._timestamp, self.experiment_dir_name
|
| 613 |
+
).as_posix()
|
| 614 |
+
|
| 615 |
+
@property
|
| 616 |
+
def experiment_driver_staging_path(self) -> str:
|
| 617 |
+
"""The local filesystem path of the experiment directory on the driver node.
|
| 618 |
+
|
| 619 |
+
The driver is the node where `Trainer.fit`/`Tuner.fit` is being called.
|
| 620 |
+
|
| 621 |
+
This path is of the form:
|
| 622 |
+
`/tmp/ray/session_<session_id>/artifacts/<ray-train-job-timestamp>/
|
| 623 |
+
<experiment_dir_name>/driver_artifacts`
|
| 624 |
+
|
| 625 |
+
This should be used as the temporary staging location for files *on the driver*
|
| 626 |
+
before syncing them to `experiment_fs_path`.
|
| 627 |
+
For example, the search algorithm should dump its state to this directory.
|
| 628 |
+
See `trial_driver_staging_path` for writing trial-specific artifacts.
|
| 629 |
+
|
| 630 |
+
The directory is synced to
|
| 631 |
+
`{storage_path}/{experiment_dir_name}` periodically.
|
| 632 |
+
See `_ExperimentCheckpointManager.checkpoint` for where that happens.
|
| 633 |
+
"""
|
| 634 |
+
return Path(self._get_session_path(), "driver_artifacts").as_posix()
|
| 635 |
+
|
| 636 |
+
@property
|
| 637 |
+
def trial_fs_path(self) -> str:
|
| 638 |
+
"""The trial directory path on the `storage_filesystem`.
|
| 639 |
+
|
| 640 |
+
Raises a ValueError if `trial_dir_name` is not set beforehand.
|
| 641 |
+
"""
|
| 642 |
+
if self.trial_dir_name is None:
|
| 643 |
+
raise RuntimeError(
|
| 644 |
+
"Should not access `trial_fs_path` without setting `trial_dir_name`"
|
| 645 |
+
)
|
| 646 |
+
return Path(self.experiment_fs_path, self.trial_dir_name).as_posix()
|
| 647 |
+
|
| 648 |
+
@property
|
| 649 |
+
def trial_driver_staging_path(self) -> str:
|
| 650 |
+
"""The local filesystem path of the trial directory on the driver.
|
| 651 |
+
|
| 652 |
+
The driver is the node where `Trainer.fit`/`Tuner.fit` is being called.
|
| 653 |
+
|
| 654 |
+
This path is of the form:
|
| 655 |
+
`/tmp/ray/session_<session_id>/artifacts/<ray-train-job-timestamp>/
|
| 656 |
+
<experiment_dir_name>/driver_artifacts/<trial_dir_name>`
|
| 657 |
+
|
| 658 |
+
This should be used as the temporary location for files on the driver
|
| 659 |
+
before persisting them to `trial_fs_path`.
|
| 660 |
+
|
| 661 |
+
For example, callbacks (e.g., JsonLoggerCallback) should write trial-specific
|
| 662 |
+
logfiles within this directory.
|
| 663 |
+
"""
|
| 664 |
+
if self.trial_dir_name is None:
|
| 665 |
+
raise RuntimeError(
|
| 666 |
+
"Should not access `trial_driver_staging_path` "
|
| 667 |
+
"without setting `trial_dir_name`"
|
| 668 |
+
)
|
| 669 |
+
return Path(self.experiment_driver_staging_path, self.trial_dir_name).as_posix()
|
| 670 |
+
|
| 671 |
+
@property
|
| 672 |
+
def trial_working_directory(self) -> str:
|
| 673 |
+
"""The local filesystem path to trial working directory.
|
| 674 |
+
|
| 675 |
+
This path is of the form:
|
| 676 |
+
`/tmp/ray/session_<session_id>/artifacts/<ray-train-job-timestamp>/
|
| 677 |
+
<experiment_dir_name>/working_dirs/<trial_dir_name>`
|
| 678 |
+
|
| 679 |
+
Ray Train/Tune moves the remote actor's working directory to this path
|
| 680 |
+
by default, unless disabled by `RAY_CHDIR_TO_TRIAL_DIR` environment variable.
|
| 681 |
+
|
| 682 |
+
Writing files to this directory allows users to persist training artifacts
|
| 683 |
+
if `SyncConfig(sync_artifacts=True)` is set.
|
| 684 |
+
"""
|
| 685 |
+
if self.trial_dir_name is None:
|
| 686 |
+
raise RuntimeError(
|
| 687 |
+
"Cannot access `trial_working_directory` without "
|
| 688 |
+
"setting `trial_dir_name`"
|
| 689 |
+
)
|
| 690 |
+
return Path(
|
| 691 |
+
self._get_session_path(), "working_dirs", self.trial_dir_name
|
| 692 |
+
).as_posix()
|
| 693 |
+
|
| 694 |
+
@property
|
| 695 |
+
def checkpoint_fs_path(self) -> str:
|
| 696 |
+
"""The current checkpoint directory path on the `storage_filesystem`.
|
| 697 |
+
|
| 698 |
+
"Current" refers to the checkpoint that is currently being created/persisted.
|
| 699 |
+
The user of this class is responsible for setting the `current_checkpoint_index`
|
| 700 |
+
(e.g., incrementing when needed).
|
| 701 |
+
"""
|
| 702 |
+
return Path(self.trial_fs_path, self.checkpoint_dir_name).as_posix()
|
| 703 |
+
|
| 704 |
+
@property
|
| 705 |
+
def checkpoint_dir_name(self) -> str:
|
| 706 |
+
"""The current checkpoint directory name, based on the checkpoint index."""
|
| 707 |
+
return StorageContext._make_checkpoint_dir_name(self.current_checkpoint_index)
|
| 708 |
+
|
| 709 |
+
@staticmethod
|
| 710 |
+
def get_experiment_dir_name(run_obj: Union[str, Callable, Type]) -> str:
|
| 711 |
+
from ray.tune.experiment import Experiment
|
| 712 |
+
from ray.tune.utils import date_str
|
| 713 |
+
|
| 714 |
+
run_identifier = Experiment.get_trainable_name(run_obj)
|
| 715 |
+
|
| 716 |
+
if bool(int(os.environ.get("TUNE_DISABLE_DATED_SUBDIR", 0))):
|
| 717 |
+
dir_name = run_identifier
|
| 718 |
+
else:
|
| 719 |
+
dir_name = "{}_{}".format(run_identifier, date_str())
|
| 720 |
+
return dir_name
|
| 721 |
+
|
| 722 |
+
@staticmethod
|
| 723 |
+
def _make_checkpoint_dir_name(index: int):
|
| 724 |
+
"""Get the name of the checkpoint directory, given an index."""
|
| 725 |
+
return f"checkpoint_{index:06d}"
|
.venv/lib/python3.11/site-packages/ray/train/_internal/syncer.py
ADDED
|
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
import logging
|
| 3 |
+
import threading
|
| 4 |
+
import time
|
| 5 |
+
import traceback
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 8 |
+
|
| 9 |
+
from ray._private.thirdparty.tabulate.tabulate import tabulate
|
| 10 |
+
from ray.train.constants import _DEPRECATED_VALUE
|
| 11 |
+
from ray.util.annotations import DeveloperAPI, PublicAPI
|
| 12 |
+
from ray.widgets import Template
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
# Syncing period for syncing checkpoints between nodes or to cloud.
|
| 17 |
+
DEFAULT_SYNC_PERIOD = 300
|
| 18 |
+
|
| 19 |
+
# Default sync timeout after which syncing processes are aborted
|
| 20 |
+
DEFAULT_SYNC_TIMEOUT = 1800
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@PublicAPI(stability="stable")
|
| 24 |
+
@dataclass
|
| 25 |
+
class SyncConfig:
|
| 26 |
+
"""Configuration object for Train/Tune file syncing to `RunConfig(storage_path)`.
|
| 27 |
+
|
| 28 |
+
In Ray Train/Tune, here is where syncing (mainly uploading) happens:
|
| 29 |
+
|
| 30 |
+
The experiment driver (on the head node) syncs the experiment directory to storage
|
| 31 |
+
(which includes experiment state such as searcher state, the list of trials
|
| 32 |
+
and their statuses, and trial metadata).
|
| 33 |
+
|
| 34 |
+
It's also possible to sync artifacts from the trial directory to storage
|
| 35 |
+
by setting `sync_artifacts=True`.
|
| 36 |
+
For a Ray Tune run with many trials, each trial will upload its trial directory
|
| 37 |
+
to storage, which includes arbitrary files that you dumped during the run.
|
| 38 |
+
For a Ray Train run doing distributed training, each remote worker will similarly
|
| 39 |
+
upload its trial directory to storage.
|
| 40 |
+
|
| 41 |
+
See :ref:`persistent-storage-guide` for more details and examples.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
sync_period: Minimum time in seconds to wait between two sync operations.
|
| 45 |
+
A smaller ``sync_period`` will have the data in storage updated more often
|
| 46 |
+
but introduces more syncing overhead. Defaults to 5 minutes.
|
| 47 |
+
sync_timeout: Maximum time in seconds to wait for a sync process
|
| 48 |
+
to finish running. A sync operation will run for at most this long
|
| 49 |
+
before raising a `TimeoutError`. Defaults to 30 minutes.
|
| 50 |
+
sync_artifacts: [Beta] Whether or not to sync artifacts that are saved to the
|
| 51 |
+
trial directory (accessed via `train.get_context().get_trial_dir()`)
|
| 52 |
+
to the persistent storage configured via `train.RunConfig(storage_path)`.
|
| 53 |
+
The trial or remote worker will try to launch an artifact syncing
|
| 54 |
+
operation every time `train.report` happens, subject to `sync_period`
|
| 55 |
+
and `sync_artifacts_on_checkpoint`.
|
| 56 |
+
Defaults to False -- no artifacts are persisted by default.
|
| 57 |
+
sync_artifacts_on_checkpoint: If True, trial/worker artifacts are
|
| 58 |
+
forcefully synced on every reported checkpoint.
|
| 59 |
+
This only has an effect if `sync_artifacts` is True.
|
| 60 |
+
Defaults to True.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
sync_period: int = DEFAULT_SYNC_PERIOD
|
| 64 |
+
sync_timeout: int = DEFAULT_SYNC_TIMEOUT
|
| 65 |
+
sync_artifacts: bool = False
|
| 66 |
+
sync_artifacts_on_checkpoint: bool = True
|
| 67 |
+
upload_dir: Optional[str] = _DEPRECATED_VALUE
|
| 68 |
+
syncer: Optional[Union[str, "Syncer"]] = _DEPRECATED_VALUE
|
| 69 |
+
sync_on_checkpoint: bool = _DEPRECATED_VALUE
|
| 70 |
+
|
| 71 |
+
# TODO(justinvyu): [Deprecated] Remove in 2.11.
|
| 72 |
+
def _deprecation_warning(self, attr_name: str, extra_msg: str):
|
| 73 |
+
if getattr(self, attr_name) != _DEPRECATED_VALUE:
|
| 74 |
+
raise DeprecationWarning(
|
| 75 |
+
f"`SyncConfig({attr_name})` is a deprecated configuration "
|
| 76 |
+
"Please remove it from your `SyncConfig`. "
|
| 77 |
+
f"{extra_msg}"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
def __post_init__(self):
|
| 81 |
+
for attr_name, extra_msg in [
|
| 82 |
+
(
|
| 83 |
+
"upload_dir",
|
| 84 |
+
"\nPlease specify `ray.train.RunConfig(storage_path)` instead.",
|
| 85 |
+
),
|
| 86 |
+
(
|
| 87 |
+
"syncer",
|
| 88 |
+
"\nPlease implement custom syncing logic with a custom "
|
| 89 |
+
"`pyarrow.fs.FileSystem` instead, and pass it into "
|
| 90 |
+
"`ray.train.RunConfig(storage_filesystem)`. "
|
| 91 |
+
"See here: https://docs.ray.io/en/latest/train/user-guides/persistent-storage.html#custom-storage", # noqa: E501
|
| 92 |
+
),
|
| 93 |
+
("sync_on_checkpoint", ""),
|
| 94 |
+
]:
|
| 95 |
+
self._deprecation_warning(attr_name, extra_msg)
|
| 96 |
+
|
| 97 |
+
def _repr_html_(self) -> str:
|
| 98 |
+
"""Generate an HTML representation of the SyncConfig."""
|
| 99 |
+
return Template("scrollableTable.html.j2").render(
|
| 100 |
+
table=tabulate(
|
| 101 |
+
{
|
| 102 |
+
"Setting": ["Sync period", "Sync timeout"],
|
| 103 |
+
"Value": [self.sync_period, self.sync_timeout],
|
| 104 |
+
},
|
| 105 |
+
tablefmt="html",
|
| 106 |
+
showindex=False,
|
| 107 |
+
headers="keys",
|
| 108 |
+
),
|
| 109 |
+
max_height="none",
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class _BackgroundProcess:
|
| 114 |
+
def __init__(self, fn: Callable):
|
| 115 |
+
self._fn = fn
|
| 116 |
+
self._process = None
|
| 117 |
+
self._result = {}
|
| 118 |
+
self._start_time = float("-inf")
|
| 119 |
+
|
| 120 |
+
@property
|
| 121 |
+
def is_running(self):
|
| 122 |
+
return self._process and self._process.is_alive()
|
| 123 |
+
|
| 124 |
+
@property
|
| 125 |
+
def start_time(self):
|
| 126 |
+
return self._start_time
|
| 127 |
+
|
| 128 |
+
def start(self, *args, **kwargs):
|
| 129 |
+
if self.is_running:
|
| 130 |
+
return False
|
| 131 |
+
|
| 132 |
+
self._result = {}
|
| 133 |
+
|
| 134 |
+
def entrypoint():
|
| 135 |
+
try:
|
| 136 |
+
result = self._fn(*args, **kwargs)
|
| 137 |
+
except Exception as e:
|
| 138 |
+
self._result["exception"] = e
|
| 139 |
+
return
|
| 140 |
+
|
| 141 |
+
self._result["result"] = result
|
| 142 |
+
|
| 143 |
+
self._process = threading.Thread(target=entrypoint)
|
| 144 |
+
self._process.daemon = True
|
| 145 |
+
self._process.start()
|
| 146 |
+
self._start_time = time.time()
|
| 147 |
+
|
| 148 |
+
def wait(self, timeout: Optional[float] = None) -> Any:
|
| 149 |
+
"""Waits for the background process to finish running. Waits until the
|
| 150 |
+
background process has run for at least `timeout` seconds, counting from
|
| 151 |
+
the time when the process was started."""
|
| 152 |
+
if not self._process:
|
| 153 |
+
return None
|
| 154 |
+
|
| 155 |
+
time_remaining = None
|
| 156 |
+
if timeout:
|
| 157 |
+
elapsed = time.time() - self.start_time
|
| 158 |
+
time_remaining = max(timeout - elapsed, 0)
|
| 159 |
+
|
| 160 |
+
self._process.join(timeout=time_remaining)
|
| 161 |
+
|
| 162 |
+
if self._process.is_alive():
|
| 163 |
+
self._process = None
|
| 164 |
+
raise TimeoutError(
|
| 165 |
+
f"{getattr(self._fn, '__name__', str(self._fn))} did not finish "
|
| 166 |
+
f"running within the timeout of {timeout} seconds."
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
self._process = None
|
| 170 |
+
|
| 171 |
+
exception = self._result.get("exception")
|
| 172 |
+
if exception:
|
| 173 |
+
raise exception
|
| 174 |
+
|
| 175 |
+
result = self._result.get("result")
|
| 176 |
+
|
| 177 |
+
self._result = {}
|
| 178 |
+
return result
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
@DeveloperAPI
|
| 182 |
+
class Syncer(abc.ABC):
|
| 183 |
+
"""Syncer class for synchronizing data between Ray nodes and remote (cloud) storage.
|
| 184 |
+
|
| 185 |
+
This class handles data transfer for two cases:
|
| 186 |
+
|
| 187 |
+
1. Synchronizing data such as experiment state snapshots from the driver to
|
| 188 |
+
cloud storage.
|
| 189 |
+
2. Synchronizing data such as trial checkpoints from remote trainables to
|
| 190 |
+
cloud storage.
|
| 191 |
+
|
| 192 |
+
Synchronizing tasks are usually asynchronous and can be awaited using ``wait()``.
|
| 193 |
+
The base class implements a ``wait_or_retry()`` API that will retry a failed
|
| 194 |
+
sync command.
|
| 195 |
+
|
| 196 |
+
The base class also exposes an API to only kick off syncs every ``sync_period``
|
| 197 |
+
seconds.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
sync_period: The minimum time in seconds between sync operations, as
|
| 201 |
+
used by ``sync_up/down_if_needed``.
|
| 202 |
+
sync_timeout: The maximum time to wait for a sync process to finish before
|
| 203 |
+
issuing a new sync operation. Ex: should be used by ``wait`` if launching
|
| 204 |
+
asynchronous sync tasks.
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
def __init__(
|
| 208 |
+
self,
|
| 209 |
+
sync_period: float = DEFAULT_SYNC_PERIOD,
|
| 210 |
+
sync_timeout: float = DEFAULT_SYNC_TIMEOUT,
|
| 211 |
+
):
|
| 212 |
+
self.sync_period = sync_period
|
| 213 |
+
self.sync_timeout = sync_timeout
|
| 214 |
+
self.last_sync_up_time = float("-inf")
|
| 215 |
+
self.last_sync_down_time = float("-inf")
|
| 216 |
+
|
| 217 |
+
@abc.abstractmethod
|
| 218 |
+
def sync_up(
|
| 219 |
+
self, local_dir: str, remote_dir: str, exclude: Optional[List] = None
|
| 220 |
+
) -> bool:
|
| 221 |
+
"""Synchronize local directory to remote directory.
|
| 222 |
+
|
| 223 |
+
This function can spawn an asynchronous process that can be awaited in
|
| 224 |
+
``wait()``.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
local_dir: Local directory to sync from.
|
| 228 |
+
remote_dir: Remote directory to sync up to. This is an URI
|
| 229 |
+
(``protocol://remote/path``).
|
| 230 |
+
exclude: Pattern of files to exclude, e.g.
|
| 231 |
+
``["*/checkpoint_*]`` to exclude trial checkpoints.
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
True if sync process has been spawned, False otherwise.
|
| 235 |
+
|
| 236 |
+
"""
|
| 237 |
+
raise NotImplementedError
|
| 238 |
+
|
| 239 |
+
@abc.abstractmethod
|
| 240 |
+
def sync_down(
|
| 241 |
+
self, remote_dir: str, local_dir: str, exclude: Optional[List] = None
|
| 242 |
+
) -> bool:
|
| 243 |
+
"""Synchronize remote directory to local directory.
|
| 244 |
+
|
| 245 |
+
This function can spawn an asynchronous process that can be awaited in
|
| 246 |
+
``wait()``.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
remote_dir: Remote directory to sync down from. This is an URI
|
| 250 |
+
(``protocol://remote/path``).
|
| 251 |
+
local_dir: Local directory to sync to.
|
| 252 |
+
exclude: Pattern of files to exclude, e.g.
|
| 253 |
+
``["*/checkpoint_*]`` to exclude trial checkpoints.
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
True if sync process has been spawned, False otherwise.
|
| 257 |
+
|
| 258 |
+
"""
|
| 259 |
+
raise NotImplementedError
|
| 260 |
+
|
| 261 |
+
@abc.abstractmethod
|
| 262 |
+
def delete(self, remote_dir: str) -> bool:
|
| 263 |
+
"""Delete directory on remote storage.
|
| 264 |
+
|
| 265 |
+
This function can spawn an asynchronous process that can be awaited in
|
| 266 |
+
``wait()``.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
remote_dir: Remote directory to delete. This is an URI
|
| 270 |
+
(``protocol://remote/path``).
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
True if sync process has been spawned, False otherwise.
|
| 274 |
+
|
| 275 |
+
"""
|
| 276 |
+
raise NotImplementedError
|
| 277 |
+
|
| 278 |
+
def retry(self):
|
| 279 |
+
"""Retry the last sync up, sync down, or delete command.
|
| 280 |
+
|
| 281 |
+
You should implement this method if you spawn asynchronous syncing
|
| 282 |
+
processes.
|
| 283 |
+
"""
|
| 284 |
+
pass
|
| 285 |
+
|
| 286 |
+
def wait(self, timeout: Optional[float] = None):
|
| 287 |
+
"""Wait for asynchronous sync command to finish.
|
| 288 |
+
|
| 289 |
+
You should implement this method if you spawn asynchronous syncing
|
| 290 |
+
processes. This method should timeout after the asynchronous command
|
| 291 |
+
has run for `sync_timeout` seconds and raise a `TimeoutError`.
|
| 292 |
+
"""
|
| 293 |
+
pass
|
| 294 |
+
|
| 295 |
+
def sync_up_if_needed(
|
| 296 |
+
self, local_dir: str, remote_dir: str, exclude: Optional[List] = None
|
| 297 |
+
) -> bool:
|
| 298 |
+
"""Syncs up if time since last sync up is greater than sync_period.
|
| 299 |
+
|
| 300 |
+
Args:
|
| 301 |
+
local_dir: Local directory to sync from.
|
| 302 |
+
remote_dir: Remote directory to sync up to. This is an URI
|
| 303 |
+
(``protocol://remote/path``).
|
| 304 |
+
exclude: Pattern of files to exclude, e.g.
|
| 305 |
+
``["*/checkpoint_*]`` to exclude trial checkpoints.
|
| 306 |
+
"""
|
| 307 |
+
now = time.time()
|
| 308 |
+
if now - self.last_sync_up_time >= self.sync_period:
|
| 309 |
+
result = self.sync_up(
|
| 310 |
+
local_dir=local_dir, remote_dir=remote_dir, exclude=exclude
|
| 311 |
+
)
|
| 312 |
+
self.last_sync_up_time = now
|
| 313 |
+
return result
|
| 314 |
+
|
| 315 |
+
def sync_down_if_needed(
|
| 316 |
+
self, remote_dir: str, local_dir: str, exclude: Optional[List] = None
|
| 317 |
+
):
|
| 318 |
+
"""Syncs down if time since last sync down is greater than sync_period.
|
| 319 |
+
|
| 320 |
+
Args:
|
| 321 |
+
remote_dir: Remote directory to sync down from. This is an URI
|
| 322 |
+
(``protocol://remote/path``).
|
| 323 |
+
local_dir: Local directory to sync to.
|
| 324 |
+
exclude: Pattern of files to exclude, e.g.
|
| 325 |
+
``["*/checkpoint_*]`` to exclude trial checkpoints.
|
| 326 |
+
"""
|
| 327 |
+
now = time.time()
|
| 328 |
+
if now - self.last_sync_down_time >= self.sync_period:
|
| 329 |
+
result = self.sync_down(
|
| 330 |
+
remote_dir=remote_dir, local_dir=local_dir, exclude=exclude
|
| 331 |
+
)
|
| 332 |
+
self.last_sync_down_time = now
|
| 333 |
+
return result
|
| 334 |
+
|
| 335 |
+
def wait_or_retry(self, max_retries: int = 2, backoff_s: int = 5):
|
| 336 |
+
assert max_retries > 0
|
| 337 |
+
last_error_traceback = None
|
| 338 |
+
for i in range(max_retries + 1):
|
| 339 |
+
try:
|
| 340 |
+
self.wait()
|
| 341 |
+
except Exception as e:
|
| 342 |
+
attempts_remaining = max_retries - i
|
| 343 |
+
|
| 344 |
+
# If we're out of retries, then save the full traceback of the last
|
| 345 |
+
# error and show it when raising an exception.
|
| 346 |
+
if attempts_remaining == 0:
|
| 347 |
+
last_error_traceback = traceback.format_exc()
|
| 348 |
+
break
|
| 349 |
+
|
| 350 |
+
logger.error(
|
| 351 |
+
f"The latest sync operation failed with the following error: "
|
| 352 |
+
f"{repr(e)}\n"
|
| 353 |
+
f"Retrying {attempts_remaining} more time(s) after sleeping "
|
| 354 |
+
f"for {backoff_s} seconds..."
|
| 355 |
+
)
|
| 356 |
+
time.sleep(backoff_s)
|
| 357 |
+
self.retry()
|
| 358 |
+
continue
|
| 359 |
+
# Succeeded!
|
| 360 |
+
return
|
| 361 |
+
raise RuntimeError(
|
| 362 |
+
f"Failed sync even after {max_retries} retries. "
|
| 363 |
+
f"The latest sync failed with the following error:\n{last_error_traceback}"
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
def reset(self):
|
| 367 |
+
self.last_sync_up_time = float("-inf")
|
| 368 |
+
self.last_sync_down_time = float("-inf")
|
| 369 |
+
|
| 370 |
+
def close(self):
|
| 371 |
+
pass
|
| 372 |
+
|
| 373 |
+
def _repr_html_(self) -> str:
|
| 374 |
+
return
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
class _BackgroundSyncer(Syncer):
|
| 378 |
+
"""Syncer using a background process for asynchronous file transfer."""
|
| 379 |
+
|
| 380 |
+
def __init__(
|
| 381 |
+
self,
|
| 382 |
+
sync_period: float = DEFAULT_SYNC_PERIOD,
|
| 383 |
+
sync_timeout: float = DEFAULT_SYNC_TIMEOUT,
|
| 384 |
+
):
|
| 385 |
+
super(_BackgroundSyncer, self).__init__(
|
| 386 |
+
sync_period=sync_period, sync_timeout=sync_timeout
|
| 387 |
+
)
|
| 388 |
+
self._sync_process = None
|
| 389 |
+
self._current_cmd = None
|
| 390 |
+
|
| 391 |
+
def _should_continue_existing_sync(self):
|
| 392 |
+
"""Returns whether a previous sync is still running within the timeout."""
|
| 393 |
+
return (
|
| 394 |
+
self._sync_process
|
| 395 |
+
and self._sync_process.is_running
|
| 396 |
+
and time.time() - self._sync_process.start_time < self.sync_timeout
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
def _launch_sync_process(self, sync_command: Tuple[Callable, Dict]):
|
| 400 |
+
"""Waits for the previous sync process to finish,
|
| 401 |
+
then launches a new process that runs the given command."""
|
| 402 |
+
if self._sync_process:
|
| 403 |
+
try:
|
| 404 |
+
self.wait()
|
| 405 |
+
except Exception:
|
| 406 |
+
logger.warning(
|
| 407 |
+
f"Last sync command failed with the following error:\n"
|
| 408 |
+
f"{traceback.format_exc()}"
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
self._current_cmd = sync_command
|
| 412 |
+
self.retry()
|
| 413 |
+
|
| 414 |
+
def sync_up(
|
| 415 |
+
self, local_dir: str, remote_dir: str, exclude: Optional[List] = None
|
| 416 |
+
) -> bool:
|
| 417 |
+
if self._should_continue_existing_sync():
|
| 418 |
+
logger.debug(
|
| 419 |
+
f"Last sync still in progress, "
|
| 420 |
+
f"skipping sync up of {local_dir} to {remote_dir}"
|
| 421 |
+
)
|
| 422 |
+
return False
|
| 423 |
+
|
| 424 |
+
sync_up_cmd = self._sync_up_command(
|
| 425 |
+
local_path=local_dir, uri=remote_dir, exclude=exclude
|
| 426 |
+
)
|
| 427 |
+
self._launch_sync_process(sync_up_cmd)
|
| 428 |
+
|
| 429 |
+
return True
|
| 430 |
+
|
| 431 |
+
def _sync_up_command(
|
| 432 |
+
self, local_path: str, uri: str, exclude: Optional[List] = None
|
| 433 |
+
) -> Tuple[Callable, Dict]:
|
| 434 |
+
raise NotImplementedError
|
| 435 |
+
|
| 436 |
+
def sync_down(
|
| 437 |
+
self, remote_dir: str, local_dir: str, exclude: Optional[List] = None
|
| 438 |
+
) -> bool:
|
| 439 |
+
if self._should_continue_existing_sync():
|
| 440 |
+
logger.warning(
|
| 441 |
+
f"Last sync still in progress, "
|
| 442 |
+
f"skipping sync down of {remote_dir} to {local_dir}"
|
| 443 |
+
)
|
| 444 |
+
return False
|
| 445 |
+
|
| 446 |
+
sync_down_cmd = self._sync_down_command(uri=remote_dir, local_path=local_dir)
|
| 447 |
+
self._launch_sync_process(sync_down_cmd)
|
| 448 |
+
|
| 449 |
+
return True
|
| 450 |
+
|
| 451 |
+
def _sync_down_command(self, uri: str, local_path: str) -> Tuple[Callable, Dict]:
|
| 452 |
+
raise NotImplementedError
|
| 453 |
+
|
| 454 |
+
def delete(self, remote_dir: str) -> bool:
|
| 455 |
+
if self._should_continue_existing_sync():
|
| 456 |
+
logger.warning(
|
| 457 |
+
f"Last sync still in progress, skipping deletion of {remote_dir}"
|
| 458 |
+
)
|
| 459 |
+
return False
|
| 460 |
+
|
| 461 |
+
delete_cmd = self._delete_command(uri=remote_dir)
|
| 462 |
+
self._launch_sync_process(delete_cmd)
|
| 463 |
+
|
| 464 |
+
return True
|
| 465 |
+
|
| 466 |
+
def _delete_command(self, uri: str) -> Tuple[Callable, Dict]:
|
| 467 |
+
raise NotImplementedError
|
| 468 |
+
|
| 469 |
+
def wait(self, timeout: Optional[float] = None):
|
| 470 |
+
if self._sync_process:
|
| 471 |
+
try:
|
| 472 |
+
self._sync_process.wait(timeout=timeout or self.sync_timeout)
|
| 473 |
+
except Exception as e:
|
| 474 |
+
raise e
|
| 475 |
+
finally:
|
| 476 |
+
# Regardless of whether the sync process succeeded within the timeout,
|
| 477 |
+
# clear the sync process so a new one can be created.
|
| 478 |
+
self._sync_process = None
|
| 479 |
+
|
| 480 |
+
def retry(self):
|
| 481 |
+
if not self._current_cmd:
|
| 482 |
+
raise RuntimeError("No sync command set, cannot retry.")
|
| 483 |
+
cmd, kwargs = self._current_cmd
|
| 484 |
+
self._sync_process = _BackgroundProcess(cmd)
|
| 485 |
+
self._sync_process.start(**kwargs)
|
| 486 |
+
|
| 487 |
+
def __getstate__(self):
|
| 488 |
+
state = self.__dict__.copy()
|
| 489 |
+
state["_sync_process"] = None
|
| 490 |
+
return state
|
.venv/lib/python3.11/site-packages/ray/train/_internal/utils.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
import functools
|
| 3 |
+
import inspect
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import (
|
| 8 |
+
Any,
|
| 9 |
+
Callable,
|
| 10 |
+
ContextManager,
|
| 11 |
+
Dict,
|
| 12 |
+
List,
|
| 13 |
+
Optional,
|
| 14 |
+
Tuple,
|
| 15 |
+
TypeVar,
|
| 16 |
+
Union,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
import ray
|
| 20 |
+
from ray.actor import ActorHandle
|
| 21 |
+
from ray.air._internal.util import (
|
| 22 |
+
StartTraceback,
|
| 23 |
+
StartTracebackWithWorkerRank,
|
| 24 |
+
find_free_port,
|
| 25 |
+
)
|
| 26 |
+
from ray.exceptions import RayActorError
|
| 27 |
+
from ray.types import ObjectRef
|
| 28 |
+
|
| 29 |
+
T = TypeVar("T")
|
| 30 |
+
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def check_for_failure(
|
| 35 |
+
remote_values: List[ObjectRef],
|
| 36 |
+
) -> Tuple[bool, Optional[Exception]]:
|
| 37 |
+
"""Check for actor failure when retrieving the remote values.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
remote_values: List of object references from Ray actor methods.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
A tuple of (bool, Exception). The bool is
|
| 44 |
+
True if evaluating all object references is successful, False otherwise.
|
| 45 |
+
"""
|
| 46 |
+
unfinished = remote_values.copy()
|
| 47 |
+
|
| 48 |
+
while len(unfinished) > 0:
|
| 49 |
+
finished, unfinished = ray.wait(unfinished)
|
| 50 |
+
|
| 51 |
+
# If a failure occurs the ObjectRef will be marked as finished.
|
| 52 |
+
# Calling ray.get will expose the failure as a RayActorError.
|
| 53 |
+
for object_ref in finished:
|
| 54 |
+
# Everything in finished has either failed or completed
|
| 55 |
+
# successfully.
|
| 56 |
+
try:
|
| 57 |
+
ray.get(object_ref)
|
| 58 |
+
except RayActorError as exc:
|
| 59 |
+
failed_actor_rank = remote_values.index(object_ref)
|
| 60 |
+
logger.info(f"Worker {failed_actor_rank} has failed.")
|
| 61 |
+
return False, exc
|
| 62 |
+
except Exception as exc:
|
| 63 |
+
# Other (e.g. training) errors should be directly raised
|
| 64 |
+
failed_worker_rank = remote_values.index(object_ref)
|
| 65 |
+
raise StartTracebackWithWorkerRank(
|
| 66 |
+
worker_rank=failed_worker_rank
|
| 67 |
+
) from exc
|
| 68 |
+
|
| 69 |
+
return True, None
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def get_address_and_port() -> Tuple[str, int]:
|
| 73 |
+
"""Returns the IP address and a free port on this node."""
|
| 74 |
+
addr = ray.util.get_node_ip_address()
|
| 75 |
+
port = find_free_port()
|
| 76 |
+
|
| 77 |
+
return addr, port
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def construct_path(path: Path, parent_path: Path) -> Path:
|
| 81 |
+
"""Constructs a path relative to a parent.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
path: A relative or absolute path.
|
| 85 |
+
parent_path: A relative path or absolute path.
|
| 86 |
+
|
| 87 |
+
Returns: An absolute path.
|
| 88 |
+
"""
|
| 89 |
+
if path.expanduser().is_absolute():
|
| 90 |
+
return path.expanduser().resolve()
|
| 91 |
+
else:
|
| 92 |
+
return parent_path.joinpath(path).expanduser().resolve()
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def update_env_vars(env_vars: Dict[str, Any]):
|
| 96 |
+
"""Updates the environment variables on this worker process.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
env_vars: Environment variables to set.
|
| 100 |
+
"""
|
| 101 |
+
sanitized = {k: str(v) for k, v in env_vars.items()}
|
| 102 |
+
os.environ.update(sanitized)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def count_required_parameters(fn: Callable) -> int:
|
| 106 |
+
"""Counts the number of required parameters of a function.
|
| 107 |
+
|
| 108 |
+
NOTE: *args counts as 1 required parameter.
|
| 109 |
+
|
| 110 |
+
Examples
|
| 111 |
+
--------
|
| 112 |
+
|
| 113 |
+
>>> def fn(a, b, /, c, *args, d=1, e=2, **kwargs):
|
| 114 |
+
... pass
|
| 115 |
+
>>> count_required_parameters(fn)
|
| 116 |
+
4
|
| 117 |
+
|
| 118 |
+
>>> fn = lambda: 1
|
| 119 |
+
>>> count_required_parameters(fn)
|
| 120 |
+
0
|
| 121 |
+
|
| 122 |
+
>>> def fn(config, a, b=1, c=2):
|
| 123 |
+
... pass
|
| 124 |
+
>>> from functools import partial
|
| 125 |
+
>>> count_required_parameters(partial(fn, a=0))
|
| 126 |
+
1
|
| 127 |
+
"""
|
| 128 |
+
params = inspect.signature(fn).parameters.values()
|
| 129 |
+
|
| 130 |
+
positional_param_kinds = {
|
| 131 |
+
inspect.Parameter.POSITIONAL_ONLY,
|
| 132 |
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
| 133 |
+
inspect.Parameter.VAR_POSITIONAL,
|
| 134 |
+
}
|
| 135 |
+
return len(
|
| 136 |
+
[
|
| 137 |
+
p
|
| 138 |
+
for p in params
|
| 139 |
+
if p.default == inspect.Parameter.empty and p.kind in positional_param_kinds
|
| 140 |
+
]
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def construct_train_func(
|
| 145 |
+
train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
|
| 146 |
+
config: Optional[Dict[str, Any]],
|
| 147 |
+
train_func_context: ContextManager,
|
| 148 |
+
fn_arg_name: Optional[str] = "train_func",
|
| 149 |
+
discard_returns: bool = False,
|
| 150 |
+
) -> Callable[[], T]:
|
| 151 |
+
"""Validates and constructs the training function to execute.
|
| 152 |
+
Args:
|
| 153 |
+
train_func: The training function to execute.
|
| 154 |
+
This can either take in no arguments or a ``config`` dict.
|
| 155 |
+
config (Optional[Dict]): Configurations to pass into
|
| 156 |
+
``train_func``. If None then an empty Dict will be created.
|
| 157 |
+
train_func_context: Context manager for user's `train_func`, which executes
|
| 158 |
+
backend-specific logic before and after the training function.
|
| 159 |
+
fn_arg_name (Optional[str]): The name of training function to use for error
|
| 160 |
+
messages.
|
| 161 |
+
discard_returns: Whether to discard any returns from train_func or not.
|
| 162 |
+
Returns:
|
| 163 |
+
A valid training function.
|
| 164 |
+
Raises:
|
| 165 |
+
ValueError: if the input ``train_func`` is invalid.
|
| 166 |
+
"""
|
| 167 |
+
num_required_params = count_required_parameters(train_func)
|
| 168 |
+
|
| 169 |
+
if discard_returns:
|
| 170 |
+
# Discard any returns from the function so that
|
| 171 |
+
# BackendExecutor doesn't try to deserialize them.
|
| 172 |
+
# Those returns are inaccesible with AIR anyway.
|
| 173 |
+
@functools.wraps(train_func)
|
| 174 |
+
def discard_return_wrapper(*args, **kwargs):
|
| 175 |
+
try:
|
| 176 |
+
train_func(*args, **kwargs)
|
| 177 |
+
except Exception as e:
|
| 178 |
+
raise StartTraceback from e
|
| 179 |
+
|
| 180 |
+
wrapped_train_func = discard_return_wrapper
|
| 181 |
+
else:
|
| 182 |
+
wrapped_train_func = train_func
|
| 183 |
+
|
| 184 |
+
if num_required_params > 1:
|
| 185 |
+
err_msg = (
|
| 186 |
+
f"{fn_arg_name} should take in 0 or 1 required arguments, but it accepts "
|
| 187 |
+
f"{num_required_params} required arguments instead."
|
| 188 |
+
)
|
| 189 |
+
raise ValueError(err_msg)
|
| 190 |
+
elif num_required_params == 1:
|
| 191 |
+
config = {} if config is None else config
|
| 192 |
+
|
| 193 |
+
@functools.wraps(wrapped_train_func)
|
| 194 |
+
def train_fn():
|
| 195 |
+
try:
|
| 196 |
+
with train_func_context():
|
| 197 |
+
return wrapped_train_func(config)
|
| 198 |
+
except Exception as e:
|
| 199 |
+
raise StartTraceback from e
|
| 200 |
+
|
| 201 |
+
else: # num_params == 0
|
| 202 |
+
|
| 203 |
+
@functools.wraps(wrapped_train_func)
|
| 204 |
+
def train_fn():
|
| 205 |
+
try:
|
| 206 |
+
with train_func_context():
|
| 207 |
+
return wrapped_train_func()
|
| 208 |
+
except Exception as e:
|
| 209 |
+
raise StartTraceback from e
|
| 210 |
+
|
| 211 |
+
return train_fn
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class Singleton(abc.ABCMeta):
|
| 215 |
+
"""Singleton Abstract Base Class
|
| 216 |
+
|
| 217 |
+
https://stackoverflow.com/questions/33364070/implementing
|
| 218 |
+
-singleton-as-metaclass-but-for-abstract-classes
|
| 219 |
+
"""
|
| 220 |
+
|
| 221 |
+
_instances = {}
|
| 222 |
+
|
| 223 |
+
def __call__(cls, *args, **kwargs):
|
| 224 |
+
if cls not in cls._instances:
|
| 225 |
+
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
|
| 226 |
+
return cls._instances[cls]
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class ActorWrapper:
|
| 230 |
+
"""Wraps an actor to provide same API as using the base class directly."""
|
| 231 |
+
|
| 232 |
+
def __init__(self, actor: ActorHandle):
|
| 233 |
+
self.actor = actor
|
| 234 |
+
|
| 235 |
+
def __getattr__(self, item):
|
| 236 |
+
# The below will fail if trying to access an attribute (not a method) from the
|
| 237 |
+
# actor.
|
| 238 |
+
actor_method = getattr(self.actor, item)
|
| 239 |
+
return lambda *args, **kwargs: ray.get(actor_method.remote(*args, **kwargs))
|
.venv/lib/python3.11/site-packages/ray/train/_internal/worker_group.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import socket
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
| 7 |
+
|
| 8 |
+
import ray
|
| 9 |
+
from ray.actor import ActorHandle
|
| 10 |
+
from ray.air._internal.util import exception_cause, skip_exceptions
|
| 11 |
+
from ray.types import ObjectRef
|
| 12 |
+
from ray.util.placement_group import PlacementGroup
|
| 13 |
+
|
| 14 |
+
T = TypeVar("T")
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class RayTrainWorker:
|
| 20 |
+
"""A class to execute arbitrary functions. Does not hold any state."""
|
| 21 |
+
|
| 22 |
+
def __execute(self, func: Callable[..., T], *args, **kwargs) -> T:
|
| 23 |
+
"""Executes the input function and returns the output.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
func: The function to execute.
|
| 27 |
+
args, kwargs: The arguments to pass into func.
|
| 28 |
+
"""
|
| 29 |
+
try:
|
| 30 |
+
return func(*args, **kwargs)
|
| 31 |
+
except Exception as e:
|
| 32 |
+
skipped = skip_exceptions(e)
|
| 33 |
+
raise skipped from exception_cause(skipped)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class WorkerMetadata:
|
| 38 |
+
"""Metadata for each worker/actor.
|
| 39 |
+
|
| 40 |
+
This information is expected to stay the same throughout the lifetime of
|
| 41 |
+
actor.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
node_id: ID of the node this worker is on.
|
| 45 |
+
node_ip: IP address of the node this worker is on.
|
| 46 |
+
hostname: Hostname that this worker is on.
|
| 47 |
+
resource_ids: Map of accelerator resources
|
| 48 |
+
("GPU", "neuron_cores", ..) to their IDs.
|
| 49 |
+
pid: Process ID of this worker.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
node_id: str
|
| 53 |
+
node_ip: str
|
| 54 |
+
hostname: str
|
| 55 |
+
resource_ids: Dict[str, List[str]]
|
| 56 |
+
pid: int
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@dataclass
|
| 60 |
+
class Worker:
|
| 61 |
+
"""Class representing a Worker."""
|
| 62 |
+
|
| 63 |
+
actor: ActorHandle
|
| 64 |
+
metadata: WorkerMetadata
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def create_executable_class(executable_cls: Optional[Type] = None) -> Type:
|
| 68 |
+
"""Create the executable class to use as the Ray actors."""
|
| 69 |
+
if not executable_cls:
|
| 70 |
+
return RayTrainWorker
|
| 71 |
+
elif issubclass(executable_cls, RayTrainWorker):
|
| 72 |
+
return executable_cls
|
| 73 |
+
else:
|
| 74 |
+
|
| 75 |
+
class _WrappedExecutable(executable_cls, RayTrainWorker):
|
| 76 |
+
def __init__(self, *args, **kwargs):
|
| 77 |
+
super().__init__(*args, **kwargs)
|
| 78 |
+
|
| 79 |
+
return _WrappedExecutable
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def construct_metadata() -> WorkerMetadata:
|
| 83 |
+
"""Creates metadata for this worker.
|
| 84 |
+
|
| 85 |
+
This function is expected to be run on the actor.
|
| 86 |
+
"""
|
| 87 |
+
node_id = ray.get_runtime_context().get_node_id()
|
| 88 |
+
node_ip = ray.util.get_node_ip_address()
|
| 89 |
+
hostname = socket.gethostname()
|
| 90 |
+
accelerator_ids = ray.get_runtime_context().get_accelerator_ids()
|
| 91 |
+
pid = os.getpid()
|
| 92 |
+
|
| 93 |
+
return WorkerMetadata(
|
| 94 |
+
node_id=node_id,
|
| 95 |
+
node_ip=node_ip,
|
| 96 |
+
hostname=hostname,
|
| 97 |
+
resource_ids=accelerator_ids,
|
| 98 |
+
pid=pid,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class WorkerGroup:
|
| 103 |
+
"""Group of Ray Actors that can execute arbitrary functions.
|
| 104 |
+
|
| 105 |
+
``WorkerGroup`` launches Ray actors according to the given
|
| 106 |
+
specification. It can then execute arbitrary Python functions in each of
|
| 107 |
+
these workers.
|
| 108 |
+
|
| 109 |
+
If not enough resources are available to launch the actors, the Ray
|
| 110 |
+
cluster will automatically scale up if autoscaling is enabled.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
num_workers: The number of workers (Ray actors) to launch.
|
| 114 |
+
Defaults to 1.
|
| 115 |
+
resources_per_worker (Optional[Dict[str, float]]):
|
| 116 |
+
Dictionary specifying the resources that will be
|
| 117 |
+
requested for each worker. Defaults to {"CPU": 1}.
|
| 118 |
+
actor_cls (Optional[Type]): If specified use this class as the
|
| 119 |
+
remote actors.
|
| 120 |
+
remote_cls_args, remote_cls_kwargs: If ``remote_cls`` is provided,
|
| 121 |
+
these args will be used for the worker initialization.
|
| 122 |
+
placement_group (PlacementGroup|str): The placement group that workers
|
| 123 |
+
should be created in. Defaults to "default" which will inherit the
|
| 124 |
+
parent placement group (if child tasks should be captured).
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
Example:
|
| 128 |
+
|
| 129 |
+
.. code_block:: python
|
| 130 |
+
|
| 131 |
+
worker_group = WorkerGroup(num_workers=2)
|
| 132 |
+
output = worker_group.execute(lambda: 1)
|
| 133 |
+
assert len(output) == 2
|
| 134 |
+
assert all(o == 1 for o in output)
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def __init__(
|
| 138 |
+
self,
|
| 139 |
+
num_workers: int = 1,
|
| 140 |
+
resources_per_worker: Optional[Dict[str, float]] = None,
|
| 141 |
+
actor_cls: Type = None,
|
| 142 |
+
actor_cls_args: Optional[Tuple] = None,
|
| 143 |
+
actor_cls_kwargs: Optional[Dict] = None,
|
| 144 |
+
placement_group: Union[PlacementGroup, str] = "default",
|
| 145 |
+
):
|
| 146 |
+
if resources_per_worker is None:
|
| 147 |
+
resources_per_worker = {"CPU": 1}
|
| 148 |
+
else:
|
| 149 |
+
resources_per_worker = resources_per_worker.copy()
|
| 150 |
+
|
| 151 |
+
if num_workers <= 0:
|
| 152 |
+
raise ValueError(
|
| 153 |
+
"The provided `num_workers` must be greater "
|
| 154 |
+
f"than 0. Received num_workers={num_workers} "
|
| 155 |
+
f"instead."
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
if any(v < 0 for v in resources_per_worker.values()):
|
| 159 |
+
raise ValueError(
|
| 160 |
+
"The number of resources per worker must not be negative. "
|
| 161 |
+
f"Received resources_per_worker={resources_per_worker}."
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
if (actor_cls_args or actor_cls_kwargs) and not actor_cls:
|
| 165 |
+
raise ValueError(
|
| 166 |
+
"`actor_cls_args` or `actor_class_kwargs` are "
|
| 167 |
+
"passed in but no `actor_cls` is passed in."
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
self.num_workers = num_workers
|
| 171 |
+
self.num_cpus_per_worker = resources_per_worker.pop("CPU", 0)
|
| 172 |
+
self.num_gpus_per_worker = resources_per_worker.pop("GPU", 0)
|
| 173 |
+
self.memory_per_worker = resources_per_worker.pop("memory", 0)
|
| 174 |
+
self.workers = []
|
| 175 |
+
self._base_cls = create_executable_class(actor_cls)
|
| 176 |
+
assert issubclass(self._base_cls, RayTrainWorker)
|
| 177 |
+
|
| 178 |
+
self._actor_cls_args = actor_cls_args or []
|
| 179 |
+
self._actor_cls_kwargs = actor_cls_kwargs or {}
|
| 180 |
+
|
| 181 |
+
self._placement_group = placement_group
|
| 182 |
+
|
| 183 |
+
# TODO(matt): Validate resources. Fast-fail if it is impossible to
|
| 184 |
+
# handle the request, rather than hang indefinitely.
|
| 185 |
+
self._remote_cls = ray.remote(
|
| 186 |
+
num_cpus=self.num_cpus_per_worker,
|
| 187 |
+
num_gpus=self.num_gpus_per_worker,
|
| 188 |
+
memory=self.memory_per_worker,
|
| 189 |
+
resources=resources_per_worker,
|
| 190 |
+
)(self._base_cls)
|
| 191 |
+
self.start()
|
| 192 |
+
|
| 193 |
+
def start(self):
|
| 194 |
+
"""Starts all the workers in this worker group."""
|
| 195 |
+
if self.workers and len(self.workers) > 0:
|
| 196 |
+
raise RuntimeError(
|
| 197 |
+
"The workers have already been started. "
|
| 198 |
+
"Please call `shutdown` first if you want to "
|
| 199 |
+
"restart them."
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
logger.debug(f"Starting {self.num_workers} workers.")
|
| 203 |
+
self.add_workers(self.num_workers)
|
| 204 |
+
logger.debug(f"{len(self.workers)} workers have successfully started.")
|
| 205 |
+
|
| 206 |
+
def shutdown(self, patience_s: float = 5):
|
| 207 |
+
"""Shutdown all the workers in this worker group.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
patience_s: Attempt a graceful shutdown
|
| 211 |
+
of the workers for this many seconds. Fallback to force kill
|
| 212 |
+
if graceful shutdown is not complete after this time. If
|
| 213 |
+
this is less than or equal to 0, immediately force kill all
|
| 214 |
+
workers.
|
| 215 |
+
"""
|
| 216 |
+
logger.debug(f"Shutting down {len(self.workers)} workers.")
|
| 217 |
+
if patience_s <= 0:
|
| 218 |
+
for worker in self.workers:
|
| 219 |
+
ray.kill(worker.actor)
|
| 220 |
+
else:
|
| 221 |
+
done_refs = [w.actor.__ray_terminate__.remote() for w in self.workers]
|
| 222 |
+
# Wait for actors to die gracefully.
|
| 223 |
+
done, not_done = ray.wait(done_refs, timeout=patience_s)
|
| 224 |
+
if not_done:
|
| 225 |
+
logger.debug("Graceful termination failed. Falling back to force kill.")
|
| 226 |
+
# If all actors are not able to die gracefully, then kill them.
|
| 227 |
+
for worker in self.workers:
|
| 228 |
+
ray.kill(worker.actor)
|
| 229 |
+
|
| 230 |
+
logger.debug("Shutdown successful.")
|
| 231 |
+
self.workers = []
|
| 232 |
+
|
| 233 |
+
def execute_async(self, func: Callable[..., T], *args, **kwargs) -> List[ObjectRef]:
|
| 234 |
+
"""Execute ``func`` on each worker and return the futures.
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
func: A function to call on each worker.
|
| 238 |
+
args, kwargs: Passed directly into func.
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
(List[ObjectRef]) A list of ``ObjectRef`` representing the
|
| 242 |
+
output of ``func`` from each worker. The order is the same
|
| 243 |
+
as ``self.workers``.
|
| 244 |
+
|
| 245 |
+
"""
|
| 246 |
+
if len(self.workers) <= 0:
|
| 247 |
+
raise RuntimeError(
|
| 248 |
+
"There are no active workers. This worker "
|
| 249 |
+
"group has most likely been shut down. Please"
|
| 250 |
+
"create a new WorkerGroup or restart this one."
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
return [
|
| 254 |
+
w.actor._RayTrainWorker__execute.options(
|
| 255 |
+
name=f"_RayTrainWorker__execute.{func.__name__}"
|
| 256 |
+
).remote(func, *args, **kwargs)
|
| 257 |
+
for w in self.workers
|
| 258 |
+
]
|
| 259 |
+
|
| 260 |
+
def execute(self, func: Callable[..., T], *args, **kwargs) -> List[T]:
|
| 261 |
+
"""Execute ``func`` on each worker and return the outputs of ``func``.
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
func: A function to call on each worker.
|
| 265 |
+
args, kwargs: Passed directly into func.
|
| 266 |
+
|
| 267 |
+
Returns:
|
| 268 |
+
(List[T]) A list containing the output of ``func`` from each
|
| 269 |
+
worker. The order is the same as ``self.workers``.
|
| 270 |
+
|
| 271 |
+
"""
|
| 272 |
+
return ray.get(self.execute_async(func, *args, **kwargs))
|
| 273 |
+
|
| 274 |
+
def execute_single_async(
|
| 275 |
+
self, worker_index: int, func: Callable[..., T], *args, **kwargs
|
| 276 |
+
) -> ObjectRef:
|
| 277 |
+
"""Execute ``func`` on worker ``worker_index`` and return futures.
|
| 278 |
+
|
| 279 |
+
Args:
|
| 280 |
+
worker_index: The index to execute func on.
|
| 281 |
+
func: A function to call on the first worker.
|
| 282 |
+
args, kwargs: Passed directly into func.
|
| 283 |
+
|
| 284 |
+
Returns:
|
| 285 |
+
(ObjectRef) An ObjectRef representing the output of func.
|
| 286 |
+
|
| 287 |
+
"""
|
| 288 |
+
if worker_index >= len(self.workers):
|
| 289 |
+
raise ValueError(
|
| 290 |
+
f"The provided worker_index {worker_index} is "
|
| 291 |
+
f"not valid for {self.num_workers} workers."
|
| 292 |
+
)
|
| 293 |
+
return (
|
| 294 |
+
self.workers[worker_index]
|
| 295 |
+
.actor._RayTrainWorker__execute.options(
|
| 296 |
+
name=f"_RayTrainWorker__execute.{func.__name__}"
|
| 297 |
+
)
|
| 298 |
+
.remote(func, *args, **kwargs)
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
def execute_single(
|
| 302 |
+
self, worker_index: int, func: Callable[..., T], *args, **kwargs
|
| 303 |
+
) -> T:
|
| 304 |
+
"""Execute ``func`` on worker with index ``worker_index``.
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
worker_index: The index to execute func on.
|
| 308 |
+
func: A function to call on the first worker.
|
| 309 |
+
args, kwargs: Passed directly into func.
|
| 310 |
+
|
| 311 |
+
Returns:
|
| 312 |
+
(T) The output of func.
|
| 313 |
+
|
| 314 |
+
"""
|
| 315 |
+
|
| 316 |
+
return ray.get(self.execute_single_async(worker_index, func, *args, **kwargs))
|
| 317 |
+
|
| 318 |
+
def remove_workers(self, worker_indexes: List[int]):
|
| 319 |
+
"""Removes the workers with the specified indexes.
|
| 320 |
+
|
| 321 |
+
The removed workers will go out of scope and their actor processes
|
| 322 |
+
will be terminated.
|
| 323 |
+
|
| 324 |
+
Args:
|
| 325 |
+
worker_indexes (List[int]): The indexes of the workers to remove.
|
| 326 |
+
"""
|
| 327 |
+
new_workers = []
|
| 328 |
+
for i in range(len(self.workers)):
|
| 329 |
+
if i not in worker_indexes:
|
| 330 |
+
new_workers.append(self.workers[i])
|
| 331 |
+
self.workers = new_workers
|
| 332 |
+
|
| 333 |
+
def add_workers(self, num_workers: int):
|
| 334 |
+
"""Adds ``num_workers`` to this WorkerGroup.
|
| 335 |
+
|
| 336 |
+
Note: Adding workers when the cluster/placement group is at capacity
|
| 337 |
+
may lead to undefined hanging behavior. If you are attempting to
|
| 338 |
+
replace existing workers in the WorkerGroup, remove_workers() should
|
| 339 |
+
be called first.
|
| 340 |
+
|
| 341 |
+
Args:
|
| 342 |
+
num_workers: The number of workers to add.
|
| 343 |
+
"""
|
| 344 |
+
new_actors = []
|
| 345 |
+
new_actor_metadata = []
|
| 346 |
+
for _ in range(num_workers):
|
| 347 |
+
actor = self._remote_cls.options(
|
| 348 |
+
placement_group=self._placement_group
|
| 349 |
+
).remote(*self._actor_cls_args, **self._actor_cls_kwargs)
|
| 350 |
+
new_actors.append(actor)
|
| 351 |
+
new_actor_metadata.append(
|
| 352 |
+
actor._RayTrainWorker__execute.options(
|
| 353 |
+
name="_RayTrainWorker__execute.construct_metadata"
|
| 354 |
+
).remote(construct_metadata)
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
# Get metadata from all actors.
|
| 358 |
+
metadata = ray.get(new_actor_metadata)
|
| 359 |
+
|
| 360 |
+
for i in range(len(new_actors)):
|
| 361 |
+
self.workers.append(Worker(actor=new_actors[i], metadata=metadata[i]))
|
| 362 |
+
|
| 363 |
+
def sort_workers_by_node_id_and_gpu_id(self, _first_node_id: Optional[str] = None):
|
| 364 |
+
"""Reorder the workers by their node id and the lowest GPU id.
|
| 365 |
+
|
| 366 |
+
This is useful for collocating workers on the same node.
|
| 367 |
+
|
| 368 |
+
Example:
|
| 369 |
+
Given workers with the following attributes:
|
| 370 |
+
worker_0: node_id=1, gpu_ids=[1]
|
| 371 |
+
worker_1: node_id=0, gpu_ids=[0]
|
| 372 |
+
worker_2: node_id=1, gpu_ids=[0]
|
| 373 |
+
worker_3: node_id=0, gpu_ids=[1]
|
| 374 |
+
|
| 375 |
+
The function will perform the following steps:
|
| 376 |
+
1. Group by node ID:
|
| 377 |
+
node_id=0: worker_1, worker_3
|
| 378 |
+
node_id=1: worker_0, worker_2
|
| 379 |
+
|
| 380 |
+
2. Sort each group by GPU ID:
|
| 381 |
+
node_id=0: worker_1 (gpu_id=0), worker_3 (gpu_id=1)
|
| 382 |
+
node_id=1: worker_2 (gpu_id=0), worker_0 (gpu_id=1)
|
| 383 |
+
|
| 384 |
+
Resulting in the order: [worker_1, worker_3, worker_2, worker_0]
|
| 385 |
+
|
| 386 |
+
Args:
|
| 387 |
+
_first_node_id: The first ID to group by.
|
| 388 |
+
Set this to the node ID of the trainer coordinator to ensure that the
|
| 389 |
+
rank 0 worker is on the same node, allowing additional resources to
|
| 390 |
+
be specified for rank 0 workers via
|
| 391 |
+
`ScalingConfig(trainer_resources=)`.
|
| 392 |
+
"""
|
| 393 |
+
node_id_to_workers = defaultdict(list)
|
| 394 |
+
|
| 395 |
+
if _first_node_id is not None:
|
| 396 |
+
node_id_to_workers[_first_node_id] = []
|
| 397 |
+
|
| 398 |
+
for worker in self.workers:
|
| 399 |
+
node_id_to_workers[worker.metadata.node_id].append(worker)
|
| 400 |
+
|
| 401 |
+
# Sort workers on the same node by the lowest GPU id
|
| 402 |
+
# More details: https://github.com/ray-project/ray/issues/40803
|
| 403 |
+
def get_lowest_gpu_id(worker) -> int:
|
| 404 |
+
gpu_ids = worker.metadata.resource_ids.get("GPU", [])
|
| 405 |
+
# If there are no GPU IDs, return 0 as a default
|
| 406 |
+
if not gpu_ids:
|
| 407 |
+
return 0
|
| 408 |
+
|
| 409 |
+
# Attempt to convert GPU IDs to integers and find the minimum ID.
|
| 410 |
+
# Fallback to return the minimum string-based ID
|
| 411 |
+
try:
|
| 412 |
+
return min(int(gpu_id) for gpu_id in gpu_ids)
|
| 413 |
+
except ValueError:
|
| 414 |
+
return min(gpu_ids)
|
| 415 |
+
|
| 416 |
+
for node_id in node_id_to_workers:
|
| 417 |
+
node_id_to_workers[node_id].sort(key=get_lowest_gpu_id)
|
| 418 |
+
|
| 419 |
+
sorted_workers = []
|
| 420 |
+
for workers in node_id_to_workers.values():
|
| 421 |
+
sorted_workers.extend(workers)
|
| 422 |
+
|
| 423 |
+
self.workers = sorted_workers
|
| 424 |
+
|
| 425 |
+
def __len__(self):
|
| 426 |
+
return len(self.workers)
|
.venv/lib/python3.11/site-packages/ray/train/horovod/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# isort: off
|
| 2 |
+
try:
|
| 3 |
+
import horovod # noqa: F401
|
| 4 |
+
except ModuleNotFoundError:
|
| 5 |
+
raise ModuleNotFoundError(
|
| 6 |
+
"Horovod isn't installed. To install Horovod with PyTorch support, run 'pip "
|
| 7 |
+
"install 'horovod[pytorch]''. To install Horovod with TensorFlow support, "
|
| 8 |
+
"run 'pip install 'horovod[tensorflow]''."
|
| 9 |
+
)
|
| 10 |
+
# isort: on
|
| 11 |
+
|
| 12 |
+
from ray.train.horovod.config import HorovodConfig
|
| 13 |
+
from ray.train.horovod.horovod_trainer import HorovodTrainer
|
| 14 |
+
from ray.train.v2._internal.constants import is_v2_enabled
|
| 15 |
+
|
| 16 |
+
if is_v2_enabled():
|
| 17 |
+
from ray.train.v2.horovod.horovod_trainer import HorovodTrainer # noqa: F811
|
| 18 |
+
|
| 19 |
+
__all__ = ["HorovodConfig", "HorovodTrainer"]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# DO NOT ADD ANYTHING AFTER THIS LINE.
|
.venv/lib/python3.11/site-packages/ray/train/horovod/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (930 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/horovod/__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (9.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/horovod/__pycache__/horovod_trainer.cpython-311.pyc
ADDED
|
Binary file (8.93 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/horovod/config.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Optional, Set
|
| 4 |
+
|
| 5 |
+
from horovod.ray.runner import Coordinator
|
| 6 |
+
from horovod.ray.utils import detect_nics, nics_to_env_var
|
| 7 |
+
from horovod.runner.common.util import secret, timeout
|
| 8 |
+
|
| 9 |
+
import ray
|
| 10 |
+
from ray.train._internal.utils import update_env_vars
|
| 11 |
+
from ray.train._internal.worker_group import Worker, WorkerGroup
|
| 12 |
+
from ray.train.backend import Backend, BackendConfig
|
| 13 |
+
from ray.util import PublicAPI
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@PublicAPI(stability="beta")
|
| 17 |
+
@dataclass
|
| 18 |
+
class HorovodConfig(BackendConfig):
|
| 19 |
+
"""Configurations for Horovod setup.
|
| 20 |
+
|
| 21 |
+
See https://github.com/horovod/horovod/blob/master/horovod/runner/common/util/settings.py # noqa: E501
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
nics (Optional[Set[str]): Network interfaces that can be used for
|
| 25 |
+
communication.
|
| 26 |
+
verbose: Horovod logging verbosity.
|
| 27 |
+
key (Optional[str]): Secret used for communication between workers.
|
| 28 |
+
ssh_port (Optional[int]): Port for SSH server running on worker nodes.
|
| 29 |
+
ssh_identity_file (Optional[str]): Path to the identity file to
|
| 30 |
+
ssh into different hosts on the cluster.
|
| 31 |
+
ssh_str (Optional[str]): CAUTION WHEN USING THIS. Private key
|
| 32 |
+
file contents. Writes the private key to ssh_identity_file.
|
| 33 |
+
timeout_s: Timeout parameter for Gloo rendezvous.
|
| 34 |
+
placement_group_timeout_s: Timeout parameter for Ray
|
| 35 |
+
Placement Group creation. Currently unused.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
nics: Optional[Set[str]] = None
|
| 39 |
+
verbose: int = 1
|
| 40 |
+
key: Optional[str] = None
|
| 41 |
+
ssh_port: Optional[int] = None
|
| 42 |
+
ssh_identity_file: Optional[str] = None
|
| 43 |
+
ssh_str: Optional[str] = None
|
| 44 |
+
timeout_s: int = 300
|
| 45 |
+
placement_group_timeout_s: int = 100
|
| 46 |
+
|
| 47 |
+
@property
|
| 48 |
+
def start_timeout(self):
|
| 49 |
+
return timeout.Timeout(
|
| 50 |
+
self.timeout_s,
|
| 51 |
+
message="Timed out waiting for {activity}. Please "
|
| 52 |
+
"check connectivity between servers. You "
|
| 53 |
+
"may need to increase the --start-timeout "
|
| 54 |
+
"parameter if you have too many servers.",
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def __post_init__(self):
|
| 58 |
+
if self.ssh_str and not os.path.exists(self.ssh_identity_file):
|
| 59 |
+
with open(self.ssh_identity_file, "w") as f:
|
| 60 |
+
os.chmod(self.ssh_identity_file, 0o600)
|
| 61 |
+
f.write(self.ssh_str)
|
| 62 |
+
|
| 63 |
+
if self.key is None:
|
| 64 |
+
self.key = secret.make_secret_key()
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def backend_cls(self):
|
| 68 |
+
return _HorovodBackend
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class _HorovodBackend(Backend):
|
| 72 |
+
share_cuda_visible_devices: bool = True
|
| 73 |
+
|
| 74 |
+
def on_start(self, worker_group: WorkerGroup, backend_config: HorovodConfig):
|
| 75 |
+
# TODO(matt): Implement placement group strategies in BackendExecutor.
|
| 76 |
+
|
| 77 |
+
# Initialize workers with Horovod environment variables
|
| 78 |
+
setup_futures = []
|
| 79 |
+
for rank in range(len(worker_group)):
|
| 80 |
+
worker_node_id = worker_group.workers[rank].metadata.node_id
|
| 81 |
+
setup_futures.append(
|
| 82 |
+
worker_group.execute_single_async(
|
| 83 |
+
rank,
|
| 84 |
+
_init_env_vars,
|
| 85 |
+
rank,
|
| 86 |
+
len(worker_group),
|
| 87 |
+
worker_node_id,
|
| 88 |
+
)
|
| 89 |
+
)
|
| 90 |
+
ray.get(setup_futures)
|
| 91 |
+
|
| 92 |
+
# Use Horovod Ray Coordinator
|
| 93 |
+
# backend_config as settings
|
| 94 |
+
self.coordinator = Coordinator(backend_config)
|
| 95 |
+
|
| 96 |
+
# Get all the hostnames of all workers
|
| 97 |
+
node_ids = [w.metadata.node_id for w in worker_group.workers]
|
| 98 |
+
hostnames = [w.metadata.hostname for w in worker_group.workers]
|
| 99 |
+
# Register each hostname to the coordinator. assumes the hostname
|
| 100 |
+
# ordering is the same.
|
| 101 |
+
for rank, (hostname, node_id) in enumerate(zip(hostnames, node_ids)):
|
| 102 |
+
self.coordinator.register(hostname, node_id, rank)
|
| 103 |
+
all_info = self.coordinator.finalize_registration()
|
| 104 |
+
|
| 105 |
+
setup_futures = []
|
| 106 |
+
for rank, local_cross_env_var in all_info.items():
|
| 107 |
+
setup_futures.append(
|
| 108 |
+
worker_group.execute_single_async(
|
| 109 |
+
rank, update_env_vars, local_cross_env_var
|
| 110 |
+
)
|
| 111 |
+
)
|
| 112 |
+
ray.get(setup_futures)
|
| 113 |
+
|
| 114 |
+
coordinator_envs = self.coordinator.establish_rendezvous()
|
| 115 |
+
|
| 116 |
+
# Get one worker from each host/node.
|
| 117 |
+
node_worker_indexes = [node_ids.index(node_id) for node_id in set(node_ids)]
|
| 118 |
+
node_workers = [
|
| 119 |
+
_HorovodWorkerWrapper(worker_group.workers[worker_index])
|
| 120 |
+
for worker_index in node_worker_indexes
|
| 121 |
+
]
|
| 122 |
+
assert len(node_workers) == len(self.coordinator.hostnames)
|
| 123 |
+
|
| 124 |
+
nics = detect_nics(
|
| 125 |
+
backend_config,
|
| 126 |
+
all_host_names=list(self.coordinator.hostnames),
|
| 127 |
+
node_workers=node_workers,
|
| 128 |
+
)
|
| 129 |
+
coordinator_envs.update(nics_to_env_var(nics))
|
| 130 |
+
|
| 131 |
+
worker_group.execute(update_env_vars, coordinator_envs)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def _init_env_vars(world_rank: int, world_size: int, node_id: str):
|
| 135 |
+
"""Initialize Horovod environment variables."""
|
| 136 |
+
os.environ["HOROVOD_HOSTNAME"] = node_id
|
| 137 |
+
os.environ["HOROVOD_RANK"] = str(world_rank)
|
| 138 |
+
os.environ["HOROVOD_SIZE"] = str(world_size)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# TODO(tgaddair): temporary workaround for Horovod's worker discovery logic,
|
| 142 |
+
# which requires passing in an extra parameter as part of the RayExecutor
|
| 143 |
+
# API. This will be removed in the future as we migrate more of the
|
| 144 |
+
# RayExecutor utils into Ray Train.
|
| 145 |
+
# See: https://github.com/horovod/horovod/blob/v0.23.0/horovod/ray/driver_service.py#L9 # noqa: E501
|
| 146 |
+
@dataclass
|
| 147 |
+
class _HorovodWorkerWrapper:
|
| 148 |
+
w: Worker
|
| 149 |
+
|
| 150 |
+
@property
|
| 151 |
+
def execute(self):
|
| 152 |
+
w = self.w
|
| 153 |
+
|
| 154 |
+
class ExecuteHandle:
|
| 155 |
+
def remote(self, func, *args, **kwargs):
|
| 156 |
+
_ = None
|
| 157 |
+
return w.actor._RayTrainWorker__execute.remote(func, _, *args, **kwargs)
|
| 158 |
+
|
| 159 |
+
return ExecuteHandle()
|
.venv/lib/python3.11/site-packages/ray/train/horovod/horovod_trainer.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable, Dict, Optional, Union
|
| 2 |
+
|
| 3 |
+
from ray.air.config import RunConfig, ScalingConfig
|
| 4 |
+
from ray.train import Checkpoint, DataConfig
|
| 5 |
+
from ray.train.data_parallel_trainer import DataParallelTrainer
|
| 6 |
+
from ray.train.horovod.config import HorovodConfig
|
| 7 |
+
from ray.train.trainer import GenDataset
|
| 8 |
+
from ray.util.annotations import PublicAPI
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@PublicAPI(stability="beta")
|
| 12 |
+
class HorovodTrainer(DataParallelTrainer):
|
| 13 |
+
"""A Trainer for data parallel Horovod training.
|
| 14 |
+
|
| 15 |
+
This Trainer runs the function ``train_loop_per_worker`` on multiple Ray
|
| 16 |
+
Actors. These actors already have the necessary Horovod setup already
|
| 17 |
+
configured for distributed Horovod training.
|
| 18 |
+
|
| 19 |
+
The ``train_loop_per_worker`` function is expected to take in either 0 or 1
|
| 20 |
+
arguments:
|
| 21 |
+
|
| 22 |
+
.. testcode::
|
| 23 |
+
|
| 24 |
+
def train_loop_per_worker():
|
| 25 |
+
...
|
| 26 |
+
|
| 27 |
+
.. testcode::
|
| 28 |
+
|
| 29 |
+
def train_loop_per_worker(config: Dict):
|
| 30 |
+
...
|
| 31 |
+
|
| 32 |
+
If ``train_loop_per_worker`` accepts an argument, then
|
| 33 |
+
``train_loop_config`` will be passed in as the argument. This is useful if you
|
| 34 |
+
want to tune the values in ``train_loop_config`` as hyperparameters.
|
| 35 |
+
|
| 36 |
+
If the ``datasets`` dict contains a training dataset (denoted by
|
| 37 |
+
the "train" key), then it will be split into multiple dataset
|
| 38 |
+
shards that can then be accessed by ``ray.train.get_dataset_shard("train")`` inside
|
| 39 |
+
``train_loop_per_worker``. All the other datasets will not be split and
|
| 40 |
+
``ray.train.get_dataset_shard(...)`` will return the the entire Dataset.
|
| 41 |
+
|
| 42 |
+
Inside the ``train_loop_per_worker`` function, you can use any of the
|
| 43 |
+
:ref:`Ray Train loop methods <train-loop-api>`.
|
| 44 |
+
|
| 45 |
+
.. testcode::
|
| 46 |
+
|
| 47 |
+
from ray import train
|
| 48 |
+
|
| 49 |
+
def train_loop_per_worker():
|
| 50 |
+
# Report intermediate results for callbacks or logging and
|
| 51 |
+
# checkpoint data.
|
| 52 |
+
train.report(...)
|
| 53 |
+
|
| 54 |
+
# Returns dict of last saved checkpoint.
|
| 55 |
+
train.get_checkpoint()
|
| 56 |
+
|
| 57 |
+
# Returns the Dataset shard for the given key.
|
| 58 |
+
train.get_dataset_shard("my_dataset")
|
| 59 |
+
|
| 60 |
+
# Returns the total number of workers executing training.
|
| 61 |
+
train.get_context().get_world_size()
|
| 62 |
+
|
| 63 |
+
# Returns the rank of this worker.
|
| 64 |
+
train.get_context().get_world_rank()
|
| 65 |
+
|
| 66 |
+
# Returns the rank of the worker on the current node.
|
| 67 |
+
train.get_context().get_local_rank()
|
| 68 |
+
|
| 69 |
+
Any returns from the ``train_loop_per_worker`` will be discarded and not
|
| 70 |
+
used or persisted anywhere.
|
| 71 |
+
|
| 72 |
+
You could use ``TensorflowPredictor`` or ``TorchPredictor`` in conjunction with
|
| 73 |
+
HorovodTrainer. You must save the model under the "model" kwarg in the
|
| 74 |
+
``Checkpoint`` passed to ``train.report()``, so that it can be used by
|
| 75 |
+
corresponding predictors.
|
| 76 |
+
|
| 77 |
+
Example:
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
.. testcode::
|
| 81 |
+
:skipif: True
|
| 82 |
+
|
| 83 |
+
import os
|
| 84 |
+
import tempfile
|
| 85 |
+
|
| 86 |
+
import ray
|
| 87 |
+
import horovod.torch as hvd
|
| 88 |
+
import torch
|
| 89 |
+
import torch.nn as nn
|
| 90 |
+
|
| 91 |
+
from ray import train
|
| 92 |
+
import ray.train.torch # Need this to use `train.torch.get_device()`
|
| 93 |
+
from ray.train import Checkpoint, ScalingConfig
|
| 94 |
+
from ray.train.horovod import HorovodTrainer
|
| 95 |
+
|
| 96 |
+
# If using GPUs, set this to True.
|
| 97 |
+
use_gpu = False
|
| 98 |
+
|
| 99 |
+
input_size = 1
|
| 100 |
+
layer_size = 15
|
| 101 |
+
output_size = 1
|
| 102 |
+
num_epochs = 3
|
| 103 |
+
|
| 104 |
+
class NeuralNetwork(nn.Module):
|
| 105 |
+
def __init__(self):
|
| 106 |
+
super(NeuralNetwork, self).__init__()
|
| 107 |
+
self.layer1 = nn.Linear(input_size, layer_size)
|
| 108 |
+
self.relu = nn.ReLU()
|
| 109 |
+
self.layer2 = nn.Linear(layer_size, output_size)
|
| 110 |
+
def forward(self, input):
|
| 111 |
+
return self.layer2(self.relu(self.layer1(input)))
|
| 112 |
+
|
| 113 |
+
def train_loop_per_worker():
|
| 114 |
+
hvd.init()
|
| 115 |
+
dataset_shard = train.get_dataset_shard("train")
|
| 116 |
+
model = NeuralNetwork()
|
| 117 |
+
device = train.torch.get_device()
|
| 118 |
+
model.to(device)
|
| 119 |
+
loss_fn = nn.MSELoss()
|
| 120 |
+
lr_scaler = 1
|
| 121 |
+
optimizer = torch.optim.SGD(model.parameters(), lr=0.1 * lr_scaler)
|
| 122 |
+
# Horovod: wrap optimizer with DistributedOptimizer.
|
| 123 |
+
optimizer = hvd.DistributedOptimizer(
|
| 124 |
+
optimizer,
|
| 125 |
+
named_parameters=model.named_parameters(),
|
| 126 |
+
op=hvd.Average,
|
| 127 |
+
)
|
| 128 |
+
for epoch in range(num_epochs):
|
| 129 |
+
model.train()
|
| 130 |
+
for batch in dataset_shard.iter_torch_batches(
|
| 131 |
+
batch_size=32, dtypes=torch.float
|
| 132 |
+
):
|
| 133 |
+
inputs, labels = torch.unsqueeze(batch["x"], 1), batch["y"]
|
| 134 |
+
outputs = model(inputs)
|
| 135 |
+
loss = loss_fn(outputs, labels)
|
| 136 |
+
optimizer.zero_grad()
|
| 137 |
+
loss.backward()
|
| 138 |
+
optimizer.step()
|
| 139 |
+
print(f"epoch: {epoch}, loss: {loss.item()}")
|
| 140 |
+
|
| 141 |
+
# Save a model checkpoint at the end of each epoch
|
| 142 |
+
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
|
| 143 |
+
ckpt_path = os.path.join(temp_checkpoint_dir, "model.pt")
|
| 144 |
+
torch.save(model.state_dict(), ckpt_path)
|
| 145 |
+
train.report(
|
| 146 |
+
{"loss": loss.item(), "epoch": epoch},
|
| 147 |
+
checkpoint=Checkpoint.from_directory(temp_checkpoint_dir),
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
train_dataset = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)])
|
| 151 |
+
scaling_config = ScalingConfig(num_workers=3, use_gpu=use_gpu)
|
| 152 |
+
trainer = HorovodTrainer(
|
| 153 |
+
train_loop_per_worker=train_loop_per_worker,
|
| 154 |
+
scaling_config=scaling_config,
|
| 155 |
+
datasets={"train": train_dataset},
|
| 156 |
+
)
|
| 157 |
+
result = trainer.fit()
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
train_loop_per_worker: The training function to execute.
|
| 161 |
+
This can either take in no arguments or a ``config`` dict.
|
| 162 |
+
train_loop_config: Configurations to pass into
|
| 163 |
+
``train_loop_per_worker`` if it accepts an argument.
|
| 164 |
+
horovod_config: Configuration for setting up the Horovod backend.
|
| 165 |
+
If set to None, use the default configuration. This replaces the
|
| 166 |
+
``backend_config`` arg of ``DataParallelTrainer``.
|
| 167 |
+
scaling_config: Configuration for how to scale data parallel training.
|
| 168 |
+
dataset_config: Configuration for dataset ingest.
|
| 169 |
+
run_config: Configuration for the execution of the training run.
|
| 170 |
+
datasets: Any Datasets to use for training. Use
|
| 171 |
+
the key "train" to denote which dataset is the training
|
| 172 |
+
dataset.
|
| 173 |
+
resume_from_checkpoint: A checkpoint to resume training from.
|
| 174 |
+
metadata: Dict that should be made available via
|
| 175 |
+
`ray.train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
|
| 176 |
+
for checkpoints saved from this Trainer. Must be JSON-serializable.
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
def __init__(
|
| 180 |
+
self,
|
| 181 |
+
train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]],
|
| 182 |
+
*,
|
| 183 |
+
train_loop_config: Optional[Dict] = None,
|
| 184 |
+
horovod_config: Optional[HorovodConfig] = None,
|
| 185 |
+
scaling_config: Optional[ScalingConfig] = None,
|
| 186 |
+
dataset_config: Optional[DataConfig] = None,
|
| 187 |
+
run_config: Optional[RunConfig] = None,
|
| 188 |
+
datasets: Optional[Dict[str, GenDataset]] = None,
|
| 189 |
+
metadata: Optional[Dict[str, Any]] = None,
|
| 190 |
+
resume_from_checkpoint: Optional[Checkpoint] = None,
|
| 191 |
+
):
|
| 192 |
+
super().__init__(
|
| 193 |
+
train_loop_per_worker=train_loop_per_worker,
|
| 194 |
+
train_loop_config=train_loop_config,
|
| 195 |
+
backend_config=horovod_config or HorovodConfig(),
|
| 196 |
+
scaling_config=scaling_config,
|
| 197 |
+
dataset_config=dataset_config,
|
| 198 |
+
run_config=run_config,
|
| 199 |
+
datasets=datasets,
|
| 200 |
+
resume_from_checkpoint=resume_from_checkpoint,
|
| 201 |
+
metadata=metadata,
|
| 202 |
+
)
|
.venv/lib/python3.11/site-packages/ray/train/lightning/__init__.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# isort: off
|
| 2 |
+
try:
|
| 3 |
+
import lightning # noqa: F401
|
| 4 |
+
except ModuleNotFoundError:
|
| 5 |
+
try:
|
| 6 |
+
import pytorch_lightning # noqa: F401
|
| 7 |
+
except ModuleNotFoundError:
|
| 8 |
+
raise ModuleNotFoundError(
|
| 9 |
+
"PyTorch Lightning isn't installed. To install PyTorch Lightning, "
|
| 10 |
+
"please run 'pip install lightning'"
|
| 11 |
+
)
|
| 12 |
+
# isort: on
|
| 13 |
+
|
| 14 |
+
from ray.train.lightning._lightning_utils import (
|
| 15 |
+
RayDDPStrategy,
|
| 16 |
+
RayDeepSpeedStrategy,
|
| 17 |
+
RayFSDPStrategy,
|
| 18 |
+
RayLightningEnvironment,
|
| 19 |
+
RayTrainReportCallback,
|
| 20 |
+
prepare_trainer,
|
| 21 |
+
)
|
| 22 |
+
from ray.train.v2._internal.constants import is_v2_enabled
|
| 23 |
+
|
| 24 |
+
if is_v2_enabled():
|
| 25 |
+
from ray.train.v2.lightning.lightning_utils import ( # noqa: F811
|
| 26 |
+
RayTrainReportCallback,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
__all__ = [
|
| 30 |
+
"prepare_trainer",
|
| 31 |
+
"RayDDPStrategy",
|
| 32 |
+
"RayFSDPStrategy",
|
| 33 |
+
"RayDeepSpeedStrategy",
|
| 34 |
+
"RayLightningEnvironment",
|
| 35 |
+
"RayTrainReportCallback",
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# DO NOT ADD ANYTHING AFTER THIS LINE.
|
.venv/lib/python3.11/site-packages/ray/train/lightning/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.16 kB). View file
|
|
|