koichi12 commited on
Commit
6229f35
·
verified ·
1 Parent(s): 55ebfe8

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/FlowControl.cpython-311-x86_64-linux-gnu.so +3 -0
  3. tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Debugger/__pycache__/libpython.cpython-311.pyc +3 -0
  4. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/__pycache__/console.cpython-311.pyc +3 -0
  5. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_tensor_str.py +697 -0
  6. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/modules/fused.py +160 -0
  7. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/__init__.py +14 -0
  8. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-311.pyc +0 -0
  9. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py +175 -0
  10. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py +177 -0
  11. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/qat/dynamic/modules/linear.py +25 -0
  12. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/qat/modules/__init__.py +14 -0
  13. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantizable/__pycache__/__init__.cpython-311.pyc +0 -0
  14. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantizable/modules/activation.py +465 -0
  15. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/__pycache__/functional.cpython-311.pyc +0 -0
  16. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-311.pyc +0 -0
  17. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/batchnorm.cpython-311.pyc +0 -0
  18. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/conv.cpython-311.pyc +0 -0
  19. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/embedding_ops.cpython-311.pyc +0 -0
  20. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/linear.cpython-311.pyc +0 -0
  21. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/rnn.cpython-311.pyc +0 -0
  22. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/utils.cpython-311.pyc +0 -0
  23. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/activation.py +302 -0
  24. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/conv.py +945 -0
  25. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/dropout.py +27 -0
  26. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/functional_modules.py +249 -0
  27. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__init__.py +21 -0
  28. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/linear.cpython-311.pyc +0 -0
  29. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/sparse.cpython-311.pyc +0 -0
  30. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/rnn.py +614 -0
  31. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/utils.py +323 -0
  32. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/sparse/quantized/__pycache__/utils.cpython-311.pyc +0 -0
  33. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/sparse/quantized/dynamic/linear.py +139 -0
  34. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/__pycache__/__init__.cpython-311.pyc +0 -0
  35. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_scheduler/__init__.py +5 -0
  36. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/quantization_utils.cpython-311.pyc +0 -0
  37. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/__init__.py +0 -0
  38. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py +130 -0
  39. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/FPGM_pruner.cpython-311.pyc +0 -0
  40. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/__init__.cpython-311.pyc +0 -0
  41. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/pruner/saliency_pruner.py +29 -0
  42. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/scheduler/__init__.py +0 -0
  43. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/scheduler/__pycache__/base_scheduler.cpython-311.pyc +0 -0
  44. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/scheduler/lambda_scheduler.py +47 -0
  45. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/sparsifier/__pycache__/__init__.cpython-311.pyc +0 -0
  46. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/sparsifier/__pycache__/base_sparsifier.cpython-311.pyc +0 -0
  47. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/sparsifier/__pycache__/weight_norm_sparsifier.cpython-311.pyc +0 -0
  48. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/_equalize.py +182 -0
  49. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__init__.py +23 -0
  50. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/__init__.cpython-311.pyc +0 -0
.gitattributes CHANGED
@@ -66,3 +66,6 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycach
66
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/__pycache__/test_fp.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
67
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/__pycache__/function_docs.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
68
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/w64-arm.exe filter=lfs diff=lfs merge=lfs -text
 
 
 
 
66
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/__pycache__/test_fp.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
67
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/__pycache__/function_docs.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
68
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/w64-arm.exe filter=lfs diff=lfs merge=lfs -text
69
+ tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/__pycache__/console.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
70
+ tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Debugger/__pycache__/libpython.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
71
+ tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/FlowControl.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/FlowControl.cpython-311-x86_64-linux-gnu.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2c7b91ab0731f5672d976d4408f3525891b8c4e1d4ed4d403f56d1c141c7f94
3
+ size 688080
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Debugger/__pycache__/libpython.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e828bf211daa379b740684868a31081921397805bfc7ef4b41a8572d794eaafb
3
+ size 137864
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/__pycache__/console.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a7da30d6865deaf94e2814884970e99b253843c23d4aa93b1107a23e61de6c1
3
+ size 123664
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_tensor_str.py ADDED
@@ -0,0 +1,697 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import dataclasses
3
+ import math
4
+ import textwrap
5
+ from typing import Any, Dict, Optional
6
+
7
+ import torch
8
+ from torch import inf
9
+
10
+
11
+ @dataclasses.dataclass
12
+ class __PrinterOptions:
13
+ precision: int = 4
14
+ threshold: float = 1000
15
+ edgeitems: int = 3
16
+ linewidth: int = 80
17
+ sci_mode: Optional[bool] = None
18
+
19
+
20
+ PRINT_OPTS = __PrinterOptions()
21
+
22
+
23
+ # We could use **kwargs, but this will give better docs
24
+ def set_printoptions(
25
+ precision=None,
26
+ threshold=None,
27
+ edgeitems=None,
28
+ linewidth=None,
29
+ profile=None,
30
+ sci_mode=None,
31
+ ):
32
+ r"""Set options for printing. Items shamelessly taken from NumPy
33
+
34
+ Args:
35
+ precision: Number of digits of precision for floating point output
36
+ (default = 4).
37
+ threshold: Total number of array elements which trigger summarization
38
+ rather than full `repr` (default = 1000).
39
+ edgeitems: Number of array items in summary at beginning and end of
40
+ each dimension (default = 3).
41
+ linewidth: The number of characters per line for the purpose of
42
+ inserting line breaks (default = 80). Thresholded matrices will
43
+ ignore this parameter.
44
+ profile: Sane defaults for pretty printing. Can override with any of
45
+ the above options. (any one of `default`, `short`, `full`)
46
+ sci_mode: Enable (True) or disable (False) scientific notation. If
47
+ None (default) is specified, the value is defined by
48
+ `torch._tensor_str._Formatter`. This value is automatically chosen
49
+ by the framework.
50
+
51
+ Example::
52
+
53
+ >>> # Limit the precision of elements
54
+ >>> torch.set_printoptions(precision=2)
55
+ >>> torch.tensor([1.12345])
56
+ tensor([1.12])
57
+ >>> # Limit the number of elements shown
58
+ >>> torch.set_printoptions(threshold=5)
59
+ >>> torch.arange(10)
60
+ tensor([0, 1, 2, ..., 7, 8, 9])
61
+ >>> # Restore defaults
62
+ >>> torch.set_printoptions(profile='default')
63
+ >>> torch.tensor([1.12345])
64
+ tensor([1.1235])
65
+ >>> torch.arange(10)
66
+ tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
67
+
68
+ """
69
+ if profile is not None:
70
+ if profile == "default":
71
+ PRINT_OPTS.precision = 4
72
+ PRINT_OPTS.threshold = 1000
73
+ PRINT_OPTS.edgeitems = 3
74
+ PRINT_OPTS.linewidth = 80
75
+ elif profile == "short":
76
+ PRINT_OPTS.precision = 2
77
+ PRINT_OPTS.threshold = 1000
78
+ PRINT_OPTS.edgeitems = 2
79
+ PRINT_OPTS.linewidth = 80
80
+ elif profile == "full":
81
+ PRINT_OPTS.precision = 4
82
+ PRINT_OPTS.threshold = inf
83
+ PRINT_OPTS.edgeitems = 3
84
+ PRINT_OPTS.linewidth = 80
85
+
86
+ if precision is not None:
87
+ PRINT_OPTS.precision = precision
88
+ if threshold is not None:
89
+ PRINT_OPTS.threshold = threshold
90
+ if edgeitems is not None:
91
+ PRINT_OPTS.edgeitems = edgeitems
92
+ if linewidth is not None:
93
+ PRINT_OPTS.linewidth = linewidth
94
+ PRINT_OPTS.sci_mode = sci_mode
95
+
96
+
97
+ def get_printoptions() -> Dict[str, Any]:
98
+ r"""Gets the current options for printing, as a dictionary that
99
+ can be passed as ``**kwargs`` to set_printoptions().
100
+ """
101
+ return dataclasses.asdict(PRINT_OPTS)
102
+
103
+
104
+ @contextlib.contextmanager
105
+ def printoptions(**kwargs):
106
+ r"""Context manager that temporarily changes the print options. Accepted
107
+ arguments are same as :func:`set_printoptions`."""
108
+ old_kwargs = get_printoptions()
109
+ set_printoptions(**kwargs)
110
+ try:
111
+ yield
112
+ finally:
113
+ set_printoptions(**old_kwargs)
114
+
115
+
116
+ def tensor_totype(t):
117
+ dtype = torch.float if t.is_mps else torch.double
118
+ return t.to(dtype=dtype)
119
+
120
+
121
+ class _Formatter:
122
+ def __init__(self, tensor):
123
+ self.floating_dtype = tensor.dtype.is_floating_point
124
+ self.int_mode = True
125
+ self.sci_mode = False
126
+ self.max_width = 1
127
+
128
+ with torch.no_grad():
129
+ tensor_view = tensor.reshape(-1)
130
+
131
+ if not self.floating_dtype:
132
+ for value in tensor_view:
133
+ value_str = f"{value}"
134
+ self.max_width = max(self.max_width, len(value_str))
135
+
136
+ else:
137
+ nonzero_finite_vals = torch.masked_select(
138
+ tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0)
139
+ )
140
+
141
+ if nonzero_finite_vals.numel() == 0:
142
+ # no valid number, do nothing
143
+ return
144
+
145
+ # Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
146
+ nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs())
147
+ nonzero_finite_min = tensor_totype(nonzero_finite_abs.min())
148
+ nonzero_finite_max = tensor_totype(nonzero_finite_abs.max())
149
+
150
+ for value in nonzero_finite_vals:
151
+ if value != torch.ceil(value):
152
+ self.int_mode = False
153
+ break
154
+
155
+ if self.int_mode:
156
+ # in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites
157
+ # to indicate that the tensor is of floating type. add 1 to the len to account for this.
158
+ if (
159
+ nonzero_finite_max / nonzero_finite_min > 1000.0
160
+ or nonzero_finite_max > 1.0e8
161
+ ):
162
+ self.sci_mode = True
163
+ for value in nonzero_finite_vals:
164
+ value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value)
165
+ self.max_width = max(self.max_width, len(value_str))
166
+ else:
167
+ for value in nonzero_finite_vals:
168
+ value_str = f"{value:.0f}"
169
+ self.max_width = max(self.max_width, len(value_str) + 1)
170
+ else:
171
+ # Check if scientific representation should be used.
172
+ if (
173
+ nonzero_finite_max / nonzero_finite_min > 1000.0
174
+ or nonzero_finite_max > 1.0e8
175
+ or nonzero_finite_min < 1.0e-4
176
+ ):
177
+ self.sci_mode = True
178
+ for value in nonzero_finite_vals:
179
+ value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value)
180
+ self.max_width = max(self.max_width, len(value_str))
181
+ else:
182
+ for value in nonzero_finite_vals:
183
+ value_str = f"{{:.{PRINT_OPTS.precision}f}}".format(value)
184
+ self.max_width = max(self.max_width, len(value_str))
185
+
186
+ if PRINT_OPTS.sci_mode is not None:
187
+ self.sci_mode = PRINT_OPTS.sci_mode
188
+
189
+ def width(self):
190
+ return self.max_width
191
+
192
+ def format(self, value):
193
+ if self.floating_dtype:
194
+ if self.sci_mode:
195
+ ret = f"{{:{self.max_width}.{PRINT_OPTS.precision}e}}".format(value)
196
+ elif self.int_mode:
197
+ ret = f"{value:.0f}"
198
+ if not (math.isinf(value) or math.isnan(value)):
199
+ ret += "."
200
+ else:
201
+ ret = f"{{:.{PRINT_OPTS.precision}f}}".format(value)
202
+ else:
203
+ ret = f"{value}"
204
+ return (self.max_width - len(ret)) * " " + ret
205
+
206
+
207
+ def _scalar_str(self, formatter1, formatter2=None):
208
+ if formatter2 is not None:
209
+ real_str = _scalar_str(self.real, formatter1)
210
+ imag_str = (_scalar_str(self.imag, formatter2) + "j").lstrip()
211
+ # handles negative numbers, +0.0, -0.0
212
+ if imag_str[0] == "+" or imag_str[0] == "-":
213
+ return real_str + imag_str
214
+ else:
215
+ return real_str + "+" + imag_str
216
+ else:
217
+ return formatter1.format(self.item())
218
+
219
+
220
+ def _vector_str(self, indent, summarize, formatter1, formatter2=None):
221
+ # length includes spaces and comma between elements
222
+ element_length = formatter1.width() + 2
223
+ if formatter2 is not None:
224
+ # width for imag_formatter + an extra j for complex
225
+ element_length += formatter2.width() + 1
226
+
227
+ elements_per_line = max(
228
+ 1, int(math.floor((PRINT_OPTS.linewidth - indent) / (element_length)))
229
+ )
230
+
231
+ def _val_formatter(val, formatter1=formatter1, formatter2=formatter2):
232
+ if formatter2 is not None:
233
+ real_str = formatter1.format(val.real)
234
+ imag_str = (formatter2.format(val.imag) + "j").lstrip()
235
+ # handles negative numbers, +0.0, -0.0
236
+ if imag_str[0] == "+" or imag_str[0] == "-":
237
+ return real_str + imag_str
238
+ else:
239
+ return real_str + "+" + imag_str
240
+ else:
241
+ return formatter1.format(val)
242
+
243
+ if summarize and not PRINT_OPTS.edgeitems:
244
+ # Deal with edge case that negative zero is zero
245
+ data = ["..."]
246
+ elif summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:
247
+ data = (
248
+ [_val_formatter(val) for val in self[: PRINT_OPTS.edgeitems].tolist()]
249
+ + [" ..."]
250
+ + [_val_formatter(val) for val in self[-PRINT_OPTS.edgeitems :].tolist()]
251
+ )
252
+ else:
253
+ data = [_val_formatter(val) for val in self.tolist()]
254
+
255
+ data_lines = [
256
+ data[i : i + elements_per_line] for i in range(0, len(data), elements_per_line)
257
+ ]
258
+ lines = [", ".join(line) for line in data_lines]
259
+ return "[" + ("," + "\n" + " " * (indent + 1)).join(lines) + "]"
260
+
261
+
262
+ # formatter2 is only used for printing complex tensors.
263
+ # For complex tensors, formatter1 and formatter2 are the formatters for tensor.real
264
+ # and tensor.imag respesectively
265
+ def _tensor_str_with_formatter(self, indent, summarize, formatter1, formatter2=None):
266
+ dim = self.dim()
267
+
268
+ if dim == 0:
269
+ return _scalar_str(self, formatter1, formatter2)
270
+
271
+ if dim == 1:
272
+ return _vector_str(self, indent, summarize, formatter1, formatter2)
273
+
274
+ if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:
275
+ slices = (
276
+ [
277
+ _tensor_str_with_formatter(
278
+ self[i], indent + 1, summarize, formatter1, formatter2
279
+ )
280
+ for i in range(0, PRINT_OPTS.edgeitems)
281
+ ]
282
+ + ["..."]
283
+ + [
284
+ _tensor_str_with_formatter(
285
+ self[i], indent + 1, summarize, formatter1, formatter2
286
+ )
287
+ for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))
288
+ ]
289
+ )
290
+ else:
291
+ slices = [
292
+ _tensor_str_with_formatter(
293
+ self[i], indent + 1, summarize, formatter1, formatter2
294
+ )
295
+ for i in range(0, self.size(0))
296
+ ]
297
+
298
+ tensor_str = ("," + "\n" * (dim - 1) + " " * (indent + 1)).join(slices)
299
+ return "[" + tensor_str + "]"
300
+
301
+
302
+ def _tensor_str(self, indent):
303
+ if self.numel() == 0:
304
+ return "[]"
305
+
306
+ if self.has_names():
307
+ # There are two main codepaths (possibly more) that tensor printing goes through:
308
+ # - tensor data can fit comfortably on screen
309
+ # - tensor data needs to be summarized
310
+ # Some of the codepaths don't fully support named tensors, so we send in
311
+ # an unnamed tensor to the formatting code as a workaround.
312
+ self = self.rename(None)
313
+
314
+ summarize = self.numel() > PRINT_OPTS.threshold
315
+
316
+ if self._is_zerotensor():
317
+ self = self.clone()
318
+
319
+ # handle the negative bit
320
+ if self.is_neg():
321
+ self = self.resolve_neg()
322
+
323
+ if self.dtype in [
324
+ torch.float16,
325
+ torch.bfloat16,
326
+ torch.float8_e5m2,
327
+ torch.float8_e5m2fnuz,
328
+ torch.float8_e4m3fn,
329
+ torch.float8_e4m3fnuz,
330
+ ]:
331
+ self = self.float()
332
+
333
+ if self.dtype is torch.complex32:
334
+ self = self.cfloat()
335
+
336
+ if self.dtype.is_complex:
337
+ # handle the conjugate bit
338
+ self = self.resolve_conj()
339
+ real_formatter = _Formatter(
340
+ get_summarized_data(self.real) if summarize else self.real
341
+ )
342
+ imag_formatter = _Formatter(
343
+ get_summarized_data(self.imag) if summarize else self.imag
344
+ )
345
+ return _tensor_str_with_formatter(
346
+ self, indent, summarize, real_formatter, imag_formatter
347
+ )
348
+ else:
349
+ formatter = _Formatter(get_summarized_data(self) if summarize else self)
350
+ return _tensor_str_with_formatter(self, indent, summarize, formatter)
351
+
352
+
353
+ def _add_suffixes(tensor_str, suffixes, indent, force_newline):
354
+ tensor_strs = [tensor_str]
355
+ last_line_len = len(tensor_str) - tensor_str.rfind("\n") + 1
356
+ for suffix in suffixes:
357
+ suffix_len = len(suffix)
358
+ if force_newline or last_line_len + suffix_len + 2 > PRINT_OPTS.linewidth:
359
+ tensor_strs.append(",\n" + " " * indent + suffix)
360
+ last_line_len = indent + suffix_len
361
+ force_newline = False
362
+ else:
363
+ tensor_strs.append(", " + suffix)
364
+ last_line_len += suffix_len + 2
365
+ tensor_strs.append(")")
366
+ return "".join(tensor_strs)
367
+
368
+
369
+ def get_summarized_data(self):
370
+ dim = self.dim()
371
+ if dim == 0:
372
+ return self
373
+ if dim == 1:
374
+ if self.size(0) > 2 * PRINT_OPTS.edgeitems:
375
+ return torch.cat(
376
+ (self[: PRINT_OPTS.edgeitems], self[-PRINT_OPTS.edgeitems :])
377
+ )
378
+ else:
379
+ return self
380
+ if not PRINT_OPTS.edgeitems:
381
+ return self.new_empty([0] * self.dim())
382
+ elif self.size(0) > 2 * PRINT_OPTS.edgeitems:
383
+ start = [self[i] for i in range(0, PRINT_OPTS.edgeitems)]
384
+ end = [self[i] for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))]
385
+ return torch.stack([get_summarized_data(x) for x in (start + end)])
386
+ else:
387
+ return torch.stack([get_summarized_data(x) for x in self])
388
+
389
+
390
+ def _str_intern(inp, *, tensor_contents=None):
391
+ if torch._C._functorch.is_functorch_wrapped_tensor(inp):
392
+ return _functorch_wrapper_str_intern(inp, tensor_contents=tensor_contents)
393
+ is_plain_tensor = type(inp) is torch.Tensor or type(inp) is torch.nn.Parameter
394
+ if inp.is_nested:
395
+ prefix = "nested_tensor("
396
+ elif is_plain_tensor:
397
+ prefix = "tensor("
398
+ else:
399
+ prefix = f"{type(inp).__name__}("
400
+ indent = len(prefix)
401
+ suffixes = []
402
+ custom_contents_provided = tensor_contents is not None
403
+ if custom_contents_provided:
404
+ tensor_str = tensor_contents
405
+
406
+ # This is used to extract the primal value and thus disable the forward AD
407
+ # within this function.
408
+ # TODO(albanD) This needs to be updated when more than one level is supported
409
+ self, tangent = torch.autograd.forward_ad.unpack_dual(inp)
410
+
411
+ # Note [Print tensor device]:
412
+ # A general logic here is we only print device when it doesn't match
413
+ # the device specified in default tensor type.
414
+ # Currently torch.set_default_tensor_type() only supports CPU/CUDA, thus
415
+ # torch._C._get_default_device() only returns either cpu or cuda.
416
+ # In other cases, we don't have a way to set them as default yet,
417
+ # and we should always print out device for them.
418
+ if (
419
+ self.device.type != torch._C._get_default_device()
420
+ or (
421
+ self.device.type == "cuda"
422
+ and torch.cuda.current_device() != self.device.index
423
+ )
424
+ or (self.device.type == "mps")
425
+ ):
426
+ suffixes.append("device='" + str(self.device) + "'")
427
+
428
+ # Tensor printing performs tensor operations like slice, indexing, etc to make it in a
429
+ # representable format. These operations on ipu/xla/lazy/mtia tensor results in compilations. Hence,
430
+ # to avoid compilations, copying the tensor to cpu before printing.
431
+ if self.device.type in ["xla", "lazy", "ipu", "mtia"]:
432
+ self = self.to("cpu")
433
+
434
+ # TODO: add an API to map real -> complex dtypes
435
+ _default_complex_dtype = (
436
+ torch.cdouble if torch.get_default_dtype() == torch.double else torch.cfloat
437
+ )
438
+ has_default_dtype = self.dtype in (
439
+ torch.get_default_dtype(),
440
+ _default_complex_dtype,
441
+ torch.int64,
442
+ torch.bool,
443
+ )
444
+ if self.is_sparse:
445
+ suffixes.append("size=" + str(tuple(self.shape)))
446
+ from torch._subclasses.fake_tensor import FakeTensor
447
+
448
+ is_meta = self.is_meta or isinstance(self, FakeTensor)
449
+ if not is_meta:
450
+ suffixes.append("nnz=" + str(self._nnz()))
451
+ if not has_default_dtype:
452
+ suffixes.append("dtype=" + str(self.dtype))
453
+ if not custom_contents_provided:
454
+ indices_prefix = "indices=tensor("
455
+ indices = self._indices().detach()
456
+ if is_meta:
457
+ indices_str = "..."
458
+ else:
459
+ indices_str = _tensor_str(indices, indent + len(indices_prefix))
460
+ if indices.numel() == 0 or is_meta:
461
+ indices_str += ", size=" + str(tuple(indices.shape))
462
+ values_prefix = "values=tensor("
463
+ values = self._values().detach()
464
+ if is_meta:
465
+ values_str = "..."
466
+ else:
467
+ values_str = _tensor_str(values, indent + len(values_prefix))
468
+ if values.numel() == 0 or is_meta:
469
+ values_str += ", size=" + str(tuple(values.shape))
470
+ tensor_str = (
471
+ indices_prefix
472
+ + indices_str
473
+ + "),\n"
474
+ + " " * indent
475
+ + values_prefix
476
+ + values_str
477
+ + ")"
478
+ )
479
+ elif self.layout in {
480
+ torch.sparse_csr,
481
+ torch.sparse_csc,
482
+ torch.sparse_bsr,
483
+ torch.sparse_bsc,
484
+ }:
485
+ from torch._subclasses.fake_tensor import FakeTensor
486
+
487
+ suffixes.append("size=" + str(tuple(self.shape)))
488
+ is_meta = self.is_meta or isinstance(self, FakeTensor)
489
+ if not is_meta:
490
+ suffixes.append("nnz=" + str(self._nnz()))
491
+ if not has_default_dtype:
492
+ suffixes.append("dtype=" + str(self.dtype))
493
+ if not custom_contents_provided:
494
+ compressed_indices_method, plain_indices_method = {
495
+ torch.sparse_csr: (torch.Tensor.crow_indices, torch.Tensor.col_indices),
496
+ torch.sparse_csc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices),
497
+ torch.sparse_bsr: (torch.Tensor.crow_indices, torch.Tensor.col_indices),
498
+ torch.sparse_bsc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices),
499
+ }[self.layout]
500
+ if self.layout in {torch.sparse_csr, torch.sparse_bsr}:
501
+ cdimname, pdimname = "row", "column"
502
+ else:
503
+ cdimname, pdimname = "column", "row"
504
+ compressed_indices_prefix = f"c{cdimname[:3]}_indices=tensor("
505
+ compressed_indices = compressed_indices_method(self).detach()
506
+ if is_meta:
507
+ compressed_indices_str = "..."
508
+ else:
509
+ compressed_indices_str = _tensor_str(
510
+ compressed_indices, indent + len(compressed_indices_prefix)
511
+ )
512
+ if compressed_indices.numel() == 0 or is_meta:
513
+ compressed_indices_str += ", size=" + str(
514
+ tuple(compressed_indices.shape)
515
+ )
516
+ plain_indices_prefix = f"{pdimname[:3]}_indices=tensor("
517
+ plain_indices = plain_indices_method(self).detach()
518
+ if is_meta:
519
+ plain_indices_str = "..."
520
+ else:
521
+ plain_indices_str = _tensor_str(
522
+ plain_indices, indent + len(plain_indices_prefix)
523
+ )
524
+ if plain_indices.numel() == 0 or is_meta:
525
+ plain_indices_str += ", size=" + str(tuple(plain_indices.shape))
526
+ values_prefix = "values=tensor("
527
+ values = self.values().detach()
528
+ if is_meta:
529
+ values_str = "..."
530
+ else:
531
+ values_str = _tensor_str(values, indent + len(values_prefix))
532
+ if values.numel() == 0 or is_meta:
533
+ values_str += ", size=" + str(tuple(values.shape))
534
+ tensor_str = (
535
+ compressed_indices_prefix
536
+ + compressed_indices_str
537
+ + "),\n"
538
+ + " " * indent
539
+ + plain_indices_prefix
540
+ + plain_indices_str
541
+ + "),\n"
542
+ + " " * indent
543
+ + values_prefix
544
+ + values_str
545
+ + ")"
546
+ )
547
+ elif self.is_quantized:
548
+ suffixes.append("size=" + str(tuple(self.shape)))
549
+ if not has_default_dtype:
550
+ suffixes.append("dtype=" + str(self.dtype))
551
+ suffixes.append("quantization_scheme=" + str(self.qscheme()))
552
+ if (
553
+ self.qscheme() == torch.per_tensor_affine
554
+ or self.qscheme() == torch.per_tensor_symmetric
555
+ ):
556
+ suffixes.append("scale=" + str(self.q_scale()))
557
+ suffixes.append("zero_point=" + str(self.q_zero_point()))
558
+ elif (
559
+ self.qscheme() == torch.per_channel_affine
560
+ or self.qscheme() == torch.per_channel_symmetric
561
+ or self.qscheme() == torch.per_channel_affine_float_qparams
562
+ ):
563
+ suffixes.append("scale=" + str(self.q_per_channel_scales()))
564
+ suffixes.append("zero_point=" + str(self.q_per_channel_zero_points()))
565
+ suffixes.append("axis=" + str(self.q_per_channel_axis()))
566
+ if not custom_contents_provided:
567
+ tensor_str = _tensor_str(self.dequantize(), indent)
568
+ elif self.is_nested:
569
+ if not custom_contents_provided:
570
+
571
+ def indented_str(s, indent):
572
+ return "\n".join(f" {line}" for line in s.split("\n"))
573
+
574
+ strs = ",\n".join(
575
+ indented_str(str(t), indent + 1)
576
+ for t in torch.ops.aten.unbind.int(self, 0)
577
+ )
578
+ tensor_str = f"[\n{strs}\n]"
579
+ elif torch._is_functional_tensor(self):
580
+ prefix = "_to_functional_tensor("
581
+ tensor_str = repr(torch._from_functional_tensor(self))
582
+ else:
583
+ # Circular import problem, so we import it here
584
+ from torch._subclasses.fake_tensor import FakeTensor
585
+
586
+ if self.is_meta or isinstance(self, FakeTensor):
587
+ suffixes.append("size=" + str(tuple(self.shape)))
588
+ if self.dtype != torch.get_default_dtype():
589
+ suffixes.append("dtype=" + str(self.dtype))
590
+ # TODO: This implies that ellipses is valid syntax for allocating
591
+ # a meta tensor or FakeTensor, which it could be, but it isn't right now
592
+ if not custom_contents_provided:
593
+ tensor_str = "..."
594
+ else:
595
+ if self.numel() == 0 and not self.is_sparse:
596
+ # Explicitly print the shape if it is not (0,), to match NumPy behavior
597
+ if self.dim() != 1:
598
+ suffixes.append("size=" + str(tuple(self.shape)))
599
+
600
+ # In an empty tensor, there are no elements to infer if the dtype
601
+ # should be int64, so it must be shown explicitly.
602
+ if self.dtype != torch.get_default_dtype():
603
+ suffixes.append("dtype=" + str(self.dtype))
604
+ if not custom_contents_provided:
605
+ tensor_str = "[]"
606
+ else:
607
+ if not PRINT_OPTS.edgeitems:
608
+ suffixes.append("size=" + str(tuple(self.shape)))
609
+
610
+ if not has_default_dtype:
611
+ suffixes.append("dtype=" + str(self.dtype))
612
+
613
+ if not custom_contents_provided:
614
+ if self.layout != torch.strided:
615
+ tensor_str = _tensor_str(self.to_dense(), indent)
616
+ else:
617
+ tensor_str = _tensor_str(self, indent)
618
+
619
+ if self.layout != torch.strided:
620
+ suffixes.append("layout=" + str(self.layout))
621
+
622
+ # Use inp here to get the original grad_fn and not the one generated by the forward grad
623
+ # unpacking.
624
+ grad_fn_name = None
625
+ try:
626
+ grad_fn = inp.grad_fn
627
+ except RuntimeError:
628
+ # Accessing the grad_fn calls rebasing logic which would cause an error
629
+ # if that tensor is a view created in no-grad mode modified in-place in
630
+ # no-grad mode. See: https://github.com/pytorch/pytorch/issues/99968
631
+ grad_fn_name = "Invalid"
632
+
633
+ if grad_fn_name is None and grad_fn is not None: # type: ignore[possibly-undefined]
634
+ grad_fn_name = type(grad_fn).__name__
635
+ if grad_fn_name == "CppFunction":
636
+ grad_fn_name = grad_fn.name().rsplit("::", 1)[-1]
637
+
638
+ if grad_fn_name is not None:
639
+ suffixes.append(f"grad_fn=<{grad_fn_name}>")
640
+ elif inp.requires_grad:
641
+ suffixes.append("requires_grad=True")
642
+
643
+ if self.has_names():
644
+ suffixes.append(f"names={self.names}")
645
+
646
+ if tangent is not None:
647
+ suffixes.append(f"tangent={tangent}")
648
+
649
+ string_repr = _add_suffixes(
650
+ prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse # type: ignore[possibly-undefined]
651
+ )
652
+
653
+ # Check if this instance is flagged as a parameter and change the repr accordingly.
654
+ # Unfortunately, this function has to be aware of this detail.
655
+ # NB: This is currently skipped for plain tensor parameters to maintain BC. In the future,
656
+ # this should be done for those as well to produce a valid repr.
657
+ if isinstance(self, torch.nn.Parameter) and not is_plain_tensor:
658
+ string_repr = f"Parameter({string_repr})"
659
+
660
+ return string_repr
661
+
662
+
663
+ def _functorch_wrapper_str_intern(tensor, *, tensor_contents=None):
664
+ level = torch._C._functorch.maybe_get_level(tensor)
665
+ assert level != -1
666
+
667
+ if torch._C._functorch.is_functionaltensor(tensor):
668
+ # Since we're unwrapping the FunctionalTensorWrapper, we need to make sure
669
+ # that it's up to date first
670
+ torch._sync(tensor)
671
+
672
+ value = torch._C._functorch.get_unwrapped(tensor)
673
+ value_repr = repr(value)
674
+
675
+ indented_value_repr = textwrap.indent(value_repr, " " * 4)
676
+ if torch._C._functorch.is_batchedtensor(tensor):
677
+ bdim = torch._C._functorch.maybe_get_bdim(tensor)
678
+ assert bdim != -1
679
+ return (
680
+ f"BatchedTensor(lvl={level}, bdim={bdim}, value=\n"
681
+ f"{indented_value_repr}\n"
682
+ f")"
683
+ )
684
+ if torch._C._functorch.is_gradtrackingtensor(tensor):
685
+ return (
686
+ f"GradTrackingTensor(lvl={level}, value=\n" f"{indented_value_repr}\n" f")"
687
+ )
688
+ if torch._C._functorch.is_functionaltensor(tensor):
689
+ return f"FunctionalTensor(lvl={level}, value=\\\n{value_repr})"
690
+
691
+ raise ValueError("We don't know how to print this, please file us an issue")
692
+
693
+
694
+ def _str(self, *, tensor_contents=None):
695
+ with torch.no_grad(), torch.utils._python_dispatch._disable_current_modes():
696
+ guard = torch._C._DisableFuncTorch()
697
+ return _str_intern(self, tensor_contents=tensor_contents)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/modules/fused.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import Conv1d, Conv2d, Conv3d, ReLU, Linear, BatchNorm1d, BatchNorm2d, BatchNorm3d
3
+ from torch.nn.utils.parametrize import type_before_parametrizations
4
+
5
+ __all__ = ['ConvReLU1d', 'ConvReLU2d', 'ConvReLU3d', 'LinearReLU', 'ConvBn1d', 'ConvBn2d',
6
+ 'ConvBnReLU1d', 'ConvBnReLU2d', 'ConvBn3d', 'ConvBnReLU3d', 'BNReLU2d', 'BNReLU3d',
7
+ 'LinearBn1d', 'LinearLeakyReLU', 'LinearTanh', 'ConvAdd2d', 'ConvAddReLU2d']
8
+
9
+ # Used for identifying intrinsic modules used in quantization
10
+ class _FusedModule(torch.nn.Sequential):
11
+ pass
12
+
13
+ class ConvReLU1d(_FusedModule):
14
+ r"""This is a sequential container which calls the Conv1d and ReLU modules.
15
+ During quantization this will be replaced with the corresponding fused module."""
16
+ def __init__(self, conv, relu):
17
+ assert type_before_parametrizations(conv) == Conv1d and type_before_parametrizations(relu) == ReLU, \
18
+ f'Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}'
19
+ super().__init__(conv, relu)
20
+
21
+ class ConvReLU2d(_FusedModule):
22
+ r"""This is a sequential container which calls the Conv2d and ReLU modules.
23
+ During quantization this will be replaced with the corresponding fused module."""
24
+ def __init__(self, conv, relu):
25
+ assert type_before_parametrizations(conv) == Conv2d and type_before_parametrizations(relu) == ReLU, \
26
+ f'Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}'
27
+ super().__init__(conv, relu)
28
+
29
+ class ConvReLU3d(_FusedModule):
30
+ r"""This is a sequential container which calls the Conv3d and ReLU modules.
31
+ During quantization this will be replaced with the corresponding fused module."""
32
+ def __init__(self, conv, relu):
33
+ assert type_before_parametrizations(conv) == Conv3d and type_before_parametrizations(relu) == ReLU, \
34
+ f'Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}'
35
+ super().__init__(conv, relu)
36
+
37
+ class LinearReLU(_FusedModule):
38
+ r"""This is a sequential container which calls the Linear and ReLU modules.
39
+ During quantization this will be replaced with the corresponding fused module."""
40
+ def __init__(self, linear, relu):
41
+ assert type_before_parametrizations(linear) == Linear and type_before_parametrizations(relu) == ReLU, \
42
+ 'Incorrect types for input modules{}{}'.format(
43
+ type_before_parametrizations(linear), type_before_parametrizations(relu))
44
+ super().__init__(linear, relu)
45
+
46
+ class ConvBn1d(_FusedModule):
47
+ r"""This is a sequential container which calls the Conv 1d and Batch Norm 1d modules.
48
+ During quantization this will be replaced with the corresponding fused module."""
49
+ def __init__(self, conv, bn):
50
+ assert type_before_parametrizations(conv) == Conv1d and type_before_parametrizations(bn) == BatchNorm1d, \
51
+ f'Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}'
52
+ super().__init__(conv, bn)
53
+
54
+ class ConvBn2d(_FusedModule):
55
+ r"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules.
56
+ During quantization this will be replaced with the corresponding fused module."""
57
+ def __init__(self, conv, bn):
58
+ assert type_before_parametrizations(conv) == Conv2d and type_before_parametrizations(bn) == BatchNorm2d, \
59
+ f'Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}'
60
+ super().__init__(conv, bn)
61
+
62
+ class ConvBnReLU1d(_FusedModule):
63
+ r"""This is a sequential container which calls the Conv 1d, Batch Norm 1d, and ReLU modules.
64
+ During quantization this will be replaced with the corresponding fused module."""
65
+ def __init__(self, conv, bn, relu):
66
+ assert type_before_parametrizations(conv) == Conv1d and type_before_parametrizations(bn) == BatchNorm1d and \
67
+ type_before_parametrizations(relu) == ReLU, 'Incorrect types for input modules{}{}{}' \
68
+ .format(type_before_parametrizations(conv), type_before_parametrizations(bn), type_before_parametrizations(relu))
69
+ super().__init__(conv, bn, relu)
70
+
71
+ class ConvBnReLU2d(_FusedModule):
72
+ r"""This is a sequential container which calls the Conv 2d, Batch Norm 2d, and ReLU modules.
73
+ During quantization this will be replaced with the corresponding fused module."""
74
+ def __init__(self, conv, bn, relu):
75
+ assert type_before_parametrizations(conv) == Conv2d and type_before_parametrizations(bn) == BatchNorm2d and \
76
+ type_before_parametrizations(relu) == ReLU, 'Incorrect types for input modules{}{}{}' \
77
+ .format(type_before_parametrizations(conv), type_before_parametrizations(bn), type_before_parametrizations(relu))
78
+ super().__init__(conv, bn, relu)
79
+
80
+ class ConvBn3d(_FusedModule):
81
+ r"""This is a sequential container which calls the Conv 3d and Batch Norm 3d modules.
82
+ During quantization this will be replaced with the corresponding fused module."""
83
+ def __init__(self, conv, bn):
84
+ assert type_before_parametrizations(conv) == Conv3d and type_before_parametrizations(bn) == BatchNorm3d, \
85
+ f'Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}'
86
+ super().__init__(conv, bn)
87
+
88
+ class ConvBnReLU3d(_FusedModule):
89
+ r"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules.
90
+ During quantization this will be replaced with the corresponding fused module."""
91
+ def __init__(self, conv, bn, relu):
92
+ assert type_before_parametrizations(conv) == Conv3d and type_before_parametrizations(bn) == BatchNorm3d and \
93
+ type_before_parametrizations(relu) == ReLU, 'Incorrect types for input modules{}{}{}' \
94
+ .format(type_before_parametrizations(conv), type_before_parametrizations(bn), type_before_parametrizations(relu))
95
+ super().__init__(conv, bn, relu)
96
+
97
+
98
+ class BNReLU2d(_FusedModule):
99
+ r"""This is a sequential container which calls the BatchNorm 2d and ReLU modules.
100
+ During quantization this will be replaced with the corresponding fused module."""
101
+ def __init__(self, batch_norm, relu):
102
+ assert type_before_parametrizations(batch_norm) == BatchNorm2d and type_before_parametrizations(relu) == ReLU, \
103
+ 'Incorrect types for input modules{}{}'.format(
104
+ type_before_parametrizations(batch_norm), type_before_parametrizations(relu))
105
+ super().__init__(batch_norm, relu)
106
+
107
+ class BNReLU3d(_FusedModule):
108
+ r"""This is a sequential container which calls the BatchNorm 3d and ReLU modules.
109
+ During quantization this will be replaced with the corresponding fused module."""
110
+ def __init__(self, batch_norm, relu):
111
+ assert type_before_parametrizations(batch_norm) == BatchNorm3d and type_before_parametrizations(relu) == ReLU, \
112
+ 'Incorrect types for input modules{}{}'.format(
113
+ type_before_parametrizations(batch_norm), type_before_parametrizations(relu))
114
+ super().__init__(batch_norm, relu)
115
+
116
+
117
+ class LinearBn1d(_FusedModule):
118
+ r"""This is a sequential container which calls the Linear and BatchNorm1d modules.
119
+ During quantization this will be replaced with the corresponding fused module."""
120
+ def __init__(self, linear, bn):
121
+ assert type_before_parametrizations(linear) == Linear and type_before_parametrizations(bn) == BatchNorm1d, \
122
+ f'Incorrect types for input modules{type_before_parametrizations(linear)}{type_before_parametrizations(bn)}'
123
+ super().__init__(linear, bn)
124
+
125
+ class LinearLeakyReLU(_FusedModule):
126
+ r"""This is a sequential container which calls the Linear and LeakyReLU modules.
127
+ During quantization this will be replaced with the corresponding fused module."""
128
+ def __init__(self, linear, leaky_relu):
129
+ assert type(linear) == Linear and type(leaky_relu) == torch.nn.LeakyReLU, \
130
+ f'Incorrect types for input modules{type(linear)}{type(leaky_relu)}'
131
+ super().__init__(linear, leaky_relu)
132
+
133
+ class LinearTanh(_FusedModule):
134
+ r"""This is a sequential container which calls the Linear and Tanh modules.
135
+ During quantization this will be replaced with the corresponding fused module."""
136
+ def __init__(self, linear, tanh):
137
+ assert type(linear) == Linear and type(tanh) == torch.nn.Tanh, \
138
+ f'Incorrect types for input modules{type(linear)}{type(tanh)}'
139
+ super().__init__(linear, tanh)
140
+
141
+ class ConvAdd2d(_FusedModule):
142
+ r"""This is a sequential container which calls the Conv2d modules with extra Add.
143
+ During quantization this will be replaced with the corresponding fused module."""
144
+ def __init__(self, conv, add):
145
+ super().__init__(conv)
146
+ self.add = add
147
+
148
+ def forward(self, x1, x2):
149
+ return self.add(self[0](x1), x2)
150
+
151
+ class ConvAddReLU2d(_FusedModule):
152
+ r"""This is a sequential container which calls the Conv2d, add, Relu.
153
+ During quantization this will be replaced with the corresponding fused module."""
154
+ def __init__(self, conv, add, relu):
155
+ super().__init__(conv)
156
+ self.add = add
157
+ self.relu = relu
158
+
159
+ def forward(self, x1, x2):
160
+ return self.relu(self.add(self[0](x1), x2))
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .modules import * # noqa: F403
2
+
3
+ __all__ = [
4
+ 'BNReLU2d',
5
+ 'BNReLU3d',
6
+ 'ConvReLU1d',
7
+ 'ConvReLU2d',
8
+ 'ConvReLU3d',
9
+ 'LinearReLU',
10
+ 'LinearLeakyReLU',
11
+ 'LinearTanh',
12
+ 'ConvAdd2d',
13
+ 'ConvAddReLU2d',
14
+ ]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-311.pyc ADDED
Binary file (3.56 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.ao.nn.intrinsic
4
+ import torch.ao.nn.intrinsic.qat
5
+ import torch.nn.functional as F
6
+ import torch.ao.nn.quantized as nnq
7
+
8
+ from torch.nn.utils import fuse_conv_bn_weights
9
+
10
+ __all__ = [
11
+ "ConvReLU1d",
12
+ "ConvReLU2d",
13
+ "ConvReLU3d",
14
+ ]
15
+
16
+ _reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding
17
+
18
+ # TODO: factor out the common parts to ConvNd
19
+ class ConvReLU1d(nnq.Conv1d):
20
+ r"""
21
+ A ConvReLU1d module is a fused module of Conv1d and ReLU
22
+
23
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Conv1d`.
24
+
25
+ Attributes:
26
+ Same as torch.ao.nn.quantized.Conv1d
27
+
28
+ """
29
+ _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU1d # type: ignore[assignment]
30
+
31
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
32
+ padding=0, dilation=1, groups=1, bias=True,
33
+ padding_mode='zeros', device=None, dtype=None):
34
+ super().__init__(
35
+ in_channels, out_channels, kernel_size, stride=stride,
36
+ padding=padding, dilation=dilation, groups=groups, bias=bias,
37
+ padding_mode=padding_mode, device=device, dtype=dtype)
38
+
39
+ def forward(self, input):
40
+ # Temporarily using len(shape) instead of ndim due to JIT issue
41
+ # https://github.com/pytorch/pytorch/issues/23890
42
+ if len(input.shape) != 3:
43
+ raise ValueError("Input shape must be `(N, C, L)`!")
44
+ if self.padding_mode != 'zeros':
45
+ # Padding in Conv1d is stored as (p, p), need to get (p,)
46
+ _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1])
47
+ input = F.pad(input, _reversed_padding_repeated_twice,
48
+ mode=self.padding_mode)
49
+ return torch.ops.quantized.conv1d_relu(
50
+ input, self._packed_params, self.scale, self.zero_point)
51
+
52
+ def _get_name(self):
53
+ return 'QuantizedConvReLU1d'
54
+
55
+ @classmethod
56
+ def from_float(cls, mod):
57
+ if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d:
58
+ assert mod.bn.running_var is not None and mod.bn.running_mean is not None
59
+ mod.weight, mod.bias = fuse_conv_bn_weights(
60
+ mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
61
+ mod.bn.eps, mod.bn.weight, mod.bn.bias)
62
+ return super().from_float(mod)
63
+
64
+ @classmethod
65
+ def from_reference(cls, ref_qconv, output_scale, output_zero_point):
66
+ assert type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU1d, \
67
+ "BatchNorm1d should be fused into Conv1d before converting to reference module"
68
+ return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
69
+
70
+ class ConvReLU2d(nnq.Conv2d):
71
+ r"""
72
+ A ConvReLU2d module is a fused module of Conv2d and ReLU
73
+
74
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
75
+
76
+ Attributes:
77
+ Same as torch.ao.nn.quantized.Conv2d
78
+
79
+ """
80
+ _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU2d # type: ignore[assignment]
81
+
82
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
83
+ padding=0, dilation=1, groups=1, bias=True,
84
+ padding_mode='zeros', device=None, dtype=None):
85
+ super().__init__(
86
+ in_channels, out_channels, kernel_size, stride=stride,
87
+ padding=padding, dilation=dilation, groups=groups, bias=bias,
88
+ padding_mode=padding_mode, device=device, dtype=dtype)
89
+
90
+ def forward(self, input):
91
+ # Temporarily using len(shape) instead of ndim due to JIT issue
92
+ # https://github.com/pytorch/pytorch/issues/23890
93
+ if len(input.shape) != 4:
94
+ raise ValueError("Input shape must be `(N, C, H, W)`!")
95
+ if self.padding_mode != 'zeros':
96
+ _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
97
+ input = F.pad(input, _reversed_padding_repeated_twice,
98
+ mode=self.padding_mode)
99
+ return torch.ops.quantized.conv2d_relu(
100
+ input, self._packed_params, self.scale, self.zero_point)
101
+
102
+ def _get_name(self):
103
+ return 'QuantizedConvReLU2d'
104
+
105
+ @classmethod
106
+ def from_float(cls, mod):
107
+ if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d:
108
+ assert mod.bn.running_var is not None and mod.bn.running_mean is not None
109
+ mod.weight, mod.bias = fuse_conv_bn_weights(
110
+ mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
111
+ mod.bn.eps, mod.bn.weight, mod.bn.bias)
112
+ return super().from_float(mod)
113
+
114
+ @classmethod
115
+ def from_reference(cls, ref_qconv, output_scale, output_zero_point):
116
+ assert type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU2d, \
117
+ "BatchNorm2d should be fused into Conv2d before converting to reference module"
118
+ return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
119
+
120
+
121
+ class ConvReLU3d(nnq.Conv3d):
122
+ r"""
123
+ A ConvReLU3d module is a fused module of Conv3d and ReLU
124
+
125
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Conv3d`.
126
+
127
+ Attributes: Same as torch.ao.nn.quantized.Conv3d
128
+
129
+ """
130
+ _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU3d # type: ignore[assignment]
131
+
132
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
133
+ padding=0, dilation=1, groups=1, bias=True,
134
+ padding_mode='zeros', device=None, dtype=None):
135
+ assert padding_mode != 'reflect', "Conv3d does not support reflection padding"
136
+ super().__init__(
137
+ in_channels, out_channels, kernel_size, stride=stride,
138
+ padding=padding, dilation=dilation, groups=groups, bias=bias,
139
+ padding_mode=padding_mode, device=device, dtype=dtype)
140
+
141
+ def forward(self, input):
142
+ # Temporarily using len(shape) instead of ndim due to JIT issue
143
+ # https://github.com/pytorch/pytorch/issues/23890
144
+ if len(input.shape) != 5:
145
+ raise ValueError("Input shape must be `(N, C, D, H, W)`!")
146
+ if self.padding_mode != 'zeros':
147
+ _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
148
+ input = F.pad(input, _reversed_padding_repeated_twice,
149
+ mode=self.padding_mode)
150
+ return torch.ops.quantized.conv3d_relu(
151
+ input, self._packed_params, self.scale, self.zero_point)
152
+
153
+ def _get_name(self):
154
+ return 'QuantizedConvReLU3d'
155
+
156
+ @classmethod
157
+ def from_float(cls, mod):
158
+ if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d:
159
+ assert mod.bn.running_var is not None and mod.bn.running_mean is not None
160
+ mod.weight, mod.bias = fuse_conv_bn_weights(
161
+ mod.weight,
162
+ mod.bias,
163
+ mod.bn.running_mean,
164
+ mod.bn.running_var,
165
+ mod.bn.eps,
166
+ mod.bn.weight,
167
+ mod.bn.bias,
168
+ )
169
+ return super().from_float(mod)
170
+
171
+ @classmethod
172
+ def from_reference(cls, ref_qconv, output_scale, output_zero_point):
173
+ assert type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU3d, \
174
+ "BatchNorm3d should be fused into Conv3d before converting to reference module"
175
+ return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.ao.nn.quantized as nnq
3
+ import torch.ao.nn.intrinsic as nni
4
+ from torch.ao.nn.quantized.modules.utils import _quantize_weight
5
+
6
+ __all__ = [
7
+ "LinearReLU",
8
+ "LinearLeakyReLU",
9
+ "LinearTanh",
10
+ ]
11
+
12
+ class LinearReLU(nnq.Linear):
13
+ r"""
14
+ A LinearReLU module fused from Linear and ReLU modules
15
+
16
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
17
+
18
+ Attributes:
19
+ Same as torch.ao.nn.quantized.Linear
20
+
21
+ Examples::
22
+
23
+ >>> # xdoctest: +SKIP
24
+ >>> m = nn.intrinsic.LinearReLU(20, 30)
25
+ >>> input = torch.randn(128, 20)
26
+ >>> output = m(input)
27
+ >>> print(output.size())
28
+ torch.Size([128, 30])
29
+ """
30
+ _FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment]
31
+
32
+ def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8):
33
+ super().__init__(in_features, out_features, bias, dtype)
34
+
35
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
36
+ return torch.ops.quantized.linear_relu(
37
+ x, self._packed_params._packed_params, self.scale, self.zero_point)
38
+
39
+ def _get_name(self):
40
+ return 'QuantizedLinearReLU'
41
+
42
+ @classmethod
43
+ def from_float(cls, mod):
44
+ return super().from_float(mod)
45
+
46
+ @classmethod
47
+ def from_reference(cls, ref_linear_relu, output_scale, output_zero_point):
48
+ return super().from_reference(ref_linear_relu[0], output_scale, output_zero_point)
49
+
50
+ class LinearLeakyReLU(nnq.Linear):
51
+ r"""
52
+ For onednn backend only
53
+ A LinearLeakyReLU module fused from Linear and LeakyReLU modules
54
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
55
+ Attributes:
56
+ Same as torch.ao.nn.quantized.Linear
57
+ + negative_slope
58
+ Examples::
59
+ >>> # xdoctest: +SKIP
60
+ >>> m = nn.intrinsic.LinearLeakyReLU(20, 30, 0.01)
61
+ >>> input = torch.randn(128, 20)
62
+ >>> output = m(input)
63
+ >>> print(output.size())
64
+ torch.Size([128, 30])
65
+ """
66
+ _FLOAT_MODULE = nni.LinearLeakyReLU # type: ignore[assignment]
67
+
68
+ def __init__(self, in_features, out_features, negative_slope, bias=True, dtype=torch.qint8):
69
+ super().__init__(in_features, out_features, bias, dtype)
70
+ self.negative_slope = negative_slope
71
+
72
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
73
+ return torch.ops.quantized.linear_leaky_relu(
74
+ x, self._packed_params._packed_params, self.scale, self.zero_point, self.negative_slope)
75
+
76
+ def _get_name(self):
77
+ return 'QuantizedLinearLeakyReLU'
78
+
79
+ @classmethod
80
+ def from_float(cls, mod):
81
+ assert type(mod) == nni.LinearLeakyReLU, 'Input float module should be LinearLeakyReLU'
82
+ assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
83
+ activation_post_process = mod.activation_post_process
84
+ leaky_relu = mod[1]
85
+ mod = mod[0]
86
+ weight_post_process = mod.qconfig.weight()
87
+ weight_post_process(mod.weight)
88
+ dtype = weight_post_process.dtype
89
+ act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator]
90
+ assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
91
+ qweight = _quantize_weight(mod.weight.float(), weight_post_process)
92
+ qlinear_leaky_relu = cls(
93
+ mod.in_features,
94
+ mod.out_features,
95
+ leaky_relu.negative_slope,
96
+ dtype=dtype)
97
+ qlinear_leaky_relu.set_weight_bias(qweight, mod.bias)
98
+ qlinear_leaky_relu.scale = float(act_scale)
99
+ qlinear_leaky_relu.zero_point = int(act_zp)
100
+ return qlinear_leaky_relu
101
+
102
+ @classmethod
103
+ def from_reference(cls, ref_mod, output_scale, output_zero_point):
104
+ linear = ref_mod[0]
105
+ leaky_relu = ref_mod[1]
106
+ qlinear_leaky_relu = cls(
107
+ linear.in_features,
108
+ linear.out_features,
109
+ leaky_relu.negative_slope)
110
+ qweight = linear.get_quantized_weight()
111
+ qlinear_leaky_relu.set_weight_bias(qweight, linear.bias)
112
+ qlinear_leaky_relu.scale = float(output_scale)
113
+ qlinear_leaky_relu.zero_point = int(output_zero_point)
114
+ return qlinear_leaky_relu
115
+
116
+ class LinearTanh(nnq.Linear):
117
+ r"""
118
+ A LinearTanh module fused from Linear and Tanh modules
119
+
120
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
121
+
122
+ Attributes:
123
+ Same as torch.ao.nn.quantized.Linear
124
+
125
+ Examples::
126
+
127
+ >>> # xdoctest: +SKIP
128
+ >>> m = nn.intrinsic.LinearTanh(20, 30)
129
+ >>> input = torch.randn(128, 20)
130
+ >>> output = m(input)
131
+ >>> print(output.size())
132
+ torch.Size([128, 30])
133
+ """
134
+ _FLOAT_MODULE = nni.LinearTanh # type: ignore[assignment]
135
+
136
+ def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8):
137
+ super().__init__(in_features, out_features, bias, dtype)
138
+
139
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
140
+ return torch.ops.quantized.linear_tanh(
141
+ x, self._packed_params._packed_params, self.scale, self.zero_point)
142
+
143
+ def _get_name(self):
144
+ return 'QuantizedLinearTanh'
145
+
146
+ @classmethod
147
+ def from_float(cls, mod):
148
+ assert type(mod) == nni.LinearTanh, 'Input float module should be LinearTanh'
149
+ assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
150
+ activation_post_process = mod.activation_post_process
151
+ mod = mod[0]
152
+ weight_post_process = mod.qconfig.weight()
153
+ weight_post_process(mod.weight)
154
+ dtype = weight_post_process.dtype
155
+ act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator]
156
+ assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
157
+ qweight = _quantize_weight(mod.weight.float(), weight_post_process)
158
+ qlinear_tanh = cls(
159
+ mod.in_features,
160
+ mod.out_features,
161
+ dtype=dtype)
162
+ qlinear_tanh.set_weight_bias(qweight, mod.bias)
163
+ qlinear_tanh.scale = float(act_scale)
164
+ qlinear_tanh.zero_point = int(act_zp)
165
+ return qlinear_tanh
166
+
167
+ @classmethod
168
+ def from_reference(cls, ref_mod, output_scale, output_zero_point):
169
+ linear = ref_mod[0]
170
+ qlinear_tanh = cls(
171
+ linear.in_features,
172
+ linear.out_features)
173
+ qweight = linear.get_quantized_weight()
174
+ qlinear_tanh.set_weight_bias(qweight, linear.bias)
175
+ qlinear_tanh.scale = float(output_scale)
176
+ qlinear_tanh.zero_point = int(output_zero_point)
177
+ return qlinear_tanh
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/qat/dynamic/modules/linear.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ __all__ = ["Linear"]
4
+
5
+ class Linear(torch.ao.nn.qat.Linear):
6
+ r"""
7
+ A linear module attached with FakeQuantize modules for weight,
8
+ used for dynamic quantization aware training.
9
+
10
+ We adopt the same interface as `torch.nn.Linear`, please see
11
+ https://pytorch.org/docs/stable/nn.html#torch.nn.Linear
12
+ for documentation.
13
+
14
+ Similar to `torch.nn.Linear`, with FakeQuantize modules initialized to
15
+ default.
16
+ """
17
+
18
+ def __init__(self, in_features, out_features, bias=True,
19
+ qconfig=None, device=None, dtype=None) -> None:
20
+ super().__init__(in_features, out_features, bias, qconfig, device, dtype)
21
+ if not torch.ao.quantization.qconfig._activation_is_memoryless(qconfig):
22
+ raise ValueError(
23
+ "Dynamic QAT requires a memoryless observer." +
24
+ "This means a MovingAverage observer with averaging constant equal to 1"
25
+ )
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/qat/modules/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .linear import Linear
2
+ from .conv import Conv1d
3
+ from .conv import Conv2d
4
+ from .conv import Conv3d
5
+ from .embedding_ops import EmbeddingBag, Embedding
6
+
7
+ __all__ = [
8
+ "Linear",
9
+ "Conv1d",
10
+ "Conv2d",
11
+ "Conv3d",
12
+ "Embedding",
13
+ "EmbeddingBag",
14
+ ]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantizable/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (256 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantizable/modules/activation.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.jit # this is needed to avoid a circular import
3
+ from torch import nn
4
+ import torch.nn.functional as nnF
5
+
6
+ from torch import Tensor
7
+ from typing import Optional, Tuple
8
+
9
+ import warnings
10
+
11
+ __all__ = [
12
+ "MultiheadAttention"
13
+ ]
14
+
15
+ class MultiheadAttention(nn.MultiheadAttention):
16
+ _FLOAT_MODULE = nn.MultiheadAttention
17
+
18
+ r"""Quantizable implementation of the MultiheadAttention.
19
+
20
+ Note::
21
+ Please, refer to :class:`~torch.nn.MultiheadAttention` for more
22
+ information
23
+
24
+ Allows the model to jointly attend to information from different
25
+ representation subspaces.
26
+ See reference: Attention Is All You Need
27
+
28
+ The original MHA module is not quantizable.
29
+ This reimplements it by explicitly instantiating the linear layers.
30
+
31
+ .. math::
32
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
33
+ \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
34
+
35
+ Args:
36
+ embed_dim: total dimension of the model.
37
+ num_heads: parallel attention heads.
38
+ dropout: a Dropout layer on attn_output_weights. Default: 0.0.
39
+ bias: add bias as module parameter. Default: True.
40
+ add_bias_kv: add bias to the key and value sequences at dim=0.
41
+ add_zero_attn: add a new batch of zeros to the key and
42
+ value sequences at dim=1.
43
+ kdim: total number of features in key. Default: None.
44
+ vdim: total number of features in value. Default: None.
45
+ batch_first: If ``True``, then the input and output tensors are provided
46
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
47
+
48
+ Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set
49
+ to :attr:`embed_dim` such that query, key, and value have the same
50
+ number of features.
51
+
52
+ Examples::
53
+
54
+ >>> import torch.ao.nn.quantizable as nnqa
55
+ >>> multihead_attn = nnqa.MultiheadAttention(embed_dim, num_heads)
56
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
57
+
58
+ Note::
59
+ Please, follow the quantization flow to convert the quantizable MHA.
60
+ """
61
+ __constants__ = ['batch_first']
62
+
63
+ def __init__(self, embed_dim: int, num_heads: int,
64
+ dropout: float = 0., bias: bool = True,
65
+ add_bias_kv: bool = False, add_zero_attn: bool = False,
66
+ kdim: Optional[int] = None, vdim: Optional[int] = None, batch_first: bool = False,
67
+ device=None, dtype=None) -> None:
68
+ factory_kwargs = {'device': device, 'dtype': dtype}
69
+ super().__init__(embed_dim, num_heads, dropout,
70
+ bias, add_bias_kv,
71
+ add_zero_attn, kdim, vdim, batch_first,
72
+ **factory_kwargs)
73
+ self.linear_Q = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs)
74
+ self.linear_K = nn.Linear(self.kdim, self.embed_dim, bias=bias, **factory_kwargs)
75
+ self.linear_V = nn.Linear(self.vdim, self.embed_dim, bias=bias, **factory_kwargs)
76
+ # for the type: ignore, see https://github.com/pytorch/pytorch/issues/58969
77
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs) # type: ignore[assignment]
78
+
79
+ # Functionals
80
+ self.q_scaling_product = torch.ao.nn.quantized.FloatFunctional()
81
+ # note: importing torch.ao.nn.quantized at top creates a circular import
82
+
83
+ # Quant/Dequant
84
+ self.quant_attn_output = torch.ao.quantization.QuantStub()
85
+ self.quant_attn_output_weights = torch.ao.quantization.QuantStub()
86
+ self.dequant_q = torch.ao.quantization.DeQuantStub()
87
+ self.dequant_k = torch.ao.quantization.DeQuantStub()
88
+ self.dequant_v = torch.ao.quantization.DeQuantStub()
89
+
90
+ def _get_name(self):
91
+ return 'QuantizableMultiheadAttention'
92
+
93
+ @classmethod
94
+ def from_float(cls, other):
95
+ assert type(other) == cls._FLOAT_MODULE
96
+ assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'"
97
+ # Setting the dropout to 0.0!
98
+ observed = cls(other.embed_dim, other.num_heads, other.dropout,
99
+ (other.in_proj_bias is not None),
100
+ (other.bias_k is not None),
101
+ other.add_zero_attn, other.kdim, other.vdim,
102
+ other.batch_first)
103
+ observed.bias_k = other.bias_k
104
+ observed.bias_v = other.bias_v
105
+ observed.qconfig = other.qconfig
106
+
107
+ # Set the linear weights
108
+ # for the type: ignores, see https://github.com/pytorch/pytorch/issues/58969
109
+ observed.out_proj.weight = other.out_proj.weight # type: ignore[has-type]
110
+ observed.out_proj.bias = other.out_proj.bias # type: ignore[has-type]
111
+ if other._qkv_same_embed_dim:
112
+ # Use separate params
113
+ bias = other.in_proj_bias
114
+ _start = 0
115
+ _end = _start + other.embed_dim
116
+ weight = other.in_proj_weight[_start:_end, :]
117
+ if bias is not None:
118
+ bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad)
119
+ observed.linear_Q.weight = torch.nn.Parameter(weight,
120
+ weight.requires_grad)
121
+ observed.linear_Q.bias = bias
122
+
123
+ bias = other.in_proj_bias
124
+ _start = _end
125
+ _end = _start + other.embed_dim
126
+ weight = other.in_proj_weight[_start:_end, :]
127
+ if bias is not None:
128
+ bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad)
129
+ observed.linear_K.weight = torch.nn.Parameter(weight,
130
+ weight.requires_grad)
131
+ observed.linear_K.bias = bias
132
+
133
+ bias = other.in_proj_bias
134
+ _start = _end
135
+ weight = other.in_proj_weight[_start:, :]
136
+ if bias is not None:
137
+ bias = torch.nn.Parameter(bias[_start:], bias.requires_grad)
138
+ observed.linear_V.weight = torch.nn.Parameter(weight,
139
+ weight.requires_grad)
140
+ observed.linear_V.bias = bias
141
+ else:
142
+ observed.linear_Q.weight = nn.Parameter(other.q_proj_weight)
143
+ observed.linear_K.weight = nn.Parameter(other.k_proj_weight)
144
+ observed.linear_V.weight = nn.Parameter(other.v_proj_weight)
145
+ if other.in_proj_bias is None:
146
+ observed.linear_Q.bias = None # type: ignore[assignment]
147
+ observed.linear_K.bias = None # type: ignore[assignment]
148
+ observed.linear_V.bias = None # type: ignore[assignment]
149
+ else:
150
+ observed.linear_Q.bias = nn.Parameter(other.in_proj_bias[0:other.embed_dim])
151
+ observed.linear_K.bias = nn.Parameter(other.in_proj_bias[other.embed_dim:(other.embed_dim * 2)])
152
+ observed.linear_V.bias = nn.Parameter(other.in_proj_bias[(other.embed_dim * 2):])
153
+ observed.eval()
154
+ # Explicit prepare
155
+ observed = torch.ao.quantization.prepare(observed, inplace=True)
156
+ return observed
157
+
158
+ @torch.jit.unused
159
+ def dequantize(self):
160
+ r"""Utility to convert the quantized MHA back to float.
161
+
162
+ The motivation for this is that it is not trivial to conver the weights
163
+ from the format that is used in the quantized version back to the
164
+ float.
165
+ """
166
+ fp = self._FLOAT_MODULE(self.embed_dim, self.num_heads, self.dropout,
167
+ (self.linear_Q._weight_bias()[1] is not None),
168
+ (self.bias_k is not None),
169
+ self.add_zero_attn, self.kdim, self.vdim, self.batch_first)
170
+ assert fp._qkv_same_embed_dim == self._qkv_same_embed_dim
171
+ if self.bias_k is not None:
172
+ fp.bias_k = nn.Parameter(self.bias_k.dequantize())
173
+ if self.bias_v is not None:
174
+ fp.bias_v = nn.Parameter(self.bias_v.dequantize())
175
+
176
+ # Set the linear weights
177
+ # Note: Because the linear layers are quantized, mypy does not nkow how
178
+ # to deal with them -- might need to ignore the typing checks.
179
+ # for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969
180
+ w, b = self.out_proj._weight_bias() # type: ignore[operator, has-type]
181
+ fp.out_proj.weight = nn.Parameter(w.dequantize())
182
+ if b is not None:
183
+ fp.out_proj.bias = nn.Parameter(b)
184
+
185
+ wQ, bQ = self.linear_Q._weight_bias() # type: ignore[operator]
186
+ wQ = wQ.dequantize()
187
+ wK, bK = self.linear_K._weight_bias() # type: ignore[operator]
188
+ wK = wK.dequantize()
189
+ wV, bV = self.linear_V._weight_bias() # type: ignore[operator]
190
+ wV = wV.dequantize()
191
+ if fp._qkv_same_embed_dim:
192
+ # Use separate params
193
+ _start = 0
194
+ _end = _start + fp.embed_dim
195
+ fp.in_proj_weight[_start:_end, :] = wQ
196
+ if fp.in_proj_bias is not None:
197
+ assert all(bQ == 0)
198
+ fp.in_proj_bias[_start:_end] = bQ
199
+
200
+ _start = _end
201
+ _end = _start + fp.embed_dim
202
+ fp.in_proj_weight[_start:_end, :] = wK
203
+ if fp.in_proj_bias is not None:
204
+ assert all(bK == 0)
205
+ fp.in_proj_bias[_start:_end] = bK
206
+
207
+ _start = _end
208
+ fp.in_proj_weight[_start:, :] = wV
209
+ if fp.in_proj_bias is not None:
210
+ assert all(bV == 0)
211
+ fp.in_proj_bias[_start:] = bV
212
+ else:
213
+ fp.q_proj_weight = nn.Parameter(wQ)
214
+ fp.k_proj_weight = nn.Parameter(wK)
215
+ fp.v_proj_weight = nn.Parameter(wV)
216
+ if fp.in_proj_bias is None:
217
+ self.linear_Q.bias = None
218
+ self.linear_K.bias = None
219
+ self.linear_V.bias = None
220
+ else:
221
+ fp.in_proj_bias[0:fp.embed_dim] = bQ
222
+ fp.in_proj_bias[fp.embed_dim:(fp.embed_dim * 2)] = bK
223
+ fp.in_proj_bias[(fp.embed_dim * 2):] = bV
224
+
225
+ return fp
226
+
227
+
228
+ @classmethod
229
+ def from_observed(cls, other):
230
+ # The whole flow is float -> observed -> quantized
231
+ # This class does float -> observed only
232
+ # See nn.quantized.MultiheadAttention
233
+ raise NotImplementedError("It looks like you are trying to prepare an "
234
+ "MHA module. Please, see "
235
+ "the examples on quantizable MHAs.")
236
+
237
+ def forward(self,
238
+ query: Tensor,
239
+ key: Tensor,
240
+ value: Tensor,
241
+ key_padding_mask: Optional[Tensor] = None,
242
+ need_weights: bool = True,
243
+ attn_mask: Optional[Tensor] = None,
244
+ average_attn_weights: bool = True,
245
+ is_causal: bool = False) -> Tuple[Tensor, Optional[Tensor]]:
246
+ r"""
247
+ Note::
248
+ Please, refer to :func:`~torch.nn.MultiheadAttention.forward` for more
249
+ information
250
+
251
+ Args:
252
+ query, key, value: map a query and a set of key-value pairs to an output.
253
+ See "Attention Is All You Need" for more details.
254
+ key_padding_mask: if provided, specified padding elements in the key will
255
+ be ignored by the attention. When given a binary mask and a value is True,
256
+ the corresponding value on the attention layer will be ignored.
257
+ need_weights: output attn_output_weights.
258
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
259
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
260
+
261
+ Shape:
262
+ - Inputs:
263
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
264
+ the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
265
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
266
+ the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
267
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
268
+ the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
269
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
270
+ If a BoolTensor is provided, the positions with the
271
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
272
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
273
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
274
+ S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
275
+ positions. If a BoolTensor is provided, positions with ``True``
276
+ is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
277
+ is provided, it will be added to the attention weight.
278
+ - is_causal: If specified, applies a causal mask as attention mask. Mutually exclusive with providing attn_mask.
279
+ Default: ``False``.
280
+ - average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
281
+ heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
282
+ effect when ``need_weights=True.``. Default: True (i.e. average weights across heads)
283
+
284
+ - Outputs:
285
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
286
+ E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
287
+ - attn_output_weights: If ``average_attn_weights=True``, returns attention weights averaged
288
+ across heads of shape :math:`(N, L, S)`, where N is the batch size, L is the target sequence length,
289
+ S is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
290
+ head of shape :math:`(N, num_heads, L, S)`.
291
+ """
292
+ return self._forward_impl(query, key, value, key_padding_mask,
293
+ need_weights, attn_mask, average_attn_weights,
294
+ is_causal)
295
+
296
+ def _forward_impl(self,
297
+ query: Tensor,
298
+ key: Tensor,
299
+ value: Tensor,
300
+ key_padding_mask: Optional[Tensor] = None,
301
+ need_weights: bool = True,
302
+ attn_mask: Optional[Tensor] = None,
303
+ average_attn_weights: bool = True,
304
+ is_causal: bool = False) -> Tuple[Tensor, Optional[Tensor]]:
305
+ # This version will not deal with the static key/value pairs.
306
+ # Keeping it here for future changes.
307
+ #
308
+ # TODO: This method has some duplicate lines with the
309
+ # `torch.nn.functional.multi_head_attention`. Will need to refactor.
310
+ static_k = None
311
+ static_v = None
312
+
313
+ if attn_mask is not None and is_causal:
314
+ raise AssertionError("Only allow causal mask or attn_mask")
315
+
316
+ if is_causal:
317
+ raise AssertionError("causal mask not supported by AO MHA module")
318
+
319
+ if self.batch_first:
320
+ query, key, value = (x.transpose(0, 1) for x in (query, key, value))
321
+
322
+ tgt_len, bsz, embed_dim_to_check = query.size()
323
+ assert self.embed_dim == embed_dim_to_check
324
+ # allow MHA to have different sizes for the feature dimension
325
+ assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
326
+
327
+ head_dim = self.embed_dim // self.num_heads
328
+ assert head_dim * self.num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
329
+ scaling = float(head_dim) ** -0.5
330
+
331
+ q = self.linear_Q(query)
332
+ k = self.linear_K(key)
333
+ v = self.linear_V(value)
334
+
335
+ q = self.q_scaling_product.mul_scalar(q, scaling)
336
+
337
+ if attn_mask is not None:
338
+ if attn_mask.dtype == torch.uint8:
339
+ warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
340
+ attn_mask = attn_mask.to(torch.bool)
341
+ assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \
342
+ f'Only float and bool types are supported for attn_mask, not {attn_mask.dtype}'
343
+
344
+ if attn_mask.dim() == 2:
345
+ attn_mask = attn_mask.unsqueeze(0)
346
+ if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
347
+ raise RuntimeError('The size of the 2D attn_mask is not correct.')
348
+ elif attn_mask.dim() == 3:
349
+ if list(attn_mask.size()) != [bsz * self.num_heads, query.size(0), key.size(0)]:
350
+ raise RuntimeError('The size of the 3D attn_mask is not correct.')
351
+ else:
352
+ raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
353
+ # attn_mask's dim is 3 now.
354
+
355
+ # convert ByteTensor key_padding_mask to bool
356
+ if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
357
+ warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
358
+ key_padding_mask = key_padding_mask.to(torch.bool)
359
+ if self.bias_k is not None and self.bias_v is not None:
360
+ if static_k is None and static_v is None:
361
+
362
+ # Explicitly assert that bias_k and bias_v are not None
363
+ # in a way that TorchScript can understand.
364
+ bias_k = self.bias_k
365
+ assert bias_k is not None
366
+ bias_v = self.bias_v
367
+ assert bias_v is not None
368
+
369
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
370
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
371
+ if attn_mask is not None:
372
+ attn_mask = nnF.pad(attn_mask, (0, 1))
373
+ if key_padding_mask is not None:
374
+ key_padding_mask = nnF.pad(key_padding_mask, (0, 1))
375
+ else:
376
+ assert static_k is None, "bias cannot be added to static key."
377
+ assert static_v is None, "bias cannot be added to static value."
378
+ else:
379
+ assert self.bias_k is None
380
+ assert self.bias_v is None
381
+
382
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, head_dim).transpose(0, 1)
383
+ if k is not None:
384
+ k = k.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
385
+ if v is not None:
386
+ v = v.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
387
+
388
+ if static_k is not None:
389
+ assert static_k.size(0) == bsz * self.num_heads
390
+ assert static_k.size(2) == head_dim
391
+ k = static_k
392
+
393
+ if static_v is not None:
394
+ assert static_v.size(0) == bsz * self.num_heads
395
+ assert static_v.size(2) == head_dim
396
+ v = static_v
397
+
398
+ src_len = k.size(1)
399
+
400
+ if key_padding_mask is not None:
401
+ assert key_padding_mask.size(0) == bsz
402
+ assert key_padding_mask.size(1) == src_len
403
+
404
+ if self.add_zero_attn:
405
+ src_len += 1
406
+ k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:])
407
+ if k.is_quantized:
408
+ k_zeros = torch.quantize_per_tensor(k_zeros, k.q_scale(), k.q_zero_point(), k.dtype)
409
+ k = torch.cat([k, k_zeros], dim=1)
410
+ v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:])
411
+ if v.is_quantized:
412
+ v_zeros = torch.quantize_per_tensor(v_zeros, v.q_scale(), v.q_zero_point(), v.dtype)
413
+ v = torch.cat([v, v_zeros], dim=1)
414
+
415
+ if attn_mask is not None:
416
+ attn_mask = nnF.pad(attn_mask, (0, 1))
417
+ if key_padding_mask is not None:
418
+ key_padding_mask = nnF.pad(key_padding_mask, (0, 1))
419
+
420
+ # Leaving the quantized zone here
421
+ q = self.dequant_q(q)
422
+ k = self.dequant_k(k)
423
+ v = self.dequant_v(v)
424
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
425
+ assert list(attn_output_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
426
+
427
+ if attn_mask is not None:
428
+ if attn_mask.dtype == torch.bool:
429
+ attn_output_weights.masked_fill_(attn_mask, float('-inf'))
430
+ else:
431
+ attn_output_weights += attn_mask
432
+
433
+ if key_padding_mask is not None:
434
+ attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
435
+ attn_output_weights = attn_output_weights.masked_fill(
436
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
437
+ float('-inf'),
438
+ )
439
+ attn_output_weights = attn_output_weights.view(bsz * self.num_heads, tgt_len, src_len)
440
+
441
+ attn_output_weights = nnF.softmax(
442
+ attn_output_weights, dim=-1)
443
+ attn_output_weights = nnF.dropout(attn_output_weights, p=self.dropout, training=self.training)
444
+
445
+ attn_output = torch.bmm(attn_output_weights, v)
446
+ assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, head_dim]
447
+ if self.batch_first:
448
+ attn_output = attn_output.view(bsz, tgt_len, self.embed_dim)
449
+ else:
450
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
451
+
452
+ # Reentering the quantized zone
453
+ attn_output = self.quant_attn_output(attn_output)
454
+ # for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969
455
+ attn_output = self.out_proj(attn_output) # type: ignore[has-type]
456
+ attn_output_weights = self.quant_attn_output_weights(attn_output_weights)
457
+
458
+ if need_weights:
459
+ # average attention weights over heads
460
+ attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
461
+ if average_attn_weights:
462
+ attn_output_weights = attn_output_weights.mean(dim=1)
463
+ return attn_output, attn_output_weights
464
+ else:
465
+ return attn_output, None
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/__pycache__/functional.cpython-311.pyc ADDED
Binary file (32.6 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-311.pyc ADDED
Binary file (59 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/batchnorm.cpython-311.pyc ADDED
Binary file (6.44 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/conv.cpython-311.pyc ADDED
Binary file (49.1 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/embedding_ops.cpython-311.pyc ADDED
Binary file (18.1 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/linear.cpython-311.pyc ADDED
Binary file (16.3 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/rnn.cpython-311.pyc ADDED
Binary file (2.6 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/utils.cpython-311.pyc ADDED
Binary file (7.19 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/activation.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from warnings import warn
3
+ __all__ = [
4
+ "ReLU6",
5
+ "Hardswish",
6
+ "ELU",
7
+ "LeakyReLU",
8
+ "Sigmoid",
9
+ "Softmax",
10
+ "MultiheadAttention",
11
+ "PReLU"
12
+ ]
13
+
14
+ class ReLU6(torch.nn.ReLU):
15
+ r"""Applies the element-wise function:
16
+
17
+ :math:`\text{ReLU6}(x) = \min(\max(x_0, x), q(6))`, where :math:`x_0` is the
18
+ zero_point, and :math:`q(6)` is the quantized representation of number 6.
19
+
20
+ Args:
21
+ inplace: can optionally do the operation in-place. Default: ``False``
22
+
23
+ Shape:
24
+ - Input: :math:`(N, *)` where `*` means, any number of additional
25
+ dimensions
26
+ - Output: :math:`(N, *)`, same shape as the input
27
+
28
+ .. image:: ../scripts/activation_images/ReLU6.png
29
+
30
+ Examples::
31
+
32
+ >>> m = nn.quantized.ReLU6()
33
+ >>> input = torch.randn(2)
34
+ >>> # xdoctest: +SKIP
35
+ >>> input = torch.quantize_per_tensor(input, 1.0, 0, dtype=torch.qint32)
36
+ >>> output = m(input)
37
+ """
38
+ def __init__(self, inplace=False):
39
+ super().__init__(inplace)
40
+ self.inplace = inplace
41
+
42
+ def forward(self, input):
43
+ return torch.ops.quantized.relu6(input, self.inplace)
44
+
45
+ def _get_name(self):
46
+ return 'QuantizedReLU6'
47
+
48
+ @staticmethod
49
+ def from_float(mod):
50
+ return ReLU6(mod.inplace)
51
+
52
+ class Hardswish(torch.nn.Hardswish):
53
+ r"""This is the quantized version of :class:`~torch.nn.Hardswish`.
54
+
55
+ Args:
56
+ scale: quantization scale of the output tensor
57
+ zero_point: quantization zero point of the output tensor
58
+ """
59
+ def __init__(self, scale, zero_point, device=None, dtype=None):
60
+ factory_kwargs = {'device': device, 'dtype': dtype}
61
+ super().__init__()
62
+ self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
63
+ self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
64
+
65
+ def forward(self, input):
66
+ return torch.ops.quantized.hardswish(input, self.scale, self.zero_point)
67
+
68
+ def _get_name(self):
69
+ return 'QuantizedHardswish'
70
+
71
+ @staticmethod
72
+ def from_float(mod):
73
+ scale, zero_point = mod.activation_post_process.calculate_qparams()
74
+ return Hardswish(float(scale), int(zero_point))
75
+
76
+ @classmethod
77
+ def from_reference(cls, mod, scale, zero_point):
78
+ return cls(float(scale), int(zero_point))
79
+
80
+ class ELU(torch.nn.ELU):
81
+ r"""This is the quantized equivalent of :class:`~torch.nn.ELU`.
82
+
83
+ Args:
84
+ scale: quantization scale of the output tensor
85
+ zero_point: quantization zero point of the output tensor
86
+ alpha: the alpha constant
87
+ """
88
+ def __init__(self, scale, zero_point, alpha=1.):
89
+ super().__init__(alpha)
90
+ self.scale = scale
91
+ self.zero_point = zero_point
92
+
93
+ def forward(self, input):
94
+ return torch.ao.nn.quantized.functional.elu(
95
+ input, self.scale, self.zero_point, self.alpha)
96
+
97
+ def _get_name(self):
98
+ return 'QuantizedELU'
99
+
100
+ @staticmethod
101
+ def from_float(mod):
102
+ scale, zero_point = mod.activation_post_process.calculate_qparams()
103
+ return ELU(float(scale), int(zero_point), mod.alpha)
104
+
105
+ @classmethod
106
+ def from_reference(cls, mod, scale, zero_point):
107
+ return cls(float(scale), int(zero_point), mod.alpha)
108
+
109
+ class LeakyReLU(torch.nn.LeakyReLU):
110
+ r"""This is the quantized equivalent of :class:`~torch.nn.LeakyReLU`.
111
+
112
+ Args:
113
+ scale: quantization scale of the output tensor
114
+ zero_point: quantization zero point of the output tensor
115
+ negative_slope: Controls the angle of the negative slope. Default: 1e-2
116
+ """
117
+ def __init__(self, scale: float, zero_point: int, negative_slope: float = 1e-2,
118
+ inplace: bool = False, device=None, dtype=None) -> None:
119
+ factory_kwargs = {'device': device, 'dtype': dtype}
120
+ super().__init__(negative_slope, inplace)
121
+ self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
122
+ self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
123
+
124
+ def forward(self, input):
125
+ return torch.ops.quantized.leaky_relu(
126
+ input, self.negative_slope, self.inplace, self.scale, self.zero_point)
127
+
128
+ def _get_name(self):
129
+ return 'QuantizedLeakyReLU'
130
+
131
+ @classmethod
132
+ def from_float(cls, mod):
133
+ scale, zero_point = mod.activation_post_process.calculate_qparams()
134
+ return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace)
135
+
136
+ @classmethod
137
+ def from_reference(cls, mod, scale, zero_point):
138
+ return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace)
139
+
140
+ class Sigmoid(torch.nn.Sigmoid):
141
+ r"""This is the quantized equivalent of :class:`~torch.nn.Sigmoid`.
142
+
143
+ Args:
144
+ scale: quantization scale of the output tensor
145
+ zero_point: quantization zero point of the output tensor
146
+ """
147
+
148
+ def __init__(self, output_scale: float, output_zero_point: int):
149
+ super().__init__()
150
+ self.output_scale = output_scale
151
+ self.output_zero_point = output_zero_point
152
+
153
+ def forward(self, input):
154
+ return torch.ops.quantized.sigmoid(input, self.output_scale, self.output_zero_point)
155
+
156
+ @classmethod
157
+ def from_float(cls, mod):
158
+ output_scale, output_zero_point = mod.activation_post_process.calculate_qparams()
159
+ return cls(float(output_scale), int(output_zero_point))
160
+
161
+ class Softmax(torch.nn.Softmax):
162
+ r"""This is the quantized version of :class:`~torch.nn.Softmax`.
163
+
164
+ Args:
165
+ dim: A dimension along which Softmax will be computed (so every slice along dim will sum to 1).
166
+ scale: quantization scale of the output tensor
167
+ zero_point: quantization zero point of the output tensor
168
+ """
169
+ def __init__(self, dim=None, scale=1.0, zero_point=0):
170
+ super().__init__()
171
+ self.dim = dim
172
+ self.scale = scale
173
+ self.zero_point = zero_point
174
+
175
+ def forward(self, input):
176
+ dim = self.dim
177
+ if dim is None:
178
+ stacklevel = 3
179
+ # Note: adding the mypy ignore on _get_softmax_dim seems less bad
180
+ # than making `_get_softmax_dim` an official API.
181
+ dim = torch.nn.functional._get_softmax_dim( # type: ignore[attr-defined]
182
+ "softmax", input.dim(), stacklevel)
183
+ return torch.ops.quantized.softmax(
184
+ input, dim, self.scale, self.zero_point)
185
+
186
+ def _get_name(self):
187
+ return 'QuantizedSoftmax'
188
+
189
+ @staticmethod
190
+ def from_float(mod):
191
+ scale, zero_point = mod.activation_post_process.calculate_qparams()
192
+ return Softmax(mod.dim, float(scale), int(zero_point))
193
+
194
+ @classmethod
195
+ def from_reference(cls, mod, scale, zero_point):
196
+ return cls(mod.dim, float(scale), int(zero_point))
197
+
198
+
199
+ class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention):
200
+ _FLOAT_MODULE = torch.ao.nn.quantizable.MultiheadAttention
201
+
202
+ def _get_name(self):
203
+ return "QuantizedMultiheadAttention"
204
+
205
+ @classmethod
206
+ def from_float(cls, other):
207
+ # The whole flow is float -> observed -> quantized
208
+ # This class does observed -> quantized only
209
+ raise NotImplementedError("It looks like you are trying to convert a "
210
+ "non-observed MHA module. Please, see "
211
+ "the examples on quantizable MHAs.")
212
+
213
+ @classmethod
214
+ def from_observed(cls, other):
215
+ converted = torch.ao.quantization.convert(other, mapping=None,
216
+ inplace=False,
217
+ remove_qconfig=True,
218
+ convert_custom_config_dict=None)
219
+ converted.__class__ = cls
220
+ # Remove the parameters for the bias_k and bias_v to quantize them
221
+ # TODO: This is a potential source of accuracy drop.
222
+ # quantized cat takes the scale and zp of the first
223
+ # element, which might lose the precision in the bias_k
224
+ # and the bias_v (which are cat'ed with k/v being first).
225
+ if converted.bias_k is not None:
226
+ bias_k = converted._parameters.pop('bias_k')
227
+ sc, zp = torch._choose_qparams_per_tensor(bias_k,
228
+ reduce_range=False)
229
+ bias_k = torch.quantize_per_tensor(bias_k, sc, zp, torch.quint8)
230
+ setattr(converted, 'bias_k', bias_k) # noqa: B010
231
+
232
+ if converted.bias_v is not None:
233
+ bias_v = converted._parameters.pop('bias_v')
234
+ sc, zp = torch._choose_qparams_per_tensor(bias_k, # type: ignore[possibly-undefined]
235
+ reduce_range=False)
236
+ bias_v = torch.quantize_per_tensor(bias_v, sc, zp, torch.quint8)
237
+ setattr(converted, 'bias_v', bias_v) # noqa: B010
238
+
239
+ del converted.in_proj_weight
240
+ del converted.in_proj_bias
241
+
242
+ return converted
243
+
244
+ class PReLU(torch.nn.Module):
245
+ r"""This is the quantized equivalent of :class:`~torch.nn.PReLU`.
246
+
247
+ Args:
248
+ scale: quantization scale of the output tensor
249
+ zero_point: quantization zero point of the output tensor
250
+ num_parameters: number of parameters: 1, or the number of channels at input. Default: 1
251
+ """
252
+ def __init__(self, output_scale: float, output_zero_point: int,
253
+ num_parameters: int = 1) -> None:
254
+ super().__init__()
255
+ self.num_parameters = num_parameters
256
+ self.scale = output_scale
257
+ self.zero_point = output_zero_point
258
+ w = torch.randn(num_parameters, dtype=torch.float)
259
+ qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.quint8)
260
+ self.set_weight(qw)
261
+
262
+ def set_weight(self, w: torch.Tensor) -> None:
263
+ self.weight = w
264
+
265
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
266
+ return torch.ops.quantized.prelu(input, self.weight, self.scale, self.zero_point)
267
+
268
+ def _get_name(self):
269
+ return 'QuantizedPReLU'
270
+
271
+ @classmethod
272
+ def from_float(cls, mod):
273
+ scale, zero_point = mod.activation_post_process.calculate_qparams()
274
+ qprelu = cls(float(scale), int(zero_point), mod.num_parameters)
275
+ float_wt = mod.weight.float()
276
+ observer = mod.qconfig.weight()
277
+ observer(float_wt)
278
+ if observer.dtype != torch.quint8:
279
+ warn(
280
+ f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}"
281
+ )
282
+ wt_scale, wt_zp = observer.calculate_qparams()
283
+ qweight = torch.quantize_per_tensor(
284
+ float_wt, float(wt_scale), int(wt_zp), torch.quint8)
285
+ qprelu.set_weight(qweight)
286
+ return qprelu
287
+
288
+ @classmethod
289
+ def from_reference(cls, mod, scale, zero_point):
290
+ qprelu = cls(float(scale), int(zero_point), mod.num_parameters)
291
+ float_wt = mod.weight.float()
292
+ observer = mod.qconfig.weight()
293
+ observer(float_wt)
294
+ if observer.dtype != torch.quint8:
295
+ warn(
296
+ f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}"
297
+ )
298
+ wt_scale, wt_zp = observer.calculate_qparams()
299
+ qweight = torch.quantize_per_tensor(
300
+ float_wt, float(wt_scale), int(wt_zp), torch.quint8)
301
+ qprelu.set_weight(qweight)
302
+ return qprelu
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/conv.py ADDED
@@ -0,0 +1,945 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""Quantized convolution modules."""
2
+
3
+ from typing import Optional, List, TypeVar
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torch.ao.nn.intrinsic as nni
9
+ import torch.ao.nn.intrinsic.qat as nniqat
10
+
11
+ from torch._ops import ops
12
+ from torch.nn.common_types import _size_1_t
13
+ from torch.nn.modules.utils import _single, _pair, _triple
14
+ from torch.nn.utils import fuse_conv_bn_weights
15
+
16
+ from .utils import _quantize_weight, WeightedQuantizedModule
17
+
18
+ __all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d']
19
+
20
+ _SUPPORTED_PADDING = {
21
+ 'zeros',
22
+ 'reflect'
23
+ }
24
+
25
+
26
+ def _reverse_repeat_padding(padding: List[int]) -> List[int]:
27
+ _reversed_padding_repeated_twice: List[int] = []
28
+ N = len(padding)
29
+ for idx in range(N):
30
+ for _ in range(2):
31
+ _reversed_padding_repeated_twice.append(padding[N - idx - 1])
32
+ return _reversed_padding_repeated_twice
33
+
34
+
35
+ class _ConvNd(WeightedQuantizedModule):
36
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
37
+ padding=0, dilation=1, groups=1, bias=True,
38
+ padding_mode='zeros', device=None, dtype=None):
39
+ # All subclasses have this signature - See PR #49702s
40
+ raise NotImplementedError
41
+
42
+ def _init(self, in_channels, out_channels, kernel_size, stride,
43
+ padding, dilation,
44
+ transposed, output_padding,
45
+ groups, bias,
46
+ padding_mode='zeros',
47
+ device=None,
48
+ dtype=None) -> None:
49
+ factory_kwargs = {'device': device, 'dtype': dtype}
50
+ super().__init__()
51
+
52
+ if in_channels % groups != 0:
53
+ raise ValueError('in_channels must be divisible by groups')
54
+ if out_channels % groups != 0:
55
+ raise ValueError('out_channels must be divisible by groups')
56
+ self.in_channels = in_channels
57
+ self.out_channels = out_channels
58
+ self.kernel_size = kernel_size
59
+ self.stride = stride
60
+ self.padding = padding
61
+ self.dilation = dilation
62
+ self.transposed = transposed
63
+ self.output_padding = output_padding
64
+ self.groups = groups
65
+ if padding_mode not in _SUPPORTED_PADDING:
66
+ raise ValueError(f"'padding_mode' {padding_mode} is not supported by quantized convolution")
67
+ self.padding_mode = padding_mode
68
+ # Initialize as NCHW. set_weight will internally transpose to NHWC.
69
+ if self.transposed:
70
+ weight_shape = [in_channels, out_channels // self.groups]
71
+ else:
72
+ weight_shape = [out_channels, in_channels // self.groups]
73
+ qweight = torch._empty_affine_quantized(
74
+ weight_shape + list(kernel_size),
75
+ scale=1, zero_point=0, dtype=torch.qint8,
76
+ **{k: v for k, v in factory_kwargs.items() if k != 'dtype'})
77
+ bias_float = (
78
+ torch.zeros(out_channels, dtype=torch.float,
79
+ **{k: v for k, v in factory_kwargs.items() if k != 'dtype'}) if bias else None)
80
+
81
+ self.set_weight_bias(qweight, bias_float)
82
+ self.scale = 1.0
83
+ self.zero_point = 0
84
+
85
+ def set_weight_bias(self, qweight, bias_float):
86
+ raise NotImplementedError
87
+
88
+ def bias(self):
89
+ raise NotImplementedError
90
+
91
+ def _weight_bias(self):
92
+ raise NotImplementedError
93
+
94
+ def extra_repr(self):
95
+ s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
96
+ ', stride={stride}, scale={scale}, zero_point={zero_point}')
97
+ if self.padding != (0,) * len(self.padding):
98
+ s += ', padding={padding}'
99
+ if self.dilation != (1,) * len(self.dilation):
100
+ s += ', dilation={dilation}'
101
+ if self.output_padding != (0,) * len(self.output_padding):
102
+ s += ', output_padding={output_padding}'
103
+ if self.groups != 1:
104
+ s += ', groups={groups}'
105
+ if self.bias() is None:
106
+ s += ', bias=False'
107
+ return s.format(**self.__dict__)
108
+
109
+ # ===== Serialization methods =====
110
+ # The special consideration here is that we have to unpack the weights into
111
+ # their regular QTensor form for serialization. Packed weights should not
112
+ # live outside the process in which they were created, rather they should be
113
+ # derived from the QTensor weight.
114
+ # self
115
+ # |--- weight : Tensor
116
+ # |--- bias : Tensor
117
+ #
118
+ # TODO: maybe change to this when https://github.com/pytorch/pytorch/pull/32958 is landed
119
+ # self
120
+ # |--- _packed_params : Conv2dPackedParamsBase or Conv3dPackedParamsBase
121
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
122
+ super()._save_to_state_dict(destination, prefix, keep_vars)
123
+ (w, b) = self._weight_bias()
124
+ destination[prefix + 'weight'] = w
125
+ destination[prefix + 'bias'] = b
126
+ destination[prefix + 'scale'] = torch.tensor(self.scale)
127
+ destination[prefix + 'zero_point'] = torch.tensor(self.zero_point)
128
+
129
+ @torch.jit.export
130
+ def __getstate__(self):
131
+ (w, b) = self._weight_bias()
132
+ return (
133
+ self.in_channels,
134
+ self.out_channels,
135
+ self.kernel_size,
136
+ self.stride,
137
+ self.padding,
138
+ self.dilation,
139
+ self.transposed,
140
+ self.output_padding,
141
+ self.groups,
142
+ self.padding_mode,
143
+ w,
144
+ b,
145
+ self.scale,
146
+ self.zero_point,
147
+ self.training
148
+ )
149
+
150
+ # ===== Deserialization methods =====
151
+ # Counterpart to the serialization methods, we must pack the serialized
152
+ # QTensor weight into its packed format for use by the FBGEMM ops.
153
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
154
+ missing_keys, unexpected_keys, error_msgs):
155
+ self.set_weight_bias(
156
+ state_dict[prefix + 'weight'], state_dict[prefix + 'bias'])
157
+ state_dict.pop(prefix + 'weight')
158
+ state_dict.pop(prefix + 'bias')
159
+ self.scale = float(state_dict[prefix + 'scale'])
160
+ state_dict.pop(prefix + 'scale')
161
+ self.zero_point = int(state_dict[prefix + 'zero_point'])
162
+ state_dict.pop(prefix + 'zero_point')
163
+ super()._load_from_state_dict(
164
+ state_dict, prefix, local_metadata, False, missing_keys,
165
+ unexpected_keys, error_msgs)
166
+
167
+ @torch.jit.export
168
+ def __setstate__(self, state):
169
+ self.in_channels = state[0]
170
+ self.out_channels = state[1]
171
+ self.kernel_size = state[2]
172
+ self.stride = state[3]
173
+ self.padding = state[4]
174
+ self.dilation = state[5]
175
+ self.transposed = state[6]
176
+ self.output_padding = state[7]
177
+ self.groups = state[8]
178
+ self.padding_mode = state[9]
179
+ self.set_weight_bias(state[10], state[11])
180
+ self.scale = state[12]
181
+ self.zero_point = state[13]
182
+ self.training = state[14]
183
+
184
+ def __deepcopy__(self, memo):
185
+ new_instance = type(self).__new__(type(self))
186
+ torch.nn.Module.__init__(new_instance)
187
+ state = self.__getstate__()
188
+ new_instance.__setstate__(state)
189
+ return new_instance
190
+
191
+ def __copy__(self):
192
+ return self.__deepcopy__({})
193
+
194
+ @classmethod
195
+ def get_qconv(cls, mod, activation_post_process, weight_post_process=None):
196
+ r"""Creates a qconv object and returns it.
197
+ """
198
+ if weight_post_process is None:
199
+ weight_post_process = mod.qconfig.weight()
200
+ weight_post_process(mod.weight)
201
+ assert weight_post_process.dtype == torch.qint8, \
202
+ 'Weight observer must have a dtype of qint8'
203
+ qweight = _quantize_weight(mod.weight.float(), weight_post_process)
204
+ # the __init__ call used is the one from derived classes and not the one from _ConvNd
205
+ qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,
206
+ mod.stride, mod.padding, mod.dilation, mod.groups,
207
+ mod.bias is not None, mod.padding_mode)
208
+ qconv.set_weight_bias(qweight, mod.bias)
209
+ if activation_post_process is None or activation_post_process.dtype == torch.float:
210
+ return qconv # dynamic quantization doesn't need scale/zero_point
211
+ else:
212
+ act_scale, act_zp = activation_post_process.calculate_qparams()
213
+ qconv.scale = float(act_scale)
214
+ qconv.zero_point = int(act_zp)
215
+ return qconv
216
+
217
+ @staticmethod
218
+ def from_float(cls, mod):
219
+ if hasattr(mod, "weight_fake_quant"):
220
+ # assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \
221
+ # ".from_float only works for " + cls.__QAT_MODULE.__name__
222
+ if type(mod) == cls._NNIQAT_CONV_BN_MODULE:
223
+ mod.weight, mod.bias = fuse_conv_bn_weights(
224
+ mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
225
+ mod.bn.eps, mod.bn.weight, mod.bn.bias)
226
+ assert hasattr(mod, "activation_post_process"), \
227
+ "Input QAT module must have observer attached"
228
+ weight_post_process = mod.weight_fake_quant
229
+ activation_post_process = mod.activation_post_process
230
+ else:
231
+ assert type(mod) == cls._FLOAT_MODULE, \
232
+ " nnq." + cls.__name__ + ".from_float only works for " + \
233
+ cls._FLOAT_MODULE.__name__ + " but got:" + str(type(mod))
234
+ assert hasattr(mod, "qconfig"), \
235
+ "Input float module must have qconfig defined."
236
+ activation_post_process = None if not hasattr(
237
+ mod, "activation_post_process") else mod.activation_post_process
238
+ if type(mod) in [cls._NNI_CONV_RELU_MODULE, cls._NNI_CONV_ADD_MODULE, cls._NNI_CONV_ADD_RELU_MODULE]:
239
+ mod = mod[0]
240
+ weight_post_process = mod.qconfig.weight()
241
+ return cls.get_qconv(mod, activation_post_process, weight_post_process)
242
+
243
+ @classmethod
244
+ def from_reference(cls, ref_qconv, output_scale, output_zero_point):
245
+ r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module
246
+ Args:
247
+ ref_qconv (Module): a reference quantized module, either produced by torch.ao.quantization
248
+ utilities or provided by the user
249
+ output_scale (float): scale for output Tensor
250
+ output_zero_point (int): zero point for output Tensor
251
+ """
252
+ qconv = cls(
253
+ ref_qconv.in_channels,
254
+ ref_qconv.out_channels,
255
+ ref_qconv.kernel_size, # type: ignore[arg-type]
256
+ ref_qconv.stride, # type: ignore[arg-type]
257
+ ref_qconv.padding, # type: ignore[arg-type]
258
+ ref_qconv.dilation, # type: ignore[arg-type]
259
+ ref_qconv.groups,
260
+ ref_qconv.bias is not None, # type: ignore[arg-type]
261
+ ref_qconv.padding_mode,
262
+ device=ref_qconv.weight.device,
263
+ dtype=ref_qconv.weight.dtype)
264
+ qweight = ref_qconv.get_quantized_weight()
265
+ qconv.set_weight_bias(qweight, ref_qconv.bias)
266
+ qconv.scale = float(output_scale)
267
+ qconv.zero_point = int(output_zero_point)
268
+ return qconv
269
+
270
+
271
+ class Conv1d(_ConvNd):
272
+ r"""Applies a 1D convolution over a quantized input signal composed of
273
+ several quantized input planes.
274
+
275
+ For details on input arguments, parameters, and implementation see
276
+ :class:`~torch.nn.Conv1d`.
277
+
278
+ .. note::
279
+ Only `zeros` is supported for the :attr:`padding_mode` argument.
280
+
281
+ .. note::
282
+ Only `torch.quint8` is supported for the input data type.
283
+
284
+
285
+ Attributes:
286
+ weight (Tensor): packed tensor derived from the learnable weight
287
+ parameter.
288
+ scale (Tensor): scalar for the output scale
289
+ zero_point (Tensor): scalar for the output zero point
290
+
291
+ See :class:`~torch.nn.Conv1d` for other attributes.
292
+
293
+ Examples::
294
+
295
+ >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
296
+ >>> m = nn.quantized.Conv1d(16, 33, 3, stride=2)
297
+ >>> input = torch.randn(20, 16, 100)
298
+ >>> # quantize input to quint8
299
+ >>> # xdoctest: +SKIP
300
+ >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0,
301
+ ... dtype=torch.quint8)
302
+ >>> output = m(q_input)
303
+
304
+ """
305
+
306
+ _FLOAT_MODULE = nn.Conv1d
307
+ _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn1d
308
+ _NNI_CONV_RELU_MODULE = nni.ConvReLU1d
309
+ _NNI_CONV_ADD_MODULE: None = None
310
+ _NNI_CONV_ADD_RELU_MODULE: None = None
311
+
312
+ def __init__(self,
313
+ in_channels: int,
314
+ out_channels: int,
315
+ kernel_size: _size_1_t,
316
+ stride: _size_1_t = 1,
317
+ padding: _size_1_t = 0,
318
+ dilation: _size_1_t = 1,
319
+ groups: int = 1,
320
+ bias: bool = True,
321
+ padding_mode: str = 'zeros',
322
+ device=None,
323
+ dtype=None):
324
+ factory_kwargs = {'device': device, 'dtype': dtype}
325
+ kernel_size = _single(kernel_size)
326
+ stride = _single(stride)
327
+ padding = padding if isinstance(padding, str) else _single(padding)
328
+ dilation = _single(dilation)
329
+
330
+ # Subclasses of _ConvNd needs to call _init rather than __init__. See
331
+ # discussion on PR #49702
332
+ super()._init(
333
+ in_channels, out_channels, kernel_size, stride, padding, dilation,
334
+ False, _single(0), groups, bias, padding_mode, **factory_kwargs)
335
+
336
+ def _get_name(self):
337
+ return 'QuantizedConv1d'
338
+
339
+ def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
340
+ if self.padding_mode == 'zeros':
341
+ self._packed_params = torch.ops.quantized.conv1d_prepack(
342
+ w, b, self.stride, self.padding, self.dilation, self.groups)
343
+ else:
344
+ self._packed_params = torch.ops.quantized.conv1d_prepack(
345
+ w, b, self.stride, _pair(0), self.dilation,
346
+ self.groups)
347
+
348
+ def _weight_bias(self):
349
+ w, b = torch.ops.quantized.conv1d_unpack(self._packed_params)
350
+ return w, b
351
+
352
+ def weight(self):
353
+ return self._weight_bias()[0]
354
+
355
+ def bias(self):
356
+ return self._weight_bias()[1]
357
+
358
+ def forward(self, input):
359
+ # Temporarily using len(shape) instead of ndim due to JIT issue
360
+ # https://github.com/pytorch/pytorch/issues/23890
361
+ if len(input.shape) != 3:
362
+ raise ValueError("Input shape must be `(N, C, L)`!")
363
+ if self.padding_mode != 'zeros':
364
+ # Padding in Conv1d is stored as (p, p), need to get (p,)
365
+ _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1])
366
+ input = F.pad(input, _reversed_padding_repeated_twice,
367
+ mode=self.padding_mode)
368
+ return ops.quantized.conv1d(input, self._packed_params, self.scale, self.zero_point)
369
+
370
+ @classmethod
371
+ def from_float(cls, mod):
372
+ r"""Creates a quantized module from a float module or qparams_dict.
373
+
374
+ Args:
375
+ mod (Module): a float module, either produced by torch.ao.quantization
376
+ utilities or provided by the user
377
+ """
378
+ return _ConvNd.from_float(cls, mod)
379
+
380
+
381
+ class Conv2d(_ConvNd):
382
+ r"""Applies a 2D convolution over a quantized input signal composed of
383
+ several quantized input planes.
384
+
385
+ For details on input arguments, parameters, and implementation see
386
+ :class:`~torch.nn.Conv2d`.
387
+
388
+ .. note::
389
+ Only `zeros` is supported for the :attr:`padding_mode` argument.
390
+
391
+ .. note::
392
+ Only `torch.quint8` is supported for the input data type.
393
+
394
+
395
+ Attributes:
396
+ weight (Tensor): packed tensor derived from the learnable weight
397
+ parameter.
398
+ scale (Tensor): scalar for the output scale
399
+ zero_point (Tensor): scalar for the output zero point
400
+
401
+ See :class:`~torch.nn.Conv2d` for other attributes.
402
+
403
+ Examples::
404
+
405
+ >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
406
+ >>> # With square kernels and equal stride
407
+ >>> m = nn.quantized.Conv2d(16, 33, 3, stride=2)
408
+ >>> # non-square kernels and unequal stride and with padding
409
+ >>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
410
+ >>> # non-square kernels and unequal stride and with padding and dilation
411
+ >>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
412
+ >>> input = torch.randn(20, 16, 50, 100)
413
+ >>> # quantize input to quint8
414
+ >>> # xdoctest: +SKIP
415
+ >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
416
+ >>> output = m(q_input)
417
+
418
+ """
419
+ _FLOAT_MODULE = nn.Conv2d
420
+ _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn2d
421
+ _NNI_CONV_RELU_MODULE = nni.ConvReLU2d
422
+ _NNI_CONV_ADD_MODULE = nni.ConvAdd2d
423
+ _NNI_CONV_ADD_RELU_MODULE = nni.ConvAddReLU2d
424
+
425
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
426
+ padding=0, dilation=1, groups=1, bias=True,
427
+ padding_mode='zeros', device=None, dtype=None):
428
+ factory_kwargs = {'device': device, 'dtype': dtype}
429
+ kernel_size = _pair(kernel_size)
430
+ stride = _pair(stride)
431
+ padding = _pair(padding)
432
+ dilation = _pair(dilation)
433
+ # Subclasses of _ConvNd need to call _init rather than __init__. See
434
+ # discussion on PR #49702
435
+ super()._init(
436
+ in_channels, out_channels, kernel_size, stride, padding, dilation,
437
+ False, _pair(0), groups, bias, padding_mode, **factory_kwargs)
438
+
439
+ def _get_name(self):
440
+ return 'QuantizedConv2d'
441
+
442
+ def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
443
+ if self.padding_mode == 'zeros':
444
+ self._packed_params = torch.ops.quantized.conv2d_prepack(
445
+ w, b, self.stride, self.padding, self.dilation, self.groups)
446
+ else:
447
+ self._packed_params = torch.ops.quantized.conv2d_prepack(
448
+ w, b, self.stride, _pair(0), self.dilation, self.groups)
449
+
450
+ def _weight_bias(self):
451
+ return self._packed_params.unpack()
452
+
453
+ def weight(self):
454
+ return self._weight_bias()[0]
455
+
456
+ def bias(self):
457
+ return self._weight_bias()[1]
458
+
459
+ def forward(self, input):
460
+ # Temporarily using len(shape) instead of ndim due to JIT issue
461
+ # https://github.com/pytorch/pytorch/issues/23890
462
+ if len(input.shape) != 4:
463
+ raise ValueError("Input shape must be `(N, C, H, W)`!")
464
+ if self.padding_mode != 'zeros':
465
+ _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
466
+ input = F.pad(input, _reversed_padding_repeated_twice,
467
+ mode=self.padding_mode)
468
+ return ops.quantized.conv2d(
469
+ input, self._packed_params, self.scale, self.zero_point)
470
+
471
+ @classmethod
472
+ def from_float(cls, mod):
473
+ r"""Creates a quantized module from a float module or qparams_dict.
474
+
475
+ Args:
476
+ mod (Module): a float module, either produced by torch.ao.quantization
477
+ utilities or provided by the user
478
+ """
479
+ return _ConvNd.from_float(cls, mod)
480
+
481
+
482
+ class Conv3d(_ConvNd):
483
+ r"""Applies a 3D convolution over a quantized input signal composed of
484
+ several quantized input planes.
485
+
486
+ For details on input arguments, parameters, and implementation see
487
+ :class:`~torch.nn.Conv3d`.
488
+
489
+ .. note::
490
+ Only `zeros` is supported for the :attr:`padding_mode` argument.
491
+
492
+ .. note::
493
+ Only `torch.quint8` is supported for the input data type.
494
+
495
+
496
+ Attributes:
497
+ weight (Tensor): packed tensor derived from the learnable weight
498
+ parameter.
499
+ scale (Tensor): scalar for the output scale
500
+ zero_point (Tensor): scalar for the output zero point
501
+
502
+ See :class:`~torch.nn.Conv3d` for other attributes.
503
+
504
+ Examples::
505
+
506
+ >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
507
+ >>> # With square kernels and equal stride
508
+ >>> m = nn.quantized.Conv3d(16, 33, 3, stride=2)
509
+ >>> # non-square kernels and unequal stride and with padding
510
+ >>> m = nn.quantized.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2))
511
+ >>> # non-square kernels and unequal stride and with padding and dilation
512
+ >>> m = nn.quantized.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), dilation=(1, 2, 2))
513
+ >>> input = torch.randn(20, 16, 56, 56, 56)
514
+ >>> # quantize input to quint8
515
+ >>> # xdoctest: +SKIP
516
+ >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
517
+ >>> output = m(q_input)
518
+
519
+ """
520
+ _FLOAT_MODULE = nn.Conv3d
521
+ _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn3d
522
+ _NNI_CONV_RELU_MODULE = nni.ConvReLU3d
523
+ _NNI_CONV_ADD_MODULE: None = None
524
+ _NNI_CONV_ADD_RELU_MODULE: None = None
525
+
526
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
527
+ padding=0, dilation=1, groups=1, bias=True,
528
+ padding_mode='zeros', device=None, dtype=None):
529
+ assert padding_mode != 'reflect', "Conv3d does not support reflection padding"
530
+ factory_kwargs = {'device': device, 'dtype': dtype}
531
+ kernel_size = _triple(kernel_size)
532
+ stride = _triple(stride)
533
+ padding = _triple(padding)
534
+ dilation = _triple(dilation)
535
+ # Subclasses of _ConvNd need to call _init rather than __init__. See
536
+ # discussion on PR #49702
537
+ super()._init(
538
+ in_channels, out_channels, kernel_size, stride, padding, dilation,
539
+ False, _triple(0), groups, bias, padding_mode, **factory_kwargs)
540
+
541
+ def _get_name(self):
542
+ return 'QuantizedConv3d'
543
+
544
+ def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
545
+ if self.padding_mode == 'zeros':
546
+ self._packed_params = torch.ops.quantized.conv3d_prepack(
547
+ w, b, self.stride, self.padding, self.dilation, self.groups)
548
+ else:
549
+ self._packed_params = torch.ops.quantized.conv3d_prepack(
550
+ w, b, self.stride, _triple(0), self.dilation, self.groups)
551
+
552
+ def _weight_bias(self):
553
+ return self._packed_params.unpack()
554
+
555
+ def weight(self):
556
+ return self._weight_bias()[0]
557
+
558
+ def bias(self):
559
+ return self._weight_bias()[1]
560
+
561
+ def forward(self, input):
562
+ # Temporarily using len(shape) instead of ndim due to JIT issue
563
+ # https://github.com/pytorch/pytorch/issues/23890
564
+ if len(input.shape) != 5:
565
+ raise ValueError("Input shape must be `(N, C, D, H, W)`!")
566
+ if self.padding_mode != 'zeros':
567
+ _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
568
+ input = F.pad(input, _reversed_padding_repeated_twice,
569
+ mode=self.padding_mode)
570
+ return ops.quantized.conv3d(
571
+ input, self._packed_params, self.scale, self.zero_point)
572
+
573
+ @classmethod
574
+ def from_float(cls, mod):
575
+ r"""Creates a quantized module from a float module or qparams_dict.
576
+
577
+ Args:
578
+ mod (Module): a float module, either produced by torch.ao.quantization
579
+ utilities or provided by the user
580
+ """
581
+ return _ConvNd.from_float(cls, mod)
582
+
583
+ # === Transposed Convolutions ===
584
+ MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd)
585
+
586
+
587
+ class _ConvTransposeNd(_ConvNd):
588
+
589
+ _FLOAT_MODULE = MOD
590
+
591
+ def __init__(self, in_channels, out_channels, kernel_size, stride,
592
+ padding, dilation, transposed, output_padding,
593
+ groups, bias, padding_mode, device=None, dtype=None):
594
+ if padding_mode != 'zeros':
595
+ raise ValueError(f'Only "zeros" padding mode is supported for {self.__class__.__name__}')
596
+ factory_kwargs = {'device': device, 'dtype': dtype}
597
+ # Subclasses of _ConvNd need to call _init rather than __init__. See
598
+ # discussion on PR #49702
599
+ super()._init(
600
+ in_channels, out_channels, kernel_size, stride,
601
+ padding, dilation, transposed, output_padding,
602
+ groups, bias, padding_mode, **factory_kwargs)
603
+
604
+ def _input_padding(self, kernel_size: List[int], dilation: List[int], padding: List[int]) -> List[int]:
605
+ res = torch.jit.annotate(List[int], [])
606
+ for kdx in range(len(kernel_size)):
607
+ pad = (dilation[kdx] * (kernel_size[kdx] - 1) - padding[kdx])
608
+ res.append(pad)
609
+ return res
610
+
611
+ @classmethod
612
+ def from_float(cls, mod):
613
+ r"""Creates a quantized module from a float module or qparams_dict.
614
+ Args:
615
+ mod (Module): a float module, either produced by torch.ao.quantization
616
+ utilities or provided by the user
617
+ """
618
+ # derived classes override cls._FLOAT_MODULE attribute
619
+ msg = ' nnq.' + cls.__name__ + '.from_float only works for ' + \
620
+ cls._FLOAT_MODULE.__name__ # type: ignore[attr-defined]
621
+ assert type(mod) == cls._FLOAT_MODULE, msg
622
+ assert hasattr(mod, 'qconfig'), \
623
+ 'Input float module must have qconfig defined.'
624
+ weight_post_process = mod.qconfig.weight()
625
+ weight_post_process(mod.weight)
626
+ assert weight_post_process.dtype == torch.qint8, \
627
+ 'Weight observer must have a dtype of qint8'
628
+ qweight = _quantize_weight(mod.weight.float(), weight_post_process)
629
+ # the __init__ call used is the one from derived classes and not the one from _ConvTransposeNd
630
+ qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, # type: ignore[call-arg]
631
+ mod.stride, mod.padding, mod.output_padding, mod.groups,
632
+ mod.bias is not None, mod.dilation, mod.padding_mode)
633
+ qconv.set_weight_bias(qweight, mod.bias)
634
+ if not hasattr(mod, "activation_post_process") or mod.activation_post_process.dtype == torch.float:
635
+ return qconv # dynamic quantization doesn't need scale/zero_point
636
+ else:
637
+ act_scale, act_zp = mod.activation_post_process.calculate_qparams()
638
+ qconv.scale = float(act_scale)
639
+ qconv.zero_point = int(act_zp)
640
+ return qconv
641
+
642
+ @staticmethod
643
+ def from_reference(cls, ref_qconvt, output_scale, output_zero_point):
644
+ r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module
645
+ Args:
646
+ ref_qconvt (Module): a reference quantized module, either produced by torch.ao.quantization
647
+ utilities or provided by the user
648
+ output_scale (float): scale for output Tensor
649
+ output_zero_point (int): zero point for output Tensor
650
+ """
651
+ qconv = cls(
652
+ ref_qconvt.in_channels,
653
+ ref_qconvt.out_channels,
654
+ ref_qconvt.kernel_size, # type: ignore[arg-type]
655
+ ref_qconvt.stride, # type: ignore[arg-type]
656
+ ref_qconvt.padding, # type: ignore[arg-type]
657
+ ref_qconvt.output_padding, # type: ignore[arg-type]
658
+ ref_qconvt.groups,
659
+ ref_qconvt.bias is not None, # type: ignore[arg-type]
660
+ ref_qconvt.dilation, # type: ignore[arg-type]
661
+ ref_qconvt.padding_mode,
662
+ device=ref_qconvt.weight.device,
663
+ dtype=ref_qconvt.weight.dtype)
664
+ qweight = ref_qconvt.get_quantized_weight()
665
+ qconv.set_weight_bias(qweight, ref_qconvt.bias)
666
+ qconv.scale = float(output_scale)
667
+ qconv.zero_point = int(output_zero_point)
668
+ return qconv
669
+
670
+
671
+ class ConvTranspose1d(_ConvTransposeNd):
672
+ r"""Applies a 1D transposed convolution operator over an input image
673
+ composed of several input planes.
674
+ For details on input arguments, parameters, and implementation see
675
+ :class:`~torch.nn.ConvTranspose1d`.
676
+
677
+ .. note:: Currently only the QNNPACK engine is implemented.
678
+ Please, set the `torch.backends.quantized.engine = 'qnnpack'`
679
+
680
+ For special notes, please, see :class:`~torch.ao.nn.quantized.Conv1d`
681
+
682
+ Attributes:
683
+ weight (Tensor): packed tensor derived from the learnable weight
684
+ parameter.
685
+ scale (Tensor): scalar for the output scale
686
+ zero_point (Tensor): scalar for the output zero point
687
+ See :class:`~torch.nn.ConvTranspose2d` for other attributes.
688
+
689
+ Examples::
690
+
691
+ >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
692
+ >>> torch.backends.quantized.engine = 'qnnpack'
693
+ >>> from torch.ao.nn import quantized as nnq
694
+ >>> # With square kernels and equal stride
695
+ >>> m = nnq.ConvTranspose1d(16, 33, 3, stride=2)
696
+ >>> # non-square kernels and unequal stride and with padding
697
+ >>> m = nnq.ConvTranspose1d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
698
+ >>> input = torch.randn(20, 16, 50)
699
+ >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
700
+ >>> output = m(q_input)
701
+ >>> # exact output size can be also specified as an argument
702
+ >>> input = torch.randn(1, 16, 12)
703
+ >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
704
+ >>> downsample = nnq.Conv1d(16, 16, 3, stride=2, padding=1)
705
+ >>> upsample = nnq.ConvTranspose1d(16, 16, 3, stride=2, padding=1)
706
+ >>> h = downsample(q_input)
707
+ >>> h.size()
708
+ torch.Size([1, 16, 6])
709
+ >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter)
710
+ >>> output = upsample(h, output_size=input.size())
711
+ >>> output.size()
712
+ torch.Size([1, 16, 12])
713
+ """
714
+
715
+ _FLOAT_MODULE = nn.ConvTranspose1d
716
+
717
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
718
+ padding=0, output_padding=0, groups=1, bias=True,
719
+ dilation=1, padding_mode='zeros', device=None, dtype=None):
720
+ factory_kwargs = {'device': device, 'dtype': dtype}
721
+ kernel_size = _single(kernel_size)
722
+ stride = _single(stride)
723
+ padding = _single(padding)
724
+ dilation = _single(dilation)
725
+ output_padding = _single(output_padding)
726
+
727
+ super().__init__(
728
+ in_channels, out_channels, kernel_size, stride, padding, dilation,
729
+ True, output_padding, groups, bias, padding_mode, **factory_kwargs)
730
+
731
+ def _get_name(self):
732
+ return 'QuantizedConvTranspose1d'
733
+
734
+ def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
735
+ self._packed_params = torch.ops.quantized.conv_transpose1d_prepack(
736
+ w, b, self.stride, self.padding, self.output_padding, self.dilation,
737
+ self.groups)
738
+
739
+ def _weight_bias(self):
740
+ w, b = torch.ops.quantized.conv_transpose1d_unpack(self._packed_params)
741
+ return w, b
742
+
743
+ def weight(self):
744
+ (w, _) = self._weight_bias()
745
+ return w
746
+
747
+ def bias(self):
748
+ (_, b) = self._weight_bias()
749
+ return b
750
+
751
+ def forward(self, input):
752
+ # Temporarily using len(shape) instead of ndim due to JIT issue
753
+ # https://github.com/pytorch/pytorch/issues/23890
754
+ if len(input.shape) != 3:
755
+ raise ValueError("Input shape must be `(N, C, L)`!")
756
+ return torch.ops.quantized.conv_transpose1d(
757
+ input, self._packed_params, self.scale, self.zero_point)
758
+
759
+ @classmethod
760
+ def from_reference(cls, ref_qconvt, output_scale, output_zero_point):
761
+ return _ConvTransposeNd.from_reference(cls, ref_qconvt, output_scale, output_zero_point)
762
+
763
+
764
+ class ConvTranspose2d(_ConvTransposeNd):
765
+ r"""Applies a 2D transposed convolution operator over an input image
766
+ composed of several input planes.
767
+ For details on input arguments, parameters, and implementation see
768
+ :class:`~torch.nn.ConvTranspose2d`.
769
+
770
+ For special notes, please, see :class:`~torch.ao.nn.quantized.Conv2d`
771
+
772
+ Attributes:
773
+ weight (Tensor): packed tensor derived from the learnable weight
774
+ parameter.
775
+ scale (Tensor): scalar for the output scale
776
+ zero_point (Tensor): scalar for the output zero point
777
+ See :class:`~torch.nn.ConvTranspose2d` for other attributes.
778
+
779
+ Examples::
780
+
781
+ >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
782
+ >>> # QNNPACK or FBGEMM as backend
783
+ >>> torch.backends.quantized.engine = 'qnnpack'
784
+ >>> # With square kernels and equal stride
785
+ >>> import torch.ao.nn.quantized as nnq
786
+ >>> m = nnq.ConvTranspose2d(16, 33, 3, stride=2)
787
+ >>> # non-square kernels and unequal stride and with padding
788
+ >>> m = nnq.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
789
+ >>> input = torch.randn(20, 16, 50, 100)
790
+ >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
791
+ >>> output = m(q_input)
792
+ >>> # exact output size can be also specified as an argument
793
+ >>> input = torch.randn(1, 16, 12, 12)
794
+ >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
795
+ >>> downsample = nnq.Conv2d(16, 16, 3, stride=2, padding=1)
796
+ >>> upsample = nnq.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
797
+ >>> h = downsample(q_input)
798
+ >>> h.size()
799
+ torch.Size([1, 16, 6, 6])
800
+ >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter)
801
+ >>> output = upsample(h, output_size=input.size())
802
+ >>> output.size()
803
+ torch.Size([1, 16, 12, 12])
804
+ """
805
+
806
+ _FLOAT_MODULE = nn.ConvTranspose2d
807
+
808
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
809
+ padding=0, output_padding=0, groups=1, bias=True,
810
+ dilation=1, padding_mode='zeros', device=None, dtype=None):
811
+ factory_kwargs = {'device': device, 'dtype': dtype}
812
+ kernel_size = _pair(kernel_size)
813
+ stride = _pair(stride)
814
+ padding = _pair(padding)
815
+ dilation = _pair(dilation)
816
+ output_padding = _pair(output_padding)
817
+
818
+ super().__init__(
819
+ in_channels, out_channels, kernel_size, stride, padding, dilation,
820
+ True, output_padding, groups, bias, padding_mode, **factory_kwargs)
821
+
822
+ def _get_name(self):
823
+ return 'QuantizedConvTranspose2d'
824
+
825
+ def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
826
+ self._packed_params = torch.ops.quantized.conv_transpose2d_prepack(
827
+ w, b, self.stride, self.padding, self.output_padding, self.dilation,
828
+ self.groups)
829
+
830
+ def _weight_bias(self):
831
+ w, b = torch.ops.quantized.conv2d_unpack(self._packed_params)
832
+ return w, b
833
+
834
+ def weight(self):
835
+ (w, _) = self._weight_bias()
836
+ return w
837
+
838
+ def bias(self):
839
+ (_, b) = self._weight_bias()
840
+ return b
841
+
842
+ def forward(self, input):
843
+ # Temporarily using len(shape) instead of ndim due to JIT issue
844
+ # https://github.com/pytorch/pytorch/issues/23890
845
+ if len(input.shape) != 4:
846
+ raise ValueError("Input shape must be `(N, C, H, W)`!")
847
+ return ops.quantized.conv_transpose2d(
848
+ input, self._packed_params, self.scale, self.zero_point)
849
+
850
+ @classmethod
851
+ def from_reference(cls, ref_qconvt, output_scale, output_zero_point):
852
+ return _ConvTransposeNd.from_reference(cls, ref_qconvt, output_scale, output_zero_point)
853
+
854
+
855
+ class ConvTranspose3d(_ConvTransposeNd):
856
+ r"""Applies a 3D transposed convolution operator over an input image
857
+ composed of several input planes.
858
+ For details on input arguments, parameters, and implementation see
859
+ :class:`~torch.nn.ConvTranspose3d`.
860
+
861
+ .. note:: Currently only the FBGEMM engine is implemented.
862
+ Please, set the `torch.backends.quantized.engine = 'fbgemm'`
863
+
864
+ For special notes, please, see :class:`~torch.ao.nn.quantized.Conv3d`
865
+
866
+ Attributes:
867
+ weight (Tensor): packed tensor derived from the learnable weight
868
+ parameter.
869
+ scale (Tensor): scalar for the output scale
870
+ zero_point (Tensor): scalar for the output zero point
871
+ See :class:`~torch.nn.ConvTranspose3d` for other attributes.
872
+
873
+ Examples::
874
+
875
+ >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
876
+ >>> torch.backends.quantized.engine = 'fbgemm'
877
+ >>> from torch.ao.nn import quantized as nnq
878
+ >>> # With cubic kernels and equal stride
879
+ >>> m = nnq.ConvTranspose3d(16, 33, 3, stride=2)
880
+ >>> # non-cubic kernels and unequal stride and with padding
881
+ >>> m = nnq.ConvTranspose3d(16, 33, (3, 3, 5), stride=(2, 1, 1), padding=(4, 2, 2))
882
+ >>> input = torch.randn(20, 16, 50, 100, 100)
883
+ >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
884
+ >>> output = m(q_input)
885
+ >>> # exact output size can be also specified as an argument
886
+ >>> input = torch.randn(1, 16, 12, 12, 12)
887
+ >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
888
+ >>> downsample = nnq.Conv3d(16, 16, 3, stride=2, padding=1)
889
+ >>> upsample = nnq.ConvTranspose3d(16, 16, 3, stride=2, padding=1)
890
+ >>> h = downsample(q_input)
891
+ >>> h.size()
892
+ torch.Size([1, 16, 6, 6, 6])
893
+ >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter)
894
+ >>> output = upsample(h, output_size=input.size())
895
+ >>> output.size()
896
+ torch.Size([1, 16, 12, 12, 12])
897
+ """
898
+
899
+ _FLOAT_MODULE = nn.ConvTranspose3d
900
+
901
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
902
+ padding=0, output_padding=0, groups=1, bias=True,
903
+ dilation=1, padding_mode='zeros', device=None, dtype=None):
904
+ factory_kwargs = {'device': device, 'dtype': dtype}
905
+ kernel_size = _triple(kernel_size)
906
+ stride = _triple(stride)
907
+ padding = _triple(padding)
908
+ dilation = _triple(dilation)
909
+ output_padding = _triple(output_padding)
910
+
911
+ super().__init__(
912
+ in_channels, out_channels, kernel_size, stride, padding, dilation,
913
+ True, output_padding, groups, bias, padding_mode, **factory_kwargs)
914
+
915
+ def _get_name(self):
916
+ return 'QuantizedConvTranspose3d'
917
+
918
+ def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
919
+ self._packed_params = torch.ops.quantized.conv_transpose3d_prepack(
920
+ w, b, self.stride, self.padding, self.output_padding, self.dilation,
921
+ self.groups)
922
+
923
+ def _weight_bias(self):
924
+ w, b = torch.ops.quantized.conv3d_unpack(self._packed_params)
925
+ return w, b
926
+
927
+ def weight(self):
928
+ (w, _) = self._weight_bias()
929
+ return w
930
+
931
+ def bias(self):
932
+ (_, b) = self._weight_bias()
933
+ return b
934
+
935
+ def forward(self, input):
936
+ # Temporarily using len(shape) instead of ndim due to JIT issue
937
+ # https://github.com/pytorch/pytorch/issues/23890
938
+ if len(input.shape) != 5:
939
+ raise ValueError("Input shape must be `(N, C, T, H, W)`!")
940
+ return ops.quantized.conv_transpose3d(
941
+ input, self._packed_params, self.scale, self.zero_point)
942
+
943
+ @classmethod
944
+ def from_reference(cls, ref_qconvt, output_scale, output_zero_point):
945
+ return _ConvTransposeNd.from_reference(cls, ref_qconvt, output_scale, output_zero_point)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/dropout.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ __all__ = ['Dropout']
4
+
5
+ class Dropout(torch.nn.Dropout):
6
+ r"""This is the quantized equivalent of :class:`~torch.nn.Dropout`.
7
+ And this is a placeholder to enable models where fp32 tensors
8
+ had dropout to work with quantized tensors in train and eval mode.
9
+
10
+ Args:
11
+ p: probability of an element to be zeroed
12
+ inplace: can optionally do the operation in-place. Default: ``False``
13
+ """
14
+
15
+ def forward(self, input):
16
+ return input
17
+
18
+ def _get_name(self):
19
+ return 'QuantizedDropout'
20
+
21
+ @classmethod
22
+ def from_float(cls, mod):
23
+ return cls(mod.p, mod.inplace)
24
+
25
+ @classmethod
26
+ def from_reference(cls, mod, scale, zero_point):
27
+ return cls(mod.p, mod.inplace)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/functional_modules.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch._ops import ops
6
+
7
+ __all__ = ['FloatFunctional', 'FXFloatFunctional', 'QFunctional']
8
+
9
+ class FloatFunctional(torch.nn.Module):
10
+ r"""State collector class for float operations.
11
+
12
+ The instance of this class can be used instead of the ``torch.`` prefix for
13
+ some operations. See example usage below.
14
+
15
+ .. note::
16
+
17
+ This class does not provide a ``forward`` hook. Instead, you must use
18
+ one of the underlying functions (e.g. ``add``).
19
+
20
+ Examples::
21
+
22
+ >>> f_add = FloatFunctional()
23
+ >>> a = torch.tensor(3.0)
24
+ >>> b = torch.tensor(4.0)
25
+ >>> f_add.add(a, b) # Equivalent to ``torch.add(a, b)``
26
+
27
+ Valid operation names:
28
+ - add
29
+ - cat
30
+ - mul
31
+ - add_relu
32
+ - add_scalar
33
+ - mul_scalar
34
+ """
35
+ def __init__(self):
36
+ super().__init__()
37
+ self.activation_post_process = torch.nn.Identity()
38
+
39
+ def forward(self, x):
40
+ raise RuntimeError("FloatFunctional is not intended to use the " +
41
+ "'forward'. Please use the underlying operation")
42
+
43
+ r"""Operation equivalent to ``torch.add(Tensor, Tensor)``"""
44
+ def add(self, x: Tensor, y: Tensor) -> Tensor:
45
+ r = torch.add(x, y)
46
+ r = self.activation_post_process(r)
47
+ return r
48
+
49
+ r"""Operation equivalent to ``torch.add(Tensor, float)``"""
50
+ def add_scalar(self, x: Tensor, y: float) -> Tensor:
51
+ r = torch.add(x, y)
52
+ # Note: this operation is not observed because the observation is not
53
+ # needed for the quantized op.
54
+ return r
55
+
56
+ r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``"""
57
+ def mul(self, x: Tensor, y: Tensor) -> Tensor:
58
+ r = torch.mul(x, y)
59
+ r = self.activation_post_process(r)
60
+ return r
61
+
62
+ r"""Operation equivalent to ``torch.mul(Tensor, float)``"""
63
+ def mul_scalar(self, x: Tensor, y: float) -> Tensor:
64
+ r = torch.mul(x, y)
65
+ # Note: this operation is not observed because the observation is not
66
+ # needed for the quantized op.
67
+ return r
68
+
69
+ r"""Operation equivalent to ``torch.cat``"""
70
+ def cat(self, x: List[Tensor], dim: int = 0) -> Tensor:
71
+ r = torch.cat(x, dim=dim)
72
+ r = self.activation_post_process(r)
73
+ return r
74
+
75
+ r"""Operation equivalent to ``relu(torch.add(x,y))``"""
76
+ def add_relu(self, x: Tensor, y: Tensor) -> Tensor:
77
+ r = torch.add(x, y)
78
+ r = torch.nn.functional.relu(r)
79
+ r = self.activation_post_process(r)
80
+ return r
81
+
82
+ r"""Operation equivalent to ``torch.matmul(Tensor, Tensor)``"""
83
+ def matmul(self, x: Tensor, y: Tensor) -> Tensor:
84
+ r = torch.matmul(x, y)
85
+ r = self.activation_post_process(r)
86
+ return r
87
+
88
+ class FXFloatFunctional(torch.nn.Module):
89
+ r""" module to replace FloatFunctional module before FX graph mode quantization,
90
+ since activation_post_process will be inserted in top level module directly
91
+
92
+ Valid operation names:
93
+ - add
94
+ - cat
95
+ - mul
96
+ - add_relu
97
+ - add_scalar
98
+ - mul_scalar
99
+ """
100
+ def forward(self, x):
101
+ raise RuntimeError("FloatFunctional is not intended to use the " +
102
+ "'forward'. Please use the underlying operation")
103
+
104
+ r"""Operation equivalent to ``torch.add(Tensor, Tensor)``"""
105
+ def add(self, x: Tensor, y: Tensor) -> Tensor:
106
+ r = torch.add(x, y)
107
+ return r
108
+
109
+ r"""Operation equivalent to ``torch.add(Tensor, float)``"""
110
+ def add_scalar(self, x: Tensor, y: float) -> Tensor:
111
+ r = torch.add(x, y)
112
+ return r
113
+
114
+ r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``"""
115
+ def mul(self, x: Tensor, y: Tensor) -> Tensor:
116
+ r = torch.mul(x, y)
117
+ return r
118
+
119
+ r"""Operation equivalent to ``torch.mul(Tensor, float)``"""
120
+ def mul_scalar(self, x: Tensor, y: float) -> Tensor:
121
+ r = torch.mul(x, y)
122
+ return r
123
+
124
+ r"""Operation equivalent to ``torch.cat``"""
125
+ def cat(self, x: List[Tensor], dim: int = 0) -> Tensor:
126
+ r = torch.cat(x, dim=dim)
127
+ return r
128
+
129
+ r"""Operation equivalent to ``relu(torch.add(x,y))``"""
130
+ def add_relu(self, x: Tensor, y: Tensor) -> Tensor:
131
+ r = torch.add(x, y)
132
+ r = torch.nn.functional.relu(r)
133
+ return r
134
+
135
+ r"""Operation equivalent to ``torch.matmul(Tensor, Tensor)``"""
136
+ def matmul(self, x: Tensor, y: Tensor) -> Tensor:
137
+ r = torch.matmul(x, y)
138
+ return r
139
+
140
+ class QFunctional(torch.nn.Module):
141
+ r"""Wrapper class for quantized operations.
142
+
143
+ The instance of this class can be used instead of the
144
+ ``torch.ops.quantized`` prefix. See example usage below.
145
+
146
+ .. note::
147
+
148
+ This class does not provide a ``forward`` hook. Instead, you must use
149
+ one of the underlying functions (e.g. ``add``).
150
+
151
+ Examples::
152
+
153
+ >>> q_add = QFunctional()
154
+ >>> # xdoctest: +SKIP
155
+ >>> a = torch.quantize_per_tensor(torch.tensor(3.0), 1.0, 0, torch.qint32)
156
+ >>> b = torch.quantize_per_tensor(torch.tensor(4.0), 1.0, 0, torch.qint32)
157
+ >>> q_add.add(a, b) # Equivalent to ``torch.ops.quantized.add(a, b, 1.0, 0)``
158
+
159
+ Valid operation names:
160
+ - add
161
+ - cat
162
+ - mul
163
+ - add_relu
164
+ - add_scalar
165
+ - mul_scalar
166
+ """
167
+ def __init__(self):
168
+ super().__init__()
169
+ self.scale = 1.0
170
+ self.zero_point = 0
171
+ self.activation_post_process = torch.nn.Identity()
172
+
173
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
174
+ super()._save_to_state_dict(destination, prefix, keep_vars)
175
+ destination[prefix + 'scale'] = torch.tensor(self.scale)
176
+ destination[prefix + 'zero_point'] = torch.tensor(self.zero_point)
177
+
178
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
179
+ missing_keys, unexpected_keys, error_msgs):
180
+
181
+ self.scale = float(state_dict.pop(prefix + 'scale'))
182
+ self.zero_point = int(state_dict.pop(prefix + 'zero_point'))
183
+ super()._load_from_state_dict(state_dict, prefix, local_metadata, False,
184
+ missing_keys, unexpected_keys, error_msgs)
185
+
186
+ def _get_name(self):
187
+ return 'QFunctional'
188
+
189
+ def extra_repr(self):
190
+ return f'scale={self.scale}, zero_point={self.zero_point}'
191
+
192
+ def forward(self, x):
193
+ raise RuntimeError("Functional is not intended to use the " +
194
+ "'forward'. Please use the underlying operation")
195
+
196
+ r"""Operation equivalent to ``torch.ops.quantized.add``"""
197
+ def add(self, x: Tensor, y: Tensor) -> Tensor:
198
+ r = ops.quantized.add(x, y, scale=self.scale, zero_point=self.zero_point)
199
+ r = self.activation_post_process(r)
200
+ return r
201
+
202
+ r"""Operation equivalent to ``torch.ops.quantized.add(Tensor, float)``"""
203
+ def add_scalar(self, x: Tensor, y: float) -> Tensor:
204
+ r = ops.quantized.add_scalar(x, y)
205
+ # Note: this operation is not observed because the observation is not
206
+ # needed for the quantized op.
207
+ return r
208
+
209
+ r"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, Tensor)``"""
210
+ def mul(self, x: Tensor, y: Tensor) -> Tensor:
211
+ r = ops.quantized.mul(x, y, scale=self.scale, zero_point=self.zero_point)
212
+ r = self.activation_post_process(r)
213
+ return r
214
+
215
+ r"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, float)``"""
216
+ def mul_scalar(self, x: Tensor, y: float) -> Tensor:
217
+ r = ops.quantized.mul_scalar(x, y)
218
+ # Note: this operation is not observed because the observation is not
219
+ # needed for the quantized op.
220
+ return r
221
+
222
+ r"""Operation equivalent to ``torch.ops.quantized.cat``"""
223
+ def cat(self, x: List[Tensor], dim: int = 0) -> Tensor:
224
+ r = ops.quantized.cat(x, scale=self.scale, zero_point=self.zero_point, dim=dim)
225
+ r = self.activation_post_process(r)
226
+ return r
227
+
228
+ r"""Operation equivalent to ``torch.ops.quantized.add_relu``"""
229
+ def add_relu(self, x: Tensor, y: Tensor) -> Tensor:
230
+ r = ops.quantized.add_relu(x, y, scale=self.scale, zero_point=self.zero_point)
231
+ r = self.activation_post_process(r)
232
+ return r
233
+
234
+ r"""Operation equivalent to ``torch.ops.quantized.matmul(Tensor, Tensor)``"""
235
+ def matmul(self, x: Tensor, y: Tensor) -> Tensor:
236
+ r = ops.quantized.matmul(x, y, scale=self.scale, zero_point=self.zero_point)
237
+ # Note: this operation is not observed because the observation is not
238
+ # needed for the quantized op.
239
+ return r
240
+
241
+ @classmethod
242
+ def from_float(cls, mod):
243
+ assert type(mod) == FloatFunctional, \
244
+ "QFunctional.from_float expects an instance of FloatFunctional"
245
+ scale, zero_point = mod.activation_post_process.calculate_qparams() # type: ignore[operator]
246
+ new_mod = QFunctional()
247
+ new_mod.scale = float(scale)
248
+ new_mod.zero_point = int(zero_point)
249
+ return new_mod
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .linear import Linear
2
+ from .conv import Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
3
+ from .rnn import RNNCell, LSTMCell, GRUCell, LSTM, GRU
4
+ from .sparse import Embedding, EmbeddingBag
5
+
6
+ __all__ = [
7
+ 'Linear',
8
+ 'Conv1d',
9
+ 'Conv2d',
10
+ 'Conv3d',
11
+ 'ConvTranspose1d',
12
+ 'ConvTranspose2d',
13
+ 'ConvTranspose3d',
14
+ 'RNNCell',
15
+ 'LSTMCell',
16
+ 'GRUCell',
17
+ 'LSTM',
18
+ 'GRU',
19
+ 'Embedding',
20
+ 'EmbeddingBag',
21
+ ]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/linear.cpython-311.pyc ADDED
Binary file (3.71 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/sparse.cpython-311.pyc ADDED
Binary file (6.01 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/rnn.py ADDED
@@ -0,0 +1,614 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch import Tensor
4
+ from .utils import _quantize_and_dequantize_weight
5
+ from .utils import _quantize_weight
6
+ from typing import Optional, Dict, Any, Tuple
7
+ from torch import _VF
8
+ from torch.nn.utils.rnn import PackedSequence
9
+
10
+ __all__ = ['RNNCellBase', 'RNNCell', 'LSTMCell', 'GRUCell', 'RNNBase', 'LSTM', 'GRU', 'get_quantized_weight']
11
+
12
+ def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
13
+ return tensor.index_select(dim, permutation)
14
+
15
+ def _get_weight_and_quantization_params(module, wn):
16
+ weight = getattr(module, wn)
17
+ params = [weight]
18
+ for param_name in [wn + n for n in ["_qscheme", "_dtype", "_scale", "_zero_point", "_axis_int"]]:
19
+ if hasattr(module, param_name):
20
+ param = getattr(module, param_name)
21
+ else:
22
+ param = None
23
+ params.append(param)
24
+ return params
25
+
26
+ def get_quantized_weight(module, wn):
27
+ if not hasattr(module, wn):
28
+ return None
29
+ params = _get_weight_and_quantization_params(module, wn)
30
+ weight = _quantize_weight(*params)
31
+ return weight
32
+
33
+ def _get_quantize_and_dequantized_weight(module, wn):
34
+ if not hasattr(module, wn):
35
+ return None
36
+ params = _get_weight_and_quantization_params(module, wn)
37
+ weight = _quantize_and_dequantize_weight(*params)
38
+ return weight
39
+
40
+ class RNNCellBase(nn.RNNCellBase):
41
+ def __init__(self, input_size: int, hidden_size: int, bias: bool, num_chunks: int,
42
+ device=None, dtype=None, weight_qparams_dict=None) -> None:
43
+ super().__init__(input_size, hidden_size, bias, num_chunks, device=device, dtype=dtype)
44
+ # TODO(jerryzh168): maybe make this arg a required arg
45
+ if weight_qparams_dict is None:
46
+ weight_qparams = {
47
+ "qscheme": torch.per_tensor_affine,
48
+ "dtype": torch.quint8,
49
+ "scale": 1.0,
50
+ "zero_point": 0
51
+ }
52
+ weight_qparams_dict = {
53
+ "weight_ih": weight_qparams,
54
+ "weight_hh": weight_qparams,
55
+ "is_decomposed": False,
56
+ }
57
+ assert len(weight_qparams_dict) == 3, "Expected length for weight_qparams_dict to be 3 for QuantizedRNNCellBase(Reference)"
58
+ self._init_weight_qparams_dict(weight_qparams_dict, device)
59
+
60
+ def _init_weight_qparams_dict(self, weight_qparams_dict, device):
61
+ assert weight_qparams_dict is not None
62
+ self.is_decomposed = weight_qparams_dict["is_decomposed"]
63
+ for key, weight_qparams in weight_qparams_dict.items():
64
+ if key == "is_decomposed":
65
+ continue
66
+ # TODO: refactor the duplicated code to utils.py
67
+ weight_qscheme = weight_qparams["qscheme"]
68
+ weight_dtype = weight_qparams["dtype"]
69
+ setattr(self, key + "_qscheme", weight_qscheme)
70
+ setattr(self, key + "_dtype", weight_dtype)
71
+ assert weight_qscheme in [None, torch.per_tensor_affine, torch.per_channel_affine], \
72
+ Exception(f"qscheme: {weight_qscheme} is not support in {self._get_name()}")
73
+ if weight_qscheme is not None:
74
+ scale = weight_qparams["scale"]
75
+ scale_tensor = scale.clone().detach() \
76
+ if isinstance(scale, torch.Tensor) else \
77
+ torch.tensor(scale, dtype=torch.float, device=device)
78
+ self.register_buffer(key + "_scale", scale_tensor)
79
+ zp = weight_qparams["zero_point"]
80
+ zp_tensor = zp.clone().detach() \
81
+ if isinstance(zp, torch.Tensor) else \
82
+ torch.tensor(zp, dtype=torch.int, device=device)
83
+ self.register_buffer(key + "_zero_point", zp_tensor)
84
+ if weight_qscheme == torch.per_channel_affine:
85
+ axis = weight_qparams["axis"]
86
+ axis_tensor = axis.clone().detach() \
87
+ if isinstance(axis, torch.Tensor) else \
88
+ torch.tensor(axis, dtype=torch.int, device=device)
89
+ self.register_buffer(key + "_axis", axis_tensor)
90
+ else:
91
+ # added for TorchScriptability, not used
92
+ self.register_buffer(
93
+ key + "_axis", torch.tensor(0, dtype=torch.int, device=device))
94
+ setattr(self, key + "_axis_int", getattr(self, key + "_axis").item())
95
+
96
+ def _get_name(self):
97
+ return "QuantizedRNNCellBase(Reference)"
98
+
99
+ def get_quantized_weight_ih(self):
100
+ return get_quantized_weight(self, "weight_ih")
101
+
102
+ def get_quantized_weight_hh(self):
103
+ return get_quantized_weight(self, "weight_hh")
104
+
105
+ def get_weight_ih(self):
106
+ return _get_quantize_and_dequantized_weight(self, "weight_ih")
107
+
108
+ def get_weight_hh(self):
109
+ return _get_quantize_and_dequantized_weight(self, "weight_hh")
110
+
111
+ class RNNCell(RNNCellBase):
112
+ """
113
+ We'll store weight_qparams for all the weights (weight_ih and weight_hh),
114
+ we need to pass in a `weight_qparams_dict` that maps from weight name,
115
+ e.g. weight_ih, to the weight_qparams for that weight
116
+ """
117
+ def __init__(self, input_size: int, hidden_size: int, bias: bool = True, nonlinearity: str = "tanh",
118
+ device=None, dtype=None, weight_qparams_dict: Optional[Dict[str, Any]] = None) -> None:
119
+ factory_kwargs = {'device': device, 'dtype': dtype, 'weight_qparams_dict': weight_qparams_dict}
120
+ super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs)
121
+ self.nonlinearity = nonlinearity
122
+
123
+ def _get_name(self):
124
+ return "QuantizedRNNCell(Reference)"
125
+
126
+ # TODO: refactor nn.RNNCell to have a _forward that takes weight_ih and weight_hh as input
127
+ # and remove duplicated code, same for the other two Cell modules
128
+ def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
129
+ assert input.dim() in (1, 2), \
130
+ f"RNNCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
131
+ is_batched = input.dim() == 2
132
+ if not is_batched:
133
+ input = input.unsqueeze(0)
134
+
135
+ if hx is None:
136
+ hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
137
+ else:
138
+ hx = hx.unsqueeze(0) if not is_batched else hx
139
+
140
+ if self.nonlinearity == "tanh":
141
+ ret = _VF.rnn_tanh_cell(
142
+ input, hx,
143
+ self.get_weight_ih(), self.get_weight_hh(),
144
+ self.bias_ih, self.bias_hh,
145
+ )
146
+ elif self.nonlinearity == "relu":
147
+ ret = _VF.rnn_relu_cell(
148
+ input, hx,
149
+ self.get_weight_ih(), self.get_weight_hh(),
150
+ self.bias_ih, self.bias_hh,
151
+ )
152
+ else:
153
+ ret = input # TODO: remove when jit supports exception flow
154
+ raise RuntimeError(
155
+ f"Unknown nonlinearity: {self.nonlinearity}")
156
+
157
+ if not is_batched:
158
+ ret = ret.squeeze(0)
159
+
160
+ return ret
161
+
162
+ @classmethod
163
+ def from_float(cls, mod, weight_qparams_dict):
164
+ ref_mod = cls(
165
+ mod.input_size,
166
+ mod.hidden_size,
167
+ mod.bias,
168
+ mod.nonlinearity,
169
+ mod.weight_ih.device,
170
+ mod.weight_ih.dtype,
171
+ weight_qparams_dict)
172
+ ref_mod.weight_ih = mod.weight_ih
173
+ ref_mod.weight_hh = mod.weight_hh
174
+ ref_mod.bias_ih = mod.bias_ih
175
+ ref_mod.bias_hh = mod.bias_hh
176
+ return ref_mod
177
+
178
+ class LSTMCell(RNNCellBase):
179
+ """
180
+ We'll store weight_qparams for all the weights (weight_ih and weight_hh),
181
+ we need to pass in a `weight_qparams_dict` that maps from weight name,
182
+ e.g. weight_ih, to the weight_qparams for that weight
183
+ """
184
+ def __init__(self, input_size: int, hidden_size: int, bias: bool = True,
185
+ device=None, dtype=None, weight_qparams_dict: Optional[Dict[str, Any]] = None) -> None:
186
+ factory_kwargs = {'device': device, 'dtype': dtype, 'weight_qparams_dict': weight_qparams_dict}
187
+ super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs)
188
+
189
+ def _get_name(self):
190
+ return "QuantizedLSTMCell(Reference)"
191
+
192
+ def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
193
+ assert input.dim() in (1, 2), \
194
+ f"LSTMCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
195
+ is_batched = input.dim() == 2
196
+ if not is_batched:
197
+ input = input.unsqueeze(0)
198
+
199
+ if hx is None:
200
+ zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
201
+ hx = (zeros, zeros)
202
+ else:
203
+ hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx
204
+
205
+ ret = _VF.lstm_cell(
206
+ input, hx,
207
+ self.get_weight_ih(), self.get_weight_hh(),
208
+ self.bias_ih, self.bias_hh,
209
+ )
210
+
211
+ if not is_batched:
212
+ ret = (ret[0].squeeze(0), ret[1].squeeze(0))
213
+ return ret
214
+
215
+ @classmethod
216
+ def from_float(cls, mod, weight_qparams_dict):
217
+ ref_mod = cls(
218
+ mod.input_size,
219
+ mod.hidden_size,
220
+ mod.bias,
221
+ mod.weight_ih.device,
222
+ mod.weight_ih.dtype,
223
+ weight_qparams_dict)
224
+ ref_mod.weight_ih = mod.weight_ih
225
+ ref_mod.weight_hh = mod.weight_hh
226
+ ref_mod.bias_ih = mod.bias_ih
227
+ ref_mod.bias_hh = mod.bias_hh
228
+ return ref_mod
229
+
230
+ class GRUCell(RNNCellBase):
231
+ """
232
+ We'll store weight_qparams for all the weights (weight_ih and weight_hh),
233
+ we need to pass in a `weight_qparams_dict` that maps from weight name,
234
+ e.g. weight_ih, to the weight_qparams for that weight
235
+ """
236
+ def __init__(self, input_size: int, hidden_size: int, bias: bool = True,
237
+ device=None, dtype=None, weight_qparams_dict: Optional[Dict[str, Any]] = None) -> None:
238
+ factory_kwargs = {'device': device, 'dtype': dtype, 'weight_qparams_dict': weight_qparams_dict}
239
+ super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs)
240
+
241
+ def _get_name(self):
242
+ return "QuantizedGRUCell(Reference)"
243
+
244
+ def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
245
+ assert input.dim() in (1, 2), \
246
+ f"GRUCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
247
+ is_batched = input.dim() == 2
248
+ if not is_batched:
249
+ input = input.unsqueeze(0)
250
+
251
+ if hx is None:
252
+ hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
253
+ else:
254
+ hx = hx.unsqueeze(0) if not is_batched else hx
255
+
256
+ ret = _VF.gru_cell(
257
+ input, hx,
258
+ self.get_weight_ih(), self.get_weight_hh(),
259
+ self.bias_ih, self.bias_hh,
260
+ )
261
+
262
+ if not is_batched:
263
+ ret = ret.squeeze(0)
264
+
265
+ return ret
266
+
267
+ @classmethod
268
+ def from_float(cls, mod, weight_qparams_dict):
269
+ ref_mod = cls(
270
+ mod.input_size,
271
+ mod.hidden_size,
272
+ mod.bias,
273
+ mod.weight_ih.device,
274
+ mod.weight_ih.dtype,
275
+ weight_qparams_dict)
276
+ ref_mod.weight_ih = mod.weight_ih
277
+ ref_mod.weight_hh = mod.weight_hh
278
+ ref_mod.bias_ih = mod.bias_ih
279
+ ref_mod.bias_hh = mod.bias_hh
280
+ return ref_mod
281
+
282
+ class RNNBase(nn.RNNBase):
283
+ def __init__(self, mode: str, input_size: int, hidden_size: int,
284
+ num_layers: int = 1, bias: bool = True, batch_first: bool = False,
285
+ dropout: float = 0., bidirectional: bool = False, proj_size: int = 0,
286
+ device=None, dtype=None,
287
+ weight_qparams_dict: Optional[Dict[str, Any]] = None) -> None:
288
+ super().__init__(
289
+ mode, input_size, hidden_size, num_layers, bias, batch_first, dropout,
290
+ bidirectional, proj_size, device, dtype
291
+ )
292
+ # TODO(jerryzh168): maybe make this arg a required arg
293
+ if weight_qparams_dict is None:
294
+ weight_qparams = {
295
+ 'qscheme': torch.per_tensor_affine,
296
+ 'dtype': torch.quint8,
297
+ 'scale': 1.0,
298
+ 'zero_point': 0
299
+ }
300
+ weight_qparams_dict = {"is_decomposed": False} # type: ignore[dict-item]
301
+ for wn in self._flat_weights_names:
302
+ if wn.startswith("weight"):
303
+ weight_qparams_dict[wn] = weight_qparams
304
+ self._init_weight_qparams_dict(weight_qparams_dict, device)
305
+
306
+ def _init_weight_qparams_dict(self, weight_qparams_dict, device):
307
+ self.is_decomposed = weight_qparams_dict["is_decomposed"]
308
+ for key, weight_qparams in weight_qparams_dict.items():
309
+ if key == "is_decomposed":
310
+ continue
311
+ weight_qscheme = weight_qparams["qscheme"]
312
+ weight_dtype = weight_qparams["dtype"]
313
+ setattr(self, key + "_qscheme", weight_qscheme)
314
+ setattr(self, key + "_dtype", weight_dtype)
315
+ assert weight_qscheme in [None, torch.per_tensor_affine, torch.per_channel_affine], \
316
+ Exception(f"qscheme: {weight_qscheme} is not support in {self._get_name()}")
317
+ if weight_qscheme is not None:
318
+ self.register_buffer(
319
+ key + "_scale",
320
+ torch.tensor(weight_qparams["scale"], dtype=torch.float, device=device))
321
+ self.register_buffer(
322
+ key + "_zero_point",
323
+ torch.tensor(weight_qparams["zero_point"], dtype=torch.int, device=device))
324
+ if weight_qscheme == torch.per_channel_affine:
325
+ self.register_buffer(
326
+ key + "_axis",
327
+ torch.tensor(weight_qparams["axis"], dtype=torch.int, device=device))
328
+ else:
329
+ # added for TorchScriptability, not used
330
+ self.register_buffer(
331
+ key + "_axis", torch.tensor(0, dtype=torch.int, device=device))
332
+ setattr(self, key + "_axis_int", getattr(self, key + "_axis").item())
333
+
334
+ class LSTM(RNNBase):
335
+ """ Reference Quantized LSTM Module
336
+ We'll store weight_qparams for all the weights in _flat_weights, we need to pass in
337
+ a `weight_qparams_dict` that maps from weight name, e.g. weight_ih_l0,
338
+ to the weight_qparams for that weight
339
+ """
340
+ def __init__(self, *args, **kwargs):
341
+ super().__init__('LSTM', *args, **kwargs)
342
+
343
+ # Same as above, see torch/nn/modules/module.py::_forward_unimplemented
344
+ def permute_hidden(self, # type: ignore[override]
345
+ hx: Tuple[Tensor, Tensor],
346
+ permutation: Optional[Tensor]
347
+ ) -> Tuple[Tensor, Tensor]:
348
+ if permutation is None:
349
+ return hx
350
+ return _apply_permutation(hx[0], permutation), _apply_permutation(hx[1], permutation)
351
+
352
+ def get_expected_cell_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]:
353
+ if batch_sizes is not None:
354
+ mini_batch = int(batch_sizes[0])
355
+ else:
356
+ mini_batch = input.size(0) if self.batch_first else input.size(1)
357
+ num_directions = 2 if self.bidirectional else 1
358
+ expected_hidden_size = (self.num_layers * num_directions,
359
+ mini_batch, self.hidden_size)
360
+ return expected_hidden_size
361
+
362
+ # In the future, we should prevent mypy from applying contravariance rules here.
363
+ # See torch/nn/modules/module.py::_forward_unimplemented
364
+ def check_forward_args(self, # type: ignore[override]
365
+ input: Tensor,
366
+ hidden: Tuple[Tensor, Tensor],
367
+ batch_sizes: Optional[Tensor],
368
+ ):
369
+ self.check_input(input, batch_sizes)
370
+ self.check_hidden_size(hidden[0], self.get_expected_hidden_size(input, batch_sizes),
371
+ 'Expected hidden[0] size {}, got {}')
372
+ self.check_hidden_size(hidden[1], self.get_expected_cell_size(input, batch_sizes),
373
+ 'Expected hidden[1] size {}, got {}')
374
+
375
+ def get_quantized_weight_bias_dict(self):
376
+ """ dictionary from flat_weight_name to quantized weight or (unquantized) bias
377
+ e.g.
378
+ {
379
+ "weight_ih_l0": quantized_weight,
380
+ "bias_ih_l0": unquantized_bias,
381
+ ...
382
+ }
383
+ """
384
+ quantized_weight_bias_dict = {}
385
+ for wn in self._flat_weights_names:
386
+ if hasattr(self, wn):
387
+ if wn.startswith("weight"):
388
+ weight_or_bias = get_quantized_weight(self, wn)
389
+ else:
390
+ weight_or_bias = getattr(self, wn)
391
+ else:
392
+ weight_or_bias = None
393
+ quantized_weight_bias_dict[wn] = weight_or_bias
394
+ return quantized_weight_bias_dict
395
+
396
+ def get_flat_weights(self):
397
+ flat_weights = []
398
+ for wn in self._flat_weights_names:
399
+ if hasattr(self, wn):
400
+ weight = getattr(self, wn)
401
+ if wn.startswith("weight"):
402
+ params = _get_weight_and_quantization_params(self, wn)
403
+ weight = _quantize_and_dequantize_weight(*params)
404
+ else:
405
+ weight = None
406
+ flat_weights.append(weight)
407
+ return flat_weights
408
+
409
+ def forward(self, input, hx=None): # noqa: F811
410
+ orig_input = input
411
+ # xxx: isinstance check needs to be in conditional for TorchScript to compile
412
+ batch_sizes = None
413
+ if isinstance(orig_input, PackedSequence):
414
+ input, batch_sizes, sorted_indices, unsorted_indices = input
415
+ max_batch_size = int(batch_sizes[0])
416
+ else:
417
+ batch_sizes = None
418
+ is_batched = input.dim() == 3
419
+ batch_dim = 0 if self.batch_first else 1
420
+ if not is_batched:
421
+ input = input.unsqueeze(batch_dim)
422
+ max_batch_size = input.size(0) if self.batch_first else input.size(1)
423
+ sorted_indices = None
424
+ unsorted_indices = None
425
+
426
+ if hx is None:
427
+ num_directions = 2 if self.bidirectional else 1
428
+ real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size
429
+ h_zeros = torch.zeros(self.num_layers * num_directions,
430
+ max_batch_size, real_hidden_size,
431
+ dtype=input.dtype, device=input.device)
432
+ c_zeros = torch.zeros(self.num_layers * num_directions,
433
+ max_batch_size, self.hidden_size,
434
+ dtype=input.dtype, device=input.device)
435
+ hx = (h_zeros, c_zeros)
436
+ else:
437
+ if batch_sizes is None: # If not PackedSequence input.
438
+ if is_batched: # type: ignore[possibly-undefined]
439
+ if (hx[0].dim() != 3 or hx[1].dim() != 3):
440
+ msg = ("For batched 3-D input, hx and cx should "
441
+ f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
442
+ raise RuntimeError(msg)
443
+ else:
444
+ if hx[0].dim() != 2 or hx[1].dim() != 2:
445
+ msg = ("For unbatched 2-D input, hx and cx should "
446
+ f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
447
+ raise RuntimeError(msg)
448
+ hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1))
449
+
450
+ # Each batch of the hidden state should match the input sequence that
451
+ # the user believes he/she is passing in.
452
+ hx = self.permute_hidden(hx, sorted_indices)
453
+
454
+ self.check_forward_args(input, hx, batch_sizes)
455
+ if batch_sizes is None:
456
+ result = _VF.lstm(input, hx, self.get_flat_weights(), self.bias, self.num_layers,
457
+ self.dropout, self.training, self.bidirectional, self.batch_first)
458
+ else:
459
+ result = _VF.lstm(input, batch_sizes, hx, self.get_flat_weights(), self.bias,
460
+ self.num_layers, self.dropout, self.training, self.bidirectional)
461
+ output = result[0]
462
+ hidden = result[1:]
463
+ # xxx: isinstance check needs to be in conditional for TorchScript to compile
464
+ if isinstance(orig_input, PackedSequence):
465
+ output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
466
+ return output_packed, self.permute_hidden(hidden, unsorted_indices)
467
+ else:
468
+ if not is_batched: # type: ignore[possibly-undefined]
469
+ output = output.squeeze(batch_dim) # type: ignore[possibly-undefined]
470
+ hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))
471
+ return output, self.permute_hidden(hidden, unsorted_indices)
472
+
473
+ def _get_name(self):
474
+ return "QuantizedLSTM(Reference)"
475
+
476
+ @classmethod
477
+ def from_float(cls, mod, weight_qparams_dict):
478
+ ref_mod = cls(
479
+ mod.input_size,
480
+ mod.hidden_size,
481
+ mod.num_layers,
482
+ mod.bias,
483
+ mod.batch_first,
484
+ mod.dropout,
485
+ mod.bidirectional,
486
+ weight_qparams_dict=weight_qparams_dict)
487
+ for wn in mod._flat_weights_names:
488
+ setattr(ref_mod, wn, getattr(mod, wn))
489
+ return ref_mod
490
+
491
+ class GRU(RNNBase):
492
+ """ Reference Quantized GRU Module
493
+ We'll store weight_qparams for all the weights in _flat_weights, we need to pass in
494
+ a `weight_qparams_dict` that maps from weight name, e.g. weight_ih_l0,
495
+ to the weight_qparams for that weight
496
+ """
497
+ def __init__(self, *args, **kwargs):
498
+ if 'proj_size' in kwargs:
499
+ raise ValueError("proj_size argument is only supported for LSTM, not RNN or GRU")
500
+ super().__init__('GRU', *args, **kwargs)
501
+
502
+ def get_quantized_weight_bias_dict(self):
503
+ """ dictionary from flat_weight_name to quantized weight or (unquantized) bias
504
+ e.g.
505
+ {
506
+ "weight_ih_l0": quantized_weight,
507
+ "bias_ih_l0": unquantized_bias,
508
+ ...
509
+ }
510
+ """
511
+ quantized_weight_bias_dict = {}
512
+ for wn in self._flat_weights_names:
513
+ if hasattr(self, wn):
514
+ if wn.startswith("weight"):
515
+ weight_or_bias = get_quantized_weight(self, wn)
516
+ else:
517
+ weight_or_bias = getattr(self, wn)
518
+ else:
519
+ weight_or_bias = None
520
+ quantized_weight_bias_dict[wn] = weight_or_bias
521
+ return quantized_weight_bias_dict
522
+
523
+ def get_flat_weights(self):
524
+ flat_weights = []
525
+ for wn in self._flat_weights_names:
526
+ if hasattr(self, wn):
527
+ weight = getattr(self, wn)
528
+ if wn.startswith("weight"):
529
+ params = _get_weight_and_quantization_params(self, wn)
530
+ weight = _quantize_and_dequantize_weight(*params)
531
+ else:
532
+ weight = None
533
+ flat_weights.append(weight)
534
+ return flat_weights
535
+
536
+ def forward(self, input, hx=None): # noqa: F811
537
+ # Note: this is copied from the forward of GRU in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py
538
+ # only changed self._flat_weights to self.get_flat_weights()
539
+ # TODO: maybe we can try inheriting from that class and define get_flat_weights
540
+ # as a @property? this might interfere with TorchScript, if we remove that
541
+ # requirement in the future we should be able to do this
542
+ orig_input = input
543
+ # xxx: isinstance check needs to be in conditional for TorchScript to compile
544
+ if isinstance(orig_input, PackedSequence):
545
+ input, batch_sizes, sorted_indices, unsorted_indices = input
546
+ max_batch_size = int(batch_sizes[0])
547
+ else:
548
+ batch_sizes = None
549
+ assert (input.dim() in (2, 3)), f"GRU: Expected input to be 2-D or 3-D but received {input.dim()}-D tensor"
550
+ is_batched = input.dim() == 3
551
+ batch_dim = 0 if self.batch_first else 1
552
+ if not is_batched:
553
+ input = input.unsqueeze(batch_dim)
554
+ if hx is not None:
555
+ if hx.dim() != 2:
556
+ raise RuntimeError(
557
+ f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor")
558
+ hx = hx.unsqueeze(1)
559
+ else:
560
+ if hx is not None and hx.dim() != 3:
561
+ raise RuntimeError(
562
+ f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor")
563
+ max_batch_size = input.size(0) if self.batch_first else input.size(1)
564
+ sorted_indices = None
565
+ unsorted_indices = None
566
+
567
+ if hx is None:
568
+ num_directions = 2 if self.bidirectional else 1
569
+ hx = torch.zeros(self.num_layers * num_directions,
570
+ max_batch_size, self.hidden_size,
571
+ dtype=input.dtype, device=input.device)
572
+ else:
573
+ # Each batch of the hidden state should match the input sequence that
574
+ # the user believes he/she is passing in.
575
+ hx = self.permute_hidden(hx, sorted_indices)
576
+
577
+ self.check_forward_args(input, hx, batch_sizes)
578
+ if batch_sizes is None:
579
+ result = _VF.gru(input, hx, self.get_flat_weights(), self.bias, self.num_layers,
580
+ self.dropout, self.training, self.bidirectional, self.batch_first)
581
+ else:
582
+ result = _VF.gru(input, batch_sizes, hx, self.get_flat_weights(), self.bias,
583
+ self.num_layers, self.dropout, self.training, self.bidirectional)
584
+ output = result[0]
585
+ hidden = result[1]
586
+
587
+ # xxx: isinstance check needs to be in conditional for TorchScript to compile
588
+ if isinstance(orig_input, PackedSequence):
589
+ output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
590
+ return output_packed, self.permute_hidden(hidden, unsorted_indices)
591
+ else:
592
+ if not is_batched: # type: ignore[possibly-undefined]
593
+ output = output.squeeze(batch_dim) # type: ignore[possibly-undefined]
594
+ hidden = hidden.squeeze(1)
595
+
596
+ return output, self.permute_hidden(hidden, unsorted_indices)
597
+
598
+ def _get_name(self):
599
+ return "QuantizedGRU(Reference)"
600
+
601
+ @classmethod
602
+ def from_float(cls, mod, weight_qparams_dict):
603
+ ref_mod = cls(
604
+ mod.input_size,
605
+ mod.hidden_size,
606
+ mod.num_layers,
607
+ mod.bias,
608
+ mod.batch_first,
609
+ mod.dropout,
610
+ mod.bidirectional,
611
+ weight_qparams_dict=weight_qparams_dict)
612
+ for wn in mod._flat_weights_names:
613
+ setattr(ref_mod, wn, getattr(mod, wn))
614
+ return ref_mod
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/utils.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import typing
3
+
4
+ __all__ = [
5
+ "ReferenceQuantizedModule",
6
+ ]
7
+
8
+ class ReferenceQuantizedModule(torch.nn.Module):
9
+ def _init_weight_qparams(self, weight_qparams, device):
10
+ if weight_qparams is None:
11
+ weight_qparams = {
12
+ "qscheme": torch.per_tensor_affine,
13
+ "dtype": torch.quint8,
14
+ "scale": 1.0,
15
+ "zero_point": 0
16
+ }
17
+ self.weight_qscheme: torch.qscheme = weight_qparams["qscheme"]
18
+ self.weight_dtype = weight_qparams["dtype"]
19
+ assert self.weight_qscheme in [
20
+ None, torch.per_tensor_affine, torch.per_channel_affine,
21
+ torch.per_channel_affine_float_qparams], \
22
+ Exception(f"qscheme: {self.weight_qscheme} is not support in reference quantized {self._get_name()}")
23
+ if self.weight_dtype in [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32]:
24
+ zero_point_dtype = weight_qparams["zero_point"].dtype if \
25
+ isinstance(weight_qparams["zero_point"], torch.Tensor) else \
26
+ torch.int
27
+ w_scale = weight_qparams["scale"]
28
+ w_scale_tensor = w_scale.clone().detach() \
29
+ if isinstance(w_scale, torch.Tensor) \
30
+ else torch.tensor(w_scale, dtype=torch.float, device=device)
31
+ self.register_buffer("weight_scale", w_scale_tensor)
32
+ w_zp = weight_qparams["zero_point"]
33
+ w_zp_tensor = w_zp.clone().detach() \
34
+ if isinstance(w_zp, torch.Tensor) \
35
+ else torch.tensor(w_zp, dtype=zero_point_dtype, device=device)
36
+ self.register_buffer("weight_zero_point", w_zp_tensor)
37
+ if self.weight_qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams]:
38
+ w_axis = weight_qparams["axis"]
39
+ w_axis_tensor = w_axis.clone().detach() \
40
+ if isinstance(w_axis, torch.Tensor) \
41
+ else torch.tensor(w_axis, dtype=torch.int, device=device)
42
+ self.register_buffer("weight_axis", w_axis_tensor)
43
+ else:
44
+ # added for TorchScriptability, not used
45
+ self.register_buffer(
46
+ "weight_axis", torch.tensor(0, dtype=torch.int, device=device))
47
+ else:
48
+ # added for TorchScriptability, and for torch.float
49
+ self.register_buffer("weight_scale", torch.tensor(1.0, dtype=torch.float, device=device))
50
+ self.register_buffer("weight_zero_point", torch.tensor(0, dtype=torch.int, device=device))
51
+ self.register_buffer(
52
+ "weight_axis", torch.tensor(0, dtype=torch.int, device=device))
53
+ self.is_decomposed: bool = weight_qparams.get("is_decomposed", False)
54
+ # store weight_axis as weight_axis_int due to some constraints of torchdynamo.export
55
+ # for capturing `.item` operations
56
+ self.weight_axis_int: int = self.weight_axis.item() # type: ignore[operator, assignment]
57
+ self.weight_quant_min: typing.Optional[int] = weight_qparams.get("quant_min", None)
58
+ self.weight_quant_max: typing.Optional[int] = weight_qparams.get("quant_max", None)
59
+
60
+ def get_weight(self):
61
+ """
62
+ Fake quantize (quantize and dequantize) the weight with
63
+ the quantization parameters for weight, this is used to
64
+ simulate the numerics for the quantized weight in a quantized
65
+ model
66
+ """
67
+ # suppress mypy warning
68
+ assert isinstance(self.weight_scale, torch.Tensor)
69
+ assert isinstance(self.weight_zero_point, torch.Tensor)
70
+ if self.is_decomposed:
71
+ return _quantize_and_dequantize_weight_decomposed(
72
+ self.weight, # type: ignore[arg-type]
73
+ self.weight_qscheme,
74
+ self.weight_dtype,
75
+ self.weight_scale,
76
+ self.weight_zero_point,
77
+ self.weight_axis_int,
78
+ self.weight_quant_min,
79
+ self.weight_quant_max)
80
+ else:
81
+ return _quantize_and_dequantize_weight(
82
+ self.weight, # type: ignore[arg-type]
83
+ self.weight_qscheme,
84
+ self.weight_dtype,
85
+ self.weight_scale,
86
+ self.weight_zero_point,
87
+ self.weight_axis_int)
88
+
89
+ def get_quantized_weight(self):
90
+ # suppress mypy warning
91
+ assert isinstance(self.weight_scale, torch.Tensor)
92
+ assert isinstance(self.weight_zero_point, torch.Tensor)
93
+ # assert isinstance(self.weight_axis, torch.Tensor)
94
+ if self.is_decomposed:
95
+ return _quantize_weight_decomposed(
96
+ self.weight, # type: ignore[arg-type]
97
+ self.weight_qscheme,
98
+ self.weight_dtype,
99
+ self.weight_scale,
100
+ self.weight_zero_point,
101
+ self.weight_axis_int,
102
+ self.weight_quant_min,
103
+ self.weight_quant_max)
104
+ else:
105
+ return _quantize_weight(
106
+ self.weight, # type: ignore[arg-type]
107
+ self.weight_qscheme,
108
+ self.weight_dtype,
109
+ self.weight_scale,
110
+ self.weight_zero_point,
111
+ self.weight_axis_int)
112
+
113
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
114
+ super()._save_to_state_dict(destination, prefix, keep_vars)
115
+ _save_weight_qparams(
116
+ destination, prefix, self.weight_qscheme, self.weight_dtype,
117
+ self.weight_scale, self.weight_zero_point, self.weight_axis)
118
+
119
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
120
+ missing_keys, unexpected_keys, error_msgs):
121
+ for key in _get_weight_qparam_keys(state_dict, prefix):
122
+ setattr(self, key, state_dict[prefix + key])
123
+ state_dict.pop(prefix + key)
124
+
125
+ super()._load_from_state_dict(
126
+ state_dict, prefix, local_metadata, False,
127
+ missing_keys, unexpected_keys, error_msgs)
128
+
129
+ def _quantize_weight_decomposed(
130
+ weight: torch.Tensor,
131
+ weight_qscheme: torch.qscheme,
132
+ weight_dtype: torch.dtype,
133
+ weight_scale: torch.Tensor,
134
+ weight_zero_point: torch.Tensor,
135
+ weight_axis: int,
136
+ weight_quant_min: typing.Optional[int],
137
+ weight_quant_max: typing.Optional[int],
138
+ ) -> torch.Tensor:
139
+ _DTYPE_TO_QVALUE_BOUNDS = {
140
+ torch.uint8: (0, 255),
141
+ torch.int8: (-128, 127),
142
+ torch.int32: (-(2**31), 2**31 - 1),
143
+ }
144
+ # TODO: add an util function for converting qdtype to dtype
145
+ _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE = {
146
+ torch.quint8: torch.uint8,
147
+ torch.qint8: torch.int8,
148
+ torch.qint32: torch.int32,
149
+ }
150
+ if weight_qscheme == torch.per_tensor_affine:
151
+ if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
152
+ weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype]
153
+ if weight_quant_min is None or weight_quant_max is None:
154
+ weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype_]
155
+ weight = torch.ops.quantized_decomposed.quantize_per_tensor(
156
+ weight,
157
+ weight_scale,
158
+ weight_zero_point,
159
+ weight_quant_min,
160
+ weight_quant_max,
161
+ weight_dtype_
162
+ )
163
+ return weight
164
+ elif weight_qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams]:
165
+ # TODO: torch.quint4x2 is not supported
166
+ if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
167
+ weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype]
168
+ if weight_quant_min is None or weight_quant_max is None:
169
+ weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype_]
170
+ weight = torch.ops.quantized_decomposed.quantize_per_channel(
171
+ weight,
172
+ weight_scale,
173
+ weight_zero_point,
174
+ weight_axis,
175
+ weight_quant_min,
176
+ weight_quant_max,
177
+ weight_dtype_) # type: ignore[arg-type]
178
+ return weight
179
+ raise Exception(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}")
180
+
181
+ def _dequantize_weight_decomposed(
182
+ weight: torch.Tensor,
183
+ weight_qscheme: torch.qscheme,
184
+ weight_dtype: torch.dtype,
185
+ weight_scale: torch.Tensor,
186
+ weight_zero_point: torch.Tensor,
187
+ weight_axis: int,
188
+ weight_quant_min: typing.Optional[int],
189
+ weight_quant_max: typing.Optional[int],
190
+ ) -> torch.Tensor:
191
+ # TODO: get the quant_min and quant_max from activation_post_process
192
+ _DTYPE_TO_QVALUE_BOUNDS = {
193
+ torch.uint8: (0, 255),
194
+ torch.int8: (-128, 127),
195
+ torch.int32: (-(2**31), 2**31 - 1),
196
+ }
197
+ # TODO: add an util function for converting qdtype to dtype
198
+ _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE = {
199
+ torch.quint8: torch.uint8,
200
+ torch.qint8: torch.int8,
201
+ torch.qint32: torch.int32,
202
+ }
203
+ weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype]
204
+ if weight_quant_min is None or weight_quant_max is None:
205
+ weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype_]
206
+ if weight_qscheme == torch.per_tensor_affine:
207
+ if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
208
+ weight = torch.ops.quantized_decomposed.dequantize_per_tensor(
209
+ weight,
210
+ weight_scale,
211
+ weight_zero_point,
212
+ weight_quant_min,
213
+ weight_quant_max,
214
+ weight_dtype_
215
+ )
216
+ return weight
217
+ elif weight_qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams]:
218
+ # TODO: torch.quint4x2 is not supported
219
+ if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
220
+ weight = torch.ops.quantized_decomposed.dequantize_per_channel(
221
+ weight,
222
+ weight_scale,
223
+ weight_zero_point,
224
+ weight_axis,
225
+ weight_quant_min,
226
+ weight_quant_max,
227
+ weight_dtype_) # type: ignore[arg-type]
228
+ return weight
229
+ raise Exception(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}")
230
+
231
+ def _quantize_weight(
232
+ weight: torch.Tensor,
233
+ weight_qscheme: torch.qscheme,
234
+ weight_dtype: torch.dtype,
235
+ weight_scale: torch.Tensor,
236
+ weight_zero_point: torch.Tensor,
237
+ weight_axis_int: int
238
+ ) -> torch.Tensor:
239
+ if weight_dtype == torch.float16:
240
+ weight = weight.to(weight_dtype)
241
+ return weight
242
+
243
+ if weight_qscheme == torch.per_tensor_affine:
244
+ if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
245
+ weight = torch.quantize_per_tensor(weight, weight_scale, weight_zero_point, weight_dtype)
246
+ return weight
247
+ elif weight_qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams]:
248
+ if weight_dtype in [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32]:
249
+ weight = torch.quantize_per_channel(
250
+ weight, weight_scale,
251
+ weight_zero_point, weight_axis_int, weight_dtype) # type: ignore[arg-type]
252
+ return weight
253
+ raise Exception(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}")
254
+
255
+ def _quantize_and_dequantize_weight_decomposed(
256
+ weight: torch.Tensor,
257
+ weight_qscheme: torch.qscheme,
258
+ weight_dtype: torch.dtype,
259
+ weight_scale: torch.Tensor,
260
+ weight_zero_point: torch.Tensor,
261
+ weight_axis_int: int,
262
+ weight_quant_min: typing.Optional[int],
263
+ weight_quant_max: typing.Optional[int],
264
+ ) -> torch.Tensor:
265
+ """ Quantize and then dequantize the weight based on
266
+ the quantization parameters
267
+ """
268
+ if weight_qscheme in [
269
+ torch.per_tensor_affine,
270
+ torch.per_channel_affine,
271
+ torch.per_channel_affine_float_qparams]:
272
+ weight_quant = _quantize_weight_decomposed(
273
+ weight, weight_qscheme, weight_dtype, weight_scale, weight_zero_point, weight_axis_int,
274
+ weight_quant_min, weight_quant_max)
275
+ weight_dequant = _dequantize_weight_decomposed(
276
+ weight_quant, weight_qscheme, weight_dtype, weight_scale, weight_zero_point,
277
+ weight_axis_int, weight_quant_min, weight_quant_max)
278
+ else:
279
+ weight_dequant = weight
280
+ return weight_dequant
281
+
282
+ def _quantize_and_dequantize_weight(
283
+ weight: torch.Tensor,
284
+ weight_qscheme: torch.qscheme,
285
+ weight_dtype: torch.dtype,
286
+ weight_scale: torch.Tensor,
287
+ weight_zero_point: torch.Tensor,
288
+ weight_axis_int: int
289
+ ) -> torch.Tensor:
290
+ """ Quantize and then dequantize the weight based on
291
+ the quantization parameters
292
+ """
293
+ if weight_qscheme in [
294
+ torch.per_tensor_affine,
295
+ torch.per_channel_affine,
296
+ torch.per_channel_affine_float_qparams]:
297
+ weight_quant = _quantize_weight(
298
+ weight, weight_qscheme, weight_dtype, weight_scale, weight_zero_point, weight_axis_int)
299
+ weight_dequant = weight_quant.dequantize()
300
+ else:
301
+ weight_dequant = weight
302
+ return weight_dequant
303
+
304
+ def _save_weight_qparams(destination, prefix, weight_qscheme, weight_dtype, weight_scale, weight_zero_point, weight_axis):
305
+ destination[prefix + "weight_qscheme"] = weight_qscheme
306
+ destination[prefix + "weight_dtype"] = weight_dtype
307
+ if weight_qscheme is not None:
308
+ destination[prefix + "weight_scale"] = weight_scale
309
+ destination[prefix + "weight_zero_point"] = weight_zero_point
310
+ if weight_qscheme == torch.per_channel_affine:
311
+ destination[prefix + "weight_axis"] = weight_axis
312
+
313
+ def _get_weight_qparam_keys(
314
+ state_dict: typing.Dict[str, typing.Any],
315
+ prefix: str):
316
+ keys = ["weight_qscheme", "weight_dtype"]
317
+ weight_qscheme = state_dict[prefix + "weight_qscheme"]
318
+ if weight_qscheme is not None:
319
+ keys.append("weight_scale")
320
+ keys.append("weight_zero_point")
321
+ if weight_qscheme == torch.quantize_per_channel:
322
+ keys.append("weight_axis")
323
+ return keys
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/sparse/quantized/__pycache__/utils.cpython-311.pyc ADDED
Binary file (2.36 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/sparse/quantized/dynamic/linear.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.ao.nn.intrinsic as nni
5
+
6
+ from torch.ao.nn.sparse.quantized import linear
7
+ from torch.ao.nn.sparse.quantized.utils import LinearBlockSparsePattern
8
+ from torch.ao.nn.quantized.modules.utils import _quantize_weight, _hide_packed_params_repr
9
+
10
+ __all__ = ['Linear']
11
+
12
+ class Linear(torch.nn.Module):
13
+ r"""
14
+ A dynamically quantized sparse linear module with float tensor as inputs and outputs.
15
+ """
16
+ _version = 1
17
+ _op_type = "sparse_dynamic"
18
+ _FLOAT_MODULE = torch.nn.Linear
19
+
20
+ def __init__(self, in_features, out_features, row_block_size, col_block_size, bias=True, dtype=torch.qint8):
21
+ super().__init__()
22
+
23
+ if dtype != torch.qint8:
24
+ raise NotImplementedError("Only QINT8 is supported for Sparse Quantized Linear Dynamic")
25
+
26
+ self.in_features = in_features
27
+ self.out_features = out_features
28
+
29
+ if bias:
30
+ bias = torch.zeros(self.out_features, dtype=torch.float)
31
+ else:
32
+ bias = None
33
+
34
+ qweight = torch._empty_affine_quantized([out_features, in_features],
35
+ scale=1, zero_point=0, dtype=torch.qint8)
36
+ self._packed_params = linear.LinearPackedParams(row_block_size=row_block_size,
37
+ col_block_size=col_block_size,
38
+ dtype=dtype)
39
+ self._packed_params.set_weight_bias(qweight, bias, row_block_size, col_block_size)
40
+
41
+ def _get_name(self):
42
+ return 'SparseQuantizedDynamicLinear'
43
+
44
+ def extra_repr(self):
45
+ return f'in_features={self.in_features}, out_features={self.out_features}, qscheme={self.weight().qscheme()}'
46
+
47
+ def __repr__(self):
48
+ return _hide_packed_params_repr(self, linear.LinearPackedParams)
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ return torch.ops.sparse.qlinear_dynamic(x, self._packed_params._packed_params)
52
+
53
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
54
+ super()._save_to_state_dict(destination, prefix, keep_vars)
55
+ destination[prefix + 'op_type'] = self._op_type
56
+
57
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
58
+ missing_keys, unexpected_keys, error_msgs):
59
+ op_type = int(state_dict[prefix + 'op_type'])
60
+ assert op_type == 'sparse', \
61
+ f"Cannot load from op_type [{op_type}], expecting [{self._op_type}]"
62
+ state_dict.pop(prefix + 'op_type')
63
+
64
+ version = local_metadata.get('version', None)
65
+ assert version <= self._version
66
+
67
+ # Is this code valid? In old quantization it seemed to be used to load
68
+ # older model
69
+ weight = state_dict.pop(prefix + 'weight')
70
+ bias = state_dict.pop(prefix + 'bias')
71
+ state_dict.update({prefix + '_packed_params.weight': weight,
72
+ prefix + '_packed_params.bias': bias})
73
+
74
+ super()._load_from_state_dict(
75
+ state_dict, prefix, local_metadata, False,
76
+ missing_keys, unexpected_keys, error_msgs)
77
+
78
+ def _weight_bias(self):
79
+ return self._packed_params._weight_bias()
80
+
81
+ def weight(self):
82
+ return self._weight_bias()[0]
83
+
84
+ def bias(self):
85
+ return self._weight_bias()[1]
86
+
87
+ def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor],
88
+ row_block_size: Optional[int], col_block_size: Optional[int]) -> None:
89
+ assert row_block_size is not None and col_block_size is not None
90
+ self.out_features = w.shape[0]
91
+ self.in_features = w.shape[1]
92
+ self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size)
93
+
94
+ @classmethod
95
+ def from_float(cls, mod):
96
+ r"""Create a quantized sparse dynamic module from a float module.
97
+
98
+ We only care about the convert at this stage, no need for observers just yet.
99
+ """
100
+ assert type(mod) == cls._FLOAT_MODULE, ' nnq.' + cls.__name__ + '.from_float only works for ' + \
101
+ cls._FLOAT_MODULE.__name__
102
+ # TODO: Need to add options to qconfig to avoid the calibration.
103
+ # TODO: Add calibration for the sparsity
104
+ assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
105
+ if type(mod) == nni.LinearReLU:
106
+ mod = mod[0]
107
+ if mod.qconfig is not None and mod.qconfig.weight is not None:
108
+ weight_observer = mod.qconfig.weight()
109
+ else:
110
+ # We have the circular import issues if we import the qconfig in the beginning of this file:
111
+ # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
112
+ # import until we need it.
113
+ from torch.ao.quantization.qconfig import default_dynamic_qconfig
114
+ weight_observer = default_dynamic_qconfig.weight()
115
+
116
+ # It is important to multiply by the mask BEFORE calling the `weight_observer`
117
+ # TODO (zaf): Mask might not be part of the qconfig (T83295194)
118
+ weight = mod.weight
119
+ if getattr(mod.qconfig, 'mask', False):
120
+ weight = mod.qconfig.mask * mod.weight
121
+
122
+ weight_observer(weight)
123
+ dtype = weight_observer.dtype
124
+ assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
125
+ w_sc, w_zp = weight_observer.calculate_qparams()
126
+ if isinstance(w_zp, torch.Tensor):
127
+ assert not torch.any(w_zp.bool()), "All weight zero points must map to 0"
128
+ else:
129
+ assert w_zp == 0, 'Weight zero point must map to 0'
130
+ qweight = _quantize_weight(weight.float(), weight_observer)
131
+
132
+ row_block_size, col_block_size = LinearBlockSparsePattern.block_size()
133
+ qlinear = cls(mod.in_features,
134
+ mod.out_features,
135
+ row_block_size,
136
+ col_block_size,
137
+ dtype=dtype)
138
+ qlinear.set_weight_bias(qweight, mod.bias, row_block_size, col_block_size)
139
+ return qlinear
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (231 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_scheduler/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .base_data_scheduler import BaseDataScheduler
2
+
3
+ __all__ = [
4
+ "BaseDataScheduler",
5
+ ]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/quantization_utils.cpython-311.pyc ADDED
Binary file (6.74 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.ao.pruning.sparsifier.utils import module_to_fqn, fqn_to_module
4
+ from typing import Dict, List, Optional
5
+
6
+ SUPPORTED_MODULES = {
7
+ nn.Embedding,
8
+ nn.EmbeddingBag
9
+ }
10
+
11
+
12
+ def _fetch_all_embeddings(model):
13
+ """Fetches Embedding and EmbeddingBag modules from the model
14
+ """
15
+ embedding_modules = []
16
+ stack = [model]
17
+ while stack:
18
+ module = stack.pop()
19
+ for _, child in module.named_children():
20
+ fqn_name = module_to_fqn(model, child)
21
+ if type(child) in SUPPORTED_MODULES:
22
+ embedding_modules.append((fqn_name, child))
23
+ else:
24
+ stack.append(child)
25
+ return embedding_modules
26
+
27
+
28
+ def post_training_sparse_quantize(model,
29
+ data_sparsifier_class,
30
+ sparsify_first=True,
31
+ select_embeddings: Optional[List[nn.Module]] = None,
32
+ **sparse_config):
33
+ """Takes in a model and applies sparsification and quantization to only embeddings & embeddingbags.
34
+ The quantization step can happen before or after sparsification depending on the `sparsify_first` argument.
35
+
36
+ Args:
37
+ - model (nn.Module)
38
+ model whose embeddings needs to be sparsified
39
+ - data_sparsifier_class (type of data sparsifier)
40
+ Type of sparsification that needs to be applied to model
41
+ - sparsify_first (bool)
42
+ if true, sparsifies first and then quantizes
43
+ otherwise, quantizes first and then sparsifies.
44
+ - select_embeddings (List of Embedding modules)
45
+ List of embedding modules to in the model to be sparsified & quantized.
46
+ If None, all embedding modules with be sparsified
47
+ - sparse_config (Dict)
48
+ config that will be passed to the constructor of data sparsifier object.
49
+
50
+ Note:
51
+ 1. When `sparsify_first=False`, quantization occurs first followed by sparsification.
52
+ - before sparsifying, the embedding layers are dequantized.
53
+ - scales and zero-points are saved
54
+ - embedding layers are sparsified and `squash_mask` is applied
55
+ - embedding weights are requantized using the saved scales and zero-points
56
+ 2. When `sparsify_first=True`, sparsification occurs first followed by quantization.
57
+ - embeddings are sparsified first
58
+ - quantization is applied on the sparsified embeddings
59
+ """
60
+ data_sparsifier = data_sparsifier_class(**sparse_config)
61
+
62
+ # if select_embeddings is None, perform it on all embeddings
63
+ if select_embeddings is None:
64
+ embedding_modules = _fetch_all_embeddings(model)
65
+
66
+ else:
67
+ embedding_modules = []
68
+ assert isinstance(select_embeddings, List), "the embedding_modules must be a list of embedding modules"
69
+ for emb in select_embeddings:
70
+ assert type(emb) in SUPPORTED_MODULES, "the embedding_modules list must be an embedding or embedding bags"
71
+ fqn_name = module_to_fqn(model, emb)
72
+ assert fqn_name is not None, "the embedding modules must be part of input model"
73
+ embedding_modules.append((fqn_name, emb))
74
+
75
+ if sparsify_first:
76
+ # sparsify
77
+ for name, emb_module in embedding_modules:
78
+ valid_name = name.replace('.', '_')
79
+ data_sparsifier.add_data(name=valid_name, data=emb_module)
80
+
81
+ data_sparsifier.step()
82
+ data_sparsifier.squash_mask()
83
+
84
+ # quantize
85
+ for _, emb_module in embedding_modules:
86
+ emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig
87
+
88
+ torch.ao.quantization.prepare(model, inplace=True)
89
+ torch.ao.quantization.convert(model, inplace=True)
90
+
91
+ else:
92
+ # quantize
93
+ for _, emb_module in embedding_modules:
94
+ emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig
95
+
96
+ torch.ao.quantization.prepare(model, inplace=True)
97
+ torch.ao.quantization.convert(model, inplace=True)
98
+
99
+ # retrieve scale & zero_points
100
+ quantize_params: Dict[str, Dict] = {'scales': {}, 'zero_points': {},
101
+ 'dequant_weights': {}, 'axis': {},
102
+ 'dtype': {}}
103
+
104
+ for name, _ in embedding_modules:
105
+ quantized_emb = fqn_to_module(model, name)
106
+ assert quantized_emb is not None # satisfy mypy
107
+
108
+ quantized_weight = quantized_emb.weight() # type: ignore[operator]
109
+ quantize_params['scales'][name] = quantized_weight.q_per_channel_scales()
110
+ quantize_params['zero_points'][name] = quantized_weight.q_per_channel_zero_points()
111
+ quantize_params['dequant_weights'][name] = torch.dequantize(quantized_weight)
112
+ quantize_params['axis'][name] = quantized_weight.q_per_channel_axis()
113
+ quantize_params['dtype'][name] = quantized_weight.dtype
114
+
115
+ # attach data to sparsifier
116
+ data_sparsifier.add_data(name=name.replace('.', '_'), data=quantize_params['dequant_weights'][name])
117
+
118
+ data_sparsifier.step()
119
+ data_sparsifier.squash_mask()
120
+
121
+ for name, _ in embedding_modules:
122
+ quantized_emb = fqn_to_module(model, name)
123
+ assert quantized_emb is not None # satisfy mypy
124
+ requantized_vector = torch.quantize_per_channel(quantize_params['dequant_weights'][name],
125
+ scales=quantize_params['scales'][name],
126
+ zero_points=quantize_params['zero_points'][name],
127
+ dtype=quantize_params['dtype'][name],
128
+ axis=quantize_params['axis'][name])
129
+
130
+ quantized_emb.set_weight(requantized_vector) # type: ignore[operator]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/FPGM_pruner.cpython-311.pyc ADDED
Binary file (5.52 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (659 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/pruner/saliency_pruner.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_structured_sparsifier import BaseStructuredSparsifier
2
+
3
+
4
+ class SaliencyPruner(BaseStructuredSparsifier):
5
+ """
6
+ Prune rows based on the saliency (L1 norm) of each row.
7
+
8
+ This pruner works on N-Dimensional weight tensors.
9
+ For each row, we will calculate the saliency, whic is the sum the L1 norm of all weights in that row.
10
+ We expect that the resulting saliency vector has the same shape as our mask.
11
+ We then pick elements to remove until we reach the target sparsity_level.
12
+ """
13
+
14
+ def update_mask(self, module, tensor_name, **kwargs):
15
+ # tensor_name will give you the FQN, all other entries in sparse config is present in kwargs
16
+ weights = getattr(module, tensor_name)
17
+ mask = getattr(module.parametrizations, tensor_name)[0].mask
18
+
19
+ # use negative weights so we can use topk (we prune out the smallest)
20
+ if weights.dim() <= 1:
21
+ raise Exception("Structured pruning can only be applied to a 2+dim weight tensor!")
22
+ saliency = -weights.norm(dim=tuple(range(1, weights.dim())), p=1)
23
+ assert saliency.shape == mask.shape
24
+
25
+ num_to_pick = int(len(mask) * kwargs["sparsity_level"])
26
+ prune = saliency.topk(num_to_pick).indices
27
+
28
+ # Set the mask to be false for the rows we want to prune
29
+ mask.data[prune] = False
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/scheduler/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/scheduler/__pycache__/base_scheduler.cpython-311.pyc ADDED
Binary file (8.9 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/scheduler/lambda_scheduler.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ from .base_scheduler import BaseScheduler
4
+
5
+ __all__ = ["LambdaSL"]
6
+
7
+ class LambdaSL(BaseScheduler):
8
+ """Sets the sparsity level of each parameter group to the final sl
9
+ times a given function. When last_epoch=-1, sets initial sl as zero.
10
+ Args:
11
+ sparsifier (BaseSparsifier): Wrapped sparsifier.
12
+ sl_lambda (function or list): A function which computes a multiplicative
13
+ factor given an integer parameter epoch, or a list of such
14
+ functions, one for each group in sparsifier.param_groups.
15
+ last_epoch (int): The index of last epoch. Default: -1.
16
+ verbose (bool): If ``True``, prints a message to stdout for
17
+ each update. Default: ``False``.
18
+ Example:
19
+ >>> # Assuming sparsifier has two groups.
20
+ >>> lambda1 = lambda epoch: epoch // 30
21
+ >>> lambda2 = lambda epoch: 0.95 ** epoch
22
+ >>> # xdoctest: +SKIP
23
+ >>> scheduler = LambdaSL(sparsifier, sl_lambda=[lambda1, lambda2])
24
+ >>> for epoch in range(100):
25
+ >>> train(...)
26
+ >>> validate(...)
27
+ >>> scheduler.step()
28
+ """
29
+
30
+ def __init__(self, sparsifier, sl_lambda, last_epoch=-1, verbose=False):
31
+ self.sparsifier = sparsifier
32
+
33
+ if not isinstance(sl_lambda, list) and not isinstance(sl_lambda, tuple):
34
+ self.sl_lambdas = [sl_lambda] * len(sparsifier.groups)
35
+ else:
36
+ if len(sl_lambda) != len(sparsifier.groups):
37
+ raise ValueError(f"Expected {len(sparsifier.groups)} lr_lambdas, but got {len(sl_lambda)}")
38
+ self.sl_lambdas = list(sl_lambda)
39
+ super().__init__(sparsifier, last_epoch, verbose)
40
+
41
+ def get_sl(self):
42
+ if not self._get_sl_called_within_step:
43
+ warnings.warn(
44
+ "To get the last sparsity level computed by the scheduler, "
45
+ "please use `get_last_sl()`.")
46
+ return [base_sl * lmbda(self.last_epoch)
47
+ for lmbda, base_sl in zip(self.sl_lambdas, self.base_sl)]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/sparsifier/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (228 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/sparsifier/__pycache__/base_sparsifier.cpython-311.pyc ADDED
Binary file (17.5 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/sparsifier/__pycache__/weight_norm_sparsifier.cpython-311.pyc ADDED
Binary file (11.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/_equalize.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import copy
3
+ from typing import Dict, Any
4
+
5
+ __all__ = [
6
+ "set_module_weight",
7
+ "set_module_bias",
8
+ "get_module_weight",
9
+ "get_module_bias",
10
+ "max_over_ndim",
11
+ "min_over_ndim",
12
+ "channel_range",
13
+ "cross_layer_equalization",
14
+ "equalize",
15
+ "converged",
16
+ ]
17
+
18
+ _supported_types = {torch.nn.Conv2d, torch.nn.Linear}
19
+ _supported_intrinsic_types = {torch.ao.nn.intrinsic.ConvReLU2d, torch.ao.nn.intrinsic.LinearReLU}
20
+ _all_supported_types = _supported_types.union(_supported_intrinsic_types)
21
+
22
+ def set_module_weight(module, weight) -> None:
23
+ if type(module) in _supported_types:
24
+ module.weight = torch.nn.Parameter(weight)
25
+ else:
26
+ module[0].weight = torch.nn.Parameter(weight)
27
+
28
+ def set_module_bias(module, bias) -> None:
29
+ if type(module) in _supported_types:
30
+ module.bias = torch.nn.Parameter(bias)
31
+ else:
32
+ module[0].bias = torch.nn.Parameter(bias)
33
+
34
+ def get_module_weight(module):
35
+ if type(module) in _supported_types:
36
+ return module.weight
37
+ else:
38
+ return module[0].weight
39
+
40
+ def get_module_bias(module):
41
+ if type(module) in _supported_types:
42
+ return module.bias
43
+ else:
44
+ return module[0].bias
45
+
46
+ def max_over_ndim(input, axis_list, keepdim=False):
47
+ """Apply 'torch.max' over the given axes."""
48
+ axis_list.sort(reverse=True)
49
+ for axis in axis_list:
50
+ input, _ = input.max(axis, keepdim)
51
+ return input
52
+
53
+ def min_over_ndim(input, axis_list, keepdim=False):
54
+ """Apply 'torch.min' over the given axes."""
55
+ axis_list.sort(reverse=True)
56
+ for axis in axis_list:
57
+ input, _ = input.min(axis, keepdim)
58
+ return input
59
+
60
+ def channel_range(input, axis=0):
61
+ """Find the range of weights associated with a specific channel."""
62
+ size_of_tensor_dim = input.ndim
63
+ axis_list = list(range(size_of_tensor_dim))
64
+ axis_list.remove(axis)
65
+
66
+ mins = min_over_ndim(input, axis_list)
67
+ maxs = max_over_ndim(input, axis_list)
68
+
69
+ assert mins.size(0) == input.size(axis), "Dimensions of resultant channel range does not match size of requested axis"
70
+ return maxs - mins
71
+
72
+ def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1):
73
+ """Scale the range of Tensor1.output to equal Tensor2.input.
74
+
75
+ Given two adjacent tensors', the weights are scaled such that
76
+ the ranges of the first tensors' output channel are equal to the
77
+ ranges of the second tensors' input channel
78
+ """
79
+ if type(module1) not in _all_supported_types or type(module2) not in _all_supported_types:
80
+ raise ValueError("module type not supported:", type(module1), " ", type(module2))
81
+
82
+ weight1 = get_module_weight(module1)
83
+ weight2 = get_module_weight(module2)
84
+
85
+ if weight1.size(output_axis) != weight2.size(input_axis):
86
+ raise TypeError("Number of output channels of first arg do not match \
87
+ number input channels of second arg")
88
+
89
+ bias = get_module_bias(module1)
90
+
91
+ weight1_range = channel_range(weight1, output_axis)
92
+ weight2_range = channel_range(weight2, input_axis)
93
+
94
+ # producing scaling factors to applied
95
+ weight2_range += 1e-9
96
+ scaling_factors = torch.sqrt(weight1_range / weight2_range)
97
+ inverse_scaling_factors = torch.reciprocal(scaling_factors)
98
+
99
+ bias = bias * inverse_scaling_factors
100
+
101
+ # formatting the scaling (1D) tensors to be applied on the given argument tensors
102
+ # pads axis to (1D) tensors to then be broadcasted
103
+ size1 = [1] * weight1.ndim
104
+ size1[output_axis] = weight1.size(output_axis)
105
+ size2 = [1] * weight2.ndim
106
+ size2[input_axis] = weight2.size(input_axis)
107
+
108
+ scaling_factors = torch.reshape(scaling_factors, size2)
109
+ inverse_scaling_factors = torch.reshape(inverse_scaling_factors, size1)
110
+
111
+ weight1 = weight1 * inverse_scaling_factors
112
+ weight2 = weight2 * scaling_factors
113
+
114
+ set_module_weight(module1, weight1)
115
+ set_module_bias(module1, bias)
116
+ set_module_weight(module2, weight2)
117
+
118
+ def equalize(model, paired_modules_list, threshold=1e-4, inplace=True):
119
+ """Equalize modules until convergence is achieved.
120
+
121
+ Given a list of adjacent modules within a model, equalization will
122
+ be applied between each pair, this will repeated until convergence is achieved
123
+
124
+ Keeps a copy of the changing modules from the previous iteration, if the copies
125
+ are not that different than the current modules (determined by converged_test),
126
+ then the modules have converged enough that further equalizing is not necessary
127
+
128
+ Implementation of this referced section 4.1 of this paper https://arxiv.org/pdf/1906.04721.pdf
129
+
130
+ Args:
131
+ model: a model (nn.module) that equalization is to be applied on
132
+ paired_modules_list: a list of lists where each sublist is a pair of two
133
+ submodules found in the model, for each pair the two submodules generally
134
+ have to be adjacent in the model to get expected/reasonable results
135
+ threshold: a number used by the converged function to determine what degree
136
+ similarity between models is necessary for them to be called equivalent
137
+ inplace: determines if function is inplace or not
138
+ """
139
+ if not inplace:
140
+ model = copy.deepcopy(model)
141
+
142
+ name_to_module : Dict[str, torch.nn.Module] = {}
143
+ previous_name_to_module: Dict[str, Any] = {}
144
+ name_set = {name for pair in paired_modules_list for name in pair}
145
+
146
+ for name, module in model.named_modules():
147
+ if name in name_set:
148
+ name_to_module[name] = module
149
+ previous_name_to_module[name] = None
150
+ while not converged(name_to_module, previous_name_to_module, threshold):
151
+ for pair in paired_modules_list:
152
+ previous_name_to_module[pair[0]] = copy.deepcopy(name_to_module[pair[0]])
153
+ previous_name_to_module[pair[1]] = copy.deepcopy(name_to_module[pair[1]])
154
+
155
+ cross_layer_equalization(name_to_module[pair[0]], name_to_module[pair[1]])
156
+
157
+ return model
158
+
159
+ def converged(curr_modules, prev_modules, threshold=1e-4):
160
+ """Test whether modules are converged to a specified threshold.
161
+
162
+ Tests for the summed norm of the differences between each set of modules
163
+ being less than the given threshold
164
+
165
+ Takes two dictionaries mapping names to modules, the set of names for each dictionary
166
+ should be the same, looping over the set of names, for each name take the difference
167
+ between the associated modules in each dictionary
168
+
169
+ """
170
+ if curr_modules.keys() != prev_modules.keys():
171
+ raise ValueError("The keys to the given mappings must have the same set of names of modules")
172
+
173
+ summed_norms = torch.tensor(0.)
174
+ if None in prev_modules.values():
175
+ return False
176
+ for name in curr_modules.keys():
177
+ curr_weight = get_module_weight(curr_modules[name])
178
+ prev_weight = get_module_weight(prev_modules[name])
179
+
180
+ difference = curr_weight.sub(prev_weight)
181
+ summed_norms += torch.norm(difference)
182
+ return bool(summed_norms < threshold)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .backend_config import BackendConfig, BackendPatternConfig, DTypeConfig, DTypeWithConstraints, ObservationType
2
+ from .fbgemm import get_fbgemm_backend_config
3
+ from .native import get_native_backend_config, get_native_backend_config_dict
4
+ from .qnnpack import get_qnnpack_backend_config
5
+ from .tensorrt import get_tensorrt_backend_config, get_tensorrt_backend_config_dict
6
+ from .executorch import get_executorch_backend_config
7
+ from .onednn import get_onednn_backend_config
8
+
9
+ __all__ = [
10
+ "get_fbgemm_backend_config",
11
+ "get_native_backend_config",
12
+ "get_native_backend_config_dict",
13
+ "get_qnnpack_backend_config",
14
+ "get_tensorrt_backend_config",
15
+ "get_tensorrt_backend_config_dict",
16
+ "get_executorch_backend_config",
17
+ "BackendConfig",
18
+ "BackendPatternConfig",
19
+ "DTypeConfig",
20
+ "DTypeWithConstraints",
21
+ "ObservationType",
22
+ "get_onednn_backend_config",
23
+ ]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.08 kB). View file