| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import warnings |
| from copy import deepcopy |
| from typing import List, Optional |
|
|
| import torch |
| from compressed_tensors.config import CompressionFormat |
| from compressed_tensors.quantization.quant_args import ( |
| FP8_E4M3_DATA, |
| DynamicType, |
| QuantizationArgs, |
| QuantizationStrategy, |
| QuantizationType, |
| ) |
| from pydantic import BaseModel, ConfigDict, model_validator |
|
|
|
|
| __all__ = [ |
| "QuantizationScheme", |
| "preset_name_to_scheme", |
| "is_preset_scheme", |
| ] |
|
|
|
|
| class QuantizationScheme(BaseModel): |
| """ |
| Set of QuantizationArgs defining how the weights, inputs and outputs of target list |
| of modules should be quantized |
| |
| :param targets: list of modules to apply the QuantizationArgs to, can be layer |
| names, layer types or a regular expression, typically ["Linear"] |
| :param weights: quantization config for layer weights |
| :param input_activations: quantization config for layer inputs |
| :param output_activations: quantization config for layer outputs |
| :param format: CompressionFormat for the layer |
| """ |
|
|
| targets: List[str] |
| weights: Optional[QuantizationArgs] = None |
| input_activations: Optional[QuantizationArgs] = None |
| output_activations: Optional[QuantizationArgs] = None |
| format: Optional[str] = None |
|
|
| @model_validator(mode="after") |
| def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": |
| inputs = model.input_activations |
| outputs = model.output_activations |
| weights = model.weights |
| format = model.format |
|
|
| if inputs is not None: |
| if inputs.strategy not in ( |
| QuantizationStrategy.TOKEN, |
| QuantizationStrategy.TENSOR, |
| QuantizationStrategy.GROUP, |
| QuantizationStrategy.TENSOR_GROUP, |
| QuantizationStrategy.ATTN_HEAD, |
| ): |
| if ( |
| inputs.strategy == QuantizationStrategy.GROUP |
| and inputs.dynamic is True |
| ): |
| raise NotImplementedError( |
| "Static and local group-wise activation " |
| "quantization is not supported" |
| ) |
|
|
| raise NotImplementedError( |
| f"Using {inputs.strategy} strategy is not supported for " |
| "activation quantization" |
| ) |
|
|
| if inputs.actorder is not None: |
| raise ValueError("Cannot apply actorder to input activations") |
|
|
| if outputs is not None: |
| if outputs.actorder is not None: |
| raise ValueError("Cannot apply actorder to output activations") |
|
|
| if format == CompressionFormat.mixed_precision.value: |
| raise ValueError( |
| "mixed-precision cannot be set as a format for a QuantizationScheme" |
| ) |
|
|
| if ( |
| inputs |
| and weights |
| and weights.strategy == QuantizationStrategy.GROUP |
| and inputs.strategy == QuantizationStrategy.GROUP |
| and weights.group_size != inputs.group_size |
| ): |
| warnings.warn( |
| "Using GROUP strategy for both weights and input_activations " |
| f"with different group sizes ({weights.group_size} vs " |
| f"{inputs.group_size}) may complicate fused kernel implementations. " |
| "Consider using TENSOR_GROUP strategy for both or matching group" |
| " sizes.", |
| UserWarning, |
| stacklevel=2, |
| ) |
|
|
| return model |
|
|
| model_config = ConfigDict(extra="forbid") |
|
|
|
|
| """ |
| Pre-Set Quantization Scheme Args |
| """ |
|
|
|
|
| def preset_name_to_scheme(name: str, targets: List[str]) -> QuantizationScheme: |
| """ |
| :param name: preset quantization settings name. must exist in upper case in |
| PRESET_SCHEMES |
| :param targets: list of quantization targets to be passed to the Scheme |
| :return: new QuantizationScheme for a given name with the given targets |
| """ |
| name = name.upper() |
|
|
| if name not in PRESET_SCHEMES: |
| raise KeyError( |
| f"Unknown preset scheme name {name}, " |
| f"available names: {list(PRESET_SCHEMES.keys())}" |
| ) |
|
|
| scheme_args = deepcopy(PRESET_SCHEMES[name]) |
| return QuantizationScheme( |
| targets=targets, |
| **scheme_args, |
| ) |
|
|
|
|
| def is_preset_scheme(name: str) -> bool: |
| """ |
| :param name: preset quantization settings name |
| :return: True if the name is a preset scheme name |
| """ |
| return name.upper() in PRESET_SCHEMES |
|
|
|
|
| UNQUANTIZED = dict() |
|
|
| NVFP4A16 = dict( |
| weights=QuantizationArgs( |
| num_bits=4, |
| type=QuantizationType.FLOAT, |
| strategy=QuantizationStrategy.TENSOR_GROUP, |
| symmetric=True, |
| dynamic=False, |
| group_size=16, |
| scale_dtype=FP8_E4M3_DATA.dtype, |
| zp_dtype=FP8_E4M3_DATA.dtype, |
| ) |
| ) |
|
|
|
|
| NVFP4 = dict( |
| weights=QuantizationArgs( |
| num_bits=4, |
| type=QuantizationType.FLOAT, |
| strategy=QuantizationStrategy.TENSOR_GROUP, |
| symmetric=True, |
| dynamic=False, |
| group_size=16, |
| observer="static_minmax", |
| scale_dtype=FP8_E4M3_DATA.dtype, |
| zp_dtype=FP8_E4M3_DATA.dtype, |
| ), |
| input_activations=QuantizationArgs( |
| num_bits=4, |
| type=QuantizationType.FLOAT, |
| strategy=QuantizationStrategy.TENSOR_GROUP, |
| symmetric=True, |
| dynamic=DynamicType.LOCAL, |
| group_size=16, |
| observer="static_minmax", |
| scale_dtype=FP8_E4M3_DATA.dtype, |
| zp_dtype=FP8_E4M3_DATA.dtype, |
| ), |
| ) |
|
|
| MXFP4A16 = dict( |
| weights=QuantizationArgs( |
| num_bits=4, |
| type=QuantizationType.FLOAT, |
| strategy=QuantizationStrategy.GROUP, |
| symmetric=True, |
| dynamic=False, |
| group_size=32, |
| scale_dtype=torch.uint8, |
| zp_dtype=torch.uint8, |
| ) |
| ) |
|
|
| MXFP4 = dict( |
| weights=QuantizationArgs( |
| num_bits=4, |
| type=QuantizationType.FLOAT, |
| strategy=QuantizationStrategy.GROUP, |
| symmetric=True, |
| dynamic=False, |
| group_size=32, |
| scale_dtype=torch.uint8, |
| zp_dtype=torch.uint8, |
| ), |
| input_activations=QuantizationArgs( |
| num_bits=4, |
| type=QuantizationType.FLOAT, |
| strategy=QuantizationStrategy.GROUP, |
| dynamic=True, |
| symmetric=True, |
| group_size=32, |
| scale_dtype=torch.uint8, |
| zp_dtype=torch.uint8, |
| ), |
| ) |
|
|
|
|
| |
| INT8_W8A8 = dict( |
| weights=QuantizationArgs( |
| num_bits=8, |
| type=QuantizationType.INT, |
| strategy=QuantizationStrategy.CHANNEL, |
| symmetric=True, |
| dynamic=False, |
| ), |
| input_activations=QuantizationArgs( |
| num_bits=8, |
| type=QuantizationType.INT, |
| strategy=QuantizationStrategy.TOKEN, |
| symmetric=True, |
| dynamic=True, |
| observer=None, |
| ), |
| ) |
|
|
| |
| W8A16 = dict( |
| weights=QuantizationArgs( |
| num_bits=8, |
| type=QuantizationType.INT, |
| strategy=QuantizationStrategy.CHANNEL, |
| symmetric=True, |
| dynamic=False, |
| ), |
| ) |
|
|
| |
| W4A16 = dict( |
| weights=QuantizationArgs( |
| num_bits=4, |
| type=QuantizationType.INT, |
| strategy=QuantizationStrategy.GROUP, |
| group_size=128, |
| symmetric=True, |
| dynamic=False, |
| ), |
| ) |
|
|
| |
| W4A16_ASYM = dict( |
| weights=QuantizationArgs( |
| num_bits=4, |
| type=QuantizationType.INT, |
| strategy=QuantizationStrategy.GROUP, |
| group_size=128, |
| symmetric=False, |
| dynamic=False, |
| ), |
| ) |
|
|
| |
| INT8_W4A8 = dict( |
| weights=QuantizationArgs( |
| num_bits=4, |
| type=QuantizationType.INT, |
| group_size=128, |
| strategy=QuantizationStrategy.GROUP, |
| symmetric=True, |
| dynamic=False, |
| ), |
| input_activations=QuantizationArgs( |
| num_bits=8, |
| type=QuantizationType.INT, |
| strategy=QuantizationStrategy.TOKEN, |
| symmetric=True, |
| dynamic=True, |
| observer=None, |
| ), |
| ) |
|
|
| |
| FP8 = dict( |
| weights=QuantizationArgs( |
| num_bits=8, |
| type=QuantizationType.FLOAT, |
| strategy=QuantizationStrategy.TENSOR, |
| symmetric=True, |
| dynamic=False, |
| ), |
| input_activations=QuantizationArgs( |
| num_bits=8, |
| type=QuantizationType.FLOAT, |
| strategy=QuantizationStrategy.TENSOR, |
| symmetric=True, |
| dynamic=False, |
| ), |
| ) |
|
|
| |
| FP8_DYNAMIC = dict( |
| weights=QuantizationArgs( |
| num_bits=8, |
| type=QuantizationType.FLOAT, |
| strategy=QuantizationStrategy.CHANNEL, |
| symmetric=True, |
| dynamic=False, |
| ), |
| input_activations=QuantizationArgs( |
| num_bits=8, |
| type=QuantizationType.FLOAT, |
| strategy=QuantizationStrategy.TOKEN, |
| symmetric=True, |
| dynamic=True, |
| observer=None, |
| ), |
| ) |
|
|
| |
| |
| |
| FP8_BLOCK = dict( |
| weights=QuantizationArgs( |
| num_bits=8, |
| type=QuantizationType.FLOAT, |
| strategy=QuantizationStrategy.BLOCK, |
| symmetric=True, |
| dynamic=False, |
| block_structure=[128, 128], |
| ), |
| input_activations=QuantizationArgs( |
| num_bits=8, |
| type=QuantizationType.FLOAT, |
| strategy=QuantizationStrategy.GROUP, |
| symmetric=True, |
| dynamic=True, |
| observer=None, |
| group_size=128, |
| ), |
| ) |
|
|
| PRESET_SCHEMES = { |
| |
| "UNQUANTIZED": UNQUANTIZED, |
| |
| "W8A16": W8A16, |
| "W4A16": W4A16, |
| "W4A16_ASYM": W4A16_ASYM, |
| |
| "W8A8": INT8_W8A8, |
| "INT8": INT8_W8A8, |
| "W4A8": INT8_W4A8, |
| |
| "FP8": FP8, |
| "FP8_DYNAMIC": FP8_DYNAMIC, |
| "FP8_BLOCK": FP8_BLOCK, |
| "NVFP4A16": NVFP4A16, |
| "NVFP4": NVFP4, |
| "MXFP4A16": MXFP4A16, |
| "MXFP4": MXFP4, |
| } |
|
|