| from typing import Optional, Tuple |
|
|
| import torch |
| from sgl_kernel.utils import _get_cache_buf |
|
|
|
|
| def awq_dequantize( |
| qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor |
| ) -> torch.ByteTensor: |
| return torch.ops.sgl_kernel.awq_dequantize.default(qweight, scales, qzeros) |
|
|
|
|
| def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): |
| return torch.ops.sgl_kernel.int8_scaled_mm.default( |
| mat_a, |
| mat_b, |
| scales_a, |
| scales_b, |
| out_dtype, |
| bias, |
| ) |
|
|
|
|
| def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype): |
| return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm.default( |
| mat_a, |
| mat_b, |
| scales_a, |
| scales_b, |
| out_dtype, |
| ) |
|
|
|
|
| def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): |
| return torch.ops.sgl_kernel.fp8_scaled_mm.default( |
| mat_a, |
| mat_b, |
| scales_a, |
| scales_b, |
| out_dtype, |
| bias, |
| ) |
|
|
|
|
| def _bmm_fp8_internal( |
| workspace_buffer: torch.Tensor, |
| A: torch.Tensor, |
| B: torch.Tensor, |
| D: torch.Tensor, |
| A_scale: torch.Tensor, |
| B_scale: torch.Tensor, |
| ) -> None: |
| cublas_handle = torch.cuda.current_blas_handle() |
| torch.ops.sgl_kernel.bmm_fp8.default( |
| A, |
| B, |
| D, |
| A_scale, |
| B_scale, |
| workspace_buffer, |
| cublas_handle, |
| ) |
|
|
|
|
| def bmm_fp8( |
| A: torch.Tensor, |
| B: torch.Tensor, |
| A_scale: torch.Tensor, |
| B_scale: torch.Tensor, |
| dtype: torch.dtype, |
| out: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| if out is None: |
| out = torch.empty( |
| (A.shape[0], A.shape[1], B.shape[2]), |
| device=A.device, |
| dtype=dtype, |
| ) |
| workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device) |
| _bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale) |
| return out |
|
|
|
|
| def dsv3_fused_a_gemm( |
| mat_a: torch.Tensor, |
| mat_b: torch.Tensor, |
| output: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| if output is None: |
| output = torch.empty( |
| (mat_a.shape[0], mat_b.shape[1]), |
| device=mat_a.device, |
| dtype=mat_a.dtype, |
| ) |
| torch.ops.sgl_kernel.dsv3_fused_a_gemm.default(output, mat_a, mat_b) |
| return output |
|
|
|
|
| def sgl_per_token_group_quant_8bit( |
| input: torch.Tensor, |
| output_q: torch.Tensor, |
| output_s: torch.Tensor, |
| group_size: int, |
| eps: float, |
| fp8_min: float, |
| fp8_max: float, |
| scale_ue8m0: bool = False, |
| fuse_silu_and_mul: bool = False, |
| masked_m: Optional[torch.Tensor] = None, |
| enable_v2: Optional[bool] = None, |
| ) -> None: |
| if enable_v2 is None: |
| from sglang.srt.utils import get_bool_env_var |
|
|
| enable_v2 = get_bool_env_var("SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2") |
|
|
| if enable_v2: |
| return torch.ops.sgl_kernel.sgl_per_token_group_quant_8bit_v2.default( |
| input, |
| output_q, |
| output_s, |
| group_size, |
| eps, |
| fp8_min, |
| fp8_max, |
| scale_ue8m0, |
| fuse_silu_and_mul, |
| masked_m, |
| ) |
|
|
| assert not fuse_silu_and_mul, "only v2 support fuse_silu_and_mul" |
| assert masked_m is None, "only v2 support masked_m" |
| torch.ops.sgl_kernel.sgl_per_token_group_quant_8bit.default( |
| input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0 |
| ) |
|
|
|
|
| |
| sgl_per_token_group_quant_fp8 = sgl_per_token_group_quant_8bit |
| sgl_per_token_group_quant_int8 = sgl_per_token_group_quant_8bit |
|
|
|
|
| def sgl_per_tensor_quant_fp8( |
| input: torch.Tensor, |
| output_q: torch.Tensor, |
| output_s: torch.Tensor, |
| is_static: bool, |
| ) -> None: |
| torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8.default( |
| input, output_q, output_s, is_static |
| ) |
|
|
|
|
| def sgl_per_token_quant_fp8( |
| input: torch.Tensor, |
| output_q: torch.Tensor, |
| output_s: torch.Tensor, |
| ) -> None: |
| torch.ops.sgl_kernel.sgl_per_token_quant_fp8.default(input, output_q, output_s) |
|
|
|
|
| def cutlass_scaled_fp4_mm( |
| a: torch.Tensor, |
| b: torch.Tensor, |
| block_scale_a: torch.Tensor, |
| block_scale_b: torch.Tensor, |
| alpha: torch.Tensor, |
| out_dtype: torch.dtype, |
| ) -> torch.Tensor: |
| from sglang.jit_kernel.nvfp4 import ( |
| cutlass_scaled_fp4_mm as jit_cutlass_scaled_fp4_mm, |
| ) |
|
|
| return jit_cutlass_scaled_fp4_mm( |
| a, |
| b, |
| block_scale_a, |
| block_scale_b, |
| alpha, |
| out_dtype, |
| ) |
|
|
|
|
| def scaled_fp4_quant( |
| input: torch.Tensor, input_global_scale: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Quantize input tensor to FP4 and return quantized tensor and scale. |
| |
| This function quantizes the last dimension of the given tensor `input`. For |
| every 16 consecutive elements, a single dynamically computed scaling factor |
| is shared. This scaling factor is quantized using the `input_global_scale` |
| and is stored in a swizzled layout (see |
| https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x). |
| |
| Args: |
| input: The input tensor to be quantized to FP4 |
| input_global_scale: A scalar scaling factor for the entire tensor. |
| |
| Returns: |
| Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every |
| two values are packed into a uint8 and float8_e4m3 scaling factors |
| in a sizzled layout. |
| """ |
| from sglang.jit_kernel.nvfp4 import scaled_fp4_quant as jit_scaled_fp4_quant |
|
|
| return jit_scaled_fp4_quant(input, input_global_scale) |
|
|
|
|
| def qserve_w4a8_per_chn_gemm( |
| in_feats: torch.Tensor, |
| kernel: torch.Tensor, |
| wscales: torch.Tensor, |
| ascales: torch.Tensor, |
| w_szs: torch.Tensor, |
| a_ssums: torch.Tensor, |
| out_feats: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| if out_feats is None: |
| |
| out_feats = torch.empty( |
| (in_feats.shape[0], kernel.shape[0]), |
| device=in_feats.device, |
| dtype=torch.float16, |
| ) |
| torch.ops.sgl_kernel.qserve_w4a8_per_chn_gemm.default( |
| in_feats, kernel, wscales, ascales, w_szs, a_ssums, out_feats |
| ) |
| return out_feats |
|
|
|
|
| def qserve_w4a8_per_group_gemm( |
| in_feats: torch.Tensor, |
| kernel: torch.Tensor, |
| zeros: torch.Tensor, |
| scales_i8: torch.Tensor, |
| wscales: torch.Tensor, |
| ascales: torch.Tensor, |
| out_feats: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| if out_feats is None: |
| |
| out_feats = torch.empty( |
| (in_feats.shape[0], kernel.shape[0]), |
| device=in_feats.device, |
| dtype=torch.float16, |
| ) |
| torch.ops.sgl_kernel.qserve_w4a8_per_group_gemm.default( |
| in_feats, kernel, zeros, scales_i8, wscales, ascales, out_feats |
| ) |
| return out_feats |
|
|
|
|
| def dsv3_router_gemm( |
| hidden_states: torch.Tensor, |
| router_weights: torch.Tensor, |
| out_dtype: torch.dtype = torch.bfloat16, |
| ) -> torch.Tensor: |
| output = torch.empty( |
| hidden_states.shape[0], |
| router_weights.shape[0], |
| device=hidden_states.device, |
| dtype=out_dtype, |
| ) |
| torch.ops.sgl_kernel.dsv3_router_gemm( |
| output, |
| hidden_states, |
| router_weights, |
| ) |
| return output |
|
|
|
|
| def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape): |
| output_tensor = torch.empty( |
| output_tensor_shape, |
| device=input_tensor.device, |
| dtype=input_tensor.dtype, |
| ) |
| torch.ops.sgl_kernel.shuffle_rows.default(input_tensor, dst2src_map, output_tensor) |
| return output_tensor |
|
|
|
|
| def scaled_fp4_grouped_quant( |
| input_tensor: torch.Tensor, |
| input_global_scale: torch.Tensor, |
| mask: torch.Tensor, |
| ): |
| """ |
| Quantize input tensor to FP4 and return quantized tensor and scale, for |
| grouped gemm inputs (e.g., grouped_gemm_nt_masked for flashinfer). |
| Args: |
| input: The input tensor to be quantized to FP4, with shape (l, m, k) |
| l is number of groups, m is number of tokens per group, k is number of features. |
| input_global_scale: A scalar scaling factor for the entire tensor, with |
| shape (l,). |
| Outputs: |
| output: The quantized tensor in FP4, with shape (m, k // 2, l) but the physical |
| layout is (l, m, k // 2). `// 2` is because two fp4 values are packed into |
| an uint8. |
| output_scales: The blockscale tensor in FP8-E4M3, with shape (32, 4, rm, 4, rk, l) |
| but the physical layout is (l, rm, rk, 32, 4, 4). |
| Note: |
| For the shape of output_scales, `32 * 4 * rm` is a padded m to nearest multiple of 128. |
| `4 * rk` is a padded `k // 16` to nearest multiple of 4. These layout constants are |
| required by the NVIDIA Blackwell MMA operations. |
| """ |
| from sglang.jit_kernel.nvfp4 import ( |
| scaled_fp4_grouped_quant as jit_scaled_fp4_grouped_quant, |
| ) |
|
|
| return jit_scaled_fp4_grouped_quant(input_tensor, input_global_scale, mask) |
|
|
|
|
| def silu_and_mul_scaled_fp4_grouped_quant( |
| input_tensor: torch.Tensor, |
| input_global_scale: torch.Tensor, |
| mask: torch.Tensor, |
| ): |
| """ |
| Quantize input tensor to FP4 and return quantized tensor and scale, for |
| grouped gemm inputs (e.g., grouped_gemm_nt_masked for flashinfer). |
| Args: |
| input: The input tensor to be quantized to FP4, with shape (l, m, k * 2) |
| l is number of groups, m is number of tokens per group, k is number of features. |
| input_global_scale: A scalar scaling factor for the entire tensor, with |
| shape (l,). |
| mask: The mask tensor, with shape (l,) |
| Outputs: |
| output: The quantized tensor in FP4, with shape (m, k // 2, l) but the physical |
| layout is (l, m, k // 2). `// 2` is because two fp4 values are packed into |
| an uint8. |
| output_scales: The blockscale tensor in FP8-E4M3, with shape (32, 4, rm, 4, rk, l) |
| but the physical layout is (l, rm, rk, 32, 4, 4). |
| Note: |
| For the shape of output_scales, `32 * 4 * rm` is a padded m to nearest multiple of 128. |
| `4 * rk` is a padded `k // 16` to nearest multiple of 4. These layout constants are |
| required by the NVIDIA Blackwell MMA operations. |
| """ |
| from sglang.jit_kernel.nvfp4 import ( |
| silu_and_mul_scaled_fp4_grouped_quant as jit_silu_and_mul_scaled_fp4_grouped_quant, |
| ) |
|
|
| return jit_silu_and_mul_scaled_fp4_grouped_quant( |
| input_tensor, |
| input_global_scale, |
| mask, |
| ) |
|
|
|
|
| |
| def gptq_gemm( |
| a: torch.Tensor, |
| b_q_weight: torch.Tensor, |
| b_gptq_qzeros: torch.Tensor, |
| b_gptq_scales: torch.Tensor, |
| b_g_idx: torch.Tensor, |
| use_shuffle: bool, |
| bit: int, |
| ) -> torch.Tensor: |
| return torch.ops.sgl_kernel.gptq_gemm( |
| a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit |
| ) |
|
|
|
|
| def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None: |
| torch.torch.ops.sgl_kernel.gptq_shuffle(q_weight, q_perm, bit) |
|
|