Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/FlowControl.cpython-311-x86_64-linux-gnu.so +3 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Debugger/__pycache__/libpython.cpython-311.pyc +3 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/__pycache__/console.cpython-311.pyc +3 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_tensor_str.py +697 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/modules/fused.py +160 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/__init__.py +14 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py +175 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py +177 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/qat/dynamic/modules/linear.py +25 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/qat/modules/__init__.py +14 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantizable/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantizable/modules/activation.py +465 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/__pycache__/functional.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/batchnorm.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/conv.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/embedding_ops.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/linear.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/rnn.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/activation.py +302 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/conv.py +945 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/dropout.py +27 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/functional_modules.py +249 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__init__.py +21 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/linear.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/sparse.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/rnn.py +614 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/utils.py +323 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/sparse/quantized/__pycache__/utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/sparse/quantized/dynamic/linear.py +139 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_scheduler/__init__.py +5 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/quantization_utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py +130 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/FPGM_pruner.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/pruner/saliency_pruner.py +29 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/scheduler/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/scheduler/__pycache__/base_scheduler.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/scheduler/lambda_scheduler.py +47 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/sparsifier/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/sparsifier/__pycache__/base_sparsifier.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/sparsifier/__pycache__/weight_norm_sparsifier.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/_equalize.py +182 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__init__.py +23 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/__init__.cpython-311.pyc +0 -0
.gitattributes
CHANGED
|
@@ -66,3 +66,6 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycach
|
|
| 66 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/__pycache__/test_fp.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 67 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/__pycache__/function_docs.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 68 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/w64-arm.exe filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/__pycache__/test_fp.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 67 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/__pycache__/function_docs.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 68 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/w64-arm.exe filter=lfs diff=lfs merge=lfs -text
|
| 69 |
+
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/__pycache__/console.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 70 |
+
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Debugger/__pycache__/libpython.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 71 |
+
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/FlowControl.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/FlowControl.cpython-311-x86_64-linux-gnu.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b2c7b91ab0731f5672d976d4408f3525891b8c4e1d4ed4d403f56d1c141c7f94
|
| 3 |
+
size 688080
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Debugger/__pycache__/libpython.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e828bf211daa379b740684868a31081921397805bfc7ef4b41a8572d794eaafb
|
| 3 |
+
size 137864
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/__pycache__/console.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1a7da30d6865deaf94e2814884970e99b253843c23d4aa93b1107a23e61de6c1
|
| 3 |
+
size 123664
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_tensor_str.py
ADDED
|
@@ -0,0 +1,697 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
import dataclasses
|
| 3 |
+
import math
|
| 4 |
+
import textwrap
|
| 5 |
+
from typing import Any, Dict, Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import inf
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclasses.dataclass
|
| 12 |
+
class __PrinterOptions:
|
| 13 |
+
precision: int = 4
|
| 14 |
+
threshold: float = 1000
|
| 15 |
+
edgeitems: int = 3
|
| 16 |
+
linewidth: int = 80
|
| 17 |
+
sci_mode: Optional[bool] = None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
PRINT_OPTS = __PrinterOptions()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# We could use **kwargs, but this will give better docs
|
| 24 |
+
def set_printoptions(
|
| 25 |
+
precision=None,
|
| 26 |
+
threshold=None,
|
| 27 |
+
edgeitems=None,
|
| 28 |
+
linewidth=None,
|
| 29 |
+
profile=None,
|
| 30 |
+
sci_mode=None,
|
| 31 |
+
):
|
| 32 |
+
r"""Set options for printing. Items shamelessly taken from NumPy
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
precision: Number of digits of precision for floating point output
|
| 36 |
+
(default = 4).
|
| 37 |
+
threshold: Total number of array elements which trigger summarization
|
| 38 |
+
rather than full `repr` (default = 1000).
|
| 39 |
+
edgeitems: Number of array items in summary at beginning and end of
|
| 40 |
+
each dimension (default = 3).
|
| 41 |
+
linewidth: The number of characters per line for the purpose of
|
| 42 |
+
inserting line breaks (default = 80). Thresholded matrices will
|
| 43 |
+
ignore this parameter.
|
| 44 |
+
profile: Sane defaults for pretty printing. Can override with any of
|
| 45 |
+
the above options. (any one of `default`, `short`, `full`)
|
| 46 |
+
sci_mode: Enable (True) or disable (False) scientific notation. If
|
| 47 |
+
None (default) is specified, the value is defined by
|
| 48 |
+
`torch._tensor_str._Formatter`. This value is automatically chosen
|
| 49 |
+
by the framework.
|
| 50 |
+
|
| 51 |
+
Example::
|
| 52 |
+
|
| 53 |
+
>>> # Limit the precision of elements
|
| 54 |
+
>>> torch.set_printoptions(precision=2)
|
| 55 |
+
>>> torch.tensor([1.12345])
|
| 56 |
+
tensor([1.12])
|
| 57 |
+
>>> # Limit the number of elements shown
|
| 58 |
+
>>> torch.set_printoptions(threshold=5)
|
| 59 |
+
>>> torch.arange(10)
|
| 60 |
+
tensor([0, 1, 2, ..., 7, 8, 9])
|
| 61 |
+
>>> # Restore defaults
|
| 62 |
+
>>> torch.set_printoptions(profile='default')
|
| 63 |
+
>>> torch.tensor([1.12345])
|
| 64 |
+
tensor([1.1235])
|
| 65 |
+
>>> torch.arange(10)
|
| 66 |
+
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
| 67 |
+
|
| 68 |
+
"""
|
| 69 |
+
if profile is not None:
|
| 70 |
+
if profile == "default":
|
| 71 |
+
PRINT_OPTS.precision = 4
|
| 72 |
+
PRINT_OPTS.threshold = 1000
|
| 73 |
+
PRINT_OPTS.edgeitems = 3
|
| 74 |
+
PRINT_OPTS.linewidth = 80
|
| 75 |
+
elif profile == "short":
|
| 76 |
+
PRINT_OPTS.precision = 2
|
| 77 |
+
PRINT_OPTS.threshold = 1000
|
| 78 |
+
PRINT_OPTS.edgeitems = 2
|
| 79 |
+
PRINT_OPTS.linewidth = 80
|
| 80 |
+
elif profile == "full":
|
| 81 |
+
PRINT_OPTS.precision = 4
|
| 82 |
+
PRINT_OPTS.threshold = inf
|
| 83 |
+
PRINT_OPTS.edgeitems = 3
|
| 84 |
+
PRINT_OPTS.linewidth = 80
|
| 85 |
+
|
| 86 |
+
if precision is not None:
|
| 87 |
+
PRINT_OPTS.precision = precision
|
| 88 |
+
if threshold is not None:
|
| 89 |
+
PRINT_OPTS.threshold = threshold
|
| 90 |
+
if edgeitems is not None:
|
| 91 |
+
PRINT_OPTS.edgeitems = edgeitems
|
| 92 |
+
if linewidth is not None:
|
| 93 |
+
PRINT_OPTS.linewidth = linewidth
|
| 94 |
+
PRINT_OPTS.sci_mode = sci_mode
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_printoptions() -> Dict[str, Any]:
|
| 98 |
+
r"""Gets the current options for printing, as a dictionary that
|
| 99 |
+
can be passed as ``**kwargs`` to set_printoptions().
|
| 100 |
+
"""
|
| 101 |
+
return dataclasses.asdict(PRINT_OPTS)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@contextlib.contextmanager
|
| 105 |
+
def printoptions(**kwargs):
|
| 106 |
+
r"""Context manager that temporarily changes the print options. Accepted
|
| 107 |
+
arguments are same as :func:`set_printoptions`."""
|
| 108 |
+
old_kwargs = get_printoptions()
|
| 109 |
+
set_printoptions(**kwargs)
|
| 110 |
+
try:
|
| 111 |
+
yield
|
| 112 |
+
finally:
|
| 113 |
+
set_printoptions(**old_kwargs)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def tensor_totype(t):
|
| 117 |
+
dtype = torch.float if t.is_mps else torch.double
|
| 118 |
+
return t.to(dtype=dtype)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class _Formatter:
|
| 122 |
+
def __init__(self, tensor):
|
| 123 |
+
self.floating_dtype = tensor.dtype.is_floating_point
|
| 124 |
+
self.int_mode = True
|
| 125 |
+
self.sci_mode = False
|
| 126 |
+
self.max_width = 1
|
| 127 |
+
|
| 128 |
+
with torch.no_grad():
|
| 129 |
+
tensor_view = tensor.reshape(-1)
|
| 130 |
+
|
| 131 |
+
if not self.floating_dtype:
|
| 132 |
+
for value in tensor_view:
|
| 133 |
+
value_str = f"{value}"
|
| 134 |
+
self.max_width = max(self.max_width, len(value_str))
|
| 135 |
+
|
| 136 |
+
else:
|
| 137 |
+
nonzero_finite_vals = torch.masked_select(
|
| 138 |
+
tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0)
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
if nonzero_finite_vals.numel() == 0:
|
| 142 |
+
# no valid number, do nothing
|
| 143 |
+
return
|
| 144 |
+
|
| 145 |
+
# Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
|
| 146 |
+
nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs())
|
| 147 |
+
nonzero_finite_min = tensor_totype(nonzero_finite_abs.min())
|
| 148 |
+
nonzero_finite_max = tensor_totype(nonzero_finite_abs.max())
|
| 149 |
+
|
| 150 |
+
for value in nonzero_finite_vals:
|
| 151 |
+
if value != torch.ceil(value):
|
| 152 |
+
self.int_mode = False
|
| 153 |
+
break
|
| 154 |
+
|
| 155 |
+
if self.int_mode:
|
| 156 |
+
# in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites
|
| 157 |
+
# to indicate that the tensor is of floating type. add 1 to the len to account for this.
|
| 158 |
+
if (
|
| 159 |
+
nonzero_finite_max / nonzero_finite_min > 1000.0
|
| 160 |
+
or nonzero_finite_max > 1.0e8
|
| 161 |
+
):
|
| 162 |
+
self.sci_mode = True
|
| 163 |
+
for value in nonzero_finite_vals:
|
| 164 |
+
value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value)
|
| 165 |
+
self.max_width = max(self.max_width, len(value_str))
|
| 166 |
+
else:
|
| 167 |
+
for value in nonzero_finite_vals:
|
| 168 |
+
value_str = f"{value:.0f}"
|
| 169 |
+
self.max_width = max(self.max_width, len(value_str) + 1)
|
| 170 |
+
else:
|
| 171 |
+
# Check if scientific representation should be used.
|
| 172 |
+
if (
|
| 173 |
+
nonzero_finite_max / nonzero_finite_min > 1000.0
|
| 174 |
+
or nonzero_finite_max > 1.0e8
|
| 175 |
+
or nonzero_finite_min < 1.0e-4
|
| 176 |
+
):
|
| 177 |
+
self.sci_mode = True
|
| 178 |
+
for value in nonzero_finite_vals:
|
| 179 |
+
value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value)
|
| 180 |
+
self.max_width = max(self.max_width, len(value_str))
|
| 181 |
+
else:
|
| 182 |
+
for value in nonzero_finite_vals:
|
| 183 |
+
value_str = f"{{:.{PRINT_OPTS.precision}f}}".format(value)
|
| 184 |
+
self.max_width = max(self.max_width, len(value_str))
|
| 185 |
+
|
| 186 |
+
if PRINT_OPTS.sci_mode is not None:
|
| 187 |
+
self.sci_mode = PRINT_OPTS.sci_mode
|
| 188 |
+
|
| 189 |
+
def width(self):
|
| 190 |
+
return self.max_width
|
| 191 |
+
|
| 192 |
+
def format(self, value):
|
| 193 |
+
if self.floating_dtype:
|
| 194 |
+
if self.sci_mode:
|
| 195 |
+
ret = f"{{:{self.max_width}.{PRINT_OPTS.precision}e}}".format(value)
|
| 196 |
+
elif self.int_mode:
|
| 197 |
+
ret = f"{value:.0f}"
|
| 198 |
+
if not (math.isinf(value) or math.isnan(value)):
|
| 199 |
+
ret += "."
|
| 200 |
+
else:
|
| 201 |
+
ret = f"{{:.{PRINT_OPTS.precision}f}}".format(value)
|
| 202 |
+
else:
|
| 203 |
+
ret = f"{value}"
|
| 204 |
+
return (self.max_width - len(ret)) * " " + ret
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def _scalar_str(self, formatter1, formatter2=None):
|
| 208 |
+
if formatter2 is not None:
|
| 209 |
+
real_str = _scalar_str(self.real, formatter1)
|
| 210 |
+
imag_str = (_scalar_str(self.imag, formatter2) + "j").lstrip()
|
| 211 |
+
# handles negative numbers, +0.0, -0.0
|
| 212 |
+
if imag_str[0] == "+" or imag_str[0] == "-":
|
| 213 |
+
return real_str + imag_str
|
| 214 |
+
else:
|
| 215 |
+
return real_str + "+" + imag_str
|
| 216 |
+
else:
|
| 217 |
+
return formatter1.format(self.item())
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def _vector_str(self, indent, summarize, formatter1, formatter2=None):
|
| 221 |
+
# length includes spaces and comma between elements
|
| 222 |
+
element_length = formatter1.width() + 2
|
| 223 |
+
if formatter2 is not None:
|
| 224 |
+
# width for imag_formatter + an extra j for complex
|
| 225 |
+
element_length += formatter2.width() + 1
|
| 226 |
+
|
| 227 |
+
elements_per_line = max(
|
| 228 |
+
1, int(math.floor((PRINT_OPTS.linewidth - indent) / (element_length)))
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
def _val_formatter(val, formatter1=formatter1, formatter2=formatter2):
|
| 232 |
+
if formatter2 is not None:
|
| 233 |
+
real_str = formatter1.format(val.real)
|
| 234 |
+
imag_str = (formatter2.format(val.imag) + "j").lstrip()
|
| 235 |
+
# handles negative numbers, +0.0, -0.0
|
| 236 |
+
if imag_str[0] == "+" or imag_str[0] == "-":
|
| 237 |
+
return real_str + imag_str
|
| 238 |
+
else:
|
| 239 |
+
return real_str + "+" + imag_str
|
| 240 |
+
else:
|
| 241 |
+
return formatter1.format(val)
|
| 242 |
+
|
| 243 |
+
if summarize and not PRINT_OPTS.edgeitems:
|
| 244 |
+
# Deal with edge case that negative zero is zero
|
| 245 |
+
data = ["..."]
|
| 246 |
+
elif summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:
|
| 247 |
+
data = (
|
| 248 |
+
[_val_formatter(val) for val in self[: PRINT_OPTS.edgeitems].tolist()]
|
| 249 |
+
+ [" ..."]
|
| 250 |
+
+ [_val_formatter(val) for val in self[-PRINT_OPTS.edgeitems :].tolist()]
|
| 251 |
+
)
|
| 252 |
+
else:
|
| 253 |
+
data = [_val_formatter(val) for val in self.tolist()]
|
| 254 |
+
|
| 255 |
+
data_lines = [
|
| 256 |
+
data[i : i + elements_per_line] for i in range(0, len(data), elements_per_line)
|
| 257 |
+
]
|
| 258 |
+
lines = [", ".join(line) for line in data_lines]
|
| 259 |
+
return "[" + ("," + "\n" + " " * (indent + 1)).join(lines) + "]"
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
# formatter2 is only used for printing complex tensors.
|
| 263 |
+
# For complex tensors, formatter1 and formatter2 are the formatters for tensor.real
|
| 264 |
+
# and tensor.imag respesectively
|
| 265 |
+
def _tensor_str_with_formatter(self, indent, summarize, formatter1, formatter2=None):
|
| 266 |
+
dim = self.dim()
|
| 267 |
+
|
| 268 |
+
if dim == 0:
|
| 269 |
+
return _scalar_str(self, formatter1, formatter2)
|
| 270 |
+
|
| 271 |
+
if dim == 1:
|
| 272 |
+
return _vector_str(self, indent, summarize, formatter1, formatter2)
|
| 273 |
+
|
| 274 |
+
if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:
|
| 275 |
+
slices = (
|
| 276 |
+
[
|
| 277 |
+
_tensor_str_with_formatter(
|
| 278 |
+
self[i], indent + 1, summarize, formatter1, formatter2
|
| 279 |
+
)
|
| 280 |
+
for i in range(0, PRINT_OPTS.edgeitems)
|
| 281 |
+
]
|
| 282 |
+
+ ["..."]
|
| 283 |
+
+ [
|
| 284 |
+
_tensor_str_with_formatter(
|
| 285 |
+
self[i], indent + 1, summarize, formatter1, formatter2
|
| 286 |
+
)
|
| 287 |
+
for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))
|
| 288 |
+
]
|
| 289 |
+
)
|
| 290 |
+
else:
|
| 291 |
+
slices = [
|
| 292 |
+
_tensor_str_with_formatter(
|
| 293 |
+
self[i], indent + 1, summarize, formatter1, formatter2
|
| 294 |
+
)
|
| 295 |
+
for i in range(0, self.size(0))
|
| 296 |
+
]
|
| 297 |
+
|
| 298 |
+
tensor_str = ("," + "\n" * (dim - 1) + " " * (indent + 1)).join(slices)
|
| 299 |
+
return "[" + tensor_str + "]"
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def _tensor_str(self, indent):
|
| 303 |
+
if self.numel() == 0:
|
| 304 |
+
return "[]"
|
| 305 |
+
|
| 306 |
+
if self.has_names():
|
| 307 |
+
# There are two main codepaths (possibly more) that tensor printing goes through:
|
| 308 |
+
# - tensor data can fit comfortably on screen
|
| 309 |
+
# - tensor data needs to be summarized
|
| 310 |
+
# Some of the codepaths don't fully support named tensors, so we send in
|
| 311 |
+
# an unnamed tensor to the formatting code as a workaround.
|
| 312 |
+
self = self.rename(None)
|
| 313 |
+
|
| 314 |
+
summarize = self.numel() > PRINT_OPTS.threshold
|
| 315 |
+
|
| 316 |
+
if self._is_zerotensor():
|
| 317 |
+
self = self.clone()
|
| 318 |
+
|
| 319 |
+
# handle the negative bit
|
| 320 |
+
if self.is_neg():
|
| 321 |
+
self = self.resolve_neg()
|
| 322 |
+
|
| 323 |
+
if self.dtype in [
|
| 324 |
+
torch.float16,
|
| 325 |
+
torch.bfloat16,
|
| 326 |
+
torch.float8_e5m2,
|
| 327 |
+
torch.float8_e5m2fnuz,
|
| 328 |
+
torch.float8_e4m3fn,
|
| 329 |
+
torch.float8_e4m3fnuz,
|
| 330 |
+
]:
|
| 331 |
+
self = self.float()
|
| 332 |
+
|
| 333 |
+
if self.dtype is torch.complex32:
|
| 334 |
+
self = self.cfloat()
|
| 335 |
+
|
| 336 |
+
if self.dtype.is_complex:
|
| 337 |
+
# handle the conjugate bit
|
| 338 |
+
self = self.resolve_conj()
|
| 339 |
+
real_formatter = _Formatter(
|
| 340 |
+
get_summarized_data(self.real) if summarize else self.real
|
| 341 |
+
)
|
| 342 |
+
imag_formatter = _Formatter(
|
| 343 |
+
get_summarized_data(self.imag) if summarize else self.imag
|
| 344 |
+
)
|
| 345 |
+
return _tensor_str_with_formatter(
|
| 346 |
+
self, indent, summarize, real_formatter, imag_formatter
|
| 347 |
+
)
|
| 348 |
+
else:
|
| 349 |
+
formatter = _Formatter(get_summarized_data(self) if summarize else self)
|
| 350 |
+
return _tensor_str_with_formatter(self, indent, summarize, formatter)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def _add_suffixes(tensor_str, suffixes, indent, force_newline):
|
| 354 |
+
tensor_strs = [tensor_str]
|
| 355 |
+
last_line_len = len(tensor_str) - tensor_str.rfind("\n") + 1
|
| 356 |
+
for suffix in suffixes:
|
| 357 |
+
suffix_len = len(suffix)
|
| 358 |
+
if force_newline or last_line_len + suffix_len + 2 > PRINT_OPTS.linewidth:
|
| 359 |
+
tensor_strs.append(",\n" + " " * indent + suffix)
|
| 360 |
+
last_line_len = indent + suffix_len
|
| 361 |
+
force_newline = False
|
| 362 |
+
else:
|
| 363 |
+
tensor_strs.append(", " + suffix)
|
| 364 |
+
last_line_len += suffix_len + 2
|
| 365 |
+
tensor_strs.append(")")
|
| 366 |
+
return "".join(tensor_strs)
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def get_summarized_data(self):
|
| 370 |
+
dim = self.dim()
|
| 371 |
+
if dim == 0:
|
| 372 |
+
return self
|
| 373 |
+
if dim == 1:
|
| 374 |
+
if self.size(0) > 2 * PRINT_OPTS.edgeitems:
|
| 375 |
+
return torch.cat(
|
| 376 |
+
(self[: PRINT_OPTS.edgeitems], self[-PRINT_OPTS.edgeitems :])
|
| 377 |
+
)
|
| 378 |
+
else:
|
| 379 |
+
return self
|
| 380 |
+
if not PRINT_OPTS.edgeitems:
|
| 381 |
+
return self.new_empty([0] * self.dim())
|
| 382 |
+
elif self.size(0) > 2 * PRINT_OPTS.edgeitems:
|
| 383 |
+
start = [self[i] for i in range(0, PRINT_OPTS.edgeitems)]
|
| 384 |
+
end = [self[i] for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))]
|
| 385 |
+
return torch.stack([get_summarized_data(x) for x in (start + end)])
|
| 386 |
+
else:
|
| 387 |
+
return torch.stack([get_summarized_data(x) for x in self])
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def _str_intern(inp, *, tensor_contents=None):
|
| 391 |
+
if torch._C._functorch.is_functorch_wrapped_tensor(inp):
|
| 392 |
+
return _functorch_wrapper_str_intern(inp, tensor_contents=tensor_contents)
|
| 393 |
+
is_plain_tensor = type(inp) is torch.Tensor or type(inp) is torch.nn.Parameter
|
| 394 |
+
if inp.is_nested:
|
| 395 |
+
prefix = "nested_tensor("
|
| 396 |
+
elif is_plain_tensor:
|
| 397 |
+
prefix = "tensor("
|
| 398 |
+
else:
|
| 399 |
+
prefix = f"{type(inp).__name__}("
|
| 400 |
+
indent = len(prefix)
|
| 401 |
+
suffixes = []
|
| 402 |
+
custom_contents_provided = tensor_contents is not None
|
| 403 |
+
if custom_contents_provided:
|
| 404 |
+
tensor_str = tensor_contents
|
| 405 |
+
|
| 406 |
+
# This is used to extract the primal value and thus disable the forward AD
|
| 407 |
+
# within this function.
|
| 408 |
+
# TODO(albanD) This needs to be updated when more than one level is supported
|
| 409 |
+
self, tangent = torch.autograd.forward_ad.unpack_dual(inp)
|
| 410 |
+
|
| 411 |
+
# Note [Print tensor device]:
|
| 412 |
+
# A general logic here is we only print device when it doesn't match
|
| 413 |
+
# the device specified in default tensor type.
|
| 414 |
+
# Currently torch.set_default_tensor_type() only supports CPU/CUDA, thus
|
| 415 |
+
# torch._C._get_default_device() only returns either cpu or cuda.
|
| 416 |
+
# In other cases, we don't have a way to set them as default yet,
|
| 417 |
+
# and we should always print out device for them.
|
| 418 |
+
if (
|
| 419 |
+
self.device.type != torch._C._get_default_device()
|
| 420 |
+
or (
|
| 421 |
+
self.device.type == "cuda"
|
| 422 |
+
and torch.cuda.current_device() != self.device.index
|
| 423 |
+
)
|
| 424 |
+
or (self.device.type == "mps")
|
| 425 |
+
):
|
| 426 |
+
suffixes.append("device='" + str(self.device) + "'")
|
| 427 |
+
|
| 428 |
+
# Tensor printing performs tensor operations like slice, indexing, etc to make it in a
|
| 429 |
+
# representable format. These operations on ipu/xla/lazy/mtia tensor results in compilations. Hence,
|
| 430 |
+
# to avoid compilations, copying the tensor to cpu before printing.
|
| 431 |
+
if self.device.type in ["xla", "lazy", "ipu", "mtia"]:
|
| 432 |
+
self = self.to("cpu")
|
| 433 |
+
|
| 434 |
+
# TODO: add an API to map real -> complex dtypes
|
| 435 |
+
_default_complex_dtype = (
|
| 436 |
+
torch.cdouble if torch.get_default_dtype() == torch.double else torch.cfloat
|
| 437 |
+
)
|
| 438 |
+
has_default_dtype = self.dtype in (
|
| 439 |
+
torch.get_default_dtype(),
|
| 440 |
+
_default_complex_dtype,
|
| 441 |
+
torch.int64,
|
| 442 |
+
torch.bool,
|
| 443 |
+
)
|
| 444 |
+
if self.is_sparse:
|
| 445 |
+
suffixes.append("size=" + str(tuple(self.shape)))
|
| 446 |
+
from torch._subclasses.fake_tensor import FakeTensor
|
| 447 |
+
|
| 448 |
+
is_meta = self.is_meta or isinstance(self, FakeTensor)
|
| 449 |
+
if not is_meta:
|
| 450 |
+
suffixes.append("nnz=" + str(self._nnz()))
|
| 451 |
+
if not has_default_dtype:
|
| 452 |
+
suffixes.append("dtype=" + str(self.dtype))
|
| 453 |
+
if not custom_contents_provided:
|
| 454 |
+
indices_prefix = "indices=tensor("
|
| 455 |
+
indices = self._indices().detach()
|
| 456 |
+
if is_meta:
|
| 457 |
+
indices_str = "..."
|
| 458 |
+
else:
|
| 459 |
+
indices_str = _tensor_str(indices, indent + len(indices_prefix))
|
| 460 |
+
if indices.numel() == 0 or is_meta:
|
| 461 |
+
indices_str += ", size=" + str(tuple(indices.shape))
|
| 462 |
+
values_prefix = "values=tensor("
|
| 463 |
+
values = self._values().detach()
|
| 464 |
+
if is_meta:
|
| 465 |
+
values_str = "..."
|
| 466 |
+
else:
|
| 467 |
+
values_str = _tensor_str(values, indent + len(values_prefix))
|
| 468 |
+
if values.numel() == 0 or is_meta:
|
| 469 |
+
values_str += ", size=" + str(tuple(values.shape))
|
| 470 |
+
tensor_str = (
|
| 471 |
+
indices_prefix
|
| 472 |
+
+ indices_str
|
| 473 |
+
+ "),\n"
|
| 474 |
+
+ " " * indent
|
| 475 |
+
+ values_prefix
|
| 476 |
+
+ values_str
|
| 477 |
+
+ ")"
|
| 478 |
+
)
|
| 479 |
+
elif self.layout in {
|
| 480 |
+
torch.sparse_csr,
|
| 481 |
+
torch.sparse_csc,
|
| 482 |
+
torch.sparse_bsr,
|
| 483 |
+
torch.sparse_bsc,
|
| 484 |
+
}:
|
| 485 |
+
from torch._subclasses.fake_tensor import FakeTensor
|
| 486 |
+
|
| 487 |
+
suffixes.append("size=" + str(tuple(self.shape)))
|
| 488 |
+
is_meta = self.is_meta or isinstance(self, FakeTensor)
|
| 489 |
+
if not is_meta:
|
| 490 |
+
suffixes.append("nnz=" + str(self._nnz()))
|
| 491 |
+
if not has_default_dtype:
|
| 492 |
+
suffixes.append("dtype=" + str(self.dtype))
|
| 493 |
+
if not custom_contents_provided:
|
| 494 |
+
compressed_indices_method, plain_indices_method = {
|
| 495 |
+
torch.sparse_csr: (torch.Tensor.crow_indices, torch.Tensor.col_indices),
|
| 496 |
+
torch.sparse_csc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices),
|
| 497 |
+
torch.sparse_bsr: (torch.Tensor.crow_indices, torch.Tensor.col_indices),
|
| 498 |
+
torch.sparse_bsc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices),
|
| 499 |
+
}[self.layout]
|
| 500 |
+
if self.layout in {torch.sparse_csr, torch.sparse_bsr}:
|
| 501 |
+
cdimname, pdimname = "row", "column"
|
| 502 |
+
else:
|
| 503 |
+
cdimname, pdimname = "column", "row"
|
| 504 |
+
compressed_indices_prefix = f"c{cdimname[:3]}_indices=tensor("
|
| 505 |
+
compressed_indices = compressed_indices_method(self).detach()
|
| 506 |
+
if is_meta:
|
| 507 |
+
compressed_indices_str = "..."
|
| 508 |
+
else:
|
| 509 |
+
compressed_indices_str = _tensor_str(
|
| 510 |
+
compressed_indices, indent + len(compressed_indices_prefix)
|
| 511 |
+
)
|
| 512 |
+
if compressed_indices.numel() == 0 or is_meta:
|
| 513 |
+
compressed_indices_str += ", size=" + str(
|
| 514 |
+
tuple(compressed_indices.shape)
|
| 515 |
+
)
|
| 516 |
+
plain_indices_prefix = f"{pdimname[:3]}_indices=tensor("
|
| 517 |
+
plain_indices = plain_indices_method(self).detach()
|
| 518 |
+
if is_meta:
|
| 519 |
+
plain_indices_str = "..."
|
| 520 |
+
else:
|
| 521 |
+
plain_indices_str = _tensor_str(
|
| 522 |
+
plain_indices, indent + len(plain_indices_prefix)
|
| 523 |
+
)
|
| 524 |
+
if plain_indices.numel() == 0 or is_meta:
|
| 525 |
+
plain_indices_str += ", size=" + str(tuple(plain_indices.shape))
|
| 526 |
+
values_prefix = "values=tensor("
|
| 527 |
+
values = self.values().detach()
|
| 528 |
+
if is_meta:
|
| 529 |
+
values_str = "..."
|
| 530 |
+
else:
|
| 531 |
+
values_str = _tensor_str(values, indent + len(values_prefix))
|
| 532 |
+
if values.numel() == 0 or is_meta:
|
| 533 |
+
values_str += ", size=" + str(tuple(values.shape))
|
| 534 |
+
tensor_str = (
|
| 535 |
+
compressed_indices_prefix
|
| 536 |
+
+ compressed_indices_str
|
| 537 |
+
+ "),\n"
|
| 538 |
+
+ " " * indent
|
| 539 |
+
+ plain_indices_prefix
|
| 540 |
+
+ plain_indices_str
|
| 541 |
+
+ "),\n"
|
| 542 |
+
+ " " * indent
|
| 543 |
+
+ values_prefix
|
| 544 |
+
+ values_str
|
| 545 |
+
+ ")"
|
| 546 |
+
)
|
| 547 |
+
elif self.is_quantized:
|
| 548 |
+
suffixes.append("size=" + str(tuple(self.shape)))
|
| 549 |
+
if not has_default_dtype:
|
| 550 |
+
suffixes.append("dtype=" + str(self.dtype))
|
| 551 |
+
suffixes.append("quantization_scheme=" + str(self.qscheme()))
|
| 552 |
+
if (
|
| 553 |
+
self.qscheme() == torch.per_tensor_affine
|
| 554 |
+
or self.qscheme() == torch.per_tensor_symmetric
|
| 555 |
+
):
|
| 556 |
+
suffixes.append("scale=" + str(self.q_scale()))
|
| 557 |
+
suffixes.append("zero_point=" + str(self.q_zero_point()))
|
| 558 |
+
elif (
|
| 559 |
+
self.qscheme() == torch.per_channel_affine
|
| 560 |
+
or self.qscheme() == torch.per_channel_symmetric
|
| 561 |
+
or self.qscheme() == torch.per_channel_affine_float_qparams
|
| 562 |
+
):
|
| 563 |
+
suffixes.append("scale=" + str(self.q_per_channel_scales()))
|
| 564 |
+
suffixes.append("zero_point=" + str(self.q_per_channel_zero_points()))
|
| 565 |
+
suffixes.append("axis=" + str(self.q_per_channel_axis()))
|
| 566 |
+
if not custom_contents_provided:
|
| 567 |
+
tensor_str = _tensor_str(self.dequantize(), indent)
|
| 568 |
+
elif self.is_nested:
|
| 569 |
+
if not custom_contents_provided:
|
| 570 |
+
|
| 571 |
+
def indented_str(s, indent):
|
| 572 |
+
return "\n".join(f" {line}" for line in s.split("\n"))
|
| 573 |
+
|
| 574 |
+
strs = ",\n".join(
|
| 575 |
+
indented_str(str(t), indent + 1)
|
| 576 |
+
for t in torch.ops.aten.unbind.int(self, 0)
|
| 577 |
+
)
|
| 578 |
+
tensor_str = f"[\n{strs}\n]"
|
| 579 |
+
elif torch._is_functional_tensor(self):
|
| 580 |
+
prefix = "_to_functional_tensor("
|
| 581 |
+
tensor_str = repr(torch._from_functional_tensor(self))
|
| 582 |
+
else:
|
| 583 |
+
# Circular import problem, so we import it here
|
| 584 |
+
from torch._subclasses.fake_tensor import FakeTensor
|
| 585 |
+
|
| 586 |
+
if self.is_meta or isinstance(self, FakeTensor):
|
| 587 |
+
suffixes.append("size=" + str(tuple(self.shape)))
|
| 588 |
+
if self.dtype != torch.get_default_dtype():
|
| 589 |
+
suffixes.append("dtype=" + str(self.dtype))
|
| 590 |
+
# TODO: This implies that ellipses is valid syntax for allocating
|
| 591 |
+
# a meta tensor or FakeTensor, which it could be, but it isn't right now
|
| 592 |
+
if not custom_contents_provided:
|
| 593 |
+
tensor_str = "..."
|
| 594 |
+
else:
|
| 595 |
+
if self.numel() == 0 and not self.is_sparse:
|
| 596 |
+
# Explicitly print the shape if it is not (0,), to match NumPy behavior
|
| 597 |
+
if self.dim() != 1:
|
| 598 |
+
suffixes.append("size=" + str(tuple(self.shape)))
|
| 599 |
+
|
| 600 |
+
# In an empty tensor, there are no elements to infer if the dtype
|
| 601 |
+
# should be int64, so it must be shown explicitly.
|
| 602 |
+
if self.dtype != torch.get_default_dtype():
|
| 603 |
+
suffixes.append("dtype=" + str(self.dtype))
|
| 604 |
+
if not custom_contents_provided:
|
| 605 |
+
tensor_str = "[]"
|
| 606 |
+
else:
|
| 607 |
+
if not PRINT_OPTS.edgeitems:
|
| 608 |
+
suffixes.append("size=" + str(tuple(self.shape)))
|
| 609 |
+
|
| 610 |
+
if not has_default_dtype:
|
| 611 |
+
suffixes.append("dtype=" + str(self.dtype))
|
| 612 |
+
|
| 613 |
+
if not custom_contents_provided:
|
| 614 |
+
if self.layout != torch.strided:
|
| 615 |
+
tensor_str = _tensor_str(self.to_dense(), indent)
|
| 616 |
+
else:
|
| 617 |
+
tensor_str = _tensor_str(self, indent)
|
| 618 |
+
|
| 619 |
+
if self.layout != torch.strided:
|
| 620 |
+
suffixes.append("layout=" + str(self.layout))
|
| 621 |
+
|
| 622 |
+
# Use inp here to get the original grad_fn and not the one generated by the forward grad
|
| 623 |
+
# unpacking.
|
| 624 |
+
grad_fn_name = None
|
| 625 |
+
try:
|
| 626 |
+
grad_fn = inp.grad_fn
|
| 627 |
+
except RuntimeError:
|
| 628 |
+
# Accessing the grad_fn calls rebasing logic which would cause an error
|
| 629 |
+
# if that tensor is a view created in no-grad mode modified in-place in
|
| 630 |
+
# no-grad mode. See: https://github.com/pytorch/pytorch/issues/99968
|
| 631 |
+
grad_fn_name = "Invalid"
|
| 632 |
+
|
| 633 |
+
if grad_fn_name is None and grad_fn is not None: # type: ignore[possibly-undefined]
|
| 634 |
+
grad_fn_name = type(grad_fn).__name__
|
| 635 |
+
if grad_fn_name == "CppFunction":
|
| 636 |
+
grad_fn_name = grad_fn.name().rsplit("::", 1)[-1]
|
| 637 |
+
|
| 638 |
+
if grad_fn_name is not None:
|
| 639 |
+
suffixes.append(f"grad_fn=<{grad_fn_name}>")
|
| 640 |
+
elif inp.requires_grad:
|
| 641 |
+
suffixes.append("requires_grad=True")
|
| 642 |
+
|
| 643 |
+
if self.has_names():
|
| 644 |
+
suffixes.append(f"names={self.names}")
|
| 645 |
+
|
| 646 |
+
if tangent is not None:
|
| 647 |
+
suffixes.append(f"tangent={tangent}")
|
| 648 |
+
|
| 649 |
+
string_repr = _add_suffixes(
|
| 650 |
+
prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse # type: ignore[possibly-undefined]
|
| 651 |
+
)
|
| 652 |
+
|
| 653 |
+
# Check if this instance is flagged as a parameter and change the repr accordingly.
|
| 654 |
+
# Unfortunately, this function has to be aware of this detail.
|
| 655 |
+
# NB: This is currently skipped for plain tensor parameters to maintain BC. In the future,
|
| 656 |
+
# this should be done for those as well to produce a valid repr.
|
| 657 |
+
if isinstance(self, torch.nn.Parameter) and not is_plain_tensor:
|
| 658 |
+
string_repr = f"Parameter({string_repr})"
|
| 659 |
+
|
| 660 |
+
return string_repr
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
def _functorch_wrapper_str_intern(tensor, *, tensor_contents=None):
|
| 664 |
+
level = torch._C._functorch.maybe_get_level(tensor)
|
| 665 |
+
assert level != -1
|
| 666 |
+
|
| 667 |
+
if torch._C._functorch.is_functionaltensor(tensor):
|
| 668 |
+
# Since we're unwrapping the FunctionalTensorWrapper, we need to make sure
|
| 669 |
+
# that it's up to date first
|
| 670 |
+
torch._sync(tensor)
|
| 671 |
+
|
| 672 |
+
value = torch._C._functorch.get_unwrapped(tensor)
|
| 673 |
+
value_repr = repr(value)
|
| 674 |
+
|
| 675 |
+
indented_value_repr = textwrap.indent(value_repr, " " * 4)
|
| 676 |
+
if torch._C._functorch.is_batchedtensor(tensor):
|
| 677 |
+
bdim = torch._C._functorch.maybe_get_bdim(tensor)
|
| 678 |
+
assert bdim != -1
|
| 679 |
+
return (
|
| 680 |
+
f"BatchedTensor(lvl={level}, bdim={bdim}, value=\n"
|
| 681 |
+
f"{indented_value_repr}\n"
|
| 682 |
+
f")"
|
| 683 |
+
)
|
| 684 |
+
if torch._C._functorch.is_gradtrackingtensor(tensor):
|
| 685 |
+
return (
|
| 686 |
+
f"GradTrackingTensor(lvl={level}, value=\n" f"{indented_value_repr}\n" f")"
|
| 687 |
+
)
|
| 688 |
+
if torch._C._functorch.is_functionaltensor(tensor):
|
| 689 |
+
return f"FunctionalTensor(lvl={level}, value=\\\n{value_repr})"
|
| 690 |
+
|
| 691 |
+
raise ValueError("We don't know how to print this, please file us an issue")
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
def _str(self, *, tensor_contents=None):
|
| 695 |
+
with torch.no_grad(), torch.utils._python_dispatch._disable_current_modes():
|
| 696 |
+
guard = torch._C._DisableFuncTorch()
|
| 697 |
+
return _str_intern(self, tensor_contents=tensor_contents)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/modules/fused.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.nn import Conv1d, Conv2d, Conv3d, ReLU, Linear, BatchNorm1d, BatchNorm2d, BatchNorm3d
|
| 3 |
+
from torch.nn.utils.parametrize import type_before_parametrizations
|
| 4 |
+
|
| 5 |
+
__all__ = ['ConvReLU1d', 'ConvReLU2d', 'ConvReLU3d', 'LinearReLU', 'ConvBn1d', 'ConvBn2d',
|
| 6 |
+
'ConvBnReLU1d', 'ConvBnReLU2d', 'ConvBn3d', 'ConvBnReLU3d', 'BNReLU2d', 'BNReLU3d',
|
| 7 |
+
'LinearBn1d', 'LinearLeakyReLU', 'LinearTanh', 'ConvAdd2d', 'ConvAddReLU2d']
|
| 8 |
+
|
| 9 |
+
# Used for identifying intrinsic modules used in quantization
|
| 10 |
+
class _FusedModule(torch.nn.Sequential):
|
| 11 |
+
pass
|
| 12 |
+
|
| 13 |
+
class ConvReLU1d(_FusedModule):
|
| 14 |
+
r"""This is a sequential container which calls the Conv1d and ReLU modules.
|
| 15 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 16 |
+
def __init__(self, conv, relu):
|
| 17 |
+
assert type_before_parametrizations(conv) == Conv1d and type_before_parametrizations(relu) == ReLU, \
|
| 18 |
+
f'Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}'
|
| 19 |
+
super().__init__(conv, relu)
|
| 20 |
+
|
| 21 |
+
class ConvReLU2d(_FusedModule):
|
| 22 |
+
r"""This is a sequential container which calls the Conv2d and ReLU modules.
|
| 23 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 24 |
+
def __init__(self, conv, relu):
|
| 25 |
+
assert type_before_parametrizations(conv) == Conv2d and type_before_parametrizations(relu) == ReLU, \
|
| 26 |
+
f'Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}'
|
| 27 |
+
super().__init__(conv, relu)
|
| 28 |
+
|
| 29 |
+
class ConvReLU3d(_FusedModule):
|
| 30 |
+
r"""This is a sequential container which calls the Conv3d and ReLU modules.
|
| 31 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 32 |
+
def __init__(self, conv, relu):
|
| 33 |
+
assert type_before_parametrizations(conv) == Conv3d and type_before_parametrizations(relu) == ReLU, \
|
| 34 |
+
f'Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}'
|
| 35 |
+
super().__init__(conv, relu)
|
| 36 |
+
|
| 37 |
+
class LinearReLU(_FusedModule):
|
| 38 |
+
r"""This is a sequential container which calls the Linear and ReLU modules.
|
| 39 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 40 |
+
def __init__(self, linear, relu):
|
| 41 |
+
assert type_before_parametrizations(linear) == Linear and type_before_parametrizations(relu) == ReLU, \
|
| 42 |
+
'Incorrect types for input modules{}{}'.format(
|
| 43 |
+
type_before_parametrizations(linear), type_before_parametrizations(relu))
|
| 44 |
+
super().__init__(linear, relu)
|
| 45 |
+
|
| 46 |
+
class ConvBn1d(_FusedModule):
|
| 47 |
+
r"""This is a sequential container which calls the Conv 1d and Batch Norm 1d modules.
|
| 48 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 49 |
+
def __init__(self, conv, bn):
|
| 50 |
+
assert type_before_parametrizations(conv) == Conv1d and type_before_parametrizations(bn) == BatchNorm1d, \
|
| 51 |
+
f'Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}'
|
| 52 |
+
super().__init__(conv, bn)
|
| 53 |
+
|
| 54 |
+
class ConvBn2d(_FusedModule):
|
| 55 |
+
r"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules.
|
| 56 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 57 |
+
def __init__(self, conv, bn):
|
| 58 |
+
assert type_before_parametrizations(conv) == Conv2d and type_before_parametrizations(bn) == BatchNorm2d, \
|
| 59 |
+
f'Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}'
|
| 60 |
+
super().__init__(conv, bn)
|
| 61 |
+
|
| 62 |
+
class ConvBnReLU1d(_FusedModule):
|
| 63 |
+
r"""This is a sequential container which calls the Conv 1d, Batch Norm 1d, and ReLU modules.
|
| 64 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 65 |
+
def __init__(self, conv, bn, relu):
|
| 66 |
+
assert type_before_parametrizations(conv) == Conv1d and type_before_parametrizations(bn) == BatchNorm1d and \
|
| 67 |
+
type_before_parametrizations(relu) == ReLU, 'Incorrect types for input modules{}{}{}' \
|
| 68 |
+
.format(type_before_parametrizations(conv), type_before_parametrizations(bn), type_before_parametrizations(relu))
|
| 69 |
+
super().__init__(conv, bn, relu)
|
| 70 |
+
|
| 71 |
+
class ConvBnReLU2d(_FusedModule):
|
| 72 |
+
r"""This is a sequential container which calls the Conv 2d, Batch Norm 2d, and ReLU modules.
|
| 73 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 74 |
+
def __init__(self, conv, bn, relu):
|
| 75 |
+
assert type_before_parametrizations(conv) == Conv2d and type_before_parametrizations(bn) == BatchNorm2d and \
|
| 76 |
+
type_before_parametrizations(relu) == ReLU, 'Incorrect types for input modules{}{}{}' \
|
| 77 |
+
.format(type_before_parametrizations(conv), type_before_parametrizations(bn), type_before_parametrizations(relu))
|
| 78 |
+
super().__init__(conv, bn, relu)
|
| 79 |
+
|
| 80 |
+
class ConvBn3d(_FusedModule):
|
| 81 |
+
r"""This is a sequential container which calls the Conv 3d and Batch Norm 3d modules.
|
| 82 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 83 |
+
def __init__(self, conv, bn):
|
| 84 |
+
assert type_before_parametrizations(conv) == Conv3d and type_before_parametrizations(bn) == BatchNorm3d, \
|
| 85 |
+
f'Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}'
|
| 86 |
+
super().__init__(conv, bn)
|
| 87 |
+
|
| 88 |
+
class ConvBnReLU3d(_FusedModule):
|
| 89 |
+
r"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules.
|
| 90 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 91 |
+
def __init__(self, conv, bn, relu):
|
| 92 |
+
assert type_before_parametrizations(conv) == Conv3d and type_before_parametrizations(bn) == BatchNorm3d and \
|
| 93 |
+
type_before_parametrizations(relu) == ReLU, 'Incorrect types for input modules{}{}{}' \
|
| 94 |
+
.format(type_before_parametrizations(conv), type_before_parametrizations(bn), type_before_parametrizations(relu))
|
| 95 |
+
super().__init__(conv, bn, relu)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class BNReLU2d(_FusedModule):
|
| 99 |
+
r"""This is a sequential container which calls the BatchNorm 2d and ReLU modules.
|
| 100 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 101 |
+
def __init__(self, batch_norm, relu):
|
| 102 |
+
assert type_before_parametrizations(batch_norm) == BatchNorm2d and type_before_parametrizations(relu) == ReLU, \
|
| 103 |
+
'Incorrect types for input modules{}{}'.format(
|
| 104 |
+
type_before_parametrizations(batch_norm), type_before_parametrizations(relu))
|
| 105 |
+
super().__init__(batch_norm, relu)
|
| 106 |
+
|
| 107 |
+
class BNReLU3d(_FusedModule):
|
| 108 |
+
r"""This is a sequential container which calls the BatchNorm 3d and ReLU modules.
|
| 109 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 110 |
+
def __init__(self, batch_norm, relu):
|
| 111 |
+
assert type_before_parametrizations(batch_norm) == BatchNorm3d and type_before_parametrizations(relu) == ReLU, \
|
| 112 |
+
'Incorrect types for input modules{}{}'.format(
|
| 113 |
+
type_before_parametrizations(batch_norm), type_before_parametrizations(relu))
|
| 114 |
+
super().__init__(batch_norm, relu)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class LinearBn1d(_FusedModule):
|
| 118 |
+
r"""This is a sequential container which calls the Linear and BatchNorm1d modules.
|
| 119 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 120 |
+
def __init__(self, linear, bn):
|
| 121 |
+
assert type_before_parametrizations(linear) == Linear and type_before_parametrizations(bn) == BatchNorm1d, \
|
| 122 |
+
f'Incorrect types for input modules{type_before_parametrizations(linear)}{type_before_parametrizations(bn)}'
|
| 123 |
+
super().__init__(linear, bn)
|
| 124 |
+
|
| 125 |
+
class LinearLeakyReLU(_FusedModule):
|
| 126 |
+
r"""This is a sequential container which calls the Linear and LeakyReLU modules.
|
| 127 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 128 |
+
def __init__(self, linear, leaky_relu):
|
| 129 |
+
assert type(linear) == Linear and type(leaky_relu) == torch.nn.LeakyReLU, \
|
| 130 |
+
f'Incorrect types for input modules{type(linear)}{type(leaky_relu)}'
|
| 131 |
+
super().__init__(linear, leaky_relu)
|
| 132 |
+
|
| 133 |
+
class LinearTanh(_FusedModule):
|
| 134 |
+
r"""This is a sequential container which calls the Linear and Tanh modules.
|
| 135 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 136 |
+
def __init__(self, linear, tanh):
|
| 137 |
+
assert type(linear) == Linear and type(tanh) == torch.nn.Tanh, \
|
| 138 |
+
f'Incorrect types for input modules{type(linear)}{type(tanh)}'
|
| 139 |
+
super().__init__(linear, tanh)
|
| 140 |
+
|
| 141 |
+
class ConvAdd2d(_FusedModule):
|
| 142 |
+
r"""This is a sequential container which calls the Conv2d modules with extra Add.
|
| 143 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 144 |
+
def __init__(self, conv, add):
|
| 145 |
+
super().__init__(conv)
|
| 146 |
+
self.add = add
|
| 147 |
+
|
| 148 |
+
def forward(self, x1, x2):
|
| 149 |
+
return self.add(self[0](x1), x2)
|
| 150 |
+
|
| 151 |
+
class ConvAddReLU2d(_FusedModule):
|
| 152 |
+
r"""This is a sequential container which calls the Conv2d, add, Relu.
|
| 153 |
+
During quantization this will be replaced with the corresponding fused module."""
|
| 154 |
+
def __init__(self, conv, add, relu):
|
| 155 |
+
super().__init__(conv)
|
| 156 |
+
self.add = add
|
| 157 |
+
self.relu = relu
|
| 158 |
+
|
| 159 |
+
def forward(self, x1, x2):
|
| 160 |
+
return self.relu(self.add(self[0](x1), x2))
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .modules import * # noqa: F403
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
'BNReLU2d',
|
| 5 |
+
'BNReLU3d',
|
| 6 |
+
'ConvReLU1d',
|
| 7 |
+
'ConvReLU2d',
|
| 8 |
+
'ConvReLU3d',
|
| 9 |
+
'LinearReLU',
|
| 10 |
+
'LinearLeakyReLU',
|
| 11 |
+
'LinearTanh',
|
| 12 |
+
'ConvAdd2d',
|
| 13 |
+
'ConvAddReLU2d',
|
| 14 |
+
]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-311.pyc
ADDED
|
Binary file (3.56 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.ao.nn.intrinsic
|
| 4 |
+
import torch.ao.nn.intrinsic.qat
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import torch.ao.nn.quantized as nnq
|
| 7 |
+
|
| 8 |
+
from torch.nn.utils import fuse_conv_bn_weights
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"ConvReLU1d",
|
| 12 |
+
"ConvReLU2d",
|
| 13 |
+
"ConvReLU3d",
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
_reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding
|
| 17 |
+
|
| 18 |
+
# TODO: factor out the common parts to ConvNd
|
| 19 |
+
class ConvReLU1d(nnq.Conv1d):
|
| 20 |
+
r"""
|
| 21 |
+
A ConvReLU1d module is a fused module of Conv1d and ReLU
|
| 22 |
+
|
| 23 |
+
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv1d`.
|
| 24 |
+
|
| 25 |
+
Attributes:
|
| 26 |
+
Same as torch.ao.nn.quantized.Conv1d
|
| 27 |
+
|
| 28 |
+
"""
|
| 29 |
+
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU1d # type: ignore[assignment]
|
| 30 |
+
|
| 31 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
| 32 |
+
padding=0, dilation=1, groups=1, bias=True,
|
| 33 |
+
padding_mode='zeros', device=None, dtype=None):
|
| 34 |
+
super().__init__(
|
| 35 |
+
in_channels, out_channels, kernel_size, stride=stride,
|
| 36 |
+
padding=padding, dilation=dilation, groups=groups, bias=bias,
|
| 37 |
+
padding_mode=padding_mode, device=device, dtype=dtype)
|
| 38 |
+
|
| 39 |
+
def forward(self, input):
|
| 40 |
+
# Temporarily using len(shape) instead of ndim due to JIT issue
|
| 41 |
+
# https://github.com/pytorch/pytorch/issues/23890
|
| 42 |
+
if len(input.shape) != 3:
|
| 43 |
+
raise ValueError("Input shape must be `(N, C, L)`!")
|
| 44 |
+
if self.padding_mode != 'zeros':
|
| 45 |
+
# Padding in Conv1d is stored as (p, p), need to get (p,)
|
| 46 |
+
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1])
|
| 47 |
+
input = F.pad(input, _reversed_padding_repeated_twice,
|
| 48 |
+
mode=self.padding_mode)
|
| 49 |
+
return torch.ops.quantized.conv1d_relu(
|
| 50 |
+
input, self._packed_params, self.scale, self.zero_point)
|
| 51 |
+
|
| 52 |
+
def _get_name(self):
|
| 53 |
+
return 'QuantizedConvReLU1d'
|
| 54 |
+
|
| 55 |
+
@classmethod
|
| 56 |
+
def from_float(cls, mod):
|
| 57 |
+
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d:
|
| 58 |
+
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
|
| 59 |
+
mod.weight, mod.bias = fuse_conv_bn_weights(
|
| 60 |
+
mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
|
| 61 |
+
mod.bn.eps, mod.bn.weight, mod.bn.bias)
|
| 62 |
+
return super().from_float(mod)
|
| 63 |
+
|
| 64 |
+
@classmethod
|
| 65 |
+
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
|
| 66 |
+
assert type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU1d, \
|
| 67 |
+
"BatchNorm1d should be fused into Conv1d before converting to reference module"
|
| 68 |
+
return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
|
| 69 |
+
|
| 70 |
+
class ConvReLU2d(nnq.Conv2d):
|
| 71 |
+
r"""
|
| 72 |
+
A ConvReLU2d module is a fused module of Conv2d and ReLU
|
| 73 |
+
|
| 74 |
+
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
|
| 75 |
+
|
| 76 |
+
Attributes:
|
| 77 |
+
Same as torch.ao.nn.quantized.Conv2d
|
| 78 |
+
|
| 79 |
+
"""
|
| 80 |
+
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU2d # type: ignore[assignment]
|
| 81 |
+
|
| 82 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
| 83 |
+
padding=0, dilation=1, groups=1, bias=True,
|
| 84 |
+
padding_mode='zeros', device=None, dtype=None):
|
| 85 |
+
super().__init__(
|
| 86 |
+
in_channels, out_channels, kernel_size, stride=stride,
|
| 87 |
+
padding=padding, dilation=dilation, groups=groups, bias=bias,
|
| 88 |
+
padding_mode=padding_mode, device=device, dtype=dtype)
|
| 89 |
+
|
| 90 |
+
def forward(self, input):
|
| 91 |
+
# Temporarily using len(shape) instead of ndim due to JIT issue
|
| 92 |
+
# https://github.com/pytorch/pytorch/issues/23890
|
| 93 |
+
if len(input.shape) != 4:
|
| 94 |
+
raise ValueError("Input shape must be `(N, C, H, W)`!")
|
| 95 |
+
if self.padding_mode != 'zeros':
|
| 96 |
+
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
|
| 97 |
+
input = F.pad(input, _reversed_padding_repeated_twice,
|
| 98 |
+
mode=self.padding_mode)
|
| 99 |
+
return torch.ops.quantized.conv2d_relu(
|
| 100 |
+
input, self._packed_params, self.scale, self.zero_point)
|
| 101 |
+
|
| 102 |
+
def _get_name(self):
|
| 103 |
+
return 'QuantizedConvReLU2d'
|
| 104 |
+
|
| 105 |
+
@classmethod
|
| 106 |
+
def from_float(cls, mod):
|
| 107 |
+
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d:
|
| 108 |
+
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
|
| 109 |
+
mod.weight, mod.bias = fuse_conv_bn_weights(
|
| 110 |
+
mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
|
| 111 |
+
mod.bn.eps, mod.bn.weight, mod.bn.bias)
|
| 112 |
+
return super().from_float(mod)
|
| 113 |
+
|
| 114 |
+
@classmethod
|
| 115 |
+
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
|
| 116 |
+
assert type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU2d, \
|
| 117 |
+
"BatchNorm2d should be fused into Conv2d before converting to reference module"
|
| 118 |
+
return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class ConvReLU3d(nnq.Conv3d):
|
| 122 |
+
r"""
|
| 123 |
+
A ConvReLU3d module is a fused module of Conv3d and ReLU
|
| 124 |
+
|
| 125 |
+
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv3d`.
|
| 126 |
+
|
| 127 |
+
Attributes: Same as torch.ao.nn.quantized.Conv3d
|
| 128 |
+
|
| 129 |
+
"""
|
| 130 |
+
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU3d # type: ignore[assignment]
|
| 131 |
+
|
| 132 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
| 133 |
+
padding=0, dilation=1, groups=1, bias=True,
|
| 134 |
+
padding_mode='zeros', device=None, dtype=None):
|
| 135 |
+
assert padding_mode != 'reflect', "Conv3d does not support reflection padding"
|
| 136 |
+
super().__init__(
|
| 137 |
+
in_channels, out_channels, kernel_size, stride=stride,
|
| 138 |
+
padding=padding, dilation=dilation, groups=groups, bias=bias,
|
| 139 |
+
padding_mode=padding_mode, device=device, dtype=dtype)
|
| 140 |
+
|
| 141 |
+
def forward(self, input):
|
| 142 |
+
# Temporarily using len(shape) instead of ndim due to JIT issue
|
| 143 |
+
# https://github.com/pytorch/pytorch/issues/23890
|
| 144 |
+
if len(input.shape) != 5:
|
| 145 |
+
raise ValueError("Input shape must be `(N, C, D, H, W)`!")
|
| 146 |
+
if self.padding_mode != 'zeros':
|
| 147 |
+
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
|
| 148 |
+
input = F.pad(input, _reversed_padding_repeated_twice,
|
| 149 |
+
mode=self.padding_mode)
|
| 150 |
+
return torch.ops.quantized.conv3d_relu(
|
| 151 |
+
input, self._packed_params, self.scale, self.zero_point)
|
| 152 |
+
|
| 153 |
+
def _get_name(self):
|
| 154 |
+
return 'QuantizedConvReLU3d'
|
| 155 |
+
|
| 156 |
+
@classmethod
|
| 157 |
+
def from_float(cls, mod):
|
| 158 |
+
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d:
|
| 159 |
+
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
|
| 160 |
+
mod.weight, mod.bias = fuse_conv_bn_weights(
|
| 161 |
+
mod.weight,
|
| 162 |
+
mod.bias,
|
| 163 |
+
mod.bn.running_mean,
|
| 164 |
+
mod.bn.running_var,
|
| 165 |
+
mod.bn.eps,
|
| 166 |
+
mod.bn.weight,
|
| 167 |
+
mod.bn.bias,
|
| 168 |
+
)
|
| 169 |
+
return super().from_float(mod)
|
| 170 |
+
|
| 171 |
+
@classmethod
|
| 172 |
+
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
|
| 173 |
+
assert type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU3d, \
|
| 174 |
+
"BatchNorm3d should be fused into Conv3d before converting to reference module"
|
| 175 |
+
return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.ao.nn.quantized as nnq
|
| 3 |
+
import torch.ao.nn.intrinsic as nni
|
| 4 |
+
from torch.ao.nn.quantized.modules.utils import _quantize_weight
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"LinearReLU",
|
| 8 |
+
"LinearLeakyReLU",
|
| 9 |
+
"LinearTanh",
|
| 10 |
+
]
|
| 11 |
+
|
| 12 |
+
class LinearReLU(nnq.Linear):
|
| 13 |
+
r"""
|
| 14 |
+
A LinearReLU module fused from Linear and ReLU modules
|
| 15 |
+
|
| 16 |
+
We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
|
| 17 |
+
|
| 18 |
+
Attributes:
|
| 19 |
+
Same as torch.ao.nn.quantized.Linear
|
| 20 |
+
|
| 21 |
+
Examples::
|
| 22 |
+
|
| 23 |
+
>>> # xdoctest: +SKIP
|
| 24 |
+
>>> m = nn.intrinsic.LinearReLU(20, 30)
|
| 25 |
+
>>> input = torch.randn(128, 20)
|
| 26 |
+
>>> output = m(input)
|
| 27 |
+
>>> print(output.size())
|
| 28 |
+
torch.Size([128, 30])
|
| 29 |
+
"""
|
| 30 |
+
_FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment]
|
| 31 |
+
|
| 32 |
+
def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8):
|
| 33 |
+
super().__init__(in_features, out_features, bias, dtype)
|
| 34 |
+
|
| 35 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 36 |
+
return torch.ops.quantized.linear_relu(
|
| 37 |
+
x, self._packed_params._packed_params, self.scale, self.zero_point)
|
| 38 |
+
|
| 39 |
+
def _get_name(self):
|
| 40 |
+
return 'QuantizedLinearReLU'
|
| 41 |
+
|
| 42 |
+
@classmethod
|
| 43 |
+
def from_float(cls, mod):
|
| 44 |
+
return super().from_float(mod)
|
| 45 |
+
|
| 46 |
+
@classmethod
|
| 47 |
+
def from_reference(cls, ref_linear_relu, output_scale, output_zero_point):
|
| 48 |
+
return super().from_reference(ref_linear_relu[0], output_scale, output_zero_point)
|
| 49 |
+
|
| 50 |
+
class LinearLeakyReLU(nnq.Linear):
|
| 51 |
+
r"""
|
| 52 |
+
For onednn backend only
|
| 53 |
+
A LinearLeakyReLU module fused from Linear and LeakyReLU modules
|
| 54 |
+
We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
|
| 55 |
+
Attributes:
|
| 56 |
+
Same as torch.ao.nn.quantized.Linear
|
| 57 |
+
+ negative_slope
|
| 58 |
+
Examples::
|
| 59 |
+
>>> # xdoctest: +SKIP
|
| 60 |
+
>>> m = nn.intrinsic.LinearLeakyReLU(20, 30, 0.01)
|
| 61 |
+
>>> input = torch.randn(128, 20)
|
| 62 |
+
>>> output = m(input)
|
| 63 |
+
>>> print(output.size())
|
| 64 |
+
torch.Size([128, 30])
|
| 65 |
+
"""
|
| 66 |
+
_FLOAT_MODULE = nni.LinearLeakyReLU # type: ignore[assignment]
|
| 67 |
+
|
| 68 |
+
def __init__(self, in_features, out_features, negative_slope, bias=True, dtype=torch.qint8):
|
| 69 |
+
super().__init__(in_features, out_features, bias, dtype)
|
| 70 |
+
self.negative_slope = negative_slope
|
| 71 |
+
|
| 72 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 73 |
+
return torch.ops.quantized.linear_leaky_relu(
|
| 74 |
+
x, self._packed_params._packed_params, self.scale, self.zero_point, self.negative_slope)
|
| 75 |
+
|
| 76 |
+
def _get_name(self):
|
| 77 |
+
return 'QuantizedLinearLeakyReLU'
|
| 78 |
+
|
| 79 |
+
@classmethod
|
| 80 |
+
def from_float(cls, mod):
|
| 81 |
+
assert type(mod) == nni.LinearLeakyReLU, 'Input float module should be LinearLeakyReLU'
|
| 82 |
+
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
|
| 83 |
+
activation_post_process = mod.activation_post_process
|
| 84 |
+
leaky_relu = mod[1]
|
| 85 |
+
mod = mod[0]
|
| 86 |
+
weight_post_process = mod.qconfig.weight()
|
| 87 |
+
weight_post_process(mod.weight)
|
| 88 |
+
dtype = weight_post_process.dtype
|
| 89 |
+
act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator]
|
| 90 |
+
assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
|
| 91 |
+
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
|
| 92 |
+
qlinear_leaky_relu = cls(
|
| 93 |
+
mod.in_features,
|
| 94 |
+
mod.out_features,
|
| 95 |
+
leaky_relu.negative_slope,
|
| 96 |
+
dtype=dtype)
|
| 97 |
+
qlinear_leaky_relu.set_weight_bias(qweight, mod.bias)
|
| 98 |
+
qlinear_leaky_relu.scale = float(act_scale)
|
| 99 |
+
qlinear_leaky_relu.zero_point = int(act_zp)
|
| 100 |
+
return qlinear_leaky_relu
|
| 101 |
+
|
| 102 |
+
@classmethod
|
| 103 |
+
def from_reference(cls, ref_mod, output_scale, output_zero_point):
|
| 104 |
+
linear = ref_mod[0]
|
| 105 |
+
leaky_relu = ref_mod[1]
|
| 106 |
+
qlinear_leaky_relu = cls(
|
| 107 |
+
linear.in_features,
|
| 108 |
+
linear.out_features,
|
| 109 |
+
leaky_relu.negative_slope)
|
| 110 |
+
qweight = linear.get_quantized_weight()
|
| 111 |
+
qlinear_leaky_relu.set_weight_bias(qweight, linear.bias)
|
| 112 |
+
qlinear_leaky_relu.scale = float(output_scale)
|
| 113 |
+
qlinear_leaky_relu.zero_point = int(output_zero_point)
|
| 114 |
+
return qlinear_leaky_relu
|
| 115 |
+
|
| 116 |
+
class LinearTanh(nnq.Linear):
|
| 117 |
+
r"""
|
| 118 |
+
A LinearTanh module fused from Linear and Tanh modules
|
| 119 |
+
|
| 120 |
+
We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
|
| 121 |
+
|
| 122 |
+
Attributes:
|
| 123 |
+
Same as torch.ao.nn.quantized.Linear
|
| 124 |
+
|
| 125 |
+
Examples::
|
| 126 |
+
|
| 127 |
+
>>> # xdoctest: +SKIP
|
| 128 |
+
>>> m = nn.intrinsic.LinearTanh(20, 30)
|
| 129 |
+
>>> input = torch.randn(128, 20)
|
| 130 |
+
>>> output = m(input)
|
| 131 |
+
>>> print(output.size())
|
| 132 |
+
torch.Size([128, 30])
|
| 133 |
+
"""
|
| 134 |
+
_FLOAT_MODULE = nni.LinearTanh # type: ignore[assignment]
|
| 135 |
+
|
| 136 |
+
def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8):
|
| 137 |
+
super().__init__(in_features, out_features, bias, dtype)
|
| 138 |
+
|
| 139 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 140 |
+
return torch.ops.quantized.linear_tanh(
|
| 141 |
+
x, self._packed_params._packed_params, self.scale, self.zero_point)
|
| 142 |
+
|
| 143 |
+
def _get_name(self):
|
| 144 |
+
return 'QuantizedLinearTanh'
|
| 145 |
+
|
| 146 |
+
@classmethod
|
| 147 |
+
def from_float(cls, mod):
|
| 148 |
+
assert type(mod) == nni.LinearTanh, 'Input float module should be LinearTanh'
|
| 149 |
+
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
|
| 150 |
+
activation_post_process = mod.activation_post_process
|
| 151 |
+
mod = mod[0]
|
| 152 |
+
weight_post_process = mod.qconfig.weight()
|
| 153 |
+
weight_post_process(mod.weight)
|
| 154 |
+
dtype = weight_post_process.dtype
|
| 155 |
+
act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator]
|
| 156 |
+
assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
|
| 157 |
+
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
|
| 158 |
+
qlinear_tanh = cls(
|
| 159 |
+
mod.in_features,
|
| 160 |
+
mod.out_features,
|
| 161 |
+
dtype=dtype)
|
| 162 |
+
qlinear_tanh.set_weight_bias(qweight, mod.bias)
|
| 163 |
+
qlinear_tanh.scale = float(act_scale)
|
| 164 |
+
qlinear_tanh.zero_point = int(act_zp)
|
| 165 |
+
return qlinear_tanh
|
| 166 |
+
|
| 167 |
+
@classmethod
|
| 168 |
+
def from_reference(cls, ref_mod, output_scale, output_zero_point):
|
| 169 |
+
linear = ref_mod[0]
|
| 170 |
+
qlinear_tanh = cls(
|
| 171 |
+
linear.in_features,
|
| 172 |
+
linear.out_features)
|
| 173 |
+
qweight = linear.get_quantized_weight()
|
| 174 |
+
qlinear_tanh.set_weight_bias(qweight, linear.bias)
|
| 175 |
+
qlinear_tanh.scale = float(output_scale)
|
| 176 |
+
qlinear_tanh.zero_point = int(output_zero_point)
|
| 177 |
+
return qlinear_tanh
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/qat/dynamic/modules/linear.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
__all__ = ["Linear"]
|
| 4 |
+
|
| 5 |
+
class Linear(torch.ao.nn.qat.Linear):
|
| 6 |
+
r"""
|
| 7 |
+
A linear module attached with FakeQuantize modules for weight,
|
| 8 |
+
used for dynamic quantization aware training.
|
| 9 |
+
|
| 10 |
+
We adopt the same interface as `torch.nn.Linear`, please see
|
| 11 |
+
https://pytorch.org/docs/stable/nn.html#torch.nn.Linear
|
| 12 |
+
for documentation.
|
| 13 |
+
|
| 14 |
+
Similar to `torch.nn.Linear`, with FakeQuantize modules initialized to
|
| 15 |
+
default.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, in_features, out_features, bias=True,
|
| 19 |
+
qconfig=None, device=None, dtype=None) -> None:
|
| 20 |
+
super().__init__(in_features, out_features, bias, qconfig, device, dtype)
|
| 21 |
+
if not torch.ao.quantization.qconfig._activation_is_memoryless(qconfig):
|
| 22 |
+
raise ValueError(
|
| 23 |
+
"Dynamic QAT requires a memoryless observer." +
|
| 24 |
+
"This means a MovingAverage observer with averaging constant equal to 1"
|
| 25 |
+
)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/qat/modules/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .linear import Linear
|
| 2 |
+
from .conv import Conv1d
|
| 3 |
+
from .conv import Conv2d
|
| 4 |
+
from .conv import Conv3d
|
| 5 |
+
from .embedding_ops import EmbeddingBag, Embedding
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"Linear",
|
| 9 |
+
"Conv1d",
|
| 10 |
+
"Conv2d",
|
| 11 |
+
"Conv3d",
|
| 12 |
+
"Embedding",
|
| 13 |
+
"EmbeddingBag",
|
| 14 |
+
]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantizable/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (256 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantizable/modules/activation.py
ADDED
|
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.jit # this is needed to avoid a circular import
|
| 3 |
+
from torch import nn
|
| 4 |
+
import torch.nn.functional as nnF
|
| 5 |
+
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
from typing import Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import warnings
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"MultiheadAttention"
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
class MultiheadAttention(nn.MultiheadAttention):
|
| 16 |
+
_FLOAT_MODULE = nn.MultiheadAttention
|
| 17 |
+
|
| 18 |
+
r"""Quantizable implementation of the MultiheadAttention.
|
| 19 |
+
|
| 20 |
+
Note::
|
| 21 |
+
Please, refer to :class:`~torch.nn.MultiheadAttention` for more
|
| 22 |
+
information
|
| 23 |
+
|
| 24 |
+
Allows the model to jointly attend to information from different
|
| 25 |
+
representation subspaces.
|
| 26 |
+
See reference: Attention Is All You Need
|
| 27 |
+
|
| 28 |
+
The original MHA module is not quantizable.
|
| 29 |
+
This reimplements it by explicitly instantiating the linear layers.
|
| 30 |
+
|
| 31 |
+
.. math::
|
| 32 |
+
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
| 33 |
+
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
embed_dim: total dimension of the model.
|
| 37 |
+
num_heads: parallel attention heads.
|
| 38 |
+
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
| 39 |
+
bias: add bias as module parameter. Default: True.
|
| 40 |
+
add_bias_kv: add bias to the key and value sequences at dim=0.
|
| 41 |
+
add_zero_attn: add a new batch of zeros to the key and
|
| 42 |
+
value sequences at dim=1.
|
| 43 |
+
kdim: total number of features in key. Default: None.
|
| 44 |
+
vdim: total number of features in value. Default: None.
|
| 45 |
+
batch_first: If ``True``, then the input and output tensors are provided
|
| 46 |
+
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
| 47 |
+
|
| 48 |
+
Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set
|
| 49 |
+
to :attr:`embed_dim` such that query, key, and value have the same
|
| 50 |
+
number of features.
|
| 51 |
+
|
| 52 |
+
Examples::
|
| 53 |
+
|
| 54 |
+
>>> import torch.ao.nn.quantizable as nnqa
|
| 55 |
+
>>> multihead_attn = nnqa.MultiheadAttention(embed_dim, num_heads)
|
| 56 |
+
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
| 57 |
+
|
| 58 |
+
Note::
|
| 59 |
+
Please, follow the quantization flow to convert the quantizable MHA.
|
| 60 |
+
"""
|
| 61 |
+
__constants__ = ['batch_first']
|
| 62 |
+
|
| 63 |
+
def __init__(self, embed_dim: int, num_heads: int,
|
| 64 |
+
dropout: float = 0., bias: bool = True,
|
| 65 |
+
add_bias_kv: bool = False, add_zero_attn: bool = False,
|
| 66 |
+
kdim: Optional[int] = None, vdim: Optional[int] = None, batch_first: bool = False,
|
| 67 |
+
device=None, dtype=None) -> None:
|
| 68 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 69 |
+
super().__init__(embed_dim, num_heads, dropout,
|
| 70 |
+
bias, add_bias_kv,
|
| 71 |
+
add_zero_attn, kdim, vdim, batch_first,
|
| 72 |
+
**factory_kwargs)
|
| 73 |
+
self.linear_Q = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs)
|
| 74 |
+
self.linear_K = nn.Linear(self.kdim, self.embed_dim, bias=bias, **factory_kwargs)
|
| 75 |
+
self.linear_V = nn.Linear(self.vdim, self.embed_dim, bias=bias, **factory_kwargs)
|
| 76 |
+
# for the type: ignore, see https://github.com/pytorch/pytorch/issues/58969
|
| 77 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs) # type: ignore[assignment]
|
| 78 |
+
|
| 79 |
+
# Functionals
|
| 80 |
+
self.q_scaling_product = torch.ao.nn.quantized.FloatFunctional()
|
| 81 |
+
# note: importing torch.ao.nn.quantized at top creates a circular import
|
| 82 |
+
|
| 83 |
+
# Quant/Dequant
|
| 84 |
+
self.quant_attn_output = torch.ao.quantization.QuantStub()
|
| 85 |
+
self.quant_attn_output_weights = torch.ao.quantization.QuantStub()
|
| 86 |
+
self.dequant_q = torch.ao.quantization.DeQuantStub()
|
| 87 |
+
self.dequant_k = torch.ao.quantization.DeQuantStub()
|
| 88 |
+
self.dequant_v = torch.ao.quantization.DeQuantStub()
|
| 89 |
+
|
| 90 |
+
def _get_name(self):
|
| 91 |
+
return 'QuantizableMultiheadAttention'
|
| 92 |
+
|
| 93 |
+
@classmethod
|
| 94 |
+
def from_float(cls, other):
|
| 95 |
+
assert type(other) == cls._FLOAT_MODULE
|
| 96 |
+
assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'"
|
| 97 |
+
# Setting the dropout to 0.0!
|
| 98 |
+
observed = cls(other.embed_dim, other.num_heads, other.dropout,
|
| 99 |
+
(other.in_proj_bias is not None),
|
| 100 |
+
(other.bias_k is not None),
|
| 101 |
+
other.add_zero_attn, other.kdim, other.vdim,
|
| 102 |
+
other.batch_first)
|
| 103 |
+
observed.bias_k = other.bias_k
|
| 104 |
+
observed.bias_v = other.bias_v
|
| 105 |
+
observed.qconfig = other.qconfig
|
| 106 |
+
|
| 107 |
+
# Set the linear weights
|
| 108 |
+
# for the type: ignores, see https://github.com/pytorch/pytorch/issues/58969
|
| 109 |
+
observed.out_proj.weight = other.out_proj.weight # type: ignore[has-type]
|
| 110 |
+
observed.out_proj.bias = other.out_proj.bias # type: ignore[has-type]
|
| 111 |
+
if other._qkv_same_embed_dim:
|
| 112 |
+
# Use separate params
|
| 113 |
+
bias = other.in_proj_bias
|
| 114 |
+
_start = 0
|
| 115 |
+
_end = _start + other.embed_dim
|
| 116 |
+
weight = other.in_proj_weight[_start:_end, :]
|
| 117 |
+
if bias is not None:
|
| 118 |
+
bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad)
|
| 119 |
+
observed.linear_Q.weight = torch.nn.Parameter(weight,
|
| 120 |
+
weight.requires_grad)
|
| 121 |
+
observed.linear_Q.bias = bias
|
| 122 |
+
|
| 123 |
+
bias = other.in_proj_bias
|
| 124 |
+
_start = _end
|
| 125 |
+
_end = _start + other.embed_dim
|
| 126 |
+
weight = other.in_proj_weight[_start:_end, :]
|
| 127 |
+
if bias is not None:
|
| 128 |
+
bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad)
|
| 129 |
+
observed.linear_K.weight = torch.nn.Parameter(weight,
|
| 130 |
+
weight.requires_grad)
|
| 131 |
+
observed.linear_K.bias = bias
|
| 132 |
+
|
| 133 |
+
bias = other.in_proj_bias
|
| 134 |
+
_start = _end
|
| 135 |
+
weight = other.in_proj_weight[_start:, :]
|
| 136 |
+
if bias is not None:
|
| 137 |
+
bias = torch.nn.Parameter(bias[_start:], bias.requires_grad)
|
| 138 |
+
observed.linear_V.weight = torch.nn.Parameter(weight,
|
| 139 |
+
weight.requires_grad)
|
| 140 |
+
observed.linear_V.bias = bias
|
| 141 |
+
else:
|
| 142 |
+
observed.linear_Q.weight = nn.Parameter(other.q_proj_weight)
|
| 143 |
+
observed.linear_K.weight = nn.Parameter(other.k_proj_weight)
|
| 144 |
+
observed.linear_V.weight = nn.Parameter(other.v_proj_weight)
|
| 145 |
+
if other.in_proj_bias is None:
|
| 146 |
+
observed.linear_Q.bias = None # type: ignore[assignment]
|
| 147 |
+
observed.linear_K.bias = None # type: ignore[assignment]
|
| 148 |
+
observed.linear_V.bias = None # type: ignore[assignment]
|
| 149 |
+
else:
|
| 150 |
+
observed.linear_Q.bias = nn.Parameter(other.in_proj_bias[0:other.embed_dim])
|
| 151 |
+
observed.linear_K.bias = nn.Parameter(other.in_proj_bias[other.embed_dim:(other.embed_dim * 2)])
|
| 152 |
+
observed.linear_V.bias = nn.Parameter(other.in_proj_bias[(other.embed_dim * 2):])
|
| 153 |
+
observed.eval()
|
| 154 |
+
# Explicit prepare
|
| 155 |
+
observed = torch.ao.quantization.prepare(observed, inplace=True)
|
| 156 |
+
return observed
|
| 157 |
+
|
| 158 |
+
@torch.jit.unused
|
| 159 |
+
def dequantize(self):
|
| 160 |
+
r"""Utility to convert the quantized MHA back to float.
|
| 161 |
+
|
| 162 |
+
The motivation for this is that it is not trivial to conver the weights
|
| 163 |
+
from the format that is used in the quantized version back to the
|
| 164 |
+
float.
|
| 165 |
+
"""
|
| 166 |
+
fp = self._FLOAT_MODULE(self.embed_dim, self.num_heads, self.dropout,
|
| 167 |
+
(self.linear_Q._weight_bias()[1] is not None),
|
| 168 |
+
(self.bias_k is not None),
|
| 169 |
+
self.add_zero_attn, self.kdim, self.vdim, self.batch_first)
|
| 170 |
+
assert fp._qkv_same_embed_dim == self._qkv_same_embed_dim
|
| 171 |
+
if self.bias_k is not None:
|
| 172 |
+
fp.bias_k = nn.Parameter(self.bias_k.dequantize())
|
| 173 |
+
if self.bias_v is not None:
|
| 174 |
+
fp.bias_v = nn.Parameter(self.bias_v.dequantize())
|
| 175 |
+
|
| 176 |
+
# Set the linear weights
|
| 177 |
+
# Note: Because the linear layers are quantized, mypy does not nkow how
|
| 178 |
+
# to deal with them -- might need to ignore the typing checks.
|
| 179 |
+
# for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969
|
| 180 |
+
w, b = self.out_proj._weight_bias() # type: ignore[operator, has-type]
|
| 181 |
+
fp.out_proj.weight = nn.Parameter(w.dequantize())
|
| 182 |
+
if b is not None:
|
| 183 |
+
fp.out_proj.bias = nn.Parameter(b)
|
| 184 |
+
|
| 185 |
+
wQ, bQ = self.linear_Q._weight_bias() # type: ignore[operator]
|
| 186 |
+
wQ = wQ.dequantize()
|
| 187 |
+
wK, bK = self.linear_K._weight_bias() # type: ignore[operator]
|
| 188 |
+
wK = wK.dequantize()
|
| 189 |
+
wV, bV = self.linear_V._weight_bias() # type: ignore[operator]
|
| 190 |
+
wV = wV.dequantize()
|
| 191 |
+
if fp._qkv_same_embed_dim:
|
| 192 |
+
# Use separate params
|
| 193 |
+
_start = 0
|
| 194 |
+
_end = _start + fp.embed_dim
|
| 195 |
+
fp.in_proj_weight[_start:_end, :] = wQ
|
| 196 |
+
if fp.in_proj_bias is not None:
|
| 197 |
+
assert all(bQ == 0)
|
| 198 |
+
fp.in_proj_bias[_start:_end] = bQ
|
| 199 |
+
|
| 200 |
+
_start = _end
|
| 201 |
+
_end = _start + fp.embed_dim
|
| 202 |
+
fp.in_proj_weight[_start:_end, :] = wK
|
| 203 |
+
if fp.in_proj_bias is not None:
|
| 204 |
+
assert all(bK == 0)
|
| 205 |
+
fp.in_proj_bias[_start:_end] = bK
|
| 206 |
+
|
| 207 |
+
_start = _end
|
| 208 |
+
fp.in_proj_weight[_start:, :] = wV
|
| 209 |
+
if fp.in_proj_bias is not None:
|
| 210 |
+
assert all(bV == 0)
|
| 211 |
+
fp.in_proj_bias[_start:] = bV
|
| 212 |
+
else:
|
| 213 |
+
fp.q_proj_weight = nn.Parameter(wQ)
|
| 214 |
+
fp.k_proj_weight = nn.Parameter(wK)
|
| 215 |
+
fp.v_proj_weight = nn.Parameter(wV)
|
| 216 |
+
if fp.in_proj_bias is None:
|
| 217 |
+
self.linear_Q.bias = None
|
| 218 |
+
self.linear_K.bias = None
|
| 219 |
+
self.linear_V.bias = None
|
| 220 |
+
else:
|
| 221 |
+
fp.in_proj_bias[0:fp.embed_dim] = bQ
|
| 222 |
+
fp.in_proj_bias[fp.embed_dim:(fp.embed_dim * 2)] = bK
|
| 223 |
+
fp.in_proj_bias[(fp.embed_dim * 2):] = bV
|
| 224 |
+
|
| 225 |
+
return fp
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
@classmethod
|
| 229 |
+
def from_observed(cls, other):
|
| 230 |
+
# The whole flow is float -> observed -> quantized
|
| 231 |
+
# This class does float -> observed only
|
| 232 |
+
# See nn.quantized.MultiheadAttention
|
| 233 |
+
raise NotImplementedError("It looks like you are trying to prepare an "
|
| 234 |
+
"MHA module. Please, see "
|
| 235 |
+
"the examples on quantizable MHAs.")
|
| 236 |
+
|
| 237 |
+
def forward(self,
|
| 238 |
+
query: Tensor,
|
| 239 |
+
key: Tensor,
|
| 240 |
+
value: Tensor,
|
| 241 |
+
key_padding_mask: Optional[Tensor] = None,
|
| 242 |
+
need_weights: bool = True,
|
| 243 |
+
attn_mask: Optional[Tensor] = None,
|
| 244 |
+
average_attn_weights: bool = True,
|
| 245 |
+
is_causal: bool = False) -> Tuple[Tensor, Optional[Tensor]]:
|
| 246 |
+
r"""
|
| 247 |
+
Note::
|
| 248 |
+
Please, refer to :func:`~torch.nn.MultiheadAttention.forward` for more
|
| 249 |
+
information
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
| 253 |
+
See "Attention Is All You Need" for more details.
|
| 254 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
| 255 |
+
be ignored by the attention. When given a binary mask and a value is True,
|
| 256 |
+
the corresponding value on the attention layer will be ignored.
|
| 257 |
+
need_weights: output attn_output_weights.
|
| 258 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
| 259 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
| 260 |
+
|
| 261 |
+
Shape:
|
| 262 |
+
- Inputs:
|
| 263 |
+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
| 264 |
+
the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
|
| 265 |
+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
| 266 |
+
the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
|
| 267 |
+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
| 268 |
+
the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
|
| 269 |
+
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
| 270 |
+
If a BoolTensor is provided, the positions with the
|
| 271 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
| 272 |
+
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
| 273 |
+
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
| 274 |
+
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
| 275 |
+
positions. If a BoolTensor is provided, positions with ``True``
|
| 276 |
+
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
| 277 |
+
is provided, it will be added to the attention weight.
|
| 278 |
+
- is_causal: If specified, applies a causal mask as attention mask. Mutually exclusive with providing attn_mask.
|
| 279 |
+
Default: ``False``.
|
| 280 |
+
- average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
|
| 281 |
+
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
|
| 282 |
+
effect when ``need_weights=True.``. Default: True (i.e. average weights across heads)
|
| 283 |
+
|
| 284 |
+
- Outputs:
|
| 285 |
+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
| 286 |
+
E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
|
| 287 |
+
- attn_output_weights: If ``average_attn_weights=True``, returns attention weights averaged
|
| 288 |
+
across heads of shape :math:`(N, L, S)`, where N is the batch size, L is the target sequence length,
|
| 289 |
+
S is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
|
| 290 |
+
head of shape :math:`(N, num_heads, L, S)`.
|
| 291 |
+
"""
|
| 292 |
+
return self._forward_impl(query, key, value, key_padding_mask,
|
| 293 |
+
need_weights, attn_mask, average_attn_weights,
|
| 294 |
+
is_causal)
|
| 295 |
+
|
| 296 |
+
def _forward_impl(self,
|
| 297 |
+
query: Tensor,
|
| 298 |
+
key: Tensor,
|
| 299 |
+
value: Tensor,
|
| 300 |
+
key_padding_mask: Optional[Tensor] = None,
|
| 301 |
+
need_weights: bool = True,
|
| 302 |
+
attn_mask: Optional[Tensor] = None,
|
| 303 |
+
average_attn_weights: bool = True,
|
| 304 |
+
is_causal: bool = False) -> Tuple[Tensor, Optional[Tensor]]:
|
| 305 |
+
# This version will not deal with the static key/value pairs.
|
| 306 |
+
# Keeping it here for future changes.
|
| 307 |
+
#
|
| 308 |
+
# TODO: This method has some duplicate lines with the
|
| 309 |
+
# `torch.nn.functional.multi_head_attention`. Will need to refactor.
|
| 310 |
+
static_k = None
|
| 311 |
+
static_v = None
|
| 312 |
+
|
| 313 |
+
if attn_mask is not None and is_causal:
|
| 314 |
+
raise AssertionError("Only allow causal mask or attn_mask")
|
| 315 |
+
|
| 316 |
+
if is_causal:
|
| 317 |
+
raise AssertionError("causal mask not supported by AO MHA module")
|
| 318 |
+
|
| 319 |
+
if self.batch_first:
|
| 320 |
+
query, key, value = (x.transpose(0, 1) for x in (query, key, value))
|
| 321 |
+
|
| 322 |
+
tgt_len, bsz, embed_dim_to_check = query.size()
|
| 323 |
+
assert self.embed_dim == embed_dim_to_check
|
| 324 |
+
# allow MHA to have different sizes for the feature dimension
|
| 325 |
+
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
| 326 |
+
|
| 327 |
+
head_dim = self.embed_dim // self.num_heads
|
| 328 |
+
assert head_dim * self.num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
| 329 |
+
scaling = float(head_dim) ** -0.5
|
| 330 |
+
|
| 331 |
+
q = self.linear_Q(query)
|
| 332 |
+
k = self.linear_K(key)
|
| 333 |
+
v = self.linear_V(value)
|
| 334 |
+
|
| 335 |
+
q = self.q_scaling_product.mul_scalar(q, scaling)
|
| 336 |
+
|
| 337 |
+
if attn_mask is not None:
|
| 338 |
+
if attn_mask.dtype == torch.uint8:
|
| 339 |
+
warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
|
| 340 |
+
attn_mask = attn_mask.to(torch.bool)
|
| 341 |
+
assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \
|
| 342 |
+
f'Only float and bool types are supported for attn_mask, not {attn_mask.dtype}'
|
| 343 |
+
|
| 344 |
+
if attn_mask.dim() == 2:
|
| 345 |
+
attn_mask = attn_mask.unsqueeze(0)
|
| 346 |
+
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
| 347 |
+
raise RuntimeError('The size of the 2D attn_mask is not correct.')
|
| 348 |
+
elif attn_mask.dim() == 3:
|
| 349 |
+
if list(attn_mask.size()) != [bsz * self.num_heads, query.size(0), key.size(0)]:
|
| 350 |
+
raise RuntimeError('The size of the 3D attn_mask is not correct.')
|
| 351 |
+
else:
|
| 352 |
+
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
|
| 353 |
+
# attn_mask's dim is 3 now.
|
| 354 |
+
|
| 355 |
+
# convert ByteTensor key_padding_mask to bool
|
| 356 |
+
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
|
| 357 |
+
warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
|
| 358 |
+
key_padding_mask = key_padding_mask.to(torch.bool)
|
| 359 |
+
if self.bias_k is not None and self.bias_v is not None:
|
| 360 |
+
if static_k is None and static_v is None:
|
| 361 |
+
|
| 362 |
+
# Explicitly assert that bias_k and bias_v are not None
|
| 363 |
+
# in a way that TorchScript can understand.
|
| 364 |
+
bias_k = self.bias_k
|
| 365 |
+
assert bias_k is not None
|
| 366 |
+
bias_v = self.bias_v
|
| 367 |
+
assert bias_v is not None
|
| 368 |
+
|
| 369 |
+
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
| 370 |
+
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
| 371 |
+
if attn_mask is not None:
|
| 372 |
+
attn_mask = nnF.pad(attn_mask, (0, 1))
|
| 373 |
+
if key_padding_mask is not None:
|
| 374 |
+
key_padding_mask = nnF.pad(key_padding_mask, (0, 1))
|
| 375 |
+
else:
|
| 376 |
+
assert static_k is None, "bias cannot be added to static key."
|
| 377 |
+
assert static_v is None, "bias cannot be added to static value."
|
| 378 |
+
else:
|
| 379 |
+
assert self.bias_k is None
|
| 380 |
+
assert self.bias_v is None
|
| 381 |
+
|
| 382 |
+
q = q.contiguous().view(tgt_len, bsz * self.num_heads, head_dim).transpose(0, 1)
|
| 383 |
+
if k is not None:
|
| 384 |
+
k = k.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
|
| 385 |
+
if v is not None:
|
| 386 |
+
v = v.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
|
| 387 |
+
|
| 388 |
+
if static_k is not None:
|
| 389 |
+
assert static_k.size(0) == bsz * self.num_heads
|
| 390 |
+
assert static_k.size(2) == head_dim
|
| 391 |
+
k = static_k
|
| 392 |
+
|
| 393 |
+
if static_v is not None:
|
| 394 |
+
assert static_v.size(0) == bsz * self.num_heads
|
| 395 |
+
assert static_v.size(2) == head_dim
|
| 396 |
+
v = static_v
|
| 397 |
+
|
| 398 |
+
src_len = k.size(1)
|
| 399 |
+
|
| 400 |
+
if key_padding_mask is not None:
|
| 401 |
+
assert key_padding_mask.size(0) == bsz
|
| 402 |
+
assert key_padding_mask.size(1) == src_len
|
| 403 |
+
|
| 404 |
+
if self.add_zero_attn:
|
| 405 |
+
src_len += 1
|
| 406 |
+
k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:])
|
| 407 |
+
if k.is_quantized:
|
| 408 |
+
k_zeros = torch.quantize_per_tensor(k_zeros, k.q_scale(), k.q_zero_point(), k.dtype)
|
| 409 |
+
k = torch.cat([k, k_zeros], dim=1)
|
| 410 |
+
v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:])
|
| 411 |
+
if v.is_quantized:
|
| 412 |
+
v_zeros = torch.quantize_per_tensor(v_zeros, v.q_scale(), v.q_zero_point(), v.dtype)
|
| 413 |
+
v = torch.cat([v, v_zeros], dim=1)
|
| 414 |
+
|
| 415 |
+
if attn_mask is not None:
|
| 416 |
+
attn_mask = nnF.pad(attn_mask, (0, 1))
|
| 417 |
+
if key_padding_mask is not None:
|
| 418 |
+
key_padding_mask = nnF.pad(key_padding_mask, (0, 1))
|
| 419 |
+
|
| 420 |
+
# Leaving the quantized zone here
|
| 421 |
+
q = self.dequant_q(q)
|
| 422 |
+
k = self.dequant_k(k)
|
| 423 |
+
v = self.dequant_v(v)
|
| 424 |
+
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
|
| 425 |
+
assert list(attn_output_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
| 426 |
+
|
| 427 |
+
if attn_mask is not None:
|
| 428 |
+
if attn_mask.dtype == torch.bool:
|
| 429 |
+
attn_output_weights.masked_fill_(attn_mask, float('-inf'))
|
| 430 |
+
else:
|
| 431 |
+
attn_output_weights += attn_mask
|
| 432 |
+
|
| 433 |
+
if key_padding_mask is not None:
|
| 434 |
+
attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 435 |
+
attn_output_weights = attn_output_weights.masked_fill(
|
| 436 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
| 437 |
+
float('-inf'),
|
| 438 |
+
)
|
| 439 |
+
attn_output_weights = attn_output_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 440 |
+
|
| 441 |
+
attn_output_weights = nnF.softmax(
|
| 442 |
+
attn_output_weights, dim=-1)
|
| 443 |
+
attn_output_weights = nnF.dropout(attn_output_weights, p=self.dropout, training=self.training)
|
| 444 |
+
|
| 445 |
+
attn_output = torch.bmm(attn_output_weights, v)
|
| 446 |
+
assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, head_dim]
|
| 447 |
+
if self.batch_first:
|
| 448 |
+
attn_output = attn_output.view(bsz, tgt_len, self.embed_dim)
|
| 449 |
+
else:
|
| 450 |
+
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
|
| 451 |
+
|
| 452 |
+
# Reentering the quantized zone
|
| 453 |
+
attn_output = self.quant_attn_output(attn_output)
|
| 454 |
+
# for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969
|
| 455 |
+
attn_output = self.out_proj(attn_output) # type: ignore[has-type]
|
| 456 |
+
attn_output_weights = self.quant_attn_output_weights(attn_output_weights)
|
| 457 |
+
|
| 458 |
+
if need_weights:
|
| 459 |
+
# average attention weights over heads
|
| 460 |
+
attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 461 |
+
if average_attn_weights:
|
| 462 |
+
attn_output_weights = attn_output_weights.mean(dim=1)
|
| 463 |
+
return attn_output, attn_output_weights
|
| 464 |
+
else:
|
| 465 |
+
return attn_output, None
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/__pycache__/functional.cpython-311.pyc
ADDED
|
Binary file (32.6 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-311.pyc
ADDED
|
Binary file (59 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/batchnorm.cpython-311.pyc
ADDED
|
Binary file (6.44 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/conv.cpython-311.pyc
ADDED
|
Binary file (49.1 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/embedding_ops.cpython-311.pyc
ADDED
|
Binary file (18.1 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/linear.cpython-311.pyc
ADDED
|
Binary file (16.3 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/rnn.cpython-311.pyc
ADDED
|
Binary file (2.6 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (7.19 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/activation.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from warnings import warn
|
| 3 |
+
__all__ = [
|
| 4 |
+
"ReLU6",
|
| 5 |
+
"Hardswish",
|
| 6 |
+
"ELU",
|
| 7 |
+
"LeakyReLU",
|
| 8 |
+
"Sigmoid",
|
| 9 |
+
"Softmax",
|
| 10 |
+
"MultiheadAttention",
|
| 11 |
+
"PReLU"
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
class ReLU6(torch.nn.ReLU):
|
| 15 |
+
r"""Applies the element-wise function:
|
| 16 |
+
|
| 17 |
+
:math:`\text{ReLU6}(x) = \min(\max(x_0, x), q(6))`, where :math:`x_0` is the
|
| 18 |
+
zero_point, and :math:`q(6)` is the quantized representation of number 6.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
inplace: can optionally do the operation in-place. Default: ``False``
|
| 22 |
+
|
| 23 |
+
Shape:
|
| 24 |
+
- Input: :math:`(N, *)` where `*` means, any number of additional
|
| 25 |
+
dimensions
|
| 26 |
+
- Output: :math:`(N, *)`, same shape as the input
|
| 27 |
+
|
| 28 |
+
.. image:: ../scripts/activation_images/ReLU6.png
|
| 29 |
+
|
| 30 |
+
Examples::
|
| 31 |
+
|
| 32 |
+
>>> m = nn.quantized.ReLU6()
|
| 33 |
+
>>> input = torch.randn(2)
|
| 34 |
+
>>> # xdoctest: +SKIP
|
| 35 |
+
>>> input = torch.quantize_per_tensor(input, 1.0, 0, dtype=torch.qint32)
|
| 36 |
+
>>> output = m(input)
|
| 37 |
+
"""
|
| 38 |
+
def __init__(self, inplace=False):
|
| 39 |
+
super().__init__(inplace)
|
| 40 |
+
self.inplace = inplace
|
| 41 |
+
|
| 42 |
+
def forward(self, input):
|
| 43 |
+
return torch.ops.quantized.relu6(input, self.inplace)
|
| 44 |
+
|
| 45 |
+
def _get_name(self):
|
| 46 |
+
return 'QuantizedReLU6'
|
| 47 |
+
|
| 48 |
+
@staticmethod
|
| 49 |
+
def from_float(mod):
|
| 50 |
+
return ReLU6(mod.inplace)
|
| 51 |
+
|
| 52 |
+
class Hardswish(torch.nn.Hardswish):
|
| 53 |
+
r"""This is the quantized version of :class:`~torch.nn.Hardswish`.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
scale: quantization scale of the output tensor
|
| 57 |
+
zero_point: quantization zero point of the output tensor
|
| 58 |
+
"""
|
| 59 |
+
def __init__(self, scale, zero_point, device=None, dtype=None):
|
| 60 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
|
| 63 |
+
self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
|
| 64 |
+
|
| 65 |
+
def forward(self, input):
|
| 66 |
+
return torch.ops.quantized.hardswish(input, self.scale, self.zero_point)
|
| 67 |
+
|
| 68 |
+
def _get_name(self):
|
| 69 |
+
return 'QuantizedHardswish'
|
| 70 |
+
|
| 71 |
+
@staticmethod
|
| 72 |
+
def from_float(mod):
|
| 73 |
+
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
| 74 |
+
return Hardswish(float(scale), int(zero_point))
|
| 75 |
+
|
| 76 |
+
@classmethod
|
| 77 |
+
def from_reference(cls, mod, scale, zero_point):
|
| 78 |
+
return cls(float(scale), int(zero_point))
|
| 79 |
+
|
| 80 |
+
class ELU(torch.nn.ELU):
|
| 81 |
+
r"""This is the quantized equivalent of :class:`~torch.nn.ELU`.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
scale: quantization scale of the output tensor
|
| 85 |
+
zero_point: quantization zero point of the output tensor
|
| 86 |
+
alpha: the alpha constant
|
| 87 |
+
"""
|
| 88 |
+
def __init__(self, scale, zero_point, alpha=1.):
|
| 89 |
+
super().__init__(alpha)
|
| 90 |
+
self.scale = scale
|
| 91 |
+
self.zero_point = zero_point
|
| 92 |
+
|
| 93 |
+
def forward(self, input):
|
| 94 |
+
return torch.ao.nn.quantized.functional.elu(
|
| 95 |
+
input, self.scale, self.zero_point, self.alpha)
|
| 96 |
+
|
| 97 |
+
def _get_name(self):
|
| 98 |
+
return 'QuantizedELU'
|
| 99 |
+
|
| 100 |
+
@staticmethod
|
| 101 |
+
def from_float(mod):
|
| 102 |
+
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
| 103 |
+
return ELU(float(scale), int(zero_point), mod.alpha)
|
| 104 |
+
|
| 105 |
+
@classmethod
|
| 106 |
+
def from_reference(cls, mod, scale, zero_point):
|
| 107 |
+
return cls(float(scale), int(zero_point), mod.alpha)
|
| 108 |
+
|
| 109 |
+
class LeakyReLU(torch.nn.LeakyReLU):
|
| 110 |
+
r"""This is the quantized equivalent of :class:`~torch.nn.LeakyReLU`.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
scale: quantization scale of the output tensor
|
| 114 |
+
zero_point: quantization zero point of the output tensor
|
| 115 |
+
negative_slope: Controls the angle of the negative slope. Default: 1e-2
|
| 116 |
+
"""
|
| 117 |
+
def __init__(self, scale: float, zero_point: int, negative_slope: float = 1e-2,
|
| 118 |
+
inplace: bool = False, device=None, dtype=None) -> None:
|
| 119 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 120 |
+
super().__init__(negative_slope, inplace)
|
| 121 |
+
self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
|
| 122 |
+
self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
|
| 123 |
+
|
| 124 |
+
def forward(self, input):
|
| 125 |
+
return torch.ops.quantized.leaky_relu(
|
| 126 |
+
input, self.negative_slope, self.inplace, self.scale, self.zero_point)
|
| 127 |
+
|
| 128 |
+
def _get_name(self):
|
| 129 |
+
return 'QuantizedLeakyReLU'
|
| 130 |
+
|
| 131 |
+
@classmethod
|
| 132 |
+
def from_float(cls, mod):
|
| 133 |
+
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
| 134 |
+
return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace)
|
| 135 |
+
|
| 136 |
+
@classmethod
|
| 137 |
+
def from_reference(cls, mod, scale, zero_point):
|
| 138 |
+
return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace)
|
| 139 |
+
|
| 140 |
+
class Sigmoid(torch.nn.Sigmoid):
|
| 141 |
+
r"""This is the quantized equivalent of :class:`~torch.nn.Sigmoid`.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
scale: quantization scale of the output tensor
|
| 145 |
+
zero_point: quantization zero point of the output tensor
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
def __init__(self, output_scale: float, output_zero_point: int):
|
| 149 |
+
super().__init__()
|
| 150 |
+
self.output_scale = output_scale
|
| 151 |
+
self.output_zero_point = output_zero_point
|
| 152 |
+
|
| 153 |
+
def forward(self, input):
|
| 154 |
+
return torch.ops.quantized.sigmoid(input, self.output_scale, self.output_zero_point)
|
| 155 |
+
|
| 156 |
+
@classmethod
|
| 157 |
+
def from_float(cls, mod):
|
| 158 |
+
output_scale, output_zero_point = mod.activation_post_process.calculate_qparams()
|
| 159 |
+
return cls(float(output_scale), int(output_zero_point))
|
| 160 |
+
|
| 161 |
+
class Softmax(torch.nn.Softmax):
|
| 162 |
+
r"""This is the quantized version of :class:`~torch.nn.Softmax`.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
dim: A dimension along which Softmax will be computed (so every slice along dim will sum to 1).
|
| 166 |
+
scale: quantization scale of the output tensor
|
| 167 |
+
zero_point: quantization zero point of the output tensor
|
| 168 |
+
"""
|
| 169 |
+
def __init__(self, dim=None, scale=1.0, zero_point=0):
|
| 170 |
+
super().__init__()
|
| 171 |
+
self.dim = dim
|
| 172 |
+
self.scale = scale
|
| 173 |
+
self.zero_point = zero_point
|
| 174 |
+
|
| 175 |
+
def forward(self, input):
|
| 176 |
+
dim = self.dim
|
| 177 |
+
if dim is None:
|
| 178 |
+
stacklevel = 3
|
| 179 |
+
# Note: adding the mypy ignore on _get_softmax_dim seems less bad
|
| 180 |
+
# than making `_get_softmax_dim` an official API.
|
| 181 |
+
dim = torch.nn.functional._get_softmax_dim( # type: ignore[attr-defined]
|
| 182 |
+
"softmax", input.dim(), stacklevel)
|
| 183 |
+
return torch.ops.quantized.softmax(
|
| 184 |
+
input, dim, self.scale, self.zero_point)
|
| 185 |
+
|
| 186 |
+
def _get_name(self):
|
| 187 |
+
return 'QuantizedSoftmax'
|
| 188 |
+
|
| 189 |
+
@staticmethod
|
| 190 |
+
def from_float(mod):
|
| 191 |
+
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
| 192 |
+
return Softmax(mod.dim, float(scale), int(zero_point))
|
| 193 |
+
|
| 194 |
+
@classmethod
|
| 195 |
+
def from_reference(cls, mod, scale, zero_point):
|
| 196 |
+
return cls(mod.dim, float(scale), int(zero_point))
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention):
|
| 200 |
+
_FLOAT_MODULE = torch.ao.nn.quantizable.MultiheadAttention
|
| 201 |
+
|
| 202 |
+
def _get_name(self):
|
| 203 |
+
return "QuantizedMultiheadAttention"
|
| 204 |
+
|
| 205 |
+
@classmethod
|
| 206 |
+
def from_float(cls, other):
|
| 207 |
+
# The whole flow is float -> observed -> quantized
|
| 208 |
+
# This class does observed -> quantized only
|
| 209 |
+
raise NotImplementedError("It looks like you are trying to convert a "
|
| 210 |
+
"non-observed MHA module. Please, see "
|
| 211 |
+
"the examples on quantizable MHAs.")
|
| 212 |
+
|
| 213 |
+
@classmethod
|
| 214 |
+
def from_observed(cls, other):
|
| 215 |
+
converted = torch.ao.quantization.convert(other, mapping=None,
|
| 216 |
+
inplace=False,
|
| 217 |
+
remove_qconfig=True,
|
| 218 |
+
convert_custom_config_dict=None)
|
| 219 |
+
converted.__class__ = cls
|
| 220 |
+
# Remove the parameters for the bias_k and bias_v to quantize them
|
| 221 |
+
# TODO: This is a potential source of accuracy drop.
|
| 222 |
+
# quantized cat takes the scale and zp of the first
|
| 223 |
+
# element, which might lose the precision in the bias_k
|
| 224 |
+
# and the bias_v (which are cat'ed with k/v being first).
|
| 225 |
+
if converted.bias_k is not None:
|
| 226 |
+
bias_k = converted._parameters.pop('bias_k')
|
| 227 |
+
sc, zp = torch._choose_qparams_per_tensor(bias_k,
|
| 228 |
+
reduce_range=False)
|
| 229 |
+
bias_k = torch.quantize_per_tensor(bias_k, sc, zp, torch.quint8)
|
| 230 |
+
setattr(converted, 'bias_k', bias_k) # noqa: B010
|
| 231 |
+
|
| 232 |
+
if converted.bias_v is not None:
|
| 233 |
+
bias_v = converted._parameters.pop('bias_v')
|
| 234 |
+
sc, zp = torch._choose_qparams_per_tensor(bias_k, # type: ignore[possibly-undefined]
|
| 235 |
+
reduce_range=False)
|
| 236 |
+
bias_v = torch.quantize_per_tensor(bias_v, sc, zp, torch.quint8)
|
| 237 |
+
setattr(converted, 'bias_v', bias_v) # noqa: B010
|
| 238 |
+
|
| 239 |
+
del converted.in_proj_weight
|
| 240 |
+
del converted.in_proj_bias
|
| 241 |
+
|
| 242 |
+
return converted
|
| 243 |
+
|
| 244 |
+
class PReLU(torch.nn.Module):
|
| 245 |
+
r"""This is the quantized equivalent of :class:`~torch.nn.PReLU`.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
scale: quantization scale of the output tensor
|
| 249 |
+
zero_point: quantization zero point of the output tensor
|
| 250 |
+
num_parameters: number of parameters: 1, or the number of channels at input. Default: 1
|
| 251 |
+
"""
|
| 252 |
+
def __init__(self, output_scale: float, output_zero_point: int,
|
| 253 |
+
num_parameters: int = 1) -> None:
|
| 254 |
+
super().__init__()
|
| 255 |
+
self.num_parameters = num_parameters
|
| 256 |
+
self.scale = output_scale
|
| 257 |
+
self.zero_point = output_zero_point
|
| 258 |
+
w = torch.randn(num_parameters, dtype=torch.float)
|
| 259 |
+
qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.quint8)
|
| 260 |
+
self.set_weight(qw)
|
| 261 |
+
|
| 262 |
+
def set_weight(self, w: torch.Tensor) -> None:
|
| 263 |
+
self.weight = w
|
| 264 |
+
|
| 265 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 266 |
+
return torch.ops.quantized.prelu(input, self.weight, self.scale, self.zero_point)
|
| 267 |
+
|
| 268 |
+
def _get_name(self):
|
| 269 |
+
return 'QuantizedPReLU'
|
| 270 |
+
|
| 271 |
+
@classmethod
|
| 272 |
+
def from_float(cls, mod):
|
| 273 |
+
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
| 274 |
+
qprelu = cls(float(scale), int(zero_point), mod.num_parameters)
|
| 275 |
+
float_wt = mod.weight.float()
|
| 276 |
+
observer = mod.qconfig.weight()
|
| 277 |
+
observer(float_wt)
|
| 278 |
+
if observer.dtype != torch.quint8:
|
| 279 |
+
warn(
|
| 280 |
+
f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}"
|
| 281 |
+
)
|
| 282 |
+
wt_scale, wt_zp = observer.calculate_qparams()
|
| 283 |
+
qweight = torch.quantize_per_tensor(
|
| 284 |
+
float_wt, float(wt_scale), int(wt_zp), torch.quint8)
|
| 285 |
+
qprelu.set_weight(qweight)
|
| 286 |
+
return qprelu
|
| 287 |
+
|
| 288 |
+
@classmethod
|
| 289 |
+
def from_reference(cls, mod, scale, zero_point):
|
| 290 |
+
qprelu = cls(float(scale), int(zero_point), mod.num_parameters)
|
| 291 |
+
float_wt = mod.weight.float()
|
| 292 |
+
observer = mod.qconfig.weight()
|
| 293 |
+
observer(float_wt)
|
| 294 |
+
if observer.dtype != torch.quint8:
|
| 295 |
+
warn(
|
| 296 |
+
f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}"
|
| 297 |
+
)
|
| 298 |
+
wt_scale, wt_zp = observer.calculate_qparams()
|
| 299 |
+
qweight = torch.quantize_per_tensor(
|
| 300 |
+
float_wt, float(wt_scale), int(wt_zp), torch.quint8)
|
| 301 |
+
qprelu.set_weight(qweight)
|
| 302 |
+
return qprelu
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/conv.py
ADDED
|
@@ -0,0 +1,945 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r"""Quantized convolution modules."""
|
| 2 |
+
|
| 3 |
+
from typing import Optional, List, TypeVar
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import torch.ao.nn.intrinsic as nni
|
| 9 |
+
import torch.ao.nn.intrinsic.qat as nniqat
|
| 10 |
+
|
| 11 |
+
from torch._ops import ops
|
| 12 |
+
from torch.nn.common_types import _size_1_t
|
| 13 |
+
from torch.nn.modules.utils import _single, _pair, _triple
|
| 14 |
+
from torch.nn.utils import fuse_conv_bn_weights
|
| 15 |
+
|
| 16 |
+
from .utils import _quantize_weight, WeightedQuantizedModule
|
| 17 |
+
|
| 18 |
+
__all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d']
|
| 19 |
+
|
| 20 |
+
_SUPPORTED_PADDING = {
|
| 21 |
+
'zeros',
|
| 22 |
+
'reflect'
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _reverse_repeat_padding(padding: List[int]) -> List[int]:
|
| 27 |
+
_reversed_padding_repeated_twice: List[int] = []
|
| 28 |
+
N = len(padding)
|
| 29 |
+
for idx in range(N):
|
| 30 |
+
for _ in range(2):
|
| 31 |
+
_reversed_padding_repeated_twice.append(padding[N - idx - 1])
|
| 32 |
+
return _reversed_padding_repeated_twice
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class _ConvNd(WeightedQuantizedModule):
|
| 36 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
| 37 |
+
padding=0, dilation=1, groups=1, bias=True,
|
| 38 |
+
padding_mode='zeros', device=None, dtype=None):
|
| 39 |
+
# All subclasses have this signature - See PR #49702s
|
| 40 |
+
raise NotImplementedError
|
| 41 |
+
|
| 42 |
+
def _init(self, in_channels, out_channels, kernel_size, stride,
|
| 43 |
+
padding, dilation,
|
| 44 |
+
transposed, output_padding,
|
| 45 |
+
groups, bias,
|
| 46 |
+
padding_mode='zeros',
|
| 47 |
+
device=None,
|
| 48 |
+
dtype=None) -> None:
|
| 49 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 50 |
+
super().__init__()
|
| 51 |
+
|
| 52 |
+
if in_channels % groups != 0:
|
| 53 |
+
raise ValueError('in_channels must be divisible by groups')
|
| 54 |
+
if out_channels % groups != 0:
|
| 55 |
+
raise ValueError('out_channels must be divisible by groups')
|
| 56 |
+
self.in_channels = in_channels
|
| 57 |
+
self.out_channels = out_channels
|
| 58 |
+
self.kernel_size = kernel_size
|
| 59 |
+
self.stride = stride
|
| 60 |
+
self.padding = padding
|
| 61 |
+
self.dilation = dilation
|
| 62 |
+
self.transposed = transposed
|
| 63 |
+
self.output_padding = output_padding
|
| 64 |
+
self.groups = groups
|
| 65 |
+
if padding_mode not in _SUPPORTED_PADDING:
|
| 66 |
+
raise ValueError(f"'padding_mode' {padding_mode} is not supported by quantized convolution")
|
| 67 |
+
self.padding_mode = padding_mode
|
| 68 |
+
# Initialize as NCHW. set_weight will internally transpose to NHWC.
|
| 69 |
+
if self.transposed:
|
| 70 |
+
weight_shape = [in_channels, out_channels // self.groups]
|
| 71 |
+
else:
|
| 72 |
+
weight_shape = [out_channels, in_channels // self.groups]
|
| 73 |
+
qweight = torch._empty_affine_quantized(
|
| 74 |
+
weight_shape + list(kernel_size),
|
| 75 |
+
scale=1, zero_point=0, dtype=torch.qint8,
|
| 76 |
+
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'})
|
| 77 |
+
bias_float = (
|
| 78 |
+
torch.zeros(out_channels, dtype=torch.float,
|
| 79 |
+
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'}) if bias else None)
|
| 80 |
+
|
| 81 |
+
self.set_weight_bias(qweight, bias_float)
|
| 82 |
+
self.scale = 1.0
|
| 83 |
+
self.zero_point = 0
|
| 84 |
+
|
| 85 |
+
def set_weight_bias(self, qweight, bias_float):
|
| 86 |
+
raise NotImplementedError
|
| 87 |
+
|
| 88 |
+
def bias(self):
|
| 89 |
+
raise NotImplementedError
|
| 90 |
+
|
| 91 |
+
def _weight_bias(self):
|
| 92 |
+
raise NotImplementedError
|
| 93 |
+
|
| 94 |
+
def extra_repr(self):
|
| 95 |
+
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
|
| 96 |
+
', stride={stride}, scale={scale}, zero_point={zero_point}')
|
| 97 |
+
if self.padding != (0,) * len(self.padding):
|
| 98 |
+
s += ', padding={padding}'
|
| 99 |
+
if self.dilation != (1,) * len(self.dilation):
|
| 100 |
+
s += ', dilation={dilation}'
|
| 101 |
+
if self.output_padding != (0,) * len(self.output_padding):
|
| 102 |
+
s += ', output_padding={output_padding}'
|
| 103 |
+
if self.groups != 1:
|
| 104 |
+
s += ', groups={groups}'
|
| 105 |
+
if self.bias() is None:
|
| 106 |
+
s += ', bias=False'
|
| 107 |
+
return s.format(**self.__dict__)
|
| 108 |
+
|
| 109 |
+
# ===== Serialization methods =====
|
| 110 |
+
# The special consideration here is that we have to unpack the weights into
|
| 111 |
+
# their regular QTensor form for serialization. Packed weights should not
|
| 112 |
+
# live outside the process in which they were created, rather they should be
|
| 113 |
+
# derived from the QTensor weight.
|
| 114 |
+
# self
|
| 115 |
+
# |--- weight : Tensor
|
| 116 |
+
# |--- bias : Tensor
|
| 117 |
+
#
|
| 118 |
+
# TODO: maybe change to this when https://github.com/pytorch/pytorch/pull/32958 is landed
|
| 119 |
+
# self
|
| 120 |
+
# |--- _packed_params : Conv2dPackedParamsBase or Conv3dPackedParamsBase
|
| 121 |
+
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
| 122 |
+
super()._save_to_state_dict(destination, prefix, keep_vars)
|
| 123 |
+
(w, b) = self._weight_bias()
|
| 124 |
+
destination[prefix + 'weight'] = w
|
| 125 |
+
destination[prefix + 'bias'] = b
|
| 126 |
+
destination[prefix + 'scale'] = torch.tensor(self.scale)
|
| 127 |
+
destination[prefix + 'zero_point'] = torch.tensor(self.zero_point)
|
| 128 |
+
|
| 129 |
+
@torch.jit.export
|
| 130 |
+
def __getstate__(self):
|
| 131 |
+
(w, b) = self._weight_bias()
|
| 132 |
+
return (
|
| 133 |
+
self.in_channels,
|
| 134 |
+
self.out_channels,
|
| 135 |
+
self.kernel_size,
|
| 136 |
+
self.stride,
|
| 137 |
+
self.padding,
|
| 138 |
+
self.dilation,
|
| 139 |
+
self.transposed,
|
| 140 |
+
self.output_padding,
|
| 141 |
+
self.groups,
|
| 142 |
+
self.padding_mode,
|
| 143 |
+
w,
|
| 144 |
+
b,
|
| 145 |
+
self.scale,
|
| 146 |
+
self.zero_point,
|
| 147 |
+
self.training
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# ===== Deserialization methods =====
|
| 151 |
+
# Counterpart to the serialization methods, we must pack the serialized
|
| 152 |
+
# QTensor weight into its packed format for use by the FBGEMM ops.
|
| 153 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
| 154 |
+
missing_keys, unexpected_keys, error_msgs):
|
| 155 |
+
self.set_weight_bias(
|
| 156 |
+
state_dict[prefix + 'weight'], state_dict[prefix + 'bias'])
|
| 157 |
+
state_dict.pop(prefix + 'weight')
|
| 158 |
+
state_dict.pop(prefix + 'bias')
|
| 159 |
+
self.scale = float(state_dict[prefix + 'scale'])
|
| 160 |
+
state_dict.pop(prefix + 'scale')
|
| 161 |
+
self.zero_point = int(state_dict[prefix + 'zero_point'])
|
| 162 |
+
state_dict.pop(prefix + 'zero_point')
|
| 163 |
+
super()._load_from_state_dict(
|
| 164 |
+
state_dict, prefix, local_metadata, False, missing_keys,
|
| 165 |
+
unexpected_keys, error_msgs)
|
| 166 |
+
|
| 167 |
+
@torch.jit.export
|
| 168 |
+
def __setstate__(self, state):
|
| 169 |
+
self.in_channels = state[0]
|
| 170 |
+
self.out_channels = state[1]
|
| 171 |
+
self.kernel_size = state[2]
|
| 172 |
+
self.stride = state[3]
|
| 173 |
+
self.padding = state[4]
|
| 174 |
+
self.dilation = state[5]
|
| 175 |
+
self.transposed = state[6]
|
| 176 |
+
self.output_padding = state[7]
|
| 177 |
+
self.groups = state[8]
|
| 178 |
+
self.padding_mode = state[9]
|
| 179 |
+
self.set_weight_bias(state[10], state[11])
|
| 180 |
+
self.scale = state[12]
|
| 181 |
+
self.zero_point = state[13]
|
| 182 |
+
self.training = state[14]
|
| 183 |
+
|
| 184 |
+
def __deepcopy__(self, memo):
|
| 185 |
+
new_instance = type(self).__new__(type(self))
|
| 186 |
+
torch.nn.Module.__init__(new_instance)
|
| 187 |
+
state = self.__getstate__()
|
| 188 |
+
new_instance.__setstate__(state)
|
| 189 |
+
return new_instance
|
| 190 |
+
|
| 191 |
+
def __copy__(self):
|
| 192 |
+
return self.__deepcopy__({})
|
| 193 |
+
|
| 194 |
+
@classmethod
|
| 195 |
+
def get_qconv(cls, mod, activation_post_process, weight_post_process=None):
|
| 196 |
+
r"""Creates a qconv object and returns it.
|
| 197 |
+
"""
|
| 198 |
+
if weight_post_process is None:
|
| 199 |
+
weight_post_process = mod.qconfig.weight()
|
| 200 |
+
weight_post_process(mod.weight)
|
| 201 |
+
assert weight_post_process.dtype == torch.qint8, \
|
| 202 |
+
'Weight observer must have a dtype of qint8'
|
| 203 |
+
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
|
| 204 |
+
# the __init__ call used is the one from derived classes and not the one from _ConvNd
|
| 205 |
+
qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,
|
| 206 |
+
mod.stride, mod.padding, mod.dilation, mod.groups,
|
| 207 |
+
mod.bias is not None, mod.padding_mode)
|
| 208 |
+
qconv.set_weight_bias(qweight, mod.bias)
|
| 209 |
+
if activation_post_process is None or activation_post_process.dtype == torch.float:
|
| 210 |
+
return qconv # dynamic quantization doesn't need scale/zero_point
|
| 211 |
+
else:
|
| 212 |
+
act_scale, act_zp = activation_post_process.calculate_qparams()
|
| 213 |
+
qconv.scale = float(act_scale)
|
| 214 |
+
qconv.zero_point = int(act_zp)
|
| 215 |
+
return qconv
|
| 216 |
+
|
| 217 |
+
@staticmethod
|
| 218 |
+
def from_float(cls, mod):
|
| 219 |
+
if hasattr(mod, "weight_fake_quant"):
|
| 220 |
+
# assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \
|
| 221 |
+
# ".from_float only works for " + cls.__QAT_MODULE.__name__
|
| 222 |
+
if type(mod) == cls._NNIQAT_CONV_BN_MODULE:
|
| 223 |
+
mod.weight, mod.bias = fuse_conv_bn_weights(
|
| 224 |
+
mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
|
| 225 |
+
mod.bn.eps, mod.bn.weight, mod.bn.bias)
|
| 226 |
+
assert hasattr(mod, "activation_post_process"), \
|
| 227 |
+
"Input QAT module must have observer attached"
|
| 228 |
+
weight_post_process = mod.weight_fake_quant
|
| 229 |
+
activation_post_process = mod.activation_post_process
|
| 230 |
+
else:
|
| 231 |
+
assert type(mod) == cls._FLOAT_MODULE, \
|
| 232 |
+
" nnq." + cls.__name__ + ".from_float only works for " + \
|
| 233 |
+
cls._FLOAT_MODULE.__name__ + " but got:" + str(type(mod))
|
| 234 |
+
assert hasattr(mod, "qconfig"), \
|
| 235 |
+
"Input float module must have qconfig defined."
|
| 236 |
+
activation_post_process = None if not hasattr(
|
| 237 |
+
mod, "activation_post_process") else mod.activation_post_process
|
| 238 |
+
if type(mod) in [cls._NNI_CONV_RELU_MODULE, cls._NNI_CONV_ADD_MODULE, cls._NNI_CONV_ADD_RELU_MODULE]:
|
| 239 |
+
mod = mod[0]
|
| 240 |
+
weight_post_process = mod.qconfig.weight()
|
| 241 |
+
return cls.get_qconv(mod, activation_post_process, weight_post_process)
|
| 242 |
+
|
| 243 |
+
@classmethod
|
| 244 |
+
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
|
| 245 |
+
r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module
|
| 246 |
+
Args:
|
| 247 |
+
ref_qconv (Module): a reference quantized module, either produced by torch.ao.quantization
|
| 248 |
+
utilities or provided by the user
|
| 249 |
+
output_scale (float): scale for output Tensor
|
| 250 |
+
output_zero_point (int): zero point for output Tensor
|
| 251 |
+
"""
|
| 252 |
+
qconv = cls(
|
| 253 |
+
ref_qconv.in_channels,
|
| 254 |
+
ref_qconv.out_channels,
|
| 255 |
+
ref_qconv.kernel_size, # type: ignore[arg-type]
|
| 256 |
+
ref_qconv.stride, # type: ignore[arg-type]
|
| 257 |
+
ref_qconv.padding, # type: ignore[arg-type]
|
| 258 |
+
ref_qconv.dilation, # type: ignore[arg-type]
|
| 259 |
+
ref_qconv.groups,
|
| 260 |
+
ref_qconv.bias is not None, # type: ignore[arg-type]
|
| 261 |
+
ref_qconv.padding_mode,
|
| 262 |
+
device=ref_qconv.weight.device,
|
| 263 |
+
dtype=ref_qconv.weight.dtype)
|
| 264 |
+
qweight = ref_qconv.get_quantized_weight()
|
| 265 |
+
qconv.set_weight_bias(qweight, ref_qconv.bias)
|
| 266 |
+
qconv.scale = float(output_scale)
|
| 267 |
+
qconv.zero_point = int(output_zero_point)
|
| 268 |
+
return qconv
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class Conv1d(_ConvNd):
|
| 272 |
+
r"""Applies a 1D convolution over a quantized input signal composed of
|
| 273 |
+
several quantized input planes.
|
| 274 |
+
|
| 275 |
+
For details on input arguments, parameters, and implementation see
|
| 276 |
+
:class:`~torch.nn.Conv1d`.
|
| 277 |
+
|
| 278 |
+
.. note::
|
| 279 |
+
Only `zeros` is supported for the :attr:`padding_mode` argument.
|
| 280 |
+
|
| 281 |
+
.. note::
|
| 282 |
+
Only `torch.quint8` is supported for the input data type.
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
Attributes:
|
| 286 |
+
weight (Tensor): packed tensor derived from the learnable weight
|
| 287 |
+
parameter.
|
| 288 |
+
scale (Tensor): scalar for the output scale
|
| 289 |
+
zero_point (Tensor): scalar for the output zero point
|
| 290 |
+
|
| 291 |
+
See :class:`~torch.nn.Conv1d` for other attributes.
|
| 292 |
+
|
| 293 |
+
Examples::
|
| 294 |
+
|
| 295 |
+
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
|
| 296 |
+
>>> m = nn.quantized.Conv1d(16, 33, 3, stride=2)
|
| 297 |
+
>>> input = torch.randn(20, 16, 100)
|
| 298 |
+
>>> # quantize input to quint8
|
| 299 |
+
>>> # xdoctest: +SKIP
|
| 300 |
+
>>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0,
|
| 301 |
+
... dtype=torch.quint8)
|
| 302 |
+
>>> output = m(q_input)
|
| 303 |
+
|
| 304 |
+
"""
|
| 305 |
+
|
| 306 |
+
_FLOAT_MODULE = nn.Conv1d
|
| 307 |
+
_NNIQAT_CONV_BN_MODULE = nniqat.ConvBn1d
|
| 308 |
+
_NNI_CONV_RELU_MODULE = nni.ConvReLU1d
|
| 309 |
+
_NNI_CONV_ADD_MODULE: None = None
|
| 310 |
+
_NNI_CONV_ADD_RELU_MODULE: None = None
|
| 311 |
+
|
| 312 |
+
def __init__(self,
|
| 313 |
+
in_channels: int,
|
| 314 |
+
out_channels: int,
|
| 315 |
+
kernel_size: _size_1_t,
|
| 316 |
+
stride: _size_1_t = 1,
|
| 317 |
+
padding: _size_1_t = 0,
|
| 318 |
+
dilation: _size_1_t = 1,
|
| 319 |
+
groups: int = 1,
|
| 320 |
+
bias: bool = True,
|
| 321 |
+
padding_mode: str = 'zeros',
|
| 322 |
+
device=None,
|
| 323 |
+
dtype=None):
|
| 324 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 325 |
+
kernel_size = _single(kernel_size)
|
| 326 |
+
stride = _single(stride)
|
| 327 |
+
padding = padding if isinstance(padding, str) else _single(padding)
|
| 328 |
+
dilation = _single(dilation)
|
| 329 |
+
|
| 330 |
+
# Subclasses of _ConvNd needs to call _init rather than __init__. See
|
| 331 |
+
# discussion on PR #49702
|
| 332 |
+
super()._init(
|
| 333 |
+
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
| 334 |
+
False, _single(0), groups, bias, padding_mode, **factory_kwargs)
|
| 335 |
+
|
| 336 |
+
def _get_name(self):
|
| 337 |
+
return 'QuantizedConv1d'
|
| 338 |
+
|
| 339 |
+
def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
|
| 340 |
+
if self.padding_mode == 'zeros':
|
| 341 |
+
self._packed_params = torch.ops.quantized.conv1d_prepack(
|
| 342 |
+
w, b, self.stride, self.padding, self.dilation, self.groups)
|
| 343 |
+
else:
|
| 344 |
+
self._packed_params = torch.ops.quantized.conv1d_prepack(
|
| 345 |
+
w, b, self.stride, _pair(0), self.dilation,
|
| 346 |
+
self.groups)
|
| 347 |
+
|
| 348 |
+
def _weight_bias(self):
|
| 349 |
+
w, b = torch.ops.quantized.conv1d_unpack(self._packed_params)
|
| 350 |
+
return w, b
|
| 351 |
+
|
| 352 |
+
def weight(self):
|
| 353 |
+
return self._weight_bias()[0]
|
| 354 |
+
|
| 355 |
+
def bias(self):
|
| 356 |
+
return self._weight_bias()[1]
|
| 357 |
+
|
| 358 |
+
def forward(self, input):
|
| 359 |
+
# Temporarily using len(shape) instead of ndim due to JIT issue
|
| 360 |
+
# https://github.com/pytorch/pytorch/issues/23890
|
| 361 |
+
if len(input.shape) != 3:
|
| 362 |
+
raise ValueError("Input shape must be `(N, C, L)`!")
|
| 363 |
+
if self.padding_mode != 'zeros':
|
| 364 |
+
# Padding in Conv1d is stored as (p, p), need to get (p,)
|
| 365 |
+
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1])
|
| 366 |
+
input = F.pad(input, _reversed_padding_repeated_twice,
|
| 367 |
+
mode=self.padding_mode)
|
| 368 |
+
return ops.quantized.conv1d(input, self._packed_params, self.scale, self.zero_point)
|
| 369 |
+
|
| 370 |
+
@classmethod
|
| 371 |
+
def from_float(cls, mod):
|
| 372 |
+
r"""Creates a quantized module from a float module or qparams_dict.
|
| 373 |
+
|
| 374 |
+
Args:
|
| 375 |
+
mod (Module): a float module, either produced by torch.ao.quantization
|
| 376 |
+
utilities or provided by the user
|
| 377 |
+
"""
|
| 378 |
+
return _ConvNd.from_float(cls, mod)
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
class Conv2d(_ConvNd):
|
| 382 |
+
r"""Applies a 2D convolution over a quantized input signal composed of
|
| 383 |
+
several quantized input planes.
|
| 384 |
+
|
| 385 |
+
For details on input arguments, parameters, and implementation see
|
| 386 |
+
:class:`~torch.nn.Conv2d`.
|
| 387 |
+
|
| 388 |
+
.. note::
|
| 389 |
+
Only `zeros` is supported for the :attr:`padding_mode` argument.
|
| 390 |
+
|
| 391 |
+
.. note::
|
| 392 |
+
Only `torch.quint8` is supported for the input data type.
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
Attributes:
|
| 396 |
+
weight (Tensor): packed tensor derived from the learnable weight
|
| 397 |
+
parameter.
|
| 398 |
+
scale (Tensor): scalar for the output scale
|
| 399 |
+
zero_point (Tensor): scalar for the output zero point
|
| 400 |
+
|
| 401 |
+
See :class:`~torch.nn.Conv2d` for other attributes.
|
| 402 |
+
|
| 403 |
+
Examples::
|
| 404 |
+
|
| 405 |
+
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
|
| 406 |
+
>>> # With square kernels and equal stride
|
| 407 |
+
>>> m = nn.quantized.Conv2d(16, 33, 3, stride=2)
|
| 408 |
+
>>> # non-square kernels and unequal stride and with padding
|
| 409 |
+
>>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
|
| 410 |
+
>>> # non-square kernels and unequal stride and with padding and dilation
|
| 411 |
+
>>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
|
| 412 |
+
>>> input = torch.randn(20, 16, 50, 100)
|
| 413 |
+
>>> # quantize input to quint8
|
| 414 |
+
>>> # xdoctest: +SKIP
|
| 415 |
+
>>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
|
| 416 |
+
>>> output = m(q_input)
|
| 417 |
+
|
| 418 |
+
"""
|
| 419 |
+
_FLOAT_MODULE = nn.Conv2d
|
| 420 |
+
_NNIQAT_CONV_BN_MODULE = nniqat.ConvBn2d
|
| 421 |
+
_NNI_CONV_RELU_MODULE = nni.ConvReLU2d
|
| 422 |
+
_NNI_CONV_ADD_MODULE = nni.ConvAdd2d
|
| 423 |
+
_NNI_CONV_ADD_RELU_MODULE = nni.ConvAddReLU2d
|
| 424 |
+
|
| 425 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
| 426 |
+
padding=0, dilation=1, groups=1, bias=True,
|
| 427 |
+
padding_mode='zeros', device=None, dtype=None):
|
| 428 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 429 |
+
kernel_size = _pair(kernel_size)
|
| 430 |
+
stride = _pair(stride)
|
| 431 |
+
padding = _pair(padding)
|
| 432 |
+
dilation = _pair(dilation)
|
| 433 |
+
# Subclasses of _ConvNd need to call _init rather than __init__. See
|
| 434 |
+
# discussion on PR #49702
|
| 435 |
+
super()._init(
|
| 436 |
+
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
| 437 |
+
False, _pair(0), groups, bias, padding_mode, **factory_kwargs)
|
| 438 |
+
|
| 439 |
+
def _get_name(self):
|
| 440 |
+
return 'QuantizedConv2d'
|
| 441 |
+
|
| 442 |
+
def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
|
| 443 |
+
if self.padding_mode == 'zeros':
|
| 444 |
+
self._packed_params = torch.ops.quantized.conv2d_prepack(
|
| 445 |
+
w, b, self.stride, self.padding, self.dilation, self.groups)
|
| 446 |
+
else:
|
| 447 |
+
self._packed_params = torch.ops.quantized.conv2d_prepack(
|
| 448 |
+
w, b, self.stride, _pair(0), self.dilation, self.groups)
|
| 449 |
+
|
| 450 |
+
def _weight_bias(self):
|
| 451 |
+
return self._packed_params.unpack()
|
| 452 |
+
|
| 453 |
+
def weight(self):
|
| 454 |
+
return self._weight_bias()[0]
|
| 455 |
+
|
| 456 |
+
def bias(self):
|
| 457 |
+
return self._weight_bias()[1]
|
| 458 |
+
|
| 459 |
+
def forward(self, input):
|
| 460 |
+
# Temporarily using len(shape) instead of ndim due to JIT issue
|
| 461 |
+
# https://github.com/pytorch/pytorch/issues/23890
|
| 462 |
+
if len(input.shape) != 4:
|
| 463 |
+
raise ValueError("Input shape must be `(N, C, H, W)`!")
|
| 464 |
+
if self.padding_mode != 'zeros':
|
| 465 |
+
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
|
| 466 |
+
input = F.pad(input, _reversed_padding_repeated_twice,
|
| 467 |
+
mode=self.padding_mode)
|
| 468 |
+
return ops.quantized.conv2d(
|
| 469 |
+
input, self._packed_params, self.scale, self.zero_point)
|
| 470 |
+
|
| 471 |
+
@classmethod
|
| 472 |
+
def from_float(cls, mod):
|
| 473 |
+
r"""Creates a quantized module from a float module or qparams_dict.
|
| 474 |
+
|
| 475 |
+
Args:
|
| 476 |
+
mod (Module): a float module, either produced by torch.ao.quantization
|
| 477 |
+
utilities or provided by the user
|
| 478 |
+
"""
|
| 479 |
+
return _ConvNd.from_float(cls, mod)
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
class Conv3d(_ConvNd):
|
| 483 |
+
r"""Applies a 3D convolution over a quantized input signal composed of
|
| 484 |
+
several quantized input planes.
|
| 485 |
+
|
| 486 |
+
For details on input arguments, parameters, and implementation see
|
| 487 |
+
:class:`~torch.nn.Conv3d`.
|
| 488 |
+
|
| 489 |
+
.. note::
|
| 490 |
+
Only `zeros` is supported for the :attr:`padding_mode` argument.
|
| 491 |
+
|
| 492 |
+
.. note::
|
| 493 |
+
Only `torch.quint8` is supported for the input data type.
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
Attributes:
|
| 497 |
+
weight (Tensor): packed tensor derived from the learnable weight
|
| 498 |
+
parameter.
|
| 499 |
+
scale (Tensor): scalar for the output scale
|
| 500 |
+
zero_point (Tensor): scalar for the output zero point
|
| 501 |
+
|
| 502 |
+
See :class:`~torch.nn.Conv3d` for other attributes.
|
| 503 |
+
|
| 504 |
+
Examples::
|
| 505 |
+
|
| 506 |
+
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
|
| 507 |
+
>>> # With square kernels and equal stride
|
| 508 |
+
>>> m = nn.quantized.Conv3d(16, 33, 3, stride=2)
|
| 509 |
+
>>> # non-square kernels and unequal stride and with padding
|
| 510 |
+
>>> m = nn.quantized.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2))
|
| 511 |
+
>>> # non-square kernels and unequal stride and with padding and dilation
|
| 512 |
+
>>> m = nn.quantized.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), dilation=(1, 2, 2))
|
| 513 |
+
>>> input = torch.randn(20, 16, 56, 56, 56)
|
| 514 |
+
>>> # quantize input to quint8
|
| 515 |
+
>>> # xdoctest: +SKIP
|
| 516 |
+
>>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
|
| 517 |
+
>>> output = m(q_input)
|
| 518 |
+
|
| 519 |
+
"""
|
| 520 |
+
_FLOAT_MODULE = nn.Conv3d
|
| 521 |
+
_NNIQAT_CONV_BN_MODULE = nniqat.ConvBn3d
|
| 522 |
+
_NNI_CONV_RELU_MODULE = nni.ConvReLU3d
|
| 523 |
+
_NNI_CONV_ADD_MODULE: None = None
|
| 524 |
+
_NNI_CONV_ADD_RELU_MODULE: None = None
|
| 525 |
+
|
| 526 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
| 527 |
+
padding=0, dilation=1, groups=1, bias=True,
|
| 528 |
+
padding_mode='zeros', device=None, dtype=None):
|
| 529 |
+
assert padding_mode != 'reflect', "Conv3d does not support reflection padding"
|
| 530 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 531 |
+
kernel_size = _triple(kernel_size)
|
| 532 |
+
stride = _triple(stride)
|
| 533 |
+
padding = _triple(padding)
|
| 534 |
+
dilation = _triple(dilation)
|
| 535 |
+
# Subclasses of _ConvNd need to call _init rather than __init__. See
|
| 536 |
+
# discussion on PR #49702
|
| 537 |
+
super()._init(
|
| 538 |
+
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
| 539 |
+
False, _triple(0), groups, bias, padding_mode, **factory_kwargs)
|
| 540 |
+
|
| 541 |
+
def _get_name(self):
|
| 542 |
+
return 'QuantizedConv3d'
|
| 543 |
+
|
| 544 |
+
def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
|
| 545 |
+
if self.padding_mode == 'zeros':
|
| 546 |
+
self._packed_params = torch.ops.quantized.conv3d_prepack(
|
| 547 |
+
w, b, self.stride, self.padding, self.dilation, self.groups)
|
| 548 |
+
else:
|
| 549 |
+
self._packed_params = torch.ops.quantized.conv3d_prepack(
|
| 550 |
+
w, b, self.stride, _triple(0), self.dilation, self.groups)
|
| 551 |
+
|
| 552 |
+
def _weight_bias(self):
|
| 553 |
+
return self._packed_params.unpack()
|
| 554 |
+
|
| 555 |
+
def weight(self):
|
| 556 |
+
return self._weight_bias()[0]
|
| 557 |
+
|
| 558 |
+
def bias(self):
|
| 559 |
+
return self._weight_bias()[1]
|
| 560 |
+
|
| 561 |
+
def forward(self, input):
|
| 562 |
+
# Temporarily using len(shape) instead of ndim due to JIT issue
|
| 563 |
+
# https://github.com/pytorch/pytorch/issues/23890
|
| 564 |
+
if len(input.shape) != 5:
|
| 565 |
+
raise ValueError("Input shape must be `(N, C, D, H, W)`!")
|
| 566 |
+
if self.padding_mode != 'zeros':
|
| 567 |
+
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
|
| 568 |
+
input = F.pad(input, _reversed_padding_repeated_twice,
|
| 569 |
+
mode=self.padding_mode)
|
| 570 |
+
return ops.quantized.conv3d(
|
| 571 |
+
input, self._packed_params, self.scale, self.zero_point)
|
| 572 |
+
|
| 573 |
+
@classmethod
|
| 574 |
+
def from_float(cls, mod):
|
| 575 |
+
r"""Creates a quantized module from a float module or qparams_dict.
|
| 576 |
+
|
| 577 |
+
Args:
|
| 578 |
+
mod (Module): a float module, either produced by torch.ao.quantization
|
| 579 |
+
utilities or provided by the user
|
| 580 |
+
"""
|
| 581 |
+
return _ConvNd.from_float(cls, mod)
|
| 582 |
+
|
| 583 |
+
# === Transposed Convolutions ===
|
| 584 |
+
MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd)
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
class _ConvTransposeNd(_ConvNd):
|
| 588 |
+
|
| 589 |
+
_FLOAT_MODULE = MOD
|
| 590 |
+
|
| 591 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride,
|
| 592 |
+
padding, dilation, transposed, output_padding,
|
| 593 |
+
groups, bias, padding_mode, device=None, dtype=None):
|
| 594 |
+
if padding_mode != 'zeros':
|
| 595 |
+
raise ValueError(f'Only "zeros" padding mode is supported for {self.__class__.__name__}')
|
| 596 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 597 |
+
# Subclasses of _ConvNd need to call _init rather than __init__. See
|
| 598 |
+
# discussion on PR #49702
|
| 599 |
+
super()._init(
|
| 600 |
+
in_channels, out_channels, kernel_size, stride,
|
| 601 |
+
padding, dilation, transposed, output_padding,
|
| 602 |
+
groups, bias, padding_mode, **factory_kwargs)
|
| 603 |
+
|
| 604 |
+
def _input_padding(self, kernel_size: List[int], dilation: List[int], padding: List[int]) -> List[int]:
|
| 605 |
+
res = torch.jit.annotate(List[int], [])
|
| 606 |
+
for kdx in range(len(kernel_size)):
|
| 607 |
+
pad = (dilation[kdx] * (kernel_size[kdx] - 1) - padding[kdx])
|
| 608 |
+
res.append(pad)
|
| 609 |
+
return res
|
| 610 |
+
|
| 611 |
+
@classmethod
|
| 612 |
+
def from_float(cls, mod):
|
| 613 |
+
r"""Creates a quantized module from a float module or qparams_dict.
|
| 614 |
+
Args:
|
| 615 |
+
mod (Module): a float module, either produced by torch.ao.quantization
|
| 616 |
+
utilities or provided by the user
|
| 617 |
+
"""
|
| 618 |
+
# derived classes override cls._FLOAT_MODULE attribute
|
| 619 |
+
msg = ' nnq.' + cls.__name__ + '.from_float only works for ' + \
|
| 620 |
+
cls._FLOAT_MODULE.__name__ # type: ignore[attr-defined]
|
| 621 |
+
assert type(mod) == cls._FLOAT_MODULE, msg
|
| 622 |
+
assert hasattr(mod, 'qconfig'), \
|
| 623 |
+
'Input float module must have qconfig defined.'
|
| 624 |
+
weight_post_process = mod.qconfig.weight()
|
| 625 |
+
weight_post_process(mod.weight)
|
| 626 |
+
assert weight_post_process.dtype == torch.qint8, \
|
| 627 |
+
'Weight observer must have a dtype of qint8'
|
| 628 |
+
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
|
| 629 |
+
# the __init__ call used is the one from derived classes and not the one from _ConvTransposeNd
|
| 630 |
+
qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, # type: ignore[call-arg]
|
| 631 |
+
mod.stride, mod.padding, mod.output_padding, mod.groups,
|
| 632 |
+
mod.bias is not None, mod.dilation, mod.padding_mode)
|
| 633 |
+
qconv.set_weight_bias(qweight, mod.bias)
|
| 634 |
+
if not hasattr(mod, "activation_post_process") or mod.activation_post_process.dtype == torch.float:
|
| 635 |
+
return qconv # dynamic quantization doesn't need scale/zero_point
|
| 636 |
+
else:
|
| 637 |
+
act_scale, act_zp = mod.activation_post_process.calculate_qparams()
|
| 638 |
+
qconv.scale = float(act_scale)
|
| 639 |
+
qconv.zero_point = int(act_zp)
|
| 640 |
+
return qconv
|
| 641 |
+
|
| 642 |
+
@staticmethod
|
| 643 |
+
def from_reference(cls, ref_qconvt, output_scale, output_zero_point):
|
| 644 |
+
r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module
|
| 645 |
+
Args:
|
| 646 |
+
ref_qconvt (Module): a reference quantized module, either produced by torch.ao.quantization
|
| 647 |
+
utilities or provided by the user
|
| 648 |
+
output_scale (float): scale for output Tensor
|
| 649 |
+
output_zero_point (int): zero point for output Tensor
|
| 650 |
+
"""
|
| 651 |
+
qconv = cls(
|
| 652 |
+
ref_qconvt.in_channels,
|
| 653 |
+
ref_qconvt.out_channels,
|
| 654 |
+
ref_qconvt.kernel_size, # type: ignore[arg-type]
|
| 655 |
+
ref_qconvt.stride, # type: ignore[arg-type]
|
| 656 |
+
ref_qconvt.padding, # type: ignore[arg-type]
|
| 657 |
+
ref_qconvt.output_padding, # type: ignore[arg-type]
|
| 658 |
+
ref_qconvt.groups,
|
| 659 |
+
ref_qconvt.bias is not None, # type: ignore[arg-type]
|
| 660 |
+
ref_qconvt.dilation, # type: ignore[arg-type]
|
| 661 |
+
ref_qconvt.padding_mode,
|
| 662 |
+
device=ref_qconvt.weight.device,
|
| 663 |
+
dtype=ref_qconvt.weight.dtype)
|
| 664 |
+
qweight = ref_qconvt.get_quantized_weight()
|
| 665 |
+
qconv.set_weight_bias(qweight, ref_qconvt.bias)
|
| 666 |
+
qconv.scale = float(output_scale)
|
| 667 |
+
qconv.zero_point = int(output_zero_point)
|
| 668 |
+
return qconv
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
class ConvTranspose1d(_ConvTransposeNd):
|
| 672 |
+
r"""Applies a 1D transposed convolution operator over an input image
|
| 673 |
+
composed of several input planes.
|
| 674 |
+
For details on input arguments, parameters, and implementation see
|
| 675 |
+
:class:`~torch.nn.ConvTranspose1d`.
|
| 676 |
+
|
| 677 |
+
.. note:: Currently only the QNNPACK engine is implemented.
|
| 678 |
+
Please, set the `torch.backends.quantized.engine = 'qnnpack'`
|
| 679 |
+
|
| 680 |
+
For special notes, please, see :class:`~torch.ao.nn.quantized.Conv1d`
|
| 681 |
+
|
| 682 |
+
Attributes:
|
| 683 |
+
weight (Tensor): packed tensor derived from the learnable weight
|
| 684 |
+
parameter.
|
| 685 |
+
scale (Tensor): scalar for the output scale
|
| 686 |
+
zero_point (Tensor): scalar for the output zero point
|
| 687 |
+
See :class:`~torch.nn.ConvTranspose2d` for other attributes.
|
| 688 |
+
|
| 689 |
+
Examples::
|
| 690 |
+
|
| 691 |
+
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
|
| 692 |
+
>>> torch.backends.quantized.engine = 'qnnpack'
|
| 693 |
+
>>> from torch.ao.nn import quantized as nnq
|
| 694 |
+
>>> # With square kernels and equal stride
|
| 695 |
+
>>> m = nnq.ConvTranspose1d(16, 33, 3, stride=2)
|
| 696 |
+
>>> # non-square kernels and unequal stride and with padding
|
| 697 |
+
>>> m = nnq.ConvTranspose1d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
|
| 698 |
+
>>> input = torch.randn(20, 16, 50)
|
| 699 |
+
>>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
|
| 700 |
+
>>> output = m(q_input)
|
| 701 |
+
>>> # exact output size can be also specified as an argument
|
| 702 |
+
>>> input = torch.randn(1, 16, 12)
|
| 703 |
+
>>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
|
| 704 |
+
>>> downsample = nnq.Conv1d(16, 16, 3, stride=2, padding=1)
|
| 705 |
+
>>> upsample = nnq.ConvTranspose1d(16, 16, 3, stride=2, padding=1)
|
| 706 |
+
>>> h = downsample(q_input)
|
| 707 |
+
>>> h.size()
|
| 708 |
+
torch.Size([1, 16, 6])
|
| 709 |
+
>>> # xdoctest: +SKIP("FIXME: output_size is not a parameter)
|
| 710 |
+
>>> output = upsample(h, output_size=input.size())
|
| 711 |
+
>>> output.size()
|
| 712 |
+
torch.Size([1, 16, 12])
|
| 713 |
+
"""
|
| 714 |
+
|
| 715 |
+
_FLOAT_MODULE = nn.ConvTranspose1d
|
| 716 |
+
|
| 717 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
| 718 |
+
padding=0, output_padding=0, groups=1, bias=True,
|
| 719 |
+
dilation=1, padding_mode='zeros', device=None, dtype=None):
|
| 720 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 721 |
+
kernel_size = _single(kernel_size)
|
| 722 |
+
stride = _single(stride)
|
| 723 |
+
padding = _single(padding)
|
| 724 |
+
dilation = _single(dilation)
|
| 725 |
+
output_padding = _single(output_padding)
|
| 726 |
+
|
| 727 |
+
super().__init__(
|
| 728 |
+
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
| 729 |
+
True, output_padding, groups, bias, padding_mode, **factory_kwargs)
|
| 730 |
+
|
| 731 |
+
def _get_name(self):
|
| 732 |
+
return 'QuantizedConvTranspose1d'
|
| 733 |
+
|
| 734 |
+
def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
|
| 735 |
+
self._packed_params = torch.ops.quantized.conv_transpose1d_prepack(
|
| 736 |
+
w, b, self.stride, self.padding, self.output_padding, self.dilation,
|
| 737 |
+
self.groups)
|
| 738 |
+
|
| 739 |
+
def _weight_bias(self):
|
| 740 |
+
w, b = torch.ops.quantized.conv_transpose1d_unpack(self._packed_params)
|
| 741 |
+
return w, b
|
| 742 |
+
|
| 743 |
+
def weight(self):
|
| 744 |
+
(w, _) = self._weight_bias()
|
| 745 |
+
return w
|
| 746 |
+
|
| 747 |
+
def bias(self):
|
| 748 |
+
(_, b) = self._weight_bias()
|
| 749 |
+
return b
|
| 750 |
+
|
| 751 |
+
def forward(self, input):
|
| 752 |
+
# Temporarily using len(shape) instead of ndim due to JIT issue
|
| 753 |
+
# https://github.com/pytorch/pytorch/issues/23890
|
| 754 |
+
if len(input.shape) != 3:
|
| 755 |
+
raise ValueError("Input shape must be `(N, C, L)`!")
|
| 756 |
+
return torch.ops.quantized.conv_transpose1d(
|
| 757 |
+
input, self._packed_params, self.scale, self.zero_point)
|
| 758 |
+
|
| 759 |
+
@classmethod
|
| 760 |
+
def from_reference(cls, ref_qconvt, output_scale, output_zero_point):
|
| 761 |
+
return _ConvTransposeNd.from_reference(cls, ref_qconvt, output_scale, output_zero_point)
|
| 762 |
+
|
| 763 |
+
|
| 764 |
+
class ConvTranspose2d(_ConvTransposeNd):
|
| 765 |
+
r"""Applies a 2D transposed convolution operator over an input image
|
| 766 |
+
composed of several input planes.
|
| 767 |
+
For details on input arguments, parameters, and implementation see
|
| 768 |
+
:class:`~torch.nn.ConvTranspose2d`.
|
| 769 |
+
|
| 770 |
+
For special notes, please, see :class:`~torch.ao.nn.quantized.Conv2d`
|
| 771 |
+
|
| 772 |
+
Attributes:
|
| 773 |
+
weight (Tensor): packed tensor derived from the learnable weight
|
| 774 |
+
parameter.
|
| 775 |
+
scale (Tensor): scalar for the output scale
|
| 776 |
+
zero_point (Tensor): scalar for the output zero point
|
| 777 |
+
See :class:`~torch.nn.ConvTranspose2d` for other attributes.
|
| 778 |
+
|
| 779 |
+
Examples::
|
| 780 |
+
|
| 781 |
+
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
|
| 782 |
+
>>> # QNNPACK or FBGEMM as backend
|
| 783 |
+
>>> torch.backends.quantized.engine = 'qnnpack'
|
| 784 |
+
>>> # With square kernels and equal stride
|
| 785 |
+
>>> import torch.ao.nn.quantized as nnq
|
| 786 |
+
>>> m = nnq.ConvTranspose2d(16, 33, 3, stride=2)
|
| 787 |
+
>>> # non-square kernels and unequal stride and with padding
|
| 788 |
+
>>> m = nnq.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
|
| 789 |
+
>>> input = torch.randn(20, 16, 50, 100)
|
| 790 |
+
>>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
|
| 791 |
+
>>> output = m(q_input)
|
| 792 |
+
>>> # exact output size can be also specified as an argument
|
| 793 |
+
>>> input = torch.randn(1, 16, 12, 12)
|
| 794 |
+
>>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
|
| 795 |
+
>>> downsample = nnq.Conv2d(16, 16, 3, stride=2, padding=1)
|
| 796 |
+
>>> upsample = nnq.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
|
| 797 |
+
>>> h = downsample(q_input)
|
| 798 |
+
>>> h.size()
|
| 799 |
+
torch.Size([1, 16, 6, 6])
|
| 800 |
+
>>> # xdoctest: +SKIP("FIXME: output_size is not a parameter)
|
| 801 |
+
>>> output = upsample(h, output_size=input.size())
|
| 802 |
+
>>> output.size()
|
| 803 |
+
torch.Size([1, 16, 12, 12])
|
| 804 |
+
"""
|
| 805 |
+
|
| 806 |
+
_FLOAT_MODULE = nn.ConvTranspose2d
|
| 807 |
+
|
| 808 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
| 809 |
+
padding=0, output_padding=0, groups=1, bias=True,
|
| 810 |
+
dilation=1, padding_mode='zeros', device=None, dtype=None):
|
| 811 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 812 |
+
kernel_size = _pair(kernel_size)
|
| 813 |
+
stride = _pair(stride)
|
| 814 |
+
padding = _pair(padding)
|
| 815 |
+
dilation = _pair(dilation)
|
| 816 |
+
output_padding = _pair(output_padding)
|
| 817 |
+
|
| 818 |
+
super().__init__(
|
| 819 |
+
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
| 820 |
+
True, output_padding, groups, bias, padding_mode, **factory_kwargs)
|
| 821 |
+
|
| 822 |
+
def _get_name(self):
|
| 823 |
+
return 'QuantizedConvTranspose2d'
|
| 824 |
+
|
| 825 |
+
def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
|
| 826 |
+
self._packed_params = torch.ops.quantized.conv_transpose2d_prepack(
|
| 827 |
+
w, b, self.stride, self.padding, self.output_padding, self.dilation,
|
| 828 |
+
self.groups)
|
| 829 |
+
|
| 830 |
+
def _weight_bias(self):
|
| 831 |
+
w, b = torch.ops.quantized.conv2d_unpack(self._packed_params)
|
| 832 |
+
return w, b
|
| 833 |
+
|
| 834 |
+
def weight(self):
|
| 835 |
+
(w, _) = self._weight_bias()
|
| 836 |
+
return w
|
| 837 |
+
|
| 838 |
+
def bias(self):
|
| 839 |
+
(_, b) = self._weight_bias()
|
| 840 |
+
return b
|
| 841 |
+
|
| 842 |
+
def forward(self, input):
|
| 843 |
+
# Temporarily using len(shape) instead of ndim due to JIT issue
|
| 844 |
+
# https://github.com/pytorch/pytorch/issues/23890
|
| 845 |
+
if len(input.shape) != 4:
|
| 846 |
+
raise ValueError("Input shape must be `(N, C, H, W)`!")
|
| 847 |
+
return ops.quantized.conv_transpose2d(
|
| 848 |
+
input, self._packed_params, self.scale, self.zero_point)
|
| 849 |
+
|
| 850 |
+
@classmethod
|
| 851 |
+
def from_reference(cls, ref_qconvt, output_scale, output_zero_point):
|
| 852 |
+
return _ConvTransposeNd.from_reference(cls, ref_qconvt, output_scale, output_zero_point)
|
| 853 |
+
|
| 854 |
+
|
| 855 |
+
class ConvTranspose3d(_ConvTransposeNd):
|
| 856 |
+
r"""Applies a 3D transposed convolution operator over an input image
|
| 857 |
+
composed of several input planes.
|
| 858 |
+
For details on input arguments, parameters, and implementation see
|
| 859 |
+
:class:`~torch.nn.ConvTranspose3d`.
|
| 860 |
+
|
| 861 |
+
.. note:: Currently only the FBGEMM engine is implemented.
|
| 862 |
+
Please, set the `torch.backends.quantized.engine = 'fbgemm'`
|
| 863 |
+
|
| 864 |
+
For special notes, please, see :class:`~torch.ao.nn.quantized.Conv3d`
|
| 865 |
+
|
| 866 |
+
Attributes:
|
| 867 |
+
weight (Tensor): packed tensor derived from the learnable weight
|
| 868 |
+
parameter.
|
| 869 |
+
scale (Tensor): scalar for the output scale
|
| 870 |
+
zero_point (Tensor): scalar for the output zero point
|
| 871 |
+
See :class:`~torch.nn.ConvTranspose3d` for other attributes.
|
| 872 |
+
|
| 873 |
+
Examples::
|
| 874 |
+
|
| 875 |
+
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
|
| 876 |
+
>>> torch.backends.quantized.engine = 'fbgemm'
|
| 877 |
+
>>> from torch.ao.nn import quantized as nnq
|
| 878 |
+
>>> # With cubic kernels and equal stride
|
| 879 |
+
>>> m = nnq.ConvTranspose3d(16, 33, 3, stride=2)
|
| 880 |
+
>>> # non-cubic kernels and unequal stride and with padding
|
| 881 |
+
>>> m = nnq.ConvTranspose3d(16, 33, (3, 3, 5), stride=(2, 1, 1), padding=(4, 2, 2))
|
| 882 |
+
>>> input = torch.randn(20, 16, 50, 100, 100)
|
| 883 |
+
>>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
|
| 884 |
+
>>> output = m(q_input)
|
| 885 |
+
>>> # exact output size can be also specified as an argument
|
| 886 |
+
>>> input = torch.randn(1, 16, 12, 12, 12)
|
| 887 |
+
>>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
|
| 888 |
+
>>> downsample = nnq.Conv3d(16, 16, 3, stride=2, padding=1)
|
| 889 |
+
>>> upsample = nnq.ConvTranspose3d(16, 16, 3, stride=2, padding=1)
|
| 890 |
+
>>> h = downsample(q_input)
|
| 891 |
+
>>> h.size()
|
| 892 |
+
torch.Size([1, 16, 6, 6, 6])
|
| 893 |
+
>>> # xdoctest: +SKIP("FIXME: output_size is not a parameter)
|
| 894 |
+
>>> output = upsample(h, output_size=input.size())
|
| 895 |
+
>>> output.size()
|
| 896 |
+
torch.Size([1, 16, 12, 12, 12])
|
| 897 |
+
"""
|
| 898 |
+
|
| 899 |
+
_FLOAT_MODULE = nn.ConvTranspose3d
|
| 900 |
+
|
| 901 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
| 902 |
+
padding=0, output_padding=0, groups=1, bias=True,
|
| 903 |
+
dilation=1, padding_mode='zeros', device=None, dtype=None):
|
| 904 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 905 |
+
kernel_size = _triple(kernel_size)
|
| 906 |
+
stride = _triple(stride)
|
| 907 |
+
padding = _triple(padding)
|
| 908 |
+
dilation = _triple(dilation)
|
| 909 |
+
output_padding = _triple(output_padding)
|
| 910 |
+
|
| 911 |
+
super().__init__(
|
| 912 |
+
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
| 913 |
+
True, output_padding, groups, bias, padding_mode, **factory_kwargs)
|
| 914 |
+
|
| 915 |
+
def _get_name(self):
|
| 916 |
+
return 'QuantizedConvTranspose3d'
|
| 917 |
+
|
| 918 |
+
def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
|
| 919 |
+
self._packed_params = torch.ops.quantized.conv_transpose3d_prepack(
|
| 920 |
+
w, b, self.stride, self.padding, self.output_padding, self.dilation,
|
| 921 |
+
self.groups)
|
| 922 |
+
|
| 923 |
+
def _weight_bias(self):
|
| 924 |
+
w, b = torch.ops.quantized.conv3d_unpack(self._packed_params)
|
| 925 |
+
return w, b
|
| 926 |
+
|
| 927 |
+
def weight(self):
|
| 928 |
+
(w, _) = self._weight_bias()
|
| 929 |
+
return w
|
| 930 |
+
|
| 931 |
+
def bias(self):
|
| 932 |
+
(_, b) = self._weight_bias()
|
| 933 |
+
return b
|
| 934 |
+
|
| 935 |
+
def forward(self, input):
|
| 936 |
+
# Temporarily using len(shape) instead of ndim due to JIT issue
|
| 937 |
+
# https://github.com/pytorch/pytorch/issues/23890
|
| 938 |
+
if len(input.shape) != 5:
|
| 939 |
+
raise ValueError("Input shape must be `(N, C, T, H, W)`!")
|
| 940 |
+
return ops.quantized.conv_transpose3d(
|
| 941 |
+
input, self._packed_params, self.scale, self.zero_point)
|
| 942 |
+
|
| 943 |
+
@classmethod
|
| 944 |
+
def from_reference(cls, ref_qconvt, output_scale, output_zero_point):
|
| 945 |
+
return _ConvTransposeNd.from_reference(cls, ref_qconvt, output_scale, output_zero_point)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/dropout.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
__all__ = ['Dropout']
|
| 4 |
+
|
| 5 |
+
class Dropout(torch.nn.Dropout):
|
| 6 |
+
r"""This is the quantized equivalent of :class:`~torch.nn.Dropout`.
|
| 7 |
+
And this is a placeholder to enable models where fp32 tensors
|
| 8 |
+
had dropout to work with quantized tensors in train and eval mode.
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
p: probability of an element to be zeroed
|
| 12 |
+
inplace: can optionally do the operation in-place. Default: ``False``
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def forward(self, input):
|
| 16 |
+
return input
|
| 17 |
+
|
| 18 |
+
def _get_name(self):
|
| 19 |
+
return 'QuantizedDropout'
|
| 20 |
+
|
| 21 |
+
@classmethod
|
| 22 |
+
def from_float(cls, mod):
|
| 23 |
+
return cls(mod.p, mod.inplace)
|
| 24 |
+
|
| 25 |
+
@classmethod
|
| 26 |
+
def from_reference(cls, mod, scale, zero_point):
|
| 27 |
+
return cls(mod.p, mod.inplace)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/functional_modules.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
from torch._ops import ops
|
| 6 |
+
|
| 7 |
+
__all__ = ['FloatFunctional', 'FXFloatFunctional', 'QFunctional']
|
| 8 |
+
|
| 9 |
+
class FloatFunctional(torch.nn.Module):
|
| 10 |
+
r"""State collector class for float operations.
|
| 11 |
+
|
| 12 |
+
The instance of this class can be used instead of the ``torch.`` prefix for
|
| 13 |
+
some operations. See example usage below.
|
| 14 |
+
|
| 15 |
+
.. note::
|
| 16 |
+
|
| 17 |
+
This class does not provide a ``forward`` hook. Instead, you must use
|
| 18 |
+
one of the underlying functions (e.g. ``add``).
|
| 19 |
+
|
| 20 |
+
Examples::
|
| 21 |
+
|
| 22 |
+
>>> f_add = FloatFunctional()
|
| 23 |
+
>>> a = torch.tensor(3.0)
|
| 24 |
+
>>> b = torch.tensor(4.0)
|
| 25 |
+
>>> f_add.add(a, b) # Equivalent to ``torch.add(a, b)``
|
| 26 |
+
|
| 27 |
+
Valid operation names:
|
| 28 |
+
- add
|
| 29 |
+
- cat
|
| 30 |
+
- mul
|
| 31 |
+
- add_relu
|
| 32 |
+
- add_scalar
|
| 33 |
+
- mul_scalar
|
| 34 |
+
"""
|
| 35 |
+
def __init__(self):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.activation_post_process = torch.nn.Identity()
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
raise RuntimeError("FloatFunctional is not intended to use the " +
|
| 41 |
+
"'forward'. Please use the underlying operation")
|
| 42 |
+
|
| 43 |
+
r"""Operation equivalent to ``torch.add(Tensor, Tensor)``"""
|
| 44 |
+
def add(self, x: Tensor, y: Tensor) -> Tensor:
|
| 45 |
+
r = torch.add(x, y)
|
| 46 |
+
r = self.activation_post_process(r)
|
| 47 |
+
return r
|
| 48 |
+
|
| 49 |
+
r"""Operation equivalent to ``torch.add(Tensor, float)``"""
|
| 50 |
+
def add_scalar(self, x: Tensor, y: float) -> Tensor:
|
| 51 |
+
r = torch.add(x, y)
|
| 52 |
+
# Note: this operation is not observed because the observation is not
|
| 53 |
+
# needed for the quantized op.
|
| 54 |
+
return r
|
| 55 |
+
|
| 56 |
+
r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``"""
|
| 57 |
+
def mul(self, x: Tensor, y: Tensor) -> Tensor:
|
| 58 |
+
r = torch.mul(x, y)
|
| 59 |
+
r = self.activation_post_process(r)
|
| 60 |
+
return r
|
| 61 |
+
|
| 62 |
+
r"""Operation equivalent to ``torch.mul(Tensor, float)``"""
|
| 63 |
+
def mul_scalar(self, x: Tensor, y: float) -> Tensor:
|
| 64 |
+
r = torch.mul(x, y)
|
| 65 |
+
# Note: this operation is not observed because the observation is not
|
| 66 |
+
# needed for the quantized op.
|
| 67 |
+
return r
|
| 68 |
+
|
| 69 |
+
r"""Operation equivalent to ``torch.cat``"""
|
| 70 |
+
def cat(self, x: List[Tensor], dim: int = 0) -> Tensor:
|
| 71 |
+
r = torch.cat(x, dim=dim)
|
| 72 |
+
r = self.activation_post_process(r)
|
| 73 |
+
return r
|
| 74 |
+
|
| 75 |
+
r"""Operation equivalent to ``relu(torch.add(x,y))``"""
|
| 76 |
+
def add_relu(self, x: Tensor, y: Tensor) -> Tensor:
|
| 77 |
+
r = torch.add(x, y)
|
| 78 |
+
r = torch.nn.functional.relu(r)
|
| 79 |
+
r = self.activation_post_process(r)
|
| 80 |
+
return r
|
| 81 |
+
|
| 82 |
+
r"""Operation equivalent to ``torch.matmul(Tensor, Tensor)``"""
|
| 83 |
+
def matmul(self, x: Tensor, y: Tensor) -> Tensor:
|
| 84 |
+
r = torch.matmul(x, y)
|
| 85 |
+
r = self.activation_post_process(r)
|
| 86 |
+
return r
|
| 87 |
+
|
| 88 |
+
class FXFloatFunctional(torch.nn.Module):
|
| 89 |
+
r""" module to replace FloatFunctional module before FX graph mode quantization,
|
| 90 |
+
since activation_post_process will be inserted in top level module directly
|
| 91 |
+
|
| 92 |
+
Valid operation names:
|
| 93 |
+
- add
|
| 94 |
+
- cat
|
| 95 |
+
- mul
|
| 96 |
+
- add_relu
|
| 97 |
+
- add_scalar
|
| 98 |
+
- mul_scalar
|
| 99 |
+
"""
|
| 100 |
+
def forward(self, x):
|
| 101 |
+
raise RuntimeError("FloatFunctional is not intended to use the " +
|
| 102 |
+
"'forward'. Please use the underlying operation")
|
| 103 |
+
|
| 104 |
+
r"""Operation equivalent to ``torch.add(Tensor, Tensor)``"""
|
| 105 |
+
def add(self, x: Tensor, y: Tensor) -> Tensor:
|
| 106 |
+
r = torch.add(x, y)
|
| 107 |
+
return r
|
| 108 |
+
|
| 109 |
+
r"""Operation equivalent to ``torch.add(Tensor, float)``"""
|
| 110 |
+
def add_scalar(self, x: Tensor, y: float) -> Tensor:
|
| 111 |
+
r = torch.add(x, y)
|
| 112 |
+
return r
|
| 113 |
+
|
| 114 |
+
r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``"""
|
| 115 |
+
def mul(self, x: Tensor, y: Tensor) -> Tensor:
|
| 116 |
+
r = torch.mul(x, y)
|
| 117 |
+
return r
|
| 118 |
+
|
| 119 |
+
r"""Operation equivalent to ``torch.mul(Tensor, float)``"""
|
| 120 |
+
def mul_scalar(self, x: Tensor, y: float) -> Tensor:
|
| 121 |
+
r = torch.mul(x, y)
|
| 122 |
+
return r
|
| 123 |
+
|
| 124 |
+
r"""Operation equivalent to ``torch.cat``"""
|
| 125 |
+
def cat(self, x: List[Tensor], dim: int = 0) -> Tensor:
|
| 126 |
+
r = torch.cat(x, dim=dim)
|
| 127 |
+
return r
|
| 128 |
+
|
| 129 |
+
r"""Operation equivalent to ``relu(torch.add(x,y))``"""
|
| 130 |
+
def add_relu(self, x: Tensor, y: Tensor) -> Tensor:
|
| 131 |
+
r = torch.add(x, y)
|
| 132 |
+
r = torch.nn.functional.relu(r)
|
| 133 |
+
return r
|
| 134 |
+
|
| 135 |
+
r"""Operation equivalent to ``torch.matmul(Tensor, Tensor)``"""
|
| 136 |
+
def matmul(self, x: Tensor, y: Tensor) -> Tensor:
|
| 137 |
+
r = torch.matmul(x, y)
|
| 138 |
+
return r
|
| 139 |
+
|
| 140 |
+
class QFunctional(torch.nn.Module):
|
| 141 |
+
r"""Wrapper class for quantized operations.
|
| 142 |
+
|
| 143 |
+
The instance of this class can be used instead of the
|
| 144 |
+
``torch.ops.quantized`` prefix. See example usage below.
|
| 145 |
+
|
| 146 |
+
.. note::
|
| 147 |
+
|
| 148 |
+
This class does not provide a ``forward`` hook. Instead, you must use
|
| 149 |
+
one of the underlying functions (e.g. ``add``).
|
| 150 |
+
|
| 151 |
+
Examples::
|
| 152 |
+
|
| 153 |
+
>>> q_add = QFunctional()
|
| 154 |
+
>>> # xdoctest: +SKIP
|
| 155 |
+
>>> a = torch.quantize_per_tensor(torch.tensor(3.0), 1.0, 0, torch.qint32)
|
| 156 |
+
>>> b = torch.quantize_per_tensor(torch.tensor(4.0), 1.0, 0, torch.qint32)
|
| 157 |
+
>>> q_add.add(a, b) # Equivalent to ``torch.ops.quantized.add(a, b, 1.0, 0)``
|
| 158 |
+
|
| 159 |
+
Valid operation names:
|
| 160 |
+
- add
|
| 161 |
+
- cat
|
| 162 |
+
- mul
|
| 163 |
+
- add_relu
|
| 164 |
+
- add_scalar
|
| 165 |
+
- mul_scalar
|
| 166 |
+
"""
|
| 167 |
+
def __init__(self):
|
| 168 |
+
super().__init__()
|
| 169 |
+
self.scale = 1.0
|
| 170 |
+
self.zero_point = 0
|
| 171 |
+
self.activation_post_process = torch.nn.Identity()
|
| 172 |
+
|
| 173 |
+
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
| 174 |
+
super()._save_to_state_dict(destination, prefix, keep_vars)
|
| 175 |
+
destination[prefix + 'scale'] = torch.tensor(self.scale)
|
| 176 |
+
destination[prefix + 'zero_point'] = torch.tensor(self.zero_point)
|
| 177 |
+
|
| 178 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
| 179 |
+
missing_keys, unexpected_keys, error_msgs):
|
| 180 |
+
|
| 181 |
+
self.scale = float(state_dict.pop(prefix + 'scale'))
|
| 182 |
+
self.zero_point = int(state_dict.pop(prefix + 'zero_point'))
|
| 183 |
+
super()._load_from_state_dict(state_dict, prefix, local_metadata, False,
|
| 184 |
+
missing_keys, unexpected_keys, error_msgs)
|
| 185 |
+
|
| 186 |
+
def _get_name(self):
|
| 187 |
+
return 'QFunctional'
|
| 188 |
+
|
| 189 |
+
def extra_repr(self):
|
| 190 |
+
return f'scale={self.scale}, zero_point={self.zero_point}'
|
| 191 |
+
|
| 192 |
+
def forward(self, x):
|
| 193 |
+
raise RuntimeError("Functional is not intended to use the " +
|
| 194 |
+
"'forward'. Please use the underlying operation")
|
| 195 |
+
|
| 196 |
+
r"""Operation equivalent to ``torch.ops.quantized.add``"""
|
| 197 |
+
def add(self, x: Tensor, y: Tensor) -> Tensor:
|
| 198 |
+
r = ops.quantized.add(x, y, scale=self.scale, zero_point=self.zero_point)
|
| 199 |
+
r = self.activation_post_process(r)
|
| 200 |
+
return r
|
| 201 |
+
|
| 202 |
+
r"""Operation equivalent to ``torch.ops.quantized.add(Tensor, float)``"""
|
| 203 |
+
def add_scalar(self, x: Tensor, y: float) -> Tensor:
|
| 204 |
+
r = ops.quantized.add_scalar(x, y)
|
| 205 |
+
# Note: this operation is not observed because the observation is not
|
| 206 |
+
# needed for the quantized op.
|
| 207 |
+
return r
|
| 208 |
+
|
| 209 |
+
r"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, Tensor)``"""
|
| 210 |
+
def mul(self, x: Tensor, y: Tensor) -> Tensor:
|
| 211 |
+
r = ops.quantized.mul(x, y, scale=self.scale, zero_point=self.zero_point)
|
| 212 |
+
r = self.activation_post_process(r)
|
| 213 |
+
return r
|
| 214 |
+
|
| 215 |
+
r"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, float)``"""
|
| 216 |
+
def mul_scalar(self, x: Tensor, y: float) -> Tensor:
|
| 217 |
+
r = ops.quantized.mul_scalar(x, y)
|
| 218 |
+
# Note: this operation is not observed because the observation is not
|
| 219 |
+
# needed for the quantized op.
|
| 220 |
+
return r
|
| 221 |
+
|
| 222 |
+
r"""Operation equivalent to ``torch.ops.quantized.cat``"""
|
| 223 |
+
def cat(self, x: List[Tensor], dim: int = 0) -> Tensor:
|
| 224 |
+
r = ops.quantized.cat(x, scale=self.scale, zero_point=self.zero_point, dim=dim)
|
| 225 |
+
r = self.activation_post_process(r)
|
| 226 |
+
return r
|
| 227 |
+
|
| 228 |
+
r"""Operation equivalent to ``torch.ops.quantized.add_relu``"""
|
| 229 |
+
def add_relu(self, x: Tensor, y: Tensor) -> Tensor:
|
| 230 |
+
r = ops.quantized.add_relu(x, y, scale=self.scale, zero_point=self.zero_point)
|
| 231 |
+
r = self.activation_post_process(r)
|
| 232 |
+
return r
|
| 233 |
+
|
| 234 |
+
r"""Operation equivalent to ``torch.ops.quantized.matmul(Tensor, Tensor)``"""
|
| 235 |
+
def matmul(self, x: Tensor, y: Tensor) -> Tensor:
|
| 236 |
+
r = ops.quantized.matmul(x, y, scale=self.scale, zero_point=self.zero_point)
|
| 237 |
+
# Note: this operation is not observed because the observation is not
|
| 238 |
+
# needed for the quantized op.
|
| 239 |
+
return r
|
| 240 |
+
|
| 241 |
+
@classmethod
|
| 242 |
+
def from_float(cls, mod):
|
| 243 |
+
assert type(mod) == FloatFunctional, \
|
| 244 |
+
"QFunctional.from_float expects an instance of FloatFunctional"
|
| 245 |
+
scale, zero_point = mod.activation_post_process.calculate_qparams() # type: ignore[operator]
|
| 246 |
+
new_mod = QFunctional()
|
| 247 |
+
new_mod.scale = float(scale)
|
| 248 |
+
new_mod.zero_point = int(zero_point)
|
| 249 |
+
return new_mod
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .linear import Linear
|
| 2 |
+
from .conv import Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
|
| 3 |
+
from .rnn import RNNCell, LSTMCell, GRUCell, LSTM, GRU
|
| 4 |
+
from .sparse import Embedding, EmbeddingBag
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
'Linear',
|
| 8 |
+
'Conv1d',
|
| 9 |
+
'Conv2d',
|
| 10 |
+
'Conv3d',
|
| 11 |
+
'ConvTranspose1d',
|
| 12 |
+
'ConvTranspose2d',
|
| 13 |
+
'ConvTranspose3d',
|
| 14 |
+
'RNNCell',
|
| 15 |
+
'LSTMCell',
|
| 16 |
+
'GRUCell',
|
| 17 |
+
'LSTM',
|
| 18 |
+
'GRU',
|
| 19 |
+
'Embedding',
|
| 20 |
+
'EmbeddingBag',
|
| 21 |
+
]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/linear.cpython-311.pyc
ADDED
|
Binary file (3.71 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/sparse.cpython-311.pyc
ADDED
|
Binary file (6.01 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/rnn.py
ADDED
|
@@ -0,0 +1,614 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch import Tensor
|
| 4 |
+
from .utils import _quantize_and_dequantize_weight
|
| 5 |
+
from .utils import _quantize_weight
|
| 6 |
+
from typing import Optional, Dict, Any, Tuple
|
| 7 |
+
from torch import _VF
|
| 8 |
+
from torch.nn.utils.rnn import PackedSequence
|
| 9 |
+
|
| 10 |
+
__all__ = ['RNNCellBase', 'RNNCell', 'LSTMCell', 'GRUCell', 'RNNBase', 'LSTM', 'GRU', 'get_quantized_weight']
|
| 11 |
+
|
| 12 |
+
def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
|
| 13 |
+
return tensor.index_select(dim, permutation)
|
| 14 |
+
|
| 15 |
+
def _get_weight_and_quantization_params(module, wn):
|
| 16 |
+
weight = getattr(module, wn)
|
| 17 |
+
params = [weight]
|
| 18 |
+
for param_name in [wn + n for n in ["_qscheme", "_dtype", "_scale", "_zero_point", "_axis_int"]]:
|
| 19 |
+
if hasattr(module, param_name):
|
| 20 |
+
param = getattr(module, param_name)
|
| 21 |
+
else:
|
| 22 |
+
param = None
|
| 23 |
+
params.append(param)
|
| 24 |
+
return params
|
| 25 |
+
|
| 26 |
+
def get_quantized_weight(module, wn):
|
| 27 |
+
if not hasattr(module, wn):
|
| 28 |
+
return None
|
| 29 |
+
params = _get_weight_and_quantization_params(module, wn)
|
| 30 |
+
weight = _quantize_weight(*params)
|
| 31 |
+
return weight
|
| 32 |
+
|
| 33 |
+
def _get_quantize_and_dequantized_weight(module, wn):
|
| 34 |
+
if not hasattr(module, wn):
|
| 35 |
+
return None
|
| 36 |
+
params = _get_weight_and_quantization_params(module, wn)
|
| 37 |
+
weight = _quantize_and_dequantize_weight(*params)
|
| 38 |
+
return weight
|
| 39 |
+
|
| 40 |
+
class RNNCellBase(nn.RNNCellBase):
|
| 41 |
+
def __init__(self, input_size: int, hidden_size: int, bias: bool, num_chunks: int,
|
| 42 |
+
device=None, dtype=None, weight_qparams_dict=None) -> None:
|
| 43 |
+
super().__init__(input_size, hidden_size, bias, num_chunks, device=device, dtype=dtype)
|
| 44 |
+
# TODO(jerryzh168): maybe make this arg a required arg
|
| 45 |
+
if weight_qparams_dict is None:
|
| 46 |
+
weight_qparams = {
|
| 47 |
+
"qscheme": torch.per_tensor_affine,
|
| 48 |
+
"dtype": torch.quint8,
|
| 49 |
+
"scale": 1.0,
|
| 50 |
+
"zero_point": 0
|
| 51 |
+
}
|
| 52 |
+
weight_qparams_dict = {
|
| 53 |
+
"weight_ih": weight_qparams,
|
| 54 |
+
"weight_hh": weight_qparams,
|
| 55 |
+
"is_decomposed": False,
|
| 56 |
+
}
|
| 57 |
+
assert len(weight_qparams_dict) == 3, "Expected length for weight_qparams_dict to be 3 for QuantizedRNNCellBase(Reference)"
|
| 58 |
+
self._init_weight_qparams_dict(weight_qparams_dict, device)
|
| 59 |
+
|
| 60 |
+
def _init_weight_qparams_dict(self, weight_qparams_dict, device):
|
| 61 |
+
assert weight_qparams_dict is not None
|
| 62 |
+
self.is_decomposed = weight_qparams_dict["is_decomposed"]
|
| 63 |
+
for key, weight_qparams in weight_qparams_dict.items():
|
| 64 |
+
if key == "is_decomposed":
|
| 65 |
+
continue
|
| 66 |
+
# TODO: refactor the duplicated code to utils.py
|
| 67 |
+
weight_qscheme = weight_qparams["qscheme"]
|
| 68 |
+
weight_dtype = weight_qparams["dtype"]
|
| 69 |
+
setattr(self, key + "_qscheme", weight_qscheme)
|
| 70 |
+
setattr(self, key + "_dtype", weight_dtype)
|
| 71 |
+
assert weight_qscheme in [None, torch.per_tensor_affine, torch.per_channel_affine], \
|
| 72 |
+
Exception(f"qscheme: {weight_qscheme} is not support in {self._get_name()}")
|
| 73 |
+
if weight_qscheme is not None:
|
| 74 |
+
scale = weight_qparams["scale"]
|
| 75 |
+
scale_tensor = scale.clone().detach() \
|
| 76 |
+
if isinstance(scale, torch.Tensor) else \
|
| 77 |
+
torch.tensor(scale, dtype=torch.float, device=device)
|
| 78 |
+
self.register_buffer(key + "_scale", scale_tensor)
|
| 79 |
+
zp = weight_qparams["zero_point"]
|
| 80 |
+
zp_tensor = zp.clone().detach() \
|
| 81 |
+
if isinstance(zp, torch.Tensor) else \
|
| 82 |
+
torch.tensor(zp, dtype=torch.int, device=device)
|
| 83 |
+
self.register_buffer(key + "_zero_point", zp_tensor)
|
| 84 |
+
if weight_qscheme == torch.per_channel_affine:
|
| 85 |
+
axis = weight_qparams["axis"]
|
| 86 |
+
axis_tensor = axis.clone().detach() \
|
| 87 |
+
if isinstance(axis, torch.Tensor) else \
|
| 88 |
+
torch.tensor(axis, dtype=torch.int, device=device)
|
| 89 |
+
self.register_buffer(key + "_axis", axis_tensor)
|
| 90 |
+
else:
|
| 91 |
+
# added for TorchScriptability, not used
|
| 92 |
+
self.register_buffer(
|
| 93 |
+
key + "_axis", torch.tensor(0, dtype=torch.int, device=device))
|
| 94 |
+
setattr(self, key + "_axis_int", getattr(self, key + "_axis").item())
|
| 95 |
+
|
| 96 |
+
def _get_name(self):
|
| 97 |
+
return "QuantizedRNNCellBase(Reference)"
|
| 98 |
+
|
| 99 |
+
def get_quantized_weight_ih(self):
|
| 100 |
+
return get_quantized_weight(self, "weight_ih")
|
| 101 |
+
|
| 102 |
+
def get_quantized_weight_hh(self):
|
| 103 |
+
return get_quantized_weight(self, "weight_hh")
|
| 104 |
+
|
| 105 |
+
def get_weight_ih(self):
|
| 106 |
+
return _get_quantize_and_dequantized_weight(self, "weight_ih")
|
| 107 |
+
|
| 108 |
+
def get_weight_hh(self):
|
| 109 |
+
return _get_quantize_and_dequantized_weight(self, "weight_hh")
|
| 110 |
+
|
| 111 |
+
class RNNCell(RNNCellBase):
|
| 112 |
+
"""
|
| 113 |
+
We'll store weight_qparams for all the weights (weight_ih and weight_hh),
|
| 114 |
+
we need to pass in a `weight_qparams_dict` that maps from weight name,
|
| 115 |
+
e.g. weight_ih, to the weight_qparams for that weight
|
| 116 |
+
"""
|
| 117 |
+
def __init__(self, input_size: int, hidden_size: int, bias: bool = True, nonlinearity: str = "tanh",
|
| 118 |
+
device=None, dtype=None, weight_qparams_dict: Optional[Dict[str, Any]] = None) -> None:
|
| 119 |
+
factory_kwargs = {'device': device, 'dtype': dtype, 'weight_qparams_dict': weight_qparams_dict}
|
| 120 |
+
super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs)
|
| 121 |
+
self.nonlinearity = nonlinearity
|
| 122 |
+
|
| 123 |
+
def _get_name(self):
|
| 124 |
+
return "QuantizedRNNCell(Reference)"
|
| 125 |
+
|
| 126 |
+
# TODO: refactor nn.RNNCell to have a _forward that takes weight_ih and weight_hh as input
|
| 127 |
+
# and remove duplicated code, same for the other two Cell modules
|
| 128 |
+
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
|
| 129 |
+
assert input.dim() in (1, 2), \
|
| 130 |
+
f"RNNCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
|
| 131 |
+
is_batched = input.dim() == 2
|
| 132 |
+
if not is_batched:
|
| 133 |
+
input = input.unsqueeze(0)
|
| 134 |
+
|
| 135 |
+
if hx is None:
|
| 136 |
+
hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
|
| 137 |
+
else:
|
| 138 |
+
hx = hx.unsqueeze(0) if not is_batched else hx
|
| 139 |
+
|
| 140 |
+
if self.nonlinearity == "tanh":
|
| 141 |
+
ret = _VF.rnn_tanh_cell(
|
| 142 |
+
input, hx,
|
| 143 |
+
self.get_weight_ih(), self.get_weight_hh(),
|
| 144 |
+
self.bias_ih, self.bias_hh,
|
| 145 |
+
)
|
| 146 |
+
elif self.nonlinearity == "relu":
|
| 147 |
+
ret = _VF.rnn_relu_cell(
|
| 148 |
+
input, hx,
|
| 149 |
+
self.get_weight_ih(), self.get_weight_hh(),
|
| 150 |
+
self.bias_ih, self.bias_hh,
|
| 151 |
+
)
|
| 152 |
+
else:
|
| 153 |
+
ret = input # TODO: remove when jit supports exception flow
|
| 154 |
+
raise RuntimeError(
|
| 155 |
+
f"Unknown nonlinearity: {self.nonlinearity}")
|
| 156 |
+
|
| 157 |
+
if not is_batched:
|
| 158 |
+
ret = ret.squeeze(0)
|
| 159 |
+
|
| 160 |
+
return ret
|
| 161 |
+
|
| 162 |
+
@classmethod
|
| 163 |
+
def from_float(cls, mod, weight_qparams_dict):
|
| 164 |
+
ref_mod = cls(
|
| 165 |
+
mod.input_size,
|
| 166 |
+
mod.hidden_size,
|
| 167 |
+
mod.bias,
|
| 168 |
+
mod.nonlinearity,
|
| 169 |
+
mod.weight_ih.device,
|
| 170 |
+
mod.weight_ih.dtype,
|
| 171 |
+
weight_qparams_dict)
|
| 172 |
+
ref_mod.weight_ih = mod.weight_ih
|
| 173 |
+
ref_mod.weight_hh = mod.weight_hh
|
| 174 |
+
ref_mod.bias_ih = mod.bias_ih
|
| 175 |
+
ref_mod.bias_hh = mod.bias_hh
|
| 176 |
+
return ref_mod
|
| 177 |
+
|
| 178 |
+
class LSTMCell(RNNCellBase):
|
| 179 |
+
"""
|
| 180 |
+
We'll store weight_qparams for all the weights (weight_ih and weight_hh),
|
| 181 |
+
we need to pass in a `weight_qparams_dict` that maps from weight name,
|
| 182 |
+
e.g. weight_ih, to the weight_qparams for that weight
|
| 183 |
+
"""
|
| 184 |
+
def __init__(self, input_size: int, hidden_size: int, bias: bool = True,
|
| 185 |
+
device=None, dtype=None, weight_qparams_dict: Optional[Dict[str, Any]] = None) -> None:
|
| 186 |
+
factory_kwargs = {'device': device, 'dtype': dtype, 'weight_qparams_dict': weight_qparams_dict}
|
| 187 |
+
super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs)
|
| 188 |
+
|
| 189 |
+
def _get_name(self):
|
| 190 |
+
return "QuantizedLSTMCell(Reference)"
|
| 191 |
+
|
| 192 |
+
def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
|
| 193 |
+
assert input.dim() in (1, 2), \
|
| 194 |
+
f"LSTMCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
|
| 195 |
+
is_batched = input.dim() == 2
|
| 196 |
+
if not is_batched:
|
| 197 |
+
input = input.unsqueeze(0)
|
| 198 |
+
|
| 199 |
+
if hx is None:
|
| 200 |
+
zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
|
| 201 |
+
hx = (zeros, zeros)
|
| 202 |
+
else:
|
| 203 |
+
hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx
|
| 204 |
+
|
| 205 |
+
ret = _VF.lstm_cell(
|
| 206 |
+
input, hx,
|
| 207 |
+
self.get_weight_ih(), self.get_weight_hh(),
|
| 208 |
+
self.bias_ih, self.bias_hh,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
if not is_batched:
|
| 212 |
+
ret = (ret[0].squeeze(0), ret[1].squeeze(0))
|
| 213 |
+
return ret
|
| 214 |
+
|
| 215 |
+
@classmethod
|
| 216 |
+
def from_float(cls, mod, weight_qparams_dict):
|
| 217 |
+
ref_mod = cls(
|
| 218 |
+
mod.input_size,
|
| 219 |
+
mod.hidden_size,
|
| 220 |
+
mod.bias,
|
| 221 |
+
mod.weight_ih.device,
|
| 222 |
+
mod.weight_ih.dtype,
|
| 223 |
+
weight_qparams_dict)
|
| 224 |
+
ref_mod.weight_ih = mod.weight_ih
|
| 225 |
+
ref_mod.weight_hh = mod.weight_hh
|
| 226 |
+
ref_mod.bias_ih = mod.bias_ih
|
| 227 |
+
ref_mod.bias_hh = mod.bias_hh
|
| 228 |
+
return ref_mod
|
| 229 |
+
|
| 230 |
+
class GRUCell(RNNCellBase):
|
| 231 |
+
"""
|
| 232 |
+
We'll store weight_qparams for all the weights (weight_ih and weight_hh),
|
| 233 |
+
we need to pass in a `weight_qparams_dict` that maps from weight name,
|
| 234 |
+
e.g. weight_ih, to the weight_qparams for that weight
|
| 235 |
+
"""
|
| 236 |
+
def __init__(self, input_size: int, hidden_size: int, bias: bool = True,
|
| 237 |
+
device=None, dtype=None, weight_qparams_dict: Optional[Dict[str, Any]] = None) -> None:
|
| 238 |
+
factory_kwargs = {'device': device, 'dtype': dtype, 'weight_qparams_dict': weight_qparams_dict}
|
| 239 |
+
super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs)
|
| 240 |
+
|
| 241 |
+
def _get_name(self):
|
| 242 |
+
return "QuantizedGRUCell(Reference)"
|
| 243 |
+
|
| 244 |
+
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
|
| 245 |
+
assert input.dim() in (1, 2), \
|
| 246 |
+
f"GRUCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
|
| 247 |
+
is_batched = input.dim() == 2
|
| 248 |
+
if not is_batched:
|
| 249 |
+
input = input.unsqueeze(0)
|
| 250 |
+
|
| 251 |
+
if hx is None:
|
| 252 |
+
hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
|
| 253 |
+
else:
|
| 254 |
+
hx = hx.unsqueeze(0) if not is_batched else hx
|
| 255 |
+
|
| 256 |
+
ret = _VF.gru_cell(
|
| 257 |
+
input, hx,
|
| 258 |
+
self.get_weight_ih(), self.get_weight_hh(),
|
| 259 |
+
self.bias_ih, self.bias_hh,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
if not is_batched:
|
| 263 |
+
ret = ret.squeeze(0)
|
| 264 |
+
|
| 265 |
+
return ret
|
| 266 |
+
|
| 267 |
+
@classmethod
|
| 268 |
+
def from_float(cls, mod, weight_qparams_dict):
|
| 269 |
+
ref_mod = cls(
|
| 270 |
+
mod.input_size,
|
| 271 |
+
mod.hidden_size,
|
| 272 |
+
mod.bias,
|
| 273 |
+
mod.weight_ih.device,
|
| 274 |
+
mod.weight_ih.dtype,
|
| 275 |
+
weight_qparams_dict)
|
| 276 |
+
ref_mod.weight_ih = mod.weight_ih
|
| 277 |
+
ref_mod.weight_hh = mod.weight_hh
|
| 278 |
+
ref_mod.bias_ih = mod.bias_ih
|
| 279 |
+
ref_mod.bias_hh = mod.bias_hh
|
| 280 |
+
return ref_mod
|
| 281 |
+
|
| 282 |
+
class RNNBase(nn.RNNBase):
|
| 283 |
+
def __init__(self, mode: str, input_size: int, hidden_size: int,
|
| 284 |
+
num_layers: int = 1, bias: bool = True, batch_first: bool = False,
|
| 285 |
+
dropout: float = 0., bidirectional: bool = False, proj_size: int = 0,
|
| 286 |
+
device=None, dtype=None,
|
| 287 |
+
weight_qparams_dict: Optional[Dict[str, Any]] = None) -> None:
|
| 288 |
+
super().__init__(
|
| 289 |
+
mode, input_size, hidden_size, num_layers, bias, batch_first, dropout,
|
| 290 |
+
bidirectional, proj_size, device, dtype
|
| 291 |
+
)
|
| 292 |
+
# TODO(jerryzh168): maybe make this arg a required arg
|
| 293 |
+
if weight_qparams_dict is None:
|
| 294 |
+
weight_qparams = {
|
| 295 |
+
'qscheme': torch.per_tensor_affine,
|
| 296 |
+
'dtype': torch.quint8,
|
| 297 |
+
'scale': 1.0,
|
| 298 |
+
'zero_point': 0
|
| 299 |
+
}
|
| 300 |
+
weight_qparams_dict = {"is_decomposed": False} # type: ignore[dict-item]
|
| 301 |
+
for wn in self._flat_weights_names:
|
| 302 |
+
if wn.startswith("weight"):
|
| 303 |
+
weight_qparams_dict[wn] = weight_qparams
|
| 304 |
+
self._init_weight_qparams_dict(weight_qparams_dict, device)
|
| 305 |
+
|
| 306 |
+
def _init_weight_qparams_dict(self, weight_qparams_dict, device):
|
| 307 |
+
self.is_decomposed = weight_qparams_dict["is_decomposed"]
|
| 308 |
+
for key, weight_qparams in weight_qparams_dict.items():
|
| 309 |
+
if key == "is_decomposed":
|
| 310 |
+
continue
|
| 311 |
+
weight_qscheme = weight_qparams["qscheme"]
|
| 312 |
+
weight_dtype = weight_qparams["dtype"]
|
| 313 |
+
setattr(self, key + "_qscheme", weight_qscheme)
|
| 314 |
+
setattr(self, key + "_dtype", weight_dtype)
|
| 315 |
+
assert weight_qscheme in [None, torch.per_tensor_affine, torch.per_channel_affine], \
|
| 316 |
+
Exception(f"qscheme: {weight_qscheme} is not support in {self._get_name()}")
|
| 317 |
+
if weight_qscheme is not None:
|
| 318 |
+
self.register_buffer(
|
| 319 |
+
key + "_scale",
|
| 320 |
+
torch.tensor(weight_qparams["scale"], dtype=torch.float, device=device))
|
| 321 |
+
self.register_buffer(
|
| 322 |
+
key + "_zero_point",
|
| 323 |
+
torch.tensor(weight_qparams["zero_point"], dtype=torch.int, device=device))
|
| 324 |
+
if weight_qscheme == torch.per_channel_affine:
|
| 325 |
+
self.register_buffer(
|
| 326 |
+
key + "_axis",
|
| 327 |
+
torch.tensor(weight_qparams["axis"], dtype=torch.int, device=device))
|
| 328 |
+
else:
|
| 329 |
+
# added for TorchScriptability, not used
|
| 330 |
+
self.register_buffer(
|
| 331 |
+
key + "_axis", torch.tensor(0, dtype=torch.int, device=device))
|
| 332 |
+
setattr(self, key + "_axis_int", getattr(self, key + "_axis").item())
|
| 333 |
+
|
| 334 |
+
class LSTM(RNNBase):
|
| 335 |
+
""" Reference Quantized LSTM Module
|
| 336 |
+
We'll store weight_qparams for all the weights in _flat_weights, we need to pass in
|
| 337 |
+
a `weight_qparams_dict` that maps from weight name, e.g. weight_ih_l0,
|
| 338 |
+
to the weight_qparams for that weight
|
| 339 |
+
"""
|
| 340 |
+
def __init__(self, *args, **kwargs):
|
| 341 |
+
super().__init__('LSTM', *args, **kwargs)
|
| 342 |
+
|
| 343 |
+
# Same as above, see torch/nn/modules/module.py::_forward_unimplemented
|
| 344 |
+
def permute_hidden(self, # type: ignore[override]
|
| 345 |
+
hx: Tuple[Tensor, Tensor],
|
| 346 |
+
permutation: Optional[Tensor]
|
| 347 |
+
) -> Tuple[Tensor, Tensor]:
|
| 348 |
+
if permutation is None:
|
| 349 |
+
return hx
|
| 350 |
+
return _apply_permutation(hx[0], permutation), _apply_permutation(hx[1], permutation)
|
| 351 |
+
|
| 352 |
+
def get_expected_cell_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]:
|
| 353 |
+
if batch_sizes is not None:
|
| 354 |
+
mini_batch = int(batch_sizes[0])
|
| 355 |
+
else:
|
| 356 |
+
mini_batch = input.size(0) if self.batch_first else input.size(1)
|
| 357 |
+
num_directions = 2 if self.bidirectional else 1
|
| 358 |
+
expected_hidden_size = (self.num_layers * num_directions,
|
| 359 |
+
mini_batch, self.hidden_size)
|
| 360 |
+
return expected_hidden_size
|
| 361 |
+
|
| 362 |
+
# In the future, we should prevent mypy from applying contravariance rules here.
|
| 363 |
+
# See torch/nn/modules/module.py::_forward_unimplemented
|
| 364 |
+
def check_forward_args(self, # type: ignore[override]
|
| 365 |
+
input: Tensor,
|
| 366 |
+
hidden: Tuple[Tensor, Tensor],
|
| 367 |
+
batch_sizes: Optional[Tensor],
|
| 368 |
+
):
|
| 369 |
+
self.check_input(input, batch_sizes)
|
| 370 |
+
self.check_hidden_size(hidden[0], self.get_expected_hidden_size(input, batch_sizes),
|
| 371 |
+
'Expected hidden[0] size {}, got {}')
|
| 372 |
+
self.check_hidden_size(hidden[1], self.get_expected_cell_size(input, batch_sizes),
|
| 373 |
+
'Expected hidden[1] size {}, got {}')
|
| 374 |
+
|
| 375 |
+
def get_quantized_weight_bias_dict(self):
|
| 376 |
+
""" dictionary from flat_weight_name to quantized weight or (unquantized) bias
|
| 377 |
+
e.g.
|
| 378 |
+
{
|
| 379 |
+
"weight_ih_l0": quantized_weight,
|
| 380 |
+
"bias_ih_l0": unquantized_bias,
|
| 381 |
+
...
|
| 382 |
+
}
|
| 383 |
+
"""
|
| 384 |
+
quantized_weight_bias_dict = {}
|
| 385 |
+
for wn in self._flat_weights_names:
|
| 386 |
+
if hasattr(self, wn):
|
| 387 |
+
if wn.startswith("weight"):
|
| 388 |
+
weight_or_bias = get_quantized_weight(self, wn)
|
| 389 |
+
else:
|
| 390 |
+
weight_or_bias = getattr(self, wn)
|
| 391 |
+
else:
|
| 392 |
+
weight_or_bias = None
|
| 393 |
+
quantized_weight_bias_dict[wn] = weight_or_bias
|
| 394 |
+
return quantized_weight_bias_dict
|
| 395 |
+
|
| 396 |
+
def get_flat_weights(self):
|
| 397 |
+
flat_weights = []
|
| 398 |
+
for wn in self._flat_weights_names:
|
| 399 |
+
if hasattr(self, wn):
|
| 400 |
+
weight = getattr(self, wn)
|
| 401 |
+
if wn.startswith("weight"):
|
| 402 |
+
params = _get_weight_and_quantization_params(self, wn)
|
| 403 |
+
weight = _quantize_and_dequantize_weight(*params)
|
| 404 |
+
else:
|
| 405 |
+
weight = None
|
| 406 |
+
flat_weights.append(weight)
|
| 407 |
+
return flat_weights
|
| 408 |
+
|
| 409 |
+
def forward(self, input, hx=None): # noqa: F811
|
| 410 |
+
orig_input = input
|
| 411 |
+
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
| 412 |
+
batch_sizes = None
|
| 413 |
+
if isinstance(orig_input, PackedSequence):
|
| 414 |
+
input, batch_sizes, sorted_indices, unsorted_indices = input
|
| 415 |
+
max_batch_size = int(batch_sizes[0])
|
| 416 |
+
else:
|
| 417 |
+
batch_sizes = None
|
| 418 |
+
is_batched = input.dim() == 3
|
| 419 |
+
batch_dim = 0 if self.batch_first else 1
|
| 420 |
+
if not is_batched:
|
| 421 |
+
input = input.unsqueeze(batch_dim)
|
| 422 |
+
max_batch_size = input.size(0) if self.batch_first else input.size(1)
|
| 423 |
+
sorted_indices = None
|
| 424 |
+
unsorted_indices = None
|
| 425 |
+
|
| 426 |
+
if hx is None:
|
| 427 |
+
num_directions = 2 if self.bidirectional else 1
|
| 428 |
+
real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size
|
| 429 |
+
h_zeros = torch.zeros(self.num_layers * num_directions,
|
| 430 |
+
max_batch_size, real_hidden_size,
|
| 431 |
+
dtype=input.dtype, device=input.device)
|
| 432 |
+
c_zeros = torch.zeros(self.num_layers * num_directions,
|
| 433 |
+
max_batch_size, self.hidden_size,
|
| 434 |
+
dtype=input.dtype, device=input.device)
|
| 435 |
+
hx = (h_zeros, c_zeros)
|
| 436 |
+
else:
|
| 437 |
+
if batch_sizes is None: # If not PackedSequence input.
|
| 438 |
+
if is_batched: # type: ignore[possibly-undefined]
|
| 439 |
+
if (hx[0].dim() != 3 or hx[1].dim() != 3):
|
| 440 |
+
msg = ("For batched 3-D input, hx and cx should "
|
| 441 |
+
f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
|
| 442 |
+
raise RuntimeError(msg)
|
| 443 |
+
else:
|
| 444 |
+
if hx[0].dim() != 2 or hx[1].dim() != 2:
|
| 445 |
+
msg = ("For unbatched 2-D input, hx and cx should "
|
| 446 |
+
f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
|
| 447 |
+
raise RuntimeError(msg)
|
| 448 |
+
hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1))
|
| 449 |
+
|
| 450 |
+
# Each batch of the hidden state should match the input sequence that
|
| 451 |
+
# the user believes he/she is passing in.
|
| 452 |
+
hx = self.permute_hidden(hx, sorted_indices)
|
| 453 |
+
|
| 454 |
+
self.check_forward_args(input, hx, batch_sizes)
|
| 455 |
+
if batch_sizes is None:
|
| 456 |
+
result = _VF.lstm(input, hx, self.get_flat_weights(), self.bias, self.num_layers,
|
| 457 |
+
self.dropout, self.training, self.bidirectional, self.batch_first)
|
| 458 |
+
else:
|
| 459 |
+
result = _VF.lstm(input, batch_sizes, hx, self.get_flat_weights(), self.bias,
|
| 460 |
+
self.num_layers, self.dropout, self.training, self.bidirectional)
|
| 461 |
+
output = result[0]
|
| 462 |
+
hidden = result[1:]
|
| 463 |
+
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
| 464 |
+
if isinstance(orig_input, PackedSequence):
|
| 465 |
+
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
| 466 |
+
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
| 467 |
+
else:
|
| 468 |
+
if not is_batched: # type: ignore[possibly-undefined]
|
| 469 |
+
output = output.squeeze(batch_dim) # type: ignore[possibly-undefined]
|
| 470 |
+
hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))
|
| 471 |
+
return output, self.permute_hidden(hidden, unsorted_indices)
|
| 472 |
+
|
| 473 |
+
def _get_name(self):
|
| 474 |
+
return "QuantizedLSTM(Reference)"
|
| 475 |
+
|
| 476 |
+
@classmethod
|
| 477 |
+
def from_float(cls, mod, weight_qparams_dict):
|
| 478 |
+
ref_mod = cls(
|
| 479 |
+
mod.input_size,
|
| 480 |
+
mod.hidden_size,
|
| 481 |
+
mod.num_layers,
|
| 482 |
+
mod.bias,
|
| 483 |
+
mod.batch_first,
|
| 484 |
+
mod.dropout,
|
| 485 |
+
mod.bidirectional,
|
| 486 |
+
weight_qparams_dict=weight_qparams_dict)
|
| 487 |
+
for wn in mod._flat_weights_names:
|
| 488 |
+
setattr(ref_mod, wn, getattr(mod, wn))
|
| 489 |
+
return ref_mod
|
| 490 |
+
|
| 491 |
+
class GRU(RNNBase):
|
| 492 |
+
""" Reference Quantized GRU Module
|
| 493 |
+
We'll store weight_qparams for all the weights in _flat_weights, we need to pass in
|
| 494 |
+
a `weight_qparams_dict` that maps from weight name, e.g. weight_ih_l0,
|
| 495 |
+
to the weight_qparams for that weight
|
| 496 |
+
"""
|
| 497 |
+
def __init__(self, *args, **kwargs):
|
| 498 |
+
if 'proj_size' in kwargs:
|
| 499 |
+
raise ValueError("proj_size argument is only supported for LSTM, not RNN or GRU")
|
| 500 |
+
super().__init__('GRU', *args, **kwargs)
|
| 501 |
+
|
| 502 |
+
def get_quantized_weight_bias_dict(self):
|
| 503 |
+
""" dictionary from flat_weight_name to quantized weight or (unquantized) bias
|
| 504 |
+
e.g.
|
| 505 |
+
{
|
| 506 |
+
"weight_ih_l0": quantized_weight,
|
| 507 |
+
"bias_ih_l0": unquantized_bias,
|
| 508 |
+
...
|
| 509 |
+
}
|
| 510 |
+
"""
|
| 511 |
+
quantized_weight_bias_dict = {}
|
| 512 |
+
for wn in self._flat_weights_names:
|
| 513 |
+
if hasattr(self, wn):
|
| 514 |
+
if wn.startswith("weight"):
|
| 515 |
+
weight_or_bias = get_quantized_weight(self, wn)
|
| 516 |
+
else:
|
| 517 |
+
weight_or_bias = getattr(self, wn)
|
| 518 |
+
else:
|
| 519 |
+
weight_or_bias = None
|
| 520 |
+
quantized_weight_bias_dict[wn] = weight_or_bias
|
| 521 |
+
return quantized_weight_bias_dict
|
| 522 |
+
|
| 523 |
+
def get_flat_weights(self):
|
| 524 |
+
flat_weights = []
|
| 525 |
+
for wn in self._flat_weights_names:
|
| 526 |
+
if hasattr(self, wn):
|
| 527 |
+
weight = getattr(self, wn)
|
| 528 |
+
if wn.startswith("weight"):
|
| 529 |
+
params = _get_weight_and_quantization_params(self, wn)
|
| 530 |
+
weight = _quantize_and_dequantize_weight(*params)
|
| 531 |
+
else:
|
| 532 |
+
weight = None
|
| 533 |
+
flat_weights.append(weight)
|
| 534 |
+
return flat_weights
|
| 535 |
+
|
| 536 |
+
def forward(self, input, hx=None): # noqa: F811
|
| 537 |
+
# Note: this is copied from the forward of GRU in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py
|
| 538 |
+
# only changed self._flat_weights to self.get_flat_weights()
|
| 539 |
+
# TODO: maybe we can try inheriting from that class and define get_flat_weights
|
| 540 |
+
# as a @property? this might interfere with TorchScript, if we remove that
|
| 541 |
+
# requirement in the future we should be able to do this
|
| 542 |
+
orig_input = input
|
| 543 |
+
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
| 544 |
+
if isinstance(orig_input, PackedSequence):
|
| 545 |
+
input, batch_sizes, sorted_indices, unsorted_indices = input
|
| 546 |
+
max_batch_size = int(batch_sizes[0])
|
| 547 |
+
else:
|
| 548 |
+
batch_sizes = None
|
| 549 |
+
assert (input.dim() in (2, 3)), f"GRU: Expected input to be 2-D or 3-D but received {input.dim()}-D tensor"
|
| 550 |
+
is_batched = input.dim() == 3
|
| 551 |
+
batch_dim = 0 if self.batch_first else 1
|
| 552 |
+
if not is_batched:
|
| 553 |
+
input = input.unsqueeze(batch_dim)
|
| 554 |
+
if hx is not None:
|
| 555 |
+
if hx.dim() != 2:
|
| 556 |
+
raise RuntimeError(
|
| 557 |
+
f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor")
|
| 558 |
+
hx = hx.unsqueeze(1)
|
| 559 |
+
else:
|
| 560 |
+
if hx is not None and hx.dim() != 3:
|
| 561 |
+
raise RuntimeError(
|
| 562 |
+
f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor")
|
| 563 |
+
max_batch_size = input.size(0) if self.batch_first else input.size(1)
|
| 564 |
+
sorted_indices = None
|
| 565 |
+
unsorted_indices = None
|
| 566 |
+
|
| 567 |
+
if hx is None:
|
| 568 |
+
num_directions = 2 if self.bidirectional else 1
|
| 569 |
+
hx = torch.zeros(self.num_layers * num_directions,
|
| 570 |
+
max_batch_size, self.hidden_size,
|
| 571 |
+
dtype=input.dtype, device=input.device)
|
| 572 |
+
else:
|
| 573 |
+
# Each batch of the hidden state should match the input sequence that
|
| 574 |
+
# the user believes he/she is passing in.
|
| 575 |
+
hx = self.permute_hidden(hx, sorted_indices)
|
| 576 |
+
|
| 577 |
+
self.check_forward_args(input, hx, batch_sizes)
|
| 578 |
+
if batch_sizes is None:
|
| 579 |
+
result = _VF.gru(input, hx, self.get_flat_weights(), self.bias, self.num_layers,
|
| 580 |
+
self.dropout, self.training, self.bidirectional, self.batch_first)
|
| 581 |
+
else:
|
| 582 |
+
result = _VF.gru(input, batch_sizes, hx, self.get_flat_weights(), self.bias,
|
| 583 |
+
self.num_layers, self.dropout, self.training, self.bidirectional)
|
| 584 |
+
output = result[0]
|
| 585 |
+
hidden = result[1]
|
| 586 |
+
|
| 587 |
+
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
| 588 |
+
if isinstance(orig_input, PackedSequence):
|
| 589 |
+
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
| 590 |
+
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
| 591 |
+
else:
|
| 592 |
+
if not is_batched: # type: ignore[possibly-undefined]
|
| 593 |
+
output = output.squeeze(batch_dim) # type: ignore[possibly-undefined]
|
| 594 |
+
hidden = hidden.squeeze(1)
|
| 595 |
+
|
| 596 |
+
return output, self.permute_hidden(hidden, unsorted_indices)
|
| 597 |
+
|
| 598 |
+
def _get_name(self):
|
| 599 |
+
return "QuantizedGRU(Reference)"
|
| 600 |
+
|
| 601 |
+
@classmethod
|
| 602 |
+
def from_float(cls, mod, weight_qparams_dict):
|
| 603 |
+
ref_mod = cls(
|
| 604 |
+
mod.input_size,
|
| 605 |
+
mod.hidden_size,
|
| 606 |
+
mod.num_layers,
|
| 607 |
+
mod.bias,
|
| 608 |
+
mod.batch_first,
|
| 609 |
+
mod.dropout,
|
| 610 |
+
mod.bidirectional,
|
| 611 |
+
weight_qparams_dict=weight_qparams_dict)
|
| 612 |
+
for wn in mod._flat_weights_names:
|
| 613 |
+
setattr(ref_mod, wn, getattr(mod, wn))
|
| 614 |
+
return ref_mod
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/utils.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import typing
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"ReferenceQuantizedModule",
|
| 6 |
+
]
|
| 7 |
+
|
| 8 |
+
class ReferenceQuantizedModule(torch.nn.Module):
|
| 9 |
+
def _init_weight_qparams(self, weight_qparams, device):
|
| 10 |
+
if weight_qparams is None:
|
| 11 |
+
weight_qparams = {
|
| 12 |
+
"qscheme": torch.per_tensor_affine,
|
| 13 |
+
"dtype": torch.quint8,
|
| 14 |
+
"scale": 1.0,
|
| 15 |
+
"zero_point": 0
|
| 16 |
+
}
|
| 17 |
+
self.weight_qscheme: torch.qscheme = weight_qparams["qscheme"]
|
| 18 |
+
self.weight_dtype = weight_qparams["dtype"]
|
| 19 |
+
assert self.weight_qscheme in [
|
| 20 |
+
None, torch.per_tensor_affine, torch.per_channel_affine,
|
| 21 |
+
torch.per_channel_affine_float_qparams], \
|
| 22 |
+
Exception(f"qscheme: {self.weight_qscheme} is not support in reference quantized {self._get_name()}")
|
| 23 |
+
if self.weight_dtype in [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32]:
|
| 24 |
+
zero_point_dtype = weight_qparams["zero_point"].dtype if \
|
| 25 |
+
isinstance(weight_qparams["zero_point"], torch.Tensor) else \
|
| 26 |
+
torch.int
|
| 27 |
+
w_scale = weight_qparams["scale"]
|
| 28 |
+
w_scale_tensor = w_scale.clone().detach() \
|
| 29 |
+
if isinstance(w_scale, torch.Tensor) \
|
| 30 |
+
else torch.tensor(w_scale, dtype=torch.float, device=device)
|
| 31 |
+
self.register_buffer("weight_scale", w_scale_tensor)
|
| 32 |
+
w_zp = weight_qparams["zero_point"]
|
| 33 |
+
w_zp_tensor = w_zp.clone().detach() \
|
| 34 |
+
if isinstance(w_zp, torch.Tensor) \
|
| 35 |
+
else torch.tensor(w_zp, dtype=zero_point_dtype, device=device)
|
| 36 |
+
self.register_buffer("weight_zero_point", w_zp_tensor)
|
| 37 |
+
if self.weight_qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams]:
|
| 38 |
+
w_axis = weight_qparams["axis"]
|
| 39 |
+
w_axis_tensor = w_axis.clone().detach() \
|
| 40 |
+
if isinstance(w_axis, torch.Tensor) \
|
| 41 |
+
else torch.tensor(w_axis, dtype=torch.int, device=device)
|
| 42 |
+
self.register_buffer("weight_axis", w_axis_tensor)
|
| 43 |
+
else:
|
| 44 |
+
# added for TorchScriptability, not used
|
| 45 |
+
self.register_buffer(
|
| 46 |
+
"weight_axis", torch.tensor(0, dtype=torch.int, device=device))
|
| 47 |
+
else:
|
| 48 |
+
# added for TorchScriptability, and for torch.float
|
| 49 |
+
self.register_buffer("weight_scale", torch.tensor(1.0, dtype=torch.float, device=device))
|
| 50 |
+
self.register_buffer("weight_zero_point", torch.tensor(0, dtype=torch.int, device=device))
|
| 51 |
+
self.register_buffer(
|
| 52 |
+
"weight_axis", torch.tensor(0, dtype=torch.int, device=device))
|
| 53 |
+
self.is_decomposed: bool = weight_qparams.get("is_decomposed", False)
|
| 54 |
+
# store weight_axis as weight_axis_int due to some constraints of torchdynamo.export
|
| 55 |
+
# for capturing `.item` operations
|
| 56 |
+
self.weight_axis_int: int = self.weight_axis.item() # type: ignore[operator, assignment]
|
| 57 |
+
self.weight_quant_min: typing.Optional[int] = weight_qparams.get("quant_min", None)
|
| 58 |
+
self.weight_quant_max: typing.Optional[int] = weight_qparams.get("quant_max", None)
|
| 59 |
+
|
| 60 |
+
def get_weight(self):
|
| 61 |
+
"""
|
| 62 |
+
Fake quantize (quantize and dequantize) the weight with
|
| 63 |
+
the quantization parameters for weight, this is used to
|
| 64 |
+
simulate the numerics for the quantized weight in a quantized
|
| 65 |
+
model
|
| 66 |
+
"""
|
| 67 |
+
# suppress mypy warning
|
| 68 |
+
assert isinstance(self.weight_scale, torch.Tensor)
|
| 69 |
+
assert isinstance(self.weight_zero_point, torch.Tensor)
|
| 70 |
+
if self.is_decomposed:
|
| 71 |
+
return _quantize_and_dequantize_weight_decomposed(
|
| 72 |
+
self.weight, # type: ignore[arg-type]
|
| 73 |
+
self.weight_qscheme,
|
| 74 |
+
self.weight_dtype,
|
| 75 |
+
self.weight_scale,
|
| 76 |
+
self.weight_zero_point,
|
| 77 |
+
self.weight_axis_int,
|
| 78 |
+
self.weight_quant_min,
|
| 79 |
+
self.weight_quant_max)
|
| 80 |
+
else:
|
| 81 |
+
return _quantize_and_dequantize_weight(
|
| 82 |
+
self.weight, # type: ignore[arg-type]
|
| 83 |
+
self.weight_qscheme,
|
| 84 |
+
self.weight_dtype,
|
| 85 |
+
self.weight_scale,
|
| 86 |
+
self.weight_zero_point,
|
| 87 |
+
self.weight_axis_int)
|
| 88 |
+
|
| 89 |
+
def get_quantized_weight(self):
|
| 90 |
+
# suppress mypy warning
|
| 91 |
+
assert isinstance(self.weight_scale, torch.Tensor)
|
| 92 |
+
assert isinstance(self.weight_zero_point, torch.Tensor)
|
| 93 |
+
# assert isinstance(self.weight_axis, torch.Tensor)
|
| 94 |
+
if self.is_decomposed:
|
| 95 |
+
return _quantize_weight_decomposed(
|
| 96 |
+
self.weight, # type: ignore[arg-type]
|
| 97 |
+
self.weight_qscheme,
|
| 98 |
+
self.weight_dtype,
|
| 99 |
+
self.weight_scale,
|
| 100 |
+
self.weight_zero_point,
|
| 101 |
+
self.weight_axis_int,
|
| 102 |
+
self.weight_quant_min,
|
| 103 |
+
self.weight_quant_max)
|
| 104 |
+
else:
|
| 105 |
+
return _quantize_weight(
|
| 106 |
+
self.weight, # type: ignore[arg-type]
|
| 107 |
+
self.weight_qscheme,
|
| 108 |
+
self.weight_dtype,
|
| 109 |
+
self.weight_scale,
|
| 110 |
+
self.weight_zero_point,
|
| 111 |
+
self.weight_axis_int)
|
| 112 |
+
|
| 113 |
+
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
| 114 |
+
super()._save_to_state_dict(destination, prefix, keep_vars)
|
| 115 |
+
_save_weight_qparams(
|
| 116 |
+
destination, prefix, self.weight_qscheme, self.weight_dtype,
|
| 117 |
+
self.weight_scale, self.weight_zero_point, self.weight_axis)
|
| 118 |
+
|
| 119 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
| 120 |
+
missing_keys, unexpected_keys, error_msgs):
|
| 121 |
+
for key in _get_weight_qparam_keys(state_dict, prefix):
|
| 122 |
+
setattr(self, key, state_dict[prefix + key])
|
| 123 |
+
state_dict.pop(prefix + key)
|
| 124 |
+
|
| 125 |
+
super()._load_from_state_dict(
|
| 126 |
+
state_dict, prefix, local_metadata, False,
|
| 127 |
+
missing_keys, unexpected_keys, error_msgs)
|
| 128 |
+
|
| 129 |
+
def _quantize_weight_decomposed(
|
| 130 |
+
weight: torch.Tensor,
|
| 131 |
+
weight_qscheme: torch.qscheme,
|
| 132 |
+
weight_dtype: torch.dtype,
|
| 133 |
+
weight_scale: torch.Tensor,
|
| 134 |
+
weight_zero_point: torch.Tensor,
|
| 135 |
+
weight_axis: int,
|
| 136 |
+
weight_quant_min: typing.Optional[int],
|
| 137 |
+
weight_quant_max: typing.Optional[int],
|
| 138 |
+
) -> torch.Tensor:
|
| 139 |
+
_DTYPE_TO_QVALUE_BOUNDS = {
|
| 140 |
+
torch.uint8: (0, 255),
|
| 141 |
+
torch.int8: (-128, 127),
|
| 142 |
+
torch.int32: (-(2**31), 2**31 - 1),
|
| 143 |
+
}
|
| 144 |
+
# TODO: add an util function for converting qdtype to dtype
|
| 145 |
+
_QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE = {
|
| 146 |
+
torch.quint8: torch.uint8,
|
| 147 |
+
torch.qint8: torch.int8,
|
| 148 |
+
torch.qint32: torch.int32,
|
| 149 |
+
}
|
| 150 |
+
if weight_qscheme == torch.per_tensor_affine:
|
| 151 |
+
if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
|
| 152 |
+
weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype]
|
| 153 |
+
if weight_quant_min is None or weight_quant_max is None:
|
| 154 |
+
weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype_]
|
| 155 |
+
weight = torch.ops.quantized_decomposed.quantize_per_tensor(
|
| 156 |
+
weight,
|
| 157 |
+
weight_scale,
|
| 158 |
+
weight_zero_point,
|
| 159 |
+
weight_quant_min,
|
| 160 |
+
weight_quant_max,
|
| 161 |
+
weight_dtype_
|
| 162 |
+
)
|
| 163 |
+
return weight
|
| 164 |
+
elif weight_qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams]:
|
| 165 |
+
# TODO: torch.quint4x2 is not supported
|
| 166 |
+
if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
|
| 167 |
+
weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype]
|
| 168 |
+
if weight_quant_min is None or weight_quant_max is None:
|
| 169 |
+
weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype_]
|
| 170 |
+
weight = torch.ops.quantized_decomposed.quantize_per_channel(
|
| 171 |
+
weight,
|
| 172 |
+
weight_scale,
|
| 173 |
+
weight_zero_point,
|
| 174 |
+
weight_axis,
|
| 175 |
+
weight_quant_min,
|
| 176 |
+
weight_quant_max,
|
| 177 |
+
weight_dtype_) # type: ignore[arg-type]
|
| 178 |
+
return weight
|
| 179 |
+
raise Exception(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}")
|
| 180 |
+
|
| 181 |
+
def _dequantize_weight_decomposed(
|
| 182 |
+
weight: torch.Tensor,
|
| 183 |
+
weight_qscheme: torch.qscheme,
|
| 184 |
+
weight_dtype: torch.dtype,
|
| 185 |
+
weight_scale: torch.Tensor,
|
| 186 |
+
weight_zero_point: torch.Tensor,
|
| 187 |
+
weight_axis: int,
|
| 188 |
+
weight_quant_min: typing.Optional[int],
|
| 189 |
+
weight_quant_max: typing.Optional[int],
|
| 190 |
+
) -> torch.Tensor:
|
| 191 |
+
# TODO: get the quant_min and quant_max from activation_post_process
|
| 192 |
+
_DTYPE_TO_QVALUE_BOUNDS = {
|
| 193 |
+
torch.uint8: (0, 255),
|
| 194 |
+
torch.int8: (-128, 127),
|
| 195 |
+
torch.int32: (-(2**31), 2**31 - 1),
|
| 196 |
+
}
|
| 197 |
+
# TODO: add an util function for converting qdtype to dtype
|
| 198 |
+
_QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE = {
|
| 199 |
+
torch.quint8: torch.uint8,
|
| 200 |
+
torch.qint8: torch.int8,
|
| 201 |
+
torch.qint32: torch.int32,
|
| 202 |
+
}
|
| 203 |
+
weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype]
|
| 204 |
+
if weight_quant_min is None or weight_quant_max is None:
|
| 205 |
+
weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype_]
|
| 206 |
+
if weight_qscheme == torch.per_tensor_affine:
|
| 207 |
+
if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
|
| 208 |
+
weight = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
| 209 |
+
weight,
|
| 210 |
+
weight_scale,
|
| 211 |
+
weight_zero_point,
|
| 212 |
+
weight_quant_min,
|
| 213 |
+
weight_quant_max,
|
| 214 |
+
weight_dtype_
|
| 215 |
+
)
|
| 216 |
+
return weight
|
| 217 |
+
elif weight_qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams]:
|
| 218 |
+
# TODO: torch.quint4x2 is not supported
|
| 219 |
+
if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
|
| 220 |
+
weight = torch.ops.quantized_decomposed.dequantize_per_channel(
|
| 221 |
+
weight,
|
| 222 |
+
weight_scale,
|
| 223 |
+
weight_zero_point,
|
| 224 |
+
weight_axis,
|
| 225 |
+
weight_quant_min,
|
| 226 |
+
weight_quant_max,
|
| 227 |
+
weight_dtype_) # type: ignore[arg-type]
|
| 228 |
+
return weight
|
| 229 |
+
raise Exception(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}")
|
| 230 |
+
|
| 231 |
+
def _quantize_weight(
|
| 232 |
+
weight: torch.Tensor,
|
| 233 |
+
weight_qscheme: torch.qscheme,
|
| 234 |
+
weight_dtype: torch.dtype,
|
| 235 |
+
weight_scale: torch.Tensor,
|
| 236 |
+
weight_zero_point: torch.Tensor,
|
| 237 |
+
weight_axis_int: int
|
| 238 |
+
) -> torch.Tensor:
|
| 239 |
+
if weight_dtype == torch.float16:
|
| 240 |
+
weight = weight.to(weight_dtype)
|
| 241 |
+
return weight
|
| 242 |
+
|
| 243 |
+
if weight_qscheme == torch.per_tensor_affine:
|
| 244 |
+
if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
|
| 245 |
+
weight = torch.quantize_per_tensor(weight, weight_scale, weight_zero_point, weight_dtype)
|
| 246 |
+
return weight
|
| 247 |
+
elif weight_qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams]:
|
| 248 |
+
if weight_dtype in [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32]:
|
| 249 |
+
weight = torch.quantize_per_channel(
|
| 250 |
+
weight, weight_scale,
|
| 251 |
+
weight_zero_point, weight_axis_int, weight_dtype) # type: ignore[arg-type]
|
| 252 |
+
return weight
|
| 253 |
+
raise Exception(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}")
|
| 254 |
+
|
| 255 |
+
def _quantize_and_dequantize_weight_decomposed(
|
| 256 |
+
weight: torch.Tensor,
|
| 257 |
+
weight_qscheme: torch.qscheme,
|
| 258 |
+
weight_dtype: torch.dtype,
|
| 259 |
+
weight_scale: torch.Tensor,
|
| 260 |
+
weight_zero_point: torch.Tensor,
|
| 261 |
+
weight_axis_int: int,
|
| 262 |
+
weight_quant_min: typing.Optional[int],
|
| 263 |
+
weight_quant_max: typing.Optional[int],
|
| 264 |
+
) -> torch.Tensor:
|
| 265 |
+
""" Quantize and then dequantize the weight based on
|
| 266 |
+
the quantization parameters
|
| 267 |
+
"""
|
| 268 |
+
if weight_qscheme in [
|
| 269 |
+
torch.per_tensor_affine,
|
| 270 |
+
torch.per_channel_affine,
|
| 271 |
+
torch.per_channel_affine_float_qparams]:
|
| 272 |
+
weight_quant = _quantize_weight_decomposed(
|
| 273 |
+
weight, weight_qscheme, weight_dtype, weight_scale, weight_zero_point, weight_axis_int,
|
| 274 |
+
weight_quant_min, weight_quant_max)
|
| 275 |
+
weight_dequant = _dequantize_weight_decomposed(
|
| 276 |
+
weight_quant, weight_qscheme, weight_dtype, weight_scale, weight_zero_point,
|
| 277 |
+
weight_axis_int, weight_quant_min, weight_quant_max)
|
| 278 |
+
else:
|
| 279 |
+
weight_dequant = weight
|
| 280 |
+
return weight_dequant
|
| 281 |
+
|
| 282 |
+
def _quantize_and_dequantize_weight(
|
| 283 |
+
weight: torch.Tensor,
|
| 284 |
+
weight_qscheme: torch.qscheme,
|
| 285 |
+
weight_dtype: torch.dtype,
|
| 286 |
+
weight_scale: torch.Tensor,
|
| 287 |
+
weight_zero_point: torch.Tensor,
|
| 288 |
+
weight_axis_int: int
|
| 289 |
+
) -> torch.Tensor:
|
| 290 |
+
""" Quantize and then dequantize the weight based on
|
| 291 |
+
the quantization parameters
|
| 292 |
+
"""
|
| 293 |
+
if weight_qscheme in [
|
| 294 |
+
torch.per_tensor_affine,
|
| 295 |
+
torch.per_channel_affine,
|
| 296 |
+
torch.per_channel_affine_float_qparams]:
|
| 297 |
+
weight_quant = _quantize_weight(
|
| 298 |
+
weight, weight_qscheme, weight_dtype, weight_scale, weight_zero_point, weight_axis_int)
|
| 299 |
+
weight_dequant = weight_quant.dequantize()
|
| 300 |
+
else:
|
| 301 |
+
weight_dequant = weight
|
| 302 |
+
return weight_dequant
|
| 303 |
+
|
| 304 |
+
def _save_weight_qparams(destination, prefix, weight_qscheme, weight_dtype, weight_scale, weight_zero_point, weight_axis):
|
| 305 |
+
destination[prefix + "weight_qscheme"] = weight_qscheme
|
| 306 |
+
destination[prefix + "weight_dtype"] = weight_dtype
|
| 307 |
+
if weight_qscheme is not None:
|
| 308 |
+
destination[prefix + "weight_scale"] = weight_scale
|
| 309 |
+
destination[prefix + "weight_zero_point"] = weight_zero_point
|
| 310 |
+
if weight_qscheme == torch.per_channel_affine:
|
| 311 |
+
destination[prefix + "weight_axis"] = weight_axis
|
| 312 |
+
|
| 313 |
+
def _get_weight_qparam_keys(
|
| 314 |
+
state_dict: typing.Dict[str, typing.Any],
|
| 315 |
+
prefix: str):
|
| 316 |
+
keys = ["weight_qscheme", "weight_dtype"]
|
| 317 |
+
weight_qscheme = state_dict[prefix + "weight_qscheme"]
|
| 318 |
+
if weight_qscheme is not None:
|
| 319 |
+
keys.append("weight_scale")
|
| 320 |
+
keys.append("weight_zero_point")
|
| 321 |
+
if weight_qscheme == torch.quantize_per_channel:
|
| 322 |
+
keys.append("weight_axis")
|
| 323 |
+
return keys
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/sparse/quantized/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (2.36 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/sparse/quantized/dynamic/linear.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.ao.nn.intrinsic as nni
|
| 5 |
+
|
| 6 |
+
from torch.ao.nn.sparse.quantized import linear
|
| 7 |
+
from torch.ao.nn.sparse.quantized.utils import LinearBlockSparsePattern
|
| 8 |
+
from torch.ao.nn.quantized.modules.utils import _quantize_weight, _hide_packed_params_repr
|
| 9 |
+
|
| 10 |
+
__all__ = ['Linear']
|
| 11 |
+
|
| 12 |
+
class Linear(torch.nn.Module):
|
| 13 |
+
r"""
|
| 14 |
+
A dynamically quantized sparse linear module with float tensor as inputs and outputs.
|
| 15 |
+
"""
|
| 16 |
+
_version = 1
|
| 17 |
+
_op_type = "sparse_dynamic"
|
| 18 |
+
_FLOAT_MODULE = torch.nn.Linear
|
| 19 |
+
|
| 20 |
+
def __init__(self, in_features, out_features, row_block_size, col_block_size, bias=True, dtype=torch.qint8):
|
| 21 |
+
super().__init__()
|
| 22 |
+
|
| 23 |
+
if dtype != torch.qint8:
|
| 24 |
+
raise NotImplementedError("Only QINT8 is supported for Sparse Quantized Linear Dynamic")
|
| 25 |
+
|
| 26 |
+
self.in_features = in_features
|
| 27 |
+
self.out_features = out_features
|
| 28 |
+
|
| 29 |
+
if bias:
|
| 30 |
+
bias = torch.zeros(self.out_features, dtype=torch.float)
|
| 31 |
+
else:
|
| 32 |
+
bias = None
|
| 33 |
+
|
| 34 |
+
qweight = torch._empty_affine_quantized([out_features, in_features],
|
| 35 |
+
scale=1, zero_point=0, dtype=torch.qint8)
|
| 36 |
+
self._packed_params = linear.LinearPackedParams(row_block_size=row_block_size,
|
| 37 |
+
col_block_size=col_block_size,
|
| 38 |
+
dtype=dtype)
|
| 39 |
+
self._packed_params.set_weight_bias(qweight, bias, row_block_size, col_block_size)
|
| 40 |
+
|
| 41 |
+
def _get_name(self):
|
| 42 |
+
return 'SparseQuantizedDynamicLinear'
|
| 43 |
+
|
| 44 |
+
def extra_repr(self):
|
| 45 |
+
return f'in_features={self.in_features}, out_features={self.out_features}, qscheme={self.weight().qscheme()}'
|
| 46 |
+
|
| 47 |
+
def __repr__(self):
|
| 48 |
+
return _hide_packed_params_repr(self, linear.LinearPackedParams)
|
| 49 |
+
|
| 50 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 51 |
+
return torch.ops.sparse.qlinear_dynamic(x, self._packed_params._packed_params)
|
| 52 |
+
|
| 53 |
+
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
| 54 |
+
super()._save_to_state_dict(destination, prefix, keep_vars)
|
| 55 |
+
destination[prefix + 'op_type'] = self._op_type
|
| 56 |
+
|
| 57 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
| 58 |
+
missing_keys, unexpected_keys, error_msgs):
|
| 59 |
+
op_type = int(state_dict[prefix + 'op_type'])
|
| 60 |
+
assert op_type == 'sparse', \
|
| 61 |
+
f"Cannot load from op_type [{op_type}], expecting [{self._op_type}]"
|
| 62 |
+
state_dict.pop(prefix + 'op_type')
|
| 63 |
+
|
| 64 |
+
version = local_metadata.get('version', None)
|
| 65 |
+
assert version <= self._version
|
| 66 |
+
|
| 67 |
+
# Is this code valid? In old quantization it seemed to be used to load
|
| 68 |
+
# older model
|
| 69 |
+
weight = state_dict.pop(prefix + 'weight')
|
| 70 |
+
bias = state_dict.pop(prefix + 'bias')
|
| 71 |
+
state_dict.update({prefix + '_packed_params.weight': weight,
|
| 72 |
+
prefix + '_packed_params.bias': bias})
|
| 73 |
+
|
| 74 |
+
super()._load_from_state_dict(
|
| 75 |
+
state_dict, prefix, local_metadata, False,
|
| 76 |
+
missing_keys, unexpected_keys, error_msgs)
|
| 77 |
+
|
| 78 |
+
def _weight_bias(self):
|
| 79 |
+
return self._packed_params._weight_bias()
|
| 80 |
+
|
| 81 |
+
def weight(self):
|
| 82 |
+
return self._weight_bias()[0]
|
| 83 |
+
|
| 84 |
+
def bias(self):
|
| 85 |
+
return self._weight_bias()[1]
|
| 86 |
+
|
| 87 |
+
def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor],
|
| 88 |
+
row_block_size: Optional[int], col_block_size: Optional[int]) -> None:
|
| 89 |
+
assert row_block_size is not None and col_block_size is not None
|
| 90 |
+
self.out_features = w.shape[0]
|
| 91 |
+
self.in_features = w.shape[1]
|
| 92 |
+
self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size)
|
| 93 |
+
|
| 94 |
+
@classmethod
|
| 95 |
+
def from_float(cls, mod):
|
| 96 |
+
r"""Create a quantized sparse dynamic module from a float module.
|
| 97 |
+
|
| 98 |
+
We only care about the convert at this stage, no need for observers just yet.
|
| 99 |
+
"""
|
| 100 |
+
assert type(mod) == cls._FLOAT_MODULE, ' nnq.' + cls.__name__ + '.from_float only works for ' + \
|
| 101 |
+
cls._FLOAT_MODULE.__name__
|
| 102 |
+
# TODO: Need to add options to qconfig to avoid the calibration.
|
| 103 |
+
# TODO: Add calibration for the sparsity
|
| 104 |
+
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
|
| 105 |
+
if type(mod) == nni.LinearReLU:
|
| 106 |
+
mod = mod[0]
|
| 107 |
+
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
| 108 |
+
weight_observer = mod.qconfig.weight()
|
| 109 |
+
else:
|
| 110 |
+
# We have the circular import issues if we import the qconfig in the beginning of this file:
|
| 111 |
+
# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
|
| 112 |
+
# import until we need it.
|
| 113 |
+
from torch.ao.quantization.qconfig import default_dynamic_qconfig
|
| 114 |
+
weight_observer = default_dynamic_qconfig.weight()
|
| 115 |
+
|
| 116 |
+
# It is important to multiply by the mask BEFORE calling the `weight_observer`
|
| 117 |
+
# TODO (zaf): Mask might not be part of the qconfig (T83295194)
|
| 118 |
+
weight = mod.weight
|
| 119 |
+
if getattr(mod.qconfig, 'mask', False):
|
| 120 |
+
weight = mod.qconfig.mask * mod.weight
|
| 121 |
+
|
| 122 |
+
weight_observer(weight)
|
| 123 |
+
dtype = weight_observer.dtype
|
| 124 |
+
assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
|
| 125 |
+
w_sc, w_zp = weight_observer.calculate_qparams()
|
| 126 |
+
if isinstance(w_zp, torch.Tensor):
|
| 127 |
+
assert not torch.any(w_zp.bool()), "All weight zero points must map to 0"
|
| 128 |
+
else:
|
| 129 |
+
assert w_zp == 0, 'Weight zero point must map to 0'
|
| 130 |
+
qweight = _quantize_weight(weight.float(), weight_observer)
|
| 131 |
+
|
| 132 |
+
row_block_size, col_block_size = LinearBlockSparsePattern.block_size()
|
| 133 |
+
qlinear = cls(mod.in_features,
|
| 134 |
+
mod.out_features,
|
| 135 |
+
row_block_size,
|
| 136 |
+
col_block_size,
|
| 137 |
+
dtype=dtype)
|
| 138 |
+
qlinear.set_weight_bias(qweight, mod.bias, row_block_size, col_block_size)
|
| 139 |
+
return qlinear
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (231 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_scheduler/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base_data_scheduler import BaseDataScheduler
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
"BaseDataScheduler",
|
| 5 |
+
]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/quantization_utils.cpython-311.pyc
ADDED
|
Binary file (6.74 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.ao.pruning.sparsifier.utils import module_to_fqn, fqn_to_module
|
| 4 |
+
from typing import Dict, List, Optional
|
| 5 |
+
|
| 6 |
+
SUPPORTED_MODULES = {
|
| 7 |
+
nn.Embedding,
|
| 8 |
+
nn.EmbeddingBag
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _fetch_all_embeddings(model):
|
| 13 |
+
"""Fetches Embedding and EmbeddingBag modules from the model
|
| 14 |
+
"""
|
| 15 |
+
embedding_modules = []
|
| 16 |
+
stack = [model]
|
| 17 |
+
while stack:
|
| 18 |
+
module = stack.pop()
|
| 19 |
+
for _, child in module.named_children():
|
| 20 |
+
fqn_name = module_to_fqn(model, child)
|
| 21 |
+
if type(child) in SUPPORTED_MODULES:
|
| 22 |
+
embedding_modules.append((fqn_name, child))
|
| 23 |
+
else:
|
| 24 |
+
stack.append(child)
|
| 25 |
+
return embedding_modules
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def post_training_sparse_quantize(model,
|
| 29 |
+
data_sparsifier_class,
|
| 30 |
+
sparsify_first=True,
|
| 31 |
+
select_embeddings: Optional[List[nn.Module]] = None,
|
| 32 |
+
**sparse_config):
|
| 33 |
+
"""Takes in a model and applies sparsification and quantization to only embeddings & embeddingbags.
|
| 34 |
+
The quantization step can happen before or after sparsification depending on the `sparsify_first` argument.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
- model (nn.Module)
|
| 38 |
+
model whose embeddings needs to be sparsified
|
| 39 |
+
- data_sparsifier_class (type of data sparsifier)
|
| 40 |
+
Type of sparsification that needs to be applied to model
|
| 41 |
+
- sparsify_first (bool)
|
| 42 |
+
if true, sparsifies first and then quantizes
|
| 43 |
+
otherwise, quantizes first and then sparsifies.
|
| 44 |
+
- select_embeddings (List of Embedding modules)
|
| 45 |
+
List of embedding modules to in the model to be sparsified & quantized.
|
| 46 |
+
If None, all embedding modules with be sparsified
|
| 47 |
+
- sparse_config (Dict)
|
| 48 |
+
config that will be passed to the constructor of data sparsifier object.
|
| 49 |
+
|
| 50 |
+
Note:
|
| 51 |
+
1. When `sparsify_first=False`, quantization occurs first followed by sparsification.
|
| 52 |
+
- before sparsifying, the embedding layers are dequantized.
|
| 53 |
+
- scales and zero-points are saved
|
| 54 |
+
- embedding layers are sparsified and `squash_mask` is applied
|
| 55 |
+
- embedding weights are requantized using the saved scales and zero-points
|
| 56 |
+
2. When `sparsify_first=True`, sparsification occurs first followed by quantization.
|
| 57 |
+
- embeddings are sparsified first
|
| 58 |
+
- quantization is applied on the sparsified embeddings
|
| 59 |
+
"""
|
| 60 |
+
data_sparsifier = data_sparsifier_class(**sparse_config)
|
| 61 |
+
|
| 62 |
+
# if select_embeddings is None, perform it on all embeddings
|
| 63 |
+
if select_embeddings is None:
|
| 64 |
+
embedding_modules = _fetch_all_embeddings(model)
|
| 65 |
+
|
| 66 |
+
else:
|
| 67 |
+
embedding_modules = []
|
| 68 |
+
assert isinstance(select_embeddings, List), "the embedding_modules must be a list of embedding modules"
|
| 69 |
+
for emb in select_embeddings:
|
| 70 |
+
assert type(emb) in SUPPORTED_MODULES, "the embedding_modules list must be an embedding or embedding bags"
|
| 71 |
+
fqn_name = module_to_fqn(model, emb)
|
| 72 |
+
assert fqn_name is not None, "the embedding modules must be part of input model"
|
| 73 |
+
embedding_modules.append((fqn_name, emb))
|
| 74 |
+
|
| 75 |
+
if sparsify_first:
|
| 76 |
+
# sparsify
|
| 77 |
+
for name, emb_module in embedding_modules:
|
| 78 |
+
valid_name = name.replace('.', '_')
|
| 79 |
+
data_sparsifier.add_data(name=valid_name, data=emb_module)
|
| 80 |
+
|
| 81 |
+
data_sparsifier.step()
|
| 82 |
+
data_sparsifier.squash_mask()
|
| 83 |
+
|
| 84 |
+
# quantize
|
| 85 |
+
for _, emb_module in embedding_modules:
|
| 86 |
+
emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig
|
| 87 |
+
|
| 88 |
+
torch.ao.quantization.prepare(model, inplace=True)
|
| 89 |
+
torch.ao.quantization.convert(model, inplace=True)
|
| 90 |
+
|
| 91 |
+
else:
|
| 92 |
+
# quantize
|
| 93 |
+
for _, emb_module in embedding_modules:
|
| 94 |
+
emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig
|
| 95 |
+
|
| 96 |
+
torch.ao.quantization.prepare(model, inplace=True)
|
| 97 |
+
torch.ao.quantization.convert(model, inplace=True)
|
| 98 |
+
|
| 99 |
+
# retrieve scale & zero_points
|
| 100 |
+
quantize_params: Dict[str, Dict] = {'scales': {}, 'zero_points': {},
|
| 101 |
+
'dequant_weights': {}, 'axis': {},
|
| 102 |
+
'dtype': {}}
|
| 103 |
+
|
| 104 |
+
for name, _ in embedding_modules:
|
| 105 |
+
quantized_emb = fqn_to_module(model, name)
|
| 106 |
+
assert quantized_emb is not None # satisfy mypy
|
| 107 |
+
|
| 108 |
+
quantized_weight = quantized_emb.weight() # type: ignore[operator]
|
| 109 |
+
quantize_params['scales'][name] = quantized_weight.q_per_channel_scales()
|
| 110 |
+
quantize_params['zero_points'][name] = quantized_weight.q_per_channel_zero_points()
|
| 111 |
+
quantize_params['dequant_weights'][name] = torch.dequantize(quantized_weight)
|
| 112 |
+
quantize_params['axis'][name] = quantized_weight.q_per_channel_axis()
|
| 113 |
+
quantize_params['dtype'][name] = quantized_weight.dtype
|
| 114 |
+
|
| 115 |
+
# attach data to sparsifier
|
| 116 |
+
data_sparsifier.add_data(name=name.replace('.', '_'), data=quantize_params['dequant_weights'][name])
|
| 117 |
+
|
| 118 |
+
data_sparsifier.step()
|
| 119 |
+
data_sparsifier.squash_mask()
|
| 120 |
+
|
| 121 |
+
for name, _ in embedding_modules:
|
| 122 |
+
quantized_emb = fqn_to_module(model, name)
|
| 123 |
+
assert quantized_emb is not None # satisfy mypy
|
| 124 |
+
requantized_vector = torch.quantize_per_channel(quantize_params['dequant_weights'][name],
|
| 125 |
+
scales=quantize_params['scales'][name],
|
| 126 |
+
zero_points=quantize_params['zero_points'][name],
|
| 127 |
+
dtype=quantize_params['dtype'][name],
|
| 128 |
+
axis=quantize_params['axis'][name])
|
| 129 |
+
|
| 130 |
+
quantized_emb.set_weight(requantized_vector) # type: ignore[operator]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/FPGM_pruner.cpython-311.pyc
ADDED
|
Binary file (5.52 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (659 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/pruner/saliency_pruner.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base_structured_sparsifier import BaseStructuredSparsifier
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class SaliencyPruner(BaseStructuredSparsifier):
|
| 5 |
+
"""
|
| 6 |
+
Prune rows based on the saliency (L1 norm) of each row.
|
| 7 |
+
|
| 8 |
+
This pruner works on N-Dimensional weight tensors.
|
| 9 |
+
For each row, we will calculate the saliency, whic is the sum the L1 norm of all weights in that row.
|
| 10 |
+
We expect that the resulting saliency vector has the same shape as our mask.
|
| 11 |
+
We then pick elements to remove until we reach the target sparsity_level.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def update_mask(self, module, tensor_name, **kwargs):
|
| 15 |
+
# tensor_name will give you the FQN, all other entries in sparse config is present in kwargs
|
| 16 |
+
weights = getattr(module, tensor_name)
|
| 17 |
+
mask = getattr(module.parametrizations, tensor_name)[0].mask
|
| 18 |
+
|
| 19 |
+
# use negative weights so we can use topk (we prune out the smallest)
|
| 20 |
+
if weights.dim() <= 1:
|
| 21 |
+
raise Exception("Structured pruning can only be applied to a 2+dim weight tensor!")
|
| 22 |
+
saliency = -weights.norm(dim=tuple(range(1, weights.dim())), p=1)
|
| 23 |
+
assert saliency.shape == mask.shape
|
| 24 |
+
|
| 25 |
+
num_to_pick = int(len(mask) * kwargs["sparsity_level"])
|
| 26 |
+
prune = saliency.topk(num_to_pick).indices
|
| 27 |
+
|
| 28 |
+
# Set the mask to be false for the rows we want to prune
|
| 29 |
+
mask.data[prune] = False
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/scheduler/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/scheduler/__pycache__/base_scheduler.cpython-311.pyc
ADDED
|
Binary file (8.9 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/scheduler/lambda_scheduler.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
|
| 3 |
+
from .base_scheduler import BaseScheduler
|
| 4 |
+
|
| 5 |
+
__all__ = ["LambdaSL"]
|
| 6 |
+
|
| 7 |
+
class LambdaSL(BaseScheduler):
|
| 8 |
+
"""Sets the sparsity level of each parameter group to the final sl
|
| 9 |
+
times a given function. When last_epoch=-1, sets initial sl as zero.
|
| 10 |
+
Args:
|
| 11 |
+
sparsifier (BaseSparsifier): Wrapped sparsifier.
|
| 12 |
+
sl_lambda (function or list): A function which computes a multiplicative
|
| 13 |
+
factor given an integer parameter epoch, or a list of such
|
| 14 |
+
functions, one for each group in sparsifier.param_groups.
|
| 15 |
+
last_epoch (int): The index of last epoch. Default: -1.
|
| 16 |
+
verbose (bool): If ``True``, prints a message to stdout for
|
| 17 |
+
each update. Default: ``False``.
|
| 18 |
+
Example:
|
| 19 |
+
>>> # Assuming sparsifier has two groups.
|
| 20 |
+
>>> lambda1 = lambda epoch: epoch // 30
|
| 21 |
+
>>> lambda2 = lambda epoch: 0.95 ** epoch
|
| 22 |
+
>>> # xdoctest: +SKIP
|
| 23 |
+
>>> scheduler = LambdaSL(sparsifier, sl_lambda=[lambda1, lambda2])
|
| 24 |
+
>>> for epoch in range(100):
|
| 25 |
+
>>> train(...)
|
| 26 |
+
>>> validate(...)
|
| 27 |
+
>>> scheduler.step()
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, sparsifier, sl_lambda, last_epoch=-1, verbose=False):
|
| 31 |
+
self.sparsifier = sparsifier
|
| 32 |
+
|
| 33 |
+
if not isinstance(sl_lambda, list) and not isinstance(sl_lambda, tuple):
|
| 34 |
+
self.sl_lambdas = [sl_lambda] * len(sparsifier.groups)
|
| 35 |
+
else:
|
| 36 |
+
if len(sl_lambda) != len(sparsifier.groups):
|
| 37 |
+
raise ValueError(f"Expected {len(sparsifier.groups)} lr_lambdas, but got {len(sl_lambda)}")
|
| 38 |
+
self.sl_lambdas = list(sl_lambda)
|
| 39 |
+
super().__init__(sparsifier, last_epoch, verbose)
|
| 40 |
+
|
| 41 |
+
def get_sl(self):
|
| 42 |
+
if not self._get_sl_called_within_step:
|
| 43 |
+
warnings.warn(
|
| 44 |
+
"To get the last sparsity level computed by the scheduler, "
|
| 45 |
+
"please use `get_last_sl()`.")
|
| 46 |
+
return [base_sl * lmbda(self.last_epoch)
|
| 47 |
+
for lmbda, base_sl in zip(self.sl_lambdas, self.base_sl)]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/sparsifier/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (228 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/sparsifier/__pycache__/base_sparsifier.cpython-311.pyc
ADDED
|
Binary file (17.5 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/sparsifier/__pycache__/weight_norm_sparsifier.cpython-311.pyc
ADDED
|
Binary file (11.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/_equalize.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import copy
|
| 3 |
+
from typing import Dict, Any
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"set_module_weight",
|
| 7 |
+
"set_module_bias",
|
| 8 |
+
"get_module_weight",
|
| 9 |
+
"get_module_bias",
|
| 10 |
+
"max_over_ndim",
|
| 11 |
+
"min_over_ndim",
|
| 12 |
+
"channel_range",
|
| 13 |
+
"cross_layer_equalization",
|
| 14 |
+
"equalize",
|
| 15 |
+
"converged",
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
_supported_types = {torch.nn.Conv2d, torch.nn.Linear}
|
| 19 |
+
_supported_intrinsic_types = {torch.ao.nn.intrinsic.ConvReLU2d, torch.ao.nn.intrinsic.LinearReLU}
|
| 20 |
+
_all_supported_types = _supported_types.union(_supported_intrinsic_types)
|
| 21 |
+
|
| 22 |
+
def set_module_weight(module, weight) -> None:
|
| 23 |
+
if type(module) in _supported_types:
|
| 24 |
+
module.weight = torch.nn.Parameter(weight)
|
| 25 |
+
else:
|
| 26 |
+
module[0].weight = torch.nn.Parameter(weight)
|
| 27 |
+
|
| 28 |
+
def set_module_bias(module, bias) -> None:
|
| 29 |
+
if type(module) in _supported_types:
|
| 30 |
+
module.bias = torch.nn.Parameter(bias)
|
| 31 |
+
else:
|
| 32 |
+
module[0].bias = torch.nn.Parameter(bias)
|
| 33 |
+
|
| 34 |
+
def get_module_weight(module):
|
| 35 |
+
if type(module) in _supported_types:
|
| 36 |
+
return module.weight
|
| 37 |
+
else:
|
| 38 |
+
return module[0].weight
|
| 39 |
+
|
| 40 |
+
def get_module_bias(module):
|
| 41 |
+
if type(module) in _supported_types:
|
| 42 |
+
return module.bias
|
| 43 |
+
else:
|
| 44 |
+
return module[0].bias
|
| 45 |
+
|
| 46 |
+
def max_over_ndim(input, axis_list, keepdim=False):
|
| 47 |
+
"""Apply 'torch.max' over the given axes."""
|
| 48 |
+
axis_list.sort(reverse=True)
|
| 49 |
+
for axis in axis_list:
|
| 50 |
+
input, _ = input.max(axis, keepdim)
|
| 51 |
+
return input
|
| 52 |
+
|
| 53 |
+
def min_over_ndim(input, axis_list, keepdim=False):
|
| 54 |
+
"""Apply 'torch.min' over the given axes."""
|
| 55 |
+
axis_list.sort(reverse=True)
|
| 56 |
+
for axis in axis_list:
|
| 57 |
+
input, _ = input.min(axis, keepdim)
|
| 58 |
+
return input
|
| 59 |
+
|
| 60 |
+
def channel_range(input, axis=0):
|
| 61 |
+
"""Find the range of weights associated with a specific channel."""
|
| 62 |
+
size_of_tensor_dim = input.ndim
|
| 63 |
+
axis_list = list(range(size_of_tensor_dim))
|
| 64 |
+
axis_list.remove(axis)
|
| 65 |
+
|
| 66 |
+
mins = min_over_ndim(input, axis_list)
|
| 67 |
+
maxs = max_over_ndim(input, axis_list)
|
| 68 |
+
|
| 69 |
+
assert mins.size(0) == input.size(axis), "Dimensions of resultant channel range does not match size of requested axis"
|
| 70 |
+
return maxs - mins
|
| 71 |
+
|
| 72 |
+
def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1):
|
| 73 |
+
"""Scale the range of Tensor1.output to equal Tensor2.input.
|
| 74 |
+
|
| 75 |
+
Given two adjacent tensors', the weights are scaled such that
|
| 76 |
+
the ranges of the first tensors' output channel are equal to the
|
| 77 |
+
ranges of the second tensors' input channel
|
| 78 |
+
"""
|
| 79 |
+
if type(module1) not in _all_supported_types or type(module2) not in _all_supported_types:
|
| 80 |
+
raise ValueError("module type not supported:", type(module1), " ", type(module2))
|
| 81 |
+
|
| 82 |
+
weight1 = get_module_weight(module1)
|
| 83 |
+
weight2 = get_module_weight(module2)
|
| 84 |
+
|
| 85 |
+
if weight1.size(output_axis) != weight2.size(input_axis):
|
| 86 |
+
raise TypeError("Number of output channels of first arg do not match \
|
| 87 |
+
number input channels of second arg")
|
| 88 |
+
|
| 89 |
+
bias = get_module_bias(module1)
|
| 90 |
+
|
| 91 |
+
weight1_range = channel_range(weight1, output_axis)
|
| 92 |
+
weight2_range = channel_range(weight2, input_axis)
|
| 93 |
+
|
| 94 |
+
# producing scaling factors to applied
|
| 95 |
+
weight2_range += 1e-9
|
| 96 |
+
scaling_factors = torch.sqrt(weight1_range / weight2_range)
|
| 97 |
+
inverse_scaling_factors = torch.reciprocal(scaling_factors)
|
| 98 |
+
|
| 99 |
+
bias = bias * inverse_scaling_factors
|
| 100 |
+
|
| 101 |
+
# formatting the scaling (1D) tensors to be applied on the given argument tensors
|
| 102 |
+
# pads axis to (1D) tensors to then be broadcasted
|
| 103 |
+
size1 = [1] * weight1.ndim
|
| 104 |
+
size1[output_axis] = weight1.size(output_axis)
|
| 105 |
+
size2 = [1] * weight2.ndim
|
| 106 |
+
size2[input_axis] = weight2.size(input_axis)
|
| 107 |
+
|
| 108 |
+
scaling_factors = torch.reshape(scaling_factors, size2)
|
| 109 |
+
inverse_scaling_factors = torch.reshape(inverse_scaling_factors, size1)
|
| 110 |
+
|
| 111 |
+
weight1 = weight1 * inverse_scaling_factors
|
| 112 |
+
weight2 = weight2 * scaling_factors
|
| 113 |
+
|
| 114 |
+
set_module_weight(module1, weight1)
|
| 115 |
+
set_module_bias(module1, bias)
|
| 116 |
+
set_module_weight(module2, weight2)
|
| 117 |
+
|
| 118 |
+
def equalize(model, paired_modules_list, threshold=1e-4, inplace=True):
|
| 119 |
+
"""Equalize modules until convergence is achieved.
|
| 120 |
+
|
| 121 |
+
Given a list of adjacent modules within a model, equalization will
|
| 122 |
+
be applied between each pair, this will repeated until convergence is achieved
|
| 123 |
+
|
| 124 |
+
Keeps a copy of the changing modules from the previous iteration, if the copies
|
| 125 |
+
are not that different than the current modules (determined by converged_test),
|
| 126 |
+
then the modules have converged enough that further equalizing is not necessary
|
| 127 |
+
|
| 128 |
+
Implementation of this referced section 4.1 of this paper https://arxiv.org/pdf/1906.04721.pdf
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
model: a model (nn.module) that equalization is to be applied on
|
| 132 |
+
paired_modules_list: a list of lists where each sublist is a pair of two
|
| 133 |
+
submodules found in the model, for each pair the two submodules generally
|
| 134 |
+
have to be adjacent in the model to get expected/reasonable results
|
| 135 |
+
threshold: a number used by the converged function to determine what degree
|
| 136 |
+
similarity between models is necessary for them to be called equivalent
|
| 137 |
+
inplace: determines if function is inplace or not
|
| 138 |
+
"""
|
| 139 |
+
if not inplace:
|
| 140 |
+
model = copy.deepcopy(model)
|
| 141 |
+
|
| 142 |
+
name_to_module : Dict[str, torch.nn.Module] = {}
|
| 143 |
+
previous_name_to_module: Dict[str, Any] = {}
|
| 144 |
+
name_set = {name for pair in paired_modules_list for name in pair}
|
| 145 |
+
|
| 146 |
+
for name, module in model.named_modules():
|
| 147 |
+
if name in name_set:
|
| 148 |
+
name_to_module[name] = module
|
| 149 |
+
previous_name_to_module[name] = None
|
| 150 |
+
while not converged(name_to_module, previous_name_to_module, threshold):
|
| 151 |
+
for pair in paired_modules_list:
|
| 152 |
+
previous_name_to_module[pair[0]] = copy.deepcopy(name_to_module[pair[0]])
|
| 153 |
+
previous_name_to_module[pair[1]] = copy.deepcopy(name_to_module[pair[1]])
|
| 154 |
+
|
| 155 |
+
cross_layer_equalization(name_to_module[pair[0]], name_to_module[pair[1]])
|
| 156 |
+
|
| 157 |
+
return model
|
| 158 |
+
|
| 159 |
+
def converged(curr_modules, prev_modules, threshold=1e-4):
|
| 160 |
+
"""Test whether modules are converged to a specified threshold.
|
| 161 |
+
|
| 162 |
+
Tests for the summed norm of the differences between each set of modules
|
| 163 |
+
being less than the given threshold
|
| 164 |
+
|
| 165 |
+
Takes two dictionaries mapping names to modules, the set of names for each dictionary
|
| 166 |
+
should be the same, looping over the set of names, for each name take the difference
|
| 167 |
+
between the associated modules in each dictionary
|
| 168 |
+
|
| 169 |
+
"""
|
| 170 |
+
if curr_modules.keys() != prev_modules.keys():
|
| 171 |
+
raise ValueError("The keys to the given mappings must have the same set of names of modules")
|
| 172 |
+
|
| 173 |
+
summed_norms = torch.tensor(0.)
|
| 174 |
+
if None in prev_modules.values():
|
| 175 |
+
return False
|
| 176 |
+
for name in curr_modules.keys():
|
| 177 |
+
curr_weight = get_module_weight(curr_modules[name])
|
| 178 |
+
prev_weight = get_module_weight(prev_modules[name])
|
| 179 |
+
|
| 180 |
+
difference = curr_weight.sub(prev_weight)
|
| 181 |
+
summed_norms += torch.norm(difference)
|
| 182 |
+
return bool(summed_norms < threshold)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .backend_config import BackendConfig, BackendPatternConfig, DTypeConfig, DTypeWithConstraints, ObservationType
|
| 2 |
+
from .fbgemm import get_fbgemm_backend_config
|
| 3 |
+
from .native import get_native_backend_config, get_native_backend_config_dict
|
| 4 |
+
from .qnnpack import get_qnnpack_backend_config
|
| 5 |
+
from .tensorrt import get_tensorrt_backend_config, get_tensorrt_backend_config_dict
|
| 6 |
+
from .executorch import get_executorch_backend_config
|
| 7 |
+
from .onednn import get_onednn_backend_config
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"get_fbgemm_backend_config",
|
| 11 |
+
"get_native_backend_config",
|
| 12 |
+
"get_native_backend_config_dict",
|
| 13 |
+
"get_qnnpack_backend_config",
|
| 14 |
+
"get_tensorrt_backend_config",
|
| 15 |
+
"get_tensorrt_backend_config_dict",
|
| 16 |
+
"get_executorch_backend_config",
|
| 17 |
+
"BackendConfig",
|
| 18 |
+
"BackendPatternConfig",
|
| 19 |
+
"DTypeConfig",
|
| 20 |
+
"DTypeWithConstraints",
|
| 21 |
+
"ObservationType",
|
| 22 |
+
"get_onednn_backend_config",
|
| 23 |
+
]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.08 kB). View file
|
|
|