ayousanz commited on
Commit
ddd9ed8
·
verified ·
1 Parent(s): 56405a9

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/Lib/site-packages/torch/ao/nn/__pycache__/__init__.cpython-39.pyc +0 -0
  2. .venv/Lib/site-packages/torch/ao/nn/intrinsic/__init__.py +40 -0
  3. .venv/Lib/site-packages/torch/ao/nn/intrinsic/modules/__init__.py +41 -0
  4. .venv/Lib/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/__init__.cpython-39.pyc +0 -0
  5. .venv/Lib/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/fused.cpython-39.pyc +0 -0
  6. .venv/Lib/site-packages/torch/ao/nn/intrinsic/modules/fused.py +245 -0
  7. .venv/Lib/site-packages/torch/ao/nn/intrinsic/qat/__init__.py +1 -0
  8. .venv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__init__.py +32 -0
  9. .venv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/conv_fused.py +1050 -0
  10. .venv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_fused.py +193 -0
  11. .venv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_relu.py +51 -0
  12. .venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-39.pyc +0 -0
  13. .venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc +0 -0
  14. .venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__init__.py +18 -0
  15. .venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-39.pyc +0 -0
  16. .venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_add.cpython-39.pyc +0 -0
  17. .venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-39.pyc +0 -0
  18. .venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py +105 -0
  19. .venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_add.py +145 -0
  20. .venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py +263 -0
  21. .venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py +187 -0
  22. .venv/Lib/site-packages/torch/ao/nn/quantized/__pycache__/__init__.cpython-39.pyc +0 -0
  23. .venv/Lib/site-packages/torch/ao/nn/quantized/__pycache__/functional.cpython-39.pyc +0 -0
  24. .venv/Lib/site-packages/torch/ao/nn/quantized/reference/__pycache__/__init__.cpython-39.pyc +0 -0
  25. .venv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/__init__.cpython-39.pyc +0 -0
  26. .venv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/rnn.cpython-39.pyc +0 -0
  27. .venv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/sparse.cpython-39.pyc +0 -0
  28. .venv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/utils.cpython-39.pyc +0 -0
  29. .venv/Lib/site-packages/torch/ao/nn/sparse/__init__.py +1 -0
  30. .venv/Lib/site-packages/torch/ao/nn/sparse/__pycache__/__init__.cpython-39.pyc +0 -0
  31. .venv/Lib/site-packages/torch/ao/nn/sparse/quantized/__init__.py +10 -0
  32. .venv/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/__init__.cpython-39.pyc +0 -0
  33. .venv/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/linear.cpython-39.pyc +0 -0
  34. .venv/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/utils.cpython-39.pyc +0 -0
  35. .venv/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__init__.py +6 -0
  36. .venv/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/__init__.cpython-39.pyc +0 -0
  37. .venv/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/linear.cpython-39.pyc +0 -0
  38. .venv/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/linear.py +188 -0
  39. .venv/Lib/site-packages/torch/ao/nn/sparse/quantized/linear.py +273 -0
  40. .venv/Lib/site-packages/torch/ao/nn/sparse/quantized/utils.py +56 -0
  41. .venv/Lib/site-packages/torch/ao/ns/__init__.py +0 -0
  42. .venv/Lib/site-packages/torch/ao/ns/__pycache__/__init__.cpython-39.pyc +0 -0
  43. .venv/Lib/site-packages/torch/ao/ns/_numeric_suite.py +563 -0
  44. .venv/Lib/site-packages/torch/ao/ns/_numeric_suite_fx.py +1130 -0
  45. .venv/Lib/site-packages/torch/ao/ns/fx/__init__.py +0 -0
  46. .venv/Lib/site-packages/torch/ao/ns/fx/__pycache__/__init__.cpython-39.pyc +0 -0
  47. .venv/Lib/site-packages/torch/ao/ns/fx/__pycache__/ns_types.cpython-39.pyc +0 -0
  48. .venv/Lib/site-packages/torch/ao/ns/fx/__pycache__/utils.cpython-39.pyc +0 -0
  49. .venv/Lib/site-packages/torch/ao/ns/fx/graph_matcher.py +470 -0
  50. .venv/Lib/site-packages/torch/ao/ns/fx/graph_passes.py +1131 -0
.venv/Lib/site-packages/torch/ao/nn/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (510 Bytes). View file
 
.venv/Lib/site-packages/torch/ao/nn/intrinsic/__init__.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from .modules import * # noqa: F403
3
+ from .modules.fused import _FusedModule # noqa: F403
4
+
5
+
6
+ # # Subpackages
7
+ # from . import qat # noqa: F403
8
+ # from . import quantized # noqa: F403
9
+
10
+ __all__ = [
11
+ "ConvBn1d",
12
+ "ConvBn2d",
13
+ "ConvBn3d",
14
+ "ConvBnReLU1d",
15
+ "ConvBnReLU2d",
16
+ "ConvBnReLU3d",
17
+ "ConvReLU1d",
18
+ "ConvReLU2d",
19
+ "ConvReLU3d",
20
+ "LinearReLU",
21
+ "BNReLU2d",
22
+ "BNReLU3d",
23
+ "LinearBn1d",
24
+ "LinearLeakyReLU",
25
+ "LinearTanh",
26
+ "ConvAdd2d",
27
+ "ConvAddReLU2d",
28
+ ]
29
+
30
+
31
+ # We are exposing all subpackages to the end-user.
32
+ # Because of possible inter-dependency, we want to avoid
33
+ # the cyclic imports, thus implementing lazy version
34
+ # as per https://peps.python.org/pep-0562/
35
+ def __getattr__(name):
36
+ if name in __all__:
37
+ import importlib
38
+
39
+ return importlib.import_module("." + name, __name__)
40
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
.venv/Lib/site-packages/torch/ao/nn/intrinsic/modules/__init__.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .fused import ( # noqa: F401
2
+ _FusedModule,
3
+ BNReLU2d,
4
+ BNReLU3d,
5
+ ConvAdd2d,
6
+ ConvAddReLU2d,
7
+ ConvBn1d,
8
+ ConvBn2d,
9
+ ConvBn3d,
10
+ ConvBnReLU1d,
11
+ ConvBnReLU2d,
12
+ ConvBnReLU3d,
13
+ ConvReLU1d,
14
+ ConvReLU2d,
15
+ ConvReLU3d,
16
+ LinearBn1d,
17
+ LinearLeakyReLU,
18
+ LinearReLU,
19
+ LinearTanh,
20
+ )
21
+
22
+
23
+ __all__ = [
24
+ "ConvBn1d",
25
+ "ConvBn2d",
26
+ "ConvBn3d",
27
+ "ConvBnReLU1d",
28
+ "ConvBnReLU2d",
29
+ "ConvBnReLU3d",
30
+ "ConvReLU1d",
31
+ "ConvReLU2d",
32
+ "ConvReLU3d",
33
+ "LinearReLU",
34
+ "BNReLU2d",
35
+ "BNReLU3d",
36
+ "LinearBn1d",
37
+ "LinearLeakyReLU",
38
+ "LinearTanh",
39
+ "ConvAdd2d",
40
+ "ConvAddReLU2d",
41
+ ]
.venv/Lib/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (709 Bytes). View file
 
.venv/Lib/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/fused.cpython-39.pyc ADDED
Binary file (9.96 kB). View file
 
.venv/Lib/site-packages/torch/ao/nn/intrinsic/modules/fused.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+ from torch.nn import (
4
+ BatchNorm1d,
5
+ BatchNorm2d,
6
+ BatchNorm3d,
7
+ Conv1d,
8
+ Conv2d,
9
+ Conv3d,
10
+ Linear,
11
+ ReLU,
12
+ )
13
+ from torch.nn.utils.parametrize import type_before_parametrizations
14
+
15
+
16
+ __all__ = [
17
+ "ConvReLU1d",
18
+ "ConvReLU2d",
19
+ "ConvReLU3d",
20
+ "LinearReLU",
21
+ "ConvBn1d",
22
+ "ConvBn2d",
23
+ "ConvBnReLU1d",
24
+ "ConvBnReLU2d",
25
+ "ConvBn3d",
26
+ "ConvBnReLU3d",
27
+ "BNReLU2d",
28
+ "BNReLU3d",
29
+ "LinearBn1d",
30
+ "LinearLeakyReLU",
31
+ "LinearTanh",
32
+ "ConvAdd2d",
33
+ "ConvAddReLU2d",
34
+ ]
35
+
36
+
37
+ # Used for identifying intrinsic modules used in quantization
38
+ class _FusedModule(torch.nn.Sequential):
39
+ pass
40
+
41
+
42
+ class ConvReLU1d(_FusedModule):
43
+ r"""This is a sequential container which calls the Conv1d and ReLU modules.
44
+ During quantization this will be replaced with the corresponding fused module."""
45
+
46
+ def __init__(self, conv, relu):
47
+ assert (
48
+ type_before_parametrizations(conv) == Conv1d
49
+ and type_before_parametrizations(relu) == ReLU
50
+ ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}"
51
+ super().__init__(conv, relu)
52
+
53
+
54
+ class ConvReLU2d(_FusedModule):
55
+ r"""This is a sequential container which calls the Conv2d and ReLU modules.
56
+ During quantization this will be replaced with the corresponding fused module."""
57
+
58
+ def __init__(self, conv, relu):
59
+ assert (
60
+ type_before_parametrizations(conv) == Conv2d
61
+ and type_before_parametrizations(relu) == ReLU
62
+ ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}"
63
+ super().__init__(conv, relu)
64
+
65
+
66
+ class ConvReLU3d(_FusedModule):
67
+ r"""This is a sequential container which calls the Conv3d and ReLU modules.
68
+ During quantization this will be replaced with the corresponding fused module."""
69
+
70
+ def __init__(self, conv, relu):
71
+ assert (
72
+ type_before_parametrizations(conv) == Conv3d
73
+ and type_before_parametrizations(relu) == ReLU
74
+ ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}"
75
+ super().__init__(conv, relu)
76
+
77
+
78
+ class LinearReLU(_FusedModule):
79
+ r"""This is a sequential container which calls the Linear and ReLU modules.
80
+ During quantization this will be replaced with the corresponding fused module."""
81
+
82
+ def __init__(self, linear, relu):
83
+ assert (
84
+ type_before_parametrizations(linear) == Linear
85
+ and type_before_parametrizations(relu) == ReLU
86
+ ), f"Incorrect types for input modules{type_before_parametrizations(linear)}{type_before_parametrizations(relu)}"
87
+ super().__init__(linear, relu)
88
+
89
+
90
+ class ConvBn1d(_FusedModule):
91
+ r"""This is a sequential container which calls the Conv 1d and Batch Norm 1d modules.
92
+ During quantization this will be replaced with the corresponding fused module."""
93
+
94
+ def __init__(self, conv, bn):
95
+ assert (
96
+ type_before_parametrizations(conv) == Conv1d
97
+ and type_before_parametrizations(bn) == BatchNorm1d
98
+ ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}"
99
+ super().__init__(conv, bn)
100
+
101
+
102
+ class ConvBn2d(_FusedModule):
103
+ r"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules.
104
+ During quantization this will be replaced with the corresponding fused module."""
105
+
106
+ def __init__(self, conv, bn):
107
+ assert (
108
+ type_before_parametrizations(conv) == Conv2d
109
+ and type_before_parametrizations(bn) == BatchNorm2d
110
+ ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}"
111
+ super().__init__(conv, bn)
112
+
113
+
114
+ class ConvBnReLU1d(_FusedModule):
115
+ r"""This is a sequential container which calls the Conv 1d, Batch Norm 1d, and ReLU modules.
116
+ During quantization this will be replaced with the corresponding fused module."""
117
+
118
+ def __init__(self, conv, bn, relu):
119
+ assert (
120
+ type_before_parametrizations(conv) == Conv1d
121
+ and type_before_parametrizations(bn) == BatchNorm1d
122
+ and type_before_parametrizations(relu) == ReLU
123
+ ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}{type_before_parametrizations(relu)}" # noqa: B950
124
+ super().__init__(conv, bn, relu)
125
+
126
+
127
+ class ConvBnReLU2d(_FusedModule):
128
+ r"""This is a sequential container which calls the Conv 2d, Batch Norm 2d, and ReLU modules.
129
+ During quantization this will be replaced with the corresponding fused module."""
130
+
131
+ def __init__(self, conv, bn, relu):
132
+ assert (
133
+ type_before_parametrizations(conv) == Conv2d
134
+ and type_before_parametrizations(bn) == BatchNorm2d
135
+ and type_before_parametrizations(relu) == ReLU
136
+ ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}{type_before_parametrizations(relu)}" # noqa: B950
137
+ super().__init__(conv, bn, relu)
138
+
139
+
140
+ class ConvBn3d(_FusedModule):
141
+ r"""This is a sequential container which calls the Conv 3d and Batch Norm 3d modules.
142
+ During quantization this will be replaced with the corresponding fused module."""
143
+
144
+ def __init__(self, conv, bn):
145
+ assert (
146
+ type_before_parametrizations(conv) == Conv3d
147
+ and type_before_parametrizations(bn) == BatchNorm3d
148
+ ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}"
149
+ super().__init__(conv, bn)
150
+
151
+
152
+ class ConvBnReLU3d(_FusedModule):
153
+ r"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules.
154
+ During quantization this will be replaced with the corresponding fused module."""
155
+
156
+ def __init__(self, conv, bn, relu):
157
+ assert (
158
+ type_before_parametrizations(conv) == Conv3d
159
+ and type_before_parametrizations(bn) == BatchNorm3d
160
+ and type_before_parametrizations(relu) == ReLU
161
+ ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}{type_before_parametrizations(relu)}" # noqa: B950
162
+ super().__init__(conv, bn, relu)
163
+
164
+
165
+ class BNReLU2d(_FusedModule):
166
+ r"""This is a sequential container which calls the BatchNorm 2d and ReLU modules.
167
+ During quantization this will be replaced with the corresponding fused module."""
168
+
169
+ def __init__(self, batch_norm, relu):
170
+ assert (
171
+ type_before_parametrizations(batch_norm) == BatchNorm2d
172
+ and type_before_parametrizations(relu) == ReLU
173
+ ), f"Incorrect types for input modules{type_before_parametrizations(batch_norm)}{type_before_parametrizations(relu)}"
174
+ super().__init__(batch_norm, relu)
175
+
176
+
177
+ class BNReLU3d(_FusedModule):
178
+ r"""This is a sequential container which calls the BatchNorm 3d and ReLU modules.
179
+ During quantization this will be replaced with the corresponding fused module."""
180
+
181
+ def __init__(self, batch_norm, relu):
182
+ assert (
183
+ type_before_parametrizations(batch_norm) == BatchNorm3d
184
+ and type_before_parametrizations(relu) == ReLU
185
+ ), f"Incorrect types for input modules{type_before_parametrizations(batch_norm)}{type_before_parametrizations(relu)}"
186
+ super().__init__(batch_norm, relu)
187
+
188
+
189
+ class LinearBn1d(_FusedModule):
190
+ r"""This is a sequential container which calls the Linear and BatchNorm1d modules.
191
+ During quantization this will be replaced with the corresponding fused module."""
192
+
193
+ def __init__(self, linear, bn):
194
+ assert (
195
+ type_before_parametrizations(linear) == Linear
196
+ and type_before_parametrizations(bn) == BatchNorm1d
197
+ ), f"Incorrect types for input modules{type_before_parametrizations(linear)}{type_before_parametrizations(bn)}"
198
+ super().__init__(linear, bn)
199
+
200
+
201
+ class LinearLeakyReLU(_FusedModule):
202
+ r"""This is a sequential container which calls the Linear and LeakyReLU modules.
203
+ During quantization this will be replaced with the corresponding fused module."""
204
+
205
+ def __init__(self, linear, leaky_relu):
206
+ assert (
207
+ type(linear) == Linear and type(leaky_relu) == torch.nn.LeakyReLU
208
+ ), f"Incorrect types for input modules{type(linear)}{type(leaky_relu)}"
209
+ super().__init__(linear, leaky_relu)
210
+
211
+
212
+ class LinearTanh(_FusedModule):
213
+ r"""This is a sequential container which calls the Linear and Tanh modules.
214
+ During quantization this will be replaced with the corresponding fused module."""
215
+
216
+ def __init__(self, linear, tanh):
217
+ assert (
218
+ type(linear) == Linear and type(tanh) == torch.nn.Tanh
219
+ ), f"Incorrect types for input modules{type(linear)}{type(tanh)}"
220
+ super().__init__(linear, tanh)
221
+
222
+
223
+ class ConvAdd2d(_FusedModule):
224
+ r"""This is a sequential container which calls the Conv2d modules with extra Add.
225
+ During quantization this will be replaced with the corresponding fused module."""
226
+
227
+ def __init__(self, conv, add):
228
+ super().__init__(conv)
229
+ self.add = add
230
+
231
+ def forward(self, x1, x2):
232
+ return self.add(self[0](x1), x2)
233
+
234
+
235
+ class ConvAddReLU2d(_FusedModule):
236
+ r"""This is a sequential container which calls the Conv2d, add, Relu.
237
+ During quantization this will be replaced with the corresponding fused module."""
238
+
239
+ def __init__(self, conv, add, relu):
240
+ super().__init__(conv)
241
+ self.add = add
242
+ self.relu = relu
243
+
244
+ def forward(self, x1, x2):
245
+ return self.relu(self.add(self[0](x1), x2))
.venv/Lib/site-packages/torch/ao/nn/intrinsic/qat/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .modules import * # noqa: F403
.venv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .conv_fused import (
2
+ ConvBn1d,
3
+ ConvBn2d,
4
+ ConvBn3d,
5
+ ConvBnReLU1d,
6
+ ConvBnReLU2d,
7
+ ConvBnReLU3d,
8
+ ConvReLU1d,
9
+ ConvReLU2d,
10
+ ConvReLU3d,
11
+ freeze_bn_stats,
12
+ update_bn_stats,
13
+ )
14
+ from .linear_fused import LinearBn1d
15
+ from .linear_relu import LinearReLU
16
+
17
+
18
+ __all__ = [
19
+ "LinearReLU",
20
+ "LinearBn1d",
21
+ "ConvReLU1d",
22
+ "ConvReLU2d",
23
+ "ConvReLU3d",
24
+ "ConvBn1d",
25
+ "ConvBn2d",
26
+ "ConvBn3d",
27
+ "ConvBnReLU1d",
28
+ "ConvBnReLU2d",
29
+ "ConvBnReLU3d",
30
+ "update_bn_stats",
31
+ "freeze_bn_stats",
32
+ ]
.venv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/conv_fused.py ADDED
@@ -0,0 +1,1050 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import math
3
+ from typing import TypeVar
4
+
5
+ import torch
6
+ import torch.ao.nn.intrinsic as nni
7
+ import torch.ao.nn.qat as nnqat
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.nn import init
11
+ from torch.nn.modules.utils import _pair, _single, _triple
12
+ from torch.nn.parameter import Parameter
13
+ from torch.nn.utils import fuse_conv_bn_weights
14
+
15
+
16
+ __all__ = [
17
+ "ConvBn1d",
18
+ "ConvBnReLU1d",
19
+ "ConvReLU1d",
20
+ "ConvBn2d",
21
+ "ConvBnReLU2d",
22
+ "ConvReLU2d",
23
+ "ConvBn3d",
24
+ "ConvBnReLU3d",
25
+ "ConvReLU3d",
26
+ "update_bn_stats",
27
+ "freeze_bn_stats",
28
+ ]
29
+ _BN_CLASS_MAP = {
30
+ 1: nn.BatchNorm1d,
31
+ 2: nn.BatchNorm2d,
32
+ 3: nn.BatchNorm3d,
33
+ }
34
+
35
+
36
+ MOD = TypeVar("MOD", bound=nn.modules.conv._ConvNd)
37
+
38
+
39
+ class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule):
40
+ _version = 2
41
+ _FLOAT_MODULE = MOD
42
+
43
+ def __init__(
44
+ self,
45
+ # ConvNd args
46
+ in_channels,
47
+ out_channels,
48
+ kernel_size,
49
+ stride,
50
+ padding,
51
+ dilation,
52
+ transposed,
53
+ output_padding,
54
+ groups,
55
+ bias,
56
+ padding_mode,
57
+ # BatchNormNd args
58
+ # num_features: out_channels
59
+ eps=1e-05,
60
+ momentum=0.1,
61
+ # affine: True
62
+ # track_running_stats: True
63
+ # Args for this module
64
+ freeze_bn=False,
65
+ qconfig=None,
66
+ dim=2,
67
+ ):
68
+ nn.modules.conv._ConvNd.__init__(
69
+ self,
70
+ in_channels,
71
+ out_channels,
72
+ kernel_size,
73
+ stride,
74
+ padding,
75
+ dilation,
76
+ transposed,
77
+ output_padding,
78
+ groups,
79
+ False,
80
+ padding_mode,
81
+ )
82
+ assert qconfig, "qconfig must be provided for QAT module"
83
+ self.qconfig = qconfig
84
+ self.freeze_bn = freeze_bn if self.training else True
85
+ self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True)
86
+ self.weight_fake_quant = self.qconfig.weight()
87
+ if bias:
88
+ self.bias = Parameter(torch.empty(out_channels))
89
+ else:
90
+ self.register_parameter("bias", None)
91
+ self.reset_bn_parameters()
92
+
93
+ # this needs to be called after reset_bn_parameters,
94
+ # as they modify the same state
95
+ if self.training:
96
+ if freeze_bn:
97
+ self.freeze_bn_stats()
98
+ else:
99
+ self.update_bn_stats()
100
+ else:
101
+ self.freeze_bn_stats()
102
+
103
+ self._enable_slow_path_for_better_numerical_stability = False
104
+
105
+ def reset_running_stats(self):
106
+ self.bn.reset_running_stats()
107
+
108
+ def reset_bn_parameters(self):
109
+ self.bn.reset_running_stats()
110
+ init.uniform_(self.bn.weight)
111
+ init.zeros_(self.bn.bias)
112
+ # note: below is actually for conv, not BN
113
+ if self.bias is not None:
114
+ fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
115
+ bound = 1 / math.sqrt(fan_in)
116
+ init.uniform_(self.bias, -bound, bound)
117
+
118
+ def reset_parameters(self):
119
+ super().reset_parameters()
120
+
121
+ def update_bn_stats(self):
122
+ self.freeze_bn = False
123
+ self.bn.training = True
124
+ return self
125
+
126
+ def freeze_bn_stats(self):
127
+ self.freeze_bn = True
128
+ self.bn.training = False
129
+ return self
130
+
131
+ def _forward(self, input):
132
+ if self._enable_slow_path_for_better_numerical_stability:
133
+ return self._forward_slow(input)
134
+ return self._forward_approximate(input)
135
+
136
+ def _forward_approximate(self, input):
137
+ """Approximated method to fuse conv and bn. It requires only one forward pass.
138
+ conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std
139
+ """
140
+ assert self.bn.running_var is not None
141
+ running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
142
+ scale_factor = self.bn.weight / running_std
143
+ weight_shape = [1] * len(self.weight.shape)
144
+ weight_shape[0] = -1
145
+ bias_shape = [1] * len(self.weight.shape)
146
+ bias_shape[1] = -1
147
+ scaled_weight = self.weight_fake_quant(
148
+ self.weight * scale_factor.reshape(weight_shape)
149
+ )
150
+ # using zero bias here since the bias for original conv
151
+ # will be added later
152
+ if self.bias is not None:
153
+ zero_bias = torch.zeros_like(self.bias, dtype=input.dtype)
154
+ else:
155
+ zero_bias = torch.zeros(
156
+ self.out_channels, device=scaled_weight.device, dtype=input.dtype
157
+ )
158
+ conv = self._conv_forward(input, scaled_weight, zero_bias)
159
+ conv_orig = conv / scale_factor.reshape(bias_shape)
160
+ if self.bias is not None:
161
+ conv_orig = conv_orig + self.bias.reshape(bias_shape)
162
+ conv = self.bn(conv_orig)
163
+ return conv
164
+
165
+ def _forward_slow(self, input):
166
+ """
167
+ A more accurate but slow method to compute conv bn fusion, following https://arxiv.org/pdf/1806.08342.pdf
168
+ It requires two forward passes but handles the case bn.weight == 0
169
+
170
+ Conv: Y = WX + B_c
171
+ Conv without bias: Y0 = WX = Y - B_c, Y = Y0 + B_c
172
+
173
+ Batch statistics:
174
+ mean_Y = Y.mean()
175
+ = Y0.mean() + B_c
176
+ var_Y = (Y - mean_Y)^2.mean()
177
+ = (Y0 - Y0.mean())^2.mean()
178
+ BN (r: bn.weight, beta: bn.bias):
179
+ Z = r * (Y - mean_Y) / sqrt(var_Y + eps) + beta
180
+ = r * (Y0 - Y0.mean()) / sqrt(var_Y + eps) + beta
181
+
182
+ Fused Conv BN training (std_Y = sqrt(var_Y + eps)):
183
+ Z = (r * W / std_Y) * X + r * (B_c - mean_Y) / std_Y + beta
184
+ = (r * W / std_Y) * X - r * Y0.mean() / std_Y + beta
185
+
186
+ Fused Conv BN inference (running_std = sqrt(running_var + eps)):
187
+ Z = (r * W / running_std) * X - r * (running_mean - B_c) / running_std + beta
188
+
189
+ QAT with fused conv bn:
190
+ Z_train = fake_quant(r * W / running_std) * X * (running_std / std_Y) - r * Y0.mean() / std_Y + beta
191
+ = conv(X, fake_quant(r * W / running_std)) * (running_std / std_Y) - r * Y0.mean() / std_Y + beta
192
+ Z_inference = conv(X, fake_quant(r * W / running_std)) - r * (running_mean - B_c) / running_std + beta
193
+ """
194
+
195
+ assert self.bn.running_var is not None
196
+ assert self.bn.running_mean is not None
197
+
198
+ # using zero bias here since the bias for original conv
199
+ # will be added later
200
+ zero_bias = torch.zeros(
201
+ self.out_channels, device=self.weight.device, dtype=input.dtype
202
+ )
203
+
204
+ weight_shape = [1] * len(self.weight.shape)
205
+ weight_shape[0] = -1
206
+ bias_shape = [1] * len(self.weight.shape)
207
+ bias_shape[1] = -1
208
+
209
+ if self.bn.training:
210
+ # needed to compute batch mean/std
211
+ conv_out = self._conv_forward(input, self.weight, zero_bias)
212
+ # update bn statistics
213
+ with torch.no_grad():
214
+ conv_out_bias = (
215
+ conv_out
216
+ if self.bias is None
217
+ else conv_out + self.bias.reshape(bias_shape)
218
+ )
219
+ self.bn(conv_out_bias)
220
+
221
+ # fused conv + bn without bias using bn running statistics
222
+ running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
223
+ scale_factor = self.bn.weight / running_std
224
+ scaled_weight = self.weight_fake_quant(
225
+ self.weight * scale_factor.reshape(weight_shape)
226
+ )
227
+ # fused conv without bias for inference: (r * W / running_std) * X
228
+ conv_bn = self._conv_forward(input, scaled_weight, zero_bias)
229
+
230
+ if self.bn.training:
231
+ avg_dims = [0] + list(range(2, len(self.weight.shape)))
232
+ batch_mean = conv_out.mean(avg_dims) # type: ignore[possibly-undefined]
233
+ batch_var = torch.square(conv_out - batch_mean.reshape(bias_shape)).mean(
234
+ avg_dims
235
+ )
236
+ batch_std = torch.sqrt(batch_var + self.bn.eps)
237
+
238
+ # scale to use batch std in training mode
239
+ # conv(X, r * W / std_Y) = conv(X, r * W / running_std) * (running_std / std_Y)
240
+ unscale_factor = running_std / batch_std
241
+ conv_bn *= unscale_factor.reshape(bias_shape)
242
+
243
+ fused_mean = batch_mean
244
+ fused_std = batch_std
245
+ else:
246
+ fused_mean = self.bn.running_mean - (
247
+ self.bias if self.bias is not None else 0
248
+ )
249
+ fused_std = running_std
250
+
251
+ # fused bias = beta - r * mean / std
252
+ fused_bias = self.bn.bias - self.bn.weight * fused_mean / fused_std
253
+ conv_bn += fused_bias.reshape(bias_shape)
254
+
255
+ # HACK to let conv bias participate in loss to avoid DDP error (parameters
256
+ # were not used in producing loss)
257
+ if self.bias is not None:
258
+ conv_bn += (self.bias - self.bias).reshape(bias_shape)
259
+
260
+ return conv_bn
261
+
262
+ def extra_repr(self):
263
+ # TODO(jerryzh): extend
264
+ return super().extra_repr()
265
+
266
+ def forward(self, input):
267
+ return self._forward(input)
268
+
269
+ def train(self, mode=True):
270
+ """
271
+ Batchnorm's training behavior is using the self.training flag. Prevent
272
+ changing it if BN is frozen. This makes sure that calling `model.train()`
273
+ on a model with a frozen BN will behave properly.
274
+ """
275
+ self.training = mode
276
+ if not self.freeze_bn:
277
+ for module in self.children():
278
+ module.train(mode)
279
+ return self
280
+
281
+ # ===== Serialization version history =====
282
+ #
283
+ # Version 1/None
284
+ # self
285
+ # |--- weight : Tensor
286
+ # |--- bias : Tensor
287
+ # |--- gamma : Tensor
288
+ # |--- beta : Tensor
289
+ # |--- running_mean : Tensor
290
+ # |--- running_var : Tensor
291
+ # |--- num_batches_tracked : Tensor
292
+ #
293
+ # Version 2
294
+ # self
295
+ # |--- weight : Tensor
296
+ # |--- bias : Tensor
297
+ # |--- bn : Module
298
+ # |--- weight : Tensor (moved from v1.self.gamma)
299
+ # |--- bias : Tensor (moved from v1.self.beta)
300
+ # |--- running_mean : Tensor (moved from v1.self.running_mean)
301
+ # |--- running_var : Tensor (moved from v1.self.running_var)
302
+ # |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked)
303
+ def _load_from_state_dict(
304
+ self,
305
+ state_dict,
306
+ prefix,
307
+ local_metadata,
308
+ strict,
309
+ missing_keys,
310
+ unexpected_keys,
311
+ error_msgs,
312
+ ):
313
+ version = local_metadata.get("version", None)
314
+ if version is None or version == 1:
315
+ # BN related parameters and buffers were moved into the BN module for v2
316
+ v2_to_v1_names = {
317
+ "bn.weight": "gamma",
318
+ "bn.bias": "beta",
319
+ "bn.running_mean": "running_mean",
320
+ "bn.running_var": "running_var",
321
+ "bn.num_batches_tracked": "num_batches_tracked",
322
+ }
323
+ for v2_name, v1_name in v2_to_v1_names.items():
324
+ if prefix + v1_name in state_dict:
325
+ state_dict[prefix + v2_name] = state_dict[prefix + v1_name]
326
+ state_dict.pop(prefix + v1_name)
327
+ elif prefix + v2_name in state_dict:
328
+ # there was a brief period where forward compatibility
329
+ # for this module was broken (between
330
+ # https://github.com/pytorch/pytorch/pull/38478
331
+ # and https://github.com/pytorch/pytorch/pull/38820)
332
+ # and modules emitted the v2 state_dict format while
333
+ # specifying that version == 1. This patches the forward
334
+ # compatibility issue by allowing the v2 style entries to
335
+ # be used.
336
+ pass
337
+ elif strict:
338
+ missing_keys.append(prefix + v2_name)
339
+
340
+ super()._load_from_state_dict(
341
+ state_dict,
342
+ prefix,
343
+ local_metadata,
344
+ strict,
345
+ missing_keys,
346
+ unexpected_keys,
347
+ error_msgs,
348
+ )
349
+
350
+ @classmethod
351
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
352
+ r"""Create a qat module from a float module or qparams_dict
353
+
354
+ Args: `mod` a float module, either produced by torch.ao.quantization utilities
355
+ or directly from user
356
+ """
357
+ # The ignore is because _FLOAT_MODULE is a TypeVar here where the bound
358
+ # has no __name__ (code is fine though)
359
+ assert type(mod) == cls._FLOAT_MODULE, (
360
+ "qat."
361
+ + cls.__name__
362
+ + ".from_float only works for "
363
+ + cls._FLOAT_MODULE.__name__ # type: ignore[attr-defined]
364
+ )
365
+ assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
366
+ assert mod.qconfig, "Input float module must have a valid qconfig"
367
+ qconfig = mod.qconfig
368
+ conv, bn = mod[0], mod[1]
369
+ qat_convbn = cls(
370
+ conv.in_channels,
371
+ conv.out_channels,
372
+ conv.kernel_size,
373
+ conv.stride,
374
+ conv.padding,
375
+ conv.dilation,
376
+ conv.groups,
377
+ conv.bias is not None,
378
+ conv.padding_mode,
379
+ bn.eps,
380
+ bn.momentum,
381
+ False,
382
+ qconfig,
383
+ )
384
+ qat_convbn.weight = conv.weight
385
+ qat_convbn.bias = conv.bias
386
+ qat_convbn.bn.weight = bn.weight
387
+ qat_convbn.bn.bias = bn.bias
388
+ qat_convbn.bn.running_mean = bn.running_mean
389
+ qat_convbn.bn.running_var = bn.running_var
390
+ # mypy error: Cannot determine type of 'num_batches_tracked'
391
+ qat_convbn.bn.num_batches_tracked = bn.num_batches_tracked # type: ignore[has-type]
392
+ return qat_convbn
393
+
394
+ def to_float(self):
395
+ cls = type(self)
396
+ conv = cls._FLOAT_CONV_MODULE( # type: ignore[attr-defined]
397
+ self.in_channels,
398
+ self.out_channels,
399
+ self.kernel_size,
400
+ self.stride,
401
+ self.padding,
402
+ self.dilation,
403
+ self.groups,
404
+ self.bias is not None,
405
+ self.padding_mode,
406
+ )
407
+ conv.weight = torch.nn.Parameter(self.weight.detach())
408
+ if self.bias is not None:
409
+ conv.bias = torch.nn.Parameter(self.bias.detach())
410
+
411
+ if cls._FLOAT_BN_MODULE: # type: ignore[attr-defined]
412
+ # fuse bn into conv
413
+ assert self.bn.running_var is not None and self.bn.running_mean is not None
414
+ conv.weight, conv.bias = fuse_conv_bn_weights(
415
+ conv.weight,
416
+ conv.bias,
417
+ self.bn.running_mean,
418
+ self.bn.running_var,
419
+ self.bn.eps,
420
+ self.bn.weight,
421
+ self.bn.bias,
422
+ )
423
+
424
+ if cls._FLOAT_RELU_MODULE: # type: ignore[attr-defined]
425
+ modules = []
426
+ modules.append(conv)
427
+ relu = cls._FLOAT_RELU_MODULE() # type: ignore[attr-defined]
428
+ modules.append(relu)
429
+ conv_relu = cls._FUSED_FLOAT_MODULE(*modules) # type: ignore[attr-defined]
430
+ conv_relu.train(self.training)
431
+ return conv_relu
432
+ else:
433
+ conv.train(self.training)
434
+ return conv
435
+
436
+
437
+ class ConvBn1d(_ConvBnNd, nn.Conv1d):
438
+ r"""
439
+ A ConvBn1d module is a module fused from Conv1d and BatchNorm1d,
440
+ attached with FakeQuantize modules for weight,
441
+ used in quantization aware training.
442
+
443
+ We combined the interface of :class:`torch.nn.Conv1d` and
444
+ :class:`torch.nn.BatchNorm1d`.
445
+
446
+ Similar to :class:`torch.nn.Conv1d`, with FakeQuantize modules initialized
447
+ to default.
448
+
449
+ Attributes:
450
+ freeze_bn:
451
+ weight_fake_quant: fake quant module for weight
452
+
453
+ """
454
+ _FLOAT_BN_MODULE = nn.BatchNorm1d
455
+ _FLOAT_RELU_MODULE: None = None
456
+ _FLOAT_MODULE = nni.ConvBn1d
457
+ _FLOAT_CONV_MODULE = nn.Conv1d
458
+
459
+ def __init__(
460
+ self,
461
+ # Conv1d args
462
+ in_channels,
463
+ out_channels,
464
+ kernel_size,
465
+ stride=1,
466
+ padding=0,
467
+ dilation=1,
468
+ groups=1,
469
+ bias=None,
470
+ padding_mode="zeros",
471
+ # BatchNorm1d args
472
+ # num_features: out_channels
473
+ eps=1e-05,
474
+ momentum=0.1,
475
+ # affine: True
476
+ # track_running_stats: True
477
+ # Args for this module
478
+ freeze_bn=False,
479
+ qconfig=None,
480
+ ):
481
+ kernel_size = _single(kernel_size)
482
+ stride = _single(stride)
483
+ padding = _single(padding)
484
+ dilation = _single(dilation)
485
+ _ConvBnNd.__init__(
486
+ self,
487
+ in_channels,
488
+ out_channels,
489
+ kernel_size,
490
+ stride,
491
+ padding,
492
+ dilation,
493
+ False,
494
+ _single(0),
495
+ groups,
496
+ bias,
497
+ padding_mode,
498
+ eps,
499
+ momentum,
500
+ freeze_bn,
501
+ qconfig,
502
+ dim=1,
503
+ )
504
+
505
+
506
+ class ConvBnReLU1d(ConvBn1d):
507
+ r"""
508
+ A ConvBnReLU1d module is a module fused from Conv1d, BatchNorm1d and ReLU,
509
+ attached with FakeQuantize modules for weight,
510
+ used in quantization aware training.
511
+
512
+ We combined the interface of :class:`torch.nn.Conv1d` and
513
+ :class:`torch.nn.BatchNorm1d` and :class:`torch.nn.ReLU`.
514
+
515
+ Similar to `torch.nn.Conv1d`, with FakeQuantize modules initialized to
516
+ default.
517
+
518
+ Attributes:
519
+ weight_fake_quant: fake quant module for weight
520
+
521
+ """
522
+ # base class defines _FLOAT_MODULE as "ConvBn1d"
523
+ _FLOAT_MODULE = nni.ConvBnReLU1d # type: ignore[assignment]
524
+ _FLOAT_CONV_MODULE = nn.Conv1d
525
+ _FLOAT_BN_MODULE = nn.BatchNorm1d
526
+ _FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment]
527
+ # module class after fusing bn into conv
528
+ _FUSED_FLOAT_MODULE = nni.ConvReLU1d
529
+
530
+ def __init__(
531
+ self,
532
+ # Conv1d args
533
+ in_channels,
534
+ out_channels,
535
+ kernel_size,
536
+ stride=1,
537
+ padding=0,
538
+ dilation=1,
539
+ groups=1,
540
+ bias=None,
541
+ padding_mode="zeros",
542
+ # BatchNorm1d args
543
+ # num_features: out_channels
544
+ eps=1e-05,
545
+ momentum=0.1,
546
+ # affine: True
547
+ # track_running_stats: True
548
+ # Args for this module
549
+ freeze_bn=False,
550
+ qconfig=None,
551
+ ):
552
+ super().__init__(
553
+ in_channels,
554
+ out_channels,
555
+ kernel_size,
556
+ stride,
557
+ padding,
558
+ dilation,
559
+ groups,
560
+ bias,
561
+ padding_mode,
562
+ eps,
563
+ momentum,
564
+ freeze_bn,
565
+ qconfig,
566
+ )
567
+
568
+ def forward(self, input):
569
+ return F.relu(ConvBn1d._forward(self, input))
570
+
571
+ @classmethod
572
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
573
+ return super().from_float(mod, use_precomputed_fake_quant)
574
+
575
+
576
+ class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
577
+ r"""A ConvReLU1d module is a fused module of Conv1d and ReLU, attached with
578
+ FakeQuantize modules for weight for
579
+ quantization aware training.
580
+
581
+ We combined the interface of :class:`~torch.nn.Conv1d` and
582
+ :class:`~torch.nn.BatchNorm1d`.
583
+
584
+ Attributes:
585
+ weight_fake_quant: fake quant module for weight
586
+
587
+ """
588
+ _FLOAT_MODULE = nni.ConvReLU1d # type: ignore[assignment]
589
+ _FLOAT_CONV_MODULE = nn.Conv1d
590
+ _FLOAT_BN_MODULE: None = None
591
+ _FLOAT_RELU_MODULE = nn.ReLU
592
+
593
+ def __init__(
594
+ self,
595
+ in_channels,
596
+ out_channels,
597
+ kernel_size,
598
+ stride=1,
599
+ padding=0,
600
+ dilation=1,
601
+ groups=1,
602
+ bias=True,
603
+ padding_mode="zeros",
604
+ qconfig=None,
605
+ ):
606
+ super().__init__(
607
+ in_channels,
608
+ out_channels,
609
+ kernel_size,
610
+ stride=stride,
611
+ padding=padding,
612
+ dilation=dilation,
613
+ groups=groups,
614
+ bias=bias,
615
+ padding_mode=padding_mode,
616
+ qconfig=qconfig,
617
+ )
618
+ assert qconfig, "qconfig must be provided for QAT module"
619
+ self.qconfig = qconfig
620
+ self.weight_fake_quant = self.qconfig.weight()
621
+
622
+ def forward(self, input):
623
+ return F.relu(
624
+ self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
625
+ )
626
+
627
+ @classmethod
628
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
629
+ return super().from_float(
630
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
631
+ )
632
+
633
+
634
+ class ConvBn2d(_ConvBnNd, nn.Conv2d):
635
+ r"""
636
+ A ConvBn2d module is a module fused from Conv2d and BatchNorm2d,
637
+ attached with FakeQuantize modules for weight,
638
+ used in quantization aware training.
639
+
640
+ We combined the interface of :class:`torch.nn.Conv2d` and
641
+ :class:`torch.nn.BatchNorm2d`.
642
+
643
+ Similar to :class:`torch.nn.Conv2d`, with FakeQuantize modules initialized
644
+ to default.
645
+
646
+ Attributes:
647
+ freeze_bn:
648
+ weight_fake_quant: fake quant module for weight
649
+
650
+ """
651
+ _FLOAT_MODULE = nni.ConvBn2d
652
+ _FLOAT_CONV_MODULE = nn.Conv2d
653
+ _FLOAT_BN_MODULE = nn.BatchNorm2d
654
+ _FLOAT_RELU_MODULE: None = None
655
+
656
+ def __init__(
657
+ self,
658
+ # ConvNd args
659
+ in_channels,
660
+ out_channels,
661
+ kernel_size,
662
+ stride=1,
663
+ padding=0,
664
+ dilation=1,
665
+ groups=1,
666
+ bias=None,
667
+ padding_mode="zeros",
668
+ # BatchNorm2d args
669
+ # num_features: out_channels
670
+ eps=1e-05,
671
+ momentum=0.1,
672
+ # affine: True
673
+ # track_running_stats: True
674
+ # Args for this module
675
+ freeze_bn=False,
676
+ qconfig=None,
677
+ ):
678
+ kernel_size = _pair(kernel_size)
679
+ stride = _pair(stride)
680
+ padding = _pair(padding)
681
+ dilation = _pair(dilation)
682
+ _ConvBnNd.__init__(
683
+ self,
684
+ in_channels,
685
+ out_channels,
686
+ kernel_size,
687
+ stride,
688
+ padding,
689
+ dilation,
690
+ False,
691
+ _pair(0),
692
+ groups,
693
+ bias,
694
+ padding_mode,
695
+ eps,
696
+ momentum,
697
+ freeze_bn,
698
+ qconfig,
699
+ dim=2,
700
+ )
701
+
702
+
703
+ class ConvBnReLU2d(ConvBn2d):
704
+ r"""
705
+ A ConvBnReLU2d module is a module fused from Conv2d, BatchNorm2d and ReLU,
706
+ attached with FakeQuantize modules for weight,
707
+ used in quantization aware training.
708
+
709
+ We combined the interface of :class:`torch.nn.Conv2d` and
710
+ :class:`torch.nn.BatchNorm2d` and :class:`torch.nn.ReLU`.
711
+
712
+ Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to
713
+ default.
714
+
715
+ Attributes:
716
+ weight_fake_quant: fake quant module for weight
717
+
718
+ """
719
+ # base class defines _FLOAT_MODULE as "ConvBn2d"
720
+ _FLOAT_MODULE = nni.ConvBnReLU2d # type: ignore[assignment]
721
+ _FLOAT_CONV_MODULE = nn.Conv2d
722
+ _FLOAT_BN_MODULE = nn.BatchNorm2d
723
+ _FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment]
724
+ # module class after fusing bn into conv
725
+ _FUSED_FLOAT_MODULE = nni.ConvReLU2d
726
+
727
+ def __init__(
728
+ self,
729
+ # Conv2d args
730
+ in_channels,
731
+ out_channels,
732
+ kernel_size,
733
+ stride=1,
734
+ padding=0,
735
+ dilation=1,
736
+ groups=1,
737
+ bias=None,
738
+ padding_mode="zeros",
739
+ # BatchNorm2d args
740
+ # num_features: out_channels
741
+ eps=1e-05,
742
+ momentum=0.1,
743
+ # affine: True
744
+ # track_running_stats: True
745
+ # Args for this module
746
+ freeze_bn=False,
747
+ qconfig=None,
748
+ ):
749
+ super().__init__(
750
+ in_channels,
751
+ out_channels,
752
+ kernel_size,
753
+ stride,
754
+ padding,
755
+ dilation,
756
+ groups,
757
+ bias,
758
+ padding_mode,
759
+ eps,
760
+ momentum,
761
+ freeze_bn,
762
+ qconfig,
763
+ )
764
+
765
+ def forward(self, input):
766
+ return F.relu(ConvBn2d._forward(self, input))
767
+
768
+ @classmethod
769
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
770
+ return super().from_float(mod, use_precomputed_fake_quant)
771
+
772
+
773
+ class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
774
+ r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with
775
+ FakeQuantize modules for weight for
776
+ quantization aware training.
777
+
778
+ We combined the interface of :class:`~torch.nn.Conv2d` and
779
+ :class:`~torch.nn.BatchNorm2d`.
780
+
781
+ Attributes:
782
+ weight_fake_quant: fake quant module for weight
783
+
784
+ """
785
+ _FLOAT_MODULE = nni.ConvReLU2d # type: ignore[assignment]
786
+ _FLOAT_CONV_MODULE = nn.Conv2d
787
+ _FLOAT_BN_MODULE: None = None
788
+ _FLOAT_RELU_MODULE = nn.ReLU
789
+
790
+ def __init__(
791
+ self,
792
+ in_channels,
793
+ out_channels,
794
+ kernel_size,
795
+ stride=1,
796
+ padding=0,
797
+ dilation=1,
798
+ groups=1,
799
+ bias=True,
800
+ padding_mode="zeros",
801
+ qconfig=None,
802
+ ):
803
+ super().__init__(
804
+ in_channels,
805
+ out_channels,
806
+ kernel_size,
807
+ stride=stride,
808
+ padding=padding,
809
+ dilation=dilation,
810
+ groups=groups,
811
+ bias=bias,
812
+ padding_mode=padding_mode,
813
+ qconfig=qconfig,
814
+ )
815
+ assert qconfig, "qconfig must be provided for QAT module"
816
+ self.qconfig = qconfig
817
+ self.weight_fake_quant = self.qconfig.weight()
818
+
819
+ def forward(self, input):
820
+ return F.relu(
821
+ self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
822
+ )
823
+
824
+ @classmethod
825
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
826
+ return super().from_float(
827
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
828
+ )
829
+
830
+
831
+ class ConvBn3d(_ConvBnNd, nn.Conv3d):
832
+ r"""
833
+ A ConvBn3d module is a module fused from Conv3d and BatchNorm3d,
834
+ attached with FakeQuantize modules for weight,
835
+ used in quantization aware training.
836
+
837
+ We combined the interface of :class:`torch.nn.Conv3d` and
838
+ :class:`torch.nn.BatchNorm3d`.
839
+
840
+ Similar to :class:`torch.nn.Conv3d`, with FakeQuantize modules initialized
841
+ to default.
842
+
843
+ Attributes:
844
+ freeze_bn:
845
+ weight_fake_quant: fake quant module for weight
846
+
847
+ """
848
+ _FLOAT_MODULE = nni.ConvBn3d
849
+ _FLOAT_CONV_MODULE = nn.Conv3d
850
+ _FLOAT_BN_MODULE = nn.BatchNorm3d
851
+ _FLOAT_RELU_MODULE: None = None
852
+
853
+ def __init__(
854
+ self,
855
+ # ConvNd args
856
+ in_channels,
857
+ out_channels,
858
+ kernel_size,
859
+ stride=1,
860
+ padding=0,
861
+ dilation=1,
862
+ groups=1,
863
+ bias=None,
864
+ padding_mode="zeros",
865
+ # BatchNorm3d args
866
+ # num_features: out_channels
867
+ eps=1e-05,
868
+ momentum=0.1,
869
+ # affine: True
870
+ # track_running_stats: True
871
+ # Args for this module
872
+ freeze_bn=False,
873
+ qconfig=None,
874
+ ):
875
+ kernel_size = _triple(kernel_size)
876
+ stride = _triple(stride)
877
+ padding = _triple(padding)
878
+ dilation = _triple(dilation)
879
+ _ConvBnNd.__init__(
880
+ self,
881
+ in_channels,
882
+ out_channels,
883
+ kernel_size,
884
+ stride,
885
+ padding,
886
+ dilation,
887
+ False,
888
+ _triple(0),
889
+ groups,
890
+ bias,
891
+ padding_mode,
892
+ eps,
893
+ momentum,
894
+ freeze_bn,
895
+ qconfig,
896
+ dim=3,
897
+ )
898
+
899
+
900
+ class ConvBnReLU3d(ConvBn3d):
901
+ r"""
902
+ A ConvBnReLU3d module is a module fused from Conv3d, BatchNorm3d and ReLU,
903
+ attached with FakeQuantize modules for weight,
904
+ used in quantization aware training.
905
+
906
+ We combined the interface of :class:`torch.nn.Conv3d` and
907
+ :class:`torch.nn.BatchNorm3d` and :class:`torch.nn.ReLU`.
908
+
909
+ Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to
910
+ default.
911
+
912
+ Attributes:
913
+ weight_fake_quant: fake quant module for weight
914
+
915
+ """
916
+ _FLOAT_MODULE = nni.ConvBnReLU3d # type: ignore[assignment]
917
+ _FLOAT_CONV_MODULE = nn.Conv3d
918
+ _FLOAT_BN_MODULE = nn.BatchNorm3d
919
+ _FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment]
920
+ # module class after fusing bn into conv
921
+ _FUSED_FLOAT_MODULE = nni.ConvReLU3d
922
+
923
+ def __init__(
924
+ self,
925
+ # Conv3d args
926
+ in_channels,
927
+ out_channels,
928
+ kernel_size,
929
+ stride=1,
930
+ padding=0,
931
+ dilation=1,
932
+ groups=1,
933
+ bias=None,
934
+ padding_mode="zeros",
935
+ # BatchNorm3d args
936
+ # num_features: out_channels
937
+ eps=1e-05,
938
+ momentum=0.1,
939
+ # affine: True
940
+ # track_running_stats: True
941
+ # Args for this module
942
+ freeze_bn=False,
943
+ qconfig=None,
944
+ ):
945
+ super().__init__(
946
+ in_channels,
947
+ out_channels,
948
+ kernel_size,
949
+ stride,
950
+ padding,
951
+ dilation,
952
+ groups,
953
+ bias,
954
+ padding_mode,
955
+ eps,
956
+ momentum,
957
+ freeze_bn,
958
+ qconfig,
959
+ )
960
+
961
+ def forward(self, input):
962
+ return F.relu(ConvBn3d._forward(self, input))
963
+
964
+ @classmethod
965
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
966
+ return super().from_float(
967
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
968
+ )
969
+
970
+
971
+ class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
972
+ r"""A ConvReLU3d module is a fused module of Conv3d and ReLU, attached with
973
+ FakeQuantize modules for weight for
974
+ quantization aware training.
975
+
976
+ We combined the interface of :class:`~torch.nn.Conv3d` and
977
+ :class:`~torch.nn.BatchNorm3d`.
978
+
979
+ Attributes:
980
+ weight_fake_quant: fake quant module for weight
981
+
982
+ """
983
+ _FLOAT_MODULE = nni.ConvReLU3d # type: ignore[assignment]
984
+ _FLOAT_CONV_MODULE = nn.Conv3d
985
+ _FLOAT_BN_MODULE: None = None
986
+ _FLOAT_RELU_MODULE = nn.ReLU
987
+
988
+ def __init__(
989
+ self,
990
+ in_channels,
991
+ out_channels,
992
+ kernel_size,
993
+ stride=1,
994
+ padding=0,
995
+ dilation=1,
996
+ groups=1,
997
+ bias=True,
998
+ padding_mode="zeros",
999
+ qconfig=None,
1000
+ ):
1001
+ super().__init__(
1002
+ in_channels,
1003
+ out_channels,
1004
+ kernel_size,
1005
+ stride=stride,
1006
+ padding=padding,
1007
+ dilation=dilation,
1008
+ groups=groups,
1009
+ bias=bias,
1010
+ padding_mode=padding_mode,
1011
+ qconfig=qconfig,
1012
+ )
1013
+ assert qconfig, "qconfig must be provided for QAT module"
1014
+ self.qconfig = qconfig
1015
+ self.weight_fake_quant = self.qconfig.weight()
1016
+
1017
+ def forward(self, input):
1018
+ return F.relu(
1019
+ self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
1020
+ )
1021
+
1022
+ @classmethod
1023
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
1024
+ return super().from_float(
1025
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
1026
+ )
1027
+
1028
+
1029
+ def update_bn_stats(mod):
1030
+ if type(mod) in {
1031
+ ConvBnReLU1d,
1032
+ ConvBnReLU2d,
1033
+ ConvBnReLU3d,
1034
+ ConvBn1d,
1035
+ ConvBn2d,
1036
+ ConvBn3d,
1037
+ }:
1038
+ mod.update_bn_stats()
1039
+
1040
+
1041
+ def freeze_bn_stats(mod):
1042
+ if type(mod) in {
1043
+ ConvBnReLU1d,
1044
+ ConvBnReLU2d,
1045
+ ConvBnReLU3d,
1046
+ ConvBn1d,
1047
+ ConvBn2d,
1048
+ ConvBn3d,
1049
+ }:
1050
+ mod.freeze_bn_stats()
.venv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_fused.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+ import torch.ao.nn.intrinsic as nni
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn import init
7
+ from torch.nn.parameter import Parameter
8
+ from torch.nn.utils.fusion import fuse_linear_bn_weights
9
+
10
+
11
+ __all__ = [
12
+ "LinearBn1d",
13
+ ]
14
+
15
+
16
+ class LinearBn1d(nn.modules.linear.Linear, nni._FusedModule):
17
+ r"""
18
+ A LinearBn1d module is a module fused from Linear and BatchNorm1d, attached
19
+ with FakeQuantize modules for weight, used in quantization aware training.
20
+
21
+ We combined the interface of :class:`torch.nn.Linear` and
22
+ :class:torch.nn.BatchNorm1d`.
23
+
24
+ Similar to :class:`torch.nn.Linear`, with FakeQuantize modules initialized
25
+ to default.
26
+
27
+ Attributes:
28
+ freeze_bn:
29
+ weight_fake_quant: fake quant module for weight
30
+
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ # Linear args
36
+ in_features,
37
+ out_features,
38
+ bias=True,
39
+ # BatchNorm1d args
40
+ # num_features: out_features
41
+ eps=1e-05,
42
+ momentum=0.1,
43
+ # affine: True
44
+ # track_running_stats: True
45
+ # Args for this module
46
+ freeze_bn=False,
47
+ qconfig=None,
48
+ ):
49
+ nn.modules.linear.Linear.__init__(self, in_features, out_features, bias)
50
+ assert qconfig, "qconfig must be provided for QAT module"
51
+ self.qconfig = qconfig
52
+ self.freeze_bn = freeze_bn if self.training else True
53
+ self.bn = nn.BatchNorm1d(out_features, eps, momentum, True, True)
54
+ self.weight_fake_quant = self.qconfig.weight()
55
+ if bias:
56
+ self.bias = Parameter(torch.empty(out_features))
57
+ else:
58
+ self.register_parameter("bias", None)
59
+ self.reset_bn_parameters()
60
+
61
+ # this needs to be called after reset_bn_parameters,
62
+ # as they modify the same state
63
+ if self.training:
64
+ if freeze_bn:
65
+ self.freeze_bn_stats()
66
+ else:
67
+ self.update_bn_stats()
68
+ else:
69
+ self.freeze_bn_stats()
70
+
71
+ def reset_running_stats(self):
72
+ self.bn.reset_running_stats()
73
+
74
+ def reset_bn_parameters(self):
75
+ self.bn.reset_running_stats()
76
+ init.uniform_(self.bn.weight)
77
+ init.zeros_(self.bn.bias)
78
+
79
+ def reset_parameters(self):
80
+ super().reset_parameters()
81
+
82
+ def update_bn_stats(self):
83
+ self.freeze_bn = False
84
+ self.bn.training = True
85
+ return self
86
+
87
+ def freeze_bn_stats(self):
88
+ self.freeze_bn = True
89
+ self.bn.training = False
90
+ return self
91
+
92
+ def forward(self, input):
93
+ assert self.bn.running_var is not None
94
+
95
+ # Scale the linear weights by BN's running statistics to reduce
96
+ # weight jitter, see https://arxiv.org/pdf/1806.08342.pdf, page 18
97
+ # for motivation.
98
+ #
99
+ # Instead of
100
+ #
101
+ # x1 = F.linear(x0, fq(w), b)
102
+ # x2 = self.bn(x1)
103
+ #
104
+ # We have
105
+ #
106
+ # # scale the weight by previous batch's running statistics
107
+ # scale_factor = bn.w / bn.running_std_from_prev_batch
108
+ # # do the linear transformation without bias
109
+ # x1_scaled = F.linear(x0, fq(w * scale_factor), 0)
110
+ # # reverse the scaling and add original bias
111
+ # x1_orig = x1_scaled / scale_factor + b
112
+ # x2 = self.bn(x1_orig)
113
+
114
+ running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
115
+ scale_factor = self.bn.weight / running_std
116
+ weight_shape = [1] * len(self.weight.shape)
117
+ weight_shape[0] = -1
118
+ bias_shape = [1] * len(self.weight.shape)
119
+ bias_shape[1] = -1
120
+ scaled_weight = self.weight_fake_quant(
121
+ self.weight * scale_factor.reshape(weight_shape)
122
+ )
123
+ if self.bias is not None:
124
+ zero_bias = torch.zeros_like(self.bias)
125
+ else:
126
+ zero_bias = torch.zeros(self.out_features, device=scaled_weight.device)
127
+ linear_out = F.linear(input, scaled_weight, zero_bias)
128
+ linear_out_orig = linear_out / scale_factor.reshape(bias_shape)
129
+ if self.bias is not None:
130
+ linear_out_orig = linear_out_orig + self.bias.reshape(bias_shape)
131
+ bn_out = self.bn(linear_out_orig)
132
+ return bn_out
133
+
134
+ def train(self, mode=True):
135
+ """
136
+ Batchnorm's training behavior is using the self.training flag. Prevent
137
+ changing it if BN is frozen. This makes sure that calling `model.train()`
138
+ on a model with a frozen BN will behave properly.
139
+ """
140
+ self.training = mode
141
+ if not self.freeze_bn:
142
+ for module in self.children():
143
+ module.train(mode)
144
+ return self
145
+
146
+ @classmethod
147
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
148
+ r"""Create a qat module from a float module or qparams_dict
149
+
150
+ Args: `mod' a float module, either produced by torch.ao.quantization
151
+ utilities or directly from user
152
+ """
153
+ assert type(mod) == nni.LinearBn1d, (
154
+ "qat."
155
+ + cls.__name__
156
+ + ".from_float only works for "
157
+ + nni.LinearBn1d.__name__
158
+ )
159
+ assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
160
+ assert mod.qconfig, "Input float module must have a valid config"
161
+ qconfig = mod.qconfig
162
+ linear, bn = mod[0], mod[1]
163
+ qat_linearbn = cls(
164
+ linear.in_features,
165
+ linear.out_features,
166
+ linear.bias is not None,
167
+ bn.eps,
168
+ bn.momentum,
169
+ False,
170
+ qconfig,
171
+ )
172
+ qat_linearbn.weight = linear.weight
173
+ qat_linearbn.bias = linear.bias
174
+ qat_linearbn.bn.weight = bn.weight
175
+ qat_linearbn.bn.bias = bn.bias
176
+ qat_linearbn.bn.running_mean = bn.running_mean
177
+ qat_linearbn.bn.running_var = bn.running_var
178
+ qat_linearbn.bn.num_batches_tracked = bn.num_batches_tracked
179
+ return qat_linearbn
180
+
181
+ def to_float(self):
182
+ linear = torch.nn.Linear(self.in_features, self.out_features)
183
+ assert self.bn.running_var is not None and self.bn.running_mean is not None
184
+ linear.weight, linear.bias = fuse_linear_bn_weights(
185
+ self.weight,
186
+ self.bias,
187
+ self.bn.running_mean,
188
+ self.bn.running_var,
189
+ self.bn.eps,
190
+ self.bn.weight,
191
+ self.bn.bias,
192
+ )
193
+ return linear
.venv/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_relu.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+ import torch.ao.nn.intrinsic as nni
4
+ import torch.ao.nn.qat as nnqat
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class LinearReLU(nnqat.Linear, nni._FusedModule):
9
+ r"""
10
+ A LinearReLU module fused from Linear and ReLU modules, attached with
11
+ FakeQuantize modules for weight, used in
12
+ quantization aware training.
13
+
14
+ We adopt the same interface as :class:`torch.nn.Linear`.
15
+
16
+ Similar to `torch.ao.nn.intrinsic.LinearReLU`, with FakeQuantize modules initialized to
17
+ default.
18
+
19
+ Attributes:
20
+ weight: fake quant module for weight
21
+
22
+ Examples::
23
+
24
+ >>> # xdoctest: +SKIP
25
+ >>> m = nn.qat.LinearReLU(20, 30)
26
+ >>> input = torch.randn(128, 20)
27
+ >>> output = m(input)
28
+ >>> print(output.size())
29
+ torch.Size([128, 30])
30
+ """
31
+ _FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment]
32
+
33
+ def __init__(self, in_features, out_features, bias=True, qconfig=None):
34
+ super().__init__(in_features, out_features, bias, qconfig)
35
+
36
+ def forward(self, input):
37
+ return F.relu(F.linear(input, self.weight_fake_quant(self.weight), self.bias))
38
+
39
+ @classmethod
40
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
41
+ return super().from_float(mod, use_precomputed_fake_quant)
42
+
43
+ def to_float(self):
44
+ linear = torch.nn.Linear(
45
+ self.in_features, self.out_features, self.bias is not None
46
+ )
47
+ linear.weight = torch.nn.Parameter(self.weight.detach())
48
+ if self.bias is not None:
49
+ linear.bias = torch.nn.Parameter(self.bias.detach())
50
+ relu = torch.nn.ReLU()
51
+ return torch.ao.nn.intrinsic.LinearReLU(linear, relu)
.venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (235 Bytes). View file
 
.venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (289 Bytes). View file
 
.venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .bn_relu import BNReLU2d, BNReLU3d
2
+ from .conv_add import ConvAdd2d, ConvAddReLU2d
3
+ from .conv_relu import ConvReLU1d, ConvReLU2d, ConvReLU3d
4
+ from .linear_relu import LinearLeakyReLU, LinearReLU, LinearTanh
5
+
6
+
7
+ __all__ = [
8
+ "LinearReLU",
9
+ "ConvReLU1d",
10
+ "ConvReLU2d",
11
+ "ConvReLU3d",
12
+ "BNReLU2d",
13
+ "BNReLU3d",
14
+ "LinearLeakyReLU",
15
+ "LinearTanh",
16
+ "ConvAdd2d",
17
+ "ConvAddReLU2d",
18
+ ]
.venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-39.pyc ADDED
Binary file (3.43 kB). View file
 
.venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_add.cpython-39.pyc ADDED
Binary file (3.79 kB). View file
 
.venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-39.pyc ADDED
Binary file (6.51 kB). View file
 
.venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+
3
+ import torch
4
+ import torch.ao.nn.intrinsic
5
+ import torch.ao.nn.intrinsic.qat
6
+ import torch.ao.nn.quantized as nnq
7
+
8
+
9
+ __all__ = ["BNReLU2d", "BNReLU3d"]
10
+
11
+
12
+ class BNReLU2d(nnq.BatchNorm2d):
13
+ r"""
14
+ A BNReLU2d module is a fused module of BatchNorm2d and ReLU
15
+
16
+ We adopt the same interface as :class:`torch.ao.nn.quantized.BatchNorm2d`.
17
+
18
+ Attributes:
19
+ Same as torch.ao.nn.quantized.BatchNorm2d
20
+
21
+ """
22
+ _FLOAT_MODULE = torch.ao.nn.intrinsic.BNReLU2d
23
+
24
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None):
25
+ super().__init__(
26
+ num_features, eps=eps, momentum=momentum, device=device, dtype=dtype
27
+ )
28
+
29
+ def forward(self, input):
30
+ # Temporarily using len(shape) instead of ndim due to JIT issue
31
+ # https://github.com/pytorch/pytorch/issues/23890
32
+ if len(input.shape) != 4:
33
+ raise ValueError("Input shape must be `(N, C, H, W)`!")
34
+ return torch.ops.quantized.batch_norm2d_relu(
35
+ input,
36
+ self.weight,
37
+ self.bias,
38
+ self.running_mean,
39
+ self.running_var,
40
+ self.eps,
41
+ self.scale,
42
+ self.zero_point,
43
+ )
44
+
45
+ def _get_name(self):
46
+ return "QuantizedBNReLU2d"
47
+
48
+ @classmethod
49
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
50
+ # TODO: Add qat support for BNReLU2d
51
+ return super().from_float(
52
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
53
+ )
54
+
55
+ @classmethod
56
+ def from_reference(cls, bn_relu, output_scale, output_zero_point):
57
+ return super().from_reference(bn_relu[0], output_scale, output_zero_point)
58
+
59
+
60
+ class BNReLU3d(nnq.BatchNorm3d):
61
+ r"""
62
+ A BNReLU3d module is a fused module of BatchNorm3d and ReLU
63
+
64
+ We adopt the same interface as :class:`torch.ao.nn.quantized.BatchNorm3d`.
65
+
66
+ Attributes:
67
+ Same as torch.ao.nn.quantized.BatchNorm3d
68
+
69
+ """
70
+ _FLOAT_MODULE = torch.ao.nn.intrinsic.BNReLU3d
71
+
72
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None):
73
+ super().__init__(
74
+ num_features, eps=eps, momentum=momentum, device=device, dtype=dtype
75
+ )
76
+
77
+ def forward(self, input):
78
+ # Temporarily using len(shape) instead of ndim due to JIT issue
79
+ # https://github.com/pytorch/pytorch/issues/23890
80
+ if len(input.shape) != 5:
81
+ raise ValueError("Input shape must be `(N, C, D, H, W)`!")
82
+ return torch.ops.quantized.batch_norm3d_relu(
83
+ input,
84
+ self.weight,
85
+ self.bias,
86
+ self.running_mean,
87
+ self.running_var,
88
+ self.eps,
89
+ self.scale,
90
+ self.zero_point,
91
+ )
92
+
93
+ def _get_name(self):
94
+ return "QuantizedBNReLU3d"
95
+
96
+ @classmethod
97
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
98
+ # TODO: Add qat support for BNReLU3d
99
+ return super().from_float(
100
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
101
+ )
102
+
103
+ @classmethod
104
+ def from_reference(cls, bn_relu, output_scale, output_zero_point):
105
+ return super().from_reference(bn_relu[0], output_scale, output_zero_point)
.venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_add.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+ import torch.ao.nn.intrinsic
4
+ import torch.ao.nn.intrinsic.qat
5
+ import torch.ao.nn.quantized as nnq
6
+ import torch.nn.functional as F
7
+
8
+
9
+ _reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding
10
+
11
+
12
+ class ConvAdd2d(nnq.Conv2d):
13
+ r"""
14
+ A ConvAdd2d module is a fused module of Conv2d and Add
15
+
16
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
17
+
18
+ Attributes:
19
+ Same as torch.ao.nn.quantized.Conv2d
20
+
21
+ """
22
+ _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvAdd2d # type: ignore[assignment]
23
+
24
+ def __init__(
25
+ self,
26
+ in_channels,
27
+ out_channels,
28
+ kernel_size,
29
+ stride=1,
30
+ padding=0,
31
+ dilation=1,
32
+ groups=1,
33
+ bias=True,
34
+ padding_mode="zeros",
35
+ device=None,
36
+ dtype=None,
37
+ ):
38
+ super().__init__(
39
+ in_channels,
40
+ out_channels,
41
+ kernel_size,
42
+ stride=stride,
43
+ padding=padding,
44
+ dilation=dilation,
45
+ groups=groups,
46
+ bias=bias,
47
+ padding_mode=padding_mode,
48
+ device=device,
49
+ dtype=dtype,
50
+ )
51
+
52
+ def forward(self, input, extra_input):
53
+ # Temporarily using len(shape) instead of ndim due to JIT issue
54
+ # https://github.com/pytorch/pytorch/issues/23890
55
+ if len(input.shape) != 4:
56
+ raise ValueError("Input shape must be `(N, C, H, W)`!")
57
+ if self.padding_mode != "zeros":
58
+ _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
59
+ input = F.pad(
60
+ input, _reversed_padding_repeated_twice, mode=self.padding_mode
61
+ )
62
+ return torch.ops.quantized.conv2d_add(
63
+ input, extra_input, self._packed_params, self.scale, self.zero_point
64
+ )
65
+
66
+ def _get_name(self):
67
+ return "QuantizedConvAdd2d"
68
+
69
+ @classmethod
70
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
71
+ return super().from_float(
72
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
73
+ )
74
+
75
+ @classmethod
76
+ def from_reference(cls, ref_qconv, output_scale, output_zero_point):
77
+ return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
78
+
79
+
80
+ class ConvAddReLU2d(nnq.Conv2d):
81
+ r"""
82
+ A ConvAddReLU2d module is a fused module of Conv2d, Add and Relu
83
+
84
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
85
+
86
+ Attributes:
87
+ Same as torch.ao.nn.quantized.Conv2d
88
+
89
+ """
90
+ _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvAddReLU2d # type: ignore[assignment]
91
+
92
+ def __init__(
93
+ self,
94
+ in_channels,
95
+ out_channels,
96
+ kernel_size,
97
+ stride=1,
98
+ padding=0,
99
+ dilation=1,
100
+ groups=1,
101
+ bias=True,
102
+ padding_mode="zeros",
103
+ device=None,
104
+ dtype=None,
105
+ ):
106
+ super().__init__(
107
+ in_channels,
108
+ out_channels,
109
+ kernel_size,
110
+ stride=stride,
111
+ padding=padding,
112
+ dilation=dilation,
113
+ groups=groups,
114
+ bias=bias,
115
+ padding_mode=padding_mode,
116
+ device=device,
117
+ dtype=dtype,
118
+ )
119
+
120
+ def forward(self, input, extra_input):
121
+ # Temporarily using len(shape) instead of ndim due to JIT issue
122
+ # https://github.com/pytorch/pytorch/issues/23890
123
+ if len(input.shape) != 4:
124
+ raise ValueError("Input shape must be `(N, C, H, W)`!")
125
+ if self.padding_mode != "zeros":
126
+ _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
127
+ input = F.pad(
128
+ input, _reversed_padding_repeated_twice, mode=self.padding_mode
129
+ )
130
+ return torch.ops.quantized.conv2d_add_relu(
131
+ input, extra_input, self._packed_params, self.scale, self.zero_point
132
+ )
133
+
134
+ def _get_name(self):
135
+ return "QuantizedConvAddReLU2d"
136
+
137
+ @classmethod
138
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
139
+ return super().from_float(
140
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
141
+ )
142
+
143
+ @classmethod
144
+ def from_reference(cls, ref_qconv, output_scale, output_zero_point):
145
+ return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
.venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+
3
+ import torch
4
+ import torch.ao.nn.intrinsic
5
+ import torch.ao.nn.intrinsic.qat
6
+ import torch.ao.nn.quantized as nnq
7
+ import torch.nn.functional as F
8
+ from torch.nn.utils import fuse_conv_bn_weights
9
+
10
+
11
+ __all__ = [
12
+ "ConvReLU1d",
13
+ "ConvReLU2d",
14
+ "ConvReLU3d",
15
+ ]
16
+
17
+ _reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding
18
+
19
+
20
+ # TODO: factor out the common parts to ConvNd
21
+ class ConvReLU1d(nnq.Conv1d):
22
+ r"""
23
+ A ConvReLU1d module is a fused module of Conv1d and ReLU
24
+
25
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Conv1d`.
26
+
27
+ Attributes:
28
+ Same as torch.ao.nn.quantized.Conv1d
29
+
30
+ """
31
+ _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU1d # type: ignore[assignment]
32
+
33
+ def __init__(
34
+ self,
35
+ in_channels,
36
+ out_channels,
37
+ kernel_size,
38
+ stride=1,
39
+ padding=0,
40
+ dilation=1,
41
+ groups=1,
42
+ bias=True,
43
+ padding_mode="zeros",
44
+ device=None,
45
+ dtype=None,
46
+ ):
47
+ super().__init__(
48
+ in_channels,
49
+ out_channels,
50
+ kernel_size,
51
+ stride=stride,
52
+ padding=padding,
53
+ dilation=dilation,
54
+ groups=groups,
55
+ bias=bias,
56
+ padding_mode=padding_mode,
57
+ device=device,
58
+ dtype=dtype,
59
+ )
60
+
61
+ def forward(self, input):
62
+ # Temporarily using len(shape) instead of ndim due to JIT issue
63
+ # https://github.com/pytorch/pytorch/issues/23890
64
+ if len(input.shape) != 3:
65
+ raise ValueError("Input shape must be `(N, C, L)`!")
66
+ if self.padding_mode != "zeros":
67
+ # Padding in Conv1d is stored as (p, p), need to get (p,)
68
+ _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1])
69
+ input = F.pad(
70
+ input, _reversed_padding_repeated_twice, mode=self.padding_mode
71
+ )
72
+ return torch.ops.quantized.conv1d_relu(
73
+ input, self._packed_params, self.scale, self.zero_point
74
+ )
75
+
76
+ def _get_name(self):
77
+ return "QuantizedConvReLU1d"
78
+
79
+ @classmethod
80
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
81
+ if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d:
82
+ assert mod.bn.running_var is not None and mod.bn.running_mean is not None
83
+ mod.weight, mod.bias = fuse_conv_bn_weights(
84
+ mod.weight,
85
+ mod.bias,
86
+ mod.bn.running_mean,
87
+ mod.bn.running_var,
88
+ mod.bn.eps,
89
+ mod.bn.weight,
90
+ mod.bn.bias,
91
+ )
92
+ return super().from_float(mod, use_precomputed_fake_quant)
93
+
94
+ @classmethod
95
+ def from_reference(cls, ref_qconv, output_scale, output_zero_point):
96
+ assert (
97
+ type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU1d
98
+ ), "BatchNorm1d should be fused into Conv1d before converting to reference module"
99
+ return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
100
+
101
+
102
+ class ConvReLU2d(nnq.Conv2d):
103
+ r"""
104
+ A ConvReLU2d module is a fused module of Conv2d and ReLU
105
+
106
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
107
+
108
+ Attributes:
109
+ Same as torch.ao.nn.quantized.Conv2d
110
+
111
+ """
112
+ _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU2d # type: ignore[assignment]
113
+
114
+ def __init__(
115
+ self,
116
+ in_channels,
117
+ out_channels,
118
+ kernel_size,
119
+ stride=1,
120
+ padding=0,
121
+ dilation=1,
122
+ groups=1,
123
+ bias=True,
124
+ padding_mode="zeros",
125
+ device=None,
126
+ dtype=None,
127
+ ):
128
+ super().__init__(
129
+ in_channels,
130
+ out_channels,
131
+ kernel_size,
132
+ stride=stride,
133
+ padding=padding,
134
+ dilation=dilation,
135
+ groups=groups,
136
+ bias=bias,
137
+ padding_mode=padding_mode,
138
+ device=device,
139
+ dtype=dtype,
140
+ )
141
+
142
+ def forward(self, input):
143
+ # Temporarily using len(shape) instead of ndim due to JIT issue
144
+ # https://github.com/pytorch/pytorch/issues/23890
145
+ if len(input.shape) != 4:
146
+ raise ValueError("Input shape must be `(N, C, H, W)`!")
147
+ if self.padding_mode != "zeros":
148
+ _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
149
+ input = F.pad(
150
+ input, _reversed_padding_repeated_twice, mode=self.padding_mode
151
+ )
152
+ return torch.ops.quantized.conv2d_relu(
153
+ input, self._packed_params, self.scale, self.zero_point
154
+ )
155
+
156
+ def _get_name(self):
157
+ return "QuantizedConvReLU2d"
158
+
159
+ @classmethod
160
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
161
+ if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d:
162
+ assert mod.bn.running_var is not None and mod.bn.running_mean is not None
163
+ mod.weight, mod.bias = fuse_conv_bn_weights(
164
+ mod.weight,
165
+ mod.bias,
166
+ mod.bn.running_mean,
167
+ mod.bn.running_var,
168
+ mod.bn.eps,
169
+ mod.bn.weight,
170
+ mod.bn.bias,
171
+ )
172
+ return super().from_float(
173
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
174
+ )
175
+
176
+ @classmethod
177
+ def from_reference(cls, ref_qconv, output_scale, output_zero_point):
178
+ assert (
179
+ type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU2d
180
+ ), "BatchNorm2d should be fused into Conv2d before converting to reference module"
181
+ return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
182
+
183
+
184
+ class ConvReLU3d(nnq.Conv3d):
185
+ r"""
186
+ A ConvReLU3d module is a fused module of Conv3d and ReLU
187
+
188
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Conv3d`.
189
+
190
+ Attributes: Same as torch.ao.nn.quantized.Conv3d
191
+
192
+ """
193
+ _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU3d # type: ignore[assignment]
194
+
195
+ def __init__(
196
+ self,
197
+ in_channels,
198
+ out_channels,
199
+ kernel_size,
200
+ stride=1,
201
+ padding=0,
202
+ dilation=1,
203
+ groups=1,
204
+ bias=True,
205
+ padding_mode="zeros",
206
+ device=None,
207
+ dtype=None,
208
+ ):
209
+ assert padding_mode != "reflect", "Conv3d does not support reflection padding"
210
+ super().__init__(
211
+ in_channels,
212
+ out_channels,
213
+ kernel_size,
214
+ stride=stride,
215
+ padding=padding,
216
+ dilation=dilation,
217
+ groups=groups,
218
+ bias=bias,
219
+ padding_mode=padding_mode,
220
+ device=device,
221
+ dtype=dtype,
222
+ )
223
+
224
+ def forward(self, input):
225
+ # Temporarily using len(shape) instead of ndim due to JIT issue
226
+ # https://github.com/pytorch/pytorch/issues/23890
227
+ if len(input.shape) != 5:
228
+ raise ValueError("Input shape must be `(N, C, D, H, W)`!")
229
+ if self.padding_mode != "zeros":
230
+ _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
231
+ input = F.pad(
232
+ input, _reversed_padding_repeated_twice, mode=self.padding_mode
233
+ )
234
+ return torch.ops.quantized.conv3d_relu(
235
+ input, self._packed_params, self.scale, self.zero_point
236
+ )
237
+
238
+ def _get_name(self):
239
+ return "QuantizedConvReLU3d"
240
+
241
+ @classmethod
242
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
243
+ if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d:
244
+ assert mod.bn.running_var is not None and mod.bn.running_mean is not None
245
+ mod.weight, mod.bias = fuse_conv_bn_weights(
246
+ mod.weight,
247
+ mod.bias,
248
+ mod.bn.running_mean,
249
+ mod.bn.running_var,
250
+ mod.bn.eps,
251
+ mod.bn.weight,
252
+ mod.bn.bias,
253
+ )
254
+ return super().from_float(
255
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
256
+ )
257
+
258
+ @classmethod
259
+ def from_reference(cls, ref_qconv, output_scale, output_zero_point):
260
+ assert (
261
+ type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU3d
262
+ ), "BatchNorm3d should be fused into Conv3d before converting to reference module"
263
+ return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
.venv/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+ import torch.ao.nn.intrinsic as nni
4
+ import torch.ao.nn.quantized as nnq
5
+ from torch.ao.nn.quantized.modules.utils import _quantize_weight
6
+
7
+
8
+ __all__ = [
9
+ "LinearReLU",
10
+ "LinearLeakyReLU",
11
+ "LinearTanh",
12
+ ]
13
+
14
+
15
+ class LinearReLU(nnq.Linear):
16
+ r"""
17
+ A LinearReLU module fused from Linear and ReLU modules
18
+
19
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
20
+
21
+ Attributes:
22
+ Same as torch.ao.nn.quantized.Linear
23
+
24
+ Examples::
25
+
26
+ >>> # xdoctest: +SKIP
27
+ >>> m = nn.intrinsic.LinearReLU(20, 30)
28
+ >>> input = torch.randn(128, 20)
29
+ >>> output = m(input)
30
+ >>> print(output.size())
31
+ torch.Size([128, 30])
32
+ """
33
+ _FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment]
34
+
35
+ def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8):
36
+ super().__init__(in_features, out_features, bias, dtype)
37
+
38
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
+ return torch.ops.quantized.linear_relu(
40
+ x, self._packed_params._packed_params, self.scale, self.zero_point
41
+ )
42
+
43
+ def _get_name(self):
44
+ return "QuantizedLinearReLU"
45
+
46
+ @classmethod
47
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
48
+ return super().from_float(mod, use_precomputed_fake_quant)
49
+
50
+ @classmethod
51
+ def from_reference(cls, ref_linear_relu, output_scale, output_zero_point):
52
+ return super().from_reference(
53
+ ref_linear_relu[0], output_scale, output_zero_point
54
+ )
55
+
56
+
57
+ class LinearLeakyReLU(nnq.Linear):
58
+ r"""
59
+ For onednn backend only
60
+ A LinearLeakyReLU module fused from Linear and LeakyReLU modules
61
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
62
+ Attributes:
63
+ Same as torch.ao.nn.quantized.Linear
64
+ + negative_slope
65
+ Examples::
66
+ >>> # xdoctest: +SKIP
67
+ >>> m = nn.intrinsic.LinearLeakyReLU(20, 30, 0.01)
68
+ >>> input = torch.randn(128, 20)
69
+ >>> output = m(input)
70
+ >>> print(output.size())
71
+ torch.Size([128, 30])
72
+ """
73
+ _FLOAT_MODULE = nni.LinearLeakyReLU # type: ignore[assignment]
74
+
75
+ def __init__(
76
+ self, in_features, out_features, negative_slope, bias=True, dtype=torch.qint8
77
+ ):
78
+ super().__init__(in_features, out_features, bias, dtype)
79
+ self.negative_slope = negative_slope
80
+
81
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
82
+ return torch.ops.quantized.linear_leaky_relu(
83
+ x,
84
+ self._packed_params._packed_params,
85
+ self.scale,
86
+ self.zero_point,
87
+ self.negative_slope,
88
+ )
89
+
90
+ def _get_name(self):
91
+ return "QuantizedLinearLeakyReLU"
92
+
93
+ @classmethod
94
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
95
+ assert (
96
+ type(mod) == nni.LinearLeakyReLU
97
+ ), "Input float module should be LinearLeakyReLU"
98
+ assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
99
+ activation_post_process = mod.activation_post_process
100
+ leaky_relu = mod[1]
101
+ mod = mod[0]
102
+ weight_post_process = mod.qconfig.weight()
103
+ weight_post_process(mod.weight)
104
+ dtype = weight_post_process.dtype
105
+ act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator]
106
+ assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8"
107
+ qweight = _quantize_weight(mod.weight.float(), weight_post_process)
108
+ qlinear_leaky_relu = cls(
109
+ mod.in_features, mod.out_features, leaky_relu.negative_slope, dtype=dtype
110
+ )
111
+ qlinear_leaky_relu.set_weight_bias(qweight, mod.bias)
112
+ qlinear_leaky_relu.scale = float(act_scale)
113
+ qlinear_leaky_relu.zero_point = int(act_zp)
114
+ return qlinear_leaky_relu
115
+
116
+ @classmethod
117
+ def from_reference(cls, ref_mod, output_scale, output_zero_point):
118
+ linear = ref_mod[0]
119
+ leaky_relu = ref_mod[1]
120
+ qlinear_leaky_relu = cls(
121
+ linear.in_features, linear.out_features, leaky_relu.negative_slope
122
+ )
123
+ qweight = linear.get_quantized_weight()
124
+ qlinear_leaky_relu.set_weight_bias(qweight, linear.bias)
125
+ qlinear_leaky_relu.scale = float(output_scale)
126
+ qlinear_leaky_relu.zero_point = int(output_zero_point)
127
+ return qlinear_leaky_relu
128
+
129
+
130
+ class LinearTanh(nnq.Linear):
131
+ r"""
132
+ A LinearTanh module fused from Linear and Tanh modules
133
+
134
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
135
+
136
+ Attributes:
137
+ Same as torch.ao.nn.quantized.Linear
138
+
139
+ Examples::
140
+
141
+ >>> # xdoctest: +SKIP
142
+ >>> m = nn.intrinsic.LinearTanh(20, 30)
143
+ >>> input = torch.randn(128, 20)
144
+ >>> output = m(input)
145
+ >>> print(output.size())
146
+ torch.Size([128, 30])
147
+ """
148
+ _FLOAT_MODULE = nni.LinearTanh # type: ignore[assignment]
149
+
150
+ def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8):
151
+ super().__init__(in_features, out_features, bias, dtype)
152
+
153
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
154
+ return torch.ops.quantized.linear_tanh(
155
+ x, self._packed_params._packed_params, self.scale, self.zero_point
156
+ )
157
+
158
+ def _get_name(self):
159
+ return "QuantizedLinearTanh"
160
+
161
+ @classmethod
162
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
163
+ assert type(mod) == nni.LinearTanh, "Input float module should be LinearTanh"
164
+ assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
165
+ activation_post_process = mod.activation_post_process
166
+ mod = mod[0]
167
+ weight_post_process = mod.qconfig.weight()
168
+ weight_post_process(mod.weight)
169
+ dtype = weight_post_process.dtype
170
+ act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator]
171
+ assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8"
172
+ qweight = _quantize_weight(mod.weight.float(), weight_post_process)
173
+ qlinear_tanh = cls(mod.in_features, mod.out_features, dtype=dtype)
174
+ qlinear_tanh.set_weight_bias(qweight, mod.bias)
175
+ qlinear_tanh.scale = float(act_scale)
176
+ qlinear_tanh.zero_point = int(act_zp)
177
+ return qlinear_tanh
178
+
179
+ @classmethod
180
+ def from_reference(cls, ref_mod, output_scale, output_zero_point):
181
+ linear = ref_mod[0]
182
+ qlinear_tanh = cls(linear.in_features, linear.out_features)
183
+ qweight = linear.get_quantized_weight()
184
+ qlinear_tanh.set_weight_bias(qweight, linear.bias)
185
+ qlinear_tanh.scale = float(output_scale)
186
+ qlinear_tanh.zero_point = int(output_zero_point)
187
+ return qlinear_tanh
.venv/Lib/site-packages/torch/ao/nn/quantized/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (674 Bytes). View file
 
.venv/Lib/site-packages/torch/ao/nn/quantized/__pycache__/functional.cpython-39.pyc ADDED
Binary file (26.7 kB). View file
 
.venv/Lib/site-packages/torch/ao/nn/quantized/reference/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (395 Bytes). View file
 
.venv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (650 Bytes). View file
 
.venv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/rnn.cpython-39.pyc ADDED
Binary file (17.6 kB). View file
 
.venv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/sparse.cpython-39.pyc ADDED
Binary file (4.1 kB). View file
 
.venv/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/utils.cpython-39.pyc ADDED
Binary file (7.02 kB). View file
 
.venv/Lib/site-packages/torch/ao/nn/sparse/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import quantized
.venv/Lib/site-packages/torch/ao/nn/sparse/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (224 Bytes). View file
 
.venv/Lib/site-packages/torch/ao/nn/sparse/quantized/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.ao.nn.sparse.quantized import dynamic
2
+
3
+ from .linear import Linear, LinearPackedParams
4
+
5
+
6
+ __all__ = [
7
+ "dynamic",
8
+ "Linear",
9
+ "LinearPackedParams",
10
+ ]
.venv/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (367 Bytes). View file
 
.venv/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/linear.cpython-39.pyc ADDED
Binary file (7.6 kB). View file
 
.venv/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/utils.cpython-39.pyc ADDED
Binary file (1.56 kB). View file
 
.venv/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .linear import Linear
2
+
3
+
4
+ __all__ = [
5
+ "Linear",
6
+ ]
.venv/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (269 Bytes). View file
 
.venv/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/linear.cpython-39.pyc ADDED
Binary file (5.18 kB). View file
 
.venv/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/linear.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.ao.nn.intrinsic as nni
6
+ from torch.ao.nn.quantized.modules.utils import (
7
+ _hide_packed_params_repr,
8
+ _quantize_weight,
9
+ )
10
+ from torch.ao.nn.sparse.quantized import linear
11
+ from torch.ao.nn.sparse.quantized.utils import LinearBlockSparsePattern
12
+
13
+
14
+ __all__ = ["Linear"]
15
+
16
+
17
+ class Linear(torch.nn.Module):
18
+ r"""
19
+ A dynamically quantized sparse linear module with float tensor as inputs and outputs.
20
+ """
21
+ _version = 1
22
+ _op_type = "sparse_dynamic"
23
+ _FLOAT_MODULE = torch.nn.Linear
24
+
25
+ def __init__(
26
+ self,
27
+ in_features,
28
+ out_features,
29
+ row_block_size,
30
+ col_block_size,
31
+ bias=True,
32
+ dtype=torch.qint8,
33
+ ):
34
+ super().__init__()
35
+
36
+ if dtype != torch.qint8:
37
+ raise NotImplementedError(
38
+ "Only QINT8 is supported for Sparse Quantized Linear Dynamic"
39
+ )
40
+
41
+ self.in_features = in_features
42
+ self.out_features = out_features
43
+
44
+ if bias:
45
+ bias = torch.zeros(self.out_features, dtype=torch.float)
46
+ else:
47
+ bias = None
48
+
49
+ qweight = torch._empty_affine_quantized(
50
+ [out_features, in_features], scale=1, zero_point=0, dtype=torch.qint8
51
+ )
52
+ self._packed_params = linear.LinearPackedParams(
53
+ row_block_size=row_block_size, col_block_size=col_block_size, dtype=dtype
54
+ )
55
+ self._packed_params.set_weight_bias(
56
+ qweight, bias, row_block_size, col_block_size
57
+ )
58
+
59
+ def _get_name(self):
60
+ return "SparseQuantizedDynamicLinear"
61
+
62
+ def extra_repr(self):
63
+ return f"in_features={self.in_features}, out_features={self.out_features}, qscheme={self.weight().qscheme()}"
64
+
65
+ def __repr__(self):
66
+ return _hide_packed_params_repr(self, linear.LinearPackedParams)
67
+
68
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
69
+ return torch.ops.sparse.qlinear_dynamic(x, self._packed_params._packed_params)
70
+
71
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
72
+ super()._save_to_state_dict(destination, prefix, keep_vars)
73
+ destination[prefix + "op_type"] = self._op_type
74
+
75
+ def _load_from_state_dict(
76
+ self,
77
+ state_dict,
78
+ prefix,
79
+ local_metadata,
80
+ strict,
81
+ missing_keys,
82
+ unexpected_keys,
83
+ error_msgs,
84
+ ):
85
+ op_type = int(state_dict[prefix + "op_type"])
86
+ assert (
87
+ op_type == "sparse"
88
+ ), f"Cannot load from op_type [{op_type}], expecting [{self._op_type}]"
89
+ state_dict.pop(prefix + "op_type")
90
+
91
+ version = local_metadata.get("version", None)
92
+ assert version <= self._version
93
+
94
+ # Is this code valid? In old quantization it seemed to be used to load
95
+ # older model
96
+ weight = state_dict.pop(prefix + "weight")
97
+ bias = state_dict.pop(prefix + "bias")
98
+ state_dict.update(
99
+ {
100
+ prefix + "_packed_params.weight": weight,
101
+ prefix + "_packed_params.bias": bias,
102
+ }
103
+ )
104
+
105
+ super()._load_from_state_dict(
106
+ state_dict,
107
+ prefix,
108
+ local_metadata,
109
+ False,
110
+ missing_keys,
111
+ unexpected_keys,
112
+ error_msgs,
113
+ )
114
+
115
+ def _weight_bias(self):
116
+ return self._packed_params._weight_bias()
117
+
118
+ def weight(self):
119
+ return self._weight_bias()[0]
120
+
121
+ def bias(self):
122
+ return self._weight_bias()[1]
123
+
124
+ def set_weight_bias(
125
+ self,
126
+ w: torch.Tensor,
127
+ b: Optional[torch.Tensor],
128
+ row_block_size: Optional[int],
129
+ col_block_size: Optional[int],
130
+ ) -> None:
131
+ assert row_block_size is not None and col_block_size is not None
132
+ self.out_features = w.shape[0]
133
+ self.in_features = w.shape[1]
134
+ self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size)
135
+
136
+ @classmethod
137
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
138
+ r"""Create a quantized sparse dynamic module from a float module.
139
+
140
+ We only care about the convert at this stage, no need for observers just yet.
141
+ """
142
+ assert type(mod) == cls._FLOAT_MODULE, (
143
+ " nnq."
144
+ + cls.__name__
145
+ + ".from_float only works for "
146
+ + cls._FLOAT_MODULE.__name__
147
+ )
148
+ # TODO: Need to add options to qconfig to avoid the calibration.
149
+ # TODO: Add calibration for the sparsity
150
+ assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
151
+ if type(mod) == nni.LinearReLU:
152
+ mod = mod[0]
153
+ if mod.qconfig is not None and mod.qconfig.weight is not None:
154
+ weight_observer = mod.qconfig.weight()
155
+ else:
156
+ # We have the circular import issues if we import the qconfig in the beginning of this file:
157
+ # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
158
+ # import until we need it.
159
+ from torch.ao.quantization.qconfig import default_dynamic_qconfig
160
+
161
+ weight_observer = default_dynamic_qconfig.weight()
162
+
163
+ # It is important to multiply by the mask BEFORE calling the `weight_observer`
164
+ # TODO (zaf): Mask might not be part of the qconfig (T83295194)
165
+ weight = mod.weight
166
+ if getattr(mod.qconfig, "mask", False):
167
+ weight = mod.qconfig.mask * mod.weight
168
+
169
+ weight_observer(weight)
170
+ dtype = weight_observer.dtype
171
+ assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8"
172
+ w_sc, w_zp = weight_observer.calculate_qparams()
173
+ if isinstance(w_zp, torch.Tensor):
174
+ assert not torch.any(w_zp.bool()), "All weight zero points must map to 0"
175
+ else:
176
+ assert w_zp == 0, "Weight zero point must map to 0"
177
+ qweight = _quantize_weight(weight.float(), weight_observer)
178
+
179
+ row_block_size, col_block_size = LinearBlockSparsePattern.block_size()
180
+ qlinear = cls(
181
+ mod.in_features,
182
+ mod.out_features,
183
+ row_block_size,
184
+ col_block_size,
185
+ dtype=dtype,
186
+ )
187
+ qlinear.set_weight_bias(qweight, mod.bias, row_block_size, col_block_size)
188
+ return qlinear
.venv/Lib/site-packages/torch/ao/nn/sparse/quantized/linear.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-decorators
2
+ # mypy: allow-untyped-defs
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from torch.ao.nn.quantized.modules.utils import (
7
+ _hide_packed_params_repr,
8
+ _quantize_weight,
9
+ )
10
+
11
+
12
+ __all__ = ["LinearPackedParams", "Linear"]
13
+
14
+
15
+ # TODO (zaf): Inherit from `quantized.LinearPackedParams` (T83294430)
16
+ class LinearPackedParams(torch.nn.Module):
17
+ _version = 1
18
+
19
+ def __init__(self, row_block_size=1, col_block_size=4, dtype=torch.qint8):
20
+ super().__init__()
21
+
22
+ if dtype != torch.qint8:
23
+ raise NotImplementedError("Linear prepacking only supports QINT8")
24
+ self.dtype = dtype
25
+ wq = torch._empty_affine_quantized(
26
+ [1, 1], scale=1.0, zero_point=0, dtype=torch.qint8
27
+ )
28
+ self.set_weight_bias(wq, None, row_block_size, col_block_size)
29
+
30
+ def _get_name(self):
31
+ return "SparseQuantizedLinearPackedParams"
32
+
33
+ @torch.jit.export
34
+ def set_weight_bias(
35
+ self,
36
+ weight: torch.Tensor,
37
+ bias: Optional[torch.Tensor],
38
+ row_block_size: Optional[int],
39
+ col_block_size: Optional[int],
40
+ ) -> None:
41
+ assert row_block_size is not None and col_block_size is not None
42
+ self._packed_params = torch.ops.sparse.qlinear_prepack(
43
+ weight, bias, row_block_size, col_block_size
44
+ )
45
+
46
+ @torch.jit.export
47
+ def _weight_bias(self):
48
+ (weight, bias, block_sizes) = torch.ops.sparse.qlinear_unpack(
49
+ self._packed_params
50
+ )
51
+ return (weight, bias, block_sizes[0], block_sizes[1])
52
+
53
+ def forward(self, x):
54
+ return x
55
+
56
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
57
+ super()._save_to_state_dict(destination, prefix, keep_vars)
58
+ destination[prefix + "dtype"] = self.dtype
59
+ destination[prefix + "_packed_params"] = self._weight_bias()
60
+
61
+ def _load_from_state_dict(
62
+ self,
63
+ state_dict,
64
+ prefix,
65
+ local_metadata,
66
+ strict,
67
+ missing_keys,
68
+ unexpected_keys,
69
+ error_msgs,
70
+ ):
71
+ version = local_metadata.get("version", None)
72
+ assert version <= self._version
73
+
74
+ self.dtype = state_dict.pop(prefix + "dtype")
75
+ weight, bias, row_block_size, col_block_size = state_dict.pop(
76
+ prefix + "_packed_params"
77
+ )
78
+ self.set_weight_bias(weight, bias, row_block_size, col_block_size)
79
+
80
+ super()._load_from_state_dict(
81
+ state_dict,
82
+ prefix,
83
+ local_metadata,
84
+ False,
85
+ missing_keys,
86
+ unexpected_keys,
87
+ error_msgs,
88
+ )
89
+
90
+ @torch.jit.export
91
+ def __getstate__(self):
92
+ return self._packed_params, self.training, self.dtype
93
+
94
+ @torch.jit.export
95
+ def __setstate__(self, state):
96
+ (self._packed_params, self.training, self.dtype) = state
97
+
98
+ def __repr__(self):
99
+ return self._weight_bias().__repr__()
100
+
101
+
102
+ # TODO (zaf): Inherit from `quantized.Linear` (T83294430)
103
+ class Linear(torch.nn.Module):
104
+ r"""
105
+ A quantized sparse linear module with quantized tensor as inputs and outputs.
106
+ """
107
+ _version = 1
108
+ _FLOAT_MODULE = torch.nn.Linear
109
+
110
+ def __init__(
111
+ self,
112
+ in_features,
113
+ out_features,
114
+ row_block_size,
115
+ col_block_size,
116
+ bias=True,
117
+ dtype=torch.qint8,
118
+ ):
119
+ super().__init__()
120
+
121
+ if dtype != torch.qint8:
122
+ raise NotImplementedError(
123
+ "Only QINT8 is supported for Sparse Quantized Linear"
124
+ )
125
+
126
+ self.in_features = in_features
127
+ self.out_features = out_features
128
+
129
+ if bias:
130
+ bias = torch.zeros(self.out_features, dtype=torch.float)
131
+ else:
132
+ bias = None
133
+
134
+ qweight = torch._empty_affine_quantized(
135
+ [out_features, in_features], scale=1, zero_point=0, dtype=torch.qint8
136
+ )
137
+ self._packed_params = LinearPackedParams(
138
+ row_block_size=row_block_size, col_block_size=col_block_size, dtype=dtype
139
+ )
140
+ self._packed_params.set_weight_bias(
141
+ qweight, bias, row_block_size, col_block_size
142
+ )
143
+ self.scale = 1.0
144
+ self.zero_point = 0
145
+
146
+ @classmethod
147
+ def _get_name(cls):
148
+ return "SparseQuantizedLinear"
149
+
150
+ def extra_repr(self):
151
+ return (
152
+ f"in_features={self.in_features}, out_features={self.out_features}, scale={self.scale}, "
153
+ f"zero_point={self.zero_point}, qscheme={self.weight().qscheme()}"
154
+ )
155
+
156
+ def __repr__(self):
157
+ return _hide_packed_params_repr(self, LinearPackedParams)
158
+
159
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
160
+ return torch.ops.sparse.qlinear(
161
+ x, self._packed_params._packed_params, self.scale, self.zero_point
162
+ )
163
+
164
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
165
+ super()._save_to_state_dict(destination, prefix, keep_vars)
166
+ destination[prefix + "scale"] = torch.tensor(self.scale)
167
+ destination[prefix + "zero_point"] = torch.tensor(self.zero_point)
168
+
169
+ def _load_from_state_dict(
170
+ self,
171
+ state_dict,
172
+ prefix,
173
+ local_metadata,
174
+ strict,
175
+ missing_keys,
176
+ unexpected_keys,
177
+ error_msgs,
178
+ ):
179
+ self.scale = float(state_dict[prefix + "scale"])
180
+ state_dict.pop(prefix + "scale")
181
+
182
+ self.zero_point = int(state_dict[prefix + "zero_point"])
183
+ state_dict.pop(prefix + "zero_point")
184
+
185
+ op_type = int(state_dict[prefix + "op_type"])
186
+ state_dict.pop(prefix + "op_type")
187
+
188
+ version = local_metadata.get("version", None)
189
+ assert version <= self._version
190
+
191
+ super()._load_from_state_dict(
192
+ state_dict,
193
+ prefix,
194
+ local_metadata,
195
+ False,
196
+ missing_keys,
197
+ unexpected_keys,
198
+ error_msgs,
199
+ )
200
+
201
+ def _weight_bias(self):
202
+ return self._packed_params._weight_bias()
203
+
204
+ def weight(self):
205
+ return self._weight_bias()[0]
206
+
207
+ def bias(self):
208
+ return self._weight_bias()[1]
209
+
210
+ def set_weight_bias(
211
+ self,
212
+ w: torch.Tensor,
213
+ b: Optional[torch.Tensor],
214
+ row_block_size: Optional[int],
215
+ col_block_size: Optional[int],
216
+ ) -> None:
217
+ assert row_block_size is not None and col_block_size is not None
218
+ self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size)
219
+
220
+ @classmethod
221
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
222
+ r"""Create a quantized sparse module from a float module.
223
+
224
+ We only care about the convert at this stage, no need for observers just yet.
225
+
226
+ TODO(zaf): Need to add the sparse params to the qconfig
227
+ """
228
+ assert type(mod) == cls._FLOAT_MODULE, (
229
+ cls._get_name() + ".from_float only works for " + cls._FLOAT_MODULE.__name__
230
+ )
231
+ assert hasattr(mod, "sparse_params"), (
232
+ "Expecting the Linear to have `sparse_params`. Make sure you have provided arguments "
233
+ 'in the `sparsifier.squash_mask(params_to_save=("sparse_block_shape",))` method.'
234
+ )
235
+ sparse_block_shape = mod.sparse_params.get("sparse_block_shape", None) # type: ignore[operator, union-attr]
236
+ assert isinstance(sparse_block_shape, (tuple, list))
237
+ assert len(sparse_block_shape) == 2
238
+ # TODO: Need to add options to qconfig to avoid the calibration.
239
+ # TODO: Add calibration for the sparsity
240
+ assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
241
+ activation_post_process = mod.activation_post_process
242
+ weight_post_process = mod.qconfig.weight() # type: ignore[operator, union-attr]
243
+
244
+ # Assumption is that the weight is already sparsified by the
245
+ # `sparsifier.convert`
246
+ weight = mod.weight
247
+
248
+ weight_post_process(weight)
249
+ dtype = weight_post_process.dtype
250
+ act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[operator, union-attr]
251
+ assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8"
252
+ w_sc, w_zp = weight_post_process.calculate_qparams()
253
+ if isinstance(w_zp, torch.Tensor):
254
+ assert not torch.any(w_zp.bool()), "All weight zero points must map to 0"
255
+ else:
256
+ assert w_zp == 0, "Weight zero point must map to 0"
257
+ qweight = _quantize_weight(weight.float(), weight_post_process)
258
+
259
+ row_block_size = mod.sparse_params["sparse_block_shape"][0] # type: ignore[index]
260
+ col_block_size = mod.sparse_params["sparse_block_shape"][1] # type: ignore[index]
261
+ qlinear = cls(
262
+ mod.in_features,
263
+ mod.out_features,
264
+ row_block_size,
265
+ col_block_size,
266
+ dtype=dtype,
267
+ )
268
+ qlinear.set_weight_bias(
269
+ qweight, mod.bias, row_block_size, col_block_size
270
+ ) # type: ignore[arg-type]
271
+ qlinear.scale = float(act_scale)
272
+ qlinear.zero_point = int(act_zp)
273
+ return qlinear
.venv/Lib/site-packages/torch/ao/nn/sparse/quantized/utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import threading
3
+
4
+
5
+ __all__ = ["LinearBlockSparsePattern"]
6
+
7
+
8
+ def _is_valid_linear_block_sparse_pattern(row_block_size, col_block_size):
9
+ return (row_block_size == 1 and col_block_size == 4) or (
10
+ row_block_size == 8 and col_block_size == 1
11
+ )
12
+
13
+
14
+ # This is a stop-gap measure as current flow does not allow module
15
+ # specific block sparse pattern.
16
+ # Infact there is no way to convey sparse pattern via module config
17
+ # of quantization flow. Thus using the global context to convey
18
+ # sparsity pattern.
19
+ # Once the flow supports it, this should be removed.
20
+ class LinearBlockSparsePattern:
21
+ rlock = threading.RLock()
22
+ row_block_size = 1
23
+ col_block_size = 4
24
+ prev_row_block_size = 1
25
+ prev_col_block_size = 4
26
+
27
+ def __init__(self, row_block_size=1, col_block_size=4):
28
+ assert _is_valid_linear_block_sparse_pattern(row_block_size, col_block_size)
29
+ LinearBlockSparsePattern.rlock.acquire()
30
+ LinearBlockSparsePattern.prev_row_block_size = (
31
+ LinearBlockSparsePattern.row_block_size
32
+ )
33
+ LinearBlockSparsePattern.prev_col_block_size = (
34
+ LinearBlockSparsePattern.col_block_size
35
+ )
36
+ LinearBlockSparsePattern.row_block_size = row_block_size
37
+ LinearBlockSparsePattern.col_block_size = col_block_size
38
+
39
+ def __enter__(self):
40
+ pass
41
+
42
+ def __exit__(self, exc_type, exc_value, backtrace):
43
+ LinearBlockSparsePattern.row_block_size = (
44
+ LinearBlockSparsePattern.prev_row_block_size
45
+ )
46
+ LinearBlockSparsePattern.col_block_size = (
47
+ LinearBlockSparsePattern.prev_col_block_size
48
+ )
49
+ LinearBlockSparsePattern.rlock.release()
50
+
51
+ @staticmethod
52
+ def block_size():
53
+ return (
54
+ LinearBlockSparsePattern.row_block_size,
55
+ LinearBlockSparsePattern.col_block_size,
56
+ )
.venv/Lib/site-packages/torch/ao/ns/__init__.py ADDED
File without changes
.venv/Lib/site-packages/torch/ao/ns/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (183 Bytes). View file
 
.venv/Lib/site-packages/torch/ao/ns/_numeric_suite.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from typing import Any, Callable, Dict, List, Optional, Set, Union
3
+
4
+ import torch
5
+ import torch.ao.nn.quantized as nnq
6
+ import torch.ao.nn.quantized.dynamic as nnqd
7
+ import torch.nn as nn
8
+ from torch.ao.quantization import prepare
9
+ from torch.ao.quantization.quantization_mappings import (
10
+ get_default_compare_output_module_list,
11
+ )
12
+
13
+
14
+ NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = {
15
+ nnqd.Linear,
16
+ nnq.Linear,
17
+ nnqd.LSTM,
18
+ nn.LSTM,
19
+ }
20
+
21
+
22
+ def _find_match(
23
+ str_list: Union[Dict[str, Any], List[str]],
24
+ key_str: str,
25
+ postfix: str,
26
+ ) -> Optional[str]:
27
+ split_str = key_str.split(".")
28
+ if split_str[-1] == postfix:
29
+ match_string = "".join(key_str.split(".")[0:-1])
30
+ for s2 in str_list:
31
+ pattern1 = "".join(s2.split(".")[0:-1])
32
+ pattern2 = "".join(s2.split(".")[0:-2])
33
+ if match_string == pattern1:
34
+ return s2
35
+ if match_string == pattern2:
36
+ return s2
37
+
38
+ # For matching "fc.weight" and "fc._packed_params._packed_params"
39
+ if postfix == "_packed_params":
40
+ match_string = "".join(key_str.split(".")[0:-2])
41
+ if len(match_string) == 0:
42
+ return None
43
+ for s2 in str_list:
44
+ pattern1 = "".join(s2.split(".")[0:-1])
45
+ pattern2 = "".join(s2.split(".")[0:-2])
46
+ if match_string == pattern1:
47
+ return s2
48
+ if match_string == pattern2:
49
+ return s2
50
+ return None
51
+ else:
52
+ return None
53
+
54
+
55
+ def compare_weights(
56
+ float_dict: Dict[str, Any], quantized_dict: Dict[str, Any]
57
+ ) -> Dict[str, Dict[str, torch.Tensor]]:
58
+ r"""Compare the weights of the float module with its corresponding quantized
59
+ module. Return a dict with key corresponding to module names and each entry being
60
+ a dictionary with two keys 'float' and 'quantized', containing the float and
61
+ quantized weights. This dict can be used to compare and compute the quantization
62
+ error of the weights of float and quantized models.
63
+
64
+ Example usage::
65
+
66
+ wt_compare_dict = compare_weights(
67
+ float_model.state_dict(), qmodel.state_dict())
68
+ for key in wt_compare_dict:
69
+ print(
70
+ key,
71
+ compute_error(
72
+ wt_compare_dict[key]['float'],
73
+ wt_compare_dict[key]['quantized'].dequantize()
74
+ )
75
+ )
76
+
77
+ Args:
78
+ float_dict: state dict of the float model
79
+ quantized_dict: state dict of the quantized model
80
+
81
+ Return:
82
+ weight_dict: dict with key corresponding to module names and each entry being
83
+ a dictionary with two keys 'float' and 'quantized', containing the float and
84
+ quantized weights
85
+ """
86
+ torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_weights")
87
+ weight_dict: Dict[str, Dict] = {}
88
+ for key in quantized_dict:
89
+ match_key = _find_match(float_dict, key, "weight")
90
+ if match_key is not None:
91
+ weight_dict[key] = {}
92
+ weight_dict[key]["float"] = float_dict[match_key]
93
+ weight_dict[key]["quantized"] = quantized_dict[key]
94
+ continue
95
+
96
+ # For matching "fc.weight" and "fc._packed_params._packed_params"
97
+ match_key = _find_match(float_dict, key, "_packed_params")
98
+ if match_key is not None:
99
+ weight_dict[key] = {}
100
+ weight_dict[key]["float"] = float_dict[match_key]
101
+ weight_dict[key]["quantized"] = quantized_dict[key][0]
102
+
103
+ # For LSTM
104
+ split_str = key.split(".")
105
+ if split_str[-1] == "param" and split_str[-3] == "_all_weight_values":
106
+ layer = split_str[-2]
107
+ module_name = ".".join(split_str[:-3])
108
+ float_weight_ih_key = module_name + ".weight_ih_l" + layer
109
+ float_weight_hh_key = module_name + ".weight_hh_l" + layer
110
+ if float_weight_ih_key in float_dict and float_weight_hh_key in float_dict:
111
+ weight_dict[key] = {}
112
+ weight_dict[key]["float"] = float_dict[float_weight_ih_key]
113
+ weight_dict[key]["quantized"] = (
114
+ quantized_dict[key].__getstate__()[0][4][0].__getstate__()[0][0]
115
+ )
116
+ weight_dict[key]["float"] = float_dict[float_weight_hh_key]
117
+ weight_dict[key]["quantized"] = (
118
+ quantized_dict[key].__getstate__()[0][4][1].__getstate__()[0][0]
119
+ )
120
+
121
+ return weight_dict
122
+
123
+
124
+ def _get_logger_dict_helper(
125
+ mod: nn.Module,
126
+ target_dict: Dict[str, Any],
127
+ prefix: str = "",
128
+ ) -> None:
129
+ r"""This is the helper function for get_logger_dict
130
+
131
+ Args:
132
+ mod: module we want to save all logger stats
133
+ prefix: prefix for the current module
134
+ target_dict: the dictionary used to save all logger stats
135
+ """
136
+
137
+ def get_prefix(prefix):
138
+ return prefix if prefix == "" else prefix + "."
139
+
140
+ for name, child in mod.named_children():
141
+ if isinstance(child, Logger):
142
+ target_dict[get_prefix(prefix) + "stats"] = child.stats
143
+ break
144
+
145
+ for name, child in mod.named_children():
146
+ module_prefix = get_prefix(prefix) + name if prefix else name
147
+ _get_logger_dict_helper(child, target_dict, module_prefix)
148
+
149
+
150
+ def get_logger_dict(mod: nn.Module, prefix: str = "") -> Dict[str, Dict]:
151
+ r"""Traverse the modules and save all logger stats into target dict.
152
+ This is mainly used for quantization accuracy debug.
153
+
154
+ Type of loggers supported:
155
+ ShadowLogger: used to log the outputs of the quantized module and its matching float shadow module,
156
+ OutputLogger: used to log the outputs of the modules
157
+
158
+ Args:
159
+ mod: module we want to save all logger stats
160
+ prefix: prefix for the current module
161
+
162
+ Return:
163
+ target_dict: the dictionary used to save all logger stats
164
+
165
+ """
166
+ torch._C._log_api_usage_once("quantization_api._numeric_suite.get_logger_dict")
167
+
168
+ target_dict: Dict[str, Dict] = {}
169
+ _get_logger_dict_helper(mod, target_dict, prefix)
170
+ return target_dict
171
+
172
+
173
+ class Logger(nn.Module):
174
+ r"""Base class for stats logging"""
175
+
176
+ def __init__(self):
177
+ super().__init__()
178
+ self.stats = {}
179
+ # We only insert observer if the op is quantized with static quantization,
180
+ # which is identified by activation_observer.dtype == quint8. This is needed
181
+ # when attaching Logger as observer for FX mode
182
+ self.dtype = torch.quint8
183
+
184
+ def forward(self, x):
185
+ # fmt: off
186
+ """
187
+ """ # blank docblock to make autodoc happy
188
+ # fmt: on
189
+
190
+
191
+ class ShadowLogger(Logger):
192
+ r"""Class used in Shadow module to record the outputs of the original and
193
+ shadow modules.
194
+ """
195
+
196
+ def __init__(self):
197
+ super().__init__()
198
+ self.stats["float"] = []
199
+ self.stats["quantized"] = []
200
+
201
+ def forward(self, x, y):
202
+ # fmt: off
203
+ """
204
+ """ # blank docblock to make autodoc happy
205
+ # fmt: on
206
+ if len(x) > 1:
207
+ x = x[0]
208
+ if len(y) > 1:
209
+ y = y[0]
210
+ self.stats["quantized"].append(x.detach())
211
+ self.stats["float"].append(y.detach())
212
+
213
+
214
+ class OutputLogger(Logger):
215
+ r"""Class used to log the outputs of the module"""
216
+
217
+ def __init__(self):
218
+ super().__init__()
219
+ self.stats["tensor_val"] = []
220
+
221
+ def forward(self, x):
222
+ # fmt: off
223
+ """
224
+ """ # blank docblock to make autodoc happy
225
+ # fmt: on
226
+ self.stats["tensor_val"].append(x)
227
+ return x
228
+
229
+
230
+ def _convert_tuple_to_list(t: Any) -> Any:
231
+ return [_convert_tuple_to_list(x) for x in t] if type(t) is tuple else t
232
+
233
+
234
+ def _dequantize_tensor_list(t: Any) -> Any:
235
+ return (
236
+ [_dequantize_tensor_list(x) for x in t]
237
+ if type(t) is list
238
+ else t.dequantize()
239
+ if t.is_quantized
240
+ else t
241
+ )
242
+
243
+
244
+ class Shadow(nn.Module):
245
+ r"""Shadow module attaches the float module to its matching quantized module
246
+ as the shadow. Then it uses Logger module to process the outputs of both
247
+ modules.
248
+
249
+ Args:
250
+ q_module: module quantized from float_module that we want to shadow
251
+ float_module: float module used to shadow q_module
252
+ logger_cls: type of logger used to process the outputs of q_module and
253
+ float_module. ShadowLogger or custom loggers can be used.
254
+ """
255
+
256
+ def __init__(self, q_module, float_module, logger_cls):
257
+ super().__init__()
258
+ self.orig_module = q_module
259
+ self.shadow_module = float_module
260
+ self.dequant = nnq.DeQuantize()
261
+ self.logger = logger_cls()
262
+
263
+ def forward(self, *x) -> torch.Tensor:
264
+ # fmt: off
265
+ """
266
+ """ # blank docblock to make autodoc happy
267
+ # fmt: on
268
+ xl = _convert_tuple_to_list(x)
269
+ output = self.orig_module(*xl)
270
+ xl_float = _dequantize_tensor_list(xl)
271
+ shadow_output = self.shadow_module(*xl_float)
272
+ self.logger(output, shadow_output)
273
+ return output
274
+
275
+ def add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
276
+ # fmt: off
277
+ """
278
+ """ # blank docblock to make autodoc happy
279
+ # fmt: on
280
+ output = self.orig_module.add(x, y)
281
+ x = x.dequantize()
282
+ y = y.dequantize()
283
+ shadow_output = self.shadow_module.add(x, y)
284
+ self.logger(output, shadow_output)
285
+ return output
286
+
287
+ def add_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
288
+ # fmt: off
289
+ """
290
+ """ # blank docblock to make autodoc happy
291
+ # fmt: on
292
+ output = self.orig_module.add_scalar(x, y)
293
+ x = x.dequantize()
294
+ shadow_output = self.shadow_module.add_scalar(x, y)
295
+ self.logger(output, shadow_output)
296
+ return output
297
+
298
+ def mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
299
+ # fmt: off
300
+ """
301
+ """ # blank docblock to make autodoc happy
302
+ # fmt: on
303
+ output = self.orig_module.mul(x, y)
304
+ x = x.dequantize()
305
+ y = y.dequantize()
306
+ shadow_output = self.shadow_module.mul(x, y)
307
+ self.logger(output, shadow_output)
308
+ return output
309
+
310
+ def mul_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
311
+ # fmt: off
312
+ """
313
+ """ # blank docblock to make autodoc happy
314
+ # fmt: on
315
+ output = self.orig_module.mul_scalar(x, y)
316
+ x = x.dequantize()
317
+ shadow_output = self.shadow_module.mul_scalar(x, y)
318
+ self.logger(output, shadow_output)
319
+ return output
320
+
321
+ def cat(self, x: List[torch.Tensor], dim: int = 0) -> torch.Tensor:
322
+ # fmt: off
323
+ """
324
+ """ # blank docblock to make autodoc happy
325
+ # fmt: on
326
+ output = self.orig_module.cat(x, dim)
327
+ x = [y.dequantize() for y in x]
328
+ shadow_output = self.shadow_module.cat(x, dim)
329
+ self.logger(output, shadow_output)
330
+ return output
331
+
332
+ def add_relu(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
333
+ # fmt: off
334
+ """
335
+ """ # blank docblock to make autodoc happy
336
+ # fmt: on
337
+ output = self.orig_module.add_relu(x, y)
338
+ x = x.dequantize()
339
+ y = y.dequantize()
340
+ shadow_output = self.shadow_module.add_relu(x, y)
341
+ self.logger(output, shadow_output)
342
+ return output
343
+
344
+
345
+ def prepare_model_with_stubs(
346
+ float_module: nn.Module,
347
+ q_module: nn.Module,
348
+ module_swap_list: Set[type],
349
+ logger_cls: Callable,
350
+ ) -> None:
351
+ r"""Prepare the model by attaching the float module to its matching quantized
352
+ module as the shadow if the float module type is in module_swap_list.
353
+
354
+ Example usage::
355
+
356
+ prepare_model_with_stubs(float_model, q_model, module_swap_list, Logger)
357
+ q_model(data)
358
+ ob_dict = get_logger_dict(q_model)
359
+
360
+ Args:
361
+ float_module: float module used to generate the q_module
362
+ q_module: module quantized from float_module
363
+ module_swap_list: list of float module types to attach the shadow
364
+ logger_cls: type of logger to be used in shadow module to process the outputs of
365
+ quantized module and its float shadow module
366
+ """
367
+ torch._C._log_api_usage_once(
368
+ "quantization_api._numeric_suite.prepare_model_with_stubs"
369
+ )
370
+
371
+ float_module_children = {}
372
+ for name, mod in float_module.named_children():
373
+ float_module_children[name] = mod
374
+
375
+ reassign = {}
376
+ for name, mod in q_module.named_children():
377
+ if name not in float_module_children:
378
+ continue
379
+
380
+ float_mod = float_module_children[name]
381
+
382
+ if type(float_mod) not in module_swap_list:
383
+ prepare_model_with_stubs(float_mod, mod, module_swap_list, logger_cls)
384
+
385
+ # Insert shadow module only if the module is not of the same type as
386
+ # the floating point module
387
+ if type(float_mod) in module_swap_list and not _is_identical_module_type(
388
+ mod, float_mod
389
+ ):
390
+ reassign[name] = Shadow(mod, float_mod, logger_cls)
391
+
392
+ for key, value in reassign.items():
393
+ q_module._modules[key] = value
394
+
395
+
396
+ def _is_identical_module_type(mod1, mod2):
397
+ # Compare if two modules have the same dtype
398
+ mod1_module_types = [type(mod) for mod in mod1.modules()]
399
+ mod2_module_types = [type(mod) for mod in mod2.modules()]
400
+ return mod1_module_types == mod2_module_types
401
+
402
+
403
+ def compare_model_stub(
404
+ float_model: nn.Module,
405
+ q_model: nn.Module,
406
+ module_swap_list: Set[type],
407
+ *data,
408
+ logger_cls=ShadowLogger,
409
+ ) -> Dict[str, Dict]:
410
+ r"""Compare quantized module in a model with its floating point counterpart,
411
+ feeding both of them the same input. Return a dict with key corresponding to
412
+ module names and each entry being a dictionary with two keys 'float' and
413
+ 'quantized', containing the output tensors of quantized and its matching
414
+ float shadow module. This dict can be used to compare and compute the module
415
+ level quantization error.
416
+
417
+ This function first call prepare_model_with_stubs() to swap the quantized
418
+ module that we want to compare with the Shadow module, which takes quantized
419
+ module, corresponding float module and logger as input, and creates a forward
420
+ path inside to make the float module to shadow quantized module sharing the
421
+ same input. The logger can be customizable, default logger is ShadowLogger
422
+ and it will save the outputs of the quantized module and float module that
423
+ can be used to compute the module level quantization error.
424
+
425
+ Example usage::
426
+
427
+ module_swap_list = [torchvision.models.quantization.resnet.QuantizableBasicBlock]
428
+ ob_dict = compare_model_stub(float_model,qmodel,module_swap_list, data)
429
+ for key in ob_dict:
430
+ print(key, compute_error(ob_dict[key]['float'], ob_dict[key]['quantized'].dequantize()))
431
+
432
+ Args:
433
+ float_model: float model used to generate the q_model
434
+ q_model: model quantized from float_model
435
+ module_swap_list: list of float module types at which shadow modules will
436
+ be attached.
437
+ data: input data used to run the prepared q_model
438
+ logger_cls: type of logger to be used in shadow module to process the outputs of
439
+ quantized module and its float shadow module
440
+ """
441
+ torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_stub")
442
+ prepare_model_with_stubs(float_model, q_model, module_swap_list, logger_cls)
443
+ q_model(*data)
444
+ ob_dict = get_logger_dict(q_model)
445
+ return ob_dict
446
+
447
+
448
+ def get_matching_activations(
449
+ float_module: nn.Module,
450
+ q_module: nn.Module,
451
+ ) -> Dict[str, Dict[str, torch.Tensor]]:
452
+ r"""Find the matching activation between float and quantized modules.
453
+
454
+ Args:
455
+ float_module: float module used to generate the q_module
456
+ q_module: module quantized from float_module
457
+
458
+ Return:
459
+ act_dict: dict with key corresponding to quantized module names and each
460
+ entry being a dictionary with two keys 'float' and 'quantized', containing
461
+ the matching float and quantized activations
462
+ """
463
+ torch._C._log_api_usage_once(
464
+ "quantization_api._numeric_suite.get_matching_activations"
465
+ )
466
+ float_dict = get_logger_dict(float_module)
467
+ quantized_dict = get_logger_dict(q_module)
468
+ act_dict: Dict[str, Dict] = {}
469
+ for key in quantized_dict:
470
+ if len(quantized_dict[key]["tensor_val"]) == 0:
471
+ continue
472
+ match_key = _find_match(sorted(float_dict, reverse=True), key, "stats")
473
+ if match_key is not None:
474
+ act_dict[key] = {}
475
+ act_dict[key]["float"] = float_dict[match_key]["tensor_val"]
476
+ act_dict[key]["quantized"] = quantized_dict[key]["tensor_val"]
477
+ return act_dict
478
+
479
+
480
+ def prepare_model_outputs(
481
+ float_module: nn.Module,
482
+ q_module: nn.Module,
483
+ logger_cls=OutputLogger,
484
+ allow_list=None,
485
+ ) -> None:
486
+ r"""Prepare the model by attaching the logger to both float module
487
+ and quantized module if they are in the allow_list.
488
+
489
+ Args:
490
+ float_module: float module used to generate the q_module
491
+ q_module: module quantized from float_module
492
+ logger_cls: type of logger to be attached to float_module and q_module
493
+ allow_list: list of module types to attach logger
494
+ """
495
+ torch._C._log_api_usage_once(
496
+ "quantization_api._numeric_suite.prepare_model_outputs"
497
+ )
498
+ if allow_list is None:
499
+ allow_list = get_default_compare_output_module_list()
500
+
501
+ qconfig_debug = torch.ao.quantization.QConfig(activation=logger_cls, weight=None)
502
+ float_module.qconfig = qconfig_debug # type: ignore[assignment]
503
+ prepare(
504
+ float_module, inplace=True, allow_list=allow_list, prepare_custom_config_dict={}
505
+ )
506
+ q_module.qconfig = qconfig_debug # type: ignore[assignment]
507
+ prepare(
508
+ q_module,
509
+ inplace=True,
510
+ allow_list=allow_list,
511
+ observer_non_leaf_module_list=NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST,
512
+ prepare_custom_config_dict={},
513
+ )
514
+
515
+
516
+ def compare_model_outputs(
517
+ float_model: nn.Module,
518
+ q_model: nn.Module,
519
+ *data,
520
+ logger_cls=OutputLogger,
521
+ allow_list=None,
522
+ ) -> Dict[str, Dict[str, torch.Tensor]]:
523
+ r"""Compare output activations between float and quantized models at
524
+ corresponding locations for the same input. Return a dict with key corresponding
525
+ to quantized module names and each entry being a dictionary with two keys
526
+ 'float' and 'quantized', containing the activations of quantized model and
527
+ float model at matching locations. This dict can be used to compare and
528
+ compute the propagation quantization error.
529
+
530
+ Example usage::
531
+
532
+ act_compare_dict = compare_model_outputs(float_model, qmodel, data)
533
+ for key in act_compare_dict:
534
+ print(
535
+ key,
536
+ compute_error(
537
+ act_compare_dict[key]['float'],
538
+ act_compare_dict[key]['quantized'].dequantize()
539
+ )
540
+ )
541
+
542
+ Args:
543
+ float_model: float model used to generate the q_model
544
+ q_model: model quantized from float_model
545
+ data: input data used to run the prepared float_model and q_model
546
+ logger_cls: type of logger to be attached to float_module and q_module
547
+ allow_list: list of module types to attach logger
548
+
549
+ Return:
550
+ act_compare_dict: dict with key corresponding to quantized module names
551
+ and each entry being a dictionary with two keys 'float' and 'quantized',
552
+ containing the matching float and quantized activations
553
+ """
554
+ torch._C._log_api_usage_once(
555
+ "quantization_api._numeric_suite.compare_model_outputs"
556
+ )
557
+ if allow_list is None:
558
+ allow_list = get_default_compare_output_module_list()
559
+ prepare_model_outputs(float_model, q_model, logger_cls, allow_list)
560
+ float_model(*data)
561
+ q_model(*data)
562
+ act_compare_dict = get_matching_activations(float_model, q_model)
563
+ return act_compare_dict
.venv/Lib/site-packages/torch/ao/ns/_numeric_suite_fx.py ADDED
@@ -0,0 +1,1130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ """
3
+ This module contains tooling to compare weights and activations
4
+ across models. Example usage::
5
+
6
+ import copy
7
+ import torch
8
+ import torch.ao.quantization.quantize_fx as quantize_fx
9
+ import torch.ao.ns._numeric_suite_fx as ns
10
+
11
+ m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1)).eval()
12
+ mp = quantize_fx.prepare_fx(m, {'': torch.ao.quantization.default_qconfig})
13
+ # We convert a copy because we need the original prepared model
14
+ # to be available for comparisons, and `quantize_fx.convert_fx` is inplace.
15
+ mq = quantize_fx.convert_fx(copy.deepcopy(mp))
16
+
17
+ #
18
+ # Comparing weights
19
+ #
20
+
21
+ # extract weight pairs
22
+ weight_comparison = ns.extract_weights('a', mp, 'b', mq)
23
+
24
+ # add SQNR for each comparison, inplace
25
+ ns.extend_logger_results_with_comparison(
26
+ weight_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr,
27
+ 'sqnr')
28
+
29
+ # weight_comparison contains the weights from `mp` and `mq` stored
30
+ # in pairs, and can be used for further analysis.
31
+
32
+
33
+ #
34
+ # Comparing activations, with error propagation
35
+ #
36
+
37
+ # add loggers
38
+ mp_ns, mq_ns = ns.add_loggers(
39
+ 'a', copy.deepcopy(mp),
40
+ 'b', copy.deepcopy(mq),
41
+ ns.OutputLogger)
42
+
43
+ # send an example datum to capture intermediate activations
44
+ datum = torch.randn(1, 1, 1, 1)
45
+ mp_ns(datum)
46
+ mq_ns(datum)
47
+
48
+ # extract intermediate activations
49
+ act_comparison = ns.extract_logger_info(
50
+ mp_ns, mq_ns, ns.OutputLogger, 'b')
51
+
52
+ # add SQNR for each comparison, inplace
53
+ ns.extend_logger_results_with_comparison(
54
+ act_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr,
55
+ 'sqnr')
56
+
57
+ # act_comparison contains the activations from `mp_ns` and `mq_ns` stored
58
+ # in pairs, and can be used for further analysis.
59
+
60
+ #
61
+ # Comparing activations, without error propagation
62
+ #
63
+
64
+ # create shadow model
65
+ mp_shadows_mq = ns.add_shadow_loggers(
66
+ 'a', copy.deepcopy(mp),
67
+ 'b', copy.deepcopy(mq),
68
+ ns.OutputLogger)
69
+
70
+ # send an example datum to capture intermediate activations
71
+ datum = torch.randn(1, 1, 1, 1)
72
+ mp_shadows_mq(datum)
73
+
74
+ # extract intermediate activations
75
+ shadow_act_comparison = ns.extract_shadow_logger_info(
76
+ mp_shadows_mq, ns.OutputLogger, 'b')
77
+
78
+ # add SQNR for each comparison, inplace
79
+ ns.extend_logger_results_with_comparison(
80
+ shadow_act_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr,
81
+ 'sqnr')
82
+
83
+ # shadow_act_comparison contains the activations from `mp_ns` and `mq_ns` stored
84
+ # in pairs, and can be used for further analysis.
85
+
86
+ """
87
+
88
+ import collections
89
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING
90
+
91
+ import torch
92
+ import torch.ao.quantization.quantize_fx as quantize_fx
93
+ import torch.nn as nn
94
+ from torch.ao.ns.fx.graph_matcher import (
95
+ get_matching_subgraph_pairs,
96
+ get_type_a_related_to_b,
97
+ )
98
+ from torch.ao.ns.fx.mappings import get_base_name_to_sets_of_related_ops
99
+ from torch.ao.ns.fx.n_shadows_utils import (
100
+ _get_dedup_subgraphs,
101
+ create_add_loggers_graph,
102
+ create_n_transformed_and_logged_copies_of_subgraph,
103
+ create_results_comparison,
104
+ extract_weight_comparison,
105
+ group_results_by_subgraph,
106
+ OutputProp,
107
+ print_n_shadows_summary,
108
+ SHADOW_WRAPPER_NODE_NAME_PREFIX,
109
+ )
110
+ from torch.ao.ns.fx.qconfig_multi_mapping import QConfigMultiMapping
111
+ from torch.ao.quantization import QConfigMapping
112
+ from torch.ao.quantization.backend_config import BackendConfig
113
+ from torch.ao.quantization.backend_config.utils import (
114
+ get_fusion_pattern_to_root_node_getter,
115
+ )
116
+ from torch.ao.quantization.fx.graph_module import _get_observed_graph_module_attr
117
+ from torch.ao.quantization.fx.match_utils import _find_matches
118
+ from torch.ao.quantization.fx.qconfig_mapping_utils import (
119
+ _generate_node_name_to_qconfig,
120
+ )
121
+ from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
122
+ from torch.fx import GraphModule
123
+ from torch.fx.graph import Node
124
+
125
+ from .fx.graph_passes import add_loggers_to_model, create_a_shadows_b
126
+ from .fx.ns_types import NSNodeTargetType, NSResultsType, NSSingleResultValuesType
127
+ from .fx.utils import (
128
+ get_target_type_str,
129
+ maybe_add_missing_fqns,
130
+ rekey_logger_info_on_node_name_of_model,
131
+ )
132
+ from .fx.weight_utils import extract_weight_from_node
133
+
134
+
135
+ if TYPE_CHECKING:
136
+ from torch.ao.quantization.qconfig import QConfigAny
137
+
138
+ RNNReturnType = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
139
+
140
+
141
+ class OutputLogger(nn.Module):
142
+ """
143
+ Base class for capturing intermediate values.
144
+ """
145
+
146
+ stats: List[torch.Tensor]
147
+ stats_rnn: List[RNNReturnType]
148
+
149
+ # Mark as impure so that calls to it will not be removed during DCE.
150
+ _is_impure = True
151
+
152
+ def __init__(
153
+ self,
154
+ ref_node_name: str,
155
+ prev_node_name: str,
156
+ model_name: str,
157
+ ref_name: str,
158
+ prev_node_target_type: str,
159
+ ref_node_target_type: str,
160
+ results_type: str,
161
+ index_within_arg: int,
162
+ index_of_arg: int,
163
+ fqn: Optional[str],
164
+ qconfig_str: Optional[str] = "",
165
+ ):
166
+ super().__init__()
167
+ self.stats: List[torch.Tensor] = []
168
+ self.stats_rnn: List[RNNReturnType] = []
169
+
170
+ # name of the node which was responsible for adding this logger
171
+ # Note:
172
+ # - if we are logging node outputs, this is the same as prev_node_name
173
+ # - if we are logging node inputs, this is the name of the node
174
+ # whose input this logger is logging.
175
+ #
176
+ # example, where logger1 is logging input of op1 and logger2 is logging
177
+ # the output of op1:
178
+ #
179
+ # x1 -> logger1 -> op1 -> logger2 -> x2
180
+ #
181
+ # in this example,
182
+ # - logger1's prev_node_name is x1 and ref_node_name is op1
183
+ # - logger2's prev_node_name is op1 and ref_node_name is op1
184
+ self.ref_node_name = ref_node_name
185
+ # name of the node whose output this Logger is capturing
186
+ self.prev_node_name = prev_node_name
187
+
188
+ # name of the model from which the node originated from
189
+ self.model_name = model_name
190
+ # reference name, used to match loggers from separate models
191
+ # to each other
192
+ self.ref_name = ref_name
193
+ # type of the target of the node whose output this logger is logging
194
+ self.prev_node_target_type = prev_node_target_type
195
+ # type of the target of the node which was responsible for adding this
196
+ # logger
197
+ self.ref_node_target_type = ref_node_target_type
198
+ # what kind of values are inside of stats
199
+ self.results_type = results_type
200
+ # index of this node within the arg of the input/output node
201
+ # for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1
202
+ self.index_within_arg = index_within_arg
203
+ # index of this node within the args of the input/output node
204
+ # for example, in add(x1, x2), x2 would have index_of_arg == 1
205
+ self.index_of_arg = index_of_arg
206
+ # fully qualified name
207
+ self.fqn = fqn
208
+ # if loggers are added before prepare_fx, but we do not want
209
+ # collect results of calibration, only results after convert_fx
210
+ # so, we add a flag to control whether this logger collects data
211
+ self.enabled = True
212
+ # string representation of qconfig
213
+ self.qconfig_str = qconfig_str
214
+ # this can be turned off to reduce memory usage during calibration
215
+ self.save_activations = True
216
+
217
+ # Note: cannot annotate the type of x because TorchScript does not support
218
+ # the Union type.
219
+ def forward(self, x):
220
+ # fmt: off
221
+ """
222
+ """ # blank docblock to make autodoc happy
223
+ # fmt: on
224
+ # TODO(future PR): consider designing this better, as the difference
225
+ # between these two flags is subtle and not obvious.
226
+ if not self.enabled:
227
+ return x
228
+ if not self.save_activations:
229
+ return x
230
+ # TODO(future PR): consider refactoring this to better reuse the parent
231
+ # class
232
+ if isinstance(x, torch.Tensor):
233
+ self.stats.append(x.detach())
234
+ elif isinstance(x, tuple) and len(x) == 2 and len(x[1]) == 2:
235
+ new_res = (x[0].detach(), (x[1][0].detach(), x[1][1].detach()))
236
+ self.stats_rnn.append(new_res)
237
+ return x
238
+
239
+ def __repr__(self):
240
+ clean_dict = {
241
+ k: v
242
+ for k, v in self.__dict__.items()
243
+ # skip nn.Module keys
244
+ if (k != "training") and not k.startswith("_")
245
+ }
246
+ return f"OutputLogger({clean_dict})"
247
+
248
+
249
+ class OutputComparisonLogger(OutputLogger):
250
+ """
251
+ Same as OutputLogger, but also requires the original activation
252
+ in order to calculate the comparison at calibration time
253
+ """
254
+
255
+ def __init__(self, *args, **kwargs):
256
+ super().__init__(*args, **kwargs)
257
+ # TODO(future PR): make the comparison function configurable
258
+ self.comparison_fn = torch.ao.ns.fx.utils.compute_sqnr
259
+ self.comparison_fn_name = "sqnr"
260
+ # precalculated comparisons of logger output versus reference
261
+ self.comparisons = []
262
+ # precalculated comparisons function
263
+
264
+ def forward(self, x, x_ref):
265
+ # fmt: off
266
+ """
267
+ """ # blank docblock to make autodoc happy
268
+ # fmt: on
269
+ if not self.enabled:
270
+ return x
271
+ assert isinstance(x, torch.Tensor), "non-tensor inputs not yet supported"
272
+ if self.save_activations:
273
+ # save the activation, for debugging
274
+ self.stats.append(x.detach())
275
+ # save the comparison
276
+ self.comparisons.append(self.comparison_fn(x, x_ref))
277
+ return x
278
+
279
+ def __repr__(self):
280
+ clean_dict = {
281
+ k: v
282
+ for k, v in self.__dict__.items()
283
+ # skip nn.Module keys
284
+ if (k != "training") and not k.startswith("_")
285
+ }
286
+ return f"OutputComparisonLogger({clean_dict})"
287
+
288
+
289
+ class NSTracer(quantize_fx.QuantizationTracer):
290
+ """
291
+ Just like a regular FX quantization tracer, but treats observers and fake_quantize
292
+ modules as leaf modules.
293
+ """
294
+
295
+ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
296
+ # fmt: off
297
+ """
298
+ """ # blank docblock to make autodoc happy
299
+ # fmt: on
300
+ if isinstance(m, torch.ao.quantization.ObserverBase):
301
+ return True
302
+ elif isinstance(m, torch.ao.quantization.FakeQuantizeBase):
303
+ return True
304
+ return super().is_leaf_module(m, module_qualified_name)
305
+
306
+
307
+ def _extract_weights_one_model(
308
+ model_name: str,
309
+ model: GraphModule,
310
+ nodes_and_names_to_instrument: List[Tuple[Node, str]],
311
+ results: NSResultsType,
312
+ op_to_type_to_weight_extraction_fn: Optional[
313
+ Dict[str, Dict[Callable, Callable]]
314
+ ] = None,
315
+ ) -> None:
316
+ torch._C._log_api_usage_once(
317
+ "quantization_api._numeric_suite_fx._extract_weights_one_model"
318
+ )
319
+ for node, ref_name in nodes_and_names_to_instrument:
320
+ res_type = NSSingleResultValuesType.WEIGHT.value
321
+ extracted_weight = extract_weight_from_node(
322
+ node, model, op_to_type_to_weight_extraction_fn
323
+ )
324
+ if extracted_weight:
325
+ if ref_name not in results:
326
+ results[ref_name] = {res_type: {}}
327
+ results[ref_name][res_type][model_name] = [extracted_weight]
328
+
329
+
330
+ def _extract_weights_impl(
331
+ model_name_a: str,
332
+ gm_a: GraphModule,
333
+ model_name_b: str,
334
+ gm_b: GraphModule,
335
+ base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
336
+ unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
337
+ op_to_type_to_weight_extraction_fn: Optional[
338
+ Dict[str, Dict[Callable, Callable]]
339
+ ] = None,
340
+ ) -> NSResultsType:
341
+ torch._C._log_api_usage_once(
342
+ "quantization_api._numeric_suite_fx._extract_weights_impl"
343
+ )
344
+ matched_subgraph_pairs = get_matching_subgraph_pairs(
345
+ gm_a, gm_b, base_name_to_sets_of_related_ops, unmatchable_types_map
346
+ )
347
+
348
+ # split the subgraph pairs into one data structure for each model
349
+ nodes_and_names_to_instrument_a: List[Tuple[Node, str]] = []
350
+ nodes_and_names_to_instrument_b: List[Tuple[Node, str]] = []
351
+ for match_name, match in matched_subgraph_pairs.items():
352
+ subgraph_a, subgraph_b = match
353
+ nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name))
354
+ nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name))
355
+
356
+ # populate the results, one model at a time
357
+ results: NSResultsType = {}
358
+ _extract_weights_one_model(
359
+ model_name_a,
360
+ gm_a,
361
+ nodes_and_names_to_instrument_a,
362
+ results,
363
+ op_to_type_to_weight_extraction_fn,
364
+ )
365
+ _extract_weights_one_model(
366
+ model_name_b,
367
+ gm_b,
368
+ nodes_and_names_to_instrument_b,
369
+ results,
370
+ op_to_type_to_weight_extraction_fn,
371
+ )
372
+
373
+ # fill in missing fqn entries
374
+ maybe_add_missing_fqns(results)
375
+
376
+ # rekey on names of nodes in gm_b
377
+ results = rekey_logger_info_on_node_name_of_model(results, model_name_b)
378
+
379
+ return results
380
+
381
+
382
+ def extract_weights(
383
+ model_name_a: str,
384
+ model_a: nn.Module,
385
+ model_name_b: str,
386
+ model_b: nn.Module,
387
+ base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
388
+ unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
389
+ op_to_type_to_weight_extraction_fn: Optional[
390
+ Dict[str, Dict[Callable, Callable]]
391
+ ] = None,
392
+ ) -> NSResultsType:
393
+ """
394
+ Extract weights from model A and model B, and return a comparison.
395
+
396
+ Args:
397
+ model_name_a: string name of model A to use in results
398
+ model_a: model A
399
+ model_name_b: string name of model B to use in results
400
+ model_b: model B
401
+ base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change
402
+ unmatchable_types_map: optional override of unmatchable types, subject to change
403
+ op_to_type_to_weight_extraction_fn: optional override of function which extracts weight
404
+ from a type, subject to change
405
+
406
+ Return:
407
+ NSResultsType, containing the weight comparisons
408
+ """
409
+
410
+ torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_weights")
411
+ if base_name_to_sets_of_related_ops is None:
412
+ base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
413
+ type_a_related_to_b = get_type_a_related_to_b(base_name_to_sets_of_related_ops)
414
+
415
+ # TODO(future PR): expose these
416
+ skipped_module_names: List[str] = []
417
+ skipped_module_classes: List[Callable] = []
418
+ tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
419
+ tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
420
+ gm_a = GraphModule(model_a, tracer_a.trace(model_a))
421
+ maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr(
422
+ model_a, "node_name_to_scope"
423
+ )
424
+ if maybe_model_a_node_name_to_scope is not None:
425
+ gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope
426
+ gm_b = GraphModule(model_b, tracer_b.trace(model_b))
427
+ maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr(
428
+ model_b, "node_name_to_scope"
429
+ )
430
+ if maybe_model_b_node_name_to_scope is not None:
431
+ gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope
432
+ return _extract_weights_impl(
433
+ model_name_a,
434
+ gm_a,
435
+ model_name_b,
436
+ gm_b,
437
+ base_name_to_sets_of_related_ops,
438
+ unmatchable_types_map,
439
+ op_to_type_to_weight_extraction_fn,
440
+ )
441
+
442
+
443
+ def _add_loggers_one_model(
444
+ model_name: str,
445
+ model: GraphModule,
446
+ nodes_and_names_to_instrument_inputs: List[Tuple[Node, str, str]],
447
+ nodes_and_names_to_instrument_outputs: List[Tuple[Node, str, str]],
448
+ logger_cls: Callable,
449
+ ) -> nn.Module:
450
+ torch._C._log_api_usage_once(
451
+ "quantization_api._numeric_suite_fx._add_loggers_one_model"
452
+ )
453
+
454
+ # TODO(future PR): do not observe nodes we do not care
455
+ # about (both fp32, denylist, etc)
456
+ node_to_instrument_inputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
457
+ node_to_instrument_outputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
458
+ for node, ref_name, ref_node_type in nodes_and_names_to_instrument_inputs:
459
+ node_to_instrument_inputs_to_ref_name[node] = (ref_name, ref_node_type)
460
+ for node, ref_name, ref_node_type in nodes_and_names_to_instrument_outputs:
461
+ node_to_instrument_outputs_to_ref_name[node] = (ref_name, ref_node_type)
462
+
463
+ model = add_loggers_to_model(
464
+ model,
465
+ node_to_instrument_inputs_to_ref_name,
466
+ node_to_instrument_outputs_to_ref_name,
467
+ logger_cls,
468
+ model_name,
469
+ )
470
+ return model
471
+
472
+
473
+ def _add_loggers_impl(
474
+ name_a: str,
475
+ gm_a: GraphModule,
476
+ name_b: str,
477
+ gm_b: GraphModule,
478
+ logger_cls: Callable,
479
+ should_log_inputs: bool,
480
+ base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
481
+ unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
482
+ ) -> Tuple[nn.Module, nn.Module]:
483
+ torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_impl")
484
+ matched_subgraph_pairs = get_matching_subgraph_pairs(
485
+ gm_a, gm_b, base_name_to_sets_of_related_ops, unmatchable_types_map
486
+ )
487
+ nodes_and_names_to_instrument_inputs_a = []
488
+ nodes_and_names_to_instrument_inputs_b = []
489
+ nodes_and_names_to_instrument_outputs_a = []
490
+ nodes_and_names_to_instrument_outputs_b = []
491
+ for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items():
492
+ ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a)
493
+ ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b)
494
+ # Note: for matching inputs we use start_node, such as observing
495
+ # the input of linear in linear-relu
496
+ if should_log_inputs:
497
+ nodes_and_names_to_instrument_inputs_a.append(
498
+ (subgraph_a.start_node, match_name, ref_node_type_a)
499
+ )
500
+ nodes_and_names_to_instrument_inputs_b.append(
501
+ (subgraph_b.start_node, match_name, ref_node_type_b)
502
+ )
503
+ # Note: for matching activations we always use end_node,
504
+ # such as observing the output of relu in linear-relu
505
+ nodes_and_names_to_instrument_outputs_a.append(
506
+ (subgraph_a.end_node, match_name, ref_node_type_a)
507
+ )
508
+ nodes_and_names_to_instrument_outputs_b.append(
509
+ (subgraph_b.end_node, match_name, ref_node_type_b)
510
+ )
511
+
512
+ new_model_a = _add_loggers_one_model(
513
+ name_a,
514
+ gm_a,
515
+ nodes_and_names_to_instrument_inputs_a,
516
+ nodes_and_names_to_instrument_outputs_a,
517
+ logger_cls,
518
+ )
519
+ new_model_b = _add_loggers_one_model(
520
+ name_b,
521
+ gm_b,
522
+ nodes_and_names_to_instrument_inputs_b,
523
+ nodes_and_names_to_instrument_outputs_b,
524
+ logger_cls,
525
+ )
526
+ return (new_model_a, new_model_b)
527
+
528
+
529
+ def add_loggers(
530
+ name_a: str,
531
+ model_a: nn.Module,
532
+ name_b: str,
533
+ model_b: nn.Module,
534
+ logger_cls: Callable,
535
+ should_log_inputs: bool = False,
536
+ base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
537
+ unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
538
+ ) -> Tuple[nn.Module, nn.Module]:
539
+ """
540
+ Instrument model A and model B with loggers.
541
+
542
+ Args:
543
+ name_a: string name of model A to use in results
544
+ model_a: model A
545
+ name_b: string name of model B to use in results
546
+ model_b: model B
547
+ logger_cls: class of Logger to use
548
+ base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change
549
+ unmatchable_types_map: optional override of unmatchable types, subject to change
550
+
551
+ Return:
552
+ Returns a tuple of (model_a_with_loggers, model_b_with_loggers). Modifies both models inplace.
553
+ """
554
+
555
+ torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_loggers")
556
+ # TODO(future PR): expose these
557
+ skipped_module_names: List[str] = []
558
+ skipped_module_classes: List[Callable] = []
559
+ tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
560
+ tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
561
+ gm_a = GraphModule(model_a, tracer_a.trace(model_a))
562
+ maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr(
563
+ model_a, "node_name_to_scope"
564
+ )
565
+ if maybe_model_a_node_name_to_scope is not None:
566
+ gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope
567
+ gm_b = GraphModule(model_b, tracer_b.trace(model_b))
568
+ maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr(
569
+ model_b, "node_name_to_scope"
570
+ )
571
+ if maybe_model_b_node_name_to_scope is not None:
572
+ gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope
573
+ return _add_loggers_impl(
574
+ name_a,
575
+ gm_a,
576
+ name_b,
577
+ gm_b,
578
+ logger_cls,
579
+ should_log_inputs=should_log_inputs,
580
+ base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
581
+ unmatchable_types_map=unmatchable_types_map,
582
+ )
583
+
584
+
585
+ def _extract_logger_info_one_model(
586
+ model: nn.Module,
587
+ results: NSResultsType,
588
+ logger_cls: Callable,
589
+ ) -> None:
590
+ torch._C._log_api_usage_once(
591
+ "quantization_api._numeric_suite_fx._extract_logger_info_one_model"
592
+ )
593
+ for gm_name, mod in model.named_modules():
594
+ # TODO(future PR): better check when scripted
595
+ is_logger = isinstance(mod, logger_cls) or ( # type: ignore[arg-type]
596
+ isinstance(mod, torch.jit.RecursiveScriptModule)
597
+ and mod.original_name == "OutputLogger"
598
+ )
599
+ if is_logger:
600
+ key = mod.ref_name
601
+ if key not in results:
602
+ results[key] = {}
603
+ assert (
604
+ mod.model_name not in results[key]
605
+ ), f"{mod.model_name} is already present in results"
606
+ if mod.results_type not in results[key]:
607
+ results[key][mod.results_type] = {}
608
+ if mod.model_name not in results[key][mod.results_type]:
609
+ results[key][mod.results_type][mod.model_name] = []
610
+ stats_to_use = mod.stats
611
+ if len(mod.stats_rnn) > 0:
612
+ stats_to_use = mod.stats_rnn
613
+ data = {
614
+ "type": mod.results_type,
615
+ "values": stats_to_use,
616
+ "ref_node_name": mod.ref_node_name,
617
+ "ref_node_target_type": mod.ref_node_target_type,
618
+ "prev_node_name": mod.prev_node_name,
619
+ "prev_node_target_type": mod.prev_node_target_type,
620
+ "index_within_arg": mod.index_within_arg,
621
+ "index_of_arg": mod.index_of_arg,
622
+ "fqn": mod.fqn,
623
+ "qconfig_str": mod.qconfig_str,
624
+ }
625
+ if hasattr(mod, "comparisons"):
626
+ data["comparisons"] = mod.comparisons
627
+ data["comparison_fn_name"] = mod.comparison_fn_name
628
+ else:
629
+ data["comparisons"] = []
630
+ data["comparison_fn_name"] = ""
631
+ results[key][mod.results_type][mod.model_name].append(data)
632
+ # ensure the list stays sorted
633
+ results[key][mod.results_type][mod.model_name].sort(
634
+ key=lambda res: f"{res['index_of_arg']}:{res['index_within_arg']}"
635
+ )
636
+
637
+
638
+ # TODO(future PR): align on naming
639
+ # this is equivalent of just the comparison extraction part of `ns.compare_model_outputs`
640
+ def extract_logger_info(
641
+ model_a: nn.Module,
642
+ model_b: nn.Module,
643
+ logger_cls: Callable,
644
+ model_name_to_use_for_layer_names: str,
645
+ ) -> NSResultsType:
646
+ """
647
+ Traverse all loggers in `model_a` and `model_b`, and extract the logged
648
+ information.
649
+
650
+ Args:
651
+ model_a: model A
652
+ model_b: model B
653
+ logger_cls: class of Logger to use
654
+ model_name_to_use_for_layer_names: string name of model to use for
655
+ layer names in the output
656
+
657
+ Return:
658
+ NSResultsType, containing the logged comparisons
659
+ """
660
+ torch._C._log_api_usage_once(
661
+ "quantization_api._numeric_suite_fx.extract_logger_info"
662
+ )
663
+ results: NSResultsType = {}
664
+ for model in (model_a, model_b):
665
+ _extract_logger_info_one_model(model, results, logger_cls)
666
+ # fill in missing fqn entries
667
+ maybe_add_missing_fqns(results)
668
+ # rekey on the name of model b
669
+ results = rekey_logger_info_on_node_name_of_model(
670
+ results, model_name_to_use_for_layer_names
671
+ )
672
+ return results
673
+
674
+
675
+ def _add_shadow_loggers_impl(
676
+ name_a: str,
677
+ gm_a: GraphModule,
678
+ name_b: str,
679
+ gm_b: GraphModule,
680
+ logger_cls: Callable,
681
+ should_log_inputs: bool,
682
+ base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
683
+ node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
684
+ unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
685
+ ) -> nn.Module:
686
+ torch._C._log_api_usage_once(
687
+ "quantization_api._numeric_suite_fx._add_shadow_loggers_impl"
688
+ )
689
+ matched_subgraph_pairs = get_matching_subgraph_pairs(
690
+ gm_a, gm_b, base_name_to_sets_of_related_ops, unmatchable_types_map
691
+ )
692
+ gm_a_shadows_b = create_a_shadows_b(
693
+ name_a,
694
+ gm_a,
695
+ name_b,
696
+ gm_b,
697
+ matched_subgraph_pairs,
698
+ logger_cls,
699
+ should_log_inputs=should_log_inputs,
700
+ node_type_to_io_type_map=node_type_to_io_type_map,
701
+ )
702
+ return gm_a_shadows_b
703
+
704
+
705
+ def add_shadow_loggers(
706
+ name_a: str,
707
+ model_a: nn.Module,
708
+ name_b: str,
709
+ model_b: nn.Module,
710
+ logger_cls: Callable,
711
+ should_log_inputs: bool = False,
712
+ base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
713
+ node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
714
+ unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
715
+ ) -> nn.Module:
716
+ """
717
+ Instrument model A and model B with shadow loggers.
718
+
719
+ Args:
720
+ name_a: string name of model A to use in results
721
+ model_a: model A
722
+ name_b: string name of model B to use in results
723
+ model_b: model B
724
+ logger_cls: class of Logger to use
725
+ should_log_inputs: whether to log inputs
726
+ base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change
727
+ unmatchable_types_map: optional override of unmatchable types, subject to change
728
+ """
729
+ torch._C._log_api_usage_once(
730
+ "quantization_api._numeric_suite_fx.add_shadow_loggers"
731
+ )
732
+ # TODO(future PR): expose these
733
+ skipped_module_names: List[str] = []
734
+ skipped_module_classes: List[Callable] = []
735
+ tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
736
+ tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
737
+ gm_a = GraphModule(model_a, tracer_a.trace(model_a))
738
+ maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr(
739
+ model_a, "node_name_to_scope"
740
+ )
741
+ if maybe_model_a_node_name_to_scope is not None:
742
+ gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope
743
+ gm_b = GraphModule(model_b, tracer_b.trace(model_b))
744
+ maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr(
745
+ model_b, "node_name_to_scope"
746
+ )
747
+ if maybe_model_b_node_name_to_scope is not None:
748
+ gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope
749
+ return _add_shadow_loggers_impl(
750
+ name_a,
751
+ gm_a,
752
+ name_b,
753
+ gm_b,
754
+ logger_cls,
755
+ should_log_inputs=should_log_inputs,
756
+ base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
757
+ node_type_to_io_type_map=node_type_to_io_type_map,
758
+ unmatchable_types_map=unmatchable_types_map,
759
+ )
760
+
761
+
762
+ def extract_shadow_logger_info(
763
+ model_a_shadows_b: nn.Module,
764
+ logger_cls: Callable,
765
+ model_name_to_use_for_layer_names: str,
766
+ ) -> NSResultsType:
767
+ """
768
+ Traverse all loggers in a shadow model, and extract the logged
769
+ information.
770
+
771
+ Args:
772
+ model_a_shadows_b: shadow model
773
+ logger_cls: class of Logger to use
774
+ model_name_to_use_for_layer_names: string name of model to use for
775
+ layer names in the output
776
+
777
+ Return:
778
+ NSResultsType, containing the logged comparisons
779
+ """
780
+ torch._C._log_api_usage_once(
781
+ "quantization_api._numeric_suite_fx.extract_shadow_logger_info"
782
+ )
783
+ results: NSResultsType = collections.defaultdict(dict)
784
+ _extract_logger_info_one_model(model_a_shadows_b, results, logger_cls)
785
+ # fill in missing fqn entries
786
+ maybe_add_missing_fqns(results)
787
+ # rekey on the name of model b
788
+ results = rekey_logger_info_on_node_name_of_model(
789
+ results, model_name_to_use_for_layer_names
790
+ )
791
+ return dict(results)
792
+
793
+
794
+ def extend_logger_results_with_comparison(
795
+ results: NSResultsType,
796
+ model_name_1: str,
797
+ model_name_2: str,
798
+ comparison_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
799
+ comparison_name: str,
800
+ ) -> None:
801
+ """
802
+ Compares the logged values from `model_name_2` against the corresponding
803
+ values in `model_name_1`, using `comparison_fn`. Records the result
804
+ in `model_name_2`'s results under `comparison_name`. Modifies `results` inplace.
805
+
806
+ Args:
807
+ results: the result data structure from `extract_logger_info` or
808
+ `extract_shadow_logger_info`.
809
+ model_name_1: string name of model 1
810
+ model_name_2: string name of model 2
811
+ comparison_fn: function to compare two Tensors
812
+ comparison_name: string name of model to use for
813
+ layer names in the output
814
+ """
815
+ for results_type_to_results in results.values():
816
+ for model_name_to_results in results_type_to_results.values():
817
+ assert (
818
+ model_name_1 in model_name_to_results
819
+ ), f"{model_name_1} not found in results"
820
+ assert (
821
+ model_name_2 in model_name_to_results
822
+ ), f"{model_name_2} not found in results"
823
+
824
+ results_1 = model_name_to_results[model_name_1]
825
+ results_2 = model_name_to_results[model_name_2]
826
+
827
+ for result_2 in results_2:
828
+ index_within_arg_2 = result_2["index_within_arg"]
829
+ index_of_arg_2 = result_2["index_of_arg"]
830
+ # find corresponding result_1
831
+ result_1 = None
832
+ for cur_result_1 in results_1:
833
+ index_within_arg_1 = cur_result_1["index_within_arg"]
834
+ index_of_arg_1 = cur_result_1["index_of_arg"]
835
+ if (index_within_arg_1 == index_within_arg_2) and (
836
+ index_of_arg_1 == index_of_arg_2
837
+ ):
838
+ result_1 = cur_result_1
839
+ break
840
+ assert result_1 is not None
841
+
842
+ values_1 = result_1["values"]
843
+ values_2 = result_2["values"]
844
+ result_2[comparison_name] = []
845
+ for value_1, value_2 in zip(values_1, values_2):
846
+ comparison_result = comparison_fn(value_1, value_2)
847
+ result_2[comparison_name].append(comparison_result)
848
+
849
+
850
+ def prepare_n_shadows_model(
851
+ model: torch.nn.Module,
852
+ example_inputs: Any,
853
+ qconfig_multi_mapping: QConfigMultiMapping,
854
+ backend_config: BackendConfig,
855
+ custom_prepare_fn: Optional[Callable] = None,
856
+ custom_prepare_kwargs: Optional[Dict[str, Any]] = None,
857
+ custom_tracer: Any = None,
858
+ ) -> GraphModule:
859
+ """
860
+ Given a model with a graph with M ops such as
861
+
862
+
863
+ args_kwargs_m -> op_m -> output_m
864
+
865
+
866
+ And a set of N qconfigs for each op, creates a new model, with
867
+ each of the subgraph of `op_m` transformed into
868
+
869
+ .. code::
870
+
871
+ |---------> op_m_n -> log_m_n
872
+ | /
873
+ args_kwargs_m ---------> op_m -> log_m_0
874
+
875
+ Where op_m_n is op_m wrapped in a submodule and transformed with
876
+ qconfig_n, and its inner graph looks like
877
+
878
+ .. code::
879
+
880
+ args_m -------- op_m_prepared_with_qconfig_n -> out_m_n
881
+ /
882
+ kwargs_m ---
883
+
884
+ This is useful for testing different quantization of multiple layers in
885
+ a single pass through the model.
886
+
887
+ High level TODOs for future PRs:
888
+ * figure out a better way to name the output structure
889
+ * return a results data structure instead of printing it out
890
+ * add examples to docblocks
891
+ """
892
+
893
+ if custom_tracer is None:
894
+ tracer = quantize_fx.QuantizationTracer([], [])
895
+ else:
896
+ tracer = custom_tracer
897
+ mt = torch.fx.GraphModule(model, tracer.trace(model))
898
+ # this is necessary to ensure logger FQNs get populated
899
+ mt._node_name_to_scope = tracer.node_name_to_scope # type: ignore[assignment]
900
+
901
+ # run example input propagation, we need this to call prepare_fx on
902
+ # individual subgraphs
903
+ output_prop = OutputProp(mt)
904
+ output_prop.propagate(*example_inputs)
905
+
906
+ # Find the set of subgraphs in the original graph which we need to
907
+ # consider.
908
+ modules = dict(mt.named_modules(remove_duplicate=False))
909
+ patterns = _get_pattern_to_quantize_handlers(backend_config)
910
+ root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config)
911
+ standalone_module_names: List[str] = []
912
+ standalone_module_classes: List[Type] = []
913
+ custom_module_classes: List[Type] = []
914
+ matches = _find_matches(
915
+ mt.graph,
916
+ modules,
917
+ patterns,
918
+ root_node_getter_mapping,
919
+ standalone_module_names,
920
+ standalone_module_classes,
921
+ custom_module_classes,
922
+ )
923
+ subgraphs_dedup: Dict[str, List[Node]] = _get_dedup_subgraphs(matches)
924
+
925
+ # generate node to qconfig for each subgraph
926
+ # TODO(future PR): deduplicate repeating entries
927
+ list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]] = []
928
+ for qconfig_mapping in qconfig_multi_mapping.qconfig_mappings_list:
929
+ node_name_to_qconfig = _generate_node_name_to_qconfig(
930
+ mt, modules, mt.graph, qconfig_mapping, tracer.node_name_to_scope
931
+ )
932
+ list_of_node_name_to_qconfig.append(node_name_to_qconfig)
933
+
934
+ # For each region in the model, do the following:
935
+ # For each qconfig for that region, do the following:
936
+ # 1. create a copy of the region wrapped in a module
937
+ # 2. pass original args, original kwargs, and expected output to module
938
+ # 3. add an output comparison logger and hook it up to compare
939
+ # actual output to expected output
940
+ # 4. run `prepare_fx` on the module
941
+ for subgraph_idx, (match_name, nodes_in_this_subgraph) in enumerate(
942
+ subgraphs_dedup.items()
943
+ ):
944
+ create_n_transformed_and_logged_copies_of_subgraph(
945
+ mt,
946
+ subgraph_idx,
947
+ match_name,
948
+ nodes_in_this_subgraph,
949
+ qconfig_multi_mapping.qconfig_mappings_list,
950
+ list_of_node_name_to_qconfig,
951
+ custom_prepare_fn,
952
+ custom_prepare_kwargs, # type: ignore[arg-type]
953
+ )
954
+
955
+ return mt
956
+
957
+
958
+ # TODO(future PR): we should rethink the names of all the PNP APIs
959
+ def _prepare_n_shadows_add_loggers_model(
960
+ model: torch.nn.Module,
961
+ example_inputs: Any,
962
+ qconfig_mapping: QConfigMapping,
963
+ backend_config: BackendConfig,
964
+ ) -> torch.nn.Module:
965
+ r"""
966
+ Note: this API is not recommended for wide usage, it is only
967
+ provided for customers who need to migrate from the `add_loggers`
968
+ API.
969
+
970
+ This creates a model which provides logging for the following
971
+ problem: if we quantize `model` with `qconfig_mapping` and feed
972
+ the same input through both models, log the comparisons of
973
+ corresponding intermediate layers.
974
+
975
+ The problem is solved with a single model. Specifically, we
976
+ partition `model` into N subgraphs, create a copy of each relevant
977
+ subgraph, wrap it in a module, apply the quantization API to that
978
+ module, and hook up loggers to measure the comparisons.
979
+
980
+ Example starting graph:
981
+
982
+ x0 -> op0 -> x1 -> op1 -> x2
983
+
984
+ Example config: quantize op0 to int8, do nothing to op1.
985
+ The following graph will be created:
986
+
987
+ .. code::
988
+
989
+ x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log
990
+ \ \ \ # noqa: W605
991
+ ---> op0_1 -> x1_1 ----> clog -> op1_0 -> x2_1 ----> clog
992
+
993
+ Where op0_0 is op0, op0_1 is op0 wrapped in a submodule and quantized
994
+ to int8, op1_0 is op1 (appearing in the graph twice), log is a logger,
995
+ and clog is a comparison logger.
996
+ """
997
+
998
+ tracer = quantize_fx.QuantizationTracer([], [])
999
+ mt = torch.fx.GraphModule(model, tracer.trace(model))
1000
+ # this is necessary to ensure logger FQNs get populated
1001
+ mt._node_name_to_scope = tracer.node_name_to_scope # type: ignore[assignment]
1002
+
1003
+ # run example input propagation, we need this to call prepare_fx on
1004
+ # individual subgraphs
1005
+ output_prop = OutputProp(mt)
1006
+ output_prop.propagate(*example_inputs)
1007
+
1008
+ # Find the set of subgraphs in the original graph which we need to
1009
+ # consider.
1010
+ modules = dict(mt.named_modules(remove_duplicate=False))
1011
+ patterns = _get_pattern_to_quantize_handlers(backend_config)
1012
+ root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config)
1013
+ standalone_module_names: List[str] = []
1014
+ standalone_module_classes: List[Type] = []
1015
+ custom_module_classes: List[Type] = []
1016
+ matches = _find_matches(
1017
+ mt.graph,
1018
+ modules,
1019
+ patterns,
1020
+ root_node_getter_mapping,
1021
+ standalone_module_names,
1022
+ standalone_module_classes,
1023
+ custom_module_classes,
1024
+ )
1025
+ subgraphs_dedup: Dict[str, List[Node]] = _get_dedup_subgraphs(matches)
1026
+
1027
+ # generate node to qconfig for each subgraph
1028
+ node_name_to_qconfig = _generate_node_name_to_qconfig(
1029
+ mt, modules, mt.graph, qconfig_mapping, tracer.node_name_to_scope
1030
+ )
1031
+
1032
+ # Now, mutate the graph to be the add_loggers graph with propagation
1033
+ # error.
1034
+ create_add_loggers_graph(mt, subgraphs_dedup, qconfig_mapping, node_name_to_qconfig)
1035
+
1036
+ return mt
1037
+
1038
+
1039
+ # TODO(future PR): we should rethink the names of all the PNP APIs
1040
+ def _n_shadows_compare_weights(
1041
+ model: torch.nn.Module,
1042
+ example_inputs: Any,
1043
+ qconfig_mapping: QConfigMapping,
1044
+ backend_config: BackendConfig,
1045
+ ) -> NSResultsType:
1046
+ """
1047
+ Note: this API is not recommended for wide usage, it is only
1048
+ provided for customers who need to migrate from the `add_loggers`
1049
+ API.
1050
+ """
1051
+ qconfig_multi_mapping = QConfigMultiMapping.from_list_qconfig_mapping(
1052
+ [qconfig_mapping]
1053
+ )
1054
+ mp = prepare_n_shadows_model(
1055
+ model, example_inputs, qconfig_multi_mapping, backend_config
1056
+ )
1057
+ # passing inputs through the model is necessary to populate
1058
+ # observers which observe weights with real values
1059
+ mp(*example_inputs)
1060
+ mq = convert_n_shadows_model(mp)
1061
+ weight_comparison = extract_weight_comparison(mq)
1062
+ return weight_comparison
1063
+
1064
+
1065
+ # TODO(future PR): consider aligning API signature with other similar quantization
1066
+ # functions (enable_fake_quant, etc)
1067
+ def loggers_set_enabled(model: torch.nn.Module, enabled: bool) -> None:
1068
+ """
1069
+ Sets the `enabled` setting on a `model`'s loggers
1070
+ """
1071
+ for name, child in model.named_modules():
1072
+ if isinstance(child, OutputLogger):
1073
+ child.enabled = enabled
1074
+
1075
+
1076
+ # TODO(future PR): consider aligning API signature with other similar quantization
1077
+ # functions (enable_fake_quant, etc)
1078
+ def loggers_set_save_activations(
1079
+ model: torch.nn.Module,
1080
+ save_activations: bool,
1081
+ ) -> None:
1082
+ """
1083
+ Sets the `save_activations` setting on a `model`'s loggers
1084
+ """
1085
+ for name, child in model.named_modules():
1086
+ if isinstance(child, OutputLogger):
1087
+ child.save_activations = save_activations
1088
+
1089
+
1090
+ def convert_n_shadows_model(
1091
+ model: GraphModule,
1092
+ custom_convert_fn: Optional[Callable] = None,
1093
+ custom_convert_kwargs: Optional[Dict[str, Any]] = None,
1094
+ ) -> GraphModule:
1095
+ """
1096
+ Given a model from `prepare_n_shadows_model`, runs `convert_fx`
1097
+ on each shadow submodule.
1098
+ """
1099
+ for node in model.graph.nodes:
1100
+ # TODO(future PR): consider matching in a safer way than
1101
+ # node name string match
1102
+ if node.name.startswith(SHADOW_WRAPPER_NODE_NAME_PREFIX):
1103
+ orig_mod = getattr(model, node.name)
1104
+ if custom_convert_fn is None:
1105
+ converted_mod = torch.ao.quantization.quantize_fx.convert_fx(orig_mod)
1106
+ else:
1107
+ if custom_convert_kwargs is None:
1108
+ custom_convert_kwargs = {}
1109
+ converted_mod = custom_convert_fn(orig_mod, **custom_convert_kwargs)
1110
+ setattr(model, node.name, converted_mod)
1111
+
1112
+ return model
1113
+
1114
+
1115
+ def extract_results_n_shadows_model(model: torch.nn.Module) -> NSResultsType:
1116
+ """
1117
+ Extracts logger results from `model`.
1118
+ """
1119
+ results: NSResultsType = {}
1120
+ _extract_logger_info_one_model(model, results, OutputLogger)
1121
+ return results
1122
+
1123
+
1124
+ def print_comparisons_n_shadows_model(results: NSResultsType) -> None:
1125
+ """
1126
+ Prints a summary of extracted `results`.
1127
+ """
1128
+ results_grouped = group_results_by_subgraph(results)
1129
+ results_comparison = create_results_comparison(results_grouped)
1130
+ print_n_shadows_summary(results_comparison)
.venv/Lib/site-packages/torch/ao/ns/fx/__init__.py ADDED
File without changes
.venv/Lib/site-packages/torch/ao/ns/fx/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (186 Bytes). View file
 
.venv/Lib/site-packages/torch/ao/ns/fx/__pycache__/ns_types.cpython-39.pyc ADDED
Binary file (976 Bytes). View file
 
.venv/Lib/site-packages/torch/ao/ns/fx/__pycache__/utils.cpython-39.pyc ADDED
Binary file (12.7 kB). View file
 
.venv/Lib/site-packages/torch/ao/ns/fx/graph_matcher.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import collections
3
+ import enum
4
+ from typing import Any, Dict, List, Optional, Set, Tuple
5
+
6
+ import torch
7
+ from torch.ao.quantization import FakeQuantizeBase, ObserverBase
8
+ from torch.ao.quantization.utils import getattr_from_fqn
9
+ from torch.fx import GraphModule
10
+ from torch.fx.graph import Graph, Node
11
+
12
+ from .mappings import get_base_name_to_sets_of_related_ops, get_unmatchable_types_map
13
+ from .ns_types import NSNodeTargetType, NSSubgraph
14
+ from .pattern_utils import (
15
+ end_node_matches_reversed_fusion,
16
+ get_reversed_fusions,
17
+ get_type_a_related_to_b,
18
+ )
19
+
20
+
21
+ toq = torch.ops.quantized
22
+
23
+
24
+ def _get_output_nodes(g: Graph) -> List[Node]:
25
+ return [n for n in g.nodes if n.op == "output"]
26
+
27
+
28
+ class _NSGraphMatchableSubgraphsIterator:
29
+ """
30
+ Iterates through the graph of gm, starting with the output nodes
31
+ and continuing backwards.
32
+ 1. Returns matchable subgraphs, in order. A subgraph is defined by
33
+ (start_node, end_node).
34
+ 2. Skips over non-matchable subgraphs
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ gm: GraphModule,
40
+ non_matchable_functions: Set[NSNodeTargetType],
41
+ non_matchable_modules: Set[NSNodeTargetType],
42
+ non_matchable_methods: Set[NSNodeTargetType],
43
+ ):
44
+ self.gm: GraphModule = gm
45
+ self.non_matchable_functions: Set[NSNodeTargetType] = non_matchable_functions
46
+ self.non_matchable_modules: Set[NSNodeTargetType] = non_matchable_modules
47
+ self.non_matchable_methods: Set[NSNodeTargetType] = non_matchable_methods
48
+ self.seen_nodes: Set[Node] = set()
49
+ self.stack: List[Node] = []
50
+ for start_node in _get_output_nodes(self.gm.graph):
51
+ self.stack.append(start_node)
52
+
53
+ def __iter__(self):
54
+ return self
55
+
56
+ def __next__(self) -> NSSubgraph:
57
+ """
58
+ Returns the next matchable subgraph.
59
+ """
60
+ while len(self.stack) > 0:
61
+ cur_end_node = self.stack.pop()
62
+ if cur_end_node in self.seen_nodes:
63
+ continue
64
+
65
+ # for subgraphs which are single nodes, start_node == end_node
66
+ # for subgraphs with more than one node, start node != end_node
67
+ cur_start_node = cur_end_node
68
+ # Subgraphs like linear-relu have the base node as the start node.
69
+ # Subgraphs like dequantize-linear-relu-to(torch.float16) have the
70
+ # base node as the second node.
71
+ # The cur_base_op_node var will move to the actual node during
72
+ # the fusion matching later in this code block.
73
+ cur_base_op_node = cur_end_node
74
+
75
+ # Check for potential fusions. For now, we are greedy
76
+ # and always skip all non-base nodes of a fusion. For example,
77
+ # if we match linear-relu backwards, we will always skip the
78
+ # relu node and attempt to match the linear node. This can
79
+ # be made configurable later if needed.
80
+ for _reverse_fusion_ops, base_op_idx in get_reversed_fusions():
81
+ is_match = end_node_matches_reversed_fusion(
82
+ cur_end_node, _reverse_fusion_ops, self.gm, self.seen_nodes
83
+ )
84
+ if is_match:
85
+ # navigate to the base node
86
+ for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1):
87
+ self.seen_nodes.add(cur_start_node)
88
+ # for now, assume that there are no other nodes
89
+ # which need to be added to the stack
90
+ cur_start_node = cur_start_node.args[0] # type: ignore[assignment]
91
+ # if the base op index matches the current node, set it
92
+ rev_base_op_idx = len(_reverse_fusion_ops) - 2 - base_op_idx
93
+ if rev_fusion_idx == rev_base_op_idx:
94
+ cur_base_op_node = cur_start_node
95
+ break
96
+
97
+ self.seen_nodes.add(cur_start_node)
98
+ # add args of previous nodes to stack
99
+ for arg in cur_start_node.all_input_nodes:
100
+ self._recursively_add_node_arg_to_stack(arg)
101
+
102
+ # skip unmatchable nodes
103
+ # note: this check is done on the start_node, i.e.
104
+ # if we are matching linear-relu in reverse, this would do the matchable
105
+ # check on the linear
106
+ if not self._is_matchable(cur_base_op_node):
107
+ continue
108
+
109
+ # If an observer or a fake_quant was not matched as a part of
110
+ # a pattern of multiple nodes, ignore it. One case where this is
111
+ # relevant is an observer on a graph input, which was added because
112
+ # it is necessary for the next node.
113
+ if cur_end_node.op == "call_module" and cur_start_node is cur_end_node:
114
+ maybe_obs = getattr_from_fqn(self.gm, cur_end_node.target) # type: ignore[arg-type]
115
+ if isinstance(maybe_obs, (ObserverBase, FakeQuantizeBase)):
116
+ continue
117
+
118
+ return NSSubgraph(
119
+ start_node=cur_start_node,
120
+ end_node=cur_end_node,
121
+ base_op_node=cur_base_op_node,
122
+ )
123
+
124
+ raise StopIteration
125
+
126
+ def _recursively_add_node_arg_to_stack(self, arg: Any) -> None:
127
+ """
128
+ Adds all of the nodes in this arg to the stack, properly navigating
129
+ through list, dicts and tuples.
130
+ """
131
+ if isinstance(arg, Node):
132
+ self.stack.append(arg)
133
+ elif (
134
+ isinstance(arg, torch.fx.immutable_collections.immutable_list)
135
+ or type(arg) is tuple
136
+ ):
137
+ for inner_arg in arg:
138
+ self._recursively_add_node_arg_to_stack(inner_arg)
139
+ elif isinstance(arg, torch.fx.immutable_collections.immutable_dict):
140
+ for value in arg.values():
141
+ self._recursively_add_node_arg_to_stack(value)
142
+
143
+ def _is_matchable(self, node: Node) -> bool:
144
+ if node.op == "call_function":
145
+ return node.target not in self.non_matchable_functions
146
+ elif node.op == "call_module":
147
+ assert isinstance(node.target, str)
148
+ target_mod = getattr_from_fqn(self.gm, node.target)
149
+ return not any(
150
+ isinstance(target_mod, t) # type: ignore[arg-type]
151
+ for t in self.non_matchable_modules
152
+ )
153
+ elif node.op == "call_method":
154
+ return node.target not in self.non_matchable_methods
155
+ else:
156
+ return False
157
+
158
+
159
+ class GraphMatchingException(Exception):
160
+ """
161
+ Exception raised when two graphs cannot be matched.
162
+ """
163
+
164
+
165
+ class SubgraphTypeRelationship(enum.Enum):
166
+ # same type, known
167
+ # example: F.linear and F.linear, or nn.Conv2d and nn.Conv2d
168
+ EQUAL = enum.auto()
169
+ # same type, but the type is not known to Numerical Suite
170
+ # (user defined type, etc).
171
+ EQUAL_BUT_UKNOWN = enum.auto()
172
+ # known, same subgraph_relationship set, but not the same type
173
+ # example: F.linear and toq.linear
174
+ RELATED_BUT_NOT_EQUAL = enum.auto()
175
+ # not related
176
+ NOT_RELATED = enum.auto()
177
+
178
+
179
+ def _get_subgraph_relationship_type(
180
+ subgraph_a: NSSubgraph,
181
+ subgraph_b: NSSubgraph,
182
+ gm_a: GraphModule,
183
+ gm_b: GraphModule,
184
+ type_a_related_to_b: Set[Tuple[NSNodeTargetType, NSNodeTargetType]],
185
+ ) -> SubgraphTypeRelationship:
186
+ node_a = subgraph_a.base_op_node
187
+ node_b = subgraph_b.base_op_node
188
+
189
+ # TODO(next): make this code handle matching by what is before the base op
190
+ if node_a.op != node_b.op:
191
+ if not (
192
+ node_a.op in ("call_function", "call_method")
193
+ and node_b.op in ("call_function", "call_method")
194
+ ):
195
+ return SubgraphTypeRelationship.NOT_RELATED
196
+
197
+ if node_a.op in ("call_function", "call_method"):
198
+ key = (node_a.target, node_b.target)
199
+
200
+ if key not in type_a_related_to_b:
201
+ if node_a.target == node_b.target:
202
+ return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN
203
+ else:
204
+ return SubgraphTypeRelationship.NOT_RELATED
205
+ # after this point, we are dealing with known types
206
+
207
+ if node_a.target == node_b.target:
208
+ node_a_has_prev = subgraph_a.base_op_node == subgraph_a.start_node
209
+ node_b_has_prev = subgraph_b.base_op_node == subgraph_b.start_node
210
+ if node_a_has_prev and (not node_b_has_prev):
211
+ return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
212
+ elif (not node_a_has_prev) and node_b_has_prev:
213
+ return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
214
+ elif (not node_a_has_prev) and (not node_b_has_prev):
215
+ return SubgraphTypeRelationship.EQUAL
216
+ else:
217
+ # TODO(future PR): check for matches start_op_node and base_op_node
218
+ return SubgraphTypeRelationship.EQUAL
219
+
220
+ if key in type_a_related_to_b:
221
+ return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
222
+ else:
223
+ return SubgraphTypeRelationship.NOT_RELATED
224
+ elif node_a.op == "call_module":
225
+ assert (
226
+ subgraph_a.base_op_node == subgraph_a.start_node
227
+ and subgraph_b.base_op_node == subgraph_b.start_node
228
+ ), "Matching call_module patterns where base_op_node != start_node is not supported yet"
229
+ # for call_module, we need to look up the modules to do the type check
230
+ assert isinstance(node_a.target, str)
231
+ mod_a = getattr_from_fqn(gm_a, node_a.target)
232
+ assert isinstance(node_b.target, str)
233
+ mod_b = getattr_from_fqn(gm_b, node_b.target)
234
+
235
+ key = (type(mod_a), type(mod_b))
236
+
237
+ if key not in type_a_related_to_b:
238
+ if type(mod_a) == type(mod_b):
239
+ return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN
240
+ else:
241
+ return SubgraphTypeRelationship.NOT_RELATED
242
+ elif type(mod_a) == type(mod_b):
243
+ return SubgraphTypeRelationship.EQUAL
244
+ else:
245
+ return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
246
+
247
+ return SubgraphTypeRelationship.NOT_RELATED
248
+
249
+
250
+ def _get_name_for_subgraph(
251
+ subgraph_a: NSSubgraph,
252
+ gm_a: GraphModule,
253
+ base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
254
+ existing_names: Set[str],
255
+ ) -> str:
256
+ """
257
+ Returns a unique name for a subgraph. This name is based on two things:
258
+ 1. the name of the set containing the underlying type of the base op in the
259
+ subgraph (i.e. 'torch.nn.functional.linear' if this is related to a linear op)
260
+ 2. the number of previous subgraphs with related underlying type of the base op
261
+
262
+ For example, in the graph
263
+
264
+ linear0 -> relu0 -> linear1 -> relu1
265
+
266
+ The subgraphs are (linear0, relu0) and (linear1, relu1). If we iterate
267
+ from the output node backwards, the name given to (linear1, relu1) will be
268
+ `base_op_torch.nn.functional.linear_0`, and the name given to (linear0, relu0)
269
+ will be `base_op_torch.nn.functional.linear_1`.
270
+
271
+ Why are we not just using the node name? Answer: because of two requirements:
272
+ A. fusions must be supported
273
+ B. some Numeric Suite APIs can be called without having all of the models in memory
274
+
275
+ For example, let's say we need to match nodes of
276
+
277
+ (1) ... -> linear0 -> relu0 -> ...
278
+
279
+ And
280
+
281
+ (2) ... -> linear_relu0 -> ...
282
+
283
+ Without being able to inspect them together. With the current naming scheme, if
284
+ we iterate through both of these graphs in the same order, and assuming the rest
285
+ of the graphs match, both of these subgraphs will get the same name without
286
+ (1) and (2) knowing anything about each other.
287
+ """
288
+ target_type = _get_node_target_type(subgraph_a.base_op_node, gm_a)
289
+ target_base_type = None
290
+ for base_name, sets_of_related_ops in base_name_to_sets_of_related_ops.items():
291
+ if target_type in sets_of_related_ops:
292
+ target_base_type = base_name
293
+ target_base_name = "base_op_" + str(target_base_type)
294
+ counter = 0
295
+ proposed_name = target_base_name + "_" + str(counter)
296
+ while proposed_name in existing_names:
297
+ counter += 1
298
+ proposed_name = target_base_name + "_" + str(counter)
299
+ existing_names.add(proposed_name)
300
+ return proposed_name
301
+
302
+
303
+ def _get_node_target_type(node: Node, gm: GraphModule) -> Optional[NSNodeTargetType]:
304
+ if node.op in ("call_function", "call_method"):
305
+ return node.target
306
+ elif node.op == "call_module":
307
+ assert isinstance(node.target, str)
308
+ mod = getattr_from_fqn(gm, node.target)
309
+ return type(mod)
310
+ return None
311
+
312
+
313
+ def get_matching_subgraph_pairs(
314
+ gm_a: GraphModule,
315
+ gm_b: GraphModule,
316
+ base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
317
+ unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
318
+ ) -> Dict[str, Tuple[NSSubgraph, NSSubgraph]]:
319
+ """
320
+ Matches matchable subgraphs of graph_a to graph_b.
321
+
322
+ For a node, "matchable" is defined as a node which is not an observer,
323
+ fake_quants, quant or dequant.
324
+
325
+ A subgraph can contain one or more nodes. A subgraph is matchable if
326
+ at least one node inside of it is matchable. Currently, all nodes in
327
+ a subgraph must be matchable (because we assume no observers will be
328
+ inserted in the middle of a fusion).
329
+
330
+ A subgraph is defined by (start_node, end_node). We assume that only
331
+ start_node and end_node are linked with the surrounding graph, all other
332
+ nodes in a subgraph are self-contained.
333
+
334
+ A pair of nodes is "related" if both nodes represent the same mathematical
335
+ operation across different quantization flavors. For example,
336
+ `F.linear` and `torch.ops.quantized.linear` are related, and
337
+ `F.linear` and `torch.nn.Conv` are not related.
338
+
339
+ For each matchable pair of nodes node_a and node_b, they will match
340
+ if node_a and node_b are related.
341
+
342
+ For graphs A and B, they will match iff:
343
+ 1. the number of matchable subgraphs in A and B is equivalent
344
+ 2. when iterating through the matchable subgraphs of A and B in the same order, each
345
+ corresponding pair of base nodes is related.
346
+
347
+ This enables us to find the corresponding subgraphs between
348
+ graphs of related models. For example, if we had two graphs such as:
349
+
350
+ graph_a: x0 -> conv_0 (type: nn.Conv2d) -> obs_0 -> x1
351
+ w -/
352
+ b -/
353
+
354
+ graph_b: x0 -> quant_0 -> qconv_0 (type: nnq.Conv2d) -> dequant_0 -> x1
355
+ packed_params_0 -/
356
+
357
+ This function will return the following result:
358
+ {
359
+ 'conv_0': ( # the name of the node in graph_b
360
+ (conv_0, conv_0), # (start_node_a, end_node_a)
361
+ (qconv_0, qconv_0), # (start_node_b, end_node_b)
362
+ ),
363
+ }
364
+
365
+ Or, if we have a fusion pattern,
366
+
367
+ graph_a: x0 -> linear_0 -> relu_0 -> obs_0 -> x1
368
+ w -/
369
+ b -/
370
+
371
+ graph_b: x0 -> quant_0 -> linear_relu_0 -> dequant_0 -> x1
372
+ packed_params_0 -/
373
+
374
+ This function will return the following result:
375
+ {
376
+ 'linear_relu_0': ( # the name of the node in graph_b
377
+ (linear_0, relu_0), # (start_node_a, end_node_a)
378
+ (linear_relu_0, linear_relu_0), # (start_node_b, end_node_b)
379
+ ),
380
+ }
381
+ """
382
+ if unmatchable_types_map is None:
383
+ unmatchable_types_map = get_unmatchable_types_map()
384
+ non_matchable_functions = unmatchable_types_map["funs_unmatchable"]
385
+ non_matchable_modules = unmatchable_types_map["mods_unmatchable"]
386
+ non_matchable_methods = unmatchable_types_map["meths_unmatchable"]
387
+
388
+ graph_a_iterator = _NSGraphMatchableSubgraphsIterator(
389
+ gm_a, non_matchable_functions, non_matchable_modules, non_matchable_methods
390
+ )
391
+ graph_b_iterator = _NSGraphMatchableSubgraphsIterator(
392
+ gm_b, non_matchable_functions, non_matchable_modules, non_matchable_methods
393
+ )
394
+ results = collections.OrderedDict()
395
+ if base_name_to_sets_of_related_ops is None:
396
+ base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
397
+ type_a_related_to_b = get_type_a_related_to_b(base_name_to_sets_of_related_ops)
398
+
399
+ existing_names_a: Set[str] = set()
400
+ existing_names_b: Set[str] = set()
401
+
402
+ while True:
403
+ # fetch the next subgraphs from a and b
404
+ cur_subgraph_a, cur_subgraph_b = None, None
405
+ try:
406
+ cur_subgraph_a = next(graph_a_iterator)
407
+ except StopIteration:
408
+ pass
409
+ try:
410
+ cur_subgraph_b = next(graph_b_iterator)
411
+ except StopIteration:
412
+ pass
413
+
414
+ # look up types of a and b for useful error messages
415
+ type_start_a, type_start_b = None, None
416
+ if cur_subgraph_a is not None:
417
+ type_start_a = _get_node_target_type(cur_subgraph_a.start_node, gm_a)
418
+ if cur_subgraph_b is not None:
419
+ type_start_b = _get_node_target_type(cur_subgraph_b.start_node, gm_b)
420
+
421
+ # check for results and determine what to do next
422
+ if cur_subgraph_a is not None and cur_subgraph_b is not None:
423
+ # both nodes were fetched, check for subgraph_relationship
424
+ # note: subgraph_relationship is checked on the start node, i.e.
425
+ # if a linear-relu pattern is checked, we would check for subgraph_relationship
426
+ # of the linear
427
+ subgraph_relationship = _get_subgraph_relationship_type(
428
+ cur_subgraph_a, cur_subgraph_b, gm_a, gm_b, type_a_related_to_b
429
+ )
430
+ if subgraph_relationship == SubgraphTypeRelationship.NOT_RELATED:
431
+ msg = f"""
432
+ The subgraphs
433
+ ({cur_subgraph_a}, {type_start_a}) and
434
+ ({cur_subgraph_b}, {type_start_b})
435
+ are not related. Please ensure that the two models you pass in have the same number
436
+ of subgraphs, and each pair of subgraphs is related to each other."""
437
+ raise GraphMatchingException(msg)
438
+ elif subgraph_relationship == SubgraphTypeRelationship.EQUAL_BUT_UKNOWN:
439
+ # skip matching but unknown types
440
+ continue
441
+ key_name_a = _get_name_for_subgraph(
442
+ cur_subgraph_a, gm_a, base_name_to_sets_of_related_ops, existing_names_a
443
+ )
444
+ key_name_b = _get_name_for_subgraph(
445
+ cur_subgraph_b, gm_b, base_name_to_sets_of_related_ops, existing_names_b
446
+ )
447
+ assert (
448
+ key_name_a == key_name_b
449
+ ), f"Subgraph names {key_name_a} and {key_name_b} do not match"
450
+ results[key_name_a] = (cur_subgraph_a, cur_subgraph_b)
451
+ continue
452
+ elif cur_subgraph_a is None and cur_subgraph_b is None:
453
+ # we reached the end of both graphs
454
+ break
455
+ else:
456
+ # only one node was fetched, no match possible, throw error
457
+ msg = f"""
458
+ Attempting to match
459
+ ({cur_subgraph_a}, {type_start_a}) and
460
+ ({cur_subgraph_b}, {type_start_b}),
461
+ one of which is empty. Please ensure that the two models you pass in have the same number
462
+ of subgraphs."""
463
+ raise GraphMatchingException(msg)
464
+
465
+ # The subgraph pairs are originally created by traversing the two graphs
466
+ # from the outputs to the inputs. Reverse the results to return the
467
+ # subgraphs in their order of execution.
468
+ results = collections.OrderedDict(reversed(list(results.items())))
469
+
470
+ return results
.venv/Lib/site-packages/torch/ao/ns/fx/graph_passes.py ADDED
@@ -0,0 +1,1131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
3
+
4
+ import torch
5
+ from torch.ao.ns.fx.mappings import get_node_type_to_io_type_map
6
+ from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
7
+ from torch.ao.quantization.observer import _is_activation_post_process
8
+ from torch.fx import GraphModule, map_arg
9
+ from torch.fx.graph import Graph, Node
10
+
11
+ from .ns_types import NSNodeTargetType, NSSingleResultValuesType, NSSubgraph
12
+ from .utils import (
13
+ get_arg_indices_of_inputs_to_log,
14
+ get_node_first_input_and_output_type,
15
+ get_node_input_qparams,
16
+ get_normalized_nth_input,
17
+ get_number_of_non_param_args,
18
+ get_target_type_str,
19
+ getattr_from_fqn,
20
+ NodeInputOrOutputType,
21
+ op_type_supports_shadowing,
22
+ return_first_non_observer_node,
23
+ )
24
+
25
+
26
+ def _maybe_get_fqn(node: Node, gm: GraphModule) -> Optional[str]:
27
+ fqn = None
28
+ if hasattr(gm, "_node_name_to_scope"):
29
+ # fqn on observers is not present, because they do not
30
+ # exist when the fqns are created during tracing. If this is
31
+ # an observer, get the fqn of the node being observed.
32
+ node_to_use_for_fqn = node
33
+ if node.op == "call_module":
34
+ assert isinstance(node.target, str)
35
+ module = getattr_from_fqn(gm, node.target)
36
+ if _is_activation_post_process(module):
37
+ node_to_use_for_fqn = get_normalized_nth_input(node, gm, 0)
38
+ fqn = gm._node_name_to_scope[node_to_use_for_fqn.name][0] # type: ignore[index]
39
+ return fqn # type: ignore[return-value]
40
+
41
+
42
+ def _insert_logger_after_node(
43
+ node: Node,
44
+ gm: GraphModule,
45
+ logger_cls: Callable,
46
+ logger_node_name_suffix: str,
47
+ ref_node_name: str,
48
+ model_name: str,
49
+ ref_name: str,
50
+ ref_node_target_type: str,
51
+ results_type: str,
52
+ index_within_arg: int,
53
+ index_of_arg: int,
54
+ fqn: Optional[str],
55
+ ) -> Node:
56
+ """
57
+ Given a starting graph of
58
+
59
+ prev_node -> node -> next_node
60
+
61
+ This function creates a new logger_cls obj and adds it
62
+ after node, resulting in
63
+
64
+ prev_node -> node -> logger_obj -> next_node
65
+ """
66
+ # create new name
67
+ logger_node_name = get_new_attr_name_with_prefix(
68
+ node.name + logger_node_name_suffix
69
+ )(gm)
70
+ target_type = get_target_type_str(node, gm)
71
+ # create the logger object
72
+ logger_obj = logger_cls(
73
+ ref_node_name,
74
+ node.name,
75
+ model_name,
76
+ ref_name,
77
+ target_type,
78
+ ref_node_target_type,
79
+ results_type,
80
+ index_within_arg,
81
+ index_of_arg,
82
+ fqn,
83
+ )
84
+ # attach the logger object to the parent module
85
+ setattr(gm, logger_node_name, logger_obj)
86
+ logger_node = node.graph.create_node("call_module", logger_node_name, (node,), {})
87
+ return logger_node
88
+
89
+
90
+ def add_loggers_to_model(
91
+ gm: GraphModule,
92
+ node_to_instrument_inputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
93
+ node_to_instrument_outputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
94
+ logger_cls: Callable,
95
+ model_name: str,
96
+ ) -> GraphModule:
97
+ """
98
+ Takes the graph of gm, adds loggers to the output
99
+ of each node in nodes_to_instrument. Returns a GraphModule with the new
100
+ graph.
101
+ """
102
+
103
+ new_graph = Graph()
104
+ env: Dict[str, Any] = {}
105
+ modules = dict(gm.named_modules())
106
+
107
+ def load_arg(a):
108
+ return map_arg(a, lambda node: env[node.name])
109
+
110
+ for node in gm.graph.nodes:
111
+ if node.op == "output":
112
+ new_graph.output(map_arg(get_normalized_nth_input(node, gm, 0), load_arg))
113
+ continue
114
+
115
+ if (node in node_to_instrument_inputs_to_ref_node_name) or (
116
+ node in node_to_instrument_outputs_to_ref_node_name
117
+ ):
118
+ fqn = _maybe_get_fqn(node, gm)
119
+
120
+ if node in node_to_instrument_inputs_to_ref_node_name:
121
+ ref_name, ref_node_type = node_to_instrument_inputs_to_ref_node_name[
122
+ node
123
+ ]
124
+ # Ops such add and mul are special because either
125
+ # one or two of the first two arguments can be tensors,
126
+ # and if one argument is a tensor it can be first or
127
+ # second (x + 1 versus 1 + x).
128
+ arg_indices_to_log = get_arg_indices_of_inputs_to_log(node)
129
+ for node_arg_idx in arg_indices_to_log:
130
+ node_arg = get_normalized_nth_input(node, gm, node_arg_idx)
131
+ if type(node_arg) == Node:
132
+ # create a single input logger
133
+ prev_node = env[node_arg.name]
134
+ env[node_arg.name] = _insert_logger_after_node(
135
+ prev_node,
136
+ gm,
137
+ logger_cls,
138
+ "_ns_logger_",
139
+ node.name,
140
+ model_name,
141
+ ref_name,
142
+ ref_node_type,
143
+ NSSingleResultValuesType.NODE_INPUT.value,
144
+ index_within_arg=0,
145
+ index_of_arg=node_arg_idx,
146
+ fqn=fqn,
147
+ )
148
+ elif (
149
+ type(node_arg) == torch.fx.immutable_collections.immutable_list
150
+ ):
151
+ # create N input loggers, one for each node
152
+ for arg_idx, arg in enumerate(node_arg): # type: ignore[var-annotated, arg-type]
153
+ prev_node = env[arg.name]
154
+ env[prev_node.name] = _insert_logger_after_node(
155
+ prev_node,
156
+ gm,
157
+ logger_cls,
158
+ "_ns_logger_",
159
+ node.name,
160
+ model_name,
161
+ ref_name,
162
+ ref_node_type,
163
+ NSSingleResultValuesType.NODE_INPUT.value,
164
+ index_within_arg=arg_idx,
165
+ index_of_arg=node_arg_idx,
166
+ fqn=fqn,
167
+ )
168
+ else:
169
+ pass
170
+
171
+ # ensure env is populated with base node
172
+ # Note: runs for both inputs and outputs
173
+ env[node.name] = new_graph.node_copy(node, load_arg)
174
+
175
+ if node in node_to_instrument_outputs_to_ref_node_name:
176
+ ref_name, ref_node_type = node_to_instrument_outputs_to_ref_node_name[
177
+ node
178
+ ]
179
+ # add the logger after the base node
180
+ env[node.name] = _insert_logger_after_node(
181
+ env[node.name],
182
+ gm,
183
+ logger_cls,
184
+ "_ns_logger_",
185
+ node.name,
186
+ model_name,
187
+ ref_name,
188
+ ref_node_type,
189
+ NSSingleResultValuesType.NODE_OUTPUT.value,
190
+ index_within_arg=0,
191
+ index_of_arg=0,
192
+ fqn=fqn,
193
+ )
194
+
195
+ else:
196
+ env[node.name] = new_graph.node_copy(node, load_arg)
197
+
198
+ new_gm = GraphModule(gm, new_graph)
199
+ return new_gm
200
+
201
+
202
+ def _insert_quantize_per_tensor_node(
203
+ prev_node_c: Node,
204
+ node_a: Node,
205
+ gm_b: GraphModule,
206
+ graph_c: Graph,
207
+ scale: Union[torch.Tensor, float],
208
+ zero_point: Union[torch.Tensor, int],
209
+ dtype_cast_name: str,
210
+ ) -> Node:
211
+ # copy scale
212
+ scale_node_name = get_new_attr_name_with_prefix(node_a.name + "_input_scale_")(gm_b)
213
+ setattr(gm_b, scale_node_name, scale)
214
+ scale_node = graph_c.create_node(
215
+ "get_attr", scale_node_name, (), {}, scale_node_name
216
+ )
217
+ # copy zero_point
218
+ zero_point_node_name = get_new_attr_name_with_prefix(
219
+ node_a.name + "_input_zero_point_"
220
+ )(gm_b)
221
+ setattr(gm_b, zero_point_node_name, zero_point)
222
+ zero_point_node = graph_c.create_node(
223
+ "get_attr", zero_point_node_name, (), {}, zero_point_node_name
224
+ )
225
+ # create the quantize_per_tensor call
226
+ return graph_c.create_node(
227
+ "call_function",
228
+ torch.quantize_per_tensor,
229
+ (prev_node_c, scale_node, zero_point_node, torch.quint8),
230
+ {},
231
+ dtype_cast_name,
232
+ )
233
+
234
+
235
+ def _insert_dtype_cast_after_node(
236
+ node_a: Node,
237
+ node_c: Node,
238
+ prev_node_c: Union[Node, List[Node]],
239
+ gm_a: GraphModule,
240
+ gm_b: GraphModule,
241
+ graph_c: Graph,
242
+ node_name_prefix: str,
243
+ logger_cls: Callable,
244
+ node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
245
+ ) -> Union[Node, List[Node]]:
246
+ """
247
+ Given a starting graph C (derived from graph B) of
248
+
249
+ ... -> prev_node_c -> node_c -> ...
250
+
251
+ And a corresponding related node_a, inserts the correct dtype
252
+ cast node after prev_node_c to cast into the dtype expected
253
+ by node_a, resulting in:
254
+
255
+ dtype_cast
256
+ /
257
+ ... -> prev_node_c -> node_c -> ...
258
+
259
+ For example, if node_c is an int8 op and node_a is an fp32 op, this function
260
+ will insert a dequant.
261
+ """
262
+ dtype_cast_op = None
263
+ dtype_cast_mod_cls = None
264
+ dtype_cast_method = None
265
+ dtype_cast_method_dtype = None
266
+ dtype_cast_scale = None
267
+ dtype_cast_zero_point = None
268
+ node_input_type_a, _node_output_type_a = get_node_first_input_and_output_type(
269
+ node_a, gm_a, logger_cls, node_type_to_io_type_map
270
+ )
271
+ node_input_type_c, _node_output_type_c = get_node_first_input_and_output_type(
272
+ node_c, gm_b, logger_cls, node_type_to_io_type_map
273
+ )
274
+
275
+ if (
276
+ (
277
+ node_input_type_a == NodeInputOrOutputType.FP32
278
+ and node_input_type_c == NodeInputOrOutputType.INT8
279
+ )
280
+ or (
281
+ node_input_type_a == NodeInputOrOutputType.FP32
282
+ and node_input_type_c == NodeInputOrOutputType.FP16
283
+ )
284
+ or
285
+ # TODO(future PR): determine the actual dtype of node_c,
286
+ # the current code only works because dequantize works with
287
+ # multiple input dtypes.
288
+ (
289
+ node_input_type_a == NodeInputOrOutputType.FP32
290
+ and node_input_type_c == NodeInputOrOutputType.FP32_OR_INT8
291
+ )
292
+ ):
293
+ dtype_cast_op = torch.dequantize
294
+ elif (
295
+ node_input_type_a == node_input_type_c
296
+ and node_input_type_a != NodeInputOrOutputType.UNKNOWN
297
+ ):
298
+ dtype_cast_mod_cls = torch.nn.Identity
299
+ elif (
300
+ node_input_type_a == NodeInputOrOutputType.INT8
301
+ and node_input_type_c == NodeInputOrOutputType.FP32
302
+ ):
303
+ # int8 shadows fp32, the dtype cast needs to quantize to int8
304
+ # with the right qparams.
305
+ node_a_input_qparams = get_node_input_qparams(
306
+ node_a, gm_a, node_type_to_io_type_map
307
+ )
308
+ if node_a_input_qparams is not None:
309
+ dtype_cast_op = torch.quantize_per_tensor # type: ignore[assignment]
310
+ dtype_cast_scale, dtype_cast_zero_point = node_a_input_qparams
311
+ elif (
312
+ node_input_type_a == NodeInputOrOutputType.FP16
313
+ and node_input_type_c == NodeInputOrOutputType.FP32
314
+ ):
315
+ dtype_cast_method = "to"
316
+ dtype_cast_method_dtype = torch.float16
317
+ else:
318
+ raise AssertionError(
319
+ f"dtype cast from {node_input_type_c} {node_c.format_node()} to "
320
+ + f"{node_input_type_a} {node_a.format_node()} needs to be implemented"
321
+ )
322
+
323
+ if isinstance(prev_node_c, Node):
324
+ new_dtype_cast_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
325
+ if dtype_cast_op:
326
+ if dtype_cast_scale is not None and dtype_cast_zero_point is not None:
327
+ return _insert_quantize_per_tensor_node(
328
+ prev_node_c,
329
+ node_a,
330
+ gm_b,
331
+ graph_c,
332
+ dtype_cast_scale,
333
+ dtype_cast_zero_point,
334
+ new_dtype_cast_name,
335
+ )
336
+ else:
337
+ return graph_c.create_node(
338
+ "call_function",
339
+ dtype_cast_op,
340
+ (prev_node_c,),
341
+ {},
342
+ new_dtype_cast_name,
343
+ )
344
+ elif dtype_cast_method:
345
+ return graph_c.create_node(
346
+ "call_method",
347
+ dtype_cast_method,
348
+ (prev_node_c, dtype_cast_method_dtype),
349
+ {},
350
+ new_dtype_cast_name,
351
+ )
352
+ else:
353
+ assert dtype_cast_mod_cls
354
+ dtype_cast_mod = dtype_cast_mod_cls()
355
+ setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
356
+ return graph_c.create_node(
357
+ "call_module",
358
+ new_dtype_cast_name,
359
+ (prev_node_c,),
360
+ {},
361
+ new_dtype_cast_name,
362
+ )
363
+ elif isinstance(prev_node_c, list):
364
+ results = []
365
+ for prev_node_c_inner in prev_node_c:
366
+ new_dtype_cast_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
367
+ if dtype_cast_op:
368
+ # TODO(future PR): add handling for quantize_per_tensor
369
+ new_dtype_cast_node = graph_c.create_node(
370
+ "call_function",
371
+ dtype_cast_op,
372
+ (prev_node_c_inner,),
373
+ {},
374
+ new_dtype_cast_name,
375
+ )
376
+ results.append(new_dtype_cast_node)
377
+ else:
378
+ assert dtype_cast_mod_cls
379
+ dtype_cast_mod = dtype_cast_mod_cls()
380
+ setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
381
+ new_dtype_cast_node = graph_c.create_node(
382
+ "call_module",
383
+ new_dtype_cast_name,
384
+ (prev_node_c_inner,),
385
+ {},
386
+ new_dtype_cast_name,
387
+ )
388
+ results.append(new_dtype_cast_node)
389
+ return results
390
+ else:
391
+ raise AssertionError(f"type f{type(prev_node_c)} is not handled")
392
+
393
+
394
+ # TODO(future PR): look into using copy_node API instead
395
+ def _copy_node_from_a_to_c(
396
+ node_a: Node,
397
+ gm_a: GraphModule,
398
+ gm_b: GraphModule,
399
+ graph_c: Graph,
400
+ ) -> Node:
401
+ """
402
+ Simple copy of node_a to graph_c.
403
+ """
404
+ if node_a.op == "get_attr":
405
+ node_a_copy_name = get_new_attr_name_with_prefix(node_a.name + "_shadow_copy_")(
406
+ gm_b
407
+ )
408
+ node_a_obj = getattr_from_fqn(gm_a, node_a.target) # type: ignore[arg-type]
409
+ if torch.is_tensor(node_a_obj):
410
+ node_a_obj = node_a_obj.detach()
411
+ setattr(gm_b, node_a_copy_name, node_a_obj)
412
+ node_a_copy = graph_c.create_node(
413
+ node_a.op, node_a_copy_name, (), {}, node_a_copy_name
414
+ )
415
+ return node_a_copy
416
+ elif node_a.op == "call_method":
417
+ assert node_a.target in (
418
+ "dequantize",
419
+ "to",
420
+ ), f"target {node_a.target} is not implemented"
421
+ if node_a.target == "dequantize":
422
+ arg_copy = _copy_node_from_a_to_c(
423
+ get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c
424
+ ) # type: ignore[arg-type]
425
+ node_a_copy_name = get_new_attr_name_with_prefix(
426
+ node_a.name + "_shadow_copy_"
427
+ )(gm_b)
428
+ node_a_copy = graph_c.create_node(
429
+ node_a.op, node_a.target, (arg_copy,), {}, node_a_copy_name
430
+ )
431
+ return node_a_copy
432
+ else: # to
433
+ arg_copy = _copy_node_from_a_to_c(
434
+ get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c
435
+ ) # type: ignore[arg-type]
436
+ node_a_copy_name = get_new_attr_name_with_prefix(
437
+ node_a.name + "_shadow_copy_"
438
+ )(gm_b)
439
+ node_a_copy = graph_c.create_node(
440
+ node_a.op,
441
+ node_a.target,
442
+ (arg_copy, get_normalized_nth_input(node_a, gm_a, 1)),
443
+ {},
444
+ node_a_copy_name,
445
+ )
446
+ return node_a_copy
447
+
448
+ else:
449
+ raise AssertionError(
450
+ f"handling of node {node_a.format_node()} with op {node_a.op} is not implemented"
451
+ )
452
+
453
+
454
+ def _can_insert_copy_of_subgraph_a(
455
+ subgraph_a: NSSubgraph,
456
+ gm_a: GraphModule,
457
+ num_non_param_args_node_a: int,
458
+ ) -> bool:
459
+ """
460
+ This function returns `False` if the input subgraph cannot be copied by
461
+ `_insert_copy_of_subgraph_a_after_input_node_c`. This usually means
462
+ that there is a corner case logic for which copy is not yet implemented.
463
+ """
464
+ # populate the list of nodes we need to check
465
+ nodes = []
466
+ cur_node = subgraph_a.end_node
467
+ while cur_node != subgraph_a.start_node:
468
+ nodes.append(cur_node)
469
+ cur_node = get_normalized_nth_input(cur_node, gm_a, 0) # type: ignore[assignment]
470
+ nodes.append(cur_node)
471
+ nodes.reverse()
472
+
473
+ def _can_insert(node_a_arg, gm_a):
474
+ if isinstance(node_a_arg, Node):
475
+ arg_a = return_first_non_observer_node(node_a_arg, gm_a)
476
+ if arg_a.op == "call_method":
477
+ return arg_a.target in ("dequantize", "to")
478
+ elif arg_a.op == "get_attr":
479
+ return True
480
+ else:
481
+ return False
482
+ elif isinstance(node_a_arg, (list, tuple)):
483
+ for el in node_a_arg:
484
+ if not isinstance(el, Node):
485
+ return False
486
+ return True
487
+
488
+ # For each node, check if we handle the copy behavior. This follows the
489
+ # logic in `_insert_copy_of_subgraph_a_after_input_node_c`.
490
+ for node_a in nodes:
491
+ local_num_non_param_args_node_a = (
492
+ num_non_param_args_node_a if node_a is nodes[0] else 1
493
+ )
494
+
495
+ norm_args_kwargs = node_a.normalized_arguments(
496
+ gm_a, normalize_to_only_use_kwargs=True
497
+ )
498
+ if norm_args_kwargs is not None:
499
+ norm_args, norm_kwargs = norm_args_kwargs
500
+ else:
501
+ norm_args, norm_kwargs = node_a.args, node_a.kwargs
502
+
503
+ cur_idx = 0
504
+
505
+ while cur_idx < len(norm_args):
506
+ if cur_idx == 0:
507
+ pass
508
+ elif cur_idx == 1 and local_num_non_param_args_node_a == 2:
509
+ pass
510
+ else:
511
+ if not _can_insert(norm_args[cur_idx], gm_a):
512
+ return False
513
+ cur_idx += 1
514
+
515
+ for kwarg_val in norm_kwargs.values():
516
+ # stitch the inputs from base graph
517
+ if cur_idx == 0:
518
+ pass
519
+ elif cur_idx == 1 and local_num_non_param_args_node_a == 2:
520
+ pass
521
+ else:
522
+ if not _can_insert(kwarg_val, gm_a):
523
+ return False
524
+ cur_idx += 1
525
+
526
+ return True
527
+
528
+
529
+ def _insert_copy_of_subgraph_a_after_input_node_c(
530
+ input_node_c: Union[Node, List[Node]],
531
+ input_node_c_2: Optional[Union[Node, List[Node]]],
532
+ subgraph_a: NSSubgraph,
533
+ gm_a: GraphModule,
534
+ gm_b: GraphModule,
535
+ node_name_prefix: str,
536
+ ) -> Node:
537
+ """
538
+ TODO(before land): real docblock
539
+ """
540
+ if isinstance(input_node_c, Node):
541
+ graph_c = input_node_c.graph
542
+ else:
543
+ assert isinstance(input_node_c, list)
544
+ graph_c = input_node_c[0].graph
545
+
546
+ # create a sequential list of the subgraphs' nodes from start to end,
547
+ # because we need to add the nodes to graph C in non-reverse order
548
+ nodes_of_a = [subgraph_a.end_node]
549
+ cur_node = subgraph_a.end_node
550
+ while cur_node != subgraph_a.start_node:
551
+ cur_node = get_normalized_nth_input(cur_node, gm_a, 0) # type: ignore[assignment]
552
+ nodes_of_a.insert(0, cur_node)
553
+
554
+ # go through nodes of a in order, and insert them into the graph of c
555
+ # sequentially
556
+ cur_node_a = nodes_of_a[0]
557
+ cur_node_c = _insert_copy_of_node_a_after_input_node_c(
558
+ input_node_c, input_node_c_2, cur_node_a, gm_a, gm_b, node_name_prefix
559
+ )
560
+ for cur_idx_a in range(1, len(nodes_of_a)):
561
+ cur_node_a = nodes_of_a[cur_idx_a]
562
+ prev_node_c = cur_node_c # previous added node is the input to next node
563
+ cur_node_c = _insert_copy_of_node_a_after_input_node_c(
564
+ prev_node_c,
565
+ # TODO(future PR): enable multiple inputs for nodes which are not at start of subgraph
566
+ None,
567
+ cur_node_a,
568
+ gm_a,
569
+ gm_b,
570
+ node_name_prefix,
571
+ )
572
+ # return the last inserted node
573
+ return cur_node_c
574
+
575
+
576
+ def _insert_copy_of_node_a_after_input_node_c(
577
+ input_node_c: Union[Node, List[Node]],
578
+ input_node_c_2: Optional[Union[Node, List[Node]]],
579
+ node_a: Node,
580
+ gm_a: GraphModule,
581
+ gm_b: GraphModule,
582
+ node_name_prefix: str,
583
+ ) -> Node:
584
+ """
585
+ Assume that node_a from graph_a has
586
+ args (input, (input2)?, arg1, ...), and
587
+ kwargs {kw0: kwarg0, ...}
588
+
589
+ Note: input2 is optional. If it equals to None, we assume that the op
590
+ has a single non-param input. If it is specified, we assume that the op
591
+ has two non-param inputs.
592
+
593
+ Copies the underlying values of arg1..argn and kwarg0..kwargn into gm_b,
594
+ and creates the corresponding nodes in graph_c. Note: observers are ignored,
595
+ so if an arg is an observer we navigate up until we find a non-observer parent.
596
+
597
+ If node_a is a call_module, points the module pointed to by node_a to gm_b.
598
+
599
+ Creates the copy of node_a in graph_c, with input as the first arg,
600
+ and all other args and kwargs pointing to the copies of the objects
601
+ in gm_b created above.
602
+
603
+ An example in pictures:
604
+
605
+ graph A:
606
+ ========
607
+
608
+ input -------------> node_a
609
+ / / /
610
+ (input_2)?----------/ / /
611
+ / /
612
+ weight -> weight_obs /
613
+ /
614
+ bias ----------------
615
+
616
+ graph C (derived from B):
617
+ =========================
618
+
619
+ input_node_c --> node_a_copy
620
+ / / /
621
+ (input_node_c_2)? / /
622
+ / /
623
+ weight_copy ----/ /
624
+ /
625
+ bias_copy ------/
626
+ """
627
+ if isinstance(input_node_c, Node):
628
+ graph_c = input_node_c.graph
629
+ else:
630
+ assert isinstance(input_node_c, list)
631
+ graph_c = input_node_c[0].graph
632
+
633
+ norm_args_kwargs = node_a.normalized_arguments(
634
+ gm_a, normalize_to_only_use_kwargs=True
635
+ )
636
+ if norm_args_kwargs is not None:
637
+ norm_args, norm_kwargs = norm_args_kwargs
638
+ else:
639
+ norm_args, norm_kwargs = node_a.args, node_a.kwargs
640
+
641
+ new_args = []
642
+ new_kwargs = {}
643
+
644
+ def _copy_arg(arg):
645
+ # copy the other inputs from the other graph
646
+ if isinstance(arg, Node):
647
+ arg = return_first_non_observer_node(arg, gm_a)
648
+ arg = _copy_node_from_a_to_c(arg, gm_a, gm_b, graph_c)
649
+ return arg
650
+ elif isinstance(arg, (int, float, torch.dtype)):
651
+ return arg
652
+ elif isinstance(kwarg_val, (list, tuple)):
653
+ for el in kwarg_val:
654
+ assert not isinstance(
655
+ el, Node
656
+ ), "handling of Node inside list is not implemented"
657
+ return arg
658
+ else:
659
+ raise AssertionError(
660
+ f"handling for kwarg of type {type(kwarg_val)} is not implemented"
661
+ )
662
+
663
+ cur_idx = 0
664
+
665
+ while cur_idx < len(norm_args):
666
+ if cur_idx == 0:
667
+ new_arg = input_node_c
668
+ elif cur_idx == 1 and input_node_c_2 is not None:
669
+ new_arg = input_node_c_2
670
+ else:
671
+ new_arg = _copy_arg(norm_args[cur_idx])
672
+ new_args.append(new_arg)
673
+ cur_idx += 1
674
+
675
+ for kwarg_name, kwarg_val in norm_kwargs.items():
676
+ # stitch the inputs from base graph
677
+ if cur_idx == 0:
678
+ new_kwargs[kwarg_name] = input_node_c
679
+ elif cur_idx == 1 and input_node_c_2 is not None:
680
+ new_kwargs[kwarg_name] = input_node_c_2
681
+ else:
682
+ new_kwargs[kwarg_name] = _copy_arg(kwarg_val)
683
+ cur_idx += 1
684
+
685
+ new_args = tuple(new_args) # type: ignore[assignment]
686
+
687
+ node_a_shadows_c_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
688
+
689
+ if node_a.op == "call_module":
690
+ # if target is a module, we point to the module from gm_b
691
+ new_mod_copy_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
692
+ # fetch the corresponding module from gm_a
693
+ assert isinstance(node_a.target, str)
694
+ mod_a = getattr_from_fqn(gm_a, node_a.target)
695
+ setattr(gm_b, new_mod_copy_name, mod_a)
696
+ node_a_shadows_c = graph_c.create_node(
697
+ node_a.op, new_mod_copy_name, new_args, new_kwargs, node_a_shadows_c_name # type: ignore[arg-type]
698
+ )
699
+ return node_a_shadows_c
700
+ else:
701
+ assert node_a.op in ("call_function", "call_method")
702
+ node_a_shadows_c = graph_c.create_node(
703
+ node_a.op, node_a.target, new_args, new_kwargs, node_a_shadows_c_name # type: ignore[arg-type]
704
+ )
705
+ return node_a_shadows_c
706
+
707
+
708
+ def create_a_shadows_b(
709
+ name_a: str,
710
+ gm_a: GraphModule,
711
+ name_b: str,
712
+ gm_b: GraphModule,
713
+ matched_subgraph_pairs: Dict[str, Tuple[NSSubgraph, NSSubgraph]],
714
+ logger_cls: Callable,
715
+ should_log_inputs: bool,
716
+ node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
717
+ ) -> GraphModule:
718
+ """
719
+ Creates a new GraphModule consisting of the graph of C, with the meaningful
720
+ nodes of A shadowing the corresponding nodes of B. For example,
721
+
722
+ Graph A:
723
+ a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2
724
+
725
+ Graph B:
726
+ b0 -> op0_int8 -> b1 -> op1_int8 -> b2
727
+
728
+ matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)}
729
+
730
+ Graph C (A shadows B):
731
+
732
+ / dequant0 -> op0_fp32 -> logger_a_0 / dequant_1 -> op1_fp32 -> logger_a_1
733
+ / /
734
+ b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1
735
+
736
+ In a nutshell, this function does the following for each node pair:
737
+ * copies the necessary attributes and modules from gm_a to gm_b,
738
+ keeping names unique
739
+ * adds a dtype cast op (dequant, quant, etc)
740
+ * adds a copy of node_a in gm_b's graph
741
+ * adds loggers to the outputs of node_a and node_b
742
+ """
743
+
744
+ if node_type_to_io_type_map is None:
745
+ node_type_to_io_type_map = get_node_type_to_io_type_map()
746
+
747
+ # graph_c is the graph created from copying the nodes of graph_b and inserting
748
+ # the shadows with the nodes copied from graph_a
749
+ graph_c = Graph()
750
+ env_c: Dict[str, Any] = {}
751
+ modules = dict(gm_b.named_modules())
752
+
753
+ def load_arg(a):
754
+ return map_arg(a, lambda node: env_c[node.name])
755
+
756
+ start_node_b_to_matched_subgraph_a_and_name = {}
757
+ end_node_b_to_matched_subgraph_a_and_name = {}
758
+ for match_name, match in matched_subgraph_pairs.items():
759
+ subgraph_a, subgraph_b = match
760
+ ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a)
761
+ ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b)
762
+ start_node_b_to_matched_subgraph_a_and_name[subgraph_b.start_node] = (
763
+ subgraph_a,
764
+ match_name,
765
+ ref_node_type_a,
766
+ ref_node_type_b,
767
+ )
768
+ end_node_b_to_matched_subgraph_a_and_name[subgraph_b.end_node] = (
769
+ subgraph_a,
770
+ match_name,
771
+ ref_node_type_a,
772
+ ref_node_type_b,
773
+ )
774
+
775
+ for node_b in gm_b.graph.nodes:
776
+ if node_b.op == "output":
777
+ graph_c.output(map_arg(node_b.args[0], load_arg))
778
+ continue
779
+
780
+ # calculate the flags to determine what to do with this node
781
+ node_b_is_start_node = node_b in start_node_b_to_matched_subgraph_a_and_name
782
+ node_b_is_end_node = node_b in end_node_b_to_matched_subgraph_a_and_name
783
+
784
+ if node_b_is_start_node or node_b_is_end_node:
785
+ if node_b_is_start_node:
786
+ (
787
+ subgraph_a,
788
+ ref_name,
789
+ ref_node_type_a,
790
+ ref_node_type_b,
791
+ ) = start_node_b_to_matched_subgraph_a_and_name[node_b]
792
+ else:
793
+ assert node_b_is_end_node
794
+ (
795
+ subgraph_a,
796
+ ref_name,
797
+ ref_node_type_a,
798
+ ref_node_type_b,
799
+ ) = end_node_b_to_matched_subgraph_a_and_name[node_b]
800
+
801
+ all_op_types_support_shadowing = op_type_supports_shadowing(
802
+ subgraph_a.start_node
803
+ ) and op_type_supports_shadowing(node_b)
804
+ if not all_op_types_support_shadowing:
805
+ print(
806
+ f"skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}"
807
+ + f", start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}"
808
+ + ", unsupported"
809
+ )
810
+ env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
811
+ continue
812
+
813
+ # For both start_node and end_node verify that we know how to do
814
+ # the dtype cast. If we do not, skip.
815
+ (
816
+ node_input_type_a,
817
+ node_output_type_a,
818
+ ) = get_node_first_input_and_output_type(
819
+ subgraph_a.start_node, gm_a, logger_cls, node_type_to_io_type_map
820
+ )
821
+ (
822
+ node_input_type_b,
823
+ node_output_type_b,
824
+ ) = get_node_first_input_and_output_type(
825
+ node_b, gm_b, logger_cls, node_type_to_io_type_map
826
+ )
827
+ node_io_types_known_a_and_b = (
828
+ node_input_type_a != NodeInputOrOutputType.UNKNOWN
829
+ and node_output_type_a != NodeInputOrOutputType.UNKNOWN
830
+ and node_input_type_b != NodeInputOrOutputType.UNKNOWN
831
+ and node_output_type_b != NodeInputOrOutputType.UNKNOWN
832
+ )
833
+ if not node_io_types_known_a_and_b:
834
+ print(
835
+ f"skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}"
836
+ + f", start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}"
837
+ + ", unknown dtype cast"
838
+ )
839
+ env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
840
+ continue
841
+
842
+ # If we are shadowing from fp32 to int8, we need to insert
843
+ # quantize_per_tensor call with qparams from the previous node.
844
+ # Only do this if we are able to infer these qparams from the graph.
845
+ if (
846
+ node_input_type_a == NodeInputOrOutputType.INT8
847
+ and node_input_type_b == NodeInputOrOutputType.FP32
848
+ ):
849
+ node_a_input_qparams = get_node_input_qparams(
850
+ subgraph_a.start_node, gm_a, node_type_to_io_type_map
851
+ )
852
+ if not node_a_input_qparams:
853
+ print(
854
+ f"skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}"
855
+ + f", start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}"
856
+ + ", unknown input qparams"
857
+ )
858
+ env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
859
+ continue
860
+
861
+ num_non_param_args_node_a = get_number_of_non_param_args(
862
+ subgraph_a.start_node, gm_a
863
+ )
864
+ if not _can_insert_copy_of_subgraph_a(
865
+ subgraph_a, gm_a, num_non_param_args_node_a
866
+ ):
867
+ print(
868
+ f"skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}"
869
+ + f", start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}"
870
+ + ", unhandled logic in subgraph copy"
871
+ )
872
+ env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
873
+ continue
874
+
875
+ fqn_base_a = _maybe_get_fqn(subgraph_a.base_op_node, gm_a)
876
+ fqn_base_b = _maybe_get_fqn(subgraph_b.base_op_node, gm_b) # type: ignore[possibly-undefined]
877
+
878
+ if node_b_is_start_node:
879
+ # if necessary, log the input of node_c
880
+ if should_log_inputs:
881
+ prev_node_b = get_normalized_nth_input(node_b, gm_b, 0)
882
+ if isinstance(prev_node_b, Node):
883
+ prev_node_c = env_c[prev_node_b.name]
884
+ env_c[prev_node_c.name] = _insert_logger_after_node(
885
+ prev_node_c,
886
+ gm_b,
887
+ logger_cls,
888
+ "_ns_logger_b_inp_",
889
+ node_b.name,
890
+ name_b,
891
+ ref_name,
892
+ ref_node_type_b,
893
+ NSSingleResultValuesType.NODE_INPUT.value,
894
+ index_within_arg=0,
895
+ index_of_arg=0,
896
+ fqn=fqn_base_b,
897
+ )
898
+ elif isinstance(prev_node_b, list):
899
+ # first, save the prev_node instances, because they
900
+ # will be overwritten in the env after the first logger
901
+ # is added
902
+ prev_node_c_list = [env_c[arg.name] for arg in prev_node_b]
903
+
904
+ for arg_idx, arg in enumerate(prev_node_b):
905
+ prev_node_c = prev_node_c_list[arg_idx]
906
+ env_c[prev_node_c.name] = _insert_logger_after_node(
907
+ prev_node_c,
908
+ gm_b,
909
+ logger_cls,
910
+ "_ns_logger_b_inp_",
911
+ node_b.name,
912
+ name_b,
913
+ ref_name,
914
+ ref_node_type_b,
915
+ NSSingleResultValuesType.NODE_INPUT.value,
916
+ index_within_arg=arg_idx,
917
+ index_of_arg=0,
918
+ fqn=fqn_base_b,
919
+ )
920
+ else:
921
+ # logging of inputs which are not lists is not supported yet
922
+ raise AssertionError(
923
+ f"type {type(prev_node_b)} is not handled yet"
924
+ )
925
+ # subgraph so far:
926
+ #
927
+ # (prev_node_c)+ -> (logger_c_input)?
928
+
929
+ # Note: this if statement is always True, spelling it out to clarify code
930
+ # intent.
931
+ if node_b_is_start_node or node_b_is_end_node:
932
+ # ensure env_c is populated with base node
933
+ env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
934
+ node_c = env_c[node_b.name]
935
+
936
+ # after this point,
937
+ #
938
+ # node_a is the original node from graph_a, with parent module gm_a
939
+ # node_b is the original node from graph_b, with parent module gm_b
940
+ # node_c is the copy of node_b in graph_c
941
+ #
942
+ # subgraph so far:
943
+ #
944
+ # (prev_node_c)+ -> (logger_c_input)? -> node_start_c
945
+
946
+ if node_b_is_start_node:
947
+ # cast dtype from the dtype of node_c's input to the dtype of
948
+ # node_a's input (dequant, etc)
949
+ # prev_node_c = node_c.args[0]
950
+ prev_node_c = get_normalized_nth_input(node_c, gm_b, 0) # type: ignore[possibly-undefined]
951
+ if should_log_inputs:
952
+ # skip the input logger when inserting a dtype cast
953
+ if isinstance(prev_node_c, Node):
954
+ prev_node_c = get_normalized_nth_input(node_c, gm_b, 0)
955
+ elif isinstance(prev_node_c, list):
956
+ prev_node_c = [
957
+ get_normalized_nth_input(arg, gm_b, 0)
958
+ for arg in prev_node_c
959
+ ]
960
+ dtype_cast_node = _insert_dtype_cast_after_node(
961
+ subgraph_a.start_node,
962
+ node_c,
963
+ prev_node_c,
964
+ gm_a,
965
+ gm_b,
966
+ graph_c,
967
+ node_b.name + "_dtype_cast_",
968
+ logger_cls,
969
+ node_type_to_io_type_map,
970
+ )
971
+ # note: not inserting to env_c because all nodes which use the dtype
972
+ # casts are copied from graph_a
973
+ #
974
+ # subgraph so far:
975
+ #
976
+ # (dtype_cast_node)+
977
+ # /
978
+ # (prev_node_c)+ -> (logger_c_input)? -> node_start_c
979
+
980
+ # if input logging is enabled, log the input to the subgraph
981
+ if should_log_inputs:
982
+ # TODO: explain this
983
+ ref_node_name = ""
984
+ if isinstance(dtype_cast_node, Node):
985
+ dtype_cast_node = _insert_logger_after_node(
986
+ dtype_cast_node,
987
+ gm_b,
988
+ logger_cls,
989
+ "_ns_logger_a_inp_",
990
+ ref_node_name,
991
+ name_a,
992
+ ref_name,
993
+ ref_node_type_a,
994
+ NSSingleResultValuesType.NODE_INPUT.value,
995
+ index_within_arg=0,
996
+ index_of_arg=0,
997
+ fqn=fqn_base_a,
998
+ )
999
+ input_logger: Union[Node, List[Node]] = dtype_cast_node
1000
+ else:
1001
+ assert isinstance(dtype_cast_node, list)
1002
+ new_loggers = []
1003
+ for dtype_cast_idx, dtype_cast_node_inner in enumerate(
1004
+ dtype_cast_node
1005
+ ):
1006
+ dtype_cast_logger = _insert_logger_after_node(
1007
+ dtype_cast_node_inner,
1008
+ gm_b,
1009
+ logger_cls,
1010
+ "_ns_logger_a_inp_",
1011
+ ref_node_name,
1012
+ name_a,
1013
+ ref_name,
1014
+ ref_node_type_a,
1015
+ NSSingleResultValuesType.NODE_INPUT.value,
1016
+ index_within_arg=dtype_cast_idx,
1017
+ index_of_arg=0,
1018
+ fqn=fqn_base_a,
1019
+ )
1020
+ new_loggers.append(dtype_cast_logger)
1021
+ dtype_cast_node = new_loggers
1022
+ input_logger = dtype_cast_node
1023
+ # subgraph so far:
1024
+ #
1025
+ # (dtype_cast_node)+ -> (logger_a_input)?
1026
+ # /
1027
+ # prev_node_c -> (logger_c_input)? -> node_start_c
1028
+
1029
+ # hook up the new mod_a copy to be in the graph, receiving the
1030
+ # same inputs as mod_b does, with dtype cast to match a
1031
+ # Some ops, such as LSTMs, have two non-param inputs. If we have
1032
+ # such an op, pass the second param as well. Note: dtype casting
1033
+ # for the second param is not implemented yet, it can be added
1034
+ # later if there is a use case.
1035
+ node_c_second_non_param_arg = None
1036
+ num_non_param_args_node_a = get_number_of_non_param_args(
1037
+ subgraph_a.start_node, gm_a
1038
+ )
1039
+ if num_non_param_args_node_a == 2:
1040
+ # node_c_second_non_param_arg = node_c.args[1]
1041
+ node_c_second_non_param_arg = get_normalized_nth_input(
1042
+ node_c, gm_b, 1
1043
+ )
1044
+ node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c(
1045
+ dtype_cast_node,
1046
+ node_c_second_non_param_arg,
1047
+ subgraph_a,
1048
+ gm_a,
1049
+ gm_b,
1050
+ node_c.name + "_shadow_copy_",
1051
+ )
1052
+ env_c[node_a_shadows_c.name] = node_a_shadows_c
1053
+ # subgraph so far:
1054
+ #
1055
+ # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy(args/kwargs not shown)
1056
+ # /
1057
+ # (prev_node_c)+ -> (logger_c_input)? -> node_start_c
1058
+
1059
+ if should_log_inputs:
1060
+ # When we created the input logger, we left the ref_node_name
1061
+ # as an empty string, because the subgraph copy did not exist
1062
+ # yet. Now that the subgraph copy exists, we modify this name
1063
+ # to its true value.
1064
+ # Note: the alternative to this is to create the input logger
1065
+ # after creating the subgraph, which is slightly more
1066
+ # complicated. This is the lesser of two evils.
1067
+ # input_logger = env_c[dtype_cast_node.name]
1068
+ # Find the first node in the subgraph
1069
+ cur_node = node_a_shadows_c
1070
+ while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger: # type: ignore[possibly-undefined]
1071
+ cur_node = get_normalized_nth_input(cur_node, gm_b, 0) # type: ignore[assignment]
1072
+ if isinstance(input_logger, Node):
1073
+ input_logger_mod = getattr(gm_b, input_logger.name)
1074
+ input_logger_mod.ref_node_name = cur_node.name
1075
+ else:
1076
+ assert isinstance(input_logger, list)
1077
+ for input_logger_inner in input_logger:
1078
+ input_logger_mod = getattr(gm_b, input_logger_inner.name)
1079
+ input_logger_mod.ref_node_name = cur_node.name
1080
+
1081
+ # hook up a logger to the mod_a copy
1082
+ env_c[node_a_shadows_c.name] = _insert_logger_after_node(
1083
+ env_c[node_a_shadows_c.name],
1084
+ gm_b,
1085
+ logger_cls,
1086
+ "_ns_logger_a_",
1087
+ node_a_shadows_c.name,
1088
+ name_a,
1089
+ ref_name,
1090
+ ref_node_type_a,
1091
+ NSSingleResultValuesType.NODE_OUTPUT.value,
1092
+ index_within_arg=0,
1093
+ index_of_arg=0,
1094
+ fqn=fqn_base_a,
1095
+ )
1096
+ # subgraph so far:
1097
+ #
1098
+ # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
1099
+ # /
1100
+ # (prev_node_c)+ -> (logger_c_input)? -> node_start_c
1101
+
1102
+ if node_b_is_end_node:
1103
+ # hook up a logger to the mod_b copy
1104
+ env_c[node_b.name] = _insert_logger_after_node(
1105
+ env_c[node_b.name],
1106
+ gm_b,
1107
+ logger_cls,
1108
+ "_ns_logger_b_",
1109
+ node_b.name,
1110
+ name_b,
1111
+ ref_name,
1112
+ ref_node_type_b,
1113
+ NSSingleResultValuesType.NODE_OUTPUT.value,
1114
+ index_within_arg=0,
1115
+ index_of_arg=0,
1116
+ fqn=fqn_base_b,
1117
+ )
1118
+ # subgraph so far:
1119
+ #
1120
+ # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
1121
+ # /
1122
+ # (prev_node_c+) -> (logger_c_input)? -> node_start_c -> ... -> node_end_c -> logger_c
1123
+ #
1124
+ # Note: node_start_c may be the same node as node_end_c, or they
1125
+ # may have nodes inbetween.
1126
+
1127
+ else:
1128
+ env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
1129
+
1130
+ gm_c = GraphModule(gm_b, graph_c)
1131
+ return gm_c