Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/Lib/site-packages/torch/ao/nn/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/ao/nn/intrinsic/__init__.py +40 -0
- .venv/Lib/site-packages/torch/ao/nn/intrinsic/modules/__init__.py +41 -0
- .venv/Lib/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/fused.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/ao/nn/intrinsic/modules/fused.py +245 -0
- .venv/Lib/site-packages/torch/ao/nn/intrinsic/qat/__init__.py +1 -0
- .venv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__init__.py +32 -0
- .venv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/conv_fused.py +1050 -0
- .venv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_fused.py +193 -0
- .venv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_relu.py +51 -0
- .venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__init__.py +18 -0
- .venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_add.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py +105 -0
- .venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_add.py +145 -0
- .venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py +263 -0
- .venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py +187 -0
- .venv/Lib/site-packages/torch/ao/nn/quantized/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/ao/nn/quantized/__pycache__/functional.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/ao/nn/quantized/reference/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/rnn.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/sparse.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/utils.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/ao/nn/sparse/__init__.py +1 -0
- .venv/Lib/site-packages/torch/ao/nn/sparse/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/ao/nn/sparse/quantized/__init__.py +10 -0
- .venv/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/linear.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/utils.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__init__.py +6 -0
- .venv/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/linear.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/linear.py +188 -0
- .venv/Lib/site-packages/torch/ao/nn/sparse/quantized/linear.py +273 -0
- .venv/Lib/site-packages/torch/ao/nn/sparse/quantized/utils.py +56 -0
- .venv/Lib/site-packages/torch/ao/ns/__init__.py +0 -0
- .venv/Lib/site-packages/torch/ao/ns/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/ao/ns/_numeric_suite.py +563 -0
- .venv/Lib/site-packages/torch/ao/ns/_numeric_suite_fx.py +1130 -0
- .venv/Lib/site-packages/torch/ao/ns/fx/__init__.py +0 -0
- .venv/Lib/site-packages/torch/ao/ns/fx/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/ao/ns/fx/__pycache__/ns_types.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/ao/ns/fx/__pycache__/utils.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/ao/ns/fx/graph_matcher.py +470 -0
- .venv/Lib/site-packages/torch/ao/ns/fx/graph_passes.py +1131 -0
.venv/Lib/site-packages/torch/ao/nn/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (510 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/ao/nn/intrinsic/__init__.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from .modules import * # noqa: F403
|
| 3 |
+
from .modules.fused import _FusedModule # noqa: F403
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# # Subpackages
|
| 7 |
+
# from . import qat # noqa: F403
|
| 8 |
+
# from . import quantized # noqa: F403
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"ConvBn1d",
|
| 12 |
+
"ConvBn2d",
|
| 13 |
+
"ConvBn3d",
|
| 14 |
+
"ConvBnReLU1d",
|
| 15 |
+
"ConvBnReLU2d",
|
| 16 |
+
"ConvBnReLU3d",
|
| 17 |
+
"ConvReLU1d",
|
| 18 |
+
"ConvReLU2d",
|
| 19 |
+
"ConvReLU3d",
|
| 20 |
+
"LinearReLU",
|
| 21 |
+
"BNReLU2d",
|
| 22 |
+
"BNReLU3d",
|
| 23 |
+
"LinearBn1d",
|
| 24 |
+
"LinearLeakyReLU",
|
| 25 |
+
"LinearTanh",
|
| 26 |
+
"ConvAdd2d",
|
| 27 |
+
"ConvAddReLU2d",
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# We are exposing all subpackages to the end-user.
|
| 32 |
+
# Because of possible inter-dependency, we want to avoid
|
| 33 |
+
# the cyclic imports, thus implementing lazy version
|
| 34 |
+
# as per https://peps.python.org/pep-0562/
|
| 35 |
+
def __getattr__(name):
|
| 36 |
+
if name in __all__:
|
| 37 |
+
import importlib
|
| 38 |
+
|
| 39 |
+
return importlib.import_module("." + name, __name__)
|
| 40 |
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
.venv/Lib/site-packages/torch/ao/nn/intrinsic/modules/__init__.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .fused import ( # noqa: F401
|
| 2 |
+
_FusedModule,
|
| 3 |
+
BNReLU2d,
|
| 4 |
+
BNReLU3d,
|
| 5 |
+
ConvAdd2d,
|
| 6 |
+
ConvAddReLU2d,
|
| 7 |
+
ConvBn1d,
|
| 8 |
+
ConvBn2d,
|
| 9 |
+
ConvBn3d,
|
| 10 |
+
ConvBnReLU1d,
|
| 11 |
+
ConvBnReLU2d,
|
| 12 |
+
ConvBnReLU3d,
|
| 13 |
+
ConvReLU1d,
|
| 14 |
+
ConvReLU2d,
|
| 15 |
+
ConvReLU3d,
|
| 16 |
+
LinearBn1d,
|
| 17 |
+
LinearLeakyReLU,
|
| 18 |
+
LinearReLU,
|
| 19 |
+
LinearTanh,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
__all__ = [
|
| 24 |
+
"ConvBn1d",
|
| 25 |
+
"ConvBn2d",
|
| 26 |
+
"ConvBn3d",
|
| 27 |
+
"ConvBnReLU1d",
|
| 28 |
+
"ConvBnReLU2d",
|
| 29 |
+
"ConvBnReLU3d",
|
| 30 |
+
"ConvReLU1d",
|
| 31 |
+
"ConvReLU2d",
|
| 32 |
+
"ConvReLU3d",
|
| 33 |
+
"LinearReLU",
|
| 34 |
+
"BNReLU2d",
|
| 35 |
+
"BNReLU3d",
|
| 36 |
+
"LinearBn1d",
|
| 37 |
+
"LinearLeakyReLU",
|
| 38 |
+
"LinearTanh",
|
| 39 |
+
"ConvAdd2d",
|
| 40 |
+
"ConvAddReLU2d",
|
| 41 |
+
]
|
.venv/Lib/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (709 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/fused.cpython-39.pyc
ADDED
|
Binary file (9.96 kB). View file
|
|
|
.venv/Lib/site-packages/torch/ao/nn/intrinsic/modules/fused.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
from torch.nn import (
|
| 4 |
+
BatchNorm1d,
|
| 5 |
+
BatchNorm2d,
|
| 6 |
+
BatchNorm3d,
|
| 7 |
+
Conv1d,
|
| 8 |
+
Conv2d,
|
| 9 |
+
Conv3d,
|
| 10 |
+
Linear,
|
| 11 |
+
ReLU,
|
| 12 |
+
)
|
| 13 |
+
from torch.nn.utils.parametrize import type_before_parametrizations
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
"ConvReLU1d",
|
| 18 |
+
"ConvReLU2d",
|
| 19 |
+
"ConvReLU3d",
|
| 20 |
+
"LinearReLU",
|
| 21 |
+
"ConvBn1d",
|
| 22 |
+
"ConvBn2d",
|
| 23 |
+
"ConvBnReLU1d",
|
| 24 |
+
"ConvBnReLU2d",
|
| 25 |
+
"ConvBn3d",
|
| 26 |
+
"ConvBnReLU3d",
|
| 27 |
+
"BNReLU2d",
|
| 28 |
+
"BNReLU3d",
|
| 29 |
+
"LinearBn1d",
|
| 30 |
+
"LinearLeakyReLU",
|
| 31 |
+
"LinearTanh",
|
| 32 |
+
"ConvAdd2d",
|
| 33 |
+
"ConvAddReLU2d",
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# Used for identifying intrinsic modules used in quantization
|
| 38 |
+
class _FusedModule(torch.nn.Sequential):
|
| 39 |
+
pass
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class ConvReLU1d(_FusedModule):
|
| 43 |
+
r"""This is a sequential container which calls the Conv1d and ReLU modules.
|
| 44 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 45 |
+
|
| 46 |
+
def __init__(self, conv, relu):
|
| 47 |
+
assert (
|
| 48 |
+
type_before_parametrizations(conv) == Conv1d
|
| 49 |
+
and type_before_parametrizations(relu) == ReLU
|
| 50 |
+
), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}"
|
| 51 |
+
super().__init__(conv, relu)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class ConvReLU2d(_FusedModule):
|
| 55 |
+
r"""This is a sequential container which calls the Conv2d and ReLU modules.
|
| 56 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 57 |
+
|
| 58 |
+
def __init__(self, conv, relu):
|
| 59 |
+
assert (
|
| 60 |
+
type_before_parametrizations(conv) == Conv2d
|
| 61 |
+
and type_before_parametrizations(relu) == ReLU
|
| 62 |
+
), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}"
|
| 63 |
+
super().__init__(conv, relu)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class ConvReLU3d(_FusedModule):
|
| 67 |
+
r"""This is a sequential container which calls the Conv3d and ReLU modules.
|
| 68 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 69 |
+
|
| 70 |
+
def __init__(self, conv, relu):
|
| 71 |
+
assert (
|
| 72 |
+
type_before_parametrizations(conv) == Conv3d
|
| 73 |
+
and type_before_parametrizations(relu) == ReLU
|
| 74 |
+
), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}"
|
| 75 |
+
super().__init__(conv, relu)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class LinearReLU(_FusedModule):
|
| 79 |
+
r"""This is a sequential container which calls the Linear and ReLU modules.
|
| 80 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 81 |
+
|
| 82 |
+
def __init__(self, linear, relu):
|
| 83 |
+
assert (
|
| 84 |
+
type_before_parametrizations(linear) == Linear
|
| 85 |
+
and type_before_parametrizations(relu) == ReLU
|
| 86 |
+
), f"Incorrect types for input modules{type_before_parametrizations(linear)}{type_before_parametrizations(relu)}"
|
| 87 |
+
super().__init__(linear, relu)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class ConvBn1d(_FusedModule):
|
| 91 |
+
r"""This is a sequential container which calls the Conv 1d and Batch Norm 1d modules.
|
| 92 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 93 |
+
|
| 94 |
+
def __init__(self, conv, bn):
|
| 95 |
+
assert (
|
| 96 |
+
type_before_parametrizations(conv) == Conv1d
|
| 97 |
+
and type_before_parametrizations(bn) == BatchNorm1d
|
| 98 |
+
), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}"
|
| 99 |
+
super().__init__(conv, bn)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class ConvBn2d(_FusedModule):
|
| 103 |
+
r"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules.
|
| 104 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 105 |
+
|
| 106 |
+
def __init__(self, conv, bn):
|
| 107 |
+
assert (
|
| 108 |
+
type_before_parametrizations(conv) == Conv2d
|
| 109 |
+
and type_before_parametrizations(bn) == BatchNorm2d
|
| 110 |
+
), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}"
|
| 111 |
+
super().__init__(conv, bn)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class ConvBnReLU1d(_FusedModule):
|
| 115 |
+
r"""This is a sequential container which calls the Conv 1d, Batch Norm 1d, and ReLU modules.
|
| 116 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 117 |
+
|
| 118 |
+
def __init__(self, conv, bn, relu):
|
| 119 |
+
assert (
|
| 120 |
+
type_before_parametrizations(conv) == Conv1d
|
| 121 |
+
and type_before_parametrizations(bn) == BatchNorm1d
|
| 122 |
+
and type_before_parametrizations(relu) == ReLU
|
| 123 |
+
), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}{type_before_parametrizations(relu)}" # noqa: B950
|
| 124 |
+
super().__init__(conv, bn, relu)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class ConvBnReLU2d(_FusedModule):
|
| 128 |
+
r"""This is a sequential container which calls the Conv 2d, Batch Norm 2d, and ReLU modules.
|
| 129 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 130 |
+
|
| 131 |
+
def __init__(self, conv, bn, relu):
|
| 132 |
+
assert (
|
| 133 |
+
type_before_parametrizations(conv) == Conv2d
|
| 134 |
+
and type_before_parametrizations(bn) == BatchNorm2d
|
| 135 |
+
and type_before_parametrizations(relu) == ReLU
|
| 136 |
+
), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}{type_before_parametrizations(relu)}" # noqa: B950
|
| 137 |
+
super().__init__(conv, bn, relu)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class ConvBn3d(_FusedModule):
|
| 141 |
+
r"""This is a sequential container which calls the Conv 3d and Batch Norm 3d modules.
|
| 142 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 143 |
+
|
| 144 |
+
def __init__(self, conv, bn):
|
| 145 |
+
assert (
|
| 146 |
+
type_before_parametrizations(conv) == Conv3d
|
| 147 |
+
and type_before_parametrizations(bn) == BatchNorm3d
|
| 148 |
+
), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}"
|
| 149 |
+
super().__init__(conv, bn)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class ConvBnReLU3d(_FusedModule):
|
| 153 |
+
r"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules.
|
| 154 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 155 |
+
|
| 156 |
+
def __init__(self, conv, bn, relu):
|
| 157 |
+
assert (
|
| 158 |
+
type_before_parametrizations(conv) == Conv3d
|
| 159 |
+
and type_before_parametrizations(bn) == BatchNorm3d
|
| 160 |
+
and type_before_parametrizations(relu) == ReLU
|
| 161 |
+
), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}{type_before_parametrizations(relu)}" # noqa: B950
|
| 162 |
+
super().__init__(conv, bn, relu)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class BNReLU2d(_FusedModule):
|
| 166 |
+
r"""This is a sequential container which calls the BatchNorm 2d and ReLU modules.
|
| 167 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 168 |
+
|
| 169 |
+
def __init__(self, batch_norm, relu):
|
| 170 |
+
assert (
|
| 171 |
+
type_before_parametrizations(batch_norm) == BatchNorm2d
|
| 172 |
+
and type_before_parametrizations(relu) == ReLU
|
| 173 |
+
), f"Incorrect types for input modules{type_before_parametrizations(batch_norm)}{type_before_parametrizations(relu)}"
|
| 174 |
+
super().__init__(batch_norm, relu)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class BNReLU3d(_FusedModule):
|
| 178 |
+
r"""This is a sequential container which calls the BatchNorm 3d and ReLU modules.
|
| 179 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 180 |
+
|
| 181 |
+
def __init__(self, batch_norm, relu):
|
| 182 |
+
assert (
|
| 183 |
+
type_before_parametrizations(batch_norm) == BatchNorm3d
|
| 184 |
+
and type_before_parametrizations(relu) == ReLU
|
| 185 |
+
), f"Incorrect types for input modules{type_before_parametrizations(batch_norm)}{type_before_parametrizations(relu)}"
|
| 186 |
+
super().__init__(batch_norm, relu)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class LinearBn1d(_FusedModule):
|
| 190 |
+
r"""This is a sequential container which calls the Linear and BatchNorm1d modules.
|
| 191 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 192 |
+
|
| 193 |
+
def __init__(self, linear, bn):
|
| 194 |
+
assert (
|
| 195 |
+
type_before_parametrizations(linear) == Linear
|
| 196 |
+
and type_before_parametrizations(bn) == BatchNorm1d
|
| 197 |
+
), f"Incorrect types for input modules{type_before_parametrizations(linear)}{type_before_parametrizations(bn)}"
|
| 198 |
+
super().__init__(linear, bn)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class LinearLeakyReLU(_FusedModule):
|
| 202 |
+
r"""This is a sequential container which calls the Linear and LeakyReLU modules.
|
| 203 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 204 |
+
|
| 205 |
+
def __init__(self, linear, leaky_relu):
|
| 206 |
+
assert (
|
| 207 |
+
type(linear) == Linear and type(leaky_relu) == torch.nn.LeakyReLU
|
| 208 |
+
), f"Incorrect types for input modules{type(linear)}{type(leaky_relu)}"
|
| 209 |
+
super().__init__(linear, leaky_relu)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class LinearTanh(_FusedModule):
|
| 213 |
+
r"""This is a sequential container which calls the Linear and Tanh modules.
|
| 214 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 215 |
+
|
| 216 |
+
def __init__(self, linear, tanh):
|
| 217 |
+
assert (
|
| 218 |
+
type(linear) == Linear and type(tanh) == torch.nn.Tanh
|
| 219 |
+
), f"Incorrect types for input modules{type(linear)}{type(tanh)}"
|
| 220 |
+
super().__init__(linear, tanh)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class ConvAdd2d(_FusedModule):
|
| 224 |
+
r"""This is a sequential container which calls the Conv2d modules with extra Add.
|
| 225 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 226 |
+
|
| 227 |
+
def __init__(self, conv, add):
|
| 228 |
+
super().__init__(conv)
|
| 229 |
+
self.add = add
|
| 230 |
+
|
| 231 |
+
def forward(self, x1, x2):
|
| 232 |
+
return self.add(self[0](x1), x2)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class ConvAddReLU2d(_FusedModule):
|
| 236 |
+
r"""This is a sequential container which calls the Conv2d, add, Relu.
|
| 237 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 238 |
+
|
| 239 |
+
def __init__(self, conv, add, relu):
|
| 240 |
+
super().__init__(conv)
|
| 241 |
+
self.add = add
|
| 242 |
+
self.relu = relu
|
| 243 |
+
|
| 244 |
+
def forward(self, x1, x2):
|
| 245 |
+
return self.relu(self.add(self[0](x1), x2))
|
.venv/Lib/site-packages/torch/ao/nn/intrinsic/qat/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .modules import * # noqa: F403
|
.venv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .conv_fused import (
|
| 2 |
+
ConvBn1d,
|
| 3 |
+
ConvBn2d,
|
| 4 |
+
ConvBn3d,
|
| 5 |
+
ConvBnReLU1d,
|
| 6 |
+
ConvBnReLU2d,
|
| 7 |
+
ConvBnReLU3d,
|
| 8 |
+
ConvReLU1d,
|
| 9 |
+
ConvReLU2d,
|
| 10 |
+
ConvReLU3d,
|
| 11 |
+
freeze_bn_stats,
|
| 12 |
+
update_bn_stats,
|
| 13 |
+
)
|
| 14 |
+
from .linear_fused import LinearBn1d
|
| 15 |
+
from .linear_relu import LinearReLU
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"LinearReLU",
|
| 20 |
+
"LinearBn1d",
|
| 21 |
+
"ConvReLU1d",
|
| 22 |
+
"ConvReLU2d",
|
| 23 |
+
"ConvReLU3d",
|
| 24 |
+
"ConvBn1d",
|
| 25 |
+
"ConvBn2d",
|
| 26 |
+
"ConvBn3d",
|
| 27 |
+
"ConvBnReLU1d",
|
| 28 |
+
"ConvBnReLU2d",
|
| 29 |
+
"ConvBnReLU3d",
|
| 30 |
+
"update_bn_stats",
|
| 31 |
+
"freeze_bn_stats",
|
| 32 |
+
]
|
.venv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/conv_fused.py
ADDED
|
@@ -0,0 +1,1050 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import math
|
| 3 |
+
from typing import TypeVar
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.ao.nn.intrinsic as nni
|
| 7 |
+
import torch.ao.nn.qat as nnqat
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torch.nn import init
|
| 11 |
+
from torch.nn.modules.utils import _pair, _single, _triple
|
| 12 |
+
from torch.nn.parameter import Parameter
|
| 13 |
+
from torch.nn.utils import fuse_conv_bn_weights
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
"ConvBn1d",
|
| 18 |
+
"ConvBnReLU1d",
|
| 19 |
+
"ConvReLU1d",
|
| 20 |
+
"ConvBn2d",
|
| 21 |
+
"ConvBnReLU2d",
|
| 22 |
+
"ConvReLU2d",
|
| 23 |
+
"ConvBn3d",
|
| 24 |
+
"ConvBnReLU3d",
|
| 25 |
+
"ConvReLU3d",
|
| 26 |
+
"update_bn_stats",
|
| 27 |
+
"freeze_bn_stats",
|
| 28 |
+
]
|
| 29 |
+
_BN_CLASS_MAP = {
|
| 30 |
+
1: nn.BatchNorm1d,
|
| 31 |
+
2: nn.BatchNorm2d,
|
| 32 |
+
3: nn.BatchNorm3d,
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
MOD = TypeVar("MOD", bound=nn.modules.conv._ConvNd)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule):
|
| 40 |
+
_version = 2
|
| 41 |
+
_FLOAT_MODULE = MOD
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
# ConvNd args
|
| 46 |
+
in_channels,
|
| 47 |
+
out_channels,
|
| 48 |
+
kernel_size,
|
| 49 |
+
stride,
|
| 50 |
+
padding,
|
| 51 |
+
dilation,
|
| 52 |
+
transposed,
|
| 53 |
+
output_padding,
|
| 54 |
+
groups,
|
| 55 |
+
bias,
|
| 56 |
+
padding_mode,
|
| 57 |
+
# BatchNormNd args
|
| 58 |
+
# num_features: out_channels
|
| 59 |
+
eps=1e-05,
|
| 60 |
+
momentum=0.1,
|
| 61 |
+
# affine: True
|
| 62 |
+
# track_running_stats: True
|
| 63 |
+
# Args for this module
|
| 64 |
+
freeze_bn=False,
|
| 65 |
+
qconfig=None,
|
| 66 |
+
dim=2,
|
| 67 |
+
):
|
| 68 |
+
nn.modules.conv._ConvNd.__init__(
|
| 69 |
+
self,
|
| 70 |
+
in_channels,
|
| 71 |
+
out_channels,
|
| 72 |
+
kernel_size,
|
| 73 |
+
stride,
|
| 74 |
+
padding,
|
| 75 |
+
dilation,
|
| 76 |
+
transposed,
|
| 77 |
+
output_padding,
|
| 78 |
+
groups,
|
| 79 |
+
False,
|
| 80 |
+
padding_mode,
|
| 81 |
+
)
|
| 82 |
+
assert qconfig, "qconfig must be provided for QAT module"
|
| 83 |
+
self.qconfig = qconfig
|
| 84 |
+
self.freeze_bn = freeze_bn if self.training else True
|
| 85 |
+
self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True)
|
| 86 |
+
self.weight_fake_quant = self.qconfig.weight()
|
| 87 |
+
if bias:
|
| 88 |
+
self.bias = Parameter(torch.empty(out_channels))
|
| 89 |
+
else:
|
| 90 |
+
self.register_parameter("bias", None)
|
| 91 |
+
self.reset_bn_parameters()
|
| 92 |
+
|
| 93 |
+
# this needs to be called after reset_bn_parameters,
|
| 94 |
+
# as they modify the same state
|
| 95 |
+
if self.training:
|
| 96 |
+
if freeze_bn:
|
| 97 |
+
self.freeze_bn_stats()
|
| 98 |
+
else:
|
| 99 |
+
self.update_bn_stats()
|
| 100 |
+
else:
|
| 101 |
+
self.freeze_bn_stats()
|
| 102 |
+
|
| 103 |
+
self._enable_slow_path_for_better_numerical_stability = False
|
| 104 |
+
|
| 105 |
+
def reset_running_stats(self):
|
| 106 |
+
self.bn.reset_running_stats()
|
| 107 |
+
|
| 108 |
+
def reset_bn_parameters(self):
|
| 109 |
+
self.bn.reset_running_stats()
|
| 110 |
+
init.uniform_(self.bn.weight)
|
| 111 |
+
init.zeros_(self.bn.bias)
|
| 112 |
+
# note: below is actually for conv, not BN
|
| 113 |
+
if self.bias is not None:
|
| 114 |
+
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
|
| 115 |
+
bound = 1 / math.sqrt(fan_in)
|
| 116 |
+
init.uniform_(self.bias, -bound, bound)
|
| 117 |
+
|
| 118 |
+
def reset_parameters(self):
|
| 119 |
+
super().reset_parameters()
|
| 120 |
+
|
| 121 |
+
def update_bn_stats(self):
|
| 122 |
+
self.freeze_bn = False
|
| 123 |
+
self.bn.training = True
|
| 124 |
+
return self
|
| 125 |
+
|
| 126 |
+
def freeze_bn_stats(self):
|
| 127 |
+
self.freeze_bn = True
|
| 128 |
+
self.bn.training = False
|
| 129 |
+
return self
|
| 130 |
+
|
| 131 |
+
def _forward(self, input):
|
| 132 |
+
if self._enable_slow_path_for_better_numerical_stability:
|
| 133 |
+
return self._forward_slow(input)
|
| 134 |
+
return self._forward_approximate(input)
|
| 135 |
+
|
| 136 |
+
def _forward_approximate(self, input):
|
| 137 |
+
"""Approximated method to fuse conv and bn. It requires only one forward pass.
|
| 138 |
+
conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std
|
| 139 |
+
"""
|
| 140 |
+
assert self.bn.running_var is not None
|
| 141 |
+
running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
|
| 142 |
+
scale_factor = self.bn.weight / running_std
|
| 143 |
+
weight_shape = [1] * len(self.weight.shape)
|
| 144 |
+
weight_shape[0] = -1
|
| 145 |
+
bias_shape = [1] * len(self.weight.shape)
|
| 146 |
+
bias_shape[1] = -1
|
| 147 |
+
scaled_weight = self.weight_fake_quant(
|
| 148 |
+
self.weight * scale_factor.reshape(weight_shape)
|
| 149 |
+
)
|
| 150 |
+
# using zero bias here since the bias for original conv
|
| 151 |
+
# will be added later
|
| 152 |
+
if self.bias is not None:
|
| 153 |
+
zero_bias = torch.zeros_like(self.bias, dtype=input.dtype)
|
| 154 |
+
else:
|
| 155 |
+
zero_bias = torch.zeros(
|
| 156 |
+
self.out_channels, device=scaled_weight.device, dtype=input.dtype
|
| 157 |
+
)
|
| 158 |
+
conv = self._conv_forward(input, scaled_weight, zero_bias)
|
| 159 |
+
conv_orig = conv / scale_factor.reshape(bias_shape)
|
| 160 |
+
if self.bias is not None:
|
| 161 |
+
conv_orig = conv_orig + self.bias.reshape(bias_shape)
|
| 162 |
+
conv = self.bn(conv_orig)
|
| 163 |
+
return conv
|
| 164 |
+
|
| 165 |
+
def _forward_slow(self, input):
|
| 166 |
+
"""
|
| 167 |
+
A more accurate but slow method to compute conv bn fusion, following https://arxiv.org/pdf/1806.08342.pdf
|
| 168 |
+
It requires two forward passes but handles the case bn.weight == 0
|
| 169 |
+
|
| 170 |
+
Conv: Y = WX + B_c
|
| 171 |
+
Conv without bias: Y0 = WX = Y - B_c, Y = Y0 + B_c
|
| 172 |
+
|
| 173 |
+
Batch statistics:
|
| 174 |
+
mean_Y = Y.mean()
|
| 175 |
+
= Y0.mean() + B_c
|
| 176 |
+
var_Y = (Y - mean_Y)^2.mean()
|
| 177 |
+
= (Y0 - Y0.mean())^2.mean()
|
| 178 |
+
BN (r: bn.weight, beta: bn.bias):
|
| 179 |
+
Z = r * (Y - mean_Y) / sqrt(var_Y + eps) + beta
|
| 180 |
+
= r * (Y0 - Y0.mean()) / sqrt(var_Y + eps) + beta
|
| 181 |
+
|
| 182 |
+
Fused Conv BN training (std_Y = sqrt(var_Y + eps)):
|
| 183 |
+
Z = (r * W / std_Y) * X + r * (B_c - mean_Y) / std_Y + beta
|
| 184 |
+
= (r * W / std_Y) * X - r * Y0.mean() / std_Y + beta
|
| 185 |
+
|
| 186 |
+
Fused Conv BN inference (running_std = sqrt(running_var + eps)):
|
| 187 |
+
Z = (r * W / running_std) * X - r * (running_mean - B_c) / running_std + beta
|
| 188 |
+
|
| 189 |
+
QAT with fused conv bn:
|
| 190 |
+
Z_train = fake_quant(r * W / running_std) * X * (running_std / std_Y) - r * Y0.mean() / std_Y + beta
|
| 191 |
+
= conv(X, fake_quant(r * W / running_std)) * (running_std / std_Y) - r * Y0.mean() / std_Y + beta
|
| 192 |
+
Z_inference = conv(X, fake_quant(r * W / running_std)) - r * (running_mean - B_c) / running_std + beta
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
assert self.bn.running_var is not None
|
| 196 |
+
assert self.bn.running_mean is not None
|
| 197 |
+
|
| 198 |
+
# using zero bias here since the bias for original conv
|
| 199 |
+
# will be added later
|
| 200 |
+
zero_bias = torch.zeros(
|
| 201 |
+
self.out_channels, device=self.weight.device, dtype=input.dtype
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
weight_shape = [1] * len(self.weight.shape)
|
| 205 |
+
weight_shape[0] = -1
|
| 206 |
+
bias_shape = [1] * len(self.weight.shape)
|
| 207 |
+
bias_shape[1] = -1
|
| 208 |
+
|
| 209 |
+
if self.bn.training:
|
| 210 |
+
# needed to compute batch mean/std
|
| 211 |
+
conv_out = self._conv_forward(input, self.weight, zero_bias)
|
| 212 |
+
# update bn statistics
|
| 213 |
+
with torch.no_grad():
|
| 214 |
+
conv_out_bias = (
|
| 215 |
+
conv_out
|
| 216 |
+
if self.bias is None
|
| 217 |
+
else conv_out + self.bias.reshape(bias_shape)
|
| 218 |
+
)
|
| 219 |
+
self.bn(conv_out_bias)
|
| 220 |
+
|
| 221 |
+
# fused conv + bn without bias using bn running statistics
|
| 222 |
+
running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
|
| 223 |
+
scale_factor = self.bn.weight / running_std
|
| 224 |
+
scaled_weight = self.weight_fake_quant(
|
| 225 |
+
self.weight * scale_factor.reshape(weight_shape)
|
| 226 |
+
)
|
| 227 |
+
# fused conv without bias for inference: (r * W / running_std) * X
|
| 228 |
+
conv_bn = self._conv_forward(input, scaled_weight, zero_bias)
|
| 229 |
+
|
| 230 |
+
if self.bn.training:
|
| 231 |
+
avg_dims = [0] + list(range(2, len(self.weight.shape)))
|
| 232 |
+
batch_mean = conv_out.mean(avg_dims) # type: ignore[possibly-undefined]
|
| 233 |
+
batch_var = torch.square(conv_out - batch_mean.reshape(bias_shape)).mean(
|
| 234 |
+
avg_dims
|
| 235 |
+
)
|
| 236 |
+
batch_std = torch.sqrt(batch_var + self.bn.eps)
|
| 237 |
+
|
| 238 |
+
# scale to use batch std in training mode
|
| 239 |
+
# conv(X, r * W / std_Y) = conv(X, r * W / running_std) * (running_std / std_Y)
|
| 240 |
+
unscale_factor = running_std / batch_std
|
| 241 |
+
conv_bn *= unscale_factor.reshape(bias_shape)
|
| 242 |
+
|
| 243 |
+
fused_mean = batch_mean
|
| 244 |
+
fused_std = batch_std
|
| 245 |
+
else:
|
| 246 |
+
fused_mean = self.bn.running_mean - (
|
| 247 |
+
self.bias if self.bias is not None else 0
|
| 248 |
+
)
|
| 249 |
+
fused_std = running_std
|
| 250 |
+
|
| 251 |
+
# fused bias = beta - r * mean / std
|
| 252 |
+
fused_bias = self.bn.bias - self.bn.weight * fused_mean / fused_std
|
| 253 |
+
conv_bn += fused_bias.reshape(bias_shape)
|
| 254 |
+
|
| 255 |
+
# HACK to let conv bias participate in loss to avoid DDP error (parameters
|
| 256 |
+
# were not used in producing loss)
|
| 257 |
+
if self.bias is not None:
|
| 258 |
+
conv_bn += (self.bias - self.bias).reshape(bias_shape)
|
| 259 |
+
|
| 260 |
+
return conv_bn
|
| 261 |
+
|
| 262 |
+
def extra_repr(self):
|
| 263 |
+
# TODO(jerryzh): extend
|
| 264 |
+
return super().extra_repr()
|
| 265 |
+
|
| 266 |
+
def forward(self, input):
|
| 267 |
+
return self._forward(input)
|
| 268 |
+
|
| 269 |
+
def train(self, mode=True):
|
| 270 |
+
"""
|
| 271 |
+
Batchnorm's training behavior is using the self.training flag. Prevent
|
| 272 |
+
changing it if BN is frozen. This makes sure that calling `model.train()`
|
| 273 |
+
on a model with a frozen BN will behave properly.
|
| 274 |
+
"""
|
| 275 |
+
self.training = mode
|
| 276 |
+
if not self.freeze_bn:
|
| 277 |
+
for module in self.children():
|
| 278 |
+
module.train(mode)
|
| 279 |
+
return self
|
| 280 |
+
|
| 281 |
+
# ===== Serialization version history =====
|
| 282 |
+
#
|
| 283 |
+
# Version 1/None
|
| 284 |
+
# self
|
| 285 |
+
# |--- weight : Tensor
|
| 286 |
+
# |--- bias : Tensor
|
| 287 |
+
# |--- gamma : Tensor
|
| 288 |
+
# |--- beta : Tensor
|
| 289 |
+
# |--- running_mean : Tensor
|
| 290 |
+
# |--- running_var : Tensor
|
| 291 |
+
# |--- num_batches_tracked : Tensor
|
| 292 |
+
#
|
| 293 |
+
# Version 2
|
| 294 |
+
# self
|
| 295 |
+
# |--- weight : Tensor
|
| 296 |
+
# |--- bias : Tensor
|
| 297 |
+
# |--- bn : Module
|
| 298 |
+
# |--- weight : Tensor (moved from v1.self.gamma)
|
| 299 |
+
# |--- bias : Tensor (moved from v1.self.beta)
|
| 300 |
+
# |--- running_mean : Tensor (moved from v1.self.running_mean)
|
| 301 |
+
# |--- running_var : Tensor (moved from v1.self.running_var)
|
| 302 |
+
# |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked)
|
| 303 |
+
def _load_from_state_dict(
|
| 304 |
+
self,
|
| 305 |
+
state_dict,
|
| 306 |
+
prefix,
|
| 307 |
+
local_metadata,
|
| 308 |
+
strict,
|
| 309 |
+
missing_keys,
|
| 310 |
+
unexpected_keys,
|
| 311 |
+
error_msgs,
|
| 312 |
+
):
|
| 313 |
+
version = local_metadata.get("version", None)
|
| 314 |
+
if version is None or version == 1:
|
| 315 |
+
# BN related parameters and buffers were moved into the BN module for v2
|
| 316 |
+
v2_to_v1_names = {
|
| 317 |
+
"bn.weight": "gamma",
|
| 318 |
+
"bn.bias": "beta",
|
| 319 |
+
"bn.running_mean": "running_mean",
|
| 320 |
+
"bn.running_var": "running_var",
|
| 321 |
+
"bn.num_batches_tracked": "num_batches_tracked",
|
| 322 |
+
}
|
| 323 |
+
for v2_name, v1_name in v2_to_v1_names.items():
|
| 324 |
+
if prefix + v1_name in state_dict:
|
| 325 |
+
state_dict[prefix + v2_name] = state_dict[prefix + v1_name]
|
| 326 |
+
state_dict.pop(prefix + v1_name)
|
| 327 |
+
elif prefix + v2_name in state_dict:
|
| 328 |
+
# there was a brief period where forward compatibility
|
| 329 |
+
# for this module was broken (between
|
| 330 |
+
# https://github.com/pytorch/pytorch/pull/38478
|
| 331 |
+
# and https://github.com/pytorch/pytorch/pull/38820)
|
| 332 |
+
# and modules emitted the v2 state_dict format while
|
| 333 |
+
# specifying that version == 1. This patches the forward
|
| 334 |
+
# compatibility issue by allowing the v2 style entries to
|
| 335 |
+
# be used.
|
| 336 |
+
pass
|
| 337 |
+
elif strict:
|
| 338 |
+
missing_keys.append(prefix + v2_name)
|
| 339 |
+
|
| 340 |
+
super()._load_from_state_dict(
|
| 341 |
+
state_dict,
|
| 342 |
+
prefix,
|
| 343 |
+
local_metadata,
|
| 344 |
+
strict,
|
| 345 |
+
missing_keys,
|
| 346 |
+
unexpected_keys,
|
| 347 |
+
error_msgs,
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
@classmethod
|
| 351 |
+
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
| 352 |
+
r"""Create a qat module from a float module or qparams_dict
|
| 353 |
+
|
| 354 |
+
Args: `mod` a float module, either produced by torch.ao.quantization utilities
|
| 355 |
+
or directly from user
|
| 356 |
+
"""
|
| 357 |
+
# The ignore is because _FLOAT_MODULE is a TypeVar here where the bound
|
| 358 |
+
# has no __name__ (code is fine though)
|
| 359 |
+
assert type(mod) == cls._FLOAT_MODULE, (
|
| 360 |
+
"qat."
|
| 361 |
+
+ cls.__name__
|
| 362 |
+
+ ".from_float only works for "
|
| 363 |
+
+ cls._FLOAT_MODULE.__name__ # type: ignore[attr-defined]
|
| 364 |
+
)
|
| 365 |
+
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
|
| 366 |
+
assert mod.qconfig, "Input float module must have a valid qconfig"
|
| 367 |
+
qconfig = mod.qconfig
|
| 368 |
+
conv, bn = mod[0], mod[1]
|
| 369 |
+
qat_convbn = cls(
|
| 370 |
+
conv.in_channels,
|
| 371 |
+
conv.out_channels,
|
| 372 |
+
conv.kernel_size,
|
| 373 |
+
conv.stride,
|
| 374 |
+
conv.padding,
|
| 375 |
+
conv.dilation,
|
| 376 |
+
conv.groups,
|
| 377 |
+
conv.bias is not None,
|
| 378 |
+
conv.padding_mode,
|
| 379 |
+
bn.eps,
|
| 380 |
+
bn.momentum,
|
| 381 |
+
False,
|
| 382 |
+
qconfig,
|
| 383 |
+
)
|
| 384 |
+
qat_convbn.weight = conv.weight
|
| 385 |
+
qat_convbn.bias = conv.bias
|
| 386 |
+
qat_convbn.bn.weight = bn.weight
|
| 387 |
+
qat_convbn.bn.bias = bn.bias
|
| 388 |
+
qat_convbn.bn.running_mean = bn.running_mean
|
| 389 |
+
qat_convbn.bn.running_var = bn.running_var
|
| 390 |
+
# mypy error: Cannot determine type of 'num_batches_tracked'
|
| 391 |
+
qat_convbn.bn.num_batches_tracked = bn.num_batches_tracked # type: ignore[has-type]
|
| 392 |
+
return qat_convbn
|
| 393 |
+
|
| 394 |
+
def to_float(self):
|
| 395 |
+
cls = type(self)
|
| 396 |
+
conv = cls._FLOAT_CONV_MODULE( # type: ignore[attr-defined]
|
| 397 |
+
self.in_channels,
|
| 398 |
+
self.out_channels,
|
| 399 |
+
self.kernel_size,
|
| 400 |
+
self.stride,
|
| 401 |
+
self.padding,
|
| 402 |
+
self.dilation,
|
| 403 |
+
self.groups,
|
| 404 |
+
self.bias is not None,
|
| 405 |
+
self.padding_mode,
|
| 406 |
+
)
|
| 407 |
+
conv.weight = torch.nn.Parameter(self.weight.detach())
|
| 408 |
+
if self.bias is not None:
|
| 409 |
+
conv.bias = torch.nn.Parameter(self.bias.detach())
|
| 410 |
+
|
| 411 |
+
if cls._FLOAT_BN_MODULE: # type: ignore[attr-defined]
|
| 412 |
+
# fuse bn into conv
|
| 413 |
+
assert self.bn.running_var is not None and self.bn.running_mean is not None
|
| 414 |
+
conv.weight, conv.bias = fuse_conv_bn_weights(
|
| 415 |
+
conv.weight,
|
| 416 |
+
conv.bias,
|
| 417 |
+
self.bn.running_mean,
|
| 418 |
+
self.bn.running_var,
|
| 419 |
+
self.bn.eps,
|
| 420 |
+
self.bn.weight,
|
| 421 |
+
self.bn.bias,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
if cls._FLOAT_RELU_MODULE: # type: ignore[attr-defined]
|
| 425 |
+
modules = []
|
| 426 |
+
modules.append(conv)
|
| 427 |
+
relu = cls._FLOAT_RELU_MODULE() # type: ignore[attr-defined]
|
| 428 |
+
modules.append(relu)
|
| 429 |
+
conv_relu = cls._FUSED_FLOAT_MODULE(*modules) # type: ignore[attr-defined]
|
| 430 |
+
conv_relu.train(self.training)
|
| 431 |
+
return conv_relu
|
| 432 |
+
else:
|
| 433 |
+
conv.train(self.training)
|
| 434 |
+
return conv
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
class ConvBn1d(_ConvBnNd, nn.Conv1d):
|
| 438 |
+
r"""
|
| 439 |
+
A ConvBn1d module is a module fused from Conv1d and BatchNorm1d,
|
| 440 |
+
attached with FakeQuantize modules for weight,
|
| 441 |
+
used in quantization aware training.
|
| 442 |
+
|
| 443 |
+
We combined the interface of :class:`torch.nn.Conv1d` and
|
| 444 |
+
:class:`torch.nn.BatchNorm1d`.
|
| 445 |
+
|
| 446 |
+
Similar to :class:`torch.nn.Conv1d`, with FakeQuantize modules initialized
|
| 447 |
+
to default.
|
| 448 |
+
|
| 449 |
+
Attributes:
|
| 450 |
+
freeze_bn:
|
| 451 |
+
weight_fake_quant: fake quant module for weight
|
| 452 |
+
|
| 453 |
+
"""
|
| 454 |
+
_FLOAT_BN_MODULE = nn.BatchNorm1d
|
| 455 |
+
_FLOAT_RELU_MODULE: None = None
|
| 456 |
+
_FLOAT_MODULE = nni.ConvBn1d
|
| 457 |
+
_FLOAT_CONV_MODULE = nn.Conv1d
|
| 458 |
+
|
| 459 |
+
def __init__(
|
| 460 |
+
self,
|
| 461 |
+
# Conv1d args
|
| 462 |
+
in_channels,
|
| 463 |
+
out_channels,
|
| 464 |
+
kernel_size,
|
| 465 |
+
stride=1,
|
| 466 |
+
padding=0,
|
| 467 |
+
dilation=1,
|
| 468 |
+
groups=1,
|
| 469 |
+
bias=None,
|
| 470 |
+
padding_mode="zeros",
|
| 471 |
+
# BatchNorm1d args
|
| 472 |
+
# num_features: out_channels
|
| 473 |
+
eps=1e-05,
|
| 474 |
+
momentum=0.1,
|
| 475 |
+
# affine: True
|
| 476 |
+
# track_running_stats: True
|
| 477 |
+
# Args for this module
|
| 478 |
+
freeze_bn=False,
|
| 479 |
+
qconfig=None,
|
| 480 |
+
):
|
| 481 |
+
kernel_size = _single(kernel_size)
|
| 482 |
+
stride = _single(stride)
|
| 483 |
+
padding = _single(padding)
|
| 484 |
+
dilation = _single(dilation)
|
| 485 |
+
_ConvBnNd.__init__(
|
| 486 |
+
self,
|
| 487 |
+
in_channels,
|
| 488 |
+
out_channels,
|
| 489 |
+
kernel_size,
|
| 490 |
+
stride,
|
| 491 |
+
padding,
|
| 492 |
+
dilation,
|
| 493 |
+
False,
|
| 494 |
+
_single(0),
|
| 495 |
+
groups,
|
| 496 |
+
bias,
|
| 497 |
+
padding_mode,
|
| 498 |
+
eps,
|
| 499 |
+
momentum,
|
| 500 |
+
freeze_bn,
|
| 501 |
+
qconfig,
|
| 502 |
+
dim=1,
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
class ConvBnReLU1d(ConvBn1d):
|
| 507 |
+
r"""
|
| 508 |
+
A ConvBnReLU1d module is a module fused from Conv1d, BatchNorm1d and ReLU,
|
| 509 |
+
attached with FakeQuantize modules for weight,
|
| 510 |
+
used in quantization aware training.
|
| 511 |
+
|
| 512 |
+
We combined the interface of :class:`torch.nn.Conv1d` and
|
| 513 |
+
:class:`torch.nn.BatchNorm1d` and :class:`torch.nn.ReLU`.
|
| 514 |
+
|
| 515 |
+
Similar to `torch.nn.Conv1d`, with FakeQuantize modules initialized to
|
| 516 |
+
default.
|
| 517 |
+
|
| 518 |
+
Attributes:
|
| 519 |
+
weight_fake_quant: fake quant module for weight
|
| 520 |
+
|
| 521 |
+
"""
|
| 522 |
+
# base class defines _FLOAT_MODULE as "ConvBn1d"
|
| 523 |
+
_FLOAT_MODULE = nni.ConvBnReLU1d # type: ignore[assignment]
|
| 524 |
+
_FLOAT_CONV_MODULE = nn.Conv1d
|
| 525 |
+
_FLOAT_BN_MODULE = nn.BatchNorm1d
|
| 526 |
+
_FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment]
|
| 527 |
+
# module class after fusing bn into conv
|
| 528 |
+
_FUSED_FLOAT_MODULE = nni.ConvReLU1d
|
| 529 |
+
|
| 530 |
+
def __init__(
|
| 531 |
+
self,
|
| 532 |
+
# Conv1d args
|
| 533 |
+
in_channels,
|
| 534 |
+
out_channels,
|
| 535 |
+
kernel_size,
|
| 536 |
+
stride=1,
|
| 537 |
+
padding=0,
|
| 538 |
+
dilation=1,
|
| 539 |
+
groups=1,
|
| 540 |
+
bias=None,
|
| 541 |
+
padding_mode="zeros",
|
| 542 |
+
# BatchNorm1d args
|
| 543 |
+
# num_features: out_channels
|
| 544 |
+
eps=1e-05,
|
| 545 |
+
momentum=0.1,
|
| 546 |
+
# affine: True
|
| 547 |
+
# track_running_stats: True
|
| 548 |
+
# Args for this module
|
| 549 |
+
freeze_bn=False,
|
| 550 |
+
qconfig=None,
|
| 551 |
+
):
|
| 552 |
+
super().__init__(
|
| 553 |
+
in_channels,
|
| 554 |
+
out_channels,
|
| 555 |
+
kernel_size,
|
| 556 |
+
stride,
|
| 557 |
+
padding,
|
| 558 |
+
dilation,
|
| 559 |
+
groups,
|
| 560 |
+
bias,
|
| 561 |
+
padding_mode,
|
| 562 |
+
eps,
|
| 563 |
+
momentum,
|
| 564 |
+
freeze_bn,
|
| 565 |
+
qconfig,
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
def forward(self, input):
|
| 569 |
+
return F.relu(ConvBn1d._forward(self, input))
|
| 570 |
+
|
| 571 |
+
@classmethod
|
| 572 |
+
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
| 573 |
+
return super().from_float(mod, use_precomputed_fake_quant)
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
|
| 577 |
+
r"""A ConvReLU1d module is a fused module of Conv1d and ReLU, attached with
|
| 578 |
+
FakeQuantize modules for weight for
|
| 579 |
+
quantization aware training.
|
| 580 |
+
|
| 581 |
+
We combined the interface of :class:`~torch.nn.Conv1d` and
|
| 582 |
+
:class:`~torch.nn.BatchNorm1d`.
|
| 583 |
+
|
| 584 |
+
Attributes:
|
| 585 |
+
weight_fake_quant: fake quant module for weight
|
| 586 |
+
|
| 587 |
+
"""
|
| 588 |
+
_FLOAT_MODULE = nni.ConvReLU1d # type: ignore[assignment]
|
| 589 |
+
_FLOAT_CONV_MODULE = nn.Conv1d
|
| 590 |
+
_FLOAT_BN_MODULE: None = None
|
| 591 |
+
_FLOAT_RELU_MODULE = nn.ReLU
|
| 592 |
+
|
| 593 |
+
def __init__(
|
| 594 |
+
self,
|
| 595 |
+
in_channels,
|
| 596 |
+
out_channels,
|
| 597 |
+
kernel_size,
|
| 598 |
+
stride=1,
|
| 599 |
+
padding=0,
|
| 600 |
+
dilation=1,
|
| 601 |
+
groups=1,
|
| 602 |
+
bias=True,
|
| 603 |
+
padding_mode="zeros",
|
| 604 |
+
qconfig=None,
|
| 605 |
+
):
|
| 606 |
+
super().__init__(
|
| 607 |
+
in_channels,
|
| 608 |
+
out_channels,
|
| 609 |
+
kernel_size,
|
| 610 |
+
stride=stride,
|
| 611 |
+
padding=padding,
|
| 612 |
+
dilation=dilation,
|
| 613 |
+
groups=groups,
|
| 614 |
+
bias=bias,
|
| 615 |
+
padding_mode=padding_mode,
|
| 616 |
+
qconfig=qconfig,
|
| 617 |
+
)
|
| 618 |
+
assert qconfig, "qconfig must be provided for QAT module"
|
| 619 |
+
self.qconfig = qconfig
|
| 620 |
+
self.weight_fake_quant = self.qconfig.weight()
|
| 621 |
+
|
| 622 |
+
def forward(self, input):
|
| 623 |
+
return F.relu(
|
| 624 |
+
self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
@classmethod
|
| 628 |
+
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
| 629 |
+
return super().from_float(
|
| 630 |
+
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
class ConvBn2d(_ConvBnNd, nn.Conv2d):
|
| 635 |
+
r"""
|
| 636 |
+
A ConvBn2d module is a module fused from Conv2d and BatchNorm2d,
|
| 637 |
+
attached with FakeQuantize modules for weight,
|
| 638 |
+
used in quantization aware training.
|
| 639 |
+
|
| 640 |
+
We combined the interface of :class:`torch.nn.Conv2d` and
|
| 641 |
+
:class:`torch.nn.BatchNorm2d`.
|
| 642 |
+
|
| 643 |
+
Similar to :class:`torch.nn.Conv2d`, with FakeQuantize modules initialized
|
| 644 |
+
to default.
|
| 645 |
+
|
| 646 |
+
Attributes:
|
| 647 |
+
freeze_bn:
|
| 648 |
+
weight_fake_quant: fake quant module for weight
|
| 649 |
+
|
| 650 |
+
"""
|
| 651 |
+
_FLOAT_MODULE = nni.ConvBn2d
|
| 652 |
+
_FLOAT_CONV_MODULE = nn.Conv2d
|
| 653 |
+
_FLOAT_BN_MODULE = nn.BatchNorm2d
|
| 654 |
+
_FLOAT_RELU_MODULE: None = None
|
| 655 |
+
|
| 656 |
+
def __init__(
|
| 657 |
+
self,
|
| 658 |
+
# ConvNd args
|
| 659 |
+
in_channels,
|
| 660 |
+
out_channels,
|
| 661 |
+
kernel_size,
|
| 662 |
+
stride=1,
|
| 663 |
+
padding=0,
|
| 664 |
+
dilation=1,
|
| 665 |
+
groups=1,
|
| 666 |
+
bias=None,
|
| 667 |
+
padding_mode="zeros",
|
| 668 |
+
# BatchNorm2d args
|
| 669 |
+
# num_features: out_channels
|
| 670 |
+
eps=1e-05,
|
| 671 |
+
momentum=0.1,
|
| 672 |
+
# affine: True
|
| 673 |
+
# track_running_stats: True
|
| 674 |
+
# Args for this module
|
| 675 |
+
freeze_bn=False,
|
| 676 |
+
qconfig=None,
|
| 677 |
+
):
|
| 678 |
+
kernel_size = _pair(kernel_size)
|
| 679 |
+
stride = _pair(stride)
|
| 680 |
+
padding = _pair(padding)
|
| 681 |
+
dilation = _pair(dilation)
|
| 682 |
+
_ConvBnNd.__init__(
|
| 683 |
+
self,
|
| 684 |
+
in_channels,
|
| 685 |
+
out_channels,
|
| 686 |
+
kernel_size,
|
| 687 |
+
stride,
|
| 688 |
+
padding,
|
| 689 |
+
dilation,
|
| 690 |
+
False,
|
| 691 |
+
_pair(0),
|
| 692 |
+
groups,
|
| 693 |
+
bias,
|
| 694 |
+
padding_mode,
|
| 695 |
+
eps,
|
| 696 |
+
momentum,
|
| 697 |
+
freeze_bn,
|
| 698 |
+
qconfig,
|
| 699 |
+
dim=2,
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
class ConvBnReLU2d(ConvBn2d):
|
| 704 |
+
r"""
|
| 705 |
+
A ConvBnReLU2d module is a module fused from Conv2d, BatchNorm2d and ReLU,
|
| 706 |
+
attached with FakeQuantize modules for weight,
|
| 707 |
+
used in quantization aware training.
|
| 708 |
+
|
| 709 |
+
We combined the interface of :class:`torch.nn.Conv2d` and
|
| 710 |
+
:class:`torch.nn.BatchNorm2d` and :class:`torch.nn.ReLU`.
|
| 711 |
+
|
| 712 |
+
Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to
|
| 713 |
+
default.
|
| 714 |
+
|
| 715 |
+
Attributes:
|
| 716 |
+
weight_fake_quant: fake quant module for weight
|
| 717 |
+
|
| 718 |
+
"""
|
| 719 |
+
# base class defines _FLOAT_MODULE as "ConvBn2d"
|
| 720 |
+
_FLOAT_MODULE = nni.ConvBnReLU2d # type: ignore[assignment]
|
| 721 |
+
_FLOAT_CONV_MODULE = nn.Conv2d
|
| 722 |
+
_FLOAT_BN_MODULE = nn.BatchNorm2d
|
| 723 |
+
_FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment]
|
| 724 |
+
# module class after fusing bn into conv
|
| 725 |
+
_FUSED_FLOAT_MODULE = nni.ConvReLU2d
|
| 726 |
+
|
| 727 |
+
def __init__(
|
| 728 |
+
self,
|
| 729 |
+
# Conv2d args
|
| 730 |
+
in_channels,
|
| 731 |
+
out_channels,
|
| 732 |
+
kernel_size,
|
| 733 |
+
stride=1,
|
| 734 |
+
padding=0,
|
| 735 |
+
dilation=1,
|
| 736 |
+
groups=1,
|
| 737 |
+
bias=None,
|
| 738 |
+
padding_mode="zeros",
|
| 739 |
+
# BatchNorm2d args
|
| 740 |
+
# num_features: out_channels
|
| 741 |
+
eps=1e-05,
|
| 742 |
+
momentum=0.1,
|
| 743 |
+
# affine: True
|
| 744 |
+
# track_running_stats: True
|
| 745 |
+
# Args for this module
|
| 746 |
+
freeze_bn=False,
|
| 747 |
+
qconfig=None,
|
| 748 |
+
):
|
| 749 |
+
super().__init__(
|
| 750 |
+
in_channels,
|
| 751 |
+
out_channels,
|
| 752 |
+
kernel_size,
|
| 753 |
+
stride,
|
| 754 |
+
padding,
|
| 755 |
+
dilation,
|
| 756 |
+
groups,
|
| 757 |
+
bias,
|
| 758 |
+
padding_mode,
|
| 759 |
+
eps,
|
| 760 |
+
momentum,
|
| 761 |
+
freeze_bn,
|
| 762 |
+
qconfig,
|
| 763 |
+
)
|
| 764 |
+
|
| 765 |
+
def forward(self, input):
|
| 766 |
+
return F.relu(ConvBn2d._forward(self, input))
|
| 767 |
+
|
| 768 |
+
@classmethod
|
| 769 |
+
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
| 770 |
+
return super().from_float(mod, use_precomputed_fake_quant)
|
| 771 |
+
|
| 772 |
+
|
| 773 |
+
class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
|
| 774 |
+
r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with
|
| 775 |
+
FakeQuantize modules for weight for
|
| 776 |
+
quantization aware training.
|
| 777 |
+
|
| 778 |
+
We combined the interface of :class:`~torch.nn.Conv2d` and
|
| 779 |
+
:class:`~torch.nn.BatchNorm2d`.
|
| 780 |
+
|
| 781 |
+
Attributes:
|
| 782 |
+
weight_fake_quant: fake quant module for weight
|
| 783 |
+
|
| 784 |
+
"""
|
| 785 |
+
_FLOAT_MODULE = nni.ConvReLU2d # type: ignore[assignment]
|
| 786 |
+
_FLOAT_CONV_MODULE = nn.Conv2d
|
| 787 |
+
_FLOAT_BN_MODULE: None = None
|
| 788 |
+
_FLOAT_RELU_MODULE = nn.ReLU
|
| 789 |
+
|
| 790 |
+
def __init__(
|
| 791 |
+
self,
|
| 792 |
+
in_channels,
|
| 793 |
+
out_channels,
|
| 794 |
+
kernel_size,
|
| 795 |
+
stride=1,
|
| 796 |
+
padding=0,
|
| 797 |
+
dilation=1,
|
| 798 |
+
groups=1,
|
| 799 |
+
bias=True,
|
| 800 |
+
padding_mode="zeros",
|
| 801 |
+
qconfig=None,
|
| 802 |
+
):
|
| 803 |
+
super().__init__(
|
| 804 |
+
in_channels,
|
| 805 |
+
out_channels,
|
| 806 |
+
kernel_size,
|
| 807 |
+
stride=stride,
|
| 808 |
+
padding=padding,
|
| 809 |
+
dilation=dilation,
|
| 810 |
+
groups=groups,
|
| 811 |
+
bias=bias,
|
| 812 |
+
padding_mode=padding_mode,
|
| 813 |
+
qconfig=qconfig,
|
| 814 |
+
)
|
| 815 |
+
assert qconfig, "qconfig must be provided for QAT module"
|
| 816 |
+
self.qconfig = qconfig
|
| 817 |
+
self.weight_fake_quant = self.qconfig.weight()
|
| 818 |
+
|
| 819 |
+
def forward(self, input):
|
| 820 |
+
return F.relu(
|
| 821 |
+
self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
|
| 822 |
+
)
|
| 823 |
+
|
| 824 |
+
@classmethod
|
| 825 |
+
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
| 826 |
+
return super().from_float(
|
| 827 |
+
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
| 828 |
+
)
|
| 829 |
+
|
| 830 |
+
|
| 831 |
+
class ConvBn3d(_ConvBnNd, nn.Conv3d):
|
| 832 |
+
r"""
|
| 833 |
+
A ConvBn3d module is a module fused from Conv3d and BatchNorm3d,
|
| 834 |
+
attached with FakeQuantize modules for weight,
|
| 835 |
+
used in quantization aware training.
|
| 836 |
+
|
| 837 |
+
We combined the interface of :class:`torch.nn.Conv3d` and
|
| 838 |
+
:class:`torch.nn.BatchNorm3d`.
|
| 839 |
+
|
| 840 |
+
Similar to :class:`torch.nn.Conv3d`, with FakeQuantize modules initialized
|
| 841 |
+
to default.
|
| 842 |
+
|
| 843 |
+
Attributes:
|
| 844 |
+
freeze_bn:
|
| 845 |
+
weight_fake_quant: fake quant module for weight
|
| 846 |
+
|
| 847 |
+
"""
|
| 848 |
+
_FLOAT_MODULE = nni.ConvBn3d
|
| 849 |
+
_FLOAT_CONV_MODULE = nn.Conv3d
|
| 850 |
+
_FLOAT_BN_MODULE = nn.BatchNorm3d
|
| 851 |
+
_FLOAT_RELU_MODULE: None = None
|
| 852 |
+
|
| 853 |
+
def __init__(
|
| 854 |
+
self,
|
| 855 |
+
# ConvNd args
|
| 856 |
+
in_channels,
|
| 857 |
+
out_channels,
|
| 858 |
+
kernel_size,
|
| 859 |
+
stride=1,
|
| 860 |
+
padding=0,
|
| 861 |
+
dilation=1,
|
| 862 |
+
groups=1,
|
| 863 |
+
bias=None,
|
| 864 |
+
padding_mode="zeros",
|
| 865 |
+
# BatchNorm3d args
|
| 866 |
+
# num_features: out_channels
|
| 867 |
+
eps=1e-05,
|
| 868 |
+
momentum=0.1,
|
| 869 |
+
# affine: True
|
| 870 |
+
# track_running_stats: True
|
| 871 |
+
# Args for this module
|
| 872 |
+
freeze_bn=False,
|
| 873 |
+
qconfig=None,
|
| 874 |
+
):
|
| 875 |
+
kernel_size = _triple(kernel_size)
|
| 876 |
+
stride = _triple(stride)
|
| 877 |
+
padding = _triple(padding)
|
| 878 |
+
dilation = _triple(dilation)
|
| 879 |
+
_ConvBnNd.__init__(
|
| 880 |
+
self,
|
| 881 |
+
in_channels,
|
| 882 |
+
out_channels,
|
| 883 |
+
kernel_size,
|
| 884 |
+
stride,
|
| 885 |
+
padding,
|
| 886 |
+
dilation,
|
| 887 |
+
False,
|
| 888 |
+
_triple(0),
|
| 889 |
+
groups,
|
| 890 |
+
bias,
|
| 891 |
+
padding_mode,
|
| 892 |
+
eps,
|
| 893 |
+
momentum,
|
| 894 |
+
freeze_bn,
|
| 895 |
+
qconfig,
|
| 896 |
+
dim=3,
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
|
| 900 |
+
class ConvBnReLU3d(ConvBn3d):
|
| 901 |
+
r"""
|
| 902 |
+
A ConvBnReLU3d module is a module fused from Conv3d, BatchNorm3d and ReLU,
|
| 903 |
+
attached with FakeQuantize modules for weight,
|
| 904 |
+
used in quantization aware training.
|
| 905 |
+
|
| 906 |
+
We combined the interface of :class:`torch.nn.Conv3d` and
|
| 907 |
+
:class:`torch.nn.BatchNorm3d` and :class:`torch.nn.ReLU`.
|
| 908 |
+
|
| 909 |
+
Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to
|
| 910 |
+
default.
|
| 911 |
+
|
| 912 |
+
Attributes:
|
| 913 |
+
weight_fake_quant: fake quant module for weight
|
| 914 |
+
|
| 915 |
+
"""
|
| 916 |
+
_FLOAT_MODULE = nni.ConvBnReLU3d # type: ignore[assignment]
|
| 917 |
+
_FLOAT_CONV_MODULE = nn.Conv3d
|
| 918 |
+
_FLOAT_BN_MODULE = nn.BatchNorm3d
|
| 919 |
+
_FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment]
|
| 920 |
+
# module class after fusing bn into conv
|
| 921 |
+
_FUSED_FLOAT_MODULE = nni.ConvReLU3d
|
| 922 |
+
|
| 923 |
+
def __init__(
|
| 924 |
+
self,
|
| 925 |
+
# Conv3d args
|
| 926 |
+
in_channels,
|
| 927 |
+
out_channels,
|
| 928 |
+
kernel_size,
|
| 929 |
+
stride=1,
|
| 930 |
+
padding=0,
|
| 931 |
+
dilation=1,
|
| 932 |
+
groups=1,
|
| 933 |
+
bias=None,
|
| 934 |
+
padding_mode="zeros",
|
| 935 |
+
# BatchNorm3d args
|
| 936 |
+
# num_features: out_channels
|
| 937 |
+
eps=1e-05,
|
| 938 |
+
momentum=0.1,
|
| 939 |
+
# affine: True
|
| 940 |
+
# track_running_stats: True
|
| 941 |
+
# Args for this module
|
| 942 |
+
freeze_bn=False,
|
| 943 |
+
qconfig=None,
|
| 944 |
+
):
|
| 945 |
+
super().__init__(
|
| 946 |
+
in_channels,
|
| 947 |
+
out_channels,
|
| 948 |
+
kernel_size,
|
| 949 |
+
stride,
|
| 950 |
+
padding,
|
| 951 |
+
dilation,
|
| 952 |
+
groups,
|
| 953 |
+
bias,
|
| 954 |
+
padding_mode,
|
| 955 |
+
eps,
|
| 956 |
+
momentum,
|
| 957 |
+
freeze_bn,
|
| 958 |
+
qconfig,
|
| 959 |
+
)
|
| 960 |
+
|
| 961 |
+
def forward(self, input):
|
| 962 |
+
return F.relu(ConvBn3d._forward(self, input))
|
| 963 |
+
|
| 964 |
+
@classmethod
|
| 965 |
+
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
| 966 |
+
return super().from_float(
|
| 967 |
+
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
| 968 |
+
)
|
| 969 |
+
|
| 970 |
+
|
| 971 |
+
class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
|
| 972 |
+
r"""A ConvReLU3d module is a fused module of Conv3d and ReLU, attached with
|
| 973 |
+
FakeQuantize modules for weight for
|
| 974 |
+
quantization aware training.
|
| 975 |
+
|
| 976 |
+
We combined the interface of :class:`~torch.nn.Conv3d` and
|
| 977 |
+
:class:`~torch.nn.BatchNorm3d`.
|
| 978 |
+
|
| 979 |
+
Attributes:
|
| 980 |
+
weight_fake_quant: fake quant module for weight
|
| 981 |
+
|
| 982 |
+
"""
|
| 983 |
+
_FLOAT_MODULE = nni.ConvReLU3d # type: ignore[assignment]
|
| 984 |
+
_FLOAT_CONV_MODULE = nn.Conv3d
|
| 985 |
+
_FLOAT_BN_MODULE: None = None
|
| 986 |
+
_FLOAT_RELU_MODULE = nn.ReLU
|
| 987 |
+
|
| 988 |
+
def __init__(
|
| 989 |
+
self,
|
| 990 |
+
in_channels,
|
| 991 |
+
out_channels,
|
| 992 |
+
kernel_size,
|
| 993 |
+
stride=1,
|
| 994 |
+
padding=0,
|
| 995 |
+
dilation=1,
|
| 996 |
+
groups=1,
|
| 997 |
+
bias=True,
|
| 998 |
+
padding_mode="zeros",
|
| 999 |
+
qconfig=None,
|
| 1000 |
+
):
|
| 1001 |
+
super().__init__(
|
| 1002 |
+
in_channels,
|
| 1003 |
+
out_channels,
|
| 1004 |
+
kernel_size,
|
| 1005 |
+
stride=stride,
|
| 1006 |
+
padding=padding,
|
| 1007 |
+
dilation=dilation,
|
| 1008 |
+
groups=groups,
|
| 1009 |
+
bias=bias,
|
| 1010 |
+
padding_mode=padding_mode,
|
| 1011 |
+
qconfig=qconfig,
|
| 1012 |
+
)
|
| 1013 |
+
assert qconfig, "qconfig must be provided for QAT module"
|
| 1014 |
+
self.qconfig = qconfig
|
| 1015 |
+
self.weight_fake_quant = self.qconfig.weight()
|
| 1016 |
+
|
| 1017 |
+
def forward(self, input):
|
| 1018 |
+
return F.relu(
|
| 1019 |
+
self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
|
| 1020 |
+
)
|
| 1021 |
+
|
| 1022 |
+
@classmethod
|
| 1023 |
+
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
| 1024 |
+
return super().from_float(
|
| 1025 |
+
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
| 1026 |
+
)
|
| 1027 |
+
|
| 1028 |
+
|
| 1029 |
+
def update_bn_stats(mod):
|
| 1030 |
+
if type(mod) in {
|
| 1031 |
+
ConvBnReLU1d,
|
| 1032 |
+
ConvBnReLU2d,
|
| 1033 |
+
ConvBnReLU3d,
|
| 1034 |
+
ConvBn1d,
|
| 1035 |
+
ConvBn2d,
|
| 1036 |
+
ConvBn3d,
|
| 1037 |
+
}:
|
| 1038 |
+
mod.update_bn_stats()
|
| 1039 |
+
|
| 1040 |
+
|
| 1041 |
+
def freeze_bn_stats(mod):
|
| 1042 |
+
if type(mod) in {
|
| 1043 |
+
ConvBnReLU1d,
|
| 1044 |
+
ConvBnReLU2d,
|
| 1045 |
+
ConvBnReLU3d,
|
| 1046 |
+
ConvBn1d,
|
| 1047 |
+
ConvBn2d,
|
| 1048 |
+
ConvBn3d,
|
| 1049 |
+
}:
|
| 1050 |
+
mod.freeze_bn_stats()
|
.venv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_fused.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
import torch.ao.nn.intrinsic as nni
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch.nn import init
|
| 7 |
+
from torch.nn.parameter import Parameter
|
| 8 |
+
from torch.nn.utils.fusion import fuse_linear_bn_weights
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"LinearBn1d",
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class LinearBn1d(nn.modules.linear.Linear, nni._FusedModule):
|
| 17 |
+
r"""
|
| 18 |
+
A LinearBn1d module is a module fused from Linear and BatchNorm1d, attached
|
| 19 |
+
with FakeQuantize modules for weight, used in quantization aware training.
|
| 20 |
+
|
| 21 |
+
We combined the interface of :class:`torch.nn.Linear` and
|
| 22 |
+
:class:torch.nn.BatchNorm1d`.
|
| 23 |
+
|
| 24 |
+
Similar to :class:`torch.nn.Linear`, with FakeQuantize modules initialized
|
| 25 |
+
to default.
|
| 26 |
+
|
| 27 |
+
Attributes:
|
| 28 |
+
freeze_bn:
|
| 29 |
+
weight_fake_quant: fake quant module for weight
|
| 30 |
+
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
# Linear args
|
| 36 |
+
in_features,
|
| 37 |
+
out_features,
|
| 38 |
+
bias=True,
|
| 39 |
+
# BatchNorm1d args
|
| 40 |
+
# num_features: out_features
|
| 41 |
+
eps=1e-05,
|
| 42 |
+
momentum=0.1,
|
| 43 |
+
# affine: True
|
| 44 |
+
# track_running_stats: True
|
| 45 |
+
# Args for this module
|
| 46 |
+
freeze_bn=False,
|
| 47 |
+
qconfig=None,
|
| 48 |
+
):
|
| 49 |
+
nn.modules.linear.Linear.__init__(self, in_features, out_features, bias)
|
| 50 |
+
assert qconfig, "qconfig must be provided for QAT module"
|
| 51 |
+
self.qconfig = qconfig
|
| 52 |
+
self.freeze_bn = freeze_bn if self.training else True
|
| 53 |
+
self.bn = nn.BatchNorm1d(out_features, eps, momentum, True, True)
|
| 54 |
+
self.weight_fake_quant = self.qconfig.weight()
|
| 55 |
+
if bias:
|
| 56 |
+
self.bias = Parameter(torch.empty(out_features))
|
| 57 |
+
else:
|
| 58 |
+
self.register_parameter("bias", None)
|
| 59 |
+
self.reset_bn_parameters()
|
| 60 |
+
|
| 61 |
+
# this needs to be called after reset_bn_parameters,
|
| 62 |
+
# as they modify the same state
|
| 63 |
+
if self.training:
|
| 64 |
+
if freeze_bn:
|
| 65 |
+
self.freeze_bn_stats()
|
| 66 |
+
else:
|
| 67 |
+
self.update_bn_stats()
|
| 68 |
+
else:
|
| 69 |
+
self.freeze_bn_stats()
|
| 70 |
+
|
| 71 |
+
def reset_running_stats(self):
|
| 72 |
+
self.bn.reset_running_stats()
|
| 73 |
+
|
| 74 |
+
def reset_bn_parameters(self):
|
| 75 |
+
self.bn.reset_running_stats()
|
| 76 |
+
init.uniform_(self.bn.weight)
|
| 77 |
+
init.zeros_(self.bn.bias)
|
| 78 |
+
|
| 79 |
+
def reset_parameters(self):
|
| 80 |
+
super().reset_parameters()
|
| 81 |
+
|
| 82 |
+
def update_bn_stats(self):
|
| 83 |
+
self.freeze_bn = False
|
| 84 |
+
self.bn.training = True
|
| 85 |
+
return self
|
| 86 |
+
|
| 87 |
+
def freeze_bn_stats(self):
|
| 88 |
+
self.freeze_bn = True
|
| 89 |
+
self.bn.training = False
|
| 90 |
+
return self
|
| 91 |
+
|
| 92 |
+
def forward(self, input):
|
| 93 |
+
assert self.bn.running_var is not None
|
| 94 |
+
|
| 95 |
+
# Scale the linear weights by BN's running statistics to reduce
|
| 96 |
+
# weight jitter, see https://arxiv.org/pdf/1806.08342.pdf, page 18
|
| 97 |
+
# for motivation.
|
| 98 |
+
#
|
| 99 |
+
# Instead of
|
| 100 |
+
#
|
| 101 |
+
# x1 = F.linear(x0, fq(w), b)
|
| 102 |
+
# x2 = self.bn(x1)
|
| 103 |
+
#
|
| 104 |
+
# We have
|
| 105 |
+
#
|
| 106 |
+
# # scale the weight by previous batch's running statistics
|
| 107 |
+
# scale_factor = bn.w / bn.running_std_from_prev_batch
|
| 108 |
+
# # do the linear transformation without bias
|
| 109 |
+
# x1_scaled = F.linear(x0, fq(w * scale_factor), 0)
|
| 110 |
+
# # reverse the scaling and add original bias
|
| 111 |
+
# x1_orig = x1_scaled / scale_factor + b
|
| 112 |
+
# x2 = self.bn(x1_orig)
|
| 113 |
+
|
| 114 |
+
running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
|
| 115 |
+
scale_factor = self.bn.weight / running_std
|
| 116 |
+
weight_shape = [1] * len(self.weight.shape)
|
| 117 |
+
weight_shape[0] = -1
|
| 118 |
+
bias_shape = [1] * len(self.weight.shape)
|
| 119 |
+
bias_shape[1] = -1
|
| 120 |
+
scaled_weight = self.weight_fake_quant(
|
| 121 |
+
self.weight * scale_factor.reshape(weight_shape)
|
| 122 |
+
)
|
| 123 |
+
if self.bias is not None:
|
| 124 |
+
zero_bias = torch.zeros_like(self.bias)
|
| 125 |
+
else:
|
| 126 |
+
zero_bias = torch.zeros(self.out_features, device=scaled_weight.device)
|
| 127 |
+
linear_out = F.linear(input, scaled_weight, zero_bias)
|
| 128 |
+
linear_out_orig = linear_out / scale_factor.reshape(bias_shape)
|
| 129 |
+
if self.bias is not None:
|
| 130 |
+
linear_out_orig = linear_out_orig + self.bias.reshape(bias_shape)
|
| 131 |
+
bn_out = self.bn(linear_out_orig)
|
| 132 |
+
return bn_out
|
| 133 |
+
|
| 134 |
+
def train(self, mode=True):
|
| 135 |
+
"""
|
| 136 |
+
Batchnorm's training behavior is using the self.training flag. Prevent
|
| 137 |
+
changing it if BN is frozen. This makes sure that calling `model.train()`
|
| 138 |
+
on a model with a frozen BN will behave properly.
|
| 139 |
+
"""
|
| 140 |
+
self.training = mode
|
| 141 |
+
if not self.freeze_bn:
|
| 142 |
+
for module in self.children():
|
| 143 |
+
module.train(mode)
|
| 144 |
+
return self
|
| 145 |
+
|
| 146 |
+
@classmethod
|
| 147 |
+
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
| 148 |
+
r"""Create a qat module from a float module or qparams_dict
|
| 149 |
+
|
| 150 |
+
Args: `mod' a float module, either produced by torch.ao.quantization
|
| 151 |
+
utilities or directly from user
|
| 152 |
+
"""
|
| 153 |
+
assert type(mod) == nni.LinearBn1d, (
|
| 154 |
+
"qat."
|
| 155 |
+
+ cls.__name__
|
| 156 |
+
+ ".from_float only works for "
|
| 157 |
+
+ nni.LinearBn1d.__name__
|
| 158 |
+
)
|
| 159 |
+
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
|
| 160 |
+
assert mod.qconfig, "Input float module must have a valid config"
|
| 161 |
+
qconfig = mod.qconfig
|
| 162 |
+
linear, bn = mod[0], mod[1]
|
| 163 |
+
qat_linearbn = cls(
|
| 164 |
+
linear.in_features,
|
| 165 |
+
linear.out_features,
|
| 166 |
+
linear.bias is not None,
|
| 167 |
+
bn.eps,
|
| 168 |
+
bn.momentum,
|
| 169 |
+
False,
|
| 170 |
+
qconfig,
|
| 171 |
+
)
|
| 172 |
+
qat_linearbn.weight = linear.weight
|
| 173 |
+
qat_linearbn.bias = linear.bias
|
| 174 |
+
qat_linearbn.bn.weight = bn.weight
|
| 175 |
+
qat_linearbn.bn.bias = bn.bias
|
| 176 |
+
qat_linearbn.bn.running_mean = bn.running_mean
|
| 177 |
+
qat_linearbn.bn.running_var = bn.running_var
|
| 178 |
+
qat_linearbn.bn.num_batches_tracked = bn.num_batches_tracked
|
| 179 |
+
return qat_linearbn
|
| 180 |
+
|
| 181 |
+
def to_float(self):
|
| 182 |
+
linear = torch.nn.Linear(self.in_features, self.out_features)
|
| 183 |
+
assert self.bn.running_var is not None and self.bn.running_mean is not None
|
| 184 |
+
linear.weight, linear.bias = fuse_linear_bn_weights(
|
| 185 |
+
self.weight,
|
| 186 |
+
self.bias,
|
| 187 |
+
self.bn.running_mean,
|
| 188 |
+
self.bn.running_var,
|
| 189 |
+
self.bn.eps,
|
| 190 |
+
self.bn.weight,
|
| 191 |
+
self.bn.bias,
|
| 192 |
+
)
|
| 193 |
+
return linear
|
.venv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_relu.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
import torch.ao.nn.intrinsic as nni
|
| 4 |
+
import torch.ao.nn.qat as nnqat
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class LinearReLU(nnqat.Linear, nni._FusedModule):
|
| 9 |
+
r"""
|
| 10 |
+
A LinearReLU module fused from Linear and ReLU modules, attached with
|
| 11 |
+
FakeQuantize modules for weight, used in
|
| 12 |
+
quantization aware training.
|
| 13 |
+
|
| 14 |
+
We adopt the same interface as :class:`torch.nn.Linear`.
|
| 15 |
+
|
| 16 |
+
Similar to `torch.ao.nn.intrinsic.LinearReLU`, with FakeQuantize modules initialized to
|
| 17 |
+
default.
|
| 18 |
+
|
| 19 |
+
Attributes:
|
| 20 |
+
weight: fake quant module for weight
|
| 21 |
+
|
| 22 |
+
Examples::
|
| 23 |
+
|
| 24 |
+
>>> # xdoctest: +SKIP
|
| 25 |
+
>>> m = nn.qat.LinearReLU(20, 30)
|
| 26 |
+
>>> input = torch.randn(128, 20)
|
| 27 |
+
>>> output = m(input)
|
| 28 |
+
>>> print(output.size())
|
| 29 |
+
torch.Size([128, 30])
|
| 30 |
+
"""
|
| 31 |
+
_FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment]
|
| 32 |
+
|
| 33 |
+
def __init__(self, in_features, out_features, bias=True, qconfig=None):
|
| 34 |
+
super().__init__(in_features, out_features, bias, qconfig)
|
| 35 |
+
|
| 36 |
+
def forward(self, input):
|
| 37 |
+
return F.relu(F.linear(input, self.weight_fake_quant(self.weight), self.bias))
|
| 38 |
+
|
| 39 |
+
@classmethod
|
| 40 |
+
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
| 41 |
+
return super().from_float(mod, use_precomputed_fake_quant)
|
| 42 |
+
|
| 43 |
+
def to_float(self):
|
| 44 |
+
linear = torch.nn.Linear(
|
| 45 |
+
self.in_features, self.out_features, self.bias is not None
|
| 46 |
+
)
|
| 47 |
+
linear.weight = torch.nn.Parameter(self.weight.detach())
|
| 48 |
+
if self.bias is not None:
|
| 49 |
+
linear.bias = torch.nn.Parameter(self.bias.detach())
|
| 50 |
+
relu = torch.nn.ReLU()
|
| 51 |
+
return torch.ao.nn.intrinsic.LinearReLU(linear, relu)
|
.venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (235 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (289 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .bn_relu import BNReLU2d, BNReLU3d
|
| 2 |
+
from .conv_add import ConvAdd2d, ConvAddReLU2d
|
| 3 |
+
from .conv_relu import ConvReLU1d, ConvReLU2d, ConvReLU3d
|
| 4 |
+
from .linear_relu import LinearLeakyReLU, LinearReLU, LinearTanh
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"LinearReLU",
|
| 9 |
+
"ConvReLU1d",
|
| 10 |
+
"ConvReLU2d",
|
| 11 |
+
"ConvReLU3d",
|
| 12 |
+
"BNReLU2d",
|
| 13 |
+
"BNReLU3d",
|
| 14 |
+
"LinearLeakyReLU",
|
| 15 |
+
"LinearTanh",
|
| 16 |
+
"ConvAdd2d",
|
| 17 |
+
"ConvAddReLU2d",
|
| 18 |
+
]
|
.venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-39.pyc
ADDED
|
Binary file (3.43 kB). View file
|
|
|
.venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_add.cpython-39.pyc
ADDED
|
Binary file (3.79 kB). View file
|
|
|
.venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-39.pyc
ADDED
|
Binary file (6.51 kB). View file
|
|
|
.venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.ao.nn.intrinsic
|
| 5 |
+
import torch.ao.nn.intrinsic.qat
|
| 6 |
+
import torch.ao.nn.quantized as nnq
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
__all__ = ["BNReLU2d", "BNReLU3d"]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class BNReLU2d(nnq.BatchNorm2d):
|
| 13 |
+
r"""
|
| 14 |
+
A BNReLU2d module is a fused module of BatchNorm2d and ReLU
|
| 15 |
+
|
| 16 |
+
We adopt the same interface as :class:`torch.ao.nn.quantized.BatchNorm2d`.
|
| 17 |
+
|
| 18 |
+
Attributes:
|
| 19 |
+
Same as torch.ao.nn.quantized.BatchNorm2d
|
| 20 |
+
|
| 21 |
+
"""
|
| 22 |
+
_FLOAT_MODULE = torch.ao.nn.intrinsic.BNReLU2d
|
| 23 |
+
|
| 24 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None):
|
| 25 |
+
super().__init__(
|
| 26 |
+
num_features, eps=eps, momentum=momentum, device=device, dtype=dtype
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
def forward(self, input):
|
| 30 |
+
# Temporarily using len(shape) instead of ndim due to JIT issue
|
| 31 |
+
# https://github.com/pytorch/pytorch/issues/23890
|
| 32 |
+
if len(input.shape) != 4:
|
| 33 |
+
raise ValueError("Input shape must be `(N, C, H, W)`!")
|
| 34 |
+
return torch.ops.quantized.batch_norm2d_relu(
|
| 35 |
+
input,
|
| 36 |
+
self.weight,
|
| 37 |
+
self.bias,
|
| 38 |
+
self.running_mean,
|
| 39 |
+
self.running_var,
|
| 40 |
+
self.eps,
|
| 41 |
+
self.scale,
|
| 42 |
+
self.zero_point,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
def _get_name(self):
|
| 46 |
+
return "QuantizedBNReLU2d"
|
| 47 |
+
|
| 48 |
+
@classmethod
|
| 49 |
+
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
| 50 |
+
# TODO: Add qat support for BNReLU2d
|
| 51 |
+
return super().from_float(
|
| 52 |
+
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
@classmethod
|
| 56 |
+
def from_reference(cls, bn_relu, output_scale, output_zero_point):
|
| 57 |
+
return super().from_reference(bn_relu[0], output_scale, output_zero_point)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class BNReLU3d(nnq.BatchNorm3d):
|
| 61 |
+
r"""
|
| 62 |
+
A BNReLU3d module is a fused module of BatchNorm3d and ReLU
|
| 63 |
+
|
| 64 |
+
We adopt the same interface as :class:`torch.ao.nn.quantized.BatchNorm3d`.
|
| 65 |
+
|
| 66 |
+
Attributes:
|
| 67 |
+
Same as torch.ao.nn.quantized.BatchNorm3d
|
| 68 |
+
|
| 69 |
+
"""
|
| 70 |
+
_FLOAT_MODULE = torch.ao.nn.intrinsic.BNReLU3d
|
| 71 |
+
|
| 72 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None):
|
| 73 |
+
super().__init__(
|
| 74 |
+
num_features, eps=eps, momentum=momentum, device=device, dtype=dtype
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
def forward(self, input):
|
| 78 |
+
# Temporarily using len(shape) instead of ndim due to JIT issue
|
| 79 |
+
# https://github.com/pytorch/pytorch/issues/23890
|
| 80 |
+
if len(input.shape) != 5:
|
| 81 |
+
raise ValueError("Input shape must be `(N, C, D, H, W)`!")
|
| 82 |
+
return torch.ops.quantized.batch_norm3d_relu(
|
| 83 |
+
input,
|
| 84 |
+
self.weight,
|
| 85 |
+
self.bias,
|
| 86 |
+
self.running_mean,
|
| 87 |
+
self.running_var,
|
| 88 |
+
self.eps,
|
| 89 |
+
self.scale,
|
| 90 |
+
self.zero_point,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def _get_name(self):
|
| 94 |
+
return "QuantizedBNReLU3d"
|
| 95 |
+
|
| 96 |
+
@classmethod
|
| 97 |
+
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
| 98 |
+
# TODO: Add qat support for BNReLU3d
|
| 99 |
+
return super().from_float(
|
| 100 |
+
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
@classmethod
|
| 104 |
+
def from_reference(cls, bn_relu, output_scale, output_zero_point):
|
| 105 |
+
return super().from_reference(bn_relu[0], output_scale, output_zero_point)
|
.venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_add.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
import torch.ao.nn.intrinsic
|
| 4 |
+
import torch.ao.nn.intrinsic.qat
|
| 5 |
+
import torch.ao.nn.quantized as nnq
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
_reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ConvAdd2d(nnq.Conv2d):
|
| 13 |
+
r"""
|
| 14 |
+
A ConvAdd2d module is a fused module of Conv2d and Add
|
| 15 |
+
|
| 16 |
+
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
|
| 17 |
+
|
| 18 |
+
Attributes:
|
| 19 |
+
Same as torch.ao.nn.quantized.Conv2d
|
| 20 |
+
|
| 21 |
+
"""
|
| 22 |
+
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvAdd2d # type: ignore[assignment]
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
in_channels,
|
| 27 |
+
out_channels,
|
| 28 |
+
kernel_size,
|
| 29 |
+
stride=1,
|
| 30 |
+
padding=0,
|
| 31 |
+
dilation=1,
|
| 32 |
+
groups=1,
|
| 33 |
+
bias=True,
|
| 34 |
+
padding_mode="zeros",
|
| 35 |
+
device=None,
|
| 36 |
+
dtype=None,
|
| 37 |
+
):
|
| 38 |
+
super().__init__(
|
| 39 |
+
in_channels,
|
| 40 |
+
out_channels,
|
| 41 |
+
kernel_size,
|
| 42 |
+
stride=stride,
|
| 43 |
+
padding=padding,
|
| 44 |
+
dilation=dilation,
|
| 45 |
+
groups=groups,
|
| 46 |
+
bias=bias,
|
| 47 |
+
padding_mode=padding_mode,
|
| 48 |
+
device=device,
|
| 49 |
+
dtype=dtype,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def forward(self, input, extra_input):
|
| 53 |
+
# Temporarily using len(shape) instead of ndim due to JIT issue
|
| 54 |
+
# https://github.com/pytorch/pytorch/issues/23890
|
| 55 |
+
if len(input.shape) != 4:
|
| 56 |
+
raise ValueError("Input shape must be `(N, C, H, W)`!")
|
| 57 |
+
if self.padding_mode != "zeros":
|
| 58 |
+
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
|
| 59 |
+
input = F.pad(
|
| 60 |
+
input, _reversed_padding_repeated_twice, mode=self.padding_mode
|
| 61 |
+
)
|
| 62 |
+
return torch.ops.quantized.conv2d_add(
|
| 63 |
+
input, extra_input, self._packed_params, self.scale, self.zero_point
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def _get_name(self):
|
| 67 |
+
return "QuantizedConvAdd2d"
|
| 68 |
+
|
| 69 |
+
@classmethod
|
| 70 |
+
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
| 71 |
+
return super().from_float(
|
| 72 |
+
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
@classmethod
|
| 76 |
+
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
|
| 77 |
+
return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class ConvAddReLU2d(nnq.Conv2d):
|
| 81 |
+
r"""
|
| 82 |
+
A ConvAddReLU2d module is a fused module of Conv2d, Add and Relu
|
| 83 |
+
|
| 84 |
+
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
|
| 85 |
+
|
| 86 |
+
Attributes:
|
| 87 |
+
Same as torch.ao.nn.quantized.Conv2d
|
| 88 |
+
|
| 89 |
+
"""
|
| 90 |
+
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvAddReLU2d # type: ignore[assignment]
|
| 91 |
+
|
| 92 |
+
def __init__(
|
| 93 |
+
self,
|
| 94 |
+
in_channels,
|
| 95 |
+
out_channels,
|
| 96 |
+
kernel_size,
|
| 97 |
+
stride=1,
|
| 98 |
+
padding=0,
|
| 99 |
+
dilation=1,
|
| 100 |
+
groups=1,
|
| 101 |
+
bias=True,
|
| 102 |
+
padding_mode="zeros",
|
| 103 |
+
device=None,
|
| 104 |
+
dtype=None,
|
| 105 |
+
):
|
| 106 |
+
super().__init__(
|
| 107 |
+
in_channels,
|
| 108 |
+
out_channels,
|
| 109 |
+
kernel_size,
|
| 110 |
+
stride=stride,
|
| 111 |
+
padding=padding,
|
| 112 |
+
dilation=dilation,
|
| 113 |
+
groups=groups,
|
| 114 |
+
bias=bias,
|
| 115 |
+
padding_mode=padding_mode,
|
| 116 |
+
device=device,
|
| 117 |
+
dtype=dtype,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
def forward(self, input, extra_input):
|
| 121 |
+
# Temporarily using len(shape) instead of ndim due to JIT issue
|
| 122 |
+
# https://github.com/pytorch/pytorch/issues/23890
|
| 123 |
+
if len(input.shape) != 4:
|
| 124 |
+
raise ValueError("Input shape must be `(N, C, H, W)`!")
|
| 125 |
+
if self.padding_mode != "zeros":
|
| 126 |
+
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
|
| 127 |
+
input = F.pad(
|
| 128 |
+
input, _reversed_padding_repeated_twice, mode=self.padding_mode
|
| 129 |
+
)
|
| 130 |
+
return torch.ops.quantized.conv2d_add_relu(
|
| 131 |
+
input, extra_input, self._packed_params, self.scale, self.zero_point
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
def _get_name(self):
|
| 135 |
+
return "QuantizedConvAddReLU2d"
|
| 136 |
+
|
| 137 |
+
@classmethod
|
| 138 |
+
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
| 139 |
+
return super().from_float(
|
| 140 |
+
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
@classmethod
|
| 144 |
+
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
|
| 145 |
+
return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
|
.venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.ao.nn.intrinsic
|
| 5 |
+
import torch.ao.nn.intrinsic.qat
|
| 6 |
+
import torch.ao.nn.quantized as nnq
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch.nn.utils import fuse_conv_bn_weights
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"ConvReLU1d",
|
| 13 |
+
"ConvReLU2d",
|
| 14 |
+
"ConvReLU3d",
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
_reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# TODO: factor out the common parts to ConvNd
|
| 21 |
+
class ConvReLU1d(nnq.Conv1d):
|
| 22 |
+
r"""
|
| 23 |
+
A ConvReLU1d module is a fused module of Conv1d and ReLU
|
| 24 |
+
|
| 25 |
+
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv1d`.
|
| 26 |
+
|
| 27 |
+
Attributes:
|
| 28 |
+
Same as torch.ao.nn.quantized.Conv1d
|
| 29 |
+
|
| 30 |
+
"""
|
| 31 |
+
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU1d # type: ignore[assignment]
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
in_channels,
|
| 36 |
+
out_channels,
|
| 37 |
+
kernel_size,
|
| 38 |
+
stride=1,
|
| 39 |
+
padding=0,
|
| 40 |
+
dilation=1,
|
| 41 |
+
groups=1,
|
| 42 |
+
bias=True,
|
| 43 |
+
padding_mode="zeros",
|
| 44 |
+
device=None,
|
| 45 |
+
dtype=None,
|
| 46 |
+
):
|
| 47 |
+
super().__init__(
|
| 48 |
+
in_channels,
|
| 49 |
+
out_channels,
|
| 50 |
+
kernel_size,
|
| 51 |
+
stride=stride,
|
| 52 |
+
padding=padding,
|
| 53 |
+
dilation=dilation,
|
| 54 |
+
groups=groups,
|
| 55 |
+
bias=bias,
|
| 56 |
+
padding_mode=padding_mode,
|
| 57 |
+
device=device,
|
| 58 |
+
dtype=dtype,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def forward(self, input):
|
| 62 |
+
# Temporarily using len(shape) instead of ndim due to JIT issue
|
| 63 |
+
# https://github.com/pytorch/pytorch/issues/23890
|
| 64 |
+
if len(input.shape) != 3:
|
| 65 |
+
raise ValueError("Input shape must be `(N, C, L)`!")
|
| 66 |
+
if self.padding_mode != "zeros":
|
| 67 |
+
# Padding in Conv1d is stored as (p, p), need to get (p,)
|
| 68 |
+
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1])
|
| 69 |
+
input = F.pad(
|
| 70 |
+
input, _reversed_padding_repeated_twice, mode=self.padding_mode
|
| 71 |
+
)
|
| 72 |
+
return torch.ops.quantized.conv1d_relu(
|
| 73 |
+
input, self._packed_params, self.scale, self.zero_point
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def _get_name(self):
|
| 77 |
+
return "QuantizedConvReLU1d"
|
| 78 |
+
|
| 79 |
+
@classmethod
|
| 80 |
+
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
| 81 |
+
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d:
|
| 82 |
+
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
|
| 83 |
+
mod.weight, mod.bias = fuse_conv_bn_weights(
|
| 84 |
+
mod.weight,
|
| 85 |
+
mod.bias,
|
| 86 |
+
mod.bn.running_mean,
|
| 87 |
+
mod.bn.running_var,
|
| 88 |
+
mod.bn.eps,
|
| 89 |
+
mod.bn.weight,
|
| 90 |
+
mod.bn.bias,
|
| 91 |
+
)
|
| 92 |
+
return super().from_float(mod, use_precomputed_fake_quant)
|
| 93 |
+
|
| 94 |
+
@classmethod
|
| 95 |
+
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
|
| 96 |
+
assert (
|
| 97 |
+
type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU1d
|
| 98 |
+
), "BatchNorm1d should be fused into Conv1d before converting to reference module"
|
| 99 |
+
return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class ConvReLU2d(nnq.Conv2d):
|
| 103 |
+
r"""
|
| 104 |
+
A ConvReLU2d module is a fused module of Conv2d and ReLU
|
| 105 |
+
|
| 106 |
+
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
|
| 107 |
+
|
| 108 |
+
Attributes:
|
| 109 |
+
Same as torch.ao.nn.quantized.Conv2d
|
| 110 |
+
|
| 111 |
+
"""
|
| 112 |
+
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU2d # type: ignore[assignment]
|
| 113 |
+
|
| 114 |
+
def __init__(
|
| 115 |
+
self,
|
| 116 |
+
in_channels,
|
| 117 |
+
out_channels,
|
| 118 |
+
kernel_size,
|
| 119 |
+
stride=1,
|
| 120 |
+
padding=0,
|
| 121 |
+
dilation=1,
|
| 122 |
+
groups=1,
|
| 123 |
+
bias=True,
|
| 124 |
+
padding_mode="zeros",
|
| 125 |
+
device=None,
|
| 126 |
+
dtype=None,
|
| 127 |
+
):
|
| 128 |
+
super().__init__(
|
| 129 |
+
in_channels,
|
| 130 |
+
out_channels,
|
| 131 |
+
kernel_size,
|
| 132 |
+
stride=stride,
|
| 133 |
+
padding=padding,
|
| 134 |
+
dilation=dilation,
|
| 135 |
+
groups=groups,
|
| 136 |
+
bias=bias,
|
| 137 |
+
padding_mode=padding_mode,
|
| 138 |
+
device=device,
|
| 139 |
+
dtype=dtype,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
def forward(self, input):
|
| 143 |
+
# Temporarily using len(shape) instead of ndim due to JIT issue
|
| 144 |
+
# https://github.com/pytorch/pytorch/issues/23890
|
| 145 |
+
if len(input.shape) != 4:
|
| 146 |
+
raise ValueError("Input shape must be `(N, C, H, W)`!")
|
| 147 |
+
if self.padding_mode != "zeros":
|
| 148 |
+
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
|
| 149 |
+
input = F.pad(
|
| 150 |
+
input, _reversed_padding_repeated_twice, mode=self.padding_mode
|
| 151 |
+
)
|
| 152 |
+
return torch.ops.quantized.conv2d_relu(
|
| 153 |
+
input, self._packed_params, self.scale, self.zero_point
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
def _get_name(self):
|
| 157 |
+
return "QuantizedConvReLU2d"
|
| 158 |
+
|
| 159 |
+
@classmethod
|
| 160 |
+
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
| 161 |
+
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d:
|
| 162 |
+
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
|
| 163 |
+
mod.weight, mod.bias = fuse_conv_bn_weights(
|
| 164 |
+
mod.weight,
|
| 165 |
+
mod.bias,
|
| 166 |
+
mod.bn.running_mean,
|
| 167 |
+
mod.bn.running_var,
|
| 168 |
+
mod.bn.eps,
|
| 169 |
+
mod.bn.weight,
|
| 170 |
+
mod.bn.bias,
|
| 171 |
+
)
|
| 172 |
+
return super().from_float(
|
| 173 |
+
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
@classmethod
|
| 177 |
+
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
|
| 178 |
+
assert (
|
| 179 |
+
type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU2d
|
| 180 |
+
), "BatchNorm2d should be fused into Conv2d before converting to reference module"
|
| 181 |
+
return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class ConvReLU3d(nnq.Conv3d):
|
| 185 |
+
r"""
|
| 186 |
+
A ConvReLU3d module is a fused module of Conv3d and ReLU
|
| 187 |
+
|
| 188 |
+
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv3d`.
|
| 189 |
+
|
| 190 |
+
Attributes: Same as torch.ao.nn.quantized.Conv3d
|
| 191 |
+
|
| 192 |
+
"""
|
| 193 |
+
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU3d # type: ignore[assignment]
|
| 194 |
+
|
| 195 |
+
def __init__(
|
| 196 |
+
self,
|
| 197 |
+
in_channels,
|
| 198 |
+
out_channels,
|
| 199 |
+
kernel_size,
|
| 200 |
+
stride=1,
|
| 201 |
+
padding=0,
|
| 202 |
+
dilation=1,
|
| 203 |
+
groups=1,
|
| 204 |
+
bias=True,
|
| 205 |
+
padding_mode="zeros",
|
| 206 |
+
device=None,
|
| 207 |
+
dtype=None,
|
| 208 |
+
):
|
| 209 |
+
assert padding_mode != "reflect", "Conv3d does not support reflection padding"
|
| 210 |
+
super().__init__(
|
| 211 |
+
in_channels,
|
| 212 |
+
out_channels,
|
| 213 |
+
kernel_size,
|
| 214 |
+
stride=stride,
|
| 215 |
+
padding=padding,
|
| 216 |
+
dilation=dilation,
|
| 217 |
+
groups=groups,
|
| 218 |
+
bias=bias,
|
| 219 |
+
padding_mode=padding_mode,
|
| 220 |
+
device=device,
|
| 221 |
+
dtype=dtype,
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
def forward(self, input):
|
| 225 |
+
# Temporarily using len(shape) instead of ndim due to JIT issue
|
| 226 |
+
# https://github.com/pytorch/pytorch/issues/23890
|
| 227 |
+
if len(input.shape) != 5:
|
| 228 |
+
raise ValueError("Input shape must be `(N, C, D, H, W)`!")
|
| 229 |
+
if self.padding_mode != "zeros":
|
| 230 |
+
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
|
| 231 |
+
input = F.pad(
|
| 232 |
+
input, _reversed_padding_repeated_twice, mode=self.padding_mode
|
| 233 |
+
)
|
| 234 |
+
return torch.ops.quantized.conv3d_relu(
|
| 235 |
+
input, self._packed_params, self.scale, self.zero_point
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
def _get_name(self):
|
| 239 |
+
return "QuantizedConvReLU3d"
|
| 240 |
+
|
| 241 |
+
@classmethod
|
| 242 |
+
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
| 243 |
+
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d:
|
| 244 |
+
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
|
| 245 |
+
mod.weight, mod.bias = fuse_conv_bn_weights(
|
| 246 |
+
mod.weight,
|
| 247 |
+
mod.bias,
|
| 248 |
+
mod.bn.running_mean,
|
| 249 |
+
mod.bn.running_var,
|
| 250 |
+
mod.bn.eps,
|
| 251 |
+
mod.bn.weight,
|
| 252 |
+
mod.bn.bias,
|
| 253 |
+
)
|
| 254 |
+
return super().from_float(
|
| 255 |
+
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
@classmethod
|
| 259 |
+
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
|
| 260 |
+
assert (
|
| 261 |
+
type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU3d
|
| 262 |
+
), "BatchNorm3d should be fused into Conv3d before converting to reference module"
|
| 263 |
+
return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
|
.venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
import torch.ao.nn.intrinsic as nni
|
| 4 |
+
import torch.ao.nn.quantized as nnq
|
| 5 |
+
from torch.ao.nn.quantized.modules.utils import _quantize_weight
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"LinearReLU",
|
| 10 |
+
"LinearLeakyReLU",
|
| 11 |
+
"LinearTanh",
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LinearReLU(nnq.Linear):
|
| 16 |
+
r"""
|
| 17 |
+
A LinearReLU module fused from Linear and ReLU modules
|
| 18 |
+
|
| 19 |
+
We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
|
| 20 |
+
|
| 21 |
+
Attributes:
|
| 22 |
+
Same as torch.ao.nn.quantized.Linear
|
| 23 |
+
|
| 24 |
+
Examples::
|
| 25 |
+
|
| 26 |
+
>>> # xdoctest: +SKIP
|
| 27 |
+
>>> m = nn.intrinsic.LinearReLU(20, 30)
|
| 28 |
+
>>> input = torch.randn(128, 20)
|
| 29 |
+
>>> output = m(input)
|
| 30 |
+
>>> print(output.size())
|
| 31 |
+
torch.Size([128, 30])
|
| 32 |
+
"""
|
| 33 |
+
_FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment]
|
| 34 |
+
|
| 35 |
+
def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8):
|
| 36 |
+
super().__init__(in_features, out_features, bias, dtype)
|
| 37 |
+
|
| 38 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 39 |
+
return torch.ops.quantized.linear_relu(
|
| 40 |
+
x, self._packed_params._packed_params, self.scale, self.zero_point
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
def _get_name(self):
|
| 44 |
+
return "QuantizedLinearReLU"
|
| 45 |
+
|
| 46 |
+
@classmethod
|
| 47 |
+
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
| 48 |
+
return super().from_float(mod, use_precomputed_fake_quant)
|
| 49 |
+
|
| 50 |
+
@classmethod
|
| 51 |
+
def from_reference(cls, ref_linear_relu, output_scale, output_zero_point):
|
| 52 |
+
return super().from_reference(
|
| 53 |
+
ref_linear_relu[0], output_scale, output_zero_point
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class LinearLeakyReLU(nnq.Linear):
|
| 58 |
+
r"""
|
| 59 |
+
For onednn backend only
|
| 60 |
+
A LinearLeakyReLU module fused from Linear and LeakyReLU modules
|
| 61 |
+
We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
|
| 62 |
+
Attributes:
|
| 63 |
+
Same as torch.ao.nn.quantized.Linear
|
| 64 |
+
+ negative_slope
|
| 65 |
+
Examples::
|
| 66 |
+
>>> # xdoctest: +SKIP
|
| 67 |
+
>>> m = nn.intrinsic.LinearLeakyReLU(20, 30, 0.01)
|
| 68 |
+
>>> input = torch.randn(128, 20)
|
| 69 |
+
>>> output = m(input)
|
| 70 |
+
>>> print(output.size())
|
| 71 |
+
torch.Size([128, 30])
|
| 72 |
+
"""
|
| 73 |
+
_FLOAT_MODULE = nni.LinearLeakyReLU # type: ignore[assignment]
|
| 74 |
+
|
| 75 |
+
def __init__(
|
| 76 |
+
self, in_features, out_features, negative_slope, bias=True, dtype=torch.qint8
|
| 77 |
+
):
|
| 78 |
+
super().__init__(in_features, out_features, bias, dtype)
|
| 79 |
+
self.negative_slope = negative_slope
|
| 80 |
+
|
| 81 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 82 |
+
return torch.ops.quantized.linear_leaky_relu(
|
| 83 |
+
x,
|
| 84 |
+
self._packed_params._packed_params,
|
| 85 |
+
self.scale,
|
| 86 |
+
self.zero_point,
|
| 87 |
+
self.negative_slope,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
def _get_name(self):
|
| 91 |
+
return "QuantizedLinearLeakyReLU"
|
| 92 |
+
|
| 93 |
+
@classmethod
|
| 94 |
+
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
| 95 |
+
assert (
|
| 96 |
+
type(mod) == nni.LinearLeakyReLU
|
| 97 |
+
), "Input float module should be LinearLeakyReLU"
|
| 98 |
+
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
|
| 99 |
+
activation_post_process = mod.activation_post_process
|
| 100 |
+
leaky_relu = mod[1]
|
| 101 |
+
mod = mod[0]
|
| 102 |
+
weight_post_process = mod.qconfig.weight()
|
| 103 |
+
weight_post_process(mod.weight)
|
| 104 |
+
dtype = weight_post_process.dtype
|
| 105 |
+
act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator]
|
| 106 |
+
assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8"
|
| 107 |
+
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
|
| 108 |
+
qlinear_leaky_relu = cls(
|
| 109 |
+
mod.in_features, mod.out_features, leaky_relu.negative_slope, dtype=dtype
|
| 110 |
+
)
|
| 111 |
+
qlinear_leaky_relu.set_weight_bias(qweight, mod.bias)
|
| 112 |
+
qlinear_leaky_relu.scale = float(act_scale)
|
| 113 |
+
qlinear_leaky_relu.zero_point = int(act_zp)
|
| 114 |
+
return qlinear_leaky_relu
|
| 115 |
+
|
| 116 |
+
@classmethod
|
| 117 |
+
def from_reference(cls, ref_mod, output_scale, output_zero_point):
|
| 118 |
+
linear = ref_mod[0]
|
| 119 |
+
leaky_relu = ref_mod[1]
|
| 120 |
+
qlinear_leaky_relu = cls(
|
| 121 |
+
linear.in_features, linear.out_features, leaky_relu.negative_slope
|
| 122 |
+
)
|
| 123 |
+
qweight = linear.get_quantized_weight()
|
| 124 |
+
qlinear_leaky_relu.set_weight_bias(qweight, linear.bias)
|
| 125 |
+
qlinear_leaky_relu.scale = float(output_scale)
|
| 126 |
+
qlinear_leaky_relu.zero_point = int(output_zero_point)
|
| 127 |
+
return qlinear_leaky_relu
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class LinearTanh(nnq.Linear):
|
| 131 |
+
r"""
|
| 132 |
+
A LinearTanh module fused from Linear and Tanh modules
|
| 133 |
+
|
| 134 |
+
We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
|
| 135 |
+
|
| 136 |
+
Attributes:
|
| 137 |
+
Same as torch.ao.nn.quantized.Linear
|
| 138 |
+
|
| 139 |
+
Examples::
|
| 140 |
+
|
| 141 |
+
>>> # xdoctest: +SKIP
|
| 142 |
+
>>> m = nn.intrinsic.LinearTanh(20, 30)
|
| 143 |
+
>>> input = torch.randn(128, 20)
|
| 144 |
+
>>> output = m(input)
|
| 145 |
+
>>> print(output.size())
|
| 146 |
+
torch.Size([128, 30])
|
| 147 |
+
"""
|
| 148 |
+
_FLOAT_MODULE = nni.LinearTanh # type: ignore[assignment]
|
| 149 |
+
|
| 150 |
+
def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8):
|
| 151 |
+
super().__init__(in_features, out_features, bias, dtype)
|
| 152 |
+
|
| 153 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 154 |
+
return torch.ops.quantized.linear_tanh(
|
| 155 |
+
x, self._packed_params._packed_params, self.scale, self.zero_point
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
def _get_name(self):
|
| 159 |
+
return "QuantizedLinearTanh"
|
| 160 |
+
|
| 161 |
+
@classmethod
|
| 162 |
+
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
| 163 |
+
assert type(mod) == nni.LinearTanh, "Input float module should be LinearTanh"
|
| 164 |
+
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
|
| 165 |
+
activation_post_process = mod.activation_post_process
|
| 166 |
+
mod = mod[0]
|
| 167 |
+
weight_post_process = mod.qconfig.weight()
|
| 168 |
+
weight_post_process(mod.weight)
|
| 169 |
+
dtype = weight_post_process.dtype
|
| 170 |
+
act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator]
|
| 171 |
+
assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8"
|
| 172 |
+
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
|
| 173 |
+
qlinear_tanh = cls(mod.in_features, mod.out_features, dtype=dtype)
|
| 174 |
+
qlinear_tanh.set_weight_bias(qweight, mod.bias)
|
| 175 |
+
qlinear_tanh.scale = float(act_scale)
|
| 176 |
+
qlinear_tanh.zero_point = int(act_zp)
|
| 177 |
+
return qlinear_tanh
|
| 178 |
+
|
| 179 |
+
@classmethod
|
| 180 |
+
def from_reference(cls, ref_mod, output_scale, output_zero_point):
|
| 181 |
+
linear = ref_mod[0]
|
| 182 |
+
qlinear_tanh = cls(linear.in_features, linear.out_features)
|
| 183 |
+
qweight = linear.get_quantized_weight()
|
| 184 |
+
qlinear_tanh.set_weight_bias(qweight, linear.bias)
|
| 185 |
+
qlinear_tanh.scale = float(output_scale)
|
| 186 |
+
qlinear_tanh.zero_point = int(output_zero_point)
|
| 187 |
+
return qlinear_tanh
|
.venv/Lib/site-packages/torch/ao/nn/quantized/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (674 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/ao/nn/quantized/__pycache__/functional.cpython-39.pyc
ADDED
|
Binary file (26.7 kB). View file
|
|
|
.venv/Lib/site-packages/torch/ao/nn/quantized/reference/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (395 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (650 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/rnn.cpython-39.pyc
ADDED
|
Binary file (17.6 kB). View file
|
|
|
.venv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/sparse.cpython-39.pyc
ADDED
|
Binary file (4.1 kB). View file
|
|
|
.venv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (7.02 kB). View file
|
|
|
.venv/Lib/site-packages/torch/ao/nn/sparse/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from . import quantized
|
.venv/Lib/site-packages/torch/ao/nn/sparse/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (224 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/ao/nn/sparse/quantized/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.ao.nn.sparse.quantized import dynamic
|
| 2 |
+
|
| 3 |
+
from .linear import Linear, LinearPackedParams
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"dynamic",
|
| 8 |
+
"Linear",
|
| 9 |
+
"LinearPackedParams",
|
| 10 |
+
]
|
.venv/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (367 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/linear.cpython-39.pyc
ADDED
|
Binary file (7.6 kB). View file
|
|
|
.venv/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (1.56 kB). View file
|
|
|
.venv/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .linear import Linear
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"Linear",
|
| 6 |
+
]
|
.venv/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (269 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/linear.cpython-39.pyc
ADDED
|
Binary file (5.18 kB). View file
|
|
|
.venv/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/linear.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.ao.nn.intrinsic as nni
|
| 6 |
+
from torch.ao.nn.quantized.modules.utils import (
|
| 7 |
+
_hide_packed_params_repr,
|
| 8 |
+
_quantize_weight,
|
| 9 |
+
)
|
| 10 |
+
from torch.ao.nn.sparse.quantized import linear
|
| 11 |
+
from torch.ao.nn.sparse.quantized.utils import LinearBlockSparsePattern
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
__all__ = ["Linear"]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Linear(torch.nn.Module):
|
| 18 |
+
r"""
|
| 19 |
+
A dynamically quantized sparse linear module with float tensor as inputs and outputs.
|
| 20 |
+
"""
|
| 21 |
+
_version = 1
|
| 22 |
+
_op_type = "sparse_dynamic"
|
| 23 |
+
_FLOAT_MODULE = torch.nn.Linear
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
in_features,
|
| 28 |
+
out_features,
|
| 29 |
+
row_block_size,
|
| 30 |
+
col_block_size,
|
| 31 |
+
bias=True,
|
| 32 |
+
dtype=torch.qint8,
|
| 33 |
+
):
|
| 34 |
+
super().__init__()
|
| 35 |
+
|
| 36 |
+
if dtype != torch.qint8:
|
| 37 |
+
raise NotImplementedError(
|
| 38 |
+
"Only QINT8 is supported for Sparse Quantized Linear Dynamic"
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
self.in_features = in_features
|
| 42 |
+
self.out_features = out_features
|
| 43 |
+
|
| 44 |
+
if bias:
|
| 45 |
+
bias = torch.zeros(self.out_features, dtype=torch.float)
|
| 46 |
+
else:
|
| 47 |
+
bias = None
|
| 48 |
+
|
| 49 |
+
qweight = torch._empty_affine_quantized(
|
| 50 |
+
[out_features, in_features], scale=1, zero_point=0, dtype=torch.qint8
|
| 51 |
+
)
|
| 52 |
+
self._packed_params = linear.LinearPackedParams(
|
| 53 |
+
row_block_size=row_block_size, col_block_size=col_block_size, dtype=dtype
|
| 54 |
+
)
|
| 55 |
+
self._packed_params.set_weight_bias(
|
| 56 |
+
qweight, bias, row_block_size, col_block_size
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def _get_name(self):
|
| 60 |
+
return "SparseQuantizedDynamicLinear"
|
| 61 |
+
|
| 62 |
+
def extra_repr(self):
|
| 63 |
+
return f"in_features={self.in_features}, out_features={self.out_features}, qscheme={self.weight().qscheme()}"
|
| 64 |
+
|
| 65 |
+
def __repr__(self):
|
| 66 |
+
return _hide_packed_params_repr(self, linear.LinearPackedParams)
|
| 67 |
+
|
| 68 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 69 |
+
return torch.ops.sparse.qlinear_dynamic(x, self._packed_params._packed_params)
|
| 70 |
+
|
| 71 |
+
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
| 72 |
+
super()._save_to_state_dict(destination, prefix, keep_vars)
|
| 73 |
+
destination[prefix + "op_type"] = self._op_type
|
| 74 |
+
|
| 75 |
+
def _load_from_state_dict(
|
| 76 |
+
self,
|
| 77 |
+
state_dict,
|
| 78 |
+
prefix,
|
| 79 |
+
local_metadata,
|
| 80 |
+
strict,
|
| 81 |
+
missing_keys,
|
| 82 |
+
unexpected_keys,
|
| 83 |
+
error_msgs,
|
| 84 |
+
):
|
| 85 |
+
op_type = int(state_dict[prefix + "op_type"])
|
| 86 |
+
assert (
|
| 87 |
+
op_type == "sparse"
|
| 88 |
+
), f"Cannot load from op_type [{op_type}], expecting [{self._op_type}]"
|
| 89 |
+
state_dict.pop(prefix + "op_type")
|
| 90 |
+
|
| 91 |
+
version = local_metadata.get("version", None)
|
| 92 |
+
assert version <= self._version
|
| 93 |
+
|
| 94 |
+
# Is this code valid? In old quantization it seemed to be used to load
|
| 95 |
+
# older model
|
| 96 |
+
weight = state_dict.pop(prefix + "weight")
|
| 97 |
+
bias = state_dict.pop(prefix + "bias")
|
| 98 |
+
state_dict.update(
|
| 99 |
+
{
|
| 100 |
+
prefix + "_packed_params.weight": weight,
|
| 101 |
+
prefix + "_packed_params.bias": bias,
|
| 102 |
+
}
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
super()._load_from_state_dict(
|
| 106 |
+
state_dict,
|
| 107 |
+
prefix,
|
| 108 |
+
local_metadata,
|
| 109 |
+
False,
|
| 110 |
+
missing_keys,
|
| 111 |
+
unexpected_keys,
|
| 112 |
+
error_msgs,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def _weight_bias(self):
|
| 116 |
+
return self._packed_params._weight_bias()
|
| 117 |
+
|
| 118 |
+
def weight(self):
|
| 119 |
+
return self._weight_bias()[0]
|
| 120 |
+
|
| 121 |
+
def bias(self):
|
| 122 |
+
return self._weight_bias()[1]
|
| 123 |
+
|
| 124 |
+
def set_weight_bias(
|
| 125 |
+
self,
|
| 126 |
+
w: torch.Tensor,
|
| 127 |
+
b: Optional[torch.Tensor],
|
| 128 |
+
row_block_size: Optional[int],
|
| 129 |
+
col_block_size: Optional[int],
|
| 130 |
+
) -> None:
|
| 131 |
+
assert row_block_size is not None and col_block_size is not None
|
| 132 |
+
self.out_features = w.shape[0]
|
| 133 |
+
self.in_features = w.shape[1]
|
| 134 |
+
self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size)
|
| 135 |
+
|
| 136 |
+
@classmethod
|
| 137 |
+
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
| 138 |
+
r"""Create a quantized sparse dynamic module from a float module.
|
| 139 |
+
|
| 140 |
+
We only care about the convert at this stage, no need for observers just yet.
|
| 141 |
+
"""
|
| 142 |
+
assert type(mod) == cls._FLOAT_MODULE, (
|
| 143 |
+
" nnq."
|
| 144 |
+
+ cls.__name__
|
| 145 |
+
+ ".from_float only works for "
|
| 146 |
+
+ cls._FLOAT_MODULE.__name__
|
| 147 |
+
)
|
| 148 |
+
# TODO: Need to add options to qconfig to avoid the calibration.
|
| 149 |
+
# TODO: Add calibration for the sparsity
|
| 150 |
+
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
|
| 151 |
+
if type(mod) == nni.LinearReLU:
|
| 152 |
+
mod = mod[0]
|
| 153 |
+
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
| 154 |
+
weight_observer = mod.qconfig.weight()
|
| 155 |
+
else:
|
| 156 |
+
# We have the circular import issues if we import the qconfig in the beginning of this file:
|
| 157 |
+
# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
|
| 158 |
+
# import until we need it.
|
| 159 |
+
from torch.ao.quantization.qconfig import default_dynamic_qconfig
|
| 160 |
+
|
| 161 |
+
weight_observer = default_dynamic_qconfig.weight()
|
| 162 |
+
|
| 163 |
+
# It is important to multiply by the mask BEFORE calling the `weight_observer`
|
| 164 |
+
# TODO (zaf): Mask might not be part of the qconfig (T83295194)
|
| 165 |
+
weight = mod.weight
|
| 166 |
+
if getattr(mod.qconfig, "mask", False):
|
| 167 |
+
weight = mod.qconfig.mask * mod.weight
|
| 168 |
+
|
| 169 |
+
weight_observer(weight)
|
| 170 |
+
dtype = weight_observer.dtype
|
| 171 |
+
assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8"
|
| 172 |
+
w_sc, w_zp = weight_observer.calculate_qparams()
|
| 173 |
+
if isinstance(w_zp, torch.Tensor):
|
| 174 |
+
assert not torch.any(w_zp.bool()), "All weight zero points must map to 0"
|
| 175 |
+
else:
|
| 176 |
+
assert w_zp == 0, "Weight zero point must map to 0"
|
| 177 |
+
qweight = _quantize_weight(weight.float(), weight_observer)
|
| 178 |
+
|
| 179 |
+
row_block_size, col_block_size = LinearBlockSparsePattern.block_size()
|
| 180 |
+
qlinear = cls(
|
| 181 |
+
mod.in_features,
|
| 182 |
+
mod.out_features,
|
| 183 |
+
row_block_size,
|
| 184 |
+
col_block_size,
|
| 185 |
+
dtype=dtype,
|
| 186 |
+
)
|
| 187 |
+
qlinear.set_weight_bias(qweight, mod.bias, row_block_size, col_block_size)
|
| 188 |
+
return qlinear
|
.venv/Lib/site-packages/torch/ao/nn/sparse/quantized/linear.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-decorators
|
| 2 |
+
# mypy: allow-untyped-defs
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.ao.nn.quantized.modules.utils import (
|
| 7 |
+
_hide_packed_params_repr,
|
| 8 |
+
_quantize_weight,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
__all__ = ["LinearPackedParams", "Linear"]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# TODO (zaf): Inherit from `quantized.LinearPackedParams` (T83294430)
|
| 16 |
+
class LinearPackedParams(torch.nn.Module):
|
| 17 |
+
_version = 1
|
| 18 |
+
|
| 19 |
+
def __init__(self, row_block_size=1, col_block_size=4, dtype=torch.qint8):
|
| 20 |
+
super().__init__()
|
| 21 |
+
|
| 22 |
+
if dtype != torch.qint8:
|
| 23 |
+
raise NotImplementedError("Linear prepacking only supports QINT8")
|
| 24 |
+
self.dtype = dtype
|
| 25 |
+
wq = torch._empty_affine_quantized(
|
| 26 |
+
[1, 1], scale=1.0, zero_point=0, dtype=torch.qint8
|
| 27 |
+
)
|
| 28 |
+
self.set_weight_bias(wq, None, row_block_size, col_block_size)
|
| 29 |
+
|
| 30 |
+
def _get_name(self):
|
| 31 |
+
return "SparseQuantizedLinearPackedParams"
|
| 32 |
+
|
| 33 |
+
@torch.jit.export
|
| 34 |
+
def set_weight_bias(
|
| 35 |
+
self,
|
| 36 |
+
weight: torch.Tensor,
|
| 37 |
+
bias: Optional[torch.Tensor],
|
| 38 |
+
row_block_size: Optional[int],
|
| 39 |
+
col_block_size: Optional[int],
|
| 40 |
+
) -> None:
|
| 41 |
+
assert row_block_size is not None and col_block_size is not None
|
| 42 |
+
self._packed_params = torch.ops.sparse.qlinear_prepack(
|
| 43 |
+
weight, bias, row_block_size, col_block_size
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
@torch.jit.export
|
| 47 |
+
def _weight_bias(self):
|
| 48 |
+
(weight, bias, block_sizes) = torch.ops.sparse.qlinear_unpack(
|
| 49 |
+
self._packed_params
|
| 50 |
+
)
|
| 51 |
+
return (weight, bias, block_sizes[0], block_sizes[1])
|
| 52 |
+
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
return x
|
| 55 |
+
|
| 56 |
+
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
| 57 |
+
super()._save_to_state_dict(destination, prefix, keep_vars)
|
| 58 |
+
destination[prefix + "dtype"] = self.dtype
|
| 59 |
+
destination[prefix + "_packed_params"] = self._weight_bias()
|
| 60 |
+
|
| 61 |
+
def _load_from_state_dict(
|
| 62 |
+
self,
|
| 63 |
+
state_dict,
|
| 64 |
+
prefix,
|
| 65 |
+
local_metadata,
|
| 66 |
+
strict,
|
| 67 |
+
missing_keys,
|
| 68 |
+
unexpected_keys,
|
| 69 |
+
error_msgs,
|
| 70 |
+
):
|
| 71 |
+
version = local_metadata.get("version", None)
|
| 72 |
+
assert version <= self._version
|
| 73 |
+
|
| 74 |
+
self.dtype = state_dict.pop(prefix + "dtype")
|
| 75 |
+
weight, bias, row_block_size, col_block_size = state_dict.pop(
|
| 76 |
+
prefix + "_packed_params"
|
| 77 |
+
)
|
| 78 |
+
self.set_weight_bias(weight, bias, row_block_size, col_block_size)
|
| 79 |
+
|
| 80 |
+
super()._load_from_state_dict(
|
| 81 |
+
state_dict,
|
| 82 |
+
prefix,
|
| 83 |
+
local_metadata,
|
| 84 |
+
False,
|
| 85 |
+
missing_keys,
|
| 86 |
+
unexpected_keys,
|
| 87 |
+
error_msgs,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
@torch.jit.export
|
| 91 |
+
def __getstate__(self):
|
| 92 |
+
return self._packed_params, self.training, self.dtype
|
| 93 |
+
|
| 94 |
+
@torch.jit.export
|
| 95 |
+
def __setstate__(self, state):
|
| 96 |
+
(self._packed_params, self.training, self.dtype) = state
|
| 97 |
+
|
| 98 |
+
def __repr__(self):
|
| 99 |
+
return self._weight_bias().__repr__()
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# TODO (zaf): Inherit from `quantized.Linear` (T83294430)
|
| 103 |
+
class Linear(torch.nn.Module):
|
| 104 |
+
r"""
|
| 105 |
+
A quantized sparse linear module with quantized tensor as inputs and outputs.
|
| 106 |
+
"""
|
| 107 |
+
_version = 1
|
| 108 |
+
_FLOAT_MODULE = torch.nn.Linear
|
| 109 |
+
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
in_features,
|
| 113 |
+
out_features,
|
| 114 |
+
row_block_size,
|
| 115 |
+
col_block_size,
|
| 116 |
+
bias=True,
|
| 117 |
+
dtype=torch.qint8,
|
| 118 |
+
):
|
| 119 |
+
super().__init__()
|
| 120 |
+
|
| 121 |
+
if dtype != torch.qint8:
|
| 122 |
+
raise NotImplementedError(
|
| 123 |
+
"Only QINT8 is supported for Sparse Quantized Linear"
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
self.in_features = in_features
|
| 127 |
+
self.out_features = out_features
|
| 128 |
+
|
| 129 |
+
if bias:
|
| 130 |
+
bias = torch.zeros(self.out_features, dtype=torch.float)
|
| 131 |
+
else:
|
| 132 |
+
bias = None
|
| 133 |
+
|
| 134 |
+
qweight = torch._empty_affine_quantized(
|
| 135 |
+
[out_features, in_features], scale=1, zero_point=0, dtype=torch.qint8
|
| 136 |
+
)
|
| 137 |
+
self._packed_params = LinearPackedParams(
|
| 138 |
+
row_block_size=row_block_size, col_block_size=col_block_size, dtype=dtype
|
| 139 |
+
)
|
| 140 |
+
self._packed_params.set_weight_bias(
|
| 141 |
+
qweight, bias, row_block_size, col_block_size
|
| 142 |
+
)
|
| 143 |
+
self.scale = 1.0
|
| 144 |
+
self.zero_point = 0
|
| 145 |
+
|
| 146 |
+
@classmethod
|
| 147 |
+
def _get_name(cls):
|
| 148 |
+
return "SparseQuantizedLinear"
|
| 149 |
+
|
| 150 |
+
def extra_repr(self):
|
| 151 |
+
return (
|
| 152 |
+
f"in_features={self.in_features}, out_features={self.out_features}, scale={self.scale}, "
|
| 153 |
+
f"zero_point={self.zero_point}, qscheme={self.weight().qscheme()}"
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
def __repr__(self):
|
| 157 |
+
return _hide_packed_params_repr(self, LinearPackedParams)
|
| 158 |
+
|
| 159 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 160 |
+
return torch.ops.sparse.qlinear(
|
| 161 |
+
x, self._packed_params._packed_params, self.scale, self.zero_point
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
| 165 |
+
super()._save_to_state_dict(destination, prefix, keep_vars)
|
| 166 |
+
destination[prefix + "scale"] = torch.tensor(self.scale)
|
| 167 |
+
destination[prefix + "zero_point"] = torch.tensor(self.zero_point)
|
| 168 |
+
|
| 169 |
+
def _load_from_state_dict(
|
| 170 |
+
self,
|
| 171 |
+
state_dict,
|
| 172 |
+
prefix,
|
| 173 |
+
local_metadata,
|
| 174 |
+
strict,
|
| 175 |
+
missing_keys,
|
| 176 |
+
unexpected_keys,
|
| 177 |
+
error_msgs,
|
| 178 |
+
):
|
| 179 |
+
self.scale = float(state_dict[prefix + "scale"])
|
| 180 |
+
state_dict.pop(prefix + "scale")
|
| 181 |
+
|
| 182 |
+
self.zero_point = int(state_dict[prefix + "zero_point"])
|
| 183 |
+
state_dict.pop(prefix + "zero_point")
|
| 184 |
+
|
| 185 |
+
op_type = int(state_dict[prefix + "op_type"])
|
| 186 |
+
state_dict.pop(prefix + "op_type")
|
| 187 |
+
|
| 188 |
+
version = local_metadata.get("version", None)
|
| 189 |
+
assert version <= self._version
|
| 190 |
+
|
| 191 |
+
super()._load_from_state_dict(
|
| 192 |
+
state_dict,
|
| 193 |
+
prefix,
|
| 194 |
+
local_metadata,
|
| 195 |
+
False,
|
| 196 |
+
missing_keys,
|
| 197 |
+
unexpected_keys,
|
| 198 |
+
error_msgs,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
def _weight_bias(self):
|
| 202 |
+
return self._packed_params._weight_bias()
|
| 203 |
+
|
| 204 |
+
def weight(self):
|
| 205 |
+
return self._weight_bias()[0]
|
| 206 |
+
|
| 207 |
+
def bias(self):
|
| 208 |
+
return self._weight_bias()[1]
|
| 209 |
+
|
| 210 |
+
def set_weight_bias(
|
| 211 |
+
self,
|
| 212 |
+
w: torch.Tensor,
|
| 213 |
+
b: Optional[torch.Tensor],
|
| 214 |
+
row_block_size: Optional[int],
|
| 215 |
+
col_block_size: Optional[int],
|
| 216 |
+
) -> None:
|
| 217 |
+
assert row_block_size is not None and col_block_size is not None
|
| 218 |
+
self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size)
|
| 219 |
+
|
| 220 |
+
@classmethod
|
| 221 |
+
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
| 222 |
+
r"""Create a quantized sparse module from a float module.
|
| 223 |
+
|
| 224 |
+
We only care about the convert at this stage, no need for observers just yet.
|
| 225 |
+
|
| 226 |
+
TODO(zaf): Need to add the sparse params to the qconfig
|
| 227 |
+
"""
|
| 228 |
+
assert type(mod) == cls._FLOAT_MODULE, (
|
| 229 |
+
cls._get_name() + ".from_float only works for " + cls._FLOAT_MODULE.__name__
|
| 230 |
+
)
|
| 231 |
+
assert hasattr(mod, "sparse_params"), (
|
| 232 |
+
"Expecting the Linear to have `sparse_params`. Make sure you have provided arguments "
|
| 233 |
+
'in the `sparsifier.squash_mask(params_to_save=("sparse_block_shape",))` method.'
|
| 234 |
+
)
|
| 235 |
+
sparse_block_shape = mod.sparse_params.get("sparse_block_shape", None) # type: ignore[operator, union-attr]
|
| 236 |
+
assert isinstance(sparse_block_shape, (tuple, list))
|
| 237 |
+
assert len(sparse_block_shape) == 2
|
| 238 |
+
# TODO: Need to add options to qconfig to avoid the calibration.
|
| 239 |
+
# TODO: Add calibration for the sparsity
|
| 240 |
+
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
|
| 241 |
+
activation_post_process = mod.activation_post_process
|
| 242 |
+
weight_post_process = mod.qconfig.weight() # type: ignore[operator, union-attr]
|
| 243 |
+
|
| 244 |
+
# Assumption is that the weight is already sparsified by the
|
| 245 |
+
# `sparsifier.convert`
|
| 246 |
+
weight = mod.weight
|
| 247 |
+
|
| 248 |
+
weight_post_process(weight)
|
| 249 |
+
dtype = weight_post_process.dtype
|
| 250 |
+
act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[operator, union-attr]
|
| 251 |
+
assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8"
|
| 252 |
+
w_sc, w_zp = weight_post_process.calculate_qparams()
|
| 253 |
+
if isinstance(w_zp, torch.Tensor):
|
| 254 |
+
assert not torch.any(w_zp.bool()), "All weight zero points must map to 0"
|
| 255 |
+
else:
|
| 256 |
+
assert w_zp == 0, "Weight zero point must map to 0"
|
| 257 |
+
qweight = _quantize_weight(weight.float(), weight_post_process)
|
| 258 |
+
|
| 259 |
+
row_block_size = mod.sparse_params["sparse_block_shape"][0] # type: ignore[index]
|
| 260 |
+
col_block_size = mod.sparse_params["sparse_block_shape"][1] # type: ignore[index]
|
| 261 |
+
qlinear = cls(
|
| 262 |
+
mod.in_features,
|
| 263 |
+
mod.out_features,
|
| 264 |
+
row_block_size,
|
| 265 |
+
col_block_size,
|
| 266 |
+
dtype=dtype,
|
| 267 |
+
)
|
| 268 |
+
qlinear.set_weight_bias(
|
| 269 |
+
qweight, mod.bias, row_block_size, col_block_size
|
| 270 |
+
) # type: ignore[arg-type]
|
| 271 |
+
qlinear.scale = float(act_scale)
|
| 272 |
+
qlinear.zero_point = int(act_zp)
|
| 273 |
+
return qlinear
|
.venv/Lib/site-packages/torch/ao/nn/sparse/quantized/utils.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import threading
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
__all__ = ["LinearBlockSparsePattern"]
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _is_valid_linear_block_sparse_pattern(row_block_size, col_block_size):
|
| 9 |
+
return (row_block_size == 1 and col_block_size == 4) or (
|
| 10 |
+
row_block_size == 8 and col_block_size == 1
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# This is a stop-gap measure as current flow does not allow module
|
| 15 |
+
# specific block sparse pattern.
|
| 16 |
+
# Infact there is no way to convey sparse pattern via module config
|
| 17 |
+
# of quantization flow. Thus using the global context to convey
|
| 18 |
+
# sparsity pattern.
|
| 19 |
+
# Once the flow supports it, this should be removed.
|
| 20 |
+
class LinearBlockSparsePattern:
|
| 21 |
+
rlock = threading.RLock()
|
| 22 |
+
row_block_size = 1
|
| 23 |
+
col_block_size = 4
|
| 24 |
+
prev_row_block_size = 1
|
| 25 |
+
prev_col_block_size = 4
|
| 26 |
+
|
| 27 |
+
def __init__(self, row_block_size=1, col_block_size=4):
|
| 28 |
+
assert _is_valid_linear_block_sparse_pattern(row_block_size, col_block_size)
|
| 29 |
+
LinearBlockSparsePattern.rlock.acquire()
|
| 30 |
+
LinearBlockSparsePattern.prev_row_block_size = (
|
| 31 |
+
LinearBlockSparsePattern.row_block_size
|
| 32 |
+
)
|
| 33 |
+
LinearBlockSparsePattern.prev_col_block_size = (
|
| 34 |
+
LinearBlockSparsePattern.col_block_size
|
| 35 |
+
)
|
| 36 |
+
LinearBlockSparsePattern.row_block_size = row_block_size
|
| 37 |
+
LinearBlockSparsePattern.col_block_size = col_block_size
|
| 38 |
+
|
| 39 |
+
def __enter__(self):
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
def __exit__(self, exc_type, exc_value, backtrace):
|
| 43 |
+
LinearBlockSparsePattern.row_block_size = (
|
| 44 |
+
LinearBlockSparsePattern.prev_row_block_size
|
| 45 |
+
)
|
| 46 |
+
LinearBlockSparsePattern.col_block_size = (
|
| 47 |
+
LinearBlockSparsePattern.prev_col_block_size
|
| 48 |
+
)
|
| 49 |
+
LinearBlockSparsePattern.rlock.release()
|
| 50 |
+
|
| 51 |
+
@staticmethod
|
| 52 |
+
def block_size():
|
| 53 |
+
return (
|
| 54 |
+
LinearBlockSparsePattern.row_block_size,
|
| 55 |
+
LinearBlockSparsePattern.col_block_size,
|
| 56 |
+
)
|
.venv/Lib/site-packages/torch/ao/ns/__init__.py
ADDED
|
File without changes
|
.venv/Lib/site-packages/torch/ao/ns/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (183 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/ao/ns/_numeric_suite.py
ADDED
|
@@ -0,0 +1,563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.ao.nn.quantized as nnq
|
| 6 |
+
import torch.ao.nn.quantized.dynamic as nnqd
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.ao.quantization import prepare
|
| 9 |
+
from torch.ao.quantization.quantization_mappings import (
|
| 10 |
+
get_default_compare_output_module_list,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = {
|
| 15 |
+
nnqd.Linear,
|
| 16 |
+
nnq.Linear,
|
| 17 |
+
nnqd.LSTM,
|
| 18 |
+
nn.LSTM,
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _find_match(
|
| 23 |
+
str_list: Union[Dict[str, Any], List[str]],
|
| 24 |
+
key_str: str,
|
| 25 |
+
postfix: str,
|
| 26 |
+
) -> Optional[str]:
|
| 27 |
+
split_str = key_str.split(".")
|
| 28 |
+
if split_str[-1] == postfix:
|
| 29 |
+
match_string = "".join(key_str.split(".")[0:-1])
|
| 30 |
+
for s2 in str_list:
|
| 31 |
+
pattern1 = "".join(s2.split(".")[0:-1])
|
| 32 |
+
pattern2 = "".join(s2.split(".")[0:-2])
|
| 33 |
+
if match_string == pattern1:
|
| 34 |
+
return s2
|
| 35 |
+
if match_string == pattern2:
|
| 36 |
+
return s2
|
| 37 |
+
|
| 38 |
+
# For matching "fc.weight" and "fc._packed_params._packed_params"
|
| 39 |
+
if postfix == "_packed_params":
|
| 40 |
+
match_string = "".join(key_str.split(".")[0:-2])
|
| 41 |
+
if len(match_string) == 0:
|
| 42 |
+
return None
|
| 43 |
+
for s2 in str_list:
|
| 44 |
+
pattern1 = "".join(s2.split(".")[0:-1])
|
| 45 |
+
pattern2 = "".join(s2.split(".")[0:-2])
|
| 46 |
+
if match_string == pattern1:
|
| 47 |
+
return s2
|
| 48 |
+
if match_string == pattern2:
|
| 49 |
+
return s2
|
| 50 |
+
return None
|
| 51 |
+
else:
|
| 52 |
+
return None
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def compare_weights(
|
| 56 |
+
float_dict: Dict[str, Any], quantized_dict: Dict[str, Any]
|
| 57 |
+
) -> Dict[str, Dict[str, torch.Tensor]]:
|
| 58 |
+
r"""Compare the weights of the float module with its corresponding quantized
|
| 59 |
+
module. Return a dict with key corresponding to module names and each entry being
|
| 60 |
+
a dictionary with two keys 'float' and 'quantized', containing the float and
|
| 61 |
+
quantized weights. This dict can be used to compare and compute the quantization
|
| 62 |
+
error of the weights of float and quantized models.
|
| 63 |
+
|
| 64 |
+
Example usage::
|
| 65 |
+
|
| 66 |
+
wt_compare_dict = compare_weights(
|
| 67 |
+
float_model.state_dict(), qmodel.state_dict())
|
| 68 |
+
for key in wt_compare_dict:
|
| 69 |
+
print(
|
| 70 |
+
key,
|
| 71 |
+
compute_error(
|
| 72 |
+
wt_compare_dict[key]['float'],
|
| 73 |
+
wt_compare_dict[key]['quantized'].dequantize()
|
| 74 |
+
)
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
float_dict: state dict of the float model
|
| 79 |
+
quantized_dict: state dict of the quantized model
|
| 80 |
+
|
| 81 |
+
Return:
|
| 82 |
+
weight_dict: dict with key corresponding to module names and each entry being
|
| 83 |
+
a dictionary with two keys 'float' and 'quantized', containing the float and
|
| 84 |
+
quantized weights
|
| 85 |
+
"""
|
| 86 |
+
torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_weights")
|
| 87 |
+
weight_dict: Dict[str, Dict] = {}
|
| 88 |
+
for key in quantized_dict:
|
| 89 |
+
match_key = _find_match(float_dict, key, "weight")
|
| 90 |
+
if match_key is not None:
|
| 91 |
+
weight_dict[key] = {}
|
| 92 |
+
weight_dict[key]["float"] = float_dict[match_key]
|
| 93 |
+
weight_dict[key]["quantized"] = quantized_dict[key]
|
| 94 |
+
continue
|
| 95 |
+
|
| 96 |
+
# For matching "fc.weight" and "fc._packed_params._packed_params"
|
| 97 |
+
match_key = _find_match(float_dict, key, "_packed_params")
|
| 98 |
+
if match_key is not None:
|
| 99 |
+
weight_dict[key] = {}
|
| 100 |
+
weight_dict[key]["float"] = float_dict[match_key]
|
| 101 |
+
weight_dict[key]["quantized"] = quantized_dict[key][0]
|
| 102 |
+
|
| 103 |
+
# For LSTM
|
| 104 |
+
split_str = key.split(".")
|
| 105 |
+
if split_str[-1] == "param" and split_str[-3] == "_all_weight_values":
|
| 106 |
+
layer = split_str[-2]
|
| 107 |
+
module_name = ".".join(split_str[:-3])
|
| 108 |
+
float_weight_ih_key = module_name + ".weight_ih_l" + layer
|
| 109 |
+
float_weight_hh_key = module_name + ".weight_hh_l" + layer
|
| 110 |
+
if float_weight_ih_key in float_dict and float_weight_hh_key in float_dict:
|
| 111 |
+
weight_dict[key] = {}
|
| 112 |
+
weight_dict[key]["float"] = float_dict[float_weight_ih_key]
|
| 113 |
+
weight_dict[key]["quantized"] = (
|
| 114 |
+
quantized_dict[key].__getstate__()[0][4][0].__getstate__()[0][0]
|
| 115 |
+
)
|
| 116 |
+
weight_dict[key]["float"] = float_dict[float_weight_hh_key]
|
| 117 |
+
weight_dict[key]["quantized"] = (
|
| 118 |
+
quantized_dict[key].__getstate__()[0][4][1].__getstate__()[0][0]
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
return weight_dict
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _get_logger_dict_helper(
|
| 125 |
+
mod: nn.Module,
|
| 126 |
+
target_dict: Dict[str, Any],
|
| 127 |
+
prefix: str = "",
|
| 128 |
+
) -> None:
|
| 129 |
+
r"""This is the helper function for get_logger_dict
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
mod: module we want to save all logger stats
|
| 133 |
+
prefix: prefix for the current module
|
| 134 |
+
target_dict: the dictionary used to save all logger stats
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def get_prefix(prefix):
|
| 138 |
+
return prefix if prefix == "" else prefix + "."
|
| 139 |
+
|
| 140 |
+
for name, child in mod.named_children():
|
| 141 |
+
if isinstance(child, Logger):
|
| 142 |
+
target_dict[get_prefix(prefix) + "stats"] = child.stats
|
| 143 |
+
break
|
| 144 |
+
|
| 145 |
+
for name, child in mod.named_children():
|
| 146 |
+
module_prefix = get_prefix(prefix) + name if prefix else name
|
| 147 |
+
_get_logger_dict_helper(child, target_dict, module_prefix)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def get_logger_dict(mod: nn.Module, prefix: str = "") -> Dict[str, Dict]:
|
| 151 |
+
r"""Traverse the modules and save all logger stats into target dict.
|
| 152 |
+
This is mainly used for quantization accuracy debug.
|
| 153 |
+
|
| 154 |
+
Type of loggers supported:
|
| 155 |
+
ShadowLogger: used to log the outputs of the quantized module and its matching float shadow module,
|
| 156 |
+
OutputLogger: used to log the outputs of the modules
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
mod: module we want to save all logger stats
|
| 160 |
+
prefix: prefix for the current module
|
| 161 |
+
|
| 162 |
+
Return:
|
| 163 |
+
target_dict: the dictionary used to save all logger stats
|
| 164 |
+
|
| 165 |
+
"""
|
| 166 |
+
torch._C._log_api_usage_once("quantization_api._numeric_suite.get_logger_dict")
|
| 167 |
+
|
| 168 |
+
target_dict: Dict[str, Dict] = {}
|
| 169 |
+
_get_logger_dict_helper(mod, target_dict, prefix)
|
| 170 |
+
return target_dict
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class Logger(nn.Module):
|
| 174 |
+
r"""Base class for stats logging"""
|
| 175 |
+
|
| 176 |
+
def __init__(self):
|
| 177 |
+
super().__init__()
|
| 178 |
+
self.stats = {}
|
| 179 |
+
# We only insert observer if the op is quantized with static quantization,
|
| 180 |
+
# which is identified by activation_observer.dtype == quint8. This is needed
|
| 181 |
+
# when attaching Logger as observer for FX mode
|
| 182 |
+
self.dtype = torch.quint8
|
| 183 |
+
|
| 184 |
+
def forward(self, x):
|
| 185 |
+
# fmt: off
|
| 186 |
+
"""
|
| 187 |
+
""" # blank docblock to make autodoc happy
|
| 188 |
+
# fmt: on
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class ShadowLogger(Logger):
|
| 192 |
+
r"""Class used in Shadow module to record the outputs of the original and
|
| 193 |
+
shadow modules.
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
def __init__(self):
|
| 197 |
+
super().__init__()
|
| 198 |
+
self.stats["float"] = []
|
| 199 |
+
self.stats["quantized"] = []
|
| 200 |
+
|
| 201 |
+
def forward(self, x, y):
|
| 202 |
+
# fmt: off
|
| 203 |
+
"""
|
| 204 |
+
""" # blank docblock to make autodoc happy
|
| 205 |
+
# fmt: on
|
| 206 |
+
if len(x) > 1:
|
| 207 |
+
x = x[0]
|
| 208 |
+
if len(y) > 1:
|
| 209 |
+
y = y[0]
|
| 210 |
+
self.stats["quantized"].append(x.detach())
|
| 211 |
+
self.stats["float"].append(y.detach())
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class OutputLogger(Logger):
|
| 215 |
+
r"""Class used to log the outputs of the module"""
|
| 216 |
+
|
| 217 |
+
def __init__(self):
|
| 218 |
+
super().__init__()
|
| 219 |
+
self.stats["tensor_val"] = []
|
| 220 |
+
|
| 221 |
+
def forward(self, x):
|
| 222 |
+
# fmt: off
|
| 223 |
+
"""
|
| 224 |
+
""" # blank docblock to make autodoc happy
|
| 225 |
+
# fmt: on
|
| 226 |
+
self.stats["tensor_val"].append(x)
|
| 227 |
+
return x
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def _convert_tuple_to_list(t: Any) -> Any:
|
| 231 |
+
return [_convert_tuple_to_list(x) for x in t] if type(t) is tuple else t
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def _dequantize_tensor_list(t: Any) -> Any:
|
| 235 |
+
return (
|
| 236 |
+
[_dequantize_tensor_list(x) for x in t]
|
| 237 |
+
if type(t) is list
|
| 238 |
+
else t.dequantize()
|
| 239 |
+
if t.is_quantized
|
| 240 |
+
else t
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class Shadow(nn.Module):
|
| 245 |
+
r"""Shadow module attaches the float module to its matching quantized module
|
| 246 |
+
as the shadow. Then it uses Logger module to process the outputs of both
|
| 247 |
+
modules.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
q_module: module quantized from float_module that we want to shadow
|
| 251 |
+
float_module: float module used to shadow q_module
|
| 252 |
+
logger_cls: type of logger used to process the outputs of q_module and
|
| 253 |
+
float_module. ShadowLogger or custom loggers can be used.
|
| 254 |
+
"""
|
| 255 |
+
|
| 256 |
+
def __init__(self, q_module, float_module, logger_cls):
|
| 257 |
+
super().__init__()
|
| 258 |
+
self.orig_module = q_module
|
| 259 |
+
self.shadow_module = float_module
|
| 260 |
+
self.dequant = nnq.DeQuantize()
|
| 261 |
+
self.logger = logger_cls()
|
| 262 |
+
|
| 263 |
+
def forward(self, *x) -> torch.Tensor:
|
| 264 |
+
# fmt: off
|
| 265 |
+
"""
|
| 266 |
+
""" # blank docblock to make autodoc happy
|
| 267 |
+
# fmt: on
|
| 268 |
+
xl = _convert_tuple_to_list(x)
|
| 269 |
+
output = self.orig_module(*xl)
|
| 270 |
+
xl_float = _dequantize_tensor_list(xl)
|
| 271 |
+
shadow_output = self.shadow_module(*xl_float)
|
| 272 |
+
self.logger(output, shadow_output)
|
| 273 |
+
return output
|
| 274 |
+
|
| 275 |
+
def add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
| 276 |
+
# fmt: off
|
| 277 |
+
"""
|
| 278 |
+
""" # blank docblock to make autodoc happy
|
| 279 |
+
# fmt: on
|
| 280 |
+
output = self.orig_module.add(x, y)
|
| 281 |
+
x = x.dequantize()
|
| 282 |
+
y = y.dequantize()
|
| 283 |
+
shadow_output = self.shadow_module.add(x, y)
|
| 284 |
+
self.logger(output, shadow_output)
|
| 285 |
+
return output
|
| 286 |
+
|
| 287 |
+
def add_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
|
| 288 |
+
# fmt: off
|
| 289 |
+
"""
|
| 290 |
+
""" # blank docblock to make autodoc happy
|
| 291 |
+
# fmt: on
|
| 292 |
+
output = self.orig_module.add_scalar(x, y)
|
| 293 |
+
x = x.dequantize()
|
| 294 |
+
shadow_output = self.shadow_module.add_scalar(x, y)
|
| 295 |
+
self.logger(output, shadow_output)
|
| 296 |
+
return output
|
| 297 |
+
|
| 298 |
+
def mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
| 299 |
+
# fmt: off
|
| 300 |
+
"""
|
| 301 |
+
""" # blank docblock to make autodoc happy
|
| 302 |
+
# fmt: on
|
| 303 |
+
output = self.orig_module.mul(x, y)
|
| 304 |
+
x = x.dequantize()
|
| 305 |
+
y = y.dequantize()
|
| 306 |
+
shadow_output = self.shadow_module.mul(x, y)
|
| 307 |
+
self.logger(output, shadow_output)
|
| 308 |
+
return output
|
| 309 |
+
|
| 310 |
+
def mul_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
|
| 311 |
+
# fmt: off
|
| 312 |
+
"""
|
| 313 |
+
""" # blank docblock to make autodoc happy
|
| 314 |
+
# fmt: on
|
| 315 |
+
output = self.orig_module.mul_scalar(x, y)
|
| 316 |
+
x = x.dequantize()
|
| 317 |
+
shadow_output = self.shadow_module.mul_scalar(x, y)
|
| 318 |
+
self.logger(output, shadow_output)
|
| 319 |
+
return output
|
| 320 |
+
|
| 321 |
+
def cat(self, x: List[torch.Tensor], dim: int = 0) -> torch.Tensor:
|
| 322 |
+
# fmt: off
|
| 323 |
+
"""
|
| 324 |
+
""" # blank docblock to make autodoc happy
|
| 325 |
+
# fmt: on
|
| 326 |
+
output = self.orig_module.cat(x, dim)
|
| 327 |
+
x = [y.dequantize() for y in x]
|
| 328 |
+
shadow_output = self.shadow_module.cat(x, dim)
|
| 329 |
+
self.logger(output, shadow_output)
|
| 330 |
+
return output
|
| 331 |
+
|
| 332 |
+
def add_relu(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
| 333 |
+
# fmt: off
|
| 334 |
+
"""
|
| 335 |
+
""" # blank docblock to make autodoc happy
|
| 336 |
+
# fmt: on
|
| 337 |
+
output = self.orig_module.add_relu(x, y)
|
| 338 |
+
x = x.dequantize()
|
| 339 |
+
y = y.dequantize()
|
| 340 |
+
shadow_output = self.shadow_module.add_relu(x, y)
|
| 341 |
+
self.logger(output, shadow_output)
|
| 342 |
+
return output
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def prepare_model_with_stubs(
|
| 346 |
+
float_module: nn.Module,
|
| 347 |
+
q_module: nn.Module,
|
| 348 |
+
module_swap_list: Set[type],
|
| 349 |
+
logger_cls: Callable,
|
| 350 |
+
) -> None:
|
| 351 |
+
r"""Prepare the model by attaching the float module to its matching quantized
|
| 352 |
+
module as the shadow if the float module type is in module_swap_list.
|
| 353 |
+
|
| 354 |
+
Example usage::
|
| 355 |
+
|
| 356 |
+
prepare_model_with_stubs(float_model, q_model, module_swap_list, Logger)
|
| 357 |
+
q_model(data)
|
| 358 |
+
ob_dict = get_logger_dict(q_model)
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
float_module: float module used to generate the q_module
|
| 362 |
+
q_module: module quantized from float_module
|
| 363 |
+
module_swap_list: list of float module types to attach the shadow
|
| 364 |
+
logger_cls: type of logger to be used in shadow module to process the outputs of
|
| 365 |
+
quantized module and its float shadow module
|
| 366 |
+
"""
|
| 367 |
+
torch._C._log_api_usage_once(
|
| 368 |
+
"quantization_api._numeric_suite.prepare_model_with_stubs"
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
float_module_children = {}
|
| 372 |
+
for name, mod in float_module.named_children():
|
| 373 |
+
float_module_children[name] = mod
|
| 374 |
+
|
| 375 |
+
reassign = {}
|
| 376 |
+
for name, mod in q_module.named_children():
|
| 377 |
+
if name not in float_module_children:
|
| 378 |
+
continue
|
| 379 |
+
|
| 380 |
+
float_mod = float_module_children[name]
|
| 381 |
+
|
| 382 |
+
if type(float_mod) not in module_swap_list:
|
| 383 |
+
prepare_model_with_stubs(float_mod, mod, module_swap_list, logger_cls)
|
| 384 |
+
|
| 385 |
+
# Insert shadow module only if the module is not of the same type as
|
| 386 |
+
# the floating point module
|
| 387 |
+
if type(float_mod) in module_swap_list and not _is_identical_module_type(
|
| 388 |
+
mod, float_mod
|
| 389 |
+
):
|
| 390 |
+
reassign[name] = Shadow(mod, float_mod, logger_cls)
|
| 391 |
+
|
| 392 |
+
for key, value in reassign.items():
|
| 393 |
+
q_module._modules[key] = value
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def _is_identical_module_type(mod1, mod2):
|
| 397 |
+
# Compare if two modules have the same dtype
|
| 398 |
+
mod1_module_types = [type(mod) for mod in mod1.modules()]
|
| 399 |
+
mod2_module_types = [type(mod) for mod in mod2.modules()]
|
| 400 |
+
return mod1_module_types == mod2_module_types
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def compare_model_stub(
|
| 404 |
+
float_model: nn.Module,
|
| 405 |
+
q_model: nn.Module,
|
| 406 |
+
module_swap_list: Set[type],
|
| 407 |
+
*data,
|
| 408 |
+
logger_cls=ShadowLogger,
|
| 409 |
+
) -> Dict[str, Dict]:
|
| 410 |
+
r"""Compare quantized module in a model with its floating point counterpart,
|
| 411 |
+
feeding both of them the same input. Return a dict with key corresponding to
|
| 412 |
+
module names and each entry being a dictionary with two keys 'float' and
|
| 413 |
+
'quantized', containing the output tensors of quantized and its matching
|
| 414 |
+
float shadow module. This dict can be used to compare and compute the module
|
| 415 |
+
level quantization error.
|
| 416 |
+
|
| 417 |
+
This function first call prepare_model_with_stubs() to swap the quantized
|
| 418 |
+
module that we want to compare with the Shadow module, which takes quantized
|
| 419 |
+
module, corresponding float module and logger as input, and creates a forward
|
| 420 |
+
path inside to make the float module to shadow quantized module sharing the
|
| 421 |
+
same input. The logger can be customizable, default logger is ShadowLogger
|
| 422 |
+
and it will save the outputs of the quantized module and float module that
|
| 423 |
+
can be used to compute the module level quantization error.
|
| 424 |
+
|
| 425 |
+
Example usage::
|
| 426 |
+
|
| 427 |
+
module_swap_list = [torchvision.models.quantization.resnet.QuantizableBasicBlock]
|
| 428 |
+
ob_dict = compare_model_stub(float_model,qmodel,module_swap_list, data)
|
| 429 |
+
for key in ob_dict:
|
| 430 |
+
print(key, compute_error(ob_dict[key]['float'], ob_dict[key]['quantized'].dequantize()))
|
| 431 |
+
|
| 432 |
+
Args:
|
| 433 |
+
float_model: float model used to generate the q_model
|
| 434 |
+
q_model: model quantized from float_model
|
| 435 |
+
module_swap_list: list of float module types at which shadow modules will
|
| 436 |
+
be attached.
|
| 437 |
+
data: input data used to run the prepared q_model
|
| 438 |
+
logger_cls: type of logger to be used in shadow module to process the outputs of
|
| 439 |
+
quantized module and its float shadow module
|
| 440 |
+
"""
|
| 441 |
+
torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_stub")
|
| 442 |
+
prepare_model_with_stubs(float_model, q_model, module_swap_list, logger_cls)
|
| 443 |
+
q_model(*data)
|
| 444 |
+
ob_dict = get_logger_dict(q_model)
|
| 445 |
+
return ob_dict
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def get_matching_activations(
|
| 449 |
+
float_module: nn.Module,
|
| 450 |
+
q_module: nn.Module,
|
| 451 |
+
) -> Dict[str, Dict[str, torch.Tensor]]:
|
| 452 |
+
r"""Find the matching activation between float and quantized modules.
|
| 453 |
+
|
| 454 |
+
Args:
|
| 455 |
+
float_module: float module used to generate the q_module
|
| 456 |
+
q_module: module quantized from float_module
|
| 457 |
+
|
| 458 |
+
Return:
|
| 459 |
+
act_dict: dict with key corresponding to quantized module names and each
|
| 460 |
+
entry being a dictionary with two keys 'float' and 'quantized', containing
|
| 461 |
+
the matching float and quantized activations
|
| 462 |
+
"""
|
| 463 |
+
torch._C._log_api_usage_once(
|
| 464 |
+
"quantization_api._numeric_suite.get_matching_activations"
|
| 465 |
+
)
|
| 466 |
+
float_dict = get_logger_dict(float_module)
|
| 467 |
+
quantized_dict = get_logger_dict(q_module)
|
| 468 |
+
act_dict: Dict[str, Dict] = {}
|
| 469 |
+
for key in quantized_dict:
|
| 470 |
+
if len(quantized_dict[key]["tensor_val"]) == 0:
|
| 471 |
+
continue
|
| 472 |
+
match_key = _find_match(sorted(float_dict, reverse=True), key, "stats")
|
| 473 |
+
if match_key is not None:
|
| 474 |
+
act_dict[key] = {}
|
| 475 |
+
act_dict[key]["float"] = float_dict[match_key]["tensor_val"]
|
| 476 |
+
act_dict[key]["quantized"] = quantized_dict[key]["tensor_val"]
|
| 477 |
+
return act_dict
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def prepare_model_outputs(
|
| 481 |
+
float_module: nn.Module,
|
| 482 |
+
q_module: nn.Module,
|
| 483 |
+
logger_cls=OutputLogger,
|
| 484 |
+
allow_list=None,
|
| 485 |
+
) -> None:
|
| 486 |
+
r"""Prepare the model by attaching the logger to both float module
|
| 487 |
+
and quantized module if they are in the allow_list.
|
| 488 |
+
|
| 489 |
+
Args:
|
| 490 |
+
float_module: float module used to generate the q_module
|
| 491 |
+
q_module: module quantized from float_module
|
| 492 |
+
logger_cls: type of logger to be attached to float_module and q_module
|
| 493 |
+
allow_list: list of module types to attach logger
|
| 494 |
+
"""
|
| 495 |
+
torch._C._log_api_usage_once(
|
| 496 |
+
"quantization_api._numeric_suite.prepare_model_outputs"
|
| 497 |
+
)
|
| 498 |
+
if allow_list is None:
|
| 499 |
+
allow_list = get_default_compare_output_module_list()
|
| 500 |
+
|
| 501 |
+
qconfig_debug = torch.ao.quantization.QConfig(activation=logger_cls, weight=None)
|
| 502 |
+
float_module.qconfig = qconfig_debug # type: ignore[assignment]
|
| 503 |
+
prepare(
|
| 504 |
+
float_module, inplace=True, allow_list=allow_list, prepare_custom_config_dict={}
|
| 505 |
+
)
|
| 506 |
+
q_module.qconfig = qconfig_debug # type: ignore[assignment]
|
| 507 |
+
prepare(
|
| 508 |
+
q_module,
|
| 509 |
+
inplace=True,
|
| 510 |
+
allow_list=allow_list,
|
| 511 |
+
observer_non_leaf_module_list=NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST,
|
| 512 |
+
prepare_custom_config_dict={},
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def compare_model_outputs(
|
| 517 |
+
float_model: nn.Module,
|
| 518 |
+
q_model: nn.Module,
|
| 519 |
+
*data,
|
| 520 |
+
logger_cls=OutputLogger,
|
| 521 |
+
allow_list=None,
|
| 522 |
+
) -> Dict[str, Dict[str, torch.Tensor]]:
|
| 523 |
+
r"""Compare output activations between float and quantized models at
|
| 524 |
+
corresponding locations for the same input. Return a dict with key corresponding
|
| 525 |
+
to quantized module names and each entry being a dictionary with two keys
|
| 526 |
+
'float' and 'quantized', containing the activations of quantized model and
|
| 527 |
+
float model at matching locations. This dict can be used to compare and
|
| 528 |
+
compute the propagation quantization error.
|
| 529 |
+
|
| 530 |
+
Example usage::
|
| 531 |
+
|
| 532 |
+
act_compare_dict = compare_model_outputs(float_model, qmodel, data)
|
| 533 |
+
for key in act_compare_dict:
|
| 534 |
+
print(
|
| 535 |
+
key,
|
| 536 |
+
compute_error(
|
| 537 |
+
act_compare_dict[key]['float'],
|
| 538 |
+
act_compare_dict[key]['quantized'].dequantize()
|
| 539 |
+
)
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
Args:
|
| 543 |
+
float_model: float model used to generate the q_model
|
| 544 |
+
q_model: model quantized from float_model
|
| 545 |
+
data: input data used to run the prepared float_model and q_model
|
| 546 |
+
logger_cls: type of logger to be attached to float_module and q_module
|
| 547 |
+
allow_list: list of module types to attach logger
|
| 548 |
+
|
| 549 |
+
Return:
|
| 550 |
+
act_compare_dict: dict with key corresponding to quantized module names
|
| 551 |
+
and each entry being a dictionary with two keys 'float' and 'quantized',
|
| 552 |
+
containing the matching float and quantized activations
|
| 553 |
+
"""
|
| 554 |
+
torch._C._log_api_usage_once(
|
| 555 |
+
"quantization_api._numeric_suite.compare_model_outputs"
|
| 556 |
+
)
|
| 557 |
+
if allow_list is None:
|
| 558 |
+
allow_list = get_default_compare_output_module_list()
|
| 559 |
+
prepare_model_outputs(float_model, q_model, logger_cls, allow_list)
|
| 560 |
+
float_model(*data)
|
| 561 |
+
q_model(*data)
|
| 562 |
+
act_compare_dict = get_matching_activations(float_model, q_model)
|
| 563 |
+
return act_compare_dict
|
.venv/Lib/site-packages/torch/ao/ns/_numeric_suite_fx.py
ADDED
|
@@ -0,0 +1,1130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
"""
|
| 3 |
+
This module contains tooling to compare weights and activations
|
| 4 |
+
across models. Example usage::
|
| 5 |
+
|
| 6 |
+
import copy
|
| 7 |
+
import torch
|
| 8 |
+
import torch.ao.quantization.quantize_fx as quantize_fx
|
| 9 |
+
import torch.ao.ns._numeric_suite_fx as ns
|
| 10 |
+
|
| 11 |
+
m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1)).eval()
|
| 12 |
+
mp = quantize_fx.prepare_fx(m, {'': torch.ao.quantization.default_qconfig})
|
| 13 |
+
# We convert a copy because we need the original prepared model
|
| 14 |
+
# to be available for comparisons, and `quantize_fx.convert_fx` is inplace.
|
| 15 |
+
mq = quantize_fx.convert_fx(copy.deepcopy(mp))
|
| 16 |
+
|
| 17 |
+
#
|
| 18 |
+
# Comparing weights
|
| 19 |
+
#
|
| 20 |
+
|
| 21 |
+
# extract weight pairs
|
| 22 |
+
weight_comparison = ns.extract_weights('a', mp, 'b', mq)
|
| 23 |
+
|
| 24 |
+
# add SQNR for each comparison, inplace
|
| 25 |
+
ns.extend_logger_results_with_comparison(
|
| 26 |
+
weight_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr,
|
| 27 |
+
'sqnr')
|
| 28 |
+
|
| 29 |
+
# weight_comparison contains the weights from `mp` and `mq` stored
|
| 30 |
+
# in pairs, and can be used for further analysis.
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
#
|
| 34 |
+
# Comparing activations, with error propagation
|
| 35 |
+
#
|
| 36 |
+
|
| 37 |
+
# add loggers
|
| 38 |
+
mp_ns, mq_ns = ns.add_loggers(
|
| 39 |
+
'a', copy.deepcopy(mp),
|
| 40 |
+
'b', copy.deepcopy(mq),
|
| 41 |
+
ns.OutputLogger)
|
| 42 |
+
|
| 43 |
+
# send an example datum to capture intermediate activations
|
| 44 |
+
datum = torch.randn(1, 1, 1, 1)
|
| 45 |
+
mp_ns(datum)
|
| 46 |
+
mq_ns(datum)
|
| 47 |
+
|
| 48 |
+
# extract intermediate activations
|
| 49 |
+
act_comparison = ns.extract_logger_info(
|
| 50 |
+
mp_ns, mq_ns, ns.OutputLogger, 'b')
|
| 51 |
+
|
| 52 |
+
# add SQNR for each comparison, inplace
|
| 53 |
+
ns.extend_logger_results_with_comparison(
|
| 54 |
+
act_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr,
|
| 55 |
+
'sqnr')
|
| 56 |
+
|
| 57 |
+
# act_comparison contains the activations from `mp_ns` and `mq_ns` stored
|
| 58 |
+
# in pairs, and can be used for further analysis.
|
| 59 |
+
|
| 60 |
+
#
|
| 61 |
+
# Comparing activations, without error propagation
|
| 62 |
+
#
|
| 63 |
+
|
| 64 |
+
# create shadow model
|
| 65 |
+
mp_shadows_mq = ns.add_shadow_loggers(
|
| 66 |
+
'a', copy.deepcopy(mp),
|
| 67 |
+
'b', copy.deepcopy(mq),
|
| 68 |
+
ns.OutputLogger)
|
| 69 |
+
|
| 70 |
+
# send an example datum to capture intermediate activations
|
| 71 |
+
datum = torch.randn(1, 1, 1, 1)
|
| 72 |
+
mp_shadows_mq(datum)
|
| 73 |
+
|
| 74 |
+
# extract intermediate activations
|
| 75 |
+
shadow_act_comparison = ns.extract_shadow_logger_info(
|
| 76 |
+
mp_shadows_mq, ns.OutputLogger, 'b')
|
| 77 |
+
|
| 78 |
+
# add SQNR for each comparison, inplace
|
| 79 |
+
ns.extend_logger_results_with_comparison(
|
| 80 |
+
shadow_act_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr,
|
| 81 |
+
'sqnr')
|
| 82 |
+
|
| 83 |
+
# shadow_act_comparison contains the activations from `mp_ns` and `mq_ns` stored
|
| 84 |
+
# in pairs, and can be used for further analysis.
|
| 85 |
+
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
import collections
|
| 89 |
+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING
|
| 90 |
+
|
| 91 |
+
import torch
|
| 92 |
+
import torch.ao.quantization.quantize_fx as quantize_fx
|
| 93 |
+
import torch.nn as nn
|
| 94 |
+
from torch.ao.ns.fx.graph_matcher import (
|
| 95 |
+
get_matching_subgraph_pairs,
|
| 96 |
+
get_type_a_related_to_b,
|
| 97 |
+
)
|
| 98 |
+
from torch.ao.ns.fx.mappings import get_base_name_to_sets_of_related_ops
|
| 99 |
+
from torch.ao.ns.fx.n_shadows_utils import (
|
| 100 |
+
_get_dedup_subgraphs,
|
| 101 |
+
create_add_loggers_graph,
|
| 102 |
+
create_n_transformed_and_logged_copies_of_subgraph,
|
| 103 |
+
create_results_comparison,
|
| 104 |
+
extract_weight_comparison,
|
| 105 |
+
group_results_by_subgraph,
|
| 106 |
+
OutputProp,
|
| 107 |
+
print_n_shadows_summary,
|
| 108 |
+
SHADOW_WRAPPER_NODE_NAME_PREFIX,
|
| 109 |
+
)
|
| 110 |
+
from torch.ao.ns.fx.qconfig_multi_mapping import QConfigMultiMapping
|
| 111 |
+
from torch.ao.quantization import QConfigMapping
|
| 112 |
+
from torch.ao.quantization.backend_config import BackendConfig
|
| 113 |
+
from torch.ao.quantization.backend_config.utils import (
|
| 114 |
+
get_fusion_pattern_to_root_node_getter,
|
| 115 |
+
)
|
| 116 |
+
from torch.ao.quantization.fx.graph_module import _get_observed_graph_module_attr
|
| 117 |
+
from torch.ao.quantization.fx.match_utils import _find_matches
|
| 118 |
+
from torch.ao.quantization.fx.qconfig_mapping_utils import (
|
| 119 |
+
_generate_node_name_to_qconfig,
|
| 120 |
+
)
|
| 121 |
+
from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
|
| 122 |
+
from torch.fx import GraphModule
|
| 123 |
+
from torch.fx.graph import Node
|
| 124 |
+
|
| 125 |
+
from .fx.graph_passes import add_loggers_to_model, create_a_shadows_b
|
| 126 |
+
from .fx.ns_types import NSNodeTargetType, NSResultsType, NSSingleResultValuesType
|
| 127 |
+
from .fx.utils import (
|
| 128 |
+
get_target_type_str,
|
| 129 |
+
maybe_add_missing_fqns,
|
| 130 |
+
rekey_logger_info_on_node_name_of_model,
|
| 131 |
+
)
|
| 132 |
+
from .fx.weight_utils import extract_weight_from_node
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
if TYPE_CHECKING:
|
| 136 |
+
from torch.ao.quantization.qconfig import QConfigAny
|
| 137 |
+
|
| 138 |
+
RNNReturnType = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class OutputLogger(nn.Module):
|
| 142 |
+
"""
|
| 143 |
+
Base class for capturing intermediate values.
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
stats: List[torch.Tensor]
|
| 147 |
+
stats_rnn: List[RNNReturnType]
|
| 148 |
+
|
| 149 |
+
# Mark as impure so that calls to it will not be removed during DCE.
|
| 150 |
+
_is_impure = True
|
| 151 |
+
|
| 152 |
+
def __init__(
|
| 153 |
+
self,
|
| 154 |
+
ref_node_name: str,
|
| 155 |
+
prev_node_name: str,
|
| 156 |
+
model_name: str,
|
| 157 |
+
ref_name: str,
|
| 158 |
+
prev_node_target_type: str,
|
| 159 |
+
ref_node_target_type: str,
|
| 160 |
+
results_type: str,
|
| 161 |
+
index_within_arg: int,
|
| 162 |
+
index_of_arg: int,
|
| 163 |
+
fqn: Optional[str],
|
| 164 |
+
qconfig_str: Optional[str] = "",
|
| 165 |
+
):
|
| 166 |
+
super().__init__()
|
| 167 |
+
self.stats: List[torch.Tensor] = []
|
| 168 |
+
self.stats_rnn: List[RNNReturnType] = []
|
| 169 |
+
|
| 170 |
+
# name of the node which was responsible for adding this logger
|
| 171 |
+
# Note:
|
| 172 |
+
# - if we are logging node outputs, this is the same as prev_node_name
|
| 173 |
+
# - if we are logging node inputs, this is the name of the node
|
| 174 |
+
# whose input this logger is logging.
|
| 175 |
+
#
|
| 176 |
+
# example, where logger1 is logging input of op1 and logger2 is logging
|
| 177 |
+
# the output of op1:
|
| 178 |
+
#
|
| 179 |
+
# x1 -> logger1 -> op1 -> logger2 -> x2
|
| 180 |
+
#
|
| 181 |
+
# in this example,
|
| 182 |
+
# - logger1's prev_node_name is x1 and ref_node_name is op1
|
| 183 |
+
# - logger2's prev_node_name is op1 and ref_node_name is op1
|
| 184 |
+
self.ref_node_name = ref_node_name
|
| 185 |
+
# name of the node whose output this Logger is capturing
|
| 186 |
+
self.prev_node_name = prev_node_name
|
| 187 |
+
|
| 188 |
+
# name of the model from which the node originated from
|
| 189 |
+
self.model_name = model_name
|
| 190 |
+
# reference name, used to match loggers from separate models
|
| 191 |
+
# to each other
|
| 192 |
+
self.ref_name = ref_name
|
| 193 |
+
# type of the target of the node whose output this logger is logging
|
| 194 |
+
self.prev_node_target_type = prev_node_target_type
|
| 195 |
+
# type of the target of the node which was responsible for adding this
|
| 196 |
+
# logger
|
| 197 |
+
self.ref_node_target_type = ref_node_target_type
|
| 198 |
+
# what kind of values are inside of stats
|
| 199 |
+
self.results_type = results_type
|
| 200 |
+
# index of this node within the arg of the input/output node
|
| 201 |
+
# for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1
|
| 202 |
+
self.index_within_arg = index_within_arg
|
| 203 |
+
# index of this node within the args of the input/output node
|
| 204 |
+
# for example, in add(x1, x2), x2 would have index_of_arg == 1
|
| 205 |
+
self.index_of_arg = index_of_arg
|
| 206 |
+
# fully qualified name
|
| 207 |
+
self.fqn = fqn
|
| 208 |
+
# if loggers are added before prepare_fx, but we do not want
|
| 209 |
+
# collect results of calibration, only results after convert_fx
|
| 210 |
+
# so, we add a flag to control whether this logger collects data
|
| 211 |
+
self.enabled = True
|
| 212 |
+
# string representation of qconfig
|
| 213 |
+
self.qconfig_str = qconfig_str
|
| 214 |
+
# this can be turned off to reduce memory usage during calibration
|
| 215 |
+
self.save_activations = True
|
| 216 |
+
|
| 217 |
+
# Note: cannot annotate the type of x because TorchScript does not support
|
| 218 |
+
# the Union type.
|
| 219 |
+
def forward(self, x):
|
| 220 |
+
# fmt: off
|
| 221 |
+
"""
|
| 222 |
+
""" # blank docblock to make autodoc happy
|
| 223 |
+
# fmt: on
|
| 224 |
+
# TODO(future PR): consider designing this better, as the difference
|
| 225 |
+
# between these two flags is subtle and not obvious.
|
| 226 |
+
if not self.enabled:
|
| 227 |
+
return x
|
| 228 |
+
if not self.save_activations:
|
| 229 |
+
return x
|
| 230 |
+
# TODO(future PR): consider refactoring this to better reuse the parent
|
| 231 |
+
# class
|
| 232 |
+
if isinstance(x, torch.Tensor):
|
| 233 |
+
self.stats.append(x.detach())
|
| 234 |
+
elif isinstance(x, tuple) and len(x) == 2 and len(x[1]) == 2:
|
| 235 |
+
new_res = (x[0].detach(), (x[1][0].detach(), x[1][1].detach()))
|
| 236 |
+
self.stats_rnn.append(new_res)
|
| 237 |
+
return x
|
| 238 |
+
|
| 239 |
+
def __repr__(self):
|
| 240 |
+
clean_dict = {
|
| 241 |
+
k: v
|
| 242 |
+
for k, v in self.__dict__.items()
|
| 243 |
+
# skip nn.Module keys
|
| 244 |
+
if (k != "training") and not k.startswith("_")
|
| 245 |
+
}
|
| 246 |
+
return f"OutputLogger({clean_dict})"
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class OutputComparisonLogger(OutputLogger):
|
| 250 |
+
"""
|
| 251 |
+
Same as OutputLogger, but also requires the original activation
|
| 252 |
+
in order to calculate the comparison at calibration time
|
| 253 |
+
"""
|
| 254 |
+
|
| 255 |
+
def __init__(self, *args, **kwargs):
|
| 256 |
+
super().__init__(*args, **kwargs)
|
| 257 |
+
# TODO(future PR): make the comparison function configurable
|
| 258 |
+
self.comparison_fn = torch.ao.ns.fx.utils.compute_sqnr
|
| 259 |
+
self.comparison_fn_name = "sqnr"
|
| 260 |
+
# precalculated comparisons of logger output versus reference
|
| 261 |
+
self.comparisons = []
|
| 262 |
+
# precalculated comparisons function
|
| 263 |
+
|
| 264 |
+
def forward(self, x, x_ref):
|
| 265 |
+
# fmt: off
|
| 266 |
+
"""
|
| 267 |
+
""" # blank docblock to make autodoc happy
|
| 268 |
+
# fmt: on
|
| 269 |
+
if not self.enabled:
|
| 270 |
+
return x
|
| 271 |
+
assert isinstance(x, torch.Tensor), "non-tensor inputs not yet supported"
|
| 272 |
+
if self.save_activations:
|
| 273 |
+
# save the activation, for debugging
|
| 274 |
+
self.stats.append(x.detach())
|
| 275 |
+
# save the comparison
|
| 276 |
+
self.comparisons.append(self.comparison_fn(x, x_ref))
|
| 277 |
+
return x
|
| 278 |
+
|
| 279 |
+
def __repr__(self):
|
| 280 |
+
clean_dict = {
|
| 281 |
+
k: v
|
| 282 |
+
for k, v in self.__dict__.items()
|
| 283 |
+
# skip nn.Module keys
|
| 284 |
+
if (k != "training") and not k.startswith("_")
|
| 285 |
+
}
|
| 286 |
+
return f"OutputComparisonLogger({clean_dict})"
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
class NSTracer(quantize_fx.QuantizationTracer):
|
| 290 |
+
"""
|
| 291 |
+
Just like a regular FX quantization tracer, but treats observers and fake_quantize
|
| 292 |
+
modules as leaf modules.
|
| 293 |
+
"""
|
| 294 |
+
|
| 295 |
+
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
|
| 296 |
+
# fmt: off
|
| 297 |
+
"""
|
| 298 |
+
""" # blank docblock to make autodoc happy
|
| 299 |
+
# fmt: on
|
| 300 |
+
if isinstance(m, torch.ao.quantization.ObserverBase):
|
| 301 |
+
return True
|
| 302 |
+
elif isinstance(m, torch.ao.quantization.FakeQuantizeBase):
|
| 303 |
+
return True
|
| 304 |
+
return super().is_leaf_module(m, module_qualified_name)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def _extract_weights_one_model(
|
| 308 |
+
model_name: str,
|
| 309 |
+
model: GraphModule,
|
| 310 |
+
nodes_and_names_to_instrument: List[Tuple[Node, str]],
|
| 311 |
+
results: NSResultsType,
|
| 312 |
+
op_to_type_to_weight_extraction_fn: Optional[
|
| 313 |
+
Dict[str, Dict[Callable, Callable]]
|
| 314 |
+
] = None,
|
| 315 |
+
) -> None:
|
| 316 |
+
torch._C._log_api_usage_once(
|
| 317 |
+
"quantization_api._numeric_suite_fx._extract_weights_one_model"
|
| 318 |
+
)
|
| 319 |
+
for node, ref_name in nodes_and_names_to_instrument:
|
| 320 |
+
res_type = NSSingleResultValuesType.WEIGHT.value
|
| 321 |
+
extracted_weight = extract_weight_from_node(
|
| 322 |
+
node, model, op_to_type_to_weight_extraction_fn
|
| 323 |
+
)
|
| 324 |
+
if extracted_weight:
|
| 325 |
+
if ref_name not in results:
|
| 326 |
+
results[ref_name] = {res_type: {}}
|
| 327 |
+
results[ref_name][res_type][model_name] = [extracted_weight]
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def _extract_weights_impl(
|
| 331 |
+
model_name_a: str,
|
| 332 |
+
gm_a: GraphModule,
|
| 333 |
+
model_name_b: str,
|
| 334 |
+
gm_b: GraphModule,
|
| 335 |
+
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 336 |
+
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 337 |
+
op_to_type_to_weight_extraction_fn: Optional[
|
| 338 |
+
Dict[str, Dict[Callable, Callable]]
|
| 339 |
+
] = None,
|
| 340 |
+
) -> NSResultsType:
|
| 341 |
+
torch._C._log_api_usage_once(
|
| 342 |
+
"quantization_api._numeric_suite_fx._extract_weights_impl"
|
| 343 |
+
)
|
| 344 |
+
matched_subgraph_pairs = get_matching_subgraph_pairs(
|
| 345 |
+
gm_a, gm_b, base_name_to_sets_of_related_ops, unmatchable_types_map
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
# split the subgraph pairs into one data structure for each model
|
| 349 |
+
nodes_and_names_to_instrument_a: List[Tuple[Node, str]] = []
|
| 350 |
+
nodes_and_names_to_instrument_b: List[Tuple[Node, str]] = []
|
| 351 |
+
for match_name, match in matched_subgraph_pairs.items():
|
| 352 |
+
subgraph_a, subgraph_b = match
|
| 353 |
+
nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name))
|
| 354 |
+
nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name))
|
| 355 |
+
|
| 356 |
+
# populate the results, one model at a time
|
| 357 |
+
results: NSResultsType = {}
|
| 358 |
+
_extract_weights_one_model(
|
| 359 |
+
model_name_a,
|
| 360 |
+
gm_a,
|
| 361 |
+
nodes_and_names_to_instrument_a,
|
| 362 |
+
results,
|
| 363 |
+
op_to_type_to_weight_extraction_fn,
|
| 364 |
+
)
|
| 365 |
+
_extract_weights_one_model(
|
| 366 |
+
model_name_b,
|
| 367 |
+
gm_b,
|
| 368 |
+
nodes_and_names_to_instrument_b,
|
| 369 |
+
results,
|
| 370 |
+
op_to_type_to_weight_extraction_fn,
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
# fill in missing fqn entries
|
| 374 |
+
maybe_add_missing_fqns(results)
|
| 375 |
+
|
| 376 |
+
# rekey on names of nodes in gm_b
|
| 377 |
+
results = rekey_logger_info_on_node_name_of_model(results, model_name_b)
|
| 378 |
+
|
| 379 |
+
return results
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def extract_weights(
|
| 383 |
+
model_name_a: str,
|
| 384 |
+
model_a: nn.Module,
|
| 385 |
+
model_name_b: str,
|
| 386 |
+
model_b: nn.Module,
|
| 387 |
+
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 388 |
+
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 389 |
+
op_to_type_to_weight_extraction_fn: Optional[
|
| 390 |
+
Dict[str, Dict[Callable, Callable]]
|
| 391 |
+
] = None,
|
| 392 |
+
) -> NSResultsType:
|
| 393 |
+
"""
|
| 394 |
+
Extract weights from model A and model B, and return a comparison.
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
model_name_a: string name of model A to use in results
|
| 398 |
+
model_a: model A
|
| 399 |
+
model_name_b: string name of model B to use in results
|
| 400 |
+
model_b: model B
|
| 401 |
+
base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change
|
| 402 |
+
unmatchable_types_map: optional override of unmatchable types, subject to change
|
| 403 |
+
op_to_type_to_weight_extraction_fn: optional override of function which extracts weight
|
| 404 |
+
from a type, subject to change
|
| 405 |
+
|
| 406 |
+
Return:
|
| 407 |
+
NSResultsType, containing the weight comparisons
|
| 408 |
+
"""
|
| 409 |
+
|
| 410 |
+
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_weights")
|
| 411 |
+
if base_name_to_sets_of_related_ops is None:
|
| 412 |
+
base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
|
| 413 |
+
type_a_related_to_b = get_type_a_related_to_b(base_name_to_sets_of_related_ops)
|
| 414 |
+
|
| 415 |
+
# TODO(future PR): expose these
|
| 416 |
+
skipped_module_names: List[str] = []
|
| 417 |
+
skipped_module_classes: List[Callable] = []
|
| 418 |
+
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
| 419 |
+
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
| 420 |
+
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
| 421 |
+
maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr(
|
| 422 |
+
model_a, "node_name_to_scope"
|
| 423 |
+
)
|
| 424 |
+
if maybe_model_a_node_name_to_scope is not None:
|
| 425 |
+
gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope
|
| 426 |
+
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
| 427 |
+
maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr(
|
| 428 |
+
model_b, "node_name_to_scope"
|
| 429 |
+
)
|
| 430 |
+
if maybe_model_b_node_name_to_scope is not None:
|
| 431 |
+
gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope
|
| 432 |
+
return _extract_weights_impl(
|
| 433 |
+
model_name_a,
|
| 434 |
+
gm_a,
|
| 435 |
+
model_name_b,
|
| 436 |
+
gm_b,
|
| 437 |
+
base_name_to_sets_of_related_ops,
|
| 438 |
+
unmatchable_types_map,
|
| 439 |
+
op_to_type_to_weight_extraction_fn,
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def _add_loggers_one_model(
|
| 444 |
+
model_name: str,
|
| 445 |
+
model: GraphModule,
|
| 446 |
+
nodes_and_names_to_instrument_inputs: List[Tuple[Node, str, str]],
|
| 447 |
+
nodes_and_names_to_instrument_outputs: List[Tuple[Node, str, str]],
|
| 448 |
+
logger_cls: Callable,
|
| 449 |
+
) -> nn.Module:
|
| 450 |
+
torch._C._log_api_usage_once(
|
| 451 |
+
"quantization_api._numeric_suite_fx._add_loggers_one_model"
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
# TODO(future PR): do not observe nodes we do not care
|
| 455 |
+
# about (both fp32, denylist, etc)
|
| 456 |
+
node_to_instrument_inputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
|
| 457 |
+
node_to_instrument_outputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
|
| 458 |
+
for node, ref_name, ref_node_type in nodes_and_names_to_instrument_inputs:
|
| 459 |
+
node_to_instrument_inputs_to_ref_name[node] = (ref_name, ref_node_type)
|
| 460 |
+
for node, ref_name, ref_node_type in nodes_and_names_to_instrument_outputs:
|
| 461 |
+
node_to_instrument_outputs_to_ref_name[node] = (ref_name, ref_node_type)
|
| 462 |
+
|
| 463 |
+
model = add_loggers_to_model(
|
| 464 |
+
model,
|
| 465 |
+
node_to_instrument_inputs_to_ref_name,
|
| 466 |
+
node_to_instrument_outputs_to_ref_name,
|
| 467 |
+
logger_cls,
|
| 468 |
+
model_name,
|
| 469 |
+
)
|
| 470 |
+
return model
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
def _add_loggers_impl(
|
| 474 |
+
name_a: str,
|
| 475 |
+
gm_a: GraphModule,
|
| 476 |
+
name_b: str,
|
| 477 |
+
gm_b: GraphModule,
|
| 478 |
+
logger_cls: Callable,
|
| 479 |
+
should_log_inputs: bool,
|
| 480 |
+
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 481 |
+
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 482 |
+
) -> Tuple[nn.Module, nn.Module]:
|
| 483 |
+
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_impl")
|
| 484 |
+
matched_subgraph_pairs = get_matching_subgraph_pairs(
|
| 485 |
+
gm_a, gm_b, base_name_to_sets_of_related_ops, unmatchable_types_map
|
| 486 |
+
)
|
| 487 |
+
nodes_and_names_to_instrument_inputs_a = []
|
| 488 |
+
nodes_and_names_to_instrument_inputs_b = []
|
| 489 |
+
nodes_and_names_to_instrument_outputs_a = []
|
| 490 |
+
nodes_and_names_to_instrument_outputs_b = []
|
| 491 |
+
for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items():
|
| 492 |
+
ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a)
|
| 493 |
+
ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b)
|
| 494 |
+
# Note: for matching inputs we use start_node, such as observing
|
| 495 |
+
# the input of linear in linear-relu
|
| 496 |
+
if should_log_inputs:
|
| 497 |
+
nodes_and_names_to_instrument_inputs_a.append(
|
| 498 |
+
(subgraph_a.start_node, match_name, ref_node_type_a)
|
| 499 |
+
)
|
| 500 |
+
nodes_and_names_to_instrument_inputs_b.append(
|
| 501 |
+
(subgraph_b.start_node, match_name, ref_node_type_b)
|
| 502 |
+
)
|
| 503 |
+
# Note: for matching activations we always use end_node,
|
| 504 |
+
# such as observing the output of relu in linear-relu
|
| 505 |
+
nodes_and_names_to_instrument_outputs_a.append(
|
| 506 |
+
(subgraph_a.end_node, match_name, ref_node_type_a)
|
| 507 |
+
)
|
| 508 |
+
nodes_and_names_to_instrument_outputs_b.append(
|
| 509 |
+
(subgraph_b.end_node, match_name, ref_node_type_b)
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
new_model_a = _add_loggers_one_model(
|
| 513 |
+
name_a,
|
| 514 |
+
gm_a,
|
| 515 |
+
nodes_and_names_to_instrument_inputs_a,
|
| 516 |
+
nodes_and_names_to_instrument_outputs_a,
|
| 517 |
+
logger_cls,
|
| 518 |
+
)
|
| 519 |
+
new_model_b = _add_loggers_one_model(
|
| 520 |
+
name_b,
|
| 521 |
+
gm_b,
|
| 522 |
+
nodes_and_names_to_instrument_inputs_b,
|
| 523 |
+
nodes_and_names_to_instrument_outputs_b,
|
| 524 |
+
logger_cls,
|
| 525 |
+
)
|
| 526 |
+
return (new_model_a, new_model_b)
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
def add_loggers(
|
| 530 |
+
name_a: str,
|
| 531 |
+
model_a: nn.Module,
|
| 532 |
+
name_b: str,
|
| 533 |
+
model_b: nn.Module,
|
| 534 |
+
logger_cls: Callable,
|
| 535 |
+
should_log_inputs: bool = False,
|
| 536 |
+
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 537 |
+
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 538 |
+
) -> Tuple[nn.Module, nn.Module]:
|
| 539 |
+
"""
|
| 540 |
+
Instrument model A and model B with loggers.
|
| 541 |
+
|
| 542 |
+
Args:
|
| 543 |
+
name_a: string name of model A to use in results
|
| 544 |
+
model_a: model A
|
| 545 |
+
name_b: string name of model B to use in results
|
| 546 |
+
model_b: model B
|
| 547 |
+
logger_cls: class of Logger to use
|
| 548 |
+
base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change
|
| 549 |
+
unmatchable_types_map: optional override of unmatchable types, subject to change
|
| 550 |
+
|
| 551 |
+
Return:
|
| 552 |
+
Returns a tuple of (model_a_with_loggers, model_b_with_loggers). Modifies both models inplace.
|
| 553 |
+
"""
|
| 554 |
+
|
| 555 |
+
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_loggers")
|
| 556 |
+
# TODO(future PR): expose these
|
| 557 |
+
skipped_module_names: List[str] = []
|
| 558 |
+
skipped_module_classes: List[Callable] = []
|
| 559 |
+
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
| 560 |
+
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
| 561 |
+
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
| 562 |
+
maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr(
|
| 563 |
+
model_a, "node_name_to_scope"
|
| 564 |
+
)
|
| 565 |
+
if maybe_model_a_node_name_to_scope is not None:
|
| 566 |
+
gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope
|
| 567 |
+
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
| 568 |
+
maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr(
|
| 569 |
+
model_b, "node_name_to_scope"
|
| 570 |
+
)
|
| 571 |
+
if maybe_model_b_node_name_to_scope is not None:
|
| 572 |
+
gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope
|
| 573 |
+
return _add_loggers_impl(
|
| 574 |
+
name_a,
|
| 575 |
+
gm_a,
|
| 576 |
+
name_b,
|
| 577 |
+
gm_b,
|
| 578 |
+
logger_cls,
|
| 579 |
+
should_log_inputs=should_log_inputs,
|
| 580 |
+
base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
|
| 581 |
+
unmatchable_types_map=unmatchable_types_map,
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
def _extract_logger_info_one_model(
|
| 586 |
+
model: nn.Module,
|
| 587 |
+
results: NSResultsType,
|
| 588 |
+
logger_cls: Callable,
|
| 589 |
+
) -> None:
|
| 590 |
+
torch._C._log_api_usage_once(
|
| 591 |
+
"quantization_api._numeric_suite_fx._extract_logger_info_one_model"
|
| 592 |
+
)
|
| 593 |
+
for gm_name, mod in model.named_modules():
|
| 594 |
+
# TODO(future PR): better check when scripted
|
| 595 |
+
is_logger = isinstance(mod, logger_cls) or ( # type: ignore[arg-type]
|
| 596 |
+
isinstance(mod, torch.jit.RecursiveScriptModule)
|
| 597 |
+
and mod.original_name == "OutputLogger"
|
| 598 |
+
)
|
| 599 |
+
if is_logger:
|
| 600 |
+
key = mod.ref_name
|
| 601 |
+
if key not in results:
|
| 602 |
+
results[key] = {}
|
| 603 |
+
assert (
|
| 604 |
+
mod.model_name not in results[key]
|
| 605 |
+
), f"{mod.model_name} is already present in results"
|
| 606 |
+
if mod.results_type not in results[key]:
|
| 607 |
+
results[key][mod.results_type] = {}
|
| 608 |
+
if mod.model_name not in results[key][mod.results_type]:
|
| 609 |
+
results[key][mod.results_type][mod.model_name] = []
|
| 610 |
+
stats_to_use = mod.stats
|
| 611 |
+
if len(mod.stats_rnn) > 0:
|
| 612 |
+
stats_to_use = mod.stats_rnn
|
| 613 |
+
data = {
|
| 614 |
+
"type": mod.results_type,
|
| 615 |
+
"values": stats_to_use,
|
| 616 |
+
"ref_node_name": mod.ref_node_name,
|
| 617 |
+
"ref_node_target_type": mod.ref_node_target_type,
|
| 618 |
+
"prev_node_name": mod.prev_node_name,
|
| 619 |
+
"prev_node_target_type": mod.prev_node_target_type,
|
| 620 |
+
"index_within_arg": mod.index_within_arg,
|
| 621 |
+
"index_of_arg": mod.index_of_arg,
|
| 622 |
+
"fqn": mod.fqn,
|
| 623 |
+
"qconfig_str": mod.qconfig_str,
|
| 624 |
+
}
|
| 625 |
+
if hasattr(mod, "comparisons"):
|
| 626 |
+
data["comparisons"] = mod.comparisons
|
| 627 |
+
data["comparison_fn_name"] = mod.comparison_fn_name
|
| 628 |
+
else:
|
| 629 |
+
data["comparisons"] = []
|
| 630 |
+
data["comparison_fn_name"] = ""
|
| 631 |
+
results[key][mod.results_type][mod.model_name].append(data)
|
| 632 |
+
# ensure the list stays sorted
|
| 633 |
+
results[key][mod.results_type][mod.model_name].sort(
|
| 634 |
+
key=lambda res: f"{res['index_of_arg']}:{res['index_within_arg']}"
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
# TODO(future PR): align on naming
|
| 639 |
+
# this is equivalent of just the comparison extraction part of `ns.compare_model_outputs`
|
| 640 |
+
def extract_logger_info(
|
| 641 |
+
model_a: nn.Module,
|
| 642 |
+
model_b: nn.Module,
|
| 643 |
+
logger_cls: Callable,
|
| 644 |
+
model_name_to_use_for_layer_names: str,
|
| 645 |
+
) -> NSResultsType:
|
| 646 |
+
"""
|
| 647 |
+
Traverse all loggers in `model_a` and `model_b`, and extract the logged
|
| 648 |
+
information.
|
| 649 |
+
|
| 650 |
+
Args:
|
| 651 |
+
model_a: model A
|
| 652 |
+
model_b: model B
|
| 653 |
+
logger_cls: class of Logger to use
|
| 654 |
+
model_name_to_use_for_layer_names: string name of model to use for
|
| 655 |
+
layer names in the output
|
| 656 |
+
|
| 657 |
+
Return:
|
| 658 |
+
NSResultsType, containing the logged comparisons
|
| 659 |
+
"""
|
| 660 |
+
torch._C._log_api_usage_once(
|
| 661 |
+
"quantization_api._numeric_suite_fx.extract_logger_info"
|
| 662 |
+
)
|
| 663 |
+
results: NSResultsType = {}
|
| 664 |
+
for model in (model_a, model_b):
|
| 665 |
+
_extract_logger_info_one_model(model, results, logger_cls)
|
| 666 |
+
# fill in missing fqn entries
|
| 667 |
+
maybe_add_missing_fqns(results)
|
| 668 |
+
# rekey on the name of model b
|
| 669 |
+
results = rekey_logger_info_on_node_name_of_model(
|
| 670 |
+
results, model_name_to_use_for_layer_names
|
| 671 |
+
)
|
| 672 |
+
return results
|
| 673 |
+
|
| 674 |
+
|
| 675 |
+
def _add_shadow_loggers_impl(
|
| 676 |
+
name_a: str,
|
| 677 |
+
gm_a: GraphModule,
|
| 678 |
+
name_b: str,
|
| 679 |
+
gm_b: GraphModule,
|
| 680 |
+
logger_cls: Callable,
|
| 681 |
+
should_log_inputs: bool,
|
| 682 |
+
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 683 |
+
node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 684 |
+
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 685 |
+
) -> nn.Module:
|
| 686 |
+
torch._C._log_api_usage_once(
|
| 687 |
+
"quantization_api._numeric_suite_fx._add_shadow_loggers_impl"
|
| 688 |
+
)
|
| 689 |
+
matched_subgraph_pairs = get_matching_subgraph_pairs(
|
| 690 |
+
gm_a, gm_b, base_name_to_sets_of_related_ops, unmatchable_types_map
|
| 691 |
+
)
|
| 692 |
+
gm_a_shadows_b = create_a_shadows_b(
|
| 693 |
+
name_a,
|
| 694 |
+
gm_a,
|
| 695 |
+
name_b,
|
| 696 |
+
gm_b,
|
| 697 |
+
matched_subgraph_pairs,
|
| 698 |
+
logger_cls,
|
| 699 |
+
should_log_inputs=should_log_inputs,
|
| 700 |
+
node_type_to_io_type_map=node_type_to_io_type_map,
|
| 701 |
+
)
|
| 702 |
+
return gm_a_shadows_b
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
def add_shadow_loggers(
|
| 706 |
+
name_a: str,
|
| 707 |
+
model_a: nn.Module,
|
| 708 |
+
name_b: str,
|
| 709 |
+
model_b: nn.Module,
|
| 710 |
+
logger_cls: Callable,
|
| 711 |
+
should_log_inputs: bool = False,
|
| 712 |
+
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 713 |
+
node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 714 |
+
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 715 |
+
) -> nn.Module:
|
| 716 |
+
"""
|
| 717 |
+
Instrument model A and model B with shadow loggers.
|
| 718 |
+
|
| 719 |
+
Args:
|
| 720 |
+
name_a: string name of model A to use in results
|
| 721 |
+
model_a: model A
|
| 722 |
+
name_b: string name of model B to use in results
|
| 723 |
+
model_b: model B
|
| 724 |
+
logger_cls: class of Logger to use
|
| 725 |
+
should_log_inputs: whether to log inputs
|
| 726 |
+
base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change
|
| 727 |
+
unmatchable_types_map: optional override of unmatchable types, subject to change
|
| 728 |
+
"""
|
| 729 |
+
torch._C._log_api_usage_once(
|
| 730 |
+
"quantization_api._numeric_suite_fx.add_shadow_loggers"
|
| 731 |
+
)
|
| 732 |
+
# TODO(future PR): expose these
|
| 733 |
+
skipped_module_names: List[str] = []
|
| 734 |
+
skipped_module_classes: List[Callable] = []
|
| 735 |
+
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
| 736 |
+
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
| 737 |
+
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
| 738 |
+
maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr(
|
| 739 |
+
model_a, "node_name_to_scope"
|
| 740 |
+
)
|
| 741 |
+
if maybe_model_a_node_name_to_scope is not None:
|
| 742 |
+
gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope
|
| 743 |
+
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
| 744 |
+
maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr(
|
| 745 |
+
model_b, "node_name_to_scope"
|
| 746 |
+
)
|
| 747 |
+
if maybe_model_b_node_name_to_scope is not None:
|
| 748 |
+
gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope
|
| 749 |
+
return _add_shadow_loggers_impl(
|
| 750 |
+
name_a,
|
| 751 |
+
gm_a,
|
| 752 |
+
name_b,
|
| 753 |
+
gm_b,
|
| 754 |
+
logger_cls,
|
| 755 |
+
should_log_inputs=should_log_inputs,
|
| 756 |
+
base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
|
| 757 |
+
node_type_to_io_type_map=node_type_to_io_type_map,
|
| 758 |
+
unmatchable_types_map=unmatchable_types_map,
|
| 759 |
+
)
|
| 760 |
+
|
| 761 |
+
|
| 762 |
+
def extract_shadow_logger_info(
|
| 763 |
+
model_a_shadows_b: nn.Module,
|
| 764 |
+
logger_cls: Callable,
|
| 765 |
+
model_name_to_use_for_layer_names: str,
|
| 766 |
+
) -> NSResultsType:
|
| 767 |
+
"""
|
| 768 |
+
Traverse all loggers in a shadow model, and extract the logged
|
| 769 |
+
information.
|
| 770 |
+
|
| 771 |
+
Args:
|
| 772 |
+
model_a_shadows_b: shadow model
|
| 773 |
+
logger_cls: class of Logger to use
|
| 774 |
+
model_name_to_use_for_layer_names: string name of model to use for
|
| 775 |
+
layer names in the output
|
| 776 |
+
|
| 777 |
+
Return:
|
| 778 |
+
NSResultsType, containing the logged comparisons
|
| 779 |
+
"""
|
| 780 |
+
torch._C._log_api_usage_once(
|
| 781 |
+
"quantization_api._numeric_suite_fx.extract_shadow_logger_info"
|
| 782 |
+
)
|
| 783 |
+
results: NSResultsType = collections.defaultdict(dict)
|
| 784 |
+
_extract_logger_info_one_model(model_a_shadows_b, results, logger_cls)
|
| 785 |
+
# fill in missing fqn entries
|
| 786 |
+
maybe_add_missing_fqns(results)
|
| 787 |
+
# rekey on the name of model b
|
| 788 |
+
results = rekey_logger_info_on_node_name_of_model(
|
| 789 |
+
results, model_name_to_use_for_layer_names
|
| 790 |
+
)
|
| 791 |
+
return dict(results)
|
| 792 |
+
|
| 793 |
+
|
| 794 |
+
def extend_logger_results_with_comparison(
|
| 795 |
+
results: NSResultsType,
|
| 796 |
+
model_name_1: str,
|
| 797 |
+
model_name_2: str,
|
| 798 |
+
comparison_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
| 799 |
+
comparison_name: str,
|
| 800 |
+
) -> None:
|
| 801 |
+
"""
|
| 802 |
+
Compares the logged values from `model_name_2` against the corresponding
|
| 803 |
+
values in `model_name_1`, using `comparison_fn`. Records the result
|
| 804 |
+
in `model_name_2`'s results under `comparison_name`. Modifies `results` inplace.
|
| 805 |
+
|
| 806 |
+
Args:
|
| 807 |
+
results: the result data structure from `extract_logger_info` or
|
| 808 |
+
`extract_shadow_logger_info`.
|
| 809 |
+
model_name_1: string name of model 1
|
| 810 |
+
model_name_2: string name of model 2
|
| 811 |
+
comparison_fn: function to compare two Tensors
|
| 812 |
+
comparison_name: string name of model to use for
|
| 813 |
+
layer names in the output
|
| 814 |
+
"""
|
| 815 |
+
for results_type_to_results in results.values():
|
| 816 |
+
for model_name_to_results in results_type_to_results.values():
|
| 817 |
+
assert (
|
| 818 |
+
model_name_1 in model_name_to_results
|
| 819 |
+
), f"{model_name_1} not found in results"
|
| 820 |
+
assert (
|
| 821 |
+
model_name_2 in model_name_to_results
|
| 822 |
+
), f"{model_name_2} not found in results"
|
| 823 |
+
|
| 824 |
+
results_1 = model_name_to_results[model_name_1]
|
| 825 |
+
results_2 = model_name_to_results[model_name_2]
|
| 826 |
+
|
| 827 |
+
for result_2 in results_2:
|
| 828 |
+
index_within_arg_2 = result_2["index_within_arg"]
|
| 829 |
+
index_of_arg_2 = result_2["index_of_arg"]
|
| 830 |
+
# find corresponding result_1
|
| 831 |
+
result_1 = None
|
| 832 |
+
for cur_result_1 in results_1:
|
| 833 |
+
index_within_arg_1 = cur_result_1["index_within_arg"]
|
| 834 |
+
index_of_arg_1 = cur_result_1["index_of_arg"]
|
| 835 |
+
if (index_within_arg_1 == index_within_arg_2) and (
|
| 836 |
+
index_of_arg_1 == index_of_arg_2
|
| 837 |
+
):
|
| 838 |
+
result_1 = cur_result_1
|
| 839 |
+
break
|
| 840 |
+
assert result_1 is not None
|
| 841 |
+
|
| 842 |
+
values_1 = result_1["values"]
|
| 843 |
+
values_2 = result_2["values"]
|
| 844 |
+
result_2[comparison_name] = []
|
| 845 |
+
for value_1, value_2 in zip(values_1, values_2):
|
| 846 |
+
comparison_result = comparison_fn(value_1, value_2)
|
| 847 |
+
result_2[comparison_name].append(comparison_result)
|
| 848 |
+
|
| 849 |
+
|
| 850 |
+
def prepare_n_shadows_model(
|
| 851 |
+
model: torch.nn.Module,
|
| 852 |
+
example_inputs: Any,
|
| 853 |
+
qconfig_multi_mapping: QConfigMultiMapping,
|
| 854 |
+
backend_config: BackendConfig,
|
| 855 |
+
custom_prepare_fn: Optional[Callable] = None,
|
| 856 |
+
custom_prepare_kwargs: Optional[Dict[str, Any]] = None,
|
| 857 |
+
custom_tracer: Any = None,
|
| 858 |
+
) -> GraphModule:
|
| 859 |
+
"""
|
| 860 |
+
Given a model with a graph with M ops such as
|
| 861 |
+
|
| 862 |
+
|
| 863 |
+
args_kwargs_m -> op_m -> output_m
|
| 864 |
+
|
| 865 |
+
|
| 866 |
+
And a set of N qconfigs for each op, creates a new model, with
|
| 867 |
+
each of the subgraph of `op_m` transformed into
|
| 868 |
+
|
| 869 |
+
.. code::
|
| 870 |
+
|
| 871 |
+
|---------> op_m_n -> log_m_n
|
| 872 |
+
| /
|
| 873 |
+
args_kwargs_m ---------> op_m -> log_m_0
|
| 874 |
+
|
| 875 |
+
Where op_m_n is op_m wrapped in a submodule and transformed with
|
| 876 |
+
qconfig_n, and its inner graph looks like
|
| 877 |
+
|
| 878 |
+
.. code::
|
| 879 |
+
|
| 880 |
+
args_m -------- op_m_prepared_with_qconfig_n -> out_m_n
|
| 881 |
+
/
|
| 882 |
+
kwargs_m ---
|
| 883 |
+
|
| 884 |
+
This is useful for testing different quantization of multiple layers in
|
| 885 |
+
a single pass through the model.
|
| 886 |
+
|
| 887 |
+
High level TODOs for future PRs:
|
| 888 |
+
* figure out a better way to name the output structure
|
| 889 |
+
* return a results data structure instead of printing it out
|
| 890 |
+
* add examples to docblocks
|
| 891 |
+
"""
|
| 892 |
+
|
| 893 |
+
if custom_tracer is None:
|
| 894 |
+
tracer = quantize_fx.QuantizationTracer([], [])
|
| 895 |
+
else:
|
| 896 |
+
tracer = custom_tracer
|
| 897 |
+
mt = torch.fx.GraphModule(model, tracer.trace(model))
|
| 898 |
+
# this is necessary to ensure logger FQNs get populated
|
| 899 |
+
mt._node_name_to_scope = tracer.node_name_to_scope # type: ignore[assignment]
|
| 900 |
+
|
| 901 |
+
# run example input propagation, we need this to call prepare_fx on
|
| 902 |
+
# individual subgraphs
|
| 903 |
+
output_prop = OutputProp(mt)
|
| 904 |
+
output_prop.propagate(*example_inputs)
|
| 905 |
+
|
| 906 |
+
# Find the set of subgraphs in the original graph which we need to
|
| 907 |
+
# consider.
|
| 908 |
+
modules = dict(mt.named_modules(remove_duplicate=False))
|
| 909 |
+
patterns = _get_pattern_to_quantize_handlers(backend_config)
|
| 910 |
+
root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config)
|
| 911 |
+
standalone_module_names: List[str] = []
|
| 912 |
+
standalone_module_classes: List[Type] = []
|
| 913 |
+
custom_module_classes: List[Type] = []
|
| 914 |
+
matches = _find_matches(
|
| 915 |
+
mt.graph,
|
| 916 |
+
modules,
|
| 917 |
+
patterns,
|
| 918 |
+
root_node_getter_mapping,
|
| 919 |
+
standalone_module_names,
|
| 920 |
+
standalone_module_classes,
|
| 921 |
+
custom_module_classes,
|
| 922 |
+
)
|
| 923 |
+
subgraphs_dedup: Dict[str, List[Node]] = _get_dedup_subgraphs(matches)
|
| 924 |
+
|
| 925 |
+
# generate node to qconfig for each subgraph
|
| 926 |
+
# TODO(future PR): deduplicate repeating entries
|
| 927 |
+
list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]] = []
|
| 928 |
+
for qconfig_mapping in qconfig_multi_mapping.qconfig_mappings_list:
|
| 929 |
+
node_name_to_qconfig = _generate_node_name_to_qconfig(
|
| 930 |
+
mt, modules, mt.graph, qconfig_mapping, tracer.node_name_to_scope
|
| 931 |
+
)
|
| 932 |
+
list_of_node_name_to_qconfig.append(node_name_to_qconfig)
|
| 933 |
+
|
| 934 |
+
# For each region in the model, do the following:
|
| 935 |
+
# For each qconfig for that region, do the following:
|
| 936 |
+
# 1. create a copy of the region wrapped in a module
|
| 937 |
+
# 2. pass original args, original kwargs, and expected output to module
|
| 938 |
+
# 3. add an output comparison logger and hook it up to compare
|
| 939 |
+
# actual output to expected output
|
| 940 |
+
# 4. run `prepare_fx` on the module
|
| 941 |
+
for subgraph_idx, (match_name, nodes_in_this_subgraph) in enumerate(
|
| 942 |
+
subgraphs_dedup.items()
|
| 943 |
+
):
|
| 944 |
+
create_n_transformed_and_logged_copies_of_subgraph(
|
| 945 |
+
mt,
|
| 946 |
+
subgraph_idx,
|
| 947 |
+
match_name,
|
| 948 |
+
nodes_in_this_subgraph,
|
| 949 |
+
qconfig_multi_mapping.qconfig_mappings_list,
|
| 950 |
+
list_of_node_name_to_qconfig,
|
| 951 |
+
custom_prepare_fn,
|
| 952 |
+
custom_prepare_kwargs, # type: ignore[arg-type]
|
| 953 |
+
)
|
| 954 |
+
|
| 955 |
+
return mt
|
| 956 |
+
|
| 957 |
+
|
| 958 |
+
# TODO(future PR): we should rethink the names of all the PNP APIs
|
| 959 |
+
def _prepare_n_shadows_add_loggers_model(
|
| 960 |
+
model: torch.nn.Module,
|
| 961 |
+
example_inputs: Any,
|
| 962 |
+
qconfig_mapping: QConfigMapping,
|
| 963 |
+
backend_config: BackendConfig,
|
| 964 |
+
) -> torch.nn.Module:
|
| 965 |
+
r"""
|
| 966 |
+
Note: this API is not recommended for wide usage, it is only
|
| 967 |
+
provided for customers who need to migrate from the `add_loggers`
|
| 968 |
+
API.
|
| 969 |
+
|
| 970 |
+
This creates a model which provides logging for the following
|
| 971 |
+
problem: if we quantize `model` with `qconfig_mapping` and feed
|
| 972 |
+
the same input through both models, log the comparisons of
|
| 973 |
+
corresponding intermediate layers.
|
| 974 |
+
|
| 975 |
+
The problem is solved with a single model. Specifically, we
|
| 976 |
+
partition `model` into N subgraphs, create a copy of each relevant
|
| 977 |
+
subgraph, wrap it in a module, apply the quantization API to that
|
| 978 |
+
module, and hook up loggers to measure the comparisons.
|
| 979 |
+
|
| 980 |
+
Example starting graph:
|
| 981 |
+
|
| 982 |
+
x0 -> op0 -> x1 -> op1 -> x2
|
| 983 |
+
|
| 984 |
+
Example config: quantize op0 to int8, do nothing to op1.
|
| 985 |
+
The following graph will be created:
|
| 986 |
+
|
| 987 |
+
.. code::
|
| 988 |
+
|
| 989 |
+
x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log
|
| 990 |
+
\ \ \ # noqa: W605
|
| 991 |
+
---> op0_1 -> x1_1 ----> clog -> op1_0 -> x2_1 ----> clog
|
| 992 |
+
|
| 993 |
+
Where op0_0 is op0, op0_1 is op0 wrapped in a submodule and quantized
|
| 994 |
+
to int8, op1_0 is op1 (appearing in the graph twice), log is a logger,
|
| 995 |
+
and clog is a comparison logger.
|
| 996 |
+
"""
|
| 997 |
+
|
| 998 |
+
tracer = quantize_fx.QuantizationTracer([], [])
|
| 999 |
+
mt = torch.fx.GraphModule(model, tracer.trace(model))
|
| 1000 |
+
# this is necessary to ensure logger FQNs get populated
|
| 1001 |
+
mt._node_name_to_scope = tracer.node_name_to_scope # type: ignore[assignment]
|
| 1002 |
+
|
| 1003 |
+
# run example input propagation, we need this to call prepare_fx on
|
| 1004 |
+
# individual subgraphs
|
| 1005 |
+
output_prop = OutputProp(mt)
|
| 1006 |
+
output_prop.propagate(*example_inputs)
|
| 1007 |
+
|
| 1008 |
+
# Find the set of subgraphs in the original graph which we need to
|
| 1009 |
+
# consider.
|
| 1010 |
+
modules = dict(mt.named_modules(remove_duplicate=False))
|
| 1011 |
+
patterns = _get_pattern_to_quantize_handlers(backend_config)
|
| 1012 |
+
root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config)
|
| 1013 |
+
standalone_module_names: List[str] = []
|
| 1014 |
+
standalone_module_classes: List[Type] = []
|
| 1015 |
+
custom_module_classes: List[Type] = []
|
| 1016 |
+
matches = _find_matches(
|
| 1017 |
+
mt.graph,
|
| 1018 |
+
modules,
|
| 1019 |
+
patterns,
|
| 1020 |
+
root_node_getter_mapping,
|
| 1021 |
+
standalone_module_names,
|
| 1022 |
+
standalone_module_classes,
|
| 1023 |
+
custom_module_classes,
|
| 1024 |
+
)
|
| 1025 |
+
subgraphs_dedup: Dict[str, List[Node]] = _get_dedup_subgraphs(matches)
|
| 1026 |
+
|
| 1027 |
+
# generate node to qconfig for each subgraph
|
| 1028 |
+
node_name_to_qconfig = _generate_node_name_to_qconfig(
|
| 1029 |
+
mt, modules, mt.graph, qconfig_mapping, tracer.node_name_to_scope
|
| 1030 |
+
)
|
| 1031 |
+
|
| 1032 |
+
# Now, mutate the graph to be the add_loggers graph with propagation
|
| 1033 |
+
# error.
|
| 1034 |
+
create_add_loggers_graph(mt, subgraphs_dedup, qconfig_mapping, node_name_to_qconfig)
|
| 1035 |
+
|
| 1036 |
+
return mt
|
| 1037 |
+
|
| 1038 |
+
|
| 1039 |
+
# TODO(future PR): we should rethink the names of all the PNP APIs
|
| 1040 |
+
def _n_shadows_compare_weights(
|
| 1041 |
+
model: torch.nn.Module,
|
| 1042 |
+
example_inputs: Any,
|
| 1043 |
+
qconfig_mapping: QConfigMapping,
|
| 1044 |
+
backend_config: BackendConfig,
|
| 1045 |
+
) -> NSResultsType:
|
| 1046 |
+
"""
|
| 1047 |
+
Note: this API is not recommended for wide usage, it is only
|
| 1048 |
+
provided for customers who need to migrate from the `add_loggers`
|
| 1049 |
+
API.
|
| 1050 |
+
"""
|
| 1051 |
+
qconfig_multi_mapping = QConfigMultiMapping.from_list_qconfig_mapping(
|
| 1052 |
+
[qconfig_mapping]
|
| 1053 |
+
)
|
| 1054 |
+
mp = prepare_n_shadows_model(
|
| 1055 |
+
model, example_inputs, qconfig_multi_mapping, backend_config
|
| 1056 |
+
)
|
| 1057 |
+
# passing inputs through the model is necessary to populate
|
| 1058 |
+
# observers which observe weights with real values
|
| 1059 |
+
mp(*example_inputs)
|
| 1060 |
+
mq = convert_n_shadows_model(mp)
|
| 1061 |
+
weight_comparison = extract_weight_comparison(mq)
|
| 1062 |
+
return weight_comparison
|
| 1063 |
+
|
| 1064 |
+
|
| 1065 |
+
# TODO(future PR): consider aligning API signature with other similar quantization
|
| 1066 |
+
# functions (enable_fake_quant, etc)
|
| 1067 |
+
def loggers_set_enabled(model: torch.nn.Module, enabled: bool) -> None:
|
| 1068 |
+
"""
|
| 1069 |
+
Sets the `enabled` setting on a `model`'s loggers
|
| 1070 |
+
"""
|
| 1071 |
+
for name, child in model.named_modules():
|
| 1072 |
+
if isinstance(child, OutputLogger):
|
| 1073 |
+
child.enabled = enabled
|
| 1074 |
+
|
| 1075 |
+
|
| 1076 |
+
# TODO(future PR): consider aligning API signature with other similar quantization
|
| 1077 |
+
# functions (enable_fake_quant, etc)
|
| 1078 |
+
def loggers_set_save_activations(
|
| 1079 |
+
model: torch.nn.Module,
|
| 1080 |
+
save_activations: bool,
|
| 1081 |
+
) -> None:
|
| 1082 |
+
"""
|
| 1083 |
+
Sets the `save_activations` setting on a `model`'s loggers
|
| 1084 |
+
"""
|
| 1085 |
+
for name, child in model.named_modules():
|
| 1086 |
+
if isinstance(child, OutputLogger):
|
| 1087 |
+
child.save_activations = save_activations
|
| 1088 |
+
|
| 1089 |
+
|
| 1090 |
+
def convert_n_shadows_model(
|
| 1091 |
+
model: GraphModule,
|
| 1092 |
+
custom_convert_fn: Optional[Callable] = None,
|
| 1093 |
+
custom_convert_kwargs: Optional[Dict[str, Any]] = None,
|
| 1094 |
+
) -> GraphModule:
|
| 1095 |
+
"""
|
| 1096 |
+
Given a model from `prepare_n_shadows_model`, runs `convert_fx`
|
| 1097 |
+
on each shadow submodule.
|
| 1098 |
+
"""
|
| 1099 |
+
for node in model.graph.nodes:
|
| 1100 |
+
# TODO(future PR): consider matching in a safer way than
|
| 1101 |
+
# node name string match
|
| 1102 |
+
if node.name.startswith(SHADOW_WRAPPER_NODE_NAME_PREFIX):
|
| 1103 |
+
orig_mod = getattr(model, node.name)
|
| 1104 |
+
if custom_convert_fn is None:
|
| 1105 |
+
converted_mod = torch.ao.quantization.quantize_fx.convert_fx(orig_mod)
|
| 1106 |
+
else:
|
| 1107 |
+
if custom_convert_kwargs is None:
|
| 1108 |
+
custom_convert_kwargs = {}
|
| 1109 |
+
converted_mod = custom_convert_fn(orig_mod, **custom_convert_kwargs)
|
| 1110 |
+
setattr(model, node.name, converted_mod)
|
| 1111 |
+
|
| 1112 |
+
return model
|
| 1113 |
+
|
| 1114 |
+
|
| 1115 |
+
def extract_results_n_shadows_model(model: torch.nn.Module) -> NSResultsType:
|
| 1116 |
+
"""
|
| 1117 |
+
Extracts logger results from `model`.
|
| 1118 |
+
"""
|
| 1119 |
+
results: NSResultsType = {}
|
| 1120 |
+
_extract_logger_info_one_model(model, results, OutputLogger)
|
| 1121 |
+
return results
|
| 1122 |
+
|
| 1123 |
+
|
| 1124 |
+
def print_comparisons_n_shadows_model(results: NSResultsType) -> None:
|
| 1125 |
+
"""
|
| 1126 |
+
Prints a summary of extracted `results`.
|
| 1127 |
+
"""
|
| 1128 |
+
results_grouped = group_results_by_subgraph(results)
|
| 1129 |
+
results_comparison = create_results_comparison(results_grouped)
|
| 1130 |
+
print_n_shadows_summary(results_comparison)
|
.venv/Lib/site-packages/torch/ao/ns/fx/__init__.py
ADDED
|
File without changes
|
.venv/Lib/site-packages/torch/ao/ns/fx/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (186 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/ao/ns/fx/__pycache__/ns_types.cpython-39.pyc
ADDED
|
Binary file (976 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/ao/ns/fx/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (12.7 kB). View file
|
|
|
.venv/Lib/site-packages/torch/ao/ns/fx/graph_matcher.py
ADDED
|
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import collections
|
| 3 |
+
import enum
|
| 4 |
+
from typing import Any, Dict, List, Optional, Set, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch.ao.quantization import FakeQuantizeBase, ObserverBase
|
| 8 |
+
from torch.ao.quantization.utils import getattr_from_fqn
|
| 9 |
+
from torch.fx import GraphModule
|
| 10 |
+
from torch.fx.graph import Graph, Node
|
| 11 |
+
|
| 12 |
+
from .mappings import get_base_name_to_sets_of_related_ops, get_unmatchable_types_map
|
| 13 |
+
from .ns_types import NSNodeTargetType, NSSubgraph
|
| 14 |
+
from .pattern_utils import (
|
| 15 |
+
end_node_matches_reversed_fusion,
|
| 16 |
+
get_reversed_fusions,
|
| 17 |
+
get_type_a_related_to_b,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
toq = torch.ops.quantized
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _get_output_nodes(g: Graph) -> List[Node]:
|
| 25 |
+
return [n for n in g.nodes if n.op == "output"]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class _NSGraphMatchableSubgraphsIterator:
|
| 29 |
+
"""
|
| 30 |
+
Iterates through the graph of gm, starting with the output nodes
|
| 31 |
+
and continuing backwards.
|
| 32 |
+
1. Returns matchable subgraphs, in order. A subgraph is defined by
|
| 33 |
+
(start_node, end_node).
|
| 34 |
+
2. Skips over non-matchable subgraphs
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
gm: GraphModule,
|
| 40 |
+
non_matchable_functions: Set[NSNodeTargetType],
|
| 41 |
+
non_matchable_modules: Set[NSNodeTargetType],
|
| 42 |
+
non_matchable_methods: Set[NSNodeTargetType],
|
| 43 |
+
):
|
| 44 |
+
self.gm: GraphModule = gm
|
| 45 |
+
self.non_matchable_functions: Set[NSNodeTargetType] = non_matchable_functions
|
| 46 |
+
self.non_matchable_modules: Set[NSNodeTargetType] = non_matchable_modules
|
| 47 |
+
self.non_matchable_methods: Set[NSNodeTargetType] = non_matchable_methods
|
| 48 |
+
self.seen_nodes: Set[Node] = set()
|
| 49 |
+
self.stack: List[Node] = []
|
| 50 |
+
for start_node in _get_output_nodes(self.gm.graph):
|
| 51 |
+
self.stack.append(start_node)
|
| 52 |
+
|
| 53 |
+
def __iter__(self):
|
| 54 |
+
return self
|
| 55 |
+
|
| 56 |
+
def __next__(self) -> NSSubgraph:
|
| 57 |
+
"""
|
| 58 |
+
Returns the next matchable subgraph.
|
| 59 |
+
"""
|
| 60 |
+
while len(self.stack) > 0:
|
| 61 |
+
cur_end_node = self.stack.pop()
|
| 62 |
+
if cur_end_node in self.seen_nodes:
|
| 63 |
+
continue
|
| 64 |
+
|
| 65 |
+
# for subgraphs which are single nodes, start_node == end_node
|
| 66 |
+
# for subgraphs with more than one node, start node != end_node
|
| 67 |
+
cur_start_node = cur_end_node
|
| 68 |
+
# Subgraphs like linear-relu have the base node as the start node.
|
| 69 |
+
# Subgraphs like dequantize-linear-relu-to(torch.float16) have the
|
| 70 |
+
# base node as the second node.
|
| 71 |
+
# The cur_base_op_node var will move to the actual node during
|
| 72 |
+
# the fusion matching later in this code block.
|
| 73 |
+
cur_base_op_node = cur_end_node
|
| 74 |
+
|
| 75 |
+
# Check for potential fusions. For now, we are greedy
|
| 76 |
+
# and always skip all non-base nodes of a fusion. For example,
|
| 77 |
+
# if we match linear-relu backwards, we will always skip the
|
| 78 |
+
# relu node and attempt to match the linear node. This can
|
| 79 |
+
# be made configurable later if needed.
|
| 80 |
+
for _reverse_fusion_ops, base_op_idx in get_reversed_fusions():
|
| 81 |
+
is_match = end_node_matches_reversed_fusion(
|
| 82 |
+
cur_end_node, _reverse_fusion_ops, self.gm, self.seen_nodes
|
| 83 |
+
)
|
| 84 |
+
if is_match:
|
| 85 |
+
# navigate to the base node
|
| 86 |
+
for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1):
|
| 87 |
+
self.seen_nodes.add(cur_start_node)
|
| 88 |
+
# for now, assume that there are no other nodes
|
| 89 |
+
# which need to be added to the stack
|
| 90 |
+
cur_start_node = cur_start_node.args[0] # type: ignore[assignment]
|
| 91 |
+
# if the base op index matches the current node, set it
|
| 92 |
+
rev_base_op_idx = len(_reverse_fusion_ops) - 2 - base_op_idx
|
| 93 |
+
if rev_fusion_idx == rev_base_op_idx:
|
| 94 |
+
cur_base_op_node = cur_start_node
|
| 95 |
+
break
|
| 96 |
+
|
| 97 |
+
self.seen_nodes.add(cur_start_node)
|
| 98 |
+
# add args of previous nodes to stack
|
| 99 |
+
for arg in cur_start_node.all_input_nodes:
|
| 100 |
+
self._recursively_add_node_arg_to_stack(arg)
|
| 101 |
+
|
| 102 |
+
# skip unmatchable nodes
|
| 103 |
+
# note: this check is done on the start_node, i.e.
|
| 104 |
+
# if we are matching linear-relu in reverse, this would do the matchable
|
| 105 |
+
# check on the linear
|
| 106 |
+
if not self._is_matchable(cur_base_op_node):
|
| 107 |
+
continue
|
| 108 |
+
|
| 109 |
+
# If an observer or a fake_quant was not matched as a part of
|
| 110 |
+
# a pattern of multiple nodes, ignore it. One case where this is
|
| 111 |
+
# relevant is an observer on a graph input, which was added because
|
| 112 |
+
# it is necessary for the next node.
|
| 113 |
+
if cur_end_node.op == "call_module" and cur_start_node is cur_end_node:
|
| 114 |
+
maybe_obs = getattr_from_fqn(self.gm, cur_end_node.target) # type: ignore[arg-type]
|
| 115 |
+
if isinstance(maybe_obs, (ObserverBase, FakeQuantizeBase)):
|
| 116 |
+
continue
|
| 117 |
+
|
| 118 |
+
return NSSubgraph(
|
| 119 |
+
start_node=cur_start_node,
|
| 120 |
+
end_node=cur_end_node,
|
| 121 |
+
base_op_node=cur_base_op_node,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
raise StopIteration
|
| 125 |
+
|
| 126 |
+
def _recursively_add_node_arg_to_stack(self, arg: Any) -> None:
|
| 127 |
+
"""
|
| 128 |
+
Adds all of the nodes in this arg to the stack, properly navigating
|
| 129 |
+
through list, dicts and tuples.
|
| 130 |
+
"""
|
| 131 |
+
if isinstance(arg, Node):
|
| 132 |
+
self.stack.append(arg)
|
| 133 |
+
elif (
|
| 134 |
+
isinstance(arg, torch.fx.immutable_collections.immutable_list)
|
| 135 |
+
or type(arg) is tuple
|
| 136 |
+
):
|
| 137 |
+
for inner_arg in arg:
|
| 138 |
+
self._recursively_add_node_arg_to_stack(inner_arg)
|
| 139 |
+
elif isinstance(arg, torch.fx.immutable_collections.immutable_dict):
|
| 140 |
+
for value in arg.values():
|
| 141 |
+
self._recursively_add_node_arg_to_stack(value)
|
| 142 |
+
|
| 143 |
+
def _is_matchable(self, node: Node) -> bool:
|
| 144 |
+
if node.op == "call_function":
|
| 145 |
+
return node.target not in self.non_matchable_functions
|
| 146 |
+
elif node.op == "call_module":
|
| 147 |
+
assert isinstance(node.target, str)
|
| 148 |
+
target_mod = getattr_from_fqn(self.gm, node.target)
|
| 149 |
+
return not any(
|
| 150 |
+
isinstance(target_mod, t) # type: ignore[arg-type]
|
| 151 |
+
for t in self.non_matchable_modules
|
| 152 |
+
)
|
| 153 |
+
elif node.op == "call_method":
|
| 154 |
+
return node.target not in self.non_matchable_methods
|
| 155 |
+
else:
|
| 156 |
+
return False
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class GraphMatchingException(Exception):
|
| 160 |
+
"""
|
| 161 |
+
Exception raised when two graphs cannot be matched.
|
| 162 |
+
"""
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class SubgraphTypeRelationship(enum.Enum):
|
| 166 |
+
# same type, known
|
| 167 |
+
# example: F.linear and F.linear, or nn.Conv2d and nn.Conv2d
|
| 168 |
+
EQUAL = enum.auto()
|
| 169 |
+
# same type, but the type is not known to Numerical Suite
|
| 170 |
+
# (user defined type, etc).
|
| 171 |
+
EQUAL_BUT_UKNOWN = enum.auto()
|
| 172 |
+
# known, same subgraph_relationship set, but not the same type
|
| 173 |
+
# example: F.linear and toq.linear
|
| 174 |
+
RELATED_BUT_NOT_EQUAL = enum.auto()
|
| 175 |
+
# not related
|
| 176 |
+
NOT_RELATED = enum.auto()
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def _get_subgraph_relationship_type(
|
| 180 |
+
subgraph_a: NSSubgraph,
|
| 181 |
+
subgraph_b: NSSubgraph,
|
| 182 |
+
gm_a: GraphModule,
|
| 183 |
+
gm_b: GraphModule,
|
| 184 |
+
type_a_related_to_b: Set[Tuple[NSNodeTargetType, NSNodeTargetType]],
|
| 185 |
+
) -> SubgraphTypeRelationship:
|
| 186 |
+
node_a = subgraph_a.base_op_node
|
| 187 |
+
node_b = subgraph_b.base_op_node
|
| 188 |
+
|
| 189 |
+
# TODO(next): make this code handle matching by what is before the base op
|
| 190 |
+
if node_a.op != node_b.op:
|
| 191 |
+
if not (
|
| 192 |
+
node_a.op in ("call_function", "call_method")
|
| 193 |
+
and node_b.op in ("call_function", "call_method")
|
| 194 |
+
):
|
| 195 |
+
return SubgraphTypeRelationship.NOT_RELATED
|
| 196 |
+
|
| 197 |
+
if node_a.op in ("call_function", "call_method"):
|
| 198 |
+
key = (node_a.target, node_b.target)
|
| 199 |
+
|
| 200 |
+
if key not in type_a_related_to_b:
|
| 201 |
+
if node_a.target == node_b.target:
|
| 202 |
+
return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN
|
| 203 |
+
else:
|
| 204 |
+
return SubgraphTypeRelationship.NOT_RELATED
|
| 205 |
+
# after this point, we are dealing with known types
|
| 206 |
+
|
| 207 |
+
if node_a.target == node_b.target:
|
| 208 |
+
node_a_has_prev = subgraph_a.base_op_node == subgraph_a.start_node
|
| 209 |
+
node_b_has_prev = subgraph_b.base_op_node == subgraph_b.start_node
|
| 210 |
+
if node_a_has_prev and (not node_b_has_prev):
|
| 211 |
+
return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
|
| 212 |
+
elif (not node_a_has_prev) and node_b_has_prev:
|
| 213 |
+
return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
|
| 214 |
+
elif (not node_a_has_prev) and (not node_b_has_prev):
|
| 215 |
+
return SubgraphTypeRelationship.EQUAL
|
| 216 |
+
else:
|
| 217 |
+
# TODO(future PR): check for matches start_op_node and base_op_node
|
| 218 |
+
return SubgraphTypeRelationship.EQUAL
|
| 219 |
+
|
| 220 |
+
if key in type_a_related_to_b:
|
| 221 |
+
return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
|
| 222 |
+
else:
|
| 223 |
+
return SubgraphTypeRelationship.NOT_RELATED
|
| 224 |
+
elif node_a.op == "call_module":
|
| 225 |
+
assert (
|
| 226 |
+
subgraph_a.base_op_node == subgraph_a.start_node
|
| 227 |
+
and subgraph_b.base_op_node == subgraph_b.start_node
|
| 228 |
+
), "Matching call_module patterns where base_op_node != start_node is not supported yet"
|
| 229 |
+
# for call_module, we need to look up the modules to do the type check
|
| 230 |
+
assert isinstance(node_a.target, str)
|
| 231 |
+
mod_a = getattr_from_fqn(gm_a, node_a.target)
|
| 232 |
+
assert isinstance(node_b.target, str)
|
| 233 |
+
mod_b = getattr_from_fqn(gm_b, node_b.target)
|
| 234 |
+
|
| 235 |
+
key = (type(mod_a), type(mod_b))
|
| 236 |
+
|
| 237 |
+
if key not in type_a_related_to_b:
|
| 238 |
+
if type(mod_a) == type(mod_b):
|
| 239 |
+
return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN
|
| 240 |
+
else:
|
| 241 |
+
return SubgraphTypeRelationship.NOT_RELATED
|
| 242 |
+
elif type(mod_a) == type(mod_b):
|
| 243 |
+
return SubgraphTypeRelationship.EQUAL
|
| 244 |
+
else:
|
| 245 |
+
return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
|
| 246 |
+
|
| 247 |
+
return SubgraphTypeRelationship.NOT_RELATED
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def _get_name_for_subgraph(
|
| 251 |
+
subgraph_a: NSSubgraph,
|
| 252 |
+
gm_a: GraphModule,
|
| 253 |
+
base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
|
| 254 |
+
existing_names: Set[str],
|
| 255 |
+
) -> str:
|
| 256 |
+
"""
|
| 257 |
+
Returns a unique name for a subgraph. This name is based on two things:
|
| 258 |
+
1. the name of the set containing the underlying type of the base op in the
|
| 259 |
+
subgraph (i.e. 'torch.nn.functional.linear' if this is related to a linear op)
|
| 260 |
+
2. the number of previous subgraphs with related underlying type of the base op
|
| 261 |
+
|
| 262 |
+
For example, in the graph
|
| 263 |
+
|
| 264 |
+
linear0 -> relu0 -> linear1 -> relu1
|
| 265 |
+
|
| 266 |
+
The subgraphs are (linear0, relu0) and (linear1, relu1). If we iterate
|
| 267 |
+
from the output node backwards, the name given to (linear1, relu1) will be
|
| 268 |
+
`base_op_torch.nn.functional.linear_0`, and the name given to (linear0, relu0)
|
| 269 |
+
will be `base_op_torch.nn.functional.linear_1`.
|
| 270 |
+
|
| 271 |
+
Why are we not just using the node name? Answer: because of two requirements:
|
| 272 |
+
A. fusions must be supported
|
| 273 |
+
B. some Numeric Suite APIs can be called without having all of the models in memory
|
| 274 |
+
|
| 275 |
+
For example, let's say we need to match nodes of
|
| 276 |
+
|
| 277 |
+
(1) ... -> linear0 -> relu0 -> ...
|
| 278 |
+
|
| 279 |
+
And
|
| 280 |
+
|
| 281 |
+
(2) ... -> linear_relu0 -> ...
|
| 282 |
+
|
| 283 |
+
Without being able to inspect them together. With the current naming scheme, if
|
| 284 |
+
we iterate through both of these graphs in the same order, and assuming the rest
|
| 285 |
+
of the graphs match, both of these subgraphs will get the same name without
|
| 286 |
+
(1) and (2) knowing anything about each other.
|
| 287 |
+
"""
|
| 288 |
+
target_type = _get_node_target_type(subgraph_a.base_op_node, gm_a)
|
| 289 |
+
target_base_type = None
|
| 290 |
+
for base_name, sets_of_related_ops in base_name_to_sets_of_related_ops.items():
|
| 291 |
+
if target_type in sets_of_related_ops:
|
| 292 |
+
target_base_type = base_name
|
| 293 |
+
target_base_name = "base_op_" + str(target_base_type)
|
| 294 |
+
counter = 0
|
| 295 |
+
proposed_name = target_base_name + "_" + str(counter)
|
| 296 |
+
while proposed_name in existing_names:
|
| 297 |
+
counter += 1
|
| 298 |
+
proposed_name = target_base_name + "_" + str(counter)
|
| 299 |
+
existing_names.add(proposed_name)
|
| 300 |
+
return proposed_name
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def _get_node_target_type(node: Node, gm: GraphModule) -> Optional[NSNodeTargetType]:
|
| 304 |
+
if node.op in ("call_function", "call_method"):
|
| 305 |
+
return node.target
|
| 306 |
+
elif node.op == "call_module":
|
| 307 |
+
assert isinstance(node.target, str)
|
| 308 |
+
mod = getattr_from_fqn(gm, node.target)
|
| 309 |
+
return type(mod)
|
| 310 |
+
return None
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def get_matching_subgraph_pairs(
|
| 314 |
+
gm_a: GraphModule,
|
| 315 |
+
gm_b: GraphModule,
|
| 316 |
+
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 317 |
+
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 318 |
+
) -> Dict[str, Tuple[NSSubgraph, NSSubgraph]]:
|
| 319 |
+
"""
|
| 320 |
+
Matches matchable subgraphs of graph_a to graph_b.
|
| 321 |
+
|
| 322 |
+
For a node, "matchable" is defined as a node which is not an observer,
|
| 323 |
+
fake_quants, quant or dequant.
|
| 324 |
+
|
| 325 |
+
A subgraph can contain one or more nodes. A subgraph is matchable if
|
| 326 |
+
at least one node inside of it is matchable. Currently, all nodes in
|
| 327 |
+
a subgraph must be matchable (because we assume no observers will be
|
| 328 |
+
inserted in the middle of a fusion).
|
| 329 |
+
|
| 330 |
+
A subgraph is defined by (start_node, end_node). We assume that only
|
| 331 |
+
start_node and end_node are linked with the surrounding graph, all other
|
| 332 |
+
nodes in a subgraph are self-contained.
|
| 333 |
+
|
| 334 |
+
A pair of nodes is "related" if both nodes represent the same mathematical
|
| 335 |
+
operation across different quantization flavors. For example,
|
| 336 |
+
`F.linear` and `torch.ops.quantized.linear` are related, and
|
| 337 |
+
`F.linear` and `torch.nn.Conv` are not related.
|
| 338 |
+
|
| 339 |
+
For each matchable pair of nodes node_a and node_b, they will match
|
| 340 |
+
if node_a and node_b are related.
|
| 341 |
+
|
| 342 |
+
For graphs A and B, they will match iff:
|
| 343 |
+
1. the number of matchable subgraphs in A and B is equivalent
|
| 344 |
+
2. when iterating through the matchable subgraphs of A and B in the same order, each
|
| 345 |
+
corresponding pair of base nodes is related.
|
| 346 |
+
|
| 347 |
+
This enables us to find the corresponding subgraphs between
|
| 348 |
+
graphs of related models. For example, if we had two graphs such as:
|
| 349 |
+
|
| 350 |
+
graph_a: x0 -> conv_0 (type: nn.Conv2d) -> obs_0 -> x1
|
| 351 |
+
w -/
|
| 352 |
+
b -/
|
| 353 |
+
|
| 354 |
+
graph_b: x0 -> quant_0 -> qconv_0 (type: nnq.Conv2d) -> dequant_0 -> x1
|
| 355 |
+
packed_params_0 -/
|
| 356 |
+
|
| 357 |
+
This function will return the following result:
|
| 358 |
+
{
|
| 359 |
+
'conv_0': ( # the name of the node in graph_b
|
| 360 |
+
(conv_0, conv_0), # (start_node_a, end_node_a)
|
| 361 |
+
(qconv_0, qconv_0), # (start_node_b, end_node_b)
|
| 362 |
+
),
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
Or, if we have a fusion pattern,
|
| 366 |
+
|
| 367 |
+
graph_a: x0 -> linear_0 -> relu_0 -> obs_0 -> x1
|
| 368 |
+
w -/
|
| 369 |
+
b -/
|
| 370 |
+
|
| 371 |
+
graph_b: x0 -> quant_0 -> linear_relu_0 -> dequant_0 -> x1
|
| 372 |
+
packed_params_0 -/
|
| 373 |
+
|
| 374 |
+
This function will return the following result:
|
| 375 |
+
{
|
| 376 |
+
'linear_relu_0': ( # the name of the node in graph_b
|
| 377 |
+
(linear_0, relu_0), # (start_node_a, end_node_a)
|
| 378 |
+
(linear_relu_0, linear_relu_0), # (start_node_b, end_node_b)
|
| 379 |
+
),
|
| 380 |
+
}
|
| 381 |
+
"""
|
| 382 |
+
if unmatchable_types_map is None:
|
| 383 |
+
unmatchable_types_map = get_unmatchable_types_map()
|
| 384 |
+
non_matchable_functions = unmatchable_types_map["funs_unmatchable"]
|
| 385 |
+
non_matchable_modules = unmatchable_types_map["mods_unmatchable"]
|
| 386 |
+
non_matchable_methods = unmatchable_types_map["meths_unmatchable"]
|
| 387 |
+
|
| 388 |
+
graph_a_iterator = _NSGraphMatchableSubgraphsIterator(
|
| 389 |
+
gm_a, non_matchable_functions, non_matchable_modules, non_matchable_methods
|
| 390 |
+
)
|
| 391 |
+
graph_b_iterator = _NSGraphMatchableSubgraphsIterator(
|
| 392 |
+
gm_b, non_matchable_functions, non_matchable_modules, non_matchable_methods
|
| 393 |
+
)
|
| 394 |
+
results = collections.OrderedDict()
|
| 395 |
+
if base_name_to_sets_of_related_ops is None:
|
| 396 |
+
base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
|
| 397 |
+
type_a_related_to_b = get_type_a_related_to_b(base_name_to_sets_of_related_ops)
|
| 398 |
+
|
| 399 |
+
existing_names_a: Set[str] = set()
|
| 400 |
+
existing_names_b: Set[str] = set()
|
| 401 |
+
|
| 402 |
+
while True:
|
| 403 |
+
# fetch the next subgraphs from a and b
|
| 404 |
+
cur_subgraph_a, cur_subgraph_b = None, None
|
| 405 |
+
try:
|
| 406 |
+
cur_subgraph_a = next(graph_a_iterator)
|
| 407 |
+
except StopIteration:
|
| 408 |
+
pass
|
| 409 |
+
try:
|
| 410 |
+
cur_subgraph_b = next(graph_b_iterator)
|
| 411 |
+
except StopIteration:
|
| 412 |
+
pass
|
| 413 |
+
|
| 414 |
+
# look up types of a and b for useful error messages
|
| 415 |
+
type_start_a, type_start_b = None, None
|
| 416 |
+
if cur_subgraph_a is not None:
|
| 417 |
+
type_start_a = _get_node_target_type(cur_subgraph_a.start_node, gm_a)
|
| 418 |
+
if cur_subgraph_b is not None:
|
| 419 |
+
type_start_b = _get_node_target_type(cur_subgraph_b.start_node, gm_b)
|
| 420 |
+
|
| 421 |
+
# check for results and determine what to do next
|
| 422 |
+
if cur_subgraph_a is not None and cur_subgraph_b is not None:
|
| 423 |
+
# both nodes were fetched, check for subgraph_relationship
|
| 424 |
+
# note: subgraph_relationship is checked on the start node, i.e.
|
| 425 |
+
# if a linear-relu pattern is checked, we would check for subgraph_relationship
|
| 426 |
+
# of the linear
|
| 427 |
+
subgraph_relationship = _get_subgraph_relationship_type(
|
| 428 |
+
cur_subgraph_a, cur_subgraph_b, gm_a, gm_b, type_a_related_to_b
|
| 429 |
+
)
|
| 430 |
+
if subgraph_relationship == SubgraphTypeRelationship.NOT_RELATED:
|
| 431 |
+
msg = f"""
|
| 432 |
+
The subgraphs
|
| 433 |
+
({cur_subgraph_a}, {type_start_a}) and
|
| 434 |
+
({cur_subgraph_b}, {type_start_b})
|
| 435 |
+
are not related. Please ensure that the two models you pass in have the same number
|
| 436 |
+
of subgraphs, and each pair of subgraphs is related to each other."""
|
| 437 |
+
raise GraphMatchingException(msg)
|
| 438 |
+
elif subgraph_relationship == SubgraphTypeRelationship.EQUAL_BUT_UKNOWN:
|
| 439 |
+
# skip matching but unknown types
|
| 440 |
+
continue
|
| 441 |
+
key_name_a = _get_name_for_subgraph(
|
| 442 |
+
cur_subgraph_a, gm_a, base_name_to_sets_of_related_ops, existing_names_a
|
| 443 |
+
)
|
| 444 |
+
key_name_b = _get_name_for_subgraph(
|
| 445 |
+
cur_subgraph_b, gm_b, base_name_to_sets_of_related_ops, existing_names_b
|
| 446 |
+
)
|
| 447 |
+
assert (
|
| 448 |
+
key_name_a == key_name_b
|
| 449 |
+
), f"Subgraph names {key_name_a} and {key_name_b} do not match"
|
| 450 |
+
results[key_name_a] = (cur_subgraph_a, cur_subgraph_b)
|
| 451 |
+
continue
|
| 452 |
+
elif cur_subgraph_a is None and cur_subgraph_b is None:
|
| 453 |
+
# we reached the end of both graphs
|
| 454 |
+
break
|
| 455 |
+
else:
|
| 456 |
+
# only one node was fetched, no match possible, throw error
|
| 457 |
+
msg = f"""
|
| 458 |
+
Attempting to match
|
| 459 |
+
({cur_subgraph_a}, {type_start_a}) and
|
| 460 |
+
({cur_subgraph_b}, {type_start_b}),
|
| 461 |
+
one of which is empty. Please ensure that the two models you pass in have the same number
|
| 462 |
+
of subgraphs."""
|
| 463 |
+
raise GraphMatchingException(msg)
|
| 464 |
+
|
| 465 |
+
# The subgraph pairs are originally created by traversing the two graphs
|
| 466 |
+
# from the outputs to the inputs. Reverse the results to return the
|
| 467 |
+
# subgraphs in their order of execution.
|
| 468 |
+
results = collections.OrderedDict(reversed(list(results.items())))
|
| 469 |
+
|
| 470 |
+
return results
|
.venv/Lib/site-packages/torch/ao/ns/fx/graph_passes.py
ADDED
|
@@ -0,0 +1,1131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch.ao.ns.fx.mappings import get_node_type_to_io_type_map
|
| 6 |
+
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
|
| 7 |
+
from torch.ao.quantization.observer import _is_activation_post_process
|
| 8 |
+
from torch.fx import GraphModule, map_arg
|
| 9 |
+
from torch.fx.graph import Graph, Node
|
| 10 |
+
|
| 11 |
+
from .ns_types import NSNodeTargetType, NSSingleResultValuesType, NSSubgraph
|
| 12 |
+
from .utils import (
|
| 13 |
+
get_arg_indices_of_inputs_to_log,
|
| 14 |
+
get_node_first_input_and_output_type,
|
| 15 |
+
get_node_input_qparams,
|
| 16 |
+
get_normalized_nth_input,
|
| 17 |
+
get_number_of_non_param_args,
|
| 18 |
+
get_target_type_str,
|
| 19 |
+
getattr_from_fqn,
|
| 20 |
+
NodeInputOrOutputType,
|
| 21 |
+
op_type_supports_shadowing,
|
| 22 |
+
return_first_non_observer_node,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _maybe_get_fqn(node: Node, gm: GraphModule) -> Optional[str]:
|
| 27 |
+
fqn = None
|
| 28 |
+
if hasattr(gm, "_node_name_to_scope"):
|
| 29 |
+
# fqn on observers is not present, because they do not
|
| 30 |
+
# exist when the fqns are created during tracing. If this is
|
| 31 |
+
# an observer, get the fqn of the node being observed.
|
| 32 |
+
node_to_use_for_fqn = node
|
| 33 |
+
if node.op == "call_module":
|
| 34 |
+
assert isinstance(node.target, str)
|
| 35 |
+
module = getattr_from_fqn(gm, node.target)
|
| 36 |
+
if _is_activation_post_process(module):
|
| 37 |
+
node_to_use_for_fqn = get_normalized_nth_input(node, gm, 0)
|
| 38 |
+
fqn = gm._node_name_to_scope[node_to_use_for_fqn.name][0] # type: ignore[index]
|
| 39 |
+
return fqn # type: ignore[return-value]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _insert_logger_after_node(
|
| 43 |
+
node: Node,
|
| 44 |
+
gm: GraphModule,
|
| 45 |
+
logger_cls: Callable,
|
| 46 |
+
logger_node_name_suffix: str,
|
| 47 |
+
ref_node_name: str,
|
| 48 |
+
model_name: str,
|
| 49 |
+
ref_name: str,
|
| 50 |
+
ref_node_target_type: str,
|
| 51 |
+
results_type: str,
|
| 52 |
+
index_within_arg: int,
|
| 53 |
+
index_of_arg: int,
|
| 54 |
+
fqn: Optional[str],
|
| 55 |
+
) -> Node:
|
| 56 |
+
"""
|
| 57 |
+
Given a starting graph of
|
| 58 |
+
|
| 59 |
+
prev_node -> node -> next_node
|
| 60 |
+
|
| 61 |
+
This function creates a new logger_cls obj and adds it
|
| 62 |
+
after node, resulting in
|
| 63 |
+
|
| 64 |
+
prev_node -> node -> logger_obj -> next_node
|
| 65 |
+
"""
|
| 66 |
+
# create new name
|
| 67 |
+
logger_node_name = get_new_attr_name_with_prefix(
|
| 68 |
+
node.name + logger_node_name_suffix
|
| 69 |
+
)(gm)
|
| 70 |
+
target_type = get_target_type_str(node, gm)
|
| 71 |
+
# create the logger object
|
| 72 |
+
logger_obj = logger_cls(
|
| 73 |
+
ref_node_name,
|
| 74 |
+
node.name,
|
| 75 |
+
model_name,
|
| 76 |
+
ref_name,
|
| 77 |
+
target_type,
|
| 78 |
+
ref_node_target_type,
|
| 79 |
+
results_type,
|
| 80 |
+
index_within_arg,
|
| 81 |
+
index_of_arg,
|
| 82 |
+
fqn,
|
| 83 |
+
)
|
| 84 |
+
# attach the logger object to the parent module
|
| 85 |
+
setattr(gm, logger_node_name, logger_obj)
|
| 86 |
+
logger_node = node.graph.create_node("call_module", logger_node_name, (node,), {})
|
| 87 |
+
return logger_node
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def add_loggers_to_model(
|
| 91 |
+
gm: GraphModule,
|
| 92 |
+
node_to_instrument_inputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
|
| 93 |
+
node_to_instrument_outputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
|
| 94 |
+
logger_cls: Callable,
|
| 95 |
+
model_name: str,
|
| 96 |
+
) -> GraphModule:
|
| 97 |
+
"""
|
| 98 |
+
Takes the graph of gm, adds loggers to the output
|
| 99 |
+
of each node in nodes_to_instrument. Returns a GraphModule with the new
|
| 100 |
+
graph.
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
new_graph = Graph()
|
| 104 |
+
env: Dict[str, Any] = {}
|
| 105 |
+
modules = dict(gm.named_modules())
|
| 106 |
+
|
| 107 |
+
def load_arg(a):
|
| 108 |
+
return map_arg(a, lambda node: env[node.name])
|
| 109 |
+
|
| 110 |
+
for node in gm.graph.nodes:
|
| 111 |
+
if node.op == "output":
|
| 112 |
+
new_graph.output(map_arg(get_normalized_nth_input(node, gm, 0), load_arg))
|
| 113 |
+
continue
|
| 114 |
+
|
| 115 |
+
if (node in node_to_instrument_inputs_to_ref_node_name) or (
|
| 116 |
+
node in node_to_instrument_outputs_to_ref_node_name
|
| 117 |
+
):
|
| 118 |
+
fqn = _maybe_get_fqn(node, gm)
|
| 119 |
+
|
| 120 |
+
if node in node_to_instrument_inputs_to_ref_node_name:
|
| 121 |
+
ref_name, ref_node_type = node_to_instrument_inputs_to_ref_node_name[
|
| 122 |
+
node
|
| 123 |
+
]
|
| 124 |
+
# Ops such add and mul are special because either
|
| 125 |
+
# one or two of the first two arguments can be tensors,
|
| 126 |
+
# and if one argument is a tensor it can be first or
|
| 127 |
+
# second (x + 1 versus 1 + x).
|
| 128 |
+
arg_indices_to_log = get_arg_indices_of_inputs_to_log(node)
|
| 129 |
+
for node_arg_idx in arg_indices_to_log:
|
| 130 |
+
node_arg = get_normalized_nth_input(node, gm, node_arg_idx)
|
| 131 |
+
if type(node_arg) == Node:
|
| 132 |
+
# create a single input logger
|
| 133 |
+
prev_node = env[node_arg.name]
|
| 134 |
+
env[node_arg.name] = _insert_logger_after_node(
|
| 135 |
+
prev_node,
|
| 136 |
+
gm,
|
| 137 |
+
logger_cls,
|
| 138 |
+
"_ns_logger_",
|
| 139 |
+
node.name,
|
| 140 |
+
model_name,
|
| 141 |
+
ref_name,
|
| 142 |
+
ref_node_type,
|
| 143 |
+
NSSingleResultValuesType.NODE_INPUT.value,
|
| 144 |
+
index_within_arg=0,
|
| 145 |
+
index_of_arg=node_arg_idx,
|
| 146 |
+
fqn=fqn,
|
| 147 |
+
)
|
| 148 |
+
elif (
|
| 149 |
+
type(node_arg) == torch.fx.immutable_collections.immutable_list
|
| 150 |
+
):
|
| 151 |
+
# create N input loggers, one for each node
|
| 152 |
+
for arg_idx, arg in enumerate(node_arg): # type: ignore[var-annotated, arg-type]
|
| 153 |
+
prev_node = env[arg.name]
|
| 154 |
+
env[prev_node.name] = _insert_logger_after_node(
|
| 155 |
+
prev_node,
|
| 156 |
+
gm,
|
| 157 |
+
logger_cls,
|
| 158 |
+
"_ns_logger_",
|
| 159 |
+
node.name,
|
| 160 |
+
model_name,
|
| 161 |
+
ref_name,
|
| 162 |
+
ref_node_type,
|
| 163 |
+
NSSingleResultValuesType.NODE_INPUT.value,
|
| 164 |
+
index_within_arg=arg_idx,
|
| 165 |
+
index_of_arg=node_arg_idx,
|
| 166 |
+
fqn=fqn,
|
| 167 |
+
)
|
| 168 |
+
else:
|
| 169 |
+
pass
|
| 170 |
+
|
| 171 |
+
# ensure env is populated with base node
|
| 172 |
+
# Note: runs for both inputs and outputs
|
| 173 |
+
env[node.name] = new_graph.node_copy(node, load_arg)
|
| 174 |
+
|
| 175 |
+
if node in node_to_instrument_outputs_to_ref_node_name:
|
| 176 |
+
ref_name, ref_node_type = node_to_instrument_outputs_to_ref_node_name[
|
| 177 |
+
node
|
| 178 |
+
]
|
| 179 |
+
# add the logger after the base node
|
| 180 |
+
env[node.name] = _insert_logger_after_node(
|
| 181 |
+
env[node.name],
|
| 182 |
+
gm,
|
| 183 |
+
logger_cls,
|
| 184 |
+
"_ns_logger_",
|
| 185 |
+
node.name,
|
| 186 |
+
model_name,
|
| 187 |
+
ref_name,
|
| 188 |
+
ref_node_type,
|
| 189 |
+
NSSingleResultValuesType.NODE_OUTPUT.value,
|
| 190 |
+
index_within_arg=0,
|
| 191 |
+
index_of_arg=0,
|
| 192 |
+
fqn=fqn,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
else:
|
| 196 |
+
env[node.name] = new_graph.node_copy(node, load_arg)
|
| 197 |
+
|
| 198 |
+
new_gm = GraphModule(gm, new_graph)
|
| 199 |
+
return new_gm
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def _insert_quantize_per_tensor_node(
|
| 203 |
+
prev_node_c: Node,
|
| 204 |
+
node_a: Node,
|
| 205 |
+
gm_b: GraphModule,
|
| 206 |
+
graph_c: Graph,
|
| 207 |
+
scale: Union[torch.Tensor, float],
|
| 208 |
+
zero_point: Union[torch.Tensor, int],
|
| 209 |
+
dtype_cast_name: str,
|
| 210 |
+
) -> Node:
|
| 211 |
+
# copy scale
|
| 212 |
+
scale_node_name = get_new_attr_name_with_prefix(node_a.name + "_input_scale_")(gm_b)
|
| 213 |
+
setattr(gm_b, scale_node_name, scale)
|
| 214 |
+
scale_node = graph_c.create_node(
|
| 215 |
+
"get_attr", scale_node_name, (), {}, scale_node_name
|
| 216 |
+
)
|
| 217 |
+
# copy zero_point
|
| 218 |
+
zero_point_node_name = get_new_attr_name_with_prefix(
|
| 219 |
+
node_a.name + "_input_zero_point_"
|
| 220 |
+
)(gm_b)
|
| 221 |
+
setattr(gm_b, zero_point_node_name, zero_point)
|
| 222 |
+
zero_point_node = graph_c.create_node(
|
| 223 |
+
"get_attr", zero_point_node_name, (), {}, zero_point_node_name
|
| 224 |
+
)
|
| 225 |
+
# create the quantize_per_tensor call
|
| 226 |
+
return graph_c.create_node(
|
| 227 |
+
"call_function",
|
| 228 |
+
torch.quantize_per_tensor,
|
| 229 |
+
(prev_node_c, scale_node, zero_point_node, torch.quint8),
|
| 230 |
+
{},
|
| 231 |
+
dtype_cast_name,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def _insert_dtype_cast_after_node(
|
| 236 |
+
node_a: Node,
|
| 237 |
+
node_c: Node,
|
| 238 |
+
prev_node_c: Union[Node, List[Node]],
|
| 239 |
+
gm_a: GraphModule,
|
| 240 |
+
gm_b: GraphModule,
|
| 241 |
+
graph_c: Graph,
|
| 242 |
+
node_name_prefix: str,
|
| 243 |
+
logger_cls: Callable,
|
| 244 |
+
node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
|
| 245 |
+
) -> Union[Node, List[Node]]:
|
| 246 |
+
"""
|
| 247 |
+
Given a starting graph C (derived from graph B) of
|
| 248 |
+
|
| 249 |
+
... -> prev_node_c -> node_c -> ...
|
| 250 |
+
|
| 251 |
+
And a corresponding related node_a, inserts the correct dtype
|
| 252 |
+
cast node after prev_node_c to cast into the dtype expected
|
| 253 |
+
by node_a, resulting in:
|
| 254 |
+
|
| 255 |
+
dtype_cast
|
| 256 |
+
/
|
| 257 |
+
... -> prev_node_c -> node_c -> ...
|
| 258 |
+
|
| 259 |
+
For example, if node_c is an int8 op and node_a is an fp32 op, this function
|
| 260 |
+
will insert a dequant.
|
| 261 |
+
"""
|
| 262 |
+
dtype_cast_op = None
|
| 263 |
+
dtype_cast_mod_cls = None
|
| 264 |
+
dtype_cast_method = None
|
| 265 |
+
dtype_cast_method_dtype = None
|
| 266 |
+
dtype_cast_scale = None
|
| 267 |
+
dtype_cast_zero_point = None
|
| 268 |
+
node_input_type_a, _node_output_type_a = get_node_first_input_and_output_type(
|
| 269 |
+
node_a, gm_a, logger_cls, node_type_to_io_type_map
|
| 270 |
+
)
|
| 271 |
+
node_input_type_c, _node_output_type_c = get_node_first_input_and_output_type(
|
| 272 |
+
node_c, gm_b, logger_cls, node_type_to_io_type_map
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
if (
|
| 276 |
+
(
|
| 277 |
+
node_input_type_a == NodeInputOrOutputType.FP32
|
| 278 |
+
and node_input_type_c == NodeInputOrOutputType.INT8
|
| 279 |
+
)
|
| 280 |
+
or (
|
| 281 |
+
node_input_type_a == NodeInputOrOutputType.FP32
|
| 282 |
+
and node_input_type_c == NodeInputOrOutputType.FP16
|
| 283 |
+
)
|
| 284 |
+
or
|
| 285 |
+
# TODO(future PR): determine the actual dtype of node_c,
|
| 286 |
+
# the current code only works because dequantize works with
|
| 287 |
+
# multiple input dtypes.
|
| 288 |
+
(
|
| 289 |
+
node_input_type_a == NodeInputOrOutputType.FP32
|
| 290 |
+
and node_input_type_c == NodeInputOrOutputType.FP32_OR_INT8
|
| 291 |
+
)
|
| 292 |
+
):
|
| 293 |
+
dtype_cast_op = torch.dequantize
|
| 294 |
+
elif (
|
| 295 |
+
node_input_type_a == node_input_type_c
|
| 296 |
+
and node_input_type_a != NodeInputOrOutputType.UNKNOWN
|
| 297 |
+
):
|
| 298 |
+
dtype_cast_mod_cls = torch.nn.Identity
|
| 299 |
+
elif (
|
| 300 |
+
node_input_type_a == NodeInputOrOutputType.INT8
|
| 301 |
+
and node_input_type_c == NodeInputOrOutputType.FP32
|
| 302 |
+
):
|
| 303 |
+
# int8 shadows fp32, the dtype cast needs to quantize to int8
|
| 304 |
+
# with the right qparams.
|
| 305 |
+
node_a_input_qparams = get_node_input_qparams(
|
| 306 |
+
node_a, gm_a, node_type_to_io_type_map
|
| 307 |
+
)
|
| 308 |
+
if node_a_input_qparams is not None:
|
| 309 |
+
dtype_cast_op = torch.quantize_per_tensor # type: ignore[assignment]
|
| 310 |
+
dtype_cast_scale, dtype_cast_zero_point = node_a_input_qparams
|
| 311 |
+
elif (
|
| 312 |
+
node_input_type_a == NodeInputOrOutputType.FP16
|
| 313 |
+
and node_input_type_c == NodeInputOrOutputType.FP32
|
| 314 |
+
):
|
| 315 |
+
dtype_cast_method = "to"
|
| 316 |
+
dtype_cast_method_dtype = torch.float16
|
| 317 |
+
else:
|
| 318 |
+
raise AssertionError(
|
| 319 |
+
f"dtype cast from {node_input_type_c} {node_c.format_node()} to "
|
| 320 |
+
+ f"{node_input_type_a} {node_a.format_node()} needs to be implemented"
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
if isinstance(prev_node_c, Node):
|
| 324 |
+
new_dtype_cast_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
|
| 325 |
+
if dtype_cast_op:
|
| 326 |
+
if dtype_cast_scale is not None and dtype_cast_zero_point is not None:
|
| 327 |
+
return _insert_quantize_per_tensor_node(
|
| 328 |
+
prev_node_c,
|
| 329 |
+
node_a,
|
| 330 |
+
gm_b,
|
| 331 |
+
graph_c,
|
| 332 |
+
dtype_cast_scale,
|
| 333 |
+
dtype_cast_zero_point,
|
| 334 |
+
new_dtype_cast_name,
|
| 335 |
+
)
|
| 336 |
+
else:
|
| 337 |
+
return graph_c.create_node(
|
| 338 |
+
"call_function",
|
| 339 |
+
dtype_cast_op,
|
| 340 |
+
(prev_node_c,),
|
| 341 |
+
{},
|
| 342 |
+
new_dtype_cast_name,
|
| 343 |
+
)
|
| 344 |
+
elif dtype_cast_method:
|
| 345 |
+
return graph_c.create_node(
|
| 346 |
+
"call_method",
|
| 347 |
+
dtype_cast_method,
|
| 348 |
+
(prev_node_c, dtype_cast_method_dtype),
|
| 349 |
+
{},
|
| 350 |
+
new_dtype_cast_name,
|
| 351 |
+
)
|
| 352 |
+
else:
|
| 353 |
+
assert dtype_cast_mod_cls
|
| 354 |
+
dtype_cast_mod = dtype_cast_mod_cls()
|
| 355 |
+
setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
|
| 356 |
+
return graph_c.create_node(
|
| 357 |
+
"call_module",
|
| 358 |
+
new_dtype_cast_name,
|
| 359 |
+
(prev_node_c,),
|
| 360 |
+
{},
|
| 361 |
+
new_dtype_cast_name,
|
| 362 |
+
)
|
| 363 |
+
elif isinstance(prev_node_c, list):
|
| 364 |
+
results = []
|
| 365 |
+
for prev_node_c_inner in prev_node_c:
|
| 366 |
+
new_dtype_cast_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
|
| 367 |
+
if dtype_cast_op:
|
| 368 |
+
# TODO(future PR): add handling for quantize_per_tensor
|
| 369 |
+
new_dtype_cast_node = graph_c.create_node(
|
| 370 |
+
"call_function",
|
| 371 |
+
dtype_cast_op,
|
| 372 |
+
(prev_node_c_inner,),
|
| 373 |
+
{},
|
| 374 |
+
new_dtype_cast_name,
|
| 375 |
+
)
|
| 376 |
+
results.append(new_dtype_cast_node)
|
| 377 |
+
else:
|
| 378 |
+
assert dtype_cast_mod_cls
|
| 379 |
+
dtype_cast_mod = dtype_cast_mod_cls()
|
| 380 |
+
setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
|
| 381 |
+
new_dtype_cast_node = graph_c.create_node(
|
| 382 |
+
"call_module",
|
| 383 |
+
new_dtype_cast_name,
|
| 384 |
+
(prev_node_c_inner,),
|
| 385 |
+
{},
|
| 386 |
+
new_dtype_cast_name,
|
| 387 |
+
)
|
| 388 |
+
results.append(new_dtype_cast_node)
|
| 389 |
+
return results
|
| 390 |
+
else:
|
| 391 |
+
raise AssertionError(f"type f{type(prev_node_c)} is not handled")
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
# TODO(future PR): look into using copy_node API instead
|
| 395 |
+
def _copy_node_from_a_to_c(
|
| 396 |
+
node_a: Node,
|
| 397 |
+
gm_a: GraphModule,
|
| 398 |
+
gm_b: GraphModule,
|
| 399 |
+
graph_c: Graph,
|
| 400 |
+
) -> Node:
|
| 401 |
+
"""
|
| 402 |
+
Simple copy of node_a to graph_c.
|
| 403 |
+
"""
|
| 404 |
+
if node_a.op == "get_attr":
|
| 405 |
+
node_a_copy_name = get_new_attr_name_with_prefix(node_a.name + "_shadow_copy_")(
|
| 406 |
+
gm_b
|
| 407 |
+
)
|
| 408 |
+
node_a_obj = getattr_from_fqn(gm_a, node_a.target) # type: ignore[arg-type]
|
| 409 |
+
if torch.is_tensor(node_a_obj):
|
| 410 |
+
node_a_obj = node_a_obj.detach()
|
| 411 |
+
setattr(gm_b, node_a_copy_name, node_a_obj)
|
| 412 |
+
node_a_copy = graph_c.create_node(
|
| 413 |
+
node_a.op, node_a_copy_name, (), {}, node_a_copy_name
|
| 414 |
+
)
|
| 415 |
+
return node_a_copy
|
| 416 |
+
elif node_a.op == "call_method":
|
| 417 |
+
assert node_a.target in (
|
| 418 |
+
"dequantize",
|
| 419 |
+
"to",
|
| 420 |
+
), f"target {node_a.target} is not implemented"
|
| 421 |
+
if node_a.target == "dequantize":
|
| 422 |
+
arg_copy = _copy_node_from_a_to_c(
|
| 423 |
+
get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c
|
| 424 |
+
) # type: ignore[arg-type]
|
| 425 |
+
node_a_copy_name = get_new_attr_name_with_prefix(
|
| 426 |
+
node_a.name + "_shadow_copy_"
|
| 427 |
+
)(gm_b)
|
| 428 |
+
node_a_copy = graph_c.create_node(
|
| 429 |
+
node_a.op, node_a.target, (arg_copy,), {}, node_a_copy_name
|
| 430 |
+
)
|
| 431 |
+
return node_a_copy
|
| 432 |
+
else: # to
|
| 433 |
+
arg_copy = _copy_node_from_a_to_c(
|
| 434 |
+
get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c
|
| 435 |
+
) # type: ignore[arg-type]
|
| 436 |
+
node_a_copy_name = get_new_attr_name_with_prefix(
|
| 437 |
+
node_a.name + "_shadow_copy_"
|
| 438 |
+
)(gm_b)
|
| 439 |
+
node_a_copy = graph_c.create_node(
|
| 440 |
+
node_a.op,
|
| 441 |
+
node_a.target,
|
| 442 |
+
(arg_copy, get_normalized_nth_input(node_a, gm_a, 1)),
|
| 443 |
+
{},
|
| 444 |
+
node_a_copy_name,
|
| 445 |
+
)
|
| 446 |
+
return node_a_copy
|
| 447 |
+
|
| 448 |
+
else:
|
| 449 |
+
raise AssertionError(
|
| 450 |
+
f"handling of node {node_a.format_node()} with op {node_a.op} is not implemented"
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def _can_insert_copy_of_subgraph_a(
|
| 455 |
+
subgraph_a: NSSubgraph,
|
| 456 |
+
gm_a: GraphModule,
|
| 457 |
+
num_non_param_args_node_a: int,
|
| 458 |
+
) -> bool:
|
| 459 |
+
"""
|
| 460 |
+
This function returns `False` if the input subgraph cannot be copied by
|
| 461 |
+
`_insert_copy_of_subgraph_a_after_input_node_c`. This usually means
|
| 462 |
+
that there is a corner case logic for which copy is not yet implemented.
|
| 463 |
+
"""
|
| 464 |
+
# populate the list of nodes we need to check
|
| 465 |
+
nodes = []
|
| 466 |
+
cur_node = subgraph_a.end_node
|
| 467 |
+
while cur_node != subgraph_a.start_node:
|
| 468 |
+
nodes.append(cur_node)
|
| 469 |
+
cur_node = get_normalized_nth_input(cur_node, gm_a, 0) # type: ignore[assignment]
|
| 470 |
+
nodes.append(cur_node)
|
| 471 |
+
nodes.reverse()
|
| 472 |
+
|
| 473 |
+
def _can_insert(node_a_arg, gm_a):
|
| 474 |
+
if isinstance(node_a_arg, Node):
|
| 475 |
+
arg_a = return_first_non_observer_node(node_a_arg, gm_a)
|
| 476 |
+
if arg_a.op == "call_method":
|
| 477 |
+
return arg_a.target in ("dequantize", "to")
|
| 478 |
+
elif arg_a.op == "get_attr":
|
| 479 |
+
return True
|
| 480 |
+
else:
|
| 481 |
+
return False
|
| 482 |
+
elif isinstance(node_a_arg, (list, tuple)):
|
| 483 |
+
for el in node_a_arg:
|
| 484 |
+
if not isinstance(el, Node):
|
| 485 |
+
return False
|
| 486 |
+
return True
|
| 487 |
+
|
| 488 |
+
# For each node, check if we handle the copy behavior. This follows the
|
| 489 |
+
# logic in `_insert_copy_of_subgraph_a_after_input_node_c`.
|
| 490 |
+
for node_a in nodes:
|
| 491 |
+
local_num_non_param_args_node_a = (
|
| 492 |
+
num_non_param_args_node_a if node_a is nodes[0] else 1
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
norm_args_kwargs = node_a.normalized_arguments(
|
| 496 |
+
gm_a, normalize_to_only_use_kwargs=True
|
| 497 |
+
)
|
| 498 |
+
if norm_args_kwargs is not None:
|
| 499 |
+
norm_args, norm_kwargs = norm_args_kwargs
|
| 500 |
+
else:
|
| 501 |
+
norm_args, norm_kwargs = node_a.args, node_a.kwargs
|
| 502 |
+
|
| 503 |
+
cur_idx = 0
|
| 504 |
+
|
| 505 |
+
while cur_idx < len(norm_args):
|
| 506 |
+
if cur_idx == 0:
|
| 507 |
+
pass
|
| 508 |
+
elif cur_idx == 1 and local_num_non_param_args_node_a == 2:
|
| 509 |
+
pass
|
| 510 |
+
else:
|
| 511 |
+
if not _can_insert(norm_args[cur_idx], gm_a):
|
| 512 |
+
return False
|
| 513 |
+
cur_idx += 1
|
| 514 |
+
|
| 515 |
+
for kwarg_val in norm_kwargs.values():
|
| 516 |
+
# stitch the inputs from base graph
|
| 517 |
+
if cur_idx == 0:
|
| 518 |
+
pass
|
| 519 |
+
elif cur_idx == 1 and local_num_non_param_args_node_a == 2:
|
| 520 |
+
pass
|
| 521 |
+
else:
|
| 522 |
+
if not _can_insert(kwarg_val, gm_a):
|
| 523 |
+
return False
|
| 524 |
+
cur_idx += 1
|
| 525 |
+
|
| 526 |
+
return True
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
def _insert_copy_of_subgraph_a_after_input_node_c(
|
| 530 |
+
input_node_c: Union[Node, List[Node]],
|
| 531 |
+
input_node_c_2: Optional[Union[Node, List[Node]]],
|
| 532 |
+
subgraph_a: NSSubgraph,
|
| 533 |
+
gm_a: GraphModule,
|
| 534 |
+
gm_b: GraphModule,
|
| 535 |
+
node_name_prefix: str,
|
| 536 |
+
) -> Node:
|
| 537 |
+
"""
|
| 538 |
+
TODO(before land): real docblock
|
| 539 |
+
"""
|
| 540 |
+
if isinstance(input_node_c, Node):
|
| 541 |
+
graph_c = input_node_c.graph
|
| 542 |
+
else:
|
| 543 |
+
assert isinstance(input_node_c, list)
|
| 544 |
+
graph_c = input_node_c[0].graph
|
| 545 |
+
|
| 546 |
+
# create a sequential list of the subgraphs' nodes from start to end,
|
| 547 |
+
# because we need to add the nodes to graph C in non-reverse order
|
| 548 |
+
nodes_of_a = [subgraph_a.end_node]
|
| 549 |
+
cur_node = subgraph_a.end_node
|
| 550 |
+
while cur_node != subgraph_a.start_node:
|
| 551 |
+
cur_node = get_normalized_nth_input(cur_node, gm_a, 0) # type: ignore[assignment]
|
| 552 |
+
nodes_of_a.insert(0, cur_node)
|
| 553 |
+
|
| 554 |
+
# go through nodes of a in order, and insert them into the graph of c
|
| 555 |
+
# sequentially
|
| 556 |
+
cur_node_a = nodes_of_a[0]
|
| 557 |
+
cur_node_c = _insert_copy_of_node_a_after_input_node_c(
|
| 558 |
+
input_node_c, input_node_c_2, cur_node_a, gm_a, gm_b, node_name_prefix
|
| 559 |
+
)
|
| 560 |
+
for cur_idx_a in range(1, len(nodes_of_a)):
|
| 561 |
+
cur_node_a = nodes_of_a[cur_idx_a]
|
| 562 |
+
prev_node_c = cur_node_c # previous added node is the input to next node
|
| 563 |
+
cur_node_c = _insert_copy_of_node_a_after_input_node_c(
|
| 564 |
+
prev_node_c,
|
| 565 |
+
# TODO(future PR): enable multiple inputs for nodes which are not at start of subgraph
|
| 566 |
+
None,
|
| 567 |
+
cur_node_a,
|
| 568 |
+
gm_a,
|
| 569 |
+
gm_b,
|
| 570 |
+
node_name_prefix,
|
| 571 |
+
)
|
| 572 |
+
# return the last inserted node
|
| 573 |
+
return cur_node_c
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
def _insert_copy_of_node_a_after_input_node_c(
|
| 577 |
+
input_node_c: Union[Node, List[Node]],
|
| 578 |
+
input_node_c_2: Optional[Union[Node, List[Node]]],
|
| 579 |
+
node_a: Node,
|
| 580 |
+
gm_a: GraphModule,
|
| 581 |
+
gm_b: GraphModule,
|
| 582 |
+
node_name_prefix: str,
|
| 583 |
+
) -> Node:
|
| 584 |
+
"""
|
| 585 |
+
Assume that node_a from graph_a has
|
| 586 |
+
args (input, (input2)?, arg1, ...), and
|
| 587 |
+
kwargs {kw0: kwarg0, ...}
|
| 588 |
+
|
| 589 |
+
Note: input2 is optional. If it equals to None, we assume that the op
|
| 590 |
+
has a single non-param input. If it is specified, we assume that the op
|
| 591 |
+
has two non-param inputs.
|
| 592 |
+
|
| 593 |
+
Copies the underlying values of arg1..argn and kwarg0..kwargn into gm_b,
|
| 594 |
+
and creates the corresponding nodes in graph_c. Note: observers are ignored,
|
| 595 |
+
so if an arg is an observer we navigate up until we find a non-observer parent.
|
| 596 |
+
|
| 597 |
+
If node_a is a call_module, points the module pointed to by node_a to gm_b.
|
| 598 |
+
|
| 599 |
+
Creates the copy of node_a in graph_c, with input as the first arg,
|
| 600 |
+
and all other args and kwargs pointing to the copies of the objects
|
| 601 |
+
in gm_b created above.
|
| 602 |
+
|
| 603 |
+
An example in pictures:
|
| 604 |
+
|
| 605 |
+
graph A:
|
| 606 |
+
========
|
| 607 |
+
|
| 608 |
+
input -------------> node_a
|
| 609 |
+
/ / /
|
| 610 |
+
(input_2)?----------/ / /
|
| 611 |
+
/ /
|
| 612 |
+
weight -> weight_obs /
|
| 613 |
+
/
|
| 614 |
+
bias ----------------
|
| 615 |
+
|
| 616 |
+
graph C (derived from B):
|
| 617 |
+
=========================
|
| 618 |
+
|
| 619 |
+
input_node_c --> node_a_copy
|
| 620 |
+
/ / /
|
| 621 |
+
(input_node_c_2)? / /
|
| 622 |
+
/ /
|
| 623 |
+
weight_copy ----/ /
|
| 624 |
+
/
|
| 625 |
+
bias_copy ------/
|
| 626 |
+
"""
|
| 627 |
+
if isinstance(input_node_c, Node):
|
| 628 |
+
graph_c = input_node_c.graph
|
| 629 |
+
else:
|
| 630 |
+
assert isinstance(input_node_c, list)
|
| 631 |
+
graph_c = input_node_c[0].graph
|
| 632 |
+
|
| 633 |
+
norm_args_kwargs = node_a.normalized_arguments(
|
| 634 |
+
gm_a, normalize_to_only_use_kwargs=True
|
| 635 |
+
)
|
| 636 |
+
if norm_args_kwargs is not None:
|
| 637 |
+
norm_args, norm_kwargs = norm_args_kwargs
|
| 638 |
+
else:
|
| 639 |
+
norm_args, norm_kwargs = node_a.args, node_a.kwargs
|
| 640 |
+
|
| 641 |
+
new_args = []
|
| 642 |
+
new_kwargs = {}
|
| 643 |
+
|
| 644 |
+
def _copy_arg(arg):
|
| 645 |
+
# copy the other inputs from the other graph
|
| 646 |
+
if isinstance(arg, Node):
|
| 647 |
+
arg = return_first_non_observer_node(arg, gm_a)
|
| 648 |
+
arg = _copy_node_from_a_to_c(arg, gm_a, gm_b, graph_c)
|
| 649 |
+
return arg
|
| 650 |
+
elif isinstance(arg, (int, float, torch.dtype)):
|
| 651 |
+
return arg
|
| 652 |
+
elif isinstance(kwarg_val, (list, tuple)):
|
| 653 |
+
for el in kwarg_val:
|
| 654 |
+
assert not isinstance(
|
| 655 |
+
el, Node
|
| 656 |
+
), "handling of Node inside list is not implemented"
|
| 657 |
+
return arg
|
| 658 |
+
else:
|
| 659 |
+
raise AssertionError(
|
| 660 |
+
f"handling for kwarg of type {type(kwarg_val)} is not implemented"
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
cur_idx = 0
|
| 664 |
+
|
| 665 |
+
while cur_idx < len(norm_args):
|
| 666 |
+
if cur_idx == 0:
|
| 667 |
+
new_arg = input_node_c
|
| 668 |
+
elif cur_idx == 1 and input_node_c_2 is not None:
|
| 669 |
+
new_arg = input_node_c_2
|
| 670 |
+
else:
|
| 671 |
+
new_arg = _copy_arg(norm_args[cur_idx])
|
| 672 |
+
new_args.append(new_arg)
|
| 673 |
+
cur_idx += 1
|
| 674 |
+
|
| 675 |
+
for kwarg_name, kwarg_val in norm_kwargs.items():
|
| 676 |
+
# stitch the inputs from base graph
|
| 677 |
+
if cur_idx == 0:
|
| 678 |
+
new_kwargs[kwarg_name] = input_node_c
|
| 679 |
+
elif cur_idx == 1 and input_node_c_2 is not None:
|
| 680 |
+
new_kwargs[kwarg_name] = input_node_c_2
|
| 681 |
+
else:
|
| 682 |
+
new_kwargs[kwarg_name] = _copy_arg(kwarg_val)
|
| 683 |
+
cur_idx += 1
|
| 684 |
+
|
| 685 |
+
new_args = tuple(new_args) # type: ignore[assignment]
|
| 686 |
+
|
| 687 |
+
node_a_shadows_c_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
|
| 688 |
+
|
| 689 |
+
if node_a.op == "call_module":
|
| 690 |
+
# if target is a module, we point to the module from gm_b
|
| 691 |
+
new_mod_copy_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
|
| 692 |
+
# fetch the corresponding module from gm_a
|
| 693 |
+
assert isinstance(node_a.target, str)
|
| 694 |
+
mod_a = getattr_from_fqn(gm_a, node_a.target)
|
| 695 |
+
setattr(gm_b, new_mod_copy_name, mod_a)
|
| 696 |
+
node_a_shadows_c = graph_c.create_node(
|
| 697 |
+
node_a.op, new_mod_copy_name, new_args, new_kwargs, node_a_shadows_c_name # type: ignore[arg-type]
|
| 698 |
+
)
|
| 699 |
+
return node_a_shadows_c
|
| 700 |
+
else:
|
| 701 |
+
assert node_a.op in ("call_function", "call_method")
|
| 702 |
+
node_a_shadows_c = graph_c.create_node(
|
| 703 |
+
node_a.op, node_a.target, new_args, new_kwargs, node_a_shadows_c_name # type: ignore[arg-type]
|
| 704 |
+
)
|
| 705 |
+
return node_a_shadows_c
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
def create_a_shadows_b(
|
| 709 |
+
name_a: str,
|
| 710 |
+
gm_a: GraphModule,
|
| 711 |
+
name_b: str,
|
| 712 |
+
gm_b: GraphModule,
|
| 713 |
+
matched_subgraph_pairs: Dict[str, Tuple[NSSubgraph, NSSubgraph]],
|
| 714 |
+
logger_cls: Callable,
|
| 715 |
+
should_log_inputs: bool,
|
| 716 |
+
node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 717 |
+
) -> GraphModule:
|
| 718 |
+
"""
|
| 719 |
+
Creates a new GraphModule consisting of the graph of C, with the meaningful
|
| 720 |
+
nodes of A shadowing the corresponding nodes of B. For example,
|
| 721 |
+
|
| 722 |
+
Graph A:
|
| 723 |
+
a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2
|
| 724 |
+
|
| 725 |
+
Graph B:
|
| 726 |
+
b0 -> op0_int8 -> b1 -> op1_int8 -> b2
|
| 727 |
+
|
| 728 |
+
matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)}
|
| 729 |
+
|
| 730 |
+
Graph C (A shadows B):
|
| 731 |
+
|
| 732 |
+
/ dequant0 -> op0_fp32 -> logger_a_0 / dequant_1 -> op1_fp32 -> logger_a_1
|
| 733 |
+
/ /
|
| 734 |
+
b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1
|
| 735 |
+
|
| 736 |
+
In a nutshell, this function does the following for each node pair:
|
| 737 |
+
* copies the necessary attributes and modules from gm_a to gm_b,
|
| 738 |
+
keeping names unique
|
| 739 |
+
* adds a dtype cast op (dequant, quant, etc)
|
| 740 |
+
* adds a copy of node_a in gm_b's graph
|
| 741 |
+
* adds loggers to the outputs of node_a and node_b
|
| 742 |
+
"""
|
| 743 |
+
|
| 744 |
+
if node_type_to_io_type_map is None:
|
| 745 |
+
node_type_to_io_type_map = get_node_type_to_io_type_map()
|
| 746 |
+
|
| 747 |
+
# graph_c is the graph created from copying the nodes of graph_b and inserting
|
| 748 |
+
# the shadows with the nodes copied from graph_a
|
| 749 |
+
graph_c = Graph()
|
| 750 |
+
env_c: Dict[str, Any] = {}
|
| 751 |
+
modules = dict(gm_b.named_modules())
|
| 752 |
+
|
| 753 |
+
def load_arg(a):
|
| 754 |
+
return map_arg(a, lambda node: env_c[node.name])
|
| 755 |
+
|
| 756 |
+
start_node_b_to_matched_subgraph_a_and_name = {}
|
| 757 |
+
end_node_b_to_matched_subgraph_a_and_name = {}
|
| 758 |
+
for match_name, match in matched_subgraph_pairs.items():
|
| 759 |
+
subgraph_a, subgraph_b = match
|
| 760 |
+
ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a)
|
| 761 |
+
ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b)
|
| 762 |
+
start_node_b_to_matched_subgraph_a_and_name[subgraph_b.start_node] = (
|
| 763 |
+
subgraph_a,
|
| 764 |
+
match_name,
|
| 765 |
+
ref_node_type_a,
|
| 766 |
+
ref_node_type_b,
|
| 767 |
+
)
|
| 768 |
+
end_node_b_to_matched_subgraph_a_and_name[subgraph_b.end_node] = (
|
| 769 |
+
subgraph_a,
|
| 770 |
+
match_name,
|
| 771 |
+
ref_node_type_a,
|
| 772 |
+
ref_node_type_b,
|
| 773 |
+
)
|
| 774 |
+
|
| 775 |
+
for node_b in gm_b.graph.nodes:
|
| 776 |
+
if node_b.op == "output":
|
| 777 |
+
graph_c.output(map_arg(node_b.args[0], load_arg))
|
| 778 |
+
continue
|
| 779 |
+
|
| 780 |
+
# calculate the flags to determine what to do with this node
|
| 781 |
+
node_b_is_start_node = node_b in start_node_b_to_matched_subgraph_a_and_name
|
| 782 |
+
node_b_is_end_node = node_b in end_node_b_to_matched_subgraph_a_and_name
|
| 783 |
+
|
| 784 |
+
if node_b_is_start_node or node_b_is_end_node:
|
| 785 |
+
if node_b_is_start_node:
|
| 786 |
+
(
|
| 787 |
+
subgraph_a,
|
| 788 |
+
ref_name,
|
| 789 |
+
ref_node_type_a,
|
| 790 |
+
ref_node_type_b,
|
| 791 |
+
) = start_node_b_to_matched_subgraph_a_and_name[node_b]
|
| 792 |
+
else:
|
| 793 |
+
assert node_b_is_end_node
|
| 794 |
+
(
|
| 795 |
+
subgraph_a,
|
| 796 |
+
ref_name,
|
| 797 |
+
ref_node_type_a,
|
| 798 |
+
ref_node_type_b,
|
| 799 |
+
) = end_node_b_to_matched_subgraph_a_and_name[node_b]
|
| 800 |
+
|
| 801 |
+
all_op_types_support_shadowing = op_type_supports_shadowing(
|
| 802 |
+
subgraph_a.start_node
|
| 803 |
+
) and op_type_supports_shadowing(node_b)
|
| 804 |
+
if not all_op_types_support_shadowing:
|
| 805 |
+
print(
|
| 806 |
+
f"skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}"
|
| 807 |
+
+ f", start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}"
|
| 808 |
+
+ ", unsupported"
|
| 809 |
+
)
|
| 810 |
+
env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
|
| 811 |
+
continue
|
| 812 |
+
|
| 813 |
+
# For both start_node and end_node verify that we know how to do
|
| 814 |
+
# the dtype cast. If we do not, skip.
|
| 815 |
+
(
|
| 816 |
+
node_input_type_a,
|
| 817 |
+
node_output_type_a,
|
| 818 |
+
) = get_node_first_input_and_output_type(
|
| 819 |
+
subgraph_a.start_node, gm_a, logger_cls, node_type_to_io_type_map
|
| 820 |
+
)
|
| 821 |
+
(
|
| 822 |
+
node_input_type_b,
|
| 823 |
+
node_output_type_b,
|
| 824 |
+
) = get_node_first_input_and_output_type(
|
| 825 |
+
node_b, gm_b, logger_cls, node_type_to_io_type_map
|
| 826 |
+
)
|
| 827 |
+
node_io_types_known_a_and_b = (
|
| 828 |
+
node_input_type_a != NodeInputOrOutputType.UNKNOWN
|
| 829 |
+
and node_output_type_a != NodeInputOrOutputType.UNKNOWN
|
| 830 |
+
and node_input_type_b != NodeInputOrOutputType.UNKNOWN
|
| 831 |
+
and node_output_type_b != NodeInputOrOutputType.UNKNOWN
|
| 832 |
+
)
|
| 833 |
+
if not node_io_types_known_a_and_b:
|
| 834 |
+
print(
|
| 835 |
+
f"skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}"
|
| 836 |
+
+ f", start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}"
|
| 837 |
+
+ ", unknown dtype cast"
|
| 838 |
+
)
|
| 839 |
+
env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
|
| 840 |
+
continue
|
| 841 |
+
|
| 842 |
+
# If we are shadowing from fp32 to int8, we need to insert
|
| 843 |
+
# quantize_per_tensor call with qparams from the previous node.
|
| 844 |
+
# Only do this if we are able to infer these qparams from the graph.
|
| 845 |
+
if (
|
| 846 |
+
node_input_type_a == NodeInputOrOutputType.INT8
|
| 847 |
+
and node_input_type_b == NodeInputOrOutputType.FP32
|
| 848 |
+
):
|
| 849 |
+
node_a_input_qparams = get_node_input_qparams(
|
| 850 |
+
subgraph_a.start_node, gm_a, node_type_to_io_type_map
|
| 851 |
+
)
|
| 852 |
+
if not node_a_input_qparams:
|
| 853 |
+
print(
|
| 854 |
+
f"skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}"
|
| 855 |
+
+ f", start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}"
|
| 856 |
+
+ ", unknown input qparams"
|
| 857 |
+
)
|
| 858 |
+
env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
|
| 859 |
+
continue
|
| 860 |
+
|
| 861 |
+
num_non_param_args_node_a = get_number_of_non_param_args(
|
| 862 |
+
subgraph_a.start_node, gm_a
|
| 863 |
+
)
|
| 864 |
+
if not _can_insert_copy_of_subgraph_a(
|
| 865 |
+
subgraph_a, gm_a, num_non_param_args_node_a
|
| 866 |
+
):
|
| 867 |
+
print(
|
| 868 |
+
f"skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}"
|
| 869 |
+
+ f", start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}"
|
| 870 |
+
+ ", unhandled logic in subgraph copy"
|
| 871 |
+
)
|
| 872 |
+
env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
|
| 873 |
+
continue
|
| 874 |
+
|
| 875 |
+
fqn_base_a = _maybe_get_fqn(subgraph_a.base_op_node, gm_a)
|
| 876 |
+
fqn_base_b = _maybe_get_fqn(subgraph_b.base_op_node, gm_b) # type: ignore[possibly-undefined]
|
| 877 |
+
|
| 878 |
+
if node_b_is_start_node:
|
| 879 |
+
# if necessary, log the input of node_c
|
| 880 |
+
if should_log_inputs:
|
| 881 |
+
prev_node_b = get_normalized_nth_input(node_b, gm_b, 0)
|
| 882 |
+
if isinstance(prev_node_b, Node):
|
| 883 |
+
prev_node_c = env_c[prev_node_b.name]
|
| 884 |
+
env_c[prev_node_c.name] = _insert_logger_after_node(
|
| 885 |
+
prev_node_c,
|
| 886 |
+
gm_b,
|
| 887 |
+
logger_cls,
|
| 888 |
+
"_ns_logger_b_inp_",
|
| 889 |
+
node_b.name,
|
| 890 |
+
name_b,
|
| 891 |
+
ref_name,
|
| 892 |
+
ref_node_type_b,
|
| 893 |
+
NSSingleResultValuesType.NODE_INPUT.value,
|
| 894 |
+
index_within_arg=0,
|
| 895 |
+
index_of_arg=0,
|
| 896 |
+
fqn=fqn_base_b,
|
| 897 |
+
)
|
| 898 |
+
elif isinstance(prev_node_b, list):
|
| 899 |
+
# first, save the prev_node instances, because they
|
| 900 |
+
# will be overwritten in the env after the first logger
|
| 901 |
+
# is added
|
| 902 |
+
prev_node_c_list = [env_c[arg.name] for arg in prev_node_b]
|
| 903 |
+
|
| 904 |
+
for arg_idx, arg in enumerate(prev_node_b):
|
| 905 |
+
prev_node_c = prev_node_c_list[arg_idx]
|
| 906 |
+
env_c[prev_node_c.name] = _insert_logger_after_node(
|
| 907 |
+
prev_node_c,
|
| 908 |
+
gm_b,
|
| 909 |
+
logger_cls,
|
| 910 |
+
"_ns_logger_b_inp_",
|
| 911 |
+
node_b.name,
|
| 912 |
+
name_b,
|
| 913 |
+
ref_name,
|
| 914 |
+
ref_node_type_b,
|
| 915 |
+
NSSingleResultValuesType.NODE_INPUT.value,
|
| 916 |
+
index_within_arg=arg_idx,
|
| 917 |
+
index_of_arg=0,
|
| 918 |
+
fqn=fqn_base_b,
|
| 919 |
+
)
|
| 920 |
+
else:
|
| 921 |
+
# logging of inputs which are not lists is not supported yet
|
| 922 |
+
raise AssertionError(
|
| 923 |
+
f"type {type(prev_node_b)} is not handled yet"
|
| 924 |
+
)
|
| 925 |
+
# subgraph so far:
|
| 926 |
+
#
|
| 927 |
+
# (prev_node_c)+ -> (logger_c_input)?
|
| 928 |
+
|
| 929 |
+
# Note: this if statement is always True, spelling it out to clarify code
|
| 930 |
+
# intent.
|
| 931 |
+
if node_b_is_start_node or node_b_is_end_node:
|
| 932 |
+
# ensure env_c is populated with base node
|
| 933 |
+
env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
|
| 934 |
+
node_c = env_c[node_b.name]
|
| 935 |
+
|
| 936 |
+
# after this point,
|
| 937 |
+
#
|
| 938 |
+
# node_a is the original node from graph_a, with parent module gm_a
|
| 939 |
+
# node_b is the original node from graph_b, with parent module gm_b
|
| 940 |
+
# node_c is the copy of node_b in graph_c
|
| 941 |
+
#
|
| 942 |
+
# subgraph so far:
|
| 943 |
+
#
|
| 944 |
+
# (prev_node_c)+ -> (logger_c_input)? -> node_start_c
|
| 945 |
+
|
| 946 |
+
if node_b_is_start_node:
|
| 947 |
+
# cast dtype from the dtype of node_c's input to the dtype of
|
| 948 |
+
# node_a's input (dequant, etc)
|
| 949 |
+
# prev_node_c = node_c.args[0]
|
| 950 |
+
prev_node_c = get_normalized_nth_input(node_c, gm_b, 0) # type: ignore[possibly-undefined]
|
| 951 |
+
if should_log_inputs:
|
| 952 |
+
# skip the input logger when inserting a dtype cast
|
| 953 |
+
if isinstance(prev_node_c, Node):
|
| 954 |
+
prev_node_c = get_normalized_nth_input(node_c, gm_b, 0)
|
| 955 |
+
elif isinstance(prev_node_c, list):
|
| 956 |
+
prev_node_c = [
|
| 957 |
+
get_normalized_nth_input(arg, gm_b, 0)
|
| 958 |
+
for arg in prev_node_c
|
| 959 |
+
]
|
| 960 |
+
dtype_cast_node = _insert_dtype_cast_after_node(
|
| 961 |
+
subgraph_a.start_node,
|
| 962 |
+
node_c,
|
| 963 |
+
prev_node_c,
|
| 964 |
+
gm_a,
|
| 965 |
+
gm_b,
|
| 966 |
+
graph_c,
|
| 967 |
+
node_b.name + "_dtype_cast_",
|
| 968 |
+
logger_cls,
|
| 969 |
+
node_type_to_io_type_map,
|
| 970 |
+
)
|
| 971 |
+
# note: not inserting to env_c because all nodes which use the dtype
|
| 972 |
+
# casts are copied from graph_a
|
| 973 |
+
#
|
| 974 |
+
# subgraph so far:
|
| 975 |
+
#
|
| 976 |
+
# (dtype_cast_node)+
|
| 977 |
+
# /
|
| 978 |
+
# (prev_node_c)+ -> (logger_c_input)? -> node_start_c
|
| 979 |
+
|
| 980 |
+
# if input logging is enabled, log the input to the subgraph
|
| 981 |
+
if should_log_inputs:
|
| 982 |
+
# TODO: explain this
|
| 983 |
+
ref_node_name = ""
|
| 984 |
+
if isinstance(dtype_cast_node, Node):
|
| 985 |
+
dtype_cast_node = _insert_logger_after_node(
|
| 986 |
+
dtype_cast_node,
|
| 987 |
+
gm_b,
|
| 988 |
+
logger_cls,
|
| 989 |
+
"_ns_logger_a_inp_",
|
| 990 |
+
ref_node_name,
|
| 991 |
+
name_a,
|
| 992 |
+
ref_name,
|
| 993 |
+
ref_node_type_a,
|
| 994 |
+
NSSingleResultValuesType.NODE_INPUT.value,
|
| 995 |
+
index_within_arg=0,
|
| 996 |
+
index_of_arg=0,
|
| 997 |
+
fqn=fqn_base_a,
|
| 998 |
+
)
|
| 999 |
+
input_logger: Union[Node, List[Node]] = dtype_cast_node
|
| 1000 |
+
else:
|
| 1001 |
+
assert isinstance(dtype_cast_node, list)
|
| 1002 |
+
new_loggers = []
|
| 1003 |
+
for dtype_cast_idx, dtype_cast_node_inner in enumerate(
|
| 1004 |
+
dtype_cast_node
|
| 1005 |
+
):
|
| 1006 |
+
dtype_cast_logger = _insert_logger_after_node(
|
| 1007 |
+
dtype_cast_node_inner,
|
| 1008 |
+
gm_b,
|
| 1009 |
+
logger_cls,
|
| 1010 |
+
"_ns_logger_a_inp_",
|
| 1011 |
+
ref_node_name,
|
| 1012 |
+
name_a,
|
| 1013 |
+
ref_name,
|
| 1014 |
+
ref_node_type_a,
|
| 1015 |
+
NSSingleResultValuesType.NODE_INPUT.value,
|
| 1016 |
+
index_within_arg=dtype_cast_idx,
|
| 1017 |
+
index_of_arg=0,
|
| 1018 |
+
fqn=fqn_base_a,
|
| 1019 |
+
)
|
| 1020 |
+
new_loggers.append(dtype_cast_logger)
|
| 1021 |
+
dtype_cast_node = new_loggers
|
| 1022 |
+
input_logger = dtype_cast_node
|
| 1023 |
+
# subgraph so far:
|
| 1024 |
+
#
|
| 1025 |
+
# (dtype_cast_node)+ -> (logger_a_input)?
|
| 1026 |
+
# /
|
| 1027 |
+
# prev_node_c -> (logger_c_input)? -> node_start_c
|
| 1028 |
+
|
| 1029 |
+
# hook up the new mod_a copy to be in the graph, receiving the
|
| 1030 |
+
# same inputs as mod_b does, with dtype cast to match a
|
| 1031 |
+
# Some ops, such as LSTMs, have two non-param inputs. If we have
|
| 1032 |
+
# such an op, pass the second param as well. Note: dtype casting
|
| 1033 |
+
# for the second param is not implemented yet, it can be added
|
| 1034 |
+
# later if there is a use case.
|
| 1035 |
+
node_c_second_non_param_arg = None
|
| 1036 |
+
num_non_param_args_node_a = get_number_of_non_param_args(
|
| 1037 |
+
subgraph_a.start_node, gm_a
|
| 1038 |
+
)
|
| 1039 |
+
if num_non_param_args_node_a == 2:
|
| 1040 |
+
# node_c_second_non_param_arg = node_c.args[1]
|
| 1041 |
+
node_c_second_non_param_arg = get_normalized_nth_input(
|
| 1042 |
+
node_c, gm_b, 1
|
| 1043 |
+
)
|
| 1044 |
+
node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c(
|
| 1045 |
+
dtype_cast_node,
|
| 1046 |
+
node_c_second_non_param_arg,
|
| 1047 |
+
subgraph_a,
|
| 1048 |
+
gm_a,
|
| 1049 |
+
gm_b,
|
| 1050 |
+
node_c.name + "_shadow_copy_",
|
| 1051 |
+
)
|
| 1052 |
+
env_c[node_a_shadows_c.name] = node_a_shadows_c
|
| 1053 |
+
# subgraph so far:
|
| 1054 |
+
#
|
| 1055 |
+
# dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy(args/kwargs not shown)
|
| 1056 |
+
# /
|
| 1057 |
+
# (prev_node_c)+ -> (logger_c_input)? -> node_start_c
|
| 1058 |
+
|
| 1059 |
+
if should_log_inputs:
|
| 1060 |
+
# When we created the input logger, we left the ref_node_name
|
| 1061 |
+
# as an empty string, because the subgraph copy did not exist
|
| 1062 |
+
# yet. Now that the subgraph copy exists, we modify this name
|
| 1063 |
+
# to its true value.
|
| 1064 |
+
# Note: the alternative to this is to create the input logger
|
| 1065 |
+
# after creating the subgraph, which is slightly more
|
| 1066 |
+
# complicated. This is the lesser of two evils.
|
| 1067 |
+
# input_logger = env_c[dtype_cast_node.name]
|
| 1068 |
+
# Find the first node in the subgraph
|
| 1069 |
+
cur_node = node_a_shadows_c
|
| 1070 |
+
while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger: # type: ignore[possibly-undefined]
|
| 1071 |
+
cur_node = get_normalized_nth_input(cur_node, gm_b, 0) # type: ignore[assignment]
|
| 1072 |
+
if isinstance(input_logger, Node):
|
| 1073 |
+
input_logger_mod = getattr(gm_b, input_logger.name)
|
| 1074 |
+
input_logger_mod.ref_node_name = cur_node.name
|
| 1075 |
+
else:
|
| 1076 |
+
assert isinstance(input_logger, list)
|
| 1077 |
+
for input_logger_inner in input_logger:
|
| 1078 |
+
input_logger_mod = getattr(gm_b, input_logger_inner.name)
|
| 1079 |
+
input_logger_mod.ref_node_name = cur_node.name
|
| 1080 |
+
|
| 1081 |
+
# hook up a logger to the mod_a copy
|
| 1082 |
+
env_c[node_a_shadows_c.name] = _insert_logger_after_node(
|
| 1083 |
+
env_c[node_a_shadows_c.name],
|
| 1084 |
+
gm_b,
|
| 1085 |
+
logger_cls,
|
| 1086 |
+
"_ns_logger_a_",
|
| 1087 |
+
node_a_shadows_c.name,
|
| 1088 |
+
name_a,
|
| 1089 |
+
ref_name,
|
| 1090 |
+
ref_node_type_a,
|
| 1091 |
+
NSSingleResultValuesType.NODE_OUTPUT.value,
|
| 1092 |
+
index_within_arg=0,
|
| 1093 |
+
index_of_arg=0,
|
| 1094 |
+
fqn=fqn_base_a,
|
| 1095 |
+
)
|
| 1096 |
+
# subgraph so far:
|
| 1097 |
+
#
|
| 1098 |
+
# dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
|
| 1099 |
+
# /
|
| 1100 |
+
# (prev_node_c)+ -> (logger_c_input)? -> node_start_c
|
| 1101 |
+
|
| 1102 |
+
if node_b_is_end_node:
|
| 1103 |
+
# hook up a logger to the mod_b copy
|
| 1104 |
+
env_c[node_b.name] = _insert_logger_after_node(
|
| 1105 |
+
env_c[node_b.name],
|
| 1106 |
+
gm_b,
|
| 1107 |
+
logger_cls,
|
| 1108 |
+
"_ns_logger_b_",
|
| 1109 |
+
node_b.name,
|
| 1110 |
+
name_b,
|
| 1111 |
+
ref_name,
|
| 1112 |
+
ref_node_type_b,
|
| 1113 |
+
NSSingleResultValuesType.NODE_OUTPUT.value,
|
| 1114 |
+
index_within_arg=0,
|
| 1115 |
+
index_of_arg=0,
|
| 1116 |
+
fqn=fqn_base_b,
|
| 1117 |
+
)
|
| 1118 |
+
# subgraph so far:
|
| 1119 |
+
#
|
| 1120 |
+
# dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
|
| 1121 |
+
# /
|
| 1122 |
+
# (prev_node_c+) -> (logger_c_input)? -> node_start_c -> ... -> node_end_c -> logger_c
|
| 1123 |
+
#
|
| 1124 |
+
# Note: node_start_c may be the same node as node_end_c, or they
|
| 1125 |
+
# may have nodes inbetween.
|
| 1126 |
+
|
| 1127 |
+
else:
|
| 1128 |
+
env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
|
| 1129 |
+
|
| 1130 |
+
gm_c = GraphModule(gm_b, graph_c)
|
| 1131 |
+
return gm_c
|