Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .venv/lib/python3.11/site-packages/grpc/__pycache__/_channel.cpython-311.pyc +3 -0
- .venv/lib/python3.11/site-packages/vllm/core/block_manager.py +520 -0
- .venv/lib/python3.11/site-packages/vllm/core/evictor.py +156 -0
- .venv/lib/python3.11/site-packages/vllm/device_allocator/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/vllm/device_allocator/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/device_allocator/cumem.py +256 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/__init__.py +5 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/__pycache__/communication_op.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/__pycache__/parallel_state.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/communication_op.py +34 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/__pycache__/custom_all_reduce.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/__pycache__/pynccl.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/cuda_wrapper.py +173 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/custom_all_reduce.py +305 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/custom_all_reduce_utils.py +257 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/hpu_communicator.py +50 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/pynccl.py +217 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/shm_broadcast.py +530 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/xpu_communicator.py +49 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/__pycache__/kv_transfer_agent.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/__pycache__/base.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/__pycache__/factory.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/__pycache__/simple_connector.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/base.py +123 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/factory.py +50 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/simple_connector.py +314 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/__pycache__/base.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/__pycache__/simple_buffer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py +109 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +243 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/__pycache__/base.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/__pycache__/mooncake_pipe.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/__pycache__/pynccl_pipe.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/base.py +66 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py +274 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py +277 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_transfer_agent.py +76 -0
- .venv/lib/python3.11/site-packages/vllm/distributed/parallel_state.py +1285 -0
.gitattributes
CHANGED
|
@@ -199,3 +199,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
|
|
| 199 |
.venv/lib/python3.11/site-packages/google/_upb/_message.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 200 |
.venv/lib/python3.11/site-packages/google/protobuf/__pycache__/descriptor_pb2.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 201 |
.venv/lib/python3.11/site-packages/jinja2/__pycache__/compiler.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 199 |
.venv/lib/python3.11/site-packages/google/_upb/_message.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 200 |
.venv/lib/python3.11/site-packages/google/protobuf/__pycache__/descriptor_pb2.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 201 |
.venv/lib/python3.11/site-packages/jinja2/__pycache__/compiler.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 202 |
+
.venv/lib/python3.11/site-packages/grpc/__pycache__/_channel.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.11/site-packages/grpc/__pycache__/_channel.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f4a55c7e16388e486f345b7c775a758b4a05a398d378d7491610665c89805e0f
|
| 3 |
+
size 103674
|
.venv/lib/python3.11/site-packages/vllm/core/block_manager.py
ADDED
|
@@ -0,0 +1,520 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""A block manager that manages token blocks."""
|
| 3 |
+
from typing import Dict, List, Optional
|
| 4 |
+
from typing import Sequence as GenericSequence
|
| 5 |
+
from typing import Tuple
|
| 6 |
+
|
| 7 |
+
from vllm.core.block.block_table import BlockTable
|
| 8 |
+
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
|
| 9 |
+
from vllm.core.block.interfaces import Block
|
| 10 |
+
from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker,
|
| 11 |
+
LastAccessBlocksTracker)
|
| 12 |
+
from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec
|
| 13 |
+
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
| 14 |
+
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
| 15 |
+
from vllm.utils import Device
|
| 16 |
+
|
| 17 |
+
SeqId = int
|
| 18 |
+
EncoderSeqId = str
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class SelfAttnBlockSpaceManager(BlockSpaceManager):
|
| 22 |
+
"""BlockSpaceManager which manages the allocation of KV cache.
|
| 23 |
+
|
| 24 |
+
It owns responsibility for allocation, swapping, allocating memory for
|
| 25 |
+
autoregressively-generated tokens, and other advanced features such as
|
| 26 |
+
prefix caching, forking/copy-on-write, and sliding-window memory allocation.
|
| 27 |
+
|
| 28 |
+
This class implements the design described in
|
| 29 |
+
https://github.com/vllm-project/vllm/pull/3492.
|
| 30 |
+
|
| 31 |
+
Lookahead slots
|
| 32 |
+
The block manager has the notion of a "lookahead slot". These are slots
|
| 33 |
+
in the KV cache that are allocated for a sequence. Unlike the other
|
| 34 |
+
allocated slots, the content of these slots is undefined -- the worker
|
| 35 |
+
may use the memory allocations in any way.
|
| 36 |
+
|
| 37 |
+
In practice, a worker could use these lookahead slots to run multiple
|
| 38 |
+
forward passes for a single scheduler invocation. Each successive
|
| 39 |
+
forward pass would write KV activations to the corresponding lookahead
|
| 40 |
+
slot. This allows low inter-token latency use-cases, where the overhead
|
| 41 |
+
of continuous batching scheduling is amortized over >1 generated tokens.
|
| 42 |
+
|
| 43 |
+
Speculative decoding uses lookahead slots to store KV activations of
|
| 44 |
+
proposal tokens.
|
| 45 |
+
|
| 46 |
+
See https://github.com/vllm-project/vllm/pull/3250 for more information
|
| 47 |
+
on lookahead scheduling.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
block_size (int): The size of each memory block.
|
| 51 |
+
num_gpu_blocks (int): The number of memory blocks allocated on GPU.
|
| 52 |
+
num_cpu_blocks (int): The number of memory blocks allocated on CPU.
|
| 53 |
+
watermark (float, optional): The threshold used for memory swapping.
|
| 54 |
+
Defaults to 0.01.
|
| 55 |
+
sliding_window (Optional[int], optional): The size of the sliding
|
| 56 |
+
window. Defaults to None.
|
| 57 |
+
enable_caching (bool, optional): Flag indicating whether caching is
|
| 58 |
+
enabled. Defaults to False.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
block_size: int,
|
| 64 |
+
num_gpu_blocks: int,
|
| 65 |
+
num_cpu_blocks: int,
|
| 66 |
+
watermark: float = 0.01,
|
| 67 |
+
sliding_window: Optional[int] = None,
|
| 68 |
+
enable_caching: bool = False,
|
| 69 |
+
) -> None:
|
| 70 |
+
self.block_size = block_size
|
| 71 |
+
self.num_total_gpu_blocks = num_gpu_blocks
|
| 72 |
+
self.num_total_cpu_blocks = num_cpu_blocks
|
| 73 |
+
|
| 74 |
+
self.sliding_window = sliding_window
|
| 75 |
+
# max_block_sliding_window is the max number of blocks that need to be
|
| 76 |
+
# allocated
|
| 77 |
+
self.max_block_sliding_window = None
|
| 78 |
+
if sliding_window is not None:
|
| 79 |
+
# +1 here because // rounds down
|
| 80 |
+
num_blocks = sliding_window // block_size + 1
|
| 81 |
+
# +1 here because the last block may not be full,
|
| 82 |
+
# and so the sequence stretches one more block at the beginning
|
| 83 |
+
# For example, if sliding_window is 3 and block_size is 4,
|
| 84 |
+
# we may need 2 blocks when the second block only holds 1 token.
|
| 85 |
+
self.max_block_sliding_window = num_blocks + 1
|
| 86 |
+
|
| 87 |
+
self.watermark = watermark
|
| 88 |
+
assert watermark >= 0.0
|
| 89 |
+
|
| 90 |
+
self.enable_caching = enable_caching
|
| 91 |
+
|
| 92 |
+
self.watermark_blocks = int(watermark * num_gpu_blocks)
|
| 93 |
+
|
| 94 |
+
self.block_allocator = CpuGpuBlockAllocator.create(
|
| 95 |
+
allocator_type="prefix_caching" if enable_caching else "naive",
|
| 96 |
+
num_gpu_blocks=num_gpu_blocks,
|
| 97 |
+
num_cpu_blocks=num_cpu_blocks,
|
| 98 |
+
block_size=block_size,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
self.block_tables: Dict[SeqId, BlockTable] = {}
|
| 102 |
+
self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {}
|
| 103 |
+
|
| 104 |
+
self._computed_blocks_tracker = ComputedBlocksTracker(
|
| 105 |
+
self.block_allocator, self.block_size, self.enable_caching)
|
| 106 |
+
self._last_access_blocks_tracker = LastAccessBlocksTracker(
|
| 107 |
+
self.block_allocator)
|
| 108 |
+
|
| 109 |
+
def can_allocate(self,
|
| 110 |
+
seq_group: SequenceGroup,
|
| 111 |
+
num_lookahead_slots: int = 0) -> AllocStatus:
|
| 112 |
+
# FIXME(woosuk): Here we assume that all sequences in the group share
|
| 113 |
+
# the same prompt. This may not be true for preempted sequences.
|
| 114 |
+
|
| 115 |
+
check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
|
| 116 |
+
|
| 117 |
+
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
|
| 118 |
+
num_required_blocks = BlockTable.get_num_required_blocks(
|
| 119 |
+
seq.get_token_ids(),
|
| 120 |
+
block_size=self.block_size,
|
| 121 |
+
num_lookahead_slots=num_lookahead_slots,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
if seq_group.is_encoder_decoder():
|
| 125 |
+
encoder_seq = seq_group.get_encoder_seq()
|
| 126 |
+
assert encoder_seq is not None
|
| 127 |
+
num_required_blocks += BlockTable.get_num_required_blocks(
|
| 128 |
+
encoder_seq.get_token_ids(),
|
| 129 |
+
block_size=self.block_size,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
if self.max_block_sliding_window is not None:
|
| 133 |
+
num_required_blocks = min(num_required_blocks,
|
| 134 |
+
self.max_block_sliding_window)
|
| 135 |
+
|
| 136 |
+
num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
|
| 137 |
+
device=Device.GPU)
|
| 138 |
+
|
| 139 |
+
# Use watermark to avoid frequent cache eviction.
|
| 140 |
+
if (self.num_total_gpu_blocks - num_required_blocks
|
| 141 |
+
< self.watermark_blocks):
|
| 142 |
+
return AllocStatus.NEVER
|
| 143 |
+
if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
|
| 144 |
+
return AllocStatus.OK
|
| 145 |
+
else:
|
| 146 |
+
return AllocStatus.LATER
|
| 147 |
+
|
| 148 |
+
def _allocate_sequence(self, seq: Sequence) -> BlockTable:
|
| 149 |
+
block_table = BlockTable(
|
| 150 |
+
block_size=self.block_size,
|
| 151 |
+
block_allocator=self.block_allocator,
|
| 152 |
+
max_block_sliding_window=self.max_block_sliding_window,
|
| 153 |
+
)
|
| 154 |
+
if seq.get_token_ids():
|
| 155 |
+
# NOTE: If there are any factors affecting the block besides
|
| 156 |
+
# token_ids, they should be added as input to extra_hash.
|
| 157 |
+
extra_hash = seq.extra_hash()
|
| 158 |
+
|
| 159 |
+
# Add blocks to the block table only if the sequence is non empty.
|
| 160 |
+
block_table.allocate(token_ids=seq.get_token_ids(),
|
| 161 |
+
extra_hash=extra_hash)
|
| 162 |
+
|
| 163 |
+
return block_table
|
| 164 |
+
|
| 165 |
+
def allocate(self, seq_group: SequenceGroup) -> None:
|
| 166 |
+
|
| 167 |
+
# Allocate self-attention block tables for decoder sequences
|
| 168 |
+
waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
|
| 169 |
+
assert not (set(seq.seq_id for seq in waiting_seqs)
|
| 170 |
+
& self.block_tables.keys()), "block table already exists"
|
| 171 |
+
|
| 172 |
+
# NOTE: Here we assume that all sequences in the group have the same
|
| 173 |
+
# prompt.
|
| 174 |
+
seq = waiting_seqs[0]
|
| 175 |
+
block_table: BlockTable = self._allocate_sequence(seq)
|
| 176 |
+
self.block_tables[seq.seq_id] = block_table
|
| 177 |
+
|
| 178 |
+
# Track seq
|
| 179 |
+
self._last_access_blocks_tracker.add_seq(seq.seq_id)
|
| 180 |
+
|
| 181 |
+
# Assign the block table for each sequence.
|
| 182 |
+
for seq in waiting_seqs[1:]:
|
| 183 |
+
self.block_tables[seq.seq_id] = block_table.fork()
|
| 184 |
+
|
| 185 |
+
# Track seq
|
| 186 |
+
self._last_access_blocks_tracker.add_seq(seq.seq_id)
|
| 187 |
+
|
| 188 |
+
# Allocate cross-attention block table for encoder sequence
|
| 189 |
+
#
|
| 190 |
+
# NOTE: Here we assume that all sequences in the group have the same
|
| 191 |
+
# encoder prompt.
|
| 192 |
+
request_id = seq_group.request_id
|
| 193 |
+
|
| 194 |
+
assert (request_id
|
| 195 |
+
not in self.cross_block_tables), \
|
| 196 |
+
"block table already exists"
|
| 197 |
+
|
| 198 |
+
check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
|
| 199 |
+
|
| 200 |
+
if seq_group.is_encoder_decoder():
|
| 201 |
+
encoder_seq = seq_group.get_encoder_seq()
|
| 202 |
+
assert encoder_seq is not None
|
| 203 |
+
block_table = self._allocate_sequence(encoder_seq)
|
| 204 |
+
self.cross_block_tables[request_id] = block_table
|
| 205 |
+
|
| 206 |
+
def can_append_slots(self, seq_group: SequenceGroup,
|
| 207 |
+
num_lookahead_slots: int) -> bool:
|
| 208 |
+
"""Determine if there is enough space in the GPU KV cache to continue
|
| 209 |
+
generation of the specified sequence group.
|
| 210 |
+
|
| 211 |
+
We use a worst-case heuristic: assume each touched block will require a
|
| 212 |
+
new allocation (either via CoW or new block). We can append slots if the
|
| 213 |
+
number of touched blocks is less than the number of free blocks.
|
| 214 |
+
|
| 215 |
+
"Lookahead slots" are slots that are allocated in addition to the slots
|
| 216 |
+
for known tokens. The contents of the lookahead slots are not defined.
|
| 217 |
+
This is used by speculative decoding when speculating future tokens.
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
num_touched_blocks = 0
|
| 221 |
+
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
| 222 |
+
block_table = self.block_tables[seq.seq_id]
|
| 223 |
+
|
| 224 |
+
num_touched_blocks += (
|
| 225 |
+
block_table.get_num_blocks_touched_by_append_slots(
|
| 226 |
+
token_ids=block_table.get_unseen_token_ids(
|
| 227 |
+
seq.get_token_ids()),
|
| 228 |
+
num_lookahead_slots=num_lookahead_slots,
|
| 229 |
+
))
|
| 230 |
+
|
| 231 |
+
num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
|
| 232 |
+
Device.GPU)
|
| 233 |
+
return num_touched_blocks <= num_free_gpu_blocks
|
| 234 |
+
|
| 235 |
+
def append_slots(
|
| 236 |
+
self,
|
| 237 |
+
seq: Sequence,
|
| 238 |
+
num_lookahead_slots: int,
|
| 239 |
+
) -> List[Tuple[int, int]]:
|
| 240 |
+
|
| 241 |
+
block_table = self.block_tables[seq.seq_id]
|
| 242 |
+
|
| 243 |
+
block_table.append_token_ids(
|
| 244 |
+
token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()),
|
| 245 |
+
num_lookahead_slots=num_lookahead_slots,
|
| 246 |
+
num_computed_slots=seq.data.get_num_computed_tokens(),
|
| 247 |
+
extra_hash=seq.extra_hash(),
|
| 248 |
+
)
|
| 249 |
+
# Return any new copy-on-writes.
|
| 250 |
+
new_cows = self.block_allocator.clear_copy_on_writes()
|
| 251 |
+
return new_cows
|
| 252 |
+
|
| 253 |
+
def free(self, seq: Sequence) -> None:
|
| 254 |
+
seq_id = seq.seq_id
|
| 255 |
+
|
| 256 |
+
if seq_id not in self.block_tables:
|
| 257 |
+
# Already freed or haven't been scheduled yet.
|
| 258 |
+
return
|
| 259 |
+
|
| 260 |
+
# Update seq block ids with the latest access time
|
| 261 |
+
self._last_access_blocks_tracker.update_seq_blocks_last_access(
|
| 262 |
+
seq_id, self.block_tables[seq.seq_id].physical_block_ids)
|
| 263 |
+
|
| 264 |
+
# Untrack seq
|
| 265 |
+
self._last_access_blocks_tracker.remove_seq(seq_id)
|
| 266 |
+
self._computed_blocks_tracker.remove_seq(seq_id)
|
| 267 |
+
|
| 268 |
+
# Free table/blocks
|
| 269 |
+
self.block_tables[seq_id].free()
|
| 270 |
+
del self.block_tables[seq_id]
|
| 271 |
+
|
| 272 |
+
def free_cross(self, seq_group: SequenceGroup) -> None:
|
| 273 |
+
request_id = seq_group.request_id
|
| 274 |
+
if request_id not in self.cross_block_tables:
|
| 275 |
+
# Already freed or hasn't been scheduled yet.
|
| 276 |
+
return
|
| 277 |
+
self.cross_block_tables[request_id].free()
|
| 278 |
+
del self.cross_block_tables[request_id]
|
| 279 |
+
|
| 280 |
+
def get_block_table(self, seq: Sequence) -> List[int]:
|
| 281 |
+
block_ids = self.block_tables[seq.seq_id].physical_block_ids
|
| 282 |
+
return block_ids # type: ignore
|
| 283 |
+
|
| 284 |
+
def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]:
|
| 285 |
+
request_id = seq_group.request_id
|
| 286 |
+
assert request_id in self.cross_block_tables
|
| 287 |
+
block_ids = self.cross_block_tables[request_id].physical_block_ids
|
| 288 |
+
assert all(b is not None for b in block_ids)
|
| 289 |
+
return block_ids # type: ignore
|
| 290 |
+
|
| 291 |
+
def access_all_blocks_in_seq(self, seq: Sequence, now: float):
|
| 292 |
+
if self.enable_caching:
|
| 293 |
+
# Record the latest access time for the sequence. The actual update
|
| 294 |
+
# of the block ids is deferred to the sequence free(..) call, since
|
| 295 |
+
# only during freeing of block ids, the blocks are actually added to
|
| 296 |
+
# the evictor (which is when the most updated time is required)
|
| 297 |
+
# (This avoids expensive calls to mark_blocks_as_accessed(..))
|
| 298 |
+
self._last_access_blocks_tracker.update_last_access(
|
| 299 |
+
seq.seq_id, now)
|
| 300 |
+
|
| 301 |
+
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
|
| 302 |
+
token_chunk_size: int):
|
| 303 |
+
# If prefix caching is enabled, mark immutable blocks as computed
|
| 304 |
+
# right after they have been scheduled (for prefill). This assumes
|
| 305 |
+
# the scheduler is synchronous so blocks are actually computed when
|
| 306 |
+
# scheduling the next batch.
|
| 307 |
+
self.block_allocator.mark_blocks_as_computed([])
|
| 308 |
+
|
| 309 |
+
def get_common_computed_block_ids(
|
| 310 |
+
self, seqs: List[Sequence]) -> GenericSequence[int]:
|
| 311 |
+
"""Determine which blocks for which we skip prefill.
|
| 312 |
+
|
| 313 |
+
With prefix caching we can skip prefill for previously-generated blocks.
|
| 314 |
+
Currently, the attention implementation only supports skipping cached
|
| 315 |
+
blocks if they are a contiguous prefix of cached blocks.
|
| 316 |
+
|
| 317 |
+
This method determines which blocks can be safely skipped for all
|
| 318 |
+
sequences in the sequence group.
|
| 319 |
+
"""
|
| 320 |
+
computed_seq_block_ids = []
|
| 321 |
+
for seq in seqs:
|
| 322 |
+
all_blocks = self.block_tables[seq.seq_id].physical_block_ids
|
| 323 |
+
num_cached_tokens = (
|
| 324 |
+
self._computed_blocks_tracker.get_num_cached_tokens(seq))
|
| 325 |
+
assert num_cached_tokens % self.block_size == 0
|
| 326 |
+
num_cached_blocks = num_cached_tokens // self.block_size
|
| 327 |
+
computed_block_ids = all_blocks[:num_cached_blocks]
|
| 328 |
+
computed_seq_block_ids.append(computed_block_ids)
|
| 329 |
+
|
| 330 |
+
# NOTE(sang): This assumes seq_block_ids doesn't contain any None.
|
| 331 |
+
return self.block_allocator.get_common_computed_block_ids(
|
| 332 |
+
computed_seq_block_ids) # type: ignore
|
| 333 |
+
|
| 334 |
+
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
| 335 |
+
if parent_seq.seq_id not in self.block_tables:
|
| 336 |
+
# Parent sequence has either been freed or never existed.
|
| 337 |
+
return
|
| 338 |
+
src_block_table = self.block_tables[parent_seq.seq_id]
|
| 339 |
+
self.block_tables[child_seq.seq_id] = src_block_table.fork()
|
| 340 |
+
|
| 341 |
+
# Track child seq
|
| 342 |
+
self._last_access_blocks_tracker.add_seq(child_seq.seq_id)
|
| 343 |
+
|
| 344 |
+
def can_swap_in(self, seq_group: SequenceGroup,
|
| 345 |
+
num_lookahead_slots: int) -> AllocStatus:
|
| 346 |
+
"""Returns the AllocStatus for the given sequence_group
|
| 347 |
+
with num_lookahead_slots.
|
| 348 |
+
|
| 349 |
+
Args:
|
| 350 |
+
sequence_group (SequenceGroup): The sequence group to swap in.
|
| 351 |
+
num_lookahead_slots (int): Number of lookahead slots used in
|
| 352 |
+
speculative decoding, default to 0.
|
| 353 |
+
|
| 354 |
+
Returns:
|
| 355 |
+
AllocStatus: The AllocStatus for the given sequence group.
|
| 356 |
+
"""
|
| 357 |
+
return self._can_swap(seq_group, Device.GPU, SequenceStatus.SWAPPED,
|
| 358 |
+
num_lookahead_slots)
|
| 359 |
+
|
| 360 |
+
def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
|
| 361 |
+
"""Returns the block id mapping (from CPU to GPU) generated by
|
| 362 |
+
swapping in the given seq_group with num_lookahead_slots.
|
| 363 |
+
|
| 364 |
+
Args:
|
| 365 |
+
seq_group (SequenceGroup): The sequence group to swap in.
|
| 366 |
+
|
| 367 |
+
Returns:
|
| 368 |
+
List[Tuple[int, int]]: The mapping of swapping block from CPU
|
| 369 |
+
to GPU.
|
| 370 |
+
"""
|
| 371 |
+
physical_block_id_mapping = []
|
| 372 |
+
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
| 373 |
+
blocks = self.block_tables[seq.seq_id].blocks
|
| 374 |
+
if len(blocks) == 0:
|
| 375 |
+
continue
|
| 376 |
+
|
| 377 |
+
seq_swap_mapping = self.block_allocator.swap(blocks=blocks,
|
| 378 |
+
src_device=Device.CPU,
|
| 379 |
+
dst_device=Device.GPU)
|
| 380 |
+
|
| 381 |
+
# Refresh the block ids of the table (post-swap)
|
| 382 |
+
self.block_tables[seq.seq_id].update(blocks)
|
| 383 |
+
|
| 384 |
+
seq_physical_block_id_mapping = {
|
| 385 |
+
self.block_allocator.get_physical_block_id(
|
| 386 |
+
Device.CPU, cpu_block_id):
|
| 387 |
+
self.block_allocator.get_physical_block_id(
|
| 388 |
+
Device.GPU, gpu_block_id)
|
| 389 |
+
for cpu_block_id, gpu_block_id in seq_swap_mapping.items()
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
physical_block_id_mapping.extend(
|
| 393 |
+
list(seq_physical_block_id_mapping.items()))
|
| 394 |
+
|
| 395 |
+
return physical_block_id_mapping
|
| 396 |
+
|
| 397 |
+
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
|
| 398 |
+
"""Returns whether we can swap out the given sequence_group
|
| 399 |
+
with num_lookahead_slots.
|
| 400 |
+
|
| 401 |
+
Args:
|
| 402 |
+
seq_group (SequenceGroup): The sequence group to swap out.
|
| 403 |
+
num_lookahead_slots (int): Number of lookahead slots used in
|
| 404 |
+
speculative decoding, default to 0.
|
| 405 |
+
|
| 406 |
+
Returns:
|
| 407 |
+
bool: Whether it's possible to swap out current sequence group.
|
| 408 |
+
"""
|
| 409 |
+
alloc_status = self._can_swap(seq_group, Device.CPU,
|
| 410 |
+
SequenceStatus.RUNNING)
|
| 411 |
+
return alloc_status == AllocStatus.OK
|
| 412 |
+
|
| 413 |
+
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
|
| 414 |
+
"""Returns the block id mapping (from GPU to CPU) generated by
|
| 415 |
+
swapping out the given sequence_group with num_lookahead_slots.
|
| 416 |
+
|
| 417 |
+
Args:
|
| 418 |
+
sequence_group (SequenceGroup): The sequence group to swap out.
|
| 419 |
+
|
| 420 |
+
Returns:
|
| 421 |
+
List[Tuple[int, int]]: The mapping of swapping block from
|
| 422 |
+
GPU to CPU.
|
| 423 |
+
"""
|
| 424 |
+
physical_block_id_mapping = []
|
| 425 |
+
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
| 426 |
+
blocks = self.block_tables[seq.seq_id].blocks
|
| 427 |
+
if len(blocks) == 0:
|
| 428 |
+
continue
|
| 429 |
+
|
| 430 |
+
seq_swap_mapping = self.block_allocator.swap(blocks=blocks,
|
| 431 |
+
src_device=Device.GPU,
|
| 432 |
+
dst_device=Device.CPU)
|
| 433 |
+
|
| 434 |
+
# Refresh the block ids of the table (post-swap)
|
| 435 |
+
self.block_tables[seq.seq_id].update(blocks)
|
| 436 |
+
|
| 437 |
+
seq_physical_block_id_mapping = {
|
| 438 |
+
self.block_allocator.get_physical_block_id(
|
| 439 |
+
Device.GPU, gpu_block_id):
|
| 440 |
+
self.block_allocator.get_physical_block_id(
|
| 441 |
+
Device.CPU, cpu_block_id)
|
| 442 |
+
for gpu_block_id, cpu_block_id in seq_swap_mapping.items()
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
physical_block_id_mapping.extend(
|
| 446 |
+
list(seq_physical_block_id_mapping.items()))
|
| 447 |
+
|
| 448 |
+
return physical_block_id_mapping
|
| 449 |
+
|
| 450 |
+
def get_num_free_gpu_blocks(self) -> int:
|
| 451 |
+
return self.block_allocator.get_num_free_blocks(Device.GPU)
|
| 452 |
+
|
| 453 |
+
def get_num_free_cpu_blocks(self) -> int:
|
| 454 |
+
return self.block_allocator.get_num_free_blocks(Device.CPU)
|
| 455 |
+
|
| 456 |
+
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
| 457 |
+
return self.block_allocator.get_prefix_cache_hit_rate(device)
|
| 458 |
+
|
| 459 |
+
def reset_prefix_cache(self) -> bool:
|
| 460 |
+
return self.block_allocator.reset_prefix_cache()
|
| 461 |
+
|
| 462 |
+
def _can_swap(self,
|
| 463 |
+
seq_group: SequenceGroup,
|
| 464 |
+
device: Device,
|
| 465 |
+
status: SequenceStatus,
|
| 466 |
+
num_lookahead_slots: int = 0) -> AllocStatus:
|
| 467 |
+
"""Returns the AllocStatus for swapping in/out the given sequence_group
|
| 468 |
+
on to the 'device'.
|
| 469 |
+
|
| 470 |
+
Args:
|
| 471 |
+
sequence_group (SequenceGroup): The sequence group to swap in/out.
|
| 472 |
+
device (Device): device to swap the 'seq_group' on.
|
| 473 |
+
status (SequenceStatus): The status of sequence which is needed
|
| 474 |
+
for action. RUNNING for swap out and SWAPPED for swap in
|
| 475 |
+
num_lookahead_slots (int): Number of lookahead slots used in
|
| 476 |
+
speculative decoding, default to 0.
|
| 477 |
+
|
| 478 |
+
Returns:
|
| 479 |
+
AllocStatus: The AllocStatus for swapping in/out the given
|
| 480 |
+
sequence_group on to the 'device'.
|
| 481 |
+
"""
|
| 482 |
+
# First determine the number of blocks that will be touched by this
|
| 483 |
+
# swap. Then verify if there are available blocks in the device
|
| 484 |
+
# to perform the swap.
|
| 485 |
+
num_blocks_touched = 0
|
| 486 |
+
blocks: List[Block] = []
|
| 487 |
+
for seq in seq_group.get_seqs(status=status):
|
| 488 |
+
block_table = self.block_tables[seq.seq_id]
|
| 489 |
+
if block_table.blocks is not None:
|
| 490 |
+
# Compute the number blocks to touch for the tokens to be
|
| 491 |
+
# appended. This does NOT include the full blocks that need
|
| 492 |
+
# to be touched for the swap.
|
| 493 |
+
num_blocks_touched += \
|
| 494 |
+
block_table.get_num_blocks_touched_by_append_slots(
|
| 495 |
+
block_table.get_unseen_token_ids(seq.get_token_ids()),
|
| 496 |
+
num_lookahead_slots=num_lookahead_slots)
|
| 497 |
+
blocks.extend(block_table.blocks)
|
| 498 |
+
# Compute the number of full blocks to touch and add it to the
|
| 499 |
+
# existing count of blocks to touch.
|
| 500 |
+
num_blocks_touched += self.block_allocator.get_num_full_blocks_touched(
|
| 501 |
+
blocks, device=device)
|
| 502 |
+
|
| 503 |
+
watermark_blocks = 0
|
| 504 |
+
if device == Device.GPU:
|
| 505 |
+
watermark_blocks = self.watermark_blocks
|
| 506 |
+
|
| 507 |
+
if self.block_allocator.get_num_total_blocks(
|
| 508 |
+
device) < num_blocks_touched:
|
| 509 |
+
return AllocStatus.NEVER
|
| 510 |
+
elif self.block_allocator.get_num_free_blocks(
|
| 511 |
+
device) - num_blocks_touched >= watermark_blocks:
|
| 512 |
+
return AllocStatus.OK
|
| 513 |
+
else:
|
| 514 |
+
return AllocStatus.LATER
|
| 515 |
+
|
| 516 |
+
def get_num_cached_tokens(self, seq: Sequence) -> int:
|
| 517 |
+
"""Get the number of tokens in blocks that are already computed and
|
| 518 |
+
cached in the block manager for the sequence.
|
| 519 |
+
"""
|
| 520 |
+
return self._computed_blocks_tracker.get_num_cached_tokens(seq)
|
.venv/lib/python3.11/site-packages/vllm/core/evictor.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import enum
|
| 4 |
+
import heapq
|
| 5 |
+
from abc import ABC, abstractmethod
|
| 6 |
+
from typing import Dict, List, Tuple
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class EvictionPolicy(enum.Enum):
|
| 10 |
+
"""Enum for eviction policy used by make_evictor to instantiate the correct
|
| 11 |
+
Evictor subclass.
|
| 12 |
+
"""
|
| 13 |
+
LRU = enum.auto()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Evictor(ABC):
|
| 17 |
+
"""The Evictor subclasses should be used by the BlockAllocator class to
|
| 18 |
+
handle eviction of freed Blocks.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
@abstractmethod
|
| 22 |
+
def __init__(self):
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
@abstractmethod
|
| 26 |
+
def __contains__(self, block_id: int) -> bool:
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
@abstractmethod
|
| 30 |
+
def evict(self) -> Tuple[int, int]:
|
| 31 |
+
"""Runs the eviction algorithm and returns the evicted block's
|
| 32 |
+
content hash along with physical block id along with physical block id
|
| 33 |
+
"""
|
| 34 |
+
pass
|
| 35 |
+
|
| 36 |
+
@abstractmethod
|
| 37 |
+
def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
|
| 38 |
+
last_accessed: float):
|
| 39 |
+
"""Adds block to the evictor, making it a candidate for eviction"""
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
@abstractmethod
|
| 43 |
+
def update(self, block_id: int, last_accessed: float):
|
| 44 |
+
"""Update corresponding block's access time in metadata"""
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
@abstractmethod
|
| 48 |
+
def remove(self, block_id: int):
|
| 49 |
+
"""Remove a given block id from the cache."""
|
| 50 |
+
pass
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
@abstractmethod
|
| 54 |
+
def num_blocks(self) -> int:
|
| 55 |
+
pass
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class BlockMetaData:
|
| 59 |
+
"""Data structure for storing key data describe cached block, so that
|
| 60 |
+
evitor could use to make its decision which one to choose for eviction
|
| 61 |
+
|
| 62 |
+
Here we use physical block id as the dict key, as there maybe several
|
| 63 |
+
blocks with the same content hash, but their physical id is unique.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(self, content_hash: int, num_hashed_tokens: int,
|
| 67 |
+
last_accessed: float):
|
| 68 |
+
self.content_hash = content_hash
|
| 69 |
+
self.num_hashed_tokens = num_hashed_tokens
|
| 70 |
+
self.last_accessed = last_accessed
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class LRUEvictor(Evictor):
|
| 74 |
+
"""Evicts in a least-recently-used order using the last_accessed timestamp
|
| 75 |
+
that's recorded in the Block. If there are multiple blocks with
|
| 76 |
+
the same last_accessed time, then the one with the largest num_hashed_tokens
|
| 77 |
+
will be evicted. If two blocks each have the lowest last_accessed time and
|
| 78 |
+
highest num_hashed_tokens value, then one will be chose arbitrarily
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
# CLEANUP_THRESHOLD determines the maximum allowable size of the priority
|
| 82 |
+
# queue relative to the free table size. When this threshold is exceeded,
|
| 83 |
+
# a cleanup operation is triggered to reduce memory usage.
|
| 84 |
+
CLEANUP_THRESHOLD = 50
|
| 85 |
+
|
| 86 |
+
def __init__(self):
|
| 87 |
+
self.free_table: Dict[int, BlockMetaData] = {}
|
| 88 |
+
self.priority_queue = []
|
| 89 |
+
|
| 90 |
+
def __contains__(self, block_id: int) -> bool:
|
| 91 |
+
return block_id in self.free_table
|
| 92 |
+
|
| 93 |
+
def evict(self) -> Tuple[int, int]:
|
| 94 |
+
if len(self.free_table) == 0:
|
| 95 |
+
raise ValueError("No usable cache memory left")
|
| 96 |
+
|
| 97 |
+
while self.priority_queue:
|
| 98 |
+
# We do not remove outdated entries from the priority queue at the
|
| 99 |
+
# time of updating the last_accessed timestamp. Instead, outdated
|
| 100 |
+
# entries are filtered out here during eviction. Outdated entries
|
| 101 |
+
# would either not in the free table, or have older last accessed
|
| 102 |
+
# time.
|
| 103 |
+
last_accessed, _, block_id, content_hash = heapq.heappop(
|
| 104 |
+
self.priority_queue)
|
| 105 |
+
if (block_id in self.free_table and
|
| 106 |
+
self.free_table[block_id].last_accessed == last_accessed):
|
| 107 |
+
self.free_table.pop(block_id)
|
| 108 |
+
return block_id, content_hash
|
| 109 |
+
|
| 110 |
+
raise ValueError("No usable cache memory left")
|
| 111 |
+
|
| 112 |
+
def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
|
| 113 |
+
last_accessed: float):
|
| 114 |
+
self.free_table[block_id] = BlockMetaData(content_hash,
|
| 115 |
+
num_hashed_tokens,
|
| 116 |
+
last_accessed)
|
| 117 |
+
heapq.heappush(
|
| 118 |
+
self.priority_queue,
|
| 119 |
+
(last_accessed, -num_hashed_tokens, block_id, content_hash))
|
| 120 |
+
self._cleanup_if_necessary()
|
| 121 |
+
|
| 122 |
+
def update(self, block_id: int, last_accessed: float):
|
| 123 |
+
self.free_table[block_id].last_accessed = last_accessed
|
| 124 |
+
|
| 125 |
+
def _cleanup_if_necessary(self):
|
| 126 |
+
if len(self.priority_queue) > LRUEvictor.CLEANUP_THRESHOLD * len(
|
| 127 |
+
self.free_table):
|
| 128 |
+
self._cleanup()
|
| 129 |
+
|
| 130 |
+
def _cleanup(self):
|
| 131 |
+
new_priority_queue: List[Tuple[float, int, int, int]] = []
|
| 132 |
+
|
| 133 |
+
for block_id, block in self.free_table.items():
|
| 134 |
+
new_priority_queue.append(
|
| 135 |
+
(block.last_accessed, -block.num_hashed_tokens, block_id,
|
| 136 |
+
block.content_hash))
|
| 137 |
+
heapq.heapify(new_priority_queue)
|
| 138 |
+
|
| 139 |
+
self.priority_queue = new_priority_queue
|
| 140 |
+
|
| 141 |
+
def remove(self, block_id: int):
|
| 142 |
+
if block_id not in self.free_table:
|
| 143 |
+
raise ValueError(
|
| 144 |
+
"Attempting to remove block that's not in the evictor")
|
| 145 |
+
self.free_table.pop(block_id)
|
| 146 |
+
|
| 147 |
+
@property
|
| 148 |
+
def num_blocks(self) -> int:
|
| 149 |
+
return len(self.free_table)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def make_evictor(eviction_policy: EvictionPolicy) -> Evictor:
|
| 153 |
+
if eviction_policy == EvictionPolicy.LRU:
|
| 154 |
+
return LRUEvictor()
|
| 155 |
+
else:
|
| 156 |
+
raise ValueError(f"Unknown cache eviction policy: {eviction_policy}")
|
.venv/lib/python3.11/site-packages/vllm/device_allocator/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/vllm/device_allocator/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (194 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/device_allocator/cumem.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# cumem-based pytorch pluggable allocator to implement sleep mode.
|
| 4 |
+
# other approaches tried but failed:
|
| 5 |
+
# - cuda-python package binding
|
| 6 |
+
# - custom libcuda driver ctypes wrapper
|
| 7 |
+
# both of them failed because of cuda context mismatch.
|
| 8 |
+
# not sure why, they are created from a different context.
|
| 9 |
+
# the only successful approach is to call cuda driver API in C.
|
| 10 |
+
import dataclasses
|
| 11 |
+
from contextlib import contextmanager
|
| 12 |
+
from typing import Callable, Dict, Optional, Tuple, Union
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
from vllm.utils import is_pin_memory_available
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def find_loaded_library(lib_name) -> Optional[str]:
|
| 20 |
+
"""
|
| 21 |
+
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
|
| 22 |
+
the file `/proc/self/maps` contains the memory maps of the process, which includes the
|
| 23 |
+
shared libraries loaded by the process. We can use this file to find the path of the
|
| 24 |
+
a loaded library.
|
| 25 |
+
""" # noqa
|
| 26 |
+
found_line = None
|
| 27 |
+
with open("/proc/self/maps") as f:
|
| 28 |
+
for line in f:
|
| 29 |
+
if lib_name in line:
|
| 30 |
+
found_line = line
|
| 31 |
+
break
|
| 32 |
+
if found_line is None:
|
| 33 |
+
# the library is not loaded in the current process
|
| 34 |
+
return None
|
| 35 |
+
# if lib_name is libcudart, we need to match a line with:
|
| 36 |
+
# address /path/to/libcudart-hash.so.11.0
|
| 37 |
+
start = found_line.index("/")
|
| 38 |
+
path = found_line[start:].strip()
|
| 39 |
+
filename = path.split("/")[-1]
|
| 40 |
+
assert filename.rpartition(".so")[0].startswith(lib_name), \
|
| 41 |
+
f"Unexpected filename: {filename} for library {lib_name}"
|
| 42 |
+
return path
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
cumem_available = False
|
| 46 |
+
try:
|
| 47 |
+
from vllm.cumem_allocator import (init_module, python_create_and_map,
|
| 48 |
+
python_unmap_and_release)
|
| 49 |
+
from vllm.distributed.device_communicators.cuda_wrapper import (
|
| 50 |
+
CudaRTLibrary)
|
| 51 |
+
lib_name = find_loaded_library("cumem_allocator")
|
| 52 |
+
libcudart = CudaRTLibrary()
|
| 53 |
+
cumem_available = True
|
| 54 |
+
except ModuleNotFoundError:
|
| 55 |
+
# rocm platform does not support cumem allocator
|
| 56 |
+
init_module = None
|
| 57 |
+
python_create_and_map = None
|
| 58 |
+
python_unmap_and_release = None
|
| 59 |
+
CudaRTLibrary = None
|
| 60 |
+
lib_name = None
|
| 61 |
+
libcudart = None
|
| 62 |
+
|
| 63 |
+
# py_device, py_alignedSize, py_d_mem, py_p_memHandle
|
| 64 |
+
HandleType = Tuple[int, int, int, int]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@dataclasses.dataclass
|
| 68 |
+
class AllocationData:
|
| 69 |
+
handle: HandleType
|
| 70 |
+
tag: str
|
| 71 |
+
cpu_backup_tensor: Optional[torch.Tensor] = None
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def create_and_map(allocation_handle: HandleType) -> None:
|
| 75 |
+
python_create_and_map(*allocation_handle)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def unmap_and_release(allocation_handle: HandleType) -> None:
|
| 79 |
+
python_unmap_and_release(*allocation_handle)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_pluggable_allocator(
|
| 83 |
+
python_malloc_fn: Callable[[int],
|
| 84 |
+
int], python_free_func: Callable[[int, int],
|
| 85 |
+
None]
|
| 86 |
+
) -> torch.cuda.memory.CUDAPluggableAllocator:
|
| 87 |
+
init_module(python_malloc_fn, python_free_func)
|
| 88 |
+
new_alloc = torch.cuda.memory.CUDAPluggableAllocator(
|
| 89 |
+
lib_name, 'my_malloc', 'my_free')
|
| 90 |
+
return new_alloc
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@contextmanager
|
| 94 |
+
def use_memory_pool_with_allocator(
|
| 95 |
+
python_malloc_fn: Callable[[int], int],
|
| 96 |
+
python_free_func: Callable[[int, int], None]) -> None:
|
| 97 |
+
new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
|
| 98 |
+
mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator)
|
| 99 |
+
with torch.cuda.memory.use_mem_pool(mem_pool):
|
| 100 |
+
yield mem_pool
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class CuMemAllocator:
|
| 104 |
+
"""
|
| 105 |
+
A singleton class that manages a memory pool for CUDA tensors.
|
| 106 |
+
The memory in this pool can be offloaded or discarded when the
|
| 107 |
+
allocator sleeps.
|
| 108 |
+
|
| 109 |
+
Inside the `use_memory_pool(tag)` context, all tensors created will
|
| 110 |
+
be allocated in the memory pool, and has the same tag as the
|
| 111 |
+
tag passed to the context.
|
| 112 |
+
|
| 113 |
+
When we call `sleep`, all tensors with the specified tag will be
|
| 114 |
+
offloaded to CPU memory, and the rest of the tensors will be discarded.
|
| 115 |
+
When we call `wake_up`, all tensors that are previously offloaded
|
| 116 |
+
will be loaded back to GPU memory, and the rest of the tensors will
|
| 117 |
+
have empty memory.
|
| 118 |
+
|
| 119 |
+
Why it needs to be a singleton?
|
| 120 |
+
When allocated tensors are garbage collected, PyTorch will call
|
| 121 |
+
the free callback, which will call the `python_free_callback` method.
|
| 122 |
+
The C-extension uses a global variable to store the function of an
|
| 123 |
+
instance of this class. If we create multiple instances of this class,
|
| 124 |
+
the global variable will be overwritten and the free callback will
|
| 125 |
+
not work as expected.
|
| 126 |
+
"""
|
| 127 |
+
instance: "CuMemAllocator" = None
|
| 128 |
+
default_tag: str = "default"
|
| 129 |
+
|
| 130 |
+
@staticmethod
|
| 131 |
+
def get_instance() -> "CuMemAllocator":
|
| 132 |
+
"""
|
| 133 |
+
CuMemAllocator is a singleton class.
|
| 134 |
+
We cannot call the constructor directly.
|
| 135 |
+
Call this method to get the instance.
|
| 136 |
+
"""
|
| 137 |
+
assert cumem_available, "cumem allocator is not available"
|
| 138 |
+
if CuMemAllocator.instance is None:
|
| 139 |
+
CuMemAllocator.instance = CuMemAllocator()
|
| 140 |
+
return CuMemAllocator.instance
|
| 141 |
+
|
| 142 |
+
def __init__(self):
|
| 143 |
+
self.pointer_to_data: Dict[int, AllocationData] = {}
|
| 144 |
+
self.current_tag: str = CuMemAllocator.default_tag
|
| 145 |
+
|
| 146 |
+
def python_malloc_callback(self, allocation_handle: HandleType) -> None:
|
| 147 |
+
"""
|
| 148 |
+
Internal method to store the allocation data
|
| 149 |
+
when memory is allocated in the memory pool."""
|
| 150 |
+
py_d_mem = allocation_handle[2]
|
| 151 |
+
self.pointer_to_data[py_d_mem] = AllocationData(
|
| 152 |
+
allocation_handle, self.current_tag)
|
| 153 |
+
return
|
| 154 |
+
|
| 155 |
+
def python_free_callback(self, ptr: int) -> HandleType:
|
| 156 |
+
"""
|
| 157 |
+
Internal method to look up the allocation data
|
| 158 |
+
when memory is freed in the memory pool."""
|
| 159 |
+
data = self.pointer_to_data.pop(ptr)
|
| 160 |
+
if data.cpu_backup_tensor is not None:
|
| 161 |
+
data.cpu_backup_tensor = None
|
| 162 |
+
return data.handle
|
| 163 |
+
|
| 164 |
+
def sleep(
|
| 165 |
+
self,
|
| 166 |
+
offload_tags: Optional[Union[Tuple[str, ...],
|
| 167 |
+
str]] = None) -> None:
|
| 168 |
+
"""
|
| 169 |
+
Put the allocator in sleep mode.
|
| 170 |
+
All data in the memory allocation with the specified tag will be
|
| 171 |
+
offloaded to CPU memory, and others will be discarded.
|
| 172 |
+
|
| 173 |
+
:param offload_tags: The tags of the memory allocation that will be
|
| 174 |
+
offloaded. The rest of the memory allocation will be discarded.
|
| 175 |
+
"""
|
| 176 |
+
if offload_tags is None:
|
| 177 |
+
# by default, allocated tensors are offloaded
|
| 178 |
+
# when the allocator sleeps
|
| 179 |
+
offload_tags = (CuMemAllocator.default_tag, )
|
| 180 |
+
elif isinstance(offload_tags, str):
|
| 181 |
+
offload_tags = (offload_tags, )
|
| 182 |
+
|
| 183 |
+
assert isinstance(offload_tags, tuple)
|
| 184 |
+
|
| 185 |
+
for ptr, data in self.pointer_to_data.items():
|
| 186 |
+
handle = data.handle
|
| 187 |
+
if data.tag in offload_tags:
|
| 188 |
+
size_in_bytes = handle[1]
|
| 189 |
+
cpu_backup_tensor = torch.empty(
|
| 190 |
+
size_in_bytes,
|
| 191 |
+
dtype=torch.uint8,
|
| 192 |
+
device='cpu',
|
| 193 |
+
pin_memory=is_pin_memory_available())
|
| 194 |
+
cpu_ptr = cpu_backup_tensor.data_ptr()
|
| 195 |
+
libcudart.cudaMemcpy(cpu_ptr, ptr, size_in_bytes)
|
| 196 |
+
data.cpu_backup_tensor = cpu_backup_tensor
|
| 197 |
+
unmap_and_release(handle)
|
| 198 |
+
|
| 199 |
+
def wake_up(self):
|
| 200 |
+
"""
|
| 201 |
+
Wake up the allocator from sleep mode.
|
| 202 |
+
All data that is previously offloaded will be loaded back to GPU
|
| 203 |
+
memory, and the rest of the data will have empty memory."""
|
| 204 |
+
for ptr, data in self.pointer_to_data.items():
|
| 205 |
+
handle = data.handle
|
| 206 |
+
create_and_map(handle)
|
| 207 |
+
if data.cpu_backup_tensor is not None:
|
| 208 |
+
cpu_backup_tensor = data.cpu_backup_tensor
|
| 209 |
+
if cpu_backup_tensor is not None:
|
| 210 |
+
size_in_bytes = cpu_backup_tensor.numel(
|
| 211 |
+
) * cpu_backup_tensor.element_size()
|
| 212 |
+
cpu_ptr = cpu_backup_tensor.data_ptr()
|
| 213 |
+
libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes)
|
| 214 |
+
data.cpu_backup_tensor = None
|
| 215 |
+
|
| 216 |
+
@contextmanager
|
| 217 |
+
def use_memory_pool(self, tag: Optional[str] = None):
|
| 218 |
+
"""
|
| 219 |
+
A context manager to use the memory pool.
|
| 220 |
+
All memory allocation created inside the context will be allocated
|
| 221 |
+
in the memory pool, and has the specified tag.
|
| 222 |
+
|
| 223 |
+
:param tag: The tag of the memory allocation. If None, the default tag
|
| 224 |
+
will be used.
|
| 225 |
+
"""
|
| 226 |
+
if tag is None:
|
| 227 |
+
tag = CuMemAllocator.default_tag
|
| 228 |
+
|
| 229 |
+
assert isinstance(tag, str)
|
| 230 |
+
|
| 231 |
+
old_tag = self.current_tag
|
| 232 |
+
self.current_tag = tag
|
| 233 |
+
with use_memory_pool_with_allocator(self.python_malloc_callback,
|
| 234 |
+
self.python_free_callback):
|
| 235 |
+
yield
|
| 236 |
+
# PyTorch's bug, calling torch.cuda.empty_cache() will error
|
| 237 |
+
# when using pluggable allocator, see
|
| 238 |
+
# https://github.com/pytorch/pytorch/issues/145168 .
|
| 239 |
+
# if we have some memory allocated and then freed,
|
| 240 |
+
# the memory will not be released.
|
| 241 |
+
# right now it is fine, because we only use this allocator
|
| 242 |
+
# during weight loading and kv cache creation, where we only
|
| 243 |
+
# allocate memory.
|
| 244 |
+
# TODO: we need to find a way to release the memory,
|
| 245 |
+
# i.e. calling torch.cuda.empty_cache()
|
| 246 |
+
self.current_tag = old_tag
|
| 247 |
+
|
| 248 |
+
def get_current_usage(self) -> int:
|
| 249 |
+
"""
|
| 250 |
+
Get the total number of bytes allocated in the memory pool.
|
| 251 |
+
"""
|
| 252 |
+
sum_bytes: int = 0
|
| 253 |
+
for ptr, data in self.pointer_to_data.items():
|
| 254 |
+
handle = data.handle
|
| 255 |
+
sum_bytes += handle[1]
|
| 256 |
+
return sum_bytes
|
.venv/lib/python3.11/site-packages/vllm/distributed/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from .communication_op import *
|
| 4 |
+
from .parallel_state import *
|
| 5 |
+
from .utils import *
|
.venv/lib/python3.11/site-packages/vllm/distributed/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (295 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/distributed/__pycache__/communication_op.cpython-311.pyc
ADDED
|
Binary file (2.22 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/distributed/__pycache__/parallel_state.cpython-311.pyc
ADDED
|
Binary file (55.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/distributed/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (13.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/distributed/communication_op.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict, Optional, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.distributed
|
| 7 |
+
|
| 8 |
+
from .parallel_state import get_tp_group
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
| 12 |
+
"""All-reduce the input tensor across model parallel group."""
|
| 13 |
+
return get_tp_group().all_reduce(input_)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def tensor_model_parallel_all_gather(input_: torch.Tensor,
|
| 17 |
+
dim: int = -1) -> torch.Tensor:
|
| 18 |
+
"""All-gather the input tensor across model parallel group."""
|
| 19 |
+
return get_tp_group().all_gather(input_, dim)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def tensor_model_parallel_gather(input_: torch.Tensor,
|
| 23 |
+
dst: int = 0,
|
| 24 |
+
dim: int = -1) -> Optional[torch.Tensor]:
|
| 25 |
+
"""Gather the input tensor across model parallel group."""
|
| 26 |
+
return get_tp_group().gather(input_, dst, dim)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
|
| 30 |
+
Any]]] = None,
|
| 31 |
+
src: int = 0):
|
| 32 |
+
if not torch.distributed.is_initialized():
|
| 33 |
+
return tensor_dict
|
| 34 |
+
return get_tp_group().broadcast_tensor_dict(tensor_dict, src)
|
.venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/__pycache__/custom_all_reduce.cpython-311.pyc
ADDED
|
Binary file (16 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/__pycache__/pynccl.cpython-311.pyc
ADDED
|
Binary file (11.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/cuda_wrapper.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""This file is a pure Python wrapper for the cudart library.
|
| 3 |
+
It avoids the need to compile a separate shared library, and is
|
| 4 |
+
convenient for use when we just need to call a few functions.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import ctypes
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Any, Dict, List, Optional
|
| 10 |
+
|
| 11 |
+
# this line makes it possible to directly load `libcudart.so` using `ctypes`
|
| 12 |
+
import torch # noqa
|
| 13 |
+
|
| 14 |
+
from vllm.logger import init_logger
|
| 15 |
+
|
| 16 |
+
logger = init_logger(__name__)
|
| 17 |
+
|
| 18 |
+
# === export types and functions from cudart to Python ===
|
| 19 |
+
# for the original cudart definition, please check
|
| 20 |
+
# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html
|
| 21 |
+
|
| 22 |
+
cudaError_t = ctypes.c_int
|
| 23 |
+
cudaMemcpyKind = ctypes.c_int
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class cudaIpcMemHandle_t(ctypes.Structure):
|
| 27 |
+
_fields_ = [("internal", ctypes.c_byte * 128)]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class Function:
|
| 32 |
+
name: str
|
| 33 |
+
restype: Any
|
| 34 |
+
argtypes: List[Any]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def find_loaded_library(lib_name) -> Optional[str]:
|
| 38 |
+
"""
|
| 39 |
+
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
|
| 40 |
+
the file `/proc/self/maps` contains the memory maps of the process, which includes the
|
| 41 |
+
shared libraries loaded by the process. We can use this file to find the path of the
|
| 42 |
+
a loaded library.
|
| 43 |
+
""" # noqa
|
| 44 |
+
found = False
|
| 45 |
+
with open("/proc/self/maps") as f:
|
| 46 |
+
for line in f:
|
| 47 |
+
if lib_name in line:
|
| 48 |
+
found = True
|
| 49 |
+
break
|
| 50 |
+
if not found:
|
| 51 |
+
# the library is not loaded in the current process
|
| 52 |
+
return None
|
| 53 |
+
# if lib_name is libcudart, we need to match a line with:
|
| 54 |
+
# address /path/to/libcudart-hash.so.11.0
|
| 55 |
+
start = line.index("/")
|
| 56 |
+
path = line[start:].strip()
|
| 57 |
+
filename = path.split("/")[-1]
|
| 58 |
+
assert filename.rpartition(".so")[0].startswith(lib_name), \
|
| 59 |
+
f"Unexpected filename: {filename} for library {lib_name}"
|
| 60 |
+
return path
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class CudaRTLibrary:
|
| 64 |
+
exported_functions = [
|
| 65 |
+
# cudaError_t cudaSetDevice ( int device )
|
| 66 |
+
Function("cudaSetDevice", cudaError_t, [ctypes.c_int]),
|
| 67 |
+
# cudaError_t cudaDeviceSynchronize ( void )
|
| 68 |
+
Function("cudaDeviceSynchronize", cudaError_t, []),
|
| 69 |
+
# cudaError_t cudaDeviceReset ( void )
|
| 70 |
+
Function("cudaDeviceReset", cudaError_t, []),
|
| 71 |
+
|
| 72 |
+
# const char* cudaGetErrorString ( cudaError_t error )
|
| 73 |
+
Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
|
| 74 |
+
|
| 75 |
+
# cudaError_t cudaMalloc ( void** devPtr, size_t size )
|
| 76 |
+
Function("cudaMalloc", cudaError_t,
|
| 77 |
+
[ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]),
|
| 78 |
+
# cudaError_t cudaFree ( void* devPtr )
|
| 79 |
+
Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
|
| 80 |
+
# cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
|
| 81 |
+
Function("cudaMemset", cudaError_t,
|
| 82 |
+
[ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]),
|
| 83 |
+
# cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
|
| 84 |
+
Function("cudaMemcpy", cudaError_t, [
|
| 85 |
+
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind
|
| 86 |
+
]),
|
| 87 |
+
|
| 88 |
+
# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
|
| 89 |
+
Function("cudaIpcGetMemHandle", cudaError_t,
|
| 90 |
+
[ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]),
|
| 91 |
+
# cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
|
| 92 |
+
Function("cudaIpcOpenMemHandle", cudaError_t, [
|
| 93 |
+
ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint
|
| 94 |
+
]),
|
| 95 |
+
]
|
| 96 |
+
|
| 97 |
+
# class attribute to store the mapping from the path to the library
|
| 98 |
+
# to avoid loading the same library multiple times
|
| 99 |
+
path_to_library_cache: Dict[str, Any] = {}
|
| 100 |
+
|
| 101 |
+
# class attribute to store the mapping from library path
|
| 102 |
+
# to the corresponding dictionary
|
| 103 |
+
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
|
| 104 |
+
|
| 105 |
+
def __init__(self, so_file: Optional[str] = None):
|
| 106 |
+
if so_file is None:
|
| 107 |
+
so_file = find_loaded_library("libcudart")
|
| 108 |
+
assert so_file is not None, \
|
| 109 |
+
"libcudart is not loaded in the current process"
|
| 110 |
+
if so_file not in CudaRTLibrary.path_to_library_cache:
|
| 111 |
+
lib = ctypes.CDLL(so_file)
|
| 112 |
+
CudaRTLibrary.path_to_library_cache[so_file] = lib
|
| 113 |
+
self.lib = CudaRTLibrary.path_to_library_cache[so_file]
|
| 114 |
+
|
| 115 |
+
if so_file not in CudaRTLibrary.path_to_dict_mapping:
|
| 116 |
+
_funcs = {}
|
| 117 |
+
for func in CudaRTLibrary.exported_functions:
|
| 118 |
+
f = getattr(self.lib, func.name)
|
| 119 |
+
f.restype = func.restype
|
| 120 |
+
f.argtypes = func.argtypes
|
| 121 |
+
_funcs[func.name] = f
|
| 122 |
+
CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs
|
| 123 |
+
self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]
|
| 124 |
+
|
| 125 |
+
def CUDART_CHECK(self, result: cudaError_t) -> None:
|
| 126 |
+
if result != 0:
|
| 127 |
+
error_str = self.cudaGetErrorString(result)
|
| 128 |
+
raise RuntimeError(f"CUDART error: {error_str}")
|
| 129 |
+
|
| 130 |
+
def cudaGetErrorString(self, error: cudaError_t) -> str:
|
| 131 |
+
return self.funcs["cudaGetErrorString"](error).decode("utf-8")
|
| 132 |
+
|
| 133 |
+
def cudaSetDevice(self, device: int) -> None:
|
| 134 |
+
self.CUDART_CHECK(self.funcs["cudaSetDevice"](device))
|
| 135 |
+
|
| 136 |
+
def cudaDeviceSynchronize(self) -> None:
|
| 137 |
+
self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]())
|
| 138 |
+
|
| 139 |
+
def cudaDeviceReset(self) -> None:
|
| 140 |
+
self.CUDART_CHECK(self.funcs["cudaDeviceReset"]())
|
| 141 |
+
|
| 142 |
+
def cudaMalloc(self, size: int) -> ctypes.c_void_p:
|
| 143 |
+
devPtr = ctypes.c_void_p()
|
| 144 |
+
self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size))
|
| 145 |
+
return devPtr
|
| 146 |
+
|
| 147 |
+
def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
|
| 148 |
+
self.CUDART_CHECK(self.funcs["cudaFree"](devPtr))
|
| 149 |
+
|
| 150 |
+
def cudaMemset(self, devPtr: ctypes.c_void_p, value: int,
|
| 151 |
+
count: int) -> None:
|
| 152 |
+
self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count))
|
| 153 |
+
|
| 154 |
+
def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p,
|
| 155 |
+
count: int) -> None:
|
| 156 |
+
cudaMemcpyDefault = 4
|
| 157 |
+
kind = cudaMemcpyDefault
|
| 158 |
+
self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind))
|
| 159 |
+
|
| 160 |
+
def cudaIpcGetMemHandle(self,
|
| 161 |
+
devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
|
| 162 |
+
handle = cudaIpcMemHandle_t()
|
| 163 |
+
self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"](
|
| 164 |
+
ctypes.byref(handle), devPtr))
|
| 165 |
+
return handle
|
| 166 |
+
|
| 167 |
+
def cudaIpcOpenMemHandle(self,
|
| 168 |
+
handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
|
| 169 |
+
cudaIpcMemLazyEnablePeerAccess = 1
|
| 170 |
+
devPtr = ctypes.c_void_p()
|
| 171 |
+
self.CUDART_CHECK(self.funcs["cudaIpcOpenMemHandle"](
|
| 172 |
+
ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess))
|
| 173 |
+
return devPtr
|
.venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/custom_all_reduce.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import ctypes
|
| 4 |
+
from contextlib import contextmanager
|
| 5 |
+
from typing import List, Optional, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.distributed as dist
|
| 9 |
+
from torch.distributed import ProcessGroup
|
| 10 |
+
|
| 11 |
+
import vllm.envs as envs
|
| 12 |
+
from vllm import _custom_ops as ops
|
| 13 |
+
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
| 14 |
+
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
|
| 15 |
+
gpu_p2p_access_check)
|
| 16 |
+
from vllm.distributed.parallel_state import in_the_same_node_as
|
| 17 |
+
from vllm.logger import init_logger
|
| 18 |
+
from vllm.platforms import current_platform
|
| 19 |
+
from vllm.utils import cuda_device_count_stateless
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
ops.meta_size()
|
| 23 |
+
custom_ar = True
|
| 24 |
+
except Exception:
|
| 25 |
+
# For AMD GPUs and CPUs
|
| 26 |
+
custom_ar = False
|
| 27 |
+
|
| 28 |
+
logger = init_logger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _can_p2p(rank: int, world_size: int) -> bool:
|
| 32 |
+
for i in range(world_size):
|
| 33 |
+
if i == rank:
|
| 34 |
+
continue
|
| 35 |
+
if envs.VLLM_SKIP_P2P_CHECK:
|
| 36 |
+
logger.info(
|
| 37 |
+
"Skipping P2P check and trusting the driver's P2P report.")
|
| 38 |
+
return torch.cuda.can_device_access_peer(rank, i)
|
| 39 |
+
if not gpu_p2p_access_check(rank, i):
|
| 40 |
+
return False
|
| 41 |
+
return True
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def is_weak_contiguous(inp: torch.Tensor):
|
| 45 |
+
return inp.is_contiguous() or (inp.storage().nbytes() -
|
| 46 |
+
inp.storage_offset() * inp.element_size()
|
| 47 |
+
== inp.numel() * inp.element_size())
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class CustomAllreduce:
|
| 51 |
+
|
| 52 |
+
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
| 53 |
+
|
| 54 |
+
# max_size: max supported allreduce size
|
| 55 |
+
def __init__(self,
|
| 56 |
+
group: ProcessGroup,
|
| 57 |
+
device: Union[int, str, torch.device],
|
| 58 |
+
max_size=8192 * 1024) -> None:
|
| 59 |
+
"""
|
| 60 |
+
Args:
|
| 61 |
+
group: the process group to work on. If None, it will use the
|
| 62 |
+
default process group.
|
| 63 |
+
device: the device to bind the CustomAllreduce to. If None,
|
| 64 |
+
it will be bind to f"cuda:{local_rank}".
|
| 65 |
+
It is the caller's responsibility to make sure each communicator
|
| 66 |
+
is bind to a unique device, and all communicators in this group
|
| 67 |
+
are in the same node.
|
| 68 |
+
"""
|
| 69 |
+
self._IS_CAPTURING = False
|
| 70 |
+
self.disabled = True
|
| 71 |
+
|
| 72 |
+
if not custom_ar:
|
| 73 |
+
# disable because of missing custom allreduce library
|
| 74 |
+
# e.g. in a non-cuda environment
|
| 75 |
+
return
|
| 76 |
+
|
| 77 |
+
self.group = group
|
| 78 |
+
|
| 79 |
+
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
| 80 |
+
"CustomAllreduce should be attached to a non-NCCL group.")
|
| 81 |
+
|
| 82 |
+
if not all(in_the_same_node_as(group, source_rank=0)):
|
| 83 |
+
# No need to initialize custom allreduce for multi-node case.
|
| 84 |
+
logger.warning(
|
| 85 |
+
"Custom allreduce is disabled because this process group"
|
| 86 |
+
" spans across nodes.")
|
| 87 |
+
return
|
| 88 |
+
|
| 89 |
+
rank = dist.get_rank(group=self.group)
|
| 90 |
+
world_size = dist.get_world_size(group=self.group)
|
| 91 |
+
if world_size == 1:
|
| 92 |
+
# No need to initialize custom allreduce for single GPU case.
|
| 93 |
+
return
|
| 94 |
+
|
| 95 |
+
if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
|
| 96 |
+
logger.warning(
|
| 97 |
+
"Custom allreduce is disabled due to an unsupported world"
|
| 98 |
+
" size: %d. Supported world sizes: %s. To silence this "
|
| 99 |
+
"warning, specify disable_custom_all_reduce=True explicitly.",
|
| 100 |
+
world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES))
|
| 101 |
+
return
|
| 102 |
+
|
| 103 |
+
if isinstance(device, int):
|
| 104 |
+
device = torch.device(f"cuda:{device}")
|
| 105 |
+
elif isinstance(device, str):
|
| 106 |
+
device = torch.device(device)
|
| 107 |
+
# now `device` is a `torch.device` object
|
| 108 |
+
assert isinstance(device, torch.device)
|
| 109 |
+
self.device = device
|
| 110 |
+
|
| 111 |
+
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
| 112 |
+
if cuda_visible_devices:
|
| 113 |
+
device_ids = list(map(int, cuda_visible_devices.split(",")))
|
| 114 |
+
else:
|
| 115 |
+
device_ids = list(range(cuda_device_count_stateless()))
|
| 116 |
+
|
| 117 |
+
physical_device_id = device_ids[device.index]
|
| 118 |
+
tensor = torch.tensor([physical_device_id],
|
| 119 |
+
dtype=torch.int,
|
| 120 |
+
device="cpu")
|
| 121 |
+
gather_list = [
|
| 122 |
+
torch.tensor([0], dtype=torch.int, device="cpu")
|
| 123 |
+
for _ in range(world_size)
|
| 124 |
+
]
|
| 125 |
+
dist.all_gather(gather_list, tensor, group=self.group)
|
| 126 |
+
physical_device_ids = [t.item() for t in gather_list]
|
| 127 |
+
|
| 128 |
+
# test nvlink first, this will filter out most of the cases
|
| 129 |
+
# where custom allreduce is not supported
|
| 130 |
+
# this checks hardware and driver support for NVLink
|
| 131 |
+
assert current_platform.is_cuda()
|
| 132 |
+
from vllm.platforms.cuda import CudaPlatform
|
| 133 |
+
cuda_platform: CudaPlatform = current_platform
|
| 134 |
+
full_nvlink = cuda_platform.is_full_nvlink(physical_device_ids)
|
| 135 |
+
if world_size > 2 and not full_nvlink:
|
| 136 |
+
logger.warning(
|
| 137 |
+
"Custom allreduce is disabled because it's not supported on"
|
| 138 |
+
" more than two PCIe-only GPUs. To silence this warning, "
|
| 139 |
+
"specify disable_custom_all_reduce=True explicitly.")
|
| 140 |
+
return
|
| 141 |
+
# test P2P capability, this checks software/cudaruntime support
|
| 142 |
+
# this is expensive to compute at the first time
|
| 143 |
+
# then we cache the result
|
| 144 |
+
if not _can_p2p(rank, world_size):
|
| 145 |
+
logger.warning(
|
| 146 |
+
"Custom allreduce is disabled because your platform lacks "
|
| 147 |
+
"GPU P2P capability or P2P test failed. To silence this "
|
| 148 |
+
"warning, specify disable_custom_all_reduce=True explicitly.")
|
| 149 |
+
return
|
| 150 |
+
|
| 151 |
+
self.disabled = False
|
| 152 |
+
# Buffers memory are owned by this Python class and passed to C++.
|
| 153 |
+
# Meta data composes of two parts: meta data for synchronization and a
|
| 154 |
+
# temporary buffer for storing intermediate allreduce results.
|
| 155 |
+
self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size,
|
| 156 |
+
group=group)
|
| 157 |
+
# This is a pre-registered IPC buffer. In eager mode, input tensors
|
| 158 |
+
# are first copied into this buffer before allreduce is performed
|
| 159 |
+
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
|
| 160 |
+
# This is a buffer for storing the tuples of pointers pointing to
|
| 161 |
+
# IPC buffers from all ranks. Each registered tuple has size of
|
| 162 |
+
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
|
| 163 |
+
# is enough for 131072 such tuples. The largest model I've seen only
|
| 164 |
+
# needs less than 10000 of registered tuples.
|
| 165 |
+
self.rank_data = torch.empty(8 * 1024 * 1024,
|
| 166 |
+
dtype=torch.uint8,
|
| 167 |
+
device=self.device)
|
| 168 |
+
self.max_size = max_size
|
| 169 |
+
self.rank = rank
|
| 170 |
+
self.world_size = world_size
|
| 171 |
+
self.full_nvlink = full_nvlink
|
| 172 |
+
self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank,
|
| 173 |
+
self.full_nvlink)
|
| 174 |
+
ops.register_buffer(self._ptr, self.buffer_ptrs)
|
| 175 |
+
|
| 176 |
+
@staticmethod
|
| 177 |
+
def create_shared_buffer(
|
| 178 |
+
size_in_bytes: int,
|
| 179 |
+
group: Optional[ProcessGroup] = None) -> List[int]:
|
| 180 |
+
"""
|
| 181 |
+
Creates a shared buffer and returns a list of pointers
|
| 182 |
+
representing the buffer on all processes in the group.
|
| 183 |
+
"""
|
| 184 |
+
lib = CudaRTLibrary()
|
| 185 |
+
pointer = lib.cudaMalloc(size_in_bytes)
|
| 186 |
+
handle = lib.cudaIpcGetMemHandle(pointer)
|
| 187 |
+
world_size = dist.get_world_size(group=group)
|
| 188 |
+
rank = dist.get_rank(group=group)
|
| 189 |
+
handles = [None] * world_size
|
| 190 |
+
dist.all_gather_object(handles, handle, group=group)
|
| 191 |
+
|
| 192 |
+
pointers: List[int] = []
|
| 193 |
+
for i, h in enumerate(handles):
|
| 194 |
+
if i == rank:
|
| 195 |
+
pointers.append(pointer.value) # type: ignore
|
| 196 |
+
else:
|
| 197 |
+
pointers.append(
|
| 198 |
+
lib.cudaIpcOpenMemHandle(h).value) # type: ignore
|
| 199 |
+
|
| 200 |
+
return pointers
|
| 201 |
+
|
| 202 |
+
@staticmethod
|
| 203 |
+
def free_shared_buffer(pointers: List[int],
|
| 204 |
+
group: Optional[ProcessGroup] = None) -> None:
|
| 205 |
+
rank = dist.get_rank(group=group)
|
| 206 |
+
lib = CudaRTLibrary()
|
| 207 |
+
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
|
| 208 |
+
|
| 209 |
+
@contextmanager
|
| 210 |
+
def capture(self):
|
| 211 |
+
"""
|
| 212 |
+
The main responsibility of this context manager is the
|
| 213 |
+
`register_graph_buffers` call at the end of the context.
|
| 214 |
+
It records all the buffer addresses used in the CUDA graph.
|
| 215 |
+
"""
|
| 216 |
+
try:
|
| 217 |
+
self._IS_CAPTURING = True
|
| 218 |
+
yield
|
| 219 |
+
finally:
|
| 220 |
+
self._IS_CAPTURING = False
|
| 221 |
+
if not self.disabled:
|
| 222 |
+
self.register_graph_buffers()
|
| 223 |
+
|
| 224 |
+
def register_graph_buffers(self):
|
| 225 |
+
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
|
| 226 |
+
logger.info("Registering %d cuda graph addresses", len(offset))
|
| 227 |
+
# We cannot directly use `dist.all_gather_object` here
|
| 228 |
+
# because it is incompatible with `gloo` backend under inference mode.
|
| 229 |
+
# see https://github.com/pytorch/pytorch/issues/126032 for details.
|
| 230 |
+
all_data = [[None, None]
|
| 231 |
+
for _ in range(dist.get_world_size(group=self.group))]
|
| 232 |
+
all_data[self.rank] = [handle, offset]
|
| 233 |
+
ranks = sorted(dist.get_process_group_ranks(group=self.group))
|
| 234 |
+
for i, rank in enumerate(ranks):
|
| 235 |
+
dist.broadcast_object_list(all_data[i],
|
| 236 |
+
src=rank,
|
| 237 |
+
group=self.group,
|
| 238 |
+
device="cpu")
|
| 239 |
+
# Unpack list of tuples to tuple of lists.
|
| 240 |
+
handles = [d[0] for d in all_data] # type: ignore
|
| 241 |
+
offsets = [d[1] for d in all_data] # type: ignore
|
| 242 |
+
ops.register_graph_buffers(self._ptr, handles, offsets)
|
| 243 |
+
|
| 244 |
+
def should_custom_ar(self, inp: torch.Tensor):
|
| 245 |
+
if self.disabled:
|
| 246 |
+
return False
|
| 247 |
+
inp_size = inp.numel() * inp.element_size()
|
| 248 |
+
# custom allreduce requires input byte size to be multiples of 16
|
| 249 |
+
if inp_size % 16 != 0:
|
| 250 |
+
return False
|
| 251 |
+
if not is_weak_contiguous(inp):
|
| 252 |
+
return False
|
| 253 |
+
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
|
| 254 |
+
# little performance improvement over NCCL.
|
| 255 |
+
if self.world_size == 2 or self.full_nvlink:
|
| 256 |
+
return inp_size < self.max_size
|
| 257 |
+
return False
|
| 258 |
+
|
| 259 |
+
def all_reduce(self,
|
| 260 |
+
inp: torch.Tensor,
|
| 261 |
+
*,
|
| 262 |
+
out: torch.Tensor = None,
|
| 263 |
+
registered: bool = False):
|
| 264 |
+
"""Performs an out-of-place all reduce.
|
| 265 |
+
|
| 266 |
+
If registered is True, this assumes inp's pointer is already
|
| 267 |
+
IPC-registered. Otherwise, inp is first copied into a pre-registered
|
| 268 |
+
buffer.
|
| 269 |
+
"""
|
| 270 |
+
if out is None:
|
| 271 |
+
out = torch.empty_like(inp)
|
| 272 |
+
if registered:
|
| 273 |
+
ops.all_reduce(self._ptr, inp, out, 0, 0)
|
| 274 |
+
else:
|
| 275 |
+
ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank],
|
| 276 |
+
self.max_size)
|
| 277 |
+
return out
|
| 278 |
+
|
| 279 |
+
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
|
| 280 |
+
"""The main allreduce API that provides support for cuda graph."""
|
| 281 |
+
# When custom allreduce is disabled, this will be None.
|
| 282 |
+
if self.disabled or not self.should_custom_ar(input):
|
| 283 |
+
return None
|
| 284 |
+
if self._IS_CAPTURING:
|
| 285 |
+
if torch.cuda.is_current_stream_capturing():
|
| 286 |
+
return self.all_reduce(input, registered=True)
|
| 287 |
+
else:
|
| 288 |
+
# If warm up, mimic the allocation pattern since custom
|
| 289 |
+
# allreduce is out-of-place.
|
| 290 |
+
return torch.empty_like(input)
|
| 291 |
+
else:
|
| 292 |
+
# Note: outside of cuda graph context, custom allreduce incurs a
|
| 293 |
+
# cost of cudaMemcpy, which should be small (<=1% of overall
|
| 294 |
+
# latency) compared to the performance gain of using custom kernels
|
| 295 |
+
return self.all_reduce(input, registered=False)
|
| 296 |
+
|
| 297 |
+
def close(self):
|
| 298 |
+
if not self.disabled and self._ptr:
|
| 299 |
+
ops.dispose(self._ptr)
|
| 300 |
+
self._ptr = 0
|
| 301 |
+
self.free_shared_buffer(self.meta_ptrs)
|
| 302 |
+
self.free_shared_buffer(self.buffer_ptrs)
|
| 303 |
+
|
| 304 |
+
def __del__(self):
|
| 305 |
+
self.close()
|
.venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/custom_all_reduce_utils.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import ctypes
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import pickle
|
| 7 |
+
import subprocess
|
| 8 |
+
import sys
|
| 9 |
+
import tempfile
|
| 10 |
+
from itertools import product
|
| 11 |
+
from typing import Dict, List, Optional, Sequence
|
| 12 |
+
|
| 13 |
+
import torch.distributed as dist
|
| 14 |
+
import torch.multiprocessing as mp
|
| 15 |
+
|
| 16 |
+
import vllm.envs as envs
|
| 17 |
+
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
| 18 |
+
from vllm.logger import init_logger
|
| 19 |
+
from vllm.utils import (cuda_device_count_stateless,
|
| 20 |
+
update_environment_variables)
|
| 21 |
+
|
| 22 |
+
logger = init_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def producer(batch_src: Sequence[int],
|
| 26 |
+
producer_queue,
|
| 27 |
+
consumer_queue,
|
| 28 |
+
result_queue,
|
| 29 |
+
cuda_visible_devices: Optional[str] = None):
|
| 30 |
+
if cuda_visible_devices is not None:
|
| 31 |
+
update_environment_variables(
|
| 32 |
+
{"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
| 33 |
+
|
| 34 |
+
lib = CudaRTLibrary()
|
| 35 |
+
for i in batch_src:
|
| 36 |
+
lib.cudaSetDevice(i)
|
| 37 |
+
pointer = lib.cudaMalloc(1024)
|
| 38 |
+
lib.cudaMemset(pointer, 1, 1024)
|
| 39 |
+
lib.cudaDeviceSynchronize()
|
| 40 |
+
handle = lib.cudaIpcGetMemHandle(pointer)
|
| 41 |
+
producer_queue.put(handle)
|
| 42 |
+
open_success = consumer_queue.get()
|
| 43 |
+
if open_success:
|
| 44 |
+
# use two queues to simulate barrier
|
| 45 |
+
producer_queue.put(0)
|
| 46 |
+
consumer_queue.get()
|
| 47 |
+
# check if the memory is modified
|
| 48 |
+
host_data = (ctypes.c_char * 1024)()
|
| 49 |
+
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
|
| 50 |
+
for i in range(1024):
|
| 51 |
+
if ord(host_data[i]) != 2:
|
| 52 |
+
open_success = False
|
| 53 |
+
break
|
| 54 |
+
result_queue.put(open_success)
|
| 55 |
+
lib.cudaDeviceReset()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def consumer(batch_tgt: Sequence[int],
|
| 59 |
+
producer_queue,
|
| 60 |
+
consumer_queue,
|
| 61 |
+
result_queue,
|
| 62 |
+
cuda_visible_devices: Optional[str] = None):
|
| 63 |
+
if cuda_visible_devices is not None:
|
| 64 |
+
update_environment_variables(
|
| 65 |
+
{"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
| 66 |
+
|
| 67 |
+
lib = CudaRTLibrary()
|
| 68 |
+
for j in batch_tgt:
|
| 69 |
+
lib.cudaSetDevice(j)
|
| 70 |
+
handle = producer_queue.get()
|
| 71 |
+
open_success = False
|
| 72 |
+
try:
|
| 73 |
+
pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore
|
| 74 |
+
open_success = True
|
| 75 |
+
except RuntimeError:
|
| 76 |
+
# cannot error out here, because the producer process
|
| 77 |
+
# is still waiting for the response.
|
| 78 |
+
pass
|
| 79 |
+
consumer_queue.put(open_success)
|
| 80 |
+
if open_success:
|
| 81 |
+
# modify the memory
|
| 82 |
+
lib.cudaMemset(pointer, 2, 1024)
|
| 83 |
+
lib.cudaDeviceSynchronize()
|
| 84 |
+
# use two queues to simulate barrier
|
| 85 |
+
producer_queue.get()
|
| 86 |
+
consumer_queue.put(0)
|
| 87 |
+
# check if the memory is modified
|
| 88 |
+
host_data = (ctypes.c_char * 1024)()
|
| 89 |
+
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
|
| 90 |
+
for i in range(1024):
|
| 91 |
+
if ord(host_data[i]) != 2:
|
| 92 |
+
open_success = False
|
| 93 |
+
break
|
| 94 |
+
result_queue.put(open_success)
|
| 95 |
+
lib.cudaDeviceReset()
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def can_actually_p2p(
|
| 99 |
+
batch_src: Sequence[int],
|
| 100 |
+
batch_tgt: Sequence[int],
|
| 101 |
+
) -> Sequence[bool]:
|
| 102 |
+
"""
|
| 103 |
+
Usually, checking if P2P access is enabled can be done by
|
| 104 |
+
`torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes
|
| 105 |
+
the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)`
|
| 106 |
+
returns `True` even if P2P access is not actually possible.
|
| 107 |
+
See https://github.com/vllm-project/vllm/issues/2728 and
|
| 108 |
+
https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
|
| 109 |
+
Therefore, we have to perform a real P2P access to check if it is actually
|
| 110 |
+
possible.
|
| 111 |
+
|
| 112 |
+
Note on p2p and cuda IPC:
|
| 113 |
+
Usually, one process uses one GPU:
|
| 114 |
+
GPU src --> cuda context src --> tensor src --> process src
|
| 115 |
+
|
| 116 |
+
We need to combine p2p and cuda IPC, so that:
|
| 117 |
+
GPU src --> cuda context src --> tensor src --> process src
|
| 118 |
+
|shared|
|
| 119 |
+
GPU tgt --> cuda context tgt --> tensor tgt --> process tgt
|
| 120 |
+
That is to say, process src creates a tensor in GPU src, passes IPC handle to
|
| 121 |
+
process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the
|
| 122 |
+
tensor in process tgt will be reflected in the tensor in process src, because
|
| 123 |
+
they are the same memory segment.
|
| 124 |
+
It is important to note that process tgt accesses the tensor in GPU tgt, not
|
| 125 |
+
GPU src. That's why we need p2p access.
|
| 126 |
+
|
| 127 |
+
The most time-consuming part is the process creation. To avoid creating
|
| 128 |
+
processes for every pair of GPUs, we use batched testing. We create two
|
| 129 |
+
processes for testing all pairs of GPUs in batch. The trick is to reset
|
| 130 |
+
the device after each test (which is not available in PyTorch).
|
| 131 |
+
""" # noqa
|
| 132 |
+
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
| 133 |
+
# pass the CUDA_VISIBLE_DEVICES to the child process
|
| 134 |
+
# to make sure they see the same set of GPUs
|
| 135 |
+
|
| 136 |
+
# make sure the processes are spawned
|
| 137 |
+
smp = mp.get_context("spawn")
|
| 138 |
+
producer_queue = smp.Queue()
|
| 139 |
+
consumer_queue = smp.Queue()
|
| 140 |
+
result_queue = smp.Queue()
|
| 141 |
+
p_src = smp.Process(target=producer,
|
| 142 |
+
args=(batch_src, producer_queue, consumer_queue,
|
| 143 |
+
result_queue, cuda_visible_devices))
|
| 144 |
+
p_tgt = smp.Process(target=consumer,
|
| 145 |
+
args=(batch_tgt, producer_queue, consumer_queue,
|
| 146 |
+
result_queue, cuda_visible_devices))
|
| 147 |
+
p_src.start()
|
| 148 |
+
p_tgt.start()
|
| 149 |
+
p_src.join()
|
| 150 |
+
p_tgt.join()
|
| 151 |
+
assert p_src.exitcode == 0 and p_tgt.exitcode == 0
|
| 152 |
+
result: List[bool] = []
|
| 153 |
+
for src, tgt in zip(batch_src, batch_tgt):
|
| 154 |
+
a = result_queue.get()
|
| 155 |
+
b = result_queue.get()
|
| 156 |
+
if a != b:
|
| 157 |
+
logger.warning(
|
| 158 |
+
"Two processes do not agree on the P2P access"
|
| 159 |
+
" status on %d -> %d, treat as disabled.", src, tgt)
|
| 160 |
+
result.append(False)
|
| 161 |
+
else:
|
| 162 |
+
result.append(a)
|
| 163 |
+
return result
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# why do we need this cache?
|
| 167 |
+
# we are testing peer-to-peer (p2p) access between GPUs,across processes.
|
| 168 |
+
# if we test it every time, it will be very slow, because we need to create
|
| 169 |
+
# N * N * 2 processes, where N is the world size. This is very slow.
|
| 170 |
+
# to reduce the time, we use a cache file to store the p2p access status.
|
| 171 |
+
# the cache file is generated by the master process if it does not exist.
|
| 172 |
+
# then all the processes can read the cache file to check the p2p access status.
|
| 173 |
+
# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
|
| 174 |
+
# can have different cache files for different CUDA_VISIBLE_DEVICES settings,
|
| 175 |
+
# e.g. used by different vllm engines. The device id in the cache file is a
|
| 176 |
+
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
|
| 177 |
+
# of visible devices in the vllm engine.
|
| 178 |
+
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def gpu_p2p_access_check(src: int, tgt: int) -> bool:
|
| 182 |
+
"""Check if GPU src can access GPU tgt."""
|
| 183 |
+
|
| 184 |
+
# if the cache variable is already calculated,
|
| 185 |
+
# read from the cache instead of checking it again
|
| 186 |
+
global _gpu_p2p_access_cache
|
| 187 |
+
if _gpu_p2p_access_cache is not None:
|
| 188 |
+
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
|
| 189 |
+
|
| 190 |
+
is_distributed = dist.is_initialized()
|
| 191 |
+
|
| 192 |
+
num_dev = cuda_device_count_stateless()
|
| 193 |
+
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
| 194 |
+
if cuda_visible_devices is None:
|
| 195 |
+
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
|
| 196 |
+
|
| 197 |
+
path = os.path.join(
|
| 198 |
+
envs.VLLM_CACHE_ROOT,
|
| 199 |
+
f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json")
|
| 200 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 201 |
+
from vllm.distributed.parallel_state import get_world_group
|
| 202 |
+
if ((not is_distributed or get_world_group().local_rank == 0)
|
| 203 |
+
and (not os.path.exists(path))):
|
| 204 |
+
# only the local master process (with local_rank == 0) can
|
| 205 |
+
# enter this block to calculate the cache
|
| 206 |
+
logger.info("generating GPU P2P access cache in %s", path)
|
| 207 |
+
cache: Dict[str, bool] = {}
|
| 208 |
+
ids = list(range(num_dev))
|
| 209 |
+
# batch of all pairs of GPUs
|
| 210 |
+
batch_src, batch_tgt = zip(*list(product(ids, ids)))
|
| 211 |
+
# NOTE: we use `subprocess` rather than `multiprocessing` here
|
| 212 |
+
# because the caller might not have `if __name__ == "__main__":`,
|
| 213 |
+
# in that case we cannot use spawn method in multiprocessing.
|
| 214 |
+
# However, `can_actually_p2p` requires spawn method.
|
| 215 |
+
# The fix is, we use `subprocess` to call the function,
|
| 216 |
+
# where we have `if __name__ == "__main__":` in this file.
|
| 217 |
+
|
| 218 |
+
# use a temporary file to store the result
|
| 219 |
+
# we don't use the output of the subprocess directly,
|
| 220 |
+
# because the subprocess might produce logging output
|
| 221 |
+
with tempfile.NamedTemporaryFile() as output_file:
|
| 222 |
+
input_bytes = pickle.dumps(
|
| 223 |
+
(batch_src, batch_tgt, output_file.name))
|
| 224 |
+
returned = subprocess.run([sys.executable, __file__],
|
| 225 |
+
input=input_bytes,
|
| 226 |
+
capture_output=True)
|
| 227 |
+
# check if the subprocess is successful
|
| 228 |
+
try:
|
| 229 |
+
returned.check_returncode()
|
| 230 |
+
except Exception as e:
|
| 231 |
+
# wrap raised exception to provide more information
|
| 232 |
+
raise RuntimeError(
|
| 233 |
+
f"Error happened when batch testing "
|
| 234 |
+
f"peer-to-peer access from {batch_src} to {batch_tgt}:\n"
|
| 235 |
+
f"{returned.stderr.decode()}") from e
|
| 236 |
+
with open(output_file.name, "rb") as f:
|
| 237 |
+
result = pickle.load(f)
|
| 238 |
+
for _i, _j, r in zip(batch_src, batch_tgt, result):
|
| 239 |
+
cache[f"{_i}->{_j}"] = r
|
| 240 |
+
with open(path, "w") as f:
|
| 241 |
+
json.dump(cache, f, indent=4)
|
| 242 |
+
if is_distributed:
|
| 243 |
+
get_world_group().barrier()
|
| 244 |
+
logger.info("reading GPU P2P access cache from %s", path)
|
| 245 |
+
with open(path) as f:
|
| 246 |
+
cache = json.load(f)
|
| 247 |
+
_gpu_p2p_access_cache = cache
|
| 248 |
+
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
__all__ = ["gpu_p2p_access_check"]
|
| 252 |
+
|
| 253 |
+
if __name__ == "__main__":
|
| 254 |
+
batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read())
|
| 255 |
+
result = can_actually_p2p(batch_src, batch_tgt)
|
| 256 |
+
with open(output_file, "wb") as f:
|
| 257 |
+
f.write(pickle.dumps(result))
|
.venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/hpu_communicator.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
from torch.distributed import ProcessGroup
|
| 6 |
+
|
| 7 |
+
from vllm.platforms import current_platform
|
| 8 |
+
|
| 9 |
+
if current_platform.is_hpu():
|
| 10 |
+
import habana_frameworks.torch as htorch # noqa: F401
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class HpuCommunicator:
|
| 14 |
+
|
| 15 |
+
def __init__(self, group: ProcessGroup):
|
| 16 |
+
if not current_platform.is_hpu():
|
| 17 |
+
self.disabled = True
|
| 18 |
+
return
|
| 19 |
+
self.disabled = False
|
| 20 |
+
self.group = group
|
| 21 |
+
self.world_size = dist.get_world_size(self.group)
|
| 22 |
+
|
| 23 |
+
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
| 24 |
+
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
|
| 25 |
+
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
|
| 26 |
+
# (which is required for tensor parallel HPUGraph inference)
|
| 27 |
+
htorch.core.mark_step()
|
| 28 |
+
dist.all_reduce(x, group=self.group)
|
| 29 |
+
return x
|
| 30 |
+
|
| 31 |
+
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
| 32 |
+
world_size = self.world_size
|
| 33 |
+
if dim < 0:
|
| 34 |
+
# Convert negative dim to positive.
|
| 35 |
+
dim += x.dim()
|
| 36 |
+
input_size = x.size()
|
| 37 |
+
# Allocate output tensor.
|
| 38 |
+
output_tensor = torch.empty((world_size, ) + input_size,
|
| 39 |
+
dtype=x.dtype,
|
| 40 |
+
device=x.device)
|
| 41 |
+
# All-gather.
|
| 42 |
+
htorch.core.mark_step()
|
| 43 |
+
dist.all_gather_into_tensor(output_tensor, x, group=self.group)
|
| 44 |
+
# Reshape
|
| 45 |
+
output_tensor = output_tensor.movedim(0, dim)
|
| 46 |
+
output_tensor = output_tensor.reshape(input_size[:dim] +
|
| 47 |
+
(world_size *
|
| 48 |
+
input_size[dim], ) +
|
| 49 |
+
input_size[dim + 1:])
|
| 50 |
+
return output_tensor
|
.venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/pynccl.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Union
|
| 4 |
+
|
| 5 |
+
# ===================== import region =====================
|
| 6 |
+
import torch
|
| 7 |
+
import torch.distributed as dist
|
| 8 |
+
from torch.distributed import ProcessGroup, ReduceOp
|
| 9 |
+
|
| 10 |
+
from vllm.distributed.device_communicators.pynccl_wrapper import (
|
| 11 |
+
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
|
| 12 |
+
ncclRedOpTypeEnum, ncclUniqueId)
|
| 13 |
+
from vllm.distributed.utils import StatelessProcessGroup
|
| 14 |
+
from vllm.logger import init_logger
|
| 15 |
+
from vllm.utils import current_stream
|
| 16 |
+
|
| 17 |
+
logger = init_logger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class PyNcclCommunicator:
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
group: Union[ProcessGroup, StatelessProcessGroup],
|
| 25 |
+
device: Union[int, str, torch.device],
|
| 26 |
+
library_path: Optional[str] = None,
|
| 27 |
+
):
|
| 28 |
+
"""
|
| 29 |
+
Args:
|
| 30 |
+
group: the process group to work on. If None, it will use the
|
| 31 |
+
default process group.
|
| 32 |
+
device: the device to bind the PyNcclCommunicator to. If None,
|
| 33 |
+
it will be bind to f"cuda:{local_rank}".
|
| 34 |
+
library_path: the path to the NCCL library. If None, it will
|
| 35 |
+
use the default library path.
|
| 36 |
+
It is the caller's responsibility to make sure each communicator
|
| 37 |
+
is bind to a unique device.
|
| 38 |
+
"""
|
| 39 |
+
if not isinstance(group, StatelessProcessGroup):
|
| 40 |
+
assert dist.is_initialized()
|
| 41 |
+
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
| 42 |
+
"PyNcclCommunicator should be attached to a non-NCCL group.")
|
| 43 |
+
# note: this rank is the rank in the group
|
| 44 |
+
self.rank = dist.get_rank(group)
|
| 45 |
+
self.world_size = dist.get_world_size(group)
|
| 46 |
+
else:
|
| 47 |
+
self.rank = group.rank
|
| 48 |
+
self.world_size = group.world_size
|
| 49 |
+
|
| 50 |
+
self.group = group
|
| 51 |
+
|
| 52 |
+
# if world_size == 1, no need to create communicator
|
| 53 |
+
if self.world_size == 1:
|
| 54 |
+
self.available = False
|
| 55 |
+
self.disabled = True
|
| 56 |
+
return
|
| 57 |
+
try:
|
| 58 |
+
self.nccl = NCCLLibrary(library_path)
|
| 59 |
+
except Exception:
|
| 60 |
+
# disable because of missing NCCL library
|
| 61 |
+
# e.g. in a non-GPU environment
|
| 62 |
+
self.available = False
|
| 63 |
+
self.disabled = True
|
| 64 |
+
return
|
| 65 |
+
|
| 66 |
+
self.available = True
|
| 67 |
+
self.disabled = False
|
| 68 |
+
|
| 69 |
+
logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())
|
| 70 |
+
|
| 71 |
+
if self.rank == 0:
|
| 72 |
+
# get the unique id from NCCL
|
| 73 |
+
self.unique_id = self.nccl.ncclGetUniqueId()
|
| 74 |
+
else:
|
| 75 |
+
# construct an empty unique id
|
| 76 |
+
self.unique_id = ncclUniqueId()
|
| 77 |
+
|
| 78 |
+
if not isinstance(group, StatelessProcessGroup):
|
| 79 |
+
tensor = torch.ByteTensor(list(self.unique_id.internal))
|
| 80 |
+
ranks = dist.get_process_group_ranks(group)
|
| 81 |
+
# arg `src` in `broadcast` is the global rank
|
| 82 |
+
dist.broadcast(tensor, src=ranks[0], group=group)
|
| 83 |
+
byte_list = tensor.tolist()
|
| 84 |
+
for i, byte in enumerate(byte_list):
|
| 85 |
+
self.unique_id.internal[i] = byte
|
| 86 |
+
else:
|
| 87 |
+
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
|
| 88 |
+
if isinstance(device, int):
|
| 89 |
+
device = torch.device(f"cuda:{device}")
|
| 90 |
+
elif isinstance(device, str):
|
| 91 |
+
device = torch.device(device)
|
| 92 |
+
# now `device` is a `torch.device` object
|
| 93 |
+
assert isinstance(device, torch.device)
|
| 94 |
+
self.device = device
|
| 95 |
+
# nccl communicator and stream will use this device
|
| 96 |
+
# `torch.cuda.device` is a context manager that changes the
|
| 97 |
+
# current cuda device to the specified one
|
| 98 |
+
with torch.cuda.device(device):
|
| 99 |
+
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
| 100 |
+
self.world_size, self.unique_id, self.rank)
|
| 101 |
+
|
| 102 |
+
stream = current_stream()
|
| 103 |
+
# A small all_reduce for warmup.
|
| 104 |
+
data = torch.zeros(1, device=device)
|
| 105 |
+
self.all_reduce(data)
|
| 106 |
+
stream.synchronize()
|
| 107 |
+
del data
|
| 108 |
+
|
| 109 |
+
def all_reduce(self,
|
| 110 |
+
in_tensor: torch.Tensor,
|
| 111 |
+
op: ReduceOp = ReduceOp.SUM,
|
| 112 |
+
stream=None) -> torch.Tensor:
|
| 113 |
+
if self.disabled:
|
| 114 |
+
return None
|
| 115 |
+
# nccl communicator created on a specific device
|
| 116 |
+
# will only work on tensors on the same device
|
| 117 |
+
# otherwise it will cause "illegal memory access"
|
| 118 |
+
assert in_tensor.device == self.device, (
|
| 119 |
+
f"this nccl communicator is created to work on {self.device}, "
|
| 120 |
+
f"but the input tensor is on {in_tensor.device}")
|
| 121 |
+
|
| 122 |
+
out_tensor = torch.empty_like(in_tensor)
|
| 123 |
+
|
| 124 |
+
if stream is None:
|
| 125 |
+
stream = current_stream()
|
| 126 |
+
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
|
| 127 |
+
buffer_type(out_tensor.data_ptr()),
|
| 128 |
+
in_tensor.numel(),
|
| 129 |
+
ncclDataTypeEnum.from_torch(in_tensor.dtype),
|
| 130 |
+
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
| 131 |
+
cudaStream_t(stream.cuda_stream))
|
| 132 |
+
return out_tensor
|
| 133 |
+
|
| 134 |
+
def all_gather(self,
|
| 135 |
+
output_tensor: torch.Tensor,
|
| 136 |
+
input_tensor: torch.Tensor,
|
| 137 |
+
stream=None):
|
| 138 |
+
if self.disabled:
|
| 139 |
+
return
|
| 140 |
+
# nccl communicator created on a specific device
|
| 141 |
+
# will only work on tensors on the same device
|
| 142 |
+
# otherwise it will cause "illegal memory access"
|
| 143 |
+
assert input_tensor.device == self.device, (
|
| 144 |
+
f"this nccl communicator is created to work on {self.device}, "
|
| 145 |
+
f"but the input tensor is on {input_tensor.device}")
|
| 146 |
+
if stream is None:
|
| 147 |
+
stream = current_stream()
|
| 148 |
+
self.nccl.ncclAllGather(
|
| 149 |
+
buffer_type(input_tensor.data_ptr()),
|
| 150 |
+
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
|
| 151 |
+
ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
|
| 152 |
+
cudaStream_t(stream.cuda_stream))
|
| 153 |
+
|
| 154 |
+
def reduce_scatter(self,
|
| 155 |
+
output_tensor: torch.Tensor,
|
| 156 |
+
input_tensor: torch.Tensor,
|
| 157 |
+
op: ReduceOp = ReduceOp.SUM,
|
| 158 |
+
stream=None):
|
| 159 |
+
if self.disabled:
|
| 160 |
+
return
|
| 161 |
+
# nccl communicator created on a specific device
|
| 162 |
+
# will only work on tensors on the same device
|
| 163 |
+
# otherwise it will cause "illegal memory access"
|
| 164 |
+
assert input_tensor.device == self.device, (
|
| 165 |
+
f"this nccl communicator is created to work on {self.device}, "
|
| 166 |
+
f"but the input tensor is on {input_tensor.device}")
|
| 167 |
+
if stream is None:
|
| 168 |
+
stream = current_stream()
|
| 169 |
+
self.nccl.ncclReduceScatter(
|
| 170 |
+
buffer_type(input_tensor.data_ptr()),
|
| 171 |
+
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
|
| 172 |
+
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
| 173 |
+
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
| 174 |
+
cudaStream_t(stream.cuda_stream))
|
| 175 |
+
|
| 176 |
+
def send(self, tensor: torch.Tensor, dst: int, stream=None):
|
| 177 |
+
if self.disabled:
|
| 178 |
+
return
|
| 179 |
+
assert tensor.device == self.device, (
|
| 180 |
+
f"this nccl communicator is created to work on {self.device}, "
|
| 181 |
+
f"but the input tensor is on {tensor.device}")
|
| 182 |
+
if stream is None:
|
| 183 |
+
stream = current_stream()
|
| 184 |
+
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
|
| 185 |
+
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
|
| 186 |
+
self.comm, cudaStream_t(stream.cuda_stream))
|
| 187 |
+
|
| 188 |
+
def recv(self, tensor: torch.Tensor, src: int, stream=None):
|
| 189 |
+
if self.disabled:
|
| 190 |
+
return
|
| 191 |
+
assert tensor.device == self.device, (
|
| 192 |
+
f"this nccl communicator is created to work on {self.device}, "
|
| 193 |
+
f"but the input tensor is on {tensor.device}")
|
| 194 |
+
if stream is None:
|
| 195 |
+
stream = current_stream()
|
| 196 |
+
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
|
| 197 |
+
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
| 198 |
+
self.comm, cudaStream_t(stream.cuda_stream))
|
| 199 |
+
|
| 200 |
+
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
|
| 201 |
+
if self.disabled:
|
| 202 |
+
return
|
| 203 |
+
assert tensor.device == self.device, (
|
| 204 |
+
f"this nccl communicator is created to work on {self.device}, "
|
| 205 |
+
f"but the input tensor is on {tensor.device}")
|
| 206 |
+
if stream is None:
|
| 207 |
+
stream = current_stream()
|
| 208 |
+
if src == self.rank:
|
| 209 |
+
sendbuff = buffer_type(tensor.data_ptr())
|
| 210 |
+
# NCCL requires the sender also to have a receive buffer
|
| 211 |
+
recvbuff = buffer_type(tensor.data_ptr())
|
| 212 |
+
else:
|
| 213 |
+
sendbuff = buffer_type()
|
| 214 |
+
recvbuff = buffer_type(tensor.data_ptr())
|
| 215 |
+
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
|
| 216 |
+
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
| 217 |
+
self.comm, cudaStream_t(stream.cuda_stream))
|
.venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/shm_broadcast.py
ADDED
|
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import pickle
|
| 5 |
+
import sys
|
| 6 |
+
import time
|
| 7 |
+
from contextlib import contextmanager
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
from multiprocessing import shared_memory
|
| 10 |
+
from typing import List, Optional, Tuple, Union
|
| 11 |
+
from unittest.mock import patch
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.distributed as dist
|
| 15 |
+
from torch.distributed import ProcessGroup
|
| 16 |
+
from zmq import IPV6 # type: ignore
|
| 17 |
+
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
| 18 |
+
|
| 19 |
+
import vllm.envs as envs
|
| 20 |
+
from vllm.distributed.utils import StatelessProcessGroup
|
| 21 |
+
from vllm.logger import init_logger
|
| 22 |
+
from vllm.utils import get_ip, get_open_port, is_valid_ipv6_address
|
| 23 |
+
|
| 24 |
+
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
| 25 |
+
|
| 26 |
+
logger = init_logger(__name__)
|
| 27 |
+
|
| 28 |
+
# We prefer to use os.sched_yield as it results in tighter polling loops,
|
| 29 |
+
# measured to be around 3e-7 seconds. However on earlier versions of Python
|
| 30 |
+
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
|
| 31 |
+
USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1))
|
| 32 |
+
or (sys.version_info[:2] == (3, 10)
|
| 33 |
+
and sys.version_info[2] >= 8))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def sched_yield():
|
| 37 |
+
if USE_SCHED_YIELD:
|
| 38 |
+
os.sched_yield()
|
| 39 |
+
else:
|
| 40 |
+
time.sleep(0)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ShmRingBuffer:
|
| 44 |
+
|
| 45 |
+
def __init__(self,
|
| 46 |
+
n_reader: int,
|
| 47 |
+
max_chunk_bytes: int,
|
| 48 |
+
max_chunks: int,
|
| 49 |
+
name: Optional[str] = None):
|
| 50 |
+
"""
|
| 51 |
+
A shared memory ring buffer implementation for broadcast communication.
|
| 52 |
+
Essentially, it is a queue where only one will `enqueue` and multiple
|
| 53 |
+
will `dequeue`. The max size of each item, together with the max number
|
| 54 |
+
of items that can be stored in the buffer are known in advance.
|
| 55 |
+
In this case, we don't need to synchronize the access to
|
| 56 |
+
the buffer.
|
| 57 |
+
|
| 58 |
+
Buffer memory layout:
|
| 59 |
+
data metadata
|
| 60 |
+
| |
|
| 61 |
+
| (current_idx) | (current_idx)
|
| 62 |
+
v v
|
| 63 |
+
+-------------------------------+----------------------------------------+
|
| 64 |
+
| chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
|
| 65 |
+
+-------------------------------+----------------------------------------+
|
| 66 |
+
| max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes |
|
| 67 |
+
|
| 68 |
+
metadata memory layout: each byte is a flag, the first byte is the written
|
| 69 |
+
flag, and the rest are reader flags. The flags are set to 0 by default.
|
| 70 |
+
+--------------+--------------+--------------+-----+--------------+
|
| 71 |
+
| written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
|
| 72 |
+
+--------------+--------------+--------------+-----+--------------+
|
| 73 |
+
|
| 74 |
+
The state of metadata is as follows:
|
| 75 |
+
|
| 76 |
+
(case 1) 0???...???: the block is not written yet, cannot read, can write
|
| 77 |
+
(case 2) 1000...000: the block is just written, can read, cannot write
|
| 78 |
+
(case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
|
| 79 |
+
(case 4) 1111...111: the block is written and read by all readers, cannot read, can write
|
| 80 |
+
|
| 81 |
+
State transition for readers:
|
| 82 |
+
|
| 83 |
+
When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
|
| 84 |
+
Only after the caller finishes reading the block, the reader can mark the block as read.
|
| 85 |
+
Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).
|
| 86 |
+
|
| 87 |
+
State transition for writer:
|
| 88 |
+
|
| 89 |
+
When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
|
| 90 |
+
to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
|
| 91 |
+
can reset the reader flags to 0, and mark the block as written (from 0 to 1).
|
| 92 |
+
NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.
|
| 93 |
+
|
| 94 |
+
During creation, `name` is None and the buffer is created. We can pass the
|
| 95 |
+
created object to other processes by pickling it. The other processes will
|
| 96 |
+
get the name of the shared memory and open it, so that they can access the
|
| 97 |
+
same shared memory buffer.
|
| 98 |
+
"""# noqa
|
| 99 |
+
self.n_reader = n_reader
|
| 100 |
+
self.metadata_size = 1 + n_reader
|
| 101 |
+
self.max_chunk_bytes = max_chunk_bytes
|
| 102 |
+
self.max_chunks = max_chunks
|
| 103 |
+
self.total_bytes_of_buffer = (self.max_chunk_bytes +
|
| 104 |
+
self.metadata_size) * self.max_chunks
|
| 105 |
+
self.data_offset = 0
|
| 106 |
+
self.metadata_offset = self.max_chunk_bytes * self.max_chunks
|
| 107 |
+
|
| 108 |
+
if name is None:
|
| 109 |
+
# we are creating a buffer
|
| 110 |
+
self.is_creator = True
|
| 111 |
+
self.shared_memory = shared_memory.SharedMemory(
|
| 112 |
+
create=True, size=self.total_bytes_of_buffer)
|
| 113 |
+
# initialize the metadata section to 0
|
| 114 |
+
with memoryview(self.shared_memory.buf[self.metadata_offset:]
|
| 115 |
+
) as metadata_buffer:
|
| 116 |
+
torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
|
| 117 |
+
else:
|
| 118 |
+
# we are opening an existing buffer
|
| 119 |
+
self.is_creator = False
|
| 120 |
+
# fix to https://stackoverflow.com/q/62748654/9191338
|
| 121 |
+
# Python incorrectly tracks shared memory even if it is not
|
| 122 |
+
# created by the process. The following patch is a workaround.
|
| 123 |
+
with patch("multiprocessing.resource_tracker.register",
|
| 124 |
+
lambda *args, **kwargs: None):
|
| 125 |
+
try:
|
| 126 |
+
self.shared_memory = shared_memory.SharedMemory(name=name)
|
| 127 |
+
assert (
|
| 128 |
+
self.shared_memory.size == self.total_bytes_of_buffer)
|
| 129 |
+
except FileNotFoundError:
|
| 130 |
+
# we might deserialize the object in a different node
|
| 131 |
+
# in this case, this object is not used,
|
| 132 |
+
# and we should suppress the error
|
| 133 |
+
pass
|
| 134 |
+
|
| 135 |
+
def handle(self):
|
| 136 |
+
return (self.n_reader, self.max_chunk_bytes, self.max_chunks,
|
| 137 |
+
self.shared_memory.name)
|
| 138 |
+
|
| 139 |
+
def __reduce__(self):
|
| 140 |
+
return (
|
| 141 |
+
self.__class__,
|
| 142 |
+
self.handle(),
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
def __del__(self):
|
| 146 |
+
if hasattr(self, "shared_memory"):
|
| 147 |
+
self.shared_memory.close()
|
| 148 |
+
if self.is_creator:
|
| 149 |
+
self.shared_memory.unlink()
|
| 150 |
+
|
| 151 |
+
@contextmanager
|
| 152 |
+
def get_data(self, current_idx: int):
|
| 153 |
+
start = self.data_offset + current_idx * self.max_chunk_bytes
|
| 154 |
+
end = start + self.max_chunk_bytes
|
| 155 |
+
with memoryview(self.shared_memory.buf[start:end]) as buf:
|
| 156 |
+
yield buf
|
| 157 |
+
|
| 158 |
+
@contextmanager
|
| 159 |
+
def get_metadata(self, current_idx: int):
|
| 160 |
+
start = self.metadata_offset + current_idx * self.metadata_size
|
| 161 |
+
end = start + self.metadata_size
|
| 162 |
+
with memoryview(self.shared_memory.buf[start:end]) as buf:
|
| 163 |
+
yield buf
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
@dataclass
|
| 167 |
+
class Handle:
|
| 168 |
+
connect_ip: str
|
| 169 |
+
local_reader_ranks: List[int] = field(default_factory=list)
|
| 170 |
+
|
| 171 |
+
buffer_handle: Optional[Tuple[int, int, int, str]] = None
|
| 172 |
+
local_subscribe_port: Optional[int] = None
|
| 173 |
+
remote_subscribe_port: Optional[int] = None
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class MessageQueue:
|
| 177 |
+
|
| 178 |
+
def __init__(
|
| 179 |
+
self,
|
| 180 |
+
n_reader, # number of all readers
|
| 181 |
+
n_local_reader, # number of local readers through shared memory
|
| 182 |
+
local_reader_ranks: Optional[List[int]] = None,
|
| 183 |
+
max_chunk_bytes: int = 1024 * 1024 * 10,
|
| 184 |
+
max_chunks: int = 10,
|
| 185 |
+
connect_ip: Optional[str] = None,
|
| 186 |
+
):
|
| 187 |
+
if local_reader_ranks is None:
|
| 188 |
+
local_reader_ranks = list(range(n_local_reader))
|
| 189 |
+
else:
|
| 190 |
+
assert len(local_reader_ranks) == n_local_reader
|
| 191 |
+
self.n_local_reader = n_local_reader
|
| 192 |
+
n_remote_reader = n_reader - n_local_reader
|
| 193 |
+
self.n_remote_reader = n_remote_reader
|
| 194 |
+
|
| 195 |
+
if connect_ip is None:
|
| 196 |
+
connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1"
|
| 197 |
+
|
| 198 |
+
context = Context()
|
| 199 |
+
|
| 200 |
+
if n_local_reader > 0:
|
| 201 |
+
# for local readers, we will:
|
| 202 |
+
# 1. create a shared memory ring buffer to communicate small data
|
| 203 |
+
# 2. create a publish-subscribe socket to communicate large data
|
| 204 |
+
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
|
| 205 |
+
max_chunks)
|
| 206 |
+
|
| 207 |
+
# XPUB is very similar to PUB,
|
| 208 |
+
# except that it can receive subscription messages
|
| 209 |
+
# to confirm the number of subscribers
|
| 210 |
+
self.local_socket = context.socket(XPUB)
|
| 211 |
+
# set the verbose option so that we can receive every subscription
|
| 212 |
+
# message. otherwise, we will only receive the first subscription
|
| 213 |
+
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details
|
| 214 |
+
self.local_socket.setsockopt(XPUB_VERBOSE, True)
|
| 215 |
+
local_subscribe_port = get_open_port()
|
| 216 |
+
socket_addr = f"tcp://127.0.0.1:{local_subscribe_port}"
|
| 217 |
+
logger.debug("Binding to %s", socket_addr)
|
| 218 |
+
self.local_socket.bind(socket_addr)
|
| 219 |
+
|
| 220 |
+
self.current_idx = 0
|
| 221 |
+
|
| 222 |
+
else:
|
| 223 |
+
self.buffer = None # type: ignore
|
| 224 |
+
local_subscribe_port = None
|
| 225 |
+
self.local_socket = None
|
| 226 |
+
self.current_idx = -1
|
| 227 |
+
|
| 228 |
+
if n_remote_reader > 0:
|
| 229 |
+
# for remote readers, we will:
|
| 230 |
+
# create a publish-subscribe socket to communicate large data
|
| 231 |
+
self.remote_socket = context.socket(XPUB)
|
| 232 |
+
self.remote_socket.setsockopt(XPUB_VERBOSE, True)
|
| 233 |
+
remote_subscribe_port = get_open_port()
|
| 234 |
+
if is_valid_ipv6_address(connect_ip):
|
| 235 |
+
self.remote_socket.setsockopt(IPV6, 1)
|
| 236 |
+
socket_addr = f"tcp://*:{remote_subscribe_port}"
|
| 237 |
+
self.remote_socket.bind(socket_addr)
|
| 238 |
+
|
| 239 |
+
else:
|
| 240 |
+
remote_subscribe_port = None
|
| 241 |
+
self.remote_socket = None
|
| 242 |
+
|
| 243 |
+
self._is_writer = True
|
| 244 |
+
self._is_local_reader = False
|
| 245 |
+
self.local_reader_rank = -1
|
| 246 |
+
# rank does not matter for remote readers
|
| 247 |
+
self._is_remote_reader = False
|
| 248 |
+
|
| 249 |
+
self.handle = Handle(
|
| 250 |
+
connect_ip=connect_ip,
|
| 251 |
+
local_reader_ranks=local_reader_ranks,
|
| 252 |
+
buffer_handle=self.buffer.handle()
|
| 253 |
+
if self.buffer is not None else None,
|
| 254 |
+
local_subscribe_port=local_subscribe_port,
|
| 255 |
+
remote_subscribe_port=remote_subscribe_port,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
logger.info("vLLM message queue communication handle: %s", self.handle)
|
| 259 |
+
|
| 260 |
+
def export_handle(self) -> Handle:
|
| 261 |
+
return self.handle
|
| 262 |
+
|
| 263 |
+
@staticmethod
|
| 264 |
+
def create_from_handle(handle: Handle, rank) -> "MessageQueue":
|
| 265 |
+
self = MessageQueue.__new__(MessageQueue)
|
| 266 |
+
self.handle = handle
|
| 267 |
+
self._is_writer = False
|
| 268 |
+
|
| 269 |
+
context = Context()
|
| 270 |
+
|
| 271 |
+
if rank in handle.local_reader_ranks:
|
| 272 |
+
assert handle.buffer_handle is not None
|
| 273 |
+
self.buffer = ShmRingBuffer(*handle.buffer_handle)
|
| 274 |
+
self.current_idx = 0
|
| 275 |
+
self.local_reader_rank = handle.local_reader_ranks.index(rank)
|
| 276 |
+
self._is_local_reader = True
|
| 277 |
+
self._is_remote_reader = False
|
| 278 |
+
|
| 279 |
+
self.local_socket = context.socket(SUB)
|
| 280 |
+
self.local_socket.setsockopt_string(SUBSCRIBE, "")
|
| 281 |
+
socket_addr = f"tcp://127.0.0.1:{handle.local_subscribe_port}"
|
| 282 |
+
logger.debug("Connecting to %s", socket_addr)
|
| 283 |
+
self.local_socket.connect(socket_addr)
|
| 284 |
+
|
| 285 |
+
self.remote_socket = None
|
| 286 |
+
else:
|
| 287 |
+
self.buffer = None # type: ignore
|
| 288 |
+
self.current_idx = -1
|
| 289 |
+
self.local_reader_rank = -1
|
| 290 |
+
self._is_local_reader = False
|
| 291 |
+
self._is_remote_reader = True
|
| 292 |
+
|
| 293 |
+
self.local_socket = None
|
| 294 |
+
|
| 295 |
+
self.remote_socket = context.socket(SUB)
|
| 296 |
+
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
|
| 297 |
+
if is_valid_ipv6_address(handle.connect_ip):
|
| 298 |
+
self.remote_socket.setsockopt(IPV6, 1)
|
| 299 |
+
socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}"
|
| 300 |
+
logger.debug("Connecting to %s", socket_addr)
|
| 301 |
+
self.remote_socket.connect(socket_addr)
|
| 302 |
+
|
| 303 |
+
return self
|
| 304 |
+
|
| 305 |
+
def wait_until_ready(self):
|
| 306 |
+
"""This is a collective operation. All processes (including the
|
| 307 |
+
readers and the writer) should call this function.
|
| 308 |
+
"""
|
| 309 |
+
if self._is_writer:
|
| 310 |
+
# wait for all readers to connect
|
| 311 |
+
|
| 312 |
+
# local readers
|
| 313 |
+
for i in range(self.n_local_reader):
|
| 314 |
+
# wait for subscription messages from all local readers
|
| 315 |
+
self.local_socket.recv()
|
| 316 |
+
if self.n_local_reader > 0:
|
| 317 |
+
# send a message to all local readers
|
| 318 |
+
# to make sure the publish channel is working
|
| 319 |
+
self.local_socket.send(b"READY")
|
| 320 |
+
|
| 321 |
+
# remote readers
|
| 322 |
+
for i in range(self.n_remote_reader):
|
| 323 |
+
# wait for subscription messages from all remote readers
|
| 324 |
+
self.remote_socket.recv()
|
| 325 |
+
if self.n_remote_reader > 0:
|
| 326 |
+
# send a message to all remote readers
|
| 327 |
+
# to make sure the publish channel is working
|
| 328 |
+
self.remote_socket.send(b"READY")
|
| 329 |
+
elif self._is_local_reader:
|
| 330 |
+
# wait for the writer to send a message
|
| 331 |
+
recv = self.local_socket.recv()
|
| 332 |
+
assert recv == b"READY"
|
| 333 |
+
elif self._is_remote_reader:
|
| 334 |
+
# wait for the writer to send a message
|
| 335 |
+
recv = self.remote_socket.recv()
|
| 336 |
+
assert recv == b"READY"
|
| 337 |
+
|
| 338 |
+
@contextmanager
|
| 339 |
+
def acquire_write(self, timeout: Optional[float] = None):
|
| 340 |
+
assert self._is_writer, "Only writers can acquire write"
|
| 341 |
+
start_time = time.monotonic()
|
| 342 |
+
n_warning = 1
|
| 343 |
+
while True:
|
| 344 |
+
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
| 345 |
+
read_count = sum(metadata_buffer[1:])
|
| 346 |
+
written_flag = metadata_buffer[0]
|
| 347 |
+
if written_flag and read_count != self.buffer.n_reader:
|
| 348 |
+
# this block is written and not read by all readers
|
| 349 |
+
# for writers, `self.current_idx` is the next block to write
|
| 350 |
+
# if this block is not ready to write,
|
| 351 |
+
# we need to wait until it is read by all readers
|
| 352 |
+
|
| 353 |
+
# Release the processor to other threads
|
| 354 |
+
sched_yield()
|
| 355 |
+
|
| 356 |
+
# if we wait for a long time, log a message
|
| 357 |
+
if (time.monotonic() - start_time
|
| 358 |
+
> VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
|
| 359 |
+
logger.debug("No available block found in %s second. ",
|
| 360 |
+
VLLM_RINGBUFFER_WARNING_INTERVAL)
|
| 361 |
+
n_warning += 1
|
| 362 |
+
|
| 363 |
+
# if we time out, raise an exception
|
| 364 |
+
if (timeout is not None
|
| 365 |
+
and time.monotonic() - start_time > timeout):
|
| 366 |
+
raise TimeoutError
|
| 367 |
+
|
| 368 |
+
continue
|
| 369 |
+
# found a block that is either
|
| 370 |
+
# (1) not written
|
| 371 |
+
# (2) read by all readers
|
| 372 |
+
|
| 373 |
+
# mark the block as not written
|
| 374 |
+
metadata_buffer[0] = 0
|
| 375 |
+
# let caller write to the buffer
|
| 376 |
+
with self.buffer.get_data(self.current_idx) as buf:
|
| 377 |
+
yield buf
|
| 378 |
+
|
| 379 |
+
# caller has written to the buffer
|
| 380 |
+
# NOTE: order is important here
|
| 381 |
+
# first set the read flags to 0
|
| 382 |
+
# then set the written flag to 1
|
| 383 |
+
# otherwise, the readers may think they already read the block
|
| 384 |
+
for i in range(1, self.buffer.n_reader + 1):
|
| 385 |
+
# set read flag to 0, meaning it is not read yet
|
| 386 |
+
metadata_buffer[i] = 0
|
| 387 |
+
# mark the block as written
|
| 388 |
+
metadata_buffer[0] = 1
|
| 389 |
+
self.current_idx = (self.current_idx +
|
| 390 |
+
1) % self.buffer.max_chunks
|
| 391 |
+
break
|
| 392 |
+
|
| 393 |
+
@contextmanager
|
| 394 |
+
def acquire_read(self, timeout: Optional[float] = None):
|
| 395 |
+
assert self._is_local_reader, "Only readers can acquire read"
|
| 396 |
+
start_time = time.monotonic()
|
| 397 |
+
n_warning = 1
|
| 398 |
+
while True:
|
| 399 |
+
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
| 400 |
+
read_flag = metadata_buffer[self.local_reader_rank + 1]
|
| 401 |
+
written_flag = metadata_buffer[0]
|
| 402 |
+
if not written_flag or read_flag:
|
| 403 |
+
# this block is either
|
| 404 |
+
# (1) not written
|
| 405 |
+
# (2) already read by this reader
|
| 406 |
+
|
| 407 |
+
# for readers, `self.current_idx` is the next block to read
|
| 408 |
+
# if this block is not ready,
|
| 409 |
+
# we need to wait until it is written
|
| 410 |
+
|
| 411 |
+
# Release the processor to other threads
|
| 412 |
+
sched_yield()
|
| 413 |
+
|
| 414 |
+
# if we wait for a long time, log a message
|
| 415 |
+
if (time.monotonic() - start_time
|
| 416 |
+
> VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
|
| 417 |
+
logger.debug("No available block found in %s second. ",
|
| 418 |
+
VLLM_RINGBUFFER_WARNING_INTERVAL)
|
| 419 |
+
n_warning += 1
|
| 420 |
+
|
| 421 |
+
# if we time out, raise an exception
|
| 422 |
+
if (timeout is not None
|
| 423 |
+
and time.monotonic() - start_time > timeout):
|
| 424 |
+
raise TimeoutError
|
| 425 |
+
|
| 426 |
+
continue
|
| 427 |
+
# found a block that is not read by this reader
|
| 428 |
+
# let caller read from the buffer
|
| 429 |
+
with self.buffer.get_data(self.current_idx) as buf:
|
| 430 |
+
yield buf
|
| 431 |
+
|
| 432 |
+
# caller has read from the buffer
|
| 433 |
+
# set the read flag
|
| 434 |
+
metadata_buffer[self.local_reader_rank + 1] = 1
|
| 435 |
+
self.current_idx = (self.current_idx +
|
| 436 |
+
1) % self.buffer.max_chunks
|
| 437 |
+
break
|
| 438 |
+
|
| 439 |
+
def enqueue(self, obj, timeout: Optional[float] = None):
|
| 440 |
+
""" Write to message queue with optional timeout (in seconds) """
|
| 441 |
+
assert self._is_writer, "Only writers can enqueue"
|
| 442 |
+
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
|
| 443 |
+
if self.n_local_reader > 0:
|
| 444 |
+
if len(serialized_obj) >= self.buffer.max_chunk_bytes:
|
| 445 |
+
with self.acquire_write(timeout) as buf:
|
| 446 |
+
buf[0] = 1 # overflow
|
| 447 |
+
self.local_socket.send(serialized_obj)
|
| 448 |
+
else:
|
| 449 |
+
with self.acquire_write(timeout) as buf:
|
| 450 |
+
buf[0] = 0 # not overflow
|
| 451 |
+
buf[1:len(serialized_obj) + 1] = serialized_obj
|
| 452 |
+
if self.n_remote_reader > 0:
|
| 453 |
+
self.remote_socket.send(serialized_obj)
|
| 454 |
+
|
| 455 |
+
def dequeue(self, timeout: Optional[float] = None):
|
| 456 |
+
""" Read from message queue with optional timeout (in seconds) """
|
| 457 |
+
if self._is_local_reader:
|
| 458 |
+
with self.acquire_read(timeout) as buf:
|
| 459 |
+
overflow = buf[0] == 1
|
| 460 |
+
if not overflow:
|
| 461 |
+
# no need to know the size of serialized object
|
| 462 |
+
# pickle format contains the size information internally
|
| 463 |
+
# see https://docs.python.org/3/library/pickle.html
|
| 464 |
+
obj = pickle.loads(buf[1:])
|
| 465 |
+
if overflow:
|
| 466 |
+
recv = self.local_socket.recv()
|
| 467 |
+
obj = pickle.loads(recv)
|
| 468 |
+
elif self._is_remote_reader:
|
| 469 |
+
recv = self.remote_socket.recv()
|
| 470 |
+
obj = pickle.loads(recv)
|
| 471 |
+
else:
|
| 472 |
+
raise RuntimeError("Only readers can dequeue")
|
| 473 |
+
return obj
|
| 474 |
+
|
| 475 |
+
def broadcast_object(self, obj=None):
|
| 476 |
+
if self._is_writer:
|
| 477 |
+
self.enqueue(obj)
|
| 478 |
+
return obj
|
| 479 |
+
else:
|
| 480 |
+
return self.dequeue()
|
| 481 |
+
|
| 482 |
+
@staticmethod
|
| 483 |
+
def create_from_process_group(pg: Union[ProcessGroup,
|
| 484 |
+
StatelessProcessGroup],
|
| 485 |
+
max_chunk_bytes,
|
| 486 |
+
max_chunks,
|
| 487 |
+
writer_rank=0) -> "MessageQueue":
|
| 488 |
+
if isinstance(pg, ProcessGroup):
|
| 489 |
+
group_rank = dist.get_rank(pg)
|
| 490 |
+
group_world_size = dist.get_world_size(pg)
|
| 491 |
+
global_ranks = dist.get_process_group_ranks(pg)
|
| 492 |
+
else:
|
| 493 |
+
group_rank = pg.rank
|
| 494 |
+
group_world_size = pg.world_size
|
| 495 |
+
global_ranks = list(range(pg.world_size))
|
| 496 |
+
|
| 497 |
+
from vllm.distributed.parallel_state import in_the_same_node_as
|
| 498 |
+
status = in_the_same_node_as(pg, source_rank=writer_rank)
|
| 499 |
+
same_node_ranks = [i for i, s in enumerate(status) if s]
|
| 500 |
+
n_reader = group_world_size - 1
|
| 501 |
+
n_local_reader = len(same_node_ranks) - 1
|
| 502 |
+
local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
|
| 503 |
+
buffer_io: MessageQueue
|
| 504 |
+
if group_rank == writer_rank:
|
| 505 |
+
buffer_io = MessageQueue(
|
| 506 |
+
n_reader=n_reader,
|
| 507 |
+
n_local_reader=n_local_reader,
|
| 508 |
+
local_reader_ranks=local_reader_ranks,
|
| 509 |
+
max_chunk_bytes=max_chunk_bytes,
|
| 510 |
+
max_chunks=max_chunks,
|
| 511 |
+
)
|
| 512 |
+
handle = buffer_io.export_handle()
|
| 513 |
+
if isinstance(pg, ProcessGroup):
|
| 514 |
+
dist.broadcast_object_list([handle],
|
| 515 |
+
src=global_ranks[writer_rank],
|
| 516 |
+
group=pg)
|
| 517 |
+
else:
|
| 518 |
+
pg.broadcast_obj(handle, writer_rank)
|
| 519 |
+
else:
|
| 520 |
+
if isinstance(pg, ProcessGroup):
|
| 521 |
+
recv = [None]
|
| 522 |
+
dist.broadcast_object_list(recv,
|
| 523 |
+
src=global_ranks[writer_rank],
|
| 524 |
+
group=pg)
|
| 525 |
+
handle = recv[0] # type: ignore
|
| 526 |
+
else:
|
| 527 |
+
handle = pg.broadcast_obj(None, writer_rank)
|
| 528 |
+
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
|
| 529 |
+
buffer_io.wait_until_ready()
|
| 530 |
+
return buffer_io
|
.venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/xpu_communicator.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
from torch.distributed import ProcessGroup
|
| 6 |
+
|
| 7 |
+
from vllm.platforms import current_platform
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class XpuCommunicator:
|
| 11 |
+
|
| 12 |
+
def __init__(self, group: ProcessGroup):
|
| 13 |
+
if not current_platform.is_xpu():
|
| 14 |
+
self.disabled = True
|
| 15 |
+
return
|
| 16 |
+
self.disabled = False
|
| 17 |
+
self.group = group
|
| 18 |
+
self.world_size = dist.get_world_size(self.group)
|
| 19 |
+
|
| 20 |
+
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
| 21 |
+
dist.all_reduce(x, group=self.group)
|
| 22 |
+
return x
|
| 23 |
+
|
| 24 |
+
def gather(self,
|
| 25 |
+
input_: torch.Tensor,
|
| 26 |
+
rank_in_group: int,
|
| 27 |
+
dst: int = 0,
|
| 28 |
+
dim: int = -1):
|
| 29 |
+
# For xpu path, gather doesn't work properly together with ray
|
| 30 |
+
# cluster so we use all_gather instead for now.
|
| 31 |
+
input_size = input_.size()
|
| 32 |
+
# Allocate output tensor.
|
| 33 |
+
output_tensor = torch.empty((self.world_size, ) + input_size,
|
| 34 |
+
dtype=input_.dtype,
|
| 35 |
+
device=input_.device)
|
| 36 |
+
# All-gather.
|
| 37 |
+
torch.distributed.all_gather_into_tensor(output_tensor,
|
| 38 |
+
input_,
|
| 39 |
+
group=self.group)
|
| 40 |
+
if rank_in_group == dst:
|
| 41 |
+
# Reshape
|
| 42 |
+
output_tensor = output_tensor.movedim(0, dim)
|
| 43 |
+
output_tensor = output_tensor.reshape(input_size[:dim] +
|
| 44 |
+
(self.world_size *
|
| 45 |
+
input_size[dim], ) +
|
| 46 |
+
input_size[dim + 1:])
|
| 47 |
+
else:
|
| 48 |
+
output_tensor = None
|
| 49 |
+
return output_tensor
|
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (201 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/__pycache__/kv_transfer_agent.cpython-311.pyc
ADDED
|
Binary file (3.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (214 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/__pycache__/base.cpython-311.pyc
ADDED
|
Binary file (5.36 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/__pycache__/factory.cpython-311.pyc
ADDED
|
Binary file (2.78 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/__pycache__/simple_connector.cpython-311.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/base.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""
|
| 3 |
+
KVConnectorBase Class for Distributed KV Cache & Hidden State communication
|
| 4 |
+
|
| 5 |
+
The class provides two primary abstract methods:
|
| 6 |
+
1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
|
| 7 |
+
2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from abc import ABC, abstractmethod
|
| 11 |
+
from typing import TYPE_CHECKING, List, Tuple, Union
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
from vllm.sequence import IntermediateTensors
|
| 16 |
+
|
| 17 |
+
if TYPE_CHECKING:
|
| 18 |
+
from vllm.config import VllmConfig
|
| 19 |
+
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class KVConnectorBase(ABC):
|
| 23 |
+
"""
|
| 24 |
+
Abstract base class for a KV connector.
|
| 25 |
+
|
| 26 |
+
The class provides two primary abstract methods:
|
| 27 |
+
1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
|
| 28 |
+
2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
@abstractmethod
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
rank: int,
|
| 35 |
+
local_rank: int,
|
| 36 |
+
config: "VllmConfig",
|
| 37 |
+
):
|
| 38 |
+
raise NotImplementedError
|
| 39 |
+
|
| 40 |
+
@abstractmethod
|
| 41 |
+
def close(self) -> None:
|
| 42 |
+
"""Close the buffer and release resources.
|
| 43 |
+
|
| 44 |
+
This method is responsible for cleaning up resources related to the
|
| 45 |
+
connector when it is no longer needed.
|
| 46 |
+
|
| 47 |
+
Raises:
|
| 48 |
+
NotImplementedError: This method must be implemented in subclasses.
|
| 49 |
+
"""
|
| 50 |
+
raise NotImplementedError
|
| 51 |
+
|
| 52 |
+
@abstractmethod
|
| 53 |
+
def send_kv_caches_and_hidden_states(
|
| 54 |
+
self,
|
| 55 |
+
model_executable: torch.nn.Module,
|
| 56 |
+
model_input: "ModelInputForGPUWithSamplingMetadata",
|
| 57 |
+
kv_caches: List[torch.Tensor],
|
| 58 |
+
hidden_or_intermediate_states: Union[torch.Tensor,
|
| 59 |
+
IntermediateTensors],
|
| 60 |
+
) -> None:
|
| 61 |
+
"""
|
| 62 |
+
Send KV caches and hidden states to the connector.
|
| 63 |
+
|
| 64 |
+
This method processes the input tokens, KV caches, and
|
| 65 |
+
hidden/intermediate states for a given model and sends the data to the
|
| 66 |
+
decode instance.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
model_executable (torch.nn.Module): The model executable containing
|
| 70 |
+
start and end layer information.
|
| 71 |
+
model_input (ModelInputForGPUWithSamplingMetadata): The input
|
| 72 |
+
metadata from vLLM.
|
| 73 |
+
kv_caches (List[torch.Tensor]): List of KV caches (keys and values)
|
| 74 |
+
for each layer.
|
| 75 |
+
hidden_or_intermediate_states (Union[torch.Tensor,
|
| 76 |
+
IntermediateTensors]):
|
| 77 |
+
The hidden or intermediate states associated with the tokens.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
None
|
| 81 |
+
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
raise NotImplementedError
|
| 85 |
+
|
| 86 |
+
@abstractmethod
|
| 87 |
+
def recv_kv_caches_and_hidden_states(
|
| 88 |
+
self, model_executable: torch.nn.Module,
|
| 89 |
+
model_input: "ModelInputForGPUWithSamplingMetadata",
|
| 90 |
+
kv_caches: List[torch.Tensor]
|
| 91 |
+
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
| 92 |
+
"ModelInputForGPUWithSamplingMetadata"]:
|
| 93 |
+
"""
|
| 94 |
+
Receive KV caches and hidden states from the connector.
|
| 95 |
+
|
| 96 |
+
This method attempts to retrieve KV caches and hidden states for input
|
| 97 |
+
tokens. If all required KV caches and hidden states are received, it
|
| 98 |
+
will bypass model input, else it will fall back to normal vLLM model
|
| 99 |
+
forwarding.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
model_executable (torch.nn.Module):
|
| 103 |
+
The model executable from vLLM modelrunner.
|
| 104 |
+
model_input (ModelInputForGPUWithSamplingMetadata):
|
| 105 |
+
The model input from vLLM modelrunner.
|
| 106 |
+
kv_caches (List[torch.Tensor]):
|
| 107 |
+
List of KV caches for each layer.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
- hidden_or_intermediate_states (torch.Tensor or
|
| 111 |
+
IntermediateTensors):
|
| 112 |
+
Concatenated hidden states if all required data is retrieved,
|
| 113 |
+
otherwise `None`.
|
| 114 |
+
- bypass_model_exec (bool):
|
| 115 |
+
Indicates whether the model execution can be skipped (True) or
|
| 116 |
+
needs to be redone (False).
|
| 117 |
+
- model_input (ModelInputForGPUWithSamplingMetadata):
|
| 118 |
+
Optionally adjusted input metadata for re-execution when
|
| 119 |
+
`bypass_model_exec=False`.
|
| 120 |
+
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
raise NotImplementedError
|
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/factory.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import importlib
|
| 4 |
+
from typing import TYPE_CHECKING, Callable, Dict, Type
|
| 5 |
+
|
| 6 |
+
from .base import KVConnectorBase
|
| 7 |
+
|
| 8 |
+
if TYPE_CHECKING:
|
| 9 |
+
from vllm.config import VllmConfig
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class KVConnectorFactory:
|
| 13 |
+
_registry: Dict[str, Callable[[], Type[KVConnectorBase]]] = {}
|
| 14 |
+
|
| 15 |
+
@classmethod
|
| 16 |
+
def register_connector(cls, name: str, module_path: str,
|
| 17 |
+
class_name: str) -> None:
|
| 18 |
+
"""Register a connector with a lazy-loading module and class name."""
|
| 19 |
+
if name in cls._registry:
|
| 20 |
+
raise ValueError(f"Connector '{name}' is already registered.")
|
| 21 |
+
|
| 22 |
+
def loader() -> Type[KVConnectorBase]:
|
| 23 |
+
module = importlib.import_module(module_path)
|
| 24 |
+
return getattr(module, class_name)
|
| 25 |
+
|
| 26 |
+
cls._registry[name] = loader
|
| 27 |
+
|
| 28 |
+
@classmethod
|
| 29 |
+
def create_connector(cls, rank: int, local_rank: int,
|
| 30 |
+
config: "VllmConfig") -> KVConnectorBase:
|
| 31 |
+
connector_name = config.kv_transfer_config.kv_connector
|
| 32 |
+
if connector_name not in cls._registry:
|
| 33 |
+
raise ValueError(f"Unsupported connector type: {connector_name}")
|
| 34 |
+
|
| 35 |
+
connector_cls = cls._registry[connector_name]()
|
| 36 |
+
return connector_cls(rank, local_rank, config)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# Register various connectors here.
|
| 40 |
+
# The registration should not be done in each individual file, as we want to
|
| 41 |
+
# only load the files corresponding to the current connector.
|
| 42 |
+
KVConnectorFactory.register_connector(
|
| 43 |
+
"PyNcclConnector",
|
| 44 |
+
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
|
| 45 |
+
"SimpleConnector")
|
| 46 |
+
|
| 47 |
+
KVConnectorFactory.register_connector(
|
| 48 |
+
"MooncakeConnector",
|
| 49 |
+
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
|
| 50 |
+
"SimpleConnector")
|
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""
|
| 3 |
+
Simple KV Cache Connector for Distributed Machine Learning Inference
|
| 4 |
+
|
| 5 |
+
The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
|
| 6 |
+
producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or
|
| 7 |
+
MooncakePipe.
|
| 8 |
+
|
| 9 |
+
But the logic can be extended to support other pipe and lookup buffer.
|
| 10 |
+
"""
|
| 11 |
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
from vllm import _custom_ops as ops
|
| 16 |
+
from vllm.config import VllmConfig
|
| 17 |
+
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
| 18 |
+
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
|
| 19 |
+
SimpleBuffer)
|
| 20 |
+
from vllm.logger import init_logger
|
| 21 |
+
from vllm.sequence import IntermediateTensors
|
| 22 |
+
|
| 23 |
+
if TYPE_CHECKING:
|
| 24 |
+
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
| 25 |
+
|
| 26 |
+
logger = init_logger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class SimpleConnector(KVConnectorBase):
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
rank: int,
|
| 34 |
+
local_rank: int,
|
| 35 |
+
config: VllmConfig,
|
| 36 |
+
):
|
| 37 |
+
|
| 38 |
+
self.config = config.kv_transfer_config
|
| 39 |
+
self.tp_size = config.parallel_config.tensor_parallel_size
|
| 40 |
+
|
| 41 |
+
if self.config.kv_connector == "PyNcclConnector":
|
| 42 |
+
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
|
| 43 |
+
PyNcclPipe)
|
| 44 |
+
logger.info(
|
| 45 |
+
"Initializing PyNcclConfig under kv_transfer_config %s",
|
| 46 |
+
self.config)
|
| 47 |
+
elif self.config.kv_connector == "MooncakeConnector":
|
| 48 |
+
# Check if MOONCAKE_CONFIG_PATH is set
|
| 49 |
+
import os
|
| 50 |
+
use_mooncake_distributed_pipe = os.getenv(
|
| 51 |
+
'MOONCAKE_CONFIG_PATH') is not None
|
| 52 |
+
|
| 53 |
+
if not use_mooncake_distributed_pipe:
|
| 54 |
+
raise ValueError(
|
| 55 |
+
"To use MooncakeConnector, you need to pass the ENV: "
|
| 56 |
+
"'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'.")
|
| 57 |
+
else:
|
| 58 |
+
from vllm.distributed.kv_transfer.kv_pipe.mooncake_pipe import ( # noqa: E501
|
| 59 |
+
MooncakePipe)
|
| 60 |
+
logger.info(
|
| 61 |
+
"Initializing MooncakeConfig under kv_transfer_config %s",
|
| 62 |
+
self.config)
|
| 63 |
+
|
| 64 |
+
self.lookup_buffer_size = self.config.kv_buffer_size
|
| 65 |
+
|
| 66 |
+
self.producer_buffer: Optional[SimpleBuffer] = None
|
| 67 |
+
self.consumer_buffer: Optional[SimpleBuffer] = None
|
| 68 |
+
|
| 69 |
+
self.producer_data_pipe: Union[PyNcclPipe, MooncakePipe]
|
| 70 |
+
self.consumer_data_pipe: Union[PyNcclPipe, MooncakePipe]
|
| 71 |
+
self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
|
| 72 |
+
self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
|
| 73 |
+
|
| 74 |
+
# 2 pipes for every rank in the world
|
| 75 |
+
port_offset_base = 2 * rank
|
| 76 |
+
|
| 77 |
+
# In disaggregated prefill, the prefill vLLM only uses send pipe
|
| 78 |
+
# and the decode vLLM only uses recv pipe
|
| 79 |
+
if self.config.is_kv_producer:
|
| 80 |
+
|
| 81 |
+
if self.config.kv_connector == "PyNcclConnector":
|
| 82 |
+
self.producer_data_pipe = PyNcclPipe(
|
| 83 |
+
local_rank=local_rank,
|
| 84 |
+
config=self.config,
|
| 85 |
+
port_offset=port_offset_base,
|
| 86 |
+
)
|
| 87 |
+
self.producer_signal_pipe = PyNcclPipe(
|
| 88 |
+
local_rank=local_rank,
|
| 89 |
+
config=self.config,
|
| 90 |
+
port_offset=port_offset_base + 1,
|
| 91 |
+
device="cpu",
|
| 92 |
+
)
|
| 93 |
+
elif self.config.kv_connector == "MooncakeConnector":
|
| 94 |
+
self.producer_data_pipe = MooncakePipe(
|
| 95 |
+
local_rank=local_rank,
|
| 96 |
+
config=self.config,
|
| 97 |
+
)
|
| 98 |
+
# We only need to initialize MooncakePipe once
|
| 99 |
+
self.producer_signal_pipe = self.producer_data_pipe
|
| 100 |
+
|
| 101 |
+
self.producer_buffer = SimpleBuffer(self.producer_signal_pipe,
|
| 102 |
+
self.producer_data_pipe,
|
| 103 |
+
self.config.kv_buffer_size)
|
| 104 |
+
|
| 105 |
+
else:
|
| 106 |
+
|
| 107 |
+
# the current vLLM instance is KV consumer, so it needs to connect
|
| 108 |
+
# its recv pipe to the send pipe of KV producder
|
| 109 |
+
if self.config.kv_connector == "PyNcclConnector":
|
| 110 |
+
self.consumer_data_pipe = PyNcclPipe(
|
| 111 |
+
local_rank=local_rank,
|
| 112 |
+
config=self.config,
|
| 113 |
+
port_offset=port_offset_base,
|
| 114 |
+
)
|
| 115 |
+
self.consumer_signal_pipe = PyNcclPipe(
|
| 116 |
+
local_rank=local_rank,
|
| 117 |
+
config=self.config,
|
| 118 |
+
port_offset=port_offset_base + 1,
|
| 119 |
+
device="cpu",
|
| 120 |
+
)
|
| 121 |
+
elif self.config.kv_connector == "MooncakeConnector":
|
| 122 |
+
self.consumer_data_pipe = MooncakePipe(
|
| 123 |
+
local_rank=local_rank,
|
| 124 |
+
config=self.config,
|
| 125 |
+
)
|
| 126 |
+
self.consumer_signal_pipe = self.consumer_data_pipe
|
| 127 |
+
|
| 128 |
+
self.consumer_buffer = SimpleBuffer(
|
| 129 |
+
self.consumer_signal_pipe,
|
| 130 |
+
self.consumer_data_pipe,
|
| 131 |
+
self.config.kv_buffer_size,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
def select(self, input_tokens: Optional[torch.Tensor],
|
| 135 |
+
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
|
| 136 |
+
|
| 137 |
+
assert self.consumer_buffer is not None, "Please initialize the "\
|
| 138 |
+
"consumer buffer before calling select."
|
| 139 |
+
return self.consumer_buffer.drop_select(input_tokens, roi)
|
| 140 |
+
|
| 141 |
+
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
|
| 142 |
+
key: torch.Tensor, value: torch.Tensor,
|
| 143 |
+
hidden: torch.Tensor) -> None:
|
| 144 |
+
|
| 145 |
+
assert self.producer_buffer is not None, "Please initialize the "\
|
| 146 |
+
"producer buffer before calling insert."
|
| 147 |
+
|
| 148 |
+
self.producer_buffer.insert(input_tokens, roi, key, value, hidden)
|
| 149 |
+
|
| 150 |
+
def send_kv_caches_and_hidden_states(
|
| 151 |
+
self,
|
| 152 |
+
model_executable: torch.nn.Module,
|
| 153 |
+
model_input: "ModelInputForGPUWithSamplingMetadata",
|
| 154 |
+
kv_caches: List[torch.Tensor],
|
| 155 |
+
hidden_or_intermediate_states: Union[torch.Tensor,
|
| 156 |
+
IntermediateTensors],
|
| 157 |
+
) -> None:
|
| 158 |
+
|
| 159 |
+
input_tokens_tensor = model_input.input_tokens
|
| 160 |
+
seq_lens = model_input.attn_metadata.seq_lens
|
| 161 |
+
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
|
| 162 |
+
start_layer = model_executable.model.start_layer
|
| 163 |
+
end_layer = model_executable.model.end_layer
|
| 164 |
+
|
| 165 |
+
model_config = model_executable.model.config
|
| 166 |
+
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
| 167 |
+
hidden_size = model_config.hidden_size
|
| 168 |
+
num_attention_heads = model_config.num_attention_heads
|
| 169 |
+
head_size = int(hidden_size / num_attention_heads)
|
| 170 |
+
|
| 171 |
+
# query_lens contains new KV caches that are added to vLLM.
|
| 172 |
+
# so we will send them to decode instance
|
| 173 |
+
# FIXME(Kuntai): This assume that all requests are prefill.
|
| 174 |
+
for idx, slen in enumerate(seq_lens):
|
| 175 |
+
start_pos = sum(seq_lens[:idx])
|
| 176 |
+
end_pos = start_pos + slen
|
| 177 |
+
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
| 178 |
+
|
| 179 |
+
keys, values = [], []
|
| 180 |
+
|
| 181 |
+
for layer_id in range(start_layer, end_layer):
|
| 182 |
+
kv_cache = kv_caches[layer_id - start_layer]
|
| 183 |
+
|
| 184 |
+
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
|
| 185 |
+
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
|
| 186 |
+
|
| 187 |
+
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
|
| 188 |
+
|
| 189 |
+
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
|
| 190 |
+
values.append(value_cache[current_slot_mapping].unsqueeze(0))
|
| 191 |
+
|
| 192 |
+
keys = torch.cat(keys, dim=0)
|
| 193 |
+
values = torch.cat(values, dim=0)
|
| 194 |
+
|
| 195 |
+
self.insert(current_tokens,
|
| 196 |
+
torch.ones_like(current_tokens,
|
| 197 |
+
dtype=bool), keys, values,
|
| 198 |
+
hidden_or_intermediate_states[start_pos:end_pos])
|
| 199 |
+
|
| 200 |
+
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
|
| 201 |
+
|
| 202 |
+
def recv_kv_caches_and_hidden_states(
|
| 203 |
+
self, model_executable: torch.nn.Module,
|
| 204 |
+
model_input: "ModelInputForGPUWithSamplingMetadata",
|
| 205 |
+
kv_caches: List[torch.Tensor]
|
| 206 |
+
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
| 207 |
+
"ModelInputForGPUWithSamplingMetadata"]:
|
| 208 |
+
|
| 209 |
+
# When bypass_model_exec is set to False, it means that at least for one
|
| 210 |
+
# request its corresponding KV cache or hidden state is missing.
|
| 211 |
+
# In this case we need to do prefilling to recompute missing KV cache
|
| 212 |
+
# and hidden states.
|
| 213 |
+
bypass_model_exec = True
|
| 214 |
+
|
| 215 |
+
input_tokens_tensor = model_input.input_tokens
|
| 216 |
+
seq_lens = model_input.attn_metadata.seq_lens
|
| 217 |
+
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
|
| 218 |
+
|
| 219 |
+
hidden_or_intermediate_states_for_one_req = []
|
| 220 |
+
|
| 221 |
+
input_tokens_list = []
|
| 222 |
+
num_computed_tokens_list = []
|
| 223 |
+
start_pos_list = []
|
| 224 |
+
|
| 225 |
+
# enumerate different requests
|
| 226 |
+
# FIXME(Kuntai): This impl assumes that all requests are prefill.
|
| 227 |
+
for idx, slen in enumerate(seq_lens):
|
| 228 |
+
|
| 229 |
+
start_pos = sum(seq_lens[:idx])
|
| 230 |
+
end_pos = start_pos + slen
|
| 231 |
+
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
| 232 |
+
num_tokens = slen
|
| 233 |
+
|
| 234 |
+
# collecting data for rebuilding the input
|
| 235 |
+
input_tokens_list.append(current_tokens)
|
| 236 |
+
start_pos_list.append(start_pos)
|
| 237 |
+
|
| 238 |
+
ret = self.select(current_tokens,
|
| 239 |
+
torch.ones_like(current_tokens, dtype=bool))
|
| 240 |
+
if ret[0] is None:
|
| 241 |
+
# didn't find any match.
|
| 242 |
+
bypass_model_exec = False
|
| 243 |
+
num_computed_tokens_list.append(0)
|
| 244 |
+
continue
|
| 245 |
+
|
| 246 |
+
roi: torch.Tensor = ret[1]
|
| 247 |
+
keys: torch.Tensor = ret[2]
|
| 248 |
+
values: torch.Tensor = ret[3]
|
| 249 |
+
hidden: torch.Tensor = ret[4]
|
| 250 |
+
|
| 251 |
+
num_computed_tokens = roi.shape[0]
|
| 252 |
+
num_computed_tokens_list.append(num_computed_tokens)
|
| 253 |
+
|
| 254 |
+
# check if both KV cache and the hidden states are received
|
| 255 |
+
# If not, need to redo the forwarding to compute missing states
|
| 256 |
+
if not all([(num_computed_tokens == num_tokens), hidden is not None
|
| 257 |
+
]):
|
| 258 |
+
bypass_model_exec = False
|
| 259 |
+
|
| 260 |
+
# update the end position based on how many tokens are cached.
|
| 261 |
+
end_pos = start_pos + num_computed_tokens
|
| 262 |
+
|
| 263 |
+
# put received KV caches into paged memory
|
| 264 |
+
for i in range(model_executable.model.start_layer,
|
| 265 |
+
model_executable.model.end_layer):
|
| 266 |
+
|
| 267 |
+
kv_cache = kv_caches[i - model_executable.model.start_layer]
|
| 268 |
+
layer = model_executable.model.layers[i]
|
| 269 |
+
|
| 270 |
+
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
| 271 |
+
ops.reshape_and_cache_flash(
|
| 272 |
+
keys[i - model_executable.model.start_layer].to(
|
| 273 |
+
key_cache.device),
|
| 274 |
+
values[i - model_executable.model.start_layer].to(
|
| 275 |
+
value_cache.device),
|
| 276 |
+
key_cache,
|
| 277 |
+
value_cache,
|
| 278 |
+
slot_mapping[start_pos:end_pos],
|
| 279 |
+
layer.self_attn.attn.kv_cache_dtype,
|
| 280 |
+
layer.self_attn.attn._k_scale,
|
| 281 |
+
layer.self_attn.attn._v_scale,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
hidden_or_intermediate_states_for_one_req.append(hidden)
|
| 285 |
+
|
| 286 |
+
if not bypass_model_exec:
|
| 287 |
+
# Some of the KV cache is not retrieved
|
| 288 |
+
# Here we will fall back to normal model forwarding
|
| 289 |
+
# But optionally you can adjust model_input so that you only do
|
| 290 |
+
# prefilling on those tokens that are missing KV caches.
|
| 291 |
+
logger.debug(
|
| 292 |
+
"[rank%d]: Failed to receive all KVs and hidden "
|
| 293 |
+
"states, redo model forwarding.", torch.distributed.get_rank())
|
| 294 |
+
hidden_or_intermediate_states = None
|
| 295 |
+
|
| 296 |
+
else:
|
| 297 |
+
logger.debug(
|
| 298 |
+
"[rank%d]: Successfully received all KVs and hidden "
|
| 299 |
+
"states, skip model forwarding.", torch.distributed.get_rank())
|
| 300 |
+
hidden_or_intermediate_states = torch.cat(
|
| 301 |
+
hidden_or_intermediate_states_for_one_req, dim=0)
|
| 302 |
+
|
| 303 |
+
return hidden_or_intermediate_states, bypass_model_exec, model_input
|
| 304 |
+
|
| 305 |
+
def close(self):
|
| 306 |
+
self.producer_data_pipe.close()
|
| 307 |
+
self.consumer_data_pipe.close()
|
| 308 |
+
if self.config.kv_connector == "PyNcclConnector":
|
| 309 |
+
self.producer_signal_pipe.close()
|
| 310 |
+
self.consumer_signal_pipe.close()
|
| 311 |
+
elif self.config.kv_connector == "MooncakeConnector":
|
| 312 |
+
# MooncakePipe reuses data_pipe for signal_pipe, so we only have to
|
| 313 |
+
# close the data_pipe.
|
| 314 |
+
pass
|
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (218 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/__pycache__/base.cpython-311.pyc
ADDED
|
Binary file (5.13 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/__pycache__/simple_buffer.cpython-311.pyc
ADDED
|
Binary file (11.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""
|
| 3 |
+
This file contains a new class `KVLookupBufferBase` that allows developers to
|
| 4 |
+
think of KV cache operations as inserting new KV cache entries (`insert`)
|
| 5 |
+
into the lookup buffer and querying existing KV caches (`drop_select`)
|
| 6 |
+
from the lookup buffer.
|
| 7 |
+
|
| 8 |
+
All distributed communications are abstracted behind this class.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from abc import ABC, abstractmethod
|
| 12 |
+
from typing import List, Optional
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class KVLookupBufferBase(ABC):
|
| 18 |
+
"""
|
| 19 |
+
Abstract base class for a lookup buffer.
|
| 20 |
+
|
| 21 |
+
This class provides an abstraction for a key-value (KV) cache lookup buffer.
|
| 22 |
+
|
| 23 |
+
The key of the lookup buffer:
|
| 24 |
+
- input_tokens: token IDs of the request
|
| 25 |
+
- roi: a binary mask on top of input_tokens.
|
| 26 |
+
- Purpose of roi: Since KV cache may only be available for a subset of
|
| 27 |
+
tokens in the input (for example, when vLLM is connected to an external
|
| 28 |
+
KV cache service), roi specifies the subset of tokens that the KV cache
|
| 29 |
+
is associated with.
|
| 30 |
+
- NOTE: roi can be further extended to describe which part of KV the
|
| 31 |
+
current process is holding (each process may only hold a part of KV
|
| 32 |
+
due to TP and PP). This is not implemented for now.
|
| 33 |
+
|
| 34 |
+
The value of the lookup buffer:
|
| 35 |
+
- key: the key tensor in the KV cache
|
| 36 |
+
- value: the value tensor in the KV cache
|
| 37 |
+
- hidden: the final hidden state generated by model forwarding. This allows
|
| 38 |
+
vLLM to bypass further model forwarding by transmitting the hidden state.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
@abstractmethod
|
| 42 |
+
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
|
| 43 |
+
key: torch.Tensor, value: torch.Tensor,
|
| 44 |
+
hidden: torch.Tensor) -> None:
|
| 45 |
+
"""Insert into the lookup buffer.
|
| 46 |
+
|
| 47 |
+
The functionality is similar to the following python statement
|
| 48 |
+
```
|
| 49 |
+
buffer[input_tokens, roi] = [key, value, hidden]
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
FIXME: in the future, we should only have two arguments, key and value,
|
| 53 |
+
where key is a tensor dict and value is a tensor dict.
|
| 54 |
+
|
| 55 |
+
FIXME: we should transmit both sampler outputs and the hidden states.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
input_tokens (torch.Tensor): token IDs.
|
| 59 |
+
roi (torch.Tensor): A binary mask on top of the input tokens
|
| 60 |
+
key (torch.Tensor): The key tensor in the KV cache.
|
| 61 |
+
value (torch.Tensor): The value tensor in the KV cache.
|
| 62 |
+
hidden (torch.Tensor): The final hidden state tensor generated
|
| 63 |
+
during model forwarding to bypass model
|
| 64 |
+
forwarding.
|
| 65 |
+
|
| 66 |
+
Raises:
|
| 67 |
+
NotImplementedError: This method must be implemented in subclasses.
|
| 68 |
+
"""
|
| 69 |
+
raise NotImplementedError
|
| 70 |
+
|
| 71 |
+
@abstractmethod
|
| 72 |
+
def drop_select(
|
| 73 |
+
self, input_tokens: Optional[torch.Tensor],
|
| 74 |
+
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
|
| 75 |
+
"""Select and *drop* KV cache entries from the lookup buffer.
|
| 76 |
+
|
| 77 |
+
The functionality is similar to the following python statements
|
| 78 |
+
```
|
| 79 |
+
ret = buffer.pop(input_tokens, roi)
|
| 80 |
+
return ret
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
If `input_tokens` and `roi` is `None`, it means selecting any of the
|
| 84 |
+
KV caches in the buffer, return, and remove it from the buffer, useful
|
| 85 |
+
when offloading KV cache to KV cache storage service.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
input_tokens (torch.Tensor): token IDs.
|
| 89 |
+
roi (torch.Tensor): A binary mask on top of the input tokens
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
List[Optional[torch.Tensor]]: A list of tensors. Can be None.
|
| 93 |
+
|
| 94 |
+
Raises:
|
| 95 |
+
NotImplementedError: This method must be implemented in subclasses.
|
| 96 |
+
"""
|
| 97 |
+
raise NotImplementedError
|
| 98 |
+
|
| 99 |
+
@abstractmethod
|
| 100 |
+
def close(self) -> None:
|
| 101 |
+
"""Close the buffer and release resources.
|
| 102 |
+
|
| 103 |
+
This method is responsible for cleaning up resources related to the
|
| 104 |
+
lookup buffer when it is no longer needed.
|
| 105 |
+
|
| 106 |
+
Raises:
|
| 107 |
+
NotImplementedError: This method must be implemented in subclasses.
|
| 108 |
+
"""
|
| 109 |
+
raise NotImplementedError
|
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""
|
| 3 |
+
Implements a distributed key-value (KV) cache transfer mechanism.
|
| 4 |
+
|
| 5 |
+
Key Features:
|
| 6 |
+
- Distributed KV cache transmission using PyNccl pipes.
|
| 7 |
+
- Non-blocking `insert`, blocking `drop_select`.
|
| 8 |
+
- Use CPU signal pipe to avoid racing condition
|
| 9 |
+
- Handles buffer size constraints and provide backpressure mechanism to
|
| 10 |
+
stop the prefill instance when the decode instance is slow.
|
| 11 |
+
"""
|
| 12 |
+
import threading
|
| 13 |
+
import time
|
| 14 |
+
from collections import deque
|
| 15 |
+
from typing import Deque, List, Optional, Union
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import (
|
| 20 |
+
KVLookupBufferBase)
|
| 21 |
+
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
| 22 |
+
from vllm.logger import init_logger
|
| 23 |
+
|
| 24 |
+
logger = init_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class SimpleBuffer(KVLookupBufferBase):
|
| 28 |
+
|
| 29 |
+
def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase,
|
| 30 |
+
buffer_size_thresh: float):
|
| 31 |
+
"""
|
| 32 |
+
signal_pipe: on CPU
|
| 33 |
+
|
| 34 |
+
NOTE: on-device recv will block all threads in the process, making the
|
| 35 |
+
KV cache producer unable to listen to new request while transmitting
|
| 36 |
+
KV cache. Luckily CPU recv only blocks the current thread so we use
|
| 37 |
+
CPU recv to listen to new request.
|
| 38 |
+
|
| 39 |
+
data_pipe: on device (e.g. GPU)
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
self.buffer: Deque[List[torch.Tensor]] = deque()
|
| 43 |
+
|
| 44 |
+
self.buffer_size = 0
|
| 45 |
+
self.buffer_size_threshold = buffer_size_thresh
|
| 46 |
+
self.buffer_lock = threading.Lock()
|
| 47 |
+
self.signal_pipe = signal_pipe
|
| 48 |
+
self.data_pipe = data_pipe
|
| 49 |
+
self.request_handling_thread: Optional[threading.Thread] = None
|
| 50 |
+
|
| 51 |
+
self.normal_signal = torch.tensor([0], device="cpu")
|
| 52 |
+
self.end_signal = None
|
| 53 |
+
|
| 54 |
+
def _matches(self, tokens_roi_sender: List[torch.Tensor],
|
| 55 |
+
tokens_roi_recver: List[torch.Tensor]):
|
| 56 |
+
|
| 57 |
+
# tokens_roi_sender: tokens and roi of the producer (in the buffer)
|
| 58 |
+
# tokens_roi_recver: tokens and roi of the consumer (query)
|
| 59 |
+
|
| 60 |
+
tokens_sender = tokens_roi_sender[0]
|
| 61 |
+
tokens_recver = tokens_roi_recver[0]
|
| 62 |
+
roi_sender = tokens_roi_sender[1]
|
| 63 |
+
roi_recver = tokens_roi_recver[1]
|
| 64 |
+
|
| 65 |
+
if tokens_recver is None:
|
| 66 |
+
# consumer sends an empty request
|
| 67 |
+
# semantics: DROP SELECT * LIMIT 1
|
| 68 |
+
# so any of the data in the buffer can be drop-selected
|
| 69 |
+
return True
|
| 70 |
+
|
| 71 |
+
# Assuming that roi is a binary mask on tokens
|
| 72 |
+
tokens_sender = tokens_sender[roi_sender]
|
| 73 |
+
tokens_recver = tokens_recver[roi_recver]
|
| 74 |
+
|
| 75 |
+
# simple common prefix matching
|
| 76 |
+
min_length = min(len(tokens_sender), len(tokens_recver))
|
| 77 |
+
if torch.allclose(tokens_sender[:min_length],
|
| 78 |
+
tokens_recver[:min_length]):
|
| 79 |
+
return min_length
|
| 80 |
+
|
| 81 |
+
return 0
|
| 82 |
+
|
| 83 |
+
def _send_tensor_and_dec_size(self,
|
| 84 |
+
tensor: Optional[torch.Tensor]) -> None:
|
| 85 |
+
|
| 86 |
+
assert tensor is not None, "Use self.data_pipe.send(None) instead"
|
| 87 |
+
self.buffer_size -= tensor.element_size() * tensor.numel()
|
| 88 |
+
if tensor.dtype == torch.bool:
|
| 89 |
+
tensor = tensor.float()
|
| 90 |
+
self.data_pipe.send_tensor(tensor)
|
| 91 |
+
|
| 92 |
+
def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]):
|
| 93 |
+
|
| 94 |
+
if isinstance(data, torch.Tensor):
|
| 95 |
+
return data.element_size() * data.numel()
|
| 96 |
+
if not data:
|
| 97 |
+
# cannot perform `not data` on a tensor
|
| 98 |
+
# so this check needs to go after the check above
|
| 99 |
+
return 0
|
| 100 |
+
|
| 101 |
+
raise AssertionError(f"Unknown data type {type(data)}")
|
| 102 |
+
|
| 103 |
+
def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor,
|
| 104 |
+
key: torch.Tensor, value: torch.Tensor,
|
| 105 |
+
hidden: torch.Tensor):
|
| 106 |
+
|
| 107 |
+
if isinstance(input_tokens, torch.Tensor):
|
| 108 |
+
input_tokens = input_tokens.clone()
|
| 109 |
+
if isinstance(roi, torch.Tensor):
|
| 110 |
+
roi = roi.clone()
|
| 111 |
+
if isinstance(key, torch.Tensor):
|
| 112 |
+
key = key.clone()
|
| 113 |
+
if isinstance(value, torch.Tensor):
|
| 114 |
+
value = value.clone()
|
| 115 |
+
if isinstance(hidden, torch.Tensor):
|
| 116 |
+
hidden = hidden.clone()
|
| 117 |
+
|
| 118 |
+
buffer_item = [input_tokens, roi, key, value, hidden]
|
| 119 |
+
|
| 120 |
+
with self.buffer_lock:
|
| 121 |
+
for data in buffer_item:
|
| 122 |
+
self.buffer_size += self._get_element_size(data)
|
| 123 |
+
self.buffer.append(buffer_item)
|
| 124 |
+
|
| 125 |
+
def _is_end_signal(self, signal):
|
| 126 |
+
return signal is None
|
| 127 |
+
|
| 128 |
+
def drop_select_handler(self):
|
| 129 |
+
|
| 130 |
+
try:
|
| 131 |
+
|
| 132 |
+
while True:
|
| 133 |
+
signal = self.signal_pipe.recv_tensor()
|
| 134 |
+
if self._is_end_signal(signal):
|
| 135 |
+
logger.info("Received end signal!")
|
| 136 |
+
break
|
| 137 |
+
|
| 138 |
+
input_tokens = self.data_pipe.recv_tensor()
|
| 139 |
+
|
| 140 |
+
roi = self.data_pipe.recv_tensor()
|
| 141 |
+
assert roi is not None, "Please provide the roi when sending "\
|
| 142 |
+
"drop-select request"
|
| 143 |
+
roi = (roi > 0.5)
|
| 144 |
+
tokens_roi_recver = [input_tokens, roi]
|
| 145 |
+
|
| 146 |
+
matched_length = 0
|
| 147 |
+
|
| 148 |
+
# perform input tokens and roi matching
|
| 149 |
+
# FIXME: this matching is O(n), ideally it should be O(1)
|
| 150 |
+
# but this buffer size won't (and shouldn't) be too large so
|
| 151 |
+
# the fix is not urgent.
|
| 152 |
+
with self.buffer_lock:
|
| 153 |
+
|
| 154 |
+
for _ in range(len(self.buffer)):
|
| 155 |
+
|
| 156 |
+
temp_length = self._matches(self.buffer[0],
|
| 157 |
+
tokens_roi_recver)
|
| 158 |
+
if temp_length > 0:
|
| 159 |
+
matched_length = temp_length
|
| 160 |
+
break
|
| 161 |
+
# rotate the element we just accessed to the end
|
| 162 |
+
self.buffer.rotate(-1)
|
| 163 |
+
|
| 164 |
+
if matched_length > 0:
|
| 165 |
+
# need to clone the tensor
|
| 166 |
+
# in case the tensor is freed before sending finishes
|
| 167 |
+
matched_item = self.buffer.popleft()
|
| 168 |
+
for tensor in matched_item:
|
| 169 |
+
self._send_tensor_and_dec_size(tensor)
|
| 170 |
+
|
| 171 |
+
else:
|
| 172 |
+
# no match, just send None
|
| 173 |
+
for _ in range(5):
|
| 174 |
+
self.data_pipe.send_tensor(None)
|
| 175 |
+
|
| 176 |
+
except RuntimeError as e:
|
| 177 |
+
if 'Connection closed by peer' not in str(e):
|
| 178 |
+
raise e
|
| 179 |
+
|
| 180 |
+
logger.debug("Closing drop_select_handler")
|
| 181 |
+
|
| 182 |
+
def drop_select(
|
| 183 |
+
self, input_tokens: Optional[torch.Tensor],
|
| 184 |
+
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
|
| 185 |
+
|
| 186 |
+
assert self.request_handling_thread is None, \
|
| 187 |
+
"drop_select should be called by the KV cache consumer "\
|
| 188 |
+
"(e.g. the decode vLLM instance)"
|
| 189 |
+
|
| 190 |
+
if isinstance(input_tokens, torch.Tensor):
|
| 191 |
+
input_tokens = input_tokens.clone()
|
| 192 |
+
if isinstance(roi, torch.Tensor):
|
| 193 |
+
roi = roi.clone().float()
|
| 194 |
+
|
| 195 |
+
self.signal_pipe.send_tensor(self.normal_signal)
|
| 196 |
+
self.data_pipe.send_tensor(input_tokens)
|
| 197 |
+
self.data_pipe.send_tensor(roi)
|
| 198 |
+
|
| 199 |
+
input_tokens = self.data_pipe.recv_tensor()
|
| 200 |
+
roi = self.data_pipe.recv_tensor()
|
| 201 |
+
if roi is not None:
|
| 202 |
+
# convert from float tensor to bool tensor
|
| 203 |
+
# as PyNccl does not support sending bool tensor
|
| 204 |
+
roi = (roi > 0.5)
|
| 205 |
+
key = self.data_pipe.recv_tensor()
|
| 206 |
+
value = self.data_pipe.recv_tensor()
|
| 207 |
+
hidden = self.data_pipe.recv_tensor()
|
| 208 |
+
|
| 209 |
+
return [input_tokens, roi, key, value, hidden]
|
| 210 |
+
|
| 211 |
+
def full_handler(self):
|
| 212 |
+
time.sleep(0.001)
|
| 213 |
+
|
| 214 |
+
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
|
| 215 |
+
key: torch.Tensor, value: torch.Tensor,
|
| 216 |
+
hidden: torch.Tensor) -> None:
|
| 217 |
+
|
| 218 |
+
if self.buffer_size > self.buffer_size_threshold:
|
| 219 |
+
# log outside the while loop to avoid this message being logged
|
| 220 |
+
# repeatedly.
|
| 221 |
+
logger.debug("KV transfer buffer is full. Handling...")
|
| 222 |
+
while self.buffer_size > self.buffer_size_threshold:
|
| 223 |
+
self.full_handler()
|
| 224 |
+
|
| 225 |
+
self._add_to_buffer(input_tokens, roi, key, value, hidden)
|
| 226 |
+
|
| 227 |
+
# when calling the insert, the current process is a sender
|
| 228 |
+
# need to launch the request handler and start listening to request.
|
| 229 |
+
if self.request_handling_thread is None:
|
| 230 |
+
self.request_handling_thread = threading.Thread(
|
| 231 |
+
target=self.drop_select_handler)
|
| 232 |
+
self.request_handling_thread.start()
|
| 233 |
+
|
| 234 |
+
def close(self):
|
| 235 |
+
|
| 236 |
+
if hasattr(self, "request_handling_thread"
|
| 237 |
+
) and self.request_handling_thread is not None:
|
| 238 |
+
self.request_handling_thread.join()
|
| 239 |
+
|
| 240 |
+
else:
|
| 241 |
+
# TODO: have a explicit close signal and have a explicit way to
|
| 242 |
+
# check if it's requester
|
| 243 |
+
self.signal_pipe.send_tensor(self.end_signal)
|
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (209 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/__pycache__/base.cpython-311.pyc
ADDED
|
Binary file (2.95 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/__pycache__/mooncake_pipe.cpython-311.pyc
ADDED
|
Binary file (18.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/__pycache__/pynccl_pipe.cpython-311.pyc
ADDED
|
Binary file (13.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/base.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""
|
| 3 |
+
This file defines an interface `KVPipeBase`
|
| 4 |
+
that provides an abstraction for sending and receiving tensors, or None, via
|
| 5 |
+
distributed communications.
|
| 6 |
+
|
| 7 |
+
All classes instantiated from this interface are assumed to be a FIFO pipe.
|
| 8 |
+
|
| 9 |
+
If your distributed communication platform already supports key-value lookup,
|
| 10 |
+
you can bypass this interface and directly start from `kv_lookup_buffer`.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from abc import ABC, abstractmethod
|
| 14 |
+
from typing import Optional
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class KVPipeBase(ABC):
|
| 20 |
+
"""
|
| 21 |
+
This class provides an interface for sending and receiving tensors, or
|
| 22 |
+
None, by distributed communications.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
@abstractmethod
|
| 26 |
+
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
|
| 27 |
+
"""Send a tensor, or None, via the pipe.
|
| 28 |
+
|
| 29 |
+
Need to support sending None -- important for error handling.
|
| 30 |
+
|
| 31 |
+
TODO: add a `key` argument so that we can use traditional
|
| 32 |
+
key-value database as the distributed communication mechanism behind
|
| 33 |
+
the pipe.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
tensor (Optional[torch.Tensor]): The tensor to be sent. Can be None.
|
| 37 |
+
|
| 38 |
+
Raises:
|
| 39 |
+
NotImplementedError: This method must be implemented in subclasses.
|
| 40 |
+
"""
|
| 41 |
+
raise NotImplementedError
|
| 42 |
+
|
| 43 |
+
@abstractmethod
|
| 44 |
+
def recv_tensor(self) -> Optional[torch.Tensor]:
|
| 45 |
+
"""Receive a tensor (can be None) from the pipeline.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
Optional[torch.Tensor]: The tensor received from the pipeline. Can
|
| 49 |
+
be None.
|
| 50 |
+
|
| 51 |
+
Raises:
|
| 52 |
+
NotImplementedError: This method must be implemented in subclasses.
|
| 53 |
+
"""
|
| 54 |
+
raise NotImplementedError
|
| 55 |
+
|
| 56 |
+
@abstractmethod
|
| 57 |
+
def close(self) -> None:
|
| 58 |
+
"""Close the pipeline and release resources.
|
| 59 |
+
|
| 60 |
+
This method is responsible for closing the communication pipeline
|
| 61 |
+
and releasing any resources associated with it.
|
| 62 |
+
|
| 63 |
+
Raises:
|
| 64 |
+
NotImplementedError: This method must be implemented in subclasses.
|
| 65 |
+
"""
|
| 66 |
+
raise NotImplementedError
|
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import pickle
|
| 6 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Optional, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import zmq
|
| 12 |
+
|
| 13 |
+
from vllm.config import KVTransferConfig
|
| 14 |
+
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
| 15 |
+
from vllm.logger import init_logger
|
| 16 |
+
|
| 17 |
+
logger = init_logger(__name__)
|
| 18 |
+
NONE_INT = -150886311
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class MooncakeTransferEngineConfig:
|
| 23 |
+
prefill_url: str
|
| 24 |
+
decode_url: str
|
| 25 |
+
metadata_backend: Union[str, None]
|
| 26 |
+
metadata_server: str
|
| 27 |
+
protocol: str
|
| 28 |
+
device_name: str
|
| 29 |
+
|
| 30 |
+
@staticmethod
|
| 31 |
+
def from_file(file_path: str) -> 'MooncakeTransferEngineConfig':
|
| 32 |
+
"""Load the config from a JSON file."""
|
| 33 |
+
with open(file_path) as fin:
|
| 34 |
+
config = json.load(fin)
|
| 35 |
+
return MooncakeTransferEngineConfig(
|
| 36 |
+
prefill_url=config.get("prefill_url"),
|
| 37 |
+
decode_url=config.get("decode_url"),
|
| 38 |
+
metadata_backend=config.get("metadata_backend", None),
|
| 39 |
+
metadata_server=config.get("metadata_server"),
|
| 40 |
+
protocol=config.get("protocol", "tcp"),
|
| 41 |
+
device_name=config.get("device_name", ""),
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
@staticmethod
|
| 45 |
+
def load_from_env() -> 'MooncakeTransferEngineConfig':
|
| 46 |
+
"""Load config from a file specified in the environment variable."""
|
| 47 |
+
config_file_path = os.getenv('MOONCAKE_CONFIG_PATH')
|
| 48 |
+
if config_file_path is None:
|
| 49 |
+
raise ValueError(
|
| 50 |
+
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
|
| 51 |
+
return MooncakeTransferEngineConfig.from_file(config_file_path)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class MooncakeTransferEngine:
|
| 55 |
+
"""Handles the transfer of data using mooncake_vllm_adaptor and ZeroMQ."""
|
| 56 |
+
|
| 57 |
+
def __init__(self, kv_rank: int, local_rank: int):
|
| 58 |
+
try:
|
| 59 |
+
import mooncake_vllm_adaptor as mva
|
| 60 |
+
except ImportError as e:
|
| 61 |
+
raise ImportError(
|
| 62 |
+
"Please install mooncake by following the instructions at "
|
| 63 |
+
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
| 64 |
+
"to run vLLM with MooncakeConnector.") from e
|
| 65 |
+
|
| 66 |
+
self.engine = mva.mooncake_vllm_adaptor()
|
| 67 |
+
self.local_rank = local_rank
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
self.config = MooncakeTransferEngineConfig.load_from_env()
|
| 71 |
+
logger.info("Mooncake Configuration loaded successfully.")
|
| 72 |
+
except ValueError as e:
|
| 73 |
+
logger.error(e)
|
| 74 |
+
raise
|
| 75 |
+
except Exception as exc:
|
| 76 |
+
logger.error(
|
| 77 |
+
"An error occurred while loading the configuration: %s", exc)
|
| 78 |
+
raise
|
| 79 |
+
prefill_host, base_prefill_port = self.config.prefill_url.split(':')
|
| 80 |
+
decode_host, base_decode_port = self.config.decode_url.split(':')
|
| 81 |
+
|
| 82 |
+
# Avoid ports conflict when running prefill and decode on the same node
|
| 83 |
+
if prefill_host == decode_host and \
|
| 84 |
+
base_prefill_port == base_decode_port:
|
| 85 |
+
base_decode_port = str(int(base_decode_port) + 100)
|
| 86 |
+
|
| 87 |
+
prefill_port = int(base_prefill_port) + self.local_rank
|
| 88 |
+
decode_port = int(base_decode_port) + self.local_rank
|
| 89 |
+
self.prefill_url = ':'.join([prefill_host, str(prefill_port)])
|
| 90 |
+
self.decode_url = ':'.join([decode_host, str(decode_port)])
|
| 91 |
+
|
| 92 |
+
self.initialize(self.prefill_url if kv_rank == 0 else self.decode_url,
|
| 93 |
+
self.config.metadata_server, self.config.protocol,
|
| 94 |
+
self.config.device_name, self.config.metadata_backend)
|
| 95 |
+
|
| 96 |
+
self.remote_url = (self.decode_url
|
| 97 |
+
if kv_rank == 0 else self.prefill_url)
|
| 98 |
+
|
| 99 |
+
# Initialize ZeroMQ context and sockets
|
| 100 |
+
self.context = zmq.Context() # type: ignore[attr-defined]
|
| 101 |
+
self.sender_socket = self.context.socket(zmq.constants.PUSH)
|
| 102 |
+
self.receiver_socket = self.context.socket(zmq.constants.PULL)
|
| 103 |
+
self.sender_ack = self.context.socket(zmq.constants.PULL)
|
| 104 |
+
self.receiver_ack = self.context.socket(zmq.constants.PUSH)
|
| 105 |
+
|
| 106 |
+
self.buffer_cleaner = ThreadPoolExecutor(max_workers=1)
|
| 107 |
+
self._setup_metadata_sockets(kv_rank, prefill_host, base_prefill_port,
|
| 108 |
+
decode_host, base_decode_port)
|
| 109 |
+
|
| 110 |
+
def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: str,
|
| 111 |
+
d_host: str, d_port: str) -> None:
|
| 112 |
+
"""Set up ZeroMQ sockets for sending and receiving data."""
|
| 113 |
+
# Offsets < 8 are left for initialization in case tp and pp are enabled
|
| 114 |
+
p_rank_offset = int(p_port) + 8 + self.local_rank * 2
|
| 115 |
+
d_rank_offset = int(d_port) + 8 + self.local_rank * 2
|
| 116 |
+
if kv_rank == 0:
|
| 117 |
+
self.sender_socket.bind(f"tcp://*:{p_rank_offset + 1}")
|
| 118 |
+
self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}")
|
| 119 |
+
self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}")
|
| 120 |
+
self.receiver_ack.bind(f"tcp://*:{p_rank_offset + 2}")
|
| 121 |
+
else:
|
| 122 |
+
self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}")
|
| 123 |
+
self.sender_socket.bind(f"tcp://*:{d_rank_offset + 1}")
|
| 124 |
+
self.receiver_ack.bind(f"tcp://*:{d_rank_offset + 2}")
|
| 125 |
+
self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}")
|
| 126 |
+
|
| 127 |
+
def initialize(self, local_hostname: str, metadata_server: str,
|
| 128 |
+
protocol: str, device_name: str,
|
| 129 |
+
metadata_backend: Union[str, None]) -> None:
|
| 130 |
+
"""Initialize the mooncake instance."""
|
| 131 |
+
if metadata_backend is None:
|
| 132 |
+
self.engine.initialize(local_hostname, metadata_server, protocol,
|
| 133 |
+
device_name)
|
| 134 |
+
else:
|
| 135 |
+
supported_backend = ["etcd", "redis"]
|
| 136 |
+
metadata_backend = metadata_backend.lower()
|
| 137 |
+
if metadata_backend not in supported_backend:
|
| 138 |
+
raise ValueError(
|
| 139 |
+
"Mooncake Configuration error. `metadata_backend`"
|
| 140 |
+
f"should be one of {supported_backend}.")
|
| 141 |
+
|
| 142 |
+
self.engine.initializeExt(local_hostname, metadata_server,
|
| 143 |
+
protocol, device_name, metadata_backend)
|
| 144 |
+
|
| 145 |
+
def allocate_managed_buffer(self, length: int) -> int:
|
| 146 |
+
"""Allocate a managed buffer of the specified length."""
|
| 147 |
+
ret = self.engine.allocateManagedBuffer(length)
|
| 148 |
+
if ret <= 0:
|
| 149 |
+
logger.error("Allocation Return Error")
|
| 150 |
+
raise Exception("Allocation Return Error")
|
| 151 |
+
return ret
|
| 152 |
+
|
| 153 |
+
def free_managed_buffer(self, buffer: int, length: int) -> int:
|
| 154 |
+
"""Free a previously allocated managed buffer."""
|
| 155 |
+
return self.engine.freeManagedBuffer(buffer, length)
|
| 156 |
+
|
| 157 |
+
def transfer_sync(self, buffer: int, peer_buffer_address: int,
|
| 158 |
+
length: int) -> int:
|
| 159 |
+
"""Synchronously transfer data to the specified address."""
|
| 160 |
+
ret = self.engine.transferSync(self.remote_url, buffer,
|
| 161 |
+
peer_buffer_address, length)
|
| 162 |
+
if ret < 0:
|
| 163 |
+
logger.error("Transfer Return Error")
|
| 164 |
+
raise Exception("Transfer Return Error")
|
| 165 |
+
return ret
|
| 166 |
+
|
| 167 |
+
def write_bytes_to_buffer(self, buffer: int, user_data: bytes,
|
| 168 |
+
length: int) -> int:
|
| 169 |
+
"""Write bytes to the allocated buffer."""
|
| 170 |
+
return self.engine.writeBytesToBuffer(buffer, user_data, length)
|
| 171 |
+
|
| 172 |
+
def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes:
|
| 173 |
+
"""Read bytes from the allocated buffer."""
|
| 174 |
+
return self.engine.readBytesFromBuffer(buffer, length)
|
| 175 |
+
|
| 176 |
+
def wait_for_ack(self, src_ptr: int, length: int) -> None:
|
| 177 |
+
"""Asynchronously wait for ACK from the receiver."""
|
| 178 |
+
ack = self.sender_ack.recv_pyobj()
|
| 179 |
+
if ack != b'ACK':
|
| 180 |
+
logger.error("Failed to receive ACK from the receiver")
|
| 181 |
+
|
| 182 |
+
self.free_managed_buffer(src_ptr, length)
|
| 183 |
+
|
| 184 |
+
def send_bytes(self, user_data: bytes) -> None:
|
| 185 |
+
"""Send bytes to the remote process."""
|
| 186 |
+
length = len(user_data)
|
| 187 |
+
src_ptr = self.allocate_managed_buffer(length)
|
| 188 |
+
self.write_bytes_to_buffer(src_ptr, user_data, length)
|
| 189 |
+
self.sender_socket.send_pyobj((src_ptr, length))
|
| 190 |
+
self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length)
|
| 191 |
+
|
| 192 |
+
def recv_bytes(self) -> bytes:
|
| 193 |
+
"""Receive bytes from the remote process."""
|
| 194 |
+
src_ptr, length = self.receiver_socket.recv_pyobj()
|
| 195 |
+
dst_ptr = self.allocate_managed_buffer(length)
|
| 196 |
+
self.transfer_sync(dst_ptr, src_ptr, length)
|
| 197 |
+
ret = self.read_bytes_from_buffer(dst_ptr, length)
|
| 198 |
+
|
| 199 |
+
# Buffer cleanup
|
| 200 |
+
self.receiver_ack.send_pyobj(b'ACK')
|
| 201 |
+
self.free_managed_buffer(dst_ptr, length)
|
| 202 |
+
|
| 203 |
+
return ret
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class MooncakePipe(KVPipeBase):
|
| 207 |
+
"""MooncakeTransferEngine based Pipe implementation."""
|
| 208 |
+
|
| 209 |
+
def __init__(self,
|
| 210 |
+
local_rank: int,
|
| 211 |
+
config: KVTransferConfig,
|
| 212 |
+
device: Optional[str] = None):
|
| 213 |
+
"""Initialize the mooncake pipe and set related parameters."""
|
| 214 |
+
self.config = config
|
| 215 |
+
self.local_rank = local_rank
|
| 216 |
+
self.kv_rank = self.config.kv_rank
|
| 217 |
+
if device is None:
|
| 218 |
+
self.device = self._select_device(self.config.kv_buffer_device)
|
| 219 |
+
else:
|
| 220 |
+
self.device = self._select_device(device)
|
| 221 |
+
|
| 222 |
+
self.transfer_engine = MooncakeTransferEngine(self.kv_rank,
|
| 223 |
+
self.local_rank)
|
| 224 |
+
self.transport_thread: Optional[ThreadPoolExecutor] = None
|
| 225 |
+
self.none_tensor = torch.tensor([NONE_INT], device=self.device)
|
| 226 |
+
|
| 227 |
+
def _select_device(self, device: str) -> torch.device:
|
| 228 |
+
"""Select available device (CUDA or CPU)."""
|
| 229 |
+
logger.info("Selecting device: %s", device)
|
| 230 |
+
if device == "cuda":
|
| 231 |
+
return torch.device(f"cuda:{self.local_rank}")
|
| 232 |
+
else:
|
| 233 |
+
return torch.device("cpu")
|
| 234 |
+
|
| 235 |
+
def tensor_hash(self, tensor: torch.Tensor) -> int:
|
| 236 |
+
"""Calculate the hash value of the tensor."""
|
| 237 |
+
return hash(tensor.data_ptr())
|
| 238 |
+
|
| 239 |
+
def _send_impl(self, tensor: torch.Tensor) -> None:
|
| 240 |
+
"""Implement the tensor sending logic."""
|
| 241 |
+
value_bytes = pickle.dumps(tensor)
|
| 242 |
+
self.transfer_engine.send_bytes(value_bytes)
|
| 243 |
+
|
| 244 |
+
def _recv_impl(self) -> torch.Tensor:
|
| 245 |
+
"""Implement the tensor receiving logic."""
|
| 246 |
+
data = self.transfer_engine.recv_bytes()
|
| 247 |
+
return pickle.loads(data)
|
| 248 |
+
|
| 249 |
+
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
|
| 250 |
+
"""Send tensor to the target process."""
|
| 251 |
+
if self.transport_thread is None:
|
| 252 |
+
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
| 253 |
+
tensor = tensor if tensor is not None else self.none_tensor
|
| 254 |
+
assert (len(tensor.shape) > 0)
|
| 255 |
+
self.transport_thread.submit(self._send_impl, tensor)
|
| 256 |
+
|
| 257 |
+
def recv_tensor(self) -> Optional[torch.Tensor]:
|
| 258 |
+
"""Receive tensor from other processes."""
|
| 259 |
+
if self.transport_thread is None:
|
| 260 |
+
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
| 261 |
+
tensor = self.transport_thread.submit(self._recv_impl).result()
|
| 262 |
+
if tensor.numel() == 1 and tensor.item() == NONE_INT:
|
| 263 |
+
return None
|
| 264 |
+
else:
|
| 265 |
+
return tensor
|
| 266 |
+
|
| 267 |
+
def close(self) -> None:
|
| 268 |
+
"""Cleanup logic when closing the pipe."""
|
| 269 |
+
self.transfer_engine.sender_socket.close()
|
| 270 |
+
self.transfer_engine.receiver_socket.close()
|
| 271 |
+
self.transfer_engine.sender_ack.close()
|
| 272 |
+
self.transfer_engine.receiver_ack.close()
|
| 273 |
+
self.transfer_engine.context.term() # Terminate the ZMQ context
|
| 274 |
+
logger.info("Closed the transfer engine and cleaned up resources.")
|
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""
|
| 3 |
+
This module implements a PyNccl pipe for sending and receiving
|
| 4 |
+
Optional[torch.Tensor] between distributed ranks with advanced
|
| 5 |
+
communication features.
|
| 6 |
+
|
| 7 |
+
Key Features:
|
| 8 |
+
- Supports sending and receiving tensors with metadata
|
| 9 |
+
- Handles both CUDA and CPU device communications
|
| 10 |
+
- Implements a non-blocking tensor transfer mechanism
|
| 11 |
+
- Manages buffer size and provides backpressure control
|
| 12 |
+
- Supports distributed process groups with configurable parameters
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import threading
|
| 16 |
+
import time
|
| 17 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 18 |
+
from typing import Callable, Dict, Optional, Tuple
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from vllm.config import KVTransferConfig
|
| 23 |
+
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
| 24 |
+
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
| 25 |
+
from vllm.distributed.utils import StatelessProcessGroup
|
| 26 |
+
from vllm.logger import init_logger
|
| 27 |
+
|
| 28 |
+
logger = init_logger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class BrokenPipeException(Exception):
|
| 32 |
+
|
| 33 |
+
def __init__(self, message):
|
| 34 |
+
self.message = message
|
| 35 |
+
super().__init__(self.message)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
Metadata = Dict[str, Optional[torch.Tensor]]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class PyNcclPipe(KVPipeBase):
|
| 42 |
+
|
| 43 |
+
METADATA_LENGTH = 16
|
| 44 |
+
MAX_TENSOR_DIMENSIONS = 14
|
| 45 |
+
METADATA_DTYPE = torch.int64
|
| 46 |
+
|
| 47 |
+
def __init__(self,
|
| 48 |
+
local_rank: int,
|
| 49 |
+
config: KVTransferConfig,
|
| 50 |
+
device: Optional[str] = None,
|
| 51 |
+
port_offset: int = 0):
|
| 52 |
+
self.config = config
|
| 53 |
+
self.local_rank = local_rank
|
| 54 |
+
self.kv_rank = self.config.kv_rank
|
| 55 |
+
self.kv_parallel_size = self.config.kv_parallel_size
|
| 56 |
+
if device is None:
|
| 57 |
+
self.device = self._select_device(self.config.kv_buffer_device)
|
| 58 |
+
else:
|
| 59 |
+
self.device = self._select_device(device)
|
| 60 |
+
|
| 61 |
+
# build distributed connection and send/recv implementation
|
| 62 |
+
self.group = StatelessProcessGroup.create(
|
| 63 |
+
host=self.config.kv_ip,
|
| 64 |
+
port=self.config.kv_port + port_offset,
|
| 65 |
+
rank=self.kv_rank,
|
| 66 |
+
world_size=self.kv_parallel_size,
|
| 67 |
+
)
|
| 68 |
+
# add a barrier to make sure the connection is initiated properly
|
| 69 |
+
self.group.barrier()
|
| 70 |
+
impl = self._get_device_send_recv_impl(self.group)
|
| 71 |
+
self.device_send_func, self.device_recv_func = impl
|
| 72 |
+
# set target rank
|
| 73 |
+
self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size
|
| 74 |
+
self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size
|
| 75 |
+
|
| 76 |
+
# transportation-related variables
|
| 77 |
+
self.transport_thread: Optional[ThreadPoolExecutor] = None
|
| 78 |
+
self.buffer_size = 0
|
| 79 |
+
self.buffer_size_lock = threading.Lock()
|
| 80 |
+
self.buffer_size_thresh = self.config.kv_buffer_size
|
| 81 |
+
|
| 82 |
+
def _get_device_send_recv_impl(
|
| 83 |
+
self, group: StatelessProcessGroup
|
| 84 |
+
) -> Tuple[Callable[[torch.Tensor, int], None], Callable[
|
| 85 |
+
[torch.Tensor, int], None]]:
|
| 86 |
+
|
| 87 |
+
send: Callable[[torch.Tensor, int], None]
|
| 88 |
+
recv: Callable[[torch.Tensor, int], None]
|
| 89 |
+
if self.device.type == "cuda":
|
| 90 |
+
# use PyNCCL for send / recv
|
| 91 |
+
comm = PyNcclCommunicator(group, device=self.local_rank)
|
| 92 |
+
comm.disabled = False
|
| 93 |
+
send, recv = comm.send, comm.recv # type: ignore
|
| 94 |
+
else:
|
| 95 |
+
# This send / recv implementation here is NOT intended to transfer
|
| 96 |
+
# KV caches (and should NOT be repurposed to transfer KV caches).
|
| 97 |
+
# Currently it is only used to transmit control-plane messages
|
| 98 |
+
# for PyNcclBuffer.
|
| 99 |
+
send = group.send_obj
|
| 100 |
+
|
| 101 |
+
def my_recv(x, src):
|
| 102 |
+
x[...] = group.recv_obj(src)
|
| 103 |
+
|
| 104 |
+
recv = my_recv
|
| 105 |
+
|
| 106 |
+
return send, recv
|
| 107 |
+
|
| 108 |
+
def _select_device(self, device: str):
|
| 109 |
+
logger.info("Selecting device: %s", device)
|
| 110 |
+
if device == "cuda":
|
| 111 |
+
return torch.device(f"cuda:{self.local_rank}")
|
| 112 |
+
else:
|
| 113 |
+
return torch.device("cpu")
|
| 114 |
+
|
| 115 |
+
def _make_metadata(self, tensor: Optional[torch.Tensor]) -> Metadata:
|
| 116 |
+
"""
|
| 117 |
+
Create the metadata as a dictionary based on the input tensor.
|
| 118 |
+
|
| 119 |
+
Parameters:
|
| 120 |
+
- tensor: The input tensor or None if no tensor is provided.
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
- metadata: A dictionary with the following keys:
|
| 124 |
+
- "dtype": The data type of the tensor or None.
|
| 125 |
+
- "shape": The shape of the tensor or None.
|
| 126 |
+
"""
|
| 127 |
+
if tensor is None:
|
| 128 |
+
return {"dtype": None, "shape": None}
|
| 129 |
+
else:
|
| 130 |
+
return {"dtype": tensor.dtype, "shape": tensor.shape}
|
| 131 |
+
|
| 132 |
+
def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor:
|
| 133 |
+
"""
|
| 134 |
+
Create a buffer to receive the tensor based on the provided metadata.
|
| 135 |
+
|
| 136 |
+
Parameters:
|
| 137 |
+
- metadata: A dictionary with keys "dtype" and "shape", describing
|
| 138 |
+
the tensor's data type and shape.
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
- buffer: A tensor of the specified type and shape, allocated on
|
| 142 |
+
self.device.
|
| 143 |
+
"""
|
| 144 |
+
return torch.empty(metadata["shape"],
|
| 145 |
+
dtype=metadata["dtype"],
|
| 146 |
+
device=self.device)
|
| 147 |
+
|
| 148 |
+
def _send_metadata(self, metadata: Metadata):
|
| 149 |
+
"""
|
| 150 |
+
Send the metadata dictionary to the target rank.
|
| 151 |
+
|
| 152 |
+
Parameters:
|
| 153 |
+
- metadata: A dictionary with keys "dtype" and "shape".
|
| 154 |
+
"""
|
| 155 |
+
self.group.send_obj(metadata, self.target_rank_for_send)
|
| 156 |
+
|
| 157 |
+
def _recv_metadata(self) -> Metadata:
|
| 158 |
+
"""
|
| 159 |
+
Receive the metadata dictionary from the target rank.
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
- metadata: A dictionary with keys "dtype" and "shape" describing
|
| 163 |
+
the tensor.
|
| 164 |
+
"""
|
| 165 |
+
return self.group.recv_obj(self.target_rank_for_recv)
|
| 166 |
+
|
| 167 |
+
def _send_impl(self, tensor: Optional[torch.Tensor]) -> None:
|
| 168 |
+
"""
|
| 169 |
+
The actual implementation of sending the tensor and its metadata to the
|
| 170 |
+
target rank.
|
| 171 |
+
|
| 172 |
+
Parameters:
|
| 173 |
+
- tensor: The input tensor to be sent, or None if no tensor is
|
| 174 |
+
being sent.
|
| 175 |
+
"""
|
| 176 |
+
metadata = self._make_metadata(tensor)
|
| 177 |
+
self._send_metadata(metadata)
|
| 178 |
+
if tensor is not None:
|
| 179 |
+
self.device_send_func(tensor.to(self.device),
|
| 180 |
+
self.target_rank_for_send)
|
| 181 |
+
|
| 182 |
+
def _recv_impl(self) -> Optional[torch.Tensor]:
|
| 183 |
+
"""
|
| 184 |
+
The actual implementation of receiving a tensor and its metadata from
|
| 185 |
+
the target rank.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
- buffer: The received tensor, or None if no tensor is received.
|
| 189 |
+
"""
|
| 190 |
+
metadata = self._recv_metadata()
|
| 191 |
+
if metadata["dtype"] is None:
|
| 192 |
+
return None
|
| 193 |
+
buffer = self._prepare_recv_buffer(metadata)
|
| 194 |
+
self.device_recv_func(buffer, self.target_rank_for_recv)
|
| 195 |
+
|
| 196 |
+
return buffer
|
| 197 |
+
|
| 198 |
+
def send_tensor_wrapper(self, tensor: Optional[torch.Tensor],
|
| 199 |
+
tensor_size: int) -> None:
|
| 200 |
+
"""
|
| 201 |
+
Wrapper for _send_impl to handle exceptions and update buffer size.
|
| 202 |
+
"""
|
| 203 |
+
try:
|
| 204 |
+
self._send_impl(tensor)
|
| 205 |
+
|
| 206 |
+
with self.buffer_size_lock:
|
| 207 |
+
self.buffer_size -= tensor_size
|
| 208 |
+
except Exception as e:
|
| 209 |
+
logger.error("[rank%d]: Exception when trying to send %s, msg: %s",
|
| 210 |
+
torch.distributed.get_rank(), str(tensor), str(e))
|
| 211 |
+
import traceback
|
| 212 |
+
traceback.print_exc()
|
| 213 |
+
|
| 214 |
+
def block_if_full(self):
|
| 215 |
+
"""
|
| 216 |
+
Block the current thread if the buffer size is larger than the
|
| 217 |
+
threshold.
|
| 218 |
+
"""
|
| 219 |
+
while self.buffer_size > self.buffer_size_thresh:
|
| 220 |
+
logger.debug("KV cache transfer pipe is full. Waiting...")
|
| 221 |
+
time.sleep(0.05)
|
| 222 |
+
|
| 223 |
+
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
|
| 224 |
+
"""
|
| 225 |
+
Sends a tensor and its metadata to the destination rank in a
|
| 226 |
+
non-blocking way.
|
| 227 |
+
|
| 228 |
+
Parameters:
|
| 229 |
+
- tensor: The tensor to send, or None if no tensor is being sent.
|
| 230 |
+
"""
|
| 231 |
+
if self.transport_thread is None:
|
| 232 |
+
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
| 233 |
+
|
| 234 |
+
if tensor is not None:
|
| 235 |
+
tensor_size = tensor.element_size() * tensor.numel()
|
| 236 |
+
else:
|
| 237 |
+
tensor_size = 0
|
| 238 |
+
|
| 239 |
+
self.block_if_full()
|
| 240 |
+
|
| 241 |
+
with self.buffer_size_lock:
|
| 242 |
+
self.buffer_size += tensor_size
|
| 243 |
+
|
| 244 |
+
self.transport_thread.submit(self.send_tensor_wrapper, tensor,
|
| 245 |
+
tensor_size)
|
| 246 |
+
|
| 247 |
+
def recv_tensor(self) -> Optional[torch.Tensor]:
|
| 248 |
+
"""
|
| 249 |
+
Receives a tensor and its metadata from the source rank. Blocking call.
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
- tensor: The received tensor, or None if no tensor is received.
|
| 253 |
+
"""
|
| 254 |
+
if self.transport_thread is None:
|
| 255 |
+
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
| 256 |
+
|
| 257 |
+
future = self.transport_thread.submit(self._recv_impl)
|
| 258 |
+
|
| 259 |
+
try:
|
| 260 |
+
tensor = future.result()
|
| 261 |
+
except Exception as e:
|
| 262 |
+
logger.error("Encountering exception in KV receiving thread")
|
| 263 |
+
logger.error("%s", e)
|
| 264 |
+
logger.error("My device: %s", self.device)
|
| 265 |
+
import traceback
|
| 266 |
+
traceback.print_exc()
|
| 267 |
+
raise e
|
| 268 |
+
|
| 269 |
+
return tensor
|
| 270 |
+
|
| 271 |
+
def close(self):
|
| 272 |
+
"""
|
| 273 |
+
Close the pipe and release associated resources.
|
| 274 |
+
"""
|
| 275 |
+
if hasattr(self,
|
| 276 |
+
"transport_thread") and self.transport_thread is not None:
|
| 277 |
+
self.transport_thread.shutdown()
|
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_transfer_agent.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""A centralized entrypoint to perform distributed KV cache transfer.
|
| 3 |
+
|
| 4 |
+
This implementation is a shim wrapper on two APIs exposed by `kv_connector`:
|
| 5 |
+
1. `send_kv_caches_and_hidden_states`
|
| 6 |
+
2. `recv_kv_caches_and_hidden_states
|
| 7 |
+
"""
|
| 8 |
+
from typing import TYPE_CHECKING, List, Tuple, Union
|
| 9 |
+
|
| 10 |
+
if TYPE_CHECKING:
|
| 11 |
+
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
| 12 |
+
from vllm.config import VllmConfig
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
| 17 |
+
KVConnectorFactory)
|
| 18 |
+
from vllm.logger import init_logger
|
| 19 |
+
from vllm.sequence import IntermediateTensors
|
| 20 |
+
|
| 21 |
+
logger = init_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class KVTransferAgent:
|
| 25 |
+
"""
|
| 26 |
+
A class designated for distributed KV transfer
|
| 27 |
+
|
| 28 |
+
Target use cases:
|
| 29 |
+
1. Disaggregated prefill
|
| 30 |
+
2. Remote KV cache storage
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
rank: int,
|
| 36 |
+
local_rank: int,
|
| 37 |
+
config: "VllmConfig",
|
| 38 |
+
):
|
| 39 |
+
|
| 40 |
+
self.config = config
|
| 41 |
+
|
| 42 |
+
if config.kv_transfer_config is None:
|
| 43 |
+
raise ValueError("KVTransferConfig is not set in the VllmConfig,"
|
| 44 |
+
" cannot initialize KVConnector.")
|
| 45 |
+
|
| 46 |
+
assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\
|
| 47 |
+
"TransferAgent should only be used when kv_connector is set."
|
| 48 |
+
|
| 49 |
+
self.connector = KVConnectorFactory.create_connector(
|
| 50 |
+
rank, local_rank, config)
|
| 51 |
+
|
| 52 |
+
def send_kv_caches_and_hidden_states(
|
| 53 |
+
self,
|
| 54 |
+
model_executable: torch.nn.Module,
|
| 55 |
+
model_input: "ModelInputForGPUWithSamplingMetadata",
|
| 56 |
+
kv_caches: List[torch.Tensor],
|
| 57 |
+
hidden_or_intermediate_states: Union[torch.Tensor,
|
| 58 |
+
IntermediateTensors],
|
| 59 |
+
) -> None:
|
| 60 |
+
|
| 61 |
+
self.connector.send_kv_caches_and_hidden_states(
|
| 62 |
+
model_executable, model_input, kv_caches,
|
| 63 |
+
hidden_or_intermediate_states)
|
| 64 |
+
|
| 65 |
+
def close(self) -> None:
|
| 66 |
+
self.connector.close()
|
| 67 |
+
|
| 68 |
+
def recv_kv_caches_and_hidden_states(
|
| 69 |
+
self, model_executable: torch.nn.Module,
|
| 70 |
+
model_input: "ModelInputForGPUWithSamplingMetadata",
|
| 71 |
+
kv_caches: List[torch.Tensor]
|
| 72 |
+
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
| 73 |
+
"ModelInputForGPUWithSamplingMetadata"]:
|
| 74 |
+
|
| 75 |
+
return self.connector.recv_kv_caches_and_hidden_states(
|
| 76 |
+
model_executable, model_input, kv_caches)
|
.venv/lib/python3.11/site-packages/vllm/distributed/parallel_state.py
ADDED
|
@@ -0,0 +1,1285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Copyright 2023 The vLLM team.
|
| 4 |
+
# Adapted from
|
| 5 |
+
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
| 6 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
| 7 |
+
"""vLLM distributed state.
|
| 8 |
+
It takes over the control of the distributed environment from PyTorch.
|
| 9 |
+
The typical workflow is:
|
| 10 |
+
|
| 11 |
+
- call `init_distributed_environment` to initialize the distributed environment.
|
| 12 |
+
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
|
| 13 |
+
initialize the model parallel groups.
|
| 14 |
+
|
| 15 |
+
- any code dealing with the distributed stuff
|
| 16 |
+
|
| 17 |
+
- call `destroy_model_parallel` to destroy the model parallel groups.
|
| 18 |
+
- call `destroy_distributed_environment` to destroy the distributed environment.
|
| 19 |
+
|
| 20 |
+
If you only need to use the distributed environment without model/pipeline
|
| 21 |
+
parallelism, you can skip the model parallel initialization and destruction
|
| 22 |
+
steps.
|
| 23 |
+
"""
|
| 24 |
+
import contextlib
|
| 25 |
+
import gc
|
| 26 |
+
import pickle
|
| 27 |
+
import weakref
|
| 28 |
+
from collections import namedtuple
|
| 29 |
+
from contextlib import contextmanager, nullcontext
|
| 30 |
+
from dataclasses import dataclass
|
| 31 |
+
from multiprocessing import shared_memory
|
| 32 |
+
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
|
| 33 |
+
Union)
|
| 34 |
+
from unittest.mock import patch
|
| 35 |
+
|
| 36 |
+
import torch
|
| 37 |
+
import torch.distributed
|
| 38 |
+
from torch.distributed import Backend, ProcessGroup
|
| 39 |
+
|
| 40 |
+
import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
|
| 41 |
+
import vllm.envs as envs
|
| 42 |
+
from vllm.distributed.utils import StatelessProcessGroup
|
| 43 |
+
from vllm.logger import init_logger
|
| 44 |
+
from vllm.utils import direct_register_custom_op, supports_custom_op
|
| 45 |
+
|
| 46 |
+
if TYPE_CHECKING:
|
| 47 |
+
from vllm.config import VllmConfig
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@dataclass
|
| 51 |
+
class GraphCaptureContext:
|
| 52 |
+
stream: torch.cuda.Stream
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _split_tensor_dict(
|
| 59 |
+
tensor_dict: Dict[str, Union[torch.Tensor, Any]]
|
| 60 |
+
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
|
| 61 |
+
"""Split the tensor dictionary into two parts:
|
| 62 |
+
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
|
| 63 |
+
by its metadata.
|
| 64 |
+
2. A list of tensors.
|
| 65 |
+
"""
|
| 66 |
+
metadata_list: List[Tuple[str, Any]] = []
|
| 67 |
+
tensor_list: List[torch.Tensor] = []
|
| 68 |
+
for key, value in tensor_dict.items():
|
| 69 |
+
if isinstance(value, torch.Tensor):
|
| 70 |
+
# Note: we cannot use `value.device` here,
|
| 71 |
+
# because it contains not only the device type but also the device
|
| 72 |
+
# index (e.g. "cuda:0"). We only need the device type.
|
| 73 |
+
# receiving side will set the device index.
|
| 74 |
+
device = value.device.type
|
| 75 |
+
metadata_list.append(
|
| 76 |
+
(key, TensorMetadata(device, value.dtype, value.size())))
|
| 77 |
+
tensor_list.append(value)
|
| 78 |
+
else:
|
| 79 |
+
metadata_list.append((key, value))
|
| 80 |
+
return metadata_list, tensor_list
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
_group_name_counter: Dict[str, int] = {}
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _get_unique_name(name: str) -> str:
|
| 87 |
+
"""Get a unique name for the group.
|
| 88 |
+
Example:
|
| 89 |
+
_get_unique_name("tp") -> "tp:0"
|
| 90 |
+
_get_unique_name("tp") -> "tp:1"
|
| 91 |
+
"""
|
| 92 |
+
if name not in _group_name_counter:
|
| 93 |
+
_group_name_counter[name] = 0
|
| 94 |
+
newname = f"{name}:{_group_name_counter[name]}"
|
| 95 |
+
_group_name_counter[name] += 1
|
| 96 |
+
return newname
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
_groups: Dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _register_group(group: "GroupCoordinator") -> None:
|
| 103 |
+
_groups[group.unique_name] = weakref.ref(group)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
| 107 |
+
assert group_name in _groups, f"Group {group_name} is not found."
|
| 108 |
+
group = _groups[group_name]()
|
| 109 |
+
if group is None:
|
| 110 |
+
raise ValueError(f"Group {group_name} is destroyed.")
|
| 111 |
+
return group._all_reduce_out_place(tensor)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
| 115 |
+
return torch.empty_like(tensor)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
if supports_custom_op():
|
| 119 |
+
direct_register_custom_op(
|
| 120 |
+
op_name="all_reduce",
|
| 121 |
+
op_func=all_reduce,
|
| 122 |
+
mutates_args=[],
|
| 123 |
+
fake_impl=all_reduce_fake,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class GroupCoordinator:
|
| 128 |
+
"""
|
| 129 |
+
PyTorch ProcessGroup wrapper for a group of processes.
|
| 130 |
+
PyTorch ProcessGroup is bound to one specific communication backend,
|
| 131 |
+
e.g. NCCL, Gloo, MPI, etc.
|
| 132 |
+
GroupCoordinator takes charge of all the communication operations among
|
| 133 |
+
the processes in the group. It can route the communication to
|
| 134 |
+
a specific implementation (e.g. switch allreduce implementation
|
| 135 |
+
based on the tensor size and cuda graph mode).
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
# available attributes:
|
| 139 |
+
rank: int # global rank
|
| 140 |
+
ranks: List[int] # global ranks in the group
|
| 141 |
+
world_size: int # size of the group
|
| 142 |
+
# difference between `local_rank` and `rank_in_group`:
|
| 143 |
+
# if we have a group of size 4 across two nodes:
|
| 144 |
+
# Process | Node | Rank | Local Rank | Rank in Group
|
| 145 |
+
# 0 | 0 | 0 | 0 | 0
|
| 146 |
+
# 1 | 0 | 1 | 1 | 1
|
| 147 |
+
# 2 | 1 | 2 | 0 | 2
|
| 148 |
+
# 3 | 1 | 3 | 1 | 3
|
| 149 |
+
local_rank: int # local rank used to assign devices
|
| 150 |
+
rank_in_group: int # rank inside the group
|
| 151 |
+
cpu_group: ProcessGroup # group for CPU communication
|
| 152 |
+
device_group: ProcessGroup # group for device communication
|
| 153 |
+
use_pynccl: bool # a hint of whether to use PyNccl
|
| 154 |
+
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
|
| 155 |
+
# communicators are only created for world size > 1
|
| 156 |
+
pynccl_comm: Optional[Any] # PyNccl communicator
|
| 157 |
+
ca_comm: Optional[Any] # Custom allreduce communicator
|
| 158 |
+
mq_broadcaster: Optional[Any] # shared memory broadcaster
|
| 159 |
+
|
| 160 |
+
def __init__(
|
| 161 |
+
self,
|
| 162 |
+
group_ranks: List[List[int]],
|
| 163 |
+
local_rank: int,
|
| 164 |
+
torch_distributed_backend: Union[str, Backend],
|
| 165 |
+
use_pynccl: bool,
|
| 166 |
+
use_custom_allreduce: bool,
|
| 167 |
+
use_tpu_communicator: bool,
|
| 168 |
+
use_hpu_communicator: bool,
|
| 169 |
+
use_xpu_communicator: bool,
|
| 170 |
+
use_message_queue_broadcaster: bool = False,
|
| 171 |
+
group_name: Optional[str] = None,
|
| 172 |
+
):
|
| 173 |
+
group_name = group_name or "anonymous"
|
| 174 |
+
self.unique_name = _get_unique_name(group_name)
|
| 175 |
+
_register_group(self)
|
| 176 |
+
|
| 177 |
+
self.rank = torch.distributed.get_rank()
|
| 178 |
+
self.local_rank = local_rank
|
| 179 |
+
self.device_group = None
|
| 180 |
+
self.cpu_group = None
|
| 181 |
+
|
| 182 |
+
for ranks in group_ranks:
|
| 183 |
+
device_group = torch.distributed.new_group(
|
| 184 |
+
ranks, backend=torch_distributed_backend)
|
| 185 |
+
# a group with `gloo` backend, to allow direct coordination between
|
| 186 |
+
# processes through the CPU.
|
| 187 |
+
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
|
| 188 |
+
if self.rank in ranks:
|
| 189 |
+
self.ranks = ranks
|
| 190 |
+
self.world_size = len(ranks)
|
| 191 |
+
self.rank_in_group = ranks.index(self.rank)
|
| 192 |
+
self.device_group = device_group
|
| 193 |
+
self.cpu_group = cpu_group
|
| 194 |
+
|
| 195 |
+
assert self.cpu_group is not None
|
| 196 |
+
assert self.device_group is not None
|
| 197 |
+
|
| 198 |
+
from vllm.platforms import current_platform
|
| 199 |
+
if current_platform.is_cuda_alike():
|
| 200 |
+
self.device = torch.device(f"cuda:{local_rank}")
|
| 201 |
+
else:
|
| 202 |
+
self.device = torch.device("cpu")
|
| 203 |
+
|
| 204 |
+
self.use_pynccl = use_pynccl
|
| 205 |
+
self.use_custom_allreduce = use_custom_allreduce
|
| 206 |
+
self.use_tpu_communicator = use_tpu_communicator
|
| 207 |
+
self.use_hpu_communicator = use_hpu_communicator
|
| 208 |
+
self.use_xpu_communicator = use_xpu_communicator
|
| 209 |
+
|
| 210 |
+
# lazy import to avoid documentation build error
|
| 211 |
+
from vllm.distributed.device_communicators.custom_all_reduce import (
|
| 212 |
+
CustomAllreduce)
|
| 213 |
+
from vllm.distributed.device_communicators.pynccl import (
|
| 214 |
+
PyNcclCommunicator)
|
| 215 |
+
|
| 216 |
+
self.pynccl_comm: Optional[PyNcclCommunicator] = None
|
| 217 |
+
if use_pynccl and self.world_size > 1:
|
| 218 |
+
self.pynccl_comm = PyNcclCommunicator(
|
| 219 |
+
group=self.cpu_group,
|
| 220 |
+
device=self.device,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
self.ca_comm: Optional[CustomAllreduce] = None
|
| 224 |
+
if use_custom_allreduce and self.world_size > 1:
|
| 225 |
+
# Initialize a custom fast all-reduce implementation.
|
| 226 |
+
self.ca_comm = CustomAllreduce(
|
| 227 |
+
group=self.cpu_group,
|
| 228 |
+
device=self.device,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
from vllm.distributed.device_communicators.tpu_communicator import (
|
| 232 |
+
TpuCommunicator)
|
| 233 |
+
self.tpu_communicator: Optional[TpuCommunicator] = None
|
| 234 |
+
if use_tpu_communicator and self.world_size > 1:
|
| 235 |
+
self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
|
| 236 |
+
|
| 237 |
+
from vllm.distributed.device_communicators.hpu_communicator import (
|
| 238 |
+
HpuCommunicator)
|
| 239 |
+
self.hpu_communicator: Optional[HpuCommunicator]
|
| 240 |
+
if use_hpu_communicator and self.world_size > 1:
|
| 241 |
+
self.hpu_communicator = HpuCommunicator(group=self.device_group)
|
| 242 |
+
|
| 243 |
+
from vllm.distributed.device_communicators.xpu_communicator import (
|
| 244 |
+
XpuCommunicator)
|
| 245 |
+
self.xpu_communicator: Optional[XpuCommunicator]
|
| 246 |
+
if use_xpu_communicator and self.world_size > 1:
|
| 247 |
+
self.xpu_communicator = XpuCommunicator(group=self.device_group)
|
| 248 |
+
|
| 249 |
+
from vllm.distributed.device_communicators.shm_broadcast import (
|
| 250 |
+
MessageQueue)
|
| 251 |
+
self.mq_broadcaster: Optional[MessageQueue] = None
|
| 252 |
+
if use_message_queue_broadcaster and self.world_size > 1:
|
| 253 |
+
self.mq_broadcaster = MessageQueue.create_from_process_group(
|
| 254 |
+
self.cpu_group, 1 << 22, 6)
|
| 255 |
+
|
| 256 |
+
@property
|
| 257 |
+
def first_rank(self):
|
| 258 |
+
"""Return the global rank of the first process in the group"""
|
| 259 |
+
return self.ranks[0]
|
| 260 |
+
|
| 261 |
+
@property
|
| 262 |
+
def last_rank(self):
|
| 263 |
+
"""Return the global rank of the last process in the group"""
|
| 264 |
+
return self.ranks[-1]
|
| 265 |
+
|
| 266 |
+
@property
|
| 267 |
+
def is_first_rank(self):
|
| 268 |
+
"""Return whether the caller is the first process in the group"""
|
| 269 |
+
return self.rank == self.first_rank
|
| 270 |
+
|
| 271 |
+
@property
|
| 272 |
+
def is_last_rank(self):
|
| 273 |
+
"""Return whether the caller is the last process in the group"""
|
| 274 |
+
return self.rank == self.last_rank
|
| 275 |
+
|
| 276 |
+
@property
|
| 277 |
+
def next_rank(self):
|
| 278 |
+
"""Return the global rank of the process that follows the caller"""
|
| 279 |
+
rank_in_group = self.rank_in_group
|
| 280 |
+
world_size = self.world_size
|
| 281 |
+
return self.ranks[(rank_in_group + 1) % world_size]
|
| 282 |
+
|
| 283 |
+
@property
|
| 284 |
+
def prev_rank(self):
|
| 285 |
+
"""Return the global rank of the process that precedes the caller"""
|
| 286 |
+
rank_in_group = self.rank_in_group
|
| 287 |
+
world_size = self.world_size
|
| 288 |
+
return self.ranks[(rank_in_group - 1) % world_size]
|
| 289 |
+
|
| 290 |
+
@contextmanager
|
| 291 |
+
def graph_capture(
|
| 292 |
+
self, graph_capture_context: Optional[GraphCaptureContext] = None):
|
| 293 |
+
if graph_capture_context is None:
|
| 294 |
+
stream = torch.cuda.Stream()
|
| 295 |
+
graph_capture_context = GraphCaptureContext(stream)
|
| 296 |
+
else:
|
| 297 |
+
stream = graph_capture_context.stream
|
| 298 |
+
|
| 299 |
+
ca_comm = self.ca_comm
|
| 300 |
+
maybe_ca_context = nullcontext(
|
| 301 |
+
) if ca_comm is None else ca_comm.capture()
|
| 302 |
+
|
| 303 |
+
# ensure all initialization operations complete before attempting to
|
| 304 |
+
# capture the graph on another stream
|
| 305 |
+
curr_stream = torch.cuda.current_stream()
|
| 306 |
+
if curr_stream != stream:
|
| 307 |
+
stream.wait_stream(curr_stream)
|
| 308 |
+
|
| 309 |
+
with torch.cuda.stream(stream), maybe_ca_context:
|
| 310 |
+
yield graph_capture_context
|
| 311 |
+
|
| 312 |
+
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
| 313 |
+
"""
|
| 314 |
+
User-facing all-reduce function before we actually call the
|
| 315 |
+
all-reduce operation.
|
| 316 |
+
|
| 317 |
+
We need this because Dynamo does not support passing an arbitrary
|
| 318 |
+
object (`self` in this case) to a custom op. We need to pass the
|
| 319 |
+
group name as a string, and then look up the group coordinator from
|
| 320 |
+
the group name, dispatch the all-reduce operation to the group
|
| 321 |
+
coordinator.
|
| 322 |
+
|
| 323 |
+
In addition, PyTorch custom ops do not support mutation or returning
|
| 324 |
+
a new tensor in the same op. So we always make the all-reduce operation
|
| 325 |
+
out-of-place.
|
| 326 |
+
"""
|
| 327 |
+
# Bypass the function if we are using only 1 GPU.
|
| 328 |
+
if self.world_size == 1:
|
| 329 |
+
return input_
|
| 330 |
+
|
| 331 |
+
if input_.is_cpu:
|
| 332 |
+
try:
|
| 333 |
+
import intel_extension_for_pytorch as ipex
|
| 334 |
+
ipex.distributed.all_reduce(input_, group=self.device_group)
|
| 335 |
+
return input_
|
| 336 |
+
except ImportError:
|
| 337 |
+
"""
|
| 338 |
+
Intel IPEX not found. Falling back to PyTorch native
|
| 339 |
+
all_reduce for CPU
|
| 340 |
+
"""
|
| 341 |
+
torch.distributed.all_reduce(input_, group=self.device_group)
|
| 342 |
+
return input_
|
| 343 |
+
|
| 344 |
+
if self.tpu_communicator is not None and \
|
| 345 |
+
not self.tpu_communicator.disabled:
|
| 346 |
+
# TPU handles Dynamo with its own logic.
|
| 347 |
+
return self.tpu_communicator.all_reduce(input_)
|
| 348 |
+
|
| 349 |
+
if self.hpu_communicator is not None and \
|
| 350 |
+
not self.hpu_communicator.disabled:
|
| 351 |
+
return self.hpu_communicator.all_reduce(input_)
|
| 352 |
+
|
| 353 |
+
if self.xpu_communicator is not None and \
|
| 354 |
+
not self.xpu_communicator.disabled:
|
| 355 |
+
return self.xpu_communicator.all_reduce(input_)
|
| 356 |
+
|
| 357 |
+
return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)
|
| 358 |
+
|
| 359 |
+
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
|
| 360 |
+
# always try custom allreduce first,
|
| 361 |
+
# and then pynccl.
|
| 362 |
+
ca_comm = self.ca_comm
|
| 363 |
+
if ca_comm is not None and not ca_comm.disabled and \
|
| 364 |
+
ca_comm.should_custom_ar(input_):
|
| 365 |
+
out = ca_comm.custom_all_reduce(input_)
|
| 366 |
+
assert out is not None
|
| 367 |
+
return out
|
| 368 |
+
pynccl_comm = self.pynccl_comm
|
| 369 |
+
assert pynccl_comm is not None
|
| 370 |
+
out = pynccl_comm.all_reduce(input_)
|
| 371 |
+
if out is None:
|
| 372 |
+
# fall back to the default all-reduce using PyTorch.
|
| 373 |
+
# this usually happens during testing.
|
| 374 |
+
# when we run the model, allreduce only happens for the TP
|
| 375 |
+
# group, where we always have either custom allreduce or pynccl.
|
| 376 |
+
out = input_.clone()
|
| 377 |
+
torch.distributed.all_reduce(out, group=self.device_group)
|
| 378 |
+
return out
|
| 379 |
+
|
| 380 |
+
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
| 381 |
+
world_size = self.world_size
|
| 382 |
+
# Bypass the function if we are using only 1 GPU.
|
| 383 |
+
if world_size == 1:
|
| 384 |
+
return input_
|
| 385 |
+
assert -input_.dim() <= dim < input_.dim(), (
|
| 386 |
+
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
| 387 |
+
|
| 388 |
+
# For TPUs, use TPU communicator.
|
| 389 |
+
tpu_comm = self.tpu_communicator
|
| 390 |
+
if tpu_comm is not None and not tpu_comm.disabled:
|
| 391 |
+
return tpu_comm.all_gather(input_, dim)
|
| 392 |
+
|
| 393 |
+
# For HPUs, use HPU communicator.
|
| 394 |
+
hpu_comm = self.hpu_communicator
|
| 395 |
+
if hpu_comm is not None and not hpu_comm.disabled:
|
| 396 |
+
return hpu_comm.all_gather(input_, dim)
|
| 397 |
+
|
| 398 |
+
if dim < 0:
|
| 399 |
+
# Convert negative dim to positive.
|
| 400 |
+
dim += input_.dim()
|
| 401 |
+
input_size = input_.size()
|
| 402 |
+
# NOTE: we have to use concat-style all-gather here,
|
| 403 |
+
# stack-style all-gather has compatibility issues with
|
| 404 |
+
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
|
| 405 |
+
output_size = (input_size[0] * world_size, ) + input_size[1:]
|
| 406 |
+
# Allocate output tensor.
|
| 407 |
+
output_tensor = torch.empty(output_size,
|
| 408 |
+
dtype=input_.dtype,
|
| 409 |
+
device=input_.device)
|
| 410 |
+
# All-gather.
|
| 411 |
+
torch.distributed.all_gather_into_tensor(output_tensor,
|
| 412 |
+
input_,
|
| 413 |
+
group=self.device_group)
|
| 414 |
+
# Reshape
|
| 415 |
+
output_tensor = output_tensor.reshape((world_size, ) + input_size)
|
| 416 |
+
output_tensor = output_tensor.movedim(0, dim)
|
| 417 |
+
output_tensor = output_tensor.reshape(input_size[:dim] +
|
| 418 |
+
(world_size *
|
| 419 |
+
input_size[dim], ) +
|
| 420 |
+
input_size[dim + 1:])
|
| 421 |
+
return output_tensor
|
| 422 |
+
|
| 423 |
+
def gather(self,
|
| 424 |
+
input_: torch.Tensor,
|
| 425 |
+
dst: int = 0,
|
| 426 |
+
dim: int = -1) -> Optional[torch.Tensor]:
|
| 427 |
+
"""
|
| 428 |
+
NOTE: We assume that the input tensor is on the same device across
|
| 429 |
+
all the ranks.
|
| 430 |
+
NOTE: `dst` is the local rank of the destination rank.
|
| 431 |
+
"""
|
| 432 |
+
world_size = self.world_size
|
| 433 |
+
# Bypass the function if we are using only 1 GPU.
|
| 434 |
+
if world_size == 1:
|
| 435 |
+
return input_
|
| 436 |
+
assert -input_.dim() <= dim < input_.dim(), (
|
| 437 |
+
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
| 438 |
+
if dim < 0:
|
| 439 |
+
# Convert negative dim to positive.
|
| 440 |
+
dim += input_.dim()
|
| 441 |
+
if self.xpu_communicator is not None and \
|
| 442 |
+
not self.xpu_communicator.disabled:
|
| 443 |
+
return self.xpu_communicator.gather(input_, self.rank_in_group,
|
| 444 |
+
dst, dim)
|
| 445 |
+
# Allocate output tensor.
|
| 446 |
+
if self.rank_in_group == dst:
|
| 447 |
+
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
|
| 448 |
+
else:
|
| 449 |
+
gather_list = None
|
| 450 |
+
# Gather.
|
| 451 |
+
torch.distributed.gather(input_,
|
| 452 |
+
gather_list,
|
| 453 |
+
dst=self.ranks[dst],
|
| 454 |
+
group=self.device_group)
|
| 455 |
+
if self.rank_in_group == dst:
|
| 456 |
+
output_tensor = torch.cat(gather_list, dim=dim)
|
| 457 |
+
else:
|
| 458 |
+
output_tensor = None
|
| 459 |
+
return output_tensor
|
| 460 |
+
|
| 461 |
+
def broadcast(self, input_: torch.Tensor, src: int = 0):
|
| 462 |
+
"""Broadcast the input tensor.
|
| 463 |
+
NOTE: `src` is the local rank of the source rank.
|
| 464 |
+
"""
|
| 465 |
+
assert src < self.world_size, f"Invalid src rank ({src})"
|
| 466 |
+
|
| 467 |
+
# Bypass the function if we are using only 1 GPU.
|
| 468 |
+
if self.world_size == 1:
|
| 469 |
+
return input_
|
| 470 |
+
# Broadcast.
|
| 471 |
+
torch.distributed.broadcast(input_,
|
| 472 |
+
src=self.ranks[src],
|
| 473 |
+
group=self.device_group)
|
| 474 |
+
return input_
|
| 475 |
+
|
| 476 |
+
def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):
|
| 477 |
+
"""Broadcast the input object.
|
| 478 |
+
NOTE: `src` is the local rank of the source rank.
|
| 479 |
+
"""
|
| 480 |
+
assert src < self.world_size, f"Invalid src rank ({src})"
|
| 481 |
+
|
| 482 |
+
# Bypass the function if we are using only 1 GPU.
|
| 483 |
+
if self.world_size == 1:
|
| 484 |
+
return obj
|
| 485 |
+
if self.mq_broadcaster is not None:
|
| 486 |
+
assert src == 0, "Message queue broadcaster only supports src=0"
|
| 487 |
+
return self.mq_broadcaster.broadcast_object(obj)
|
| 488 |
+
if self.rank_in_group == src:
|
| 489 |
+
torch.distributed.broadcast_object_list([obj],
|
| 490 |
+
src=self.ranks[src],
|
| 491 |
+
group=self.cpu_group)
|
| 492 |
+
return obj
|
| 493 |
+
else:
|
| 494 |
+
recv = [None]
|
| 495 |
+
torch.distributed.broadcast_object_list(recv,
|
| 496 |
+
src=self.ranks[src],
|
| 497 |
+
group=self.cpu_group)
|
| 498 |
+
return recv[0]
|
| 499 |
+
|
| 500 |
+
def broadcast_object_list(self,
|
| 501 |
+
obj_list: List[Any],
|
| 502 |
+
src: int = 0,
|
| 503 |
+
group: Optional[ProcessGroup] = None):
|
| 504 |
+
"""Broadcast the input object list.
|
| 505 |
+
NOTE: `src` is the local rank of the source rank.
|
| 506 |
+
"""
|
| 507 |
+
assert src < self.world_size, f"Invalid src rank ({src})"
|
| 508 |
+
|
| 509 |
+
# Bypass the function if we are using only 1 GPU.
|
| 510 |
+
if self.world_size == 1:
|
| 511 |
+
return obj_list
|
| 512 |
+
# Broadcast.
|
| 513 |
+
torch.distributed.broadcast_object_list(obj_list,
|
| 514 |
+
src=self.ranks[src],
|
| 515 |
+
group=self.device_group)
|
| 516 |
+
return obj_list
|
| 517 |
+
|
| 518 |
+
def send_object(self, obj: Any, dst: int) -> None:
|
| 519 |
+
"""Send the input object list to the destination rank."""
|
| 520 |
+
"""NOTE: `dst` is the local rank of the destination rank."""
|
| 521 |
+
|
| 522 |
+
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
| 523 |
+
|
| 524 |
+
assert dst != self.rank_in_group, (
|
| 525 |
+
"Invalid destination rank. Destination rank is the same "
|
| 526 |
+
"as the current rank.")
|
| 527 |
+
|
| 528 |
+
# Serialize object to tensor and get the size as well
|
| 529 |
+
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
|
| 530 |
+
|
| 531 |
+
size_tensor = torch.tensor([object_tensor.numel()],
|
| 532 |
+
dtype=torch.long,
|
| 533 |
+
device="cpu")
|
| 534 |
+
|
| 535 |
+
# Send object size
|
| 536 |
+
|
| 537 |
+
torch.distributed.send(size_tensor,
|
| 538 |
+
dst=self.ranks[dst],
|
| 539 |
+
group=self.cpu_group)
|
| 540 |
+
|
| 541 |
+
# Send object
|
| 542 |
+
torch.distributed.send(object_tensor,
|
| 543 |
+
dst=self.ranks[dst],
|
| 544 |
+
group=self.cpu_group)
|
| 545 |
+
|
| 546 |
+
return None
|
| 547 |
+
|
| 548 |
+
def recv_object(self, src: int) -> Any:
|
| 549 |
+
"""Receive the input object list from the source rank."""
|
| 550 |
+
"""NOTE: `src` is the local rank of the source rank."""
|
| 551 |
+
|
| 552 |
+
assert src < self.world_size, f"Invalid src rank ({src})"
|
| 553 |
+
|
| 554 |
+
assert src != self.rank_in_group, (
|
| 555 |
+
"Invalid source rank. Source rank is the same as the current rank."
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
|
| 559 |
+
|
| 560 |
+
# Receive object size
|
| 561 |
+
rank_size = torch.distributed.recv(size_tensor,
|
| 562 |
+
src=self.ranks[src],
|
| 563 |
+
group=self.cpu_group)
|
| 564 |
+
|
| 565 |
+
# Tensor to receive serialized objects into.
|
| 566 |
+
object_tensor = torch.empty( # type: ignore[call-overload]
|
| 567 |
+
size_tensor.item(), # type: ignore[arg-type]
|
| 568 |
+
dtype=torch.uint8,
|
| 569 |
+
device="cpu")
|
| 570 |
+
|
| 571 |
+
rank_object = torch.distributed.recv(object_tensor,
|
| 572 |
+
src=self.ranks[src],
|
| 573 |
+
group=self.cpu_group)
|
| 574 |
+
|
| 575 |
+
assert rank_object == rank_size, (
|
| 576 |
+
"Received object sender rank does not match the size sender rank.")
|
| 577 |
+
|
| 578 |
+
obj = pickle.loads(object_tensor.numpy().tobytes())
|
| 579 |
+
|
| 580 |
+
return obj
|
| 581 |
+
|
| 582 |
+
def broadcast_tensor_dict(
|
| 583 |
+
self,
|
| 584 |
+
tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
|
| 585 |
+
src: int = 0,
|
| 586 |
+
group: Optional[ProcessGroup] = None,
|
| 587 |
+
metadata_group: Optional[ProcessGroup] = None
|
| 588 |
+
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
| 589 |
+
"""Broadcast the input tensor dictionary.
|
| 590 |
+
NOTE: `src` is the local rank of the source rank.
|
| 591 |
+
"""
|
| 592 |
+
# Bypass the function if we are using only 1 GPU.
|
| 593 |
+
if (not torch.distributed.is_initialized() or self.world_size == 1):
|
| 594 |
+
return tensor_dict
|
| 595 |
+
|
| 596 |
+
group = self.device_group
|
| 597 |
+
metadata_group = self.cpu_group
|
| 598 |
+
assert src < self.world_size, f"Invalid src rank ({src})"
|
| 599 |
+
|
| 600 |
+
rank_in_group = self.rank_in_group
|
| 601 |
+
if rank_in_group == src:
|
| 602 |
+
metadata_list: List[Tuple[Any, Any]] = []
|
| 603 |
+
assert isinstance(
|
| 604 |
+
tensor_dict,
|
| 605 |
+
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
|
| 606 |
+
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
| 607 |
+
# `metadata_list` lives in CPU memory.
|
| 608 |
+
# `broadcast_object_list` has serialization & deserialization,
|
| 609 |
+
# all happening on CPU. Therefore, we can use the CPU group.
|
| 610 |
+
self.broadcast_object(metadata_list, src=src)
|
| 611 |
+
async_handles = []
|
| 612 |
+
for tensor in tensor_list:
|
| 613 |
+
if tensor.numel() == 0:
|
| 614 |
+
# Skip broadcasting empty tensors.
|
| 615 |
+
continue
|
| 616 |
+
if tensor.is_cpu:
|
| 617 |
+
# use metadata_group for CPU tensors
|
| 618 |
+
handle = torch.distributed.broadcast(tensor,
|
| 619 |
+
src=self.ranks[src],
|
| 620 |
+
group=metadata_group,
|
| 621 |
+
async_op=True)
|
| 622 |
+
else:
|
| 623 |
+
# use group for GPU tensors
|
| 624 |
+
handle = torch.distributed.broadcast(tensor,
|
| 625 |
+
src=self.ranks[src],
|
| 626 |
+
group=group,
|
| 627 |
+
async_op=True)
|
| 628 |
+
async_handles.append(handle)
|
| 629 |
+
for async_handle in async_handles:
|
| 630 |
+
async_handle.wait()
|
| 631 |
+
|
| 632 |
+
else:
|
| 633 |
+
metadata_list = self.broadcast_object(None, src=src)
|
| 634 |
+
tensor_dict = {}
|
| 635 |
+
async_handles = []
|
| 636 |
+
for key, value in metadata_list:
|
| 637 |
+
if isinstance(value, TensorMetadata):
|
| 638 |
+
tensor = torch.empty(value.size,
|
| 639 |
+
dtype=value.dtype,
|
| 640 |
+
device=value.device)
|
| 641 |
+
if tensor.numel() == 0:
|
| 642 |
+
# Skip broadcasting empty tensors.
|
| 643 |
+
tensor_dict[key] = tensor
|
| 644 |
+
continue
|
| 645 |
+
if tensor.is_cpu:
|
| 646 |
+
# use metadata_group for CPU tensors
|
| 647 |
+
handle = torch.distributed.broadcast(
|
| 648 |
+
tensor,
|
| 649 |
+
src=self.ranks[src],
|
| 650 |
+
group=metadata_group,
|
| 651 |
+
async_op=True)
|
| 652 |
+
else:
|
| 653 |
+
# use group for GPU tensors
|
| 654 |
+
handle = torch.distributed.broadcast(
|
| 655 |
+
tensor,
|
| 656 |
+
src=self.ranks[src],
|
| 657 |
+
group=group,
|
| 658 |
+
async_op=True)
|
| 659 |
+
async_handles.append(handle)
|
| 660 |
+
tensor_dict[key] = tensor
|
| 661 |
+
else:
|
| 662 |
+
tensor_dict[key] = value
|
| 663 |
+
for async_handle in async_handles:
|
| 664 |
+
async_handle.wait()
|
| 665 |
+
return tensor_dict
|
| 666 |
+
|
| 667 |
+
def send_tensor_dict(
|
| 668 |
+
self,
|
| 669 |
+
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
|
| 670 |
+
dst: Optional[int] = None,
|
| 671 |
+
all_gather_group: Optional["GroupCoordinator"] = None,
|
| 672 |
+
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
| 673 |
+
"""Send the input tensor dictionary.
|
| 674 |
+
NOTE: `dst` is the local rank of the source rank.
|
| 675 |
+
"""
|
| 676 |
+
# Bypass the function if we are using only 1 GPU.
|
| 677 |
+
if not torch.distributed.is_initialized() or self.world_size == 1:
|
| 678 |
+
return tensor_dict
|
| 679 |
+
|
| 680 |
+
all_gather_size = (1 if all_gather_group is None else
|
| 681 |
+
all_gather_group.world_size)
|
| 682 |
+
all_gather_rank = (0 if all_gather_group is None else
|
| 683 |
+
all_gather_group.rank_in_group)
|
| 684 |
+
|
| 685 |
+
group = self.device_group
|
| 686 |
+
metadata_group = self.cpu_group
|
| 687 |
+
|
| 688 |
+
if dst is None:
|
| 689 |
+
dst = (self.rank_in_group + 1) % self.world_size
|
| 690 |
+
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
| 691 |
+
|
| 692 |
+
metadata_list: List[Tuple[Any, Any]] = []
|
| 693 |
+
assert isinstance(
|
| 694 |
+
tensor_dict,
|
| 695 |
+
dict), f"Expecting a dictionary, got {type(tensor_dict)}"
|
| 696 |
+
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
| 697 |
+
# `metadata_list` lives in CPU memory.
|
| 698 |
+
# `send_object_list` has serialization & deserialization,
|
| 699 |
+
# all happening on CPU. Therefore, we can use the CPU group.
|
| 700 |
+
self.send_object(metadata_list, dst=dst)
|
| 701 |
+
for tensor in tensor_list:
|
| 702 |
+
if tensor.numel() == 0:
|
| 703 |
+
# Skip sending empty tensors.
|
| 704 |
+
continue
|
| 705 |
+
|
| 706 |
+
# send-allgather: send only a slice, then do allgather.
|
| 707 |
+
if (all_gather_group is not None
|
| 708 |
+
and tensor.numel() % all_gather_size == 0):
|
| 709 |
+
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
|
| 710 |
+
|
| 711 |
+
if tensor.is_cpu:
|
| 712 |
+
# use metadata_group for CPU tensors
|
| 713 |
+
torch.distributed.send(tensor,
|
| 714 |
+
dst=self.ranks[dst],
|
| 715 |
+
group=metadata_group)
|
| 716 |
+
else:
|
| 717 |
+
# use group for GPU tensors
|
| 718 |
+
torch.distributed.send(tensor,
|
| 719 |
+
dst=self.ranks[dst],
|
| 720 |
+
group=group)
|
| 721 |
+
return None
|
| 722 |
+
|
| 723 |
+
def recv_tensor_dict(
|
| 724 |
+
self,
|
| 725 |
+
src: Optional[int] = None,
|
| 726 |
+
all_gather_group: Optional["GroupCoordinator"] = None,
|
| 727 |
+
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
| 728 |
+
"""Recv the input tensor dictionary.
|
| 729 |
+
NOTE: `src` is the local rank of the source rank.
|
| 730 |
+
"""
|
| 731 |
+
# Bypass the function if we are using only 1 GPU.
|
| 732 |
+
if not torch.distributed.is_initialized() or self.world_size == 1:
|
| 733 |
+
return None
|
| 734 |
+
|
| 735 |
+
all_gather_size = (1 if all_gather_group is None else
|
| 736 |
+
all_gather_group.world_size)
|
| 737 |
+
all_gather_rank = (0 if all_gather_group is None else
|
| 738 |
+
all_gather_group.rank_in_group)
|
| 739 |
+
|
| 740 |
+
group = self.device_group
|
| 741 |
+
metadata_group = self.cpu_group
|
| 742 |
+
|
| 743 |
+
if src is None:
|
| 744 |
+
src = (self.rank_in_group - 1) % self.world_size
|
| 745 |
+
assert src < self.world_size, f"Invalid src rank ({src})"
|
| 746 |
+
|
| 747 |
+
recv_metadata_list = self.recv_object(src=src)
|
| 748 |
+
tensor_dict: Dict[str, Any] = {}
|
| 749 |
+
for key, value in recv_metadata_list:
|
| 750 |
+
if isinstance(value, TensorMetadata):
|
| 751 |
+
tensor = torch.empty(value.size,
|
| 752 |
+
dtype=value.dtype,
|
| 753 |
+
device=value.device)
|
| 754 |
+
if tensor.numel() == 0:
|
| 755 |
+
# Skip broadcasting empty tensors.
|
| 756 |
+
tensor_dict[key] = tensor
|
| 757 |
+
continue
|
| 758 |
+
|
| 759 |
+
# send-allgather: send only a slice, then do allgather.
|
| 760 |
+
use_all_gather = (all_gather_group is not None
|
| 761 |
+
and tensor.numel() % all_gather_size == 0)
|
| 762 |
+
|
| 763 |
+
if use_all_gather:
|
| 764 |
+
orig_shape = tensor.shape
|
| 765 |
+
tensor = tensor.reshape(all_gather_size,
|
| 766 |
+
-1)[all_gather_rank]
|
| 767 |
+
|
| 768 |
+
if tensor.is_cpu:
|
| 769 |
+
# use metadata_group for CPU tensors
|
| 770 |
+
torch.distributed.recv(tensor,
|
| 771 |
+
src=self.ranks[src],
|
| 772 |
+
group=metadata_group)
|
| 773 |
+
else:
|
| 774 |
+
# use group for GPU tensors
|
| 775 |
+
torch.distributed.recv(tensor,
|
| 776 |
+
src=self.ranks[src],
|
| 777 |
+
group=group)
|
| 778 |
+
if use_all_gather:
|
| 779 |
+
# do the allgather
|
| 780 |
+
tensor = all_gather_group.all_gather( # type: ignore
|
| 781 |
+
tensor, dim=0)
|
| 782 |
+
tensor = tensor.reshape(orig_shape)
|
| 783 |
+
|
| 784 |
+
tensor_dict[key] = tensor
|
| 785 |
+
else:
|
| 786 |
+
tensor_dict[key] = value
|
| 787 |
+
return tensor_dict
|
| 788 |
+
|
| 789 |
+
def barrier(self):
|
| 790 |
+
"""Barrier synchronization among the group.
|
| 791 |
+
NOTE: don't use `device_group` here! `barrier` in NCCL is
|
| 792 |
+
terrible because it is internally a broadcast operation with
|
| 793 |
+
secretly created GPU tensors. It is easy to mess up the current
|
| 794 |
+
device. Use the CPU group instead.
|
| 795 |
+
"""
|
| 796 |
+
torch.distributed.barrier(group=self.cpu_group)
|
| 797 |
+
|
| 798 |
+
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
| 799 |
+
"""Sends a tensor to the destination rank in a non-blocking way"""
|
| 800 |
+
"""NOTE: `dst` is the local rank of the destination rank."""
|
| 801 |
+
if dst is None:
|
| 802 |
+
dst = (self.rank_in_group + 1) % self.world_size
|
| 803 |
+
|
| 804 |
+
pynccl_comm = self.pynccl_comm
|
| 805 |
+
if pynccl_comm is not None and not pynccl_comm.disabled:
|
| 806 |
+
pynccl_comm.send(tensor, dst)
|
| 807 |
+
else:
|
| 808 |
+
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
|
| 809 |
+
|
| 810 |
+
def recv(self,
|
| 811 |
+
size: torch.Size,
|
| 812 |
+
dtype: torch.dtype,
|
| 813 |
+
src: Optional[int] = None) -> torch.Tensor:
|
| 814 |
+
"""Receives a tensor from the source rank."""
|
| 815 |
+
"""NOTE: `src` is the local rank of the source rank."""
|
| 816 |
+
if src is None:
|
| 817 |
+
src = (self.rank_in_group - 1) % self.world_size
|
| 818 |
+
|
| 819 |
+
tensor = torch.empty(size, dtype=dtype, device=self.device)
|
| 820 |
+
pynccl_comm = self.pynccl_comm
|
| 821 |
+
if pynccl_comm is not None and not pynccl_comm.disabled:
|
| 822 |
+
pynccl_comm.recv(tensor, src)
|
| 823 |
+
else:
|
| 824 |
+
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
| 825 |
+
return tensor
|
| 826 |
+
|
| 827 |
+
def destroy(self):
|
| 828 |
+
if self.device_group is not None:
|
| 829 |
+
torch.distributed.destroy_process_group(self.device_group)
|
| 830 |
+
self.device_group = None
|
| 831 |
+
if self.cpu_group is not None:
|
| 832 |
+
torch.distributed.destroy_process_group(self.cpu_group)
|
| 833 |
+
self.cpu_group = None
|
| 834 |
+
if self.pynccl_comm is not None:
|
| 835 |
+
self.pynccl_comm = None
|
| 836 |
+
if self.ca_comm is not None:
|
| 837 |
+
self.ca_comm = None
|
| 838 |
+
if self.mq_broadcaster is not None:
|
| 839 |
+
self.mq_broadcaster = None
|
| 840 |
+
|
| 841 |
+
|
| 842 |
+
_WORLD: Optional[GroupCoordinator] = None
|
| 843 |
+
|
| 844 |
+
|
| 845 |
+
def get_world_group() -> GroupCoordinator:
|
| 846 |
+
assert _WORLD is not None, ("world group is not initialized")
|
| 847 |
+
return _WORLD
|
| 848 |
+
|
| 849 |
+
|
| 850 |
+
def init_world_group(ranks: List[int], local_rank: int,
|
| 851 |
+
backend: str) -> GroupCoordinator:
|
| 852 |
+
return GroupCoordinator(
|
| 853 |
+
group_ranks=[ranks],
|
| 854 |
+
local_rank=local_rank,
|
| 855 |
+
torch_distributed_backend=backend,
|
| 856 |
+
use_pynccl=False,
|
| 857 |
+
use_custom_allreduce=False,
|
| 858 |
+
use_tpu_communicator=False,
|
| 859 |
+
use_hpu_communicator=False,
|
| 860 |
+
use_xpu_communicator=False,
|
| 861 |
+
group_name="world",
|
| 862 |
+
)
|
| 863 |
+
|
| 864 |
+
|
| 865 |
+
def init_model_parallel_group(
|
| 866 |
+
group_ranks: List[List[int]],
|
| 867 |
+
local_rank: int,
|
| 868 |
+
backend: str,
|
| 869 |
+
use_custom_allreduce: Optional[bool] = None,
|
| 870 |
+
use_message_queue_broadcaster: bool = False,
|
| 871 |
+
group_name: Optional[str] = None,
|
| 872 |
+
) -> GroupCoordinator:
|
| 873 |
+
if use_custom_allreduce is None:
|
| 874 |
+
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
| 875 |
+
from vllm.platforms import current_platform
|
| 876 |
+
return GroupCoordinator(
|
| 877 |
+
group_ranks=group_ranks,
|
| 878 |
+
local_rank=local_rank,
|
| 879 |
+
torch_distributed_backend=backend,
|
| 880 |
+
use_pynccl=current_platform.is_cuda_alike(),
|
| 881 |
+
use_custom_allreduce=current_platform.is_cuda_alike()
|
| 882 |
+
and use_custom_allreduce,
|
| 883 |
+
use_tpu_communicator=True,
|
| 884 |
+
use_hpu_communicator=True,
|
| 885 |
+
use_xpu_communicator=True,
|
| 886 |
+
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
| 887 |
+
group_name=group_name,
|
| 888 |
+
)
|
| 889 |
+
|
| 890 |
+
|
| 891 |
+
_TP: Optional[GroupCoordinator] = None
|
| 892 |
+
|
| 893 |
+
|
| 894 |
+
def get_tp_group() -> GroupCoordinator:
|
| 895 |
+
assert _TP is not None, ("tensor model parallel group is not initialized")
|
| 896 |
+
return _TP
|
| 897 |
+
|
| 898 |
+
|
| 899 |
+
# kept for backward compatibility
|
| 900 |
+
get_tensor_model_parallel_group = get_tp_group
|
| 901 |
+
|
| 902 |
+
_PP: Optional[GroupCoordinator] = None
|
| 903 |
+
|
| 904 |
+
|
| 905 |
+
def get_pp_group() -> GroupCoordinator:
|
| 906 |
+
assert _PP is not None, (
|
| 907 |
+
"pipeline model parallel group is not initialized")
|
| 908 |
+
return _PP
|
| 909 |
+
|
| 910 |
+
|
| 911 |
+
# kept for backward compatibility
|
| 912 |
+
get_pipeline_model_parallel_group = get_pp_group
|
| 913 |
+
|
| 914 |
+
_KV_TRANSFER: Optional[kv_transfer.KVTransferAgent] = None
|
| 915 |
+
|
| 916 |
+
|
| 917 |
+
def get_kv_transfer_group() -> kv_transfer.KVTransferAgent:
|
| 918 |
+
assert _KV_TRANSFER is not None, (
|
| 919 |
+
"disaggregated KV cache transfer parallel group is not initialized")
|
| 920 |
+
return _KV_TRANSFER
|
| 921 |
+
|
| 922 |
+
|
| 923 |
+
@contextmanager
|
| 924 |
+
def graph_capture(device: torch.device):
|
| 925 |
+
"""
|
| 926 |
+
`graph_capture` is a context manager which should surround the code that
|
| 927 |
+
is capturing the CUDA graph. Its main purpose is to ensure that the
|
| 928 |
+
some operations will be run after the graph is captured, before the graph
|
| 929 |
+
is replayed. It returns a `GraphCaptureContext` object which contains the
|
| 930 |
+
necessary data for the graph capture. Currently, it only contains the
|
| 931 |
+
stream that the graph capture is running on. This stream is set to the
|
| 932 |
+
current CUDA stream when the context manager is entered and reset to the
|
| 933 |
+
default stream when the context manager is exited. This is to ensure that
|
| 934 |
+
the graph capture is running on a separate stream from the default stream,
|
| 935 |
+
in order to explicitly distinguish the kernels to capture
|
| 936 |
+
from other kernels possibly launched on background in the default stream.
|
| 937 |
+
"""
|
| 938 |
+
context = GraphCaptureContext(torch.cuda.Stream(device=device))
|
| 939 |
+
with get_tp_group().graph_capture(context), get_pp_group().graph_capture(
|
| 940 |
+
context):
|
| 941 |
+
yield context
|
| 942 |
+
|
| 943 |
+
|
| 944 |
+
logger = init_logger(__name__)
|
| 945 |
+
|
| 946 |
+
_ENABLE_CUSTOM_ALL_REDUCE = True
|
| 947 |
+
|
| 948 |
+
|
| 949 |
+
def set_custom_all_reduce(enable: bool):
|
| 950 |
+
global _ENABLE_CUSTOM_ALL_REDUCE
|
| 951 |
+
_ENABLE_CUSTOM_ALL_REDUCE = enable
|
| 952 |
+
|
| 953 |
+
|
| 954 |
+
def init_distributed_environment(
|
| 955 |
+
world_size: int = -1,
|
| 956 |
+
rank: int = -1,
|
| 957 |
+
distributed_init_method: str = "env://",
|
| 958 |
+
local_rank: int = -1,
|
| 959 |
+
backend: str = "nccl",
|
| 960 |
+
):
|
| 961 |
+
logger.debug(
|
| 962 |
+
"world_size=%d rank=%d local_rank=%d "
|
| 963 |
+
"distributed_init_method=%s backend=%s", world_size, rank, local_rank,
|
| 964 |
+
distributed_init_method, backend)
|
| 965 |
+
if not torch.distributed.is_initialized():
|
| 966 |
+
assert distributed_init_method is not None, (
|
| 967 |
+
"distributed_init_method must be provided when initializing "
|
| 968 |
+
"distributed environment")
|
| 969 |
+
# this backend is used for WORLD
|
| 970 |
+
torch.distributed.init_process_group(
|
| 971 |
+
backend=backend,
|
| 972 |
+
init_method=distributed_init_method,
|
| 973 |
+
world_size=world_size,
|
| 974 |
+
rank=rank)
|
| 975 |
+
# set the local rank
|
| 976 |
+
# local_rank is not available in torch ProcessGroup,
|
| 977 |
+
# see https://github.com/pytorch/pytorch/issues/122816
|
| 978 |
+
if local_rank == -1:
|
| 979 |
+
# local rank not set, this usually happens in single-node
|
| 980 |
+
# setting, where we can use rank as local rank
|
| 981 |
+
if distributed_init_method == "env://":
|
| 982 |
+
local_rank = envs.LOCAL_RANK
|
| 983 |
+
else:
|
| 984 |
+
local_rank = rank
|
| 985 |
+
global _WORLD
|
| 986 |
+
if _WORLD is None:
|
| 987 |
+
ranks = list(range(torch.distributed.get_world_size()))
|
| 988 |
+
_WORLD = init_world_group(ranks, local_rank, backend)
|
| 989 |
+
else:
|
| 990 |
+
assert _WORLD.world_size == torch.distributed.get_world_size(), (
|
| 991 |
+
"world group already initialized with a different world size")
|
| 992 |
+
|
| 993 |
+
|
| 994 |
+
def initialize_model_parallel(
|
| 995 |
+
tensor_model_parallel_size: int = 1,
|
| 996 |
+
pipeline_model_parallel_size: int = 1,
|
| 997 |
+
backend: Optional[str] = None,
|
| 998 |
+
) -> None:
|
| 999 |
+
"""
|
| 1000 |
+
Initialize model parallel groups.
|
| 1001 |
+
|
| 1002 |
+
Arguments:
|
| 1003 |
+
tensor_model_parallel_size: number of GPUs used for tensor model
|
| 1004 |
+
parallelism.
|
| 1005 |
+
pipeline_model_parallel_size: number of GPUs used for pipeline model
|
| 1006 |
+
parallelism.
|
| 1007 |
+
|
| 1008 |
+
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
|
| 1009 |
+
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
|
| 1010 |
+
the model pipeline. The present function will
|
| 1011 |
+
create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
|
| 1012 |
+
4 tensor model-parallel groups:
|
| 1013 |
+
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
|
| 1014 |
+
2 pipeline model-parallel groups:
|
| 1015 |
+
[g0, g2, g4, g6], [g1, g3, g5, g7]
|
| 1016 |
+
Note that for efficiency, the caller should make sure adjacent ranks
|
| 1017 |
+
are on the same DGX box. For example if we are using 2 DGX-1 boxes
|
| 1018 |
+
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
|
| 1019 |
+
ranks 8 to 15 belong to the second box.
|
| 1020 |
+
"""
|
| 1021 |
+
# Get world size and rank. Ensure some consistencies.
|
| 1022 |
+
assert torch.distributed.is_initialized()
|
| 1023 |
+
world_size: int = torch.distributed.get_world_size()
|
| 1024 |
+
backend = backend or torch.distributed.get_backend(
|
| 1025 |
+
get_world_group().device_group)
|
| 1026 |
+
|
| 1027 |
+
if (world_size
|
| 1028 |
+
!= tensor_model_parallel_size * pipeline_model_parallel_size):
|
| 1029 |
+
raise RuntimeError(
|
| 1030 |
+
f"world_size ({world_size}) is not equal to "
|
| 1031 |
+
f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
|
| 1032 |
+
f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
|
| 1033 |
+
|
| 1034 |
+
# Build the tensor model-parallel groups.
|
| 1035 |
+
num_tensor_model_parallel_groups: int = (world_size //
|
| 1036 |
+
tensor_model_parallel_size)
|
| 1037 |
+
global _TP
|
| 1038 |
+
assert _TP is None, ("tensor model parallel group is already initialized")
|
| 1039 |
+
group_ranks = []
|
| 1040 |
+
for i in range(num_tensor_model_parallel_groups):
|
| 1041 |
+
ranks = list(
|
| 1042 |
+
range(i * tensor_model_parallel_size,
|
| 1043 |
+
(i + 1) * tensor_model_parallel_size))
|
| 1044 |
+
group_ranks.append(ranks)
|
| 1045 |
+
|
| 1046 |
+
# message queue broadcaster is only used in tensor model parallel group
|
| 1047 |
+
_TP = init_model_parallel_group(group_ranks,
|
| 1048 |
+
get_world_group().local_rank,
|
| 1049 |
+
backend,
|
| 1050 |
+
use_message_queue_broadcaster=True,
|
| 1051 |
+
group_name="tp")
|
| 1052 |
+
|
| 1053 |
+
# Build the pipeline model-parallel groups.
|
| 1054 |
+
num_pipeline_model_parallel_groups: int = (world_size //
|
| 1055 |
+
pipeline_model_parallel_size)
|
| 1056 |
+
global _PP
|
| 1057 |
+
assert _PP is None, (
|
| 1058 |
+
"pipeline model parallel group is already initialized")
|
| 1059 |
+
group_ranks = []
|
| 1060 |
+
for i in range(num_pipeline_model_parallel_groups):
|
| 1061 |
+
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
|
| 1062 |
+
group_ranks.append(ranks)
|
| 1063 |
+
# pipeline parallel does not need custom allreduce
|
| 1064 |
+
_PP = init_model_parallel_group(group_ranks,
|
| 1065 |
+
get_world_group().local_rank,
|
| 1066 |
+
backend,
|
| 1067 |
+
use_custom_allreduce=False,
|
| 1068 |
+
group_name="pp")
|
| 1069 |
+
|
| 1070 |
+
|
| 1071 |
+
def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
|
| 1072 |
+
"""
|
| 1073 |
+
Initialize KV cache transfer parallel group.
|
| 1074 |
+
"""
|
| 1075 |
+
|
| 1076 |
+
global _KV_TRANSFER
|
| 1077 |
+
|
| 1078 |
+
if vllm_config.kv_transfer_config is None:
|
| 1079 |
+
return
|
| 1080 |
+
|
| 1081 |
+
if all([
|
| 1082 |
+
vllm_config.kv_transfer_config.need_kv_parallel_group, _KV_TRANSFER
|
| 1083 |
+
is None
|
| 1084 |
+
]):
|
| 1085 |
+
_KV_TRANSFER = kv_transfer.KVTransferAgent(
|
| 1086 |
+
rank=get_world_group().rank,
|
| 1087 |
+
local_rank=get_world_group().local_rank,
|
| 1088 |
+
config=vllm_config)
|
| 1089 |
+
|
| 1090 |
+
|
| 1091 |
+
def ensure_model_parallel_initialized(
|
| 1092 |
+
tensor_model_parallel_size: int,
|
| 1093 |
+
pipeline_model_parallel_size: int,
|
| 1094 |
+
backend: Optional[str] = None,
|
| 1095 |
+
) -> None:
|
| 1096 |
+
"""Helper to initialize model parallel groups if they are not initialized,
|
| 1097 |
+
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
|
| 1098 |
+
values if the model parallel groups are initialized.
|
| 1099 |
+
"""
|
| 1100 |
+
backend = backend or torch.distributed.get_backend(
|
| 1101 |
+
get_world_group().device_group)
|
| 1102 |
+
if not model_parallel_is_initialized():
|
| 1103 |
+
initialize_model_parallel(tensor_model_parallel_size,
|
| 1104 |
+
pipeline_model_parallel_size, backend)
|
| 1105 |
+
return
|
| 1106 |
+
|
| 1107 |
+
assert (
|
| 1108 |
+
get_tensor_model_parallel_world_size() == tensor_model_parallel_size
|
| 1109 |
+
), ("tensor parallel group already initialized, but of unexpected size: "
|
| 1110 |
+
f"{get_tensor_model_parallel_world_size()=} vs. "
|
| 1111 |
+
f"{tensor_model_parallel_size=}")
|
| 1112 |
+
pp_world_size = get_pp_group().world_size
|
| 1113 |
+
assert (pp_world_size == pipeline_model_parallel_size), (
|
| 1114 |
+
"pipeline parallel group already initialized, but of unexpected size: "
|
| 1115 |
+
f"{pp_world_size=} vs. "
|
| 1116 |
+
f"{pipeline_model_parallel_size=}")
|
| 1117 |
+
|
| 1118 |
+
|
| 1119 |
+
def model_parallel_is_initialized():
|
| 1120 |
+
"""Check if tensor and pipeline parallel groups are initialized."""
|
| 1121 |
+
return (_TP is not None and _PP is not None)
|
| 1122 |
+
|
| 1123 |
+
|
| 1124 |
+
_TP_STATE_PATCHED = False
|
| 1125 |
+
|
| 1126 |
+
|
| 1127 |
+
@contextmanager
|
| 1128 |
+
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
|
| 1129 |
+
"""Patch the tp group temporarily until this function ends.
|
| 1130 |
+
|
| 1131 |
+
This method is for draft workers of speculative decoding to run draft model
|
| 1132 |
+
with different tp degree from that of target model workers.
|
| 1133 |
+
|
| 1134 |
+
Args:
|
| 1135 |
+
tp_group (GroupCoordinator): the tp group coordinator
|
| 1136 |
+
"""
|
| 1137 |
+
global _TP_STATE_PATCHED
|
| 1138 |
+
assert not _TP_STATE_PATCHED, "Should not call when it's already patched"
|
| 1139 |
+
|
| 1140 |
+
_TP_STATE_PATCHED = True
|
| 1141 |
+
old_tp_group = get_tp_group()
|
| 1142 |
+
global _TP
|
| 1143 |
+
_TP = tp_group
|
| 1144 |
+
try:
|
| 1145 |
+
yield
|
| 1146 |
+
finally:
|
| 1147 |
+
# restore the original state
|
| 1148 |
+
_TP_STATE_PATCHED = False
|
| 1149 |
+
_TP = old_tp_group
|
| 1150 |
+
|
| 1151 |
+
|
| 1152 |
+
def get_tensor_model_parallel_world_size():
|
| 1153 |
+
"""Return world size for the tensor model parallel group."""
|
| 1154 |
+
return get_tp_group().world_size
|
| 1155 |
+
|
| 1156 |
+
|
| 1157 |
+
def get_tensor_model_parallel_rank():
|
| 1158 |
+
"""Return my rank for the tensor model parallel group."""
|
| 1159 |
+
return get_tp_group().rank_in_group
|
| 1160 |
+
|
| 1161 |
+
|
| 1162 |
+
def destroy_model_parallel():
|
| 1163 |
+
"""Set the groups to none and destroy them."""
|
| 1164 |
+
global _TP
|
| 1165 |
+
if _TP:
|
| 1166 |
+
_TP.destroy()
|
| 1167 |
+
_TP = None
|
| 1168 |
+
|
| 1169 |
+
global _PP
|
| 1170 |
+
if _PP:
|
| 1171 |
+
_PP.destroy()
|
| 1172 |
+
_PP = None
|
| 1173 |
+
|
| 1174 |
+
|
| 1175 |
+
def destroy_distributed_environment():
|
| 1176 |
+
global _WORLD
|
| 1177 |
+
if _WORLD:
|
| 1178 |
+
_WORLD.destroy()
|
| 1179 |
+
_WORLD = None
|
| 1180 |
+
if torch.distributed.is_initialized():
|
| 1181 |
+
torch.distributed.destroy_process_group()
|
| 1182 |
+
|
| 1183 |
+
|
| 1184 |
+
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
|
| 1185 |
+
destroy_model_parallel()
|
| 1186 |
+
destroy_distributed_environment()
|
| 1187 |
+
with contextlib.suppress(AssertionError):
|
| 1188 |
+
torch.distributed.destroy_process_group()
|
| 1189 |
+
if shutdown_ray:
|
| 1190 |
+
import ray # Lazy import Ray
|
| 1191 |
+
ray.shutdown()
|
| 1192 |
+
gc.collect()
|
| 1193 |
+
from vllm.platforms import current_platform
|
| 1194 |
+
if not current_platform.is_cpu():
|
| 1195 |
+
torch.cuda.empty_cache()
|
| 1196 |
+
try:
|
| 1197 |
+
torch._C._host_emptyCache()
|
| 1198 |
+
except AttributeError:
|
| 1199 |
+
logger.warning(
|
| 1200 |
+
"torch._C._host_emptyCache() only available in Pytorch >=2.5")
|
| 1201 |
+
|
| 1202 |
+
|
| 1203 |
+
def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
|
| 1204 |
+
source_rank: int = 0) -> List[bool]:
|
| 1205 |
+
"""
|
| 1206 |
+
This is a collective operation that returns if each rank is in the same node
|
| 1207 |
+
as the source rank. It tests if processes are attached to the same
|
| 1208 |
+
memory system (shared access to shared memory).
|
| 1209 |
+
"""
|
| 1210 |
+
if isinstance(pg, ProcessGroup):
|
| 1211 |
+
assert torch.distributed.get_backend(
|
| 1212 |
+
pg) != torch.distributed.Backend.NCCL, (
|
| 1213 |
+
"in_the_same_node_as should be tested with a non-NCCL group.")
|
| 1214 |
+
# local rank inside the group
|
| 1215 |
+
rank = torch.distributed.get_rank(group=pg)
|
| 1216 |
+
world_size = torch.distributed.get_world_size(group=pg)
|
| 1217 |
+
|
| 1218 |
+
# global ranks of the processes in the group
|
| 1219 |
+
ranks = torch.distributed.get_process_group_ranks(pg)
|
| 1220 |
+
else:
|
| 1221 |
+
rank = pg.rank
|
| 1222 |
+
world_size = pg.world_size
|
| 1223 |
+
ranks = list(range(world_size))
|
| 1224 |
+
|
| 1225 |
+
# local tensor in each process to store the result
|
| 1226 |
+
is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)
|
| 1227 |
+
|
| 1228 |
+
magic_message = b"magic_message"
|
| 1229 |
+
shm = None
|
| 1230 |
+
|
| 1231 |
+
try:
|
| 1232 |
+
with contextlib.suppress(OSError):
|
| 1233 |
+
if rank == source_rank:
|
| 1234 |
+
# create a shared memory segment
|
| 1235 |
+
shm = shared_memory.SharedMemory(create=True, size=128)
|
| 1236 |
+
shm.buf[:len(magic_message)] = magic_message
|
| 1237 |
+
if isinstance(pg, ProcessGroup):
|
| 1238 |
+
torch.distributed.broadcast_object_list(
|
| 1239 |
+
[shm.name], src=ranks[source_rank], group=pg)
|
| 1240 |
+
else:
|
| 1241 |
+
pg.broadcast_obj(shm.name, src=source_rank)
|
| 1242 |
+
is_in_the_same_node[rank] = 1
|
| 1243 |
+
else:
|
| 1244 |
+
# try to open the shared memory segment
|
| 1245 |
+
if isinstance(pg, ProcessGroup):
|
| 1246 |
+
recv = [None]
|
| 1247 |
+
torch.distributed.broadcast_object_list(
|
| 1248 |
+
recv, src=ranks[source_rank], group=pg)
|
| 1249 |
+
name = recv[0]
|
| 1250 |
+
else:
|
| 1251 |
+
name = pg.broadcast_obj(None, src=source_rank)
|
| 1252 |
+
# fix to https://stackoverflow.com/q/62748654/9191338
|
| 1253 |
+
# Python incorrectly tracks shared memory even if it is not
|
| 1254 |
+
# created by the process. The following patch is a workaround.
|
| 1255 |
+
with patch("multiprocessing.resource_tracker.register",
|
| 1256 |
+
lambda *args, **kwargs: None):
|
| 1257 |
+
shm = shared_memory.SharedMemory(name=name)
|
| 1258 |
+
if shm.buf[:len(magic_message)] == magic_message:
|
| 1259 |
+
is_in_the_same_node[rank] = 1
|
| 1260 |
+
except Exception as e:
|
| 1261 |
+
logger.error("Error ignored in is_in_the_same_node: %s", e)
|
| 1262 |
+
finally:
|
| 1263 |
+
if shm:
|
| 1264 |
+
shm.close()
|
| 1265 |
+
|
| 1266 |
+
if isinstance(pg, ProcessGroup):
|
| 1267 |
+
torch.distributed.barrier(group=pg)
|
| 1268 |
+
else:
|
| 1269 |
+
pg.barrier()
|
| 1270 |
+
|
| 1271 |
+
# clean up the shared memory segment
|
| 1272 |
+
with contextlib.suppress(OSError):
|
| 1273 |
+
if rank == source_rank and shm:
|
| 1274 |
+
shm.unlink()
|
| 1275 |
+
|
| 1276 |
+
if isinstance(pg, ProcessGroup):
|
| 1277 |
+
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
|
| 1278 |
+
aggregated_data = is_in_the_same_node
|
| 1279 |
+
else:
|
| 1280 |
+
aggregated_data = torch.zeros_like(is_in_the_same_node)
|
| 1281 |
+
for i in range(world_size):
|
| 1282 |
+
rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
|
| 1283 |
+
aggregated_data += rank_data
|
| 1284 |
+
|
| 1285 |
+
return [x == 1 for x in aggregated_data.tolist()]
|