koichi12 commited on
Commit
f39d59b
·
verified ·
1 Parent(s): 68246e2

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/ray/train/__init__.py +90 -0
  2. .venv/lib/python3.11/site-packages/ray/train/_checkpoint.py +424 -0
  3. .venv/lib/python3.11/site-packages/ray/train/backend.py +59 -0
  4. .venv/lib/python3.11/site-packages/ray/train/base_trainer.py +827 -0
  5. .venv/lib/python3.11/site-packages/ray/train/constants.py +118 -0
  6. .venv/lib/python3.11/site-packages/ray/train/context.py +139 -0
  7. .venv/lib/python3.11/site-packages/ray/train/data_parallel_trainer.py +587 -0
  8. .venv/lib/python3.11/site-packages/ray/train/error.py +6 -0
  9. .venv/lib/python3.11/site-packages/ray/train/examples/__init__.py +0 -0
  10. .venv/lib/python3.11/site-packages/ray/train/examples/mlflow_simple_example.py +55 -0
  11. .venv/lib/python3.11/site-packages/ray/train/examples/tf/tune_tensorflow_autoencoder_example.py +77 -0
  12. .venv/lib/python3.11/site-packages/ray/train/huggingface/__init__.py +0 -0
  13. .venv/lib/python3.11/site-packages/ray/train/huggingface/__pycache__/__init__.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/ray/train/huggingface/transformers/__init__.py +12 -0
  15. .venv/lib/python3.11/site-packages/ray/train/huggingface/transformers/__pycache__/__init__.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/ray/train/huggingface/transformers/__pycache__/_transformers_utils.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/ray/train/huggingface/transformers/_transformers_utils.py +143 -0
  18. .venv/lib/python3.11/site-packages/ray/train/lightgbm/__init__.py +18 -0
  19. .venv/lib/python3.11/site-packages/ray/train/lightgbm/__pycache__/__init__.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/ray/train/lightgbm/__pycache__/_lightgbm_utils.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/ray/train/lightgbm/__pycache__/config.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/ray/train/lightgbm/__pycache__/lightgbm_checkpoint.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/ray/train/lightgbm/__pycache__/lightgbm_predictor.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/ray/train/lightgbm/__pycache__/lightgbm_trainer.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/ray/train/lightgbm/__pycache__/v2.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/ray/train/lightgbm/_lightgbm_utils.py +170 -0
  27. .venv/lib/python3.11/site-packages/ray/train/lightgbm/config.py +89 -0
  28. .venv/lib/python3.11/site-packages/ray/train/lightgbm/lightgbm_checkpoint.py +70 -0
  29. .venv/lib/python3.11/site-packages/ray/train/lightgbm/lightgbm_predictor.py +152 -0
  30. .venv/lib/python3.11/site-packages/ray/train/lightgbm/lightgbm_trainer.py +221 -0
  31. .venv/lib/python3.11/site-packages/ray/train/lightgbm/v2.py +132 -0
  32. .venv/lib/python3.11/site-packages/ray/train/predictor.py +254 -0
  33. .venv/lib/python3.11/site-packages/ray/train/session.py +0 -0
  34. .venv/lib/python3.11/site-packages/ray/train/trainer.py +194 -0
  35. .venv/lib/python3.11/site-packages/ray/train/utils.py +19 -0
  36. .venv/lib/python3.11/site-packages/ray/train/xgboost/__init__.py +20 -0
  37. .venv/lib/python3.11/site-packages/ray/train/xgboost/__pycache__/__init__.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/ray/train/xgboost/__pycache__/_xgboost_utils.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/ray/train/xgboost/__pycache__/config.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/ray/train/xgboost/__pycache__/v2.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/ray/train/xgboost/__pycache__/xgboost_checkpoint.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/ray/train/xgboost/__pycache__/xgboost_predictor.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/ray/train/xgboost/__pycache__/xgboost_trainer.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/ray/train/xgboost/_xgboost_utils.py +210 -0
  45. .venv/lib/python3.11/site-packages/ray/train/xgboost/config.py +202 -0
  46. .venv/lib/python3.11/site-packages/ray/train/xgboost/v2.py +133 -0
  47. .venv/lib/python3.11/site-packages/ray/train/xgboost/xgboost_checkpoint.py +75 -0
  48. .venv/lib/python3.11/site-packages/ray/train/xgboost/xgboost_predictor.py +160 -0
  49. .venv/lib/python3.11/site-packages/ray/train/xgboost/xgboost_trainer.py +222 -0
  50. .venv/lib/python3.11/site-packages/ray/tune/search/__pycache__/__init__.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/ray/train/__init__.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Try import ray[train] core requirements (defined in setup.py)
2
+ # isort: off
3
+ try:
4
+ import fsspec # noqa: F401
5
+ import pandas # noqa: F401
6
+ import pyarrow # noqa: F401
7
+ import requests # noqa: F401
8
+ except ImportError as exc:
9
+ raise ImportError(
10
+ "Can't import ray.train as some dependencies are missing. "
11
+ 'Run `pip install "ray[train]"` to fix.'
12
+ ) from exc
13
+ # isort: on
14
+
15
+
16
+ from ray._private.usage import usage_lib
17
+ from ray.air.config import CheckpointConfig, FailureConfig, RunConfig, ScalingConfig
18
+ from ray.air.result import Result
19
+
20
+ # Import this first so it can be used in other modules
21
+ from ray.train._checkpoint import Checkpoint
22
+ from ray.train._internal.data_config import DataConfig
23
+ from ray.train._internal.session import get_checkpoint, get_dataset_shard, report
24
+ from ray.train._internal.syncer import SyncConfig
25
+ from ray.train.backend import BackendConfig
26
+ from ray.train.constants import TRAIN_DATASET_KEY
27
+ from ray.train.context import get_context
28
+ from ray.train.trainer import TrainingIterator
29
+ from ray.train.v2._internal.constants import is_v2_enabled
30
+
31
+ if is_v2_enabled():
32
+ from ray.train.v2.api.callback import UserCallback # noqa: F811
33
+ from ray.train.v2.api.config import ( # noqa: F811
34
+ FailureConfig,
35
+ RunConfig,
36
+ ScalingConfig,
37
+ )
38
+ from ray.train.v2.api.result import Result # noqa: F811
39
+ from ray.train.v2.api.train_fn_utils import ( # noqa: F811
40
+ get_checkpoint,
41
+ get_context,
42
+ get_dataset_shard,
43
+ report,
44
+ )
45
+
46
+
47
+ usage_lib.record_library_usage("train")
48
+
49
+ Checkpoint.__module__ = "ray.train"
50
+
51
+ __all__ = [
52
+ "get_checkpoint",
53
+ "get_context",
54
+ "get_dataset_shard",
55
+ "report",
56
+ "BackendConfig",
57
+ "Checkpoint",
58
+ "CheckpointConfig",
59
+ "DataConfig",
60
+ "FailureConfig",
61
+ "Result",
62
+ "RunConfig",
63
+ "ScalingConfig",
64
+ "SyncConfig",
65
+ "TrainingIterator",
66
+ "TRAIN_DATASET_KEY",
67
+ ]
68
+
69
+ get_checkpoint.__module__ = "ray.train"
70
+ get_context.__module__ = "ray.train"
71
+ get_dataset_shard.__module__ = "ray.train"
72
+ report.__module__ = "ray.train"
73
+ BackendConfig.__module__ = "ray.train"
74
+ Checkpoint.__module__ = "ray.train"
75
+ CheckpointConfig.__module__ = "ray.train"
76
+ DataConfig.__module__ = "ray.train"
77
+ FailureConfig.__module__ = "ray.train"
78
+ Result.__module__ = "ray.train"
79
+ RunConfig.__module__ = "ray.train"
80
+ ScalingConfig.__module__ = "ray.train"
81
+ SyncConfig.__module__ = "ray.train"
82
+ TrainingIterator.__module__ = "ray.train"
83
+
84
+
85
+ if is_v2_enabled():
86
+ __all__.append("UserCallback")
87
+ UserCallback.__module__ = "ray.train"
88
+
89
+
90
+ # DO NOT ADD ANYTHING AFTER THIS LINE.
.venv/lib/python3.11/site-packages/ray/train/_checkpoint.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import glob
3
+ import json
4
+ import logging
5
+ import os
6
+ import platform
7
+ import shutil
8
+ import tempfile
9
+ import traceback
10
+ import uuid
11
+ from pathlib import Path
12
+ from typing import Any, Dict, Iterator, List, Optional, Union
13
+
14
+ import pyarrow.fs
15
+
16
+ from ray.air._internal.filelock import TempFileLock
17
+ from ray.train._internal.storage import _download_from_fs_path, _exists_at_fs_path
18
+ from ray.util.annotations import PublicAPI
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # The filename of the file that stores user metadata set on the checkpoint.
23
+ _METADATA_FILE_NAME = ".metadata.json"
24
+
25
+ # The prefix of the temp checkpoint directory that `to_directory` downloads to
26
+ # on the local filesystem.
27
+ _CHECKPOINT_TEMP_DIR_PREFIX = "checkpoint_tmp_"
28
+
29
+
30
+ class _CheckpointMetaClass(type):
31
+ def __getattr__(self, item):
32
+ try:
33
+ return super().__getattribute__(item)
34
+ except AttributeError as exc:
35
+ if item in {
36
+ "from_dict",
37
+ "to_dict",
38
+ "from_bytes",
39
+ "to_bytes",
40
+ "get_internal_representation",
41
+ }:
42
+ raise _get_migration_error(item) from exc
43
+ elif item in {
44
+ "from_uri",
45
+ "to_uri",
46
+ "uri",
47
+ }:
48
+ raise _get_uri_error(item) from exc
49
+ elif item in {"get_preprocessor", "set_preprocessor"}:
50
+ raise _get_preprocessor_error(item) from exc
51
+
52
+ raise exc
53
+
54
+
55
+ @PublicAPI(stability="beta")
56
+ class Checkpoint(metaclass=_CheckpointMetaClass):
57
+ """A reference to data persisted as a directory in local or remote storage.
58
+
59
+ Access the checkpoint contents locally using ``checkpoint.to_directory()``
60
+ or ``checkpoint.as_directory``.
61
+
62
+ Attributes
63
+ ----------
64
+ path: A path on the filesystem containing the checkpoint contents.
65
+ filesystem: PyArrow FileSystem that can be used to access data at the `path`.
66
+
67
+ See Also
68
+ --------
69
+ ray.train.report : Report a checkpoint during training (with Ray Train/Tune).
70
+ ray.train.get_checkpoint : Get the latest checkpoint during training
71
+ (for restoration).
72
+
73
+ :ref:`train-checkpointing`
74
+ :ref:`persistent-storage-guide`
75
+
76
+ Examples
77
+ --------
78
+
79
+ Creating a checkpoint using ``Checkpoint.from_directory``:
80
+
81
+ >>> from ray.train import Checkpoint
82
+ >>> checkpoint = Checkpoint.from_directory("/tmp/example_checkpoint_dir")
83
+ >>> checkpoint.filesystem # doctest: +ELLIPSIS
84
+ <pyarrow._fs.LocalFileSystem object...
85
+ >>> checkpoint.path
86
+ '/tmp/example_checkpoint_dir'
87
+
88
+ Creating a checkpoint from a remote URI:
89
+
90
+ >>> checkpoint = Checkpoint("s3://bucket/path/to/checkpoint")
91
+ >>> checkpoint.filesystem # doctest: +ELLIPSIS
92
+ <pyarrow._s3fs.S3FileSystem object...
93
+ >>> checkpoint.path
94
+ 'bucket/path/to/checkpoint'
95
+
96
+ Creating a checkpoint with a custom filesystem:
97
+
98
+ >>> checkpoint = Checkpoint(
99
+ ... path="bucket/path/to/checkpoint",
100
+ ... filesystem=pyarrow.fs.S3FileSystem(),
101
+ ... )
102
+ >>> checkpoint.filesystem # doctest: +ELLIPSIS
103
+ <pyarrow._s3fs.S3FileSystem object...
104
+ >>> checkpoint.path
105
+ 'bucket/path/to/checkpoint'
106
+
107
+ Accessing a checkpoint's contents:
108
+
109
+ >>> import os # doctest: +SKIP
110
+ >>> with checkpoint.as_directory() as local_checkpoint_dir: # doctest: +SKIP
111
+ ... print(os.listdir(local_checkpoint_dir)) # doctest: +SKIP
112
+ ['model.pt', 'optimizer.pt', 'misc.pt']
113
+ """
114
+
115
+ def __init__(
116
+ self,
117
+ path: Union[str, os.PathLike],
118
+ filesystem: Optional["pyarrow.fs.FileSystem"] = None,
119
+ ):
120
+ """Construct a Checkpoint.
121
+
122
+ Args:
123
+ path: A local path or remote URI containing the checkpoint data.
124
+ If a filesystem is provided, then this path must NOT be a URI.
125
+ It should be a path on the filesystem with the prefix already stripped.
126
+ filesystem: PyArrow FileSystem to use to access data at the path.
127
+ If not specified, this is inferred from the URI scheme.
128
+ """
129
+ self.path = str(path)
130
+ self.filesystem = filesystem
131
+
132
+ if path and not filesystem:
133
+ self.filesystem, self.path = pyarrow.fs.FileSystem.from_uri(path)
134
+
135
+ # This random UUID is used to create a temporary directory name on the
136
+ # local filesystem, which will be used for downloading checkpoint data.
137
+ # This ensures that if multiple processes download the same checkpoint object
138
+ # only one process performs the actual download while the others wait.
139
+ # This prevents duplicated download efforts and data.
140
+ # NOTE: Calling `to_directory` from multiple `Checkpoint` objects
141
+ # that point to the same (fs, path) will still download the data multiple times.
142
+ # This only ensures a canonical temp directory name for a single `Checkpoint`.
143
+ self._uuid = uuid.uuid4()
144
+
145
+ def __repr__(self):
146
+ return f"Checkpoint(filesystem={self.filesystem.type_name}, path={self.path})"
147
+
148
+ def get_metadata(self) -> Dict[str, Any]:
149
+ """Return the metadata dict stored with the checkpoint.
150
+
151
+ If no metadata is stored, an empty dict is returned.
152
+ """
153
+ metadata_path = Path(self.path, _METADATA_FILE_NAME).as_posix()
154
+ if not _exists_at_fs_path(self.filesystem, metadata_path):
155
+ return {}
156
+
157
+ with self.filesystem.open_input_file(metadata_path) as f:
158
+ return json.loads(f.readall().decode("utf-8"))
159
+
160
+ def set_metadata(self, metadata: Dict[str, Any]) -> None:
161
+ """Set the metadata stored with this checkpoint.
162
+
163
+ This will overwrite any existing metadata stored with this checkpoint.
164
+ """
165
+ metadata_path = Path(self.path, _METADATA_FILE_NAME).as_posix()
166
+ with self.filesystem.open_output_stream(metadata_path) as f:
167
+ f.write(json.dumps(metadata).encode("utf-8"))
168
+
169
+ def update_metadata(self, metadata: Dict[str, Any]) -> None:
170
+ """Update the metadata stored with this checkpoint.
171
+
172
+ This will update any existing metadata stored with this checkpoint.
173
+ """
174
+ existing_metadata = self.get_metadata()
175
+ existing_metadata.update(metadata)
176
+ self.set_metadata(existing_metadata)
177
+
178
+ @classmethod
179
+ def from_directory(cls, path: Union[str, os.PathLike]) -> "Checkpoint":
180
+ """Create checkpoint object from a local directory.
181
+
182
+ Args:
183
+ path: Local directory containing checkpoint data.
184
+
185
+ Returns:
186
+ A ray.train.Checkpoint object.
187
+ """
188
+ return cls(path, filesystem=pyarrow.fs.LocalFileSystem())
189
+
190
+ def to_directory(self, path: Optional[Union[str, os.PathLike]] = None) -> str:
191
+ """Write checkpoint data to a local directory.
192
+
193
+ *If multiple processes on the same node call this method simultaneously,*
194
+ only a single process will perform the download, while the others
195
+ wait for the download to finish. Once the download finishes, all processes
196
+ receive the same local directory to read from.
197
+
198
+ Args:
199
+ path: Target directory to download data to. If not specified,
200
+ this method will use a temporary directory.
201
+
202
+ Returns:
203
+ str: Directory containing checkpoint data.
204
+ """
205
+ user_provided_path = path is not None
206
+ local_path = (
207
+ path if user_provided_path else self._get_temporary_checkpoint_dir()
208
+ )
209
+ local_path = os.path.normpath(os.path.expanduser(str(local_path)))
210
+ os.makedirs(local_path, exist_ok=True)
211
+
212
+ try:
213
+ # Timeout 0 means there will be only one attempt to acquire
214
+ # the file lock. If it cannot be acquired, throw a TimeoutError
215
+ with TempFileLock(local_path, timeout=0):
216
+ _download_from_fs_path(
217
+ fs=self.filesystem, fs_path=self.path, local_path=local_path
218
+ )
219
+ except TimeoutError:
220
+ # if the directory is already locked, then wait but do not do anything.
221
+ with TempFileLock(local_path, timeout=-1):
222
+ pass
223
+ if not os.path.exists(local_path):
224
+ raise RuntimeError(
225
+ f"Checkpoint directory {local_path} does not exist, "
226
+ "even though it should have been created by "
227
+ "another process. Please raise an issue on GitHub: "
228
+ "https://github.com/ray-project/ray/issues"
229
+ )
230
+
231
+ return local_path
232
+
233
+ @contextlib.contextmanager
234
+ def as_directory(self) -> Iterator[str]:
235
+ """Returns checkpoint contents in a local directory as a context.
236
+
237
+ This function makes checkpoint data available as a directory while avoiding
238
+ unnecessary copies and left-over temporary data.
239
+
240
+ *If the checkpoint points to a local directory*, this method just returns the
241
+ local directory path without making a copy, and nothing will be cleaned up
242
+ after exiting the context.
243
+
244
+ *If the checkpoint points to a remote directory*, this method will download the
245
+ checkpoint to a local temporary directory and return the path
246
+ to the temporary directory.
247
+
248
+ *If multiple processes on the same node call this method simultaneously,*
249
+ only a single process will perform the download, while the others
250
+ wait for the download to finish. Once the download finishes, all processes
251
+ receive the same local (temporary) directory to read from.
252
+
253
+ Once all processes have finished working with the checkpoint,
254
+ the temporary directory is cleaned up.
255
+
256
+ Users should treat the returned checkpoint directory as read-only and avoid
257
+ changing any data within it, as it may be deleted when exiting the context.
258
+
259
+ Example:
260
+
261
+ .. testcode::
262
+ :hide:
263
+
264
+ from pathlib import Path
265
+ import tempfile
266
+
267
+ from ray.train import Checkpoint
268
+
269
+ temp_dir = tempfile.mkdtemp()
270
+ (Path(temp_dir) / "example.txt").write_text("example checkpoint data")
271
+ checkpoint = Checkpoint.from_directory(temp_dir)
272
+
273
+ .. testcode::
274
+
275
+ with checkpoint.as_directory() as checkpoint_dir:
276
+ # Do some read-only processing of files within checkpoint_dir
277
+ pass
278
+
279
+ # At this point, if a temporary directory was created, it will have
280
+ # been deleted.
281
+
282
+ """
283
+ if isinstance(self.filesystem, pyarrow.fs.LocalFileSystem):
284
+ yield self.path
285
+ else:
286
+ del_lock_path = _get_del_lock_path(self._get_temporary_checkpoint_dir())
287
+ open(del_lock_path, "a").close()
288
+
289
+ temp_dir = self.to_directory()
290
+ try:
291
+ yield temp_dir
292
+ finally:
293
+ # Always cleanup the del lock after we're done with the directory.
294
+ # This avoids leaving a lock file behind in the case of an exception
295
+ # in the user code.
296
+ try:
297
+ os.remove(del_lock_path)
298
+ except Exception:
299
+ logger.warning(
300
+ f"Could not remove {del_lock_path} deletion file lock. "
301
+ f"Traceback:\n{traceback.format_exc()}"
302
+ )
303
+
304
+ # If there are no more lock files, that means there are no more
305
+ # readers of this directory, and we can safely delete it.
306
+ # In the edge case (process crash before del lock file is removed),
307
+ # we do not remove the directory at all.
308
+ # Since it's in /tmp, this is not that big of a deal.
309
+ # check if any lock files are remaining
310
+ remaining_locks = _list_existing_del_locks(temp_dir)
311
+ if not remaining_locks:
312
+ try:
313
+ # Timeout 0 means there will be only one attempt to acquire
314
+ # the file lock. If it cannot be acquired, a TimeoutError
315
+ # will be thrown.
316
+ with TempFileLock(temp_dir, timeout=0):
317
+ shutil.rmtree(temp_dir, ignore_errors=True)
318
+ except TimeoutError:
319
+ pass
320
+
321
+ def _get_temporary_checkpoint_dir(self) -> str:
322
+ """Return the name for the temporary checkpoint dir that this checkpoint
323
+ will get downloaded to, if accessing via `to_directory` or `as_directory`.
324
+ """
325
+ tmp_dir_path = tempfile.gettempdir()
326
+ checkpoint_dir_name = _CHECKPOINT_TEMP_DIR_PREFIX + self._uuid.hex
327
+ if platform.system() == "Windows":
328
+ # Max path on Windows is 260 chars, -1 for joining \
329
+ # Also leave a little for the del lock
330
+ del_lock_name = _get_del_lock_path("")
331
+ checkpoint_dir_name = (
332
+ _CHECKPOINT_TEMP_DIR_PREFIX
333
+ + self._uuid.hex[
334
+ -259
335
+ + len(_CHECKPOINT_TEMP_DIR_PREFIX)
336
+ + len(tmp_dir_path)
337
+ + len(del_lock_name) :
338
+ ]
339
+ )
340
+ if not checkpoint_dir_name.startswith(_CHECKPOINT_TEMP_DIR_PREFIX):
341
+ raise RuntimeError(
342
+ "Couldn't create checkpoint directory due to length "
343
+ "constraints. Try specifying a shorter checkpoint path."
344
+ )
345
+ return Path(tmp_dir_path, checkpoint_dir_name).as_posix()
346
+
347
+ def __fspath__(self):
348
+ raise TypeError(
349
+ "You cannot use `Checkpoint` objects directly as paths. "
350
+ "Use `Checkpoint.to_directory()` or `Checkpoint.as_directory()` instead."
351
+ )
352
+
353
+
354
+ def _get_del_lock_path(path: str, suffix: str = None) -> str:
355
+ """Get the path to the deletion lock file for a file/directory at `path`.
356
+
357
+ Example:
358
+
359
+ >>> _get_del_lock_path("/tmp/checkpoint_tmp") # doctest: +ELLIPSIS
360
+ '/tmp/checkpoint_tmp.del_lock_...
361
+ >>> _get_del_lock_path("/tmp/checkpoint_tmp/") # doctest: +ELLIPSIS
362
+ '/tmp/checkpoint_tmp.del_lock_...
363
+ >>> _get_del_lock_path("/tmp/checkpoint_tmp.txt") # doctest: +ELLIPSIS
364
+ '/tmp/checkpoint_tmp.txt.del_lock_...
365
+
366
+ """
367
+ suffix = suffix if suffix is not None else str(os.getpid())
368
+ return f"{path.rstrip('/')}.del_lock_{suffix}"
369
+
370
+
371
+ def _list_existing_del_locks(path: str) -> List[str]:
372
+ """List all the deletion lock files for a file/directory at `path`.
373
+
374
+ For example, if 2 checkpoints are being read via `as_directory`,
375
+ then this should return a list of 2 deletion lock files.
376
+ """
377
+ return list(glob.glob(f"{_get_del_lock_path(path, suffix='*')}"))
378
+
379
+
380
+ def _get_migration_error(name: str):
381
+ return AttributeError(
382
+ f"The new `ray.train.Checkpoint` class does not support `{name}()`. "
383
+ f"Instead, only directories are supported.\n\n"
384
+ f"Example to store a dictionary in a checkpoint:\n\n"
385
+ f"import os, tempfile\n"
386
+ f"import ray.cloudpickle as pickle\n"
387
+ f"from ray import train\n"
388
+ f"from ray.train import Checkpoint\n\n"
389
+ f"with tempfile.TemporaryDirectory() as checkpoint_dir:\n"
390
+ f" with open(os.path.join(checkpoint_dir, 'data.pkl'), 'wb') as fp:\n"
391
+ f" pickle.dump({{'data': 'value'}}, fp)\n\n"
392
+ f" checkpoint = Checkpoint.from_directory(checkpoint_dir)\n"
393
+ f" train.report(..., checkpoint=checkpoint)\n\n"
394
+ f"Example to load a dictionary from a checkpoint:\n\n"
395
+ f"if train.get_checkpoint():\n"
396
+ f" with train.get_checkpoint().as_directory() as checkpoint_dir:\n"
397
+ f" with open(os.path.join(checkpoint_dir, 'data.pkl'), 'rb') as fp:\n"
398
+ f" data = pickle.load(fp)"
399
+ )
400
+
401
+
402
+ def _get_uri_error(name: str):
403
+ return AttributeError(
404
+ f"The new `ray.train.Checkpoint` class does not support `{name}()`. "
405
+ f"To create a checkpoint from remote storage, create a `Checkpoint` using its "
406
+ f"constructor instead of `from_directory`.\n"
407
+ f'Example: `Checkpoint(path="s3://a/b/c")`.\n'
408
+ f"Then, access the contents of the checkpoint with "
409
+ f"`checkpoint.as_directory()` / `checkpoint.to_directory()`.\n"
410
+ f"To upload data to remote storage, use e.g. `pyarrow.fs.FileSystem` "
411
+ f"or your client of choice."
412
+ )
413
+
414
+
415
+ def _get_preprocessor_error(name: str):
416
+ return AttributeError(
417
+ f"The new `ray.train.Checkpoint` class does not support `{name}()`. "
418
+ f"To include preprocessor information in checkpoints, "
419
+ f"pass it as metadata in the <Framework>Trainer constructor.\n"
420
+ f"Example: `TorchTrainer(..., metadata={{...}})`.\n"
421
+ f"After training, access it in the checkpoint via `checkpoint.get_metadata()`. "
422
+ f"See here: https://docs.ray.io/en/master/train/user-guides/"
423
+ f"data-loading-preprocessing.html#preprocessing-structured-data"
424
+ )
.venv/lib/python3.11/site-packages/ray/train/backend.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from contextlib import nullcontext
3
+ from typing import TypeVar
4
+
5
+ from ray.train._internal.utils import Singleton
6
+ from ray.train._internal.worker_group import WorkerGroup
7
+ from ray.util.annotations import DeveloperAPI
8
+ from ray.widgets import make_table_html_repr
9
+
10
+ EncodedData = TypeVar("EncodedData")
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ @DeveloperAPI
16
+ class BackendConfig:
17
+ """Parent class for configurations of training backend."""
18
+
19
+ @property
20
+ def backend_cls(self):
21
+ return Backend
22
+
23
+ @property
24
+ def train_func_context(self):
25
+ return nullcontext
26
+
27
+ def _repr_html_(self) -> str:
28
+ return make_table_html_repr(obj=self, title=type(self).__name__)
29
+
30
+
31
+ @DeveloperAPI
32
+ class Backend(metaclass=Singleton):
33
+ """Singleton for distributed communication backend.
34
+
35
+ Attributes:
36
+ share_cuda_visible_devices: If True, each worker
37
+ process will have CUDA_VISIBLE_DEVICES set as the visible device
38
+ IDs of all workers on the same node for this training instance.
39
+ If False, each worker will have CUDA_VISIBLE_DEVICES set to the
40
+ device IDs allocated by Ray for that worker.
41
+ """
42
+
43
+ share_cuda_visible_devices: bool = False
44
+
45
+ def on_start(self, worker_group: WorkerGroup, backend_config: BackendConfig):
46
+ """Logic for starting this backend."""
47
+ pass
48
+
49
+ def on_shutdown(self, worker_group: WorkerGroup, backend_config: BackendConfig):
50
+ """Logic for shutting down the backend."""
51
+ pass
52
+
53
+ def on_training_start(
54
+ self, worker_group: WorkerGroup, backend_config: BackendConfig
55
+ ):
56
+ """Logic ran right before training is started.
57
+
58
+ Session API is available at this point."""
59
+ pass
.venv/lib/python3.11/site-packages/ray/train/base_trainer.py ADDED
@@ -0,0 +1,827 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import copy
3
+ import inspect
4
+ import json
5
+ import logging
6
+ import os
7
+ import warnings
8
+ from functools import partial
9
+ from pathlib import Path
10
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
11
+
12
+ import pyarrow.fs
13
+
14
+ import ray
15
+ import ray.cloudpickle as pickle
16
+ from ray._private.dict import deep_update
17
+ from ray.air._internal import usage as air_usage
18
+ from ray.air._internal.config import ensure_only_allowed_dataclass_keys_updated
19
+ from ray.air._internal.usage import AirEntrypoint
20
+ from ray.air.config import RunConfig, ScalingConfig
21
+ from ray.air.result import Result
22
+ from ray.train import Checkpoint
23
+ from ray.train._internal.session import get_session
24
+ from ray.train._internal.storage import (
25
+ StorageContext,
26
+ _exists_at_fs_path,
27
+ get_fs_and_path,
28
+ )
29
+ from ray.util import PublicAPI
30
+ from ray.util.annotations import DeveloperAPI
31
+
32
+ if TYPE_CHECKING:
33
+ from ray.data import Dataset
34
+ from ray.tune import Trainable
35
+
36
+ _TRAINER_PKL = "trainer.pkl"
37
+
38
+ # A type representing either a ray.data.Dataset or a function that returns a
39
+ # ray.data.Dataset and accepts no arguments.
40
+ GenDataset = Union["Dataset", Callable[[], "Dataset"]]
41
+
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+ PREPROCESSOR_DEPRECATION_MESSAGE = (
46
+ "The `preprocessor` argument to Trainers is deprecated as of Ray 2.7. "
47
+ "Instead, use the Preprocessor `fit` and `transform` APIs directly on the Ray "
48
+ "Dataset. For any state that needs to be saved to the trained checkpoint, pass it "
49
+ "in using the `metadata` argument of the `Trainer`. "
50
+ "For a full example, see "
51
+ "https://docs.ray.io/en/master/train/user-guides/data-loading-preprocessing.html#preprocessing-structured-data " # noqa:E501
52
+ )
53
+
54
+
55
+ @PublicAPI(stability="beta")
56
+ class TrainingFailedError(RuntimeError):
57
+ """An error indicating that training has failed."""
58
+
59
+ _RESTORE_MSG = (
60
+ "The Ray Train run failed. Please inspect the previous error messages for a "
61
+ "cause. After fixing the issue (assuming that the error is not caused by "
62
+ "your own application logic, but rather an error such as OOM), you can restart "
63
+ "the run from scratch or continue this run.\n"
64
+ "To continue this run, you can use: "
65
+ '`trainer = {trainer_cls_name}.restore("{path}")`.'
66
+ )
67
+
68
+ _FAILURE_CONFIG_MSG = (
69
+ "To start a new run that will retry on training failures, set "
70
+ "`train.RunConfig(failure_config=train.FailureConfig(max_failures))` "
71
+ "in the Trainer's `run_config` with `max_failures > 0`, or `max_failures = -1` "
72
+ "for unlimited retries."
73
+ )
74
+
75
+
76
+ def _train_coordinator_fn(
77
+ config: dict, trainer_cls: Type["BaseTrainer"], metadata: dict
78
+ ):
79
+ """This is the function that defines the logic of the Ray Train coordinator.
80
+ This is responsible for setting up a remote instance of the `trainer_cls`
81
+ (a different instance than the one calling `trainer.fit` on the driver!)
82
+ and running the training loop.
83
+ """
84
+ assert metadata is not None, metadata
85
+ # Propagate user metadata from the Trainer constructor.
86
+ get_session().metadata = metadata
87
+
88
+ # config already contains merged values.
89
+ # Instantiate new Trainer in Trainable.
90
+ trainer = trainer_cls(**config)
91
+
92
+ # Get the checkpoint from Tune and pass it to workers later on.
93
+ checkpoint = ray.train.get_checkpoint()
94
+ if checkpoint:
95
+ # Set `starting_checkpoint` for auto-recovery fault-tolerance
96
+ # as well as manual restoration.
97
+ trainer.starting_checkpoint = checkpoint
98
+ # else: Train will restore from the user-provided
99
+ # `resume_from_checkpoint` == `starting_checkpoint`.
100
+
101
+ # Evaluate datasets if they are wrapped in a factory.
102
+ trainer.datasets = {
103
+ k: d() if callable(d) else d for k, d in trainer.datasets.items()
104
+ }
105
+
106
+ trainer.setup()
107
+ trainer.training_loop()
108
+
109
+
110
+ @DeveloperAPI
111
+ class BaseTrainer(abc.ABC):
112
+ """Defines interface for distributed training on Ray.
113
+
114
+ Note: The base ``BaseTrainer`` class cannot be instantiated directly. Only
115
+ one of its subclasses can be used.
116
+
117
+ Note to developers: If a new trainer is added, please update
118
+ `air/_internal/usage.py`.
119
+
120
+ **How does a trainer work?**
121
+
122
+ - First, initialize the Trainer. The initialization runs locally,
123
+ so heavyweight setup should not be done in ``__init__``.
124
+ - Then, when you call ``trainer.fit()``, the Trainer is serialized
125
+ and copied to a remote Ray actor. The following methods are then
126
+ called in sequence on the remote actor.
127
+ - ``trainer.setup()``: Any heavyweight Trainer setup should be
128
+ specified here.
129
+ - ``trainer.training_loop()``: Executes the main training logic.
130
+ - Calling ``trainer.fit()`` will return a ``ray.result.Result``
131
+ object where you can access metrics from your training run, as well
132
+ as any checkpoints that may have been saved.
133
+
134
+ **How do I create a new Trainer?**
135
+
136
+ Subclass ``ray.train.trainer.BaseTrainer``, and override the ``training_loop``
137
+ method, and optionally ``setup``.
138
+
139
+ .. testcode::
140
+
141
+ import torch
142
+
143
+ from ray.train.trainer import BaseTrainer
144
+ from ray import train, tune
145
+
146
+
147
+ class MyPytorchTrainer(BaseTrainer):
148
+ def setup(self):
149
+ self.model = torch.nn.Linear(1, 1)
150
+ self.optimizer = torch.optim.SGD(
151
+ self.model.parameters(), lr=0.1)
152
+
153
+ def training_loop(self):
154
+ # You can access any Trainer attributes directly in this method.
155
+ # self.datasets["train"] has already been
156
+ dataset = self.datasets["train"]
157
+
158
+ torch_ds = dataset.iter_torch_batches(dtypes=torch.float)
159
+ loss_fn = torch.nn.MSELoss()
160
+
161
+ for epoch_idx in range(10):
162
+ loss = 0
163
+ num_batches = 0
164
+ torch_ds = dataset.iter_torch_batches(
165
+ dtypes=torch.float, batch_size=2
166
+ )
167
+ for batch in torch_ds:
168
+ X = torch.unsqueeze(batch["x"], 1)
169
+ y = torch.unsqueeze(batch["y"], 1)
170
+ # Compute prediction error
171
+ pred = self.model(X)
172
+ batch_loss = loss_fn(pred, y)
173
+
174
+ # Backpropagation
175
+ self.optimizer.zero_grad()
176
+ batch_loss.backward()
177
+ self.optimizer.step()
178
+
179
+ loss += batch_loss.item()
180
+ num_batches += 1
181
+ loss /= num_batches
182
+
183
+ # Use Tune functions to report intermediate
184
+ # results.
185
+ train.report({"loss": loss, "epoch": epoch_idx})
186
+
187
+
188
+ # Initialize the Trainer, and call Trainer.fit()
189
+ import ray
190
+ train_dataset = ray.data.from_items(
191
+ [{"x": i, "y": i} for i in range(10)])
192
+ my_trainer = MyPytorchTrainer(datasets={"train": train_dataset})
193
+ result = my_trainer.fit()
194
+
195
+ .. testoutput::
196
+ :hide:
197
+
198
+ ...
199
+
200
+ Args:
201
+ scaling_config: Configuration for how to scale training.
202
+ run_config: Configuration for the execution of the training run.
203
+ datasets: Any Datasets to use for training. Use the key "train"
204
+ to denote which dataset is the training dataset.
205
+ metadata: Dict that should be made available via
206
+ `train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
207
+ for checkpoints saved from this Trainer. Must be JSON-serializable.
208
+ resume_from_checkpoint: A checkpoint to resume training from.
209
+ """
210
+
211
+ _scaling_config_allowed_keys: List[str] = [
212
+ "trainer_resources",
213
+ ]
214
+ _handles_checkpoint_freq: bool = False
215
+ _handles_checkpoint_at_end: bool = False
216
+
217
+ # fields to propagate to Tuner param_space.
218
+ # See `BaseTrainer._extract_fields_for_tuner_param_space` for more details.
219
+ _fields_for_tuner_param_space = []
220
+
221
+ def __init__(
222
+ self,
223
+ *,
224
+ scaling_config: Optional[ScalingConfig] = None,
225
+ run_config: Optional[RunConfig] = None,
226
+ datasets: Optional[Dict[str, GenDataset]] = None,
227
+ metadata: Optional[Dict[str, Any]] = None,
228
+ resume_from_checkpoint: Optional[Checkpoint] = None,
229
+ ):
230
+ self.scaling_config = (
231
+ scaling_config if scaling_config is not None else ScalingConfig()
232
+ )
233
+ self.run_config = (
234
+ copy.copy(run_config) if run_config is not None else RunConfig()
235
+ )
236
+ self.metadata = metadata
237
+ self.datasets = datasets if datasets is not None else {}
238
+ self.starting_checkpoint = resume_from_checkpoint
239
+
240
+ # These attributes should only be set through `BaseTrainer.restore`
241
+ self._restore_path = None
242
+ self._restore_storage_filesystem = None
243
+
244
+ self._validate_attributes()
245
+
246
+ air_usage.tag_air_trainer(self)
247
+
248
+ @PublicAPI(stability="alpha")
249
+ @classmethod
250
+ def restore(
251
+ cls: Type["BaseTrainer"],
252
+ path: Union[str, os.PathLike],
253
+ storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
254
+ datasets: Optional[Dict[str, GenDataset]] = None,
255
+ scaling_config: Optional[ScalingConfig] = None,
256
+ **kwargs,
257
+ ) -> "BaseTrainer":
258
+ """Restores a Train experiment from a previously interrupted/failed run.
259
+
260
+ Restore should be used for experiment-level fault tolerance in the event
261
+ that the head node crashes (e.g., OOM or some other runtime error) or the
262
+ entire cluster goes down (e.g., network error affecting all nodes).
263
+
264
+ A run that has already completed successfully will not be resumed from this API.
265
+ To continue training from a successful run, launch a new run with the
266
+ ``<Framework>Trainer(resume_from_checkpoint)`` API instead, passing in a
267
+ checkpoint from the previous run to start with.
268
+
269
+ .. note::
270
+
271
+ Restoring an experiment from a path that's pointing to a *different*
272
+ location than the original experiment path is supported. However, Ray Train
273
+ assumes that the full experiment directory is available
274
+ (including checkpoints) so that it's possible to resume trials from their
275
+ latest state.
276
+
277
+ For example, if the original experiment path was run locally, then the
278
+ results are uploaded to cloud storage, Ray Train expects the full contents
279
+ to be available in cloud storage if attempting to resume
280
+ via ``<Framework>Trainer.restore("s3://...")``. The restored run will
281
+ continue writing results to the same cloud storage location.
282
+
283
+ The following example can be paired with implementing job retry using
284
+ :ref:`Ray Jobs <jobs-overview>` to produce a Train experiment that will
285
+ attempt to resume on both experiment-level and trial-level failures:
286
+
287
+ .. testcode::
288
+
289
+ import os
290
+ import ray
291
+ from ray import train
292
+ from ray.train.trainer import BaseTrainer
293
+
294
+ experiment_name = "unique_experiment_name"
295
+ storage_path = os.path.expanduser("~/ray_results")
296
+ experiment_dir = os.path.join(storage_path, experiment_name)
297
+
298
+ # Define some dummy inputs for demonstration purposes
299
+ datasets = {"train": ray.data.from_items([{"a": i} for i in range(10)])}
300
+
301
+ class CustomTrainer(BaseTrainer):
302
+ def training_loop(self):
303
+ pass
304
+
305
+ if CustomTrainer.can_restore(experiment_dir):
306
+ trainer = CustomTrainer.restore(
307
+ experiment_dir, datasets=datasets
308
+ )
309
+ else:
310
+ trainer = CustomTrainer(
311
+ datasets=datasets,
312
+ run_config=train.RunConfig(
313
+ name=experiment_name,
314
+ storage_path=storage_path,
315
+ # Tip: You can also enable retries on failure for
316
+ # worker-level fault tolerance
317
+ failure_config=train.FailureConfig(max_failures=3),
318
+ ),
319
+ )
320
+
321
+ result = trainer.fit()
322
+
323
+ .. testoutput::
324
+ :hide:
325
+
326
+ ...
327
+
328
+ Args:
329
+ path: The path to the experiment directory of the training run to restore.
330
+ This can be a local path or a remote URI if the experiment was
331
+ uploaded to the cloud.
332
+ storage_filesystem: Custom ``pyarrow.fs.FileSystem``
333
+ corresponding to the ``path``. This may be necessary if the original
334
+ experiment passed in a custom filesystem.
335
+ datasets: Re-specified datasets used in the original training run.
336
+ This must include all the datasets that were passed in the
337
+ original trainer constructor.
338
+ scaling_config: Optionally re-specified scaling config. This can be
339
+ modified to be different from the original spec.
340
+ **kwargs: Other optionally re-specified arguments, passed in by subclasses.
341
+
342
+ Raises:
343
+ ValueError: If all datasets were not re-supplied on restore.
344
+
345
+ Returns:
346
+ BaseTrainer: A restored instance of the class that is calling this method.
347
+ """
348
+ if not cls.can_restore(path, storage_filesystem):
349
+ raise ValueError(
350
+ f"Invalid restore path: {path}. Make sure that this path exists and "
351
+ "is the experiment directory that results from a call to "
352
+ "`trainer.fit()`."
353
+ )
354
+ fs, fs_path = get_fs_and_path(path, storage_filesystem)
355
+ trainer_pkl_path = Path(fs_path, _TRAINER_PKL).as_posix()
356
+ with fs.open_input_file(trainer_pkl_path) as f:
357
+ trainer_cls, param_dict = pickle.loads(f.readall())
358
+
359
+ if trainer_cls is not cls:
360
+ warnings.warn(
361
+ f"Invalid trainer type. You are attempting to restore a trainer of type"
362
+ f" {trainer_cls} with `{cls.__name__}.restore`, "
363
+ "which will most likely fail. "
364
+ f"Use `{trainer_cls.__name__}.restore` instead."
365
+ )
366
+
367
+ original_datasets = param_dict.pop("datasets", {})
368
+ if original_datasets and not datasets:
369
+ raise ValueError(
370
+ "The following datasets need to be provided again on restore: "
371
+ f"{list(original_datasets.keys())}\n"
372
+ f"Use {cls.__name__}.restore(..., datasets=datasets) "
373
+ "with the datasets that were provided to the original trainer."
374
+ )
375
+ datasets = datasets or {}
376
+ if set(original_datasets) != set(datasets):
377
+ raise ValueError(
378
+ "The provided datasets don't match the original dataset keys.\n"
379
+ f" Expected datasets for the keys: {list(original_datasets.keys())}\n"
380
+ f" Actual datasets provided: {list(datasets.keys())}"
381
+ )
382
+ param_dict["datasets"] = datasets
383
+
384
+ if scaling_config:
385
+ param_dict["scaling_config"] = scaling_config
386
+
387
+ for param_name, val in kwargs.items():
388
+ # Overwrite the old value if something is passed into restore
389
+ if val is not None:
390
+ param_dict[param_name] = val
391
+
392
+ try:
393
+ trainer = cls(**param_dict)
394
+ except Exception as e:
395
+ raise ValueError(
396
+ "Trainer restoration failed (see above for the stack trace). "
397
+ "Make sure that you use the right trainer class to restore: "
398
+ f"`{cls.__name__}.restore`\n"
399
+ ) from e
400
+ trainer._restore_path = path
401
+ trainer._restore_storage_filesystem = storage_filesystem
402
+ return trainer
403
+
404
+ @PublicAPI(stability="alpha")
405
+ @classmethod
406
+ def can_restore(
407
+ cls: Type["BaseTrainer"],
408
+ path: Union[str, os.PathLike],
409
+ storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
410
+ ) -> bool:
411
+ """Checks whether a given directory contains a restorable Train experiment.
412
+
413
+ Args:
414
+ path: The path to the experiment directory of the Train experiment.
415
+ This can be either a local directory (e.g., ~/ray_results/exp_name)
416
+ or a remote URI (e.g., s3://bucket/exp_name).
417
+
418
+ Returns:
419
+ bool: Whether this path exists and contains the trainer state to resume from
420
+ """
421
+ fs, fs_path = get_fs_and_path(path, storage_filesystem)
422
+ trainer_pkl_path = Path(fs_path, _TRAINER_PKL).as_posix()
423
+ return _exists_at_fs_path(fs, trainer_pkl_path)
424
+
425
+ def __repr__(self):
426
+ # A dictionary that maps parameters to their default values.
427
+ default_values: Dict[str, Any] = {
428
+ "scaling_config": ScalingConfig(),
429
+ "run_config": RunConfig(),
430
+ "datasets": {},
431
+ "starting_checkpoint": None,
432
+ }
433
+
434
+ non_default_arguments = []
435
+ for parameter, default_value in default_values.items():
436
+ value = getattr(self, parameter)
437
+ if value != default_value:
438
+ non_default_arguments.append(f"{parameter}={value!r}")
439
+
440
+ if non_default_arguments:
441
+ return f"<{self.__class__.__name__} {' '.join(non_default_arguments)}>"
442
+
443
+ return f"<{self.__class__.__name__}>"
444
+
445
+ def __new__(cls, *args, **kwargs):
446
+ # Store the init args as attributes so this can be merged with Tune hparams.
447
+ trainer = super(BaseTrainer, cls).__new__(cls)
448
+ parameters = inspect.signature(cls.__init__).parameters
449
+ parameters = list(parameters.keys())
450
+ # Remove self.
451
+ parameters = parameters[1:]
452
+ arg_dict = dict(zip(parameters, args))
453
+ trainer._param_dict = {**arg_dict, **kwargs}
454
+ return trainer
455
+
456
+ def _validate_attributes(self):
457
+ """Called on __init()__ to validate trainer attributes."""
458
+ # Run config
459
+ if not isinstance(self.run_config, RunConfig):
460
+ raise ValueError(
461
+ f"`run_config` should be an instance of `ray.train.RunConfig`, "
462
+ f"found {type(self.run_config)} with value `{self.run_config}`."
463
+ )
464
+ # Scaling config
465
+ if not isinstance(self.scaling_config, ScalingConfig):
466
+ raise ValueError(
467
+ "`scaling_config` should be an instance of `ScalingConfig`, "
468
+ f"found {type(self.scaling_config)} with value `{self.scaling_config}`."
469
+ )
470
+ # Datasets
471
+ if not isinstance(self.datasets, dict):
472
+ raise ValueError(
473
+ f"`datasets` should be a dict mapping from a string to "
474
+ f"`ray.data.Dataset` objects, "
475
+ f"found {type(self.datasets)} with value `{self.datasets}`."
476
+ )
477
+ else:
478
+ for key, dataset in self.datasets.items():
479
+ if not isinstance(dataset, ray.data.Dataset) and not callable(dataset):
480
+ raise ValueError(
481
+ f"The Dataset under '{key}' key is not a "
482
+ "`ray.data.Dataset`. "
483
+ f"Received {dataset} instead."
484
+ )
485
+ # Metadata.
486
+ self.metadata = self.metadata or {}
487
+ if not isinstance(self.metadata, dict):
488
+ raise TypeError(
489
+ f"The provided metadata must be a dict, was {type(self.metadata)}."
490
+ )
491
+ try:
492
+ self.metadata = json.loads(json.dumps(self.metadata))
493
+ except Exception as e:
494
+ raise ValueError(
495
+ "The provided metadata must be JSON-serializable: "
496
+ f"{self.metadata}: {e}"
497
+ )
498
+
499
+ if self.starting_checkpoint is not None and not isinstance(
500
+ self.starting_checkpoint, Checkpoint
501
+ ):
502
+ raise ValueError(
503
+ f"`resume_from_checkpoint` should be an instance of "
504
+ f"`ray.train.Checkpoint`, found {type(self.starting_checkpoint)} "
505
+ f"with value `{self.starting_checkpoint}`."
506
+ )
507
+
508
+ @classmethod
509
+ def _validate_scaling_config(cls, scaling_config: ScalingConfig) -> ScalingConfig:
510
+ """Returns scaling config dataclass after validating updated keys."""
511
+ ensure_only_allowed_dataclass_keys_updated(
512
+ dataclass=scaling_config,
513
+ allowed_keys=cls._scaling_config_allowed_keys,
514
+ )
515
+ return scaling_config
516
+
517
+ def setup(self) -> None:
518
+ """Called during fit() to perform initial setup on the Trainer.
519
+
520
+ .. note:: This method is run on a remote process.
521
+
522
+ This method will not be called on the driver, so any expensive setup
523
+ operations should be placed here and not in ``__init__``.
524
+
525
+ This method is called prior to ``preprocess_datasets`` and
526
+ ``training_loop``.
527
+ """
528
+ pass
529
+
530
+ def preprocess_datasets(self) -> None:
531
+ """Deprecated."""
532
+ raise DeprecationWarning(
533
+ "`preprocess_datasets` is no longer used, since preprocessors "
534
+ f"are no longer accepted by Trainers.\n{PREPROCESSOR_DEPRECATION_MESSAGE}"
535
+ )
536
+
537
+ @abc.abstractmethod
538
+ def training_loop(self) -> None:
539
+ """Loop called by fit() to run training and report results to Tune.
540
+
541
+ .. note:: This method runs on a remote process.
542
+
543
+ ``self.datasets`` have already been evaluated if they were wrapped in a factory.
544
+
545
+ You can use the :ref:`Ray Train utilities <train-loop-api>`
546
+ (:func:`train.report() <ray.train.report>` and
547
+ :func:`train.get_checkpoint() <ray.train.get_checkpoint>`) inside
548
+ this training loop.
549
+
550
+ Example:
551
+
552
+ .. testcode::
553
+
554
+ from ray.train.trainer import BaseTrainer
555
+ from ray import train
556
+
557
+ class MyTrainer(BaseTrainer):
558
+ def training_loop(self):
559
+ for epoch_idx in range(5):
560
+ ...
561
+ train.report({"epoch": epoch_idx})
562
+
563
+ """
564
+ raise NotImplementedError
565
+
566
+ @PublicAPI(stability="beta")
567
+ def fit(self) -> Result:
568
+ """Runs training.
569
+
570
+ Returns:
571
+ A Result object containing the training result.
572
+
573
+ Raises:
574
+ TrainingFailedError: If any failures during the execution
575
+ of ``self.as_trainable()``, or during the Tune execution loop.
576
+ """
577
+ from ray.tune import ResumeConfig, TuneError
578
+ from ray.tune.tuner import Tuner
579
+
580
+ trainable = self.as_trainable()
581
+ param_space = self._extract_fields_for_tuner_param_space()
582
+
583
+ self.run_config.name = (
584
+ self.run_config.name or StorageContext.get_experiment_dir_name(trainable)
585
+ )
586
+ # The storage context here is only used to access the resolved
587
+ # storage fs and experiment path, in order to avoid duplicating that logic.
588
+ # This is NOT the storage context object that gets passed to remote workers.
589
+ storage = StorageContext(
590
+ storage_path=self.run_config.storage_path,
591
+ experiment_dir_name=self.run_config.name,
592
+ storage_filesystem=self.run_config.storage_filesystem,
593
+ )
594
+
595
+ if self._restore_path:
596
+ tuner = Tuner.restore(
597
+ path=self._restore_path,
598
+ trainable=trainable,
599
+ param_space=param_space,
600
+ _resume_config=ResumeConfig(
601
+ finished=ResumeConfig.ResumeType.RESUME,
602
+ unfinished=ResumeConfig.ResumeType.RESUME,
603
+ errored=ResumeConfig.ResumeType.RESUME,
604
+ ),
605
+ storage_filesystem=self._restore_storage_filesystem,
606
+ )
607
+ else:
608
+ tuner = Tuner(
609
+ trainable=trainable,
610
+ param_space=param_space,
611
+ run_config=self.run_config,
612
+ _entrypoint=AirEntrypoint.TRAINER,
613
+ )
614
+
615
+ self._save(storage.storage_filesystem, storage.experiment_fs_path)
616
+
617
+ restore_msg = TrainingFailedError._RESTORE_MSG.format(
618
+ trainer_cls_name=self.__class__.__name__,
619
+ path=str(storage.experiment_fs_path),
620
+ )
621
+
622
+ try:
623
+ result_grid = tuner.fit()
624
+ except TuneError as e:
625
+ # Catch any `TuneError`s raised by the `Tuner.fit` call.
626
+ # Unwrap the `TuneError` if needed.
627
+ parent_error = e.__cause__ or e
628
+
629
+ # Raise it to the user as a `TrainingFailedError` with a message to restore.
630
+ raise TrainingFailedError(restore_msg) from parent_error
631
+ # Other exceptions get passed through directly (ex: on `fail_fast='raise'`)
632
+
633
+ assert len(result_grid) == 1
634
+ result = result_grid[0]
635
+ if result.error:
636
+ # Raise trainable errors to the user with a message to restore
637
+ # or configure `FailureConfig` in a new run.
638
+ raise TrainingFailedError(
639
+ "\n".join([restore_msg, TrainingFailedError._FAILURE_CONFIG_MSG])
640
+ ) from result.error
641
+ return result
642
+
643
+ def _save(self, fs: pyarrow.fs.FileSystem, experiment_path: str):
644
+ """Saves the current trainer's class along with the `param_dict` of
645
+ parameters passed to this trainer's constructor.
646
+
647
+ This is used to recreate the trainer on restore.
648
+ Unless a parameter is re-specified during restoration (only a subset
649
+ of parameters can be passed in again), that parameter will be loaded
650
+ from the saved copy.
651
+
652
+ Datasets should not be saved as part of the state. Instead, we save the
653
+ keys and replace the dataset values with dummy functions that will
654
+ raise an error if invoked. The error only serves as a guardrail for
655
+ misuse (e.g., manually unpickling and constructing the Trainer again)
656
+ and is not typically surfaced, since datasets must be re-specified
657
+ upon restoration.
658
+ """
659
+ param_dict = self._param_dict.copy()
660
+ datasets = param_dict.pop("datasets", {})
661
+
662
+ def raise_fn():
663
+ raise RuntimeError
664
+
665
+ if datasets:
666
+ param_dict["datasets"] = {
667
+ dataset_name: raise_fn for dataset_name in datasets
668
+ }
669
+
670
+ cls_and_param_dict = (self.__class__, param_dict)
671
+
672
+ fs.create_dir(experiment_path)
673
+ with fs.open_output_stream(Path(experiment_path, _TRAINER_PKL).as_posix()) as f:
674
+ f.write(pickle.dumps(cls_and_param_dict))
675
+
676
+ def _extract_fields_for_tuner_param_space(self) -> Dict:
677
+ """Extracts fields to be included in `Tuner.param_space`.
678
+
679
+ This is needed to leverage the full logging/integration offerings from Tune.
680
+ For example, `param_space` is logged automatically to wandb integration.
681
+
682
+ Currently only done for `train_loop_config`.
683
+
684
+ Returns:
685
+ A dictionary that should be passed to Tuner.param_space.
686
+ """
687
+ result = {}
688
+ for key in self._fields_for_tuner_param_space:
689
+ if key in self._param_dict.keys():
690
+ result[key] = copy.deepcopy(self._param_dict[key])
691
+ return result
692
+
693
+ def _generate_trainable_cls(self) -> Type["Trainable"]:
694
+ """Generates the base Trainable class.
695
+
696
+ Returns:
697
+ A Trainable class to use for training.
698
+ """
699
+
700
+ from ray.tune.execution.placement_groups import PlacementGroupFactory
701
+ from ray.tune.trainable import wrap_function
702
+
703
+ trainer_cls = self.__class__
704
+ scaling_config = self.scaling_config
705
+ metadata = self.metadata
706
+
707
+ train_coordinator_fn = partial(
708
+ _train_coordinator_fn, trainer_cls=trainer_cls, metadata=metadata
709
+ )
710
+ # Change the name of the training function to match the name of the Trainer
711
+ # class. This will mean the Tune trial name will match the name of Trainer on
712
+ # stdout messages and the results directory.
713
+ train_coordinator_fn.__name__ = trainer_cls.__name__
714
+
715
+ trainable_cls = wrap_function(train_coordinator_fn)
716
+ has_base_dataset = bool(self.datasets)
717
+ if has_base_dataset:
718
+ from ray.data.context import DataContext
719
+
720
+ dataset_context = DataContext.get_current()
721
+ else:
722
+ dataset_context = None
723
+
724
+ class TrainTrainable(trainable_cls):
725
+ """Adds default resources to the Trainable."""
726
+
727
+ _handles_checkpoint_freq = trainer_cls._handles_checkpoint_freq
728
+ _handles_checkpoint_at_end = trainer_cls._handles_checkpoint_at_end
729
+
730
+ @classmethod
731
+ def has_base_dataset(cls) -> bool:
732
+ """Whether a dataset is provided through the Trainer."""
733
+ return has_base_dataset
734
+
735
+ @classmethod
736
+ def base_scaling_config(cls) -> ScalingConfig:
737
+ """Returns the unchanged scaling config provided through the Trainer."""
738
+ return scaling_config
739
+
740
+ def setup(self, config, **kwargs):
741
+ base_config = dict(kwargs)
742
+ # Merge Tuner param space hyperparameters in `config` into the
743
+ # base config passed to the Trainer constructor, which is `base_config`.
744
+ # `base_config` is pulled from the object store from the usage of
745
+ # tune.with_parameters in `BaseTrainer.as_trainable`.
746
+
747
+ # run_config is not a tunable hyperparameter so it does not need to be
748
+ # merged.
749
+ run_config = base_config.pop("run_config", None)
750
+ self._merged_config = deep_update(
751
+ base_config, self.config, new_keys_allowed=True
752
+ )
753
+ self._merged_config["run_config"] = run_config
754
+ merged_scaling_config = self._merged_config.get(
755
+ "scaling_config", ScalingConfig()
756
+ )
757
+ if isinstance(merged_scaling_config, dict):
758
+ merged_scaling_config = ScalingConfig(**merged_scaling_config)
759
+ self._merged_config[
760
+ "scaling_config"
761
+ ] = self._reconcile_scaling_config_with_trial_resources(
762
+ merged_scaling_config
763
+ )
764
+ if self.has_base_dataset():
765
+ # Set the DataContext on the Trainer actor to the DataContext
766
+ # specified on the driver.
767
+ DataContext._set_current(dataset_context)
768
+ super(TrainTrainable, self).setup(config)
769
+
770
+ def _reconcile_scaling_config_with_trial_resources(
771
+ self, scaling_config: ScalingConfig
772
+ ) -> ScalingConfig:
773
+ """
774
+ ResourceChangingScheduler workaround.
775
+
776
+ Ensures that the scaling config matches trial resources.
777
+
778
+ This should be replaced with RCS returning a ScalingConfig
779
+ in the future.
780
+ """
781
+
782
+ trial_resources = self.trial_resources
783
+ # This will be false if the resources are default
784
+ if not isinstance(trial_resources, PlacementGroupFactory):
785
+ return scaling_config
786
+
787
+ # Ignore ResourceChangingScheduler workaround when resource bundles
788
+ # are unchanged
789
+ if self.trial_resources == scaling_config.as_placement_group_factory():
790
+ return scaling_config
791
+
792
+ trainer_cls._validate_scaling_config(scaling_config)
793
+
794
+ return ScalingConfig.from_placement_group_factory(trial_resources)
795
+
796
+ def _trainable_func(self, config):
797
+ # We ignore the config passed by Tune and instead use the merged
798
+ # config which includes the initial Trainer args.
799
+ super()._trainable_func(self._merged_config)
800
+
801
+ @classmethod
802
+ def default_resource_request(cls, config):
803
+ # `config["scaling_config"] is a dataclass when passed via the
804
+ # `scaling_config` argument in `Trainer` and is a dict when passed
805
+ # via the `scaling_config` key of `param_spec`.
806
+
807
+ # Conversion logic must be duplicated in `TrainTrainable.__init__`
808
+ # because this is a class method.
809
+ updated_scaling_config = config.get("scaling_config", scaling_config)
810
+ if isinstance(updated_scaling_config, dict):
811
+ updated_scaling_config = ScalingConfig(**updated_scaling_config)
812
+ validated_scaling_config = trainer_cls._validate_scaling_config(
813
+ updated_scaling_config
814
+ )
815
+ return validated_scaling_config.as_placement_group_factory()
816
+
817
+ return TrainTrainable
818
+
819
+ def as_trainable(self) -> Type["Trainable"]:
820
+ """Converts self to a ``tune.Trainable`` class."""
821
+ from ray import tune
822
+
823
+ base_config = self._param_dict
824
+ trainable_cls = self._generate_trainable_cls()
825
+
826
+ # Wrap with `tune.with_parameters` to handle very large values in base_config
827
+ return tune.with_parameters(trainable_cls, **base_config)
.venv/lib/python3.11/site-packages/ray/train/constants.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import ray
4
+ from ray._private.ray_constants import env_bool
5
+ from ray.air.constants import ( # noqa: F401
6
+ COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV,
7
+ EVALUATION_DATASET_KEY,
8
+ MODEL_KEY,
9
+ PREPROCESSOR_KEY,
10
+ TRAIN_DATASET_KEY,
11
+ )
12
+
13
+
14
+ def _get_ray_train_session_dir() -> str:
15
+ assert ray.is_initialized(), "Ray must be initialized to get the session dir."
16
+ return Path(
17
+ ray._private.worker._global_node.get_session_dir_path(), "artifacts"
18
+ ).as_posix()
19
+
20
+
21
+ DEFAULT_STORAGE_PATH = Path("~/ray_results").expanduser().as_posix()
22
+
23
+ # Autofilled ray.train.report() metrics. Keys should be consistent with Tune.
24
+ CHECKPOINT_DIR_NAME = "checkpoint_dir_name"
25
+ TIME_TOTAL_S = "_time_total_s"
26
+ WORKER_HOSTNAME = "_hostname"
27
+ WORKER_NODE_IP = "_node_ip"
28
+ WORKER_PID = "_pid"
29
+
30
+ # Will not be reported unless ENABLE_DETAILED_AUTOFILLED_METRICS_ENV
31
+ # env var is not 0
32
+ DETAILED_AUTOFILLED_KEYS = {WORKER_HOSTNAME, WORKER_NODE_IP, WORKER_PID, TIME_TOTAL_S}
33
+
34
+ # Default filename for JSON logger
35
+ RESULT_FILE_JSON = "results.json"
36
+
37
+ # The name of the subdirectory inside the trainer run_dir to store checkpoints.
38
+ TRAIN_CHECKPOINT_SUBDIR = "checkpoints"
39
+
40
+ # The key to use to specify the checkpoint id for Tune.
41
+ # This needs to be added to the checkpoint dictionary so if the Tune trial
42
+ # is restarted, the checkpoint_id can continue to increment.
43
+ TUNE_CHECKPOINT_ID = "_current_checkpoint_id"
44
+
45
+ # Deprecated configs can use this value to detect if the user has set it.
46
+ _DEPRECATED_VALUE = "DEPRECATED"
47
+
48
+ # ==================================================
49
+ # Environment Variables
50
+ # ==================================================
51
+
52
+ ENABLE_DETAILED_AUTOFILLED_METRICS_ENV = (
53
+ "TRAIN_RESULT_ENABLE_DETAILED_AUTOFILLED_METRICS"
54
+ )
55
+
56
+ # Integer value which if set will override the value of
57
+ # Backend.share_cuda_visible_devices. 1 for True, 0 for False.
58
+ ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV = "TRAIN_ENABLE_SHARE_CUDA_VISIBLE_DEVICES"
59
+
60
+ # Integer value which if set will not share ROCR accelerator visible devices
61
+ # across workers. 1 for True (default), 0 for False.
62
+ ENABLE_SHARE_ROCR_VISIBLE_DEVICES_ENV = "TRAIN_ENABLE_SHARE_ROCR_VISIBLE_DEVICES"
63
+
64
+ # Integer value which if set will not share neuron-core accelerator visible cores
65
+ # across workers. 1 for True (default), 0 for False.
66
+ ENABLE_SHARE_NEURON_CORES_ACCELERATOR_ENV = (
67
+ "TRAIN_ENABLE_SHARE_NEURON_CORES_ACCELERATOR"
68
+ )
69
+
70
+ # Integer value which if set will not share npu visible devices
71
+ # across workers. 1 for True (default), 0 for False.
72
+ ENABLE_SHARE_NPU_RT_VISIBLE_DEVICES_ENV = "TRAIN_ENABLE_SHARE_ASCEND_RT_VISIBLE_DEVICES"
73
+
74
+ # Integer value which indicates the number of seconds to wait when creating
75
+ # the worker placement group before timing out.
76
+ TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV = "TRAIN_PLACEMENT_GROUP_TIMEOUT_S"
77
+
78
+ # Integer value which if set will change the placement group strategy from
79
+ # PACK to SPREAD. 1 for True, 0 for False.
80
+ TRAIN_ENABLE_WORKER_SPREAD_ENV = "TRAIN_ENABLE_WORKER_SPREAD"
81
+
82
+ # Set this to 0 to disable changing the working directory of each Tune Trainable
83
+ # or Train worker to the trial directory. Defaults to 1.
84
+ RAY_CHDIR_TO_TRIAL_DIR = "RAY_CHDIR_TO_TRIAL_DIR"
85
+
86
+ # Set this to 1 to count preemption errors toward `FailureConfig(max_failures)`.
87
+ # Defaults to 0, which always retries on node preemption failures.
88
+ RAY_TRAIN_COUNT_PREEMPTION_AS_FAILURE = "RAY_TRAIN_COUNT_PREEMPTION_AS_FAILURE"
89
+
90
+ # Set this to 1 to start a StateActor and collect information Train Runs
91
+ # Defaults to 0
92
+ RAY_TRAIN_ENABLE_STATE_TRACKING = "RAY_TRAIN_ENABLE_STATE_TRACKING"
93
+
94
+ # Set this to 1 to enable deprecation warnings for V2 migration.
95
+ ENABLE_V2_MIGRATION_WARNINGS_ENV_VAR = "RAY_TRAIN_ENABLE_V2_MIGRATION_WARNINGS"
96
+
97
+
98
+ def _v2_migration_warnings_enabled() -> bool:
99
+ return env_bool(ENABLE_V2_MIGRATION_WARNINGS_ENV_VAR, False)
100
+
101
+
102
+ # NOTE: When adding a new environment variable, please track it in this list.
103
+ TRAIN_ENV_VARS = {
104
+ ENABLE_DETAILED_AUTOFILLED_METRICS_ENV,
105
+ ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV,
106
+ ENABLE_SHARE_NEURON_CORES_ACCELERATOR_ENV,
107
+ TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV,
108
+ TRAIN_ENABLE_WORKER_SPREAD_ENV,
109
+ RAY_CHDIR_TO_TRIAL_DIR,
110
+ RAY_TRAIN_COUNT_PREEMPTION_AS_FAILURE,
111
+ RAY_TRAIN_ENABLE_STATE_TRACKING,
112
+ }
113
+
114
+ # Key for AIR Checkpoint metadata in TrainingResult metadata
115
+ CHECKPOINT_METADATA_KEY = "checkpoint_metadata"
116
+
117
+ # Key for AIR Checkpoint world rank in TrainingResult metadata
118
+ CHECKPOINT_RANK_KEY = "checkpoint_rank"
.venv/lib/python3.11/site-packages/ray/train/context.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ from typing import TYPE_CHECKING, Any, Dict, Optional
3
+
4
+ from ray.train._internal import session
5
+ from ray.train._internal.storage import StorageContext
6
+ from ray.train.constants import _v2_migration_warnings_enabled
7
+ from ray.train.utils import _copy_doc, _log_deprecation_warning
8
+ from ray.util.annotations import Deprecated, DeveloperAPI, PublicAPI
9
+
10
+ if TYPE_CHECKING:
11
+ from ray.tune.execution.placement_groups import PlacementGroupFactory
12
+
13
+
14
+ # The context singleton on this process.
15
+ _default_context: "Optional[TrainContext]" = None
16
+ _context_lock = threading.Lock()
17
+
18
+
19
+ _GET_METADATA_DEPRECATION_MESSAGE = (
20
+ "`get_metadata` was an experimental API that accessed the metadata passed "
21
+ "to `<Framework>Trainer(metadata=...)`. This API can be replaced by passing "
22
+ "the metadata directly to the training function (e.g., via `train_loop_config`)."
23
+ )
24
+
25
+ _TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE = (
26
+ "`{}` is deprecated because the concept of a `Trial` will "
27
+ "soon be removed in Ray Train (see here: "
28
+ "https://github.com/ray-project/enhancements/pull/57). "
29
+ "Ray Train will no longer assume that it's running within a Ray Tune `Trial` "
30
+ "in the future."
31
+ )
32
+
33
+
34
+ @PublicAPI(stability="stable")
35
+ class TrainContext:
36
+ """Context containing metadata that can be accessed within Ray Train workers."""
37
+
38
+ @_copy_doc(session.get_experiment_name)
39
+ def get_experiment_name(self) -> str:
40
+ return session.get_experiment_name()
41
+
42
+ @_copy_doc(session.get_world_size)
43
+ def get_world_size(self) -> int:
44
+ return session.get_world_size()
45
+
46
+ @_copy_doc(session.get_world_rank)
47
+ def get_world_rank(self) -> int:
48
+ return session.get_world_rank()
49
+
50
+ @_copy_doc(session.get_local_rank)
51
+ def get_local_rank(self) -> int:
52
+ return session.get_local_rank()
53
+
54
+ @_copy_doc(session.get_local_world_size)
55
+ def get_local_world_size(self) -> int:
56
+ return session.get_local_world_size()
57
+
58
+ @_copy_doc(session.get_node_rank)
59
+ def get_node_rank(self) -> int:
60
+ return session.get_node_rank()
61
+
62
+ @DeveloperAPI
63
+ @_copy_doc(session.get_storage)
64
+ def get_storage(self) -> StorageContext:
65
+ return session.get_storage()
66
+
67
+ # Deprecated APIs
68
+
69
+ @Deprecated(
70
+ message=_GET_METADATA_DEPRECATION_MESSAGE,
71
+ warning=_v2_migration_warnings_enabled(),
72
+ )
73
+ @_copy_doc(session.get_metadata)
74
+ def get_metadata(self) -> Dict[str, Any]:
75
+ return session.get_metadata()
76
+
77
+ @Deprecated(
78
+ message=_TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_trial_name"),
79
+ warning=_v2_migration_warnings_enabled(),
80
+ )
81
+ @_copy_doc(session.get_trial_name)
82
+ def get_trial_name(self) -> str:
83
+ return session.get_trial_name()
84
+
85
+ @Deprecated(
86
+ message=_TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_trial_id"),
87
+ warning=_v2_migration_warnings_enabled(),
88
+ )
89
+ @_copy_doc(session.get_trial_id)
90
+ def get_trial_id(self) -> str:
91
+ return session.get_trial_id()
92
+
93
+ @Deprecated(
94
+ message=_TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format(
95
+ "get_trial_resources"
96
+ ),
97
+ warning=_v2_migration_warnings_enabled(),
98
+ )
99
+ @_copy_doc(session.get_trial_resources)
100
+ def get_trial_resources(self) -> "PlacementGroupFactory":
101
+ return session.get_trial_resources()
102
+
103
+ @Deprecated(
104
+ message=_TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_trial_dir"),
105
+ warning=_v2_migration_warnings_enabled(),
106
+ )
107
+ @_copy_doc(session.get_trial_dir)
108
+ def get_trial_dir(self) -> str:
109
+ return session.get_trial_dir()
110
+
111
+
112
+ @PublicAPI(stability="stable")
113
+ def get_context() -> TrainContext:
114
+ """Get or create a singleton training context.
115
+
116
+ The context is only available within a function passed to Ray Train.
117
+
118
+ See the :class:`~ray.train.TrainContext` API reference to see available methods.
119
+ """
120
+ from ray.tune.trainable.trainable_fn_utils import _in_tune_session
121
+
122
+ # If we are running in a Tune function, switch to Tune context.
123
+ if _in_tune_session():
124
+ from ray.tune import get_context as get_tune_context
125
+
126
+ if _v2_migration_warnings_enabled():
127
+ _log_deprecation_warning(
128
+ "`ray.train.get_context()` should be switched to "
129
+ "`ray.tune.get_context()` when running in a function "
130
+ "passed to Ray Tune. This will be an error in the future."
131
+ )
132
+ return get_tune_context()
133
+
134
+ global _default_context
135
+
136
+ with _context_lock:
137
+ if _default_context is None:
138
+ _default_context = TrainContext()
139
+ return _default_context
.venv/lib/python3.11/site-packages/ray/train/data_parallel_trainer.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import uuid
3
+ from typing import Any, Callable, Dict, List, Optional, Type, Union
4
+
5
+ import ray
6
+ from ray._private.ray_constants import env_integer
7
+ from ray._private.thirdparty.tabulate.tabulate import tabulate
8
+ from ray.air.config import RunConfig, ScalingConfig
9
+ from ray.train import BackendConfig, Checkpoint, TrainingIterator
10
+ from ray.train._internal import session
11
+ from ray.train._internal.backend_executor import BackendExecutor, TrialInfo
12
+ from ray.train._internal.data_config import DataConfig
13
+ from ray.train._internal.session import _TrainingResult, get_session
14
+ from ray.train._internal.utils import construct_train_func, count_required_parameters
15
+ from ray.train.constants import RAY_TRAIN_ENABLE_STATE_TRACKING
16
+ from ray.train.trainer import BaseTrainer, GenDataset
17
+ from ray.util.annotations import DeveloperAPI, PublicAPI
18
+ from ray.widgets import Template
19
+ from ray.widgets.util import repr_with_fallback
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ @DeveloperAPI
25
+ class DataParallelTrainer(BaseTrainer):
26
+ """A Trainer for data parallel training.
27
+
28
+ You should subclass this Trainer if your Trainer follows SPMD (single program,
29
+ multiple data) programming paradigm - you want multiple processes to run the same
30
+ function, but on different data.
31
+
32
+ This Trainer runs the function ``train_loop_per_worker`` on multiple Ray
33
+ Actors.
34
+
35
+ The ``train_loop_per_worker`` function is expected to take in either 0 or 1
36
+ arguments:
37
+
38
+ .. testcode::
39
+
40
+ def train_loop_per_worker():
41
+ ...
42
+
43
+ .. testcode::
44
+
45
+ def train_loop_per_worker(config: Dict):
46
+ ...
47
+
48
+ If ``train_loop_per_worker`` accepts an argument, then
49
+ ``train_loop_config`` will be passed in as the argument. This is useful if you
50
+ want to tune the values in ``train_loop_config`` as hyperparameters.
51
+
52
+ If the ``datasets`` dict contains a training dataset (denoted by
53
+ the "train" key), then it will be split into multiple dataset
54
+ shards that can then be accessed by ``train.get_dataset_shard("train")`` inside
55
+ ``train_loop_per_worker``. All the other datasets will not be split and
56
+ ``train.get_dataset_shard(...)`` will return the the entire Dataset.
57
+
58
+ Inside the ``train_loop_per_worker`` function, you can use any of the
59
+ :ref:`Ray Train loop methods <train-loop-api>`.
60
+
61
+ .. testcode::
62
+
63
+ from ray import train
64
+
65
+ def train_loop_per_worker():
66
+ # Report intermediate results for callbacks or logging and
67
+ # checkpoint data.
68
+ train.report(...)
69
+
70
+ # Returns dict of last saved checkpoint.
71
+ train.get_checkpoint()
72
+
73
+ # Returns the Dataset shard for the given key.
74
+ train.get_dataset_shard("my_dataset")
75
+
76
+ # Returns the total number of workers executing training.
77
+ train.get_context().get_world_size()
78
+
79
+ # Returns the rank of this worker.
80
+ train.get_context().get_world_rank()
81
+
82
+ # Returns the rank of the worker on the current node.
83
+ train.get_context().get_local_rank()
84
+
85
+ Any returns from the ``train_loop_per_worker`` will be discarded and not
86
+ used or persisted anywhere.
87
+
88
+ **How do I use DataParallelTrainer or any of its subclasses?**
89
+
90
+ Example:
91
+
92
+ .. testcode::
93
+
94
+ import ray
95
+ from ray import train
96
+ from ray.train import ScalingConfig
97
+ from ray.train.data_parallel_trainer import DataParallelTrainer
98
+
99
+ def train_loop_for_worker():
100
+ dataset_shard_for_this_worker = train.get_dataset_shard("train")
101
+
102
+ # 3 items for 3 workers, each worker gets 1 item
103
+ batches = list(dataset_shard_for_this_worker.iter_batches(batch_size=1))
104
+ assert len(batches) == 1
105
+
106
+ train_dataset = ray.data.from_items([1, 2, 3])
107
+ assert train_dataset.count() == 3
108
+ trainer = DataParallelTrainer(
109
+ train_loop_for_worker,
110
+ scaling_config=ScalingConfig(num_workers=3),
111
+ datasets={"train": train_dataset},
112
+ )
113
+ result = trainer.fit()
114
+
115
+ .. testoutput::
116
+ :hide:
117
+
118
+ ...
119
+
120
+ **How do I develop on top of DataParallelTrainer?**
121
+
122
+ In many cases, using DataParallelTrainer directly is sufficient to execute
123
+ functions on multiple actors.
124
+
125
+ However, you may want to subclass ``DataParallelTrainer`` and create a custom
126
+ Trainer for the following 2 use cases:
127
+
128
+ - **Use Case 1:** You want to do data parallel training, but want to have
129
+ a predefined ``training_loop_per_worker``.
130
+
131
+ - **Use Case 2:** You want to implement a custom
132
+ :py:class:`~ray.train.backend.Backend` that automatically handles
133
+ additional setup or teardown logic on each actor, so that the users of this
134
+ new trainer do not have to implement this logic. For example, a
135
+ ``TensorflowTrainer`` can be built on top of ``DataParallelTrainer``
136
+ that automatically handles setting the proper environment variables for
137
+ distributed Tensorflow on each actor.
138
+
139
+ For 1, you can set a predefined training loop in __init__
140
+
141
+ .. testcode::
142
+
143
+ from ray.train.data_parallel_trainer import DataParallelTrainer
144
+
145
+ class MyDataParallelTrainer(DataParallelTrainer):
146
+ def __init__(self, *args, **kwargs):
147
+ predefined_train_loop_per_worker = lambda: 1
148
+ super().__init__(predefined_train_loop_per_worker, *args, **kwargs)
149
+
150
+
151
+ For 2, you can implement the ``ray.train.Backend`` and ``ray.train.BackendConfig``
152
+ interfaces.
153
+
154
+ .. testcode::
155
+
156
+ from dataclasses import dataclass
157
+ from ray.train.backend import Backend, BackendConfig
158
+
159
+ class MyBackend(Backend):
160
+ def on_start(self, worker_group, backend_config):
161
+ def set_env_var(env_var_value):
162
+ import os
163
+ os.environ["MY_ENV_VAR"] = env_var_value
164
+
165
+ worker_group.execute(set_env_var, backend_config.env_var)
166
+
167
+ @dataclass
168
+ class MyBackendConfig(BackendConfig):
169
+ env_var: str = "default_value"
170
+
171
+ def backend_cls(self):
172
+ return MyBackend
173
+
174
+ class MyTrainer(DataParallelTrainer):
175
+ def __init__(self, train_loop_per_worker, my_backend_config:
176
+ MyBackendConfig, **kwargs):
177
+
178
+ super().__init__(
179
+ train_loop_per_worker,
180
+ backend_config=my_backend_config, **kwargs)
181
+
182
+ Args:
183
+ train_loop_per_worker: The training function to execute.
184
+ This can either take in no arguments or a ``config`` dict.
185
+ train_loop_config: Configurations to pass into
186
+ ``train_loop_per_worker`` if it accepts an argument.
187
+ backend_config: Configuration for setting up a Backend (e.g. Torch,
188
+ Tensorflow, Horovod) on each worker to enable distributed
189
+ communication. If no Backend should be set up, then set this to None.
190
+ scaling_config: Configuration for how to scale data parallel training.
191
+ dataset_config: Configuration for dataset ingest. This is merged with the
192
+ default dataset config for the given trainer (`cls._dataset_config`).
193
+ run_config: Configuration for the execution of the training run.
194
+ datasets: Ray Datasets to use for training and evaluation.
195
+ This is a dict where the key is the name of the dataset, which
196
+ can be accessed from within the ``train_loop_per_worker`` by calling
197
+ ``train.get_dataset_shard(dataset_key)``.
198
+ By default, all datasets are sharded equally across workers.
199
+ This can be configured via ``dataset_config``.
200
+ metadata: Dict that should be made available via
201
+ `train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
202
+ for checkpoints saved from this Trainer. Must be JSON-serializable.
203
+ resume_from_checkpoint: A checkpoint to resume training from.
204
+ """
205
+
206
+ # Exposed here for testing purposes. Should never need
207
+ # to be overriden.
208
+ _backend_executor_cls: Type[BackendExecutor] = BackendExecutor
209
+ _training_iterator_cls: Type[TrainingIterator] = TrainingIterator
210
+
211
+ _scaling_config_allowed_keys = BaseTrainer._scaling_config_allowed_keys + [
212
+ "num_workers",
213
+ "resources_per_worker",
214
+ "use_gpu",
215
+ "placement_strategy",
216
+ "accelerator_type",
217
+ ]
218
+
219
+ # For backwards compatibility with the legacy dataset config API.
220
+ _dataset_config = None
221
+
222
+ _fields_for_tuner_param_space = BaseTrainer._fields_for_tuner_param_space + [
223
+ "train_loop_config"
224
+ ]
225
+
226
+ def __init__(
227
+ self,
228
+ train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]],
229
+ *,
230
+ train_loop_config: Optional[Dict] = None,
231
+ backend_config: Optional[BackendConfig] = None,
232
+ scaling_config: Optional[ScalingConfig] = None,
233
+ dataset_config: Optional[DataConfig] = None,
234
+ run_config: Optional[RunConfig] = None,
235
+ datasets: Optional[Dict[str, GenDataset]] = None,
236
+ metadata: Optional[Dict[str, Any]] = None,
237
+ resume_from_checkpoint: Optional[Checkpoint] = None,
238
+ ):
239
+ self._train_loop_per_worker = train_loop_per_worker
240
+ self._train_loop_config = train_loop_config
241
+
242
+ if dataset_config is None:
243
+ dataset_config = DataConfig()
244
+
245
+ if not isinstance(dataset_config, DataConfig):
246
+ raise ValueError(
247
+ "`dataset_config` must be an instance of ray.train.DataConfig, "
248
+ f"was: {dataset_config}"
249
+ )
250
+ self._data_config = dataset_config
251
+
252
+ backend_config = (
253
+ backend_config if backend_config is not None else BackendConfig()
254
+ )
255
+ self._backend_config = backend_config
256
+
257
+ super(DataParallelTrainer, self).__init__(
258
+ scaling_config=scaling_config,
259
+ run_config=run_config,
260
+ datasets=datasets,
261
+ metadata=metadata,
262
+ resume_from_checkpoint=resume_from_checkpoint,
263
+ )
264
+
265
+ train_total_resources = self.scaling_config.total_resources
266
+ self._data_config.set_train_total_resources(
267
+ train_total_resources.get("CPU", 0),
268
+ train_total_resources.get("GPU", 0),
269
+ )
270
+
271
+ if env_integer(RAY_TRAIN_ENABLE_STATE_TRACKING, 0):
272
+ from ray.train._internal.state.state_actor import get_or_create_state_actor
273
+
274
+ get_or_create_state_actor()
275
+
276
+ @PublicAPI(stability="beta")
277
+ @classmethod
278
+ def restore(
279
+ cls: Type["DataParallelTrainer"],
280
+ path: str,
281
+ train_loop_per_worker: Optional[
282
+ Union[Callable[[], None], Callable[[Dict], None]]
283
+ ] = None,
284
+ train_loop_config: Optional[Dict] = None,
285
+ **kwargs,
286
+ ) -> "DataParallelTrainer":
287
+ """Restores a DataParallelTrainer from a previously interrupted/failed run.
288
+
289
+ Args:
290
+ train_loop_per_worker: Optionally re-specified train loop function.
291
+ This should be used to re-specify a function that is not
292
+ restorable in a new Ray cluster (e.g., it holds onto outdated
293
+ object references). This should be the same training loop
294
+ that was passed to the original trainer constructor.
295
+ train_loop_config: Optionally re-specified train config.
296
+ This should similarly be used if the original `train_loop_config`
297
+ contained outdated object references, and it should not be modified
298
+ from what was originally passed in.
299
+
300
+ See :meth:`BaseTrainer.restore() <ray.train.trainer.BaseTrainer.restore>`
301
+ for descriptions of the other arguments.
302
+
303
+ Returns:
304
+ DataParallelTrainer: A restored instance of the `DataParallelTrainer`
305
+ subclass that is calling this method.
306
+ """
307
+ return super(DataParallelTrainer, cls).restore(
308
+ path=path,
309
+ train_loop_per_worker=train_loop_per_worker,
310
+ train_loop_config=train_loop_config,
311
+ **kwargs,
312
+ )
313
+
314
+ def _validate_attributes(self):
315
+ super()._validate_attributes()
316
+
317
+ self._validate_train_loop_per_worker(
318
+ self._train_loop_per_worker, "train_loop_per_worker"
319
+ )
320
+
321
+ def _validate_train_loop_per_worker(
322
+ self, train_loop_per_worker: Callable, fn_name: str
323
+ ) -> None:
324
+ num_required_params = count_required_parameters(train_loop_per_worker)
325
+ if num_required_params > 1:
326
+ raise ValueError(
327
+ f"{fn_name} should take in 0 or 1 arguments, "
328
+ f"but it accepts {num_required_params} arguments instead."
329
+ )
330
+
331
+ @classmethod
332
+ def _validate_scaling_config(cls, scaling_config: ScalingConfig) -> ScalingConfig:
333
+ scaling_config = super(DataParallelTrainer, cls)._validate_scaling_config(
334
+ scaling_config
335
+ )
336
+
337
+ # This validation happens after the scaling config is updated from
338
+ # its specification in the Tuner `param_space`
339
+ if not scaling_config.use_gpu and "GPU" in ray.available_resources():
340
+ logger.info(
341
+ "GPUs are detected in your Ray cluster, but GPU "
342
+ "training is not enabled for this trainer. To enable "
343
+ "GPU training, make sure to set `use_gpu` to True "
344
+ "in your scaling config."
345
+ )
346
+
347
+ if scaling_config.num_workers is None:
348
+ raise ValueError(
349
+ "You must specify the 'num_workers' in `scaling_config` as either an "
350
+ f"argument of `{cls.__name__}` or through the `param_space` of a "
351
+ "`Tuner` (if performing hyperparameter tuning)."
352
+ )
353
+
354
+ if scaling_config.num_workers <= 0:
355
+ raise ValueError(
356
+ "'num_workers' in `scaling_config` must be a positive "
357
+ f"integer. Received {scaling_config.num_workers}"
358
+ )
359
+
360
+ return scaling_config
361
+
362
+ def _run_training(self, training_iterator: TrainingIterator) -> None:
363
+ """This method loops over the `TrainingIterator`:
364
+ The actual iteration (for ... in ...) waits for the training function
365
+ on each worker to report a result and supplies it as a list of results.
366
+ Afterwards (in the body of the loop), it will report the result
367
+ to the Tune session.
368
+ The iterator ends after the training function on each worker has finished.
369
+ """
370
+ for training_results in training_iterator:
371
+ # TODO(ml-team): add ability to report results from multiple workers.
372
+ self._propagate_results(training_results)
373
+
374
+ def _propagate_results(self, training_results: List[_TrainingResult]):
375
+ first_worker_result = training_results[0]
376
+ assert all(isinstance(result, _TrainingResult) for result in training_results)
377
+
378
+ tune_session = get_session()
379
+
380
+ # Check if any workers reported a checkpoint.
381
+ # If so, report a checkpoint pointing to the persisted location
382
+ # to Tune for book-keeping.
383
+ # NOTE: This removes the restriction for any individual worker
384
+ # (ex: global rank 0 worker) from needing to report a checkpoint.
385
+ # All workers reported a checkpoint to the same fs path, so there's
386
+ # no need to report multiple checkpoints to Tune.
387
+ worker_checkpoints = [
388
+ result.checkpoint
389
+ for result in training_results
390
+ if result.checkpoint is not None
391
+ ]
392
+ at_least_one_reported_checkpoint = len(worker_checkpoints) > 0
393
+
394
+ if at_least_one_reported_checkpoint:
395
+ # Update the coordinator's checkpoint index to the latest.
396
+ # This is what keeps the checkpoint index in line with the workers.
397
+ tune_session.storage._update_checkpoint_index(first_worker_result.metrics)
398
+
399
+ # Make sure that all workers uploaded to the same location.
400
+ assert all(
401
+ checkpoint.path == tune_session.storage.checkpoint_fs_path
402
+ for checkpoint in worker_checkpoints
403
+ )
404
+
405
+ checkpoint = (
406
+ Checkpoint(
407
+ filesystem=tune_session.storage.storage_filesystem,
408
+ path=tune_session.storage.checkpoint_fs_path,
409
+ )
410
+ if at_least_one_reported_checkpoint
411
+ else None
412
+ )
413
+
414
+ tracked_training_result = _TrainingResult(
415
+ checkpoint=checkpoint,
416
+ metrics=first_worker_result.metrics,
417
+ )
418
+
419
+ logger.debug(
420
+ "Report (metrics, checkpoint) to the Tune session:\n"
421
+ f" metrics={tracked_training_result.metrics}\n"
422
+ f" checkpoint={tracked_training_result.checkpoint}"
423
+ )
424
+
425
+ # Report the metrics and checkpoint to Tune.
426
+ tune_session._report_training_result(tracked_training_result)
427
+
428
+ def training_loop(self) -> None:
429
+ scaling_config = self._validate_scaling_config(self.scaling_config)
430
+
431
+ train_loop_per_worker = construct_train_func(
432
+ self._train_loop_per_worker,
433
+ self._train_loop_config,
434
+ train_func_context=self._backend_config.train_func_context,
435
+ fn_arg_name="train_loop_per_worker",
436
+ discard_returns=True,
437
+ )
438
+
439
+ trial_info = TrialInfo(
440
+ name=session.get_trial_name(),
441
+ id=session.get_trial_id(),
442
+ resources=session.get_trial_resources(),
443
+ logdir=session.get_trial_dir(),
444
+ driver_ip=ray.util.get_node_ip_address(),
445
+ driver_node_id=ray.get_runtime_context().get_node_id(),
446
+ experiment_name=session.get_experiment_name(),
447
+ run_id=uuid.uuid4().hex,
448
+ )
449
+
450
+ backend_executor = self._backend_executor_cls(
451
+ backend_config=self._backend_config,
452
+ trial_info=trial_info,
453
+ num_workers=scaling_config.num_workers,
454
+ resources_per_worker=scaling_config._resources_per_worker_not_none,
455
+ max_retries=0,
456
+ )
457
+
458
+ # Start the remote actors.
459
+ backend_executor.start()
460
+
461
+ training_iterator = self._training_iterator_cls(
462
+ backend_executor=backend_executor,
463
+ backend_config=self._backend_config,
464
+ train_func=train_loop_per_worker,
465
+ datasets=self.datasets,
466
+ metadata=self.metadata,
467
+ data_config=self._data_config,
468
+ checkpoint=self.starting_checkpoint,
469
+ )
470
+
471
+ self._run_training(training_iterator)
472
+
473
+ # Shutdown workers.
474
+ backend_executor.shutdown()
475
+
476
+ def get_dataset_config(self) -> DataConfig:
477
+ """Returns a copy of this Trainer's final dataset configs.
478
+
479
+ Returns:
480
+ The merged default + user-supplied dataset config.
481
+ """
482
+
483
+ return self._data_config
484
+
485
+ @repr_with_fallback(["ipywidgets", "8"])
486
+ def _repr_mimebundle_(self, **kwargs):
487
+ """Returns a mimebundle with an ipywidget repr and a simple text repr.
488
+
489
+ Depending on the frontend where the data is being displayed,
490
+ different mimetypes will be used from this bundle.
491
+ See https://ipython.readthedocs.io/en/stable/config/integrating.html
492
+ for information about this method, and
493
+ https://ipywidgets.readthedocs.io/en/latest/embedding.html
494
+ for more information about the jupyter widget mimetype.
495
+
496
+ Returns:
497
+ A mimebundle containing an ipywidget repr and a simple text repr.
498
+ """
499
+ from ipywidgets import HTML, Layout, Tab, VBox
500
+
501
+ title = HTML(f"<h2>{self.__class__.__name__}</h2>")
502
+
503
+ children = []
504
+ titles = []
505
+
506
+ if self.datasets:
507
+ children.append(self._datasets_repr_())
508
+ titles.append("Datasets")
509
+
510
+ children.append(HTML(self._data_config_repr_html_()))
511
+ titles.append("Data Config")
512
+
513
+ if self._train_loop_config:
514
+ children.append(HTML(self._train_loop_config_repr_html_()))
515
+ titles.append("Train Loop Config")
516
+
517
+ if self.scaling_config:
518
+ children.append(HTML(self.scaling_config._repr_html_()))
519
+ titles.append("Scaling Config")
520
+
521
+ if self.run_config:
522
+ children.append(HTML(self.run_config._repr_html_()))
523
+ titles.append("Run Config")
524
+
525
+ if self._backend_config:
526
+ children.append(HTML(self._backend_config._repr_html_()))
527
+ titles.append("Backend Config")
528
+
529
+ tab = Tab(children, titles=titles)
530
+ widget = VBox([title, tab], layout=Layout(width="100%"))
531
+ bundle = widget._repr_mimebundle_(**kwargs)
532
+ bundle.update(
533
+ {
534
+ "text/plain": repr(self),
535
+ }
536
+ )
537
+ return bundle
538
+
539
+ def _train_loop_config_repr_html_(self) -> str:
540
+ if self._train_loop_config:
541
+ table_data = {}
542
+ for k, v in self._train_loop_config.items():
543
+ if isinstance(v, str) or str(v).isnumeric():
544
+ table_data[k] = v
545
+ elif hasattr(v, "_repr_html_"):
546
+ table_data[k] = v._repr_html_()
547
+ else:
548
+ table_data[k] = str(v)
549
+
550
+ return Template("title_data.html.j2").render(
551
+ title="Train Loop Config",
552
+ data=Template("scrollableTable.html.j2").render(
553
+ table=tabulate(
554
+ table_data.items(),
555
+ headers=["Setting", "Value"],
556
+ showindex=False,
557
+ tablefmt="unsafehtml",
558
+ ),
559
+ max_height="none",
560
+ ),
561
+ )
562
+ else:
563
+ return ""
564
+
565
+ def _data_config_repr_html_(self) -> str:
566
+ # TODO make this rendering nicer.
567
+ content = [str(self._data_config)]
568
+ return Template("rendered_html_common.html.j2").render(content=content)
569
+
570
+ def _datasets_repr_(self) -> str:
571
+ from ipywidgets import HTML, Layout, VBox
572
+
573
+ content = []
574
+ if self.datasets:
575
+ for name, config in self.datasets.items():
576
+ tab = config._tab_repr_()
577
+ if tab:
578
+ content.append(
579
+ HTML(
580
+ Template("title_data.html.j2").render(
581
+ title=f"Dataset - <code>{name}</code>", data=None
582
+ )
583
+ )
584
+ )
585
+ content.append(config._tab_repr_())
586
+
587
+ return VBox(content, layout=Layout(width="100%"))
.venv/lib/python3.11/site-packages/ray/train/error.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from ray.util.annotations import PublicAPI
2
+
3
+
4
+ @PublicAPI(stability="beta")
5
+ class SessionMisuseError(Exception):
6
+ """Indicates a method or function was used outside of a session."""
.venv/lib/python3.11/site-packages/ray/train/examples/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/train/examples/mlflow_simple_example.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from ray import train
4
+ from ray.train import RunConfig, ScalingConfig
5
+ from ray.train.torch import TorchTrainer
6
+ from ray.tune.logger import TBXLoggerCallback
7
+ from ray.tune.logger.mlflow import MLflowLoggerCallback
8
+
9
+
10
+ def train_func():
11
+ for i in range(3):
12
+ train.report(dict(epoch=i))
13
+
14
+
15
+ trainer = TorchTrainer(
16
+ train_func,
17
+ scaling_config=ScalingConfig(num_workers=2),
18
+ run_config=RunConfig(
19
+ callbacks=[
20
+ MLflowLoggerCallback(experiment_name="train_experiment"),
21
+ TBXLoggerCallback(),
22
+ ],
23
+ ),
24
+ )
25
+
26
+ # Run the training function, logging all the intermediate results
27
+ # to MLflow and Tensorboard.
28
+ result = trainer.fit()
29
+
30
+ # For MLFLow logs:
31
+
32
+ # MLFlow logs will by default be saved in an `mlflow` directory
33
+ # in the current working directory.
34
+
35
+ # $ cd mlflow
36
+ # # View the MLflow UI.
37
+ # $ mlflow ui
38
+
39
+ # You can change the directory by setting the `tracking_uri` argument
40
+ # in `MLflowLoggerCallback`.
41
+
42
+ # For TensorBoard logs:
43
+
44
+ # Print the latest run directory and keep note of it.
45
+ # For example: /home/ubuntu/ray_results/TorchTrainer_2022-06-13_20-31-06
46
+ print("Run directory:", Path(result.path).parent) # TensorBoard is saved in parent dir
47
+
48
+ # How to visualize the logs
49
+
50
+ # Navigate to the run directory of the trainer.
51
+ # For example `cd /home/ubuntu/ray_results/TorchTrainer_2022-06-13_20-31-06`
52
+ # $ cd <TRAINER_RUN_DIR>
53
+ #
54
+ # # View the tensorboard UI.
55
+ # $ tensorboard --logdir .
.venv/lib/python3.11/site-packages/ray/train/examples/tf/tune_tensorflow_autoencoder_example.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import ray
4
+ from ray import tune
5
+ from ray.train import ScalingConfig
6
+ from ray.train.examples.tf.tensorflow_mnist_example import train_func
7
+ from ray.train.tensorflow import TensorflowTrainer
8
+ from ray.tune.tune_config import TuneConfig
9
+ from ray.tune.tuner import Tuner
10
+
11
+
12
+ def tune_tensorflow_mnist(
13
+ num_workers: int = 2, num_samples: int = 2, use_gpu: bool = False
14
+ ):
15
+ scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)
16
+ trainer = TensorflowTrainer(
17
+ train_loop_per_worker=train_func,
18
+ scaling_config=scaling_config,
19
+ )
20
+ tuner = Tuner(
21
+ trainer,
22
+ tune_config=TuneConfig(
23
+ num_samples=num_samples, metric="binary_crossentropy", mode="min"
24
+ ),
25
+ param_space={
26
+ "train_loop_config": {
27
+ "lr": tune.loguniform(1e-4, 1e-1),
28
+ "batch_size": tune.choice([32, 64, 128]),
29
+ "epochs": 3,
30
+ }
31
+ },
32
+ )
33
+ best_accuracy = tuner.fit().get_best_result().metrics["binary_crossentropy"]
34
+ print(f"Best accuracy config: {best_accuracy}")
35
+
36
+
37
+ if __name__ == "__main__":
38
+ parser = argparse.ArgumentParser()
39
+ parser.add_argument(
40
+ "--smoke-test",
41
+ action="store_true",
42
+ default=False,
43
+ help="Finish quickly for testing.",
44
+ )
45
+ parser.add_argument(
46
+ "--address", required=False, type=str, help="the address to use for Ray"
47
+ )
48
+ parser.add_argument(
49
+ "--num-workers",
50
+ "-n",
51
+ type=int,
52
+ default=2,
53
+ help="Sets number of workers for training.",
54
+ )
55
+ parser.add_argument(
56
+ "--num-samples",
57
+ type=int,
58
+ default=2,
59
+ help="Sets number of samples for training.",
60
+ )
61
+ parser.add_argument(
62
+ "--use-gpu", action="store_true", default=False, help="Enables GPU training"
63
+ )
64
+
65
+ args = parser.parse_args()
66
+
67
+ if args.smoke_test:
68
+ num_gpus = args.num_workers if args.use_gpu else 0
69
+ ray.init(num_cpus=8, num_gpus=num_gpus)
70
+ tune_tensorflow_mnist(num_workers=2, num_samples=2, use_gpu=args.use_gpu)
71
+ else:
72
+ ray.init(address=args.address)
73
+ tune_tensorflow_mnist(
74
+ num_workers=args.num_workers,
75
+ num_samples=args.num_samples,
76
+ use_gpu=args.use_gpu,
77
+ )
.venv/lib/python3.11/site-packages/ray/train/huggingface/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/train/huggingface/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (194 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/train/huggingface/transformers/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.train.huggingface.transformers._transformers_utils import (
2
+ RayTrainReportCallback,
3
+ prepare_trainer,
4
+ )
5
+
6
+ __all__ = [
7
+ "RayTrainReportCallback",
8
+ "prepare_trainer",
9
+ ]
10
+
11
+
12
+ # DO NOT ADD ANYTHING AFTER THIS LINE.
.venv/lib/python3.11/site-packages/ray/train/huggingface/transformers/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (410 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/train/huggingface/transformers/__pycache__/_transformers_utils.cpython-311.pyc ADDED
Binary file (8.74 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/huggingface/transformers/_transformers_utils.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import shutil
3
+ from pathlib import Path
4
+ from tempfile import TemporaryDirectory
5
+ from typing import Iterator, Optional, Type
6
+
7
+ from torch.utils.data import DataLoader, Dataset, IterableDataset
8
+
9
+ import ray
10
+ from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag
11
+ from ray.data.iterator import _IterableFromIterator
12
+ from ray.train import Checkpoint
13
+ from ray.util import PublicAPI
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ TRANSFORMERS_IMPORT_ERROR: Optional[ImportError] = None
19
+
20
+ try:
21
+ import transformers.trainer
22
+ from transformers import Trainer
23
+ from transformers.trainer_callback import TrainerCallback
24
+ except ImportError as e:
25
+ TRANSFORMERS_IMPORT_ERROR = e
26
+ TrainerCallback = object
27
+
28
+
29
+ @PublicAPI(stability="beta")
30
+ class RayTrainReportCallback(TrainerCallback):
31
+ """A simple callback to report checkpoints and metrics to Ray Train.
32
+
33
+ This callback is a subclass of `transformers.TrainerCallback
34
+ <https://huggingface.co/docs/transformers/main/en/main_classes/callback#transformers.TrainerCallback>`_
35
+ and overrides the `TrainerCallback.on_save()` method. After
36
+ a new checkpoint get saved, it fetches the latest metric dictionary
37
+ from `TrainerState.log_history` and reports it with the latest checkpoint
38
+ to Ray Train.
39
+
40
+ Checkpoints will be saved in the following structure::
41
+
42
+ checkpoint_00000*/ Ray Train Checkpoint
43
+ └─ checkpoint/ Hugging Face Transformers Checkpoint
44
+
45
+ For customized reporting and checkpointing logic, implement your own
46
+ `transformers.TrainerCallback` following this user
47
+ guide: :ref:`Saving and Loading Checkpoints <train-dl-saving-checkpoints>`.
48
+
49
+ Note that users should ensure that the logging, evaluation, and saving frequencies
50
+ are properly configured so that the monitoring metric is always up-to-date
51
+ when `transformers.Trainer` saves a checkpoint.
52
+
53
+ Suppose the monitoring metric is reported from evaluation stage:
54
+
55
+ Some valid configurations:
56
+ - evaluation_strategy == save_strategy == "epoch"
57
+ - evaluation_strategy == save_strategy == "steps", save_steps % eval_steps == 0
58
+
59
+ Some invalid configurations:
60
+ - evaluation_strategy != save_strategy
61
+ - evaluation_strategy == save_strategy == "steps", save_steps % eval_steps != 0
62
+
63
+ """
64
+
65
+ CHECKPOINT_NAME = "checkpoint"
66
+
67
+ def __init__(self, *args, **kwargs):
68
+ super().__init__(*args, **kwargs)
69
+ record_extra_usage_tag(TagKey.TRAIN_TRANSFORMERS_RAYTRAINREPORTCALLBACK, "1")
70
+
71
+ def on_save(self, args, state, control, **kwargs):
72
+ """Event called after a checkpoint save."""
73
+ with TemporaryDirectory() as tmpdir:
74
+ # Aggregate all the logged metrics
75
+ metrics = {}
76
+ for log in state.log_history:
77
+ metrics.update(log)
78
+
79
+ # Copy ckpt files and construct a Ray Train Checkpoint
80
+ source_ckpt_path = transformers.trainer.get_last_checkpoint(args.output_dir)
81
+ if source_ckpt_path is not None:
82
+ target_ckpt_path = Path(tmpdir, self.CHECKPOINT_NAME).as_posix()
83
+ shutil.copytree(source_ckpt_path, target_ckpt_path)
84
+ checkpoint = Checkpoint.from_directory(tmpdir)
85
+ else:
86
+ checkpoint = None
87
+
88
+ # Report latest metrics and checkpoint to Ray Train
89
+ ray.train.report(metrics=metrics, checkpoint=checkpoint)
90
+
91
+
92
+ class RayTorchIterableDataset(IterableDataset):
93
+ """Wrapper class for ray data iterables."""
94
+
95
+ def __init__(self, data_iterable) -> None:
96
+ super().__init__()
97
+ self.data_iterable = data_iterable
98
+
99
+ def __iter__(self) -> Iterator:
100
+ return iter(self.data_iterable)
101
+
102
+
103
+ @PublicAPI(stability="beta")
104
+ def prepare_trainer(trainer: "Trainer") -> "Trainer":
105
+ """Prepare your HuggingFace Transformer Trainer for Ray Train.
106
+
107
+ This utility function enable the trainer integrates with Ray Data Integration.
108
+ Internally, it overrides the `get_train_dataloader` and `get_eval_dataloader`
109
+ methods and inject the data integration logics if the `train_dataset` and
110
+ `eval_dataset` are Ray Data Iterables.
111
+ """
112
+
113
+ if TRANSFORMERS_IMPORT_ERROR is not None:
114
+ raise TRANSFORMERS_IMPORT_ERROR
115
+
116
+ base_trainer_class: Type[transformers.trainer.Trainer] = trainer.__class__
117
+
118
+ class RayTransformersTrainer(base_trainer_class):
119
+ """A Wrapper of `transformers.Trainer` for Ray Data Integration."""
120
+
121
+ def get_train_dataloader(self) -> DataLoader:
122
+ if isinstance(self.train_dataset, _IterableFromIterator):
123
+ dataset = RayTorchIterableDataset(self.train_dataset)
124
+ return DataLoader(dataset, batch_size=1, collate_fn=lambda x: x[0])
125
+ else:
126
+ return super().get_train_dataloader()
127
+
128
+ def get_eval_dataloader(
129
+ self, eval_dataset: Optional[Dataset] = None
130
+ ) -> DataLoader:
131
+ if eval_dataset is None:
132
+ eval_dataset = self.eval_dataset
133
+
134
+ if isinstance(eval_dataset, _IterableFromIterator):
135
+ dataset = RayTorchIterableDataset(eval_dataset)
136
+ return DataLoader(dataset, batch_size=1, collate_fn=lambda x: x[0])
137
+ else:
138
+ return super().get_eval_dataloader(eval_dataset)
139
+
140
+ trainer.__class__ = RayTransformersTrainer
141
+
142
+ record_extra_usage_tag(TagKey.TRAIN_TRANSFORMERS_PREPARE_TRAINER, "1")
143
+ return trainer
.venv/lib/python3.11/site-packages/ray/train/lightgbm/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.train.lightgbm._lightgbm_utils import RayTrainReportCallback
2
+ from ray.train.lightgbm.lightgbm_checkpoint import LightGBMCheckpoint
3
+ from ray.train.lightgbm.lightgbm_predictor import LightGBMPredictor
4
+ from ray.train.lightgbm.lightgbm_trainer import LightGBMTrainer
5
+ from ray.train.v2._internal.constants import is_v2_enabled
6
+
7
+ if is_v2_enabled():
8
+ from ray.train.v2.lightgbm.lightgbm_trainer import LightGBMTrainer # noqa: F811
9
+
10
+ __all__ = [
11
+ "RayTrainReportCallback",
12
+ "LightGBMCheckpoint",
13
+ "LightGBMPredictor",
14
+ "LightGBMTrainer",
15
+ ]
16
+
17
+
18
+ # DO NOT ADD ANYTHING AFTER THIS LINE.
.venv/lib/python3.11/site-packages/ray/train/lightgbm/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (814 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/train/lightgbm/__pycache__/_lightgbm_utils.cpython-311.pyc ADDED
Binary file (8.97 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/lightgbm/__pycache__/config.cpython-311.pyc ADDED
Binary file (5.3 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/lightgbm/__pycache__/lightgbm_checkpoint.cpython-311.pyc ADDED
Binary file (4.03 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/lightgbm/__pycache__/lightgbm_predictor.cpython-311.pyc ADDED
Binary file (7.27 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/lightgbm/__pycache__/lightgbm_trainer.cpython-311.pyc ADDED
Binary file (11 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/lightgbm/__pycache__/v2.cpython-311.pyc ADDED
Binary file (6.73 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/lightgbm/_lightgbm_utils.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ from contextlib import contextmanager
3
+ from pathlib import Path
4
+ from typing import Callable, Dict, List, Optional, Union
5
+
6
+ from lightgbm.basic import Booster
7
+ from lightgbm.callback import CallbackEnv
8
+
9
+ import ray.train
10
+ from ray.train import Checkpoint
11
+ from ray.tune.utils import flatten_dict
12
+ from ray.util.annotations import PublicAPI
13
+
14
+
15
+ @PublicAPI(stability="beta")
16
+ class RayTrainReportCallback:
17
+ """Creates a callback that reports metrics and checkpoints model.
18
+
19
+ Args:
20
+ metrics: Metrics to report. If this is a list,
21
+ each item should be a metric key reported by LightGBM,
22
+ and it will be reported to Ray Train/Tune under the same name.
23
+ This can also be a dict of {<key-to-report>: <lightgbm-metric-key>},
24
+ which can be used to rename LightGBM default metrics.
25
+ filename: Customize the saved checkpoint file type by passing
26
+ a filename. Defaults to "model.txt".
27
+ frequency: How often to save checkpoints, in terms of iterations.
28
+ Defaults to 0 (no checkpoints are saved during training).
29
+ checkpoint_at_end: Whether or not to save a checkpoint at the end of training.
30
+ results_postprocessing_fn: An optional Callable that takes in
31
+ the metrics dict that will be reported (after it has been flattened)
32
+ and returns a modified dict.
33
+
34
+ Examples
35
+ --------
36
+
37
+ Reporting checkpoints and metrics to Ray Tune when running many
38
+ independent xgboost trials (without data parallelism within a trial).
39
+
40
+ .. testcode::
41
+ :skipif: True
42
+
43
+ import lightgbm
44
+
45
+ from ray.train.lightgbm import RayTrainReportCallback
46
+
47
+ config = {
48
+ # ...
49
+ "metric": ["binary_logloss", "binary_error"],
50
+ }
51
+
52
+ # Report only log loss to Tune after each validation epoch.
53
+ bst = lightgbm.train(
54
+ ...,
55
+ callbacks=[
56
+ RayTrainReportCallback(
57
+ metrics={"loss": "eval-binary_logloss"}, frequency=1
58
+ )
59
+ ],
60
+ )
61
+
62
+ Loading a model from a checkpoint reported by this callback.
63
+
64
+ .. testcode::
65
+ :skipif: True
66
+
67
+ from ray.train.lightgbm import RayTrainReportCallback
68
+
69
+ # Get a `Checkpoint` object that is saved by the callback during training.
70
+ result = trainer.fit()
71
+ booster = RayTrainReportCallback.get_model(result.checkpoint)
72
+
73
+ """
74
+
75
+ CHECKPOINT_NAME = "model.txt"
76
+
77
+ def __init__(
78
+ self,
79
+ metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
80
+ filename: str = CHECKPOINT_NAME,
81
+ frequency: int = 0,
82
+ checkpoint_at_end: bool = True,
83
+ results_postprocessing_fn: Optional[
84
+ Callable[[Dict[str, Union[float, List[float]]]], Dict[str, float]]
85
+ ] = None,
86
+ ):
87
+ if isinstance(metrics, str):
88
+ metrics = [metrics]
89
+ self._metrics = metrics
90
+ self._filename = filename
91
+ self._frequency = frequency
92
+ self._checkpoint_at_end = checkpoint_at_end
93
+ self._results_postprocessing_fn = results_postprocessing_fn
94
+
95
+ @classmethod
96
+ def get_model(
97
+ cls, checkpoint: Checkpoint, filename: str = CHECKPOINT_NAME
98
+ ) -> Booster:
99
+ """Retrieve the model stored in a checkpoint reported by this callback.
100
+
101
+ Args:
102
+ checkpoint: The checkpoint object returned by a training run.
103
+ The checkpoint should be saved by an instance of this callback.
104
+ filename: The filename to load the model from, which should match
105
+ the filename used when creating the callback.
106
+ """
107
+ with checkpoint.as_directory() as checkpoint_path:
108
+ return Booster(model_file=Path(checkpoint_path, filename).as_posix())
109
+
110
+ def _get_report_dict(self, evals_log: Dict[str, Dict[str, list]]) -> dict:
111
+ result_dict = flatten_dict(evals_log, delimiter="-")
112
+ if not self._metrics:
113
+ report_dict = result_dict
114
+ else:
115
+ report_dict = {}
116
+ for key in self._metrics:
117
+ if isinstance(self._metrics, dict):
118
+ metric = self._metrics[key]
119
+ else:
120
+ metric = key
121
+ report_dict[key] = result_dict[metric]
122
+ if self._results_postprocessing_fn:
123
+ report_dict = self._results_postprocessing_fn(report_dict)
124
+ return report_dict
125
+
126
+ def _get_eval_result(self, env: CallbackEnv) -> dict:
127
+ eval_result = {}
128
+ for entry in env.evaluation_result_list:
129
+ data_name, eval_name, result = entry[0:3]
130
+ if len(entry) > 4:
131
+ stdv = entry[4]
132
+ suffix = "-mean"
133
+ else:
134
+ stdv = None
135
+ suffix = ""
136
+ if data_name not in eval_result:
137
+ eval_result[data_name] = {}
138
+ eval_result[data_name][eval_name + suffix] = result
139
+ if stdv is not None:
140
+ eval_result[data_name][eval_name + "-stdv"] = stdv
141
+ return eval_result
142
+
143
+ @contextmanager
144
+ def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]:
145
+ if ray.train.get_context().get_world_rank() in (0, None):
146
+ with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
147
+ model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix())
148
+ yield Checkpoint.from_directory(temp_checkpoint_dir)
149
+ else:
150
+ yield None
151
+
152
+ def __call__(self, env: CallbackEnv) -> None:
153
+ eval_result = self._get_eval_result(env)
154
+ report_dict = self._get_report_dict(eval_result)
155
+
156
+ # Ex: if frequency=2, checkpoint_at_end=True and num_boost_rounds=11,
157
+ # you will checkpoint at iterations 1, 3, 5, ..., 9, and 10 (checkpoint_at_end)
158
+ # (iterations count from 0)
159
+ on_last_iter = env.iteration == env.end_iteration - 1
160
+ should_checkpoint_at_end = on_last_iter and self._checkpoint_at_end
161
+ should_checkpoint_with_frequency = (
162
+ self._frequency != 0 and (env.iteration + 1) % self._frequency == 0
163
+ )
164
+ should_checkpoint = should_checkpoint_at_end or should_checkpoint_with_frequency
165
+
166
+ if should_checkpoint:
167
+ with self._get_checkpoint(model=env.model) as checkpoint:
168
+ ray.train.report(report_dict, checkpoint=checkpoint)
169
+ else:
170
+ ray.train.report(report_dict)
.venv/lib/python3.11/site-packages/ray/train/lightgbm/config.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import threading
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, Optional
5
+
6
+ import ray
7
+ from ray.train._internal.utils import get_address_and_port
8
+ from ray.train._internal.worker_group import WorkerGroup
9
+ from ray.train.backend import Backend, BackendConfig
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ # Global LightGBM distributed network configuration for each worker process.
15
+ _lightgbm_network_params: Optional[Dict[str, Any]] = None
16
+ _lightgbm_network_params_lock = threading.Lock()
17
+
18
+
19
+ def get_network_params() -> Dict[str, Any]:
20
+ """Returns the network parameters to enable LightGBM distributed training."""
21
+ global _lightgbm_network_params
22
+
23
+ with _lightgbm_network_params_lock:
24
+ if not _lightgbm_network_params:
25
+ logger.warning(
26
+ "`ray.train.lightgbm.get_network_params` was called outside "
27
+ "the context of a `ray.train.lightgbm.LightGBMTrainer`. "
28
+ "The current process has no knowledge of the distributed training "
29
+ "worker group, so this method will return an empty dict. "
30
+ "Please call this within the training loop of a "
31
+ "`ray.train.lightgbm.LightGBMTrainer`. "
32
+ "If you are in fact calling this within a `LightGBMTrainer`, "
33
+ "this is unexpected: please file a bug report to the Ray Team."
34
+ )
35
+ return {}
36
+
37
+ return _lightgbm_network_params.copy()
38
+
39
+
40
+ def _set_network_params(
41
+ num_machines: int,
42
+ local_listen_port: int,
43
+ machines: str,
44
+ ):
45
+ global _lightgbm_network_params
46
+
47
+ with _lightgbm_network_params_lock:
48
+ assert (
49
+ _lightgbm_network_params is None
50
+ ), "LightGBM network params are already initialized."
51
+ _lightgbm_network_params = dict(
52
+ num_machines=num_machines,
53
+ local_listen_port=local_listen_port,
54
+ machines=machines,
55
+ )
56
+
57
+
58
+ @dataclass
59
+ class LightGBMConfig(BackendConfig):
60
+ """Configuration for LightGBM distributed data-parallel training setup.
61
+
62
+ See the LightGBM docs for more information on the "network parameters"
63
+ that Ray Train sets up for you:
64
+ https://lightgbm.readthedocs.io/en/latest/Parameters.html#network-parameters
65
+ """
66
+
67
+ @property
68
+ def backend_cls(self):
69
+ return _LightGBMBackend
70
+
71
+
72
+ class _LightGBMBackend(Backend):
73
+ def on_training_start(
74
+ self, worker_group: WorkerGroup, backend_config: LightGBMConfig
75
+ ):
76
+ node_ips_and_ports = worker_group.execute(get_address_and_port)
77
+ ports = [port for _, port in node_ips_and_ports]
78
+ machines = ",".join(
79
+ [f"{node_ip}:{port}" for node_ip, port in node_ips_and_ports]
80
+ )
81
+ num_machines = len(worker_group)
82
+ ray.get(
83
+ [
84
+ worker_group.execute_single_async(
85
+ rank, _set_network_params, num_machines, ports[rank], machines
86
+ )
87
+ for rank in range(len(worker_group))
88
+ ]
89
+ )
.venv/lib/python3.11/site-packages/ray/train/lightgbm/lightgbm_checkpoint.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ from pathlib import Path
3
+ from typing import TYPE_CHECKING, Optional
4
+
5
+ import lightgbm
6
+
7
+ from ray.train._internal.framework_checkpoint import FrameworkCheckpoint
8
+ from ray.util.annotations import PublicAPI
9
+
10
+ if TYPE_CHECKING:
11
+ from ray.data.preprocessor import Preprocessor
12
+
13
+
14
+ @PublicAPI(stability="beta")
15
+ class LightGBMCheckpoint(FrameworkCheckpoint):
16
+ """A :py:class:`~ray.train.Checkpoint` with LightGBM-specific functionality."""
17
+
18
+ MODEL_FILENAME = "model.txt"
19
+
20
+ @classmethod
21
+ def from_model(
22
+ cls,
23
+ booster: lightgbm.Booster,
24
+ *,
25
+ preprocessor: Optional["Preprocessor"] = None,
26
+ path: Optional[str] = None,
27
+ ) -> "LightGBMCheckpoint":
28
+ """Create a :py:class:`~ray.train.Checkpoint` that stores a LightGBM model.
29
+
30
+ Args:
31
+ booster: The LightGBM model to store in the checkpoint.
32
+ preprocessor: A fitted preprocessor to be applied before inference.
33
+ path: The path to the directory where the checkpoint file will be saved.
34
+ This should start as an empty directory, since the *entire*
35
+ directory will be treated as the checkpoint when reported.
36
+ By default, a temporary directory will be created.
37
+
38
+ Returns:
39
+ An :py:class:`LightGBMCheckpoint` containing the specified ``Estimator``.
40
+
41
+ Examples:
42
+ >>> import lightgbm
43
+ >>> import numpy as np
44
+ >>> from ray.train.lightgbm import LightGBMCheckpoint
45
+ >>>
46
+ >>> train_X = np.array([[1, 2], [3, 4]])
47
+ >>> train_y = np.array([0, 1])
48
+ >>>
49
+ >>> model = lightgbm.LGBMClassifier().fit(train_X, train_y)
50
+ >>> checkpoint = LightGBMCheckpoint.from_model(model.booster_)
51
+ """
52
+ checkpoint_path = Path(path or tempfile.mkdtemp())
53
+
54
+ if not checkpoint_path.is_dir():
55
+ raise ValueError(f"`path` must be a directory, but got: {checkpoint_path}")
56
+
57
+ booster.save_model(checkpoint_path.joinpath(cls.MODEL_FILENAME).as_posix())
58
+
59
+ checkpoint = cls.from_directory(checkpoint_path.as_posix())
60
+ if preprocessor:
61
+ checkpoint.set_preprocessor(preprocessor)
62
+
63
+ return checkpoint
64
+
65
+ def get_model(self) -> lightgbm.Booster:
66
+ """Retrieve the LightGBM model stored in this checkpoint."""
67
+ with self.as_directory() as checkpoint_path:
68
+ return lightgbm.Booster(
69
+ model_file=Path(checkpoint_path, self.MODEL_FILENAME).as_posix()
70
+ )
.venv/lib/python3.11/site-packages/ray/train/lightgbm/lightgbm_predictor.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING, List, Optional, Union
2
+
3
+ import lightgbm
4
+ import pandas as pd
5
+ from pandas.api.types import is_object_dtype
6
+
7
+ from ray.air.constants import TENSOR_COLUMN_NAME
8
+ from ray.air.data_batch_type import DataBatchType
9
+ from ray.air.util.data_batch_conversion import _unwrap_ndarray_object_type_if_needed
10
+ from ray.train.lightgbm import LightGBMCheckpoint
11
+ from ray.train.predictor import Predictor
12
+ from ray.util.annotations import PublicAPI
13
+
14
+ if TYPE_CHECKING:
15
+ from ray.data.preprocessor import Preprocessor
16
+
17
+
18
+ @PublicAPI(stability="beta")
19
+ class LightGBMPredictor(Predictor):
20
+ """A predictor for LightGBM models.
21
+
22
+ Args:
23
+ model: The LightGBM booster to use for predictions.
24
+ preprocessor: A preprocessor used to transform data batches prior
25
+ to prediction.
26
+ """
27
+
28
+ def __init__(
29
+ self, model: lightgbm.Booster, preprocessor: Optional["Preprocessor"] = None
30
+ ):
31
+ self.model = model
32
+ super().__init__(preprocessor)
33
+
34
+ def __repr__(self):
35
+ return (
36
+ f"{self.__class__.__name__}(model={self.model!r}, "
37
+ f"preprocessor={self._preprocessor!r})"
38
+ )
39
+
40
+ @classmethod
41
+ def from_checkpoint(cls, checkpoint: LightGBMCheckpoint) -> "LightGBMPredictor":
42
+ """Instantiate the predictor from a LightGBMCheckpoint.
43
+
44
+ Args:
45
+ checkpoint: The checkpoint to load the model and preprocessor from.
46
+
47
+ """
48
+ model = checkpoint.get_model()
49
+ preprocessor = checkpoint.get_preprocessor()
50
+ return cls(model=model, preprocessor=preprocessor)
51
+
52
+ def predict(
53
+ self,
54
+ data: DataBatchType,
55
+ feature_columns: Optional[Union[List[str], List[int]]] = None,
56
+ **predict_kwargs,
57
+ ) -> DataBatchType:
58
+ """Run inference on data batch.
59
+
60
+ Args:
61
+ data: A batch of input data.
62
+ feature_columns: The names or indices of the columns in the
63
+ data to use as features to predict on. If None, then use
64
+ all columns in ``data``.
65
+ **predict_kwargs: Keyword arguments passed to
66
+ ``lightgbm.Booster.predict``.
67
+
68
+ Examples:
69
+ >>> import numpy as np
70
+ >>> import lightgbm as lgbm
71
+ >>> from ray.train.lightgbm import LightGBMPredictor
72
+ >>>
73
+ >>> train_X = np.array([[1, 2], [3, 4]])
74
+ >>> train_y = np.array([0, 1])
75
+ >>>
76
+ >>> model = lgbm.LGBMClassifier().fit(train_X, train_y)
77
+ >>> predictor = LightGBMPredictor(model=model.booster_)
78
+ >>>
79
+ >>> data = np.array([[1, 2], [3, 4]])
80
+ >>> predictions = predictor.predict(data)
81
+ >>>
82
+ >>> # Only use first and second column as the feature
83
+ >>> data = np.array([[1, 2, 8], [3, 4, 9]])
84
+ >>> predictions = predictor.predict(data, feature_columns=[0, 1])
85
+
86
+ >>> import pandas as pd
87
+ >>> import lightgbm as lgbm
88
+ >>> from ray.train.lightgbm import LightGBMPredictor
89
+ >>>
90
+ >>> train_X = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"])
91
+ >>> train_y = pd.Series([0, 1])
92
+ >>>
93
+ >>> model = lgbm.LGBMClassifier().fit(train_X, train_y)
94
+ >>> predictor = LightGBMPredictor(model=model.booster_)
95
+ >>>
96
+ >>> # Pandas dataframe.
97
+ >>> data = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"])
98
+ >>> predictions = predictor.predict(data)
99
+ >>>
100
+ >>> # Only use first and second column as the feature
101
+ >>> data = pd.DataFrame([[1, 2, 8], [3, 4, 9]], columns=["A", "B", "C"])
102
+ >>> predictions = predictor.predict(data, feature_columns=["A", "B"])
103
+
104
+
105
+ Returns:
106
+ Prediction result.
107
+
108
+ """
109
+ return Predictor.predict(
110
+ self, data, feature_columns=feature_columns, **predict_kwargs
111
+ )
112
+
113
+ def _predict_pandas(
114
+ self,
115
+ data: "pd.DataFrame",
116
+ feature_columns: Optional[Union[List[str], List[int]]] = None,
117
+ **predict_kwargs,
118
+ ) -> pd.DataFrame:
119
+ feature_names = None
120
+ if TENSOR_COLUMN_NAME in data:
121
+ data = data[TENSOR_COLUMN_NAME].to_numpy()
122
+ data = _unwrap_ndarray_object_type_if_needed(data)
123
+ if feature_columns:
124
+ # In this case feature_columns is a list of integers
125
+ data = data[:, feature_columns]
126
+ # Turn into dataframe to make dtype resolution easy
127
+ data = pd.DataFrame(data, columns=feature_names)
128
+ data = data.infer_objects()
129
+
130
+ # Pandas does not detect categorical dtypes. Any remaining object
131
+ # dtypes are probably categories, so convert them.
132
+ # This will fail if we have a category composed entirely of
133
+ # integers, but this is the best we can do here.
134
+ update_dtypes = {}
135
+ for column in data.columns:
136
+ dtype = data.dtypes[column]
137
+ if is_object_dtype(dtype):
138
+ update_dtypes[column] = pd.CategoricalDtype()
139
+
140
+ if update_dtypes:
141
+ data = data.astype(update_dtypes, copy=False)
142
+ elif feature_columns:
143
+ # feature_columns is a list of integers or strings
144
+ data = data[feature_columns]
145
+
146
+ df = pd.DataFrame(self.model.predict(data, **predict_kwargs))
147
+ df.columns = (
148
+ ["predictions"]
149
+ if len(df.columns) == 1
150
+ else [f"predictions_{i}" for i in range(len(df.columns))]
151
+ )
152
+ return df
.venv/lib/python3.11/site-packages/ray/train/lightgbm/lightgbm_trainer.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from functools import partial
3
+ from typing import Any, Dict, Optional
4
+
5
+ import lightgbm
6
+
7
+ import ray
8
+ from ray.train import Checkpoint
9
+ from ray.train.constants import _DEPRECATED_VALUE, TRAIN_DATASET_KEY
10
+ from ray.train.lightgbm import RayTrainReportCallback
11
+ from ray.train.lightgbm.v2 import LightGBMTrainer as SimpleLightGBMTrainer
12
+ from ray.train.trainer import GenDataset
13
+ from ray.util.annotations import PublicAPI
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def _lightgbm_train_fn_per_worker(
19
+ config: dict,
20
+ label_column: str,
21
+ num_boost_round: int,
22
+ dataset_keys: set,
23
+ lightgbm_train_kwargs: dict,
24
+ ):
25
+ checkpoint = ray.train.get_checkpoint()
26
+ starting_model = None
27
+ remaining_iters = num_boost_round
28
+ if checkpoint:
29
+ starting_model = RayTrainReportCallback.get_model(checkpoint)
30
+ starting_iter = starting_model.current_iteration()
31
+ remaining_iters = num_boost_round - starting_iter
32
+ logger.info(
33
+ f"Model loaded from checkpoint will train for "
34
+ f"additional {remaining_iters} iterations (trees) in order "
35
+ "to achieve the target number of iterations "
36
+ f"({num_boost_round=})."
37
+ )
38
+
39
+ train_ds_iter = ray.train.get_dataset_shard(TRAIN_DATASET_KEY)
40
+ train_df = train_ds_iter.materialize().to_pandas()
41
+
42
+ eval_ds_iters = {
43
+ k: ray.train.get_dataset_shard(k)
44
+ for k in dataset_keys
45
+ if k != TRAIN_DATASET_KEY
46
+ }
47
+ eval_dfs = {k: d.materialize().to_pandas() for k, d in eval_ds_iters.items()}
48
+
49
+ train_X, train_y = train_df.drop(label_column, axis=1), train_df[label_column]
50
+ train_set = lightgbm.Dataset(train_X, label=train_y)
51
+
52
+ # NOTE: Include the training dataset in the evaluation datasets.
53
+ # This allows `train-*` metrics to be calculated and reported.
54
+ valid_sets = [train_set]
55
+ valid_names = [TRAIN_DATASET_KEY]
56
+
57
+ for eval_name, eval_df in eval_dfs.items():
58
+ eval_X, eval_y = eval_df.drop(label_column, axis=1), eval_df[label_column]
59
+ valid_sets.append(lightgbm.Dataset(eval_X, label=eval_y))
60
+ valid_names.append(eval_name)
61
+
62
+ # Add network params of the worker group to enable distributed training.
63
+ config.update(ray.train.lightgbm.v2.get_network_params())
64
+
65
+ lightgbm.train(
66
+ params=config,
67
+ train_set=train_set,
68
+ num_boost_round=remaining_iters,
69
+ valid_sets=valid_sets,
70
+ valid_names=valid_names,
71
+ init_model=starting_model,
72
+ **lightgbm_train_kwargs,
73
+ )
74
+
75
+
76
+ @PublicAPI(stability="beta")
77
+ class LightGBMTrainer(SimpleLightGBMTrainer):
78
+ """A Trainer for data parallel LightGBM training.
79
+
80
+ This Trainer runs the LightGBM training loop in a distributed manner
81
+ using multiple Ray Actors.
82
+
83
+ If you would like to take advantage of LightGBM's built-in handling
84
+ for features with the categorical data type, consider applying the
85
+ :class:`Categorizer` preprocessor to set the dtypes in the dataset.
86
+
87
+ .. note::
88
+ ``LightGBMTrainer`` does not modify or otherwise alter the working
89
+ of the LightGBM distributed training algorithm.
90
+ Ray only provides orchestration, data ingest and fault tolerance.
91
+ For more information on LightGBM distributed training, refer to
92
+ `LightGBM documentation <https://lightgbm.readthedocs.io/>`__.
93
+
94
+ Example:
95
+ .. testcode::
96
+
97
+ import ray
98
+
99
+ from ray.train.lightgbm import LightGBMTrainer
100
+ from ray.train import ScalingConfig
101
+
102
+ train_dataset = ray.data.from_items(
103
+ [{"x": x, "y": x + 1} for x in range(32)]
104
+ )
105
+ trainer = LightGBMTrainer(
106
+ label_column="y",
107
+ params={"objective": "regression"},
108
+ scaling_config=ScalingConfig(num_workers=3),
109
+ datasets={"train": train_dataset},
110
+ )
111
+ result = trainer.fit()
112
+
113
+ .. testoutput::
114
+ :hide:
115
+
116
+ ...
117
+
118
+ Args:
119
+ datasets: The Ray Datasets to use for training and validation. Must include a
120
+ "train" key denoting the training dataset. All non-training datasets will
121
+ be used as separate validation sets, each reporting a separate metric.
122
+ label_column: Name of the label column. A column with this name
123
+ must be present in the training dataset.
124
+ params: LightGBM training parameters passed to ``lightgbm.train()``.
125
+ Refer to `LightGBM documentation <https://lightgbm.readthedocs.io>`_
126
+ for a list of possible parameters.
127
+ num_boost_round: Target number of boosting iterations (trees in the model).
128
+ Note that unlike in ``lightgbm.train``, this is the target number
129
+ of trees, meaning that if you set ``num_boost_round=10`` and pass a model
130
+ that has already been trained for 5 iterations, it will be trained for 5
131
+ iterations more, instead of 10 more.
132
+ scaling_config: Configuration for how to scale data parallel training.
133
+ run_config: Configuration for the execution of the training run.
134
+ resume_from_checkpoint: A checkpoint to resume training from.
135
+ metadata: Dict that should be made available in `checkpoint.get_metadata()`
136
+ for checkpoints saved from this Trainer. Must be JSON-serializable.
137
+ **train_kwargs: Additional kwargs passed to ``lightgbm.train()`` function.
138
+ """
139
+
140
+ _handles_checkpoint_freq = True
141
+ _handles_checkpoint_at_end = True
142
+
143
+ def __init__(
144
+ self,
145
+ *,
146
+ datasets: Dict[str, GenDataset],
147
+ label_column: str,
148
+ params: Dict[str, Any],
149
+ num_boost_round: int = 10,
150
+ scaling_config: Optional[ray.train.ScalingConfig] = None,
151
+ run_config: Optional[ray.train.RunConfig] = None,
152
+ dataset_config: Optional[ray.train.DataConfig] = None,
153
+ resume_from_checkpoint: Optional[Checkpoint] = None,
154
+ metadata: Optional[Dict[str, Any]] = None,
155
+ dmatrix_params: Optional[Dict[str, Dict[str, Any]]] = _DEPRECATED_VALUE,
156
+ **train_kwargs,
157
+ ):
158
+ # TODO(justinvyu): [Deprecated] Remove in 2.11
159
+ if dmatrix_params != _DEPRECATED_VALUE:
160
+ raise DeprecationWarning(
161
+ "`dmatrix_params` is deprecated, since XGBoostTrainer no longer "
162
+ "depends on the `xgboost_ray.RayDMatrix` utility. "
163
+ "You can remove this argument and use `dataset_config` instead "
164
+ "to customize Ray Dataset ingestion."
165
+ )
166
+
167
+ # Initialize a default Ray Train metrics/checkpoint reporting callback if needed
168
+ callbacks = train_kwargs.get("callbacks", [])
169
+ user_supplied_callback = any(
170
+ isinstance(callback, RayTrainReportCallback) for callback in callbacks
171
+ )
172
+ callback_kwargs = {}
173
+ if run_config:
174
+ checkpoint_frequency = run_config.checkpoint_config.checkpoint_frequency
175
+ checkpoint_at_end = run_config.checkpoint_config.checkpoint_at_end
176
+
177
+ callback_kwargs["frequency"] = checkpoint_frequency
178
+ # Default `checkpoint_at_end=True` unless the user explicitly sets it.
179
+ callback_kwargs["checkpoint_at_end"] = (
180
+ checkpoint_at_end if checkpoint_at_end is not None else True
181
+ )
182
+
183
+ if not user_supplied_callback:
184
+ callbacks.append(RayTrainReportCallback(**callback_kwargs))
185
+ train_kwargs["callbacks"] = callbacks
186
+
187
+ train_fn_per_worker = partial(
188
+ _lightgbm_train_fn_per_worker,
189
+ label_column=label_column,
190
+ num_boost_round=num_boost_round,
191
+ dataset_keys=set(datasets),
192
+ lightgbm_train_kwargs=train_kwargs,
193
+ )
194
+
195
+ super(LightGBMTrainer, self).__init__(
196
+ train_loop_per_worker=train_fn_per_worker,
197
+ train_loop_config=params,
198
+ scaling_config=scaling_config,
199
+ run_config=run_config,
200
+ datasets=datasets,
201
+ dataset_config=dataset_config,
202
+ resume_from_checkpoint=resume_from_checkpoint,
203
+ metadata=metadata,
204
+ )
205
+
206
+ @classmethod
207
+ def get_model(
208
+ cls,
209
+ checkpoint: Checkpoint,
210
+ ) -> lightgbm.Booster:
211
+ """Retrieve the LightGBM model stored in this checkpoint."""
212
+ return RayTrainReportCallback.get_model(checkpoint)
213
+
214
+ def _validate_attributes(self):
215
+ super()._validate_attributes()
216
+
217
+ if TRAIN_DATASET_KEY not in self.datasets:
218
+ raise KeyError(
219
+ f"'{TRAIN_DATASET_KEY}' key must be preset in `datasets`. "
220
+ f"Got {list(self.datasets.keys())}"
221
+ )
.venv/lib/python3.11/site-packages/ray/train/lightgbm/v2.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Callable, Dict, Optional, Union
3
+
4
+ import ray.train
5
+ from ray.train import Checkpoint
6
+ from ray.train.data_parallel_trainer import DataParallelTrainer
7
+ from ray.train.lightgbm.config import LightGBMConfig, get_network_params # noqa: F401
8
+ from ray.train.trainer import GenDataset
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class LightGBMTrainer(DataParallelTrainer):
14
+ """A Trainer for distributed data-parallel LightGBM training.
15
+
16
+ Example
17
+ -------
18
+
19
+ .. testcode::
20
+
21
+ import lightgbm as lgb
22
+
23
+ import ray.data
24
+ import ray.train
25
+ from ray.train.lightgbm import RayTrainReportCallback
26
+ from ray.train.lightgbm.v2 import LightGBMTrainer
27
+
28
+
29
+ def train_fn_per_worker(config: dict):
30
+ # (Optional) Add logic to resume training state from a checkpoint.
31
+ # ray.train.get_checkpoint()
32
+
33
+ # 1. Get the dataset shard for the worker and convert to a `lgb.Dataset`
34
+ train_ds_iter, eval_ds_iter = (
35
+ ray.train.get_dataset_shard("train"),
36
+ ray.train.get_dataset_shard("validation"),
37
+ )
38
+ train_ds, eval_ds = train_ds_iter.materialize(), eval_ds_iter.materialize()
39
+ train_df, eval_df = train_ds.to_pandas(), eval_ds.to_pandas()
40
+ train_X, train_y = train_df.drop("y", axis=1), train_df["y"]
41
+ eval_X, eval_y = eval_df.drop("y", axis=1), eval_df["y"]
42
+
43
+ train_set = lgb.Dataset(train_X, label=train_y)
44
+ eval_set = lgb.Dataset(eval_X, label=eval_y)
45
+
46
+ # 2. Run distributed data-parallel training.
47
+ # `get_network_params` sets up the necessary configurations for LightGBM
48
+ # to set up the data parallel training worker group on your Ray cluster.
49
+ params = {
50
+ "objective": "regression",
51
+ # Adding the line below is the only change needed
52
+ # for your `lgb.train` call!
53
+ **ray.train.lightgbm.v2.get_network_params(),
54
+ }
55
+ lgb.train(
56
+ params,
57
+ train_set,
58
+ valid_sets=[eval_set],
59
+ valid_names=["eval"],
60
+ callbacks=[RayTrainReportCallback()],
61
+ )
62
+
63
+ train_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)])
64
+ eval_ds = ray.data.from_items(
65
+ [{"x": x, "y": x + 1} for x in range(32, 32 + 16)]
66
+ )
67
+ trainer = LightGBMTrainer(
68
+ train_fn_per_worker,
69
+ datasets={"train": train_ds, "validation": eval_ds},
70
+ scaling_config=ray.train.ScalingConfig(num_workers=4),
71
+ )
72
+ result = trainer.fit()
73
+ booster = RayTrainReportCallback.get_model(result.checkpoint)
74
+
75
+ .. testoutput::
76
+ :hide:
77
+
78
+ ...
79
+
80
+ Args:
81
+ train_loop_per_worker: The training function to execute on each worker.
82
+ This function can either take in zero arguments or a single ``Dict``
83
+ argument which is set by defining ``train_loop_config``.
84
+ Within this function you can use any of the
85
+ :ref:`Ray Train Loop utilities <train-loop-api>`.
86
+ train_loop_config: A configuration ``Dict`` to pass in as an argument to
87
+ ``train_loop_per_worker``.
88
+ This is typically used for specifying hyperparameters.
89
+ lightgbm_config: The configuration for setting up the distributed lightgbm
90
+ backend. See :class:`~ray.train.lightgbm.LightGBMConfig` for more info.
91
+ datasets: The Ray Datasets to use for training and validation.
92
+ dataset_config: The configuration for ingesting the input ``datasets``.
93
+ By default, all the Ray Dataset are split equally across workers.
94
+ See :class:`~ray.train.DataConfig` for more details.
95
+ scaling_config: The configuration for how to scale data parallel training.
96
+ ``num_workers`` determines how many Python processes are used for training,
97
+ and ``use_gpu`` determines whether or not each process should use GPUs.
98
+ See :class:`~ray.train.ScalingConfig` for more info.
99
+ run_config: The configuration for the execution of the training run.
100
+ See :class:`~ray.train.RunConfig` for more info.
101
+ resume_from_checkpoint: A checkpoint to resume training from.
102
+ This checkpoint can be accessed from within ``train_loop_per_worker``
103
+ by calling ``ray.train.get_checkpoint()``.
104
+ metadata: Dict that should be made available via
105
+ `ray.train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
106
+ for checkpoints saved from this Trainer. Must be JSON-serializable.
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]],
112
+ *,
113
+ train_loop_config: Optional[Dict] = None,
114
+ lightgbm_config: Optional[LightGBMConfig] = None,
115
+ scaling_config: Optional[ray.train.ScalingConfig] = None,
116
+ run_config: Optional[ray.train.RunConfig] = None,
117
+ datasets: Optional[Dict[str, GenDataset]] = None,
118
+ dataset_config: Optional[ray.train.DataConfig] = None,
119
+ metadata: Optional[Dict[str, Any]] = None,
120
+ resume_from_checkpoint: Optional[Checkpoint] = None,
121
+ ):
122
+ super(LightGBMTrainer, self).__init__(
123
+ train_loop_per_worker=train_loop_per_worker,
124
+ train_loop_config=train_loop_config,
125
+ backend_config=lightgbm_config or LightGBMConfig(),
126
+ scaling_config=scaling_config,
127
+ dataset_config=dataset_config,
128
+ run_config=run_config,
129
+ datasets=datasets,
130
+ resume_from_checkpoint=resume_from_checkpoint,
131
+ metadata=metadata,
132
+ )
.venv/lib/python3.11/site-packages/ray/train/predictor.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import Callable, Dict, Optional, Type, Union
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+ from ray.air.data_batch_type import DataBatchType
8
+ from ray.air.util.data_batch_conversion import (
9
+ BatchFormat,
10
+ _convert_batch_type_to_numpy,
11
+ _convert_batch_type_to_pandas,
12
+ )
13
+ from ray.data import Preprocessor
14
+ from ray.train import Checkpoint
15
+ from ray.util.annotations import DeveloperAPI, PublicAPI
16
+
17
+ try:
18
+ import pyarrow
19
+
20
+ pa_table = pyarrow.Table
21
+ except ImportError:
22
+ pa_table = None
23
+
24
+ # Reverse mapping from data batch type to batch format.
25
+ TYPE_TO_ENUM: Dict[Type[DataBatchType], BatchFormat] = {
26
+ np.ndarray: BatchFormat.NUMPY,
27
+ dict: BatchFormat.NUMPY,
28
+ pd.DataFrame: BatchFormat.PANDAS,
29
+ }
30
+
31
+
32
+ @PublicAPI(stability="beta")
33
+ class PredictorNotSerializableException(RuntimeError):
34
+ """Error raised when trying to serialize a Predictor instance."""
35
+
36
+ pass
37
+
38
+
39
+ @PublicAPI(stability="beta")
40
+ class Predictor(abc.ABC):
41
+ """Predictors load models from checkpoints to perform inference.
42
+
43
+ .. note::
44
+ The base ``Predictor`` class cannot be instantiated directly. Only one of
45
+ its subclasses can be used.
46
+
47
+ **How does a Predictor work?**
48
+
49
+ Predictors expose a ``predict`` method that accepts an input batch of type
50
+ ``DataBatchType`` and outputs predictions of the same type as the input batch.
51
+
52
+ When the ``predict`` method is called the following occurs:
53
+
54
+ - The input batch is converted into a pandas DataFrame. Tensor input (like a
55
+ ``np.ndarray``) will be converted into a single column Pandas Dataframe.
56
+ - If there is a :ref:`Preprocessor <preprocessor-ref>` saved in the provided
57
+ :class:`Checkpoint <ray.train.Checkpoint>`, the preprocessor will be used to
58
+ transform the DataFrame.
59
+ - The transformed DataFrame will be passed to the model for inference (via the
60
+ ``predictor._predict_pandas`` method).
61
+ - The predictions will be outputted by ``predict`` in the same type as the
62
+ original input.
63
+
64
+ **How do I create a new Predictor?**
65
+
66
+ To implement a new Predictor for your particular framework, you should subclass
67
+ the base ``Predictor`` and implement the following two methods:
68
+
69
+ 1. ``_predict_pandas``: Given a pandas.DataFrame input, return a
70
+ pandas.DataFrame containing predictions.
71
+ 2. ``from_checkpoint``: Logic for creating a Predictor from a
72
+ :class:`Checkpoint <ray.train.Checkpoint>`.
73
+ 3. Optionally ``_predict_numpy`` for better performance when working with
74
+ tensor data to avoid extra copies from Pandas conversions.
75
+ """
76
+
77
+ def __init__(self, preprocessor: Optional[Preprocessor] = None):
78
+ """Subclasseses must call Predictor.__init__() to set a preprocessor."""
79
+ self._preprocessor: Optional[Preprocessor] = preprocessor
80
+ # Whether tensor columns should be automatically cast from/to the tensor
81
+ # extension type at UDF boundaries. This can be overridden by subclasses.
82
+ self._cast_tensor_columns = False
83
+
84
+ @classmethod
85
+ @abc.abstractmethod
86
+ def from_checkpoint(cls, checkpoint: Checkpoint, **kwargs) -> "Predictor":
87
+ """Create a specific predictor from a checkpoint.
88
+
89
+ Args:
90
+ checkpoint: Checkpoint to load predictor data from.
91
+ kwargs: Arguments specific to predictor implementations.
92
+
93
+ Returns:
94
+ Predictor: Predictor object.
95
+ """
96
+ raise NotImplementedError
97
+
98
+ @classmethod
99
+ def from_pandas_udf(
100
+ cls, pandas_udf: Callable[[pd.DataFrame], pd.DataFrame]
101
+ ) -> "Predictor":
102
+ """Create a Predictor from a Pandas UDF.
103
+
104
+ Args:
105
+ pandas_udf: A function that takes a pandas.DataFrame and other
106
+ optional kwargs and returns a pandas.DataFrame.
107
+ """
108
+
109
+ class PandasUDFPredictor(Predictor):
110
+ @classmethod
111
+ def from_checkpoint(cls, checkpoint: Checkpoint, **kwargs) -> "Predictor":
112
+ return PandasUDFPredictor()
113
+
114
+ def _predict_pandas(self, df, **kwargs) -> "pd.DataFrame":
115
+ return pandas_udf(df, **kwargs)
116
+
117
+ return PandasUDFPredictor()
118
+
119
+ def get_preprocessor(self) -> Optional[Preprocessor]:
120
+ """Get the preprocessor to use prior to executing predictions."""
121
+ return self._preprocessor
122
+
123
+ def set_preprocessor(self, preprocessor: Optional[Preprocessor]) -> None:
124
+ """Set the preprocessor to use prior to executing predictions."""
125
+ self._preprocessor = preprocessor
126
+
127
+ @classmethod
128
+ @DeveloperAPI
129
+ def preferred_batch_format(cls) -> BatchFormat:
130
+ """Batch format hint for upstream producers to try yielding best block format.
131
+
132
+ The preferred batch format to use if both `_predict_pandas` and
133
+ `_predict_numpy` are implemented. Defaults to Pandas.
134
+
135
+ Can be overriden by predictor classes depending on the framework type,
136
+ e.g. TorchPredictor prefers Numpy and XGBoostPredictor prefers Pandas as
137
+ native batch format.
138
+
139
+ """
140
+ return BatchFormat.PANDAS
141
+
142
+ @classmethod
143
+ def _batch_format_to_use(cls) -> BatchFormat:
144
+ """Determine the batch format to use for the predictor."""
145
+ has_pandas_implemented = cls._predict_pandas != Predictor._predict_pandas
146
+ has_numpy_implemented = cls._predict_numpy != Predictor._predict_numpy
147
+ if has_pandas_implemented and has_numpy_implemented:
148
+ return cls.preferred_batch_format()
149
+ elif has_pandas_implemented:
150
+ return BatchFormat.PANDAS
151
+ elif has_numpy_implemented:
152
+ return BatchFormat.NUMPY
153
+ else:
154
+ raise NotImplementedError(
155
+ f"Predictor {cls.__name__} must implement at least one of "
156
+ "`_predict_pandas` and `_predict_numpy`."
157
+ )
158
+
159
+ def _set_cast_tensor_columns(self):
160
+ """Enable automatic tensor column casting.
161
+
162
+ If this is called on a predictor, the predictor will cast tensor columns to
163
+ NumPy ndarrays in the input to the preprocessors and cast tensor columns back to
164
+ the tensor extension type in the prediction outputs.
165
+ """
166
+ self._cast_tensor_columns = True
167
+
168
+ def predict(self, data: DataBatchType, **kwargs) -> DataBatchType:
169
+ """Perform inference on a batch of data.
170
+
171
+ Args:
172
+ data: A batch of input data of type ``DataBatchType``.
173
+ kwargs: Arguments specific to predictor implementations. These are passed
174
+ directly to ``_predict_numpy`` or ``_predict_pandas``.
175
+
176
+ Returns:
177
+ DataBatchType:
178
+ Prediction result. The return type will be the same as the input type.
179
+ """
180
+ if not hasattr(self, "_preprocessor"):
181
+ raise NotImplementedError(
182
+ "Subclasses of Predictor must call Predictor.__init__(preprocessor)."
183
+ )
184
+ try:
185
+ batch_format = TYPE_TO_ENUM[type(data)]
186
+ except KeyError:
187
+ raise RuntimeError(
188
+ f"Invalid input data type of {type(data)}, supported "
189
+ f"types: {list(TYPE_TO_ENUM.keys())}"
190
+ )
191
+
192
+ if self._preprocessor:
193
+ data = self._preprocessor.transform_batch(data)
194
+
195
+ batch_format_to_use = self._batch_format_to_use()
196
+
197
+ # We can finish prediction as long as one predict method is implemented.
198
+ # For prediction, we have to return back in the same format as the input.
199
+ if batch_format == BatchFormat.PANDAS:
200
+ if batch_format_to_use == BatchFormat.PANDAS:
201
+ return self._predict_pandas(
202
+ _convert_batch_type_to_pandas(data), **kwargs
203
+ )
204
+ elif batch_format_to_use == BatchFormat.NUMPY:
205
+ return _convert_batch_type_to_pandas(
206
+ self._predict_numpy(_convert_batch_type_to_numpy(data), **kwargs)
207
+ )
208
+ elif batch_format == BatchFormat.NUMPY:
209
+ if batch_format_to_use == BatchFormat.PANDAS:
210
+ return _convert_batch_type_to_numpy(
211
+ self._predict_pandas(_convert_batch_type_to_pandas(data), **kwargs)
212
+ )
213
+ elif batch_format_to_use == BatchFormat.NUMPY:
214
+ return self._predict_numpy(_convert_batch_type_to_numpy(data), **kwargs)
215
+
216
+ @DeveloperAPI
217
+ def _predict_pandas(self, data: "pd.DataFrame", **kwargs) -> "pd.DataFrame":
218
+ """Perform inference on a Pandas DataFrame.
219
+
220
+ Args:
221
+ data: A pandas DataFrame to perform predictions on.
222
+ kwargs: Arguments specific to the predictor implementation.
223
+
224
+ Returns:
225
+ A pandas DataFrame containing the prediction result.
226
+
227
+ """
228
+ raise NotImplementedError
229
+
230
+ @DeveloperAPI
231
+ def _predict_numpy(
232
+ self, data: Union[np.ndarray, Dict[str, np.ndarray]], **kwargs
233
+ ) -> Union[np.ndarray, Dict[str, np.ndarray]]:
234
+ """Perform inference on a Numpy data.
235
+
236
+ All Predictors working with tensor data (like deep learning predictors)
237
+ should implement this method.
238
+
239
+ Args:
240
+ data: A Numpy ndarray or dictionary of ndarrays to perform predictions on.
241
+ kwargs: Arguments specific to the predictor implementation.
242
+
243
+ Returns:
244
+ A Numpy ndarray or dictionary of ndarray containing the prediction result.
245
+
246
+ """
247
+ raise NotImplementedError
248
+
249
+ def __reduce__(self):
250
+ raise PredictorNotSerializableException(
251
+ "Predictor instances are not serializable. Instead, you may want "
252
+ "to serialize a checkpoint and initialize the Predictor with "
253
+ "Predictor.from_checkpoint."
254
+ )
.venv/lib/python3.11/site-packages/ray/train/session.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/train/trainer.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import traceback
3
+ from pathlib import Path
4
+ from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
5
+
6
+ from ray.air._internal.util import (
7
+ StartTraceback,
8
+ StartTracebackWithWorkerRank,
9
+ skip_exceptions,
10
+ )
11
+ from ray.data import Dataset
12
+ from ray.train import Checkpoint, DataConfig
13
+ from ray.train._internal.backend_executor import (
14
+ BackendExecutor,
15
+ InactiveWorkerGroupError,
16
+ TrainBackendError,
17
+ TrainingWorkerError,
18
+ )
19
+ from ray.train._internal.session import _TrainingResult, _TrainSession, get_session
20
+ from ray.train._internal.utils import ActorWrapper
21
+ from ray.train.backend import BackendConfig
22
+ from ray.train.base_trainer import ( # noqa: F401
23
+ BaseTrainer,
24
+ GenDataset,
25
+ TrainingFailedError,
26
+ )
27
+ from ray.util.annotations import DeveloperAPI
28
+
29
+ T = TypeVar("T")
30
+ S = TypeVar("S")
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ @DeveloperAPI
36
+ class TrainingIterator:
37
+ """An iterator over Train results. Returned by ``trainer.run_iterator``."""
38
+
39
+ def __init__(
40
+ self,
41
+ backend_executor: Union[BackendExecutor, ActorWrapper],
42
+ backend_config: BackendConfig,
43
+ train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
44
+ datasets: Dict[str, Dataset],
45
+ metadata: Dict[str, Any],
46
+ data_config: DataConfig,
47
+ checkpoint: Optional[Union[Dict, str, Path, Checkpoint]],
48
+ ):
49
+ self._backend_executor = backend_executor
50
+ self._backend = backend_config.backend_cls()
51
+ self._train_func = train_func
52
+ self._datasets = datasets
53
+ self._metadata = metadata
54
+ self._data_config = data_config
55
+
56
+ self._start_training(
57
+ train_func=train_func,
58
+ datasets=self._datasets,
59
+ metadata=self._metadata,
60
+ data_config=self._data_config,
61
+ checkpoint=checkpoint,
62
+ )
63
+
64
+ self._finished_training = False
65
+
66
+ def __iter__(self):
67
+ return self
68
+
69
+ def _start_training(
70
+ self,
71
+ train_func,
72
+ datasets,
73
+ metadata,
74
+ data_config,
75
+ checkpoint: Optional[Checkpoint] = None,
76
+ ):
77
+ tune_session: _TrainSession = get_session()
78
+ assert tune_session, "`_start_training` should only be called from within Tune"
79
+ storage = tune_session.storage
80
+
81
+ self._run_with_error_handling(
82
+ lambda: self._backend_executor.start_training(
83
+ train_func=train_func,
84
+ datasets=datasets,
85
+ metadata=metadata,
86
+ data_config=data_config,
87
+ storage=storage,
88
+ checkpoint=checkpoint,
89
+ )
90
+ )
91
+
92
+ def _run_with_error_handling(self, func: Callable):
93
+ try:
94
+ return func()
95
+ except TrainingWorkerError:
96
+ # TODO(ml-team): This Train fault-tolerance code doesn't get used
97
+ # since max_retries=0
98
+ # Workers have already been restarted.
99
+ logger.info(
100
+ "Workers have been successfully restarted. Resuming "
101
+ "training from latest checkpoint."
102
+ )
103
+ self._start_training(
104
+ self._train_func,
105
+ self._datasets,
106
+ self._metadata,
107
+ self._data_config,
108
+ )
109
+ return self._run_with_error_handling(func)
110
+ except InactiveWorkerGroupError:
111
+ raise RuntimeError(
112
+ "This Trainer is not active. It is either shutdown "
113
+ "already or never started in the first place. "
114
+ "Either create a new Trainer or start this one."
115
+ ) from None
116
+ except TrainBackendError:
117
+ raise RuntimeError(
118
+ "Training failed. You should not be seeing "
119
+ "this error and this is a bug. Please create "
120
+ "a new issue at "
121
+ "https://github.com/ray-project/ray."
122
+ ) from None
123
+
124
+ def __next__(self):
125
+ if self.is_finished():
126
+ self._backend_executor.report_final_run_status(errored=False)
127
+ raise StopIteration
128
+ try:
129
+ next_results = self._run_with_error_handling(self._fetch_next_result)
130
+ if next_results is None:
131
+ self._backend_executor.report_final_run_status(errored=False)
132
+ self._run_with_error_handling(self._finish_training)
133
+ self._finished_training = True
134
+ raise StopIteration
135
+ else:
136
+ return next_results
137
+ except StartTraceback as e:
138
+ # If this is a StartTraceback, then this is a user error.
139
+ # We raise it directly
140
+ if isinstance(e, StartTracebackWithWorkerRank):
141
+ failed_rank = e.worker_rank
142
+ else:
143
+ failed_rank = None
144
+
145
+ # Extract the stack trace from the exception
146
+ e = skip_exceptions(e)
147
+ stack_trace = "".join(
148
+ traceback.format_exception(type(e), e, e.__traceback__)
149
+ )
150
+
151
+ self._backend_executor.report_final_run_status(
152
+ errored=True, stack_trace=stack_trace, failed_rank=failed_rank
153
+ )
154
+ try:
155
+ # Exception raised in at least one training worker. Immediately raise
156
+ # this error to the user and do not attempt to terminate gracefully.
157
+ self._backend_executor.shutdown(graceful_termination=False)
158
+ self._finished_training = True
159
+ except Exception:
160
+ pass
161
+ raise
162
+
163
+ def _fetch_next_result(self) -> Optional[List[Dict]]:
164
+ """Fetch next results produced by ``session.report()`` from each worker.
165
+
166
+ Assumes ``start_training`` has already been called.
167
+
168
+ Returns:
169
+ A list of dictionaries of values passed to ``session.report()`` from
170
+ each worker. Each item corresponds to an intermediate result
171
+ a single worker. If there are no more items to fetch,
172
+ returns None.
173
+ """
174
+ results = self._backend_executor.get_next_results()
175
+ if results is None:
176
+ return None
177
+ assert all(isinstance(result, _TrainingResult) for result in results)
178
+ return results
179
+
180
+ def _finish_training(self):
181
+ """Finish training and return final results. Propagate any exceptions.
182
+
183
+ Blocks until training is finished on all workers.
184
+
185
+ Assumes `start_training` has already been called.
186
+
187
+ Returns:
188
+ A list of return values from calling ``train_func`` on each worker.
189
+ Each item corresponds to the return value from a single worker.
190
+ """
191
+ return self._backend_executor.finish_training()
192
+
193
+ def is_finished(self) -> bool:
194
+ return self._finished_training
.venv/lib/python3.11/site-packages/ray/train/utils.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ from ray.util.annotations import RayDeprecationWarning
4
+
5
+
6
+ def _copy_doc(copy_func):
7
+ def wrapped(func):
8
+ func.__doc__ = copy_func.__doc__
9
+ return func
10
+
11
+ return wrapped
12
+
13
+
14
+ def _log_deprecation_warning(message):
15
+ warnings.warn(
16
+ message,
17
+ RayDeprecationWarning,
18
+ stacklevel=2,
19
+ )
.venv/lib/python3.11/site-packages/ray/train/xgboost/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.train.v2._internal.constants import is_v2_enabled
2
+ from ray.train.xgboost._xgboost_utils import RayTrainReportCallback
3
+ from ray.train.xgboost.config import XGBoostConfig
4
+ from ray.train.xgboost.xgboost_checkpoint import XGBoostCheckpoint
5
+ from ray.train.xgboost.xgboost_predictor import XGBoostPredictor
6
+ from ray.train.xgboost.xgboost_trainer import XGBoostTrainer
7
+
8
+ if is_v2_enabled():
9
+ from ray.train.v2.xgboost.xgboost_trainer import XGBoostTrainer # noqa: F811
10
+
11
+ __all__ = [
12
+ "RayTrainReportCallback",
13
+ "XGBoostCheckpoint",
14
+ "XGBoostConfig",
15
+ "XGBoostPredictor",
16
+ "XGBoostTrainer",
17
+ ]
18
+
19
+
20
+ # DO NOT ADD ANYTHING AFTER THIS LINE.
.venv/lib/python3.11/site-packages/ray/train/xgboost/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (883 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/train/xgboost/__pycache__/_xgboost_utils.cpython-311.pyc ADDED
Binary file (10.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/xgboost/__pycache__/config.cpython-311.pyc ADDED
Binary file (10.7 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/xgboost/__pycache__/v2.cpython-311.pyc ADDED
Binary file (6.67 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/xgboost/__pycache__/xgboost_checkpoint.cpython-311.pyc ADDED
Binary file (4.11 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/xgboost/__pycache__/xgboost_predictor.cpython-311.pyc ADDED
Binary file (7.94 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/xgboost/__pycache__/xgboost_trainer.cpython-311.pyc ADDED
Binary file (11.1 kB). View file
 
.venv/lib/python3.11/site-packages/ray/train/xgboost/_xgboost_utils.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ from collections import OrderedDict
3
+ from contextlib import contextmanager
4
+ from pathlib import Path
5
+ from typing import Callable, Dict, List, Optional, Union
6
+
7
+ from xgboost.core import Booster
8
+
9
+ import ray.train
10
+ from ray.train import Checkpoint
11
+ from ray.tune.utils import flatten_dict
12
+ from ray.util.annotations import PublicAPI
13
+
14
+ try:
15
+ from xgboost.callback import TrainingCallback
16
+ except ImportError:
17
+
18
+ class TrainingCallback:
19
+ pass
20
+
21
+
22
+ class TuneCallback(TrainingCallback):
23
+ # TODO(justinvyu): [code_removal] Remove this after enforcing min xgboost version.
24
+ """Base class for Tune's XGBoost callbacks."""
25
+
26
+ def __call__(self, env):
27
+ """Compatibility with xgboost<1.3"""
28
+ return self.after_iteration(
29
+ env.model, env.iteration, env.evaluation_result_list
30
+ )
31
+
32
+ def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):
33
+ raise NotImplementedError
34
+
35
+
36
+ @PublicAPI(stability="beta")
37
+ class RayTrainReportCallback(TuneCallback):
38
+ """XGBoost callback to save checkpoints and report metrics.
39
+
40
+ Args:
41
+ metrics: Metrics to report. If this is a list,
42
+ each item describes the metric key reported to XGBoost,
43
+ and it will be reported under the same name.
44
+ This can also be a dict of {<key-to-report>: <xgboost-metric-key>},
45
+ which can be used to rename xgboost default metrics.
46
+ filename: Customize the saved checkpoint file type by passing
47
+ a filename. Defaults to "model.ubj".
48
+ frequency: How often to save checkpoints, in terms of iterations.
49
+ Defaults to 0 (no checkpoints are saved during training).
50
+ checkpoint_at_end: Whether or not to save a checkpoint at the end of training.
51
+ results_postprocessing_fn: An optional Callable that takes in
52
+ the metrics dict that will be reported (after it has been flattened)
53
+ and returns a modified dict. For example, this can be used to
54
+ average results across CV fold when using ``xgboost.cv``.
55
+
56
+ Examples
57
+ --------
58
+
59
+ Reporting checkpoints and metrics to Ray Tune when running many
60
+ independent xgboost trials (without data parallelism within a trial).
61
+
62
+ .. testcode::
63
+ :skipif: True
64
+
65
+ import xgboost
66
+
67
+ from ray.tune import Tuner
68
+ from ray.train.xgboost import RayTrainReportCallback
69
+
70
+ def train_fn(config):
71
+ # Report log loss to Ray Tune after each validation epoch.
72
+ bst = xgboost.train(
73
+ ...,
74
+ callbacks=[
75
+ RayTrainReportCallback(
76
+ metrics={"loss": "eval-logloss"}, frequency=1
77
+ )
78
+ ],
79
+ )
80
+
81
+ tuner = Tuner(train_fn)
82
+ results = tuner.fit()
83
+
84
+ Loading a model from a checkpoint reported by this callback.
85
+
86
+ .. testcode::
87
+ :skipif: True
88
+
89
+ from ray.train.xgboost import RayTrainReportCallback
90
+
91
+ # Get a `Checkpoint` object that is saved by the callback during training.
92
+ result = trainer.fit()
93
+ booster = RayTrainReportCallback.get_model(result.checkpoint)
94
+
95
+ """
96
+
97
+ CHECKPOINT_NAME = "model.ubj"
98
+
99
+ def __init__(
100
+ self,
101
+ metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
102
+ filename: str = CHECKPOINT_NAME,
103
+ frequency: int = 0,
104
+ checkpoint_at_end: bool = True,
105
+ results_postprocessing_fn: Optional[
106
+ Callable[[Dict[str, Union[float, List[float]]]], Dict[str, float]]
107
+ ] = None,
108
+ ):
109
+ if isinstance(metrics, str):
110
+ metrics = [metrics]
111
+ self._metrics = metrics
112
+ self._filename = filename
113
+ self._frequency = frequency
114
+ self._checkpoint_at_end = checkpoint_at_end
115
+ self._results_postprocessing_fn = results_postprocessing_fn
116
+
117
+ # Keeps track of the eval metrics from the last iteration,
118
+ # so that the latest metrics can be reported with the checkpoint
119
+ # at the end of training.
120
+ self._evals_log = None
121
+ # Keep track of the last checkpoint iteration to avoid double-checkpointing
122
+ # when using `checkpoint_at_end=True`.
123
+ self._last_checkpoint_iteration = None
124
+
125
+ @classmethod
126
+ def get_model(
127
+ cls, checkpoint: Checkpoint, filename: str = CHECKPOINT_NAME
128
+ ) -> Booster:
129
+ """Retrieve the model stored in a checkpoint reported by this callback.
130
+
131
+ Args:
132
+ checkpoint: The checkpoint object returned by a training run.
133
+ The checkpoint should be saved by an instance of this callback.
134
+ filename: The filename to load the model from, which should match
135
+ the filename used when creating the callback.
136
+ """
137
+ with checkpoint.as_directory() as checkpoint_path:
138
+ booster = Booster()
139
+ booster.load_model(Path(checkpoint_path, filename).as_posix())
140
+ return booster
141
+
142
+ def _get_report_dict(self, evals_log):
143
+ if isinstance(evals_log, OrderedDict):
144
+ # xgboost>=1.3
145
+ result_dict = flatten_dict(evals_log, delimiter="-")
146
+ for k in list(result_dict):
147
+ result_dict[k] = result_dict[k][-1]
148
+ else:
149
+ # xgboost<1.3
150
+ result_dict = dict(evals_log)
151
+ if not self._metrics:
152
+ report_dict = result_dict
153
+ else:
154
+ report_dict = {}
155
+ for key in self._metrics:
156
+ if isinstance(self._metrics, dict):
157
+ metric = self._metrics[key]
158
+ else:
159
+ metric = key
160
+ report_dict[key] = result_dict[metric]
161
+
162
+ if self._results_postprocessing_fn:
163
+ report_dict = self._results_postprocessing_fn(report_dict)
164
+
165
+ return report_dict
166
+
167
+ @contextmanager
168
+ def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]:
169
+ # NOTE: The world rank returns None for Tune usage without Train.
170
+ if ray.train.get_context().get_world_rank() in (0, None):
171
+ with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
172
+ model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix())
173
+ yield Checkpoint(temp_checkpoint_dir)
174
+ else:
175
+ yield None
176
+
177
+ def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):
178
+ self._evals_log = evals_log
179
+
180
+ checkpointing_disabled = self._frequency == 0
181
+ # Ex: if frequency=2, checkpoint at epoch 1, 3, 5, ... (counting from 0)
182
+ should_checkpoint = (
183
+ not checkpointing_disabled and (epoch + 1) % self._frequency == 0
184
+ )
185
+
186
+ report_dict = self._get_report_dict(evals_log)
187
+ if should_checkpoint:
188
+ self._last_checkpoint_iteration = epoch
189
+ with self._get_checkpoint(model=model) as checkpoint:
190
+ ray.train.report(report_dict, checkpoint=checkpoint)
191
+ else:
192
+ ray.train.report(report_dict)
193
+
194
+ def after_training(self, model: Booster) -> Booster:
195
+ if not self._checkpoint_at_end:
196
+ return model
197
+
198
+ if (
199
+ self._last_checkpoint_iteration is not None
200
+ and model.num_boosted_rounds() - 1 == self._last_checkpoint_iteration
201
+ ):
202
+ # Avoids a duplicate checkpoint if the checkpoint frequency happens
203
+ # to align with the last iteration.
204
+ return model
205
+
206
+ report_dict = self._get_report_dict(self._evals_log) if self._evals_log else {}
207
+ with self._get_checkpoint(model=model) as checkpoint:
208
+ ray.train.report(report_dict, checkpoint=checkpoint)
209
+
210
+ return model
.venv/lib/python3.11/site-packages/ray/train/xgboost/config.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import threading
5
+ from contextlib import contextmanager
6
+ from dataclasses import dataclass
7
+ from typing import Optional
8
+
9
+ import xgboost
10
+ from packaging.version import Version
11
+ from xgboost import RabitTracker
12
+ from xgboost.collective import CommunicatorContext
13
+
14
+ import ray
15
+ from ray.train._internal.worker_group import WorkerGroup
16
+ from ray.train.backend import Backend, BackendConfig
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclass
22
+ class XGBoostConfig(BackendConfig):
23
+ """Configuration for xgboost collective communication setup.
24
+
25
+ Ray Train will set up the necessary coordinator processes and environment
26
+ variables for your workers to communicate with each other.
27
+ Additional configuration options can be passed into the
28
+ `xgboost.collective.CommunicatorContext` that wraps your own `xgboost.train` code.
29
+
30
+ See the `xgboost.collective` module for more information:
31
+ https://github.com/dmlc/xgboost/blob/master/python-package/xgboost/collective.py
32
+
33
+ Args:
34
+ xgboost_communicator: The backend to use for collective communication for
35
+ distributed xgboost training. For now, only "rabit" is supported.
36
+ """
37
+
38
+ xgboost_communicator: str = "rabit"
39
+
40
+ @property
41
+ def train_func_context(self):
42
+ @contextmanager
43
+ def collective_communication_context():
44
+ with CommunicatorContext(**_get_xgboost_args()):
45
+ yield
46
+
47
+ return collective_communication_context
48
+
49
+ @property
50
+ def backend_cls(self):
51
+ if self.xgboost_communicator == "rabit":
52
+ return (
53
+ _XGBoostRabitBackend
54
+ if Version(xgboost.__version__) >= Version("2.1.0")
55
+ else _XGBoostRabitBackend_pre_xgb210
56
+ )
57
+
58
+ raise NotImplementedError(f"Unsupported backend: {self.xgboost_communicator}")
59
+
60
+
61
+ class _XGBoostRabitBackend(Backend):
62
+ def __init__(self):
63
+ self._tracker: Optional[RabitTracker] = None
64
+ self._wait_thread: Optional[threading.Thread] = None
65
+
66
+ def _setup_xgboost_distributed_backend(self, worker_group: WorkerGroup):
67
+ # Set up the rabit tracker on the Train driver.
68
+ num_workers = len(worker_group)
69
+ rabit_args = {"n_workers": num_workers}
70
+ train_driver_ip = ray.util.get_node_ip_address()
71
+
72
+ # NOTE: sortby="task" is needed to ensure that the xgboost worker ranks
73
+ # align with Ray Train worker ranks.
74
+ # The worker ranks will be sorted by `dmlc_task_id`,
75
+ # which is defined below.
76
+ self._tracker = RabitTracker(
77
+ n_workers=num_workers, host_ip=train_driver_ip, sortby="task"
78
+ )
79
+ self._tracker.start()
80
+
81
+ # The RabitTracker is started in a separate thread, and the
82
+ # `wait_for` method must be called for `worker_args` to return.
83
+ self._wait_thread = threading.Thread(target=self._tracker.wait_for, daemon=True)
84
+ self._wait_thread.start()
85
+
86
+ rabit_args.update(self._tracker.worker_args())
87
+
88
+ start_log = (
89
+ "RabitTracker coordinator started with parameters:\n"
90
+ f"{json.dumps(rabit_args, indent=2)}"
91
+ )
92
+ logger.debug(start_log)
93
+
94
+ def set_xgboost_communicator_args(args):
95
+ import ray.train
96
+
97
+ args["dmlc_task_id"] = (
98
+ f"[xgboost.ray-rank={ray.train.get_context().get_world_rank():08}]:"
99
+ f"{ray.get_runtime_context().get_actor_id()}"
100
+ )
101
+
102
+ _set_xgboost_args(args)
103
+
104
+ worker_group.execute(set_xgboost_communicator_args, rabit_args)
105
+
106
+ def on_training_start(
107
+ self, worker_group: WorkerGroup, backend_config: XGBoostConfig
108
+ ):
109
+ assert backend_config.xgboost_communicator == "rabit"
110
+ self._setup_xgboost_distributed_backend(worker_group)
111
+
112
+ def on_shutdown(self, worker_group: WorkerGroup, backend_config: XGBoostConfig):
113
+ timeout = 5
114
+
115
+ if self._wait_thread is not None:
116
+ self._wait_thread.join(timeout=timeout)
117
+
118
+ if self._wait_thread.is_alive():
119
+ logger.warning(
120
+ "During shutdown, the RabitTracker thread failed to join "
121
+ f"within {timeout} seconds. "
122
+ "The process will still be terminated as part of Ray actor cleanup."
123
+ )
124
+
125
+
126
+ class _XGBoostRabitBackend_pre_xgb210(Backend):
127
+ def __init__(self):
128
+ self._tracker: Optional[RabitTracker] = None
129
+
130
+ def _setup_xgboost_distributed_backend(self, worker_group: WorkerGroup):
131
+ # Set up the rabit tracker on the Train driver.
132
+ num_workers = len(worker_group)
133
+ rabit_args = {"DMLC_NUM_WORKER": num_workers}
134
+ train_driver_ip = ray.util.get_node_ip_address()
135
+
136
+ # NOTE: sortby="task" is needed to ensure that the xgboost worker ranks
137
+ # align with Ray Train worker ranks.
138
+ # The worker ranks will be sorted by `DMLC_TASK_ID`,
139
+ # which is defined below.
140
+ self._tracker = RabitTracker(
141
+ n_workers=num_workers, host_ip=train_driver_ip, sortby="task"
142
+ )
143
+ self._tracker.start(n_workers=num_workers)
144
+
145
+ worker_args = self._tracker.worker_envs()
146
+ rabit_args.update(worker_args)
147
+
148
+ start_log = (
149
+ "RabitTracker coordinator started with parameters:\n"
150
+ f"{json.dumps(rabit_args, indent=2)}"
151
+ )
152
+ logger.debug(start_log)
153
+
154
+ def set_xgboost_env_vars():
155
+ import ray.train
156
+
157
+ for k, v in rabit_args.items():
158
+ os.environ[k] = str(v)
159
+
160
+ # Ranks are assigned in increasing order of the worker's task id.
161
+ # This task id will be sorted by increasing world rank.
162
+ os.environ["DMLC_TASK_ID"] = (
163
+ f"[xgboost.ray-rank={ray.train.get_context().get_world_rank():08}]:"
164
+ f"{ray.get_runtime_context().get_actor_id()}"
165
+ )
166
+
167
+ worker_group.execute(set_xgboost_env_vars)
168
+
169
+ def on_training_start(
170
+ self, worker_group: WorkerGroup, backend_config: XGBoostConfig
171
+ ):
172
+ assert backend_config.xgboost_communicator == "rabit"
173
+ self._setup_xgboost_distributed_backend(worker_group)
174
+
175
+ def on_shutdown(self, worker_group: WorkerGroup, backend_config: XGBoostConfig):
176
+ if not self._tracker:
177
+ return
178
+
179
+ timeout = 5
180
+ self._tracker.thread.join(timeout=timeout)
181
+
182
+ if self._tracker.thread.is_alive():
183
+ logger.warning(
184
+ "During shutdown, the RabitTracker thread failed to join "
185
+ f"within {timeout} seconds. "
186
+ "The process will still be terminated as part of Ray actor cleanup."
187
+ )
188
+
189
+
190
+ _xgboost_args: dict = {}
191
+ _xgboost_args_lock = threading.Lock()
192
+
193
+
194
+ def _set_xgboost_args(args):
195
+ with _xgboost_args_lock:
196
+ global _xgboost_args
197
+ _xgboost_args = args
198
+
199
+
200
+ def _get_xgboost_args() -> dict:
201
+ with _xgboost_args_lock:
202
+ return _xgboost_args
.venv/lib/python3.11/site-packages/ray/train/xgboost/v2.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Callable, Dict, Optional, Union
3
+
4
+ import ray.train
5
+ from ray.train import Checkpoint
6
+ from ray.train.data_parallel_trainer import DataParallelTrainer
7
+ from ray.train.trainer import GenDataset
8
+ from ray.train.xgboost import XGBoostConfig
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class XGBoostTrainer(DataParallelTrainer):
14
+ """A Trainer for distributed data-parallel XGBoost training.
15
+
16
+ Example
17
+ -------
18
+
19
+ .. testcode::
20
+
21
+ import xgboost
22
+
23
+ import ray.data
24
+ import ray.train
25
+ from ray.train.xgboost import RayTrainReportCallback
26
+ from ray.train.xgboost.v2 import XGBoostTrainer
27
+
28
+ def train_fn_per_worker(config: dict):
29
+ # (Optional) Add logic to resume training state from a checkpoint.
30
+ # ray.train.get_checkpoint()
31
+
32
+ # 1. Get the dataset shard for the worker and convert to a `xgboost.DMatrix`
33
+ train_ds_iter, eval_ds_iter = (
34
+ ray.train.get_dataset_shard("train"),
35
+ ray.train.get_dataset_shard("validation"),
36
+ )
37
+ train_ds, eval_ds = train_ds_iter.materialize(), eval_ds_iter.materialize()
38
+
39
+ train_df, eval_df = train_ds.to_pandas(), eval_ds.to_pandas()
40
+ train_X, train_y = train_df.drop("y", axis=1), train_df["y"]
41
+ eval_X, eval_y = eval_df.drop("y", axis=1), eval_df["y"]
42
+
43
+ dtrain = xgboost.DMatrix(train_X, label=train_y)
44
+ deval = xgboost.DMatrix(eval_X, label=eval_y)
45
+
46
+ params = {
47
+ "tree_method": "approx",
48
+ "objective": "reg:squarederror",
49
+ "eta": 1e-4,
50
+ "subsample": 0.5,
51
+ "max_depth": 2,
52
+ }
53
+
54
+ # 2. Do distributed data-parallel training.
55
+ # Ray Train sets up the necessary coordinator processes and
56
+ # environment variables for your workers to communicate with each other.
57
+ bst = xgboost.train(
58
+ params,
59
+ dtrain=dtrain,
60
+ evals=[(deval, "validation")],
61
+ num_boost_round=10,
62
+ callbacks=[RayTrainReportCallback()],
63
+ )
64
+
65
+ train_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)])
66
+ eval_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(16)])
67
+ trainer = XGBoostTrainer(
68
+ train_fn_per_worker,
69
+ datasets={"train": train_ds, "validation": eval_ds},
70
+ scaling_config=ray.train.ScalingConfig(num_workers=4),
71
+ )
72
+ result = trainer.fit()
73
+ booster = RayTrainReportCallback.get_model(result.checkpoint)
74
+
75
+ .. testoutput::
76
+ :hide:
77
+
78
+ ...
79
+
80
+ Args:
81
+ train_loop_per_worker: The training function to execute on each worker.
82
+ This function can either take in zero arguments or a single ``Dict``
83
+ argument which is set by defining ``train_loop_config``.
84
+ Within this function you can use any of the
85
+ :ref:`Ray Train Loop utilities <train-loop-api>`.
86
+ train_loop_config: A configuration ``Dict`` to pass in as an argument to
87
+ ``train_loop_per_worker``.
88
+ This is typically used for specifying hyperparameters.
89
+ xgboost_config: The configuration for setting up the distributed xgboost
90
+ backend. Defaults to using the "rabit" backend.
91
+ See :class:`~ray.train.xgboost.XGBoostConfig` for more info.
92
+ datasets: The Ray Datasets to use for training and validation.
93
+ dataset_config: The configuration for ingesting the input ``datasets``.
94
+ By default, all the Ray Datasets are split equally across workers.
95
+ See :class:`~ray.train.DataConfig` for more details.
96
+ scaling_config: The configuration for how to scale data parallel training.
97
+ ``num_workers`` determines how many Python processes are used for training,
98
+ and ``use_gpu`` determines whether or not each process should use GPUs.
99
+ See :class:`~ray.train.ScalingConfig` for more info.
100
+ run_config: The configuration for the execution of the training run.
101
+ See :class:`~ray.train.RunConfig` for more info.
102
+ resume_from_checkpoint: A checkpoint to resume training from.
103
+ This checkpoint can be accessed from within ``train_loop_per_worker``
104
+ by calling ``ray.train.get_checkpoint()``.
105
+ metadata: Dict that should be made available via
106
+ `ray.train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
107
+ for checkpoints saved from this Trainer. Must be JSON-serializable.
108
+ """
109
+
110
+ def __init__(
111
+ self,
112
+ train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]],
113
+ *,
114
+ train_loop_config: Optional[Dict] = None,
115
+ xgboost_config: Optional[XGBoostConfig] = None,
116
+ scaling_config: Optional[ray.train.ScalingConfig] = None,
117
+ run_config: Optional[ray.train.RunConfig] = None,
118
+ datasets: Optional[Dict[str, GenDataset]] = None,
119
+ dataset_config: Optional[ray.train.DataConfig] = None,
120
+ metadata: Optional[Dict[str, Any]] = None,
121
+ resume_from_checkpoint: Optional[Checkpoint] = None,
122
+ ):
123
+ super(XGBoostTrainer, self).__init__(
124
+ train_loop_per_worker=train_loop_per_worker,
125
+ train_loop_config=train_loop_config,
126
+ backend_config=xgboost_config or XGBoostConfig(),
127
+ scaling_config=scaling_config,
128
+ dataset_config=dataset_config,
129
+ run_config=run_config,
130
+ datasets=datasets,
131
+ resume_from_checkpoint=resume_from_checkpoint,
132
+ metadata=metadata,
133
+ )
.venv/lib/python3.11/site-packages/ray/train/xgboost/xgboost_checkpoint.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ from pathlib import Path
3
+ from typing import TYPE_CHECKING, Optional
4
+
5
+ import xgboost
6
+
7
+ from ray.train._internal.framework_checkpoint import FrameworkCheckpoint
8
+ from ray.util.annotations import PublicAPI
9
+
10
+ if TYPE_CHECKING:
11
+ from ray.data.preprocessor import Preprocessor
12
+
13
+
14
+ @PublicAPI(stability="beta")
15
+ class XGBoostCheckpoint(FrameworkCheckpoint):
16
+ """A :py:class:`~ray.train.Checkpoint` with XGBoost-specific functionality."""
17
+
18
+ MODEL_FILENAME = "model.json"
19
+
20
+ @classmethod
21
+ def from_model(
22
+ cls,
23
+ booster: xgboost.Booster,
24
+ *,
25
+ preprocessor: Optional["Preprocessor"] = None,
26
+ path: Optional[str] = None,
27
+ ) -> "XGBoostCheckpoint":
28
+ """Create a :py:class:`~ray.train.Checkpoint` that stores an XGBoost
29
+ model.
30
+
31
+ Args:
32
+ booster: The XGBoost model to store in the checkpoint.
33
+ preprocessor: A fitted preprocessor to be applied before inference.
34
+ path: The path to the directory where the checkpoint file will be saved.
35
+ This should start as an empty directory, since the *entire*
36
+ directory will be treated as the checkpoint when reported.
37
+ By default, a temporary directory will be created.
38
+
39
+ Returns:
40
+ An :py:class:`XGBoostCheckpoint` containing the specified ``Estimator``.
41
+
42
+ Examples:
43
+
44
+ ... testcode::
45
+
46
+ import numpy as np
47
+ import ray
48
+ from ray.train.xgboost import XGBoostCheckpoint
49
+ import xgboost
50
+
51
+ train_X = np.array([[1, 2], [3, 4]])
52
+ train_y = np.array([0, 1])
53
+
54
+ model = xgboost.XGBClassifier().fit(train_X, train_y)
55
+ checkpoint = XGBoostCheckpoint.from_model(model.get_booster())
56
+
57
+ """
58
+ checkpoint_path = Path(path or tempfile.mkdtemp())
59
+
60
+ if not checkpoint_path.is_dir():
61
+ raise ValueError(f"`path` must be a directory, but got: {checkpoint_path}")
62
+
63
+ booster.save_model(checkpoint_path.joinpath(cls.MODEL_FILENAME).as_posix())
64
+
65
+ checkpoint = cls.from_directory(checkpoint_path.as_posix())
66
+ if preprocessor:
67
+ checkpoint.set_preprocessor(preprocessor)
68
+ return checkpoint
69
+
70
+ def get_model(self) -> xgboost.Booster:
71
+ """Retrieve the XGBoost model stored in this checkpoint."""
72
+ with self.as_directory() as checkpoint_path:
73
+ booster = xgboost.Booster()
74
+ booster.load_model(Path(checkpoint_path, self.MODEL_FILENAME).as_posix())
75
+ return booster
.venv/lib/python3.11/site-packages/ray/train/xgboost/xgboost_predictor.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
2
+
3
+ import pandas as pd
4
+ import xgboost
5
+
6
+ from ray.air.constants import TENSOR_COLUMN_NAME
7
+ from ray.air.data_batch_type import DataBatchType
8
+ from ray.air.util.data_batch_conversion import _unwrap_ndarray_object_type_if_needed
9
+ from ray.train.predictor import Predictor
10
+ from ray.train.xgboost import XGBoostCheckpoint
11
+ from ray.util.annotations import PublicAPI
12
+
13
+ if TYPE_CHECKING:
14
+ from ray.data.preprocessor import Preprocessor
15
+
16
+
17
+ @PublicAPI(stability="beta")
18
+ class XGBoostPredictor(Predictor):
19
+ """A predictor for XGBoost models.
20
+
21
+ Args:
22
+ model: The XGBoost booster to use for predictions.
23
+ preprocessor: A preprocessor used to transform data batches prior
24
+ to prediction.
25
+ """
26
+
27
+ def __init__(
28
+ self, model: xgboost.Booster, preprocessor: Optional["Preprocessor"] = None
29
+ ):
30
+ self.model = model
31
+ super().__init__(preprocessor)
32
+
33
+ def __repr__(self):
34
+ return (
35
+ f"{self.__class__.__name__}(model={self.model!r}, "
36
+ f"preprocessor={self._preprocessor!r})"
37
+ )
38
+
39
+ @classmethod
40
+ def from_checkpoint(cls, checkpoint: XGBoostCheckpoint) -> "XGBoostPredictor":
41
+ """Instantiate the predictor from a Checkpoint.
42
+
43
+ This is a helper constructor that instantiates the predictor from a
44
+ framework-specific XGBoost checkpoint.
45
+
46
+ Args:
47
+ checkpoint: The checkpoint to load the model and preprocessor from.
48
+
49
+ """
50
+ model = checkpoint.get_model()
51
+ preprocessor = checkpoint.get_preprocessor()
52
+ return cls(model=model, preprocessor=preprocessor)
53
+
54
+ def predict(
55
+ self,
56
+ data: DataBatchType,
57
+ feature_columns: Optional[Union[List[str], List[int]]] = None,
58
+ dmatrix_kwargs: Optional[Dict[str, Any]] = None,
59
+ **predict_kwargs,
60
+ ) -> DataBatchType:
61
+ """Run inference on data batch.
62
+
63
+ The data is converted into an XGBoost DMatrix before being inputted to
64
+ the model.
65
+
66
+ Args:
67
+ data: A batch of input data.
68
+ feature_columns: The names or indices of the columns in the
69
+ data to use as features to predict on. If None, then use
70
+ all columns in ``data``.
71
+ dmatrix_kwargs: Dict of keyword arguments passed to ``xgboost.DMatrix``.
72
+ **predict_kwargs: Keyword arguments passed to ``xgboost.Booster.predict``.
73
+
74
+
75
+ Examples:
76
+
77
+ .. testcode::
78
+
79
+ import numpy as np
80
+ import xgboost as xgb
81
+ from ray.train.xgboost import XGBoostPredictor
82
+ train_X = np.array([[1, 2], [3, 4]])
83
+ train_y = np.array([0, 1])
84
+ model = xgb.XGBClassifier().fit(train_X, train_y)
85
+ predictor = XGBoostPredictor(model=model.get_booster())
86
+ data = np.array([[1, 2], [3, 4]])
87
+ predictions = predictor.predict(data)
88
+ # Only use first and second column as the feature
89
+ data = np.array([[1, 2, 8], [3, 4, 9]])
90
+ predictions = predictor.predict(data, feature_columns=[0, 1])
91
+
92
+ .. testcode::
93
+
94
+ import pandas as pd
95
+ import xgboost as xgb
96
+ from ray.train.xgboost import XGBoostPredictor
97
+ train_X = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"])
98
+ train_y = pd.Series([0, 1])
99
+ model = xgb.XGBClassifier().fit(train_X, train_y)
100
+ predictor = XGBoostPredictor(model=model.get_booster())
101
+ # Pandas dataframe.
102
+ data = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"])
103
+ predictions = predictor.predict(data)
104
+ # Only use first and second column as the feature
105
+ data = pd.DataFrame([[1, 2, 8], [3, 4, 9]], columns=["A", "B", "C"])
106
+ predictions = predictor.predict(data, feature_columns=["A", "B"])
107
+
108
+
109
+ Returns:
110
+ Prediction result.
111
+
112
+ """
113
+ return Predictor.predict(
114
+ self,
115
+ data,
116
+ feature_columns=feature_columns,
117
+ dmatrix_kwargs=dmatrix_kwargs,
118
+ **predict_kwargs,
119
+ )
120
+
121
+ def _predict_pandas(
122
+ self,
123
+ data: "pd.DataFrame",
124
+ feature_columns: Optional[Union[List[str], List[int]]] = None,
125
+ dmatrix_kwargs: Optional[Dict[str, Any]] = None,
126
+ **predict_kwargs,
127
+ ) -> "pd.DataFrame":
128
+ dmatrix_kwargs = dmatrix_kwargs or {}
129
+
130
+ feature_names = None
131
+ if TENSOR_COLUMN_NAME in data:
132
+ data = data[TENSOR_COLUMN_NAME].to_numpy()
133
+ data = _unwrap_ndarray_object_type_if_needed(data)
134
+ if feature_columns:
135
+ # In this case feature_columns is a list of integers
136
+ data = data[:, feature_columns]
137
+ elif feature_columns:
138
+ # feature_columns is a list of integers or strings
139
+ data = data[feature_columns].to_numpy()
140
+ # Only set the feature names if they are strings
141
+ if all(isinstance(fc, str) for fc in feature_columns):
142
+ feature_names = feature_columns
143
+ else:
144
+ feature_columns = data.columns.tolist()
145
+ data = data.to_numpy()
146
+
147
+ if all(isinstance(fc, str) for fc in feature_columns):
148
+ feature_names = feature_columns
149
+
150
+ if feature_names:
151
+ dmatrix_kwargs["feature_names"] = feature_names
152
+
153
+ matrix = xgboost.DMatrix(data, **dmatrix_kwargs)
154
+ df = pd.DataFrame(self.model.predict(matrix, **predict_kwargs))
155
+ df.columns = (
156
+ ["predictions"]
157
+ if len(df.columns) == 1
158
+ else [f"predictions_{i}" for i in range(len(df.columns))]
159
+ )
160
+ return df
.venv/lib/python3.11/site-packages/ray/train/xgboost/xgboost_trainer.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from functools import partial
3
+ from typing import Any, Dict, Optional
4
+
5
+ import xgboost
6
+ from packaging.version import Version
7
+
8
+ import ray.train
9
+ from ray.train import Checkpoint
10
+ from ray.train.constants import _DEPRECATED_VALUE, TRAIN_DATASET_KEY
11
+ from ray.train.trainer import GenDataset
12
+ from ray.train.xgboost import RayTrainReportCallback
13
+ from ray.train.xgboost.v2 import XGBoostTrainer as SimpleXGBoostTrainer
14
+ from ray.util.annotations import PublicAPI
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def _xgboost_train_fn_per_worker(
20
+ config: dict,
21
+ label_column: str,
22
+ num_boost_round: int,
23
+ dataset_keys: set,
24
+ xgboost_train_kwargs: dict,
25
+ ):
26
+ checkpoint = ray.train.get_checkpoint()
27
+ starting_model = None
28
+ remaining_iters = num_boost_round
29
+ if checkpoint:
30
+ starting_model = RayTrainReportCallback.get_model(checkpoint)
31
+ starting_iter = starting_model.num_boosted_rounds()
32
+ remaining_iters = num_boost_round - starting_iter
33
+ logger.info(
34
+ f"Model loaded from checkpoint will train for "
35
+ f"additional {remaining_iters} iterations (trees) in order "
36
+ "to achieve the target number of iterations "
37
+ f"({num_boost_round=})."
38
+ )
39
+
40
+ train_ds_iter = ray.train.get_dataset_shard(TRAIN_DATASET_KEY)
41
+ train_df = train_ds_iter.materialize().to_pandas()
42
+
43
+ eval_ds_iters = {
44
+ k: ray.train.get_dataset_shard(k)
45
+ for k in dataset_keys
46
+ if k != TRAIN_DATASET_KEY
47
+ }
48
+ eval_dfs = {k: d.materialize().to_pandas() for k, d in eval_ds_iters.items()}
49
+
50
+ train_X, train_y = train_df.drop(label_column, axis=1), train_df[label_column]
51
+ dtrain = xgboost.DMatrix(train_X, label=train_y)
52
+
53
+ # NOTE: Include the training dataset in the evaluation datasets.
54
+ # This allows `train-*` metrics to be calculated and reported.
55
+ evals = [(dtrain, TRAIN_DATASET_KEY)]
56
+
57
+ for eval_name, eval_df in eval_dfs.items():
58
+ eval_X, eval_y = eval_df.drop(label_column, axis=1), eval_df[label_column]
59
+ evals.append((xgboost.DMatrix(eval_X, label=eval_y), eval_name))
60
+
61
+ evals_result = {}
62
+ xgboost.train(
63
+ config,
64
+ dtrain=dtrain,
65
+ evals=evals,
66
+ evals_result=evals_result,
67
+ num_boost_round=remaining_iters,
68
+ xgb_model=starting_model,
69
+ **xgboost_train_kwargs,
70
+ )
71
+
72
+
73
+ @PublicAPI(stability="beta")
74
+ class XGBoostTrainer(SimpleXGBoostTrainer):
75
+ """A Trainer for data parallel XGBoost training.
76
+
77
+ This Trainer runs the XGBoost training loop in a distributed manner
78
+ using multiple Ray Actors.
79
+
80
+ .. note::
81
+ ``XGBoostTrainer`` does not modify or otherwise alter the working
82
+ of the XGBoost distributed training algorithm.
83
+ Ray only provides orchestration, data ingest and fault tolerance.
84
+ For more information on XGBoost distributed training, refer to
85
+ `XGBoost documentation <https://xgboost.readthedocs.io>`__.
86
+
87
+ Example:
88
+ .. testcode::
89
+
90
+ import ray
91
+
92
+ from ray.train.xgboost import XGBoostTrainer
93
+ from ray.train import ScalingConfig
94
+
95
+ train_dataset = ray.data.from_items(
96
+ [{"x": x, "y": x + 1} for x in range(32)])
97
+ trainer = XGBoostTrainer(
98
+ label_column="y",
99
+ params={"objective": "reg:squarederror"},
100
+ scaling_config=ScalingConfig(num_workers=3),
101
+ datasets={"train": train_dataset},
102
+ )
103
+ result = trainer.fit()
104
+
105
+ .. testoutput::
106
+ :hide:
107
+
108
+ ...
109
+
110
+ Args:
111
+ datasets: The Ray Datasets to use for training and validation. Must include a
112
+ "train" key denoting the training dataset. All non-training datasets will
113
+ be used as separate validation sets, each reporting a separate metric.
114
+ label_column: Name of the label column. A column with this name
115
+ must be present in the training dataset.
116
+ params: XGBoost training parameters.
117
+ Refer to `XGBoost documentation <https://xgboost.readthedocs.io/>`_
118
+ for a list of possible parameters.
119
+ num_boost_round: Target number of boosting iterations (trees in the model).
120
+ Note that unlike in ``xgboost.train``, this is the target number
121
+ of trees, meaning that if you set ``num_boost_round=10`` and pass a model
122
+ that has already been trained for 5 iterations, it will be trained for 5
123
+ iterations more, instead of 10 more.
124
+ scaling_config: Configuration for how to scale data parallel training.
125
+ run_config: Configuration for the execution of the training run.
126
+ dataset_config: The configuration for ingesting the input ``datasets``.
127
+ By default, all the Ray Datasets are split equally across workers.
128
+ See :class:`~ray.train.DataConfig` for more details.
129
+ resume_from_checkpoint: A checkpoint to resume training from.
130
+ metadata: Dict that should be made available in `checkpoint.get_metadata()`
131
+ for checkpoints saved from this Trainer. Must be JSON-serializable.
132
+ **train_kwargs: Additional kwargs passed to ``xgboost.train()`` function.
133
+ """
134
+
135
+ _handles_checkpoint_freq = True
136
+ _handles_checkpoint_at_end = True
137
+
138
+ def __init__(
139
+ self,
140
+ *,
141
+ datasets: Dict[str, GenDataset],
142
+ label_column: str,
143
+ params: Dict[str, Any],
144
+ dmatrix_params: Optional[Dict[str, Dict[str, Any]]] = _DEPRECATED_VALUE,
145
+ num_boost_round: int = 10,
146
+ scaling_config: Optional[ray.train.ScalingConfig] = None,
147
+ run_config: Optional[ray.train.RunConfig] = None,
148
+ dataset_config: Optional[ray.train.DataConfig] = None,
149
+ resume_from_checkpoint: Optional[Checkpoint] = None,
150
+ metadata: Optional[Dict[str, Any]] = None,
151
+ **train_kwargs,
152
+ ):
153
+ if Version(xgboost.__version__) < Version("1.7.0"):
154
+ raise ImportError(
155
+ "`XGBoostTrainer` requires the `xgboost` version to be >= 1.7.0. "
156
+ 'Upgrade with: `pip install -U "xgboost>=1.7"`'
157
+ )
158
+
159
+ # TODO(justinvyu): [Deprecated] Remove in 2.11
160
+ if dmatrix_params != _DEPRECATED_VALUE:
161
+ raise DeprecationWarning(
162
+ "`dmatrix_params` is deprecated, since XGBoostTrainer no longer "
163
+ "depends on the `xgboost_ray.RayDMatrix` utility. "
164
+ "You can remove this argument and use `dataset_config` instead "
165
+ "to customize Ray Dataset ingestion."
166
+ )
167
+
168
+ # Initialize a default Ray Train metrics/checkpoint reporting callback if needed
169
+ callbacks = train_kwargs.get("callbacks", [])
170
+ user_supplied_callback = any(
171
+ isinstance(callback, RayTrainReportCallback) for callback in callbacks
172
+ )
173
+ callback_kwargs = {}
174
+ if run_config:
175
+ checkpoint_frequency = run_config.checkpoint_config.checkpoint_frequency
176
+ checkpoint_at_end = run_config.checkpoint_config.checkpoint_at_end
177
+
178
+ callback_kwargs["frequency"] = checkpoint_frequency
179
+ # Default `checkpoint_at_end=True` unless the user explicitly sets it.
180
+ callback_kwargs["checkpoint_at_end"] = (
181
+ checkpoint_at_end if checkpoint_at_end is not None else True
182
+ )
183
+
184
+ if not user_supplied_callback:
185
+ callbacks.append(RayTrainReportCallback(**callback_kwargs))
186
+ train_kwargs["callbacks"] = callbacks
187
+
188
+ train_fn_per_worker = partial(
189
+ _xgboost_train_fn_per_worker,
190
+ label_column=label_column,
191
+ num_boost_round=num_boost_round,
192
+ dataset_keys=set(datasets),
193
+ xgboost_train_kwargs=train_kwargs,
194
+ )
195
+
196
+ super(XGBoostTrainer, self).__init__(
197
+ train_loop_per_worker=train_fn_per_worker,
198
+ train_loop_config=params,
199
+ scaling_config=scaling_config,
200
+ run_config=run_config,
201
+ datasets=datasets,
202
+ dataset_config=dataset_config,
203
+ resume_from_checkpoint=resume_from_checkpoint,
204
+ metadata=metadata,
205
+ )
206
+
207
+ @classmethod
208
+ def get_model(
209
+ cls,
210
+ checkpoint: Checkpoint,
211
+ ) -> xgboost.Booster:
212
+ """Retrieve the XGBoost model stored in this checkpoint."""
213
+ return RayTrainReportCallback.get_model(checkpoint)
214
+
215
+ def _validate_attributes(self):
216
+ super()._validate_attributes()
217
+
218
+ if TRAIN_DATASET_KEY not in self.datasets:
219
+ raise KeyError(
220
+ f"'{TRAIN_DATASET_KEY}' key must be preset in `datasets`. "
221
+ f"Got {list(self.datasets.keys())}"
222
+ )
.venv/lib/python3.11/site-packages/ray/tune/search/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (5.87 kB). View file