Spaces:
Running
Running
| from fnmatch import fnmatch | |
| from typing import Any, Dict, List, Optional, Union | |
| import torch | |
| from dataclasses import dataclass | |
| from optimum.quanto.quantize import _quantize_submodule | |
| from optimum.quanto.tensor import Optimizer, qtype, qtypes | |
| from torchao.quantization.quant_api import ( | |
| quantize_ as torchao_quantize_, | |
| Float8WeightOnlyConfig, | |
| UIntXWeightOnlyConfig | |
| ) | |
| # the quantize function in quanto had a bug where it was using exclude instead of include | |
| Q_MODULES = ['QLinear', 'QConv2d', 'QEmbedding', 'QBatchNorm2d', 'QLayerNorm', 'QConvTranspose2d', 'QEmbeddingBag'] | |
| torchao_qtypes = { | |
| # "int4": Int4WeightOnlyConfig(), | |
| "uint2": UIntXWeightOnlyConfig(torch.uint2), | |
| "uint3": UIntXWeightOnlyConfig(torch.uint3), | |
| "uint4": UIntXWeightOnlyConfig(torch.uint4), | |
| "uint5": UIntXWeightOnlyConfig(torch.uint5), | |
| "uint6": UIntXWeightOnlyConfig(torch.uint6), | |
| "uint7": UIntXWeightOnlyConfig(torch.uint7), | |
| "uint8": UIntXWeightOnlyConfig(torch.uint8), | |
| "float8": Float8WeightOnlyConfig(), | |
| } | |
| class aotype: | |
| def __init__(self, name: str): | |
| self.name = name | |
| self.config = torchao_qtypes[name] | |
| def get_qtype(qtype: Union[str, qtype]) -> qtype: | |
| if qtype in torchao_qtypes: | |
| return aotype(qtype) | |
| if isinstance(qtype, str): | |
| return qtypes[qtype] | |
| else: | |
| return qtype | |
| def quantize( | |
| model: torch.nn.Module, | |
| weights: Optional[Union[str, qtype, aotype]] = None, | |
| activations: Optional[Union[str, qtype]] = None, | |
| optimizer: Optional[Optimizer] = None, | |
| include: Optional[Union[str, List[str]]] = None, | |
| exclude: Optional[Union[str, List[str]]] = None, | |
| ): | |
| """Quantize the specified model submodules | |
| Recursively quantize the submodules of the specified parent model. | |
| Only modules that have quantized counterparts will be quantized. | |
| If include patterns are specified, the submodule name must match one of them. | |
| If exclude patterns are specified, the submodule must not match one of them. | |
| Include or exclude patterns are Unix shell-style wildcards which are NOT regular expressions. See | |
| https://docs.python.org/3/library/fnmatch.html for more details. | |
| Note: quantization happens in-place and modifies the original model and its descendants. | |
| Args: | |
| model (`torch.nn.Module`): the model whose submodules will be quantized. | |
| weights (`Optional[Union[str, qtype]]`): the qtype for weights quantization. | |
| activations (`Optional[Union[str, qtype]]`): the qtype for activations quantization. | |
| include (`Optional[Union[str, List[str]]]`): | |
| Patterns constituting the allowlist. If provided, module names must match at | |
| least one pattern from the allowlist. | |
| exclude (`Optional[Union[str, List[str]]]`): | |
| Patterns constituting the denylist. If provided, module names must not match | |
| any patterns from the denylist. | |
| """ | |
| if include is not None: | |
| include = [include] if isinstance(include, str) else include | |
| if exclude is not None: | |
| exclude = [exclude] if isinstance(exclude, str) else exclude | |
| for name, m in model.named_modules(): | |
| if include is not None and not any(fnmatch(name, pattern) for pattern in include): | |
| continue | |
| if exclude is not None and any(fnmatch(name, pattern) for pattern in exclude): | |
| continue | |
| try: | |
| # check if m is QLinear or QConv2d | |
| if m.__class__.__name__ in Q_MODULES: | |
| continue | |
| else: | |
| if isinstance(weights, aotype): | |
| torchao_quantize_(m, weights.config) | |
| else: | |
| _quantize_submodule(model, name, m, weights=weights, | |
| activations=activations, optimizer=optimizer) | |
| except Exception as e: | |
| print(f"Failed to quantize {name}: {e}") | |
| raise e |