|
|
|
|
|
|
|
|
import copy
|
|
|
import operator
|
|
|
import warnings
|
|
|
from typing import Any, Callable, Optional, Union
|
|
|
|
|
|
import torch
|
|
|
from torch.ao.quantization import CUSTOM_KEY, NUMERIC_DEBUG_HANDLE_KEY
|
|
|
from torch.ao.quantization.backend_config import (
|
|
|
BackendConfig,
|
|
|
get_native_backend_config,
|
|
|
)
|
|
|
from torch.ao.quantization.backend_config.utils import (
|
|
|
get_fused_module_classes,
|
|
|
get_pattern_to_dtype_configs,
|
|
|
get_qat_module_classes,
|
|
|
get_root_module_to_quantized_reference_module,
|
|
|
)
|
|
|
from torch.ao.quantization.observer import _is_activation_post_process
|
|
|
from torch.ao.quantization.qconfig import qconfig_equals, QConfigAny
|
|
|
from torch.ao.quantization.qconfig_mapping import QConfigMapping
|
|
|
from torch.ao.quantization.quant_type import QuantType
|
|
|
from torch.ao.quantization.quantize import _remove_qconfig
|
|
|
from torch.ao.quantization.stubs import DeQuantStub
|
|
|
from torch.ao.quantization.utils import (
|
|
|
_parent_name,
|
|
|
activation_is_statically_quantized,
|
|
|
get_qparam_dict,
|
|
|
get_swapped_custom_module_class,
|
|
|
is_per_channel,
|
|
|
to_underlying_dtype,
|
|
|
weight_is_quantized,
|
|
|
)
|
|
|
from torch.fx import GraphModule
|
|
|
from torch.fx.graph import Argument, Graph, Node
|
|
|
from torch.nn.utils.parametrize import type_before_parametrizations
|
|
|
|
|
|
|
|
|
from ._decomposed import quantized_decomposed_lib
|
|
|
from ._equalize import convert_eq_obs, update_obs_for_equalization
|
|
|
from .custom_config import ConvertCustomConfig, PrepareCustomConfig
|
|
|
from .graph_module import _is_observed_module, _is_observed_standalone_module
|
|
|
from .lower_to_fbgemm import lower_to_fbgemm
|
|
|
from .qconfig_mapping_utils import (
|
|
|
_compare_prepare_convert_qconfig_mappings,
|
|
|
_generate_node_name_to_qconfig,
|
|
|
_is_qconfig_supported_by_dtype_configs,
|
|
|
_update_qconfig_for_fusion,
|
|
|
_update_qconfig_for_qat,
|
|
|
)
|
|
|
from .utils import (
|
|
|
_get_module,
|
|
|
_is_custom_module_lstm,
|
|
|
_is_custom_module_mha,
|
|
|
assert_and_get_unique_device,
|
|
|
collect_producer_nodes,
|
|
|
create_getattr_from_value,
|
|
|
get_custom_module_class_keys,
|
|
|
graph_module_from_producer_nodes,
|
|
|
node_arg_is_weight,
|
|
|
)
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
"convert",
|
|
|
"convert_custom_module",
|
|
|
"convert_standalone_module",
|
|
|
"convert_weighted_module",
|
|
|
]
|
|
|
|
|
|
SUPPORTED_QDTYPES = [
|
|
|
torch.quint8,
|
|
|
torch.qint8,
|
|
|
torch.qint32,
|
|
|
torch.uint8,
|
|
|
torch.int8,
|
|
|
torch.uint16,
|
|
|
torch.int16,
|
|
|
torch.int32,
|
|
|
torch.float8_e5m2,
|
|
|
torch.float8_e4m3fn,
|
|
|
]
|
|
|
|
|
|
_QSCHEME_TO_CHOOSE_QPARAMS_OP = {
|
|
|
torch.per_tensor_affine: torch.ops.quantized_decomposed.choose_qparams.tensor,
|
|
|
torch.per_tensor_symmetric: torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor,
|
|
|
}
|
|
|
|
|
|
|
|
|
def _replace_observer_with_quantize_dequantize_node_decomposed(
|
|
|
model: torch.fx.GraphModule,
|
|
|
node: Node,
|
|
|
modules: dict[str, torch.nn.Module],
|
|
|
node_name_to_scope: dict[str, tuple[str, type]],
|
|
|
node_name_to_qconfig: dict[str, QConfigAny],
|
|
|
) -> None:
|
|
|
"""Replace activation_post_process module call node with quantize and
|
|
|
dequantize node working with decomposed Tensor
|
|
|
|
|
|
Before:
|
|
|
... -> observer_0(x) -> ...
|
|
|
After:
|
|
|
... -> torch.ops.quantized_decomposed.quantize_per_tensor(x, ...) ->
|
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor() -> ...
|
|
|
|
|
|
or quantize_per_channel and dequantize_per_channel
|
|
|
"""
|
|
|
graph = model.graph
|
|
|
assert modules is not None
|
|
|
assert isinstance(node.target, str)
|
|
|
module_path, prefix = _get_module_path_and_prefix(
|
|
|
node, node_name_to_scope, node_name_to_qconfig
|
|
|
)
|
|
|
activation_post_process = modules[node.target]
|
|
|
if hasattr(activation_post_process, "convert"):
|
|
|
activation_post_process.convert(model, node)
|
|
|
return
|
|
|
|
|
|
|
|
|
skip_replacement = all(
|
|
|
_has_none_qconfig(n, node_name_to_qconfig)
|
|
|
for n in list(node.args) + list(node.users.keys())
|
|
|
)
|
|
|
if skip_replacement or not _is_conversion_supported(activation_post_process):
|
|
|
|
|
|
|
|
|
with graph.inserting_before(node):
|
|
|
node.replace_all_uses_with(node.args[0])
|
|
|
graph.erase_node(node)
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dtype = activation_post_process.dtype
|
|
|
|
|
|
is_dynamic = False
|
|
|
if hasattr(activation_post_process, "is_dynamic"):
|
|
|
is_dynamic = activation_post_process.is_dynamic
|
|
|
|
|
|
def add_dequantize_op_kwargs(dequantize_op, input_node):
|
|
|
dequantize_op_kwargs = {}
|
|
|
if "val" in input_node.meta:
|
|
|
dq_out_dtype = input_node.meta["val"].dtype
|
|
|
if dq_out_dtype != torch.float32:
|
|
|
dequantize_op_kwargs = {"out_dtype": dq_out_dtype}
|
|
|
return dequantize_op_kwargs
|
|
|
|
|
|
if dtype in SUPPORTED_QDTYPES and (not is_dynamic):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
node_type = "call_function"
|
|
|
quantize_op: Optional[Callable] = None
|
|
|
scale, zero_point = activation_post_process.calculate_qparams()
|
|
|
if is_per_channel(activation_post_process.qscheme):
|
|
|
ch_axis = int(activation_post_process.ch_axis)
|
|
|
quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default
|
|
|
dequantize_op = (
|
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default
|
|
|
)
|
|
|
quant_min = activation_post_process.quant_min
|
|
|
quant_max = activation_post_process.quant_max
|
|
|
dtype_ = to_underlying_dtype(dtype)
|
|
|
qparams = {
|
|
|
"_scale_": scale,
|
|
|
"_zero_point_": zero_point,
|
|
|
"_axis_": ch_axis,
|
|
|
"_quant_min_": quant_min,
|
|
|
"_quant_max_": quant_max,
|
|
|
"_dtype_": dtype_,
|
|
|
}
|
|
|
else:
|
|
|
quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
|
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
|
scale = float(scale)
|
|
|
zero_point = int(zero_point)
|
|
|
quant_min = activation_post_process.quant_min
|
|
|
quant_max = activation_post_process.quant_max
|
|
|
dtype_ = to_underlying_dtype(dtype)
|
|
|
qparams = {
|
|
|
"_scale_": scale,
|
|
|
"_zero_point_": zero_point,
|
|
|
"_quant_min_": quant_min,
|
|
|
"_quant_max_": quant_max,
|
|
|
"_dtype_": dtype_,
|
|
|
}
|
|
|
|
|
|
|
|
|
with graph.inserting_before(node):
|
|
|
input_node = node.args[0]
|
|
|
quantize_op_inputs = [input_node]
|
|
|
for key, value_or_node in qparams.items():
|
|
|
|
|
|
|
|
|
if key in ["_scale_", "_zero_point_"] and (
|
|
|
not isinstance(value_or_node, (float, int))
|
|
|
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qparam_node = create_getattr_from_value(
|
|
|
model, graph, module_path + prefix + key, value_or_node
|
|
|
)
|
|
|
quantize_op_inputs.append(qparam_node)
|
|
|
else:
|
|
|
|
|
|
quantize_op_inputs.append(value_or_node)
|
|
|
|
|
|
quantized_node = graph.create_node(
|
|
|
node_type, quantize_op, tuple(quantize_op_inputs), {}
|
|
|
)
|
|
|
|
|
|
dq_inputs = [quantized_node] + quantize_op_inputs[1:]
|
|
|
dequantized_node = graph.call_function(
|
|
|
dequantize_op,
|
|
|
tuple(dq_inputs),
|
|
|
add_dequantize_op_kwargs(dequantize_op, input_node),
|
|
|
)
|
|
|
|
|
|
node.replace_all_uses_with(dequantized_node)
|
|
|
|
|
|
if (
|
|
|
CUSTOM_KEY in node.meta
|
|
|
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
|
|
|
):
|
|
|
if CUSTOM_KEY not in dequantized_node.meta:
|
|
|
dequantized_node.meta[CUSTOM_KEY] = {}
|
|
|
dequantized_node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = node.meta[
|
|
|
CUSTOM_KEY
|
|
|
][NUMERIC_DEBUG_HANDLE_KEY]
|
|
|
graph.erase_node(node)
|
|
|
elif is_dynamic:
|
|
|
|
|
|
|
|
|
|
|
|
node_type = "call_function"
|
|
|
quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dtype_ = to_underlying_dtype(dtype)
|
|
|
assert dtype_ in [torch.uint8, torch.int8], (
|
|
|
"only uint8 and int8 are supported in reference flow for "
|
|
|
"dynamic quantization right now"
|
|
|
)
|
|
|
quant_min = activation_post_process.quant_min
|
|
|
quant_max = activation_post_process.quant_max
|
|
|
qscheme = getattr(activation_post_process, "qscheme", torch.per_tensor_affine)
|
|
|
eps = getattr(activation_post_process, "eps", torch.finfo(torch.float32).eps)
|
|
|
|
|
|
|
|
|
|
|
|
qparams = {
|
|
|
"_quant_min_": quant_min,
|
|
|
"_quant_max_": quant_max,
|
|
|
"_eps_": eps,
|
|
|
"_dtype_": dtype_,
|
|
|
}
|
|
|
|
|
|
choose_qparams_op = _QSCHEME_TO_CHOOSE_QPARAMS_OP[qscheme]
|
|
|
|
|
|
with graph.inserting_before(node):
|
|
|
input_node = node.args[0]
|
|
|
choose_qparams_op_inputs = [node.args[0]]
|
|
|
for key, value in qparams.items():
|
|
|
|
|
|
|
|
|
choose_qparams_op_inputs.append(value)
|
|
|
choose_qparams_node = graph.create_node(
|
|
|
"call_function", choose_qparams_op, tuple(choose_qparams_op_inputs), {}
|
|
|
)
|
|
|
|
|
|
scale_node = graph.create_node(
|
|
|
"call_function", operator.getitem, (choose_qparams_node, 0), {}
|
|
|
)
|
|
|
zero_point_node = graph.create_node(
|
|
|
"call_function", operator.getitem, (choose_qparams_node, 1), {}
|
|
|
)
|
|
|
quant_min = qparams["_quant_min_"]
|
|
|
quant_max = qparams["_quant_max_"]
|
|
|
dtype = qparams["_dtype_"]
|
|
|
qparams = {
|
|
|
"_scale_": scale_node,
|
|
|
"_zero_point_": zero_point_node,
|
|
|
"_quant_min_": quant_min,
|
|
|
"_quant_max_": quant_max,
|
|
|
"_dtype_": dtype,
|
|
|
}
|
|
|
|
|
|
|
|
|
with graph.inserting_before(node):
|
|
|
input_node = node.args[0]
|
|
|
quantize_op_inputs = [input_node]
|
|
|
for key, value_or_node in qparams.items():
|
|
|
|
|
|
|
|
|
if key in ["_scale_", "_zero_point_"]:
|
|
|
|
|
|
|
|
|
qparam_node = value_or_node
|
|
|
quantize_op_inputs.append(qparam_node)
|
|
|
else:
|
|
|
|
|
|
|
|
|
quantize_op_inputs.append(value_or_node)
|
|
|
|
|
|
quantized_node = graph.create_node(
|
|
|
node_type, quantize_op, tuple(quantize_op_inputs), {}
|
|
|
)
|
|
|
|
|
|
dq_inputs = [quantized_node] + quantize_op_inputs[1:]
|
|
|
|
|
|
|
|
|
|
|
|
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
|
|
|
dequantized_node = graph.call_function(
|
|
|
dequantize_op,
|
|
|
tuple(dq_inputs),
|
|
|
add_dequantize_op_kwargs(dequantize_op, input_node),
|
|
|
)
|
|
|
|
|
|
node.replace_all_uses_with(dequantized_node)
|
|
|
|
|
|
if NUMERIC_DEBUG_HANDLE_KEY in node.meta:
|
|
|
dequantized_node.meta[NUMERIC_DEBUG_HANDLE_KEY] = node.meta[
|
|
|
NUMERIC_DEBUG_HANDLE_KEY
|
|
|
]
|
|
|
graph.erase_node(node)
|
|
|
elif dtype == torch.float16:
|
|
|
|
|
|
dtype_convert_op = torch.ops.quantized_decomposed.convert_element_type.no_fuse
|
|
|
with graph.inserting_before(node):
|
|
|
input_node = node.args[0]
|
|
|
convert_fp16_node = graph.create_node(
|
|
|
"call_function", dtype_convert_op, (input_node, torch.float16), {}
|
|
|
)
|
|
|
convert_fp32_node = graph.create_node(
|
|
|
"call_function", dtype_convert_op, (convert_fp16_node, torch.float), {}
|
|
|
)
|
|
|
node.replace_all_uses_with(convert_fp32_node)
|
|
|
graph.erase_node(node)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _replace_observer_with_quantize_dequantize_node(
|
|
|
model: torch.fx.GraphModule,
|
|
|
node: Node,
|
|
|
modules: dict[str, torch.nn.Module],
|
|
|
node_name_to_scope: dict[str, tuple[str, type]],
|
|
|
node_name_to_qconfig: dict[str, QConfigAny],
|
|
|
) -> None:
|
|
|
"""Replace activation_post_process module call node with quantize and
|
|
|
dequantize node
|
|
|
|
|
|
Before:
|
|
|
... -> observer_0(x) -> ...
|
|
|
After:
|
|
|
... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
|
|
|
"""
|
|
|
assert modules is not None
|
|
|
assert isinstance(node.target, str)
|
|
|
graph = model.graph
|
|
|
module_path, prefix = _get_module_path_and_prefix(
|
|
|
node, node_name_to_scope, node_name_to_qconfig
|
|
|
)
|
|
|
activation_post_process = modules[node.target]
|
|
|
|
|
|
|
|
|
skip_replacement = all(
|
|
|
_has_none_qconfig(n, node_name_to_qconfig)
|
|
|
for n in list(node.args) + list(node.users.keys())
|
|
|
)
|
|
|
if skip_replacement or not _is_conversion_supported(activation_post_process):
|
|
|
|
|
|
|
|
|
with graph.inserting_before(node):
|
|
|
node.replace_all_uses_with(node.args[0])
|
|
|
graph.erase_node(node)
|
|
|
return
|
|
|
|
|
|
|
|
|
dtype = activation_post_process.dtype
|
|
|
|
|
|
is_dynamic = False
|
|
|
if hasattr(activation_post_process, "is_dynamic"):
|
|
|
is_dynamic = activation_post_process.is_dynamic
|
|
|
|
|
|
if dtype in [
|
|
|
torch.quint8,
|
|
|
torch.qint8,
|
|
|
torch.qint32,
|
|
|
torch.float8_e5m2,
|
|
|
torch.float8_e4m3fn,
|
|
|
] and (not is_dynamic):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
node_type = "call_function"
|
|
|
quantize_op: Optional[Callable] = None
|
|
|
scale, zero_point = activation_post_process.calculate_qparams()
|
|
|
if is_per_channel(activation_post_process.qscheme):
|
|
|
ch_axis = int(activation_post_process.ch_axis)
|
|
|
qparams = {
|
|
|
"_scale_": scale,
|
|
|
"_zero_point_": zero_point,
|
|
|
"_axis_": ch_axis,
|
|
|
"_dtype_": dtype,
|
|
|
}
|
|
|
quantize_op = torch.quantize_per_channel
|
|
|
else:
|
|
|
scale = float(scale)
|
|
|
zero_point = int(zero_point)
|
|
|
qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype}
|
|
|
quantize_op = torch.quantize_per_tensor
|
|
|
|
|
|
|
|
|
with graph.inserting_before(node):
|
|
|
input_node = node.args[0]
|
|
|
quantize_op_inputs = [input_node]
|
|
|
for key, value_or_node in qparams.items():
|
|
|
|
|
|
|
|
|
if key in ["_scale_", "_zero_point_"]:
|
|
|
|
|
|
|
|
|
qparam_node = create_getattr_from_value(
|
|
|
model, graph, module_path + prefix + key, value_or_node
|
|
|
)
|
|
|
quantize_op_inputs.append(qparam_node)
|
|
|
else:
|
|
|
|
|
|
quantize_op_inputs.append(value_or_node)
|
|
|
|
|
|
quantized_node = graph.create_node(
|
|
|
node_type, quantize_op, tuple(quantize_op_inputs), {}
|
|
|
)
|
|
|
dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
|
|
|
node.replace_all_uses_with(dequantized_node)
|
|
|
graph.erase_node(node)
|
|
|
elif is_dynamic:
|
|
|
|
|
|
|
|
|
node_type = "call_function"
|
|
|
quantize_op = torch.quantize_per_tensor_dynamic
|
|
|
|
|
|
|
|
|
reduce_range = torch.backends.quantized.engine in ("fbgemm", "x86")
|
|
|
qparams = {"_dtype_": dtype, "_reduce_range_": reduce_range}
|
|
|
|
|
|
with graph.inserting_before(node):
|
|
|
input_node = node.args[0]
|
|
|
quantize_op_inputs = [input_node]
|
|
|
for key, value in qparams.items():
|
|
|
quantize_op_inputs.append(value)
|
|
|
|
|
|
quantized_node = graph.create_node(
|
|
|
node_type, quantize_op, tuple(quantize_op_inputs), {}
|
|
|
)
|
|
|
dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
|
|
|
node.replace_all_uses_with(dequantized_node)
|
|
|
graph.erase_node(node)
|
|
|
elif dtype == torch.float16:
|
|
|
node_type = "call_method"
|
|
|
quantize_op = "to"
|
|
|
qparams = {"_dtype_": dtype}
|
|
|
with graph.inserting_before(node):
|
|
|
input_node = node.args[0]
|
|
|
quantize_op_inputs = [input_node]
|
|
|
for key, value in qparams.items():
|
|
|
|
|
|
|
|
|
quantize_op_inputs.append(value)
|
|
|
|
|
|
quantized_node = graph.create_node(
|
|
|
node_type, quantize_op, tuple(quantize_op_inputs), {}
|
|
|
)
|
|
|
dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
|
|
|
node.replace_all_uses_with(dequantized_node)
|
|
|
graph.erase_node(node)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _replace_observer_or_dequant_stub_with_dequantize_node(
|
|
|
node: Node, graph: Graph
|
|
|
) -> None:
|
|
|
call_custom_module_node = node.args[0]
|
|
|
assert isinstance(call_custom_module_node, Node), (
|
|
|
f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}"
|
|
|
)
|
|
|
node.replace_all_uses_with(call_custom_module_node)
|
|
|
graph.erase_node(node)
|
|
|
_insert_dequantize_node(call_custom_module_node, graph)
|
|
|
|
|
|
|
|
|
def _is_conversion_supported(activation_post_process: torch.nn.Module) -> bool:
|
|
|
dtype = activation_post_process.dtype
|
|
|
|
|
|
is_dynamic = False
|
|
|
if hasattr(activation_post_process, "is_dynamic"):
|
|
|
is_dynamic = activation_post_process.is_dynamic
|
|
|
|
|
|
return (
|
|
|
(dtype in SUPPORTED_QDTYPES and (not is_dynamic))
|
|
|
or is_dynamic
|
|
|
or dtype == torch.float16
|
|
|
)
|
|
|
|
|
|
|
|
|
def _has_none_qconfig(
|
|
|
node: Argument, node_name_to_qconfig: dict[str, QConfigAny]
|
|
|
) -> bool:
|
|
|
"""Check if a node has a qconfig of None, i.e. user requested to not quantize
|
|
|
the node
|
|
|
"""
|
|
|
return (
|
|
|
isinstance(node, Node)
|
|
|
and node.name in node_name_to_qconfig
|
|
|
and node_name_to_qconfig[node.name] is None
|
|
|
)
|
|
|
|
|
|
|
|
|
def _run_weight_observers(observed: GraphModule, backend_config: BackendConfig) -> None:
|
|
|
"""Extract the subgraph that produces the weight for dynamic quant
|
|
|
or weight only quant node and run the subgraph to observe the weight.
|
|
|
Note that the observers of dynamic quant or weight only quant ops are
|
|
|
run during the convert step.
|
|
|
"""
|
|
|
for node in observed.graph.nodes:
|
|
|
if node.op != "call_function":
|
|
|
continue
|
|
|
for node_arg in node.args:
|
|
|
|
|
|
if node_arg and node_arg_is_weight(node, node_arg):
|
|
|
weight_observer_nodes = collect_producer_nodes(node_arg)
|
|
|
if weight_observer_nodes is None:
|
|
|
continue
|
|
|
weight_observer_module = graph_module_from_producer_nodes(
|
|
|
observed, weight_observer_nodes
|
|
|
)
|
|
|
|
|
|
weight_observer_module()
|
|
|
|
|
|
|
|
|
def _maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph) -> None:
|
|
|
"""If the arg is a dequantize Node, or a list/tuple/dict of dequantize Node,
|
|
|
we'll recursively remove the dequantize Node
|
|
|
"""
|
|
|
if isinstance(arg, Node) and arg.op == "call_method" and arg.target == "dequantize":
|
|
|
quantize_node = arg.args[0]
|
|
|
|
|
|
|
|
|
node.replace_input_with(arg, quantize_node)
|
|
|
elif isinstance(arg, (list, tuple)):
|
|
|
for arg_element in arg:
|
|
|
_maybe_recursive_remove_dequantize(arg_element, node, graph)
|
|
|
elif isinstance(arg, dict):
|
|
|
for arg_element in arg.values():
|
|
|
_maybe_recursive_remove_dequantize(arg_element, node, graph)
|
|
|
else:
|
|
|
warnings.warn(
|
|
|
f"Unsupported node type in recursive remove dequantize: {type(arg)}"
|
|
|
)
|
|
|
|
|
|
|
|
|
def _get_module_path_and_prefix(
|
|
|
obs_node: Node,
|
|
|
node_name_to_scope: dict[str, tuple[str, type]],
|
|
|
node_name_to_qconfig: dict[str, QConfigAny],
|
|
|
) -> tuple[str, str]:
|
|
|
"""Given and observer node, get the `Scope` or the fully qualified name for
|
|
|
the submodule containing the observed node, also return a prefix of "_input"
|
|
|
when the observed node is an input of a F.linear op, and not the output of another
|
|
|
quantized op.
|
|
|
TODO: this logic is hacky, we should think about how to remove it or make it more
|
|
|
general
|
|
|
"""
|
|
|
observed_node = obs_node.args[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert isinstance(observed_node, Node), (
|
|
|
f"Expecting observed node to be a Node, but got {observed_node}"
|
|
|
)
|
|
|
is_input_observer_only = (
|
|
|
node_name_to_qconfig[observed_node.name] is None
|
|
|
if observed_node.name in node_name_to_qconfig
|
|
|
else None
|
|
|
)
|
|
|
if is_input_observer_only:
|
|
|
|
|
|
|
|
|
|
|
|
users = list(obs_node.users)
|
|
|
first_linear_use_or_first_use = users[0] if users else None
|
|
|
linear_node = None
|
|
|
for n in users:
|
|
|
if n.op == "call_function" and n.target == torch.nn.functional.linear:
|
|
|
linear_node = n
|
|
|
break
|
|
|
if linear_node:
|
|
|
first_linear_use_or_first_use = linear_node
|
|
|
prefix = "_input"
|
|
|
else:
|
|
|
|
|
|
first_linear_use_or_first_use = observed_node
|
|
|
prefix = ""
|
|
|
|
|
|
if (
|
|
|
first_linear_use_or_first_use
|
|
|
and first_linear_use_or_first_use.name in node_name_to_scope
|
|
|
):
|
|
|
module_path, _ = node_name_to_scope[first_linear_use_or_first_use.name]
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
|
|
module_path = ""
|
|
|
return module_path, prefix
|
|
|
|
|
|
|
|
|
def _insert_dequantize_node(node: Node, graph: Graph) -> None:
|
|
|
"""Inserts dequantize node for `node` in `graph`"""
|
|
|
with graph.inserting_after(node):
|
|
|
dequantize_node = graph.call_method("dequantize", (node,))
|
|
|
for user_node in dict(node.users):
|
|
|
if user_node is not dequantize_node:
|
|
|
user_node.replace_input_with(node, dequantize_node)
|
|
|
|
|
|
|
|
|
def _maybe_get_observer_for_node(
|
|
|
node: Node, modules: dict[str, torch.nn.Module]
|
|
|
) -> Optional[torch.nn.Module]:
|
|
|
"""
|
|
|
If the node is observed, return the observer
|
|
|
instance. Otherwise, return None.
|
|
|
"""
|
|
|
for maybe_obs_node in node.users.keys():
|
|
|
if maybe_obs_node.op == "call_module":
|
|
|
maybe_obs = modules[str(maybe_obs_node.target)]
|
|
|
if _is_activation_post_process(maybe_obs):
|
|
|
return maybe_obs
|
|
|
return None
|
|
|
|
|
|
|
|
|
def convert_standalone_module(
|
|
|
node: Node,
|
|
|
modules: dict[str, torch.nn.Module],
|
|
|
model: torch.fx.GraphModule,
|
|
|
is_reference: bool,
|
|
|
backend_config: Optional[BackendConfig],
|
|
|
) -> None:
|
|
|
"""Converts a observed standalone module to a quantized standalone module by calling
|
|
|
the fx convert api, currently using the same `is_reference` flag as parent, but we may
|
|
|
changing this behavior in the future (e.g. separating quantization and lowering for
|
|
|
standalone module as well)
|
|
|
|
|
|
Args:
|
|
|
- node: The call_module node of the observed standalone module
|
|
|
- modules: named_module of original model
|
|
|
- model: original model
|
|
|
- is_reference: a flag from parent provided by user to decide if we want to
|
|
|
produce a reference model or a fbgemm/qnnpack model
|
|
|
- backend_config: backend configuration of the target backend of quantization
|
|
|
"""
|
|
|
|
|
|
if is_reference:
|
|
|
convert_fn = torch.ao.quantization.quantize_fx.convert_to_reference_fx
|
|
|
else:
|
|
|
convert_fn = torch.ao.quantization.quantize_fx.convert_fx
|
|
|
|
|
|
|
|
|
observed_standalone_module: GraphModule = modules[str(node.target)]
|
|
|
sm_input_quantized_idxs = observed_standalone_module.meta[
|
|
|
"_observed_graph_module_attrs"
|
|
|
].standalone_module_input_quantized_idxs
|
|
|
|
|
|
args = list(node.args)
|
|
|
for idx in range(len(args)):
|
|
|
if idx in sm_input_quantized_idxs:
|
|
|
arg = args[idx]
|
|
|
if arg.op == "call_method" and arg.target == "dequantize":
|
|
|
quantize_node = arg.args[0]
|
|
|
node.replace_input_with(arg, quantize_node)
|
|
|
if len(arg.users) == 0:
|
|
|
model.graph.erase_node(arg)
|
|
|
|
|
|
sm_output_quantized_idxs = observed_standalone_module.meta[
|
|
|
"_observed_graph_module_attrs"
|
|
|
].standalone_module_output_quantized_idxs
|
|
|
if len(sm_output_quantized_idxs) > 0:
|
|
|
assert sm_output_quantized_idxs[0] == 0, "Currently only quantized"
|
|
|
"output idxs = [0] is supported"
|
|
|
|
|
|
|
|
|
|
|
|
_insert_dequantize_node(node, model.graph)
|
|
|
|
|
|
|
|
|
|
|
|
quantized_standalone_module = convert_fn(
|
|
|
observed_standalone_module, backend_config=backend_config
|
|
|
)
|
|
|
parent_name, name = _parent_name(node.target)
|
|
|
|
|
|
setattr(modules[parent_name], name, quantized_standalone_module)
|
|
|
modules[str(node.target)] = quantized_standalone_module
|
|
|
|
|
|
|
|
|
def convert_weighted_module(
|
|
|
node: Node,
|
|
|
modules: dict[str, torch.nn.Module],
|
|
|
observed_node_names: set[str],
|
|
|
node_name_to_qconfig: dict[str, QConfigAny],
|
|
|
backend_config: BackendConfig,
|
|
|
is_decomposed: bool = False,
|
|
|
is_reference: bool = False,
|
|
|
) -> None:
|
|
|
"""Convert a weighted module to reference quantized module in the model
|
|
|
If the QConfig of a QAT module is not set, the module will still be converted to
|
|
|
a float module.
|
|
|
|
|
|
Args:
|
|
|
- node: The call_module node of the observed standalone module
|
|
|
- modules: named_module of original model
|
|
|
- observed_node_names: names for the set of observed fx node, we can skip
|
|
|
this conversion if the node is not observed
|
|
|
"""
|
|
|
original_module = modules[str(node.target)]
|
|
|
qconfig: QConfigAny = original_module.qconfig
|
|
|
weight_post_process = None
|
|
|
qat_module_classes = get_qat_module_classes(backend_config)
|
|
|
|
|
|
if isinstance(original_module, qat_module_classes):
|
|
|
|
|
|
|
|
|
|
|
|
weight_post_process = original_module.weight_fake_quant
|
|
|
original_module = original_module.to_float()
|
|
|
|
|
|
parent_name, name = _parent_name(node.target)
|
|
|
setattr(modules[parent_name], name, original_module)
|
|
|
|
|
|
is_observed = node.name in observed_node_names
|
|
|
|
|
|
if (
|
|
|
qconfig is None
|
|
|
or _has_none_qconfig(node, node_name_to_qconfig)
|
|
|
or not is_observed
|
|
|
):
|
|
|
return
|
|
|
|
|
|
|
|
|
pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config)
|
|
|
dtype_configs = pattern_to_dtype_configs.get(type(original_module), [])
|
|
|
if not _is_qconfig_supported_by_dtype_configs(qconfig, dtype_configs):
|
|
|
return
|
|
|
|
|
|
|
|
|
is_weight_quantized = weight_is_quantized(qconfig)
|
|
|
|
|
|
|
|
|
|
|
|
if not is_weight_quantized:
|
|
|
return
|
|
|
|
|
|
fused_module = None
|
|
|
float_module = original_module
|
|
|
|
|
|
if isinstance(original_module, torch.ao.nn.intrinsic._FusedModule):
|
|
|
fused_module = float_module
|
|
|
float_module = fused_module[0]
|
|
|
|
|
|
|
|
|
|
|
|
wq_or_wq_dict = {"is_decomposed": is_decomposed}
|
|
|
if isinstance(float_module, torch.nn.RNNCellBase):
|
|
|
weight_post_process_ih = qconfig.weight()
|
|
|
weight_post_process_hh = qconfig.weight()
|
|
|
weight_post_process_ih(float_module.weight_ih)
|
|
|
weight_post_process_hh(float_module.weight_hh)
|
|
|
weight_qparams_ih = get_qparam_dict(weight_post_process_ih)
|
|
|
weight_qparams_hh = get_qparam_dict(weight_post_process_hh)
|
|
|
wq_or_wq_dict.update(
|
|
|
{
|
|
|
"weight_ih": weight_qparams_ih,
|
|
|
"weight_hh": weight_qparams_hh,
|
|
|
}
|
|
|
)
|
|
|
elif isinstance(float_module, (torch.nn.LSTM, torch.nn.GRU)):
|
|
|
|
|
|
|
|
|
for wn in float_module._flat_weights_names:
|
|
|
if hasattr(float_module, wn) and wn.startswith("weight"):
|
|
|
weight = getattr(float_module, wn)
|
|
|
weight_post_process = qconfig.weight()
|
|
|
if weight_post_process.dtype == torch.qint8:
|
|
|
weight_post_process(weight)
|
|
|
wq_or_wq_dict[wn] = get_qparam_dict(weight_post_process)
|
|
|
else:
|
|
|
|
|
|
|
|
|
is_ptq = weight_post_process is None
|
|
|
if is_ptq:
|
|
|
weight_post_process = qconfig.weight()
|
|
|
device = assert_and_get_unique_device(float_module)
|
|
|
if device:
|
|
|
weight_post_process.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
is_qat = not is_ptq
|
|
|
if not (is_decomposed and is_reference and is_qat):
|
|
|
weight_post_process(float_module.weight)
|
|
|
|
|
|
wq_or_wq_dict.update(get_qparam_dict(weight_post_process))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
root_module_to_quantized_reference_module = (
|
|
|
get_root_module_to_quantized_reference_module(backend_config)
|
|
|
)
|
|
|
ref_qmodule_cls = root_module_to_quantized_reference_module.get(
|
|
|
type_before_parametrizations(float_module), None
|
|
|
)
|
|
|
assert ref_qmodule_cls is not None, (
|
|
|
f"No reference quantized module class configured for {type_before_parametrizations(float_module)}"
|
|
|
)
|
|
|
ref_qmodule = ref_qmodule_cls.from_float(float_module, wq_or_wq_dict)
|
|
|
if fused_module is not None:
|
|
|
fused_module[0] = ref_qmodule
|
|
|
else:
|
|
|
parent_name, name = _parent_name(node.target)
|
|
|
setattr(modules[parent_name], name, ref_qmodule)
|
|
|
|
|
|
|
|
|
def _remove_previous_dequantize_in_custom_module(
|
|
|
node: Node, prev_node: Node, graph: Graph
|
|
|
) -> None:
|
|
|
"""
|
|
|
Given a custom module `node`, if the previous node is a dequantize, reroute the custom as follows:
|
|
|
|
|
|
Before: quantize - dequantize - custom_module
|
|
|
After: quantize - custom_module
|
|
|
\\ - dequantize
|
|
|
"""
|
|
|
|
|
|
assert isinstance(prev_node, Node), (
|
|
|
f"Expecting the argument for custom module node to be a Node, but got {prev_node}"
|
|
|
)
|
|
|
if prev_node.op == "call_method" and prev_node.target == "dequantize":
|
|
|
node.replace_input_with(prev_node, prev_node.args[0])
|
|
|
|
|
|
if len(prev_node.users) == 0:
|
|
|
graph.erase_node(prev_node)
|
|
|
|
|
|
|
|
|
def convert_custom_module(
|
|
|
node: Node,
|
|
|
graph: Graph,
|
|
|
modules: dict[str, torch.nn.Module],
|
|
|
custom_module_class_mapping: dict[QuantType, dict[type, type]],
|
|
|
statically_quantized_custom_module_nodes: set[Node],
|
|
|
) -> None:
|
|
|
"""Converts an observed custom module to a quantized custom module based on
|
|
|
`custom_module_class_mapping`
|
|
|
For static quantization, we'll also remove the previous `dequantize` node and
|
|
|
attach the observer node for output to the module, the observer for the node
|
|
|
will be converted to a dequantize node instead of quantize-dequantize pairs
|
|
|
later in the graph. In the end we would have a quantized custom module that
|
|
|
has the same interface as a default quantized module in nn.quantized namespace,
|
|
|
i.e. quantized input and quantized output.
|
|
|
|
|
|
Args:
|
|
|
- node: The call_module node of the observed standalone module
|
|
|
- graph: The graph containing the node
|
|
|
- modules: named_module of original model
|
|
|
- custom_module_class_mapping: mapping from observed custom module class to
|
|
|
quantized custom module class, used to swap custom modules
|
|
|
- statically_quantized_custom_module_nodes: we'll add the custom module node
|
|
|
if we find it is statically quantized, this will be used later when converting
|
|
|
observers to quant/dequant node pairs, if the observed node is a statically
|
|
|
quantized custom module nodes, we'll convert the observer to a dequantize node,
|
|
|
this is to keep the interface the same as the default quantized module.
|
|
|
TODO: maybe we want to redesign this part to align with reference model design
|
|
|
as well, but there has been some discussions around the interface, so we can do
|
|
|
it later.
|
|
|
"""
|
|
|
observed_custom_module = modules[str(node.target)]
|
|
|
qconfig = observed_custom_module.qconfig
|
|
|
if activation_is_statically_quantized(qconfig):
|
|
|
statically_quantized_custom_module_nodes.add(node)
|
|
|
if _is_custom_module_lstm(node, modules):
|
|
|
|
|
|
|
|
|
assert (
|
|
|
len(node.args) == 2
|
|
|
and isinstance(node.args[1], tuple)
|
|
|
and len(node.args[1]) == 2
|
|
|
)
|
|
|
(inputs, (hidden0, hidden1)) = node.args
|
|
|
assert isinstance(inputs, Node)
|
|
|
assert isinstance(hidden0, Node)
|
|
|
assert isinstance(hidden1, Node)
|
|
|
_remove_previous_dequantize_in_custom_module(node, inputs, graph)
|
|
|
_remove_previous_dequantize_in_custom_module(node, hidden0, graph)
|
|
|
_remove_previous_dequantize_in_custom_module(node, hidden1, graph)
|
|
|
elif _is_custom_module_mha(node, modules):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert len(node.args) == 3
|
|
|
query, key, value = node.args
|
|
|
assert isinstance(query, Node)
|
|
|
assert isinstance(key, Node)
|
|
|
assert isinstance(value, Node)
|
|
|
_remove_previous_dequantize_in_custom_module(node, query, graph)
|
|
|
_remove_previous_dequantize_in_custom_module(node, key, graph)
|
|
|
_remove_previous_dequantize_in_custom_module(node, value, graph)
|
|
|
else:
|
|
|
|
|
|
arg = node.args[0]
|
|
|
assert isinstance(arg, Node)
|
|
|
_remove_previous_dequantize_in_custom_module(node, arg, graph)
|
|
|
|
|
|
activation_post_process = _maybe_get_observer_for_node(node, modules)
|
|
|
assert activation_post_process is not None
|
|
|
observed_custom_module.activation_post_process = activation_post_process
|
|
|
|
|
|
|
|
|
quantized_custom_module_class = get_swapped_custom_module_class(
|
|
|
observed_custom_module, custom_module_class_mapping, qconfig
|
|
|
)
|
|
|
quantized_custom_module = quantized_custom_module_class.from_observed(
|
|
|
observed_custom_module
|
|
|
)
|
|
|
parent_name, name = _parent_name(node.target)
|
|
|
setattr(modules[parent_name], name, quantized_custom_module)
|
|
|
|
|
|
|
|
|
def convert(
|
|
|
model: GraphModule,
|
|
|
is_reference: bool = False,
|
|
|
convert_custom_config: Union[ConvertCustomConfig, dict[str, Any], None] = None,
|
|
|
is_standalone_module: bool = False,
|
|
|
_remove_qconfig_flag: bool = True,
|
|
|
qconfig_mapping: Union[QConfigMapping, dict[str, Any], None] = None,
|
|
|
backend_config: Union[BackendConfig, dict[str, Any], None] = None,
|
|
|
is_decomposed: bool = False,
|
|
|
keep_original_weights: bool = False,
|
|
|
) -> GraphModule:
|
|
|
"""
|
|
|
We will convert an observed model (a module with observer calls) to a reference
|
|
|
quantized model, the rule is simple:
|
|
|
1. for each observer module call in the graph, we'll convert it to calls to
|
|
|
quantize and dequantize functions based on the observer instance
|
|
|
2. for weighted operations like linear/conv, we need to convert them to reference
|
|
|
quantized module, this requires us to know whether the dtype configured for the
|
|
|
weight is supported in the backend, this is done in prepare step and the result
|
|
|
is stored in observed_node_names, we can decide whether we need to swap the
|
|
|
module based on this set
|
|
|
|
|
|
Args:
|
|
|
* `is_standalone_module`: when this flag is True, it means we are quantizing
|
|
|
a submodule that is not inlined in parent module, and will be quantized
|
|
|
separately as one unit.
|
|
|
|
|
|
* `is_decomposed`: a boolean flag to indicate whether we want to use the
|
|
|
quantize operator for decomposed quantized tensor
|
|
|
(torch.ops.quantized_decomposed.quantize_per_tensor) or default/standalone
|
|
|
quantized tensor (torch.quantize_per_tensor)
|
|
|
|
|
|
Returns:
|
|
|
a quantized standalone module, whether input/output is quantized is
|
|
|
specified by prepare_custom_config, with
|
|
|
input_quantized_idxs, output_quantized_idxs, please
|
|
|
see docs for :func:`~torch.ao.quantization.prepare_fx` for details
|
|
|
"""
|
|
|
if convert_custom_config is None:
|
|
|
convert_custom_config = ConvertCustomConfig()
|
|
|
|
|
|
if isinstance(convert_custom_config, dict):
|
|
|
warnings.warn(
|
|
|
"Passing a convert_custom_config_dict to convert is deprecated and will not be supported "
|
|
|
"in a future version. Please pass in a ConvertCustomConfig instead.",
|
|
|
FutureWarning,
|
|
|
stacklevel=2,
|
|
|
)
|
|
|
convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config)
|
|
|
|
|
|
if isinstance(qconfig_mapping, dict):
|
|
|
warnings.warn(
|
|
|
"Passing a QConfig dictionary to convert is deprecated and will not be supported "
|
|
|
"in a future version. Please pass in a QConfigMapping instead.",
|
|
|
FutureWarning,
|
|
|
stacklevel=2,
|
|
|
)
|
|
|
qconfig_mapping = (
|
|
|
QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None
|
|
|
)
|
|
|
qconfig_mapping = copy.deepcopy(qconfig_mapping)
|
|
|
assert qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping)
|
|
|
|
|
|
if isinstance(backend_config, dict):
|
|
|
warnings.warn(
|
|
|
"Passing a backend_config_dict to prepare is deprecated and will not be supported "
|
|
|
"in a future version. Please pass in a BackendConfig instead.",
|
|
|
FutureWarning,
|
|
|
stacklevel=2,
|
|
|
)
|
|
|
backend_config = BackendConfig.from_dict(backend_config)
|
|
|
|
|
|
if backend_config is None:
|
|
|
backend_config = get_native_backend_config()
|
|
|
|
|
|
assert _is_observed_module(model), "incoming model must be produced by prepare_fx"
|
|
|
observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"]
|
|
|
node_name_to_scope: dict[str, tuple[str, type]] = (
|
|
|
observed_graph_module_attrs.node_name_to_scope
|
|
|
)
|
|
|
prepare_custom_config: PrepareCustomConfig = (
|
|
|
observed_graph_module_attrs.prepare_custom_config
|
|
|
)
|
|
|
observed_node_names: set[str] = observed_graph_module_attrs.observed_node_names
|
|
|
node_name_to_qconfig: dict[str, QConfigAny] = (
|
|
|
observed_graph_module_attrs.node_name_to_qconfig
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules = dict(model.named_modules(remove_duplicate=False))
|
|
|
|
|
|
|
|
|
|
|
|
if qconfig_mapping:
|
|
|
prepare_qconfig_mapping: QConfigMapping = (
|
|
|
observed_graph_module_attrs.qconfig_mapping
|
|
|
)
|
|
|
modules_copy = copy.deepcopy(modules)
|
|
|
|
|
|
if observed_graph_module_attrs.is_qat:
|
|
|
_update_qconfig_for_qat(qconfig_mapping, backend_config)
|
|
|
_update_qconfig_for_fusion(model, qconfig_mapping)
|
|
|
|
|
|
_compare_prepare_convert_qconfig_mappings(
|
|
|
prepare_qconfig_mapping, qconfig_mapping
|
|
|
)
|
|
|
convert_node_name_to_qconfig = _generate_node_name_to_qconfig(
|
|
|
model, modules_copy, model.graph, qconfig_mapping, node_name_to_scope
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
for k, v in node_name_to_qconfig.items():
|
|
|
assert k in convert_node_name_to_qconfig, (
|
|
|
f"Expected key {k} in convert node_name_to_qconfig"
|
|
|
)
|
|
|
if convert_node_name_to_qconfig[k] is not None:
|
|
|
assert qconfig_equals(v, convert_node_name_to_qconfig[k]), (
|
|
|
f"Expected k {k} to have the same value in prepare and convert QConfigMappings, "
|
|
|
f"but {v} was updated to {convert_node_name_to_qconfig[k]}"
|
|
|
)
|
|
|
node_name_to_qconfig = convert_node_name_to_qconfig
|
|
|
|
|
|
custom_module_classes = get_custom_module_class_keys(
|
|
|
convert_custom_config.observed_to_quantized_mapping
|
|
|
)
|
|
|
custom_module_class_mapping = convert_custom_config.observed_to_quantized_mapping
|
|
|
|
|
|
if observed_graph_module_attrs.equalization_node_name_to_qconfig is not None:
|
|
|
|
|
|
|
|
|
|
|
|
weight_eq_obs_dict = update_obs_for_equalization(model, modules)
|
|
|
convert_eq_obs(model, modules, weight_eq_obs_dict)
|
|
|
|
|
|
|
|
|
|
|
|
_run_weight_observers(model, backend_config)
|
|
|
|
|
|
|
|
|
|
|
|
placeholder_node_seen_cnt = 0
|
|
|
input_quantized_idxs: list[int] = prepare_custom_config.input_quantized_indexes
|
|
|
output_quantized_idxs: list[int] = prepare_custom_config.output_quantized_indexes
|
|
|
|
|
|
root_module_to_quantized_reference_module = (
|
|
|
get_root_module_to_quantized_reference_module(backend_config)
|
|
|
)
|
|
|
|
|
|
root_module_classes = tuple(root_module_to_quantized_reference_module.keys())
|
|
|
qat_module_classes = get_qat_module_classes(backend_config)
|
|
|
fused_module_classes = get_fused_module_classes(backend_config)
|
|
|
statically_quantized_custom_module_nodes: set[Node] = set()
|
|
|
|
|
|
for node in list(model.graph.nodes):
|
|
|
if node.op == "placeholder":
|
|
|
cur_placeholder_node_idx = placeholder_node_seen_cnt
|
|
|
placeholder_node_seen_cnt += 1
|
|
|
if cur_placeholder_node_idx in input_quantized_idxs:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_insert_dequantize_node(node, model.graph)
|
|
|
elif node.op == "output":
|
|
|
|
|
|
if len(output_quantized_idxs) == 0:
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
return_node = node
|
|
|
output = node.args[0]
|
|
|
|
|
|
if isinstance(output, (list, tuple)):
|
|
|
for idx in output_quantized_idxs:
|
|
|
_maybe_recursive_remove_dequantize(
|
|
|
output[idx], return_node, model.graph
|
|
|
)
|
|
|
elif isinstance(output, (Node, dict)):
|
|
|
|
|
|
|
|
|
|
|
|
if 0 in output_quantized_idxs:
|
|
|
_maybe_recursive_remove_dequantize(output, return_node, model.graph)
|
|
|
else:
|
|
|
warnings.warn(
|
|
|
f"Unsupported node type for output_quantized_idxs: {type(output)}"
|
|
|
)
|
|
|
elif node.op == "call_module":
|
|
|
mod = _get_module(node, modules)
|
|
|
assert mod is not None
|
|
|
if _is_activation_post_process(mod):
|
|
|
observed_node = node.args[0]
|
|
|
if observed_node in statically_quantized_custom_module_nodes:
|
|
|
_replace_observer_or_dequant_stub_with_dequantize_node(
|
|
|
node, model.graph
|
|
|
)
|
|
|
else:
|
|
|
if is_decomposed:
|
|
|
_replace_observer_with_quantize_dequantize_node_decomposed(
|
|
|
model,
|
|
|
node,
|
|
|
modules,
|
|
|
node_name_to_scope,
|
|
|
node_name_to_qconfig,
|
|
|
)
|
|
|
else:
|
|
|
_replace_observer_with_quantize_dequantize_node(
|
|
|
model,
|
|
|
node,
|
|
|
modules,
|
|
|
node_name_to_scope,
|
|
|
node_name_to_qconfig,
|
|
|
)
|
|
|
elif isinstance(mod, DeQuantStub):
|
|
|
_replace_observer_or_dequant_stub_with_dequantize_node(
|
|
|
node, model.graph
|
|
|
)
|
|
|
elif _is_observed_standalone_module(mod):
|
|
|
convert_standalone_module(
|
|
|
node, modules, model, is_reference, backend_config
|
|
|
)
|
|
|
|
|
|
|
|
|
elif type_before_parametrizations(mod) in set(root_module_classes).union(
|
|
|
qat_module_classes
|
|
|
).union(fused_module_classes):
|
|
|
|
|
|
|
|
|
if (
|
|
|
type_before_parametrizations(mod) in fused_module_classes
|
|
|
and type_before_parametrizations(mod[0]) not in root_module_classes
|
|
|
):
|
|
|
continue
|
|
|
convert_weighted_module(
|
|
|
node,
|
|
|
modules,
|
|
|
observed_node_names,
|
|
|
node_name_to_qconfig,
|
|
|
backend_config,
|
|
|
is_decomposed,
|
|
|
is_reference,
|
|
|
)
|
|
|
elif type_before_parametrizations(mod) in custom_module_classes:
|
|
|
convert_custom_module(
|
|
|
node,
|
|
|
model.graph,
|
|
|
modules,
|
|
|
custom_module_class_mapping,
|
|
|
statically_quantized_custom_module_nodes,
|
|
|
)
|
|
|
|
|
|
|
|
|
model.graph.eliminate_dead_code()
|
|
|
model = GraphModule(model, model.graph)
|
|
|
|
|
|
|
|
|
if not is_reference:
|
|
|
model = lower_to_fbgemm(
|
|
|
model, node_name_to_qconfig, node_name_to_scope, keep_original_weights
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if _remove_qconfig_flag:
|
|
|
_remove_qconfig(model)
|
|
|
model.delete_all_unused_submodules()
|
|
|
model.meta.pop("_observed_graph_module_attrs", None)
|
|
|
return model
|
|
|
|