| | |
| | import copy |
| | import inspect |
| | import itertools |
| | import typing_extensions |
| | import warnings |
| |
|
| | import torch |
| | import torch.ao.nn.quantized as nnq |
| | import torch.nn as nn |
| | from torch.ao.nn.intrinsic import _FusedModule |
| | from torch.ao.quantization.observer import _is_activation_post_process |
| | from torch.ao.quantization.qconfig import ( |
| | _activation_is_memoryless, |
| | _add_module_to_qconfig_obs_ctr, |
| | default_dynamic_qconfig, |
| | float16_dynamic_qconfig, |
| | float_qparams_weight_only_qconfig, |
| | float_qparams_weight_only_qconfig_4bit, |
| | ) |
| | from torch.ao.quantization.quantization_mappings import ( |
| | _get_special_act_post_process, |
| | _has_special_act_post_process, |
| | get_default_dynamic_quant_module_mappings, |
| | get_default_qat_module_mappings, |
| | get_default_qconfig_propagation_list, |
| | get_default_static_quant_module_mappings, |
| | get_default_static_quant_reference_module_mappings, |
| | no_observer_set, |
| | ) |
| | from torch.ao.quantization.stubs import DeQuantStub, QuantWrapper |
| | from torch.nn.utils.parametrize import type_before_parametrizations |
| |
|
| | from .utils import ( |
| | DEPRECATION_WARNING, |
| | get_qparam_dict, |
| | has_no_children_ignoring_parametrizations, |
| | ) |
| |
|
| |
|
| | __all__ = [ |
| | "get_default_custom_config_dict", |
| | "propagate_qconfig_", |
| | "add_quant_dequant", |
| | "prepare", |
| | "quantize", |
| | "quantize_dynamic", |
| | "prepare_qat", |
| | "quantize_qat", |
| | "convert", |
| | "swap_module", |
| | ] |
| |
|
| |
|
| | |
| | is_activation_post_process = _is_activation_post_process |
| |
|
| |
|
| | _DEFAULT_CUSTOM_CONFIG_DICT = { |
| | "float_to_observed_custom_module_class": { |
| | nn.LSTM: nn.quantizable.LSTM, |
| | nn.MultiheadAttention: nn.quantizable.MultiheadAttention, |
| | }, |
| | "observed_to_quantized_custom_module_class": { |
| | nn.quantizable.LSTM: nn.quantized.LSTM, |
| | nn.quantizable.MultiheadAttention: nn.quantized.MultiheadAttention, |
| | }, |
| | } |
| |
|
| |
|
| | def get_default_custom_config_dict(): |
| | r"""Defines the default custom config dict.""" |
| | return _DEFAULT_CUSTOM_CONFIG_DICT |
| |
|
| |
|
| | def _propagate_qconfig_helper( |
| | module, |
| | qconfig_dict, |
| | qconfig_parent=None, |
| | prefix="", |
| | prepare_custom_config_dict=None, |
| | ): |
| | r"""This is a helper function for `propagate_qconfig_` |
| | |
| | Args: |
| | module: input module |
| | qconfig_dict: dictionary that maps from name of submodule to quantization |
| | configuration |
| | qconfig_parent: quantization config of parent module, we will fallback to |
| | this config when there is no specified config for current |
| | module |
| | prefix: corresponding prefix of the current module, used as key in |
| | qconfig_dict |
| | prepare_custom_config_dict: dictionary for custom handling of modules |
| | see docs for :func:`~torch.ao.quantization.prepare_fx` |
| | |
| | Return: |
| | None, module is modified inplace with qconfig attached |
| | """ |
| |
|
| | module_qconfig = qconfig_dict.get( |
| | type_before_parametrizations(module), qconfig_parent |
| | ) |
| | module_qconfig = qconfig_dict.get(prefix, module_qconfig) |
| | module_qconfig = getattr(module, "qconfig", module_qconfig) |
| |
|
| | torch.ao.quantization.qconfig._assert_valid_qconfig(module_qconfig, module) |
| |
|
| | qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(module_qconfig, module) |
| | module.qconfig = qconfig_with_device_check |
| |
|
| | for name, child in module.named_children(): |
| | module_prefix = prefix + "." + name if prefix else name |
| | |
| | if prepare_custom_config_dict is None or not ( |
| | name in prepare_custom_config_dict.get("non_traceable_module_name", []) |
| | or type(child) |
| | in prepare_custom_config_dict.get("non_traceable_module_class", []) |
| | ): |
| | _propagate_qconfig_helper( |
| | child, qconfig_dict, qconfig_with_device_check, module_prefix |
| | ) |
| |
|
| |
|
| | def propagate_qconfig_(module, qconfig_dict=None, prepare_custom_config_dict=None): |
| | r"""Propagate qconfig through the module hierarchy and assign `qconfig` |
| | attribute on each leaf module |
| | |
| | Args: |
| | module: input module |
| | qconfig_dict: dictionary that maps from name or type of submodule to |
| | quantization configuration, qconfig applies to all submodules of a |
| | given module unless qconfig for the submodules are specified (when |
| | the submodule already has qconfig attribute) |
| | prepare_custom_config_dict: dictionary for custom handling of modules |
| | see docs for :func:`~torch.ao.quantization.prepare_fx` |
| | |
| | Return: |
| | None, module is modified inplace with qconfig attached |
| | """ |
| | if qconfig_dict is None: |
| | qconfig_dict = {} |
| | if prepare_custom_config_dict is None: |
| | prepare_custom_config_dict = {} |
| | _propagate_qconfig_helper( |
| | module, qconfig_dict, prepare_custom_config_dict=prepare_custom_config_dict |
| | ) |
| |
|
| |
|
| | def _observer_forward_hook(self, input, output): |
| | r"""Forward hook that calls observer on the output""" |
| | return self.activation_post_process(output) |
| |
|
| |
|
| | def _observer_forward_pre_hook(self, input): |
| | r"""Forward pre hook that calls observer on the output""" |
| | return self.activation_post_process(input[0]) |
| |
|
| |
|
| | def _register_activation_post_process_hook(module, pre_hook=False): |
| | assert hasattr(module, "activation_post_process"), ( |
| | "Expect activation_post_process attribute already attached to the module" |
| | ) |
| | if pre_hook: |
| | module.register_forward_pre_hook(_observer_forward_pre_hook, prepend=True) |
| | else: |
| | module.register_forward_hook(_observer_forward_hook, prepend=True) |
| |
|
| |
|
| | def _add_observer_( |
| | module, |
| | qconfig_propagation_list=None, |
| | non_leaf_module_list=None, |
| | device=None, |
| | custom_module_class_mapping=None, |
| | ): |
| | r"""Add observer for the leaf child of the module. |
| | |
| | This function insert observer module to all leaf child module that |
| | has a valid qconfig attribute. |
| | |
| | Args: |
| | module: input module with qconfig attributes for all the leaf modules that we want to quantize |
| | qconfig_propagation_list: a list of quantizable modules that will have observers added to them |
| | if they are leaf nodes |
| | device: parent device, if any |
| | non_leaf_module_list: list of non-leaf modules we want to add observer |
| | |
| | Return: |
| | None, module is modified inplace with added observer modules and forward_hooks |
| | """ |
| | if qconfig_propagation_list is None: |
| | qconfig_propagation_list = get_default_qconfig_propagation_list() |
| |
|
| | if custom_module_class_mapping is None: |
| | custom_module_class_mapping = {} |
| |
|
| | |
| | if device is None: |
| | devices = _get_unique_devices_(module) |
| | assert len(devices) <= 1, ( |
| | f"_add_observer_ only works with cpu or single-device CUDA modules, but got devices {devices}" |
| | ) |
| | device = next(iter(devices)) if len(devices) > 0 else None |
| |
|
| | def get_activation_post_process(qconfig, device, special_act_post_process=None): |
| | activation = ( |
| | qconfig.activation() |
| | if special_act_post_process is None |
| | else special_act_post_process() |
| | ) |
| | if device is not None: |
| | activation.to(device) |
| | return activation |
| |
|
| | def needs_observation(m): |
| | return hasattr(m, "qconfig") and m.qconfig is not None |
| |
|
| | def insert_activation_post_process(m, special_act_post_process=None): |
| | """Adds an activation post process module and register |
| | a pre or post hook that calls the module |
| | """ |
| | |
| | if needs_observation(m) and not isinstance(m, DeQuantStub): |
| | |
| | m.add_module( |
| | "activation_post_process", |
| | get_activation_post_process( |
| | m.qconfig, device, special_act_post_process |
| | ), |
| | ) |
| | |
| | |
| | _register_activation_post_process_hook( |
| | m, pre_hook=_activation_is_memoryless(m.qconfig) |
| | ) |
| |
|
| | for name, child in module.named_children(): |
| | |
| | if type_before_parametrizations(child) in [nn.Dropout]: |
| | continue |
| | elif issubclass( |
| | type_before_parametrizations(child), (nnq.FloatFunctional, nnq.QFunctional) |
| | ): |
| | if needs_observation(child): |
| | assert hasattr(child, "activation_post_process"), ( |
| | f"functional class {type_before_parametrizations(child)} has no pre-defined `activation_post_process`" |
| | ) |
| | child.activation_post_process = get_activation_post_process( |
| | child.qconfig, device |
| | ) |
| | elif isinstance(child, _FusedModule): |
| | |
| | if needs_observation(child): |
| | insert_activation_post_process(child) |
| | elif ( |
| | non_leaf_module_list is not None |
| | and type_before_parametrizations(child) in non_leaf_module_list |
| | ): |
| | if needs_observation(child): |
| | insert_activation_post_process(child) |
| | elif _has_special_act_post_process(child): |
| | special_act_post_process = _get_special_act_post_process(child) |
| | insert_activation_post_process(child, special_act_post_process) |
| | elif ( |
| | needs_observation(child) |
| | and type_before_parametrizations(child) in custom_module_class_mapping |
| | ): |
| | observed_class = custom_module_class_mapping[ |
| | type_before_parametrizations(child) |
| | ] |
| | observed_child = observed_class.from_float(child) |
| | setattr(module, name, observed_child) |
| | |
| | |
| | if not issubclass(observed_class, tuple(no_observer_set())): |
| | insert_activation_post_process(observed_child) |
| | else: |
| | _add_observer_( |
| | child, |
| | qconfig_propagation_list, |
| | non_leaf_module_list, |
| | device, |
| | custom_module_class_mapping, |
| | ) |
| |
|
| | |
| | |
| | if ( |
| | has_no_children_ignoring_parametrizations(module) |
| | and not isinstance(module, torch.nn.Sequential) |
| | and type_before_parametrizations(module) in qconfig_propagation_list |
| | ): |
| | insert_activation_post_process(module) |
| | |
| | |
| | |
| | |
| | if ( |
| | hasattr(module, "weight_fake_quant") |
| | and not isinstance(module, torch.nn.Sequential) |
| | and type_before_parametrizations(module) in qconfig_propagation_list |
| | ): |
| | insert_activation_post_process(module) |
| |
|
| |
|
| | def _get_unique_devices_(module): |
| | return {p.device for p in module.parameters() if p.device.type != "meta"} | { |
| | p.device for p in module.buffers() if p.device.type != "meta" |
| | } |
| |
|
| |
|
| | def add_quant_dequant(module): |
| | r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig |
| | Note that this function will modify the children of module inplace and it |
| | can return a new module which wraps the input module as well. |
| | |
| | Args: |
| | module: input module with qconfig attributes for all the leaf modules |
| | that we want to quantize |
| | |
| | Return: |
| | Either the inplace modified module with submodules wrapped in |
| | `QuantWrapper` based on qconfig or a new `QuantWrapper` module which |
| | wraps the input module, the latter case only happens when the input |
| | module is a leaf module and we want to quantize it. |
| | """ |
| | if ( |
| | has_no_children_ignoring_parametrizations(module) |
| | and hasattr(module, "qconfig") |
| | and module.qconfig |
| | ): |
| | return QuantWrapper(module) |
| |
|
| | for name, child in module.named_children(): |
| | module._modules[name] = add_quant_dequant(child) |
| | return module |
| |
|
| |
|
| | @typing_extensions.deprecated(DEPRECATION_WARNING) |
| | def prepare( |
| | model, |
| | inplace=False, |
| | allow_list=None, |
| | observer_non_leaf_module_list=None, |
| | prepare_custom_config_dict=None, |
| | ): |
| | r"""Prepares a copy of the model for quantization calibration or quantization-aware training. |
| | |
| | Quantization configuration should be assigned preemptively |
| | to individual submodules in `.qconfig` attribute. |
| | |
| | The model will be attached with observer or fake quant modules, and qconfig |
| | will be propagated. |
| | |
| | Args: |
| | `model`: input model to be modified in-place |
| | `inplace`: carry out model transformations in-place, the original module is mutated |
| | `allow_list`: list of quantizable modules |
| | `observer_non_leaf_module_list`: list of non-leaf modules we want to add observer |
| | `prepare_custom_config_dict`: customization configuration dictionary for prepare function |
| | |
| | .. code-block:: python |
| | |
| | # Example of prepare_custom_config_dict: |
| | prepare_custom_config_dict = { |
| | # user will manually define the corresponding observed |
| | # module class which has a from_float class method that converts |
| | # float custom module to observed custom module |
| | "float_to_observed_custom_module_class": {CustomModule: ObservedCustomModule} |
| | } |
| | |
| | """ |
| | torch._C._log_api_usage_once("quantization_api.quantize.prepare") |
| | if prepare_custom_config_dict is None: |
| | prepare_custom_config_dict = get_default_custom_config_dict() |
| | custom_module_class_mapping = prepare_custom_config_dict.get( |
| | "float_to_observed_custom_module_class", {} |
| | ) |
| |
|
| | if not inplace: |
| | model = copy.deepcopy(model) |
| |
|
| | |
| | qconfig_propagation_list = allow_list |
| | if allow_list is None: |
| | qconfig_propagation_list = get_default_qconfig_propagation_list() |
| | propagate_qconfig_(model, qconfig_dict=None) |
| |
|
| | |
| | if not any(hasattr(m, "qconfig") and m.qconfig for m in model.modules()): |
| | warnings.warn( |
| | "None of the submodule got qconfig applied. Make sure you " |
| | "passed correct configuration through `qconfig_dict` or " |
| | "by assigning the `.qconfig` attribute directly on submodules" |
| | ) |
| |
|
| | _add_observer_( |
| | model, |
| | qconfig_propagation_list, |
| | observer_non_leaf_module_list, |
| | custom_module_class_mapping=custom_module_class_mapping, |
| | ) |
| | return model |
| |
|
| |
|
| | def _remove_activation_post_process(module): |
| | |
| | |
| | if hasattr(module, "activation_post_process") and _is_activation_post_process( |
| | module.activation_post_process |
| | ): |
| | delattr(module, "activation_post_process") |
| |
|
| | |
| | def remove_hooks(pre_hook=False): |
| | hook_map = module._forward_pre_hooks if pre_hook else module._forward_hooks |
| | observer_hook = ( |
| | _observer_forward_pre_hook if pre_hook else _observer_forward_hook |
| | ) |
| | handle_ids_to_remove = set() |
| | for handle_id, hook_fn in hook_map.items(): |
| | if hook_fn is observer_hook: |
| | handle_ids_to_remove.add(handle_id) |
| | for handle_id in handle_ids_to_remove: |
| | hook_map.pop(handle_id) |
| |
|
| | remove_hooks(pre_hook=True) |
| | remove_hooks(pre_hook=False) |
| |
|
| |
|
| | |
| | def _remove_qconfig(module): |
| | r"""Clean up the qconfig left in the module so that new qconfig can be |
| | propagated. |
| | |
| | Args: |
| | module: module to be cleaned up |
| | """ |
| | for child in module.children(): |
| | _remove_qconfig(child) |
| |
|
| | if hasattr(module, "qconfig"): |
| | del module.qconfig |
| |
|
| | _remove_activation_post_process(module) |
| |
|
| |
|
| | @typing_extensions.deprecated(DEPRECATION_WARNING) |
| | def quantize(model, run_fn, run_args, mapping=None, inplace=False): |
| | r"""Quantize the input float model with post training static quantization. |
| | |
| | First it will prepare the model for calibration, then it calls |
| | `run_fn` which will run the calibration step, after that we will |
| | convert the model to a quantized model. |
| | |
| | Args: |
| | model: input float model |
| | run_fn: a calibration function for calibrating the prepared model |
| | run_args: positional arguments for `run_fn` |
| | inplace: carry out model transformations in-place, the original module is mutated |
| | mapping: correspondence between original module types and quantized counterparts |
| | |
| | Return: |
| | Quantized model. |
| | """ |
| | torch._C._log_api_usage_once("quantization_api.quantize.quantize") |
| | if mapping is None: |
| | mapping = get_default_static_quant_module_mappings() |
| | if not inplace: |
| | model = copy.deepcopy(model) |
| | model.eval() |
| | prepare(model, inplace=True) |
| | run_fn(model, *run_args) |
| | convert(model, mapping, inplace=True) |
| | return model |
| |
|
| |
|
| | @typing_extensions.deprecated(DEPRECATION_WARNING) |
| | def quantize_dynamic( |
| | model, qconfig_spec=None, dtype=torch.qint8, mapping=None, inplace=False |
| | ): |
| | r"""Converts a float model to dynamic (i.e. weights-only) quantized model. |
| | |
| | Replaces specified modules with dynamic weight-only quantized versions and output the quantized model. |
| | |
| | For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization |
| | by default is performed for layers with large weights size - i.e. Linear and RNN variants. |
| | |
| | Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`. |
| | If `qconfig` is provided, the `dtype` argument is ignored. |
| | |
| | Args: |
| | model: input model |
| | qconfig_spec: Either: |
| | |
| | - A dictionary that maps from name or type of submodule to quantization |
| | configuration, qconfig applies to all submodules of a given |
| | module unless qconfig for the submodules are specified (when the |
| | submodule already has qconfig attribute). Entries in the dictionary |
| | need to be QConfig instances. |
| | |
| | - A set of types and/or submodule names to apply dynamic quantization to, |
| | in which case the `dtype` argument is used to specify the bit-width |
| | |
| | inplace: carry out model transformations in-place, the original module is mutated |
| | mapping: maps type of a submodule to a type of corresponding dynamically quantized version |
| | with which the submodule needs to be replaced |
| | |
| | """ |
| | torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic") |
| | if qconfig_spec is None: |
| | if dtype == torch.qint8: |
| | qconfig_spec = { |
| | nn.Linear: default_dynamic_qconfig, |
| | nn.LSTM: default_dynamic_qconfig, |
| | nn.GRU: default_dynamic_qconfig, |
| | nn.LSTMCell: default_dynamic_qconfig, |
| | nn.RNNCell: default_dynamic_qconfig, |
| | nn.GRUCell: default_dynamic_qconfig, |
| | } |
| | elif dtype == torch.float16: |
| | qconfig_spec = { |
| | nn.Linear: float16_dynamic_qconfig, |
| | nn.LSTM: float16_dynamic_qconfig, |
| | nn.GRU: float16_dynamic_qconfig, |
| | nn.LSTMCell: float16_dynamic_qconfig, |
| | nn.RNNCell: float16_dynamic_qconfig, |
| | nn.GRUCell: float16_dynamic_qconfig, |
| | } |
| | elif dtype == torch.quint8: |
| | qconfig_spec = { |
| | nn.EmbeddingBag: float_qparams_weight_only_qconfig, |
| | nn.Embedding: float_qparams_weight_only_qconfig, |
| | } |
| | elif dtype == torch.quint4x2: |
| | qconfig_spec = { |
| | nn.EmbeddingBag: float_qparams_weight_only_qconfig_4bit, |
| | } |
| | else: |
| | raise ValueError( |
| | f"Don't know how to quantize with default settings for {dtype}. Provide full qconfig please" |
| | ) |
| | elif isinstance(qconfig_spec, set): |
| | if dtype is torch.qint8: |
| | default_qconfig = default_dynamic_qconfig |
| | elif dtype is torch.float16: |
| | default_qconfig = float16_dynamic_qconfig |
| | elif dtype is torch.quint8: |
| | default_qconfig = float_qparams_weight_only_qconfig |
| | elif dtype is torch.quint4x2: |
| | default_qconfig = float_qparams_weight_only_qconfig_4bit |
| | else: |
| | raise RuntimeError( |
| | "Unknown dtype specified for quantize_dynamic: ", str(dtype) |
| | ) |
| | qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig))) |
| |
|
| | if mapping is None: |
| | mapping = get_default_dynamic_quant_module_mappings() |
| |
|
| | if not inplace: |
| | model = copy.deepcopy(model) |
| | model.eval() |
| | propagate_qconfig_(model, qconfig_spec) |
| | convert(model, mapping, inplace=True) |
| | return model |
| |
|
| |
|
| | @typing_extensions.deprecated(DEPRECATION_WARNING) |
| | def prepare_qat(model, mapping=None, inplace=False): |
| | r""" |
| | Prepares a copy of the model for quantization calibration or |
| | quantization-aware training and converts it to quantized version. |
| | |
| | Quantization configuration should be assigned preemptively |
| | to individual submodules in `.qconfig` attribute. |
| | |
| | Args: |
| | model: input model to be modified in-place |
| | mapping: dictionary that maps float modules to quantized modules to be |
| | replaced. |
| | inplace: carry out model transformations in-place, the original module |
| | is mutated |
| | """ |
| | torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat") |
| | assert model.training, "prepare_qat only works on models in training mode" |
| | if mapping is None: |
| | mapping = get_default_qat_module_mappings() |
| |
|
| | if not inplace: |
| | model = copy.deepcopy(model) |
| |
|
| | propagate_qconfig_(model, qconfig_dict=None) |
| | convert(model, mapping=mapping, inplace=True, remove_qconfig=False) |
| | prepare(model, observer_non_leaf_module_list=set(mapping.values()), inplace=True) |
| | return model |
| |
|
| |
|
| | @typing_extensions.deprecated(DEPRECATION_WARNING) |
| | def quantize_qat(model, run_fn, run_args, inplace=False): |
| | r"""Do quantization aware training and output a quantized model |
| | |
| | Args: |
| | model: input model |
| | run_fn: a function for evaluating the prepared model, can be a |
| | function that simply runs the prepared model or a training |
| | loop |
| | run_args: positional arguments for `run_fn` |
| | |
| | Return: |
| | Quantized model. |
| | """ |
| | torch._C._log_api_usage_once("quantization_api.quantize.quantize_qat") |
| | if not inplace: |
| | model = copy.deepcopy(model) |
| | model.train() |
| | prepare_qat(model, inplace=True) |
| | run_fn(model, *run_args) |
| | convert(model, inplace=True) |
| | return model |
| |
|
| |
|
| | @typing_extensions.deprecated(DEPRECATION_WARNING) |
| | def convert( |
| | module, |
| | mapping=None, |
| | inplace=False, |
| | remove_qconfig=True, |
| | is_reference=False, |
| | convert_custom_config_dict=None, |
| | use_precomputed_fake_quant=False, |
| | ): |
| | r"""Converts submodules in input module to a different module according to `mapping` |
| | by calling `from_float` method on the target module class. And remove qconfig at the |
| | end if remove_qconfig is set to True. |
| | |
| | Args: |
| | `module`: prepared and calibrated module |
| | `mapping`: a dictionary that maps from source module type to target |
| | module type, can be overwritten to allow swapping user defined |
| | Modules |
| | `inplace`: carry out model transformations in-place, the original module |
| | is mutated |
| | `convert_custom_config_dict`: custom configuration dictionary for convert function |
| | `use_precomputed_fake_quant`: a flag to enable use of precomputed fake quant |
| | |
| | .. code-block:: python |
| | |
| | # Example of convert_custom_config_dict: |
| | convert_custom_config_dict = { |
| | # user will manually define the corresponding quantized |
| | # module class which has a from_observed class method that converts |
| | # observed custom module to quantized custom module |
| | "observed_to_quantized_custom_module_class": { |
| | ObservedCustomModule: QuantizedCustomModule |
| | } |
| | } |
| | |
| | """ |
| | torch._C._log_api_usage_once("quantization_api.quantize.convert") |
| | if not inplace: |
| | module = copy.deepcopy(module) |
| | _convert( |
| | module, |
| | mapping, |
| | inplace=True, |
| | is_reference=is_reference, |
| | convert_custom_config_dict=convert_custom_config_dict, |
| | use_precomputed_fake_quant=use_precomputed_fake_quant, |
| | ) |
| | if remove_qconfig: |
| | _remove_qconfig(module) |
| | return module |
| |
|
| |
|
| | def _convert( |
| | module, |
| | mapping=None, |
| | inplace=False, |
| | is_reference=False, |
| | convert_custom_config_dict=None, |
| | use_precomputed_fake_quant=False, |
| | ): |
| | r"""Converts submodules in input module to a different module according to `mapping` |
| | by calling `from_float` method on the target module class |
| | |
| | Args: |
| | module: input module |
| | mapping: a dictionary that maps from source module type to target |
| | module type, can be overwritten to allow swapping user defined |
| | Modules |
| | inplace: carry out model transformations in-place, the original module |
| | is mutated |
| | is_reference: a flag to enable quantized reference module |
| | use_precomputed_fake_quant: a flag to enable use of precomputed fake quant |
| | |
| | """ |
| | if mapping is None: |
| | mapping = ( |
| | get_default_static_quant_reference_module_mappings() |
| | if is_reference |
| | else get_default_static_quant_module_mappings() |
| | ) |
| | if convert_custom_config_dict is None: |
| | convert_custom_config_dict = get_default_custom_config_dict() |
| | custom_module_class_mapping = convert_custom_config_dict.get( |
| | "observed_to_quantized_custom_module_class", {} |
| | ) |
| |
|
| | if not inplace: |
| | module = copy.deepcopy(module) |
| | reassign = {} |
| | for name, mod in module.named_children(): |
| | |
| | |
| | if ( |
| | not isinstance(mod, _FusedModule) |
| | and type_before_parametrizations(mod) not in custom_module_class_mapping |
| | ): |
| | _convert( |
| | mod, |
| | mapping, |
| | True, |
| | is_reference, |
| | convert_custom_config_dict, |
| | use_precomputed_fake_quant=use_precomputed_fake_quant, |
| | ) |
| | reassign[name] = swap_module( |
| | mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant |
| | ) |
| |
|
| | for key, value in reassign.items(): |
| | module._modules[key] = value |
| |
|
| | return module |
| |
|
| |
|
| | def swap_module( |
| | mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant=False |
| | ): |
| | r"""Swaps the module if it has a quantized counterpart and it has an |
| | `observer` attached. |
| | |
| | Args: |
| | mod: input module |
| | mapping: a dictionary that maps from nn module to nnq module |
| | |
| | Return: |
| | The corresponding quantized module of `mod` |
| | """ |
| | new_mod = mod |
| | if hasattr(mod, "qconfig") and mod.qconfig is not None: |
| | swapped = False |
| | if type_before_parametrizations(mod) in custom_module_class_mapping: |
| | new_mod = custom_module_class_mapping[ |
| | type_before_parametrizations(mod) |
| | ].from_observed(mod) |
| | swapped = True |
| | elif type_before_parametrizations(mod) in mapping: |
| | qmod = mapping[type_before_parametrizations(mod)] |
| | if hasattr(qmod, "_IS_REFERENCE") and qmod._IS_REFERENCE: |
| | assert mod.qconfig is not None |
| | weight_post_process = mod.qconfig.weight() |
| | weight_post_process(mod.weight) |
| | weight_qparams = get_qparam_dict(weight_post_process) |
| | new_mod = qmod.from_float(mod, weight_qparams) |
| | else: |
| | sig = inspect.signature(qmod.from_float) |
| | if "use_precomputed_fake_quant" in sig.parameters: |
| | new_mod = qmod.from_float( |
| | mod, use_precomputed_fake_quant=use_precomputed_fake_quant |
| | ) |
| | else: |
| | new_mod = qmod.from_float(mod) |
| | swapped = True |
| |
|
| | if swapped: |
| | |
| | for pre_hook_fn in mod._forward_pre_hooks.values(): |
| | new_mod.register_forward_pre_hook(pre_hook_fn) |
| | |
| | |
| | for hook_fn in mod._forward_hooks.values(): |
| | if hook_fn is not _observer_forward_hook: |
| | new_mod.register_forward_hook(hook_fn) |
| |
|
| | |
| | devices = _get_unique_devices_(mod) |
| | assert len(devices) <= 1 or ( |
| | len(devices) == 2 and torch.device("meta") in devices |
| | ), ( |
| | f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" |
| | ) |
| | device = next(iter(devices)) if len(devices) > 0 else None |
| | if device: |
| | new_mod.to(device) |
| | return new_mod |
| |
|
| |
|
| | def _get_observer_dict(mod, target_dict, prefix=""): |
| | r"""Traverse the modules and save all observers into dict. |
| | This is mainly used for quantization accuracy debug |
| | Args: |
| | mod: the top module we want to save all observers |
| | prefix: the prefix for the current module |
| | target_dict: the dictionary used to save all the observers |
| | """ |
| |
|
| | def get_prefix(prefix): |
| | return prefix if prefix == "" else prefix + "." |
| |
|
| | if hasattr(mod, "activation_post_process"): |
| | target_dict[get_prefix(prefix) + "activation_post_process"] = ( |
| | mod.activation_post_process |
| | ) |
| | for name, child in mod.named_children(): |
| | module_prefix = get_prefix(prefix) + name if prefix else name |
| | _get_observer_dict(child, target_dict, module_prefix) |
| |
|