Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/ray/train/__init__.py +90 -0
- .venv/lib/python3.11/site-packages/ray/train/_checkpoint.py +424 -0
- .venv/lib/python3.11/site-packages/ray/train/backend.py +59 -0
- .venv/lib/python3.11/site-packages/ray/train/base_trainer.py +827 -0
- .venv/lib/python3.11/site-packages/ray/train/constants.py +118 -0
- .venv/lib/python3.11/site-packages/ray/train/context.py +139 -0
- .venv/lib/python3.11/site-packages/ray/train/data_parallel_trainer.py +587 -0
- .venv/lib/python3.11/site-packages/ray/train/error.py +6 -0
- .venv/lib/python3.11/site-packages/ray/train/examples/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/train/examples/mlflow_simple_example.py +55 -0
- .venv/lib/python3.11/site-packages/ray/train/examples/tf/tune_tensorflow_autoencoder_example.py +77 -0
- .venv/lib/python3.11/site-packages/ray/train/huggingface/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/train/huggingface/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/huggingface/transformers/__init__.py +12 -0
- .venv/lib/python3.11/site-packages/ray/train/huggingface/transformers/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/huggingface/transformers/__pycache__/_transformers_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/huggingface/transformers/_transformers_utils.py +143 -0
- .venv/lib/python3.11/site-packages/ray/train/lightgbm/__init__.py +18 -0
- .venv/lib/python3.11/site-packages/ray/train/lightgbm/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/lightgbm/__pycache__/_lightgbm_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/lightgbm/__pycache__/config.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/lightgbm/__pycache__/lightgbm_checkpoint.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/lightgbm/__pycache__/lightgbm_predictor.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/lightgbm/__pycache__/lightgbm_trainer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/lightgbm/__pycache__/v2.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/lightgbm/_lightgbm_utils.py +170 -0
- .venv/lib/python3.11/site-packages/ray/train/lightgbm/config.py +89 -0
- .venv/lib/python3.11/site-packages/ray/train/lightgbm/lightgbm_checkpoint.py +70 -0
- .venv/lib/python3.11/site-packages/ray/train/lightgbm/lightgbm_predictor.py +152 -0
- .venv/lib/python3.11/site-packages/ray/train/lightgbm/lightgbm_trainer.py +221 -0
- .venv/lib/python3.11/site-packages/ray/train/lightgbm/v2.py +132 -0
- .venv/lib/python3.11/site-packages/ray/train/predictor.py +254 -0
- .venv/lib/python3.11/site-packages/ray/train/session.py +0 -0
- .venv/lib/python3.11/site-packages/ray/train/trainer.py +194 -0
- .venv/lib/python3.11/site-packages/ray/train/utils.py +19 -0
- .venv/lib/python3.11/site-packages/ray/train/xgboost/__init__.py +20 -0
- .venv/lib/python3.11/site-packages/ray/train/xgboost/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/xgboost/__pycache__/_xgboost_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/xgboost/__pycache__/config.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/xgboost/__pycache__/v2.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/xgboost/__pycache__/xgboost_checkpoint.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/xgboost/__pycache__/xgboost_predictor.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/xgboost/__pycache__/xgboost_trainer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/train/xgboost/_xgboost_utils.py +210 -0
- .venv/lib/python3.11/site-packages/ray/train/xgboost/config.py +202 -0
- .venv/lib/python3.11/site-packages/ray/train/xgboost/v2.py +133 -0
- .venv/lib/python3.11/site-packages/ray/train/xgboost/xgboost_checkpoint.py +75 -0
- .venv/lib/python3.11/site-packages/ray/train/xgboost/xgboost_predictor.py +160 -0
- .venv/lib/python3.11/site-packages/ray/train/xgboost/xgboost_trainer.py +222 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/__pycache__/__init__.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/ray/train/__init__.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Try import ray[train] core requirements (defined in setup.py)
|
| 2 |
+
# isort: off
|
| 3 |
+
try:
|
| 4 |
+
import fsspec # noqa: F401
|
| 5 |
+
import pandas # noqa: F401
|
| 6 |
+
import pyarrow # noqa: F401
|
| 7 |
+
import requests # noqa: F401
|
| 8 |
+
except ImportError as exc:
|
| 9 |
+
raise ImportError(
|
| 10 |
+
"Can't import ray.train as some dependencies are missing. "
|
| 11 |
+
'Run `pip install "ray[train]"` to fix.'
|
| 12 |
+
) from exc
|
| 13 |
+
# isort: on
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from ray._private.usage import usage_lib
|
| 17 |
+
from ray.air.config import CheckpointConfig, FailureConfig, RunConfig, ScalingConfig
|
| 18 |
+
from ray.air.result import Result
|
| 19 |
+
|
| 20 |
+
# Import this first so it can be used in other modules
|
| 21 |
+
from ray.train._checkpoint import Checkpoint
|
| 22 |
+
from ray.train._internal.data_config import DataConfig
|
| 23 |
+
from ray.train._internal.session import get_checkpoint, get_dataset_shard, report
|
| 24 |
+
from ray.train._internal.syncer import SyncConfig
|
| 25 |
+
from ray.train.backend import BackendConfig
|
| 26 |
+
from ray.train.constants import TRAIN_DATASET_KEY
|
| 27 |
+
from ray.train.context import get_context
|
| 28 |
+
from ray.train.trainer import TrainingIterator
|
| 29 |
+
from ray.train.v2._internal.constants import is_v2_enabled
|
| 30 |
+
|
| 31 |
+
if is_v2_enabled():
|
| 32 |
+
from ray.train.v2.api.callback import UserCallback # noqa: F811
|
| 33 |
+
from ray.train.v2.api.config import ( # noqa: F811
|
| 34 |
+
FailureConfig,
|
| 35 |
+
RunConfig,
|
| 36 |
+
ScalingConfig,
|
| 37 |
+
)
|
| 38 |
+
from ray.train.v2.api.result import Result # noqa: F811
|
| 39 |
+
from ray.train.v2.api.train_fn_utils import ( # noqa: F811
|
| 40 |
+
get_checkpoint,
|
| 41 |
+
get_context,
|
| 42 |
+
get_dataset_shard,
|
| 43 |
+
report,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
usage_lib.record_library_usage("train")
|
| 48 |
+
|
| 49 |
+
Checkpoint.__module__ = "ray.train"
|
| 50 |
+
|
| 51 |
+
__all__ = [
|
| 52 |
+
"get_checkpoint",
|
| 53 |
+
"get_context",
|
| 54 |
+
"get_dataset_shard",
|
| 55 |
+
"report",
|
| 56 |
+
"BackendConfig",
|
| 57 |
+
"Checkpoint",
|
| 58 |
+
"CheckpointConfig",
|
| 59 |
+
"DataConfig",
|
| 60 |
+
"FailureConfig",
|
| 61 |
+
"Result",
|
| 62 |
+
"RunConfig",
|
| 63 |
+
"ScalingConfig",
|
| 64 |
+
"SyncConfig",
|
| 65 |
+
"TrainingIterator",
|
| 66 |
+
"TRAIN_DATASET_KEY",
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
get_checkpoint.__module__ = "ray.train"
|
| 70 |
+
get_context.__module__ = "ray.train"
|
| 71 |
+
get_dataset_shard.__module__ = "ray.train"
|
| 72 |
+
report.__module__ = "ray.train"
|
| 73 |
+
BackendConfig.__module__ = "ray.train"
|
| 74 |
+
Checkpoint.__module__ = "ray.train"
|
| 75 |
+
CheckpointConfig.__module__ = "ray.train"
|
| 76 |
+
DataConfig.__module__ = "ray.train"
|
| 77 |
+
FailureConfig.__module__ = "ray.train"
|
| 78 |
+
Result.__module__ = "ray.train"
|
| 79 |
+
RunConfig.__module__ = "ray.train"
|
| 80 |
+
ScalingConfig.__module__ = "ray.train"
|
| 81 |
+
SyncConfig.__module__ = "ray.train"
|
| 82 |
+
TrainingIterator.__module__ = "ray.train"
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
if is_v2_enabled():
|
| 86 |
+
__all__.append("UserCallback")
|
| 87 |
+
UserCallback.__module__ = "ray.train"
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# DO NOT ADD ANYTHING AFTER THIS LINE.
|
.venv/lib/python3.11/site-packages/ray/train/_checkpoint.py
ADDED
|
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
import glob
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import platform
|
| 7 |
+
import shutil
|
| 8 |
+
import tempfile
|
| 9 |
+
import traceback
|
| 10 |
+
import uuid
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Any, Dict, Iterator, List, Optional, Union
|
| 13 |
+
|
| 14 |
+
import pyarrow.fs
|
| 15 |
+
|
| 16 |
+
from ray.air._internal.filelock import TempFileLock
|
| 17 |
+
from ray.train._internal.storage import _download_from_fs_path, _exists_at_fs_path
|
| 18 |
+
from ray.util.annotations import PublicAPI
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
# The filename of the file that stores user metadata set on the checkpoint.
|
| 23 |
+
_METADATA_FILE_NAME = ".metadata.json"
|
| 24 |
+
|
| 25 |
+
# The prefix of the temp checkpoint directory that `to_directory` downloads to
|
| 26 |
+
# on the local filesystem.
|
| 27 |
+
_CHECKPOINT_TEMP_DIR_PREFIX = "checkpoint_tmp_"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class _CheckpointMetaClass(type):
|
| 31 |
+
def __getattr__(self, item):
|
| 32 |
+
try:
|
| 33 |
+
return super().__getattribute__(item)
|
| 34 |
+
except AttributeError as exc:
|
| 35 |
+
if item in {
|
| 36 |
+
"from_dict",
|
| 37 |
+
"to_dict",
|
| 38 |
+
"from_bytes",
|
| 39 |
+
"to_bytes",
|
| 40 |
+
"get_internal_representation",
|
| 41 |
+
}:
|
| 42 |
+
raise _get_migration_error(item) from exc
|
| 43 |
+
elif item in {
|
| 44 |
+
"from_uri",
|
| 45 |
+
"to_uri",
|
| 46 |
+
"uri",
|
| 47 |
+
}:
|
| 48 |
+
raise _get_uri_error(item) from exc
|
| 49 |
+
elif item in {"get_preprocessor", "set_preprocessor"}:
|
| 50 |
+
raise _get_preprocessor_error(item) from exc
|
| 51 |
+
|
| 52 |
+
raise exc
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@PublicAPI(stability="beta")
|
| 56 |
+
class Checkpoint(metaclass=_CheckpointMetaClass):
|
| 57 |
+
"""A reference to data persisted as a directory in local or remote storage.
|
| 58 |
+
|
| 59 |
+
Access the checkpoint contents locally using ``checkpoint.to_directory()``
|
| 60 |
+
or ``checkpoint.as_directory``.
|
| 61 |
+
|
| 62 |
+
Attributes
|
| 63 |
+
----------
|
| 64 |
+
path: A path on the filesystem containing the checkpoint contents.
|
| 65 |
+
filesystem: PyArrow FileSystem that can be used to access data at the `path`.
|
| 66 |
+
|
| 67 |
+
See Also
|
| 68 |
+
--------
|
| 69 |
+
ray.train.report : Report a checkpoint during training (with Ray Train/Tune).
|
| 70 |
+
ray.train.get_checkpoint : Get the latest checkpoint during training
|
| 71 |
+
(for restoration).
|
| 72 |
+
|
| 73 |
+
:ref:`train-checkpointing`
|
| 74 |
+
:ref:`persistent-storage-guide`
|
| 75 |
+
|
| 76 |
+
Examples
|
| 77 |
+
--------
|
| 78 |
+
|
| 79 |
+
Creating a checkpoint using ``Checkpoint.from_directory``:
|
| 80 |
+
|
| 81 |
+
>>> from ray.train import Checkpoint
|
| 82 |
+
>>> checkpoint = Checkpoint.from_directory("/tmp/example_checkpoint_dir")
|
| 83 |
+
>>> checkpoint.filesystem # doctest: +ELLIPSIS
|
| 84 |
+
<pyarrow._fs.LocalFileSystem object...
|
| 85 |
+
>>> checkpoint.path
|
| 86 |
+
'/tmp/example_checkpoint_dir'
|
| 87 |
+
|
| 88 |
+
Creating a checkpoint from a remote URI:
|
| 89 |
+
|
| 90 |
+
>>> checkpoint = Checkpoint("s3://bucket/path/to/checkpoint")
|
| 91 |
+
>>> checkpoint.filesystem # doctest: +ELLIPSIS
|
| 92 |
+
<pyarrow._s3fs.S3FileSystem object...
|
| 93 |
+
>>> checkpoint.path
|
| 94 |
+
'bucket/path/to/checkpoint'
|
| 95 |
+
|
| 96 |
+
Creating a checkpoint with a custom filesystem:
|
| 97 |
+
|
| 98 |
+
>>> checkpoint = Checkpoint(
|
| 99 |
+
... path="bucket/path/to/checkpoint",
|
| 100 |
+
... filesystem=pyarrow.fs.S3FileSystem(),
|
| 101 |
+
... )
|
| 102 |
+
>>> checkpoint.filesystem # doctest: +ELLIPSIS
|
| 103 |
+
<pyarrow._s3fs.S3FileSystem object...
|
| 104 |
+
>>> checkpoint.path
|
| 105 |
+
'bucket/path/to/checkpoint'
|
| 106 |
+
|
| 107 |
+
Accessing a checkpoint's contents:
|
| 108 |
+
|
| 109 |
+
>>> import os # doctest: +SKIP
|
| 110 |
+
>>> with checkpoint.as_directory() as local_checkpoint_dir: # doctest: +SKIP
|
| 111 |
+
... print(os.listdir(local_checkpoint_dir)) # doctest: +SKIP
|
| 112 |
+
['model.pt', 'optimizer.pt', 'misc.pt']
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
def __init__(
|
| 116 |
+
self,
|
| 117 |
+
path: Union[str, os.PathLike],
|
| 118 |
+
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
|
| 119 |
+
):
|
| 120 |
+
"""Construct a Checkpoint.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
path: A local path or remote URI containing the checkpoint data.
|
| 124 |
+
If a filesystem is provided, then this path must NOT be a URI.
|
| 125 |
+
It should be a path on the filesystem with the prefix already stripped.
|
| 126 |
+
filesystem: PyArrow FileSystem to use to access data at the path.
|
| 127 |
+
If not specified, this is inferred from the URI scheme.
|
| 128 |
+
"""
|
| 129 |
+
self.path = str(path)
|
| 130 |
+
self.filesystem = filesystem
|
| 131 |
+
|
| 132 |
+
if path and not filesystem:
|
| 133 |
+
self.filesystem, self.path = pyarrow.fs.FileSystem.from_uri(path)
|
| 134 |
+
|
| 135 |
+
# This random UUID is used to create a temporary directory name on the
|
| 136 |
+
# local filesystem, which will be used for downloading checkpoint data.
|
| 137 |
+
# This ensures that if multiple processes download the same checkpoint object
|
| 138 |
+
# only one process performs the actual download while the others wait.
|
| 139 |
+
# This prevents duplicated download efforts and data.
|
| 140 |
+
# NOTE: Calling `to_directory` from multiple `Checkpoint` objects
|
| 141 |
+
# that point to the same (fs, path) will still download the data multiple times.
|
| 142 |
+
# This only ensures a canonical temp directory name for a single `Checkpoint`.
|
| 143 |
+
self._uuid = uuid.uuid4()
|
| 144 |
+
|
| 145 |
+
def __repr__(self):
|
| 146 |
+
return f"Checkpoint(filesystem={self.filesystem.type_name}, path={self.path})"
|
| 147 |
+
|
| 148 |
+
def get_metadata(self) -> Dict[str, Any]:
|
| 149 |
+
"""Return the metadata dict stored with the checkpoint.
|
| 150 |
+
|
| 151 |
+
If no metadata is stored, an empty dict is returned.
|
| 152 |
+
"""
|
| 153 |
+
metadata_path = Path(self.path, _METADATA_FILE_NAME).as_posix()
|
| 154 |
+
if not _exists_at_fs_path(self.filesystem, metadata_path):
|
| 155 |
+
return {}
|
| 156 |
+
|
| 157 |
+
with self.filesystem.open_input_file(metadata_path) as f:
|
| 158 |
+
return json.loads(f.readall().decode("utf-8"))
|
| 159 |
+
|
| 160 |
+
def set_metadata(self, metadata: Dict[str, Any]) -> None:
|
| 161 |
+
"""Set the metadata stored with this checkpoint.
|
| 162 |
+
|
| 163 |
+
This will overwrite any existing metadata stored with this checkpoint.
|
| 164 |
+
"""
|
| 165 |
+
metadata_path = Path(self.path, _METADATA_FILE_NAME).as_posix()
|
| 166 |
+
with self.filesystem.open_output_stream(metadata_path) as f:
|
| 167 |
+
f.write(json.dumps(metadata).encode("utf-8"))
|
| 168 |
+
|
| 169 |
+
def update_metadata(self, metadata: Dict[str, Any]) -> None:
|
| 170 |
+
"""Update the metadata stored with this checkpoint.
|
| 171 |
+
|
| 172 |
+
This will update any existing metadata stored with this checkpoint.
|
| 173 |
+
"""
|
| 174 |
+
existing_metadata = self.get_metadata()
|
| 175 |
+
existing_metadata.update(metadata)
|
| 176 |
+
self.set_metadata(existing_metadata)
|
| 177 |
+
|
| 178 |
+
@classmethod
|
| 179 |
+
def from_directory(cls, path: Union[str, os.PathLike]) -> "Checkpoint":
|
| 180 |
+
"""Create checkpoint object from a local directory.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
path: Local directory containing checkpoint data.
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
A ray.train.Checkpoint object.
|
| 187 |
+
"""
|
| 188 |
+
return cls(path, filesystem=pyarrow.fs.LocalFileSystem())
|
| 189 |
+
|
| 190 |
+
def to_directory(self, path: Optional[Union[str, os.PathLike]] = None) -> str:
|
| 191 |
+
"""Write checkpoint data to a local directory.
|
| 192 |
+
|
| 193 |
+
*If multiple processes on the same node call this method simultaneously,*
|
| 194 |
+
only a single process will perform the download, while the others
|
| 195 |
+
wait for the download to finish. Once the download finishes, all processes
|
| 196 |
+
receive the same local directory to read from.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
path: Target directory to download data to. If not specified,
|
| 200 |
+
this method will use a temporary directory.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
str: Directory containing checkpoint data.
|
| 204 |
+
"""
|
| 205 |
+
user_provided_path = path is not None
|
| 206 |
+
local_path = (
|
| 207 |
+
path if user_provided_path else self._get_temporary_checkpoint_dir()
|
| 208 |
+
)
|
| 209 |
+
local_path = os.path.normpath(os.path.expanduser(str(local_path)))
|
| 210 |
+
os.makedirs(local_path, exist_ok=True)
|
| 211 |
+
|
| 212 |
+
try:
|
| 213 |
+
# Timeout 0 means there will be only one attempt to acquire
|
| 214 |
+
# the file lock. If it cannot be acquired, throw a TimeoutError
|
| 215 |
+
with TempFileLock(local_path, timeout=0):
|
| 216 |
+
_download_from_fs_path(
|
| 217 |
+
fs=self.filesystem, fs_path=self.path, local_path=local_path
|
| 218 |
+
)
|
| 219 |
+
except TimeoutError:
|
| 220 |
+
# if the directory is already locked, then wait but do not do anything.
|
| 221 |
+
with TempFileLock(local_path, timeout=-1):
|
| 222 |
+
pass
|
| 223 |
+
if not os.path.exists(local_path):
|
| 224 |
+
raise RuntimeError(
|
| 225 |
+
f"Checkpoint directory {local_path} does not exist, "
|
| 226 |
+
"even though it should have been created by "
|
| 227 |
+
"another process. Please raise an issue on GitHub: "
|
| 228 |
+
"https://github.com/ray-project/ray/issues"
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
return local_path
|
| 232 |
+
|
| 233 |
+
@contextlib.contextmanager
|
| 234 |
+
def as_directory(self) -> Iterator[str]:
|
| 235 |
+
"""Returns checkpoint contents in a local directory as a context.
|
| 236 |
+
|
| 237 |
+
This function makes checkpoint data available as a directory while avoiding
|
| 238 |
+
unnecessary copies and left-over temporary data.
|
| 239 |
+
|
| 240 |
+
*If the checkpoint points to a local directory*, this method just returns the
|
| 241 |
+
local directory path without making a copy, and nothing will be cleaned up
|
| 242 |
+
after exiting the context.
|
| 243 |
+
|
| 244 |
+
*If the checkpoint points to a remote directory*, this method will download the
|
| 245 |
+
checkpoint to a local temporary directory and return the path
|
| 246 |
+
to the temporary directory.
|
| 247 |
+
|
| 248 |
+
*If multiple processes on the same node call this method simultaneously,*
|
| 249 |
+
only a single process will perform the download, while the others
|
| 250 |
+
wait for the download to finish. Once the download finishes, all processes
|
| 251 |
+
receive the same local (temporary) directory to read from.
|
| 252 |
+
|
| 253 |
+
Once all processes have finished working with the checkpoint,
|
| 254 |
+
the temporary directory is cleaned up.
|
| 255 |
+
|
| 256 |
+
Users should treat the returned checkpoint directory as read-only and avoid
|
| 257 |
+
changing any data within it, as it may be deleted when exiting the context.
|
| 258 |
+
|
| 259 |
+
Example:
|
| 260 |
+
|
| 261 |
+
.. testcode::
|
| 262 |
+
:hide:
|
| 263 |
+
|
| 264 |
+
from pathlib import Path
|
| 265 |
+
import tempfile
|
| 266 |
+
|
| 267 |
+
from ray.train import Checkpoint
|
| 268 |
+
|
| 269 |
+
temp_dir = tempfile.mkdtemp()
|
| 270 |
+
(Path(temp_dir) / "example.txt").write_text("example checkpoint data")
|
| 271 |
+
checkpoint = Checkpoint.from_directory(temp_dir)
|
| 272 |
+
|
| 273 |
+
.. testcode::
|
| 274 |
+
|
| 275 |
+
with checkpoint.as_directory() as checkpoint_dir:
|
| 276 |
+
# Do some read-only processing of files within checkpoint_dir
|
| 277 |
+
pass
|
| 278 |
+
|
| 279 |
+
# At this point, if a temporary directory was created, it will have
|
| 280 |
+
# been deleted.
|
| 281 |
+
|
| 282 |
+
"""
|
| 283 |
+
if isinstance(self.filesystem, pyarrow.fs.LocalFileSystem):
|
| 284 |
+
yield self.path
|
| 285 |
+
else:
|
| 286 |
+
del_lock_path = _get_del_lock_path(self._get_temporary_checkpoint_dir())
|
| 287 |
+
open(del_lock_path, "a").close()
|
| 288 |
+
|
| 289 |
+
temp_dir = self.to_directory()
|
| 290 |
+
try:
|
| 291 |
+
yield temp_dir
|
| 292 |
+
finally:
|
| 293 |
+
# Always cleanup the del lock after we're done with the directory.
|
| 294 |
+
# This avoids leaving a lock file behind in the case of an exception
|
| 295 |
+
# in the user code.
|
| 296 |
+
try:
|
| 297 |
+
os.remove(del_lock_path)
|
| 298 |
+
except Exception:
|
| 299 |
+
logger.warning(
|
| 300 |
+
f"Could not remove {del_lock_path} deletion file lock. "
|
| 301 |
+
f"Traceback:\n{traceback.format_exc()}"
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# If there are no more lock files, that means there are no more
|
| 305 |
+
# readers of this directory, and we can safely delete it.
|
| 306 |
+
# In the edge case (process crash before del lock file is removed),
|
| 307 |
+
# we do not remove the directory at all.
|
| 308 |
+
# Since it's in /tmp, this is not that big of a deal.
|
| 309 |
+
# check if any lock files are remaining
|
| 310 |
+
remaining_locks = _list_existing_del_locks(temp_dir)
|
| 311 |
+
if not remaining_locks:
|
| 312 |
+
try:
|
| 313 |
+
# Timeout 0 means there will be only one attempt to acquire
|
| 314 |
+
# the file lock. If it cannot be acquired, a TimeoutError
|
| 315 |
+
# will be thrown.
|
| 316 |
+
with TempFileLock(temp_dir, timeout=0):
|
| 317 |
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
| 318 |
+
except TimeoutError:
|
| 319 |
+
pass
|
| 320 |
+
|
| 321 |
+
def _get_temporary_checkpoint_dir(self) -> str:
|
| 322 |
+
"""Return the name for the temporary checkpoint dir that this checkpoint
|
| 323 |
+
will get downloaded to, if accessing via `to_directory` or `as_directory`.
|
| 324 |
+
"""
|
| 325 |
+
tmp_dir_path = tempfile.gettempdir()
|
| 326 |
+
checkpoint_dir_name = _CHECKPOINT_TEMP_DIR_PREFIX + self._uuid.hex
|
| 327 |
+
if platform.system() == "Windows":
|
| 328 |
+
# Max path on Windows is 260 chars, -1 for joining \
|
| 329 |
+
# Also leave a little for the del lock
|
| 330 |
+
del_lock_name = _get_del_lock_path("")
|
| 331 |
+
checkpoint_dir_name = (
|
| 332 |
+
_CHECKPOINT_TEMP_DIR_PREFIX
|
| 333 |
+
+ self._uuid.hex[
|
| 334 |
+
-259
|
| 335 |
+
+ len(_CHECKPOINT_TEMP_DIR_PREFIX)
|
| 336 |
+
+ len(tmp_dir_path)
|
| 337 |
+
+ len(del_lock_name) :
|
| 338 |
+
]
|
| 339 |
+
)
|
| 340 |
+
if not checkpoint_dir_name.startswith(_CHECKPOINT_TEMP_DIR_PREFIX):
|
| 341 |
+
raise RuntimeError(
|
| 342 |
+
"Couldn't create checkpoint directory due to length "
|
| 343 |
+
"constraints. Try specifying a shorter checkpoint path."
|
| 344 |
+
)
|
| 345 |
+
return Path(tmp_dir_path, checkpoint_dir_name).as_posix()
|
| 346 |
+
|
| 347 |
+
def __fspath__(self):
|
| 348 |
+
raise TypeError(
|
| 349 |
+
"You cannot use `Checkpoint` objects directly as paths. "
|
| 350 |
+
"Use `Checkpoint.to_directory()` or `Checkpoint.as_directory()` instead."
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def _get_del_lock_path(path: str, suffix: str = None) -> str:
|
| 355 |
+
"""Get the path to the deletion lock file for a file/directory at `path`.
|
| 356 |
+
|
| 357 |
+
Example:
|
| 358 |
+
|
| 359 |
+
>>> _get_del_lock_path("/tmp/checkpoint_tmp") # doctest: +ELLIPSIS
|
| 360 |
+
'/tmp/checkpoint_tmp.del_lock_...
|
| 361 |
+
>>> _get_del_lock_path("/tmp/checkpoint_tmp/") # doctest: +ELLIPSIS
|
| 362 |
+
'/tmp/checkpoint_tmp.del_lock_...
|
| 363 |
+
>>> _get_del_lock_path("/tmp/checkpoint_tmp.txt") # doctest: +ELLIPSIS
|
| 364 |
+
'/tmp/checkpoint_tmp.txt.del_lock_...
|
| 365 |
+
|
| 366 |
+
"""
|
| 367 |
+
suffix = suffix if suffix is not None else str(os.getpid())
|
| 368 |
+
return f"{path.rstrip('/')}.del_lock_{suffix}"
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def _list_existing_del_locks(path: str) -> List[str]:
|
| 372 |
+
"""List all the deletion lock files for a file/directory at `path`.
|
| 373 |
+
|
| 374 |
+
For example, if 2 checkpoints are being read via `as_directory`,
|
| 375 |
+
then this should return a list of 2 deletion lock files.
|
| 376 |
+
"""
|
| 377 |
+
return list(glob.glob(f"{_get_del_lock_path(path, suffix='*')}"))
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def _get_migration_error(name: str):
|
| 381 |
+
return AttributeError(
|
| 382 |
+
f"The new `ray.train.Checkpoint` class does not support `{name}()`. "
|
| 383 |
+
f"Instead, only directories are supported.\n\n"
|
| 384 |
+
f"Example to store a dictionary in a checkpoint:\n\n"
|
| 385 |
+
f"import os, tempfile\n"
|
| 386 |
+
f"import ray.cloudpickle as pickle\n"
|
| 387 |
+
f"from ray import train\n"
|
| 388 |
+
f"from ray.train import Checkpoint\n\n"
|
| 389 |
+
f"with tempfile.TemporaryDirectory() as checkpoint_dir:\n"
|
| 390 |
+
f" with open(os.path.join(checkpoint_dir, 'data.pkl'), 'wb') as fp:\n"
|
| 391 |
+
f" pickle.dump({{'data': 'value'}}, fp)\n\n"
|
| 392 |
+
f" checkpoint = Checkpoint.from_directory(checkpoint_dir)\n"
|
| 393 |
+
f" train.report(..., checkpoint=checkpoint)\n\n"
|
| 394 |
+
f"Example to load a dictionary from a checkpoint:\n\n"
|
| 395 |
+
f"if train.get_checkpoint():\n"
|
| 396 |
+
f" with train.get_checkpoint().as_directory() as checkpoint_dir:\n"
|
| 397 |
+
f" with open(os.path.join(checkpoint_dir, 'data.pkl'), 'rb') as fp:\n"
|
| 398 |
+
f" data = pickle.load(fp)"
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def _get_uri_error(name: str):
|
| 403 |
+
return AttributeError(
|
| 404 |
+
f"The new `ray.train.Checkpoint` class does not support `{name}()`. "
|
| 405 |
+
f"To create a checkpoint from remote storage, create a `Checkpoint` using its "
|
| 406 |
+
f"constructor instead of `from_directory`.\n"
|
| 407 |
+
f'Example: `Checkpoint(path="s3://a/b/c")`.\n'
|
| 408 |
+
f"Then, access the contents of the checkpoint with "
|
| 409 |
+
f"`checkpoint.as_directory()` / `checkpoint.to_directory()`.\n"
|
| 410 |
+
f"To upload data to remote storage, use e.g. `pyarrow.fs.FileSystem` "
|
| 411 |
+
f"or your client of choice."
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def _get_preprocessor_error(name: str):
|
| 416 |
+
return AttributeError(
|
| 417 |
+
f"The new `ray.train.Checkpoint` class does not support `{name}()`. "
|
| 418 |
+
f"To include preprocessor information in checkpoints, "
|
| 419 |
+
f"pass it as metadata in the <Framework>Trainer constructor.\n"
|
| 420 |
+
f"Example: `TorchTrainer(..., metadata={{...}})`.\n"
|
| 421 |
+
f"After training, access it in the checkpoint via `checkpoint.get_metadata()`. "
|
| 422 |
+
f"See here: https://docs.ray.io/en/master/train/user-guides/"
|
| 423 |
+
f"data-loading-preprocessing.html#preprocessing-structured-data"
|
| 424 |
+
)
|
.venv/lib/python3.11/site-packages/ray/train/backend.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from contextlib import nullcontext
|
| 3 |
+
from typing import TypeVar
|
| 4 |
+
|
| 5 |
+
from ray.train._internal.utils import Singleton
|
| 6 |
+
from ray.train._internal.worker_group import WorkerGroup
|
| 7 |
+
from ray.util.annotations import DeveloperAPI
|
| 8 |
+
from ray.widgets import make_table_html_repr
|
| 9 |
+
|
| 10 |
+
EncodedData = TypeVar("EncodedData")
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@DeveloperAPI
|
| 16 |
+
class BackendConfig:
|
| 17 |
+
"""Parent class for configurations of training backend."""
|
| 18 |
+
|
| 19 |
+
@property
|
| 20 |
+
def backend_cls(self):
|
| 21 |
+
return Backend
|
| 22 |
+
|
| 23 |
+
@property
|
| 24 |
+
def train_func_context(self):
|
| 25 |
+
return nullcontext
|
| 26 |
+
|
| 27 |
+
def _repr_html_(self) -> str:
|
| 28 |
+
return make_table_html_repr(obj=self, title=type(self).__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@DeveloperAPI
|
| 32 |
+
class Backend(metaclass=Singleton):
|
| 33 |
+
"""Singleton for distributed communication backend.
|
| 34 |
+
|
| 35 |
+
Attributes:
|
| 36 |
+
share_cuda_visible_devices: If True, each worker
|
| 37 |
+
process will have CUDA_VISIBLE_DEVICES set as the visible device
|
| 38 |
+
IDs of all workers on the same node for this training instance.
|
| 39 |
+
If False, each worker will have CUDA_VISIBLE_DEVICES set to the
|
| 40 |
+
device IDs allocated by Ray for that worker.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
share_cuda_visible_devices: bool = False
|
| 44 |
+
|
| 45 |
+
def on_start(self, worker_group: WorkerGroup, backend_config: BackendConfig):
|
| 46 |
+
"""Logic for starting this backend."""
|
| 47 |
+
pass
|
| 48 |
+
|
| 49 |
+
def on_shutdown(self, worker_group: WorkerGroup, backend_config: BackendConfig):
|
| 50 |
+
"""Logic for shutting down the backend."""
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
def on_training_start(
|
| 54 |
+
self, worker_group: WorkerGroup, backend_config: BackendConfig
|
| 55 |
+
):
|
| 56 |
+
"""Logic ran right before training is started.
|
| 57 |
+
|
| 58 |
+
Session API is available at this point."""
|
| 59 |
+
pass
|
.venv/lib/python3.11/site-packages/ray/train/base_trainer.py
ADDED
|
@@ -0,0 +1,827 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
import copy
|
| 3 |
+
import inspect
|
| 4 |
+
import json
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import warnings
|
| 8 |
+
from functools import partial
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
|
| 11 |
+
|
| 12 |
+
import pyarrow.fs
|
| 13 |
+
|
| 14 |
+
import ray
|
| 15 |
+
import ray.cloudpickle as pickle
|
| 16 |
+
from ray._private.dict import deep_update
|
| 17 |
+
from ray.air._internal import usage as air_usage
|
| 18 |
+
from ray.air._internal.config import ensure_only_allowed_dataclass_keys_updated
|
| 19 |
+
from ray.air._internal.usage import AirEntrypoint
|
| 20 |
+
from ray.air.config import RunConfig, ScalingConfig
|
| 21 |
+
from ray.air.result import Result
|
| 22 |
+
from ray.train import Checkpoint
|
| 23 |
+
from ray.train._internal.session import get_session
|
| 24 |
+
from ray.train._internal.storage import (
|
| 25 |
+
StorageContext,
|
| 26 |
+
_exists_at_fs_path,
|
| 27 |
+
get_fs_and_path,
|
| 28 |
+
)
|
| 29 |
+
from ray.util import PublicAPI
|
| 30 |
+
from ray.util.annotations import DeveloperAPI
|
| 31 |
+
|
| 32 |
+
if TYPE_CHECKING:
|
| 33 |
+
from ray.data import Dataset
|
| 34 |
+
from ray.tune import Trainable
|
| 35 |
+
|
| 36 |
+
_TRAINER_PKL = "trainer.pkl"
|
| 37 |
+
|
| 38 |
+
# A type representing either a ray.data.Dataset or a function that returns a
|
| 39 |
+
# ray.data.Dataset and accepts no arguments.
|
| 40 |
+
GenDataset = Union["Dataset", Callable[[], "Dataset"]]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
logger = logging.getLogger(__name__)
|
| 44 |
+
|
| 45 |
+
PREPROCESSOR_DEPRECATION_MESSAGE = (
|
| 46 |
+
"The `preprocessor` argument to Trainers is deprecated as of Ray 2.7. "
|
| 47 |
+
"Instead, use the Preprocessor `fit` and `transform` APIs directly on the Ray "
|
| 48 |
+
"Dataset. For any state that needs to be saved to the trained checkpoint, pass it "
|
| 49 |
+
"in using the `metadata` argument of the `Trainer`. "
|
| 50 |
+
"For a full example, see "
|
| 51 |
+
"https://docs.ray.io/en/master/train/user-guides/data-loading-preprocessing.html#preprocessing-structured-data " # noqa:E501
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@PublicAPI(stability="beta")
|
| 56 |
+
class TrainingFailedError(RuntimeError):
|
| 57 |
+
"""An error indicating that training has failed."""
|
| 58 |
+
|
| 59 |
+
_RESTORE_MSG = (
|
| 60 |
+
"The Ray Train run failed. Please inspect the previous error messages for a "
|
| 61 |
+
"cause. After fixing the issue (assuming that the error is not caused by "
|
| 62 |
+
"your own application logic, but rather an error such as OOM), you can restart "
|
| 63 |
+
"the run from scratch or continue this run.\n"
|
| 64 |
+
"To continue this run, you can use: "
|
| 65 |
+
'`trainer = {trainer_cls_name}.restore("{path}")`.'
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
_FAILURE_CONFIG_MSG = (
|
| 69 |
+
"To start a new run that will retry on training failures, set "
|
| 70 |
+
"`train.RunConfig(failure_config=train.FailureConfig(max_failures))` "
|
| 71 |
+
"in the Trainer's `run_config` with `max_failures > 0`, or `max_failures = -1` "
|
| 72 |
+
"for unlimited retries."
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _train_coordinator_fn(
|
| 77 |
+
config: dict, trainer_cls: Type["BaseTrainer"], metadata: dict
|
| 78 |
+
):
|
| 79 |
+
"""This is the function that defines the logic of the Ray Train coordinator.
|
| 80 |
+
This is responsible for setting up a remote instance of the `trainer_cls`
|
| 81 |
+
(a different instance than the one calling `trainer.fit` on the driver!)
|
| 82 |
+
and running the training loop.
|
| 83 |
+
"""
|
| 84 |
+
assert metadata is not None, metadata
|
| 85 |
+
# Propagate user metadata from the Trainer constructor.
|
| 86 |
+
get_session().metadata = metadata
|
| 87 |
+
|
| 88 |
+
# config already contains merged values.
|
| 89 |
+
# Instantiate new Trainer in Trainable.
|
| 90 |
+
trainer = trainer_cls(**config)
|
| 91 |
+
|
| 92 |
+
# Get the checkpoint from Tune and pass it to workers later on.
|
| 93 |
+
checkpoint = ray.train.get_checkpoint()
|
| 94 |
+
if checkpoint:
|
| 95 |
+
# Set `starting_checkpoint` for auto-recovery fault-tolerance
|
| 96 |
+
# as well as manual restoration.
|
| 97 |
+
trainer.starting_checkpoint = checkpoint
|
| 98 |
+
# else: Train will restore from the user-provided
|
| 99 |
+
# `resume_from_checkpoint` == `starting_checkpoint`.
|
| 100 |
+
|
| 101 |
+
# Evaluate datasets if they are wrapped in a factory.
|
| 102 |
+
trainer.datasets = {
|
| 103 |
+
k: d() if callable(d) else d for k, d in trainer.datasets.items()
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
trainer.setup()
|
| 107 |
+
trainer.training_loop()
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@DeveloperAPI
|
| 111 |
+
class BaseTrainer(abc.ABC):
|
| 112 |
+
"""Defines interface for distributed training on Ray.
|
| 113 |
+
|
| 114 |
+
Note: The base ``BaseTrainer`` class cannot be instantiated directly. Only
|
| 115 |
+
one of its subclasses can be used.
|
| 116 |
+
|
| 117 |
+
Note to developers: If a new trainer is added, please update
|
| 118 |
+
`air/_internal/usage.py`.
|
| 119 |
+
|
| 120 |
+
**How does a trainer work?**
|
| 121 |
+
|
| 122 |
+
- First, initialize the Trainer. The initialization runs locally,
|
| 123 |
+
so heavyweight setup should not be done in ``__init__``.
|
| 124 |
+
- Then, when you call ``trainer.fit()``, the Trainer is serialized
|
| 125 |
+
and copied to a remote Ray actor. The following methods are then
|
| 126 |
+
called in sequence on the remote actor.
|
| 127 |
+
- ``trainer.setup()``: Any heavyweight Trainer setup should be
|
| 128 |
+
specified here.
|
| 129 |
+
- ``trainer.training_loop()``: Executes the main training logic.
|
| 130 |
+
- Calling ``trainer.fit()`` will return a ``ray.result.Result``
|
| 131 |
+
object where you can access metrics from your training run, as well
|
| 132 |
+
as any checkpoints that may have been saved.
|
| 133 |
+
|
| 134 |
+
**How do I create a new Trainer?**
|
| 135 |
+
|
| 136 |
+
Subclass ``ray.train.trainer.BaseTrainer``, and override the ``training_loop``
|
| 137 |
+
method, and optionally ``setup``.
|
| 138 |
+
|
| 139 |
+
.. testcode::
|
| 140 |
+
|
| 141 |
+
import torch
|
| 142 |
+
|
| 143 |
+
from ray.train.trainer import BaseTrainer
|
| 144 |
+
from ray import train, tune
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class MyPytorchTrainer(BaseTrainer):
|
| 148 |
+
def setup(self):
|
| 149 |
+
self.model = torch.nn.Linear(1, 1)
|
| 150 |
+
self.optimizer = torch.optim.SGD(
|
| 151 |
+
self.model.parameters(), lr=0.1)
|
| 152 |
+
|
| 153 |
+
def training_loop(self):
|
| 154 |
+
# You can access any Trainer attributes directly in this method.
|
| 155 |
+
# self.datasets["train"] has already been
|
| 156 |
+
dataset = self.datasets["train"]
|
| 157 |
+
|
| 158 |
+
torch_ds = dataset.iter_torch_batches(dtypes=torch.float)
|
| 159 |
+
loss_fn = torch.nn.MSELoss()
|
| 160 |
+
|
| 161 |
+
for epoch_idx in range(10):
|
| 162 |
+
loss = 0
|
| 163 |
+
num_batches = 0
|
| 164 |
+
torch_ds = dataset.iter_torch_batches(
|
| 165 |
+
dtypes=torch.float, batch_size=2
|
| 166 |
+
)
|
| 167 |
+
for batch in torch_ds:
|
| 168 |
+
X = torch.unsqueeze(batch["x"], 1)
|
| 169 |
+
y = torch.unsqueeze(batch["y"], 1)
|
| 170 |
+
# Compute prediction error
|
| 171 |
+
pred = self.model(X)
|
| 172 |
+
batch_loss = loss_fn(pred, y)
|
| 173 |
+
|
| 174 |
+
# Backpropagation
|
| 175 |
+
self.optimizer.zero_grad()
|
| 176 |
+
batch_loss.backward()
|
| 177 |
+
self.optimizer.step()
|
| 178 |
+
|
| 179 |
+
loss += batch_loss.item()
|
| 180 |
+
num_batches += 1
|
| 181 |
+
loss /= num_batches
|
| 182 |
+
|
| 183 |
+
# Use Tune functions to report intermediate
|
| 184 |
+
# results.
|
| 185 |
+
train.report({"loss": loss, "epoch": epoch_idx})
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
# Initialize the Trainer, and call Trainer.fit()
|
| 189 |
+
import ray
|
| 190 |
+
train_dataset = ray.data.from_items(
|
| 191 |
+
[{"x": i, "y": i} for i in range(10)])
|
| 192 |
+
my_trainer = MyPytorchTrainer(datasets={"train": train_dataset})
|
| 193 |
+
result = my_trainer.fit()
|
| 194 |
+
|
| 195 |
+
.. testoutput::
|
| 196 |
+
:hide:
|
| 197 |
+
|
| 198 |
+
...
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
scaling_config: Configuration for how to scale training.
|
| 202 |
+
run_config: Configuration for the execution of the training run.
|
| 203 |
+
datasets: Any Datasets to use for training. Use the key "train"
|
| 204 |
+
to denote which dataset is the training dataset.
|
| 205 |
+
metadata: Dict that should be made available via
|
| 206 |
+
`train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
|
| 207 |
+
for checkpoints saved from this Trainer. Must be JSON-serializable.
|
| 208 |
+
resume_from_checkpoint: A checkpoint to resume training from.
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
_scaling_config_allowed_keys: List[str] = [
|
| 212 |
+
"trainer_resources",
|
| 213 |
+
]
|
| 214 |
+
_handles_checkpoint_freq: bool = False
|
| 215 |
+
_handles_checkpoint_at_end: bool = False
|
| 216 |
+
|
| 217 |
+
# fields to propagate to Tuner param_space.
|
| 218 |
+
# See `BaseTrainer._extract_fields_for_tuner_param_space` for more details.
|
| 219 |
+
_fields_for_tuner_param_space = []
|
| 220 |
+
|
| 221 |
+
def __init__(
|
| 222 |
+
self,
|
| 223 |
+
*,
|
| 224 |
+
scaling_config: Optional[ScalingConfig] = None,
|
| 225 |
+
run_config: Optional[RunConfig] = None,
|
| 226 |
+
datasets: Optional[Dict[str, GenDataset]] = None,
|
| 227 |
+
metadata: Optional[Dict[str, Any]] = None,
|
| 228 |
+
resume_from_checkpoint: Optional[Checkpoint] = None,
|
| 229 |
+
):
|
| 230 |
+
self.scaling_config = (
|
| 231 |
+
scaling_config if scaling_config is not None else ScalingConfig()
|
| 232 |
+
)
|
| 233 |
+
self.run_config = (
|
| 234 |
+
copy.copy(run_config) if run_config is not None else RunConfig()
|
| 235 |
+
)
|
| 236 |
+
self.metadata = metadata
|
| 237 |
+
self.datasets = datasets if datasets is not None else {}
|
| 238 |
+
self.starting_checkpoint = resume_from_checkpoint
|
| 239 |
+
|
| 240 |
+
# These attributes should only be set through `BaseTrainer.restore`
|
| 241 |
+
self._restore_path = None
|
| 242 |
+
self._restore_storage_filesystem = None
|
| 243 |
+
|
| 244 |
+
self._validate_attributes()
|
| 245 |
+
|
| 246 |
+
air_usage.tag_air_trainer(self)
|
| 247 |
+
|
| 248 |
+
@PublicAPI(stability="alpha")
|
| 249 |
+
@classmethod
|
| 250 |
+
def restore(
|
| 251 |
+
cls: Type["BaseTrainer"],
|
| 252 |
+
path: Union[str, os.PathLike],
|
| 253 |
+
storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
|
| 254 |
+
datasets: Optional[Dict[str, GenDataset]] = None,
|
| 255 |
+
scaling_config: Optional[ScalingConfig] = None,
|
| 256 |
+
**kwargs,
|
| 257 |
+
) -> "BaseTrainer":
|
| 258 |
+
"""Restores a Train experiment from a previously interrupted/failed run.
|
| 259 |
+
|
| 260 |
+
Restore should be used for experiment-level fault tolerance in the event
|
| 261 |
+
that the head node crashes (e.g., OOM or some other runtime error) or the
|
| 262 |
+
entire cluster goes down (e.g., network error affecting all nodes).
|
| 263 |
+
|
| 264 |
+
A run that has already completed successfully will not be resumed from this API.
|
| 265 |
+
To continue training from a successful run, launch a new run with the
|
| 266 |
+
``<Framework>Trainer(resume_from_checkpoint)`` API instead, passing in a
|
| 267 |
+
checkpoint from the previous run to start with.
|
| 268 |
+
|
| 269 |
+
.. note::
|
| 270 |
+
|
| 271 |
+
Restoring an experiment from a path that's pointing to a *different*
|
| 272 |
+
location than the original experiment path is supported. However, Ray Train
|
| 273 |
+
assumes that the full experiment directory is available
|
| 274 |
+
(including checkpoints) so that it's possible to resume trials from their
|
| 275 |
+
latest state.
|
| 276 |
+
|
| 277 |
+
For example, if the original experiment path was run locally, then the
|
| 278 |
+
results are uploaded to cloud storage, Ray Train expects the full contents
|
| 279 |
+
to be available in cloud storage if attempting to resume
|
| 280 |
+
via ``<Framework>Trainer.restore("s3://...")``. The restored run will
|
| 281 |
+
continue writing results to the same cloud storage location.
|
| 282 |
+
|
| 283 |
+
The following example can be paired with implementing job retry using
|
| 284 |
+
:ref:`Ray Jobs <jobs-overview>` to produce a Train experiment that will
|
| 285 |
+
attempt to resume on both experiment-level and trial-level failures:
|
| 286 |
+
|
| 287 |
+
.. testcode::
|
| 288 |
+
|
| 289 |
+
import os
|
| 290 |
+
import ray
|
| 291 |
+
from ray import train
|
| 292 |
+
from ray.train.trainer import BaseTrainer
|
| 293 |
+
|
| 294 |
+
experiment_name = "unique_experiment_name"
|
| 295 |
+
storage_path = os.path.expanduser("~/ray_results")
|
| 296 |
+
experiment_dir = os.path.join(storage_path, experiment_name)
|
| 297 |
+
|
| 298 |
+
# Define some dummy inputs for demonstration purposes
|
| 299 |
+
datasets = {"train": ray.data.from_items([{"a": i} for i in range(10)])}
|
| 300 |
+
|
| 301 |
+
class CustomTrainer(BaseTrainer):
|
| 302 |
+
def training_loop(self):
|
| 303 |
+
pass
|
| 304 |
+
|
| 305 |
+
if CustomTrainer.can_restore(experiment_dir):
|
| 306 |
+
trainer = CustomTrainer.restore(
|
| 307 |
+
experiment_dir, datasets=datasets
|
| 308 |
+
)
|
| 309 |
+
else:
|
| 310 |
+
trainer = CustomTrainer(
|
| 311 |
+
datasets=datasets,
|
| 312 |
+
run_config=train.RunConfig(
|
| 313 |
+
name=experiment_name,
|
| 314 |
+
storage_path=storage_path,
|
| 315 |
+
# Tip: You can also enable retries on failure for
|
| 316 |
+
# worker-level fault tolerance
|
| 317 |
+
failure_config=train.FailureConfig(max_failures=3),
|
| 318 |
+
),
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
result = trainer.fit()
|
| 322 |
+
|
| 323 |
+
.. testoutput::
|
| 324 |
+
:hide:
|
| 325 |
+
|
| 326 |
+
...
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
path: The path to the experiment directory of the training run to restore.
|
| 330 |
+
This can be a local path or a remote URI if the experiment was
|
| 331 |
+
uploaded to the cloud.
|
| 332 |
+
storage_filesystem: Custom ``pyarrow.fs.FileSystem``
|
| 333 |
+
corresponding to the ``path``. This may be necessary if the original
|
| 334 |
+
experiment passed in a custom filesystem.
|
| 335 |
+
datasets: Re-specified datasets used in the original training run.
|
| 336 |
+
This must include all the datasets that were passed in the
|
| 337 |
+
original trainer constructor.
|
| 338 |
+
scaling_config: Optionally re-specified scaling config. This can be
|
| 339 |
+
modified to be different from the original spec.
|
| 340 |
+
**kwargs: Other optionally re-specified arguments, passed in by subclasses.
|
| 341 |
+
|
| 342 |
+
Raises:
|
| 343 |
+
ValueError: If all datasets were not re-supplied on restore.
|
| 344 |
+
|
| 345 |
+
Returns:
|
| 346 |
+
BaseTrainer: A restored instance of the class that is calling this method.
|
| 347 |
+
"""
|
| 348 |
+
if not cls.can_restore(path, storage_filesystem):
|
| 349 |
+
raise ValueError(
|
| 350 |
+
f"Invalid restore path: {path}. Make sure that this path exists and "
|
| 351 |
+
"is the experiment directory that results from a call to "
|
| 352 |
+
"`trainer.fit()`."
|
| 353 |
+
)
|
| 354 |
+
fs, fs_path = get_fs_and_path(path, storage_filesystem)
|
| 355 |
+
trainer_pkl_path = Path(fs_path, _TRAINER_PKL).as_posix()
|
| 356 |
+
with fs.open_input_file(trainer_pkl_path) as f:
|
| 357 |
+
trainer_cls, param_dict = pickle.loads(f.readall())
|
| 358 |
+
|
| 359 |
+
if trainer_cls is not cls:
|
| 360 |
+
warnings.warn(
|
| 361 |
+
f"Invalid trainer type. You are attempting to restore a trainer of type"
|
| 362 |
+
f" {trainer_cls} with `{cls.__name__}.restore`, "
|
| 363 |
+
"which will most likely fail. "
|
| 364 |
+
f"Use `{trainer_cls.__name__}.restore` instead."
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
original_datasets = param_dict.pop("datasets", {})
|
| 368 |
+
if original_datasets and not datasets:
|
| 369 |
+
raise ValueError(
|
| 370 |
+
"The following datasets need to be provided again on restore: "
|
| 371 |
+
f"{list(original_datasets.keys())}\n"
|
| 372 |
+
f"Use {cls.__name__}.restore(..., datasets=datasets) "
|
| 373 |
+
"with the datasets that were provided to the original trainer."
|
| 374 |
+
)
|
| 375 |
+
datasets = datasets or {}
|
| 376 |
+
if set(original_datasets) != set(datasets):
|
| 377 |
+
raise ValueError(
|
| 378 |
+
"The provided datasets don't match the original dataset keys.\n"
|
| 379 |
+
f" Expected datasets for the keys: {list(original_datasets.keys())}\n"
|
| 380 |
+
f" Actual datasets provided: {list(datasets.keys())}"
|
| 381 |
+
)
|
| 382 |
+
param_dict["datasets"] = datasets
|
| 383 |
+
|
| 384 |
+
if scaling_config:
|
| 385 |
+
param_dict["scaling_config"] = scaling_config
|
| 386 |
+
|
| 387 |
+
for param_name, val in kwargs.items():
|
| 388 |
+
# Overwrite the old value if something is passed into restore
|
| 389 |
+
if val is not None:
|
| 390 |
+
param_dict[param_name] = val
|
| 391 |
+
|
| 392 |
+
try:
|
| 393 |
+
trainer = cls(**param_dict)
|
| 394 |
+
except Exception as e:
|
| 395 |
+
raise ValueError(
|
| 396 |
+
"Trainer restoration failed (see above for the stack trace). "
|
| 397 |
+
"Make sure that you use the right trainer class to restore: "
|
| 398 |
+
f"`{cls.__name__}.restore`\n"
|
| 399 |
+
) from e
|
| 400 |
+
trainer._restore_path = path
|
| 401 |
+
trainer._restore_storage_filesystem = storage_filesystem
|
| 402 |
+
return trainer
|
| 403 |
+
|
| 404 |
+
@PublicAPI(stability="alpha")
|
| 405 |
+
@classmethod
|
| 406 |
+
def can_restore(
|
| 407 |
+
cls: Type["BaseTrainer"],
|
| 408 |
+
path: Union[str, os.PathLike],
|
| 409 |
+
storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
|
| 410 |
+
) -> bool:
|
| 411 |
+
"""Checks whether a given directory contains a restorable Train experiment.
|
| 412 |
+
|
| 413 |
+
Args:
|
| 414 |
+
path: The path to the experiment directory of the Train experiment.
|
| 415 |
+
This can be either a local directory (e.g., ~/ray_results/exp_name)
|
| 416 |
+
or a remote URI (e.g., s3://bucket/exp_name).
|
| 417 |
+
|
| 418 |
+
Returns:
|
| 419 |
+
bool: Whether this path exists and contains the trainer state to resume from
|
| 420 |
+
"""
|
| 421 |
+
fs, fs_path = get_fs_and_path(path, storage_filesystem)
|
| 422 |
+
trainer_pkl_path = Path(fs_path, _TRAINER_PKL).as_posix()
|
| 423 |
+
return _exists_at_fs_path(fs, trainer_pkl_path)
|
| 424 |
+
|
| 425 |
+
def __repr__(self):
|
| 426 |
+
# A dictionary that maps parameters to their default values.
|
| 427 |
+
default_values: Dict[str, Any] = {
|
| 428 |
+
"scaling_config": ScalingConfig(),
|
| 429 |
+
"run_config": RunConfig(),
|
| 430 |
+
"datasets": {},
|
| 431 |
+
"starting_checkpoint": None,
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
non_default_arguments = []
|
| 435 |
+
for parameter, default_value in default_values.items():
|
| 436 |
+
value = getattr(self, parameter)
|
| 437 |
+
if value != default_value:
|
| 438 |
+
non_default_arguments.append(f"{parameter}={value!r}")
|
| 439 |
+
|
| 440 |
+
if non_default_arguments:
|
| 441 |
+
return f"<{self.__class__.__name__} {' '.join(non_default_arguments)}>"
|
| 442 |
+
|
| 443 |
+
return f"<{self.__class__.__name__}>"
|
| 444 |
+
|
| 445 |
+
def __new__(cls, *args, **kwargs):
|
| 446 |
+
# Store the init args as attributes so this can be merged with Tune hparams.
|
| 447 |
+
trainer = super(BaseTrainer, cls).__new__(cls)
|
| 448 |
+
parameters = inspect.signature(cls.__init__).parameters
|
| 449 |
+
parameters = list(parameters.keys())
|
| 450 |
+
# Remove self.
|
| 451 |
+
parameters = parameters[1:]
|
| 452 |
+
arg_dict = dict(zip(parameters, args))
|
| 453 |
+
trainer._param_dict = {**arg_dict, **kwargs}
|
| 454 |
+
return trainer
|
| 455 |
+
|
| 456 |
+
def _validate_attributes(self):
|
| 457 |
+
"""Called on __init()__ to validate trainer attributes."""
|
| 458 |
+
# Run config
|
| 459 |
+
if not isinstance(self.run_config, RunConfig):
|
| 460 |
+
raise ValueError(
|
| 461 |
+
f"`run_config` should be an instance of `ray.train.RunConfig`, "
|
| 462 |
+
f"found {type(self.run_config)} with value `{self.run_config}`."
|
| 463 |
+
)
|
| 464 |
+
# Scaling config
|
| 465 |
+
if not isinstance(self.scaling_config, ScalingConfig):
|
| 466 |
+
raise ValueError(
|
| 467 |
+
"`scaling_config` should be an instance of `ScalingConfig`, "
|
| 468 |
+
f"found {type(self.scaling_config)} with value `{self.scaling_config}`."
|
| 469 |
+
)
|
| 470 |
+
# Datasets
|
| 471 |
+
if not isinstance(self.datasets, dict):
|
| 472 |
+
raise ValueError(
|
| 473 |
+
f"`datasets` should be a dict mapping from a string to "
|
| 474 |
+
f"`ray.data.Dataset` objects, "
|
| 475 |
+
f"found {type(self.datasets)} with value `{self.datasets}`."
|
| 476 |
+
)
|
| 477 |
+
else:
|
| 478 |
+
for key, dataset in self.datasets.items():
|
| 479 |
+
if not isinstance(dataset, ray.data.Dataset) and not callable(dataset):
|
| 480 |
+
raise ValueError(
|
| 481 |
+
f"The Dataset under '{key}' key is not a "
|
| 482 |
+
"`ray.data.Dataset`. "
|
| 483 |
+
f"Received {dataset} instead."
|
| 484 |
+
)
|
| 485 |
+
# Metadata.
|
| 486 |
+
self.metadata = self.metadata or {}
|
| 487 |
+
if not isinstance(self.metadata, dict):
|
| 488 |
+
raise TypeError(
|
| 489 |
+
f"The provided metadata must be a dict, was {type(self.metadata)}."
|
| 490 |
+
)
|
| 491 |
+
try:
|
| 492 |
+
self.metadata = json.loads(json.dumps(self.metadata))
|
| 493 |
+
except Exception as e:
|
| 494 |
+
raise ValueError(
|
| 495 |
+
"The provided metadata must be JSON-serializable: "
|
| 496 |
+
f"{self.metadata}: {e}"
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
if self.starting_checkpoint is not None and not isinstance(
|
| 500 |
+
self.starting_checkpoint, Checkpoint
|
| 501 |
+
):
|
| 502 |
+
raise ValueError(
|
| 503 |
+
f"`resume_from_checkpoint` should be an instance of "
|
| 504 |
+
f"`ray.train.Checkpoint`, found {type(self.starting_checkpoint)} "
|
| 505 |
+
f"with value `{self.starting_checkpoint}`."
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
@classmethod
|
| 509 |
+
def _validate_scaling_config(cls, scaling_config: ScalingConfig) -> ScalingConfig:
|
| 510 |
+
"""Returns scaling config dataclass after validating updated keys."""
|
| 511 |
+
ensure_only_allowed_dataclass_keys_updated(
|
| 512 |
+
dataclass=scaling_config,
|
| 513 |
+
allowed_keys=cls._scaling_config_allowed_keys,
|
| 514 |
+
)
|
| 515 |
+
return scaling_config
|
| 516 |
+
|
| 517 |
+
def setup(self) -> None:
|
| 518 |
+
"""Called during fit() to perform initial setup on the Trainer.
|
| 519 |
+
|
| 520 |
+
.. note:: This method is run on a remote process.
|
| 521 |
+
|
| 522 |
+
This method will not be called on the driver, so any expensive setup
|
| 523 |
+
operations should be placed here and not in ``__init__``.
|
| 524 |
+
|
| 525 |
+
This method is called prior to ``preprocess_datasets`` and
|
| 526 |
+
``training_loop``.
|
| 527 |
+
"""
|
| 528 |
+
pass
|
| 529 |
+
|
| 530 |
+
def preprocess_datasets(self) -> None:
|
| 531 |
+
"""Deprecated."""
|
| 532 |
+
raise DeprecationWarning(
|
| 533 |
+
"`preprocess_datasets` is no longer used, since preprocessors "
|
| 534 |
+
f"are no longer accepted by Trainers.\n{PREPROCESSOR_DEPRECATION_MESSAGE}"
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
@abc.abstractmethod
|
| 538 |
+
def training_loop(self) -> None:
|
| 539 |
+
"""Loop called by fit() to run training and report results to Tune.
|
| 540 |
+
|
| 541 |
+
.. note:: This method runs on a remote process.
|
| 542 |
+
|
| 543 |
+
``self.datasets`` have already been evaluated if they were wrapped in a factory.
|
| 544 |
+
|
| 545 |
+
You can use the :ref:`Ray Train utilities <train-loop-api>`
|
| 546 |
+
(:func:`train.report() <ray.train.report>` and
|
| 547 |
+
:func:`train.get_checkpoint() <ray.train.get_checkpoint>`) inside
|
| 548 |
+
this training loop.
|
| 549 |
+
|
| 550 |
+
Example:
|
| 551 |
+
|
| 552 |
+
.. testcode::
|
| 553 |
+
|
| 554 |
+
from ray.train.trainer import BaseTrainer
|
| 555 |
+
from ray import train
|
| 556 |
+
|
| 557 |
+
class MyTrainer(BaseTrainer):
|
| 558 |
+
def training_loop(self):
|
| 559 |
+
for epoch_idx in range(5):
|
| 560 |
+
...
|
| 561 |
+
train.report({"epoch": epoch_idx})
|
| 562 |
+
|
| 563 |
+
"""
|
| 564 |
+
raise NotImplementedError
|
| 565 |
+
|
| 566 |
+
@PublicAPI(stability="beta")
|
| 567 |
+
def fit(self) -> Result:
|
| 568 |
+
"""Runs training.
|
| 569 |
+
|
| 570 |
+
Returns:
|
| 571 |
+
A Result object containing the training result.
|
| 572 |
+
|
| 573 |
+
Raises:
|
| 574 |
+
TrainingFailedError: If any failures during the execution
|
| 575 |
+
of ``self.as_trainable()``, or during the Tune execution loop.
|
| 576 |
+
"""
|
| 577 |
+
from ray.tune import ResumeConfig, TuneError
|
| 578 |
+
from ray.tune.tuner import Tuner
|
| 579 |
+
|
| 580 |
+
trainable = self.as_trainable()
|
| 581 |
+
param_space = self._extract_fields_for_tuner_param_space()
|
| 582 |
+
|
| 583 |
+
self.run_config.name = (
|
| 584 |
+
self.run_config.name or StorageContext.get_experiment_dir_name(trainable)
|
| 585 |
+
)
|
| 586 |
+
# The storage context here is only used to access the resolved
|
| 587 |
+
# storage fs and experiment path, in order to avoid duplicating that logic.
|
| 588 |
+
# This is NOT the storage context object that gets passed to remote workers.
|
| 589 |
+
storage = StorageContext(
|
| 590 |
+
storage_path=self.run_config.storage_path,
|
| 591 |
+
experiment_dir_name=self.run_config.name,
|
| 592 |
+
storage_filesystem=self.run_config.storage_filesystem,
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
if self._restore_path:
|
| 596 |
+
tuner = Tuner.restore(
|
| 597 |
+
path=self._restore_path,
|
| 598 |
+
trainable=trainable,
|
| 599 |
+
param_space=param_space,
|
| 600 |
+
_resume_config=ResumeConfig(
|
| 601 |
+
finished=ResumeConfig.ResumeType.RESUME,
|
| 602 |
+
unfinished=ResumeConfig.ResumeType.RESUME,
|
| 603 |
+
errored=ResumeConfig.ResumeType.RESUME,
|
| 604 |
+
),
|
| 605 |
+
storage_filesystem=self._restore_storage_filesystem,
|
| 606 |
+
)
|
| 607 |
+
else:
|
| 608 |
+
tuner = Tuner(
|
| 609 |
+
trainable=trainable,
|
| 610 |
+
param_space=param_space,
|
| 611 |
+
run_config=self.run_config,
|
| 612 |
+
_entrypoint=AirEntrypoint.TRAINER,
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
self._save(storage.storage_filesystem, storage.experiment_fs_path)
|
| 616 |
+
|
| 617 |
+
restore_msg = TrainingFailedError._RESTORE_MSG.format(
|
| 618 |
+
trainer_cls_name=self.__class__.__name__,
|
| 619 |
+
path=str(storage.experiment_fs_path),
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
try:
|
| 623 |
+
result_grid = tuner.fit()
|
| 624 |
+
except TuneError as e:
|
| 625 |
+
# Catch any `TuneError`s raised by the `Tuner.fit` call.
|
| 626 |
+
# Unwrap the `TuneError` if needed.
|
| 627 |
+
parent_error = e.__cause__ or e
|
| 628 |
+
|
| 629 |
+
# Raise it to the user as a `TrainingFailedError` with a message to restore.
|
| 630 |
+
raise TrainingFailedError(restore_msg) from parent_error
|
| 631 |
+
# Other exceptions get passed through directly (ex: on `fail_fast='raise'`)
|
| 632 |
+
|
| 633 |
+
assert len(result_grid) == 1
|
| 634 |
+
result = result_grid[0]
|
| 635 |
+
if result.error:
|
| 636 |
+
# Raise trainable errors to the user with a message to restore
|
| 637 |
+
# or configure `FailureConfig` in a new run.
|
| 638 |
+
raise TrainingFailedError(
|
| 639 |
+
"\n".join([restore_msg, TrainingFailedError._FAILURE_CONFIG_MSG])
|
| 640 |
+
) from result.error
|
| 641 |
+
return result
|
| 642 |
+
|
| 643 |
+
def _save(self, fs: pyarrow.fs.FileSystem, experiment_path: str):
|
| 644 |
+
"""Saves the current trainer's class along with the `param_dict` of
|
| 645 |
+
parameters passed to this trainer's constructor.
|
| 646 |
+
|
| 647 |
+
This is used to recreate the trainer on restore.
|
| 648 |
+
Unless a parameter is re-specified during restoration (only a subset
|
| 649 |
+
of parameters can be passed in again), that parameter will be loaded
|
| 650 |
+
from the saved copy.
|
| 651 |
+
|
| 652 |
+
Datasets should not be saved as part of the state. Instead, we save the
|
| 653 |
+
keys and replace the dataset values with dummy functions that will
|
| 654 |
+
raise an error if invoked. The error only serves as a guardrail for
|
| 655 |
+
misuse (e.g., manually unpickling and constructing the Trainer again)
|
| 656 |
+
and is not typically surfaced, since datasets must be re-specified
|
| 657 |
+
upon restoration.
|
| 658 |
+
"""
|
| 659 |
+
param_dict = self._param_dict.copy()
|
| 660 |
+
datasets = param_dict.pop("datasets", {})
|
| 661 |
+
|
| 662 |
+
def raise_fn():
|
| 663 |
+
raise RuntimeError
|
| 664 |
+
|
| 665 |
+
if datasets:
|
| 666 |
+
param_dict["datasets"] = {
|
| 667 |
+
dataset_name: raise_fn for dataset_name in datasets
|
| 668 |
+
}
|
| 669 |
+
|
| 670 |
+
cls_and_param_dict = (self.__class__, param_dict)
|
| 671 |
+
|
| 672 |
+
fs.create_dir(experiment_path)
|
| 673 |
+
with fs.open_output_stream(Path(experiment_path, _TRAINER_PKL).as_posix()) as f:
|
| 674 |
+
f.write(pickle.dumps(cls_and_param_dict))
|
| 675 |
+
|
| 676 |
+
def _extract_fields_for_tuner_param_space(self) -> Dict:
|
| 677 |
+
"""Extracts fields to be included in `Tuner.param_space`.
|
| 678 |
+
|
| 679 |
+
This is needed to leverage the full logging/integration offerings from Tune.
|
| 680 |
+
For example, `param_space` is logged automatically to wandb integration.
|
| 681 |
+
|
| 682 |
+
Currently only done for `train_loop_config`.
|
| 683 |
+
|
| 684 |
+
Returns:
|
| 685 |
+
A dictionary that should be passed to Tuner.param_space.
|
| 686 |
+
"""
|
| 687 |
+
result = {}
|
| 688 |
+
for key in self._fields_for_tuner_param_space:
|
| 689 |
+
if key in self._param_dict.keys():
|
| 690 |
+
result[key] = copy.deepcopy(self._param_dict[key])
|
| 691 |
+
return result
|
| 692 |
+
|
| 693 |
+
def _generate_trainable_cls(self) -> Type["Trainable"]:
|
| 694 |
+
"""Generates the base Trainable class.
|
| 695 |
+
|
| 696 |
+
Returns:
|
| 697 |
+
A Trainable class to use for training.
|
| 698 |
+
"""
|
| 699 |
+
|
| 700 |
+
from ray.tune.execution.placement_groups import PlacementGroupFactory
|
| 701 |
+
from ray.tune.trainable import wrap_function
|
| 702 |
+
|
| 703 |
+
trainer_cls = self.__class__
|
| 704 |
+
scaling_config = self.scaling_config
|
| 705 |
+
metadata = self.metadata
|
| 706 |
+
|
| 707 |
+
train_coordinator_fn = partial(
|
| 708 |
+
_train_coordinator_fn, trainer_cls=trainer_cls, metadata=metadata
|
| 709 |
+
)
|
| 710 |
+
# Change the name of the training function to match the name of the Trainer
|
| 711 |
+
# class. This will mean the Tune trial name will match the name of Trainer on
|
| 712 |
+
# stdout messages and the results directory.
|
| 713 |
+
train_coordinator_fn.__name__ = trainer_cls.__name__
|
| 714 |
+
|
| 715 |
+
trainable_cls = wrap_function(train_coordinator_fn)
|
| 716 |
+
has_base_dataset = bool(self.datasets)
|
| 717 |
+
if has_base_dataset:
|
| 718 |
+
from ray.data.context import DataContext
|
| 719 |
+
|
| 720 |
+
dataset_context = DataContext.get_current()
|
| 721 |
+
else:
|
| 722 |
+
dataset_context = None
|
| 723 |
+
|
| 724 |
+
class TrainTrainable(trainable_cls):
|
| 725 |
+
"""Adds default resources to the Trainable."""
|
| 726 |
+
|
| 727 |
+
_handles_checkpoint_freq = trainer_cls._handles_checkpoint_freq
|
| 728 |
+
_handles_checkpoint_at_end = trainer_cls._handles_checkpoint_at_end
|
| 729 |
+
|
| 730 |
+
@classmethod
|
| 731 |
+
def has_base_dataset(cls) -> bool:
|
| 732 |
+
"""Whether a dataset is provided through the Trainer."""
|
| 733 |
+
return has_base_dataset
|
| 734 |
+
|
| 735 |
+
@classmethod
|
| 736 |
+
def base_scaling_config(cls) -> ScalingConfig:
|
| 737 |
+
"""Returns the unchanged scaling config provided through the Trainer."""
|
| 738 |
+
return scaling_config
|
| 739 |
+
|
| 740 |
+
def setup(self, config, **kwargs):
|
| 741 |
+
base_config = dict(kwargs)
|
| 742 |
+
# Merge Tuner param space hyperparameters in `config` into the
|
| 743 |
+
# base config passed to the Trainer constructor, which is `base_config`.
|
| 744 |
+
# `base_config` is pulled from the object store from the usage of
|
| 745 |
+
# tune.with_parameters in `BaseTrainer.as_trainable`.
|
| 746 |
+
|
| 747 |
+
# run_config is not a tunable hyperparameter so it does not need to be
|
| 748 |
+
# merged.
|
| 749 |
+
run_config = base_config.pop("run_config", None)
|
| 750 |
+
self._merged_config = deep_update(
|
| 751 |
+
base_config, self.config, new_keys_allowed=True
|
| 752 |
+
)
|
| 753 |
+
self._merged_config["run_config"] = run_config
|
| 754 |
+
merged_scaling_config = self._merged_config.get(
|
| 755 |
+
"scaling_config", ScalingConfig()
|
| 756 |
+
)
|
| 757 |
+
if isinstance(merged_scaling_config, dict):
|
| 758 |
+
merged_scaling_config = ScalingConfig(**merged_scaling_config)
|
| 759 |
+
self._merged_config[
|
| 760 |
+
"scaling_config"
|
| 761 |
+
] = self._reconcile_scaling_config_with_trial_resources(
|
| 762 |
+
merged_scaling_config
|
| 763 |
+
)
|
| 764 |
+
if self.has_base_dataset():
|
| 765 |
+
# Set the DataContext on the Trainer actor to the DataContext
|
| 766 |
+
# specified on the driver.
|
| 767 |
+
DataContext._set_current(dataset_context)
|
| 768 |
+
super(TrainTrainable, self).setup(config)
|
| 769 |
+
|
| 770 |
+
def _reconcile_scaling_config_with_trial_resources(
|
| 771 |
+
self, scaling_config: ScalingConfig
|
| 772 |
+
) -> ScalingConfig:
|
| 773 |
+
"""
|
| 774 |
+
ResourceChangingScheduler workaround.
|
| 775 |
+
|
| 776 |
+
Ensures that the scaling config matches trial resources.
|
| 777 |
+
|
| 778 |
+
This should be replaced with RCS returning a ScalingConfig
|
| 779 |
+
in the future.
|
| 780 |
+
"""
|
| 781 |
+
|
| 782 |
+
trial_resources = self.trial_resources
|
| 783 |
+
# This will be false if the resources are default
|
| 784 |
+
if not isinstance(trial_resources, PlacementGroupFactory):
|
| 785 |
+
return scaling_config
|
| 786 |
+
|
| 787 |
+
# Ignore ResourceChangingScheduler workaround when resource bundles
|
| 788 |
+
# are unchanged
|
| 789 |
+
if self.trial_resources == scaling_config.as_placement_group_factory():
|
| 790 |
+
return scaling_config
|
| 791 |
+
|
| 792 |
+
trainer_cls._validate_scaling_config(scaling_config)
|
| 793 |
+
|
| 794 |
+
return ScalingConfig.from_placement_group_factory(trial_resources)
|
| 795 |
+
|
| 796 |
+
def _trainable_func(self, config):
|
| 797 |
+
# We ignore the config passed by Tune and instead use the merged
|
| 798 |
+
# config which includes the initial Trainer args.
|
| 799 |
+
super()._trainable_func(self._merged_config)
|
| 800 |
+
|
| 801 |
+
@classmethod
|
| 802 |
+
def default_resource_request(cls, config):
|
| 803 |
+
# `config["scaling_config"] is a dataclass when passed via the
|
| 804 |
+
# `scaling_config` argument in `Trainer` and is a dict when passed
|
| 805 |
+
# via the `scaling_config` key of `param_spec`.
|
| 806 |
+
|
| 807 |
+
# Conversion logic must be duplicated in `TrainTrainable.__init__`
|
| 808 |
+
# because this is a class method.
|
| 809 |
+
updated_scaling_config = config.get("scaling_config", scaling_config)
|
| 810 |
+
if isinstance(updated_scaling_config, dict):
|
| 811 |
+
updated_scaling_config = ScalingConfig(**updated_scaling_config)
|
| 812 |
+
validated_scaling_config = trainer_cls._validate_scaling_config(
|
| 813 |
+
updated_scaling_config
|
| 814 |
+
)
|
| 815 |
+
return validated_scaling_config.as_placement_group_factory()
|
| 816 |
+
|
| 817 |
+
return TrainTrainable
|
| 818 |
+
|
| 819 |
+
def as_trainable(self) -> Type["Trainable"]:
|
| 820 |
+
"""Converts self to a ``tune.Trainable`` class."""
|
| 821 |
+
from ray import tune
|
| 822 |
+
|
| 823 |
+
base_config = self._param_dict
|
| 824 |
+
trainable_cls = self._generate_trainable_cls()
|
| 825 |
+
|
| 826 |
+
# Wrap with `tune.with_parameters` to handle very large values in base_config
|
| 827 |
+
return tune.with_parameters(trainable_cls, **base_config)
|
.venv/lib/python3.11/site-packages/ray/train/constants.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import ray
|
| 4 |
+
from ray._private.ray_constants import env_bool
|
| 5 |
+
from ray.air.constants import ( # noqa: F401
|
| 6 |
+
COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV,
|
| 7 |
+
EVALUATION_DATASET_KEY,
|
| 8 |
+
MODEL_KEY,
|
| 9 |
+
PREPROCESSOR_KEY,
|
| 10 |
+
TRAIN_DATASET_KEY,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _get_ray_train_session_dir() -> str:
|
| 15 |
+
assert ray.is_initialized(), "Ray must be initialized to get the session dir."
|
| 16 |
+
return Path(
|
| 17 |
+
ray._private.worker._global_node.get_session_dir_path(), "artifacts"
|
| 18 |
+
).as_posix()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
DEFAULT_STORAGE_PATH = Path("~/ray_results").expanduser().as_posix()
|
| 22 |
+
|
| 23 |
+
# Autofilled ray.train.report() metrics. Keys should be consistent with Tune.
|
| 24 |
+
CHECKPOINT_DIR_NAME = "checkpoint_dir_name"
|
| 25 |
+
TIME_TOTAL_S = "_time_total_s"
|
| 26 |
+
WORKER_HOSTNAME = "_hostname"
|
| 27 |
+
WORKER_NODE_IP = "_node_ip"
|
| 28 |
+
WORKER_PID = "_pid"
|
| 29 |
+
|
| 30 |
+
# Will not be reported unless ENABLE_DETAILED_AUTOFILLED_METRICS_ENV
|
| 31 |
+
# env var is not 0
|
| 32 |
+
DETAILED_AUTOFILLED_KEYS = {WORKER_HOSTNAME, WORKER_NODE_IP, WORKER_PID, TIME_TOTAL_S}
|
| 33 |
+
|
| 34 |
+
# Default filename for JSON logger
|
| 35 |
+
RESULT_FILE_JSON = "results.json"
|
| 36 |
+
|
| 37 |
+
# The name of the subdirectory inside the trainer run_dir to store checkpoints.
|
| 38 |
+
TRAIN_CHECKPOINT_SUBDIR = "checkpoints"
|
| 39 |
+
|
| 40 |
+
# The key to use to specify the checkpoint id for Tune.
|
| 41 |
+
# This needs to be added to the checkpoint dictionary so if the Tune trial
|
| 42 |
+
# is restarted, the checkpoint_id can continue to increment.
|
| 43 |
+
TUNE_CHECKPOINT_ID = "_current_checkpoint_id"
|
| 44 |
+
|
| 45 |
+
# Deprecated configs can use this value to detect if the user has set it.
|
| 46 |
+
_DEPRECATED_VALUE = "DEPRECATED"
|
| 47 |
+
|
| 48 |
+
# ==================================================
|
| 49 |
+
# Environment Variables
|
| 50 |
+
# ==================================================
|
| 51 |
+
|
| 52 |
+
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV = (
|
| 53 |
+
"TRAIN_RESULT_ENABLE_DETAILED_AUTOFILLED_METRICS"
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Integer value which if set will override the value of
|
| 57 |
+
# Backend.share_cuda_visible_devices. 1 for True, 0 for False.
|
| 58 |
+
ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV = "TRAIN_ENABLE_SHARE_CUDA_VISIBLE_DEVICES"
|
| 59 |
+
|
| 60 |
+
# Integer value which if set will not share ROCR accelerator visible devices
|
| 61 |
+
# across workers. 1 for True (default), 0 for False.
|
| 62 |
+
ENABLE_SHARE_ROCR_VISIBLE_DEVICES_ENV = "TRAIN_ENABLE_SHARE_ROCR_VISIBLE_DEVICES"
|
| 63 |
+
|
| 64 |
+
# Integer value which if set will not share neuron-core accelerator visible cores
|
| 65 |
+
# across workers. 1 for True (default), 0 for False.
|
| 66 |
+
ENABLE_SHARE_NEURON_CORES_ACCELERATOR_ENV = (
|
| 67 |
+
"TRAIN_ENABLE_SHARE_NEURON_CORES_ACCELERATOR"
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Integer value which if set will not share npu visible devices
|
| 71 |
+
# across workers. 1 for True (default), 0 for False.
|
| 72 |
+
ENABLE_SHARE_NPU_RT_VISIBLE_DEVICES_ENV = "TRAIN_ENABLE_SHARE_ASCEND_RT_VISIBLE_DEVICES"
|
| 73 |
+
|
| 74 |
+
# Integer value which indicates the number of seconds to wait when creating
|
| 75 |
+
# the worker placement group before timing out.
|
| 76 |
+
TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV = "TRAIN_PLACEMENT_GROUP_TIMEOUT_S"
|
| 77 |
+
|
| 78 |
+
# Integer value which if set will change the placement group strategy from
|
| 79 |
+
# PACK to SPREAD. 1 for True, 0 for False.
|
| 80 |
+
TRAIN_ENABLE_WORKER_SPREAD_ENV = "TRAIN_ENABLE_WORKER_SPREAD"
|
| 81 |
+
|
| 82 |
+
# Set this to 0 to disable changing the working directory of each Tune Trainable
|
| 83 |
+
# or Train worker to the trial directory. Defaults to 1.
|
| 84 |
+
RAY_CHDIR_TO_TRIAL_DIR = "RAY_CHDIR_TO_TRIAL_DIR"
|
| 85 |
+
|
| 86 |
+
# Set this to 1 to count preemption errors toward `FailureConfig(max_failures)`.
|
| 87 |
+
# Defaults to 0, which always retries on node preemption failures.
|
| 88 |
+
RAY_TRAIN_COUNT_PREEMPTION_AS_FAILURE = "RAY_TRAIN_COUNT_PREEMPTION_AS_FAILURE"
|
| 89 |
+
|
| 90 |
+
# Set this to 1 to start a StateActor and collect information Train Runs
|
| 91 |
+
# Defaults to 0
|
| 92 |
+
RAY_TRAIN_ENABLE_STATE_TRACKING = "RAY_TRAIN_ENABLE_STATE_TRACKING"
|
| 93 |
+
|
| 94 |
+
# Set this to 1 to enable deprecation warnings for V2 migration.
|
| 95 |
+
ENABLE_V2_MIGRATION_WARNINGS_ENV_VAR = "RAY_TRAIN_ENABLE_V2_MIGRATION_WARNINGS"
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _v2_migration_warnings_enabled() -> bool:
|
| 99 |
+
return env_bool(ENABLE_V2_MIGRATION_WARNINGS_ENV_VAR, False)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# NOTE: When adding a new environment variable, please track it in this list.
|
| 103 |
+
TRAIN_ENV_VARS = {
|
| 104 |
+
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV,
|
| 105 |
+
ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV,
|
| 106 |
+
ENABLE_SHARE_NEURON_CORES_ACCELERATOR_ENV,
|
| 107 |
+
TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV,
|
| 108 |
+
TRAIN_ENABLE_WORKER_SPREAD_ENV,
|
| 109 |
+
RAY_CHDIR_TO_TRIAL_DIR,
|
| 110 |
+
RAY_TRAIN_COUNT_PREEMPTION_AS_FAILURE,
|
| 111 |
+
RAY_TRAIN_ENABLE_STATE_TRACKING,
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
# Key for AIR Checkpoint metadata in TrainingResult metadata
|
| 115 |
+
CHECKPOINT_METADATA_KEY = "checkpoint_metadata"
|
| 116 |
+
|
| 117 |
+
# Key for AIR Checkpoint world rank in TrainingResult metadata
|
| 118 |
+
CHECKPOINT_RANK_KEY = "checkpoint_rank"
|
.venv/lib/python3.11/site-packages/ray/train/context.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import threading
|
| 2 |
+
from typing import TYPE_CHECKING, Any, Dict, Optional
|
| 3 |
+
|
| 4 |
+
from ray.train._internal import session
|
| 5 |
+
from ray.train._internal.storage import StorageContext
|
| 6 |
+
from ray.train.constants import _v2_migration_warnings_enabled
|
| 7 |
+
from ray.train.utils import _copy_doc, _log_deprecation_warning
|
| 8 |
+
from ray.util.annotations import Deprecated, DeveloperAPI, PublicAPI
|
| 9 |
+
|
| 10 |
+
if TYPE_CHECKING:
|
| 11 |
+
from ray.tune.execution.placement_groups import PlacementGroupFactory
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# The context singleton on this process.
|
| 15 |
+
_default_context: "Optional[TrainContext]" = None
|
| 16 |
+
_context_lock = threading.Lock()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
_GET_METADATA_DEPRECATION_MESSAGE = (
|
| 20 |
+
"`get_metadata` was an experimental API that accessed the metadata passed "
|
| 21 |
+
"to `<Framework>Trainer(metadata=...)`. This API can be replaced by passing "
|
| 22 |
+
"the metadata directly to the training function (e.g., via `train_loop_config`)."
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
_TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE = (
|
| 26 |
+
"`{}` is deprecated because the concept of a `Trial` will "
|
| 27 |
+
"soon be removed in Ray Train (see here: "
|
| 28 |
+
"https://github.com/ray-project/enhancements/pull/57). "
|
| 29 |
+
"Ray Train will no longer assume that it's running within a Ray Tune `Trial` "
|
| 30 |
+
"in the future."
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@PublicAPI(stability="stable")
|
| 35 |
+
class TrainContext:
|
| 36 |
+
"""Context containing metadata that can be accessed within Ray Train workers."""
|
| 37 |
+
|
| 38 |
+
@_copy_doc(session.get_experiment_name)
|
| 39 |
+
def get_experiment_name(self) -> str:
|
| 40 |
+
return session.get_experiment_name()
|
| 41 |
+
|
| 42 |
+
@_copy_doc(session.get_world_size)
|
| 43 |
+
def get_world_size(self) -> int:
|
| 44 |
+
return session.get_world_size()
|
| 45 |
+
|
| 46 |
+
@_copy_doc(session.get_world_rank)
|
| 47 |
+
def get_world_rank(self) -> int:
|
| 48 |
+
return session.get_world_rank()
|
| 49 |
+
|
| 50 |
+
@_copy_doc(session.get_local_rank)
|
| 51 |
+
def get_local_rank(self) -> int:
|
| 52 |
+
return session.get_local_rank()
|
| 53 |
+
|
| 54 |
+
@_copy_doc(session.get_local_world_size)
|
| 55 |
+
def get_local_world_size(self) -> int:
|
| 56 |
+
return session.get_local_world_size()
|
| 57 |
+
|
| 58 |
+
@_copy_doc(session.get_node_rank)
|
| 59 |
+
def get_node_rank(self) -> int:
|
| 60 |
+
return session.get_node_rank()
|
| 61 |
+
|
| 62 |
+
@DeveloperAPI
|
| 63 |
+
@_copy_doc(session.get_storage)
|
| 64 |
+
def get_storage(self) -> StorageContext:
|
| 65 |
+
return session.get_storage()
|
| 66 |
+
|
| 67 |
+
# Deprecated APIs
|
| 68 |
+
|
| 69 |
+
@Deprecated(
|
| 70 |
+
message=_GET_METADATA_DEPRECATION_MESSAGE,
|
| 71 |
+
warning=_v2_migration_warnings_enabled(),
|
| 72 |
+
)
|
| 73 |
+
@_copy_doc(session.get_metadata)
|
| 74 |
+
def get_metadata(self) -> Dict[str, Any]:
|
| 75 |
+
return session.get_metadata()
|
| 76 |
+
|
| 77 |
+
@Deprecated(
|
| 78 |
+
message=_TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_trial_name"),
|
| 79 |
+
warning=_v2_migration_warnings_enabled(),
|
| 80 |
+
)
|
| 81 |
+
@_copy_doc(session.get_trial_name)
|
| 82 |
+
def get_trial_name(self) -> str:
|
| 83 |
+
return session.get_trial_name()
|
| 84 |
+
|
| 85 |
+
@Deprecated(
|
| 86 |
+
message=_TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_trial_id"),
|
| 87 |
+
warning=_v2_migration_warnings_enabled(),
|
| 88 |
+
)
|
| 89 |
+
@_copy_doc(session.get_trial_id)
|
| 90 |
+
def get_trial_id(self) -> str:
|
| 91 |
+
return session.get_trial_id()
|
| 92 |
+
|
| 93 |
+
@Deprecated(
|
| 94 |
+
message=_TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format(
|
| 95 |
+
"get_trial_resources"
|
| 96 |
+
),
|
| 97 |
+
warning=_v2_migration_warnings_enabled(),
|
| 98 |
+
)
|
| 99 |
+
@_copy_doc(session.get_trial_resources)
|
| 100 |
+
def get_trial_resources(self) -> "PlacementGroupFactory":
|
| 101 |
+
return session.get_trial_resources()
|
| 102 |
+
|
| 103 |
+
@Deprecated(
|
| 104 |
+
message=_TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_trial_dir"),
|
| 105 |
+
warning=_v2_migration_warnings_enabled(),
|
| 106 |
+
)
|
| 107 |
+
@_copy_doc(session.get_trial_dir)
|
| 108 |
+
def get_trial_dir(self) -> str:
|
| 109 |
+
return session.get_trial_dir()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@PublicAPI(stability="stable")
|
| 113 |
+
def get_context() -> TrainContext:
|
| 114 |
+
"""Get or create a singleton training context.
|
| 115 |
+
|
| 116 |
+
The context is only available within a function passed to Ray Train.
|
| 117 |
+
|
| 118 |
+
See the :class:`~ray.train.TrainContext` API reference to see available methods.
|
| 119 |
+
"""
|
| 120 |
+
from ray.tune.trainable.trainable_fn_utils import _in_tune_session
|
| 121 |
+
|
| 122 |
+
# If we are running in a Tune function, switch to Tune context.
|
| 123 |
+
if _in_tune_session():
|
| 124 |
+
from ray.tune import get_context as get_tune_context
|
| 125 |
+
|
| 126 |
+
if _v2_migration_warnings_enabled():
|
| 127 |
+
_log_deprecation_warning(
|
| 128 |
+
"`ray.train.get_context()` should be switched to "
|
| 129 |
+
"`ray.tune.get_context()` when running in a function "
|
| 130 |
+
"passed to Ray Tune. This will be an error in the future."
|
| 131 |
+
)
|
| 132 |
+
return get_tune_context()
|
| 133 |
+
|
| 134 |
+
global _default_context
|
| 135 |
+
|
| 136 |
+
with _context_lock:
|
| 137 |
+
if _default_context is None:
|
| 138 |
+
_default_context = TrainContext()
|
| 139 |
+
return _default_context
|
.venv/lib/python3.11/site-packages/ray/train/data_parallel_trainer.py
ADDED
|
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import uuid
|
| 3 |
+
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
| 4 |
+
|
| 5 |
+
import ray
|
| 6 |
+
from ray._private.ray_constants import env_integer
|
| 7 |
+
from ray._private.thirdparty.tabulate.tabulate import tabulate
|
| 8 |
+
from ray.air.config import RunConfig, ScalingConfig
|
| 9 |
+
from ray.train import BackendConfig, Checkpoint, TrainingIterator
|
| 10 |
+
from ray.train._internal import session
|
| 11 |
+
from ray.train._internal.backend_executor import BackendExecutor, TrialInfo
|
| 12 |
+
from ray.train._internal.data_config import DataConfig
|
| 13 |
+
from ray.train._internal.session import _TrainingResult, get_session
|
| 14 |
+
from ray.train._internal.utils import construct_train_func, count_required_parameters
|
| 15 |
+
from ray.train.constants import RAY_TRAIN_ENABLE_STATE_TRACKING
|
| 16 |
+
from ray.train.trainer import BaseTrainer, GenDataset
|
| 17 |
+
from ray.util.annotations import DeveloperAPI, PublicAPI
|
| 18 |
+
from ray.widgets import Template
|
| 19 |
+
from ray.widgets.util import repr_with_fallback
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@DeveloperAPI
|
| 25 |
+
class DataParallelTrainer(BaseTrainer):
|
| 26 |
+
"""A Trainer for data parallel training.
|
| 27 |
+
|
| 28 |
+
You should subclass this Trainer if your Trainer follows SPMD (single program,
|
| 29 |
+
multiple data) programming paradigm - you want multiple processes to run the same
|
| 30 |
+
function, but on different data.
|
| 31 |
+
|
| 32 |
+
This Trainer runs the function ``train_loop_per_worker`` on multiple Ray
|
| 33 |
+
Actors.
|
| 34 |
+
|
| 35 |
+
The ``train_loop_per_worker`` function is expected to take in either 0 or 1
|
| 36 |
+
arguments:
|
| 37 |
+
|
| 38 |
+
.. testcode::
|
| 39 |
+
|
| 40 |
+
def train_loop_per_worker():
|
| 41 |
+
...
|
| 42 |
+
|
| 43 |
+
.. testcode::
|
| 44 |
+
|
| 45 |
+
def train_loop_per_worker(config: Dict):
|
| 46 |
+
...
|
| 47 |
+
|
| 48 |
+
If ``train_loop_per_worker`` accepts an argument, then
|
| 49 |
+
``train_loop_config`` will be passed in as the argument. This is useful if you
|
| 50 |
+
want to tune the values in ``train_loop_config`` as hyperparameters.
|
| 51 |
+
|
| 52 |
+
If the ``datasets`` dict contains a training dataset (denoted by
|
| 53 |
+
the "train" key), then it will be split into multiple dataset
|
| 54 |
+
shards that can then be accessed by ``train.get_dataset_shard("train")`` inside
|
| 55 |
+
``train_loop_per_worker``. All the other datasets will not be split and
|
| 56 |
+
``train.get_dataset_shard(...)`` will return the the entire Dataset.
|
| 57 |
+
|
| 58 |
+
Inside the ``train_loop_per_worker`` function, you can use any of the
|
| 59 |
+
:ref:`Ray Train loop methods <train-loop-api>`.
|
| 60 |
+
|
| 61 |
+
.. testcode::
|
| 62 |
+
|
| 63 |
+
from ray import train
|
| 64 |
+
|
| 65 |
+
def train_loop_per_worker():
|
| 66 |
+
# Report intermediate results for callbacks or logging and
|
| 67 |
+
# checkpoint data.
|
| 68 |
+
train.report(...)
|
| 69 |
+
|
| 70 |
+
# Returns dict of last saved checkpoint.
|
| 71 |
+
train.get_checkpoint()
|
| 72 |
+
|
| 73 |
+
# Returns the Dataset shard for the given key.
|
| 74 |
+
train.get_dataset_shard("my_dataset")
|
| 75 |
+
|
| 76 |
+
# Returns the total number of workers executing training.
|
| 77 |
+
train.get_context().get_world_size()
|
| 78 |
+
|
| 79 |
+
# Returns the rank of this worker.
|
| 80 |
+
train.get_context().get_world_rank()
|
| 81 |
+
|
| 82 |
+
# Returns the rank of the worker on the current node.
|
| 83 |
+
train.get_context().get_local_rank()
|
| 84 |
+
|
| 85 |
+
Any returns from the ``train_loop_per_worker`` will be discarded and not
|
| 86 |
+
used or persisted anywhere.
|
| 87 |
+
|
| 88 |
+
**How do I use DataParallelTrainer or any of its subclasses?**
|
| 89 |
+
|
| 90 |
+
Example:
|
| 91 |
+
|
| 92 |
+
.. testcode::
|
| 93 |
+
|
| 94 |
+
import ray
|
| 95 |
+
from ray import train
|
| 96 |
+
from ray.train import ScalingConfig
|
| 97 |
+
from ray.train.data_parallel_trainer import DataParallelTrainer
|
| 98 |
+
|
| 99 |
+
def train_loop_for_worker():
|
| 100 |
+
dataset_shard_for_this_worker = train.get_dataset_shard("train")
|
| 101 |
+
|
| 102 |
+
# 3 items for 3 workers, each worker gets 1 item
|
| 103 |
+
batches = list(dataset_shard_for_this_worker.iter_batches(batch_size=1))
|
| 104 |
+
assert len(batches) == 1
|
| 105 |
+
|
| 106 |
+
train_dataset = ray.data.from_items([1, 2, 3])
|
| 107 |
+
assert train_dataset.count() == 3
|
| 108 |
+
trainer = DataParallelTrainer(
|
| 109 |
+
train_loop_for_worker,
|
| 110 |
+
scaling_config=ScalingConfig(num_workers=3),
|
| 111 |
+
datasets={"train": train_dataset},
|
| 112 |
+
)
|
| 113 |
+
result = trainer.fit()
|
| 114 |
+
|
| 115 |
+
.. testoutput::
|
| 116 |
+
:hide:
|
| 117 |
+
|
| 118 |
+
...
|
| 119 |
+
|
| 120 |
+
**How do I develop on top of DataParallelTrainer?**
|
| 121 |
+
|
| 122 |
+
In many cases, using DataParallelTrainer directly is sufficient to execute
|
| 123 |
+
functions on multiple actors.
|
| 124 |
+
|
| 125 |
+
However, you may want to subclass ``DataParallelTrainer`` and create a custom
|
| 126 |
+
Trainer for the following 2 use cases:
|
| 127 |
+
|
| 128 |
+
- **Use Case 1:** You want to do data parallel training, but want to have
|
| 129 |
+
a predefined ``training_loop_per_worker``.
|
| 130 |
+
|
| 131 |
+
- **Use Case 2:** You want to implement a custom
|
| 132 |
+
:py:class:`~ray.train.backend.Backend` that automatically handles
|
| 133 |
+
additional setup or teardown logic on each actor, so that the users of this
|
| 134 |
+
new trainer do not have to implement this logic. For example, a
|
| 135 |
+
``TensorflowTrainer`` can be built on top of ``DataParallelTrainer``
|
| 136 |
+
that automatically handles setting the proper environment variables for
|
| 137 |
+
distributed Tensorflow on each actor.
|
| 138 |
+
|
| 139 |
+
For 1, you can set a predefined training loop in __init__
|
| 140 |
+
|
| 141 |
+
.. testcode::
|
| 142 |
+
|
| 143 |
+
from ray.train.data_parallel_trainer import DataParallelTrainer
|
| 144 |
+
|
| 145 |
+
class MyDataParallelTrainer(DataParallelTrainer):
|
| 146 |
+
def __init__(self, *args, **kwargs):
|
| 147 |
+
predefined_train_loop_per_worker = lambda: 1
|
| 148 |
+
super().__init__(predefined_train_loop_per_worker, *args, **kwargs)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
For 2, you can implement the ``ray.train.Backend`` and ``ray.train.BackendConfig``
|
| 152 |
+
interfaces.
|
| 153 |
+
|
| 154 |
+
.. testcode::
|
| 155 |
+
|
| 156 |
+
from dataclasses import dataclass
|
| 157 |
+
from ray.train.backend import Backend, BackendConfig
|
| 158 |
+
|
| 159 |
+
class MyBackend(Backend):
|
| 160 |
+
def on_start(self, worker_group, backend_config):
|
| 161 |
+
def set_env_var(env_var_value):
|
| 162 |
+
import os
|
| 163 |
+
os.environ["MY_ENV_VAR"] = env_var_value
|
| 164 |
+
|
| 165 |
+
worker_group.execute(set_env_var, backend_config.env_var)
|
| 166 |
+
|
| 167 |
+
@dataclass
|
| 168 |
+
class MyBackendConfig(BackendConfig):
|
| 169 |
+
env_var: str = "default_value"
|
| 170 |
+
|
| 171 |
+
def backend_cls(self):
|
| 172 |
+
return MyBackend
|
| 173 |
+
|
| 174 |
+
class MyTrainer(DataParallelTrainer):
|
| 175 |
+
def __init__(self, train_loop_per_worker, my_backend_config:
|
| 176 |
+
MyBackendConfig, **kwargs):
|
| 177 |
+
|
| 178 |
+
super().__init__(
|
| 179 |
+
train_loop_per_worker,
|
| 180 |
+
backend_config=my_backend_config, **kwargs)
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
train_loop_per_worker: The training function to execute.
|
| 184 |
+
This can either take in no arguments or a ``config`` dict.
|
| 185 |
+
train_loop_config: Configurations to pass into
|
| 186 |
+
``train_loop_per_worker`` if it accepts an argument.
|
| 187 |
+
backend_config: Configuration for setting up a Backend (e.g. Torch,
|
| 188 |
+
Tensorflow, Horovod) on each worker to enable distributed
|
| 189 |
+
communication. If no Backend should be set up, then set this to None.
|
| 190 |
+
scaling_config: Configuration for how to scale data parallel training.
|
| 191 |
+
dataset_config: Configuration for dataset ingest. This is merged with the
|
| 192 |
+
default dataset config for the given trainer (`cls._dataset_config`).
|
| 193 |
+
run_config: Configuration for the execution of the training run.
|
| 194 |
+
datasets: Ray Datasets to use for training and evaluation.
|
| 195 |
+
This is a dict where the key is the name of the dataset, which
|
| 196 |
+
can be accessed from within the ``train_loop_per_worker`` by calling
|
| 197 |
+
``train.get_dataset_shard(dataset_key)``.
|
| 198 |
+
By default, all datasets are sharded equally across workers.
|
| 199 |
+
This can be configured via ``dataset_config``.
|
| 200 |
+
metadata: Dict that should be made available via
|
| 201 |
+
`train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
|
| 202 |
+
for checkpoints saved from this Trainer. Must be JSON-serializable.
|
| 203 |
+
resume_from_checkpoint: A checkpoint to resume training from.
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
# Exposed here for testing purposes. Should never need
|
| 207 |
+
# to be overriden.
|
| 208 |
+
_backend_executor_cls: Type[BackendExecutor] = BackendExecutor
|
| 209 |
+
_training_iterator_cls: Type[TrainingIterator] = TrainingIterator
|
| 210 |
+
|
| 211 |
+
_scaling_config_allowed_keys = BaseTrainer._scaling_config_allowed_keys + [
|
| 212 |
+
"num_workers",
|
| 213 |
+
"resources_per_worker",
|
| 214 |
+
"use_gpu",
|
| 215 |
+
"placement_strategy",
|
| 216 |
+
"accelerator_type",
|
| 217 |
+
]
|
| 218 |
+
|
| 219 |
+
# For backwards compatibility with the legacy dataset config API.
|
| 220 |
+
_dataset_config = None
|
| 221 |
+
|
| 222 |
+
_fields_for_tuner_param_space = BaseTrainer._fields_for_tuner_param_space + [
|
| 223 |
+
"train_loop_config"
|
| 224 |
+
]
|
| 225 |
+
|
| 226 |
+
def __init__(
|
| 227 |
+
self,
|
| 228 |
+
train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]],
|
| 229 |
+
*,
|
| 230 |
+
train_loop_config: Optional[Dict] = None,
|
| 231 |
+
backend_config: Optional[BackendConfig] = None,
|
| 232 |
+
scaling_config: Optional[ScalingConfig] = None,
|
| 233 |
+
dataset_config: Optional[DataConfig] = None,
|
| 234 |
+
run_config: Optional[RunConfig] = None,
|
| 235 |
+
datasets: Optional[Dict[str, GenDataset]] = None,
|
| 236 |
+
metadata: Optional[Dict[str, Any]] = None,
|
| 237 |
+
resume_from_checkpoint: Optional[Checkpoint] = None,
|
| 238 |
+
):
|
| 239 |
+
self._train_loop_per_worker = train_loop_per_worker
|
| 240 |
+
self._train_loop_config = train_loop_config
|
| 241 |
+
|
| 242 |
+
if dataset_config is None:
|
| 243 |
+
dataset_config = DataConfig()
|
| 244 |
+
|
| 245 |
+
if not isinstance(dataset_config, DataConfig):
|
| 246 |
+
raise ValueError(
|
| 247 |
+
"`dataset_config` must be an instance of ray.train.DataConfig, "
|
| 248 |
+
f"was: {dataset_config}"
|
| 249 |
+
)
|
| 250 |
+
self._data_config = dataset_config
|
| 251 |
+
|
| 252 |
+
backend_config = (
|
| 253 |
+
backend_config if backend_config is not None else BackendConfig()
|
| 254 |
+
)
|
| 255 |
+
self._backend_config = backend_config
|
| 256 |
+
|
| 257 |
+
super(DataParallelTrainer, self).__init__(
|
| 258 |
+
scaling_config=scaling_config,
|
| 259 |
+
run_config=run_config,
|
| 260 |
+
datasets=datasets,
|
| 261 |
+
metadata=metadata,
|
| 262 |
+
resume_from_checkpoint=resume_from_checkpoint,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
train_total_resources = self.scaling_config.total_resources
|
| 266 |
+
self._data_config.set_train_total_resources(
|
| 267 |
+
train_total_resources.get("CPU", 0),
|
| 268 |
+
train_total_resources.get("GPU", 0),
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
if env_integer(RAY_TRAIN_ENABLE_STATE_TRACKING, 0):
|
| 272 |
+
from ray.train._internal.state.state_actor import get_or_create_state_actor
|
| 273 |
+
|
| 274 |
+
get_or_create_state_actor()
|
| 275 |
+
|
| 276 |
+
@PublicAPI(stability="beta")
|
| 277 |
+
@classmethod
|
| 278 |
+
def restore(
|
| 279 |
+
cls: Type["DataParallelTrainer"],
|
| 280 |
+
path: str,
|
| 281 |
+
train_loop_per_worker: Optional[
|
| 282 |
+
Union[Callable[[], None], Callable[[Dict], None]]
|
| 283 |
+
] = None,
|
| 284 |
+
train_loop_config: Optional[Dict] = None,
|
| 285 |
+
**kwargs,
|
| 286 |
+
) -> "DataParallelTrainer":
|
| 287 |
+
"""Restores a DataParallelTrainer from a previously interrupted/failed run.
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
train_loop_per_worker: Optionally re-specified train loop function.
|
| 291 |
+
This should be used to re-specify a function that is not
|
| 292 |
+
restorable in a new Ray cluster (e.g., it holds onto outdated
|
| 293 |
+
object references). This should be the same training loop
|
| 294 |
+
that was passed to the original trainer constructor.
|
| 295 |
+
train_loop_config: Optionally re-specified train config.
|
| 296 |
+
This should similarly be used if the original `train_loop_config`
|
| 297 |
+
contained outdated object references, and it should not be modified
|
| 298 |
+
from what was originally passed in.
|
| 299 |
+
|
| 300 |
+
See :meth:`BaseTrainer.restore() <ray.train.trainer.BaseTrainer.restore>`
|
| 301 |
+
for descriptions of the other arguments.
|
| 302 |
+
|
| 303 |
+
Returns:
|
| 304 |
+
DataParallelTrainer: A restored instance of the `DataParallelTrainer`
|
| 305 |
+
subclass that is calling this method.
|
| 306 |
+
"""
|
| 307 |
+
return super(DataParallelTrainer, cls).restore(
|
| 308 |
+
path=path,
|
| 309 |
+
train_loop_per_worker=train_loop_per_worker,
|
| 310 |
+
train_loop_config=train_loop_config,
|
| 311 |
+
**kwargs,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
def _validate_attributes(self):
|
| 315 |
+
super()._validate_attributes()
|
| 316 |
+
|
| 317 |
+
self._validate_train_loop_per_worker(
|
| 318 |
+
self._train_loop_per_worker, "train_loop_per_worker"
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
def _validate_train_loop_per_worker(
|
| 322 |
+
self, train_loop_per_worker: Callable, fn_name: str
|
| 323 |
+
) -> None:
|
| 324 |
+
num_required_params = count_required_parameters(train_loop_per_worker)
|
| 325 |
+
if num_required_params > 1:
|
| 326 |
+
raise ValueError(
|
| 327 |
+
f"{fn_name} should take in 0 or 1 arguments, "
|
| 328 |
+
f"but it accepts {num_required_params} arguments instead."
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
@classmethod
|
| 332 |
+
def _validate_scaling_config(cls, scaling_config: ScalingConfig) -> ScalingConfig:
|
| 333 |
+
scaling_config = super(DataParallelTrainer, cls)._validate_scaling_config(
|
| 334 |
+
scaling_config
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
# This validation happens after the scaling config is updated from
|
| 338 |
+
# its specification in the Tuner `param_space`
|
| 339 |
+
if not scaling_config.use_gpu and "GPU" in ray.available_resources():
|
| 340 |
+
logger.info(
|
| 341 |
+
"GPUs are detected in your Ray cluster, but GPU "
|
| 342 |
+
"training is not enabled for this trainer. To enable "
|
| 343 |
+
"GPU training, make sure to set `use_gpu` to True "
|
| 344 |
+
"in your scaling config."
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
if scaling_config.num_workers is None:
|
| 348 |
+
raise ValueError(
|
| 349 |
+
"You must specify the 'num_workers' in `scaling_config` as either an "
|
| 350 |
+
f"argument of `{cls.__name__}` or through the `param_space` of a "
|
| 351 |
+
"`Tuner` (if performing hyperparameter tuning)."
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
if scaling_config.num_workers <= 0:
|
| 355 |
+
raise ValueError(
|
| 356 |
+
"'num_workers' in `scaling_config` must be a positive "
|
| 357 |
+
f"integer. Received {scaling_config.num_workers}"
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
return scaling_config
|
| 361 |
+
|
| 362 |
+
def _run_training(self, training_iterator: TrainingIterator) -> None:
|
| 363 |
+
"""This method loops over the `TrainingIterator`:
|
| 364 |
+
The actual iteration (for ... in ...) waits for the training function
|
| 365 |
+
on each worker to report a result and supplies it as a list of results.
|
| 366 |
+
Afterwards (in the body of the loop), it will report the result
|
| 367 |
+
to the Tune session.
|
| 368 |
+
The iterator ends after the training function on each worker has finished.
|
| 369 |
+
"""
|
| 370 |
+
for training_results in training_iterator:
|
| 371 |
+
# TODO(ml-team): add ability to report results from multiple workers.
|
| 372 |
+
self._propagate_results(training_results)
|
| 373 |
+
|
| 374 |
+
def _propagate_results(self, training_results: List[_TrainingResult]):
|
| 375 |
+
first_worker_result = training_results[0]
|
| 376 |
+
assert all(isinstance(result, _TrainingResult) for result in training_results)
|
| 377 |
+
|
| 378 |
+
tune_session = get_session()
|
| 379 |
+
|
| 380 |
+
# Check if any workers reported a checkpoint.
|
| 381 |
+
# If so, report a checkpoint pointing to the persisted location
|
| 382 |
+
# to Tune for book-keeping.
|
| 383 |
+
# NOTE: This removes the restriction for any individual worker
|
| 384 |
+
# (ex: global rank 0 worker) from needing to report a checkpoint.
|
| 385 |
+
# All workers reported a checkpoint to the same fs path, so there's
|
| 386 |
+
# no need to report multiple checkpoints to Tune.
|
| 387 |
+
worker_checkpoints = [
|
| 388 |
+
result.checkpoint
|
| 389 |
+
for result in training_results
|
| 390 |
+
if result.checkpoint is not None
|
| 391 |
+
]
|
| 392 |
+
at_least_one_reported_checkpoint = len(worker_checkpoints) > 0
|
| 393 |
+
|
| 394 |
+
if at_least_one_reported_checkpoint:
|
| 395 |
+
# Update the coordinator's checkpoint index to the latest.
|
| 396 |
+
# This is what keeps the checkpoint index in line with the workers.
|
| 397 |
+
tune_session.storage._update_checkpoint_index(first_worker_result.metrics)
|
| 398 |
+
|
| 399 |
+
# Make sure that all workers uploaded to the same location.
|
| 400 |
+
assert all(
|
| 401 |
+
checkpoint.path == tune_session.storage.checkpoint_fs_path
|
| 402 |
+
for checkpoint in worker_checkpoints
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
checkpoint = (
|
| 406 |
+
Checkpoint(
|
| 407 |
+
filesystem=tune_session.storage.storage_filesystem,
|
| 408 |
+
path=tune_session.storage.checkpoint_fs_path,
|
| 409 |
+
)
|
| 410 |
+
if at_least_one_reported_checkpoint
|
| 411 |
+
else None
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
tracked_training_result = _TrainingResult(
|
| 415 |
+
checkpoint=checkpoint,
|
| 416 |
+
metrics=first_worker_result.metrics,
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
logger.debug(
|
| 420 |
+
"Report (metrics, checkpoint) to the Tune session:\n"
|
| 421 |
+
f" metrics={tracked_training_result.metrics}\n"
|
| 422 |
+
f" checkpoint={tracked_training_result.checkpoint}"
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
# Report the metrics and checkpoint to Tune.
|
| 426 |
+
tune_session._report_training_result(tracked_training_result)
|
| 427 |
+
|
| 428 |
+
def training_loop(self) -> None:
|
| 429 |
+
scaling_config = self._validate_scaling_config(self.scaling_config)
|
| 430 |
+
|
| 431 |
+
train_loop_per_worker = construct_train_func(
|
| 432 |
+
self._train_loop_per_worker,
|
| 433 |
+
self._train_loop_config,
|
| 434 |
+
train_func_context=self._backend_config.train_func_context,
|
| 435 |
+
fn_arg_name="train_loop_per_worker",
|
| 436 |
+
discard_returns=True,
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
trial_info = TrialInfo(
|
| 440 |
+
name=session.get_trial_name(),
|
| 441 |
+
id=session.get_trial_id(),
|
| 442 |
+
resources=session.get_trial_resources(),
|
| 443 |
+
logdir=session.get_trial_dir(),
|
| 444 |
+
driver_ip=ray.util.get_node_ip_address(),
|
| 445 |
+
driver_node_id=ray.get_runtime_context().get_node_id(),
|
| 446 |
+
experiment_name=session.get_experiment_name(),
|
| 447 |
+
run_id=uuid.uuid4().hex,
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
backend_executor = self._backend_executor_cls(
|
| 451 |
+
backend_config=self._backend_config,
|
| 452 |
+
trial_info=trial_info,
|
| 453 |
+
num_workers=scaling_config.num_workers,
|
| 454 |
+
resources_per_worker=scaling_config._resources_per_worker_not_none,
|
| 455 |
+
max_retries=0,
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
# Start the remote actors.
|
| 459 |
+
backend_executor.start()
|
| 460 |
+
|
| 461 |
+
training_iterator = self._training_iterator_cls(
|
| 462 |
+
backend_executor=backend_executor,
|
| 463 |
+
backend_config=self._backend_config,
|
| 464 |
+
train_func=train_loop_per_worker,
|
| 465 |
+
datasets=self.datasets,
|
| 466 |
+
metadata=self.metadata,
|
| 467 |
+
data_config=self._data_config,
|
| 468 |
+
checkpoint=self.starting_checkpoint,
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
self._run_training(training_iterator)
|
| 472 |
+
|
| 473 |
+
# Shutdown workers.
|
| 474 |
+
backend_executor.shutdown()
|
| 475 |
+
|
| 476 |
+
def get_dataset_config(self) -> DataConfig:
|
| 477 |
+
"""Returns a copy of this Trainer's final dataset configs.
|
| 478 |
+
|
| 479 |
+
Returns:
|
| 480 |
+
The merged default + user-supplied dataset config.
|
| 481 |
+
"""
|
| 482 |
+
|
| 483 |
+
return self._data_config
|
| 484 |
+
|
| 485 |
+
@repr_with_fallback(["ipywidgets", "8"])
|
| 486 |
+
def _repr_mimebundle_(self, **kwargs):
|
| 487 |
+
"""Returns a mimebundle with an ipywidget repr and a simple text repr.
|
| 488 |
+
|
| 489 |
+
Depending on the frontend where the data is being displayed,
|
| 490 |
+
different mimetypes will be used from this bundle.
|
| 491 |
+
See https://ipython.readthedocs.io/en/stable/config/integrating.html
|
| 492 |
+
for information about this method, and
|
| 493 |
+
https://ipywidgets.readthedocs.io/en/latest/embedding.html
|
| 494 |
+
for more information about the jupyter widget mimetype.
|
| 495 |
+
|
| 496 |
+
Returns:
|
| 497 |
+
A mimebundle containing an ipywidget repr and a simple text repr.
|
| 498 |
+
"""
|
| 499 |
+
from ipywidgets import HTML, Layout, Tab, VBox
|
| 500 |
+
|
| 501 |
+
title = HTML(f"<h2>{self.__class__.__name__}</h2>")
|
| 502 |
+
|
| 503 |
+
children = []
|
| 504 |
+
titles = []
|
| 505 |
+
|
| 506 |
+
if self.datasets:
|
| 507 |
+
children.append(self._datasets_repr_())
|
| 508 |
+
titles.append("Datasets")
|
| 509 |
+
|
| 510 |
+
children.append(HTML(self._data_config_repr_html_()))
|
| 511 |
+
titles.append("Data Config")
|
| 512 |
+
|
| 513 |
+
if self._train_loop_config:
|
| 514 |
+
children.append(HTML(self._train_loop_config_repr_html_()))
|
| 515 |
+
titles.append("Train Loop Config")
|
| 516 |
+
|
| 517 |
+
if self.scaling_config:
|
| 518 |
+
children.append(HTML(self.scaling_config._repr_html_()))
|
| 519 |
+
titles.append("Scaling Config")
|
| 520 |
+
|
| 521 |
+
if self.run_config:
|
| 522 |
+
children.append(HTML(self.run_config._repr_html_()))
|
| 523 |
+
titles.append("Run Config")
|
| 524 |
+
|
| 525 |
+
if self._backend_config:
|
| 526 |
+
children.append(HTML(self._backend_config._repr_html_()))
|
| 527 |
+
titles.append("Backend Config")
|
| 528 |
+
|
| 529 |
+
tab = Tab(children, titles=titles)
|
| 530 |
+
widget = VBox([title, tab], layout=Layout(width="100%"))
|
| 531 |
+
bundle = widget._repr_mimebundle_(**kwargs)
|
| 532 |
+
bundle.update(
|
| 533 |
+
{
|
| 534 |
+
"text/plain": repr(self),
|
| 535 |
+
}
|
| 536 |
+
)
|
| 537 |
+
return bundle
|
| 538 |
+
|
| 539 |
+
def _train_loop_config_repr_html_(self) -> str:
|
| 540 |
+
if self._train_loop_config:
|
| 541 |
+
table_data = {}
|
| 542 |
+
for k, v in self._train_loop_config.items():
|
| 543 |
+
if isinstance(v, str) or str(v).isnumeric():
|
| 544 |
+
table_data[k] = v
|
| 545 |
+
elif hasattr(v, "_repr_html_"):
|
| 546 |
+
table_data[k] = v._repr_html_()
|
| 547 |
+
else:
|
| 548 |
+
table_data[k] = str(v)
|
| 549 |
+
|
| 550 |
+
return Template("title_data.html.j2").render(
|
| 551 |
+
title="Train Loop Config",
|
| 552 |
+
data=Template("scrollableTable.html.j2").render(
|
| 553 |
+
table=tabulate(
|
| 554 |
+
table_data.items(),
|
| 555 |
+
headers=["Setting", "Value"],
|
| 556 |
+
showindex=False,
|
| 557 |
+
tablefmt="unsafehtml",
|
| 558 |
+
),
|
| 559 |
+
max_height="none",
|
| 560 |
+
),
|
| 561 |
+
)
|
| 562 |
+
else:
|
| 563 |
+
return ""
|
| 564 |
+
|
| 565 |
+
def _data_config_repr_html_(self) -> str:
|
| 566 |
+
# TODO make this rendering nicer.
|
| 567 |
+
content = [str(self._data_config)]
|
| 568 |
+
return Template("rendered_html_common.html.j2").render(content=content)
|
| 569 |
+
|
| 570 |
+
def _datasets_repr_(self) -> str:
|
| 571 |
+
from ipywidgets import HTML, Layout, VBox
|
| 572 |
+
|
| 573 |
+
content = []
|
| 574 |
+
if self.datasets:
|
| 575 |
+
for name, config in self.datasets.items():
|
| 576 |
+
tab = config._tab_repr_()
|
| 577 |
+
if tab:
|
| 578 |
+
content.append(
|
| 579 |
+
HTML(
|
| 580 |
+
Template("title_data.html.j2").render(
|
| 581 |
+
title=f"Dataset - <code>{name}</code>", data=None
|
| 582 |
+
)
|
| 583 |
+
)
|
| 584 |
+
)
|
| 585 |
+
content.append(config._tab_repr_())
|
| 586 |
+
|
| 587 |
+
return VBox(content, layout=Layout(width="100%"))
|
.venv/lib/python3.11/site-packages/ray/train/error.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.util.annotations import PublicAPI
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
@PublicAPI(stability="beta")
|
| 5 |
+
class SessionMisuseError(Exception):
|
| 6 |
+
"""Indicates a method or function was used outside of a session."""
|
.venv/lib/python3.11/site-packages/ray/train/examples/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/train/examples/mlflow_simple_example.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
from ray import train
|
| 4 |
+
from ray.train import RunConfig, ScalingConfig
|
| 5 |
+
from ray.train.torch import TorchTrainer
|
| 6 |
+
from ray.tune.logger import TBXLoggerCallback
|
| 7 |
+
from ray.tune.logger.mlflow import MLflowLoggerCallback
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def train_func():
|
| 11 |
+
for i in range(3):
|
| 12 |
+
train.report(dict(epoch=i))
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
trainer = TorchTrainer(
|
| 16 |
+
train_func,
|
| 17 |
+
scaling_config=ScalingConfig(num_workers=2),
|
| 18 |
+
run_config=RunConfig(
|
| 19 |
+
callbacks=[
|
| 20 |
+
MLflowLoggerCallback(experiment_name="train_experiment"),
|
| 21 |
+
TBXLoggerCallback(),
|
| 22 |
+
],
|
| 23 |
+
),
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
# Run the training function, logging all the intermediate results
|
| 27 |
+
# to MLflow and Tensorboard.
|
| 28 |
+
result = trainer.fit()
|
| 29 |
+
|
| 30 |
+
# For MLFLow logs:
|
| 31 |
+
|
| 32 |
+
# MLFlow logs will by default be saved in an `mlflow` directory
|
| 33 |
+
# in the current working directory.
|
| 34 |
+
|
| 35 |
+
# $ cd mlflow
|
| 36 |
+
# # View the MLflow UI.
|
| 37 |
+
# $ mlflow ui
|
| 38 |
+
|
| 39 |
+
# You can change the directory by setting the `tracking_uri` argument
|
| 40 |
+
# in `MLflowLoggerCallback`.
|
| 41 |
+
|
| 42 |
+
# For TensorBoard logs:
|
| 43 |
+
|
| 44 |
+
# Print the latest run directory and keep note of it.
|
| 45 |
+
# For example: /home/ubuntu/ray_results/TorchTrainer_2022-06-13_20-31-06
|
| 46 |
+
print("Run directory:", Path(result.path).parent) # TensorBoard is saved in parent dir
|
| 47 |
+
|
| 48 |
+
# How to visualize the logs
|
| 49 |
+
|
| 50 |
+
# Navigate to the run directory of the trainer.
|
| 51 |
+
# For example `cd /home/ubuntu/ray_results/TorchTrainer_2022-06-13_20-31-06`
|
| 52 |
+
# $ cd <TRAINER_RUN_DIR>
|
| 53 |
+
#
|
| 54 |
+
# # View the tensorboard UI.
|
| 55 |
+
# $ tensorboard --logdir .
|
.venv/lib/python3.11/site-packages/ray/train/examples/tf/tune_tensorflow_autoencoder_example.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
import ray
|
| 4 |
+
from ray import tune
|
| 5 |
+
from ray.train import ScalingConfig
|
| 6 |
+
from ray.train.examples.tf.tensorflow_mnist_example import train_func
|
| 7 |
+
from ray.train.tensorflow import TensorflowTrainer
|
| 8 |
+
from ray.tune.tune_config import TuneConfig
|
| 9 |
+
from ray.tune.tuner import Tuner
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def tune_tensorflow_mnist(
|
| 13 |
+
num_workers: int = 2, num_samples: int = 2, use_gpu: bool = False
|
| 14 |
+
):
|
| 15 |
+
scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)
|
| 16 |
+
trainer = TensorflowTrainer(
|
| 17 |
+
train_loop_per_worker=train_func,
|
| 18 |
+
scaling_config=scaling_config,
|
| 19 |
+
)
|
| 20 |
+
tuner = Tuner(
|
| 21 |
+
trainer,
|
| 22 |
+
tune_config=TuneConfig(
|
| 23 |
+
num_samples=num_samples, metric="binary_crossentropy", mode="min"
|
| 24 |
+
),
|
| 25 |
+
param_space={
|
| 26 |
+
"train_loop_config": {
|
| 27 |
+
"lr": tune.loguniform(1e-4, 1e-1),
|
| 28 |
+
"batch_size": tune.choice([32, 64, 128]),
|
| 29 |
+
"epochs": 3,
|
| 30 |
+
}
|
| 31 |
+
},
|
| 32 |
+
)
|
| 33 |
+
best_accuracy = tuner.fit().get_best_result().metrics["binary_crossentropy"]
|
| 34 |
+
print(f"Best accuracy config: {best_accuracy}")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
if __name__ == "__main__":
|
| 38 |
+
parser = argparse.ArgumentParser()
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"--smoke-test",
|
| 41 |
+
action="store_true",
|
| 42 |
+
default=False,
|
| 43 |
+
help="Finish quickly for testing.",
|
| 44 |
+
)
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"--address", required=False, type=str, help="the address to use for Ray"
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"--num-workers",
|
| 50 |
+
"-n",
|
| 51 |
+
type=int,
|
| 52 |
+
default=2,
|
| 53 |
+
help="Sets number of workers for training.",
|
| 54 |
+
)
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
"--num-samples",
|
| 57 |
+
type=int,
|
| 58 |
+
default=2,
|
| 59 |
+
help="Sets number of samples for training.",
|
| 60 |
+
)
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
"--use-gpu", action="store_true", default=False, help="Enables GPU training"
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
args = parser.parse_args()
|
| 66 |
+
|
| 67 |
+
if args.smoke_test:
|
| 68 |
+
num_gpus = args.num_workers if args.use_gpu else 0
|
| 69 |
+
ray.init(num_cpus=8, num_gpus=num_gpus)
|
| 70 |
+
tune_tensorflow_mnist(num_workers=2, num_samples=2, use_gpu=args.use_gpu)
|
| 71 |
+
else:
|
| 72 |
+
ray.init(address=args.address)
|
| 73 |
+
tune_tensorflow_mnist(
|
| 74 |
+
num_workers=args.num_workers,
|
| 75 |
+
num_samples=args.num_samples,
|
| 76 |
+
use_gpu=args.use_gpu,
|
| 77 |
+
)
|
.venv/lib/python3.11/site-packages/ray/train/huggingface/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/train/huggingface/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (194 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/huggingface/transformers/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.train.huggingface.transformers._transformers_utils import (
|
| 2 |
+
RayTrainReportCallback,
|
| 3 |
+
prepare_trainer,
|
| 4 |
+
)
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"RayTrainReportCallback",
|
| 8 |
+
"prepare_trainer",
|
| 9 |
+
]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# DO NOT ADD ANYTHING AFTER THIS LINE.
|
.venv/lib/python3.11/site-packages/ray/train/huggingface/transformers/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (410 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/huggingface/transformers/__pycache__/_transformers_utils.cpython-311.pyc
ADDED
|
Binary file (8.74 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/huggingface/transformers/_transformers_utils.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import shutil
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from tempfile import TemporaryDirectory
|
| 5 |
+
from typing import Iterator, Optional, Type
|
| 6 |
+
|
| 7 |
+
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
| 8 |
+
|
| 9 |
+
import ray
|
| 10 |
+
from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag
|
| 11 |
+
from ray.data.iterator import _IterableFromIterator
|
| 12 |
+
from ray.train import Checkpoint
|
| 13 |
+
from ray.util import PublicAPI
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
TRANSFORMERS_IMPORT_ERROR: Optional[ImportError] = None
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
import transformers.trainer
|
| 22 |
+
from transformers import Trainer
|
| 23 |
+
from transformers.trainer_callback import TrainerCallback
|
| 24 |
+
except ImportError as e:
|
| 25 |
+
TRANSFORMERS_IMPORT_ERROR = e
|
| 26 |
+
TrainerCallback = object
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@PublicAPI(stability="beta")
|
| 30 |
+
class RayTrainReportCallback(TrainerCallback):
|
| 31 |
+
"""A simple callback to report checkpoints and metrics to Ray Train.
|
| 32 |
+
|
| 33 |
+
This callback is a subclass of `transformers.TrainerCallback
|
| 34 |
+
<https://huggingface.co/docs/transformers/main/en/main_classes/callback#transformers.TrainerCallback>`_
|
| 35 |
+
and overrides the `TrainerCallback.on_save()` method. After
|
| 36 |
+
a new checkpoint get saved, it fetches the latest metric dictionary
|
| 37 |
+
from `TrainerState.log_history` and reports it with the latest checkpoint
|
| 38 |
+
to Ray Train.
|
| 39 |
+
|
| 40 |
+
Checkpoints will be saved in the following structure::
|
| 41 |
+
|
| 42 |
+
checkpoint_00000*/ Ray Train Checkpoint
|
| 43 |
+
└─ checkpoint/ Hugging Face Transformers Checkpoint
|
| 44 |
+
|
| 45 |
+
For customized reporting and checkpointing logic, implement your own
|
| 46 |
+
`transformers.TrainerCallback` following this user
|
| 47 |
+
guide: :ref:`Saving and Loading Checkpoints <train-dl-saving-checkpoints>`.
|
| 48 |
+
|
| 49 |
+
Note that users should ensure that the logging, evaluation, and saving frequencies
|
| 50 |
+
are properly configured so that the monitoring metric is always up-to-date
|
| 51 |
+
when `transformers.Trainer` saves a checkpoint.
|
| 52 |
+
|
| 53 |
+
Suppose the monitoring metric is reported from evaluation stage:
|
| 54 |
+
|
| 55 |
+
Some valid configurations:
|
| 56 |
+
- evaluation_strategy == save_strategy == "epoch"
|
| 57 |
+
- evaluation_strategy == save_strategy == "steps", save_steps % eval_steps == 0
|
| 58 |
+
|
| 59 |
+
Some invalid configurations:
|
| 60 |
+
- evaluation_strategy != save_strategy
|
| 61 |
+
- evaluation_strategy == save_strategy == "steps", save_steps % eval_steps != 0
|
| 62 |
+
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
CHECKPOINT_NAME = "checkpoint"
|
| 66 |
+
|
| 67 |
+
def __init__(self, *args, **kwargs):
|
| 68 |
+
super().__init__(*args, **kwargs)
|
| 69 |
+
record_extra_usage_tag(TagKey.TRAIN_TRANSFORMERS_RAYTRAINREPORTCALLBACK, "1")
|
| 70 |
+
|
| 71 |
+
def on_save(self, args, state, control, **kwargs):
|
| 72 |
+
"""Event called after a checkpoint save."""
|
| 73 |
+
with TemporaryDirectory() as tmpdir:
|
| 74 |
+
# Aggregate all the logged metrics
|
| 75 |
+
metrics = {}
|
| 76 |
+
for log in state.log_history:
|
| 77 |
+
metrics.update(log)
|
| 78 |
+
|
| 79 |
+
# Copy ckpt files and construct a Ray Train Checkpoint
|
| 80 |
+
source_ckpt_path = transformers.trainer.get_last_checkpoint(args.output_dir)
|
| 81 |
+
if source_ckpt_path is not None:
|
| 82 |
+
target_ckpt_path = Path(tmpdir, self.CHECKPOINT_NAME).as_posix()
|
| 83 |
+
shutil.copytree(source_ckpt_path, target_ckpt_path)
|
| 84 |
+
checkpoint = Checkpoint.from_directory(tmpdir)
|
| 85 |
+
else:
|
| 86 |
+
checkpoint = None
|
| 87 |
+
|
| 88 |
+
# Report latest metrics and checkpoint to Ray Train
|
| 89 |
+
ray.train.report(metrics=metrics, checkpoint=checkpoint)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class RayTorchIterableDataset(IterableDataset):
|
| 93 |
+
"""Wrapper class for ray data iterables."""
|
| 94 |
+
|
| 95 |
+
def __init__(self, data_iterable) -> None:
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.data_iterable = data_iterable
|
| 98 |
+
|
| 99 |
+
def __iter__(self) -> Iterator:
|
| 100 |
+
return iter(self.data_iterable)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@PublicAPI(stability="beta")
|
| 104 |
+
def prepare_trainer(trainer: "Trainer") -> "Trainer":
|
| 105 |
+
"""Prepare your HuggingFace Transformer Trainer for Ray Train.
|
| 106 |
+
|
| 107 |
+
This utility function enable the trainer integrates with Ray Data Integration.
|
| 108 |
+
Internally, it overrides the `get_train_dataloader` and `get_eval_dataloader`
|
| 109 |
+
methods and inject the data integration logics if the `train_dataset` and
|
| 110 |
+
`eval_dataset` are Ray Data Iterables.
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
if TRANSFORMERS_IMPORT_ERROR is not None:
|
| 114 |
+
raise TRANSFORMERS_IMPORT_ERROR
|
| 115 |
+
|
| 116 |
+
base_trainer_class: Type[transformers.trainer.Trainer] = trainer.__class__
|
| 117 |
+
|
| 118 |
+
class RayTransformersTrainer(base_trainer_class):
|
| 119 |
+
"""A Wrapper of `transformers.Trainer` for Ray Data Integration."""
|
| 120 |
+
|
| 121 |
+
def get_train_dataloader(self) -> DataLoader:
|
| 122 |
+
if isinstance(self.train_dataset, _IterableFromIterator):
|
| 123 |
+
dataset = RayTorchIterableDataset(self.train_dataset)
|
| 124 |
+
return DataLoader(dataset, batch_size=1, collate_fn=lambda x: x[0])
|
| 125 |
+
else:
|
| 126 |
+
return super().get_train_dataloader()
|
| 127 |
+
|
| 128 |
+
def get_eval_dataloader(
|
| 129 |
+
self, eval_dataset: Optional[Dataset] = None
|
| 130 |
+
) -> DataLoader:
|
| 131 |
+
if eval_dataset is None:
|
| 132 |
+
eval_dataset = self.eval_dataset
|
| 133 |
+
|
| 134 |
+
if isinstance(eval_dataset, _IterableFromIterator):
|
| 135 |
+
dataset = RayTorchIterableDataset(eval_dataset)
|
| 136 |
+
return DataLoader(dataset, batch_size=1, collate_fn=lambda x: x[0])
|
| 137 |
+
else:
|
| 138 |
+
return super().get_eval_dataloader(eval_dataset)
|
| 139 |
+
|
| 140 |
+
trainer.__class__ = RayTransformersTrainer
|
| 141 |
+
|
| 142 |
+
record_extra_usage_tag(TagKey.TRAIN_TRANSFORMERS_PREPARE_TRAINER, "1")
|
| 143 |
+
return trainer
|
.venv/lib/python3.11/site-packages/ray/train/lightgbm/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.train.lightgbm._lightgbm_utils import RayTrainReportCallback
|
| 2 |
+
from ray.train.lightgbm.lightgbm_checkpoint import LightGBMCheckpoint
|
| 3 |
+
from ray.train.lightgbm.lightgbm_predictor import LightGBMPredictor
|
| 4 |
+
from ray.train.lightgbm.lightgbm_trainer import LightGBMTrainer
|
| 5 |
+
from ray.train.v2._internal.constants import is_v2_enabled
|
| 6 |
+
|
| 7 |
+
if is_v2_enabled():
|
| 8 |
+
from ray.train.v2.lightgbm.lightgbm_trainer import LightGBMTrainer # noqa: F811
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"RayTrainReportCallback",
|
| 12 |
+
"LightGBMCheckpoint",
|
| 13 |
+
"LightGBMPredictor",
|
| 14 |
+
"LightGBMTrainer",
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# DO NOT ADD ANYTHING AFTER THIS LINE.
|
.venv/lib/python3.11/site-packages/ray/train/lightgbm/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (814 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/lightgbm/__pycache__/_lightgbm_utils.cpython-311.pyc
ADDED
|
Binary file (8.97 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/lightgbm/__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (5.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/lightgbm/__pycache__/lightgbm_checkpoint.cpython-311.pyc
ADDED
|
Binary file (4.03 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/lightgbm/__pycache__/lightgbm_predictor.cpython-311.pyc
ADDED
|
Binary file (7.27 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/lightgbm/__pycache__/lightgbm_trainer.cpython-311.pyc
ADDED
|
Binary file (11 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/lightgbm/__pycache__/v2.cpython-311.pyc
ADDED
|
Binary file (6.73 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/lightgbm/_lightgbm_utils.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tempfile
|
| 2 |
+
from contextlib import contextmanager
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Callable, Dict, List, Optional, Union
|
| 5 |
+
|
| 6 |
+
from lightgbm.basic import Booster
|
| 7 |
+
from lightgbm.callback import CallbackEnv
|
| 8 |
+
|
| 9 |
+
import ray.train
|
| 10 |
+
from ray.train import Checkpoint
|
| 11 |
+
from ray.tune.utils import flatten_dict
|
| 12 |
+
from ray.util.annotations import PublicAPI
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@PublicAPI(stability="beta")
|
| 16 |
+
class RayTrainReportCallback:
|
| 17 |
+
"""Creates a callback that reports metrics and checkpoints model.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
metrics: Metrics to report. If this is a list,
|
| 21 |
+
each item should be a metric key reported by LightGBM,
|
| 22 |
+
and it will be reported to Ray Train/Tune under the same name.
|
| 23 |
+
This can also be a dict of {<key-to-report>: <lightgbm-metric-key>},
|
| 24 |
+
which can be used to rename LightGBM default metrics.
|
| 25 |
+
filename: Customize the saved checkpoint file type by passing
|
| 26 |
+
a filename. Defaults to "model.txt".
|
| 27 |
+
frequency: How often to save checkpoints, in terms of iterations.
|
| 28 |
+
Defaults to 0 (no checkpoints are saved during training).
|
| 29 |
+
checkpoint_at_end: Whether or not to save a checkpoint at the end of training.
|
| 30 |
+
results_postprocessing_fn: An optional Callable that takes in
|
| 31 |
+
the metrics dict that will be reported (after it has been flattened)
|
| 32 |
+
and returns a modified dict.
|
| 33 |
+
|
| 34 |
+
Examples
|
| 35 |
+
--------
|
| 36 |
+
|
| 37 |
+
Reporting checkpoints and metrics to Ray Tune when running many
|
| 38 |
+
independent xgboost trials (without data parallelism within a trial).
|
| 39 |
+
|
| 40 |
+
.. testcode::
|
| 41 |
+
:skipif: True
|
| 42 |
+
|
| 43 |
+
import lightgbm
|
| 44 |
+
|
| 45 |
+
from ray.train.lightgbm import RayTrainReportCallback
|
| 46 |
+
|
| 47 |
+
config = {
|
| 48 |
+
# ...
|
| 49 |
+
"metric": ["binary_logloss", "binary_error"],
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
# Report only log loss to Tune after each validation epoch.
|
| 53 |
+
bst = lightgbm.train(
|
| 54 |
+
...,
|
| 55 |
+
callbacks=[
|
| 56 |
+
RayTrainReportCallback(
|
| 57 |
+
metrics={"loss": "eval-binary_logloss"}, frequency=1
|
| 58 |
+
)
|
| 59 |
+
],
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
Loading a model from a checkpoint reported by this callback.
|
| 63 |
+
|
| 64 |
+
.. testcode::
|
| 65 |
+
:skipif: True
|
| 66 |
+
|
| 67 |
+
from ray.train.lightgbm import RayTrainReportCallback
|
| 68 |
+
|
| 69 |
+
# Get a `Checkpoint` object that is saved by the callback during training.
|
| 70 |
+
result = trainer.fit()
|
| 71 |
+
booster = RayTrainReportCallback.get_model(result.checkpoint)
|
| 72 |
+
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
CHECKPOINT_NAME = "model.txt"
|
| 76 |
+
|
| 77 |
+
def __init__(
|
| 78 |
+
self,
|
| 79 |
+
metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
|
| 80 |
+
filename: str = CHECKPOINT_NAME,
|
| 81 |
+
frequency: int = 0,
|
| 82 |
+
checkpoint_at_end: bool = True,
|
| 83 |
+
results_postprocessing_fn: Optional[
|
| 84 |
+
Callable[[Dict[str, Union[float, List[float]]]], Dict[str, float]]
|
| 85 |
+
] = None,
|
| 86 |
+
):
|
| 87 |
+
if isinstance(metrics, str):
|
| 88 |
+
metrics = [metrics]
|
| 89 |
+
self._metrics = metrics
|
| 90 |
+
self._filename = filename
|
| 91 |
+
self._frequency = frequency
|
| 92 |
+
self._checkpoint_at_end = checkpoint_at_end
|
| 93 |
+
self._results_postprocessing_fn = results_postprocessing_fn
|
| 94 |
+
|
| 95 |
+
@classmethod
|
| 96 |
+
def get_model(
|
| 97 |
+
cls, checkpoint: Checkpoint, filename: str = CHECKPOINT_NAME
|
| 98 |
+
) -> Booster:
|
| 99 |
+
"""Retrieve the model stored in a checkpoint reported by this callback.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
checkpoint: The checkpoint object returned by a training run.
|
| 103 |
+
The checkpoint should be saved by an instance of this callback.
|
| 104 |
+
filename: The filename to load the model from, which should match
|
| 105 |
+
the filename used when creating the callback.
|
| 106 |
+
"""
|
| 107 |
+
with checkpoint.as_directory() as checkpoint_path:
|
| 108 |
+
return Booster(model_file=Path(checkpoint_path, filename).as_posix())
|
| 109 |
+
|
| 110 |
+
def _get_report_dict(self, evals_log: Dict[str, Dict[str, list]]) -> dict:
|
| 111 |
+
result_dict = flatten_dict(evals_log, delimiter="-")
|
| 112 |
+
if not self._metrics:
|
| 113 |
+
report_dict = result_dict
|
| 114 |
+
else:
|
| 115 |
+
report_dict = {}
|
| 116 |
+
for key in self._metrics:
|
| 117 |
+
if isinstance(self._metrics, dict):
|
| 118 |
+
metric = self._metrics[key]
|
| 119 |
+
else:
|
| 120 |
+
metric = key
|
| 121 |
+
report_dict[key] = result_dict[metric]
|
| 122 |
+
if self._results_postprocessing_fn:
|
| 123 |
+
report_dict = self._results_postprocessing_fn(report_dict)
|
| 124 |
+
return report_dict
|
| 125 |
+
|
| 126 |
+
def _get_eval_result(self, env: CallbackEnv) -> dict:
|
| 127 |
+
eval_result = {}
|
| 128 |
+
for entry in env.evaluation_result_list:
|
| 129 |
+
data_name, eval_name, result = entry[0:3]
|
| 130 |
+
if len(entry) > 4:
|
| 131 |
+
stdv = entry[4]
|
| 132 |
+
suffix = "-mean"
|
| 133 |
+
else:
|
| 134 |
+
stdv = None
|
| 135 |
+
suffix = ""
|
| 136 |
+
if data_name not in eval_result:
|
| 137 |
+
eval_result[data_name] = {}
|
| 138 |
+
eval_result[data_name][eval_name + suffix] = result
|
| 139 |
+
if stdv is not None:
|
| 140 |
+
eval_result[data_name][eval_name + "-stdv"] = stdv
|
| 141 |
+
return eval_result
|
| 142 |
+
|
| 143 |
+
@contextmanager
|
| 144 |
+
def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]:
|
| 145 |
+
if ray.train.get_context().get_world_rank() in (0, None):
|
| 146 |
+
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
|
| 147 |
+
model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix())
|
| 148 |
+
yield Checkpoint.from_directory(temp_checkpoint_dir)
|
| 149 |
+
else:
|
| 150 |
+
yield None
|
| 151 |
+
|
| 152 |
+
def __call__(self, env: CallbackEnv) -> None:
|
| 153 |
+
eval_result = self._get_eval_result(env)
|
| 154 |
+
report_dict = self._get_report_dict(eval_result)
|
| 155 |
+
|
| 156 |
+
# Ex: if frequency=2, checkpoint_at_end=True and num_boost_rounds=11,
|
| 157 |
+
# you will checkpoint at iterations 1, 3, 5, ..., 9, and 10 (checkpoint_at_end)
|
| 158 |
+
# (iterations count from 0)
|
| 159 |
+
on_last_iter = env.iteration == env.end_iteration - 1
|
| 160 |
+
should_checkpoint_at_end = on_last_iter and self._checkpoint_at_end
|
| 161 |
+
should_checkpoint_with_frequency = (
|
| 162 |
+
self._frequency != 0 and (env.iteration + 1) % self._frequency == 0
|
| 163 |
+
)
|
| 164 |
+
should_checkpoint = should_checkpoint_at_end or should_checkpoint_with_frequency
|
| 165 |
+
|
| 166 |
+
if should_checkpoint:
|
| 167 |
+
with self._get_checkpoint(model=env.model) as checkpoint:
|
| 168 |
+
ray.train.report(report_dict, checkpoint=checkpoint)
|
| 169 |
+
else:
|
| 170 |
+
ray.train.report(report_dict)
|
.venv/lib/python3.11/site-packages/ray/train/lightgbm/config.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import threading
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Dict, Optional
|
| 5 |
+
|
| 6 |
+
import ray
|
| 7 |
+
from ray.train._internal.utils import get_address_and_port
|
| 8 |
+
from ray.train._internal.worker_group import WorkerGroup
|
| 9 |
+
from ray.train.backend import Backend, BackendConfig
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# Global LightGBM distributed network configuration for each worker process.
|
| 15 |
+
_lightgbm_network_params: Optional[Dict[str, Any]] = None
|
| 16 |
+
_lightgbm_network_params_lock = threading.Lock()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_network_params() -> Dict[str, Any]:
|
| 20 |
+
"""Returns the network parameters to enable LightGBM distributed training."""
|
| 21 |
+
global _lightgbm_network_params
|
| 22 |
+
|
| 23 |
+
with _lightgbm_network_params_lock:
|
| 24 |
+
if not _lightgbm_network_params:
|
| 25 |
+
logger.warning(
|
| 26 |
+
"`ray.train.lightgbm.get_network_params` was called outside "
|
| 27 |
+
"the context of a `ray.train.lightgbm.LightGBMTrainer`. "
|
| 28 |
+
"The current process has no knowledge of the distributed training "
|
| 29 |
+
"worker group, so this method will return an empty dict. "
|
| 30 |
+
"Please call this within the training loop of a "
|
| 31 |
+
"`ray.train.lightgbm.LightGBMTrainer`. "
|
| 32 |
+
"If you are in fact calling this within a `LightGBMTrainer`, "
|
| 33 |
+
"this is unexpected: please file a bug report to the Ray Team."
|
| 34 |
+
)
|
| 35 |
+
return {}
|
| 36 |
+
|
| 37 |
+
return _lightgbm_network_params.copy()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _set_network_params(
|
| 41 |
+
num_machines: int,
|
| 42 |
+
local_listen_port: int,
|
| 43 |
+
machines: str,
|
| 44 |
+
):
|
| 45 |
+
global _lightgbm_network_params
|
| 46 |
+
|
| 47 |
+
with _lightgbm_network_params_lock:
|
| 48 |
+
assert (
|
| 49 |
+
_lightgbm_network_params is None
|
| 50 |
+
), "LightGBM network params are already initialized."
|
| 51 |
+
_lightgbm_network_params = dict(
|
| 52 |
+
num_machines=num_machines,
|
| 53 |
+
local_listen_port=local_listen_port,
|
| 54 |
+
machines=machines,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@dataclass
|
| 59 |
+
class LightGBMConfig(BackendConfig):
|
| 60 |
+
"""Configuration for LightGBM distributed data-parallel training setup.
|
| 61 |
+
|
| 62 |
+
See the LightGBM docs for more information on the "network parameters"
|
| 63 |
+
that Ray Train sets up for you:
|
| 64 |
+
https://lightgbm.readthedocs.io/en/latest/Parameters.html#network-parameters
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def backend_cls(self):
|
| 69 |
+
return _LightGBMBackend
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class _LightGBMBackend(Backend):
|
| 73 |
+
def on_training_start(
|
| 74 |
+
self, worker_group: WorkerGroup, backend_config: LightGBMConfig
|
| 75 |
+
):
|
| 76 |
+
node_ips_and_ports = worker_group.execute(get_address_and_port)
|
| 77 |
+
ports = [port for _, port in node_ips_and_ports]
|
| 78 |
+
machines = ",".join(
|
| 79 |
+
[f"{node_ip}:{port}" for node_ip, port in node_ips_and_ports]
|
| 80 |
+
)
|
| 81 |
+
num_machines = len(worker_group)
|
| 82 |
+
ray.get(
|
| 83 |
+
[
|
| 84 |
+
worker_group.execute_single_async(
|
| 85 |
+
rank, _set_network_params, num_machines, ports[rank], machines
|
| 86 |
+
)
|
| 87 |
+
for rank in range(len(worker_group))
|
| 88 |
+
]
|
| 89 |
+
)
|
.venv/lib/python3.11/site-packages/ray/train/lightgbm/lightgbm_checkpoint.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tempfile
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import TYPE_CHECKING, Optional
|
| 4 |
+
|
| 5 |
+
import lightgbm
|
| 6 |
+
|
| 7 |
+
from ray.train._internal.framework_checkpoint import FrameworkCheckpoint
|
| 8 |
+
from ray.util.annotations import PublicAPI
|
| 9 |
+
|
| 10 |
+
if TYPE_CHECKING:
|
| 11 |
+
from ray.data.preprocessor import Preprocessor
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@PublicAPI(stability="beta")
|
| 15 |
+
class LightGBMCheckpoint(FrameworkCheckpoint):
|
| 16 |
+
"""A :py:class:`~ray.train.Checkpoint` with LightGBM-specific functionality."""
|
| 17 |
+
|
| 18 |
+
MODEL_FILENAME = "model.txt"
|
| 19 |
+
|
| 20 |
+
@classmethod
|
| 21 |
+
def from_model(
|
| 22 |
+
cls,
|
| 23 |
+
booster: lightgbm.Booster,
|
| 24 |
+
*,
|
| 25 |
+
preprocessor: Optional["Preprocessor"] = None,
|
| 26 |
+
path: Optional[str] = None,
|
| 27 |
+
) -> "LightGBMCheckpoint":
|
| 28 |
+
"""Create a :py:class:`~ray.train.Checkpoint` that stores a LightGBM model.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
booster: The LightGBM model to store in the checkpoint.
|
| 32 |
+
preprocessor: A fitted preprocessor to be applied before inference.
|
| 33 |
+
path: The path to the directory where the checkpoint file will be saved.
|
| 34 |
+
This should start as an empty directory, since the *entire*
|
| 35 |
+
directory will be treated as the checkpoint when reported.
|
| 36 |
+
By default, a temporary directory will be created.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
An :py:class:`LightGBMCheckpoint` containing the specified ``Estimator``.
|
| 40 |
+
|
| 41 |
+
Examples:
|
| 42 |
+
>>> import lightgbm
|
| 43 |
+
>>> import numpy as np
|
| 44 |
+
>>> from ray.train.lightgbm import LightGBMCheckpoint
|
| 45 |
+
>>>
|
| 46 |
+
>>> train_X = np.array([[1, 2], [3, 4]])
|
| 47 |
+
>>> train_y = np.array([0, 1])
|
| 48 |
+
>>>
|
| 49 |
+
>>> model = lightgbm.LGBMClassifier().fit(train_X, train_y)
|
| 50 |
+
>>> checkpoint = LightGBMCheckpoint.from_model(model.booster_)
|
| 51 |
+
"""
|
| 52 |
+
checkpoint_path = Path(path or tempfile.mkdtemp())
|
| 53 |
+
|
| 54 |
+
if not checkpoint_path.is_dir():
|
| 55 |
+
raise ValueError(f"`path` must be a directory, but got: {checkpoint_path}")
|
| 56 |
+
|
| 57 |
+
booster.save_model(checkpoint_path.joinpath(cls.MODEL_FILENAME).as_posix())
|
| 58 |
+
|
| 59 |
+
checkpoint = cls.from_directory(checkpoint_path.as_posix())
|
| 60 |
+
if preprocessor:
|
| 61 |
+
checkpoint.set_preprocessor(preprocessor)
|
| 62 |
+
|
| 63 |
+
return checkpoint
|
| 64 |
+
|
| 65 |
+
def get_model(self) -> lightgbm.Booster:
|
| 66 |
+
"""Retrieve the LightGBM model stored in this checkpoint."""
|
| 67 |
+
with self.as_directory() as checkpoint_path:
|
| 68 |
+
return lightgbm.Booster(
|
| 69 |
+
model_file=Path(checkpoint_path, self.MODEL_FILENAME).as_posix()
|
| 70 |
+
)
|
.venv/lib/python3.11/site-packages/ray/train/lightgbm/lightgbm_predictor.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING, List, Optional, Union
|
| 2 |
+
|
| 3 |
+
import lightgbm
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from pandas.api.types import is_object_dtype
|
| 6 |
+
|
| 7 |
+
from ray.air.constants import TENSOR_COLUMN_NAME
|
| 8 |
+
from ray.air.data_batch_type import DataBatchType
|
| 9 |
+
from ray.air.util.data_batch_conversion import _unwrap_ndarray_object_type_if_needed
|
| 10 |
+
from ray.train.lightgbm import LightGBMCheckpoint
|
| 11 |
+
from ray.train.predictor import Predictor
|
| 12 |
+
from ray.util.annotations import PublicAPI
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
from ray.data.preprocessor import Preprocessor
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@PublicAPI(stability="beta")
|
| 19 |
+
class LightGBMPredictor(Predictor):
|
| 20 |
+
"""A predictor for LightGBM models.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
model: The LightGBM booster to use for predictions.
|
| 24 |
+
preprocessor: A preprocessor used to transform data batches prior
|
| 25 |
+
to prediction.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self, model: lightgbm.Booster, preprocessor: Optional["Preprocessor"] = None
|
| 30 |
+
):
|
| 31 |
+
self.model = model
|
| 32 |
+
super().__init__(preprocessor)
|
| 33 |
+
|
| 34 |
+
def __repr__(self):
|
| 35 |
+
return (
|
| 36 |
+
f"{self.__class__.__name__}(model={self.model!r}, "
|
| 37 |
+
f"preprocessor={self._preprocessor!r})"
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
@classmethod
|
| 41 |
+
def from_checkpoint(cls, checkpoint: LightGBMCheckpoint) -> "LightGBMPredictor":
|
| 42 |
+
"""Instantiate the predictor from a LightGBMCheckpoint.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
checkpoint: The checkpoint to load the model and preprocessor from.
|
| 46 |
+
|
| 47 |
+
"""
|
| 48 |
+
model = checkpoint.get_model()
|
| 49 |
+
preprocessor = checkpoint.get_preprocessor()
|
| 50 |
+
return cls(model=model, preprocessor=preprocessor)
|
| 51 |
+
|
| 52 |
+
def predict(
|
| 53 |
+
self,
|
| 54 |
+
data: DataBatchType,
|
| 55 |
+
feature_columns: Optional[Union[List[str], List[int]]] = None,
|
| 56 |
+
**predict_kwargs,
|
| 57 |
+
) -> DataBatchType:
|
| 58 |
+
"""Run inference on data batch.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
data: A batch of input data.
|
| 62 |
+
feature_columns: The names or indices of the columns in the
|
| 63 |
+
data to use as features to predict on. If None, then use
|
| 64 |
+
all columns in ``data``.
|
| 65 |
+
**predict_kwargs: Keyword arguments passed to
|
| 66 |
+
``lightgbm.Booster.predict``.
|
| 67 |
+
|
| 68 |
+
Examples:
|
| 69 |
+
>>> import numpy as np
|
| 70 |
+
>>> import lightgbm as lgbm
|
| 71 |
+
>>> from ray.train.lightgbm import LightGBMPredictor
|
| 72 |
+
>>>
|
| 73 |
+
>>> train_X = np.array([[1, 2], [3, 4]])
|
| 74 |
+
>>> train_y = np.array([0, 1])
|
| 75 |
+
>>>
|
| 76 |
+
>>> model = lgbm.LGBMClassifier().fit(train_X, train_y)
|
| 77 |
+
>>> predictor = LightGBMPredictor(model=model.booster_)
|
| 78 |
+
>>>
|
| 79 |
+
>>> data = np.array([[1, 2], [3, 4]])
|
| 80 |
+
>>> predictions = predictor.predict(data)
|
| 81 |
+
>>>
|
| 82 |
+
>>> # Only use first and second column as the feature
|
| 83 |
+
>>> data = np.array([[1, 2, 8], [3, 4, 9]])
|
| 84 |
+
>>> predictions = predictor.predict(data, feature_columns=[0, 1])
|
| 85 |
+
|
| 86 |
+
>>> import pandas as pd
|
| 87 |
+
>>> import lightgbm as lgbm
|
| 88 |
+
>>> from ray.train.lightgbm import LightGBMPredictor
|
| 89 |
+
>>>
|
| 90 |
+
>>> train_X = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"])
|
| 91 |
+
>>> train_y = pd.Series([0, 1])
|
| 92 |
+
>>>
|
| 93 |
+
>>> model = lgbm.LGBMClassifier().fit(train_X, train_y)
|
| 94 |
+
>>> predictor = LightGBMPredictor(model=model.booster_)
|
| 95 |
+
>>>
|
| 96 |
+
>>> # Pandas dataframe.
|
| 97 |
+
>>> data = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"])
|
| 98 |
+
>>> predictions = predictor.predict(data)
|
| 99 |
+
>>>
|
| 100 |
+
>>> # Only use first and second column as the feature
|
| 101 |
+
>>> data = pd.DataFrame([[1, 2, 8], [3, 4, 9]], columns=["A", "B", "C"])
|
| 102 |
+
>>> predictions = predictor.predict(data, feature_columns=["A", "B"])
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
Prediction result.
|
| 107 |
+
|
| 108 |
+
"""
|
| 109 |
+
return Predictor.predict(
|
| 110 |
+
self, data, feature_columns=feature_columns, **predict_kwargs
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def _predict_pandas(
|
| 114 |
+
self,
|
| 115 |
+
data: "pd.DataFrame",
|
| 116 |
+
feature_columns: Optional[Union[List[str], List[int]]] = None,
|
| 117 |
+
**predict_kwargs,
|
| 118 |
+
) -> pd.DataFrame:
|
| 119 |
+
feature_names = None
|
| 120 |
+
if TENSOR_COLUMN_NAME in data:
|
| 121 |
+
data = data[TENSOR_COLUMN_NAME].to_numpy()
|
| 122 |
+
data = _unwrap_ndarray_object_type_if_needed(data)
|
| 123 |
+
if feature_columns:
|
| 124 |
+
# In this case feature_columns is a list of integers
|
| 125 |
+
data = data[:, feature_columns]
|
| 126 |
+
# Turn into dataframe to make dtype resolution easy
|
| 127 |
+
data = pd.DataFrame(data, columns=feature_names)
|
| 128 |
+
data = data.infer_objects()
|
| 129 |
+
|
| 130 |
+
# Pandas does not detect categorical dtypes. Any remaining object
|
| 131 |
+
# dtypes are probably categories, so convert them.
|
| 132 |
+
# This will fail if we have a category composed entirely of
|
| 133 |
+
# integers, but this is the best we can do here.
|
| 134 |
+
update_dtypes = {}
|
| 135 |
+
for column in data.columns:
|
| 136 |
+
dtype = data.dtypes[column]
|
| 137 |
+
if is_object_dtype(dtype):
|
| 138 |
+
update_dtypes[column] = pd.CategoricalDtype()
|
| 139 |
+
|
| 140 |
+
if update_dtypes:
|
| 141 |
+
data = data.astype(update_dtypes, copy=False)
|
| 142 |
+
elif feature_columns:
|
| 143 |
+
# feature_columns is a list of integers or strings
|
| 144 |
+
data = data[feature_columns]
|
| 145 |
+
|
| 146 |
+
df = pd.DataFrame(self.model.predict(data, **predict_kwargs))
|
| 147 |
+
df.columns = (
|
| 148 |
+
["predictions"]
|
| 149 |
+
if len(df.columns) == 1
|
| 150 |
+
else [f"predictions_{i}" for i in range(len(df.columns))]
|
| 151 |
+
)
|
| 152 |
+
return df
|
.venv/lib/python3.11/site-packages/ray/train/lightgbm/lightgbm_trainer.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from functools import partial
|
| 3 |
+
from typing import Any, Dict, Optional
|
| 4 |
+
|
| 5 |
+
import lightgbm
|
| 6 |
+
|
| 7 |
+
import ray
|
| 8 |
+
from ray.train import Checkpoint
|
| 9 |
+
from ray.train.constants import _DEPRECATED_VALUE, TRAIN_DATASET_KEY
|
| 10 |
+
from ray.train.lightgbm import RayTrainReportCallback
|
| 11 |
+
from ray.train.lightgbm.v2 import LightGBMTrainer as SimpleLightGBMTrainer
|
| 12 |
+
from ray.train.trainer import GenDataset
|
| 13 |
+
from ray.util.annotations import PublicAPI
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _lightgbm_train_fn_per_worker(
|
| 19 |
+
config: dict,
|
| 20 |
+
label_column: str,
|
| 21 |
+
num_boost_round: int,
|
| 22 |
+
dataset_keys: set,
|
| 23 |
+
lightgbm_train_kwargs: dict,
|
| 24 |
+
):
|
| 25 |
+
checkpoint = ray.train.get_checkpoint()
|
| 26 |
+
starting_model = None
|
| 27 |
+
remaining_iters = num_boost_round
|
| 28 |
+
if checkpoint:
|
| 29 |
+
starting_model = RayTrainReportCallback.get_model(checkpoint)
|
| 30 |
+
starting_iter = starting_model.current_iteration()
|
| 31 |
+
remaining_iters = num_boost_round - starting_iter
|
| 32 |
+
logger.info(
|
| 33 |
+
f"Model loaded from checkpoint will train for "
|
| 34 |
+
f"additional {remaining_iters} iterations (trees) in order "
|
| 35 |
+
"to achieve the target number of iterations "
|
| 36 |
+
f"({num_boost_round=})."
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
train_ds_iter = ray.train.get_dataset_shard(TRAIN_DATASET_KEY)
|
| 40 |
+
train_df = train_ds_iter.materialize().to_pandas()
|
| 41 |
+
|
| 42 |
+
eval_ds_iters = {
|
| 43 |
+
k: ray.train.get_dataset_shard(k)
|
| 44 |
+
for k in dataset_keys
|
| 45 |
+
if k != TRAIN_DATASET_KEY
|
| 46 |
+
}
|
| 47 |
+
eval_dfs = {k: d.materialize().to_pandas() for k, d in eval_ds_iters.items()}
|
| 48 |
+
|
| 49 |
+
train_X, train_y = train_df.drop(label_column, axis=1), train_df[label_column]
|
| 50 |
+
train_set = lightgbm.Dataset(train_X, label=train_y)
|
| 51 |
+
|
| 52 |
+
# NOTE: Include the training dataset in the evaluation datasets.
|
| 53 |
+
# This allows `train-*` metrics to be calculated and reported.
|
| 54 |
+
valid_sets = [train_set]
|
| 55 |
+
valid_names = [TRAIN_DATASET_KEY]
|
| 56 |
+
|
| 57 |
+
for eval_name, eval_df in eval_dfs.items():
|
| 58 |
+
eval_X, eval_y = eval_df.drop(label_column, axis=1), eval_df[label_column]
|
| 59 |
+
valid_sets.append(lightgbm.Dataset(eval_X, label=eval_y))
|
| 60 |
+
valid_names.append(eval_name)
|
| 61 |
+
|
| 62 |
+
# Add network params of the worker group to enable distributed training.
|
| 63 |
+
config.update(ray.train.lightgbm.v2.get_network_params())
|
| 64 |
+
|
| 65 |
+
lightgbm.train(
|
| 66 |
+
params=config,
|
| 67 |
+
train_set=train_set,
|
| 68 |
+
num_boost_round=remaining_iters,
|
| 69 |
+
valid_sets=valid_sets,
|
| 70 |
+
valid_names=valid_names,
|
| 71 |
+
init_model=starting_model,
|
| 72 |
+
**lightgbm_train_kwargs,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@PublicAPI(stability="beta")
|
| 77 |
+
class LightGBMTrainer(SimpleLightGBMTrainer):
|
| 78 |
+
"""A Trainer for data parallel LightGBM training.
|
| 79 |
+
|
| 80 |
+
This Trainer runs the LightGBM training loop in a distributed manner
|
| 81 |
+
using multiple Ray Actors.
|
| 82 |
+
|
| 83 |
+
If you would like to take advantage of LightGBM's built-in handling
|
| 84 |
+
for features with the categorical data type, consider applying the
|
| 85 |
+
:class:`Categorizer` preprocessor to set the dtypes in the dataset.
|
| 86 |
+
|
| 87 |
+
.. note::
|
| 88 |
+
``LightGBMTrainer`` does not modify or otherwise alter the working
|
| 89 |
+
of the LightGBM distributed training algorithm.
|
| 90 |
+
Ray only provides orchestration, data ingest and fault tolerance.
|
| 91 |
+
For more information on LightGBM distributed training, refer to
|
| 92 |
+
`LightGBM documentation <https://lightgbm.readthedocs.io/>`__.
|
| 93 |
+
|
| 94 |
+
Example:
|
| 95 |
+
.. testcode::
|
| 96 |
+
|
| 97 |
+
import ray
|
| 98 |
+
|
| 99 |
+
from ray.train.lightgbm import LightGBMTrainer
|
| 100 |
+
from ray.train import ScalingConfig
|
| 101 |
+
|
| 102 |
+
train_dataset = ray.data.from_items(
|
| 103 |
+
[{"x": x, "y": x + 1} for x in range(32)]
|
| 104 |
+
)
|
| 105 |
+
trainer = LightGBMTrainer(
|
| 106 |
+
label_column="y",
|
| 107 |
+
params={"objective": "regression"},
|
| 108 |
+
scaling_config=ScalingConfig(num_workers=3),
|
| 109 |
+
datasets={"train": train_dataset},
|
| 110 |
+
)
|
| 111 |
+
result = trainer.fit()
|
| 112 |
+
|
| 113 |
+
.. testoutput::
|
| 114 |
+
:hide:
|
| 115 |
+
|
| 116 |
+
...
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
datasets: The Ray Datasets to use for training and validation. Must include a
|
| 120 |
+
"train" key denoting the training dataset. All non-training datasets will
|
| 121 |
+
be used as separate validation sets, each reporting a separate metric.
|
| 122 |
+
label_column: Name of the label column. A column with this name
|
| 123 |
+
must be present in the training dataset.
|
| 124 |
+
params: LightGBM training parameters passed to ``lightgbm.train()``.
|
| 125 |
+
Refer to `LightGBM documentation <https://lightgbm.readthedocs.io>`_
|
| 126 |
+
for a list of possible parameters.
|
| 127 |
+
num_boost_round: Target number of boosting iterations (trees in the model).
|
| 128 |
+
Note that unlike in ``lightgbm.train``, this is the target number
|
| 129 |
+
of trees, meaning that if you set ``num_boost_round=10`` and pass a model
|
| 130 |
+
that has already been trained for 5 iterations, it will be trained for 5
|
| 131 |
+
iterations more, instead of 10 more.
|
| 132 |
+
scaling_config: Configuration for how to scale data parallel training.
|
| 133 |
+
run_config: Configuration for the execution of the training run.
|
| 134 |
+
resume_from_checkpoint: A checkpoint to resume training from.
|
| 135 |
+
metadata: Dict that should be made available in `checkpoint.get_metadata()`
|
| 136 |
+
for checkpoints saved from this Trainer. Must be JSON-serializable.
|
| 137 |
+
**train_kwargs: Additional kwargs passed to ``lightgbm.train()`` function.
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
_handles_checkpoint_freq = True
|
| 141 |
+
_handles_checkpoint_at_end = True
|
| 142 |
+
|
| 143 |
+
def __init__(
|
| 144 |
+
self,
|
| 145 |
+
*,
|
| 146 |
+
datasets: Dict[str, GenDataset],
|
| 147 |
+
label_column: str,
|
| 148 |
+
params: Dict[str, Any],
|
| 149 |
+
num_boost_round: int = 10,
|
| 150 |
+
scaling_config: Optional[ray.train.ScalingConfig] = None,
|
| 151 |
+
run_config: Optional[ray.train.RunConfig] = None,
|
| 152 |
+
dataset_config: Optional[ray.train.DataConfig] = None,
|
| 153 |
+
resume_from_checkpoint: Optional[Checkpoint] = None,
|
| 154 |
+
metadata: Optional[Dict[str, Any]] = None,
|
| 155 |
+
dmatrix_params: Optional[Dict[str, Dict[str, Any]]] = _DEPRECATED_VALUE,
|
| 156 |
+
**train_kwargs,
|
| 157 |
+
):
|
| 158 |
+
# TODO(justinvyu): [Deprecated] Remove in 2.11
|
| 159 |
+
if dmatrix_params != _DEPRECATED_VALUE:
|
| 160 |
+
raise DeprecationWarning(
|
| 161 |
+
"`dmatrix_params` is deprecated, since XGBoostTrainer no longer "
|
| 162 |
+
"depends on the `xgboost_ray.RayDMatrix` utility. "
|
| 163 |
+
"You can remove this argument and use `dataset_config` instead "
|
| 164 |
+
"to customize Ray Dataset ingestion."
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Initialize a default Ray Train metrics/checkpoint reporting callback if needed
|
| 168 |
+
callbacks = train_kwargs.get("callbacks", [])
|
| 169 |
+
user_supplied_callback = any(
|
| 170 |
+
isinstance(callback, RayTrainReportCallback) for callback in callbacks
|
| 171 |
+
)
|
| 172 |
+
callback_kwargs = {}
|
| 173 |
+
if run_config:
|
| 174 |
+
checkpoint_frequency = run_config.checkpoint_config.checkpoint_frequency
|
| 175 |
+
checkpoint_at_end = run_config.checkpoint_config.checkpoint_at_end
|
| 176 |
+
|
| 177 |
+
callback_kwargs["frequency"] = checkpoint_frequency
|
| 178 |
+
# Default `checkpoint_at_end=True` unless the user explicitly sets it.
|
| 179 |
+
callback_kwargs["checkpoint_at_end"] = (
|
| 180 |
+
checkpoint_at_end if checkpoint_at_end is not None else True
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
if not user_supplied_callback:
|
| 184 |
+
callbacks.append(RayTrainReportCallback(**callback_kwargs))
|
| 185 |
+
train_kwargs["callbacks"] = callbacks
|
| 186 |
+
|
| 187 |
+
train_fn_per_worker = partial(
|
| 188 |
+
_lightgbm_train_fn_per_worker,
|
| 189 |
+
label_column=label_column,
|
| 190 |
+
num_boost_round=num_boost_round,
|
| 191 |
+
dataset_keys=set(datasets),
|
| 192 |
+
lightgbm_train_kwargs=train_kwargs,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
super(LightGBMTrainer, self).__init__(
|
| 196 |
+
train_loop_per_worker=train_fn_per_worker,
|
| 197 |
+
train_loop_config=params,
|
| 198 |
+
scaling_config=scaling_config,
|
| 199 |
+
run_config=run_config,
|
| 200 |
+
datasets=datasets,
|
| 201 |
+
dataset_config=dataset_config,
|
| 202 |
+
resume_from_checkpoint=resume_from_checkpoint,
|
| 203 |
+
metadata=metadata,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
@classmethod
|
| 207 |
+
def get_model(
|
| 208 |
+
cls,
|
| 209 |
+
checkpoint: Checkpoint,
|
| 210 |
+
) -> lightgbm.Booster:
|
| 211 |
+
"""Retrieve the LightGBM model stored in this checkpoint."""
|
| 212 |
+
return RayTrainReportCallback.get_model(checkpoint)
|
| 213 |
+
|
| 214 |
+
def _validate_attributes(self):
|
| 215 |
+
super()._validate_attributes()
|
| 216 |
+
|
| 217 |
+
if TRAIN_DATASET_KEY not in self.datasets:
|
| 218 |
+
raise KeyError(
|
| 219 |
+
f"'{TRAIN_DATASET_KEY}' key must be preset in `datasets`. "
|
| 220 |
+
f"Got {list(self.datasets.keys())}"
|
| 221 |
+
)
|
.venv/lib/python3.11/site-packages/ray/train/lightgbm/v2.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Any, Callable, Dict, Optional, Union
|
| 3 |
+
|
| 4 |
+
import ray.train
|
| 5 |
+
from ray.train import Checkpoint
|
| 6 |
+
from ray.train.data_parallel_trainer import DataParallelTrainer
|
| 7 |
+
from ray.train.lightgbm.config import LightGBMConfig, get_network_params # noqa: F401
|
| 8 |
+
from ray.train.trainer import GenDataset
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class LightGBMTrainer(DataParallelTrainer):
|
| 14 |
+
"""A Trainer for distributed data-parallel LightGBM training.
|
| 15 |
+
|
| 16 |
+
Example
|
| 17 |
+
-------
|
| 18 |
+
|
| 19 |
+
.. testcode::
|
| 20 |
+
|
| 21 |
+
import lightgbm as lgb
|
| 22 |
+
|
| 23 |
+
import ray.data
|
| 24 |
+
import ray.train
|
| 25 |
+
from ray.train.lightgbm import RayTrainReportCallback
|
| 26 |
+
from ray.train.lightgbm.v2 import LightGBMTrainer
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def train_fn_per_worker(config: dict):
|
| 30 |
+
# (Optional) Add logic to resume training state from a checkpoint.
|
| 31 |
+
# ray.train.get_checkpoint()
|
| 32 |
+
|
| 33 |
+
# 1. Get the dataset shard for the worker and convert to a `lgb.Dataset`
|
| 34 |
+
train_ds_iter, eval_ds_iter = (
|
| 35 |
+
ray.train.get_dataset_shard("train"),
|
| 36 |
+
ray.train.get_dataset_shard("validation"),
|
| 37 |
+
)
|
| 38 |
+
train_ds, eval_ds = train_ds_iter.materialize(), eval_ds_iter.materialize()
|
| 39 |
+
train_df, eval_df = train_ds.to_pandas(), eval_ds.to_pandas()
|
| 40 |
+
train_X, train_y = train_df.drop("y", axis=1), train_df["y"]
|
| 41 |
+
eval_X, eval_y = eval_df.drop("y", axis=1), eval_df["y"]
|
| 42 |
+
|
| 43 |
+
train_set = lgb.Dataset(train_X, label=train_y)
|
| 44 |
+
eval_set = lgb.Dataset(eval_X, label=eval_y)
|
| 45 |
+
|
| 46 |
+
# 2. Run distributed data-parallel training.
|
| 47 |
+
# `get_network_params` sets up the necessary configurations for LightGBM
|
| 48 |
+
# to set up the data parallel training worker group on your Ray cluster.
|
| 49 |
+
params = {
|
| 50 |
+
"objective": "regression",
|
| 51 |
+
# Adding the line below is the only change needed
|
| 52 |
+
# for your `lgb.train` call!
|
| 53 |
+
**ray.train.lightgbm.v2.get_network_params(),
|
| 54 |
+
}
|
| 55 |
+
lgb.train(
|
| 56 |
+
params,
|
| 57 |
+
train_set,
|
| 58 |
+
valid_sets=[eval_set],
|
| 59 |
+
valid_names=["eval"],
|
| 60 |
+
callbacks=[RayTrainReportCallback()],
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
train_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)])
|
| 64 |
+
eval_ds = ray.data.from_items(
|
| 65 |
+
[{"x": x, "y": x + 1} for x in range(32, 32 + 16)]
|
| 66 |
+
)
|
| 67 |
+
trainer = LightGBMTrainer(
|
| 68 |
+
train_fn_per_worker,
|
| 69 |
+
datasets={"train": train_ds, "validation": eval_ds},
|
| 70 |
+
scaling_config=ray.train.ScalingConfig(num_workers=4),
|
| 71 |
+
)
|
| 72 |
+
result = trainer.fit()
|
| 73 |
+
booster = RayTrainReportCallback.get_model(result.checkpoint)
|
| 74 |
+
|
| 75 |
+
.. testoutput::
|
| 76 |
+
:hide:
|
| 77 |
+
|
| 78 |
+
...
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
train_loop_per_worker: The training function to execute on each worker.
|
| 82 |
+
This function can either take in zero arguments or a single ``Dict``
|
| 83 |
+
argument which is set by defining ``train_loop_config``.
|
| 84 |
+
Within this function you can use any of the
|
| 85 |
+
:ref:`Ray Train Loop utilities <train-loop-api>`.
|
| 86 |
+
train_loop_config: A configuration ``Dict`` to pass in as an argument to
|
| 87 |
+
``train_loop_per_worker``.
|
| 88 |
+
This is typically used for specifying hyperparameters.
|
| 89 |
+
lightgbm_config: The configuration for setting up the distributed lightgbm
|
| 90 |
+
backend. See :class:`~ray.train.lightgbm.LightGBMConfig` for more info.
|
| 91 |
+
datasets: The Ray Datasets to use for training and validation.
|
| 92 |
+
dataset_config: The configuration for ingesting the input ``datasets``.
|
| 93 |
+
By default, all the Ray Dataset are split equally across workers.
|
| 94 |
+
See :class:`~ray.train.DataConfig` for more details.
|
| 95 |
+
scaling_config: The configuration for how to scale data parallel training.
|
| 96 |
+
``num_workers`` determines how many Python processes are used for training,
|
| 97 |
+
and ``use_gpu`` determines whether or not each process should use GPUs.
|
| 98 |
+
See :class:`~ray.train.ScalingConfig` for more info.
|
| 99 |
+
run_config: The configuration for the execution of the training run.
|
| 100 |
+
See :class:`~ray.train.RunConfig` for more info.
|
| 101 |
+
resume_from_checkpoint: A checkpoint to resume training from.
|
| 102 |
+
This checkpoint can be accessed from within ``train_loop_per_worker``
|
| 103 |
+
by calling ``ray.train.get_checkpoint()``.
|
| 104 |
+
metadata: Dict that should be made available via
|
| 105 |
+
`ray.train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
|
| 106 |
+
for checkpoints saved from this Trainer. Must be JSON-serializable.
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
def __init__(
|
| 110 |
+
self,
|
| 111 |
+
train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]],
|
| 112 |
+
*,
|
| 113 |
+
train_loop_config: Optional[Dict] = None,
|
| 114 |
+
lightgbm_config: Optional[LightGBMConfig] = None,
|
| 115 |
+
scaling_config: Optional[ray.train.ScalingConfig] = None,
|
| 116 |
+
run_config: Optional[ray.train.RunConfig] = None,
|
| 117 |
+
datasets: Optional[Dict[str, GenDataset]] = None,
|
| 118 |
+
dataset_config: Optional[ray.train.DataConfig] = None,
|
| 119 |
+
metadata: Optional[Dict[str, Any]] = None,
|
| 120 |
+
resume_from_checkpoint: Optional[Checkpoint] = None,
|
| 121 |
+
):
|
| 122 |
+
super(LightGBMTrainer, self).__init__(
|
| 123 |
+
train_loop_per_worker=train_loop_per_worker,
|
| 124 |
+
train_loop_config=train_loop_config,
|
| 125 |
+
backend_config=lightgbm_config or LightGBMConfig(),
|
| 126 |
+
scaling_config=scaling_config,
|
| 127 |
+
dataset_config=dataset_config,
|
| 128 |
+
run_config=run_config,
|
| 129 |
+
datasets=datasets,
|
| 130 |
+
resume_from_checkpoint=resume_from_checkpoint,
|
| 131 |
+
metadata=metadata,
|
| 132 |
+
)
|
.venv/lib/python3.11/site-packages/ray/train/predictor.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from typing import Callable, Dict, Optional, Type, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
|
| 7 |
+
from ray.air.data_batch_type import DataBatchType
|
| 8 |
+
from ray.air.util.data_batch_conversion import (
|
| 9 |
+
BatchFormat,
|
| 10 |
+
_convert_batch_type_to_numpy,
|
| 11 |
+
_convert_batch_type_to_pandas,
|
| 12 |
+
)
|
| 13 |
+
from ray.data import Preprocessor
|
| 14 |
+
from ray.train import Checkpoint
|
| 15 |
+
from ray.util.annotations import DeveloperAPI, PublicAPI
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
import pyarrow
|
| 19 |
+
|
| 20 |
+
pa_table = pyarrow.Table
|
| 21 |
+
except ImportError:
|
| 22 |
+
pa_table = None
|
| 23 |
+
|
| 24 |
+
# Reverse mapping from data batch type to batch format.
|
| 25 |
+
TYPE_TO_ENUM: Dict[Type[DataBatchType], BatchFormat] = {
|
| 26 |
+
np.ndarray: BatchFormat.NUMPY,
|
| 27 |
+
dict: BatchFormat.NUMPY,
|
| 28 |
+
pd.DataFrame: BatchFormat.PANDAS,
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@PublicAPI(stability="beta")
|
| 33 |
+
class PredictorNotSerializableException(RuntimeError):
|
| 34 |
+
"""Error raised when trying to serialize a Predictor instance."""
|
| 35 |
+
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@PublicAPI(stability="beta")
|
| 40 |
+
class Predictor(abc.ABC):
|
| 41 |
+
"""Predictors load models from checkpoints to perform inference.
|
| 42 |
+
|
| 43 |
+
.. note::
|
| 44 |
+
The base ``Predictor`` class cannot be instantiated directly. Only one of
|
| 45 |
+
its subclasses can be used.
|
| 46 |
+
|
| 47 |
+
**How does a Predictor work?**
|
| 48 |
+
|
| 49 |
+
Predictors expose a ``predict`` method that accepts an input batch of type
|
| 50 |
+
``DataBatchType`` and outputs predictions of the same type as the input batch.
|
| 51 |
+
|
| 52 |
+
When the ``predict`` method is called the following occurs:
|
| 53 |
+
|
| 54 |
+
- The input batch is converted into a pandas DataFrame. Tensor input (like a
|
| 55 |
+
``np.ndarray``) will be converted into a single column Pandas Dataframe.
|
| 56 |
+
- If there is a :ref:`Preprocessor <preprocessor-ref>` saved in the provided
|
| 57 |
+
:class:`Checkpoint <ray.train.Checkpoint>`, the preprocessor will be used to
|
| 58 |
+
transform the DataFrame.
|
| 59 |
+
- The transformed DataFrame will be passed to the model for inference (via the
|
| 60 |
+
``predictor._predict_pandas`` method).
|
| 61 |
+
- The predictions will be outputted by ``predict`` in the same type as the
|
| 62 |
+
original input.
|
| 63 |
+
|
| 64 |
+
**How do I create a new Predictor?**
|
| 65 |
+
|
| 66 |
+
To implement a new Predictor for your particular framework, you should subclass
|
| 67 |
+
the base ``Predictor`` and implement the following two methods:
|
| 68 |
+
|
| 69 |
+
1. ``_predict_pandas``: Given a pandas.DataFrame input, return a
|
| 70 |
+
pandas.DataFrame containing predictions.
|
| 71 |
+
2. ``from_checkpoint``: Logic for creating a Predictor from a
|
| 72 |
+
:class:`Checkpoint <ray.train.Checkpoint>`.
|
| 73 |
+
3. Optionally ``_predict_numpy`` for better performance when working with
|
| 74 |
+
tensor data to avoid extra copies from Pandas conversions.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(self, preprocessor: Optional[Preprocessor] = None):
|
| 78 |
+
"""Subclasseses must call Predictor.__init__() to set a preprocessor."""
|
| 79 |
+
self._preprocessor: Optional[Preprocessor] = preprocessor
|
| 80 |
+
# Whether tensor columns should be automatically cast from/to the tensor
|
| 81 |
+
# extension type at UDF boundaries. This can be overridden by subclasses.
|
| 82 |
+
self._cast_tensor_columns = False
|
| 83 |
+
|
| 84 |
+
@classmethod
|
| 85 |
+
@abc.abstractmethod
|
| 86 |
+
def from_checkpoint(cls, checkpoint: Checkpoint, **kwargs) -> "Predictor":
|
| 87 |
+
"""Create a specific predictor from a checkpoint.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
checkpoint: Checkpoint to load predictor data from.
|
| 91 |
+
kwargs: Arguments specific to predictor implementations.
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
Predictor: Predictor object.
|
| 95 |
+
"""
|
| 96 |
+
raise NotImplementedError
|
| 97 |
+
|
| 98 |
+
@classmethod
|
| 99 |
+
def from_pandas_udf(
|
| 100 |
+
cls, pandas_udf: Callable[[pd.DataFrame], pd.DataFrame]
|
| 101 |
+
) -> "Predictor":
|
| 102 |
+
"""Create a Predictor from a Pandas UDF.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
pandas_udf: A function that takes a pandas.DataFrame and other
|
| 106 |
+
optional kwargs and returns a pandas.DataFrame.
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
class PandasUDFPredictor(Predictor):
|
| 110 |
+
@classmethod
|
| 111 |
+
def from_checkpoint(cls, checkpoint: Checkpoint, **kwargs) -> "Predictor":
|
| 112 |
+
return PandasUDFPredictor()
|
| 113 |
+
|
| 114 |
+
def _predict_pandas(self, df, **kwargs) -> "pd.DataFrame":
|
| 115 |
+
return pandas_udf(df, **kwargs)
|
| 116 |
+
|
| 117 |
+
return PandasUDFPredictor()
|
| 118 |
+
|
| 119 |
+
def get_preprocessor(self) -> Optional[Preprocessor]:
|
| 120 |
+
"""Get the preprocessor to use prior to executing predictions."""
|
| 121 |
+
return self._preprocessor
|
| 122 |
+
|
| 123 |
+
def set_preprocessor(self, preprocessor: Optional[Preprocessor]) -> None:
|
| 124 |
+
"""Set the preprocessor to use prior to executing predictions."""
|
| 125 |
+
self._preprocessor = preprocessor
|
| 126 |
+
|
| 127 |
+
@classmethod
|
| 128 |
+
@DeveloperAPI
|
| 129 |
+
def preferred_batch_format(cls) -> BatchFormat:
|
| 130 |
+
"""Batch format hint for upstream producers to try yielding best block format.
|
| 131 |
+
|
| 132 |
+
The preferred batch format to use if both `_predict_pandas` and
|
| 133 |
+
`_predict_numpy` are implemented. Defaults to Pandas.
|
| 134 |
+
|
| 135 |
+
Can be overriden by predictor classes depending on the framework type,
|
| 136 |
+
e.g. TorchPredictor prefers Numpy and XGBoostPredictor prefers Pandas as
|
| 137 |
+
native batch format.
|
| 138 |
+
|
| 139 |
+
"""
|
| 140 |
+
return BatchFormat.PANDAS
|
| 141 |
+
|
| 142 |
+
@classmethod
|
| 143 |
+
def _batch_format_to_use(cls) -> BatchFormat:
|
| 144 |
+
"""Determine the batch format to use for the predictor."""
|
| 145 |
+
has_pandas_implemented = cls._predict_pandas != Predictor._predict_pandas
|
| 146 |
+
has_numpy_implemented = cls._predict_numpy != Predictor._predict_numpy
|
| 147 |
+
if has_pandas_implemented and has_numpy_implemented:
|
| 148 |
+
return cls.preferred_batch_format()
|
| 149 |
+
elif has_pandas_implemented:
|
| 150 |
+
return BatchFormat.PANDAS
|
| 151 |
+
elif has_numpy_implemented:
|
| 152 |
+
return BatchFormat.NUMPY
|
| 153 |
+
else:
|
| 154 |
+
raise NotImplementedError(
|
| 155 |
+
f"Predictor {cls.__name__} must implement at least one of "
|
| 156 |
+
"`_predict_pandas` and `_predict_numpy`."
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
def _set_cast_tensor_columns(self):
|
| 160 |
+
"""Enable automatic tensor column casting.
|
| 161 |
+
|
| 162 |
+
If this is called on a predictor, the predictor will cast tensor columns to
|
| 163 |
+
NumPy ndarrays in the input to the preprocessors and cast tensor columns back to
|
| 164 |
+
the tensor extension type in the prediction outputs.
|
| 165 |
+
"""
|
| 166 |
+
self._cast_tensor_columns = True
|
| 167 |
+
|
| 168 |
+
def predict(self, data: DataBatchType, **kwargs) -> DataBatchType:
|
| 169 |
+
"""Perform inference on a batch of data.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
data: A batch of input data of type ``DataBatchType``.
|
| 173 |
+
kwargs: Arguments specific to predictor implementations. These are passed
|
| 174 |
+
directly to ``_predict_numpy`` or ``_predict_pandas``.
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
DataBatchType:
|
| 178 |
+
Prediction result. The return type will be the same as the input type.
|
| 179 |
+
"""
|
| 180 |
+
if not hasattr(self, "_preprocessor"):
|
| 181 |
+
raise NotImplementedError(
|
| 182 |
+
"Subclasses of Predictor must call Predictor.__init__(preprocessor)."
|
| 183 |
+
)
|
| 184 |
+
try:
|
| 185 |
+
batch_format = TYPE_TO_ENUM[type(data)]
|
| 186 |
+
except KeyError:
|
| 187 |
+
raise RuntimeError(
|
| 188 |
+
f"Invalid input data type of {type(data)}, supported "
|
| 189 |
+
f"types: {list(TYPE_TO_ENUM.keys())}"
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
if self._preprocessor:
|
| 193 |
+
data = self._preprocessor.transform_batch(data)
|
| 194 |
+
|
| 195 |
+
batch_format_to_use = self._batch_format_to_use()
|
| 196 |
+
|
| 197 |
+
# We can finish prediction as long as one predict method is implemented.
|
| 198 |
+
# For prediction, we have to return back in the same format as the input.
|
| 199 |
+
if batch_format == BatchFormat.PANDAS:
|
| 200 |
+
if batch_format_to_use == BatchFormat.PANDAS:
|
| 201 |
+
return self._predict_pandas(
|
| 202 |
+
_convert_batch_type_to_pandas(data), **kwargs
|
| 203 |
+
)
|
| 204 |
+
elif batch_format_to_use == BatchFormat.NUMPY:
|
| 205 |
+
return _convert_batch_type_to_pandas(
|
| 206 |
+
self._predict_numpy(_convert_batch_type_to_numpy(data), **kwargs)
|
| 207 |
+
)
|
| 208 |
+
elif batch_format == BatchFormat.NUMPY:
|
| 209 |
+
if batch_format_to_use == BatchFormat.PANDAS:
|
| 210 |
+
return _convert_batch_type_to_numpy(
|
| 211 |
+
self._predict_pandas(_convert_batch_type_to_pandas(data), **kwargs)
|
| 212 |
+
)
|
| 213 |
+
elif batch_format_to_use == BatchFormat.NUMPY:
|
| 214 |
+
return self._predict_numpy(_convert_batch_type_to_numpy(data), **kwargs)
|
| 215 |
+
|
| 216 |
+
@DeveloperAPI
|
| 217 |
+
def _predict_pandas(self, data: "pd.DataFrame", **kwargs) -> "pd.DataFrame":
|
| 218 |
+
"""Perform inference on a Pandas DataFrame.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
data: A pandas DataFrame to perform predictions on.
|
| 222 |
+
kwargs: Arguments specific to the predictor implementation.
|
| 223 |
+
|
| 224 |
+
Returns:
|
| 225 |
+
A pandas DataFrame containing the prediction result.
|
| 226 |
+
|
| 227 |
+
"""
|
| 228 |
+
raise NotImplementedError
|
| 229 |
+
|
| 230 |
+
@DeveloperAPI
|
| 231 |
+
def _predict_numpy(
|
| 232 |
+
self, data: Union[np.ndarray, Dict[str, np.ndarray]], **kwargs
|
| 233 |
+
) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
| 234 |
+
"""Perform inference on a Numpy data.
|
| 235 |
+
|
| 236 |
+
All Predictors working with tensor data (like deep learning predictors)
|
| 237 |
+
should implement this method.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
data: A Numpy ndarray or dictionary of ndarrays to perform predictions on.
|
| 241 |
+
kwargs: Arguments specific to the predictor implementation.
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
A Numpy ndarray or dictionary of ndarray containing the prediction result.
|
| 245 |
+
|
| 246 |
+
"""
|
| 247 |
+
raise NotImplementedError
|
| 248 |
+
|
| 249 |
+
def __reduce__(self):
|
| 250 |
+
raise PredictorNotSerializableException(
|
| 251 |
+
"Predictor instances are not serializable. Instead, you may want "
|
| 252 |
+
"to serialize a checkpoint and initialize the Predictor with "
|
| 253 |
+
"Predictor.from_checkpoint."
|
| 254 |
+
)
|
.venv/lib/python3.11/site-packages/ray/train/session.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/train/trainer.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import traceback
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
|
| 5 |
+
|
| 6 |
+
from ray.air._internal.util import (
|
| 7 |
+
StartTraceback,
|
| 8 |
+
StartTracebackWithWorkerRank,
|
| 9 |
+
skip_exceptions,
|
| 10 |
+
)
|
| 11 |
+
from ray.data import Dataset
|
| 12 |
+
from ray.train import Checkpoint, DataConfig
|
| 13 |
+
from ray.train._internal.backend_executor import (
|
| 14 |
+
BackendExecutor,
|
| 15 |
+
InactiveWorkerGroupError,
|
| 16 |
+
TrainBackendError,
|
| 17 |
+
TrainingWorkerError,
|
| 18 |
+
)
|
| 19 |
+
from ray.train._internal.session import _TrainingResult, _TrainSession, get_session
|
| 20 |
+
from ray.train._internal.utils import ActorWrapper
|
| 21 |
+
from ray.train.backend import BackendConfig
|
| 22 |
+
from ray.train.base_trainer import ( # noqa: F401
|
| 23 |
+
BaseTrainer,
|
| 24 |
+
GenDataset,
|
| 25 |
+
TrainingFailedError,
|
| 26 |
+
)
|
| 27 |
+
from ray.util.annotations import DeveloperAPI
|
| 28 |
+
|
| 29 |
+
T = TypeVar("T")
|
| 30 |
+
S = TypeVar("S")
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@DeveloperAPI
|
| 36 |
+
class TrainingIterator:
|
| 37 |
+
"""An iterator over Train results. Returned by ``trainer.run_iterator``."""
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
backend_executor: Union[BackendExecutor, ActorWrapper],
|
| 42 |
+
backend_config: BackendConfig,
|
| 43 |
+
train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
|
| 44 |
+
datasets: Dict[str, Dataset],
|
| 45 |
+
metadata: Dict[str, Any],
|
| 46 |
+
data_config: DataConfig,
|
| 47 |
+
checkpoint: Optional[Union[Dict, str, Path, Checkpoint]],
|
| 48 |
+
):
|
| 49 |
+
self._backend_executor = backend_executor
|
| 50 |
+
self._backend = backend_config.backend_cls()
|
| 51 |
+
self._train_func = train_func
|
| 52 |
+
self._datasets = datasets
|
| 53 |
+
self._metadata = metadata
|
| 54 |
+
self._data_config = data_config
|
| 55 |
+
|
| 56 |
+
self._start_training(
|
| 57 |
+
train_func=train_func,
|
| 58 |
+
datasets=self._datasets,
|
| 59 |
+
metadata=self._metadata,
|
| 60 |
+
data_config=self._data_config,
|
| 61 |
+
checkpoint=checkpoint,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
self._finished_training = False
|
| 65 |
+
|
| 66 |
+
def __iter__(self):
|
| 67 |
+
return self
|
| 68 |
+
|
| 69 |
+
def _start_training(
|
| 70 |
+
self,
|
| 71 |
+
train_func,
|
| 72 |
+
datasets,
|
| 73 |
+
metadata,
|
| 74 |
+
data_config,
|
| 75 |
+
checkpoint: Optional[Checkpoint] = None,
|
| 76 |
+
):
|
| 77 |
+
tune_session: _TrainSession = get_session()
|
| 78 |
+
assert tune_session, "`_start_training` should only be called from within Tune"
|
| 79 |
+
storage = tune_session.storage
|
| 80 |
+
|
| 81 |
+
self._run_with_error_handling(
|
| 82 |
+
lambda: self._backend_executor.start_training(
|
| 83 |
+
train_func=train_func,
|
| 84 |
+
datasets=datasets,
|
| 85 |
+
metadata=metadata,
|
| 86 |
+
data_config=data_config,
|
| 87 |
+
storage=storage,
|
| 88 |
+
checkpoint=checkpoint,
|
| 89 |
+
)
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
def _run_with_error_handling(self, func: Callable):
|
| 93 |
+
try:
|
| 94 |
+
return func()
|
| 95 |
+
except TrainingWorkerError:
|
| 96 |
+
# TODO(ml-team): This Train fault-tolerance code doesn't get used
|
| 97 |
+
# since max_retries=0
|
| 98 |
+
# Workers have already been restarted.
|
| 99 |
+
logger.info(
|
| 100 |
+
"Workers have been successfully restarted. Resuming "
|
| 101 |
+
"training from latest checkpoint."
|
| 102 |
+
)
|
| 103 |
+
self._start_training(
|
| 104 |
+
self._train_func,
|
| 105 |
+
self._datasets,
|
| 106 |
+
self._metadata,
|
| 107 |
+
self._data_config,
|
| 108 |
+
)
|
| 109 |
+
return self._run_with_error_handling(func)
|
| 110 |
+
except InactiveWorkerGroupError:
|
| 111 |
+
raise RuntimeError(
|
| 112 |
+
"This Trainer is not active. It is either shutdown "
|
| 113 |
+
"already or never started in the first place. "
|
| 114 |
+
"Either create a new Trainer or start this one."
|
| 115 |
+
) from None
|
| 116 |
+
except TrainBackendError:
|
| 117 |
+
raise RuntimeError(
|
| 118 |
+
"Training failed. You should not be seeing "
|
| 119 |
+
"this error and this is a bug. Please create "
|
| 120 |
+
"a new issue at "
|
| 121 |
+
"https://github.com/ray-project/ray."
|
| 122 |
+
) from None
|
| 123 |
+
|
| 124 |
+
def __next__(self):
|
| 125 |
+
if self.is_finished():
|
| 126 |
+
self._backend_executor.report_final_run_status(errored=False)
|
| 127 |
+
raise StopIteration
|
| 128 |
+
try:
|
| 129 |
+
next_results = self._run_with_error_handling(self._fetch_next_result)
|
| 130 |
+
if next_results is None:
|
| 131 |
+
self._backend_executor.report_final_run_status(errored=False)
|
| 132 |
+
self._run_with_error_handling(self._finish_training)
|
| 133 |
+
self._finished_training = True
|
| 134 |
+
raise StopIteration
|
| 135 |
+
else:
|
| 136 |
+
return next_results
|
| 137 |
+
except StartTraceback as e:
|
| 138 |
+
# If this is a StartTraceback, then this is a user error.
|
| 139 |
+
# We raise it directly
|
| 140 |
+
if isinstance(e, StartTracebackWithWorkerRank):
|
| 141 |
+
failed_rank = e.worker_rank
|
| 142 |
+
else:
|
| 143 |
+
failed_rank = None
|
| 144 |
+
|
| 145 |
+
# Extract the stack trace from the exception
|
| 146 |
+
e = skip_exceptions(e)
|
| 147 |
+
stack_trace = "".join(
|
| 148 |
+
traceback.format_exception(type(e), e, e.__traceback__)
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
self._backend_executor.report_final_run_status(
|
| 152 |
+
errored=True, stack_trace=stack_trace, failed_rank=failed_rank
|
| 153 |
+
)
|
| 154 |
+
try:
|
| 155 |
+
# Exception raised in at least one training worker. Immediately raise
|
| 156 |
+
# this error to the user and do not attempt to terminate gracefully.
|
| 157 |
+
self._backend_executor.shutdown(graceful_termination=False)
|
| 158 |
+
self._finished_training = True
|
| 159 |
+
except Exception:
|
| 160 |
+
pass
|
| 161 |
+
raise
|
| 162 |
+
|
| 163 |
+
def _fetch_next_result(self) -> Optional[List[Dict]]:
|
| 164 |
+
"""Fetch next results produced by ``session.report()`` from each worker.
|
| 165 |
+
|
| 166 |
+
Assumes ``start_training`` has already been called.
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
A list of dictionaries of values passed to ``session.report()`` from
|
| 170 |
+
each worker. Each item corresponds to an intermediate result
|
| 171 |
+
a single worker. If there are no more items to fetch,
|
| 172 |
+
returns None.
|
| 173 |
+
"""
|
| 174 |
+
results = self._backend_executor.get_next_results()
|
| 175 |
+
if results is None:
|
| 176 |
+
return None
|
| 177 |
+
assert all(isinstance(result, _TrainingResult) for result in results)
|
| 178 |
+
return results
|
| 179 |
+
|
| 180 |
+
def _finish_training(self):
|
| 181 |
+
"""Finish training and return final results. Propagate any exceptions.
|
| 182 |
+
|
| 183 |
+
Blocks until training is finished on all workers.
|
| 184 |
+
|
| 185 |
+
Assumes `start_training` has already been called.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
A list of return values from calling ``train_func`` on each worker.
|
| 189 |
+
Each item corresponds to the return value from a single worker.
|
| 190 |
+
"""
|
| 191 |
+
return self._backend_executor.finish_training()
|
| 192 |
+
|
| 193 |
+
def is_finished(self) -> bool:
|
| 194 |
+
return self._finished_training
|
.venv/lib/python3.11/site-packages/ray/train/utils.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
|
| 3 |
+
from ray.util.annotations import RayDeprecationWarning
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _copy_doc(copy_func):
|
| 7 |
+
def wrapped(func):
|
| 8 |
+
func.__doc__ = copy_func.__doc__
|
| 9 |
+
return func
|
| 10 |
+
|
| 11 |
+
return wrapped
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _log_deprecation_warning(message):
|
| 15 |
+
warnings.warn(
|
| 16 |
+
message,
|
| 17 |
+
RayDeprecationWarning,
|
| 18 |
+
stacklevel=2,
|
| 19 |
+
)
|
.venv/lib/python3.11/site-packages/ray/train/xgboost/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.train.v2._internal.constants import is_v2_enabled
|
| 2 |
+
from ray.train.xgboost._xgboost_utils import RayTrainReportCallback
|
| 3 |
+
from ray.train.xgboost.config import XGBoostConfig
|
| 4 |
+
from ray.train.xgboost.xgboost_checkpoint import XGBoostCheckpoint
|
| 5 |
+
from ray.train.xgboost.xgboost_predictor import XGBoostPredictor
|
| 6 |
+
from ray.train.xgboost.xgboost_trainer import XGBoostTrainer
|
| 7 |
+
|
| 8 |
+
if is_v2_enabled():
|
| 9 |
+
from ray.train.v2.xgboost.xgboost_trainer import XGBoostTrainer # noqa: F811
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"RayTrainReportCallback",
|
| 13 |
+
"XGBoostCheckpoint",
|
| 14 |
+
"XGBoostConfig",
|
| 15 |
+
"XGBoostPredictor",
|
| 16 |
+
"XGBoostTrainer",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# DO NOT ADD ANYTHING AFTER THIS LINE.
|
.venv/lib/python3.11/site-packages/ray/train/xgboost/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (883 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/xgboost/__pycache__/_xgboost_utils.cpython-311.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/xgboost/__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (10.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/xgboost/__pycache__/v2.cpython-311.pyc
ADDED
|
Binary file (6.67 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/xgboost/__pycache__/xgboost_checkpoint.cpython-311.pyc
ADDED
|
Binary file (4.11 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/xgboost/__pycache__/xgboost_predictor.cpython-311.pyc
ADDED
|
Binary file (7.94 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/xgboost/__pycache__/xgboost_trainer.cpython-311.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/train/xgboost/_xgboost_utils.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tempfile
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
from contextlib import contextmanager
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Callable, Dict, List, Optional, Union
|
| 6 |
+
|
| 7 |
+
from xgboost.core import Booster
|
| 8 |
+
|
| 9 |
+
import ray.train
|
| 10 |
+
from ray.train import Checkpoint
|
| 11 |
+
from ray.tune.utils import flatten_dict
|
| 12 |
+
from ray.util.annotations import PublicAPI
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from xgboost.callback import TrainingCallback
|
| 16 |
+
except ImportError:
|
| 17 |
+
|
| 18 |
+
class TrainingCallback:
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class TuneCallback(TrainingCallback):
|
| 23 |
+
# TODO(justinvyu): [code_removal] Remove this after enforcing min xgboost version.
|
| 24 |
+
"""Base class for Tune's XGBoost callbacks."""
|
| 25 |
+
|
| 26 |
+
def __call__(self, env):
|
| 27 |
+
"""Compatibility with xgboost<1.3"""
|
| 28 |
+
return self.after_iteration(
|
| 29 |
+
env.model, env.iteration, env.evaluation_result_list
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):
|
| 33 |
+
raise NotImplementedError
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@PublicAPI(stability="beta")
|
| 37 |
+
class RayTrainReportCallback(TuneCallback):
|
| 38 |
+
"""XGBoost callback to save checkpoints and report metrics.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
metrics: Metrics to report. If this is a list,
|
| 42 |
+
each item describes the metric key reported to XGBoost,
|
| 43 |
+
and it will be reported under the same name.
|
| 44 |
+
This can also be a dict of {<key-to-report>: <xgboost-metric-key>},
|
| 45 |
+
which can be used to rename xgboost default metrics.
|
| 46 |
+
filename: Customize the saved checkpoint file type by passing
|
| 47 |
+
a filename. Defaults to "model.ubj".
|
| 48 |
+
frequency: How often to save checkpoints, in terms of iterations.
|
| 49 |
+
Defaults to 0 (no checkpoints are saved during training).
|
| 50 |
+
checkpoint_at_end: Whether or not to save a checkpoint at the end of training.
|
| 51 |
+
results_postprocessing_fn: An optional Callable that takes in
|
| 52 |
+
the metrics dict that will be reported (after it has been flattened)
|
| 53 |
+
and returns a modified dict. For example, this can be used to
|
| 54 |
+
average results across CV fold when using ``xgboost.cv``.
|
| 55 |
+
|
| 56 |
+
Examples
|
| 57 |
+
--------
|
| 58 |
+
|
| 59 |
+
Reporting checkpoints and metrics to Ray Tune when running many
|
| 60 |
+
independent xgboost trials (without data parallelism within a trial).
|
| 61 |
+
|
| 62 |
+
.. testcode::
|
| 63 |
+
:skipif: True
|
| 64 |
+
|
| 65 |
+
import xgboost
|
| 66 |
+
|
| 67 |
+
from ray.tune import Tuner
|
| 68 |
+
from ray.train.xgboost import RayTrainReportCallback
|
| 69 |
+
|
| 70 |
+
def train_fn(config):
|
| 71 |
+
# Report log loss to Ray Tune after each validation epoch.
|
| 72 |
+
bst = xgboost.train(
|
| 73 |
+
...,
|
| 74 |
+
callbacks=[
|
| 75 |
+
RayTrainReportCallback(
|
| 76 |
+
metrics={"loss": "eval-logloss"}, frequency=1
|
| 77 |
+
)
|
| 78 |
+
],
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
tuner = Tuner(train_fn)
|
| 82 |
+
results = tuner.fit()
|
| 83 |
+
|
| 84 |
+
Loading a model from a checkpoint reported by this callback.
|
| 85 |
+
|
| 86 |
+
.. testcode::
|
| 87 |
+
:skipif: True
|
| 88 |
+
|
| 89 |
+
from ray.train.xgboost import RayTrainReportCallback
|
| 90 |
+
|
| 91 |
+
# Get a `Checkpoint` object that is saved by the callback during training.
|
| 92 |
+
result = trainer.fit()
|
| 93 |
+
booster = RayTrainReportCallback.get_model(result.checkpoint)
|
| 94 |
+
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
CHECKPOINT_NAME = "model.ubj"
|
| 98 |
+
|
| 99 |
+
def __init__(
|
| 100 |
+
self,
|
| 101 |
+
metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
|
| 102 |
+
filename: str = CHECKPOINT_NAME,
|
| 103 |
+
frequency: int = 0,
|
| 104 |
+
checkpoint_at_end: bool = True,
|
| 105 |
+
results_postprocessing_fn: Optional[
|
| 106 |
+
Callable[[Dict[str, Union[float, List[float]]]], Dict[str, float]]
|
| 107 |
+
] = None,
|
| 108 |
+
):
|
| 109 |
+
if isinstance(metrics, str):
|
| 110 |
+
metrics = [metrics]
|
| 111 |
+
self._metrics = metrics
|
| 112 |
+
self._filename = filename
|
| 113 |
+
self._frequency = frequency
|
| 114 |
+
self._checkpoint_at_end = checkpoint_at_end
|
| 115 |
+
self._results_postprocessing_fn = results_postprocessing_fn
|
| 116 |
+
|
| 117 |
+
# Keeps track of the eval metrics from the last iteration,
|
| 118 |
+
# so that the latest metrics can be reported with the checkpoint
|
| 119 |
+
# at the end of training.
|
| 120 |
+
self._evals_log = None
|
| 121 |
+
# Keep track of the last checkpoint iteration to avoid double-checkpointing
|
| 122 |
+
# when using `checkpoint_at_end=True`.
|
| 123 |
+
self._last_checkpoint_iteration = None
|
| 124 |
+
|
| 125 |
+
@classmethod
|
| 126 |
+
def get_model(
|
| 127 |
+
cls, checkpoint: Checkpoint, filename: str = CHECKPOINT_NAME
|
| 128 |
+
) -> Booster:
|
| 129 |
+
"""Retrieve the model stored in a checkpoint reported by this callback.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
checkpoint: The checkpoint object returned by a training run.
|
| 133 |
+
The checkpoint should be saved by an instance of this callback.
|
| 134 |
+
filename: The filename to load the model from, which should match
|
| 135 |
+
the filename used when creating the callback.
|
| 136 |
+
"""
|
| 137 |
+
with checkpoint.as_directory() as checkpoint_path:
|
| 138 |
+
booster = Booster()
|
| 139 |
+
booster.load_model(Path(checkpoint_path, filename).as_posix())
|
| 140 |
+
return booster
|
| 141 |
+
|
| 142 |
+
def _get_report_dict(self, evals_log):
|
| 143 |
+
if isinstance(evals_log, OrderedDict):
|
| 144 |
+
# xgboost>=1.3
|
| 145 |
+
result_dict = flatten_dict(evals_log, delimiter="-")
|
| 146 |
+
for k in list(result_dict):
|
| 147 |
+
result_dict[k] = result_dict[k][-1]
|
| 148 |
+
else:
|
| 149 |
+
# xgboost<1.3
|
| 150 |
+
result_dict = dict(evals_log)
|
| 151 |
+
if not self._metrics:
|
| 152 |
+
report_dict = result_dict
|
| 153 |
+
else:
|
| 154 |
+
report_dict = {}
|
| 155 |
+
for key in self._metrics:
|
| 156 |
+
if isinstance(self._metrics, dict):
|
| 157 |
+
metric = self._metrics[key]
|
| 158 |
+
else:
|
| 159 |
+
metric = key
|
| 160 |
+
report_dict[key] = result_dict[metric]
|
| 161 |
+
|
| 162 |
+
if self._results_postprocessing_fn:
|
| 163 |
+
report_dict = self._results_postprocessing_fn(report_dict)
|
| 164 |
+
|
| 165 |
+
return report_dict
|
| 166 |
+
|
| 167 |
+
@contextmanager
|
| 168 |
+
def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]:
|
| 169 |
+
# NOTE: The world rank returns None for Tune usage without Train.
|
| 170 |
+
if ray.train.get_context().get_world_rank() in (0, None):
|
| 171 |
+
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
|
| 172 |
+
model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix())
|
| 173 |
+
yield Checkpoint(temp_checkpoint_dir)
|
| 174 |
+
else:
|
| 175 |
+
yield None
|
| 176 |
+
|
| 177 |
+
def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):
|
| 178 |
+
self._evals_log = evals_log
|
| 179 |
+
|
| 180 |
+
checkpointing_disabled = self._frequency == 0
|
| 181 |
+
# Ex: if frequency=2, checkpoint at epoch 1, 3, 5, ... (counting from 0)
|
| 182 |
+
should_checkpoint = (
|
| 183 |
+
not checkpointing_disabled and (epoch + 1) % self._frequency == 0
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
report_dict = self._get_report_dict(evals_log)
|
| 187 |
+
if should_checkpoint:
|
| 188 |
+
self._last_checkpoint_iteration = epoch
|
| 189 |
+
with self._get_checkpoint(model=model) as checkpoint:
|
| 190 |
+
ray.train.report(report_dict, checkpoint=checkpoint)
|
| 191 |
+
else:
|
| 192 |
+
ray.train.report(report_dict)
|
| 193 |
+
|
| 194 |
+
def after_training(self, model: Booster) -> Booster:
|
| 195 |
+
if not self._checkpoint_at_end:
|
| 196 |
+
return model
|
| 197 |
+
|
| 198 |
+
if (
|
| 199 |
+
self._last_checkpoint_iteration is not None
|
| 200 |
+
and model.num_boosted_rounds() - 1 == self._last_checkpoint_iteration
|
| 201 |
+
):
|
| 202 |
+
# Avoids a duplicate checkpoint if the checkpoint frequency happens
|
| 203 |
+
# to align with the last iteration.
|
| 204 |
+
return model
|
| 205 |
+
|
| 206 |
+
report_dict = self._get_report_dict(self._evals_log) if self._evals_log else {}
|
| 207 |
+
with self._get_checkpoint(model=model) as checkpoint:
|
| 208 |
+
ray.train.report(report_dict, checkpoint=checkpoint)
|
| 209 |
+
|
| 210 |
+
return model
|
.venv/lib/python3.11/site-packages/ray/train/xgboost/config.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import threading
|
| 5 |
+
from contextlib import contextmanager
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import xgboost
|
| 10 |
+
from packaging.version import Version
|
| 11 |
+
from xgboost import RabitTracker
|
| 12 |
+
from xgboost.collective import CommunicatorContext
|
| 13 |
+
|
| 14 |
+
import ray
|
| 15 |
+
from ray.train._internal.worker_group import WorkerGroup
|
| 16 |
+
from ray.train.backend import Backend, BackendConfig
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class XGBoostConfig(BackendConfig):
|
| 23 |
+
"""Configuration for xgboost collective communication setup.
|
| 24 |
+
|
| 25 |
+
Ray Train will set up the necessary coordinator processes and environment
|
| 26 |
+
variables for your workers to communicate with each other.
|
| 27 |
+
Additional configuration options can be passed into the
|
| 28 |
+
`xgboost.collective.CommunicatorContext` that wraps your own `xgboost.train` code.
|
| 29 |
+
|
| 30 |
+
See the `xgboost.collective` module for more information:
|
| 31 |
+
https://github.com/dmlc/xgboost/blob/master/python-package/xgboost/collective.py
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
xgboost_communicator: The backend to use for collective communication for
|
| 35 |
+
distributed xgboost training. For now, only "rabit" is supported.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
xgboost_communicator: str = "rabit"
|
| 39 |
+
|
| 40 |
+
@property
|
| 41 |
+
def train_func_context(self):
|
| 42 |
+
@contextmanager
|
| 43 |
+
def collective_communication_context():
|
| 44 |
+
with CommunicatorContext(**_get_xgboost_args()):
|
| 45 |
+
yield
|
| 46 |
+
|
| 47 |
+
return collective_communication_context
|
| 48 |
+
|
| 49 |
+
@property
|
| 50 |
+
def backend_cls(self):
|
| 51 |
+
if self.xgboost_communicator == "rabit":
|
| 52 |
+
return (
|
| 53 |
+
_XGBoostRabitBackend
|
| 54 |
+
if Version(xgboost.__version__) >= Version("2.1.0")
|
| 55 |
+
else _XGBoostRabitBackend_pre_xgb210
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
raise NotImplementedError(f"Unsupported backend: {self.xgboost_communicator}")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class _XGBoostRabitBackend(Backend):
|
| 62 |
+
def __init__(self):
|
| 63 |
+
self._tracker: Optional[RabitTracker] = None
|
| 64 |
+
self._wait_thread: Optional[threading.Thread] = None
|
| 65 |
+
|
| 66 |
+
def _setup_xgboost_distributed_backend(self, worker_group: WorkerGroup):
|
| 67 |
+
# Set up the rabit tracker on the Train driver.
|
| 68 |
+
num_workers = len(worker_group)
|
| 69 |
+
rabit_args = {"n_workers": num_workers}
|
| 70 |
+
train_driver_ip = ray.util.get_node_ip_address()
|
| 71 |
+
|
| 72 |
+
# NOTE: sortby="task" is needed to ensure that the xgboost worker ranks
|
| 73 |
+
# align with Ray Train worker ranks.
|
| 74 |
+
# The worker ranks will be sorted by `dmlc_task_id`,
|
| 75 |
+
# which is defined below.
|
| 76 |
+
self._tracker = RabitTracker(
|
| 77 |
+
n_workers=num_workers, host_ip=train_driver_ip, sortby="task"
|
| 78 |
+
)
|
| 79 |
+
self._tracker.start()
|
| 80 |
+
|
| 81 |
+
# The RabitTracker is started in a separate thread, and the
|
| 82 |
+
# `wait_for` method must be called for `worker_args` to return.
|
| 83 |
+
self._wait_thread = threading.Thread(target=self._tracker.wait_for, daemon=True)
|
| 84 |
+
self._wait_thread.start()
|
| 85 |
+
|
| 86 |
+
rabit_args.update(self._tracker.worker_args())
|
| 87 |
+
|
| 88 |
+
start_log = (
|
| 89 |
+
"RabitTracker coordinator started with parameters:\n"
|
| 90 |
+
f"{json.dumps(rabit_args, indent=2)}"
|
| 91 |
+
)
|
| 92 |
+
logger.debug(start_log)
|
| 93 |
+
|
| 94 |
+
def set_xgboost_communicator_args(args):
|
| 95 |
+
import ray.train
|
| 96 |
+
|
| 97 |
+
args["dmlc_task_id"] = (
|
| 98 |
+
f"[xgboost.ray-rank={ray.train.get_context().get_world_rank():08}]:"
|
| 99 |
+
f"{ray.get_runtime_context().get_actor_id()}"
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
_set_xgboost_args(args)
|
| 103 |
+
|
| 104 |
+
worker_group.execute(set_xgboost_communicator_args, rabit_args)
|
| 105 |
+
|
| 106 |
+
def on_training_start(
|
| 107 |
+
self, worker_group: WorkerGroup, backend_config: XGBoostConfig
|
| 108 |
+
):
|
| 109 |
+
assert backend_config.xgboost_communicator == "rabit"
|
| 110 |
+
self._setup_xgboost_distributed_backend(worker_group)
|
| 111 |
+
|
| 112 |
+
def on_shutdown(self, worker_group: WorkerGroup, backend_config: XGBoostConfig):
|
| 113 |
+
timeout = 5
|
| 114 |
+
|
| 115 |
+
if self._wait_thread is not None:
|
| 116 |
+
self._wait_thread.join(timeout=timeout)
|
| 117 |
+
|
| 118 |
+
if self._wait_thread.is_alive():
|
| 119 |
+
logger.warning(
|
| 120 |
+
"During shutdown, the RabitTracker thread failed to join "
|
| 121 |
+
f"within {timeout} seconds. "
|
| 122 |
+
"The process will still be terminated as part of Ray actor cleanup."
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class _XGBoostRabitBackend_pre_xgb210(Backend):
|
| 127 |
+
def __init__(self):
|
| 128 |
+
self._tracker: Optional[RabitTracker] = None
|
| 129 |
+
|
| 130 |
+
def _setup_xgboost_distributed_backend(self, worker_group: WorkerGroup):
|
| 131 |
+
# Set up the rabit tracker on the Train driver.
|
| 132 |
+
num_workers = len(worker_group)
|
| 133 |
+
rabit_args = {"DMLC_NUM_WORKER": num_workers}
|
| 134 |
+
train_driver_ip = ray.util.get_node_ip_address()
|
| 135 |
+
|
| 136 |
+
# NOTE: sortby="task" is needed to ensure that the xgboost worker ranks
|
| 137 |
+
# align with Ray Train worker ranks.
|
| 138 |
+
# The worker ranks will be sorted by `DMLC_TASK_ID`,
|
| 139 |
+
# which is defined below.
|
| 140 |
+
self._tracker = RabitTracker(
|
| 141 |
+
n_workers=num_workers, host_ip=train_driver_ip, sortby="task"
|
| 142 |
+
)
|
| 143 |
+
self._tracker.start(n_workers=num_workers)
|
| 144 |
+
|
| 145 |
+
worker_args = self._tracker.worker_envs()
|
| 146 |
+
rabit_args.update(worker_args)
|
| 147 |
+
|
| 148 |
+
start_log = (
|
| 149 |
+
"RabitTracker coordinator started with parameters:\n"
|
| 150 |
+
f"{json.dumps(rabit_args, indent=2)}"
|
| 151 |
+
)
|
| 152 |
+
logger.debug(start_log)
|
| 153 |
+
|
| 154 |
+
def set_xgboost_env_vars():
|
| 155 |
+
import ray.train
|
| 156 |
+
|
| 157 |
+
for k, v in rabit_args.items():
|
| 158 |
+
os.environ[k] = str(v)
|
| 159 |
+
|
| 160 |
+
# Ranks are assigned in increasing order of the worker's task id.
|
| 161 |
+
# This task id will be sorted by increasing world rank.
|
| 162 |
+
os.environ["DMLC_TASK_ID"] = (
|
| 163 |
+
f"[xgboost.ray-rank={ray.train.get_context().get_world_rank():08}]:"
|
| 164 |
+
f"{ray.get_runtime_context().get_actor_id()}"
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
worker_group.execute(set_xgboost_env_vars)
|
| 168 |
+
|
| 169 |
+
def on_training_start(
|
| 170 |
+
self, worker_group: WorkerGroup, backend_config: XGBoostConfig
|
| 171 |
+
):
|
| 172 |
+
assert backend_config.xgboost_communicator == "rabit"
|
| 173 |
+
self._setup_xgboost_distributed_backend(worker_group)
|
| 174 |
+
|
| 175 |
+
def on_shutdown(self, worker_group: WorkerGroup, backend_config: XGBoostConfig):
|
| 176 |
+
if not self._tracker:
|
| 177 |
+
return
|
| 178 |
+
|
| 179 |
+
timeout = 5
|
| 180 |
+
self._tracker.thread.join(timeout=timeout)
|
| 181 |
+
|
| 182 |
+
if self._tracker.thread.is_alive():
|
| 183 |
+
logger.warning(
|
| 184 |
+
"During shutdown, the RabitTracker thread failed to join "
|
| 185 |
+
f"within {timeout} seconds. "
|
| 186 |
+
"The process will still be terminated as part of Ray actor cleanup."
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
_xgboost_args: dict = {}
|
| 191 |
+
_xgboost_args_lock = threading.Lock()
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def _set_xgboost_args(args):
|
| 195 |
+
with _xgboost_args_lock:
|
| 196 |
+
global _xgboost_args
|
| 197 |
+
_xgboost_args = args
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def _get_xgboost_args() -> dict:
|
| 201 |
+
with _xgboost_args_lock:
|
| 202 |
+
return _xgboost_args
|
.venv/lib/python3.11/site-packages/ray/train/xgboost/v2.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Any, Callable, Dict, Optional, Union
|
| 3 |
+
|
| 4 |
+
import ray.train
|
| 5 |
+
from ray.train import Checkpoint
|
| 6 |
+
from ray.train.data_parallel_trainer import DataParallelTrainer
|
| 7 |
+
from ray.train.trainer import GenDataset
|
| 8 |
+
from ray.train.xgboost import XGBoostConfig
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class XGBoostTrainer(DataParallelTrainer):
|
| 14 |
+
"""A Trainer for distributed data-parallel XGBoost training.
|
| 15 |
+
|
| 16 |
+
Example
|
| 17 |
+
-------
|
| 18 |
+
|
| 19 |
+
.. testcode::
|
| 20 |
+
|
| 21 |
+
import xgboost
|
| 22 |
+
|
| 23 |
+
import ray.data
|
| 24 |
+
import ray.train
|
| 25 |
+
from ray.train.xgboost import RayTrainReportCallback
|
| 26 |
+
from ray.train.xgboost.v2 import XGBoostTrainer
|
| 27 |
+
|
| 28 |
+
def train_fn_per_worker(config: dict):
|
| 29 |
+
# (Optional) Add logic to resume training state from a checkpoint.
|
| 30 |
+
# ray.train.get_checkpoint()
|
| 31 |
+
|
| 32 |
+
# 1. Get the dataset shard for the worker and convert to a `xgboost.DMatrix`
|
| 33 |
+
train_ds_iter, eval_ds_iter = (
|
| 34 |
+
ray.train.get_dataset_shard("train"),
|
| 35 |
+
ray.train.get_dataset_shard("validation"),
|
| 36 |
+
)
|
| 37 |
+
train_ds, eval_ds = train_ds_iter.materialize(), eval_ds_iter.materialize()
|
| 38 |
+
|
| 39 |
+
train_df, eval_df = train_ds.to_pandas(), eval_ds.to_pandas()
|
| 40 |
+
train_X, train_y = train_df.drop("y", axis=1), train_df["y"]
|
| 41 |
+
eval_X, eval_y = eval_df.drop("y", axis=1), eval_df["y"]
|
| 42 |
+
|
| 43 |
+
dtrain = xgboost.DMatrix(train_X, label=train_y)
|
| 44 |
+
deval = xgboost.DMatrix(eval_X, label=eval_y)
|
| 45 |
+
|
| 46 |
+
params = {
|
| 47 |
+
"tree_method": "approx",
|
| 48 |
+
"objective": "reg:squarederror",
|
| 49 |
+
"eta": 1e-4,
|
| 50 |
+
"subsample": 0.5,
|
| 51 |
+
"max_depth": 2,
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
# 2. Do distributed data-parallel training.
|
| 55 |
+
# Ray Train sets up the necessary coordinator processes and
|
| 56 |
+
# environment variables for your workers to communicate with each other.
|
| 57 |
+
bst = xgboost.train(
|
| 58 |
+
params,
|
| 59 |
+
dtrain=dtrain,
|
| 60 |
+
evals=[(deval, "validation")],
|
| 61 |
+
num_boost_round=10,
|
| 62 |
+
callbacks=[RayTrainReportCallback()],
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
train_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)])
|
| 66 |
+
eval_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(16)])
|
| 67 |
+
trainer = XGBoostTrainer(
|
| 68 |
+
train_fn_per_worker,
|
| 69 |
+
datasets={"train": train_ds, "validation": eval_ds},
|
| 70 |
+
scaling_config=ray.train.ScalingConfig(num_workers=4),
|
| 71 |
+
)
|
| 72 |
+
result = trainer.fit()
|
| 73 |
+
booster = RayTrainReportCallback.get_model(result.checkpoint)
|
| 74 |
+
|
| 75 |
+
.. testoutput::
|
| 76 |
+
:hide:
|
| 77 |
+
|
| 78 |
+
...
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
train_loop_per_worker: The training function to execute on each worker.
|
| 82 |
+
This function can either take in zero arguments or a single ``Dict``
|
| 83 |
+
argument which is set by defining ``train_loop_config``.
|
| 84 |
+
Within this function you can use any of the
|
| 85 |
+
:ref:`Ray Train Loop utilities <train-loop-api>`.
|
| 86 |
+
train_loop_config: A configuration ``Dict`` to pass in as an argument to
|
| 87 |
+
``train_loop_per_worker``.
|
| 88 |
+
This is typically used for specifying hyperparameters.
|
| 89 |
+
xgboost_config: The configuration for setting up the distributed xgboost
|
| 90 |
+
backend. Defaults to using the "rabit" backend.
|
| 91 |
+
See :class:`~ray.train.xgboost.XGBoostConfig` for more info.
|
| 92 |
+
datasets: The Ray Datasets to use for training and validation.
|
| 93 |
+
dataset_config: The configuration for ingesting the input ``datasets``.
|
| 94 |
+
By default, all the Ray Datasets are split equally across workers.
|
| 95 |
+
See :class:`~ray.train.DataConfig` for more details.
|
| 96 |
+
scaling_config: The configuration for how to scale data parallel training.
|
| 97 |
+
``num_workers`` determines how many Python processes are used for training,
|
| 98 |
+
and ``use_gpu`` determines whether or not each process should use GPUs.
|
| 99 |
+
See :class:`~ray.train.ScalingConfig` for more info.
|
| 100 |
+
run_config: The configuration for the execution of the training run.
|
| 101 |
+
See :class:`~ray.train.RunConfig` for more info.
|
| 102 |
+
resume_from_checkpoint: A checkpoint to resume training from.
|
| 103 |
+
This checkpoint can be accessed from within ``train_loop_per_worker``
|
| 104 |
+
by calling ``ray.train.get_checkpoint()``.
|
| 105 |
+
metadata: Dict that should be made available via
|
| 106 |
+
`ray.train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
|
| 107 |
+
for checkpoints saved from this Trainer. Must be JSON-serializable.
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]],
|
| 113 |
+
*,
|
| 114 |
+
train_loop_config: Optional[Dict] = None,
|
| 115 |
+
xgboost_config: Optional[XGBoostConfig] = None,
|
| 116 |
+
scaling_config: Optional[ray.train.ScalingConfig] = None,
|
| 117 |
+
run_config: Optional[ray.train.RunConfig] = None,
|
| 118 |
+
datasets: Optional[Dict[str, GenDataset]] = None,
|
| 119 |
+
dataset_config: Optional[ray.train.DataConfig] = None,
|
| 120 |
+
metadata: Optional[Dict[str, Any]] = None,
|
| 121 |
+
resume_from_checkpoint: Optional[Checkpoint] = None,
|
| 122 |
+
):
|
| 123 |
+
super(XGBoostTrainer, self).__init__(
|
| 124 |
+
train_loop_per_worker=train_loop_per_worker,
|
| 125 |
+
train_loop_config=train_loop_config,
|
| 126 |
+
backend_config=xgboost_config or XGBoostConfig(),
|
| 127 |
+
scaling_config=scaling_config,
|
| 128 |
+
dataset_config=dataset_config,
|
| 129 |
+
run_config=run_config,
|
| 130 |
+
datasets=datasets,
|
| 131 |
+
resume_from_checkpoint=resume_from_checkpoint,
|
| 132 |
+
metadata=metadata,
|
| 133 |
+
)
|
.venv/lib/python3.11/site-packages/ray/train/xgboost/xgboost_checkpoint.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tempfile
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import TYPE_CHECKING, Optional
|
| 4 |
+
|
| 5 |
+
import xgboost
|
| 6 |
+
|
| 7 |
+
from ray.train._internal.framework_checkpoint import FrameworkCheckpoint
|
| 8 |
+
from ray.util.annotations import PublicAPI
|
| 9 |
+
|
| 10 |
+
if TYPE_CHECKING:
|
| 11 |
+
from ray.data.preprocessor import Preprocessor
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@PublicAPI(stability="beta")
|
| 15 |
+
class XGBoostCheckpoint(FrameworkCheckpoint):
|
| 16 |
+
"""A :py:class:`~ray.train.Checkpoint` with XGBoost-specific functionality."""
|
| 17 |
+
|
| 18 |
+
MODEL_FILENAME = "model.json"
|
| 19 |
+
|
| 20 |
+
@classmethod
|
| 21 |
+
def from_model(
|
| 22 |
+
cls,
|
| 23 |
+
booster: xgboost.Booster,
|
| 24 |
+
*,
|
| 25 |
+
preprocessor: Optional["Preprocessor"] = None,
|
| 26 |
+
path: Optional[str] = None,
|
| 27 |
+
) -> "XGBoostCheckpoint":
|
| 28 |
+
"""Create a :py:class:`~ray.train.Checkpoint` that stores an XGBoost
|
| 29 |
+
model.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
booster: The XGBoost model to store in the checkpoint.
|
| 33 |
+
preprocessor: A fitted preprocessor to be applied before inference.
|
| 34 |
+
path: The path to the directory where the checkpoint file will be saved.
|
| 35 |
+
This should start as an empty directory, since the *entire*
|
| 36 |
+
directory will be treated as the checkpoint when reported.
|
| 37 |
+
By default, a temporary directory will be created.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
An :py:class:`XGBoostCheckpoint` containing the specified ``Estimator``.
|
| 41 |
+
|
| 42 |
+
Examples:
|
| 43 |
+
|
| 44 |
+
... testcode::
|
| 45 |
+
|
| 46 |
+
import numpy as np
|
| 47 |
+
import ray
|
| 48 |
+
from ray.train.xgboost import XGBoostCheckpoint
|
| 49 |
+
import xgboost
|
| 50 |
+
|
| 51 |
+
train_X = np.array([[1, 2], [3, 4]])
|
| 52 |
+
train_y = np.array([0, 1])
|
| 53 |
+
|
| 54 |
+
model = xgboost.XGBClassifier().fit(train_X, train_y)
|
| 55 |
+
checkpoint = XGBoostCheckpoint.from_model(model.get_booster())
|
| 56 |
+
|
| 57 |
+
"""
|
| 58 |
+
checkpoint_path = Path(path or tempfile.mkdtemp())
|
| 59 |
+
|
| 60 |
+
if not checkpoint_path.is_dir():
|
| 61 |
+
raise ValueError(f"`path` must be a directory, but got: {checkpoint_path}")
|
| 62 |
+
|
| 63 |
+
booster.save_model(checkpoint_path.joinpath(cls.MODEL_FILENAME).as_posix())
|
| 64 |
+
|
| 65 |
+
checkpoint = cls.from_directory(checkpoint_path.as_posix())
|
| 66 |
+
if preprocessor:
|
| 67 |
+
checkpoint.set_preprocessor(preprocessor)
|
| 68 |
+
return checkpoint
|
| 69 |
+
|
| 70 |
+
def get_model(self) -> xgboost.Booster:
|
| 71 |
+
"""Retrieve the XGBoost model stored in this checkpoint."""
|
| 72 |
+
with self.as_directory() as checkpoint_path:
|
| 73 |
+
booster = xgboost.Booster()
|
| 74 |
+
booster.load_model(Path(checkpoint_path, self.MODEL_FILENAME).as_posix())
|
| 75 |
+
return booster
|
.venv/lib/python3.11/site-packages/ray/train/xgboost/xgboost_predictor.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
| 2 |
+
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import xgboost
|
| 5 |
+
|
| 6 |
+
from ray.air.constants import TENSOR_COLUMN_NAME
|
| 7 |
+
from ray.air.data_batch_type import DataBatchType
|
| 8 |
+
from ray.air.util.data_batch_conversion import _unwrap_ndarray_object_type_if_needed
|
| 9 |
+
from ray.train.predictor import Predictor
|
| 10 |
+
from ray.train.xgboost import XGBoostCheckpoint
|
| 11 |
+
from ray.util.annotations import PublicAPI
|
| 12 |
+
|
| 13 |
+
if TYPE_CHECKING:
|
| 14 |
+
from ray.data.preprocessor import Preprocessor
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@PublicAPI(stability="beta")
|
| 18 |
+
class XGBoostPredictor(Predictor):
|
| 19 |
+
"""A predictor for XGBoost models.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
model: The XGBoost booster to use for predictions.
|
| 23 |
+
preprocessor: A preprocessor used to transform data batches prior
|
| 24 |
+
to prediction.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self, model: xgboost.Booster, preprocessor: Optional["Preprocessor"] = None
|
| 29 |
+
):
|
| 30 |
+
self.model = model
|
| 31 |
+
super().__init__(preprocessor)
|
| 32 |
+
|
| 33 |
+
def __repr__(self):
|
| 34 |
+
return (
|
| 35 |
+
f"{self.__class__.__name__}(model={self.model!r}, "
|
| 36 |
+
f"preprocessor={self._preprocessor!r})"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
@classmethod
|
| 40 |
+
def from_checkpoint(cls, checkpoint: XGBoostCheckpoint) -> "XGBoostPredictor":
|
| 41 |
+
"""Instantiate the predictor from a Checkpoint.
|
| 42 |
+
|
| 43 |
+
This is a helper constructor that instantiates the predictor from a
|
| 44 |
+
framework-specific XGBoost checkpoint.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
checkpoint: The checkpoint to load the model and preprocessor from.
|
| 48 |
+
|
| 49 |
+
"""
|
| 50 |
+
model = checkpoint.get_model()
|
| 51 |
+
preprocessor = checkpoint.get_preprocessor()
|
| 52 |
+
return cls(model=model, preprocessor=preprocessor)
|
| 53 |
+
|
| 54 |
+
def predict(
|
| 55 |
+
self,
|
| 56 |
+
data: DataBatchType,
|
| 57 |
+
feature_columns: Optional[Union[List[str], List[int]]] = None,
|
| 58 |
+
dmatrix_kwargs: Optional[Dict[str, Any]] = None,
|
| 59 |
+
**predict_kwargs,
|
| 60 |
+
) -> DataBatchType:
|
| 61 |
+
"""Run inference on data batch.
|
| 62 |
+
|
| 63 |
+
The data is converted into an XGBoost DMatrix before being inputted to
|
| 64 |
+
the model.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
data: A batch of input data.
|
| 68 |
+
feature_columns: The names or indices of the columns in the
|
| 69 |
+
data to use as features to predict on. If None, then use
|
| 70 |
+
all columns in ``data``.
|
| 71 |
+
dmatrix_kwargs: Dict of keyword arguments passed to ``xgboost.DMatrix``.
|
| 72 |
+
**predict_kwargs: Keyword arguments passed to ``xgboost.Booster.predict``.
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
Examples:
|
| 76 |
+
|
| 77 |
+
.. testcode::
|
| 78 |
+
|
| 79 |
+
import numpy as np
|
| 80 |
+
import xgboost as xgb
|
| 81 |
+
from ray.train.xgboost import XGBoostPredictor
|
| 82 |
+
train_X = np.array([[1, 2], [3, 4]])
|
| 83 |
+
train_y = np.array([0, 1])
|
| 84 |
+
model = xgb.XGBClassifier().fit(train_X, train_y)
|
| 85 |
+
predictor = XGBoostPredictor(model=model.get_booster())
|
| 86 |
+
data = np.array([[1, 2], [3, 4]])
|
| 87 |
+
predictions = predictor.predict(data)
|
| 88 |
+
# Only use first and second column as the feature
|
| 89 |
+
data = np.array([[1, 2, 8], [3, 4, 9]])
|
| 90 |
+
predictions = predictor.predict(data, feature_columns=[0, 1])
|
| 91 |
+
|
| 92 |
+
.. testcode::
|
| 93 |
+
|
| 94 |
+
import pandas as pd
|
| 95 |
+
import xgboost as xgb
|
| 96 |
+
from ray.train.xgboost import XGBoostPredictor
|
| 97 |
+
train_X = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"])
|
| 98 |
+
train_y = pd.Series([0, 1])
|
| 99 |
+
model = xgb.XGBClassifier().fit(train_X, train_y)
|
| 100 |
+
predictor = XGBoostPredictor(model=model.get_booster())
|
| 101 |
+
# Pandas dataframe.
|
| 102 |
+
data = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"])
|
| 103 |
+
predictions = predictor.predict(data)
|
| 104 |
+
# Only use first and second column as the feature
|
| 105 |
+
data = pd.DataFrame([[1, 2, 8], [3, 4, 9]], columns=["A", "B", "C"])
|
| 106 |
+
predictions = predictor.predict(data, feature_columns=["A", "B"])
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
Prediction result.
|
| 111 |
+
|
| 112 |
+
"""
|
| 113 |
+
return Predictor.predict(
|
| 114 |
+
self,
|
| 115 |
+
data,
|
| 116 |
+
feature_columns=feature_columns,
|
| 117 |
+
dmatrix_kwargs=dmatrix_kwargs,
|
| 118 |
+
**predict_kwargs,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
def _predict_pandas(
|
| 122 |
+
self,
|
| 123 |
+
data: "pd.DataFrame",
|
| 124 |
+
feature_columns: Optional[Union[List[str], List[int]]] = None,
|
| 125 |
+
dmatrix_kwargs: Optional[Dict[str, Any]] = None,
|
| 126 |
+
**predict_kwargs,
|
| 127 |
+
) -> "pd.DataFrame":
|
| 128 |
+
dmatrix_kwargs = dmatrix_kwargs or {}
|
| 129 |
+
|
| 130 |
+
feature_names = None
|
| 131 |
+
if TENSOR_COLUMN_NAME in data:
|
| 132 |
+
data = data[TENSOR_COLUMN_NAME].to_numpy()
|
| 133 |
+
data = _unwrap_ndarray_object_type_if_needed(data)
|
| 134 |
+
if feature_columns:
|
| 135 |
+
# In this case feature_columns is a list of integers
|
| 136 |
+
data = data[:, feature_columns]
|
| 137 |
+
elif feature_columns:
|
| 138 |
+
# feature_columns is a list of integers or strings
|
| 139 |
+
data = data[feature_columns].to_numpy()
|
| 140 |
+
# Only set the feature names if they are strings
|
| 141 |
+
if all(isinstance(fc, str) for fc in feature_columns):
|
| 142 |
+
feature_names = feature_columns
|
| 143 |
+
else:
|
| 144 |
+
feature_columns = data.columns.tolist()
|
| 145 |
+
data = data.to_numpy()
|
| 146 |
+
|
| 147 |
+
if all(isinstance(fc, str) for fc in feature_columns):
|
| 148 |
+
feature_names = feature_columns
|
| 149 |
+
|
| 150 |
+
if feature_names:
|
| 151 |
+
dmatrix_kwargs["feature_names"] = feature_names
|
| 152 |
+
|
| 153 |
+
matrix = xgboost.DMatrix(data, **dmatrix_kwargs)
|
| 154 |
+
df = pd.DataFrame(self.model.predict(matrix, **predict_kwargs))
|
| 155 |
+
df.columns = (
|
| 156 |
+
["predictions"]
|
| 157 |
+
if len(df.columns) == 1
|
| 158 |
+
else [f"predictions_{i}" for i in range(len(df.columns))]
|
| 159 |
+
)
|
| 160 |
+
return df
|
.venv/lib/python3.11/site-packages/ray/train/xgboost/xgboost_trainer.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from functools import partial
|
| 3 |
+
from typing import Any, Dict, Optional
|
| 4 |
+
|
| 5 |
+
import xgboost
|
| 6 |
+
from packaging.version import Version
|
| 7 |
+
|
| 8 |
+
import ray.train
|
| 9 |
+
from ray.train import Checkpoint
|
| 10 |
+
from ray.train.constants import _DEPRECATED_VALUE, TRAIN_DATASET_KEY
|
| 11 |
+
from ray.train.trainer import GenDataset
|
| 12 |
+
from ray.train.xgboost import RayTrainReportCallback
|
| 13 |
+
from ray.train.xgboost.v2 import XGBoostTrainer as SimpleXGBoostTrainer
|
| 14 |
+
from ray.util.annotations import PublicAPI
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _xgboost_train_fn_per_worker(
|
| 20 |
+
config: dict,
|
| 21 |
+
label_column: str,
|
| 22 |
+
num_boost_round: int,
|
| 23 |
+
dataset_keys: set,
|
| 24 |
+
xgboost_train_kwargs: dict,
|
| 25 |
+
):
|
| 26 |
+
checkpoint = ray.train.get_checkpoint()
|
| 27 |
+
starting_model = None
|
| 28 |
+
remaining_iters = num_boost_round
|
| 29 |
+
if checkpoint:
|
| 30 |
+
starting_model = RayTrainReportCallback.get_model(checkpoint)
|
| 31 |
+
starting_iter = starting_model.num_boosted_rounds()
|
| 32 |
+
remaining_iters = num_boost_round - starting_iter
|
| 33 |
+
logger.info(
|
| 34 |
+
f"Model loaded from checkpoint will train for "
|
| 35 |
+
f"additional {remaining_iters} iterations (trees) in order "
|
| 36 |
+
"to achieve the target number of iterations "
|
| 37 |
+
f"({num_boost_round=})."
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
train_ds_iter = ray.train.get_dataset_shard(TRAIN_DATASET_KEY)
|
| 41 |
+
train_df = train_ds_iter.materialize().to_pandas()
|
| 42 |
+
|
| 43 |
+
eval_ds_iters = {
|
| 44 |
+
k: ray.train.get_dataset_shard(k)
|
| 45 |
+
for k in dataset_keys
|
| 46 |
+
if k != TRAIN_DATASET_KEY
|
| 47 |
+
}
|
| 48 |
+
eval_dfs = {k: d.materialize().to_pandas() for k, d in eval_ds_iters.items()}
|
| 49 |
+
|
| 50 |
+
train_X, train_y = train_df.drop(label_column, axis=1), train_df[label_column]
|
| 51 |
+
dtrain = xgboost.DMatrix(train_X, label=train_y)
|
| 52 |
+
|
| 53 |
+
# NOTE: Include the training dataset in the evaluation datasets.
|
| 54 |
+
# This allows `train-*` metrics to be calculated and reported.
|
| 55 |
+
evals = [(dtrain, TRAIN_DATASET_KEY)]
|
| 56 |
+
|
| 57 |
+
for eval_name, eval_df in eval_dfs.items():
|
| 58 |
+
eval_X, eval_y = eval_df.drop(label_column, axis=1), eval_df[label_column]
|
| 59 |
+
evals.append((xgboost.DMatrix(eval_X, label=eval_y), eval_name))
|
| 60 |
+
|
| 61 |
+
evals_result = {}
|
| 62 |
+
xgboost.train(
|
| 63 |
+
config,
|
| 64 |
+
dtrain=dtrain,
|
| 65 |
+
evals=evals,
|
| 66 |
+
evals_result=evals_result,
|
| 67 |
+
num_boost_round=remaining_iters,
|
| 68 |
+
xgb_model=starting_model,
|
| 69 |
+
**xgboost_train_kwargs,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@PublicAPI(stability="beta")
|
| 74 |
+
class XGBoostTrainer(SimpleXGBoostTrainer):
|
| 75 |
+
"""A Trainer for data parallel XGBoost training.
|
| 76 |
+
|
| 77 |
+
This Trainer runs the XGBoost training loop in a distributed manner
|
| 78 |
+
using multiple Ray Actors.
|
| 79 |
+
|
| 80 |
+
.. note::
|
| 81 |
+
``XGBoostTrainer`` does not modify or otherwise alter the working
|
| 82 |
+
of the XGBoost distributed training algorithm.
|
| 83 |
+
Ray only provides orchestration, data ingest and fault tolerance.
|
| 84 |
+
For more information on XGBoost distributed training, refer to
|
| 85 |
+
`XGBoost documentation <https://xgboost.readthedocs.io>`__.
|
| 86 |
+
|
| 87 |
+
Example:
|
| 88 |
+
.. testcode::
|
| 89 |
+
|
| 90 |
+
import ray
|
| 91 |
+
|
| 92 |
+
from ray.train.xgboost import XGBoostTrainer
|
| 93 |
+
from ray.train import ScalingConfig
|
| 94 |
+
|
| 95 |
+
train_dataset = ray.data.from_items(
|
| 96 |
+
[{"x": x, "y": x + 1} for x in range(32)])
|
| 97 |
+
trainer = XGBoostTrainer(
|
| 98 |
+
label_column="y",
|
| 99 |
+
params={"objective": "reg:squarederror"},
|
| 100 |
+
scaling_config=ScalingConfig(num_workers=3),
|
| 101 |
+
datasets={"train": train_dataset},
|
| 102 |
+
)
|
| 103 |
+
result = trainer.fit()
|
| 104 |
+
|
| 105 |
+
.. testoutput::
|
| 106 |
+
:hide:
|
| 107 |
+
|
| 108 |
+
...
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
datasets: The Ray Datasets to use for training and validation. Must include a
|
| 112 |
+
"train" key denoting the training dataset. All non-training datasets will
|
| 113 |
+
be used as separate validation sets, each reporting a separate metric.
|
| 114 |
+
label_column: Name of the label column. A column with this name
|
| 115 |
+
must be present in the training dataset.
|
| 116 |
+
params: XGBoost training parameters.
|
| 117 |
+
Refer to `XGBoost documentation <https://xgboost.readthedocs.io/>`_
|
| 118 |
+
for a list of possible parameters.
|
| 119 |
+
num_boost_round: Target number of boosting iterations (trees in the model).
|
| 120 |
+
Note that unlike in ``xgboost.train``, this is the target number
|
| 121 |
+
of trees, meaning that if you set ``num_boost_round=10`` and pass a model
|
| 122 |
+
that has already been trained for 5 iterations, it will be trained for 5
|
| 123 |
+
iterations more, instead of 10 more.
|
| 124 |
+
scaling_config: Configuration for how to scale data parallel training.
|
| 125 |
+
run_config: Configuration for the execution of the training run.
|
| 126 |
+
dataset_config: The configuration for ingesting the input ``datasets``.
|
| 127 |
+
By default, all the Ray Datasets are split equally across workers.
|
| 128 |
+
See :class:`~ray.train.DataConfig` for more details.
|
| 129 |
+
resume_from_checkpoint: A checkpoint to resume training from.
|
| 130 |
+
metadata: Dict that should be made available in `checkpoint.get_metadata()`
|
| 131 |
+
for checkpoints saved from this Trainer. Must be JSON-serializable.
|
| 132 |
+
**train_kwargs: Additional kwargs passed to ``xgboost.train()`` function.
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
_handles_checkpoint_freq = True
|
| 136 |
+
_handles_checkpoint_at_end = True
|
| 137 |
+
|
| 138 |
+
def __init__(
|
| 139 |
+
self,
|
| 140 |
+
*,
|
| 141 |
+
datasets: Dict[str, GenDataset],
|
| 142 |
+
label_column: str,
|
| 143 |
+
params: Dict[str, Any],
|
| 144 |
+
dmatrix_params: Optional[Dict[str, Dict[str, Any]]] = _DEPRECATED_VALUE,
|
| 145 |
+
num_boost_round: int = 10,
|
| 146 |
+
scaling_config: Optional[ray.train.ScalingConfig] = None,
|
| 147 |
+
run_config: Optional[ray.train.RunConfig] = None,
|
| 148 |
+
dataset_config: Optional[ray.train.DataConfig] = None,
|
| 149 |
+
resume_from_checkpoint: Optional[Checkpoint] = None,
|
| 150 |
+
metadata: Optional[Dict[str, Any]] = None,
|
| 151 |
+
**train_kwargs,
|
| 152 |
+
):
|
| 153 |
+
if Version(xgboost.__version__) < Version("1.7.0"):
|
| 154 |
+
raise ImportError(
|
| 155 |
+
"`XGBoostTrainer` requires the `xgboost` version to be >= 1.7.0. "
|
| 156 |
+
'Upgrade with: `pip install -U "xgboost>=1.7"`'
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# TODO(justinvyu): [Deprecated] Remove in 2.11
|
| 160 |
+
if dmatrix_params != _DEPRECATED_VALUE:
|
| 161 |
+
raise DeprecationWarning(
|
| 162 |
+
"`dmatrix_params` is deprecated, since XGBoostTrainer no longer "
|
| 163 |
+
"depends on the `xgboost_ray.RayDMatrix` utility. "
|
| 164 |
+
"You can remove this argument and use `dataset_config` instead "
|
| 165 |
+
"to customize Ray Dataset ingestion."
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# Initialize a default Ray Train metrics/checkpoint reporting callback if needed
|
| 169 |
+
callbacks = train_kwargs.get("callbacks", [])
|
| 170 |
+
user_supplied_callback = any(
|
| 171 |
+
isinstance(callback, RayTrainReportCallback) for callback in callbacks
|
| 172 |
+
)
|
| 173 |
+
callback_kwargs = {}
|
| 174 |
+
if run_config:
|
| 175 |
+
checkpoint_frequency = run_config.checkpoint_config.checkpoint_frequency
|
| 176 |
+
checkpoint_at_end = run_config.checkpoint_config.checkpoint_at_end
|
| 177 |
+
|
| 178 |
+
callback_kwargs["frequency"] = checkpoint_frequency
|
| 179 |
+
# Default `checkpoint_at_end=True` unless the user explicitly sets it.
|
| 180 |
+
callback_kwargs["checkpoint_at_end"] = (
|
| 181 |
+
checkpoint_at_end if checkpoint_at_end is not None else True
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
if not user_supplied_callback:
|
| 185 |
+
callbacks.append(RayTrainReportCallback(**callback_kwargs))
|
| 186 |
+
train_kwargs["callbacks"] = callbacks
|
| 187 |
+
|
| 188 |
+
train_fn_per_worker = partial(
|
| 189 |
+
_xgboost_train_fn_per_worker,
|
| 190 |
+
label_column=label_column,
|
| 191 |
+
num_boost_round=num_boost_round,
|
| 192 |
+
dataset_keys=set(datasets),
|
| 193 |
+
xgboost_train_kwargs=train_kwargs,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
super(XGBoostTrainer, self).__init__(
|
| 197 |
+
train_loop_per_worker=train_fn_per_worker,
|
| 198 |
+
train_loop_config=params,
|
| 199 |
+
scaling_config=scaling_config,
|
| 200 |
+
run_config=run_config,
|
| 201 |
+
datasets=datasets,
|
| 202 |
+
dataset_config=dataset_config,
|
| 203 |
+
resume_from_checkpoint=resume_from_checkpoint,
|
| 204 |
+
metadata=metadata,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
@classmethod
|
| 208 |
+
def get_model(
|
| 209 |
+
cls,
|
| 210 |
+
checkpoint: Checkpoint,
|
| 211 |
+
) -> xgboost.Booster:
|
| 212 |
+
"""Retrieve the XGBoost model stored in this checkpoint."""
|
| 213 |
+
return RayTrainReportCallback.get_model(checkpoint)
|
| 214 |
+
|
| 215 |
+
def _validate_attributes(self):
|
| 216 |
+
super()._validate_attributes()
|
| 217 |
+
|
| 218 |
+
if TRAIN_DATASET_KEY not in self.datasets:
|
| 219 |
+
raise KeyError(
|
| 220 |
+
f"'{TRAIN_DATASET_KEY}' key must be preset in `datasets`. "
|
| 221 |
+
f"Got {list(self.datasets.keys())}"
|
| 222 |
+
)
|
.venv/lib/python3.11/site-packages/ray/tune/search/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (5.87 kB). View file
|
|
|