koichi12 commited on
Commit
0fbd155
·
verified ·
1 Parent(s): 0d7d10f

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. .venv/lib/python3.11/site-packages/grpc/_cython/cygrpc.cpython-311-x86_64-linux-gnu.so +3 -0
  3. .venv/lib/python3.11/site-packages/vllm/__pycache__/config.cpython-311.pyc +3 -0
  4. .venv/lib/python3.11/site-packages/vllm/__pycache__/utils.cpython-311.pyc +3 -0
  5. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/__init__.py +0 -0
  6. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/activation.py +360 -0
  7. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/__init__.py +48 -0
  8. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/__pycache__/__init__.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/__pycache__/fused_marlin_moe.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/__pycache__/fused_moe.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/__pycache__/layer.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/__pycache__/moe_pallas.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/__pycache__/moe_torch_iterative.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  15. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  16. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  17. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +360 -0
  18. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/fused_moe.py +1363 -0
  19. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/layer.py +647 -0
  20. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/moe_pallas.py +64 -0
  21. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py +53 -0
  22. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/layernorm.py +213 -0
  23. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/linear.py +1159 -0
  24. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/logits_processor.py +193 -0
  25. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/pooler.py +322 -0
  26. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/rejection_sampler.py +400 -0
  27. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/resampler.py +269 -0
  28. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/rotary_embedding.py +1114 -0
  29. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/sampler.py +1292 -0
  30. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/spec_decode_base_sampler.py +256 -0
  31. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/typical_acceptance_sampler.py +172 -0
  32. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/utils.py +58 -0
  33. .venv/lib/python3.11/site-packages/vllm/model_executor/layers/vocab_parallel_embedding.py +484 -0
  34. .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/__init__.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/adapters.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/arctic.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/bert.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/blip2.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/bloom.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/chameleon.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/chatglm.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/clip.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/decilm.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/deepseek.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/eagle.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/fairseq2_llama.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/falcon.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/glm.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/gpt2.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/gpt_j.cpython-311.pyc +0 -0
.gitattributes CHANGED
@@ -201,3 +201,6 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
201
  .venv/lib/python3.11/site-packages/jinja2/__pycache__/compiler.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
202
  .venv/lib/python3.11/site-packages/grpc/__pycache__/_channel.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
203
  .venv/lib/python3.11/site-packages/msgpack/_cmsgpack.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
 
 
 
 
201
  .venv/lib/python3.11/site-packages/jinja2/__pycache__/compiler.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
202
  .venv/lib/python3.11/site-packages/grpc/__pycache__/_channel.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
203
  .venv/lib/python3.11/site-packages/msgpack/_cmsgpack.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
204
+ .venv/lib/python3.11/site-packages/vllm/__pycache__/config.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
205
+ .venv/lib/python3.11/site-packages/vllm/__pycache__/utils.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
206
+ .venv/lib/python3.11/site-packages/grpc/_cython/cygrpc.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/grpc/_cython/cygrpc.cpython-311-x86_64-linux-gnu.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e5d08f09e30133cae4310d214c9357fca55ebd0e2db830c422465af821a6392
3
+ size 13660664
.venv/lib/python3.11/site-packages/vllm/__pycache__/config.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c332333bd8134456a6d80925bf608e61fc31c7df941a7862edcbfacf4b07e81
3
+ size 148527
.venv/lib/python3.11/site-packages/vllm/__pycache__/utils.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:35e12ff57902027f5a6f938eca8e4cb4a91c51c331e59ff752edd1b635b6330f
3
+ size 113860
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/activation.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Custom activation functions."""
3
+ import math
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
11
+ get_tensor_model_parallel_world_size)
12
+ from vllm.model_executor.custom_op import CustomOp
13
+ from vllm.model_executor.utils import set_weight_attrs
14
+ from vllm.platforms import current_platform
15
+ from vllm.utils import LazyDict
16
+
17
+
18
+ @CustomOp.register("fatrelu_and_mul")
19
+ class FatreluAndMul(CustomOp):
20
+ """An activation function for FATReLU.
21
+
22
+ The function computes x -> FATReLU(x[:d]) * x[d:] where
23
+ d = x.shape[-1] // 2.
24
+ This is used in openbmb/MiniCPM-S-1B-sft.
25
+
26
+ Shapes:
27
+ x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
28
+ return: (num_tokens, d) or (batch_size, seq_len, d)
29
+ """
30
+
31
+ def __init__(self, threshold: float = 0.):
32
+ super().__init__()
33
+ self.threshold = threshold
34
+ if current_platform.is_cuda_alike():
35
+ self.op = torch.ops._C.fatrelu_and_mul
36
+ elif current_platform.is_cpu():
37
+ self._forward_method = self.forward_native
38
+
39
+ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
40
+ d = x.shape[-1] // 2
41
+ x1 = x[..., :d]
42
+ x2 = x[..., d:]
43
+ x1 = F.threshold(x1, self.threshold, 0.0)
44
+ return x1 * x2
45
+
46
+ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
47
+ d = x.shape[-1] // 2
48
+ output_shape = (x.shape[:-1] + (d, ))
49
+ out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
50
+ self.op(out, x, self.threshold)
51
+ return out
52
+
53
+
54
+ @CustomOp.register("silu_and_mul")
55
+ class SiluAndMul(CustomOp):
56
+ """An activation function for SwiGLU.
57
+
58
+ The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
59
+
60
+ Shapes:
61
+ x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
62
+ return: (num_tokens, d) or (batch_size, seq_len, d)
63
+ """
64
+
65
+ def __init__(self):
66
+ super().__init__()
67
+ if current_platform.is_cuda_alike() or current_platform.is_cpu():
68
+ self.op = torch.ops._C.silu_and_mul
69
+ elif current_platform.is_xpu():
70
+ from vllm._ipex_ops import ipex_ops
71
+ self.op = ipex_ops.silu_and_mul
72
+
73
+ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
74
+ """PyTorch-native implementation equivalent to forward()."""
75
+ d = x.shape[-1] // 2
76
+ return F.silu(x[..., :d]) * x[..., d:]
77
+
78
+ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
79
+ d = x.shape[-1] // 2
80
+ output_shape = (x.shape[:-1] + (d, ))
81
+ out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
82
+ self.op(out, x)
83
+ return out
84
+
85
+ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
86
+ d = x.shape[-1] // 2
87
+ output_shape = (x.shape[:-1] + (d, ))
88
+ out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
89
+ self.op(out, x)
90
+ return out
91
+
92
+
93
+ @CustomOp.register("mul_and_silu")
94
+ class MulAndSilu(CustomOp):
95
+ """An activation function for SwiGLU.
96
+
97
+ The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.
98
+
99
+ Shapes:
100
+ x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
101
+ return: (num_tokens, d) or (batch_size, seq_len, d)
102
+ """
103
+
104
+ def __init__(self):
105
+ super().__init__()
106
+ if current_platform.is_cuda_alike():
107
+ self.op = torch.ops._C.mul_and_silu
108
+ elif current_platform.is_xpu():
109
+ from vllm._ipex_ops import ipex_ops
110
+ self.op = ipex_ops.silu_and_mul
111
+ elif current_platform.is_cpu():
112
+ self._forward_method = self.forward_native
113
+
114
+ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
115
+ """PyTorch-native implementation equivalent to forward()."""
116
+ d = x.shape[-1] // 2
117
+ return x[..., :d] * F.silu(x[..., d:])
118
+
119
+ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
120
+ d = x.shape[-1] // 2
121
+ output_shape = (x.shape[:-1] + (d, ))
122
+ out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
123
+ self.op(out, x)
124
+ return out
125
+
126
+ # TODO implement forward_xpu for MulAndSilu
127
+ # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
128
+
129
+
130
+ @CustomOp.register("gelu_and_mul")
131
+ class GeluAndMul(CustomOp):
132
+ """An activation function for GeGLU.
133
+
134
+ The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
135
+
136
+ Shapes:
137
+ x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
138
+ return: (batch_size, seq_len, d) or (num_tokens, d)
139
+ """
140
+
141
+ def __init__(self, approximate: str = "none"):
142
+ super().__init__()
143
+ self.approximate = approximate
144
+ if approximate not in ("none", "tanh"):
145
+ raise ValueError(f"Unknown approximate mode: {approximate}")
146
+ if current_platform.is_cuda_alike() or current_platform.is_cpu():
147
+ if approximate == "none":
148
+ self.op = torch.ops._C.gelu_and_mul
149
+ elif approximate == "tanh":
150
+ self.op = torch.ops._C.gelu_tanh_and_mul
151
+ elif current_platform.is_xpu():
152
+ from vllm._ipex_ops import ipex_ops
153
+ if approximate == "none":
154
+ self.op = ipex_ops.gelu_and_mul
155
+ else:
156
+ self.op = ipex_ops.gelu_tanh_and_mul
157
+
158
+ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
159
+ """PyTorch-native implementation equivalent to forward()."""
160
+ d = x.shape[-1] // 2
161
+ return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
162
+
163
+ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
164
+ d = x.shape[-1] // 2
165
+ output_shape = (x.shape[:-1] + (d, ))
166
+ out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
167
+ self.op(out, x)
168
+ return out
169
+
170
+ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
171
+ d = x.shape[-1] // 2
172
+ output_shape = (x.shape[:-1] + (d, ))
173
+ out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
174
+ self.op(out, x)
175
+ return out
176
+
177
+ def extra_repr(self) -> str:
178
+ return f'approximate={repr(self.approximate)}'
179
+
180
+
181
+ @CustomOp.register("gelu_new")
182
+ class NewGELU(CustomOp):
183
+
184
+ def __init__(self):
185
+ super().__init__()
186
+ if current_platform.is_cuda_alike() or current_platform.is_cpu():
187
+ self.op = torch.ops._C.gelu_new
188
+ elif current_platform.is_xpu():
189
+ from vllm._ipex_ops import ipex_ops
190
+ self.op = ipex_ops.gelu_new
191
+
192
+ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
193
+ """PyTorch-native implementation equivalent to forward()."""
194
+ c = math.sqrt(2.0 / math.pi)
195
+ return 0.5 * x * (1.0 + torch.tanh(c *
196
+ (x + 0.044715 * torch.pow(x, 3.0))))
197
+
198
+ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
199
+ out = torch.empty_like(x)
200
+ self.op(out, x)
201
+ return out
202
+
203
+ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
204
+ return self.op(x)
205
+
206
+
207
+ @CustomOp.register("gelu_fast")
208
+ class FastGELU(CustomOp):
209
+
210
+ def __init__(self):
211
+ super().__init__()
212
+ if current_platform.is_cuda_alike() or current_platform.is_cpu():
213
+ self.op = torch.ops._C.gelu_fast
214
+ elif current_platform.is_xpu():
215
+ from vllm._ipex_ops import ipex_ops
216
+ self.op = ipex_ops.gelu_fast
217
+
218
+ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
219
+ """PyTorch-native implementation equivalent to forward()."""
220
+ return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
221
+ (1.0 + 0.044715 * x * x)))
222
+
223
+ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
224
+ out = torch.empty_like(x)
225
+ self.op(out, x)
226
+ return out
227
+
228
+ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
229
+ return self.op(x)
230
+
231
+
232
+ @CustomOp.register("quick_gelu")
233
+ class QuickGELU(CustomOp):
234
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
235
+ def __init__(self):
236
+ super().__init__()
237
+ if current_platform.is_cuda_alike() or current_platform.is_cpu():
238
+ self.op = torch.ops._C.gelu_quick
239
+ elif current_platform.is_xpu():
240
+ from vllm._ipex_ops import ipex_ops
241
+ self.op = ipex_ops.gelu_quick
242
+
243
+ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
244
+ """PyTorch-native implementation equivalent to forward()."""
245
+ return x * torch.sigmoid(1.702 * x)
246
+
247
+ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
248
+ out = torch.empty_like(x)
249
+ self.op(out, x)
250
+ return out
251
+
252
+ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
253
+ out = torch.empty_like(x)
254
+ self.op(out, x)
255
+ return out
256
+
257
+ # TODO implement forward_xpu for QuickGELU
258
+ # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
259
+
260
+
261
+ @CustomOp.register("relu2")
262
+ class ReLUSquaredActivation(CustomOp):
263
+ """
264
+ Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
265
+ """
266
+
267
+ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
268
+ """PyTorch-native implementation equivalent to forward()."""
269
+ return torch.square(F.relu(x))
270
+
271
+ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
272
+ return self.forward_native(x)
273
+
274
+
275
+ class ScaledActivation(nn.Module):
276
+ """An activation function with post-scale parameters.
277
+
278
+ This is used for some quantization methods like AWQ.
279
+ """
280
+
281
+ def __init__(
282
+ self,
283
+ act_module: nn.Module,
284
+ intermediate_size: int,
285
+ input_is_parallel: bool = True,
286
+ params_dtype: Optional[torch.dtype] = None,
287
+ ):
288
+ super().__init__()
289
+ self.act = act_module
290
+ self.input_is_parallel = input_is_parallel
291
+ if input_is_parallel:
292
+ tp_size = get_tensor_model_parallel_world_size()
293
+ intermediate_size_per_partition = divide(intermediate_size,
294
+ tp_size)
295
+ else:
296
+ intermediate_size_per_partition = intermediate_size
297
+ if params_dtype is None:
298
+ params_dtype = torch.get_default_dtype()
299
+ self.scales = nn.Parameter(
300
+ torch.empty(intermediate_size_per_partition, dtype=params_dtype))
301
+ set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
302
+
303
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
304
+ return self.act(x) / self.scales
305
+
306
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
307
+ param_data = param.data
308
+ if self.input_is_parallel:
309
+ tp_rank = get_tensor_model_parallel_rank()
310
+ shard_size = param_data.shape[0]
311
+ start_idx = tp_rank * shard_size
312
+ loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
313
+ assert param_data.shape == loaded_weight.shape
314
+ param_data.copy_(loaded_weight)
315
+
316
+
317
+ _ACTIVATION_REGISTRY = LazyDict({
318
+ "gelu":
319
+ lambda: nn.GELU(),
320
+ "gelu_fast":
321
+ lambda: FastGELU(),
322
+ "gelu_new":
323
+ lambda: NewGELU(),
324
+ "gelu_pytorch_tanh":
325
+ lambda: nn.GELU(approximate="tanh"),
326
+ "relu":
327
+ lambda: nn.ReLU(),
328
+ "relu2":
329
+ lambda: ReLUSquaredActivation(),
330
+ "silu":
331
+ lambda: nn.SiLU(),
332
+ "quick_gelu":
333
+ lambda: QuickGELU(),
334
+ })
335
+
336
+
337
+ def get_act_fn(act_fn_name: str) -> nn.Module:
338
+ """Get an activation function by name."""
339
+ act_fn_name = act_fn_name.lower()
340
+ if act_fn_name not in _ACTIVATION_REGISTRY:
341
+ raise ValueError(
342
+ f"Activation function {act_fn_name!r} is not supported.")
343
+
344
+ return _ACTIVATION_REGISTRY[act_fn_name]
345
+
346
+
347
+ _ACTIVATION_AND_MUL_REGISTRY = LazyDict({
348
+ "gelu": lambda: GeluAndMul(),
349
+ "silu": lambda: SiluAndMul(),
350
+ })
351
+
352
+
353
+ def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
354
+ """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
355
+ act_fn_name = act_fn_name.lower()
356
+ if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
357
+ raise ValueError(
358
+ f"Activation function {act_fn_name!r} is not supported.")
359
+
360
+ return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/__init__.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from contextlib import contextmanager
4
+ from typing import Any, Dict, Optional
5
+
6
+ from vllm.model_executor.layers.fused_moe.layer import (
7
+ FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
8
+ from vllm.triton_utils import HAS_TRITON
9
+
10
+ _config: Optional[Dict[str, Any]] = None
11
+
12
+
13
+ @contextmanager
14
+ def override_config(config):
15
+ global _config
16
+ old_config = _config
17
+ _config = config
18
+ yield
19
+ _config = old_config
20
+
21
+
22
+ def get_config() -> Optional[Dict[str, Any]]:
23
+ return _config
24
+
25
+
26
+ __all__ = [
27
+ "FusedMoE",
28
+ "FusedMoEMethodBase",
29
+ "FusedMoeWeightScaleSupported",
30
+ "override_config",
31
+ "get_config",
32
+ ]
33
+
34
+ if HAS_TRITON:
35
+ # import to register the custom ops
36
+ import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
37
+ import vllm.model_executor.layers.fused_moe.fused_moe # noqa
38
+ from vllm.model_executor.layers.fused_moe.fused_moe import (
39
+ fused_experts, fused_moe, fused_topk, get_config_file_name,
40
+ grouped_topk)
41
+
42
+ __all__ += [
43
+ "fused_moe",
44
+ "fused_topk",
45
+ "fused_experts",
46
+ "get_config_file_name",
47
+ "grouped_topk",
48
+ ]
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.69 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/__pycache__/fused_marlin_moe.cpython-311.pyc ADDED
Binary file (13.7 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/__pycache__/fused_moe.cpython-311.pyc ADDED
Binary file (50.4 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/__pycache__/layer.cpython-311.pyc ADDED
Binary file (24.3 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/__pycache__/moe_pallas.cpython-311.pyc ADDED
Binary file (3.66 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/__pycache__/moe_torch_iterative.cpython-311.pyc ADDED
Binary file (2.66 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 32,
5
+ "BLOCK_SIZE_K": 256,
6
+ "GROUP_SIZE_M": 16,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 32,
13
+ "BLOCK_SIZE_K": 256,
14
+ "GROUP_SIZE_M": 16,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 32,
21
+ "BLOCK_SIZE_K": 256,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 32,
29
+ "BLOCK_SIZE_K": 256,
30
+ "GROUP_SIZE_M": 16,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 32,
37
+ "BLOCK_SIZE_K": 256,
38
+ "GROUP_SIZE_M": 16,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 32,
45
+ "BLOCK_SIZE_K": 256,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 32,
53
+ "BLOCK_SIZE_K": 256,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 8,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 64,
68
+ "BLOCK_SIZE_N": 64,
69
+ "BLOCK_SIZE_K": 64,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 4
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 32,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 64,
84
+ "BLOCK_SIZE_N": 64,
85
+ "BLOCK_SIZE_K": 64,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 64,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 4
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 64,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 32,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 256,
108
+ "BLOCK_SIZE_N": 32,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 32,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 64,
119
+ "num_warps": 4,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 64,
127
+ "num_warps": 4,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 32,
135
+ "num_warps": 4,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 64,
143
+ "num_warps": 4,
144
+ "num_stages": 4
145
+ },
146
+ "5120": {
147
+ "BLOCK_SIZE_M": 64,
148
+ "BLOCK_SIZE_N": 256,
149
+ "BLOCK_SIZE_K": 64,
150
+ "GROUP_SIZE_M": 32,
151
+ "num_warps": 4,
152
+ "num_stages": 4
153
+ },
154
+ "9216": {
155
+ "BLOCK_SIZE_M": 64,
156
+ "BLOCK_SIZE_N": 256,
157
+ "BLOCK_SIZE_K": 64,
158
+ "GROUP_SIZE_M": 32,
159
+ "num_warps": 4,
160
+ "num_stages": 4
161
+ },
162
+ "13312": {
163
+ "BLOCK_SIZE_M": 64,
164
+ "BLOCK_SIZE_N": 256,
165
+ "BLOCK_SIZE_K": 64,
166
+ "GROUP_SIZE_M": 16,
167
+ "num_warps": 4,
168
+ "num_stages": 4
169
+ },
170
+ "17408": {
171
+ "BLOCK_SIZE_M": 64,
172
+ "BLOCK_SIZE_N": 256,
173
+ "BLOCK_SIZE_K": 64,
174
+ "GROUP_SIZE_M": 16,
175
+ "num_warps": 4,
176
+ "num_stages": 4
177
+ },
178
+ "25600": {
179
+ "BLOCK_SIZE_M": 64,
180
+ "BLOCK_SIZE_N": 256,
181
+ "BLOCK_SIZE_K": 64,
182
+ "GROUP_SIZE_M": 16,
183
+ "num_warps": 4,
184
+ "num_stages": 4
185
+ },
186
+ "33792": {
187
+ "BLOCK_SIZE_M": 64,
188
+ "BLOCK_SIZE_N": 256,
189
+ "BLOCK_SIZE_K": 64,
190
+ "GROUP_SIZE_M": 16,
191
+ "num_warps": 4,
192
+ "num_stages": 4
193
+ },
194
+ "41984": {
195
+ "BLOCK_SIZE_M": 64,
196
+ "BLOCK_SIZE_N": 256,
197
+ "BLOCK_SIZE_K": 64,
198
+ "GROUP_SIZE_M": 16,
199
+ "num_warps": 4,
200
+ "num_stages": 4
201
+ },
202
+ "50176": {
203
+ "BLOCK_SIZE_M": 64,
204
+ "BLOCK_SIZE_N": 256,
205
+ "BLOCK_SIZE_K": 64,
206
+ "GROUP_SIZE_M": 16,
207
+ "num_warps": 4,
208
+ "num_stages": 4
209
+ },
210
+ "58368": {
211
+ "BLOCK_SIZE_M": 64,
212
+ "BLOCK_SIZE_N": 256,
213
+ "BLOCK_SIZE_K": 64,
214
+ "GROUP_SIZE_M": 16,
215
+ "num_warps": 4,
216
+ "num_stages": 4
217
+ }
218
+ }
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 32,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 32,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 5
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 256,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 8,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 64,
37
+ "BLOCK_SIZE_K": 256,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 8,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 256,
46
+ "GROUP_SIZE_M": 32,
47
+ "num_warps": 8,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 16,
55
+ "num_warps": 4,
56
+ "num_stages": 4
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 64,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 5
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 64,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 5
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 64,
77
+ "BLOCK_SIZE_K": 256,
78
+ "GROUP_SIZE_M": 64,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 32,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 32,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 16,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 256,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 256,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 4,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 4,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 4,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 32,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 32,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 32,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 32,
21
+ "BLOCK_SIZE_K": 256,
22
+ "GROUP_SIZE_M": 16,
23
+ "num_warps": 4,
24
+ "num_stages": 2
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 256,
30
+ "GROUP_SIZE_M": 16,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 64,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 32,
39
+ "num_warps": 4,
40
+ "num_stages": 4
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 256,
46
+ "GROUP_SIZE_M": 16,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 256,
54
+ "GROUP_SIZE_M": 64,
55
+ "num_warps": 8,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 256,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 8,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 256,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 8,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 64,
77
+ "BLOCK_SIZE_K": 256,
78
+ "GROUP_SIZE_M": 64,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 32,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 32,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 16,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 256,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 4
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 256,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 4,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 4,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 16,
127
+ "num_warps": 4,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 32,
135
+ "num_warps": 4,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 4
145
+ }
146
+ }
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Fused MoE utilities for GPTQ."""
3
+ import functools
4
+ from typing import Optional
5
+
6
+ import torch
7
+
8
+ from vllm.model_executor.layers.fused_moe.fused_moe import (
9
+ fused_topk, moe_align_block_size, try_get_optimal_moe_config)
10
+ from vllm.scalar_type import scalar_types
11
+ from vllm.utils import direct_register_custom_op
12
+
13
+
14
+ def get_scalar_type(num_bits: int, has_zp: bool):
15
+ if has_zp:
16
+ assert num_bits == 4
17
+ return scalar_types.uint4
18
+ else:
19
+ return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
20
+
21
+
22
+ def single_marlin_moe(
23
+ hidden_states: torch.Tensor,
24
+ w: torch.Tensor,
25
+ scales: torch.Tensor,
26
+ gating_output: torch.Tensor,
27
+ topk: int,
28
+ renormalize: bool,
29
+ g_idx: Optional[torch.Tensor] = None,
30
+ sort_indices: Optional[torch.Tensor] = None,
31
+ w_zeros: Optional[torch.Tensor] = None,
32
+ num_bits: int = 8,
33
+ is_k_full: bool = True,
34
+ ) -> torch.Tensor:
35
+ """
36
+ This function computes the multiplication of hidden_states with expert
37
+ weights used in Marlin MoE, using weights w and top-k gating mechanism.
38
+ Its purpose is testing and debugging the fused MoE kernel.
39
+
40
+ Parameters:
41
+ - hidden_states (torch.Tensor): The input tensor to the Marlin Mul.
42
+ - w (torch.Tensor): The set of expert weights.
43
+ - scales (torch.Tensor): The quantization scales.
44
+ - gating_output (torch.Tensor): The output of the gating operation
45
+ (before softmax).
46
+ - g_idx (Optional[torch.Tensor]): Optional act_order indices.
47
+ - sort_indices (Optional[torch.Tensor]): Optional act_order input
48
+ permutation.
49
+ - topk (int): The number of top-k experts to select.
50
+ - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
51
+ - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
52
+ - num_bits (bool): The number of bits in expert weights quantization.
53
+
54
+ Returns:
55
+ - torch.Tensor: The output tensor after applying the MoE layer.
56
+ """
57
+ # Check constraints.
58
+ assert hidden_states.shape[0] == gating_output.shape[0], (
59
+ "Number of tokens mismatch")
60
+ assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch"
61
+ assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch"
62
+ assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
63
+ assert w.is_contiguous(), "Expert weights must be contiguous"
64
+ assert hidden_states.dtype == torch.float16
65
+ assert num_bits in [4, 8]
66
+
67
+ M, K = hidden_states.shape
68
+ E = w.shape[0]
69
+ N = w.shape[2] // (num_bits // 2)
70
+
71
+ topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
72
+ renormalize)
73
+
74
+ # This might not be an optimal config for a single MMM
75
+ get_config_func = functools.partial(try_get_optimal_moe_config,
76
+ w.shape,
77
+ w.shape,
78
+ topk_ids.shape[1],
79
+ None,
80
+ is_marlin=True)
81
+ config = get_config_func(M)
82
+
83
+ block_size_m = config['BLOCK_SIZE_M']
84
+
85
+ sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
86
+
87
+ max_workspace_size = (N // 64) * 16
88
+ workspace = torch.zeros(max_workspace_size,
89
+ dtype=torch.int,
90
+ device=hidden_states.device,
91
+ requires_grad=False)
92
+
93
+ has_zero_point = w_zeros is not None
94
+ if w_zeros is None:
95
+ w_zeros = torch.empty((0, 0),
96
+ dtype=hidden_states.dtype,
97
+ device=hidden_states.device,
98
+ requires_grad=False)
99
+
100
+ if g_idx is None:
101
+ g_idx = torch.empty((0, 0),
102
+ dtype=torch.int32,
103
+ device=hidden_states.device,
104
+ requires_grad=False)
105
+
106
+ if sort_indices is None:
107
+ sort_indices = torch.empty((0),
108
+ dtype=torch.int32,
109
+ device=hidden_states.device,
110
+ requires_grad=False)
111
+
112
+ scalar_type = get_scalar_type(num_bits, has_zero_point)
113
+
114
+ intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
115
+ hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
116
+ w_zeros, g_idx, sort_indices, workspace, scalar_type.id, M, N, K,
117
+ is_k_full, E, topk, block_size_m, True, False)
118
+
119
+ return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
120
+
121
+
122
+ def single_marlin_moe_fake(
123
+ hidden_states: torch.Tensor,
124
+ w: torch.Tensor,
125
+ scales: torch.Tensor,
126
+ gating_output: torch.Tensor,
127
+ topk: int,
128
+ renormalize: bool,
129
+ g_idx: Optional[torch.Tensor] = None,
130
+ sort_indices: Optional[torch.Tensor] = None,
131
+ w_zeros: Optional[torch.Tensor] = None,
132
+ num_bits: int = 8,
133
+ is_k_full: bool = True,
134
+ ) -> torch.Tensor:
135
+ return torch.empty_like(hidden_states)
136
+
137
+
138
+ direct_register_custom_op(
139
+ op_name="single_marlin_moe",
140
+ op_func=single_marlin_moe,
141
+ mutates_args=[],
142
+ fake_impl=single_marlin_moe_fake,
143
+ )
144
+
145
+
146
+ def fused_marlin_moe(
147
+ hidden_states: torch.Tensor,
148
+ w1: torch.Tensor,
149
+ w2: torch.Tensor,
150
+ w1_scale: torch.Tensor,
151
+ w2_scale: torch.Tensor,
152
+ gating_output: torch.Tensor,
153
+ topk_weights: torch.Tensor,
154
+ topk_ids: torch.Tensor,
155
+ g_idx1: Optional[torch.Tensor] = None,
156
+ g_idx2: Optional[torch.Tensor] = None,
157
+ sort_indices1: Optional[torch.Tensor] = None,
158
+ sort_indices2: Optional[torch.Tensor] = None,
159
+ w1_zeros: Optional[torch.Tensor] = None,
160
+ w2_zeros: Optional[torch.Tensor] = None,
161
+ num_bits: int = 8,
162
+ is_k_full: bool = True,
163
+ ) -> torch.Tensor:
164
+ """
165
+ This function computes a Mixture of Experts (MoE) layer using two sets of
166
+ weights, w1 and w2, and top-k gating mechanism.
167
+
168
+ Parameters:
169
+ - hidden_states (torch.Tensor): The input tensor to the MoE layer.
170
+ - w1 (torch.Tensor): The first set of expert weights.
171
+ - w2 (torch.Tensor): The second set of expert weights.
172
+ - w1_scale (torch.Tensor): Scale to be used for w1.
173
+ - w2_scale (torch.Tensor): Scale to be used for w2.
174
+ - gating_output (torch.Tensor): The output of the gating operation
175
+ (before softmax).
176
+ - g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
177
+ - g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
178
+ - sort_indices1 (Optional[torch.Tensor]): The first act_order input
179
+ permutation.
180
+ - sort_indices2 (Optional[torch.Tensor]): The second act_order input
181
+ permutation.
182
+ - topk_weights (torch.Tensor): Top-k weights.
183
+ - topk_ids (torch.Tensor): Indices of topk-k elements.
184
+ - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
185
+ - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
186
+ - num_bits (bool): The number of bits in expert weights quantization.
187
+
188
+ Returns:
189
+ - torch.Tensor: The output tensor after applying the MoE layer.
190
+ """
191
+ # Check constraints.
192
+ assert hidden_states.shape[0] == gating_output.shape[
193
+ 0], "Number of tokens mismatch"
194
+ assert hidden_states.shape[
195
+ 1] == w1.shape[1] * 16, "Hidden size mismatch w1"
196
+ assert hidden_states.shape[1] == w2.shape[2] // (
197
+ num_bits // 2), "Hidden size mismatch w2"
198
+ assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
199
+ assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
200
+ assert w1.is_contiguous(), "Expert weights1 must be contiguous"
201
+ assert w2.is_contiguous(), "Expert weights2 must be contiguous"
202
+ assert hidden_states.dtype == torch.float16
203
+ assert num_bits in [4, 8]
204
+
205
+ has_no_act_order = (g_idx1 is None and g_idx2 is None
206
+ and sort_indices1 is None and sort_indices2 is None)
207
+ has_all_act_order = (g_idx1 is not None and g_idx2 is not None
208
+ and sort_indices1 is not None
209
+ and sort_indices2 is not None)
210
+ assert has_no_act_order or has_all_act_order, (
211
+ "g_idx and sorted_indices "
212
+ "must be all not None or must be all None")
213
+
214
+ has_no_zp = w1_zeros is None and w2_zeros is None
215
+ has_all_zp = w1_zeros is not None and w2_zeros is not None
216
+ assert has_no_zp or has_all_zp, ("zero points must be both not None or "
217
+ "must be both None")
218
+
219
+ M, K = hidden_states.shape
220
+ E = w1.shape[0]
221
+ N = w2.shape[1] * 16
222
+ topk = topk_ids.shape[1]
223
+
224
+ get_config_func = functools.partial(
225
+ try_get_optimal_moe_config,
226
+ w1.shape,
227
+ w2.shape,
228
+ topk_ids.shape[1],
229
+ None,
230
+ is_marlin=True,
231
+ )
232
+ config = get_config_func(M)
233
+
234
+ block_size_m = config["BLOCK_SIZE_M"]
235
+
236
+ sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
237
+
238
+ max_workspace_size = (max(2 * N, K) // 64) * 16
239
+ workspace = torch.zeros(max_workspace_size,
240
+ dtype=torch.int,
241
+ device="cuda",
242
+ requires_grad=False)
243
+
244
+ if has_no_zp:
245
+ w1_zeros = torch.empty((0, 0),
246
+ dtype=hidden_states.dtype,
247
+ device=hidden_states.device,
248
+ requires_grad=False)
249
+ w2_zeros = torch.empty((0, 0),
250
+ dtype=hidden_states.dtype,
251
+ device=hidden_states.device,
252
+ requires_grad=False)
253
+
254
+ if has_no_act_order:
255
+ g_idx1 = torch.empty((0, 0),
256
+ dtype=torch.int32,
257
+ device=hidden_states.device,
258
+ requires_grad=False)
259
+ g_idx2 = torch.empty((0, 0),
260
+ dtype=torch.int32,
261
+ device=hidden_states.device,
262
+ requires_grad=False)
263
+ sort_indices1 = torch.empty((0),
264
+ dtype=torch.int32,
265
+ device=hidden_states.device,
266
+ requires_grad=False)
267
+ sort_indices2 = torch.empty((0, 0),
268
+ dtype=torch.int32,
269
+ device=hidden_states.device,
270
+ requires_grad=False)
271
+
272
+ scalar_type1 = get_scalar_type(num_bits, has_all_zp)
273
+ scalar_type2 = get_scalar_type(num_bits, has_all_zp)
274
+
275
+ intermediate_cache2 = torch.empty(
276
+ (M * topk_ids.shape[1], N),
277
+ device=hidden_states.device,
278
+ dtype=hidden_states.dtype,
279
+ )
280
+
281
+ intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe(
282
+ hidden_states,
283
+ w1,
284
+ sorted_token_ids,
285
+ topk_weights,
286
+ topk_ids,
287
+ w1_scale,
288
+ w1_zeros,
289
+ g_idx1,
290
+ sort_indices1,
291
+ workspace,
292
+ scalar_type1.id,
293
+ M,
294
+ 2 * N,
295
+ K,
296
+ is_k_full,
297
+ E,
298
+ topk,
299
+ block_size_m,
300
+ True,
301
+ False,
302
+ )
303
+
304
+ torch.ops._C.silu_and_mul(intermediate_cache2,
305
+ intermediate_cache1.view(-1, 2 * N))
306
+
307
+ intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
308
+ intermediate_cache2,
309
+ w2,
310
+ sorted_token_ids,
311
+ topk_weights,
312
+ topk_ids,
313
+ w2_scale,
314
+ w2_zeros,
315
+ g_idx2,
316
+ sort_indices2,
317
+ workspace,
318
+ scalar_type2.id,
319
+ M,
320
+ K,
321
+ N,
322
+ is_k_full,
323
+ E,
324
+ topk,
325
+ block_size_m,
326
+ False,
327
+ True,
328
+ )
329
+
330
+ return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
331
+ dim=1)
332
+
333
+
334
+ def fused_marlin_moe_fake(
335
+ hidden_states: torch.Tensor,
336
+ w1: torch.Tensor,
337
+ w2: torch.Tensor,
338
+ w1_scale: torch.Tensor,
339
+ w2_scale: torch.Tensor,
340
+ gating_output: torch.Tensor,
341
+ topk_weights: torch.Tensor,
342
+ topk_ids: torch.Tensor,
343
+ g_idx1: Optional[torch.Tensor] = None,
344
+ g_idx2: Optional[torch.Tensor] = None,
345
+ sort_indices1: Optional[torch.Tensor] = None,
346
+ sort_indices2: Optional[torch.Tensor] = None,
347
+ w1_zeros: Optional[torch.Tensor] = None,
348
+ w2_zeros: Optional[torch.Tensor] = None,
349
+ num_bits: int = 8,
350
+ is_k_full: bool = True,
351
+ ) -> torch.Tensor:
352
+ return torch.empty_like(hidden_states)
353
+
354
+
355
+ direct_register_custom_op(
356
+ op_name="fused_marlin_moe",
357
+ op_func=fused_marlin_moe,
358
+ mutates_args=[],
359
+ fake_impl=fused_marlin_moe_fake,
360
+ )
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/fused_moe.py ADDED
@@ -0,0 +1,1363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Fused MoE kernel."""
3
+ import functools
4
+ import json
5
+ import os
6
+ from typing import Any, Callable, Dict, List, Optional, Tuple
7
+
8
+ import torch
9
+ import triton
10
+ import triton.language as tl
11
+
12
+ import vllm.envs as envs
13
+ from vllm import _custom_ops as ops
14
+ from vllm.logger import init_logger
15
+ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
16
+ per_token_group_quant_fp8)
17
+ from vllm.platforms import current_platform
18
+ from vllm.utils import direct_register_custom_op
19
+
20
+ logger = init_logger(__name__)
21
+
22
+
23
+ @triton.jit
24
+ def fused_moe_kernel_gptq_awq(
25
+ # Pointers to matrices
26
+ a_ptr,
27
+ b_ptr,
28
+ c_ptr,
29
+ b_scale_ptr,
30
+ b_zp_ptr,
31
+ topk_weights_ptr,
32
+ sorted_token_ids_ptr,
33
+ expert_ids_ptr,
34
+ num_tokens_post_padded_ptr,
35
+ # Matrix dimensions
36
+ N: tl.constexpr,
37
+ K: tl.constexpr,
38
+ EM,
39
+ num_valid_tokens,
40
+ # The stride variables represent how much to increase the ptr by when
41
+ # moving by 1 element in a particular dimension. E.g. `stride_am` is
42
+ # how much to increase `a_ptr` by to get the element one row down
43
+ # (A has M rows).
44
+ stride_am,
45
+ stride_ak,
46
+ stride_be,
47
+ stride_bk,
48
+ stride_bn,
49
+ stride_cm,
50
+ stride_cn,
51
+ stride_bse,
52
+ stride_bsk,
53
+ stride_bsn,
54
+ stride_bze,
55
+ stride_bzk,
56
+ stride_bzn,
57
+ block_k_diviable: tl.constexpr,
58
+ group_size: tl.constexpr,
59
+ # Meta-parameters
60
+ BLOCK_SIZE_M: tl.constexpr,
61
+ BLOCK_SIZE_N: tl.constexpr,
62
+ BLOCK_SIZE_K: tl.constexpr,
63
+ GROUP_SIZE_M: tl.constexpr,
64
+ MUL_ROUTED_WEIGHT: tl.constexpr,
65
+ top_k: tl.constexpr,
66
+ compute_type: tl.constexpr,
67
+ has_zp: tl.constexpr,
68
+ use_int4_w4a16: tl.constexpr,
69
+ use_int8_w8a16: tl.constexpr):
70
+ """
71
+ Implements the fused computation for a Mixture of Experts (MOE) using
72
+ token and expert matrices.
73
+
74
+ Key Parameters:
75
+ - A: The input tensor representing tokens with shape (*, K), where '*' can
76
+ be any shape representing batches and K is the feature dimension of
77
+ each token.
78
+ - B: The stacked MOE weight tensor with shape (E, N, K), where E is
79
+ the number of experts, K is the input feature dimension, and N is
80
+ the output feature dimension.
81
+ - C: The output cache tensor with shape (M, topk, N), where M is the
82
+ total number of tokens post padding, topk is the number of times
83
+ each token is repeated, and N is the output feature dimension.
84
+ - sorted_token_ids: A tensor containing the sorted indices of tokens,
85
+ repeated topk times and arranged by the expert index they are
86
+ assigned to.
87
+ - expert_ids: A tensor containing the indices of the expert for each
88
+ block. It determines which expert matrix from B should be used for
89
+ each block in A.
90
+ This kernel performs the multiplication of a token by its corresponding
91
+ expert matrix as determined by `expert_ids`. The sorting of
92
+ `sorted_token_ids` by expert index and padding ensures divisibility by
93
+ BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
94
+ multiplication across different blocks processed by the same expert.
95
+ """
96
+ # -----------------------------------------------------------
97
+ # Map program ids `pid` to the block of C it should compute.
98
+ # This is done in a grouped ordering to promote L2 data reuse.
99
+ pid = tl.program_id(axis=0)
100
+ num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
101
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
102
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
103
+ group_id = pid // num_pid_in_group
104
+ first_pid_m = group_id * GROUP_SIZE_M
105
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
106
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
107
+ pid_n = (pid % num_pid_in_group) // group_size_m
108
+
109
+ # ----------------------------------------------------------
110
+ # Create pointers for the first blocks of A and B.
111
+ # We will advance this pointer as we move in the K direction
112
+ # and accumulate
113
+ # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
114
+ # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
115
+ num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
116
+ if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
117
+ return
118
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(
119
+ tl.int64)
120
+ offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
121
+ token_mask = offs_token < num_valid_tokens
122
+
123
+ offs_bn = (pid_n * BLOCK_SIZE_N +
124
+ tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
125
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
126
+ a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
127
+ offs_k[None, :] * stride_ak)
128
+
129
+ off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
130
+
131
+ if use_int4_w4a16:
132
+ b_ptrs = b_ptr + off_experts * stride_be + \
133
+ (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn
134
+ b_shifter = (offs_k[:, None] % 2) * 4
135
+ elif use_int8_w8a16:
136
+ b_ptrs = b_ptr + off_experts * stride_be + \
137
+ offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
138
+
139
+ if not has_zp and use_int4_w4a16:
140
+ b_zp_num = 8
141
+ if not has_zp and use_int8_w8a16:
142
+ b_zp_num = 128
143
+ elif has_zp and use_int4_w4a16:
144
+ b_zp_shifter = (offs_bn[None, :] % 2) * 4
145
+
146
+ # -----------------------------------------------------------
147
+ # Iterate to compute a block of the C matrix.
148
+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
149
+ # of fp32 values for higher accuracy.
150
+ # `accumulator` will be converted back to fp16 after the loop.
151
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
152
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
153
+ # Load the next block of A and B, generate a mask by checking the
154
+ # K dimension.
155
+
156
+ if not block_k_diviable:
157
+ k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
158
+ k_other = 0.0
159
+ else:
160
+ k_mask = None
161
+ k_other = None
162
+
163
+ a = tl.load(a_ptrs,
164
+ mask=token_mask[:, None] &
165
+ (offs_k[None, :] < K - k * BLOCK_SIZE_K),
166
+ other=0.0)
167
+ b = tl.load(b_ptrs)
168
+ if use_int4_w4a16:
169
+ b = (b >> b_shifter) & 0xF
170
+
171
+ b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \
172
+ offs_bn[None, :] * stride_bsn + \
173
+ ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
174
+ b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
175
+ b_scale = b_scale.to(tl.float32)
176
+
177
+ if has_zp and use_int4_w4a16:
178
+ offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
179
+ b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \
180
+ (offs_bn[None, :] // 2) * stride_bzn + \
181
+ offs_k_true * stride_bzk
182
+ b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
183
+ b_zp = ((b_zp >> b_zp_shifter) & 0xF)
184
+ b_zp = b_zp.to(tl.float32)
185
+ elif has_zp and use_int8_w8a16:
186
+ offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
187
+ b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \
188
+ offs_bn[None, :] * stride_bzn + \
189
+ offs_k_true * stride_bzk
190
+ b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
191
+ b_zp = b_zp.to(tl.float32)
192
+
193
+ # We accumulate along the K dimension.
194
+ if has_zp:
195
+ b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
196
+ else:
197
+ b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
198
+ accumulator = tl.dot(a, b, acc=accumulator)
199
+
200
+ # Advance the ptrs to the next K block.
201
+ a_ptrs += BLOCK_SIZE_K * stride_ak
202
+ if use_int4_w4a16:
203
+ b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
204
+ else:
205
+ b_ptrs += BLOCK_SIZE_K * stride_bk
206
+
207
+ if MUL_ROUTED_WEIGHT:
208
+ moe_weight = tl.load(topk_weights_ptr + offs_token,
209
+ mask=token_mask,
210
+ other=0)
211
+ accumulator = accumulator * moe_weight[:, None]
212
+
213
+ accumulator = accumulator.to(compute_type)
214
+ # -----------------------------------------------------------
215
+ # Write back the block of the output
216
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
217
+ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
218
+ None, :]
219
+ c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
220
+ tl.store(c_ptrs, accumulator, mask=c_mask)
221
+
222
+
223
+ @triton.jit
224
+ def fused_moe_kernel(
225
+ # Pointers to matrices
226
+ a_ptr,
227
+ b_ptr,
228
+ c_ptr,
229
+ a_scale_ptr,
230
+ b_scale_ptr,
231
+ topk_weights_ptr,
232
+ sorted_token_ids_ptr,
233
+ expert_ids_ptr,
234
+ num_tokens_post_padded_ptr,
235
+ # Matrix dimensions
236
+ N,
237
+ K,
238
+ EM,
239
+ num_valid_tokens,
240
+ # The stride variables represent how much to increase the ptr by when
241
+ # moving by 1 element in a particular dimension. E.g. `stride_am` is
242
+ # how much to increase `a_ptr` by to get the element one row down
243
+ # (A has M rows).
244
+ stride_am,
245
+ stride_ak,
246
+ stride_be,
247
+ stride_bk,
248
+ stride_bn,
249
+ stride_cm,
250
+ stride_cn,
251
+ stride_asm,
252
+ stride_ask,
253
+ stride_bse,
254
+ stride_bsk,
255
+ stride_bsn,
256
+ # Block size for block-wise quantization
257
+ group_n: tl.constexpr,
258
+ group_k: tl.constexpr,
259
+ # Meta-parameters
260
+ BLOCK_SIZE_M: tl.constexpr,
261
+ BLOCK_SIZE_N: tl.constexpr,
262
+ BLOCK_SIZE_K: tl.constexpr,
263
+ GROUP_SIZE_M: tl.constexpr,
264
+ MUL_ROUTED_WEIGHT: tl.constexpr,
265
+ top_k: tl.constexpr,
266
+ compute_type: tl.constexpr,
267
+ use_fp8_w8a8: tl.constexpr,
268
+ use_int8_w8a16: tl.constexpr):
269
+ """
270
+ Implements the fused computation for a Mixture of Experts (MOE) using
271
+ token and expert matrices.
272
+
273
+ Key Parameters:
274
+ - A: The input tensor representing tokens with shape (*, K), where '*' can
275
+ be any shape representing batches and K is the feature dimension of
276
+ each token.
277
+ - B: The stacked MOE weight tensor with shape (E, N, K), where E is
278
+ the number of experts, K is the input feature dimension, and N is
279
+ the output feature dimension.
280
+ - C: The output cache tensor with shape (M, topk, N), where M is the
281
+ total number of tokens post padding, topk is the number of times
282
+ each token is repeated, and N is the output feature dimension.
283
+ - sorted_token_ids: A tensor containing the sorted indices of tokens,
284
+ repeated topk times and arranged by the expert index they are
285
+ assigned to.
286
+ - expert_ids: A tensor containing the indices of the expert for each
287
+ block. It determines which expert matrix from B should be used for
288
+ each block in A.
289
+ This kernel performs the multiplication of a token by its corresponding
290
+ expert matrix as determined by `expert_ids`. The sorting of
291
+ `sorted_token_ids` by expert index and padding ensures divisibility by
292
+ BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
293
+ multiplication across different blocks processed by the same expert.
294
+ """
295
+ # -----------------------------------------------------------
296
+ # Map program ids `pid` to the block of C it should compute.
297
+ # This is done in a grouped ordering to promote L2 data reuse.
298
+ pid = tl.program_id(axis=0)
299
+ num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
300
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
301
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
302
+ group_id = pid // num_pid_in_group
303
+ first_pid_m = group_id * GROUP_SIZE_M
304
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
305
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
306
+ pid_n = (pid % num_pid_in_group) // group_size_m
307
+
308
+ # ----------------------------------------------------------
309
+ # Create pointers for the first blocks of A and B.
310
+ # We will advance this pointer as we move in the K direction
311
+ # and accumulate
312
+ # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
313
+ # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
314
+ num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
315
+ if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
316
+ return
317
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(
318
+ tl.int64)
319
+ offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
320
+ token_mask = offs_token < num_valid_tokens
321
+
322
+ offs_bn = (pid_n * BLOCK_SIZE_N +
323
+ tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
324
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
325
+ a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
326
+ offs_k[None, :] * stride_ak)
327
+
328
+ off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
329
+ b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
330
+ offs_bn[None, :] * stride_bn)
331
+ if use_int8_w8a16:
332
+ b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[
333
+ None, :] * stride_bsn
334
+ b_scale = tl.load(b_scale_ptrs)
335
+
336
+ if use_fp8_w8a8:
337
+ if group_k > 0 and group_n > 0:
338
+ a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
339
+ offs_bsn = offs_bn // group_n
340
+ b_scale_ptrs = (b_scale_ptr + off_experts * stride_bse +
341
+ offs_bsn * stride_bsn)
342
+ else:
343
+ a_scale = tl.load(a_scale_ptr)
344
+ b_scale = tl.load(b_scale_ptr + off_experts)
345
+
346
+ # -----------------------------------------------------------
347
+ # Iterate to compute a block of the C matrix.
348
+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
349
+ # of fp32 values for higher accuracy.
350
+ # `accumulator` will be converted back to fp16 after the loop.
351
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
352
+
353
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
354
+ # Load the next block of A and B, generate a mask by checking the
355
+ # K dimension.
356
+ a = tl.load(a_ptrs,
357
+ mask=token_mask[:, None] &
358
+ (offs_k[None, :] < K - k * BLOCK_SIZE_K),
359
+ other=0.0)
360
+ b = tl.load(b_ptrs,
361
+ mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
362
+ other=0.0)
363
+ # We accumulate along the K dimension.
364
+ if use_int8_w8a16:
365
+ accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
366
+ elif use_fp8_w8a8:
367
+ if group_k > 0 and group_n > 0:
368
+ k_start = k * BLOCK_SIZE_K
369
+ offs_ks = k_start // group_k
370
+ a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask,
371
+ mask=token_mask,
372
+ other=0.0)
373
+ b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
374
+
375
+ accumulator += tl.dot(a, b) * a_scale[:,
376
+ None] * b_scale[None, :]
377
+ else:
378
+ accumulator = tl.dot(a, b, acc=accumulator)
379
+ else:
380
+ accumulator += tl.dot(a, b)
381
+ # Advance the ptrs to the next K block.
382
+ a_ptrs += BLOCK_SIZE_K * stride_ak
383
+ b_ptrs += BLOCK_SIZE_K * stride_bk
384
+
385
+ if MUL_ROUTED_WEIGHT:
386
+ moe_weight = tl.load(topk_weights_ptr + offs_token,
387
+ mask=token_mask,
388
+ other=0)
389
+ accumulator = accumulator * moe_weight[:, None]
390
+ if use_int8_w8a16:
391
+ accumulator = (accumulator * b_scale).to(compute_type)
392
+ elif use_fp8_w8a8:
393
+ if group_k > 0 and group_n > 0:
394
+ accumulator = accumulator.to(compute_type)
395
+ else:
396
+ accumulator = (accumulator * a_scale * b_scale).to(compute_type)
397
+ else:
398
+ accumulator = accumulator.to(compute_type)
399
+ # -----------------------------------------------------------
400
+ # Write back the block of the output
401
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
402
+ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
403
+ None, :]
404
+ c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
405
+ tl.store(c_ptrs, accumulator, mask=c_mask)
406
+
407
+
408
+ def ceil_div(a, b):
409
+ return (a + b - 1) // b
410
+
411
+
412
+ @triton.jit
413
+ def moe_align_block_size_stage1(
414
+ topk_ids_ptr,
415
+ tokens_cnts_ptr,
416
+ num_experts: tl.constexpr,
417
+ numel: tl.constexpr,
418
+ tokens_per_thread: tl.constexpr,
419
+ ):
420
+ pid = tl.program_id(0)
421
+
422
+ start_idx = pid * tokens_per_thread
423
+
424
+ off_c = (pid + 1) * num_experts
425
+
426
+ for i in range(tokens_per_thread):
427
+ if start_idx + i < numel:
428
+ idx = tl.load(topk_ids_ptr + start_idx + i)
429
+ token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
430
+ tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
431
+
432
+
433
+ @triton.jit
434
+ def moe_align_block_size_stage2(
435
+ tokens_cnts_ptr,
436
+ num_experts: tl.constexpr,
437
+ ):
438
+ pid = tl.program_id(0)
439
+
440
+ last_cnt = 0
441
+ for i in range(1, num_experts + 1):
442
+ token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
443
+ last_cnt = last_cnt + token_cnt
444
+ tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
445
+
446
+
447
+ @triton.jit
448
+ def moe_align_block_size_stage3(
449
+ total_tokens_post_pad_ptr,
450
+ tokens_cnts_ptr,
451
+ cumsum_ptr,
452
+ num_experts: tl.constexpr,
453
+ block_size: tl.constexpr,
454
+ ):
455
+ last_cumsum = 0
456
+ off_cnt = num_experts * num_experts
457
+ for i in range(1, num_experts + 1):
458
+ token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
459
+ last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
460
+ tl.store(cumsum_ptr + i, last_cumsum)
461
+ tl.store(total_tokens_post_pad_ptr, last_cumsum)
462
+
463
+
464
+ @triton.jit
465
+ def moe_align_block_size_stage4(
466
+ topk_ids_ptr,
467
+ sorted_token_ids_ptr,
468
+ expert_ids_ptr,
469
+ tokens_cnts_ptr,
470
+ cumsum_ptr,
471
+ num_experts: tl.constexpr,
472
+ block_size: tl.constexpr,
473
+ numel: tl.constexpr,
474
+ tokens_per_thread: tl.constexpr,
475
+ ):
476
+ pid = tl.program_id(0)
477
+ start_idx = tl.load(cumsum_ptr + pid)
478
+ end_idx = tl.load(cumsum_ptr + pid + 1)
479
+
480
+ for i in range(start_idx, end_idx, block_size):
481
+ tl.store(expert_ids_ptr + i // block_size, pid)
482
+
483
+ start_idx = pid * tokens_per_thread
484
+ off_t = pid * num_experts
485
+
486
+ for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread,
487
+ numel)):
488
+ expert_id = tl.load(topk_ids_ptr + i)
489
+ token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
490
+ rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
491
+ tl.store(sorted_token_ids_ptr + rank_post_pad, i)
492
+ tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
493
+
494
+
495
+ # Triton implementation based on:
496
+ # https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
497
+ def moe_align_block_size_triton(
498
+ topk_ids: torch.Tensor,
499
+ num_experts: int,
500
+ block_size: int,
501
+ sorted_token_ids: torch.Tensor,
502
+ expert_ids: torch.Tensor,
503
+ num_tokens_post_pad: torch.Tensor,
504
+ ) -> None:
505
+ numel = topk_ids.numel()
506
+ grid = (num_experts, )
507
+ tokens_cnts = torch.zeros((num_experts + 1, num_experts),
508
+ dtype=torch.int32,
509
+ device=topk_ids.device)
510
+ cumsum = torch.zeros((num_experts + 1, ),
511
+ dtype=torch.int32,
512
+ device=topk_ids.device)
513
+ tokens_per_thread = ceil_div(numel, num_experts)
514
+
515
+ moe_align_block_size_stage1[grid](
516
+ topk_ids,
517
+ tokens_cnts,
518
+ num_experts,
519
+ numel,
520
+ tokens_per_thread,
521
+ )
522
+ moe_align_block_size_stage2[grid](
523
+ tokens_cnts,
524
+ num_experts,
525
+ )
526
+ moe_align_block_size_stage3[(1, )](
527
+ num_tokens_post_pad,
528
+ tokens_cnts,
529
+ cumsum,
530
+ num_experts,
531
+ block_size,
532
+ )
533
+ moe_align_block_size_stage4[grid](
534
+ topk_ids,
535
+ sorted_token_ids,
536
+ expert_ids,
537
+ tokens_cnts,
538
+ cumsum,
539
+ num_experts,
540
+ block_size,
541
+ numel,
542
+ tokens_per_thread,
543
+ )
544
+
545
+
546
+ def moe_align_block_size(
547
+ topk_ids: torch.Tensor, block_size: int,
548
+ num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
549
+ """
550
+ Aligns the token distribution across experts to be compatible with block
551
+ size for matrix multiplication.
552
+
553
+ Parameters:
554
+ - topk_ids: A tensor of shape [total_tokens, top_k] representing the
555
+ top-k expert indices for each token.
556
+ - block_size: The block size used in block matrix multiplication.
557
+ - num_experts: The total number of experts.
558
+
559
+ Returns:
560
+ - sorted_token_ids: A tensor containing the sorted token indices according
561
+ to their allocated expert.
562
+ - expert_ids: A tensor indicating the assigned expert index for each block.
563
+ - num_tokens_post_padded: The total number of tokens after padding,
564
+ ensuring divisibility by block_size.
565
+
566
+ This function pads the number of tokens that each expert needs to process
567
+ so that it is divisible by block_size.
568
+ Padding ensures that during block matrix multiplication, the dimensions
569
+ align correctly.
570
+
571
+ Example:
572
+ Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
573
+ block_size = 4, and num_experts = 4:
574
+ - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
575
+ with each expert needing to process 3 tokens.
576
+ - As block_size is 4, we pad 1 token for each expert.
577
+ - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
578
+ - Then append padding tokens [12, 12, 12, 12] for each block.
579
+ - After sorting by expert index, we obtain token_ids
580
+ [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
581
+ Tokens 12 are non-existent (padding) and are ignored in
582
+ the subsequent matrix multiplication.
583
+ - The padding ensures that the total number of tokens is now divisible
584
+ by block_size for proper block matrix operations.
585
+ """
586
+ max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
587
+ sorted_ids = torch.empty((max_num_tokens_padded, ),
588
+ dtype=torch.int32,
589
+ device=topk_ids.device)
590
+ sorted_ids.fill_(topk_ids.numel())
591
+ max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
592
+ expert_ids = torch.empty((max_num_m_blocks, ),
593
+ dtype=torch.int32,
594
+ device=topk_ids.device)
595
+ num_tokens_post_pad = torch.empty((1),
596
+ dtype=torch.int32,
597
+ device=topk_ids.device)
598
+ if num_experts >= 224:
599
+ if envs.VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON:
600
+ moe_align_block_size_triton(
601
+ topk_ids,
602
+ num_experts,
603
+ block_size,
604
+ sorted_ids,
605
+ expert_ids,
606
+ num_tokens_post_pad,
607
+ )
608
+ else:
609
+ ops.sgl_moe_align_block_size(
610
+ topk_ids,
611
+ num_experts,
612
+ block_size,
613
+ sorted_ids,
614
+ expert_ids,
615
+ num_tokens_post_pad,
616
+ )
617
+ else:
618
+ ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
619
+ expert_ids, num_tokens_post_pad)
620
+ return sorted_ids, expert_ids, num_tokens_post_pad
621
+
622
+
623
+ def invoke_fused_moe_kernel(A: torch.Tensor,
624
+ B: torch.Tensor,
625
+ C: torch.Tensor,
626
+ A_scale: Optional[torch.Tensor],
627
+ B_scale: Optional[torch.Tensor],
628
+ B_zp: Optional[torch.Tensor],
629
+ topk_weights: torch.Tensor,
630
+ topk_ids: torch.Tensor,
631
+ sorted_token_ids: torch.Tensor,
632
+ expert_ids: torch.Tensor,
633
+ num_tokens_post_padded: torch.Tensor,
634
+ mul_routed_weight: bool,
635
+ top_k: int,
636
+ config: Dict[str, Any],
637
+ compute_type: tl.dtype,
638
+ use_fp8_w8a8: bool,
639
+ use_int8_w8a16: bool,
640
+ use_int4_w4a16: bool,
641
+ block_shape: Optional[List[int]] = None) -> None:
642
+ assert topk_weights.stride(1) == 1
643
+ assert sorted_token_ids.stride(0) == 1
644
+
645
+ if use_fp8_w8a8:
646
+ assert B_scale is not None
647
+ if block_shape is None:
648
+ A, A_scale = ops.scaled_fp8_quant(A, A_scale)
649
+ else:
650
+ assert len(block_shape) == 2
651
+ block_n, block_k = block_shape[0], block_shape[1]
652
+ A, A_scale = per_token_group_quant_fp8(A, block_k)
653
+ assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
654
+ assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
655
+ assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
656
+ elif use_int8_w8a16 or use_int4_w4a16:
657
+ assert B_scale is not None
658
+ assert block_shape is None or block_shape[0] == 0
659
+ else:
660
+ assert A_scale is None
661
+ assert B_scale is None
662
+
663
+ EM = sorted_token_ids.shape[0]
664
+ if A.shape[0] < config["BLOCK_SIZE_M"]:
665
+ # optimize for small batch_size.
666
+ # We assume that top_ids of each token is unique, so
667
+ # so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
668
+ # and we can skip some invalid blocks.
669
+ EM = min(sorted_token_ids.shape[0],
670
+ A.shape[0] * top_k * config['BLOCK_SIZE_M'])
671
+ grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(
672
+ B.shape[1], META['BLOCK_SIZE_N']), )
673
+
674
+ if (use_int8_w8a16 or use_int4_w4a16) and \
675
+ block_shape is not None and block_shape[1] > 0:
676
+ assert B_scale is not None and B_scale.ndim == 3
677
+ assert B_zp is None or B_zp.ndim == 3
678
+
679
+ fused_moe_kernel_gptq_awq[grid](
680
+ A,
681
+ B,
682
+ C,
683
+ B_scale,
684
+ B_zp,
685
+ topk_weights,
686
+ sorted_token_ids,
687
+ expert_ids,
688
+ num_tokens_post_padded,
689
+ B.shape[1],
690
+ A.shape[1],
691
+ EM,
692
+ topk_ids.numel(),
693
+ A.stride(0),
694
+ A.stride(1),
695
+ B.stride(0),
696
+ B.stride(2),
697
+ B.stride(1),
698
+ C.stride(1),
699
+ C.stride(2),
700
+ B_scale.stride(0),
701
+ B_scale.stride(2),
702
+ B_scale.stride(1),
703
+ B_zp.stride(0) if B_zp is not None else 0,
704
+ B_zp.stride(2) if B_zp is not None else 0,
705
+ B_zp.stride(1) if B_zp is not None else 0,
706
+ block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0,
707
+ group_size=block_shape[1],
708
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
709
+ top_k=top_k,
710
+ compute_type=compute_type,
711
+ has_zp=B_zp is not None,
712
+ use_int4_w4a16=use_int4_w4a16,
713
+ use_int8_w8a16=use_int8_w8a16,
714
+ **config,
715
+ )
716
+
717
+ else:
718
+ fused_moe_kernel[grid](
719
+ A,
720
+ B,
721
+ C,
722
+ A_scale,
723
+ B_scale,
724
+ topk_weights,
725
+ sorted_token_ids,
726
+ expert_ids,
727
+ num_tokens_post_padded,
728
+ B.shape[1],
729
+ A.shape[1],
730
+ EM,
731
+ topk_ids.numel(),
732
+ A.stride(0),
733
+ A.stride(1),
734
+ B.stride(0),
735
+ B.stride(2),
736
+ B.stride(1),
737
+ C.stride(1),
738
+ C.stride(2),
739
+ A_scale.stride(0)
740
+ if A_scale is not None and A_scale.ndim == 2 else 0,
741
+ A_scale.stride(1)
742
+ if A_scale is not None and A_scale.ndim == 2 else 0,
743
+ B_scale.stride(0)
744
+ if B_scale is not None and B_scale.ndim >= 2 else 0,
745
+ B_scale.stride(2)
746
+ if B_scale is not None and B_scale.ndim == 3 else 0,
747
+ B_scale.stride(1)
748
+ if B_scale is not None and B_scale.ndim >= 2 else 0,
749
+ 0 if block_shape is None else block_shape[0],
750
+ 0 if block_shape is None else block_shape[1],
751
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
752
+ top_k=top_k,
753
+ compute_type=compute_type,
754
+ use_fp8_w8a8=use_fp8_w8a8,
755
+ use_int8_w8a16=use_int8_w8a16,
756
+ **config,
757
+ )
758
+
759
+
760
+ # Adapted from: https://github.com/sgl-project/sglang/pull/2628
761
+ def get_config_file_name(E: int,
762
+ N: int,
763
+ dtype: Optional[str],
764
+ block_shape: Optional[List[int]] = None) -> str:
765
+ device_name = current_platform.get_device_name().replace(" ", "_")
766
+ dtype_selector = "" if not dtype else f",dtype={dtype}"
767
+ block_shape_selector = ("" if not block_shape or not all(block_shape) else
768
+ f",block_shape={block_shape}").replace(" ", "")
769
+ return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501
770
+
771
+
772
+ # Adapted from: https://github.com/sgl-project/sglang/pull/2628
773
+ @functools.lru_cache
774
+ def get_moe_configs(
775
+ E: int,
776
+ N: int,
777
+ dtype: Optional[str],
778
+ block_n: Optional[int] = None,
779
+ block_k: Optional[int] = None,
780
+ ) -> Optional[Dict[int, Any]]:
781
+ """
782
+ Return optimized configurations for the fused MoE kernel.
783
+
784
+ The return value will be a dictionary that maps an irregular grid of
785
+ batch sizes to configurations of the fused_moe kernel. To evaluate the
786
+ kernel on a given batch size bs, the closest batch size in the grid should
787
+ be picked and the associated configuration chosen to invoke the kernel.
788
+ """
789
+
790
+ # First look up if an optimized configuration is available in the configs
791
+ # directory
792
+ block_shape = [block_n, block_k] if block_n and block_k else None
793
+ json_file_name = get_config_file_name(E, N, dtype, block_shape)
794
+
795
+ config_file_path = os.path.join(
796
+ os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
797
+ if os.path.exists(config_file_path):
798
+ with open(config_file_path) as f:
799
+ logger.info("Using configuration from %s for MoE layer.",
800
+ config_file_path)
801
+ # If a configuration has been found, return it
802
+ return {int(key): val for key, val in json.load(f).items()}
803
+
804
+ # If no optimized configuration is available, we will use the default
805
+ # configuration
806
+ logger.warning(
807
+ ("Using default MoE config. Performance might be sub-optimal! "
808
+ "Config file not found at %s"), config_file_path)
809
+ return None
810
+
811
+
812
+ def get_default_config(
813
+ M: int,
814
+ E: int,
815
+ N: int,
816
+ K: int,
817
+ topk: int,
818
+ dtype: Optional[str],
819
+ is_marlin: bool,
820
+ block_shape: Optional[List[int]] = None,
821
+ ) -> Dict[str, int]:
822
+ if dtype == "fp8_w8a8" and block_shape is not None:
823
+ # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
824
+ # BLOCK_SIZE_K must be divisible by block_shape[1]
825
+ config = {
826
+ "BLOCK_SIZE_M": 64,
827
+ "BLOCK_SIZE_N": block_shape[0],
828
+ "BLOCK_SIZE_K": block_shape[1],
829
+ "GROUP_SIZE_M": 32,
830
+ "num_warps": 4,
831
+ "num_stages": 3,
832
+ }
833
+ else:
834
+ config = {
835
+ "BLOCK_SIZE_M": 64,
836
+ "BLOCK_SIZE_N": 64,
837
+ "BLOCK_SIZE_K": 32,
838
+ "GROUP_SIZE_M": 8,
839
+ }
840
+ # A heuristic: fused marlin works faster with this config for small M
841
+ if M <= E or (is_marlin and M <= 32):
842
+ config = {
843
+ "BLOCK_SIZE_M": 16,
844
+ "BLOCK_SIZE_N": 32,
845
+ "BLOCK_SIZE_K": 64,
846
+ "GROUP_SIZE_M": 1,
847
+ }
848
+ return config
849
+
850
+
851
+ def try_get_optimal_moe_config(
852
+ w1_shape: Tuple[int, ...],
853
+ w2_shape: Tuple[int, ...],
854
+ top_k: int,
855
+ dtype: Optional[str],
856
+ M: int,
857
+ is_marlin: bool = False,
858
+ block_shape: Optional[List[int]] = None,
859
+ ):
860
+ from vllm.model_executor.layers.fused_moe import get_config
861
+ override_config = get_config()
862
+ if override_config:
863
+ config = override_config
864
+ else:
865
+ # First try to load optimal config from the file
866
+ E, _, N = w2_shape
867
+ block_n = block_shape[0] if block_shape else 0
868
+ block_k = block_shape[1] if block_shape else 0
869
+ configs = get_moe_configs(E, N, dtype, block_n, block_k)
870
+
871
+ if configs:
872
+ # If an optimal configuration map has been found, look up the
873
+ # optimal config
874
+ config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
875
+ else:
876
+ # Else use the default config
877
+ config = get_default_config(M, E, N, w1_shape[2], top_k, dtype,
878
+ is_marlin, block_shape)
879
+ return config
880
+
881
+
882
+ def fused_topk(
883
+ hidden_states: torch.Tensor,
884
+ gating_output: torch.Tensor,
885
+ topk: int,
886
+ renormalize: bool,
887
+ ):
888
+ assert hidden_states.shape[0] == gating_output.shape[0], (
889
+ "Number of tokens mismatch")
890
+
891
+ M, _ = hidden_states.shape
892
+
893
+ topk_weights = torch.empty(M,
894
+ topk,
895
+ dtype=torch.float32,
896
+ device=hidden_states.device)
897
+ topk_ids = torch.empty(M,
898
+ topk,
899
+ dtype=torch.int32,
900
+ device=hidden_states.device)
901
+ token_expert_indicies = torch.empty(M,
902
+ topk,
903
+ dtype=torch.int32,
904
+ device=hidden_states.device)
905
+
906
+ ops.topk_softmax(
907
+ topk_weights,
908
+ topk_ids,
909
+ token_expert_indicies,
910
+ gating_output.float(), # TODO(woosuk): Optimize this.
911
+ )
912
+ del token_expert_indicies # Not used. Will be used in the future.
913
+
914
+ if renormalize:
915
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
916
+
917
+ return topk_weights, topk_ids
918
+
919
+
920
+ # This is used by the Deepseek-V2 and Deepseek-V3 model
921
+ @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
922
+ def grouped_topk(hidden_states: torch.Tensor,
923
+ gating_output: torch.Tensor,
924
+ topk: int,
925
+ renormalize: bool,
926
+ num_expert_group: int = 0,
927
+ topk_group: int = 0,
928
+ scoring_func: str = "softmax",
929
+ e_score_correction_bias: Optional[torch.Tensor] = None):
930
+
931
+ assert hidden_states.shape[0] == gating_output.shape[0], (
932
+ "Number of tokens mismatch")
933
+
934
+ if scoring_func == "softmax":
935
+ scores = torch.softmax(gating_output, dim=-1)
936
+ elif scoring_func == "sigmoid":
937
+ scores = gating_output.sigmoid()
938
+ else:
939
+ raise ValueError(f"Unsupported scoring function: {scoring_func}")
940
+
941
+ if e_score_correction_bias is not None:
942
+ # Store original scores before applying correction bias. We use biased
943
+ # scores for expert selection but original scores for routing weights
944
+ original_scores = scores
945
+ scores = scores + e_score_correction_bias.unsqueeze(0)
946
+
947
+ num_token = scores.shape[0]
948
+ group_scores = scores.view(num_token, num_expert_group,
949
+ -1).max(dim=-1).values # [n, n_group]
950
+ group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
951
+ sorted=False)[1] # [n, top_k_group]
952
+ group_mask = torch.zeros_like(group_scores) # [n, n_group]
953
+ group_mask.scatter_(1, group_idx, 1) # [n, n_group]
954
+ score_mask = group_mask.unsqueeze(-1).expand(
955
+ num_token, num_expert_group,
956
+ scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
957
+ tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
958
+
959
+ if e_score_correction_bias is not None:
960
+ topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
961
+ # Use original unbiased scores for the routing weights
962
+ topk_weights = original_scores.gather(1, topk_ids)
963
+ else:
964
+ topk_weights, topk_ids = torch.topk(tmp_scores,
965
+ k=topk,
966
+ dim=-1,
967
+ sorted=False)
968
+
969
+ if renormalize:
970
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
971
+
972
+ return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
973
+
974
+
975
+ def get_config_dtype_str(dtype: torch.dtype,
976
+ use_int4_w4a16: Optional[bool] = False,
977
+ use_int8_w8a16: Optional[bool] = False,
978
+ use_fp8_w8a8: Optional[bool] = False):
979
+ if use_fp8_w8a8:
980
+ return "fp8_w8a8"
981
+ elif use_int8_w8a16:
982
+ return "int8_w8a16"
983
+ elif use_int4_w4a16:
984
+ return "int4_w8a16"
985
+ elif dtype == torch.float:
986
+ # avoiding cases where kernel fails when float32 MoE
987
+ # use fp16/bfloat16 configs
988
+ return "float32"
989
+ return None
990
+
991
+
992
+ def inplace_fused_experts(hidden_states: torch.Tensor,
993
+ w1: torch.Tensor,
994
+ w2: torch.Tensor,
995
+ topk_weights: torch.Tensor,
996
+ topk_ids: torch.Tensor,
997
+ use_fp8_w8a8: bool = False,
998
+ use_int8_w8a16: bool = False,
999
+ use_int4_w4a16: bool = False,
1000
+ w1_scale: Optional[torch.Tensor] = None,
1001
+ w2_scale: Optional[torch.Tensor] = None,
1002
+ w1_zp: Optional[torch.Tensor] = None,
1003
+ w2_zp: Optional[torch.Tensor] = None,
1004
+ a1_scale: Optional[torch.Tensor] = None,
1005
+ a2_scale: Optional[torch.Tensor] = None,
1006
+ block_shape: Optional[List[int]] = None) -> None:
1007
+ fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
1008
+ use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, w1_scale,
1009
+ w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
1010
+
1011
+
1012
+ def inplace_fused_experts_fake(
1013
+ hidden_states: torch.Tensor,
1014
+ w1: torch.Tensor,
1015
+ w2: torch.Tensor,
1016
+ topk_weights: torch.Tensor,
1017
+ topk_ids: torch.Tensor,
1018
+ use_fp8_w8a8: bool = False,
1019
+ use_int8_w8a16: bool = False,
1020
+ use_int4_w4a16: bool = False,
1021
+ w1_scale: Optional[torch.Tensor] = None,
1022
+ w2_scale: Optional[torch.Tensor] = None,
1023
+ w1_zp: Optional[torch.Tensor] = None,
1024
+ w2_zp: Optional[torch.Tensor] = None,
1025
+ a1_scale: Optional[torch.Tensor] = None,
1026
+ a2_scale: Optional[torch.Tensor] = None,
1027
+ block_shape: Optional[List[int]] = None) -> None:
1028
+ pass
1029
+
1030
+
1031
+ direct_register_custom_op(
1032
+ op_name="inplace_fused_experts",
1033
+ op_func=inplace_fused_experts,
1034
+ mutates_args=["hidden_states"],
1035
+ fake_impl=inplace_fused_experts_fake,
1036
+ )
1037
+
1038
+
1039
+ def outplace_fused_experts(
1040
+ hidden_states: torch.Tensor,
1041
+ w1: torch.Tensor,
1042
+ w2: torch.Tensor,
1043
+ topk_weights: torch.Tensor,
1044
+ topk_ids: torch.Tensor,
1045
+ use_fp8_w8a8: bool = False,
1046
+ use_int8_w8a16: bool = False,
1047
+ use_int4_w4a16: bool = False,
1048
+ w1_scale: Optional[torch.Tensor] = None,
1049
+ w2_scale: Optional[torch.Tensor] = None,
1050
+ w1_zp: Optional[torch.Tensor] = None,
1051
+ w2_zp: Optional[torch.Tensor] = None,
1052
+ a1_scale: Optional[torch.Tensor] = None,
1053
+ a2_scale: Optional[torch.Tensor] = None,
1054
+ block_shape: Optional[List[int]] = None) -> torch.Tensor:
1055
+ return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
1056
+ False, use_fp8_w8a8, use_int8_w8a16,
1057
+ use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp,
1058
+ a1_scale, a2_scale, block_shape)
1059
+
1060
+
1061
+ def outplace_fused_experts_fake(
1062
+ hidden_states: torch.Tensor,
1063
+ w1: torch.Tensor,
1064
+ w2: torch.Tensor,
1065
+ topk_weights: torch.Tensor,
1066
+ topk_ids: torch.Tensor,
1067
+ use_fp8_w8a8: bool = False,
1068
+ use_int8_w8a16: bool = False,
1069
+ use_int4_w4a16: bool = False,
1070
+ w1_scale: Optional[torch.Tensor] = None,
1071
+ w2_scale: Optional[torch.Tensor] = None,
1072
+ w1_zp: Optional[torch.Tensor] = None,
1073
+ w2_zp: Optional[torch.Tensor] = None,
1074
+ a1_scale: Optional[torch.Tensor] = None,
1075
+ a2_scale: Optional[torch.Tensor] = None,
1076
+ block_shape: Optional[List[int]] = None) -> torch.Tensor:
1077
+ return torch.empty_like(hidden_states)
1078
+
1079
+
1080
+ direct_register_custom_op(
1081
+ op_name="outplace_fused_experts",
1082
+ op_func=outplace_fused_experts,
1083
+ mutates_args=[],
1084
+ fake_impl=outplace_fused_experts_fake,
1085
+ )
1086
+
1087
+
1088
+ def fused_experts(hidden_states: torch.Tensor,
1089
+ w1: torch.Tensor,
1090
+ w2: torch.Tensor,
1091
+ topk_weights: torch.Tensor,
1092
+ topk_ids: torch.Tensor,
1093
+ inplace: bool = False,
1094
+ use_fp8_w8a8: bool = False,
1095
+ use_int8_w8a16: bool = False,
1096
+ use_int4_w4a16: bool = False,
1097
+ w1_scale: Optional[torch.Tensor] = None,
1098
+ w2_scale: Optional[torch.Tensor] = None,
1099
+ w1_zp: Optional[torch.Tensor] = None,
1100
+ w2_zp: Optional[torch.Tensor] = None,
1101
+ a1_scale: Optional[torch.Tensor] = None,
1102
+ a2_scale: Optional[torch.Tensor] = None,
1103
+ block_shape: Optional[List[int]] = None):
1104
+ if inplace:
1105
+ torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2,
1106
+ topk_weights, topk_ids,
1107
+ use_fp8_w8a8, use_int8_w8a16,
1108
+ use_int4_w4a16, w1_scale,
1109
+ w2_scale, w1_zp, w2_zp, a1_scale,
1110
+ a2_scale, block_shape)
1111
+ return hidden_states
1112
+ else:
1113
+ return torch.ops.vllm.outplace_fused_experts(
1114
+ hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8,
1115
+ use_int8_w8a16, use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp,
1116
+ a1_scale, a2_scale, block_shape)
1117
+
1118
+
1119
+ def fused_experts_impl(hidden_states: torch.Tensor,
1120
+ w1: torch.Tensor,
1121
+ w2: torch.Tensor,
1122
+ topk_weights: torch.Tensor,
1123
+ topk_ids: torch.Tensor,
1124
+ inplace: bool = False,
1125
+ use_fp8_w8a8: bool = False,
1126
+ use_int8_w8a16: bool = False,
1127
+ use_int4_w4a16: bool = False,
1128
+ w1_scale: Optional[torch.Tensor] = None,
1129
+ w2_scale: Optional[torch.Tensor] = None,
1130
+ w1_zp: Optional[torch.Tensor] = None,
1131
+ w2_zp: Optional[torch.Tensor] = None,
1132
+ a1_scale: Optional[torch.Tensor] = None,
1133
+ a2_scale: Optional[torch.Tensor] = None,
1134
+ block_shape: Optional[List[int]] = None):
1135
+ # Check constraints.
1136
+ if use_int4_w4a16:
1137
+ assert hidden_states.shape[1] // 2 == w1.shape[
1138
+ 2], "Hidden size mismatch"
1139
+ else:
1140
+ assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
1141
+
1142
+ assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
1143
+ assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
1144
+ assert w1.is_contiguous(), "Expert weights1 must be contiguous"
1145
+ assert w2.is_contiguous(), "Expert weights2 must be contiguous"
1146
+ assert hidden_states.dtype in [
1147
+ torch.float32, torch.float16, torch.bfloat16
1148
+ ]
1149
+
1150
+ num_tokens, _ = hidden_states.shape
1151
+ E, N, _ = w1.shape
1152
+ # We execute the fused_moe kernel in chunks to circumvent this issue:
1153
+ # https://github.com/vllm-project/vllm/issues/5938
1154
+ CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
1155
+ M = min(num_tokens, CHUNK_SIZE)
1156
+ config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
1157
+ use_int8_w8a16=use_int8_w8a16,
1158
+ use_int4_w4a16=use_int4_w4a16,
1159
+ dtype=hidden_states.dtype)
1160
+
1161
+ get_config_func = functools.partial(
1162
+ try_get_optimal_moe_config,
1163
+ w1.shape,
1164
+ w2.shape,
1165
+ topk_ids.shape[1],
1166
+ config_dtype,
1167
+ block_shape=block_shape,
1168
+ )
1169
+
1170
+ config = get_config_func(M)
1171
+
1172
+ intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
1173
+ device=hidden_states.device,
1174
+ dtype=hidden_states.dtype)
1175
+ intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
1176
+ device=hidden_states.device,
1177
+ dtype=hidden_states.dtype)
1178
+ intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
1179
+ device=hidden_states.device,
1180
+ dtype=hidden_states.dtype)
1181
+
1182
+ if hidden_states.dtype == torch.bfloat16:
1183
+ compute_type = tl.bfloat16
1184
+ elif hidden_states.dtype == torch.float16:
1185
+ compute_type = tl.float16
1186
+ elif hidden_states.dtype == torch.float32:
1187
+ compute_type = tl.float32
1188
+ else:
1189
+ raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
1190
+
1191
+ if inplace:
1192
+ out_hidden_states = hidden_states
1193
+ else:
1194
+ out_hidden_states = torch.empty_like(hidden_states)
1195
+
1196
+ for chunk in range((num_tokens // CHUNK_SIZE) + 1):
1197
+ begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
1198
+ min((chunk + 1) * CHUNK_SIZE,
1199
+ num_tokens))
1200
+ curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
1201
+ tokens_in_chunk, _ = curr_hidden_states.shape
1202
+
1203
+ if tokens_in_chunk == 0:
1204
+ break
1205
+
1206
+ if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
1207
+ # Adjust the intermediate cache size and config for the last
1208
+ # chunk. Note that in most cases we only have one chunk
1209
+ # so the cache size and config are already set correctly and
1210
+ # do not need to be adjusted.
1211
+ intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
1212
+ intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
1213
+ intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
1214
+ config = get_config_func(tokens_in_chunk)
1215
+
1216
+ curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
1217
+ curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
1218
+
1219
+ sorted_token_ids, expert_ids, num_tokens_post_padded = (
1220
+ moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E))
1221
+
1222
+ invoke_fused_moe_kernel(curr_hidden_states,
1223
+ w1,
1224
+ intermediate_cache1,
1225
+ a1_scale,
1226
+ w1_scale,
1227
+ w1_zp,
1228
+ curr_topk_weights,
1229
+ curr_topk_ids,
1230
+ sorted_token_ids,
1231
+ expert_ids,
1232
+ num_tokens_post_padded,
1233
+ False,
1234
+ topk_ids.shape[1],
1235
+ config,
1236
+ compute_type=compute_type,
1237
+ use_fp8_w8a8=use_fp8_w8a8,
1238
+ use_int8_w8a16=use_int8_w8a16,
1239
+ use_int4_w4a16=use_int4_w4a16,
1240
+ block_shape=block_shape)
1241
+
1242
+ torch.ops._C.silu_and_mul(intermediate_cache2,
1243
+ intermediate_cache1.view(-1, N))
1244
+
1245
+ invoke_fused_moe_kernel(intermediate_cache2,
1246
+ w2,
1247
+ intermediate_cache3,
1248
+ a2_scale,
1249
+ w2_scale,
1250
+ w2_zp,
1251
+ curr_topk_weights,
1252
+ curr_topk_ids,
1253
+ sorted_token_ids,
1254
+ expert_ids,
1255
+ num_tokens_post_padded,
1256
+ True,
1257
+ 1,
1258
+ config,
1259
+ compute_type=compute_type,
1260
+ use_fp8_w8a8=use_fp8_w8a8,
1261
+ use_int8_w8a16=use_int8_w8a16,
1262
+ use_int4_w4a16=use_int4_w4a16,
1263
+ block_shape=block_shape)
1264
+
1265
+ ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
1266
+ out_hidden_states[begin_chunk_idx:end_chunk_idx])
1267
+ return out_hidden_states
1268
+
1269
+
1270
+ def fused_moe(
1271
+ hidden_states: torch.Tensor,
1272
+ w1: torch.Tensor,
1273
+ w2: torch.Tensor,
1274
+ gating_output: torch.Tensor,
1275
+ topk: int,
1276
+ renormalize: bool,
1277
+ inplace: bool = False,
1278
+ use_grouped_topk: bool = False,
1279
+ num_expert_group: Optional[int] = None,
1280
+ topk_group: Optional[int] = None,
1281
+ custom_routing_function: Optional[Callable] = None,
1282
+ use_fp8_w8a8: bool = False,
1283
+ use_int8_w8a16: bool = False,
1284
+ use_int4_w4a16: bool = False,
1285
+ w1_scale: Optional[torch.Tensor] = None,
1286
+ w2_scale: Optional[torch.Tensor] = None,
1287
+ w1_zp: Optional[torch.Tensor] = None,
1288
+ w2_zp: Optional[torch.Tensor] = None,
1289
+ a1_scale: Optional[torch.Tensor] = None,
1290
+ a2_scale: Optional[torch.Tensor] = None,
1291
+ block_shape: Optional[List[int]] = None,
1292
+ ) -> torch.Tensor:
1293
+ """
1294
+ This function computes a Mixture of Experts (MoE) layer using two sets of
1295
+ weights, w1 and w2, and top-k gating mechanism.
1296
+
1297
+ Parameters:
1298
+ - hidden_states (torch.Tensor): The input tensor to the MoE layer.
1299
+ - w1 (torch.Tensor): The first set of expert weights.
1300
+ - w2 (torch.Tensor): The second set of expert weights.
1301
+ - gating_output (torch.Tensor): The output of the gating operation
1302
+ (before softmax).
1303
+ - topk (int): The number of top-k experts to select.
1304
+ - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
1305
+ - inplace (bool): If True, perform the operation in-place.
1306
+ Defaults to False.
1307
+ - num_expert_group: Optional[int]: additional parameter for grouped_topk
1308
+ - topk_group: Optional[int]: additional parameter for grouped_topk
1309
+ - use_grouped_topk: If True, use grouped_topk instead of fused_topk
1310
+ note: Deepseekv2 model uses grouped_topk
1311
+ - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
1312
+ products for w1 and w2. Defaults to False.
1313
+ - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
1314
+ activation to compute the inner products for w1 and w2.
1315
+ Defaults to False.
1316
+ - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
1317
+ activation to compute the inner products for w1 and w2.
1318
+ Defaults to False.
1319
+ - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
1320
+ w1.
1321
+ - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
1322
+ w2.
1323
+ - a1_scale (Optional[torch.Tensor]): Optional scale to be used for
1324
+ a1.
1325
+ - a2_scale (Optional[torch.Tensor]): Optional scale to be used for
1326
+ a2.
1327
+ - block_shape: (Optional[List[int]]): Optional block size for block-wise
1328
+ quantization.
1329
+
1330
+ Returns:
1331
+ - torch.Tensor: The output tensor after applying the MoE layer.
1332
+ """
1333
+ # Check constraints.
1334
+ assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
1335
+
1336
+ if use_grouped_topk:
1337
+ assert num_expert_group is not None and topk_group is not None
1338
+ topk_weights, topk_ids = grouped_topk(hidden_states, gating_output,
1339
+ topk, renormalize,
1340
+ num_expert_group, topk_group)
1341
+ elif custom_routing_function is None:
1342
+ topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
1343
+ renormalize)
1344
+ else:
1345
+ topk_weights, topk_ids = custom_routing_function(
1346
+ hidden_states, gating_output, topk, renormalize)
1347
+
1348
+ return fused_experts(hidden_states,
1349
+ w1,
1350
+ w2,
1351
+ topk_weights,
1352
+ topk_ids,
1353
+ inplace=inplace,
1354
+ use_fp8_w8a8=use_fp8_w8a8,
1355
+ use_int8_w8a16=use_int8_w8a16,
1356
+ use_int4_w4a16=use_int4_w4a16,
1357
+ w1_scale=w1_scale,
1358
+ w2_scale=w2_scale,
1359
+ w1_zp=w1_zp,
1360
+ w2_zp=w2_zp,
1361
+ a1_scale=a1_scale,
1362
+ a2_scale=a2_scale,
1363
+ block_shape=block_shape)
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/layer.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from abc import abstractmethod
4
+ from enum import Enum
5
+ from typing import Callable, List, Optional, Tuple
6
+
7
+ import torch
8
+
9
+ from vllm.distributed import (get_tensor_model_parallel_rank,
10
+ get_tensor_model_parallel_world_size,
11
+ tensor_model_parallel_all_reduce)
12
+ from vllm.logger import init_logger
13
+ from vllm.model_executor.custom_op import CustomOp
14
+ from vllm.model_executor.layers.quantization.base_config import (
15
+ QuantizationConfig, QuantizeMethodBase)
16
+ from vllm.model_executor.utils import set_weight_attrs
17
+ from vllm.platforms import current_platform
18
+ from vllm.platforms.interface import CpuArchEnum
19
+
20
+ if current_platform.is_cuda_alike():
21
+ from .fused_moe import fused_experts
22
+ else:
23
+ fused_experts = None # type: ignore
24
+ if current_platform.is_tpu():
25
+ # the iterative moe implementation is used until the moe_pallas is fixed
26
+ from .moe_torch_iterative import fused_moe as fused_moe_pallas
27
+ else:
28
+ fused_moe_pallas = None # type: ignore
29
+ logger = init_logger(__name__)
30
+
31
+
32
+ class FusedMoeWeightScaleSupported(Enum):
33
+ TENSOR = "tensor"
34
+ CHANNEL = "channel"
35
+ GROUP = "group"
36
+ BLOCK = "block"
37
+
38
+
39
+ class FusedMoEMethodBase(QuantizeMethodBase):
40
+
41
+ @abstractmethod
42
+ def create_weights(self, layer: torch.nn.Module, num_experts: int,
43
+ hidden_size: int, intermediate_size_per_partition: int,
44
+ params_dtype: torch.dtype, **extra_weight_attrs):
45
+ raise NotImplementedError
46
+
47
+ @abstractmethod
48
+ def apply(
49
+ self,
50
+ layer: torch.nn.Module,
51
+ x: torch.Tensor,
52
+ router_logits: torch.Tensor,
53
+ top_k: int,
54
+ renormalize: bool,
55
+ use_grouped_topk: bool = False,
56
+ topk_group: Optional[int] = None,
57
+ num_expert_group: Optional[int] = None,
58
+ custom_routing_function: Optional[Callable] = None,
59
+ scoring_func: str = "softmax",
60
+ e_score_correction_bias: Optional[torch.Tensor] = None
61
+ ) -> torch.Tensor:
62
+ raise NotImplementedError
63
+
64
+
65
+ @CustomOp.register("unquantized_fused_moe")
66
+ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
67
+ """MoE method without quantization."""
68
+
69
+ def create_weights(self, layer: torch.nn.Module, num_experts: int,
70
+ hidden_size: int, intermediate_size_per_partition: int,
71
+ params_dtype: torch.dtype, **extra_weight_attrs):
72
+ # Fused gate_up_proj (column parallel)
73
+ w13_weight = torch.nn.Parameter(torch.empty(
74
+ num_experts,
75
+ 2 * intermediate_size_per_partition,
76
+ hidden_size,
77
+ dtype=params_dtype),
78
+ requires_grad=False)
79
+ layer.register_parameter("w13_weight", w13_weight)
80
+ set_weight_attrs(w13_weight, extra_weight_attrs)
81
+
82
+ # down_proj (row parallel)
83
+ w2_weight = torch.nn.Parameter(torch.empty(
84
+ num_experts,
85
+ hidden_size,
86
+ intermediate_size_per_partition,
87
+ dtype=params_dtype),
88
+ requires_grad=False)
89
+ layer.register_parameter("w2_weight", w2_weight)
90
+ set_weight_attrs(w2_weight, extra_weight_attrs)
91
+
92
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
93
+ super().process_weights_after_loading(layer)
94
+
95
+ if current_platform.is_cpu():
96
+ if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
97
+ import intel_extension_for_pytorch as ipex
98
+ layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
99
+ layer.w13_weight,
100
+ layer.w2_weight,
101
+ use_prepack=True,
102
+ )
103
+ else:
104
+ raise NotImplementedError("CPU MOE only supports x86 arch.")
105
+
106
+ def apply(
107
+ self,
108
+ layer: torch.nn.Module,
109
+ x: torch.Tensor,
110
+ router_logits: torch.Tensor,
111
+ top_k: int,
112
+ renormalize: bool,
113
+ use_grouped_topk: bool = False,
114
+ topk_group: Optional[int] = None,
115
+ num_expert_group: Optional[int] = None,
116
+ custom_routing_function: Optional[Callable] = None,
117
+ scoring_func: str = "softmax",
118
+ e_score_correction_bias: Optional[torch.Tensor] = None
119
+ ) -> torch.Tensor:
120
+ return self.forward(x=x,
121
+ layer=layer,
122
+ router_logits=router_logits,
123
+ top_k=top_k,
124
+ renormalize=renormalize,
125
+ use_grouped_topk=use_grouped_topk,
126
+ topk_group=topk_group,
127
+ num_expert_group=num_expert_group,
128
+ custom_routing_function=custom_routing_function,
129
+ scoring_func=scoring_func,
130
+ e_score_correction_bias=e_score_correction_bias)
131
+
132
+ def forward_cuda(
133
+ self,
134
+ layer: torch.nn.Module,
135
+ x: torch.Tensor,
136
+ use_grouped_topk: bool,
137
+ top_k: int,
138
+ router_logits: torch.Tensor,
139
+ renormalize: bool,
140
+ topk_group: Optional[int] = None,
141
+ num_expert_group: Optional[int] = None,
142
+ custom_routing_function: Optional[Callable] = None,
143
+ scoring_func: str = "softmax",
144
+ e_score_correction_bias: Optional[torch.Tensor] = None
145
+ ) -> torch.Tensor:
146
+ topk_weights, topk_ids = FusedMoE.select_experts(
147
+ hidden_states=x,
148
+ router_logits=router_logits,
149
+ use_grouped_topk=use_grouped_topk,
150
+ top_k=top_k,
151
+ renormalize=renormalize,
152
+ topk_group=topk_group,
153
+ num_expert_group=num_expert_group,
154
+ custom_routing_function=custom_routing_function,
155
+ scoring_func=scoring_func,
156
+ e_score_correction_bias=e_score_correction_bias)
157
+
158
+ return fused_experts(hidden_states=x,
159
+ w1=layer.w13_weight,
160
+ w2=layer.w2_weight,
161
+ topk_weights=topk_weights,
162
+ topk_ids=topk_ids,
163
+ inplace=True)
164
+
165
+ def forward_cpu(
166
+ self,
167
+ layer: torch.nn.Module,
168
+ x: torch.Tensor,
169
+ use_grouped_topk: bool,
170
+ top_k: int,
171
+ router_logits: torch.Tensor,
172
+ renormalize: bool,
173
+ topk_group: Optional[int] = None,
174
+ num_expert_group: Optional[int] = None,
175
+ custom_routing_function: Optional[Callable] = None,
176
+ **kwargs,
177
+ ):
178
+ assert custom_routing_function is None
179
+ return layer.ipex_fusion(
180
+ x,
181
+ use_grouped_topk,
182
+ top_k,
183
+ router_logits,
184
+ renormalize,
185
+ topk_group,
186
+ num_expert_group,
187
+ )
188
+
189
+ def forward_tpu(
190
+ self,
191
+ layer: torch.nn.Module,
192
+ x: torch.Tensor,
193
+ use_grouped_topk: bool,
194
+ top_k: int,
195
+ router_logits: torch.Tensor,
196
+ renormalize: bool,
197
+ topk_group: Optional[int] = None,
198
+ num_expert_group: Optional[int] = None,
199
+ custom_routing_function: Optional[Callable] = None,
200
+ scoring_func: str = "softmax",
201
+ e_score_correction_bias: Optional[torch.Tensor] = None
202
+ ) -> torch.Tensor:
203
+ assert not use_grouped_topk
204
+ assert num_expert_group is None
205
+ assert topk_group is None
206
+ assert custom_routing_function is None
207
+ if scoring_func != "softmax":
208
+ raise NotImplementedError(
209
+ "Only softmax scoring function is supported for TPU.")
210
+ if e_score_correction_bias is not None:
211
+ raise NotImplementedError(
212
+ "Expert score correction bias is not supported for TPU.")
213
+ return fused_moe_pallas(hidden_states=x,
214
+ w1=layer.w13_weight,
215
+ w2=layer.w2_weight,
216
+ topk=top_k,
217
+ gating_output=router_logits,
218
+ renormalize=renormalize)
219
+
220
+ forward_native = forward_cuda
221
+
222
+
223
+ class FusedMoE(torch.nn.Module):
224
+ """FusedMoE layer for MoE models.
225
+
226
+ This layer contains both MergedColumnParallel weights (gate_up_proj /
227
+ w13) and RowParallelLinear weights (down_proj/ w2).
228
+
229
+ Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
230
+ copy that naming convention here and handle any remapping in the
231
+ load_weights function in each model implementation.
232
+
233
+ Args:
234
+ num_experts: Number of experts in the model
235
+ top_k: Number of experts selected for each token
236
+ hidden_size: Input hidden state size of the transformer
237
+ intermediate_size: Intermediate size of the experts
238
+ params_dtype: Data type for the parameters.
239
+ reduce_results: Whether to all all_reduce on the output of the layer
240
+ renomalize: Whether to renormalize the logits in the fused_moe kernel
241
+ quant_config: Quantization configure.
242
+ """
243
+
244
+ def __init__(
245
+ self,
246
+ num_experts: int,
247
+ top_k: int,
248
+ hidden_size: int,
249
+ intermediate_size: int,
250
+ params_dtype: Optional[torch.dtype] = None,
251
+ reduce_results: bool = False,
252
+ renormalize: bool = True,
253
+ use_grouped_topk: bool = False,
254
+ num_expert_group: Optional[int] = None,
255
+ topk_group: Optional[int] = None,
256
+ quant_config: Optional[QuantizationConfig] = None,
257
+ tp_size: Optional[int] = None,
258
+ prefix: str = "",
259
+ custom_routing_function: Optional[Callable] = None,
260
+ scoring_func: str = "softmax",
261
+ e_score_correction_bias: Optional[torch.Tensor] = None,
262
+ ):
263
+ super().__init__()
264
+
265
+ if params_dtype is None:
266
+ params_dtype = torch.get_default_dtype()
267
+
268
+ self.tp_size = (tp_size if tp_size is not None else
269
+ get_tensor_model_parallel_world_size())
270
+ self.top_k = top_k
271
+ self.num_experts = num_experts
272
+ assert intermediate_size % self.tp_size == 0
273
+ self.intermediate_size_per_partition = intermediate_size // self.tp_size
274
+ self.reduce_results = reduce_results
275
+ self.renormalize = renormalize
276
+ self.use_grouped_topk = use_grouped_topk
277
+ if self.use_grouped_topk:
278
+ assert num_expert_group is not None and topk_group is not None
279
+ self.num_expert_group = num_expert_group
280
+ self.topk_group = topk_group
281
+ self.custom_routing_function = custom_routing_function
282
+ self.scoring_func = scoring_func
283
+ self.e_score_correction_bias = e_score_correction_bias
284
+
285
+ if self.scoring_func != "softmax" and not self.use_grouped_topk:
286
+ raise ValueError("Only softmax scoring function is supported for "
287
+ "non-grouped topk.")
288
+
289
+ if quant_config is None:
290
+ self.quant_method: Optional[QuantizeMethodBase] = (
291
+ UnquantizedFusedMoEMethod())
292
+ else:
293
+ self.quant_method = quant_config.get_quant_method(self, prefix)
294
+ assert self.quant_method is not None
295
+
296
+ moe_quant_params = {
297
+ "num_experts": num_experts,
298
+ "hidden_size": hidden_size,
299
+ "intermediate_size_per_partition":
300
+ self.intermediate_size_per_partition,
301
+ "params_dtype": params_dtype,
302
+ "weight_loader": self.weight_loader,
303
+ }
304
+ # need full intermediate size pre-sharding for WNA16 act order
305
+ if (self.quant_method.__class__.__name__ ==
306
+ "CompressedTensorsWNA16MoEMethod"):
307
+ moe_quant_params["intermediate_size_full"] = intermediate_size
308
+
309
+ self.quant_method.create_weights(layer=self, **moe_quant_params)
310
+
311
+ def _load_per_tensor_weight_scale(self, shard_id: str,
312
+ param: torch.nn.Parameter,
313
+ loaded_weight: torch.Tensor,
314
+ expert_id: int):
315
+ param_data = param.data
316
+ # for per tensor weight quantization
317
+ if shard_id in ("w1", "w3"):
318
+ # We have to keep the weight scales of w1 and w3 because
319
+ # we need to re-quantize w1/w3 weights after weight loading.
320
+ idx = 0 if shard_id == "w1" else 1
321
+ param_data[expert_id][idx] = loaded_weight
322
+ # If we are in the row parallel case (down_proj)
323
+ elif shard_id == "w2":
324
+ param_data[expert_id] = loaded_weight
325
+
326
+ def _load_model_weight_or_group_weight_scale(self,
327
+ shard_dim: int,
328
+ expert_data: torch.Tensor,
329
+ shard_id: str,
330
+ loaded_weight: torch.Tensor,
331
+ tp_rank: int,
332
+ load_full_w2: bool = False):
333
+ """
334
+ Load grouped weight scales for group quantization or model weights
335
+ :param shard_dim: dimension to shard
336
+ :param expert_data: parameter for a particular expert
337
+ :param shard_id: either w1, w2, or w3
338
+ :param loaded_weight: checkpoint weight to load into the param
339
+ :param tp_rank: tensor parallel rank
340
+ :param load_full_w2: whether or not the w2 loaded should be sharded.
341
+ """
342
+ if shard_id == "w2":
343
+ # In the case where we have actorder/g_idx, we do not partition the
344
+ # w2 scales, as indicated by `load_full` argument, for all tp cases
345
+ self._load_w2(shard_dim=shard_dim,
346
+ loaded_weight=loaded_weight,
347
+ expert_data=expert_data,
348
+ tp_rank=tp_rank,
349
+ load_full=load_full_w2)
350
+ elif shard_id in ("w1", "w3"):
351
+ self._load_w13(shard_id=shard_id,
352
+ shard_dim=shard_dim,
353
+ loaded_weight=loaded_weight,
354
+ expert_data=expert_data,
355
+ tp_rank=tp_rank)
356
+
357
+ def _load_per_channel_weight_scale(self, expert_data: torch.Tensor,
358
+ shard_dim: int, shard_id: str,
359
+ loaded_weight: torch.Tensor,
360
+ tp_rank: int):
361
+ # for per channel weight quantization
362
+ if shard_id == "w2":
363
+ expert_data.copy_(loaded_weight)
364
+ elif shard_id in ("w1", "w3"):
365
+ self._load_w13(shard_id=shard_id,
366
+ shard_dim=shard_dim,
367
+ loaded_weight=loaded_weight,
368
+ expert_data=expert_data,
369
+ tp_rank=tp_rank)
370
+
371
+ def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
372
+ shard_id: str, loaded_weight: torch.Tensor, tp_rank: int):
373
+
374
+ # Index the loaded weight for tp sharding.
375
+ # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
376
+ shard_size = expert_data.shape[shard_dim] // 2
377
+ loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
378
+ shard_size)
379
+ # Narrow parameter and load.
380
+ # w1, gate_proj: Load into first logical weight of w13.
381
+ if shard_id == "w1":
382
+ expert_data = expert_data.narrow(shard_dim, 0, shard_size)
383
+ # w3, up_proj: Load into second logical weight of w13.
384
+ else:
385
+ assert shard_id == "w3"
386
+ expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
387
+ expert_data.copy_(loaded_weight)
388
+
389
+ def _load_w2(self,
390
+ expert_data: torch.Tensor,
391
+ shard_dim: int,
392
+ loaded_weight: torch.Tensor,
393
+ tp_rank: int,
394
+ load_full: bool = False):
395
+
396
+ # Index the loaded weight for tp sharding.
397
+ # down_proj: "RowParallel" so tp sharding on input_dim
398
+ # Narrow parameter and load.
399
+ shard_size = expert_data.shape[shard_dim]
400
+ if not load_full:
401
+ loaded_weight = loaded_weight.narrow(shard_dim,
402
+ shard_size * tp_rank,
403
+ shard_size)
404
+ # w2, down_proj: Load into only logical weight of w2.
405
+ expert_data.copy_(loaded_weight)
406
+
407
+ def _load_single_value(self, param: torch.nn.Parameter,
408
+ loaded_weight: torch.Tensor, expert_id: int):
409
+ param_data = param.data
410
+
411
+ # Input scales can be loaded directly and should be equal.
412
+ param_data[expert_id] = loaded_weight
413
+
414
+ def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
415
+ shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int):
416
+
417
+ if shard_id == "w2":
418
+ self._load_w2(shard_dim=shard_dim,
419
+ loaded_weight=loaded_weight,
420
+ expert_data=expert_data,
421
+ tp_rank=tp_rank)
422
+ else:
423
+ assert shard_id in ("w1", "w3")
424
+ expert_data.copy_(loaded_weight)
425
+
426
+ def weight_loader(self, param: torch.nn.Parameter,
427
+ loaded_weight: torch.Tensor, weight_name: str,
428
+ shard_id: str, expert_id: int) -> None:
429
+
430
+ # compressed-tensors checkpoints with packed weights are stored flipped
431
+ # TODO (mgoin): check self.quant_method.quant_config.quant_format
432
+ # against known CompressionFormat enum values that have this quality
433
+ loaded_weight = loaded_weight.t().contiguous() if (
434
+ self.quant_method.__class__.__name__
435
+ == "CompressedTensorsWNA16MoEMethod") else loaded_weight
436
+
437
+ if shard_id not in ("w1", "w2", "w3"):
438
+ raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
439
+ f"got {shard_id}.")
440
+
441
+ WEIGHT_SCALE_SUPPORTED = [
442
+ e.value for e in FusedMoeWeightScaleSupported
443
+ ]
444
+ # Fetch the dim to shard the parameter/loaded weight
445
+ # based on the shard id. This will be whatever
446
+ # dimension intermediate_size_per_partition is used.
447
+ SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
448
+
449
+ expert_data = param.data[expert_id]
450
+ tp_rank = get_tensor_model_parallel_rank()
451
+
452
+ # is_transposed: if the dim to shard the weight
453
+ # should be flipped. Required by GPTQ, compressed-tensors
454
+ # should be whatever dimension intermediate_size_per_partition is
455
+ is_transposed = getattr(param, "is_transposed", False)
456
+ shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
457
+ if is_transposed:
458
+ shard_dim = int(not shard_dim)
459
+
460
+ # Case input scale: input_scale loading is only supported for fp8
461
+ if "input_scale" in weight_name:
462
+ # this is needed for compressed-tensors only
463
+ loaded_weight = loaded_weight.to(param.data.device)
464
+
465
+ if param.data[expert_id] != 1 and (param.data[expert_id] -
466
+ loaded_weight).abs() > 1e-5:
467
+ raise ValueError(
468
+ "input_scales of w1 and w3 of a layer "
469
+ f"must be equal. But got {param.data[expert_id]} "
470
+ f"vs. {loaded_weight}")
471
+
472
+ self._load_single_value(param=param,
473
+ loaded_weight=loaded_weight,
474
+ expert_id=expert_id)
475
+ return
476
+
477
+ # Case g_idx
478
+ if "g_idx" in weight_name:
479
+ self._load_g_idx(shard_dim=0,
480
+ shard_id=shard_id,
481
+ loaded_weight=loaded_weight,
482
+ expert_data=expert_data,
483
+ tp_rank=tp_rank)
484
+ return
485
+
486
+ # Case weight scales and zero_points
487
+ if ("scale" in weight_name or "zero" in weight_name):
488
+ # load the weight scales and zp based on the quantization scheme
489
+ # supported weight scales/zp can be found in
490
+ # FusedMoeWeightScaleSupported
491
+ # TODO @dsikka: once hardened, refactor to use vLLM Parameters
492
+ # specific to each case
493
+ quant_method = getattr(param, "quant_method", None)
494
+ if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
495
+ self._load_per_channel_weight_scale(
496
+ shard_id=shard_id,
497
+ shard_dim=shard_dim,
498
+ loaded_weight=loaded_weight,
499
+ expert_data=expert_data,
500
+ tp_rank=tp_rank)
501
+ elif quant_method in [
502
+ FusedMoeWeightScaleSupported.GROUP.value,
503
+ FusedMoeWeightScaleSupported.BLOCK.value,
504
+ ]:
505
+ self._load_model_weight_or_group_weight_scale(
506
+ shard_id=shard_id,
507
+ shard_dim=shard_dim,
508
+ loaded_weight=loaded_weight,
509
+ expert_data=expert_data,
510
+ tp_rank=tp_rank,
511
+ load_full_w2=getattr(param, "load_full_w2", False))
512
+ elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
513
+ self._load_per_tensor_weight_scale(shard_id=shard_id,
514
+ param=param,
515
+ loaded_weight=loaded_weight,
516
+ expert_id=expert_id)
517
+ else:
518
+ raise ValueError(
519
+ f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
520
+ return
521
+
522
+ # Case weight_shape
523
+ if "weight_shape" in weight_name:
524
+ # only required by compressed-tensors
525
+ self._load_single_value(param=param,
526
+ loaded_weight=loaded_weight,
527
+ expert_id=expert_id)
528
+ return
529
+
530
+ # Case model weights
531
+ if "weight" in weight_name:
532
+ self._load_model_weight_or_group_weight_scale(
533
+ shard_id=shard_id,
534
+ shard_dim=shard_dim,
535
+ loaded_weight=loaded_weight,
536
+ expert_data=expert_data,
537
+ tp_rank=tp_rank)
538
+ return
539
+
540
+ @staticmethod
541
+ def select_experts(hidden_states: torch.Tensor,
542
+ router_logits: torch.Tensor,
543
+ top_k: int,
544
+ use_grouped_topk: bool,
545
+ renormalize: bool,
546
+ topk_group: Optional[int] = None,
547
+ num_expert_group: Optional[int] = None,
548
+ custom_routing_function: Optional[Callable] = None,
549
+ scoring_func: str = "softmax",
550
+ e_score_correction_bias: Optional[torch.Tensor] = None):
551
+ from vllm.model_executor.layers.fused_moe.fused_moe import (
552
+ fused_topk, grouped_topk)
553
+
554
+ # DeekSeekv2 uses grouped_top_k
555
+ if use_grouped_topk:
556
+ assert topk_group is not None
557
+ assert num_expert_group is not None
558
+ topk_weights, topk_ids = grouped_topk(
559
+ hidden_states=hidden_states,
560
+ gating_output=router_logits,
561
+ topk=top_k,
562
+ renormalize=renormalize,
563
+ num_expert_group=num_expert_group,
564
+ topk_group=topk_group,
565
+ scoring_func=scoring_func,
566
+ e_score_correction_bias=e_score_correction_bias)
567
+ elif custom_routing_function is None:
568
+ topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
569
+ gating_output=router_logits,
570
+ topk=top_k,
571
+ renormalize=renormalize)
572
+ else:
573
+ topk_weights, topk_ids = custom_routing_function(
574
+ hidden_states=hidden_states,
575
+ gating_output=router_logits,
576
+ topk=top_k,
577
+ renormalize=renormalize)
578
+
579
+ return topk_weights, topk_ids
580
+
581
+ def forward(self, hidden_states: torch.Tensor,
582
+ router_logits: torch.Tensor):
583
+ assert self.quant_method is not None
584
+
585
+ # Matrix multiply.
586
+ final_hidden_states = self.quant_method.apply(
587
+ layer=self,
588
+ x=hidden_states,
589
+ router_logits=router_logits,
590
+ top_k=self.top_k,
591
+ renormalize=self.renormalize,
592
+ use_grouped_topk=self.use_grouped_topk,
593
+ topk_group=self.topk_group,
594
+ num_expert_group=self.num_expert_group,
595
+ custom_routing_function=self.custom_routing_function,
596
+ scoring_func=self.scoring_func,
597
+ e_score_correction_bias=self.e_score_correction_bias)
598
+
599
+ if self.reduce_results and self.tp_size > 1:
600
+ final_hidden_states = tensor_model_parallel_all_reduce(
601
+ final_hidden_states)
602
+
603
+ return final_hidden_states
604
+
605
+ @classmethod
606
+ def make_expert_params_mapping(
607
+ cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
608
+ ckpt_up_proj_name: str,
609
+ num_experts: int) -> List[Tuple[str, str, int, str]]:
610
+
611
+ return [
612
+ # (param_name, weight_name, expert_id, shard_id)
613
+ ("experts.w13_" if weight_name
614
+ in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
615
+ f"experts.{expert_id}.{weight_name}.", expert_id, shard_id)
616
+ for expert_id in range(num_experts) for shard_id, weight_name in [
617
+ ("w1", ckpt_gate_proj_name),
618
+ ("w2", ckpt_down_proj_name),
619
+ ("w3", ckpt_up_proj_name),
620
+ ]
621
+ ]
622
+
623
+ def _load_fp8_scale(self, param: torch.nn.Parameter,
624
+ loaded_weight: torch.Tensor, weight_name: str,
625
+ shard_id: str, expert_id: int) -> None:
626
+ param_data = param.data
627
+
628
+ # Input scales can be loaded directly and should be equal.
629
+ if "input_scale" in weight_name:
630
+ if param_data[expert_id] != 1 and (param_data[expert_id] -
631
+ loaded_weight).abs() > 1e-5:
632
+ raise ValueError(
633
+ "input_scales of w1 and w3 of a layer "
634
+ f"must be equal. But got {param_data[expert_id]} "
635
+ f"vs. {loaded_weight}")
636
+ param_data[expert_id] = loaded_weight
637
+ # Weight scales
638
+ elif "weight_scale" in weight_name:
639
+ # If we are in merged column case (gate_up_proj)
640
+ if shard_id in ("w1", "w3"):
641
+ # We have to keep the weight scales of w1 and w3 because
642
+ # we need to re-quantize w1/w3 weights after weight loading.
643
+ idx = 0 if shard_id == "w1" else 1
644
+ param_data[expert_id][idx] = loaded_weight
645
+ # If we are in the row parallel case (down_proj)
646
+ else:
647
+ param_data[expert_id] = loaded_weight
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/moe_pallas.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch_xla.experimental.custom_kernel import _histogram
6
+
7
+
8
+ def fused_moe(
9
+ hidden_states: torch.Tensor,
10
+ w1: torch.Tensor,
11
+ w2: torch.Tensor,
12
+ gating_output: torch.Tensor,
13
+ topk: int,
14
+ renormalize: bool,
15
+ ) -> torch.Tensor:
16
+ """
17
+ Args:
18
+ hidden_states: [*, hidden_size]
19
+ w1: [num_experts, intermediate_size * 2, hidden_size]
20
+ w2: [num_experts, hidden_size, intermediate_size]
21
+ gating_output: [*, num_experts]
22
+ """
23
+ orig_shape = hidden_states.shape
24
+ hidden_size = hidden_states.shape[-1]
25
+ num_tokens = hidden_states.shape[:-1].numel()
26
+ num_experts = w1.shape[0]
27
+ intermediate_size = w2.shape[-1]
28
+ device = hidden_states.device
29
+ dtype = hidden_states.dtype
30
+ assert (num_tokens * topk) % 16 == 0, (
31
+ "The Pallas GMM kernel requires num_tokens * topk to be a multiple of "
32
+ f"16 but got {num_tokens * topk}")
33
+
34
+ hidden_states = hidden_states.view(num_tokens, hidden_size)
35
+ gating_output = gating_output.view(num_tokens, num_experts)
36
+ topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
37
+ topk_weights, topk_indices = topk_weights.topk(topk, dim=-1)
38
+ if renormalize:
39
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
40
+ topk_weights = topk_weights.to(dtype)
41
+
42
+ topk_indices = topk_indices.flatten()
43
+ topk_argsort_indices = topk_indices.argsort()
44
+ topk_argsort_revert_indices = topk_argsort_indices.argsort()
45
+ token_indices = torch.arange(num_tokens,
46
+ device=device).repeat_interleave(topk)
47
+ token_indices = token_indices[topk_argsort_indices]
48
+ group_sizes = _histogram(topk_indices.to(torch.int32), 0, num_experts - 1)
49
+
50
+ # NOTE(woosuk): The GMM Pallas kernel requires a different weight layout
51
+ # from HF Transformers.
52
+ w1 = w1.transpose(1, 2)
53
+ w2 = w2.transpose(1, 2)
54
+
55
+ x = hidden_states[token_indices]
56
+ x = torch.ops.xla.gmm(x, w1, group_sizes)
57
+ x = F.silu(x[..., :intermediate_size]) * x[..., intermediate_size:]
58
+ x = torch.ops.xla.gmm(x, w2, group_sizes)
59
+ x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
60
+
61
+ x = x * topk_weights.unsqueeze_(dim=-1)
62
+ x = x.sum(dim=-2)
63
+ x = x.reshape(orig_shape)
64
+ return x
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def fused_moe(
8
+ hidden_states: torch.Tensor,
9
+ w1: torch.Tensor,
10
+ w2: torch.Tensor,
11
+ gating_output: torch.Tensor,
12
+ topk: int,
13
+ renormalize: bool,
14
+ ) -> torch.Tensor:
15
+ """
16
+ Args:
17
+ hidden_states: [*, hidden_size]
18
+ w1: [num_experts, intermediate_size * 2, hidden_size]
19
+ w2: [num_experts, hidden_size, intermediate_size]
20
+ gating_output: [*, num_experts]
21
+ """
22
+ orig_shape = hidden_states.shape
23
+ hidden_size = hidden_states.shape[-1]
24
+ num_tokens = hidden_states.shape[:-1].numel()
25
+ num_experts = w1.shape[0]
26
+ intermediate_size = w2.shape[-1]
27
+ dtype = hidden_states.dtype
28
+
29
+ hidden_states = hidden_states.view(num_tokens, hidden_size)
30
+ gating_output = gating_output.view(num_tokens, num_experts)
31
+ topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
32
+ topk_weights, selected_experts = topk_weights.topk(topk, dim=-1)
33
+ if renormalize:
34
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
35
+ topk_weights = topk_weights.to(dtype)
36
+
37
+ final_hidden_states = None
38
+ for expert_idx in range(num_experts):
39
+ expert_w1 = w1[expert_idx]
40
+ expert_w2 = w2[expert_idx]
41
+ expert_mask = (selected_experts == expert_idx)
42
+ expert_weights = (topk_weights * expert_mask).sum(dim=-1, keepdim=True)
43
+ x = F.linear(hidden_states, expert_w1)
44
+ gate = F.silu(x[:, :intermediate_size])
45
+ x = x[:, intermediate_size:] * gate
46
+ x = F.linear(x, expert_w2)
47
+ current_hidden_states = x * expert_weights
48
+ if final_hidden_states is None:
49
+ final_hidden_states = current_hidden_states
50
+ else:
51
+ final_hidden_states = final_hidden_states + current_hidden_states
52
+
53
+ return final_hidden_states.view(orig_shape) # type: ignore
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/layernorm.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Custom normalization layers."""
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from vllm.model_executor.custom_op import CustomOp
9
+
10
+
11
+ @CustomOp.register("rms_norm")
12
+ class RMSNorm(CustomOp):
13
+ """Root mean square normalization.
14
+
15
+ Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
16
+ Refer to https://arxiv.org/abs/1910.07467
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ hidden_size: int,
22
+ eps: float = 1e-6,
23
+ var_hidden_size: Optional[int] = None,
24
+ has_weight: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+
28
+ self.hidden_size = hidden_size
29
+ self.variance_epsilon = eps
30
+ self.variance_size_override = (None if var_hidden_size == hidden_size
31
+ else var_hidden_size)
32
+ self.has_weight = has_weight
33
+
34
+ self.weight = torch.ones(hidden_size)
35
+ if self.has_weight:
36
+ self.weight = nn.Parameter(self.weight)
37
+
38
+ def forward_native(
39
+ self,
40
+ x: torch.Tensor,
41
+ residual: Optional[torch.Tensor] = None,
42
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
43
+ """PyTorch-native implementation equivalent to forward()."""
44
+ orig_dtype = x.dtype
45
+ x = x.to(torch.float32)
46
+ if residual is not None:
47
+ x = x + residual.to(torch.float32)
48
+ residual = x.to(orig_dtype)
49
+
50
+ hidden_size = x.shape[-1]
51
+ if hidden_size != self.hidden_size:
52
+ raise ValueError("Expected hidden_size to be "
53
+ f"{self.hidden_size}, but found: {hidden_size}")
54
+
55
+ if self.variance_size_override is None:
56
+ x_var = x
57
+ else:
58
+ if hidden_size < self.variance_size_override:
59
+ raise ValueError(
60
+ "Expected hidden_size to be at least "
61
+ f"{self.variance_size_override}, but found: {hidden_size}")
62
+
63
+ x_var = x[:, :, :self.variance_size_override]
64
+
65
+ variance = x_var.pow(2).mean(dim=-1, keepdim=True)
66
+
67
+ x = x * torch.rsqrt(variance + self.variance_epsilon)
68
+ x = x.to(orig_dtype)
69
+ if self.has_weight:
70
+ x = x * self.weight
71
+ if residual is None:
72
+ return x
73
+ else:
74
+ return x, residual
75
+
76
+ def forward_cuda(
77
+ self,
78
+ x: torch.Tensor,
79
+ residual: Optional[torch.Tensor] = None,
80
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
81
+ if self.variance_size_override is not None:
82
+ return self.forward_native(x, residual)
83
+
84
+ from vllm import _custom_ops as ops
85
+
86
+ if residual is not None:
87
+ ops.fused_add_rms_norm(
88
+ x,
89
+ residual,
90
+ self.weight.data,
91
+ self.variance_epsilon,
92
+ )
93
+ return x, residual
94
+ out = torch.empty_like(x)
95
+ ops.rms_norm(
96
+ out,
97
+ x,
98
+ self.weight.data,
99
+ self.variance_epsilon,
100
+ )
101
+ return out
102
+
103
+ def forward_hpu(
104
+ self,
105
+ x: torch.Tensor,
106
+ residual: Optional[torch.Tensor] = None,
107
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
108
+ from vllm_hpu_extension.ops import HPUFusedRMSNorm
109
+ if HPUFusedRMSNorm is None:
110
+ return self.forward_native(x, residual)
111
+ if residual is not None:
112
+ orig_shape = x.shape
113
+ residual += x.view(residual.shape)
114
+ # Note: HPUFusedRMSNorm requires 3D tensors as inputs
115
+ x = HPUFusedRMSNorm.apply(residual, self.weight,
116
+ self.variance_epsilon)
117
+ return x.view(orig_shape), residual
118
+
119
+ x = HPUFusedRMSNorm.apply(x, self.weight, self.variance_epsilon)
120
+ return x
121
+
122
+ def forward_xpu(
123
+ self,
124
+ x: torch.Tensor,
125
+ residual: Optional[torch.Tensor] = None,
126
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
127
+ if self.variance_size_override is not None:
128
+ return self.forward_native(x, residual)
129
+
130
+ from vllm._ipex_ops import ipex_ops as ops
131
+
132
+ if residual is not None:
133
+ ops.fused_add_rms_norm(
134
+ x,
135
+ residual,
136
+ self.weight.data,
137
+ self.variance_epsilon,
138
+ )
139
+ return x, residual
140
+ return ops.rms_norm(
141
+ x,
142
+ self.weight.data,
143
+ self.variance_epsilon,
144
+ )
145
+
146
+ def extra_repr(self) -> str:
147
+ s = f"hidden_size={self.weight.data.size(0)}"
148
+ s += f", eps={self.variance_epsilon}"
149
+ return s
150
+
151
+
152
+ @CustomOp.register("gemma_rms_norm")
153
+ class GemmaRMSNorm(CustomOp):
154
+ """RMS normalization for Gemma.
155
+
156
+ Two differences from the above RMSNorm:
157
+ 1. x * (1 + w) instead of x * w.
158
+ 2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
159
+ """
160
+
161
+ def __init__(
162
+ self,
163
+ hidden_size: int,
164
+ eps: float = 1e-6,
165
+ ) -> None:
166
+ super().__init__()
167
+ self.weight = nn.Parameter(torch.zeros(hidden_size))
168
+ self.variance_epsilon = eps
169
+
170
+ @staticmethod
171
+ def forward_static(
172
+ weight: torch.Tensor,
173
+ variance_epsilon: float,
174
+ x: torch.Tensor,
175
+ residual: Optional[torch.Tensor],
176
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
177
+ """PyTorch-native implementation equivalent to forward()."""
178
+ orig_dtype = x.dtype
179
+ if residual is not None:
180
+ x = x + residual
181
+ residual = x
182
+
183
+ x = x.float()
184
+ variance = x.pow(2).mean(dim=-1, keepdim=True)
185
+ x = x * torch.rsqrt(variance + variance_epsilon)
186
+ # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
187
+ # See https://github.com/huggingface/transformers/pull/29402
188
+ x = x * (1.0 + weight.float())
189
+ x = x.to(orig_dtype)
190
+ return x if residual is None else (x, residual)
191
+
192
+ def forward_native(
193
+ self,
194
+ x: torch.Tensor,
195
+ residual: Optional[torch.Tensor] = None,
196
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
197
+ """PyTorch-native implementation equivalent to forward()."""
198
+ return self.forward_static(self.weight.data, self.variance_epsilon, x,
199
+ residual)
200
+
201
+ def forward_cuda(
202
+ self,
203
+ x: torch.Tensor,
204
+ residual: Optional[torch.Tensor] = None,
205
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
206
+ if torch.compiler.is_compiling():
207
+ return self.forward_native(x, residual)
208
+
209
+ if not getattr(self, "_is_compiled", False):
210
+ self.forward_static = torch.compile( # type: ignore
211
+ self.forward_static)
212
+ self._is_compiled = True
213
+ return self.forward_native(x, residual)
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/linear.py ADDED
@@ -0,0 +1,1159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import itertools
4
+ from abc import abstractmethod
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch.nn.parameter import Parameter, UninitializedParameter
10
+
11
+ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
12
+ get_tensor_model_parallel_world_size,
13
+ split_tensor_along_last_dim,
14
+ tensor_model_parallel_all_gather,
15
+ tensor_model_parallel_all_reduce)
16
+ from vllm.logger import init_logger
17
+ from vllm.model_executor.layers.quantization.base_config import (
18
+ QuantizationConfig, QuantizeMethodBase)
19
+ # yapf: disable
20
+ from vllm.model_executor.parameter import (BasevLLMParameter,
21
+ BlockQuantScaleParameter,
22
+ PackedColumnParameter,
23
+ PackedvLLMParameter,
24
+ PerTensorScaleParameter,
25
+ RowvLLMParameter)
26
+ # yapf: enable
27
+ from vllm.model_executor.utils import set_weight_attrs
28
+
29
+ logger = init_logger(__name__)
30
+
31
+ WEIGHT_LOADER_V2_SUPPORTED = [
32
+ "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
33
+ "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
34
+ "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
35
+ "TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
36
+ "ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod",
37
+ "HQQMarlinMethod", "QuarkLinearMethod"
38
+ ]
39
+
40
+
41
+ def adjust_marlin_shard(param, shard_size, shard_offset):
42
+ marlin_tile_size = getattr(param, "marlin_tile_size", None)
43
+ if marlin_tile_size is None:
44
+ return shard_size, shard_offset
45
+
46
+ return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
47
+
48
+
49
+ def adjust_bitsandbytes_4bit_shard(param: Parameter,
50
+ shard_offsets: dict[str, tuple[int, int]],
51
+ loaded_shard_id: str) -> tuple[int, int]:
52
+ """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
53
+
54
+ total, _ = shard_offsets["total"]
55
+ orig_offset, orig_size = shard_offsets[loaded_shard_id]
56
+
57
+ quantized_total = param.data.shape[0]
58
+ quantized_offset = orig_offset * quantized_total // total
59
+ quantized_size = orig_size * quantized_total // total
60
+
61
+ return quantized_size, quantized_offset
62
+
63
+
64
+ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
65
+ """For fused modules (QKV and MLP) we have an array of length
66
+ N that holds 1 scale for each "logical" matrix. So the param
67
+ is an array of length N. The loaded_weight corresponds to
68
+ one of the shards on disk. Here, we slice the param based on
69
+ the shard_id for loading.
70
+ """
71
+ qkv_idxs = {"q": 0, "k": 1, "v": 2}
72
+
73
+ if isinstance(shard_id, str):
74
+ shard_id = qkv_idxs[shard_id]
75
+ elif not isinstance(shard_id, int):
76
+ raise ValueError(f"Unknown Shard Id {shard_id}")
77
+
78
+ # AutoFP8 scales do not have a shape
79
+ # compressed-tensors scales do have a shape
80
+ if len(loaded_weight.shape) != 0:
81
+ assert loaded_weight.shape[0] == 1
82
+ loaded_weight = loaded_weight[0]
83
+
84
+ return param[shard_id], loaded_weight
85
+
86
+
87
+ class LinearMethodBase(QuantizeMethodBase):
88
+ """Base class for different (maybe quantized) linear methods."""
89
+
90
+ @abstractmethod
91
+ def create_weights(self, layer: torch.nn.Module,
92
+ input_size_per_partition: int,
93
+ output_partition_sizes: list[int], input_size: int,
94
+ output_size: int, params_dtype: torch.dtype,
95
+ **extra_weight_attrs):
96
+ """Create weights for a linear layer.
97
+ The weights will be set as attributes of the layer.
98
+
99
+ Args:
100
+ layer: The layer that is using the LinearMethodBase factory.
101
+ input_size_per_partition: Size of the weight input dim on rank X.
102
+ output_partition_sizes: Sizes of the output dim of each logical
103
+ weight on rank X. E.g., output_partition_sizes for QKVLinear
104
+ is a list contains the width of Wq, Wk, Wv on rank X.
105
+ input_size: Size of the input dim of the weight across all ranks.
106
+ output_size: Size of the output dim of the weight across all ranks.
107
+ params_dtype: Datatype of the parameters.
108
+ """
109
+ raise NotImplementedError
110
+
111
+ @abstractmethod
112
+ def apply(self,
113
+ layer: torch.nn.Module,
114
+ x: torch.Tensor,
115
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
116
+ """Apply the weights in layer to the input tensor.
117
+ Expects create_weights to have been called before on the layer."""
118
+ raise NotImplementedError
119
+
120
+
121
+ class UnquantizedLinearMethod(LinearMethodBase):
122
+ """Linear method without quantization."""
123
+
124
+ def create_weights(self, layer: torch.nn.Module,
125
+ input_size_per_partition: int,
126
+ output_partition_sizes: list[int], input_size: int,
127
+ output_size: int, params_dtype: torch.dtype,
128
+ **extra_weight_attrs):
129
+ weight = Parameter(torch.empty(sum(output_partition_sizes),
130
+ input_size_per_partition,
131
+ dtype=params_dtype),
132
+ requires_grad=False)
133
+ set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
134
+ layer.register_parameter("weight", weight)
135
+ set_weight_attrs(weight, extra_weight_attrs)
136
+
137
+ def apply(self,
138
+ layer: torch.nn.Module,
139
+ x: torch.Tensor,
140
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
141
+
142
+ return F.linear(x, layer.weight, bias)
143
+
144
+
145
+ class LinearBase(torch.nn.Module):
146
+ """Base linear layer.
147
+
148
+ Args:
149
+ input_size: input dimension of the linear layer.
150
+ output_size: output dimension of the linear layer.
151
+ bias: If true, add bias.
152
+ skip_bias_add: If true, skip adding bias but instead return it.
153
+ params_dtype: Data type for the parameters.
154
+ quant_config: Quantization configure.
155
+ """
156
+
157
+ def __init__(
158
+ self,
159
+ input_size: int,
160
+ output_size: int,
161
+ skip_bias_add: bool = False,
162
+ params_dtype: Optional[torch.dtype] = None,
163
+ quant_config: Optional[QuantizationConfig] = None,
164
+ prefix: str = "",
165
+ ):
166
+ super().__init__()
167
+
168
+ # Keep input parameters
169
+ self.input_size = input_size
170
+ self.output_size = output_size
171
+ self.skip_bias_add = skip_bias_add
172
+ if params_dtype is None:
173
+ params_dtype = torch.get_default_dtype()
174
+ self.params_dtype = params_dtype
175
+ if quant_config is None:
176
+ self.quant_method: Optional[
177
+ QuantizeMethodBase] = UnquantizedLinearMethod()
178
+ else:
179
+ self.quant_method = quant_config.get_quant_method(self,
180
+ prefix=prefix)
181
+
182
+ def forward(self,
183
+ x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
184
+ raise NotImplementedError
185
+
186
+
187
+ class ReplicatedLinear(LinearBase):
188
+ """Replicated linear layer.
189
+
190
+ Args:
191
+ input_size: input dimension of the linear layer.
192
+ output_size: output dimension of the linear layer.
193
+ bias: If true, add bias.
194
+ skip_bias_add: If true, skip adding bias but instead return it.
195
+ params_dtype: Data type for the parameters.
196
+ quant_config: Quantization configure.
197
+ prefix: The name of the layer in the state dict, including all parents
198
+ (e.g. model.layers.0.qkv_proj)
199
+ """
200
+
201
+ def __init__(self,
202
+ input_size: int,
203
+ output_size: int,
204
+ bias: bool = True,
205
+ skip_bias_add: bool = False,
206
+ params_dtype: Optional[torch.dtype] = None,
207
+ quant_config: Optional[QuantizationConfig] = None,
208
+ prefix: str = ""):
209
+ super().__init__(input_size,
210
+ output_size,
211
+ skip_bias_add,
212
+ params_dtype,
213
+ quant_config,
214
+ prefix=prefix)
215
+
216
+ # All the linear layer supports quant method.
217
+ assert self.quant_method is not None
218
+ self.quant_method.create_weights(self,
219
+ self.input_size, [self.output_size],
220
+ self.input_size,
221
+ self.output_size,
222
+ self.params_dtype,
223
+ weight_loader=self.weight_loader)
224
+
225
+ if bias:
226
+ self.bias = Parameter(
227
+ torch.empty(self.output_size, dtype=self.params_dtype))
228
+ set_weight_attrs(self.bias, {
229
+ "output_dim": 0,
230
+ "weight_loader": self.weight_loader,
231
+ })
232
+ else:
233
+ self.register_parameter("bias", None)
234
+
235
+ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
236
+ # If the weight on disk does not have a shape, give it one
237
+ # (such scales for AutoFp8).
238
+ if len(loaded_weight.shape) == 0:
239
+ loaded_weight = loaded_weight.reshape(1)
240
+
241
+ assert param.size() == loaded_weight.size()
242
+ param.data.copy_(loaded_weight)
243
+
244
+ def forward(self,
245
+ x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
246
+ bias = self.bias if not self.skip_bias_add else None
247
+ assert self.quant_method is not None
248
+ output = self.quant_method.apply(self, x, bias)
249
+ output_bias = self.bias if self.skip_bias_add else None
250
+ return output, output_bias
251
+
252
+ def extra_repr(self) -> str:
253
+ s = f"in_features={self.input_size}"
254
+ s += f", output_features={self.output_size}"
255
+ s += f", bias={self.bias is not None}"
256
+ return s
257
+
258
+
259
+ class ColumnParallelLinear(LinearBase):
260
+ """Linear layer with column parallelism.
261
+
262
+ The linear layer is defined as Y = XA + b. A is parallelized along
263
+ its second dimension as A = [A_1, ..., A_p].
264
+
265
+ Args:
266
+ input_size: first dimension of matrix A.
267
+ output_size: second dimension of matrix A.
268
+ bias: If true, add bias.
269
+ gather_output: If true, call all-gather on output and make Y available
270
+ to all GPUs, otherwise, every GPU will have its output
271
+ which is Y_i = XA_i
272
+ skip_bias_add: This was added to enable performance optimizations where
273
+ bias can be fused with other element-wise operations. we
274
+ skip adding bias but instead return it.
275
+ params_dtype: Data type for the parameters.
276
+ quant_config: Quantization configure.
277
+ output_sizes: list of output sizes packed into one output, like for QKV
278
+ the list would be size 3.
279
+ prefix: The name of the layer in the state dict, including all parents
280
+ (e.g. model.layers.0.qkv_proj)
281
+ """
282
+
283
+ def __init__(self,
284
+ input_size: int,
285
+ output_size: int,
286
+ bias: bool = True,
287
+ gather_output: bool = False,
288
+ skip_bias_add: bool = False,
289
+ params_dtype: Optional[torch.dtype] = None,
290
+ quant_config: Optional[QuantizationConfig] = None,
291
+ output_sizes: Optional[list[int]] = None,
292
+ prefix: str = ""):
293
+ super().__init__(input_size, output_size, skip_bias_add, params_dtype,
294
+ quant_config, prefix)
295
+
296
+ self.gather_output = gather_output
297
+
298
+ # Divide the weight matrix along the last dimension.
299
+ tp_size = get_tensor_model_parallel_world_size()
300
+ assert self.quant_method is not None
301
+ self.output_size_per_partition = divide(self.output_size, tp_size)
302
+ self.output_partition_sizes = [self.output_size_per_partition]
303
+ # If QKV or MergedColumn, use output size of each partition.
304
+ if hasattr(self, "output_sizes"):
305
+ self.output_partition_sizes = [
306
+ divide(output_size, tp_size)
307
+ for output_size in self.output_sizes
308
+ ]
309
+
310
+ if output_sizes is None:
311
+ output_sizes = [output_size]
312
+
313
+ self.quant_method.create_weights(
314
+ layer=self,
315
+ input_size_per_partition=self.input_size,
316
+ output_partition_sizes=self.output_partition_sizes,
317
+ input_size=self.input_size,
318
+ output_size=self.output_size,
319
+ params_dtype=self.params_dtype,
320
+ weight_loader=(
321
+ self.weight_loader_v2 if self.quant_method.__class__.__name__
322
+ in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
323
+ if bias:
324
+ self.bias = Parameter(
325
+ torch.empty(self.output_size_per_partition,
326
+ dtype=params_dtype))
327
+ set_weight_attrs(self.bias, {
328
+ "output_dim": 0,
329
+ "weight_loader": self.weight_loader,
330
+ })
331
+ else:
332
+ self.register_parameter("bias", None)
333
+
334
+ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
335
+ tp_rank = get_tensor_model_parallel_rank()
336
+ output_dim = getattr(param, "output_dim", None)
337
+
338
+ # Special case for GGUF
339
+ is_gguf_weight = getattr(param, "is_gguf_weight", False)
340
+ is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
341
+ if is_gguf_weight_type:
342
+ param.weight_type = loaded_weight.item()
343
+
344
+ # Materialize GGUF UninitializedParameter
345
+ if is_gguf_weight and isinstance(param, UninitializedParameter):
346
+ param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
347
+
348
+ use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
349
+ is_sharded_weight = getattr(param, "is_sharded_weight", False)
350
+ # bitsandbytes loads the weights of the specific portion
351
+ # no need to narrow
352
+ is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
353
+
354
+ param_data = param.data
355
+ if output_dim is not None and not is_sharded_weight:
356
+ shard_size = param_data.shape[output_dim]
357
+ start_idx = tp_rank * shard_size
358
+ loaded_weight = loaded_weight.narrow(output_dim, start_idx,
359
+ shard_size)
360
+
361
+ # Special case for loading scales off disk, which often do not
362
+ # have a shape (such as in the case of AutoFP8).
363
+ if len(loaded_weight.shape) == 0:
364
+ loaded_weight = loaded_weight.reshape(1)
365
+
366
+ assert param_data.shape == loaded_weight.shape
367
+ param_data.copy_(loaded_weight)
368
+
369
+ def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
370
+ # Special case for loading scales off disk, which often do not
371
+ # have a shape (such as in the case of AutoFP8).
372
+ if len(loaded_weight.shape) == 0:
373
+ assert loaded_weight.numel() == 1
374
+ loaded_weight = loaded_weight.reshape(1)
375
+ param.load_column_parallel_weight(loaded_weight=loaded_weight)
376
+
377
+ def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]:
378
+ bias = self.bias if not self.skip_bias_add else None
379
+
380
+ # Matrix multiply.
381
+ assert self.quant_method is not None
382
+ output_parallel = self.quant_method.apply(self, input_, bias)
383
+ if self.gather_output:
384
+ # All-gather across the partitions.
385
+ output = tensor_model_parallel_all_gather(output_parallel)
386
+ else:
387
+ output = output_parallel
388
+ output_bias = self.bias if self.skip_bias_add else None
389
+ return output, output_bias
390
+
391
+ def extra_repr(self) -> str:
392
+ s = f"in_features={self.input_size}"
393
+ s += f", output_features={self.output_size_per_partition}"
394
+ s += f", bias={self.bias is not None}"
395
+ s += f", tp_size={get_tensor_model_parallel_world_size()}"
396
+ s += f", gather_output={self.gather_output}"
397
+ return s
398
+
399
+
400
+ class MergedColumnParallelLinear(ColumnParallelLinear):
401
+ """Packed linear layers with column parallelism.
402
+
403
+ Similar to ColumnParallelLinear, but the weight matrix is concatenated
404
+ along the output dimension. When the weight matrix is loaded, the
405
+ different partitions are sharded separately.
406
+
407
+ Args:
408
+ input_size: input dimension of the linear layer.
409
+ output_sizes: list of output dimensions of the linear layer.
410
+ bias: If true, add bias.
411
+ gather_output: If true, call all-gather on output and make the output
412
+ available to all GPUs, otherwise, every GPU will have
413
+ its own output.
414
+ skip_bias_add: This was added to enable performance optimizations where
415
+ bias can be fused with other element-wise operations. we
416
+ skip adding bias but instead return it.
417
+ params_dtype: Data type for the parameters.
418
+ quant_config: Quantization configure.
419
+ prefix: The name of the layer in the state dict, including all parents
420
+ (e.g. model.layers.0.qkv_proj)
421
+ """
422
+
423
+ def __init__(self,
424
+ input_size: int,
425
+ output_sizes: list[int],
426
+ bias: bool = True,
427
+ gather_output: bool = False,
428
+ skip_bias_add: bool = False,
429
+ params_dtype: Optional[torch.dtype] = None,
430
+ quant_config: Optional[QuantizationConfig] = None,
431
+ prefix: str = ""):
432
+ self.output_sizes = output_sizes
433
+ tp_size = get_tensor_model_parallel_world_size()
434
+ assert all(output_size % tp_size == 0 for output_size in output_sizes)
435
+ super().__init__(input_size=input_size,
436
+ output_size=sum(output_sizes),
437
+ bias=bias,
438
+ gather_output=gather_output,
439
+ skip_bias_add=skip_bias_add,
440
+ params_dtype=params_dtype,
441
+ quant_config=quant_config,
442
+ prefix=prefix)
443
+
444
+ def weight_loader(self,
445
+ param: Parameter,
446
+ loaded_weight: torch.Tensor,
447
+ loaded_shard_id: Optional[int] = None):
448
+
449
+ # Special case for GGUF
450
+ # initialize GGUF param after we know the quantize type
451
+ is_gguf_weight = getattr(param, "is_gguf_weight", False)
452
+ is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
453
+ if is_gguf_weight_type:
454
+ if loaded_shard_id is not None:
455
+ param.data[loaded_shard_id].copy_(loaded_weight)
456
+ param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
457
+ else:
458
+ param.shard_weight_type = {
459
+ i: loaded_weight.item()
460
+ for i, _ in enumerate(self.output_sizes)
461
+ }
462
+ return
463
+
464
+ if is_gguf_weight:
465
+ tp_size = get_tensor_model_parallel_world_size()
466
+ tp_rank = get_tensor_model_parallel_rank()
467
+
468
+ output_dim = getattr(param, "output_dim", None)
469
+ shard_size = loaded_weight.size(output_dim) // tp_size
470
+ start_idx = tp_rank * shard_size
471
+
472
+ if loaded_shard_id is not None:
473
+ loaded_weight = loaded_weight.narrow(output_dim, start_idx,
474
+ shard_size)
475
+ param.shard_id.append(loaded_shard_id)
476
+ param.shard_id_map[loaded_shard_id] = len(param.data_container)
477
+ param.data_container.append(loaded_weight)
478
+ if len(param.data_container) == 2:
479
+ self.qweight = param.materialize_nested()
480
+ return
481
+
482
+ param_data = param.data
483
+ output_dim = getattr(param, "output_dim", None)
484
+ # Special case for AQLM codebooks.
485
+ is_metadata = getattr(param, "is_metadata", False)
486
+ # Special case for per-tensor scale to load scalar into fused array.
487
+ needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
488
+
489
+ if loaded_shard_id is None:
490
+ # Loaded weight is already fused on disk (mlp).
491
+ # (e.g., Phi-3's gate_up_proj).
492
+ if output_dim is None:
493
+ if needs_scalar_to_array:
494
+ param_data, loaded_weight = adjust_scalar_to_fused_array(
495
+ param_data, loaded_weight, 0)
496
+
497
+ assert param_data.shape == loaded_weight.shape
498
+ param_data.copy_(loaded_weight)
499
+ return
500
+ current_shard_offset = 0
501
+ use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
502
+ False)
503
+ shard_offsets: list[tuple[int, int, int]] = []
504
+ for i, output_size in enumerate(self.output_sizes):
505
+ shard_offsets.append((i, current_shard_offset, output_size))
506
+ current_shard_offset += output_size
507
+ packed_dim = getattr(param, "packed_dim", None)
508
+ for shard_id, shard_offset, shard_size in shard_offsets:
509
+ # Special case for Quantization.
510
+ # If quantized, we need to adjust the offset and size to account
511
+ # for the packing.
512
+ if packed_dim == output_dim:
513
+ shard_size = shard_size // param.pack_factor
514
+ shard_offset = shard_offset // param.pack_factor
515
+ # Special case for Marlin.
516
+ shard_size, shard_offset = adjust_marlin_shard(
517
+ param, shard_size, shard_offset)
518
+
519
+ if use_bitsandbytes_4bit:
520
+ index = list(itertools.accumulate([0] + self.output_sizes))
521
+ orig_offsets = {
522
+ str(i): (index[i], size)
523
+ for i, size in enumerate(self.output_sizes)
524
+ }
525
+ orig_offsets["total"] = (self.output_size, 0)
526
+ shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
527
+ param, orig_offsets, str(shard_id))
528
+
529
+ loaded_weight_shard = loaded_weight.narrow(
530
+ output_dim, shard_offset, shard_size)
531
+ self.weight_loader(param, loaded_weight_shard, shard_id)
532
+ return
533
+
534
+ assert loaded_shard_id < len(self.output_sizes)
535
+ tp_rank = get_tensor_model_parallel_rank()
536
+ tp_size = get_tensor_model_parallel_world_size()
537
+ if output_dim is not None:
538
+ shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
539
+ shard_size = self.output_sizes[loaded_shard_id] // tp_size
540
+ # Special case for quantization.
541
+ # If quantized, we need to adjust the offset and size to account
542
+ # for the packing.
543
+ packed_dim = getattr(param, "packed_dim", None)
544
+ if packed_dim == output_dim:
545
+ shard_size = shard_size // param.pack_factor
546
+ shard_offset = shard_offset // param.pack_factor
547
+ # Special case for Marlin.
548
+ shard_size, shard_offset = adjust_marlin_shard(
549
+ param, shard_size, shard_offset)
550
+
551
+ use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
552
+ False)
553
+ is_sharded_weight = getattr(param, "is_sharded_weight", False)
554
+ # bitsandbytes loads the weights of the specific portion
555
+ # no need to narrow
556
+ is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
557
+
558
+ if use_bitsandbytes_4bit:
559
+ shard_size = loaded_weight.shape[output_dim]
560
+ shard_offset = loaded_weight.shape[output_dim] * \
561
+ loaded_shard_id
562
+
563
+ param_data = param_data.narrow(output_dim, shard_offset,
564
+ shard_size)
565
+ start_idx = tp_rank * shard_size
566
+ if not is_sharded_weight:
567
+ loaded_weight = loaded_weight.narrow(output_dim, start_idx,
568
+ shard_size)
569
+ # Special case for AQLM codebooks.
570
+ elif is_metadata:
571
+ # metadata indicates fixed size concatenated along dim 0
572
+ shard_size = loaded_weight.shape[0]
573
+ shard_offset = loaded_shard_id * shard_size
574
+ param_data = param_data.narrow(0, shard_offset, shard_size)
575
+
576
+ # Special case for per-tensor scales in fused case.
577
+ elif needs_scalar_to_array:
578
+ param_data, loaded_weight = adjust_scalar_to_fused_array(
579
+ param_data, loaded_weight, loaded_shard_id)
580
+
581
+ else:
582
+ ignore_warning = getattr(param, "ignore_warning", False)
583
+ if not ignore_warning:
584
+ logger.warning(
585
+ "Loading a weight without `output_dim` attribute in "
586
+ "MergedColumnParallelLinear, assume the weight is "
587
+ "the same for all partitions.")
588
+
589
+ assert param_data.shape == loaded_weight.shape
590
+ param_data.copy_(loaded_weight)
591
+
592
+ def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
593
+ loaded_weight: torch.Tensor):
594
+ """
595
+ Handle special case for models where MLP layers are already
596
+ fused on disk. In this case, we have no shard id. This function
597
+ determmines the shard id by splitting these layers and then calls
598
+ the weight loader using the shard id.
599
+
600
+ An example of a model with these fused layers:
601
+ https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
602
+ """
603
+
604
+ current_shard_offset = 0
605
+ shard_offsets: list[tuple[int, int, int]] = []
606
+ for i, output_size in enumerate(self.output_sizes):
607
+ shard_offsets.append((i, current_shard_offset, output_size))
608
+ current_shard_offset += output_size
609
+
610
+ for shard_id, shard_offset, shard_size in shard_offsets:
611
+ # Special case for Quantization.
612
+ # If quantized, we need to adjust the offset and size to account
613
+ # for the packing.
614
+ if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
615
+ )) and param.packed_dim == param.output_dim:
616
+ shard_size, shard_offset = \
617
+ param.adjust_shard_indexes_for_packing(
618
+ shard_size=shard_size, shard_offset=shard_offset)
619
+
620
+ loaded_weight_shard = loaded_weight.narrow(param.output_dim,
621
+ shard_offset,
622
+ shard_size)
623
+ self.weight_loader_v2(param, loaded_weight_shard, shard_id)
624
+
625
+ def weight_loader_v2(self,
626
+ param: BasevLLMParameter,
627
+ loaded_weight: torch.Tensor,
628
+ loaded_shard_id: Optional[int] = None):
629
+ if loaded_shard_id is None:
630
+ if isinstance(param, PerTensorScaleParameter):
631
+ param.load_merged_column_weight(loaded_weight=loaded_weight,
632
+ shard_id=0)
633
+ return
634
+ elif type(param) in (RowvLLMParameter, BasevLLMParameter):
635
+ param.load_merged_column_weight(loaded_weight=loaded_weight)
636
+ return
637
+ # TODO: @dsikka - move to parameter.py
638
+ self._load_fused_module_from_checkpoint(param, loaded_weight)
639
+ return
640
+
641
+ assert loaded_shard_id < len(self.output_sizes)
642
+
643
+ tp_size = get_tensor_model_parallel_world_size()
644
+
645
+ if isinstance(param, BlockQuantScaleParameter):
646
+ from vllm.model_executor.layers.quantization.fp8 import (
647
+ Fp8LinearMethod, Fp8MoEMethod)
648
+ assert self.quant_method is not None
649
+ assert isinstance(self.quant_method,
650
+ (Fp8LinearMethod, Fp8MoEMethod))
651
+ weight_block_size = self.quant_method.quant_config.weight_block_size
652
+ assert weight_block_size is not None
653
+ block_n, _ = weight_block_size[0], weight_block_size[1]
654
+ shard_offset = (
655
+ (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
656
+ block_n) // tp_size
657
+ shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
658
+ block_n // tp_size)
659
+ else:
660
+ shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
661
+ shard_size = self.output_sizes[loaded_shard_id] // tp_size
662
+
663
+ param.load_merged_column_weight(loaded_weight=loaded_weight,
664
+ shard_id=loaded_shard_id,
665
+ shard_offset=shard_offset,
666
+ shard_size=shard_size)
667
+
668
+
669
+ class QKVParallelLinear(ColumnParallelLinear):
670
+ """Linear layers for the attention's QKV transformation.
671
+
672
+ Linear layers for the linear transformation of the query, key, and value
673
+ vectors in the attention layer. The weight matrix is concatenated along
674
+ the output dimension. The layer is parallelized along the head dimension.
675
+ When the number of key/value heads is smaller than the number of query
676
+ heads (e.g., multi-query/grouped-query attention), the key/value head may
677
+ be replicated while the query heads are partitioned.
678
+
679
+ Args:
680
+ hidden_size: input hidden state size of the transformer.
681
+ head_size: size of each attention head.
682
+ total_num_heads: total number of attention query heads.
683
+ total_num_kv_heads: total number of attention key/value heads. If
684
+ None, assume total_num_kv_heads = total_num_heads.
685
+ bias: If true, add bias.
686
+ skip_bias_add: This was added to enable performance optimizations where
687
+ bias can be fused with other element-wise operations. we
688
+ skip adding bias but instead return it.
689
+ params_dtype: Data type for the parameters.
690
+ quant_config: Quantization configure.
691
+ prefix: The name of the layer in the state dict, including all parents
692
+ (e.g. model.layers.0.qkv_proj)
693
+ """
694
+
695
+ def __init__(self,
696
+ hidden_size: int,
697
+ head_size: int,
698
+ total_num_heads: int,
699
+ total_num_kv_heads: Optional[int] = None,
700
+ bias: bool = True,
701
+ skip_bias_add: bool = False,
702
+ params_dtype: Optional[torch.dtype] = None,
703
+ quant_config: Optional[QuantizationConfig] = None,
704
+ prefix: str = ""):
705
+ self.hidden_size = hidden_size
706
+ self.head_size = head_size
707
+ self.total_num_heads = total_num_heads
708
+ if total_num_kv_heads is None:
709
+ total_num_kv_heads = total_num_heads
710
+ self.total_num_kv_heads = total_num_kv_heads
711
+ # Divide the weight matrix along the last dimension.
712
+ tp_size = get_tensor_model_parallel_world_size()
713
+ self.num_heads = divide(self.total_num_heads, tp_size)
714
+ if tp_size >= self.total_num_kv_heads:
715
+ self.num_kv_heads = 1
716
+ self.num_kv_head_replicas = divide(tp_size,
717
+ self.total_num_kv_heads)
718
+ else:
719
+ self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
720
+ self.num_kv_head_replicas = 1
721
+ input_size = self.hidden_size
722
+ output_size = (self.num_heads +
723
+ 2 * self.num_kv_heads) * tp_size * self.head_size
724
+ self.output_sizes = [
725
+ self.num_heads * self.head_size * tp_size, # q_proj
726
+ self.num_kv_heads * self.head_size * tp_size, # k_proj
727
+ self.num_kv_heads * self.head_size * tp_size, # v_proj
728
+ ]
729
+
730
+ super().__init__(input_size=input_size,
731
+ output_size=output_size,
732
+ bias=bias,
733
+ gather_output=False,
734
+ skip_bias_add=skip_bias_add,
735
+ params_dtype=params_dtype,
736
+ quant_config=quant_config,
737
+ prefix=prefix)
738
+
739
+ def _get_shard_offset_mapping(self, loaded_shard_id: str):
740
+ shard_offset_mapping = {
741
+ "q": 0,
742
+ "k": self.num_heads * self.head_size,
743
+ "v": (self.num_heads + self.num_kv_heads) * self.head_size,
744
+ "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
745
+ }
746
+ return shard_offset_mapping.get(loaded_shard_id)
747
+
748
+ def _get_shard_size_mapping(self, loaded_shard_id: str):
749
+ shard_size_mapping = {
750
+ "q": self.num_heads * self.head_size,
751
+ "k": self.num_kv_heads * self.head_size,
752
+ "v": self.num_kv_heads * self.head_size,
753
+ }
754
+ return shard_size_mapping.get(loaded_shard_id)
755
+
756
+ def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
757
+ loaded_weight: torch.Tensor):
758
+ """
759
+ Handle special case for models where QKV layers are already
760
+ fused on disk. In this case, we have no shard id. This function
761
+ determmines the shard id by splitting these layers and then calls
762
+ the weight loader using the shard id.
763
+
764
+ An example of a model with these fused layers:
765
+ https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
766
+ """
767
+ shard_offsets = [
768
+ # (shard_id, shard_offset, shard_size)
769
+ ("q", 0, self.total_num_heads * self.head_size),
770
+ ("k", self.total_num_heads * self.head_size,
771
+ self.total_num_kv_heads * self.head_size),
772
+ ("v",
773
+ (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
774
+ self.total_num_kv_heads * self.head_size),
775
+ ]
776
+
777
+ for shard_id, shard_offset, shard_size in shard_offsets:
778
+ # Special case for Quantization.
779
+ # If quantized, we need to adjust the offset and size to account
780
+ # for the packing.
781
+ if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
782
+ )) and param.packed_dim == param.output_dim:
783
+ shard_size, shard_offset = \
784
+ param.adjust_shard_indexes_for_packing(
785
+ shard_size=shard_size, shard_offset=shard_offset)
786
+
787
+ loaded_weight_shard = loaded_weight.narrow(param.output_dim,
788
+ shard_offset,
789
+ shard_size)
790
+ self.weight_loader_v2(param, loaded_weight_shard, shard_id)
791
+
792
+ def weight_loader_v2(self,
793
+ param: BasevLLMParameter,
794
+ loaded_weight: torch.Tensor,
795
+ loaded_shard_id: Optional[str] = None):
796
+ if loaded_shard_id is None: # special case for certain models
797
+ if isinstance(param, PerTensorScaleParameter):
798
+ param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
799
+ return
800
+ elif type(param) in (RowvLLMParameter, BasevLLMParameter):
801
+ param.load_qkv_weight(loaded_weight=loaded_weight)
802
+ return
803
+ # TODO: @dsikka - move to parameter.py
804
+ self._load_fused_module_from_checkpoint(param, loaded_weight)
805
+ return
806
+
807
+ assert loaded_shard_id in ["q", "k", "v"]
808
+
809
+ shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
810
+ shard_size = self._get_shard_size_mapping(loaded_shard_id)
811
+
812
+ param.load_qkv_weight(loaded_weight=loaded_weight,
813
+ num_heads=self.num_kv_head_replicas,
814
+ shard_id=loaded_shard_id,
815
+ shard_offset=shard_offset,
816
+ shard_size=shard_size)
817
+
818
+ def weight_loader(self,
819
+ param: Parameter,
820
+ loaded_weight: torch.Tensor,
821
+ loaded_shard_id: Optional[str] = None):
822
+
823
+ # Special case for GGUF
824
+ # initialize GGUF param after we know the quantize type
825
+ is_gguf_weight = getattr(param, "is_gguf_weight", False)
826
+ is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
827
+ if is_gguf_weight_type:
828
+ idx_map = {"q": 0, "k": 1, "v": 2}
829
+ if loaded_shard_id is not None:
830
+ param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
831
+ param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
832
+ else:
833
+ param.shard_weight_type = {
834
+ k: loaded_weight.item()
835
+ for k in idx_map
836
+ }
837
+ return
838
+
839
+ if is_gguf_weight:
840
+ tp_size = get_tensor_model_parallel_world_size()
841
+ tp_rank = get_tensor_model_parallel_rank()
842
+
843
+ output_dim = getattr(param, "output_dim", None)
844
+ shard_size = loaded_weight.size(output_dim) // tp_size
845
+ start_idx = tp_rank * shard_size
846
+
847
+ if loaded_shard_id is not None:
848
+ loaded_weight = loaded_weight.narrow(output_dim, start_idx,
849
+ shard_size)
850
+ param.shard_id.append(loaded_shard_id)
851
+ param.shard_id_map[loaded_shard_id] = len(param.data_container)
852
+ param.data_container.append(loaded_weight)
853
+ if len(param.data_container) == 3:
854
+ self.qweight = param.materialize_nested()
855
+ return
856
+
857
+ param_data = param.data
858
+ output_dim = getattr(param, "output_dim", None)
859
+ # Special case for AQLM codebooks.
860
+ is_metadata = getattr(param, "is_metadata", False)
861
+
862
+ # Special case for per-tensor scales in fused case.
863
+ needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
864
+
865
+ if loaded_shard_id is None:
866
+ # Loaded weight is already fused on disk (qkv).
867
+ # (e.g., Phi-3's qkv_proj).
868
+ if output_dim is None:
869
+ if needs_scalar_to_array:
870
+ param_data, loaded_weight = adjust_scalar_to_fused_array(
871
+ param_data, loaded_weight, 0)
872
+
873
+ assert param_data.shape == loaded_weight.shape
874
+ param_data.copy_(loaded_weight)
875
+ return
876
+ shard_offsets = [
877
+ # (shard_id, shard_offset, shard_size)
878
+ ("q", 0, self.total_num_heads * self.head_size),
879
+ ("k", self.total_num_heads * self.head_size,
880
+ self.total_num_kv_heads * self.head_size),
881
+ ("v", (self.total_num_heads + self.total_num_kv_heads) *
882
+ self.head_size, self.total_num_kv_heads * self.head_size),
883
+ ]
884
+ use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
885
+ False)
886
+
887
+ packed_dim = getattr(param, "packed_dim", None)
888
+ for shard_id, shard_offset, shard_size in shard_offsets:
889
+ # Special case for Quantized Weights.
890
+ # If quantized, we need to adjust the offset and size to account
891
+ # for the packing.
892
+ if packed_dim == output_dim:
893
+ shard_size = shard_size // param.pack_factor
894
+ shard_offset = shard_offset // param.pack_factor
895
+
896
+ # Special case for Marlin.
897
+ shard_size, shard_offset = adjust_marlin_shard(
898
+ param, shard_size, shard_offset)
899
+
900
+ if use_bitsandbytes_4bit:
901
+ orig_qkv_offsets = {
902
+ "q": (0, self.total_num_heads * self.head_size),
903
+ "k": (self.total_num_heads * self.head_size,
904
+ self.total_num_kv_heads * self.head_size),
905
+ "v":
906
+ ((self.total_num_heads + self.total_num_kv_heads) *
907
+ self.head_size,
908
+ self.total_num_kv_heads * self.head_size),
909
+ "total":
910
+ ((self.total_num_heads + 2 * self.total_num_kv_heads) *
911
+ self.head_size, 0)
912
+ }
913
+
914
+ shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
915
+ param, orig_qkv_offsets, shard_id)
916
+
917
+ loaded_weight_shard = loaded_weight.narrow(
918
+ output_dim, shard_offset, shard_size)
919
+ self.weight_loader(param, loaded_weight_shard, shard_id)
920
+ return
921
+
922
+ tp_rank = get_tensor_model_parallel_rank()
923
+ assert loaded_shard_id in ["q", "k", "v"]
924
+
925
+ # If output dim is defined, use the default loading process.
926
+ if output_dim is not None:
927
+ if loaded_shard_id == "q":
928
+ shard_offset = 0
929
+ shard_size = self.num_heads * self.head_size
930
+ elif loaded_shard_id == "k":
931
+ shard_offset = self.num_heads * self.head_size
932
+ shard_size = self.num_kv_heads * self.head_size
933
+ elif loaded_shard_id == "v":
934
+ shard_offset = (self.num_heads +
935
+ self.num_kv_heads) * self.head_size
936
+ shard_size = self.num_kv_heads * self.head_size
937
+ # Special case for Quantized Weights.
938
+ # If quantized, we need to adjust the offset and size to account
939
+ # for the packing.
940
+ packed_dim = getattr(param, "packed_dim", None)
941
+ if packed_dim == output_dim:
942
+ shard_size = shard_size // param.pack_factor
943
+ shard_offset = shard_offset // param.pack_factor
944
+
945
+ # Special case for Marlin.
946
+ shard_size, shard_offset = adjust_marlin_shard(
947
+ param, shard_size, shard_offset)
948
+
949
+ use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
950
+ False)
951
+ is_sharded_weight = getattr(param, "is_sharded_weight", False)
952
+ # bitsandbytes loads the weights of the specific portion
953
+ # no need to narrow
954
+ is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
955
+
956
+ if use_bitsandbytes_4bit:
957
+ orig_qkv_offsets = {
958
+ "q": (0, self.num_heads * self.head_size),
959
+ "k": (self.num_heads * self.head_size,
960
+ self.num_kv_heads * self.head_size),
961
+ "v":
962
+ ((self.num_heads + self.num_kv_heads) * self.head_size,
963
+ self.num_kv_heads * self.head_size),
964
+ "total":
965
+ ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
966
+ 0)
967
+ }
968
+ shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
969
+ param, orig_qkv_offsets, loaded_shard_id)
970
+
971
+ param_data = param_data.narrow(output_dim, shard_offset,
972
+ shard_size)
973
+ if loaded_shard_id == "q":
974
+ shard_id = tp_rank
975
+ else:
976
+ shard_id = tp_rank // self.num_kv_head_replicas
977
+ start_idx = shard_id * shard_size
978
+
979
+ if not is_sharded_weight:
980
+ loaded_weight = loaded_weight.narrow(output_dim, start_idx,
981
+ shard_size)
982
+
983
+ # Special case for for AQLM codebooks.
984
+ elif is_metadata:
985
+ # metadata indicates fixed size concatenated along dim 0
986
+ shard_size = loaded_weight.shape[0]
987
+ shard_index = ["q", "k", "v"].index(loaded_shard_id)
988
+ param_data = param_data.narrow(0, shard_index * shard_size,
989
+ shard_size)
990
+ # Special case for per-tensor scales in fused case.
991
+ elif needs_scalar_to_array:
992
+ param_data, loaded_weight = adjust_scalar_to_fused_array(
993
+ param_data, loaded_weight, loaded_shard_id)
994
+ else:
995
+ ignore_warning = getattr(param, "ignore_warning", False)
996
+ if not ignore_warning:
997
+ logger.warning(
998
+ "Loading a weight without `output_dim` attribute in "
999
+ "QKVParallelLinear, assume the weight is the same "
1000
+ "for all partitions.")
1001
+
1002
+ assert param_data.shape == loaded_weight.shape
1003
+ param_data.copy_(loaded_weight)
1004
+
1005
+
1006
+ class RowParallelLinear(LinearBase):
1007
+ """Linear layer with row parallelism.
1008
+
1009
+ The linear layer is defined as Y = XA + b. A is parallelized along
1010
+ its first dimension and X along its second dimension as:
1011
+ - -
1012
+ | A_1 |
1013
+ | . |
1014
+ A = | . | X = [X_1, ..., X_p]
1015
+ | . |
1016
+ | A_p |
1017
+ - -
1018
+ Arguments:
1019
+ input_size: first dimension of matrix A.
1020
+ output_size: second dimension of matrix A.
1021
+ bias: If true, add bias. Note that bias is not parallelized.
1022
+ input_is_parallel: If true, we assume that the input is already
1023
+ split across the GPUs and we do not split
1024
+ again.
1025
+ skip_bias_add: This was added to enable performance optimization where
1026
+ bias can be fused with other element-wise operations.
1027
+ We skip adding bias but instead return it.
1028
+ params_dtype: Data type for the parameters.
1029
+ quant_config: Quantization configure.
1030
+ """
1031
+
1032
+ def __init__(self,
1033
+ input_size: int,
1034
+ output_size: int,
1035
+ bias: bool = True,
1036
+ input_is_parallel: bool = True,
1037
+ skip_bias_add: bool = False,
1038
+ params_dtype: Optional[torch.dtype] = None,
1039
+ reduce_results: bool = True,
1040
+ quant_config: Optional[QuantizationConfig] = None,
1041
+ prefix: str = ""):
1042
+ super().__init__(input_size, output_size, skip_bias_add, params_dtype,
1043
+ quant_config, prefix)
1044
+
1045
+ self.input_is_parallel = input_is_parallel
1046
+ self.reduce_results = reduce_results
1047
+
1048
+ # Divide the weight matrix along the last dimension.
1049
+ self.tp_rank = get_tensor_model_parallel_rank()
1050
+ self.tp_size = get_tensor_model_parallel_world_size()
1051
+ self.input_size_per_partition = divide(input_size, self.tp_size)
1052
+ assert self.quant_method is not None
1053
+
1054
+ self.quant_method.create_weights(
1055
+ layer=self,
1056
+ input_size_per_partition=self.input_size_per_partition,
1057
+ output_partition_sizes=[self.output_size],
1058
+ input_size=self.input_size,
1059
+ output_size=self.output_size,
1060
+ params_dtype=self.params_dtype,
1061
+ weight_loader=(
1062
+ self.weight_loader_v2 if self.quant_method.__class__.__name__
1063
+ in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
1064
+ if not reduce_results and (bias and not skip_bias_add):
1065
+ raise ValueError("When not reduce the results, adding bias to the "
1066
+ "results can lead to incorrect results")
1067
+
1068
+ if bias:
1069
+ self.bias = Parameter(
1070
+ torch.empty(self.output_size, dtype=params_dtype))
1071
+ set_weight_attrs(self.bias, {
1072
+ "output_dim": 0,
1073
+ "weight_loader": self.weight_loader,
1074
+ })
1075
+ else:
1076
+ self.register_parameter("bias", None)
1077
+
1078
+ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
1079
+ tp_rank = get_tensor_model_parallel_rank()
1080
+ tp_size = get_tensor_model_parallel_world_size()
1081
+ input_dim = getattr(param, "input_dim", None)
1082
+ use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1083
+ is_sharded_weight = getattr(param, "is_sharded_weight", False)
1084
+ # bitsandbytes loads the weights of the specific portion
1085
+ # no need to narrow
1086
+ is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1087
+
1088
+ # Special case for GGUF
1089
+ is_gguf_weight = getattr(param, "is_gguf_weight", False)
1090
+ is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1091
+ if is_gguf_weight_type:
1092
+ param.weight_type = loaded_weight.item()
1093
+
1094
+ # Materialize GGUF UninitializedParameter
1095
+ if is_gguf_weight and isinstance(param, UninitializedParameter):
1096
+ weight_shape = list(loaded_weight.shape)
1097
+ if input_dim:
1098
+ weight_shape[input_dim] = weight_shape[input_dim] // tp_size
1099
+ param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1100
+
1101
+ param_data = param.data
1102
+ if input_dim is not None and not is_sharded_weight:
1103
+ shard_size = param_data.shape[input_dim]
1104
+ start_idx = tp_rank * shard_size
1105
+ loaded_weight = loaded_weight.narrow(input_dim, start_idx,
1106
+ shard_size)
1107
+
1108
+ # Special case for loading scales off disk, which often do not
1109
+ # have a shape (such as in the case of AutoFP8).
1110
+ if len(loaded_weight.shape) == 0:
1111
+ loaded_weight = loaded_weight.reshape(1)
1112
+
1113
+ assert param_data.shape == loaded_weight.shape
1114
+ param_data.copy_(loaded_weight)
1115
+
1116
+ def weight_loader_v2(self, param: BasevLLMParameter,
1117
+ loaded_weight: torch.Tensor):
1118
+
1119
+ # Special case for loading scales off disk, which often do not
1120
+ # have a shape (such as in the case of AutoFP8).
1121
+ if len(loaded_weight.shape) == 0:
1122
+ assert loaded_weight.numel() == 1
1123
+ loaded_weight = loaded_weight.reshape(1)
1124
+
1125
+ param.load_row_parallel_weight(loaded_weight=loaded_weight)
1126
+
1127
+ def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]:
1128
+ if self.input_is_parallel:
1129
+ input_parallel = input_
1130
+ else:
1131
+ tp_rank = get_tensor_model_parallel_rank()
1132
+ splitted_input = split_tensor_along_last_dim(
1133
+ input_, num_partitions=self.tp_size)
1134
+ input_parallel = splitted_input[tp_rank].contiguous()
1135
+
1136
+ # Matrix multiply.
1137
+ assert self.quant_method is not None
1138
+ # Only fuse bias add into GEMM for rank 0 (this ensures that
1139
+ # bias will not get added more than once in TP>1 case)
1140
+ bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1141
+ output_parallel = self.quant_method.apply(self,
1142
+ input_parallel,
1143
+ bias=bias_)
1144
+ if self.reduce_results and self.tp_size > 1:
1145
+ output = tensor_model_parallel_all_reduce(output_parallel)
1146
+ else:
1147
+ output = output_parallel
1148
+
1149
+ output_bias = self.bias if self.skip_bias_add else None
1150
+
1151
+ return output, output_bias
1152
+
1153
+ def extra_repr(self) -> str:
1154
+ s = f"input_features={self.input_size_per_partition}"
1155
+ s += f", output_features={self.output_size}"
1156
+ s += f", bias={self.bias is not None}"
1157
+ s += f", tp_size={self.tp_size}"
1158
+ s += f", reduce_results={self.reduce_results}"
1159
+ return s
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/logits_processor.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """A layer that compute logits from hidden_stats."""
3
+ import inspect
4
+ from concurrent.futures import ThreadPoolExecutor
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ import vllm.envs as envs
11
+ from vllm.config import get_current_vllm_config
12
+ from vllm.distributed import (tensor_model_parallel_all_gather,
13
+ tensor_model_parallel_gather)
14
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
15
+ VocabParallelEmbedding)
16
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
17
+ from vllm.platforms import current_platform
18
+
19
+ _logits_processor_threadpool: Optional[ThreadPoolExecutor] = None
20
+ if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None:
21
+ _logits_processor_threadpool = ThreadPoolExecutor(
22
+ envs.VLLM_LOGITS_PROCESSOR_THREADS)
23
+
24
+
25
+ class LogitsProcessor(nn.Module):
26
+ """Process logits and apply logits processors from sampling metadata.
27
+
28
+ This layer does the following:
29
+ 1. Gather logits from model hidden_states.
30
+ 2. Scale logits if needed.
31
+ 3. Apply logits processors (if any).
32
+ """
33
+
34
+ def __init__(self,
35
+ vocab_size: int,
36
+ org_vocab_size: Optional[int] = None,
37
+ scale: float = 1.0,
38
+ logits_as_input: bool = False,
39
+ soft_cap: Optional[float] = None) -> None:
40
+ """
41
+ Args:
42
+ scale: A scaling factor to apply to the logits.
43
+ """
44
+ super().__init__()
45
+ self.scale = scale
46
+ self.vocab_size = vocab_size
47
+ # Whether the input is logits (default is hidden states).
48
+ self.logits_as_input = logits_as_input
49
+ # original vocabulary size (without LoRA).
50
+ self.org_vocab_size = org_vocab_size or vocab_size
51
+ # Soft cap the logits. Used in Gemma 2.
52
+ self.soft_cap = soft_cap
53
+ # Whether to use gather or all-gather to gather the logits.
54
+
55
+ parallel_config = get_current_vllm_config().parallel_config
56
+ self.use_all_gather = current_platform.is_tpu() \
57
+ or envs.VLLM_USE_V1 \
58
+ or parallel_config.distributed_executor_backend == "external_launcher" # noqa
59
+
60
+ def forward(
61
+ self,
62
+ lm_head: VocabParallelEmbedding,
63
+ hidden_states: torch.Tensor,
64
+ sampling_metadata: Optional[SamplingMetadata] = None,
65
+ embedding_bias: Optional[torch.Tensor] = None,
66
+ ) -> Optional[torch.Tensor]:
67
+ if self.logits_as_input:
68
+ logits = hidden_states
69
+ else:
70
+ if sampling_metadata is not None:
71
+ hidden_states = _prune_hidden_states(hidden_states,
72
+ sampling_metadata)
73
+
74
+ # Get the logits for the next tokens.
75
+ logits = self._get_logits(hidden_states, lm_head, embedding_bias)
76
+ if logits is not None:
77
+ if self.soft_cap is not None:
78
+ logits = logits / self.soft_cap
79
+ logits = torch.tanh(logits)
80
+ logits = logits * self.soft_cap
81
+
82
+ if self.scale != 1.0:
83
+ logits *= self.scale
84
+
85
+ # Apply logits processors (if any).
86
+ if sampling_metadata is not None:
87
+ logits = _apply_logits_processors(logits, sampling_metadata)
88
+
89
+ return logits
90
+
91
+ def _get_logits(
92
+ self,
93
+ hidden_states: torch.Tensor,
94
+ lm_head: VocabParallelEmbedding,
95
+ embedding_bias: Optional[torch.Tensor],
96
+ ) -> Optional[torch.Tensor]:
97
+ # Get the logits for the next tokens.
98
+ logits = lm_head.linear_method.apply(lm_head,
99
+ hidden_states,
100
+ bias=embedding_bias)
101
+
102
+ if self.use_all_gather:
103
+ # Gather is not supported for some devices such as TPUs.
104
+ # Use all-gather instead.
105
+ # NOTE(woosuk): Here, the outputs of every device should not be None
106
+ # because XLA requires strict SPMD among all devices. Every device
107
+ # should execute the same operations after gathering the logits.
108
+ logits = tensor_model_parallel_all_gather(logits)
109
+ else:
110
+ # None may be returned for rank > 0
111
+ logits = tensor_model_parallel_gather(logits)
112
+ # Remove paddings in vocab (if any).
113
+ if logits is not None:
114
+ logits = logits[..., :self.org_vocab_size]
115
+ return logits
116
+
117
+ def extra_repr(self) -> str:
118
+ s = f"vocab_size={self.vocab_size}"
119
+ s += f", forg_vocab_size={self.org_vocab_size}"
120
+ s += f", scale={self.scale}, logits_as_input={self.logits_as_input}"
121
+ return s
122
+
123
+
124
+ def _prune_hidden_states(
125
+ hidden_states: torch.Tensor,
126
+ sampling_metadata: SamplingMetadata,
127
+ ) -> torch.Tensor:
128
+ # NOTE(kzawora): The if guard is needed for Gaudi - in some scenarios
129
+ # (warmup, profile_run) we might not have selected_token_indices,
130
+ # so we skip pruning.
131
+ if sampling_metadata.selected_token_indices is not None:
132
+ return hidden_states.index_select(
133
+ 0, sampling_metadata.selected_token_indices)
134
+ else:
135
+ return hidden_states
136
+
137
+
138
+ def _apply_logits_processors(
139
+ logits: torch.Tensor,
140
+ sampling_metadata: SamplingMetadata,
141
+ ) -> torch.Tensor:
142
+ found_logits_processors = False
143
+ logits_processed = 0
144
+ logits_row_ids_and_logits_row_futures = []
145
+ for seq_group in sampling_metadata.seq_groups:
146
+ seq_ids = seq_group.seq_ids
147
+ sampling_params = seq_group.sampling_params
148
+ logits_processors = sampling_params.logits_processors
149
+ if logits_processors:
150
+ found_logits_processors = True
151
+
152
+ for seq_id, logits_row_idx in zip(seq_ids,
153
+ seq_group.sample_indices):
154
+ logits_row = logits[logits_row_idx]
155
+ past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids
156
+ prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids
157
+
158
+ if _logits_processor_threadpool is not None:
159
+ logits_row_ids_and_logits_row_futures.append(
160
+ (logits_row_idx,
161
+ _logits_processor_threadpool.submit(
162
+ _apply_logits_processors_single_seq, logits_row,
163
+ logits_processors, past_tokens_ids,
164
+ prompt_tokens_ids)))
165
+ else:
166
+ logits[logits_row_idx] = \
167
+ _apply_logits_processors_single_seq(
168
+ logits_row, logits_processors, past_tokens_ids,
169
+ prompt_tokens_ids)
170
+
171
+ logits_processed += len(seq_group.sample_indices) + len(
172
+ seq_group.prompt_logprob_indices)
173
+
174
+ for logits_row_idx, future in logits_row_ids_and_logits_row_futures:
175
+ logits[logits_row_idx] = future.result()
176
+
177
+ if found_logits_processors:
178
+ # verifies that no rows in logits were missed unexpectedly
179
+ assert logits_processed == logits.shape[0]
180
+ return logits
181
+
182
+
183
+ def _apply_logits_processors_single_seq(logits_row, logits_processors,
184
+ past_tokens_ids,
185
+ prompt_tokens_ids) -> torch.Tensor:
186
+ for logits_processor in logits_processors:
187
+ parameters = inspect.signature(logits_processor).parameters
188
+ if len(parameters) == 3:
189
+ logits_row = logits_processor(prompt_tokens_ids, past_tokens_ids,
190
+ logits_row)
191
+ else:
192
+ logits_row = logits_processor(past_tokens_ids, logits_row)
193
+ return logits_row
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/pooler.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from enum import IntEnum
4
+ from typing import List, Optional, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from transformers import PretrainedConfig
10
+ from typing_extensions import assert_never
11
+
12
+ from vllm.config import PoolerConfig
13
+ from vllm.model_executor.pooling_metadata import (PoolingMetadata,
14
+ PoolingTensors)
15
+ from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
16
+ from vllm.transformers_utils.config import (
17
+ get_cross_encoder_activation_function)
18
+
19
+
20
+ class PoolingType(IntEnum):
21
+ """Enumeration for different types of pooling methods."""
22
+ LAST = 0
23
+ ALL = 1
24
+ CLS = 2
25
+ STEP = 3
26
+ MEAN = 4
27
+
28
+
29
+ class SimplePooler(nn.Module):
30
+ """A layer that pools specific information from hidden states.
31
+
32
+ This layer does the following:
33
+ 1. Extracts specific tokens or aggregates data based on pooling method.
34
+ 2. Normalizes output if specified.
35
+ 3. Returns structured results as `PoolerOutput`.
36
+
37
+ Attributes:
38
+ pooling_type: The type of pooling to use.
39
+ normalize: Whether to normalize the pooled data.
40
+ """
41
+
42
+ @staticmethod
43
+ def from_pooling_type(
44
+ pooling_type: PoolingType,
45
+ *,
46
+ normalize: bool,
47
+ softmax: bool,
48
+ step_tag_id: Optional[int] = None,
49
+ returned_token_ids: Optional[List[int]] = None,
50
+ ) -> "SimplePooler":
51
+ if pooling_type == PoolingType.LAST:
52
+ assert step_tag_id is None and returned_token_ids is None
53
+ return LastPool(normalize=normalize, softmax=softmax)
54
+ if pooling_type == PoolingType.ALL:
55
+ assert step_tag_id is None and returned_token_ids is None
56
+ return AllPool(normalize=normalize, softmax=softmax)
57
+ if pooling_type == PoolingType.CLS:
58
+ assert step_tag_id is None and returned_token_ids is None
59
+ return CLSPool(normalize=normalize, softmax=softmax)
60
+ if pooling_type == PoolingType.MEAN:
61
+ assert step_tag_id is None and returned_token_ids is None
62
+ return MeanPool(normalize=normalize, softmax=softmax)
63
+ if pooling_type == PoolingType.STEP:
64
+ return StepPool(normalize=normalize,
65
+ softmax=softmax,
66
+ step_tag_id=step_tag_id,
67
+ returned_token_ids=returned_token_ids)
68
+
69
+ assert_never(pooling_type)
70
+
71
+ def __init__(self, *, normalize: bool, softmax: bool) -> None:
72
+ super().__init__()
73
+
74
+ self.head = PoolerHead(normalize=normalize, softmax=softmax)
75
+
76
+ def get_prompt_lens(
77
+ self,
78
+ hidden_states: torch.Tensor,
79
+ pooling_metadata: PoolingMetadata,
80
+ ) -> torch.Tensor:
81
+ return PoolingTensors.from_pooling_metadata(
82
+ pooling_metadata, hidden_states.device).prompt_lens
83
+
84
+ def extract_states(
85
+ self,
86
+ hidden_states: torch.Tensor,
87
+ pooling_metadata: PoolingMetadata,
88
+ ) -> Union[list[torch.Tensor], torch.Tensor]:
89
+ raise NotImplementedError
90
+
91
+ def build_output(self, data: torch.Tensor) -> PoolingSequenceGroupOutput:
92
+ return PoolingSequenceGroupOutput(data)
93
+
94
+ def forward(
95
+ self,
96
+ hidden_states: torch.Tensor,
97
+ pooling_metadata: PoolingMetadata,
98
+ ) -> PoolerOutput:
99
+ pooled_data = self.extract_states(hidden_states, pooling_metadata)
100
+ pooled_data = self.head(pooled_data)
101
+ pooled_outputs = [self.build_output(data) for data in pooled_data]
102
+ return PoolerOutput(outputs=pooled_outputs)
103
+
104
+
105
+ class CLSPool(SimplePooler):
106
+
107
+ def extract_states(
108
+ self,
109
+ hidden_states: torch.Tensor,
110
+ pooling_metadata: PoolingMetadata,
111
+ ) -> Union[list[torch.Tensor], torch.Tensor]:
112
+ prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
113
+
114
+ first_token_flat_indices = torch.zeros_like(prompt_lens)
115
+ first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
116
+ return hidden_states[first_token_flat_indices]
117
+
118
+
119
+ class LastPool(SimplePooler):
120
+
121
+ def extract_states(
122
+ self,
123
+ hidden_states: torch.Tensor,
124
+ pooling_metadata: PoolingMetadata,
125
+ ) -> Union[list[torch.Tensor], torch.Tensor]:
126
+ prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
127
+
128
+ last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
129
+ return hidden_states[last_token_flat_indices]
130
+
131
+
132
+ class AllPool(SimplePooler):
133
+
134
+ def extract_states(
135
+ self,
136
+ hidden_states: torch.Tensor,
137
+ pooling_metadata: PoolingMetadata,
138
+ ) -> Union[list[torch.Tensor], torch.Tensor]:
139
+ prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
140
+
141
+ offset = 0
142
+ pooled_data = list[torch.Tensor]()
143
+ for prompt_len in prompt_lens:
144
+ pooled_data.append(hidden_states[offset:offset + prompt_len])
145
+ offset += prompt_len
146
+
147
+ return pooled_data
148
+
149
+
150
+ class MeanPool(SimplePooler):
151
+
152
+ def extract_states(
153
+ self,
154
+ hidden_states: torch.Tensor,
155
+ pooling_metadata: PoolingMetadata,
156
+ ) -> Union[list[torch.Tensor], torch.Tensor]:
157
+ prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
158
+
159
+ cumsum = torch.cumsum(hidden_states, dim=0)
160
+ start_indices = torch.cat([
161
+ torch.tensor([0], device=hidden_states.device),
162
+ torch.cumsum(prompt_lens[:-1], dim=0)
163
+ ])
164
+ end_indices = torch.cumsum(prompt_lens, dim=0)
165
+ return (cumsum[end_indices - 1] - cumsum[start_indices] +
166
+ hidden_states[start_indices]) / prompt_lens.unsqueeze(1)
167
+
168
+
169
+ class StepPool(SimplePooler):
170
+
171
+ def __init__(
172
+ self,
173
+ *,
174
+ normalize: bool,
175
+ softmax: bool,
176
+ step_tag_id: Optional[int] = None,
177
+ returned_token_ids: Optional[List[int]] = None,
178
+ ):
179
+ super().__init__(normalize=normalize, softmax=softmax)
180
+
181
+ self.step_tag_id = step_tag_id
182
+ self.returned_token_ids = returned_token_ids
183
+
184
+ def extract_states(
185
+ self,
186
+ hidden_states: torch.Tensor,
187
+ pooling_metadata: PoolingMetadata,
188
+ ) -> Union[list[torch.Tensor], torch.Tensor]:
189
+ prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
190
+
191
+ returned_token_ids = self.returned_token_ids
192
+ if returned_token_ids is not None and len(returned_token_ids) > 0:
193
+ hidden_states = hidden_states[:, returned_token_ids]
194
+
195
+ step_tag_id = self.step_tag_id
196
+
197
+ offset = 0
198
+ pooled_data = list[torch.Tensor]()
199
+ for prompt_len, seq_data_i in zip(prompt_lens,
200
+ pooling_metadata.seq_data.values()):
201
+ pooled_data_i = hidden_states[offset:offset + prompt_len]
202
+ if step_tag_id is not None:
203
+ token_ids = torch.tensor(seq_data_i.prompt_token_ids)
204
+ pooled_data_i = pooled_data_i[token_ids == step_tag_id]
205
+
206
+ offset += prompt_len
207
+ pooled_data.append(pooled_data_i)
208
+
209
+ return pooled_data
210
+
211
+
212
+ class PoolerHead(nn.Module):
213
+
214
+ def __init__(self, *, normalize: bool, softmax: bool) -> None:
215
+ super().__init__()
216
+
217
+ self.normalize = normalize
218
+ self.softmax = softmax
219
+
220
+ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor]):
221
+ if self.normalize:
222
+ if isinstance(pooled_data, list):
223
+ pooled_data = [
224
+ F.normalize(data, p=2, dim=1) for data in pooled_data
225
+ ]
226
+ else:
227
+ pooled_data = F.normalize(pooled_data, p=2, dim=1)
228
+
229
+ if self.softmax:
230
+ if isinstance(pooled_data, list):
231
+ pooled_data = [F.softmax(data, dim=-1) for data in pooled_data]
232
+ else:
233
+ pooled_data = F.softmax(pooled_data, dim=-1)
234
+
235
+ return pooled_data
236
+
237
+
238
+ class Pooler(nn.Module):
239
+
240
+ @classmethod
241
+ def from_config_with_defaults(
242
+ cls,
243
+ pooler_config: PoolerConfig,
244
+ pooling_type: PoolingType,
245
+ normalize: bool,
246
+ softmax: bool,
247
+ step_tag_id: Optional[int] = None,
248
+ returned_token_ids: Optional[List[int]] = None,
249
+ ) -> SimplePooler:
250
+ return SimplePooler.from_pooling_type(
251
+ pooling_type=PoolingType[pooler_config.pooling_type]
252
+ if pooler_config.pooling_type is not None else pooling_type,
253
+ normalize=pooler_config.normalize
254
+ if pooler_config.normalize is not None else normalize,
255
+ softmax=pooler_config.softmax
256
+ if pooler_config.softmax is not None else softmax,
257
+ step_tag_id=pooler_config.step_tag_id
258
+ if pooler_config.step_tag_id is not None else step_tag_id,
259
+ returned_token_ids=pooler_config.returned_token_ids
260
+ if pooler_config.returned_token_ids is not None else
261
+ returned_token_ids,
262
+ )
263
+
264
+
265
+ class CrossEncodingPooler(nn.Module):
266
+ """A layer that pools specific information from hidden states.
267
+
268
+ This layer does the following:
269
+ 1. Extracts specific tokens or aggregates data based on pooling method.
270
+ 2. Normalizes output if specified.
271
+ 3. Returns structured results as `PoolerOutput`.
272
+
273
+ Attributes:
274
+ pooling_type: The type of pooling to use.
275
+ normalize: Whether to normalize the pooled data.
276
+ """
277
+
278
+ def __init__(
279
+ self,
280
+ config: PretrainedConfig,
281
+ classifier: nn.Module,
282
+ pooler: Optional[nn.Module] = None,
283
+ ):
284
+ super().__init__()
285
+ self.classifier = classifier
286
+ self.pooler = pooler
287
+ self.default_activation_function = \
288
+ get_cross_encoder_activation_function(config)
289
+
290
+ def forward(
291
+ self,
292
+ hidden_states: torch.Tensor,
293
+ pooling_metadata: PoolingMetadata,
294
+ ) -> PoolerOutput:
295
+ """Pools sentence pair scores from the hidden_states."""
296
+
297
+ prompt_lens = PoolingTensors.from_pooling_metadata(
298
+ pooling_metadata, hidden_states.device).prompt_lens
299
+
300
+ offset = 0
301
+ pooled_data_lst = []
302
+ for prompt_len in prompt_lens:
303
+ pooled_data_i = hidden_states[offset:offset + prompt_len]
304
+
305
+ if self.pooler is not None:
306
+ final_shape_tensor = self.pooler(pooled_data_i)
307
+ else:
308
+ final_shape_tensor = self.classifier(pooled_data_i)
309
+
310
+ pooled_data_lst.append(final_shape_tensor)
311
+ offset += prompt_len
312
+
313
+ pooled_output = torch.stack(pooled_data_lst)
314
+
315
+ if self.pooler is not None:
316
+ # apply classifier once on the full batch if possible
317
+ pooled_output = self.classifier(pooled_output)
318
+
319
+ scores = self.default_activation_function(pooled_output).squeeze(-1)
320
+
321
+ pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores]
322
+ return PoolerOutput(outputs=pooled_outputs)
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/rejection_sampler.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from functools import cached_property
4
+ from importlib.util import find_spec
5
+ from typing import Dict, Optional, Tuple
6
+
7
+ import torch
8
+ import torch.jit
9
+
10
+ import vllm.envs as envs
11
+ from vllm.logger import init_logger
12
+ from vllm.model_executor.layers.spec_decode_base_sampler import (
13
+ SpecDecodeStochasticBaseSampler)
14
+ from vllm.platforms import current_platform
15
+
16
+ logger = init_logger(__name__)
17
+
18
+ if find_spec("flashinfer"):
19
+ """
20
+ Consider utilizing the FlashInfer rejection sampling kernel initially,
21
+ as it employs a dedicated kernel rather than relying on
22
+ Torch tensor operations. This design choice helps to fuse operations,
23
+ reduce memory I/O, and consequently enhances performance.
24
+ """
25
+ from flashinfer.sampling import chain_speculative_sampling
26
+ else:
27
+ chain_speculative_sampling = None
28
+
29
+
30
+ class RejectionSampler(SpecDecodeStochasticBaseSampler):
31
+ """Apply modified rejection sampling as described in "Accelerating Large
32
+ Language Model Decoding with Speculative Sampling"
33
+ https://arxiv.org/pdf/2302.01318.pdf.
34
+ """
35
+
36
+ def __init__(self,
37
+ strict_mode: bool = False,
38
+ use_flashinfer: Optional[bool] = None):
39
+ """Create a rejection sampler.
40
+
41
+ Args:
42
+ strict_mode: Whether or not to perform shape/device/dtype checks
43
+ during sampling. This catches correctness issues but adds
44
+ nontrivial latency.
45
+ use_flashinfer: We will use this parameter to determine whether
46
+ to use the FlashInfer rejection sampling kernel or not. If it's
47
+ None, we will use the default value from the environment variable.
48
+ This parameter is only used for testing purposes.
49
+ """
50
+ super().__init__(strict_mode=strict_mode)
51
+ if use_flashinfer is None:
52
+ self.use_flashinfer = envs.VLLM_USE_FLASHINFER_SAMPLER and (
53
+ chain_speculative_sampling is not None)
54
+ else:
55
+ self.use_flashinfer = use_flashinfer
56
+
57
+ if self.use_flashinfer:
58
+ logger.info("Use flashinfer for rejection sampling.")
59
+ else:
60
+ logger.info("Use pytorch for rejection sampling.")
61
+
62
+ def forward(
63
+ self,
64
+ target_with_bonus_probs: torch.Tensor,
65
+ bonus_token_ids: torch.Tensor,
66
+ draft_probs: torch.Tensor,
67
+ draft_token_ids: torch.Tensor,
68
+ seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
69
+ ) -> torch.Tensor:
70
+ """Sample token ids using rejection sampling. This accepts or rejects
71
+ tokens proposed by the draft model using the probability of each token
72
+ according to the draft and target models.
73
+
74
+ In the worst case where all draft tokens are rejected, it is guaranteed
75
+ one correct token will be emitted.
76
+
77
+ In the case where all draft tokens are accepted, a bonus token will be
78
+ accepted as its cheap to have the target model score this speculative
79
+ sequence.
80
+
81
+ Args:
82
+ target_with_bonus_probs: The probability distribution
83
+ over token ids given context according to the target model.
84
+ shape = [batch_size, num_speculative_tokens + 1, vocab_size]
85
+
86
+ bonus_token_ids: The "bonus" token ids that are accepted iff all
87
+ speculative tokens in a sequence are accepted.
88
+ shape = [batch_size, num_bonus_tokens]
89
+
90
+ draft_probs: The probability distribution over token ids given
91
+ context according to the draft model.
92
+ shape = [batch_size, num_speculative_tokens, vocab_size]
93
+
94
+ draft_token_ids: The token ids that were sampled from the draft
95
+ probabilities.
96
+ shape = [batch_size, num_speculative_tokens]
97
+
98
+ seeded_seqs: Dict of batch row index to torch generator, for
99
+ sequences using seeded generation.
100
+
101
+ Returns:
102
+ output_token_ids: The token ids sampled via rejection sampling,
103
+ or -1 if unable to sample a token because the previous token
104
+ was rejected.
105
+ shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
106
+ """
107
+ # Only perform shape/dtype/device checking in strict mode, as it adds
108
+ # overhead.
109
+ if self._strict_mode:
110
+ self._raise_if_incorrect_input(target_with_bonus_probs,
111
+ draft_token_ids, bonus_token_ids,
112
+ draft_probs)
113
+
114
+ batch_size, k, _ = draft_probs.shape
115
+
116
+ # batch_size = 0 when all requests in the batch are
117
+ # non_spec requests. In this case, output_token_ids is
118
+ # just an empty tensor.
119
+ if batch_size == 0:
120
+ return torch.empty(0, k + 1, device=draft_probs.device, dtype=int)
121
+
122
+ # If use Flashinfer chain_speculative_sampling kernel
123
+ # for rejection sampling
124
+ if self.use_flashinfer and chain_speculative_sampling is not None:
125
+ batch_size, k, _ = draft_probs.shape
126
+ uniform_samples = self._create_uniform_samples(
127
+ seeded_seqs, batch_size, k, draft_probs.device)
128
+ output_token_ids, accepted_token_num, emitted_token_num \
129
+ = chain_speculative_sampling(
130
+ draft_probs, draft_token_ids, uniform_samples,
131
+ target_with_bonus_probs)
132
+
133
+ # num_emitted_tokens returned by flashinfer
134
+ # does not include the bonus token
135
+ # Flashinfer stops at the first token that violates
136
+ # the condition p >= q and does not include recovery/bonus token.
137
+ # Therefore, we need to add batch_size here.
138
+ self.num_accepted_tokens += accepted_token_num.sum()
139
+ self.num_emitted_tokens += emitted_token_num.sum() + batch_size
140
+ self.num_draft_tokens += batch_size * k
141
+ else:
142
+ accepted, recovered_token_ids = (
143
+ self._batch_modified_rejection_sampling(
144
+ target_with_bonus_probs[:, :-1],
145
+ draft_probs,
146
+ draft_token_ids,
147
+ seeded_seqs,
148
+ ))
149
+
150
+ output_token_ids = self._create_output(
151
+ accepted,
152
+ recovered_token_ids,
153
+ draft_token_ids,
154
+ bonus_token_ids,
155
+ )
156
+
157
+ return output_token_ids
158
+
159
+ def _batch_modified_rejection_sampling(
160
+ self,
161
+ target_probs: torch.Tensor, # [batch_size, k, vocab_size]
162
+ draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
163
+ draft_token_ids: torch.Tensor, # [batch_size, k]
164
+ seeded_seqs: Optional[Dict[int, torch.Generator]],
165
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
166
+ """Perform modified rejection sampling on each sequence.
167
+
168
+ Returns:
169
+ A tuple of two tensors:
170
+ 0: A bool tensor of which tokens in each sequence is accepted.
171
+ shape = [batch_size, k]
172
+ 1: Token ids sampled from a recovered distribution, to be used
173
+ when a token is rejected.
174
+ shape = [batch_size, k]
175
+ """
176
+
177
+ batch_size, k, vocab_size = draft_probs.shape
178
+
179
+ # shape [batch_size, k]
180
+ accepted = self._get_accepted(target_probs, draft_probs,
181
+ draft_token_ids, seeded_seqs)
182
+
183
+ recovered_probs = self._get_recovered_probs(
184
+ target_probs, draft_probs).reshape(batch_size * k, vocab_size)
185
+
186
+ # NOTE: the recovered_probs are overwritten by this method.
187
+ recovered_token_ids = _multinomial(
188
+ recovered_probs,
189
+ num_samples=1,
190
+ k=k,
191
+ seeded_seqs=seeded_seqs or {},
192
+ ).reshape(batch_size, k)
193
+
194
+ return accepted, recovered_token_ids
195
+
196
+ def _create_uniform_samples(self,
197
+ seeded_seqs: Optional[Dict[int,
198
+ torch.Generator]],
199
+ batch_size: int, k: int,
200
+ device: torch.device) -> torch.Tensor:
201
+ """
202
+ Generates a batch of uniform random samples, with optional seeding
203
+ for specific sequences.
204
+
205
+ This method creates a tensor of shape `(batch_size, k + 1)` filled
206
+ with uniform random values in the range [0, 1). If `seeded_seqs`
207
+ is provided, the sequences corresponding to specific indices
208
+ will be generated using the provided `torch.Generator` for
209
+ reproducibility. The other sequences will be generated without
210
+ a seed.
211
+
212
+ Args:
213
+ seeded_seqs : Optional[Dict[int, torch.Generator]]
214
+ A dictionary mapping indices in the batch to
215
+ `torch.Generator` objects. If `None`, all samples are
216
+ generated without a seed.
217
+ batch_size : int
218
+ The number of sequences to generate.
219
+ k : int
220
+ The number of random samples per sequence.
221
+ device : torch.device
222
+ The device on which to allocate the tensor.
223
+
224
+ Returns:
225
+ uniform_rand : torch.Tensor
226
+ A tensor of shape `(batch_size, k + 1)` containing uniform
227
+ random values in the range [0, 1).
228
+ """
229
+ if not seeded_seqs:
230
+ return torch.rand(batch_size, k + 1, device=device)
231
+
232
+ uniform_rand = torch.empty(batch_size, k + 1, device=device)
233
+
234
+ non_seeded_indices = []
235
+ for idx in range(batch_size):
236
+ generator = seeded_seqs.get(idx)
237
+ if generator is None:
238
+ non_seeded_indices.append(idx)
239
+ else:
240
+ uniform_rand[idx, :] = torch.rand(1,
241
+ k + 1,
242
+ dtype=self.probs_dtype,
243
+ device=device,
244
+ generator=generator)
245
+ if non_seeded_indices:
246
+ uniform_rand[non_seeded_indices, :] = torch.rand(
247
+ len(non_seeded_indices),
248
+ k + 1,
249
+ dtype=self.probs_dtype,
250
+ device=device)
251
+ return uniform_rand
252
+
253
+ def _get_accepted(
254
+ self,
255
+ target_probs: torch.Tensor, # [batch_size, k, vocab_size]
256
+ draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
257
+ draft_token_ids: torch.Tensor, # [batch_size, k]
258
+ seeded_seqs: Optional[Dict[int, torch.Generator]],
259
+ ) -> torch.Tensor:
260
+ r"""Create bool matrix over the proposed draft tokens. If
261
+ True, then a token can be accepted, else it should be
262
+ rejected.
263
+
264
+ Given :math:`q(\hat{x}_{n+1}|x_1, \dots, x_n)`, the probability of
265
+ :math:`\hat{x}_{n+1}` given context :math:`x_1, \dots, x_n` according
266
+ to the target model, and :math:`p(\hat{x}_{n+1}|x_1, \dots, x_n)`, the
267
+ same conditional probability according to the draft model, the token
268
+ is accepted with probability:
269
+
270
+ .. math::
271
+ \min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)}
272
+ {p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right)
273
+
274
+ This implementation does not apply causality. When using the output,
275
+ if a token is rejected, subsequent tokens should not be used.
276
+
277
+ Returns a bool tensor of shape [batch_size, k] specifying which tokens
278
+ are accepted.
279
+ """
280
+ batch_size, k, _ = draft_probs.shape
281
+ batch_indices = torch.arange(batch_size,
282
+ device=target_probs.device)[:, None]
283
+ probs_indicies = torch.arange(k, device=target_probs.device)
284
+
285
+ # shape [batch_size, k]
286
+ selected_draft_probs = draft_probs[batch_indices, probs_indicies,
287
+ draft_token_ids]
288
+
289
+ # shape [batch_size, k]
290
+ selected_target_probs = target_probs[batch_indices, probs_indicies,
291
+ draft_token_ids]
292
+
293
+ uniform_rand = self._create_uniform_samples(seeded_seqs, batch_size,
294
+ k - 1, target_probs.device)
295
+
296
+ capped_ratio = torch.minimum(
297
+ selected_target_probs / selected_draft_probs,
298
+ torch.full((1, ), 1, device=target_probs.device))
299
+ accepted = uniform_rand < capped_ratio
300
+
301
+ return accepted
302
+
303
+ def _get_recovered_probs(
304
+ self,
305
+ target_probs: torch.Tensor, # [k, vocab_size]
306
+ draft_probs: torch.Tensor, # [k, vocab_size]
307
+ ) -> torch.Tensor:
308
+ r"""Create a probability distribution for each proposed token which can
309
+ be sampled if the proposed token is rejected.
310
+
311
+ When this routine is applied sequentially, the true distribution of the
312
+ target model is recovered (within hardware numerics).
313
+
314
+ The probability distribution used in this rejection case is constructed
315
+ as follows. Given :math:`q(x|x_1, \dots, x_n)`, the probability of
316
+ :math:`x` given context :math:`x_1, \dots, x_n` according to the target
317
+ model and :math:`p(x|x_1, \dots, x_n)`, the same conditional probability
318
+ according to the draft model:
319
+
320
+ .. math::
321
+ x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+
322
+
323
+ where :math:`(f(x))_+` is defined as:
324
+
325
+ .. math::
326
+ (f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}
327
+
328
+ See https://github.com/vllm-project/vllm/pull/2336 for a visualization
329
+ of the draft, target, and recovered probability distributions.
330
+
331
+ Returns a tensor of shape [batch_size, k, vocab_size].
332
+
333
+ Note: This batches operations on GPU and thus constructs the recovered
334
+ distribution for all tokens, even if they are accepted. This causes
335
+ division-by-zero errors, so we use self._smallest_positive_value to
336
+ avoid that. This introduces some drift to the distribution.
337
+ """
338
+ _, k, _ = draft_probs.shape
339
+
340
+ # shape [batch_size, k, vocab_size]
341
+ difference = target_probs - draft_probs
342
+
343
+ # TODO(cade): Can we use logprobs instead of probs, and avoid the
344
+ # division-by-zero errors without introducing distribution drift?
345
+
346
+ # shape [batch_size, k, vocab_size]
347
+ f = torch.clamp(difference, min=self._smallest_positive_value)
348
+
349
+ # shape [batch_size, k, vocab_size]
350
+ recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1)
351
+
352
+ return recovered_probs
353
+
354
+ @cached_property
355
+ def _smallest_positive_value(self) -> float:
356
+ """Return the smallest positive value representable by the probs dtype.
357
+ This value is used when constructing a distribution from which to sample
358
+ recovered tokens in the first rejection case.
359
+
360
+ See _get_recovered_probs for more details
361
+
362
+ Note that this isn't actually the smallest positive value representable
363
+ by float32, but the smallest positive normal value.
364
+ See https://en.wikipedia.org/wiki/Subnormal_number for more information.
365
+ """
366
+ return torch.finfo(self.probs_dtype).tiny
367
+
368
+
369
+ # torch.multinomial forces a GPU<->CPU sync.
370
+ # Therefore, we use an optimized implementation instead that skips the sync.
371
+ # Note that we always sample with replacement.
372
+ # probs will be modified in place, but this is fine, as we pass
373
+ # in a copy already.
374
+ @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
375
+ def _multinomial(
376
+ probs: torch.Tensor,
377
+ num_samples: int,
378
+ k: int,
379
+ seeded_seqs: Dict[int, torch.Generator],
380
+ ) -> torch.Tensor:
381
+
382
+ if num_samples > 1:
383
+ # This is equivalent to torch.repeat_interleaved (which also
384
+ # forces a GPU<->CPU sync).
385
+ probs = probs[:, None, :].expand(probs.shape[0], num_samples,
386
+ probs.shape[1]).contiguous().view(
387
+ -1, probs.shape[1])
388
+ q = torch.empty_like(probs)
389
+ if not seeded_seqs:
390
+ q.exponential_(1.0)
391
+ else:
392
+ start = 0
393
+ for idx in range(len(q) // k):
394
+ end = start + k
395
+ generator = seeded_seqs.get(idx)
396
+ # Note: generator might be None for non seeded
397
+ q[start:end].exponential_(1.0, generator=generator)
398
+ start = end
399
+
400
+ return probs.div_(q).argmax(dim=1).view(-1, num_samples)
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/resampler.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # Adapted from
4
+ # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
5
+ # https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
6
+ # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
7
+ #
8
+ # Copyright 2023 The Qwen team.
9
+ # Copyright 2023 The vLLM team.
10
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
11
+ #
12
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
13
+ # and OPT implementations in this library. It has been modified from its
14
+ # original forms to accommodate minor architectural differences compared
15
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
16
+ #
17
+ # Licensed under the Apache License, Version 2.0 (the "License");
18
+ # you may not use this file except in compliance with the License.
19
+ # You may obtain a copy of the License at
20
+ #
21
+ # http://www.apache.org/licenses/LICENSE-2.0
22
+ #
23
+ # Unless required by applicable law or agreed to in writing, software
24
+ # distributed under the License is distributed on an "AS IS" BASIS,
25
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26
+ # See the License for the specific language governing permissions and
27
+ # limitations under the License.
28
+ """
29
+ Shared resampler perceiver network used in multimodal models and
30
+ related helpers for sincos positional embeddings.
31
+
32
+ Example models: Qwen (Qwen-VL), MiniCPM-V 2.0
33
+ """
34
+ import math
35
+ from functools import partial
36
+ from typing import Callable, Optional, Tuple, Union
37
+
38
+ import numpy as np
39
+ import torch
40
+ import torch.nn.functional as F
41
+ from torch import nn
42
+
43
+ from vllm.model_executor.layers.linear import ReplicatedLinear
44
+ from vllm.model_executor.layers.quantization import QuantizationConfig
45
+
46
+ DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
47
+
48
+
49
+ def get_abs_pos(abs_pos: torch.Tensor, tgt_size: Union[torch.Tensor,
50
+ int]) -> torch.Tensor:
51
+ # abs_pos: L, C
52
+ # tgt_size: (H, W)
53
+ # return: M, C
54
+ src_size = int(math.sqrt(abs_pos.size(0)))
55
+ dtype = abs_pos.dtype
56
+ if isinstance(tgt_size, int):
57
+ tgt_size = (tgt_size, tgt_size)
58
+ if (src_size == tgt_size[0] and src_size == tgt_size[1]):
59
+ return abs_pos
60
+ return (F.interpolate(
61
+ abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
62
+ size=(tgt_size[0], tgt_size[1]),
63
+ mode="bicubic",
64
+ align_corners=False,
65
+ ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype))
66
+
67
+
68
+ # sin/cos positional embedding helpers are adapted from:
69
+ # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
70
+ def get_1d_sincos_pos_embed_from_grid(
71
+ embed_dim: int, pos: np.ndarray,
72
+ version: Tuple[int, int] = (2, 0)) -> torch.Tensor:
73
+ """
74
+ embed_dim: output dimension for each position
75
+ pos: a list of positions to be encoded: size (M,) / (H, W)
76
+ out: (M, D) / (H, W, D)
77
+ """
78
+ assert embed_dim % 2 == 0
79
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
80
+ omega /= embed_dim / 2.0
81
+ omega = 1.0 / 10000**omega # (D/2,)
82
+
83
+ if version == (2, 0):
84
+ pos = pos.reshape(-1) # (M,)
85
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
86
+ emb_sin = np.sin(out) # (M, D/2)
87
+ emb_cos = np.cos(out) # (M, D/2)
88
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
89
+ else:
90
+ out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
91
+ emb_sin = np.sin(out) # (H, W, D/2)
92
+ emb_cos = np.cos(out) # (H, W, D/2)
93
+ emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
94
+ return emb
95
+
96
+
97
+ def get_2d_sincos_pos_embed_from_grid(
98
+ embed_dim: int, grid: np.ndarray,
99
+ version: Tuple[int, int] = (2, 0)) -> torch.Tensor:
100
+ assert embed_dim % 2 == 0
101
+
102
+ # use half of dimensions to encode grid_h
103
+ emb_h = get_1d_sincos_pos_embed_from_grid(
104
+ embed_dim // 2, grid[0], version) # (H*W, D/2) or (H, W, D/2)
105
+ emb_w = get_1d_sincos_pos_embed_from_grid(
106
+ embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2)
107
+
108
+ if version == (2, 0):
109
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
110
+ else:
111
+ emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
112
+ return emb
113
+
114
+
115
+ def get_2d_sincos_pos_embed(
116
+ embed_dim: int,
117
+ grid_size: Union[int, Tuple[int, int]],
118
+ cls_token: bool = False,
119
+ version: Tuple[int, int] = (2, 0),
120
+ ) -> torch.Tensor:
121
+ """
122
+ grid_size: int of the grid height and width
123
+ return:
124
+ pos_embed: [grid_size*grid_size, embed_dim] or
125
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
126
+ """
127
+ if isinstance(grid_size, int):
128
+ grid_h_size, grid_w_size = grid_size, grid_size
129
+ else:
130
+ grid_h_size, grid_w_size = grid_size[0], grid_size[1]
131
+
132
+ grid_h = np.arange(grid_h_size, dtype=np.float32)
133
+ grid_w = np.arange(grid_w_size, dtype=np.float32)
134
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
135
+ grid = np.stack(grid, axis=0)
136
+ assert isinstance(grid, np.ndarray) and \
137
+ grid.shape == (2, grid_h_size, grid_w_size)
138
+
139
+ if version == (2, 0):
140
+ grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
141
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
142
+ if cls_token:
143
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed],
144
+ axis=0)
145
+ else:
146
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
147
+ return pos_embed
148
+
149
+
150
+ class BaseResampler(nn.Module):
151
+ """
152
+ A 2D perceiver-resampler network with one cross attention layers by
153
+ (grid_size**2) learnable queries and 2d sincos pos_emb.
154
+ Outputs:
155
+ A tensor with the shape of (grid_size**2, embed_dim)
156
+ """
157
+
158
+ def __init__(self,
159
+ num_queries: int,
160
+ embed_dim: int,
161
+ num_heads: int,
162
+ kv_dim: Optional[int] = None,
163
+ norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
164
+ do_post_projection: bool = True,
165
+ quant_config: Optional[QuantizationConfig] = None,
166
+ prefix: str = "") -> None:
167
+ super().__init__()
168
+
169
+ self.num_queries = num_queries
170
+ self.embed_dim = embed_dim
171
+ self.num_heads = num_heads
172
+
173
+ self.query = nn.Parameter(torch.empty(self.num_queries, embed_dim))
174
+
175
+ if kv_dim is not None and kv_dim != embed_dim:
176
+ self.kv_proj = ReplicatedLinear(kv_dim,
177
+ embed_dim,
178
+ bias=False,
179
+ quant_config=quant_config,
180
+ prefix=f"{prefix}.kv_proj")
181
+ else:
182
+ # Maintain the same return value with ReplicatedLinear.forward
183
+ self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa
184
+ nn.Identity()(*args, **kwargs),
185
+ None,
186
+ )
187
+ self.attn = nn.MultiheadAttention(embed_dim, num_heads)
188
+ self.ln_q = norm_layer(embed_dim)
189
+ self.ln_kv = norm_layer(embed_dim)
190
+ self.do_post_projection = do_post_projection
191
+ self.ln_post = norm_layer(embed_dim) if do_post_projection else None
192
+ self.proj = nn.Parameter(
193
+ (embed_dim**-0.5) *
194
+ torch.empty(embed_dim, embed_dim)) if do_post_projection else None
195
+
196
+ def _repeat(self, query, N: int):
197
+ return query.unsqueeze(1).repeat(1, N, 1)
198
+
199
+
200
+ class Resampler2(BaseResampler):
201
+ """Resampler-perceiver network to be used for a variety of model types,
202
+ e.g., Qwen-vl / Minicpmv 2.0. The main difference is the addition of the
203
+ do_post_projection arg, which indicates whether or not there should be
204
+ a post layer normalization and projector after the attention. This is
205
+ present in minicpmv2.0, but not qwen-vl.
206
+ """
207
+
208
+ def __init__(self,
209
+ grid_size: int,
210
+ embed_dim: int,
211
+ num_heads: int,
212
+ kv_dim: Optional[int] = None,
213
+ norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
214
+ adaptive: bool = False,
215
+ do_post_projection: bool = True,
216
+ quant_config: Optional[QuantizationConfig] = None,
217
+ prefix: str = "") -> None:
218
+ super().__init__(grid_size**2,
219
+ embed_dim,
220
+ num_heads,
221
+ kv_dim,
222
+ norm_layer,
223
+ do_post_projection=do_post_projection,
224
+ quant_config=quant_config,
225
+ prefix=prefix)
226
+
227
+ self.adaptive = adaptive
228
+ pos_embed_arr = get_2d_sincos_pos_embed(embed_dim,
229
+ grid_size,
230
+ version=(2, 0))
231
+
232
+ self.pos_embed = nn.Parameter(
233
+ torch.from_numpy(pos_embed_arr).requires_grad_(False))
234
+
235
+ def forward(
236
+ self,
237
+ x: torch.Tensor,
238
+ tgt_sizes: Optional[torch.Tensor] = None,
239
+ attn_mask: Optional[torch.Tensor] = None,
240
+ ) -> torch.Tensor:
241
+ if tgt_sizes is None:
242
+ tgt_sizes = int(math.sqrt(x.size(1)))
243
+ if self.adaptive:
244
+ pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim,
245
+ tgt_sizes,
246
+ version=(2, 0))
247
+ pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device,
248
+ dtype=x.dtype)
249
+ else:
250
+ pos_embed = get_abs_pos(self.pos_embed,
251
+ tgt_sizes).to(device=x.device,
252
+ dtype=x.dtype)
253
+
254
+ x, _ = self.kv_proj(x)
255
+ x = self.ln_kv(x).permute(1, 0, 2)
256
+
257
+ N = x.shape[1]
258
+ q = self.ln_q(self.query)
259
+ out = self.attn(
260
+ self._repeat(q, N) + self.pos_embed.unsqueeze(1),
261
+ x + pos_embed.unsqueeze(1),
262
+ x,
263
+ attn_mask=attn_mask,
264
+ )[0]
265
+ x = out.permute(1, 0, 2)
266
+ if self.do_post_projection:
267
+ x = self.ln_post(x)
268
+ x = x @ self.proj
269
+ return x
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/rotary_embedding.py ADDED
@@ -0,0 +1,1114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # Adapted from
4
+ # https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
5
+ # Copyright 2023 The vLLM team.
6
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
7
+ #
8
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
9
+ # and OPT implementations in this library. It has been modified from its
10
+ # original forms to accommodate minor architectural differences compared
11
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
12
+ #
13
+ # Licensed under the Apache License, Version 2.0 (the "License");
14
+ # you may not use this file except in compliance with the License.
15
+ # You may obtain a copy of the License at
16
+ #
17
+ # http://www.apache.org/licenses/LICENSE-2.0
18
+ #
19
+ # Unless required by applicable law or agreed to in writing, software
20
+ # distributed under the License is distributed on an "AS IS" BASIS,
21
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22
+ # See the License for the specific language governing permissions and
23
+ # limitations under the License.
24
+ """Rotary Positional Embeddings."""
25
+ import math
26
+ from typing import Any, Dict, List, Optional, Tuple, Union
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+ from transformers import PretrainedConfig
31
+
32
+ from vllm.model_executor.custom_op import CustomOp
33
+
34
+
35
+ def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
36
+ x1 = x[..., :x.shape[-1] // 2]
37
+ x2 = x[..., x.shape[-1] // 2:]
38
+ return torch.cat((-x2, x1), dim=-1)
39
+
40
+
41
+ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
42
+ x1 = x[..., ::2]
43
+ x2 = x[..., 1::2]
44
+ x = torch.stack((-x2, x1), dim=-1)
45
+ return x.flatten(-2)
46
+
47
+
48
+ def _apply_rotary_emb(
49
+ x: torch.Tensor,
50
+ cos: torch.Tensor,
51
+ sin: torch.Tensor,
52
+ is_neox_style: bool,
53
+ ) -> torch.Tensor:
54
+ """
55
+ Args:
56
+ x: [num_tokens, num_heads, head_size]
57
+ cos: [num_tokens, head_size // 2]
58
+ sin: [num_tokens, head_size // 2]
59
+ is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
60
+ positional embeddings.
61
+ """
62
+ cos = cos.unsqueeze(-2).to(x.dtype)
63
+ sin = sin.unsqueeze(-2).to(x.dtype)
64
+ if is_neox_style:
65
+ x1, x2 = torch.chunk(x, 2, dim=-1)
66
+ else:
67
+ x1 = x[..., ::2]
68
+ x2 = x[..., 1::2]
69
+ o1 = x1 * cos - x2 * sin
70
+ o2 = x2 * cos + x1 * sin
71
+ if is_neox_style:
72
+ return torch.cat((o1, o2), dim=-1)
73
+ else:
74
+ return torch.stack((o1, o2), dim=-1).flatten(-2)
75
+
76
+
77
+ @CustomOp.register("rotary_embedding")
78
+ class RotaryEmbedding(CustomOp):
79
+ """Original rotary positional embedding."""
80
+
81
+ def __init__(
82
+ self,
83
+ head_size: int,
84
+ rotary_dim: int,
85
+ max_position_embeddings: int,
86
+ base: int,
87
+ is_neox_style: bool,
88
+ dtype: torch.dtype,
89
+ ) -> None:
90
+ super().__init__()
91
+ self.head_size = head_size
92
+ self.rotary_dim = rotary_dim
93
+ self.max_position_embeddings = max_position_embeddings
94
+ self.base = base
95
+ self.is_neox_style = is_neox_style
96
+ self.dtype = dtype
97
+
98
+ cache = self._compute_cos_sin_cache()
99
+ cache = cache.to(dtype)
100
+ self.cos_sin_cache: torch.Tensor
101
+ self.register_buffer("cos_sin_cache", cache, persistent=False)
102
+
103
+ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
104
+ """Compute the inverse frequency."""
105
+ # NOTE(woosuk): To exactly match the HF implementation, we need to
106
+ # use CPU to compute the cache and then move it to GPU. However, we
107
+ # create the cache on GPU for faster initialization. This may cause
108
+ # a slight numerical difference between the HF implementation and ours.
109
+ inv_freq = 1.0 / (base**(torch.arange(
110
+ 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
111
+ return inv_freq
112
+
113
+ def _compute_cos_sin_cache(self) -> torch.Tensor:
114
+ """Compute the cos and sin cache."""
115
+ inv_freq = self._compute_inv_freq(self.base)
116
+ t = torch.arange(self.max_position_embeddings, dtype=torch.float)
117
+
118
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
119
+ cos = freqs.cos()
120
+ sin = freqs.sin()
121
+ cache = torch.cat((cos, sin), dim=-1)
122
+ return cache
123
+
124
+ def forward_native(
125
+ self,
126
+ positions: torch.Tensor,
127
+ query: torch.Tensor,
128
+ key: torch.Tensor,
129
+ offsets: Optional[torch.Tensor] = None,
130
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
131
+ """A PyTorch-native implementation of forward()."""
132
+ if offsets is not None:
133
+ positions = positions + offsets
134
+ positions = positions.flatten()
135
+ num_tokens = positions.shape[0]
136
+ cos_sin = self.cos_sin_cache.index_select(0, positions)
137
+ cos, sin = cos_sin.chunk(2, dim=-1)
138
+
139
+ query_shape = query.shape
140
+ query = query.view(num_tokens, -1, self.head_size)
141
+ query_rot = query[..., :self.rotary_dim]
142
+ query_pass = query[..., self.rotary_dim:]
143
+ query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
144
+ query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
145
+
146
+ key_shape = key.shape
147
+ key = key.view(num_tokens, -1, self.head_size)
148
+ key_rot = key[..., :self.rotary_dim]
149
+ key_pass = key[..., self.rotary_dim:]
150
+ key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
151
+ key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
152
+ return query, key
153
+
154
+ def forward_cuda(
155
+ self,
156
+ positions: torch.Tensor,
157
+ query: torch.Tensor,
158
+ key: torch.Tensor,
159
+ offsets: Optional[torch.Tensor] = None,
160
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
161
+ from vllm import _custom_ops as ops
162
+
163
+ self.cos_sin_cache = self.cos_sin_cache.to(query.device,
164
+ dtype=query.dtype)
165
+ # ops.rotary_embedding()/batched_rotary_embedding()
166
+ # are in-place operations that update the query and key tensors.
167
+ if offsets is not None:
168
+ ops.batched_rotary_embedding(positions, query, key, self.head_size,
169
+ self.cos_sin_cache,
170
+ self.is_neox_style, self.rotary_dim,
171
+ offsets)
172
+ else:
173
+ ops.rotary_embedding(positions, query, key, self.head_size,
174
+ self.cos_sin_cache, self.is_neox_style)
175
+ return query, key
176
+
177
+ def forward_xpu(
178
+ self,
179
+ positions: torch.Tensor,
180
+ query: torch.Tensor,
181
+ key: torch.Tensor,
182
+ offsets: Optional[torch.Tensor] = None,
183
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
184
+ from vllm._ipex_ops import ipex_ops as ops
185
+
186
+ self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
187
+ dtype=query.dtype)
188
+ # ops.rotary_embedding()/batched_rotary_embedding()
189
+ # are in-place operations that update the query and key tensors.
190
+ if offsets is not None:
191
+ ops.batched_rotary_embedding(positions, query, key, self.head_size,
192
+ self.cos_sin_cache,
193
+ self.is_neox_style, self.rotary_dim,
194
+ offsets)
195
+ else:
196
+ ops.rotary_embedding(positions, query, key, self.head_size,
197
+ self.cos_sin_cache, self.is_neox_style)
198
+ return query, key
199
+
200
+ def forward_hpu(
201
+ self,
202
+ positions: torch.Tensor,
203
+ query: torch.Tensor,
204
+ key: torch.Tensor,
205
+ offsets: Optional[torch.Tensor] = None,
206
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
207
+ from habana_frameworks.torch.hpex.kernels import (
208
+ RotaryPosEmbeddingMode, apply_rotary_pos_emb)
209
+ positions = positions.flatten()
210
+ if offsets is not None:
211
+ positions = positions + offsets
212
+ num_tokens = positions.shape[0]
213
+ cos_sin = self.cos_sin_cache.index_select(0, positions).view(
214
+ num_tokens, 1, -1)
215
+ cos, sin = cos_sin.chunk(2, dim=-1)
216
+ # HPU RoPE kernel requires hidden dimension for cos and sin to be equal
217
+ # to query hidden dimension, so the original tensors need to be
218
+ # expanded
219
+ # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
220
+ # and expansion of cos/sin tensors via concatenation
221
+ # GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE
222
+ # and expansion of cos/sin tensors via repeat_interleave
223
+ rope_mode: RotaryPosEmbeddingMode
224
+ if self.is_neox_style:
225
+ rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
226
+ cos = torch.cat((cos, cos), dim=-1)
227
+ sin = torch.cat((sin, sin), dim=-1)
228
+ else:
229
+ rope_mode = RotaryPosEmbeddingMode.PAIRWISE
230
+ sin = torch.repeat_interleave(sin,
231
+ 2,
232
+ dim=-1,
233
+ output_size=cos_sin.shape[-1])
234
+ cos = torch.repeat_interleave(cos,
235
+ 2,
236
+ dim=-1,
237
+ output_size=cos_sin.shape[-1])
238
+
239
+ query_shape = query.shape
240
+ query = query.view(num_tokens, -1, self.head_size)
241
+ query_rot = query[..., :self.rotary_dim]
242
+ query_pass = query[..., self.rotary_dim:]
243
+ query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0,
244
+ rope_mode)
245
+ query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
246
+
247
+ key_shape = key.shape
248
+ key = key.view(num_tokens, -1, self.head_size)
249
+ key_rot = key[..., :self.rotary_dim]
250
+ key_pass = key[..., self.rotary_dim:]
251
+ key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
252
+ key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
253
+ return query, key
254
+
255
+ def extra_repr(self) -> str:
256
+ s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
257
+ s += f", max_position_embeddings={self.max_position_embeddings}"
258
+ s += f", base={self.base}, is_neox_style={self.is_neox_style}"
259
+ return s
260
+
261
+
262
+ class LinearScalingRotaryEmbedding(RotaryEmbedding):
263
+ """RotaryEmbedding extended with linear scaling.
264
+
265
+ It supports multiple scaling factors. Since multiple LoRA adapters may have
266
+ different scaling factors, we need multiple cos/sin caches. In this way,
267
+ instead of running rotary embedding kernel per lora, we can run multiple
268
+ lora in a batched way.
269
+
270
+ In addition to that, we also keep the cos/sin cache for the scaling factor
271
+ of 1 (default) at all times.
272
+
273
+ Exemplary for two scaling factors x=1, y and z with embeddings
274
+ [[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and
275
+ [[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and
276
+ [[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]],
277
+
278
+ we construct the cos/sin cache as follows:
279
+ [[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p],
280
+ ...
281
+ [xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]]
282
+
283
+ We then use offsets to index into the cos/sin cache for
284
+ the respective scaling factors.
285
+
286
+ The offset to cache can be accessed via `scaling_factor_to_offset` API.
287
+
288
+ Credits to the Reddit user /u/kaiokendev
289
+ """
290
+
291
+ def __init__(
292
+ self,
293
+ head_size: int,
294
+ rotary_dim: int,
295
+ max_position_embeddings: int,
296
+ base: int,
297
+ is_neox_style: bool,
298
+ scaling_factors: Union[List[float], float],
299
+ dtype: torch.dtype,
300
+ ) -> None:
301
+ if isinstance(scaling_factors, float):
302
+ scaling_factors = [scaling_factors]
303
+ self.scaling_factors: List[float] = scaling_factors # noqa
304
+ super().__init__(head_size, rotary_dim, max_position_embeddings, base,
305
+ is_neox_style, dtype)
306
+ # Lazy initialized.
307
+ self._scaling_factor_to_offset: Dict[float, int]
308
+
309
+ def _compute_cos_sin_cache(self) -> torch.Tensor:
310
+ inv_freq = self._compute_inv_freq(self.base)
311
+ cache_list: List[torch.Tensor] = []
312
+ # offsets to the next cache in a tensor.
313
+ # Each offset corresponds to the same index in scaling_factors.
314
+ offsets: List[int] = []
315
+ for scaling_factor in self.scaling_factors:
316
+ # NOTE(woosuk): self.max_position_embeddings is the original
317
+ # maximum length before applying the rope scaling.
318
+ # Thus, the maximum length after applying the rope scaling is
319
+ # self.max_position_embeddings * self.scaling_factor.
320
+ max_len = self.max_position_embeddings * scaling_factor
321
+ t = torch.arange(max_len, dtype=torch.float)
322
+ t = t / scaling_factor
323
+
324
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
325
+ cos = freqs.cos()
326
+ sin = freqs.sin()
327
+ cache = torch.cat((cos, sin), dim=-1)
328
+ if not cache_list:
329
+ offset = 0
330
+ else:
331
+ last_offset = offsets[-1]
332
+ next_max_len = cache_list[-1].shape[0]
333
+ offset = last_offset + next_max_len
334
+ offsets.append(offset)
335
+ cache_list.append(cache)
336
+ self._scaling_factor_to_offset = {
337
+ float(scaling_factor): offsets[i]
338
+ for i, scaling_factor in enumerate(self.scaling_factors)
339
+ }
340
+ assert len(self.scaling_factors) == len(offsets)
341
+ return torch.cat(cache_list, dim=0)
342
+
343
+ @property
344
+ def scaling_factor_to_offset(self) -> Dict[float, int]:
345
+ return self._scaling_factor_to_offset
346
+
347
+
348
+ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
349
+ """RotaryEmbedding extended with Dynamic NTK scaling.
350
+
351
+ Credits to the Reddit users /u/bloc97 and /u/emozilla
352
+ """
353
+
354
+ def __init__(
355
+ self,
356
+ head_size: int,
357
+ rotary_dim: int,
358
+ max_position_embeddings: int,
359
+ base: int,
360
+ is_neox_style: bool,
361
+ scaling_factor: float,
362
+ dtype: torch.dtype,
363
+ ) -> None:
364
+ self.scaling_factor = scaling_factor
365
+ super().__init__(head_size, rotary_dim, max_position_embeddings, base,
366
+ is_neox_style, dtype)
367
+
368
+ def _compute_cos_sin_cache(self) -> torch.Tensor:
369
+ # NOTE(woosuk): self.max_position_embeddings is the original
370
+ # maximum length before applying the rope scaling.
371
+ # Thus, the maximum length after applying the rope scaling is
372
+ # self.max_position_embeddings * self.scaling_factor.
373
+ max_len = self.max_position_embeddings * self.scaling_factor
374
+ base = self.base * (
375
+ (self.scaling_factor * max_len / self.max_position_embeddings) -
376
+ (self.scaling_factor - 1))**(self.rotary_dim /
377
+ (self.rotary_dim - 2))
378
+ inv_freq = self._compute_inv_freq(base)
379
+ t = torch.arange(max_len, dtype=torch.float)
380
+
381
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
382
+ cos = freqs.cos()
383
+ sin = freqs.sin()
384
+ cache = torch.cat((cos, sin), dim=-1)
385
+ return cache
386
+
387
+
388
+ # Inverse dim formula to find dim based on number of rotations
389
+ def _yarn_find_correction_dim(num_rotations: int,
390
+ dim: int,
391
+ base: float = 10000,
392
+ max_position_embeddings: int = 2048) -> float:
393
+ return (dim * math.log(max_position_embeddings /
394
+ (num_rotations * 2 * math.pi))) / (2 *
395
+ math.log(base))
396
+
397
+
398
+ # Find dim range bounds based on rotations
399
+ def _yarn_find_correction_range(
400
+ low_rot: int,
401
+ high_rot: int,
402
+ dim: int,
403
+ base: float = 10000,
404
+ max_position_embeddings: int = 2048) -> Tuple[int, int]:
405
+ low = math.floor(
406
+ _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
407
+ high = math.ceil(
408
+ _yarn_find_correction_dim(high_rot, dim, base,
409
+ max_position_embeddings))
410
+ return max(low, 0), min(high, dim - 1) # Clamp values just in case
411
+
412
+
413
+ def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
414
+ dtype: torch.dtype) -> torch.Tensor:
415
+ if low == high:
416
+ high += 0.001 # Prevent singularity
417
+
418
+ linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
419
+ ramp_func = torch.clamp(linear_func, 0, 1)
420
+ return ramp_func
421
+
422
+
423
+ def _yarn_get_mscale(scale: float = 1) -> float:
424
+ if scale <= 1:
425
+ return 1.0
426
+ return 0.1 * math.log(scale) + 1.0
427
+
428
+
429
+ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
430
+ """RotaryEmbedding extended with YaRN method.
431
+
432
+ Credits to Peng et al. github.com/jquesnelle/yarn
433
+ """
434
+
435
+ def __init__(
436
+ self,
437
+ head_size: int,
438
+ rotary_dim: int,
439
+ max_position_embeddings: int,
440
+ base: int,
441
+ is_neox_style: bool,
442
+ scaling_factor: float,
443
+ dtype: torch.dtype,
444
+ *,
445
+ extrapolation_factor: float = 1,
446
+ attn_factor: float = 1,
447
+ beta_fast: int = 32,
448
+ beta_slow: int = 1,
449
+ ) -> None:
450
+ self.scaling_factor = scaling_factor
451
+ self.extrapolation_factor = extrapolation_factor
452
+ self.attn_factor = attn_factor
453
+ self.beta_fast = beta_fast
454
+ self.beta_slow = beta_slow
455
+ # Get n-d magnitude scaling corrected for interpolation
456
+ self.mscale = float(
457
+ _yarn_get_mscale(self.scaling_factor) * attn_factor)
458
+ super().__init__(head_size, rotary_dim, max_position_embeddings, base,
459
+ is_neox_style, dtype)
460
+
461
+ def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
462
+ pos_freqs = self.base**(
463
+ torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
464
+ self.rotary_dim)
465
+ inv_freq_extrapolation = 1.0 / pos_freqs
466
+ inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
467
+
468
+ low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
469
+ self.rotary_dim, self.base,
470
+ self.max_position_embeddings)
471
+ # Get n-d rotational scaling corrected for extrapolation
472
+ inv_freq_mask = (1 - _yarn_linear_ramp_mask(
473
+ low, high, self.rotary_dim // 2,
474
+ dtype=torch.float)) * self.extrapolation_factor
475
+ inv_freq = inv_freq_interpolation * (
476
+ 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
477
+ return inv_freq
478
+
479
+ def _compute_cos_sin_cache(self) -> torch.Tensor:
480
+ inv_freq = self._compute_inv_freq(self.scaling_factor)
481
+ t = torch.arange(self.max_position_embeddings * self.scaling_factor,
482
+ dtype=torch.float32)
483
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
484
+ cos = (freqs.cos() * self.mscale)
485
+ sin = (freqs.sin() * self.mscale)
486
+ cache = torch.cat((cos, sin), dim=-1)
487
+ return cache
488
+
489
+
490
+ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
491
+ """Phi3 family of models scaled rotary embedding.
492
+
493
+ Based on the original RotaryEmbedding implementation.
494
+ """
495
+
496
+ def __init__(
497
+ self,
498
+ head_size: int,
499
+ rotary_dim: int,
500
+ max_position_embeddings: int,
501
+ original_max_position_embeddings: int,
502
+ base: int,
503
+ is_neox_style: bool,
504
+ dtype: torch.dtype,
505
+ short_factor: List[float],
506
+ long_factor: List[float],
507
+ short_mscale: Optional[float] = None,
508
+ long_mscale: Optional[float] = None,
509
+ ):
510
+ super().__init__()
511
+
512
+ if rotary_dim != head_size:
513
+ raise ValueError(
514
+ f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
515
+ rotary_dim != head_size ({rotary_dim}!={head_size}).")
516
+ if is_neox_style is False:
517
+ raise ValueError(
518
+ "`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
519
+ )
520
+
521
+ self.head_size = head_size
522
+ self.max_position_embeddings = max_position_embeddings
523
+ self.original_max_position_embeddings = original_max_position_embeddings
524
+ self.base = base
525
+ self.short_factor = short_factor
526
+ self.long_factor = long_factor
527
+
528
+ scale = self.max_position_embeddings / \
529
+ self.original_max_position_embeddings
530
+ if scale <= 1.0:
531
+ scaling_factor = 1.0
532
+ else:
533
+ scaling_factor = math.sqrt(
534
+ 1 + math.log(scale) /
535
+ math.log(self.original_max_position_embeddings))
536
+ if short_mscale is None:
537
+ short_mscale = scaling_factor
538
+ if long_mscale is None:
539
+ long_mscale = scaling_factor
540
+
541
+ self.short_mscale = short_mscale
542
+ self.long_mscale = long_mscale
543
+
544
+ short_cache = self._compute_cos_sin_cache(
545
+ original_max_position_embeddings, short_factor, short_mscale)
546
+ short_cache = short_cache.to(dtype)
547
+
548
+ long_cache = self._compute_cos_sin_cache(max_position_embeddings,
549
+ long_factor, long_mscale)
550
+ long_cache = long_cache.to(dtype)
551
+
552
+ long_short_cache = torch.cat([short_cache, long_cache], dim=0)
553
+ self.register_buffer("long_short_cos_sin_cache",
554
+ long_short_cache,
555
+ persistent=False)
556
+
557
+ def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor:
558
+ rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
559
+ inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
560
+ 0, self.head_size, 2, dtype=torch.float) / self.head_size)))
561
+ return inv_freq
562
+
563
+ def _compute_cos_sin_cache(
564
+ self,
565
+ max_position_embeddings: int,
566
+ rescale_factors: List[float],
567
+ mscale: float,
568
+ ) -> torch.Tensor:
569
+ inv_freq = self._compute_inv_freq(rescale_factors)
570
+ t = torch.arange(max_position_embeddings, dtype=torch.float)
571
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
572
+ cos = freqs.cos() * mscale
573
+ sin = freqs.sin() * mscale
574
+ cache = torch.cat((cos, sin), dim=-1)
575
+ return cache
576
+
577
+ def forward(
578
+ self,
579
+ positions: torch.Tensor,
580
+ query: torch.Tensor,
581
+ key: torch.Tensor,
582
+ offsets: Optional[torch.Tensor] = None,
583
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
584
+ query = query.view(*query.shape[:-1], -1, self.head_size)
585
+ key = key.view(*key.shape[:-1], -1, self.head_size)
586
+
587
+ k = self.original_max_position_embeddings
588
+ long_prompt_offset = (torch.any(positions > k).float() *
589
+ torch.full_like(positions, k)).long()
590
+ idx = (torch.add(positions, long_prompt_offset)
591
+ if long_prompt_offset is not None else positions)
592
+ idx = torch.add(idx, offsets) if offsets is not None else idx
593
+ cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
594
+
595
+ cos, sin = cos_sin.chunk(2, dim=-1)
596
+ cos = cos.repeat(1, 2).unsqueeze(-2)
597
+ sin = sin.repeat(1, 2).unsqueeze(-2)
598
+
599
+ query = query * cos + _rotate_neox(query) * sin
600
+ key = key * cos + _rotate_neox(key) * sin
601
+
602
+ return query.flatten(-2), key.flatten(-2)
603
+
604
+
605
+ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
606
+ if scale <= 1:
607
+ return 1.0
608
+ return 0.1 * mscale * math.log(scale) + 1.0
609
+
610
+
611
+ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
612
+ """RotaryEmbedding extended with YaRN method.
613
+
614
+ Credits to Peng et al. github.com/jquesnelle/yarn
615
+ """
616
+
617
+ def __init__(
618
+ self,
619
+ head_size: int,
620
+ rotary_dim: int,
621
+ max_position_embeddings: int,
622
+ base: int,
623
+ is_neox_style: bool,
624
+ scaling_factor: float,
625
+ dtype: torch.dtype,
626
+ *,
627
+ extrapolation_factor: float = 1,
628
+ attn_factor: float = 1,
629
+ beta_fast: int = 32,
630
+ beta_slow: int = 1,
631
+ mscale: float = 1,
632
+ mscale_all_dim: float = 0,
633
+ ) -> None:
634
+ self.scaling_factor = scaling_factor
635
+ self.extrapolation_factor = extrapolation_factor
636
+ self.attn_factor = attn_factor
637
+ self.beta_fast = beta_fast
638
+ self.beta_slow = beta_slow
639
+ # Get n-d magnitude scaling corrected for interpolation.
640
+ self.mscale = float(
641
+ yarn_get_mscale(self.scaling_factor, float(mscale)) /
642
+ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
643
+ attn_factor)
644
+ super().__init__(head_size, rotary_dim, max_position_embeddings, base,
645
+ is_neox_style, dtype)
646
+
647
+ def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
648
+ pos_freqs = self.base**(torch.arange(
649
+ 0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
650
+ self.rotary_dim)
651
+ inv_freq_extrapolation = 1.0 / pos_freqs
652
+ inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
653
+
654
+ low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
655
+ self.rotary_dim, self.base,
656
+ self.max_position_embeddings)
657
+ # Get n-d rotational scaling corrected for extrapolation
658
+ inv_freq_mask = (1 - _yarn_linear_ramp_mask(
659
+ low, high, self.rotary_dim // 2,
660
+ dtype=torch.float)) * self.extrapolation_factor
661
+ inv_freq = inv_freq_interpolation * (
662
+ 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
663
+ return inv_freq
664
+
665
+ def _compute_cos_sin_cache(self) -> torch.Tensor:
666
+ inv_freq = self._compute_inv_freq(self.scaling_factor)
667
+ t = torch.arange(self.max_position_embeddings * self.scaling_factor,
668
+ device="cuda",
669
+ dtype=torch.float32)
670
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
671
+ cos = (freqs.cos() * self.mscale)
672
+ sin = (freqs.sin() * self.mscale)
673
+ cache = torch.cat((cos, sin), dim=-1)
674
+ return cache
675
+
676
+ def forward(
677
+ self,
678
+ positions: torch.Tensor,
679
+ query: torch.Tensor,
680
+ key: torch.Tensor,
681
+ offsets: Optional[torch.Tensor] = None,
682
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
683
+ """PyTorch-native implementation equivalent to forward()."""
684
+ query_rot = query[..., :self.rotary_dim]
685
+ key_rot = key[..., :self.rotary_dim]
686
+ if self.rotary_dim < self.head_size:
687
+ query_pass = query[..., self.rotary_dim:]
688
+ key_pass = key[..., self.rotary_dim:]
689
+
690
+ self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
691
+ positions.device)
692
+ cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
693
+ if offsets is not None else positions]
694
+ cos, sin = cos_sin.chunk(2, dim=-1)
695
+ if self.is_neox_style:
696
+ # NOTE(woosuk): Here we assume that the positions tensor has the
697
+ # shape [batch_size, seq_len].
698
+ cos = cos.repeat(1, 1, 2).unsqueeze(-2)
699
+ sin = sin.repeat(1, 1, 2).unsqueeze(-2)
700
+ else:
701
+ cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
702
+ sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
703
+
704
+ rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
705
+ query_rot = query_rot * cos + rotate_fn(query_rot) * sin
706
+ key_rot = key_rot * cos + rotate_fn(key_rot) * sin
707
+
708
+ if self.rotary_dim < self.head_size:
709
+ query = torch.cat((query_rot, query_pass), dim=-1)
710
+ key = torch.cat((key_rot, key_pass), dim=-1)
711
+ else:
712
+ query = query_rot
713
+ key = key_rot
714
+ return query, key
715
+
716
+
717
+ class Llama3RotaryEmbedding(RotaryEmbedding):
718
+
719
+ def __init__(
720
+ self,
721
+ head_size: int,
722
+ rotary_dim: int,
723
+ max_position_embeddings: int,
724
+ base: int,
725
+ is_neox_style: bool,
726
+ dtype: torch.dtype,
727
+ scaling_factor: float,
728
+ low_freq_factor: float,
729
+ high_freq_factor: float,
730
+ orig_max_position: int,
731
+ ) -> None:
732
+ self.scaling_factor = scaling_factor
733
+ self.low_freq_factor = low_freq_factor
734
+ self.high_freq_factor = high_freq_factor
735
+ self.orig_max_position = orig_max_position
736
+ super().__init__(head_size, rotary_dim, max_position_embeddings, base,
737
+ is_neox_style, dtype)
738
+
739
+ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
740
+ inv_freqs = super()._compute_inv_freq(base)
741
+ low_freq_wavelen = self.orig_max_position / self.low_freq_factor
742
+ high_freq_wavelen = self.orig_max_position / self.high_freq_factor
743
+
744
+ wave_len = 2 * math.pi / inv_freqs
745
+ if self.low_freq_factor != self.high_freq_factor:
746
+ smooth = (self.orig_max_position / wave_len - self.low_freq_factor
747
+ ) / (self.high_freq_factor - self.low_freq_factor)
748
+ else:
749
+ smooth = 0
750
+ new_freqs = torch.where(
751
+ wave_len < high_freq_wavelen,
752
+ inv_freqs,
753
+ torch.where(
754
+ wave_len > low_freq_wavelen,
755
+ inv_freqs / self.scaling_factor,
756
+ (1 - smooth) * inv_freqs / self.scaling_factor +
757
+ smooth * inv_freqs,
758
+ ),
759
+ )
760
+ return new_freqs
761
+
762
+
763
+ class MRotaryEmbedding(RotaryEmbedding):
764
+ """Rotary Embedding with Multimodal Sections."""
765
+
766
+ def __init__(
767
+ self,
768
+ head_size: int,
769
+ rotary_dim: int,
770
+ max_position_embeddings: int,
771
+ base: int,
772
+ is_neox_style: bool,
773
+ dtype: torch.dtype,
774
+ mrope_section: Optional[List[int]] = None,
775
+ ) -> None:
776
+ # In Qwen2.5-VL, the maximum index value is related to the duration of
777
+ # the input video. We enlarge max_position_embeddings to 4 times to get
778
+ # a larger the cos and sin cache.
779
+ self.cache_max_position_num = max_position_embeddings * 4
780
+ super().__init__(head_size, rotary_dim, self.cache_max_position_num,
781
+ base, is_neox_style, dtype)
782
+
783
+ self.mrope_section = mrope_section
784
+ if self.mrope_section:
785
+ assert sum(self.mrope_section) == rotary_dim // 2
786
+
787
+ def forward(
788
+ self,
789
+ positions: torch.Tensor,
790
+ query: torch.Tensor,
791
+ key: torch.Tensor,
792
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
793
+ """PyTorch-native implementation equivalent to forward().
794
+
795
+ Args:
796
+ positions:
797
+ [num_tokens,] (text only) or
798
+ [3, num_tokens] (T/H/W positions with multimodal inputs)
799
+ query: [num_tokens, num_heads * head_size]
800
+ key: [num_tokens, num_kv_heads * head_size]
801
+ """
802
+ assert positions.ndim == 1 or positions.ndim == 2
803
+
804
+ num_tokens = positions.shape[-1]
805
+ cos_sin = self.cos_sin_cache[positions]
806
+ cos, sin = cos_sin.chunk(2, dim=-1)
807
+ if positions.ndim == 2:
808
+ assert self.mrope_section
809
+
810
+ cos = torch.cat([
811
+ m[i]
812
+ for i, m in enumerate(cos.split(self.mrope_section, dim=-1))
813
+ ],
814
+ dim=-1)
815
+ sin = torch.cat([
816
+ m[i]
817
+ for i, m in enumerate(sin.split(self.mrope_section, dim=-1))
818
+ ],
819
+ dim=-1)
820
+
821
+ query_shape = query.shape
822
+ query = query.view(num_tokens, -1, self.head_size)
823
+ query_rot = query[..., :self.rotary_dim]
824
+ query_pass = query[..., self.rotary_dim:]
825
+ query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
826
+ query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
827
+
828
+ key_shape = key.shape
829
+ key = key.view(num_tokens, -1, self.head_size)
830
+ key_rot = key[..., :self.rotary_dim]
831
+ key_pass = key[..., self.rotary_dim:]
832
+ key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
833
+ key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
834
+ return query, key
835
+
836
+ @staticmethod
837
+ def get_input_positions(
838
+ input_tokens: List[int],
839
+ hf_config: PretrainedConfig,
840
+ image_grid_thw: Union[List[List[int]], torch.Tensor],
841
+ video_grid_thw: Union[List[List[int]], torch.Tensor],
842
+ second_per_grid_ts: Optional[List[float]] = None,
843
+ context_len: int = 0,
844
+ seq_len: Optional[int] = None,
845
+ ) -> Tuple[List[List[int]], int]:
846
+ """Get mrope input positions and delta value."""
847
+
848
+ llm_positions, mrope_position_delta = \
849
+ MRotaryEmbedding.get_input_positions_tensor(
850
+ input_tokens=input_tokens,
851
+ hf_config=hf_config,
852
+ image_grid_thw=image_grid_thw,
853
+ video_grid_thw=video_grid_thw,
854
+ second_per_grid_ts=second_per_grid_ts,
855
+ context_len=context_len,
856
+ seq_len=seq_len,
857
+ )
858
+
859
+ return llm_positions.tolist(), mrope_position_delta
860
+
861
+ @staticmethod
862
+ def get_input_positions_tensor(
863
+ input_tokens: List[int],
864
+ hf_config: PretrainedConfig,
865
+ image_grid_thw: Union[List[List[int]], torch.Tensor],
866
+ video_grid_thw: Union[List[List[int]], torch.Tensor],
867
+ second_per_grid_ts: Optional[List[float]] = None,
868
+ context_len: int = 0,
869
+ seq_len: Optional[int] = None,
870
+ ) -> Tuple[torch.Tensor, int]:
871
+ """Get mrope input positions and delta value."""
872
+
873
+ image_token_id = hf_config.image_token_id
874
+ video_token_id = hf_config.video_token_id
875
+ vision_start_token_id = hf_config.vision_start_token_id
876
+ spatial_merge_size = hf_config.vision_config.spatial_merge_size
877
+ tokens_per_second = getattr(hf_config.vision_config,
878
+ "tokens_per_second", 1.0)
879
+
880
+ if isinstance(image_grid_thw, torch.Tensor):
881
+ image_grid_thw = image_grid_thw.tolist()
882
+ if isinstance(video_grid_thw, torch.Tensor):
883
+ video_grid_thw = video_grid_thw.tolist()
884
+
885
+ input_tokens_tensor = torch.tensor(input_tokens)
886
+ vision_start_indices = torch.argwhere(
887
+ input_tokens_tensor == vision_start_token_id).squeeze(1)
888
+ vision_tokens = input_tokens_tensor[vision_start_indices + 1]
889
+ image_nums = (vision_tokens == image_token_id).sum()
890
+ video_nums = (vision_tokens == video_token_id).sum()
891
+ llm_pos_ids_list: list = []
892
+
893
+ st = 0
894
+ remain_images, remain_videos = image_nums, video_nums
895
+
896
+ image_index, video_index = 0, 0
897
+ for _ in range(image_nums + video_nums):
898
+ video_second_per_grid_t = 0.0
899
+ if image_token_id in input_tokens and remain_images > 0:
900
+ ed_image = input_tokens.index(image_token_id, st)
901
+ else:
902
+ ed_image = len(input_tokens) + 1
903
+ if video_token_id in input_tokens and remain_videos > 0:
904
+ ed_video = input_tokens.index(video_token_id, st)
905
+ else:
906
+ ed_video = len(input_tokens) + 1
907
+ if ed_image < ed_video:
908
+ t, h, w = (
909
+ image_grid_thw[image_index][0],
910
+ image_grid_thw[image_index][1],
911
+ image_grid_thw[image_index][2],
912
+ )
913
+ image_index += 1
914
+ remain_images -= 1
915
+ ed = ed_image
916
+ else:
917
+ t, h, w = (
918
+ video_grid_thw[video_index][0],
919
+ video_grid_thw[video_index][1],
920
+ video_grid_thw[video_index][2],
921
+ )
922
+ video_second_per_grid_t = 1.0
923
+ if second_per_grid_ts is not None:
924
+ video_second_per_grid_t = second_per_grid_ts[video_index]
925
+ video_index += 1
926
+ remain_videos -= 1
927
+ ed = ed_video
928
+
929
+ llm_grid_t, llm_grid_h, llm_grid_w = \
930
+ t, h // spatial_merge_size, w // spatial_merge_size
931
+ text_len = ed - st
932
+
933
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(
934
+ llm_pos_ids_list) > 0 else 0
935
+ llm_pos_ids_list.append(
936
+ torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
937
+
938
+ t_index = (torch.arange(llm_grid_t).view(-1, 1).expand(
939
+ -1, llm_grid_h * llm_grid_w) * video_second_per_grid_t *
940
+ tokens_per_second).long().flatten()
941
+
942
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
943
+ llm_grid_t, -1, llm_grid_w).flatten()
944
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
945
+ llm_grid_t, llm_grid_h, -1).flatten()
946
+ llm_pos_ids_list.append(
947
+ torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
948
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
949
+
950
+ if st < len(input_tokens):
951
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(
952
+ llm_pos_ids_list) > 0 else 0
953
+ text_len = len(input_tokens) - st
954
+ llm_pos_ids_list.append(
955
+ torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
956
+
957
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
958
+ mrope_position_delta = (llm_positions.max() + 1 -
959
+ len(input_tokens)).item()
960
+ llm_positions = llm_positions[:, context_len:seq_len]
961
+
962
+ return llm_positions, mrope_position_delta
963
+
964
+ @staticmethod
965
+ def get_next_input_positions(
966
+ mrope_position_delta: int,
967
+ context_len: int,
968
+ seq_len: int,
969
+ ) -> List[List[int]]:
970
+ return [
971
+ list(
972
+ range(context_len + mrope_position_delta,
973
+ seq_len + mrope_position_delta)) for _ in range(3)
974
+ ]
975
+
976
+ @staticmethod
977
+ def get_next_input_positions_tensor(
978
+ mrope_position_delta: int,
979
+ context_len: int,
980
+ seq_len: int,
981
+ ) -> torch.Tensor:
982
+ return torch.arange(
983
+ mrope_position_delta + context_len,
984
+ mrope_position_delta + seq_len,
985
+ ).expand(3, -1)
986
+
987
+
988
+ _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
989
+
990
+
991
+ def get_rope(
992
+ head_size: int,
993
+ rotary_dim: int,
994
+ max_position: int,
995
+ base: int,
996
+ is_neox_style: bool = True,
997
+ rope_scaling: Optional[Dict[str, Any]] = None,
998
+ dtype: Optional[torch.dtype] = None,
999
+ partial_rotary_factor: float = 1.0,
1000
+ ) -> RotaryEmbedding:
1001
+ if dtype is None:
1002
+ dtype = torch.get_default_dtype()
1003
+ if rope_scaling is not None:
1004
+ # Transforms every value that is a list into a tuple for caching calls
1005
+ rope_scaling_tuple = {
1006
+ k: tuple(v) if isinstance(v, list) else v
1007
+ for k, v in rope_scaling.items()
1008
+ }
1009
+ rope_scaling_args = tuple(rope_scaling_tuple.items())
1010
+ else:
1011
+ rope_scaling_args = None
1012
+ if partial_rotary_factor < 1.0:
1013
+ rotary_dim = int(rotary_dim * partial_rotary_factor)
1014
+ key = (head_size, rotary_dim, max_position, base, is_neox_style,
1015
+ rope_scaling_args, dtype)
1016
+ if key in _ROPE_DICT:
1017
+ return _ROPE_DICT[key]
1018
+
1019
+ if rope_scaling is None:
1020
+ rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
1021
+ is_neox_style, dtype)
1022
+ else:
1023
+ scaling_type = rope_scaling["rope_type"]
1024
+
1025
+ if scaling_type == "llama3":
1026
+ scaling_factor = rope_scaling["factor"]
1027
+ low_freq_factor = rope_scaling["low_freq_factor"]
1028
+ high_freq_factor = rope_scaling["high_freq_factor"]
1029
+ original_max_position = rope_scaling[
1030
+ "original_max_position_embeddings"]
1031
+ rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim,
1032
+ max_position, base,
1033
+ is_neox_style, dtype,
1034
+ scaling_factor, low_freq_factor,
1035
+ high_freq_factor,
1036
+ original_max_position)
1037
+ elif scaling_type == "default":
1038
+ if "mrope_section" in rope_scaling:
1039
+ rotary_emb = MRotaryEmbedding(
1040
+ head_size,
1041
+ rotary_dim,
1042
+ max_position,
1043
+ base,
1044
+ is_neox_style,
1045
+ dtype,
1046
+ mrope_section=rope_scaling["mrope_section"],
1047
+ )
1048
+ else:
1049
+ rotary_emb = RotaryEmbedding(
1050
+ head_size,
1051
+ rotary_dim,
1052
+ max_position,
1053
+ base,
1054
+ is_neox_style,
1055
+ dtype,
1056
+ )
1057
+ elif scaling_type == "linear":
1058
+ scaling_factor = rope_scaling["factor"]
1059
+ rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
1060
+ max_position, base,
1061
+ is_neox_style,
1062
+ scaling_factor, dtype)
1063
+ elif scaling_type == "dynamic":
1064
+ scaling_factor = rope_scaling["factor"]
1065
+ rotary_emb = DynamicNTKScalingRotaryEmbedding(
1066
+ head_size, rotary_dim, max_position, base, is_neox_style,
1067
+ scaling_factor, dtype)
1068
+ elif scaling_type == "yarn":
1069
+ scaling_factor = rope_scaling["factor"]
1070
+ original_max_position = rope_scaling[
1071
+ "original_max_position_embeddings"]
1072
+ extra_kwargs = {
1073
+ k: v
1074
+ for k, v in rope_scaling.items()
1075
+ if k in ("extrapolation_factor", "attn_factor", "beta_fast",
1076
+ "beta_slow")
1077
+ }
1078
+ rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
1079
+ original_max_position,
1080
+ base, is_neox_style,
1081
+ scaling_factor, dtype,
1082
+ **extra_kwargs)
1083
+ elif scaling_type == "deepseek_yarn":
1084
+ scaling_factor = rope_scaling["factor"]
1085
+ original_max_position = rope_scaling[
1086
+ "original_max_position_embeddings"]
1087
+ # assert max_position == original_max_position * scaling_factor
1088
+ extra_kwargs = {
1089
+ k: v
1090
+ for k, v in rope_scaling.items()
1091
+ if k in ("extrapolation_factor", "attn_factor", "beta_fast",
1092
+ "beta_slow", "mscale", "mscale_all_dim")
1093
+ }
1094
+ rotary_emb = DeepseekScalingRotaryEmbedding(
1095
+ head_size, rotary_dim, original_max_position, base,
1096
+ is_neox_style, scaling_factor, dtype, **extra_kwargs)
1097
+ elif scaling_type == "longrope":
1098
+ short_factor = rope_scaling["short_factor"]
1099
+ long_factor = rope_scaling["long_factor"]
1100
+ original_max_position = rope_scaling[
1101
+ "original_max_position_embeddings"]
1102
+ extra_kwargs = {
1103
+ k: v
1104
+ for k, v in rope_scaling.items()
1105
+ if k in ("short_mscale", "long_mscale")
1106
+ }
1107
+ rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
1108
+ head_size, rotary_dim, max_position, original_max_position,
1109
+ base, is_neox_style, dtype, short_factor, long_factor,
1110
+ **extra_kwargs)
1111
+ else:
1112
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
1113
+ _ROPE_DICT[key] = rotary_emb
1114
+ return rotary_emb
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/sampler.py ADDED
@@ -0,0 +1,1292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """A layer that samples the next tokens from the model's outputs."""
3
+ import itertools
4
+ import warnings
5
+ from dataclasses import dataclass
6
+ from importlib.util import find_spec
7
+ from math import inf
8
+ from typing import Dict, Iterator, List, Optional, Tuple, Union
9
+
10
+ import msgspec
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ import vllm.envs as envs
15
+ from vllm.model_executor.layers.utils import apply_penalties
16
+ from vllm.model_executor.sampling_metadata import (SamplingMetadata,
17
+ SamplingTensors,
18
+ SequenceGroupToSample)
19
+ from vllm.sampling_params import SamplingType
20
+ from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
21
+ CompletionSequenceGroupOutput, Logprob,
22
+ PromptLogprobs, SampleLogprobs, SequenceOutput)
23
+ from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
24
+
25
+ if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
26
+ import flashinfer.sampling
27
+ # yapf: disable
28
+ from flashinfer.sampling import (
29
+ top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)
30
+
31
+ # yapf: enable
32
+ else:
33
+ flashinfer_top_k_top_p_sampling = None
34
+
35
+
36
+ def get_sampler() -> torch.nn.Module:
37
+ if envs.VLLM_USE_V1:
38
+ # Lazy import: the v1 package isn't distributed
39
+ from vllm.v1.sample.sampler import Sampler as V1Sampler
40
+ return V1Sampler()
41
+ return Sampler()
42
+
43
+
44
+ # (num_token_ids, num_parent_ids) per sequence group.
45
+ SampleResultType = List[Tuple[List[int], List[int]]]
46
+
47
+ # Types of temporary data structures used for
48
+ # computing sample_result
49
+ SampleMetadataType = Dict[SamplingType, Tuple[List[int],
50
+ List[SequenceGroupToSample]]]
51
+ MultinomialSamplesType = Dict[SamplingType, torch.Tensor]
52
+ SampleResultsDictType = Dict[int, Tuple[List[int], List[int]]]
53
+
54
+
55
+ # Encapsulates temporary data structures for computing
56
+ # sample_result.
57
+ #
58
+ # * For multi-step scheduling: must be returned
59
+ # by `Sampler.forward()` and used later to compute the pythonized
60
+ # sample_result
61
+ #
62
+ # * For single-step scheduling: consumed immediately
63
+ # inside `Sampler.forward()` to compute pythonized sample_result.
64
+ @dataclass
65
+ class SampleResultArgsType:
66
+ sample_metadata: SampleMetadataType
67
+ multinomial_samples: MultinomialSamplesType
68
+ sample_results_dict: SampleResultsDictType
69
+ sampling_metadata: SamplingMetadata
70
+ greedy_samples: Optional[torch.Tensor]
71
+ beam_search_logprobs: Optional[torch.Tensor]
72
+
73
+
74
+ # Union of non-deferred (single-step scheduling)
75
+ # vs deferred (multi-step scheduling)
76
+ # sample result types
77
+ MaybeDeferredSampleResultType = Union[SampleResultType, SampleResultArgsType]
78
+
79
+ # Abbreviation of the _sample() return type
80
+ SampleReturnType = Tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]]
81
+
82
+
83
+ class SamplerOutput(
84
+ msgspec.Struct,
85
+ omit_defaults=True, # type: ignore[call-arg]
86
+ array_like=True): # type: ignore[call-arg]
87
+ """For each sequence group, we generate a list of SequenceOutput object,
88
+ each of which contains one possible candidate for the next token.
89
+
90
+ This data structure implements methods, so it can be used like a list, but
91
+ also has optional fields for device tensors.
92
+ """
93
+
94
+ outputs: List[CompletionSequenceGroupOutput]
95
+
96
+ # On-device tensor containing probabilities of each token.
97
+ sampled_token_probs: Optional[torch.Tensor] = None
98
+
99
+ # On-device tensor containing the logprobs of each token.
100
+ logprobs: Optional["torch.Tensor"] = None
101
+
102
+ # Holds either (1) the pythonized sampler result (single-step scheduling)
103
+ # or (2) what will be arguments for later deferred pythonization of the
104
+ # sampler result (muliti-step scheduling)
105
+ deferred_sample_results_args: Optional[SampleResultArgsType] = None
106
+
107
+ # On-device tensor containing the sampled token ids.
108
+ sampled_token_ids: Optional[torch.Tensor] = None
109
+ # CPU tensor containing the sampled token ids. Used during multi-step to
110
+ # return the sampled token ids from last rank to AsyncLLMEngine to be
111
+ # 'broadcasted' to all other PP ranks for next step.
112
+ sampled_token_ids_cpu: Optional[torch.Tensor] = None
113
+
114
+ # Spec decode metrics populated by workers.
115
+ spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
116
+
117
+ # Optional last hidden states from the model.
118
+ hidden_states: Optional[torch.Tensor] = None
119
+
120
+ # Optional prefill hidden states from the model
121
+ # (used for models like EAGLE).
122
+ prefill_hidden_states: Optional[torch.Tensor] = None
123
+
124
+ # Time taken in the forward pass for this across all workers
125
+ model_forward_time: Optional[float] = None
126
+
127
+ # Time taken in the model execute function. This will include model forward,
128
+ # block/sync across workers, cpu-gpu sync time and sampling time.
129
+ model_execute_time: Optional[float] = None
130
+
131
+ def __getitem__(self, idx: int) -> CompletionSequenceGroupOutput:
132
+ return self.outputs[idx]
133
+
134
+ def __setitem__(self, idx: int, value):
135
+ self.outputs[idx] = value
136
+
137
+ def __iter__(self) -> Iterator[CompletionSequenceGroupOutput]:
138
+ return iter(self.outputs)
139
+
140
+ def __len__(self):
141
+ return len(self.outputs)
142
+
143
+ def __eq__(self, other: object):
144
+ return isinstance(other,
145
+ self.__class__) and self.outputs == other.outputs
146
+
147
+ def __repr__(self) -> str:
148
+ """Show the shape of a tensor instead of its values to reduce noise.
149
+ """
150
+ sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
151
+ else self.sampled_token_probs.shape)
152
+ sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
153
+ self.sampled_token_ids.shape)
154
+ return (
155
+ f"SamplerOutput(outputs={self.outputs}, "
156
+ f"sampled_token_probs={sampled_token_probs_repr}, "
157
+ f"sampled_token_ids={sampled_token_ids_repr}, "
158
+ f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
159
+
160
+
161
+ class Sampler(nn.Module):
162
+ """Samples the next tokens from the model's outputs.
163
+
164
+ This layer does the following:
165
+ 1. Discard the hidden states that are not used for sampling (i.e., all
166
+ tokens except the final one in each prompt).
167
+ 2. Compute the logits for the next tokens.
168
+ 3. Apply presence, frequency and repetition penalties.
169
+ 4. Apply temperature scaling.
170
+ 5. Apply top-p and top-k truncation.
171
+ 6. Sample the next tokens.
172
+ Here, each sequence group within the batch can have different sampling
173
+ parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
174
+
175
+ The structure of the logits tensor is coupled with the seq_groups in
176
+ sampling_metadata. Typically, each sequence in each seq_group has one row in
177
+ logits for the next token to be sampled; however, for a seq_group with a
178
+ prompt request with the prompt_logprobs sampling parameter, there are rows
179
+ in logits for each token in the input prompt.
180
+ """
181
+
182
+ def __init__(self):
183
+ super().__init__()
184
+
185
+ # Whether or not the SamplerOutput should have on-device tensors
186
+ # containing the sampled token ids and probabilities. This is used by
187
+ # speculative decoding.
188
+ self.include_gpu_probs_tensor = False
189
+ self.should_modify_greedy_probs_inplace = False
190
+
191
+ def _init_sampling_tensors(
192
+ self,
193
+ logits: torch.Tensor,
194
+ sampling_metadata: SamplingMetadata,
195
+ ):
196
+ """The goal here is to reuse sampling tensors between similar decode
197
+ runs. This is possible because sampling logic does not change between
198
+ decodes of the same sequences.
199
+ """
200
+ _, vocab_size = logits.shape
201
+
202
+ # First free any existing stored sampling tensors.
203
+ # This is necessary because some sampling tensors may
204
+ # have pinned memory.
205
+ self._sampling_tensors = None
206
+
207
+ # Initialize new sampling tensors
208
+ (sampling_tensors, do_penalties, do_top_p_top_k,
209
+ do_min_p) = SamplingTensors.from_sampling_metadata(
210
+ sampling_metadata, vocab_size, logits.device, logits.dtype)
211
+
212
+ self._sampling_tensors = sampling_tensors
213
+ self._do_penalties = do_penalties
214
+ self._do_top_p_top_k = do_top_p_top_k
215
+ self._do_min_p = do_min_p
216
+
217
+ def forward(
218
+ self,
219
+ logits: torch.Tensor,
220
+ sampling_metadata: SamplingMetadata,
221
+ ) -> Optional[SamplerOutput]:
222
+ """
223
+ Single-step scheduling:
224
+ * Perform GPU-side sampling computation & compute
225
+ GPU-side logprobs tensor
226
+ * Pythonize sampling result & logprobs tensor
227
+
228
+ Multi-step scheduling:
229
+ * Perform GPU-side sampling computation & compute
230
+ GPU-side logprobs tensor
231
+ * Defer Pythonization of sampling result & logprobs
232
+ tensor
233
+ * Encapsulate arguments required for deferred Pythonization
234
+ in the :class:`SamplerOutput` structure
235
+
236
+ Args:
237
+ logits: (num_tokens, vocab_size).
238
+ sampling_metadata: Metadata for sampling.
239
+ """
240
+ assert logits is not None
241
+ _, vocab_size = logits.shape
242
+
243
+ # Prepare sampling tensors with pinned memory to avoid blocking.
244
+ if not sampling_metadata.reuse_sampling_tensors:
245
+ self._init_sampling_tensors(logits, sampling_metadata)
246
+ elif self._do_penalties:
247
+ # In this case, the sampling tensors logic depends on
248
+ # "output_tokens" of a sequence. As a result, we cannot
249
+ # reuse sampling tensors, since "output_tokens" changes
250
+ # between decode runs.
251
+ self._init_sampling_tensors(logits, sampling_metadata)
252
+
253
+ assert self._sampling_tensors is not None
254
+ sampling_tensors = self._sampling_tensors
255
+ do_penalties = self._do_penalties
256
+ do_top_p_top_k = self._do_top_p_top_k
257
+ do_min_p = self._do_min_p
258
+
259
+ logits = _apply_min_tokens_penalty(logits, sampling_metadata)
260
+
261
+ # Apply presence and frequency penalties.
262
+ if do_penalties:
263
+ logits = apply_penalties(logits, sampling_tensors.prompt_tokens,
264
+ sampling_tensors.output_tokens,
265
+ sampling_tensors.presence_penalties,
266
+ sampling_tensors.frequency_penalties,
267
+ sampling_tensors.repetition_penalties)
268
+
269
+ # Use float32 to apply temperature scaling.
270
+ # Use in-place division to avoid creating a new tensor.
271
+ logits = logits.to(torch.float)
272
+ logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
273
+
274
+ if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
275
+ logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
276
+ sampling_tensors.top_ks)
277
+
278
+ if do_min_p:
279
+ logits = _apply_min_p(logits, sampling_tensors.min_ps)
280
+
281
+ # We use float32 for probabilities and log probabilities.
282
+ # Compute the probabilities.
283
+ probs = torch.softmax(logits, dim=-1, dtype=torch.float)
284
+ # Compute the log probabilities.
285
+ logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
286
+
287
+ # Sample the next tokens.
288
+ maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
289
+ probs,
290
+ logprobs,
291
+ sampling_metadata,
292
+ sampling_tensors,
293
+ include_gpu_probs_tensor=self.include_gpu_probs_tensor,
294
+ modify_greedy_probs=self._should_modify_greedy_probs_inplace,
295
+ )
296
+
297
+ if self.include_gpu_probs_tensor:
298
+ # Since we will defer sampler result Pythonization,
299
+ # preserve GPU-side tensors in support of later
300
+ # deferred pythonization of logprobs
301
+ assert maybe_sampled_tokens_tensor is not None
302
+ on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
303
+ else:
304
+ # Since Pythonization has already happened, don't preserve
305
+ # GPU-side tensors.
306
+ on_device_tensors = None
307
+
308
+ # Get the logprobs query results.
309
+ prompt_logprobs = None
310
+ sample_logprobs = None
311
+ if not sampling_metadata.skip_sampler_cpu_output:
312
+ # Pythonize logprobs now (GPU -> CPU); do not defer.
313
+ assert not isinstance(maybe_deferred_sample_results,
314
+ SampleResultArgsType)
315
+ prompt_logprobs, sample_logprobs = get_logprobs(
316
+ logprobs, sampling_metadata, maybe_deferred_sample_results)
317
+
318
+ return _build_sampler_output(
319
+ maybe_deferred_sample_results,
320
+ sampling_metadata,
321
+ prompt_logprobs,
322
+ sample_logprobs,
323
+ on_device_tensors=on_device_tensors,
324
+ skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
325
+
326
+ @property
327
+ def _should_modify_greedy_probs_inplace(self) -> bool:
328
+ """Whether or not the sampler should modify the probability distribution
329
+ of greedily-sampled tokens such that multinomial sampling would sample
330
+ the greedily-sampled token.
331
+
332
+ In other words, if True then we set the probability of the greedily-
333
+ sampled token to 1.
334
+
335
+ This is used by speculative decoding, which requires that the sampling
336
+ method be encoded into the probability distribution.
337
+ """
338
+ return self.should_modify_greedy_probs_inplace
339
+
340
+
341
+ def _apply_min_tokens_penalty(
342
+ logits: torch.Tensor,
343
+ sampling_metadata: SamplingMetadata,
344
+ ) -> torch.Tensor:
345
+ """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
346
+ have not been generated yet
347
+ """
348
+ # list of indices in logits that will be set to -inf
349
+ logits_to_penalize: List[Tuple[int, int]] = []
350
+ logits_applied = 0
351
+ for seq_group in sampling_metadata.seq_groups:
352
+ seq_ids = seq_group.seq_ids
353
+ sampling_params = seq_group.sampling_params
354
+
355
+ sample_indices = seq_group.sample_indices
356
+ logits_applied += len(sample_indices) + len(
357
+ seq_group.prompt_logprob_indices)
358
+ if not seq_group.do_sample:
359
+ continue
360
+
361
+ start_idx = sample_indices[0]
362
+ min_tokens = sampling_params.min_tokens
363
+ token_ids_to_penalize = sampling_params.all_stop_token_ids
364
+ if min_tokens > 0 and token_ids_to_penalize:
365
+ seqs_to_penalize: List[int] = []
366
+ for j, seq_id in enumerate(seq_ids):
367
+ seq_data = seq_group.seq_data[seq_id]
368
+ if len(seq_data.output_token_ids_array) < min_tokens:
369
+ seqs_to_penalize.append(j)
370
+
371
+ if seqs_to_penalize:
372
+ # convert to the index into logits
373
+ seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
374
+ # itertools.product pairs each seq index with every token id
375
+ logits_to_penalize.extend(
376
+ itertools.product(seqs_to_penalize, token_ids_to_penalize))
377
+
378
+ if logits_to_penalize:
379
+ # use zip and * to group indices along each dimension
380
+ # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
381
+ logits[tuple(zip(*logits_to_penalize))] = -float("inf")
382
+
383
+ # verifies that no rows in logits were missed unexpectedly
384
+ assert logits_applied == logits.shape[0]
385
+ return logits
386
+
387
+
388
+ def _apply_top_k_top_p(
389
+ logits: torch.Tensor,
390
+ p: torch.Tensor,
391
+ k: torch.Tensor,
392
+ ) -> torch.Tensor:
393
+ logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
394
+
395
+ # Apply top-k.
396
+ top_k_mask = logits_sort.size(1) - k.to(torch.long)
397
+ # Get all the top_k values.
398
+ top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
399
+ top_k_mask = logits_sort < top_k_mask
400
+ logits_sort.masked_fill_(top_k_mask, -float("inf"))
401
+
402
+ # Apply top-p.
403
+ probs_sort = logits_sort.softmax(dim=-1)
404
+ probs_sum = probs_sort.cumsum(dim=-1)
405
+ top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
406
+ # at least one
407
+ top_p_mask[:, -1] = False
408
+ logits_sort.masked_fill_(top_p_mask, -float("inf"))
409
+
410
+ # Re-sort the probabilities.
411
+ logits = torch.empty_like(logits_sort).scatter_(dim=-1,
412
+ index=logits_idx,
413
+ src=logits_sort)
414
+ return logits
415
+
416
+
417
+ def _apply_min_p(
418
+ logits: torch.Tensor,
419
+ min_p: torch.Tensor,
420
+ ) -> torch.Tensor:
421
+ """
422
+ Adapted from
423
+ https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
424
+ """
425
+ probs = torch.softmax(logits, dim=-1)
426
+ top_probs, _ = probs.max(dim=-1, keepdim=True)
427
+ scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
428
+ tokens_to_remove = probs < scaled_min_p
429
+ logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
430
+
431
+ return logits
432
+
433
+
434
+ def _greedy_sample(
435
+ selected_seq_groups: List[SequenceGroupToSample],
436
+ samples: torch.Tensor,
437
+ ) -> SampleResultType:
438
+ """Run greedy sampling on a given samples.
439
+
440
+ Args:
441
+ selected_seq_groups: A list of sequence groups batched.
442
+ samples: (num_selected_samples,) A tensor of samples. The length of
443
+ samples could be smaller than selected_seq_groups if
444
+ seq_group.do_sample is False.
445
+ Returns:
446
+ Tuple of (next_token_ids, parent_ids). The length of returned list is
447
+ same as the length of selected_seq_groups. If the corresponding
448
+ seq_group has do_sample=False, tuple contains ([], [])
449
+ """
450
+ samples_lst = samples.tolist()
451
+ sample_idx = 0
452
+ results: SampleResultType = []
453
+ for seq_group in selected_seq_groups:
454
+ if not seq_group.do_sample:
455
+ results.append(([], []))
456
+ continue
457
+
458
+ seq_ids = seq_group.seq_ids
459
+ num_parent_seqs = len(seq_ids)
460
+ assert num_parent_seqs == 1, (
461
+ "Greedy sampling should have only one seq.")
462
+ parent_ids = list(range(num_parent_seqs))
463
+ next_token_ids = [samples_lst[sample_idx]]
464
+ results.append((next_token_ids, parent_ids))
465
+ sample_idx += num_parent_seqs
466
+ return results
467
+
468
+
469
+ def _random_sample(
470
+ selected_seq_groups: List[SequenceGroupToSample],
471
+ random_samples: torch.Tensor,
472
+ ) -> SampleResultType:
473
+ """Run random sampling on a given samples.
474
+
475
+ Args:
476
+ selected_seq_groups: A list of sequence groups batched.
477
+ random_samples: (num_selected_samples,) A tensor of samples. The
478
+ length of samples could be smaller than selected_seq_groups if
479
+ seq_group.do_sample is False.
480
+ Returns:
481
+ Tuple of (next_token_ids, parent_ids). The length of returned list is
482
+ same as the length of selected_seq_groups. If the corresponding
483
+ seq_group has do_sample=False, tuple contains ([], [])
484
+ """
485
+ # Find the maximum n value of the prompt phase requests.
486
+ random_samples = random_samples.cpu()
487
+ sample_idx = 0
488
+ results: SampleResultType = []
489
+ for seq_group in selected_seq_groups:
490
+ if not seq_group.do_sample:
491
+ results.append(([], []))
492
+ continue
493
+
494
+ seq_ids = seq_group.seq_ids
495
+ sampling_params = seq_group.sampling_params
496
+ is_prompt = seq_group.is_prompt
497
+ num_parent_seqs = len(seq_ids)
498
+ if is_prompt:
499
+ # Prompt phase.
500
+ parent_ids = [0] * sampling_params.n
501
+ next_token_ids = random_samples[
502
+ sample_idx, :sampling_params.n].tolist()
503
+ else:
504
+ # Generation phase.
505
+ parent_ids = list(range(num_parent_seqs))
506
+ next_token_ids = random_samples[sample_idx:sample_idx +
507
+ num_parent_seqs, 0].tolist()
508
+ results.append((next_token_ids, parent_ids))
509
+ sample_idx += num_parent_seqs
510
+ return results
511
+
512
+
513
+ def _beam_search_sample(
514
+ selected_seq_groups: List[SequenceGroupToSample],
515
+ logprobs: torch.Tensor,
516
+ ) -> SampleResultType:
517
+ """Run beam sampling on a given samples.
518
+
519
+ Args:
520
+ selected_seq_groups: A list of sequence groups batched.
521
+ logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
522
+ on selected sample indices.
523
+ Returns:
524
+ Tuple of (next_token_ids, parent_ids). The length of returned list is
525
+ same as the length of selected_seq_groups. If the corresponding
526
+ seq_group has do_sample=False, tuple contains ([], [])
527
+ """
528
+ # We sample 2 * beam_width candidates to make sure that with high
529
+ # probability we can get `beam_width` candidates in addition to
530
+ # the finished sequences for the next iteration. See
531
+ # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
532
+ # for details. See also HF reference:
533
+ # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
534
+ #
535
+ # NOTE: Beam search is not vectorized, so its speed can be slower than
536
+ # other sampling methods.
537
+ sample_idx = 0
538
+ results: SampleResultType = []
539
+ for seq_group in selected_seq_groups:
540
+ if not seq_group.do_sample:
541
+ results.append(([], []))
542
+ continue
543
+
544
+ is_prompt = seq_group.is_prompt
545
+ seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
546
+ num_parent_seqs = len(seq_ids)
547
+ beam_width = sampling_params.n
548
+ seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
549
+ if is_prompt:
550
+ # Prompt phase.
551
+ assert num_parent_seqs == 1, (
552
+ "Prompt input should have only one seq.")
553
+ parent_ids = [0] * (2 * beam_width)
554
+ _, next_token_ids = torch.topk(seq_group_logprobs[0],
555
+ 2 * beam_width)
556
+ next_token_ids = next_token_ids.tolist()
557
+ else:
558
+ # Generation phase.
559
+ cumulative_logprobs: List[float] = [
560
+ seq_group.seq_data[seq_id].cumulative_logprob
561
+ for seq_id in seq_ids
562
+ ]
563
+ cumulative_logprobs_tensor = torch.tensor(
564
+ cumulative_logprobs,
565
+ dtype=torch.float,
566
+ device=seq_group_logprobs.device)
567
+ seq_group_logprobs = (seq_group_logprobs +
568
+ cumulative_logprobs_tensor.unsqueeze(dim=1))
569
+ _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
570
+ 2 * beam_width)
571
+ topk_ids = topk_ids.tolist()
572
+ vocab_size = seq_group_logprobs.size(-1)
573
+ parent_ids = [i // vocab_size for i in topk_ids]
574
+ next_token_ids = [i % vocab_size for i in topk_ids]
575
+ results.append((next_token_ids, parent_ids))
576
+ sample_idx += num_parent_seqs
577
+ assert sample_idx == logprobs.size(0)
578
+ return results
579
+
580
+
581
+ # torch.multinomial forces a GPU<->CPU sync.
582
+ # Therefore, we use an optimized implementation instead.
583
+ # Note that we always sample with replacement.
584
+ # probs will be modified in place, but this is fine, as we pass
585
+ # in a copy already.
586
+ def _multinomial(
587
+ probs: torch.Tensor,
588
+ num_samples: int,
589
+ seq_groups: Optional[List[SequenceGroupToSample]] = None,
590
+ ) -> torch.Tensor:
591
+ if num_samples > 1:
592
+ probs = probs.repeat_interleave(num_samples, dim=0)
593
+ q = torch.empty_like(probs)
594
+ if seq_groups is None:
595
+ q.exponential_()
596
+ else:
597
+ sample_idx = 0
598
+ for seq_group in seq_groups:
599
+ seq_ids = seq_group.seq_ids
600
+ stride = len(seq_ids) * num_samples
601
+ assert seq_group.generator is not None
602
+ q[sample_idx:sample_idx +
603
+ stride].exponential_(generator=seq_group.generator)
604
+ sample_idx += stride
605
+ return probs.div_(q).argmax(dim=1).view(-1, num_samples)
606
+
607
+
608
+ def _top_k_top_p_multinomial_with_flashinfer(
609
+ probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
610
+ num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]):
611
+ max_top_k_round = 32
612
+ if num_samples > 1:
613
+ probs = probs.repeat_interleave(num_samples, dim=0)
614
+ top_ks = top_ks.repeat_interleave(num_samples)
615
+ top_ps = top_ps.repeat_interleave(num_samples)
616
+ batch_size = probs.shape[0]
617
+ uniform_samples = torch.empty((max_top_k_round, batch_size),
618
+ device=probs.device)
619
+ if seq_groups is None:
620
+ uniform_samples.uniform_()
621
+ else:
622
+ sample_idx = 0
623
+ for seq_group in seq_groups:
624
+ seq_ids = seq_group.seq_ids
625
+ stride = len(seq_ids) * num_samples
626
+ assert seq_group.generator is not None
627
+ uniform_samples[:, sample_idx:sample_idx +
628
+ stride].uniform_(generator=seq_group.generator)
629
+ sample_idx += stride
630
+ batch_next_token_ids, success = flashinfer_top_k_top_p_sampling(
631
+ probs,
632
+ uniform_samples,
633
+ top_ks,
634
+ top_ps,
635
+ )
636
+ if not success.all():
637
+ warnings.warn("FlashInfer rejection sampling failed, fallback.",
638
+ stacklevel=1)
639
+ probs = flashinfer.sampling.top_k_renorm_prob(probs, top_ks)
640
+ probs = flashinfer.sampling.top_p_renorm_prob(probs, top_ps)
641
+ batch_next_token_ids = flashinfer.sampling.sampling_from_probs(
642
+ probs, uniform_samples[0])
643
+ return batch_next_token_ids.view(-1, num_samples)
644
+
645
+
646
+ def get_pythonized_sample_results(
647
+ sample_result_args: SampleResultArgsType) -> SampleResultType:
648
+ '''This function consumes GPU-side sampler results and computes
649
+ Pythonized CPU-side sampler results (GPU -> CPU sync.)
650
+
651
+ Single-step scheduling: this function is invoked at sampling-time
652
+ for immediate Pythonization.
653
+
654
+ Multi-step scheduling: Pythonization is deferred until after multiple
655
+ GPU-side steps have been completed.
656
+
657
+ Args:
658
+ sample_result_args: GPU-side inputs to the Pythonization process
659
+
660
+ Returns:
661
+ Pythonized sampler results
662
+ '''
663
+
664
+ (
665
+ sample_metadata,
666
+ sampling_metadata,
667
+ greedy_samples,
668
+ multinomial_samples,
669
+ beam_search_logprobs,
670
+ sample_results_dict,
671
+ ) = (
672
+ sample_result_args.sample_metadata,
673
+ sample_result_args.sampling_metadata,
674
+ sample_result_args.greedy_samples,
675
+ sample_result_args.multinomial_samples,
676
+ sample_result_args.beam_search_logprobs,
677
+ sample_result_args.sample_results_dict,
678
+ )
679
+
680
+ for sampling_type in SamplingType:
681
+ if sampling_type not in sample_metadata:
682
+ continue
683
+ (seq_group_id, seq_groups) = sample_metadata[sampling_type]
684
+ if sampling_type == SamplingType.GREEDY:
685
+ sample_results = _greedy_sample(seq_groups, greedy_samples)
686
+ elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
687
+ sample_results = _random_sample(seq_groups,
688
+ multinomial_samples[sampling_type])
689
+ elif sampling_type == SamplingType.BEAM:
690
+ sample_results = _beam_search_sample(seq_groups,
691
+ beam_search_logprobs)
692
+ sample_results_dict.update(zip(seq_group_id, sample_results))
693
+
694
+ return [
695
+ sample_results_dict.get(i, ([], []))
696
+ for i in range(len(sampling_metadata.seq_groups))
697
+ ]
698
+
699
+
700
+ def _sample_with_torch(
701
+ probs: torch.Tensor,
702
+ logprobs: torch.Tensor,
703
+ sampling_metadata: SamplingMetadata,
704
+ sampling_tensors: SamplingTensors,
705
+ include_gpu_probs_tensor: bool,
706
+ modify_greedy_probs: bool,
707
+ ) -> SampleReturnType:
708
+ '''Torch-oriented _sample() implementation.
709
+
710
+ Single-step scheduling:
711
+ * Perform GPU-side sampling computation
712
+ * Immediately Pythonize sampling result
713
+
714
+ Multi-step scheduling:
715
+ * Perform GPU-side sampling computation
716
+ * Defer Pythonization & preserve GPU-side
717
+ tensors required for Pythonization
718
+ '''
719
+
720
+ categorized_seq_group_ids: Dict[SamplingType, List[int]] = {
721
+ t: []
722
+ for t in SamplingType
723
+ }
724
+ categorized_sample_indices = sampling_metadata.categorized_sample_indices
725
+ for i, seq_group in enumerate(sampling_metadata.seq_groups):
726
+ sampling_params = seq_group.sampling_params
727
+ sampling_type = sampling_params.sampling_type
728
+ categorized_seq_group_ids[sampling_type].append(i)
729
+
730
+ sample_results_dict: SampleResultsDictType = {}
731
+ sample_metadata: SampleMetadataType = {}
732
+ multinomial_samples: MultinomialSamplesType = {}
733
+ greedy_samples: Optional[torch.Tensor] = None
734
+ beam_search_logprobs: Optional[torch.Tensor] = None
735
+
736
+ # Create output tensor for sampled token ids.
737
+ if include_gpu_probs_tensor:
738
+ sampled_token_ids_tensor = torch.full((logprobs.shape[0], 1),
739
+ VLLM_INVALID_TOKEN_ID,
740
+ dtype=torch.long,
741
+ device=logprobs.device)
742
+ else:
743
+ sampled_token_ids_tensor = None
744
+
745
+ # Counterintiutively, having two loops here is actually faster.
746
+ # The first loop can run without waiting on GPU<->CPU sync.
747
+ for sampling_type in SamplingType:
748
+ sample_indices = categorized_sample_indices[sampling_type]
749
+ num_tokens = len(sample_indices)
750
+ if num_tokens == 0:
751
+ continue
752
+
753
+ seq_group_id = categorized_seq_group_ids[sampling_type]
754
+ seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
755
+ sample_metadata[sampling_type] = (seq_group_id, seq_groups)
756
+ long_sample_indices = sample_indices.long()
757
+ if sampling_type == SamplingType.GREEDY:
758
+ greedy_samples = torch.argmax(logprobs[long_sample_indices],
759
+ dim=-1)
760
+
761
+ if sampled_token_ids_tensor is not None:
762
+ # Store sampled tokens in output tensor.
763
+ sampled_token_ids_tensor[
764
+ long_sample_indices] = greedy_samples.unsqueeze(-1)
765
+
766
+ if modify_greedy_probs:
767
+ # If required, modify the probabilities such that sampling from
768
+ # the modified distribution would always sample the argmax
769
+ # token id.
770
+ _modify_greedy_probs_inplace(logprobs, probs,
771
+ long_sample_indices,
772
+ greedy_samples)
773
+
774
+ elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
775
+ max_n_in_batch = 1
776
+ for seq_group in seq_groups:
777
+ if seq_group.is_prompt:
778
+ sampling_params = seq_group.sampling_params
779
+ max_n_in_batch = max(max_n_in_batch, sampling_params.n)
780
+ seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
781
+ seq_groups)
782
+
783
+ if flashinfer_top_k_top_p_sampling is not None:
784
+ multinomial_samples[
785
+ sampling_type] = _top_k_top_p_multinomial_with_flashinfer(
786
+ probs[long_sample_indices],
787
+ sampling_tensors.top_ks[long_sample_indices],
788
+ sampling_tensors.top_ps[long_sample_indices],
789
+ max_n_in_batch,
790
+ seq_groups_arg,
791
+ )
792
+ else:
793
+ multinomial_samples[sampling_type] = _multinomial(
794
+ probs[long_sample_indices],
795
+ max_n_in_batch,
796
+ seq_groups=seq_groups_arg)
797
+
798
+ if sampled_token_ids_tensor is not None:
799
+ # Store sampled tokens in output tensor.
800
+ sampled_token_ids_tensor[long_sample_indices] = \
801
+ multinomial_samples[sampling_type].to(torch.long)
802
+
803
+ elif sampling_type == SamplingType.BEAM:
804
+ beam_search_logprobs = logprobs[sample_indices]
805
+ else:
806
+ raise ValueError(f"Unsupported sampling type: {sampling_type}")
807
+
808
+ # Encapsulate arguments for computing Pythonized sampler
809
+ # results, whether deferred or otherwise.
810
+ maybe_deferred_args = SampleResultArgsType(
811
+ sampling_metadata=sampling_metadata,
812
+ sample_metadata=sample_metadata,
813
+ multinomial_samples=multinomial_samples,
814
+ greedy_samples=greedy_samples,
815
+ beam_search_logprobs=beam_search_logprobs,
816
+ sample_results_dict=sample_results_dict)
817
+
818
+ if not sampling_metadata.skip_sampler_cpu_output:
819
+ # GPU<->CPU sync happens here.
820
+ # This also converts the sampler output to a Python object.
821
+ # Return Pythonized sampler result & sampled token ids
822
+ return get_pythonized_sample_results(
823
+ maybe_deferred_args), sampled_token_ids_tensor
824
+ else:
825
+ # Defer sampler result Pythonization; return deferred
826
+ # Pythonization args & sampled token ids
827
+ return (
828
+ maybe_deferred_args,
829
+ sampled_token_ids_tensor,
830
+ )
831
+
832
+
833
+ def _sample(
834
+ probs: torch.Tensor,
835
+ logprobs: torch.Tensor,
836
+ sampling_metadata: SamplingMetadata,
837
+ sampling_tensors: SamplingTensors,
838
+ include_gpu_probs_tensor: bool,
839
+ modify_greedy_probs: bool,
840
+ ) -> SampleReturnType:
841
+ """
842
+ Args:
843
+ probs: (num_query_tokens_in_batch, num_vocab)
844
+ logprobs: (num_query_tokens_in_batch, num_vocab)
845
+ sampling_metadata: The metadata for a batch for sampling.
846
+ sampling_tensors: Tensors that include sampling related metadata.
847
+
848
+ Returns:
849
+ (next_token_ids, parent_seq_ids) for each seq group in a batch.
850
+ If sampling is skipped, it returns ([], [])
851
+ sampled_token_ids_tensor: A tensor of sampled token ids.
852
+ """
853
+ return _sample_with_torch(
854
+ probs,
855
+ logprobs,
856
+ sampling_metadata,
857
+ sampling_tensors,
858
+ include_gpu_probs_tensor=include_gpu_probs_tensor,
859
+ modify_greedy_probs=modify_greedy_probs,
860
+ )
861
+
862
+
863
+ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
864
+ """
865
+ This function calculates the ranks of the chosen tokens in a logprob tensor.
866
+
867
+ Args:
868
+ x (torch.Tensor): 2D logprob tensor of shape (N, M)
869
+ where N is the no. of tokens and M is the vocab dim.
870
+ indices (torch.Tensor): List of chosen token indices.
871
+
872
+ Returns:
873
+ torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
874
+ Each element in the returned tensor represents the rank
875
+ of the chosen token in the input logprob tensor.
876
+ """
877
+ vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
878
+ indices]
879
+ result = (x > vals[:, None])
880
+ del vals
881
+ return result.sum(1).add_(1)
882
+
883
+
884
+ def get_logprobs(
885
+ logprobs: torch.Tensor,
886
+ sampling_metadata: SamplingMetadata,
887
+ sample_results: SampleResultType,
888
+ ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
889
+ """Return sample logprobs and prompt logprobs.
890
+
891
+ The logic consists of 3 parts.
892
+ - Select indices to compute logprob from, ranks of token ids, and
893
+ the top k token ids from logprobs.
894
+ - Compute prompt logprobs if required.
895
+ - Compute sample logprobs if required.
896
+
897
+ Args:
898
+ logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
899
+ logprob per vocab. Sequence groups' query tokens are batched in a
900
+ single flattened tensor. For example, assuming there are N
901
+ seq groups, it is sorted by prefill tokens for seq_group_1 (if
902
+ prompt logprob is enabled), decode tokens for seq_group_1 (if
903
+ sampling is required), prefill tokens for seq_group_2, ...
904
+ sampling_metadata: The sampling metadata.
905
+ sample_results: (num_seq_groups) The tuple of (next_token_ids,
906
+ parent_ids) for each sequence group. When beam search is enabled,
907
+ sample_results can contain different number of seq_ids from
908
+ sampling_metadata.seq_groups. It is because beam search creates
909
+ 2 * BEAM_WIDTH number of samples (whereas there are only up to
910
+ BEAM_WIDTH number of seq_ids).
911
+
912
+ Returns:
913
+ A tuple of prompt and sample logprobs per sequence group in a batch.
914
+ """
915
+ # The index of query token to calculate logprobs. It includes both
916
+ # prompt and sample logprob indices.
917
+ query_indices: List[int] = []
918
+ # The next token ids to get the logprob value from.
919
+ next_token_ids: List[int] = []
920
+ # The largest requested number of logprobs. We find logprobs as many as the
921
+ # largest num logprobs in this API. If every logprobs is None, it will be
922
+ # set to -1.
923
+ largest_num_logprobs = -1
924
+
925
+ # Select indices to compute logprob from, ranks of token ids, and the top
926
+ # k token ids from logprobs.
927
+ for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
928
+ sample_results):
929
+ sampling_params = seq_group.sampling_params
930
+
931
+ # Update indices and tokens for prompt logprobs.
932
+ if (seq_group.is_prompt
933
+ and sampling_params.prompt_logprobs is not None):
934
+ largest_num_logprobs = max(largest_num_logprobs,
935
+ sampling_params.prompt_logprobs)
936
+ next_prompt_tokens = _get_next_prompt_tokens(seq_group)
937
+ query_indices.extend(seq_group.prompt_logprob_indices)
938
+ next_token_ids.extend(next_prompt_tokens)
939
+
940
+ # Update indices and next tokenes for sample logprob.
941
+ if seq_group.do_sample:
942
+ token_ids, parent_seq_ids = sample_result
943
+ # NOTE: We cannot directly use sample_indices because
944
+ # sample_indices only contain parent seq_ids of a previous step.
945
+ # The current step may have different number of seq_ids, and
946
+ # we can obtain it from `sample_result[1]`.
947
+ query_idx = seq_group.sample_indices[0]
948
+ query_indices.extend(
949
+ [query_idx + parent_id for parent_id in parent_seq_ids])
950
+ next_token_ids.extend(token_ids)
951
+
952
+ if sampling_params.logprobs is not None:
953
+ largest_num_logprobs = max(largest_num_logprobs,
954
+ sampling_params.logprobs)
955
+
956
+ assert len(next_token_ids) == len(query_indices)
957
+
958
+ if len(query_indices) == 0:
959
+ empty_sampled_logprob: SampleLogprobs = []
960
+ empty_prompt_logprob: Optional[PromptLogprobs] = None
961
+ return [empty_prompt_logprob], [empty_sampled_logprob]
962
+
963
+ selected_logprobs, ranks = None, None
964
+ top_logprobs, top_token_ids = None, None
965
+
966
+ # If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
967
+ # skip the whole logprob calculation.
968
+ if largest_num_logprobs >= 0:
969
+ query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
970
+ next_token_ids_gpu = torch.tensor(next_token_ids,
971
+ device=logprobs.device)
972
+
973
+ # (num_selected_query_tokens, num_logprobs). Note that query_indices can
974
+ # contain duplicates if beam search is enabled.
975
+ selected_logprobs = logprobs[[
976
+ query_indices_gpu,
977
+ next_token_ids_gpu,
978
+ ]]
979
+ ranks = _get_ranks(
980
+ logprobs[query_indices_gpu],
981
+ next_token_ids_gpu,
982
+ )
983
+ assert selected_logprobs.shape[0] == ranks.shape[0]
984
+
985
+ # We need to compute top k only if there exists logprobs > 0.
986
+ if largest_num_logprobs > 0:
987
+ # Logprobs of topk tokens for a batch of sequence groups.
988
+ # (num_query_tokens_across_batch).
989
+ top_logprobs, top_token_ids = torch.topk(logprobs,
990
+ largest_num_logprobs,
991
+ dim=-1)
992
+ top_logprobs = top_logprobs.to('cpu')
993
+ top_token_ids = top_token_ids.to('cpu')
994
+
995
+ selected_logprobs = selected_logprobs.to('cpu')
996
+ ranks = ranks.to('cpu')
997
+
998
+ # Find prompt/sample logprobs.
999
+ prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
1000
+ sample_logprobs_per_seq_group: List[SampleLogprobs] = []
1001
+ top_logprob_idx = 0
1002
+ selected_logprobs_idx = 0
1003
+
1004
+ for seq_group, sample_result in zip(sampling_metadata.seq_groups,
1005
+ sample_results):
1006
+ (prompt_logprobs, top_logprob_idx,
1007
+ selected_logprobs_idx) = _get_prompt_logprob_if_needed(
1008
+ seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
1009
+ selected_logprobs_idx, top_logprob_idx)
1010
+ prompt_logprobs_per_seq_group.append(prompt_logprobs)
1011
+
1012
+ (sampled_logprobs, top_logprob_idx,
1013
+ selected_logprobs_idx) = _get_sampled_logprob_if_needed(
1014
+ seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
1015
+ top_logprobs, selected_logprobs_idx, top_logprob_idx)
1016
+ sample_logprobs_per_seq_group.append(sampled_logprobs)
1017
+
1018
+ return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group
1019
+
1020
+
1021
+ def _get_prompt_logprob_if_needed(
1022
+ seq_group: SequenceGroupToSample,
1023
+ selected_logprobs: torch.Tensor,
1024
+ ranks: torch.Tensor,
1025
+ top_token_ids: torch.Tensor,
1026
+ top_logprobs: torch.Tensor,
1027
+ selected_logprobs_idx: int,
1028
+ top_logprob_idx: int,
1029
+ ):
1030
+ """Compute the prompt logprob from a sequence group if needed."""
1031
+ sampling_params = seq_group.sampling_params
1032
+ is_prompt = seq_group.is_prompt
1033
+
1034
+ # Find prompt logprobs
1035
+ prompt_logprobs: Optional[PromptLogprobs] = None
1036
+ if is_prompt and sampling_params.prompt_logprobs is not None:
1037
+ prompt_logprobs = []
1038
+ num_logprobs = sampling_params.prompt_logprobs
1039
+ next_prompt_tokens = _get_next_prompt_tokens(seq_group)
1040
+ # Pre-select indexes and create a list. It is faster than calling .item
1041
+ # repetitively.
1042
+ selected_logprob_items = selected_logprobs[
1043
+ selected_logprobs_idx:selected_logprobs_idx +
1044
+ len(next_prompt_tokens)].tolist()
1045
+ rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
1046
+ len(next_prompt_tokens)].tolist()
1047
+
1048
+ for idx, token_id in enumerate(next_prompt_tokens):
1049
+ # Calculate the prompt logprob of the real prompt tokens.
1050
+ # {token_id: (logprob, rank_from_vocab)}
1051
+ prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
1052
+ token_id: (selected_logprob_items[idx], rank_items[idx])
1053
+ }
1054
+
1055
+ # Add top K prompt logprobs along with its rank.
1056
+ if num_logprobs > 0:
1057
+ top_ids = top_token_ids[
1058
+ top_logprob_idx, :num_logprobs].tolist()
1059
+ top_probs = top_logprobs[
1060
+ top_logprob_idx, :num_logprobs].tolist()
1061
+ # Top K is already sorted by rank, so we can use 1 ~
1062
+ # num_logprobs + 1 for rank.
1063
+ top_ranks = range(1, num_logprobs + 1)
1064
+ prompt_logprobs_dict.update({
1065
+ top_id: (top_prob, rank)
1066
+ for top_id, top_prob, rank in zip(top_ids, top_probs,
1067
+ top_ranks)
1068
+ })
1069
+ prompt_logprobs.append({
1070
+ token_id: Logprob(*logprob_and_rank)
1071
+ for token_id, logprob_and_rank in prompt_logprobs_dict.items()
1072
+ })
1073
+ # + 1 to go to the next prompt token.
1074
+ top_logprob_idx += 1
1075
+
1076
+ # + len(next_prompt_tokens) to go to the next prompt.
1077
+ selected_logprobs_idx += len(next_prompt_tokens)
1078
+ return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
1079
+
1080
+
1081
+ def _get_sampled_logprob_if_needed(
1082
+ seq_group: SequenceGroupToSample,
1083
+ sample_result: Tuple[List[int], List[int]],
1084
+ selected_logprobs: torch.Tensor,
1085
+ ranks: torch.Tensor,
1086
+ top_token_ids: torch.Tensor,
1087
+ top_logprobs: torch.Tensor,
1088
+ selected_logprobs_idx: int,
1089
+ top_logprob_idx: int,
1090
+ ):
1091
+ """Compute the sample logprob if needed."""
1092
+ seq_ids = seq_group.seq_ids
1093
+ num_logprobs = seq_group.sampling_params.logprobs
1094
+ sampled_logprobs: SampleLogprobs = []
1095
+ next_token_ids, parent_seq_ids = sample_result
1096
+
1097
+ if seq_group.do_sample:
1098
+ assert len(next_token_ids) > 0
1099
+ if num_logprobs is None:
1100
+ for next_token_id in next_token_ids:
1101
+ # Use a dummy logprob
1102
+ sampled_logprobs.append({next_token_id: Logprob(inf)})
1103
+ else:
1104
+ # Pre-select items from tensor. tolist() is faster than repetitive
1105
+ # `.item()` calls.
1106
+ selected_logprob_items = selected_logprobs[
1107
+ selected_logprobs_idx:selected_logprobs_idx +
1108
+ len(next_token_ids)].tolist()
1109
+ rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
1110
+ len(next_token_ids)].tolist()
1111
+ for idx, (next_token_id, parent_id) in enumerate(
1112
+ zip(next_token_ids, parent_seq_ids)):
1113
+ # Get the logprob of a sampled token.
1114
+ sampled_logprobs_dict = {
1115
+ next_token_id:
1116
+ (selected_logprob_items[idx], rank_items[idx])
1117
+ }
1118
+ if num_logprobs is not None and num_logprobs > 0:
1119
+ # Get top K logprobs.
1120
+ top_ids = top_token_ids[top_logprob_idx +
1121
+ parent_id, :num_logprobs].tolist()
1122
+ top_probs = top_logprobs[
1123
+ top_logprob_idx + parent_id, :num_logprobs].tolist()
1124
+ # Top K is already sorted by rank, so we can use 1 ~
1125
+ # num_logprobs + 1 for rank.
1126
+ top_ranks = range(1, num_logprobs + 1)
1127
+ sampled_logprobs_dict.update({
1128
+ top_id: (top_prob, rank)
1129
+ for top_id, top_prob, rank in zip(
1130
+ top_ids, top_probs, top_ranks)
1131
+ })
1132
+
1133
+ sampled_logprobs.append({
1134
+ token_id: Logprob(*logprob_and_rank)
1135
+ for token_id, logprob_and_rank in
1136
+ sampled_logprobs_dict.items()
1137
+ })
1138
+
1139
+ # NOTE: This part of code is not intuitive. `selected_logprobs` include
1140
+ # logprobs for the current step, which has len(next_token_ids) tokens
1141
+ # per sequence group. `logprobs` includes logprobs from the previous
1142
+ # steps, which has len(seq_ids) tokens per sequence group.
1143
+
1144
+ # Iterate to the next sequence group in a batch.
1145
+ selected_logprobs_idx += len(next_token_ids)
1146
+ # Iterate to the next sequence group in a batch.
1147
+ top_logprob_idx += len(seq_ids)
1148
+ return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
1149
+
1150
+
1151
+ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
1152
+ sample_indices: torch.Tensor,
1153
+ greedy_samples: torch.Tensor) -> None:
1154
+ """Modify the probability distributions of the greedily-sampled tokens such
1155
+ that each sampled token has a "probability" of 1.0. This is required by
1156
+ speculative decoding, which depends on the sampling method being encoded
1157
+ within the probability distribution for correctness.
1158
+
1159
+ # Why do we only need to do this for greedy sampling?
1160
+
1161
+ vLLM's sampler performs the following steps for greedy or multinomial
1162
+ (random) sampling:
1163
+ 1. Get logits from model.
1164
+ 2. Modify logits according to per-sequence sampling parameters.
1165
+ - Multiply by temperature, top-k and top-p masking, penalize tokens
1166
+ according to their frequency, etc.
1167
+ 3. Sample a token.
1168
+ - Random sampling simply samples from the modified probability
1169
+ distribution.
1170
+ - Greedy sampling performs `argmax` to obtain the token with the
1171
+ highest likelihood.
1172
+
1173
+ Ignoring greedy sampling for a moment, we find that the computed probability
1174
+ distribution has the following property: we can sample from it independently
1175
+ and find that the token sampled by the Sampler has a frequency corresponding
1176
+ to how often we see it in our sampling. In other words, for tokens sampled
1177
+ with vLLM's random SamplingType, the computed probability distribution
1178
+ encodes the sampling methodology completely.
1179
+
1180
+ Greedy sampling does not normally have this property. vLLM modifies logits
1181
+ according to sampling params, then performs `argmax`, then returns the
1182
+ sampled token and the computed probability distribution. If we sample from
1183
+ the distribution, we'll find the likelihood of the greedily-sampled token
1184
+ is not always 1.0.
1185
+
1186
+ Since lossless speculative decoding requires that the sampling methodology
1187
+ be encoded within the probability distribution, we are motivated to modify
1188
+ the probability distribution such that the sampled token has probability 1
1189
+ when speculative decoding is used.
1190
+
1191
+ NOTE: Alternatively, we could use an extremely low temperature to achieve
1192
+ greedy sampling using multinomial computation and unite the codepaths. This
1193
+ has implications on the overall design of the sampler, e.g. how to record
1194
+ accurate logprobs for the user, so this improvement is deferred to later.
1195
+ """
1196
+ # NOTE: logprobs are not modified so they can be returned to the user.
1197
+ probs[sample_indices, :] = 0
1198
+ probs[sample_indices, greedy_samples] = 1.0
1199
+
1200
+
1201
+ def _build_sampler_output(
1202
+ maybe_deferred_sample_results: MaybeDeferredSampleResultType,
1203
+ sampling_metadata: SamplingMetadata,
1204
+ prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
1205
+ sample_logprobs: Optional[List[SampleLogprobs]],
1206
+ on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
1207
+ torch.Tensor]],
1208
+ skip_sampler_cpu_output: bool = False,
1209
+ ) -> SamplerOutput:
1210
+ """Construct Python objects with the output of sampling.
1211
+
1212
+ Args:
1213
+ on_device_tensors: Tuple containing on-device tensors with the
1214
+ probabilities used in sampling and the sampled token ids. This
1215
+ allows post-processing without copies to CPU/serialization, e.g. in
1216
+ speculative decoding rejection sampling.
1217
+ """
1218
+ sampler_output: List[CompletionSequenceGroupOutput] = []
1219
+
1220
+ if skip_sampler_cpu_output:
1221
+ assert isinstance(maybe_deferred_sample_results, SampleResultArgsType)
1222
+ deferred_sample_results_args = maybe_deferred_sample_results
1223
+ else:
1224
+ assert prompt_logprobs is not None
1225
+ assert sample_logprobs is not None
1226
+ assert not isinstance(maybe_deferred_sample_results,
1227
+ SampleResultArgsType)
1228
+ deferred_sample_results_args = None
1229
+
1230
+ for (seq_group, sample_result, group_prompt_logprobs,
1231
+ group_sample_logprobs) in zip(sampling_metadata.seq_groups,
1232
+ maybe_deferred_sample_results,
1233
+ prompt_logprobs, sample_logprobs):
1234
+ seq_ids = seq_group.seq_ids
1235
+ next_token_ids, parent_ids = sample_result
1236
+ seq_outputs: List[SequenceOutput] = []
1237
+ for parent_id, next_token_id, logprobs in zip(
1238
+ parent_ids, next_token_ids, group_sample_logprobs):
1239
+ seq_outputs.append(
1240
+ SequenceOutput(seq_ids[parent_id], next_token_id,
1241
+ logprobs))
1242
+ sampler_output.append(
1243
+ CompletionSequenceGroupOutput(seq_outputs,
1244
+ group_prompt_logprobs))
1245
+
1246
+ # If not specified, store None values in SamplerOutput.
1247
+ if on_device_tensors is not None:
1248
+ (sampled_token_probs, logprobs_tensor,
1249
+ sampled_token_ids) = on_device_tensors
1250
+ else:
1251
+ sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
1252
+ None)
1253
+
1254
+ return SamplerOutput(
1255
+ outputs=sampler_output,
1256
+ sampled_token_probs=sampled_token_probs,
1257
+ sampled_token_ids=sampled_token_ids,
1258
+ logprobs=logprobs_tensor,
1259
+ deferred_sample_results_args=deferred_sample_results_args)
1260
+
1261
+
1262
+ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
1263
+ """Get a list of next prompt tokens to compute logprob from a
1264
+ given sequence group.
1265
+
1266
+ It is used to compute prompt logprob. Imagine you have logprob for each
1267
+ query token. Query token needs to know the next prompt token id to compute
1268
+ prompt logprob. This is a helper to obtain next prompt token ids.
1269
+
1270
+ This API has to be used only when the caller knows seq_group is in prefill
1271
+ stage.
1272
+
1273
+ Returns:
1274
+ A list of next prompt tokens to compute logprob.
1275
+ """
1276
+ assert seq_group.is_prompt, (
1277
+ "Caller should ensure the sequence group is in a prefill stage.")
1278
+ seq_ids = seq_group.seq_ids
1279
+ query_len = seq_group.query_len
1280
+ assert query_len is not None
1281
+ # prompt has only 1 seq id.
1282
+ assert len(seq_ids) == 1
1283
+ seq_data = seq_group.seq_data[seq_ids[0]]
1284
+ computed_len = seq_data.get_num_computed_tokens()
1285
+ prompt_tokens = seq_data.prompt_token_ids
1286
+ # +1 because we are looking for a next prompt token.
1287
+ next_token_index_start = computed_len + 1
1288
+ next_token_index_end = min(computed_len + query_len + 1,
1289
+ len(prompt_tokens))
1290
+ next_prompt_tokens = prompt_tokens[
1291
+ next_token_index_start:next_token_index_end]
1292
+ return next_prompt_tokens
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/spec_decode_base_sampler.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from abc import abstractmethod
4
+ from typing import Dict, Optional, Union
5
+
6
+ import torch
7
+ import torch.jit
8
+ import torch.nn as nn
9
+
10
+
11
+ class SpecDecodeBaseSampler(nn.Module):
12
+ """Base class for samplers used for Speculative Decoding verification
13
+ step.
14
+ """
15
+
16
+ def __init__(self, strict_mode: bool = False):
17
+ """Base class constructor.
18
+ Args:
19
+ strict_mode: Whether or not to perform shape/device/dtype checks
20
+ during sampling. This catches correctness issues but adds
21
+ nontrivial latency.
22
+ """
23
+ super().__init__()
24
+ self._strict_mode = strict_mode
25
+
26
+ # NOTE: A "bonus token" is accepted iff all proposal tokens are
27
+ # accepted. There is always only one possible bonus token. We store this
28
+ # value in a variable for readability.
29
+ self._num_bonus_tokens = 1
30
+
31
+ self.num_accepted_tokens: Optional[torch.Tensor] = None
32
+ self.num_emitted_tokens: Optional[torch.Tensor] = None
33
+ self.num_draft_tokens: int = 0
34
+
35
+ def init_gpu_tensors(self, device: Union[int, str]) -> None:
36
+ assert self.num_accepted_tokens is None
37
+ if isinstance(device, int):
38
+ device = f"cuda:{device}"
39
+ elif not isinstance(device, str):
40
+ raise ValueError(f"Device must be int or str, get {type(device)}")
41
+ self.num_accepted_tokens = torch.tensor(0,
42
+ dtype=torch.long,
43
+ device=device)
44
+ self.num_emitted_tokens = torch.tensor(0,
45
+ dtype=torch.long,
46
+ device=device)
47
+
48
+ def init_tensors(self,
49
+ device: Union[int, str],
50
+ device_type: Union[torch.device, str] = 'cuda') -> None:
51
+ assert self.num_accepted_tokens is None
52
+ if isinstance(device_type, torch.device):
53
+ device_type = device_type.type
54
+ if isinstance(device, int):
55
+ device = f"{device_type}:{device}"
56
+ self.num_accepted_tokens = torch.tensor(0,
57
+ dtype=torch.long,
58
+ device=device)
59
+ self.num_emitted_tokens = torch.tensor(0,
60
+ dtype=torch.long,
61
+ device=device)
62
+
63
+ @property
64
+ def probs_dtype(self):
65
+ return torch.float32
66
+
67
+ @property
68
+ def token_id_dtype(self):
69
+ return torch.int64
70
+
71
+ def _create_output(
72
+ self,
73
+ accepted: torch.Tensor, # [batch_size, k]
74
+ substitute_token_ids: torch.Tensor, # [batch_size, k]
75
+ draft_token_ids: torch.Tensor, # [batch_size, k]
76
+ bonus_token_ids: torch.Tensor, # [batch_size]
77
+ ) -> torch.Tensor:
78
+ """Format output. Returns a matrix of token ids. When
79
+ a token is rejected via sampling, all subsequent token ids are
80
+ set to -1 for the sequence.
81
+
82
+ Args:
83
+ accepted: A boolean tensor indicating if the corresponding
84
+ draft token in draft_token_ids should be accepted or not.
85
+ substitute_token_ids: A tensor of token_ids that can be used
86
+ as substitutes for the draft token ids if the proposed token
87
+ is rejected.
88
+ draft_token_ids: A tensor of token ids speculated by the
89
+ draft model.
90
+ bonus_token_ids: Token ids to use as the bonus token if
91
+ all the draft tokens are accepted.
92
+ Returns:
93
+ A tensor containing the accepted token ids. The shape of the
94
+ tensor is [batch_size, k + num_bonus_tokens]
95
+ """
96
+ batch_size, k = substitute_token_ids.shape
97
+ bonus_token_ids = bonus_token_ids.squeeze(-1)
98
+ # Determine the index of the first False value for each row.
99
+ limits = (accepted == 0).max(1).indices
100
+ limits[~(accepted == 0).any(1)] = k
101
+
102
+ # Create masks using the indices.
103
+ indices = torch.arange(k, device=accepted.device).unsqueeze(0)
104
+ accepted_mask = indices < limits.unsqueeze(1)
105
+ after_false_mask = indices == limits.unsqueeze(1)
106
+
107
+ # Create an extended output tensor
108
+ output_with_bonus_tokens = -torch.ones(
109
+ (batch_size, k + self._num_bonus_tokens),
110
+ dtype=self.token_id_dtype,
111
+ device=accepted.device)
112
+ output = output_with_bonus_tokens[:, :k]
113
+
114
+ # Fill in the first k columns of the output tensor using masks and data
115
+ # tensors.
116
+ output[:, :k] = torch.where(accepted_mask, draft_token_ids,
117
+ -torch.ones_like(draft_token_ids))
118
+
119
+ # Fill the last column.
120
+ # We check output directly as accepted may have True values inconsistent
121
+ # with causal acceptance.
122
+ output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
123
+ bonus_token_ids, -1)
124
+
125
+ # Fill the recovered token ids.
126
+ output.mul_(~after_false_mask).add_(
127
+ substitute_token_ids.mul(after_false_mask))
128
+
129
+ self.num_accepted_tokens += accepted.sum()
130
+ self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()
131
+ self.num_draft_tokens += batch_size * k
132
+
133
+ return output_with_bonus_tokens
134
+
135
+ def _raise_if_incorrect_input(
136
+ self,
137
+ target_with_bonus_probs: torch.Tensor,
138
+ draft_token_ids: torch.Tensor,
139
+ bonus_token_ids: torch.Tensor,
140
+ draft_probs: Optional[torch.Tensor] = None,
141
+ ) -> None:
142
+ self._raise_if_incorrect_shape(target_with_bonus_probs,
143
+ draft_token_ids, bonus_token_ids,
144
+ draft_probs)
145
+ self._raise_if_incorrect_dtype(target_with_bonus_probs,
146
+ draft_token_ids, bonus_token_ids,
147
+ draft_probs)
148
+ self._raise_if_inconsistent_device(target_with_bonus_probs,
149
+ draft_token_ids, bonus_token_ids,
150
+ draft_probs)
151
+ self._raise_if_out_of_bounds_vocab(target_with_bonus_probs.shape[-1],
152
+ draft_token_ids, bonus_token_ids)
153
+
154
+ def _raise_if_incorrect_shape(
155
+ self,
156
+ target_with_bonus_probs: torch.Tensor,
157
+ draft_token_ids: torch.Tensor,
158
+ bonus_token_ids: torch.Tensor,
159
+ draft_probs: Optional[torch.Tensor] = None,
160
+ ) -> None:
161
+ (target_batch_size, num_target_probs,
162
+ target_vocab_size) = target_with_bonus_probs.shape
163
+
164
+ # Does not count the extra token
165
+ num_target_probs -= 1
166
+
167
+ # validate the shape of draft token ids.
168
+ draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
169
+ assert draft_token_ids_batch_size == target_batch_size
170
+ assert num_draft_token_ids == num_target_probs
171
+
172
+ # validate the shape of bonus token ids
173
+ bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape
174
+ assert bonus_batch_size == target_batch_size
175
+ assert num_bonus_tokens == self._num_bonus_tokens
176
+
177
+ # validate the shape of draft probs if it is set
178
+ if draft_probs is not None:
179
+ (draft_batch_size, num_draft_probs,
180
+ draft_vocab_size) = draft_probs.shape
181
+ assert draft_batch_size == target_batch_size
182
+ assert num_draft_probs == num_target_probs
183
+ assert (draft_vocab_size == target_vocab_size
184
+ ), f"{draft_vocab_size=} {target_vocab_size=}"
185
+
186
+ def _raise_if_incorrect_dtype(
187
+ self,
188
+ target_with_bonus_probs: torch.Tensor,
189
+ draft_token_ids: torch.Tensor,
190
+ bonus_token_ids: torch.Tensor,
191
+ draft_probs: Optional[torch.Tensor] = None,
192
+ ) -> None:
193
+ assert target_with_bonus_probs.dtype == self.probs_dtype
194
+ assert draft_token_ids.dtype == self.token_id_dtype
195
+ assert bonus_token_ids.dtype == self.token_id_dtype
196
+ if draft_probs is not None:
197
+ assert draft_probs.dtype == self.probs_dtype
198
+
199
+ def _raise_if_inconsistent_device(
200
+ self,
201
+ target_with_bonus_probs: torch.Tensor,
202
+ draft_token_ids: torch.Tensor,
203
+ bonus_token_ids: torch.Tensor,
204
+ draft_probs: Optional[torch.Tensor] = None,
205
+ ) -> None:
206
+ devices = [
207
+ t.device for t in [
208
+ target_with_bonus_probs, bonus_token_ids, draft_probs,
209
+ draft_token_ids
210
+ ] if t is not None
211
+ ]
212
+ assert all([devices[0] == device for device in devices])
213
+
214
+ def _raise_if_out_of_bounds_vocab(
215
+ self,
216
+ vocab_size: int,
217
+ draft_token_ids: torch.Tensor,
218
+ bonus_token_ids: torch.Tensor,
219
+ ) -> None:
220
+ assert torch.all(bonus_token_ids < vocab_size)
221
+ assert torch.all(bonus_token_ids >= 0)
222
+ assert torch.all(draft_token_ids < vocab_size)
223
+ assert torch.all(draft_token_ids >= 0)
224
+
225
+
226
+ class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler):
227
+ """Base class for samplers used for Speculative Decoding verification
228
+ step which are deterministic.
229
+ """
230
+
231
+ @abstractmethod
232
+ def forward(
233
+ self,
234
+ target_with_bonus_probs: torch.Tensor,
235
+ bonus_token_ids: torch.Tensor,
236
+ draft_probs: torch.Tensor,
237
+ draft_token_ids: torch.Tensor,
238
+ ) -> torch.Tensor:
239
+ raise NotImplementedError
240
+
241
+
242
+ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
243
+ """Base class for samplers used for Speculative Decoding verification
244
+ step which are stochastic
245
+ """
246
+
247
+ @abstractmethod
248
+ def forward(
249
+ self,
250
+ target_with_bonus_probs: torch.Tensor,
251
+ bonus_token_ids: torch.Tensor,
252
+ draft_probs: torch.Tensor,
253
+ draft_token_ids: torch.Tensor,
254
+ seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
255
+ ) -> torch.Tensor:
256
+ raise NotImplementedError
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/typical_acceptance_sampler.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import torch
4
+ import torch.jit
5
+
6
+ from vllm.model_executor.layers.spec_decode_base_sampler import (
7
+ SpecDecodeDeterministicBaseSampler)
8
+
9
+
10
+ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
11
+ """Apply typical acceptance sampling as described in section 3.3.1 in
12
+ "MEDUSA: Simple LLM Inference Acceleration Framework with
13
+ Multiple Decoding Heads"
14
+ https://arxiv.org/pdf/2401.10774
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ posterior_threshold: float,
20
+ posterior_alpha: float,
21
+ strict_mode: bool = False,
22
+ ):
23
+ """Create a Typical Acceptance Sampler.
24
+
25
+ Args:
26
+ strict_mode: Whether or not to perform shape/device/dtype checks
27
+ during sampling. This catches correctness issues but adds
28
+ nontrivial latency.
29
+ posterior_threshold : A threshold value that sets a lower bound
30
+ on the posterior probability of a token in target model for it
31
+ to be accepted.
32
+ posterior_alpha : A scaling factor for the entropy-based
33
+ threshold in typical acceptance sampling.
34
+ """
35
+ self._posterior_threshold = posterior_threshold
36
+ self._posterior_alpha = posterior_alpha
37
+ super().__init__(strict_mode=strict_mode)
38
+
39
+ def forward(
40
+ self,
41
+ target_with_bonus_probs: torch.Tensor,
42
+ bonus_token_ids: torch.Tensor,
43
+ draft_probs: torch.Tensor,
44
+ draft_token_ids: torch.Tensor,
45
+ ) -> torch.Tensor:
46
+ """Sample token ids using typical acceptance sampling. This accepts
47
+ or rejects tokens proposed by the draft model using the probability
48
+ of each token according to the draft and target models.
49
+
50
+ In the worst case where all draft tokens are rejected, it is guaranteed
51
+ one token will be emitted.
52
+
53
+ In the case where all draft tokens are accepted, the bonus token will be
54
+ accepted.
55
+
56
+ Args:
57
+ target_probs: The probability distribution over token ids given
58
+ context according to the target model.
59
+ shape = [batch_size, num_speculative_tokens, vocab_size]
60
+
61
+ bonus_token_ids: The "bonus" token ids that are accepted iff all
62
+ speculative tokens in a sequence are accepted.
63
+ shape = [batch_size, num_bonus_tokens]
64
+
65
+ draft_probs: This parameter is unused by the acceptance sampler.
66
+
67
+ draft_token_ids: The token ids that were sampled from the draft
68
+ probabilities.
69
+ shape = [batch_size, num_speculative_tokens]
70
+
71
+ Returns:
72
+ output_token_ids: The token ids sampled via rejection sampling,
73
+ or -1 if unable to sample a token because the previous token
74
+ was rejected.
75
+ shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
76
+ """
77
+ # Only perform shape/dtype/device checking in strict mode, as it adds
78
+ # overhead.
79
+ if self._strict_mode:
80
+ self._raise_if_incorrect_input(target_with_bonus_probs,
81
+ draft_token_ids, bonus_token_ids)
82
+ target_probs = target_with_bonus_probs[:, :-1]
83
+ accepted = self._evaluate_accepted_tokens(target_probs,
84
+ draft_token_ids)
85
+ recovered_token_ids = self._get_recovered_token_ids(target_probs)
86
+ output_token_ids = self._create_output(accepted, recovered_token_ids,
87
+ draft_token_ids,
88
+ bonus_token_ids)
89
+ return output_token_ids
90
+
91
+ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids):
92
+ r"""
93
+ Evaluates and returns a mask of accepted tokens based on the
94
+ posterior probabilities.
95
+
96
+ Parameters:
97
+ ----------
98
+ target_probs : torch.Tensor
99
+ A tensor of shape (batch_size, k, vocab_size) representing
100
+ the probabilities of each token in the vocabulary for each
101
+ position in the proposed sequence. This is the distribution
102
+ generated by the target model.
103
+ draft_token_ids : torch.Tensor
104
+ A tensor of shape (batch_size, k) representing the proposed
105
+ token ids.
106
+
107
+ A draft token_id x_{n+k} is accepted if it satisfies the
108
+ following condition
109
+
110
+ .. math::
111
+ p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) >
112
+ \min \left( \epsilon, \delta * \exp \left(
113
+ -H(p_{\text{original}}(
114
+ \cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right)
115
+
116
+ where :math:`p_{\text{original}}` corresponds to target_probs
117
+ and :math:`\epsilon` and :math:`\delta` correspond to hyperparameters
118
+ specified using self._posterior_threshold and self._posterior_alpha
119
+
120
+ This method computes the posterior probabilities for the given
121
+ draft token ids based on the provided target probabilities. It
122
+ calculates the entropy of the posterior distribution and determines
123
+ a dynamic threshold for each token position using the provided
124
+ posterior_threshold and posterior_alpha values. The method then
125
+ returns a boolean mask indicating which tokens can be accepted.
126
+
127
+ Returns:
128
+ -------
129
+ torch.Tensor
130
+ A boolean tensor of shape (batch_size, k) where each element
131
+ indicates whether the corresponding draft token has been accepted
132
+ or rejected. True indicates acceptance and false indicates
133
+ rejection.
134
+
135
+ """
136
+ device = target_probs.device
137
+ candidates_prob = torch.gather(
138
+ target_probs, dim=-1,
139
+ index=draft_token_ids.unsqueeze(-1)).squeeze(-1)
140
+ # A small constant added to prevent computing the logarithm of zero,
141
+ # which can lead to undefined values.
142
+ epsilon = 1e-5
143
+ posterior_entropy = -torch.sum(
144
+ target_probs * torch.log(target_probs + epsilon), dim=-1)
145
+ threshold = torch.minimum(
146
+ torch.ones_like(posterior_entropy, device=device) *
147
+ self._posterior_threshold,
148
+ torch.exp(-posterior_entropy) * self._posterior_alpha,
149
+ )
150
+ accepted_mask = candidates_prob > threshold
151
+ return accepted_mask
152
+
153
+ def _get_recovered_token_ids(self, target_probs):
154
+ """
155
+ The recovered token ids will fill the first unmatched token
156
+ by the target token.
157
+
158
+ Parameters
159
+ ----------
160
+ target_probs : torch.Tensor
161
+ A tensor of shape (batch_size, k, vocab_size) containing
162
+ the target probability distribution
163
+
164
+ Returns
165
+ -------
166
+ torch.Tensor
167
+ A tensor of shape (batch_size, k) with the recovered token
168
+ ids which are selected from target probs.
169
+ """
170
+ max_indices = torch.argmax(target_probs, dim=-1)
171
+
172
+ return max_indices
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/utils.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Utility methods for model layers."""
3
+ from typing import Tuple
4
+
5
+ import torch
6
+
7
+
8
+ def get_token_bin_counts_and_mask(
9
+ tokens: torch.Tensor,
10
+ vocab_size: int,
11
+ num_seqs: int,
12
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
13
+ # Compute the bin counts for the tokens.
14
+ # vocab_size + 1 for padding.
15
+ bin_counts = torch.zeros((num_seqs, vocab_size + 1),
16
+ dtype=torch.long,
17
+ device=tokens.device)
18
+ bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
19
+ bin_counts = bin_counts[:, :vocab_size]
20
+ mask = bin_counts > 0
21
+
22
+ return bin_counts, mask
23
+
24
+
25
+ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
26
+ output_tokens_tensor: torch.Tensor,
27
+ presence_penalties: torch.Tensor,
28
+ frequency_penalties: torch.Tensor,
29
+ repetition_penalties: torch.Tensor) -> torch.Tensor:
30
+ """
31
+ Applies penalties in place to the logits tensor
32
+ logits : The input logits tensor of shape [num_seqs, vocab_size]
33
+ prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
34
+ are padded to the maximum prompt length within the batch using
35
+ `vocab_size` as the padding value. The value `vocab_size` is used
36
+ for padding because it does not correspond to any valid token ID
37
+ in the vocabulary.
38
+ output_tokens_tensor: The output tokens tensor.
39
+ presence_penalties: The presence penalties of shape (num_seqs, )
40
+ frequency_penalties: The frequency penalties of shape (num_seqs, )
41
+ repetition_penalties: The repetition penalties of shape (num_seqs, )
42
+ """
43
+ num_seqs, vocab_size = logits.shape
44
+ _, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor,
45
+ vocab_size, num_seqs)
46
+ output_bin_counts, output_mask = get_token_bin_counts_and_mask(
47
+ output_tokens_tensor, vocab_size, num_seqs)
48
+ repetition_penalties = repetition_penalties.unsqueeze_(dim=1).repeat(
49
+ 1, vocab_size)
50
+ logits[logits > 0] /= torch.where(prompt_mask | output_mask,
51
+ repetition_penalties, 1.0)[logits > 0]
52
+ logits[logits <= 0] *= torch.where(prompt_mask | output_mask,
53
+ repetition_penalties, 1.0)[logits <= 0]
54
+ # We follow the definition in OpenAI API.
55
+ # Refer to https://platform.openai.com/docs/api-reference/parameter-details
56
+ logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
57
+ logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
58
+ return logits
.venv/lib/python3.11/site-packages/vllm/model_executor/layers/vocab_parallel_embedding.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Sequence, Tuple
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch.nn.parameter import Parameter, UninitializedParameter
9
+
10
+ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
11
+ get_tensor_model_parallel_world_size,
12
+ tensor_model_parallel_all_reduce)
13
+ from vllm.model_executor.layers.quantization.base_config import (
14
+ QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
15
+ from vllm.model_executor.parameter import BasevLLMParameter
16
+ from vllm.model_executor.utils import set_weight_attrs
17
+ from vllm.platforms import current_platform
18
+
19
+ DEFAULT_VOCAB_PADDING_SIZE = 64
20
+
21
+
22
+ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
23
+ """Unquantized method for embeddings."""
24
+
25
+ def create_weights(self, layer: torch.nn.Module,
26
+ input_size_per_partition: int,
27
+ output_partition_sizes: List[int], input_size: int,
28
+ output_size: int, params_dtype: torch.dtype,
29
+ **extra_weight_attrs):
30
+ """Create weights for embedding layer."""
31
+ weight = Parameter(torch.empty(sum(output_partition_sizes),
32
+ input_size_per_partition,
33
+ dtype=params_dtype),
34
+ requires_grad=False)
35
+ set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
36
+ layer.register_parameter("weight", weight)
37
+ set_weight_attrs(weight, extra_weight_attrs)
38
+
39
+ def apply(self,
40
+ layer: torch.nn.Module,
41
+ x: torch.Tensor,
42
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
43
+ return F.linear(x, layer.weight, bias)
44
+
45
+ def embedding(self, layer: torch.nn.Module,
46
+ input_: torch.Tensor) -> torch.Tensor:
47
+ return F.embedding(input_, layer.weight)
48
+
49
+
50
+ def pad_vocab_size(vocab_size: int,
51
+ pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
52
+ """Pad the vocab size to the given value."""
53
+ return ((vocab_size + pad_to - 1) // pad_to) * pad_to
54
+
55
+
56
+ def vocab_range_from_per_partition_vocab_size(
57
+ per_partition_vocab_size: int,
58
+ rank: int,
59
+ offset: int = 0) -> Sequence[int]:
60
+ index_f = rank * per_partition_vocab_size
61
+ index_l = index_f + per_partition_vocab_size
62
+ return index_f + offset, index_l + offset
63
+
64
+
65
+ def vocab_range_from_global_vocab_size(global_vocab_size: int,
66
+ rank: int,
67
+ world_size: int,
68
+ offset: int = 0) -> Sequence[int]:
69
+ per_partition_vocab_size = divide(global_vocab_size, world_size)
70
+ return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
71
+ rank,
72
+ offset=offset)
73
+
74
+
75
+ @dataclass
76
+ class VocabParallelEmbeddingShardIndices:
77
+ """Indices for a shard of a vocab parallel embedding."""
78
+ padded_org_vocab_start_index: int
79
+ padded_org_vocab_end_index: int
80
+ padded_added_vocab_start_index: int
81
+ padded_added_vocab_end_index: int
82
+
83
+ org_vocab_start_index: int
84
+ org_vocab_end_index: int
85
+ added_vocab_start_index: int
86
+ added_vocab_end_index: int
87
+
88
+ @property
89
+ def num_org_elements(self) -> int:
90
+ return self.org_vocab_end_index - self.org_vocab_start_index
91
+
92
+ @property
93
+ def num_added_elements(self) -> int:
94
+ return self.added_vocab_end_index - self.added_vocab_start_index
95
+
96
+ @property
97
+ def num_org_elements_padded(self) -> int:
98
+ return (self.padded_org_vocab_end_index -
99
+ self.padded_org_vocab_start_index)
100
+
101
+ @property
102
+ def num_added_elements_padded(self) -> int:
103
+ return (self.padded_added_vocab_end_index -
104
+ self.padded_added_vocab_start_index)
105
+
106
+ @property
107
+ def num_org_vocab_padding(self) -> int:
108
+ return self.num_org_elements_padded - self.num_org_elements
109
+
110
+ @property
111
+ def num_added_vocab_padding(self) -> int:
112
+ return self.num_added_elements_padded - self.num_added_elements
113
+
114
+ @property
115
+ def num_elements_padded(self) -> int:
116
+ return self.num_org_elements_padded + self.num_added_elements_padded
117
+
118
+ def __post_init__(self):
119
+ # sanity checks
120
+ assert (self.padded_org_vocab_start_index
121
+ <= self.padded_org_vocab_end_index)
122
+ assert (self.padded_added_vocab_start_index
123
+ <= self.padded_added_vocab_end_index)
124
+
125
+ assert self.org_vocab_start_index <= self.org_vocab_end_index
126
+ assert self.added_vocab_start_index <= self.added_vocab_end_index
127
+
128
+ assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
129
+ assert (self.added_vocab_start_index
130
+ <= self.padded_added_vocab_start_index)
131
+ assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
132
+ assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
133
+
134
+ assert self.num_org_elements <= self.num_org_elements_padded
135
+ assert self.num_added_elements <= self.num_added_elements_padded
136
+
137
+
138
+ @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
139
+ def get_masked_input_and_mask(
140
+ input_: torch.Tensor, org_vocab_start_index: int,
141
+ org_vocab_end_index: int, num_org_vocab_padding: int,
142
+ added_vocab_start_index: int,
143
+ added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
144
+ # torch.compile will fuse all of the pointwise ops below
145
+ # into a single kernel, making it very fast
146
+ org_vocab_mask = (input_ >= org_vocab_start_index) & (
147
+ input_ < org_vocab_end_index)
148
+ added_vocab_mask = (input_ >= added_vocab_start_index) & (
149
+ input_ < added_vocab_end_index)
150
+ added_offset = added_vocab_start_index - (
151
+ org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
152
+ valid_offset = (org_vocab_start_index *
153
+ org_vocab_mask) + (added_offset * added_vocab_mask)
154
+ vocab_mask = org_vocab_mask | added_vocab_mask
155
+ input_ = vocab_mask * (input_ - valid_offset)
156
+ return input_, ~vocab_mask
157
+
158
+
159
+ class VocabParallelEmbedding(torch.nn.Module):
160
+ """Embedding parallelized in the vocabulary dimension.
161
+
162
+ Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
163
+ make sure it is divisible by the number of model parallel GPUs.
164
+
165
+ In order to support various loading methods, we ensure that LoRA-added
166
+ embeddings are always at the end of TP-sharded tensors. In other words,
167
+ we shard base embeddings and LoRA embeddings separately (both padded),
168
+ and place them in the same tensor.
169
+ In this example, we will have the original vocab size = 1010,
170
+ added vocab size = 16 and padding to 64. Therefore, the total
171
+ vocab size with padding will be 1088 (because we first pad 1010 to
172
+ 1024, add 16, and then pad to 1088).
173
+ Therefore, the tensor format looks like the following:
174
+ TP1, rank 0 (no sharding):
175
+ |< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >|
176
+ corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 |
177
+ index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |
178
+
179
+ TP2, rank 0:
180
+ |< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >|
181
+ corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 |
182
+ index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 |
183
+ TP2, rank 1:
184
+ |< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
185
+ corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
186
+ index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
187
+
188
+ Args:
189
+ num_embeddings: vocabulary size.
190
+ embedding_dim: size of hidden state.
191
+ params_dtype: type of the parameters.
192
+ org_num_embeddings: original vocabulary size (without LoRA).
193
+ padding_size: padding size for the vocabulary.
194
+ quant_config: quant config for the layer
195
+ prefix: full name of the layer in the state dict
196
+ """ # noqa: E501
197
+
198
+ def __init__(self,
199
+ num_embeddings: int,
200
+ embedding_dim: int,
201
+ params_dtype: Optional[torch.dtype] = None,
202
+ org_num_embeddings: Optional[int] = None,
203
+ padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
204
+ quant_config: Optional[QuantizationConfig] = None,
205
+ prefix: str = ""):
206
+ super().__init__()
207
+
208
+ # Keep the input dimensions.
209
+ tp_rank = get_tensor_model_parallel_rank()
210
+ self.tp_size = get_tensor_model_parallel_world_size()
211
+ self.num_embeddings = num_embeddings
212
+ self.padding_size = padding_size
213
+ self.org_vocab_size = org_num_embeddings or num_embeddings
214
+ num_added_embeddings = num_embeddings - self.org_vocab_size
215
+ self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
216
+ self.padding_size)
217
+ self.num_embeddings_padded = pad_vocab_size(
218
+ self.org_vocab_size_padded + num_added_embeddings,
219
+ self.padding_size)
220
+ assert self.org_vocab_size_padded <= self.num_embeddings_padded
221
+
222
+ self.shard_indices = self._get_indices(self.num_embeddings_padded,
223
+ self.org_vocab_size_padded,
224
+ self.num_embeddings,
225
+ self.org_vocab_size, tp_rank,
226
+ self.tp_size)
227
+ self.embedding_dim = embedding_dim
228
+
229
+ linear_method = None
230
+ if quant_config is not None:
231
+ linear_method = quant_config.get_quant_method(self, prefix=prefix)
232
+ if linear_method is None:
233
+ linear_method = UnquantizedEmbeddingMethod()
234
+
235
+ # If we are making an embedding layer, then our quantization linear
236
+ # method must implement the embedding operation. If we are another
237
+ # layer type like ParallelLMHead, this is not important.
238
+ is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
239
+ linear_method_implements_embedding = method_has_implemented_embedding(
240
+ type(linear_method))
241
+ if is_embedding_layer and not linear_method_implements_embedding:
242
+ raise NotImplementedError(
243
+ f"The class {type(linear_method).__name__} must implement "
244
+ "the 'embedding' method, see UnquantizedEmbeddingMethod.")
245
+
246
+ self.linear_method: QuantizeMethodBase = linear_method
247
+
248
+ if params_dtype is None:
249
+ params_dtype = torch.get_default_dtype()
250
+ # Divide the weight matrix along the vocaburaly dimension.
251
+ self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
252
+ self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
253
+ self.tp_size)
254
+ assert (self.shard_indices.num_elements_padded ==
255
+ self.num_embeddings_per_partition)
256
+ self.num_org_embeddings_per_partition = (
257
+ self.shard_indices.org_vocab_end_index -
258
+ self.shard_indices.org_vocab_start_index)
259
+ self.num_added_embeddings_per_partition = (
260
+ self.shard_indices.added_vocab_end_index -
261
+ self.shard_indices.added_vocab_start_index)
262
+
263
+ self.linear_method.create_weights(self,
264
+ self.embedding_dim,
265
+ [self.num_embeddings_per_partition],
266
+ self.embedding_dim,
267
+ self.num_embeddings_padded,
268
+ params_dtype=params_dtype,
269
+ weight_loader=self.weight_loader)
270
+
271
+ @classmethod
272
+ def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
273
+ vocab_size: int, org_vocab_size: int, tp_rank: int,
274
+ tp_size: int) -> VocabParallelEmbeddingShardIndices:
275
+ """Get start and end indices for vocab parallel embedding, following the
276
+ layout outlined in the class docstring, based on the given tp_rank and
277
+ tp_size."""
278
+ num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded
279
+ padded_org_vocab_start_index, padded_org_vocab_end_index = (
280
+ vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank,
281
+ tp_size))
282
+ padded_added_vocab_start_index, padded_added_vocab_end_index = (
283
+ vocab_range_from_global_vocab_size(num_added_embeddings_padded,
284
+ tp_rank,
285
+ tp_size,
286
+ offset=org_vocab_size))
287
+ # remove padding
288
+ org_vocab_start_index = min(padded_org_vocab_start_index,
289
+ org_vocab_size)
290
+ org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
291
+ added_vocab_start_index = min(padded_added_vocab_start_index,
292
+ vocab_size)
293
+ added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
294
+ return VocabParallelEmbeddingShardIndices(
295
+ padded_org_vocab_start_index, padded_org_vocab_end_index,
296
+ padded_added_vocab_start_index, padded_added_vocab_end_index,
297
+ org_vocab_start_index, org_vocab_end_index,
298
+ added_vocab_start_index, added_vocab_end_index)
299
+
300
+ def get_sharded_to_full_mapping(self) -> Optional[List[int]]:
301
+ """Get a mapping that can be used to reindex the gathered
302
+ logits for sampling.
303
+
304
+ During sampling, we gather logits from all ranks. The relationship
305
+ of index->token_id will follow the same format as outlined in the class
306
+ docstring. However, after the gather, we want to reindex the final
307
+ logits tensor to map index->token_id one-to-one (the index is always
308
+ equal the token_id it corresponds to). The indices returned by this
309
+ method allow us to do that.
310
+ """
311
+ if self.tp_size < 2:
312
+ return None
313
+
314
+ base_embeddings: List[int] = []
315
+ added_embeddings: List[int] = []
316
+ padding: List[int] = []
317
+ for tp_rank in range(self.tp_size):
318
+ shard_indices = self._get_indices(self.num_embeddings_padded,
319
+ self.org_vocab_size_padded,
320
+ self.num_embeddings,
321
+ self.org_vocab_size, tp_rank,
322
+ self.tp_size)
323
+ range_start = self.num_embeddings_per_partition * tp_rank
324
+ range_end = self.num_embeddings_per_partition * (tp_rank + 1)
325
+ base_embeddings.extend(
326
+ range(range_start,
327
+ range_start + shard_indices.num_org_elements))
328
+ padding.extend(
329
+ range(range_start + shard_indices.num_org_elements,
330
+ range_start + shard_indices.num_org_elements_padded))
331
+ added_embeddings.extend(
332
+ range(
333
+ range_start + shard_indices.num_org_elements_padded,
334
+ range_start + shard_indices.num_org_elements_padded +
335
+ shard_indices.num_added_elements))
336
+ padding.extend(
337
+ range(
338
+ range_start + shard_indices.num_org_elements_padded +
339
+ shard_indices.num_added_elements,
340
+ range_start + shard_indices.num_org_elements_padded +
341
+ shard_indices.num_added_elements_padded))
342
+ assert (range_start + shard_indices.num_org_elements_padded +
343
+ shard_indices.num_added_elements_padded == range_end)
344
+ ret = base_embeddings + added_embeddings + padding
345
+ assert len(ret) == self.num_embeddings_padded
346
+ return ret
347
+
348
+ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
349
+ output_dim = getattr(param, "output_dim", None)
350
+ packed_dim = getattr(param, "packed_dim", None)
351
+
352
+ # If the parameter is a gguf weight, then load it directly.
353
+ if getattr(param, "is_gguf_weight_type", None):
354
+ param.data.copy_(loaded_weight)
355
+ param.weight_type = loaded_weight.item()
356
+ return
357
+ elif isinstance(param, UninitializedParameter):
358
+ shape = list(loaded_weight.shape)
359
+ if output_dim is not None:
360
+ shape[output_dim] = self.num_embeddings_per_partition
361
+ param.materialize(tuple(shape), dtype=loaded_weight.dtype)
362
+
363
+ # If parameter does not have output dim, then it should
364
+ # be copied onto all gpus (e.g. g_idx for act_order gptq).
365
+ if output_dim is None:
366
+ assert param.data.shape == loaded_weight.shape
367
+ param.data.copy_(loaded_weight)
368
+ return
369
+
370
+ # Shard indexes for loading the weight
371
+ start_idx = self.shard_indices.org_vocab_start_index
372
+ shard_size = self.shard_indices.org_vocab_end_index - start_idx
373
+
374
+ # If param packed on the same dim we are sharding on, then
375
+ # need to adjust offsets of loaded weight by pack_factor.
376
+ if packed_dim is not None and packed_dim == output_dim:
377
+ packed_factor = param.packed_factor if isinstance(
378
+ param, BasevLLMParameter) else param.pack_factor
379
+ assert loaded_weight.shape[output_dim] == (self.org_vocab_size //
380
+ param.packed_factor)
381
+ start_idx = start_idx // packed_factor
382
+ shard_size = shard_size // packed_factor
383
+ else:
384
+ assert loaded_weight.shape[output_dim] == self.org_vocab_size
385
+
386
+ # Copy the data. Select chunk corresponding to current shard.
387
+ loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
388
+
389
+ if current_platform.is_hpu():
390
+ # FIXME(kzawora): Weight copy with slicing bugs out on Gaudi here,
391
+ # so we're using a workaround. Remove this when fixed in
392
+ # HPU PT bridge.
393
+ padded_weight = torch.cat([
394
+ loaded_weight,
395
+ torch.zeros(param.shape[0] - loaded_weight.shape[0],
396
+ *loaded_weight.shape[1:])
397
+ ])
398
+ param.data.copy_(padded_weight)
399
+ else:
400
+ param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
401
+ param[loaded_weight.shape[0]:].data.fill_(0)
402
+
403
+ def forward(self, input_):
404
+ if self.tp_size > 1:
405
+ # Build the mask.
406
+ masked_input, input_mask = get_masked_input_and_mask(
407
+ input_, self.shard_indices.org_vocab_start_index,
408
+ self.shard_indices.org_vocab_end_index,
409
+ self.shard_indices.num_org_vocab_padding,
410
+ self.shard_indices.added_vocab_start_index,
411
+ self.shard_indices.added_vocab_end_index)
412
+ else:
413
+ masked_input = input_
414
+ # Get the embeddings.
415
+ output_parallel = self.linear_method.embedding(self,
416
+ masked_input.long())
417
+ # Mask the output embedding.
418
+ if self.tp_size > 1:
419
+ output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
420
+ # Reduce across all the model parallel GPUs.
421
+ output = tensor_model_parallel_all_reduce(output_parallel)
422
+ return output
423
+
424
+ def extra_repr(self) -> str:
425
+ s = f"num_embeddings={self.num_embeddings_per_partition}"
426
+ s += f", embedding_dim={self.embedding_dim}"
427
+ s += f", org_vocab_size={self.org_vocab_size}"
428
+ s += f', num_embeddings_padded={self.num_embeddings_padded}'
429
+ s += f', tp_size={self.tp_size}'
430
+ return s
431
+
432
+
433
+ class ParallelLMHead(VocabParallelEmbedding):
434
+ """Parallelized LM head.
435
+
436
+ Output logits weight matrices used in the Sampler. The weight and bias
437
+ tensors are padded to make sure they are divisible by the number of
438
+ model parallel GPUs.
439
+
440
+ Args:
441
+ num_embeddings: vocabulary size.
442
+ embedding_dim: size of hidden state.
443
+ bias: whether to use bias.
444
+ params_dtype: type of the parameters.
445
+ org_num_embeddings: original vocabulary size (without LoRA).
446
+ padding_size: padding size for the vocabulary.
447
+ """
448
+
449
+ def __init__(self,
450
+ num_embeddings: int,
451
+ embedding_dim: int,
452
+ bias: bool = False,
453
+ params_dtype: Optional[torch.dtype] = None,
454
+ org_num_embeddings: Optional[int] = None,
455
+ padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
456
+ quant_config: Optional[QuantizationConfig] = None,
457
+ prefix: str = ""):
458
+ super().__init__(num_embeddings, embedding_dim, params_dtype,
459
+ org_num_embeddings, padding_size, quant_config,
460
+ prefix)
461
+ self.quant_config = quant_config
462
+ if bias:
463
+ self.bias = Parameter(
464
+ torch.empty(self.num_embeddings_per_partition,
465
+ dtype=params_dtype))
466
+ set_weight_attrs(self.bias, {
467
+ "output_dim": 0,
468
+ "weight_loader": self.weight_loader,
469
+ })
470
+ else:
471
+ self.register_parameter("bias", None)
472
+
473
+ def tie_weights(self, embed_tokens: VocabParallelEmbedding):
474
+ """Tie the weights with word embeddings."""
475
+ # GGUF quantized embed_tokens.
476
+ if self.quant_config and self.quant_config.get_name() == "gguf":
477
+ return embed_tokens
478
+ else:
479
+ self.weight = embed_tokens.weight
480
+ return self
481
+
482
+ def forward(self, input_):
483
+ del input_
484
+ raise RuntimeError("LMHead's weights should be used in the sampler.")
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (907 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/adapters.cpython-311.pyc ADDED
Binary file (10.8 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/arctic.cpython-311.pyc ADDED
Binary file (28.3 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/bert.cpython-311.pyc ADDED
Binary file (28.3 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/blip2.cpython-311.pyc ADDED
Binary file (34.8 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/bloom.cpython-311.pyc ADDED
Binary file (18 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/chameleon.cpython-311.pyc ADDED
Binary file (57.5 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/chatglm.cpython-311.pyc ADDED
Binary file (34.6 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/clip.cpython-311.pyc ADDED
Binary file (25.4 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/decilm.cpython-311.pyc ADDED
Binary file (4.97 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/deepseek.cpython-311.pyc ADDED
Binary file (23.4 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/eagle.cpython-311.pyc ADDED
Binary file (11.1 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/fairseq2_llama.cpython-311.pyc ADDED
Binary file (7.54 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/falcon.cpython-311.pyc ADDED
Binary file (23 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/glm.cpython-311.pyc ADDED
Binary file (1.62 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/gpt2.cpython-311.pyc ADDED
Binary file (16.1 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/model_executor/models/__pycache__/gpt_j.cpython-311.pyc ADDED
Binary file (16.7 kB). View file