Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/ray/dashboard/modules/data/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/log_agent.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/log_consts.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/log_manager.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/dashboard/modules/node/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/dashboard/modules/node/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/dashboard/modules/node/__pycache__/node_consts.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/dashboard/modules/node/node_consts.py +17 -0
- .venv/lib/python3.11/site-packages/ray/dashboard/modules/node/node_head.py +496 -0
- .venv/lib/python3.11/site-packages/ray/util/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/util/__pycache__/actor_group.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/util/__pycache__/actor_pool.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/util/__pycache__/annotations.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/util/__pycache__/check_open_ports.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/util/__pycache__/check_serialize.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/util/__pycache__/client_connect.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/util/__pycache__/debug.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/util/__pycache__/debugpy.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/util/__pycache__/iter.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/util/__pycache__/iter_metrics.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/util/__pycache__/metrics.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/util/__pycache__/placement_group.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/util/__pycache__/queue.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/util/__pycache__/rpdb.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/util/__pycache__/scheduling_strategies.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/util/__pycache__/serialization.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/util/__pycache__/serialization_addons.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/util/__pycache__/timer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/util/accelerators/__init__.py +78 -0
- .venv/lib/python3.11/site-packages/ray/util/accelerators/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/util/accelerators/__pycache__/accelerators.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/util/accelerators/__pycache__/tpu.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/util/accelerators/accelerators.py +33 -0
- .venv/lib/python3.11/site-packages/ray/util/accelerators/tpu.py +39 -0
- .venv/lib/python3.11/site-packages/ray/util/annotations.py +268 -0
- .venv/lib/python3.11/site-packages/ray/util/client/api.py +406 -0
- .venv/lib/python3.11/site-packages/ray/util/client/client_app.py +90 -0
- .venv/lib/python3.11/site-packages/ray/util/client/common.py +956 -0
- .venv/lib/python3.11/site-packages/ray/util/client/dataclient.py +599 -0
- .venv/lib/python3.11/site-packages/ray/util/client/options.py +47 -0
- .venv/lib/python3.11/site-packages/ray/util/client/ray_client_helpers.py +115 -0
- .venv/lib/python3.11/site-packages/ray/util/client/runtime_context.py +65 -0
- .venv/lib/python3.11/site-packages/ray/util/client/worker.py +908 -0
- .venv/lib/python3.11/site-packages/ray/util/dask/__init__.py +63 -0
- .venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/callbacks.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/common.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/optimizations.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/scheduler.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/ray/dashboard/modules/data/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (198 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/log_agent.cpython-311.pyc
ADDED
|
Binary file (16.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/log_consts.cpython-311.pyc
ADDED
|
Binary file (340 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/log_manager.cpython-311.pyc
ADDED
|
Binary file (19.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/dashboard/modules/node/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/dashboard/modules/node/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (199 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/dashboard/modules/node/__pycache__/node_consts.cpython-311.pyc
ADDED
|
Binary file (824 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/dashboard/modules/node/node_consts.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray._private.ray_constants import env_integer
|
| 2 |
+
|
| 3 |
+
NODE_STATS_UPDATE_INTERVAL_SECONDS = env_integer(
|
| 4 |
+
"NODE_STATS_UPDATE_INTERVAL_SECONDS", 15
|
| 5 |
+
)
|
| 6 |
+
RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT = env_integer(
|
| 7 |
+
"RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT", 10
|
| 8 |
+
)
|
| 9 |
+
MAX_COUNT_OF_GCS_RPC_ERROR = 10
|
| 10 |
+
# This is consistent with gcs_node_manager.cc
|
| 11 |
+
MAX_DEAD_NODES_TO_CACHE = env_integer("RAY_maximum_gcs_dead_node_cached_count", 1000)
|
| 12 |
+
RAY_DASHBOARD_NODE_SUBSCRIBER_POLL_SIZE = env_integer(
|
| 13 |
+
"RAY_DASHBOARD_NODE_SUBSCRIBER_POLL_SIZE", 200
|
| 14 |
+
)
|
| 15 |
+
RAY_DASHBOARD_AGENT_POLL_INTERVAL_S = env_integer(
|
| 16 |
+
"RAY_DASHBOARD_AGENT_POLL_INTERVAL_S", 1
|
| 17 |
+
)
|
.venv/lib/python3.11/site-packages/ray/dashboard/modules/node/node_head.py
ADDED
|
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import time
|
| 5 |
+
from collections import deque
|
| 6 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 7 |
+
from itertools import chain
|
| 8 |
+
from typing import AsyncGenerator, Iterable, List
|
| 9 |
+
|
| 10 |
+
import aiohttp.web
|
| 11 |
+
import grpc
|
| 12 |
+
|
| 13 |
+
import ray._private.utils
|
| 14 |
+
import ray.dashboard.consts as dashboard_consts
|
| 15 |
+
import ray.dashboard.optional_utils as dashboard_optional_utils
|
| 16 |
+
import ray.dashboard.utils as dashboard_utils
|
| 17 |
+
from ray._private import ray_constants
|
| 18 |
+
from ray._private.collections_utils import split
|
| 19 |
+
from ray._private.gcs_pubsub import GcsAioNodeInfoSubscriber
|
| 20 |
+
from ray._private.ray_constants import (
|
| 21 |
+
DEBUG_AUTOSCALING_ERROR,
|
| 22 |
+
DEBUG_AUTOSCALING_STATUS,
|
| 23 |
+
env_integer,
|
| 24 |
+
)
|
| 25 |
+
from ray._private.gcs_pubsub import GcsAioResourceUsageSubscriber
|
| 26 |
+
from ray._private.utils import get_or_create_event_loop
|
| 27 |
+
from ray.autoscaler._private.util import (
|
| 28 |
+
LoadMetricsSummary,
|
| 29 |
+
get_per_node_breakdown_as_dict,
|
| 30 |
+
parse_usage,
|
| 31 |
+
)
|
| 32 |
+
from ray.core.generated import gcs_pb2, node_manager_pb2, node_manager_pb2_grpc
|
| 33 |
+
from ray.dashboard.consts import GCS_RPC_TIMEOUT_SECONDS
|
| 34 |
+
from ray.dashboard.datacenter import DataOrganizer, DataSource
|
| 35 |
+
from ray.dashboard.modules.node import node_consts
|
| 36 |
+
from ray.dashboard.modules.node.node_consts import (
|
| 37 |
+
RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT,
|
| 38 |
+
)
|
| 39 |
+
from ray.dashboard.utils import async_loop_forever
|
| 40 |
+
|
| 41 |
+
logger = logging.getLogger(__name__)
|
| 42 |
+
routes = dashboard_optional_utils.DashboardHeadRouteTable
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# NOTE: Executor in this head is intentionally constrained to just 1 thread by
|
| 46 |
+
# default to limit its concurrency, therefore reducing potential for
|
| 47 |
+
# GIL contention
|
| 48 |
+
RAY_DASHBOARD_NODE_HEAD_TPE_MAX_WORKERS = env_integer(
|
| 49 |
+
"RAY_DASHBOARD_NODE_HEAD_TPE_MAX_WORKERS", 1
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _gcs_node_info_to_dict(message: gcs_pb2.GcsNodeInfo) -> dict:
|
| 54 |
+
return dashboard_utils.message_to_dict(
|
| 55 |
+
message, {"nodeId"}, always_print_fields_with_no_presence=True
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def node_stats_to_dict(message):
|
| 60 |
+
decode_keys = {
|
| 61 |
+
"actorId",
|
| 62 |
+
"jobId",
|
| 63 |
+
"taskId",
|
| 64 |
+
"parentTaskId",
|
| 65 |
+
"sourceActorId",
|
| 66 |
+
"callerId",
|
| 67 |
+
"rayletId",
|
| 68 |
+
"workerId",
|
| 69 |
+
"placementGroupId",
|
| 70 |
+
}
|
| 71 |
+
core_workers_stats = message.core_workers_stats
|
| 72 |
+
message.ClearField("core_workers_stats")
|
| 73 |
+
try:
|
| 74 |
+
result = dashboard_utils.message_to_dict(message, decode_keys)
|
| 75 |
+
result["coreWorkersStats"] = [
|
| 76 |
+
dashboard_utils.message_to_dict(
|
| 77 |
+
m, decode_keys, always_print_fields_with_no_presence=True
|
| 78 |
+
)
|
| 79 |
+
for m in core_workers_stats
|
| 80 |
+
]
|
| 81 |
+
return result
|
| 82 |
+
finally:
|
| 83 |
+
message.core_workers_stats.extend(core_workers_stats)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class NodeHead(dashboard_utils.DashboardHeadModule):
|
| 87 |
+
def __init__(self, config: dashboard_utils.DashboardHeadModuleConfig):
|
| 88 |
+
super().__init__(config)
|
| 89 |
+
|
| 90 |
+
self._stubs = {}
|
| 91 |
+
self._collect_memory_info = False
|
| 92 |
+
|
| 93 |
+
DataSource.nodes.signal.append(self._update_stubs)
|
| 94 |
+
# The time where the module is started.
|
| 95 |
+
self._module_start_time = time.time()
|
| 96 |
+
# The time it takes until the head node is registered. None means
|
| 97 |
+
# head node hasn't been registered.
|
| 98 |
+
self._head_node_registration_time_s = None
|
| 99 |
+
# Queue of dead nodes to be removed, up to MAX_DEAD_NODES_TO_CACHE
|
| 100 |
+
self._dead_node_queue = deque()
|
| 101 |
+
|
| 102 |
+
self._executor = ThreadPoolExecutor(
|
| 103 |
+
max_workers=RAY_DASHBOARD_NODE_HEAD_TPE_MAX_WORKERS,
|
| 104 |
+
thread_name_prefix="node_head_executor",
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
async def _update_stubs(self, change):
|
| 108 |
+
if change.old:
|
| 109 |
+
node_id, node_info = change.old
|
| 110 |
+
self._stubs.pop(node_id, None)
|
| 111 |
+
if change.new:
|
| 112 |
+
# TODO(fyrestone): Handle exceptions.
|
| 113 |
+
node_id, node_info = change.new
|
| 114 |
+
address = "{}:{}".format(
|
| 115 |
+
node_info["nodeManagerAddress"], int(node_info["nodeManagerPort"])
|
| 116 |
+
)
|
| 117 |
+
options = ray_constants.GLOBAL_GRPC_OPTIONS
|
| 118 |
+
channel = ray._private.utils.init_grpc_channel(
|
| 119 |
+
address, options, asynchronous=True
|
| 120 |
+
)
|
| 121 |
+
stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
|
| 122 |
+
self._stubs[node_id] = stub
|
| 123 |
+
|
| 124 |
+
def get_internal_states(self):
|
| 125 |
+
return {
|
| 126 |
+
"head_node_registration_time_s": self._head_node_registration_time_s,
|
| 127 |
+
"registered_nodes": len(DataSource.nodes),
|
| 128 |
+
"registered_agents": len(DataSource.agents),
|
| 129 |
+
"module_lifetime_s": time.time() - self._module_start_time,
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
async def _subscribe_for_node_updates(self) -> AsyncGenerator[dict, None]:
|
| 133 |
+
"""
|
| 134 |
+
Yields the initial state of all nodes, then yields the updated state of nodes.
|
| 135 |
+
|
| 136 |
+
It makes GetAllNodeInfo call only once after the subscription is done, to get
|
| 137 |
+
the initial state of the nodes.
|
| 138 |
+
"""
|
| 139 |
+
subscriber = GcsAioNodeInfoSubscriber(address=self.gcs_address)
|
| 140 |
+
await subscriber.subscribe()
|
| 141 |
+
|
| 142 |
+
# Get all node info from GCS. To prevent Time-of-check to time-of-use issue [1],
|
| 143 |
+
# it happens after the subscription. That is, an update between
|
| 144 |
+
# get-all-node-info and the subscription is not missed.
|
| 145 |
+
# [1] https://en.wikipedia.org/wiki/Time-of-check_to_time-of-use
|
| 146 |
+
all_node_info = await self.gcs_aio_client.get_all_node_info(timeout=None)
|
| 147 |
+
|
| 148 |
+
def _convert_to_dict(messages: Iterable[gcs_pb2.GcsNodeInfo]) -> List[dict]:
|
| 149 |
+
return [_gcs_node_info_to_dict(m) for m in messages]
|
| 150 |
+
|
| 151 |
+
all_node_infos = await get_or_create_event_loop().run_in_executor(
|
| 152 |
+
self._executor,
|
| 153 |
+
_convert_to_dict,
|
| 154 |
+
all_node_info.values(),
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
for node in all_node_infos:
|
| 158 |
+
yield node
|
| 159 |
+
|
| 160 |
+
while True:
|
| 161 |
+
try:
|
| 162 |
+
node_id_updated_info_tuples = await subscriber.poll(
|
| 163 |
+
batch_size=node_consts.RAY_DASHBOARD_NODE_SUBSCRIBER_POLL_SIZE
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
if node_id_updated_info_tuples:
|
| 167 |
+
_, updated_infos_proto = zip(*node_id_updated_info_tuples)
|
| 168 |
+
else:
|
| 169 |
+
updated_infos_proto = []
|
| 170 |
+
|
| 171 |
+
updated_infos = await get_or_create_event_loop().run_in_executor(
|
| 172 |
+
self._executor,
|
| 173 |
+
_convert_to_dict,
|
| 174 |
+
updated_infos_proto,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
for node in updated_infos:
|
| 178 |
+
yield node
|
| 179 |
+
except Exception:
|
| 180 |
+
logger.exception("Failed handling updated nodes.")
|
| 181 |
+
|
| 182 |
+
async def _update_node(self, node: dict):
|
| 183 |
+
node_id = node["nodeId"] # hex
|
| 184 |
+
if node["isHeadNode"] and not self._head_node_registration_time_s:
|
| 185 |
+
self._head_node_registration_time_s = time.time() - self._module_start_time
|
| 186 |
+
# Put head node ID in the internal KV to be read by JobAgent.
|
| 187 |
+
# TODO(architkulkarni): Remove once State API exposes which
|
| 188 |
+
# node is the head node.
|
| 189 |
+
await self.gcs_aio_client.internal_kv_put(
|
| 190 |
+
ray_constants.KV_HEAD_NODE_ID_KEY,
|
| 191 |
+
node_id.encode(),
|
| 192 |
+
overwrite=True,
|
| 193 |
+
namespace=ray_constants.KV_NAMESPACE_JOB,
|
| 194 |
+
timeout=GCS_RPC_TIMEOUT_SECONDS,
|
| 195 |
+
)
|
| 196 |
+
assert node["state"] in ["ALIVE", "DEAD"]
|
| 197 |
+
is_alive = node["state"] == "ALIVE"
|
| 198 |
+
# Prepare agents for alive node, and pop agents for dead node.
|
| 199 |
+
if is_alive:
|
| 200 |
+
if node_id not in DataSource.agents:
|
| 201 |
+
# Agent port is read from internal KV, which is only populated
|
| 202 |
+
# upon Agent startup. In case this update received before agent
|
| 203 |
+
# fully started up, we schedule a task to asynchronously update
|
| 204 |
+
# DataSource with appropriate agent port.
|
| 205 |
+
asyncio.create_task(self._update_agent(node_id))
|
| 206 |
+
else:
|
| 207 |
+
DataSource.agents.pop(node_id, None)
|
| 208 |
+
self._dead_node_queue.append(node_id)
|
| 209 |
+
if len(self._dead_node_queue) > node_consts.MAX_DEAD_NODES_TO_CACHE:
|
| 210 |
+
DataSource.nodes.pop(self._dead_node_queue.popleft(), None)
|
| 211 |
+
DataSource.nodes[node_id] = node
|
| 212 |
+
|
| 213 |
+
async def _update_agent(self, node_id):
|
| 214 |
+
"""
|
| 215 |
+
Given a node, update the agent_port in DataSource.agents. Problem is it's not
|
| 216 |
+
present until agent.py starts, so we need to loop waiting for agent.py writes
|
| 217 |
+
its port to internal kv.
|
| 218 |
+
"""
|
| 219 |
+
key = (
|
| 220 |
+
f"{dashboard_consts.DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}{node_id}".encode()
|
| 221 |
+
)
|
| 222 |
+
while True:
|
| 223 |
+
try:
|
| 224 |
+
agent_addr = await self.gcs_aio_client.internal_kv_get(
|
| 225 |
+
key,
|
| 226 |
+
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
|
| 227 |
+
timeout=None,
|
| 228 |
+
)
|
| 229 |
+
# The node may be dead already. Only update DataSource.agents if the
|
| 230 |
+
# node is still alive.
|
| 231 |
+
if DataSource.nodes.get(node_id, {}).get("state") != "ALIVE":
|
| 232 |
+
return
|
| 233 |
+
if agent_addr:
|
| 234 |
+
DataSource.agents[node_id] = json.loads(agent_addr)
|
| 235 |
+
return
|
| 236 |
+
except Exception:
|
| 237 |
+
logger.exception(f"Error getting agent port for node {node_id}.")
|
| 238 |
+
|
| 239 |
+
await asyncio.sleep(node_consts.RAY_DASHBOARD_AGENT_POLL_INTERVAL_S)
|
| 240 |
+
|
| 241 |
+
async def _update_nodes(self):
|
| 242 |
+
"""
|
| 243 |
+
Subscribe to node updates and update the internal states. If the head node is
|
| 244 |
+
not registered after RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT, it logs a
|
| 245 |
+
warning only once.
|
| 246 |
+
"""
|
| 247 |
+
warning_shown = False
|
| 248 |
+
async for node in self._subscribe_for_node_updates():
|
| 249 |
+
await self._update_node(node)
|
| 250 |
+
if not self._head_node_registration_time_s:
|
| 251 |
+
# head node is not registered yet
|
| 252 |
+
if (
|
| 253 |
+
not warning_shown
|
| 254 |
+
and (time.time() - self._module_start_time)
|
| 255 |
+
> RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT
|
| 256 |
+
):
|
| 257 |
+
logger.warning(
|
| 258 |
+
"Head node is not registered even after "
|
| 259 |
+
f"{RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT} seconds. "
|
| 260 |
+
"The API server might not work correctly. Please "
|
| 261 |
+
"report a Github issue. Internal states :"
|
| 262 |
+
f"{self.get_internal_states()}"
|
| 263 |
+
)
|
| 264 |
+
warning_shown = True
|
| 265 |
+
|
| 266 |
+
@routes.get("/internal/node_module")
|
| 267 |
+
async def get_node_module_internal_state(self, req) -> aiohttp.web.Response:
|
| 268 |
+
return dashboard_optional_utils.rest_response(
|
| 269 |
+
success=True,
|
| 270 |
+
message="",
|
| 271 |
+
**self.get_internal_states(),
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
async def get_nodes_logical_resources(self) -> dict:
|
| 275 |
+
|
| 276 |
+
from ray.autoscaler.v2.utils import is_autoscaler_v2
|
| 277 |
+
|
| 278 |
+
if is_autoscaler_v2():
|
| 279 |
+
from ray.autoscaler.v2.sdk import get_cluster_status
|
| 280 |
+
|
| 281 |
+
try:
|
| 282 |
+
cluster_status = get_cluster_status(self.gcs_address)
|
| 283 |
+
except Exception:
|
| 284 |
+
logger.exception("Error getting cluster status")
|
| 285 |
+
return {}
|
| 286 |
+
|
| 287 |
+
per_node_resources = {}
|
| 288 |
+
# TODO(rickyx): we should just return structure data rather than strings.
|
| 289 |
+
for node in chain(cluster_status.active_nodes, cluster_status.idle_nodes):
|
| 290 |
+
if not node.resource_usage:
|
| 291 |
+
continue
|
| 292 |
+
|
| 293 |
+
usage_dict = {
|
| 294 |
+
r.resource_name: (r.used, r.total)
|
| 295 |
+
for r in node.resource_usage.usage
|
| 296 |
+
}
|
| 297 |
+
per_node_resources[node.node_id] = "\n".join(
|
| 298 |
+
parse_usage(usage_dict, verbose=True)
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
return per_node_resources
|
| 302 |
+
|
| 303 |
+
# Legacy autoscaler status code.
|
| 304 |
+
(status_string, error) = await asyncio.gather(
|
| 305 |
+
*[
|
| 306 |
+
self.gcs_aio_client.internal_kv_get(
|
| 307 |
+
key.encode(), namespace=None, timeout=GCS_RPC_TIMEOUT_SECONDS
|
| 308 |
+
)
|
| 309 |
+
for key in [
|
| 310 |
+
DEBUG_AUTOSCALING_STATUS,
|
| 311 |
+
DEBUG_AUTOSCALING_ERROR,
|
| 312 |
+
]
|
| 313 |
+
]
|
| 314 |
+
)
|
| 315 |
+
if not status_string:
|
| 316 |
+
return {}
|
| 317 |
+
status_dict = json.loads(status_string)
|
| 318 |
+
|
| 319 |
+
lm_summary_dict = status_dict.get("load_metrics_report")
|
| 320 |
+
if lm_summary_dict:
|
| 321 |
+
lm_summary = LoadMetricsSummary(**lm_summary_dict)
|
| 322 |
+
|
| 323 |
+
node_logical_resources = get_per_node_breakdown_as_dict(lm_summary)
|
| 324 |
+
return node_logical_resources if error is None else {}
|
| 325 |
+
|
| 326 |
+
@routes.get("/nodes")
|
| 327 |
+
@dashboard_optional_utils.aiohttp_cache
|
| 328 |
+
async def get_all_nodes(self, req) -> aiohttp.web.Response:
|
| 329 |
+
view = req.query.get("view")
|
| 330 |
+
if view == "summary":
|
| 331 |
+
all_node_summary_task = DataOrganizer.get_all_node_summary()
|
| 332 |
+
nodes_logical_resource_task = self.get_nodes_logical_resources()
|
| 333 |
+
|
| 334 |
+
all_node_summary, nodes_logical_resources = await asyncio.gather(
|
| 335 |
+
all_node_summary_task, nodes_logical_resource_task
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
return dashboard_optional_utils.rest_response(
|
| 339 |
+
success=True,
|
| 340 |
+
message="Node summary fetched.",
|
| 341 |
+
summary=all_node_summary,
|
| 342 |
+
node_logical_resources=nodes_logical_resources,
|
| 343 |
+
)
|
| 344 |
+
elif view is not None and view.lower() == "hostNameList".lower():
|
| 345 |
+
alive_hostnames = set()
|
| 346 |
+
for node in DataSource.nodes.values():
|
| 347 |
+
if node["state"] == "ALIVE":
|
| 348 |
+
alive_hostnames.add(node["nodeManagerHostname"])
|
| 349 |
+
return dashboard_optional_utils.rest_response(
|
| 350 |
+
success=True,
|
| 351 |
+
message="Node hostname list fetched.",
|
| 352 |
+
host_name_list=list(alive_hostnames),
|
| 353 |
+
)
|
| 354 |
+
else:
|
| 355 |
+
return dashboard_optional_utils.rest_response(
|
| 356 |
+
success=False, message=f"Unknown view {view}"
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
@routes.get("/nodes/{node_id}")
|
| 360 |
+
@dashboard_optional_utils.aiohttp_cache
|
| 361 |
+
async def get_node(self, req) -> aiohttp.web.Response:
|
| 362 |
+
node_id = req.match_info.get("node_id")
|
| 363 |
+
node_info = await DataOrganizer.get_node_info(node_id)
|
| 364 |
+
return dashboard_optional_utils.rest_response(
|
| 365 |
+
success=True, message="Node details fetched.", detail=node_info
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
@async_loop_forever(node_consts.NODE_STATS_UPDATE_INTERVAL_SECONDS)
|
| 369 |
+
async def _update_node_stats(self):
|
| 370 |
+
timeout = max(2, node_consts.NODE_STATS_UPDATE_INTERVAL_SECONDS - 1)
|
| 371 |
+
|
| 372 |
+
# NOTE: We copy stubs to make sure
|
| 373 |
+
# it doesn't change during the iteration (since its being updated
|
| 374 |
+
# from another async task)
|
| 375 |
+
current_stub_node_id_tuples = list(self._stubs.items())
|
| 376 |
+
|
| 377 |
+
node_ids = []
|
| 378 |
+
get_node_stats_tasks = []
|
| 379 |
+
|
| 380 |
+
for _, (node_id, stub) in enumerate(current_stub_node_id_tuples):
|
| 381 |
+
node_info = DataSource.nodes.get(node_id)
|
| 382 |
+
if node_info["state"] != "ALIVE":
|
| 383 |
+
continue
|
| 384 |
+
|
| 385 |
+
node_ids.append(node_id)
|
| 386 |
+
get_node_stats_tasks.append(
|
| 387 |
+
stub.GetNodeStats(
|
| 388 |
+
node_manager_pb2.GetNodeStatsRequest(
|
| 389 |
+
include_memory_info=self._collect_memory_info
|
| 390 |
+
),
|
| 391 |
+
timeout=timeout,
|
| 392 |
+
)
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
responses = []
|
| 396 |
+
|
| 397 |
+
# NOTE: We're chunking up fetching of the stats to run in batches of no more
|
| 398 |
+
# than 100 nodes at a time to avoid flooding the event-loop's queue
|
| 399 |
+
# with potentially a large, uninterrupted sequence of tasks updating
|
| 400 |
+
# the node stats for very large clusters.
|
| 401 |
+
for get_node_stats_tasks_chunk in split(get_node_stats_tasks, 100):
|
| 402 |
+
current_chunk_responses = await asyncio.gather(
|
| 403 |
+
*get_node_stats_tasks_chunk,
|
| 404 |
+
return_exceptions=True,
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
responses.extend(current_chunk_responses)
|
| 408 |
+
|
| 409 |
+
# We're doing short (25ms) yield after every chunk to make sure
|
| 410 |
+
# - We're not overloading the event-loop with excessive # of tasks
|
| 411 |
+
# - Allowing 10k nodes stats fetches be sent out performed in 2.5s
|
| 412 |
+
await asyncio.sleep(0.025)
|
| 413 |
+
|
| 414 |
+
def postprocess(node_id_response_tuples):
|
| 415 |
+
"""Pure function reorganizing the data into {node_id: stats}."""
|
| 416 |
+
new_node_stats = {}
|
| 417 |
+
|
| 418 |
+
for node_id, response in node_id_response_tuples:
|
| 419 |
+
if isinstance(response, asyncio.CancelledError):
|
| 420 |
+
pass
|
| 421 |
+
elif isinstance(response, grpc.RpcError):
|
| 422 |
+
if response.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
|
| 423 |
+
message = (
|
| 424 |
+
f"Cannot reach the node, {node_id}, after timeout "
|
| 425 |
+
f" {timeout}. This node may have been overloaded, "
|
| 426 |
+
"terminated, or the network is slow."
|
| 427 |
+
)
|
| 428 |
+
elif response.code() == grpc.StatusCode.UNAVAILABLE:
|
| 429 |
+
message = (
|
| 430 |
+
f"Cannot reach the node, {node_id}. "
|
| 431 |
+
"The node may have been terminated."
|
| 432 |
+
)
|
| 433 |
+
else:
|
| 434 |
+
message = f"Error updating node stats of {node_id}."
|
| 435 |
+
|
| 436 |
+
logger.error(message, exc_info=response)
|
| 437 |
+
elif isinstance(response, Exception):
|
| 438 |
+
logger.error(
|
| 439 |
+
f"Error updating node stats of {node_id}.", exc_info=response
|
| 440 |
+
)
|
| 441 |
+
else:
|
| 442 |
+
new_node_stats[node_id] = node_stats_to_dict(response)
|
| 443 |
+
|
| 444 |
+
return new_node_stats
|
| 445 |
+
|
| 446 |
+
# NOTE: Zip will silently truncate to shorter argument that potentially
|
| 447 |
+
# could lead to subtle hard to catch issues, hence the assertion
|
| 448 |
+
assert len(node_ids) == len(responses)
|
| 449 |
+
|
| 450 |
+
new_node_stats = await get_or_create_event_loop().run_in_executor(
|
| 451 |
+
self._executor, postprocess, zip(node_ids, responses)
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
for node_id, new_stat in new_node_stats.items():
|
| 455 |
+
DataSource.node_stats[node_id] = new_stat
|
| 456 |
+
|
| 457 |
+
async def _update_node_physical_stats(self):
|
| 458 |
+
"""
|
| 459 |
+
Update DataSource.node_physical_stats by subscribing to the GCS resource usage.
|
| 460 |
+
"""
|
| 461 |
+
subscriber = GcsAioResourceUsageSubscriber(address=self.gcs_address)
|
| 462 |
+
await subscriber.subscribe()
|
| 463 |
+
|
| 464 |
+
loop = get_or_create_event_loop()
|
| 465 |
+
|
| 466 |
+
while True:
|
| 467 |
+
try:
|
| 468 |
+
# The key is b'RAY_REPORTER:{node id hex}',
|
| 469 |
+
# e.g. b'RAY_REPORTER:2b4fbd...'
|
| 470 |
+
key, data = await subscriber.poll()
|
| 471 |
+
if key is None:
|
| 472 |
+
continue
|
| 473 |
+
|
| 474 |
+
# NOTE: Every iteration is executed inside the thread-pool executor
|
| 475 |
+
# (TPE) to avoid blocking the Dashboard's event-loop
|
| 476 |
+
parsed_data = await loop.run_in_executor(
|
| 477 |
+
self._executor, json.loads, data
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
node_id = key.split(":")[-1]
|
| 481 |
+
DataSource.node_physical_stats[node_id] = parsed_data
|
| 482 |
+
except Exception:
|
| 483 |
+
logger.exception(
|
| 484 |
+
"Error receiving node physical stats from _update_node_physical_stats."
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
async def run(self, server):
|
| 488 |
+
await asyncio.gather(
|
| 489 |
+
self._update_nodes(),
|
| 490 |
+
self._update_node_stats(),
|
| 491 |
+
self._update_node_physical_stats(),
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
@staticmethod
|
| 495 |
+
def is_minimal_module():
|
| 496 |
+
return False
|
.venv/lib/python3.11/site-packages/ray/util/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (3.33 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/util/__pycache__/actor_group.cpython-311.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/util/__pycache__/actor_pool.cpython-311.pyc
ADDED
|
Binary file (17.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/util/__pycache__/annotations.cpython-311.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/util/__pycache__/check_open_ports.cpython-311.pyc
ADDED
|
Binary file (9.19 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/util/__pycache__/check_serialize.cpython-311.pyc
ADDED
|
Binary file (11.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/util/__pycache__/client_connect.cpython-311.pyc
ADDED
|
Binary file (3.37 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/util/__pycache__/debug.cpython-311.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/util/__pycache__/debugpy.cpython-311.pyc
ADDED
|
Binary file (6.59 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/util/__pycache__/iter.cpython-311.pyc
ADDED
|
Binary file (65.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/util/__pycache__/iter_metrics.cpython-311.pyc
ADDED
|
Binary file (3.94 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/util/__pycache__/metrics.cpython-311.pyc
ADDED
|
Binary file (14.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/util/__pycache__/placement_group.cpython-311.pyc
ADDED
|
Binary file (24.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/util/__pycache__/queue.cpython-311.pyc
ADDED
|
Binary file (16.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/util/__pycache__/rpdb.cpython-311.pyc
ADDED
|
Binary file (19.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/util/__pycache__/scheduling_strategies.cpython-311.pyc
ADDED
|
Binary file (9.84 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/util/__pycache__/serialization.cpython-311.pyc
ADDED
|
Binary file (3.52 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/util/__pycache__/serialization_addons.cpython-311.pyc
ADDED
|
Binary file (1.96 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/util/__pycache__/timer.cpython-311.pyc
ADDED
|
Binary file (4.12 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/util/accelerators/__init__.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
|
| 3 |
+
from ray.util.accelerators import tpu
|
| 4 |
+
from ray.util.accelerators.accelerators import (
|
| 5 |
+
NVIDIA_TESLA_V100,
|
| 6 |
+
NVIDIA_TESLA_P100,
|
| 7 |
+
NVIDIA_TESLA_T4,
|
| 8 |
+
NVIDIA_TESLA_P4,
|
| 9 |
+
NVIDIA_TESLA_K80,
|
| 10 |
+
NVIDIA_TESLA_A10G,
|
| 11 |
+
NVIDIA_L4,
|
| 12 |
+
NVIDIA_A100,
|
| 13 |
+
NVIDIA_H100,
|
| 14 |
+
INTEL_MAX_1550,
|
| 15 |
+
INTEL_MAX_1100,
|
| 16 |
+
INTEL_GAUDI,
|
| 17 |
+
AMD_INSTINCT_MI100,
|
| 18 |
+
AMD_INSTINCT_MI210,
|
| 19 |
+
AMD_INSTINCT_MI250,
|
| 20 |
+
AMD_INSTINCT_MI250x,
|
| 21 |
+
AMD_INSTINCT_MI300x,
|
| 22 |
+
AMD_RADEON_R9_200_HD_7900,
|
| 23 |
+
AMD_RADEON_HD_7900,
|
| 24 |
+
AWS_NEURON_CORE,
|
| 25 |
+
GOOGLE_TPU_V2,
|
| 26 |
+
GOOGLE_TPU_V3,
|
| 27 |
+
GOOGLE_TPU_V4,
|
| 28 |
+
GOOGLE_TPU_V5P,
|
| 29 |
+
GOOGLE_TPU_V5LITEPOD,
|
| 30 |
+
GOOGLE_TPU_V6E,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
__all__ = [
|
| 34 |
+
"tpu",
|
| 35 |
+
"NVIDIA_TESLA_V100",
|
| 36 |
+
"NVIDIA_TESLA_P100",
|
| 37 |
+
"NVIDIA_TESLA_T4",
|
| 38 |
+
"NVIDIA_TESLA_P4",
|
| 39 |
+
"NVIDIA_TESLA_K80",
|
| 40 |
+
"NVIDIA_TESLA_A10G",
|
| 41 |
+
"NVIDIA_L4",
|
| 42 |
+
"NVIDIA_A100",
|
| 43 |
+
"NVIDIA_A100_40G",
|
| 44 |
+
"NVIDIA_A100_80G",
|
| 45 |
+
"NVIDIA_H100",
|
| 46 |
+
"INTEL_MAX_1550",
|
| 47 |
+
"INTEL_MAX_1100",
|
| 48 |
+
"INTEL_GAUDI",
|
| 49 |
+
"AMD_INSTINCT_MI100",
|
| 50 |
+
"AMD_INSTINCT_MI210",
|
| 51 |
+
"AMD_INSTINCT_MI250",
|
| 52 |
+
"AMD_INSTINCT_MI250x",
|
| 53 |
+
"AMD_INSTINCT_MI300x",
|
| 54 |
+
"AMD_RADEON_R9_200_HD_7900",
|
| 55 |
+
"AMD_RADEON_HD_7900",
|
| 56 |
+
"AWS_NEURON_CORE",
|
| 57 |
+
"GOOGLE_TPU_V2",
|
| 58 |
+
"GOOGLE_TPU_V3",
|
| 59 |
+
"GOOGLE_TPU_V4",
|
| 60 |
+
"GOOGLE_TPU_V5P",
|
| 61 |
+
"GOOGLE_TPU_V5LITEPOD",
|
| 62 |
+
"GOOGLE_TPU_V6E",
|
| 63 |
+
# Deprecated
|
| 64 |
+
"NVIDIA_TESLA_A100",
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def __getattr__(name: str):
|
| 69 |
+
if name == "NVIDIA_TESLA_A100":
|
| 70 |
+
from ray.util.annotations import RayDeprecationWarning
|
| 71 |
+
|
| 72 |
+
warnings.warn(
|
| 73 |
+
"NVIDIA_TESLA_A100 is deprecated, use NVIDIA_A100 instead.",
|
| 74 |
+
RayDeprecationWarning,
|
| 75 |
+
stacklevel=2,
|
| 76 |
+
)
|
| 77 |
+
return NVIDIA_A100
|
| 78 |
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
.venv/lib/python3.11/site-packages/ray/util/accelerators/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (2.13 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/util/accelerators/__pycache__/accelerators.cpython-311.pyc
ADDED
|
Binary file (1.33 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/util/accelerators/__pycache__/tpu.cpython-311.pyc
ADDED
|
Binary file (1.95 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/util/accelerators/accelerators.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
NVIDIA_TESLA_V100 = "V100"
|
| 2 |
+
NVIDIA_TESLA_P100 = "P100"
|
| 3 |
+
NVIDIA_TESLA_T4 = "T4"
|
| 4 |
+
NVIDIA_TESLA_P4 = "P4"
|
| 5 |
+
NVIDIA_TESLA_K80 = "K80"
|
| 6 |
+
NVIDIA_TESLA_A10G = "A10G"
|
| 7 |
+
NVIDIA_L4 = "L4"
|
| 8 |
+
NVIDIA_L40S = "L40S"
|
| 9 |
+
NVIDIA_A100 = "A100"
|
| 10 |
+
NVIDIA_H100 = "H100"
|
| 11 |
+
INTEL_MAX_1550 = "Intel-GPU-Max-1550"
|
| 12 |
+
INTEL_MAX_1100 = "Intel-GPU-Max-1100"
|
| 13 |
+
INTEL_GAUDI = "Intel-GAUDI"
|
| 14 |
+
AMD_INSTINCT_MI100 = "AMD-Instinct-MI100"
|
| 15 |
+
AMD_INSTINCT_MI250x = "AMD-Instinct-MI250X"
|
| 16 |
+
AMD_INSTINCT_MI250 = "AMD-Instinct-MI250X-MI250"
|
| 17 |
+
AMD_INSTINCT_MI210 = "AMD-Instinct-MI210"
|
| 18 |
+
AMD_INSTINCT_MI300x = "AMD-Instinct-MI300X-OAM"
|
| 19 |
+
AMD_RADEON_R9_200_HD_7900 = "AMD-Radeon-R9-200-HD-7900"
|
| 20 |
+
AMD_RADEON_HD_7900 = "AMD-Radeon-HD-7900"
|
| 21 |
+
AWS_NEURON_CORE = "aws-neuron-core"
|
| 22 |
+
GOOGLE_TPU_V2 = "TPU-V2"
|
| 23 |
+
GOOGLE_TPU_V3 = "TPU-V3"
|
| 24 |
+
GOOGLE_TPU_V4 = "TPU-V4"
|
| 25 |
+
GOOGLE_TPU_V5P = "TPU-V5P"
|
| 26 |
+
GOOGLE_TPU_V5LITEPOD = "TPU-V5LITEPOD"
|
| 27 |
+
GOOGLE_TPU_V6E = "TPU-V6E"
|
| 28 |
+
|
| 29 |
+
# Use these instead of NVIDIA_A100 if you need a specific accelerator size. Note that
|
| 30 |
+
# these labels are not auto-added to nodes, you'll have to add them manually in
|
| 31 |
+
# addition to the default A100 label if needed.
|
| 32 |
+
NVIDIA_A100_40G = "A100-40G"
|
| 33 |
+
NVIDIA_A100_80G = "A100-80G"
|
.venv/lib/python3.11/site-packages/ray/util/accelerators/tpu.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
from ray._private.accelerators import TPUAcceleratorManager
|
| 3 |
+
from ray.util.annotations import PublicAPI
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@PublicAPI(stability="alpha")
|
| 7 |
+
def get_current_pod_name() -> Optional[str]:
|
| 8 |
+
"""
|
| 9 |
+
Return the name of the TPU pod that the worker is a part of.
|
| 10 |
+
|
| 11 |
+
Returns:
|
| 12 |
+
The name of the TPU pod. Returns None if not part of a TPU pod.
|
| 13 |
+
"""
|
| 14 |
+
tpu_name = TPUAcceleratorManager.get_current_node_tpu_name()
|
| 15 |
+
if tpu_name == "":
|
| 16 |
+
tpu_name = None
|
| 17 |
+
return tpu_name
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@PublicAPI(stability="alpha")
|
| 21 |
+
def get_current_pod_worker_count() -> Optional[int]:
|
| 22 |
+
"""
|
| 23 |
+
Count the number of workers associated with the TPU pod that the worker belongs to.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
The total number of workers in the TPU pod. Returns None if the worker is not
|
| 27 |
+
part of a TPU pod.
|
| 28 |
+
"""
|
| 29 |
+
return TPUAcceleratorManager.get_num_workers_in_current_tpu_pod()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@PublicAPI(stablity="alpha")
|
| 33 |
+
def get_num_tpu_chips_on_node() -> int:
|
| 34 |
+
"""
|
| 35 |
+
Return the number of TPU chips on the node.
|
| 36 |
+
Returns:
|
| 37 |
+
The total number of chips on the TPU node. Returns 0 if none are found.
|
| 38 |
+
"""
|
| 39 |
+
return TPUAcceleratorManager.get_current_node_num_accelerators()
|
.venv/lib/python3.11/site-packages/ray/util/annotations.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
from typing import Optional
|
| 3 |
+
import inspect
|
| 4 |
+
import sys
|
| 5 |
+
import warnings
|
| 6 |
+
from functools import wraps
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AnnotationType(Enum):
|
| 10 |
+
PUBLIC_API = "PublicAPI"
|
| 11 |
+
DEVELOPER_API = "DeveloperAPI"
|
| 12 |
+
DEPRECATED = "Deprecated"
|
| 13 |
+
UNKNOWN = "Unknown"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def PublicAPI(*args, **kwargs):
|
| 17 |
+
"""Annotation for documenting public APIs.
|
| 18 |
+
|
| 19 |
+
Public APIs are classes and methods exposed to end users of Ray.
|
| 20 |
+
|
| 21 |
+
If ``stability="alpha"``, the API can be used by advanced users who are
|
| 22 |
+
tolerant to and expect breaking changes.
|
| 23 |
+
|
| 24 |
+
If ``stability="beta"``, the API is still public and can be used by early
|
| 25 |
+
users, but are subject to change.
|
| 26 |
+
|
| 27 |
+
If ``stability="stable"``, the APIs will remain backwards compatible across
|
| 28 |
+
minor Ray releases (e.g., Ray 1.4 -> 1.8).
|
| 29 |
+
|
| 30 |
+
For a full definition of the stability levels, please refer to the
|
| 31 |
+
:ref:`Ray API Stability definitions <api-stability>`.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
stability: One of {"stable", "beta", "alpha"}.
|
| 35 |
+
api_group: Optional. Used only for doc rendering purpose. APIs in the same group
|
| 36 |
+
will be grouped together in the API doc pages.
|
| 37 |
+
|
| 38 |
+
Examples:
|
| 39 |
+
>>> from ray.util.annotations import PublicAPI
|
| 40 |
+
>>> @PublicAPI
|
| 41 |
+
... def func(x):
|
| 42 |
+
... return x
|
| 43 |
+
|
| 44 |
+
>>> @PublicAPI(stability="beta")
|
| 45 |
+
... def func(y):
|
| 46 |
+
... return y
|
| 47 |
+
"""
|
| 48 |
+
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
| 49 |
+
return PublicAPI(stability="stable", api_group="Others")(args[0])
|
| 50 |
+
|
| 51 |
+
if "stability" in kwargs:
|
| 52 |
+
stability = kwargs["stability"]
|
| 53 |
+
assert stability in ["stable", "beta", "alpha"], stability
|
| 54 |
+
else:
|
| 55 |
+
stability = "stable"
|
| 56 |
+
api_group = kwargs.get("api_group", "Others")
|
| 57 |
+
|
| 58 |
+
def wrap(obj):
|
| 59 |
+
if stability in ["alpha", "beta"]:
|
| 60 |
+
message = (
|
| 61 |
+
f"**PublicAPI ({stability}):** This API is in {stability} "
|
| 62 |
+
"and may change before becoming stable."
|
| 63 |
+
)
|
| 64 |
+
_append_doc(obj, message=message)
|
| 65 |
+
|
| 66 |
+
_mark_annotated(obj, type=AnnotationType.PUBLIC_API, api_group=api_group)
|
| 67 |
+
return obj
|
| 68 |
+
|
| 69 |
+
return wrap
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def DeveloperAPI(*args, **kwargs):
|
| 73 |
+
"""Annotation for documenting developer APIs.
|
| 74 |
+
|
| 75 |
+
Developer APIs are lower-level methods explicitly exposed to advanced Ray
|
| 76 |
+
users and library developers. Their interfaces may change across minor
|
| 77 |
+
Ray releases.
|
| 78 |
+
|
| 79 |
+
Examples:
|
| 80 |
+
>>> from ray.util.annotations import DeveloperAPI
|
| 81 |
+
>>> @DeveloperAPI
|
| 82 |
+
... def func(x):
|
| 83 |
+
... return x
|
| 84 |
+
"""
|
| 85 |
+
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
| 86 |
+
return DeveloperAPI()(args[0])
|
| 87 |
+
|
| 88 |
+
def wrap(obj):
|
| 89 |
+
_append_doc(
|
| 90 |
+
obj,
|
| 91 |
+
message="**DeveloperAPI:** This API may change across minor Ray releases.",
|
| 92 |
+
)
|
| 93 |
+
_mark_annotated(obj, type=AnnotationType.DEVELOPER_API)
|
| 94 |
+
return obj
|
| 95 |
+
|
| 96 |
+
return wrap
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class RayDeprecationWarning(DeprecationWarning):
|
| 100 |
+
"""Specialized Deprecation Warning for fine grained filtering control"""
|
| 101 |
+
|
| 102 |
+
pass
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# By default, print the first occurrence of matching warnings for
|
| 106 |
+
# each module where the warning is issued (regardless of line number)
|
| 107 |
+
if not sys.warnoptions:
|
| 108 |
+
warnings.filterwarnings("module", category=RayDeprecationWarning)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def Deprecated(*args, **kwargs):
|
| 112 |
+
"""Annotation for documenting a deprecated API.
|
| 113 |
+
|
| 114 |
+
Deprecated APIs may be removed in future releases of Ray.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
message: a message to help users understand the reason for the
|
| 118 |
+
deprecation, and provide a migration path.
|
| 119 |
+
|
| 120 |
+
Examples:
|
| 121 |
+
>>> from ray.util.annotations import Deprecated
|
| 122 |
+
>>> @Deprecated
|
| 123 |
+
... def func(x):
|
| 124 |
+
... return x
|
| 125 |
+
|
| 126 |
+
>>> @Deprecated(message="g() is deprecated because the API is error "
|
| 127 |
+
... "prone. Please call h() instead.")
|
| 128 |
+
... def g(y):
|
| 129 |
+
... return y
|
| 130 |
+
"""
|
| 131 |
+
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
| 132 |
+
return Deprecated()(args[0])
|
| 133 |
+
|
| 134 |
+
doc_message = (
|
| 135 |
+
"**DEPRECATED**: This API is deprecated and may be removed "
|
| 136 |
+
"in future Ray releases."
|
| 137 |
+
)
|
| 138 |
+
warning_message = (
|
| 139 |
+
"This API is deprecated and may be removed in future Ray releases. "
|
| 140 |
+
"You could suppress this warning by setting env variable "
|
| 141 |
+
'PYTHONWARNINGS="ignore::DeprecationWarning"'
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
warning = kwargs.pop("warning", False)
|
| 145 |
+
|
| 146 |
+
if "message" in kwargs:
|
| 147 |
+
doc_message = doc_message + "\n" + kwargs["message"]
|
| 148 |
+
warning_message = warning_message + "\n" + kwargs["message"]
|
| 149 |
+
del kwargs["message"]
|
| 150 |
+
|
| 151 |
+
if kwargs:
|
| 152 |
+
raise ValueError("Unknown kwargs: {}".format(kwargs.keys()))
|
| 153 |
+
|
| 154 |
+
def inner(obj):
|
| 155 |
+
_append_doc(obj, message=doc_message, directive="warning")
|
| 156 |
+
_mark_annotated(obj, type=AnnotationType.DEPRECATED)
|
| 157 |
+
|
| 158 |
+
if not warning:
|
| 159 |
+
return obj
|
| 160 |
+
|
| 161 |
+
if inspect.isclass(obj):
|
| 162 |
+
obj_init = obj.__init__
|
| 163 |
+
|
| 164 |
+
def patched_init(*args, **kwargs):
|
| 165 |
+
warnings.warn(warning_message, RayDeprecationWarning, stacklevel=2)
|
| 166 |
+
return obj_init(*args, **kwargs)
|
| 167 |
+
|
| 168 |
+
obj.__init__ = patched_init
|
| 169 |
+
return obj
|
| 170 |
+
else:
|
| 171 |
+
# class method or function.
|
| 172 |
+
@wraps(obj)
|
| 173 |
+
def wrapper(*args, **kwargs):
|
| 174 |
+
warnings.warn(warning_message, RayDeprecationWarning, stacklevel=2)
|
| 175 |
+
return obj(*args, **kwargs)
|
| 176 |
+
|
| 177 |
+
return wrapper
|
| 178 |
+
|
| 179 |
+
return inner
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _append_doc(obj, *, message: str, directive: Optional[str] = None) -> str:
|
| 183 |
+
if not obj.__doc__:
|
| 184 |
+
obj.__doc__ = ""
|
| 185 |
+
|
| 186 |
+
obj.__doc__ = obj.__doc__.rstrip()
|
| 187 |
+
|
| 188 |
+
indent = _get_indent(obj.__doc__)
|
| 189 |
+
obj.__doc__ += "\n\n"
|
| 190 |
+
|
| 191 |
+
if directive is not None:
|
| 192 |
+
obj.__doc__ += f"{' ' * indent}.. {directive}::\n\n"
|
| 193 |
+
|
| 194 |
+
message = message.replace("\n", "\n" + " " * (indent + 4))
|
| 195 |
+
obj.__doc__ += f"{' ' * (indent + 4)}{message}"
|
| 196 |
+
else:
|
| 197 |
+
message = message.replace("\n", "\n" + " " * (indent + 4))
|
| 198 |
+
obj.__doc__ += f"{' ' * indent}{message}"
|
| 199 |
+
obj.__doc__ += f"\n{' ' * indent}"
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def _get_indent(docstring: str) -> int:
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
Example:
|
| 206 |
+
>>> def f():
|
| 207 |
+
... '''Docstring summary.'''
|
| 208 |
+
>>> f.__doc__
|
| 209 |
+
'Docstring summary.'
|
| 210 |
+
>>> _get_indent(f.__doc__)
|
| 211 |
+
0
|
| 212 |
+
|
| 213 |
+
>>> def g(foo):
|
| 214 |
+
... '''Docstring summary.
|
| 215 |
+
...
|
| 216 |
+
... Args:
|
| 217 |
+
... foo: Does bar.
|
| 218 |
+
... '''
|
| 219 |
+
>>> g.__doc__
|
| 220 |
+
'Docstring summary.\\n\\n Args:\\n foo: Does bar.\\n '
|
| 221 |
+
>>> _get_indent(g.__doc__)
|
| 222 |
+
4
|
| 223 |
+
|
| 224 |
+
>>> class A:
|
| 225 |
+
... def h():
|
| 226 |
+
... '''Docstring summary.
|
| 227 |
+
...
|
| 228 |
+
... Returns:
|
| 229 |
+
... None.
|
| 230 |
+
... '''
|
| 231 |
+
>>> A.h.__doc__
|
| 232 |
+
'Docstring summary.\\n\\n Returns:\\n None.\\n '
|
| 233 |
+
>>> _get_indent(A.h.__doc__)
|
| 234 |
+
8
|
| 235 |
+
"""
|
| 236 |
+
if not docstring:
|
| 237 |
+
return 0
|
| 238 |
+
|
| 239 |
+
non_empty_lines = list(filter(bool, docstring.splitlines()))
|
| 240 |
+
if len(non_empty_lines) == 1:
|
| 241 |
+
# Docstring contains summary only.
|
| 242 |
+
return 0
|
| 243 |
+
|
| 244 |
+
# The docstring summary isn't indented, so check the indentation of the second
|
| 245 |
+
# non-empty line.
|
| 246 |
+
return len(non_empty_lines[1]) - len(non_empty_lines[1].lstrip())
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def _mark_annotated(
|
| 250 |
+
obj, type: AnnotationType = AnnotationType.UNKNOWN, api_group="Others"
|
| 251 |
+
) -> None:
|
| 252 |
+
# Set magic token for check_api_annotations linter.
|
| 253 |
+
if hasattr(obj, "__name__"):
|
| 254 |
+
obj._annotated = obj.__name__
|
| 255 |
+
obj._annotated_type = type
|
| 256 |
+
obj._annotated_api_group = api_group
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def _is_annotated(obj) -> bool:
|
| 260 |
+
# Check the magic token exists and applies to this class (not a subclass).
|
| 261 |
+
return hasattr(obj, "_annotated") and obj._annotated == obj.__name__
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def _get_annotation_type(obj) -> Optional[str]:
|
| 265 |
+
if not _is_annotated(obj):
|
| 266 |
+
return None
|
| 267 |
+
|
| 268 |
+
return obj._annotated_type.value
|
.venv/lib/python3.11/site-packages/ray/util/client/api.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This file defines the interface between the ray client worker
|
| 2 |
+
and the overall ray module API.
|
| 3 |
+
"""
|
| 4 |
+
import json
|
| 5 |
+
import logging
|
| 6 |
+
from concurrent.futures import Future
|
| 7 |
+
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
|
| 8 |
+
|
| 9 |
+
from ray._private import ray_option_utils
|
| 10 |
+
from ray.util.client.runtime_context import _ClientWorkerPropertyAPI
|
| 11 |
+
|
| 12 |
+
if TYPE_CHECKING:
|
| 13 |
+
from ray.actor import ActorClass
|
| 14 |
+
from ray.core.generated.ray_client_pb2 import DataResponse
|
| 15 |
+
from ray.remote_function import RemoteFunction
|
| 16 |
+
from ray.util.client.common import ClientActorHandle, ClientObjectRef, ClientStub
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _as_bytes(value):
|
| 22 |
+
if isinstance(value, str):
|
| 23 |
+
return value.encode("utf-8")
|
| 24 |
+
return value
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class _ClientAPI:
|
| 28 |
+
"""The Client-side methods corresponding to the ray API. Delegates
|
| 29 |
+
to the Client Worker that contains the connection to the ClientServer.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, worker=None):
|
| 33 |
+
self.worker = worker
|
| 34 |
+
|
| 35 |
+
def get(self, vals, *, timeout=None):
|
| 36 |
+
"""get is the hook stub passed on to replace `ray.get`
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
vals: [Client]ObjectRef or list of these refs to retrieve.
|
| 40 |
+
timeout: Optional timeout in milliseconds
|
| 41 |
+
"""
|
| 42 |
+
return self.worker.get(vals, timeout=timeout)
|
| 43 |
+
|
| 44 |
+
def put(self, *args, **kwargs):
|
| 45 |
+
"""put is the hook stub passed on to replace `ray.put`
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
val: The value to `put`.
|
| 49 |
+
args: opaque arguments
|
| 50 |
+
kwargs: opaque keyword arguments
|
| 51 |
+
"""
|
| 52 |
+
return self.worker.put(*args, **kwargs)
|
| 53 |
+
|
| 54 |
+
def wait(self, *args, **kwargs):
|
| 55 |
+
"""wait is the hook stub passed on to replace `ray.wait`
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
args: opaque arguments
|
| 59 |
+
kwargs: opaque keyword arguments
|
| 60 |
+
"""
|
| 61 |
+
return self.worker.wait(*args, **kwargs)
|
| 62 |
+
|
| 63 |
+
def remote(self, *args, **kwargs):
|
| 64 |
+
"""remote is the hook stub passed on to replace `ray.remote`.
|
| 65 |
+
|
| 66 |
+
This sets up remote functions or actors, as the decorator,
|
| 67 |
+
but does not execute them.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
args: opaque arguments
|
| 71 |
+
kwargs: opaque keyword arguments
|
| 72 |
+
"""
|
| 73 |
+
# Delayed import to avoid a cyclic import
|
| 74 |
+
from ray.util.client.common import remote_decorator
|
| 75 |
+
|
| 76 |
+
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
| 77 |
+
# This is the case where the decorator is just @ray.remote.
|
| 78 |
+
return remote_decorator(options=None)(args[0])
|
| 79 |
+
assert (
|
| 80 |
+
len(args) == 0 and len(kwargs) > 0
|
| 81 |
+
), ray_option_utils.remote_args_error_string
|
| 82 |
+
return remote_decorator(options=kwargs)
|
| 83 |
+
|
| 84 |
+
# TODO(mwtian): consider adding _internal_ prefix to call_remote /
|
| 85 |
+
# call_release / call_retain.
|
| 86 |
+
def call_remote(self, instance: "ClientStub", *args, **kwargs) -> List[Future]:
|
| 87 |
+
"""call_remote is called by stub objects to execute them remotely.
|
| 88 |
+
|
| 89 |
+
This is used by stub objects in situations where they're called
|
| 90 |
+
with .remote, eg, `f.remote()` or `actor_cls.remote()`.
|
| 91 |
+
This allows the client stub objects to delegate execution to be
|
| 92 |
+
implemented in the most effective way whether it's in the client,
|
| 93 |
+
clientserver, or raylet worker.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
instance: The Client-side stub reference to a remote object
|
| 97 |
+
args: opaque arguments
|
| 98 |
+
kwargs: opaque keyword arguments
|
| 99 |
+
"""
|
| 100 |
+
return self.worker.call_remote(instance, *args, **kwargs)
|
| 101 |
+
|
| 102 |
+
def call_release(self, id: bytes) -> None:
|
| 103 |
+
"""Attempts to release an object reference.
|
| 104 |
+
|
| 105 |
+
When client references are destructed, they release their reference,
|
| 106 |
+
which can opportunistically send a notification through the datachannel
|
| 107 |
+
to release the reference being held for that object on the server.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
id: The id of the reference to release on the server side.
|
| 111 |
+
"""
|
| 112 |
+
return self.worker.call_release(id)
|
| 113 |
+
|
| 114 |
+
def call_retain(self, id: bytes) -> None:
|
| 115 |
+
"""Attempts to retain a client object reference.
|
| 116 |
+
|
| 117 |
+
Increments the reference count on the client side, to prevent
|
| 118 |
+
the client worker from attempting to release the server reference.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
id: The id of the reference to retain on the client side.
|
| 122 |
+
"""
|
| 123 |
+
return self.worker.call_retain(id)
|
| 124 |
+
|
| 125 |
+
def close(self) -> None:
|
| 126 |
+
"""close cleans up an API connection by closing any channels or
|
| 127 |
+
shutting down any servers gracefully.
|
| 128 |
+
"""
|
| 129 |
+
return self.worker.close()
|
| 130 |
+
|
| 131 |
+
def get_actor(
|
| 132 |
+
self, name: str, namespace: Optional[str] = None
|
| 133 |
+
) -> "ClientActorHandle":
|
| 134 |
+
"""Returns a handle to an actor by name.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
name: The name passed to this actor by
|
| 138 |
+
Actor.options(name="name").remote()
|
| 139 |
+
"""
|
| 140 |
+
return self.worker.get_actor(name, namespace)
|
| 141 |
+
|
| 142 |
+
def list_named_actors(self, all_namespaces: bool = False) -> List[str]:
|
| 143 |
+
"""List all named actors in the system.
|
| 144 |
+
|
| 145 |
+
Actors must have been created with Actor.options(name="name").remote().
|
| 146 |
+
This works for both detached & non-detached actors.
|
| 147 |
+
|
| 148 |
+
By default, only actors in the current namespace will be returned
|
| 149 |
+
and the returned entries will simply be their name.
|
| 150 |
+
|
| 151 |
+
If `all_namespaces` is set to True, all actors in the cluster will be
|
| 152 |
+
returned regardless of namespace, and the retunred entries will be of
|
| 153 |
+
the form '<namespace>/<name>'.
|
| 154 |
+
"""
|
| 155 |
+
return self.worker.list_named_actors(all_namespaces)
|
| 156 |
+
|
| 157 |
+
def kill(self, actor: "ClientActorHandle", *, no_restart=True):
|
| 158 |
+
"""kill forcibly stops an actor running in the cluster
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
no_restart: Whether this actor should be restarted if it's a
|
| 162 |
+
restartable actor.
|
| 163 |
+
"""
|
| 164 |
+
return self.worker.terminate_actor(actor, no_restart)
|
| 165 |
+
|
| 166 |
+
def cancel(self, obj: "ClientObjectRef", *, force=False, recursive=True):
|
| 167 |
+
"""Cancels a task on the cluster.
|
| 168 |
+
|
| 169 |
+
If the specified task is pending execution, it will not be executed. If
|
| 170 |
+
the task is currently executing, the behavior depends on the ``force``
|
| 171 |
+
flag, as per `ray.cancel()`
|
| 172 |
+
|
| 173 |
+
Only non-actor tasks can be canceled. Canceled tasks will not be
|
| 174 |
+
retried (max_retries will not be respected).
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
object_ref: ObjectRef returned by the task
|
| 178 |
+
that should be canceled.
|
| 179 |
+
force: Whether to force-kill a running task by killing
|
| 180 |
+
the worker that is running the task.
|
| 181 |
+
recursive: Whether to try to cancel tasks submitted by
|
| 182 |
+
the task specified.
|
| 183 |
+
"""
|
| 184 |
+
return self.worker.terminate_task(obj, force, recursive)
|
| 185 |
+
|
| 186 |
+
# Various metadata methods for the client that are defined in the protocol.
|
| 187 |
+
def is_initialized(self) -> bool:
|
| 188 |
+
"""True if our client is connected, and if the server is initialized.
|
| 189 |
+
Returns:
|
| 190 |
+
A boolean determining if the client is connected and
|
| 191 |
+
server initialized.
|
| 192 |
+
"""
|
| 193 |
+
return self.worker.is_initialized()
|
| 194 |
+
|
| 195 |
+
def nodes(self):
|
| 196 |
+
"""Get a list of the nodes in the cluster (for debugging only).
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
Information about the Ray clients in the cluster.
|
| 200 |
+
"""
|
| 201 |
+
# This should be imported here, otherwise, it will error doc build.
|
| 202 |
+
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
| 203 |
+
|
| 204 |
+
return self.worker.get_cluster_info(ray_client_pb2.ClusterInfoType.NODES)
|
| 205 |
+
|
| 206 |
+
def method(self, *args, **kwargs):
|
| 207 |
+
"""Annotate an actor method
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
num_returns: The number of object refs that should be returned by
|
| 211 |
+
invocations of this actor method.
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
# NOTE: So this follows the same logic as in ray/actor.py::method()
|
| 215 |
+
# The reason to duplicate it here is to simplify the client mode
|
| 216 |
+
# redirection logic. As the annotated method gets pickled and sent to
|
| 217 |
+
# the server from the client it carries this private variable, it
|
| 218 |
+
# activates the same logic on the server side; so there's no need to
|
| 219 |
+
# pass anything else. It's inside the class definition that becomes an
|
| 220 |
+
# actor. Similar annotations would follow the same way.
|
| 221 |
+
valid_kwargs = ["num_returns", "concurrency_group"]
|
| 222 |
+
error_string = (
|
| 223 |
+
"The @ray.method decorator must be applied using at least one of "
|
| 224 |
+
f"the arguments in the list {valid_kwargs}, for example "
|
| 225 |
+
"'@ray.method(num_returns=2)'."
|
| 226 |
+
)
|
| 227 |
+
assert len(args) == 0 and len(kwargs) > 0, error_string
|
| 228 |
+
for key in kwargs:
|
| 229 |
+
key_error_string = (
|
| 230 |
+
f'Unexpected keyword argument to @ray.method: "{key}". The '
|
| 231 |
+
f"supported keyword arguments are {valid_kwargs}"
|
| 232 |
+
)
|
| 233 |
+
assert key in valid_kwargs, key_error_string
|
| 234 |
+
|
| 235 |
+
def annotate_method(method):
|
| 236 |
+
if "num_returns" in kwargs:
|
| 237 |
+
method.__ray_num_returns__ = kwargs["num_returns"]
|
| 238 |
+
if "concurrency_group" in kwargs:
|
| 239 |
+
method.__ray_concurrency_group__ = kwargs["concurrency_group"]
|
| 240 |
+
return method
|
| 241 |
+
|
| 242 |
+
return annotate_method
|
| 243 |
+
|
| 244 |
+
def cluster_resources(self):
|
| 245 |
+
"""Get the current total cluster resources.
|
| 246 |
+
|
| 247 |
+
Note that this information can grow stale as nodes are added to or
|
| 248 |
+
removed from the cluster.
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
A dictionary mapping resource name to the total quantity of that
|
| 252 |
+
resource in the cluster.
|
| 253 |
+
"""
|
| 254 |
+
# This should be imported here, otherwise, it will error doc build.
|
| 255 |
+
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
| 256 |
+
|
| 257 |
+
return self.worker.get_cluster_info(
|
| 258 |
+
ray_client_pb2.ClusterInfoType.CLUSTER_RESOURCES
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
def available_resources(self):
|
| 262 |
+
"""Get the current available cluster resources.
|
| 263 |
+
|
| 264 |
+
This is different from `cluster_resources` in that this will return
|
| 265 |
+
idle (available) resources rather than total resources.
|
| 266 |
+
|
| 267 |
+
Note that this information can grow stale as tasks start and finish.
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
A dictionary mapping resource name to the total quantity of that
|
| 271 |
+
resource in the cluster.
|
| 272 |
+
"""
|
| 273 |
+
# This should be imported here, otherwise, it will error doc build.
|
| 274 |
+
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
| 275 |
+
|
| 276 |
+
return self.worker.get_cluster_info(
|
| 277 |
+
ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
def get_runtime_context(self):
|
| 281 |
+
"""Return a Ray RuntimeContext describing the state on the server
|
| 282 |
+
|
| 283 |
+
Returns:
|
| 284 |
+
A RuntimeContext wrapping a client making get_cluster_info calls.
|
| 285 |
+
"""
|
| 286 |
+
return _ClientWorkerPropertyAPI(self.worker).build_runtime_context()
|
| 287 |
+
|
| 288 |
+
# Client process isn't assigned any GPUs.
|
| 289 |
+
def get_gpu_ids(self) -> list:
|
| 290 |
+
return []
|
| 291 |
+
|
| 292 |
+
def timeline(self, filename: Optional[str] = None) -> Optional[List[Any]]:
|
| 293 |
+
logger.warning(
|
| 294 |
+
"Timeline will include events from other clients using this server."
|
| 295 |
+
)
|
| 296 |
+
# This should be imported here, otherwise, it will error doc build.
|
| 297 |
+
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
| 298 |
+
|
| 299 |
+
all_events = self.worker.get_cluster_info(
|
| 300 |
+
ray_client_pb2.ClusterInfoType.TIMELINE
|
| 301 |
+
)
|
| 302 |
+
if filename is not None:
|
| 303 |
+
with open(filename, "w") as outfile:
|
| 304 |
+
json.dump(all_events, outfile)
|
| 305 |
+
else:
|
| 306 |
+
return all_events
|
| 307 |
+
|
| 308 |
+
def _internal_kv_initialized(self) -> bool:
|
| 309 |
+
"""Hook for internal_kv._internal_kv_initialized."""
|
| 310 |
+
# NOTE(edoakes): the kv is always initialized because we initialize it
|
| 311 |
+
# manually in the proxier with a GCS client if Ray hasn't been
|
| 312 |
+
# initialized yet.
|
| 313 |
+
return True
|
| 314 |
+
|
| 315 |
+
def _internal_kv_exists(
|
| 316 |
+
self, key: Union[str, bytes], *, namespace: Optional[Union[str, bytes]] = None
|
| 317 |
+
) -> bool:
|
| 318 |
+
"""Hook for internal_kv._internal_kv_exists."""
|
| 319 |
+
return self.worker.internal_kv_exists(
|
| 320 |
+
_as_bytes(key), namespace=_as_bytes(namespace)
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
def _internal_kv_get(
|
| 324 |
+
self, key: Union[str, bytes], *, namespace: Optional[Union[str, bytes]] = None
|
| 325 |
+
) -> bytes:
|
| 326 |
+
"""Hook for internal_kv._internal_kv_get."""
|
| 327 |
+
return self.worker.internal_kv_get(
|
| 328 |
+
_as_bytes(key), namespace=_as_bytes(namespace)
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
def _internal_kv_put(
|
| 332 |
+
self,
|
| 333 |
+
key: Union[str, bytes],
|
| 334 |
+
value: Union[str, bytes],
|
| 335 |
+
overwrite: bool = True,
|
| 336 |
+
*,
|
| 337 |
+
namespace: Optional[Union[str, bytes]] = None,
|
| 338 |
+
) -> bool:
|
| 339 |
+
"""Hook for internal_kv._internal_kv_put."""
|
| 340 |
+
return self.worker.internal_kv_put(
|
| 341 |
+
_as_bytes(key), _as_bytes(value), overwrite, namespace=_as_bytes(namespace)
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
def _internal_kv_del(
|
| 345 |
+
self,
|
| 346 |
+
key: Union[str, bytes],
|
| 347 |
+
*,
|
| 348 |
+
del_by_prefix: bool = False,
|
| 349 |
+
namespace: Optional[Union[str, bytes]] = None,
|
| 350 |
+
) -> int:
|
| 351 |
+
"""Hook for internal_kv._internal_kv_del."""
|
| 352 |
+
return self.worker.internal_kv_del(
|
| 353 |
+
_as_bytes(key), del_by_prefix=del_by_prefix, namespace=_as_bytes(namespace)
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
def _internal_kv_list(
|
| 357 |
+
self,
|
| 358 |
+
prefix: Union[str, bytes],
|
| 359 |
+
*,
|
| 360 |
+
namespace: Optional[Union[str, bytes]] = None,
|
| 361 |
+
) -> List[bytes]:
|
| 362 |
+
"""Hook for internal_kv._internal_kv_list."""
|
| 363 |
+
return self.worker.internal_kv_list(
|
| 364 |
+
_as_bytes(prefix), namespace=_as_bytes(namespace)
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
def _pin_runtime_env_uri(self, uri: str, expiration_s: int) -> None:
|
| 368 |
+
"""Hook for internal_kv._pin_runtime_env_uri."""
|
| 369 |
+
return self.worker.pin_runtime_env_uri(uri, expiration_s)
|
| 370 |
+
|
| 371 |
+
def _convert_actor(self, actor: "ActorClass") -> str:
|
| 372 |
+
"""Register a ClientActorClass for the ActorClass and return a UUID"""
|
| 373 |
+
return self.worker._convert_actor(actor)
|
| 374 |
+
|
| 375 |
+
def _convert_function(self, func: "RemoteFunction") -> str:
|
| 376 |
+
"""Register a ClientRemoteFunc for the ActorClass and return a UUID"""
|
| 377 |
+
return self.worker._convert_function(func)
|
| 378 |
+
|
| 379 |
+
def _get_converted(self, key: str) -> "ClientStub":
|
| 380 |
+
"""Given a UUID, return the converted object"""
|
| 381 |
+
return self.worker._get_converted(key)
|
| 382 |
+
|
| 383 |
+
def _converted_key_exists(self, key: str) -> bool:
|
| 384 |
+
"""Check if a key UUID is present in the store of converted objects."""
|
| 385 |
+
return self.worker._converted_key_exists(key)
|
| 386 |
+
|
| 387 |
+
def __getattr__(self, key: str):
|
| 388 |
+
if not key.startswith("_"):
|
| 389 |
+
raise NotImplementedError(
|
| 390 |
+
"Not available in Ray client: `ray.{}`. This method is only "
|
| 391 |
+
"available within Ray remote functions and is not yet "
|
| 392 |
+
"implemented in the client API.".format(key)
|
| 393 |
+
)
|
| 394 |
+
return self.__getattribute__(key)
|
| 395 |
+
|
| 396 |
+
def _register_callback(
|
| 397 |
+
self, ref: "ClientObjectRef", callback: Callable[["DataResponse"], None]
|
| 398 |
+
) -> None:
|
| 399 |
+
self.worker.register_callback(ref, callback)
|
| 400 |
+
|
| 401 |
+
def _get_dashboard_url(self) -> str:
|
| 402 |
+
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
| 403 |
+
|
| 404 |
+
return self.worker.get_cluster_info(
|
| 405 |
+
ray_client_pb2.ClusterInfoType.DASHBOARD_URL
|
| 406 |
+
).get("dashboard_url", "")
|
.venv/lib/python3.11/site-packages/ray/util/client/client_app.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.util.client import ray
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
ray.connect("localhost:50051")
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@ray.remote
|
| 8 |
+
class HelloActor:
|
| 9 |
+
def __init__(self):
|
| 10 |
+
self.count = 0
|
| 11 |
+
|
| 12 |
+
def say_hello(self, whom: str) -> Tuple[str, int]:
|
| 13 |
+
self.count += 1
|
| 14 |
+
return ("Hello " + whom, self.count)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
actor = HelloActor.remote()
|
| 18 |
+
s, count = ray.get(actor.say_hello.remote("you"))
|
| 19 |
+
print(s, count)
|
| 20 |
+
assert s == "Hello you"
|
| 21 |
+
assert count == 1
|
| 22 |
+
s, count = ray.get(actor.say_hello.remote("world"))
|
| 23 |
+
print(s, count)
|
| 24 |
+
assert s == "Hello world"
|
| 25 |
+
assert count == 2
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@ray.remote
|
| 29 |
+
def plus2(x):
|
| 30 |
+
return x + 2
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@ray.remote
|
| 34 |
+
def fact(x):
|
| 35 |
+
print(x, type(fact))
|
| 36 |
+
if x <= 0:
|
| 37 |
+
return 1
|
| 38 |
+
# This hits the "nested tasks" issue
|
| 39 |
+
# https://github.com/ray-project/ray/issues/3644
|
| 40 |
+
# So we're on the right track!
|
| 41 |
+
return ray.get(fact.remote(x - 1)) * x
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@ray.remote
|
| 45 |
+
def get_nodes():
|
| 46 |
+
return ray.nodes() # Can access the full Ray API in remote methods.
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
print("Cluster nodes", ray.get(get_nodes.remote()))
|
| 50 |
+
print(ray.nodes())
|
| 51 |
+
|
| 52 |
+
objectref = ray.put("hello world")
|
| 53 |
+
|
| 54 |
+
# `ClientObjectRef(...)`
|
| 55 |
+
print(objectref)
|
| 56 |
+
|
| 57 |
+
# `hello world`
|
| 58 |
+
print(ray.get(objectref))
|
| 59 |
+
|
| 60 |
+
ref2 = plus2.remote(234)
|
| 61 |
+
# `ClientObjectRef(...)`
|
| 62 |
+
print(ref2)
|
| 63 |
+
# `236`
|
| 64 |
+
print(ray.get(ref2))
|
| 65 |
+
|
| 66 |
+
ref3 = fact.remote(20)
|
| 67 |
+
# `ClientObjectRef(...)`
|
| 68 |
+
print(ref3)
|
| 69 |
+
# `2432902008176640000`
|
| 70 |
+
print(ray.get(ref3))
|
| 71 |
+
|
| 72 |
+
# Reuse the cached ClientRemoteFunc object
|
| 73 |
+
ref4 = fact.remote(5)
|
| 74 |
+
# `120`
|
| 75 |
+
print(ray.get(ref4))
|
| 76 |
+
|
| 77 |
+
ref5 = fact.remote(10)
|
| 78 |
+
|
| 79 |
+
print([ref2, ref3, ref4, ref5])
|
| 80 |
+
# should return ref2, ref3, ref4
|
| 81 |
+
res = ray.wait([ref5, ref2, ref3, ref4], num_returns=3)
|
| 82 |
+
print(res)
|
| 83 |
+
assert [ref2, ref3, ref4] == res[0]
|
| 84 |
+
assert [ref5] == res[1]
|
| 85 |
+
|
| 86 |
+
# should return ref2, ref3, ref4, ref5
|
| 87 |
+
res = ray.wait([ref2, ref3, ref4, ref5], num_returns=4)
|
| 88 |
+
print(res)
|
| 89 |
+
assert [ref2, ref3, ref4, ref5] == res[0]
|
| 90 |
+
assert [] == res[1]
|
.venv/lib/python3.11/site-packages/ray/util/client/common.py
ADDED
|
@@ -0,0 +1,956 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import pickle
|
| 5 |
+
import threading
|
| 6 |
+
import uuid
|
| 7 |
+
from collections import OrderedDict
|
| 8 |
+
from concurrent.futures import Future
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 11 |
+
|
| 12 |
+
import grpc
|
| 13 |
+
|
| 14 |
+
import ray._raylet as raylet
|
| 15 |
+
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
| 16 |
+
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
| 17 |
+
from ray._private import ray_constants
|
| 18 |
+
from ray._private.inspect_util import (
|
| 19 |
+
is_class_method,
|
| 20 |
+
is_cython,
|
| 21 |
+
is_function_or_method,
|
| 22 |
+
is_static_method,
|
| 23 |
+
)
|
| 24 |
+
from ray._private.signature import extract_signature, get_signature
|
| 25 |
+
from ray._private.utils import check_oversized_function
|
| 26 |
+
from ray.util.client import ray
|
| 27 |
+
from ray.util.client.options import validate_options
|
| 28 |
+
|
| 29 |
+
logger = logging.getLogger(__name__)
|
| 30 |
+
|
| 31 |
+
# The maximum field value for int32 id's -- which is also the maximum
|
| 32 |
+
# number of simultaneous in-flight requests.
|
| 33 |
+
INT32_MAX = (2**31) - 1
|
| 34 |
+
|
| 35 |
+
# gRPC status codes that the client shouldn't attempt to recover from
|
| 36 |
+
# Resource exhausted: Server is low on resources, or has hit the max number
|
| 37 |
+
# of client connections
|
| 38 |
+
# Invalid argument: Reserved for application errors
|
| 39 |
+
# Not found: Set if the client is attempting to reconnect to a session that
|
| 40 |
+
# does not exist
|
| 41 |
+
# Failed precondition: Reserverd for application errors
|
| 42 |
+
# Aborted: Set when an error is serialized into the details of the context,
|
| 43 |
+
# signals that error should be deserialized on the client side
|
| 44 |
+
GRPC_UNRECOVERABLE_ERRORS = (
|
| 45 |
+
grpc.StatusCode.RESOURCE_EXHAUSTED,
|
| 46 |
+
grpc.StatusCode.INVALID_ARGUMENT,
|
| 47 |
+
grpc.StatusCode.NOT_FOUND,
|
| 48 |
+
grpc.StatusCode.FAILED_PRECONDITION,
|
| 49 |
+
grpc.StatusCode.ABORTED,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# TODO: Instead of just making the max message size large, the right thing to
|
| 53 |
+
# do is to split up the bytes representation of serialized data into multiple
|
| 54 |
+
# messages and reconstruct them on either end. That said, since clients are
|
| 55 |
+
# drivers and really just feed initial things in and final results out, (when
|
| 56 |
+
# not going to S3 or similar) then a large limit will suffice for many use
|
| 57 |
+
# cases.
|
| 58 |
+
#
|
| 59 |
+
# Currently, this is 2GiB, the max for a signed int.
|
| 60 |
+
GRPC_MAX_MESSAGE_SIZE = (2 * 1024 * 1024 * 1024) - 1
|
| 61 |
+
|
| 62 |
+
# 30 seconds because ELB timeout is 60 seconds
|
| 63 |
+
GRPC_KEEPALIVE_TIME_MS = 1000 * 30
|
| 64 |
+
|
| 65 |
+
# Long timeout because we do not want gRPC ending a connection.
|
| 66 |
+
GRPC_KEEPALIVE_TIMEOUT_MS = 1000 * 600
|
| 67 |
+
|
| 68 |
+
GRPC_OPTIONS = [
|
| 69 |
+
*ray_constants.GLOBAL_GRPC_OPTIONS,
|
| 70 |
+
("grpc.max_send_message_length", GRPC_MAX_MESSAGE_SIZE),
|
| 71 |
+
("grpc.max_receive_message_length", GRPC_MAX_MESSAGE_SIZE),
|
| 72 |
+
("grpc.keepalive_time_ms", GRPC_KEEPALIVE_TIME_MS),
|
| 73 |
+
("grpc.keepalive_timeout_ms", GRPC_KEEPALIVE_TIMEOUT_MS),
|
| 74 |
+
("grpc.keepalive_permit_without_calls", 1),
|
| 75 |
+
# Send an infinite number of pings
|
| 76 |
+
("grpc.http2.max_pings_without_data", 0),
|
| 77 |
+
("grpc.http2.min_ping_interval_without_data_ms", GRPC_KEEPALIVE_TIME_MS - 50),
|
| 78 |
+
# Allow many strikes
|
| 79 |
+
("grpc.http2.max_ping_strikes", 0),
|
| 80 |
+
]
|
| 81 |
+
|
| 82 |
+
CLIENT_SERVER_MAX_THREADS = float(os.getenv("RAY_CLIENT_SERVER_MAX_THREADS", 100))
|
| 83 |
+
|
| 84 |
+
# Large objects are chunked into 5 MiB messages, ref PR #35025
|
| 85 |
+
OBJECT_TRANSFER_CHUNK_SIZE = 5 * 2**20
|
| 86 |
+
|
| 87 |
+
# Warn the user if the object being transferred is larger than 2 GiB
|
| 88 |
+
OBJECT_TRANSFER_WARNING_SIZE = 2 * 2**30
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class ClientObjectRef(raylet.ObjectRef):
|
| 92 |
+
def __init__(self, id: Union[bytes, Future]):
|
| 93 |
+
self._mutex = threading.Lock()
|
| 94 |
+
self._worker = ray.get_context().client_worker
|
| 95 |
+
self._id_future = None
|
| 96 |
+
if isinstance(id, bytes):
|
| 97 |
+
self._set_id(id)
|
| 98 |
+
elif isinstance(id, Future):
|
| 99 |
+
self._id_future = id
|
| 100 |
+
else:
|
| 101 |
+
raise TypeError("Unexpected type for id {}".format(id))
|
| 102 |
+
|
| 103 |
+
def __del__(self):
|
| 104 |
+
if self._worker is not None and self._worker.is_connected():
|
| 105 |
+
try:
|
| 106 |
+
if not self.is_nil():
|
| 107 |
+
self._worker.call_release(self.id)
|
| 108 |
+
except Exception:
|
| 109 |
+
logger.info(
|
| 110 |
+
"Exception in ObjectRef is ignored in destructor. "
|
| 111 |
+
"To receive this exception in application code, call "
|
| 112 |
+
"a method on the actor reference before its destructor "
|
| 113 |
+
"is run."
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def binary(self):
|
| 117 |
+
self._wait_for_id()
|
| 118 |
+
return super().binary()
|
| 119 |
+
|
| 120 |
+
def hex(self):
|
| 121 |
+
self._wait_for_id()
|
| 122 |
+
return super().hex()
|
| 123 |
+
|
| 124 |
+
def is_nil(self):
|
| 125 |
+
self._wait_for_id()
|
| 126 |
+
return super().is_nil()
|
| 127 |
+
|
| 128 |
+
def __hash__(self):
|
| 129 |
+
self._wait_for_id()
|
| 130 |
+
return hash(self.id)
|
| 131 |
+
|
| 132 |
+
def task_id(self):
|
| 133 |
+
self._wait_for_id()
|
| 134 |
+
return super().task_id()
|
| 135 |
+
|
| 136 |
+
@property
|
| 137 |
+
def id(self):
|
| 138 |
+
return self.binary()
|
| 139 |
+
|
| 140 |
+
def future(self) -> Future:
|
| 141 |
+
fut = Future()
|
| 142 |
+
|
| 143 |
+
def set_future(data: Any) -> None:
|
| 144 |
+
"""Schedules a callback to set the exception or result
|
| 145 |
+
in the Future."""
|
| 146 |
+
|
| 147 |
+
if isinstance(data, Exception):
|
| 148 |
+
fut.set_exception(data)
|
| 149 |
+
else:
|
| 150 |
+
fut.set_result(data)
|
| 151 |
+
|
| 152 |
+
self._on_completed(set_future)
|
| 153 |
+
|
| 154 |
+
# Prevent this object ref from being released.
|
| 155 |
+
fut.object_ref = self
|
| 156 |
+
return fut
|
| 157 |
+
|
| 158 |
+
def _on_completed(self, py_callback: Callable[[Any], None]) -> None:
|
| 159 |
+
"""Register a callback that will be called after Object is ready.
|
| 160 |
+
If the ObjectRef is already ready, the callback will be called soon.
|
| 161 |
+
The callback should take the result as the only argument. The result
|
| 162 |
+
can be an exception object in case of task error.
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
def deserialize_obj(
|
| 166 |
+
resp: Union[ray_client_pb2.DataResponse, Exception]
|
| 167 |
+
) -> None:
|
| 168 |
+
from ray.util.client.client_pickler import loads_from_server
|
| 169 |
+
|
| 170 |
+
if isinstance(resp, Exception):
|
| 171 |
+
data = resp
|
| 172 |
+
elif isinstance(resp, bytearray):
|
| 173 |
+
data = loads_from_server(resp)
|
| 174 |
+
else:
|
| 175 |
+
obj = resp.get
|
| 176 |
+
data = None
|
| 177 |
+
if not obj.valid:
|
| 178 |
+
data = loads_from_server(resp.get.error)
|
| 179 |
+
else:
|
| 180 |
+
data = loads_from_server(resp.get.data)
|
| 181 |
+
|
| 182 |
+
py_callback(data)
|
| 183 |
+
|
| 184 |
+
self._worker.register_callback(self, deserialize_obj)
|
| 185 |
+
|
| 186 |
+
def _set_id(self, id):
|
| 187 |
+
super()._set_id(id)
|
| 188 |
+
self._worker.call_retain(id)
|
| 189 |
+
|
| 190 |
+
def _wait_for_id(self, timeout=None):
|
| 191 |
+
if self._id_future:
|
| 192 |
+
with self._mutex:
|
| 193 |
+
if self._id_future:
|
| 194 |
+
self._set_id(self._id_future.result(timeout=timeout))
|
| 195 |
+
self._id_future = None
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class ClientActorRef(raylet.ActorID):
|
| 199 |
+
def __init__(
|
| 200 |
+
self,
|
| 201 |
+
id: Union[bytes, Future],
|
| 202 |
+
weak_ref: Optional[bool] = False,
|
| 203 |
+
):
|
| 204 |
+
self._weak_ref = weak_ref
|
| 205 |
+
self._mutex = threading.Lock()
|
| 206 |
+
self._worker = ray.get_context().client_worker
|
| 207 |
+
if isinstance(id, bytes):
|
| 208 |
+
self._set_id(id)
|
| 209 |
+
self._id_future = None
|
| 210 |
+
elif isinstance(id, Future):
|
| 211 |
+
self._id_future = id
|
| 212 |
+
else:
|
| 213 |
+
raise TypeError("Unexpected type for id {}".format(id))
|
| 214 |
+
|
| 215 |
+
def __del__(self):
|
| 216 |
+
if self._weak_ref:
|
| 217 |
+
return
|
| 218 |
+
|
| 219 |
+
if self._worker is not None and self._worker.is_connected():
|
| 220 |
+
try:
|
| 221 |
+
if not self.is_nil():
|
| 222 |
+
self._worker.call_release(self.id)
|
| 223 |
+
except Exception:
|
| 224 |
+
logger.debug(
|
| 225 |
+
"Exception from actor creation is ignored in destructor. "
|
| 226 |
+
"To receive this exception in application code, call "
|
| 227 |
+
"a method on the actor reference before its destructor "
|
| 228 |
+
"is run."
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
def binary(self):
|
| 232 |
+
self._wait_for_id()
|
| 233 |
+
return super().binary()
|
| 234 |
+
|
| 235 |
+
def hex(self):
|
| 236 |
+
self._wait_for_id()
|
| 237 |
+
return super().hex()
|
| 238 |
+
|
| 239 |
+
def is_nil(self):
|
| 240 |
+
self._wait_for_id()
|
| 241 |
+
return super().is_nil()
|
| 242 |
+
|
| 243 |
+
def __hash__(self):
|
| 244 |
+
self._wait_for_id()
|
| 245 |
+
return hash(self.id)
|
| 246 |
+
|
| 247 |
+
@property
|
| 248 |
+
def id(self):
|
| 249 |
+
return self.binary()
|
| 250 |
+
|
| 251 |
+
def _set_id(self, id):
|
| 252 |
+
super()._set_id(id)
|
| 253 |
+
self._worker.call_retain(id)
|
| 254 |
+
|
| 255 |
+
def _wait_for_id(self, timeout=None):
|
| 256 |
+
if self._id_future:
|
| 257 |
+
with self._mutex:
|
| 258 |
+
if self._id_future:
|
| 259 |
+
self._set_id(self._id_future.result(timeout=timeout))
|
| 260 |
+
self._id_future = None
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class ClientStub:
|
| 264 |
+
pass
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class ClientRemoteFunc(ClientStub):
|
| 268 |
+
"""A stub created on the Ray Client to represent a remote
|
| 269 |
+
function that can be exectued on the cluster.
|
| 270 |
+
|
| 271 |
+
This class is allowed to be passed around between remote functions.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
_func: The actual function to execute remotely
|
| 275 |
+
_name: The original name of the function
|
| 276 |
+
_ref: The ClientObjectRef of the pickled code of the function, _func
|
| 277 |
+
"""
|
| 278 |
+
|
| 279 |
+
def __init__(self, f, options=None):
|
| 280 |
+
self._lock = threading.Lock()
|
| 281 |
+
self._func = f
|
| 282 |
+
self._name = f.__name__
|
| 283 |
+
self._signature = get_signature(f)
|
| 284 |
+
self._ref = None
|
| 285 |
+
self._client_side_ref = ClientSideRefID.generate_id()
|
| 286 |
+
self._options = validate_options(options)
|
| 287 |
+
|
| 288 |
+
def __call__(self, *args, **kwargs):
|
| 289 |
+
raise TypeError(
|
| 290 |
+
"Remote function cannot be called directly. "
|
| 291 |
+
f"Use {self._name}.remote method instead"
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
def remote(self, *args, **kwargs):
|
| 295 |
+
# Check if supplied parameters match the function signature. Same case
|
| 296 |
+
# at the other callsites.
|
| 297 |
+
self._signature.bind(*args, **kwargs)
|
| 298 |
+
return return_refs(ray.call_remote(self, *args, **kwargs))
|
| 299 |
+
|
| 300 |
+
def options(self, **kwargs):
|
| 301 |
+
return OptionWrapper(self, kwargs)
|
| 302 |
+
|
| 303 |
+
def _remote(self, args=None, kwargs=None, **option_args):
|
| 304 |
+
if args is None:
|
| 305 |
+
args = []
|
| 306 |
+
if kwargs is None:
|
| 307 |
+
kwargs = {}
|
| 308 |
+
return self.options(**option_args).remote(*args, **kwargs)
|
| 309 |
+
|
| 310 |
+
def __repr__(self):
|
| 311 |
+
return "ClientRemoteFunc(%s, %s)" % (self._name, self._ref)
|
| 312 |
+
|
| 313 |
+
def _ensure_ref(self):
|
| 314 |
+
with self._lock:
|
| 315 |
+
if self._ref is None:
|
| 316 |
+
# While calling ray.put() on our function, if
|
| 317 |
+
# our function is recursive, it will attempt to
|
| 318 |
+
# encode the ClientRemoteFunc -- itself -- and
|
| 319 |
+
# infinitely recurse on _ensure_ref.
|
| 320 |
+
#
|
| 321 |
+
# So we set the state of the reference to be an
|
| 322 |
+
# in-progress self reference value, which
|
| 323 |
+
# the encoding can detect and handle correctly.
|
| 324 |
+
self._ref = InProgressSentinel()
|
| 325 |
+
data = ray.worker._dumps_from_client(self._func)
|
| 326 |
+
# Check pickled size before sending it to server, which is more
|
| 327 |
+
# efficient and can be done synchronously inside remote() call.
|
| 328 |
+
check_oversized_function(data, self._name, "remote function", None)
|
| 329 |
+
self._ref = ray.worker._put_pickled(
|
| 330 |
+
data, client_ref_id=self._client_side_ref.id
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
|
| 334 |
+
self._ensure_ref()
|
| 335 |
+
task = ray_client_pb2.ClientTask()
|
| 336 |
+
task.type = ray_client_pb2.ClientTask.FUNCTION
|
| 337 |
+
task.name = self._name
|
| 338 |
+
task.payload_id = self._ref.id
|
| 339 |
+
set_task_options(task, self._options, "baseline_options")
|
| 340 |
+
return task
|
| 341 |
+
|
| 342 |
+
def _num_returns(self) -> int:
|
| 343 |
+
if not self._options:
|
| 344 |
+
return None
|
| 345 |
+
return self._options.get("num_returns")
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
class ClientActorClass(ClientStub):
|
| 349 |
+
"""A stub created on the Ray Client to represent an actor class.
|
| 350 |
+
|
| 351 |
+
It is wrapped by ray.remote and can be executed on the cluster.
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
actor_cls: The actual class to execute remotely
|
| 355 |
+
_name: The original name of the class
|
| 356 |
+
_ref: The ClientObjectRef of the pickled `actor_cls`
|
| 357 |
+
"""
|
| 358 |
+
|
| 359 |
+
def __init__(self, actor_cls, options=None):
|
| 360 |
+
self.actor_cls = actor_cls
|
| 361 |
+
self._lock = threading.Lock()
|
| 362 |
+
self._name = actor_cls.__name__
|
| 363 |
+
self._init_signature = inspect.Signature(
|
| 364 |
+
parameters=extract_signature(actor_cls.__init__, ignore_first=True)
|
| 365 |
+
)
|
| 366 |
+
self._ref = None
|
| 367 |
+
self._client_side_ref = ClientSideRefID.generate_id()
|
| 368 |
+
self._options = validate_options(options)
|
| 369 |
+
|
| 370 |
+
def __call__(self, *args, **kwargs):
|
| 371 |
+
raise TypeError(
|
| 372 |
+
"Remote actor cannot be instantiated directly. "
|
| 373 |
+
f"Use {self._name}.remote() instead"
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
def _ensure_ref(self):
|
| 377 |
+
with self._lock:
|
| 378 |
+
if self._ref is None:
|
| 379 |
+
# As before, set the state of the reference to be an
|
| 380 |
+
# in-progress self reference value, which
|
| 381 |
+
# the encoding can detect and handle correctly.
|
| 382 |
+
self._ref = InProgressSentinel()
|
| 383 |
+
data = ray.worker._dumps_from_client(self.actor_cls)
|
| 384 |
+
# Check pickled size before sending it to server, which is more
|
| 385 |
+
# efficient and can be done synchronously inside remote() call.
|
| 386 |
+
check_oversized_function(data, self._name, "actor", None)
|
| 387 |
+
self._ref = ray.worker._put_pickled(
|
| 388 |
+
data, client_ref_id=self._client_side_ref.id
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
def remote(self, *args, **kwargs) -> "ClientActorHandle":
|
| 392 |
+
self._init_signature.bind(*args, **kwargs)
|
| 393 |
+
# Actually instantiate the actor
|
| 394 |
+
futures = ray.call_remote(self, *args, **kwargs)
|
| 395 |
+
assert len(futures) == 1
|
| 396 |
+
return ClientActorHandle(ClientActorRef(futures[0]), actor_class=self)
|
| 397 |
+
|
| 398 |
+
def options(self, **kwargs):
|
| 399 |
+
return ActorOptionWrapper(self, kwargs)
|
| 400 |
+
|
| 401 |
+
def _remote(self, args=None, kwargs=None, **option_args):
|
| 402 |
+
if args is None:
|
| 403 |
+
args = []
|
| 404 |
+
if kwargs is None:
|
| 405 |
+
kwargs = {}
|
| 406 |
+
return self.options(**option_args).remote(*args, **kwargs)
|
| 407 |
+
|
| 408 |
+
def __repr__(self):
|
| 409 |
+
return "ClientActorClass(%s, %s)" % (self._name, self._ref)
|
| 410 |
+
|
| 411 |
+
def __getattr__(self, key):
|
| 412 |
+
if key not in self.__dict__:
|
| 413 |
+
raise AttributeError("Not a class attribute")
|
| 414 |
+
raise NotImplementedError("static methods")
|
| 415 |
+
|
| 416 |
+
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
|
| 417 |
+
self._ensure_ref()
|
| 418 |
+
task = ray_client_pb2.ClientTask()
|
| 419 |
+
task.type = ray_client_pb2.ClientTask.ACTOR
|
| 420 |
+
task.name = self._name
|
| 421 |
+
task.payload_id = self._ref.id
|
| 422 |
+
set_task_options(task, self._options, "baseline_options")
|
| 423 |
+
return task
|
| 424 |
+
|
| 425 |
+
@staticmethod
|
| 426 |
+
def _num_returns() -> int:
|
| 427 |
+
return 1
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
class ClientActorHandle(ClientStub):
|
| 431 |
+
"""Client-side stub for instantiated actor.
|
| 432 |
+
|
| 433 |
+
A stub created on the Ray Client to represent a remote actor that
|
| 434 |
+
has been started on the cluster. This class is allowed to be passed
|
| 435 |
+
around between remote functions.
|
| 436 |
+
|
| 437 |
+
Args:
|
| 438 |
+
actor_ref: A reference to the running actor given to the client. This
|
| 439 |
+
is a serialized version of the actual handle as an opaque token.
|
| 440 |
+
"""
|
| 441 |
+
|
| 442 |
+
def __init__(
|
| 443 |
+
self,
|
| 444 |
+
actor_ref: ClientActorRef,
|
| 445 |
+
actor_class: Optional[ClientActorClass] = None,
|
| 446 |
+
):
|
| 447 |
+
self.actor_ref = actor_ref
|
| 448 |
+
self._dir: Optional[List[str]] = None
|
| 449 |
+
if actor_class is not None:
|
| 450 |
+
self._method_num_returns = {}
|
| 451 |
+
self._method_signatures = {}
|
| 452 |
+
for method_name, method_obj in inspect.getmembers(
|
| 453 |
+
actor_class.actor_cls, is_function_or_method
|
| 454 |
+
):
|
| 455 |
+
self._method_num_returns[method_name] = getattr(
|
| 456 |
+
method_obj, "__ray_num_returns__", None
|
| 457 |
+
)
|
| 458 |
+
self._method_signatures[method_name] = inspect.Signature(
|
| 459 |
+
parameters=extract_signature(
|
| 460 |
+
method_obj,
|
| 461 |
+
ignore_first=(
|
| 462 |
+
not (
|
| 463 |
+
is_class_method(method_obj)
|
| 464 |
+
or is_static_method(actor_class.actor_cls, method_name)
|
| 465 |
+
)
|
| 466 |
+
),
|
| 467 |
+
)
|
| 468 |
+
)
|
| 469 |
+
else:
|
| 470 |
+
self._method_num_returns = None
|
| 471 |
+
self._method_signatures = None
|
| 472 |
+
|
| 473 |
+
def __dir__(self) -> List[str]:
|
| 474 |
+
if self._method_num_returns is not None:
|
| 475 |
+
return self._method_num_returns.keys()
|
| 476 |
+
if ray.is_connected():
|
| 477 |
+
self._init_class_info()
|
| 478 |
+
return self._method_num_returns.keys()
|
| 479 |
+
return super().__dir__()
|
| 480 |
+
|
| 481 |
+
# For compatibility with core worker ActorHandle._actor_id which returns
|
| 482 |
+
# ActorID
|
| 483 |
+
@property
|
| 484 |
+
def _actor_id(self) -> ClientActorRef:
|
| 485 |
+
return self.actor_ref
|
| 486 |
+
|
| 487 |
+
def __hash__(self) -> int:
|
| 488 |
+
return hash(self._actor_id)
|
| 489 |
+
|
| 490 |
+
def __eq__(self, __value) -> bool:
|
| 491 |
+
return hash(self) == hash(__value)
|
| 492 |
+
|
| 493 |
+
def __getattr__(self, key):
|
| 494 |
+
if key == "_method_num_returns":
|
| 495 |
+
# We need to explicitly handle this value since it is used below,
|
| 496 |
+
# otherwise we may end up infinitely recursing when deserializing.
|
| 497 |
+
# This can happen after unpickling an object but before
|
| 498 |
+
# _method_num_returns is correctly populated.
|
| 499 |
+
raise AttributeError(f"ClientActorRef has no attribute '{key}'")
|
| 500 |
+
|
| 501 |
+
if self._method_num_returns is None:
|
| 502 |
+
self._init_class_info()
|
| 503 |
+
if key not in self._method_signatures:
|
| 504 |
+
raise AttributeError(f"ClientActorRef has no attribute '{key}'")
|
| 505 |
+
return ClientRemoteMethod(
|
| 506 |
+
self,
|
| 507 |
+
key,
|
| 508 |
+
self._method_num_returns.get(key),
|
| 509 |
+
self._method_signatures.get(key),
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
def __repr__(self):
|
| 513 |
+
return "ClientActorHandle(%s)" % (self.actor_ref.id.hex())
|
| 514 |
+
|
| 515 |
+
def _init_class_info(self):
|
| 516 |
+
# TODO: fetch Ray method decorators
|
| 517 |
+
@ray.remote(num_cpus=0)
|
| 518 |
+
def get_class_info(x):
|
| 519 |
+
return x._ray_method_num_returns, x._ray_method_signatures
|
| 520 |
+
|
| 521 |
+
self._method_num_returns, method_parameters = ray.get(
|
| 522 |
+
get_class_info.remote(self)
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
self._method_signatures = {}
|
| 526 |
+
for method, parameters in method_parameters.items():
|
| 527 |
+
self._method_signatures[method] = inspect.Signature(parameters=parameters)
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
class ClientRemoteMethod(ClientStub):
|
| 531 |
+
"""A stub for a method on a remote actor.
|
| 532 |
+
|
| 533 |
+
Can be annotated with execution options.
|
| 534 |
+
|
| 535 |
+
Args:
|
| 536 |
+
actor_handle: A reference to the ClientActorHandle that generated
|
| 537 |
+
this method and will have this method called upon it.
|
| 538 |
+
method_name: The name of this method
|
| 539 |
+
"""
|
| 540 |
+
|
| 541 |
+
def __init__(
|
| 542 |
+
self,
|
| 543 |
+
actor_handle: ClientActorHandle,
|
| 544 |
+
method_name: str,
|
| 545 |
+
num_returns: int,
|
| 546 |
+
signature: inspect.Signature,
|
| 547 |
+
):
|
| 548 |
+
self._actor_handle = actor_handle
|
| 549 |
+
self._method_name = method_name
|
| 550 |
+
self._method_num_returns = num_returns
|
| 551 |
+
self._signature = signature
|
| 552 |
+
|
| 553 |
+
def __call__(self, *args, **kwargs):
|
| 554 |
+
raise TypeError(
|
| 555 |
+
"Actor methods cannot be called directly. Instead "
|
| 556 |
+
f"of running 'object.{self._method_name}()', try "
|
| 557 |
+
f"'object.{self._method_name}.remote()'."
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
def remote(self, *args, **kwargs):
|
| 561 |
+
self._signature.bind(*args, **kwargs)
|
| 562 |
+
return return_refs(ray.call_remote(self, *args, **kwargs))
|
| 563 |
+
|
| 564 |
+
def __repr__(self):
|
| 565 |
+
return "ClientRemoteMethod(%s, %s, %s)" % (
|
| 566 |
+
self._method_name,
|
| 567 |
+
self._actor_handle,
|
| 568 |
+
self._method_num_returns,
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
def options(self, **kwargs):
|
| 572 |
+
return OptionWrapper(self, kwargs)
|
| 573 |
+
|
| 574 |
+
def _remote(self, args=None, kwargs=None, **option_args):
|
| 575 |
+
if args is None:
|
| 576 |
+
args = []
|
| 577 |
+
if kwargs is None:
|
| 578 |
+
kwargs = {}
|
| 579 |
+
return self.options(**option_args).remote(*args, **kwargs)
|
| 580 |
+
|
| 581 |
+
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
|
| 582 |
+
task = ray_client_pb2.ClientTask()
|
| 583 |
+
task.type = ray_client_pb2.ClientTask.METHOD
|
| 584 |
+
task.name = self._method_name
|
| 585 |
+
task.payload_id = self._actor_handle.actor_ref.id
|
| 586 |
+
return task
|
| 587 |
+
|
| 588 |
+
def _num_returns(self) -> int:
|
| 589 |
+
return self._method_num_returns
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
class OptionWrapper:
|
| 593 |
+
def __init__(self, stub: ClientStub, options: Optional[Dict[str, Any]]):
|
| 594 |
+
self._remote_stub = stub
|
| 595 |
+
self._options = validate_options(options)
|
| 596 |
+
|
| 597 |
+
def remote(self, *args, **kwargs):
|
| 598 |
+
self._remote_stub._signature.bind(*args, **kwargs)
|
| 599 |
+
return return_refs(ray.call_remote(self, *args, **kwargs))
|
| 600 |
+
|
| 601 |
+
def __getattr__(self, key):
|
| 602 |
+
return getattr(self._remote_stub, key)
|
| 603 |
+
|
| 604 |
+
def _prepare_client_task(self):
|
| 605 |
+
task = self._remote_stub._prepare_client_task()
|
| 606 |
+
set_task_options(task, self._options)
|
| 607 |
+
return task
|
| 608 |
+
|
| 609 |
+
def _num_returns(self) -> int:
|
| 610 |
+
if self._options:
|
| 611 |
+
num = self._options.get("num_returns")
|
| 612 |
+
if num is not None:
|
| 613 |
+
return num
|
| 614 |
+
return self._remote_stub._num_returns()
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
class ActorOptionWrapper(OptionWrapper):
|
| 618 |
+
def remote(self, *args, **kwargs):
|
| 619 |
+
self._remote_stub._init_signature.bind(*args, **kwargs)
|
| 620 |
+
futures = ray.call_remote(self, *args, **kwargs)
|
| 621 |
+
assert len(futures) == 1
|
| 622 |
+
actor_class = None
|
| 623 |
+
if isinstance(self._remote_stub, ClientActorClass):
|
| 624 |
+
actor_class = self._remote_stub
|
| 625 |
+
return ClientActorHandle(ClientActorRef(futures[0]), actor_class=actor_class)
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
def set_task_options(
|
| 629 |
+
task: ray_client_pb2.ClientTask,
|
| 630 |
+
options: Optional[Dict[str, Any]],
|
| 631 |
+
field: str = "options",
|
| 632 |
+
) -> None:
|
| 633 |
+
if options is None:
|
| 634 |
+
task.ClearField(field)
|
| 635 |
+
return
|
| 636 |
+
|
| 637 |
+
getattr(task, field).pickled_options = pickle.dumps(options)
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
def return_refs(
|
| 641 |
+
futures: List[Future],
|
| 642 |
+
) -> Union[None, ClientObjectRef, List[ClientObjectRef]]:
|
| 643 |
+
if not futures:
|
| 644 |
+
return None
|
| 645 |
+
if len(futures) == 1:
|
| 646 |
+
return ClientObjectRef(futures[0])
|
| 647 |
+
return [ClientObjectRef(fut) for fut in futures]
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
class InProgressSentinel:
|
| 651 |
+
def __repr__(self) -> str:
|
| 652 |
+
return self.__class__.__name__
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
class ClientSideRefID:
|
| 656 |
+
"""An ID generated by the client for objects not yet given an ObjectRef"""
|
| 657 |
+
|
| 658 |
+
def __init__(self, id: bytes):
|
| 659 |
+
assert len(id) != 0
|
| 660 |
+
self.id = id
|
| 661 |
+
|
| 662 |
+
@staticmethod
|
| 663 |
+
def generate_id() -> "ClientSideRefID":
|
| 664 |
+
tid = uuid.uuid4()
|
| 665 |
+
return ClientSideRefID(b"\xcc" + tid.bytes)
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
def remote_decorator(options: Optional[Dict[str, Any]]):
|
| 669 |
+
def decorator(function_or_class) -> ClientStub:
|
| 670 |
+
if inspect.isfunction(function_or_class) or is_cython(function_or_class):
|
| 671 |
+
return ClientRemoteFunc(function_or_class, options=options)
|
| 672 |
+
elif inspect.isclass(function_or_class):
|
| 673 |
+
return ClientActorClass(function_or_class, options=options)
|
| 674 |
+
else:
|
| 675 |
+
raise TypeError(
|
| 676 |
+
"The @ray.remote decorator must be applied to "
|
| 677 |
+
"either a function or to a class."
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
return decorator
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
@dataclass
|
| 684 |
+
class ClientServerHandle:
|
| 685 |
+
"""Holds the handles to the registered gRPC servicers and their server."""
|
| 686 |
+
|
| 687 |
+
task_servicer: ray_client_pb2_grpc.RayletDriverServicer
|
| 688 |
+
data_servicer: ray_client_pb2_grpc.RayletDataStreamerServicer
|
| 689 |
+
logs_servicer: ray_client_pb2_grpc.RayletLogStreamerServicer
|
| 690 |
+
grpc_server: grpc.Server
|
| 691 |
+
|
| 692 |
+
def stop(self, grace: int) -> None:
|
| 693 |
+
# The data servicer might be sleeping while waiting for clients to
|
| 694 |
+
# reconnect. Signal that they no longer have to sleep and can exit
|
| 695 |
+
# immediately, since the RPC server is stopped.
|
| 696 |
+
self.grpc_server.stop(grace)
|
| 697 |
+
self.data_servicer.stopped.set()
|
| 698 |
+
|
| 699 |
+
# Add a hook for all the cases that previously
|
| 700 |
+
# expected simply a gRPC server
|
| 701 |
+
def __getattr__(self, attr):
|
| 702 |
+
return getattr(self.grpc_server, attr)
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
def _get_client_id_from_context(context: Any) -> str:
|
| 706 |
+
"""
|
| 707 |
+
Get `client_id` from gRPC metadata. If the `client_id` is not present,
|
| 708 |
+
this function logs an error and sets the status_code.
|
| 709 |
+
"""
|
| 710 |
+
metadata = {k: v for k, v in context.invocation_metadata()}
|
| 711 |
+
client_id = metadata.get("client_id") or ""
|
| 712 |
+
if client_id == "":
|
| 713 |
+
logger.error("Client connecting with no client_id")
|
| 714 |
+
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
| 715 |
+
return client_id
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
def _propagate_error_in_context(e: Exception, context: Any) -> bool:
|
| 719 |
+
"""
|
| 720 |
+
Encode an error into the context of an RPC response. Returns True
|
| 721 |
+
if the error can be recovered from, false otherwise
|
| 722 |
+
"""
|
| 723 |
+
try:
|
| 724 |
+
if isinstance(e, grpc.RpcError):
|
| 725 |
+
# RPC error, propagate directly by copying details into context
|
| 726 |
+
context.set_code(e.code())
|
| 727 |
+
context.set_details(e.details())
|
| 728 |
+
return e.code() not in GRPC_UNRECOVERABLE_ERRORS
|
| 729 |
+
except Exception:
|
| 730 |
+
# Extra precaution -- if encoding the RPC directly fails fallback
|
| 731 |
+
# to treating it as a regular error
|
| 732 |
+
pass
|
| 733 |
+
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
| 734 |
+
context.set_details(str(e))
|
| 735 |
+
return False
|
| 736 |
+
|
| 737 |
+
|
| 738 |
+
def _id_is_newer(id1: int, id2: int) -> bool:
|
| 739 |
+
"""
|
| 740 |
+
We should only replace cache entries with the responses for newer IDs.
|
| 741 |
+
Most of the time newer IDs will be the ones with higher value, except when
|
| 742 |
+
the req_id counter rolls over. We check for this case by checking the
|
| 743 |
+
distance between the two IDs. If the distance is significant, then it's
|
| 744 |
+
likely that the req_id counter rolled over, and the smaller id should
|
| 745 |
+
still be used to replace the one in cache.
|
| 746 |
+
"""
|
| 747 |
+
diff = abs(id2 - id1)
|
| 748 |
+
if diff > (INT32_MAX // 2):
|
| 749 |
+
# Rollover likely occurred. In this case the smaller ID is newer
|
| 750 |
+
return id1 < id2
|
| 751 |
+
return id1 > id2
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
class ResponseCache:
|
| 755 |
+
"""
|
| 756 |
+
Cache for blocking method calls. Needed to prevent retried requests from
|
| 757 |
+
being applied multiple times on the server, for example when the client
|
| 758 |
+
disconnects. This is used to cache requests/responses sent through
|
| 759 |
+
unary-unary RPCs to the RayletServicer.
|
| 760 |
+
|
| 761 |
+
Note that no clean up logic is used, the last response for each thread
|
| 762 |
+
will always be remembered, so at most the cache will hold N entries,
|
| 763 |
+
where N is the number of threads on the client side. This relies on the
|
| 764 |
+
assumption that a thread will not make a new blocking request until it has
|
| 765 |
+
received a response for a previous one, at which point it's safe to
|
| 766 |
+
overwrite the old response.
|
| 767 |
+
|
| 768 |
+
The high level logic is:
|
| 769 |
+
|
| 770 |
+
1. Before making a call, check the cache for the current thread.
|
| 771 |
+
2. If present in the cache, check the request id of the cached
|
| 772 |
+
response.
|
| 773 |
+
a. If it matches the current request_id, then the request has been
|
| 774 |
+
received before and we shouldn't re-attempt the logic. Wait for
|
| 775 |
+
the response to become available in the cache, and then return it
|
| 776 |
+
b. If it doesn't match, then this is a new request and we can
|
| 777 |
+
proceed with calling the real stub. While the response is still
|
| 778 |
+
being generated, temporarily keep (req_id, None) in the cache.
|
| 779 |
+
Once the call is finished, update the cache entry with the
|
| 780 |
+
new (req_id, response) pair. Notify other threads that may
|
| 781 |
+
have been waiting for the response to be prepared.
|
| 782 |
+
"""
|
| 783 |
+
|
| 784 |
+
def __init__(self):
|
| 785 |
+
self.cv = threading.Condition()
|
| 786 |
+
self.cache: Dict[int, Tuple[int, Any]] = {}
|
| 787 |
+
|
| 788 |
+
def check_cache(self, thread_id: int, request_id: int) -> Optional[Any]:
|
| 789 |
+
"""
|
| 790 |
+
Check the cache for a given thread, and see if the entry in the cache
|
| 791 |
+
matches the current request_id. Returns None if the request_id has
|
| 792 |
+
not been seen yet, otherwise returns the cached result.
|
| 793 |
+
|
| 794 |
+
Throws an error if the placeholder in the cache doesn't match the
|
| 795 |
+
request_id -- this means that a new request evicted the old value in
|
| 796 |
+
the cache, and that the RPC for `request_id` is redundant and the
|
| 797 |
+
result can be discarded, i.e.:
|
| 798 |
+
|
| 799 |
+
1. Request A is sent (A1)
|
| 800 |
+
2. Channel disconnects
|
| 801 |
+
3. Request A is resent (A2)
|
| 802 |
+
4. A1 is received
|
| 803 |
+
5. A2 is received, waits for A1 to finish
|
| 804 |
+
6. A1 finishes and is sent back to client
|
| 805 |
+
7. Request B is sent
|
| 806 |
+
8. Request B overwrites cache entry
|
| 807 |
+
9. A2 wakes up extremely late, but cache is now invalid
|
| 808 |
+
|
| 809 |
+
In practice this is VERY unlikely to happen, but the error can at
|
| 810 |
+
least serve as a sanity check or catch invalid request id's.
|
| 811 |
+
"""
|
| 812 |
+
with self.cv:
|
| 813 |
+
if thread_id in self.cache:
|
| 814 |
+
cached_request_id, cached_resp = self.cache[thread_id]
|
| 815 |
+
if cached_request_id == request_id:
|
| 816 |
+
while cached_resp is None:
|
| 817 |
+
# The call was started, but the response hasn't yet
|
| 818 |
+
# been added to the cache. Let go of the lock and
|
| 819 |
+
# wait until the response is ready.
|
| 820 |
+
self.cv.wait()
|
| 821 |
+
cached_request_id, cached_resp = self.cache[thread_id]
|
| 822 |
+
if cached_request_id != request_id:
|
| 823 |
+
raise RuntimeError(
|
| 824 |
+
"Cached response doesn't match the id of the "
|
| 825 |
+
"original request. This might happen if this "
|
| 826 |
+
"request was received out of order. The "
|
| 827 |
+
"result of the caller is no longer needed. "
|
| 828 |
+
f"({request_id} != {cached_request_id})"
|
| 829 |
+
)
|
| 830 |
+
return cached_resp
|
| 831 |
+
if not _id_is_newer(request_id, cached_request_id):
|
| 832 |
+
raise RuntimeError(
|
| 833 |
+
"Attempting to replace newer cache entry with older "
|
| 834 |
+
"one. This might happen if this request was received "
|
| 835 |
+
"out of order. The result of the caller is no "
|
| 836 |
+
f"longer needed. ({request_id} != {cached_request_id}"
|
| 837 |
+
)
|
| 838 |
+
self.cache[thread_id] = (request_id, None)
|
| 839 |
+
return None
|
| 840 |
+
|
| 841 |
+
def update_cache(self, thread_id: int, request_id: int, response: Any) -> None:
|
| 842 |
+
"""
|
| 843 |
+
Inserts `response` into the cache for `request_id`.
|
| 844 |
+
"""
|
| 845 |
+
with self.cv:
|
| 846 |
+
cached_request_id, cached_resp = self.cache[thread_id]
|
| 847 |
+
if cached_request_id != request_id or cached_resp is not None:
|
| 848 |
+
# The cache was overwritten by a newer requester between
|
| 849 |
+
# our call to check_cache and our call to update it.
|
| 850 |
+
# This can't happen if the assumption that the cached requests
|
| 851 |
+
# are all blocking on the client side, so if you encounter
|
| 852 |
+
# this, check if any async requests are being cached.
|
| 853 |
+
raise RuntimeError(
|
| 854 |
+
"Attempting to update the cache, but placeholder's "
|
| 855 |
+
"do not match the current request_id. This might happen "
|
| 856 |
+
"if this request was received out of order. The result "
|
| 857 |
+
f"of the caller is no longer needed. ({request_id} != "
|
| 858 |
+
f"{cached_request_id})"
|
| 859 |
+
)
|
| 860 |
+
self.cache[thread_id] = (request_id, response)
|
| 861 |
+
self.cv.notify_all()
|
| 862 |
+
|
| 863 |
+
|
| 864 |
+
class OrderedResponseCache:
|
| 865 |
+
"""
|
| 866 |
+
Cache for streaming RPCs, i.e. the DataServicer. Relies on explicit
|
| 867 |
+
ack's from the client to determine when it can clean up cache entries.
|
| 868 |
+
"""
|
| 869 |
+
|
| 870 |
+
def __init__(self):
|
| 871 |
+
self.last_received = 0
|
| 872 |
+
self.cv = threading.Condition()
|
| 873 |
+
self.cache: Dict[int, Any] = OrderedDict()
|
| 874 |
+
|
| 875 |
+
def check_cache(self, req_id: int) -> Optional[Any]:
|
| 876 |
+
"""
|
| 877 |
+
Check the cache for a given thread, and see if the entry in the cache
|
| 878 |
+
matches the current request_id. Returns None if the request_id has
|
| 879 |
+
not been seen yet, otherwise returns the cached result.
|
| 880 |
+
"""
|
| 881 |
+
with self.cv:
|
| 882 |
+
if _id_is_newer(self.last_received, req_id) or self.last_received == req_id:
|
| 883 |
+
# Request is for an id that has already been cleared from
|
| 884 |
+
# cache/acknowledged.
|
| 885 |
+
raise RuntimeError(
|
| 886 |
+
"Attempting to accesss a cache entry that has already "
|
| 887 |
+
"cleaned up. The client has already acknowledged "
|
| 888 |
+
f"receiving this response. ({req_id}, "
|
| 889 |
+
f"{self.last_received})"
|
| 890 |
+
)
|
| 891 |
+
if req_id in self.cache:
|
| 892 |
+
cached_resp = self.cache[req_id]
|
| 893 |
+
while cached_resp is None:
|
| 894 |
+
# The call was started, but the response hasn't yet been
|
| 895 |
+
# added to the cache. Let go of the lock and wait until
|
| 896 |
+
# the response is ready
|
| 897 |
+
self.cv.wait()
|
| 898 |
+
if req_id not in self.cache:
|
| 899 |
+
raise RuntimeError(
|
| 900 |
+
"Cache entry was removed. This likely means that "
|
| 901 |
+
"the result of this call is no longer needed."
|
| 902 |
+
)
|
| 903 |
+
cached_resp = self.cache[req_id]
|
| 904 |
+
return cached_resp
|
| 905 |
+
self.cache[req_id] = None
|
| 906 |
+
return None
|
| 907 |
+
|
| 908 |
+
def update_cache(self, req_id: int, resp: Any) -> None:
|
| 909 |
+
"""
|
| 910 |
+
Inserts `response` into the cache for `request_id`.
|
| 911 |
+
"""
|
| 912 |
+
with self.cv:
|
| 913 |
+
self.cv.notify_all()
|
| 914 |
+
if req_id not in self.cache:
|
| 915 |
+
raise RuntimeError(
|
| 916 |
+
"Attempting to update the cache, but placeholder is "
|
| 917 |
+
"missing. This might happen on a redundant call to "
|
| 918 |
+
f"update_cache. ({req_id})"
|
| 919 |
+
)
|
| 920 |
+
self.cache[req_id] = resp
|
| 921 |
+
|
| 922 |
+
def invalidate(self, e: Exception) -> bool:
|
| 923 |
+
"""
|
| 924 |
+
Invalidate any partially populated cache entries, replacing their
|
| 925 |
+
placeholders with the passed in exception. Useful to prevent a thread
|
| 926 |
+
from waiting indefinitely on a failed call.
|
| 927 |
+
|
| 928 |
+
Returns True if the cache contains an error, False otherwise
|
| 929 |
+
"""
|
| 930 |
+
with self.cv:
|
| 931 |
+
invalid = False
|
| 932 |
+
for req_id in self.cache:
|
| 933 |
+
if self.cache[req_id] is None:
|
| 934 |
+
self.cache[req_id] = e
|
| 935 |
+
if isinstance(self.cache[req_id], Exception):
|
| 936 |
+
invalid = True
|
| 937 |
+
self.cv.notify_all()
|
| 938 |
+
return invalid
|
| 939 |
+
|
| 940 |
+
def cleanup(self, last_received: int) -> None:
|
| 941 |
+
"""
|
| 942 |
+
Cleanup all of the cached requests up to last_received. Assumes that
|
| 943 |
+
the cache entries were inserted in ascending order.
|
| 944 |
+
"""
|
| 945 |
+
with self.cv:
|
| 946 |
+
if _id_is_newer(last_received, self.last_received):
|
| 947 |
+
self.last_received = last_received
|
| 948 |
+
to_remove = []
|
| 949 |
+
for req_id in self.cache:
|
| 950 |
+
if _id_is_newer(last_received, req_id) or last_received == req_id:
|
| 951 |
+
to_remove.append(req_id)
|
| 952 |
+
else:
|
| 953 |
+
break
|
| 954 |
+
for req_id in to_remove:
|
| 955 |
+
del self.cache[req_id]
|
| 956 |
+
self.cv.notify_all()
|
.venv/lib/python3.11/site-packages/ray/util/client/dataclient.py
ADDED
|
@@ -0,0 +1,599 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This file implements a threaded stream controller to abstract a data stream
|
| 2 |
+
back to the ray clientserver.
|
| 3 |
+
"""
|
| 4 |
+
import math
|
| 5 |
+
import logging
|
| 6 |
+
import queue
|
| 7 |
+
import threading
|
| 8 |
+
import warnings
|
| 9 |
+
import grpc
|
| 10 |
+
|
| 11 |
+
from collections import OrderedDict
|
| 12 |
+
from typing import Any, Callable, Dict, TYPE_CHECKING, Optional, Union
|
| 13 |
+
|
| 14 |
+
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
| 15 |
+
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
| 16 |
+
from ray.util.client.common import (
|
| 17 |
+
INT32_MAX,
|
| 18 |
+
OBJECT_TRANSFER_CHUNK_SIZE,
|
| 19 |
+
OBJECT_TRANSFER_WARNING_SIZE,
|
| 20 |
+
)
|
| 21 |
+
from ray.util.debug import log_once
|
| 22 |
+
|
| 23 |
+
if TYPE_CHECKING:
|
| 24 |
+
from ray.util.client.worker import Worker
|
| 25 |
+
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
ResponseCallable = Callable[[Union[ray_client_pb2.DataResponse, Exception]], None]
|
| 29 |
+
|
| 30 |
+
# Send an acknowledge on every 32nd response received
|
| 31 |
+
ACKNOWLEDGE_BATCH_SIZE = 32
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def chunk_put(req: ray_client_pb2.DataRequest):
|
| 35 |
+
"""
|
| 36 |
+
Chunks a put request. Doing this lazily is important for large objects,
|
| 37 |
+
since taking slices of bytes objects does a copy. This means if we
|
| 38 |
+
immediately materialized every chunk of a large object and inserted them
|
| 39 |
+
into the result_queue, we would effectively double the memory needed
|
| 40 |
+
on the client to handle the put.
|
| 41 |
+
"""
|
| 42 |
+
# When accessing a protobuf field, deserialization is performed, which will
|
| 43 |
+
# generate a copy. So we need to avoid accessing the `data` field multiple
|
| 44 |
+
# times in the loop
|
| 45 |
+
request_data = req.put.data
|
| 46 |
+
total_size = len(request_data)
|
| 47 |
+
assert total_size > 0, "Cannot chunk object with missing data"
|
| 48 |
+
if total_size >= OBJECT_TRANSFER_WARNING_SIZE and log_once(
|
| 49 |
+
"client_object_put_size_warning"
|
| 50 |
+
):
|
| 51 |
+
size_gb = total_size / 2**30
|
| 52 |
+
warnings.warn(
|
| 53 |
+
"Ray Client is attempting to send a "
|
| 54 |
+
f"{size_gb:.2f} GiB object over the network, which may "
|
| 55 |
+
"be slow. Consider serializing the object and using a remote "
|
| 56 |
+
"URI to transfer via S3 or Google Cloud Storage instead. "
|
| 57 |
+
"Documentation for doing this can be found here: "
|
| 58 |
+
"https://docs.ray.io/en/latest/handling-dependencies.html#remote-uris",
|
| 59 |
+
UserWarning,
|
| 60 |
+
)
|
| 61 |
+
total_chunks = math.ceil(total_size / OBJECT_TRANSFER_CHUNK_SIZE)
|
| 62 |
+
for chunk_id in range(0, total_chunks):
|
| 63 |
+
start = chunk_id * OBJECT_TRANSFER_CHUNK_SIZE
|
| 64 |
+
end = min(total_size, (chunk_id + 1) * OBJECT_TRANSFER_CHUNK_SIZE)
|
| 65 |
+
chunk = ray_client_pb2.PutRequest(
|
| 66 |
+
client_ref_id=req.put.client_ref_id,
|
| 67 |
+
data=request_data[start:end],
|
| 68 |
+
chunk_id=chunk_id,
|
| 69 |
+
total_chunks=total_chunks,
|
| 70 |
+
total_size=total_size,
|
| 71 |
+
owner_id=req.put.owner_id,
|
| 72 |
+
)
|
| 73 |
+
yield ray_client_pb2.DataRequest(req_id=req.req_id, put=chunk)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def chunk_task(req: ray_client_pb2.DataRequest):
|
| 77 |
+
"""
|
| 78 |
+
Chunks a client task. Doing this lazily is important with large arguments,
|
| 79 |
+
since taking slices of bytes objects does a copy. This means if we
|
| 80 |
+
immediately materialized every chunk of a large argument and inserted them
|
| 81 |
+
into the result_queue, we would effectively double the memory needed
|
| 82 |
+
on the client to handle the task.
|
| 83 |
+
"""
|
| 84 |
+
# When accessing a protobuf field, deserialization is performed, which will
|
| 85 |
+
# generate a copy. So we need to avoid accessing the `data` field multiple
|
| 86 |
+
# times in the loop
|
| 87 |
+
request_data = req.task.data
|
| 88 |
+
total_size = len(request_data)
|
| 89 |
+
assert total_size > 0, "Cannot chunk object with missing data"
|
| 90 |
+
total_chunks = math.ceil(total_size / OBJECT_TRANSFER_CHUNK_SIZE)
|
| 91 |
+
for chunk_id in range(0, total_chunks):
|
| 92 |
+
start = chunk_id * OBJECT_TRANSFER_CHUNK_SIZE
|
| 93 |
+
end = min(total_size, (chunk_id + 1) * OBJECT_TRANSFER_CHUNK_SIZE)
|
| 94 |
+
chunk = ray_client_pb2.ClientTask(
|
| 95 |
+
type=req.task.type,
|
| 96 |
+
name=req.task.name,
|
| 97 |
+
payload_id=req.task.payload_id,
|
| 98 |
+
client_id=req.task.client_id,
|
| 99 |
+
options=req.task.options,
|
| 100 |
+
baseline_options=req.task.baseline_options,
|
| 101 |
+
namespace=req.task.namespace,
|
| 102 |
+
data=request_data[start:end],
|
| 103 |
+
chunk_id=chunk_id,
|
| 104 |
+
total_chunks=total_chunks,
|
| 105 |
+
)
|
| 106 |
+
yield ray_client_pb2.DataRequest(req_id=req.req_id, task=chunk)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class ChunkCollector:
|
| 110 |
+
"""
|
| 111 |
+
This object collects chunks from async get requests via __call__, and
|
| 112 |
+
calls the underlying callback when the object is fully received, or if an
|
| 113 |
+
exception while retrieving the object occurs.
|
| 114 |
+
|
| 115 |
+
This is not used in synchronous gets (synchronous gets interact with the
|
| 116 |
+
raylet servicer directly, not through the datapath).
|
| 117 |
+
|
| 118 |
+
__call__ returns true once the underlying call back has been called.
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
def __init__(self, callback: ResponseCallable, request: ray_client_pb2.DataRequest):
|
| 122 |
+
# Bytearray containing data received so far
|
| 123 |
+
self.data = bytearray()
|
| 124 |
+
# The callback that will be called once all data is received
|
| 125 |
+
self.callback = callback
|
| 126 |
+
# The id of the last chunk we've received, or -1 if haven't seen any yet
|
| 127 |
+
self.last_seen_chunk = -1
|
| 128 |
+
# The GetRequest that initiated the transfer. start_chunk_id will be
|
| 129 |
+
# updated as chunks are received to avoid re-requesting chunks that
|
| 130 |
+
# we've already received.
|
| 131 |
+
self.request = request
|
| 132 |
+
|
| 133 |
+
def __call__(self, response: Union[ray_client_pb2.DataResponse, Exception]) -> bool:
|
| 134 |
+
if isinstance(response, Exception):
|
| 135 |
+
self.callback(response)
|
| 136 |
+
return True
|
| 137 |
+
get_resp = response.get
|
| 138 |
+
if not get_resp.valid:
|
| 139 |
+
self.callback(response)
|
| 140 |
+
return True
|
| 141 |
+
if get_resp.total_size > OBJECT_TRANSFER_WARNING_SIZE and log_once(
|
| 142 |
+
"client_object_transfer_size_warning"
|
| 143 |
+
):
|
| 144 |
+
size_gb = get_resp.total_size / 2**30
|
| 145 |
+
warnings.warn(
|
| 146 |
+
"Ray Client is attempting to retrieve a "
|
| 147 |
+
f"{size_gb:.2f} GiB object over the network, which may "
|
| 148 |
+
"be slow. Consider serializing the object to a file and "
|
| 149 |
+
"using rsync or S3 instead.",
|
| 150 |
+
UserWarning,
|
| 151 |
+
)
|
| 152 |
+
chunk_data = get_resp.data
|
| 153 |
+
chunk_id = get_resp.chunk_id
|
| 154 |
+
if chunk_id == self.last_seen_chunk + 1:
|
| 155 |
+
self.data.extend(chunk_data)
|
| 156 |
+
self.last_seen_chunk = chunk_id
|
| 157 |
+
# If we disconnect partway through, restart the get request
|
| 158 |
+
# at the first chunk we haven't seen
|
| 159 |
+
self.request.get.start_chunk_id = self.last_seen_chunk + 1
|
| 160 |
+
elif chunk_id > self.last_seen_chunk + 1:
|
| 161 |
+
# A chunk was skipped. This shouldn't happen in practice since
|
| 162 |
+
# grpc guarantees that chunks will arrive in order.
|
| 163 |
+
msg = (
|
| 164 |
+
f"Received chunk {chunk_id} when we expected "
|
| 165 |
+
f"{self.last_seen_chunk + 1} for request {response.req_id}"
|
| 166 |
+
)
|
| 167 |
+
logger.warning(msg)
|
| 168 |
+
self.callback(RuntimeError(msg))
|
| 169 |
+
return True
|
| 170 |
+
else:
|
| 171 |
+
# We received a chunk that've already seen before. Ignore, since
|
| 172 |
+
# it should already be appended to self.data.
|
| 173 |
+
logger.debug(
|
| 174 |
+
f"Received a repeated chunk {chunk_id} "
|
| 175 |
+
f"from request {response.req_id}."
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
if get_resp.chunk_id == get_resp.total_chunks - 1:
|
| 179 |
+
self.callback(self.data)
|
| 180 |
+
return True
|
| 181 |
+
else:
|
| 182 |
+
# Not done yet
|
| 183 |
+
return False
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class DataClient:
|
| 187 |
+
def __init__(self, client_worker: "Worker", client_id: str, metadata: list):
|
| 188 |
+
"""Initializes a thread-safe datapath over a Ray Client gRPC channel.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
client_worker: The Ray Client worker that manages this client
|
| 192 |
+
client_id: the generated ID representing this client
|
| 193 |
+
metadata: metadata to pass to gRPC requests
|
| 194 |
+
"""
|
| 195 |
+
self.client_worker = client_worker
|
| 196 |
+
self._client_id = client_id
|
| 197 |
+
self._metadata = metadata
|
| 198 |
+
self.data_thread = self._start_datathread()
|
| 199 |
+
|
| 200 |
+
# Track outstanding requests to resend in case of disconnection
|
| 201 |
+
self.outstanding_requests: Dict[int, Any] = OrderedDict()
|
| 202 |
+
|
| 203 |
+
# Serialize access to all mutable internal states: self.request_queue,
|
| 204 |
+
# self.ready_data, self.asyncio_waiting_data,
|
| 205 |
+
# self._in_shutdown, self._req_id, self.outstanding_requests and
|
| 206 |
+
# calling self._next_id()
|
| 207 |
+
self.lock = threading.Lock()
|
| 208 |
+
|
| 209 |
+
# Waiting for response or shutdown.
|
| 210 |
+
self.cv = threading.Condition(lock=self.lock)
|
| 211 |
+
|
| 212 |
+
self.request_queue = self._create_queue()
|
| 213 |
+
self.ready_data: Dict[int, Any] = {}
|
| 214 |
+
# NOTE: Dictionary insertion is guaranteed to complete before lookup
|
| 215 |
+
# and/or removal because of synchronization via the request_queue.
|
| 216 |
+
self.asyncio_waiting_data: Dict[int, ResponseCallable] = {}
|
| 217 |
+
self._in_shutdown = False
|
| 218 |
+
self._req_id = 0
|
| 219 |
+
self._last_exception = None
|
| 220 |
+
self._acknowledge_counter = 0
|
| 221 |
+
|
| 222 |
+
self.data_thread.start()
|
| 223 |
+
|
| 224 |
+
# Must hold self.lock when calling this function.
|
| 225 |
+
def _next_id(self) -> int:
|
| 226 |
+
assert self.lock.locked()
|
| 227 |
+
self._req_id += 1
|
| 228 |
+
if self._req_id > INT32_MAX:
|
| 229 |
+
self._req_id = 1
|
| 230 |
+
# Responses that aren't tracked (like opportunistic releases)
|
| 231 |
+
# have req_id=0, so make sure we never mint such an id.
|
| 232 |
+
assert self._req_id != 0
|
| 233 |
+
return self._req_id
|
| 234 |
+
|
| 235 |
+
def _start_datathread(self) -> threading.Thread:
|
| 236 |
+
return threading.Thread(
|
| 237 |
+
target=self._data_main,
|
| 238 |
+
name="ray_client_streaming_rpc",
|
| 239 |
+
args=(),
|
| 240 |
+
daemon=True,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# A helper that takes requests from queue. If the request wraps a PutRequest,
|
| 244 |
+
# lazily chunks and yields the request. Otherwise, yields the request directly.
|
| 245 |
+
def _requests(self):
|
| 246 |
+
while True:
|
| 247 |
+
req = self.request_queue.get()
|
| 248 |
+
if req is None:
|
| 249 |
+
# Stop when client signals shutdown.
|
| 250 |
+
return
|
| 251 |
+
req_type = req.WhichOneof("type")
|
| 252 |
+
if req_type == "put":
|
| 253 |
+
yield from chunk_put(req)
|
| 254 |
+
elif req_type == "task":
|
| 255 |
+
yield from chunk_task(req)
|
| 256 |
+
else:
|
| 257 |
+
yield req
|
| 258 |
+
|
| 259 |
+
def _data_main(self) -> None:
|
| 260 |
+
reconnecting = False
|
| 261 |
+
try:
|
| 262 |
+
while not self.client_worker._in_shutdown:
|
| 263 |
+
stub = ray_client_pb2_grpc.RayletDataStreamerStub(
|
| 264 |
+
self.client_worker.channel
|
| 265 |
+
)
|
| 266 |
+
metadata = self._metadata + [("reconnecting", str(reconnecting))]
|
| 267 |
+
resp_stream = stub.Datapath(
|
| 268 |
+
self._requests(),
|
| 269 |
+
metadata=metadata,
|
| 270 |
+
wait_for_ready=True,
|
| 271 |
+
)
|
| 272 |
+
try:
|
| 273 |
+
for response in resp_stream:
|
| 274 |
+
self._process_response(response)
|
| 275 |
+
return
|
| 276 |
+
except grpc.RpcError as e:
|
| 277 |
+
reconnecting = self._can_reconnect(e)
|
| 278 |
+
if not reconnecting:
|
| 279 |
+
self._last_exception = e
|
| 280 |
+
return
|
| 281 |
+
self._reconnect_channel()
|
| 282 |
+
except Exception as e:
|
| 283 |
+
self._last_exception = e
|
| 284 |
+
finally:
|
| 285 |
+
logger.debug("Shutting down data channel.")
|
| 286 |
+
self._shutdown()
|
| 287 |
+
|
| 288 |
+
def _process_response(self, response: Any) -> None:
|
| 289 |
+
"""
|
| 290 |
+
Process responses from the data servicer.
|
| 291 |
+
"""
|
| 292 |
+
if response.req_id == 0:
|
| 293 |
+
# This is not being waited for.
|
| 294 |
+
logger.debug(f"Got unawaited response {response}")
|
| 295 |
+
return
|
| 296 |
+
if response.req_id in self.asyncio_waiting_data:
|
| 297 |
+
can_remove = True
|
| 298 |
+
try:
|
| 299 |
+
callback = self.asyncio_waiting_data[response.req_id]
|
| 300 |
+
if isinstance(callback, ChunkCollector):
|
| 301 |
+
can_remove = callback(response)
|
| 302 |
+
elif callback:
|
| 303 |
+
callback(response)
|
| 304 |
+
if can_remove:
|
| 305 |
+
# NOTE: calling del self.asyncio_waiting_data results
|
| 306 |
+
# in the destructor of ClientObjectRef running, which
|
| 307 |
+
# calls ReleaseObject(). So self.asyncio_waiting_data
|
| 308 |
+
# is accessed without holding self.lock. Holding the
|
| 309 |
+
# lock shouldn't be necessary either.
|
| 310 |
+
del self.asyncio_waiting_data[response.req_id]
|
| 311 |
+
except Exception:
|
| 312 |
+
logger.exception("Callback error:")
|
| 313 |
+
with self.lock:
|
| 314 |
+
# Update outstanding requests
|
| 315 |
+
if response.req_id in self.outstanding_requests and can_remove:
|
| 316 |
+
del self.outstanding_requests[response.req_id]
|
| 317 |
+
# Acknowledge response
|
| 318 |
+
self._acknowledge(response.req_id)
|
| 319 |
+
else:
|
| 320 |
+
with self.lock:
|
| 321 |
+
self.ready_data[response.req_id] = response
|
| 322 |
+
self.cv.notify_all()
|
| 323 |
+
|
| 324 |
+
def _can_reconnect(self, e: grpc.RpcError) -> bool:
|
| 325 |
+
"""
|
| 326 |
+
Processes RPC errors that occur while reading from data stream.
|
| 327 |
+
Returns True if the error can be recovered from, False otherwise.
|
| 328 |
+
"""
|
| 329 |
+
if not self.client_worker._can_reconnect(e):
|
| 330 |
+
logger.error("Unrecoverable error in data channel.")
|
| 331 |
+
logger.debug(e)
|
| 332 |
+
return False
|
| 333 |
+
logger.debug("Recoverable error in data channel.")
|
| 334 |
+
logger.debug(e)
|
| 335 |
+
return True
|
| 336 |
+
|
| 337 |
+
def _shutdown(self) -> None:
|
| 338 |
+
"""
|
| 339 |
+
Shutdown the data channel
|
| 340 |
+
"""
|
| 341 |
+
with self.lock:
|
| 342 |
+
self._in_shutdown = True
|
| 343 |
+
self.cv.notify_all()
|
| 344 |
+
|
| 345 |
+
callbacks = self.asyncio_waiting_data.values()
|
| 346 |
+
self.asyncio_waiting_data = {}
|
| 347 |
+
|
| 348 |
+
if self._last_exception:
|
| 349 |
+
# Abort async requests with the error.
|
| 350 |
+
err = ConnectionError(
|
| 351 |
+
"Failed during this or a previous request. Exception that "
|
| 352 |
+
f"broke the connection: {self._last_exception}"
|
| 353 |
+
)
|
| 354 |
+
else:
|
| 355 |
+
err = ConnectionError(
|
| 356 |
+
"Request cannot be fulfilled because the data client has "
|
| 357 |
+
"disconnected."
|
| 358 |
+
)
|
| 359 |
+
for callback in callbacks:
|
| 360 |
+
if callback:
|
| 361 |
+
callback(err)
|
| 362 |
+
# Since self._in_shutdown is set to True, no new item
|
| 363 |
+
# will be added to self.asyncio_waiting_data
|
| 364 |
+
|
| 365 |
+
def _acknowledge(self, req_id: int) -> None:
|
| 366 |
+
"""
|
| 367 |
+
Puts an acknowledge request on the request queue periodically.
|
| 368 |
+
Lock should be held before calling this. Used when an async or
|
| 369 |
+
blocking response is received.
|
| 370 |
+
"""
|
| 371 |
+
if not self.client_worker._reconnect_enabled:
|
| 372 |
+
# Skip ACKs if reconnect isn't enabled
|
| 373 |
+
return
|
| 374 |
+
assert self.lock.locked()
|
| 375 |
+
self._acknowledge_counter += 1
|
| 376 |
+
if self._acknowledge_counter % ACKNOWLEDGE_BATCH_SIZE == 0:
|
| 377 |
+
self.request_queue.put(
|
| 378 |
+
ray_client_pb2.DataRequest(
|
| 379 |
+
acknowledge=ray_client_pb2.AcknowledgeRequest(req_id=req_id)
|
| 380 |
+
)
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
def _reconnect_channel(self) -> None:
|
| 384 |
+
"""
|
| 385 |
+
Attempts to reconnect the gRPC channel and resend outstanding
|
| 386 |
+
requests. First, the server is pinged to see if the current channel
|
| 387 |
+
still works. If the ping fails, then the current channel is closed
|
| 388 |
+
and replaced with a new one.
|
| 389 |
+
|
| 390 |
+
Once a working channel is available, a new request queue is made
|
| 391 |
+
and filled with any outstanding requests to be resent to the server.
|
| 392 |
+
"""
|
| 393 |
+
try:
|
| 394 |
+
# Ping the server to see if the current channel is reuseable, for
|
| 395 |
+
# example if gRPC reconnected the channel on its own or if the
|
| 396 |
+
# RPC error was transient and the channel is still open
|
| 397 |
+
ping_succeeded = self.client_worker.ping_server(timeout=5)
|
| 398 |
+
except grpc.RpcError:
|
| 399 |
+
ping_succeeded = False
|
| 400 |
+
|
| 401 |
+
if not ping_succeeded:
|
| 402 |
+
# Ping failed, try refreshing the data channel
|
| 403 |
+
logger.warning(
|
| 404 |
+
"Encountered connection issues in the data channel. "
|
| 405 |
+
"Attempting to reconnect."
|
| 406 |
+
)
|
| 407 |
+
try:
|
| 408 |
+
self.client_worker._connect_channel(reconnecting=True)
|
| 409 |
+
except ConnectionError:
|
| 410 |
+
logger.warning("Failed to reconnect the data channel")
|
| 411 |
+
raise
|
| 412 |
+
logger.debug("Reconnection succeeded!")
|
| 413 |
+
|
| 414 |
+
# Recreate the request queue, and resend outstanding requests
|
| 415 |
+
with self.lock:
|
| 416 |
+
self.request_queue = self._create_queue()
|
| 417 |
+
for request in self.outstanding_requests.values():
|
| 418 |
+
# Resend outstanding requests
|
| 419 |
+
self.request_queue.put(request)
|
| 420 |
+
|
| 421 |
+
# Use SimpleQueue to avoid deadlocks when appending to queue from __del__()
|
| 422 |
+
@staticmethod
|
| 423 |
+
def _create_queue():
|
| 424 |
+
return queue.SimpleQueue()
|
| 425 |
+
|
| 426 |
+
def close(self) -> None:
|
| 427 |
+
thread = None
|
| 428 |
+
with self.lock:
|
| 429 |
+
self._in_shutdown = True
|
| 430 |
+
# Notify blocking operations to fail.
|
| 431 |
+
self.cv.notify_all()
|
| 432 |
+
# Add sentinel to terminate streaming RPC.
|
| 433 |
+
if self.request_queue is not None:
|
| 434 |
+
# Intentional shutdown, tell server it can clean up the
|
| 435 |
+
# connection immediately and ignore the reconnect grace period.
|
| 436 |
+
cleanup_request = ray_client_pb2.DataRequest(
|
| 437 |
+
connection_cleanup=ray_client_pb2.ConnectionCleanupRequest()
|
| 438 |
+
)
|
| 439 |
+
self.request_queue.put(cleanup_request)
|
| 440 |
+
self.request_queue.put(None)
|
| 441 |
+
if self.data_thread is not None:
|
| 442 |
+
thread = self.data_thread
|
| 443 |
+
# Wait until streaming RPCs are done.
|
| 444 |
+
if thread is not None:
|
| 445 |
+
thread.join()
|
| 446 |
+
|
| 447 |
+
def _blocking_send(
|
| 448 |
+
self, req: ray_client_pb2.DataRequest
|
| 449 |
+
) -> ray_client_pb2.DataResponse:
|
| 450 |
+
with self.lock:
|
| 451 |
+
self._check_shutdown()
|
| 452 |
+
req_id = self._next_id()
|
| 453 |
+
req.req_id = req_id
|
| 454 |
+
self.request_queue.put(req)
|
| 455 |
+
self.outstanding_requests[req_id] = req
|
| 456 |
+
|
| 457 |
+
self.cv.wait_for(lambda: req_id in self.ready_data or self._in_shutdown)
|
| 458 |
+
self._check_shutdown()
|
| 459 |
+
|
| 460 |
+
data = self.ready_data[req_id]
|
| 461 |
+
del self.ready_data[req_id]
|
| 462 |
+
del self.outstanding_requests[req_id]
|
| 463 |
+
self._acknowledge(req_id)
|
| 464 |
+
|
| 465 |
+
return data
|
| 466 |
+
|
| 467 |
+
def _async_send(
|
| 468 |
+
self,
|
| 469 |
+
req: ray_client_pb2.DataRequest,
|
| 470 |
+
callback: Optional[ResponseCallable] = None,
|
| 471 |
+
) -> None:
|
| 472 |
+
with self.lock:
|
| 473 |
+
self._check_shutdown()
|
| 474 |
+
req_id = self._next_id()
|
| 475 |
+
req.req_id = req_id
|
| 476 |
+
self.asyncio_waiting_data[req_id] = callback
|
| 477 |
+
self.outstanding_requests[req_id] = req
|
| 478 |
+
self.request_queue.put(req)
|
| 479 |
+
|
| 480 |
+
# Must hold self.lock when calling this function.
|
| 481 |
+
def _check_shutdown(self):
|
| 482 |
+
assert self.lock.locked()
|
| 483 |
+
if not self._in_shutdown:
|
| 484 |
+
return
|
| 485 |
+
|
| 486 |
+
self.lock.release()
|
| 487 |
+
|
| 488 |
+
# Do not try disconnect() or throw exceptions in self.data_thread.
|
| 489 |
+
# Otherwise deadlock can occur.
|
| 490 |
+
if threading.current_thread().ident == self.data_thread.ident:
|
| 491 |
+
return
|
| 492 |
+
|
| 493 |
+
from ray.util import disconnect
|
| 494 |
+
|
| 495 |
+
disconnect()
|
| 496 |
+
|
| 497 |
+
self.lock.acquire()
|
| 498 |
+
|
| 499 |
+
if self._last_exception is not None:
|
| 500 |
+
msg = (
|
| 501 |
+
"Request can't be sent because the Ray client has already "
|
| 502 |
+
"been disconnected due to an error. Last exception: "
|
| 503 |
+
f"{self._last_exception}"
|
| 504 |
+
)
|
| 505 |
+
else:
|
| 506 |
+
msg = (
|
| 507 |
+
"Request can't be sent because the Ray client has already "
|
| 508 |
+
"been disconnected."
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
raise ConnectionError(msg)
|
| 512 |
+
|
| 513 |
+
def Init(
|
| 514 |
+
self, request: ray_client_pb2.InitRequest, context=None
|
| 515 |
+
) -> ray_client_pb2.InitResponse:
|
| 516 |
+
datareq = ray_client_pb2.DataRequest(
|
| 517 |
+
init=request,
|
| 518 |
+
)
|
| 519 |
+
resp = self._blocking_send(datareq)
|
| 520 |
+
return resp.init
|
| 521 |
+
|
| 522 |
+
def PrepRuntimeEnv(
|
| 523 |
+
self, request: ray_client_pb2.PrepRuntimeEnvRequest, context=None
|
| 524 |
+
) -> ray_client_pb2.PrepRuntimeEnvResponse:
|
| 525 |
+
datareq = ray_client_pb2.DataRequest(
|
| 526 |
+
prep_runtime_env=request,
|
| 527 |
+
)
|
| 528 |
+
resp = self._blocking_send(datareq)
|
| 529 |
+
return resp.prep_runtime_env
|
| 530 |
+
|
| 531 |
+
def ConnectionInfo(self, context=None) -> ray_client_pb2.ConnectionInfoResponse:
|
| 532 |
+
datareq = ray_client_pb2.DataRequest(
|
| 533 |
+
connection_info=ray_client_pb2.ConnectionInfoRequest()
|
| 534 |
+
)
|
| 535 |
+
resp = self._blocking_send(datareq)
|
| 536 |
+
return resp.connection_info
|
| 537 |
+
|
| 538 |
+
def GetObject(
|
| 539 |
+
self, request: ray_client_pb2.GetRequest, context=None
|
| 540 |
+
) -> ray_client_pb2.GetResponse:
|
| 541 |
+
datareq = ray_client_pb2.DataRequest(
|
| 542 |
+
get=request,
|
| 543 |
+
)
|
| 544 |
+
resp = self._blocking_send(datareq)
|
| 545 |
+
return resp.get
|
| 546 |
+
|
| 547 |
+
def RegisterGetCallback(
|
| 548 |
+
self, request: ray_client_pb2.GetRequest, callback: ResponseCallable
|
| 549 |
+
) -> None:
|
| 550 |
+
if len(request.ids) != 1:
|
| 551 |
+
raise ValueError(
|
| 552 |
+
"RegisterGetCallback() must have exactly 1 Object ID. "
|
| 553 |
+
f"Actual: {request}"
|
| 554 |
+
)
|
| 555 |
+
datareq = ray_client_pb2.DataRequest(
|
| 556 |
+
get=request,
|
| 557 |
+
)
|
| 558 |
+
collector = ChunkCollector(callback=callback, request=datareq)
|
| 559 |
+
self._async_send(datareq, collector)
|
| 560 |
+
|
| 561 |
+
# TODO: convert PutObject to async
|
| 562 |
+
def PutObject(
|
| 563 |
+
self, request: ray_client_pb2.PutRequest, context=None
|
| 564 |
+
) -> ray_client_pb2.PutResponse:
|
| 565 |
+
datareq = ray_client_pb2.DataRequest(
|
| 566 |
+
put=request,
|
| 567 |
+
)
|
| 568 |
+
resp = self._blocking_send(datareq)
|
| 569 |
+
return resp.put
|
| 570 |
+
|
| 571 |
+
def ReleaseObject(
|
| 572 |
+
self, request: ray_client_pb2.ReleaseRequest, context=None
|
| 573 |
+
) -> None:
|
| 574 |
+
datareq = ray_client_pb2.DataRequest(
|
| 575 |
+
release=request,
|
| 576 |
+
)
|
| 577 |
+
self._async_send(datareq)
|
| 578 |
+
|
| 579 |
+
def Schedule(self, request: ray_client_pb2.ClientTask, callback: ResponseCallable):
|
| 580 |
+
datareq = ray_client_pb2.DataRequest(task=request)
|
| 581 |
+
self._async_send(datareq, callback)
|
| 582 |
+
|
| 583 |
+
def Terminate(
|
| 584 |
+
self, request: ray_client_pb2.TerminateRequest
|
| 585 |
+
) -> ray_client_pb2.TerminateResponse:
|
| 586 |
+
req = ray_client_pb2.DataRequest(
|
| 587 |
+
terminate=request,
|
| 588 |
+
)
|
| 589 |
+
resp = self._blocking_send(req)
|
| 590 |
+
return resp.terminate
|
| 591 |
+
|
| 592 |
+
def ListNamedActors(
|
| 593 |
+
self, request: ray_client_pb2.ClientListNamedActorsRequest
|
| 594 |
+
) -> ray_client_pb2.ClientListNamedActorsResponse:
|
| 595 |
+
req = ray_client_pb2.DataRequest(
|
| 596 |
+
list_named_actors=request,
|
| 597 |
+
)
|
| 598 |
+
resp = self._blocking_send(req)
|
| 599 |
+
return resp.list_named_actors
|
.venv/lib/python3.11/site-packages/ray/util/client/options.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
from typing import Dict
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
from ray._private import ray_option_utils
|
| 6 |
+
from ray.util.placement_group import PlacementGroup, check_placement_group_index
|
| 7 |
+
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def validate_options(kwargs_dict: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
| 11 |
+
if kwargs_dict is None:
|
| 12 |
+
return None
|
| 13 |
+
if len(kwargs_dict) == 0:
|
| 14 |
+
return None
|
| 15 |
+
|
| 16 |
+
out = {}
|
| 17 |
+
for k, v in kwargs_dict.items():
|
| 18 |
+
if k not in ray_option_utils.valid_options:
|
| 19 |
+
raise ValueError(
|
| 20 |
+
f"Invalid option keyword: '{k}'. "
|
| 21 |
+
f"{ray_option_utils.remote_args_error_string}"
|
| 22 |
+
)
|
| 23 |
+
ray_option_utils.valid_options[k].validate(k, v)
|
| 24 |
+
out[k] = v
|
| 25 |
+
|
| 26 |
+
# Validate placement setting similar to the logic in ray/actor.py and
|
| 27 |
+
# ray/remote_function.py. The difference is that when
|
| 28 |
+
# placement_group = default and placement_group_capture_child_tasks
|
| 29 |
+
# specified, placement group cannot be resolved at client. So this check
|
| 30 |
+
# skips this case and relies on server to enforce any condition.
|
| 31 |
+
bundle_index = out.get("placement_group_bundle_index", None)
|
| 32 |
+
pg = out.get("placement_group", None)
|
| 33 |
+
scheduling_strategy = out.get("scheduling_strategy", None)
|
| 34 |
+
if isinstance(scheduling_strategy, PlacementGroupSchedulingStrategy):
|
| 35 |
+
pg = scheduling_strategy.placement_group
|
| 36 |
+
bundle_index = scheduling_strategy.placement_group_bundle_index
|
| 37 |
+
if bundle_index is not None:
|
| 38 |
+
if pg is None:
|
| 39 |
+
pg = PlacementGroup.empty()
|
| 40 |
+
if pg == "default" and (
|
| 41 |
+
out.get("placement_group_capture_child_tasks", None) is None
|
| 42 |
+
):
|
| 43 |
+
pg = PlacementGroup.empty()
|
| 44 |
+
if isinstance(pg, PlacementGroup):
|
| 45 |
+
check_placement_group_index(pg, bundle_index)
|
| 46 |
+
|
| 47 |
+
return out
|
.venv/lib/python3.11/site-packages/ray/util/client/ray_client_helpers.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from contextlib import contextmanager
|
| 2 |
+
import time
|
| 3 |
+
from typing import Any, Dict
|
| 4 |
+
|
| 5 |
+
import ray as real_ray
|
| 6 |
+
from ray.job_config import JobConfig
|
| 7 |
+
import ray.util.client.server.server as ray_client_server
|
| 8 |
+
from ray.util.client import ray
|
| 9 |
+
from ray._private.client_mode_hook import enable_client_mode, disable_client_hook
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@contextmanager
|
| 13 |
+
def ray_start_client_server(metadata=None, ray_connect_handler=None, **kwargs):
|
| 14 |
+
with ray_start_client_server_pair(
|
| 15 |
+
metadata=metadata, ray_connect_handler=ray_connect_handler, **kwargs
|
| 16 |
+
) as pair:
|
| 17 |
+
client, server = pair
|
| 18 |
+
yield client
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@contextmanager
|
| 22 |
+
def ray_start_client_server_for_address(address):
|
| 23 |
+
"""
|
| 24 |
+
Starts a Ray client server that initializes drivers at the specified address.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def connect_handler(
|
| 28 |
+
job_config: JobConfig = None, **ray_init_kwargs: Dict[str, Any]
|
| 29 |
+
):
|
| 30 |
+
import ray
|
| 31 |
+
|
| 32 |
+
with disable_client_hook():
|
| 33 |
+
if not ray.is_initialized():
|
| 34 |
+
return ray.init(address, job_config=job_config, **ray_init_kwargs)
|
| 35 |
+
|
| 36 |
+
with ray_start_client_server(ray_connect_handler=connect_handler) as ray:
|
| 37 |
+
yield ray
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@contextmanager
|
| 41 |
+
def ray_start_client_server_pair(metadata=None, ray_connect_handler=None, **kwargs):
|
| 42 |
+
ray._inside_client_test = True
|
| 43 |
+
with disable_client_hook():
|
| 44 |
+
assert not ray.is_initialized()
|
| 45 |
+
server = ray_client_server.serve(
|
| 46 |
+
"127.0.0.1:50051", ray_connect_handler=ray_connect_handler
|
| 47 |
+
)
|
| 48 |
+
ray.connect("127.0.0.1:50051", metadata=metadata, **kwargs)
|
| 49 |
+
try:
|
| 50 |
+
yield ray, server
|
| 51 |
+
finally:
|
| 52 |
+
ray._inside_client_test = False
|
| 53 |
+
ray.disconnect()
|
| 54 |
+
server.stop(0)
|
| 55 |
+
del server
|
| 56 |
+
start = time.monotonic()
|
| 57 |
+
with disable_client_hook():
|
| 58 |
+
while ray.is_initialized():
|
| 59 |
+
time.sleep(1)
|
| 60 |
+
if time.monotonic() - start > 30:
|
| 61 |
+
raise RuntimeError("Failed to terminate Ray")
|
| 62 |
+
# Allow windows to close processes before moving on
|
| 63 |
+
time.sleep(3)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@contextmanager
|
| 67 |
+
def ray_start_cluster_client_server_pair(address):
|
| 68 |
+
ray._inside_client_test = True
|
| 69 |
+
|
| 70 |
+
def ray_connect_handler(job_config=None, **ray_init_kwargs):
|
| 71 |
+
real_ray.init(address=address)
|
| 72 |
+
|
| 73 |
+
server = ray_client_server.serve(
|
| 74 |
+
"127.0.0.1:50051", ray_connect_handler=ray_connect_handler
|
| 75 |
+
)
|
| 76 |
+
ray.connect("127.0.0.1:50051")
|
| 77 |
+
try:
|
| 78 |
+
yield ray, server
|
| 79 |
+
finally:
|
| 80 |
+
ray._inside_client_test = False
|
| 81 |
+
ray.disconnect()
|
| 82 |
+
server.stop(0)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@contextmanager
|
| 86 |
+
def connect_to_client_or_not(connect_to_client: bool):
|
| 87 |
+
"""Utility for running test logic with and without a Ray client connection.
|
| 88 |
+
|
| 89 |
+
If client_connect is True, will connect to Ray client in context.
|
| 90 |
+
If client_connect is False, does nothing.
|
| 91 |
+
|
| 92 |
+
How to use:
|
| 93 |
+
Given a test of the following form:
|
| 94 |
+
|
| 95 |
+
def test_<name>(args):
|
| 96 |
+
<initialize a ray cluster>
|
| 97 |
+
<use the ray cluster>
|
| 98 |
+
|
| 99 |
+
Modify the test to
|
| 100 |
+
|
| 101 |
+
@pytest.mark.parametrize("connect_to_client", [False, True])
|
| 102 |
+
def test_<name>(args, connect_to_client)
|
| 103 |
+
<initialize a ray cluster>
|
| 104 |
+
with connect_to_client_or_not(connect_to_client):
|
| 105 |
+
<use the ray cluster>
|
| 106 |
+
|
| 107 |
+
Parameterize the argument connect over True, False to run the test with and
|
| 108 |
+
without a Ray client connection.
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
if connect_to_client:
|
| 112 |
+
with ray_start_client_server(namespace=""), enable_client_mode():
|
| 113 |
+
yield
|
| 114 |
+
else:
|
| 115 |
+
yield
|
.venv/lib/python3.11/site-packages/ray/util/client/runtime_context.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING
|
| 2 |
+
from types import SimpleNamespace
|
| 3 |
+
|
| 4 |
+
if TYPE_CHECKING:
|
| 5 |
+
from ray import JobID, NodeID
|
| 6 |
+
from ray.runtime_context import RuntimeContext
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class _ClientWorkerPropertyAPI:
|
| 10 |
+
"""Emulates the properties of the ray._private.worker object for the client"""
|
| 11 |
+
|
| 12 |
+
def __init__(self, worker):
|
| 13 |
+
assert worker is not None
|
| 14 |
+
self.worker = worker
|
| 15 |
+
|
| 16 |
+
def build_runtime_context(self) -> "RuntimeContext":
|
| 17 |
+
"""Creates a RuntimeContext backed by the properites of this API"""
|
| 18 |
+
# Defer the import of RuntimeContext until needed to avoid cycles
|
| 19 |
+
from ray.runtime_context import RuntimeContext
|
| 20 |
+
|
| 21 |
+
return RuntimeContext(self)
|
| 22 |
+
|
| 23 |
+
def _fetch_runtime_context(self):
|
| 24 |
+
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
| 25 |
+
|
| 26 |
+
return self.worker.get_cluster_info(
|
| 27 |
+
ray_client_pb2.ClusterInfoType.RUNTIME_CONTEXT
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
@property
|
| 31 |
+
def mode(self):
|
| 32 |
+
from ray._private.worker import SCRIPT_MODE
|
| 33 |
+
|
| 34 |
+
return SCRIPT_MODE
|
| 35 |
+
|
| 36 |
+
@property
|
| 37 |
+
def current_job_id(self) -> "JobID":
|
| 38 |
+
from ray import JobID
|
| 39 |
+
|
| 40 |
+
return JobID(self._fetch_runtime_context().job_id)
|
| 41 |
+
|
| 42 |
+
@property
|
| 43 |
+
def current_node_id(self) -> "NodeID":
|
| 44 |
+
from ray import NodeID
|
| 45 |
+
|
| 46 |
+
return NodeID(self._fetch_runtime_context().node_id)
|
| 47 |
+
|
| 48 |
+
@property
|
| 49 |
+
def namespace(self) -> str:
|
| 50 |
+
return self._fetch_runtime_context().namespace
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def should_capture_child_tasks_in_placement_group(self) -> bool:
|
| 54 |
+
return self._fetch_runtime_context().capture_client_tasks
|
| 55 |
+
|
| 56 |
+
@property
|
| 57 |
+
def runtime_env(self) -> str:
|
| 58 |
+
return self._fetch_runtime_context().runtime_env
|
| 59 |
+
|
| 60 |
+
def check_connected(self) -> bool:
|
| 61 |
+
return self.worker.ping_server()
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def gcs_client(self) -> str:
|
| 65 |
+
return SimpleNamespace(address=self._fetch_runtime_context().gcs_address)
|
.venv/lib/python3.11/site-packages/ray/util/client/worker.py
ADDED
|
@@ -0,0 +1,908 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This file includes the Worker class which sits on the client side.
|
| 2 |
+
It implements the Ray API functions that are forwarded through grpc calls
|
| 3 |
+
to the server.
|
| 4 |
+
"""
|
| 5 |
+
import base64
|
| 6 |
+
import json
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import tempfile
|
| 10 |
+
import threading
|
| 11 |
+
import time
|
| 12 |
+
import uuid
|
| 13 |
+
import warnings
|
| 14 |
+
from collections import defaultdict
|
| 15 |
+
from concurrent.futures import Future
|
| 16 |
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import grpc
|
| 19 |
+
|
| 20 |
+
import ray._private.tls_utils
|
| 21 |
+
import ray.cloudpickle as cloudpickle
|
| 22 |
+
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
| 23 |
+
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
| 24 |
+
from ray._private.ray_constants import DEFAULT_CLIENT_RECONNECT_GRACE_PERIOD
|
| 25 |
+
from ray._private.runtime_env.py_modules import upload_py_modules_if_needed
|
| 26 |
+
from ray._private.runtime_env.working_dir import upload_working_dir_if_needed
|
| 27 |
+
|
| 28 |
+
# Use cloudpickle's version of pickle for UnpicklingError
|
| 29 |
+
from ray.cloudpickle.compat import pickle
|
| 30 |
+
from ray.exceptions import GetTimeoutError
|
| 31 |
+
from ray.job_config import JobConfig
|
| 32 |
+
from ray.util.client.client_pickler import dumps_from_client, loads_from_server
|
| 33 |
+
from ray.util.client.common import (
|
| 34 |
+
GRPC_OPTIONS,
|
| 35 |
+
GRPC_UNRECOVERABLE_ERRORS,
|
| 36 |
+
INT32_MAX,
|
| 37 |
+
OBJECT_TRANSFER_WARNING_SIZE,
|
| 38 |
+
ClientActorClass,
|
| 39 |
+
ClientActorHandle,
|
| 40 |
+
ClientActorRef,
|
| 41 |
+
ClientObjectRef,
|
| 42 |
+
ClientRemoteFunc,
|
| 43 |
+
ClientStub,
|
| 44 |
+
)
|
| 45 |
+
from ray.util.client.dataclient import DataClient
|
| 46 |
+
from ray.util.client.logsclient import LogstreamClient
|
| 47 |
+
from ray.util.debug import log_once
|
| 48 |
+
|
| 49 |
+
if TYPE_CHECKING:
|
| 50 |
+
from ray.actor import ActorClass
|
| 51 |
+
from ray.remote_function import RemoteFunction
|
| 52 |
+
|
| 53 |
+
logger = logging.getLogger(__name__)
|
| 54 |
+
|
| 55 |
+
INITIAL_TIMEOUT_SEC = 5
|
| 56 |
+
MAX_TIMEOUT_SEC = 30
|
| 57 |
+
|
| 58 |
+
# The max amount of time an operation can run blocking in the server. This
|
| 59 |
+
# allows for Ctrl-C of the client to work without explicitly cancelling server
|
| 60 |
+
# operations.
|
| 61 |
+
MAX_BLOCKING_OPERATION_TIME_S: float = 2.0
|
| 62 |
+
|
| 63 |
+
# If the total size (bytes) of all outbound messages to schedule tasks since
|
| 64 |
+
# the connection began exceeds this value, a warning should be raised
|
| 65 |
+
MESSAGE_SIZE_THRESHOLD = 10 * 2**20 # 10 MB
|
| 66 |
+
|
| 67 |
+
# Links to the Ray Design Pattern doc to use in the task overhead warning
|
| 68 |
+
# message
|
| 69 |
+
DESIGN_PATTERN_FINE_GRAIN_TASKS_LINK = "https://docs.google.com/document/d/167rnnDFIVRhHhK4mznEIemOtj63IOhtIPvSYaPgI4Fg/edit#heading=h.f7ins22n6nyl" # noqa E501
|
| 70 |
+
|
| 71 |
+
DESIGN_PATTERN_LARGE_OBJECTS_LINK = "https://docs.google.com/document/d/167rnnDFIVRhHhK4mznEIemOtj63IOhtIPvSYaPgI4Fg/edit#heading=h.1afmymq455wu" # noqa E501
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def backoff(timeout: int) -> int:
|
| 75 |
+
timeout = timeout + 5
|
| 76 |
+
if timeout > MAX_TIMEOUT_SEC:
|
| 77 |
+
timeout = MAX_TIMEOUT_SEC
|
| 78 |
+
return timeout
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class Worker:
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
conn_str: str = "",
|
| 85 |
+
secure: bool = False,
|
| 86 |
+
metadata: List[Tuple[str, str]] = None,
|
| 87 |
+
connection_retries: int = 3,
|
| 88 |
+
_credentials: Optional[grpc.ChannelCredentials] = None,
|
| 89 |
+
):
|
| 90 |
+
"""Initializes the worker side grpc client.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
conn_str: The host:port connection string for the ray server.
|
| 94 |
+
secure: whether to use SSL secure channel or not.
|
| 95 |
+
metadata: additional metadata passed in the grpc request headers.
|
| 96 |
+
connection_retries: Number of times to attempt to reconnect to the
|
| 97 |
+
ray server if it doesn't respond immediately. Setting to 0 tries
|
| 98 |
+
at least once. For infinite retries, catch the ConnectionError
|
| 99 |
+
exception.
|
| 100 |
+
_credentials: gprc channel credentials. Default ones will be used
|
| 101 |
+
if None.
|
| 102 |
+
"""
|
| 103 |
+
self._client_id = make_client_id()
|
| 104 |
+
self.metadata = [("client_id", self._client_id)] + (
|
| 105 |
+
metadata if metadata else []
|
| 106 |
+
)
|
| 107 |
+
self.channel = None
|
| 108 |
+
self.server = None
|
| 109 |
+
self._conn_state = grpc.ChannelConnectivity.IDLE
|
| 110 |
+
self._converted: Dict[str, ClientStub] = {}
|
| 111 |
+
self._secure = secure or os.environ.get("RAY_USE_TLS", "0").lower() in (
|
| 112 |
+
"1",
|
| 113 |
+
"true",
|
| 114 |
+
)
|
| 115 |
+
self._conn_str = conn_str
|
| 116 |
+
self._connection_retries = connection_retries
|
| 117 |
+
|
| 118 |
+
if _credentials is not None:
|
| 119 |
+
self._credentials = _credentials
|
| 120 |
+
self._secure = True
|
| 121 |
+
else:
|
| 122 |
+
self._credentials = None
|
| 123 |
+
|
| 124 |
+
self._reconnect_grace_period = DEFAULT_CLIENT_RECONNECT_GRACE_PERIOD
|
| 125 |
+
if "RAY_CLIENT_RECONNECT_GRACE_PERIOD" in os.environ:
|
| 126 |
+
# Use value in environment variable if available
|
| 127 |
+
self._reconnect_grace_period = int(
|
| 128 |
+
os.environ["RAY_CLIENT_RECONNECT_GRACE_PERIOD"]
|
| 129 |
+
)
|
| 130 |
+
# Disable retries if grace period is set to 0
|
| 131 |
+
self._reconnect_enabled = self._reconnect_grace_period != 0
|
| 132 |
+
|
| 133 |
+
# Set to True when the connection cannot be recovered and reconnect
|
| 134 |
+
# attempts should be stopped
|
| 135 |
+
self._in_shutdown = False
|
| 136 |
+
# Set to True after initial connection succeeds
|
| 137 |
+
self._has_connected = False
|
| 138 |
+
|
| 139 |
+
self._connect_channel()
|
| 140 |
+
self._has_connected = True
|
| 141 |
+
|
| 142 |
+
# Has Ray been initialized on the server?
|
| 143 |
+
self._serverside_ray_initialized = False
|
| 144 |
+
|
| 145 |
+
# Initialize the streams to finish protocol negotiation.
|
| 146 |
+
self.data_client = DataClient(self, self._client_id, self.metadata)
|
| 147 |
+
self.reference_count: Dict[bytes, int] = defaultdict(int)
|
| 148 |
+
|
| 149 |
+
self.log_client = LogstreamClient(self, self.metadata)
|
| 150 |
+
self.log_client.set_logstream_level(logging.INFO)
|
| 151 |
+
|
| 152 |
+
self.closed = False
|
| 153 |
+
|
| 154 |
+
# Track this value to raise a warning if a lot of data are transferred.
|
| 155 |
+
self.total_outbound_message_size_bytes = 0
|
| 156 |
+
|
| 157 |
+
# Used to create unique IDs for RPCs to the RayletServicer
|
| 158 |
+
self._req_id_lock = threading.Lock()
|
| 159 |
+
self._req_id = 0
|
| 160 |
+
|
| 161 |
+
def _connect_channel(self, reconnecting=False) -> None:
|
| 162 |
+
"""
|
| 163 |
+
Attempts to connect to the server specified by conn_str. If
|
| 164 |
+
reconnecting after an RPC error, cleans up the old channel and
|
| 165 |
+
continues to attempt to connect until the grace period is over.
|
| 166 |
+
"""
|
| 167 |
+
if self.channel is not None:
|
| 168 |
+
self.channel.unsubscribe(self._on_channel_state_change)
|
| 169 |
+
self.channel.close()
|
| 170 |
+
|
| 171 |
+
if self._secure:
|
| 172 |
+
if self._credentials is not None:
|
| 173 |
+
credentials = self._credentials
|
| 174 |
+
elif os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true"):
|
| 175 |
+
(
|
| 176 |
+
server_cert_chain,
|
| 177 |
+
private_key,
|
| 178 |
+
ca_cert,
|
| 179 |
+
) = ray._private.tls_utils.load_certs_from_env()
|
| 180 |
+
credentials = grpc.ssl_channel_credentials(
|
| 181 |
+
certificate_chain=server_cert_chain,
|
| 182 |
+
private_key=private_key,
|
| 183 |
+
root_certificates=ca_cert,
|
| 184 |
+
)
|
| 185 |
+
else:
|
| 186 |
+
credentials = grpc.ssl_channel_credentials()
|
| 187 |
+
self.channel = grpc.secure_channel(
|
| 188 |
+
self._conn_str, credentials, options=GRPC_OPTIONS
|
| 189 |
+
)
|
| 190 |
+
else:
|
| 191 |
+
self.channel = grpc.insecure_channel(self._conn_str, options=GRPC_OPTIONS)
|
| 192 |
+
|
| 193 |
+
self.channel.subscribe(self._on_channel_state_change)
|
| 194 |
+
|
| 195 |
+
# Retry the connection until the channel responds to something
|
| 196 |
+
# looking like a gRPC connection, though it may be a proxy.
|
| 197 |
+
start_time = time.time()
|
| 198 |
+
conn_attempts = 0
|
| 199 |
+
timeout = INITIAL_TIMEOUT_SEC
|
| 200 |
+
service_ready = False
|
| 201 |
+
while conn_attempts < max(self._connection_retries, 1) or reconnecting:
|
| 202 |
+
conn_attempts += 1
|
| 203 |
+
if self._in_shutdown:
|
| 204 |
+
# User manually closed the worker before connection finished
|
| 205 |
+
break
|
| 206 |
+
elapsed_time = time.time() - start_time
|
| 207 |
+
if reconnecting and elapsed_time > self._reconnect_grace_period:
|
| 208 |
+
self._in_shutdown = True
|
| 209 |
+
raise ConnectionError(
|
| 210 |
+
"Failed to reconnect within the reconnection grace period "
|
| 211 |
+
f"({self._reconnect_grace_period}s)"
|
| 212 |
+
)
|
| 213 |
+
try:
|
| 214 |
+
# Let gRPC wait for us to see if the channel becomes ready.
|
| 215 |
+
# If it throws, we couldn't connect.
|
| 216 |
+
grpc.channel_ready_future(self.channel).result(timeout=timeout)
|
| 217 |
+
# The HTTP2 channel is ready. Wrap the channel with the
|
| 218 |
+
# RayletDriverStub, allowing for unary requests.
|
| 219 |
+
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
|
| 220 |
+
service_ready = bool(self.ping_server())
|
| 221 |
+
if service_ready:
|
| 222 |
+
break
|
| 223 |
+
# Ray is not ready yet, wait a timeout
|
| 224 |
+
time.sleep(timeout)
|
| 225 |
+
except grpc.FutureTimeoutError:
|
| 226 |
+
logger.debug(f"Couldn't connect channel in {timeout} seconds, retrying")
|
| 227 |
+
# Note that channel_ready_future constitutes its own timeout,
|
| 228 |
+
# which is why we do not sleep here.
|
| 229 |
+
except grpc.RpcError as e:
|
| 230 |
+
logger.debug(
|
| 231 |
+
"Ray client server unavailable, " f"retrying in {timeout}s..."
|
| 232 |
+
)
|
| 233 |
+
logger.debug(f"Received when checking init: {e.details()}")
|
| 234 |
+
# Ray is not ready yet, wait a timeout.
|
| 235 |
+
time.sleep(timeout)
|
| 236 |
+
# Fallthrough, backoff, and retry at the top of the loop
|
| 237 |
+
logger.debug(
|
| 238 |
+
"Waiting for Ray to become ready on the server, "
|
| 239 |
+
f"retry in {timeout}s..."
|
| 240 |
+
)
|
| 241 |
+
if not reconnecting:
|
| 242 |
+
# Don't increase backoff when trying to reconnect --
|
| 243 |
+
# we already know the server exists, attempt to reconnect
|
| 244 |
+
# as soon as we can
|
| 245 |
+
timeout = backoff(timeout)
|
| 246 |
+
|
| 247 |
+
# If we made it through the loop without service_ready
|
| 248 |
+
# it means we've used up our retries and
|
| 249 |
+
# should error back to the user.
|
| 250 |
+
if not service_ready:
|
| 251 |
+
self._in_shutdown = True
|
| 252 |
+
if log_once("ray_client_security_groups"):
|
| 253 |
+
warnings.warn(
|
| 254 |
+
"Ray Client connection timed out. Ensure that "
|
| 255 |
+
"the Ray Client port on the head node is reachable "
|
| 256 |
+
"from your local machine. See https://docs.ray.io/en"
|
| 257 |
+
"/latest/cluster/ray-client.html#step-2-check-ports for "
|
| 258 |
+
"more information."
|
| 259 |
+
)
|
| 260 |
+
raise ConnectionError("ray client connection timeout")
|
| 261 |
+
|
| 262 |
+
def _can_reconnect(self, e: grpc.RpcError) -> bool:
|
| 263 |
+
"""
|
| 264 |
+
Returns True if the RPC error can be recovered from and a retry is
|
| 265 |
+
appropriate, false otherwise.
|
| 266 |
+
"""
|
| 267 |
+
if not self._reconnect_enabled:
|
| 268 |
+
return False
|
| 269 |
+
if self._in_shutdown:
|
| 270 |
+
# Channel is being shutdown, don't try to reconnect
|
| 271 |
+
return False
|
| 272 |
+
if e.code() in GRPC_UNRECOVERABLE_ERRORS:
|
| 273 |
+
# Unrecoverable error -- These errors are specifically raised
|
| 274 |
+
# by the server's application logic
|
| 275 |
+
return False
|
| 276 |
+
if e.code() == grpc.StatusCode.INTERNAL:
|
| 277 |
+
details = e.details()
|
| 278 |
+
if details == "Exception serializing request!":
|
| 279 |
+
# The client failed tried to send a bad request (for example,
|
| 280 |
+
# passing "None" instead of a valid grpc message). Don't
|
| 281 |
+
# try to reconnect/retry.
|
| 282 |
+
return False
|
| 283 |
+
# All other errors can be treated as recoverable
|
| 284 |
+
return True
|
| 285 |
+
|
| 286 |
+
def _call_stub(self, stub_name: str, *args, **kwargs) -> Any:
|
| 287 |
+
"""
|
| 288 |
+
Calls the stub specified by stub_name (Schedule, WaitObject, etc...).
|
| 289 |
+
If a recoverable error occurrs while calling the stub, attempts to
|
| 290 |
+
retry the RPC.
|
| 291 |
+
"""
|
| 292 |
+
while not self._in_shutdown:
|
| 293 |
+
try:
|
| 294 |
+
return getattr(self.server, stub_name)(*args, **kwargs)
|
| 295 |
+
except grpc.RpcError as e:
|
| 296 |
+
if self._can_reconnect(e):
|
| 297 |
+
time.sleep(0.5)
|
| 298 |
+
continue
|
| 299 |
+
raise
|
| 300 |
+
except ValueError:
|
| 301 |
+
# Trying to use the stub on a cancelled channel will raise
|
| 302 |
+
# ValueError. This should only happen when the data client
|
| 303 |
+
# is attempting to reset the connection -- sleep and try
|
| 304 |
+
# again.
|
| 305 |
+
time.sleep(0.5)
|
| 306 |
+
continue
|
| 307 |
+
raise ConnectionError("Client is shutting down.")
|
| 308 |
+
|
| 309 |
+
def _get_object_iterator(
|
| 310 |
+
self, req: ray_client_pb2.GetRequest, *args, **kwargs
|
| 311 |
+
) -> Any:
|
| 312 |
+
"""
|
| 313 |
+
Calls the stub for GetObject on the underlying server stub. If a
|
| 314 |
+
recoverable error occurs while streaming the response, attempts
|
| 315 |
+
to retry the get starting from the first chunk that hasn't been
|
| 316 |
+
received.
|
| 317 |
+
"""
|
| 318 |
+
last_seen_chunk = -1
|
| 319 |
+
while not self._in_shutdown:
|
| 320 |
+
# If we disconnect partway through, restart the get request
|
| 321 |
+
# at the first chunk we haven't seen
|
| 322 |
+
req.start_chunk_id = last_seen_chunk + 1
|
| 323 |
+
try:
|
| 324 |
+
for chunk in self.server.GetObject(req, *args, **kwargs):
|
| 325 |
+
if chunk.chunk_id <= last_seen_chunk:
|
| 326 |
+
# Ignore repeat chunks
|
| 327 |
+
logger.debug(
|
| 328 |
+
f"Received a repeated chunk {chunk.chunk_id} "
|
| 329 |
+
f"from request {req.req_id}."
|
| 330 |
+
)
|
| 331 |
+
continue
|
| 332 |
+
if last_seen_chunk + 1 != chunk.chunk_id:
|
| 333 |
+
raise RuntimeError(
|
| 334 |
+
f"Received chunk {chunk.chunk_id} when we expected "
|
| 335 |
+
f"{self.last_seen_chunk + 1}"
|
| 336 |
+
)
|
| 337 |
+
last_seen_chunk = chunk.chunk_id
|
| 338 |
+
yield chunk
|
| 339 |
+
if last_seen_chunk == chunk.total_chunks - 1:
|
| 340 |
+
# We've yielded the last chunk, exit early
|
| 341 |
+
return
|
| 342 |
+
return
|
| 343 |
+
except grpc.RpcError as e:
|
| 344 |
+
if self._can_reconnect(e):
|
| 345 |
+
time.sleep(0.5)
|
| 346 |
+
continue
|
| 347 |
+
raise
|
| 348 |
+
except ValueError:
|
| 349 |
+
# Trying to use the stub on a cancelled channel will raise
|
| 350 |
+
# ValueError. This should only happen when the data client
|
| 351 |
+
# is attempting to reset the connection -- sleep and try
|
| 352 |
+
# again.
|
| 353 |
+
time.sleep(0.5)
|
| 354 |
+
continue
|
| 355 |
+
raise ConnectionError("Client is shutting down.")
|
| 356 |
+
|
| 357 |
+
def _add_ids_to_metadata(self, metadata: Any):
|
| 358 |
+
"""
|
| 359 |
+
Adds a unique req_id and the current thread's identifier to the
|
| 360 |
+
metadata. These values are useful for preventing mutating operations
|
| 361 |
+
from being replayed on the server side in the event that the client
|
| 362 |
+
must retry a requsest.
|
| 363 |
+
Args:
|
| 364 |
+
metadata - the gRPC metadata to append the IDs to
|
| 365 |
+
"""
|
| 366 |
+
if not self._reconnect_enabled:
|
| 367 |
+
# IDs not needed if the reconnects are disabled
|
| 368 |
+
return metadata
|
| 369 |
+
thread_id = str(threading.get_ident())
|
| 370 |
+
with self._req_id_lock:
|
| 371 |
+
self._req_id += 1
|
| 372 |
+
if self._req_id > INT32_MAX:
|
| 373 |
+
self._req_id = 1
|
| 374 |
+
req_id = str(self._req_id)
|
| 375 |
+
return metadata + [("thread_id", thread_id), ("req_id", req_id)]
|
| 376 |
+
|
| 377 |
+
def _on_channel_state_change(self, conn_state: grpc.ChannelConnectivity):
|
| 378 |
+
logger.debug(f"client gRPC channel state change: {conn_state}")
|
| 379 |
+
self._conn_state = conn_state
|
| 380 |
+
|
| 381 |
+
def connection_info(self):
|
| 382 |
+
try:
|
| 383 |
+
data = self.data_client.ConnectionInfo()
|
| 384 |
+
except grpc.RpcError as e:
|
| 385 |
+
raise decode_exception(e)
|
| 386 |
+
return {
|
| 387 |
+
"num_clients": data.num_clients,
|
| 388 |
+
"python_version": data.python_version,
|
| 389 |
+
"ray_version": data.ray_version,
|
| 390 |
+
"ray_commit": data.ray_commit,
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
def register_callback(
|
| 394 |
+
self,
|
| 395 |
+
ref: ClientObjectRef,
|
| 396 |
+
callback: Callable[[ray_client_pb2.DataResponse], None],
|
| 397 |
+
) -> None:
|
| 398 |
+
req = ray_client_pb2.GetRequest(ids=[ref.id], asynchronous=True)
|
| 399 |
+
self.data_client.RegisterGetCallback(req, callback)
|
| 400 |
+
|
| 401 |
+
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
|
| 402 |
+
if isinstance(vals, list):
|
| 403 |
+
if not vals:
|
| 404 |
+
return []
|
| 405 |
+
to_get = vals
|
| 406 |
+
elif isinstance(vals, ClientObjectRef):
|
| 407 |
+
to_get = [vals]
|
| 408 |
+
else:
|
| 409 |
+
raise Exception(
|
| 410 |
+
"Can't get something that's not a "
|
| 411 |
+
"list of IDs or just an ID: %s" % type(vals)
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
if timeout is None:
|
| 415 |
+
deadline = None
|
| 416 |
+
else:
|
| 417 |
+
deadline = time.monotonic() + timeout
|
| 418 |
+
|
| 419 |
+
max_blocking_operation_time = MAX_BLOCKING_OPERATION_TIME_S
|
| 420 |
+
if "RAY_CLIENT_MAX_BLOCKING_OPERATION_TIME_S" in os.environ:
|
| 421 |
+
max_blocking_operation_time = float(
|
| 422 |
+
os.environ["RAY_CLIENT_MAX_BLOCKING_OPERATION_TIME_S"]
|
| 423 |
+
)
|
| 424 |
+
while True:
|
| 425 |
+
if deadline:
|
| 426 |
+
op_timeout = min(
|
| 427 |
+
max_blocking_operation_time,
|
| 428 |
+
max(deadline - time.monotonic(), 0.001),
|
| 429 |
+
)
|
| 430 |
+
else:
|
| 431 |
+
op_timeout = max_blocking_operation_time
|
| 432 |
+
try:
|
| 433 |
+
res = self._get(to_get, op_timeout)
|
| 434 |
+
break
|
| 435 |
+
except GetTimeoutError:
|
| 436 |
+
if deadline and time.monotonic() > deadline:
|
| 437 |
+
raise
|
| 438 |
+
logger.debug("Internal retry for get {}".format(to_get))
|
| 439 |
+
if len(to_get) != len(res):
|
| 440 |
+
raise Exception(
|
| 441 |
+
"Mismatched number of items in request ({}) and response ({})".format(
|
| 442 |
+
len(to_get), len(res)
|
| 443 |
+
)
|
| 444 |
+
)
|
| 445 |
+
if isinstance(vals, ClientObjectRef):
|
| 446 |
+
res = res[0]
|
| 447 |
+
return res
|
| 448 |
+
|
| 449 |
+
def _get(self, ref: List[ClientObjectRef], timeout: float):
|
| 450 |
+
req = ray_client_pb2.GetRequest(ids=[r.id for r in ref], timeout=timeout)
|
| 451 |
+
data = bytearray()
|
| 452 |
+
try:
|
| 453 |
+
resp = self._get_object_iterator(req, metadata=self.metadata)
|
| 454 |
+
for chunk in resp:
|
| 455 |
+
if not chunk.valid:
|
| 456 |
+
try:
|
| 457 |
+
err = cloudpickle.loads(chunk.error)
|
| 458 |
+
except (pickle.UnpicklingError, TypeError):
|
| 459 |
+
logger.exception("Failed to deserialize {}".format(chunk.error))
|
| 460 |
+
raise
|
| 461 |
+
raise err
|
| 462 |
+
if chunk.total_size > OBJECT_TRANSFER_WARNING_SIZE and log_once(
|
| 463 |
+
"client_object_transfer_size_warning"
|
| 464 |
+
):
|
| 465 |
+
size_gb = chunk.total_size / 2**30
|
| 466 |
+
warnings.warn(
|
| 467 |
+
"Ray Client is attempting to retrieve a "
|
| 468 |
+
f"{size_gb:.2f} GiB object over the network, which may "
|
| 469 |
+
"be slow. Consider serializing the object to a file "
|
| 470 |
+
"and using S3 or rsync instead.",
|
| 471 |
+
UserWarning,
|
| 472 |
+
stacklevel=5,
|
| 473 |
+
)
|
| 474 |
+
data.extend(chunk.data)
|
| 475 |
+
except grpc.RpcError as e:
|
| 476 |
+
raise decode_exception(e)
|
| 477 |
+
return loads_from_server(data)
|
| 478 |
+
|
| 479 |
+
def put(
|
| 480 |
+
self,
|
| 481 |
+
val,
|
| 482 |
+
*,
|
| 483 |
+
client_ref_id: bytes = None,
|
| 484 |
+
_owner: Optional[ClientActorHandle] = None,
|
| 485 |
+
):
|
| 486 |
+
if isinstance(val, ClientObjectRef):
|
| 487 |
+
raise TypeError(
|
| 488 |
+
"Calling 'put' on an ObjectRef is not allowed "
|
| 489 |
+
"(similarly, returning an ObjectRef from a remote "
|
| 490 |
+
"function is not allowed). If you really want to "
|
| 491 |
+
"do this, you can wrap the ObjectRef in a list and "
|
| 492 |
+
"call 'put' on it (or return it)."
|
| 493 |
+
)
|
| 494 |
+
data = dumps_from_client(val, self._client_id)
|
| 495 |
+
return self._put_pickled(data, client_ref_id, _owner)
|
| 496 |
+
|
| 497 |
+
def _put_pickled(
|
| 498 |
+
self, data, client_ref_id: bytes, owner: Optional[ClientActorHandle] = None
|
| 499 |
+
):
|
| 500 |
+
req = ray_client_pb2.PutRequest(data=data)
|
| 501 |
+
if client_ref_id is not None:
|
| 502 |
+
req.client_ref_id = client_ref_id
|
| 503 |
+
if owner is not None:
|
| 504 |
+
req.owner_id = owner.actor_ref.id
|
| 505 |
+
|
| 506 |
+
resp = self.data_client.PutObject(req)
|
| 507 |
+
if not resp.valid:
|
| 508 |
+
try:
|
| 509 |
+
raise cloudpickle.loads(resp.error)
|
| 510 |
+
except (pickle.UnpicklingError, TypeError):
|
| 511 |
+
logger.exception("Failed to deserialize {}".format(resp.error))
|
| 512 |
+
raise
|
| 513 |
+
return ClientObjectRef(resp.id)
|
| 514 |
+
|
| 515 |
+
# TODO(ekl) respect MAX_BLOCKING_OPERATION_TIME_S for wait too
|
| 516 |
+
def wait(
|
| 517 |
+
self,
|
| 518 |
+
object_refs: List[ClientObjectRef],
|
| 519 |
+
*,
|
| 520 |
+
num_returns: int = 1,
|
| 521 |
+
timeout: float = None,
|
| 522 |
+
fetch_local: bool = True,
|
| 523 |
+
) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]:
|
| 524 |
+
if not isinstance(object_refs, list):
|
| 525 |
+
raise TypeError(
|
| 526 |
+
"wait() expected a list of ClientObjectRef, " f"got {type(object_refs)}"
|
| 527 |
+
)
|
| 528 |
+
for ref in object_refs:
|
| 529 |
+
if not isinstance(ref, ClientObjectRef):
|
| 530 |
+
raise TypeError(
|
| 531 |
+
"wait() expected a list of ClientObjectRef, "
|
| 532 |
+
f"got list containing {type(ref)}"
|
| 533 |
+
)
|
| 534 |
+
data = {
|
| 535 |
+
"object_ids": [object_ref.id for object_ref in object_refs],
|
| 536 |
+
"num_returns": num_returns,
|
| 537 |
+
"timeout": timeout if (timeout is not None) else -1,
|
| 538 |
+
"client_id": self._client_id,
|
| 539 |
+
}
|
| 540 |
+
req = ray_client_pb2.WaitRequest(**data)
|
| 541 |
+
resp = self._call_stub("WaitObject", req, metadata=self.metadata)
|
| 542 |
+
if not resp.valid:
|
| 543 |
+
# TODO(ameer): improve error/exceptions messages.
|
| 544 |
+
raise Exception("Client Wait request failed. Reference invalid?")
|
| 545 |
+
client_ready_object_ids = [
|
| 546 |
+
ClientObjectRef(ref) for ref in resp.ready_object_ids
|
| 547 |
+
]
|
| 548 |
+
client_remaining_object_ids = [
|
| 549 |
+
ClientObjectRef(ref) for ref in resp.remaining_object_ids
|
| 550 |
+
]
|
| 551 |
+
|
| 552 |
+
return (client_ready_object_ids, client_remaining_object_ids)
|
| 553 |
+
|
| 554 |
+
def call_remote(self, instance, *args, **kwargs) -> List[Future]:
|
| 555 |
+
task = instance._prepare_client_task()
|
| 556 |
+
# data is serialized tuple of (args, kwargs)
|
| 557 |
+
task.data = dumps_from_client((args, kwargs), self._client_id)
|
| 558 |
+
num_returns = instance._num_returns()
|
| 559 |
+
if num_returns == "dynamic":
|
| 560 |
+
num_returns = -1
|
| 561 |
+
if num_returns == "streaming":
|
| 562 |
+
raise RuntimeError(
|
| 563 |
+
'Streaming actor methods (num_returns="streaming") '
|
| 564 |
+
"are not currently supported when using Ray Client."
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
return self._call_schedule_for_task(task, num_returns)
|
| 568 |
+
|
| 569 |
+
def _call_schedule_for_task(
|
| 570 |
+
self, task: ray_client_pb2.ClientTask, num_returns: Optional[int]
|
| 571 |
+
) -> List[Future]:
|
| 572 |
+
logger.debug(f"Scheduling task {task.name} {task.type} {task.payload_id}")
|
| 573 |
+
task.client_id = self._client_id
|
| 574 |
+
if num_returns is None:
|
| 575 |
+
num_returns = 1
|
| 576 |
+
|
| 577 |
+
num_return_refs = num_returns
|
| 578 |
+
if num_return_refs == -1:
|
| 579 |
+
num_return_refs = 1
|
| 580 |
+
id_futures = [Future() for _ in range(num_return_refs)]
|
| 581 |
+
|
| 582 |
+
def populate_ids(resp: Union[ray_client_pb2.DataResponse, Exception]) -> None:
|
| 583 |
+
if isinstance(resp, Exception):
|
| 584 |
+
if isinstance(resp, grpc.RpcError):
|
| 585 |
+
resp = decode_exception(resp)
|
| 586 |
+
for future in id_futures:
|
| 587 |
+
future.set_exception(resp)
|
| 588 |
+
return
|
| 589 |
+
|
| 590 |
+
ticket = resp.task_ticket
|
| 591 |
+
if not ticket.valid:
|
| 592 |
+
try:
|
| 593 |
+
ex = cloudpickle.loads(ticket.error)
|
| 594 |
+
except (pickle.UnpicklingError, TypeError) as e_new:
|
| 595 |
+
ex = e_new
|
| 596 |
+
for future in id_futures:
|
| 597 |
+
future.set_exception(ex)
|
| 598 |
+
return
|
| 599 |
+
|
| 600 |
+
if len(ticket.return_ids) != num_return_refs:
|
| 601 |
+
exc = ValueError(
|
| 602 |
+
f"Expected {num_return_refs} returns but received "
|
| 603 |
+
f"{len(ticket.return_ids)}"
|
| 604 |
+
)
|
| 605 |
+
for future, raw_id in zip(id_futures, ticket.return_ids):
|
| 606 |
+
future.set_exception(exc)
|
| 607 |
+
return
|
| 608 |
+
|
| 609 |
+
for future, raw_id in zip(id_futures, ticket.return_ids):
|
| 610 |
+
future.set_result(raw_id)
|
| 611 |
+
|
| 612 |
+
self.data_client.Schedule(task, populate_ids)
|
| 613 |
+
|
| 614 |
+
self.total_outbound_message_size_bytes += task.ByteSize()
|
| 615 |
+
if (
|
| 616 |
+
self.total_outbound_message_size_bytes > MESSAGE_SIZE_THRESHOLD
|
| 617 |
+
and log_once("client_communication_overhead_warning")
|
| 618 |
+
):
|
| 619 |
+
warnings.warn(
|
| 620 |
+
"More than 10MB of messages have been created to schedule "
|
| 621 |
+
"tasks on the server. This can be slow on Ray Client due to "
|
| 622 |
+
"communication overhead over the network. If you're running "
|
| 623 |
+
"many fine-grained tasks, consider running them inside a "
|
| 624 |
+
'single remote function. See the section on "Too '
|
| 625 |
+
'fine-grained tasks" in the Ray Design Patterns document for '
|
| 626 |
+
f"more details: {DESIGN_PATTERN_FINE_GRAIN_TASKS_LINK}. If "
|
| 627 |
+
"your functions frequently use large objects, consider "
|
| 628 |
+
"storing the objects remotely with ray.put. An example of "
|
| 629 |
+
'this is shown in the "Closure capture of large / '
|
| 630 |
+
'unserializable object" section of the Ray Design Patterns '
|
| 631 |
+
"document, available here: "
|
| 632 |
+
f"{DESIGN_PATTERN_LARGE_OBJECTS_LINK}",
|
| 633 |
+
UserWarning,
|
| 634 |
+
)
|
| 635 |
+
return id_futures
|
| 636 |
+
|
| 637 |
+
def call_release(self, id: bytes) -> None:
|
| 638 |
+
if self.closed:
|
| 639 |
+
return
|
| 640 |
+
self.reference_count[id] -= 1
|
| 641 |
+
if self.reference_count[id] == 0:
|
| 642 |
+
self._release_server(id)
|
| 643 |
+
del self.reference_count[id]
|
| 644 |
+
|
| 645 |
+
def _release_server(self, id: bytes) -> None:
|
| 646 |
+
if self.data_client is not None:
|
| 647 |
+
logger.debug(f"Releasing {id.hex()}")
|
| 648 |
+
self.data_client.ReleaseObject(ray_client_pb2.ReleaseRequest(ids=[id]))
|
| 649 |
+
|
| 650 |
+
def call_retain(self, id: bytes) -> None:
|
| 651 |
+
logger.debug(f"Retaining {id.hex()}")
|
| 652 |
+
self.reference_count[id] += 1
|
| 653 |
+
|
| 654 |
+
def close(self):
|
| 655 |
+
self._in_shutdown = True
|
| 656 |
+
self.closed = True
|
| 657 |
+
self.data_client.close()
|
| 658 |
+
self.log_client.close()
|
| 659 |
+
self.server = None
|
| 660 |
+
if self.channel:
|
| 661 |
+
self.channel.close()
|
| 662 |
+
self.channel = None
|
| 663 |
+
|
| 664 |
+
def get_actor(
|
| 665 |
+
self, name: str, namespace: Optional[str] = None
|
| 666 |
+
) -> ClientActorHandle:
|
| 667 |
+
task = ray_client_pb2.ClientTask()
|
| 668 |
+
task.type = ray_client_pb2.ClientTask.NAMED_ACTOR
|
| 669 |
+
task.name = name
|
| 670 |
+
task.namespace = namespace or ""
|
| 671 |
+
# Populate task.data with empty args and kwargs
|
| 672 |
+
task.data = dumps_from_client(([], {}), self._client_id)
|
| 673 |
+
futures = self._call_schedule_for_task(task, 1)
|
| 674 |
+
assert len(futures) == 1
|
| 675 |
+
handle = ClientActorHandle(ClientActorRef(futures[0], weak_ref=True))
|
| 676 |
+
# `actor_ref.is_nil()` waits until the underlying ID is resolved.
|
| 677 |
+
# This is needed because `get_actor` is often used to check the
|
| 678 |
+
# existence of an actor.
|
| 679 |
+
if handle.actor_ref.is_nil():
|
| 680 |
+
raise ValueError(f"ActorID for {name} is empty")
|
| 681 |
+
return handle
|
| 682 |
+
|
| 683 |
+
def terminate_actor(self, actor: ClientActorHandle, no_restart: bool) -> None:
|
| 684 |
+
if not isinstance(actor, ClientActorHandle):
|
| 685 |
+
raise ValueError(
|
| 686 |
+
"ray.kill() only supported for actors. Got: {}.".format(type(actor))
|
| 687 |
+
)
|
| 688 |
+
term_actor = ray_client_pb2.TerminateRequest.ActorTerminate()
|
| 689 |
+
term_actor.id = actor.actor_ref.id
|
| 690 |
+
term_actor.no_restart = no_restart
|
| 691 |
+
term = ray_client_pb2.TerminateRequest(actor=term_actor)
|
| 692 |
+
term.client_id = self._client_id
|
| 693 |
+
try:
|
| 694 |
+
self.data_client.Terminate(term)
|
| 695 |
+
except grpc.RpcError as e:
|
| 696 |
+
raise decode_exception(e)
|
| 697 |
+
|
| 698 |
+
def terminate_task(
|
| 699 |
+
self, obj: ClientObjectRef, force: bool, recursive: bool
|
| 700 |
+
) -> None:
|
| 701 |
+
if not isinstance(obj, ClientObjectRef):
|
| 702 |
+
raise TypeError(
|
| 703 |
+
"ray.cancel() only supported for non-actor object refs. "
|
| 704 |
+
f"Got: {type(obj)}."
|
| 705 |
+
)
|
| 706 |
+
term_object = ray_client_pb2.TerminateRequest.TaskObjectTerminate()
|
| 707 |
+
term_object.id = obj.id
|
| 708 |
+
term_object.force = force
|
| 709 |
+
term_object.recursive = recursive
|
| 710 |
+
term = ray_client_pb2.TerminateRequest(task_object=term_object)
|
| 711 |
+
term.client_id = self._client_id
|
| 712 |
+
try:
|
| 713 |
+
self.data_client.Terminate(term)
|
| 714 |
+
except grpc.RpcError as e:
|
| 715 |
+
raise decode_exception(e)
|
| 716 |
+
|
| 717 |
+
def get_cluster_info(
|
| 718 |
+
self,
|
| 719 |
+
req_type: ray_client_pb2.ClusterInfoType.TypeEnum,
|
| 720 |
+
timeout: Optional[float] = None,
|
| 721 |
+
):
|
| 722 |
+
req = ray_client_pb2.ClusterInfoRequest()
|
| 723 |
+
req.type = req_type
|
| 724 |
+
resp = self.server.ClusterInfo(req, timeout=timeout, metadata=self.metadata)
|
| 725 |
+
if resp.WhichOneof("response_type") == "resource_table":
|
| 726 |
+
# translate from a proto map to a python dict
|
| 727 |
+
output_dict = {k: v for k, v in resp.resource_table.table.items()}
|
| 728 |
+
return output_dict
|
| 729 |
+
elif resp.WhichOneof("response_type") == "runtime_context":
|
| 730 |
+
return resp.runtime_context
|
| 731 |
+
return json.loads(resp.json)
|
| 732 |
+
|
| 733 |
+
def internal_kv_get(self, key: bytes, namespace: Optional[bytes]) -> bytes:
|
| 734 |
+
req = ray_client_pb2.KVGetRequest(key=key, namespace=namespace)
|
| 735 |
+
try:
|
| 736 |
+
resp = self._call_stub("KVGet", req, metadata=self.metadata)
|
| 737 |
+
except grpc.RpcError as e:
|
| 738 |
+
raise decode_exception(e)
|
| 739 |
+
if resp.HasField("value"):
|
| 740 |
+
return resp.value
|
| 741 |
+
# Value is None when the key does not exist in the KV.
|
| 742 |
+
return None
|
| 743 |
+
|
| 744 |
+
def internal_kv_exists(self, key: bytes, namespace: Optional[bytes]) -> bool:
|
| 745 |
+
req = ray_client_pb2.KVExistsRequest(key=key, namespace=namespace)
|
| 746 |
+
try:
|
| 747 |
+
resp = self._call_stub("KVExists", req, metadata=self.metadata)
|
| 748 |
+
except grpc.RpcError as e:
|
| 749 |
+
raise decode_exception(e)
|
| 750 |
+
return resp.exists
|
| 751 |
+
|
| 752 |
+
def internal_kv_put(
|
| 753 |
+
self, key: bytes, value: bytes, overwrite: bool, namespace: Optional[bytes]
|
| 754 |
+
) -> bool:
|
| 755 |
+
req = ray_client_pb2.KVPutRequest(
|
| 756 |
+
key=key, value=value, overwrite=overwrite, namespace=namespace
|
| 757 |
+
)
|
| 758 |
+
metadata = self._add_ids_to_metadata(self.metadata)
|
| 759 |
+
try:
|
| 760 |
+
resp = self._call_stub("KVPut", req, metadata=metadata)
|
| 761 |
+
except grpc.RpcError as e:
|
| 762 |
+
raise decode_exception(e)
|
| 763 |
+
return resp.already_exists
|
| 764 |
+
|
| 765 |
+
def internal_kv_del(
|
| 766 |
+
self, key: bytes, del_by_prefix: bool, namespace: Optional[bytes]
|
| 767 |
+
) -> int:
|
| 768 |
+
req = ray_client_pb2.KVDelRequest(
|
| 769 |
+
key=key, del_by_prefix=del_by_prefix, namespace=namespace
|
| 770 |
+
)
|
| 771 |
+
metadata = self._add_ids_to_metadata(self.metadata)
|
| 772 |
+
try:
|
| 773 |
+
resp = self._call_stub("KVDel", req, metadata=metadata)
|
| 774 |
+
except grpc.RpcError as e:
|
| 775 |
+
raise decode_exception(e)
|
| 776 |
+
return resp.deleted_num
|
| 777 |
+
|
| 778 |
+
def internal_kv_list(
|
| 779 |
+
self, prefix: bytes, namespace: Optional[bytes]
|
| 780 |
+
) -> List[bytes]:
|
| 781 |
+
try:
|
| 782 |
+
req = ray_client_pb2.KVListRequest(prefix=prefix, namespace=namespace)
|
| 783 |
+
return self._call_stub("KVList", req, metadata=self.metadata).keys
|
| 784 |
+
except grpc.RpcError as e:
|
| 785 |
+
raise decode_exception(e)
|
| 786 |
+
|
| 787 |
+
def pin_runtime_env_uri(self, uri: str, expiration_s: int) -> None:
|
| 788 |
+
req = ray_client_pb2.ClientPinRuntimeEnvURIRequest(
|
| 789 |
+
uri=uri, expiration_s=expiration_s
|
| 790 |
+
)
|
| 791 |
+
self._call_stub("PinRuntimeEnvURI", req, metadata=self.metadata)
|
| 792 |
+
|
| 793 |
+
def list_named_actors(self, all_namespaces: bool) -> List[Dict[str, str]]:
|
| 794 |
+
req = ray_client_pb2.ClientListNamedActorsRequest(all_namespaces=all_namespaces)
|
| 795 |
+
return json.loads(self.data_client.ListNamedActors(req).actors_json)
|
| 796 |
+
|
| 797 |
+
def is_initialized(self) -> bool:
|
| 798 |
+
if not self.is_connected() or self.server is None:
|
| 799 |
+
return False
|
| 800 |
+
if not self._serverside_ray_initialized:
|
| 801 |
+
# We only check that Ray is initialized on the server once to
|
| 802 |
+
# avoid making an RPC every time this function is called. This is
|
| 803 |
+
# safe to do because Ray only 'un-initializes' on the server when
|
| 804 |
+
# the Client connection is torn down.
|
| 805 |
+
self._serverside_ray_initialized = self.get_cluster_info(
|
| 806 |
+
ray_client_pb2.ClusterInfoType.IS_INITIALIZED
|
| 807 |
+
)
|
| 808 |
+
|
| 809 |
+
return self._serverside_ray_initialized
|
| 810 |
+
|
| 811 |
+
def ping_server(self, timeout=None) -> bool:
|
| 812 |
+
"""Simple health check.
|
| 813 |
+
|
| 814 |
+
Piggybacks the IS_INITIALIZED call to check if the server provides
|
| 815 |
+
an actual response.
|
| 816 |
+
"""
|
| 817 |
+
if self.server is not None:
|
| 818 |
+
logger.debug("Pinging server.")
|
| 819 |
+
result = self.get_cluster_info(
|
| 820 |
+
ray_client_pb2.ClusterInfoType.PING, timeout=timeout
|
| 821 |
+
)
|
| 822 |
+
return result is not None
|
| 823 |
+
return False
|
| 824 |
+
|
| 825 |
+
def is_connected(self) -> bool:
|
| 826 |
+
return not self._in_shutdown and self._has_connected
|
| 827 |
+
|
| 828 |
+
def _server_init(
|
| 829 |
+
self, job_config: JobConfig, ray_init_kwargs: Optional[Dict[str, Any]] = None
|
| 830 |
+
):
|
| 831 |
+
"""Initialize the server"""
|
| 832 |
+
if ray_init_kwargs is None:
|
| 833 |
+
ray_init_kwargs = {}
|
| 834 |
+
try:
|
| 835 |
+
if job_config is None:
|
| 836 |
+
serialized_job_config = None
|
| 837 |
+
else:
|
| 838 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 839 |
+
runtime_env = job_config.runtime_env or {}
|
| 840 |
+
runtime_env = upload_py_modules_if_needed(
|
| 841 |
+
runtime_env, tmp_dir, logger=logger
|
| 842 |
+
)
|
| 843 |
+
runtime_env = upload_working_dir_if_needed(
|
| 844 |
+
runtime_env, tmp_dir, logger=logger
|
| 845 |
+
)
|
| 846 |
+
# Remove excludes, it isn't relevant after the upload step.
|
| 847 |
+
runtime_env.pop("excludes", None)
|
| 848 |
+
job_config.set_runtime_env(runtime_env, validate=True)
|
| 849 |
+
|
| 850 |
+
serialized_job_config = pickle.dumps(job_config)
|
| 851 |
+
|
| 852 |
+
response = self.data_client.Init(
|
| 853 |
+
ray_client_pb2.InitRequest(
|
| 854 |
+
job_config=serialized_job_config,
|
| 855 |
+
ray_init_kwargs=json.dumps(ray_init_kwargs),
|
| 856 |
+
reconnect_grace_period=self._reconnect_grace_period,
|
| 857 |
+
)
|
| 858 |
+
)
|
| 859 |
+
if not response.ok:
|
| 860 |
+
raise ConnectionAbortedError(
|
| 861 |
+
f"Initialization failure from server:\n{response.msg}"
|
| 862 |
+
)
|
| 863 |
+
|
| 864 |
+
except grpc.RpcError as e:
|
| 865 |
+
raise decode_exception(e)
|
| 866 |
+
|
| 867 |
+
def _convert_actor(self, actor: "ActorClass") -> str:
|
| 868 |
+
"""Register a ClientActorClass for the ActorClass and return a UUID"""
|
| 869 |
+
key = uuid.uuid4().hex
|
| 870 |
+
cls = actor.__ray_metadata__.modified_class
|
| 871 |
+
self._converted[key] = ClientActorClass(cls, options=actor._default_options)
|
| 872 |
+
return key
|
| 873 |
+
|
| 874 |
+
def _convert_function(self, func: "RemoteFunction") -> str:
|
| 875 |
+
"""Register a ClientRemoteFunc for the ActorClass and return a UUID"""
|
| 876 |
+
key = uuid.uuid4().hex
|
| 877 |
+
self._converted[key] = ClientRemoteFunc(
|
| 878 |
+
func._function, options=func._default_options
|
| 879 |
+
)
|
| 880 |
+
return key
|
| 881 |
+
|
| 882 |
+
def _get_converted(self, key: str) -> "ClientStub":
|
| 883 |
+
"""Given a UUID, return the converted object"""
|
| 884 |
+
return self._converted[key]
|
| 885 |
+
|
| 886 |
+
def _converted_key_exists(self, key: str) -> bool:
|
| 887 |
+
"""Check if a key UUID is present in the store of converted objects."""
|
| 888 |
+
return key in self._converted
|
| 889 |
+
|
| 890 |
+
def _dumps_from_client(self, val) -> bytes:
|
| 891 |
+
return dumps_from_client(val, self._client_id)
|
| 892 |
+
|
| 893 |
+
|
| 894 |
+
def make_client_id() -> str:
|
| 895 |
+
id = uuid.uuid4()
|
| 896 |
+
return id.hex
|
| 897 |
+
|
| 898 |
+
|
| 899 |
+
def decode_exception(e: grpc.RpcError) -> Exception:
|
| 900 |
+
if e.code() != grpc.StatusCode.ABORTED:
|
| 901 |
+
# The ABORTED status code is used by the server when an application
|
| 902 |
+
# error is serialized into the the exception details. If the code
|
| 903 |
+
# isn't ABORTED, then return the original error since there's no
|
| 904 |
+
# serialized error to decode.
|
| 905 |
+
# See server.py::return_exception_in_context for details
|
| 906 |
+
return ConnectionError(f"GRPC connection failed: {e}")
|
| 907 |
+
data = base64.standard_b64decode(e.details())
|
| 908 |
+
return loads_from_server(data)
|
.venv/lib/python3.11/site-packages/ray/util/dask/__init__.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dask
|
| 2 |
+
from .scheduler import (
|
| 3 |
+
ray_dask_get,
|
| 4 |
+
ray_dask_get_sync,
|
| 5 |
+
enable_dask_on_ray,
|
| 6 |
+
disable_dask_on_ray,
|
| 7 |
+
)
|
| 8 |
+
from .callbacks import (
|
| 9 |
+
RayDaskCallback,
|
| 10 |
+
local_ray_callbacks,
|
| 11 |
+
unpack_ray_callbacks,
|
| 12 |
+
ProgressBarCallback,
|
| 13 |
+
)
|
| 14 |
+
from .optimizations import dataframe_optimize
|
| 15 |
+
|
| 16 |
+
dask_persist = dask.persist
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def ray_dask_persist(*args, **kwargs):
|
| 20 |
+
kwargs["ray_persist"] = True
|
| 21 |
+
return dask_persist(*args, **kwargs)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
ray_dask_persist.__doc__ = dask_persist.__doc__
|
| 25 |
+
|
| 26 |
+
dask_persist_mixin = dask.base.DaskMethodsMixin.persist
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def ray_dask_persist_mixin(self, **kwargs):
|
| 30 |
+
kwargs["ray_persist"] = True
|
| 31 |
+
return dask_persist_mixin(self, **kwargs)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
ray_dask_persist_mixin.__doc__ = dask_persist_mixin.__doc__
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# We patch dask in order to inject a kwarg into its `dask.persist()` calls,
|
| 38 |
+
# which the Dask-on-Ray scheduler needs.
|
| 39 |
+
# FIXME(Clark): Monkey patching is bad and we should try to avoid this.
|
| 40 |
+
def patch_dask(ray_dask_persist, ray_dask_persist_mixin):
|
| 41 |
+
dask.persist = ray_dask_persist
|
| 42 |
+
dask.base.DaskMethodsMixin.persist = ray_dask_persist_mixin
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
patch_dask(ray_dask_persist, ray_dask_persist_mixin)
|
| 46 |
+
|
| 47 |
+
__all__ = [
|
| 48 |
+
# Config
|
| 49 |
+
"enable_dask_on_ray",
|
| 50 |
+
"disable_dask_on_ray",
|
| 51 |
+
# Schedulers
|
| 52 |
+
"ray_dask_get",
|
| 53 |
+
"ray_dask_get_sync",
|
| 54 |
+
# Helpers
|
| 55 |
+
"ray_dask_persist",
|
| 56 |
+
# Callbacks
|
| 57 |
+
"RayDaskCallback",
|
| 58 |
+
"local_ray_callbacks",
|
| 59 |
+
"unpack_ray_callbacks",
|
| 60 |
+
# Optimizations
|
| 61 |
+
"dataframe_optimize",
|
| 62 |
+
"ProgressBarCallback",
|
| 63 |
+
]
|
.venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.68 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/callbacks.cpython-311.pyc
ADDED
|
Binary file (17.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/common.cpython-311.pyc
ADDED
|
Binary file (4.98 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/optimizations.cpython-311.pyc
ADDED
|
Binary file (7.55 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/scheduler.cpython-311.pyc
ADDED
|
Binary file (25.1 kB). View file
|
|
|