| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from enum import Enum |
|
|
| from compressed_tensors.utils import delete_offload_parameter |
| from torch.nn import Module |
|
|
|
|
| __all__ = ["QuantizationMetadata", "KVCacheScaleType"] |
|
|
|
|
| class KVCacheScaleType(Enum): |
| KEY = "k_scale" |
| VALUE = "v_scale" |
|
|
|
|
| class QuantizationMetadata: |
| """ |
| Container class for metadata related to quantization |
| """ |
|
|
| @staticmethod |
| def all_qparam_names(): |
| """ |
| All quantization parameter names that might be registered |
| onto a module during lifecycle (excluding serialized parameters) |
| """ |
| return [KVCacheScaleType.KEY.value, KVCacheScaleType.VALUE.value] + [ |
| f"{base_name}_{suffix}" |
| for base_name in ("input", "weight", "output") |
| for suffix in ( |
| "global_scale", |
| "scale", |
| "zero_point", |
| "g_idx", |
| ) |
| ] |
|
|
| @classmethod |
| def clear_all_qparams(cls, module: Module): |
| """ |
| Remove all parameters related to quantization that might have |
| been registered onto a module previously in lifecycle (excluding |
| serialized parameters) |
| |
| :param module: Module to clear |
| """ |
| for key in cls.all_qparam_names(): |
| if hasattr(module, key): |
| delete_offload_parameter(module, key) |
|
|