koichi12 commited on
Commit
e8a93e7
·
verified ·
1 Parent(s): de7cd93

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/ray/_private/__init__.py +0 -0
  2. .venv/lib/python3.11/site-packages/ray/_private/arrow_serialization.py +816 -0
  3. .venv/lib/python3.11/site-packages/ray/_private/async_compat.py +52 -0
  4. .venv/lib/python3.11/site-packages/ray/_private/async_utils.py +52 -0
  5. .venv/lib/python3.11/site-packages/ray/_private/auto_init_hook.py +31 -0
  6. .venv/lib/python3.11/site-packages/ray/_private/client_mode_hook.py +184 -0
  7. .venv/lib/python3.11/site-packages/ray/_private/collections_utils.py +10 -0
  8. .venv/lib/python3.11/site-packages/ray/_private/compat.py +40 -0
  9. .venv/lib/python3.11/site-packages/ray/_private/conftest_utils.py +14 -0
  10. .venv/lib/python3.11/site-packages/ray/_private/dict.py +247 -0
  11. .venv/lib/python3.11/site-packages/ray/_private/external_storage.py +707 -0
  12. .venv/lib/python3.11/site-packages/ray/_private/function_manager.py +706 -0
  13. .venv/lib/python3.11/site-packages/ray/_private/gcs_aio_client.py +47 -0
  14. .venv/lib/python3.11/site-packages/ray/_private/gcs_pubsub.py +311 -0
  15. .venv/lib/python3.11/site-packages/ray/_private/gcs_utils.py +163 -0
  16. .venv/lib/python3.11/site-packages/ray/_private/inspect_util.py +49 -0
  17. .venv/lib/python3.11/site-packages/ray/_private/internal_api.py +255 -0
  18. .venv/lib/python3.11/site-packages/ray/_private/log.py +117 -0
  19. .venv/lib/python3.11/site-packages/ray/_private/log_monitor.py +581 -0
  20. .venv/lib/python3.11/site-packages/ray/_private/logging_utils.py +29 -0
  21. .venv/lib/python3.11/site-packages/ray/_private/memory_monitor.py +162 -0
  22. .venv/lib/python3.11/site-packages/ray/_private/metrics_agent.py +675 -0
  23. .venv/lib/python3.11/site-packages/ray/_private/node.py +1862 -0
  24. .venv/lib/python3.11/site-packages/ray/_private/parameter.py +483 -0
  25. .venv/lib/python3.11/site-packages/ray/_private/process_watcher.py +198 -0
  26. .venv/lib/python3.11/site-packages/ray/_private/profiling.py +240 -0
  27. .venv/lib/python3.11/site-packages/ray/_private/prometheus_exporter.py +365 -0
  28. .venv/lib/python3.11/site-packages/ray/_private/protobuf_compat.py +46 -0
  29. .venv/lib/python3.11/site-packages/ray/_private/pydantic_compat.py +108 -0
  30. .venv/lib/python3.11/site-packages/ray/_private/ray_client_microbenchmark.py +117 -0
  31. .venv/lib/python3.11/site-packages/ray/_private/ray_cluster_perf.py +50 -0
  32. .venv/lib/python3.11/site-packages/ray/_private/ray_constants.py +554 -0
  33. .venv/lib/python3.11/site-packages/ray/_private/ray_experimental_perf.py +337 -0
  34. .venv/lib/python3.11/site-packages/ray/_private/ray_microbenchmark_helpers.py +91 -0
  35. .venv/lib/python3.11/site-packages/ray/_private/ray_option_utils.py +387 -0
  36. .venv/lib/python3.11/site-packages/ray/_private/ray_perf.py +328 -0
  37. .venv/lib/python3.11/site-packages/ray/_private/ray_process_reaper.py +60 -0
  38. .venv/lib/python3.11/site-packages/ray/_private/resource_spec.py +317 -0
  39. .venv/lib/python3.11/site-packages/ray/_private/runtime_env/__init__.py +3 -0
  40. .venv/lib/python3.11/site-packages/ray/_private/runtime_env/_clonevirtualenv.py +334 -0
  41. .venv/lib/python3.11/site-packages/ray/_private/runtime_env/conda.py +407 -0
  42. .venv/lib/python3.11/site-packages/ray/_private/runtime_env/conda_utils.py +278 -0
  43. .venv/lib/python3.11/site-packages/ray/_private/runtime_env/constants.py +28 -0
  44. .venv/lib/python3.11/site-packages/ray/_private/runtime_env/context.py +108 -0
  45. .venv/lib/python3.11/site-packages/ray/_private/runtime_env/default_impl.py +11 -0
  46. .venv/lib/python3.11/site-packages/ray/_private/runtime_env/dependency_utils.py +113 -0
  47. .venv/lib/python3.11/site-packages/ray/_private/runtime_env/image_uri.py +195 -0
  48. .venv/lib/python3.11/site-packages/ray/_private/runtime_env/java_jars.py +103 -0
  49. .venv/lib/python3.11/site-packages/ray/_private/runtime_env/mpi.py +114 -0
  50. .venv/lib/python3.11/site-packages/ray/_private/runtime_env/mpi_runner.py +32 -0
.venv/lib/python3.11/site-packages/ray/_private/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/_private/arrow_serialization.py ADDED
@@ -0,0 +1,816 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # arrow_serialization.py must resides outside of ray.data, otherwise
2
+ # it causes circular dependency issues for AsyncActors due to
3
+ # ray.data's lazy import.
4
+ # see https://github.com/ray-project/ray/issues/30498 for more context.
5
+ from dataclasses import dataclass
6
+ import logging
7
+ import os
8
+ import sys
9
+ from typing import List, Tuple, Optional, TYPE_CHECKING
10
+
11
+ if TYPE_CHECKING:
12
+ import pyarrow
13
+ from ray.data.extensions import ArrowTensorArray
14
+
15
+ RAY_DISABLE_CUSTOM_ARROW_JSON_OPTIONS_SERIALIZATION = (
16
+ "RAY_DISABLE_CUSTOM_ARROW_JSON_OPTIONS_SERIALIZATION"
17
+ )
18
+ RAY_DISABLE_CUSTOM_ARROW_DATA_SERIALIZATION = (
19
+ "RAY_DISABLE_CUSTOM_ARROW_DATA_SERIALIZATION"
20
+ )
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Whether we have already warned the user about bloated fallback serialization.
25
+ _serialization_fallback_set = set()
26
+
27
+ # Whether we're currently running in a test, either local or CI.
28
+ _in_test = None
29
+
30
+
31
+ def _is_in_test():
32
+ global _in_test
33
+
34
+ if _in_test is None:
35
+ _in_test = any(
36
+ env_var in os.environ
37
+ # These environment variables are always set by pytest and Buildkite,
38
+ # respectively.
39
+ for env_var in ("PYTEST_CURRENT_TEST", "BUILDKITE")
40
+ )
41
+ return _in_test
42
+
43
+
44
+ def _register_custom_datasets_serializers(serialization_context):
45
+ try:
46
+ import pyarrow as pa # noqa: F401
47
+ except ModuleNotFoundError:
48
+ # No pyarrow installed so not using Arrow, so no need for custom serializers.
49
+ return
50
+
51
+ # Register all custom serializers required by Datasets.
52
+ _register_arrow_data_serializer(serialization_context)
53
+ _register_arrow_json_readoptions_serializer(serialization_context)
54
+ _register_arrow_json_parseoptions_serializer(serialization_context)
55
+
56
+
57
+ # Register custom Arrow JSON ReadOptions serializer to workaround it not being picklable
58
+ # in Arrow < 8.0.0.
59
+ def _register_arrow_json_readoptions_serializer(serialization_context):
60
+ if (
61
+ os.environ.get(
62
+ RAY_DISABLE_CUSTOM_ARROW_JSON_OPTIONS_SERIALIZATION,
63
+ "0",
64
+ )
65
+ == "1"
66
+ ):
67
+ return
68
+
69
+ import pyarrow.json as pajson
70
+
71
+ serialization_context._register_cloudpickle_serializer(
72
+ pajson.ReadOptions,
73
+ custom_serializer=lambda opts: (opts.use_threads, opts.block_size),
74
+ custom_deserializer=lambda args: pajson.ReadOptions(*args),
75
+ )
76
+
77
+
78
+ def _register_arrow_json_parseoptions_serializer(serialization_context):
79
+ if (
80
+ os.environ.get(
81
+ RAY_DISABLE_CUSTOM_ARROW_JSON_OPTIONS_SERIALIZATION,
82
+ "0",
83
+ )
84
+ == "1"
85
+ ):
86
+ return
87
+
88
+ import pyarrow.json as pajson
89
+
90
+ serialization_context._register_cloudpickle_serializer(
91
+ pajson.ParseOptions,
92
+ custom_serializer=lambda opts: (
93
+ opts.explicit_schema,
94
+ opts.newlines_in_values,
95
+ opts.unexpected_field_behavior,
96
+ ),
97
+ custom_deserializer=lambda args: pajson.ParseOptions(*args),
98
+ )
99
+
100
+
101
+ # Register custom Arrow data serializer to work around zero-copy slice pickling bug.
102
+ # See https://issues.apache.org/jira/browse/ARROW-10739.
103
+ def _register_arrow_data_serializer(serialization_context):
104
+ """Custom reducer for Arrow data that works around a zero-copy slicing pickling
105
+ bug by using the Arrow IPC format for the underlying serialization.
106
+
107
+ Background:
108
+ Arrow has both array-level slicing and buffer-level slicing; both are zero-copy,
109
+ but the former has a serialization bug where the entire buffer is serialized
110
+ instead of just the slice, while the latter's serialization works as expected
111
+ and only serializes the slice of the buffer. I.e., array-level slicing doesn't
112
+ propagate the slice down to the buffer when serializing the array.
113
+
114
+ We work around this by registering a custom cloudpickle reducers for Arrow
115
+ Tables that delegates serialization to the Arrow IPC format; thankfully, Arrow's
116
+ IPC serialization has fixed this buffer truncation bug.
117
+
118
+ See https://issues.apache.org/jira/browse/ARROW-10739.
119
+ """
120
+ if os.environ.get(RAY_DISABLE_CUSTOM_ARROW_DATA_SERIALIZATION, "0") == "1":
121
+ return
122
+
123
+ import pyarrow as pa
124
+
125
+ serialization_context._register_cloudpickle_reducer(pa.Table, _arrow_table_reduce)
126
+
127
+
128
+ def _arrow_table_reduce(t: "pyarrow.Table"):
129
+ """Custom reducer for Arrow Tables that works around a zero-copy slice pickling bug.
130
+ Background:
131
+ Arrow has both array-level slicing and buffer-level slicing; both are zero-copy,
132
+ but the former has a serialization bug where the entire buffer is serialized
133
+ instead of just the slice, while the latter's serialization works as expected
134
+ and only serializes the slice of the buffer. I.e., array-level slicing doesn't
135
+ propagate the slice down to the buffer when serializing the array.
136
+ All that these copy methods do is, at serialization time, take the array-level
137
+ slicing and translate them to buffer-level slicing, so only the buffer slice is
138
+ sent over the wire instead of the entire buffer.
139
+ See https://issues.apache.org/jira/browse/ARROW-10739.
140
+ """
141
+ global _serialization_fallback_set
142
+
143
+ # Reduce the ChunkedArray columns.
144
+ reduced_columns = []
145
+ for column_name in t.column_names:
146
+ column = t[column_name]
147
+ try:
148
+ # Delegate to ChunkedArray reducer.
149
+ reduced_column = _arrow_chunked_array_reduce(column)
150
+ except Exception as e:
151
+ if not _is_dense_union(column.type) and _is_in_test():
152
+ # If running in a test and the column is not a dense union array
153
+ # (which we expect to need a fallback), we want to raise the error,
154
+ # not fall back.
155
+ raise e from None
156
+ if type(column.type) not in _serialization_fallback_set:
157
+ logger.warning(
158
+ "Failed to complete optimized serialization of Arrow Table, "
159
+ f"serialization of column '{column_name}' of type {column.type} "
160
+ "failed, so we're falling back to Arrow IPC serialization for the "
161
+ "table. Note that this may result in slower serialization and more "
162
+ "worker memory utilization. Serialization error:",
163
+ exc_info=True,
164
+ )
165
+ _serialization_fallback_set.add(type(column.type))
166
+ # Fall back to Arrow IPC-based workaround for the entire table.
167
+ return _arrow_table_ipc_reduce(t)
168
+ else:
169
+ # Column reducer succeeded, add reduced column to list.
170
+ reduced_columns.append(reduced_column)
171
+ return _reconstruct_table, (reduced_columns, t.schema)
172
+
173
+
174
+ def _reconstruct_table(
175
+ reduced_columns: List[Tuple[List["pyarrow.Array"], "pyarrow.DataType"]],
176
+ schema: "pyarrow.Schema",
177
+ ) -> "pyarrow.Table":
178
+ """Restore a serialized Arrow Table, reconstructing each reduced column."""
179
+ import pyarrow as pa
180
+
181
+ # Reconstruct each reduced column.
182
+ columns = []
183
+ for chunks_payload, type_ in reduced_columns:
184
+ columns.append(_reconstruct_chunked_array(chunks_payload, type_))
185
+
186
+ return pa.Table.from_arrays(columns, schema=schema)
187
+
188
+
189
+ def _arrow_chunked_array_reduce(
190
+ ca: "pyarrow.ChunkedArray",
191
+ ) -> Tuple[List["PicklableArrayPayload"], "pyarrow.DataType"]:
192
+ """Custom reducer for Arrow ChunkedArrays that works around a zero-copy slice
193
+ pickling bug. This reducer does not return a reconstruction function, since it's
194
+ expected to be reconstructed by the Arrow Table reconstructor.
195
+ """
196
+ # Convert chunks to serialization payloads.
197
+ chunk_payloads = []
198
+ for chunk in ca.chunks:
199
+ chunk_payload = PicklableArrayPayload.from_array(chunk)
200
+ chunk_payloads.append(chunk_payload)
201
+ return chunk_payloads, ca.type
202
+
203
+
204
+ def _reconstruct_chunked_array(
205
+ chunks: List["PicklableArrayPayload"], type_: "pyarrow.DataType"
206
+ ) -> "pyarrow.ChunkedArray":
207
+ """Restore a serialized Arrow ChunkedArray from chunks and type."""
208
+ import pyarrow as pa
209
+
210
+ # Reconstruct chunks from serialization payloads.
211
+ chunks = [chunk.to_array() for chunk in chunks]
212
+
213
+ return pa.chunked_array(chunks, type_)
214
+
215
+
216
+ @dataclass
217
+ class PicklableArrayPayload:
218
+ """Picklable array payload, holding data buffers and array metadata.
219
+
220
+ This is a helper container for pickling and reconstructing nested Arrow Arrays while
221
+ ensuring that the buffers that underly zero-copy slice views are properly truncated.
222
+ """
223
+
224
+ # Array type.
225
+ type: "pyarrow.DataType"
226
+ # Length of array.
227
+ length: int
228
+ # Underlying data buffers.
229
+ buffers: List["pyarrow.Buffer"]
230
+ # Cached null count.
231
+ null_count: int
232
+ # Slice offset into base array.
233
+ offset: int
234
+ # Serialized array payloads for nested (child) arrays.
235
+ children: List["PicklableArrayPayload"]
236
+
237
+ @classmethod
238
+ def from_array(self, a: "pyarrow.Array") -> "PicklableArrayPayload":
239
+ """Create a picklable array payload from an Arrow Array.
240
+
241
+ This will recursively accumulate data buffer and metadata payloads that are
242
+ ready for pickling; namely, the data buffers underlying zero-copy slice views
243
+ will be properly truncated.
244
+ """
245
+ return _array_to_array_payload(a)
246
+
247
+ def to_array(self) -> "pyarrow.Array":
248
+ """Reconstruct an Arrow Array from this picklable payload."""
249
+ return _array_payload_to_array(self)
250
+
251
+
252
+ def _array_payload_to_array(payload: "PicklableArrayPayload") -> "pyarrow.Array":
253
+ """Reconstruct an Arrow Array from a possibly nested PicklableArrayPayload."""
254
+ import pyarrow as pa
255
+ from ray.air.util.tensor_extensions.arrow import get_arrow_extension_tensor_types
256
+
257
+ children = [child_payload.to_array() for child_payload in payload.children]
258
+
259
+ tensor_extension_types = get_arrow_extension_tensor_types()
260
+
261
+ if pa.types.is_dictionary(payload.type):
262
+ # Dedicated path for reconstructing a DictionaryArray, since
263
+ # Array.from_buffers() doesn't work for DictionaryArrays.
264
+ assert len(children) == 2, len(children)
265
+ indices, dictionary = children
266
+ return pa.DictionaryArray.from_arrays(indices, dictionary)
267
+ elif pa.types.is_map(payload.type) and len(children) > 1:
268
+ # In pyarrow<7.0.0, the underlying map child array is not exposed, so we work
269
+ # with the key and item arrays.
270
+ assert len(children) == 3, len(children)
271
+ offsets, keys, items = children
272
+ return pa.MapArray.from_arrays(offsets, keys, items)
273
+ elif isinstance(
274
+ payload.type,
275
+ tensor_extension_types,
276
+ ):
277
+ # Dedicated path for reconstructing an ArrowTensorArray or
278
+ # ArrowVariableShapedTensorArray, both of which can't be reconstructed by the
279
+ # Array.from_buffers() API.
280
+ assert len(children) == 1, len(children)
281
+ storage = children[0]
282
+ return pa.ExtensionArray.from_storage(payload.type, storage)
283
+ else:
284
+ # Common case: use Array.from_buffers() to construct an array of a certain type.
285
+ return pa.Array.from_buffers(
286
+ type=payload.type,
287
+ length=payload.length,
288
+ buffers=payload.buffers,
289
+ null_count=payload.null_count,
290
+ offset=payload.offset,
291
+ children=children,
292
+ )
293
+
294
+
295
+ def _array_to_array_payload(a: "pyarrow.Array") -> "PicklableArrayPayload":
296
+ """Serialize an Arrow Array to an PicklableArrayPayload for later pickling.
297
+
298
+ This function's primary purpose is to dispatch to the handler for the input array
299
+ type.
300
+ """
301
+ import pyarrow as pa
302
+
303
+ from ray.air.util.tensor_extensions.arrow import get_arrow_extension_tensor_types
304
+
305
+ tensor_extension_types = get_arrow_extension_tensor_types()
306
+
307
+ if _is_dense_union(a.type):
308
+ # Dense unions are not supported.
309
+ # TODO(Clark): Support dense unions.
310
+ raise NotImplementedError(
311
+ "Custom slice view serialization of dense union arrays is not yet "
312
+ "supported."
313
+ )
314
+
315
+ # Dispatch to handler for array type.
316
+ if pa.types.is_null(a.type):
317
+ return _null_array_to_array_payload(a)
318
+ elif _is_primitive(a.type):
319
+ return _primitive_array_to_array_payload(a)
320
+ elif _is_binary(a.type):
321
+ return _binary_array_to_array_payload(a)
322
+ elif pa.types.is_list(a.type) or pa.types.is_large_list(a.type):
323
+ return _list_array_to_array_payload(a)
324
+ elif pa.types.is_fixed_size_list(a.type):
325
+ return _fixed_size_list_array_to_array_payload(a)
326
+ elif pa.types.is_struct(a.type):
327
+ return _struct_array_to_array_payload(a)
328
+ elif pa.types.is_union(a.type):
329
+ return _union_array_to_array_payload(a)
330
+ elif pa.types.is_dictionary(a.type):
331
+ return _dictionary_array_to_array_payload(a)
332
+ elif pa.types.is_map(a.type):
333
+ return _map_array_to_array_payload(a)
334
+ elif isinstance(a.type, tensor_extension_types):
335
+ return _tensor_array_to_array_payload(a)
336
+ elif isinstance(a.type, pa.ExtensionType):
337
+ return _extension_array_to_array_payload(a)
338
+ else:
339
+ raise ValueError("Unhandled Arrow array type:", a.type)
340
+
341
+
342
+ def _is_primitive(type_: "pyarrow.DataType") -> bool:
343
+ """Whether the provided Array type is primitive (boolean, numeric, temporal or
344
+ fixed-size binary)."""
345
+ import pyarrow as pa
346
+
347
+ return (
348
+ pa.types.is_integer(type_)
349
+ or pa.types.is_floating(type_)
350
+ or pa.types.is_decimal(type_)
351
+ or pa.types.is_boolean(type_)
352
+ or pa.types.is_temporal(type_)
353
+ or pa.types.is_fixed_size_binary(type_)
354
+ )
355
+
356
+
357
+ def _is_binary(type_: "pyarrow.DataType") -> bool:
358
+ """Whether the provided Array type is a variable-sized binary type."""
359
+ import pyarrow as pa
360
+
361
+ return (
362
+ pa.types.is_string(type_)
363
+ or pa.types.is_large_string(type_)
364
+ or pa.types.is_binary(type_)
365
+ or pa.types.is_large_binary(type_)
366
+ )
367
+
368
+
369
+ def _null_array_to_array_payload(a: "pyarrow.NullArray") -> "PicklableArrayPayload":
370
+ """Serialize null array to PicklableArrayPayload."""
371
+ # Buffer scheme: [None]
372
+ return PicklableArrayPayload(
373
+ type=a.type,
374
+ length=len(a),
375
+ buffers=[None], # Single null buffer is expected.
376
+ null_count=a.null_count,
377
+ offset=0,
378
+ children=[],
379
+ )
380
+
381
+
382
+ def _primitive_array_to_array_payload(a: "pyarrow.Array") -> "PicklableArrayPayload":
383
+ """Serialize primitive (numeric, temporal, boolean) arrays to
384
+ PicklableArrayPayload.
385
+ """
386
+ assert _is_primitive(a.type), a.type
387
+ # Buffer scheme: [bitmap, data]
388
+ buffers = a.buffers()
389
+ assert len(buffers) == 2, len(buffers)
390
+
391
+ # Copy bitmap buffer, if needed.
392
+ bitmap_buf = buffers[0]
393
+ if a.null_count > 0:
394
+ bitmap_buf = _copy_bitpacked_buffer_if_needed(bitmap_buf, a.offset, len(a))
395
+ else:
396
+ bitmap_buf = None
397
+
398
+ # Copy data buffer, if needed.
399
+ data_buf = buffers[1]
400
+ if data_buf is not None:
401
+ data_buf = _copy_buffer_if_needed(buffers[1], a.type, a.offset, len(a))
402
+
403
+ return PicklableArrayPayload(
404
+ type=a.type,
405
+ length=len(a),
406
+ buffers=[bitmap_buf, data_buf],
407
+ null_count=a.null_count,
408
+ offset=0,
409
+ children=[],
410
+ )
411
+
412
+
413
+ def _binary_array_to_array_payload(a: "pyarrow.Array") -> "PicklableArrayPayload":
414
+ """Serialize binary (variable-sized binary, string) arrays to
415
+ PicklableArrayPayload.
416
+ """
417
+ assert _is_binary(a.type), a.type
418
+ # Buffer scheme: [bitmap, value_offsets, data]
419
+ buffers = a.buffers()
420
+ assert len(buffers) == 3, len(buffers)
421
+
422
+ # Copy bitmap buffer, if needed.
423
+ if a.null_count > 0:
424
+ bitmap_buf = _copy_bitpacked_buffer_if_needed(buffers[0], a.offset, len(a))
425
+ else:
426
+ bitmap_buf = None
427
+
428
+ # Copy offset buffer, if needed.
429
+ offset_buf = buffers[1]
430
+ offset_buf, data_offset, data_length = _copy_offsets_buffer_if_needed(
431
+ offset_buf, a.type, a.offset, len(a)
432
+ )
433
+ data_buf = buffers[2]
434
+ data_buf = _copy_buffer_if_needed(data_buf, None, data_offset, data_length)
435
+ return PicklableArrayPayload(
436
+ type=a.type,
437
+ length=len(a),
438
+ buffers=[bitmap_buf, offset_buf, data_buf],
439
+ null_count=a.null_count,
440
+ offset=0,
441
+ children=[],
442
+ )
443
+
444
+
445
+ def _list_array_to_array_payload(a: "pyarrow.Array") -> "PicklableArrayPayload":
446
+ """Serialize list (regular and large) arrays to PicklableArrayPayload."""
447
+ # Dedicated path for ListArrays. These arrays have a nested set of bitmap and
448
+ # offset buffers, eventually bottoming out on a data buffer.
449
+ # Buffer scheme:
450
+ # [bitmap, offsets, bitmap, offsets, ..., bitmap, data]
451
+ buffers = a.buffers()
452
+ assert len(buffers) > 1, len(buffers)
453
+
454
+ # Copy bitmap buffer, if needed.
455
+ if a.null_count > 0:
456
+ bitmap_buf = _copy_bitpacked_buffer_if_needed(buffers[0], a.offset, len(a))
457
+ else:
458
+ bitmap_buf = None
459
+
460
+ # Copy offset buffer, if needed.
461
+ offset_buf = buffers[1]
462
+ offset_buf, child_offset, child_length = _copy_offsets_buffer_if_needed(
463
+ offset_buf, a.type, a.offset, len(a)
464
+ )
465
+
466
+ # Propagate slice to child.
467
+ child = a.values.slice(child_offset, child_length)
468
+
469
+ return PicklableArrayPayload(
470
+ type=a.type,
471
+ length=len(a),
472
+ buffers=[bitmap_buf, offset_buf],
473
+ null_count=a.null_count,
474
+ offset=0,
475
+ children=[_array_to_array_payload(child)],
476
+ )
477
+
478
+
479
+ def _fixed_size_list_array_to_array_payload(
480
+ a: "pyarrow.FixedSizeListArray",
481
+ ) -> "PicklableArrayPayload":
482
+ """Serialize fixed size list arrays to PicklableArrayPayload."""
483
+ # Dedicated path for fixed-size lists.
484
+ # Buffer scheme:
485
+ # [bitmap, values_bitmap, values_data, values_subbuffers...]
486
+ buffers = a.buffers()
487
+ assert len(buffers) >= 1, len(buffers)
488
+
489
+ # Copy bitmap buffer, if needed.
490
+ if a.null_count > 0:
491
+ bitmap_buf = _copy_bitpacked_buffer_if_needed(buffers[0], a.offset, len(a))
492
+ else:
493
+ bitmap_buf = None
494
+
495
+ # Propagate slice to child.
496
+ child_offset = a.type.list_size * a.offset
497
+ child_length = a.type.list_size * len(a)
498
+ child = a.values.slice(child_offset, child_length)
499
+
500
+ return PicklableArrayPayload(
501
+ type=a.type,
502
+ length=len(a),
503
+ buffers=[bitmap_buf],
504
+ null_count=a.null_count,
505
+ offset=0,
506
+ children=[_array_to_array_payload(child)],
507
+ )
508
+
509
+
510
+ def _struct_array_to_array_payload(a: "pyarrow.StructArray") -> "PicklableArrayPayload":
511
+ """Serialize struct arrays to PicklableArrayPayload."""
512
+ # Dedicated path for StructArrays.
513
+ # StructArrays have a top-level bitmap buffer and one or more children arrays.
514
+ # Buffer scheme: [bitmap, None, child_bitmap, child_data, ...]
515
+ buffers = a.buffers()
516
+ assert len(buffers) >= 1, len(buffers)
517
+
518
+ # Copy bitmap buffer, if needed.
519
+ if a.null_count > 0:
520
+ bitmap_buf = _copy_bitpacked_buffer_if_needed(buffers[0], a.offset, len(a))
521
+ else:
522
+ bitmap_buf = None
523
+
524
+ # Get field children payload.
525
+ # Offsets and truncations are already propagated to the field arrays, so we can
526
+ # serialize them as-is.
527
+ children = [_array_to_array_payload(a.field(i)) for i in range(a.type.num_fields)]
528
+ return PicklableArrayPayload(
529
+ type=a.type,
530
+ length=len(a),
531
+ buffers=[bitmap_buf],
532
+ null_count=a.null_count,
533
+ offset=0,
534
+ children=children,
535
+ )
536
+
537
+
538
+ def _union_array_to_array_payload(a: "pyarrow.UnionArray") -> "PicklableArrayPayload":
539
+ """Serialize union arrays to PicklableArrayPayload."""
540
+ import pyarrow as pa
541
+
542
+ # Dedicated path for UnionArrays.
543
+ # UnionArrays have a top-level bitmap buffer and type code buffer, and one or
544
+ # more children arrays.
545
+ # Buffer scheme: [None, typecodes, child_bitmap, child_data, ...]
546
+ assert not _is_dense_union(a.type)
547
+ buffers = a.buffers()
548
+ assert len(buffers) > 1, len(buffers)
549
+
550
+ bitmap_buf = buffers[0]
551
+ assert bitmap_buf is None, bitmap_buf
552
+
553
+ # Copy type code buffer, if needed.
554
+ type_code_buf = buffers[1]
555
+ type_code_buf = _copy_buffer_if_needed(type_code_buf, pa.int8(), a.offset, len(a))
556
+
557
+ # Get field children payload.
558
+ # Offsets and truncations are already propagated to the field arrays, so we can
559
+ # serialize them as-is.
560
+ children = [_array_to_array_payload(a.field(i)) for i in range(a.type.num_fields)]
561
+ return PicklableArrayPayload(
562
+ type=a.type,
563
+ length=len(a),
564
+ buffers=[bitmap_buf, type_code_buf],
565
+ null_count=a.null_count,
566
+ offset=0,
567
+ children=children,
568
+ )
569
+
570
+
571
+ def _dictionary_array_to_array_payload(
572
+ a: "pyarrow.DictionaryArray",
573
+ ) -> "PicklableArrayPayload":
574
+ """Serialize dictionary arrays to PicklableArrayPayload."""
575
+ # Dedicated path for DictionaryArrays.
576
+ # Buffer scheme: [indices_bitmap, indices_data] (dictionary stored separately)
577
+ indices_payload = _array_to_array_payload(a.indices)
578
+ dictionary_payload = _array_to_array_payload(a.dictionary)
579
+ return PicklableArrayPayload(
580
+ type=a.type,
581
+ length=len(a),
582
+ buffers=[],
583
+ null_count=a.null_count,
584
+ offset=0,
585
+ children=[indices_payload, dictionary_payload],
586
+ )
587
+
588
+
589
+ def _map_array_to_array_payload(a: "pyarrow.MapArray") -> "PicklableArrayPayload":
590
+ """Serialize map arrays to PicklableArrayPayload."""
591
+ import pyarrow as pa
592
+
593
+ # Dedicated path for MapArrays.
594
+ # Buffer scheme: [bitmap, offsets, child_struct_array_buffers, ...]
595
+ buffers = a.buffers()
596
+ assert len(buffers) > 0, len(buffers)
597
+
598
+ # Copy bitmap buffer, if needed.
599
+ if a.null_count > 0:
600
+ bitmap_buf = _copy_bitpacked_buffer_if_needed(buffers[0], a.offset, len(a))
601
+ else:
602
+ bitmap_buf = None
603
+
604
+ new_buffers = [bitmap_buf]
605
+
606
+ # Copy offsets buffer, if needed.
607
+ offset_buf = buffers[1]
608
+ offset_buf, data_offset, data_length = _copy_offsets_buffer_if_needed(
609
+ offset_buf, a.type, a.offset, len(a)
610
+ )
611
+
612
+ if isinstance(a, pa.lib.ListArray):
613
+ # Map arrays directly expose the one child struct array in pyarrow>=7.0.0, which
614
+ # is easier to work with than the raw buffers.
615
+ new_buffers.append(offset_buf)
616
+ children = [_array_to_array_payload(a.values.slice(data_offset, data_length))]
617
+ else:
618
+ # In pyarrow<7.0.0, the child struct array is not exposed, so we work with the
619
+ # key and item arrays.
620
+ buffers = a.buffers()
621
+ assert len(buffers) > 2, len(buffers)
622
+ # Reconstruct offsets array.
623
+ offsets = pa.Array.from_buffers(
624
+ pa.int32(), len(a) + 1, [bitmap_buf, offset_buf]
625
+ )
626
+ # Propagate slice to keys.
627
+ keys = a.keys.slice(data_offset, data_length)
628
+ # Propagate slice to items.
629
+ items = a.items.slice(data_offset, data_length)
630
+ children = [
631
+ _array_to_array_payload(offsets),
632
+ _array_to_array_payload(keys),
633
+ _array_to_array_payload(items),
634
+ ]
635
+ return PicklableArrayPayload(
636
+ type=a.type,
637
+ length=len(a),
638
+ buffers=new_buffers,
639
+ null_count=a.null_count,
640
+ offset=0,
641
+ children=children,
642
+ )
643
+
644
+
645
+ def _tensor_array_to_array_payload(a: "ArrowTensorArray") -> "PicklableArrayPayload":
646
+ """Serialize tensor arrays to PicklableArrayPayload."""
647
+ # Offset is propagated to storage array, and the storage array items align with the
648
+ # tensor elements, so we only need to do the straightforward creation of the storage
649
+ # array payload.
650
+ storage_payload = _array_to_array_payload(a.storage)
651
+ return PicklableArrayPayload(
652
+ type=a.type,
653
+ length=len(a),
654
+ buffers=[],
655
+ null_count=a.null_count,
656
+ offset=0,
657
+ children=[storage_payload],
658
+ )
659
+
660
+
661
+ def _extension_array_to_array_payload(
662
+ a: "pyarrow.ExtensionArray",
663
+ ) -> "PicklableArrayPayload":
664
+ payload = _array_to_array_payload(a.storage)
665
+ payload.type = a.type
666
+ payload.length = len(a)
667
+ payload.null_count = a.null_count
668
+ return payload
669
+
670
+
671
+ def _copy_buffer_if_needed(
672
+ buf: "pyarrow.Buffer",
673
+ type_: Optional["pyarrow.DataType"],
674
+ offset: int,
675
+ length: int,
676
+ ) -> "pyarrow.Buffer":
677
+ """Copy buffer, if needed."""
678
+ import pyarrow as pa
679
+
680
+ if type_ is not None and pa.types.is_boolean(type_):
681
+ # Arrow boolean array buffers are bit-packed, with 8 entries per byte,
682
+ # and are accessed via bit offsets.
683
+ buf = _copy_bitpacked_buffer_if_needed(buf, offset, length)
684
+ else:
685
+ type_bytewidth = type_.bit_width // 8 if type_ is not None else 1
686
+ buf = _copy_normal_buffer_if_needed(buf, type_bytewidth, offset, length)
687
+ return buf
688
+
689
+
690
+ def _copy_normal_buffer_if_needed(
691
+ buf: "pyarrow.Buffer",
692
+ byte_width: int,
693
+ offset: int,
694
+ length: int,
695
+ ) -> "pyarrow.Buffer":
696
+ """Copy buffer, if needed."""
697
+ byte_offset = offset * byte_width
698
+ byte_length = length * byte_width
699
+ if offset > 0 or byte_length < buf.size:
700
+ # Array is a zero-copy slice, so we need to copy to a new buffer before
701
+ # serializing; this slice of the underlying buffer (not the array) will ensure
702
+ # that the buffer is properly copied at pickle-time.
703
+ buf = buf.slice(byte_offset, byte_length)
704
+ return buf
705
+
706
+
707
+ def _copy_bitpacked_buffer_if_needed(
708
+ buf: "pyarrow.Buffer",
709
+ offset: int,
710
+ length: int,
711
+ ) -> "pyarrow.Buffer":
712
+ """Copy bit-packed binary buffer, if needed."""
713
+ bit_offset = offset % 8
714
+ byte_offset = offset // 8
715
+ byte_length = _bytes_for_bits(bit_offset + length) // 8
716
+ if offset > 0 or byte_length < buf.size:
717
+ buf = buf.slice(byte_offset, byte_length)
718
+ if bit_offset != 0:
719
+ # Need to manually shift the buffer to eliminate the bit offset.
720
+ buf = _align_bit_offset(buf, bit_offset, byte_length)
721
+ return buf
722
+
723
+
724
+ def _copy_offsets_buffer_if_needed(
725
+ buf: "pyarrow.Buffer",
726
+ arr_type: "pyarrow.DataType",
727
+ offset: int,
728
+ length: int,
729
+ ) -> Tuple["pyarrow.Buffer", int, int]:
730
+ """Copy the provided offsets buffer, returning the copied buffer and the
731
+ offset + length of the underlying data.
732
+ """
733
+ import pyarrow as pa
734
+ import pyarrow.compute as pac
735
+
736
+ if (
737
+ pa.types.is_large_list(arr_type)
738
+ or pa.types.is_large_string(arr_type)
739
+ or pa.types.is_large_binary(arr_type)
740
+ or pa.types.is_large_unicode(arr_type)
741
+ ):
742
+ offset_type = pa.int64()
743
+ else:
744
+ offset_type = pa.int32()
745
+ # Copy offset buffer, if needed.
746
+ buf = _copy_buffer_if_needed(buf, offset_type, offset, length + 1)
747
+ # Reconstruct the offset array so we can determine the offset and length
748
+ # of the child array.
749
+ offsets = pa.Array.from_buffers(offset_type, length + 1, [None, buf])
750
+ child_offset = offsets[0].as_py()
751
+ child_length = offsets[-1].as_py() - child_offset
752
+ # Create new offsets aligned to 0 for the copied data buffer slice.
753
+ offsets = pac.subtract(offsets, child_offset)
754
+ if pa.types.is_int32(offset_type):
755
+ # We need to cast the resulting Int64Array back down to an Int32Array.
756
+ offsets = offsets.cast(offset_type, safe=False)
757
+ buf = offsets.buffers()[1]
758
+ return buf, child_offset, child_length
759
+
760
+
761
+ def _bytes_for_bits(n: int) -> int:
762
+ """Round up n to the nearest multiple of 8.
763
+ This is used to get the byte-padded number of bits for n bits.
764
+ """
765
+ return (n + 7) & (-8)
766
+
767
+
768
+ def _align_bit_offset(
769
+ buf: "pyarrow.Buffer",
770
+ bit_offset: int,
771
+ byte_length: int,
772
+ ) -> "pyarrow.Buffer":
773
+ """Align the bit offset into the buffer with the front of the buffer by shifting
774
+ the buffer and eliminating the offset.
775
+ """
776
+ import pyarrow as pa
777
+
778
+ bytes_ = buf.to_pybytes()
779
+ bytes_as_int = int.from_bytes(bytes_, sys.byteorder)
780
+ bytes_as_int >>= bit_offset
781
+ bytes_ = bytes_as_int.to_bytes(byte_length, sys.byteorder)
782
+ return pa.py_buffer(bytes_)
783
+
784
+
785
+ def _arrow_table_ipc_reduce(table: "pyarrow.Table"):
786
+ """Custom reducer for Arrow Table that works around a zero-copy slicing pickling
787
+ bug by using the Arrow IPC format for the underlying serialization.
788
+
789
+ This is currently used as a fallback for unsupported types (or unknown bugs) for
790
+ the manual buffer truncation workaround, e.g. for dense unions.
791
+ """
792
+ from pyarrow.ipc import RecordBatchStreamWriter
793
+ from pyarrow.lib import BufferOutputStream
794
+
795
+ output_stream = BufferOutputStream()
796
+ with RecordBatchStreamWriter(output_stream, schema=table.schema) as wr:
797
+ wr.write_table(table)
798
+ # NOTE: output_stream.getvalue() materializes the serialized table to a single
799
+ # contiguous bytestring, resulting in a few copy. This adds 1-2 extra copies on the
800
+ # serialization side, and 1 extra copy on the deserialization side.
801
+ return _restore_table_from_ipc, (output_stream.getvalue(),)
802
+
803
+
804
+ def _restore_table_from_ipc(buf: bytes) -> "pyarrow.Table":
805
+ """Restore an Arrow Table serialized to Arrow IPC format."""
806
+ from pyarrow.ipc import RecordBatchStreamReader
807
+
808
+ with RecordBatchStreamReader(buf) as reader:
809
+ return reader.read_all()
810
+
811
+
812
+ def _is_dense_union(type_: "pyarrow.DataType") -> bool:
813
+ """Whether the provided Arrow type is a dense union."""
814
+ import pyarrow as pa
815
+
816
+ return pa.types.is_union(type_) and type_.mode == "dense"
.venv/lib/python3.11/site-packages/ray/_private/async_compat.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file should only be imported from Python 3.
3
+ It will raise SyntaxError when importing from Python 2.
4
+ """
5
+ import asyncio
6
+ import inspect
7
+ from functools import lru_cache
8
+
9
+ try:
10
+ import uvloop
11
+ except ImportError:
12
+ uvloop = None
13
+
14
+
15
+ def get_new_event_loop():
16
+ """Construct a new event loop. Ray will use uvloop if it exists"""
17
+ if uvloop:
18
+ return uvloop.new_event_loop()
19
+ else:
20
+ return asyncio.new_event_loop()
21
+
22
+
23
+ def try_install_uvloop():
24
+ """Installs uvloop as event-loop implementation for asyncio (if available)"""
25
+ if uvloop:
26
+ uvloop.install()
27
+ else:
28
+ pass
29
+
30
+
31
+ def is_async_func(func) -> bool:
32
+ """Return True if the function is an async or async generator method."""
33
+ return inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func)
34
+
35
+
36
+ @lru_cache(maxsize=2**10)
37
+ def has_async_methods(cls: object) -> bool:
38
+ """Return True if the class has any async methods."""
39
+ return len(inspect.getmembers(cls, predicate=is_async_func)) > 0
40
+
41
+
42
+ @lru_cache(maxsize=2**10)
43
+ def sync_to_async(func):
44
+ """Wrap a blocking function in an async function"""
45
+
46
+ if is_async_func(func):
47
+ return func
48
+
49
+ async def wrapper(*args, **kwargs):
50
+ return func(*args, **kwargs)
51
+
52
+ return wrapper
.venv/lib/python3.11/site-packages/ray/_private/async_utils.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from [aiodebug](https://gitlab.com/quantlane/libs/aiodebug)
2
+
3
+ # Copyright 2016-2022 Quantlane s.r.o.
4
+
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ # Modifications:
18
+ # - Removed the dependency to `logwood`.
19
+ # - Renamed `monitor_loop_lag.enable()` to just `enable_monitor_loop_lag()`.
20
+ # - Miscellaneous changes to make it work with Ray.
21
+
22
+ from typing import Callable, Optional
23
+ import asyncio
24
+ import asyncio.events
25
+
26
+
27
+ def enable_monitor_loop_lag(
28
+ callback: Callable[[float], None],
29
+ interval_s: float = 0.25,
30
+ loop: Optional[asyncio.AbstractEventLoop] = None,
31
+ ) -> None:
32
+ """
33
+ Start logging event loop lags to the callback. In ideal circumstances they should be
34
+ very close to zero. Lags may increase if event loop callbacks block for too long.
35
+
36
+ Note: this works for all event loops, including uvloop.
37
+
38
+ :param callback: Callback to call with the lag in seconds.
39
+ """
40
+ if loop is None:
41
+ loop = asyncio.get_running_loop()
42
+ if loop is None:
43
+ raise ValueError("No provided loop, nor running loop found.")
44
+
45
+ async def monitor():
46
+ while loop.is_running():
47
+ t0 = loop.time()
48
+ await asyncio.sleep(interval_s)
49
+ lag = loop.time() - t0 - interval_s # Should be close to zero.
50
+ callback(lag)
51
+
52
+ loop.create_task(monitor(), name="async_utils.monitor_loop_lag")
.venv/lib/python3.11/site-packages/ray/_private/auto_init_hook.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ray
2
+ import os
3
+ from functools import wraps
4
+ import threading
5
+
6
+ auto_init_lock = threading.Lock()
7
+ enable_auto_connect = os.environ.get("RAY_ENABLE_AUTO_CONNECT", "") != "0"
8
+
9
+
10
+ def auto_init_ray():
11
+ if enable_auto_connect and not ray.is_initialized():
12
+ with auto_init_lock:
13
+ if not ray.is_initialized():
14
+ ray.init()
15
+
16
+
17
+ def wrap_auto_init(fn):
18
+ @wraps(fn)
19
+ def auto_init_wrapper(*args, **kwargs):
20
+ auto_init_ray()
21
+ return fn(*args, **kwargs)
22
+
23
+ return auto_init_wrapper
24
+
25
+
26
+ def wrap_auto_init_for_all_apis(api_names):
27
+ """Wrap public APIs with automatic ray.init."""
28
+ for api_name in api_names:
29
+ api = getattr(ray, api_name, None)
30
+ assert api is not None, api_name
31
+ setattr(ray, api_name, wrap_auto_init(api))
.venv/lib/python3.11/site-packages/ray/_private/client_mode_hook.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import threading
3
+ from contextlib import contextmanager
4
+ from functools import wraps
5
+ from ray._private.auto_init_hook import auto_init_ray
6
+
7
+ # Attr set on func defs to mark they have been converted to client mode.
8
+ RAY_CLIENT_MODE_ATTR = "__ray_client_mode_key__"
9
+
10
+ # Global setting of whether client mode is enabled. This default to OFF,
11
+ # but is enabled upon ray.client(...).connect() or in tests.
12
+ is_client_mode_enabled = os.environ.get("RAY_CLIENT_MODE", "0") == "1"
13
+
14
+ # When RAY_CLIENT_MODE == 1, we treat it as default enabled client mode
15
+ # This is useful for testing
16
+ is_client_mode_enabled_by_default = is_client_mode_enabled
17
+ os.environ.update({"RAY_CLIENT_MODE": "0"})
18
+
19
+ is_init_called = False
20
+
21
+ # Local setting of whether to ignore client hook conversion. This defaults
22
+ # to TRUE and is disabled when the underlying 'real' Ray function is needed.
23
+ _client_hook_status_on_thread = threading.local()
24
+ _client_hook_status_on_thread.status = True
25
+
26
+
27
+ def _get_client_hook_status_on_thread():
28
+ """Get's the value of `_client_hook_status_on_thread`.
29
+ Since `_client_hook_status_on_thread` is a thread-local variable, we may
30
+ need to add and set the 'status' attribute.
31
+ """
32
+ global _client_hook_status_on_thread
33
+ if not hasattr(_client_hook_status_on_thread, "status"):
34
+ _client_hook_status_on_thread.status = True
35
+ return _client_hook_status_on_thread.status
36
+
37
+
38
+ def _set_client_hook_status(val: bool):
39
+ global _client_hook_status_on_thread
40
+ _client_hook_status_on_thread.status = val
41
+
42
+
43
+ def _disable_client_hook():
44
+ global _client_hook_status_on_thread
45
+ out = _get_client_hook_status_on_thread()
46
+ _client_hook_status_on_thread.status = False
47
+ return out
48
+
49
+
50
+ def _explicitly_enable_client_mode():
51
+ """Force client mode to be enabled.
52
+ NOTE: This should not be used in tests, use `enable_client_mode`.
53
+ """
54
+ global is_client_mode_enabled
55
+ is_client_mode_enabled = True
56
+
57
+
58
+ def _explicitly_disable_client_mode():
59
+ global is_client_mode_enabled
60
+ is_client_mode_enabled = False
61
+
62
+
63
+ @contextmanager
64
+ def disable_client_hook():
65
+ val = _disable_client_hook()
66
+ try:
67
+ yield None
68
+ finally:
69
+ _set_client_hook_status(val)
70
+
71
+
72
+ @contextmanager
73
+ def enable_client_mode():
74
+ _explicitly_enable_client_mode()
75
+ try:
76
+ yield None
77
+ finally:
78
+ _explicitly_disable_client_mode()
79
+
80
+
81
+ def client_mode_hook(func: callable):
82
+ """Decorator for whether to use the 'regular' ray version of a function,
83
+ or the Ray Client version of that function.
84
+
85
+ Args:
86
+ func: This function. This is set when this function is used
87
+ as a decorator.
88
+ """
89
+
90
+ from ray.util.client import ray
91
+
92
+ @wraps(func)
93
+ def wrapper(*args, **kwargs):
94
+ # NOTE(hchen): DO NOT use "import" inside this function.
95
+ # Because when it's called within a `__del__` method, this error
96
+ # will be raised (see #35114):
97
+ # ImportError: sys.meta_path is None, Python is likely shutting down.
98
+ if client_mode_should_convert():
99
+ # Legacy code
100
+ # we only convert init function if RAY_CLIENT_MODE=1
101
+ if func.__name__ != "init" or is_client_mode_enabled_by_default:
102
+ return getattr(ray, func.__name__)(*args, **kwargs)
103
+ return func(*args, **kwargs)
104
+
105
+ return wrapper
106
+
107
+
108
+ def client_mode_should_convert():
109
+ """Determines if functions should be converted to client mode."""
110
+
111
+ # `is_client_mode_enabled_by_default` is used for testing with
112
+ # `RAY_CLIENT_MODE=1`. This flag means all tests run with client mode.
113
+ return (
114
+ is_client_mode_enabled or is_client_mode_enabled_by_default
115
+ ) and _get_client_hook_status_on_thread()
116
+
117
+
118
+ def client_mode_wrap(func):
119
+ """Wraps a function called during client mode for execution as a remote
120
+ task.
121
+
122
+ Can be used to implement public features of ray client which do not
123
+ belong in the main ray API (`ray.*`), yet require server-side execution.
124
+ An example is the creation of placement groups:
125
+ `ray.util.placement_group.placement_group()`. When called on the client
126
+ side, this function is wrapped in a task to facilitate interaction with
127
+ the GCS.
128
+ """
129
+
130
+ @wraps(func)
131
+ def wrapper(*args, **kwargs):
132
+ from ray.util.client import ray
133
+
134
+ auto_init_ray()
135
+ # Directly pass this through since `client_mode_wrap` is for
136
+ # Placement Group APIs
137
+ if client_mode_should_convert():
138
+ f = ray.remote(num_cpus=0)(func)
139
+ ref = f.remote(*args, **kwargs)
140
+ return ray.get(ref)
141
+ return func(*args, **kwargs)
142
+
143
+ return wrapper
144
+
145
+
146
+ def client_mode_convert_function(func_cls, in_args, in_kwargs, **kwargs):
147
+ """Runs a preregistered ray RemoteFunction through the ray client.
148
+
149
+ The common case for this is to transparently convert that RemoteFunction
150
+ to a ClientRemoteFunction. This happens in circumstances where the
151
+ RemoteFunction is declared early, in a library and only then is Ray used in
152
+ client mode -- necessitating a conversion.
153
+ """
154
+ from ray.util.client import ray
155
+
156
+ key = getattr(func_cls, RAY_CLIENT_MODE_ATTR, None)
157
+
158
+ # Second part of "or" is needed in case func_cls is reused between Ray
159
+ # client sessions in one Python interpreter session.
160
+ if (key is None) or (not ray._converted_key_exists(key)):
161
+ key = ray._convert_function(func_cls)
162
+ setattr(func_cls, RAY_CLIENT_MODE_ATTR, key)
163
+ client_func = ray._get_converted(key)
164
+ return client_func._remote(in_args, in_kwargs, **kwargs)
165
+
166
+
167
+ def client_mode_convert_actor(actor_cls, in_args, in_kwargs, **kwargs):
168
+ """Runs a preregistered actor class on the ray client
169
+
170
+ The common case for this decorator is for instantiating an ActorClass
171
+ transparently as a ClientActorClass. This happens in circumstances where
172
+ the ActorClass is declared early, in a library and only then is Ray used in
173
+ client mode -- necessitating a conversion.
174
+ """
175
+ from ray.util.client import ray
176
+
177
+ key = getattr(actor_cls, RAY_CLIENT_MODE_ATTR, None)
178
+ # Second part of "or" is needed in case actor_cls is reused between Ray
179
+ # client sessions in one Python interpreter session.
180
+ if (key is None) or (not ray._converted_key_exists(key)):
181
+ key = ray._convert_actor(actor_cls)
182
+ setattr(actor_cls, RAY_CLIENT_MODE_ATTR, key)
183
+ client_actor = ray._get_converted(key)
184
+ return client_actor._remote(in_args, in_kwargs, **kwargs)
.venv/lib/python3.11/site-packages/ray/_private/collections_utils.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Any
2
+
3
+
4
+ def split(items: List[Any], chunk_size: int):
5
+ """Splits provided list into chunks of given size"""
6
+
7
+ assert chunk_size > 0, "Chunk size has to be > 0"
8
+
9
+ for i in range(0, len(items), chunk_size):
10
+ yield items[i : i + chunk_size]
.venv/lib/python3.11/site-packages/ray/_private/compat.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import platform
3
+
4
+
5
+ def patch_psutil():
6
+ """WSL's /proc/meminfo has an inconsistency where it
7
+ nondeterministically omits a space after colons (after "SwapFree:"
8
+ in my case).
9
+ psutil then splits on spaces and then parses the wrong field,
10
+ crashing on the 'int(fields[1])' expression in
11
+ psutil._pslinux.virtual_memory().
12
+ Workaround: We ensure there is a space following each colon.
13
+ """
14
+ assert (
15
+ platform.system() == "Linux"
16
+ and "Microsoft".lower() in platform.release().lower()
17
+ )
18
+
19
+ try:
20
+ import psutil._pslinux
21
+ except ImportError:
22
+ psutil = None
23
+ psutil_open_binary = None
24
+ if psutil:
25
+ try:
26
+ psutil_open_binary = psutil._pslinux.open_binary
27
+ except AttributeError:
28
+ pass
29
+ # Only patch it if it doesn't seem to have been patched already
30
+ if psutil_open_binary and psutil_open_binary.__name__ == "open_binary":
31
+
32
+ def psutil_open_binary_patched(fname, *args, **kwargs):
33
+ f = psutil_open_binary(fname, *args, **kwargs)
34
+ if fname == "/proc/meminfo":
35
+ with f:
36
+ # Make sure there's a space after colons
37
+ return io.BytesIO(f.read().replace(b":", b": "))
38
+ return f
39
+
40
+ psutil._pslinux.open_binary = psutil_open_binary_patched
.venv/lib/python3.11/site-packages/ray/_private/conftest_utils.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import ray._private.ray_constants as ray_constants
3
+
4
+
5
+ @pytest.fixture
6
+ def set_override_dashboard_url(monkeypatch, request):
7
+ override_url = getattr(request, "param", "https://external_dashboard_url")
8
+ with monkeypatch.context() as m:
9
+ if override_url:
10
+ m.setenv(
11
+ ray_constants.RAY_OVERRIDE_DASHBOARD_URL,
12
+ override_url,
13
+ )
14
+ yield
.venv/lib/python3.11/site-packages/ray/_private/dict.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from collections import deque
3
+ from collections.abc import Mapping, Sequence
4
+ from typing import Dict, List, Optional, TypeVar, Union
5
+
6
+ from ray.util.annotations import Deprecated
7
+
8
+ T = TypeVar("T")
9
+
10
+
11
+ @Deprecated
12
+ def merge_dicts(d1: dict, d2: dict) -> dict:
13
+ """
14
+ Args:
15
+ d1 (dict): Dict 1.
16
+ d2 (dict): Dict 2.
17
+
18
+ Returns:
19
+ dict: A new dict that is d1 and d2 deep merged.
20
+ """
21
+ merged = copy.deepcopy(d1)
22
+ deep_update(merged, d2, True, [])
23
+ return merged
24
+
25
+
26
+ @Deprecated
27
+ def deep_update(
28
+ original: dict,
29
+ new_dict: dict,
30
+ new_keys_allowed: bool = False,
31
+ allow_new_subkey_list: Optional[List[str]] = None,
32
+ override_all_if_type_changes: Optional[List[str]] = None,
33
+ override_all_key_list: Optional[List[str]] = None,
34
+ ) -> dict:
35
+ """Updates original dict with values from new_dict recursively.
36
+
37
+ If new key is introduced in new_dict, then if new_keys_allowed is not
38
+ True, an error will be thrown. Further, for sub-dicts, if the key is
39
+ in the allow_new_subkey_list, then new subkeys can be introduced.
40
+
41
+ Args:
42
+ original: Dictionary with default values.
43
+ new_dict: Dictionary with values to be updated
44
+ new_keys_allowed: Whether new keys are allowed.
45
+ allow_new_subkey_list: List of keys that
46
+ correspond to dict values where new subkeys can be introduced.
47
+ This is only at the top level.
48
+ override_all_if_type_changes: List of top level
49
+ keys with value=dict, for which we always simply override the
50
+ entire value (dict), iff the "type" key in that value dict changes.
51
+ override_all_key_list: List of top level keys
52
+ for which we override the entire value if the key is in the new_dict.
53
+ """
54
+ allow_new_subkey_list = allow_new_subkey_list or []
55
+ override_all_if_type_changes = override_all_if_type_changes or []
56
+ override_all_key_list = override_all_key_list or []
57
+
58
+ for k, value in new_dict.items():
59
+ if k not in original and not new_keys_allowed:
60
+ raise Exception("Unknown config parameter `{}` ".format(k))
61
+
62
+ # Both orginal value and new one are dicts.
63
+ if (
64
+ isinstance(original.get(k), dict)
65
+ and isinstance(value, dict)
66
+ and k not in override_all_key_list
67
+ ):
68
+ # Check old type vs old one. If different, override entire value.
69
+ if (
70
+ k in override_all_if_type_changes
71
+ and "type" in value
72
+ and "type" in original[k]
73
+ and value["type"] != original[k]["type"]
74
+ ):
75
+ original[k] = value
76
+ # Allowed key -> ok to add new subkeys.
77
+ elif k in allow_new_subkey_list:
78
+ deep_update(
79
+ original[k],
80
+ value,
81
+ True,
82
+ override_all_key_list=override_all_key_list,
83
+ )
84
+ # Non-allowed key.
85
+ else:
86
+ deep_update(
87
+ original[k],
88
+ value,
89
+ new_keys_allowed,
90
+ override_all_key_list=override_all_key_list,
91
+ )
92
+ # Original value not a dict OR new value not a dict:
93
+ # Override entire value.
94
+ else:
95
+ original[k] = value
96
+ return original
97
+
98
+
99
+ @Deprecated
100
+ def flatten_dict(
101
+ dt: Dict,
102
+ delimiter: str = "/",
103
+ prevent_delimiter: bool = False,
104
+ flatten_list: bool = False,
105
+ ):
106
+ """Flatten dict.
107
+
108
+ Output and input are of the same dict type.
109
+ Input dict remains the same after the operation.
110
+ """
111
+
112
+ def _raise_delimiter_exception():
113
+ raise ValueError(
114
+ f"Found delimiter `{delimiter}` in key when trying to flatten "
115
+ f"array. Please avoid using the delimiter in your specification."
116
+ )
117
+
118
+ dt = copy.copy(dt)
119
+ if prevent_delimiter and any(delimiter in key for key in dt):
120
+ # Raise if delimiter is any of the keys
121
+ _raise_delimiter_exception()
122
+
123
+ while_check = (dict, list) if flatten_list else dict
124
+
125
+ while any(isinstance(v, while_check) for v in dt.values()):
126
+ remove = []
127
+ add = {}
128
+ for key, value in dt.items():
129
+ if isinstance(value, dict):
130
+ for subkey, v in value.items():
131
+ if prevent_delimiter and delimiter in subkey:
132
+ # Raise if delimiter is in any of the subkeys
133
+ _raise_delimiter_exception()
134
+
135
+ add[delimiter.join([key, str(subkey)])] = v
136
+ remove.append(key)
137
+ elif flatten_list and isinstance(value, list):
138
+ for i, v in enumerate(value):
139
+ if prevent_delimiter and delimiter in subkey:
140
+ # Raise if delimiter is in any of the subkeys
141
+ _raise_delimiter_exception()
142
+
143
+ add[delimiter.join([key, str(i)])] = v
144
+ remove.append(key)
145
+
146
+ dt.update(add)
147
+ for k in remove:
148
+ del dt[k]
149
+ return dt
150
+
151
+
152
+ @Deprecated
153
+ def unflatten_dict(dt: Dict[str, T], delimiter: str = "/") -> Dict[str, T]:
154
+ """Unflatten dict. Does not support unflattening lists."""
155
+ dict_type = type(dt)
156
+ out = dict_type()
157
+ for key, val in dt.items():
158
+ path = key.split(delimiter)
159
+ item = out
160
+ for k in path[:-1]:
161
+ item = item.setdefault(k, dict_type())
162
+ if not isinstance(item, dict_type):
163
+ raise TypeError(
164
+ f"Cannot unflatten dict due the key '{key}' "
165
+ f"having a parent key '{k}', which value is not "
166
+ f"of type {dict_type} (got {type(item)}). "
167
+ "Change the key names to resolve the conflict."
168
+ )
169
+ item[path[-1]] = val
170
+ return out
171
+
172
+
173
+ @Deprecated
174
+ def unflatten_list_dict(dt: Dict[str, T], delimiter: str = "/") -> Dict[str, T]:
175
+ """Unflatten nested dict and list.
176
+
177
+ This function now has some limitations:
178
+ (1) The keys of dt must be str.
179
+ (2) If unflattened dt (the result) contains list, the index order must be
180
+ ascending when accessing dt. Otherwise, this function will throw
181
+ AssertionError.
182
+ (3) The unflattened dt (the result) shouldn't contain dict with number
183
+ keys.
184
+
185
+ Be careful to use this function. If you want to improve this function,
186
+ please also improve the unit test. See #14487 for more details.
187
+
188
+ Args:
189
+ dt: Flattened dictionary that is originally nested by multiple
190
+ list and dict.
191
+ delimiter: Delimiter of keys.
192
+
193
+ Example:
194
+ >>> dt = {"aaa/0/bb": 12, "aaa/1/cc": 56, "aaa/1/dd": 92}
195
+ >>> unflatten_list_dict(dt)
196
+ {'aaa': [{'bb': 12}, {'cc': 56, 'dd': 92}]}
197
+ """
198
+ out_type = list if list(dt)[0].split(delimiter, 1)[0].isdigit() else type(dt)
199
+ out = out_type()
200
+ for key, val in dt.items():
201
+ path = key.split(delimiter)
202
+
203
+ item = out
204
+ for i, k in enumerate(path[:-1]):
205
+ next_type = list if path[i + 1].isdigit() else dict
206
+ if isinstance(item, dict):
207
+ item = item.setdefault(k, next_type())
208
+ elif isinstance(item, list):
209
+ if int(k) >= len(item):
210
+ item.append(next_type())
211
+ assert int(k) == len(item) - 1
212
+ item = item[int(k)]
213
+
214
+ if isinstance(item, dict):
215
+ item[path[-1]] = val
216
+ elif isinstance(item, list):
217
+ item.append(val)
218
+ assert int(path[-1]) == len(item) - 1
219
+ return out
220
+
221
+
222
+ @Deprecated
223
+ def unflattened_lookup(
224
+ flat_key: str, lookup: Union[Mapping, Sequence], delimiter: str = "/", **kwargs
225
+ ) -> Union[Mapping, Sequence]:
226
+ """
227
+ Unflatten `flat_key` and iteratively look up in `lookup`. E.g.
228
+ `flat_key="a/0/b"` will try to return `lookup["a"][0]["b"]`.
229
+ """
230
+ if flat_key in lookup:
231
+ return lookup[flat_key]
232
+ keys = deque(flat_key.split(delimiter))
233
+ base = lookup
234
+ while keys:
235
+ key = keys.popleft()
236
+ try:
237
+ if isinstance(base, Mapping):
238
+ base = base[key]
239
+ elif isinstance(base, Sequence):
240
+ base = base[int(key)]
241
+ else:
242
+ raise KeyError()
243
+ except KeyError as e:
244
+ if "default" in kwargs:
245
+ return kwargs["default"]
246
+ raise e
247
+ return base
.venv/lib/python3.11/site-packages/ray/_private/external_storage.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import logging
3
+ import os
4
+ import random
5
+ import shutil
6
+ import time
7
+ import urllib
8
+ import uuid
9
+ from collections import namedtuple
10
+ from typing import IO, List, Optional, Tuple, Union
11
+
12
+ import ray
13
+ from ray._private.ray_constants import DEFAULT_OBJECT_PREFIX
14
+ from ray._raylet import ObjectRef
15
+
16
+ ParsedURL = namedtuple("ParsedURL", "base_url, offset, size")
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def create_url_with_offset(*, url: str, offset: int, size: int) -> str:
21
+ """Methods to create a URL with offset.
22
+
23
+ When ray spills objects, it fuses multiple objects
24
+ into one file to optimize the performance. That says, each object
25
+ needs to keep tracking of its own special url to store metadata.
26
+
27
+ This method creates an url_with_offset, which is used internally
28
+ by Ray.
29
+
30
+ Created url_with_offset can be passed to the self._get_base_url method
31
+ to parse the filename used to store files.
32
+
33
+ Example) file://path/to/file?offset=""&size=""
34
+
35
+ Args:
36
+ url: url to the object stored in the external storage.
37
+ offset: Offset from the beginning of the file to
38
+ the first bytes of this object.
39
+ size: Size of the object that is stored in the url.
40
+ It is used to calculate the last offset.
41
+
42
+ Returns:
43
+ url_with_offset stored internally to find
44
+ objects from external storage.
45
+ """
46
+ return f"{url}?offset={offset}&size={size}"
47
+
48
+
49
+ def parse_url_with_offset(url_with_offset: str) -> Tuple[str, int, int]:
50
+ """Parse url_with_offset to retrieve information.
51
+
52
+ base_url is the url where the object ref
53
+ is stored in the external storage.
54
+
55
+ Args:
56
+ url_with_offset: url created by create_url_with_offset.
57
+
58
+ Returns:
59
+ named tuple of base_url, offset, and size.
60
+ """
61
+ parsed_result = urllib.parse.urlparse(url_with_offset)
62
+ query_dict = urllib.parse.parse_qs(parsed_result.query)
63
+ # Split by ? to remove the query from the url.
64
+ base_url = parsed_result.geturl().split("?")[0]
65
+ if "offset" not in query_dict or "size" not in query_dict:
66
+ raise ValueError(f"Failed to parse URL: {url_with_offset}")
67
+ offset = int(query_dict["offset"][0])
68
+ size = int(query_dict["size"][0])
69
+ return ParsedURL(base_url=base_url, offset=offset, size=size)
70
+
71
+
72
+ class ExternalStorage(metaclass=abc.ABCMeta):
73
+ """The base class for external storage.
74
+
75
+ This class provides some useful functions for zero-copy object
76
+ put/get from plasma store. Also it specifies the interface for
77
+ object spilling.
78
+
79
+ When inheriting this class, please make sure to implement validation
80
+ logic inside __init__ method. When ray instance starts, it will
81
+ instantiating external storage to validate the config.
82
+
83
+ Raises:
84
+ ValueError: when given configuration for
85
+ the external storage is invalid.
86
+ """
87
+
88
+ HEADER_LENGTH = 24
89
+
90
+ def _get_objects_from_store(self, object_refs):
91
+ worker = ray._private.worker.global_worker
92
+ # Since the object should always exist in the plasma store before
93
+ # spilling, it can directly get the object from the local plasma
94
+ # store.
95
+ # issue: https://github.com/ray-project/ray/pull/13831
96
+ ray_object_pairs = worker.core_worker.get_if_local(object_refs)
97
+ return ray_object_pairs
98
+
99
+ def _put_object_to_store(
100
+ self, metadata, data_size, file_like, object_ref, owner_address
101
+ ):
102
+ worker = ray._private.worker.global_worker
103
+ worker.core_worker.put_file_like_object(
104
+ metadata, data_size, file_like, object_ref, owner_address
105
+ )
106
+
107
+ def _write_multiple_objects(
108
+ self, f: IO, object_refs: List[ObjectRef], owner_addresses: List[str], url: str
109
+ ) -> List[str]:
110
+ """Fuse all given objects into a given file handle.
111
+
112
+ Args:
113
+ f: File handle to fusion all given object refs.
114
+ object_refs: Object references to fusion to a single file.
115
+ owner_addresses: Owner addresses for the provided objects.
116
+ url: url where the object ref is stored
117
+ in the external storage.
118
+
119
+ Return:
120
+ List of urls_with_offset of fused objects.
121
+ The order of returned keys are equivalent to the one
122
+ with given object_refs.
123
+ """
124
+ keys = []
125
+ offset = 0
126
+ ray_object_pairs = self._get_objects_from_store(object_refs)
127
+ for ref, (buf, metadata), owner_address in zip(
128
+ object_refs, ray_object_pairs, owner_addresses
129
+ ):
130
+ address_len = len(owner_address)
131
+ metadata_len = len(metadata)
132
+ if buf is None and len(metadata) == 0:
133
+ error = f"Object {ref.hex()} does not exist."
134
+ raise ValueError(error)
135
+ buf_len = 0 if buf is None else len(buf)
136
+ payload = (
137
+ address_len.to_bytes(8, byteorder="little")
138
+ + metadata_len.to_bytes(8, byteorder="little")
139
+ + buf_len.to_bytes(8, byteorder="little")
140
+ + owner_address
141
+ + metadata
142
+ + (memoryview(buf) if buf_len else b"")
143
+ )
144
+ # 24 bytes to store owner address, metadata, and buffer lengths.
145
+ payload_len = len(payload)
146
+ assert (
147
+ self.HEADER_LENGTH + address_len + metadata_len + buf_len == payload_len
148
+ )
149
+ written_bytes = f.write(payload)
150
+ assert written_bytes == payload_len
151
+ url_with_offset = create_url_with_offset(
152
+ url=url, offset=offset, size=written_bytes
153
+ )
154
+ keys.append(url_with_offset.encode())
155
+ offset += written_bytes
156
+ # Necessary because pyarrow.io.NativeFile does not flush() on close().
157
+ f.flush()
158
+ return keys
159
+
160
+ def _size_check(self, address_len, metadata_len, buffer_len, obtained_data_size):
161
+ """Check whether or not the obtained_data_size is as expected.
162
+
163
+ Args:
164
+ metadata_len: Actual metadata length of the object.
165
+ buffer_len: Actual buffer length of the object.
166
+ obtained_data_size: Data size specified in the
167
+ url_with_offset.
168
+
169
+ Raises:
170
+ ValueError if obtained_data_size is different from
171
+ address_len + metadata_len + buffer_len +
172
+ 24 (first 8 bytes to store length).
173
+ """
174
+ data_size_in_bytes = (
175
+ address_len + metadata_len + buffer_len + self.HEADER_LENGTH
176
+ )
177
+ if data_size_in_bytes != obtained_data_size:
178
+ raise ValueError(
179
+ f"Obtained data has a size of {data_size_in_bytes}, "
180
+ "although it is supposed to have the "
181
+ f"size of {obtained_data_size}."
182
+ )
183
+
184
+ @abc.abstractmethod
185
+ def spill_objects(self, object_refs, owner_addresses) -> List[str]:
186
+ """Spill objects to the external storage. Objects are specified
187
+ by their object refs.
188
+
189
+ Args:
190
+ object_refs: The list of the refs of the objects to be spilled.
191
+ owner_addresses: Owner addresses for the provided objects.
192
+ Returns:
193
+ A list of internal URLs with object offset.
194
+ """
195
+
196
+ @abc.abstractmethod
197
+ def restore_spilled_objects(
198
+ self, object_refs: List[ObjectRef], url_with_offset_list: List[str]
199
+ ) -> int:
200
+ """Restore objects from the external storage.
201
+
202
+ Args:
203
+ object_refs: List of object IDs (note that it is not ref).
204
+ url_with_offset_list: List of url_with_offset.
205
+
206
+ Returns:
207
+ The total number of bytes restored.
208
+ """
209
+
210
+ @abc.abstractmethod
211
+ def delete_spilled_objects(self, urls: List[str]):
212
+ """Delete objects that are spilled to the external storage.
213
+
214
+ Args:
215
+ urls: URLs that store spilled object files.
216
+
217
+ NOTE: This function should not fail if some of the urls
218
+ do not exist.
219
+ """
220
+
221
+ @abc.abstractmethod
222
+ def destroy_external_storage(self):
223
+ """Destroy external storage when a head node is down.
224
+
225
+ NOTE: This is currently working when the cluster is
226
+ started by ray.init
227
+ """
228
+
229
+
230
+ class NullStorage(ExternalStorage):
231
+ """The class that represents an uninitialized external storage."""
232
+
233
+ def spill_objects(self, object_refs, owner_addresses) -> List[str]:
234
+ raise NotImplementedError("External storage is not initialized")
235
+
236
+ def restore_spilled_objects(self, object_refs, url_with_offset_list):
237
+ raise NotImplementedError("External storage is not initialized")
238
+
239
+ def delete_spilled_objects(self, urls: List[str]):
240
+ raise NotImplementedError("External storage is not initialized")
241
+
242
+ def destroy_external_storage(self):
243
+ raise NotImplementedError("External storage is not initialized")
244
+
245
+
246
+ class FileSystemStorage(ExternalStorage):
247
+ """The class for filesystem-like external storage.
248
+
249
+ Raises:
250
+ ValueError: Raises directory path to
251
+ spill objects doesn't exist.
252
+ """
253
+
254
+ def __init__(
255
+ self,
256
+ node_id: str,
257
+ directory_path: Union[str, List[str]],
258
+ buffer_size: Optional[int] = None,
259
+ ):
260
+ # -- A list of directory paths to spill objects --
261
+ self._directory_paths = []
262
+ # -- Current directory to spill objects --
263
+ self._current_directory_index = 0
264
+ # -- File buffer size to spill objects --
265
+ self._buffer_size = -1
266
+
267
+ # Validation.
268
+ assert (
269
+ directory_path is not None
270
+ ), "directory_path should be provided to use object spilling."
271
+ if isinstance(directory_path, str):
272
+ directory_path = [directory_path]
273
+ assert isinstance(
274
+ directory_path, list
275
+ ), "Directory_path must be either a single string or a list of strings"
276
+ if buffer_size is not None:
277
+ assert isinstance(buffer_size, int), "buffer_size must be an integer."
278
+ self._buffer_size = buffer_size
279
+
280
+ # Create directories.
281
+ for path in directory_path:
282
+ full_dir_path = os.path.join(path, f"{DEFAULT_OBJECT_PREFIX}_{node_id}")
283
+ os.makedirs(full_dir_path, exist_ok=True)
284
+ if not os.path.exists(full_dir_path):
285
+ raise ValueError(
286
+ "The given directory path to store objects, "
287
+ f"{full_dir_path}, could not be created."
288
+ )
289
+ self._directory_paths.append(full_dir_path)
290
+ assert len(self._directory_paths) == len(directory_path)
291
+ # Choose the current directory.
292
+ # It chooses a random index to maximize multiple directories that are
293
+ # mounted at different point.
294
+ self._current_directory_index = random.randrange(0, len(self._directory_paths))
295
+
296
+ def spill_objects(self, object_refs, owner_addresses) -> List[str]:
297
+ if len(object_refs) == 0:
298
+ return []
299
+ # Choose the current directory path by round robin order.
300
+ self._current_directory_index = (self._current_directory_index + 1) % len(
301
+ self._directory_paths
302
+ )
303
+ directory_path = self._directory_paths[self._current_directory_index]
304
+
305
+ filename = _get_unique_spill_filename(object_refs)
306
+ url = f"{os.path.join(directory_path, filename)}"
307
+ with open(url, "wb", buffering=self._buffer_size) as f:
308
+ return self._write_multiple_objects(f, object_refs, owner_addresses, url)
309
+
310
+ def restore_spilled_objects(
311
+ self, object_refs: List[ObjectRef], url_with_offset_list: List[str]
312
+ ):
313
+ total = 0
314
+ for i in range(len(object_refs)):
315
+ object_ref = object_refs[i]
316
+ url_with_offset = url_with_offset_list[i].decode()
317
+ # Retrieve the information needed.
318
+ parsed_result = parse_url_with_offset(url_with_offset)
319
+ base_url = parsed_result.base_url
320
+ offset = parsed_result.offset
321
+ # Read a part of the file and recover the object.
322
+ with open(base_url, "rb") as f:
323
+ f.seek(offset)
324
+ address_len = int.from_bytes(f.read(8), byteorder="little")
325
+ metadata_len = int.from_bytes(f.read(8), byteorder="little")
326
+ buf_len = int.from_bytes(f.read(8), byteorder="little")
327
+ self._size_check(address_len, metadata_len, buf_len, parsed_result.size)
328
+ total += buf_len
329
+ owner_address = f.read(address_len)
330
+ metadata = f.read(metadata_len)
331
+ # read remaining data to our buffer
332
+ self._put_object_to_store(
333
+ metadata, buf_len, f, object_ref, owner_address
334
+ )
335
+ return total
336
+
337
+ def delete_spilled_objects(self, urls: List[str]):
338
+ for url in urls:
339
+ path = parse_url_with_offset(url.decode()).base_url
340
+ try:
341
+ os.remove(path)
342
+ except FileNotFoundError:
343
+ # Occurs when the urls are retried during worker crash/failure.
344
+ pass
345
+
346
+ def destroy_external_storage(self):
347
+ for directory_path in self._directory_paths:
348
+ self._destroy_external_storage(directory_path)
349
+
350
+ def _destroy_external_storage(self, directory_path):
351
+ # There's a race condition where IO workers are still
352
+ # deleting each objects while we try deleting the
353
+ # whole directory. So we should keep trying it until
354
+ # The directory is actually deleted.
355
+ while os.path.isdir(directory_path):
356
+ try:
357
+ shutil.rmtree(directory_path)
358
+ except (FileNotFoundError):
359
+ # If exception occurs when other IO workers are
360
+ # deleting the file at the same time.
361
+ pass
362
+ except Exception:
363
+ logger.exception(
364
+ "Error cleaning up spill files. "
365
+ "You might still have remaining spilled "
366
+ "objects inside `ray_spilled_objects` directory."
367
+ )
368
+ break
369
+
370
+
371
+ class ExternalStorageRayStorageImpl(ExternalStorage):
372
+ """Implements the external storage interface using the ray storage API."""
373
+
374
+ def __init__(
375
+ self,
376
+ node_id: str,
377
+ session_name: str,
378
+ # For remote spilling, at least 1MB is recommended.
379
+ buffer_size=1024 * 1024,
380
+ # Override the storage config for unit tests.
381
+ _force_storage_for_testing: Optional[str] = None,
382
+ ):
383
+ from ray._private import storage
384
+
385
+ if _force_storage_for_testing:
386
+ storage._reset()
387
+ storage._init_storage(_force_storage_for_testing, True)
388
+
389
+ self._fs, storage_prefix = storage._get_filesystem_internal()
390
+ self._buffer_size = buffer_size
391
+ self._prefix = os.path.join(
392
+ storage_prefix, f"{DEFAULT_OBJECT_PREFIX}_{node_id}", session_name
393
+ )
394
+ self._fs.create_dir(self._prefix)
395
+
396
+ def spill_objects(self, object_refs, owner_addresses) -> List[str]:
397
+ if len(object_refs) == 0:
398
+ return []
399
+ filename = _get_unique_spill_filename(object_refs)
400
+ url = f"{os.path.join(self._prefix, filename)}"
401
+ with self._fs.open_output_stream(url, buffer_size=self._buffer_size) as f:
402
+ return self._write_multiple_objects(f, object_refs, owner_addresses, url)
403
+
404
+ def restore_spilled_objects(
405
+ self, object_refs: List[ObjectRef], url_with_offset_list: List[str]
406
+ ):
407
+ total = 0
408
+ for i in range(len(object_refs)):
409
+ object_ref = object_refs[i]
410
+ url_with_offset = url_with_offset_list[i].decode()
411
+ # Retrieve the information needed.
412
+ parsed_result = parse_url_with_offset(url_with_offset)
413
+ base_url = parsed_result.base_url
414
+ offset = parsed_result.offset
415
+ # Read a part of the file and recover the object.
416
+ with self._fs.open_input_file(base_url) as f:
417
+ f.seek(offset)
418
+ address_len = int.from_bytes(f.read(8), byteorder="little")
419
+ metadata_len = int.from_bytes(f.read(8), byteorder="little")
420
+ buf_len = int.from_bytes(f.read(8), byteorder="little")
421
+ self._size_check(address_len, metadata_len, buf_len, parsed_result.size)
422
+ total += buf_len
423
+ owner_address = f.read(address_len)
424
+ metadata = f.read(metadata_len)
425
+ # read remaining data to our buffer
426
+ self._put_object_to_store(
427
+ metadata, buf_len, f, object_ref, owner_address
428
+ )
429
+ return total
430
+
431
+ def delete_spilled_objects(self, urls: List[str]):
432
+ for url in urls:
433
+ path = parse_url_with_offset(url.decode()).base_url
434
+ try:
435
+ self._fs.delete_file(path)
436
+ except FileNotFoundError:
437
+ # Occurs when the urls are retried during worker crash/failure.
438
+ pass
439
+
440
+ def destroy_external_storage(self):
441
+ try:
442
+ self._fs.delete_dir(self._prefix)
443
+ except Exception:
444
+ logger.exception(
445
+ "Error cleaning up spill files. "
446
+ "You might still have remaining spilled "
447
+ "objects inside `{}`.".format(self._prefix)
448
+ )
449
+
450
+
451
+ class ExternalStorageSmartOpenImpl(ExternalStorage):
452
+ """The external storage class implemented by smart_open.
453
+ (https://github.com/RaRe-Technologies/smart_open)
454
+
455
+ Smart open supports multiple backend with the same APIs.
456
+
457
+ To use this implementation, you should pre-create the given uri.
458
+ For example, if your uri is a local file path, you should pre-create
459
+ the directory.
460
+
461
+ Args:
462
+ uri: Storage URI used for smart open.
463
+ prefix: Prefix of objects that are stored.
464
+ override_transport_params: Overriding the default value of
465
+ transport_params for smart-open library.
466
+
467
+ Raises:
468
+ ModuleNotFoundError: If it fails to setup.
469
+ For example, if smart open library
470
+ is not downloaded, this will fail.
471
+ """
472
+
473
+ def __init__(
474
+ self,
475
+ node_id: str,
476
+ uri: str or list,
477
+ override_transport_params: dict = None,
478
+ buffer_size=1024 * 1024, # For remote spilling, at least 1MB is recommended.
479
+ ):
480
+ try:
481
+ from smart_open import open # noqa
482
+ except ModuleNotFoundError as e:
483
+ raise ModuleNotFoundError(
484
+ "Smart open is chosen to be a object spilling "
485
+ "external storage, but smart_open and boto3 "
486
+ f"is not downloaded. Original error: {e}"
487
+ )
488
+
489
+ # Validation
490
+ assert uri is not None, "uri should be provided to use object spilling."
491
+ if isinstance(uri, str):
492
+ uri = [uri]
493
+ assert isinstance(uri, list), "uri must be a single string or list of strings."
494
+ assert isinstance(buffer_size, int), "buffer_size must be an integer."
495
+
496
+ uri_is_s3 = [u.startswith("s3://") for u in uri]
497
+ self.is_for_s3 = all(uri_is_s3)
498
+ if not self.is_for_s3:
499
+ assert not any(uri_is_s3), "all uri's must be s3 or none can be s3."
500
+ self._uris = uri
501
+ else:
502
+ self._uris = [u.strip("/") for u in uri]
503
+ assert len(self._uris) == len(uri)
504
+
505
+ self._current_uri_index = random.randrange(0, len(self._uris))
506
+ self.prefix = f"{DEFAULT_OBJECT_PREFIX}_{node_id}"
507
+ self.override_transport_params = override_transport_params or {}
508
+
509
+ if self.is_for_s3:
510
+ import boto3 # noqa
511
+
512
+ # Setup boto3. It is essential because if we don't create boto
513
+ # session, smart_open will create a new session for every
514
+ # open call.
515
+ self.s3 = boto3.resource(service_name="s3")
516
+
517
+ # smart_open always seek to 0 if we don't set this argument.
518
+ # This will lead us to call a Object.get when it is not necessary,
519
+ # so defer seek and call seek before reading objects instead.
520
+ self.transport_params = {
521
+ "defer_seek": True,
522
+ "resource": self.s3,
523
+ "buffer_size": buffer_size,
524
+ }
525
+ else:
526
+ self.transport_params = {}
527
+
528
+ self.transport_params.update(self.override_transport_params)
529
+
530
+ def spill_objects(self, object_refs, owner_addresses) -> List[str]:
531
+ if len(object_refs) == 0:
532
+ return []
533
+ from smart_open import open
534
+
535
+ # Choose the current uri by round robin order.
536
+ self._current_uri_index = (self._current_uri_index + 1) % len(self._uris)
537
+ uri = self._uris[self._current_uri_index]
538
+
539
+ key = f"{self.prefix}-{_get_unique_spill_filename(object_refs)}"
540
+ url = f"{uri}/{key}"
541
+
542
+ with open(
543
+ url,
544
+ mode="wb",
545
+ transport_params=self.transport_params,
546
+ ) as file_like:
547
+ return self._write_multiple_objects(
548
+ file_like, object_refs, owner_addresses, url
549
+ )
550
+
551
+ def restore_spilled_objects(
552
+ self, object_refs: List[ObjectRef], url_with_offset_list: List[str]
553
+ ):
554
+ from smart_open import open
555
+
556
+ total = 0
557
+ for i in range(len(object_refs)):
558
+ object_ref = object_refs[i]
559
+ url_with_offset = url_with_offset_list[i].decode()
560
+
561
+ # Retrieve the information needed.
562
+ parsed_result = parse_url_with_offset(url_with_offset)
563
+ base_url = parsed_result.base_url
564
+ offset = parsed_result.offset
565
+
566
+ with open(base_url, "rb", transport_params=self.transport_params) as f:
567
+ # smart open seek reads the file from offset-end_of_the_file
568
+ # when the seek is called.
569
+ f.seek(offset)
570
+ address_len = int.from_bytes(f.read(8), byteorder="little")
571
+ metadata_len = int.from_bytes(f.read(8), byteorder="little")
572
+ buf_len = int.from_bytes(f.read(8), byteorder="little")
573
+ self._size_check(address_len, metadata_len, buf_len, parsed_result.size)
574
+ owner_address = f.read(address_len)
575
+ total += buf_len
576
+ metadata = f.read(metadata_len)
577
+ # read remaining data to our buffer
578
+ self._put_object_to_store(
579
+ metadata, buf_len, f, object_ref, owner_address
580
+ )
581
+ return total
582
+
583
+ def delete_spilled_objects(self, urls: List[str]):
584
+ pass
585
+
586
+ def destroy_external_storage(self):
587
+ pass
588
+
589
+
590
+ _external_storage = NullStorage()
591
+
592
+
593
+ class UnstableFileStorage(FileSystemStorage):
594
+ """This class is for testing with writing failure."""
595
+
596
+ def __init__(self, node_id: str, **kwargs):
597
+ super().__init__(node_id, **kwargs)
598
+ self._failure_rate = 0.1
599
+ self._partial_failure_ratio = 0.2
600
+
601
+ def spill_objects(self, object_refs, owner_addresses) -> List[str]:
602
+ r = random.random() < self._failure_rate
603
+ failed = r < self._failure_rate
604
+ partial_failed = r < self._partial_failure_ratio
605
+ if failed:
606
+ raise IOError("Spilling object failed")
607
+ elif partial_failed:
608
+ i = random.choice(range(len(object_refs)))
609
+ return super().spill_objects(object_refs[:i], owner_addresses)
610
+ else:
611
+ return super().spill_objects(object_refs, owner_addresses)
612
+
613
+
614
+ class SlowFileStorage(FileSystemStorage):
615
+ """This class is for testing slow object spilling."""
616
+
617
+ def __init__(self, node_id: str, **kwargs):
618
+ super().__init__(node_id, **kwargs)
619
+ self._min_delay = 1
620
+ self._max_delay = 2
621
+
622
+ def spill_objects(self, object_refs, owner_addresses) -> List[str]:
623
+ delay = random.random() * (self._max_delay - self._min_delay) + self._min_delay
624
+ time.sleep(delay)
625
+ return super().spill_objects(object_refs, owner_addresses)
626
+
627
+
628
+ def setup_external_storage(config, node_id, session_name):
629
+ """Setup the external storage according to the config."""
630
+ assert node_id is not None, "node_id should be provided."
631
+ global _external_storage
632
+ if config:
633
+ storage_type = config["type"]
634
+ if storage_type == "filesystem":
635
+ _external_storage = FileSystemStorage(node_id, **config["params"])
636
+ elif storage_type == "ray_storage":
637
+ _external_storage = ExternalStorageRayStorageImpl(
638
+ node_id, session_name, **config["params"]
639
+ )
640
+ elif storage_type == "smart_open":
641
+ _external_storage = ExternalStorageSmartOpenImpl(
642
+ node_id, **config["params"]
643
+ )
644
+ elif storage_type == "mock_distributed_fs":
645
+ # This storage is used to unit test distributed external storages.
646
+ # TODO(sang): Delete it after introducing the mock S3 test.
647
+ _external_storage = FileSystemStorage(node_id, **config["params"])
648
+ elif storage_type == "unstable_fs":
649
+ # This storage is used to unit test unstable file system for fault
650
+ # tolerance.
651
+ _external_storage = UnstableFileStorage(node_id, **config["params"])
652
+ elif storage_type == "slow_fs":
653
+ # This storage is used to unit test slow filesystems.
654
+ _external_storage = SlowFileStorage(node_id, **config["params"])
655
+ else:
656
+ raise ValueError(f"Unknown external storage type: {storage_type}")
657
+ else:
658
+ _external_storage = NullStorage()
659
+ return _external_storage
660
+
661
+
662
+ def reset_external_storage():
663
+ global _external_storage
664
+ _external_storage = NullStorage()
665
+
666
+
667
+ def spill_objects(object_refs, owner_addresses):
668
+ """Spill objects to the external storage. Objects are specified
669
+ by their object refs.
670
+
671
+ Args:
672
+ object_refs: The list of the refs of the objects to be spilled.
673
+ owner_addresses: The owner addresses of the provided object refs.
674
+ Returns:
675
+ A list of keys corresponding to the input object refs.
676
+ """
677
+ return _external_storage.spill_objects(object_refs, owner_addresses)
678
+
679
+
680
+ def restore_spilled_objects(
681
+ object_refs: List[ObjectRef], url_with_offset_list: List[str]
682
+ ):
683
+ """Restore objects from the external storage.
684
+
685
+ Args:
686
+ object_refs: List of object IDs (note that it is not ref).
687
+ url_with_offset_list: List of url_with_offset.
688
+ """
689
+ return _external_storage.restore_spilled_objects(object_refs, url_with_offset_list)
690
+
691
+
692
+ def delete_spilled_objects(urls: List[str]):
693
+ """Delete objects that are spilled to the external storage.
694
+
695
+ Args:
696
+ urls: URLs that store spilled object files.
697
+ """
698
+ _external_storage.delete_spilled_objects(urls)
699
+
700
+
701
+ def _get_unique_spill_filename(object_refs: List[ObjectRef]):
702
+ """Generate a unqiue spill file name.
703
+
704
+ Args:
705
+ object_refs: objects to be spilled in this file.
706
+ """
707
+ return f"{uuid.uuid4().hex}-multi-{len(object_refs)}"
.venv/lib/python3.11/site-packages/ray/_private/function_manager.py ADDED
@@ -0,0 +1,706 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dis
2
+ import sys
3
+ import hashlib
4
+ import importlib
5
+ import inspect
6
+ import json
7
+ import logging
8
+ import os
9
+ import threading
10
+ import time
11
+ import traceback
12
+ from collections import defaultdict, namedtuple
13
+ from typing import Optional, Callable
14
+
15
+ import ray
16
+ from ray.remote_function import RemoteFunction
17
+ import ray._private.profiling as profiling
18
+ from ray import cloudpickle as pickle
19
+ from ray._private import ray_constants
20
+ from ray._private.inspect_util import (
21
+ is_class_method,
22
+ is_function_or_method,
23
+ is_static_method,
24
+ )
25
+ from ray._private.ray_constants import KV_NAMESPACE_FUNCTION_TABLE
26
+ from ray._private.utils import (
27
+ check_oversized_function,
28
+ ensure_str,
29
+ format_error_message,
30
+ )
31
+ from ray._private.serialization import pickle_dumps
32
+ from ray._raylet import (
33
+ JobID,
34
+ PythonFunctionDescriptor,
35
+ WORKER_PROCESS_SETUP_HOOK_KEY_NAME_GCS,
36
+ )
37
+
38
+ FunctionExecutionInfo = namedtuple(
39
+ "FunctionExecutionInfo", ["function", "function_name", "max_calls"]
40
+ )
41
+ ImportedFunctionInfo = namedtuple(
42
+ "ImportedFunctionInfo",
43
+ ["job_id", "function_id", "function_name", "function", "module", "max_calls"],
44
+ )
45
+
46
+ """FunctionExecutionInfo: A named tuple storing remote function information."""
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+
51
+ def make_function_table_key(key_type: bytes, job_id: JobID, key: Optional[bytes]):
52
+ if key is None:
53
+ return b":".join([key_type, job_id.hex().encode()])
54
+ else:
55
+ return b":".join([key_type, job_id.hex().encode(), key])
56
+
57
+
58
+ class FunctionActorManager:
59
+ """A class used to export/load remote functions and actors.
60
+ Attributes:
61
+ _worker: The associated worker that this manager related.
62
+ _functions_to_export: The remote functions to export when
63
+ the worker gets connected.
64
+ _actors_to_export: The actors to export when the worker gets
65
+ connected.
66
+ _function_execution_info: The function_id
67
+ and execution_info.
68
+ _num_task_executions: The function
69
+ execution times.
70
+ imported_actor_classes: The set of actor classes keys (format:
71
+ ActorClass:function_id) that are already in GCS.
72
+ """
73
+
74
+ def __init__(self, worker):
75
+ self._worker = worker
76
+ self._functions_to_export = []
77
+ self._actors_to_export = []
78
+ # This field is a dictionary that maps function IDs
79
+ # to a FunctionExecutionInfo object. This should only be used on
80
+ # workers that execute remote functions.
81
+ self._function_execution_info = defaultdict(lambda: {})
82
+ self._num_task_executions = defaultdict(lambda: {})
83
+ # A set of all of the actor class keys that have been imported by the
84
+ # import thread. It is safe to convert this worker into an actor of
85
+ # these types.
86
+ self.imported_actor_classes = set()
87
+ self._loaded_actor_classes = {}
88
+ # Deserialize an ActorHandle will call load_actor_class(). If a
89
+ # function closure captured an ActorHandle, the deserialization of the
90
+ # function will be:
91
+ # -> fetch_and_register_remote_function (acquire lock)
92
+ # -> _load_actor_class_from_gcs (acquire lock, too)
93
+ # So, the lock should be a reentrant lock.
94
+ self.lock = threading.RLock()
95
+
96
+ self.execution_infos = {}
97
+ # This is the counter to keep track of how many keys have already
98
+ # been exported so that we can find next key quicker.
99
+ self._num_exported = 0
100
+ # This is to protect self._num_exported when doing exporting
101
+ self._export_lock = threading.Lock()
102
+
103
+ def increase_task_counter(self, function_descriptor):
104
+ function_id = function_descriptor.function_id
105
+ self._num_task_executions[function_id] += 1
106
+
107
+ def get_task_counter(self, function_descriptor):
108
+ function_id = function_descriptor.function_id
109
+ return self._num_task_executions[function_id]
110
+
111
+ def compute_collision_identifier(self, function_or_class):
112
+ """The identifier is used to detect excessive duplicate exports.
113
+ The identifier is used to determine when the same function or class is
114
+ exported many times. This can yield false positives.
115
+ Args:
116
+ function_or_class: The function or class to compute an identifier
117
+ for.
118
+ Returns:
119
+ The identifier. Note that different functions or classes can give
120
+ rise to same identifier. However, the same function should
121
+ hopefully always give rise to the same identifier. TODO(rkn):
122
+ verify if this is actually the case. Note that if the
123
+ identifier is incorrect in any way, then we may give warnings
124
+ unnecessarily or fail to give warnings, but the application's
125
+ behavior won't change.
126
+ """
127
+ import io
128
+
129
+ string_file = io.StringIO()
130
+ dis.dis(function_or_class, file=string_file, depth=2)
131
+ collision_identifier = function_or_class.__name__ + ":" + string_file.getvalue()
132
+
133
+ # Return a hash of the identifier in case it is too large.
134
+ return hashlib.sha1(collision_identifier.encode("utf-8")).digest()
135
+
136
+ def load_function_or_class_from_local(self, module_name, function_or_class_name):
137
+ """Try to load a function or class in the module from local."""
138
+ module = importlib.import_module(module_name)
139
+ parts = [part for part in function_or_class_name.split(".") if part]
140
+ object = module
141
+ try:
142
+ for part in parts:
143
+ object = getattr(object, part)
144
+ return object
145
+ except Exception:
146
+ return None
147
+
148
+ def export_setup_func(
149
+ self, setup_func: Callable, timeout: Optional[int] = None
150
+ ) -> bytes:
151
+ """Export the setup hook function and return the key."""
152
+ pickled_function = pickle_dumps(
153
+ setup_func,
154
+ "Cannot serialize the worker_process_setup_hook " f"{setup_func.__name__}",
155
+ )
156
+
157
+ function_to_run_id = hashlib.shake_128(pickled_function).digest(
158
+ ray_constants.ID_SIZE
159
+ )
160
+ key = make_function_table_key(
161
+ # This value should match with gcs_function_manager.h.
162
+ # Otherwise, it won't be GC'ed.
163
+ WORKER_PROCESS_SETUP_HOOK_KEY_NAME_GCS.encode(),
164
+ # b"FunctionsToRun",
165
+ self._worker.current_job_id.binary(),
166
+ function_to_run_id,
167
+ )
168
+
169
+ check_oversized_function(
170
+ pickled_function, setup_func.__name__, "function", self._worker
171
+ )
172
+
173
+ try:
174
+ self._worker.gcs_client.internal_kv_put(
175
+ key,
176
+ pickle.dumps(
177
+ {
178
+ "job_id": self._worker.current_job_id.binary(),
179
+ "function_id": function_to_run_id,
180
+ "function": pickled_function,
181
+ }
182
+ ),
183
+ # overwrite
184
+ True,
185
+ ray_constants.KV_NAMESPACE_FUNCTION_TABLE,
186
+ timeout=timeout,
187
+ )
188
+ except Exception as e:
189
+ logger.exception(
190
+ "Failed to export the setup hook " f"{setup_func.__name__}."
191
+ )
192
+ raise e
193
+
194
+ return key
195
+
196
+ def export(self, remote_function):
197
+ """Pickle a remote function and export it to redis.
198
+ Args:
199
+ remote_function: the RemoteFunction object.
200
+ """
201
+ if self._worker.load_code_from_local:
202
+ function_descriptor = remote_function._function_descriptor
203
+ module_name, function_name = (
204
+ function_descriptor.module_name,
205
+ function_descriptor.function_name,
206
+ )
207
+ # If the function is dynamic, we still export it to GCS
208
+ # even if load_code_from_local is set True.
209
+ if (
210
+ self.load_function_or_class_from_local(module_name, function_name)
211
+ is not None
212
+ ):
213
+ return
214
+ function = remote_function._function
215
+ pickled_function = remote_function._pickled_function
216
+
217
+ check_oversized_function(
218
+ pickled_function,
219
+ remote_function._function_name,
220
+ "remote function",
221
+ self._worker,
222
+ )
223
+ key = make_function_table_key(
224
+ b"RemoteFunction",
225
+ self._worker.current_job_id,
226
+ remote_function._function_descriptor.function_id.binary(),
227
+ )
228
+ if self._worker.gcs_client.internal_kv_exists(key, KV_NAMESPACE_FUNCTION_TABLE):
229
+ return
230
+ val = pickle.dumps(
231
+ {
232
+ "job_id": self._worker.current_job_id.binary(),
233
+ "function_id": remote_function._function_descriptor.function_id.binary(), # noqa: E501
234
+ "function_name": remote_function._function_name,
235
+ "module": function.__module__,
236
+ "function": pickled_function,
237
+ "collision_identifier": self.compute_collision_identifier(function),
238
+ "max_calls": remote_function._max_calls,
239
+ }
240
+ )
241
+ self._worker.gcs_client.internal_kv_put(
242
+ key, val, True, KV_NAMESPACE_FUNCTION_TABLE
243
+ )
244
+
245
+ def fetch_registered_method(
246
+ self, key: str, timeout: Optional[int] = None
247
+ ) -> Optional[ImportedFunctionInfo]:
248
+ vals = self._worker.gcs_client.internal_kv_get(
249
+ key, KV_NAMESPACE_FUNCTION_TABLE, timeout=timeout
250
+ )
251
+ if vals is None:
252
+ return None
253
+ else:
254
+ vals = pickle.loads(vals)
255
+ fields = [
256
+ "job_id",
257
+ "function_id",
258
+ "function_name",
259
+ "function",
260
+ "module",
261
+ "max_calls",
262
+ ]
263
+ return ImportedFunctionInfo._make(vals.get(field) for field in fields)
264
+
265
+ def fetch_and_register_remote_function(self, key):
266
+ """Import a remote function."""
267
+ remote_function_info = self.fetch_registered_method(key)
268
+ if not remote_function_info:
269
+ return False
270
+ (
271
+ job_id_str,
272
+ function_id_str,
273
+ function_name,
274
+ serialized_function,
275
+ module,
276
+ max_calls,
277
+ ) = remote_function_info
278
+
279
+ function_id = ray.FunctionID(function_id_str)
280
+ job_id = ray.JobID(job_id_str)
281
+ max_calls = int(max_calls)
282
+
283
+ # This function is called by ImportThread. This operation needs to be
284
+ # atomic. Otherwise, there is race condition. Another thread may use
285
+ # the temporary function above before the real function is ready.
286
+ with self.lock:
287
+ self._num_task_executions[function_id] = 0
288
+
289
+ try:
290
+ function = pickle.loads(serialized_function)
291
+ except Exception:
292
+ # If an exception was thrown when the remote function was
293
+ # imported, we record the traceback and notify the scheduler
294
+ # of the failure.
295
+ traceback_str = format_error_message(traceback.format_exc())
296
+
297
+ def f(*args, **kwargs):
298
+ raise RuntimeError(
299
+ "The remote function failed to import on the "
300
+ "worker. This may be because needed library "
301
+ "dependencies are not installed in the worker "
302
+ "environment or cannot be found from sys.path "
303
+ f"{sys.path}:\n\n{traceback_str}"
304
+ )
305
+
306
+ # Use a placeholder method when function pickled failed
307
+ self._function_execution_info[function_id] = FunctionExecutionInfo(
308
+ function=f, function_name=function_name, max_calls=max_calls
309
+ )
310
+
311
+ # Log the error message. Log at DEBUG level to avoid overly
312
+ # spamming the log on import failure. The user gets the error
313
+ # via the RuntimeError message above.
314
+ logger.debug(
315
+ "Failed to unpickle the remote function "
316
+ f"'{function_name}' with "
317
+ f"function ID {function_id.hex()}. "
318
+ f"Job ID:{job_id}."
319
+ f"Traceback:\n{traceback_str}. "
320
+ )
321
+ else:
322
+ # The below line is necessary. Because in the driver process,
323
+ # if the function is defined in the file where the python
324
+ # script was started from, its module is `__main__`.
325
+ # However in the worker process, the `__main__` module is a
326
+ # different module, which is `default_worker.py`
327
+ function.__module__ = module
328
+ self._function_execution_info[function_id] = FunctionExecutionInfo(
329
+ function=function, function_name=function_name, max_calls=max_calls
330
+ )
331
+ return True
332
+
333
+ def get_execution_info(self, job_id, function_descriptor):
334
+ """Get the FunctionExecutionInfo of a remote function.
335
+ Args:
336
+ job_id: ID of the job that the function belongs to.
337
+ function_descriptor: The FunctionDescriptor of the function to get.
338
+ Returns:
339
+ A FunctionExecutionInfo object.
340
+ """
341
+ function_id = function_descriptor.function_id
342
+ # If the function has already been loaded,
343
+ # There's no need to load again
344
+ if function_id in self._function_execution_info:
345
+ return self._function_execution_info[function_id]
346
+ if self._worker.load_code_from_local:
347
+ # Load function from local code.
348
+ if not function_descriptor.is_actor_method():
349
+ # If the function is not able to be loaded,
350
+ # try to load it from GCS,
351
+ # even if load_code_from_local is set True
352
+ if self._load_function_from_local(function_descriptor) is True:
353
+ return self._function_execution_info[function_id]
354
+ # Load function from GCS.
355
+ # Wait until the function to be executed has actually been
356
+ # registered on this worker. We will push warnings to the user if
357
+ # we spend too long in this loop.
358
+ # The driver function may not be found in sys.path. Try to load
359
+ # the function from GCS.
360
+ with profiling.profile("wait_for_function"):
361
+ self._wait_for_function(function_descriptor, job_id)
362
+ try:
363
+ function_id = function_descriptor.function_id
364
+ info = self._function_execution_info[function_id]
365
+ except KeyError as e:
366
+ message = (
367
+ "Error occurs in get_execution_info: "
368
+ "job_id: %s, function_descriptor: %s. Message: %s"
369
+ % (job_id, function_descriptor, e)
370
+ )
371
+ raise KeyError(message)
372
+ return info
373
+
374
+ def _load_function_from_local(self, function_descriptor):
375
+ assert not function_descriptor.is_actor_method()
376
+ function_id = function_descriptor.function_id
377
+
378
+ module_name, function_name = (
379
+ function_descriptor.module_name,
380
+ function_descriptor.function_name,
381
+ )
382
+
383
+ object = self.load_function_or_class_from_local(module_name, function_name)
384
+ if object is not None:
385
+ # Directly importing from local may break function with dynamic ray.remote,
386
+ # such as the _start_controller function utilized for the Ray service.
387
+ if isinstance(object, RemoteFunction):
388
+ function = object._function
389
+ else:
390
+ function = object
391
+ self._function_execution_info[function_id] = FunctionExecutionInfo(
392
+ function=function,
393
+ function_name=function_name,
394
+ max_calls=0,
395
+ )
396
+ self._num_task_executions[function_id] = 0
397
+ return True
398
+ else:
399
+ return False
400
+
401
+ def _wait_for_function(self, function_descriptor, job_id: str, timeout=10):
402
+ """Wait until the function to be executed is present on this worker.
403
+ This method will simply loop until the import thread has imported the
404
+ relevant function. If we spend too long in this loop, that may indicate
405
+ a problem somewhere and we will push an error message to the user.
406
+ If this worker is an actor, then this will wait until the actor has
407
+ been defined.
408
+ Args:
409
+ function_descriptor : The FunctionDescriptor of the function that
410
+ we want to execute.
411
+ job_id: The ID of the job to push the error message to
412
+ if this times out.
413
+ """
414
+ start_time = time.time()
415
+ # Only send the warning once.
416
+ warning_sent = False
417
+ while True:
418
+ with self.lock:
419
+ if self._worker.actor_id.is_nil():
420
+ if function_descriptor.function_id in self._function_execution_info:
421
+ break
422
+ else:
423
+ key = make_function_table_key(
424
+ b"RemoteFunction",
425
+ job_id,
426
+ function_descriptor.function_id.binary(),
427
+ )
428
+ if self.fetch_and_register_remote_function(key) is True:
429
+ break
430
+ else:
431
+ assert not self._worker.actor_id.is_nil()
432
+ # Actor loading will happen when execute_task is called.
433
+ assert self._worker.actor_id in self._worker.actors
434
+ break
435
+
436
+ if time.time() - start_time > timeout:
437
+ warning_message = (
438
+ "This worker was asked to execute a function "
439
+ f"that has not been registered ({function_descriptor}, "
440
+ f"node={self._worker.node_ip_address}, "
441
+ f"worker_id={self._worker.worker_id.hex()}, "
442
+ f"pid={os.getpid()}). You may have to restart Ray."
443
+ )
444
+ if not warning_sent:
445
+ logger.error(warning_message)
446
+ ray._private.utils.push_error_to_driver(
447
+ self._worker,
448
+ ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR,
449
+ warning_message,
450
+ job_id=job_id,
451
+ )
452
+ warning_sent = True
453
+ time.sleep(0.001)
454
+
455
+ def export_actor_class(
456
+ self, Class, actor_creation_function_descriptor, actor_method_names
457
+ ):
458
+ if self._worker.load_code_from_local:
459
+ module_name, class_name = (
460
+ actor_creation_function_descriptor.module_name,
461
+ actor_creation_function_descriptor.class_name,
462
+ )
463
+ # If the class is dynamic, we still export it to GCS
464
+ # even if load_code_from_local is set True.
465
+ if (
466
+ self.load_function_or_class_from_local(module_name, class_name)
467
+ is not None
468
+ ):
469
+ return
470
+
471
+ # `current_job_id` shouldn't be NIL, unless:
472
+ # 1) This worker isn't an actor;
473
+ # 2) And a previous task started a background thread, which didn't
474
+ # finish before the task finished, and still uses Ray API
475
+ # after that.
476
+ assert not self._worker.current_job_id.is_nil(), (
477
+ "You might have started a background thread in a non-actor "
478
+ "task, please make sure the thread finishes before the "
479
+ "task finishes."
480
+ )
481
+ job_id = self._worker.current_job_id
482
+ key = make_function_table_key(
483
+ b"ActorClass",
484
+ job_id,
485
+ actor_creation_function_descriptor.function_id.binary(),
486
+ )
487
+ serialized_actor_class = pickle_dumps(
488
+ Class,
489
+ f"Could not serialize the actor class "
490
+ f"{actor_creation_function_descriptor.repr}",
491
+ )
492
+ actor_class_info = {
493
+ "class_name": actor_creation_function_descriptor.class_name.split(".")[-1],
494
+ "module": actor_creation_function_descriptor.module_name,
495
+ "class": serialized_actor_class,
496
+ "job_id": job_id.binary(),
497
+ "collision_identifier": self.compute_collision_identifier(Class),
498
+ "actor_method_names": json.dumps(list(actor_method_names)),
499
+ }
500
+
501
+ check_oversized_function(
502
+ actor_class_info["class"],
503
+ actor_class_info["class_name"],
504
+ "actor",
505
+ self._worker,
506
+ )
507
+
508
+ self._worker.gcs_client.internal_kv_put(
509
+ key, pickle.dumps(actor_class_info), True, KV_NAMESPACE_FUNCTION_TABLE
510
+ )
511
+ # TODO(rkn): Currently we allow actor classes to be defined
512
+ # within tasks. I tried to disable this, but it may be necessary
513
+ # because of https://github.com/ray-project/ray/issues/1146.
514
+
515
+ def load_actor_class(self, job_id, actor_creation_function_descriptor):
516
+ """Load the actor class.
517
+ Args:
518
+ job_id: job ID of the actor.
519
+ actor_creation_function_descriptor: Function descriptor of
520
+ the actor constructor.
521
+ Returns:
522
+ The actor class.
523
+ """
524
+ function_id = actor_creation_function_descriptor.function_id
525
+ # Check if the actor class already exists in the cache.
526
+ actor_class = self._loaded_actor_classes.get(function_id, None)
527
+ if actor_class is None:
528
+ # Load actor class.
529
+ if self._worker.load_code_from_local:
530
+ # Load actor class from local code first.
531
+ actor_class = self._load_actor_class_from_local(
532
+ actor_creation_function_descriptor
533
+ )
534
+ # If the actor is unable to be loaded
535
+ # from local, try to load it
536
+ # from GCS even if load_code_from_local is set True
537
+ if actor_class is None:
538
+ actor_class = self._load_actor_class_from_gcs(
539
+ job_id, actor_creation_function_descriptor
540
+ )
541
+
542
+ else:
543
+ # Load actor class from GCS.
544
+ actor_class = self._load_actor_class_from_gcs(
545
+ job_id, actor_creation_function_descriptor
546
+ )
547
+ # Save the loaded actor class in cache.
548
+ self._loaded_actor_classes[function_id] = actor_class
549
+
550
+ # Generate execution info for the methods of this actor class.
551
+ module_name = actor_creation_function_descriptor.module_name
552
+ actor_class_name = actor_creation_function_descriptor.class_name
553
+ actor_methods = inspect.getmembers(
554
+ actor_class, predicate=is_function_or_method
555
+ )
556
+ for actor_method_name, actor_method in actor_methods:
557
+ # Actor creation function descriptor use a unique function
558
+ # hash to solve actor name conflict. When constructing an
559
+ # actor, the actor creation function descriptor will be the
560
+ # key to find __init__ method execution info. So, here we
561
+ # use actor creation function descriptor as method descriptor
562
+ # for generating __init__ method execution info.
563
+ if actor_method_name == "__init__":
564
+ method_descriptor = actor_creation_function_descriptor
565
+ else:
566
+ method_descriptor = PythonFunctionDescriptor(
567
+ module_name, actor_method_name, actor_class_name
568
+ )
569
+ method_id = method_descriptor.function_id
570
+ executor = self._make_actor_method_executor(
571
+ actor_method_name,
572
+ actor_method,
573
+ actor_imported=True,
574
+ )
575
+ self._function_execution_info[method_id] = FunctionExecutionInfo(
576
+ function=executor,
577
+ function_name=actor_method_name,
578
+ max_calls=0,
579
+ )
580
+ self._num_task_executions[method_id] = 0
581
+ self._num_task_executions[function_id] = 0
582
+ return actor_class
583
+
584
+ def _load_actor_class_from_local(self, actor_creation_function_descriptor):
585
+ """Load actor class from local code."""
586
+ module_name, class_name = (
587
+ actor_creation_function_descriptor.module_name,
588
+ actor_creation_function_descriptor.class_name,
589
+ )
590
+
591
+ object = self.load_function_or_class_from_local(module_name, class_name)
592
+
593
+ if object is not None:
594
+ if isinstance(object, ray.actor.ActorClass):
595
+ return object.__ray_metadata__.modified_class
596
+ else:
597
+ return object
598
+ else:
599
+ return None
600
+
601
+ def _create_fake_actor_class(
602
+ self, actor_class_name, actor_method_names, traceback_str
603
+ ):
604
+ class TemporaryActor:
605
+ pass
606
+
607
+ def temporary_actor_method(*args, **kwargs):
608
+ raise RuntimeError(
609
+ f"The actor with name {actor_class_name} "
610
+ "failed to import on the worker. This may be because "
611
+ "needed library dependencies are not installed in the "
612
+ f"worker environment:\n\n{traceback_str}"
613
+ )
614
+
615
+ for method in actor_method_names:
616
+ setattr(TemporaryActor, method, temporary_actor_method)
617
+
618
+ return TemporaryActor
619
+
620
+ def _load_actor_class_from_gcs(self, job_id, actor_creation_function_descriptor):
621
+ """Load actor class from GCS."""
622
+ key = make_function_table_key(
623
+ b"ActorClass",
624
+ job_id,
625
+ actor_creation_function_descriptor.function_id.binary(),
626
+ )
627
+
628
+ # Fetch raw data from GCS.
629
+ vals = self._worker.gcs_client.internal_kv_get(key, KV_NAMESPACE_FUNCTION_TABLE)
630
+ fields = ["job_id", "class_name", "module", "class", "actor_method_names"]
631
+ if vals is None:
632
+ vals = {}
633
+ else:
634
+ vals = pickle.loads(vals)
635
+ (job_id_str, class_name, module, pickled_class, actor_method_names) = (
636
+ vals.get(field) for field in fields
637
+ )
638
+
639
+ class_name = ensure_str(class_name)
640
+ module_name = ensure_str(module)
641
+ job_id = ray.JobID(job_id_str)
642
+ actor_method_names = json.loads(ensure_str(actor_method_names))
643
+
644
+ actor_class = None
645
+ try:
646
+ with self.lock:
647
+ actor_class = pickle.loads(pickled_class)
648
+ except Exception:
649
+ logger.debug("Failed to load actor class %s.", class_name)
650
+ # If an exception was thrown when the actor was imported, we record
651
+ # the traceback and notify the scheduler of the failure.
652
+ traceback_str = format_error_message(traceback.format_exc())
653
+ # The actor class failed to be unpickled, create a fake actor
654
+ # class instead (just to produce error messages and to prevent
655
+ # the driver from hanging).
656
+ actor_class = self._create_fake_actor_class(
657
+ class_name, actor_method_names, traceback_str
658
+ )
659
+
660
+ # The below line is necessary. Because in the driver process,
661
+ # if the function is defined in the file where the python script
662
+ # was started from, its module is `__main__`.
663
+ # However in the worker process, the `__main__` module is a
664
+ # different module, which is `default_worker.py`
665
+ actor_class.__module__ = module_name
666
+ return actor_class
667
+
668
+ def _make_actor_method_executor(
669
+ self, method_name: str, method, actor_imported: bool
670
+ ):
671
+ """Make an executor that wraps a user-defined actor method.
672
+ The wrapped method updates the worker's internal state and performs any
673
+ necessary checkpointing operations.
674
+ Args:
675
+ method_name: The name of the actor method.
676
+ method: The actor method to wrap. This should be a
677
+ method defined on the actor class and should therefore take an
678
+ instance of the actor as the first argument.
679
+ actor_imported: Whether the actor has been imported.
680
+ Checkpointing operations will not be run if this is set to
681
+ False.
682
+ Returns:
683
+ A function that executes the given actor method on the worker's
684
+ stored instance of the actor. The function also updates the
685
+ worker's internal state to record the executed method.
686
+ """
687
+
688
+ def actor_method_executor(__ray_actor, *args, **kwargs):
689
+ # Execute the assigned method.
690
+ is_bound = is_class_method(method) or is_static_method(
691
+ type(__ray_actor), method_name
692
+ )
693
+ if is_bound:
694
+ return method(*args, **kwargs)
695
+ else:
696
+ return method(__ray_actor, *args, **kwargs)
697
+
698
+ # Set method_name and method as attributes to the executor closure
699
+ # so we can make decision based on these attributes in task executor.
700
+ # Precisely, asyncio support requires to know whether:
701
+ # - the method is a ray internal method: starts with __ray
702
+ # - the method is a coroutine function: defined by async def
703
+ actor_method_executor.name = method_name
704
+ actor_method_executor.method = method
705
+
706
+ return actor_method_executor
.venv/lib/python3.11/site-packages/ray/_private/gcs_aio_client.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional
3
+ import ray
4
+ from ray._raylet import InnerGcsClient
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ class GcsAioClient:
10
+ """
11
+ Async GCS client.
12
+
13
+ Historical note: there was a `ray::gcs::PythonGcsClient` C++ binding which has only
14
+ sync API and in Python we wrap it with ThreadPoolExecutor. It's been removed in
15
+ favor of `ray::gcs::GcsClient` which contains async API.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ address: str = None,
21
+ loop=None,
22
+ executor=None,
23
+ nums_reconnect_retry: int = 5,
24
+ cluster_id: Optional[str] = None,
25
+ ):
26
+ # This must be consistent with GcsClient.__cinit__ in _raylet.pyx
27
+ timeout_ms = ray._config.py_gcs_connect_timeout_s() * 1000
28
+ self.inner = InnerGcsClient.standalone(
29
+ str(address), cluster_id=cluster_id, timeout_ms=timeout_ms
30
+ )
31
+ # Forwarded Methods. Not using __getattr__ because we want one fewer layer of
32
+ # indirection.
33
+ self.internal_kv_get = self.inner.async_internal_kv_get
34
+ self.internal_kv_multi_get = self.inner.async_internal_kv_multi_get
35
+ self.internal_kv_put = self.inner.async_internal_kv_put
36
+ self.internal_kv_del = self.inner.async_internal_kv_del
37
+ self.internal_kv_exists = self.inner.async_internal_kv_exists
38
+ self.internal_kv_keys = self.inner.async_internal_kv_keys
39
+ self.check_alive = self.inner.async_check_alive
40
+ self.get_all_job_info = self.inner.async_get_all_job_info
41
+ # Forwarded Properties.
42
+ self.address = self.inner.address
43
+ self.cluster_id = self.inner.cluster_id
44
+ # Note: these only exists in the new client.
45
+ self.get_all_actor_info = self.inner.async_get_all_actor_info
46
+ self.get_all_node_info = self.inner.async_get_all_node_info
47
+ self.kill_actor = self.inner.async_kill_actor
.venv/lib/python3.11/site-packages/ray/_private/gcs_pubsub.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from collections import deque
3
+ import logging
4
+ import random
5
+ from typing import Tuple, List
6
+
7
+ import grpc
8
+ from ray._private.utils import get_or_create_event_loop
9
+
10
+ try:
11
+ from grpc import aio as aiogrpc
12
+ except ImportError:
13
+ from grpc.experimental import aio as aiogrpc
14
+
15
+ import ray._private.gcs_utils as gcs_utils
16
+ from ray.core.generated import gcs_service_pb2_grpc
17
+ from ray.core.generated import gcs_service_pb2
18
+ from ray.core.generated import gcs_pb2
19
+ from ray.core.generated import common_pb2
20
+ from ray.core.generated import pubsub_pb2
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Max retries for GCS publisher connection error
25
+ MAX_GCS_PUBLISH_RETRIES = 60
26
+
27
+
28
+ class _PublisherBase:
29
+ @staticmethod
30
+ def _create_node_resource_usage_request(key: str, json: str):
31
+ return gcs_service_pb2.GcsPublishRequest(
32
+ pub_messages=[
33
+ pubsub_pb2.PubMessage(
34
+ channel_type=pubsub_pb2.RAY_NODE_RESOURCE_USAGE_CHANNEL,
35
+ key_id=key.encode(),
36
+ node_resource_usage_message=common_pb2.NodeResourceUsage(json=json),
37
+ )
38
+ ]
39
+ )
40
+
41
+
42
+ class _SubscriberBase:
43
+ def __init__(self, worker_id: bytes = None):
44
+ self._worker_id = worker_id
45
+ # self._subscriber_id needs to match the binary format of a random
46
+ # SubscriberID / UniqueID, which is 28 (kUniqueIDSize) random bytes.
47
+ self._subscriber_id = bytes(bytearray(random.getrandbits(8) for _ in range(28)))
48
+ self._last_batch_size = 0
49
+ self._max_processed_sequence_id = 0
50
+ self._publisher_id = b""
51
+
52
+ # Batch size of the result from last poll. Used to indicate whether the
53
+ # subscriber can keep up.
54
+ @property
55
+ def last_batch_size(self):
56
+ return self._last_batch_size
57
+
58
+ def _subscribe_request(self, channel):
59
+ cmd = pubsub_pb2.Command(channel_type=channel, subscribe_message={})
60
+ req = gcs_service_pb2.GcsSubscriberCommandBatchRequest(
61
+ subscriber_id=self._subscriber_id, sender_id=self._worker_id, commands=[cmd]
62
+ )
63
+ return req
64
+
65
+ def _poll_request(self):
66
+ return gcs_service_pb2.GcsSubscriberPollRequest(
67
+ subscriber_id=self._subscriber_id,
68
+ max_processed_sequence_id=self._max_processed_sequence_id,
69
+ publisher_id=self._publisher_id,
70
+ )
71
+
72
+ def _unsubscribe_request(self, channels):
73
+ req = gcs_service_pb2.GcsSubscriberCommandBatchRequest(
74
+ subscriber_id=self._subscriber_id, sender_id=self._worker_id, commands=[]
75
+ )
76
+ for channel in channels:
77
+ req.commands.append(
78
+ pubsub_pb2.Command(channel_type=channel, unsubscribe_message={})
79
+ )
80
+ return req
81
+
82
+ @staticmethod
83
+ def _should_terminate_polling(e: grpc.RpcError) -> None:
84
+ # Caller only expects polling to be terminated after deadline exceeded.
85
+ if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
86
+ return True
87
+ # Could be a temporary connection issue. Suppress error.
88
+ # TODO: reconnect GRPC channel?
89
+ if e.code() == grpc.StatusCode.UNAVAILABLE:
90
+ return True
91
+ return False
92
+
93
+
94
+ class GcsAioPublisher(_PublisherBase):
95
+ """Publisher to GCS. Uses async io."""
96
+
97
+ def __init__(self, address: str = None, channel: aiogrpc.Channel = None):
98
+ if address:
99
+ assert channel is None, "address and channel cannot both be specified"
100
+ channel = gcs_utils.create_gcs_channel(address, aio=True)
101
+ else:
102
+ assert channel is not None, "One of address and channel must be specified"
103
+ self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)
104
+
105
+ async def publish_resource_usage(self, key: str, json: str) -> None:
106
+ """Publishes logs to GCS."""
107
+ req = self._create_node_resource_usage_request(key, json)
108
+ await self._stub.GcsPublish(req)
109
+
110
+
111
+ class _AioSubscriber(_SubscriberBase):
112
+ """Async io subscriber to GCS.
113
+
114
+ Usage example common to Aio subscribers:
115
+ subscriber = GcsAioXxxSubscriber(address="...")
116
+ await subscriber.subscribe()
117
+ while running:
118
+ ...... = await subscriber.poll()
119
+ ......
120
+ await subscriber.close()
121
+ """
122
+
123
+ def __init__(
124
+ self,
125
+ pubsub_channel_type,
126
+ worker_id: bytes = None,
127
+ address: str = None,
128
+ channel: aiogrpc.Channel = None,
129
+ ):
130
+ super().__init__(worker_id)
131
+
132
+ if address:
133
+ assert channel is None, "address and channel cannot both be specified"
134
+ channel = gcs_utils.create_gcs_channel(address, aio=True)
135
+ else:
136
+ assert channel is not None, "One of address and channel must be specified"
137
+ # GRPC stub to GCS pubsub.
138
+ self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)
139
+
140
+ # Type of the channel.
141
+ self._channel = pubsub_channel_type
142
+ # A queue of received PubMessage.
143
+ self._queue = deque()
144
+ # Indicates whether the subscriber has closed.
145
+ self._close = asyncio.Event()
146
+
147
+ async def subscribe(self) -> None:
148
+ """Registers a subscription for the subscriber's channel type.
149
+
150
+ Before the registration, published messages in the channel will not be
151
+ saved for the subscriber.
152
+ """
153
+ if self._close.is_set():
154
+ return
155
+ req = self._subscribe_request(self._channel)
156
+ await self._stub.GcsSubscriberCommandBatch(req, timeout=30)
157
+
158
+ async def _poll_call(self, req, timeout=None):
159
+ # Wrap GRPC _AioCall as a coroutine.
160
+ return await self._stub.GcsSubscriberPoll(req, timeout=timeout)
161
+
162
+ async def _poll(self, timeout=None) -> None:
163
+ while len(self._queue) == 0:
164
+ req = self._poll_request()
165
+ poll = get_or_create_event_loop().create_task(
166
+ self._poll_call(req, timeout=timeout)
167
+ )
168
+ close = get_or_create_event_loop().create_task(self._close.wait())
169
+ done, others = await asyncio.wait(
170
+ [poll, close], timeout=timeout, return_when=asyncio.FIRST_COMPLETED
171
+ )
172
+ # Cancel the other task if needed to prevent memory leak.
173
+ other_task = others.pop()
174
+ if not other_task.done():
175
+ other_task.cancel()
176
+ if poll not in done or close in done:
177
+ # Request timed out or subscriber closed.
178
+ break
179
+ try:
180
+ self._last_batch_size = len(poll.result().pub_messages)
181
+ if poll.result().publisher_id != self._publisher_id:
182
+ if self._publisher_id != "":
183
+ logger.debug(
184
+ f"replied publisher_id {poll.result().publisher_id}"
185
+ f"different from {self._publisher_id}, this should "
186
+ "only happens during gcs failover."
187
+ )
188
+ self._publisher_id = poll.result().publisher_id
189
+ self._max_processed_sequence_id = 0
190
+ for msg in poll.result().pub_messages:
191
+ if msg.sequence_id <= self._max_processed_sequence_id:
192
+ logger.warning(f"Ignoring out of order message {msg}")
193
+ continue
194
+ self._max_processed_sequence_id = msg.sequence_id
195
+ self._queue.append(msg)
196
+ except grpc.RpcError as e:
197
+ if self._should_terminate_polling(e):
198
+ return
199
+ raise
200
+
201
+ async def close(self) -> None:
202
+ """Closes the subscriber and its active subscription."""
203
+
204
+ # Mark close to terminate inflight polling and prevent future requests.
205
+ if self._close.is_set():
206
+ return
207
+ self._close.set()
208
+ req = self._unsubscribe_request(channels=[self._channel])
209
+ try:
210
+ await self._stub.GcsSubscriberCommandBatch(req, timeout=5)
211
+ except Exception:
212
+ pass
213
+ self._stub = None
214
+
215
+
216
+ class GcsAioResourceUsageSubscriber(_AioSubscriber):
217
+ def __init__(
218
+ self,
219
+ worker_id: bytes = None,
220
+ address: str = None,
221
+ channel: grpc.Channel = None,
222
+ ):
223
+ super().__init__(
224
+ pubsub_pb2.RAY_NODE_RESOURCE_USAGE_CHANNEL, worker_id, address, channel
225
+ )
226
+
227
+ async def poll(self, timeout=None) -> Tuple[bytes, str]:
228
+ """Polls for new resource usage message.
229
+
230
+ Returns:
231
+ A tuple of string reporter ID and resource usage json string.
232
+ """
233
+ await self._poll(timeout=timeout)
234
+ return self._pop_resource_usage(self._queue)
235
+
236
+ @staticmethod
237
+ def _pop_resource_usage(queue):
238
+ if len(queue) == 0:
239
+ return None, None
240
+ msg = queue.popleft()
241
+ return msg.key_id.decode(), msg.node_resource_usage_message.json
242
+
243
+
244
+ class GcsAioActorSubscriber(_AioSubscriber):
245
+ def __init__(
246
+ self,
247
+ worker_id: bytes = None,
248
+ address: str = None,
249
+ channel: grpc.Channel = None,
250
+ ):
251
+ super().__init__(pubsub_pb2.GCS_ACTOR_CHANNEL, worker_id, address, channel)
252
+
253
+ @property
254
+ def queue_size(self):
255
+ return len(self._queue)
256
+
257
+ async def poll(
258
+ self, batch_size, timeout=None
259
+ ) -> List[Tuple[bytes, gcs_pb2.ActorTableData]]:
260
+ """Polls for new actor message.
261
+
262
+ Returns:
263
+ A list of tuples of binary actor ID and actor table data.
264
+ """
265
+ await self._poll(timeout=timeout)
266
+ return self._pop_actors(self._queue, batch_size=batch_size)
267
+
268
+ @staticmethod
269
+ def _pop_actors(queue, batch_size):
270
+ if len(queue) == 0:
271
+ return []
272
+ popped = 0
273
+ msgs = []
274
+ while len(queue) > 0 and popped < batch_size:
275
+ msg = queue.popleft()
276
+ msgs.append((msg.key_id, msg.actor_message))
277
+ popped += 1
278
+ return msgs
279
+
280
+
281
+ class GcsAioNodeInfoSubscriber(_AioSubscriber):
282
+ def __init__(
283
+ self,
284
+ worker_id: bytes = None,
285
+ address: str = None,
286
+ channel: grpc.Channel = None,
287
+ ):
288
+ super().__init__(pubsub_pb2.GCS_NODE_INFO_CHANNEL, worker_id, address, channel)
289
+
290
+ async def poll(
291
+ self, batch_size, timeout=None
292
+ ) -> List[Tuple[bytes, gcs_pb2.GcsNodeInfo]]:
293
+ """Polls for new node info message.
294
+
295
+ Returns:
296
+ A list of tuples of (node_id, GcsNodeInfo).
297
+ """
298
+ await self._poll(timeout=timeout)
299
+ return self._pop_node_infos(self._queue, batch_size=batch_size)
300
+
301
+ @staticmethod
302
+ def _pop_node_infos(queue, batch_size):
303
+ if len(queue) == 0:
304
+ return []
305
+ popped = 0
306
+ msgs = []
307
+ while len(queue) > 0 and popped < batch_size:
308
+ msg = queue.popleft()
309
+ msgs.append((msg.key_id, msg.node_info_message))
310
+ popped += 1
311
+ return msgs
.venv/lib/python3.11/site-packages/ray/_private/gcs_utils.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional
3
+
4
+ from ray._private import ray_constants
5
+
6
+ import ray._private.gcs_aio_client
7
+
8
+ from ray.core.generated.common_pb2 import ErrorType, JobConfig
9
+ from ray.core.generated.gcs_pb2 import (
10
+ ActorTableData,
11
+ AvailableResources,
12
+ TotalResources,
13
+ ErrorTableData,
14
+ GcsEntry,
15
+ GcsNodeInfo,
16
+ JobTableData,
17
+ PlacementGroupTableData,
18
+ PubSubMessage,
19
+ ResourceDemand,
20
+ ResourceLoad,
21
+ ResourcesData,
22
+ ResourceUsageBatchData,
23
+ TablePrefix,
24
+ TablePubsub,
25
+ TaskEvents,
26
+ WorkerTableData,
27
+ )
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ __all__ = [
32
+ "ActorTableData",
33
+ "GcsNodeInfo",
34
+ "AvailableResources",
35
+ "TotalResources",
36
+ "JobTableData",
37
+ "JobConfig",
38
+ "ErrorTableData",
39
+ "ErrorType",
40
+ "GcsEntry",
41
+ "ResourceUsageBatchData",
42
+ "ResourcesData",
43
+ "TablePrefix",
44
+ "TablePubsub",
45
+ "TaskEvents",
46
+ "ResourceDemand",
47
+ "ResourceLoad",
48
+ "PubSubMessage",
49
+ "WorkerTableData",
50
+ "PlacementGroupTableData",
51
+ ]
52
+
53
+
54
+ WORKER = 0
55
+ DRIVER = 1
56
+
57
+ # Cap messages at 512MB
58
+ _MAX_MESSAGE_LENGTH = 512 * 1024 * 1024
59
+ # Send keepalive every 60s
60
+ _GRPC_KEEPALIVE_TIME_MS = 60 * 1000
61
+ # Keepalive should be replied < 60s
62
+ _GRPC_KEEPALIVE_TIMEOUT_MS = 60 * 1000
63
+
64
+ # Also relying on these defaults:
65
+ # grpc.keepalive_permit_without_calls=0: No keepalive without inflight calls.
66
+ # grpc.use_local_subchannel_pool=0: Subchannels are shared.
67
+ _GRPC_OPTIONS = [
68
+ *ray_constants.GLOBAL_GRPC_OPTIONS,
69
+ ("grpc.max_send_message_length", _MAX_MESSAGE_LENGTH),
70
+ ("grpc.max_receive_message_length", _MAX_MESSAGE_LENGTH),
71
+ ("grpc.keepalive_time_ms", _GRPC_KEEPALIVE_TIME_MS),
72
+ ("grpc.keepalive_timeout_ms", _GRPC_KEEPALIVE_TIMEOUT_MS),
73
+ ]
74
+
75
+
76
+ def create_gcs_channel(address: str, aio=False):
77
+ """Returns a GRPC channel to GCS.
78
+
79
+ Args:
80
+ address: GCS address string, e.g. ip:port
81
+ aio: Whether using grpc.aio
82
+ Returns:
83
+ grpc.Channel or grpc.aio.Channel to GCS
84
+ """
85
+ from ray._private.utils import init_grpc_channel
86
+
87
+ return init_grpc_channel(address, options=_GRPC_OPTIONS, asynchronous=aio)
88
+
89
+
90
+ class GcsChannel:
91
+ def __init__(self, gcs_address: Optional[str] = None, aio: bool = False):
92
+ self._gcs_address = gcs_address
93
+ self._aio = aio
94
+
95
+ @property
96
+ def address(self):
97
+ return self._gcs_address
98
+
99
+ def connect(self):
100
+ # GCS server uses a cached port, so it should use the same port after
101
+ # restarting. This means GCS address should stay the same for the
102
+ # lifetime of the Ray cluster.
103
+ self._channel = create_gcs_channel(self._gcs_address, self._aio)
104
+
105
+ def channel(self):
106
+ return self._channel
107
+
108
+
109
+ # re-export
110
+ GcsAioClient = ray._private.gcs_aio_client.GcsAioClient
111
+
112
+
113
+ def cleanup_redis_storage(
114
+ host: str,
115
+ port: int,
116
+ password: str,
117
+ use_ssl: bool,
118
+ storage_namespace: str,
119
+ username: Optional[str] = None,
120
+ ):
121
+ """This function is used to cleanup the storage. Before we having
122
+ a good design for storage backend, it can be used to delete the old
123
+ data. It support redis cluster and non cluster mode.
124
+
125
+ Args:
126
+ host: The host address of the Redis.
127
+ port: The port of the Redis.
128
+ username: The username of the Redis.
129
+ password: The password of the Redis.
130
+ use_ssl: Whether to encrypt the connection.
131
+ storage_namespace: The namespace of the storage to be deleted.
132
+ """
133
+
134
+ from ray._raylet import del_key_prefix_from_storage # type: ignore
135
+
136
+ if not isinstance(host, str):
137
+ raise ValueError("Host must be a string")
138
+
139
+ if username is None:
140
+ username = ""
141
+
142
+ if not isinstance(username, str):
143
+ raise ValueError("Username must be a string")
144
+
145
+ if not isinstance(password, str):
146
+ raise ValueError("Password must be a string")
147
+
148
+ if port < 0:
149
+ raise ValueError(f"Invalid port: {port}")
150
+
151
+ if not isinstance(use_ssl, bool):
152
+ raise TypeError("use_ssl must be a boolean")
153
+
154
+ if not isinstance(storage_namespace, str):
155
+ raise ValueError("storage namespace must be a string")
156
+
157
+ # Right now, GCS stores all data into multiple hashes with keys prefixed by
158
+ # storage_namespace. So we only need to delete the specific key prefix to cleanup
159
+ # the cluster.
160
+ # Note this deletes all keys with prefix `RAY{key_prefix}@`, not `{key_prefix}`.
161
+ return del_key_prefix_from_storage(
162
+ host, port, username, password, use_ssl, storage_namespace
163
+ )
.venv/lib/python3.11/site-packages/ray/_private/inspect_util.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+
3
+
4
+ def is_cython(obj):
5
+ """Check if an object is a Cython function or method"""
6
+
7
+ # TODO(suo): We could split these into two functions, one for Cython
8
+ # functions and another for Cython methods.
9
+ # TODO(suo): There doesn't appear to be a Cython function 'type' we can
10
+ # check against via isinstance. Please correct me if I'm wrong.
11
+ def check_cython(x):
12
+ return type(x).__name__ == "cython_function_or_method"
13
+
14
+ # Check if function or method, respectively
15
+ return check_cython(obj) or (
16
+ hasattr(obj, "__func__") and check_cython(obj.__func__)
17
+ )
18
+
19
+
20
+ def is_function_or_method(obj):
21
+ """Check if an object is a function or method.
22
+
23
+ Args:
24
+ obj: The Python object in question.
25
+
26
+ Returns:
27
+ True if the object is an function or method.
28
+ """
29
+ return inspect.isfunction(obj) or inspect.ismethod(obj) or is_cython(obj)
30
+
31
+
32
+ def is_class_method(f):
33
+ """Returns whether the given method is a class_method."""
34
+ return hasattr(f, "__self__") and f.__self__ is not None
35
+
36
+
37
+ def is_static_method(cls, f_name):
38
+ """Returns whether the class has a static method with the given name.
39
+
40
+ Args:
41
+ cls: The Python class (i.e. object of type `type`) to
42
+ search for the method in.
43
+ f_name: The name of the method to look up in this class
44
+ and check whether or not it is static.
45
+ """
46
+ for base_cls in inspect.getmro(cls):
47
+ if f_name in base_cls.__dict__:
48
+ return isinstance(base_cls.__dict__[f_name], staticmethod)
49
+ return False
.venv/lib/python3.11/site-packages/ray/_private/internal_api.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import ray
4
+ import ray._private.profiling as profiling
5
+ import ray._private.services as services
6
+ import ray._private.utils as utils
7
+ import ray._private.worker
8
+ from ray._private.state import GlobalState
9
+ from ray._raylet import GcsClientOptions
10
+ from ray.core.generated import common_pb2
11
+
12
+ __all__ = ["free", "global_gc"]
13
+ MAX_MESSAGE_LENGTH = ray._config.max_grpc_message_size()
14
+
15
+
16
+ def global_gc():
17
+ """Trigger gc.collect() on all workers in the cluster."""
18
+
19
+ worker = ray._private.worker.global_worker
20
+ worker.core_worker.global_gc()
21
+
22
+
23
+ def get_state_from_address(address=None):
24
+ address = services.canonicalize_bootstrap_address_or_die(address)
25
+
26
+ state = GlobalState()
27
+ options = GcsClientOptions.create(
28
+ address, None, allow_cluster_id_nil=True, fetch_cluster_id_if_nil=False
29
+ )
30
+ state._initialize_global_state(options)
31
+ return state
32
+
33
+
34
+ def memory_summary(
35
+ address=None,
36
+ group_by="NODE_ADDRESS",
37
+ sort_by="OBJECT_SIZE",
38
+ units="B",
39
+ line_wrap=True,
40
+ stats_only=False,
41
+ num_entries=None,
42
+ ):
43
+ from ray.dashboard.memory_utils import memory_summary
44
+
45
+ state = get_state_from_address(address)
46
+ reply = get_memory_info_reply(state)
47
+
48
+ if stats_only:
49
+ return store_stats_summary(reply)
50
+ return memory_summary(
51
+ state, group_by, sort_by, line_wrap, units, num_entries
52
+ ) + store_stats_summary(reply)
53
+
54
+
55
+ def get_memory_info_reply(state, node_manager_address=None, node_manager_port=None):
56
+ """Returns global memory info."""
57
+
58
+ from ray.core.generated import node_manager_pb2, node_manager_pb2_grpc
59
+
60
+ # We can ask any Raylet for the global memory info, that Raylet internally
61
+ # asks all nodes in the cluster for memory stats.
62
+ if node_manager_address is None or node_manager_port is None:
63
+ # We should ask for a raylet that is alive.
64
+ raylet = None
65
+ for node in state.node_table():
66
+ if node["Alive"]:
67
+ raylet = node
68
+ break
69
+ assert raylet is not None, "Every raylet is dead"
70
+ raylet_address = "{}:{}".format(
71
+ raylet["NodeManagerAddress"], raylet["NodeManagerPort"]
72
+ )
73
+ else:
74
+ raylet_address = "{}:{}".format(node_manager_address, node_manager_port)
75
+
76
+ channel = utils.init_grpc_channel(
77
+ raylet_address,
78
+ options=[
79
+ ("grpc.max_send_message_length", MAX_MESSAGE_LENGTH),
80
+ ("grpc.max_receive_message_length", MAX_MESSAGE_LENGTH),
81
+ ],
82
+ )
83
+
84
+ stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
85
+ reply = stub.FormatGlobalMemoryInfo(
86
+ node_manager_pb2.FormatGlobalMemoryInfoRequest(include_memory_info=False),
87
+ timeout=60.0,
88
+ )
89
+ return reply
90
+
91
+
92
+ def node_stats(
93
+ node_manager_address=None, node_manager_port=None, include_memory_info=True
94
+ ):
95
+ """Returns NodeStats object describing memory usage in the cluster."""
96
+
97
+ from ray.core.generated import node_manager_pb2, node_manager_pb2_grpc
98
+
99
+ # We can ask any Raylet for the global memory info.
100
+ assert node_manager_address is not None and node_manager_port is not None
101
+ raylet_address = "{}:{}".format(node_manager_address, node_manager_port)
102
+ channel = utils.init_grpc_channel(
103
+ raylet_address,
104
+ options=[
105
+ ("grpc.max_send_message_length", MAX_MESSAGE_LENGTH),
106
+ ("grpc.max_receive_message_length", MAX_MESSAGE_LENGTH),
107
+ ],
108
+ )
109
+
110
+ stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
111
+ node_stats = stub.GetNodeStats(
112
+ node_manager_pb2.GetNodeStatsRequest(include_memory_info=include_memory_info),
113
+ timeout=30.0,
114
+ )
115
+ return node_stats
116
+
117
+
118
+ def store_stats_summary(reply):
119
+ """Returns formatted string describing object store stats in all nodes."""
120
+ store_summary = "--- Aggregate object store stats across all nodes ---\n"
121
+ # TODO(ekl) it would be nice if we could provide a full memory usage
122
+ # breakdown by type (e.g., pinned by worker, primary, etc.)
123
+ store_summary += (
124
+ "Plasma memory usage {} MiB, {} objects, {}% full, {}% "
125
+ "needed\n".format(
126
+ int(reply.store_stats.object_store_bytes_used / (1024 * 1024)),
127
+ reply.store_stats.num_local_objects,
128
+ round(
129
+ 100
130
+ * reply.store_stats.object_store_bytes_used
131
+ / reply.store_stats.object_store_bytes_avail,
132
+ 2,
133
+ ),
134
+ round(
135
+ 100
136
+ * reply.store_stats.object_store_bytes_primary_copy
137
+ / reply.store_stats.object_store_bytes_avail,
138
+ 2,
139
+ ),
140
+ )
141
+ )
142
+ if reply.store_stats.object_store_bytes_fallback > 0:
143
+ store_summary += "Plasma filesystem mmap usage: {} MiB\n".format(
144
+ int(reply.store_stats.object_store_bytes_fallback / (1024 * 1024))
145
+ )
146
+ if reply.store_stats.spill_time_total_s > 0:
147
+ store_summary += (
148
+ "Spilled {} MiB, {} objects, avg write throughput {} MiB/s\n".format(
149
+ int(reply.store_stats.spilled_bytes_total / (1024 * 1024)),
150
+ reply.store_stats.spilled_objects_total,
151
+ int(
152
+ reply.store_stats.spilled_bytes_total
153
+ / (1024 * 1024)
154
+ / reply.store_stats.spill_time_total_s
155
+ ),
156
+ )
157
+ )
158
+ if reply.store_stats.restore_time_total_s > 0:
159
+ store_summary += (
160
+ "Restored {} MiB, {} objects, avg read throughput {} MiB/s\n".format(
161
+ int(reply.store_stats.restored_bytes_total / (1024 * 1024)),
162
+ reply.store_stats.restored_objects_total,
163
+ int(
164
+ reply.store_stats.restored_bytes_total
165
+ / (1024 * 1024)
166
+ / reply.store_stats.restore_time_total_s
167
+ ),
168
+ )
169
+ )
170
+ if reply.store_stats.consumed_bytes > 0:
171
+ store_summary += "Objects consumed by Ray tasks: {} MiB.\n".format(
172
+ int(reply.store_stats.consumed_bytes / (1024 * 1024))
173
+ )
174
+ if reply.store_stats.object_pulls_queued:
175
+ store_summary += "Object fetches queued, waiting for available memory."
176
+
177
+ return store_summary
178
+
179
+
180
+ def free(object_refs: list, local_only: bool = False):
181
+ """Free a list of IDs from the in-process and plasma object stores.
182
+
183
+ This function is a low-level API which should be used in restricted
184
+ scenarios.
185
+
186
+ If local_only is false, the request will be send to all object stores.
187
+
188
+ This method will not return any value to indicate whether the deletion is
189
+ successful or not. This function is an instruction to the object store. If
190
+ some of the objects are in use, the object stores will delete them later
191
+ when the ref count is down to 0.
192
+
193
+ Examples:
194
+
195
+ .. testcode::
196
+
197
+ import ray
198
+
199
+ @ray.remote
200
+ def f():
201
+ return 0
202
+
203
+ obj_ref = f.remote()
204
+ ray.get(obj_ref) # wait for object to be created first
205
+ free([obj_ref]) # unpin & delete object globally
206
+
207
+ Args:
208
+ object_refs (List[ObjectRef]): List of object refs to delete.
209
+ local_only: Whether only deleting the list of objects in local
210
+ object store or all object stores.
211
+ """
212
+ worker = ray._private.worker.global_worker
213
+
214
+ if isinstance(object_refs, ray.ObjectRef):
215
+ object_refs = [object_refs]
216
+
217
+ if not isinstance(object_refs, list):
218
+ raise TypeError(
219
+ "free() expects a list of ObjectRef, got {}".format(type(object_refs))
220
+ )
221
+
222
+ # Make sure that the values are object refs.
223
+ for object_ref in object_refs:
224
+ if not isinstance(object_ref, ray.ObjectRef):
225
+ raise TypeError(
226
+ "Attempting to call `free` on the value {}, "
227
+ "which is not an ray.ObjectRef.".format(object_ref)
228
+ )
229
+
230
+ worker.check_connected()
231
+ with profiling.profile("ray.free"):
232
+ if len(object_refs) == 0:
233
+ return
234
+
235
+ worker.core_worker.free_objects(object_refs, local_only)
236
+
237
+
238
+ def get_local_ongoing_lineage_reconstruction_tasks() -> List[
239
+ Tuple[common_pb2.LineageReconstructionTask, int]
240
+ ]:
241
+ """Return the locally submitted ongoing retry tasks
242
+ triggered by lineage reconstruction.
243
+
244
+ NOTE: for the lineage reconstruction task status,
245
+ this method only returns the status known to the submitter
246
+ (i.e. it returns SUBMITTED_TO_WORKER instead of RUNNING).
247
+
248
+ The return type is a list of pairs where pair.first is the
249
+ lineage reconstruction task info and pair.second is the number
250
+ of ongoing lineage reconstruction tasks of this type.
251
+ """
252
+
253
+ worker = ray._private.worker.global_worker
254
+ worker.check_connected()
255
+ return worker.core_worker.get_local_ongoing_lineage_reconstruction_tasks()
.venv/lib/python3.11/site-packages/ray/_private/log.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import threading
3
+ from typing import Union
4
+ import time
5
+
6
+ INTERNAL_TIMESTAMP_LOG_KEY = "_ray_timestamp_ns"
7
+
8
+
9
+ def _print_loggers():
10
+ """Print a formatted list of loggers and their handlers for debugging."""
11
+ loggers = {logging.root.name: logging.root}
12
+ loggers.update(dict(sorted(logging.root.manager.loggerDict.items())))
13
+ for name, logger in loggers.items():
14
+ if isinstance(logger, logging.Logger):
15
+ print(f" {name}: disabled={logger.disabled}, propagate={logger.propagate}")
16
+ for handler in logger.handlers:
17
+ print(f" {handler}")
18
+
19
+
20
+ def clear_logger(logger: Union[str, logging.Logger]):
21
+ """Reset a logger, clearing its handlers and enabling propagation.
22
+
23
+ Args:
24
+ logger: Logger to be cleared
25
+ """
26
+ if isinstance(logger, str):
27
+ logger = logging.getLogger(logger)
28
+ logger.propagate = True
29
+ logger.handlers.clear()
30
+
31
+
32
+ class PlainRayHandler(logging.StreamHandler):
33
+ """A plain log handler.
34
+
35
+ This handler writes to whatever sys.stderr points to at emit-time,
36
+ not at instantiation time. See docs for logging._StderrHandler.
37
+ """
38
+
39
+ def __init__(self):
40
+ super().__init__()
41
+ self.plain_handler = logging._StderrHandler()
42
+ self.plain_handler.level = self.level
43
+ self.plain_handler.formatter = logging.Formatter(fmt="%(message)s")
44
+
45
+ def emit(self, record: logging.LogRecord):
46
+ """Emit the log message.
47
+
48
+ If this is a worker, bypass fancy logging and just emit the log record.
49
+ If this is the driver, emit the message using the appropriate console handler.
50
+
51
+ Args:
52
+ record: Log record to be emitted
53
+ """
54
+ import ray
55
+
56
+ if (
57
+ hasattr(ray, "_private")
58
+ and hasattr(ray._private, "worker")
59
+ and ray._private.worker.global_worker.mode
60
+ == ray._private.worker.WORKER_MODE
61
+ ):
62
+ self.plain_handler.emit(record)
63
+ else:
64
+ logging._StderrHandler.emit(self, record)
65
+
66
+
67
+ logger_initialized = False
68
+ logging_config_lock = threading.Lock()
69
+
70
+
71
+ def _setup_log_record_factory():
72
+ """Setup log record factory to add _ray_timestamp_ns to LogRecord."""
73
+ old_factory = logging.getLogRecordFactory()
74
+
75
+ def record_factory(*args, **kwargs):
76
+ record = old_factory(*args, **kwargs)
77
+ # Python logging module starts to use `time.time_ns()` to generate `created`
78
+ # from Python 3.13 to avoid the precision loss caused by the float type.
79
+ # Here, we generate the `created` for the LogRecord to support older Python
80
+ # versions.
81
+ ct = time.time_ns()
82
+ record.created = ct / 1e9
83
+
84
+ record.__dict__[INTERNAL_TIMESTAMP_LOG_KEY] = ct
85
+
86
+ return record
87
+
88
+ logging.setLogRecordFactory(record_factory)
89
+
90
+
91
+ def generate_logging_config():
92
+ """Generate the default Ray logging configuration."""
93
+ with logging_config_lock:
94
+ global logger_initialized
95
+ if logger_initialized:
96
+ return
97
+ logger_initialized = True
98
+
99
+ plain_formatter = logging.Formatter(
100
+ "%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s -- %(message)s"
101
+ )
102
+
103
+ default_handler = PlainRayHandler()
104
+ default_handler.setFormatter(plain_formatter)
105
+
106
+ ray_logger = logging.getLogger("ray")
107
+ ray_logger.setLevel(logging.INFO)
108
+ ray_logger.addHandler(default_handler)
109
+ ray_logger.propagate = False
110
+
111
+ # Special handling for ray.rllib: only warning-level messages passed through
112
+ # See https://github.com/ray-project/ray/pull/31858 for related PR
113
+ rllib_logger = logging.getLogger("ray.rllib")
114
+ rllib_logger.setLevel(logging.WARN)
115
+
116
+ # Set up the LogRecord factory.
117
+ _setup_log_record_factory()
.venv/lib/python3.11/site-packages/ray/_private/log_monitor.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import errno
3
+ import glob
4
+ import logging
5
+ import logging.handlers
6
+ import os
7
+ import platform
8
+ import re
9
+ import shutil
10
+ import time
11
+ import traceback
12
+ from typing import Callable, List, Optional, Set
13
+
14
+ from ray._raylet import GcsClient
15
+ import ray._private.ray_constants as ray_constants
16
+ import ray._private.services as services
17
+ import ray._private.utils
18
+ from ray._private.ray_logging import setup_component_logger
19
+
20
+ # Logger for this module. It should be configured at the entry point
21
+ # into the program using Ray. Ray provides a default configuration at
22
+ # entry/init points.
23
+ logger = logging.getLogger(__name__)
24
+
25
+ # The groups are job id, and pid.
26
+ WORKER_LOG_PATTERN = re.compile(".*worker.*-([0-9a-f]+)-(\d+)")
27
+ # The groups are job id.
28
+ RUNTIME_ENV_SETUP_PATTERN = re.compile(".*runtime_env_setup-(\d+).log")
29
+ # Log name update interval under pressure.
30
+ # We need it because log name update is CPU intensive and uses 100%
31
+ # of cpu when there are many log files.
32
+ LOG_NAME_UPDATE_INTERVAL_S = float(os.getenv("LOG_NAME_UPDATE_INTERVAL_S", 0.5))
33
+ # Once there are more files than this threshold,
34
+ # log monitor start giving backpressure to lower cpu usages.
35
+ RAY_LOG_MONITOR_MANY_FILES_THRESHOLD = int(
36
+ os.getenv("RAY_LOG_MONITOR_MANY_FILES_THRESHOLD", 1000)
37
+ )
38
+ RAY_RUNTIME_ENV_LOG_TO_DRIVER_ENABLED = int(
39
+ os.getenv("RAY_RUNTIME_ENV_LOG_TO_DRIVER_ENABLED", 0)
40
+ )
41
+
42
+
43
+ class LogFileInfo:
44
+ def __init__(
45
+ self,
46
+ filename=None,
47
+ size_when_last_opened=None,
48
+ file_position=None,
49
+ file_handle=None,
50
+ is_err_file=False,
51
+ job_id=None,
52
+ worker_pid=None,
53
+ ):
54
+ assert (
55
+ filename is not None
56
+ and size_when_last_opened is not None
57
+ and file_position is not None
58
+ )
59
+ self.filename = filename
60
+ self.size_when_last_opened = size_when_last_opened
61
+ self.file_position = file_position
62
+ self.file_handle = file_handle
63
+ self.is_err_file = is_err_file
64
+ self.job_id = job_id
65
+ self.worker_pid = worker_pid
66
+ self.actor_name = None
67
+ self.task_name = None
68
+
69
+ def reopen_if_necessary(self):
70
+ """Check if the file's inode has changed and reopen it if necessary.
71
+ There are a variety of reasons what we would logically consider a file
72
+ would have different inodes, such as log rotation or file syncing
73
+ semantics.
74
+ """
75
+ try:
76
+ open_inode = None
77
+ if self.file_handle and not self.file_handle.closed:
78
+ open_inode = os.fstat(self.file_handle.fileno()).st_ino
79
+
80
+ new_inode = os.stat(self.filename).st_ino
81
+ if open_inode != new_inode:
82
+ self.file_handle = open(self.filename, "rb")
83
+ self.file_handle.seek(self.file_position)
84
+ except Exception:
85
+ logger.debug(f"file no longer exists, skip re-opening of {self.filename}")
86
+
87
+ def __repr__(self):
88
+ return (
89
+ "FileInfo(\n"
90
+ f"\tfilename: {self.filename}\n"
91
+ f"\tsize_when_last_opened: {self.size_when_last_opened}\n"
92
+ f"\tfile_position: {self.file_position}\n"
93
+ f"\tfile_handle: {self.file_handle}\n"
94
+ f"\tis_err_file: {self.is_err_file}\n"
95
+ f"\tjob_id: {self.job_id}\n"
96
+ f"\tworker_pid: {self.worker_pid}\n"
97
+ f"\tactor_name: {self.actor_name}\n"
98
+ f"\ttask_name: {self.task_name}\n"
99
+ ")"
100
+ )
101
+
102
+
103
+ class LogMonitor:
104
+ """A monitor process for monitoring Ray log files.
105
+
106
+ This class maintains a list of open files and a list of closed log files. We
107
+ can't simply leave all files open because we'll run out of file
108
+ descriptors.
109
+
110
+ The "run" method of this class will cycle between doing several things:
111
+ 1. First, it will check if any new files have appeared in the log
112
+ directory. If so, they will be added to the list of closed files.
113
+ 2. Then, if we are unable to open any new files, we will close all of the
114
+ files.
115
+ 3. Then, we will open as many closed files as we can that may have new
116
+ lines (judged by an increase in file size since the last time the file
117
+ was opened).
118
+ 4. Then we will loop through the open files and see if there are any new
119
+ lines in the file. If so, we will publish them to Ray pubsub.
120
+
121
+ Attributes:
122
+ ip: The hostname of this machine, for grouping log messages.
123
+ logs_dir: The directory that the log files are in.
124
+ log_filenames: This is the set of filenames of all files in
125
+ open_file_infos and closed_file_infos.
126
+ open_file_infos (list[LogFileInfo]): Info for all of the open files.
127
+ closed_file_infos (list[LogFileInfo]): Info for all of the closed
128
+ files.
129
+ can_open_more_files: True if we can still open more files and
130
+ false otherwise.
131
+ max_files_open: The maximum number of files that can be open.
132
+ """
133
+
134
+ def __init__(
135
+ self,
136
+ node_ip_address: str,
137
+ logs_dir: str,
138
+ gcs_publisher: ray._raylet.GcsPublisher,
139
+ is_proc_alive_fn: Callable[[int], bool],
140
+ max_files_open: int = ray_constants.LOG_MONITOR_MAX_OPEN_FILES,
141
+ gcs_address: Optional[str] = None,
142
+ ):
143
+ """Initialize the log monitor object."""
144
+ self.ip: str = node_ip_address
145
+ self.logs_dir: str = logs_dir
146
+ self.publisher = gcs_publisher
147
+ self.log_filenames: Set[str] = set()
148
+ self.open_file_infos: List[LogFileInfo] = []
149
+ self.closed_file_infos: List[LogFileInfo] = []
150
+ self.can_open_more_files: bool = True
151
+ self.max_files_open: int = max_files_open
152
+ self.is_proc_alive_fn: Callable[[int], bool] = is_proc_alive_fn
153
+ self.is_autoscaler_v2: bool = self.get_is_autoscaler_v2(gcs_address)
154
+
155
+ logger.info(
156
+ f"Starting log monitor with [max open files={max_files_open}],"
157
+ f" [is_autoscaler_v2={self.is_autoscaler_v2}]"
158
+ )
159
+
160
+ def get_is_autoscaler_v2(self, gcs_address: Optional[str]) -> bool:
161
+ """Check if autoscaler v2 is enabled."""
162
+ if gcs_address is None:
163
+ return False
164
+
165
+ if not ray.experimental.internal_kv._internal_kv_initialized():
166
+ gcs_client = GcsClient(address=gcs_address)
167
+ ray.experimental.internal_kv._initialize_internal_kv(gcs_client)
168
+ from ray.autoscaler.v2.utils import is_autoscaler_v2
169
+
170
+ return is_autoscaler_v2()
171
+
172
+ def _close_all_files(self):
173
+ """Close all open files (so that we can open more)."""
174
+ while len(self.open_file_infos) > 0:
175
+ file_info = self.open_file_infos.pop(0)
176
+ file_info.file_handle.close()
177
+ file_info.file_handle = None
178
+
179
+ proc_alive = True
180
+ # Test if the worker process that generated the log file
181
+ # is still alive. Only applies to worker processes.
182
+ # For all other system components, we always assume they are alive.
183
+ if (
184
+ file_info.worker_pid != "raylet"
185
+ and file_info.worker_pid != "gcs_server"
186
+ and file_info.worker_pid != "autoscaler"
187
+ and file_info.worker_pid != "runtime_env"
188
+ and file_info.worker_pid is not None
189
+ ):
190
+ assert not isinstance(file_info.worker_pid, str), (
191
+ "PID should be an int type. " f"Given PID: {file_info.worker_pid}."
192
+ )
193
+ proc_alive = self.is_proc_alive_fn(file_info.worker_pid)
194
+ if not proc_alive:
195
+ # The process is not alive any more, so move the log file
196
+ # out of the log directory so glob.glob will not be slowed
197
+ # by it.
198
+ target = os.path.join(
199
+ self.logs_dir, "old", os.path.basename(file_info.filename)
200
+ )
201
+ try:
202
+ shutil.move(file_info.filename, target)
203
+ except (IOError, OSError) as e:
204
+ if e.errno == errno.ENOENT:
205
+ logger.warning(
206
+ f"Warning: The file {file_info.filename} was not found."
207
+ )
208
+ else:
209
+ raise e
210
+
211
+ if proc_alive:
212
+ self.closed_file_infos.append(file_info)
213
+
214
+ self.can_open_more_files = True
215
+
216
+ def update_log_filenames(self):
217
+ """Update the list of log files to monitor."""
218
+ monitor_log_paths = []
219
+ # output of user code is written here
220
+ monitor_log_paths += glob.glob(
221
+ f"{self.logs_dir}/worker*[.out|.err]"
222
+ ) + glob.glob(f"{self.logs_dir}/java-worker*.log")
223
+ # segfaults and other serious errors are logged here
224
+ monitor_log_paths += glob.glob(f"{self.logs_dir}/raylet*.err")
225
+ # monitor logs are needed to report autoscaler events
226
+ # TODO(rickyx): remove this after migration.
227
+ if not self.is_autoscaler_v2:
228
+ # We publish monitor logs in autoscaler v1
229
+ monitor_log_paths += glob.glob(f"{self.logs_dir}/monitor.log")
230
+ else:
231
+ # We publish autoscaler events directly in autoscaler v2
232
+ monitor_log_paths += glob.glob(
233
+ f"{self.logs_dir}/events/event_AUTOSCALER.log"
234
+ )
235
+
236
+ # If gcs server restarts, there can be multiple log files.
237
+ monitor_log_paths += glob.glob(f"{self.logs_dir}/gcs_server*.err")
238
+
239
+ # runtime_env setup process is logged here
240
+ if RAY_RUNTIME_ENV_LOG_TO_DRIVER_ENABLED:
241
+ monitor_log_paths += glob.glob(f"{self.logs_dir}/runtime_env*.log")
242
+ for file_path in monitor_log_paths:
243
+ if os.path.isfile(file_path) and file_path not in self.log_filenames:
244
+ worker_match = WORKER_LOG_PATTERN.match(file_path)
245
+ if worker_match:
246
+ worker_pid = int(worker_match.group(2))
247
+ else:
248
+ worker_pid = None
249
+ job_id = None
250
+
251
+ # Perform existence check first because most file will not be
252
+ # including runtime_env. This saves some cpu cycle.
253
+ if "runtime_env" in file_path:
254
+ runtime_env_job_match = RUNTIME_ENV_SETUP_PATTERN.match(file_path)
255
+ if runtime_env_job_match:
256
+ job_id = runtime_env_job_match.group(1)
257
+
258
+ is_err_file = file_path.endswith("err")
259
+
260
+ self.log_filenames.add(file_path)
261
+ self.closed_file_infos.append(
262
+ LogFileInfo(
263
+ filename=file_path,
264
+ size_when_last_opened=0,
265
+ file_position=0,
266
+ file_handle=None,
267
+ is_err_file=is_err_file,
268
+ job_id=job_id,
269
+ worker_pid=worker_pid,
270
+ )
271
+ )
272
+ log_filename = os.path.basename(file_path)
273
+ logger.info(f"Beginning to track file {log_filename}")
274
+
275
+ def open_closed_files(self):
276
+ """Open some closed files if they may have new lines.
277
+
278
+ Opening more files may require us to close some of the already open
279
+ files.
280
+ """
281
+ if not self.can_open_more_files:
282
+ # If we can't open any more files. Close all of the files.
283
+ self._close_all_files()
284
+
285
+ files_with_no_updates = []
286
+ while len(self.closed_file_infos) > 0:
287
+ if len(self.open_file_infos) >= self.max_files_open:
288
+ self.can_open_more_files = False
289
+ break
290
+
291
+ file_info = self.closed_file_infos.pop(0)
292
+ assert file_info.file_handle is None
293
+ # Get the file size to see if it has gotten bigger since we last
294
+ # opened it.
295
+ try:
296
+ file_size = os.path.getsize(file_info.filename)
297
+ except (IOError, OSError) as e:
298
+ # Catch "file not found" errors.
299
+ if e.errno == errno.ENOENT:
300
+ logger.warning(
301
+ f"Warning: The file {file_info.filename} was not found."
302
+ )
303
+ self.log_filenames.remove(file_info.filename)
304
+ continue
305
+ raise e
306
+
307
+ # If some new lines have been added to this file, try to reopen the
308
+ # file.
309
+ if file_size > file_info.size_when_last_opened:
310
+ try:
311
+ f = open(file_info.filename, "rb")
312
+ except (IOError, OSError) as e:
313
+ if e.errno == errno.ENOENT:
314
+ logger.warning(
315
+ f"Warning: The file {file_info.filename} was not found."
316
+ )
317
+ self.log_filenames.remove(file_info.filename)
318
+ continue
319
+ else:
320
+ raise e
321
+
322
+ f.seek(file_info.file_position)
323
+ file_info.size_when_last_opened = file_size
324
+ file_info.file_handle = f
325
+ self.open_file_infos.append(file_info)
326
+ else:
327
+ files_with_no_updates.append(file_info)
328
+
329
+ if len(self.open_file_infos) >= self.max_files_open:
330
+ self.can_open_more_files = False
331
+ # Add the files with no changes back to the list of closed files.
332
+ self.closed_file_infos += files_with_no_updates
333
+
334
+ def check_log_files_and_publish_updates(self):
335
+ """Gets updates to the log files and publishes them.
336
+
337
+ Returns:
338
+ True if anything was published and false otherwise.
339
+ """
340
+ anything_published = False
341
+ lines_to_publish = []
342
+
343
+ def flush():
344
+ nonlocal lines_to_publish
345
+ nonlocal anything_published
346
+ if len(lines_to_publish) > 0:
347
+ data = {
348
+ "ip": self.ip,
349
+ "pid": file_info.worker_pid,
350
+ "job": file_info.job_id,
351
+ "is_err": file_info.is_err_file,
352
+ "lines": lines_to_publish,
353
+ "actor_name": file_info.actor_name,
354
+ "task_name": file_info.task_name,
355
+ }
356
+ try:
357
+ self.publisher.publish_logs(data)
358
+ except Exception:
359
+ logger.exception(f"Failed to publish log messages {data}")
360
+ anything_published = True
361
+ lines_to_publish = []
362
+
363
+ for file_info in self.open_file_infos:
364
+ assert not file_info.file_handle.closed
365
+ file_info.reopen_if_necessary()
366
+
367
+ max_num_lines_to_read = ray_constants.LOG_MONITOR_NUM_LINES_TO_READ
368
+ for _ in range(max_num_lines_to_read):
369
+ try:
370
+ next_line = file_info.file_handle.readline()
371
+ # Replace any characters not in UTF-8 with
372
+ # a replacement character, see
373
+ # https://stackoverflow.com/a/38565489/10891801
374
+ next_line = next_line.decode("utf-8", "replace")
375
+ if next_line == "":
376
+ break
377
+ next_line = next_line.rstrip("\r\n")
378
+
379
+ if next_line.startswith(ray_constants.LOG_PREFIX_ACTOR_NAME):
380
+ flush() # Possible change of task/actor name.
381
+ file_info.actor_name = next_line.split(
382
+ ray_constants.LOG_PREFIX_ACTOR_NAME, 1
383
+ )[1]
384
+ file_info.task_name = None
385
+ elif next_line.startswith(ray_constants.LOG_PREFIX_TASK_NAME):
386
+ flush() # Possible change of task/actor name.
387
+ file_info.task_name = next_line.split(
388
+ ray_constants.LOG_PREFIX_TASK_NAME, 1
389
+ )[1]
390
+ elif next_line.startswith(ray_constants.LOG_PREFIX_JOB_ID):
391
+ file_info.job_id = next_line.split(
392
+ ray_constants.LOG_PREFIX_JOB_ID, 1
393
+ )[1]
394
+ elif next_line.startswith(
395
+ "Windows fatal exception: access violation"
396
+ ):
397
+ # We are suppressing the
398
+ # 'Windows fatal exception: access violation'
399
+ # message on workers on Windows here.
400
+ # As far as we know it is harmless,
401
+ # but is frequently popping up if Python
402
+ # functions are run inside the core
403
+ # worker C extension. See the investigation in
404
+ # github.com/ray-project/ray/issues/18944
405
+ # Also skip the following line, which is an
406
+ # empty line.
407
+ file_info.file_handle.readline()
408
+ else:
409
+ lines_to_publish.append(next_line)
410
+ except Exception:
411
+ logger.error(
412
+ f"Error: Reading file: {file_info.filename}, "
413
+ f"position: {file_info.file_info.file_handle.tell()} "
414
+ "failed."
415
+ )
416
+ raise
417
+
418
+ if file_info.file_position == 0:
419
+ # make filename windows-agnostic
420
+ filename = file_info.filename.replace("\\", "/")
421
+ if "/raylet" in filename:
422
+ file_info.worker_pid = "raylet"
423
+ elif "/gcs_server" in filename:
424
+ file_info.worker_pid = "gcs_server"
425
+ elif "/monitor" in filename or "event_AUTOSCALER" in filename:
426
+ file_info.worker_pid = "autoscaler"
427
+ elif "/runtime_env" in filename:
428
+ file_info.worker_pid = "runtime_env"
429
+
430
+ # Record the current position in the file.
431
+ file_info.file_position = file_info.file_handle.tell()
432
+ flush()
433
+
434
+ return anything_published
435
+
436
+ def should_update_filenames(self, last_file_updated_time: float) -> bool:
437
+ """Return true if filenames should be updated.
438
+
439
+ This method is used to apply the backpressure on file updates because
440
+ that requires heavy glob operations which use lots of CPUs.
441
+
442
+ Args:
443
+ last_file_updated_time: The last time filenames are updated.
444
+
445
+ Returns:
446
+ True if filenames should be updated. False otherwise.
447
+ """
448
+ elapsed_seconds = float(time.time() - last_file_updated_time)
449
+ return (
450
+ len(self.log_filenames) < RAY_LOG_MONITOR_MANY_FILES_THRESHOLD
451
+ or elapsed_seconds > LOG_NAME_UPDATE_INTERVAL_S
452
+ )
453
+
454
+ def run(self):
455
+ """Run the log monitor.
456
+
457
+ This will scan the file system once every LOG_NAME_UPDATE_INTERVAL_S to
458
+ check if there are new log files to monitor. It will also publish new
459
+ log lines.
460
+ """
461
+ last_updated = time.time()
462
+ while True:
463
+ if self.should_update_filenames(last_updated):
464
+ self.update_log_filenames()
465
+ last_updated = time.time()
466
+
467
+ self.open_closed_files()
468
+ anything_published = self.check_log_files_and_publish_updates()
469
+ # If nothing was published, then wait a little bit before checking
470
+ # for logs to avoid using too much CPU.
471
+ if not anything_published:
472
+ time.sleep(0.1)
473
+
474
+
475
+ def is_proc_alive(pid):
476
+ # Import locally to make sure the bundled version is used if needed
477
+ import psutil
478
+
479
+ try:
480
+ return psutil.Process(pid).is_running()
481
+ except psutil.NoSuchProcess:
482
+ # The process does not exist.
483
+ return False
484
+
485
+
486
+ if __name__ == "__main__":
487
+ parser = argparse.ArgumentParser(
488
+ description=("Parse GCS server address for the log monitor to connect to.")
489
+ )
490
+ parser.add_argument(
491
+ "--gcs-address", required=False, type=str, help="The address (ip:port) of GCS."
492
+ )
493
+ parser.add_argument(
494
+ "--logging-level",
495
+ required=False,
496
+ type=str,
497
+ default=ray_constants.LOGGER_LEVEL,
498
+ choices=ray_constants.LOGGER_LEVEL_CHOICES,
499
+ help=ray_constants.LOGGER_LEVEL_HELP,
500
+ )
501
+ parser.add_argument(
502
+ "--logging-format",
503
+ required=False,
504
+ type=str,
505
+ default=ray_constants.LOGGER_FORMAT,
506
+ help=ray_constants.LOGGER_FORMAT_HELP,
507
+ )
508
+ parser.add_argument(
509
+ "--logging-filename",
510
+ required=False,
511
+ type=str,
512
+ default=ray_constants.LOG_MONITOR_LOG_FILE_NAME,
513
+ help="Specify the name of log file, "
514
+ "log to stdout if set empty, default is "
515
+ f'"{ray_constants.LOG_MONITOR_LOG_FILE_NAME}"',
516
+ )
517
+ parser.add_argument(
518
+ "--session-dir",
519
+ required=True,
520
+ type=str,
521
+ help="Specify the path of the session directory used by Ray processes.",
522
+ )
523
+ parser.add_argument(
524
+ "--logs-dir",
525
+ required=True,
526
+ type=str,
527
+ help="Specify the path of the log directory used by Ray processes.",
528
+ )
529
+ parser.add_argument(
530
+ "--logging-rotate-bytes",
531
+ required=False,
532
+ type=int,
533
+ default=ray_constants.LOGGING_ROTATE_BYTES,
534
+ help="Specify the max bytes for rotating "
535
+ "log file, default is "
536
+ f"{ray_constants.LOGGING_ROTATE_BYTES} bytes.",
537
+ )
538
+ parser.add_argument(
539
+ "--logging-rotate-backup-count",
540
+ required=False,
541
+ type=int,
542
+ default=ray_constants.LOGGING_ROTATE_BACKUP_COUNT,
543
+ help="Specify the backup count of rotated log file, default is "
544
+ f"{ray_constants.LOGGING_ROTATE_BACKUP_COUNT}.",
545
+ )
546
+ args = parser.parse_args()
547
+ setup_component_logger(
548
+ logging_level=args.logging_level,
549
+ logging_format=args.logging_format,
550
+ log_dir=args.logs_dir,
551
+ filename=args.logging_filename,
552
+ max_bytes=args.logging_rotate_bytes,
553
+ backup_count=args.logging_rotate_backup_count,
554
+ )
555
+
556
+ node_ip = services.get_cached_node_ip_address(args.session_dir)
557
+ log_monitor = LogMonitor(
558
+ node_ip,
559
+ args.logs_dir,
560
+ ray._raylet.GcsPublisher(address=args.gcs_address),
561
+ is_proc_alive,
562
+ gcs_address=args.gcs_address,
563
+ )
564
+
565
+ try:
566
+ log_monitor.run()
567
+ except Exception as e:
568
+ # Something went wrong, so push an error to all drivers.
569
+ gcs_publisher = ray._raylet.GcsPublisher(address=args.gcs_address)
570
+ traceback_str = ray._private.utils.format_error_message(traceback.format_exc())
571
+ message = (
572
+ f"The log monitor on node {platform.node()} "
573
+ f"failed with the following error:\n{traceback_str}"
574
+ )
575
+ ray._private.utils.publish_error_to_driver(
576
+ ray_constants.LOG_MONITOR_DIED_ERROR,
577
+ message,
578
+ gcs_publisher=gcs_publisher,
579
+ )
580
+ logger.error(message)
581
+ raise e
.venv/lib/python3.11/site-packages/ray/_private/logging_utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.core.generated.logging_pb2 import LogBatch
2
+
3
+
4
+ def log_batch_dict_to_proto(log_json: dict) -> LogBatch:
5
+ """Converts a dict containing a batch of logs to a LogBatch proto."""
6
+ return LogBatch(
7
+ ip=log_json.get("ip"),
8
+ # Cast to support string pid like "gcs".
9
+ pid=str(log_json.get("pid")) if log_json.get("pid") else None,
10
+ # Job ID as a hex string.
11
+ job_id=log_json.get("job"),
12
+ is_error=bool(log_json.get("is_err")),
13
+ lines=log_json.get("lines"),
14
+ actor_name=log_json.get("actor_name"),
15
+ task_name=log_json.get("task_name"),
16
+ )
17
+
18
+
19
+ def log_batch_proto_to_dict(log_batch: LogBatch) -> dict:
20
+ """Converts a LogBatch proto to a dict containing a batch of logs."""
21
+ return {
22
+ "ip": log_batch.ip,
23
+ "pid": log_batch.pid,
24
+ "job": log_batch.job_id,
25
+ "is_err": log_batch.is_error,
26
+ "lines": log_batch.lines,
27
+ "actor_name": log_batch.actor_name,
28
+ "task_name": log_batch.task_name,
29
+ }
.venv/lib/python3.11/site-packages/ray/_private/memory_monitor.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import platform
4
+ import sys
5
+ import time
6
+
7
+ # Import ray before psutil will make sure we use psutil's bundled version
8
+ import ray # noqa F401
9
+ import psutil # noqa E402
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def get_rss(memory_info):
15
+ """Get the estimated non-shared memory usage from psutil memory_info."""
16
+ mem = memory_info.rss
17
+ # OSX doesn't have the shared attribute
18
+ if hasattr(memory_info, "shared"):
19
+ mem -= memory_info.shared
20
+ return mem
21
+
22
+
23
+ def get_shared(virtual_memory):
24
+ """Get the estimated shared memory usage from psutil virtual mem info."""
25
+ # OSX doesn't have the shared attribute
26
+ if hasattr(virtual_memory, "shared"):
27
+ return virtual_memory.shared
28
+ else:
29
+ return 0
30
+
31
+
32
+ def get_top_n_memory_usage(n: int = 10):
33
+ """Get the top n memory usage of the process
34
+
35
+ Params:
36
+ n: Number of top n process memory usage to return.
37
+ Returns:
38
+ (str) The formatted string of top n process memory usage.
39
+ """
40
+ pids = psutil.pids()
41
+ proc_stats = []
42
+ for pid in pids:
43
+ try:
44
+ proc = psutil.Process(pid)
45
+ proc_stats.append((get_rss(proc.memory_info()), pid, proc.cmdline()))
46
+ except psutil.NoSuchProcess:
47
+ # We should skip the process that has exited. Refer this
48
+ # issue for more detail:
49
+ # https://github.com/ray-project/ray/issues/14929
50
+ continue
51
+ except psutil.AccessDenied:
52
+ # On MacOS, the proc_pidinfo call (used to get per-process
53
+ # memory info) fails with a permission denied error when used
54
+ # on a process that isn’t owned by the same user. For now, we
55
+ # drop the memory info of any such process, assuming that
56
+ # processes owned by other users (e.g. root) aren't Ray
57
+ # processes and will be of less interest when an OOM happens
58
+ # on a Ray node.
59
+ # See issue for more detail:
60
+ # https://github.com/ray-project/ray/issues/11845#issuecomment-849904019 # noqa: E501
61
+ continue
62
+ proc_str = "PID\tMEM\tCOMMAND"
63
+ for rss, pid, cmdline in sorted(proc_stats, reverse=True)[:n]:
64
+ proc_str += "\n{}\t{}GiB\t{}".format(
65
+ pid, round(rss / (1024**3), 2), " ".join(cmdline)[:100].strip()
66
+ )
67
+ return proc_str
68
+
69
+
70
+ class RayOutOfMemoryError(Exception):
71
+ def __init__(self, msg):
72
+ Exception.__init__(self, msg)
73
+
74
+ @staticmethod
75
+ def get_message(used_gb, total_gb, threshold):
76
+ proc_str = get_top_n_memory_usage(n=10)
77
+ return (
78
+ "More than {}% of the memory on ".format(int(100 * threshold))
79
+ + "node {} is used ({} / {} GB). ".format(
80
+ platform.node(), round(used_gb, 2), round(total_gb, 2)
81
+ )
82
+ + f"The top 10 memory consumers are:\n\n{proc_str}"
83
+ + "\n\nIn addition, up to {} GiB of shared memory is ".format(
84
+ round(get_shared(psutil.virtual_memory()) / (1024**3), 2)
85
+ )
86
+ + "currently being used by the Ray object store.\n---\n"
87
+ "--- Tip: Use the `ray memory` command to list active "
88
+ "objects in the cluster.\n"
89
+ "--- To disable OOM exceptions, set "
90
+ "RAY_DISABLE_MEMORY_MONITOR=1.\n---\n"
91
+ )
92
+
93
+
94
+ class MemoryMonitor:
95
+ """Helper class for raising errors on low memory.
96
+
97
+ This presents a much cleaner error message to users than what would happen
98
+ if we actually ran out of memory.
99
+
100
+ The monitor tries to use the cgroup memory limit and usage if it is set
101
+ and available so that it is more reasonable inside containers. Otherwise,
102
+ it uses `psutil` to check the memory usage.
103
+
104
+ The environment variable `RAY_MEMORY_MONITOR_ERROR_THRESHOLD` can be used
105
+ to overwrite the default error_threshold setting.
106
+
107
+ Used by test only. For production code use memory_monitor.cc
108
+ """
109
+
110
+ def __init__(self, error_threshold=0.95, check_interval=1):
111
+ # Note: it takes ~50us to check the memory usage through psutil, so
112
+ # throttle this check at most once a second or so.
113
+ self.check_interval = check_interval
114
+ self.last_checked = 0
115
+ try:
116
+ self.error_threshold = float(
117
+ os.getenv("RAY_MEMORY_MONITOR_ERROR_THRESHOLD")
118
+ )
119
+ except (ValueError, TypeError):
120
+ self.error_threshold = error_threshold
121
+ # Try to read the cgroup memory limit if it is available.
122
+ try:
123
+ with open("/sys/fs/cgroup/memory/memory.limit_in_bytes", "rb") as f:
124
+ self.cgroup_memory_limit_gb = int(f.read()) / (1024**3)
125
+ except IOError:
126
+ self.cgroup_memory_limit_gb = sys.maxsize / (1024**3)
127
+ if not psutil:
128
+ logger.warn(
129
+ "WARNING: Not monitoring node memory since `psutil` "
130
+ "is not installed. Install this with "
131
+ "`pip install psutil` to enable "
132
+ "debugging of memory-related crashes."
133
+ )
134
+ self.disabled = (
135
+ "RAY_DEBUG_DISABLE_MEMORY_MONITOR" in os.environ
136
+ or "RAY_DISABLE_MEMORY_MONITOR" in os.environ
137
+ )
138
+
139
+ def get_memory_usage(self):
140
+ from ray._private.utils import get_system_memory, get_used_memory
141
+
142
+ total_gb = get_system_memory() / (1024**3)
143
+ used_gb = get_used_memory() / (1024**3)
144
+
145
+ return used_gb, total_gb
146
+
147
+ def raise_if_low_memory(self):
148
+ if self.disabled:
149
+ return
150
+
151
+ if time.time() - self.last_checked > self.check_interval:
152
+ self.last_checked = time.time()
153
+ used_gb, total_gb = self.get_memory_usage()
154
+
155
+ if used_gb > total_gb * self.error_threshold:
156
+ raise RayOutOfMemoryError(
157
+ RayOutOfMemoryError.get_message(
158
+ used_gb, total_gb, self.error_threshold
159
+ )
160
+ )
161
+ else:
162
+ logger.debug(f"Memory usage is {used_gb} / {total_gb}")
.venv/lib/python3.11/site-packages/ray/_private/metrics_agent.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import re
5
+ import threading
6
+ import time
7
+ import traceback
8
+ from collections import namedtuple
9
+ from typing import List, Tuple, Any, Dict, Set
10
+
11
+ from prometheus_client.core import (
12
+ CounterMetricFamily,
13
+ GaugeMetricFamily,
14
+ HistogramMetricFamily,
15
+ )
16
+ from opencensus.metrics.export.value import ValueDouble
17
+ from opencensus.metrics.export.metric_descriptor import MetricDescriptorType
18
+ from opencensus.stats import aggregation
19
+ from opencensus.stats import measure as measure_module
20
+ from opencensus.stats.view_manager import ViewManager
21
+ from opencensus.stats.stats_recorder import StatsRecorder
22
+ from opencensus.stats.base_exporter import StatsExporter
23
+ from prometheus_client.core import Metric as PrometheusMetric
24
+ from opencensus.stats.aggregation_data import (
25
+ CountAggregationData,
26
+ DistributionAggregationData,
27
+ LastValueAggregationData,
28
+ SumAggregationData,
29
+ )
30
+ from opencensus.stats.view import View
31
+ from opencensus.tags import tag_key as tag_key_module
32
+ from opencensus.tags import tag_map as tag_map_module
33
+ from opencensus.tags import tag_value as tag_value_module
34
+
35
+ import ray
36
+ from ray._raylet import GcsClient
37
+
38
+ from ray.core.generated.metrics_pb2 import Metric
39
+ from ray._private.ray_constants import env_bool
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+ # Env var key to decide worker timeout.
44
+ # If the worker doesn't report for more than
45
+ # this time, we treat workers as dead.
46
+ RAY_WORKER_TIMEOUT_S = "RAY_WORKER_TIMEOUT_S"
47
+ GLOBAL_COMPONENT_KEY = "CORE"
48
+ RE_NON_ALPHANUMS = re.compile(r"[^a-zA-Z0-9]")
49
+
50
+
51
+ class Gauge(View):
52
+ """Gauge representation of opencensus view.
53
+
54
+ This class is used to collect process metrics from the reporter agent.
55
+ Cpp metrics should be collected in a different way.
56
+ """
57
+
58
+ def __init__(self, name, description, unit, tags: List[str]):
59
+ self._measure = measure_module.MeasureInt(name, description, unit)
60
+ tags = [tag_key_module.TagKey(tag) for tag in tags]
61
+ self._view = View(
62
+ name, description, tags, self.measure, aggregation.LastValueAggregation()
63
+ )
64
+
65
+ @property
66
+ def measure(self):
67
+ return self._measure
68
+
69
+ @property
70
+ def view(self):
71
+ return self._view
72
+
73
+ @property
74
+ def name(self):
75
+ return self.measure.name
76
+
77
+
78
+ Record = namedtuple("Record", ["gauge", "value", "tags"])
79
+
80
+
81
+ def fix_grpc_metric(metric: Metric):
82
+ """
83
+ Fix the inbound `opencensus.proto.metrics.v1.Metric` protos to make it acceptable
84
+ by opencensus.stats.DistributionAggregationData.
85
+
86
+ - metric name: gRPC OpenCensus metrics have names with slashes and dots, e.g.
87
+ `grpc.io/client/server_latency`[1]. However Prometheus metric names only take
88
+ alphanums,underscores and colons[2]. We santinize the name by replacing non-alphanum
89
+ chars to underscore, like the official opencensus prometheus exporter[3].
90
+ - distribution bucket bounds: The Metric proto asks distribution bucket bounds to
91
+ be > 0 [4]. However, gRPC OpenCensus metrics have their first bucket bound == 0 [1].
92
+ This makes the `DistributionAggregationData` constructor to raise Exceptions. This
93
+ applies to all bytes and milliseconds (latencies). The fix: we update the initial 0
94
+ bounds to be 0.000_000_1. This will not affect the precision of the metrics, since
95
+ we don't expect any less-than-1 bytes, or less-than-1-nanosecond times.
96
+
97
+ [1] https://github.com/census-instrumentation/opencensus-specs/blob/master/stats/gRPC.md#units # noqa: E501
98
+ [2] https://prometheus.io/docs/concepts/data_model/#metric-names-and-labels
99
+ [3] https://github.com/census-instrumentation/opencensus-cpp/blob/50eb5de762e5f87e206c011a4f930adb1a1775b1/opencensus/exporters/stats/prometheus/internal/prometheus_utils.cc#L39 # noqa: E501
100
+ [4] https://github.com/census-instrumentation/opencensus-proto/blob/master/src/opencensus/proto/metrics/v1/metrics.proto#L218 # noqa: E501
101
+ """
102
+
103
+ if not metric.metric_descriptor.name.startswith("grpc.io/"):
104
+ return
105
+
106
+ metric.metric_descriptor.name = RE_NON_ALPHANUMS.sub(
107
+ "_", metric.metric_descriptor.name
108
+ )
109
+
110
+ for series in metric.timeseries:
111
+ for point in series.points:
112
+ if point.HasField("distribution_value"):
113
+ dist_value = point.distribution_value
114
+ bucket_bounds = dist_value.bucket_options.explicit.bounds
115
+ if len(bucket_bounds) > 0 and bucket_bounds[0] == 0:
116
+ bucket_bounds[0] = 0.000_000_1
117
+
118
+
119
+ class OpencensusProxyMetric:
120
+ def __init__(self, name: str, desc: str, unit: str, label_keys: List[str]):
121
+ """Represents the OpenCensus metrics that will be proxy exported."""
122
+ self._name = name
123
+ self._desc = desc
124
+ self._unit = unit
125
+ # -- The label keys of the metric --
126
+ self._label_keys = label_keys
127
+ # -- The data that needs to be proxy exported --
128
+ # tuple of label values -> data (OpenCesnsus Aggregation data)
129
+ self._data = {}
130
+
131
+ @property
132
+ def name(self):
133
+ return self._name
134
+
135
+ @property
136
+ def desc(self):
137
+ return self._desc
138
+
139
+ @property
140
+ def unit(self):
141
+ return self._unit
142
+
143
+ @property
144
+ def label_keys(self):
145
+ return self._label_keys
146
+
147
+ @property
148
+ def data(self):
149
+ return self._data
150
+
151
+ def record(self, metric: Metric):
152
+ """Parse the Opencensus Protobuf and store the data.
153
+
154
+ The data can be accessed via `data` API once recorded.
155
+ """
156
+ timeseries = metric.timeseries
157
+
158
+ if len(timeseries) == 0:
159
+ return
160
+
161
+ # Create the aggregation and fill it in the our stats
162
+ for series in timeseries:
163
+ labels = tuple(val.value for val in series.label_values)
164
+
165
+ # Aggregate points.
166
+ for point in series.points:
167
+ if (
168
+ metric.metric_descriptor.type
169
+ == MetricDescriptorType.CUMULATIVE_INT64
170
+ ):
171
+ data = CountAggregationData(point.int64_value)
172
+ elif (
173
+ metric.metric_descriptor.type
174
+ == MetricDescriptorType.CUMULATIVE_DOUBLE
175
+ ):
176
+ data = SumAggregationData(ValueDouble, point.double_value)
177
+ elif metric.metric_descriptor.type == MetricDescriptorType.GAUGE_DOUBLE:
178
+ data = LastValueAggregationData(ValueDouble, point.double_value)
179
+ elif (
180
+ metric.metric_descriptor.type
181
+ == MetricDescriptorType.CUMULATIVE_DISTRIBUTION
182
+ ):
183
+ dist_value = point.distribution_value
184
+ counts_per_bucket = [bucket.count for bucket in dist_value.buckets]
185
+ bucket_bounds = dist_value.bucket_options.explicit.bounds
186
+ data = DistributionAggregationData(
187
+ dist_value.sum / dist_value.count,
188
+ dist_value.count,
189
+ dist_value.sum_of_squared_deviation,
190
+ counts_per_bucket,
191
+ bucket_bounds,
192
+ )
193
+ else:
194
+ raise ValueError("Summary is not supported")
195
+ self._data[labels] = data
196
+
197
+
198
+ class Component:
199
+ def __init__(self, id: str):
200
+ """Represent a component that requests to proxy export metrics
201
+
202
+ Args:
203
+ id: Id of this component.
204
+ """
205
+ self.id = id
206
+ # -- The time this component reported its metrics last time --
207
+ # It is used to figure out if this component is stale.
208
+ self._last_reported_time = time.monotonic()
209
+ # -- Metrics requested to proxy export from this component --
210
+ # metrics_name (str) -> metric (OpencensusProxyMetric)
211
+ self._metrics = {}
212
+
213
+ @property
214
+ def metrics(self) -> Dict[str, OpencensusProxyMetric]:
215
+ """Return the metrics requested to proxy export from this component."""
216
+ return self._metrics
217
+
218
+ @property
219
+ def last_reported_time(self):
220
+ return self._last_reported_time
221
+
222
+ def record(self, metrics: List[Metric]):
223
+ """Parse the Opencensus protobuf and store metrics.
224
+
225
+ Metrics can be accessed via `metrics` API for proxy export.
226
+
227
+ Args:
228
+ metrics: A list of Opencensus protobuf for proxy export.
229
+ """
230
+ self._last_reported_time = time.monotonic()
231
+ for metric in metrics:
232
+ fix_grpc_metric(metric)
233
+ descriptor = metric.metric_descriptor
234
+ name = descriptor.name
235
+ label_keys = [label_key.key for label_key in descriptor.label_keys]
236
+
237
+ if name not in self._metrics:
238
+ self._metrics[name] = OpencensusProxyMetric(
239
+ name, descriptor.description, descriptor.unit, label_keys
240
+ )
241
+ self._metrics[name].record(metric)
242
+
243
+
244
+ class OpenCensusProxyCollector:
245
+ def __init__(self, namespace: str, component_timeout_s: int = 60):
246
+ """Prometheus collector implementation for opencensus proxy export.
247
+
248
+ Prometheus collector requires to implement `collect` which is
249
+ invoked whenever Prometheus queries the endpoint.
250
+
251
+ The class is thread-safe.
252
+
253
+ Args:
254
+ namespace: Prometheus namespace.
255
+ """
256
+ # -- Protect `self._components` --
257
+ self._components_lock = threading.Lock()
258
+ # -- Timeout until the component is marked as stale --
259
+ # Once the component is considered as stale,
260
+ # the metrics from that worker won't be exported.
261
+ self._component_timeout_s = component_timeout_s
262
+ # -- Prometheus namespace --
263
+ self._namespace = namespace
264
+ # -- Component that requests to proxy export metrics --
265
+ # Component means core worker, raylet, and GCS.
266
+ # component_id -> Components
267
+ # For workers, they contain worker ids.
268
+ # For other components (raylet, GCS),
269
+ # they contain the global key `GLOBAL_COMPONENT_KEY`.
270
+ self._components = {}
271
+ # Whether we want to export counter as gauge.
272
+ # This is for bug compatibility.
273
+ # See https://github.com/ray-project/ray/pull/43795.
274
+ self._export_counter_as_gauge = env_bool("RAY_EXPORT_COUNTER_AS_GAUGE", True)
275
+
276
+ def record(self, metrics: List[Metric], worker_id_hex: str = None):
277
+ """Record the metrics reported from the component that reports it.
278
+
279
+ Args:
280
+ metrics: A list of opencensus protobuf to proxy export metrics.
281
+ worker_id_hex: A worker id that reports these metrics.
282
+ If None, it means they are reported from Raylet or GCS.
283
+ """
284
+ key = GLOBAL_COMPONENT_KEY if not worker_id_hex else worker_id_hex
285
+ with self._components_lock:
286
+ if key not in self._components:
287
+ self._components[key] = Component(key)
288
+ self._components[key].record(metrics)
289
+
290
+ def clean_stale_components(self):
291
+ """Clean up stale components.
292
+
293
+ Stale means the component is dead or unresponsive.
294
+
295
+ Stale components won't be reported to Prometheus anymore.
296
+ """
297
+ with self._components_lock:
298
+ stale_components = []
299
+ stale_component_ids = []
300
+ for id, component in self._components.items():
301
+ elapsed = time.monotonic() - component.last_reported_time
302
+ if elapsed > self._component_timeout_s:
303
+ stale_component_ids.append(id)
304
+ logger.info(
305
+ "Metrics from a worker ({}) is cleaned up due to "
306
+ "timeout. Time since last report {}s".format(id, elapsed)
307
+ )
308
+ for id in stale_component_ids:
309
+ stale_components.append(self._components.pop(id))
310
+ return stale_components
311
+
312
+ # TODO(sang): add start and end timestamp
313
+ def to_metrics(
314
+ self,
315
+ metric_name: str,
316
+ metric_description: str,
317
+ label_keys: List[str],
318
+ metric_units: str,
319
+ label_values: Tuple[tag_value_module.TagValue],
320
+ agg_data: Any,
321
+ metrics_map: Dict[str, List[PrometheusMetric]],
322
+ ):
323
+ """to_metric translate the data that OpenCensus create
324
+ to Prometheus format, using Prometheus Metric object.
325
+
326
+ This method is from Opencensus Prometheus Exporter.
327
+
328
+ Args:
329
+ metric_name: Name of the metric.
330
+ metric_description: Description of the metric.
331
+ label_keys: The fixed label keys of the metric.
332
+ metric_units: Units of the metric.
333
+ label_values: The values of `label_keys`.
334
+ agg_data: `opencensus.stats.aggregation_data.AggregationData` object.
335
+ Aggregated data that needs to be converted as Prometheus samples
336
+ metrics_map: The converted metric is added to this map.
337
+
338
+ """
339
+ assert self._components_lock.locked()
340
+ metric_name = f"{self._namespace}_{metric_name}"
341
+ assert len(label_values) == len(label_keys), (label_values, label_keys)
342
+ # Prometheus requires that all tag values be strings hence
343
+ # the need to cast none to the empty string before exporting. See
344
+ # https://github.com/census-instrumentation/opencensus-python/issues/480
345
+ label_values = [tv if tv else "" for tv in label_values]
346
+
347
+ if isinstance(agg_data, CountAggregationData):
348
+ metrics = metrics_map.get(metric_name)
349
+ if not metrics:
350
+ metric = CounterMetricFamily(
351
+ name=metric_name,
352
+ documentation=metric_description,
353
+ unit=metric_units,
354
+ labels=label_keys,
355
+ )
356
+ metrics = [metric]
357
+ metrics_map[metric_name] = metrics
358
+ metrics[0].add_metric(labels=label_values, value=agg_data.count_data)
359
+ return
360
+
361
+ if isinstance(agg_data, SumAggregationData):
362
+ # This should be emitted as prometheus counter
363
+ # but we used to emit it as prometheus gauge.
364
+ # To keep the backward compatibility
365
+ # (changing from counter to gauge changes the metric name
366
+ # since prometheus client will add "_total" suffix to counter
367
+ # per OpenMetrics specification),
368
+ # we now emit both counter and gauge and in the
369
+ # next major Ray release (3.0) we can stop emitting gauge.
370
+ # This leaves people enough time to migrate their dashboards.
371
+ # See https://github.com/ray-project/ray/pull/43795.
372
+ metrics = metrics_map.get(metric_name)
373
+ if not metrics:
374
+ metric = CounterMetricFamily(
375
+ name=metric_name,
376
+ documentation=metric_description,
377
+ labels=label_keys,
378
+ )
379
+ metrics = [metric]
380
+ metrics_map[metric_name] = metrics
381
+ metrics[0].add_metric(labels=label_values, value=agg_data.sum_data)
382
+
383
+ if not self._export_counter_as_gauge:
384
+ pass
385
+ elif metric_name.endswith("_total"):
386
+ # In this case, we only need to emit prometheus counter
387
+ # since for metric name already ends with _total suffix
388
+ # prometheus client won't change it
389
+ # so there is no backward compatibility issue.
390
+ # See https://prometheus.github.io/client_python/instrumenting/counter/
391
+ pass
392
+ else:
393
+ if len(metrics) == 1:
394
+ metric = GaugeMetricFamily(
395
+ name=metric_name,
396
+ documentation=(
397
+ f"(DEPRECATED, use {metric_name}_total metric instead) "
398
+ f"{metric_description}"
399
+ ),
400
+ labels=label_keys,
401
+ )
402
+ metrics.append(metric)
403
+ assert len(metrics) == 2
404
+ metrics[1].add_metric(labels=label_values, value=agg_data.sum_data)
405
+ return
406
+
407
+ elif isinstance(agg_data, DistributionAggregationData):
408
+
409
+ assert agg_data.bounds == sorted(agg_data.bounds)
410
+ # buckets are a list of buckets. Each bucket is another list with
411
+ # a pair of bucket name and value, or a triple of bucket name,
412
+ # value, and exemplar. buckets need to be in order.
413
+ buckets = []
414
+ cum_count = 0 # Prometheus buckets expect cumulative count.
415
+ for ii, bound in enumerate(agg_data.bounds):
416
+ cum_count += agg_data.counts_per_bucket[ii]
417
+ bucket = [str(bound), cum_count]
418
+ buckets.append(bucket)
419
+ # Prometheus requires buckets to be sorted, and +Inf present.
420
+ # In OpenCensus we don't have +Inf in the bucket bonds so need to
421
+ # append it here.
422
+ buckets.append(["+Inf", agg_data.count_data])
423
+ metrics = metrics_map.get(metric_name)
424
+ if not metrics:
425
+ metric = HistogramMetricFamily(
426
+ name=metric_name,
427
+ documentation=metric_description,
428
+ labels=label_keys,
429
+ )
430
+ metrics = [metric]
431
+ metrics_map[metric_name] = metrics
432
+ metrics[0].add_metric(
433
+ labels=label_values,
434
+ buckets=buckets,
435
+ sum_value=agg_data.sum,
436
+ )
437
+ return
438
+
439
+ elif isinstance(agg_data, LastValueAggregationData):
440
+ metrics = metrics_map.get(metric_name)
441
+ if not metrics:
442
+ metric = GaugeMetricFamily(
443
+ name=metric_name,
444
+ documentation=metric_description,
445
+ labels=label_keys,
446
+ )
447
+ metrics = [metric]
448
+ metrics_map[metric_name] = metrics
449
+ metrics[0].add_metric(labels=label_values, value=agg_data.value)
450
+ return
451
+
452
+ else:
453
+ raise ValueError(f"unsupported aggregation type {type(agg_data)}")
454
+
455
+ def collect(self): # pragma: NO COVER
456
+ """Collect fetches the statistics from OpenCensus
457
+ and delivers them as Prometheus Metrics.
458
+ Collect is invoked every time a prometheus.Gatherer is run
459
+ for example when the HTTP endpoint is invoked by Prometheus.
460
+
461
+ This method is required as a Prometheus Collector.
462
+ """
463
+ with self._components_lock:
464
+ metrics_map = {}
465
+ for component in self._components.values():
466
+ for metric in component.metrics.values():
467
+ for label_values, data in metric.data.items():
468
+ self.to_metrics(
469
+ metric.name,
470
+ metric.desc,
471
+ metric.label_keys,
472
+ metric.unit,
473
+ label_values,
474
+ data,
475
+ metrics_map,
476
+ )
477
+
478
+ for metrics in metrics_map.values():
479
+ for metric in metrics:
480
+ yield metric
481
+
482
+
483
+ class MetricsAgent:
484
+ def __init__(
485
+ self,
486
+ view_manager: ViewManager,
487
+ stats_recorder: StatsRecorder,
488
+ stats_exporter: StatsExporter = None,
489
+ ):
490
+ """A class to record and export metrics.
491
+
492
+ The class exports metrics in 2 different ways.
493
+ - Directly record and export metrics using OpenCensus.
494
+ - Proxy metrics from other core components
495
+ (e.g., raylet, GCS, core workers).
496
+
497
+ This class is thread-safe.
498
+ """
499
+ # Lock required because gRPC server uses
500
+ # multiple threads to process requests.
501
+ self._lock = threading.Lock()
502
+
503
+ #
504
+ # Opencensus components to record metrics.
505
+ #
506
+
507
+ # Managing views to export metrics
508
+ # If the stats_exporter is None, we disable all metrics export.
509
+ self.view_manager = view_manager
510
+ # A class that's used to record metrics
511
+ # emitted from the current process.
512
+ self.stats_recorder = stats_recorder
513
+ # A class to export metrics.
514
+ self.stats_exporter = stats_exporter
515
+ # -- A Prometheus custom collector to proxy export metrics --
516
+ # `None` if the prometheus server is not started.
517
+ self.proxy_exporter_collector = None
518
+
519
+ if self.stats_exporter is None:
520
+ # If the exporter is not given,
521
+ # we disable metrics collection.
522
+ self.view_manager = None
523
+ else:
524
+ self.view_manager.register_exporter(stats_exporter)
525
+ self.proxy_exporter_collector = OpenCensusProxyCollector(
526
+ self.stats_exporter.options.namespace,
527
+ component_timeout_s=int(os.getenv(RAY_WORKER_TIMEOUT_S, 120)),
528
+ )
529
+
530
+ # Registered view names.
531
+ self._registered_views: Set[str] = set()
532
+
533
+ def record_and_export(self, records: List[Record], global_tags=None):
534
+ """Directly record and export stats from the same process."""
535
+ global_tags = global_tags or {}
536
+ with self._lock:
537
+ if not self.view_manager:
538
+ return
539
+
540
+ for record in records:
541
+ gauge = record.gauge
542
+ value = record.value
543
+ tags = record.tags
544
+ self._record_gauge(gauge, value, {**tags, **global_tags})
545
+
546
+ def _record_gauge(self, gauge: Gauge, value: float, tags: dict):
547
+ if gauge.name not in self._registered_views:
548
+ self.view_manager.register_view(gauge.view)
549
+ self._registered_views.add(gauge.name)
550
+ measurement_map = self.stats_recorder.new_measurement_map()
551
+ tag_map = tag_map_module.TagMap()
552
+ for key, tag_val in tags.items():
553
+ tag_key = tag_key_module.TagKey(key)
554
+ tag_value = tag_value_module.TagValue(tag_val)
555
+ tag_map.insert(tag_key, tag_value)
556
+ measurement_map.measure_float_put(gauge.measure, value)
557
+ # NOTE: When we record this metric, timestamp will be renewed.
558
+ measurement_map.record(tag_map)
559
+
560
+ def proxy_export_metrics(self, metrics: List[Metric], worker_id_hex: str = None):
561
+ """Proxy export metrics specified by a Opencensus Protobuf.
562
+
563
+ This API is used to export metrics emitted from
564
+ core components.
565
+
566
+ Args:
567
+ metrics: A list of protobuf Metric defined from OpenCensus.
568
+ worker_id_hex: The worker ID it proxies metrics export. None
569
+ if the metric is not from a worker (i.e., raylet, GCS).
570
+ """
571
+ with self._lock:
572
+ if not self.view_manager:
573
+ return
574
+
575
+ self._proxy_export_metrics(metrics, worker_id_hex)
576
+
577
+ def _proxy_export_metrics(self, metrics: List[Metric], worker_id_hex: str = None):
578
+ self.proxy_exporter_collector.record(metrics, worker_id_hex)
579
+
580
+ def clean_all_dead_worker_metrics(self):
581
+ """Clean dead worker's metrics.
582
+
583
+ Worker metrics are cleaned up and won't be exported once
584
+ it is considered as dead.
585
+
586
+ This method has to be periodically called by a caller.
587
+ """
588
+ with self._lock:
589
+ if not self.view_manager:
590
+ return
591
+
592
+ self.proxy_exporter_collector.clean_stale_components()
593
+
594
+
595
+ class PrometheusServiceDiscoveryWriter(threading.Thread):
596
+ """A class to support Prometheus service discovery.
597
+
598
+ It supports file-based service discovery. Checkout
599
+ https://prometheus.io/docs/guides/file-sd/ for more details.
600
+
601
+ Args:
602
+ gcs_address: Gcs address for this cluster.
603
+ temp_dir: Temporary directory used by
604
+ Ray to store logs and metadata.
605
+ """
606
+
607
+ def __init__(self, gcs_address, temp_dir):
608
+ gcs_client_options = ray._raylet.GcsClientOptions.create(
609
+ gcs_address, None, allow_cluster_id_nil=True, fetch_cluster_id_if_nil=False
610
+ )
611
+ self.gcs_address = gcs_address
612
+
613
+ ray._private.state.state._initialize_global_state(gcs_client_options)
614
+ self.temp_dir = temp_dir
615
+ self.default_service_discovery_flush_period = 5
616
+ super().__init__()
617
+
618
+ def get_file_discovery_content(self):
619
+ """Return the content for Prometheus service discovery."""
620
+ nodes = ray.nodes()
621
+ metrics_export_addresses = [
622
+ "{}:{}".format(node["NodeManagerAddress"], node["MetricsExportPort"])
623
+ for node in nodes
624
+ if node["alive"] is True
625
+ ]
626
+ gcs_client = GcsClient(address=self.gcs_address)
627
+ autoscaler_addr = gcs_client.internal_kv_get(b"AutoscalerMetricsAddress", None)
628
+ if autoscaler_addr:
629
+ metrics_export_addresses.append(autoscaler_addr.decode("utf-8"))
630
+ dashboard_addr = gcs_client.internal_kv_get(b"DashboardMetricsAddress", None)
631
+ if dashboard_addr:
632
+ metrics_export_addresses.append(dashboard_addr.decode("utf-8"))
633
+ return json.dumps(
634
+ [{"labels": {"job": "ray"}, "targets": metrics_export_addresses}]
635
+ )
636
+
637
+ def write(self):
638
+ # Write a file based on https://prometheus.io/docs/guides/file-sd/
639
+ # Write should be atomic. Otherwise, Prometheus raises an error that
640
+ # json file format is invalid because it reads a file when
641
+ # file is re-written. Note that Prometheus still works although we
642
+ # have this error.
643
+ temp_file_name = self.get_temp_file_name()
644
+ with open(temp_file_name, "w") as json_file:
645
+ json_file.write(self.get_file_discovery_content())
646
+ # NOTE: os.replace is atomic on both Linux and Windows, so we won't
647
+ # have race condition reading this file.
648
+ os.replace(temp_file_name, self.get_target_file_name())
649
+
650
+ def get_target_file_name(self):
651
+ return os.path.join(
652
+ self.temp_dir, ray._private.ray_constants.PROMETHEUS_SERVICE_DISCOVERY_FILE
653
+ )
654
+
655
+ def get_temp_file_name(self):
656
+ return os.path.join(
657
+ self.temp_dir,
658
+ "{}_{}".format(
659
+ "tmp", ray._private.ray_constants.PROMETHEUS_SERVICE_DISCOVERY_FILE
660
+ ),
661
+ )
662
+
663
+ def run(self):
664
+ while True:
665
+ # This thread won't be broken by exceptions.
666
+ try:
667
+ self.write()
668
+ except Exception as e:
669
+ logger.warning(
670
+ "Writing a service discovery file, {},"
671
+ "failed.".format(self.get_target_file_name())
672
+ )
673
+ logger.warning(traceback.format_exc())
674
+ logger.warning(f"Error message: {e}")
675
+ time.sleep(self.default_service_discovery_flush_period)
.venv/lib/python3.11/site-packages/ray/_private/node.py ADDED
@@ -0,0 +1,1862 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ import collections
3
+ import datetime
4
+ import errno
5
+ import json
6
+ import logging
7
+ import os
8
+ import random
9
+ import signal
10
+ import socket
11
+ import subprocess
12
+ import sys
13
+ import tempfile
14
+ import threading
15
+ import time
16
+ import traceback
17
+ from collections import defaultdict
18
+ from typing import Dict, Optional, Tuple, IO, AnyStr
19
+
20
+ from filelock import FileLock
21
+
22
+ import ray
23
+ import ray._private.ray_constants as ray_constants
24
+ import ray._private.services
25
+ from ray._private import storage
26
+ from ray._raylet import GcsClient, get_session_key_from_storage
27
+ from ray._private.resource_spec import ResourceSpec
28
+ from ray._private.services import serialize_config, get_address
29
+ from ray._private.utils import open_log, try_to_create_directory, try_to_symlink
30
+
31
+ # Logger for this module. It should be configured at the entry point
32
+ # into the program using Ray. Ray configures it by default automatically
33
+ # using logging.basicConfig in its entry/init points.
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ class Node:
38
+ """An encapsulation of the Ray processes on a single node.
39
+
40
+ This class is responsible for starting Ray processes and killing them,
41
+ and it also controls the temp file policy.
42
+
43
+ Attributes:
44
+ all_processes: A mapping from process type (str) to a list of
45
+ ProcessInfo objects. All lists have length one except for the Redis
46
+ server list, which has multiple.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ ray_params,
52
+ head: bool = False,
53
+ shutdown_at_exit: bool = True,
54
+ spawn_reaper: bool = True,
55
+ connect_only: bool = False,
56
+ default_worker: bool = False,
57
+ ray_init_cluster: bool = False,
58
+ ):
59
+ """Start a node.
60
+
61
+ Args:
62
+ ray_params: The RayParams to use to configure the node.
63
+ head: True if this is the head node, which means it will
64
+ start additional processes like the Redis servers, monitor
65
+ processes, and web UI.
66
+ shutdown_at_exit: If true, spawned processes will be cleaned
67
+ up if this process exits normally.
68
+ spawn_reaper: If true, spawns a process that will clean up
69
+ other spawned processes if this process dies unexpectedly.
70
+ connect_only: If true, connect to the node without starting
71
+ new processes.
72
+ default_worker: Whether it's running from a ray worker or not
73
+ ray_init_cluster: Whether it's a cluster created by ray.init()
74
+ """
75
+ if shutdown_at_exit:
76
+ if connect_only:
77
+ raise ValueError(
78
+ "'shutdown_at_exit' and 'connect_only' cannot both be true."
79
+ )
80
+ self._register_shutdown_hooks()
81
+ self._default_worker = default_worker
82
+ self.head = head
83
+ self.kernel_fate_share = bool(
84
+ spawn_reaper and ray._private.utils.detect_fate_sharing_support()
85
+ )
86
+ self.all_processes: dict = {}
87
+ self.removal_lock = threading.Lock()
88
+
89
+ self.ray_init_cluster = ray_init_cluster
90
+ if ray_init_cluster:
91
+ assert head, "ray.init() created cluster only has the head node"
92
+
93
+ # Set up external Redis when `RAY_REDIS_ADDRESS` is specified.
94
+ redis_address_env = os.environ.get("RAY_REDIS_ADDRESS")
95
+ if ray_params.external_addresses is None and redis_address_env is not None:
96
+ external_redis = redis_address_env.split(",")
97
+
98
+ # Reuse primary Redis as Redis shard when there's only one
99
+ # instance provided.
100
+ if len(external_redis) == 1:
101
+ external_redis.append(external_redis[0])
102
+ [primary_redis_ip, port] = external_redis[0].rsplit(":", 1)
103
+ ray_params.external_addresses = external_redis
104
+ ray_params.num_redis_shards = len(external_redis) - 1
105
+
106
+ if (
107
+ ray_params._system_config
108
+ and len(ray_params._system_config) > 0
109
+ and (not head and not connect_only)
110
+ ):
111
+ raise ValueError(
112
+ "System config parameters can only be set on the head node."
113
+ )
114
+
115
+ ray_params.update_if_absent(
116
+ include_log_monitor=True,
117
+ resources={},
118
+ worker_path=os.path.join(
119
+ os.path.dirname(os.path.abspath(__file__)),
120
+ "workers",
121
+ "default_worker.py",
122
+ ),
123
+ setup_worker_path=os.path.join(
124
+ os.path.dirname(os.path.abspath(__file__)),
125
+ "workers",
126
+ ray_constants.SETUP_WORKER_FILENAME,
127
+ ),
128
+ )
129
+
130
+ self._resource_spec = None
131
+ self._localhost = socket.gethostbyname("localhost")
132
+ self._ray_params = ray_params
133
+ self._config = ray_params._system_config or {}
134
+
135
+ self._dashboard_agent_listen_port = ray_params.dashboard_agent_listen_port
136
+ self._dashboard_grpc_port = ray_params.dashboard_grpc_port
137
+
138
+ # Configure log rotation parameters.
139
+ self.max_bytes = int(
140
+ os.getenv("RAY_ROTATION_MAX_BYTES", ray_constants.LOGGING_ROTATE_BYTES)
141
+ )
142
+ self.backup_count = int(
143
+ os.getenv(
144
+ "RAY_ROTATION_BACKUP_COUNT", ray_constants.LOGGING_ROTATE_BACKUP_COUNT
145
+ )
146
+ )
147
+
148
+ assert self.max_bytes >= 0
149
+ assert self.backup_count >= 0
150
+
151
+ self._redis_address = ray_params.redis_address
152
+ if head:
153
+ ray_params.update_if_absent(num_redis_shards=1)
154
+ self._gcs_address = ray_params.gcs_address
155
+ self._gcs_client = None
156
+
157
+ if not self.head:
158
+ self.validate_ip_port(self.address)
159
+ self._init_gcs_client()
160
+
161
+ # Register the temp dir.
162
+ self._session_name = ray_params.session_name
163
+ if self._session_name is None:
164
+ if head:
165
+ # We expect this the first time we initialize a cluster, but not during
166
+ # subsequent restarts of the head node.
167
+ maybe_key = self.check_persisted_session_name()
168
+ if maybe_key is None:
169
+ # date including microsecond
170
+ date_str = datetime.datetime.today().strftime(
171
+ "%Y-%m-%d_%H-%M-%S_%f"
172
+ )
173
+ self._session_name = f"session_{date_str}_{os.getpid()}"
174
+ else:
175
+ self._session_name = ray._private.utils.decode(maybe_key)
176
+ else:
177
+ assert not self._default_worker
178
+ session_name = ray._private.utils.internal_kv_get_with_retry(
179
+ self.get_gcs_client(),
180
+ "session_name",
181
+ ray_constants.KV_NAMESPACE_SESSION,
182
+ num_retries=ray_constants.NUM_REDIS_GET_RETRIES,
183
+ )
184
+ self._session_name = ray._private.utils.decode(session_name)
185
+
186
+ # Initialize webui url
187
+ if head:
188
+ self._webui_url = None
189
+ else:
190
+ if ray_params.webui is None:
191
+ assert not self._default_worker
192
+ self._webui_url = ray._private.services.get_webui_url_from_internal_kv()
193
+ else:
194
+ self._webui_url = (
195
+ f"{ray_params.dashboard_host}:{ray_params.dashboard_port}"
196
+ )
197
+
198
+ # It creates a session_dir.
199
+ self._init_temp()
200
+
201
+ node_ip_address = ray_params.node_ip_address
202
+ if node_ip_address is None:
203
+ if connect_only:
204
+ node_ip_address = self._wait_and_get_for_node_address()
205
+ else:
206
+ node_ip_address = ray.util.get_node_ip_address()
207
+
208
+ assert node_ip_address is not None
209
+ ray_params.update_if_absent(
210
+ node_ip_address=node_ip_address, raylet_ip_address=node_ip_address
211
+ )
212
+ self._node_ip_address = node_ip_address
213
+ if not connect_only:
214
+ ray._private.services.write_node_ip_address(
215
+ self.get_session_dir_path(), node_ip_address
216
+ )
217
+
218
+ if ray_params.raylet_ip_address:
219
+ raylet_ip_address = ray_params.raylet_ip_address
220
+ else:
221
+ raylet_ip_address = node_ip_address
222
+
223
+ if raylet_ip_address != node_ip_address and (not connect_only or head):
224
+ raise ValueError(
225
+ "The raylet IP address should only be different than the node "
226
+ "IP address when connecting to an existing raylet; i.e., when "
227
+ "head=False and connect_only=True."
228
+ )
229
+ self._raylet_ip_address = raylet_ip_address
230
+
231
+ # Validate and initialize the persistent storage API.
232
+ if head:
233
+ storage._init_storage(ray_params.storage, is_head=True)
234
+ else:
235
+ if not self._default_worker:
236
+ storage_uri = ray._private.services.get_storage_uri_from_internal_kv()
237
+ else:
238
+ storage_uri = ray_params.storage
239
+ storage._init_storage(storage_uri, is_head=False)
240
+
241
+ # If it is a head node, try validating if
242
+ # external storage is configurable.
243
+ if head:
244
+ self.validate_external_storage()
245
+
246
+ if connect_only:
247
+ # Get socket names from the configuration.
248
+ self._plasma_store_socket_name = ray_params.plasma_store_socket_name
249
+ self._raylet_socket_name = ray_params.raylet_socket_name
250
+ self._node_id = ray_params.node_id
251
+
252
+ # If user does not provide the socket name, get it from Redis.
253
+ if (
254
+ self._plasma_store_socket_name is None
255
+ or self._raylet_socket_name is None
256
+ or self._ray_params.node_manager_port is None
257
+ or self._node_id is None
258
+ ):
259
+ # Get the address info of the processes to connect to
260
+ # from Redis or GCS.
261
+ node_info = ray._private.services.get_node_to_connect_for_driver(
262
+ self.gcs_address,
263
+ self._raylet_ip_address,
264
+ )
265
+ self._plasma_store_socket_name = node_info["object_store_socket_name"]
266
+ self._raylet_socket_name = node_info["raylet_socket_name"]
267
+ self._ray_params.node_manager_port = node_info["node_manager_port"]
268
+ self._node_id = node_info["node_id"]
269
+ else:
270
+ # If the user specified a socket name, use it.
271
+ self._plasma_store_socket_name = self._prepare_socket_file(
272
+ self._ray_params.plasma_store_socket_name, default_prefix="plasma_store"
273
+ )
274
+ self._raylet_socket_name = self._prepare_socket_file(
275
+ self._ray_params.raylet_socket_name, default_prefix="raylet"
276
+ )
277
+ if (
278
+ self._ray_params.env_vars is not None
279
+ and "RAY_OVERRIDE_NODE_ID_FOR_TESTING" in self._ray_params.env_vars
280
+ ):
281
+ node_id = self._ray_params.env_vars["RAY_OVERRIDE_NODE_ID_FOR_TESTING"]
282
+ logger.debug(
283
+ f"Setting node ID to {node_id} "
284
+ "based on ray_params.env_vars override"
285
+ )
286
+ self._node_id = node_id
287
+ elif os.environ.get("RAY_OVERRIDE_NODE_ID_FOR_TESTING"):
288
+ node_id = os.environ["RAY_OVERRIDE_NODE_ID_FOR_TESTING"]
289
+ logger.debug(f"Setting node ID to {node_id} based on env override")
290
+ self._node_id = node_id
291
+ else:
292
+ node_id = ray.NodeID.from_random().hex()
293
+ logger.debug(f"Setting node ID to {node_id}")
294
+ self._node_id = node_id
295
+
296
+ # The dashboard agent port is assigned first to avoid
297
+ # other processes accidentally taking its default port
298
+ self._dashboard_agent_listen_port = self._get_cached_port(
299
+ "dashboard_agent_listen_port",
300
+ default_port=ray_params.dashboard_agent_listen_port,
301
+ )
302
+
303
+ self.metrics_agent_port = self._get_cached_port(
304
+ "metrics_agent_port", default_port=ray_params.metrics_agent_port
305
+ )
306
+ self._metrics_export_port = self._get_cached_port(
307
+ "metrics_export_port", default_port=ray_params.metrics_export_port
308
+ )
309
+ self._runtime_env_agent_port = self._get_cached_port(
310
+ "runtime_env_agent_port",
311
+ default_port=ray_params.runtime_env_agent_port,
312
+ )
313
+
314
+ ray_params.update_if_absent(
315
+ metrics_agent_port=self.metrics_agent_port,
316
+ metrics_export_port=self._metrics_export_port,
317
+ dashboard_agent_listen_port=self._dashboard_agent_listen_port,
318
+ runtime_env_agent_port=self._runtime_env_agent_port,
319
+ )
320
+
321
+ # Pick a GCS server port.
322
+ if head:
323
+ gcs_server_port = os.getenv(ray_constants.GCS_PORT_ENVIRONMENT_VARIABLE)
324
+ if gcs_server_port:
325
+ ray_params.update_if_absent(gcs_server_port=int(gcs_server_port))
326
+ if ray_params.gcs_server_port is None or ray_params.gcs_server_port == 0:
327
+ ray_params.gcs_server_port = self._get_cached_port("gcs_server_port")
328
+
329
+ if not connect_only and spawn_reaper and not self.kernel_fate_share:
330
+ self.start_reaper_process()
331
+ if not connect_only:
332
+ self._ray_params.update_pre_selected_port()
333
+
334
+ # Start processes.
335
+ if head:
336
+ self.start_head_processes()
337
+
338
+ if not connect_only:
339
+ self.start_ray_processes()
340
+ # we should update the address info after the node has been started
341
+ try:
342
+ ray._private.services.wait_for_node(
343
+ self.gcs_address,
344
+ self._plasma_store_socket_name,
345
+ )
346
+ except TimeoutError as te:
347
+ raise Exception(
348
+ "The current node timed out during startup. This "
349
+ "could happen because some of the Ray processes "
350
+ "failed to startup."
351
+ ) from te
352
+ node_info = ray._private.services.get_node(
353
+ self.gcs_address,
354
+ self._node_id,
355
+ )
356
+ if self._ray_params.node_manager_port == 0:
357
+ self._ray_params.node_manager_port = node_info["node_manager_port"]
358
+
359
+ # Makes sure the Node object has valid addresses after setup.
360
+ self.validate_ip_port(self.address)
361
+ self.validate_ip_port(self.gcs_address)
362
+
363
+ if not connect_only:
364
+ self._record_stats()
365
+
366
+ def check_persisted_session_name(self):
367
+ if self._ray_params.external_addresses is None:
368
+ return None
369
+ self._redis_address = self._ray_params.external_addresses[0]
370
+ redis_ip_address, redis_port, enable_redis_ssl = get_address(
371
+ self._redis_address,
372
+ )
373
+ # Address is ip:port or redis://ip:port
374
+ if int(redis_port) < 0:
375
+ raise ValueError(
376
+ f"Invalid Redis port provided: {redis_port}."
377
+ "The port must be a non-negative integer."
378
+ )
379
+
380
+ return get_session_key_from_storage(
381
+ redis_ip_address,
382
+ int(redis_port),
383
+ self._ray_params.redis_username,
384
+ self._ray_params.redis_password,
385
+ enable_redis_ssl,
386
+ serialize_config(self._config),
387
+ b"session_name",
388
+ )
389
+
390
+ @staticmethod
391
+ def validate_ip_port(ip_port):
392
+ """Validates the address is in the ip:port format"""
393
+ _, _, port = ip_port.rpartition(":")
394
+ if port == ip_port:
395
+ raise ValueError(f"Port is not specified for address {ip_port}")
396
+ try:
397
+ _ = int(port)
398
+ except ValueError:
399
+ raise ValueError(
400
+ f"Unable to parse port number from {port} (full address = {ip_port})"
401
+ )
402
+
403
+ def check_version_info(self):
404
+ """Check if the Python and Ray version of this process matches that in GCS.
405
+
406
+ This will be used to detect if workers or drivers are started using
407
+ different versions of Python, or Ray.
408
+
409
+ Raises:
410
+ Exception: An exception is raised if there is a version mismatch.
411
+ """
412
+ import ray._private.usage.usage_lib as ray_usage_lib
413
+
414
+ cluster_metadata = ray_usage_lib.get_cluster_metadata(self.get_gcs_client())
415
+ if cluster_metadata is None:
416
+ cluster_metadata = ray_usage_lib.get_cluster_metadata(self.get_gcs_client())
417
+
418
+ if not cluster_metadata:
419
+ return
420
+ node_ip_address = ray._private.services.get_node_ip_address()
421
+ ray._private.utils.check_version_info(
422
+ cluster_metadata, f"node {node_ip_address}"
423
+ )
424
+
425
+ def _register_shutdown_hooks(self):
426
+ # Register the atexit handler. In this case, we shouldn't call sys.exit
427
+ # as we're already in the exit procedure.
428
+ def atexit_handler(*args):
429
+ self.kill_all_processes(check_alive=False, allow_graceful=True)
430
+
431
+ atexit.register(atexit_handler)
432
+
433
+ # Register the handler to be called if we get a SIGTERM.
434
+ # In this case, we want to exit with an error code (1) after
435
+ # cleaning up child processes.
436
+ def sigterm_handler(signum, frame):
437
+ self.kill_all_processes(check_alive=False, allow_graceful=True)
438
+ sys.exit(1)
439
+
440
+ ray._private.utils.set_sigterm_handler(sigterm_handler)
441
+
442
+ def _init_temp(self):
443
+ # Create a dictionary to store temp file index.
444
+ self._incremental_dict = collections.defaultdict(lambda: 0)
445
+
446
+ if self.head:
447
+ self._ray_params.update_if_absent(
448
+ temp_dir=ray._private.utils.get_ray_temp_dir()
449
+ )
450
+ self._temp_dir = self._ray_params.temp_dir
451
+ else:
452
+ if self._ray_params.temp_dir is None:
453
+ assert not self._default_worker
454
+ temp_dir = ray._private.utils.internal_kv_get_with_retry(
455
+ self.get_gcs_client(),
456
+ "temp_dir",
457
+ ray_constants.KV_NAMESPACE_SESSION,
458
+ num_retries=ray_constants.NUM_REDIS_GET_RETRIES,
459
+ )
460
+ self._temp_dir = ray._private.utils.decode(temp_dir)
461
+ else:
462
+ self._temp_dir = self._ray_params.temp_dir
463
+
464
+ try_to_create_directory(self._temp_dir)
465
+
466
+ if self.head:
467
+ self._session_dir = os.path.join(self._temp_dir, self._session_name)
468
+ else:
469
+ if self._temp_dir is None or self._session_name is None:
470
+ assert not self._default_worker
471
+ session_dir = ray._private.utils.internal_kv_get_with_retry(
472
+ self.get_gcs_client(),
473
+ "session_dir",
474
+ ray_constants.KV_NAMESPACE_SESSION,
475
+ num_retries=ray_constants.NUM_REDIS_GET_RETRIES,
476
+ )
477
+ self._session_dir = ray._private.utils.decode(session_dir)
478
+ else:
479
+ self._session_dir = os.path.join(self._temp_dir, self._session_name)
480
+ session_symlink = os.path.join(self._temp_dir, ray_constants.SESSION_LATEST)
481
+
482
+ # Send a warning message if the session exists.
483
+ try_to_create_directory(self._session_dir)
484
+ try_to_symlink(session_symlink, self._session_dir)
485
+ # Create a directory to be used for socket files.
486
+ self._sockets_dir = os.path.join(self._session_dir, "sockets")
487
+ try_to_create_directory(self._sockets_dir)
488
+ # Create a directory to be used for process log files.
489
+ self._logs_dir = os.path.join(self._session_dir, "logs")
490
+ try_to_create_directory(self._logs_dir)
491
+ old_logs_dir = os.path.join(self._logs_dir, "old")
492
+ try_to_create_directory(old_logs_dir)
493
+ # Create a directory to be used for runtime environment.
494
+ self._runtime_env_dir = os.path.join(
495
+ self._session_dir, self._ray_params.runtime_env_dir_name
496
+ )
497
+ try_to_create_directory(self._runtime_env_dir)
498
+
499
+ def _get_node_labels(self):
500
+ def merge_labels(env_override_labels, params_labels):
501
+ """Merges two dictionaries, picking from the
502
+ first in the event of a conflict. Also emit a warning on every
503
+ conflict.
504
+ """
505
+
506
+ result = params_labels.copy()
507
+ result.update(env_override_labels)
508
+
509
+ for key in set(env_override_labels.keys()).intersection(
510
+ set(params_labels.keys())
511
+ ):
512
+ if params_labels[key] != env_override_labels[key]:
513
+ logger.warning(
514
+ "Autoscaler is overriding your label:"
515
+ f"{key}: {params_labels[key]} to "
516
+ f"{key}: {env_override_labels[key]}."
517
+ )
518
+ return result
519
+
520
+ env_override_labels = {}
521
+ env_override_labels_string = os.getenv(
522
+ ray_constants.LABELS_ENVIRONMENT_VARIABLE
523
+ )
524
+ if env_override_labels_string:
525
+ try:
526
+ env_override_labels = json.loads(env_override_labels_string)
527
+ except Exception:
528
+ logger.exception(f"Failed to load {env_override_labels_string}")
529
+ raise
530
+ logger.info(f"Autoscaler overriding labels: {env_override_labels}.")
531
+
532
+ return merge_labels(env_override_labels, self._ray_params.labels or {})
533
+
534
+ def get_resource_spec(self):
535
+ """Resolve and return the current resource spec for the node."""
536
+
537
+ def merge_resources(env_dict, params_dict):
538
+ """Separates special case params and merges two dictionaries, picking from the
539
+ first in the event of a conflict. Also emit a warning on every
540
+ conflict.
541
+ """
542
+ num_cpus = env_dict.pop("CPU", None)
543
+ num_gpus = env_dict.pop("GPU", None)
544
+ memory = env_dict.pop("memory", None)
545
+ object_store_memory = env_dict.pop("object_store_memory", None)
546
+
547
+ result = params_dict.copy()
548
+ result.update(env_dict)
549
+
550
+ for key in set(env_dict.keys()).intersection(set(params_dict.keys())):
551
+ if params_dict[key] != env_dict[key]:
552
+ logger.warning(
553
+ "Autoscaler is overriding your resource:"
554
+ f"{key}: {params_dict[key]} with {env_dict[key]}."
555
+ )
556
+ return num_cpus, num_gpus, memory, object_store_memory, result
557
+
558
+ if not self._resource_spec:
559
+ env_resources = {}
560
+ env_string = os.getenv(ray_constants.RESOURCES_ENVIRONMENT_VARIABLE)
561
+ if env_string:
562
+ try:
563
+ env_resources = json.loads(env_string)
564
+ except Exception:
565
+ logger.exception(f"Failed to load {env_string}")
566
+ raise
567
+ logger.debug(f"Autoscaler overriding resources: {env_resources}.")
568
+ (
569
+ num_cpus,
570
+ num_gpus,
571
+ memory,
572
+ object_store_memory,
573
+ resources,
574
+ ) = merge_resources(env_resources, self._ray_params.resources)
575
+ self._resource_spec = ResourceSpec(
576
+ self._ray_params.num_cpus if num_cpus is None else num_cpus,
577
+ self._ray_params.num_gpus if num_gpus is None else num_gpus,
578
+ self._ray_params.memory if memory is None else memory,
579
+ (
580
+ self._ray_params.object_store_memory
581
+ if object_store_memory is None
582
+ else object_store_memory
583
+ ),
584
+ resources,
585
+ self._ray_params.redis_max_memory,
586
+ ).resolve(is_head=self.head, node_ip_address=self.node_ip_address)
587
+ return self._resource_spec
588
+
589
+ @property
590
+ def node_id(self):
591
+ """Get the node ID."""
592
+ return self._node_id
593
+
594
+ @property
595
+ def session_name(self):
596
+ """Get the session name (cluster ID)."""
597
+ return self._session_name
598
+
599
+ @property
600
+ def node_ip_address(self):
601
+ """Get the IP address of this node."""
602
+ return self._node_ip_address
603
+
604
+ @property
605
+ def raylet_ip_address(self):
606
+ """Get the IP address of the raylet that this node connects to."""
607
+ return self._raylet_ip_address
608
+
609
+ @property
610
+ def address(self):
611
+ """Get the address for bootstrapping, e.g. the address to pass to
612
+ `ray start` or `ray.init()` to start worker nodes, that has been
613
+ converted to ip:port format.
614
+ """
615
+ return self._gcs_address
616
+
617
+ @property
618
+ def gcs_address(self):
619
+ """Get the gcs address."""
620
+ assert self._gcs_address is not None, "Gcs address is not set"
621
+ return self._gcs_address
622
+
623
+ @property
624
+ def redis_address(self):
625
+ """Get the cluster Redis address."""
626
+ return self._redis_address
627
+
628
+ @property
629
+ def redis_username(self):
630
+ """Get the cluster Redis username."""
631
+ return self._ray_params.redis_username
632
+
633
+ @property
634
+ def redis_password(self):
635
+ """Get the cluster Redis password."""
636
+ return self._ray_params.redis_password
637
+
638
+ @property
639
+ def object_ref_seed(self):
640
+ """Get the seed for deterministic generation of object refs"""
641
+ return self._ray_params.object_ref_seed
642
+
643
+ @property
644
+ def plasma_store_socket_name(self):
645
+ """Get the node's plasma store socket name."""
646
+ return self._plasma_store_socket_name
647
+
648
+ @property
649
+ def unique_id(self):
650
+ """Get a unique identifier for this node."""
651
+ return f"{self.node_ip_address}:{self._plasma_store_socket_name}"
652
+
653
+ @property
654
+ def webui_url(self):
655
+ """Get the cluster's web UI url."""
656
+ return self._webui_url
657
+
658
+ @property
659
+ def raylet_socket_name(self):
660
+ """Get the node's raylet socket name."""
661
+ return self._raylet_socket_name
662
+
663
+ @property
664
+ def node_manager_port(self):
665
+ """Get the node manager's port."""
666
+ return self._ray_params.node_manager_port
667
+
668
+ @property
669
+ def metrics_export_port(self):
670
+ """Get the port that exposes metrics"""
671
+ return self._metrics_export_port
672
+
673
+ @property
674
+ def runtime_env_agent_port(self):
675
+ """Get the port that exposes runtime env agent as http"""
676
+ return self._runtime_env_agent_port
677
+
678
+ @property
679
+ def runtime_env_agent_address(self):
680
+ """Get the address that exposes runtime env agent as http"""
681
+ return f"http://{self._raylet_ip_address}:{self._runtime_env_agent_port}"
682
+
683
+ @property
684
+ def dashboard_agent_listen_port(self):
685
+ """Get the dashboard agent's listen port"""
686
+ return self._dashboard_agent_listen_port
687
+
688
+ @property
689
+ def dashboard_grpc_port(self):
690
+ """Get the dashboard head grpc port"""
691
+ return self._dashboard_grpc_port
692
+
693
+ @property
694
+ def logging_config(self):
695
+ """Get the logging config of the current node."""
696
+ return {
697
+ "log_rotation_max_bytes": self.max_bytes,
698
+ "log_rotation_backup_count": self.backup_count,
699
+ }
700
+
701
+ @property
702
+ def address_info(self):
703
+ """Get a dictionary of addresses."""
704
+ return {
705
+ "node_ip_address": self._node_ip_address,
706
+ "raylet_ip_address": self._raylet_ip_address,
707
+ "redis_address": self.redis_address,
708
+ "object_store_address": self._plasma_store_socket_name,
709
+ "raylet_socket_name": self._raylet_socket_name,
710
+ "webui_url": self._webui_url,
711
+ "session_dir": self._session_dir,
712
+ "metrics_export_port": self._metrics_export_port,
713
+ "gcs_address": self.gcs_address,
714
+ "address": self.address,
715
+ "dashboard_agent_listen_port": self.dashboard_agent_listen_port,
716
+ }
717
+
718
+ def is_head(self):
719
+ return self.head
720
+
721
+ def get_gcs_client(self):
722
+ if self._gcs_client is None:
723
+ self._init_gcs_client()
724
+ return self._gcs_client
725
+
726
+ def _init_gcs_client(self):
727
+ if self.head:
728
+ gcs_process = self.all_processes[ray_constants.PROCESS_TYPE_GCS_SERVER][
729
+ 0
730
+ ].process
731
+ else:
732
+ gcs_process = None
733
+
734
+ # TODO(ryw) instead of create a new GcsClient, wrap the one from
735
+ # CoreWorkerProcess to save a grpc channel.
736
+ for _ in range(ray_constants.NUM_REDIS_GET_RETRIES):
737
+ gcs_address = None
738
+ last_ex = None
739
+ try:
740
+ gcs_address = self.gcs_address
741
+ client = GcsClient(
742
+ address=gcs_address,
743
+ cluster_id=self._ray_params.cluster_id, # Hex string
744
+ )
745
+ self.cluster_id = client.cluster_id
746
+ if self.head:
747
+ # Send a simple request to make sure GCS is alive
748
+ # if it's a head node.
749
+ client.internal_kv_get(b"dummy", None)
750
+ self._gcs_client = client
751
+ break
752
+ except Exception:
753
+ if gcs_process is not None and gcs_process.poll() is not None:
754
+ # GCS has exited.
755
+ break
756
+ last_ex = traceback.format_exc()
757
+ logger.debug(f"Connecting to GCS: {last_ex}")
758
+ time.sleep(1)
759
+
760
+ if self._gcs_client is None:
761
+ if hasattr(self, "_logs_dir"):
762
+ with open(os.path.join(self._logs_dir, "gcs_server.err")) as err:
763
+ # Use " C " or " E " to exclude the stacktrace.
764
+ # This should work for most cases, especitally
765
+ # it's when GCS is starting. Only display last 10 lines of logs.
766
+ errors = [e for e in err.readlines() if " C " in e or " E " in e][
767
+ -10:
768
+ ]
769
+ error_msg = "\n" + "".join(errors) + "\n"
770
+ raise RuntimeError(
771
+ f"Failed to {'start' if self.head else 'connect to'} GCS. "
772
+ f" Last {len(errors)} lines of error files:"
773
+ f"{error_msg}."
774
+ f"Please check {os.path.join(self._logs_dir, 'gcs_server.out')}"
775
+ f" for details. Last connection error: {last_ex}"
776
+ )
777
+ else:
778
+ raise RuntimeError(
779
+ f"Failed to {'start' if self.head else 'connect to'} GCS. Last "
780
+ f"connection error: {last_ex}"
781
+ )
782
+
783
+ ray.experimental.internal_kv._initialize_internal_kv(self._gcs_client)
784
+
785
+ def get_temp_dir_path(self):
786
+ """Get the path of the temporary directory."""
787
+ return self._temp_dir
788
+
789
+ def get_runtime_env_dir_path(self):
790
+ """Get the path of the runtime env."""
791
+ return self._runtime_env_dir
792
+
793
+ def get_session_dir_path(self):
794
+ """Get the path of the session directory."""
795
+ return self._session_dir
796
+
797
+ def get_logs_dir_path(self):
798
+ """Get the path of the log files directory."""
799
+ return self._logs_dir
800
+
801
+ def get_sockets_dir_path(self):
802
+ """Get the path of the sockets directory."""
803
+ return self._sockets_dir
804
+
805
+ def _make_inc_temp(
806
+ self, suffix: str = "", prefix: str = "", directory_name: Optional[str] = None
807
+ ):
808
+ """Return an incremental temporary file name. The file is not created.
809
+
810
+ Args:
811
+ suffix: The suffix of the temp file.
812
+ prefix: The prefix of the temp file.
813
+ directory_name (str) : The base directory of the temp file.
814
+
815
+ Returns:
816
+ A string of file name. If there existing a file having
817
+ the same name, the returned name will look like
818
+ "{directory_name}/{prefix}.{unique_index}{suffix}"
819
+ """
820
+ if directory_name is None:
821
+ directory_name = ray._private.utils.get_ray_temp_dir()
822
+ directory_name = os.path.expanduser(directory_name)
823
+ index = self._incremental_dict[suffix, prefix, directory_name]
824
+ # `tempfile.TMP_MAX` could be extremely large,
825
+ # so using `range` in Python2.x should be avoided.
826
+ while index < tempfile.TMP_MAX:
827
+ if index == 0:
828
+ filename = os.path.join(directory_name, prefix + suffix)
829
+ else:
830
+ filename = os.path.join(
831
+ directory_name, prefix + "." + str(index) + suffix
832
+ )
833
+ index += 1
834
+ if not os.path.exists(filename):
835
+ # Save the index.
836
+ self._incremental_dict[suffix, prefix, directory_name] = index
837
+ return filename
838
+
839
+ raise FileExistsError(errno.EEXIST, "No usable temporary filename found")
840
+
841
+ def should_redirect_logs(self):
842
+ redirect_output = self._ray_params.redirect_output
843
+ if redirect_output is None:
844
+ # Fall back to stderr redirect environment variable.
845
+ redirect_output = (
846
+ os.environ.get(
847
+ ray_constants.LOGGING_REDIRECT_STDERR_ENVIRONMENT_VARIABLE
848
+ )
849
+ != "1"
850
+ )
851
+ return redirect_output
852
+
853
+ def get_log_file_names(
854
+ self,
855
+ name: str,
856
+ unique: bool = False,
857
+ create_out: bool = True,
858
+ create_err: bool = True,
859
+ ) -> Tuple[Optional[str], Optional[str]]:
860
+ """Get filename to dump logs for stdout and stderr, with no files opened.
861
+ If output redirection has been disabled, no files will
862
+ be opened and `(None, None)` will be returned.
863
+
864
+ Args:
865
+ name: descriptive string for this log file.
866
+ unique: if true, a counter will be attached to `name` to
867
+ ensure the returned filename is not already used.
868
+ create_out: if True, create a .out file.
869
+ create_err: if True, create a .err file.
870
+
871
+ Returns:
872
+ A tuple of two file handles for redirecting optional (stdout, stderr),
873
+ or `(None, None)` if output redirection is disabled.
874
+ """
875
+ if not self.should_redirect_logs():
876
+ return None, None
877
+
878
+ log_stdout = None
879
+ log_stderr = None
880
+
881
+ if create_out:
882
+ log_stdout = self._get_log_file_name(name, "out", unique=unique)
883
+ if create_err:
884
+ log_stderr = self._get_log_file_name(name, "err", unique=unique)
885
+ return log_stdout, log_stderr
886
+
887
+ def get_log_file_handles(
888
+ self,
889
+ name: str,
890
+ unique: bool = False,
891
+ create_out: bool = True,
892
+ create_err: bool = True,
893
+ ) -> Tuple[Optional[IO[AnyStr]], Optional[IO[AnyStr]]]:
894
+ """Open log files with partially randomized filenames, returning the
895
+ file handles. If output redirection has been disabled, no files will
896
+ be opened and `(None, None)` will be returned.
897
+
898
+ Args:
899
+ name: descriptive string for this log file.
900
+ unique: if true, a counter will be attached to `name` to
901
+ ensure the returned filename is not already used.
902
+ create_out: if True, create a .out file.
903
+ create_err: if True, create a .err file.
904
+
905
+ Returns:
906
+ A tuple of two file handles for redirecting optional (stdout, stderr),
907
+ or `(None, None)` if output redirection is disabled.
908
+ """
909
+ log_stdout_fname, log_stderr_fname = self.get_log_file_names(
910
+ name, unique=unique, create_out=create_out, create_err=create_err
911
+ )
912
+ log_stdout = None if log_stdout_fname is None else open_log(log_stdout_fname)
913
+ log_stderr = None if log_stderr_fname is None else open_log(log_stderr_fname)
914
+ return log_stdout, log_stderr
915
+
916
+ def _get_log_file_name(
917
+ self,
918
+ name: str,
919
+ suffix: str,
920
+ unique: bool = False,
921
+ ) -> str:
922
+ """Generate partially randomized filenames for log files.
923
+
924
+ Args:
925
+ name: descriptive string for this log file.
926
+ suffix: suffix of the file. Usually it is .out of .err.
927
+ unique: if true, a counter will be attached to `name` to
928
+ ensure the returned filename is not already used.
929
+
930
+ Returns:
931
+ A tuple of two file names for redirecting (stdout, stderr).
932
+ """
933
+ # strip if the suffix is something like .out.
934
+ suffix = suffix.strip(".")
935
+
936
+ if unique:
937
+ filename = self._make_inc_temp(
938
+ suffix=f".{suffix}", prefix=name, directory_name=self._logs_dir
939
+ )
940
+ else:
941
+ filename = os.path.join(self._logs_dir, f"{name}.{suffix}")
942
+ return filename
943
+
944
+ def _get_unused_port(self, allocated_ports=None):
945
+ if allocated_ports is None:
946
+ allocated_ports = set()
947
+
948
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
949
+ s.bind(("", 0))
950
+ port = s.getsockname()[1]
951
+
952
+ # Try to generate a port that is far above the 'next available' one.
953
+ # This solves issue #8254 where GRPC fails because the port assigned
954
+ # from this method has been used by a different process.
955
+ for _ in range(ray_constants.NUM_PORT_RETRIES):
956
+ new_port = random.randint(port, 65535)
957
+ if new_port in allocated_ports:
958
+ # This port is allocated for other usage already,
959
+ # so we shouldn't use it even if it's not in use right now.
960
+ continue
961
+ new_s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
962
+ try:
963
+ new_s.bind(("", new_port))
964
+ except OSError:
965
+ new_s.close()
966
+ continue
967
+ s.close()
968
+ new_s.close()
969
+ return new_port
970
+ logger.error("Unable to succeed in selecting a random port.")
971
+ s.close()
972
+ return port
973
+
974
+ def _prepare_socket_file(self, socket_path: str, default_prefix: str):
975
+ """Prepare the socket file for raylet and plasma.
976
+
977
+ This method helps to prepare a socket file.
978
+ 1. Make the directory if the directory does not exist.
979
+ 2. If the socket file exists, do nothing (this just means we aren't the
980
+ first worker on the node).
981
+
982
+ Args:
983
+ socket_path: the socket file to prepare.
984
+ """
985
+ result = socket_path
986
+ is_mac = sys.platform.startswith("darwin")
987
+ if sys.platform == "win32":
988
+ if socket_path is None:
989
+ result = f"tcp://{self._localhost}" f":{self._get_unused_port()}"
990
+ else:
991
+ if socket_path is None:
992
+ result = self._make_inc_temp(
993
+ prefix=default_prefix, directory_name=self._sockets_dir
994
+ )
995
+ else:
996
+ try_to_create_directory(os.path.dirname(socket_path))
997
+
998
+ # Check socket path length to make sure it's short enough
999
+ maxlen = (104 if is_mac else 108) - 1 # sockaddr_un->sun_path
1000
+ if len(result.split("://", 1)[-1].encode("utf-8")) > maxlen:
1001
+ raise OSError(
1002
+ f"AF_UNIX path length cannot exceed {maxlen} bytes: {result!r}"
1003
+ )
1004
+ return result
1005
+
1006
+ def _get_cached_port(
1007
+ self, port_name: str, default_port: Optional[int] = None
1008
+ ) -> int:
1009
+ """Get a port number from a cache on this node.
1010
+
1011
+ Different driver processes on a node should use the same ports for
1012
+ some purposes, e.g. exporting metrics. This method returns a port
1013
+ number for the given port name and caches it in a file. If the
1014
+ port isn't already cached, an unused port is generated and cached.
1015
+
1016
+ Args:
1017
+ port_name: the name of the port, e.g. metrics_export_port
1018
+ default_port (Optional[int]): The port to return and cache if no
1019
+ port has already been cached for the given port_name. If None, an
1020
+ unused port is generated and cached.
1021
+ Returns:
1022
+ port: the port number.
1023
+ """
1024
+ file_path = os.path.join(self.get_session_dir_path(), "ports_by_node.json")
1025
+
1026
+ # Make sure only the ports in RAY_CACHED_PORTS are cached.
1027
+ assert port_name in ray_constants.RAY_ALLOWED_CACHED_PORTS
1028
+
1029
+ # Maps a Node.unique_id to a dict that maps port names to port numbers.
1030
+ ports_by_node: Dict[str, Dict[str, int]] = defaultdict(dict)
1031
+
1032
+ with FileLock(file_path + ".lock"):
1033
+ if not os.path.exists(file_path):
1034
+ with open(file_path, "w") as f:
1035
+ json.dump({}, f)
1036
+
1037
+ with open(file_path, "r") as f:
1038
+ ports_by_node.update(json.load(f))
1039
+
1040
+ if (
1041
+ self.unique_id in ports_by_node
1042
+ and port_name in ports_by_node[self.unique_id]
1043
+ ):
1044
+ # The port has already been cached at this node, so use it.
1045
+ port = int(ports_by_node[self.unique_id][port_name])
1046
+ else:
1047
+ # Pick a new port to use and cache it at this node.
1048
+ allocated_ports = set(ports_by_node[self.unique_id].values())
1049
+
1050
+ if default_port is not None and default_port in allocated_ports:
1051
+ # The default port is already in use, so don't use it.
1052
+ default_port = None
1053
+
1054
+ port = default_port or self._get_unused_port(allocated_ports)
1055
+
1056
+ ports_by_node[self.unique_id][port_name] = port
1057
+ with open(file_path, "w") as f:
1058
+ json.dump(ports_by_node, f)
1059
+
1060
+ return port
1061
+
1062
+ def _wait_and_get_for_node_address(self, timeout_s: int = 60) -> str:
1063
+ """Wait until the RAY_NODE_IP_FILENAME file is avialable.
1064
+
1065
+ RAY_NODE_IP_FILENAME is created when a ray instance is started.
1066
+
1067
+ Args:
1068
+ timeout_s: If the ip address is not found within this
1069
+ timeout, it will raise ValueError.
1070
+ Returns:
1071
+ The node_ip_address of the current session if it finds it
1072
+ within timeout_s.
1073
+ """
1074
+ for i in range(timeout_s):
1075
+ node_ip_address = ray._private.services.get_cached_node_ip_address(
1076
+ self.get_session_dir_path()
1077
+ )
1078
+
1079
+ if node_ip_address is not None:
1080
+ return node_ip_address
1081
+
1082
+ time.sleep(1)
1083
+ if i % 10 == 0:
1084
+ logger.info(
1085
+ f"Can't find a `{ray_constants.RAY_NODE_IP_FILENAME}` "
1086
+ f"file from {self.get_session_dir_path()}. "
1087
+ "Have you started Ray instance using "
1088
+ "`ray start` or `ray.init`?"
1089
+ )
1090
+
1091
+ raise ValueError(
1092
+ f"Can't find a `{ray_constants.RAY_NODE_IP_FILENAME}` "
1093
+ f"file from {self.get_session_dir_path()}. "
1094
+ f"for {timeout_s} seconds. "
1095
+ "A ray instance hasn't started. "
1096
+ "Did you do `ray start` or `ray.init` on this host?"
1097
+ )
1098
+
1099
+ def start_reaper_process(self):
1100
+ """
1101
+ Start the reaper process.
1102
+
1103
+ This must be the first process spawned and should only be called when
1104
+ ray processes should be cleaned up if this process dies.
1105
+ """
1106
+ assert (
1107
+ not self.kernel_fate_share
1108
+ ), "a reaper should not be used with kernel fate-sharing"
1109
+ process_info = ray._private.services.start_reaper(fate_share=False)
1110
+ assert ray_constants.PROCESS_TYPE_REAPER not in self.all_processes
1111
+ if process_info is not None:
1112
+ self.all_processes[ray_constants.PROCESS_TYPE_REAPER] = [
1113
+ process_info,
1114
+ ]
1115
+
1116
+ def start_log_monitor(self):
1117
+ """Start the log monitor."""
1118
+ # Only redirect logs to .err. .err file is only useful when the
1119
+ # component has an unexpected output to stdout/stderr.
1120
+ _, stderr_file = self.get_log_file_handles(
1121
+ "log_monitor", unique=True, create_out=False
1122
+ )
1123
+ process_info = ray._private.services.start_log_monitor(
1124
+ self.get_session_dir_path(),
1125
+ self._logs_dir,
1126
+ self.gcs_address,
1127
+ fate_share=self.kernel_fate_share,
1128
+ max_bytes=self.max_bytes,
1129
+ backup_count=self.backup_count,
1130
+ redirect_logging=self.should_redirect_logs(),
1131
+ stdout_file=stderr_file,
1132
+ stderr_file=stderr_file,
1133
+ )
1134
+ assert ray_constants.PROCESS_TYPE_LOG_MONITOR not in self.all_processes
1135
+ self.all_processes[ray_constants.PROCESS_TYPE_LOG_MONITOR] = [
1136
+ process_info,
1137
+ ]
1138
+
1139
+ def start_api_server(
1140
+ self, *, include_dashboard: Optional[bool], raise_on_failure: bool
1141
+ ):
1142
+ """Start the dashboard.
1143
+
1144
+ Args:
1145
+ include_dashboard: If true, this will load all dashboard-related modules
1146
+ when starting the API server. Otherwise, it will only
1147
+ start the modules that are not relevant to the dashboard.
1148
+ raise_on_failure: If true, this will raise an exception
1149
+ if we fail to start the API server. Otherwise it will print
1150
+ a warning if we fail to start the API server.
1151
+ """
1152
+ # Only redirect logs to .err. .err file is only useful when the
1153
+ # component has an unexpected output to stdout/stderr.
1154
+ _, stderr_file = self.get_log_file_handles(
1155
+ "dashboard", unique=True, create_out=False
1156
+ )
1157
+ self._webui_url, process_info = ray._private.services.start_api_server(
1158
+ include_dashboard,
1159
+ raise_on_failure,
1160
+ self._ray_params.dashboard_host,
1161
+ self.gcs_address,
1162
+ self.cluster_id.hex(),
1163
+ self._node_ip_address,
1164
+ self._temp_dir,
1165
+ self._logs_dir,
1166
+ self._session_dir,
1167
+ port=self._ray_params.dashboard_port,
1168
+ dashboard_grpc_port=self._ray_params.dashboard_grpc_port,
1169
+ fate_share=self.kernel_fate_share,
1170
+ max_bytes=self.max_bytes,
1171
+ backup_count=self.backup_count,
1172
+ redirect_logging=self.should_redirect_logs(),
1173
+ stdout_file=stderr_file,
1174
+ stderr_file=stderr_file,
1175
+ )
1176
+ assert ray_constants.PROCESS_TYPE_DASHBOARD not in self.all_processes
1177
+ if process_info is not None:
1178
+ self.all_processes[ray_constants.PROCESS_TYPE_DASHBOARD] = [
1179
+ process_info,
1180
+ ]
1181
+ self.get_gcs_client().internal_kv_put(
1182
+ b"webui:url",
1183
+ self._webui_url.encode(),
1184
+ True,
1185
+ ray_constants.KV_NAMESPACE_DASHBOARD,
1186
+ )
1187
+
1188
+ def start_gcs_server(self):
1189
+ """Start the gcs server."""
1190
+ gcs_server_port = self._ray_params.gcs_server_port
1191
+ assert gcs_server_port > 0
1192
+ assert self._gcs_address is None, "GCS server is already running."
1193
+ assert self._gcs_client is None, "GCS client is already connected."
1194
+
1195
+ # TODO(hjiang): Update stderr to pass filename and get spdlog to handle
1196
+ # logging as well.
1197
+ stdout_log_fname, _ = self.get_log_file_names(
1198
+ "gcs_server", unique=True, create_out=True, create_err=False
1199
+ )
1200
+ _, stderr_file = self.get_log_file_handles(
1201
+ "gcs_server", unique=True, create_out=False, create_err=True
1202
+ )
1203
+ process_info = ray._private.services.start_gcs_server(
1204
+ self.redis_address,
1205
+ log_dir=self._logs_dir,
1206
+ ray_log_filepath=stdout_log_fname,
1207
+ stderr_file=stderr_file,
1208
+ session_name=self.session_name,
1209
+ redis_username=self._ray_params.redis_username,
1210
+ redis_password=self._ray_params.redis_password,
1211
+ config=self._config,
1212
+ fate_share=self.kernel_fate_share,
1213
+ gcs_server_port=gcs_server_port,
1214
+ metrics_agent_port=self._ray_params.metrics_agent_port,
1215
+ node_ip_address=self._node_ip_address,
1216
+ )
1217
+ assert ray_constants.PROCESS_TYPE_GCS_SERVER not in self.all_processes
1218
+ self.all_processes[ray_constants.PROCESS_TYPE_GCS_SERVER] = [
1219
+ process_info,
1220
+ ]
1221
+ # Connecting via non-localhost address may be blocked by firewall rule,
1222
+ # e.g. https://github.com/ray-project/ray/issues/15780
1223
+ # TODO(mwtian): figure out a way to use 127.0.0.1 for local connection
1224
+ # when possible.
1225
+ self._gcs_address = f"{self._node_ip_address}:" f"{gcs_server_port}"
1226
+
1227
+ def start_raylet(
1228
+ self,
1229
+ plasma_directory: str,
1230
+ object_store_memory: int,
1231
+ use_valgrind: bool = False,
1232
+ use_profiler: bool = False,
1233
+ enable_physical_mode: bool = False,
1234
+ ):
1235
+ """Start the raylet.
1236
+
1237
+ Args:
1238
+ use_valgrind: True if we should start the process in
1239
+ valgrind.
1240
+ use_profiler: True if we should start the process in the
1241
+ valgrind profiler.
1242
+ """
1243
+ stdout_log_fname, _ = self.get_log_file_names(
1244
+ "raylet", unique=True, create_out=True, create_err=False
1245
+ )
1246
+ _, stderr_file = self.get_log_file_handles(
1247
+ "raylet", unique=True, create_out=False, create_err=True
1248
+ )
1249
+ process_info = ray._private.services.start_raylet(
1250
+ self.redis_address,
1251
+ self.gcs_address,
1252
+ self._node_id,
1253
+ self._node_ip_address,
1254
+ self._ray_params.node_manager_port,
1255
+ self._raylet_socket_name,
1256
+ self._plasma_store_socket_name,
1257
+ self.cluster_id.hex(),
1258
+ self._ray_params.worker_path,
1259
+ self._ray_params.setup_worker_path,
1260
+ self._ray_params.storage,
1261
+ self._temp_dir,
1262
+ self._session_dir,
1263
+ self._runtime_env_dir,
1264
+ self._logs_dir,
1265
+ self.get_resource_spec(),
1266
+ plasma_directory,
1267
+ object_store_memory,
1268
+ self.session_name,
1269
+ is_head_node=self.is_head(),
1270
+ min_worker_port=self._ray_params.min_worker_port,
1271
+ max_worker_port=self._ray_params.max_worker_port,
1272
+ worker_port_list=self._ray_params.worker_port_list,
1273
+ object_manager_port=self._ray_params.object_manager_port,
1274
+ redis_username=self._ray_params.redis_username,
1275
+ redis_password=self._ray_params.redis_password,
1276
+ metrics_agent_port=self._ray_params.metrics_agent_port,
1277
+ runtime_env_agent_port=self._ray_params.runtime_env_agent_port,
1278
+ metrics_export_port=self._metrics_export_port,
1279
+ dashboard_agent_listen_port=self._ray_params.dashboard_agent_listen_port,
1280
+ use_valgrind=use_valgrind,
1281
+ use_profiler=use_profiler,
1282
+ ray_log_filepath=stdout_log_fname,
1283
+ stderr_file=stderr_file,
1284
+ huge_pages=self._ray_params.huge_pages,
1285
+ fate_share=self.kernel_fate_share,
1286
+ socket_to_use=None,
1287
+ max_bytes=self.max_bytes,
1288
+ backup_count=self.backup_count,
1289
+ ray_debugger_external=self._ray_params.ray_debugger_external,
1290
+ env_updates=self._ray_params.env_vars,
1291
+ node_name=self._ray_params.node_name,
1292
+ webui=self._webui_url,
1293
+ labels=self._get_node_labels(),
1294
+ enable_physical_mode=enable_physical_mode,
1295
+ )
1296
+ assert ray_constants.PROCESS_TYPE_RAYLET not in self.all_processes
1297
+ self.all_processes[ray_constants.PROCESS_TYPE_RAYLET] = [process_info]
1298
+
1299
+ def start_worker(self):
1300
+ """Start a worker process."""
1301
+ raise NotImplementedError
1302
+
1303
+ def start_monitor(self):
1304
+ """Start the monitor.
1305
+
1306
+ Autoscaling output goes to these monitor.err/out files, and
1307
+ any modification to these files may break existing
1308
+ cluster launching commands.
1309
+ """
1310
+ from ray.autoscaler.v2.utils import is_autoscaler_v2
1311
+
1312
+ stdout_file, stderr_file = self.get_log_file_handles("monitor", unique=True)
1313
+ process_info = ray._private.services.start_monitor(
1314
+ self.gcs_address,
1315
+ self._logs_dir,
1316
+ stdout_file=stdout_file,
1317
+ stderr_file=stderr_file,
1318
+ autoscaling_config=self._ray_params.autoscaling_config,
1319
+ fate_share=self.kernel_fate_share,
1320
+ max_bytes=self.max_bytes,
1321
+ backup_count=self.backup_count,
1322
+ monitor_ip=self._node_ip_address,
1323
+ autoscaler_v2=is_autoscaler_v2(fetch_from_server=True),
1324
+ )
1325
+ assert ray_constants.PROCESS_TYPE_MONITOR not in self.all_processes
1326
+ self.all_processes[ray_constants.PROCESS_TYPE_MONITOR] = [process_info]
1327
+
1328
+ def start_ray_client_server(self):
1329
+ """Start the ray client server process."""
1330
+ stdout_file, stderr_file = self.get_log_file_handles(
1331
+ "ray_client_server", unique=True
1332
+ )
1333
+ process_info = ray._private.services.start_ray_client_server(
1334
+ self.address,
1335
+ self._node_ip_address,
1336
+ self._ray_params.ray_client_server_port,
1337
+ stdout_file=stdout_file,
1338
+ stderr_file=stderr_file,
1339
+ redis_username=self._ray_params.redis_username,
1340
+ redis_password=self._ray_params.redis_password,
1341
+ fate_share=self.kernel_fate_share,
1342
+ runtime_env_agent_address=self.runtime_env_agent_address,
1343
+ )
1344
+ assert ray_constants.PROCESS_TYPE_RAY_CLIENT_SERVER not in self.all_processes
1345
+ self.all_processes[ray_constants.PROCESS_TYPE_RAY_CLIENT_SERVER] = [
1346
+ process_info
1347
+ ]
1348
+
1349
+ def _write_cluster_info_to_kv(self):
1350
+ """Write the cluster metadata to GCS.
1351
+ Cluster metadata is always recorded, but they are
1352
+ not reported unless usage report is enabled.
1353
+ Check `usage_stats_head.py` for more details.
1354
+ """
1355
+ # Make sure the cluster metadata wasn't reported before.
1356
+ import ray._private.usage.usage_lib as ray_usage_lib
1357
+
1358
+ ray_usage_lib.put_cluster_metadata(
1359
+ self.get_gcs_client(), ray_init_cluster=self.ray_init_cluster
1360
+ )
1361
+ # Make sure GCS is up.
1362
+ added = self.get_gcs_client().internal_kv_put(
1363
+ b"session_name",
1364
+ self._session_name.encode(),
1365
+ False,
1366
+ ray_constants.KV_NAMESPACE_SESSION,
1367
+ )
1368
+ if not added:
1369
+ curr_val = self.get_gcs_client().internal_kv_get(
1370
+ b"session_name", ray_constants.KV_NAMESPACE_SESSION
1371
+ )
1372
+ assert curr_val == self._session_name.encode("utf-8"), (
1373
+ f"Session name {self._session_name} does not match "
1374
+ f"persisted value {curr_val}. Perhaps there was an "
1375
+ f"error connecting to Redis."
1376
+ )
1377
+
1378
+ self.get_gcs_client().internal_kv_put(
1379
+ b"session_dir",
1380
+ self._session_dir.encode(),
1381
+ True,
1382
+ ray_constants.KV_NAMESPACE_SESSION,
1383
+ )
1384
+ self.get_gcs_client().internal_kv_put(
1385
+ b"temp_dir",
1386
+ self._temp_dir.encode(),
1387
+ True,
1388
+ ray_constants.KV_NAMESPACE_SESSION,
1389
+ )
1390
+ if self._ray_params.storage is not None:
1391
+ self.get_gcs_client().internal_kv_put(
1392
+ b"storage",
1393
+ self._ray_params.storage.encode(),
1394
+ True,
1395
+ ray_constants.KV_NAMESPACE_SESSION,
1396
+ )
1397
+ # Add tracing_startup_hook to redis / internal kv manually
1398
+ # since internal kv is not yet initialized.
1399
+ if self._ray_params.tracing_startup_hook:
1400
+ self.get_gcs_client().internal_kv_put(
1401
+ b"tracing_startup_hook",
1402
+ self._ray_params.tracing_startup_hook.encode(),
1403
+ True,
1404
+ ray_constants.KV_NAMESPACE_TRACING,
1405
+ )
1406
+
1407
+ def start_head_processes(self):
1408
+ """Start head processes on the node."""
1409
+ logger.debug(
1410
+ f"Process STDOUT and STDERR is being " f"redirected to {self._logs_dir}."
1411
+ )
1412
+ assert self._gcs_address is None
1413
+ assert self._gcs_client is None
1414
+
1415
+ self.start_gcs_server()
1416
+ assert self.get_gcs_client() is not None
1417
+ self._write_cluster_info_to_kv()
1418
+
1419
+ if not self._ray_params.no_monitor:
1420
+ self.start_monitor()
1421
+
1422
+ if self._ray_params.ray_client_server_port:
1423
+ self.start_ray_client_server()
1424
+
1425
+ if self._ray_params.include_dashboard is None:
1426
+ # Default
1427
+ raise_on_api_server_failure = False
1428
+ else:
1429
+ raise_on_api_server_failure = self._ray_params.include_dashboard
1430
+
1431
+ self.start_api_server(
1432
+ include_dashboard=self._ray_params.include_dashboard,
1433
+ raise_on_failure=raise_on_api_server_failure,
1434
+ )
1435
+
1436
+ def start_ray_processes(self):
1437
+ """Start all of the processes on the node."""
1438
+ logger.debug(
1439
+ f"Process STDOUT and STDERR is being " f"redirected to {self._logs_dir}."
1440
+ )
1441
+
1442
+ if not self.head:
1443
+ # Get the system config from GCS first if this is a non-head node.
1444
+ gcs_options = ray._raylet.GcsClientOptions.create(
1445
+ self.gcs_address,
1446
+ self.cluster_id.hex(),
1447
+ allow_cluster_id_nil=False,
1448
+ fetch_cluster_id_if_nil=False,
1449
+ )
1450
+ global_state = ray._private.state.GlobalState()
1451
+ global_state._initialize_global_state(gcs_options)
1452
+ new_config = global_state.get_system_config()
1453
+ assert self._config.items() <= new_config.items(), (
1454
+ "The system config from GCS is not a superset of the local"
1455
+ " system config. There might be a configuration inconsistency"
1456
+ " issue between the head node and non-head nodes."
1457
+ f" Local system config: {self._config},"
1458
+ f" GCS system config: {new_config}"
1459
+ )
1460
+ self._config = new_config
1461
+
1462
+ # Make sure we don't call `determine_plasma_store_config` multiple
1463
+ # times to avoid printing multiple warnings.
1464
+ resource_spec = self.get_resource_spec()
1465
+ (
1466
+ plasma_directory,
1467
+ object_store_memory,
1468
+ ) = ray._private.services.determine_plasma_store_config(
1469
+ resource_spec.object_store_memory,
1470
+ plasma_directory=self._ray_params.plasma_directory,
1471
+ huge_pages=self._ray_params.huge_pages,
1472
+ )
1473
+ self.start_raylet(plasma_directory, object_store_memory)
1474
+ if self._ray_params.include_log_monitor:
1475
+ self.start_log_monitor()
1476
+
1477
+ def _kill_process_type(
1478
+ self,
1479
+ process_type,
1480
+ allow_graceful: bool = False,
1481
+ check_alive: bool = True,
1482
+ wait: bool = False,
1483
+ ):
1484
+ """Kill a process of a given type.
1485
+
1486
+ If the process type is PROCESS_TYPE_REDIS_SERVER, then we will kill all
1487
+ of the Redis servers.
1488
+
1489
+ If the process was started in valgrind, then we will raise an exception
1490
+ if the process has a non-zero exit code.
1491
+
1492
+ Args:
1493
+ process_type: The type of the process to kill.
1494
+ allow_graceful: Send a SIGTERM first and give the process
1495
+ time to exit gracefully. If that doesn't work, then use
1496
+ SIGKILL. We usually want to do this outside of tests.
1497
+ check_alive: If true, then we expect the process to be alive
1498
+ and will raise an exception if the process is already dead.
1499
+ wait: If true, then this method will not return until the
1500
+ process in question has exited.
1501
+
1502
+ Raises:
1503
+ This process raises an exception in the following cases:
1504
+ 1. The process had already died and check_alive is true.
1505
+ 2. The process had been started in valgrind and had a non-zero
1506
+ exit code.
1507
+ """
1508
+
1509
+ # Ensure thread safety
1510
+ with self.removal_lock:
1511
+ self._kill_process_impl(
1512
+ process_type,
1513
+ allow_graceful=allow_graceful,
1514
+ check_alive=check_alive,
1515
+ wait=wait,
1516
+ )
1517
+
1518
+ def _kill_process_impl(
1519
+ self, process_type, allow_graceful=False, check_alive=True, wait=False
1520
+ ):
1521
+ """See `_kill_process_type`."""
1522
+ if process_type not in self.all_processes:
1523
+ return
1524
+ process_infos = self.all_processes[process_type]
1525
+ if process_type != ray_constants.PROCESS_TYPE_REDIS_SERVER:
1526
+ assert len(process_infos) == 1
1527
+ for process_info in process_infos:
1528
+ process = process_info.process
1529
+ # Handle the case where the process has already exited.
1530
+ if process.poll() is not None:
1531
+ if check_alive:
1532
+ raise RuntimeError(
1533
+ "Attempting to kill a process of type "
1534
+ f"'{process_type}', but this process is already dead."
1535
+ )
1536
+ else:
1537
+ continue
1538
+
1539
+ if process_info.use_valgrind:
1540
+ process.terminate()
1541
+ process.wait()
1542
+ if process.returncode != 0:
1543
+ message = (
1544
+ "Valgrind detected some errors in process of "
1545
+ f"type {process_type}. Error code {process.returncode}."
1546
+ )
1547
+ if process_info.stdout_file is not None:
1548
+ with open(process_info.stdout_file, "r") as f:
1549
+ message += "\nPROCESS STDOUT:\n" + f.read()
1550
+ if process_info.stderr_file is not None:
1551
+ with open(process_info.stderr_file, "r") as f:
1552
+ message += "\nPROCESS STDERR:\n" + f.read()
1553
+ raise RuntimeError(message)
1554
+ continue
1555
+
1556
+ if process_info.use_valgrind_profiler:
1557
+ # Give process signal to write profiler data.
1558
+ os.kill(process.pid, signal.SIGINT)
1559
+ # Wait for profiling data to be written.
1560
+ time.sleep(0.1)
1561
+
1562
+ if allow_graceful:
1563
+ process.terminate()
1564
+ # Allow the process one second to exit gracefully.
1565
+ timeout_seconds = 1
1566
+ try:
1567
+ process.wait(timeout_seconds)
1568
+ except subprocess.TimeoutExpired:
1569
+ pass
1570
+
1571
+ # If the process did not exit, force kill it.
1572
+ if process.poll() is None:
1573
+ process.kill()
1574
+ # The reason we usually don't call process.wait() here is that
1575
+ # there's some chance we'd end up waiting a really long time.
1576
+ if wait:
1577
+ process.wait()
1578
+
1579
+ del self.all_processes[process_type]
1580
+
1581
+ def kill_redis(self, check_alive: bool = True):
1582
+ """Kill the Redis servers.
1583
+
1584
+ Args:
1585
+ check_alive: Raise an exception if any of the processes
1586
+ were already dead.
1587
+ """
1588
+ self._kill_process_type(
1589
+ ray_constants.PROCESS_TYPE_REDIS_SERVER, check_alive=check_alive
1590
+ )
1591
+
1592
+ def kill_raylet(self, check_alive: bool = True):
1593
+ """Kill the raylet.
1594
+
1595
+ Args:
1596
+ check_alive: Raise an exception if the process was already
1597
+ dead.
1598
+ """
1599
+ self._kill_process_type(
1600
+ ray_constants.PROCESS_TYPE_RAYLET, check_alive=check_alive
1601
+ )
1602
+
1603
+ def kill_log_monitor(self, check_alive: bool = True):
1604
+ """Kill the log monitor.
1605
+
1606
+ Args:
1607
+ check_alive: Raise an exception if the process was already
1608
+ dead.
1609
+ """
1610
+ self._kill_process_type(
1611
+ ray_constants.PROCESS_TYPE_LOG_MONITOR, check_alive=check_alive
1612
+ )
1613
+
1614
+ def kill_reporter(self, check_alive: bool = True):
1615
+ """Kill the reporter.
1616
+
1617
+ Args:
1618
+ check_alive: Raise an exception if the process was already
1619
+ dead.
1620
+ """
1621
+ self._kill_process_type(
1622
+ ray_constants.PROCESS_TYPE_REPORTER, check_alive=check_alive
1623
+ )
1624
+
1625
+ def kill_dashboard(self, check_alive: bool = True):
1626
+ """Kill the dashboard.
1627
+
1628
+ Args:
1629
+ check_alive: Raise an exception if the process was already
1630
+ dead.
1631
+ """
1632
+ self._kill_process_type(
1633
+ ray_constants.PROCESS_TYPE_DASHBOARD, check_alive=check_alive
1634
+ )
1635
+
1636
+ def kill_monitor(self, check_alive: bool = True):
1637
+ """Kill the monitor.
1638
+
1639
+ Args:
1640
+ check_alive: Raise an exception if the process was already
1641
+ dead.
1642
+ """
1643
+ self._kill_process_type(
1644
+ ray_constants.PROCESS_TYPE_MONITOR, check_alive=check_alive
1645
+ )
1646
+
1647
+ def kill_gcs_server(self, check_alive: bool = True):
1648
+ """Kill the gcs server.
1649
+
1650
+ Args:
1651
+ check_alive: Raise an exception if the process was already
1652
+ dead.
1653
+ """
1654
+ self._kill_process_type(
1655
+ ray_constants.PROCESS_TYPE_GCS_SERVER, check_alive=check_alive, wait=True
1656
+ )
1657
+ # Clear GCS client and address to indicate no GCS server is running.
1658
+ self._gcs_address = None
1659
+ self._gcs_client = None
1660
+
1661
+ def kill_reaper(self, check_alive: bool = True):
1662
+ """Kill the reaper process.
1663
+
1664
+ Args:
1665
+ check_alive: Raise an exception if the process was already
1666
+ dead.
1667
+ """
1668
+ self._kill_process_type(
1669
+ ray_constants.PROCESS_TYPE_REAPER, check_alive=check_alive
1670
+ )
1671
+
1672
+ def kill_all_processes(self, check_alive=True, allow_graceful=False, wait=False):
1673
+ """Kill all of the processes.
1674
+
1675
+ Note that This is slower than necessary because it calls kill, wait,
1676
+ kill, wait, ... instead of kill, kill, ..., wait, wait, ...
1677
+
1678
+ Args:
1679
+ check_alive: Raise an exception if any of the processes were
1680
+ already dead.
1681
+ wait: If true, then this method will not return until the
1682
+ process in question has exited.
1683
+ """
1684
+ # Kill the raylet first. This is important for suppressing errors at
1685
+ # shutdown because we give the raylet a chance to exit gracefully and
1686
+ # clean up its child worker processes. If we were to kill the plasma
1687
+ # store (or Redis) first, that could cause the raylet to exit
1688
+ # ungracefully, leading to more verbose output from the workers.
1689
+ if ray_constants.PROCESS_TYPE_RAYLET in self.all_processes:
1690
+ self._kill_process_type(
1691
+ ray_constants.PROCESS_TYPE_RAYLET,
1692
+ check_alive=check_alive,
1693
+ allow_graceful=allow_graceful,
1694
+ wait=wait,
1695
+ )
1696
+
1697
+ if ray_constants.PROCESS_TYPE_GCS_SERVER in self.all_processes:
1698
+ self._kill_process_type(
1699
+ ray_constants.PROCESS_TYPE_GCS_SERVER,
1700
+ check_alive=check_alive,
1701
+ allow_graceful=allow_graceful,
1702
+ wait=wait,
1703
+ )
1704
+
1705
+ # We call "list" to copy the keys because we are modifying the
1706
+ # dictionary while iterating over it.
1707
+ for process_type in list(self.all_processes.keys()):
1708
+ # Need to kill the reaper process last in case we die unexpectedly
1709
+ # while cleaning up.
1710
+ if process_type != ray_constants.PROCESS_TYPE_REAPER:
1711
+ self._kill_process_type(
1712
+ process_type,
1713
+ check_alive=check_alive,
1714
+ allow_graceful=allow_graceful,
1715
+ wait=wait,
1716
+ )
1717
+
1718
+ if ray_constants.PROCESS_TYPE_REAPER in self.all_processes:
1719
+ self._kill_process_type(
1720
+ ray_constants.PROCESS_TYPE_REAPER,
1721
+ check_alive=check_alive,
1722
+ allow_graceful=allow_graceful,
1723
+ wait=wait,
1724
+ )
1725
+
1726
+ def live_processes(self):
1727
+ """Return a list of the live processes.
1728
+
1729
+ Returns:
1730
+ A list of the live processes.
1731
+ """
1732
+ result = []
1733
+ for process_type, process_infos in self.all_processes.items():
1734
+ for process_info in process_infos:
1735
+ if process_info.process.poll() is None:
1736
+ result.append((process_type, process_info.process))
1737
+ return result
1738
+
1739
+ def dead_processes(self):
1740
+ """Return a list of the dead processes.
1741
+
1742
+ Note that this ignores processes that have been explicitly killed,
1743
+ e.g., via a command like node.kill_raylet().
1744
+
1745
+ Returns:
1746
+ A list of the dead processes ignoring the ones that have been
1747
+ explicitly killed.
1748
+ """
1749
+ result = []
1750
+ for process_type, process_infos in self.all_processes.items():
1751
+ for process_info in process_infos:
1752
+ if process_info.process.poll() is not None:
1753
+ result.append((process_type, process_info.process))
1754
+ return result
1755
+
1756
+ def any_processes_alive(self):
1757
+ """Return true if any processes are still alive.
1758
+
1759
+ Returns:
1760
+ True if any process is still alive.
1761
+ """
1762
+ return any(self.live_processes())
1763
+
1764
+ def remaining_processes_alive(self):
1765
+ """Return true if all remaining processes are still alive.
1766
+
1767
+ Note that this ignores processes that have been explicitly killed,
1768
+ e.g., via a command like node.kill_raylet().
1769
+
1770
+ Returns:
1771
+ True if any process that wasn't explicitly killed is still alive.
1772
+ """
1773
+ return not any(self.dead_processes())
1774
+
1775
+ def destroy_external_storage(self):
1776
+ object_spilling_config = self._config.get("object_spilling_config", {})
1777
+ if object_spilling_config:
1778
+ object_spilling_config = json.loads(object_spilling_config)
1779
+ from ray._private import external_storage
1780
+
1781
+ storage = external_storage.setup_external_storage(
1782
+ object_spilling_config, self._node_id, self._session_name
1783
+ )
1784
+ storage.destroy_external_storage()
1785
+
1786
+ def validate_external_storage(self):
1787
+ """Make sure we can setup the object spilling external storage.
1788
+ This will also fill up the default setting for object spilling
1789
+ if not specified.
1790
+ """
1791
+ object_spilling_config = self._config.get("object_spilling_config", {})
1792
+ automatic_spilling_enabled = self._config.get(
1793
+ "automatic_object_spilling_enabled", True
1794
+ )
1795
+ if not automatic_spilling_enabled:
1796
+ return
1797
+
1798
+ if not object_spilling_config:
1799
+ object_spilling_config = os.environ.get("RAY_object_spilling_config", "")
1800
+
1801
+ # If the config is not specified, we fill up the default.
1802
+ if not object_spilling_config:
1803
+ object_spilling_config = json.dumps(
1804
+ {"type": "filesystem", "params": {"directory_path": self._session_dir}}
1805
+ )
1806
+
1807
+ # Try setting up the storage.
1808
+ # Configure the proper system config.
1809
+ # We need to set both ray param's system config and self._config
1810
+ # because they could've been diverged at this point.
1811
+ deserialized_config = json.loads(object_spilling_config)
1812
+ self._ray_params._system_config[
1813
+ "object_spilling_config"
1814
+ ] = object_spilling_config
1815
+ self._config["object_spilling_config"] = object_spilling_config
1816
+
1817
+ is_external_storage_type_fs = deserialized_config["type"] == "filesystem"
1818
+ self._ray_params._system_config[
1819
+ "is_external_storage_type_fs"
1820
+ ] = is_external_storage_type_fs
1821
+ self._config["is_external_storage_type_fs"] = is_external_storage_type_fs
1822
+
1823
+ # Validate external storage usage.
1824
+ from ray._private import external_storage
1825
+
1826
+ # Node ID is available only after GCS is connected. However,
1827
+ # validate_external_storage() needs to be called before it to
1828
+ # be able to validate the configs early. Therefore, we use a
1829
+ # dummy node ID here and make sure external storage can be set
1830
+ # up based on the provided config. This storage is destroyed
1831
+ # right after the validation.
1832
+ dummy_node_id = ray.NodeID.from_random().hex()
1833
+ storage = external_storage.setup_external_storage(
1834
+ deserialized_config, dummy_node_id, self._session_name
1835
+ )
1836
+ storage.destroy_external_storage()
1837
+ external_storage.reset_external_storage()
1838
+
1839
+ def _record_stats(self):
1840
+ # This is only called when a new node is started.
1841
+ # Initialize the internal kv so that the metrics can be put
1842
+ from ray._private.usage.usage_lib import (
1843
+ TagKey,
1844
+ record_extra_usage_tag,
1845
+ record_hardware_usage,
1846
+ )
1847
+
1848
+ if not ray.experimental.internal_kv._internal_kv_initialized():
1849
+ ray.experimental.internal_kv._initialize_internal_kv(self.get_gcs_client())
1850
+ assert ray.experimental.internal_kv._internal_kv_initialized()
1851
+ if self.head:
1852
+ # record head node stats
1853
+ gcs_storage_type = (
1854
+ "redis" if os.environ.get("RAY_REDIS_ADDRESS") is not None else "memory"
1855
+ )
1856
+ record_extra_usage_tag(TagKey.GCS_STORAGE, gcs_storage_type)
1857
+ cpu_model_name = ray._private.utils.get_current_node_cpu_model_name()
1858
+ if cpu_model_name:
1859
+ # CPU model name can be an arbitrary long string
1860
+ # so we truncate it to the first 50 characters
1861
+ # to avoid any issues.
1862
+ record_hardware_usage(cpu_model_name[:50])
.venv/lib/python3.11/site-packages/ray/_private/parameter.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import Dict, List, Optional
4
+
5
+ import ray._private.ray_constants as ray_constants
6
+ from ray._private.utils import (
7
+ validate_node_labels,
8
+ check_ray_client_dependencies_installed,
9
+ )
10
+
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class RayParams:
16
+ """A class used to store the parameters used by Ray.
17
+
18
+ Attributes:
19
+ redis_address: The address of the Redis server to connect to. If
20
+ this address is not provided, then this command will start Redis, a
21
+ raylet, a plasma store, a plasma manager, and some workers.
22
+ It will also kill these processes when Python exits.
23
+ redis_port: The port that the primary Redis shard should listen
24
+ to. If None, then it will fall back to
25
+ ray._private.ray_constants.DEFAULT_PORT, or a random port if the default is
26
+ not available.
27
+ redis_shard_ports: A list of the ports to use for the non-primary Redis
28
+ shards. If None, then it will fall back to the ports right after
29
+ redis_port, or random ports if those are not available.
30
+ num_cpus: Number of CPUs to configure the raylet with.
31
+ num_gpus: Number of GPUs to configure the raylet with.
32
+ resources: A dictionary mapping the name of a resource to the quantity
33
+ of that resource available.
34
+ labels: The key-value labels of the node.
35
+ memory: Total available memory for workers requesting memory.
36
+ object_store_memory: The amount of memory (in bytes) to start the
37
+ object store with.
38
+ redis_max_memory: The max amount of memory (in bytes) to allow redis
39
+ to use, or None for no limit. Once the limit is exceeded, redis
40
+ will start LRU eviction of entries. This only applies to the
41
+ sharded redis tables (task and object tables).
42
+ object_manager_port int: The port to use for the object manager.
43
+ node_manager_port: The port to use for the node manager.
44
+ gcs_server_port: The port to use for the GCS server.
45
+ node_ip_address: The IP address of the node that we are on.
46
+ raylet_ip_address: The IP address of the raylet that this node
47
+ connects to.
48
+ min_worker_port: The lowest port number that workers will bind
49
+ on. If not set or set to 0, random ports will be chosen.
50
+ max_worker_port: The highest port number that workers will bind
51
+ on. If set, min_worker_port must also be set.
52
+ worker_port_list: An explicit list of ports to be used for
53
+ workers (comma-separated). Overrides min_worker_port and
54
+ max_worker_port.
55
+ ray_client_server_port: The port number the ray client server
56
+ will bind on. If not set, the ray client server will not
57
+ be started.
58
+ object_ref_seed: Used to seed the deterministic generation of
59
+ object refs. The same value can be used across multiple runs of the
60
+ same job in order to generate the object refs in a consistent
61
+ manner. However, the same ID should not be used for different jobs.
62
+ redirect_output: True if stdout and stderr for non-worker
63
+ processes should be redirected to files and false otherwise.
64
+ external_addresses: The address of external Redis server to
65
+ connect to, in format of "ip1:port1,ip2:port2,...". If this
66
+ address is provided, then ray won't start Redis instances in the
67
+ head node but use external Redis server(s) instead.
68
+ num_redis_shards: The number of Redis shards to start in addition to
69
+ the primary Redis shard.
70
+ redis_max_clients: If provided, attempt to configure Redis with this
71
+ maxclients number.
72
+ redis_username: Prevents external clients without the username
73
+ from connecting to Redis if provided.
74
+ redis_password: Prevents external clients without the password
75
+ from connecting to Redis if provided.
76
+ plasma_directory: A directory where the Plasma memory mapped files will
77
+ be created.
78
+ worker_path: The path of the source code that will be run by the
79
+ worker.
80
+ setup_worker_path: The path of the Python file that will set up
81
+ the environment for the worker process.
82
+ huge_pages: Boolean flag indicating whether to start the Object
83
+ Store with hugetlbfs support. Requires plasma_directory.
84
+ include_dashboard: Boolean flag indicating whether to start the web
85
+ UI, which displays the status of the Ray cluster. If this value is
86
+ None, then the UI will be started if the relevant dependencies are
87
+ present.
88
+ dashboard_host: The host to bind the web UI server to. Can either be
89
+ localhost (127.0.0.1) or 0.0.0.0 (available from all interfaces).
90
+ By default, this is set to localhost to prevent access from
91
+ external machines.
92
+ dashboard_port: The port to bind the dashboard server to.
93
+ Defaults to 8265.
94
+ dashboard_agent_listen_port: The port for dashboard agents to listen on
95
+ for HTTP requests.
96
+ Defaults to 52365.
97
+ dashboard_grpc_port: The port for the dashboard head process to listen
98
+ for gRPC on.
99
+ Defaults to random available port.
100
+ runtime_env_agent_port: The port at which the runtime env agent
101
+ listens to for HTTP.
102
+ Defaults to random available port.
103
+ plasma_store_socket_name: If provided, it specifies the socket
104
+ name used by the plasma store.
105
+ raylet_socket_name: If provided, it specifies the socket path
106
+ used by the raylet process.
107
+ temp_dir: If provided, it will specify the root temporary
108
+ directory for the Ray process. Must be an absolute path.
109
+ storage: Specify a URI for persistent cluster-wide storage. This storage path
110
+ must be accessible by all nodes of the cluster, otherwise an error will be
111
+ raised.
112
+ runtime_env_dir_name: If provided, specifies the directory that
113
+ will be created in the session dir to hold runtime_env files.
114
+ include_log_monitor: If True, then start a log monitor to
115
+ monitor the log files for all processes on this node and push their
116
+ contents to Redis.
117
+ autoscaling_config: path to autoscaling config file.
118
+ metrics_agent_port: The port to bind metrics agent.
119
+ metrics_export_port: The port at which metrics are exposed
120
+ through a Prometheus endpoint.
121
+ no_monitor: If True, the ray autoscaler monitor for this cluster
122
+ will not be started.
123
+ _system_config: Configuration for overriding RayConfig
124
+ defaults. Used to set system configuration and for experimental Ray
125
+ core feature flags.
126
+ enable_object_reconstruction: Enable plasma reconstruction on
127
+ failure.
128
+ ray_debugger_external: If true, make the Ray debugger for a
129
+ worker available externally to the node it is running on. This will
130
+ bind on 0.0.0.0 instead of localhost.
131
+ env_vars: Override environment variables for the raylet.
132
+ session_name: The name of the session of the ray cluster.
133
+ webui: The url of the UI.
134
+ cluster_id: The cluster ID in hex string.
135
+ enable_physical_mode: Whether physical mode is enabled, which applies
136
+ constraint to tasks' resource consumption. As of now, only memory resource
137
+ is supported.
138
+ """
139
+
140
+ def __init__(
141
+ self,
142
+ redis_address: Optional[str] = None,
143
+ gcs_address: Optional[str] = None,
144
+ num_cpus: Optional[int] = None,
145
+ num_gpus: Optional[int] = None,
146
+ resources: Optional[Dict[str, float]] = None,
147
+ labels: Optional[Dict[str, str]] = None,
148
+ memory: Optional[float] = None,
149
+ object_store_memory: Optional[float] = None,
150
+ redis_max_memory: Optional[float] = None,
151
+ redis_port: Optional[int] = None,
152
+ redis_shard_ports: Optional[List[int]] = None,
153
+ object_manager_port: Optional[int] = None,
154
+ node_manager_port: int = 0,
155
+ gcs_server_port: Optional[int] = None,
156
+ node_ip_address: Optional[str] = None,
157
+ node_name: Optional[str] = None,
158
+ raylet_ip_address: Optional[str] = None,
159
+ min_worker_port: Optional[int] = None,
160
+ max_worker_port: Optional[int] = None,
161
+ worker_port_list: Optional[List[int]] = None,
162
+ ray_client_server_port: Optional[int] = None,
163
+ object_ref_seed: Optional[int] = None,
164
+ driver_mode=None,
165
+ redirect_output: Optional[bool] = None,
166
+ external_addresses: Optional[List[str]] = None,
167
+ num_redis_shards: Optional[int] = None,
168
+ redis_max_clients: Optional[int] = None,
169
+ redis_username: Optional[str] = ray_constants.REDIS_DEFAULT_USERNAME,
170
+ redis_password: Optional[str] = ray_constants.REDIS_DEFAULT_PASSWORD,
171
+ plasma_directory: Optional[str] = None,
172
+ worker_path: Optional[str] = None,
173
+ setup_worker_path: Optional[str] = None,
174
+ huge_pages: Optional[bool] = False,
175
+ include_dashboard: Optional[bool] = None,
176
+ dashboard_host: Optional[str] = ray_constants.DEFAULT_DASHBOARD_IP,
177
+ dashboard_port: Optional[bool] = ray_constants.DEFAULT_DASHBOARD_PORT,
178
+ dashboard_agent_listen_port: Optional[
179
+ int
180
+ ] = ray_constants.DEFAULT_DASHBOARD_AGENT_LISTEN_PORT,
181
+ runtime_env_agent_port: Optional[int] = None,
182
+ dashboard_grpc_port: Optional[int] = None,
183
+ plasma_store_socket_name: Optional[str] = None,
184
+ raylet_socket_name: Optional[str] = None,
185
+ temp_dir: Optional[str] = None,
186
+ storage: Optional[str] = None,
187
+ runtime_env_dir_name: Optional[str] = None,
188
+ include_log_monitor: Optional[str] = None,
189
+ autoscaling_config: Optional[str] = None,
190
+ ray_debugger_external: bool = False,
191
+ _system_config: Optional[Dict[str, str]] = None,
192
+ enable_object_reconstruction: Optional[bool] = False,
193
+ metrics_agent_port: Optional[int] = None,
194
+ metrics_export_port: Optional[int] = None,
195
+ tracing_startup_hook=None,
196
+ no_monitor: Optional[bool] = False,
197
+ env_vars: Optional[Dict[str, str]] = None,
198
+ session_name: Optional[str] = None,
199
+ webui: Optional[str] = None,
200
+ cluster_id: Optional[str] = None,
201
+ node_id: Optional[str] = None,
202
+ enable_physical_mode: bool = False,
203
+ ):
204
+ self.redis_address = redis_address
205
+ self.gcs_address = gcs_address
206
+ self.num_cpus = num_cpus
207
+ self.num_gpus = num_gpus
208
+ self.memory = memory
209
+ self.object_store_memory = object_store_memory
210
+ self.resources = resources
211
+ self.redis_max_memory = redis_max_memory
212
+ self.redis_port = redis_port
213
+ self.redis_shard_ports = redis_shard_ports
214
+ self.object_manager_port = object_manager_port
215
+ self.node_manager_port = node_manager_port
216
+ self.gcs_server_port = gcs_server_port
217
+ self.node_ip_address = node_ip_address
218
+ self.node_name = node_name
219
+ self.raylet_ip_address = raylet_ip_address
220
+ self.min_worker_port = min_worker_port
221
+ self.max_worker_port = max_worker_port
222
+ self.worker_port_list = worker_port_list
223
+ self.ray_client_server_port = ray_client_server_port
224
+ self.driver_mode = driver_mode
225
+ self.redirect_output = redirect_output
226
+ self.external_addresses = external_addresses
227
+ self.num_redis_shards = num_redis_shards
228
+ self.redis_max_clients = redis_max_clients
229
+ self.redis_username = redis_username
230
+ self.redis_password = redis_password
231
+ self.plasma_directory = plasma_directory
232
+ self.worker_path = worker_path
233
+ self.setup_worker_path = setup_worker_path
234
+ self.huge_pages = huge_pages
235
+ self.include_dashboard = include_dashboard
236
+ self.dashboard_host = dashboard_host
237
+ self.dashboard_port = dashboard_port
238
+ self.dashboard_agent_listen_port = dashboard_agent_listen_port
239
+ self.dashboard_grpc_port = dashboard_grpc_port
240
+ self.runtime_env_agent_port = runtime_env_agent_port
241
+ self.plasma_store_socket_name = plasma_store_socket_name
242
+ self.raylet_socket_name = raylet_socket_name
243
+ self.temp_dir = temp_dir
244
+ self.storage = storage or os.environ.get(
245
+ ray_constants.RAY_STORAGE_ENVIRONMENT_VARIABLE
246
+ )
247
+ self.runtime_env_dir_name = (
248
+ runtime_env_dir_name or ray_constants.DEFAULT_RUNTIME_ENV_DIR_NAME
249
+ )
250
+ self.include_log_monitor = include_log_monitor
251
+ self.autoscaling_config = autoscaling_config
252
+ self.metrics_agent_port = metrics_agent_port
253
+ self.metrics_export_port = metrics_export_port
254
+ self.tracing_startup_hook = tracing_startup_hook
255
+ self.no_monitor = no_monitor
256
+ self.object_ref_seed = object_ref_seed
257
+ self.ray_debugger_external = ray_debugger_external
258
+ self.env_vars = env_vars
259
+ self.session_name = session_name
260
+ self.webui = webui
261
+ self._system_config = _system_config or {}
262
+ self._enable_object_reconstruction = enable_object_reconstruction
263
+ self.labels = labels
264
+ self._check_usage()
265
+ self.cluster_id = cluster_id
266
+ self.node_id = node_id
267
+ self.enable_physical_mode = enable_physical_mode
268
+
269
+ # Set the internal config options for object reconstruction.
270
+ if enable_object_reconstruction:
271
+ # Turn off object pinning.
272
+ if self._system_config is None:
273
+ self._system_config = dict()
274
+ print(self._system_config)
275
+ self._system_config["lineage_pinning_enabled"] = True
276
+
277
+ def update(self, **kwargs):
278
+ """Update the settings according to the keyword arguments.
279
+
280
+ Args:
281
+ kwargs: The keyword arguments to set corresponding fields.
282
+ """
283
+ for arg in kwargs:
284
+ if hasattr(self, arg):
285
+ setattr(self, arg, kwargs[arg])
286
+ else:
287
+ raise ValueError(f"Invalid RayParams parameter in update: {arg}")
288
+
289
+ self._check_usage()
290
+
291
+ def update_if_absent(self, **kwargs):
292
+ """Update the settings when the target fields are None.
293
+
294
+ Args:
295
+ kwargs: The keyword arguments to set corresponding fields.
296
+ """
297
+ for arg in kwargs:
298
+ if hasattr(self, arg):
299
+ if getattr(self, arg) is None:
300
+ setattr(self, arg, kwargs[arg])
301
+ else:
302
+ raise ValueError(
303
+ f"Invalid RayParams parameter in update_if_absent: {arg}"
304
+ )
305
+
306
+ self._check_usage()
307
+
308
+ def update_pre_selected_port(self):
309
+ """Update the pre-selected port information
310
+
311
+ Returns:
312
+ The dictionary mapping of component -> ports.
313
+ """
314
+
315
+ def wrap_port(port):
316
+ # 0 port means select a random port for the grpc server.
317
+ if port is None or port == 0:
318
+ return []
319
+ else:
320
+ return [port]
321
+
322
+ # Create a dictionary of the component -> port mapping.
323
+ pre_selected_ports = {
324
+ "gcs": wrap_port(self.redis_port),
325
+ "object_manager": wrap_port(self.object_manager_port),
326
+ "node_manager": wrap_port(self.node_manager_port),
327
+ "gcs_server": wrap_port(self.gcs_server_port),
328
+ "client_server": wrap_port(self.ray_client_server_port),
329
+ "dashboard": wrap_port(self.dashboard_port),
330
+ "dashboard_agent_grpc": wrap_port(self.metrics_agent_port),
331
+ "dashboard_agent_http": wrap_port(self.dashboard_agent_listen_port),
332
+ "dashboard_grpc": wrap_port(self.dashboard_grpc_port),
333
+ "runtime_env_agent": wrap_port(self.runtime_env_agent_port),
334
+ "metrics_export": wrap_port(self.metrics_export_port),
335
+ }
336
+ redis_shard_ports = self.redis_shard_ports
337
+ if redis_shard_ports is None:
338
+ redis_shard_ports = []
339
+ pre_selected_ports["redis_shards"] = redis_shard_ports
340
+ if self.worker_port_list is None:
341
+ if self.min_worker_port is not None and self.max_worker_port is not None:
342
+ pre_selected_ports["worker_ports"] = list(
343
+ range(self.min_worker_port, self.max_worker_port + 1)
344
+ )
345
+ else:
346
+ # The dict is not updated when it requires random ports.
347
+ pre_selected_ports["worker_ports"] = []
348
+ else:
349
+ pre_selected_ports["worker_ports"] = [
350
+ int(port) for port in self.worker_port_list.split(",")
351
+ ]
352
+
353
+ # Update the pre selected port set.
354
+ self.reserved_ports = set()
355
+ for comp, port_list in pre_selected_ports.items():
356
+ for port in port_list:
357
+ if port in self.reserved_ports:
358
+ raise ValueError(
359
+ f"Ray component {comp} is trying to use "
360
+ f"a port number {port} that is used by other components.\n"
361
+ f"Port information: {self._format_ports(pre_selected_ports)}\n"
362
+ "If you allocate ports, please make sure the same port "
363
+ "is not used by multiple components."
364
+ )
365
+ self.reserved_ports.add(port)
366
+
367
+ def _check_usage(self):
368
+ if self.worker_port_list is not None:
369
+ for port_str in self.worker_port_list.split(","):
370
+ try:
371
+ port = int(port_str)
372
+ except ValueError as e:
373
+ raise ValueError(
374
+ "worker_port_list must be a comma-separated "
375
+ f"list of integers: {e}"
376
+ ) from None
377
+
378
+ if port < 1024 or port > 65535:
379
+ raise ValueError(
380
+ "Ports in worker_port_list must be "
381
+ f"between 1024 and 65535. Got: {port}"
382
+ )
383
+
384
+ # Used primarily for testing.
385
+ if os.environ.get("RAY_USE_RANDOM_PORTS", False):
386
+ if self.min_worker_port is None and self.max_worker_port is None:
387
+ self.min_worker_port = 0
388
+ self.max_worker_port = 0
389
+
390
+ if self.min_worker_port is not None:
391
+ if self.min_worker_port != 0 and (
392
+ self.min_worker_port < 1024 or self.min_worker_port > 65535
393
+ ):
394
+ raise ValueError(
395
+ "min_worker_port must be 0 or an integer between 1024 and 65535."
396
+ )
397
+
398
+ if self.max_worker_port is not None:
399
+ if self.min_worker_port is None:
400
+ raise ValueError(
401
+ "If max_worker_port is set, min_worker_port must also be set."
402
+ )
403
+ elif self.max_worker_port != 0:
404
+ if self.max_worker_port < 1024 or self.max_worker_port > 65535:
405
+ raise ValueError(
406
+ "max_worker_port must be 0 or an integer between "
407
+ "1024 and 65535."
408
+ )
409
+ elif self.max_worker_port <= self.min_worker_port:
410
+ raise ValueError(
411
+ "max_worker_port must be higher than min_worker_port."
412
+ )
413
+ if self.ray_client_server_port is not None:
414
+ if not check_ray_client_dependencies_installed():
415
+ raise ValueError(
416
+ "Ray Client requires pip package `ray[client]`. "
417
+ "If you installed the minimal Ray (e.g. `pip install ray`), "
418
+ "please reinstall by executing `pip install ray[client]`."
419
+ )
420
+ if (
421
+ self.ray_client_server_port < 1024
422
+ or self.ray_client_server_port > 65535
423
+ ):
424
+ raise ValueError(
425
+ "ray_client_server_port must be an integer "
426
+ "between 1024 and 65535."
427
+ )
428
+ if self.runtime_env_agent_port is not None:
429
+ if (
430
+ self.runtime_env_agent_port < 1024
431
+ or self.runtime_env_agent_port > 65535
432
+ ):
433
+ raise ValueError(
434
+ "runtime_env_agent_port must be an integer "
435
+ "between 1024 and 65535."
436
+ )
437
+
438
+ if self.resources is not None:
439
+
440
+ def build_error(resource, alternative):
441
+ return (
442
+ f"{self.resources} -> `{resource}` cannot be a "
443
+ "custom resource because it is one of the default resources "
444
+ f"({ray_constants.DEFAULT_RESOURCES}). "
445
+ f"Use `{alternative}` instead. For example, use `ray start "
446
+ f"--{alternative.replace('_', '-')}=1` instead of "
447
+ f"`ray start --resources={{'{resource}': 1}}`"
448
+ )
449
+
450
+ assert "CPU" not in self.resources, build_error("CPU", "num_cpus")
451
+ assert "GPU" not in self.resources, build_error("GPU", "num_gpus")
452
+ assert "memory" not in self.resources, build_error("memory", "memory")
453
+ assert "object_store_memory" not in self.resources, build_error(
454
+ "object_store_memory", "object_store_memory"
455
+ )
456
+
457
+ if self.redirect_output is not None:
458
+ raise DeprecationWarning("The redirect_output argument is deprecated.")
459
+
460
+ if self.temp_dir is not None and not os.path.isabs(self.temp_dir):
461
+ raise ValueError("temp_dir must be absolute path or None.")
462
+
463
+ validate_node_labels(self.labels)
464
+
465
+ def _format_ports(self, pre_selected_ports):
466
+ """Format the pre-selected ports information to be more human-readable."""
467
+ ports = pre_selected_ports.copy()
468
+
469
+ for comp, port_list in ports.items():
470
+ if len(port_list) == 1:
471
+ ports[comp] = port_list[0]
472
+ elif len(port_list) == 0:
473
+ # Nothing is selected, meaning it will be randomly selected.
474
+ ports[comp] = "random"
475
+ elif comp == "worker_ports":
476
+ min_port = port_list[0]
477
+ max_port = port_list[len(port_list) - 1]
478
+ if len(port_list) < 50:
479
+ port_range_str = str(port_list)
480
+ else:
481
+ port_range_str = f"from {min_port} to {max_port}"
482
+ ports[comp] = f"{len(port_list)} ports {port_range_str}"
483
+ return ports
.venv/lib/python3.11/site-packages/ray/_private/process_watcher.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import io
3
+ import logging
4
+ import sys
5
+ import os
6
+
7
+ from concurrent.futures import ThreadPoolExecutor
8
+
9
+ import ray
10
+ from ray.dashboard.consts import _PARENT_DEATH_THREASHOLD
11
+ import ray.dashboard.consts as dashboard_consts
12
+ import ray._private.ray_constants as ray_constants
13
+ from ray._private.utils import run_background_task
14
+
15
+ # Import psutil after ray so the packaged version is used.
16
+ import psutil
17
+
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # TODO: move all consts from dashboard_consts to ray_constants and rename to remove
22
+ # DASHBOARD_ prefixes.
23
+
24
+ # Publishes at most this number of lines of Raylet logs, when the Raylet dies
25
+ # unexpectedly.
26
+ _RAYLET_LOG_MAX_PUBLISH_LINES = 20
27
+
28
+ # Reads at most this amount of Raylet logs from the tail, for publishing and
29
+ # checking if the Raylet was terminated gracefully.
30
+ _RAYLET_LOG_MAX_TAIL_SIZE = 1 * 1024**2
31
+
32
+ try:
33
+ create_task = asyncio.create_task
34
+ except AttributeError:
35
+ create_task = asyncio.ensure_future
36
+
37
+
38
+ def get_raylet_pid():
39
+ # TODO(edoakes): RAY_RAYLET_PID isn't properly set on Windows. This is
40
+ # only used for fate-sharing with the raylet and we need a different
41
+ # fate-sharing mechanism for Windows anyways.
42
+ if sys.platform in ["win32", "cygwin"]:
43
+ return None
44
+ raylet_pid = int(os.environ["RAY_RAYLET_PID"])
45
+ assert raylet_pid > 0
46
+ logger.info("raylet pid is %s", raylet_pid)
47
+ return raylet_pid
48
+
49
+
50
+ def create_check_raylet_task(log_dir, gcs_address, parent_dead_callback, loop):
51
+ """
52
+ Creates an asyncio task to periodically check if the raylet process is still
53
+ running. If raylet is dead for _PARENT_DEATH_THREASHOLD (5) times, prepare to exit
54
+ as follows:
55
+
56
+ - Write logs about whether the raylet exit is graceful, by looking into the raylet
57
+ log and search for term "SIGTERM",
58
+ - Flush the logs via GcsPublisher,
59
+ - Exit.
60
+ """
61
+ if sys.platform in ["win32", "cygwin"]:
62
+ raise RuntimeError("can't check raylet process in Windows.")
63
+ raylet_pid = get_raylet_pid()
64
+
65
+ if dashboard_consts.PARENT_HEALTH_CHECK_BY_PIPE:
66
+ logger.info("check_parent_via_pipe")
67
+ check_parent_task = _check_parent_via_pipe(
68
+ log_dir, gcs_address, loop, parent_dead_callback
69
+ )
70
+ else:
71
+ logger.info("_check_parent")
72
+ check_parent_task = _check_parent(
73
+ raylet_pid, log_dir, gcs_address, parent_dead_callback
74
+ )
75
+
76
+ return run_background_task(check_parent_task)
77
+
78
+
79
+ def report_raylet_error_logs(log_dir: str, gcs_address: str):
80
+ log_path = os.path.join(log_dir, "raylet.out")
81
+ error = False
82
+ msg = "Raylet is terminated. "
83
+ try:
84
+ with open(log_path, "r", encoding="utf-8") as f:
85
+ # Seek to _RAYLET_LOG_MAX_TAIL_SIZE from the end if the
86
+ # file is larger than that.
87
+ f.seek(0, io.SEEK_END)
88
+ pos = max(0, f.tell() - _RAYLET_LOG_MAX_TAIL_SIZE)
89
+ f.seek(pos, io.SEEK_SET)
90
+ # Read remaining logs by lines.
91
+ raylet_logs = f.readlines()
92
+ # Assume the SIGTERM message must exist within the last
93
+ # _RAYLET_LOG_MAX_TAIL_SIZE of the log file.
94
+ if any("Raylet received SIGTERM" in line for line in raylet_logs):
95
+ msg += "Termination is graceful."
96
+ logger.info(msg)
97
+ else:
98
+ msg += (
99
+ "Termination is unexpected. Possible reasons "
100
+ "include: (1) SIGKILL by the user or system "
101
+ "OOM killer, (2) Invalid memory access from "
102
+ "Raylet causing SIGSEGV or SIGBUS, "
103
+ "(3) Other termination signals. "
104
+ f"Last {_RAYLET_LOG_MAX_PUBLISH_LINES} lines "
105
+ "of the Raylet logs:\n"
106
+ )
107
+ msg += " " + " ".join(
108
+ raylet_logs[-_RAYLET_LOG_MAX_PUBLISH_LINES:]
109
+ )
110
+ error = True
111
+ except Exception as e:
112
+ msg += f"Failed to read Raylet logs at {log_path}: {e}!"
113
+ logger.exception(msg)
114
+ error = True
115
+ if error:
116
+ logger.error(msg)
117
+ # TODO: switch to async if necessary.
118
+ ray._private.utils.publish_error_to_driver(
119
+ ray_constants.RAYLET_DIED_ERROR,
120
+ msg,
121
+ gcs_publisher=ray._raylet.GcsPublisher(address=gcs_address),
122
+ )
123
+ else:
124
+ logger.info(msg)
125
+
126
+
127
+ async def _check_parent_via_pipe(
128
+ log_dir: str, gcs_address: str, loop, parent_dead_callback
129
+ ):
130
+ while True:
131
+ try:
132
+ # Read input asynchronously.
133
+ # The parent (raylet) should have redirected its pipe
134
+ # to stdin. If we read 0 bytes from stdin, it means
135
+ # the process is dead.
136
+ with ThreadPoolExecutor(max_workers=1) as executor:
137
+ input_data = await loop.run_in_executor(
138
+ executor, lambda: sys.stdin.readline()
139
+ )
140
+ if len(input_data) == 0:
141
+ # cannot read bytes from parent == parent is dead.
142
+ parent_dead_callback("_check_parent_via_pipe: The parent is dead.")
143
+ report_raylet_error_logs(log_dir, gcs_address)
144
+ sys.exit(0)
145
+ except Exception as e:
146
+ logger.exception(
147
+ "raylet health checking is failed. "
148
+ f"The agent process may leak. Exception: {e}"
149
+ )
150
+
151
+
152
+ async def _check_parent(raylet_pid, log_dir, gcs_address, parent_dead_callback):
153
+ """Check if raylet is dead and fate-share if it is."""
154
+ try:
155
+ curr_proc = psutil.Process()
156
+ parent_death_cnt = 0
157
+ while True:
158
+ parent = curr_proc.parent()
159
+ # If the parent is dead, it is None.
160
+ parent_gone = parent is None
161
+ init_assigned_for_parent = False
162
+ parent_changed = False
163
+
164
+ if parent:
165
+ # Sometimes, the parent is changed to the `init` process.
166
+ # In this case, the parent.pid is 1.
167
+ init_assigned_for_parent = parent.pid == 1
168
+ # Sometimes, the parent is dead, and the pid is reused
169
+ # by other processes. In this case, this condition is triggered.
170
+ parent_changed = raylet_pid != parent.pid
171
+
172
+ if parent_gone or init_assigned_for_parent or parent_changed:
173
+ parent_death_cnt += 1
174
+ logger.warning(
175
+ f"Raylet is considered dead {parent_death_cnt} X. "
176
+ f"If it reaches to {_PARENT_DEATH_THREASHOLD}, the agent "
177
+ f"will kill itself. Parent: {parent}, "
178
+ f"parent_gone: {parent_gone}, "
179
+ f"init_assigned_for_parent: {init_assigned_for_parent}, "
180
+ f"parent_changed: {parent_changed}."
181
+ )
182
+ if parent_death_cnt < _PARENT_DEATH_THREASHOLD:
183
+ await asyncio.sleep(
184
+ dashboard_consts.DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_S
185
+ )
186
+ continue
187
+
188
+ parent_dead_callback("_check_parent: The parent is dead.")
189
+ report_raylet_error_logs(log_dir, gcs_address)
190
+ sys.exit(0)
191
+ else:
192
+ parent_death_cnt = 0
193
+ await asyncio.sleep(
194
+ dashboard_consts.DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_S
195
+ )
196
+ except Exception:
197
+ logger.exception("Failed to check parent PID, exiting.")
198
+ sys.exit(1)
.venv/lib/python3.11/site-packages/ray/_private/profiling.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from collections import defaultdict
4
+ from dataclasses import dataclass, asdict
5
+ from typing import List, Dict, Union
6
+
7
+ import ray
8
+
9
+
10
+ class _NullLogSpan:
11
+ """A log span context manager that does nothing"""
12
+
13
+ def __enter__(self):
14
+ pass
15
+
16
+ def __exit__(self, type, value, tb):
17
+ pass
18
+
19
+
20
+ PROFILING_ENABLED = "RAY_PROFILING" in os.environ
21
+ NULL_LOG_SPAN = _NullLogSpan()
22
+
23
+ # Colors are specified at
24
+ # https://github.com/catapult-project/catapult/blob/master/tracing/tracing/base/color_scheme.html. # noqa: E501
25
+ _default_color_mapping = defaultdict(
26
+ lambda: "generic_work",
27
+ {
28
+ "worker_idle": "cq_build_abandoned",
29
+ "task": "rail_response",
30
+ "task:deserialize_arguments": "rail_load",
31
+ "task:execute": "rail_animation",
32
+ "task:store_outputs": "rail_idle",
33
+ "wait_for_function": "detailed_memory_dump",
34
+ "ray.get": "good",
35
+ "ray.put": "terrible",
36
+ "ray.wait": "vsync_highlight_color",
37
+ "submit_task": "background_memory_dump",
38
+ "fetch_and_run_function": "detailed_memory_dump",
39
+ "register_remote_function": "detailed_memory_dump",
40
+ },
41
+ )
42
+
43
+
44
+ @dataclass(init=True)
45
+ class ChromeTracingCompleteEvent:
46
+ # https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview#heading=h.lpfof2aylapb # noqa
47
+ # The event categories. This is a comma separated list of categories
48
+ # for the event. The categories can be used to hide events in
49
+ # the Trace Viewer UI.
50
+ cat: str
51
+ # The string displayed on the event.
52
+ name: str
53
+ # The identifier for the group of rows that the event
54
+ # appears in.
55
+ pid: int
56
+ # The identifier for the row that the event appears in.
57
+ tid: int
58
+ # The start time in microseconds.
59
+ ts: int
60
+ # The duration in microseconds.
61
+ dur: int
62
+ # This is the name of the color to display the box in.
63
+ cname: str
64
+ # The extra user-defined data.
65
+ args: Dict[str, Union[str, int]]
66
+ # The event type (X means the complete event).
67
+ ph: str = "X"
68
+
69
+
70
+ @dataclass(init=True)
71
+ class ChromeTracingMetadataEvent:
72
+ # https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview#bookmark=id.iycbnb4z7i9g # noqa
73
+ name: str
74
+ # Metadata arguments. E.g., name: <metadata_name>
75
+ args: Dict[str, str]
76
+ # The process id of this event. In Ray, pid indicates the node.
77
+ pid: int
78
+ # The thread id of this event. In Ray, tid indicates each worker.
79
+ tid: int = None
80
+ # M means the metadata event.
81
+ ph: str = "M"
82
+
83
+
84
+ def profile(event_type, extra_data=None):
85
+ """Profile a span of time so that it appears in the timeline visualization.
86
+
87
+ Note that this only works in the raylet code path.
88
+
89
+ This function can be used as follows (both on the driver or within a task).
90
+
91
+ .. testcode::
92
+ import ray._private.profiling as profiling
93
+
94
+ with profiling.profile("custom event", extra_data={'key': 'val'}):
95
+ # Do some computation here.
96
+ x = 1 * 2
97
+
98
+ Optionally, a dictionary can be passed as the "extra_data" argument, and
99
+ it can have keys "name" and "cname" if you want to override the default
100
+ timeline display text and box color. Other values will appear at the bottom
101
+ of the chrome tracing GUI when you click on the box corresponding to this
102
+ profile span.
103
+
104
+ Args:
105
+ event_type: A string describing the type of the event.
106
+ extra_data: This must be a dictionary mapping strings to strings. This
107
+ data will be added to the json objects that are used to populate
108
+ the timeline, so if you want to set a particular color, you can
109
+ simply set the "cname" attribute to an appropriate color.
110
+ Similarly, if you set the "name" attribute, then that will set the
111
+ text displayed on the box in the timeline.
112
+
113
+ Returns:
114
+ An object that can profile a span of time via a "with" statement.
115
+ """
116
+ if not PROFILING_ENABLED:
117
+ return NULL_LOG_SPAN
118
+ worker = ray._private.worker.global_worker
119
+ if worker.mode == ray._private.worker.LOCAL_MODE:
120
+ return NULL_LOG_SPAN
121
+ return worker.core_worker.profile_event(event_type.encode("ascii"), extra_data)
122
+
123
+
124
+ def chrome_tracing_dump(
125
+ tasks: List[dict],
126
+ ) -> str:
127
+ """Generate a chrome/perfetto tracing dump using task events.
128
+
129
+ Args:
130
+ tasks: List of tasks generated by a state API list_tasks(detail=True).
131
+
132
+ Returns:
133
+ Json serialized dump to create a chrome/perfetto tracing.
134
+ """
135
+ # All events from given tasks.
136
+ all_events = []
137
+
138
+ # Chrome tracing doesn't have a concept of "node". Instead, we use
139
+ # chrome tracing's pid == ray's node.
140
+ # chrome tracing's tid == ray's process.
141
+ # Note that pid or tid is usually integer, but ray's node/process has
142
+ # ids in string.
143
+ # Unfortunately, perfetto doesn't allow to have string as a value of pid/tid.
144
+ # To workaround it, we use Metadata event from chrome tracing schema
145
+ # (https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview#heading=h.xqopa5m0e28f) # noqa
146
+ # which allows pid/tid -> name mapping. In order to use this schema
147
+ # we build node_ip/(node_ip, worker_id) -> arbitrary index mapping.
148
+
149
+ # node ip address -> node idx.
150
+ node_to_index = {}
151
+ # Arbitrary index mapped to the ip address.
152
+ node_idx = 0
153
+ # (node index, worker id) -> worker idx
154
+ worker_to_index = {}
155
+ # Arbitrary index mapped to the (node index, worker id).
156
+ worker_idx = 0
157
+
158
+ for task in tasks:
159
+ profiling_data = task.get("profiling_data", [])
160
+ if profiling_data:
161
+ node_ip_address = profiling_data["node_ip_address"]
162
+ component_events = profiling_data["events"]
163
+ component_type = profiling_data["component_type"]
164
+ component_id = component_type + ":" + profiling_data["component_id"]
165
+
166
+ if component_type not in ["worker", "driver"]:
167
+ continue
168
+
169
+ for event in component_events:
170
+ extra_data = event["extra_data"]
171
+ # Propagate extra data.
172
+ extra_data["task_id"] = task["task_id"]
173
+ extra_data["job_id"] = task["job_id"]
174
+ extra_data["attempt_number"] = task["attempt_number"]
175
+ extra_data["func_or_class_name"] = task["func_or_class_name"]
176
+ extra_data["actor_id"] = task["actor_id"]
177
+ event_name = event["event_name"]
178
+
179
+ # build a id -> arbitrary index mapping
180
+ if node_ip_address not in node_to_index:
181
+ node_to_index[node_ip_address] = node_idx
182
+ # Whenever new node ip is introduced, we increment the index.
183
+ node_idx += 1
184
+
185
+ if (
186
+ node_to_index[node_ip_address],
187
+ component_id,
188
+ ) not in worker_to_index: # noqa
189
+ worker_to_index[
190
+ (node_to_index[node_ip_address], component_id)
191
+ ] = worker_idx # noqa
192
+ worker_idx += 1
193
+
194
+ # Modify the name with the additional user-defined extra data.
195
+ cname = _default_color_mapping[event["event_name"]]
196
+ name = event_name
197
+
198
+ if "cname" in extra_data:
199
+ cname = _default_color_mapping[event["extra_data"]["cname"]]
200
+ if "name" in extra_data:
201
+ name = extra_data["name"]
202
+
203
+ new_event = ChromeTracingCompleteEvent(
204
+ cat=event_name,
205
+ name=name,
206
+ pid=node_to_index[node_ip_address],
207
+ tid=worker_to_index[(node_to_index[node_ip_address], component_id)],
208
+ ts=event["start_time"] * 1e3,
209
+ dur=(event["end_time"] * 1e3) - (event["start_time"] * 1e3),
210
+ cname=cname,
211
+ args=extra_data,
212
+ )
213
+ all_events.append(asdict(new_event))
214
+
215
+ for node, i in node_to_index.items():
216
+ all_events.append(
217
+ asdict(
218
+ ChromeTracingMetadataEvent(
219
+ name="process_name",
220
+ pid=i,
221
+ args={"name": f"Node {node}"},
222
+ )
223
+ )
224
+ )
225
+
226
+ for worker, i in worker_to_index.items():
227
+ all_events.append(
228
+ asdict(
229
+ ChromeTracingMetadataEvent(
230
+ name="thread_name",
231
+ ph="M",
232
+ tid=i,
233
+ pid=worker[0],
234
+ args={"name": worker[1]},
235
+ )
236
+ )
237
+ )
238
+
239
+ # Handle task event disabled.
240
+ return json.dumps(all_events)
.venv/lib/python3.11/site-packages/ray/_private/prometheus_exporter.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NOTE: This file has been copied from OpenCensus Python exporter.
2
+ # It is because OpenCensus Prometheus exporter hasn't released for a while
3
+ # and the latest version has a compatibility issue with the latest OpenCensus
4
+ # library.
5
+
6
+ import re
7
+
8
+ from prometheus_client import start_http_server
9
+ from prometheus_client.core import (
10
+ REGISTRY,
11
+ CounterMetricFamily,
12
+ GaugeMetricFamily,
13
+ HistogramMetricFamily,
14
+ UnknownMetricFamily,
15
+ )
16
+
17
+ from opencensus.common.transports import sync
18
+ from opencensus.stats import aggregation_data as aggregation_data_module
19
+ from opencensus.stats import base_exporter
20
+ import logging
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class Options(object):
26
+ """Options contains options for configuring the exporter.
27
+ The address can be empty as the prometheus client will
28
+ assume it's localhost
29
+ :type namespace: str
30
+ :param namespace: The prometheus namespace to be used. Defaults to ''.
31
+ :type port: int
32
+ :param port: The Prometheus port to be used. Defaults to 8000.
33
+ :type address: str
34
+ :param address: The Prometheus address to be used. Defaults to ''.
35
+ :type registry: registry
36
+ :param registry: The Prometheus address to be used. Defaults to ''.
37
+ :type registry: :class:`~prometheus_client.core.CollectorRegistry`
38
+ :param registry: A Prometheus collector registry instance.
39
+ """
40
+
41
+ def __init__(self, namespace="", port=8000, address="", registry=REGISTRY):
42
+ self._namespace = namespace
43
+ self._registry = registry
44
+ self._port = int(port)
45
+ self._address = address
46
+
47
+ @property
48
+ def registry(self):
49
+ """Prometheus Collector Registry instance"""
50
+ return self._registry
51
+
52
+ @property
53
+ def namespace(self):
54
+ """Prefix to be used with view name"""
55
+ return self._namespace
56
+
57
+ @property
58
+ def port(self):
59
+ """Port number to listen"""
60
+ return self._port
61
+
62
+ @property
63
+ def address(self):
64
+ """Endpoint address (default is localhost)"""
65
+ return self._address
66
+
67
+
68
+ class Collector(object):
69
+ """Collector represents the Prometheus Collector object"""
70
+
71
+ def __init__(self, options=Options(), view_name_to_data_map=None):
72
+ if view_name_to_data_map is None:
73
+ view_name_to_data_map = {}
74
+ self._options = options
75
+ self._registry = options.registry
76
+ self._view_name_to_data_map = view_name_to_data_map
77
+ self._registered_views = {}
78
+
79
+ @property
80
+ def options(self):
81
+ """Options to be used to configure the exporter"""
82
+ return self._options
83
+
84
+ @property
85
+ def registry(self):
86
+ """Prometheus Collector Registry instance"""
87
+ return self._registry
88
+
89
+ @property
90
+ def view_name_to_data_map(self):
91
+ """Map with all view data objects
92
+ that will be sent to Prometheus
93
+ """
94
+ return self._view_name_to_data_map
95
+
96
+ @property
97
+ def registered_views(self):
98
+ """Map with all registered views"""
99
+ return self._registered_views
100
+
101
+ def register_view(self, view):
102
+ """register_view will create the needed structure
103
+ in order to be able to sent all data to Prometheus
104
+ """
105
+ v_name = get_view_name(self.options.namespace, view)
106
+
107
+ if v_name not in self.registered_views:
108
+ desc = {
109
+ "name": v_name,
110
+ "documentation": view.description,
111
+ "labels": list(map(sanitize, view.columns)),
112
+ "units": view.measure.unit,
113
+ }
114
+ self.registered_views[v_name] = desc
115
+
116
+ def add_view_data(self, view_data):
117
+ """Add view data object to be sent to server"""
118
+ self.register_view(view_data.view)
119
+ v_name = get_view_name(self.options.namespace, view_data.view)
120
+ self.view_name_to_data_map[v_name] = view_data
121
+
122
+ # TODO: add start and end timestamp
123
+ def to_metric(self, desc, tag_values, agg_data, metrics_map):
124
+ """to_metric translate the data that OpenCensus create
125
+ to Prometheus format, using Prometheus Metric object
126
+ :type desc: dict
127
+ :param desc: The map that describes view definition
128
+ :type tag_values: tuple of :class:
129
+ `~opencensus.tags.tag_value.TagValue`
130
+ :param object of opencensus.tags.tag_value.TagValue:
131
+ TagValue object used as label values
132
+ :type agg_data: object of :class:
133
+ `~opencensus.stats.aggregation_data.AggregationData`
134
+ :param object of opencensus.stats.aggregation_data.AggregationData:
135
+ Aggregated data that needs to be converted as Prometheus samples
136
+ :rtype: :class:`~prometheus_client.core.CounterMetricFamily` or
137
+ :class:`~prometheus_client.core.HistogramMetricFamily` or
138
+ :class:`~prometheus_client.core.UnknownMetricFamily` or
139
+ :class:`~prometheus_client.core.GaugeMetricFamily`
140
+ """
141
+ metric_name = desc["name"]
142
+ metric_description = desc["documentation"]
143
+ label_keys = desc["labels"]
144
+ metric_units = desc["units"]
145
+ assert len(tag_values) == len(label_keys), (tag_values, label_keys)
146
+ # Prometheus requires that all tag values be strings hence
147
+ # the need to cast none to the empty string before exporting. See
148
+ # https://github.com/census-instrumentation/opencensus-python/issues/480
149
+ tag_values = [tv if tv else "" for tv in tag_values]
150
+
151
+ if isinstance(agg_data, aggregation_data_module.CountAggregationData):
152
+ metric = metrics_map.get(metric_name)
153
+ if not metric:
154
+ metric = CounterMetricFamily(
155
+ name=metric_name,
156
+ documentation=metric_description,
157
+ unit=metric_units,
158
+ labels=label_keys,
159
+ )
160
+ metrics_map[metric_name] = metric
161
+ metric.add_metric(labels=tag_values, value=agg_data.count_data)
162
+ return
163
+
164
+ elif isinstance(agg_data, aggregation_data_module.DistributionAggregationData):
165
+
166
+ assert agg_data.bounds == sorted(agg_data.bounds)
167
+ # buckets are a list of buckets. Each bucket is another list with
168
+ # a pair of bucket name and value, or a triple of bucket name,
169
+ # value, and exemplar. buckets need to be in order.
170
+ buckets = []
171
+ cum_count = 0 # Prometheus buckets expect cumulative count.
172
+ for ii, bound in enumerate(agg_data.bounds):
173
+ cum_count += agg_data.counts_per_bucket[ii]
174
+ bucket = [str(bound), cum_count]
175
+ buckets.append(bucket)
176
+ # Prometheus requires buckets to be sorted, and +Inf present.
177
+ # In OpenCensus we don't have +Inf in the bucket bonds so need to
178
+ # append it here.
179
+ buckets.append(["+Inf", agg_data.count_data])
180
+ metric = metrics_map.get(metric_name)
181
+ if not metric:
182
+ metric = HistogramMetricFamily(
183
+ name=metric_name,
184
+ documentation=metric_description,
185
+ labels=label_keys,
186
+ )
187
+ metrics_map[metric_name] = metric
188
+ metric.add_metric(
189
+ labels=tag_values,
190
+ buckets=buckets,
191
+ sum_value=agg_data.sum,
192
+ )
193
+ return
194
+
195
+ elif isinstance(agg_data, aggregation_data_module.SumAggregationData):
196
+ metric = metrics_map.get(metric_name)
197
+ if not metric:
198
+ metric = UnknownMetricFamily(
199
+ name=metric_name,
200
+ documentation=metric_description,
201
+ labels=label_keys,
202
+ )
203
+ metrics_map[metric_name] = metric
204
+ metric.add_metric(labels=tag_values, value=agg_data.sum_data)
205
+ return
206
+
207
+ elif isinstance(agg_data, aggregation_data_module.LastValueAggregationData):
208
+ metric = metrics_map.get(metric_name)
209
+ if not metric:
210
+ metric = GaugeMetricFamily(
211
+ name=metric_name,
212
+ documentation=metric_description,
213
+ labels=label_keys,
214
+ )
215
+ metrics_map[metric_name] = metric
216
+ metric.add_metric(labels=tag_values, value=agg_data.value)
217
+ return
218
+
219
+ else:
220
+ raise ValueError(f"unsupported aggregation type {type(agg_data)}")
221
+
222
+ def collect(self): # pragma: NO COVER
223
+ """Collect fetches the statistics from OpenCensus
224
+ and delivers them as Prometheus Metrics.
225
+ Collect is invoked every time a prometheus.Gatherer is run
226
+ for example when the HTTP endpoint is invoked by Prometheus.
227
+ """
228
+ # Make a shallow copy of self._view_name_to_data_map, to avoid seeing
229
+ # concurrent modifications when iterating through the dictionary.
230
+ metrics_map = {}
231
+ for v_name, view_data in self._view_name_to_data_map.copy().items():
232
+ if v_name not in self.registered_views:
233
+ continue
234
+ desc = self.registered_views[v_name]
235
+ for tag_values in view_data.tag_value_aggregation_data_map:
236
+ agg_data = view_data.tag_value_aggregation_data_map[tag_values]
237
+ self.to_metric(desc, tag_values, agg_data, metrics_map)
238
+
239
+ for metric in metrics_map.values():
240
+ yield metric
241
+
242
+
243
+ class PrometheusStatsExporter(base_exporter.StatsExporter):
244
+ """Exporter exports stats to Prometheus, users need
245
+ to register the exporter as an HTTP Handler to be
246
+ able to export.
247
+ :type options:
248
+ :class:`~opencensus.ext.prometheus.stats_exporter.Options`
249
+ :param options: An options object with the parameters to instantiate the
250
+ prometheus exporter.
251
+ :type gatherer: :class:`~prometheus_client.core.CollectorRegistry`
252
+ :param gatherer: A Prometheus collector registry instance.
253
+ :type transport:
254
+ :class:`opencensus.common.transports.sync.SyncTransport` or
255
+ :class:`opencensus.common.transports.async_.AsyncTransport`
256
+ :param transport: An instance of a Transpor to send data with.
257
+ :type collector:
258
+ :class:`~opencensus.ext.prometheus.stats_exporter.Collector`
259
+ :param collector: An instance of the Prometheus Collector object.
260
+ """
261
+
262
+ def __init__(
263
+ self, options, gatherer, transport=sync.SyncTransport, collector=Collector()
264
+ ):
265
+ self._options = options
266
+ self._gatherer = gatherer
267
+ self._collector = collector
268
+ self._transport = transport(self)
269
+ self.serve_http()
270
+ REGISTRY.register(self._collector)
271
+
272
+ @property
273
+ def transport(self):
274
+ """The transport way to be sent data to server
275
+ (default is sync).
276
+ """
277
+ return self._transport
278
+
279
+ @property
280
+ def collector(self):
281
+ """Collector class instance to be used
282
+ to communicate with Prometheus
283
+ """
284
+ return self._collector
285
+
286
+ @property
287
+ def gatherer(self):
288
+ """Prometheus Collector Registry instance"""
289
+ return self._gatherer
290
+
291
+ @property
292
+ def options(self):
293
+ """Options to be used to configure the exporter"""
294
+ return self._options
295
+
296
+ def export(self, view_data):
297
+ """export send the data to the transport class
298
+ in order to be sent to Prometheus in a sync or async way.
299
+ """
300
+ if view_data is not None: # pragma: NO COVER
301
+ self.transport.export(view_data)
302
+
303
+ def on_register_view(self, view):
304
+ return NotImplementedError("Not supported by Prometheus")
305
+
306
+ def emit(self, view_data): # pragma: NO COVER
307
+ """Emit exports to the Prometheus if view data has one or more rows.
308
+ Each OpenCensus AggregationData will be converted to
309
+ corresponding Prometheus Metric: SumData will be converted
310
+ to Untyped Metric, CountData will be a Counter Metric
311
+ DistributionData will be a Histogram Metric.
312
+ """
313
+
314
+ for v_data in view_data:
315
+ if v_data.tag_value_aggregation_data_map is None:
316
+ v_data.tag_value_aggregation_data_map = {}
317
+
318
+ self.collector.add_view_data(v_data)
319
+
320
+ def serve_http(self):
321
+ """serve_http serves the Prometheus endpoint."""
322
+ address = str(self.options.address)
323
+ kwargs = {"addr": address} if address else {}
324
+ start_http_server(port=self.options.port, **kwargs)
325
+
326
+
327
+ def new_stats_exporter(option):
328
+ """new_stats_exporter returns an exporter
329
+ that exports stats to Prometheus.
330
+ """
331
+ if option.namespace == "":
332
+ raise ValueError("Namespace can not be empty string.")
333
+
334
+ collector = new_collector(option)
335
+
336
+ exporter = PrometheusStatsExporter(
337
+ options=option, gatherer=option.registry, collector=collector
338
+ )
339
+ return exporter
340
+
341
+
342
+ def new_collector(options):
343
+ """new_collector should be used
344
+ to create instance of Collector class in order to
345
+ prevent the usage of constructor directly
346
+ """
347
+ return Collector(options=options)
348
+
349
+
350
+ def get_view_name(namespace, view):
351
+ """create the name for the view"""
352
+ name = ""
353
+ if namespace != "":
354
+ name = namespace + "_"
355
+ return sanitize(name + view.name)
356
+
357
+
358
+ _NON_LETTERS_NOR_DIGITS_RE = re.compile(r"[^\w]", re.UNICODE | re.IGNORECASE)
359
+
360
+
361
+ def sanitize(key):
362
+ """sanitize the given metric name or label according to Prometheus rule.
363
+ Replace all characters other than [A-Za-z0-9_] with '_'.
364
+ """
365
+ return _NON_LETTERS_NOR_DIGITS_RE.sub("_", key)
.venv/lib/python3.11/site-packages/ray/_private/protobuf_compat.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from google.protobuf.json_format import MessageToDict
2
+ import inspect
3
+
4
+ """
5
+ This module provides a compatibility layer for different versions of the protobuf
6
+ library.
7
+ """
8
+
9
+ _protobuf_has_old_arg_name_cached = None
10
+
11
+
12
+ def _protobuf_has_old_arg_name():
13
+ """Cache the inspect result to avoid doing it for every single message."""
14
+ global _protobuf_has_old_arg_name_cached
15
+ if _protobuf_has_old_arg_name_cached is None:
16
+ params = inspect.signature(MessageToDict).parameters
17
+ _protobuf_has_old_arg_name_cached = "including_default_value_fields" in params
18
+ return _protobuf_has_old_arg_name_cached
19
+
20
+
21
+ def rename_always_print_fields_with_no_presence(kwargs):
22
+ """
23
+ Protobuf version 5.26.0rc2 renamed argument for `MessageToDict`:
24
+ `including_default_value_fields` -> `always_print_fields_with_no_presence`.
25
+ See https://github.com/protocolbuffers/protobuf/commit/06e7caba58ede0220b110b89d08f329e5f8a7537#diff-8de817c14d6a087981503c9aea38730b1b3e98f4e306db5ff9d525c7c304f234L129 # noqa: E501
26
+
27
+ We choose to always use the new argument name. If user used the old arg, we raise an
28
+ error.
29
+
30
+ If protobuf does not have the new arg name but have the old arg name, we rename our
31
+ arg to the old one.
32
+ """
33
+ old_arg_name = "including_default_value_fields"
34
+ new_arg_name = "always_print_fields_with_no_presence"
35
+ if old_arg_name in kwargs:
36
+ raise ValueError(f"{old_arg_name} is deprecated, please use {new_arg_name}")
37
+
38
+ if new_arg_name in kwargs and _protobuf_has_old_arg_name():
39
+ kwargs[old_arg_name] = kwargs.pop(new_arg_name)
40
+
41
+ return kwargs
42
+
43
+
44
+ def message_to_dict(*args, **kwargs):
45
+ kwargs = rename_always_print_fields_with_no_presence(kwargs)
46
+ return MessageToDict(*args, **kwargs)
.venv/lib/python3.11/site-packages/ray/_private/pydantic_compat.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ruff: noqa
2
+ import packaging.version
3
+
4
+ # Pydantic is a dependency of `ray["default"]` but not the minimal installation,
5
+ # so handle the case where it isn't installed.
6
+ try:
7
+ import pydantic
8
+
9
+ PYDANTIC_INSTALLED = True
10
+ except ImportError:
11
+ pydantic = None
12
+ PYDANTIC_INSTALLED = False
13
+
14
+
15
+ if not PYDANTIC_INSTALLED:
16
+ IS_PYDANTIC_2 = False
17
+ BaseModel = None
18
+ Extra = None
19
+ Field = None
20
+ NonNegativeFloat = None
21
+ NonNegativeInt = None
22
+ PositiveFloat = None
23
+ PositiveInt = None
24
+ PrivateAttr = None
25
+ StrictInt = None
26
+ ValidationError = None
27
+ root_validator = None
28
+ validator = None
29
+ is_subclass_of_base_model = lambda obj: False
30
+ # In pydantic <1.9.0, __version__ attribute is missing, issue ref:
31
+ # https://github.com/pydantic/pydantic/issues/2572, so we need to check
32
+ # the existence prior to comparison.
33
+ elif not hasattr(pydantic, "__version__") or packaging.version.parse(
34
+ pydantic.__version__
35
+ ) < packaging.version.parse("2.0"):
36
+ IS_PYDANTIC_2 = False
37
+ from pydantic import (
38
+ BaseModel,
39
+ Extra,
40
+ Field,
41
+ NonNegativeFloat,
42
+ NonNegativeInt,
43
+ PositiveFloat,
44
+ PositiveInt,
45
+ PrivateAttr,
46
+ StrictInt,
47
+ ValidationError,
48
+ root_validator,
49
+ validator,
50
+ )
51
+
52
+ def is_subclass_of_base_model(obj):
53
+ return issubclass(obj, BaseModel)
54
+
55
+ else:
56
+ IS_PYDANTIC_2 = True
57
+ from pydantic.v1 import (
58
+ BaseModel,
59
+ Extra,
60
+ Field,
61
+ NonNegativeFloat,
62
+ NonNegativeInt,
63
+ PositiveFloat,
64
+ PositiveInt,
65
+ PrivateAttr,
66
+ StrictInt,
67
+ ValidationError,
68
+ root_validator,
69
+ validator,
70
+ )
71
+
72
+ def is_subclass_of_base_model(obj):
73
+ from pydantic import BaseModel as BaseModelV2
74
+ from pydantic.v1 import BaseModel as BaseModelV1
75
+
76
+ return issubclass(obj, BaseModelV1) or issubclass(obj, BaseModelV2)
77
+
78
+
79
+ def register_pydantic_serializers(serialization_context):
80
+ if not PYDANTIC_INSTALLED:
81
+ return
82
+
83
+ if IS_PYDANTIC_2:
84
+ # TODO(edoakes): compare against the version that has the fixes.
85
+ from pydantic.v1.fields import ModelField
86
+ else:
87
+ from pydantic.fields import ModelField
88
+
89
+ # Pydantic's Cython validators are not serializable.
90
+ # https://github.com/cloudpipe/cloudpickle/issues/408
91
+ serialization_context._register_cloudpickle_serializer(
92
+ ModelField,
93
+ custom_serializer=lambda o: {
94
+ "name": o.name,
95
+ # outer_type_ is the original type for ModelFields,
96
+ # while type_ can be updated later with the nested type
97
+ # like int for List[int].
98
+ "type_": o.outer_type_,
99
+ "class_validators": o.class_validators,
100
+ "model_config": o.model_config,
101
+ "default": o.default,
102
+ "default_factory": o.default_factory,
103
+ "required": o.required,
104
+ "alias": o.alias,
105
+ "field_info": o.field_info,
106
+ },
107
+ custom_deserializer=lambda kwargs: ModelField(**kwargs),
108
+ )
.venv/lib/python3.11/site-packages/ray/_private/ray_client_microbenchmark.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import logging
3
+ import numpy as np
4
+ import sys
5
+
6
+ from ray.util.client.ray_client_helpers import ray_start_client_server
7
+
8
+ from ray._private.ray_microbenchmark_helpers import timeit
9
+
10
+
11
+ def benchmark_get_calls(ray, results):
12
+ value = ray.put(0)
13
+
14
+ def get_small():
15
+ ray.get(value)
16
+
17
+ results += timeit("client: get calls", get_small)
18
+
19
+
20
+ def benchmark_tasks_and_get_batch(ray, results):
21
+ @ray.remote
22
+ def small_value():
23
+ return b"ok"
24
+
25
+ def small_value_batch():
26
+ submitted = [small_value.remote() for _ in range(1000)]
27
+ ray.get(submitted)
28
+ return 0
29
+
30
+ results += timeit("client: tasks and get batch", small_value_batch)
31
+
32
+
33
+ def benchmark_put_calls(ray, results):
34
+ def put_small():
35
+ ray.put(0)
36
+
37
+ results += timeit("client: put calls", put_small)
38
+
39
+
40
+ def benchmark_remote_put_calls(ray, results):
41
+ @ray.remote
42
+ def do_put_small():
43
+ for _ in range(100):
44
+ ray.put(0)
45
+
46
+ def put_multi_small():
47
+ ray.get([do_put_small.remote() for _ in range(10)])
48
+
49
+ results += timeit("client: tasks and put batch", put_multi_small, 1000)
50
+
51
+
52
+ def benchmark_put_large(ray, results):
53
+ arr = np.zeros(100 * 1024 * 1024, dtype=np.int64)
54
+
55
+ def put_large():
56
+ ray.put(arr)
57
+
58
+ results += timeit("client: put gigabytes", put_large, 8 * 0.1)
59
+
60
+
61
+ def benchmark_simple_actor(ray, results):
62
+ @ray.remote(num_cpus=0)
63
+ class Actor:
64
+ def small_value(self):
65
+ return b"ok"
66
+
67
+ def small_value_arg(self, x):
68
+ return b"ok"
69
+
70
+ def small_value_batch(self, n):
71
+ ray.get([self.small_value.remote() for _ in range(n)])
72
+
73
+ a = Actor.remote()
74
+
75
+ def actor_sync():
76
+ ray.get(a.small_value.remote())
77
+
78
+ results += timeit("client: 1:1 actor calls sync", actor_sync)
79
+
80
+ def actor_async():
81
+ ray.get([a.small_value.remote() for _ in range(1000)])
82
+
83
+ results += timeit("client: 1:1 actor calls async", actor_async, 1000)
84
+
85
+ a = Actor.options(max_concurrency=16).remote()
86
+
87
+ def actor_concurrent():
88
+ ray.get([a.small_value.remote() for _ in range(1000)])
89
+
90
+ results += timeit("client: 1:1 actor calls concurrent", actor_concurrent, 1000)
91
+
92
+
93
+ def main(results=None):
94
+ results = results or []
95
+
96
+ ray_config = {"logging_level": logging.WARNING}
97
+
98
+ def ray_connect_handler(job_config=None, **ray_init_kwargs):
99
+ from ray._private.client_mode_hook import disable_client_hook
100
+
101
+ with disable_client_hook():
102
+ import ray as real_ray
103
+
104
+ if not real_ray.is_initialized():
105
+ real_ray.init(**ray_config)
106
+
107
+ for name, obj in inspect.getmembers(sys.modules[__name__]):
108
+ if not name.startswith("benchmark_"):
109
+ continue
110
+ with ray_start_client_server(ray_connect_handler=ray_connect_handler) as ray:
111
+ obj(ray, results)
112
+
113
+ return results
114
+
115
+
116
+ if __name__ == "__main__":
117
+ main()
.venv/lib/python3.11/site-packages/ray/_private/ray_cluster_perf.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This is the script for `ray clusterbenchmark`."""
2
+
3
+ import time
4
+ import numpy as np
5
+ import ray
6
+
7
+ from ray.cluster_utils import Cluster
8
+
9
+
10
+ def main():
11
+ cluster = Cluster(
12
+ initialize_head=True,
13
+ connect=True,
14
+ head_node_args={"object_store_memory": 20 * 1024 * 1024 * 1024, "num_cpus": 16},
15
+ )
16
+ cluster.add_node(
17
+ object_store_memory=20 * 1024 * 1024 * 1024, num_gpus=1, num_cpus=16
18
+ )
19
+
20
+ object_ref_list = []
21
+ for i in range(0, 10):
22
+ object_ref = ray.put(np.random.rand(1024 * 128, 1024))
23
+ object_ref_list.append(object_ref)
24
+
25
+ @ray.remote(num_gpus=1)
26
+ def f(object_ref_list):
27
+ diffs = []
28
+ for object_ref in object_ref_list:
29
+ before = time.time()
30
+ ray.get(object_ref)
31
+ after = time.time()
32
+ diffs.append(after - before)
33
+ time.sleep(1)
34
+ return np.mean(diffs), np.std(diffs)
35
+
36
+ time_diff, time_diff_std = ray.get(f.remote(object_ref_list))
37
+
38
+ print(
39
+ "latency to get an 1G object over network",
40
+ round(time_diff, 2),
41
+ "+-",
42
+ round(time_diff_std, 2),
43
+ )
44
+
45
+ ray.shutdown()
46
+ cluster.shutdown()
47
+
48
+
49
+ if __name__ == "__main__":
50
+ main()
.venv/lib/python3.11/site-packages/ray/_private/ray_constants.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Ray constants used in the Python code."""
2
+
3
+ import logging
4
+ import os
5
+ import sys
6
+ import json
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def env_integer(key, default):
12
+ if key in os.environ:
13
+ value = os.environ[key]
14
+ if value.isdigit():
15
+ return int(os.environ[key])
16
+
17
+ logger.debug(
18
+ f"Found {key} in environment, but value must "
19
+ f"be an integer. Got: {value}. Returning "
20
+ f"provided default {default}."
21
+ )
22
+ return default
23
+ return default
24
+
25
+
26
+ def env_float(key, default):
27
+ if key in os.environ:
28
+ value = os.environ[key]
29
+ try:
30
+ return float(value)
31
+ except ValueError:
32
+ logger.debug(
33
+ f"Found {key} in environment, but value must "
34
+ f"be a float. Got: {value}. Returning "
35
+ f"provided default {default}."
36
+ )
37
+ return default
38
+ return default
39
+
40
+
41
+ def env_bool(key, default):
42
+ if key in os.environ:
43
+ return (
44
+ True
45
+ if os.environ[key].lower() == "true" or os.environ[key] == "1"
46
+ else False
47
+ )
48
+ return default
49
+
50
+
51
+ def env_set_by_user(key):
52
+ return key in os.environ
53
+
54
+
55
+ # Whether event logging to driver is enabled. Set to 0 to disable.
56
+ AUTOSCALER_EVENTS = env_integer("RAY_SCHEDULER_EVENTS", 1)
57
+
58
+ RAY_LOG_TO_DRIVER = env_bool("RAY_LOG_TO_DRIVER", True)
59
+
60
+ # Filter level under which events will be filtered out, i.e. not printing to driver
61
+ RAY_LOG_TO_DRIVER_EVENT_LEVEL = os.environ.get("RAY_LOG_TO_DRIVER_EVENT_LEVEL", "INFO")
62
+
63
+ # Internal kv keys for storing monitor debug status.
64
+ DEBUG_AUTOSCALING_ERROR = "__autoscaling_error"
65
+ DEBUG_AUTOSCALING_STATUS = "__autoscaling_status"
66
+ DEBUG_AUTOSCALING_STATUS_LEGACY = "__autoscaling_status_legacy"
67
+
68
+ ID_SIZE = 28
69
+
70
+ # The default maximum number of bytes to allocate to the object store unless
71
+ # overridden by the user.
72
+ DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES = env_integer(
73
+ "RAY_DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES", 200 * 10**9 # 200 GB
74
+ )
75
+ # The default proportion of available memory allocated to the object store
76
+ DEFAULT_OBJECT_STORE_MEMORY_PROPORTION = env_float(
77
+ "RAY_DEFAULT_OBJECT_STORE_MEMORY_PROPORTION", 0.3
78
+ )
79
+ # The smallest cap on the memory used by the object store that we allow.
80
+ # This must be greater than MEMORY_RESOURCE_UNIT_BYTES
81
+ OBJECT_STORE_MINIMUM_MEMORY_BYTES = 75 * 1024 * 1024
82
+ # Each ObjectRef currently uses about 3KB of caller memory.
83
+ CALLER_MEMORY_USAGE_PER_OBJECT_REF = 3000
84
+ # Match max_direct_call_object_size in
85
+ # src/ray/common/ray_config_def.h.
86
+ # TODO(swang): Ideally this should be pulled directly from the
87
+ # config in case the user overrides it.
88
+ DEFAULT_MAX_DIRECT_CALL_OBJECT_SIZE = 100 * 1024
89
+ # The default maximum number of bytes that the non-primary Redis shards are
90
+ # allowed to use unless overridden by the user.
91
+ DEFAULT_REDIS_MAX_MEMORY_BYTES = 10**10
92
+ # The smallest cap on the memory used by Redis that we allow.
93
+ REDIS_MINIMUM_MEMORY_BYTES = 10**7
94
+ # Above this number of bytes, raise an error by default unless the user sets
95
+ # RAY_ALLOW_SLOW_STORAGE=1. This avoids swapping with large object stores.
96
+ REQUIRE_SHM_SIZE_THRESHOLD = 10**10
97
+ # Mac with 16GB memory has degraded performance when the object store size is
98
+ # greater than 2GB.
99
+ # (see https://github.com/ray-project/ray/issues/20388 for details)
100
+ # The workaround here is to limit capacity to 2GB for Mac by default,
101
+ # and raise error if the capacity is overwritten by user.
102
+ MAC_DEGRADED_PERF_MMAP_SIZE_LIMIT = 2 * 2**30
103
+ # If a user does not specify a port for the primary Ray service,
104
+ # we attempt to start the service running at this port.
105
+ DEFAULT_PORT = 6379
106
+
107
+ RAY_ADDRESS_ENVIRONMENT_VARIABLE = "RAY_ADDRESS"
108
+ RAY_NAMESPACE_ENVIRONMENT_VARIABLE = "RAY_NAMESPACE"
109
+ RAY_RUNTIME_ENV_ENVIRONMENT_VARIABLE = "RAY_RUNTIME_ENV"
110
+ RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_ENV_VAR = (
111
+ "RAY_RUNTIME_ENV_TEMPORARY_REFERENCE_EXPIRATION_S"
112
+ )
113
+ # Ray populates this env var to the working dir in the creation of a runtime env.
114
+ # For example, `pip` and `conda` users can use this environment variable to locate the
115
+ # `requirements.txt` file.
116
+ RAY_RUNTIME_ENV_CREATE_WORKING_DIR_ENV_VAR = "RAY_RUNTIME_ENV_CREATE_WORKING_DIR"
117
+ # Defaults to 10 minutes. This should be longer than the total time it takes for
118
+ # the local working_dir and py_modules to be uploaded, or these files might get
119
+ # garbage collected before the job starts.
120
+ RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_DEFAULT = 10 * 60
121
+ # If set to 1, then `.gitignore` files will not be parsed and loaded into "excludes"
122
+ # when using a local working_dir or py_modules.
123
+ RAY_RUNTIME_ENV_IGNORE_GITIGNORE = "RAY_RUNTIME_ENV_IGNORE_GITIGNORE"
124
+ RAY_STORAGE_ENVIRONMENT_VARIABLE = "RAY_STORAGE"
125
+ # Hook for running a user-specified runtime-env hook. This hook will be called
126
+ # unconditionally given the runtime_env dict passed for ray.init. It must return
127
+ # a rewritten runtime_env dict. Example: "your.module.runtime_env_hook".
128
+ RAY_RUNTIME_ENV_HOOK = "RAY_RUNTIME_ENV_HOOK"
129
+ # Hook that is invoked on `ray start`. It will be given the cluster parameters and
130
+ # whether we are the head node as arguments. The function can modify the params class,
131
+ # but otherwise returns void. Example: "your.module.ray_start_hook".
132
+ RAY_START_HOOK = "RAY_START_HOOK"
133
+ # Hook that is invoked on `ray job submit`. It will be given all the same args as the
134
+ # job.cli.submit() function gets, passed as kwargs to this function.
135
+ RAY_JOB_SUBMIT_HOOK = "RAY_JOB_SUBMIT_HOOK"
136
+ # Headers to pass when using the Job CLI. It will be given to
137
+ # instantiate a Job SubmissionClient.
138
+ RAY_JOB_HEADERS = "RAY_JOB_HEADERS"
139
+
140
+ DEFAULT_DASHBOARD_IP = "127.0.0.1"
141
+ DEFAULT_DASHBOARD_PORT = 8265
142
+ DASHBOARD_ADDRESS = "dashboard"
143
+ PROMETHEUS_SERVICE_DISCOVERY_FILE = "prom_metrics_service_discovery.json"
144
+ DEFAULT_DASHBOARD_AGENT_LISTEN_PORT = 52365
145
+ # Default resource requirements for actors when no resource requirements are
146
+ # specified.
147
+ DEFAULT_ACTOR_METHOD_CPU_SIMPLE = 1
148
+ DEFAULT_ACTOR_CREATION_CPU_SIMPLE = 0
149
+ # Default resource requirements for actors when some resource requirements are
150
+ # specified in .
151
+ DEFAULT_ACTOR_METHOD_CPU_SPECIFIED = 0
152
+ DEFAULT_ACTOR_CREATION_CPU_SPECIFIED = 1
153
+ # Default number of return values for each actor method.
154
+ DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS = 1
155
+
156
+ # Wait 30 seconds for client to reconnect after unexpected disconnection
157
+ DEFAULT_CLIENT_RECONNECT_GRACE_PERIOD = 30
158
+
159
+ # If a remote function or actor (or some other export) has serialized size
160
+ # greater than this quantity, print an warning.
161
+ FUNCTION_SIZE_WARN_THRESHOLD = 10**7
162
+ FUNCTION_SIZE_ERROR_THRESHOLD = env_integer("FUNCTION_SIZE_ERROR_THRESHOLD", (10**8))
163
+
164
+ # If remote functions with the same source are imported this many times, then
165
+ # print a warning.
166
+ DUPLICATE_REMOTE_FUNCTION_THRESHOLD = 100
167
+
168
+ # The maximum resource quantity that is allowed. TODO(rkn): This could be
169
+ # relaxed, but the current implementation of the node manager will be slower
170
+ # for large resource quantities due to bookkeeping of specific resource IDs.
171
+ MAX_RESOURCE_QUANTITY = 100e12
172
+
173
+ # Number of units 1 resource can be subdivided into.
174
+ MIN_RESOURCE_GRANULARITY = 0.0001
175
+
176
+ # Set this environment variable to populate the dashboard URL with
177
+ # an external hosted Ray dashboard URL (e.g. because the
178
+ # dashboard is behind a proxy or load balancer). This only overrides
179
+ # the dashboard URL when returning or printing to a user through a public
180
+ # API, but not in the internal KV store.
181
+ RAY_OVERRIDE_DASHBOARD_URL = "RAY_OVERRIDE_DASHBOARD_URL"
182
+
183
+
184
+ # Different types of Ray errors that can be pushed to the driver.
185
+ # TODO(rkn): These should be defined in flatbuffers and must be synced with
186
+ # the existing C++ definitions.
187
+ PICKLING_LARGE_OBJECT_PUSH_ERROR = "pickling_large_object"
188
+ WAIT_FOR_FUNCTION_PUSH_ERROR = "wait_for_function"
189
+ VERSION_MISMATCH_PUSH_ERROR = "version_mismatch"
190
+ WORKER_CRASH_PUSH_ERROR = "worker_crash"
191
+ WORKER_DIED_PUSH_ERROR = "worker_died"
192
+ WORKER_POOL_LARGE_ERROR = "worker_pool_large"
193
+ PUT_RECONSTRUCTION_PUSH_ERROR = "put_reconstruction"
194
+ RESOURCE_DEADLOCK_ERROR = "resource_deadlock"
195
+ REMOVED_NODE_ERROR = "node_removed"
196
+ MONITOR_DIED_ERROR = "monitor_died"
197
+ LOG_MONITOR_DIED_ERROR = "log_monitor_died"
198
+ DASHBOARD_AGENT_DIED_ERROR = "dashboard_agent_died"
199
+ DASHBOARD_DIED_ERROR = "dashboard_died"
200
+ RAYLET_DIED_ERROR = "raylet_died"
201
+ DETACHED_ACTOR_ANONYMOUS_NAMESPACE_ERROR = "detached_actor_anonymous_namespace"
202
+ EXCESS_QUEUEING_WARNING = "excess_queueing_warning"
203
+
204
+ # Used in gpu detection
205
+ RESOURCE_CONSTRAINT_PREFIX = "accelerator_type:"
206
+
207
+ # Used by autoscaler to set the node custom resources and labels
208
+ # from cluster.yaml.
209
+ RESOURCES_ENVIRONMENT_VARIABLE = "RAY_OVERRIDE_RESOURCES"
210
+ LABELS_ENVIRONMENT_VARIABLE = "RAY_OVERRIDE_LABELS"
211
+
212
+ # Temporary flag to disable log processing in the dashboard. This is useful
213
+ # if the dashboard is overloaded by logs and failing to process other
214
+ # dashboard API requests (e.g. Job Submission).
215
+ DISABLE_DASHBOARD_LOG_INFO = env_integer("RAY_DISABLE_DASHBOARD_LOG_INFO", 0)
216
+
217
+ LOGGER_FORMAT = "%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s -- %(message)s"
218
+ LOGGER_FORMAT_ESCAPE = json.dumps(LOGGER_FORMAT.replace("%", "%%"))
219
+ LOGGER_FORMAT_HELP = f"The logging format. default={LOGGER_FORMAT_ESCAPE}"
220
+ # Configure the default logging levels for various Ray components.
221
+ # TODO (kevin85421): Currently, I don't encourage Ray users to configure
222
+ # `RAY_LOGGER_LEVEL` until its scope and expected behavior are clear and
223
+ # easy to understand. Now, only Ray developers should use it.
224
+ LOGGER_LEVEL = os.environ.get("RAY_LOGGER_LEVEL", "info")
225
+ LOGGER_LEVEL_CHOICES = ["debug", "info", "warning", "error", "critical"]
226
+ LOGGER_LEVEL_HELP = (
227
+ "The logging level threshold, choices=['debug', 'info',"
228
+ " 'warning', 'error', 'critical'], default='info'"
229
+ )
230
+
231
+ LOGGING_ROTATE_BYTES = 512 * 1024 * 1024 # 512MB.
232
+ LOGGING_ROTATE_BACKUP_COUNT = 5 # 5 Backup files at max.
233
+
234
+ LOGGING_REDIRECT_STDERR_ENVIRONMENT_VARIABLE = "RAY_LOG_TO_STDERR"
235
+ # Logging format when logging stderr. This should be formatted with the
236
+ # component before setting the formatter, e.g. via
237
+ # format = LOGGER_FORMAT_STDERR.format(component="dashboard")
238
+ # handler.setFormatter(logging.Formatter(format))
239
+ LOGGER_FORMAT_STDERR = (
240
+ "%(asctime)s\t%(levelname)s ({component}) %(filename)s:%(lineno)s -- %(message)s"
241
+ )
242
+
243
+ # Constants used to define the different process types.
244
+ PROCESS_TYPE_REAPER = "reaper"
245
+ PROCESS_TYPE_MONITOR = "monitor"
246
+ PROCESS_TYPE_RAY_CLIENT_SERVER = "ray_client_server"
247
+ PROCESS_TYPE_LOG_MONITOR = "log_monitor"
248
+ # TODO(sang): Delete it.
249
+ PROCESS_TYPE_REPORTER = "reporter"
250
+ PROCESS_TYPE_DASHBOARD = "dashboard"
251
+ PROCESS_TYPE_DASHBOARD_AGENT = "dashboard_agent"
252
+ PROCESS_TYPE_RUNTIME_ENV_AGENT = "runtime_env_agent"
253
+ PROCESS_TYPE_WORKER = "worker"
254
+ PROCESS_TYPE_RAYLET = "raylet"
255
+ PROCESS_TYPE_REDIS_SERVER = "redis_server"
256
+ PROCESS_TYPE_WEB_UI = "web_ui"
257
+ PROCESS_TYPE_GCS_SERVER = "gcs_server"
258
+ PROCESS_TYPE_PYTHON_CORE_WORKER_DRIVER = "python-core-driver"
259
+ PROCESS_TYPE_PYTHON_CORE_WORKER = "python-core-worker"
260
+
261
+ # Log file names
262
+ MONITOR_LOG_FILE_NAME = f"{PROCESS_TYPE_MONITOR}.log"
263
+ LOG_MONITOR_LOG_FILE_NAME = f"{PROCESS_TYPE_LOG_MONITOR}.log"
264
+
265
+ # Enable log deduplication.
266
+ RAY_DEDUP_LOGS = env_bool("RAY_DEDUP_LOGS", True)
267
+
268
+ # How many seconds of messages to buffer for log deduplication.
269
+ RAY_DEDUP_LOGS_AGG_WINDOW_S = env_integer("RAY_DEDUP_LOGS_AGG_WINDOW_S", 5)
270
+
271
+ # Regex for log messages to never deduplicate, or None. This takes precedence over
272
+ # the skip regex below. A default pattern is set for testing.
273
+ TESTING_NEVER_DEDUP_TOKEN = "__ray_testing_never_deduplicate__"
274
+ RAY_DEDUP_LOGS_ALLOW_REGEX = os.environ.get(
275
+ "RAY_DEDUP_LOGS_ALLOW_REGEX", TESTING_NEVER_DEDUP_TOKEN
276
+ )
277
+
278
+ # Regex for log messages to always skip / suppress, or None.
279
+ RAY_DEDUP_LOGS_SKIP_REGEX = os.environ.get("RAY_DEDUP_LOGS_SKIP_REGEX")
280
+
281
+ WORKER_PROCESS_TYPE_IDLE_WORKER = "ray::IDLE"
282
+ WORKER_PROCESS_TYPE_SPILL_WORKER_NAME = "SpillWorker"
283
+ WORKER_PROCESS_TYPE_RESTORE_WORKER_NAME = "RestoreWorker"
284
+ WORKER_PROCESS_TYPE_SPILL_WORKER_IDLE = (
285
+ f"ray::IDLE_{WORKER_PROCESS_TYPE_SPILL_WORKER_NAME}"
286
+ )
287
+ WORKER_PROCESS_TYPE_RESTORE_WORKER_IDLE = (
288
+ f"ray::IDLE_{WORKER_PROCESS_TYPE_RESTORE_WORKER_NAME}"
289
+ )
290
+ WORKER_PROCESS_TYPE_SPILL_WORKER = f"ray::SPILL_{WORKER_PROCESS_TYPE_SPILL_WORKER_NAME}"
291
+ WORKER_PROCESS_TYPE_RESTORE_WORKER = (
292
+ f"ray::RESTORE_{WORKER_PROCESS_TYPE_RESTORE_WORKER_NAME}"
293
+ )
294
+ WORKER_PROCESS_TYPE_SPILL_WORKER_DELETE = (
295
+ f"ray::DELETE_{WORKER_PROCESS_TYPE_SPILL_WORKER_NAME}"
296
+ )
297
+ WORKER_PROCESS_TYPE_RESTORE_WORKER_DELETE = (
298
+ f"ray::DELETE_{WORKER_PROCESS_TYPE_RESTORE_WORKER_NAME}"
299
+ )
300
+
301
+ # The number of files the log monitor will open. If more files exist, they will
302
+ # be ignored.
303
+ LOG_MONITOR_MAX_OPEN_FILES = int(
304
+ os.environ.get("RAY_LOG_MONITOR_MAX_OPEN_FILES", "200")
305
+ )
306
+
307
+ # The maximum batch of lines to be read in a single iteration. We _always_ try
308
+ # to read this number of lines even if there aren't any new lines.
309
+ LOG_MONITOR_NUM_LINES_TO_READ = int(
310
+ os.environ.get("RAY_LOG_MONITOR_NUM_LINES_TO_READ", "1000")
311
+ )
312
+
313
+ # Autoscaler events are denoted by the ":event_summary:" magic token.
314
+ LOG_PREFIX_EVENT_SUMMARY = ":event_summary:"
315
+ # Cluster-level info events are denoted by the ":info_message:" magic token. These may
316
+ # be emitted in the stderr of Ray components.
317
+ LOG_PREFIX_INFO_MESSAGE = ":info_message:"
318
+ # Actor names are recorded in the logs with this magic token as a prefix.
319
+ LOG_PREFIX_ACTOR_NAME = ":actor_name:"
320
+ # Task names are recorded in the logs with this magic token as a prefix.
321
+ LOG_PREFIX_TASK_NAME = ":task_name:"
322
+ # Job ids are recorded in the logs with this magic token as a prefix.
323
+ LOG_PREFIX_JOB_ID = ":job_id:"
324
+
325
+ # The object metadata field uses the following format: It is a comma
326
+ # separated list of fields. The first field is mandatory and is the
327
+ # type of the object (see types below) or an integer, which is interpreted
328
+ # as an error value. The second part is optional and if present has the
329
+ # form DEBUG:<breakpoint_id>, it is used for implementing the debugger.
330
+
331
+ # A constant used as object metadata to indicate the object is cross language.
332
+ OBJECT_METADATA_TYPE_CROSS_LANGUAGE = b"XLANG"
333
+ # A constant used as object metadata to indicate the object is python specific.
334
+ OBJECT_METADATA_TYPE_PYTHON = b"PYTHON"
335
+ # A constant used as object metadata to indicate the object is raw bytes.
336
+ OBJECT_METADATA_TYPE_RAW = b"RAW"
337
+
338
+ # A constant used as object metadata to indicate the object is an actor handle.
339
+ # This value should be synchronized with the Java definition in
340
+ # ObjectSerializer.java
341
+ # TODO(fyrestone): Serialize the ActorHandle via the custom type feature
342
+ # of XLANG.
343
+ OBJECT_METADATA_TYPE_ACTOR_HANDLE = b"ACTOR_HANDLE"
344
+
345
+ # A constant indicating the debugging part of the metadata (see above).
346
+ OBJECT_METADATA_DEBUG_PREFIX = b"DEBUG:"
347
+
348
+ AUTOSCALER_RESOURCE_REQUEST_CHANNEL = b"autoscaler_resource_request"
349
+
350
+ REDIS_DEFAULT_USERNAME = ""
351
+
352
+ REDIS_DEFAULT_PASSWORD = ""
353
+
354
+ # The default ip address to bind to.
355
+ NODE_DEFAULT_IP = "127.0.0.1"
356
+
357
+ # The Mach kernel page size in bytes.
358
+ MACH_PAGE_SIZE_BYTES = 4096
359
+
360
+ # The max number of bytes for task execution error message.
361
+ MAX_APPLICATION_ERROR_LEN = 500
362
+
363
+ # Max 64 bit integer value, which is needed to ensure against overflow
364
+ # in C++ when passing integer values cross-language.
365
+ MAX_INT64_VALUE = 9223372036854775807
366
+
367
+ # Object Spilling related constants
368
+ DEFAULT_OBJECT_PREFIX = "ray_spilled_objects"
369
+
370
+ GCS_PORT_ENVIRONMENT_VARIABLE = "RAY_GCS_SERVER_PORT"
371
+
372
+ HEALTHCHECK_EXPIRATION_S = os.environ.get("RAY_HEALTHCHECK_EXPIRATION_S", 10)
373
+
374
+ # Filename of "shim process" that sets up Python worker environment.
375
+ # Should be kept in sync with kSetupWorkerFilename in
376
+ # src/ray/common/constants.h.
377
+ SETUP_WORKER_FILENAME = "setup_worker.py"
378
+
379
+ # Directory name where runtime_env resources will be created & cached.
380
+ DEFAULT_RUNTIME_ENV_DIR_NAME = "runtime_resources"
381
+
382
+ # The timeout seconds for the creation of runtime env,
383
+ # dafault timeout is 10 minutes
384
+ DEFAULT_RUNTIME_ENV_TIMEOUT_SECONDS = 600
385
+
386
+ # Used to separate lines when formatting the call stack where an ObjectRef was
387
+ # created.
388
+ CALL_STACK_LINE_DELIMITER = " | "
389
+
390
+ # The default gRPC max message size is 4 MiB, we use a larger number of 250 MiB
391
+ # NOTE: This is equal to the C++ limit of (RAY_CONFIG::max_grpc_message_size)
392
+ GRPC_CPP_MAX_MESSAGE_SIZE = 250 * 1024 * 1024
393
+
394
+ # The gRPC send & receive max length for "dashboard agent" server.
395
+ # NOTE: This is equal to the C++ limit of RayConfig::max_grpc_message_size
396
+ # and HAVE TO STAY IN SYNC with it (ie, meaning that both of these values
397
+ # have to be set at the same time)
398
+ AGENT_GRPC_MAX_MESSAGE_LENGTH = env_integer(
399
+ "AGENT_GRPC_MAX_MESSAGE_LENGTH", 20 * 1024 * 1024 # 20MB
400
+ )
401
+
402
+
403
+ # GRPC options
404
+ GRPC_ENABLE_HTTP_PROXY = (
405
+ 1
406
+ if os.environ.get("RAY_grpc_enable_http_proxy", "0").lower() in ("1", "true")
407
+ else 0
408
+ )
409
+ GLOBAL_GRPC_OPTIONS = (("grpc.enable_http_proxy", GRPC_ENABLE_HTTP_PROXY),)
410
+
411
+ # Internal kv namespaces
412
+ KV_NAMESPACE_DASHBOARD = b"dashboard"
413
+ KV_NAMESPACE_SESSION = b"session"
414
+ KV_NAMESPACE_TRACING = b"tracing"
415
+ KV_NAMESPACE_PDB = b"ray_pdb"
416
+ KV_NAMESPACE_HEALTHCHECK = b"healthcheck"
417
+ KV_NAMESPACE_JOB = b"job"
418
+ KV_NAMESPACE_CLUSTER = b"cluster"
419
+ KV_HEAD_NODE_ID_KEY = b"head_node_id"
420
+ # TODO: Set package for runtime env
421
+ # We need to update ray client for this since runtime env use ray client
422
+ # This might introduce some compatibility issues so leave it here for now.
423
+ KV_NAMESPACE_PACKAGE = None
424
+ KV_NAMESPACE_SERVE = b"serve"
425
+ KV_NAMESPACE_FUNCTION_TABLE = b"fun"
426
+
427
+ LANGUAGE_WORKER_TYPES = ["python", "java", "cpp"]
428
+
429
+ # Accelerator constants
430
+ NOSET_CUDA_VISIBLE_DEVICES_ENV_VAR = "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"
431
+
432
+ CUDA_VISIBLE_DEVICES_ENV_VAR = "CUDA_VISIBLE_DEVICES"
433
+ ROCR_VISIBLE_DEVICES_ENV_VAR = "ROCR_VISIBLE_DEVICES"
434
+ NEURON_RT_VISIBLE_CORES_ENV_VAR = "NEURON_RT_VISIBLE_CORES"
435
+ TPU_VISIBLE_CHIPS_ENV_VAR = "TPU_VISIBLE_CHIPS"
436
+ NPU_RT_VISIBLE_DEVICES_ENV_VAR = "ASCEND_RT_VISIBLE_DEVICES"
437
+
438
+ NEURON_CORES = "neuron_cores"
439
+ GPU = "GPU"
440
+ TPU = "TPU"
441
+ NPU = "NPU"
442
+ HPU = "HPU"
443
+
444
+
445
+ RAY_WORKER_NICENESS = "RAY_worker_niceness"
446
+
447
+ # Default max_retries option in @ray.remote for non-actor
448
+ # tasks.
449
+ DEFAULT_TASK_MAX_RETRIES = 3
450
+
451
+ # Default max_concurrency option in @ray.remote for threaded actors.
452
+ DEFAULT_MAX_CONCURRENCY_THREADED = 1
453
+
454
+ # Default max_concurrency option in @ray.remote for async actors.
455
+ DEFAULT_MAX_CONCURRENCY_ASYNC = 1000
456
+
457
+ # Prefix for namespaces which are used internally by ray.
458
+ # Jobs within these namespaces should be hidden from users
459
+ # and should not be considered user activity.
460
+ # Please keep this in sync with the definition kRayInternalNamespacePrefix
461
+ # in /src/ray/gcs/gcs_server/gcs_job_manager.h.
462
+ RAY_INTERNAL_NAMESPACE_PREFIX = "_ray_internal_"
463
+ RAY_INTERNAL_DASHBOARD_NAMESPACE = f"{RAY_INTERNAL_NAMESPACE_PREFIX}dashboard"
464
+
465
+ # Ray internal flags. These flags should not be set by users, and we strip them on job
466
+ # submission.
467
+ # This should be consistent with src/ray/common/ray_internal_flag_def.h
468
+ RAY_INTERNAL_FLAGS = [
469
+ "RAY_JOB_ID",
470
+ "RAY_RAYLET_PID",
471
+ "RAY_OVERRIDE_NODE_ID_FOR_TESTING",
472
+ ]
473
+
474
+
475
+ def gcs_actor_scheduling_enabled():
476
+ return os.environ.get("RAY_gcs_actor_scheduling_enabled") == "true"
477
+
478
+
479
+ DEFAULT_RESOURCES = {"CPU", "GPU", "memory", "object_store_memory"}
480
+
481
+ # Supported Python versions for runtime env's "conda" field. Ray downloads
482
+ # Ray wheels into the conda environment, so the Ray wheels for these Python
483
+ # versions must be available online.
484
+ RUNTIME_ENV_CONDA_PY_VERSIONS = [(3, 9), (3, 10), (3, 11), (3, 12)]
485
+
486
+ # Whether to enable Ray clusters (in addition to local Ray).
487
+ # Ray clusters are not explicitly supported for Windows and OSX.
488
+ IS_WINDOWS_OR_OSX = sys.platform == "darwin" or sys.platform == "win32"
489
+ ENABLE_RAY_CLUSTERS_ENV_VAR = "RAY_ENABLE_WINDOWS_OR_OSX_CLUSTER"
490
+ ENABLE_RAY_CLUSTER = env_bool(
491
+ ENABLE_RAY_CLUSTERS_ENV_VAR,
492
+ not IS_WINDOWS_OR_OSX,
493
+ )
494
+
495
+ SESSION_LATEST = "session_latest"
496
+ NUM_PORT_RETRIES = 40
497
+ NUM_REDIS_GET_RETRIES = int(os.environ.get("RAY_NUM_REDIS_GET_RETRIES", "20"))
498
+
499
+ # The allowed cached ports in Ray. Refer to Port configuration for more details:
500
+ # https://docs.ray.io/en/latest/ray-core/configure.html#ports-configurations
501
+ RAY_ALLOWED_CACHED_PORTS = {
502
+ "metrics_agent_port",
503
+ "metrics_export_port",
504
+ "dashboard_agent_listen_port",
505
+ "runtime_env_agent_port",
506
+ "gcs_server_port", # the `port` option for gcs port.
507
+ }
508
+
509
+ # Turn this on if actor task log's offsets are expected to be recorded.
510
+ # With this enabled, actor tasks' log could be queried with task id.
511
+ RAY_ENABLE_RECORD_ACTOR_TASK_LOGGING = env_bool(
512
+ "RAY_ENABLE_RECORD_ACTOR_TASK_LOGGING", False
513
+ )
514
+
515
+ # RuntimeEnv env var to indicate it exports a function
516
+ WORKER_PROCESS_SETUP_HOOK_ENV_VAR = "__RAY_WORKER_PROCESS_SETUP_HOOK_ENV_VAR"
517
+ RAY_WORKER_PROCESS_SETUP_HOOK_LOAD_TIMEOUT_ENV_VAR = (
518
+ "RAY_WORKER_PROCESS_SETUP_HOOK_LOAD_TIMEOUT" # noqa
519
+ )
520
+
521
+ RAY_DEFAULT_LABEL_KEYS_PREFIX = "ray.io/"
522
+
523
+ RAY_TPU_MAX_CONCURRENT_CONNECTIONS_ENV_VAR = "RAY_TPU_MAX_CONCURRENT_ACTIVE_CONNECTIONS"
524
+
525
+ RAY_NODE_IP_FILENAME = "node_ip_address.json"
526
+
527
+ PLACEMENT_GROUP_BUNDLE_RESOURCE_NAME = "bundle"
528
+
529
+ RAY_LOGGING_CONFIG_ENCODING = os.environ.get("RAY_LOGGING_CONFIG_ENCODING")
530
+
531
+ RAY_BACKEND_LOG_JSON_ENV_VAR = "RAY_BACKEND_LOG_JSON"
532
+
533
+ # Write export API event of all resource types to file if enabled.
534
+ # RAY_enable_export_api_write_config will not be considered if
535
+ # this is enabled.
536
+ RAY_ENABLE_EXPORT_API_WRITE = env_bool("RAY_enable_export_api_write", False)
537
+
538
+ # Comma separated string containing individual resource
539
+ # to write export API events for. This configuration is only used if
540
+ # RAY_enable_export_api_write is not enabled. Full list of valid
541
+ # resource types in ExportEvent.SourceType enum in
542
+ # src/ray/protobuf/export_api/export_event.proto
543
+ # Example config:
544
+ # `export RAY_enable_export_api_write_config='EXPORT_SUBMISSION_JOB,EXPORT_ACTOR'`
545
+ RAY_ENABLE_EXPORT_API_WRITE_CONFIG_STR = os.environ.get(
546
+ "RAY_enable_export_api_write_config", ""
547
+ )
548
+ RAY_ENABLE_EXPORT_API_WRITE_CONFIG = RAY_ENABLE_EXPORT_API_WRITE_CONFIG_STR.split(",")
549
+
550
+ RAY_EXPORT_EVENT_MAX_FILE_SIZE_BYTES = env_bool(
551
+ "RAY_EXPORT_EVENT_MAX_FILE_SIZE_BYTES", 100 * 1e6
552
+ )
553
+
554
+ RAY_EXPORT_EVENT_MAX_BACKUP_COUNT = env_bool("RAY_EXPORT_EVENT_MAX_BACKUP_COUNT", 20)
.venv/lib/python3.11/site-packages/ray/_private/ray_experimental_perf.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This is the script for `ray microbenchmark`."""
2
+
3
+ import asyncio
4
+ import logging
5
+ from ray._private.ray_microbenchmark_helpers import timeit, asyncio_timeit
6
+ import multiprocessing
7
+ import ray
8
+ from ray.dag.compiled_dag_node import CompiledDAG
9
+ from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
10
+
11
+ import ray.experimental.channel as ray_channel
12
+ from ray.dag import InputNode, MultiOutputNode
13
+ from ray._private.utils import (
14
+ get_or_create_event_loop,
15
+ )
16
+ from ray._private.test_utils import get_actor_node_id
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @ray.remote
22
+ class DAGActor:
23
+ def echo(self, x):
24
+ return x
25
+
26
+ def echo_multiple(self, *x):
27
+ return x
28
+
29
+
30
+ def check_optimized_build():
31
+ if not ray._raylet.OPTIMIZED:
32
+ msg = (
33
+ "WARNING: Unoptimized build! "
34
+ "To benchmark an optimized build, try:\n"
35
+ "\tbazel build -c opt //:ray_pkg\n"
36
+ "You can also make this permanent by adding\n"
37
+ "\tbuild --compilation_mode=opt\n"
38
+ "to your user-wide ~/.bazelrc file. "
39
+ "(Do not add this to the project-level .bazelrc file.)"
40
+ )
41
+ logger.warning(msg)
42
+
43
+
44
+ def create_driver_actor():
45
+ return CompiledDAG.DAGDriverProxyActor.options(
46
+ scheduling_strategy=NodeAffinitySchedulingStrategy(
47
+ ray.get_runtime_context().get_node_id(), soft=False
48
+ )
49
+ ).remote()
50
+
51
+
52
+ def main(results=None):
53
+ results = results or []
54
+ loop = get_or_create_event_loop()
55
+
56
+ check_optimized_build()
57
+
58
+ print("Tip: set TESTS_TO_RUN='pattern' to run a subset of benchmarks")
59
+
60
+ #################################################
61
+ # Perf tests for channels, used in compiled DAGs.
62
+ #################################################
63
+ ray.init()
64
+
65
+ def put_channel_small(chans, do_get=False):
66
+ for chan in chans:
67
+ chan.write(b"0")
68
+ if do_get:
69
+ chan.read()
70
+
71
+ @ray.remote
72
+ class ChannelReader:
73
+ def ready(self):
74
+ return
75
+
76
+ def read(self, chans):
77
+ while True:
78
+ for chan in chans:
79
+ chan.read()
80
+
81
+ driver_actor = create_driver_actor()
82
+ driver_node = get_actor_node_id(driver_actor)
83
+ chans = [ray_channel.Channel(None, [(driver_actor, driver_node)], 1000)]
84
+ results += timeit(
85
+ "[unstable] local put:local get, single channel calls",
86
+ lambda: put_channel_small(chans, do_get=True),
87
+ )
88
+
89
+ reader = ChannelReader.remote()
90
+ reader_node = get_actor_node_id(reader)
91
+ chans = [ray_channel.Channel(None, [(reader, reader_node)], 1000)]
92
+ ray.get(reader.ready.remote())
93
+ reader.read.remote(chans)
94
+ results += timeit(
95
+ "[unstable] local put:1 remote get, single channel calls",
96
+ lambda: put_channel_small(chans),
97
+ )
98
+ ray.kill(reader)
99
+
100
+ n_cpu = multiprocessing.cpu_count() // 2
101
+ print(f"Testing multiple readers/channels, n={n_cpu}")
102
+
103
+ reader_and_node_list = []
104
+ for _ in range(n_cpu):
105
+ reader = ChannelReader.remote()
106
+ reader_node = get_actor_node_id(reader)
107
+ reader_and_node_list.append((reader, reader_node))
108
+ chans = [ray_channel.Channel(None, reader_and_node_list, 1000)]
109
+ ray.get([reader.ready.remote() for reader, _ in reader_and_node_list])
110
+ for reader, _ in reader_and_node_list:
111
+ reader.read.remote(chans)
112
+ results += timeit(
113
+ "[unstable] local put:n remote get, single channel calls",
114
+ lambda: put_channel_small(chans),
115
+ )
116
+ for reader, _ in reader_and_node_list:
117
+ ray.kill(reader)
118
+
119
+ reader = ChannelReader.remote()
120
+ reader_node = get_actor_node_id(reader)
121
+ chans = [
122
+ ray_channel.Channel(None, [(reader, reader_node)], 1000) for _ in range(n_cpu)
123
+ ]
124
+ ray.get(reader.ready.remote())
125
+ reader.read.remote(chans)
126
+ results += timeit(
127
+ "[unstable] local put:1 remote get, n channels calls",
128
+ lambda: put_channel_small(chans),
129
+ )
130
+ ray.kill(reader)
131
+
132
+ reader_and_node_list = []
133
+ for _ in range(n_cpu):
134
+ reader = ChannelReader.remote()
135
+ reader_node = get_actor_node_id(reader)
136
+ reader_and_node_list.append((reader, reader_node))
137
+ chans = [
138
+ ray_channel.Channel(None, [reader_and_node_list[i]], 1000) for i in range(n_cpu)
139
+ ]
140
+ ray.get([reader.ready.remote() for reader, _ in reader_and_node_list])
141
+ for chan, reader_node_tuple in zip(chans, reader_and_node_list):
142
+ reader = reader_node_tuple[0]
143
+ reader.read.remote([chan])
144
+ results += timeit(
145
+ "[unstable] local put:n remote get, n channels calls",
146
+ lambda: put_channel_small(chans),
147
+ )
148
+ for reader, _ in reader_and_node_list:
149
+ ray.kill(reader)
150
+
151
+ # Tests for compiled DAGs.
152
+
153
+ def _exec(dag, num_args=1, payload_size=1):
154
+ output_ref = dag.execute(*[b"x" * payload_size for _ in range(num_args)])
155
+ ray.get(output_ref)
156
+
157
+ async def exec_async(tag):
158
+ async def _exec_async():
159
+ fut = await compiled_dag.execute_async(b"x")
160
+ if not isinstance(fut, list):
161
+ await fut
162
+ else:
163
+ await asyncio.gather(*fut)
164
+
165
+ return await asyncio_timeit(
166
+ tag,
167
+ _exec_async,
168
+ )
169
+
170
+ # Single-actor DAG calls
171
+
172
+ a = DAGActor.remote()
173
+ with InputNode() as inp:
174
+ dag = a.echo.bind(inp)
175
+
176
+ results += timeit(
177
+ "[unstable] single-actor DAG calls", lambda: ray.get(dag.execute(b"x"))
178
+ )
179
+ compiled_dag = dag.experimental_compile()
180
+ results += timeit(
181
+ "[unstable] compiled single-actor DAG calls", lambda: _exec(compiled_dag)
182
+ )
183
+ del a
184
+
185
+ # Single-actor asyncio DAG calls
186
+
187
+ a = DAGActor.remote()
188
+ with InputNode() as inp:
189
+ dag = a.echo.bind(inp)
190
+ compiled_dag = dag.experimental_compile(enable_asyncio=True)
191
+ results += loop.run_until_complete(
192
+ exec_async(
193
+ "[unstable] compiled single-actor asyncio DAG calls",
194
+ )
195
+ )
196
+ del a
197
+
198
+ # Scatter-gather DAG calls
199
+
200
+ n_cpu = multiprocessing.cpu_count() // 2
201
+ actors = [DAGActor.remote() for _ in range(n_cpu)]
202
+ with InputNode() as inp:
203
+ dag = MultiOutputNode([a.echo.bind(inp) for a in actors])
204
+ results += timeit(
205
+ f"[unstable] scatter-gather DAG calls, n={n_cpu} actors",
206
+ lambda: ray.get(dag.execute(b"x")),
207
+ )
208
+ compiled_dag = dag.experimental_compile()
209
+ results += timeit(
210
+ f"[unstable] compiled scatter-gather DAG calls, n={n_cpu} actors",
211
+ lambda: _exec(compiled_dag),
212
+ )
213
+
214
+ # Scatter-gather asyncio DAG calls
215
+
216
+ actors = [DAGActor.remote() for _ in range(n_cpu)]
217
+ with InputNode() as inp:
218
+ dag = MultiOutputNode([a.echo.bind(inp) for a in actors])
219
+ compiled_dag = dag.experimental_compile(enable_asyncio=True)
220
+ results += loop.run_until_complete(
221
+ exec_async(
222
+ f"[unstable] compiled scatter-gather asyncio DAG calls, n={n_cpu} actors",
223
+ )
224
+ )
225
+
226
+ # Chain DAG calls
227
+
228
+ actors = [DAGActor.remote() for _ in range(n_cpu)]
229
+ with InputNode() as inp:
230
+ dag = inp
231
+ for a in actors:
232
+ dag = a.echo.bind(dag)
233
+ results += timeit(
234
+ f"[unstable] chain DAG calls, n={n_cpu} actors",
235
+ lambda: ray.get(dag.execute(b"x")),
236
+ )
237
+ compiled_dag = dag.experimental_compile()
238
+ results += timeit(
239
+ f"[unstable] compiled chain DAG calls, n={n_cpu} actors",
240
+ lambda: _exec(compiled_dag),
241
+ )
242
+
243
+ # Chain asyncio DAG calls
244
+
245
+ actors = [DAGActor.remote() for _ in range(n_cpu)]
246
+ with InputNode() as inp:
247
+ dag = inp
248
+ for a in actors:
249
+ dag = a.echo.bind(dag)
250
+ compiled_dag = dag.experimental_compile(enable_asyncio=True)
251
+ results += loop.run_until_complete(
252
+ exec_async(f"[unstable] compiled chain asyncio DAG calls, n={n_cpu} actors")
253
+ )
254
+
255
+ # Multiple args with small payloads
256
+
257
+ n_actors = 8
258
+ assert (
259
+ n_cpu > n_actors
260
+ ), f"n_cpu ({n_cpu}) must be greater than n_actors ({n_actors})"
261
+
262
+ actors = [DAGActor.remote() for _ in range(n_actors)]
263
+ with InputNode() as inp:
264
+ dag = MultiOutputNode([actors[i].echo.bind(inp[i]) for i in range(n_actors)])
265
+ payload_size = 1
266
+ results += timeit(
267
+ f"[unstable] multiple args with small payloads DAG calls, n={n_actors} actors",
268
+ lambda: ray.get(dag.execute(*[b"x" * payload_size for _ in range(n_actors)])),
269
+ )
270
+ compiled_dag = dag.experimental_compile()
271
+ results += timeit(
272
+ f"[unstable] compiled multiple args with small payloads DAG calls, "
273
+ f"n={n_actors} actors",
274
+ lambda: _exec(compiled_dag, num_args=n_actors, payload_size=payload_size),
275
+ )
276
+
277
+ # Multiple args with medium payloads
278
+
279
+ actors = [DAGActor.remote() for _ in range(n_actors)]
280
+ with InputNode() as inp:
281
+ dag = MultiOutputNode([actors[i].echo.bind(inp[i]) for i in range(n_actors)])
282
+ payload_size = 1024 * 1024
283
+ results += timeit(
284
+ f"[unstable] multiple args with medium payloads DAG calls, n={n_actors} actors",
285
+ lambda: ray.get(dag.execute(*[b"x" * payload_size for _ in range(n_actors)])),
286
+ )
287
+ compiled_dag = dag.experimental_compile()
288
+ results += timeit(
289
+ "[unstable] compiled multiple args with medium payloads DAG calls, "
290
+ f"n={n_actors} actors",
291
+ lambda: _exec(compiled_dag, num_args=n_actors, payload_size=payload_size),
292
+ )
293
+
294
+ # Multiple args with large payloads
295
+
296
+ actors = [DAGActor.remote() for _ in range(n_actors)]
297
+ with InputNode() as inp:
298
+ dag = MultiOutputNode([actors[i].echo.bind(inp[i]) for i in range(n_actors)])
299
+ payload_size = 10 * 1024 * 1024
300
+ results += timeit(
301
+ f"[unstable] multiple args with large payloads DAG calls, n={n_actors} actors",
302
+ lambda: ray.get(dag.execute(*[b"x" * payload_size for _ in range(n_actors)])),
303
+ )
304
+ compiled_dag = dag.experimental_compile()
305
+ results += timeit(
306
+ "[unstable] compiled multiple args with large payloads DAG calls, "
307
+ f"n={n_actors} actors",
308
+ lambda: _exec(compiled_dag, num_args=n_actors, payload_size=payload_size),
309
+ )
310
+
311
+ # Worst case for multiple arguments: a single actor takes all the arguments
312
+ # with small payloads.
313
+
314
+ actor = DAGActor.remote()
315
+ n_args = 8
316
+ with InputNode() as inp:
317
+ dag = actor.echo_multiple.bind(*[inp[i] for i in range(n_args)])
318
+ payload_size = 1
319
+ results += timeit(
320
+ "[unstable] single-actor with all args with small payloads DAG calls, "
321
+ "n=1 actors",
322
+ lambda: ray.get(dag.execute(*[b"x" * payload_size for _ in range(n_args)])),
323
+ )
324
+ compiled_dag = dag.experimental_compile()
325
+ results += timeit(
326
+ "[unstable] single-actor with all args with small payloads DAG calls, "
327
+ "n=1 actors",
328
+ lambda: _exec(compiled_dag, num_args=n_args, payload_size=payload_size),
329
+ )
330
+
331
+ ray.shutdown()
332
+
333
+ return results
334
+
335
+
336
+ if __name__ == "__main__":
337
+ main()
.venv/lib/python3.11/site-packages/ray/_private/ray_microbenchmark_helpers.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import List, Optional, Tuple
3
+ import os
4
+ import ray
5
+ import numpy as np
6
+
7
+ from contextlib import contextmanager
8
+
9
+ # Only run tests matching this filter pattern.
10
+
11
+ filter_pattern = os.environ.get("TESTS_TO_RUN", "")
12
+ skip_pattern = os.environ.get("TESTS_TO_SKIP", "")
13
+
14
+
15
+ def timeit(
16
+ name, fn, multiplier=1, warmup_time_sec=10
17
+ ) -> List[Optional[Tuple[str, float, float]]]:
18
+ if filter_pattern and filter_pattern not in name:
19
+ return [None]
20
+ if skip_pattern and skip_pattern in name:
21
+ return [None]
22
+ # sleep for a while to avoid noisy neigbhors.
23
+ # related issue: https://github.com/ray-project/ray/issues/22045
24
+ time.sleep(warmup_time_sec)
25
+ # warmup
26
+ start = time.perf_counter()
27
+ count = 0
28
+ while time.perf_counter() - start < 1:
29
+ fn()
30
+ count += 1
31
+ # real run
32
+ step = count // 10 + 1
33
+ stats = []
34
+ for _ in range(4):
35
+ start = time.perf_counter()
36
+ count = 0
37
+ while time.perf_counter() - start < 2:
38
+ for _ in range(step):
39
+ fn()
40
+ count += step
41
+ end = time.perf_counter()
42
+ stats.append(multiplier * count / (end - start))
43
+
44
+ mean = np.mean(stats)
45
+ sd = np.std(stats)
46
+ print(name, "per second", round(mean, 2), "+-", round(sd, 2))
47
+ return [(name, mean, sd)]
48
+
49
+
50
+ async def asyncio_timeit(
51
+ name, async_fn, multiplier=1, warmup_time_sec=10
52
+ ) -> List[Optional[Tuple[str, float, float]]]:
53
+ if filter_pattern and filter_pattern not in name:
54
+ return [None]
55
+ if skip_pattern and skip_pattern in name:
56
+ return [None]
57
+ # sleep for a while to avoid noisy neigbhors.
58
+ # related issue: https://github.com/ray-project/ray/issues/22045
59
+ time.sleep(warmup_time_sec)
60
+ # warmup
61
+ start = time.perf_counter()
62
+ count = 0
63
+ while time.perf_counter() - start < 1:
64
+ await async_fn()
65
+ count += 1
66
+ # real run
67
+ step = count // 10 + 1
68
+ stats = []
69
+ for _ in range(4):
70
+ start = time.perf_counter()
71
+ count = 0
72
+ while time.perf_counter() - start < 2:
73
+ for _ in range(step):
74
+ await async_fn()
75
+ count += step
76
+ end = time.perf_counter()
77
+ stats.append(multiplier * count / (end - start))
78
+
79
+ mean = np.mean(stats)
80
+ sd = np.std(stats)
81
+ print(name, "per second", round(mean, 2), "+-", round(sd, 2))
82
+ return [(name, mean, sd)]
83
+
84
+
85
+ @contextmanager
86
+ def ray_setup_and_teardown(**init_args):
87
+ ray.init(**init_args)
88
+ try:
89
+ yield None
90
+ finally:
91
+ ray.shutdown()
.venv/lib/python3.11/site-packages/ray/_private/ray_option_utils.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Manage, parse and validate options for Ray tasks, actors and actor methods."""
2
+ import warnings
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Dict, Optional, Tuple, Union
5
+
6
+ import ray
7
+ from ray._private import ray_constants
8
+ from ray._private.utils import get_ray_doc_version
9
+ from ray.util.placement_group import PlacementGroup
10
+ from ray.util.scheduling_strategies import (
11
+ NodeAffinitySchedulingStrategy,
12
+ PlacementGroupSchedulingStrategy,
13
+ NodeLabelSchedulingStrategy,
14
+ )
15
+
16
+
17
+ @dataclass
18
+ class Option:
19
+ # Type constraint of an option.
20
+ type_constraint: Optional[Union[type, Tuple[type]]] = None
21
+ # Value constraint of an option.
22
+ # The callable should return None if there is no error.
23
+ # Otherwise, return the error message.
24
+ value_constraint: Optional[Callable[[Any], Optional[str]]] = None
25
+ # Default value.
26
+ default_value: Any = None
27
+
28
+ def validate(self, keyword: str, value: Any):
29
+ """Validate the option."""
30
+ if self.type_constraint is not None:
31
+ if not isinstance(value, self.type_constraint):
32
+ raise TypeError(
33
+ f"The type of keyword '{keyword}' must be {self.type_constraint}, "
34
+ f"but received type {type(value)}"
35
+ )
36
+ if self.value_constraint is not None:
37
+ possible_error_message = self.value_constraint(value)
38
+ if possible_error_message:
39
+ raise ValueError(possible_error_message)
40
+
41
+
42
+ def _counting_option(name: str, infinite: bool = True, default_value: Any = None):
43
+ """This is used for positive and discrete options.
44
+
45
+ Args:
46
+ name: The name of the option keyword.
47
+ infinite: If True, user could use -1 to represent infinity.
48
+ default_value: The default value for this option.
49
+ """
50
+ if infinite:
51
+ return Option(
52
+ (int, type(None)),
53
+ lambda x: None
54
+ if (x is None or x >= -1)
55
+ else f"The keyword '{name}' only accepts None, 0, -1"
56
+ " or a positive integer, where -1 represents infinity.",
57
+ default_value=default_value,
58
+ )
59
+ return Option(
60
+ (int, type(None)),
61
+ lambda x: None
62
+ if (x is None or x >= 0)
63
+ else f"The keyword '{name}' only accepts None, 0 or a positive integer.",
64
+ default_value=default_value,
65
+ )
66
+
67
+
68
+ def _validate_resource_quantity(name, quantity):
69
+ if quantity < 0:
70
+ return f"The quantity of resource {name} cannot be negative"
71
+ if (
72
+ isinstance(quantity, float)
73
+ and quantity != 0.0
74
+ and int(quantity * ray._raylet.RESOURCE_UNIT_SCALING) == 0
75
+ ):
76
+ return (
77
+ f"The precision of the fractional quantity of resource {name}"
78
+ " cannot go beyond 0.0001"
79
+ )
80
+ resource_name = "GPU" if name == "num_gpus" else name
81
+ if resource_name in ray._private.accelerators.get_all_accelerator_resource_names():
82
+ (
83
+ valid,
84
+ error_message,
85
+ ) = ray._private.accelerators.get_accelerator_manager_for_resource(
86
+ resource_name
87
+ ).validate_resource_request_quantity(
88
+ quantity
89
+ )
90
+ if not valid:
91
+ return error_message
92
+ return None
93
+
94
+
95
+ def _resource_option(name: str, default_value: Any = None):
96
+ """This is used for resource related options."""
97
+ return Option(
98
+ (float, int, type(None)),
99
+ lambda x: None if (x is None) else _validate_resource_quantity(name, x),
100
+ default_value=default_value,
101
+ )
102
+
103
+
104
+ def _validate_resources(resources: Optional[Dict[str, float]]) -> Optional[str]:
105
+ if resources is None:
106
+ return None
107
+
108
+ if "CPU" in resources or "GPU" in resources:
109
+ return (
110
+ "Use the 'num_cpus' and 'num_gpus' keyword instead of 'CPU' and 'GPU' "
111
+ "in 'resources' keyword"
112
+ )
113
+
114
+ for name, quantity in resources.items():
115
+ possible_error_message = _validate_resource_quantity(name, quantity)
116
+ if possible_error_message:
117
+ return possible_error_message
118
+
119
+ return None
120
+
121
+
122
+ _common_options = {
123
+ "accelerator_type": Option((str, type(None))),
124
+ "memory": _resource_option("memory"),
125
+ "name": Option((str, type(None))),
126
+ "num_cpus": _resource_option("num_cpus"),
127
+ "num_gpus": _resource_option("num_gpus"),
128
+ "object_store_memory": _counting_option("object_store_memory", False),
129
+ # TODO(suquark): "placement_group", "placement_group_bundle_index"
130
+ # and "placement_group_capture_child_tasks" are deprecated,
131
+ # use "scheduling_strategy" instead.
132
+ "placement_group": Option(
133
+ (type(None), str, PlacementGroup), default_value="default"
134
+ ),
135
+ "placement_group_bundle_index": Option(int, default_value=-1),
136
+ "placement_group_capture_child_tasks": Option((bool, type(None))),
137
+ "resources": Option((dict, type(None)), lambda x: _validate_resources(x)),
138
+ "runtime_env": Option((dict, type(None))),
139
+ "scheduling_strategy": Option(
140
+ (
141
+ type(None),
142
+ str,
143
+ PlacementGroupSchedulingStrategy,
144
+ NodeAffinitySchedulingStrategy,
145
+ NodeLabelSchedulingStrategy,
146
+ )
147
+ ),
148
+ "_metadata": Option((dict, type(None))),
149
+ "enable_task_events": Option(bool, default_value=True),
150
+ "_labels": Option((dict, type(None))),
151
+ }
152
+
153
+
154
+ def issubclass_safe(obj: Any, cls_: type) -> bool:
155
+ try:
156
+ return issubclass(obj, cls_)
157
+ except TypeError:
158
+ return False
159
+
160
+
161
+ _task_only_options = {
162
+ "max_calls": _counting_option("max_calls", False, default_value=0),
163
+ # Normal tasks may be retried on failure this many times.
164
+ # TODO(swang): Allow this to be set globally for an application.
165
+ "max_retries": _counting_option(
166
+ "max_retries", default_value=ray_constants.DEFAULT_TASK_MAX_RETRIES
167
+ ),
168
+ # override "_common_options"
169
+ "num_cpus": _resource_option("num_cpus", default_value=1),
170
+ "num_returns": Option(
171
+ (int, str, type(None)),
172
+ lambda x: None
173
+ if (x is None or x == "dynamic" or x == "streaming" or x >= 0)
174
+ else "Default None. When None is passed, "
175
+ "The default value is 1 for a task and actor task, and "
176
+ "'streaming' for generator tasks and generator actor tasks. "
177
+ "The keyword 'num_returns' only accepts None, "
178
+ "a non-negative integer, "
179
+ "'streaming' (for generators), or 'dynamic'. 'dynamic' flag "
180
+ "will be deprecated in the future, and it is recommended to use "
181
+ "'streaming' instead.",
182
+ default_value=None,
183
+ ),
184
+ "object_store_memory": Option( # override "_common_options"
185
+ (int, type(None)),
186
+ lambda x: None
187
+ if (x is None)
188
+ else "Setting 'object_store_memory' is not implemented for tasks",
189
+ ),
190
+ "retry_exceptions": Option(
191
+ (bool, list, tuple),
192
+ lambda x: None
193
+ if (
194
+ isinstance(x, bool)
195
+ or (
196
+ isinstance(x, (list, tuple))
197
+ and all(issubclass_safe(x_, Exception) for x_ in x)
198
+ )
199
+ )
200
+ else "retry_exceptions must be either a boolean or a list of exceptions",
201
+ default_value=False,
202
+ ),
203
+ "_generator_backpressure_num_objects": Option(
204
+ (int, type(None)),
205
+ lambda x: None
206
+ if x != 0
207
+ else (
208
+ "_generator_backpressure_num_objects=0 is not allowed. "
209
+ "Use a value > 0. If the value is equal to 1, the behavior "
210
+ "is identical to Python generator (generator 1 object "
211
+ "whenever `next` is called). Use -1 to disable this feature. "
212
+ ),
213
+ ),
214
+ }
215
+
216
+ _actor_only_options = {
217
+ "concurrency_groups": Option((list, dict, type(None))),
218
+ "lifetime": Option(
219
+ (str, type(None)),
220
+ lambda x: None
221
+ if x in (None, "detached", "non_detached")
222
+ else "actor `lifetime` argument must be one of 'detached', "
223
+ "'non_detached' and 'None'.",
224
+ ),
225
+ "max_concurrency": _counting_option("max_concurrency", False),
226
+ "max_restarts": _counting_option("max_restarts", default_value=0),
227
+ "max_task_retries": _counting_option("max_task_retries", default_value=0),
228
+ "max_pending_calls": _counting_option("max_pending_calls", default_value=-1),
229
+ "namespace": Option((str, type(None))),
230
+ "get_if_exists": Option(bool, default_value=False),
231
+ }
232
+
233
+ # Priority is important here because during dictionary update, same key with higher
234
+ # priority overrides the same key with lower priority. We make use of priority
235
+ # to set the correct default value for tasks / actors.
236
+
237
+ # priority: _common_options > _actor_only_options > _task_only_options
238
+ valid_options: Dict[str, Option] = {
239
+ **_task_only_options,
240
+ **_actor_only_options,
241
+ **_common_options,
242
+ }
243
+ # priority: _task_only_options > _common_options
244
+ task_options: Dict[str, Option] = {**_common_options, **_task_only_options}
245
+ # priority: _actor_only_options > _common_options
246
+ actor_options: Dict[str, Option] = {**_common_options, **_actor_only_options}
247
+
248
+ remote_args_error_string = (
249
+ "The @ray.remote decorator must be applied either with no arguments and no "
250
+ "parentheses, for example '@ray.remote', or it must be applied using some of "
251
+ f"the arguments in the list {list(valid_options.keys())}, for example "
252
+ "'@ray.remote(num_returns=2, resources={\"CustomResource\": 1})'."
253
+ )
254
+
255
+
256
+ def _check_deprecate_placement_group(options: Dict[str, Any]):
257
+ """Check if deprecated placement group option exists."""
258
+ placement_group = options.get("placement_group", "default")
259
+ scheduling_strategy = options.get("scheduling_strategy")
260
+ # TODO(suquark): @ray.remote(placement_group=None) is used in
261
+ # "python/ray.data._internal/remote_fn.py" and many other places,
262
+ # while "ray.data.read_api.read_datasource" set "scheduling_strategy=SPREAD".
263
+ # This might be a bug, but it is also ok to allow them co-exist.
264
+ if (placement_group not in ("default", None)) and (scheduling_strategy is not None):
265
+ raise ValueError(
266
+ "Placement groups should be specified via the "
267
+ "scheduling_strategy option. "
268
+ "The placement_group option is deprecated."
269
+ )
270
+
271
+
272
+ def _warn_if_using_deprecated_placement_group(
273
+ options: Dict[str, Any], caller_stacklevel: int
274
+ ):
275
+ placement_group = options["placement_group"]
276
+ placement_group_bundle_index = options["placement_group_bundle_index"]
277
+ placement_group_capture_child_tasks = options["placement_group_capture_child_tasks"]
278
+ if placement_group != "default":
279
+ warnings.warn(
280
+ "placement_group parameter is deprecated. Use "
281
+ "scheduling_strategy=PlacementGroupSchedulingStrategy(...) "
282
+ "instead, see the usage at "
283
+ f"https://docs.ray.io/en/{get_ray_doc_version()}/ray-core/package-ref.html#ray-remote.", # noqa: E501
284
+ DeprecationWarning,
285
+ stacklevel=caller_stacklevel + 1,
286
+ )
287
+ if placement_group_bundle_index != -1:
288
+ warnings.warn(
289
+ "placement_group_bundle_index parameter is deprecated. Use "
290
+ "scheduling_strategy=PlacementGroupSchedulingStrategy(...) "
291
+ "instead, see the usage at "
292
+ f"https://docs.ray.io/en/{get_ray_doc_version()}/ray-core/package-ref.html#ray-remote.", # noqa: E501
293
+ DeprecationWarning,
294
+ stacklevel=caller_stacklevel + 1,
295
+ )
296
+ if placement_group_capture_child_tasks:
297
+ warnings.warn(
298
+ "placement_group_capture_child_tasks parameter is deprecated. Use "
299
+ "scheduling_strategy=PlacementGroupSchedulingStrategy(...) "
300
+ "instead, see the usage at "
301
+ f"https://docs.ray.io/en/{get_ray_doc_version()}/ray-core/package-ref.html#ray-remote.", # noqa: E501
302
+ DeprecationWarning,
303
+ stacklevel=caller_stacklevel + 1,
304
+ )
305
+
306
+
307
+ def validate_task_options(options: Dict[str, Any], in_options: bool):
308
+ """Options check for Ray tasks.
309
+
310
+ Args:
311
+ options: Options for Ray tasks.
312
+ in_options: If True, we are checking the options under the context of
313
+ ".options()".
314
+ """
315
+ for k, v in options.items():
316
+ if k not in task_options:
317
+ raise ValueError(
318
+ f"Invalid option keyword {k} for remote functions. "
319
+ f"Valid ones are {list(task_options)}."
320
+ )
321
+ task_options[k].validate(k, v)
322
+ if in_options and "max_calls" in options:
323
+ raise ValueError("Setting 'max_calls' is not supported in '.options()'.")
324
+ _check_deprecate_placement_group(options)
325
+
326
+
327
+ def validate_actor_options(options: Dict[str, Any], in_options: bool):
328
+ """Options check for Ray actors.
329
+
330
+ Args:
331
+ options: Options for Ray actors.
332
+ in_options: If True, we are checking the options under the context of
333
+ ".options()".
334
+ """
335
+ for k, v in options.items():
336
+ if k not in actor_options:
337
+ raise ValueError(
338
+ f"Invalid option keyword {k} for actors. "
339
+ f"Valid ones are {list(actor_options)}."
340
+ )
341
+ actor_options[k].validate(k, v)
342
+
343
+ if in_options and "concurrency_groups" in options:
344
+ raise ValueError(
345
+ "Setting 'concurrency_groups' is not supported in '.options()'."
346
+ )
347
+
348
+ if options.get("get_if_exists") and not options.get("name"):
349
+ raise ValueError("The actor name must be specified to use `get_if_exists`.")
350
+
351
+ if "object_store_memory" in options:
352
+ warnings.warn(
353
+ "Setting 'object_store_memory'"
354
+ " for actors is deprecated since it doesn't actually"
355
+ " reserve the required object store memory."
356
+ f" Use object spilling that's enabled by default (https://docs.ray.io/en/{get_ray_doc_version()}/ray-core/objects/object-spilling.html) " # noqa: E501
357
+ "instead to bypass the object store memory size limitation.",
358
+ DeprecationWarning,
359
+ stacklevel=1,
360
+ )
361
+
362
+ _check_deprecate_placement_group(options)
363
+
364
+
365
+ def update_options(
366
+ original_options: Dict[str, Any], new_options: Dict[str, Any]
367
+ ) -> Dict[str, Any]:
368
+ """Update original options with new options and return.
369
+ The returned updated options contain shallow copy of original options.
370
+ """
371
+
372
+ updated_options = {**original_options, **new_options}
373
+ # Ensure we update each namespace in "_metadata" independently.
374
+ # "_metadata" is a dict like {namespace1: config1, namespace2: config2}
375
+ if (
376
+ original_options.get("_metadata") is not None
377
+ and new_options.get("_metadata") is not None
378
+ ):
379
+ # make a shallow copy to avoid messing up the metadata dict in
380
+ # the original options.
381
+ metadata = original_options["_metadata"].copy()
382
+ for namespace, config in new_options["_metadata"].items():
383
+ metadata[namespace] = {**metadata.get(namespace, {}), **config}
384
+
385
+ updated_options["_metadata"] = metadata
386
+
387
+ return updated_options
.venv/lib/python3.11/site-packages/ray/_private/ray_perf.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This is the script for `ray microbenchmark`."""
2
+
3
+ import asyncio
4
+ import logging
5
+ from ray._private.ray_microbenchmark_helpers import timeit
6
+ from ray._private.ray_client_microbenchmark import main as client_microbenchmark_main
7
+ import numpy as np
8
+ import multiprocessing
9
+ import ray
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ @ray.remote(num_cpus=0)
15
+ class Actor:
16
+ def small_value(self):
17
+ return b"ok"
18
+
19
+ def small_value_arg(self, x):
20
+ return b"ok"
21
+
22
+ def small_value_batch(self, n):
23
+ ray.get([small_value.remote() for _ in range(n)])
24
+
25
+
26
+ @ray.remote
27
+ class AsyncActor:
28
+ async def small_value(self):
29
+ return b"ok"
30
+
31
+ async def small_value_with_arg(self, x):
32
+ return b"ok"
33
+
34
+ async def small_value_batch(self, n):
35
+ await asyncio.wait([small_value.remote() for _ in range(n)])
36
+
37
+
38
+ @ray.remote(num_cpus=0)
39
+ class Client:
40
+ def __init__(self, servers):
41
+ if not isinstance(servers, list):
42
+ servers = [servers]
43
+ self.servers = servers
44
+
45
+ def small_value_batch(self, n):
46
+ results = []
47
+ for s in self.servers:
48
+ results.extend([s.small_value.remote() for _ in range(n)])
49
+ ray.get(results)
50
+
51
+ def small_value_batch_arg(self, n):
52
+ x = ray.put(0)
53
+ results = []
54
+ for s in self.servers:
55
+ results.extend([s.small_value_arg.remote(x) for _ in range(n)])
56
+ ray.get(results)
57
+
58
+
59
+ @ray.remote
60
+ def small_value():
61
+ return b"ok"
62
+
63
+
64
+ @ray.remote
65
+ def small_value_batch(n):
66
+ submitted = [small_value.remote() for _ in range(n)]
67
+ ray.get(submitted)
68
+ return 0
69
+
70
+
71
+ @ray.remote
72
+ def create_object_containing_ref():
73
+ obj_refs = []
74
+ for _ in range(10000):
75
+ obj_refs.append(ray.put(1))
76
+ return obj_refs
77
+
78
+
79
+ def check_optimized_build():
80
+ if not ray._raylet.OPTIMIZED:
81
+ msg = (
82
+ "WARNING: Unoptimized build! "
83
+ "To benchmark an optimized build, try:\n"
84
+ "\tbazel build -c opt //:ray_pkg\n"
85
+ "You can also make this permanent by adding\n"
86
+ "\tbuild --compilation_mode=opt\n"
87
+ "to your user-wide ~/.bazelrc file. "
88
+ "(Do not add this to the project-level .bazelrc file.)"
89
+ )
90
+ logger.warning(msg)
91
+
92
+
93
+ def main(results=None):
94
+ results = results or []
95
+
96
+ check_optimized_build()
97
+
98
+ print("Tip: set TESTS_TO_RUN='pattern' to run a subset of benchmarks")
99
+
100
+ ray.init()
101
+
102
+ value = ray.put(0)
103
+
104
+ def get_small():
105
+ ray.get(value)
106
+
107
+ def put_small():
108
+ ray.put(0)
109
+
110
+ @ray.remote
111
+ def do_put_small():
112
+ for _ in range(100):
113
+ ray.put(0)
114
+
115
+ def put_multi_small():
116
+ ray.get([do_put_small.remote() for _ in range(10)])
117
+
118
+ arr = np.zeros(100 * 1024 * 1024, dtype=np.int64)
119
+
120
+ results += timeit("single client get calls (Plasma Store)", get_small)
121
+
122
+ results += timeit("single client put calls (Plasma Store)", put_small)
123
+
124
+ results += timeit("multi client put calls (Plasma Store)", put_multi_small, 1000)
125
+
126
+ def put_large():
127
+ ray.put(arr)
128
+
129
+ results += timeit("single client put gigabytes", put_large, 8 * 0.1)
130
+
131
+ def small_value_batch():
132
+ submitted = [small_value.remote() for _ in range(1000)]
133
+ ray.get(submitted)
134
+ return 0
135
+
136
+ results += timeit("single client tasks and get batch", small_value_batch)
137
+
138
+ @ray.remote
139
+ def do_put():
140
+ for _ in range(10):
141
+ ray.put(np.zeros(10 * 1024 * 1024, dtype=np.int64))
142
+
143
+ def put_multi():
144
+ ray.get([do_put.remote() for _ in range(10)])
145
+
146
+ results += timeit("multi client put gigabytes", put_multi, 10 * 8 * 0.1)
147
+
148
+ obj_containing_ref = create_object_containing_ref.remote()
149
+
150
+ def get_containing_object_ref():
151
+ ray.get(obj_containing_ref)
152
+
153
+ results += timeit(
154
+ "single client get object containing 10k refs", get_containing_object_ref
155
+ )
156
+
157
+ def wait_multiple_refs():
158
+ num_objs = 1000
159
+ not_ready = [small_value.remote() for _ in range(num_objs)]
160
+ # We only need to trigger the fetch_local once for each object,
161
+ # raylet will persist these fetch requests even after ray.wait returns.
162
+ # See https://github.com/ray-project/ray/issues/30375.
163
+ fetch_local = True
164
+ for _ in range(num_objs):
165
+ _ready, not_ready = ray.wait(not_ready, fetch_local=fetch_local)
166
+ if fetch_local:
167
+ fetch_local = False
168
+
169
+ results += timeit("single client wait 1k refs", wait_multiple_refs)
170
+
171
+ def small_task():
172
+ ray.get(small_value.remote())
173
+
174
+ results += timeit("single client tasks sync", small_task)
175
+
176
+ def small_task_async():
177
+ ray.get([small_value.remote() for _ in range(1000)])
178
+
179
+ results += timeit("single client tasks async", small_task_async, 1000)
180
+
181
+ n = 10000
182
+ m = 4
183
+ actors = [Actor.remote() for _ in range(m)]
184
+
185
+ def multi_task():
186
+ submitted = [a.small_value_batch.remote(n) for a in actors]
187
+ ray.get(submitted)
188
+
189
+ results += timeit("multi client tasks async", multi_task, n * m)
190
+
191
+ a = Actor.remote()
192
+
193
+ def actor_sync():
194
+ ray.get(a.small_value.remote())
195
+
196
+ results += timeit("1:1 actor calls sync", actor_sync)
197
+
198
+ a = Actor.remote()
199
+
200
+ def actor_async():
201
+ ray.get([a.small_value.remote() for _ in range(1000)])
202
+
203
+ results += timeit("1:1 actor calls async", actor_async, 1000)
204
+
205
+ a = Actor.options(max_concurrency=16).remote()
206
+
207
+ def actor_concurrent():
208
+ ray.get([a.small_value.remote() for _ in range(1000)])
209
+
210
+ results += timeit("1:1 actor calls concurrent", actor_concurrent, 1000)
211
+
212
+ n = 5000
213
+ n_cpu = multiprocessing.cpu_count() // 2
214
+ actors = [Actor._remote() for _ in range(n_cpu)]
215
+ client = Client.remote(actors)
216
+
217
+ def actor_async_direct():
218
+ ray.get(client.small_value_batch.remote(n))
219
+
220
+ results += timeit("1:n actor calls async", actor_async_direct, n * len(actors))
221
+
222
+ n_cpu = multiprocessing.cpu_count() // 2
223
+ a = [Actor.remote() for _ in range(n_cpu)]
224
+
225
+ @ray.remote
226
+ def work(actors):
227
+ ray.get([actors[i % n_cpu].small_value.remote() for i in range(n)])
228
+
229
+ def actor_multi2():
230
+ ray.get([work.remote(a) for _ in range(m)])
231
+
232
+ results += timeit("n:n actor calls async", actor_multi2, m * n)
233
+
234
+ n = 1000
235
+ actors = [Actor._remote() for _ in range(n_cpu)]
236
+ clients = [Client.remote(a) for a in actors]
237
+
238
+ def actor_multi2_direct_arg():
239
+ ray.get([c.small_value_batch_arg.remote(n) for c in clients])
240
+
241
+ results += timeit(
242
+ "n:n actor calls with arg async", actor_multi2_direct_arg, n * len(clients)
243
+ )
244
+
245
+ a = AsyncActor.remote()
246
+
247
+ def actor_sync():
248
+ ray.get(a.small_value.remote())
249
+
250
+ results += timeit("1:1 async-actor calls sync", actor_sync)
251
+
252
+ a = AsyncActor.remote()
253
+
254
+ def async_actor():
255
+ ray.get([a.small_value.remote() for _ in range(1000)])
256
+
257
+ results += timeit("1:1 async-actor calls async", async_actor, 1000)
258
+
259
+ a = AsyncActor.remote()
260
+
261
+ def async_actor():
262
+ ray.get([a.small_value_with_arg.remote(i) for i in range(1000)])
263
+
264
+ results += timeit("1:1 async-actor calls with args async", async_actor, 1000)
265
+
266
+ n = 5000
267
+ n_cpu = multiprocessing.cpu_count() // 2
268
+ actors = [AsyncActor.remote() for _ in range(n_cpu)]
269
+ client = Client.remote(actors)
270
+
271
+ def async_actor_async():
272
+ ray.get(client.small_value_batch.remote(n))
273
+
274
+ results += timeit("1:n async-actor calls async", async_actor_async, n * len(actors))
275
+
276
+ n = 5000
277
+ m = 4
278
+ n_cpu = multiprocessing.cpu_count() // 2
279
+ a = [AsyncActor.remote() for _ in range(n_cpu)]
280
+
281
+ @ray.remote
282
+ def async_actor_work(actors):
283
+ ray.get([actors[i % n_cpu].small_value.remote() for i in range(n)])
284
+
285
+ def async_actor_multi():
286
+ ray.get([async_actor_work.remote(a) for _ in range(m)])
287
+
288
+ results += timeit("n:n async-actor calls async", async_actor_multi, m * n)
289
+ ray.shutdown()
290
+
291
+ ############################
292
+ # End of channel perf tests.
293
+ ############################
294
+
295
+ NUM_PGS = 100
296
+ NUM_BUNDLES = 1
297
+ ray.init(resources={"custom": 100})
298
+
299
+ def placement_group_create_removal(num_pgs):
300
+ pgs = [
301
+ ray.util.placement_group(
302
+ bundles=[{"custom": 0.001} for _ in range(NUM_BUNDLES)]
303
+ )
304
+ for _ in range(num_pgs)
305
+ ]
306
+ [pg.wait(timeout_seconds=30) for pg in pgs]
307
+ # Include placement group removal here to clean up.
308
+ # If we don't clean up placement groups, the whole performance
309
+ # gets slower as it runs more.
310
+ # Since timeit function runs multiple times without
311
+ # the cleaning logic, we should have this method here.
312
+ for pg in pgs:
313
+ ray.util.remove_placement_group(pg)
314
+
315
+ results += timeit(
316
+ "placement group create/removal",
317
+ lambda: placement_group_create_removal(NUM_PGS),
318
+ NUM_PGS,
319
+ )
320
+ ray.shutdown()
321
+
322
+ client_microbenchmark_main(results)
323
+
324
+ return results
325
+
326
+
327
+ if __name__ == "__main__":
328
+ main()
.venv/lib/python3.11/site-packages/ray/_private/ray_process_reaper.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ import os
3
+ import signal
4
+ import sys
5
+ import time
6
+
7
+ """
8
+ This is a lightweight "reaper" process used to ensure that ray processes are
9
+ cleaned up properly when the main ray process dies unexpectedly (e.g.,
10
+ segfaults or gets SIGKILLed). Note that processes may not be cleaned up
11
+ properly if this process is SIGTERMed or SIGKILLed.
12
+
13
+ It detects that its parent has died by reading from stdin, which must be
14
+ inherited from the parent process so that the OS will deliver an EOF if the
15
+ parent dies. When this happens, the reaper process kills the rest of its
16
+ process group (first attempting graceful shutdown with SIGTERM, then escalating
17
+ to SIGKILL).
18
+ """
19
+
20
+ SIGTERM_GRACE_PERIOD_SECONDS = 1
21
+
22
+
23
+ def reap_process_group(*args):
24
+ def sigterm_handler(*args):
25
+ # Give a one-second grace period for other processes to clean up.
26
+ time.sleep(SIGTERM_GRACE_PERIOD_SECONDS)
27
+ # SIGKILL the pgroup (including ourselves) as a last-resort.
28
+ if sys.platform == "win32":
29
+ atexit.unregister(sigterm_handler)
30
+ os.kill(0, signal.CTRL_BREAK_EVENT)
31
+ else:
32
+ os.killpg(0, signal.SIGKILL)
33
+
34
+ # Set a SIGTERM handler to handle SIGTERMing ourselves with the group.
35
+ if sys.platform == "win32":
36
+ atexit.register(sigterm_handler)
37
+ else:
38
+ signal.signal(signal.SIGTERM, sigterm_handler)
39
+
40
+ # Our parent must have died, SIGTERM the group (including ourselves).
41
+ if sys.platform == "win32":
42
+ os.kill(0, signal.CTRL_C_EVENT)
43
+ else:
44
+ os.killpg(0, signal.SIGTERM)
45
+
46
+
47
+ def main():
48
+ # Read from stdout forever. Because stdout is a file descriptor
49
+ # inherited from our parent process, we will get an EOF if the parent
50
+ # dies, which is signaled by an empty return from read().
51
+ # We intentionally don't set any signal handlers here, so a SIGTERM from
52
+ # the parent can be used to kill this process gracefully without it killing
53
+ # the rest of the process group.
54
+ while len(sys.stdin.read()) != 0:
55
+ pass
56
+ reap_process_group()
57
+
58
+
59
+ if __name__ == "__main__":
60
+ main()
.venv/lib/python3.11/site-packages/ray/_private/resource_spec.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+ from collections import namedtuple
4
+ from typing import Optional
5
+
6
+ import ray
7
+ import ray._private.ray_constants as ray_constants
8
+
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # Prefix for the node id resource that is automatically added to each node.
13
+ # For example, a node may have id `node:172.23.42.1`.
14
+ NODE_ID_PREFIX = "node:"
15
+ # The system resource that head node has.
16
+ HEAD_NODE_RESOURCE_NAME = NODE_ID_PREFIX + "__internal_head__"
17
+
18
+
19
+ class ResourceSpec(
20
+ namedtuple(
21
+ "ResourceSpec",
22
+ [
23
+ "num_cpus",
24
+ "num_gpus",
25
+ "memory",
26
+ "object_store_memory",
27
+ "resources",
28
+ "redis_max_memory",
29
+ ],
30
+ )
31
+ ):
32
+ """Represents the resource configuration passed to a raylet.
33
+
34
+ All fields can be None. Before starting services, resolve() should be
35
+ called to return a ResourceSpec with unknown values filled in with
36
+ defaults based on the local machine specifications.
37
+
38
+ Attributes:
39
+ num_cpus: The CPUs allocated for this raylet.
40
+ num_gpus: The GPUs allocated for this raylet.
41
+ memory: The memory allocated for this raylet.
42
+ object_store_memory: The object store memory allocated for this raylet.
43
+ Note that when calling to_resource_dict(), this will be scaled down
44
+ by 30% to account for the global plasma LRU reserve.
45
+ resources: The custom resources allocated for this raylet.
46
+ redis_max_memory: The max amount of memory (in bytes) to allow each
47
+ redis shard to use. Once the limit is exceeded, redis will start
48
+ LRU eviction of entries. This only applies to the sharded redis
49
+ tables (task, object, and profile tables). By default, this is
50
+ capped at 10GB but can be set higher.
51
+ """
52
+
53
+ def __new__(
54
+ cls,
55
+ num_cpus=None,
56
+ num_gpus=None,
57
+ memory=None,
58
+ object_store_memory=None,
59
+ resources=None,
60
+ redis_max_memory=None,
61
+ ):
62
+ return super(ResourceSpec, cls).__new__(
63
+ cls,
64
+ num_cpus,
65
+ num_gpus,
66
+ memory,
67
+ object_store_memory,
68
+ resources,
69
+ redis_max_memory,
70
+ )
71
+
72
+ def resolved(self):
73
+ """Returns if this ResourceSpec has default values filled out."""
74
+ for v in self._asdict().values():
75
+ if v is None:
76
+ return False
77
+ return True
78
+
79
+ def to_resource_dict(self):
80
+ """Returns a dict suitable to pass to raylet initialization.
81
+
82
+ This renames num_cpus / num_gpus to "CPU" / "GPU",
83
+ translates memory from bytes into 100MB memory units, and checks types.
84
+ """
85
+ assert self.resolved()
86
+
87
+ resources = dict(
88
+ self.resources,
89
+ CPU=self.num_cpus,
90
+ GPU=self.num_gpus,
91
+ memory=int(self.memory),
92
+ object_store_memory=int(self.object_store_memory),
93
+ )
94
+
95
+ resources = {
96
+ resource_label: resource_quantity
97
+ for resource_label, resource_quantity in resources.items()
98
+ if resource_quantity != 0
99
+ }
100
+
101
+ # Check types.
102
+ for resource_label, resource_quantity in resources.items():
103
+ assert isinstance(resource_quantity, int) or isinstance(
104
+ resource_quantity, float
105
+ ), (
106
+ f"{resource_label} ({type(resource_quantity)}): " f"{resource_quantity}"
107
+ )
108
+ if (
109
+ isinstance(resource_quantity, float)
110
+ and not resource_quantity.is_integer()
111
+ ):
112
+ raise ValueError(
113
+ "Resource quantities must all be whole numbers. "
114
+ "Violated by resource '{}' in {}.".format(resource_label, resources)
115
+ )
116
+ if resource_quantity < 0:
117
+ raise ValueError(
118
+ "Resource quantities must be nonnegative. "
119
+ "Violated by resource '{}' in {}.".format(resource_label, resources)
120
+ )
121
+ if resource_quantity > ray_constants.MAX_RESOURCE_QUANTITY:
122
+ raise ValueError(
123
+ "Resource quantities must be at most {}. "
124
+ "Violated by resource '{}' in {}.".format(
125
+ ray_constants.MAX_RESOURCE_QUANTITY, resource_label, resources
126
+ )
127
+ )
128
+
129
+ return resources
130
+
131
+ def resolve(self, is_head: bool, node_ip_address: Optional[str] = None):
132
+ """Returns a copy with values filled out with system defaults.
133
+
134
+ Args:
135
+ is_head: Whether this is the head node.
136
+ node_ip_address: The IP address of the node that we are on.
137
+ This is used to automatically create a node id resource.
138
+ """
139
+
140
+ resources = (self.resources or {}).copy()
141
+ assert "CPU" not in resources, resources
142
+ assert "GPU" not in resources, resources
143
+ assert "memory" not in resources, resources
144
+ assert "object_store_memory" not in resources, resources
145
+
146
+ if node_ip_address is None:
147
+ node_ip_address = ray.util.get_node_ip_address()
148
+
149
+ # Automatically create a node id resource on each node. This is
150
+ # queryable with ray._private.state.node_ids() and
151
+ # ray._private.state.current_node_id().
152
+ resources[NODE_ID_PREFIX + node_ip_address] = 1.0
153
+
154
+ # Automatically create a head node resource.
155
+ if HEAD_NODE_RESOURCE_NAME in resources:
156
+ raise ValueError(
157
+ f"{HEAD_NODE_RESOURCE_NAME}"
158
+ " is a reserved resource name, use another name instead."
159
+ )
160
+ if is_head:
161
+ resources[HEAD_NODE_RESOURCE_NAME] = 1.0
162
+
163
+ num_cpus = self.num_cpus
164
+ if num_cpus is None:
165
+ num_cpus = ray._private.utils.get_num_cpus()
166
+
167
+ num_gpus = 0
168
+ for (
169
+ accelerator_resource_name
170
+ ) in ray._private.accelerators.get_all_accelerator_resource_names():
171
+ accelerator_manager = (
172
+ ray._private.accelerators.get_accelerator_manager_for_resource(
173
+ accelerator_resource_name
174
+ )
175
+ )
176
+ num_accelerators = None
177
+ if accelerator_resource_name == "GPU":
178
+ num_accelerators = self.num_gpus
179
+ else:
180
+ num_accelerators = resources.get(accelerator_resource_name, None)
181
+ visible_accelerator_ids = (
182
+ accelerator_manager.get_current_process_visible_accelerator_ids()
183
+ )
184
+ # Check that the number of accelerators that the raylet wants doesn't
185
+ # exceed the amount allowed by visible accelerator ids.
186
+ if (
187
+ num_accelerators is not None
188
+ and visible_accelerator_ids is not None
189
+ and num_accelerators > len(visible_accelerator_ids)
190
+ ):
191
+ raise ValueError(
192
+ f"Attempting to start raylet with {num_accelerators} "
193
+ f"{accelerator_resource_name}, "
194
+ f"but {accelerator_manager.get_visible_accelerator_ids_env_var()} "
195
+ f"contains {visible_accelerator_ids}."
196
+ )
197
+ if num_accelerators is None:
198
+ # Try to automatically detect the number of accelerators.
199
+ num_accelerators = (
200
+ accelerator_manager.get_current_node_num_accelerators()
201
+ )
202
+ # Don't use more accelerators than allowed by visible accelerator ids.
203
+ if visible_accelerator_ids is not None:
204
+ num_accelerators = min(
205
+ num_accelerators, len(visible_accelerator_ids)
206
+ )
207
+
208
+ if num_accelerators:
209
+ if accelerator_resource_name == "GPU":
210
+ num_gpus = num_accelerators
211
+ else:
212
+ resources[accelerator_resource_name] = num_accelerators
213
+
214
+ accelerator_type = (
215
+ accelerator_manager.get_current_node_accelerator_type()
216
+ )
217
+ if accelerator_type:
218
+ resources[
219
+ f"{ray_constants.RESOURCE_CONSTRAINT_PREFIX}{accelerator_type}"
220
+ ] = 1
221
+
222
+ from ray._private.usage import usage_lib
223
+
224
+ usage_lib.record_hardware_usage(accelerator_type)
225
+ additional_resources = (
226
+ accelerator_manager.get_current_node_additional_resources()
227
+ )
228
+ if additional_resources:
229
+ resources.update(additional_resources)
230
+ # Choose a default object store size.
231
+ system_memory = ray._private.utils.get_system_memory()
232
+ avail_memory = ray._private.utils.estimate_available_memory()
233
+ object_store_memory = self.object_store_memory
234
+ if object_store_memory is None:
235
+ object_store_memory = int(
236
+ avail_memory * ray_constants.DEFAULT_OBJECT_STORE_MEMORY_PROPORTION
237
+ )
238
+
239
+ # Set the object_store_memory size to 2GB on Mac
240
+ # to avoid degraded performance.
241
+ # (https://github.com/ray-project/ray/issues/20388)
242
+ if sys.platform == "darwin":
243
+ object_store_memory = min(
244
+ object_store_memory, ray_constants.MAC_DEGRADED_PERF_MMAP_SIZE_LIMIT
245
+ )
246
+
247
+ object_store_memory_cap = (
248
+ ray_constants.DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES
249
+ )
250
+
251
+ # Cap by shm size by default to avoid low performance, but don't
252
+ # go lower than REQUIRE_SHM_SIZE_THRESHOLD.
253
+ if sys.platform == "linux" or sys.platform == "linux2":
254
+ # Multiple by 0.95 to give a bit of wiggle-room.
255
+ # https://github.com/ray-project/ray/pull/23034/files
256
+ shm_avail = ray._private.utils.get_shared_memory_bytes() * 0.95
257
+ shm_cap = max(ray_constants.REQUIRE_SHM_SIZE_THRESHOLD, shm_avail)
258
+
259
+ object_store_memory_cap = min(object_store_memory_cap, shm_cap)
260
+
261
+ # Cap memory to avoid memory waste and perf issues on large nodes
262
+ if (
263
+ object_store_memory_cap
264
+ and object_store_memory > object_store_memory_cap
265
+ ):
266
+ logger.debug(
267
+ "Warning: Capping object memory store to {}GB. ".format(
268
+ object_store_memory_cap // 1e9
269
+ )
270
+ + "To increase this further, specify `object_store_memory` "
271
+ "when calling ray.init() or ray start."
272
+ )
273
+ object_store_memory = object_store_memory_cap
274
+
275
+ redis_max_memory = self.redis_max_memory
276
+ if redis_max_memory is None:
277
+ redis_max_memory = min(
278
+ ray_constants.DEFAULT_REDIS_MAX_MEMORY_BYTES,
279
+ max(int(avail_memory * 0.1), ray_constants.REDIS_MINIMUM_MEMORY_BYTES),
280
+ )
281
+ if redis_max_memory < ray_constants.REDIS_MINIMUM_MEMORY_BYTES:
282
+ raise ValueError(
283
+ "Attempting to cap Redis memory usage at {} bytes, "
284
+ "but the minimum allowed is {} bytes.".format(
285
+ redis_max_memory, ray_constants.REDIS_MINIMUM_MEMORY_BYTES
286
+ )
287
+ )
288
+
289
+ memory = self.memory
290
+ if memory is None:
291
+ memory = (
292
+ avail_memory
293
+ - object_store_memory
294
+ - (redis_max_memory if is_head else 0)
295
+ )
296
+ if memory < 100e6 and memory < 0.05 * system_memory:
297
+ raise ValueError(
298
+ "After taking into account object store and redis memory "
299
+ "usage, the amount of memory on this node available for "
300
+ "tasks and actors ({} GB) is less than {}% of total. "
301
+ "You can adjust these settings with "
302
+ "ray.init(memory=<bytes>, "
303
+ "object_store_memory=<bytes>).".format(
304
+ round(memory / 1e9, 2), int(100 * (memory / system_memory))
305
+ )
306
+ )
307
+
308
+ spec = ResourceSpec(
309
+ num_cpus,
310
+ num_gpus,
311
+ memory,
312
+ object_store_memory,
313
+ resources,
314
+ redis_max_memory,
315
+ )
316
+ assert spec.resolved()
317
+ return spec
.venv/lib/python3.11/site-packages/ray/_private/runtime_env/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # List of files to exclude from the Ray directory when using runtime_env for
2
+ # Ray development. These are not necessary in the Ray workers.
3
+ RAY_WORKER_DEV_EXCLUDES = ["raylet", "gcs_server", "cpp/", "tests/", "core/src"]
.venv/lib/python3.11/site-packages/ray/_private/runtime_env/_clonevirtualenv.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import with_statement
4
+
5
+ import logging
6
+ import optparse
7
+ import os
8
+ import os.path
9
+ import re
10
+ import shutil
11
+ import subprocess
12
+ import sys
13
+ import itertools
14
+
15
+ __version__ = "0.5.7"
16
+
17
+
18
+ logger = logging.getLogger()
19
+
20
+
21
+ env_bin_dir = "bin"
22
+ if sys.platform == "win32":
23
+ env_bin_dir = "Scripts"
24
+ _WIN32 = True
25
+ else:
26
+ _WIN32 = False
27
+
28
+
29
+ class UserError(Exception):
30
+ pass
31
+
32
+
33
+ def _dirmatch(path, matchwith):
34
+ """Check if path is within matchwith's tree.
35
+ >>> _dirmatch('/home/foo/bar', '/home/foo/bar')
36
+ True
37
+ >>> _dirmatch('/home/foo/bar/', '/home/foo/bar')
38
+ True
39
+ >>> _dirmatch('/home/foo/bar/etc', '/home/foo/bar')
40
+ True
41
+ >>> _dirmatch('/home/foo/bar2', '/home/foo/bar')
42
+ False
43
+ >>> _dirmatch('/home/foo/bar2/etc', '/home/foo/bar')
44
+ False
45
+ """
46
+ matchlen = len(matchwith)
47
+ if path.startswith(matchwith) and path[matchlen : matchlen + 1] in [os.sep, ""]:
48
+ return True
49
+ return False
50
+
51
+
52
+ def _virtualenv_sys(venv_path):
53
+ """obtain version and path info from a virtualenv."""
54
+ executable = os.path.join(venv_path, env_bin_dir, "python")
55
+ if _WIN32:
56
+ env = os.environ.copy()
57
+ else:
58
+ env = {}
59
+ # Must use "executable" as the first argument rather than as the
60
+ # keyword argument "executable" to get correct value from sys.path
61
+ p = subprocess.Popen(
62
+ [
63
+ executable,
64
+ "-c",
65
+ "import sys;"
66
+ 'print ("%d.%d" % (sys.version_info.major, sys.version_info.minor));'
67
+ 'print ("\\n".join(sys.path));',
68
+ ],
69
+ env=env,
70
+ stdout=subprocess.PIPE,
71
+ )
72
+ stdout, err = p.communicate()
73
+ assert not p.returncode and stdout
74
+ lines = stdout.decode("utf-8").splitlines()
75
+ return lines[0], list(filter(bool, lines[1:]))
76
+
77
+
78
+ def clone_virtualenv(src_dir, dst_dir):
79
+ if not os.path.exists(src_dir):
80
+ raise UserError("src dir %r does not exist" % src_dir)
81
+ if os.path.exists(dst_dir):
82
+ raise UserError("dest dir %r exists" % dst_dir)
83
+ # sys_path = _virtualenv_syspath(src_dir)
84
+ logger.info("cloning virtualenv '%s' => '%s'..." % (src_dir, dst_dir))
85
+ shutil.copytree(
86
+ src_dir, dst_dir, symlinks=True, ignore=shutil.ignore_patterns("*.pyc")
87
+ )
88
+ version, sys_path = _virtualenv_sys(dst_dir)
89
+ logger.info("fixing scripts in bin...")
90
+ fixup_scripts(src_dir, dst_dir, version)
91
+
92
+ has_old = lambda s: any(i for i in s if _dirmatch(i, src_dir)) # noqa: E731
93
+
94
+ if has_old(sys_path):
95
+ # only need to fix stuff in sys.path if we have old
96
+ # paths in the sys.path of new python env. right?
97
+ logger.info("fixing paths in sys.path...")
98
+ fixup_syspath_items(sys_path, src_dir, dst_dir)
99
+ v_sys = _virtualenv_sys(dst_dir)
100
+ remaining = has_old(v_sys[1])
101
+ assert not remaining, v_sys
102
+ fix_symlink_if_necessary(src_dir, dst_dir)
103
+
104
+
105
+ def fix_symlink_if_necessary(src_dir, dst_dir):
106
+ # sometimes the source virtual environment has symlinks that point to itself
107
+ # one example is $OLD_VIRTUAL_ENV/local/lib points to $OLD_VIRTUAL_ENV/lib
108
+ # this function makes sure
109
+ # $NEW_VIRTUAL_ENV/local/lib will point to $NEW_VIRTUAL_ENV/lib
110
+ # usually this goes unnoticed unless one tries to upgrade a package though pip,
111
+ # so this bug is hard to find.
112
+ logger.info("scanning for internal symlinks that point to the original virtual env")
113
+ for dirpath, dirnames, filenames in os.walk(dst_dir):
114
+ for a_file in itertools.chain(filenames, dirnames):
115
+ full_file_path = os.path.join(dirpath, a_file)
116
+ if os.path.islink(full_file_path):
117
+ target = os.path.realpath(full_file_path)
118
+ if target.startswith(src_dir):
119
+ new_target = target.replace(src_dir, dst_dir)
120
+ logger.debug("fixing symlink in %s" % (full_file_path,))
121
+ os.remove(full_file_path)
122
+ os.symlink(new_target, full_file_path)
123
+
124
+
125
+ def fixup_scripts(old_dir, new_dir, version, rewrite_env_python=False):
126
+ bin_dir = os.path.join(new_dir, env_bin_dir)
127
+ root, dirs, files = next(os.walk(bin_dir))
128
+ pybinre = re.compile(r"pythonw?([0-9]+(\.[0-9]+(\.[0-9]+)?)?)?$")
129
+ for file_ in files:
130
+ filename = os.path.join(root, file_)
131
+ if file_ in ["python", "python%s" % version, "activate_this.py"]:
132
+ continue
133
+ elif file_.startswith("python") and pybinre.match(file_):
134
+ # ignore other possible python binaries
135
+ continue
136
+ elif file_.endswith(".pyc"):
137
+ # ignore compiled files
138
+ continue
139
+ elif file_ == "activate" or file_.startswith("activate."):
140
+ fixup_activate(os.path.join(root, file_), old_dir, new_dir)
141
+ elif os.path.islink(filename):
142
+ fixup_link(filename, old_dir, new_dir)
143
+ elif os.path.isfile(filename):
144
+ fixup_script_(
145
+ root,
146
+ file_,
147
+ old_dir,
148
+ new_dir,
149
+ version,
150
+ rewrite_env_python=rewrite_env_python,
151
+ )
152
+
153
+
154
+ def fixup_script_(root, file_, old_dir, new_dir, version, rewrite_env_python=False):
155
+ old_shebang = "#!%s/bin/python" % os.path.normcase(os.path.abspath(old_dir))
156
+ new_shebang = "#!%s/bin/python" % os.path.normcase(os.path.abspath(new_dir))
157
+ env_shebang = "#!/usr/bin/env python"
158
+
159
+ filename = os.path.join(root, file_)
160
+ with open(filename, "rb") as f:
161
+ if f.read(2) != b"#!":
162
+ # no shebang
163
+ return
164
+ f.seek(0)
165
+ lines = f.readlines()
166
+
167
+ if not lines:
168
+ # warn: empty script
169
+ return
170
+
171
+ def rewrite_shebang(version=None):
172
+ logger.debug("fixing %s" % filename)
173
+ shebang = new_shebang
174
+ if version:
175
+ shebang = shebang + version
176
+ shebang = (shebang + "\n").encode("utf-8")
177
+ with open(filename, "wb") as f:
178
+ f.write(shebang)
179
+ f.writelines(lines[1:])
180
+
181
+ try:
182
+ bang = lines[0].decode("utf-8").strip()
183
+ except UnicodeDecodeError:
184
+ # binary file
185
+ return
186
+
187
+ # This takes care of the scheme in which shebang is of type
188
+ # '#!/venv/bin/python3' while the version of system python
189
+ # is of type 3.x e.g. 3.5.
190
+ short_version = bang[len(old_shebang) :]
191
+
192
+ if not bang.startswith("#!"):
193
+ return
194
+ elif bang == old_shebang:
195
+ rewrite_shebang()
196
+ elif bang.startswith(old_shebang) and bang[len(old_shebang) :] == version:
197
+ rewrite_shebang(version)
198
+ elif (
199
+ bang.startswith(old_shebang)
200
+ and short_version
201
+ and bang[len(old_shebang) :] == short_version
202
+ ):
203
+ rewrite_shebang(short_version)
204
+ elif rewrite_env_python and bang.startswith(env_shebang):
205
+ if bang == env_shebang:
206
+ rewrite_shebang()
207
+ elif bang[len(env_shebang) :] == version:
208
+ rewrite_shebang(version)
209
+ else:
210
+ # can't do anything
211
+ return
212
+
213
+
214
+ def fixup_activate(filename, old_dir, new_dir):
215
+ logger.debug("fixing %s" % filename)
216
+ with open(filename, "rb") as f:
217
+ data = f.read().decode("utf-8")
218
+
219
+ data = data.replace(old_dir, new_dir)
220
+ with open(filename, "wb") as f:
221
+ f.write(data.encode("utf-8"))
222
+
223
+
224
+ def fixup_link(filename, old_dir, new_dir, target=None):
225
+ logger.debug("fixing %s" % filename)
226
+ if target is None:
227
+ target = os.readlink(filename)
228
+
229
+ origdir = os.path.dirname(os.path.abspath(filename)).replace(new_dir, old_dir)
230
+ if not os.path.isabs(target):
231
+ target = os.path.abspath(os.path.join(origdir, target))
232
+ rellink = True
233
+ else:
234
+ rellink = False
235
+
236
+ if _dirmatch(target, old_dir):
237
+ if rellink:
238
+ # keep relative links, but don't keep original in case it
239
+ # traversed up out of, then back into the venv.
240
+ # so, recreate a relative link from absolute.
241
+ target = target[len(origdir) :].lstrip(os.sep)
242
+ else:
243
+ target = target.replace(old_dir, new_dir, 1)
244
+
245
+ # else: links outside the venv, replaced with absolute path to target.
246
+ _replace_symlink(filename, target)
247
+
248
+
249
+ def _replace_symlink(filename, newtarget):
250
+ tmpfn = "%s.new" % filename
251
+ os.symlink(newtarget, tmpfn)
252
+ os.rename(tmpfn, filename)
253
+
254
+
255
+ def fixup_syspath_items(syspath, old_dir, new_dir):
256
+ for path in syspath:
257
+ if not os.path.isdir(path):
258
+ continue
259
+ path = os.path.normcase(os.path.abspath(path))
260
+ if _dirmatch(path, old_dir):
261
+ path = path.replace(old_dir, new_dir, 1)
262
+ if not os.path.exists(path):
263
+ continue
264
+ elif not _dirmatch(path, new_dir):
265
+ continue
266
+ root, dirs, files = next(os.walk(path))
267
+ for file_ in files:
268
+ filename = os.path.join(root, file_)
269
+ if filename.endswith(".pth"):
270
+ fixup_pth_file(filename, old_dir, new_dir)
271
+ elif filename.endswith(".egg-link"):
272
+ fixup_egglink_file(filename, old_dir, new_dir)
273
+
274
+
275
+ def fixup_pth_file(filename, old_dir, new_dir):
276
+ logger.debug("fixup_pth_file %s" % filename)
277
+
278
+ with open(filename, "r") as f:
279
+ lines = f.readlines()
280
+
281
+ has_change = False
282
+
283
+ for num, line in enumerate(lines):
284
+ line = (line.decode("utf-8") if hasattr(line, "decode") else line).strip()
285
+
286
+ if not line or line.startswith("#") or line.startswith("import "):
287
+ continue
288
+ elif _dirmatch(line, old_dir):
289
+ lines[num] = line.replace(old_dir, new_dir, 1)
290
+ has_change = True
291
+
292
+ if has_change:
293
+ with open(filename, "w") as f:
294
+ payload = os.linesep.join([line.strip() for line in lines]) + os.linesep
295
+ f.write(payload)
296
+
297
+
298
+ def fixup_egglink_file(filename, old_dir, new_dir):
299
+ logger.debug("fixing %s" % filename)
300
+ with open(filename, "rb") as f:
301
+ link = f.read().decode("utf-8").strip()
302
+ if _dirmatch(link, old_dir):
303
+ link = link.replace(old_dir, new_dir, 1)
304
+ with open(filename, "wb") as f:
305
+ link = (link + "\n").encode("utf-8")
306
+ f.write(link)
307
+
308
+
309
+ def main():
310
+ parser = optparse.OptionParser(
311
+ "usage: %prog [options] /path/to/existing/venv /path/to/cloned/venv"
312
+ )
313
+ parser.add_option(
314
+ "-v", action="count", dest="verbose", default=False, help="verbosity"
315
+ )
316
+ options, args = parser.parse_args()
317
+ try:
318
+ old_dir, new_dir = args
319
+ except ValueError:
320
+ print("virtualenv-clone %s" % (__version__,))
321
+ parser.error("not enough arguments given.")
322
+ old_dir = os.path.realpath(old_dir)
323
+ new_dir = os.path.realpath(new_dir)
324
+ loglevel = (logging.WARNING, logging.INFO, logging.DEBUG)[min(2, options.verbose)]
325
+ logging.basicConfig(level=loglevel, format="%(message)s")
326
+ try:
327
+ clone_virtualenv(old_dir, new_dir)
328
+ except UserError:
329
+ e = sys.exc_info()[1]
330
+ parser.error(str(e))
331
+
332
+
333
+ if __name__ == "__main__":
334
+ main()
.venv/lib/python3.11/site-packages/ray/_private/runtime_env/conda.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import json
3
+ import logging
4
+ import os
5
+ import platform
6
+ import runpy
7
+ import shutil
8
+ import subprocess
9
+ import sys
10
+ from pathlib import Path
11
+ from typing import Any, Dict, List, Optional
12
+
13
+ import yaml
14
+ from filelock import FileLock
15
+
16
+ import ray
17
+ from ray._private.runtime_env.conda_utils import (
18
+ create_conda_env_if_needed,
19
+ delete_conda_env,
20
+ get_conda_activate_commands,
21
+ get_conda_info_json,
22
+ get_conda_envs,
23
+ )
24
+ from ray._private.runtime_env.context import RuntimeEnvContext
25
+ from ray._private.runtime_env.packaging import Protocol, parse_uri
26
+ from ray._private.runtime_env.plugin import RuntimeEnvPlugin
27
+ from ray._private.runtime_env.validation import parse_and_validate_conda
28
+ from ray._private.utils import (
29
+ get_directory_size_bytes,
30
+ get_master_wheel_url,
31
+ get_or_create_event_loop,
32
+ get_release_wheel_url,
33
+ get_wheel_filename,
34
+ try_to_create_directory,
35
+ )
36
+
37
+ default_logger = logging.getLogger(__name__)
38
+
39
+ _WIN32 = os.name == "nt"
40
+
41
+
42
+ def _resolve_current_ray_path() -> str:
43
+ # When ray is built from source with pip install -e,
44
+ # ray.__file__ returns .../python/ray/__init__.py and this function returns
45
+ # ".../python".
46
+ # When ray is installed from a prebuilt binary, ray.__file__ returns
47
+ # .../site-packages/ray/__init__.py and this function returns
48
+ # ".../site-packages".
49
+ return os.path.split(os.path.split(ray.__file__)[0])[0]
50
+
51
+
52
+ def _get_ray_setup_spec():
53
+ """Find the Ray setup_spec from the currently running Ray.
54
+
55
+ This function works even when Ray is built from source with pip install -e.
56
+ """
57
+ ray_source_python_path = _resolve_current_ray_path()
58
+ setup_py_path = os.path.join(ray_source_python_path, "setup.py")
59
+ return runpy.run_path(setup_py_path)["setup_spec"]
60
+
61
+
62
+ def _resolve_install_from_source_ray_dependencies():
63
+ """Find the Ray dependencies when Ray is installed from source."""
64
+ deps = (
65
+ _get_ray_setup_spec().install_requires + _get_ray_setup_spec().extras["default"]
66
+ )
67
+ # Remove duplicates
68
+ return list(set(deps))
69
+
70
+
71
+ def _inject_ray_to_conda_site(
72
+ conda_path, logger: Optional[logging.Logger] = default_logger
73
+ ):
74
+ """Write the current Ray site package directory to a new site"""
75
+ if _WIN32:
76
+ python_binary = os.path.join(conda_path, "python")
77
+ else:
78
+ python_binary = os.path.join(conda_path, "bin/python")
79
+ site_packages_path = (
80
+ subprocess.check_output(
81
+ [
82
+ python_binary,
83
+ "-c",
84
+ "import sysconfig; print(sysconfig.get_paths()['purelib'])",
85
+ ]
86
+ )
87
+ .decode()
88
+ .strip()
89
+ )
90
+
91
+ ray_path = _resolve_current_ray_path()
92
+ logger.warning(
93
+ f"Injecting {ray_path} to environment site-packages {site_packages_path} "
94
+ "because _inject_current_ray flag is on."
95
+ )
96
+
97
+ maybe_ray_dir = os.path.join(site_packages_path, "ray")
98
+ if os.path.isdir(maybe_ray_dir):
99
+ logger.warning(f"Replacing existing ray installation with {ray_path}")
100
+ shutil.rmtree(maybe_ray_dir)
101
+
102
+ # See usage of *.pth file at
103
+ # https://docs.python.org/3/library/site.html
104
+ with open(os.path.join(site_packages_path, "ray_shared.pth"), "w") as f:
105
+ f.write(ray_path)
106
+
107
+
108
+ def _current_py_version():
109
+ return ".".join(map(str, sys.version_info[:3])) # like 3.6.10
110
+
111
+
112
+ def _is_m1_mac():
113
+ return sys.platform == "darwin" and platform.machine() == "arm64"
114
+
115
+
116
+ def current_ray_pip_specifier(
117
+ logger: Optional[logging.Logger] = default_logger,
118
+ ) -> Optional[str]:
119
+ """The pip requirement specifier for the running version of Ray.
120
+
121
+ Returns:
122
+ A string which can be passed to `pip install` to install the
123
+ currently running Ray version, or None if running on a version
124
+ built from source locally (likely if you are developing Ray).
125
+
126
+ Examples:
127
+ Returns "https://s3-us-west-2.amazonaws.com/ray-wheels/[..].whl"
128
+ if running a stable release, a nightly or a specific commit
129
+ """
130
+ if os.environ.get("RAY_CI_POST_WHEEL_TESTS"):
131
+ # Running in Buildkite CI after the wheel has been built.
132
+ # Wheels are at in the ray/.whl directory, but use relative path to
133
+ # allow for testing locally if needed.
134
+ return os.path.join(
135
+ Path(ray.__file__).resolve().parents[2], ".whl", get_wheel_filename()
136
+ )
137
+ elif ray.__commit__ == "{{RAY_COMMIT_SHA}}":
138
+ # Running on a version built from source locally.
139
+ if os.environ.get("RAY_RUNTIME_ENV_LOCAL_DEV_MODE") != "1":
140
+ logger.warning(
141
+ "Current Ray version could not be detected, most likely "
142
+ "because you have manually built Ray from source. To use "
143
+ "runtime_env in this case, set the environment variable "
144
+ "RAY_RUNTIME_ENV_LOCAL_DEV_MODE=1."
145
+ )
146
+ return None
147
+ elif "dev" in ray.__version__:
148
+ # Running on a nightly wheel.
149
+ if _is_m1_mac():
150
+ raise ValueError("Nightly wheels are not available for M1 Macs.")
151
+ return get_master_wheel_url()
152
+ else:
153
+ if _is_m1_mac():
154
+ # M1 Mac release wheels are currently not uploaded to AWS S3; they
155
+ # are only available on PyPI. So unfortunately, this codepath is
156
+ # not end-to-end testable prior to the release going live on PyPI.
157
+ return f"ray=={ray.__version__}"
158
+ else:
159
+ return get_release_wheel_url()
160
+
161
+
162
+ def inject_dependencies(
163
+ conda_dict: Dict[Any, Any],
164
+ py_version: str,
165
+ pip_dependencies: Optional[List[str]] = None,
166
+ ) -> Dict[Any, Any]:
167
+ """Add Ray, Python and (optionally) extra pip dependencies to a conda dict.
168
+
169
+ Args:
170
+ conda_dict: A dict representing the JSON-serialized conda
171
+ environment YAML file. This dict will be modified and returned.
172
+ py_version: A string representing a Python version to inject
173
+ into the conda dependencies, e.g. "3.7.7"
174
+ pip_dependencies (List[str]): A list of pip dependencies that
175
+ will be prepended to the list of pip dependencies in
176
+ the conda dict. If the conda dict does not already have a "pip"
177
+ field, one will be created.
178
+ Returns:
179
+ The modified dict. (Note: the input argument conda_dict is modified
180
+ and returned.)
181
+ """
182
+ if pip_dependencies is None:
183
+ pip_dependencies = []
184
+ if conda_dict.get("dependencies") is None:
185
+ conda_dict["dependencies"] = []
186
+
187
+ # Inject Python dependency.
188
+ deps = conda_dict["dependencies"]
189
+
190
+ # Add current python dependency. If the user has already included a
191
+ # python version dependency, conda will raise a readable error if the two
192
+ # are incompatible, e.g:
193
+ # ResolvePackageNotFound: - python[version='3.5.*,>=3.6']
194
+ deps.append(f"python={py_version}")
195
+
196
+ if "pip" not in deps:
197
+ deps.append("pip")
198
+
199
+ # Insert pip dependencies.
200
+ found_pip_dict = False
201
+ for dep in deps:
202
+ if isinstance(dep, dict) and dep.get("pip") and isinstance(dep["pip"], list):
203
+ dep["pip"] = pip_dependencies + dep["pip"]
204
+ found_pip_dict = True
205
+ break
206
+ if not found_pip_dict:
207
+ deps.append({"pip": pip_dependencies})
208
+
209
+ return conda_dict
210
+
211
+
212
+ def _get_conda_env_hash(conda_dict: Dict) -> str:
213
+ # Set `sort_keys=True` so that different orderings yield the same hash.
214
+ serialized_conda_spec = json.dumps(conda_dict, sort_keys=True)
215
+ hash = hashlib.sha1(serialized_conda_spec.encode("utf-8")).hexdigest()
216
+ return hash
217
+
218
+
219
+ def get_uri(runtime_env: Dict) -> Optional[str]:
220
+ """Return `"conda://<hashed_dependencies>"`, or None if no GC required."""
221
+ conda = runtime_env.get("conda")
222
+ if conda is not None:
223
+ if isinstance(conda, str):
224
+ # User-preinstalled conda env. We don't garbage collect these, so
225
+ # we don't track them with URIs.
226
+ uri = None
227
+ elif isinstance(conda, dict):
228
+ uri = f"conda://{_get_conda_env_hash(conda_dict=conda)}"
229
+ else:
230
+ raise TypeError(
231
+ "conda field received by RuntimeEnvAgent must be "
232
+ f"str or dict, not {type(conda).__name__}."
233
+ )
234
+ else:
235
+ uri = None
236
+ return uri
237
+
238
+
239
+ def _get_conda_dict_with_ray_inserted(
240
+ runtime_env: "RuntimeEnv", # noqa: F821
241
+ logger: Optional[logging.Logger] = default_logger,
242
+ ) -> Dict[str, Any]:
243
+ """Returns the conda spec with the Ray and `python` dependency inserted."""
244
+ conda_dict = json.loads(runtime_env.conda_config())
245
+ assert conda_dict is not None
246
+
247
+ ray_pip = current_ray_pip_specifier(logger=logger)
248
+ if ray_pip:
249
+ extra_pip_dependencies = [ray_pip, "ray[default]"]
250
+ elif runtime_env.get_extension("_inject_current_ray"):
251
+ extra_pip_dependencies = _resolve_install_from_source_ray_dependencies()
252
+ else:
253
+ extra_pip_dependencies = []
254
+ conda_dict = inject_dependencies(
255
+ conda_dict, _current_py_version(), extra_pip_dependencies
256
+ )
257
+ return conda_dict
258
+
259
+
260
+ class CondaPlugin(RuntimeEnvPlugin):
261
+
262
+ name = "conda"
263
+
264
+ def __init__(self, resources_dir: str):
265
+ self._resources_dir = os.path.join(resources_dir, "conda")
266
+ try_to_create_directory(self._resources_dir)
267
+
268
+ # It is not safe for multiple processes to install conda envs
269
+ # concurrently, even if the envs are different, so use a global
270
+ # lock for all conda installs and deletions.
271
+ # See https://github.com/ray-project/ray/issues/17086
272
+ self._installs_and_deletions_file_lock = os.path.join(
273
+ self._resources_dir, "ray-conda-installs-and-deletions.lock"
274
+ )
275
+ # A set of named conda environments (instead of yaml or dict)
276
+ # that are validated to exist.
277
+ # NOTE: It has to be only used within the same thread, which
278
+ # is an event loop.
279
+ # Also, we don't need to GC this field because it is pretty small.
280
+ self._validated_named_conda_env = set()
281
+
282
+ def _get_path_from_hash(self, hash: str) -> str:
283
+ """Generate a path from the hash of a conda or pip spec.
284
+
285
+ The output path also functions as the name of the conda environment
286
+ when using the `--prefix` option to `conda create` and `conda remove`.
287
+
288
+ Example output:
289
+ /tmp/ray/session_2021-11-03_16-33-59_356303_41018/runtime_resources
290
+ /conda/ray-9a7972c3a75f55e976e620484f58410c920db091
291
+ """
292
+ return os.path.join(self._resources_dir, hash)
293
+
294
+ def get_uris(self, runtime_env: "RuntimeEnv") -> List[str]: # noqa: F821
295
+ """Return the conda URI from the RuntimeEnv if it exists, else return []."""
296
+ conda_uri = runtime_env.conda_uri()
297
+ if conda_uri:
298
+ return [conda_uri]
299
+ return []
300
+
301
+ def delete_uri(
302
+ self, uri: str, logger: Optional[logging.Logger] = default_logger
303
+ ) -> int:
304
+ """Delete URI and return the number of bytes deleted."""
305
+ logger.info(f"Got request to delete URI {uri}")
306
+ protocol, hash = parse_uri(uri)
307
+ if protocol != Protocol.CONDA:
308
+ raise ValueError(
309
+ "CondaPlugin can only delete URIs with protocol "
310
+ f"conda. Received protocol {protocol}, URI {uri}"
311
+ )
312
+
313
+ conda_env_path = self._get_path_from_hash(hash)
314
+ local_dir_size = get_directory_size_bytes(conda_env_path)
315
+
316
+ with FileLock(self._installs_and_deletions_file_lock):
317
+ successful = delete_conda_env(prefix=conda_env_path, logger=logger)
318
+ if not successful:
319
+ logger.warning(f"Error when deleting conda env {conda_env_path}. ")
320
+ return 0
321
+
322
+ return local_dir_size
323
+
324
+ async def create(
325
+ self,
326
+ uri: Optional[str],
327
+ runtime_env: "RuntimeEnv", # noqa: F821
328
+ context: RuntimeEnvContext,
329
+ logger: logging.Logger = default_logger,
330
+ ) -> int:
331
+ if not runtime_env.has_conda():
332
+ return 0
333
+
334
+ def _create():
335
+ result = parse_and_validate_conda(runtime_env.get("conda"))
336
+
337
+ if isinstance(result, str):
338
+ # The conda env name is given.
339
+ # In this case, we only verify if the given
340
+ # conda env exists.
341
+
342
+ # If the env is already validated, do nothing.
343
+ if result in self._validated_named_conda_env:
344
+ return 0
345
+
346
+ conda_info = get_conda_info_json()
347
+ envs = get_conda_envs(conda_info)
348
+
349
+ # We accept `result` as a conda name or full path.
350
+ if not any(result == env[0] or result == env[1] for env in envs):
351
+ raise ValueError(
352
+ f"The given conda environment '{result}' "
353
+ f"from the runtime env {runtime_env} doesn't "
354
+ "exist from the output of `conda info --json`. "
355
+ "You can only specify an env that already exists. "
356
+ f"Please make sure to create an env {result} "
357
+ )
358
+ self._validated_named_conda_env.add(result)
359
+ return 0
360
+
361
+ logger.debug(
362
+ "Setting up conda for runtime_env: " f"{runtime_env.serialize()}"
363
+ )
364
+ protocol, hash = parse_uri(uri)
365
+ conda_env_name = self._get_path_from_hash(hash)
366
+
367
+ conda_dict = _get_conda_dict_with_ray_inserted(runtime_env, logger=logger)
368
+
369
+ logger.info(f"Setting up conda environment with {runtime_env}")
370
+ with FileLock(self._installs_and_deletions_file_lock):
371
+ try:
372
+ conda_yaml_file = os.path.join(
373
+ self._resources_dir, "environment.yml"
374
+ )
375
+ with open(conda_yaml_file, "w") as file:
376
+ yaml.dump(conda_dict, file)
377
+ create_conda_env_if_needed(
378
+ conda_yaml_file, prefix=conda_env_name, logger=logger
379
+ )
380
+ finally:
381
+ os.remove(conda_yaml_file)
382
+
383
+ if runtime_env.get_extension("_inject_current_ray"):
384
+ _inject_ray_to_conda_site(conda_path=conda_env_name, logger=logger)
385
+ logger.info(f"Finished creating conda environment at {conda_env_name}")
386
+ return get_directory_size_bytes(conda_env_name)
387
+
388
+ loop = get_or_create_event_loop()
389
+ return await loop.run_in_executor(None, _create)
390
+
391
+ def modify_context(
392
+ self,
393
+ uris: List[str],
394
+ runtime_env: "RuntimeEnv", # noqa: F821
395
+ context: RuntimeEnvContext,
396
+ logger: Optional[logging.Logger] = default_logger,
397
+ ):
398
+ if not runtime_env.has_conda():
399
+ return
400
+
401
+ if runtime_env.conda_env_name():
402
+ conda_env_name = runtime_env.conda_env_name()
403
+ else:
404
+ protocol, hash = parse_uri(runtime_env.conda_uri())
405
+ conda_env_name = self._get_path_from_hash(hash)
406
+ context.py_executable = "python"
407
+ context.command_prefix += get_conda_activate_commands(conda_env_name)
.venv/lib/python3.11/site-packages/ray/_private/runtime_env/conda_utils.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import shutil
4
+ import subprocess
5
+ import hashlib
6
+ import json
7
+ from typing import Optional, List, Union, Tuple
8
+
9
+ """Utilities for conda. Adapted from https://github.com/mlflow/mlflow."""
10
+
11
+ # Name of environment variable indicating a path to a conda installation. Ray
12
+ # will default to running "conda" if unset.
13
+ RAY_CONDA_HOME = "RAY_CONDA_HOME"
14
+
15
+ _WIN32 = os.name == "nt"
16
+
17
+
18
+ def get_conda_activate_commands(conda_env_name: str) -> List[str]:
19
+ """
20
+ Get a list of commands to run to silently activate the given conda env.
21
+ """
22
+ # Checking for newer conda versions
23
+ if not _WIN32 and ("CONDA_EXE" in os.environ or RAY_CONDA_HOME in os.environ):
24
+ conda_path = get_conda_bin_executable("conda")
25
+ activate_conda_env = [
26
+ ".",
27
+ f"{os.path.dirname(conda_path)}/../etc/profile.d/conda.sh",
28
+ "&&",
29
+ ]
30
+ activate_conda_env += ["conda", "activate", conda_env_name]
31
+
32
+ else:
33
+ activate_path = get_conda_bin_executable("activate")
34
+ if not _WIN32:
35
+ # Use bash command syntax
36
+ activate_conda_env = ["source", activate_path, conda_env_name]
37
+ else:
38
+ activate_conda_env = ["conda", "activate", conda_env_name]
39
+ return activate_conda_env + ["1>&2", "&&"]
40
+
41
+
42
+ def get_conda_bin_executable(executable_name: str) -> str:
43
+ """
44
+ Return path to the specified executable, assumed to be discoverable within
45
+ a conda installation.
46
+
47
+ The conda home directory (expected to contain a 'bin' subdirectory on
48
+ linux) is configurable via the ``RAY_CONDA_HOME`` environment variable. If
49
+ ``RAY_CONDA_HOME`` is unspecified, try the ``CONDA_EXE`` environment
50
+ variable set by activating conda. If neither is specified, this method
51
+ returns `executable_name`.
52
+ """
53
+ conda_home = os.environ.get(RAY_CONDA_HOME)
54
+ if conda_home:
55
+ if _WIN32:
56
+ candidate = os.path.join(conda_home, "%s.exe" % executable_name)
57
+ if os.path.exists(candidate):
58
+ return candidate
59
+ candidate = os.path.join(conda_home, "%s.bat" % executable_name)
60
+ if os.path.exists(candidate):
61
+ return candidate
62
+ else:
63
+ return os.path.join(conda_home, "bin/%s" % executable_name)
64
+ else:
65
+ conda_home = "."
66
+ # Use CONDA_EXE as per https://github.com/conda/conda/issues/7126
67
+ if "CONDA_EXE" in os.environ:
68
+ conda_bin_dir = os.path.dirname(os.environ["CONDA_EXE"])
69
+ if _WIN32:
70
+ candidate = os.path.join(conda_home, "%s.exe" % executable_name)
71
+ if os.path.exists(candidate):
72
+ return candidate
73
+ candidate = os.path.join(conda_home, "%s.bat" % executable_name)
74
+ if os.path.exists(candidate):
75
+ return candidate
76
+ else:
77
+ return os.path.join(conda_bin_dir, executable_name)
78
+ if _WIN32:
79
+ return executable_name + ".bat"
80
+ return executable_name
81
+
82
+
83
+ def _get_conda_env_name(conda_env_path: str) -> str:
84
+ conda_env_contents = open(conda_env_path).read()
85
+ return "ray-%s" % hashlib.sha1(conda_env_contents.encode("utf-8")).hexdigest()
86
+
87
+
88
+ def create_conda_env_if_needed(
89
+ conda_yaml_file: str, prefix: str, logger: Optional[logging.Logger] = None
90
+ ) -> None:
91
+ """
92
+ Given a conda YAML, creates a conda environment containing the required
93
+ dependencies if such a conda environment doesn't already exist.
94
+ Args:
95
+ conda_yaml_file: The path to a conda `environment.yml` file.
96
+ prefix: Directory to install the environment into via
97
+ the `--prefix` option to conda create. This also becomes the name
98
+ of the conda env; i.e. it can be passed into `conda activate` and
99
+ `conda remove`
100
+ """
101
+ if logger is None:
102
+ logger = logging.getLogger(__name__)
103
+
104
+ conda_path = get_conda_bin_executable("conda")
105
+ try:
106
+ exec_cmd([conda_path, "--help"], throw_on_error=False)
107
+ except (EnvironmentError, FileNotFoundError):
108
+ raise ValueError(
109
+ f"Could not find Conda executable at '{conda_path}'. "
110
+ "Ensure Conda is installed as per the instructions at "
111
+ "https://conda.io/projects/conda/en/latest/"
112
+ "user-guide/install/index.html. "
113
+ "You can also configure Ray to look for a specific "
114
+ f"Conda executable by setting the {RAY_CONDA_HOME} "
115
+ "environment variable to the path of the Conda executable."
116
+ )
117
+
118
+ _, stdout, _ = exec_cmd([conda_path, "env", "list", "--json"])
119
+ envs = json.loads(stdout[stdout.index("{") :])["envs"]
120
+
121
+ if prefix in envs:
122
+ logger.info(f"Conda environment {prefix} already exists.")
123
+ return
124
+
125
+ create_cmd = [
126
+ conda_path,
127
+ "env",
128
+ "create",
129
+ "--file",
130
+ conda_yaml_file,
131
+ "--prefix",
132
+ prefix,
133
+ ]
134
+
135
+ logger.info(f"Creating conda environment {prefix}")
136
+ exit_code, output = exec_cmd_stream_to_logger(create_cmd, logger)
137
+ if exit_code != 0:
138
+ if os.path.exists(prefix):
139
+ shutil.rmtree(prefix)
140
+ raise RuntimeError(
141
+ f"Failed to install conda environment {prefix}:\nOutput:\n{output}"
142
+ )
143
+
144
+
145
+ def delete_conda_env(prefix: str, logger: Optional[logging.Logger] = None) -> bool:
146
+ if logger is None:
147
+ logger = logging.getLogger(__name__)
148
+
149
+ logger.info(f"Deleting conda environment {prefix}")
150
+
151
+ conda_path = get_conda_bin_executable("conda")
152
+ delete_cmd = [conda_path, "remove", "-p", prefix, "--all", "-y"]
153
+ exit_code, output = exec_cmd_stream_to_logger(delete_cmd, logger)
154
+
155
+ if exit_code != 0:
156
+ logger.debug(f"Failed to delete conda environment {prefix}:\n{output}")
157
+ return False
158
+
159
+ return True
160
+
161
+
162
+ def get_conda_env_list() -> list:
163
+ """
164
+ Get conda env list in full paths.
165
+ """
166
+ conda_path = get_conda_bin_executable("conda")
167
+ try:
168
+ exec_cmd([conda_path, "--help"], throw_on_error=False)
169
+ except EnvironmentError:
170
+ raise ValueError(f"Could not find Conda executable at {conda_path}.")
171
+ _, stdout, _ = exec_cmd([conda_path, "env", "list", "--json"])
172
+ envs = json.loads(stdout)["envs"]
173
+ return envs
174
+
175
+
176
+ def get_conda_info_json() -> dict:
177
+ """
178
+ Get `conda info --json` output.
179
+
180
+ Returns dict of conda info. See [1] for more details. We mostly care about these
181
+ keys:
182
+
183
+ - `conda_prefix`: str The path to the conda installation.
184
+ - `envs`: List[str] absolute paths to conda environments.
185
+
186
+ [1] https://github.com/conda/conda/blob/main/conda/cli/main_info.py
187
+ """
188
+ conda_path = get_conda_bin_executable("conda")
189
+ try:
190
+ exec_cmd([conda_path, "--help"], throw_on_error=False)
191
+ except EnvironmentError:
192
+ raise ValueError(f"Could not find Conda executable at {conda_path}.")
193
+ _, stdout, _ = exec_cmd([conda_path, "info", "--json"])
194
+ return json.loads(stdout)
195
+
196
+
197
+ def get_conda_envs(conda_info: dict) -> List[Tuple[str, str]]:
198
+ """
199
+ Gets the conda environments, as a list of (name, path) tuples.
200
+ """
201
+ prefix = conda_info["conda_prefix"]
202
+ ret = []
203
+ for env in conda_info["envs"]:
204
+ if env == prefix:
205
+ ret.append(("base", env))
206
+ else:
207
+ ret.append((os.path.basename(env), env))
208
+ return ret
209
+
210
+
211
+ class ShellCommandException(Exception):
212
+ pass
213
+
214
+
215
+ def exec_cmd(
216
+ cmd: List[str], throw_on_error: bool = True, logger: Optional[logging.Logger] = None
217
+ ) -> Union[int, Tuple[int, str, str]]:
218
+ """
219
+ Runs a command as a child process.
220
+
221
+ A convenience wrapper for running a command from a Python script.
222
+
223
+ Note on the return value: A tuple of the exit code,
224
+ standard output and standard error is returned.
225
+
226
+ Args:
227
+ cmd: the command to run, as a list of strings
228
+ throw_on_error: if true, raises an Exception if the exit code of the
229
+ program is nonzero
230
+ """
231
+ child = subprocess.Popen(
232
+ cmd,
233
+ stdout=subprocess.PIPE,
234
+ stdin=subprocess.PIPE,
235
+ stderr=subprocess.PIPE,
236
+ universal_newlines=True,
237
+ )
238
+ (stdout, stderr) = child.communicate()
239
+ exit_code = child.wait()
240
+ if throw_on_error and exit_code != 0:
241
+ raise ShellCommandException(
242
+ "Non-zero exit code: %s\n\nSTDOUT:\n%s\n\nSTDERR:%s"
243
+ % (exit_code, stdout, stderr)
244
+ )
245
+ return exit_code, stdout, stderr
246
+
247
+
248
+ def exec_cmd_stream_to_logger(
249
+ cmd: List[str], logger: logging.Logger, n_lines: int = 50, **kwargs
250
+ ) -> Tuple[int, str]:
251
+ """Runs a command as a child process, streaming output to the logger.
252
+
253
+ The last n_lines lines of output are also returned (stdout and stderr).
254
+ """
255
+ if "env" in kwargs and _WIN32 and "PATH" not in [x.upper() for x in kwargs.keys]:
256
+ raise ValueError("On windows, Popen requires 'PATH' in 'env'")
257
+ child = subprocess.Popen(
258
+ cmd,
259
+ universal_newlines=True,
260
+ stdout=subprocess.PIPE,
261
+ stderr=subprocess.STDOUT,
262
+ **kwargs,
263
+ )
264
+ last_n_lines = []
265
+ with child.stdout:
266
+ for line in iter(child.stdout.readline, b""):
267
+ exit_code = child.poll()
268
+ if exit_code is not None:
269
+ break
270
+ line = line.strip()
271
+ if not line:
272
+ continue
273
+ last_n_lines.append(line.strip())
274
+ last_n_lines = last_n_lines[-n_lines:]
275
+ logger.info(line.strip())
276
+
277
+ exit_code = child.wait()
278
+ return exit_code, "\n".join(last_n_lines)
.venv/lib/python3.11/site-packages/ray/_private/runtime_env/constants.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Env var set by job manager to pass runtime env and metadata to subprocess
2
+ RAY_JOB_CONFIG_JSON_ENV_VAR = "RAY_JOB_CONFIG_JSON_ENV_VAR"
3
+
4
+ # The plugin config which should be loaded when ray cluster starts.
5
+ # It is a json formatted config,
6
+ # e.g. [{"class": "xxx.xxx.xxx_plugin", "priority": 10}].
7
+ RAY_RUNTIME_ENV_PLUGINS_ENV_VAR = "RAY_RUNTIME_ENV_PLUGINS"
8
+
9
+ # The field name of plugin class in the plugin config.
10
+ RAY_RUNTIME_ENV_CLASS_FIELD_NAME = "class"
11
+
12
+ # The field name of priority in the plugin config.
13
+ RAY_RUNTIME_ENV_PRIORITY_FIELD_NAME = "priority"
14
+
15
+ # The default priority of runtime env plugin.
16
+ RAY_RUNTIME_ENV_PLUGIN_DEFAULT_PRIORITY = 10
17
+
18
+ # The minimum priority of runtime env plugin.
19
+ RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY = 0
20
+
21
+ # The maximum priority of runtime env plugin.
22
+ RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY = 100
23
+
24
+ # The schema files or directories of plugins which should be loaded in workers.
25
+ RAY_RUNTIME_ENV_PLUGIN_SCHEMAS_ENV_VAR = "RAY_RUNTIME_ENV_PLUGIN_SCHEMAS"
26
+
27
+ # The file suffix of runtime env plugin schemas.
28
+ RAY_RUNTIME_ENV_PLUGIN_SCHEMA_SUFFIX = ".json"
.venv/lib/python3.11/site-packages/ray/_private/runtime_env/context.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import subprocess
5
+ import shlex
6
+ import sys
7
+ from typing import Dict, List, Optional
8
+
9
+ from ray.util.annotations import DeveloperAPI
10
+ from ray.core.generated.common_pb2 import Language
11
+ from ray._private.services import get_ray_jars_dir
12
+ from ray._private.utils import update_envs
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ @DeveloperAPI
18
+ class RuntimeEnvContext:
19
+ """A context used to describe the created runtime env."""
20
+
21
+ def __init__(
22
+ self,
23
+ command_prefix: List[str] = None,
24
+ env_vars: Dict[str, str] = None,
25
+ py_executable: Optional[str] = None,
26
+ override_worker_entrypoint: Optional[str] = None,
27
+ java_jars: List[str] = None,
28
+ ):
29
+ self.command_prefix = command_prefix or []
30
+ self.env_vars = env_vars or {}
31
+ self.py_executable = py_executable or sys.executable
32
+ self.override_worker_entrypoint: Optional[str] = override_worker_entrypoint
33
+ self.java_jars = java_jars or []
34
+
35
+ def serialize(self) -> str:
36
+ return json.dumps(self.__dict__)
37
+
38
+ @staticmethod
39
+ def deserialize(json_string):
40
+ return RuntimeEnvContext(**json.loads(json_string))
41
+
42
+ def exec_worker(self, passthrough_args: List[str], language: Language):
43
+ update_envs(self.env_vars)
44
+
45
+ if language == Language.PYTHON and sys.platform == "win32":
46
+ executable = [self.py_executable]
47
+ elif language == Language.PYTHON:
48
+ executable = ["exec", self.py_executable]
49
+ elif language == Language.JAVA:
50
+ executable = ["java"]
51
+ ray_jars = os.path.join(get_ray_jars_dir(), "*")
52
+
53
+ local_java_jars = []
54
+ for java_jar in self.java_jars:
55
+ local_java_jars.append(f"{java_jar}/*")
56
+ local_java_jars.append(java_jar)
57
+
58
+ class_path_args = ["-cp", ray_jars + ":" + str(":".join(local_java_jars))]
59
+ passthrough_args = class_path_args + passthrough_args
60
+ elif sys.platform == "win32":
61
+ executable = []
62
+ else:
63
+ executable = ["exec"]
64
+
65
+ # By default, raylet uses the path to default_worker.py on host.
66
+ # However, the path to default_worker.py inside the container
67
+ # can be different. We need the user to specify the path to
68
+ # default_worker.py inside the container.
69
+ if self.override_worker_entrypoint:
70
+ logger.debug(
71
+ f"Changing the worker entrypoint from {passthrough_args[0]} to "
72
+ f"{self.override_worker_entrypoint}."
73
+ )
74
+ passthrough_args[0] = self.override_worker_entrypoint
75
+
76
+ if sys.platform == "win32":
77
+
78
+ def quote(s):
79
+ s = s.replace("&", "%26")
80
+ return s
81
+
82
+ passthrough_args = [quote(s) for s in passthrough_args]
83
+
84
+ cmd = [*self.command_prefix, *executable, *passthrough_args]
85
+ logger.debug(f"Exec'ing worker with command: {cmd}")
86
+ subprocess.Popen(cmd, shell=True).wait()
87
+ else:
88
+ # We use shlex to do the necessary shell escape
89
+ # of special characters in passthrough_args.
90
+ passthrough_args = [shlex.quote(s) for s in passthrough_args]
91
+ cmd = [*self.command_prefix, *executable, *passthrough_args]
92
+ # TODO(SongGuyang): We add this env to command for macOS because it doesn't
93
+ # work for the C++ process of `os.execvp`. We should find a better way to
94
+ # fix it.
95
+ MACOS_LIBRARY_PATH_ENV_NAME = "DYLD_LIBRARY_PATH"
96
+ if MACOS_LIBRARY_PATH_ENV_NAME in os.environ:
97
+ cmd.insert(
98
+ 0,
99
+ f"{MACOS_LIBRARY_PATH_ENV_NAME}="
100
+ f"{os.environ[MACOS_LIBRARY_PATH_ENV_NAME]}",
101
+ )
102
+ logger.debug(f"Exec'ing worker with command: {cmd}")
103
+ # PyCharm will monkey patch the os.execvp at
104
+ # .pycharm_helpers/pydev/_pydev_bundle/pydev_monkey.py
105
+ # The monkey patched os.execvp function has a different
106
+ # signature. So, we use os.execvp("executable", args=[])
107
+ # instead of os.execvp(file="executable", args=[])
108
+ os.execvp("bash", args=["bash", "-c", " ".join(cmd)])
.venv/lib/python3.11/site-packages/ray/_private/runtime_env/default_impl.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray._private.runtime_env.image_uri import ImageURIPlugin
2
+
3
+
4
+ def get_image_uri_plugin_cls():
5
+ return ImageURIPlugin
6
+
7
+
8
+ def get_protocols_provider():
9
+ from ray._private.runtime_env.protocol import ProtocolsProvider
10
+
11
+ return ProtocolsProvider
.venv/lib/python3.11/site-packages/ray/_private/runtime_env/dependency_utils.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Util functions to manage dependency requirements."""
2
+
3
+ from typing import List, Tuple, Optional
4
+ import os
5
+ import tempfile
6
+ import logging
7
+ from contextlib import asynccontextmanager
8
+ from ray._private.runtime_env import virtualenv_utils
9
+ from ray._private.runtime_env.utils import check_output_cmd
10
+
11
+ INTERNAL_PIP_FILENAME = "ray_runtime_env_internal_pip_requirements.txt"
12
+ MAX_INTERNAL_PIP_FILENAME_TRIES = 100
13
+
14
+
15
+ def gen_requirements_txt(requirements_file: str, pip_packages: List[str]):
16
+ """Dump [pip_packages] to the given [requirements_file] for later env setup."""
17
+ with open(requirements_file, "w") as file:
18
+ for line in pip_packages:
19
+ file.write(line + "\n")
20
+
21
+
22
+ @asynccontextmanager
23
+ async def check_ray(python: str, cwd: str, logger: logging.Logger):
24
+ """A context manager to check ray is not overwritten.
25
+
26
+ Currently, we only check ray version and path. It works for virtualenv,
27
+ - ray is in Python's site-packages.
28
+ - ray is overwritten during yield.
29
+ - ray is in virtualenv's site-packages.
30
+ """
31
+
32
+ async def _get_ray_version_and_path() -> Tuple[str, str]:
33
+ with tempfile.TemporaryDirectory(
34
+ prefix="check_ray_version_tempfile"
35
+ ) as tmp_dir:
36
+ ray_version_path = os.path.join(tmp_dir, "ray_version.txt")
37
+ check_ray_cmd = [
38
+ python,
39
+ "-c",
40
+ """
41
+ import ray
42
+ with open(r"{ray_version_path}", "wt") as f:
43
+ f.write(ray.__version__)
44
+ f.write(" ")
45
+ f.write(ray.__path__[0])
46
+ """.format(
47
+ ray_version_path=ray_version_path
48
+ ),
49
+ ]
50
+ if virtualenv_utils._WIN32:
51
+ env = os.environ.copy()
52
+ else:
53
+ env = {}
54
+ output = await check_output_cmd(
55
+ check_ray_cmd, logger=logger, cwd=cwd, env=env
56
+ )
57
+ logger.info(f"try to write ray version information in: {ray_version_path}")
58
+ with open(ray_version_path, "rt") as f:
59
+ output = f.read()
60
+ # print after import ray may have  endings, so we strip them by *_
61
+ ray_version, ray_path, *_ = [s.strip() for s in output.split()]
62
+ return ray_version, ray_path
63
+
64
+ version, path = await _get_ray_version_and_path()
65
+ yield
66
+ actual_version, actual_path = await _get_ray_version_and_path()
67
+ if actual_version != version or actual_path != path:
68
+ raise RuntimeError(
69
+ "Changing the ray version is not allowed: \n"
70
+ f" current version: {actual_version}, "
71
+ f"current path: {actual_path}\n"
72
+ f" expect version: {version}, "
73
+ f"expect path: {path}\n"
74
+ "Please ensure the dependencies in the runtime_env pip field "
75
+ "do not install a different version of Ray."
76
+ )
77
+
78
+
79
+ def get_requirements_file(target_dir: str, pip_list: Optional[List[str]]) -> str:
80
+ """Returns the path to the requirements file to use for this runtime env.
81
+
82
+ If pip_list is not None, we will check if the internal pip filename is in any of
83
+ the entries of pip_list. If so, we will append numbers to the end of the
84
+ filename until we find one that doesn't conflict. This prevents infinite
85
+ recursion if the user specifies the internal pip filename in their pip list.
86
+
87
+ Args:
88
+ target_dir: The directory to store the requirements file in.
89
+ pip_list: A list of pip requirements specified by the user.
90
+
91
+ Returns:
92
+ The path to the requirements file to use for this runtime env.
93
+ """
94
+
95
+ def filename_in_pip_list(filename: str) -> bool:
96
+ for pip_entry in pip_list:
97
+ if filename in pip_entry:
98
+ return True
99
+ return False
100
+
101
+ filename = INTERNAL_PIP_FILENAME
102
+ if pip_list is not None:
103
+ i = 1
104
+ while filename_in_pip_list(filename) and i < MAX_INTERNAL_PIP_FILENAME_TRIES:
105
+ filename = f"{INTERNAL_PIP_FILENAME}.{i}"
106
+ i += 1
107
+ if i == MAX_INTERNAL_PIP_FILENAME_TRIES:
108
+ raise RuntimeError(
109
+ "Could not find a valid filename for the internal "
110
+ "pip requirements file. Please specify a different "
111
+ "pip list in your runtime env."
112
+ )
113
+ return os.path.join(target_dir, filename)
.venv/lib/python3.11/site-packages/ray/_private/runtime_env/image_uri.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import List, Optional
4
+
5
+ from ray._private.runtime_env.context import RuntimeEnvContext
6
+ from ray._private.runtime_env.plugin import RuntimeEnvPlugin
7
+ from ray._private.runtime_env.utils import check_output_cmd
8
+
9
+ default_logger = logging.getLogger(__name__)
10
+
11
+
12
+ async def _create_impl(image_uri: str, logger: logging.Logger):
13
+ # Pull image if it doesn't exist
14
+ # Also get path to `default_worker.py` inside the image.
15
+ pull_image_cmd = [
16
+ "podman",
17
+ "run",
18
+ "--rm",
19
+ image_uri,
20
+ "python",
21
+ "-c",
22
+ (
23
+ "import ray._private.workers.default_worker as default_worker; "
24
+ "print(default_worker.__file__)"
25
+ ),
26
+ ]
27
+ logger.info("Pulling image %s", image_uri)
28
+ worker_path = await check_output_cmd(pull_image_cmd, logger=logger)
29
+ return worker_path.strip()
30
+
31
+
32
+ def _modify_context_impl(
33
+ image_uri: str,
34
+ worker_path: str,
35
+ run_options: Optional[List[str]],
36
+ context: RuntimeEnvContext,
37
+ logger: logging.Logger,
38
+ ray_tmp_dir: str,
39
+ ):
40
+ context.override_worker_entrypoint = worker_path
41
+
42
+ container_driver = "podman"
43
+ container_command = [
44
+ container_driver,
45
+ "run",
46
+ "-v",
47
+ ray_tmp_dir + ":" + ray_tmp_dir,
48
+ "--cgroup-manager=cgroupfs",
49
+ "--network=host",
50
+ "--pid=host",
51
+ "--ipc=host",
52
+ # NOTE(zcin): Mounted volumes in rootless containers are
53
+ # owned by the user `root`. The user on host (which will
54
+ # usually be `ray` if this is being run in a ray docker
55
+ # image) who started the container is mapped using user
56
+ # namespaces to the user `root` in a rootless container. In
57
+ # order for the Ray Python worker to access the mounted ray
58
+ # tmp dir, we need to use keep-id mode which maps the user
59
+ # as itself (instead of as `root`) into the container.
60
+ # https://www.redhat.com/sysadmin/rootless-podman-user-namespace-modes
61
+ "--userns=keep-id",
62
+ ]
63
+
64
+ # Environment variables to set in container
65
+ env_vars = dict()
66
+
67
+ # Propagate all host environment variables that have the prefix "RAY_"
68
+ # This should include RAY_RAYLET_PID
69
+ for env_var_name, env_var_value in os.environ.items():
70
+ if env_var_name.startswith("RAY_"):
71
+ env_vars[env_var_name] = env_var_value
72
+
73
+ # Support for runtime_env['env_vars']
74
+ env_vars.update(context.env_vars)
75
+
76
+ # Set environment variables
77
+ for env_var_name, env_var_value in env_vars.items():
78
+ container_command.append("--env")
79
+ container_command.append(f"{env_var_name}='{env_var_value}'")
80
+
81
+ # The RAY_JOB_ID environment variable is needed for the default worker.
82
+ # It won't be set at the time setup() is called, but it will be set
83
+ # when worker command is executed, so we use RAY_JOB_ID=$RAY_JOB_ID
84
+ # for the container start command
85
+ container_command.append("--env")
86
+ container_command.append("RAY_JOB_ID=$RAY_JOB_ID")
87
+
88
+ if run_options:
89
+ container_command.extend(run_options)
90
+ # TODO(chenk008): add resource limit
91
+ container_command.append("--entrypoint")
92
+ container_command.append("python")
93
+ container_command.append(image_uri)
94
+
95
+ # Example:
96
+ # podman run -v /tmp/ray:/tmp/ray
97
+ # --cgroup-manager=cgroupfs --network=host --pid=host --ipc=host
98
+ # --userns=keep-id --env RAY_RAYLET_PID=23478 --env RAY_JOB_ID=$RAY_JOB_ID
99
+ # --entrypoint python rayproject/ray:nightly-py39
100
+ container_command_str = " ".join(container_command)
101
+ logger.info(f"Starting worker in container with prefix {container_command_str}")
102
+
103
+ context.py_executable = container_command_str
104
+
105
+
106
+ class ImageURIPlugin(RuntimeEnvPlugin):
107
+ """Starts worker in a container of a custom image."""
108
+
109
+ name = "image_uri"
110
+
111
+ @staticmethod
112
+ def get_compatible_keys():
113
+ return {"image_uri", "config", "env_vars"}
114
+
115
+ def __init__(self, ray_tmp_dir: str):
116
+ self._ray_tmp_dir = ray_tmp_dir
117
+
118
+ async def create(
119
+ self,
120
+ uri: Optional[str],
121
+ runtime_env: "RuntimeEnv", # noqa: F821
122
+ context: RuntimeEnvContext,
123
+ logger: logging.Logger,
124
+ ) -> float:
125
+ if not runtime_env.image_uri():
126
+ return
127
+
128
+ self.worker_path = await _create_impl(runtime_env.image_uri(), logger)
129
+
130
+ def modify_context(
131
+ self,
132
+ uris: List[str],
133
+ runtime_env: "RuntimeEnv", # noqa: F821
134
+ context: RuntimeEnvContext,
135
+ logger: Optional[logging.Logger] = default_logger,
136
+ ):
137
+ if not runtime_env.image_uri():
138
+ return
139
+
140
+ _modify_context_impl(
141
+ runtime_env.image_uri(),
142
+ self.worker_path,
143
+ [],
144
+ context,
145
+ logger,
146
+ self._ray_tmp_dir,
147
+ )
148
+
149
+
150
+ class ContainerPlugin(RuntimeEnvPlugin):
151
+ """Starts worker in container."""
152
+
153
+ name = "container"
154
+
155
+ def __init__(self, ray_tmp_dir: str):
156
+ self._ray_tmp_dir = ray_tmp_dir
157
+
158
+ async def create(
159
+ self,
160
+ uri: Optional[str],
161
+ runtime_env: "RuntimeEnv", # noqa: F821
162
+ context: RuntimeEnvContext,
163
+ logger: logging.Logger,
164
+ ) -> float:
165
+ if not runtime_env.has_py_container() or not runtime_env.py_container_image():
166
+ return
167
+
168
+ self.worker_path = await _create_impl(runtime_env.py_container_image(), logger)
169
+
170
+ def modify_context(
171
+ self,
172
+ uris: List[str],
173
+ runtime_env: "RuntimeEnv", # noqa: F821
174
+ context: RuntimeEnvContext,
175
+ logger: Optional[logging.Logger] = default_logger,
176
+ ):
177
+ if not runtime_env.has_py_container() or not runtime_env.py_container_image():
178
+ return
179
+
180
+ if runtime_env.py_container_worker_path():
181
+ logger.warning(
182
+ "You are using `container.worker_path`, but the path to "
183
+ "`default_worker.py` is now automatically detected from the image. "
184
+ "`container.worker_path` is deprecated and will be removed in future "
185
+ "versions."
186
+ )
187
+
188
+ _modify_context_impl(
189
+ runtime_env.py_container_image(),
190
+ runtime_env.py_container_worker_path() or self.worker_path,
191
+ runtime_env.py_container_run_options(),
192
+ context,
193
+ logger,
194
+ self._ray_tmp_dir,
195
+ )
.venv/lib/python3.11/site-packages/ray/_private/runtime_env/java_jars.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import Dict, List, Optional
4
+
5
+ from ray._private.gcs_utils import GcsAioClient
6
+ from ray._private.runtime_env.context import RuntimeEnvContext
7
+ from ray._private.runtime_env.packaging import (
8
+ delete_package,
9
+ download_and_unpack_package,
10
+ get_local_dir_from_uri,
11
+ is_jar_uri,
12
+ )
13
+ from ray._private.runtime_env.plugin import RuntimeEnvPlugin
14
+ from ray._private.utils import get_directory_size_bytes, try_to_create_directory
15
+ from ray.exceptions import RuntimeEnvSetupError
16
+
17
+ default_logger = logging.getLogger(__name__)
18
+
19
+
20
+ class JavaJarsPlugin(RuntimeEnvPlugin):
21
+
22
+ name = "java_jars"
23
+
24
+ def __init__(self, resources_dir: str, gcs_aio_client: GcsAioClient):
25
+ self._resources_dir = os.path.join(resources_dir, "java_jars_files")
26
+ self._gcs_aio_client = gcs_aio_client
27
+ try_to_create_directory(self._resources_dir)
28
+
29
+ def _get_local_dir_from_uri(self, uri: str):
30
+ return get_local_dir_from_uri(uri, self._resources_dir)
31
+
32
+ def delete_uri(
33
+ self, uri: str, logger: Optional[logging.Logger] = default_logger
34
+ ) -> int:
35
+ """Delete URI and return the number of bytes deleted."""
36
+ local_dir = get_local_dir_from_uri(uri, self._resources_dir)
37
+ local_dir_size = get_directory_size_bytes(local_dir)
38
+
39
+ deleted = delete_package(uri, self._resources_dir)
40
+ if not deleted:
41
+ logger.warning(f"Tried to delete nonexistent URI: {uri}.")
42
+ return 0
43
+
44
+ return local_dir_size
45
+
46
+ def get_uris(self, runtime_env: dict) -> List[str]:
47
+ return runtime_env.java_jars()
48
+
49
+ async def _download_jars(
50
+ self, uri: str, logger: Optional[logging.Logger] = default_logger
51
+ ):
52
+ """Download a jar URI."""
53
+ try:
54
+ jar_file = await download_and_unpack_package(
55
+ uri, self._resources_dir, self._gcs_aio_client, logger=logger
56
+ )
57
+ except Exception as e:
58
+ raise RuntimeEnvSetupError(
59
+ "Failed to download jar file: {}".format(e)
60
+ ) from e
61
+ module_dir = self._get_local_dir_from_uri(uri)
62
+ logger.debug(f"Succeeded to download jar file {jar_file} .")
63
+ return module_dir
64
+
65
+ async def create(
66
+ self,
67
+ uri: str,
68
+ runtime_env: "RuntimeEnv", # noqa: F821
69
+ context: RuntimeEnvContext,
70
+ logger: Optional[logging.Logger] = default_logger,
71
+ ) -> int:
72
+ if not uri:
73
+ return 0
74
+ if is_jar_uri(uri):
75
+ module_dir = await self._download_jars(uri=uri, logger=logger)
76
+ else:
77
+ try:
78
+ module_dir = await download_and_unpack_package(
79
+ uri, self._resources_dir, self._gcs_aio_client, logger=logger
80
+ )
81
+ except Exception as e:
82
+ raise RuntimeEnvSetupError(
83
+ "Failed to download jar file: {}".format(e)
84
+ ) from e
85
+
86
+ return get_directory_size_bytes(module_dir)
87
+
88
+ def modify_context(
89
+ self,
90
+ uris: List[str],
91
+ runtime_env_dict: Dict,
92
+ context: RuntimeEnvContext,
93
+ logger: Optional[logging.Logger] = default_logger,
94
+ ):
95
+ for uri in uris:
96
+ module_dir = self._get_local_dir_from_uri(uri)
97
+ if not module_dir.exists():
98
+ raise ValueError(
99
+ f"Local directory {module_dir} for URI {uri} does "
100
+ "not exist on the cluster. Something may have gone wrong while "
101
+ "downloading, unpacking or installing the java jar files."
102
+ )
103
+ context.java_jars.append(str(module_dir))
.venv/lib/python3.11/site-packages/ray/_private/runtime_env/mpi.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import List, Optional
4
+ from ray._private.runtime_env.context import RuntimeEnvContext
5
+ from ray._private.runtime_env.plugin import RuntimeEnvPlugin
6
+ import subprocess
7
+
8
+ default_logger = logging.getLogger(__name__)
9
+
10
+
11
+ def mpi_init():
12
+ """Initialize the MPI cluster. When using MPI cluster, this must be called first."""
13
+
14
+ if hasattr(mpi_init, "inited"):
15
+ assert mpi_init.inited is True
16
+ return
17
+
18
+ from mpi4py import MPI
19
+
20
+ comm = MPI.COMM_WORLD
21
+ rank = comm.Get_rank()
22
+ if rank == 0:
23
+ from ray._private.accelerators import get_all_accelerator_managers
24
+
25
+ device_vars = [
26
+ m.get_visible_accelerator_ids_env_var()
27
+ for m in get_all_accelerator_managers()
28
+ ]
29
+ visible_devices = {
30
+ n: os.environ.get(n) for n in device_vars if os.environ.get(n)
31
+ }
32
+ comm.bcast(visible_devices)
33
+ with open(f"/tmp/{os.getpid()}.{rank}", "w") as f:
34
+ f.write(str(visible_devices))
35
+ else:
36
+ visible_devices = comm.bcast(None)
37
+ os.environ.update(visible_devices)
38
+ mpi_init.inited = True
39
+
40
+
41
+ class MPIPlugin(RuntimeEnvPlugin):
42
+ """This plugin enable a MPI cluster to run on top of ray.
43
+
44
+ To use this, "mpi" need to be added to the runtime env like following
45
+
46
+ @ray.remote(
47
+ runtime_env={
48
+ "mpi": {
49
+ "args": ["-n", "4"],
50
+ "worker_entry": worker_entry,
51
+ }
52
+ }
53
+ )
54
+ def calc_pi():
55
+ ...
56
+
57
+ Here worker_entry should be function for the MPI worker to run.
58
+ For example, it should be `'py_module.worker_func'`. The module should be able to
59
+ be imported in the runtime.
60
+
61
+ In the mpi worker with rank==0, it'll be the normal ray function or actor.
62
+ For the worker with rank > 0, it'll just run `worker_func`.
63
+
64
+ ray.runtime_env.mpi_init must be called in the ray actors/tasks before any MPI
65
+ communication.
66
+ """
67
+
68
+ priority = 90
69
+ name = "mpi"
70
+
71
+ def modify_context(
72
+ self,
73
+ uris: List[str], # noqa: ARG002
74
+ runtime_env: "RuntimeEnv", # noqa: F821 ARG002
75
+ context: RuntimeEnvContext,
76
+ logger: Optional[logging.Logger] = default_logger, # noqa: ARG002
77
+ ) -> None:
78
+ mpi_config = runtime_env.mpi()
79
+ if mpi_config is None:
80
+ return
81
+ try:
82
+ proc = subprocess.run(
83
+ ["mpirun", "--version"], capture_output=True, check=True
84
+ )
85
+ except subprocess.CalledProcessError:
86
+ logger.exception(
87
+ "Failed to run mpi run. Please make sure mpi has been installed"
88
+ )
89
+ # The worker will fail to run and exception will be thrown in runtime
90
+ # env agent.
91
+ raise
92
+
93
+ logger.info(f"Running MPI plugin\n {proc.stdout.decode()}")
94
+
95
+ # worker_entry should be a file either in the working dir
96
+ # or visible inside the cluster.
97
+ worker_entry = mpi_config.get("worker_entry")
98
+
99
+ assert (
100
+ worker_entry is not None
101
+ ), "`worker_entry` must be setup in the runtime env."
102
+
103
+ cmds = (
104
+ ["mpirun"]
105
+ + mpi_config.get("args", [])
106
+ + [
107
+ context.py_executable,
108
+ "-m",
109
+ "ray._private.runtime_env.mpi_runner",
110
+ worker_entry,
111
+ ]
112
+ )
113
+ # Construct the start cmd
114
+ context.py_executable = " ".join(cmds)
.venv/lib/python3.11/site-packages/ray/_private/runtime_env/mpi_runner.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import argparse
3
+ import importlib
4
+ from mpi4py import MPI
5
+
6
+
7
+ if __name__ == "__main__":
8
+ parser = argparse.ArgumentParser(description="Setup MPI worker")
9
+ parser.add_argument("worker_entry")
10
+ parser.add_argument("main_entry")
11
+
12
+ args, remaining_args = parser.parse_known_args()
13
+
14
+ comm = MPI.COMM_WORLD
15
+
16
+ rank = comm.Get_rank()
17
+
18
+ if rank == 0:
19
+ entry_file = args.main_entry
20
+
21
+ sys.argv[1:] = remaining_args
22
+ spec = importlib.util.spec_from_file_location("__main__", entry_file)
23
+ mod = importlib.util.module_from_spec(spec)
24
+ spec.loader.exec_module(mod)
25
+ else:
26
+ from ray.runtime_env import mpi_init
27
+
28
+ mpi_init()
29
+ module, func = args.worker_entry.rsplit(".", 1)
30
+ m = importlib.import_module(module)
31
+ f = getattr(m, func)
32
+ f()