Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- .venv/lib/python3.11/site-packages/grpc/_cython/cygrpc.cpython-311-x86_64-linux-gnu.so +3 -0
- .venv/lib/python3.11/site-packages/vllm/__pycache__/config.cpython-311.pyc +3 -0
- .venv/lib/python3.11/site-packages/vllm/__pycache__/utils.cpython-311.pyc +3 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/activation.py +360 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/__init__.py +48 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/__pycache__/fused_marlin_moe.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/__pycache__/fused_moe.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/__pycache__/layer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/__pycache__/moe_pallas.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/__pycache__/moe_torch_iterative.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +360 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/fused_moe.py +1363 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/layer.py +647 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/moe_pallas.py +64 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py +53 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/layernorm.py +213 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/linear.py +1159 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/logits_processor.py +193 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/pooler.py +322 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/rejection_sampler.py +400 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/resampler.py +269 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/rotary_embedding.py +1114 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/sampler.py +1292 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/spec_decode_base_sampler.py +256 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/typical_acceptance_sampler.py +172 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/utils.py +58 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/layers/vocab_parallel_embedding.py +484 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/adapters.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/arctic.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/bert.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/blip2.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/bloom.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/chameleon.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/chatglm.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/clip.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/decilm.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/deepseek.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/eagle.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/fairseq2_llama.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/falcon.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/glm.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/gpt2.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/gpt_j.cpython-311.pyc +0 -0
.gitattributes
CHANGED
|
@@ -201,3 +201,6 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
|
|
| 201 |
.venv/lib/python3.11/site-packages/jinja2/__pycache__/compiler.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 202 |
.venv/lib/python3.11/site-packages/grpc/__pycache__/_channel.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 203 |
.venv/lib/python3.11/site-packages/msgpack/_cmsgpack.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
.venv/lib/python3.11/site-packages/jinja2/__pycache__/compiler.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 202 |
.venv/lib/python3.11/site-packages/grpc/__pycache__/_channel.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 203 |
.venv/lib/python3.11/site-packages/msgpack/_cmsgpack.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 204 |
+
.venv/lib/python3.11/site-packages/vllm/__pycache__/config.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 205 |
+
.venv/lib/python3.11/site-packages/vllm/__pycache__/utils.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 206 |
+
.venv/lib/python3.11/site-packages/grpc/_cython/cygrpc.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.11/site-packages/grpc/_cython/cygrpc.cpython-311-x86_64-linux-gnu.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5e5d08f09e30133cae4310d214c9357fca55ebd0e2db830c422465af821a6392
|
| 3 |
+
size 13660664
|
.venv/lib/python3.11/site-packages/vllm/__pycache__/config.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7c332333bd8134456a6d80925bf608e61fc31c7df941a7862edcbfacf4b07e81
|
| 3 |
+
size 148527
|
.venv/lib/python3.11/site-packages/vllm/__pycache__/utils.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:35e12ff57902027f5a6f938eca8e4cb4a91c51c331e59ff752edd1b635b6330f
|
| 3 |
+
size 113860
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/activation.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""Custom activation functions."""
|
| 3 |
+
import math
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
| 11 |
+
get_tensor_model_parallel_world_size)
|
| 12 |
+
from vllm.model_executor.custom_op import CustomOp
|
| 13 |
+
from vllm.model_executor.utils import set_weight_attrs
|
| 14 |
+
from vllm.platforms import current_platform
|
| 15 |
+
from vllm.utils import LazyDict
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@CustomOp.register("fatrelu_and_mul")
|
| 19 |
+
class FatreluAndMul(CustomOp):
|
| 20 |
+
"""An activation function for FATReLU.
|
| 21 |
+
|
| 22 |
+
The function computes x -> FATReLU(x[:d]) * x[d:] where
|
| 23 |
+
d = x.shape[-1] // 2.
|
| 24 |
+
This is used in openbmb/MiniCPM-S-1B-sft.
|
| 25 |
+
|
| 26 |
+
Shapes:
|
| 27 |
+
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
| 28 |
+
return: (num_tokens, d) or (batch_size, seq_len, d)
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, threshold: float = 0.):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.threshold = threshold
|
| 34 |
+
if current_platform.is_cuda_alike():
|
| 35 |
+
self.op = torch.ops._C.fatrelu_and_mul
|
| 36 |
+
elif current_platform.is_cpu():
|
| 37 |
+
self._forward_method = self.forward_native
|
| 38 |
+
|
| 39 |
+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
| 40 |
+
d = x.shape[-1] // 2
|
| 41 |
+
x1 = x[..., :d]
|
| 42 |
+
x2 = x[..., d:]
|
| 43 |
+
x1 = F.threshold(x1, self.threshold, 0.0)
|
| 44 |
+
return x1 * x2
|
| 45 |
+
|
| 46 |
+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
| 47 |
+
d = x.shape[-1] // 2
|
| 48 |
+
output_shape = (x.shape[:-1] + (d, ))
|
| 49 |
+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
| 50 |
+
self.op(out, x, self.threshold)
|
| 51 |
+
return out
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@CustomOp.register("silu_and_mul")
|
| 55 |
+
class SiluAndMul(CustomOp):
|
| 56 |
+
"""An activation function for SwiGLU.
|
| 57 |
+
|
| 58 |
+
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
| 59 |
+
|
| 60 |
+
Shapes:
|
| 61 |
+
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
| 62 |
+
return: (num_tokens, d) or (batch_size, seq_len, d)
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def __init__(self):
|
| 66 |
+
super().__init__()
|
| 67 |
+
if current_platform.is_cuda_alike() or current_platform.is_cpu():
|
| 68 |
+
self.op = torch.ops._C.silu_and_mul
|
| 69 |
+
elif current_platform.is_xpu():
|
| 70 |
+
from vllm._ipex_ops import ipex_ops
|
| 71 |
+
self.op = ipex_ops.silu_and_mul
|
| 72 |
+
|
| 73 |
+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
| 74 |
+
"""PyTorch-native implementation equivalent to forward()."""
|
| 75 |
+
d = x.shape[-1] // 2
|
| 76 |
+
return F.silu(x[..., :d]) * x[..., d:]
|
| 77 |
+
|
| 78 |
+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
| 79 |
+
d = x.shape[-1] // 2
|
| 80 |
+
output_shape = (x.shape[:-1] + (d, ))
|
| 81 |
+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
| 82 |
+
self.op(out, x)
|
| 83 |
+
return out
|
| 84 |
+
|
| 85 |
+
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
| 86 |
+
d = x.shape[-1] // 2
|
| 87 |
+
output_shape = (x.shape[:-1] + (d, ))
|
| 88 |
+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
| 89 |
+
self.op(out, x)
|
| 90 |
+
return out
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@CustomOp.register("mul_and_silu")
|
| 94 |
+
class MulAndSilu(CustomOp):
|
| 95 |
+
"""An activation function for SwiGLU.
|
| 96 |
+
|
| 97 |
+
The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.
|
| 98 |
+
|
| 99 |
+
Shapes:
|
| 100 |
+
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
| 101 |
+
return: (num_tokens, d) or (batch_size, seq_len, d)
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def __init__(self):
|
| 105 |
+
super().__init__()
|
| 106 |
+
if current_platform.is_cuda_alike():
|
| 107 |
+
self.op = torch.ops._C.mul_and_silu
|
| 108 |
+
elif current_platform.is_xpu():
|
| 109 |
+
from vllm._ipex_ops import ipex_ops
|
| 110 |
+
self.op = ipex_ops.silu_and_mul
|
| 111 |
+
elif current_platform.is_cpu():
|
| 112 |
+
self._forward_method = self.forward_native
|
| 113 |
+
|
| 114 |
+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
| 115 |
+
"""PyTorch-native implementation equivalent to forward()."""
|
| 116 |
+
d = x.shape[-1] // 2
|
| 117 |
+
return x[..., :d] * F.silu(x[..., d:])
|
| 118 |
+
|
| 119 |
+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
| 120 |
+
d = x.shape[-1] // 2
|
| 121 |
+
output_shape = (x.shape[:-1] + (d, ))
|
| 122 |
+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
| 123 |
+
self.op(out, x)
|
| 124 |
+
return out
|
| 125 |
+
|
| 126 |
+
# TODO implement forward_xpu for MulAndSilu
|
| 127 |
+
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
@CustomOp.register("gelu_and_mul")
|
| 131 |
+
class GeluAndMul(CustomOp):
|
| 132 |
+
"""An activation function for GeGLU.
|
| 133 |
+
|
| 134 |
+
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
| 135 |
+
|
| 136 |
+
Shapes:
|
| 137 |
+
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
|
| 138 |
+
return: (batch_size, seq_len, d) or (num_tokens, d)
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
def __init__(self, approximate: str = "none"):
|
| 142 |
+
super().__init__()
|
| 143 |
+
self.approximate = approximate
|
| 144 |
+
if approximate not in ("none", "tanh"):
|
| 145 |
+
raise ValueError(f"Unknown approximate mode: {approximate}")
|
| 146 |
+
if current_platform.is_cuda_alike() or current_platform.is_cpu():
|
| 147 |
+
if approximate == "none":
|
| 148 |
+
self.op = torch.ops._C.gelu_and_mul
|
| 149 |
+
elif approximate == "tanh":
|
| 150 |
+
self.op = torch.ops._C.gelu_tanh_and_mul
|
| 151 |
+
elif current_platform.is_xpu():
|
| 152 |
+
from vllm._ipex_ops import ipex_ops
|
| 153 |
+
if approximate == "none":
|
| 154 |
+
self.op = ipex_ops.gelu_and_mul
|
| 155 |
+
else:
|
| 156 |
+
self.op = ipex_ops.gelu_tanh_and_mul
|
| 157 |
+
|
| 158 |
+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
| 159 |
+
"""PyTorch-native implementation equivalent to forward()."""
|
| 160 |
+
d = x.shape[-1] // 2
|
| 161 |
+
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
| 162 |
+
|
| 163 |
+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
| 164 |
+
d = x.shape[-1] // 2
|
| 165 |
+
output_shape = (x.shape[:-1] + (d, ))
|
| 166 |
+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
| 167 |
+
self.op(out, x)
|
| 168 |
+
return out
|
| 169 |
+
|
| 170 |
+
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
| 171 |
+
d = x.shape[-1] // 2
|
| 172 |
+
output_shape = (x.shape[:-1] + (d, ))
|
| 173 |
+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
| 174 |
+
self.op(out, x)
|
| 175 |
+
return out
|
| 176 |
+
|
| 177 |
+
def extra_repr(self) -> str:
|
| 178 |
+
return f'approximate={repr(self.approximate)}'
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
@CustomOp.register("gelu_new")
|
| 182 |
+
class NewGELU(CustomOp):
|
| 183 |
+
|
| 184 |
+
def __init__(self):
|
| 185 |
+
super().__init__()
|
| 186 |
+
if current_platform.is_cuda_alike() or current_platform.is_cpu():
|
| 187 |
+
self.op = torch.ops._C.gelu_new
|
| 188 |
+
elif current_platform.is_xpu():
|
| 189 |
+
from vllm._ipex_ops import ipex_ops
|
| 190 |
+
self.op = ipex_ops.gelu_new
|
| 191 |
+
|
| 192 |
+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
| 193 |
+
"""PyTorch-native implementation equivalent to forward()."""
|
| 194 |
+
c = math.sqrt(2.0 / math.pi)
|
| 195 |
+
return 0.5 * x * (1.0 + torch.tanh(c *
|
| 196 |
+
(x + 0.044715 * torch.pow(x, 3.0))))
|
| 197 |
+
|
| 198 |
+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
| 199 |
+
out = torch.empty_like(x)
|
| 200 |
+
self.op(out, x)
|
| 201 |
+
return out
|
| 202 |
+
|
| 203 |
+
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
| 204 |
+
return self.op(x)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
@CustomOp.register("gelu_fast")
|
| 208 |
+
class FastGELU(CustomOp):
|
| 209 |
+
|
| 210 |
+
def __init__(self):
|
| 211 |
+
super().__init__()
|
| 212 |
+
if current_platform.is_cuda_alike() or current_platform.is_cpu():
|
| 213 |
+
self.op = torch.ops._C.gelu_fast
|
| 214 |
+
elif current_platform.is_xpu():
|
| 215 |
+
from vllm._ipex_ops import ipex_ops
|
| 216 |
+
self.op = ipex_ops.gelu_fast
|
| 217 |
+
|
| 218 |
+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
| 219 |
+
"""PyTorch-native implementation equivalent to forward()."""
|
| 220 |
+
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
|
| 221 |
+
(1.0 + 0.044715 * x * x)))
|
| 222 |
+
|
| 223 |
+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
| 224 |
+
out = torch.empty_like(x)
|
| 225 |
+
self.op(out, x)
|
| 226 |
+
return out
|
| 227 |
+
|
| 228 |
+
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
| 229 |
+
return self.op(x)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
@CustomOp.register("quick_gelu")
|
| 233 |
+
class QuickGELU(CustomOp):
|
| 234 |
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
|
| 235 |
+
def __init__(self):
|
| 236 |
+
super().__init__()
|
| 237 |
+
if current_platform.is_cuda_alike() or current_platform.is_cpu():
|
| 238 |
+
self.op = torch.ops._C.gelu_quick
|
| 239 |
+
elif current_platform.is_xpu():
|
| 240 |
+
from vllm._ipex_ops import ipex_ops
|
| 241 |
+
self.op = ipex_ops.gelu_quick
|
| 242 |
+
|
| 243 |
+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
| 244 |
+
"""PyTorch-native implementation equivalent to forward()."""
|
| 245 |
+
return x * torch.sigmoid(1.702 * x)
|
| 246 |
+
|
| 247 |
+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
| 248 |
+
out = torch.empty_like(x)
|
| 249 |
+
self.op(out, x)
|
| 250 |
+
return out
|
| 251 |
+
|
| 252 |
+
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
| 253 |
+
out = torch.empty_like(x)
|
| 254 |
+
self.op(out, x)
|
| 255 |
+
return out
|
| 256 |
+
|
| 257 |
+
# TODO implement forward_xpu for QuickGELU
|
| 258 |
+
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
@CustomOp.register("relu2")
|
| 262 |
+
class ReLUSquaredActivation(CustomOp):
|
| 263 |
+
"""
|
| 264 |
+
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
| 268 |
+
"""PyTorch-native implementation equivalent to forward()."""
|
| 269 |
+
return torch.square(F.relu(x))
|
| 270 |
+
|
| 271 |
+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
| 272 |
+
return self.forward_native(x)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class ScaledActivation(nn.Module):
|
| 276 |
+
"""An activation function with post-scale parameters.
|
| 277 |
+
|
| 278 |
+
This is used for some quantization methods like AWQ.
|
| 279 |
+
"""
|
| 280 |
+
|
| 281 |
+
def __init__(
|
| 282 |
+
self,
|
| 283 |
+
act_module: nn.Module,
|
| 284 |
+
intermediate_size: int,
|
| 285 |
+
input_is_parallel: bool = True,
|
| 286 |
+
params_dtype: Optional[torch.dtype] = None,
|
| 287 |
+
):
|
| 288 |
+
super().__init__()
|
| 289 |
+
self.act = act_module
|
| 290 |
+
self.input_is_parallel = input_is_parallel
|
| 291 |
+
if input_is_parallel:
|
| 292 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 293 |
+
intermediate_size_per_partition = divide(intermediate_size,
|
| 294 |
+
tp_size)
|
| 295 |
+
else:
|
| 296 |
+
intermediate_size_per_partition = intermediate_size
|
| 297 |
+
if params_dtype is None:
|
| 298 |
+
params_dtype = torch.get_default_dtype()
|
| 299 |
+
self.scales = nn.Parameter(
|
| 300 |
+
torch.empty(intermediate_size_per_partition, dtype=params_dtype))
|
| 301 |
+
set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
|
| 302 |
+
|
| 303 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 304 |
+
return self.act(x) / self.scales
|
| 305 |
+
|
| 306 |
+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
| 307 |
+
param_data = param.data
|
| 308 |
+
if self.input_is_parallel:
|
| 309 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 310 |
+
shard_size = param_data.shape[0]
|
| 311 |
+
start_idx = tp_rank * shard_size
|
| 312 |
+
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
| 313 |
+
assert param_data.shape == loaded_weight.shape
|
| 314 |
+
param_data.copy_(loaded_weight)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
_ACTIVATION_REGISTRY = LazyDict({
|
| 318 |
+
"gelu":
|
| 319 |
+
lambda: nn.GELU(),
|
| 320 |
+
"gelu_fast":
|
| 321 |
+
lambda: FastGELU(),
|
| 322 |
+
"gelu_new":
|
| 323 |
+
lambda: NewGELU(),
|
| 324 |
+
"gelu_pytorch_tanh":
|
| 325 |
+
lambda: nn.GELU(approximate="tanh"),
|
| 326 |
+
"relu":
|
| 327 |
+
lambda: nn.ReLU(),
|
| 328 |
+
"relu2":
|
| 329 |
+
lambda: ReLUSquaredActivation(),
|
| 330 |
+
"silu":
|
| 331 |
+
lambda: nn.SiLU(),
|
| 332 |
+
"quick_gelu":
|
| 333 |
+
lambda: QuickGELU(),
|
| 334 |
+
})
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def get_act_fn(act_fn_name: str) -> nn.Module:
|
| 338 |
+
"""Get an activation function by name."""
|
| 339 |
+
act_fn_name = act_fn_name.lower()
|
| 340 |
+
if act_fn_name not in _ACTIVATION_REGISTRY:
|
| 341 |
+
raise ValueError(
|
| 342 |
+
f"Activation function {act_fn_name!r} is not supported.")
|
| 343 |
+
|
| 344 |
+
return _ACTIVATION_REGISTRY[act_fn_name]
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
|
| 348 |
+
"gelu": lambda: GeluAndMul(),
|
| 349 |
+
"silu": lambda: SiluAndMul(),
|
| 350 |
+
})
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
|
| 354 |
+
"""Get an activation-and-mul (i.e. SiluAndMul) function by name."""
|
| 355 |
+
act_fn_name = act_fn_name.lower()
|
| 356 |
+
if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
|
| 357 |
+
raise ValueError(
|
| 358 |
+
f"Activation function {act_fn_name!r} is not supported.")
|
| 359 |
+
|
| 360 |
+
return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/__init__.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from contextlib import contextmanager
|
| 4 |
+
from typing import Any, Dict, Optional
|
| 5 |
+
|
| 6 |
+
from vllm.model_executor.layers.fused_moe.layer import (
|
| 7 |
+
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
| 8 |
+
from vllm.triton_utils import HAS_TRITON
|
| 9 |
+
|
| 10 |
+
_config: Optional[Dict[str, Any]] = None
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@contextmanager
|
| 14 |
+
def override_config(config):
|
| 15 |
+
global _config
|
| 16 |
+
old_config = _config
|
| 17 |
+
_config = config
|
| 18 |
+
yield
|
| 19 |
+
_config = old_config
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_config() -> Optional[Dict[str, Any]]:
|
| 23 |
+
return _config
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
__all__ = [
|
| 27 |
+
"FusedMoE",
|
| 28 |
+
"FusedMoEMethodBase",
|
| 29 |
+
"FusedMoeWeightScaleSupported",
|
| 30 |
+
"override_config",
|
| 31 |
+
"get_config",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
if HAS_TRITON:
|
| 35 |
+
# import to register the custom ops
|
| 36 |
+
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
|
| 37 |
+
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
|
| 38 |
+
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
| 39 |
+
fused_experts, fused_moe, fused_topk, get_config_file_name,
|
| 40 |
+
grouped_topk)
|
| 41 |
+
|
| 42 |
+
__all__ += [
|
| 43 |
+
"fused_moe",
|
| 44 |
+
"fused_topk",
|
| 45 |
+
"fused_experts",
|
| 46 |
+
"get_config_file_name",
|
| 47 |
+
"grouped_topk",
|
| 48 |
+
]
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.69 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/__pycache__/fused_marlin_moe.cpython-311.pyc
ADDED
|
Binary file (13.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/__pycache__/fused_moe.cpython-311.pyc
ADDED
|
Binary file (50.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/__pycache__/layer.cpython-311.pyc
ADDED
|
Binary file (24.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/__pycache__/moe_pallas.cpython-311.pyc
ADDED
|
Binary file (3.66 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/__pycache__/moe_torch_iterative.cpython-311.pyc
ADDED
|
Binary file (2.66 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"1": {
|
| 3 |
+
"BLOCK_SIZE_M": 16,
|
| 4 |
+
"BLOCK_SIZE_N": 32,
|
| 5 |
+
"BLOCK_SIZE_K": 256,
|
| 6 |
+
"GROUP_SIZE_M": 16,
|
| 7 |
+
"num_warps": 4,
|
| 8 |
+
"num_stages": 3
|
| 9 |
+
},
|
| 10 |
+
"2": {
|
| 11 |
+
"BLOCK_SIZE_M": 16,
|
| 12 |
+
"BLOCK_SIZE_N": 32,
|
| 13 |
+
"BLOCK_SIZE_K": 256,
|
| 14 |
+
"GROUP_SIZE_M": 16,
|
| 15 |
+
"num_warps": 4,
|
| 16 |
+
"num_stages": 4
|
| 17 |
+
},
|
| 18 |
+
"4": {
|
| 19 |
+
"BLOCK_SIZE_M": 16,
|
| 20 |
+
"BLOCK_SIZE_N": 32,
|
| 21 |
+
"BLOCK_SIZE_K": 256,
|
| 22 |
+
"GROUP_SIZE_M": 1,
|
| 23 |
+
"num_warps": 4,
|
| 24 |
+
"num_stages": 4
|
| 25 |
+
},
|
| 26 |
+
"8": {
|
| 27 |
+
"BLOCK_SIZE_M": 16,
|
| 28 |
+
"BLOCK_SIZE_N": 32,
|
| 29 |
+
"BLOCK_SIZE_K": 256,
|
| 30 |
+
"GROUP_SIZE_M": 16,
|
| 31 |
+
"num_warps": 4,
|
| 32 |
+
"num_stages": 4
|
| 33 |
+
},
|
| 34 |
+
"16": {
|
| 35 |
+
"BLOCK_SIZE_M": 16,
|
| 36 |
+
"BLOCK_SIZE_N": 32,
|
| 37 |
+
"BLOCK_SIZE_K": 256,
|
| 38 |
+
"GROUP_SIZE_M": 16,
|
| 39 |
+
"num_warps": 4,
|
| 40 |
+
"num_stages": 3
|
| 41 |
+
},
|
| 42 |
+
"24": {
|
| 43 |
+
"BLOCK_SIZE_M": 16,
|
| 44 |
+
"BLOCK_SIZE_N": 32,
|
| 45 |
+
"BLOCK_SIZE_K": 256,
|
| 46 |
+
"GROUP_SIZE_M": 1,
|
| 47 |
+
"num_warps": 4,
|
| 48 |
+
"num_stages": 3
|
| 49 |
+
},
|
| 50 |
+
"32": {
|
| 51 |
+
"BLOCK_SIZE_M": 16,
|
| 52 |
+
"BLOCK_SIZE_N": 32,
|
| 53 |
+
"BLOCK_SIZE_K": 256,
|
| 54 |
+
"GROUP_SIZE_M": 1,
|
| 55 |
+
"num_warps": 4,
|
| 56 |
+
"num_stages": 3
|
| 57 |
+
},
|
| 58 |
+
"48": {
|
| 59 |
+
"BLOCK_SIZE_M": 16,
|
| 60 |
+
"BLOCK_SIZE_N": 128,
|
| 61 |
+
"BLOCK_SIZE_K": 128,
|
| 62 |
+
"GROUP_SIZE_M": 1,
|
| 63 |
+
"num_warps": 8,
|
| 64 |
+
"num_stages": 3
|
| 65 |
+
},
|
| 66 |
+
"64": {
|
| 67 |
+
"BLOCK_SIZE_M": 64,
|
| 68 |
+
"BLOCK_SIZE_N": 64,
|
| 69 |
+
"BLOCK_SIZE_K": 64,
|
| 70 |
+
"GROUP_SIZE_M": 1,
|
| 71 |
+
"num_warps": 4,
|
| 72 |
+
"num_stages": 4
|
| 73 |
+
},
|
| 74 |
+
"96": {
|
| 75 |
+
"BLOCK_SIZE_M": 32,
|
| 76 |
+
"BLOCK_SIZE_N": 128,
|
| 77 |
+
"BLOCK_SIZE_K": 128,
|
| 78 |
+
"GROUP_SIZE_M": 1,
|
| 79 |
+
"num_warps": 4,
|
| 80 |
+
"num_stages": 3
|
| 81 |
+
},
|
| 82 |
+
"128": {
|
| 83 |
+
"BLOCK_SIZE_M": 64,
|
| 84 |
+
"BLOCK_SIZE_N": 64,
|
| 85 |
+
"BLOCK_SIZE_K": 64,
|
| 86 |
+
"GROUP_SIZE_M": 1,
|
| 87 |
+
"num_warps": 4,
|
| 88 |
+
"num_stages": 3
|
| 89 |
+
},
|
| 90 |
+
"256": {
|
| 91 |
+
"BLOCK_SIZE_M": 64,
|
| 92 |
+
"BLOCK_SIZE_N": 64,
|
| 93 |
+
"BLOCK_SIZE_K": 64,
|
| 94 |
+
"GROUP_SIZE_M": 1,
|
| 95 |
+
"num_warps": 4,
|
| 96 |
+
"num_stages": 4
|
| 97 |
+
},
|
| 98 |
+
"512": {
|
| 99 |
+
"BLOCK_SIZE_M": 64,
|
| 100 |
+
"BLOCK_SIZE_N": 64,
|
| 101 |
+
"BLOCK_SIZE_K": 64,
|
| 102 |
+
"GROUP_SIZE_M": 32,
|
| 103 |
+
"num_warps": 4,
|
| 104 |
+
"num_stages": 3
|
| 105 |
+
},
|
| 106 |
+
"1024": {
|
| 107 |
+
"BLOCK_SIZE_M": 256,
|
| 108 |
+
"BLOCK_SIZE_N": 32,
|
| 109 |
+
"BLOCK_SIZE_K": 64,
|
| 110 |
+
"GROUP_SIZE_M": 32,
|
| 111 |
+
"num_warps": 4,
|
| 112 |
+
"num_stages": 3
|
| 113 |
+
},
|
| 114 |
+
"1536": {
|
| 115 |
+
"BLOCK_SIZE_M": 64,
|
| 116 |
+
"BLOCK_SIZE_N": 256,
|
| 117 |
+
"BLOCK_SIZE_K": 64,
|
| 118 |
+
"GROUP_SIZE_M": 64,
|
| 119 |
+
"num_warps": 4,
|
| 120 |
+
"num_stages": 4
|
| 121 |
+
},
|
| 122 |
+
"2048": {
|
| 123 |
+
"BLOCK_SIZE_M": 64,
|
| 124 |
+
"BLOCK_SIZE_N": 256,
|
| 125 |
+
"BLOCK_SIZE_K": 64,
|
| 126 |
+
"GROUP_SIZE_M": 64,
|
| 127 |
+
"num_warps": 4,
|
| 128 |
+
"num_stages": 4
|
| 129 |
+
},
|
| 130 |
+
"3072": {
|
| 131 |
+
"BLOCK_SIZE_M": 64,
|
| 132 |
+
"BLOCK_SIZE_N": 256,
|
| 133 |
+
"BLOCK_SIZE_K": 64,
|
| 134 |
+
"GROUP_SIZE_M": 32,
|
| 135 |
+
"num_warps": 4,
|
| 136 |
+
"num_stages": 4
|
| 137 |
+
},
|
| 138 |
+
"4096": {
|
| 139 |
+
"BLOCK_SIZE_M": 64,
|
| 140 |
+
"BLOCK_SIZE_N": 256,
|
| 141 |
+
"BLOCK_SIZE_K": 64,
|
| 142 |
+
"GROUP_SIZE_M": 64,
|
| 143 |
+
"num_warps": 4,
|
| 144 |
+
"num_stages": 4
|
| 145 |
+
},
|
| 146 |
+
"5120": {
|
| 147 |
+
"BLOCK_SIZE_M": 64,
|
| 148 |
+
"BLOCK_SIZE_N": 256,
|
| 149 |
+
"BLOCK_SIZE_K": 64,
|
| 150 |
+
"GROUP_SIZE_M": 32,
|
| 151 |
+
"num_warps": 4,
|
| 152 |
+
"num_stages": 4
|
| 153 |
+
},
|
| 154 |
+
"9216": {
|
| 155 |
+
"BLOCK_SIZE_M": 64,
|
| 156 |
+
"BLOCK_SIZE_N": 256,
|
| 157 |
+
"BLOCK_SIZE_K": 64,
|
| 158 |
+
"GROUP_SIZE_M": 32,
|
| 159 |
+
"num_warps": 4,
|
| 160 |
+
"num_stages": 4
|
| 161 |
+
},
|
| 162 |
+
"13312": {
|
| 163 |
+
"BLOCK_SIZE_M": 64,
|
| 164 |
+
"BLOCK_SIZE_N": 256,
|
| 165 |
+
"BLOCK_SIZE_K": 64,
|
| 166 |
+
"GROUP_SIZE_M": 16,
|
| 167 |
+
"num_warps": 4,
|
| 168 |
+
"num_stages": 4
|
| 169 |
+
},
|
| 170 |
+
"17408": {
|
| 171 |
+
"BLOCK_SIZE_M": 64,
|
| 172 |
+
"BLOCK_SIZE_N": 256,
|
| 173 |
+
"BLOCK_SIZE_K": 64,
|
| 174 |
+
"GROUP_SIZE_M": 16,
|
| 175 |
+
"num_warps": 4,
|
| 176 |
+
"num_stages": 4
|
| 177 |
+
},
|
| 178 |
+
"25600": {
|
| 179 |
+
"BLOCK_SIZE_M": 64,
|
| 180 |
+
"BLOCK_SIZE_N": 256,
|
| 181 |
+
"BLOCK_SIZE_K": 64,
|
| 182 |
+
"GROUP_SIZE_M": 16,
|
| 183 |
+
"num_warps": 4,
|
| 184 |
+
"num_stages": 4
|
| 185 |
+
},
|
| 186 |
+
"33792": {
|
| 187 |
+
"BLOCK_SIZE_M": 64,
|
| 188 |
+
"BLOCK_SIZE_N": 256,
|
| 189 |
+
"BLOCK_SIZE_K": 64,
|
| 190 |
+
"GROUP_SIZE_M": 16,
|
| 191 |
+
"num_warps": 4,
|
| 192 |
+
"num_stages": 4
|
| 193 |
+
},
|
| 194 |
+
"41984": {
|
| 195 |
+
"BLOCK_SIZE_M": 64,
|
| 196 |
+
"BLOCK_SIZE_N": 256,
|
| 197 |
+
"BLOCK_SIZE_K": 64,
|
| 198 |
+
"GROUP_SIZE_M": 16,
|
| 199 |
+
"num_warps": 4,
|
| 200 |
+
"num_stages": 4
|
| 201 |
+
},
|
| 202 |
+
"50176": {
|
| 203 |
+
"BLOCK_SIZE_M": 64,
|
| 204 |
+
"BLOCK_SIZE_N": 256,
|
| 205 |
+
"BLOCK_SIZE_K": 64,
|
| 206 |
+
"GROUP_SIZE_M": 16,
|
| 207 |
+
"num_warps": 4,
|
| 208 |
+
"num_stages": 4
|
| 209 |
+
},
|
| 210 |
+
"58368": {
|
| 211 |
+
"BLOCK_SIZE_M": 64,
|
| 212 |
+
"BLOCK_SIZE_N": 256,
|
| 213 |
+
"BLOCK_SIZE_K": 64,
|
| 214 |
+
"GROUP_SIZE_M": 16,
|
| 215 |
+
"num_warps": 4,
|
| 216 |
+
"num_stages": 4
|
| 217 |
+
}
|
| 218 |
+
}
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"1": {
|
| 3 |
+
"BLOCK_SIZE_M": 16,
|
| 4 |
+
"BLOCK_SIZE_N": 32,
|
| 5 |
+
"BLOCK_SIZE_K": 128,
|
| 6 |
+
"GROUP_SIZE_M": 1,
|
| 7 |
+
"num_warps": 4,
|
| 8 |
+
"num_stages": 3
|
| 9 |
+
},
|
| 10 |
+
"2": {
|
| 11 |
+
"BLOCK_SIZE_M": 16,
|
| 12 |
+
"BLOCK_SIZE_N": 32,
|
| 13 |
+
"BLOCK_SIZE_K": 128,
|
| 14 |
+
"GROUP_SIZE_M": 1,
|
| 15 |
+
"num_warps": 4,
|
| 16 |
+
"num_stages": 3
|
| 17 |
+
},
|
| 18 |
+
"4": {
|
| 19 |
+
"BLOCK_SIZE_M": 16,
|
| 20 |
+
"BLOCK_SIZE_N": 64,
|
| 21 |
+
"BLOCK_SIZE_K": 128,
|
| 22 |
+
"GROUP_SIZE_M": 1,
|
| 23 |
+
"num_warps": 4,
|
| 24 |
+
"num_stages": 5
|
| 25 |
+
},
|
| 26 |
+
"8": {
|
| 27 |
+
"BLOCK_SIZE_M": 16,
|
| 28 |
+
"BLOCK_SIZE_N": 64,
|
| 29 |
+
"BLOCK_SIZE_K": 256,
|
| 30 |
+
"GROUP_SIZE_M": 1,
|
| 31 |
+
"num_warps": 8,
|
| 32 |
+
"num_stages": 3
|
| 33 |
+
},
|
| 34 |
+
"16": {
|
| 35 |
+
"BLOCK_SIZE_M": 16,
|
| 36 |
+
"BLOCK_SIZE_N": 64,
|
| 37 |
+
"BLOCK_SIZE_K": 256,
|
| 38 |
+
"GROUP_SIZE_M": 1,
|
| 39 |
+
"num_warps": 8,
|
| 40 |
+
"num_stages": 3
|
| 41 |
+
},
|
| 42 |
+
"24": {
|
| 43 |
+
"BLOCK_SIZE_M": 16,
|
| 44 |
+
"BLOCK_SIZE_N": 64,
|
| 45 |
+
"BLOCK_SIZE_K": 256,
|
| 46 |
+
"GROUP_SIZE_M": 32,
|
| 47 |
+
"num_warps": 8,
|
| 48 |
+
"num_stages": 3
|
| 49 |
+
},
|
| 50 |
+
"32": {
|
| 51 |
+
"BLOCK_SIZE_M": 16,
|
| 52 |
+
"BLOCK_SIZE_N": 64,
|
| 53 |
+
"BLOCK_SIZE_K": 128,
|
| 54 |
+
"GROUP_SIZE_M": 16,
|
| 55 |
+
"num_warps": 4,
|
| 56 |
+
"num_stages": 4
|
| 57 |
+
},
|
| 58 |
+
"48": {
|
| 59 |
+
"BLOCK_SIZE_M": 16,
|
| 60 |
+
"BLOCK_SIZE_N": 64,
|
| 61 |
+
"BLOCK_SIZE_K": 128,
|
| 62 |
+
"GROUP_SIZE_M": 1,
|
| 63 |
+
"num_warps": 4,
|
| 64 |
+
"num_stages": 5
|
| 65 |
+
},
|
| 66 |
+
"64": {
|
| 67 |
+
"BLOCK_SIZE_M": 16,
|
| 68 |
+
"BLOCK_SIZE_N": 64,
|
| 69 |
+
"BLOCK_SIZE_K": 128,
|
| 70 |
+
"GROUP_SIZE_M": 1,
|
| 71 |
+
"num_warps": 4,
|
| 72 |
+
"num_stages": 5
|
| 73 |
+
},
|
| 74 |
+
"96": {
|
| 75 |
+
"BLOCK_SIZE_M": 16,
|
| 76 |
+
"BLOCK_SIZE_N": 64,
|
| 77 |
+
"BLOCK_SIZE_K": 256,
|
| 78 |
+
"GROUP_SIZE_M": 64,
|
| 79 |
+
"num_warps": 4,
|
| 80 |
+
"num_stages": 3
|
| 81 |
+
},
|
| 82 |
+
"128": {
|
| 83 |
+
"BLOCK_SIZE_M": 32,
|
| 84 |
+
"BLOCK_SIZE_N": 128,
|
| 85 |
+
"BLOCK_SIZE_K": 128,
|
| 86 |
+
"GROUP_SIZE_M": 1,
|
| 87 |
+
"num_warps": 4,
|
| 88 |
+
"num_stages": 3
|
| 89 |
+
},
|
| 90 |
+
"256": {
|
| 91 |
+
"BLOCK_SIZE_M": 32,
|
| 92 |
+
"BLOCK_SIZE_N": 128,
|
| 93 |
+
"BLOCK_SIZE_K": 128,
|
| 94 |
+
"GROUP_SIZE_M": 16,
|
| 95 |
+
"num_warps": 4,
|
| 96 |
+
"num_stages": 3
|
| 97 |
+
},
|
| 98 |
+
"512": {
|
| 99 |
+
"BLOCK_SIZE_M": 64,
|
| 100 |
+
"BLOCK_SIZE_N": 256,
|
| 101 |
+
"BLOCK_SIZE_K": 64,
|
| 102 |
+
"GROUP_SIZE_M": 1,
|
| 103 |
+
"num_warps": 4,
|
| 104 |
+
"num_stages": 3
|
| 105 |
+
},
|
| 106 |
+
"1024": {
|
| 107 |
+
"BLOCK_SIZE_M": 64,
|
| 108 |
+
"BLOCK_SIZE_N": 256,
|
| 109 |
+
"BLOCK_SIZE_K": 64,
|
| 110 |
+
"GROUP_SIZE_M": 16,
|
| 111 |
+
"num_warps": 4,
|
| 112 |
+
"num_stages": 4
|
| 113 |
+
},
|
| 114 |
+
"1536": {
|
| 115 |
+
"BLOCK_SIZE_M": 64,
|
| 116 |
+
"BLOCK_SIZE_N": 256,
|
| 117 |
+
"BLOCK_SIZE_K": 64,
|
| 118 |
+
"GROUP_SIZE_M": 16,
|
| 119 |
+
"num_warps": 4,
|
| 120 |
+
"num_stages": 4
|
| 121 |
+
},
|
| 122 |
+
"2048": {
|
| 123 |
+
"BLOCK_SIZE_M": 64,
|
| 124 |
+
"BLOCK_SIZE_N": 256,
|
| 125 |
+
"BLOCK_SIZE_K": 64,
|
| 126 |
+
"GROUP_SIZE_M": 32,
|
| 127 |
+
"num_warps": 4,
|
| 128 |
+
"num_stages": 4
|
| 129 |
+
},
|
| 130 |
+
"3072": {
|
| 131 |
+
"BLOCK_SIZE_M": 64,
|
| 132 |
+
"BLOCK_SIZE_N": 256,
|
| 133 |
+
"BLOCK_SIZE_K": 64,
|
| 134 |
+
"GROUP_SIZE_M": 32,
|
| 135 |
+
"num_warps": 4,
|
| 136 |
+
"num_stages": 3
|
| 137 |
+
},
|
| 138 |
+
"4096": {
|
| 139 |
+
"BLOCK_SIZE_M": 64,
|
| 140 |
+
"BLOCK_SIZE_N": 256,
|
| 141 |
+
"BLOCK_SIZE_K": 64,
|
| 142 |
+
"GROUP_SIZE_M": 16,
|
| 143 |
+
"num_warps": 4,
|
| 144 |
+
"num_stages": 3
|
| 145 |
+
}
|
| 146 |
+
}
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"1": {
|
| 3 |
+
"BLOCK_SIZE_M": 16,
|
| 4 |
+
"BLOCK_SIZE_N": 32,
|
| 5 |
+
"BLOCK_SIZE_K": 128,
|
| 6 |
+
"GROUP_SIZE_M": 1,
|
| 7 |
+
"num_warps": 4,
|
| 8 |
+
"num_stages": 3
|
| 9 |
+
},
|
| 10 |
+
"2": {
|
| 11 |
+
"BLOCK_SIZE_M": 16,
|
| 12 |
+
"BLOCK_SIZE_N": 32,
|
| 13 |
+
"BLOCK_SIZE_K": 128,
|
| 14 |
+
"GROUP_SIZE_M": 1,
|
| 15 |
+
"num_warps": 4,
|
| 16 |
+
"num_stages": 3
|
| 17 |
+
},
|
| 18 |
+
"4": {
|
| 19 |
+
"BLOCK_SIZE_M": 16,
|
| 20 |
+
"BLOCK_SIZE_N": 32,
|
| 21 |
+
"BLOCK_SIZE_K": 256,
|
| 22 |
+
"GROUP_SIZE_M": 16,
|
| 23 |
+
"num_warps": 4,
|
| 24 |
+
"num_stages": 2
|
| 25 |
+
},
|
| 26 |
+
"8": {
|
| 27 |
+
"BLOCK_SIZE_M": 16,
|
| 28 |
+
"BLOCK_SIZE_N": 64,
|
| 29 |
+
"BLOCK_SIZE_K": 256,
|
| 30 |
+
"GROUP_SIZE_M": 16,
|
| 31 |
+
"num_warps": 4,
|
| 32 |
+
"num_stages": 3
|
| 33 |
+
},
|
| 34 |
+
"16": {
|
| 35 |
+
"BLOCK_SIZE_M": 16,
|
| 36 |
+
"BLOCK_SIZE_N": 64,
|
| 37 |
+
"BLOCK_SIZE_K": 128,
|
| 38 |
+
"GROUP_SIZE_M": 32,
|
| 39 |
+
"num_warps": 4,
|
| 40 |
+
"num_stages": 4
|
| 41 |
+
},
|
| 42 |
+
"24": {
|
| 43 |
+
"BLOCK_SIZE_M": 16,
|
| 44 |
+
"BLOCK_SIZE_N": 64,
|
| 45 |
+
"BLOCK_SIZE_K": 256,
|
| 46 |
+
"GROUP_SIZE_M": 16,
|
| 47 |
+
"num_warps": 4,
|
| 48 |
+
"num_stages": 3
|
| 49 |
+
},
|
| 50 |
+
"32": {
|
| 51 |
+
"BLOCK_SIZE_M": 16,
|
| 52 |
+
"BLOCK_SIZE_N": 128,
|
| 53 |
+
"BLOCK_SIZE_K": 256,
|
| 54 |
+
"GROUP_SIZE_M": 64,
|
| 55 |
+
"num_warps": 8,
|
| 56 |
+
"num_stages": 3
|
| 57 |
+
},
|
| 58 |
+
"48": {
|
| 59 |
+
"BLOCK_SIZE_M": 16,
|
| 60 |
+
"BLOCK_SIZE_N": 128,
|
| 61 |
+
"BLOCK_SIZE_K": 256,
|
| 62 |
+
"GROUP_SIZE_M": 1,
|
| 63 |
+
"num_warps": 8,
|
| 64 |
+
"num_stages": 3
|
| 65 |
+
},
|
| 66 |
+
"64": {
|
| 67 |
+
"BLOCK_SIZE_M": 16,
|
| 68 |
+
"BLOCK_SIZE_N": 128,
|
| 69 |
+
"BLOCK_SIZE_K": 256,
|
| 70 |
+
"GROUP_SIZE_M": 1,
|
| 71 |
+
"num_warps": 8,
|
| 72 |
+
"num_stages": 3
|
| 73 |
+
},
|
| 74 |
+
"96": {
|
| 75 |
+
"BLOCK_SIZE_M": 16,
|
| 76 |
+
"BLOCK_SIZE_N": 64,
|
| 77 |
+
"BLOCK_SIZE_K": 256,
|
| 78 |
+
"GROUP_SIZE_M": 64,
|
| 79 |
+
"num_warps": 4,
|
| 80 |
+
"num_stages": 3
|
| 81 |
+
},
|
| 82 |
+
"128": {
|
| 83 |
+
"BLOCK_SIZE_M": 32,
|
| 84 |
+
"BLOCK_SIZE_N": 128,
|
| 85 |
+
"BLOCK_SIZE_K": 128,
|
| 86 |
+
"GROUP_SIZE_M": 1,
|
| 87 |
+
"num_warps": 4,
|
| 88 |
+
"num_stages": 3
|
| 89 |
+
},
|
| 90 |
+
"256": {
|
| 91 |
+
"BLOCK_SIZE_M": 32,
|
| 92 |
+
"BLOCK_SIZE_N": 128,
|
| 93 |
+
"BLOCK_SIZE_K": 128,
|
| 94 |
+
"GROUP_SIZE_M": 16,
|
| 95 |
+
"num_warps": 4,
|
| 96 |
+
"num_stages": 3
|
| 97 |
+
},
|
| 98 |
+
"512": {
|
| 99 |
+
"BLOCK_SIZE_M": 64,
|
| 100 |
+
"BLOCK_SIZE_N": 256,
|
| 101 |
+
"BLOCK_SIZE_K": 64,
|
| 102 |
+
"GROUP_SIZE_M": 64,
|
| 103 |
+
"num_warps": 4,
|
| 104 |
+
"num_stages": 4
|
| 105 |
+
},
|
| 106 |
+
"1024": {
|
| 107 |
+
"BLOCK_SIZE_M": 64,
|
| 108 |
+
"BLOCK_SIZE_N": 256,
|
| 109 |
+
"BLOCK_SIZE_K": 64,
|
| 110 |
+
"GROUP_SIZE_M": 16,
|
| 111 |
+
"num_warps": 4,
|
| 112 |
+
"num_stages": 4
|
| 113 |
+
},
|
| 114 |
+
"1536": {
|
| 115 |
+
"BLOCK_SIZE_M": 64,
|
| 116 |
+
"BLOCK_SIZE_N": 256,
|
| 117 |
+
"BLOCK_SIZE_K": 64,
|
| 118 |
+
"GROUP_SIZE_M": 16,
|
| 119 |
+
"num_warps": 4,
|
| 120 |
+
"num_stages": 4
|
| 121 |
+
},
|
| 122 |
+
"2048": {
|
| 123 |
+
"BLOCK_SIZE_M": 64,
|
| 124 |
+
"BLOCK_SIZE_N": 256,
|
| 125 |
+
"BLOCK_SIZE_K": 64,
|
| 126 |
+
"GROUP_SIZE_M": 16,
|
| 127 |
+
"num_warps": 4,
|
| 128 |
+
"num_stages": 4
|
| 129 |
+
},
|
| 130 |
+
"3072": {
|
| 131 |
+
"BLOCK_SIZE_M": 64,
|
| 132 |
+
"BLOCK_SIZE_N": 256,
|
| 133 |
+
"BLOCK_SIZE_K": 64,
|
| 134 |
+
"GROUP_SIZE_M": 32,
|
| 135 |
+
"num_warps": 4,
|
| 136 |
+
"num_stages": 4
|
| 137 |
+
},
|
| 138 |
+
"4096": {
|
| 139 |
+
"BLOCK_SIZE_M": 64,
|
| 140 |
+
"BLOCK_SIZE_N": 256,
|
| 141 |
+
"BLOCK_SIZE_K": 64,
|
| 142 |
+
"GROUP_SIZE_M": 16,
|
| 143 |
+
"num_warps": 4,
|
| 144 |
+
"num_stages": 4
|
| 145 |
+
}
|
| 146 |
+
}
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""Fused MoE utilities for GPTQ."""
|
| 3 |
+
import functools
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
| 9 |
+
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
|
| 10 |
+
from vllm.scalar_type import scalar_types
|
| 11 |
+
from vllm.utils import direct_register_custom_op
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_scalar_type(num_bits: int, has_zp: bool):
|
| 15 |
+
if has_zp:
|
| 16 |
+
assert num_bits == 4
|
| 17 |
+
return scalar_types.uint4
|
| 18 |
+
else:
|
| 19 |
+
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def single_marlin_moe(
|
| 23 |
+
hidden_states: torch.Tensor,
|
| 24 |
+
w: torch.Tensor,
|
| 25 |
+
scales: torch.Tensor,
|
| 26 |
+
gating_output: torch.Tensor,
|
| 27 |
+
topk: int,
|
| 28 |
+
renormalize: bool,
|
| 29 |
+
g_idx: Optional[torch.Tensor] = None,
|
| 30 |
+
sort_indices: Optional[torch.Tensor] = None,
|
| 31 |
+
w_zeros: Optional[torch.Tensor] = None,
|
| 32 |
+
num_bits: int = 8,
|
| 33 |
+
is_k_full: bool = True,
|
| 34 |
+
) -> torch.Tensor:
|
| 35 |
+
"""
|
| 36 |
+
This function computes the multiplication of hidden_states with expert
|
| 37 |
+
weights used in Marlin MoE, using weights w and top-k gating mechanism.
|
| 38 |
+
Its purpose is testing and debugging the fused MoE kernel.
|
| 39 |
+
|
| 40 |
+
Parameters:
|
| 41 |
+
- hidden_states (torch.Tensor): The input tensor to the Marlin Mul.
|
| 42 |
+
- w (torch.Tensor): The set of expert weights.
|
| 43 |
+
- scales (torch.Tensor): The quantization scales.
|
| 44 |
+
- gating_output (torch.Tensor): The output of the gating operation
|
| 45 |
+
(before softmax).
|
| 46 |
+
- g_idx (Optional[torch.Tensor]): Optional act_order indices.
|
| 47 |
+
- sort_indices (Optional[torch.Tensor]): Optional act_order input
|
| 48 |
+
permutation.
|
| 49 |
+
- topk (int): The number of top-k experts to select.
|
| 50 |
+
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
| 51 |
+
- w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
|
| 52 |
+
- num_bits (bool): The number of bits in expert weights quantization.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
- torch.Tensor: The output tensor after applying the MoE layer.
|
| 56 |
+
"""
|
| 57 |
+
# Check constraints.
|
| 58 |
+
assert hidden_states.shape[0] == gating_output.shape[0], (
|
| 59 |
+
"Number of tokens mismatch")
|
| 60 |
+
assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch"
|
| 61 |
+
assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch"
|
| 62 |
+
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
| 63 |
+
assert w.is_contiguous(), "Expert weights must be contiguous"
|
| 64 |
+
assert hidden_states.dtype == torch.float16
|
| 65 |
+
assert num_bits in [4, 8]
|
| 66 |
+
|
| 67 |
+
M, K = hidden_states.shape
|
| 68 |
+
E = w.shape[0]
|
| 69 |
+
N = w.shape[2] // (num_bits // 2)
|
| 70 |
+
|
| 71 |
+
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
|
| 72 |
+
renormalize)
|
| 73 |
+
|
| 74 |
+
# This might not be an optimal config for a single MMM
|
| 75 |
+
get_config_func = functools.partial(try_get_optimal_moe_config,
|
| 76 |
+
w.shape,
|
| 77 |
+
w.shape,
|
| 78 |
+
topk_ids.shape[1],
|
| 79 |
+
None,
|
| 80 |
+
is_marlin=True)
|
| 81 |
+
config = get_config_func(M)
|
| 82 |
+
|
| 83 |
+
block_size_m = config['BLOCK_SIZE_M']
|
| 84 |
+
|
| 85 |
+
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
|
| 86 |
+
|
| 87 |
+
max_workspace_size = (N // 64) * 16
|
| 88 |
+
workspace = torch.zeros(max_workspace_size,
|
| 89 |
+
dtype=torch.int,
|
| 90 |
+
device=hidden_states.device,
|
| 91 |
+
requires_grad=False)
|
| 92 |
+
|
| 93 |
+
has_zero_point = w_zeros is not None
|
| 94 |
+
if w_zeros is None:
|
| 95 |
+
w_zeros = torch.empty((0, 0),
|
| 96 |
+
dtype=hidden_states.dtype,
|
| 97 |
+
device=hidden_states.device,
|
| 98 |
+
requires_grad=False)
|
| 99 |
+
|
| 100 |
+
if g_idx is None:
|
| 101 |
+
g_idx = torch.empty((0, 0),
|
| 102 |
+
dtype=torch.int32,
|
| 103 |
+
device=hidden_states.device,
|
| 104 |
+
requires_grad=False)
|
| 105 |
+
|
| 106 |
+
if sort_indices is None:
|
| 107 |
+
sort_indices = torch.empty((0),
|
| 108 |
+
dtype=torch.int32,
|
| 109 |
+
device=hidden_states.device,
|
| 110 |
+
requires_grad=False)
|
| 111 |
+
|
| 112 |
+
scalar_type = get_scalar_type(num_bits, has_zero_point)
|
| 113 |
+
|
| 114 |
+
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
|
| 115 |
+
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
|
| 116 |
+
w_zeros, g_idx, sort_indices, workspace, scalar_type.id, M, N, K,
|
| 117 |
+
is_k_full, E, topk, block_size_m, True, False)
|
| 118 |
+
|
| 119 |
+
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def single_marlin_moe_fake(
|
| 123 |
+
hidden_states: torch.Tensor,
|
| 124 |
+
w: torch.Tensor,
|
| 125 |
+
scales: torch.Tensor,
|
| 126 |
+
gating_output: torch.Tensor,
|
| 127 |
+
topk: int,
|
| 128 |
+
renormalize: bool,
|
| 129 |
+
g_idx: Optional[torch.Tensor] = None,
|
| 130 |
+
sort_indices: Optional[torch.Tensor] = None,
|
| 131 |
+
w_zeros: Optional[torch.Tensor] = None,
|
| 132 |
+
num_bits: int = 8,
|
| 133 |
+
is_k_full: bool = True,
|
| 134 |
+
) -> torch.Tensor:
|
| 135 |
+
return torch.empty_like(hidden_states)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
direct_register_custom_op(
|
| 139 |
+
op_name="single_marlin_moe",
|
| 140 |
+
op_func=single_marlin_moe,
|
| 141 |
+
mutates_args=[],
|
| 142 |
+
fake_impl=single_marlin_moe_fake,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def fused_marlin_moe(
|
| 147 |
+
hidden_states: torch.Tensor,
|
| 148 |
+
w1: torch.Tensor,
|
| 149 |
+
w2: torch.Tensor,
|
| 150 |
+
w1_scale: torch.Tensor,
|
| 151 |
+
w2_scale: torch.Tensor,
|
| 152 |
+
gating_output: torch.Tensor,
|
| 153 |
+
topk_weights: torch.Tensor,
|
| 154 |
+
topk_ids: torch.Tensor,
|
| 155 |
+
g_idx1: Optional[torch.Tensor] = None,
|
| 156 |
+
g_idx2: Optional[torch.Tensor] = None,
|
| 157 |
+
sort_indices1: Optional[torch.Tensor] = None,
|
| 158 |
+
sort_indices2: Optional[torch.Tensor] = None,
|
| 159 |
+
w1_zeros: Optional[torch.Tensor] = None,
|
| 160 |
+
w2_zeros: Optional[torch.Tensor] = None,
|
| 161 |
+
num_bits: int = 8,
|
| 162 |
+
is_k_full: bool = True,
|
| 163 |
+
) -> torch.Tensor:
|
| 164 |
+
"""
|
| 165 |
+
This function computes a Mixture of Experts (MoE) layer using two sets of
|
| 166 |
+
weights, w1 and w2, and top-k gating mechanism.
|
| 167 |
+
|
| 168 |
+
Parameters:
|
| 169 |
+
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
| 170 |
+
- w1 (torch.Tensor): The first set of expert weights.
|
| 171 |
+
- w2 (torch.Tensor): The second set of expert weights.
|
| 172 |
+
- w1_scale (torch.Tensor): Scale to be used for w1.
|
| 173 |
+
- w2_scale (torch.Tensor): Scale to be used for w2.
|
| 174 |
+
- gating_output (torch.Tensor): The output of the gating operation
|
| 175 |
+
(before softmax).
|
| 176 |
+
- g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
|
| 177 |
+
- g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
|
| 178 |
+
- sort_indices1 (Optional[torch.Tensor]): The first act_order input
|
| 179 |
+
permutation.
|
| 180 |
+
- sort_indices2 (Optional[torch.Tensor]): The second act_order input
|
| 181 |
+
permutation.
|
| 182 |
+
- topk_weights (torch.Tensor): Top-k weights.
|
| 183 |
+
- topk_ids (torch.Tensor): Indices of topk-k elements.
|
| 184 |
+
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
|
| 185 |
+
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
|
| 186 |
+
- num_bits (bool): The number of bits in expert weights quantization.
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
- torch.Tensor: The output tensor after applying the MoE layer.
|
| 190 |
+
"""
|
| 191 |
+
# Check constraints.
|
| 192 |
+
assert hidden_states.shape[0] == gating_output.shape[
|
| 193 |
+
0], "Number of tokens mismatch"
|
| 194 |
+
assert hidden_states.shape[
|
| 195 |
+
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
|
| 196 |
+
assert hidden_states.shape[1] == w2.shape[2] // (
|
| 197 |
+
num_bits // 2), "Hidden size mismatch w2"
|
| 198 |
+
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
| 199 |
+
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
| 200 |
+
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
| 201 |
+
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
| 202 |
+
assert hidden_states.dtype == torch.float16
|
| 203 |
+
assert num_bits in [4, 8]
|
| 204 |
+
|
| 205 |
+
has_no_act_order = (g_idx1 is None and g_idx2 is None
|
| 206 |
+
and sort_indices1 is None and sort_indices2 is None)
|
| 207 |
+
has_all_act_order = (g_idx1 is not None and g_idx2 is not None
|
| 208 |
+
and sort_indices1 is not None
|
| 209 |
+
and sort_indices2 is not None)
|
| 210 |
+
assert has_no_act_order or has_all_act_order, (
|
| 211 |
+
"g_idx and sorted_indices "
|
| 212 |
+
"must be all not None or must be all None")
|
| 213 |
+
|
| 214 |
+
has_no_zp = w1_zeros is None and w2_zeros is None
|
| 215 |
+
has_all_zp = w1_zeros is not None and w2_zeros is not None
|
| 216 |
+
assert has_no_zp or has_all_zp, ("zero points must be both not None or "
|
| 217 |
+
"must be both None")
|
| 218 |
+
|
| 219 |
+
M, K = hidden_states.shape
|
| 220 |
+
E = w1.shape[0]
|
| 221 |
+
N = w2.shape[1] * 16
|
| 222 |
+
topk = topk_ids.shape[1]
|
| 223 |
+
|
| 224 |
+
get_config_func = functools.partial(
|
| 225 |
+
try_get_optimal_moe_config,
|
| 226 |
+
w1.shape,
|
| 227 |
+
w2.shape,
|
| 228 |
+
topk_ids.shape[1],
|
| 229 |
+
None,
|
| 230 |
+
is_marlin=True,
|
| 231 |
+
)
|
| 232 |
+
config = get_config_func(M)
|
| 233 |
+
|
| 234 |
+
block_size_m = config["BLOCK_SIZE_M"]
|
| 235 |
+
|
| 236 |
+
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
|
| 237 |
+
|
| 238 |
+
max_workspace_size = (max(2 * N, K) // 64) * 16
|
| 239 |
+
workspace = torch.zeros(max_workspace_size,
|
| 240 |
+
dtype=torch.int,
|
| 241 |
+
device="cuda",
|
| 242 |
+
requires_grad=False)
|
| 243 |
+
|
| 244 |
+
if has_no_zp:
|
| 245 |
+
w1_zeros = torch.empty((0, 0),
|
| 246 |
+
dtype=hidden_states.dtype,
|
| 247 |
+
device=hidden_states.device,
|
| 248 |
+
requires_grad=False)
|
| 249 |
+
w2_zeros = torch.empty((0, 0),
|
| 250 |
+
dtype=hidden_states.dtype,
|
| 251 |
+
device=hidden_states.device,
|
| 252 |
+
requires_grad=False)
|
| 253 |
+
|
| 254 |
+
if has_no_act_order:
|
| 255 |
+
g_idx1 = torch.empty((0, 0),
|
| 256 |
+
dtype=torch.int32,
|
| 257 |
+
device=hidden_states.device,
|
| 258 |
+
requires_grad=False)
|
| 259 |
+
g_idx2 = torch.empty((0, 0),
|
| 260 |
+
dtype=torch.int32,
|
| 261 |
+
device=hidden_states.device,
|
| 262 |
+
requires_grad=False)
|
| 263 |
+
sort_indices1 = torch.empty((0),
|
| 264 |
+
dtype=torch.int32,
|
| 265 |
+
device=hidden_states.device,
|
| 266 |
+
requires_grad=False)
|
| 267 |
+
sort_indices2 = torch.empty((0, 0),
|
| 268 |
+
dtype=torch.int32,
|
| 269 |
+
device=hidden_states.device,
|
| 270 |
+
requires_grad=False)
|
| 271 |
+
|
| 272 |
+
scalar_type1 = get_scalar_type(num_bits, has_all_zp)
|
| 273 |
+
scalar_type2 = get_scalar_type(num_bits, has_all_zp)
|
| 274 |
+
|
| 275 |
+
intermediate_cache2 = torch.empty(
|
| 276 |
+
(M * topk_ids.shape[1], N),
|
| 277 |
+
device=hidden_states.device,
|
| 278 |
+
dtype=hidden_states.dtype,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe(
|
| 282 |
+
hidden_states,
|
| 283 |
+
w1,
|
| 284 |
+
sorted_token_ids,
|
| 285 |
+
topk_weights,
|
| 286 |
+
topk_ids,
|
| 287 |
+
w1_scale,
|
| 288 |
+
w1_zeros,
|
| 289 |
+
g_idx1,
|
| 290 |
+
sort_indices1,
|
| 291 |
+
workspace,
|
| 292 |
+
scalar_type1.id,
|
| 293 |
+
M,
|
| 294 |
+
2 * N,
|
| 295 |
+
K,
|
| 296 |
+
is_k_full,
|
| 297 |
+
E,
|
| 298 |
+
topk,
|
| 299 |
+
block_size_m,
|
| 300 |
+
True,
|
| 301 |
+
False,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
torch.ops._C.silu_and_mul(intermediate_cache2,
|
| 305 |
+
intermediate_cache1.view(-1, 2 * N))
|
| 306 |
+
|
| 307 |
+
intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
|
| 308 |
+
intermediate_cache2,
|
| 309 |
+
w2,
|
| 310 |
+
sorted_token_ids,
|
| 311 |
+
topk_weights,
|
| 312 |
+
topk_ids,
|
| 313 |
+
w2_scale,
|
| 314 |
+
w2_zeros,
|
| 315 |
+
g_idx2,
|
| 316 |
+
sort_indices2,
|
| 317 |
+
workspace,
|
| 318 |
+
scalar_type2.id,
|
| 319 |
+
M,
|
| 320 |
+
K,
|
| 321 |
+
N,
|
| 322 |
+
is_k_full,
|
| 323 |
+
E,
|
| 324 |
+
topk,
|
| 325 |
+
block_size_m,
|
| 326 |
+
False,
|
| 327 |
+
True,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
| 331 |
+
dim=1)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def fused_marlin_moe_fake(
|
| 335 |
+
hidden_states: torch.Tensor,
|
| 336 |
+
w1: torch.Tensor,
|
| 337 |
+
w2: torch.Tensor,
|
| 338 |
+
w1_scale: torch.Tensor,
|
| 339 |
+
w2_scale: torch.Tensor,
|
| 340 |
+
gating_output: torch.Tensor,
|
| 341 |
+
topk_weights: torch.Tensor,
|
| 342 |
+
topk_ids: torch.Tensor,
|
| 343 |
+
g_idx1: Optional[torch.Tensor] = None,
|
| 344 |
+
g_idx2: Optional[torch.Tensor] = None,
|
| 345 |
+
sort_indices1: Optional[torch.Tensor] = None,
|
| 346 |
+
sort_indices2: Optional[torch.Tensor] = None,
|
| 347 |
+
w1_zeros: Optional[torch.Tensor] = None,
|
| 348 |
+
w2_zeros: Optional[torch.Tensor] = None,
|
| 349 |
+
num_bits: int = 8,
|
| 350 |
+
is_k_full: bool = True,
|
| 351 |
+
) -> torch.Tensor:
|
| 352 |
+
return torch.empty_like(hidden_states)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
direct_register_custom_op(
|
| 356 |
+
op_name="fused_marlin_moe",
|
| 357 |
+
op_func=fused_marlin_moe,
|
| 358 |
+
mutates_args=[],
|
| 359 |
+
fake_impl=fused_marlin_moe_fake,
|
| 360 |
+
)
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/fused_moe.py
ADDED
|
@@ -0,0 +1,1363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""Fused MoE kernel."""
|
| 3 |
+
import functools
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import triton
|
| 10 |
+
import triton.language as tl
|
| 11 |
+
|
| 12 |
+
import vllm.envs as envs
|
| 13 |
+
from vllm import _custom_ops as ops
|
| 14 |
+
from vllm.logger import init_logger
|
| 15 |
+
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
| 16 |
+
per_token_group_quant_fp8)
|
| 17 |
+
from vllm.platforms import current_platform
|
| 18 |
+
from vllm.utils import direct_register_custom_op
|
| 19 |
+
|
| 20 |
+
logger = init_logger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@triton.jit
|
| 24 |
+
def fused_moe_kernel_gptq_awq(
|
| 25 |
+
# Pointers to matrices
|
| 26 |
+
a_ptr,
|
| 27 |
+
b_ptr,
|
| 28 |
+
c_ptr,
|
| 29 |
+
b_scale_ptr,
|
| 30 |
+
b_zp_ptr,
|
| 31 |
+
topk_weights_ptr,
|
| 32 |
+
sorted_token_ids_ptr,
|
| 33 |
+
expert_ids_ptr,
|
| 34 |
+
num_tokens_post_padded_ptr,
|
| 35 |
+
# Matrix dimensions
|
| 36 |
+
N: tl.constexpr,
|
| 37 |
+
K: tl.constexpr,
|
| 38 |
+
EM,
|
| 39 |
+
num_valid_tokens,
|
| 40 |
+
# The stride variables represent how much to increase the ptr by when
|
| 41 |
+
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
| 42 |
+
# how much to increase `a_ptr` by to get the element one row down
|
| 43 |
+
# (A has M rows).
|
| 44 |
+
stride_am,
|
| 45 |
+
stride_ak,
|
| 46 |
+
stride_be,
|
| 47 |
+
stride_bk,
|
| 48 |
+
stride_bn,
|
| 49 |
+
stride_cm,
|
| 50 |
+
stride_cn,
|
| 51 |
+
stride_bse,
|
| 52 |
+
stride_bsk,
|
| 53 |
+
stride_bsn,
|
| 54 |
+
stride_bze,
|
| 55 |
+
stride_bzk,
|
| 56 |
+
stride_bzn,
|
| 57 |
+
block_k_diviable: tl.constexpr,
|
| 58 |
+
group_size: tl.constexpr,
|
| 59 |
+
# Meta-parameters
|
| 60 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 61 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 62 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 63 |
+
GROUP_SIZE_M: tl.constexpr,
|
| 64 |
+
MUL_ROUTED_WEIGHT: tl.constexpr,
|
| 65 |
+
top_k: tl.constexpr,
|
| 66 |
+
compute_type: tl.constexpr,
|
| 67 |
+
has_zp: tl.constexpr,
|
| 68 |
+
use_int4_w4a16: tl.constexpr,
|
| 69 |
+
use_int8_w8a16: tl.constexpr):
|
| 70 |
+
"""
|
| 71 |
+
Implements the fused computation for a Mixture of Experts (MOE) using
|
| 72 |
+
token and expert matrices.
|
| 73 |
+
|
| 74 |
+
Key Parameters:
|
| 75 |
+
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
| 76 |
+
be any shape representing batches and K is the feature dimension of
|
| 77 |
+
each token.
|
| 78 |
+
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
| 79 |
+
the number of experts, K is the input feature dimension, and N is
|
| 80 |
+
the output feature dimension.
|
| 81 |
+
- C: The output cache tensor with shape (M, topk, N), where M is the
|
| 82 |
+
total number of tokens post padding, topk is the number of times
|
| 83 |
+
each token is repeated, and N is the output feature dimension.
|
| 84 |
+
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
| 85 |
+
repeated topk times and arranged by the expert index they are
|
| 86 |
+
assigned to.
|
| 87 |
+
- expert_ids: A tensor containing the indices of the expert for each
|
| 88 |
+
block. It determines which expert matrix from B should be used for
|
| 89 |
+
each block in A.
|
| 90 |
+
This kernel performs the multiplication of a token by its corresponding
|
| 91 |
+
expert matrix as determined by `expert_ids`. The sorting of
|
| 92 |
+
`sorted_token_ids` by expert index and padding ensures divisibility by
|
| 93 |
+
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
| 94 |
+
multiplication across different blocks processed by the same expert.
|
| 95 |
+
"""
|
| 96 |
+
# -----------------------------------------------------------
|
| 97 |
+
# Map program ids `pid` to the block of C it should compute.
|
| 98 |
+
# This is done in a grouped ordering to promote L2 data reuse.
|
| 99 |
+
pid = tl.program_id(axis=0)
|
| 100 |
+
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
| 101 |
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
| 102 |
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
| 103 |
+
group_id = pid // num_pid_in_group
|
| 104 |
+
first_pid_m = group_id * GROUP_SIZE_M
|
| 105 |
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
| 106 |
+
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
| 107 |
+
pid_n = (pid % num_pid_in_group) // group_size_m
|
| 108 |
+
|
| 109 |
+
# ----------------------------------------------------------
|
| 110 |
+
# Create pointers for the first blocks of A and B.
|
| 111 |
+
# We will advance this pointer as we move in the K direction
|
| 112 |
+
# and accumulate
|
| 113 |
+
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
| 114 |
+
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
|
| 115 |
+
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
| 116 |
+
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
| 117 |
+
return
|
| 118 |
+
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(
|
| 119 |
+
tl.int64)
|
| 120 |
+
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
| 121 |
+
token_mask = offs_token < num_valid_tokens
|
| 122 |
+
|
| 123 |
+
offs_bn = (pid_n * BLOCK_SIZE_N +
|
| 124 |
+
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
| 125 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 126 |
+
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
|
| 127 |
+
offs_k[None, :] * stride_ak)
|
| 128 |
+
|
| 129 |
+
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
| 130 |
+
|
| 131 |
+
if use_int4_w4a16:
|
| 132 |
+
b_ptrs = b_ptr + off_experts * stride_be + \
|
| 133 |
+
(offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn
|
| 134 |
+
b_shifter = (offs_k[:, None] % 2) * 4
|
| 135 |
+
elif use_int8_w8a16:
|
| 136 |
+
b_ptrs = b_ptr + off_experts * stride_be + \
|
| 137 |
+
offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
|
| 138 |
+
|
| 139 |
+
if not has_zp and use_int4_w4a16:
|
| 140 |
+
b_zp_num = 8
|
| 141 |
+
if not has_zp and use_int8_w8a16:
|
| 142 |
+
b_zp_num = 128
|
| 143 |
+
elif has_zp and use_int4_w4a16:
|
| 144 |
+
b_zp_shifter = (offs_bn[None, :] % 2) * 4
|
| 145 |
+
|
| 146 |
+
# -----------------------------------------------------------
|
| 147 |
+
# Iterate to compute a block of the C matrix.
|
| 148 |
+
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
| 149 |
+
# of fp32 values for higher accuracy.
|
| 150 |
+
# `accumulator` will be converted back to fp16 after the loop.
|
| 151 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 152 |
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
| 153 |
+
# Load the next block of A and B, generate a mask by checking the
|
| 154 |
+
# K dimension.
|
| 155 |
+
|
| 156 |
+
if not block_k_diviable:
|
| 157 |
+
k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
|
| 158 |
+
k_other = 0.0
|
| 159 |
+
else:
|
| 160 |
+
k_mask = None
|
| 161 |
+
k_other = None
|
| 162 |
+
|
| 163 |
+
a = tl.load(a_ptrs,
|
| 164 |
+
mask=token_mask[:, None] &
|
| 165 |
+
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
| 166 |
+
other=0.0)
|
| 167 |
+
b = tl.load(b_ptrs)
|
| 168 |
+
if use_int4_w4a16:
|
| 169 |
+
b = (b >> b_shifter) & 0xF
|
| 170 |
+
|
| 171 |
+
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \
|
| 172 |
+
offs_bn[None, :] * stride_bsn + \
|
| 173 |
+
((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
|
| 174 |
+
b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
|
| 175 |
+
b_scale = b_scale.to(tl.float32)
|
| 176 |
+
|
| 177 |
+
if has_zp and use_int4_w4a16:
|
| 178 |
+
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
|
| 179 |
+
b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \
|
| 180 |
+
(offs_bn[None, :] // 2) * stride_bzn + \
|
| 181 |
+
offs_k_true * stride_bzk
|
| 182 |
+
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
|
| 183 |
+
b_zp = ((b_zp >> b_zp_shifter) & 0xF)
|
| 184 |
+
b_zp = b_zp.to(tl.float32)
|
| 185 |
+
elif has_zp and use_int8_w8a16:
|
| 186 |
+
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
|
| 187 |
+
b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \
|
| 188 |
+
offs_bn[None, :] * stride_bzn + \
|
| 189 |
+
offs_k_true * stride_bzk
|
| 190 |
+
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
|
| 191 |
+
b_zp = b_zp.to(tl.float32)
|
| 192 |
+
|
| 193 |
+
# We accumulate along the K dimension.
|
| 194 |
+
if has_zp:
|
| 195 |
+
b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
|
| 196 |
+
else:
|
| 197 |
+
b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
|
| 198 |
+
accumulator = tl.dot(a, b, acc=accumulator)
|
| 199 |
+
|
| 200 |
+
# Advance the ptrs to the next K block.
|
| 201 |
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
| 202 |
+
if use_int4_w4a16:
|
| 203 |
+
b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
|
| 204 |
+
else:
|
| 205 |
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
| 206 |
+
|
| 207 |
+
if MUL_ROUTED_WEIGHT:
|
| 208 |
+
moe_weight = tl.load(topk_weights_ptr + offs_token,
|
| 209 |
+
mask=token_mask,
|
| 210 |
+
other=0)
|
| 211 |
+
accumulator = accumulator * moe_weight[:, None]
|
| 212 |
+
|
| 213 |
+
accumulator = accumulator.to(compute_type)
|
| 214 |
+
# -----------------------------------------------------------
|
| 215 |
+
# Write back the block of the output
|
| 216 |
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 217 |
+
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
|
| 218 |
+
None, :]
|
| 219 |
+
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
| 220 |
+
tl.store(c_ptrs, accumulator, mask=c_mask)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
@triton.jit
|
| 224 |
+
def fused_moe_kernel(
|
| 225 |
+
# Pointers to matrices
|
| 226 |
+
a_ptr,
|
| 227 |
+
b_ptr,
|
| 228 |
+
c_ptr,
|
| 229 |
+
a_scale_ptr,
|
| 230 |
+
b_scale_ptr,
|
| 231 |
+
topk_weights_ptr,
|
| 232 |
+
sorted_token_ids_ptr,
|
| 233 |
+
expert_ids_ptr,
|
| 234 |
+
num_tokens_post_padded_ptr,
|
| 235 |
+
# Matrix dimensions
|
| 236 |
+
N,
|
| 237 |
+
K,
|
| 238 |
+
EM,
|
| 239 |
+
num_valid_tokens,
|
| 240 |
+
# The stride variables represent how much to increase the ptr by when
|
| 241 |
+
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
| 242 |
+
# how much to increase `a_ptr` by to get the element one row down
|
| 243 |
+
# (A has M rows).
|
| 244 |
+
stride_am,
|
| 245 |
+
stride_ak,
|
| 246 |
+
stride_be,
|
| 247 |
+
stride_bk,
|
| 248 |
+
stride_bn,
|
| 249 |
+
stride_cm,
|
| 250 |
+
stride_cn,
|
| 251 |
+
stride_asm,
|
| 252 |
+
stride_ask,
|
| 253 |
+
stride_bse,
|
| 254 |
+
stride_bsk,
|
| 255 |
+
stride_bsn,
|
| 256 |
+
# Block size for block-wise quantization
|
| 257 |
+
group_n: tl.constexpr,
|
| 258 |
+
group_k: tl.constexpr,
|
| 259 |
+
# Meta-parameters
|
| 260 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 261 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 262 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 263 |
+
GROUP_SIZE_M: tl.constexpr,
|
| 264 |
+
MUL_ROUTED_WEIGHT: tl.constexpr,
|
| 265 |
+
top_k: tl.constexpr,
|
| 266 |
+
compute_type: tl.constexpr,
|
| 267 |
+
use_fp8_w8a8: tl.constexpr,
|
| 268 |
+
use_int8_w8a16: tl.constexpr):
|
| 269 |
+
"""
|
| 270 |
+
Implements the fused computation for a Mixture of Experts (MOE) using
|
| 271 |
+
token and expert matrices.
|
| 272 |
+
|
| 273 |
+
Key Parameters:
|
| 274 |
+
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
| 275 |
+
be any shape representing batches and K is the feature dimension of
|
| 276 |
+
each token.
|
| 277 |
+
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
| 278 |
+
the number of experts, K is the input feature dimension, and N is
|
| 279 |
+
the output feature dimension.
|
| 280 |
+
- C: The output cache tensor with shape (M, topk, N), where M is the
|
| 281 |
+
total number of tokens post padding, topk is the number of times
|
| 282 |
+
each token is repeated, and N is the output feature dimension.
|
| 283 |
+
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
| 284 |
+
repeated topk times and arranged by the expert index they are
|
| 285 |
+
assigned to.
|
| 286 |
+
- expert_ids: A tensor containing the indices of the expert for each
|
| 287 |
+
block. It determines which expert matrix from B should be used for
|
| 288 |
+
each block in A.
|
| 289 |
+
This kernel performs the multiplication of a token by its corresponding
|
| 290 |
+
expert matrix as determined by `expert_ids`. The sorting of
|
| 291 |
+
`sorted_token_ids` by expert index and padding ensures divisibility by
|
| 292 |
+
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
| 293 |
+
multiplication across different blocks processed by the same expert.
|
| 294 |
+
"""
|
| 295 |
+
# -----------------------------------------------------------
|
| 296 |
+
# Map program ids `pid` to the block of C it should compute.
|
| 297 |
+
# This is done in a grouped ordering to promote L2 data reuse.
|
| 298 |
+
pid = tl.program_id(axis=0)
|
| 299 |
+
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
| 300 |
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
| 301 |
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
| 302 |
+
group_id = pid // num_pid_in_group
|
| 303 |
+
first_pid_m = group_id * GROUP_SIZE_M
|
| 304 |
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
| 305 |
+
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
| 306 |
+
pid_n = (pid % num_pid_in_group) // group_size_m
|
| 307 |
+
|
| 308 |
+
# ----------------------------------------------------------
|
| 309 |
+
# Create pointers for the first blocks of A and B.
|
| 310 |
+
# We will advance this pointer as we move in the K direction
|
| 311 |
+
# and accumulate
|
| 312 |
+
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
| 313 |
+
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
|
| 314 |
+
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
| 315 |
+
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
| 316 |
+
return
|
| 317 |
+
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(
|
| 318 |
+
tl.int64)
|
| 319 |
+
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
| 320 |
+
token_mask = offs_token < num_valid_tokens
|
| 321 |
+
|
| 322 |
+
offs_bn = (pid_n * BLOCK_SIZE_N +
|
| 323 |
+
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
| 324 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 325 |
+
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
|
| 326 |
+
offs_k[None, :] * stride_ak)
|
| 327 |
+
|
| 328 |
+
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
| 329 |
+
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
|
| 330 |
+
offs_bn[None, :] * stride_bn)
|
| 331 |
+
if use_int8_w8a16:
|
| 332 |
+
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[
|
| 333 |
+
None, :] * stride_bsn
|
| 334 |
+
b_scale = tl.load(b_scale_ptrs)
|
| 335 |
+
|
| 336 |
+
if use_fp8_w8a8:
|
| 337 |
+
if group_k > 0 and group_n > 0:
|
| 338 |
+
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
| 339 |
+
offs_bsn = offs_bn // group_n
|
| 340 |
+
b_scale_ptrs = (b_scale_ptr + off_experts * stride_bse +
|
| 341 |
+
offs_bsn * stride_bsn)
|
| 342 |
+
else:
|
| 343 |
+
a_scale = tl.load(a_scale_ptr)
|
| 344 |
+
b_scale = tl.load(b_scale_ptr + off_experts)
|
| 345 |
+
|
| 346 |
+
# -----------------------------------------------------------
|
| 347 |
+
# Iterate to compute a block of the C matrix.
|
| 348 |
+
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
| 349 |
+
# of fp32 values for higher accuracy.
|
| 350 |
+
# `accumulator` will be converted back to fp16 after the loop.
|
| 351 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 352 |
+
|
| 353 |
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
| 354 |
+
# Load the next block of A and B, generate a mask by checking the
|
| 355 |
+
# K dimension.
|
| 356 |
+
a = tl.load(a_ptrs,
|
| 357 |
+
mask=token_mask[:, None] &
|
| 358 |
+
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
| 359 |
+
other=0.0)
|
| 360 |
+
b = tl.load(b_ptrs,
|
| 361 |
+
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
|
| 362 |
+
other=0.0)
|
| 363 |
+
# We accumulate along the K dimension.
|
| 364 |
+
if use_int8_w8a16:
|
| 365 |
+
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
| 366 |
+
elif use_fp8_w8a8:
|
| 367 |
+
if group_k > 0 and group_n > 0:
|
| 368 |
+
k_start = k * BLOCK_SIZE_K
|
| 369 |
+
offs_ks = k_start // group_k
|
| 370 |
+
a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask,
|
| 371 |
+
mask=token_mask,
|
| 372 |
+
other=0.0)
|
| 373 |
+
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
|
| 374 |
+
|
| 375 |
+
accumulator += tl.dot(a, b) * a_scale[:,
|
| 376 |
+
None] * b_scale[None, :]
|
| 377 |
+
else:
|
| 378 |
+
accumulator = tl.dot(a, b, acc=accumulator)
|
| 379 |
+
else:
|
| 380 |
+
accumulator += tl.dot(a, b)
|
| 381 |
+
# Advance the ptrs to the next K block.
|
| 382 |
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
| 383 |
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
| 384 |
+
|
| 385 |
+
if MUL_ROUTED_WEIGHT:
|
| 386 |
+
moe_weight = tl.load(topk_weights_ptr + offs_token,
|
| 387 |
+
mask=token_mask,
|
| 388 |
+
other=0)
|
| 389 |
+
accumulator = accumulator * moe_weight[:, None]
|
| 390 |
+
if use_int8_w8a16:
|
| 391 |
+
accumulator = (accumulator * b_scale).to(compute_type)
|
| 392 |
+
elif use_fp8_w8a8:
|
| 393 |
+
if group_k > 0 and group_n > 0:
|
| 394 |
+
accumulator = accumulator.to(compute_type)
|
| 395 |
+
else:
|
| 396 |
+
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
| 397 |
+
else:
|
| 398 |
+
accumulator = accumulator.to(compute_type)
|
| 399 |
+
# -----------------------------------------------------------
|
| 400 |
+
# Write back the block of the output
|
| 401 |
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 402 |
+
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
|
| 403 |
+
None, :]
|
| 404 |
+
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
| 405 |
+
tl.store(c_ptrs, accumulator, mask=c_mask)
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def ceil_div(a, b):
|
| 409 |
+
return (a + b - 1) // b
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
@triton.jit
|
| 413 |
+
def moe_align_block_size_stage1(
|
| 414 |
+
topk_ids_ptr,
|
| 415 |
+
tokens_cnts_ptr,
|
| 416 |
+
num_experts: tl.constexpr,
|
| 417 |
+
numel: tl.constexpr,
|
| 418 |
+
tokens_per_thread: tl.constexpr,
|
| 419 |
+
):
|
| 420 |
+
pid = tl.program_id(0)
|
| 421 |
+
|
| 422 |
+
start_idx = pid * tokens_per_thread
|
| 423 |
+
|
| 424 |
+
off_c = (pid + 1) * num_experts
|
| 425 |
+
|
| 426 |
+
for i in range(tokens_per_thread):
|
| 427 |
+
if start_idx + i < numel:
|
| 428 |
+
idx = tl.load(topk_ids_ptr + start_idx + i)
|
| 429 |
+
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
|
| 430 |
+
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
@triton.jit
|
| 434 |
+
def moe_align_block_size_stage2(
|
| 435 |
+
tokens_cnts_ptr,
|
| 436 |
+
num_experts: tl.constexpr,
|
| 437 |
+
):
|
| 438 |
+
pid = tl.program_id(0)
|
| 439 |
+
|
| 440 |
+
last_cnt = 0
|
| 441 |
+
for i in range(1, num_experts + 1):
|
| 442 |
+
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
|
| 443 |
+
last_cnt = last_cnt + token_cnt
|
| 444 |
+
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
@triton.jit
|
| 448 |
+
def moe_align_block_size_stage3(
|
| 449 |
+
total_tokens_post_pad_ptr,
|
| 450 |
+
tokens_cnts_ptr,
|
| 451 |
+
cumsum_ptr,
|
| 452 |
+
num_experts: tl.constexpr,
|
| 453 |
+
block_size: tl.constexpr,
|
| 454 |
+
):
|
| 455 |
+
last_cumsum = 0
|
| 456 |
+
off_cnt = num_experts * num_experts
|
| 457 |
+
for i in range(1, num_experts + 1):
|
| 458 |
+
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
|
| 459 |
+
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
|
| 460 |
+
tl.store(cumsum_ptr + i, last_cumsum)
|
| 461 |
+
tl.store(total_tokens_post_pad_ptr, last_cumsum)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
@triton.jit
|
| 465 |
+
def moe_align_block_size_stage4(
|
| 466 |
+
topk_ids_ptr,
|
| 467 |
+
sorted_token_ids_ptr,
|
| 468 |
+
expert_ids_ptr,
|
| 469 |
+
tokens_cnts_ptr,
|
| 470 |
+
cumsum_ptr,
|
| 471 |
+
num_experts: tl.constexpr,
|
| 472 |
+
block_size: tl.constexpr,
|
| 473 |
+
numel: tl.constexpr,
|
| 474 |
+
tokens_per_thread: tl.constexpr,
|
| 475 |
+
):
|
| 476 |
+
pid = tl.program_id(0)
|
| 477 |
+
start_idx = tl.load(cumsum_ptr + pid)
|
| 478 |
+
end_idx = tl.load(cumsum_ptr + pid + 1)
|
| 479 |
+
|
| 480 |
+
for i in range(start_idx, end_idx, block_size):
|
| 481 |
+
tl.store(expert_ids_ptr + i // block_size, pid)
|
| 482 |
+
|
| 483 |
+
start_idx = pid * tokens_per_thread
|
| 484 |
+
off_t = pid * num_experts
|
| 485 |
+
|
| 486 |
+
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread,
|
| 487 |
+
numel)):
|
| 488 |
+
expert_id = tl.load(topk_ids_ptr + i)
|
| 489 |
+
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
|
| 490 |
+
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
|
| 491 |
+
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
|
| 492 |
+
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
# Triton implementation based on:
|
| 496 |
+
# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
|
| 497 |
+
def moe_align_block_size_triton(
|
| 498 |
+
topk_ids: torch.Tensor,
|
| 499 |
+
num_experts: int,
|
| 500 |
+
block_size: int,
|
| 501 |
+
sorted_token_ids: torch.Tensor,
|
| 502 |
+
expert_ids: torch.Tensor,
|
| 503 |
+
num_tokens_post_pad: torch.Tensor,
|
| 504 |
+
) -> None:
|
| 505 |
+
numel = topk_ids.numel()
|
| 506 |
+
grid = (num_experts, )
|
| 507 |
+
tokens_cnts = torch.zeros((num_experts + 1, num_experts),
|
| 508 |
+
dtype=torch.int32,
|
| 509 |
+
device=topk_ids.device)
|
| 510 |
+
cumsum = torch.zeros((num_experts + 1, ),
|
| 511 |
+
dtype=torch.int32,
|
| 512 |
+
device=topk_ids.device)
|
| 513 |
+
tokens_per_thread = ceil_div(numel, num_experts)
|
| 514 |
+
|
| 515 |
+
moe_align_block_size_stage1[grid](
|
| 516 |
+
topk_ids,
|
| 517 |
+
tokens_cnts,
|
| 518 |
+
num_experts,
|
| 519 |
+
numel,
|
| 520 |
+
tokens_per_thread,
|
| 521 |
+
)
|
| 522 |
+
moe_align_block_size_stage2[grid](
|
| 523 |
+
tokens_cnts,
|
| 524 |
+
num_experts,
|
| 525 |
+
)
|
| 526 |
+
moe_align_block_size_stage3[(1, )](
|
| 527 |
+
num_tokens_post_pad,
|
| 528 |
+
tokens_cnts,
|
| 529 |
+
cumsum,
|
| 530 |
+
num_experts,
|
| 531 |
+
block_size,
|
| 532 |
+
)
|
| 533 |
+
moe_align_block_size_stage4[grid](
|
| 534 |
+
topk_ids,
|
| 535 |
+
sorted_token_ids,
|
| 536 |
+
expert_ids,
|
| 537 |
+
tokens_cnts,
|
| 538 |
+
cumsum,
|
| 539 |
+
num_experts,
|
| 540 |
+
block_size,
|
| 541 |
+
numel,
|
| 542 |
+
tokens_per_thread,
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
def moe_align_block_size(
|
| 547 |
+
topk_ids: torch.Tensor, block_size: int,
|
| 548 |
+
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 549 |
+
"""
|
| 550 |
+
Aligns the token distribution across experts to be compatible with block
|
| 551 |
+
size for matrix multiplication.
|
| 552 |
+
|
| 553 |
+
Parameters:
|
| 554 |
+
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
|
| 555 |
+
top-k expert indices for each token.
|
| 556 |
+
- block_size: The block size used in block matrix multiplication.
|
| 557 |
+
- num_experts: The total number of experts.
|
| 558 |
+
|
| 559 |
+
Returns:
|
| 560 |
+
- sorted_token_ids: A tensor containing the sorted token indices according
|
| 561 |
+
to their allocated expert.
|
| 562 |
+
- expert_ids: A tensor indicating the assigned expert index for each block.
|
| 563 |
+
- num_tokens_post_padded: The total number of tokens after padding,
|
| 564 |
+
ensuring divisibility by block_size.
|
| 565 |
+
|
| 566 |
+
This function pads the number of tokens that each expert needs to process
|
| 567 |
+
so that it is divisible by block_size.
|
| 568 |
+
Padding ensures that during block matrix multiplication, the dimensions
|
| 569 |
+
align correctly.
|
| 570 |
+
|
| 571 |
+
Example:
|
| 572 |
+
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
|
| 573 |
+
block_size = 4, and num_experts = 4:
|
| 574 |
+
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
|
| 575 |
+
with each expert needing to process 3 tokens.
|
| 576 |
+
- As block_size is 4, we pad 1 token for each expert.
|
| 577 |
+
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
|
| 578 |
+
- Then append padding tokens [12, 12, 12, 12] for each block.
|
| 579 |
+
- After sorting by expert index, we obtain token_ids
|
| 580 |
+
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
|
| 581 |
+
Tokens 12 are non-existent (padding) and are ignored in
|
| 582 |
+
the subsequent matrix multiplication.
|
| 583 |
+
- The padding ensures that the total number of tokens is now divisible
|
| 584 |
+
by block_size for proper block matrix operations.
|
| 585 |
+
"""
|
| 586 |
+
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
| 587 |
+
sorted_ids = torch.empty((max_num_tokens_padded, ),
|
| 588 |
+
dtype=torch.int32,
|
| 589 |
+
device=topk_ids.device)
|
| 590 |
+
sorted_ids.fill_(topk_ids.numel())
|
| 591 |
+
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
| 592 |
+
expert_ids = torch.empty((max_num_m_blocks, ),
|
| 593 |
+
dtype=torch.int32,
|
| 594 |
+
device=topk_ids.device)
|
| 595 |
+
num_tokens_post_pad = torch.empty((1),
|
| 596 |
+
dtype=torch.int32,
|
| 597 |
+
device=topk_ids.device)
|
| 598 |
+
if num_experts >= 224:
|
| 599 |
+
if envs.VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON:
|
| 600 |
+
moe_align_block_size_triton(
|
| 601 |
+
topk_ids,
|
| 602 |
+
num_experts,
|
| 603 |
+
block_size,
|
| 604 |
+
sorted_ids,
|
| 605 |
+
expert_ids,
|
| 606 |
+
num_tokens_post_pad,
|
| 607 |
+
)
|
| 608 |
+
else:
|
| 609 |
+
ops.sgl_moe_align_block_size(
|
| 610 |
+
topk_ids,
|
| 611 |
+
num_experts,
|
| 612 |
+
block_size,
|
| 613 |
+
sorted_ids,
|
| 614 |
+
expert_ids,
|
| 615 |
+
num_tokens_post_pad,
|
| 616 |
+
)
|
| 617 |
+
else:
|
| 618 |
+
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
|
| 619 |
+
expert_ids, num_tokens_post_pad)
|
| 620 |
+
return sorted_ids, expert_ids, num_tokens_post_pad
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
def invoke_fused_moe_kernel(A: torch.Tensor,
|
| 624 |
+
B: torch.Tensor,
|
| 625 |
+
C: torch.Tensor,
|
| 626 |
+
A_scale: Optional[torch.Tensor],
|
| 627 |
+
B_scale: Optional[torch.Tensor],
|
| 628 |
+
B_zp: Optional[torch.Tensor],
|
| 629 |
+
topk_weights: torch.Tensor,
|
| 630 |
+
topk_ids: torch.Tensor,
|
| 631 |
+
sorted_token_ids: torch.Tensor,
|
| 632 |
+
expert_ids: torch.Tensor,
|
| 633 |
+
num_tokens_post_padded: torch.Tensor,
|
| 634 |
+
mul_routed_weight: bool,
|
| 635 |
+
top_k: int,
|
| 636 |
+
config: Dict[str, Any],
|
| 637 |
+
compute_type: tl.dtype,
|
| 638 |
+
use_fp8_w8a8: bool,
|
| 639 |
+
use_int8_w8a16: bool,
|
| 640 |
+
use_int4_w4a16: bool,
|
| 641 |
+
block_shape: Optional[List[int]] = None) -> None:
|
| 642 |
+
assert topk_weights.stride(1) == 1
|
| 643 |
+
assert sorted_token_ids.stride(0) == 1
|
| 644 |
+
|
| 645 |
+
if use_fp8_w8a8:
|
| 646 |
+
assert B_scale is not None
|
| 647 |
+
if block_shape is None:
|
| 648 |
+
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
| 649 |
+
else:
|
| 650 |
+
assert len(block_shape) == 2
|
| 651 |
+
block_n, block_k = block_shape[0], block_shape[1]
|
| 652 |
+
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
| 653 |
+
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
| 654 |
+
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
| 655 |
+
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
| 656 |
+
elif use_int8_w8a16 or use_int4_w4a16:
|
| 657 |
+
assert B_scale is not None
|
| 658 |
+
assert block_shape is None or block_shape[0] == 0
|
| 659 |
+
else:
|
| 660 |
+
assert A_scale is None
|
| 661 |
+
assert B_scale is None
|
| 662 |
+
|
| 663 |
+
EM = sorted_token_ids.shape[0]
|
| 664 |
+
if A.shape[0] < config["BLOCK_SIZE_M"]:
|
| 665 |
+
# optimize for small batch_size.
|
| 666 |
+
# We assume that top_ids of each token is unique, so
|
| 667 |
+
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
|
| 668 |
+
# and we can skip some invalid blocks.
|
| 669 |
+
EM = min(sorted_token_ids.shape[0],
|
| 670 |
+
A.shape[0] * top_k * config['BLOCK_SIZE_M'])
|
| 671 |
+
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(
|
| 672 |
+
B.shape[1], META['BLOCK_SIZE_N']), )
|
| 673 |
+
|
| 674 |
+
if (use_int8_w8a16 or use_int4_w4a16) and \
|
| 675 |
+
block_shape is not None and block_shape[1] > 0:
|
| 676 |
+
assert B_scale is not None and B_scale.ndim == 3
|
| 677 |
+
assert B_zp is None or B_zp.ndim == 3
|
| 678 |
+
|
| 679 |
+
fused_moe_kernel_gptq_awq[grid](
|
| 680 |
+
A,
|
| 681 |
+
B,
|
| 682 |
+
C,
|
| 683 |
+
B_scale,
|
| 684 |
+
B_zp,
|
| 685 |
+
topk_weights,
|
| 686 |
+
sorted_token_ids,
|
| 687 |
+
expert_ids,
|
| 688 |
+
num_tokens_post_padded,
|
| 689 |
+
B.shape[1],
|
| 690 |
+
A.shape[1],
|
| 691 |
+
EM,
|
| 692 |
+
topk_ids.numel(),
|
| 693 |
+
A.stride(0),
|
| 694 |
+
A.stride(1),
|
| 695 |
+
B.stride(0),
|
| 696 |
+
B.stride(2),
|
| 697 |
+
B.stride(1),
|
| 698 |
+
C.stride(1),
|
| 699 |
+
C.stride(2),
|
| 700 |
+
B_scale.stride(0),
|
| 701 |
+
B_scale.stride(2),
|
| 702 |
+
B_scale.stride(1),
|
| 703 |
+
B_zp.stride(0) if B_zp is not None else 0,
|
| 704 |
+
B_zp.stride(2) if B_zp is not None else 0,
|
| 705 |
+
B_zp.stride(1) if B_zp is not None else 0,
|
| 706 |
+
block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0,
|
| 707 |
+
group_size=block_shape[1],
|
| 708 |
+
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
| 709 |
+
top_k=top_k,
|
| 710 |
+
compute_type=compute_type,
|
| 711 |
+
has_zp=B_zp is not None,
|
| 712 |
+
use_int4_w4a16=use_int4_w4a16,
|
| 713 |
+
use_int8_w8a16=use_int8_w8a16,
|
| 714 |
+
**config,
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
else:
|
| 718 |
+
fused_moe_kernel[grid](
|
| 719 |
+
A,
|
| 720 |
+
B,
|
| 721 |
+
C,
|
| 722 |
+
A_scale,
|
| 723 |
+
B_scale,
|
| 724 |
+
topk_weights,
|
| 725 |
+
sorted_token_ids,
|
| 726 |
+
expert_ids,
|
| 727 |
+
num_tokens_post_padded,
|
| 728 |
+
B.shape[1],
|
| 729 |
+
A.shape[1],
|
| 730 |
+
EM,
|
| 731 |
+
topk_ids.numel(),
|
| 732 |
+
A.stride(0),
|
| 733 |
+
A.stride(1),
|
| 734 |
+
B.stride(0),
|
| 735 |
+
B.stride(2),
|
| 736 |
+
B.stride(1),
|
| 737 |
+
C.stride(1),
|
| 738 |
+
C.stride(2),
|
| 739 |
+
A_scale.stride(0)
|
| 740 |
+
if A_scale is not None and A_scale.ndim == 2 else 0,
|
| 741 |
+
A_scale.stride(1)
|
| 742 |
+
if A_scale is not None and A_scale.ndim == 2 else 0,
|
| 743 |
+
B_scale.stride(0)
|
| 744 |
+
if B_scale is not None and B_scale.ndim >= 2 else 0,
|
| 745 |
+
B_scale.stride(2)
|
| 746 |
+
if B_scale is not None and B_scale.ndim == 3 else 0,
|
| 747 |
+
B_scale.stride(1)
|
| 748 |
+
if B_scale is not None and B_scale.ndim >= 2 else 0,
|
| 749 |
+
0 if block_shape is None else block_shape[0],
|
| 750 |
+
0 if block_shape is None else block_shape[1],
|
| 751 |
+
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
| 752 |
+
top_k=top_k,
|
| 753 |
+
compute_type=compute_type,
|
| 754 |
+
use_fp8_w8a8=use_fp8_w8a8,
|
| 755 |
+
use_int8_w8a16=use_int8_w8a16,
|
| 756 |
+
**config,
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
|
| 760 |
+
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
|
| 761 |
+
def get_config_file_name(E: int,
|
| 762 |
+
N: int,
|
| 763 |
+
dtype: Optional[str],
|
| 764 |
+
block_shape: Optional[List[int]] = None) -> str:
|
| 765 |
+
device_name = current_platform.get_device_name().replace(" ", "_")
|
| 766 |
+
dtype_selector = "" if not dtype else f",dtype={dtype}"
|
| 767 |
+
block_shape_selector = ("" if not block_shape or not all(block_shape) else
|
| 768 |
+
f",block_shape={block_shape}").replace(" ", "")
|
| 769 |
+
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
|
| 773 |
+
@functools.lru_cache
|
| 774 |
+
def get_moe_configs(
|
| 775 |
+
E: int,
|
| 776 |
+
N: int,
|
| 777 |
+
dtype: Optional[str],
|
| 778 |
+
block_n: Optional[int] = None,
|
| 779 |
+
block_k: Optional[int] = None,
|
| 780 |
+
) -> Optional[Dict[int, Any]]:
|
| 781 |
+
"""
|
| 782 |
+
Return optimized configurations for the fused MoE kernel.
|
| 783 |
+
|
| 784 |
+
The return value will be a dictionary that maps an irregular grid of
|
| 785 |
+
batch sizes to configurations of the fused_moe kernel. To evaluate the
|
| 786 |
+
kernel on a given batch size bs, the closest batch size in the grid should
|
| 787 |
+
be picked and the associated configuration chosen to invoke the kernel.
|
| 788 |
+
"""
|
| 789 |
+
|
| 790 |
+
# First look up if an optimized configuration is available in the configs
|
| 791 |
+
# directory
|
| 792 |
+
block_shape = [block_n, block_k] if block_n and block_k else None
|
| 793 |
+
json_file_name = get_config_file_name(E, N, dtype, block_shape)
|
| 794 |
+
|
| 795 |
+
config_file_path = os.path.join(
|
| 796 |
+
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
|
| 797 |
+
if os.path.exists(config_file_path):
|
| 798 |
+
with open(config_file_path) as f:
|
| 799 |
+
logger.info("Using configuration from %s for MoE layer.",
|
| 800 |
+
config_file_path)
|
| 801 |
+
# If a configuration has been found, return it
|
| 802 |
+
return {int(key): val for key, val in json.load(f).items()}
|
| 803 |
+
|
| 804 |
+
# If no optimized configuration is available, we will use the default
|
| 805 |
+
# configuration
|
| 806 |
+
logger.warning(
|
| 807 |
+
("Using default MoE config. Performance might be sub-optimal! "
|
| 808 |
+
"Config file not found at %s"), config_file_path)
|
| 809 |
+
return None
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
def get_default_config(
|
| 813 |
+
M: int,
|
| 814 |
+
E: int,
|
| 815 |
+
N: int,
|
| 816 |
+
K: int,
|
| 817 |
+
topk: int,
|
| 818 |
+
dtype: Optional[str],
|
| 819 |
+
is_marlin: bool,
|
| 820 |
+
block_shape: Optional[List[int]] = None,
|
| 821 |
+
) -> Dict[str, int]:
|
| 822 |
+
if dtype == "fp8_w8a8" and block_shape is not None:
|
| 823 |
+
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
|
| 824 |
+
# BLOCK_SIZE_K must be divisible by block_shape[1]
|
| 825 |
+
config = {
|
| 826 |
+
"BLOCK_SIZE_M": 64,
|
| 827 |
+
"BLOCK_SIZE_N": block_shape[0],
|
| 828 |
+
"BLOCK_SIZE_K": block_shape[1],
|
| 829 |
+
"GROUP_SIZE_M": 32,
|
| 830 |
+
"num_warps": 4,
|
| 831 |
+
"num_stages": 3,
|
| 832 |
+
}
|
| 833 |
+
else:
|
| 834 |
+
config = {
|
| 835 |
+
"BLOCK_SIZE_M": 64,
|
| 836 |
+
"BLOCK_SIZE_N": 64,
|
| 837 |
+
"BLOCK_SIZE_K": 32,
|
| 838 |
+
"GROUP_SIZE_M": 8,
|
| 839 |
+
}
|
| 840 |
+
# A heuristic: fused marlin works faster with this config for small M
|
| 841 |
+
if M <= E or (is_marlin and M <= 32):
|
| 842 |
+
config = {
|
| 843 |
+
"BLOCK_SIZE_M": 16,
|
| 844 |
+
"BLOCK_SIZE_N": 32,
|
| 845 |
+
"BLOCK_SIZE_K": 64,
|
| 846 |
+
"GROUP_SIZE_M": 1,
|
| 847 |
+
}
|
| 848 |
+
return config
|
| 849 |
+
|
| 850 |
+
|
| 851 |
+
def try_get_optimal_moe_config(
|
| 852 |
+
w1_shape: Tuple[int, ...],
|
| 853 |
+
w2_shape: Tuple[int, ...],
|
| 854 |
+
top_k: int,
|
| 855 |
+
dtype: Optional[str],
|
| 856 |
+
M: int,
|
| 857 |
+
is_marlin: bool = False,
|
| 858 |
+
block_shape: Optional[List[int]] = None,
|
| 859 |
+
):
|
| 860 |
+
from vllm.model_executor.layers.fused_moe import get_config
|
| 861 |
+
override_config = get_config()
|
| 862 |
+
if override_config:
|
| 863 |
+
config = override_config
|
| 864 |
+
else:
|
| 865 |
+
# First try to load optimal config from the file
|
| 866 |
+
E, _, N = w2_shape
|
| 867 |
+
block_n = block_shape[0] if block_shape else 0
|
| 868 |
+
block_k = block_shape[1] if block_shape else 0
|
| 869 |
+
configs = get_moe_configs(E, N, dtype, block_n, block_k)
|
| 870 |
+
|
| 871 |
+
if configs:
|
| 872 |
+
# If an optimal configuration map has been found, look up the
|
| 873 |
+
# optimal config
|
| 874 |
+
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
| 875 |
+
else:
|
| 876 |
+
# Else use the default config
|
| 877 |
+
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype,
|
| 878 |
+
is_marlin, block_shape)
|
| 879 |
+
return config
|
| 880 |
+
|
| 881 |
+
|
| 882 |
+
def fused_topk(
|
| 883 |
+
hidden_states: torch.Tensor,
|
| 884 |
+
gating_output: torch.Tensor,
|
| 885 |
+
topk: int,
|
| 886 |
+
renormalize: bool,
|
| 887 |
+
):
|
| 888 |
+
assert hidden_states.shape[0] == gating_output.shape[0], (
|
| 889 |
+
"Number of tokens mismatch")
|
| 890 |
+
|
| 891 |
+
M, _ = hidden_states.shape
|
| 892 |
+
|
| 893 |
+
topk_weights = torch.empty(M,
|
| 894 |
+
topk,
|
| 895 |
+
dtype=torch.float32,
|
| 896 |
+
device=hidden_states.device)
|
| 897 |
+
topk_ids = torch.empty(M,
|
| 898 |
+
topk,
|
| 899 |
+
dtype=torch.int32,
|
| 900 |
+
device=hidden_states.device)
|
| 901 |
+
token_expert_indicies = torch.empty(M,
|
| 902 |
+
topk,
|
| 903 |
+
dtype=torch.int32,
|
| 904 |
+
device=hidden_states.device)
|
| 905 |
+
|
| 906 |
+
ops.topk_softmax(
|
| 907 |
+
topk_weights,
|
| 908 |
+
topk_ids,
|
| 909 |
+
token_expert_indicies,
|
| 910 |
+
gating_output.float(), # TODO(woosuk): Optimize this.
|
| 911 |
+
)
|
| 912 |
+
del token_expert_indicies # Not used. Will be used in the future.
|
| 913 |
+
|
| 914 |
+
if renormalize:
|
| 915 |
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
| 916 |
+
|
| 917 |
+
return topk_weights, topk_ids
|
| 918 |
+
|
| 919 |
+
|
| 920 |
+
# This is used by the Deepseek-V2 and Deepseek-V3 model
|
| 921 |
+
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
| 922 |
+
def grouped_topk(hidden_states: torch.Tensor,
|
| 923 |
+
gating_output: torch.Tensor,
|
| 924 |
+
topk: int,
|
| 925 |
+
renormalize: bool,
|
| 926 |
+
num_expert_group: int = 0,
|
| 927 |
+
topk_group: int = 0,
|
| 928 |
+
scoring_func: str = "softmax",
|
| 929 |
+
e_score_correction_bias: Optional[torch.Tensor] = None):
|
| 930 |
+
|
| 931 |
+
assert hidden_states.shape[0] == gating_output.shape[0], (
|
| 932 |
+
"Number of tokens mismatch")
|
| 933 |
+
|
| 934 |
+
if scoring_func == "softmax":
|
| 935 |
+
scores = torch.softmax(gating_output, dim=-1)
|
| 936 |
+
elif scoring_func == "sigmoid":
|
| 937 |
+
scores = gating_output.sigmoid()
|
| 938 |
+
else:
|
| 939 |
+
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
| 940 |
+
|
| 941 |
+
if e_score_correction_bias is not None:
|
| 942 |
+
# Store original scores before applying correction bias. We use biased
|
| 943 |
+
# scores for expert selection but original scores for routing weights
|
| 944 |
+
original_scores = scores
|
| 945 |
+
scores = scores + e_score_correction_bias.unsqueeze(0)
|
| 946 |
+
|
| 947 |
+
num_token = scores.shape[0]
|
| 948 |
+
group_scores = scores.view(num_token, num_expert_group,
|
| 949 |
+
-1).max(dim=-1).values # [n, n_group]
|
| 950 |
+
group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
|
| 951 |
+
sorted=False)[1] # [n, top_k_group]
|
| 952 |
+
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
| 953 |
+
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
| 954 |
+
score_mask = group_mask.unsqueeze(-1).expand(
|
| 955 |
+
num_token, num_expert_group,
|
| 956 |
+
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
|
| 957 |
+
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
| 958 |
+
|
| 959 |
+
if e_score_correction_bias is not None:
|
| 960 |
+
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
|
| 961 |
+
# Use original unbiased scores for the routing weights
|
| 962 |
+
topk_weights = original_scores.gather(1, topk_ids)
|
| 963 |
+
else:
|
| 964 |
+
topk_weights, topk_ids = torch.topk(tmp_scores,
|
| 965 |
+
k=topk,
|
| 966 |
+
dim=-1,
|
| 967 |
+
sorted=False)
|
| 968 |
+
|
| 969 |
+
if renormalize:
|
| 970 |
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
| 971 |
+
|
| 972 |
+
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
| 973 |
+
|
| 974 |
+
|
| 975 |
+
def get_config_dtype_str(dtype: torch.dtype,
|
| 976 |
+
use_int4_w4a16: Optional[bool] = False,
|
| 977 |
+
use_int8_w8a16: Optional[bool] = False,
|
| 978 |
+
use_fp8_w8a8: Optional[bool] = False):
|
| 979 |
+
if use_fp8_w8a8:
|
| 980 |
+
return "fp8_w8a8"
|
| 981 |
+
elif use_int8_w8a16:
|
| 982 |
+
return "int8_w8a16"
|
| 983 |
+
elif use_int4_w4a16:
|
| 984 |
+
return "int4_w8a16"
|
| 985 |
+
elif dtype == torch.float:
|
| 986 |
+
# avoiding cases where kernel fails when float32 MoE
|
| 987 |
+
# use fp16/bfloat16 configs
|
| 988 |
+
return "float32"
|
| 989 |
+
return None
|
| 990 |
+
|
| 991 |
+
|
| 992 |
+
def inplace_fused_experts(hidden_states: torch.Tensor,
|
| 993 |
+
w1: torch.Tensor,
|
| 994 |
+
w2: torch.Tensor,
|
| 995 |
+
topk_weights: torch.Tensor,
|
| 996 |
+
topk_ids: torch.Tensor,
|
| 997 |
+
use_fp8_w8a8: bool = False,
|
| 998 |
+
use_int8_w8a16: bool = False,
|
| 999 |
+
use_int4_w4a16: bool = False,
|
| 1000 |
+
w1_scale: Optional[torch.Tensor] = None,
|
| 1001 |
+
w2_scale: Optional[torch.Tensor] = None,
|
| 1002 |
+
w1_zp: Optional[torch.Tensor] = None,
|
| 1003 |
+
w2_zp: Optional[torch.Tensor] = None,
|
| 1004 |
+
a1_scale: Optional[torch.Tensor] = None,
|
| 1005 |
+
a2_scale: Optional[torch.Tensor] = None,
|
| 1006 |
+
block_shape: Optional[List[int]] = None) -> None:
|
| 1007 |
+
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
| 1008 |
+
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, w1_scale,
|
| 1009 |
+
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
|
| 1010 |
+
|
| 1011 |
+
|
| 1012 |
+
def inplace_fused_experts_fake(
|
| 1013 |
+
hidden_states: torch.Tensor,
|
| 1014 |
+
w1: torch.Tensor,
|
| 1015 |
+
w2: torch.Tensor,
|
| 1016 |
+
topk_weights: torch.Tensor,
|
| 1017 |
+
topk_ids: torch.Tensor,
|
| 1018 |
+
use_fp8_w8a8: bool = False,
|
| 1019 |
+
use_int8_w8a16: bool = False,
|
| 1020 |
+
use_int4_w4a16: bool = False,
|
| 1021 |
+
w1_scale: Optional[torch.Tensor] = None,
|
| 1022 |
+
w2_scale: Optional[torch.Tensor] = None,
|
| 1023 |
+
w1_zp: Optional[torch.Tensor] = None,
|
| 1024 |
+
w2_zp: Optional[torch.Tensor] = None,
|
| 1025 |
+
a1_scale: Optional[torch.Tensor] = None,
|
| 1026 |
+
a2_scale: Optional[torch.Tensor] = None,
|
| 1027 |
+
block_shape: Optional[List[int]] = None) -> None:
|
| 1028 |
+
pass
|
| 1029 |
+
|
| 1030 |
+
|
| 1031 |
+
direct_register_custom_op(
|
| 1032 |
+
op_name="inplace_fused_experts",
|
| 1033 |
+
op_func=inplace_fused_experts,
|
| 1034 |
+
mutates_args=["hidden_states"],
|
| 1035 |
+
fake_impl=inplace_fused_experts_fake,
|
| 1036 |
+
)
|
| 1037 |
+
|
| 1038 |
+
|
| 1039 |
+
def outplace_fused_experts(
|
| 1040 |
+
hidden_states: torch.Tensor,
|
| 1041 |
+
w1: torch.Tensor,
|
| 1042 |
+
w2: torch.Tensor,
|
| 1043 |
+
topk_weights: torch.Tensor,
|
| 1044 |
+
topk_ids: torch.Tensor,
|
| 1045 |
+
use_fp8_w8a8: bool = False,
|
| 1046 |
+
use_int8_w8a16: bool = False,
|
| 1047 |
+
use_int4_w4a16: bool = False,
|
| 1048 |
+
w1_scale: Optional[torch.Tensor] = None,
|
| 1049 |
+
w2_scale: Optional[torch.Tensor] = None,
|
| 1050 |
+
w1_zp: Optional[torch.Tensor] = None,
|
| 1051 |
+
w2_zp: Optional[torch.Tensor] = None,
|
| 1052 |
+
a1_scale: Optional[torch.Tensor] = None,
|
| 1053 |
+
a2_scale: Optional[torch.Tensor] = None,
|
| 1054 |
+
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
| 1055 |
+
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
|
| 1056 |
+
False, use_fp8_w8a8, use_int8_w8a16,
|
| 1057 |
+
use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp,
|
| 1058 |
+
a1_scale, a2_scale, block_shape)
|
| 1059 |
+
|
| 1060 |
+
|
| 1061 |
+
def outplace_fused_experts_fake(
|
| 1062 |
+
hidden_states: torch.Tensor,
|
| 1063 |
+
w1: torch.Tensor,
|
| 1064 |
+
w2: torch.Tensor,
|
| 1065 |
+
topk_weights: torch.Tensor,
|
| 1066 |
+
topk_ids: torch.Tensor,
|
| 1067 |
+
use_fp8_w8a8: bool = False,
|
| 1068 |
+
use_int8_w8a16: bool = False,
|
| 1069 |
+
use_int4_w4a16: bool = False,
|
| 1070 |
+
w1_scale: Optional[torch.Tensor] = None,
|
| 1071 |
+
w2_scale: Optional[torch.Tensor] = None,
|
| 1072 |
+
w1_zp: Optional[torch.Tensor] = None,
|
| 1073 |
+
w2_zp: Optional[torch.Tensor] = None,
|
| 1074 |
+
a1_scale: Optional[torch.Tensor] = None,
|
| 1075 |
+
a2_scale: Optional[torch.Tensor] = None,
|
| 1076 |
+
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
| 1077 |
+
return torch.empty_like(hidden_states)
|
| 1078 |
+
|
| 1079 |
+
|
| 1080 |
+
direct_register_custom_op(
|
| 1081 |
+
op_name="outplace_fused_experts",
|
| 1082 |
+
op_func=outplace_fused_experts,
|
| 1083 |
+
mutates_args=[],
|
| 1084 |
+
fake_impl=outplace_fused_experts_fake,
|
| 1085 |
+
)
|
| 1086 |
+
|
| 1087 |
+
|
| 1088 |
+
def fused_experts(hidden_states: torch.Tensor,
|
| 1089 |
+
w1: torch.Tensor,
|
| 1090 |
+
w2: torch.Tensor,
|
| 1091 |
+
topk_weights: torch.Tensor,
|
| 1092 |
+
topk_ids: torch.Tensor,
|
| 1093 |
+
inplace: bool = False,
|
| 1094 |
+
use_fp8_w8a8: bool = False,
|
| 1095 |
+
use_int8_w8a16: bool = False,
|
| 1096 |
+
use_int4_w4a16: bool = False,
|
| 1097 |
+
w1_scale: Optional[torch.Tensor] = None,
|
| 1098 |
+
w2_scale: Optional[torch.Tensor] = None,
|
| 1099 |
+
w1_zp: Optional[torch.Tensor] = None,
|
| 1100 |
+
w2_zp: Optional[torch.Tensor] = None,
|
| 1101 |
+
a1_scale: Optional[torch.Tensor] = None,
|
| 1102 |
+
a2_scale: Optional[torch.Tensor] = None,
|
| 1103 |
+
block_shape: Optional[List[int]] = None):
|
| 1104 |
+
if inplace:
|
| 1105 |
+
torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2,
|
| 1106 |
+
topk_weights, topk_ids,
|
| 1107 |
+
use_fp8_w8a8, use_int8_w8a16,
|
| 1108 |
+
use_int4_w4a16, w1_scale,
|
| 1109 |
+
w2_scale, w1_zp, w2_zp, a1_scale,
|
| 1110 |
+
a2_scale, block_shape)
|
| 1111 |
+
return hidden_states
|
| 1112 |
+
else:
|
| 1113 |
+
return torch.ops.vllm.outplace_fused_experts(
|
| 1114 |
+
hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8,
|
| 1115 |
+
use_int8_w8a16, use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp,
|
| 1116 |
+
a1_scale, a2_scale, block_shape)
|
| 1117 |
+
|
| 1118 |
+
|
| 1119 |
+
def fused_experts_impl(hidden_states: torch.Tensor,
|
| 1120 |
+
w1: torch.Tensor,
|
| 1121 |
+
w2: torch.Tensor,
|
| 1122 |
+
topk_weights: torch.Tensor,
|
| 1123 |
+
topk_ids: torch.Tensor,
|
| 1124 |
+
inplace: bool = False,
|
| 1125 |
+
use_fp8_w8a8: bool = False,
|
| 1126 |
+
use_int8_w8a16: bool = False,
|
| 1127 |
+
use_int4_w4a16: bool = False,
|
| 1128 |
+
w1_scale: Optional[torch.Tensor] = None,
|
| 1129 |
+
w2_scale: Optional[torch.Tensor] = None,
|
| 1130 |
+
w1_zp: Optional[torch.Tensor] = None,
|
| 1131 |
+
w2_zp: Optional[torch.Tensor] = None,
|
| 1132 |
+
a1_scale: Optional[torch.Tensor] = None,
|
| 1133 |
+
a2_scale: Optional[torch.Tensor] = None,
|
| 1134 |
+
block_shape: Optional[List[int]] = None):
|
| 1135 |
+
# Check constraints.
|
| 1136 |
+
if use_int4_w4a16:
|
| 1137 |
+
assert hidden_states.shape[1] // 2 == w1.shape[
|
| 1138 |
+
2], "Hidden size mismatch"
|
| 1139 |
+
else:
|
| 1140 |
+
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
| 1141 |
+
|
| 1142 |
+
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
| 1143 |
+
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
| 1144 |
+
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
| 1145 |
+
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
| 1146 |
+
assert hidden_states.dtype in [
|
| 1147 |
+
torch.float32, torch.float16, torch.bfloat16
|
| 1148 |
+
]
|
| 1149 |
+
|
| 1150 |
+
num_tokens, _ = hidden_states.shape
|
| 1151 |
+
E, N, _ = w1.shape
|
| 1152 |
+
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
| 1153 |
+
# https://github.com/vllm-project/vllm/issues/5938
|
| 1154 |
+
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
| 1155 |
+
M = min(num_tokens, CHUNK_SIZE)
|
| 1156 |
+
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
|
| 1157 |
+
use_int8_w8a16=use_int8_w8a16,
|
| 1158 |
+
use_int4_w4a16=use_int4_w4a16,
|
| 1159 |
+
dtype=hidden_states.dtype)
|
| 1160 |
+
|
| 1161 |
+
get_config_func = functools.partial(
|
| 1162 |
+
try_get_optimal_moe_config,
|
| 1163 |
+
w1.shape,
|
| 1164 |
+
w2.shape,
|
| 1165 |
+
topk_ids.shape[1],
|
| 1166 |
+
config_dtype,
|
| 1167 |
+
block_shape=block_shape,
|
| 1168 |
+
)
|
| 1169 |
+
|
| 1170 |
+
config = get_config_func(M)
|
| 1171 |
+
|
| 1172 |
+
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
|
| 1173 |
+
device=hidden_states.device,
|
| 1174 |
+
dtype=hidden_states.dtype)
|
| 1175 |
+
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
|
| 1176 |
+
device=hidden_states.device,
|
| 1177 |
+
dtype=hidden_states.dtype)
|
| 1178 |
+
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
|
| 1179 |
+
device=hidden_states.device,
|
| 1180 |
+
dtype=hidden_states.dtype)
|
| 1181 |
+
|
| 1182 |
+
if hidden_states.dtype == torch.bfloat16:
|
| 1183 |
+
compute_type = tl.bfloat16
|
| 1184 |
+
elif hidden_states.dtype == torch.float16:
|
| 1185 |
+
compute_type = tl.float16
|
| 1186 |
+
elif hidden_states.dtype == torch.float32:
|
| 1187 |
+
compute_type = tl.float32
|
| 1188 |
+
else:
|
| 1189 |
+
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
|
| 1190 |
+
|
| 1191 |
+
if inplace:
|
| 1192 |
+
out_hidden_states = hidden_states
|
| 1193 |
+
else:
|
| 1194 |
+
out_hidden_states = torch.empty_like(hidden_states)
|
| 1195 |
+
|
| 1196 |
+
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
|
| 1197 |
+
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
|
| 1198 |
+
min((chunk + 1) * CHUNK_SIZE,
|
| 1199 |
+
num_tokens))
|
| 1200 |
+
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
|
| 1201 |
+
tokens_in_chunk, _ = curr_hidden_states.shape
|
| 1202 |
+
|
| 1203 |
+
if tokens_in_chunk == 0:
|
| 1204 |
+
break
|
| 1205 |
+
|
| 1206 |
+
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
|
| 1207 |
+
# Adjust the intermediate cache size and config for the last
|
| 1208 |
+
# chunk. Note that in most cases we only have one chunk
|
| 1209 |
+
# so the cache size and config are already set correctly and
|
| 1210 |
+
# do not need to be adjusted.
|
| 1211 |
+
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
|
| 1212 |
+
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
|
| 1213 |
+
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
|
| 1214 |
+
config = get_config_func(tokens_in_chunk)
|
| 1215 |
+
|
| 1216 |
+
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
| 1217 |
+
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
| 1218 |
+
|
| 1219 |
+
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
| 1220 |
+
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E))
|
| 1221 |
+
|
| 1222 |
+
invoke_fused_moe_kernel(curr_hidden_states,
|
| 1223 |
+
w1,
|
| 1224 |
+
intermediate_cache1,
|
| 1225 |
+
a1_scale,
|
| 1226 |
+
w1_scale,
|
| 1227 |
+
w1_zp,
|
| 1228 |
+
curr_topk_weights,
|
| 1229 |
+
curr_topk_ids,
|
| 1230 |
+
sorted_token_ids,
|
| 1231 |
+
expert_ids,
|
| 1232 |
+
num_tokens_post_padded,
|
| 1233 |
+
False,
|
| 1234 |
+
topk_ids.shape[1],
|
| 1235 |
+
config,
|
| 1236 |
+
compute_type=compute_type,
|
| 1237 |
+
use_fp8_w8a8=use_fp8_w8a8,
|
| 1238 |
+
use_int8_w8a16=use_int8_w8a16,
|
| 1239 |
+
use_int4_w4a16=use_int4_w4a16,
|
| 1240 |
+
block_shape=block_shape)
|
| 1241 |
+
|
| 1242 |
+
torch.ops._C.silu_and_mul(intermediate_cache2,
|
| 1243 |
+
intermediate_cache1.view(-1, N))
|
| 1244 |
+
|
| 1245 |
+
invoke_fused_moe_kernel(intermediate_cache2,
|
| 1246 |
+
w2,
|
| 1247 |
+
intermediate_cache3,
|
| 1248 |
+
a2_scale,
|
| 1249 |
+
w2_scale,
|
| 1250 |
+
w2_zp,
|
| 1251 |
+
curr_topk_weights,
|
| 1252 |
+
curr_topk_ids,
|
| 1253 |
+
sorted_token_ids,
|
| 1254 |
+
expert_ids,
|
| 1255 |
+
num_tokens_post_padded,
|
| 1256 |
+
True,
|
| 1257 |
+
1,
|
| 1258 |
+
config,
|
| 1259 |
+
compute_type=compute_type,
|
| 1260 |
+
use_fp8_w8a8=use_fp8_w8a8,
|
| 1261 |
+
use_int8_w8a16=use_int8_w8a16,
|
| 1262 |
+
use_int4_w4a16=use_int4_w4a16,
|
| 1263 |
+
block_shape=block_shape)
|
| 1264 |
+
|
| 1265 |
+
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
| 1266 |
+
out_hidden_states[begin_chunk_idx:end_chunk_idx])
|
| 1267 |
+
return out_hidden_states
|
| 1268 |
+
|
| 1269 |
+
|
| 1270 |
+
def fused_moe(
|
| 1271 |
+
hidden_states: torch.Tensor,
|
| 1272 |
+
w1: torch.Tensor,
|
| 1273 |
+
w2: torch.Tensor,
|
| 1274 |
+
gating_output: torch.Tensor,
|
| 1275 |
+
topk: int,
|
| 1276 |
+
renormalize: bool,
|
| 1277 |
+
inplace: bool = False,
|
| 1278 |
+
use_grouped_topk: bool = False,
|
| 1279 |
+
num_expert_group: Optional[int] = None,
|
| 1280 |
+
topk_group: Optional[int] = None,
|
| 1281 |
+
custom_routing_function: Optional[Callable] = None,
|
| 1282 |
+
use_fp8_w8a8: bool = False,
|
| 1283 |
+
use_int8_w8a16: bool = False,
|
| 1284 |
+
use_int4_w4a16: bool = False,
|
| 1285 |
+
w1_scale: Optional[torch.Tensor] = None,
|
| 1286 |
+
w2_scale: Optional[torch.Tensor] = None,
|
| 1287 |
+
w1_zp: Optional[torch.Tensor] = None,
|
| 1288 |
+
w2_zp: Optional[torch.Tensor] = None,
|
| 1289 |
+
a1_scale: Optional[torch.Tensor] = None,
|
| 1290 |
+
a2_scale: Optional[torch.Tensor] = None,
|
| 1291 |
+
block_shape: Optional[List[int]] = None,
|
| 1292 |
+
) -> torch.Tensor:
|
| 1293 |
+
"""
|
| 1294 |
+
This function computes a Mixture of Experts (MoE) layer using two sets of
|
| 1295 |
+
weights, w1 and w2, and top-k gating mechanism.
|
| 1296 |
+
|
| 1297 |
+
Parameters:
|
| 1298 |
+
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
| 1299 |
+
- w1 (torch.Tensor): The first set of expert weights.
|
| 1300 |
+
- w2 (torch.Tensor): The second set of expert weights.
|
| 1301 |
+
- gating_output (torch.Tensor): The output of the gating operation
|
| 1302 |
+
(before softmax).
|
| 1303 |
+
- topk (int): The number of top-k experts to select.
|
| 1304 |
+
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
| 1305 |
+
- inplace (bool): If True, perform the operation in-place.
|
| 1306 |
+
Defaults to False.
|
| 1307 |
+
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
| 1308 |
+
- topk_group: Optional[int]: additional parameter for grouped_topk
|
| 1309 |
+
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
| 1310 |
+
note: Deepseekv2 model uses grouped_topk
|
| 1311 |
+
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
| 1312 |
+
products for w1 and w2. Defaults to False.
|
| 1313 |
+
- use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
|
| 1314 |
+
activation to compute the inner products for w1 and w2.
|
| 1315 |
+
Defaults to False.
|
| 1316 |
+
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
|
| 1317 |
+
activation to compute the inner products for w1 and w2.
|
| 1318 |
+
Defaults to False.
|
| 1319 |
+
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
| 1320 |
+
w1.
|
| 1321 |
+
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
| 1322 |
+
w2.
|
| 1323 |
+
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
| 1324 |
+
a1.
|
| 1325 |
+
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
| 1326 |
+
a2.
|
| 1327 |
+
- block_shape: (Optional[List[int]]): Optional block size for block-wise
|
| 1328 |
+
quantization.
|
| 1329 |
+
|
| 1330 |
+
Returns:
|
| 1331 |
+
- torch.Tensor: The output tensor after applying the MoE layer.
|
| 1332 |
+
"""
|
| 1333 |
+
# Check constraints.
|
| 1334 |
+
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
| 1335 |
+
|
| 1336 |
+
if use_grouped_topk:
|
| 1337 |
+
assert num_expert_group is not None and topk_group is not None
|
| 1338 |
+
topk_weights, topk_ids = grouped_topk(hidden_states, gating_output,
|
| 1339 |
+
topk, renormalize,
|
| 1340 |
+
num_expert_group, topk_group)
|
| 1341 |
+
elif custom_routing_function is None:
|
| 1342 |
+
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
|
| 1343 |
+
renormalize)
|
| 1344 |
+
else:
|
| 1345 |
+
topk_weights, topk_ids = custom_routing_function(
|
| 1346 |
+
hidden_states, gating_output, topk, renormalize)
|
| 1347 |
+
|
| 1348 |
+
return fused_experts(hidden_states,
|
| 1349 |
+
w1,
|
| 1350 |
+
w2,
|
| 1351 |
+
topk_weights,
|
| 1352 |
+
topk_ids,
|
| 1353 |
+
inplace=inplace,
|
| 1354 |
+
use_fp8_w8a8=use_fp8_w8a8,
|
| 1355 |
+
use_int8_w8a16=use_int8_w8a16,
|
| 1356 |
+
use_int4_w4a16=use_int4_w4a16,
|
| 1357 |
+
w1_scale=w1_scale,
|
| 1358 |
+
w2_scale=w2_scale,
|
| 1359 |
+
w1_zp=w1_zp,
|
| 1360 |
+
w2_zp=w2_zp,
|
| 1361 |
+
a1_scale=a1_scale,
|
| 1362 |
+
a2_scale=a2_scale,
|
| 1363 |
+
block_shape=block_shape)
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/layer.py
ADDED
|
@@ -0,0 +1,647 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from abc import abstractmethod
|
| 4 |
+
from enum import Enum
|
| 5 |
+
from typing import Callable, List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from vllm.distributed import (get_tensor_model_parallel_rank,
|
| 10 |
+
get_tensor_model_parallel_world_size,
|
| 11 |
+
tensor_model_parallel_all_reduce)
|
| 12 |
+
from vllm.logger import init_logger
|
| 13 |
+
from vllm.model_executor.custom_op import CustomOp
|
| 14 |
+
from vllm.model_executor.layers.quantization.base_config import (
|
| 15 |
+
QuantizationConfig, QuantizeMethodBase)
|
| 16 |
+
from vllm.model_executor.utils import set_weight_attrs
|
| 17 |
+
from vllm.platforms import current_platform
|
| 18 |
+
from vllm.platforms.interface import CpuArchEnum
|
| 19 |
+
|
| 20 |
+
if current_platform.is_cuda_alike():
|
| 21 |
+
from .fused_moe import fused_experts
|
| 22 |
+
else:
|
| 23 |
+
fused_experts = None # type: ignore
|
| 24 |
+
if current_platform.is_tpu():
|
| 25 |
+
# the iterative moe implementation is used until the moe_pallas is fixed
|
| 26 |
+
from .moe_torch_iterative import fused_moe as fused_moe_pallas
|
| 27 |
+
else:
|
| 28 |
+
fused_moe_pallas = None # type: ignore
|
| 29 |
+
logger = init_logger(__name__)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class FusedMoeWeightScaleSupported(Enum):
|
| 33 |
+
TENSOR = "tensor"
|
| 34 |
+
CHANNEL = "channel"
|
| 35 |
+
GROUP = "group"
|
| 36 |
+
BLOCK = "block"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class FusedMoEMethodBase(QuantizeMethodBase):
|
| 40 |
+
|
| 41 |
+
@abstractmethod
|
| 42 |
+
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
| 43 |
+
hidden_size: int, intermediate_size_per_partition: int,
|
| 44 |
+
params_dtype: torch.dtype, **extra_weight_attrs):
|
| 45 |
+
raise NotImplementedError
|
| 46 |
+
|
| 47 |
+
@abstractmethod
|
| 48 |
+
def apply(
|
| 49 |
+
self,
|
| 50 |
+
layer: torch.nn.Module,
|
| 51 |
+
x: torch.Tensor,
|
| 52 |
+
router_logits: torch.Tensor,
|
| 53 |
+
top_k: int,
|
| 54 |
+
renormalize: bool,
|
| 55 |
+
use_grouped_topk: bool = False,
|
| 56 |
+
topk_group: Optional[int] = None,
|
| 57 |
+
num_expert_group: Optional[int] = None,
|
| 58 |
+
custom_routing_function: Optional[Callable] = None,
|
| 59 |
+
scoring_func: str = "softmax",
|
| 60 |
+
e_score_correction_bias: Optional[torch.Tensor] = None
|
| 61 |
+
) -> torch.Tensor:
|
| 62 |
+
raise NotImplementedError
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@CustomOp.register("unquantized_fused_moe")
|
| 66 |
+
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
| 67 |
+
"""MoE method without quantization."""
|
| 68 |
+
|
| 69 |
+
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
| 70 |
+
hidden_size: int, intermediate_size_per_partition: int,
|
| 71 |
+
params_dtype: torch.dtype, **extra_weight_attrs):
|
| 72 |
+
# Fused gate_up_proj (column parallel)
|
| 73 |
+
w13_weight = torch.nn.Parameter(torch.empty(
|
| 74 |
+
num_experts,
|
| 75 |
+
2 * intermediate_size_per_partition,
|
| 76 |
+
hidden_size,
|
| 77 |
+
dtype=params_dtype),
|
| 78 |
+
requires_grad=False)
|
| 79 |
+
layer.register_parameter("w13_weight", w13_weight)
|
| 80 |
+
set_weight_attrs(w13_weight, extra_weight_attrs)
|
| 81 |
+
|
| 82 |
+
# down_proj (row parallel)
|
| 83 |
+
w2_weight = torch.nn.Parameter(torch.empty(
|
| 84 |
+
num_experts,
|
| 85 |
+
hidden_size,
|
| 86 |
+
intermediate_size_per_partition,
|
| 87 |
+
dtype=params_dtype),
|
| 88 |
+
requires_grad=False)
|
| 89 |
+
layer.register_parameter("w2_weight", w2_weight)
|
| 90 |
+
set_weight_attrs(w2_weight, extra_weight_attrs)
|
| 91 |
+
|
| 92 |
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
| 93 |
+
super().process_weights_after_loading(layer)
|
| 94 |
+
|
| 95 |
+
if current_platform.is_cpu():
|
| 96 |
+
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
|
| 97 |
+
import intel_extension_for_pytorch as ipex
|
| 98 |
+
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
|
| 99 |
+
layer.w13_weight,
|
| 100 |
+
layer.w2_weight,
|
| 101 |
+
use_prepack=True,
|
| 102 |
+
)
|
| 103 |
+
else:
|
| 104 |
+
raise NotImplementedError("CPU MOE only supports x86 arch.")
|
| 105 |
+
|
| 106 |
+
def apply(
|
| 107 |
+
self,
|
| 108 |
+
layer: torch.nn.Module,
|
| 109 |
+
x: torch.Tensor,
|
| 110 |
+
router_logits: torch.Tensor,
|
| 111 |
+
top_k: int,
|
| 112 |
+
renormalize: bool,
|
| 113 |
+
use_grouped_topk: bool = False,
|
| 114 |
+
topk_group: Optional[int] = None,
|
| 115 |
+
num_expert_group: Optional[int] = None,
|
| 116 |
+
custom_routing_function: Optional[Callable] = None,
|
| 117 |
+
scoring_func: str = "softmax",
|
| 118 |
+
e_score_correction_bias: Optional[torch.Tensor] = None
|
| 119 |
+
) -> torch.Tensor:
|
| 120 |
+
return self.forward(x=x,
|
| 121 |
+
layer=layer,
|
| 122 |
+
router_logits=router_logits,
|
| 123 |
+
top_k=top_k,
|
| 124 |
+
renormalize=renormalize,
|
| 125 |
+
use_grouped_topk=use_grouped_topk,
|
| 126 |
+
topk_group=topk_group,
|
| 127 |
+
num_expert_group=num_expert_group,
|
| 128 |
+
custom_routing_function=custom_routing_function,
|
| 129 |
+
scoring_func=scoring_func,
|
| 130 |
+
e_score_correction_bias=e_score_correction_bias)
|
| 131 |
+
|
| 132 |
+
def forward_cuda(
|
| 133 |
+
self,
|
| 134 |
+
layer: torch.nn.Module,
|
| 135 |
+
x: torch.Tensor,
|
| 136 |
+
use_grouped_topk: bool,
|
| 137 |
+
top_k: int,
|
| 138 |
+
router_logits: torch.Tensor,
|
| 139 |
+
renormalize: bool,
|
| 140 |
+
topk_group: Optional[int] = None,
|
| 141 |
+
num_expert_group: Optional[int] = None,
|
| 142 |
+
custom_routing_function: Optional[Callable] = None,
|
| 143 |
+
scoring_func: str = "softmax",
|
| 144 |
+
e_score_correction_bias: Optional[torch.Tensor] = None
|
| 145 |
+
) -> torch.Tensor:
|
| 146 |
+
topk_weights, topk_ids = FusedMoE.select_experts(
|
| 147 |
+
hidden_states=x,
|
| 148 |
+
router_logits=router_logits,
|
| 149 |
+
use_grouped_topk=use_grouped_topk,
|
| 150 |
+
top_k=top_k,
|
| 151 |
+
renormalize=renormalize,
|
| 152 |
+
topk_group=topk_group,
|
| 153 |
+
num_expert_group=num_expert_group,
|
| 154 |
+
custom_routing_function=custom_routing_function,
|
| 155 |
+
scoring_func=scoring_func,
|
| 156 |
+
e_score_correction_bias=e_score_correction_bias)
|
| 157 |
+
|
| 158 |
+
return fused_experts(hidden_states=x,
|
| 159 |
+
w1=layer.w13_weight,
|
| 160 |
+
w2=layer.w2_weight,
|
| 161 |
+
topk_weights=topk_weights,
|
| 162 |
+
topk_ids=topk_ids,
|
| 163 |
+
inplace=True)
|
| 164 |
+
|
| 165 |
+
def forward_cpu(
|
| 166 |
+
self,
|
| 167 |
+
layer: torch.nn.Module,
|
| 168 |
+
x: torch.Tensor,
|
| 169 |
+
use_grouped_topk: bool,
|
| 170 |
+
top_k: int,
|
| 171 |
+
router_logits: torch.Tensor,
|
| 172 |
+
renormalize: bool,
|
| 173 |
+
topk_group: Optional[int] = None,
|
| 174 |
+
num_expert_group: Optional[int] = None,
|
| 175 |
+
custom_routing_function: Optional[Callable] = None,
|
| 176 |
+
**kwargs,
|
| 177 |
+
):
|
| 178 |
+
assert custom_routing_function is None
|
| 179 |
+
return layer.ipex_fusion(
|
| 180 |
+
x,
|
| 181 |
+
use_grouped_topk,
|
| 182 |
+
top_k,
|
| 183 |
+
router_logits,
|
| 184 |
+
renormalize,
|
| 185 |
+
topk_group,
|
| 186 |
+
num_expert_group,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def forward_tpu(
|
| 190 |
+
self,
|
| 191 |
+
layer: torch.nn.Module,
|
| 192 |
+
x: torch.Tensor,
|
| 193 |
+
use_grouped_topk: bool,
|
| 194 |
+
top_k: int,
|
| 195 |
+
router_logits: torch.Tensor,
|
| 196 |
+
renormalize: bool,
|
| 197 |
+
topk_group: Optional[int] = None,
|
| 198 |
+
num_expert_group: Optional[int] = None,
|
| 199 |
+
custom_routing_function: Optional[Callable] = None,
|
| 200 |
+
scoring_func: str = "softmax",
|
| 201 |
+
e_score_correction_bias: Optional[torch.Tensor] = None
|
| 202 |
+
) -> torch.Tensor:
|
| 203 |
+
assert not use_grouped_topk
|
| 204 |
+
assert num_expert_group is None
|
| 205 |
+
assert topk_group is None
|
| 206 |
+
assert custom_routing_function is None
|
| 207 |
+
if scoring_func != "softmax":
|
| 208 |
+
raise NotImplementedError(
|
| 209 |
+
"Only softmax scoring function is supported for TPU.")
|
| 210 |
+
if e_score_correction_bias is not None:
|
| 211 |
+
raise NotImplementedError(
|
| 212 |
+
"Expert score correction bias is not supported for TPU.")
|
| 213 |
+
return fused_moe_pallas(hidden_states=x,
|
| 214 |
+
w1=layer.w13_weight,
|
| 215 |
+
w2=layer.w2_weight,
|
| 216 |
+
topk=top_k,
|
| 217 |
+
gating_output=router_logits,
|
| 218 |
+
renormalize=renormalize)
|
| 219 |
+
|
| 220 |
+
forward_native = forward_cuda
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class FusedMoE(torch.nn.Module):
|
| 224 |
+
"""FusedMoE layer for MoE models.
|
| 225 |
+
|
| 226 |
+
This layer contains both MergedColumnParallel weights (gate_up_proj /
|
| 227 |
+
w13) and RowParallelLinear weights (down_proj/ w2).
|
| 228 |
+
|
| 229 |
+
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
|
| 230 |
+
copy that naming convention here and handle any remapping in the
|
| 231 |
+
load_weights function in each model implementation.
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
num_experts: Number of experts in the model
|
| 235 |
+
top_k: Number of experts selected for each token
|
| 236 |
+
hidden_size: Input hidden state size of the transformer
|
| 237 |
+
intermediate_size: Intermediate size of the experts
|
| 238 |
+
params_dtype: Data type for the parameters.
|
| 239 |
+
reduce_results: Whether to all all_reduce on the output of the layer
|
| 240 |
+
renomalize: Whether to renormalize the logits in the fused_moe kernel
|
| 241 |
+
quant_config: Quantization configure.
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
def __init__(
|
| 245 |
+
self,
|
| 246 |
+
num_experts: int,
|
| 247 |
+
top_k: int,
|
| 248 |
+
hidden_size: int,
|
| 249 |
+
intermediate_size: int,
|
| 250 |
+
params_dtype: Optional[torch.dtype] = None,
|
| 251 |
+
reduce_results: bool = False,
|
| 252 |
+
renormalize: bool = True,
|
| 253 |
+
use_grouped_topk: bool = False,
|
| 254 |
+
num_expert_group: Optional[int] = None,
|
| 255 |
+
topk_group: Optional[int] = None,
|
| 256 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 257 |
+
tp_size: Optional[int] = None,
|
| 258 |
+
prefix: str = "",
|
| 259 |
+
custom_routing_function: Optional[Callable] = None,
|
| 260 |
+
scoring_func: str = "softmax",
|
| 261 |
+
e_score_correction_bias: Optional[torch.Tensor] = None,
|
| 262 |
+
):
|
| 263 |
+
super().__init__()
|
| 264 |
+
|
| 265 |
+
if params_dtype is None:
|
| 266 |
+
params_dtype = torch.get_default_dtype()
|
| 267 |
+
|
| 268 |
+
self.tp_size = (tp_size if tp_size is not None else
|
| 269 |
+
get_tensor_model_parallel_world_size())
|
| 270 |
+
self.top_k = top_k
|
| 271 |
+
self.num_experts = num_experts
|
| 272 |
+
assert intermediate_size % self.tp_size == 0
|
| 273 |
+
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
| 274 |
+
self.reduce_results = reduce_results
|
| 275 |
+
self.renormalize = renormalize
|
| 276 |
+
self.use_grouped_topk = use_grouped_topk
|
| 277 |
+
if self.use_grouped_topk:
|
| 278 |
+
assert num_expert_group is not None and topk_group is not None
|
| 279 |
+
self.num_expert_group = num_expert_group
|
| 280 |
+
self.topk_group = topk_group
|
| 281 |
+
self.custom_routing_function = custom_routing_function
|
| 282 |
+
self.scoring_func = scoring_func
|
| 283 |
+
self.e_score_correction_bias = e_score_correction_bias
|
| 284 |
+
|
| 285 |
+
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
| 286 |
+
raise ValueError("Only softmax scoring function is supported for "
|
| 287 |
+
"non-grouped topk.")
|
| 288 |
+
|
| 289 |
+
if quant_config is None:
|
| 290 |
+
self.quant_method: Optional[QuantizeMethodBase] = (
|
| 291 |
+
UnquantizedFusedMoEMethod())
|
| 292 |
+
else:
|
| 293 |
+
self.quant_method = quant_config.get_quant_method(self, prefix)
|
| 294 |
+
assert self.quant_method is not None
|
| 295 |
+
|
| 296 |
+
moe_quant_params = {
|
| 297 |
+
"num_experts": num_experts,
|
| 298 |
+
"hidden_size": hidden_size,
|
| 299 |
+
"intermediate_size_per_partition":
|
| 300 |
+
self.intermediate_size_per_partition,
|
| 301 |
+
"params_dtype": params_dtype,
|
| 302 |
+
"weight_loader": self.weight_loader,
|
| 303 |
+
}
|
| 304 |
+
# need full intermediate size pre-sharding for WNA16 act order
|
| 305 |
+
if (self.quant_method.__class__.__name__ ==
|
| 306 |
+
"CompressedTensorsWNA16MoEMethod"):
|
| 307 |
+
moe_quant_params["intermediate_size_full"] = intermediate_size
|
| 308 |
+
|
| 309 |
+
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
| 310 |
+
|
| 311 |
+
def _load_per_tensor_weight_scale(self, shard_id: str,
|
| 312 |
+
param: torch.nn.Parameter,
|
| 313 |
+
loaded_weight: torch.Tensor,
|
| 314 |
+
expert_id: int):
|
| 315 |
+
param_data = param.data
|
| 316 |
+
# for per tensor weight quantization
|
| 317 |
+
if shard_id in ("w1", "w3"):
|
| 318 |
+
# We have to keep the weight scales of w1 and w3 because
|
| 319 |
+
# we need to re-quantize w1/w3 weights after weight loading.
|
| 320 |
+
idx = 0 if shard_id == "w1" else 1
|
| 321 |
+
param_data[expert_id][idx] = loaded_weight
|
| 322 |
+
# If we are in the row parallel case (down_proj)
|
| 323 |
+
elif shard_id == "w2":
|
| 324 |
+
param_data[expert_id] = loaded_weight
|
| 325 |
+
|
| 326 |
+
def _load_model_weight_or_group_weight_scale(self,
|
| 327 |
+
shard_dim: int,
|
| 328 |
+
expert_data: torch.Tensor,
|
| 329 |
+
shard_id: str,
|
| 330 |
+
loaded_weight: torch.Tensor,
|
| 331 |
+
tp_rank: int,
|
| 332 |
+
load_full_w2: bool = False):
|
| 333 |
+
"""
|
| 334 |
+
Load grouped weight scales for group quantization or model weights
|
| 335 |
+
:param shard_dim: dimension to shard
|
| 336 |
+
:param expert_data: parameter for a particular expert
|
| 337 |
+
:param shard_id: either w1, w2, or w3
|
| 338 |
+
:param loaded_weight: checkpoint weight to load into the param
|
| 339 |
+
:param tp_rank: tensor parallel rank
|
| 340 |
+
:param load_full_w2: whether or not the w2 loaded should be sharded.
|
| 341 |
+
"""
|
| 342 |
+
if shard_id == "w2":
|
| 343 |
+
# In the case where we have actorder/g_idx, we do not partition the
|
| 344 |
+
# w2 scales, as indicated by `load_full` argument, for all tp cases
|
| 345 |
+
self._load_w2(shard_dim=shard_dim,
|
| 346 |
+
loaded_weight=loaded_weight,
|
| 347 |
+
expert_data=expert_data,
|
| 348 |
+
tp_rank=tp_rank,
|
| 349 |
+
load_full=load_full_w2)
|
| 350 |
+
elif shard_id in ("w1", "w3"):
|
| 351 |
+
self._load_w13(shard_id=shard_id,
|
| 352 |
+
shard_dim=shard_dim,
|
| 353 |
+
loaded_weight=loaded_weight,
|
| 354 |
+
expert_data=expert_data,
|
| 355 |
+
tp_rank=tp_rank)
|
| 356 |
+
|
| 357 |
+
def _load_per_channel_weight_scale(self, expert_data: torch.Tensor,
|
| 358 |
+
shard_dim: int, shard_id: str,
|
| 359 |
+
loaded_weight: torch.Tensor,
|
| 360 |
+
tp_rank: int):
|
| 361 |
+
# for per channel weight quantization
|
| 362 |
+
if shard_id == "w2":
|
| 363 |
+
expert_data.copy_(loaded_weight)
|
| 364 |
+
elif shard_id in ("w1", "w3"):
|
| 365 |
+
self._load_w13(shard_id=shard_id,
|
| 366 |
+
shard_dim=shard_dim,
|
| 367 |
+
loaded_weight=loaded_weight,
|
| 368 |
+
expert_data=expert_data,
|
| 369 |
+
tp_rank=tp_rank)
|
| 370 |
+
|
| 371 |
+
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
|
| 372 |
+
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int):
|
| 373 |
+
|
| 374 |
+
# Index the loaded weight for tp sharding.
|
| 375 |
+
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
| 376 |
+
shard_size = expert_data.shape[shard_dim] // 2
|
| 377 |
+
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
| 378 |
+
shard_size)
|
| 379 |
+
# Narrow parameter and load.
|
| 380 |
+
# w1, gate_proj: Load into first logical weight of w13.
|
| 381 |
+
if shard_id == "w1":
|
| 382 |
+
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
| 383 |
+
# w3, up_proj: Load into second logical weight of w13.
|
| 384 |
+
else:
|
| 385 |
+
assert shard_id == "w3"
|
| 386 |
+
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
| 387 |
+
expert_data.copy_(loaded_weight)
|
| 388 |
+
|
| 389 |
+
def _load_w2(self,
|
| 390 |
+
expert_data: torch.Tensor,
|
| 391 |
+
shard_dim: int,
|
| 392 |
+
loaded_weight: torch.Tensor,
|
| 393 |
+
tp_rank: int,
|
| 394 |
+
load_full: bool = False):
|
| 395 |
+
|
| 396 |
+
# Index the loaded weight for tp sharding.
|
| 397 |
+
# down_proj: "RowParallel" so tp sharding on input_dim
|
| 398 |
+
# Narrow parameter and load.
|
| 399 |
+
shard_size = expert_data.shape[shard_dim]
|
| 400 |
+
if not load_full:
|
| 401 |
+
loaded_weight = loaded_weight.narrow(shard_dim,
|
| 402 |
+
shard_size * tp_rank,
|
| 403 |
+
shard_size)
|
| 404 |
+
# w2, down_proj: Load into only logical weight of w2.
|
| 405 |
+
expert_data.copy_(loaded_weight)
|
| 406 |
+
|
| 407 |
+
def _load_single_value(self, param: torch.nn.Parameter,
|
| 408 |
+
loaded_weight: torch.Tensor, expert_id: int):
|
| 409 |
+
param_data = param.data
|
| 410 |
+
|
| 411 |
+
# Input scales can be loaded directly and should be equal.
|
| 412 |
+
param_data[expert_id] = loaded_weight
|
| 413 |
+
|
| 414 |
+
def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
|
| 415 |
+
shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int):
|
| 416 |
+
|
| 417 |
+
if shard_id == "w2":
|
| 418 |
+
self._load_w2(shard_dim=shard_dim,
|
| 419 |
+
loaded_weight=loaded_weight,
|
| 420 |
+
expert_data=expert_data,
|
| 421 |
+
tp_rank=tp_rank)
|
| 422 |
+
else:
|
| 423 |
+
assert shard_id in ("w1", "w3")
|
| 424 |
+
expert_data.copy_(loaded_weight)
|
| 425 |
+
|
| 426 |
+
def weight_loader(self, param: torch.nn.Parameter,
|
| 427 |
+
loaded_weight: torch.Tensor, weight_name: str,
|
| 428 |
+
shard_id: str, expert_id: int) -> None:
|
| 429 |
+
|
| 430 |
+
# compressed-tensors checkpoints with packed weights are stored flipped
|
| 431 |
+
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
| 432 |
+
# against known CompressionFormat enum values that have this quality
|
| 433 |
+
loaded_weight = loaded_weight.t().contiguous() if (
|
| 434 |
+
self.quant_method.__class__.__name__
|
| 435 |
+
== "CompressedTensorsWNA16MoEMethod") else loaded_weight
|
| 436 |
+
|
| 437 |
+
if shard_id not in ("w1", "w2", "w3"):
|
| 438 |
+
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
|
| 439 |
+
f"got {shard_id}.")
|
| 440 |
+
|
| 441 |
+
WEIGHT_SCALE_SUPPORTED = [
|
| 442 |
+
e.value for e in FusedMoeWeightScaleSupported
|
| 443 |
+
]
|
| 444 |
+
# Fetch the dim to shard the parameter/loaded weight
|
| 445 |
+
# based on the shard id. This will be whatever
|
| 446 |
+
# dimension intermediate_size_per_partition is used.
|
| 447 |
+
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
|
| 448 |
+
|
| 449 |
+
expert_data = param.data[expert_id]
|
| 450 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 451 |
+
|
| 452 |
+
# is_transposed: if the dim to shard the weight
|
| 453 |
+
# should be flipped. Required by GPTQ, compressed-tensors
|
| 454 |
+
# should be whatever dimension intermediate_size_per_partition is
|
| 455 |
+
is_transposed = getattr(param, "is_transposed", False)
|
| 456 |
+
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
| 457 |
+
if is_transposed:
|
| 458 |
+
shard_dim = int(not shard_dim)
|
| 459 |
+
|
| 460 |
+
# Case input scale: input_scale loading is only supported for fp8
|
| 461 |
+
if "input_scale" in weight_name:
|
| 462 |
+
# this is needed for compressed-tensors only
|
| 463 |
+
loaded_weight = loaded_weight.to(param.data.device)
|
| 464 |
+
|
| 465 |
+
if param.data[expert_id] != 1 and (param.data[expert_id] -
|
| 466 |
+
loaded_weight).abs() > 1e-5:
|
| 467 |
+
raise ValueError(
|
| 468 |
+
"input_scales of w1 and w3 of a layer "
|
| 469 |
+
f"must be equal. But got {param.data[expert_id]} "
|
| 470 |
+
f"vs. {loaded_weight}")
|
| 471 |
+
|
| 472 |
+
self._load_single_value(param=param,
|
| 473 |
+
loaded_weight=loaded_weight,
|
| 474 |
+
expert_id=expert_id)
|
| 475 |
+
return
|
| 476 |
+
|
| 477 |
+
# Case g_idx
|
| 478 |
+
if "g_idx" in weight_name:
|
| 479 |
+
self._load_g_idx(shard_dim=0,
|
| 480 |
+
shard_id=shard_id,
|
| 481 |
+
loaded_weight=loaded_weight,
|
| 482 |
+
expert_data=expert_data,
|
| 483 |
+
tp_rank=tp_rank)
|
| 484 |
+
return
|
| 485 |
+
|
| 486 |
+
# Case weight scales and zero_points
|
| 487 |
+
if ("scale" in weight_name or "zero" in weight_name):
|
| 488 |
+
# load the weight scales and zp based on the quantization scheme
|
| 489 |
+
# supported weight scales/zp can be found in
|
| 490 |
+
# FusedMoeWeightScaleSupported
|
| 491 |
+
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
|
| 492 |
+
# specific to each case
|
| 493 |
+
quant_method = getattr(param, "quant_method", None)
|
| 494 |
+
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
|
| 495 |
+
self._load_per_channel_weight_scale(
|
| 496 |
+
shard_id=shard_id,
|
| 497 |
+
shard_dim=shard_dim,
|
| 498 |
+
loaded_weight=loaded_weight,
|
| 499 |
+
expert_data=expert_data,
|
| 500 |
+
tp_rank=tp_rank)
|
| 501 |
+
elif quant_method in [
|
| 502 |
+
FusedMoeWeightScaleSupported.GROUP.value,
|
| 503 |
+
FusedMoeWeightScaleSupported.BLOCK.value,
|
| 504 |
+
]:
|
| 505 |
+
self._load_model_weight_or_group_weight_scale(
|
| 506 |
+
shard_id=shard_id,
|
| 507 |
+
shard_dim=shard_dim,
|
| 508 |
+
loaded_weight=loaded_weight,
|
| 509 |
+
expert_data=expert_data,
|
| 510 |
+
tp_rank=tp_rank,
|
| 511 |
+
load_full_w2=getattr(param, "load_full_w2", False))
|
| 512 |
+
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
|
| 513 |
+
self._load_per_tensor_weight_scale(shard_id=shard_id,
|
| 514 |
+
param=param,
|
| 515 |
+
loaded_weight=loaded_weight,
|
| 516 |
+
expert_id=expert_id)
|
| 517 |
+
else:
|
| 518 |
+
raise ValueError(
|
| 519 |
+
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
|
| 520 |
+
return
|
| 521 |
+
|
| 522 |
+
# Case weight_shape
|
| 523 |
+
if "weight_shape" in weight_name:
|
| 524 |
+
# only required by compressed-tensors
|
| 525 |
+
self._load_single_value(param=param,
|
| 526 |
+
loaded_weight=loaded_weight,
|
| 527 |
+
expert_id=expert_id)
|
| 528 |
+
return
|
| 529 |
+
|
| 530 |
+
# Case model weights
|
| 531 |
+
if "weight" in weight_name:
|
| 532 |
+
self._load_model_weight_or_group_weight_scale(
|
| 533 |
+
shard_id=shard_id,
|
| 534 |
+
shard_dim=shard_dim,
|
| 535 |
+
loaded_weight=loaded_weight,
|
| 536 |
+
expert_data=expert_data,
|
| 537 |
+
tp_rank=tp_rank)
|
| 538 |
+
return
|
| 539 |
+
|
| 540 |
+
@staticmethod
|
| 541 |
+
def select_experts(hidden_states: torch.Tensor,
|
| 542 |
+
router_logits: torch.Tensor,
|
| 543 |
+
top_k: int,
|
| 544 |
+
use_grouped_topk: bool,
|
| 545 |
+
renormalize: bool,
|
| 546 |
+
topk_group: Optional[int] = None,
|
| 547 |
+
num_expert_group: Optional[int] = None,
|
| 548 |
+
custom_routing_function: Optional[Callable] = None,
|
| 549 |
+
scoring_func: str = "softmax",
|
| 550 |
+
e_score_correction_bias: Optional[torch.Tensor] = None):
|
| 551 |
+
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
| 552 |
+
fused_topk, grouped_topk)
|
| 553 |
+
|
| 554 |
+
# DeekSeekv2 uses grouped_top_k
|
| 555 |
+
if use_grouped_topk:
|
| 556 |
+
assert topk_group is not None
|
| 557 |
+
assert num_expert_group is not None
|
| 558 |
+
topk_weights, topk_ids = grouped_topk(
|
| 559 |
+
hidden_states=hidden_states,
|
| 560 |
+
gating_output=router_logits,
|
| 561 |
+
topk=top_k,
|
| 562 |
+
renormalize=renormalize,
|
| 563 |
+
num_expert_group=num_expert_group,
|
| 564 |
+
topk_group=topk_group,
|
| 565 |
+
scoring_func=scoring_func,
|
| 566 |
+
e_score_correction_bias=e_score_correction_bias)
|
| 567 |
+
elif custom_routing_function is None:
|
| 568 |
+
topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
|
| 569 |
+
gating_output=router_logits,
|
| 570 |
+
topk=top_k,
|
| 571 |
+
renormalize=renormalize)
|
| 572 |
+
else:
|
| 573 |
+
topk_weights, topk_ids = custom_routing_function(
|
| 574 |
+
hidden_states=hidden_states,
|
| 575 |
+
gating_output=router_logits,
|
| 576 |
+
topk=top_k,
|
| 577 |
+
renormalize=renormalize)
|
| 578 |
+
|
| 579 |
+
return topk_weights, topk_ids
|
| 580 |
+
|
| 581 |
+
def forward(self, hidden_states: torch.Tensor,
|
| 582 |
+
router_logits: torch.Tensor):
|
| 583 |
+
assert self.quant_method is not None
|
| 584 |
+
|
| 585 |
+
# Matrix multiply.
|
| 586 |
+
final_hidden_states = self.quant_method.apply(
|
| 587 |
+
layer=self,
|
| 588 |
+
x=hidden_states,
|
| 589 |
+
router_logits=router_logits,
|
| 590 |
+
top_k=self.top_k,
|
| 591 |
+
renormalize=self.renormalize,
|
| 592 |
+
use_grouped_topk=self.use_grouped_topk,
|
| 593 |
+
topk_group=self.topk_group,
|
| 594 |
+
num_expert_group=self.num_expert_group,
|
| 595 |
+
custom_routing_function=self.custom_routing_function,
|
| 596 |
+
scoring_func=self.scoring_func,
|
| 597 |
+
e_score_correction_bias=self.e_score_correction_bias)
|
| 598 |
+
|
| 599 |
+
if self.reduce_results and self.tp_size > 1:
|
| 600 |
+
final_hidden_states = tensor_model_parallel_all_reduce(
|
| 601 |
+
final_hidden_states)
|
| 602 |
+
|
| 603 |
+
return final_hidden_states
|
| 604 |
+
|
| 605 |
+
@classmethod
|
| 606 |
+
def make_expert_params_mapping(
|
| 607 |
+
cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
|
| 608 |
+
ckpt_up_proj_name: str,
|
| 609 |
+
num_experts: int) -> List[Tuple[str, str, int, str]]:
|
| 610 |
+
|
| 611 |
+
return [
|
| 612 |
+
# (param_name, weight_name, expert_id, shard_id)
|
| 613 |
+
("experts.w13_" if weight_name
|
| 614 |
+
in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
|
| 615 |
+
f"experts.{expert_id}.{weight_name}.", expert_id, shard_id)
|
| 616 |
+
for expert_id in range(num_experts) for shard_id, weight_name in [
|
| 617 |
+
("w1", ckpt_gate_proj_name),
|
| 618 |
+
("w2", ckpt_down_proj_name),
|
| 619 |
+
("w3", ckpt_up_proj_name),
|
| 620 |
+
]
|
| 621 |
+
]
|
| 622 |
+
|
| 623 |
+
def _load_fp8_scale(self, param: torch.nn.Parameter,
|
| 624 |
+
loaded_weight: torch.Tensor, weight_name: str,
|
| 625 |
+
shard_id: str, expert_id: int) -> None:
|
| 626 |
+
param_data = param.data
|
| 627 |
+
|
| 628 |
+
# Input scales can be loaded directly and should be equal.
|
| 629 |
+
if "input_scale" in weight_name:
|
| 630 |
+
if param_data[expert_id] != 1 and (param_data[expert_id] -
|
| 631 |
+
loaded_weight).abs() > 1e-5:
|
| 632 |
+
raise ValueError(
|
| 633 |
+
"input_scales of w1 and w3 of a layer "
|
| 634 |
+
f"must be equal. But got {param_data[expert_id]} "
|
| 635 |
+
f"vs. {loaded_weight}")
|
| 636 |
+
param_data[expert_id] = loaded_weight
|
| 637 |
+
# Weight scales
|
| 638 |
+
elif "weight_scale" in weight_name:
|
| 639 |
+
# If we are in merged column case (gate_up_proj)
|
| 640 |
+
if shard_id in ("w1", "w3"):
|
| 641 |
+
# We have to keep the weight scales of w1 and w3 because
|
| 642 |
+
# we need to re-quantize w1/w3 weights after weight loading.
|
| 643 |
+
idx = 0 if shard_id == "w1" else 1
|
| 644 |
+
param_data[expert_id][idx] = loaded_weight
|
| 645 |
+
# If we are in the row parallel case (down_proj)
|
| 646 |
+
else:
|
| 647 |
+
param_data[expert_id] = loaded_weight
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/moe_pallas.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch_xla.experimental.custom_kernel import _histogram
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def fused_moe(
|
| 9 |
+
hidden_states: torch.Tensor,
|
| 10 |
+
w1: torch.Tensor,
|
| 11 |
+
w2: torch.Tensor,
|
| 12 |
+
gating_output: torch.Tensor,
|
| 13 |
+
topk: int,
|
| 14 |
+
renormalize: bool,
|
| 15 |
+
) -> torch.Tensor:
|
| 16 |
+
"""
|
| 17 |
+
Args:
|
| 18 |
+
hidden_states: [*, hidden_size]
|
| 19 |
+
w1: [num_experts, intermediate_size * 2, hidden_size]
|
| 20 |
+
w2: [num_experts, hidden_size, intermediate_size]
|
| 21 |
+
gating_output: [*, num_experts]
|
| 22 |
+
"""
|
| 23 |
+
orig_shape = hidden_states.shape
|
| 24 |
+
hidden_size = hidden_states.shape[-1]
|
| 25 |
+
num_tokens = hidden_states.shape[:-1].numel()
|
| 26 |
+
num_experts = w1.shape[0]
|
| 27 |
+
intermediate_size = w2.shape[-1]
|
| 28 |
+
device = hidden_states.device
|
| 29 |
+
dtype = hidden_states.dtype
|
| 30 |
+
assert (num_tokens * topk) % 16 == 0, (
|
| 31 |
+
"The Pallas GMM kernel requires num_tokens * topk to be a multiple of "
|
| 32 |
+
f"16 but got {num_tokens * topk}")
|
| 33 |
+
|
| 34 |
+
hidden_states = hidden_states.view(num_tokens, hidden_size)
|
| 35 |
+
gating_output = gating_output.view(num_tokens, num_experts)
|
| 36 |
+
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
|
| 37 |
+
topk_weights, topk_indices = topk_weights.topk(topk, dim=-1)
|
| 38 |
+
if renormalize:
|
| 39 |
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
| 40 |
+
topk_weights = topk_weights.to(dtype)
|
| 41 |
+
|
| 42 |
+
topk_indices = topk_indices.flatten()
|
| 43 |
+
topk_argsort_indices = topk_indices.argsort()
|
| 44 |
+
topk_argsort_revert_indices = topk_argsort_indices.argsort()
|
| 45 |
+
token_indices = torch.arange(num_tokens,
|
| 46 |
+
device=device).repeat_interleave(topk)
|
| 47 |
+
token_indices = token_indices[topk_argsort_indices]
|
| 48 |
+
group_sizes = _histogram(topk_indices.to(torch.int32), 0, num_experts - 1)
|
| 49 |
+
|
| 50 |
+
# NOTE(woosuk): The GMM Pallas kernel requires a different weight layout
|
| 51 |
+
# from HF Transformers.
|
| 52 |
+
w1 = w1.transpose(1, 2)
|
| 53 |
+
w2 = w2.transpose(1, 2)
|
| 54 |
+
|
| 55 |
+
x = hidden_states[token_indices]
|
| 56 |
+
x = torch.ops.xla.gmm(x, w1, group_sizes)
|
| 57 |
+
x = F.silu(x[..., :intermediate_size]) * x[..., intermediate_size:]
|
| 58 |
+
x = torch.ops.xla.gmm(x, w2, group_sizes)
|
| 59 |
+
x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
|
| 60 |
+
|
| 61 |
+
x = x * topk_weights.unsqueeze_(dim=-1)
|
| 62 |
+
x = x.sum(dim=-2)
|
| 63 |
+
x = x.reshape(orig_shape)
|
| 64 |
+
return x
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def fused_moe(
|
| 8 |
+
hidden_states: torch.Tensor,
|
| 9 |
+
w1: torch.Tensor,
|
| 10 |
+
w2: torch.Tensor,
|
| 11 |
+
gating_output: torch.Tensor,
|
| 12 |
+
topk: int,
|
| 13 |
+
renormalize: bool,
|
| 14 |
+
) -> torch.Tensor:
|
| 15 |
+
"""
|
| 16 |
+
Args:
|
| 17 |
+
hidden_states: [*, hidden_size]
|
| 18 |
+
w1: [num_experts, intermediate_size * 2, hidden_size]
|
| 19 |
+
w2: [num_experts, hidden_size, intermediate_size]
|
| 20 |
+
gating_output: [*, num_experts]
|
| 21 |
+
"""
|
| 22 |
+
orig_shape = hidden_states.shape
|
| 23 |
+
hidden_size = hidden_states.shape[-1]
|
| 24 |
+
num_tokens = hidden_states.shape[:-1].numel()
|
| 25 |
+
num_experts = w1.shape[0]
|
| 26 |
+
intermediate_size = w2.shape[-1]
|
| 27 |
+
dtype = hidden_states.dtype
|
| 28 |
+
|
| 29 |
+
hidden_states = hidden_states.view(num_tokens, hidden_size)
|
| 30 |
+
gating_output = gating_output.view(num_tokens, num_experts)
|
| 31 |
+
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
|
| 32 |
+
topk_weights, selected_experts = topk_weights.topk(topk, dim=-1)
|
| 33 |
+
if renormalize:
|
| 34 |
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
| 35 |
+
topk_weights = topk_weights.to(dtype)
|
| 36 |
+
|
| 37 |
+
final_hidden_states = None
|
| 38 |
+
for expert_idx in range(num_experts):
|
| 39 |
+
expert_w1 = w1[expert_idx]
|
| 40 |
+
expert_w2 = w2[expert_idx]
|
| 41 |
+
expert_mask = (selected_experts == expert_idx)
|
| 42 |
+
expert_weights = (topk_weights * expert_mask).sum(dim=-1, keepdim=True)
|
| 43 |
+
x = F.linear(hidden_states, expert_w1)
|
| 44 |
+
gate = F.silu(x[:, :intermediate_size])
|
| 45 |
+
x = x[:, intermediate_size:] * gate
|
| 46 |
+
x = F.linear(x, expert_w2)
|
| 47 |
+
current_hidden_states = x * expert_weights
|
| 48 |
+
if final_hidden_states is None:
|
| 49 |
+
final_hidden_states = current_hidden_states
|
| 50 |
+
else:
|
| 51 |
+
final_hidden_states = final_hidden_states + current_hidden_states
|
| 52 |
+
|
| 53 |
+
return final_hidden_states.view(orig_shape) # type: ignore
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/layernorm.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""Custom normalization layers."""
|
| 3 |
+
from typing import Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from vllm.model_executor.custom_op import CustomOp
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@CustomOp.register("rms_norm")
|
| 12 |
+
class RMSNorm(CustomOp):
|
| 13 |
+
"""Root mean square normalization.
|
| 14 |
+
|
| 15 |
+
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
|
| 16 |
+
Refer to https://arxiv.org/abs/1910.07467
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
hidden_size: int,
|
| 22 |
+
eps: float = 1e-6,
|
| 23 |
+
var_hidden_size: Optional[int] = None,
|
| 24 |
+
has_weight: bool = True,
|
| 25 |
+
) -> None:
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
+
self.hidden_size = hidden_size
|
| 29 |
+
self.variance_epsilon = eps
|
| 30 |
+
self.variance_size_override = (None if var_hidden_size == hidden_size
|
| 31 |
+
else var_hidden_size)
|
| 32 |
+
self.has_weight = has_weight
|
| 33 |
+
|
| 34 |
+
self.weight = torch.ones(hidden_size)
|
| 35 |
+
if self.has_weight:
|
| 36 |
+
self.weight = nn.Parameter(self.weight)
|
| 37 |
+
|
| 38 |
+
def forward_native(
|
| 39 |
+
self,
|
| 40 |
+
x: torch.Tensor,
|
| 41 |
+
residual: Optional[torch.Tensor] = None,
|
| 42 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 43 |
+
"""PyTorch-native implementation equivalent to forward()."""
|
| 44 |
+
orig_dtype = x.dtype
|
| 45 |
+
x = x.to(torch.float32)
|
| 46 |
+
if residual is not None:
|
| 47 |
+
x = x + residual.to(torch.float32)
|
| 48 |
+
residual = x.to(orig_dtype)
|
| 49 |
+
|
| 50 |
+
hidden_size = x.shape[-1]
|
| 51 |
+
if hidden_size != self.hidden_size:
|
| 52 |
+
raise ValueError("Expected hidden_size to be "
|
| 53 |
+
f"{self.hidden_size}, but found: {hidden_size}")
|
| 54 |
+
|
| 55 |
+
if self.variance_size_override is None:
|
| 56 |
+
x_var = x
|
| 57 |
+
else:
|
| 58 |
+
if hidden_size < self.variance_size_override:
|
| 59 |
+
raise ValueError(
|
| 60 |
+
"Expected hidden_size to be at least "
|
| 61 |
+
f"{self.variance_size_override}, but found: {hidden_size}")
|
| 62 |
+
|
| 63 |
+
x_var = x[:, :, :self.variance_size_override]
|
| 64 |
+
|
| 65 |
+
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
|
| 66 |
+
|
| 67 |
+
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
| 68 |
+
x = x.to(orig_dtype)
|
| 69 |
+
if self.has_weight:
|
| 70 |
+
x = x * self.weight
|
| 71 |
+
if residual is None:
|
| 72 |
+
return x
|
| 73 |
+
else:
|
| 74 |
+
return x, residual
|
| 75 |
+
|
| 76 |
+
def forward_cuda(
|
| 77 |
+
self,
|
| 78 |
+
x: torch.Tensor,
|
| 79 |
+
residual: Optional[torch.Tensor] = None,
|
| 80 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 81 |
+
if self.variance_size_override is not None:
|
| 82 |
+
return self.forward_native(x, residual)
|
| 83 |
+
|
| 84 |
+
from vllm import _custom_ops as ops
|
| 85 |
+
|
| 86 |
+
if residual is not None:
|
| 87 |
+
ops.fused_add_rms_norm(
|
| 88 |
+
x,
|
| 89 |
+
residual,
|
| 90 |
+
self.weight.data,
|
| 91 |
+
self.variance_epsilon,
|
| 92 |
+
)
|
| 93 |
+
return x, residual
|
| 94 |
+
out = torch.empty_like(x)
|
| 95 |
+
ops.rms_norm(
|
| 96 |
+
out,
|
| 97 |
+
x,
|
| 98 |
+
self.weight.data,
|
| 99 |
+
self.variance_epsilon,
|
| 100 |
+
)
|
| 101 |
+
return out
|
| 102 |
+
|
| 103 |
+
def forward_hpu(
|
| 104 |
+
self,
|
| 105 |
+
x: torch.Tensor,
|
| 106 |
+
residual: Optional[torch.Tensor] = None,
|
| 107 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 108 |
+
from vllm_hpu_extension.ops import HPUFusedRMSNorm
|
| 109 |
+
if HPUFusedRMSNorm is None:
|
| 110 |
+
return self.forward_native(x, residual)
|
| 111 |
+
if residual is not None:
|
| 112 |
+
orig_shape = x.shape
|
| 113 |
+
residual += x.view(residual.shape)
|
| 114 |
+
# Note: HPUFusedRMSNorm requires 3D tensors as inputs
|
| 115 |
+
x = HPUFusedRMSNorm.apply(residual, self.weight,
|
| 116 |
+
self.variance_epsilon)
|
| 117 |
+
return x.view(orig_shape), residual
|
| 118 |
+
|
| 119 |
+
x = HPUFusedRMSNorm.apply(x, self.weight, self.variance_epsilon)
|
| 120 |
+
return x
|
| 121 |
+
|
| 122 |
+
def forward_xpu(
|
| 123 |
+
self,
|
| 124 |
+
x: torch.Tensor,
|
| 125 |
+
residual: Optional[torch.Tensor] = None,
|
| 126 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 127 |
+
if self.variance_size_override is not None:
|
| 128 |
+
return self.forward_native(x, residual)
|
| 129 |
+
|
| 130 |
+
from vllm._ipex_ops import ipex_ops as ops
|
| 131 |
+
|
| 132 |
+
if residual is not None:
|
| 133 |
+
ops.fused_add_rms_norm(
|
| 134 |
+
x,
|
| 135 |
+
residual,
|
| 136 |
+
self.weight.data,
|
| 137 |
+
self.variance_epsilon,
|
| 138 |
+
)
|
| 139 |
+
return x, residual
|
| 140 |
+
return ops.rms_norm(
|
| 141 |
+
x,
|
| 142 |
+
self.weight.data,
|
| 143 |
+
self.variance_epsilon,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
def extra_repr(self) -> str:
|
| 147 |
+
s = f"hidden_size={self.weight.data.size(0)}"
|
| 148 |
+
s += f", eps={self.variance_epsilon}"
|
| 149 |
+
return s
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
@CustomOp.register("gemma_rms_norm")
|
| 153 |
+
class GemmaRMSNorm(CustomOp):
|
| 154 |
+
"""RMS normalization for Gemma.
|
| 155 |
+
|
| 156 |
+
Two differences from the above RMSNorm:
|
| 157 |
+
1. x * (1 + w) instead of x * w.
|
| 158 |
+
2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
def __init__(
|
| 162 |
+
self,
|
| 163 |
+
hidden_size: int,
|
| 164 |
+
eps: float = 1e-6,
|
| 165 |
+
) -> None:
|
| 166 |
+
super().__init__()
|
| 167 |
+
self.weight = nn.Parameter(torch.zeros(hidden_size))
|
| 168 |
+
self.variance_epsilon = eps
|
| 169 |
+
|
| 170 |
+
@staticmethod
|
| 171 |
+
def forward_static(
|
| 172 |
+
weight: torch.Tensor,
|
| 173 |
+
variance_epsilon: float,
|
| 174 |
+
x: torch.Tensor,
|
| 175 |
+
residual: Optional[torch.Tensor],
|
| 176 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 177 |
+
"""PyTorch-native implementation equivalent to forward()."""
|
| 178 |
+
orig_dtype = x.dtype
|
| 179 |
+
if residual is not None:
|
| 180 |
+
x = x + residual
|
| 181 |
+
residual = x
|
| 182 |
+
|
| 183 |
+
x = x.float()
|
| 184 |
+
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
| 185 |
+
x = x * torch.rsqrt(variance + variance_epsilon)
|
| 186 |
+
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
|
| 187 |
+
# See https://github.com/huggingface/transformers/pull/29402
|
| 188 |
+
x = x * (1.0 + weight.float())
|
| 189 |
+
x = x.to(orig_dtype)
|
| 190 |
+
return x if residual is None else (x, residual)
|
| 191 |
+
|
| 192 |
+
def forward_native(
|
| 193 |
+
self,
|
| 194 |
+
x: torch.Tensor,
|
| 195 |
+
residual: Optional[torch.Tensor] = None,
|
| 196 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 197 |
+
"""PyTorch-native implementation equivalent to forward()."""
|
| 198 |
+
return self.forward_static(self.weight.data, self.variance_epsilon, x,
|
| 199 |
+
residual)
|
| 200 |
+
|
| 201 |
+
def forward_cuda(
|
| 202 |
+
self,
|
| 203 |
+
x: torch.Tensor,
|
| 204 |
+
residual: Optional[torch.Tensor] = None,
|
| 205 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 206 |
+
if torch.compiler.is_compiling():
|
| 207 |
+
return self.forward_native(x, residual)
|
| 208 |
+
|
| 209 |
+
if not getattr(self, "_is_compiled", False):
|
| 210 |
+
self.forward_static = torch.compile( # type: ignore
|
| 211 |
+
self.forward_static)
|
| 212 |
+
self._is_compiled = True
|
| 213 |
+
return self.forward_native(x, residual)
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/linear.py
ADDED
|
@@ -0,0 +1,1159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import itertools
|
| 4 |
+
from abc import abstractmethod
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch.nn.parameter import Parameter, UninitializedParameter
|
| 10 |
+
|
| 11 |
+
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
| 12 |
+
get_tensor_model_parallel_world_size,
|
| 13 |
+
split_tensor_along_last_dim,
|
| 14 |
+
tensor_model_parallel_all_gather,
|
| 15 |
+
tensor_model_parallel_all_reduce)
|
| 16 |
+
from vllm.logger import init_logger
|
| 17 |
+
from vllm.model_executor.layers.quantization.base_config import (
|
| 18 |
+
QuantizationConfig, QuantizeMethodBase)
|
| 19 |
+
# yapf: disable
|
| 20 |
+
from vllm.model_executor.parameter import (BasevLLMParameter,
|
| 21 |
+
BlockQuantScaleParameter,
|
| 22 |
+
PackedColumnParameter,
|
| 23 |
+
PackedvLLMParameter,
|
| 24 |
+
PerTensorScaleParameter,
|
| 25 |
+
RowvLLMParameter)
|
| 26 |
+
# yapf: enable
|
| 27 |
+
from vllm.model_executor.utils import set_weight_attrs
|
| 28 |
+
|
| 29 |
+
logger = init_logger(__name__)
|
| 30 |
+
|
| 31 |
+
WEIGHT_LOADER_V2_SUPPORTED = [
|
| 32 |
+
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
|
| 33 |
+
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
|
| 34 |
+
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
|
| 35 |
+
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
|
| 36 |
+
"ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod",
|
| 37 |
+
"HQQMarlinMethod", "QuarkLinearMethod"
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def adjust_marlin_shard(param, shard_size, shard_offset):
|
| 42 |
+
marlin_tile_size = getattr(param, "marlin_tile_size", None)
|
| 43 |
+
if marlin_tile_size is None:
|
| 44 |
+
return shard_size, shard_offset
|
| 45 |
+
|
| 46 |
+
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def adjust_bitsandbytes_4bit_shard(param: Parameter,
|
| 50 |
+
shard_offsets: dict[str, tuple[int, int]],
|
| 51 |
+
loaded_shard_id: str) -> tuple[int, int]:
|
| 52 |
+
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
|
| 53 |
+
|
| 54 |
+
total, _ = shard_offsets["total"]
|
| 55 |
+
orig_offset, orig_size = shard_offsets[loaded_shard_id]
|
| 56 |
+
|
| 57 |
+
quantized_total = param.data.shape[0]
|
| 58 |
+
quantized_offset = orig_offset * quantized_total // total
|
| 59 |
+
quantized_size = orig_size * quantized_total // total
|
| 60 |
+
|
| 61 |
+
return quantized_size, quantized_offset
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
|
| 65 |
+
"""For fused modules (QKV and MLP) we have an array of length
|
| 66 |
+
N that holds 1 scale for each "logical" matrix. So the param
|
| 67 |
+
is an array of length N. The loaded_weight corresponds to
|
| 68 |
+
one of the shards on disk. Here, we slice the param based on
|
| 69 |
+
the shard_id for loading.
|
| 70 |
+
"""
|
| 71 |
+
qkv_idxs = {"q": 0, "k": 1, "v": 2}
|
| 72 |
+
|
| 73 |
+
if isinstance(shard_id, str):
|
| 74 |
+
shard_id = qkv_idxs[shard_id]
|
| 75 |
+
elif not isinstance(shard_id, int):
|
| 76 |
+
raise ValueError(f"Unknown Shard Id {shard_id}")
|
| 77 |
+
|
| 78 |
+
# AutoFP8 scales do not have a shape
|
| 79 |
+
# compressed-tensors scales do have a shape
|
| 80 |
+
if len(loaded_weight.shape) != 0:
|
| 81 |
+
assert loaded_weight.shape[0] == 1
|
| 82 |
+
loaded_weight = loaded_weight[0]
|
| 83 |
+
|
| 84 |
+
return param[shard_id], loaded_weight
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class LinearMethodBase(QuantizeMethodBase):
|
| 88 |
+
"""Base class for different (maybe quantized) linear methods."""
|
| 89 |
+
|
| 90 |
+
@abstractmethod
|
| 91 |
+
def create_weights(self, layer: torch.nn.Module,
|
| 92 |
+
input_size_per_partition: int,
|
| 93 |
+
output_partition_sizes: list[int], input_size: int,
|
| 94 |
+
output_size: int, params_dtype: torch.dtype,
|
| 95 |
+
**extra_weight_attrs):
|
| 96 |
+
"""Create weights for a linear layer.
|
| 97 |
+
The weights will be set as attributes of the layer.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
layer: The layer that is using the LinearMethodBase factory.
|
| 101 |
+
input_size_per_partition: Size of the weight input dim on rank X.
|
| 102 |
+
output_partition_sizes: Sizes of the output dim of each logical
|
| 103 |
+
weight on rank X. E.g., output_partition_sizes for QKVLinear
|
| 104 |
+
is a list contains the width of Wq, Wk, Wv on rank X.
|
| 105 |
+
input_size: Size of the input dim of the weight across all ranks.
|
| 106 |
+
output_size: Size of the output dim of the weight across all ranks.
|
| 107 |
+
params_dtype: Datatype of the parameters.
|
| 108 |
+
"""
|
| 109 |
+
raise NotImplementedError
|
| 110 |
+
|
| 111 |
+
@abstractmethod
|
| 112 |
+
def apply(self,
|
| 113 |
+
layer: torch.nn.Module,
|
| 114 |
+
x: torch.Tensor,
|
| 115 |
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 116 |
+
"""Apply the weights in layer to the input tensor.
|
| 117 |
+
Expects create_weights to have been called before on the layer."""
|
| 118 |
+
raise NotImplementedError
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class UnquantizedLinearMethod(LinearMethodBase):
|
| 122 |
+
"""Linear method without quantization."""
|
| 123 |
+
|
| 124 |
+
def create_weights(self, layer: torch.nn.Module,
|
| 125 |
+
input_size_per_partition: int,
|
| 126 |
+
output_partition_sizes: list[int], input_size: int,
|
| 127 |
+
output_size: int, params_dtype: torch.dtype,
|
| 128 |
+
**extra_weight_attrs):
|
| 129 |
+
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
| 130 |
+
input_size_per_partition,
|
| 131 |
+
dtype=params_dtype),
|
| 132 |
+
requires_grad=False)
|
| 133 |
+
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
| 134 |
+
layer.register_parameter("weight", weight)
|
| 135 |
+
set_weight_attrs(weight, extra_weight_attrs)
|
| 136 |
+
|
| 137 |
+
def apply(self,
|
| 138 |
+
layer: torch.nn.Module,
|
| 139 |
+
x: torch.Tensor,
|
| 140 |
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 141 |
+
|
| 142 |
+
return F.linear(x, layer.weight, bias)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class LinearBase(torch.nn.Module):
|
| 146 |
+
"""Base linear layer.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
input_size: input dimension of the linear layer.
|
| 150 |
+
output_size: output dimension of the linear layer.
|
| 151 |
+
bias: If true, add bias.
|
| 152 |
+
skip_bias_add: If true, skip adding bias but instead return it.
|
| 153 |
+
params_dtype: Data type for the parameters.
|
| 154 |
+
quant_config: Quantization configure.
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
def __init__(
|
| 158 |
+
self,
|
| 159 |
+
input_size: int,
|
| 160 |
+
output_size: int,
|
| 161 |
+
skip_bias_add: bool = False,
|
| 162 |
+
params_dtype: Optional[torch.dtype] = None,
|
| 163 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 164 |
+
prefix: str = "",
|
| 165 |
+
):
|
| 166 |
+
super().__init__()
|
| 167 |
+
|
| 168 |
+
# Keep input parameters
|
| 169 |
+
self.input_size = input_size
|
| 170 |
+
self.output_size = output_size
|
| 171 |
+
self.skip_bias_add = skip_bias_add
|
| 172 |
+
if params_dtype is None:
|
| 173 |
+
params_dtype = torch.get_default_dtype()
|
| 174 |
+
self.params_dtype = params_dtype
|
| 175 |
+
if quant_config is None:
|
| 176 |
+
self.quant_method: Optional[
|
| 177 |
+
QuantizeMethodBase] = UnquantizedLinearMethod()
|
| 178 |
+
else:
|
| 179 |
+
self.quant_method = quant_config.get_quant_method(self,
|
| 180 |
+
prefix=prefix)
|
| 181 |
+
|
| 182 |
+
def forward(self,
|
| 183 |
+
x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
|
| 184 |
+
raise NotImplementedError
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class ReplicatedLinear(LinearBase):
|
| 188 |
+
"""Replicated linear layer.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
input_size: input dimension of the linear layer.
|
| 192 |
+
output_size: output dimension of the linear layer.
|
| 193 |
+
bias: If true, add bias.
|
| 194 |
+
skip_bias_add: If true, skip adding bias but instead return it.
|
| 195 |
+
params_dtype: Data type for the parameters.
|
| 196 |
+
quant_config: Quantization configure.
|
| 197 |
+
prefix: The name of the layer in the state dict, including all parents
|
| 198 |
+
(e.g. model.layers.0.qkv_proj)
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
def __init__(self,
|
| 202 |
+
input_size: int,
|
| 203 |
+
output_size: int,
|
| 204 |
+
bias: bool = True,
|
| 205 |
+
skip_bias_add: bool = False,
|
| 206 |
+
params_dtype: Optional[torch.dtype] = None,
|
| 207 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 208 |
+
prefix: str = ""):
|
| 209 |
+
super().__init__(input_size,
|
| 210 |
+
output_size,
|
| 211 |
+
skip_bias_add,
|
| 212 |
+
params_dtype,
|
| 213 |
+
quant_config,
|
| 214 |
+
prefix=prefix)
|
| 215 |
+
|
| 216 |
+
# All the linear layer supports quant method.
|
| 217 |
+
assert self.quant_method is not None
|
| 218 |
+
self.quant_method.create_weights(self,
|
| 219 |
+
self.input_size, [self.output_size],
|
| 220 |
+
self.input_size,
|
| 221 |
+
self.output_size,
|
| 222 |
+
self.params_dtype,
|
| 223 |
+
weight_loader=self.weight_loader)
|
| 224 |
+
|
| 225 |
+
if bias:
|
| 226 |
+
self.bias = Parameter(
|
| 227 |
+
torch.empty(self.output_size, dtype=self.params_dtype))
|
| 228 |
+
set_weight_attrs(self.bias, {
|
| 229 |
+
"output_dim": 0,
|
| 230 |
+
"weight_loader": self.weight_loader,
|
| 231 |
+
})
|
| 232 |
+
else:
|
| 233 |
+
self.register_parameter("bias", None)
|
| 234 |
+
|
| 235 |
+
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
| 236 |
+
# If the weight on disk does not have a shape, give it one
|
| 237 |
+
# (such scales for AutoFp8).
|
| 238 |
+
if len(loaded_weight.shape) == 0:
|
| 239 |
+
loaded_weight = loaded_weight.reshape(1)
|
| 240 |
+
|
| 241 |
+
assert param.size() == loaded_weight.size()
|
| 242 |
+
param.data.copy_(loaded_weight)
|
| 243 |
+
|
| 244 |
+
def forward(self,
|
| 245 |
+
x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
|
| 246 |
+
bias = self.bias if not self.skip_bias_add else None
|
| 247 |
+
assert self.quant_method is not None
|
| 248 |
+
output = self.quant_method.apply(self, x, bias)
|
| 249 |
+
output_bias = self.bias if self.skip_bias_add else None
|
| 250 |
+
return output, output_bias
|
| 251 |
+
|
| 252 |
+
def extra_repr(self) -> str:
|
| 253 |
+
s = f"in_features={self.input_size}"
|
| 254 |
+
s += f", output_features={self.output_size}"
|
| 255 |
+
s += f", bias={self.bias is not None}"
|
| 256 |
+
return s
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class ColumnParallelLinear(LinearBase):
|
| 260 |
+
"""Linear layer with column parallelism.
|
| 261 |
+
|
| 262 |
+
The linear layer is defined as Y = XA + b. A is parallelized along
|
| 263 |
+
its second dimension as A = [A_1, ..., A_p].
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
input_size: first dimension of matrix A.
|
| 267 |
+
output_size: second dimension of matrix A.
|
| 268 |
+
bias: If true, add bias.
|
| 269 |
+
gather_output: If true, call all-gather on output and make Y available
|
| 270 |
+
to all GPUs, otherwise, every GPU will have its output
|
| 271 |
+
which is Y_i = XA_i
|
| 272 |
+
skip_bias_add: This was added to enable performance optimizations where
|
| 273 |
+
bias can be fused with other element-wise operations. we
|
| 274 |
+
skip adding bias but instead return it.
|
| 275 |
+
params_dtype: Data type for the parameters.
|
| 276 |
+
quant_config: Quantization configure.
|
| 277 |
+
output_sizes: list of output sizes packed into one output, like for QKV
|
| 278 |
+
the list would be size 3.
|
| 279 |
+
prefix: The name of the layer in the state dict, including all parents
|
| 280 |
+
(e.g. model.layers.0.qkv_proj)
|
| 281 |
+
"""
|
| 282 |
+
|
| 283 |
+
def __init__(self,
|
| 284 |
+
input_size: int,
|
| 285 |
+
output_size: int,
|
| 286 |
+
bias: bool = True,
|
| 287 |
+
gather_output: bool = False,
|
| 288 |
+
skip_bias_add: bool = False,
|
| 289 |
+
params_dtype: Optional[torch.dtype] = None,
|
| 290 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 291 |
+
output_sizes: Optional[list[int]] = None,
|
| 292 |
+
prefix: str = ""):
|
| 293 |
+
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
| 294 |
+
quant_config, prefix)
|
| 295 |
+
|
| 296 |
+
self.gather_output = gather_output
|
| 297 |
+
|
| 298 |
+
# Divide the weight matrix along the last dimension.
|
| 299 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 300 |
+
assert self.quant_method is not None
|
| 301 |
+
self.output_size_per_partition = divide(self.output_size, tp_size)
|
| 302 |
+
self.output_partition_sizes = [self.output_size_per_partition]
|
| 303 |
+
# If QKV or MergedColumn, use output size of each partition.
|
| 304 |
+
if hasattr(self, "output_sizes"):
|
| 305 |
+
self.output_partition_sizes = [
|
| 306 |
+
divide(output_size, tp_size)
|
| 307 |
+
for output_size in self.output_sizes
|
| 308 |
+
]
|
| 309 |
+
|
| 310 |
+
if output_sizes is None:
|
| 311 |
+
output_sizes = [output_size]
|
| 312 |
+
|
| 313 |
+
self.quant_method.create_weights(
|
| 314 |
+
layer=self,
|
| 315 |
+
input_size_per_partition=self.input_size,
|
| 316 |
+
output_partition_sizes=self.output_partition_sizes,
|
| 317 |
+
input_size=self.input_size,
|
| 318 |
+
output_size=self.output_size,
|
| 319 |
+
params_dtype=self.params_dtype,
|
| 320 |
+
weight_loader=(
|
| 321 |
+
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
| 322 |
+
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
|
| 323 |
+
if bias:
|
| 324 |
+
self.bias = Parameter(
|
| 325 |
+
torch.empty(self.output_size_per_partition,
|
| 326 |
+
dtype=params_dtype))
|
| 327 |
+
set_weight_attrs(self.bias, {
|
| 328 |
+
"output_dim": 0,
|
| 329 |
+
"weight_loader": self.weight_loader,
|
| 330 |
+
})
|
| 331 |
+
else:
|
| 332 |
+
self.register_parameter("bias", None)
|
| 333 |
+
|
| 334 |
+
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
| 335 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 336 |
+
output_dim = getattr(param, "output_dim", None)
|
| 337 |
+
|
| 338 |
+
# Special case for GGUF
|
| 339 |
+
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
| 340 |
+
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
| 341 |
+
if is_gguf_weight_type:
|
| 342 |
+
param.weight_type = loaded_weight.item()
|
| 343 |
+
|
| 344 |
+
# Materialize GGUF UninitializedParameter
|
| 345 |
+
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
| 346 |
+
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
|
| 347 |
+
|
| 348 |
+
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
| 349 |
+
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
| 350 |
+
# bitsandbytes loads the weights of the specific portion
|
| 351 |
+
# no need to narrow
|
| 352 |
+
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
|
| 353 |
+
|
| 354 |
+
param_data = param.data
|
| 355 |
+
if output_dim is not None and not is_sharded_weight:
|
| 356 |
+
shard_size = param_data.shape[output_dim]
|
| 357 |
+
start_idx = tp_rank * shard_size
|
| 358 |
+
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
| 359 |
+
shard_size)
|
| 360 |
+
|
| 361 |
+
# Special case for loading scales off disk, which often do not
|
| 362 |
+
# have a shape (such as in the case of AutoFP8).
|
| 363 |
+
if len(loaded_weight.shape) == 0:
|
| 364 |
+
loaded_weight = loaded_weight.reshape(1)
|
| 365 |
+
|
| 366 |
+
assert param_data.shape == loaded_weight.shape
|
| 367 |
+
param_data.copy_(loaded_weight)
|
| 368 |
+
|
| 369 |
+
def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
|
| 370 |
+
# Special case for loading scales off disk, which often do not
|
| 371 |
+
# have a shape (such as in the case of AutoFP8).
|
| 372 |
+
if len(loaded_weight.shape) == 0:
|
| 373 |
+
assert loaded_weight.numel() == 1
|
| 374 |
+
loaded_weight = loaded_weight.reshape(1)
|
| 375 |
+
param.load_column_parallel_weight(loaded_weight=loaded_weight)
|
| 376 |
+
|
| 377 |
+
def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]:
|
| 378 |
+
bias = self.bias if not self.skip_bias_add else None
|
| 379 |
+
|
| 380 |
+
# Matrix multiply.
|
| 381 |
+
assert self.quant_method is not None
|
| 382 |
+
output_parallel = self.quant_method.apply(self, input_, bias)
|
| 383 |
+
if self.gather_output:
|
| 384 |
+
# All-gather across the partitions.
|
| 385 |
+
output = tensor_model_parallel_all_gather(output_parallel)
|
| 386 |
+
else:
|
| 387 |
+
output = output_parallel
|
| 388 |
+
output_bias = self.bias if self.skip_bias_add else None
|
| 389 |
+
return output, output_bias
|
| 390 |
+
|
| 391 |
+
def extra_repr(self) -> str:
|
| 392 |
+
s = f"in_features={self.input_size}"
|
| 393 |
+
s += f", output_features={self.output_size_per_partition}"
|
| 394 |
+
s += f", bias={self.bias is not None}"
|
| 395 |
+
s += f", tp_size={get_tensor_model_parallel_world_size()}"
|
| 396 |
+
s += f", gather_output={self.gather_output}"
|
| 397 |
+
return s
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
class MergedColumnParallelLinear(ColumnParallelLinear):
|
| 401 |
+
"""Packed linear layers with column parallelism.
|
| 402 |
+
|
| 403 |
+
Similar to ColumnParallelLinear, but the weight matrix is concatenated
|
| 404 |
+
along the output dimension. When the weight matrix is loaded, the
|
| 405 |
+
different partitions are sharded separately.
|
| 406 |
+
|
| 407 |
+
Args:
|
| 408 |
+
input_size: input dimension of the linear layer.
|
| 409 |
+
output_sizes: list of output dimensions of the linear layer.
|
| 410 |
+
bias: If true, add bias.
|
| 411 |
+
gather_output: If true, call all-gather on output and make the output
|
| 412 |
+
available to all GPUs, otherwise, every GPU will have
|
| 413 |
+
its own output.
|
| 414 |
+
skip_bias_add: This was added to enable performance optimizations where
|
| 415 |
+
bias can be fused with other element-wise operations. we
|
| 416 |
+
skip adding bias but instead return it.
|
| 417 |
+
params_dtype: Data type for the parameters.
|
| 418 |
+
quant_config: Quantization configure.
|
| 419 |
+
prefix: The name of the layer in the state dict, including all parents
|
| 420 |
+
(e.g. model.layers.0.qkv_proj)
|
| 421 |
+
"""
|
| 422 |
+
|
| 423 |
+
def __init__(self,
|
| 424 |
+
input_size: int,
|
| 425 |
+
output_sizes: list[int],
|
| 426 |
+
bias: bool = True,
|
| 427 |
+
gather_output: bool = False,
|
| 428 |
+
skip_bias_add: bool = False,
|
| 429 |
+
params_dtype: Optional[torch.dtype] = None,
|
| 430 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 431 |
+
prefix: str = ""):
|
| 432 |
+
self.output_sizes = output_sizes
|
| 433 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 434 |
+
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
| 435 |
+
super().__init__(input_size=input_size,
|
| 436 |
+
output_size=sum(output_sizes),
|
| 437 |
+
bias=bias,
|
| 438 |
+
gather_output=gather_output,
|
| 439 |
+
skip_bias_add=skip_bias_add,
|
| 440 |
+
params_dtype=params_dtype,
|
| 441 |
+
quant_config=quant_config,
|
| 442 |
+
prefix=prefix)
|
| 443 |
+
|
| 444 |
+
def weight_loader(self,
|
| 445 |
+
param: Parameter,
|
| 446 |
+
loaded_weight: torch.Tensor,
|
| 447 |
+
loaded_shard_id: Optional[int] = None):
|
| 448 |
+
|
| 449 |
+
# Special case for GGUF
|
| 450 |
+
# initialize GGUF param after we know the quantize type
|
| 451 |
+
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
| 452 |
+
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
| 453 |
+
if is_gguf_weight_type:
|
| 454 |
+
if loaded_shard_id is not None:
|
| 455 |
+
param.data[loaded_shard_id].copy_(loaded_weight)
|
| 456 |
+
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
|
| 457 |
+
else:
|
| 458 |
+
param.shard_weight_type = {
|
| 459 |
+
i: loaded_weight.item()
|
| 460 |
+
for i, _ in enumerate(self.output_sizes)
|
| 461 |
+
}
|
| 462 |
+
return
|
| 463 |
+
|
| 464 |
+
if is_gguf_weight:
|
| 465 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 466 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 467 |
+
|
| 468 |
+
output_dim = getattr(param, "output_dim", None)
|
| 469 |
+
shard_size = loaded_weight.size(output_dim) // tp_size
|
| 470 |
+
start_idx = tp_rank * shard_size
|
| 471 |
+
|
| 472 |
+
if loaded_shard_id is not None:
|
| 473 |
+
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
| 474 |
+
shard_size)
|
| 475 |
+
param.shard_id.append(loaded_shard_id)
|
| 476 |
+
param.shard_id_map[loaded_shard_id] = len(param.data_container)
|
| 477 |
+
param.data_container.append(loaded_weight)
|
| 478 |
+
if len(param.data_container) == 2:
|
| 479 |
+
self.qweight = param.materialize_nested()
|
| 480 |
+
return
|
| 481 |
+
|
| 482 |
+
param_data = param.data
|
| 483 |
+
output_dim = getattr(param, "output_dim", None)
|
| 484 |
+
# Special case for AQLM codebooks.
|
| 485 |
+
is_metadata = getattr(param, "is_metadata", False)
|
| 486 |
+
# Special case for per-tensor scale to load scalar into fused array.
|
| 487 |
+
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
|
| 488 |
+
|
| 489 |
+
if loaded_shard_id is None:
|
| 490 |
+
# Loaded weight is already fused on disk (mlp).
|
| 491 |
+
# (e.g., Phi-3's gate_up_proj).
|
| 492 |
+
if output_dim is None:
|
| 493 |
+
if needs_scalar_to_array:
|
| 494 |
+
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
| 495 |
+
param_data, loaded_weight, 0)
|
| 496 |
+
|
| 497 |
+
assert param_data.shape == loaded_weight.shape
|
| 498 |
+
param_data.copy_(loaded_weight)
|
| 499 |
+
return
|
| 500 |
+
current_shard_offset = 0
|
| 501 |
+
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
|
| 502 |
+
False)
|
| 503 |
+
shard_offsets: list[tuple[int, int, int]] = []
|
| 504 |
+
for i, output_size in enumerate(self.output_sizes):
|
| 505 |
+
shard_offsets.append((i, current_shard_offset, output_size))
|
| 506 |
+
current_shard_offset += output_size
|
| 507 |
+
packed_dim = getattr(param, "packed_dim", None)
|
| 508 |
+
for shard_id, shard_offset, shard_size in shard_offsets:
|
| 509 |
+
# Special case for Quantization.
|
| 510 |
+
# If quantized, we need to adjust the offset and size to account
|
| 511 |
+
# for the packing.
|
| 512 |
+
if packed_dim == output_dim:
|
| 513 |
+
shard_size = shard_size // param.pack_factor
|
| 514 |
+
shard_offset = shard_offset // param.pack_factor
|
| 515 |
+
# Special case for Marlin.
|
| 516 |
+
shard_size, shard_offset = adjust_marlin_shard(
|
| 517 |
+
param, shard_size, shard_offset)
|
| 518 |
+
|
| 519 |
+
if use_bitsandbytes_4bit:
|
| 520 |
+
index = list(itertools.accumulate([0] + self.output_sizes))
|
| 521 |
+
orig_offsets = {
|
| 522 |
+
str(i): (index[i], size)
|
| 523 |
+
for i, size in enumerate(self.output_sizes)
|
| 524 |
+
}
|
| 525 |
+
orig_offsets["total"] = (self.output_size, 0)
|
| 526 |
+
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
| 527 |
+
param, orig_offsets, str(shard_id))
|
| 528 |
+
|
| 529 |
+
loaded_weight_shard = loaded_weight.narrow(
|
| 530 |
+
output_dim, shard_offset, shard_size)
|
| 531 |
+
self.weight_loader(param, loaded_weight_shard, shard_id)
|
| 532 |
+
return
|
| 533 |
+
|
| 534 |
+
assert loaded_shard_id < len(self.output_sizes)
|
| 535 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 536 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 537 |
+
if output_dim is not None:
|
| 538 |
+
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
| 539 |
+
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
| 540 |
+
# Special case for quantization.
|
| 541 |
+
# If quantized, we need to adjust the offset and size to account
|
| 542 |
+
# for the packing.
|
| 543 |
+
packed_dim = getattr(param, "packed_dim", None)
|
| 544 |
+
if packed_dim == output_dim:
|
| 545 |
+
shard_size = shard_size // param.pack_factor
|
| 546 |
+
shard_offset = shard_offset // param.pack_factor
|
| 547 |
+
# Special case for Marlin.
|
| 548 |
+
shard_size, shard_offset = adjust_marlin_shard(
|
| 549 |
+
param, shard_size, shard_offset)
|
| 550 |
+
|
| 551 |
+
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
|
| 552 |
+
False)
|
| 553 |
+
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
| 554 |
+
# bitsandbytes loads the weights of the specific portion
|
| 555 |
+
# no need to narrow
|
| 556 |
+
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
|
| 557 |
+
|
| 558 |
+
if use_bitsandbytes_4bit:
|
| 559 |
+
shard_size = loaded_weight.shape[output_dim]
|
| 560 |
+
shard_offset = loaded_weight.shape[output_dim] * \
|
| 561 |
+
loaded_shard_id
|
| 562 |
+
|
| 563 |
+
param_data = param_data.narrow(output_dim, shard_offset,
|
| 564 |
+
shard_size)
|
| 565 |
+
start_idx = tp_rank * shard_size
|
| 566 |
+
if not is_sharded_weight:
|
| 567 |
+
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
| 568 |
+
shard_size)
|
| 569 |
+
# Special case for AQLM codebooks.
|
| 570 |
+
elif is_metadata:
|
| 571 |
+
# metadata indicates fixed size concatenated along dim 0
|
| 572 |
+
shard_size = loaded_weight.shape[0]
|
| 573 |
+
shard_offset = loaded_shard_id * shard_size
|
| 574 |
+
param_data = param_data.narrow(0, shard_offset, shard_size)
|
| 575 |
+
|
| 576 |
+
# Special case for per-tensor scales in fused case.
|
| 577 |
+
elif needs_scalar_to_array:
|
| 578 |
+
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
| 579 |
+
param_data, loaded_weight, loaded_shard_id)
|
| 580 |
+
|
| 581 |
+
else:
|
| 582 |
+
ignore_warning = getattr(param, "ignore_warning", False)
|
| 583 |
+
if not ignore_warning:
|
| 584 |
+
logger.warning(
|
| 585 |
+
"Loading a weight without `output_dim` attribute in "
|
| 586 |
+
"MergedColumnParallelLinear, assume the weight is "
|
| 587 |
+
"the same for all partitions.")
|
| 588 |
+
|
| 589 |
+
assert param_data.shape == loaded_weight.shape
|
| 590 |
+
param_data.copy_(loaded_weight)
|
| 591 |
+
|
| 592 |
+
def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
|
| 593 |
+
loaded_weight: torch.Tensor):
|
| 594 |
+
"""
|
| 595 |
+
Handle special case for models where MLP layers are already
|
| 596 |
+
fused on disk. In this case, we have no shard id. This function
|
| 597 |
+
determmines the shard id by splitting these layers and then calls
|
| 598 |
+
the weight loader using the shard id.
|
| 599 |
+
|
| 600 |
+
An example of a model with these fused layers:
|
| 601 |
+
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
|
| 602 |
+
"""
|
| 603 |
+
|
| 604 |
+
current_shard_offset = 0
|
| 605 |
+
shard_offsets: list[tuple[int, int, int]] = []
|
| 606 |
+
for i, output_size in enumerate(self.output_sizes):
|
| 607 |
+
shard_offsets.append((i, current_shard_offset, output_size))
|
| 608 |
+
current_shard_offset += output_size
|
| 609 |
+
|
| 610 |
+
for shard_id, shard_offset, shard_size in shard_offsets:
|
| 611 |
+
# Special case for Quantization.
|
| 612 |
+
# If quantized, we need to adjust the offset and size to account
|
| 613 |
+
# for the packing.
|
| 614 |
+
if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
|
| 615 |
+
)) and param.packed_dim == param.output_dim:
|
| 616 |
+
shard_size, shard_offset = \
|
| 617 |
+
param.adjust_shard_indexes_for_packing(
|
| 618 |
+
shard_size=shard_size, shard_offset=shard_offset)
|
| 619 |
+
|
| 620 |
+
loaded_weight_shard = loaded_weight.narrow(param.output_dim,
|
| 621 |
+
shard_offset,
|
| 622 |
+
shard_size)
|
| 623 |
+
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
|
| 624 |
+
|
| 625 |
+
def weight_loader_v2(self,
|
| 626 |
+
param: BasevLLMParameter,
|
| 627 |
+
loaded_weight: torch.Tensor,
|
| 628 |
+
loaded_shard_id: Optional[int] = None):
|
| 629 |
+
if loaded_shard_id is None:
|
| 630 |
+
if isinstance(param, PerTensorScaleParameter):
|
| 631 |
+
param.load_merged_column_weight(loaded_weight=loaded_weight,
|
| 632 |
+
shard_id=0)
|
| 633 |
+
return
|
| 634 |
+
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
| 635 |
+
param.load_merged_column_weight(loaded_weight=loaded_weight)
|
| 636 |
+
return
|
| 637 |
+
# TODO: @dsikka - move to parameter.py
|
| 638 |
+
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
| 639 |
+
return
|
| 640 |
+
|
| 641 |
+
assert loaded_shard_id < len(self.output_sizes)
|
| 642 |
+
|
| 643 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 644 |
+
|
| 645 |
+
if isinstance(param, BlockQuantScaleParameter):
|
| 646 |
+
from vllm.model_executor.layers.quantization.fp8 import (
|
| 647 |
+
Fp8LinearMethod, Fp8MoEMethod)
|
| 648 |
+
assert self.quant_method is not None
|
| 649 |
+
assert isinstance(self.quant_method,
|
| 650 |
+
(Fp8LinearMethod, Fp8MoEMethod))
|
| 651 |
+
weight_block_size = self.quant_method.quant_config.weight_block_size
|
| 652 |
+
assert weight_block_size is not None
|
| 653 |
+
block_n, _ = weight_block_size[0], weight_block_size[1]
|
| 654 |
+
shard_offset = (
|
| 655 |
+
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
|
| 656 |
+
block_n) // tp_size
|
| 657 |
+
shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
|
| 658 |
+
block_n // tp_size)
|
| 659 |
+
else:
|
| 660 |
+
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
| 661 |
+
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
| 662 |
+
|
| 663 |
+
param.load_merged_column_weight(loaded_weight=loaded_weight,
|
| 664 |
+
shard_id=loaded_shard_id,
|
| 665 |
+
shard_offset=shard_offset,
|
| 666 |
+
shard_size=shard_size)
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
class QKVParallelLinear(ColumnParallelLinear):
|
| 670 |
+
"""Linear layers for the attention's QKV transformation.
|
| 671 |
+
|
| 672 |
+
Linear layers for the linear transformation of the query, key, and value
|
| 673 |
+
vectors in the attention layer. The weight matrix is concatenated along
|
| 674 |
+
the output dimension. The layer is parallelized along the head dimension.
|
| 675 |
+
When the number of key/value heads is smaller than the number of query
|
| 676 |
+
heads (e.g., multi-query/grouped-query attention), the key/value head may
|
| 677 |
+
be replicated while the query heads are partitioned.
|
| 678 |
+
|
| 679 |
+
Args:
|
| 680 |
+
hidden_size: input hidden state size of the transformer.
|
| 681 |
+
head_size: size of each attention head.
|
| 682 |
+
total_num_heads: total number of attention query heads.
|
| 683 |
+
total_num_kv_heads: total number of attention key/value heads. If
|
| 684 |
+
None, assume total_num_kv_heads = total_num_heads.
|
| 685 |
+
bias: If true, add bias.
|
| 686 |
+
skip_bias_add: This was added to enable performance optimizations where
|
| 687 |
+
bias can be fused with other element-wise operations. we
|
| 688 |
+
skip adding bias but instead return it.
|
| 689 |
+
params_dtype: Data type for the parameters.
|
| 690 |
+
quant_config: Quantization configure.
|
| 691 |
+
prefix: The name of the layer in the state dict, including all parents
|
| 692 |
+
(e.g. model.layers.0.qkv_proj)
|
| 693 |
+
"""
|
| 694 |
+
|
| 695 |
+
def __init__(self,
|
| 696 |
+
hidden_size: int,
|
| 697 |
+
head_size: int,
|
| 698 |
+
total_num_heads: int,
|
| 699 |
+
total_num_kv_heads: Optional[int] = None,
|
| 700 |
+
bias: bool = True,
|
| 701 |
+
skip_bias_add: bool = False,
|
| 702 |
+
params_dtype: Optional[torch.dtype] = None,
|
| 703 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 704 |
+
prefix: str = ""):
|
| 705 |
+
self.hidden_size = hidden_size
|
| 706 |
+
self.head_size = head_size
|
| 707 |
+
self.total_num_heads = total_num_heads
|
| 708 |
+
if total_num_kv_heads is None:
|
| 709 |
+
total_num_kv_heads = total_num_heads
|
| 710 |
+
self.total_num_kv_heads = total_num_kv_heads
|
| 711 |
+
# Divide the weight matrix along the last dimension.
|
| 712 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 713 |
+
self.num_heads = divide(self.total_num_heads, tp_size)
|
| 714 |
+
if tp_size >= self.total_num_kv_heads:
|
| 715 |
+
self.num_kv_heads = 1
|
| 716 |
+
self.num_kv_head_replicas = divide(tp_size,
|
| 717 |
+
self.total_num_kv_heads)
|
| 718 |
+
else:
|
| 719 |
+
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
| 720 |
+
self.num_kv_head_replicas = 1
|
| 721 |
+
input_size = self.hidden_size
|
| 722 |
+
output_size = (self.num_heads +
|
| 723 |
+
2 * self.num_kv_heads) * tp_size * self.head_size
|
| 724 |
+
self.output_sizes = [
|
| 725 |
+
self.num_heads * self.head_size * tp_size, # q_proj
|
| 726 |
+
self.num_kv_heads * self.head_size * tp_size, # k_proj
|
| 727 |
+
self.num_kv_heads * self.head_size * tp_size, # v_proj
|
| 728 |
+
]
|
| 729 |
+
|
| 730 |
+
super().__init__(input_size=input_size,
|
| 731 |
+
output_size=output_size,
|
| 732 |
+
bias=bias,
|
| 733 |
+
gather_output=False,
|
| 734 |
+
skip_bias_add=skip_bias_add,
|
| 735 |
+
params_dtype=params_dtype,
|
| 736 |
+
quant_config=quant_config,
|
| 737 |
+
prefix=prefix)
|
| 738 |
+
|
| 739 |
+
def _get_shard_offset_mapping(self, loaded_shard_id: str):
|
| 740 |
+
shard_offset_mapping = {
|
| 741 |
+
"q": 0,
|
| 742 |
+
"k": self.num_heads * self.head_size,
|
| 743 |
+
"v": (self.num_heads + self.num_kv_heads) * self.head_size,
|
| 744 |
+
"total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
|
| 745 |
+
}
|
| 746 |
+
return shard_offset_mapping.get(loaded_shard_id)
|
| 747 |
+
|
| 748 |
+
def _get_shard_size_mapping(self, loaded_shard_id: str):
|
| 749 |
+
shard_size_mapping = {
|
| 750 |
+
"q": self.num_heads * self.head_size,
|
| 751 |
+
"k": self.num_kv_heads * self.head_size,
|
| 752 |
+
"v": self.num_kv_heads * self.head_size,
|
| 753 |
+
}
|
| 754 |
+
return shard_size_mapping.get(loaded_shard_id)
|
| 755 |
+
|
| 756 |
+
def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
|
| 757 |
+
loaded_weight: torch.Tensor):
|
| 758 |
+
"""
|
| 759 |
+
Handle special case for models where QKV layers are already
|
| 760 |
+
fused on disk. In this case, we have no shard id. This function
|
| 761 |
+
determmines the shard id by splitting these layers and then calls
|
| 762 |
+
the weight loader using the shard id.
|
| 763 |
+
|
| 764 |
+
An example of a model with these fused layers:
|
| 765 |
+
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
|
| 766 |
+
"""
|
| 767 |
+
shard_offsets = [
|
| 768 |
+
# (shard_id, shard_offset, shard_size)
|
| 769 |
+
("q", 0, self.total_num_heads * self.head_size),
|
| 770 |
+
("k", self.total_num_heads * self.head_size,
|
| 771 |
+
self.total_num_kv_heads * self.head_size),
|
| 772 |
+
("v",
|
| 773 |
+
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
|
| 774 |
+
self.total_num_kv_heads * self.head_size),
|
| 775 |
+
]
|
| 776 |
+
|
| 777 |
+
for shard_id, shard_offset, shard_size in shard_offsets:
|
| 778 |
+
# Special case for Quantization.
|
| 779 |
+
# If quantized, we need to adjust the offset and size to account
|
| 780 |
+
# for the packing.
|
| 781 |
+
if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
|
| 782 |
+
)) and param.packed_dim == param.output_dim:
|
| 783 |
+
shard_size, shard_offset = \
|
| 784 |
+
param.adjust_shard_indexes_for_packing(
|
| 785 |
+
shard_size=shard_size, shard_offset=shard_offset)
|
| 786 |
+
|
| 787 |
+
loaded_weight_shard = loaded_weight.narrow(param.output_dim,
|
| 788 |
+
shard_offset,
|
| 789 |
+
shard_size)
|
| 790 |
+
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
|
| 791 |
+
|
| 792 |
+
def weight_loader_v2(self,
|
| 793 |
+
param: BasevLLMParameter,
|
| 794 |
+
loaded_weight: torch.Tensor,
|
| 795 |
+
loaded_shard_id: Optional[str] = None):
|
| 796 |
+
if loaded_shard_id is None: # special case for certain models
|
| 797 |
+
if isinstance(param, PerTensorScaleParameter):
|
| 798 |
+
param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
|
| 799 |
+
return
|
| 800 |
+
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
| 801 |
+
param.load_qkv_weight(loaded_weight=loaded_weight)
|
| 802 |
+
return
|
| 803 |
+
# TODO: @dsikka - move to parameter.py
|
| 804 |
+
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
| 805 |
+
return
|
| 806 |
+
|
| 807 |
+
assert loaded_shard_id in ["q", "k", "v"]
|
| 808 |
+
|
| 809 |
+
shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
|
| 810 |
+
shard_size = self._get_shard_size_mapping(loaded_shard_id)
|
| 811 |
+
|
| 812 |
+
param.load_qkv_weight(loaded_weight=loaded_weight,
|
| 813 |
+
num_heads=self.num_kv_head_replicas,
|
| 814 |
+
shard_id=loaded_shard_id,
|
| 815 |
+
shard_offset=shard_offset,
|
| 816 |
+
shard_size=shard_size)
|
| 817 |
+
|
| 818 |
+
def weight_loader(self,
|
| 819 |
+
param: Parameter,
|
| 820 |
+
loaded_weight: torch.Tensor,
|
| 821 |
+
loaded_shard_id: Optional[str] = None):
|
| 822 |
+
|
| 823 |
+
# Special case for GGUF
|
| 824 |
+
# initialize GGUF param after we know the quantize type
|
| 825 |
+
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
| 826 |
+
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
| 827 |
+
if is_gguf_weight_type:
|
| 828 |
+
idx_map = {"q": 0, "k": 1, "v": 2}
|
| 829 |
+
if loaded_shard_id is not None:
|
| 830 |
+
param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
|
| 831 |
+
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
|
| 832 |
+
else:
|
| 833 |
+
param.shard_weight_type = {
|
| 834 |
+
k: loaded_weight.item()
|
| 835 |
+
for k in idx_map
|
| 836 |
+
}
|
| 837 |
+
return
|
| 838 |
+
|
| 839 |
+
if is_gguf_weight:
|
| 840 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 841 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 842 |
+
|
| 843 |
+
output_dim = getattr(param, "output_dim", None)
|
| 844 |
+
shard_size = loaded_weight.size(output_dim) // tp_size
|
| 845 |
+
start_idx = tp_rank * shard_size
|
| 846 |
+
|
| 847 |
+
if loaded_shard_id is not None:
|
| 848 |
+
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
| 849 |
+
shard_size)
|
| 850 |
+
param.shard_id.append(loaded_shard_id)
|
| 851 |
+
param.shard_id_map[loaded_shard_id] = len(param.data_container)
|
| 852 |
+
param.data_container.append(loaded_weight)
|
| 853 |
+
if len(param.data_container) == 3:
|
| 854 |
+
self.qweight = param.materialize_nested()
|
| 855 |
+
return
|
| 856 |
+
|
| 857 |
+
param_data = param.data
|
| 858 |
+
output_dim = getattr(param, "output_dim", None)
|
| 859 |
+
# Special case for AQLM codebooks.
|
| 860 |
+
is_metadata = getattr(param, "is_metadata", False)
|
| 861 |
+
|
| 862 |
+
# Special case for per-tensor scales in fused case.
|
| 863 |
+
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
|
| 864 |
+
|
| 865 |
+
if loaded_shard_id is None:
|
| 866 |
+
# Loaded weight is already fused on disk (qkv).
|
| 867 |
+
# (e.g., Phi-3's qkv_proj).
|
| 868 |
+
if output_dim is None:
|
| 869 |
+
if needs_scalar_to_array:
|
| 870 |
+
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
| 871 |
+
param_data, loaded_weight, 0)
|
| 872 |
+
|
| 873 |
+
assert param_data.shape == loaded_weight.shape
|
| 874 |
+
param_data.copy_(loaded_weight)
|
| 875 |
+
return
|
| 876 |
+
shard_offsets = [
|
| 877 |
+
# (shard_id, shard_offset, shard_size)
|
| 878 |
+
("q", 0, self.total_num_heads * self.head_size),
|
| 879 |
+
("k", self.total_num_heads * self.head_size,
|
| 880 |
+
self.total_num_kv_heads * self.head_size),
|
| 881 |
+
("v", (self.total_num_heads + self.total_num_kv_heads) *
|
| 882 |
+
self.head_size, self.total_num_kv_heads * self.head_size),
|
| 883 |
+
]
|
| 884 |
+
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
|
| 885 |
+
False)
|
| 886 |
+
|
| 887 |
+
packed_dim = getattr(param, "packed_dim", None)
|
| 888 |
+
for shard_id, shard_offset, shard_size in shard_offsets:
|
| 889 |
+
# Special case for Quantized Weights.
|
| 890 |
+
# If quantized, we need to adjust the offset and size to account
|
| 891 |
+
# for the packing.
|
| 892 |
+
if packed_dim == output_dim:
|
| 893 |
+
shard_size = shard_size // param.pack_factor
|
| 894 |
+
shard_offset = shard_offset // param.pack_factor
|
| 895 |
+
|
| 896 |
+
# Special case for Marlin.
|
| 897 |
+
shard_size, shard_offset = adjust_marlin_shard(
|
| 898 |
+
param, shard_size, shard_offset)
|
| 899 |
+
|
| 900 |
+
if use_bitsandbytes_4bit:
|
| 901 |
+
orig_qkv_offsets = {
|
| 902 |
+
"q": (0, self.total_num_heads * self.head_size),
|
| 903 |
+
"k": (self.total_num_heads * self.head_size,
|
| 904 |
+
self.total_num_kv_heads * self.head_size),
|
| 905 |
+
"v":
|
| 906 |
+
((self.total_num_heads + self.total_num_kv_heads) *
|
| 907 |
+
self.head_size,
|
| 908 |
+
self.total_num_kv_heads * self.head_size),
|
| 909 |
+
"total":
|
| 910 |
+
((self.total_num_heads + 2 * self.total_num_kv_heads) *
|
| 911 |
+
self.head_size, 0)
|
| 912 |
+
}
|
| 913 |
+
|
| 914 |
+
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
| 915 |
+
param, orig_qkv_offsets, shard_id)
|
| 916 |
+
|
| 917 |
+
loaded_weight_shard = loaded_weight.narrow(
|
| 918 |
+
output_dim, shard_offset, shard_size)
|
| 919 |
+
self.weight_loader(param, loaded_weight_shard, shard_id)
|
| 920 |
+
return
|
| 921 |
+
|
| 922 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 923 |
+
assert loaded_shard_id in ["q", "k", "v"]
|
| 924 |
+
|
| 925 |
+
# If output dim is defined, use the default loading process.
|
| 926 |
+
if output_dim is not None:
|
| 927 |
+
if loaded_shard_id == "q":
|
| 928 |
+
shard_offset = 0
|
| 929 |
+
shard_size = self.num_heads * self.head_size
|
| 930 |
+
elif loaded_shard_id == "k":
|
| 931 |
+
shard_offset = self.num_heads * self.head_size
|
| 932 |
+
shard_size = self.num_kv_heads * self.head_size
|
| 933 |
+
elif loaded_shard_id == "v":
|
| 934 |
+
shard_offset = (self.num_heads +
|
| 935 |
+
self.num_kv_heads) * self.head_size
|
| 936 |
+
shard_size = self.num_kv_heads * self.head_size
|
| 937 |
+
# Special case for Quantized Weights.
|
| 938 |
+
# If quantized, we need to adjust the offset and size to account
|
| 939 |
+
# for the packing.
|
| 940 |
+
packed_dim = getattr(param, "packed_dim", None)
|
| 941 |
+
if packed_dim == output_dim:
|
| 942 |
+
shard_size = shard_size // param.pack_factor
|
| 943 |
+
shard_offset = shard_offset // param.pack_factor
|
| 944 |
+
|
| 945 |
+
# Special case for Marlin.
|
| 946 |
+
shard_size, shard_offset = adjust_marlin_shard(
|
| 947 |
+
param, shard_size, shard_offset)
|
| 948 |
+
|
| 949 |
+
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
|
| 950 |
+
False)
|
| 951 |
+
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
| 952 |
+
# bitsandbytes loads the weights of the specific portion
|
| 953 |
+
# no need to narrow
|
| 954 |
+
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
|
| 955 |
+
|
| 956 |
+
if use_bitsandbytes_4bit:
|
| 957 |
+
orig_qkv_offsets = {
|
| 958 |
+
"q": (0, self.num_heads * self.head_size),
|
| 959 |
+
"k": (self.num_heads * self.head_size,
|
| 960 |
+
self.num_kv_heads * self.head_size),
|
| 961 |
+
"v":
|
| 962 |
+
((self.num_heads + self.num_kv_heads) * self.head_size,
|
| 963 |
+
self.num_kv_heads * self.head_size),
|
| 964 |
+
"total":
|
| 965 |
+
((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
|
| 966 |
+
0)
|
| 967 |
+
}
|
| 968 |
+
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
| 969 |
+
param, orig_qkv_offsets, loaded_shard_id)
|
| 970 |
+
|
| 971 |
+
param_data = param_data.narrow(output_dim, shard_offset,
|
| 972 |
+
shard_size)
|
| 973 |
+
if loaded_shard_id == "q":
|
| 974 |
+
shard_id = tp_rank
|
| 975 |
+
else:
|
| 976 |
+
shard_id = tp_rank // self.num_kv_head_replicas
|
| 977 |
+
start_idx = shard_id * shard_size
|
| 978 |
+
|
| 979 |
+
if not is_sharded_weight:
|
| 980 |
+
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
| 981 |
+
shard_size)
|
| 982 |
+
|
| 983 |
+
# Special case for for AQLM codebooks.
|
| 984 |
+
elif is_metadata:
|
| 985 |
+
# metadata indicates fixed size concatenated along dim 0
|
| 986 |
+
shard_size = loaded_weight.shape[0]
|
| 987 |
+
shard_index = ["q", "k", "v"].index(loaded_shard_id)
|
| 988 |
+
param_data = param_data.narrow(0, shard_index * shard_size,
|
| 989 |
+
shard_size)
|
| 990 |
+
# Special case for per-tensor scales in fused case.
|
| 991 |
+
elif needs_scalar_to_array:
|
| 992 |
+
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
| 993 |
+
param_data, loaded_weight, loaded_shard_id)
|
| 994 |
+
else:
|
| 995 |
+
ignore_warning = getattr(param, "ignore_warning", False)
|
| 996 |
+
if not ignore_warning:
|
| 997 |
+
logger.warning(
|
| 998 |
+
"Loading a weight without `output_dim` attribute in "
|
| 999 |
+
"QKVParallelLinear, assume the weight is the same "
|
| 1000 |
+
"for all partitions.")
|
| 1001 |
+
|
| 1002 |
+
assert param_data.shape == loaded_weight.shape
|
| 1003 |
+
param_data.copy_(loaded_weight)
|
| 1004 |
+
|
| 1005 |
+
|
| 1006 |
+
class RowParallelLinear(LinearBase):
|
| 1007 |
+
"""Linear layer with row parallelism.
|
| 1008 |
+
|
| 1009 |
+
The linear layer is defined as Y = XA + b. A is parallelized along
|
| 1010 |
+
its first dimension and X along its second dimension as:
|
| 1011 |
+
- -
|
| 1012 |
+
| A_1 |
|
| 1013 |
+
| . |
|
| 1014 |
+
A = | . | X = [X_1, ..., X_p]
|
| 1015 |
+
| . |
|
| 1016 |
+
| A_p |
|
| 1017 |
+
- -
|
| 1018 |
+
Arguments:
|
| 1019 |
+
input_size: first dimension of matrix A.
|
| 1020 |
+
output_size: second dimension of matrix A.
|
| 1021 |
+
bias: If true, add bias. Note that bias is not parallelized.
|
| 1022 |
+
input_is_parallel: If true, we assume that the input is already
|
| 1023 |
+
split across the GPUs and we do not split
|
| 1024 |
+
again.
|
| 1025 |
+
skip_bias_add: This was added to enable performance optimization where
|
| 1026 |
+
bias can be fused with other element-wise operations.
|
| 1027 |
+
We skip adding bias but instead return it.
|
| 1028 |
+
params_dtype: Data type for the parameters.
|
| 1029 |
+
quant_config: Quantization configure.
|
| 1030 |
+
"""
|
| 1031 |
+
|
| 1032 |
+
def __init__(self,
|
| 1033 |
+
input_size: int,
|
| 1034 |
+
output_size: int,
|
| 1035 |
+
bias: bool = True,
|
| 1036 |
+
input_is_parallel: bool = True,
|
| 1037 |
+
skip_bias_add: bool = False,
|
| 1038 |
+
params_dtype: Optional[torch.dtype] = None,
|
| 1039 |
+
reduce_results: bool = True,
|
| 1040 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 1041 |
+
prefix: str = ""):
|
| 1042 |
+
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
| 1043 |
+
quant_config, prefix)
|
| 1044 |
+
|
| 1045 |
+
self.input_is_parallel = input_is_parallel
|
| 1046 |
+
self.reduce_results = reduce_results
|
| 1047 |
+
|
| 1048 |
+
# Divide the weight matrix along the last dimension.
|
| 1049 |
+
self.tp_rank = get_tensor_model_parallel_rank()
|
| 1050 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 1051 |
+
self.input_size_per_partition = divide(input_size, self.tp_size)
|
| 1052 |
+
assert self.quant_method is not None
|
| 1053 |
+
|
| 1054 |
+
self.quant_method.create_weights(
|
| 1055 |
+
layer=self,
|
| 1056 |
+
input_size_per_partition=self.input_size_per_partition,
|
| 1057 |
+
output_partition_sizes=[self.output_size],
|
| 1058 |
+
input_size=self.input_size,
|
| 1059 |
+
output_size=self.output_size,
|
| 1060 |
+
params_dtype=self.params_dtype,
|
| 1061 |
+
weight_loader=(
|
| 1062 |
+
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
| 1063 |
+
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
|
| 1064 |
+
if not reduce_results and (bias and not skip_bias_add):
|
| 1065 |
+
raise ValueError("When not reduce the results, adding bias to the "
|
| 1066 |
+
"results can lead to incorrect results")
|
| 1067 |
+
|
| 1068 |
+
if bias:
|
| 1069 |
+
self.bias = Parameter(
|
| 1070 |
+
torch.empty(self.output_size, dtype=params_dtype))
|
| 1071 |
+
set_weight_attrs(self.bias, {
|
| 1072 |
+
"output_dim": 0,
|
| 1073 |
+
"weight_loader": self.weight_loader,
|
| 1074 |
+
})
|
| 1075 |
+
else:
|
| 1076 |
+
self.register_parameter("bias", None)
|
| 1077 |
+
|
| 1078 |
+
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
| 1079 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 1080 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 1081 |
+
input_dim = getattr(param, "input_dim", None)
|
| 1082 |
+
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
| 1083 |
+
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
| 1084 |
+
# bitsandbytes loads the weights of the specific portion
|
| 1085 |
+
# no need to narrow
|
| 1086 |
+
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
|
| 1087 |
+
|
| 1088 |
+
# Special case for GGUF
|
| 1089 |
+
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
| 1090 |
+
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
| 1091 |
+
if is_gguf_weight_type:
|
| 1092 |
+
param.weight_type = loaded_weight.item()
|
| 1093 |
+
|
| 1094 |
+
# Materialize GGUF UninitializedParameter
|
| 1095 |
+
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
| 1096 |
+
weight_shape = list(loaded_weight.shape)
|
| 1097 |
+
if input_dim:
|
| 1098 |
+
weight_shape[input_dim] = weight_shape[input_dim] // tp_size
|
| 1099 |
+
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
|
| 1100 |
+
|
| 1101 |
+
param_data = param.data
|
| 1102 |
+
if input_dim is not None and not is_sharded_weight:
|
| 1103 |
+
shard_size = param_data.shape[input_dim]
|
| 1104 |
+
start_idx = tp_rank * shard_size
|
| 1105 |
+
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
|
| 1106 |
+
shard_size)
|
| 1107 |
+
|
| 1108 |
+
# Special case for loading scales off disk, which often do not
|
| 1109 |
+
# have a shape (such as in the case of AutoFP8).
|
| 1110 |
+
if len(loaded_weight.shape) == 0:
|
| 1111 |
+
loaded_weight = loaded_weight.reshape(1)
|
| 1112 |
+
|
| 1113 |
+
assert param_data.shape == loaded_weight.shape
|
| 1114 |
+
param_data.copy_(loaded_weight)
|
| 1115 |
+
|
| 1116 |
+
def weight_loader_v2(self, param: BasevLLMParameter,
|
| 1117 |
+
loaded_weight: torch.Tensor):
|
| 1118 |
+
|
| 1119 |
+
# Special case for loading scales off disk, which often do not
|
| 1120 |
+
# have a shape (such as in the case of AutoFP8).
|
| 1121 |
+
if len(loaded_weight.shape) == 0:
|
| 1122 |
+
assert loaded_weight.numel() == 1
|
| 1123 |
+
loaded_weight = loaded_weight.reshape(1)
|
| 1124 |
+
|
| 1125 |
+
param.load_row_parallel_weight(loaded_weight=loaded_weight)
|
| 1126 |
+
|
| 1127 |
+
def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]:
|
| 1128 |
+
if self.input_is_parallel:
|
| 1129 |
+
input_parallel = input_
|
| 1130 |
+
else:
|
| 1131 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 1132 |
+
splitted_input = split_tensor_along_last_dim(
|
| 1133 |
+
input_, num_partitions=self.tp_size)
|
| 1134 |
+
input_parallel = splitted_input[tp_rank].contiguous()
|
| 1135 |
+
|
| 1136 |
+
# Matrix multiply.
|
| 1137 |
+
assert self.quant_method is not None
|
| 1138 |
+
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
| 1139 |
+
# bias will not get added more than once in TP>1 case)
|
| 1140 |
+
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
| 1141 |
+
output_parallel = self.quant_method.apply(self,
|
| 1142 |
+
input_parallel,
|
| 1143 |
+
bias=bias_)
|
| 1144 |
+
if self.reduce_results and self.tp_size > 1:
|
| 1145 |
+
output = tensor_model_parallel_all_reduce(output_parallel)
|
| 1146 |
+
else:
|
| 1147 |
+
output = output_parallel
|
| 1148 |
+
|
| 1149 |
+
output_bias = self.bias if self.skip_bias_add else None
|
| 1150 |
+
|
| 1151 |
+
return output, output_bias
|
| 1152 |
+
|
| 1153 |
+
def extra_repr(self) -> str:
|
| 1154 |
+
s = f"input_features={self.input_size_per_partition}"
|
| 1155 |
+
s += f", output_features={self.output_size}"
|
| 1156 |
+
s += f", bias={self.bias is not None}"
|
| 1157 |
+
s += f", tp_size={self.tp_size}"
|
| 1158 |
+
s += f", reduce_results={self.reduce_results}"
|
| 1159 |
+
return s
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/logits_processor.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""A layer that compute logits from hidden_stats."""
|
| 3 |
+
import inspect
|
| 4 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
import vllm.envs as envs
|
| 11 |
+
from vllm.config import get_current_vllm_config
|
| 12 |
+
from vllm.distributed import (tensor_model_parallel_all_gather,
|
| 13 |
+
tensor_model_parallel_gather)
|
| 14 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 15 |
+
VocabParallelEmbedding)
|
| 16 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 17 |
+
from vllm.platforms import current_platform
|
| 18 |
+
|
| 19 |
+
_logits_processor_threadpool: Optional[ThreadPoolExecutor] = None
|
| 20 |
+
if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None:
|
| 21 |
+
_logits_processor_threadpool = ThreadPoolExecutor(
|
| 22 |
+
envs.VLLM_LOGITS_PROCESSOR_THREADS)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class LogitsProcessor(nn.Module):
|
| 26 |
+
"""Process logits and apply logits processors from sampling metadata.
|
| 27 |
+
|
| 28 |
+
This layer does the following:
|
| 29 |
+
1. Gather logits from model hidden_states.
|
| 30 |
+
2. Scale logits if needed.
|
| 31 |
+
3. Apply logits processors (if any).
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self,
|
| 35 |
+
vocab_size: int,
|
| 36 |
+
org_vocab_size: Optional[int] = None,
|
| 37 |
+
scale: float = 1.0,
|
| 38 |
+
logits_as_input: bool = False,
|
| 39 |
+
soft_cap: Optional[float] = None) -> None:
|
| 40 |
+
"""
|
| 41 |
+
Args:
|
| 42 |
+
scale: A scaling factor to apply to the logits.
|
| 43 |
+
"""
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.scale = scale
|
| 46 |
+
self.vocab_size = vocab_size
|
| 47 |
+
# Whether the input is logits (default is hidden states).
|
| 48 |
+
self.logits_as_input = logits_as_input
|
| 49 |
+
# original vocabulary size (without LoRA).
|
| 50 |
+
self.org_vocab_size = org_vocab_size or vocab_size
|
| 51 |
+
# Soft cap the logits. Used in Gemma 2.
|
| 52 |
+
self.soft_cap = soft_cap
|
| 53 |
+
# Whether to use gather or all-gather to gather the logits.
|
| 54 |
+
|
| 55 |
+
parallel_config = get_current_vllm_config().parallel_config
|
| 56 |
+
self.use_all_gather = current_platform.is_tpu() \
|
| 57 |
+
or envs.VLLM_USE_V1 \
|
| 58 |
+
or parallel_config.distributed_executor_backend == "external_launcher" # noqa
|
| 59 |
+
|
| 60 |
+
def forward(
|
| 61 |
+
self,
|
| 62 |
+
lm_head: VocabParallelEmbedding,
|
| 63 |
+
hidden_states: torch.Tensor,
|
| 64 |
+
sampling_metadata: Optional[SamplingMetadata] = None,
|
| 65 |
+
embedding_bias: Optional[torch.Tensor] = None,
|
| 66 |
+
) -> Optional[torch.Tensor]:
|
| 67 |
+
if self.logits_as_input:
|
| 68 |
+
logits = hidden_states
|
| 69 |
+
else:
|
| 70 |
+
if sampling_metadata is not None:
|
| 71 |
+
hidden_states = _prune_hidden_states(hidden_states,
|
| 72 |
+
sampling_metadata)
|
| 73 |
+
|
| 74 |
+
# Get the logits for the next tokens.
|
| 75 |
+
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
|
| 76 |
+
if logits is not None:
|
| 77 |
+
if self.soft_cap is not None:
|
| 78 |
+
logits = logits / self.soft_cap
|
| 79 |
+
logits = torch.tanh(logits)
|
| 80 |
+
logits = logits * self.soft_cap
|
| 81 |
+
|
| 82 |
+
if self.scale != 1.0:
|
| 83 |
+
logits *= self.scale
|
| 84 |
+
|
| 85 |
+
# Apply logits processors (if any).
|
| 86 |
+
if sampling_metadata is not None:
|
| 87 |
+
logits = _apply_logits_processors(logits, sampling_metadata)
|
| 88 |
+
|
| 89 |
+
return logits
|
| 90 |
+
|
| 91 |
+
def _get_logits(
|
| 92 |
+
self,
|
| 93 |
+
hidden_states: torch.Tensor,
|
| 94 |
+
lm_head: VocabParallelEmbedding,
|
| 95 |
+
embedding_bias: Optional[torch.Tensor],
|
| 96 |
+
) -> Optional[torch.Tensor]:
|
| 97 |
+
# Get the logits for the next tokens.
|
| 98 |
+
logits = lm_head.linear_method.apply(lm_head,
|
| 99 |
+
hidden_states,
|
| 100 |
+
bias=embedding_bias)
|
| 101 |
+
|
| 102 |
+
if self.use_all_gather:
|
| 103 |
+
# Gather is not supported for some devices such as TPUs.
|
| 104 |
+
# Use all-gather instead.
|
| 105 |
+
# NOTE(woosuk): Here, the outputs of every device should not be None
|
| 106 |
+
# because XLA requires strict SPMD among all devices. Every device
|
| 107 |
+
# should execute the same operations after gathering the logits.
|
| 108 |
+
logits = tensor_model_parallel_all_gather(logits)
|
| 109 |
+
else:
|
| 110 |
+
# None may be returned for rank > 0
|
| 111 |
+
logits = tensor_model_parallel_gather(logits)
|
| 112 |
+
# Remove paddings in vocab (if any).
|
| 113 |
+
if logits is not None:
|
| 114 |
+
logits = logits[..., :self.org_vocab_size]
|
| 115 |
+
return logits
|
| 116 |
+
|
| 117 |
+
def extra_repr(self) -> str:
|
| 118 |
+
s = f"vocab_size={self.vocab_size}"
|
| 119 |
+
s += f", forg_vocab_size={self.org_vocab_size}"
|
| 120 |
+
s += f", scale={self.scale}, logits_as_input={self.logits_as_input}"
|
| 121 |
+
return s
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _prune_hidden_states(
|
| 125 |
+
hidden_states: torch.Tensor,
|
| 126 |
+
sampling_metadata: SamplingMetadata,
|
| 127 |
+
) -> torch.Tensor:
|
| 128 |
+
# NOTE(kzawora): The if guard is needed for Gaudi - in some scenarios
|
| 129 |
+
# (warmup, profile_run) we might not have selected_token_indices,
|
| 130 |
+
# so we skip pruning.
|
| 131 |
+
if sampling_metadata.selected_token_indices is not None:
|
| 132 |
+
return hidden_states.index_select(
|
| 133 |
+
0, sampling_metadata.selected_token_indices)
|
| 134 |
+
else:
|
| 135 |
+
return hidden_states
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _apply_logits_processors(
|
| 139 |
+
logits: torch.Tensor,
|
| 140 |
+
sampling_metadata: SamplingMetadata,
|
| 141 |
+
) -> torch.Tensor:
|
| 142 |
+
found_logits_processors = False
|
| 143 |
+
logits_processed = 0
|
| 144 |
+
logits_row_ids_and_logits_row_futures = []
|
| 145 |
+
for seq_group in sampling_metadata.seq_groups:
|
| 146 |
+
seq_ids = seq_group.seq_ids
|
| 147 |
+
sampling_params = seq_group.sampling_params
|
| 148 |
+
logits_processors = sampling_params.logits_processors
|
| 149 |
+
if logits_processors:
|
| 150 |
+
found_logits_processors = True
|
| 151 |
+
|
| 152 |
+
for seq_id, logits_row_idx in zip(seq_ids,
|
| 153 |
+
seq_group.sample_indices):
|
| 154 |
+
logits_row = logits[logits_row_idx]
|
| 155 |
+
past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids
|
| 156 |
+
prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids
|
| 157 |
+
|
| 158 |
+
if _logits_processor_threadpool is not None:
|
| 159 |
+
logits_row_ids_and_logits_row_futures.append(
|
| 160 |
+
(logits_row_idx,
|
| 161 |
+
_logits_processor_threadpool.submit(
|
| 162 |
+
_apply_logits_processors_single_seq, logits_row,
|
| 163 |
+
logits_processors, past_tokens_ids,
|
| 164 |
+
prompt_tokens_ids)))
|
| 165 |
+
else:
|
| 166 |
+
logits[logits_row_idx] = \
|
| 167 |
+
_apply_logits_processors_single_seq(
|
| 168 |
+
logits_row, logits_processors, past_tokens_ids,
|
| 169 |
+
prompt_tokens_ids)
|
| 170 |
+
|
| 171 |
+
logits_processed += len(seq_group.sample_indices) + len(
|
| 172 |
+
seq_group.prompt_logprob_indices)
|
| 173 |
+
|
| 174 |
+
for logits_row_idx, future in logits_row_ids_and_logits_row_futures:
|
| 175 |
+
logits[logits_row_idx] = future.result()
|
| 176 |
+
|
| 177 |
+
if found_logits_processors:
|
| 178 |
+
# verifies that no rows in logits were missed unexpectedly
|
| 179 |
+
assert logits_processed == logits.shape[0]
|
| 180 |
+
return logits
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def _apply_logits_processors_single_seq(logits_row, logits_processors,
|
| 184 |
+
past_tokens_ids,
|
| 185 |
+
prompt_tokens_ids) -> torch.Tensor:
|
| 186 |
+
for logits_processor in logits_processors:
|
| 187 |
+
parameters = inspect.signature(logits_processor).parameters
|
| 188 |
+
if len(parameters) == 3:
|
| 189 |
+
logits_row = logits_processor(prompt_tokens_ids, past_tokens_ids,
|
| 190 |
+
logits_row)
|
| 191 |
+
else:
|
| 192 |
+
logits_row = logits_processor(past_tokens_ids, logits_row)
|
| 193 |
+
return logits_row
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/pooler.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from enum import IntEnum
|
| 4 |
+
from typing import List, Optional, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from transformers import PretrainedConfig
|
| 10 |
+
from typing_extensions import assert_never
|
| 11 |
+
|
| 12 |
+
from vllm.config import PoolerConfig
|
| 13 |
+
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
|
| 14 |
+
PoolingTensors)
|
| 15 |
+
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
|
| 16 |
+
from vllm.transformers_utils.config import (
|
| 17 |
+
get_cross_encoder_activation_function)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class PoolingType(IntEnum):
|
| 21 |
+
"""Enumeration for different types of pooling methods."""
|
| 22 |
+
LAST = 0
|
| 23 |
+
ALL = 1
|
| 24 |
+
CLS = 2
|
| 25 |
+
STEP = 3
|
| 26 |
+
MEAN = 4
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class SimplePooler(nn.Module):
|
| 30 |
+
"""A layer that pools specific information from hidden states.
|
| 31 |
+
|
| 32 |
+
This layer does the following:
|
| 33 |
+
1. Extracts specific tokens or aggregates data based on pooling method.
|
| 34 |
+
2. Normalizes output if specified.
|
| 35 |
+
3. Returns structured results as `PoolerOutput`.
|
| 36 |
+
|
| 37 |
+
Attributes:
|
| 38 |
+
pooling_type: The type of pooling to use.
|
| 39 |
+
normalize: Whether to normalize the pooled data.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
@staticmethod
|
| 43 |
+
def from_pooling_type(
|
| 44 |
+
pooling_type: PoolingType,
|
| 45 |
+
*,
|
| 46 |
+
normalize: bool,
|
| 47 |
+
softmax: bool,
|
| 48 |
+
step_tag_id: Optional[int] = None,
|
| 49 |
+
returned_token_ids: Optional[List[int]] = None,
|
| 50 |
+
) -> "SimplePooler":
|
| 51 |
+
if pooling_type == PoolingType.LAST:
|
| 52 |
+
assert step_tag_id is None and returned_token_ids is None
|
| 53 |
+
return LastPool(normalize=normalize, softmax=softmax)
|
| 54 |
+
if pooling_type == PoolingType.ALL:
|
| 55 |
+
assert step_tag_id is None and returned_token_ids is None
|
| 56 |
+
return AllPool(normalize=normalize, softmax=softmax)
|
| 57 |
+
if pooling_type == PoolingType.CLS:
|
| 58 |
+
assert step_tag_id is None and returned_token_ids is None
|
| 59 |
+
return CLSPool(normalize=normalize, softmax=softmax)
|
| 60 |
+
if pooling_type == PoolingType.MEAN:
|
| 61 |
+
assert step_tag_id is None and returned_token_ids is None
|
| 62 |
+
return MeanPool(normalize=normalize, softmax=softmax)
|
| 63 |
+
if pooling_type == PoolingType.STEP:
|
| 64 |
+
return StepPool(normalize=normalize,
|
| 65 |
+
softmax=softmax,
|
| 66 |
+
step_tag_id=step_tag_id,
|
| 67 |
+
returned_token_ids=returned_token_ids)
|
| 68 |
+
|
| 69 |
+
assert_never(pooling_type)
|
| 70 |
+
|
| 71 |
+
def __init__(self, *, normalize: bool, softmax: bool) -> None:
|
| 72 |
+
super().__init__()
|
| 73 |
+
|
| 74 |
+
self.head = PoolerHead(normalize=normalize, softmax=softmax)
|
| 75 |
+
|
| 76 |
+
def get_prompt_lens(
|
| 77 |
+
self,
|
| 78 |
+
hidden_states: torch.Tensor,
|
| 79 |
+
pooling_metadata: PoolingMetadata,
|
| 80 |
+
) -> torch.Tensor:
|
| 81 |
+
return PoolingTensors.from_pooling_metadata(
|
| 82 |
+
pooling_metadata, hidden_states.device).prompt_lens
|
| 83 |
+
|
| 84 |
+
def extract_states(
|
| 85 |
+
self,
|
| 86 |
+
hidden_states: torch.Tensor,
|
| 87 |
+
pooling_metadata: PoolingMetadata,
|
| 88 |
+
) -> Union[list[torch.Tensor], torch.Tensor]:
|
| 89 |
+
raise NotImplementedError
|
| 90 |
+
|
| 91 |
+
def build_output(self, data: torch.Tensor) -> PoolingSequenceGroupOutput:
|
| 92 |
+
return PoolingSequenceGroupOutput(data)
|
| 93 |
+
|
| 94 |
+
def forward(
|
| 95 |
+
self,
|
| 96 |
+
hidden_states: torch.Tensor,
|
| 97 |
+
pooling_metadata: PoolingMetadata,
|
| 98 |
+
) -> PoolerOutput:
|
| 99 |
+
pooled_data = self.extract_states(hidden_states, pooling_metadata)
|
| 100 |
+
pooled_data = self.head(pooled_data)
|
| 101 |
+
pooled_outputs = [self.build_output(data) for data in pooled_data]
|
| 102 |
+
return PoolerOutput(outputs=pooled_outputs)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class CLSPool(SimplePooler):
|
| 106 |
+
|
| 107 |
+
def extract_states(
|
| 108 |
+
self,
|
| 109 |
+
hidden_states: torch.Tensor,
|
| 110 |
+
pooling_metadata: PoolingMetadata,
|
| 111 |
+
) -> Union[list[torch.Tensor], torch.Tensor]:
|
| 112 |
+
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
| 113 |
+
|
| 114 |
+
first_token_flat_indices = torch.zeros_like(prompt_lens)
|
| 115 |
+
first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
|
| 116 |
+
return hidden_states[first_token_flat_indices]
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class LastPool(SimplePooler):
|
| 120 |
+
|
| 121 |
+
def extract_states(
|
| 122 |
+
self,
|
| 123 |
+
hidden_states: torch.Tensor,
|
| 124 |
+
pooling_metadata: PoolingMetadata,
|
| 125 |
+
) -> Union[list[torch.Tensor], torch.Tensor]:
|
| 126 |
+
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
| 127 |
+
|
| 128 |
+
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
|
| 129 |
+
return hidden_states[last_token_flat_indices]
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class AllPool(SimplePooler):
|
| 133 |
+
|
| 134 |
+
def extract_states(
|
| 135 |
+
self,
|
| 136 |
+
hidden_states: torch.Tensor,
|
| 137 |
+
pooling_metadata: PoolingMetadata,
|
| 138 |
+
) -> Union[list[torch.Tensor], torch.Tensor]:
|
| 139 |
+
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
| 140 |
+
|
| 141 |
+
offset = 0
|
| 142 |
+
pooled_data = list[torch.Tensor]()
|
| 143 |
+
for prompt_len in prompt_lens:
|
| 144 |
+
pooled_data.append(hidden_states[offset:offset + prompt_len])
|
| 145 |
+
offset += prompt_len
|
| 146 |
+
|
| 147 |
+
return pooled_data
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class MeanPool(SimplePooler):
|
| 151 |
+
|
| 152 |
+
def extract_states(
|
| 153 |
+
self,
|
| 154 |
+
hidden_states: torch.Tensor,
|
| 155 |
+
pooling_metadata: PoolingMetadata,
|
| 156 |
+
) -> Union[list[torch.Tensor], torch.Tensor]:
|
| 157 |
+
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
| 158 |
+
|
| 159 |
+
cumsum = torch.cumsum(hidden_states, dim=0)
|
| 160 |
+
start_indices = torch.cat([
|
| 161 |
+
torch.tensor([0], device=hidden_states.device),
|
| 162 |
+
torch.cumsum(prompt_lens[:-1], dim=0)
|
| 163 |
+
])
|
| 164 |
+
end_indices = torch.cumsum(prompt_lens, dim=0)
|
| 165 |
+
return (cumsum[end_indices - 1] - cumsum[start_indices] +
|
| 166 |
+
hidden_states[start_indices]) / prompt_lens.unsqueeze(1)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class StepPool(SimplePooler):
|
| 170 |
+
|
| 171 |
+
def __init__(
|
| 172 |
+
self,
|
| 173 |
+
*,
|
| 174 |
+
normalize: bool,
|
| 175 |
+
softmax: bool,
|
| 176 |
+
step_tag_id: Optional[int] = None,
|
| 177 |
+
returned_token_ids: Optional[List[int]] = None,
|
| 178 |
+
):
|
| 179 |
+
super().__init__(normalize=normalize, softmax=softmax)
|
| 180 |
+
|
| 181 |
+
self.step_tag_id = step_tag_id
|
| 182 |
+
self.returned_token_ids = returned_token_ids
|
| 183 |
+
|
| 184 |
+
def extract_states(
|
| 185 |
+
self,
|
| 186 |
+
hidden_states: torch.Tensor,
|
| 187 |
+
pooling_metadata: PoolingMetadata,
|
| 188 |
+
) -> Union[list[torch.Tensor], torch.Tensor]:
|
| 189 |
+
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
| 190 |
+
|
| 191 |
+
returned_token_ids = self.returned_token_ids
|
| 192 |
+
if returned_token_ids is not None and len(returned_token_ids) > 0:
|
| 193 |
+
hidden_states = hidden_states[:, returned_token_ids]
|
| 194 |
+
|
| 195 |
+
step_tag_id = self.step_tag_id
|
| 196 |
+
|
| 197 |
+
offset = 0
|
| 198 |
+
pooled_data = list[torch.Tensor]()
|
| 199 |
+
for prompt_len, seq_data_i in zip(prompt_lens,
|
| 200 |
+
pooling_metadata.seq_data.values()):
|
| 201 |
+
pooled_data_i = hidden_states[offset:offset + prompt_len]
|
| 202 |
+
if step_tag_id is not None:
|
| 203 |
+
token_ids = torch.tensor(seq_data_i.prompt_token_ids)
|
| 204 |
+
pooled_data_i = pooled_data_i[token_ids == step_tag_id]
|
| 205 |
+
|
| 206 |
+
offset += prompt_len
|
| 207 |
+
pooled_data.append(pooled_data_i)
|
| 208 |
+
|
| 209 |
+
return pooled_data
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class PoolerHead(nn.Module):
|
| 213 |
+
|
| 214 |
+
def __init__(self, *, normalize: bool, softmax: bool) -> None:
|
| 215 |
+
super().__init__()
|
| 216 |
+
|
| 217 |
+
self.normalize = normalize
|
| 218 |
+
self.softmax = softmax
|
| 219 |
+
|
| 220 |
+
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor]):
|
| 221 |
+
if self.normalize:
|
| 222 |
+
if isinstance(pooled_data, list):
|
| 223 |
+
pooled_data = [
|
| 224 |
+
F.normalize(data, p=2, dim=1) for data in pooled_data
|
| 225 |
+
]
|
| 226 |
+
else:
|
| 227 |
+
pooled_data = F.normalize(pooled_data, p=2, dim=1)
|
| 228 |
+
|
| 229 |
+
if self.softmax:
|
| 230 |
+
if isinstance(pooled_data, list):
|
| 231 |
+
pooled_data = [F.softmax(data, dim=-1) for data in pooled_data]
|
| 232 |
+
else:
|
| 233 |
+
pooled_data = F.softmax(pooled_data, dim=-1)
|
| 234 |
+
|
| 235 |
+
return pooled_data
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class Pooler(nn.Module):
|
| 239 |
+
|
| 240 |
+
@classmethod
|
| 241 |
+
def from_config_with_defaults(
|
| 242 |
+
cls,
|
| 243 |
+
pooler_config: PoolerConfig,
|
| 244 |
+
pooling_type: PoolingType,
|
| 245 |
+
normalize: bool,
|
| 246 |
+
softmax: bool,
|
| 247 |
+
step_tag_id: Optional[int] = None,
|
| 248 |
+
returned_token_ids: Optional[List[int]] = None,
|
| 249 |
+
) -> SimplePooler:
|
| 250 |
+
return SimplePooler.from_pooling_type(
|
| 251 |
+
pooling_type=PoolingType[pooler_config.pooling_type]
|
| 252 |
+
if pooler_config.pooling_type is not None else pooling_type,
|
| 253 |
+
normalize=pooler_config.normalize
|
| 254 |
+
if pooler_config.normalize is not None else normalize,
|
| 255 |
+
softmax=pooler_config.softmax
|
| 256 |
+
if pooler_config.softmax is not None else softmax,
|
| 257 |
+
step_tag_id=pooler_config.step_tag_id
|
| 258 |
+
if pooler_config.step_tag_id is not None else step_tag_id,
|
| 259 |
+
returned_token_ids=pooler_config.returned_token_ids
|
| 260 |
+
if pooler_config.returned_token_ids is not None else
|
| 261 |
+
returned_token_ids,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class CrossEncodingPooler(nn.Module):
|
| 266 |
+
"""A layer that pools specific information from hidden states.
|
| 267 |
+
|
| 268 |
+
This layer does the following:
|
| 269 |
+
1. Extracts specific tokens or aggregates data based on pooling method.
|
| 270 |
+
2. Normalizes output if specified.
|
| 271 |
+
3. Returns structured results as `PoolerOutput`.
|
| 272 |
+
|
| 273 |
+
Attributes:
|
| 274 |
+
pooling_type: The type of pooling to use.
|
| 275 |
+
normalize: Whether to normalize the pooled data.
|
| 276 |
+
"""
|
| 277 |
+
|
| 278 |
+
def __init__(
|
| 279 |
+
self,
|
| 280 |
+
config: PretrainedConfig,
|
| 281 |
+
classifier: nn.Module,
|
| 282 |
+
pooler: Optional[nn.Module] = None,
|
| 283 |
+
):
|
| 284 |
+
super().__init__()
|
| 285 |
+
self.classifier = classifier
|
| 286 |
+
self.pooler = pooler
|
| 287 |
+
self.default_activation_function = \
|
| 288 |
+
get_cross_encoder_activation_function(config)
|
| 289 |
+
|
| 290 |
+
def forward(
|
| 291 |
+
self,
|
| 292 |
+
hidden_states: torch.Tensor,
|
| 293 |
+
pooling_metadata: PoolingMetadata,
|
| 294 |
+
) -> PoolerOutput:
|
| 295 |
+
"""Pools sentence pair scores from the hidden_states."""
|
| 296 |
+
|
| 297 |
+
prompt_lens = PoolingTensors.from_pooling_metadata(
|
| 298 |
+
pooling_metadata, hidden_states.device).prompt_lens
|
| 299 |
+
|
| 300 |
+
offset = 0
|
| 301 |
+
pooled_data_lst = []
|
| 302 |
+
for prompt_len in prompt_lens:
|
| 303 |
+
pooled_data_i = hidden_states[offset:offset + prompt_len]
|
| 304 |
+
|
| 305 |
+
if self.pooler is not None:
|
| 306 |
+
final_shape_tensor = self.pooler(pooled_data_i)
|
| 307 |
+
else:
|
| 308 |
+
final_shape_tensor = self.classifier(pooled_data_i)
|
| 309 |
+
|
| 310 |
+
pooled_data_lst.append(final_shape_tensor)
|
| 311 |
+
offset += prompt_len
|
| 312 |
+
|
| 313 |
+
pooled_output = torch.stack(pooled_data_lst)
|
| 314 |
+
|
| 315 |
+
if self.pooler is not None:
|
| 316 |
+
# apply classifier once on the full batch if possible
|
| 317 |
+
pooled_output = self.classifier(pooled_output)
|
| 318 |
+
|
| 319 |
+
scores = self.default_activation_function(pooled_output).squeeze(-1)
|
| 320 |
+
|
| 321 |
+
pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores]
|
| 322 |
+
return PoolerOutput(outputs=pooled_outputs)
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/rejection_sampler.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from functools import cached_property
|
| 4 |
+
from importlib.util import find_spec
|
| 5 |
+
from typing import Dict, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.jit
|
| 9 |
+
|
| 10 |
+
import vllm.envs as envs
|
| 11 |
+
from vllm.logger import init_logger
|
| 12 |
+
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
| 13 |
+
SpecDecodeStochasticBaseSampler)
|
| 14 |
+
from vllm.platforms import current_platform
|
| 15 |
+
|
| 16 |
+
logger = init_logger(__name__)
|
| 17 |
+
|
| 18 |
+
if find_spec("flashinfer"):
|
| 19 |
+
"""
|
| 20 |
+
Consider utilizing the FlashInfer rejection sampling kernel initially,
|
| 21 |
+
as it employs a dedicated kernel rather than relying on
|
| 22 |
+
Torch tensor operations. This design choice helps to fuse operations,
|
| 23 |
+
reduce memory I/O, and consequently enhances performance.
|
| 24 |
+
"""
|
| 25 |
+
from flashinfer.sampling import chain_speculative_sampling
|
| 26 |
+
else:
|
| 27 |
+
chain_speculative_sampling = None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
| 31 |
+
"""Apply modified rejection sampling as described in "Accelerating Large
|
| 32 |
+
Language Model Decoding with Speculative Sampling"
|
| 33 |
+
https://arxiv.org/pdf/2302.01318.pdf.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self,
|
| 37 |
+
strict_mode: bool = False,
|
| 38 |
+
use_flashinfer: Optional[bool] = None):
|
| 39 |
+
"""Create a rejection sampler.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
strict_mode: Whether or not to perform shape/device/dtype checks
|
| 43 |
+
during sampling. This catches correctness issues but adds
|
| 44 |
+
nontrivial latency.
|
| 45 |
+
use_flashinfer: We will use this parameter to determine whether
|
| 46 |
+
to use the FlashInfer rejection sampling kernel or not. If it's
|
| 47 |
+
None, we will use the default value from the environment variable.
|
| 48 |
+
This parameter is only used for testing purposes.
|
| 49 |
+
"""
|
| 50 |
+
super().__init__(strict_mode=strict_mode)
|
| 51 |
+
if use_flashinfer is None:
|
| 52 |
+
self.use_flashinfer = envs.VLLM_USE_FLASHINFER_SAMPLER and (
|
| 53 |
+
chain_speculative_sampling is not None)
|
| 54 |
+
else:
|
| 55 |
+
self.use_flashinfer = use_flashinfer
|
| 56 |
+
|
| 57 |
+
if self.use_flashinfer:
|
| 58 |
+
logger.info("Use flashinfer for rejection sampling.")
|
| 59 |
+
else:
|
| 60 |
+
logger.info("Use pytorch for rejection sampling.")
|
| 61 |
+
|
| 62 |
+
def forward(
|
| 63 |
+
self,
|
| 64 |
+
target_with_bonus_probs: torch.Tensor,
|
| 65 |
+
bonus_token_ids: torch.Tensor,
|
| 66 |
+
draft_probs: torch.Tensor,
|
| 67 |
+
draft_token_ids: torch.Tensor,
|
| 68 |
+
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
|
| 69 |
+
) -> torch.Tensor:
|
| 70 |
+
"""Sample token ids using rejection sampling. This accepts or rejects
|
| 71 |
+
tokens proposed by the draft model using the probability of each token
|
| 72 |
+
according to the draft and target models.
|
| 73 |
+
|
| 74 |
+
In the worst case where all draft tokens are rejected, it is guaranteed
|
| 75 |
+
one correct token will be emitted.
|
| 76 |
+
|
| 77 |
+
In the case where all draft tokens are accepted, a bonus token will be
|
| 78 |
+
accepted as its cheap to have the target model score this speculative
|
| 79 |
+
sequence.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
target_with_bonus_probs: The probability distribution
|
| 83 |
+
over token ids given context according to the target model.
|
| 84 |
+
shape = [batch_size, num_speculative_tokens + 1, vocab_size]
|
| 85 |
+
|
| 86 |
+
bonus_token_ids: The "bonus" token ids that are accepted iff all
|
| 87 |
+
speculative tokens in a sequence are accepted.
|
| 88 |
+
shape = [batch_size, num_bonus_tokens]
|
| 89 |
+
|
| 90 |
+
draft_probs: The probability distribution over token ids given
|
| 91 |
+
context according to the draft model.
|
| 92 |
+
shape = [batch_size, num_speculative_tokens, vocab_size]
|
| 93 |
+
|
| 94 |
+
draft_token_ids: The token ids that were sampled from the draft
|
| 95 |
+
probabilities.
|
| 96 |
+
shape = [batch_size, num_speculative_tokens]
|
| 97 |
+
|
| 98 |
+
seeded_seqs: Dict of batch row index to torch generator, for
|
| 99 |
+
sequences using seeded generation.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
output_token_ids: The token ids sampled via rejection sampling,
|
| 103 |
+
or -1 if unable to sample a token because the previous token
|
| 104 |
+
was rejected.
|
| 105 |
+
shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
|
| 106 |
+
"""
|
| 107 |
+
# Only perform shape/dtype/device checking in strict mode, as it adds
|
| 108 |
+
# overhead.
|
| 109 |
+
if self._strict_mode:
|
| 110 |
+
self._raise_if_incorrect_input(target_with_bonus_probs,
|
| 111 |
+
draft_token_ids, bonus_token_ids,
|
| 112 |
+
draft_probs)
|
| 113 |
+
|
| 114 |
+
batch_size, k, _ = draft_probs.shape
|
| 115 |
+
|
| 116 |
+
# batch_size = 0 when all requests in the batch are
|
| 117 |
+
# non_spec requests. In this case, output_token_ids is
|
| 118 |
+
# just an empty tensor.
|
| 119 |
+
if batch_size == 0:
|
| 120 |
+
return torch.empty(0, k + 1, device=draft_probs.device, dtype=int)
|
| 121 |
+
|
| 122 |
+
# If use Flashinfer chain_speculative_sampling kernel
|
| 123 |
+
# for rejection sampling
|
| 124 |
+
if self.use_flashinfer and chain_speculative_sampling is not None:
|
| 125 |
+
batch_size, k, _ = draft_probs.shape
|
| 126 |
+
uniform_samples = self._create_uniform_samples(
|
| 127 |
+
seeded_seqs, batch_size, k, draft_probs.device)
|
| 128 |
+
output_token_ids, accepted_token_num, emitted_token_num \
|
| 129 |
+
= chain_speculative_sampling(
|
| 130 |
+
draft_probs, draft_token_ids, uniform_samples,
|
| 131 |
+
target_with_bonus_probs)
|
| 132 |
+
|
| 133 |
+
# num_emitted_tokens returned by flashinfer
|
| 134 |
+
# does not include the bonus token
|
| 135 |
+
# Flashinfer stops at the first token that violates
|
| 136 |
+
# the condition p >= q and does not include recovery/bonus token.
|
| 137 |
+
# Therefore, we need to add batch_size here.
|
| 138 |
+
self.num_accepted_tokens += accepted_token_num.sum()
|
| 139 |
+
self.num_emitted_tokens += emitted_token_num.sum() + batch_size
|
| 140 |
+
self.num_draft_tokens += batch_size * k
|
| 141 |
+
else:
|
| 142 |
+
accepted, recovered_token_ids = (
|
| 143 |
+
self._batch_modified_rejection_sampling(
|
| 144 |
+
target_with_bonus_probs[:, :-1],
|
| 145 |
+
draft_probs,
|
| 146 |
+
draft_token_ids,
|
| 147 |
+
seeded_seqs,
|
| 148 |
+
))
|
| 149 |
+
|
| 150 |
+
output_token_ids = self._create_output(
|
| 151 |
+
accepted,
|
| 152 |
+
recovered_token_ids,
|
| 153 |
+
draft_token_ids,
|
| 154 |
+
bonus_token_ids,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
return output_token_ids
|
| 158 |
+
|
| 159 |
+
def _batch_modified_rejection_sampling(
|
| 160 |
+
self,
|
| 161 |
+
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
| 162 |
+
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
| 163 |
+
draft_token_ids: torch.Tensor, # [batch_size, k]
|
| 164 |
+
seeded_seqs: Optional[Dict[int, torch.Generator]],
|
| 165 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 166 |
+
"""Perform modified rejection sampling on each sequence.
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
A tuple of two tensors:
|
| 170 |
+
0: A bool tensor of which tokens in each sequence is accepted.
|
| 171 |
+
shape = [batch_size, k]
|
| 172 |
+
1: Token ids sampled from a recovered distribution, to be used
|
| 173 |
+
when a token is rejected.
|
| 174 |
+
shape = [batch_size, k]
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
batch_size, k, vocab_size = draft_probs.shape
|
| 178 |
+
|
| 179 |
+
# shape [batch_size, k]
|
| 180 |
+
accepted = self._get_accepted(target_probs, draft_probs,
|
| 181 |
+
draft_token_ids, seeded_seqs)
|
| 182 |
+
|
| 183 |
+
recovered_probs = self._get_recovered_probs(
|
| 184 |
+
target_probs, draft_probs).reshape(batch_size * k, vocab_size)
|
| 185 |
+
|
| 186 |
+
# NOTE: the recovered_probs are overwritten by this method.
|
| 187 |
+
recovered_token_ids = _multinomial(
|
| 188 |
+
recovered_probs,
|
| 189 |
+
num_samples=1,
|
| 190 |
+
k=k,
|
| 191 |
+
seeded_seqs=seeded_seqs or {},
|
| 192 |
+
).reshape(batch_size, k)
|
| 193 |
+
|
| 194 |
+
return accepted, recovered_token_ids
|
| 195 |
+
|
| 196 |
+
def _create_uniform_samples(self,
|
| 197 |
+
seeded_seqs: Optional[Dict[int,
|
| 198 |
+
torch.Generator]],
|
| 199 |
+
batch_size: int, k: int,
|
| 200 |
+
device: torch.device) -> torch.Tensor:
|
| 201 |
+
"""
|
| 202 |
+
Generates a batch of uniform random samples, with optional seeding
|
| 203 |
+
for specific sequences.
|
| 204 |
+
|
| 205 |
+
This method creates a tensor of shape `(batch_size, k + 1)` filled
|
| 206 |
+
with uniform random values in the range [0, 1). If `seeded_seqs`
|
| 207 |
+
is provided, the sequences corresponding to specific indices
|
| 208 |
+
will be generated using the provided `torch.Generator` for
|
| 209 |
+
reproducibility. The other sequences will be generated without
|
| 210 |
+
a seed.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
seeded_seqs : Optional[Dict[int, torch.Generator]]
|
| 214 |
+
A dictionary mapping indices in the batch to
|
| 215 |
+
`torch.Generator` objects. If `None`, all samples are
|
| 216 |
+
generated without a seed.
|
| 217 |
+
batch_size : int
|
| 218 |
+
The number of sequences to generate.
|
| 219 |
+
k : int
|
| 220 |
+
The number of random samples per sequence.
|
| 221 |
+
device : torch.device
|
| 222 |
+
The device on which to allocate the tensor.
|
| 223 |
+
|
| 224 |
+
Returns:
|
| 225 |
+
uniform_rand : torch.Tensor
|
| 226 |
+
A tensor of shape `(batch_size, k + 1)` containing uniform
|
| 227 |
+
random values in the range [0, 1).
|
| 228 |
+
"""
|
| 229 |
+
if not seeded_seqs:
|
| 230 |
+
return torch.rand(batch_size, k + 1, device=device)
|
| 231 |
+
|
| 232 |
+
uniform_rand = torch.empty(batch_size, k + 1, device=device)
|
| 233 |
+
|
| 234 |
+
non_seeded_indices = []
|
| 235 |
+
for idx in range(batch_size):
|
| 236 |
+
generator = seeded_seqs.get(idx)
|
| 237 |
+
if generator is None:
|
| 238 |
+
non_seeded_indices.append(idx)
|
| 239 |
+
else:
|
| 240 |
+
uniform_rand[idx, :] = torch.rand(1,
|
| 241 |
+
k + 1,
|
| 242 |
+
dtype=self.probs_dtype,
|
| 243 |
+
device=device,
|
| 244 |
+
generator=generator)
|
| 245 |
+
if non_seeded_indices:
|
| 246 |
+
uniform_rand[non_seeded_indices, :] = torch.rand(
|
| 247 |
+
len(non_seeded_indices),
|
| 248 |
+
k + 1,
|
| 249 |
+
dtype=self.probs_dtype,
|
| 250 |
+
device=device)
|
| 251 |
+
return uniform_rand
|
| 252 |
+
|
| 253 |
+
def _get_accepted(
|
| 254 |
+
self,
|
| 255 |
+
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
| 256 |
+
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
| 257 |
+
draft_token_ids: torch.Tensor, # [batch_size, k]
|
| 258 |
+
seeded_seqs: Optional[Dict[int, torch.Generator]],
|
| 259 |
+
) -> torch.Tensor:
|
| 260 |
+
r"""Create bool matrix over the proposed draft tokens. If
|
| 261 |
+
True, then a token can be accepted, else it should be
|
| 262 |
+
rejected.
|
| 263 |
+
|
| 264 |
+
Given :math:`q(\hat{x}_{n+1}|x_1, \dots, x_n)`, the probability of
|
| 265 |
+
:math:`\hat{x}_{n+1}` given context :math:`x_1, \dots, x_n` according
|
| 266 |
+
to the target model, and :math:`p(\hat{x}_{n+1}|x_1, \dots, x_n)`, the
|
| 267 |
+
same conditional probability according to the draft model, the token
|
| 268 |
+
is accepted with probability:
|
| 269 |
+
|
| 270 |
+
.. math::
|
| 271 |
+
\min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)}
|
| 272 |
+
{p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right)
|
| 273 |
+
|
| 274 |
+
This implementation does not apply causality. When using the output,
|
| 275 |
+
if a token is rejected, subsequent tokens should not be used.
|
| 276 |
+
|
| 277 |
+
Returns a bool tensor of shape [batch_size, k] specifying which tokens
|
| 278 |
+
are accepted.
|
| 279 |
+
"""
|
| 280 |
+
batch_size, k, _ = draft_probs.shape
|
| 281 |
+
batch_indices = torch.arange(batch_size,
|
| 282 |
+
device=target_probs.device)[:, None]
|
| 283 |
+
probs_indicies = torch.arange(k, device=target_probs.device)
|
| 284 |
+
|
| 285 |
+
# shape [batch_size, k]
|
| 286 |
+
selected_draft_probs = draft_probs[batch_indices, probs_indicies,
|
| 287 |
+
draft_token_ids]
|
| 288 |
+
|
| 289 |
+
# shape [batch_size, k]
|
| 290 |
+
selected_target_probs = target_probs[batch_indices, probs_indicies,
|
| 291 |
+
draft_token_ids]
|
| 292 |
+
|
| 293 |
+
uniform_rand = self._create_uniform_samples(seeded_seqs, batch_size,
|
| 294 |
+
k - 1, target_probs.device)
|
| 295 |
+
|
| 296 |
+
capped_ratio = torch.minimum(
|
| 297 |
+
selected_target_probs / selected_draft_probs,
|
| 298 |
+
torch.full((1, ), 1, device=target_probs.device))
|
| 299 |
+
accepted = uniform_rand < capped_ratio
|
| 300 |
+
|
| 301 |
+
return accepted
|
| 302 |
+
|
| 303 |
+
def _get_recovered_probs(
|
| 304 |
+
self,
|
| 305 |
+
target_probs: torch.Tensor, # [k, vocab_size]
|
| 306 |
+
draft_probs: torch.Tensor, # [k, vocab_size]
|
| 307 |
+
) -> torch.Tensor:
|
| 308 |
+
r"""Create a probability distribution for each proposed token which can
|
| 309 |
+
be sampled if the proposed token is rejected.
|
| 310 |
+
|
| 311 |
+
When this routine is applied sequentially, the true distribution of the
|
| 312 |
+
target model is recovered (within hardware numerics).
|
| 313 |
+
|
| 314 |
+
The probability distribution used in this rejection case is constructed
|
| 315 |
+
as follows. Given :math:`q(x|x_1, \dots, x_n)`, the probability of
|
| 316 |
+
:math:`x` given context :math:`x_1, \dots, x_n` according to the target
|
| 317 |
+
model and :math:`p(x|x_1, \dots, x_n)`, the same conditional probability
|
| 318 |
+
according to the draft model:
|
| 319 |
+
|
| 320 |
+
.. math::
|
| 321 |
+
x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+
|
| 322 |
+
|
| 323 |
+
where :math:`(f(x))_+` is defined as:
|
| 324 |
+
|
| 325 |
+
.. math::
|
| 326 |
+
(f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}
|
| 327 |
+
|
| 328 |
+
See https://github.com/vllm-project/vllm/pull/2336 for a visualization
|
| 329 |
+
of the draft, target, and recovered probability distributions.
|
| 330 |
+
|
| 331 |
+
Returns a tensor of shape [batch_size, k, vocab_size].
|
| 332 |
+
|
| 333 |
+
Note: This batches operations on GPU and thus constructs the recovered
|
| 334 |
+
distribution for all tokens, even if they are accepted. This causes
|
| 335 |
+
division-by-zero errors, so we use self._smallest_positive_value to
|
| 336 |
+
avoid that. This introduces some drift to the distribution.
|
| 337 |
+
"""
|
| 338 |
+
_, k, _ = draft_probs.shape
|
| 339 |
+
|
| 340 |
+
# shape [batch_size, k, vocab_size]
|
| 341 |
+
difference = target_probs - draft_probs
|
| 342 |
+
|
| 343 |
+
# TODO(cade): Can we use logprobs instead of probs, and avoid the
|
| 344 |
+
# division-by-zero errors without introducing distribution drift?
|
| 345 |
+
|
| 346 |
+
# shape [batch_size, k, vocab_size]
|
| 347 |
+
f = torch.clamp(difference, min=self._smallest_positive_value)
|
| 348 |
+
|
| 349 |
+
# shape [batch_size, k, vocab_size]
|
| 350 |
+
recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1)
|
| 351 |
+
|
| 352 |
+
return recovered_probs
|
| 353 |
+
|
| 354 |
+
@cached_property
|
| 355 |
+
def _smallest_positive_value(self) -> float:
|
| 356 |
+
"""Return the smallest positive value representable by the probs dtype.
|
| 357 |
+
This value is used when constructing a distribution from which to sample
|
| 358 |
+
recovered tokens in the first rejection case.
|
| 359 |
+
|
| 360 |
+
See _get_recovered_probs for more details
|
| 361 |
+
|
| 362 |
+
Note that this isn't actually the smallest positive value representable
|
| 363 |
+
by float32, but the smallest positive normal value.
|
| 364 |
+
See https://en.wikipedia.org/wiki/Subnormal_number for more information.
|
| 365 |
+
"""
|
| 366 |
+
return torch.finfo(self.probs_dtype).tiny
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
# torch.multinomial forces a GPU<->CPU sync.
|
| 370 |
+
# Therefore, we use an optimized implementation instead that skips the sync.
|
| 371 |
+
# Note that we always sample with replacement.
|
| 372 |
+
# probs will be modified in place, but this is fine, as we pass
|
| 373 |
+
# in a copy already.
|
| 374 |
+
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
| 375 |
+
def _multinomial(
|
| 376 |
+
probs: torch.Tensor,
|
| 377 |
+
num_samples: int,
|
| 378 |
+
k: int,
|
| 379 |
+
seeded_seqs: Dict[int, torch.Generator],
|
| 380 |
+
) -> torch.Tensor:
|
| 381 |
+
|
| 382 |
+
if num_samples > 1:
|
| 383 |
+
# This is equivalent to torch.repeat_interleaved (which also
|
| 384 |
+
# forces a GPU<->CPU sync).
|
| 385 |
+
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
|
| 386 |
+
probs.shape[1]).contiguous().view(
|
| 387 |
+
-1, probs.shape[1])
|
| 388 |
+
q = torch.empty_like(probs)
|
| 389 |
+
if not seeded_seqs:
|
| 390 |
+
q.exponential_(1.0)
|
| 391 |
+
else:
|
| 392 |
+
start = 0
|
| 393 |
+
for idx in range(len(q) // k):
|
| 394 |
+
end = start + k
|
| 395 |
+
generator = seeded_seqs.get(idx)
|
| 396 |
+
# Note: generator might be None for non seeded
|
| 397 |
+
q[start:end].exponential_(1.0, generator=generator)
|
| 398 |
+
start = end
|
| 399 |
+
|
| 400 |
+
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/resampler.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Adapted from
|
| 4 |
+
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
| 5 |
+
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
|
| 6 |
+
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
|
| 7 |
+
#
|
| 8 |
+
# Copyright 2023 The Qwen team.
|
| 9 |
+
# Copyright 2023 The vLLM team.
|
| 10 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 11 |
+
#
|
| 12 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 13 |
+
# and OPT implementations in this library. It has been modified from its
|
| 14 |
+
# original forms to accommodate minor architectural differences compared
|
| 15 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 16 |
+
#
|
| 17 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 18 |
+
# you may not use this file except in compliance with the License.
|
| 19 |
+
# You may obtain a copy of the License at
|
| 20 |
+
#
|
| 21 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 22 |
+
#
|
| 23 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 24 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 25 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 26 |
+
# See the License for the specific language governing permissions and
|
| 27 |
+
# limitations under the License.
|
| 28 |
+
"""
|
| 29 |
+
Shared resampler perceiver network used in multimodal models and
|
| 30 |
+
related helpers for sincos positional embeddings.
|
| 31 |
+
|
| 32 |
+
Example models: Qwen (Qwen-VL), MiniCPM-V 2.0
|
| 33 |
+
"""
|
| 34 |
+
import math
|
| 35 |
+
from functools import partial
|
| 36 |
+
from typing import Callable, Optional, Tuple, Union
|
| 37 |
+
|
| 38 |
+
import numpy as np
|
| 39 |
+
import torch
|
| 40 |
+
import torch.nn.functional as F
|
| 41 |
+
from torch import nn
|
| 42 |
+
|
| 43 |
+
from vllm.model_executor.layers.linear import ReplicatedLinear
|
| 44 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 45 |
+
|
| 46 |
+
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_abs_pos(abs_pos: torch.Tensor, tgt_size: Union[torch.Tensor,
|
| 50 |
+
int]) -> torch.Tensor:
|
| 51 |
+
# abs_pos: L, C
|
| 52 |
+
# tgt_size: (H, W)
|
| 53 |
+
# return: M, C
|
| 54 |
+
src_size = int(math.sqrt(abs_pos.size(0)))
|
| 55 |
+
dtype = abs_pos.dtype
|
| 56 |
+
if isinstance(tgt_size, int):
|
| 57 |
+
tgt_size = (tgt_size, tgt_size)
|
| 58 |
+
if (src_size == tgt_size[0] and src_size == tgt_size[1]):
|
| 59 |
+
return abs_pos
|
| 60 |
+
return (F.interpolate(
|
| 61 |
+
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
|
| 62 |
+
size=(tgt_size[0], tgt_size[1]),
|
| 63 |
+
mode="bicubic",
|
| 64 |
+
align_corners=False,
|
| 65 |
+
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype))
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# sin/cos positional embedding helpers are adapted from:
|
| 69 |
+
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
|
| 70 |
+
def get_1d_sincos_pos_embed_from_grid(
|
| 71 |
+
embed_dim: int, pos: np.ndarray,
|
| 72 |
+
version: Tuple[int, int] = (2, 0)) -> torch.Tensor:
|
| 73 |
+
"""
|
| 74 |
+
embed_dim: output dimension for each position
|
| 75 |
+
pos: a list of positions to be encoded: size (M,) / (H, W)
|
| 76 |
+
out: (M, D) / (H, W, D)
|
| 77 |
+
"""
|
| 78 |
+
assert embed_dim % 2 == 0
|
| 79 |
+
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
| 80 |
+
omega /= embed_dim / 2.0
|
| 81 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
| 82 |
+
|
| 83 |
+
if version == (2, 0):
|
| 84 |
+
pos = pos.reshape(-1) # (M,)
|
| 85 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 86 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 87 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 88 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 89 |
+
else:
|
| 90 |
+
out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
|
| 91 |
+
emb_sin = np.sin(out) # (H, W, D/2)
|
| 92 |
+
emb_cos = np.cos(out) # (H, W, D/2)
|
| 93 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
|
| 94 |
+
return emb
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_2d_sincos_pos_embed_from_grid(
|
| 98 |
+
embed_dim: int, grid: np.ndarray,
|
| 99 |
+
version: Tuple[int, int] = (2, 0)) -> torch.Tensor:
|
| 100 |
+
assert embed_dim % 2 == 0
|
| 101 |
+
|
| 102 |
+
# use half of dimensions to encode grid_h
|
| 103 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(
|
| 104 |
+
embed_dim // 2, grid[0], version) # (H*W, D/2) or (H, W, D/2)
|
| 105 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(
|
| 106 |
+
embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2)
|
| 107 |
+
|
| 108 |
+
if version == (2, 0):
|
| 109 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 110 |
+
else:
|
| 111 |
+
emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
|
| 112 |
+
return emb
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def get_2d_sincos_pos_embed(
|
| 116 |
+
embed_dim: int,
|
| 117 |
+
grid_size: Union[int, Tuple[int, int]],
|
| 118 |
+
cls_token: bool = False,
|
| 119 |
+
version: Tuple[int, int] = (2, 0),
|
| 120 |
+
) -> torch.Tensor:
|
| 121 |
+
"""
|
| 122 |
+
grid_size: int of the grid height and width
|
| 123 |
+
return:
|
| 124 |
+
pos_embed: [grid_size*grid_size, embed_dim] or
|
| 125 |
+
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 126 |
+
"""
|
| 127 |
+
if isinstance(grid_size, int):
|
| 128 |
+
grid_h_size, grid_w_size = grid_size, grid_size
|
| 129 |
+
else:
|
| 130 |
+
grid_h_size, grid_w_size = grid_size[0], grid_size[1]
|
| 131 |
+
|
| 132 |
+
grid_h = np.arange(grid_h_size, dtype=np.float32)
|
| 133 |
+
grid_w = np.arange(grid_w_size, dtype=np.float32)
|
| 134 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 135 |
+
grid = np.stack(grid, axis=0)
|
| 136 |
+
assert isinstance(grid, np.ndarray) and \
|
| 137 |
+
grid.shape == (2, grid_h_size, grid_w_size)
|
| 138 |
+
|
| 139 |
+
if version == (2, 0):
|
| 140 |
+
grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
|
| 141 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
|
| 142 |
+
if cls_token:
|
| 143 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed],
|
| 144 |
+
axis=0)
|
| 145 |
+
else:
|
| 146 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
|
| 147 |
+
return pos_embed
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class BaseResampler(nn.Module):
|
| 151 |
+
"""
|
| 152 |
+
A 2D perceiver-resampler network with one cross attention layers by
|
| 153 |
+
(grid_size**2) learnable queries and 2d sincos pos_emb.
|
| 154 |
+
Outputs:
|
| 155 |
+
A tensor with the shape of (grid_size**2, embed_dim)
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
def __init__(self,
|
| 159 |
+
num_queries: int,
|
| 160 |
+
embed_dim: int,
|
| 161 |
+
num_heads: int,
|
| 162 |
+
kv_dim: Optional[int] = None,
|
| 163 |
+
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
|
| 164 |
+
do_post_projection: bool = True,
|
| 165 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 166 |
+
prefix: str = "") -> None:
|
| 167 |
+
super().__init__()
|
| 168 |
+
|
| 169 |
+
self.num_queries = num_queries
|
| 170 |
+
self.embed_dim = embed_dim
|
| 171 |
+
self.num_heads = num_heads
|
| 172 |
+
|
| 173 |
+
self.query = nn.Parameter(torch.empty(self.num_queries, embed_dim))
|
| 174 |
+
|
| 175 |
+
if kv_dim is not None and kv_dim != embed_dim:
|
| 176 |
+
self.kv_proj = ReplicatedLinear(kv_dim,
|
| 177 |
+
embed_dim,
|
| 178 |
+
bias=False,
|
| 179 |
+
quant_config=quant_config,
|
| 180 |
+
prefix=f"{prefix}.kv_proj")
|
| 181 |
+
else:
|
| 182 |
+
# Maintain the same return value with ReplicatedLinear.forward
|
| 183 |
+
self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa
|
| 184 |
+
nn.Identity()(*args, **kwargs),
|
| 185 |
+
None,
|
| 186 |
+
)
|
| 187 |
+
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
|
| 188 |
+
self.ln_q = norm_layer(embed_dim)
|
| 189 |
+
self.ln_kv = norm_layer(embed_dim)
|
| 190 |
+
self.do_post_projection = do_post_projection
|
| 191 |
+
self.ln_post = norm_layer(embed_dim) if do_post_projection else None
|
| 192 |
+
self.proj = nn.Parameter(
|
| 193 |
+
(embed_dim**-0.5) *
|
| 194 |
+
torch.empty(embed_dim, embed_dim)) if do_post_projection else None
|
| 195 |
+
|
| 196 |
+
def _repeat(self, query, N: int):
|
| 197 |
+
return query.unsqueeze(1).repeat(1, N, 1)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class Resampler2(BaseResampler):
|
| 201 |
+
"""Resampler-perceiver network to be used for a variety of model types,
|
| 202 |
+
e.g., Qwen-vl / Minicpmv 2.0. The main difference is the addition of the
|
| 203 |
+
do_post_projection arg, which indicates whether or not there should be
|
| 204 |
+
a post layer normalization and projector after the attention. This is
|
| 205 |
+
present in minicpmv2.0, but not qwen-vl.
|
| 206 |
+
"""
|
| 207 |
+
|
| 208 |
+
def __init__(self,
|
| 209 |
+
grid_size: int,
|
| 210 |
+
embed_dim: int,
|
| 211 |
+
num_heads: int,
|
| 212 |
+
kv_dim: Optional[int] = None,
|
| 213 |
+
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
|
| 214 |
+
adaptive: bool = False,
|
| 215 |
+
do_post_projection: bool = True,
|
| 216 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 217 |
+
prefix: str = "") -> None:
|
| 218 |
+
super().__init__(grid_size**2,
|
| 219 |
+
embed_dim,
|
| 220 |
+
num_heads,
|
| 221 |
+
kv_dim,
|
| 222 |
+
norm_layer,
|
| 223 |
+
do_post_projection=do_post_projection,
|
| 224 |
+
quant_config=quant_config,
|
| 225 |
+
prefix=prefix)
|
| 226 |
+
|
| 227 |
+
self.adaptive = adaptive
|
| 228 |
+
pos_embed_arr = get_2d_sincos_pos_embed(embed_dim,
|
| 229 |
+
grid_size,
|
| 230 |
+
version=(2, 0))
|
| 231 |
+
|
| 232 |
+
self.pos_embed = nn.Parameter(
|
| 233 |
+
torch.from_numpy(pos_embed_arr).requires_grad_(False))
|
| 234 |
+
|
| 235 |
+
def forward(
|
| 236 |
+
self,
|
| 237 |
+
x: torch.Tensor,
|
| 238 |
+
tgt_sizes: Optional[torch.Tensor] = None,
|
| 239 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 240 |
+
) -> torch.Tensor:
|
| 241 |
+
if tgt_sizes is None:
|
| 242 |
+
tgt_sizes = int(math.sqrt(x.size(1)))
|
| 243 |
+
if self.adaptive:
|
| 244 |
+
pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim,
|
| 245 |
+
tgt_sizes,
|
| 246 |
+
version=(2, 0))
|
| 247 |
+
pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device,
|
| 248 |
+
dtype=x.dtype)
|
| 249 |
+
else:
|
| 250 |
+
pos_embed = get_abs_pos(self.pos_embed,
|
| 251 |
+
tgt_sizes).to(device=x.device,
|
| 252 |
+
dtype=x.dtype)
|
| 253 |
+
|
| 254 |
+
x, _ = self.kv_proj(x)
|
| 255 |
+
x = self.ln_kv(x).permute(1, 0, 2)
|
| 256 |
+
|
| 257 |
+
N = x.shape[1]
|
| 258 |
+
q = self.ln_q(self.query)
|
| 259 |
+
out = self.attn(
|
| 260 |
+
self._repeat(q, N) + self.pos_embed.unsqueeze(1),
|
| 261 |
+
x + pos_embed.unsqueeze(1),
|
| 262 |
+
x,
|
| 263 |
+
attn_mask=attn_mask,
|
| 264 |
+
)[0]
|
| 265 |
+
x = out.permute(1, 0, 2)
|
| 266 |
+
if self.do_post_projection:
|
| 267 |
+
x = self.ln_post(x)
|
| 268 |
+
x = x @ self.proj
|
| 269 |
+
return x
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/rotary_embedding.py
ADDED
|
@@ -0,0 +1,1114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
# Adapted from
|
| 4 |
+
# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
|
| 5 |
+
# Copyright 2023 The vLLM team.
|
| 6 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 9 |
+
# and OPT implementations in this library. It has been modified from its
|
| 10 |
+
# original forms to accommodate minor architectural differences compared
|
| 11 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 12 |
+
#
|
| 13 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 14 |
+
# you may not use this file except in compliance with the License.
|
| 15 |
+
# You may obtain a copy of the License at
|
| 16 |
+
#
|
| 17 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 18 |
+
#
|
| 19 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 20 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 21 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 22 |
+
# See the License for the specific language governing permissions and
|
| 23 |
+
# limitations under the License.
|
| 24 |
+
"""Rotary Positional Embeddings."""
|
| 25 |
+
import math
|
| 26 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
import torch.nn as nn
|
| 30 |
+
from transformers import PretrainedConfig
|
| 31 |
+
|
| 32 |
+
from vllm.model_executor.custom_op import CustomOp
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
| 36 |
+
x1 = x[..., :x.shape[-1] // 2]
|
| 37 |
+
x2 = x[..., x.shape[-1] // 2:]
|
| 38 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
|
| 42 |
+
x1 = x[..., ::2]
|
| 43 |
+
x2 = x[..., 1::2]
|
| 44 |
+
x = torch.stack((-x2, x1), dim=-1)
|
| 45 |
+
return x.flatten(-2)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _apply_rotary_emb(
|
| 49 |
+
x: torch.Tensor,
|
| 50 |
+
cos: torch.Tensor,
|
| 51 |
+
sin: torch.Tensor,
|
| 52 |
+
is_neox_style: bool,
|
| 53 |
+
) -> torch.Tensor:
|
| 54 |
+
"""
|
| 55 |
+
Args:
|
| 56 |
+
x: [num_tokens, num_heads, head_size]
|
| 57 |
+
cos: [num_tokens, head_size // 2]
|
| 58 |
+
sin: [num_tokens, head_size // 2]
|
| 59 |
+
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
| 60 |
+
positional embeddings.
|
| 61 |
+
"""
|
| 62 |
+
cos = cos.unsqueeze(-2).to(x.dtype)
|
| 63 |
+
sin = sin.unsqueeze(-2).to(x.dtype)
|
| 64 |
+
if is_neox_style:
|
| 65 |
+
x1, x2 = torch.chunk(x, 2, dim=-1)
|
| 66 |
+
else:
|
| 67 |
+
x1 = x[..., ::2]
|
| 68 |
+
x2 = x[..., 1::2]
|
| 69 |
+
o1 = x1 * cos - x2 * sin
|
| 70 |
+
o2 = x2 * cos + x1 * sin
|
| 71 |
+
if is_neox_style:
|
| 72 |
+
return torch.cat((o1, o2), dim=-1)
|
| 73 |
+
else:
|
| 74 |
+
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@CustomOp.register("rotary_embedding")
|
| 78 |
+
class RotaryEmbedding(CustomOp):
|
| 79 |
+
"""Original rotary positional embedding."""
|
| 80 |
+
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
head_size: int,
|
| 84 |
+
rotary_dim: int,
|
| 85 |
+
max_position_embeddings: int,
|
| 86 |
+
base: int,
|
| 87 |
+
is_neox_style: bool,
|
| 88 |
+
dtype: torch.dtype,
|
| 89 |
+
) -> None:
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.head_size = head_size
|
| 92 |
+
self.rotary_dim = rotary_dim
|
| 93 |
+
self.max_position_embeddings = max_position_embeddings
|
| 94 |
+
self.base = base
|
| 95 |
+
self.is_neox_style = is_neox_style
|
| 96 |
+
self.dtype = dtype
|
| 97 |
+
|
| 98 |
+
cache = self._compute_cos_sin_cache()
|
| 99 |
+
cache = cache.to(dtype)
|
| 100 |
+
self.cos_sin_cache: torch.Tensor
|
| 101 |
+
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
| 102 |
+
|
| 103 |
+
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
| 104 |
+
"""Compute the inverse frequency."""
|
| 105 |
+
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
| 106 |
+
# use CPU to compute the cache and then move it to GPU. However, we
|
| 107 |
+
# create the cache on GPU for faster initialization. This may cause
|
| 108 |
+
# a slight numerical difference between the HF implementation and ours.
|
| 109 |
+
inv_freq = 1.0 / (base**(torch.arange(
|
| 110 |
+
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
|
| 111 |
+
return inv_freq
|
| 112 |
+
|
| 113 |
+
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
| 114 |
+
"""Compute the cos and sin cache."""
|
| 115 |
+
inv_freq = self._compute_inv_freq(self.base)
|
| 116 |
+
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
| 117 |
+
|
| 118 |
+
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
| 119 |
+
cos = freqs.cos()
|
| 120 |
+
sin = freqs.sin()
|
| 121 |
+
cache = torch.cat((cos, sin), dim=-1)
|
| 122 |
+
return cache
|
| 123 |
+
|
| 124 |
+
def forward_native(
|
| 125 |
+
self,
|
| 126 |
+
positions: torch.Tensor,
|
| 127 |
+
query: torch.Tensor,
|
| 128 |
+
key: torch.Tensor,
|
| 129 |
+
offsets: Optional[torch.Tensor] = None,
|
| 130 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 131 |
+
"""A PyTorch-native implementation of forward()."""
|
| 132 |
+
if offsets is not None:
|
| 133 |
+
positions = positions + offsets
|
| 134 |
+
positions = positions.flatten()
|
| 135 |
+
num_tokens = positions.shape[0]
|
| 136 |
+
cos_sin = self.cos_sin_cache.index_select(0, positions)
|
| 137 |
+
cos, sin = cos_sin.chunk(2, dim=-1)
|
| 138 |
+
|
| 139 |
+
query_shape = query.shape
|
| 140 |
+
query = query.view(num_tokens, -1, self.head_size)
|
| 141 |
+
query_rot = query[..., :self.rotary_dim]
|
| 142 |
+
query_pass = query[..., self.rotary_dim:]
|
| 143 |
+
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
|
| 144 |
+
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
| 145 |
+
|
| 146 |
+
key_shape = key.shape
|
| 147 |
+
key = key.view(num_tokens, -1, self.head_size)
|
| 148 |
+
key_rot = key[..., :self.rotary_dim]
|
| 149 |
+
key_pass = key[..., self.rotary_dim:]
|
| 150 |
+
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
|
| 151 |
+
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
| 152 |
+
return query, key
|
| 153 |
+
|
| 154 |
+
def forward_cuda(
|
| 155 |
+
self,
|
| 156 |
+
positions: torch.Tensor,
|
| 157 |
+
query: torch.Tensor,
|
| 158 |
+
key: torch.Tensor,
|
| 159 |
+
offsets: Optional[torch.Tensor] = None,
|
| 160 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 161 |
+
from vllm import _custom_ops as ops
|
| 162 |
+
|
| 163 |
+
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
|
| 164 |
+
dtype=query.dtype)
|
| 165 |
+
# ops.rotary_embedding()/batched_rotary_embedding()
|
| 166 |
+
# are in-place operations that update the query and key tensors.
|
| 167 |
+
if offsets is not None:
|
| 168 |
+
ops.batched_rotary_embedding(positions, query, key, self.head_size,
|
| 169 |
+
self.cos_sin_cache,
|
| 170 |
+
self.is_neox_style, self.rotary_dim,
|
| 171 |
+
offsets)
|
| 172 |
+
else:
|
| 173 |
+
ops.rotary_embedding(positions, query, key, self.head_size,
|
| 174 |
+
self.cos_sin_cache, self.is_neox_style)
|
| 175 |
+
return query, key
|
| 176 |
+
|
| 177 |
+
def forward_xpu(
|
| 178 |
+
self,
|
| 179 |
+
positions: torch.Tensor,
|
| 180 |
+
query: torch.Tensor,
|
| 181 |
+
key: torch.Tensor,
|
| 182 |
+
offsets: Optional[torch.Tensor] = None,
|
| 183 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 184 |
+
from vllm._ipex_ops import ipex_ops as ops
|
| 185 |
+
|
| 186 |
+
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
|
| 187 |
+
dtype=query.dtype)
|
| 188 |
+
# ops.rotary_embedding()/batched_rotary_embedding()
|
| 189 |
+
# are in-place operations that update the query and key tensors.
|
| 190 |
+
if offsets is not None:
|
| 191 |
+
ops.batched_rotary_embedding(positions, query, key, self.head_size,
|
| 192 |
+
self.cos_sin_cache,
|
| 193 |
+
self.is_neox_style, self.rotary_dim,
|
| 194 |
+
offsets)
|
| 195 |
+
else:
|
| 196 |
+
ops.rotary_embedding(positions, query, key, self.head_size,
|
| 197 |
+
self.cos_sin_cache, self.is_neox_style)
|
| 198 |
+
return query, key
|
| 199 |
+
|
| 200 |
+
def forward_hpu(
|
| 201 |
+
self,
|
| 202 |
+
positions: torch.Tensor,
|
| 203 |
+
query: torch.Tensor,
|
| 204 |
+
key: torch.Tensor,
|
| 205 |
+
offsets: Optional[torch.Tensor] = None,
|
| 206 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 207 |
+
from habana_frameworks.torch.hpex.kernels import (
|
| 208 |
+
RotaryPosEmbeddingMode, apply_rotary_pos_emb)
|
| 209 |
+
positions = positions.flatten()
|
| 210 |
+
if offsets is not None:
|
| 211 |
+
positions = positions + offsets
|
| 212 |
+
num_tokens = positions.shape[0]
|
| 213 |
+
cos_sin = self.cos_sin_cache.index_select(0, positions).view(
|
| 214 |
+
num_tokens, 1, -1)
|
| 215 |
+
cos, sin = cos_sin.chunk(2, dim=-1)
|
| 216 |
+
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal
|
| 217 |
+
# to query hidden dimension, so the original tensors need to be
|
| 218 |
+
# expanded
|
| 219 |
+
# GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
|
| 220 |
+
# and expansion of cos/sin tensors via concatenation
|
| 221 |
+
# GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE
|
| 222 |
+
# and expansion of cos/sin tensors via repeat_interleave
|
| 223 |
+
rope_mode: RotaryPosEmbeddingMode
|
| 224 |
+
if self.is_neox_style:
|
| 225 |
+
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
|
| 226 |
+
cos = torch.cat((cos, cos), dim=-1)
|
| 227 |
+
sin = torch.cat((sin, sin), dim=-1)
|
| 228 |
+
else:
|
| 229 |
+
rope_mode = RotaryPosEmbeddingMode.PAIRWISE
|
| 230 |
+
sin = torch.repeat_interleave(sin,
|
| 231 |
+
2,
|
| 232 |
+
dim=-1,
|
| 233 |
+
output_size=cos_sin.shape[-1])
|
| 234 |
+
cos = torch.repeat_interleave(cos,
|
| 235 |
+
2,
|
| 236 |
+
dim=-1,
|
| 237 |
+
output_size=cos_sin.shape[-1])
|
| 238 |
+
|
| 239 |
+
query_shape = query.shape
|
| 240 |
+
query = query.view(num_tokens, -1, self.head_size)
|
| 241 |
+
query_rot = query[..., :self.rotary_dim]
|
| 242 |
+
query_pass = query[..., self.rotary_dim:]
|
| 243 |
+
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0,
|
| 244 |
+
rope_mode)
|
| 245 |
+
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
| 246 |
+
|
| 247 |
+
key_shape = key.shape
|
| 248 |
+
key = key.view(num_tokens, -1, self.head_size)
|
| 249 |
+
key_rot = key[..., :self.rotary_dim]
|
| 250 |
+
key_pass = key[..., self.rotary_dim:]
|
| 251 |
+
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
|
| 252 |
+
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
| 253 |
+
return query, key
|
| 254 |
+
|
| 255 |
+
def extra_repr(self) -> str:
|
| 256 |
+
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
|
| 257 |
+
s += f", max_position_embeddings={self.max_position_embeddings}"
|
| 258 |
+
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
|
| 259 |
+
return s
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
| 263 |
+
"""RotaryEmbedding extended with linear scaling.
|
| 264 |
+
|
| 265 |
+
It supports multiple scaling factors. Since multiple LoRA adapters may have
|
| 266 |
+
different scaling factors, we need multiple cos/sin caches. In this way,
|
| 267 |
+
instead of running rotary embedding kernel per lora, we can run multiple
|
| 268 |
+
lora in a batched way.
|
| 269 |
+
|
| 270 |
+
In addition to that, we also keep the cos/sin cache for the scaling factor
|
| 271 |
+
of 1 (default) at all times.
|
| 272 |
+
|
| 273 |
+
Exemplary for two scaling factors x=1, y and z with embeddings
|
| 274 |
+
[[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and
|
| 275 |
+
[[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and
|
| 276 |
+
[[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]],
|
| 277 |
+
|
| 278 |
+
we construct the cos/sin cache as follows:
|
| 279 |
+
[[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p],
|
| 280 |
+
...
|
| 281 |
+
[xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]]
|
| 282 |
+
|
| 283 |
+
We then use offsets to index into the cos/sin cache for
|
| 284 |
+
the respective scaling factors.
|
| 285 |
+
|
| 286 |
+
The offset to cache can be accessed via `scaling_factor_to_offset` API.
|
| 287 |
+
|
| 288 |
+
Credits to the Reddit user /u/kaiokendev
|
| 289 |
+
"""
|
| 290 |
+
|
| 291 |
+
def __init__(
|
| 292 |
+
self,
|
| 293 |
+
head_size: int,
|
| 294 |
+
rotary_dim: int,
|
| 295 |
+
max_position_embeddings: int,
|
| 296 |
+
base: int,
|
| 297 |
+
is_neox_style: bool,
|
| 298 |
+
scaling_factors: Union[List[float], float],
|
| 299 |
+
dtype: torch.dtype,
|
| 300 |
+
) -> None:
|
| 301 |
+
if isinstance(scaling_factors, float):
|
| 302 |
+
scaling_factors = [scaling_factors]
|
| 303 |
+
self.scaling_factors: List[float] = scaling_factors # noqa
|
| 304 |
+
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
| 305 |
+
is_neox_style, dtype)
|
| 306 |
+
# Lazy initialized.
|
| 307 |
+
self._scaling_factor_to_offset: Dict[float, int]
|
| 308 |
+
|
| 309 |
+
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
| 310 |
+
inv_freq = self._compute_inv_freq(self.base)
|
| 311 |
+
cache_list: List[torch.Tensor] = []
|
| 312 |
+
# offsets to the next cache in a tensor.
|
| 313 |
+
# Each offset corresponds to the same index in scaling_factors.
|
| 314 |
+
offsets: List[int] = []
|
| 315 |
+
for scaling_factor in self.scaling_factors:
|
| 316 |
+
# NOTE(woosuk): self.max_position_embeddings is the original
|
| 317 |
+
# maximum length before applying the rope scaling.
|
| 318 |
+
# Thus, the maximum length after applying the rope scaling is
|
| 319 |
+
# self.max_position_embeddings * self.scaling_factor.
|
| 320 |
+
max_len = self.max_position_embeddings * scaling_factor
|
| 321 |
+
t = torch.arange(max_len, dtype=torch.float)
|
| 322 |
+
t = t / scaling_factor
|
| 323 |
+
|
| 324 |
+
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
| 325 |
+
cos = freqs.cos()
|
| 326 |
+
sin = freqs.sin()
|
| 327 |
+
cache = torch.cat((cos, sin), dim=-1)
|
| 328 |
+
if not cache_list:
|
| 329 |
+
offset = 0
|
| 330 |
+
else:
|
| 331 |
+
last_offset = offsets[-1]
|
| 332 |
+
next_max_len = cache_list[-1].shape[0]
|
| 333 |
+
offset = last_offset + next_max_len
|
| 334 |
+
offsets.append(offset)
|
| 335 |
+
cache_list.append(cache)
|
| 336 |
+
self._scaling_factor_to_offset = {
|
| 337 |
+
float(scaling_factor): offsets[i]
|
| 338 |
+
for i, scaling_factor in enumerate(self.scaling_factors)
|
| 339 |
+
}
|
| 340 |
+
assert len(self.scaling_factors) == len(offsets)
|
| 341 |
+
return torch.cat(cache_list, dim=0)
|
| 342 |
+
|
| 343 |
+
@property
|
| 344 |
+
def scaling_factor_to_offset(self) -> Dict[float, int]:
|
| 345 |
+
return self._scaling_factor_to_offset
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
| 349 |
+
"""RotaryEmbedding extended with Dynamic NTK scaling.
|
| 350 |
+
|
| 351 |
+
Credits to the Reddit users /u/bloc97 and /u/emozilla
|
| 352 |
+
"""
|
| 353 |
+
|
| 354 |
+
def __init__(
|
| 355 |
+
self,
|
| 356 |
+
head_size: int,
|
| 357 |
+
rotary_dim: int,
|
| 358 |
+
max_position_embeddings: int,
|
| 359 |
+
base: int,
|
| 360 |
+
is_neox_style: bool,
|
| 361 |
+
scaling_factor: float,
|
| 362 |
+
dtype: torch.dtype,
|
| 363 |
+
) -> None:
|
| 364 |
+
self.scaling_factor = scaling_factor
|
| 365 |
+
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
| 366 |
+
is_neox_style, dtype)
|
| 367 |
+
|
| 368 |
+
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
| 369 |
+
# NOTE(woosuk): self.max_position_embeddings is the original
|
| 370 |
+
# maximum length before applying the rope scaling.
|
| 371 |
+
# Thus, the maximum length after applying the rope scaling is
|
| 372 |
+
# self.max_position_embeddings * self.scaling_factor.
|
| 373 |
+
max_len = self.max_position_embeddings * self.scaling_factor
|
| 374 |
+
base = self.base * (
|
| 375 |
+
(self.scaling_factor * max_len / self.max_position_embeddings) -
|
| 376 |
+
(self.scaling_factor - 1))**(self.rotary_dim /
|
| 377 |
+
(self.rotary_dim - 2))
|
| 378 |
+
inv_freq = self._compute_inv_freq(base)
|
| 379 |
+
t = torch.arange(max_len, dtype=torch.float)
|
| 380 |
+
|
| 381 |
+
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
| 382 |
+
cos = freqs.cos()
|
| 383 |
+
sin = freqs.sin()
|
| 384 |
+
cache = torch.cat((cos, sin), dim=-1)
|
| 385 |
+
return cache
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
# Inverse dim formula to find dim based on number of rotations
|
| 389 |
+
def _yarn_find_correction_dim(num_rotations: int,
|
| 390 |
+
dim: int,
|
| 391 |
+
base: float = 10000,
|
| 392 |
+
max_position_embeddings: int = 2048) -> float:
|
| 393 |
+
return (dim * math.log(max_position_embeddings /
|
| 394 |
+
(num_rotations * 2 * math.pi))) / (2 *
|
| 395 |
+
math.log(base))
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
# Find dim range bounds based on rotations
|
| 399 |
+
def _yarn_find_correction_range(
|
| 400 |
+
low_rot: int,
|
| 401 |
+
high_rot: int,
|
| 402 |
+
dim: int,
|
| 403 |
+
base: float = 10000,
|
| 404 |
+
max_position_embeddings: int = 2048) -> Tuple[int, int]:
|
| 405 |
+
low = math.floor(
|
| 406 |
+
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
| 407 |
+
high = math.ceil(
|
| 408 |
+
_yarn_find_correction_dim(high_rot, dim, base,
|
| 409 |
+
max_position_embeddings))
|
| 410 |
+
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
|
| 414 |
+
dtype: torch.dtype) -> torch.Tensor:
|
| 415 |
+
if low == high:
|
| 416 |
+
high += 0.001 # Prevent singularity
|
| 417 |
+
|
| 418 |
+
linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
|
| 419 |
+
ramp_func = torch.clamp(linear_func, 0, 1)
|
| 420 |
+
return ramp_func
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def _yarn_get_mscale(scale: float = 1) -> float:
|
| 424 |
+
if scale <= 1:
|
| 425 |
+
return 1.0
|
| 426 |
+
return 0.1 * math.log(scale) + 1.0
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
| 430 |
+
"""RotaryEmbedding extended with YaRN method.
|
| 431 |
+
|
| 432 |
+
Credits to Peng et al. github.com/jquesnelle/yarn
|
| 433 |
+
"""
|
| 434 |
+
|
| 435 |
+
def __init__(
|
| 436 |
+
self,
|
| 437 |
+
head_size: int,
|
| 438 |
+
rotary_dim: int,
|
| 439 |
+
max_position_embeddings: int,
|
| 440 |
+
base: int,
|
| 441 |
+
is_neox_style: bool,
|
| 442 |
+
scaling_factor: float,
|
| 443 |
+
dtype: torch.dtype,
|
| 444 |
+
*,
|
| 445 |
+
extrapolation_factor: float = 1,
|
| 446 |
+
attn_factor: float = 1,
|
| 447 |
+
beta_fast: int = 32,
|
| 448 |
+
beta_slow: int = 1,
|
| 449 |
+
) -> None:
|
| 450 |
+
self.scaling_factor = scaling_factor
|
| 451 |
+
self.extrapolation_factor = extrapolation_factor
|
| 452 |
+
self.attn_factor = attn_factor
|
| 453 |
+
self.beta_fast = beta_fast
|
| 454 |
+
self.beta_slow = beta_slow
|
| 455 |
+
# Get n-d magnitude scaling corrected for interpolation
|
| 456 |
+
self.mscale = float(
|
| 457 |
+
_yarn_get_mscale(self.scaling_factor) * attn_factor)
|
| 458 |
+
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
| 459 |
+
is_neox_style, dtype)
|
| 460 |
+
|
| 461 |
+
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
| 462 |
+
pos_freqs = self.base**(
|
| 463 |
+
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
|
| 464 |
+
self.rotary_dim)
|
| 465 |
+
inv_freq_extrapolation = 1.0 / pos_freqs
|
| 466 |
+
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
| 467 |
+
|
| 468 |
+
low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
|
| 469 |
+
self.rotary_dim, self.base,
|
| 470 |
+
self.max_position_embeddings)
|
| 471 |
+
# Get n-d rotational scaling corrected for extrapolation
|
| 472 |
+
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
|
| 473 |
+
low, high, self.rotary_dim // 2,
|
| 474 |
+
dtype=torch.float)) * self.extrapolation_factor
|
| 475 |
+
inv_freq = inv_freq_interpolation * (
|
| 476 |
+
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
| 477 |
+
return inv_freq
|
| 478 |
+
|
| 479 |
+
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
| 480 |
+
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
| 481 |
+
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
|
| 482 |
+
dtype=torch.float32)
|
| 483 |
+
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
| 484 |
+
cos = (freqs.cos() * self.mscale)
|
| 485 |
+
sin = (freqs.sin() * self.mscale)
|
| 486 |
+
cache = torch.cat((cos, sin), dim=-1)
|
| 487 |
+
return cache
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
| 491 |
+
"""Phi3 family of models scaled rotary embedding.
|
| 492 |
+
|
| 493 |
+
Based on the original RotaryEmbedding implementation.
|
| 494 |
+
"""
|
| 495 |
+
|
| 496 |
+
def __init__(
|
| 497 |
+
self,
|
| 498 |
+
head_size: int,
|
| 499 |
+
rotary_dim: int,
|
| 500 |
+
max_position_embeddings: int,
|
| 501 |
+
original_max_position_embeddings: int,
|
| 502 |
+
base: int,
|
| 503 |
+
is_neox_style: bool,
|
| 504 |
+
dtype: torch.dtype,
|
| 505 |
+
short_factor: List[float],
|
| 506 |
+
long_factor: List[float],
|
| 507 |
+
short_mscale: Optional[float] = None,
|
| 508 |
+
long_mscale: Optional[float] = None,
|
| 509 |
+
):
|
| 510 |
+
super().__init__()
|
| 511 |
+
|
| 512 |
+
if rotary_dim != head_size:
|
| 513 |
+
raise ValueError(
|
| 514 |
+
f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
|
| 515 |
+
rotary_dim != head_size ({rotary_dim}!={head_size}).")
|
| 516 |
+
if is_neox_style is False:
|
| 517 |
+
raise ValueError(
|
| 518 |
+
"`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
self.head_size = head_size
|
| 522 |
+
self.max_position_embeddings = max_position_embeddings
|
| 523 |
+
self.original_max_position_embeddings = original_max_position_embeddings
|
| 524 |
+
self.base = base
|
| 525 |
+
self.short_factor = short_factor
|
| 526 |
+
self.long_factor = long_factor
|
| 527 |
+
|
| 528 |
+
scale = self.max_position_embeddings / \
|
| 529 |
+
self.original_max_position_embeddings
|
| 530 |
+
if scale <= 1.0:
|
| 531 |
+
scaling_factor = 1.0
|
| 532 |
+
else:
|
| 533 |
+
scaling_factor = math.sqrt(
|
| 534 |
+
1 + math.log(scale) /
|
| 535 |
+
math.log(self.original_max_position_embeddings))
|
| 536 |
+
if short_mscale is None:
|
| 537 |
+
short_mscale = scaling_factor
|
| 538 |
+
if long_mscale is None:
|
| 539 |
+
long_mscale = scaling_factor
|
| 540 |
+
|
| 541 |
+
self.short_mscale = short_mscale
|
| 542 |
+
self.long_mscale = long_mscale
|
| 543 |
+
|
| 544 |
+
short_cache = self._compute_cos_sin_cache(
|
| 545 |
+
original_max_position_embeddings, short_factor, short_mscale)
|
| 546 |
+
short_cache = short_cache.to(dtype)
|
| 547 |
+
|
| 548 |
+
long_cache = self._compute_cos_sin_cache(max_position_embeddings,
|
| 549 |
+
long_factor, long_mscale)
|
| 550 |
+
long_cache = long_cache.to(dtype)
|
| 551 |
+
|
| 552 |
+
long_short_cache = torch.cat([short_cache, long_cache], dim=0)
|
| 553 |
+
self.register_buffer("long_short_cos_sin_cache",
|
| 554 |
+
long_short_cache,
|
| 555 |
+
persistent=False)
|
| 556 |
+
|
| 557 |
+
def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor:
|
| 558 |
+
rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
|
| 559 |
+
inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
|
| 560 |
+
0, self.head_size, 2, dtype=torch.float) / self.head_size)))
|
| 561 |
+
return inv_freq
|
| 562 |
+
|
| 563 |
+
def _compute_cos_sin_cache(
|
| 564 |
+
self,
|
| 565 |
+
max_position_embeddings: int,
|
| 566 |
+
rescale_factors: List[float],
|
| 567 |
+
mscale: float,
|
| 568 |
+
) -> torch.Tensor:
|
| 569 |
+
inv_freq = self._compute_inv_freq(rescale_factors)
|
| 570 |
+
t = torch.arange(max_position_embeddings, dtype=torch.float)
|
| 571 |
+
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
| 572 |
+
cos = freqs.cos() * mscale
|
| 573 |
+
sin = freqs.sin() * mscale
|
| 574 |
+
cache = torch.cat((cos, sin), dim=-1)
|
| 575 |
+
return cache
|
| 576 |
+
|
| 577 |
+
def forward(
|
| 578 |
+
self,
|
| 579 |
+
positions: torch.Tensor,
|
| 580 |
+
query: torch.Tensor,
|
| 581 |
+
key: torch.Tensor,
|
| 582 |
+
offsets: Optional[torch.Tensor] = None,
|
| 583 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 584 |
+
query = query.view(*query.shape[:-1], -1, self.head_size)
|
| 585 |
+
key = key.view(*key.shape[:-1], -1, self.head_size)
|
| 586 |
+
|
| 587 |
+
k = self.original_max_position_embeddings
|
| 588 |
+
long_prompt_offset = (torch.any(positions > k).float() *
|
| 589 |
+
torch.full_like(positions, k)).long()
|
| 590 |
+
idx = (torch.add(positions, long_prompt_offset)
|
| 591 |
+
if long_prompt_offset is not None else positions)
|
| 592 |
+
idx = torch.add(idx, offsets) if offsets is not None else idx
|
| 593 |
+
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
|
| 594 |
+
|
| 595 |
+
cos, sin = cos_sin.chunk(2, dim=-1)
|
| 596 |
+
cos = cos.repeat(1, 2).unsqueeze(-2)
|
| 597 |
+
sin = sin.repeat(1, 2).unsqueeze(-2)
|
| 598 |
+
|
| 599 |
+
query = query * cos + _rotate_neox(query) * sin
|
| 600 |
+
key = key * cos + _rotate_neox(key) * sin
|
| 601 |
+
|
| 602 |
+
return query.flatten(-2), key.flatten(-2)
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
| 606 |
+
if scale <= 1:
|
| 607 |
+
return 1.0
|
| 608 |
+
return 0.1 * mscale * math.log(scale) + 1.0
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
| 612 |
+
"""RotaryEmbedding extended with YaRN method.
|
| 613 |
+
|
| 614 |
+
Credits to Peng et al. github.com/jquesnelle/yarn
|
| 615 |
+
"""
|
| 616 |
+
|
| 617 |
+
def __init__(
|
| 618 |
+
self,
|
| 619 |
+
head_size: int,
|
| 620 |
+
rotary_dim: int,
|
| 621 |
+
max_position_embeddings: int,
|
| 622 |
+
base: int,
|
| 623 |
+
is_neox_style: bool,
|
| 624 |
+
scaling_factor: float,
|
| 625 |
+
dtype: torch.dtype,
|
| 626 |
+
*,
|
| 627 |
+
extrapolation_factor: float = 1,
|
| 628 |
+
attn_factor: float = 1,
|
| 629 |
+
beta_fast: int = 32,
|
| 630 |
+
beta_slow: int = 1,
|
| 631 |
+
mscale: float = 1,
|
| 632 |
+
mscale_all_dim: float = 0,
|
| 633 |
+
) -> None:
|
| 634 |
+
self.scaling_factor = scaling_factor
|
| 635 |
+
self.extrapolation_factor = extrapolation_factor
|
| 636 |
+
self.attn_factor = attn_factor
|
| 637 |
+
self.beta_fast = beta_fast
|
| 638 |
+
self.beta_slow = beta_slow
|
| 639 |
+
# Get n-d magnitude scaling corrected for interpolation.
|
| 640 |
+
self.mscale = float(
|
| 641 |
+
yarn_get_mscale(self.scaling_factor, float(mscale)) /
|
| 642 |
+
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
|
| 643 |
+
attn_factor)
|
| 644 |
+
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
| 645 |
+
is_neox_style, dtype)
|
| 646 |
+
|
| 647 |
+
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
| 648 |
+
pos_freqs = self.base**(torch.arange(
|
| 649 |
+
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
|
| 650 |
+
self.rotary_dim)
|
| 651 |
+
inv_freq_extrapolation = 1.0 / pos_freqs
|
| 652 |
+
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
| 653 |
+
|
| 654 |
+
low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
|
| 655 |
+
self.rotary_dim, self.base,
|
| 656 |
+
self.max_position_embeddings)
|
| 657 |
+
# Get n-d rotational scaling corrected for extrapolation
|
| 658 |
+
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
|
| 659 |
+
low, high, self.rotary_dim // 2,
|
| 660 |
+
dtype=torch.float)) * self.extrapolation_factor
|
| 661 |
+
inv_freq = inv_freq_interpolation * (
|
| 662 |
+
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
| 663 |
+
return inv_freq
|
| 664 |
+
|
| 665 |
+
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
| 666 |
+
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
| 667 |
+
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
|
| 668 |
+
device="cuda",
|
| 669 |
+
dtype=torch.float32)
|
| 670 |
+
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
| 671 |
+
cos = (freqs.cos() * self.mscale)
|
| 672 |
+
sin = (freqs.sin() * self.mscale)
|
| 673 |
+
cache = torch.cat((cos, sin), dim=-1)
|
| 674 |
+
return cache
|
| 675 |
+
|
| 676 |
+
def forward(
|
| 677 |
+
self,
|
| 678 |
+
positions: torch.Tensor,
|
| 679 |
+
query: torch.Tensor,
|
| 680 |
+
key: torch.Tensor,
|
| 681 |
+
offsets: Optional[torch.Tensor] = None,
|
| 682 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 683 |
+
"""PyTorch-native implementation equivalent to forward()."""
|
| 684 |
+
query_rot = query[..., :self.rotary_dim]
|
| 685 |
+
key_rot = key[..., :self.rotary_dim]
|
| 686 |
+
if self.rotary_dim < self.head_size:
|
| 687 |
+
query_pass = query[..., self.rotary_dim:]
|
| 688 |
+
key_pass = key[..., self.rotary_dim:]
|
| 689 |
+
|
| 690 |
+
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
|
| 691 |
+
positions.device)
|
| 692 |
+
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
|
| 693 |
+
if offsets is not None else positions]
|
| 694 |
+
cos, sin = cos_sin.chunk(2, dim=-1)
|
| 695 |
+
if self.is_neox_style:
|
| 696 |
+
# NOTE(woosuk): Here we assume that the positions tensor has the
|
| 697 |
+
# shape [batch_size, seq_len].
|
| 698 |
+
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
|
| 699 |
+
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
|
| 700 |
+
else:
|
| 701 |
+
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
| 702 |
+
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
| 703 |
+
|
| 704 |
+
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
|
| 705 |
+
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
|
| 706 |
+
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
|
| 707 |
+
|
| 708 |
+
if self.rotary_dim < self.head_size:
|
| 709 |
+
query = torch.cat((query_rot, query_pass), dim=-1)
|
| 710 |
+
key = torch.cat((key_rot, key_pass), dim=-1)
|
| 711 |
+
else:
|
| 712 |
+
query = query_rot
|
| 713 |
+
key = key_rot
|
| 714 |
+
return query, key
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
class Llama3RotaryEmbedding(RotaryEmbedding):
|
| 718 |
+
|
| 719 |
+
def __init__(
|
| 720 |
+
self,
|
| 721 |
+
head_size: int,
|
| 722 |
+
rotary_dim: int,
|
| 723 |
+
max_position_embeddings: int,
|
| 724 |
+
base: int,
|
| 725 |
+
is_neox_style: bool,
|
| 726 |
+
dtype: torch.dtype,
|
| 727 |
+
scaling_factor: float,
|
| 728 |
+
low_freq_factor: float,
|
| 729 |
+
high_freq_factor: float,
|
| 730 |
+
orig_max_position: int,
|
| 731 |
+
) -> None:
|
| 732 |
+
self.scaling_factor = scaling_factor
|
| 733 |
+
self.low_freq_factor = low_freq_factor
|
| 734 |
+
self.high_freq_factor = high_freq_factor
|
| 735 |
+
self.orig_max_position = orig_max_position
|
| 736 |
+
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
| 737 |
+
is_neox_style, dtype)
|
| 738 |
+
|
| 739 |
+
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
| 740 |
+
inv_freqs = super()._compute_inv_freq(base)
|
| 741 |
+
low_freq_wavelen = self.orig_max_position / self.low_freq_factor
|
| 742 |
+
high_freq_wavelen = self.orig_max_position / self.high_freq_factor
|
| 743 |
+
|
| 744 |
+
wave_len = 2 * math.pi / inv_freqs
|
| 745 |
+
if self.low_freq_factor != self.high_freq_factor:
|
| 746 |
+
smooth = (self.orig_max_position / wave_len - self.low_freq_factor
|
| 747 |
+
) / (self.high_freq_factor - self.low_freq_factor)
|
| 748 |
+
else:
|
| 749 |
+
smooth = 0
|
| 750 |
+
new_freqs = torch.where(
|
| 751 |
+
wave_len < high_freq_wavelen,
|
| 752 |
+
inv_freqs,
|
| 753 |
+
torch.where(
|
| 754 |
+
wave_len > low_freq_wavelen,
|
| 755 |
+
inv_freqs / self.scaling_factor,
|
| 756 |
+
(1 - smooth) * inv_freqs / self.scaling_factor +
|
| 757 |
+
smooth * inv_freqs,
|
| 758 |
+
),
|
| 759 |
+
)
|
| 760 |
+
return new_freqs
|
| 761 |
+
|
| 762 |
+
|
| 763 |
+
class MRotaryEmbedding(RotaryEmbedding):
|
| 764 |
+
"""Rotary Embedding with Multimodal Sections."""
|
| 765 |
+
|
| 766 |
+
def __init__(
|
| 767 |
+
self,
|
| 768 |
+
head_size: int,
|
| 769 |
+
rotary_dim: int,
|
| 770 |
+
max_position_embeddings: int,
|
| 771 |
+
base: int,
|
| 772 |
+
is_neox_style: bool,
|
| 773 |
+
dtype: torch.dtype,
|
| 774 |
+
mrope_section: Optional[List[int]] = None,
|
| 775 |
+
) -> None:
|
| 776 |
+
# In Qwen2.5-VL, the maximum index value is related to the duration of
|
| 777 |
+
# the input video. We enlarge max_position_embeddings to 4 times to get
|
| 778 |
+
# a larger the cos and sin cache.
|
| 779 |
+
self.cache_max_position_num = max_position_embeddings * 4
|
| 780 |
+
super().__init__(head_size, rotary_dim, self.cache_max_position_num,
|
| 781 |
+
base, is_neox_style, dtype)
|
| 782 |
+
|
| 783 |
+
self.mrope_section = mrope_section
|
| 784 |
+
if self.mrope_section:
|
| 785 |
+
assert sum(self.mrope_section) == rotary_dim // 2
|
| 786 |
+
|
| 787 |
+
def forward(
|
| 788 |
+
self,
|
| 789 |
+
positions: torch.Tensor,
|
| 790 |
+
query: torch.Tensor,
|
| 791 |
+
key: torch.Tensor,
|
| 792 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 793 |
+
"""PyTorch-native implementation equivalent to forward().
|
| 794 |
+
|
| 795 |
+
Args:
|
| 796 |
+
positions:
|
| 797 |
+
[num_tokens,] (text only) or
|
| 798 |
+
[3, num_tokens] (T/H/W positions with multimodal inputs)
|
| 799 |
+
query: [num_tokens, num_heads * head_size]
|
| 800 |
+
key: [num_tokens, num_kv_heads * head_size]
|
| 801 |
+
"""
|
| 802 |
+
assert positions.ndim == 1 or positions.ndim == 2
|
| 803 |
+
|
| 804 |
+
num_tokens = positions.shape[-1]
|
| 805 |
+
cos_sin = self.cos_sin_cache[positions]
|
| 806 |
+
cos, sin = cos_sin.chunk(2, dim=-1)
|
| 807 |
+
if positions.ndim == 2:
|
| 808 |
+
assert self.mrope_section
|
| 809 |
+
|
| 810 |
+
cos = torch.cat([
|
| 811 |
+
m[i]
|
| 812 |
+
for i, m in enumerate(cos.split(self.mrope_section, dim=-1))
|
| 813 |
+
],
|
| 814 |
+
dim=-1)
|
| 815 |
+
sin = torch.cat([
|
| 816 |
+
m[i]
|
| 817 |
+
for i, m in enumerate(sin.split(self.mrope_section, dim=-1))
|
| 818 |
+
],
|
| 819 |
+
dim=-1)
|
| 820 |
+
|
| 821 |
+
query_shape = query.shape
|
| 822 |
+
query = query.view(num_tokens, -1, self.head_size)
|
| 823 |
+
query_rot = query[..., :self.rotary_dim]
|
| 824 |
+
query_pass = query[..., self.rotary_dim:]
|
| 825 |
+
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
|
| 826 |
+
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
| 827 |
+
|
| 828 |
+
key_shape = key.shape
|
| 829 |
+
key = key.view(num_tokens, -1, self.head_size)
|
| 830 |
+
key_rot = key[..., :self.rotary_dim]
|
| 831 |
+
key_pass = key[..., self.rotary_dim:]
|
| 832 |
+
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
|
| 833 |
+
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
| 834 |
+
return query, key
|
| 835 |
+
|
| 836 |
+
@staticmethod
|
| 837 |
+
def get_input_positions(
|
| 838 |
+
input_tokens: List[int],
|
| 839 |
+
hf_config: PretrainedConfig,
|
| 840 |
+
image_grid_thw: Union[List[List[int]], torch.Tensor],
|
| 841 |
+
video_grid_thw: Union[List[List[int]], torch.Tensor],
|
| 842 |
+
second_per_grid_ts: Optional[List[float]] = None,
|
| 843 |
+
context_len: int = 0,
|
| 844 |
+
seq_len: Optional[int] = None,
|
| 845 |
+
) -> Tuple[List[List[int]], int]:
|
| 846 |
+
"""Get mrope input positions and delta value."""
|
| 847 |
+
|
| 848 |
+
llm_positions, mrope_position_delta = \
|
| 849 |
+
MRotaryEmbedding.get_input_positions_tensor(
|
| 850 |
+
input_tokens=input_tokens,
|
| 851 |
+
hf_config=hf_config,
|
| 852 |
+
image_grid_thw=image_grid_thw,
|
| 853 |
+
video_grid_thw=video_grid_thw,
|
| 854 |
+
second_per_grid_ts=second_per_grid_ts,
|
| 855 |
+
context_len=context_len,
|
| 856 |
+
seq_len=seq_len,
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
return llm_positions.tolist(), mrope_position_delta
|
| 860 |
+
|
| 861 |
+
@staticmethod
|
| 862 |
+
def get_input_positions_tensor(
|
| 863 |
+
input_tokens: List[int],
|
| 864 |
+
hf_config: PretrainedConfig,
|
| 865 |
+
image_grid_thw: Union[List[List[int]], torch.Tensor],
|
| 866 |
+
video_grid_thw: Union[List[List[int]], torch.Tensor],
|
| 867 |
+
second_per_grid_ts: Optional[List[float]] = None,
|
| 868 |
+
context_len: int = 0,
|
| 869 |
+
seq_len: Optional[int] = None,
|
| 870 |
+
) -> Tuple[torch.Tensor, int]:
|
| 871 |
+
"""Get mrope input positions and delta value."""
|
| 872 |
+
|
| 873 |
+
image_token_id = hf_config.image_token_id
|
| 874 |
+
video_token_id = hf_config.video_token_id
|
| 875 |
+
vision_start_token_id = hf_config.vision_start_token_id
|
| 876 |
+
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
| 877 |
+
tokens_per_second = getattr(hf_config.vision_config,
|
| 878 |
+
"tokens_per_second", 1.0)
|
| 879 |
+
|
| 880 |
+
if isinstance(image_grid_thw, torch.Tensor):
|
| 881 |
+
image_grid_thw = image_grid_thw.tolist()
|
| 882 |
+
if isinstance(video_grid_thw, torch.Tensor):
|
| 883 |
+
video_grid_thw = video_grid_thw.tolist()
|
| 884 |
+
|
| 885 |
+
input_tokens_tensor = torch.tensor(input_tokens)
|
| 886 |
+
vision_start_indices = torch.argwhere(
|
| 887 |
+
input_tokens_tensor == vision_start_token_id).squeeze(1)
|
| 888 |
+
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
|
| 889 |
+
image_nums = (vision_tokens == image_token_id).sum()
|
| 890 |
+
video_nums = (vision_tokens == video_token_id).sum()
|
| 891 |
+
llm_pos_ids_list: list = []
|
| 892 |
+
|
| 893 |
+
st = 0
|
| 894 |
+
remain_images, remain_videos = image_nums, video_nums
|
| 895 |
+
|
| 896 |
+
image_index, video_index = 0, 0
|
| 897 |
+
for _ in range(image_nums + video_nums):
|
| 898 |
+
video_second_per_grid_t = 0.0
|
| 899 |
+
if image_token_id in input_tokens and remain_images > 0:
|
| 900 |
+
ed_image = input_tokens.index(image_token_id, st)
|
| 901 |
+
else:
|
| 902 |
+
ed_image = len(input_tokens) + 1
|
| 903 |
+
if video_token_id in input_tokens and remain_videos > 0:
|
| 904 |
+
ed_video = input_tokens.index(video_token_id, st)
|
| 905 |
+
else:
|
| 906 |
+
ed_video = len(input_tokens) + 1
|
| 907 |
+
if ed_image < ed_video:
|
| 908 |
+
t, h, w = (
|
| 909 |
+
image_grid_thw[image_index][0],
|
| 910 |
+
image_grid_thw[image_index][1],
|
| 911 |
+
image_grid_thw[image_index][2],
|
| 912 |
+
)
|
| 913 |
+
image_index += 1
|
| 914 |
+
remain_images -= 1
|
| 915 |
+
ed = ed_image
|
| 916 |
+
else:
|
| 917 |
+
t, h, w = (
|
| 918 |
+
video_grid_thw[video_index][0],
|
| 919 |
+
video_grid_thw[video_index][1],
|
| 920 |
+
video_grid_thw[video_index][2],
|
| 921 |
+
)
|
| 922 |
+
video_second_per_grid_t = 1.0
|
| 923 |
+
if second_per_grid_ts is not None:
|
| 924 |
+
video_second_per_grid_t = second_per_grid_ts[video_index]
|
| 925 |
+
video_index += 1
|
| 926 |
+
remain_videos -= 1
|
| 927 |
+
ed = ed_video
|
| 928 |
+
|
| 929 |
+
llm_grid_t, llm_grid_h, llm_grid_w = \
|
| 930 |
+
t, h // spatial_merge_size, w // spatial_merge_size
|
| 931 |
+
text_len = ed - st
|
| 932 |
+
|
| 933 |
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
|
| 934 |
+
llm_pos_ids_list) > 0 else 0
|
| 935 |
+
llm_pos_ids_list.append(
|
| 936 |
+
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
| 937 |
+
|
| 938 |
+
t_index = (torch.arange(llm_grid_t).view(-1, 1).expand(
|
| 939 |
+
-1, llm_grid_h * llm_grid_w) * video_second_per_grid_t *
|
| 940 |
+
tokens_per_second).long().flatten()
|
| 941 |
+
|
| 942 |
+
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
|
| 943 |
+
llm_grid_t, -1, llm_grid_w).flatten()
|
| 944 |
+
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
|
| 945 |
+
llm_grid_t, llm_grid_h, -1).flatten()
|
| 946 |
+
llm_pos_ids_list.append(
|
| 947 |
+
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
|
| 948 |
+
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
| 949 |
+
|
| 950 |
+
if st < len(input_tokens):
|
| 951 |
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
|
| 952 |
+
llm_pos_ids_list) > 0 else 0
|
| 953 |
+
text_len = len(input_tokens) - st
|
| 954 |
+
llm_pos_ids_list.append(
|
| 955 |
+
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
| 956 |
+
|
| 957 |
+
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
| 958 |
+
mrope_position_delta = (llm_positions.max() + 1 -
|
| 959 |
+
len(input_tokens)).item()
|
| 960 |
+
llm_positions = llm_positions[:, context_len:seq_len]
|
| 961 |
+
|
| 962 |
+
return llm_positions, mrope_position_delta
|
| 963 |
+
|
| 964 |
+
@staticmethod
|
| 965 |
+
def get_next_input_positions(
|
| 966 |
+
mrope_position_delta: int,
|
| 967 |
+
context_len: int,
|
| 968 |
+
seq_len: int,
|
| 969 |
+
) -> List[List[int]]:
|
| 970 |
+
return [
|
| 971 |
+
list(
|
| 972 |
+
range(context_len + mrope_position_delta,
|
| 973 |
+
seq_len + mrope_position_delta)) for _ in range(3)
|
| 974 |
+
]
|
| 975 |
+
|
| 976 |
+
@staticmethod
|
| 977 |
+
def get_next_input_positions_tensor(
|
| 978 |
+
mrope_position_delta: int,
|
| 979 |
+
context_len: int,
|
| 980 |
+
seq_len: int,
|
| 981 |
+
) -> torch.Tensor:
|
| 982 |
+
return torch.arange(
|
| 983 |
+
mrope_position_delta + context_len,
|
| 984 |
+
mrope_position_delta + seq_len,
|
| 985 |
+
).expand(3, -1)
|
| 986 |
+
|
| 987 |
+
|
| 988 |
+
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
| 989 |
+
|
| 990 |
+
|
| 991 |
+
def get_rope(
|
| 992 |
+
head_size: int,
|
| 993 |
+
rotary_dim: int,
|
| 994 |
+
max_position: int,
|
| 995 |
+
base: int,
|
| 996 |
+
is_neox_style: bool = True,
|
| 997 |
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
| 998 |
+
dtype: Optional[torch.dtype] = None,
|
| 999 |
+
partial_rotary_factor: float = 1.0,
|
| 1000 |
+
) -> RotaryEmbedding:
|
| 1001 |
+
if dtype is None:
|
| 1002 |
+
dtype = torch.get_default_dtype()
|
| 1003 |
+
if rope_scaling is not None:
|
| 1004 |
+
# Transforms every value that is a list into a tuple for caching calls
|
| 1005 |
+
rope_scaling_tuple = {
|
| 1006 |
+
k: tuple(v) if isinstance(v, list) else v
|
| 1007 |
+
for k, v in rope_scaling.items()
|
| 1008 |
+
}
|
| 1009 |
+
rope_scaling_args = tuple(rope_scaling_tuple.items())
|
| 1010 |
+
else:
|
| 1011 |
+
rope_scaling_args = None
|
| 1012 |
+
if partial_rotary_factor < 1.0:
|
| 1013 |
+
rotary_dim = int(rotary_dim * partial_rotary_factor)
|
| 1014 |
+
key = (head_size, rotary_dim, max_position, base, is_neox_style,
|
| 1015 |
+
rope_scaling_args, dtype)
|
| 1016 |
+
if key in _ROPE_DICT:
|
| 1017 |
+
return _ROPE_DICT[key]
|
| 1018 |
+
|
| 1019 |
+
if rope_scaling is None:
|
| 1020 |
+
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
| 1021 |
+
is_neox_style, dtype)
|
| 1022 |
+
else:
|
| 1023 |
+
scaling_type = rope_scaling["rope_type"]
|
| 1024 |
+
|
| 1025 |
+
if scaling_type == "llama3":
|
| 1026 |
+
scaling_factor = rope_scaling["factor"]
|
| 1027 |
+
low_freq_factor = rope_scaling["low_freq_factor"]
|
| 1028 |
+
high_freq_factor = rope_scaling["high_freq_factor"]
|
| 1029 |
+
original_max_position = rope_scaling[
|
| 1030 |
+
"original_max_position_embeddings"]
|
| 1031 |
+
rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim,
|
| 1032 |
+
max_position, base,
|
| 1033 |
+
is_neox_style, dtype,
|
| 1034 |
+
scaling_factor, low_freq_factor,
|
| 1035 |
+
high_freq_factor,
|
| 1036 |
+
original_max_position)
|
| 1037 |
+
elif scaling_type == "default":
|
| 1038 |
+
if "mrope_section" in rope_scaling:
|
| 1039 |
+
rotary_emb = MRotaryEmbedding(
|
| 1040 |
+
head_size,
|
| 1041 |
+
rotary_dim,
|
| 1042 |
+
max_position,
|
| 1043 |
+
base,
|
| 1044 |
+
is_neox_style,
|
| 1045 |
+
dtype,
|
| 1046 |
+
mrope_section=rope_scaling["mrope_section"],
|
| 1047 |
+
)
|
| 1048 |
+
else:
|
| 1049 |
+
rotary_emb = RotaryEmbedding(
|
| 1050 |
+
head_size,
|
| 1051 |
+
rotary_dim,
|
| 1052 |
+
max_position,
|
| 1053 |
+
base,
|
| 1054 |
+
is_neox_style,
|
| 1055 |
+
dtype,
|
| 1056 |
+
)
|
| 1057 |
+
elif scaling_type == "linear":
|
| 1058 |
+
scaling_factor = rope_scaling["factor"]
|
| 1059 |
+
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
|
| 1060 |
+
max_position, base,
|
| 1061 |
+
is_neox_style,
|
| 1062 |
+
scaling_factor, dtype)
|
| 1063 |
+
elif scaling_type == "dynamic":
|
| 1064 |
+
scaling_factor = rope_scaling["factor"]
|
| 1065 |
+
rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
| 1066 |
+
head_size, rotary_dim, max_position, base, is_neox_style,
|
| 1067 |
+
scaling_factor, dtype)
|
| 1068 |
+
elif scaling_type == "yarn":
|
| 1069 |
+
scaling_factor = rope_scaling["factor"]
|
| 1070 |
+
original_max_position = rope_scaling[
|
| 1071 |
+
"original_max_position_embeddings"]
|
| 1072 |
+
extra_kwargs = {
|
| 1073 |
+
k: v
|
| 1074 |
+
for k, v in rope_scaling.items()
|
| 1075 |
+
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
|
| 1076 |
+
"beta_slow")
|
| 1077 |
+
}
|
| 1078 |
+
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
|
| 1079 |
+
original_max_position,
|
| 1080 |
+
base, is_neox_style,
|
| 1081 |
+
scaling_factor, dtype,
|
| 1082 |
+
**extra_kwargs)
|
| 1083 |
+
elif scaling_type == "deepseek_yarn":
|
| 1084 |
+
scaling_factor = rope_scaling["factor"]
|
| 1085 |
+
original_max_position = rope_scaling[
|
| 1086 |
+
"original_max_position_embeddings"]
|
| 1087 |
+
# assert max_position == original_max_position * scaling_factor
|
| 1088 |
+
extra_kwargs = {
|
| 1089 |
+
k: v
|
| 1090 |
+
for k, v in rope_scaling.items()
|
| 1091 |
+
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
|
| 1092 |
+
"beta_slow", "mscale", "mscale_all_dim")
|
| 1093 |
+
}
|
| 1094 |
+
rotary_emb = DeepseekScalingRotaryEmbedding(
|
| 1095 |
+
head_size, rotary_dim, original_max_position, base,
|
| 1096 |
+
is_neox_style, scaling_factor, dtype, **extra_kwargs)
|
| 1097 |
+
elif scaling_type == "longrope":
|
| 1098 |
+
short_factor = rope_scaling["short_factor"]
|
| 1099 |
+
long_factor = rope_scaling["long_factor"]
|
| 1100 |
+
original_max_position = rope_scaling[
|
| 1101 |
+
"original_max_position_embeddings"]
|
| 1102 |
+
extra_kwargs = {
|
| 1103 |
+
k: v
|
| 1104 |
+
for k, v in rope_scaling.items()
|
| 1105 |
+
if k in ("short_mscale", "long_mscale")
|
| 1106 |
+
}
|
| 1107 |
+
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
|
| 1108 |
+
head_size, rotary_dim, max_position, original_max_position,
|
| 1109 |
+
base, is_neox_style, dtype, short_factor, long_factor,
|
| 1110 |
+
**extra_kwargs)
|
| 1111 |
+
else:
|
| 1112 |
+
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
| 1113 |
+
_ROPE_DICT[key] = rotary_emb
|
| 1114 |
+
return rotary_emb
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/sampler.py
ADDED
|
@@ -0,0 +1,1292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""A layer that samples the next tokens from the model's outputs."""
|
| 3 |
+
import itertools
|
| 4 |
+
import warnings
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from importlib.util import find_spec
|
| 7 |
+
from math import inf
|
| 8 |
+
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import msgspec
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
import vllm.envs as envs
|
| 15 |
+
from vllm.model_executor.layers.utils import apply_penalties
|
| 16 |
+
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
| 17 |
+
SamplingTensors,
|
| 18 |
+
SequenceGroupToSample)
|
| 19 |
+
from vllm.sampling_params import SamplingType
|
| 20 |
+
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
|
| 21 |
+
CompletionSequenceGroupOutput, Logprob,
|
| 22 |
+
PromptLogprobs, SampleLogprobs, SequenceOutput)
|
| 23 |
+
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
| 24 |
+
|
| 25 |
+
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
|
| 26 |
+
import flashinfer.sampling
|
| 27 |
+
# yapf: disable
|
| 28 |
+
from flashinfer.sampling import (
|
| 29 |
+
top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)
|
| 30 |
+
|
| 31 |
+
# yapf: enable
|
| 32 |
+
else:
|
| 33 |
+
flashinfer_top_k_top_p_sampling = None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_sampler() -> torch.nn.Module:
|
| 37 |
+
if envs.VLLM_USE_V1:
|
| 38 |
+
# Lazy import: the v1 package isn't distributed
|
| 39 |
+
from vllm.v1.sample.sampler import Sampler as V1Sampler
|
| 40 |
+
return V1Sampler()
|
| 41 |
+
return Sampler()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# (num_token_ids, num_parent_ids) per sequence group.
|
| 45 |
+
SampleResultType = List[Tuple[List[int], List[int]]]
|
| 46 |
+
|
| 47 |
+
# Types of temporary data structures used for
|
| 48 |
+
# computing sample_result
|
| 49 |
+
SampleMetadataType = Dict[SamplingType, Tuple[List[int],
|
| 50 |
+
List[SequenceGroupToSample]]]
|
| 51 |
+
MultinomialSamplesType = Dict[SamplingType, torch.Tensor]
|
| 52 |
+
SampleResultsDictType = Dict[int, Tuple[List[int], List[int]]]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# Encapsulates temporary data structures for computing
|
| 56 |
+
# sample_result.
|
| 57 |
+
#
|
| 58 |
+
# * For multi-step scheduling: must be returned
|
| 59 |
+
# by `Sampler.forward()` and used later to compute the pythonized
|
| 60 |
+
# sample_result
|
| 61 |
+
#
|
| 62 |
+
# * For single-step scheduling: consumed immediately
|
| 63 |
+
# inside `Sampler.forward()` to compute pythonized sample_result.
|
| 64 |
+
@dataclass
|
| 65 |
+
class SampleResultArgsType:
|
| 66 |
+
sample_metadata: SampleMetadataType
|
| 67 |
+
multinomial_samples: MultinomialSamplesType
|
| 68 |
+
sample_results_dict: SampleResultsDictType
|
| 69 |
+
sampling_metadata: SamplingMetadata
|
| 70 |
+
greedy_samples: Optional[torch.Tensor]
|
| 71 |
+
beam_search_logprobs: Optional[torch.Tensor]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# Union of non-deferred (single-step scheduling)
|
| 75 |
+
# vs deferred (multi-step scheduling)
|
| 76 |
+
# sample result types
|
| 77 |
+
MaybeDeferredSampleResultType = Union[SampleResultType, SampleResultArgsType]
|
| 78 |
+
|
| 79 |
+
# Abbreviation of the _sample() return type
|
| 80 |
+
SampleReturnType = Tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]]
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class SamplerOutput(
|
| 84 |
+
msgspec.Struct,
|
| 85 |
+
omit_defaults=True, # type: ignore[call-arg]
|
| 86 |
+
array_like=True): # type: ignore[call-arg]
|
| 87 |
+
"""For each sequence group, we generate a list of SequenceOutput object,
|
| 88 |
+
each of which contains one possible candidate for the next token.
|
| 89 |
+
|
| 90 |
+
This data structure implements methods, so it can be used like a list, but
|
| 91 |
+
also has optional fields for device tensors.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
outputs: List[CompletionSequenceGroupOutput]
|
| 95 |
+
|
| 96 |
+
# On-device tensor containing probabilities of each token.
|
| 97 |
+
sampled_token_probs: Optional[torch.Tensor] = None
|
| 98 |
+
|
| 99 |
+
# On-device tensor containing the logprobs of each token.
|
| 100 |
+
logprobs: Optional["torch.Tensor"] = None
|
| 101 |
+
|
| 102 |
+
# Holds either (1) the pythonized sampler result (single-step scheduling)
|
| 103 |
+
# or (2) what will be arguments for later deferred pythonization of the
|
| 104 |
+
# sampler result (muliti-step scheduling)
|
| 105 |
+
deferred_sample_results_args: Optional[SampleResultArgsType] = None
|
| 106 |
+
|
| 107 |
+
# On-device tensor containing the sampled token ids.
|
| 108 |
+
sampled_token_ids: Optional[torch.Tensor] = None
|
| 109 |
+
# CPU tensor containing the sampled token ids. Used during multi-step to
|
| 110 |
+
# return the sampled token ids from last rank to AsyncLLMEngine to be
|
| 111 |
+
# 'broadcasted' to all other PP ranks for next step.
|
| 112 |
+
sampled_token_ids_cpu: Optional[torch.Tensor] = None
|
| 113 |
+
|
| 114 |
+
# Spec decode metrics populated by workers.
|
| 115 |
+
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
|
| 116 |
+
|
| 117 |
+
# Optional last hidden states from the model.
|
| 118 |
+
hidden_states: Optional[torch.Tensor] = None
|
| 119 |
+
|
| 120 |
+
# Optional prefill hidden states from the model
|
| 121 |
+
# (used for models like EAGLE).
|
| 122 |
+
prefill_hidden_states: Optional[torch.Tensor] = None
|
| 123 |
+
|
| 124 |
+
# Time taken in the forward pass for this across all workers
|
| 125 |
+
model_forward_time: Optional[float] = None
|
| 126 |
+
|
| 127 |
+
# Time taken in the model execute function. This will include model forward,
|
| 128 |
+
# block/sync across workers, cpu-gpu sync time and sampling time.
|
| 129 |
+
model_execute_time: Optional[float] = None
|
| 130 |
+
|
| 131 |
+
def __getitem__(self, idx: int) -> CompletionSequenceGroupOutput:
|
| 132 |
+
return self.outputs[idx]
|
| 133 |
+
|
| 134 |
+
def __setitem__(self, idx: int, value):
|
| 135 |
+
self.outputs[idx] = value
|
| 136 |
+
|
| 137 |
+
def __iter__(self) -> Iterator[CompletionSequenceGroupOutput]:
|
| 138 |
+
return iter(self.outputs)
|
| 139 |
+
|
| 140 |
+
def __len__(self):
|
| 141 |
+
return len(self.outputs)
|
| 142 |
+
|
| 143 |
+
def __eq__(self, other: object):
|
| 144 |
+
return isinstance(other,
|
| 145 |
+
self.__class__) and self.outputs == other.outputs
|
| 146 |
+
|
| 147 |
+
def __repr__(self) -> str:
|
| 148 |
+
"""Show the shape of a tensor instead of its values to reduce noise.
|
| 149 |
+
"""
|
| 150 |
+
sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
|
| 151 |
+
else self.sampled_token_probs.shape)
|
| 152 |
+
sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
|
| 153 |
+
self.sampled_token_ids.shape)
|
| 154 |
+
return (
|
| 155 |
+
f"SamplerOutput(outputs={self.outputs}, "
|
| 156 |
+
f"sampled_token_probs={sampled_token_probs_repr}, "
|
| 157 |
+
f"sampled_token_ids={sampled_token_ids_repr}, "
|
| 158 |
+
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class Sampler(nn.Module):
|
| 162 |
+
"""Samples the next tokens from the model's outputs.
|
| 163 |
+
|
| 164 |
+
This layer does the following:
|
| 165 |
+
1. Discard the hidden states that are not used for sampling (i.e., all
|
| 166 |
+
tokens except the final one in each prompt).
|
| 167 |
+
2. Compute the logits for the next tokens.
|
| 168 |
+
3. Apply presence, frequency and repetition penalties.
|
| 169 |
+
4. Apply temperature scaling.
|
| 170 |
+
5. Apply top-p and top-k truncation.
|
| 171 |
+
6. Sample the next tokens.
|
| 172 |
+
Here, each sequence group within the batch can have different sampling
|
| 173 |
+
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
|
| 174 |
+
|
| 175 |
+
The structure of the logits tensor is coupled with the seq_groups in
|
| 176 |
+
sampling_metadata. Typically, each sequence in each seq_group has one row in
|
| 177 |
+
logits for the next token to be sampled; however, for a seq_group with a
|
| 178 |
+
prompt request with the prompt_logprobs sampling parameter, there are rows
|
| 179 |
+
in logits for each token in the input prompt.
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
def __init__(self):
|
| 183 |
+
super().__init__()
|
| 184 |
+
|
| 185 |
+
# Whether or not the SamplerOutput should have on-device tensors
|
| 186 |
+
# containing the sampled token ids and probabilities. This is used by
|
| 187 |
+
# speculative decoding.
|
| 188 |
+
self.include_gpu_probs_tensor = False
|
| 189 |
+
self.should_modify_greedy_probs_inplace = False
|
| 190 |
+
|
| 191 |
+
def _init_sampling_tensors(
|
| 192 |
+
self,
|
| 193 |
+
logits: torch.Tensor,
|
| 194 |
+
sampling_metadata: SamplingMetadata,
|
| 195 |
+
):
|
| 196 |
+
"""The goal here is to reuse sampling tensors between similar decode
|
| 197 |
+
runs. This is possible because sampling logic does not change between
|
| 198 |
+
decodes of the same sequences.
|
| 199 |
+
"""
|
| 200 |
+
_, vocab_size = logits.shape
|
| 201 |
+
|
| 202 |
+
# First free any existing stored sampling tensors.
|
| 203 |
+
# This is necessary because some sampling tensors may
|
| 204 |
+
# have pinned memory.
|
| 205 |
+
self._sampling_tensors = None
|
| 206 |
+
|
| 207 |
+
# Initialize new sampling tensors
|
| 208 |
+
(sampling_tensors, do_penalties, do_top_p_top_k,
|
| 209 |
+
do_min_p) = SamplingTensors.from_sampling_metadata(
|
| 210 |
+
sampling_metadata, vocab_size, logits.device, logits.dtype)
|
| 211 |
+
|
| 212 |
+
self._sampling_tensors = sampling_tensors
|
| 213 |
+
self._do_penalties = do_penalties
|
| 214 |
+
self._do_top_p_top_k = do_top_p_top_k
|
| 215 |
+
self._do_min_p = do_min_p
|
| 216 |
+
|
| 217 |
+
def forward(
|
| 218 |
+
self,
|
| 219 |
+
logits: torch.Tensor,
|
| 220 |
+
sampling_metadata: SamplingMetadata,
|
| 221 |
+
) -> Optional[SamplerOutput]:
|
| 222 |
+
"""
|
| 223 |
+
Single-step scheduling:
|
| 224 |
+
* Perform GPU-side sampling computation & compute
|
| 225 |
+
GPU-side logprobs tensor
|
| 226 |
+
* Pythonize sampling result & logprobs tensor
|
| 227 |
+
|
| 228 |
+
Multi-step scheduling:
|
| 229 |
+
* Perform GPU-side sampling computation & compute
|
| 230 |
+
GPU-side logprobs tensor
|
| 231 |
+
* Defer Pythonization of sampling result & logprobs
|
| 232 |
+
tensor
|
| 233 |
+
* Encapsulate arguments required for deferred Pythonization
|
| 234 |
+
in the :class:`SamplerOutput` structure
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
logits: (num_tokens, vocab_size).
|
| 238 |
+
sampling_metadata: Metadata for sampling.
|
| 239 |
+
"""
|
| 240 |
+
assert logits is not None
|
| 241 |
+
_, vocab_size = logits.shape
|
| 242 |
+
|
| 243 |
+
# Prepare sampling tensors with pinned memory to avoid blocking.
|
| 244 |
+
if not sampling_metadata.reuse_sampling_tensors:
|
| 245 |
+
self._init_sampling_tensors(logits, sampling_metadata)
|
| 246 |
+
elif self._do_penalties:
|
| 247 |
+
# In this case, the sampling tensors logic depends on
|
| 248 |
+
# "output_tokens" of a sequence. As a result, we cannot
|
| 249 |
+
# reuse sampling tensors, since "output_tokens" changes
|
| 250 |
+
# between decode runs.
|
| 251 |
+
self._init_sampling_tensors(logits, sampling_metadata)
|
| 252 |
+
|
| 253 |
+
assert self._sampling_tensors is not None
|
| 254 |
+
sampling_tensors = self._sampling_tensors
|
| 255 |
+
do_penalties = self._do_penalties
|
| 256 |
+
do_top_p_top_k = self._do_top_p_top_k
|
| 257 |
+
do_min_p = self._do_min_p
|
| 258 |
+
|
| 259 |
+
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
|
| 260 |
+
|
| 261 |
+
# Apply presence and frequency penalties.
|
| 262 |
+
if do_penalties:
|
| 263 |
+
logits = apply_penalties(logits, sampling_tensors.prompt_tokens,
|
| 264 |
+
sampling_tensors.output_tokens,
|
| 265 |
+
sampling_tensors.presence_penalties,
|
| 266 |
+
sampling_tensors.frequency_penalties,
|
| 267 |
+
sampling_tensors.repetition_penalties)
|
| 268 |
+
|
| 269 |
+
# Use float32 to apply temperature scaling.
|
| 270 |
+
# Use in-place division to avoid creating a new tensor.
|
| 271 |
+
logits = logits.to(torch.float)
|
| 272 |
+
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
|
| 273 |
+
|
| 274 |
+
if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
|
| 275 |
+
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
|
| 276 |
+
sampling_tensors.top_ks)
|
| 277 |
+
|
| 278 |
+
if do_min_p:
|
| 279 |
+
logits = _apply_min_p(logits, sampling_tensors.min_ps)
|
| 280 |
+
|
| 281 |
+
# We use float32 for probabilities and log probabilities.
|
| 282 |
+
# Compute the probabilities.
|
| 283 |
+
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
| 284 |
+
# Compute the log probabilities.
|
| 285 |
+
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
| 286 |
+
|
| 287 |
+
# Sample the next tokens.
|
| 288 |
+
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
|
| 289 |
+
probs,
|
| 290 |
+
logprobs,
|
| 291 |
+
sampling_metadata,
|
| 292 |
+
sampling_tensors,
|
| 293 |
+
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
|
| 294 |
+
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
if self.include_gpu_probs_tensor:
|
| 298 |
+
# Since we will defer sampler result Pythonization,
|
| 299 |
+
# preserve GPU-side tensors in support of later
|
| 300 |
+
# deferred pythonization of logprobs
|
| 301 |
+
assert maybe_sampled_tokens_tensor is not None
|
| 302 |
+
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
|
| 303 |
+
else:
|
| 304 |
+
# Since Pythonization has already happened, don't preserve
|
| 305 |
+
# GPU-side tensors.
|
| 306 |
+
on_device_tensors = None
|
| 307 |
+
|
| 308 |
+
# Get the logprobs query results.
|
| 309 |
+
prompt_logprobs = None
|
| 310 |
+
sample_logprobs = None
|
| 311 |
+
if not sampling_metadata.skip_sampler_cpu_output:
|
| 312 |
+
# Pythonize logprobs now (GPU -> CPU); do not defer.
|
| 313 |
+
assert not isinstance(maybe_deferred_sample_results,
|
| 314 |
+
SampleResultArgsType)
|
| 315 |
+
prompt_logprobs, sample_logprobs = get_logprobs(
|
| 316 |
+
logprobs, sampling_metadata, maybe_deferred_sample_results)
|
| 317 |
+
|
| 318 |
+
return _build_sampler_output(
|
| 319 |
+
maybe_deferred_sample_results,
|
| 320 |
+
sampling_metadata,
|
| 321 |
+
prompt_logprobs,
|
| 322 |
+
sample_logprobs,
|
| 323 |
+
on_device_tensors=on_device_tensors,
|
| 324 |
+
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
|
| 325 |
+
|
| 326 |
+
@property
|
| 327 |
+
def _should_modify_greedy_probs_inplace(self) -> bool:
|
| 328 |
+
"""Whether or not the sampler should modify the probability distribution
|
| 329 |
+
of greedily-sampled tokens such that multinomial sampling would sample
|
| 330 |
+
the greedily-sampled token.
|
| 331 |
+
|
| 332 |
+
In other words, if True then we set the probability of the greedily-
|
| 333 |
+
sampled token to 1.
|
| 334 |
+
|
| 335 |
+
This is used by speculative decoding, which requires that the sampling
|
| 336 |
+
method be encoded into the probability distribution.
|
| 337 |
+
"""
|
| 338 |
+
return self.should_modify_greedy_probs_inplace
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def _apply_min_tokens_penalty(
|
| 342 |
+
logits: torch.Tensor,
|
| 343 |
+
sampling_metadata: SamplingMetadata,
|
| 344 |
+
) -> torch.Tensor:
|
| 345 |
+
"""Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
|
| 346 |
+
have not been generated yet
|
| 347 |
+
"""
|
| 348 |
+
# list of indices in logits that will be set to -inf
|
| 349 |
+
logits_to_penalize: List[Tuple[int, int]] = []
|
| 350 |
+
logits_applied = 0
|
| 351 |
+
for seq_group in sampling_metadata.seq_groups:
|
| 352 |
+
seq_ids = seq_group.seq_ids
|
| 353 |
+
sampling_params = seq_group.sampling_params
|
| 354 |
+
|
| 355 |
+
sample_indices = seq_group.sample_indices
|
| 356 |
+
logits_applied += len(sample_indices) + len(
|
| 357 |
+
seq_group.prompt_logprob_indices)
|
| 358 |
+
if not seq_group.do_sample:
|
| 359 |
+
continue
|
| 360 |
+
|
| 361 |
+
start_idx = sample_indices[0]
|
| 362 |
+
min_tokens = sampling_params.min_tokens
|
| 363 |
+
token_ids_to_penalize = sampling_params.all_stop_token_ids
|
| 364 |
+
if min_tokens > 0 and token_ids_to_penalize:
|
| 365 |
+
seqs_to_penalize: List[int] = []
|
| 366 |
+
for j, seq_id in enumerate(seq_ids):
|
| 367 |
+
seq_data = seq_group.seq_data[seq_id]
|
| 368 |
+
if len(seq_data.output_token_ids_array) < min_tokens:
|
| 369 |
+
seqs_to_penalize.append(j)
|
| 370 |
+
|
| 371 |
+
if seqs_to_penalize:
|
| 372 |
+
# convert to the index into logits
|
| 373 |
+
seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
|
| 374 |
+
# itertools.product pairs each seq index with every token id
|
| 375 |
+
logits_to_penalize.extend(
|
| 376 |
+
itertools.product(seqs_to_penalize, token_ids_to_penalize))
|
| 377 |
+
|
| 378 |
+
if logits_to_penalize:
|
| 379 |
+
# use zip and * to group indices along each dimension
|
| 380 |
+
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
|
| 381 |
+
logits[tuple(zip(*logits_to_penalize))] = -float("inf")
|
| 382 |
+
|
| 383 |
+
# verifies that no rows in logits were missed unexpectedly
|
| 384 |
+
assert logits_applied == logits.shape[0]
|
| 385 |
+
return logits
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def _apply_top_k_top_p(
|
| 389 |
+
logits: torch.Tensor,
|
| 390 |
+
p: torch.Tensor,
|
| 391 |
+
k: torch.Tensor,
|
| 392 |
+
) -> torch.Tensor:
|
| 393 |
+
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
| 394 |
+
|
| 395 |
+
# Apply top-k.
|
| 396 |
+
top_k_mask = logits_sort.size(1) - k.to(torch.long)
|
| 397 |
+
# Get all the top_k values.
|
| 398 |
+
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
| 399 |
+
top_k_mask = logits_sort < top_k_mask
|
| 400 |
+
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
| 401 |
+
|
| 402 |
+
# Apply top-p.
|
| 403 |
+
probs_sort = logits_sort.softmax(dim=-1)
|
| 404 |
+
probs_sum = probs_sort.cumsum(dim=-1)
|
| 405 |
+
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
|
| 406 |
+
# at least one
|
| 407 |
+
top_p_mask[:, -1] = False
|
| 408 |
+
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
| 409 |
+
|
| 410 |
+
# Re-sort the probabilities.
|
| 411 |
+
logits = torch.empty_like(logits_sort).scatter_(dim=-1,
|
| 412 |
+
index=logits_idx,
|
| 413 |
+
src=logits_sort)
|
| 414 |
+
return logits
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def _apply_min_p(
|
| 418 |
+
logits: torch.Tensor,
|
| 419 |
+
min_p: torch.Tensor,
|
| 420 |
+
) -> torch.Tensor:
|
| 421 |
+
"""
|
| 422 |
+
Adapted from
|
| 423 |
+
https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
|
| 424 |
+
"""
|
| 425 |
+
probs = torch.softmax(logits, dim=-1)
|
| 426 |
+
top_probs, _ = probs.max(dim=-1, keepdim=True)
|
| 427 |
+
scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
|
| 428 |
+
tokens_to_remove = probs < scaled_min_p
|
| 429 |
+
logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
|
| 430 |
+
|
| 431 |
+
return logits
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def _greedy_sample(
|
| 435 |
+
selected_seq_groups: List[SequenceGroupToSample],
|
| 436 |
+
samples: torch.Tensor,
|
| 437 |
+
) -> SampleResultType:
|
| 438 |
+
"""Run greedy sampling on a given samples.
|
| 439 |
+
|
| 440 |
+
Args:
|
| 441 |
+
selected_seq_groups: A list of sequence groups batched.
|
| 442 |
+
samples: (num_selected_samples,) A tensor of samples. The length of
|
| 443 |
+
samples could be smaller than selected_seq_groups if
|
| 444 |
+
seq_group.do_sample is False.
|
| 445 |
+
Returns:
|
| 446 |
+
Tuple of (next_token_ids, parent_ids). The length of returned list is
|
| 447 |
+
same as the length of selected_seq_groups. If the corresponding
|
| 448 |
+
seq_group has do_sample=False, tuple contains ([], [])
|
| 449 |
+
"""
|
| 450 |
+
samples_lst = samples.tolist()
|
| 451 |
+
sample_idx = 0
|
| 452 |
+
results: SampleResultType = []
|
| 453 |
+
for seq_group in selected_seq_groups:
|
| 454 |
+
if not seq_group.do_sample:
|
| 455 |
+
results.append(([], []))
|
| 456 |
+
continue
|
| 457 |
+
|
| 458 |
+
seq_ids = seq_group.seq_ids
|
| 459 |
+
num_parent_seqs = len(seq_ids)
|
| 460 |
+
assert num_parent_seqs == 1, (
|
| 461 |
+
"Greedy sampling should have only one seq.")
|
| 462 |
+
parent_ids = list(range(num_parent_seqs))
|
| 463 |
+
next_token_ids = [samples_lst[sample_idx]]
|
| 464 |
+
results.append((next_token_ids, parent_ids))
|
| 465 |
+
sample_idx += num_parent_seqs
|
| 466 |
+
return results
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def _random_sample(
|
| 470 |
+
selected_seq_groups: List[SequenceGroupToSample],
|
| 471 |
+
random_samples: torch.Tensor,
|
| 472 |
+
) -> SampleResultType:
|
| 473 |
+
"""Run random sampling on a given samples.
|
| 474 |
+
|
| 475 |
+
Args:
|
| 476 |
+
selected_seq_groups: A list of sequence groups batched.
|
| 477 |
+
random_samples: (num_selected_samples,) A tensor of samples. The
|
| 478 |
+
length of samples could be smaller than selected_seq_groups if
|
| 479 |
+
seq_group.do_sample is False.
|
| 480 |
+
Returns:
|
| 481 |
+
Tuple of (next_token_ids, parent_ids). The length of returned list is
|
| 482 |
+
same as the length of selected_seq_groups. If the corresponding
|
| 483 |
+
seq_group has do_sample=False, tuple contains ([], [])
|
| 484 |
+
"""
|
| 485 |
+
# Find the maximum n value of the prompt phase requests.
|
| 486 |
+
random_samples = random_samples.cpu()
|
| 487 |
+
sample_idx = 0
|
| 488 |
+
results: SampleResultType = []
|
| 489 |
+
for seq_group in selected_seq_groups:
|
| 490 |
+
if not seq_group.do_sample:
|
| 491 |
+
results.append(([], []))
|
| 492 |
+
continue
|
| 493 |
+
|
| 494 |
+
seq_ids = seq_group.seq_ids
|
| 495 |
+
sampling_params = seq_group.sampling_params
|
| 496 |
+
is_prompt = seq_group.is_prompt
|
| 497 |
+
num_parent_seqs = len(seq_ids)
|
| 498 |
+
if is_prompt:
|
| 499 |
+
# Prompt phase.
|
| 500 |
+
parent_ids = [0] * sampling_params.n
|
| 501 |
+
next_token_ids = random_samples[
|
| 502 |
+
sample_idx, :sampling_params.n].tolist()
|
| 503 |
+
else:
|
| 504 |
+
# Generation phase.
|
| 505 |
+
parent_ids = list(range(num_parent_seqs))
|
| 506 |
+
next_token_ids = random_samples[sample_idx:sample_idx +
|
| 507 |
+
num_parent_seqs, 0].tolist()
|
| 508 |
+
results.append((next_token_ids, parent_ids))
|
| 509 |
+
sample_idx += num_parent_seqs
|
| 510 |
+
return results
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def _beam_search_sample(
|
| 514 |
+
selected_seq_groups: List[SequenceGroupToSample],
|
| 515 |
+
logprobs: torch.Tensor,
|
| 516 |
+
) -> SampleResultType:
|
| 517 |
+
"""Run beam sampling on a given samples.
|
| 518 |
+
|
| 519 |
+
Args:
|
| 520 |
+
selected_seq_groups: A list of sequence groups batched.
|
| 521 |
+
logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
|
| 522 |
+
on selected sample indices.
|
| 523 |
+
Returns:
|
| 524 |
+
Tuple of (next_token_ids, parent_ids). The length of returned list is
|
| 525 |
+
same as the length of selected_seq_groups. If the corresponding
|
| 526 |
+
seq_group has do_sample=False, tuple contains ([], [])
|
| 527 |
+
"""
|
| 528 |
+
# We sample 2 * beam_width candidates to make sure that with high
|
| 529 |
+
# probability we can get `beam_width` candidates in addition to
|
| 530 |
+
# the finished sequences for the next iteration. See
|
| 531 |
+
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
|
| 532 |
+
# for details. See also HF reference:
|
| 533 |
+
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
|
| 534 |
+
#
|
| 535 |
+
# NOTE: Beam search is not vectorized, so its speed can be slower than
|
| 536 |
+
# other sampling methods.
|
| 537 |
+
sample_idx = 0
|
| 538 |
+
results: SampleResultType = []
|
| 539 |
+
for seq_group in selected_seq_groups:
|
| 540 |
+
if not seq_group.do_sample:
|
| 541 |
+
results.append(([], []))
|
| 542 |
+
continue
|
| 543 |
+
|
| 544 |
+
is_prompt = seq_group.is_prompt
|
| 545 |
+
seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
|
| 546 |
+
num_parent_seqs = len(seq_ids)
|
| 547 |
+
beam_width = sampling_params.n
|
| 548 |
+
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
|
| 549 |
+
if is_prompt:
|
| 550 |
+
# Prompt phase.
|
| 551 |
+
assert num_parent_seqs == 1, (
|
| 552 |
+
"Prompt input should have only one seq.")
|
| 553 |
+
parent_ids = [0] * (2 * beam_width)
|
| 554 |
+
_, next_token_ids = torch.topk(seq_group_logprobs[0],
|
| 555 |
+
2 * beam_width)
|
| 556 |
+
next_token_ids = next_token_ids.tolist()
|
| 557 |
+
else:
|
| 558 |
+
# Generation phase.
|
| 559 |
+
cumulative_logprobs: List[float] = [
|
| 560 |
+
seq_group.seq_data[seq_id].cumulative_logprob
|
| 561 |
+
for seq_id in seq_ids
|
| 562 |
+
]
|
| 563 |
+
cumulative_logprobs_tensor = torch.tensor(
|
| 564 |
+
cumulative_logprobs,
|
| 565 |
+
dtype=torch.float,
|
| 566 |
+
device=seq_group_logprobs.device)
|
| 567 |
+
seq_group_logprobs = (seq_group_logprobs +
|
| 568 |
+
cumulative_logprobs_tensor.unsqueeze(dim=1))
|
| 569 |
+
_, topk_ids = torch.topk(seq_group_logprobs.flatten(),
|
| 570 |
+
2 * beam_width)
|
| 571 |
+
topk_ids = topk_ids.tolist()
|
| 572 |
+
vocab_size = seq_group_logprobs.size(-1)
|
| 573 |
+
parent_ids = [i // vocab_size for i in topk_ids]
|
| 574 |
+
next_token_ids = [i % vocab_size for i in topk_ids]
|
| 575 |
+
results.append((next_token_ids, parent_ids))
|
| 576 |
+
sample_idx += num_parent_seqs
|
| 577 |
+
assert sample_idx == logprobs.size(0)
|
| 578 |
+
return results
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
# torch.multinomial forces a GPU<->CPU sync.
|
| 582 |
+
# Therefore, we use an optimized implementation instead.
|
| 583 |
+
# Note that we always sample with replacement.
|
| 584 |
+
# probs will be modified in place, but this is fine, as we pass
|
| 585 |
+
# in a copy already.
|
| 586 |
+
def _multinomial(
|
| 587 |
+
probs: torch.Tensor,
|
| 588 |
+
num_samples: int,
|
| 589 |
+
seq_groups: Optional[List[SequenceGroupToSample]] = None,
|
| 590 |
+
) -> torch.Tensor:
|
| 591 |
+
if num_samples > 1:
|
| 592 |
+
probs = probs.repeat_interleave(num_samples, dim=0)
|
| 593 |
+
q = torch.empty_like(probs)
|
| 594 |
+
if seq_groups is None:
|
| 595 |
+
q.exponential_()
|
| 596 |
+
else:
|
| 597 |
+
sample_idx = 0
|
| 598 |
+
for seq_group in seq_groups:
|
| 599 |
+
seq_ids = seq_group.seq_ids
|
| 600 |
+
stride = len(seq_ids) * num_samples
|
| 601 |
+
assert seq_group.generator is not None
|
| 602 |
+
q[sample_idx:sample_idx +
|
| 603 |
+
stride].exponential_(generator=seq_group.generator)
|
| 604 |
+
sample_idx += stride
|
| 605 |
+
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def _top_k_top_p_multinomial_with_flashinfer(
|
| 609 |
+
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
|
| 610 |
+
num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]):
|
| 611 |
+
max_top_k_round = 32
|
| 612 |
+
if num_samples > 1:
|
| 613 |
+
probs = probs.repeat_interleave(num_samples, dim=0)
|
| 614 |
+
top_ks = top_ks.repeat_interleave(num_samples)
|
| 615 |
+
top_ps = top_ps.repeat_interleave(num_samples)
|
| 616 |
+
batch_size = probs.shape[0]
|
| 617 |
+
uniform_samples = torch.empty((max_top_k_round, batch_size),
|
| 618 |
+
device=probs.device)
|
| 619 |
+
if seq_groups is None:
|
| 620 |
+
uniform_samples.uniform_()
|
| 621 |
+
else:
|
| 622 |
+
sample_idx = 0
|
| 623 |
+
for seq_group in seq_groups:
|
| 624 |
+
seq_ids = seq_group.seq_ids
|
| 625 |
+
stride = len(seq_ids) * num_samples
|
| 626 |
+
assert seq_group.generator is not None
|
| 627 |
+
uniform_samples[:, sample_idx:sample_idx +
|
| 628 |
+
stride].uniform_(generator=seq_group.generator)
|
| 629 |
+
sample_idx += stride
|
| 630 |
+
batch_next_token_ids, success = flashinfer_top_k_top_p_sampling(
|
| 631 |
+
probs,
|
| 632 |
+
uniform_samples,
|
| 633 |
+
top_ks,
|
| 634 |
+
top_ps,
|
| 635 |
+
)
|
| 636 |
+
if not success.all():
|
| 637 |
+
warnings.warn("FlashInfer rejection sampling failed, fallback.",
|
| 638 |
+
stacklevel=1)
|
| 639 |
+
probs = flashinfer.sampling.top_k_renorm_prob(probs, top_ks)
|
| 640 |
+
probs = flashinfer.sampling.top_p_renorm_prob(probs, top_ps)
|
| 641 |
+
batch_next_token_ids = flashinfer.sampling.sampling_from_probs(
|
| 642 |
+
probs, uniform_samples[0])
|
| 643 |
+
return batch_next_token_ids.view(-1, num_samples)
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
def get_pythonized_sample_results(
|
| 647 |
+
sample_result_args: SampleResultArgsType) -> SampleResultType:
|
| 648 |
+
'''This function consumes GPU-side sampler results and computes
|
| 649 |
+
Pythonized CPU-side sampler results (GPU -> CPU sync.)
|
| 650 |
+
|
| 651 |
+
Single-step scheduling: this function is invoked at sampling-time
|
| 652 |
+
for immediate Pythonization.
|
| 653 |
+
|
| 654 |
+
Multi-step scheduling: Pythonization is deferred until after multiple
|
| 655 |
+
GPU-side steps have been completed.
|
| 656 |
+
|
| 657 |
+
Args:
|
| 658 |
+
sample_result_args: GPU-side inputs to the Pythonization process
|
| 659 |
+
|
| 660 |
+
Returns:
|
| 661 |
+
Pythonized sampler results
|
| 662 |
+
'''
|
| 663 |
+
|
| 664 |
+
(
|
| 665 |
+
sample_metadata,
|
| 666 |
+
sampling_metadata,
|
| 667 |
+
greedy_samples,
|
| 668 |
+
multinomial_samples,
|
| 669 |
+
beam_search_logprobs,
|
| 670 |
+
sample_results_dict,
|
| 671 |
+
) = (
|
| 672 |
+
sample_result_args.sample_metadata,
|
| 673 |
+
sample_result_args.sampling_metadata,
|
| 674 |
+
sample_result_args.greedy_samples,
|
| 675 |
+
sample_result_args.multinomial_samples,
|
| 676 |
+
sample_result_args.beam_search_logprobs,
|
| 677 |
+
sample_result_args.sample_results_dict,
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
for sampling_type in SamplingType:
|
| 681 |
+
if sampling_type not in sample_metadata:
|
| 682 |
+
continue
|
| 683 |
+
(seq_group_id, seq_groups) = sample_metadata[sampling_type]
|
| 684 |
+
if sampling_type == SamplingType.GREEDY:
|
| 685 |
+
sample_results = _greedy_sample(seq_groups, greedy_samples)
|
| 686 |
+
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
| 687 |
+
sample_results = _random_sample(seq_groups,
|
| 688 |
+
multinomial_samples[sampling_type])
|
| 689 |
+
elif sampling_type == SamplingType.BEAM:
|
| 690 |
+
sample_results = _beam_search_sample(seq_groups,
|
| 691 |
+
beam_search_logprobs)
|
| 692 |
+
sample_results_dict.update(zip(seq_group_id, sample_results))
|
| 693 |
+
|
| 694 |
+
return [
|
| 695 |
+
sample_results_dict.get(i, ([], []))
|
| 696 |
+
for i in range(len(sampling_metadata.seq_groups))
|
| 697 |
+
]
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
def _sample_with_torch(
|
| 701 |
+
probs: torch.Tensor,
|
| 702 |
+
logprobs: torch.Tensor,
|
| 703 |
+
sampling_metadata: SamplingMetadata,
|
| 704 |
+
sampling_tensors: SamplingTensors,
|
| 705 |
+
include_gpu_probs_tensor: bool,
|
| 706 |
+
modify_greedy_probs: bool,
|
| 707 |
+
) -> SampleReturnType:
|
| 708 |
+
'''Torch-oriented _sample() implementation.
|
| 709 |
+
|
| 710 |
+
Single-step scheduling:
|
| 711 |
+
* Perform GPU-side sampling computation
|
| 712 |
+
* Immediately Pythonize sampling result
|
| 713 |
+
|
| 714 |
+
Multi-step scheduling:
|
| 715 |
+
* Perform GPU-side sampling computation
|
| 716 |
+
* Defer Pythonization & preserve GPU-side
|
| 717 |
+
tensors required for Pythonization
|
| 718 |
+
'''
|
| 719 |
+
|
| 720 |
+
categorized_seq_group_ids: Dict[SamplingType, List[int]] = {
|
| 721 |
+
t: []
|
| 722 |
+
for t in SamplingType
|
| 723 |
+
}
|
| 724 |
+
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
| 725 |
+
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
| 726 |
+
sampling_params = seq_group.sampling_params
|
| 727 |
+
sampling_type = sampling_params.sampling_type
|
| 728 |
+
categorized_seq_group_ids[sampling_type].append(i)
|
| 729 |
+
|
| 730 |
+
sample_results_dict: SampleResultsDictType = {}
|
| 731 |
+
sample_metadata: SampleMetadataType = {}
|
| 732 |
+
multinomial_samples: MultinomialSamplesType = {}
|
| 733 |
+
greedy_samples: Optional[torch.Tensor] = None
|
| 734 |
+
beam_search_logprobs: Optional[torch.Tensor] = None
|
| 735 |
+
|
| 736 |
+
# Create output tensor for sampled token ids.
|
| 737 |
+
if include_gpu_probs_tensor:
|
| 738 |
+
sampled_token_ids_tensor = torch.full((logprobs.shape[0], 1),
|
| 739 |
+
VLLM_INVALID_TOKEN_ID,
|
| 740 |
+
dtype=torch.long,
|
| 741 |
+
device=logprobs.device)
|
| 742 |
+
else:
|
| 743 |
+
sampled_token_ids_tensor = None
|
| 744 |
+
|
| 745 |
+
# Counterintiutively, having two loops here is actually faster.
|
| 746 |
+
# The first loop can run without waiting on GPU<->CPU sync.
|
| 747 |
+
for sampling_type in SamplingType:
|
| 748 |
+
sample_indices = categorized_sample_indices[sampling_type]
|
| 749 |
+
num_tokens = len(sample_indices)
|
| 750 |
+
if num_tokens == 0:
|
| 751 |
+
continue
|
| 752 |
+
|
| 753 |
+
seq_group_id = categorized_seq_group_ids[sampling_type]
|
| 754 |
+
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
|
| 755 |
+
sample_metadata[sampling_type] = (seq_group_id, seq_groups)
|
| 756 |
+
long_sample_indices = sample_indices.long()
|
| 757 |
+
if sampling_type == SamplingType.GREEDY:
|
| 758 |
+
greedy_samples = torch.argmax(logprobs[long_sample_indices],
|
| 759 |
+
dim=-1)
|
| 760 |
+
|
| 761 |
+
if sampled_token_ids_tensor is not None:
|
| 762 |
+
# Store sampled tokens in output tensor.
|
| 763 |
+
sampled_token_ids_tensor[
|
| 764 |
+
long_sample_indices] = greedy_samples.unsqueeze(-1)
|
| 765 |
+
|
| 766 |
+
if modify_greedy_probs:
|
| 767 |
+
# If required, modify the probabilities such that sampling from
|
| 768 |
+
# the modified distribution would always sample the argmax
|
| 769 |
+
# token id.
|
| 770 |
+
_modify_greedy_probs_inplace(logprobs, probs,
|
| 771 |
+
long_sample_indices,
|
| 772 |
+
greedy_samples)
|
| 773 |
+
|
| 774 |
+
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
| 775 |
+
max_n_in_batch = 1
|
| 776 |
+
for seq_group in seq_groups:
|
| 777 |
+
if seq_group.is_prompt:
|
| 778 |
+
sampling_params = seq_group.sampling_params
|
| 779 |
+
max_n_in_batch = max(max_n_in_batch, sampling_params.n)
|
| 780 |
+
seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
|
| 781 |
+
seq_groups)
|
| 782 |
+
|
| 783 |
+
if flashinfer_top_k_top_p_sampling is not None:
|
| 784 |
+
multinomial_samples[
|
| 785 |
+
sampling_type] = _top_k_top_p_multinomial_with_flashinfer(
|
| 786 |
+
probs[long_sample_indices],
|
| 787 |
+
sampling_tensors.top_ks[long_sample_indices],
|
| 788 |
+
sampling_tensors.top_ps[long_sample_indices],
|
| 789 |
+
max_n_in_batch,
|
| 790 |
+
seq_groups_arg,
|
| 791 |
+
)
|
| 792 |
+
else:
|
| 793 |
+
multinomial_samples[sampling_type] = _multinomial(
|
| 794 |
+
probs[long_sample_indices],
|
| 795 |
+
max_n_in_batch,
|
| 796 |
+
seq_groups=seq_groups_arg)
|
| 797 |
+
|
| 798 |
+
if sampled_token_ids_tensor is not None:
|
| 799 |
+
# Store sampled tokens in output tensor.
|
| 800 |
+
sampled_token_ids_tensor[long_sample_indices] = \
|
| 801 |
+
multinomial_samples[sampling_type].to(torch.long)
|
| 802 |
+
|
| 803 |
+
elif sampling_type == SamplingType.BEAM:
|
| 804 |
+
beam_search_logprobs = logprobs[sample_indices]
|
| 805 |
+
else:
|
| 806 |
+
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
| 807 |
+
|
| 808 |
+
# Encapsulate arguments for computing Pythonized sampler
|
| 809 |
+
# results, whether deferred or otherwise.
|
| 810 |
+
maybe_deferred_args = SampleResultArgsType(
|
| 811 |
+
sampling_metadata=sampling_metadata,
|
| 812 |
+
sample_metadata=sample_metadata,
|
| 813 |
+
multinomial_samples=multinomial_samples,
|
| 814 |
+
greedy_samples=greedy_samples,
|
| 815 |
+
beam_search_logprobs=beam_search_logprobs,
|
| 816 |
+
sample_results_dict=sample_results_dict)
|
| 817 |
+
|
| 818 |
+
if not sampling_metadata.skip_sampler_cpu_output:
|
| 819 |
+
# GPU<->CPU sync happens here.
|
| 820 |
+
# This also converts the sampler output to a Python object.
|
| 821 |
+
# Return Pythonized sampler result & sampled token ids
|
| 822 |
+
return get_pythonized_sample_results(
|
| 823 |
+
maybe_deferred_args), sampled_token_ids_tensor
|
| 824 |
+
else:
|
| 825 |
+
# Defer sampler result Pythonization; return deferred
|
| 826 |
+
# Pythonization args & sampled token ids
|
| 827 |
+
return (
|
| 828 |
+
maybe_deferred_args,
|
| 829 |
+
sampled_token_ids_tensor,
|
| 830 |
+
)
|
| 831 |
+
|
| 832 |
+
|
| 833 |
+
def _sample(
|
| 834 |
+
probs: torch.Tensor,
|
| 835 |
+
logprobs: torch.Tensor,
|
| 836 |
+
sampling_metadata: SamplingMetadata,
|
| 837 |
+
sampling_tensors: SamplingTensors,
|
| 838 |
+
include_gpu_probs_tensor: bool,
|
| 839 |
+
modify_greedy_probs: bool,
|
| 840 |
+
) -> SampleReturnType:
|
| 841 |
+
"""
|
| 842 |
+
Args:
|
| 843 |
+
probs: (num_query_tokens_in_batch, num_vocab)
|
| 844 |
+
logprobs: (num_query_tokens_in_batch, num_vocab)
|
| 845 |
+
sampling_metadata: The metadata for a batch for sampling.
|
| 846 |
+
sampling_tensors: Tensors that include sampling related metadata.
|
| 847 |
+
|
| 848 |
+
Returns:
|
| 849 |
+
(next_token_ids, parent_seq_ids) for each seq group in a batch.
|
| 850 |
+
If sampling is skipped, it returns ([], [])
|
| 851 |
+
sampled_token_ids_tensor: A tensor of sampled token ids.
|
| 852 |
+
"""
|
| 853 |
+
return _sample_with_torch(
|
| 854 |
+
probs,
|
| 855 |
+
logprobs,
|
| 856 |
+
sampling_metadata,
|
| 857 |
+
sampling_tensors,
|
| 858 |
+
include_gpu_probs_tensor=include_gpu_probs_tensor,
|
| 859 |
+
modify_greedy_probs=modify_greedy_probs,
|
| 860 |
+
)
|
| 861 |
+
|
| 862 |
+
|
| 863 |
+
def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
| 864 |
+
"""
|
| 865 |
+
This function calculates the ranks of the chosen tokens in a logprob tensor.
|
| 866 |
+
|
| 867 |
+
Args:
|
| 868 |
+
x (torch.Tensor): 2D logprob tensor of shape (N, M)
|
| 869 |
+
where N is the no. of tokens and M is the vocab dim.
|
| 870 |
+
indices (torch.Tensor): List of chosen token indices.
|
| 871 |
+
|
| 872 |
+
Returns:
|
| 873 |
+
torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
|
| 874 |
+
Each element in the returned tensor represents the rank
|
| 875 |
+
of the chosen token in the input logprob tensor.
|
| 876 |
+
"""
|
| 877 |
+
vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
|
| 878 |
+
indices]
|
| 879 |
+
result = (x > vals[:, None])
|
| 880 |
+
del vals
|
| 881 |
+
return result.sum(1).add_(1)
|
| 882 |
+
|
| 883 |
+
|
| 884 |
+
def get_logprobs(
|
| 885 |
+
logprobs: torch.Tensor,
|
| 886 |
+
sampling_metadata: SamplingMetadata,
|
| 887 |
+
sample_results: SampleResultType,
|
| 888 |
+
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
|
| 889 |
+
"""Return sample logprobs and prompt logprobs.
|
| 890 |
+
|
| 891 |
+
The logic consists of 3 parts.
|
| 892 |
+
- Select indices to compute logprob from, ranks of token ids, and
|
| 893 |
+
the top k token ids from logprobs.
|
| 894 |
+
- Compute prompt logprobs if required.
|
| 895 |
+
- Compute sample logprobs if required.
|
| 896 |
+
|
| 897 |
+
Args:
|
| 898 |
+
logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
|
| 899 |
+
logprob per vocab. Sequence groups' query tokens are batched in a
|
| 900 |
+
single flattened tensor. For example, assuming there are N
|
| 901 |
+
seq groups, it is sorted by prefill tokens for seq_group_1 (if
|
| 902 |
+
prompt logprob is enabled), decode tokens for seq_group_1 (if
|
| 903 |
+
sampling is required), prefill tokens for seq_group_2, ...
|
| 904 |
+
sampling_metadata: The sampling metadata.
|
| 905 |
+
sample_results: (num_seq_groups) The tuple of (next_token_ids,
|
| 906 |
+
parent_ids) for each sequence group. When beam search is enabled,
|
| 907 |
+
sample_results can contain different number of seq_ids from
|
| 908 |
+
sampling_metadata.seq_groups. It is because beam search creates
|
| 909 |
+
2 * BEAM_WIDTH number of samples (whereas there are only up to
|
| 910 |
+
BEAM_WIDTH number of seq_ids).
|
| 911 |
+
|
| 912 |
+
Returns:
|
| 913 |
+
A tuple of prompt and sample logprobs per sequence group in a batch.
|
| 914 |
+
"""
|
| 915 |
+
# The index of query token to calculate logprobs. It includes both
|
| 916 |
+
# prompt and sample logprob indices.
|
| 917 |
+
query_indices: List[int] = []
|
| 918 |
+
# The next token ids to get the logprob value from.
|
| 919 |
+
next_token_ids: List[int] = []
|
| 920 |
+
# The largest requested number of logprobs. We find logprobs as many as the
|
| 921 |
+
# largest num logprobs in this API. If every logprobs is None, it will be
|
| 922 |
+
# set to -1.
|
| 923 |
+
largest_num_logprobs = -1
|
| 924 |
+
|
| 925 |
+
# Select indices to compute logprob from, ranks of token ids, and the top
|
| 926 |
+
# k token ids from logprobs.
|
| 927 |
+
for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
|
| 928 |
+
sample_results):
|
| 929 |
+
sampling_params = seq_group.sampling_params
|
| 930 |
+
|
| 931 |
+
# Update indices and tokens for prompt logprobs.
|
| 932 |
+
if (seq_group.is_prompt
|
| 933 |
+
and sampling_params.prompt_logprobs is not None):
|
| 934 |
+
largest_num_logprobs = max(largest_num_logprobs,
|
| 935 |
+
sampling_params.prompt_logprobs)
|
| 936 |
+
next_prompt_tokens = _get_next_prompt_tokens(seq_group)
|
| 937 |
+
query_indices.extend(seq_group.prompt_logprob_indices)
|
| 938 |
+
next_token_ids.extend(next_prompt_tokens)
|
| 939 |
+
|
| 940 |
+
# Update indices and next tokenes for sample logprob.
|
| 941 |
+
if seq_group.do_sample:
|
| 942 |
+
token_ids, parent_seq_ids = sample_result
|
| 943 |
+
# NOTE: We cannot directly use sample_indices because
|
| 944 |
+
# sample_indices only contain parent seq_ids of a previous step.
|
| 945 |
+
# The current step may have different number of seq_ids, and
|
| 946 |
+
# we can obtain it from `sample_result[1]`.
|
| 947 |
+
query_idx = seq_group.sample_indices[0]
|
| 948 |
+
query_indices.extend(
|
| 949 |
+
[query_idx + parent_id for parent_id in parent_seq_ids])
|
| 950 |
+
next_token_ids.extend(token_ids)
|
| 951 |
+
|
| 952 |
+
if sampling_params.logprobs is not None:
|
| 953 |
+
largest_num_logprobs = max(largest_num_logprobs,
|
| 954 |
+
sampling_params.logprobs)
|
| 955 |
+
|
| 956 |
+
assert len(next_token_ids) == len(query_indices)
|
| 957 |
+
|
| 958 |
+
if len(query_indices) == 0:
|
| 959 |
+
empty_sampled_logprob: SampleLogprobs = []
|
| 960 |
+
empty_prompt_logprob: Optional[PromptLogprobs] = None
|
| 961 |
+
return [empty_prompt_logprob], [empty_sampled_logprob]
|
| 962 |
+
|
| 963 |
+
selected_logprobs, ranks = None, None
|
| 964 |
+
top_logprobs, top_token_ids = None, None
|
| 965 |
+
|
| 966 |
+
# If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
|
| 967 |
+
# skip the whole logprob calculation.
|
| 968 |
+
if largest_num_logprobs >= 0:
|
| 969 |
+
query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
|
| 970 |
+
next_token_ids_gpu = torch.tensor(next_token_ids,
|
| 971 |
+
device=logprobs.device)
|
| 972 |
+
|
| 973 |
+
# (num_selected_query_tokens, num_logprobs). Note that query_indices can
|
| 974 |
+
# contain duplicates if beam search is enabled.
|
| 975 |
+
selected_logprobs = logprobs[[
|
| 976 |
+
query_indices_gpu,
|
| 977 |
+
next_token_ids_gpu,
|
| 978 |
+
]]
|
| 979 |
+
ranks = _get_ranks(
|
| 980 |
+
logprobs[query_indices_gpu],
|
| 981 |
+
next_token_ids_gpu,
|
| 982 |
+
)
|
| 983 |
+
assert selected_logprobs.shape[0] == ranks.shape[0]
|
| 984 |
+
|
| 985 |
+
# We need to compute top k only if there exists logprobs > 0.
|
| 986 |
+
if largest_num_logprobs > 0:
|
| 987 |
+
# Logprobs of topk tokens for a batch of sequence groups.
|
| 988 |
+
# (num_query_tokens_across_batch).
|
| 989 |
+
top_logprobs, top_token_ids = torch.topk(logprobs,
|
| 990 |
+
largest_num_logprobs,
|
| 991 |
+
dim=-1)
|
| 992 |
+
top_logprobs = top_logprobs.to('cpu')
|
| 993 |
+
top_token_ids = top_token_ids.to('cpu')
|
| 994 |
+
|
| 995 |
+
selected_logprobs = selected_logprobs.to('cpu')
|
| 996 |
+
ranks = ranks.to('cpu')
|
| 997 |
+
|
| 998 |
+
# Find prompt/sample logprobs.
|
| 999 |
+
prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
|
| 1000 |
+
sample_logprobs_per_seq_group: List[SampleLogprobs] = []
|
| 1001 |
+
top_logprob_idx = 0
|
| 1002 |
+
selected_logprobs_idx = 0
|
| 1003 |
+
|
| 1004 |
+
for seq_group, sample_result in zip(sampling_metadata.seq_groups,
|
| 1005 |
+
sample_results):
|
| 1006 |
+
(prompt_logprobs, top_logprob_idx,
|
| 1007 |
+
selected_logprobs_idx) = _get_prompt_logprob_if_needed(
|
| 1008 |
+
seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
|
| 1009 |
+
selected_logprobs_idx, top_logprob_idx)
|
| 1010 |
+
prompt_logprobs_per_seq_group.append(prompt_logprobs)
|
| 1011 |
+
|
| 1012 |
+
(sampled_logprobs, top_logprob_idx,
|
| 1013 |
+
selected_logprobs_idx) = _get_sampled_logprob_if_needed(
|
| 1014 |
+
seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
|
| 1015 |
+
top_logprobs, selected_logprobs_idx, top_logprob_idx)
|
| 1016 |
+
sample_logprobs_per_seq_group.append(sampled_logprobs)
|
| 1017 |
+
|
| 1018 |
+
return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group
|
| 1019 |
+
|
| 1020 |
+
|
| 1021 |
+
def _get_prompt_logprob_if_needed(
|
| 1022 |
+
seq_group: SequenceGroupToSample,
|
| 1023 |
+
selected_logprobs: torch.Tensor,
|
| 1024 |
+
ranks: torch.Tensor,
|
| 1025 |
+
top_token_ids: torch.Tensor,
|
| 1026 |
+
top_logprobs: torch.Tensor,
|
| 1027 |
+
selected_logprobs_idx: int,
|
| 1028 |
+
top_logprob_idx: int,
|
| 1029 |
+
):
|
| 1030 |
+
"""Compute the prompt logprob from a sequence group if needed."""
|
| 1031 |
+
sampling_params = seq_group.sampling_params
|
| 1032 |
+
is_prompt = seq_group.is_prompt
|
| 1033 |
+
|
| 1034 |
+
# Find prompt logprobs
|
| 1035 |
+
prompt_logprobs: Optional[PromptLogprobs] = None
|
| 1036 |
+
if is_prompt and sampling_params.prompt_logprobs is not None:
|
| 1037 |
+
prompt_logprobs = []
|
| 1038 |
+
num_logprobs = sampling_params.prompt_logprobs
|
| 1039 |
+
next_prompt_tokens = _get_next_prompt_tokens(seq_group)
|
| 1040 |
+
# Pre-select indexes and create a list. It is faster than calling .item
|
| 1041 |
+
# repetitively.
|
| 1042 |
+
selected_logprob_items = selected_logprobs[
|
| 1043 |
+
selected_logprobs_idx:selected_logprobs_idx +
|
| 1044 |
+
len(next_prompt_tokens)].tolist()
|
| 1045 |
+
rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
|
| 1046 |
+
len(next_prompt_tokens)].tolist()
|
| 1047 |
+
|
| 1048 |
+
for idx, token_id in enumerate(next_prompt_tokens):
|
| 1049 |
+
# Calculate the prompt logprob of the real prompt tokens.
|
| 1050 |
+
# {token_id: (logprob, rank_from_vocab)}
|
| 1051 |
+
prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
|
| 1052 |
+
token_id: (selected_logprob_items[idx], rank_items[idx])
|
| 1053 |
+
}
|
| 1054 |
+
|
| 1055 |
+
# Add top K prompt logprobs along with its rank.
|
| 1056 |
+
if num_logprobs > 0:
|
| 1057 |
+
top_ids = top_token_ids[
|
| 1058 |
+
top_logprob_idx, :num_logprobs].tolist()
|
| 1059 |
+
top_probs = top_logprobs[
|
| 1060 |
+
top_logprob_idx, :num_logprobs].tolist()
|
| 1061 |
+
# Top K is already sorted by rank, so we can use 1 ~
|
| 1062 |
+
# num_logprobs + 1 for rank.
|
| 1063 |
+
top_ranks = range(1, num_logprobs + 1)
|
| 1064 |
+
prompt_logprobs_dict.update({
|
| 1065 |
+
top_id: (top_prob, rank)
|
| 1066 |
+
for top_id, top_prob, rank in zip(top_ids, top_probs,
|
| 1067 |
+
top_ranks)
|
| 1068 |
+
})
|
| 1069 |
+
prompt_logprobs.append({
|
| 1070 |
+
token_id: Logprob(*logprob_and_rank)
|
| 1071 |
+
for token_id, logprob_and_rank in prompt_logprobs_dict.items()
|
| 1072 |
+
})
|
| 1073 |
+
# + 1 to go to the next prompt token.
|
| 1074 |
+
top_logprob_idx += 1
|
| 1075 |
+
|
| 1076 |
+
# + len(next_prompt_tokens) to go to the next prompt.
|
| 1077 |
+
selected_logprobs_idx += len(next_prompt_tokens)
|
| 1078 |
+
return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
|
| 1079 |
+
|
| 1080 |
+
|
| 1081 |
+
def _get_sampled_logprob_if_needed(
|
| 1082 |
+
seq_group: SequenceGroupToSample,
|
| 1083 |
+
sample_result: Tuple[List[int], List[int]],
|
| 1084 |
+
selected_logprobs: torch.Tensor,
|
| 1085 |
+
ranks: torch.Tensor,
|
| 1086 |
+
top_token_ids: torch.Tensor,
|
| 1087 |
+
top_logprobs: torch.Tensor,
|
| 1088 |
+
selected_logprobs_idx: int,
|
| 1089 |
+
top_logprob_idx: int,
|
| 1090 |
+
):
|
| 1091 |
+
"""Compute the sample logprob if needed."""
|
| 1092 |
+
seq_ids = seq_group.seq_ids
|
| 1093 |
+
num_logprobs = seq_group.sampling_params.logprobs
|
| 1094 |
+
sampled_logprobs: SampleLogprobs = []
|
| 1095 |
+
next_token_ids, parent_seq_ids = sample_result
|
| 1096 |
+
|
| 1097 |
+
if seq_group.do_sample:
|
| 1098 |
+
assert len(next_token_ids) > 0
|
| 1099 |
+
if num_logprobs is None:
|
| 1100 |
+
for next_token_id in next_token_ids:
|
| 1101 |
+
# Use a dummy logprob
|
| 1102 |
+
sampled_logprobs.append({next_token_id: Logprob(inf)})
|
| 1103 |
+
else:
|
| 1104 |
+
# Pre-select items from tensor. tolist() is faster than repetitive
|
| 1105 |
+
# `.item()` calls.
|
| 1106 |
+
selected_logprob_items = selected_logprobs[
|
| 1107 |
+
selected_logprobs_idx:selected_logprobs_idx +
|
| 1108 |
+
len(next_token_ids)].tolist()
|
| 1109 |
+
rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
|
| 1110 |
+
len(next_token_ids)].tolist()
|
| 1111 |
+
for idx, (next_token_id, parent_id) in enumerate(
|
| 1112 |
+
zip(next_token_ids, parent_seq_ids)):
|
| 1113 |
+
# Get the logprob of a sampled token.
|
| 1114 |
+
sampled_logprobs_dict = {
|
| 1115 |
+
next_token_id:
|
| 1116 |
+
(selected_logprob_items[idx], rank_items[idx])
|
| 1117 |
+
}
|
| 1118 |
+
if num_logprobs is not None and num_logprobs > 0:
|
| 1119 |
+
# Get top K logprobs.
|
| 1120 |
+
top_ids = top_token_ids[top_logprob_idx +
|
| 1121 |
+
parent_id, :num_logprobs].tolist()
|
| 1122 |
+
top_probs = top_logprobs[
|
| 1123 |
+
top_logprob_idx + parent_id, :num_logprobs].tolist()
|
| 1124 |
+
# Top K is already sorted by rank, so we can use 1 ~
|
| 1125 |
+
# num_logprobs + 1 for rank.
|
| 1126 |
+
top_ranks = range(1, num_logprobs + 1)
|
| 1127 |
+
sampled_logprobs_dict.update({
|
| 1128 |
+
top_id: (top_prob, rank)
|
| 1129 |
+
for top_id, top_prob, rank in zip(
|
| 1130 |
+
top_ids, top_probs, top_ranks)
|
| 1131 |
+
})
|
| 1132 |
+
|
| 1133 |
+
sampled_logprobs.append({
|
| 1134 |
+
token_id: Logprob(*logprob_and_rank)
|
| 1135 |
+
for token_id, logprob_and_rank in
|
| 1136 |
+
sampled_logprobs_dict.items()
|
| 1137 |
+
})
|
| 1138 |
+
|
| 1139 |
+
# NOTE: This part of code is not intuitive. `selected_logprobs` include
|
| 1140 |
+
# logprobs for the current step, which has len(next_token_ids) tokens
|
| 1141 |
+
# per sequence group. `logprobs` includes logprobs from the previous
|
| 1142 |
+
# steps, which has len(seq_ids) tokens per sequence group.
|
| 1143 |
+
|
| 1144 |
+
# Iterate to the next sequence group in a batch.
|
| 1145 |
+
selected_logprobs_idx += len(next_token_ids)
|
| 1146 |
+
# Iterate to the next sequence group in a batch.
|
| 1147 |
+
top_logprob_idx += len(seq_ids)
|
| 1148 |
+
return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
|
| 1149 |
+
|
| 1150 |
+
|
| 1151 |
+
def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
|
| 1152 |
+
sample_indices: torch.Tensor,
|
| 1153 |
+
greedy_samples: torch.Tensor) -> None:
|
| 1154 |
+
"""Modify the probability distributions of the greedily-sampled tokens such
|
| 1155 |
+
that each sampled token has a "probability" of 1.0. This is required by
|
| 1156 |
+
speculative decoding, which depends on the sampling method being encoded
|
| 1157 |
+
within the probability distribution for correctness.
|
| 1158 |
+
|
| 1159 |
+
# Why do we only need to do this for greedy sampling?
|
| 1160 |
+
|
| 1161 |
+
vLLM's sampler performs the following steps for greedy or multinomial
|
| 1162 |
+
(random) sampling:
|
| 1163 |
+
1. Get logits from model.
|
| 1164 |
+
2. Modify logits according to per-sequence sampling parameters.
|
| 1165 |
+
- Multiply by temperature, top-k and top-p masking, penalize tokens
|
| 1166 |
+
according to their frequency, etc.
|
| 1167 |
+
3. Sample a token.
|
| 1168 |
+
- Random sampling simply samples from the modified probability
|
| 1169 |
+
distribution.
|
| 1170 |
+
- Greedy sampling performs `argmax` to obtain the token with the
|
| 1171 |
+
highest likelihood.
|
| 1172 |
+
|
| 1173 |
+
Ignoring greedy sampling for a moment, we find that the computed probability
|
| 1174 |
+
distribution has the following property: we can sample from it independently
|
| 1175 |
+
and find that the token sampled by the Sampler has a frequency corresponding
|
| 1176 |
+
to how often we see it in our sampling. In other words, for tokens sampled
|
| 1177 |
+
with vLLM's random SamplingType, the computed probability distribution
|
| 1178 |
+
encodes the sampling methodology completely.
|
| 1179 |
+
|
| 1180 |
+
Greedy sampling does not normally have this property. vLLM modifies logits
|
| 1181 |
+
according to sampling params, then performs `argmax`, then returns the
|
| 1182 |
+
sampled token and the computed probability distribution. If we sample from
|
| 1183 |
+
the distribution, we'll find the likelihood of the greedily-sampled token
|
| 1184 |
+
is not always 1.0.
|
| 1185 |
+
|
| 1186 |
+
Since lossless speculative decoding requires that the sampling methodology
|
| 1187 |
+
be encoded within the probability distribution, we are motivated to modify
|
| 1188 |
+
the probability distribution such that the sampled token has probability 1
|
| 1189 |
+
when speculative decoding is used.
|
| 1190 |
+
|
| 1191 |
+
NOTE: Alternatively, we could use an extremely low temperature to achieve
|
| 1192 |
+
greedy sampling using multinomial computation and unite the codepaths. This
|
| 1193 |
+
has implications on the overall design of the sampler, e.g. how to record
|
| 1194 |
+
accurate logprobs for the user, so this improvement is deferred to later.
|
| 1195 |
+
"""
|
| 1196 |
+
# NOTE: logprobs are not modified so they can be returned to the user.
|
| 1197 |
+
probs[sample_indices, :] = 0
|
| 1198 |
+
probs[sample_indices, greedy_samples] = 1.0
|
| 1199 |
+
|
| 1200 |
+
|
| 1201 |
+
def _build_sampler_output(
|
| 1202 |
+
maybe_deferred_sample_results: MaybeDeferredSampleResultType,
|
| 1203 |
+
sampling_metadata: SamplingMetadata,
|
| 1204 |
+
prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
|
| 1205 |
+
sample_logprobs: Optional[List[SampleLogprobs]],
|
| 1206 |
+
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
|
| 1207 |
+
torch.Tensor]],
|
| 1208 |
+
skip_sampler_cpu_output: bool = False,
|
| 1209 |
+
) -> SamplerOutput:
|
| 1210 |
+
"""Construct Python objects with the output of sampling.
|
| 1211 |
+
|
| 1212 |
+
Args:
|
| 1213 |
+
on_device_tensors: Tuple containing on-device tensors with the
|
| 1214 |
+
probabilities used in sampling and the sampled token ids. This
|
| 1215 |
+
allows post-processing without copies to CPU/serialization, e.g. in
|
| 1216 |
+
speculative decoding rejection sampling.
|
| 1217 |
+
"""
|
| 1218 |
+
sampler_output: List[CompletionSequenceGroupOutput] = []
|
| 1219 |
+
|
| 1220 |
+
if skip_sampler_cpu_output:
|
| 1221 |
+
assert isinstance(maybe_deferred_sample_results, SampleResultArgsType)
|
| 1222 |
+
deferred_sample_results_args = maybe_deferred_sample_results
|
| 1223 |
+
else:
|
| 1224 |
+
assert prompt_logprobs is not None
|
| 1225 |
+
assert sample_logprobs is not None
|
| 1226 |
+
assert not isinstance(maybe_deferred_sample_results,
|
| 1227 |
+
SampleResultArgsType)
|
| 1228 |
+
deferred_sample_results_args = None
|
| 1229 |
+
|
| 1230 |
+
for (seq_group, sample_result, group_prompt_logprobs,
|
| 1231 |
+
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
|
| 1232 |
+
maybe_deferred_sample_results,
|
| 1233 |
+
prompt_logprobs, sample_logprobs):
|
| 1234 |
+
seq_ids = seq_group.seq_ids
|
| 1235 |
+
next_token_ids, parent_ids = sample_result
|
| 1236 |
+
seq_outputs: List[SequenceOutput] = []
|
| 1237 |
+
for parent_id, next_token_id, logprobs in zip(
|
| 1238 |
+
parent_ids, next_token_ids, group_sample_logprobs):
|
| 1239 |
+
seq_outputs.append(
|
| 1240 |
+
SequenceOutput(seq_ids[parent_id], next_token_id,
|
| 1241 |
+
logprobs))
|
| 1242 |
+
sampler_output.append(
|
| 1243 |
+
CompletionSequenceGroupOutput(seq_outputs,
|
| 1244 |
+
group_prompt_logprobs))
|
| 1245 |
+
|
| 1246 |
+
# If not specified, store None values in SamplerOutput.
|
| 1247 |
+
if on_device_tensors is not None:
|
| 1248 |
+
(sampled_token_probs, logprobs_tensor,
|
| 1249 |
+
sampled_token_ids) = on_device_tensors
|
| 1250 |
+
else:
|
| 1251 |
+
sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
|
| 1252 |
+
None)
|
| 1253 |
+
|
| 1254 |
+
return SamplerOutput(
|
| 1255 |
+
outputs=sampler_output,
|
| 1256 |
+
sampled_token_probs=sampled_token_probs,
|
| 1257 |
+
sampled_token_ids=sampled_token_ids,
|
| 1258 |
+
logprobs=logprobs_tensor,
|
| 1259 |
+
deferred_sample_results_args=deferred_sample_results_args)
|
| 1260 |
+
|
| 1261 |
+
|
| 1262 |
+
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
|
| 1263 |
+
"""Get a list of next prompt tokens to compute logprob from a
|
| 1264 |
+
given sequence group.
|
| 1265 |
+
|
| 1266 |
+
It is used to compute prompt logprob. Imagine you have logprob for each
|
| 1267 |
+
query token. Query token needs to know the next prompt token id to compute
|
| 1268 |
+
prompt logprob. This is a helper to obtain next prompt token ids.
|
| 1269 |
+
|
| 1270 |
+
This API has to be used only when the caller knows seq_group is in prefill
|
| 1271 |
+
stage.
|
| 1272 |
+
|
| 1273 |
+
Returns:
|
| 1274 |
+
A list of next prompt tokens to compute logprob.
|
| 1275 |
+
"""
|
| 1276 |
+
assert seq_group.is_prompt, (
|
| 1277 |
+
"Caller should ensure the sequence group is in a prefill stage.")
|
| 1278 |
+
seq_ids = seq_group.seq_ids
|
| 1279 |
+
query_len = seq_group.query_len
|
| 1280 |
+
assert query_len is not None
|
| 1281 |
+
# prompt has only 1 seq id.
|
| 1282 |
+
assert len(seq_ids) == 1
|
| 1283 |
+
seq_data = seq_group.seq_data[seq_ids[0]]
|
| 1284 |
+
computed_len = seq_data.get_num_computed_tokens()
|
| 1285 |
+
prompt_tokens = seq_data.prompt_token_ids
|
| 1286 |
+
# +1 because we are looking for a next prompt token.
|
| 1287 |
+
next_token_index_start = computed_len + 1
|
| 1288 |
+
next_token_index_end = min(computed_len + query_len + 1,
|
| 1289 |
+
len(prompt_tokens))
|
| 1290 |
+
next_prompt_tokens = prompt_tokens[
|
| 1291 |
+
next_token_index_start:next_token_index_end]
|
| 1292 |
+
return next_prompt_tokens
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/spec_decode_base_sampler.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from abc import abstractmethod
|
| 4 |
+
from typing import Dict, Optional, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.jit
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class SpecDecodeBaseSampler(nn.Module):
|
| 12 |
+
"""Base class for samplers used for Speculative Decoding verification
|
| 13 |
+
step.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, strict_mode: bool = False):
|
| 17 |
+
"""Base class constructor.
|
| 18 |
+
Args:
|
| 19 |
+
strict_mode: Whether or not to perform shape/device/dtype checks
|
| 20 |
+
during sampling. This catches correctness issues but adds
|
| 21 |
+
nontrivial latency.
|
| 22 |
+
"""
|
| 23 |
+
super().__init__()
|
| 24 |
+
self._strict_mode = strict_mode
|
| 25 |
+
|
| 26 |
+
# NOTE: A "bonus token" is accepted iff all proposal tokens are
|
| 27 |
+
# accepted. There is always only one possible bonus token. We store this
|
| 28 |
+
# value in a variable for readability.
|
| 29 |
+
self._num_bonus_tokens = 1
|
| 30 |
+
|
| 31 |
+
self.num_accepted_tokens: Optional[torch.Tensor] = None
|
| 32 |
+
self.num_emitted_tokens: Optional[torch.Tensor] = None
|
| 33 |
+
self.num_draft_tokens: int = 0
|
| 34 |
+
|
| 35 |
+
def init_gpu_tensors(self, device: Union[int, str]) -> None:
|
| 36 |
+
assert self.num_accepted_tokens is None
|
| 37 |
+
if isinstance(device, int):
|
| 38 |
+
device = f"cuda:{device}"
|
| 39 |
+
elif not isinstance(device, str):
|
| 40 |
+
raise ValueError(f"Device must be int or str, get {type(device)}")
|
| 41 |
+
self.num_accepted_tokens = torch.tensor(0,
|
| 42 |
+
dtype=torch.long,
|
| 43 |
+
device=device)
|
| 44 |
+
self.num_emitted_tokens = torch.tensor(0,
|
| 45 |
+
dtype=torch.long,
|
| 46 |
+
device=device)
|
| 47 |
+
|
| 48 |
+
def init_tensors(self,
|
| 49 |
+
device: Union[int, str],
|
| 50 |
+
device_type: Union[torch.device, str] = 'cuda') -> None:
|
| 51 |
+
assert self.num_accepted_tokens is None
|
| 52 |
+
if isinstance(device_type, torch.device):
|
| 53 |
+
device_type = device_type.type
|
| 54 |
+
if isinstance(device, int):
|
| 55 |
+
device = f"{device_type}:{device}"
|
| 56 |
+
self.num_accepted_tokens = torch.tensor(0,
|
| 57 |
+
dtype=torch.long,
|
| 58 |
+
device=device)
|
| 59 |
+
self.num_emitted_tokens = torch.tensor(0,
|
| 60 |
+
dtype=torch.long,
|
| 61 |
+
device=device)
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def probs_dtype(self):
|
| 65 |
+
return torch.float32
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def token_id_dtype(self):
|
| 69 |
+
return torch.int64
|
| 70 |
+
|
| 71 |
+
def _create_output(
|
| 72 |
+
self,
|
| 73 |
+
accepted: torch.Tensor, # [batch_size, k]
|
| 74 |
+
substitute_token_ids: torch.Tensor, # [batch_size, k]
|
| 75 |
+
draft_token_ids: torch.Tensor, # [batch_size, k]
|
| 76 |
+
bonus_token_ids: torch.Tensor, # [batch_size]
|
| 77 |
+
) -> torch.Tensor:
|
| 78 |
+
"""Format output. Returns a matrix of token ids. When
|
| 79 |
+
a token is rejected via sampling, all subsequent token ids are
|
| 80 |
+
set to -1 for the sequence.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
accepted: A boolean tensor indicating if the corresponding
|
| 84 |
+
draft token in draft_token_ids should be accepted or not.
|
| 85 |
+
substitute_token_ids: A tensor of token_ids that can be used
|
| 86 |
+
as substitutes for the draft token ids if the proposed token
|
| 87 |
+
is rejected.
|
| 88 |
+
draft_token_ids: A tensor of token ids speculated by the
|
| 89 |
+
draft model.
|
| 90 |
+
bonus_token_ids: Token ids to use as the bonus token if
|
| 91 |
+
all the draft tokens are accepted.
|
| 92 |
+
Returns:
|
| 93 |
+
A tensor containing the accepted token ids. The shape of the
|
| 94 |
+
tensor is [batch_size, k + num_bonus_tokens]
|
| 95 |
+
"""
|
| 96 |
+
batch_size, k = substitute_token_ids.shape
|
| 97 |
+
bonus_token_ids = bonus_token_ids.squeeze(-1)
|
| 98 |
+
# Determine the index of the first False value for each row.
|
| 99 |
+
limits = (accepted == 0).max(1).indices
|
| 100 |
+
limits[~(accepted == 0).any(1)] = k
|
| 101 |
+
|
| 102 |
+
# Create masks using the indices.
|
| 103 |
+
indices = torch.arange(k, device=accepted.device).unsqueeze(0)
|
| 104 |
+
accepted_mask = indices < limits.unsqueeze(1)
|
| 105 |
+
after_false_mask = indices == limits.unsqueeze(1)
|
| 106 |
+
|
| 107 |
+
# Create an extended output tensor
|
| 108 |
+
output_with_bonus_tokens = -torch.ones(
|
| 109 |
+
(batch_size, k + self._num_bonus_tokens),
|
| 110 |
+
dtype=self.token_id_dtype,
|
| 111 |
+
device=accepted.device)
|
| 112 |
+
output = output_with_bonus_tokens[:, :k]
|
| 113 |
+
|
| 114 |
+
# Fill in the first k columns of the output tensor using masks and data
|
| 115 |
+
# tensors.
|
| 116 |
+
output[:, :k] = torch.where(accepted_mask, draft_token_ids,
|
| 117 |
+
-torch.ones_like(draft_token_ids))
|
| 118 |
+
|
| 119 |
+
# Fill the last column.
|
| 120 |
+
# We check output directly as accepted may have True values inconsistent
|
| 121 |
+
# with causal acceptance.
|
| 122 |
+
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
|
| 123 |
+
bonus_token_ids, -1)
|
| 124 |
+
|
| 125 |
+
# Fill the recovered token ids.
|
| 126 |
+
output.mul_(~after_false_mask).add_(
|
| 127 |
+
substitute_token_ids.mul(after_false_mask))
|
| 128 |
+
|
| 129 |
+
self.num_accepted_tokens += accepted.sum()
|
| 130 |
+
self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()
|
| 131 |
+
self.num_draft_tokens += batch_size * k
|
| 132 |
+
|
| 133 |
+
return output_with_bonus_tokens
|
| 134 |
+
|
| 135 |
+
def _raise_if_incorrect_input(
|
| 136 |
+
self,
|
| 137 |
+
target_with_bonus_probs: torch.Tensor,
|
| 138 |
+
draft_token_ids: torch.Tensor,
|
| 139 |
+
bonus_token_ids: torch.Tensor,
|
| 140 |
+
draft_probs: Optional[torch.Tensor] = None,
|
| 141 |
+
) -> None:
|
| 142 |
+
self._raise_if_incorrect_shape(target_with_bonus_probs,
|
| 143 |
+
draft_token_ids, bonus_token_ids,
|
| 144 |
+
draft_probs)
|
| 145 |
+
self._raise_if_incorrect_dtype(target_with_bonus_probs,
|
| 146 |
+
draft_token_ids, bonus_token_ids,
|
| 147 |
+
draft_probs)
|
| 148 |
+
self._raise_if_inconsistent_device(target_with_bonus_probs,
|
| 149 |
+
draft_token_ids, bonus_token_ids,
|
| 150 |
+
draft_probs)
|
| 151 |
+
self._raise_if_out_of_bounds_vocab(target_with_bonus_probs.shape[-1],
|
| 152 |
+
draft_token_ids, bonus_token_ids)
|
| 153 |
+
|
| 154 |
+
def _raise_if_incorrect_shape(
|
| 155 |
+
self,
|
| 156 |
+
target_with_bonus_probs: torch.Tensor,
|
| 157 |
+
draft_token_ids: torch.Tensor,
|
| 158 |
+
bonus_token_ids: torch.Tensor,
|
| 159 |
+
draft_probs: Optional[torch.Tensor] = None,
|
| 160 |
+
) -> None:
|
| 161 |
+
(target_batch_size, num_target_probs,
|
| 162 |
+
target_vocab_size) = target_with_bonus_probs.shape
|
| 163 |
+
|
| 164 |
+
# Does not count the extra token
|
| 165 |
+
num_target_probs -= 1
|
| 166 |
+
|
| 167 |
+
# validate the shape of draft token ids.
|
| 168 |
+
draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
|
| 169 |
+
assert draft_token_ids_batch_size == target_batch_size
|
| 170 |
+
assert num_draft_token_ids == num_target_probs
|
| 171 |
+
|
| 172 |
+
# validate the shape of bonus token ids
|
| 173 |
+
bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape
|
| 174 |
+
assert bonus_batch_size == target_batch_size
|
| 175 |
+
assert num_bonus_tokens == self._num_bonus_tokens
|
| 176 |
+
|
| 177 |
+
# validate the shape of draft probs if it is set
|
| 178 |
+
if draft_probs is not None:
|
| 179 |
+
(draft_batch_size, num_draft_probs,
|
| 180 |
+
draft_vocab_size) = draft_probs.shape
|
| 181 |
+
assert draft_batch_size == target_batch_size
|
| 182 |
+
assert num_draft_probs == num_target_probs
|
| 183 |
+
assert (draft_vocab_size == target_vocab_size
|
| 184 |
+
), f"{draft_vocab_size=} {target_vocab_size=}"
|
| 185 |
+
|
| 186 |
+
def _raise_if_incorrect_dtype(
|
| 187 |
+
self,
|
| 188 |
+
target_with_bonus_probs: torch.Tensor,
|
| 189 |
+
draft_token_ids: torch.Tensor,
|
| 190 |
+
bonus_token_ids: torch.Tensor,
|
| 191 |
+
draft_probs: Optional[torch.Tensor] = None,
|
| 192 |
+
) -> None:
|
| 193 |
+
assert target_with_bonus_probs.dtype == self.probs_dtype
|
| 194 |
+
assert draft_token_ids.dtype == self.token_id_dtype
|
| 195 |
+
assert bonus_token_ids.dtype == self.token_id_dtype
|
| 196 |
+
if draft_probs is not None:
|
| 197 |
+
assert draft_probs.dtype == self.probs_dtype
|
| 198 |
+
|
| 199 |
+
def _raise_if_inconsistent_device(
|
| 200 |
+
self,
|
| 201 |
+
target_with_bonus_probs: torch.Tensor,
|
| 202 |
+
draft_token_ids: torch.Tensor,
|
| 203 |
+
bonus_token_ids: torch.Tensor,
|
| 204 |
+
draft_probs: Optional[torch.Tensor] = None,
|
| 205 |
+
) -> None:
|
| 206 |
+
devices = [
|
| 207 |
+
t.device for t in [
|
| 208 |
+
target_with_bonus_probs, bonus_token_ids, draft_probs,
|
| 209 |
+
draft_token_ids
|
| 210 |
+
] if t is not None
|
| 211 |
+
]
|
| 212 |
+
assert all([devices[0] == device for device in devices])
|
| 213 |
+
|
| 214 |
+
def _raise_if_out_of_bounds_vocab(
|
| 215 |
+
self,
|
| 216 |
+
vocab_size: int,
|
| 217 |
+
draft_token_ids: torch.Tensor,
|
| 218 |
+
bonus_token_ids: torch.Tensor,
|
| 219 |
+
) -> None:
|
| 220 |
+
assert torch.all(bonus_token_ids < vocab_size)
|
| 221 |
+
assert torch.all(bonus_token_ids >= 0)
|
| 222 |
+
assert torch.all(draft_token_ids < vocab_size)
|
| 223 |
+
assert torch.all(draft_token_ids >= 0)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler):
|
| 227 |
+
"""Base class for samplers used for Speculative Decoding verification
|
| 228 |
+
step which are deterministic.
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
@abstractmethod
|
| 232 |
+
def forward(
|
| 233 |
+
self,
|
| 234 |
+
target_with_bonus_probs: torch.Tensor,
|
| 235 |
+
bonus_token_ids: torch.Tensor,
|
| 236 |
+
draft_probs: torch.Tensor,
|
| 237 |
+
draft_token_ids: torch.Tensor,
|
| 238 |
+
) -> torch.Tensor:
|
| 239 |
+
raise NotImplementedError
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
|
| 243 |
+
"""Base class for samplers used for Speculative Decoding verification
|
| 244 |
+
step which are stochastic
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
@abstractmethod
|
| 248 |
+
def forward(
|
| 249 |
+
self,
|
| 250 |
+
target_with_bonus_probs: torch.Tensor,
|
| 251 |
+
bonus_token_ids: torch.Tensor,
|
| 252 |
+
draft_probs: torch.Tensor,
|
| 253 |
+
draft_token_ids: torch.Tensor,
|
| 254 |
+
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
|
| 255 |
+
) -> torch.Tensor:
|
| 256 |
+
raise NotImplementedError
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/typical_acceptance_sampler.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.jit
|
| 5 |
+
|
| 6 |
+
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
| 7 |
+
SpecDecodeDeterministicBaseSampler)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
|
| 11 |
+
"""Apply typical acceptance sampling as described in section 3.3.1 in
|
| 12 |
+
"MEDUSA: Simple LLM Inference Acceleration Framework with
|
| 13 |
+
Multiple Decoding Heads"
|
| 14 |
+
https://arxiv.org/pdf/2401.10774
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
posterior_threshold: float,
|
| 20 |
+
posterior_alpha: float,
|
| 21 |
+
strict_mode: bool = False,
|
| 22 |
+
):
|
| 23 |
+
"""Create a Typical Acceptance Sampler.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
strict_mode: Whether or not to perform shape/device/dtype checks
|
| 27 |
+
during sampling. This catches correctness issues but adds
|
| 28 |
+
nontrivial latency.
|
| 29 |
+
posterior_threshold : A threshold value that sets a lower bound
|
| 30 |
+
on the posterior probability of a token in target model for it
|
| 31 |
+
to be accepted.
|
| 32 |
+
posterior_alpha : A scaling factor for the entropy-based
|
| 33 |
+
threshold in typical acceptance sampling.
|
| 34 |
+
"""
|
| 35 |
+
self._posterior_threshold = posterior_threshold
|
| 36 |
+
self._posterior_alpha = posterior_alpha
|
| 37 |
+
super().__init__(strict_mode=strict_mode)
|
| 38 |
+
|
| 39 |
+
def forward(
|
| 40 |
+
self,
|
| 41 |
+
target_with_bonus_probs: torch.Tensor,
|
| 42 |
+
bonus_token_ids: torch.Tensor,
|
| 43 |
+
draft_probs: torch.Tensor,
|
| 44 |
+
draft_token_ids: torch.Tensor,
|
| 45 |
+
) -> torch.Tensor:
|
| 46 |
+
"""Sample token ids using typical acceptance sampling. This accepts
|
| 47 |
+
or rejects tokens proposed by the draft model using the probability
|
| 48 |
+
of each token according to the draft and target models.
|
| 49 |
+
|
| 50 |
+
In the worst case where all draft tokens are rejected, it is guaranteed
|
| 51 |
+
one token will be emitted.
|
| 52 |
+
|
| 53 |
+
In the case where all draft tokens are accepted, the bonus token will be
|
| 54 |
+
accepted.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
target_probs: The probability distribution over token ids given
|
| 58 |
+
context according to the target model.
|
| 59 |
+
shape = [batch_size, num_speculative_tokens, vocab_size]
|
| 60 |
+
|
| 61 |
+
bonus_token_ids: The "bonus" token ids that are accepted iff all
|
| 62 |
+
speculative tokens in a sequence are accepted.
|
| 63 |
+
shape = [batch_size, num_bonus_tokens]
|
| 64 |
+
|
| 65 |
+
draft_probs: This parameter is unused by the acceptance sampler.
|
| 66 |
+
|
| 67 |
+
draft_token_ids: The token ids that were sampled from the draft
|
| 68 |
+
probabilities.
|
| 69 |
+
shape = [batch_size, num_speculative_tokens]
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
output_token_ids: The token ids sampled via rejection sampling,
|
| 73 |
+
or -1 if unable to sample a token because the previous token
|
| 74 |
+
was rejected.
|
| 75 |
+
shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
|
| 76 |
+
"""
|
| 77 |
+
# Only perform shape/dtype/device checking in strict mode, as it adds
|
| 78 |
+
# overhead.
|
| 79 |
+
if self._strict_mode:
|
| 80 |
+
self._raise_if_incorrect_input(target_with_bonus_probs,
|
| 81 |
+
draft_token_ids, bonus_token_ids)
|
| 82 |
+
target_probs = target_with_bonus_probs[:, :-1]
|
| 83 |
+
accepted = self._evaluate_accepted_tokens(target_probs,
|
| 84 |
+
draft_token_ids)
|
| 85 |
+
recovered_token_ids = self._get_recovered_token_ids(target_probs)
|
| 86 |
+
output_token_ids = self._create_output(accepted, recovered_token_ids,
|
| 87 |
+
draft_token_ids,
|
| 88 |
+
bonus_token_ids)
|
| 89 |
+
return output_token_ids
|
| 90 |
+
|
| 91 |
+
def _evaluate_accepted_tokens(self, target_probs, draft_token_ids):
|
| 92 |
+
r"""
|
| 93 |
+
Evaluates and returns a mask of accepted tokens based on the
|
| 94 |
+
posterior probabilities.
|
| 95 |
+
|
| 96 |
+
Parameters:
|
| 97 |
+
----------
|
| 98 |
+
target_probs : torch.Tensor
|
| 99 |
+
A tensor of shape (batch_size, k, vocab_size) representing
|
| 100 |
+
the probabilities of each token in the vocabulary for each
|
| 101 |
+
position in the proposed sequence. This is the distribution
|
| 102 |
+
generated by the target model.
|
| 103 |
+
draft_token_ids : torch.Tensor
|
| 104 |
+
A tensor of shape (batch_size, k) representing the proposed
|
| 105 |
+
token ids.
|
| 106 |
+
|
| 107 |
+
A draft token_id x_{n+k} is accepted if it satisfies the
|
| 108 |
+
following condition
|
| 109 |
+
|
| 110 |
+
.. math::
|
| 111 |
+
p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) >
|
| 112 |
+
\min \left( \epsilon, \delta * \exp \left(
|
| 113 |
+
-H(p_{\text{original}}(
|
| 114 |
+
\cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right)
|
| 115 |
+
|
| 116 |
+
where :math:`p_{\text{original}}` corresponds to target_probs
|
| 117 |
+
and :math:`\epsilon` and :math:`\delta` correspond to hyperparameters
|
| 118 |
+
specified using self._posterior_threshold and self._posterior_alpha
|
| 119 |
+
|
| 120 |
+
This method computes the posterior probabilities for the given
|
| 121 |
+
draft token ids based on the provided target probabilities. It
|
| 122 |
+
calculates the entropy of the posterior distribution and determines
|
| 123 |
+
a dynamic threshold for each token position using the provided
|
| 124 |
+
posterior_threshold and posterior_alpha values. The method then
|
| 125 |
+
returns a boolean mask indicating which tokens can be accepted.
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
-------
|
| 129 |
+
torch.Tensor
|
| 130 |
+
A boolean tensor of shape (batch_size, k) where each element
|
| 131 |
+
indicates whether the corresponding draft token has been accepted
|
| 132 |
+
or rejected. True indicates acceptance and false indicates
|
| 133 |
+
rejection.
|
| 134 |
+
|
| 135 |
+
"""
|
| 136 |
+
device = target_probs.device
|
| 137 |
+
candidates_prob = torch.gather(
|
| 138 |
+
target_probs, dim=-1,
|
| 139 |
+
index=draft_token_ids.unsqueeze(-1)).squeeze(-1)
|
| 140 |
+
# A small constant added to prevent computing the logarithm of zero,
|
| 141 |
+
# which can lead to undefined values.
|
| 142 |
+
epsilon = 1e-5
|
| 143 |
+
posterior_entropy = -torch.sum(
|
| 144 |
+
target_probs * torch.log(target_probs + epsilon), dim=-1)
|
| 145 |
+
threshold = torch.minimum(
|
| 146 |
+
torch.ones_like(posterior_entropy, device=device) *
|
| 147 |
+
self._posterior_threshold,
|
| 148 |
+
torch.exp(-posterior_entropy) * self._posterior_alpha,
|
| 149 |
+
)
|
| 150 |
+
accepted_mask = candidates_prob > threshold
|
| 151 |
+
return accepted_mask
|
| 152 |
+
|
| 153 |
+
def _get_recovered_token_ids(self, target_probs):
|
| 154 |
+
"""
|
| 155 |
+
The recovered token ids will fill the first unmatched token
|
| 156 |
+
by the target token.
|
| 157 |
+
|
| 158 |
+
Parameters
|
| 159 |
+
----------
|
| 160 |
+
target_probs : torch.Tensor
|
| 161 |
+
A tensor of shape (batch_size, k, vocab_size) containing
|
| 162 |
+
the target probability distribution
|
| 163 |
+
|
| 164 |
+
Returns
|
| 165 |
+
-------
|
| 166 |
+
torch.Tensor
|
| 167 |
+
A tensor of shape (batch_size, k) with the recovered token
|
| 168 |
+
ids which are selected from target probs.
|
| 169 |
+
"""
|
| 170 |
+
max_indices = torch.argmax(target_probs, dim=-1)
|
| 171 |
+
|
| 172 |
+
return max_indices
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/utils.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""Utility methods for model layers."""
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_token_bin_counts_and_mask(
|
| 9 |
+
tokens: torch.Tensor,
|
| 10 |
+
vocab_size: int,
|
| 11 |
+
num_seqs: int,
|
| 12 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 13 |
+
# Compute the bin counts for the tokens.
|
| 14 |
+
# vocab_size + 1 for padding.
|
| 15 |
+
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
|
| 16 |
+
dtype=torch.long,
|
| 17 |
+
device=tokens.device)
|
| 18 |
+
bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
|
| 19 |
+
bin_counts = bin_counts[:, :vocab_size]
|
| 20 |
+
mask = bin_counts > 0
|
| 21 |
+
|
| 22 |
+
return bin_counts, mask
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
| 26 |
+
output_tokens_tensor: torch.Tensor,
|
| 27 |
+
presence_penalties: torch.Tensor,
|
| 28 |
+
frequency_penalties: torch.Tensor,
|
| 29 |
+
repetition_penalties: torch.Tensor) -> torch.Tensor:
|
| 30 |
+
"""
|
| 31 |
+
Applies penalties in place to the logits tensor
|
| 32 |
+
logits : The input logits tensor of shape [num_seqs, vocab_size]
|
| 33 |
+
prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
|
| 34 |
+
are padded to the maximum prompt length within the batch using
|
| 35 |
+
`vocab_size` as the padding value. The value `vocab_size` is used
|
| 36 |
+
for padding because it does not correspond to any valid token ID
|
| 37 |
+
in the vocabulary.
|
| 38 |
+
output_tokens_tensor: The output tokens tensor.
|
| 39 |
+
presence_penalties: The presence penalties of shape (num_seqs, )
|
| 40 |
+
frequency_penalties: The frequency penalties of shape (num_seqs, )
|
| 41 |
+
repetition_penalties: The repetition penalties of shape (num_seqs, )
|
| 42 |
+
"""
|
| 43 |
+
num_seqs, vocab_size = logits.shape
|
| 44 |
+
_, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor,
|
| 45 |
+
vocab_size, num_seqs)
|
| 46 |
+
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
|
| 47 |
+
output_tokens_tensor, vocab_size, num_seqs)
|
| 48 |
+
repetition_penalties = repetition_penalties.unsqueeze_(dim=1).repeat(
|
| 49 |
+
1, vocab_size)
|
| 50 |
+
logits[logits > 0] /= torch.where(prompt_mask | output_mask,
|
| 51 |
+
repetition_penalties, 1.0)[logits > 0]
|
| 52 |
+
logits[logits <= 0] *= torch.where(prompt_mask | output_mask,
|
| 53 |
+
repetition_penalties, 1.0)[logits <= 0]
|
| 54 |
+
# We follow the definition in OpenAI API.
|
| 55 |
+
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
| 56 |
+
logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
|
| 57 |
+
logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
|
| 58 |
+
return logits
|
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/vocab_parallel_embedding.py
ADDED
|
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import List, Optional, Sequence, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch.nn.parameter import Parameter, UninitializedParameter
|
| 9 |
+
|
| 10 |
+
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
| 11 |
+
get_tensor_model_parallel_world_size,
|
| 12 |
+
tensor_model_parallel_all_reduce)
|
| 13 |
+
from vllm.model_executor.layers.quantization.base_config import (
|
| 14 |
+
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
|
| 15 |
+
from vllm.model_executor.parameter import BasevLLMParameter
|
| 16 |
+
from vllm.model_executor.utils import set_weight_attrs
|
| 17 |
+
from vllm.platforms import current_platform
|
| 18 |
+
|
| 19 |
+
DEFAULT_VOCAB_PADDING_SIZE = 64
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
| 23 |
+
"""Unquantized method for embeddings."""
|
| 24 |
+
|
| 25 |
+
def create_weights(self, layer: torch.nn.Module,
|
| 26 |
+
input_size_per_partition: int,
|
| 27 |
+
output_partition_sizes: List[int], input_size: int,
|
| 28 |
+
output_size: int, params_dtype: torch.dtype,
|
| 29 |
+
**extra_weight_attrs):
|
| 30 |
+
"""Create weights for embedding layer."""
|
| 31 |
+
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
| 32 |
+
input_size_per_partition,
|
| 33 |
+
dtype=params_dtype),
|
| 34 |
+
requires_grad=False)
|
| 35 |
+
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
| 36 |
+
layer.register_parameter("weight", weight)
|
| 37 |
+
set_weight_attrs(weight, extra_weight_attrs)
|
| 38 |
+
|
| 39 |
+
def apply(self,
|
| 40 |
+
layer: torch.nn.Module,
|
| 41 |
+
x: torch.Tensor,
|
| 42 |
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 43 |
+
return F.linear(x, layer.weight, bias)
|
| 44 |
+
|
| 45 |
+
def embedding(self, layer: torch.nn.Module,
|
| 46 |
+
input_: torch.Tensor) -> torch.Tensor:
|
| 47 |
+
return F.embedding(input_, layer.weight)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def pad_vocab_size(vocab_size: int,
|
| 51 |
+
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
|
| 52 |
+
"""Pad the vocab size to the given value."""
|
| 53 |
+
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def vocab_range_from_per_partition_vocab_size(
|
| 57 |
+
per_partition_vocab_size: int,
|
| 58 |
+
rank: int,
|
| 59 |
+
offset: int = 0) -> Sequence[int]:
|
| 60 |
+
index_f = rank * per_partition_vocab_size
|
| 61 |
+
index_l = index_f + per_partition_vocab_size
|
| 62 |
+
return index_f + offset, index_l + offset
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def vocab_range_from_global_vocab_size(global_vocab_size: int,
|
| 66 |
+
rank: int,
|
| 67 |
+
world_size: int,
|
| 68 |
+
offset: int = 0) -> Sequence[int]:
|
| 69 |
+
per_partition_vocab_size = divide(global_vocab_size, world_size)
|
| 70 |
+
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
|
| 71 |
+
rank,
|
| 72 |
+
offset=offset)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class VocabParallelEmbeddingShardIndices:
|
| 77 |
+
"""Indices for a shard of a vocab parallel embedding."""
|
| 78 |
+
padded_org_vocab_start_index: int
|
| 79 |
+
padded_org_vocab_end_index: int
|
| 80 |
+
padded_added_vocab_start_index: int
|
| 81 |
+
padded_added_vocab_end_index: int
|
| 82 |
+
|
| 83 |
+
org_vocab_start_index: int
|
| 84 |
+
org_vocab_end_index: int
|
| 85 |
+
added_vocab_start_index: int
|
| 86 |
+
added_vocab_end_index: int
|
| 87 |
+
|
| 88 |
+
@property
|
| 89 |
+
def num_org_elements(self) -> int:
|
| 90 |
+
return self.org_vocab_end_index - self.org_vocab_start_index
|
| 91 |
+
|
| 92 |
+
@property
|
| 93 |
+
def num_added_elements(self) -> int:
|
| 94 |
+
return self.added_vocab_end_index - self.added_vocab_start_index
|
| 95 |
+
|
| 96 |
+
@property
|
| 97 |
+
def num_org_elements_padded(self) -> int:
|
| 98 |
+
return (self.padded_org_vocab_end_index -
|
| 99 |
+
self.padded_org_vocab_start_index)
|
| 100 |
+
|
| 101 |
+
@property
|
| 102 |
+
def num_added_elements_padded(self) -> int:
|
| 103 |
+
return (self.padded_added_vocab_end_index -
|
| 104 |
+
self.padded_added_vocab_start_index)
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def num_org_vocab_padding(self) -> int:
|
| 108 |
+
return self.num_org_elements_padded - self.num_org_elements
|
| 109 |
+
|
| 110 |
+
@property
|
| 111 |
+
def num_added_vocab_padding(self) -> int:
|
| 112 |
+
return self.num_added_elements_padded - self.num_added_elements
|
| 113 |
+
|
| 114 |
+
@property
|
| 115 |
+
def num_elements_padded(self) -> int:
|
| 116 |
+
return self.num_org_elements_padded + self.num_added_elements_padded
|
| 117 |
+
|
| 118 |
+
def __post_init__(self):
|
| 119 |
+
# sanity checks
|
| 120 |
+
assert (self.padded_org_vocab_start_index
|
| 121 |
+
<= self.padded_org_vocab_end_index)
|
| 122 |
+
assert (self.padded_added_vocab_start_index
|
| 123 |
+
<= self.padded_added_vocab_end_index)
|
| 124 |
+
|
| 125 |
+
assert self.org_vocab_start_index <= self.org_vocab_end_index
|
| 126 |
+
assert self.added_vocab_start_index <= self.added_vocab_end_index
|
| 127 |
+
|
| 128 |
+
assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
|
| 129 |
+
assert (self.added_vocab_start_index
|
| 130 |
+
<= self.padded_added_vocab_start_index)
|
| 131 |
+
assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
|
| 132 |
+
assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
|
| 133 |
+
|
| 134 |
+
assert self.num_org_elements <= self.num_org_elements_padded
|
| 135 |
+
assert self.num_added_elements <= self.num_added_elements_padded
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
| 139 |
+
def get_masked_input_and_mask(
|
| 140 |
+
input_: torch.Tensor, org_vocab_start_index: int,
|
| 141 |
+
org_vocab_end_index: int, num_org_vocab_padding: int,
|
| 142 |
+
added_vocab_start_index: int,
|
| 143 |
+
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 144 |
+
# torch.compile will fuse all of the pointwise ops below
|
| 145 |
+
# into a single kernel, making it very fast
|
| 146 |
+
org_vocab_mask = (input_ >= org_vocab_start_index) & (
|
| 147 |
+
input_ < org_vocab_end_index)
|
| 148 |
+
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
| 149 |
+
input_ < added_vocab_end_index)
|
| 150 |
+
added_offset = added_vocab_start_index - (
|
| 151 |
+
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
|
| 152 |
+
valid_offset = (org_vocab_start_index *
|
| 153 |
+
org_vocab_mask) + (added_offset * added_vocab_mask)
|
| 154 |
+
vocab_mask = org_vocab_mask | added_vocab_mask
|
| 155 |
+
input_ = vocab_mask * (input_ - valid_offset)
|
| 156 |
+
return input_, ~vocab_mask
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class VocabParallelEmbedding(torch.nn.Module):
|
| 160 |
+
"""Embedding parallelized in the vocabulary dimension.
|
| 161 |
+
|
| 162 |
+
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
|
| 163 |
+
make sure it is divisible by the number of model parallel GPUs.
|
| 164 |
+
|
| 165 |
+
In order to support various loading methods, we ensure that LoRA-added
|
| 166 |
+
embeddings are always at the end of TP-sharded tensors. In other words,
|
| 167 |
+
we shard base embeddings and LoRA embeddings separately (both padded),
|
| 168 |
+
and place them in the same tensor.
|
| 169 |
+
In this example, we will have the original vocab size = 1010,
|
| 170 |
+
added vocab size = 16 and padding to 64. Therefore, the total
|
| 171 |
+
vocab size with padding will be 1088 (because we first pad 1010 to
|
| 172 |
+
1024, add 16, and then pad to 1088).
|
| 173 |
+
Therefore, the tensor format looks like the following:
|
| 174 |
+
TP1, rank 0 (no sharding):
|
| 175 |
+
|< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >|
|
| 176 |
+
corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 |
|
| 177 |
+
index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |
|
| 178 |
+
|
| 179 |
+
TP2, rank 0:
|
| 180 |
+
|< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >|
|
| 181 |
+
corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 |
|
| 182 |
+
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 |
|
| 183 |
+
TP2, rank 1:
|
| 184 |
+
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
|
| 185 |
+
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
|
| 186 |
+
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
num_embeddings: vocabulary size.
|
| 190 |
+
embedding_dim: size of hidden state.
|
| 191 |
+
params_dtype: type of the parameters.
|
| 192 |
+
org_num_embeddings: original vocabulary size (without LoRA).
|
| 193 |
+
padding_size: padding size for the vocabulary.
|
| 194 |
+
quant_config: quant config for the layer
|
| 195 |
+
prefix: full name of the layer in the state dict
|
| 196 |
+
""" # noqa: E501
|
| 197 |
+
|
| 198 |
+
def __init__(self,
|
| 199 |
+
num_embeddings: int,
|
| 200 |
+
embedding_dim: int,
|
| 201 |
+
params_dtype: Optional[torch.dtype] = None,
|
| 202 |
+
org_num_embeddings: Optional[int] = None,
|
| 203 |
+
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
| 204 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 205 |
+
prefix: str = ""):
|
| 206 |
+
super().__init__()
|
| 207 |
+
|
| 208 |
+
# Keep the input dimensions.
|
| 209 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 210 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 211 |
+
self.num_embeddings = num_embeddings
|
| 212 |
+
self.padding_size = padding_size
|
| 213 |
+
self.org_vocab_size = org_num_embeddings or num_embeddings
|
| 214 |
+
num_added_embeddings = num_embeddings - self.org_vocab_size
|
| 215 |
+
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
|
| 216 |
+
self.padding_size)
|
| 217 |
+
self.num_embeddings_padded = pad_vocab_size(
|
| 218 |
+
self.org_vocab_size_padded + num_added_embeddings,
|
| 219 |
+
self.padding_size)
|
| 220 |
+
assert self.org_vocab_size_padded <= self.num_embeddings_padded
|
| 221 |
+
|
| 222 |
+
self.shard_indices = self._get_indices(self.num_embeddings_padded,
|
| 223 |
+
self.org_vocab_size_padded,
|
| 224 |
+
self.num_embeddings,
|
| 225 |
+
self.org_vocab_size, tp_rank,
|
| 226 |
+
self.tp_size)
|
| 227 |
+
self.embedding_dim = embedding_dim
|
| 228 |
+
|
| 229 |
+
linear_method = None
|
| 230 |
+
if quant_config is not None:
|
| 231 |
+
linear_method = quant_config.get_quant_method(self, prefix=prefix)
|
| 232 |
+
if linear_method is None:
|
| 233 |
+
linear_method = UnquantizedEmbeddingMethod()
|
| 234 |
+
|
| 235 |
+
# If we are making an embedding layer, then our quantization linear
|
| 236 |
+
# method must implement the embedding operation. If we are another
|
| 237 |
+
# layer type like ParallelLMHead, this is not important.
|
| 238 |
+
is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
|
| 239 |
+
linear_method_implements_embedding = method_has_implemented_embedding(
|
| 240 |
+
type(linear_method))
|
| 241 |
+
if is_embedding_layer and not linear_method_implements_embedding:
|
| 242 |
+
raise NotImplementedError(
|
| 243 |
+
f"The class {type(linear_method).__name__} must implement "
|
| 244 |
+
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
|
| 245 |
+
|
| 246 |
+
self.linear_method: QuantizeMethodBase = linear_method
|
| 247 |
+
|
| 248 |
+
if params_dtype is None:
|
| 249 |
+
params_dtype = torch.get_default_dtype()
|
| 250 |
+
# Divide the weight matrix along the vocaburaly dimension.
|
| 251 |
+
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
|
| 252 |
+
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
|
| 253 |
+
self.tp_size)
|
| 254 |
+
assert (self.shard_indices.num_elements_padded ==
|
| 255 |
+
self.num_embeddings_per_partition)
|
| 256 |
+
self.num_org_embeddings_per_partition = (
|
| 257 |
+
self.shard_indices.org_vocab_end_index -
|
| 258 |
+
self.shard_indices.org_vocab_start_index)
|
| 259 |
+
self.num_added_embeddings_per_partition = (
|
| 260 |
+
self.shard_indices.added_vocab_end_index -
|
| 261 |
+
self.shard_indices.added_vocab_start_index)
|
| 262 |
+
|
| 263 |
+
self.linear_method.create_weights(self,
|
| 264 |
+
self.embedding_dim,
|
| 265 |
+
[self.num_embeddings_per_partition],
|
| 266 |
+
self.embedding_dim,
|
| 267 |
+
self.num_embeddings_padded,
|
| 268 |
+
params_dtype=params_dtype,
|
| 269 |
+
weight_loader=self.weight_loader)
|
| 270 |
+
|
| 271 |
+
@classmethod
|
| 272 |
+
def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
|
| 273 |
+
vocab_size: int, org_vocab_size: int, tp_rank: int,
|
| 274 |
+
tp_size: int) -> VocabParallelEmbeddingShardIndices:
|
| 275 |
+
"""Get start and end indices for vocab parallel embedding, following the
|
| 276 |
+
layout outlined in the class docstring, based on the given tp_rank and
|
| 277 |
+
tp_size."""
|
| 278 |
+
num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded
|
| 279 |
+
padded_org_vocab_start_index, padded_org_vocab_end_index = (
|
| 280 |
+
vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank,
|
| 281 |
+
tp_size))
|
| 282 |
+
padded_added_vocab_start_index, padded_added_vocab_end_index = (
|
| 283 |
+
vocab_range_from_global_vocab_size(num_added_embeddings_padded,
|
| 284 |
+
tp_rank,
|
| 285 |
+
tp_size,
|
| 286 |
+
offset=org_vocab_size))
|
| 287 |
+
# remove padding
|
| 288 |
+
org_vocab_start_index = min(padded_org_vocab_start_index,
|
| 289 |
+
org_vocab_size)
|
| 290 |
+
org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
|
| 291 |
+
added_vocab_start_index = min(padded_added_vocab_start_index,
|
| 292 |
+
vocab_size)
|
| 293 |
+
added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
|
| 294 |
+
return VocabParallelEmbeddingShardIndices(
|
| 295 |
+
padded_org_vocab_start_index, padded_org_vocab_end_index,
|
| 296 |
+
padded_added_vocab_start_index, padded_added_vocab_end_index,
|
| 297 |
+
org_vocab_start_index, org_vocab_end_index,
|
| 298 |
+
added_vocab_start_index, added_vocab_end_index)
|
| 299 |
+
|
| 300 |
+
def get_sharded_to_full_mapping(self) -> Optional[List[int]]:
|
| 301 |
+
"""Get a mapping that can be used to reindex the gathered
|
| 302 |
+
logits for sampling.
|
| 303 |
+
|
| 304 |
+
During sampling, we gather logits from all ranks. The relationship
|
| 305 |
+
of index->token_id will follow the same format as outlined in the class
|
| 306 |
+
docstring. However, after the gather, we want to reindex the final
|
| 307 |
+
logits tensor to map index->token_id one-to-one (the index is always
|
| 308 |
+
equal the token_id it corresponds to). The indices returned by this
|
| 309 |
+
method allow us to do that.
|
| 310 |
+
"""
|
| 311 |
+
if self.tp_size < 2:
|
| 312 |
+
return None
|
| 313 |
+
|
| 314 |
+
base_embeddings: List[int] = []
|
| 315 |
+
added_embeddings: List[int] = []
|
| 316 |
+
padding: List[int] = []
|
| 317 |
+
for tp_rank in range(self.tp_size):
|
| 318 |
+
shard_indices = self._get_indices(self.num_embeddings_padded,
|
| 319 |
+
self.org_vocab_size_padded,
|
| 320 |
+
self.num_embeddings,
|
| 321 |
+
self.org_vocab_size, tp_rank,
|
| 322 |
+
self.tp_size)
|
| 323 |
+
range_start = self.num_embeddings_per_partition * tp_rank
|
| 324 |
+
range_end = self.num_embeddings_per_partition * (tp_rank + 1)
|
| 325 |
+
base_embeddings.extend(
|
| 326 |
+
range(range_start,
|
| 327 |
+
range_start + shard_indices.num_org_elements))
|
| 328 |
+
padding.extend(
|
| 329 |
+
range(range_start + shard_indices.num_org_elements,
|
| 330 |
+
range_start + shard_indices.num_org_elements_padded))
|
| 331 |
+
added_embeddings.extend(
|
| 332 |
+
range(
|
| 333 |
+
range_start + shard_indices.num_org_elements_padded,
|
| 334 |
+
range_start + shard_indices.num_org_elements_padded +
|
| 335 |
+
shard_indices.num_added_elements))
|
| 336 |
+
padding.extend(
|
| 337 |
+
range(
|
| 338 |
+
range_start + shard_indices.num_org_elements_padded +
|
| 339 |
+
shard_indices.num_added_elements,
|
| 340 |
+
range_start + shard_indices.num_org_elements_padded +
|
| 341 |
+
shard_indices.num_added_elements_padded))
|
| 342 |
+
assert (range_start + shard_indices.num_org_elements_padded +
|
| 343 |
+
shard_indices.num_added_elements_padded == range_end)
|
| 344 |
+
ret = base_embeddings + added_embeddings + padding
|
| 345 |
+
assert len(ret) == self.num_embeddings_padded
|
| 346 |
+
return ret
|
| 347 |
+
|
| 348 |
+
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
| 349 |
+
output_dim = getattr(param, "output_dim", None)
|
| 350 |
+
packed_dim = getattr(param, "packed_dim", None)
|
| 351 |
+
|
| 352 |
+
# If the parameter is a gguf weight, then load it directly.
|
| 353 |
+
if getattr(param, "is_gguf_weight_type", None):
|
| 354 |
+
param.data.copy_(loaded_weight)
|
| 355 |
+
param.weight_type = loaded_weight.item()
|
| 356 |
+
return
|
| 357 |
+
elif isinstance(param, UninitializedParameter):
|
| 358 |
+
shape = list(loaded_weight.shape)
|
| 359 |
+
if output_dim is not None:
|
| 360 |
+
shape[output_dim] = self.num_embeddings_per_partition
|
| 361 |
+
param.materialize(tuple(shape), dtype=loaded_weight.dtype)
|
| 362 |
+
|
| 363 |
+
# If parameter does not have output dim, then it should
|
| 364 |
+
# be copied onto all gpus (e.g. g_idx for act_order gptq).
|
| 365 |
+
if output_dim is None:
|
| 366 |
+
assert param.data.shape == loaded_weight.shape
|
| 367 |
+
param.data.copy_(loaded_weight)
|
| 368 |
+
return
|
| 369 |
+
|
| 370 |
+
# Shard indexes for loading the weight
|
| 371 |
+
start_idx = self.shard_indices.org_vocab_start_index
|
| 372 |
+
shard_size = self.shard_indices.org_vocab_end_index - start_idx
|
| 373 |
+
|
| 374 |
+
# If param packed on the same dim we are sharding on, then
|
| 375 |
+
# need to adjust offsets of loaded weight by pack_factor.
|
| 376 |
+
if packed_dim is not None and packed_dim == output_dim:
|
| 377 |
+
packed_factor = param.packed_factor if isinstance(
|
| 378 |
+
param, BasevLLMParameter) else param.pack_factor
|
| 379 |
+
assert loaded_weight.shape[output_dim] == (self.org_vocab_size //
|
| 380 |
+
param.packed_factor)
|
| 381 |
+
start_idx = start_idx // packed_factor
|
| 382 |
+
shard_size = shard_size // packed_factor
|
| 383 |
+
else:
|
| 384 |
+
assert loaded_weight.shape[output_dim] == self.org_vocab_size
|
| 385 |
+
|
| 386 |
+
# Copy the data. Select chunk corresponding to current shard.
|
| 387 |
+
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
| 388 |
+
|
| 389 |
+
if current_platform.is_hpu():
|
| 390 |
+
# FIXME(kzawora): Weight copy with slicing bugs out on Gaudi here,
|
| 391 |
+
# so we're using a workaround. Remove this when fixed in
|
| 392 |
+
# HPU PT bridge.
|
| 393 |
+
padded_weight = torch.cat([
|
| 394 |
+
loaded_weight,
|
| 395 |
+
torch.zeros(param.shape[0] - loaded_weight.shape[0],
|
| 396 |
+
*loaded_weight.shape[1:])
|
| 397 |
+
])
|
| 398 |
+
param.data.copy_(padded_weight)
|
| 399 |
+
else:
|
| 400 |
+
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
|
| 401 |
+
param[loaded_weight.shape[0]:].data.fill_(0)
|
| 402 |
+
|
| 403 |
+
def forward(self, input_):
|
| 404 |
+
if self.tp_size > 1:
|
| 405 |
+
# Build the mask.
|
| 406 |
+
masked_input, input_mask = get_masked_input_and_mask(
|
| 407 |
+
input_, self.shard_indices.org_vocab_start_index,
|
| 408 |
+
self.shard_indices.org_vocab_end_index,
|
| 409 |
+
self.shard_indices.num_org_vocab_padding,
|
| 410 |
+
self.shard_indices.added_vocab_start_index,
|
| 411 |
+
self.shard_indices.added_vocab_end_index)
|
| 412 |
+
else:
|
| 413 |
+
masked_input = input_
|
| 414 |
+
# Get the embeddings.
|
| 415 |
+
output_parallel = self.linear_method.embedding(self,
|
| 416 |
+
masked_input.long())
|
| 417 |
+
# Mask the output embedding.
|
| 418 |
+
if self.tp_size > 1:
|
| 419 |
+
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
| 420 |
+
# Reduce across all the model parallel GPUs.
|
| 421 |
+
output = tensor_model_parallel_all_reduce(output_parallel)
|
| 422 |
+
return output
|
| 423 |
+
|
| 424 |
+
def extra_repr(self) -> str:
|
| 425 |
+
s = f"num_embeddings={self.num_embeddings_per_partition}"
|
| 426 |
+
s += f", embedding_dim={self.embedding_dim}"
|
| 427 |
+
s += f", org_vocab_size={self.org_vocab_size}"
|
| 428 |
+
s += f', num_embeddings_padded={self.num_embeddings_padded}'
|
| 429 |
+
s += f', tp_size={self.tp_size}'
|
| 430 |
+
return s
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
class ParallelLMHead(VocabParallelEmbedding):
|
| 434 |
+
"""Parallelized LM head.
|
| 435 |
+
|
| 436 |
+
Output logits weight matrices used in the Sampler. The weight and bias
|
| 437 |
+
tensors are padded to make sure they are divisible by the number of
|
| 438 |
+
model parallel GPUs.
|
| 439 |
+
|
| 440 |
+
Args:
|
| 441 |
+
num_embeddings: vocabulary size.
|
| 442 |
+
embedding_dim: size of hidden state.
|
| 443 |
+
bias: whether to use bias.
|
| 444 |
+
params_dtype: type of the parameters.
|
| 445 |
+
org_num_embeddings: original vocabulary size (without LoRA).
|
| 446 |
+
padding_size: padding size for the vocabulary.
|
| 447 |
+
"""
|
| 448 |
+
|
| 449 |
+
def __init__(self,
|
| 450 |
+
num_embeddings: int,
|
| 451 |
+
embedding_dim: int,
|
| 452 |
+
bias: bool = False,
|
| 453 |
+
params_dtype: Optional[torch.dtype] = None,
|
| 454 |
+
org_num_embeddings: Optional[int] = None,
|
| 455 |
+
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
| 456 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 457 |
+
prefix: str = ""):
|
| 458 |
+
super().__init__(num_embeddings, embedding_dim, params_dtype,
|
| 459 |
+
org_num_embeddings, padding_size, quant_config,
|
| 460 |
+
prefix)
|
| 461 |
+
self.quant_config = quant_config
|
| 462 |
+
if bias:
|
| 463 |
+
self.bias = Parameter(
|
| 464 |
+
torch.empty(self.num_embeddings_per_partition,
|
| 465 |
+
dtype=params_dtype))
|
| 466 |
+
set_weight_attrs(self.bias, {
|
| 467 |
+
"output_dim": 0,
|
| 468 |
+
"weight_loader": self.weight_loader,
|
| 469 |
+
})
|
| 470 |
+
else:
|
| 471 |
+
self.register_parameter("bias", None)
|
| 472 |
+
|
| 473 |
+
def tie_weights(self, embed_tokens: VocabParallelEmbedding):
|
| 474 |
+
"""Tie the weights with word embeddings."""
|
| 475 |
+
# GGUF quantized embed_tokens.
|
| 476 |
+
if self.quant_config and self.quant_config.get_name() == "gguf":
|
| 477 |
+
return embed_tokens
|
| 478 |
+
else:
|
| 479 |
+
self.weight = embed_tokens.weight
|
| 480 |
+
return self
|
| 481 |
+
|
| 482 |
+
def forward(self, input_):
|
| 483 |
+
del input_
|
| 484 |
+
raise RuntimeError("LMHead's weights should be used in the sampler.")
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (907 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/adapters.cpython-311.pyc
ADDED
|
Binary file (10.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/arctic.cpython-311.pyc
ADDED
|
Binary file (28.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/bert.cpython-311.pyc
ADDED
|
Binary file (28.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/blip2.cpython-311.pyc
ADDED
|
Binary file (34.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/bloom.cpython-311.pyc
ADDED
|
Binary file (18 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/chameleon.cpython-311.pyc
ADDED
|
Binary file (57.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/chatglm.cpython-311.pyc
ADDED
|
Binary file (34.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/clip.cpython-311.pyc
ADDED
|
Binary file (25.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/decilm.cpython-311.pyc
ADDED
|
Binary file (4.97 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/deepseek.cpython-311.pyc
ADDED
|
Binary file (23.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/eagle.cpython-311.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/fairseq2_llama.cpython-311.pyc
ADDED
|
Binary file (7.54 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/falcon.cpython-311.pyc
ADDED
|
Binary file (23 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/glm.cpython-311.pyc
ADDED
|
Binary file (1.62 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/gpt2.cpython-311.pyc
ADDED
|
Binary file (16.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/gpt_j.cpython-311.pyc
ADDED
|
Binary file (16.7 kB). View file
|
|
|