Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/Lib/site-packages/torch/nn/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/__pycache__/_reduction.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/__pycache__/common_types.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/__pycache__/functional.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/__pycache__/grad.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/__pycache__/init.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/__pycache__/parameter.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/quantized/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/quantized/__pycache__/functional.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/quantized/_reference/__init__.py +1 -0
- .venv/Lib/site-packages/torch/nn/quantized/_reference/modules/__init__.py +39 -0
- .venv/Lib/site-packages/torch/nn/quantized/_reference/modules/conv.py +21 -0
- .venv/Lib/site-packages/torch/nn/quantized/_reference/modules/linear.py +12 -0
- .venv/Lib/site-packages/torch/nn/quantized/_reference/modules/rnn.py +19 -0
- .venv/Lib/site-packages/torch/nn/quantized/_reference/modules/sparse.py +12 -0
- .venv/Lib/site-packages/torch/nn/quantized/_reference/modules/utils.py +18 -0
- .venv/Lib/site-packages/torch/nn/utils/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/utils/__pycache__/_named_member_accessor.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/utils/__pycache__/clip_grad.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/utils/__pycache__/convert_parameters.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/utils/__pycache__/fusion.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/utils/__pycache__/init.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/utils/__pycache__/memory_format.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/utils/__pycache__/parametrizations.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/utils/__pycache__/parametrize.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/utils/__pycache__/rnn.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/utils/__pycache__/spectral_norm.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/utils/__pycache__/stateless.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/utils/__pycache__/weight_norm.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/utils/_expanded_weights/__init__.py +10 -0
- .venv/Lib/site-packages/torch/nn/utils/clip_grad.py +189 -0
- .venv/Lib/site-packages/torch/onnx/__init__.py +553 -0
- .venv/Lib/site-packages/torch/onnx/_constants.py +25 -0
- .venv/Lib/site-packages/torch/onnx/_deprecation.py +72 -0
- .venv/Lib/site-packages/torch/onnx/_experimental.py +27 -0
- .venv/Lib/site-packages/torch/onnx/_exporter_states.py +12 -0
- .venv/Lib/site-packages/torch/onnx/_flags.py +49 -0
- .venv/Lib/site-packages/torch/onnx/_globals.py +87 -0
- .venv/Lib/site-packages/torch/onnx/_internal/__init__.py +0 -0
- .venv/Lib/site-packages/torch/onnx/_internal/_lazy_import.py +41 -0
- .venv/Lib/site-packages/torch/onnx/_internal/diagnostics/__init__.py +22 -0
- .venv/Lib/site-packages/torch/onnx/_internal/diagnostics/_diagnostic.py +211 -0
- .venv/Lib/site-packages/torch/onnx/_internal/diagnostics/_rules.py +636 -0
- .venv/Lib/site-packages/torch/onnx/_internal/diagnostics/infra/_infra.py +285 -0
- .venv/Lib/site-packages/torch/onnx/_internal/diagnostics/infra/context.py +404 -0
- .venv/Lib/site-packages/torch/onnx/_internal/diagnostics/infra/decorator.py +153 -0
- .venv/Lib/site-packages/torch/onnx/_internal/diagnostics/infra/formatter.py +106 -0
- .venv/Lib/site-packages/torch/onnx/_internal/diagnostics/infra/utils.py +69 -0
- .venv/Lib/site-packages/torch/onnx/_internal/io_adapter.py +641 -0
- .venv/Lib/site-packages/torch/onnx/_internal/jit_utils.py +373 -0
.venv/Lib/site-packages/torch/nn/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (2.23 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/__pycache__/_reduction.cpython-39.pyc
ADDED
|
Binary file (1.31 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/__pycache__/common_types.cpython-39.pyc
ADDED
|
Binary file (1.03 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/__pycache__/functional.cpython-39.pyc
ADDED
|
Binary file (182 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/__pycache__/grad.cpython-39.pyc
ADDED
|
Binary file (9.05 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/__pycache__/init.cpython-39.pyc
ADDED
|
Binary file (21.2 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/__pycache__/parameter.cpython-39.pyc
ADDED
|
Binary file (10.8 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/quantized/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (744 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/nn/quantized/__pycache__/functional.cpython-39.pyc
ADDED
|
Binary file (459 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/nn/quantized/_reference/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from torch.nn.quantized._reference.modules import * # noqa: F403
|
.venv/Lib/site-packages/torch/nn/quantized/_reference/modules/__init__.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""Quantized Reference Modules.
|
| 3 |
+
|
| 4 |
+
This module is in the process of migration to
|
| 5 |
+
`torch/ao/nn/quantized/reference`, and is kept here for
|
| 6 |
+
compatibility while the migration process is ongoing.
|
| 7 |
+
If you are adding a new entry/functionality, please, add it to the
|
| 8 |
+
appropriate file under the `torch/ao/nn/quantized/reference`,
|
| 9 |
+
while adding an import statement here.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from torch.ao.nn.quantized.reference.modules.conv import (
|
| 13 |
+
Conv1d,
|
| 14 |
+
Conv2d,
|
| 15 |
+
Conv3d,
|
| 16 |
+
ConvTranspose1d,
|
| 17 |
+
ConvTranspose2d,
|
| 18 |
+
ConvTranspose3d,
|
| 19 |
+
)
|
| 20 |
+
from torch.ao.nn.quantized.reference.modules.linear import Linear
|
| 21 |
+
from torch.ao.nn.quantized.reference.modules.rnn import GRUCell, LSTM, LSTMCell, RNNCell
|
| 22 |
+
from torch.ao.nn.quantized.reference.modules.sparse import Embedding, EmbeddingBag
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
__all__ = [
|
| 26 |
+
"Linear",
|
| 27 |
+
"Conv1d",
|
| 28 |
+
"Conv2d",
|
| 29 |
+
"Conv3d",
|
| 30 |
+
"ConvTranspose1d",
|
| 31 |
+
"ConvTranspose2d",
|
| 32 |
+
"ConvTranspose3d",
|
| 33 |
+
"RNNCell",
|
| 34 |
+
"LSTMCell",
|
| 35 |
+
"GRUCell",
|
| 36 |
+
"LSTM",
|
| 37 |
+
"Embedding",
|
| 38 |
+
"EmbeddingBag",
|
| 39 |
+
]
|
.venv/Lib/site-packages/torch/nn/quantized/_reference/modules/conv.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""Quantized Reference Modules.
|
| 3 |
+
|
| 4 |
+
This module is in the process of migration to
|
| 5 |
+
`torch/ao/nn/quantized/reference`, and is kept here for
|
| 6 |
+
compatibility while the migration process is ongoing.
|
| 7 |
+
If you are adding a new entry/functionality, please, add it to the
|
| 8 |
+
appropriate file under the `torch/ao/nn/quantized/reference`,
|
| 9 |
+
while adding an import statement here.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from torch.ao.nn.quantized.reference.modules.conv import (
|
| 13 |
+
_ConvNd,
|
| 14 |
+
_ConvTransposeNd,
|
| 15 |
+
Conv1d,
|
| 16 |
+
Conv2d,
|
| 17 |
+
Conv3d,
|
| 18 |
+
ConvTranspose1d,
|
| 19 |
+
ConvTranspose2d,
|
| 20 |
+
ConvTranspose3d,
|
| 21 |
+
)
|
.venv/Lib/site-packages/torch/nn/quantized/_reference/modules/linear.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""Quantized Reference Modules.
|
| 3 |
+
|
| 4 |
+
This module is in the process of migration to
|
| 5 |
+
`torch/ao/nn/quantized/reference`, and is kept here for
|
| 6 |
+
compatibility while the migration process is ongoing.
|
| 7 |
+
If you are adding a new entry/functionality, please, add it to the
|
| 8 |
+
appropriate file under the `torch/ao/nn/quantized/reference`,
|
| 9 |
+
while adding an import statement here.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from torch.ao.nn.quantized.reference.modules.linear import Linear
|
.venv/Lib/site-packages/torch/nn/quantized/_reference/modules/rnn.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""Quantized Reference Modules.
|
| 3 |
+
|
| 4 |
+
This module is in the process of migration to
|
| 5 |
+
`torch/ao/nn/quantized/reference`, and is kept here for
|
| 6 |
+
compatibility while the migration process is ongoing.
|
| 7 |
+
If you are adding a new entry/functionality, please, add it to the
|
| 8 |
+
appropriate file under the `torch/ao/nn/quantized/reference`,
|
| 9 |
+
while adding an import statement here.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from torch.ao.nn.quantized.reference.modules.rnn import (
|
| 13 |
+
GRUCell,
|
| 14 |
+
LSTM,
|
| 15 |
+
LSTMCell,
|
| 16 |
+
RNNBase,
|
| 17 |
+
RNNCell,
|
| 18 |
+
RNNCellBase,
|
| 19 |
+
)
|
.venv/Lib/site-packages/torch/nn/quantized/_reference/modules/sparse.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""Quantized Reference Modules.
|
| 3 |
+
|
| 4 |
+
This module is in the process of migration to
|
| 5 |
+
`torch/ao/nn/quantized/reference`, and is kept here for
|
| 6 |
+
compatibility while the migration process is ongoing.
|
| 7 |
+
If you are adding a new entry/functionality, please, add it to the
|
| 8 |
+
appropriate file under the `torch/ao/nn/quantized/reference`,
|
| 9 |
+
while adding an import statement here.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from torch.ao.nn.quantized.reference.modules.sparse import Embedding, EmbeddingBag
|
.venv/Lib/site-packages/torch/nn/quantized/_reference/modules/utils.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""Quantized Reference Modules.
|
| 3 |
+
|
| 4 |
+
This module is in the process of migration to
|
| 5 |
+
`torch/ao/nn/quantized/reference`, and is kept here for
|
| 6 |
+
compatibility while the migration process is ongoing.
|
| 7 |
+
If you are adding a new entry/functionality, please, add it to the
|
| 8 |
+
appropriate file under the `torch/ao/nn/quantized/reference`,
|
| 9 |
+
while adding an import statement here.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from torch.ao.nn.quantized.reference.modules.utils import (
|
| 13 |
+
_get_weight_qparam_keys,
|
| 14 |
+
_quantize_and_dequantize_weight,
|
| 15 |
+
_quantize_weight,
|
| 16 |
+
_save_weight_qparams,
|
| 17 |
+
ReferenceQuantizedModule,
|
| 18 |
+
)
|
.venv/Lib/site-packages/torch/nn/utils/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (1 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/utils/__pycache__/_named_member_accessor.cpython-39.pyc
ADDED
|
Binary file (12.1 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/utils/__pycache__/clip_grad.cpython-39.pyc
ADDED
|
Binary file (6.26 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/utils/__pycache__/convert_parameters.cpython-39.pyc
ADDED
|
Binary file (2.51 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/utils/__pycache__/fusion.cpython-39.pyc
ADDED
|
Binary file (4.97 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/utils/__pycache__/init.cpython-39.pyc
ADDED
|
Binary file (2.38 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/utils/__pycache__/memory_format.cpython-39.pyc
ADDED
|
Binary file (7.57 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/utils/__pycache__/parametrizations.cpython-39.pyc
ADDED
|
Binary file (18 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/utils/__pycache__/parametrize.cpython-39.pyc
ADDED
|
Binary file (23.7 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/utils/__pycache__/rnn.cpython-39.pyc
ADDED
|
Binary file (20.9 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/utils/__pycache__/spectral_norm.cpython-39.pyc
ADDED
|
Binary file (9.94 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/utils/__pycache__/stateless.cpython-39.pyc
ADDED
|
Binary file (9.39 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/utils/__pycache__/weight_norm.cpython-39.pyc
ADDED
|
Binary file (5.9 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/utils/_expanded_weights/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .conv_expanded_weights import ConvPerSampleGrad
|
| 2 |
+
from .embedding_expanded_weights import EmbeddingPerSampleGrad
|
| 3 |
+
from .expanded_weights_impl import ExpandedWeight
|
| 4 |
+
from .group_norm_expanded_weights import GroupNormPerSampleGrad
|
| 5 |
+
from .instance_norm_expanded_weights import InstanceNormPerSampleGrad
|
| 6 |
+
from .layer_norm_expanded_weights import LayerNormPerSampleGrad
|
| 7 |
+
from .linear_expanded_weights import LinearPerSampleGrad
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
__all__ = ["ExpandedWeight"]
|
.venv/Lib/site-packages/torch/nn/utils/clip_grad.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-decorators
|
| 2 |
+
# mypy: allow-untyped-defs
|
| 3 |
+
import functools
|
| 4 |
+
from typing import cast, Dict, Iterable, List, Optional, Tuple, Union
|
| 5 |
+
from typing_extensions import deprecated
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
from torch.utils._foreach_utils import (
|
| 10 |
+
_device_has_foreach_support,
|
| 11 |
+
_group_tensors_by_device_and_dtype,
|
| 12 |
+
_has_foreach_support,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
__all__ = ["clip_grad_norm_", "clip_grad_norm", "clip_grad_value_"]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _no_grad(func):
|
| 23 |
+
"""
|
| 24 |
+
This wrapper is needed to avoid a circular import when using @torch.no_grad on the exposed functions
|
| 25 |
+
clip_grad_norm_ and clip_grad_value_ themselves.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def _no_grad_wrapper(*args, **kwargs):
|
| 29 |
+
with torch.no_grad():
|
| 30 |
+
return func(*args, **kwargs)
|
| 31 |
+
|
| 32 |
+
functools.update_wrapper(_no_grad_wrapper, func)
|
| 33 |
+
return _no_grad_wrapper
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@_no_grad
|
| 37 |
+
def clip_grad_norm_(
|
| 38 |
+
parameters: _tensor_or_tensors,
|
| 39 |
+
max_norm: float,
|
| 40 |
+
norm_type: float = 2.0,
|
| 41 |
+
error_if_nonfinite: bool = False,
|
| 42 |
+
foreach: Optional[bool] = None,
|
| 43 |
+
) -> torch.Tensor:
|
| 44 |
+
r"""Clip the gradient norm of an iterable of parameters.
|
| 45 |
+
|
| 46 |
+
The norm is computed over the norms of the individual gradients of all parameters,
|
| 47 |
+
as if the norms of the individual gradients were concatenated into a single vector.
|
| 48 |
+
Gradients are modified in-place.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
| 52 |
+
single Tensor that will have gradients normalized
|
| 53 |
+
max_norm (float): max norm of the gradients
|
| 54 |
+
norm_type (float): type of the used p-norm. Can be ``'inf'`` for
|
| 55 |
+
infinity norm.
|
| 56 |
+
error_if_nonfinite (bool): if True, an error is thrown if the total
|
| 57 |
+
norm of the gradients from :attr:`parameters` is ``nan``,
|
| 58 |
+
``inf``, or ``-inf``. Default: False (will switch to True in the future)
|
| 59 |
+
foreach (bool): use the faster foreach-based implementation.
|
| 60 |
+
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
|
| 61 |
+
fall back to the slow implementation for other device types.
|
| 62 |
+
Default: ``None``
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
Total norm of the parameter gradients (viewed as a single vector).
|
| 66 |
+
"""
|
| 67 |
+
if isinstance(parameters, torch.Tensor):
|
| 68 |
+
parameters = [parameters]
|
| 69 |
+
grads = [p.grad for p in parameters if p.grad is not None]
|
| 70 |
+
max_norm = float(max_norm)
|
| 71 |
+
norm_type = float(norm_type)
|
| 72 |
+
if len(grads) == 0:
|
| 73 |
+
return torch.tensor(0.0)
|
| 74 |
+
first_device = grads[0].device
|
| 75 |
+
grouped_grads: Dict[
|
| 76 |
+
Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
|
| 77 |
+
] = _group_tensors_by_device_and_dtype(
|
| 78 |
+
[grads]
|
| 79 |
+
) # type: ignore[assignment]
|
| 80 |
+
|
| 81 |
+
norms: List[Tensor] = []
|
| 82 |
+
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
|
| 83 |
+
if (foreach is None and _has_foreach_support(device_grads, device)) or (
|
| 84 |
+
foreach and _device_has_foreach_support(device)
|
| 85 |
+
):
|
| 86 |
+
norms.extend(torch._foreach_norm(device_grads, norm_type))
|
| 87 |
+
elif foreach:
|
| 88 |
+
raise RuntimeError(
|
| 89 |
+
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
|
| 90 |
+
)
|
| 91 |
+
else:
|
| 92 |
+
norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
|
| 93 |
+
|
| 94 |
+
total_norm = torch.linalg.vector_norm(
|
| 95 |
+
torch.stack([norm.to(first_device) for norm in norms]), norm_type
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
|
| 99 |
+
raise RuntimeError(
|
| 100 |
+
f"The total norm of order {norm_type} for gradients from "
|
| 101 |
+
"`parameters` is non-finite, so it cannot be clipped. To disable "
|
| 102 |
+
"this error and scale the gradients by the non-finite norm anyway, "
|
| 103 |
+
"set `error_if_nonfinite=False`"
|
| 104 |
+
)
|
| 105 |
+
clip_coef = max_norm / (total_norm + 1e-6)
|
| 106 |
+
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
|
| 107 |
+
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
|
| 108 |
+
# when the gradients do not reside in CPU memory.
|
| 109 |
+
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
|
| 110 |
+
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
|
| 111 |
+
if (foreach is None and _has_foreach_support(device_grads, device)) or (
|
| 112 |
+
foreach and _device_has_foreach_support(device)
|
| 113 |
+
):
|
| 114 |
+
torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
|
| 115 |
+
elif foreach:
|
| 116 |
+
raise RuntimeError(
|
| 117 |
+
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
|
| 118 |
+
)
|
| 119 |
+
else:
|
| 120 |
+
clip_coef_clamped_device = clip_coef_clamped.to(device)
|
| 121 |
+
for g in device_grads:
|
| 122 |
+
g.mul_(clip_coef_clamped_device)
|
| 123 |
+
|
| 124 |
+
return total_norm
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
@deprecated(
|
| 128 |
+
"`torch.nn.utils.clip_grad_norm` is now deprecated "
|
| 129 |
+
"in favor of `torch.nn.utils.clip_grad_norm_`.",
|
| 130 |
+
category=FutureWarning,
|
| 131 |
+
)
|
| 132 |
+
def clip_grad_norm(
|
| 133 |
+
parameters: _tensor_or_tensors,
|
| 134 |
+
max_norm: float,
|
| 135 |
+
norm_type: float = 2.0,
|
| 136 |
+
error_if_nonfinite: bool = False,
|
| 137 |
+
foreach: Optional[bool] = None,
|
| 138 |
+
) -> torch.Tensor:
|
| 139 |
+
r"""Clip the gradient norm of an iterable of parameters.
|
| 140 |
+
|
| 141 |
+
.. warning::
|
| 142 |
+
This method is now deprecated in favor of
|
| 143 |
+
:func:`torch.nn.utils.clip_grad_norm_`.
|
| 144 |
+
"""
|
| 145 |
+
return clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite, foreach)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
@_no_grad
|
| 149 |
+
def clip_grad_value_(
|
| 150 |
+
parameters: _tensor_or_tensors,
|
| 151 |
+
clip_value: float,
|
| 152 |
+
foreach: Optional[bool] = None,
|
| 153 |
+
) -> None:
|
| 154 |
+
r"""Clip the gradients of an iterable of parameters at specified value.
|
| 155 |
+
|
| 156 |
+
Gradients are modified in-place.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
| 160 |
+
single Tensor that will have gradients normalized
|
| 161 |
+
clip_value (float): maximum allowed value of the gradients.
|
| 162 |
+
The gradients are clipped in the range
|
| 163 |
+
:math:`\left[\text{-clip\_value}, \text{clip\_value}\right]`
|
| 164 |
+
foreach (bool): use the faster foreach-based implementation
|
| 165 |
+
If ``None``, use the foreach implementation for CUDA and CPU native tensors and
|
| 166 |
+
silently fall back to the slow implementation for other device types.
|
| 167 |
+
Default: ``None``
|
| 168 |
+
"""
|
| 169 |
+
if isinstance(parameters, torch.Tensor):
|
| 170 |
+
parameters = [parameters]
|
| 171 |
+
clip_value = float(clip_value)
|
| 172 |
+
|
| 173 |
+
grads = [p.grad for p in parameters if p.grad is not None]
|
| 174 |
+
grouped_grads = _group_tensors_by_device_and_dtype([grads])
|
| 175 |
+
|
| 176 |
+
for (device, _), ([grads], _) in grouped_grads.items(): # type: ignore[assignment]
|
| 177 |
+
if (
|
| 178 |
+
foreach is None
|
| 179 |
+
and _has_foreach_support(cast(List[Tensor], grads), device=device)
|
| 180 |
+
) or (foreach and _device_has_foreach_support(device)):
|
| 181 |
+
torch._foreach_clamp_min_(cast(List[Tensor], grads), -clip_value)
|
| 182 |
+
torch._foreach_clamp_max_(cast(List[Tensor], grads), clip_value)
|
| 183 |
+
elif foreach:
|
| 184 |
+
raise RuntimeError(
|
| 185 |
+
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
|
| 186 |
+
)
|
| 187 |
+
else:
|
| 188 |
+
for grad in grads:
|
| 189 |
+
cast(Tensor, grad).clamp_(min=-clip_value, max=clip_value)
|
.venv/Lib/site-packages/torch/onnx/__init__.py
ADDED
|
@@ -0,0 +1,553 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
# Modules
|
| 7 |
+
"symbolic_helper",
|
| 8 |
+
"utils",
|
| 9 |
+
"errors",
|
| 10 |
+
# All opsets
|
| 11 |
+
"symbolic_caffe2",
|
| 12 |
+
"symbolic_opset7",
|
| 13 |
+
"symbolic_opset8",
|
| 14 |
+
"symbolic_opset9",
|
| 15 |
+
"symbolic_opset10",
|
| 16 |
+
"symbolic_opset11",
|
| 17 |
+
"symbolic_opset12",
|
| 18 |
+
"symbolic_opset13",
|
| 19 |
+
"symbolic_opset14",
|
| 20 |
+
"symbolic_opset15",
|
| 21 |
+
"symbolic_opset16",
|
| 22 |
+
"symbolic_opset17",
|
| 23 |
+
"symbolic_opset18",
|
| 24 |
+
"symbolic_opset19",
|
| 25 |
+
"symbolic_opset20",
|
| 26 |
+
# Enums
|
| 27 |
+
"ExportTypes",
|
| 28 |
+
"OperatorExportTypes",
|
| 29 |
+
"TrainingMode",
|
| 30 |
+
"TensorProtoDataType",
|
| 31 |
+
"JitScalarType",
|
| 32 |
+
# Public functions
|
| 33 |
+
"export",
|
| 34 |
+
"export_to_pretty_string",
|
| 35 |
+
"is_in_onnx_export",
|
| 36 |
+
"select_model_mode_for_export",
|
| 37 |
+
"register_custom_op_symbolic",
|
| 38 |
+
"unregister_custom_op_symbolic",
|
| 39 |
+
"disable_log",
|
| 40 |
+
"enable_log",
|
| 41 |
+
# Base error
|
| 42 |
+
"OnnxExporterError",
|
| 43 |
+
# Dynamo Exporter
|
| 44 |
+
"DiagnosticOptions",
|
| 45 |
+
"ExportOptions",
|
| 46 |
+
"ONNXProgram",
|
| 47 |
+
"ONNXRuntimeOptions",
|
| 48 |
+
"OnnxRegistry",
|
| 49 |
+
"dynamo_export",
|
| 50 |
+
"enable_fake_mode",
|
| 51 |
+
# DORT / torch.compile
|
| 52 |
+
"is_onnxrt_backend_supported",
|
| 53 |
+
]
|
| 54 |
+
|
| 55 |
+
from typing import Any, Callable, Collection, Mapping, Sequence, TYPE_CHECKING
|
| 56 |
+
|
| 57 |
+
import torch
|
| 58 |
+
from torch import _C
|
| 59 |
+
from torch._C import _onnx as _C_onnx
|
| 60 |
+
from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode
|
| 61 |
+
|
| 62 |
+
from ._exporter_states import ExportTypes
|
| 63 |
+
from ._internal.onnxruntime import (
|
| 64 |
+
is_onnxrt_backend_supported,
|
| 65 |
+
OrtBackend as _OrtBackend,
|
| 66 |
+
OrtBackendOptions as _OrtBackendOptions,
|
| 67 |
+
OrtExecutionProvider as _OrtExecutionProvider,
|
| 68 |
+
)
|
| 69 |
+
from ._type_utils import JitScalarType
|
| 70 |
+
from .errors import OnnxExporterError
|
| 71 |
+
from .utils import (
|
| 72 |
+
_optimize_graph,
|
| 73 |
+
_run_symbolic_function,
|
| 74 |
+
_run_symbolic_method,
|
| 75 |
+
export_to_pretty_string,
|
| 76 |
+
is_in_onnx_export,
|
| 77 |
+
register_custom_op_symbolic,
|
| 78 |
+
select_model_mode_for_export,
|
| 79 |
+
unregister_custom_op_symbolic,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
from . import ( # usort: skip. Keep the order instead of sorting lexicographically
|
| 84 |
+
errors,
|
| 85 |
+
symbolic_caffe2,
|
| 86 |
+
symbolic_helper,
|
| 87 |
+
symbolic_opset7,
|
| 88 |
+
symbolic_opset8,
|
| 89 |
+
symbolic_opset9,
|
| 90 |
+
symbolic_opset10,
|
| 91 |
+
symbolic_opset11,
|
| 92 |
+
symbolic_opset12,
|
| 93 |
+
symbolic_opset13,
|
| 94 |
+
symbolic_opset14,
|
| 95 |
+
symbolic_opset15,
|
| 96 |
+
symbolic_opset16,
|
| 97 |
+
symbolic_opset17,
|
| 98 |
+
symbolic_opset18,
|
| 99 |
+
symbolic_opset19,
|
| 100 |
+
symbolic_opset20,
|
| 101 |
+
utils,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
from ._internal._exporter_legacy import ( # usort: skip. needs to be last to avoid circular import
|
| 106 |
+
DiagnosticOptions,
|
| 107 |
+
ExportOptions,
|
| 108 |
+
ONNXProgram,
|
| 109 |
+
ONNXRuntimeOptions,
|
| 110 |
+
OnnxRegistry,
|
| 111 |
+
enable_fake_mode,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
if TYPE_CHECKING:
|
| 116 |
+
import os
|
| 117 |
+
|
| 118 |
+
# Set namespace for exposed private names
|
| 119 |
+
DiagnosticOptions.__module__ = "torch.onnx"
|
| 120 |
+
ExportOptions.__module__ = "torch.onnx"
|
| 121 |
+
ExportTypes.__module__ = "torch.onnx"
|
| 122 |
+
JitScalarType.__module__ = "torch.onnx"
|
| 123 |
+
ONNXProgram.__module__ = "torch.onnx"
|
| 124 |
+
ONNXRuntimeOptions.__module__ = "torch.onnx"
|
| 125 |
+
OnnxExporterError.__module__ = "torch.onnx"
|
| 126 |
+
OnnxRegistry.__module__ = "torch.onnx"
|
| 127 |
+
_OrtBackend.__module__ = "torch.onnx"
|
| 128 |
+
_OrtBackendOptions.__module__ = "torch.onnx"
|
| 129 |
+
_OrtExecutionProvider.__module__ = "torch.onnx"
|
| 130 |
+
enable_fake_mode.__module__ = "torch.onnx"
|
| 131 |
+
is_onnxrt_backend_supported.__module__ = "torch.onnx"
|
| 132 |
+
|
| 133 |
+
producer_name = "pytorch"
|
| 134 |
+
producer_version = _C_onnx.PRODUCER_VERSION
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def export(
|
| 138 |
+
model: torch.nn.Module
|
| 139 |
+
| torch.export.ExportedProgram
|
| 140 |
+
| torch.jit.ScriptModule
|
| 141 |
+
| torch.jit.ScriptFunction,
|
| 142 |
+
args: tuple[Any, ...] = (),
|
| 143 |
+
f: str | os.PathLike | None = None,
|
| 144 |
+
*,
|
| 145 |
+
kwargs: dict[str, Any] | None = None,
|
| 146 |
+
export_params: bool = True,
|
| 147 |
+
verbose: bool | None = None,
|
| 148 |
+
input_names: Sequence[str] | None = None,
|
| 149 |
+
output_names: Sequence[str] | None = None,
|
| 150 |
+
opset_version: int | None = None,
|
| 151 |
+
dynamic_axes: Mapping[str, Mapping[int, str]]
|
| 152 |
+
| Mapping[str, Sequence[int]]
|
| 153 |
+
| None = None,
|
| 154 |
+
keep_initializers_as_inputs: bool = False,
|
| 155 |
+
dynamo: bool = False,
|
| 156 |
+
# Dynamo only options
|
| 157 |
+
external_data: bool = True,
|
| 158 |
+
dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None,
|
| 159 |
+
report: bool = False,
|
| 160 |
+
verify: bool = False,
|
| 161 |
+
profile: bool = False,
|
| 162 |
+
dump_exported_program: bool = False,
|
| 163 |
+
artifacts_dir: str | os.PathLike = ".",
|
| 164 |
+
fallback: bool = False,
|
| 165 |
+
# Deprecated options
|
| 166 |
+
training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,
|
| 167 |
+
operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX,
|
| 168 |
+
do_constant_folding: bool = True,
|
| 169 |
+
custom_opsets: Mapping[str, int] | None = None,
|
| 170 |
+
export_modules_as_functions: bool | Collection[type[torch.nn.Module]] = False,
|
| 171 |
+
autograd_inlining: bool = True,
|
| 172 |
+
**_: Any, # ignored options
|
| 173 |
+
) -> Any | None:
|
| 174 |
+
r"""Exports a model into ONNX format.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
model: The model to be exported.
|
| 178 |
+
args: Example positional inputs. Any non-Tensor arguments will be hard-coded into the
|
| 179 |
+
exported model; any Tensor arguments will become inputs of the exported model,
|
| 180 |
+
in the order they occur in the tuple.
|
| 181 |
+
f: Path to the output ONNX model file. E.g. "model.onnx".
|
| 182 |
+
kwargs: Optional example keyword inputs.
|
| 183 |
+
export_params: If false, parameters (weights) will not be exported.
|
| 184 |
+
verbose: Whether to enable verbose logging.
|
| 185 |
+
input_names: names to assign to the input nodes of the graph, in order.
|
| 186 |
+
output_names: names to assign to the output nodes of the graph, in order.
|
| 187 |
+
opset_version: The version of the
|
| 188 |
+
`default (ai.onnx) opset <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_
|
| 189 |
+
to target. Must be >= 7.
|
| 190 |
+
dynamic_axes:
|
| 191 |
+
|
| 192 |
+
By default the exported model will have the shapes of all input and output tensors
|
| 193 |
+
set to exactly match those given in ``args``. To specify axes of tensors as
|
| 194 |
+
dynamic (i.e. known only at run-time), set ``dynamic_axes`` to a dict with schema:
|
| 195 |
+
|
| 196 |
+
* KEY (str): an input or output name. Each name must also be provided in ``input_names`` or
|
| 197 |
+
``output_names``.
|
| 198 |
+
* VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a
|
| 199 |
+
list, each element is an axis index.
|
| 200 |
+
|
| 201 |
+
For example::
|
| 202 |
+
|
| 203 |
+
class SumModule(torch.nn.Module):
|
| 204 |
+
def forward(self, x):
|
| 205 |
+
return torch.sum(x, dim=1)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
torch.onnx.export(
|
| 209 |
+
SumModule(),
|
| 210 |
+
(torch.ones(2, 2),),
|
| 211 |
+
"onnx.pb",
|
| 212 |
+
input_names=["x"],
|
| 213 |
+
output_names=["sum"],
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
Produces::
|
| 217 |
+
|
| 218 |
+
input {
|
| 219 |
+
name: "x"
|
| 220 |
+
...
|
| 221 |
+
shape {
|
| 222 |
+
dim {
|
| 223 |
+
dim_value: 2 # axis 0
|
| 224 |
+
}
|
| 225 |
+
dim {
|
| 226 |
+
dim_value: 2 # axis 1
|
| 227 |
+
...
|
| 228 |
+
output {
|
| 229 |
+
name: "sum"
|
| 230 |
+
...
|
| 231 |
+
shape {
|
| 232 |
+
dim {
|
| 233 |
+
dim_value: 2 # axis 0
|
| 234 |
+
...
|
| 235 |
+
|
| 236 |
+
While::
|
| 237 |
+
|
| 238 |
+
torch.onnx.export(
|
| 239 |
+
SumModule(),
|
| 240 |
+
(torch.ones(2, 2),),
|
| 241 |
+
"onnx.pb",
|
| 242 |
+
input_names=["x"],
|
| 243 |
+
output_names=["sum"],
|
| 244 |
+
dynamic_axes={
|
| 245 |
+
# dict value: manually named axes
|
| 246 |
+
"x": {0: "my_custom_axis_name"},
|
| 247 |
+
# list value: automatic names
|
| 248 |
+
"sum": [0],
|
| 249 |
+
},
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
Produces::
|
| 253 |
+
|
| 254 |
+
input {
|
| 255 |
+
name: "x"
|
| 256 |
+
...
|
| 257 |
+
shape {
|
| 258 |
+
dim {
|
| 259 |
+
dim_param: "my_custom_axis_name" # axis 0
|
| 260 |
+
}
|
| 261 |
+
dim {
|
| 262 |
+
dim_value: 2 # axis 1
|
| 263 |
+
...
|
| 264 |
+
output {
|
| 265 |
+
name: "sum"
|
| 266 |
+
...
|
| 267 |
+
shape {
|
| 268 |
+
dim {
|
| 269 |
+
dim_param: "sum_dynamic_axes_1" # axis 0
|
| 270 |
+
...
|
| 271 |
+
|
| 272 |
+
keep_initializers_as_inputs: If True, all the
|
| 273 |
+
initializers (typically corresponding to model weights) in the
|
| 274 |
+
exported graph will also be added as inputs to the graph. If False,
|
| 275 |
+
then initializers are not added as inputs to the graph, and only
|
| 276 |
+
the user inputs are added as inputs.
|
| 277 |
+
|
| 278 |
+
Set this to True if you intend to supply model weights at runtime.
|
| 279 |
+
Set it to False if the weights are static to allow for better optimizations
|
| 280 |
+
(e.g. constant folding) by backends/runtimes.
|
| 281 |
+
|
| 282 |
+
dynamo: Whether to export the model with ``torch.export`` ExportedProgram instead of TorchScript.
|
| 283 |
+
external_data: Whether to save the model weights as an external data file.
|
| 284 |
+
This is required for models with large weights that exceed the ONNX file size limit (2GB).
|
| 285 |
+
When False, the weights are saved in the ONNX file with the model architecture.
|
| 286 |
+
dynamic_shapes: A dictionary of dynamic shapes for the model inputs. Refer to
|
| 287 |
+
:func:`torch.export.export` for more details. This is only used (and preferred) when dynamo is True.
|
| 288 |
+
Only one parameter `dynamic_axes` or `dynamic_shapes` should be set
|
| 289 |
+
at the same time.
|
| 290 |
+
report: Whether to generate a markdown report for the export process.
|
| 291 |
+
verify: Whether to verify the exported model using ONNX Runtime.
|
| 292 |
+
profile: Whether to profile the export process.
|
| 293 |
+
dump_exported_program: Whether to dump the :class:`torch.export.ExportedProgram` to a file.
|
| 294 |
+
This is useful for debugging the exporter.
|
| 295 |
+
artifacts_dir: The directory to save the debugging artifacts like the report and the serialized
|
| 296 |
+
exported program.
|
| 297 |
+
fallback: Whether to fallback to the TorchScript exporter if the dynamo exporter fails.
|
| 298 |
+
|
| 299 |
+
training: Deprecated option. Instead, set the training mode of the model before exporting.
|
| 300 |
+
operator_export_type: Deprecated option. Only ONNX is supported.
|
| 301 |
+
do_constant_folding: Deprecated option. The exported graph is always optimized.
|
| 302 |
+
custom_opsets: Deprecated.
|
| 303 |
+
A dictionary:
|
| 304 |
+
|
| 305 |
+
* KEY (str): opset domain name
|
| 306 |
+
* VALUE (int): opset version
|
| 307 |
+
|
| 308 |
+
If a custom opset is referenced by ``model`` but not mentioned in this dictionary,
|
| 309 |
+
the opset version is set to 1. Only custom opset domain name and version should be
|
| 310 |
+
indicated through this argument.
|
| 311 |
+
export_modules_as_functions: Deprecated option.
|
| 312 |
+
|
| 313 |
+
Flag to enable
|
| 314 |
+
exporting all ``nn.Module`` forward calls as local functions in ONNX. Or a set to indicate the
|
| 315 |
+
particular types of modules to export as local functions in ONNX.
|
| 316 |
+
This feature requires ``opset_version`` >= 15, otherwise the export will fail. This is because
|
| 317 |
+
``opset_version`` < 15 implies IR version < 8, which means no local function support.
|
| 318 |
+
Module variables will be exported as function attributes. There are two categories of function
|
| 319 |
+
attributes.
|
| 320 |
+
|
| 321 |
+
1. Annotated attributes: class variables that have type annotations via
|
| 322 |
+
`PEP 526-style <https://www.python.org/dev/peps/pep-0526/#class-and-instance-variable-annotations>`_
|
| 323 |
+
will be exported as attributes.
|
| 324 |
+
Annotated attributes are not used inside the subgraph of ONNX local function because
|
| 325 |
+
they are not created by PyTorch JIT tracing, but they may be used by consumers
|
| 326 |
+
to determine whether or not to replace the function with a particular fused kernel.
|
| 327 |
+
|
| 328 |
+
2. Inferred attributes: variables that are used by operators inside the module. Attribute names
|
| 329 |
+
will have prefix "inferred::". This is to differentiate from predefined attributes retrieved from
|
| 330 |
+
python module annotations. Inferred attributes are used inside the subgraph of ONNX local function.
|
| 331 |
+
|
| 332 |
+
* ``False`` (default): export ``nn.Module`` forward calls as fine grained nodes.
|
| 333 |
+
* ``True``: export all ``nn.Module`` forward calls as local function nodes.
|
| 334 |
+
* Set of type of nn.Module: export ``nn.Module`` forward calls as local function nodes,
|
| 335 |
+
only if the type of the ``nn.Module`` is found in the set.
|
| 336 |
+
autograd_inlining: Deprecated.
|
| 337 |
+
Flag used to control whether to inline autograd functions.
|
| 338 |
+
Refer to https://github.com/pytorch/pytorch/pull/74765 for more details.
|
| 339 |
+
"""
|
| 340 |
+
if dynamo is True or isinstance(model, torch.export.ExportedProgram):
|
| 341 |
+
from torch.onnx._internal import exporter
|
| 342 |
+
|
| 343 |
+
if isinstance(args, torch.Tensor):
|
| 344 |
+
args = (args,)
|
| 345 |
+
return exporter.export_compat(
|
| 346 |
+
model,
|
| 347 |
+
args,
|
| 348 |
+
f,
|
| 349 |
+
kwargs=kwargs,
|
| 350 |
+
export_params=export_params,
|
| 351 |
+
verbose=verbose,
|
| 352 |
+
input_names=input_names,
|
| 353 |
+
output_names=output_names,
|
| 354 |
+
opset_version=opset_version,
|
| 355 |
+
dynamic_axes=dynamic_axes,
|
| 356 |
+
keep_initializers_as_inputs=keep_initializers_as_inputs,
|
| 357 |
+
external_data=external_data,
|
| 358 |
+
dynamic_shapes=dynamic_shapes,
|
| 359 |
+
report=report,
|
| 360 |
+
verify=verify,
|
| 361 |
+
profile=profile,
|
| 362 |
+
dump_exported_program=dump_exported_program,
|
| 363 |
+
artifacts_dir=artifacts_dir,
|
| 364 |
+
fallback=fallback,
|
| 365 |
+
)
|
| 366 |
+
else:
|
| 367 |
+
from torch.onnx.utils import export
|
| 368 |
+
|
| 369 |
+
if dynamic_shapes:
|
| 370 |
+
raise ValueError(
|
| 371 |
+
"The exporter only supports dynamic shapes "
|
| 372 |
+
"through parameter dynamic_axes when dynamo=False."
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
export(
|
| 376 |
+
model,
|
| 377 |
+
args,
|
| 378 |
+
f, # type: ignore[arg-type]
|
| 379 |
+
kwargs=kwargs,
|
| 380 |
+
export_params=export_params,
|
| 381 |
+
verbose=verbose is True,
|
| 382 |
+
input_names=input_names,
|
| 383 |
+
output_names=output_names,
|
| 384 |
+
opset_version=opset_version,
|
| 385 |
+
dynamic_axes=dynamic_axes,
|
| 386 |
+
keep_initializers_as_inputs=keep_initializers_as_inputs,
|
| 387 |
+
training=training,
|
| 388 |
+
operator_export_type=operator_export_type,
|
| 389 |
+
do_constant_folding=do_constant_folding,
|
| 390 |
+
custom_opsets=custom_opsets,
|
| 391 |
+
export_modules_as_functions=export_modules_as_functions,
|
| 392 |
+
autograd_inlining=autograd_inlining,
|
| 393 |
+
)
|
| 394 |
+
return None
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def dynamo_export(
|
| 398 |
+
model: torch.nn.Module | Callable | torch.export.ExportedProgram, # type: ignore[name-defined]
|
| 399 |
+
/,
|
| 400 |
+
*model_args,
|
| 401 |
+
export_options: ExportOptions | None = None,
|
| 402 |
+
**model_kwargs,
|
| 403 |
+
) -> ONNXProgram | Any:
|
| 404 |
+
"""Export a torch.nn.Module to an ONNX graph.
|
| 405 |
+
|
| 406 |
+
Args:
|
| 407 |
+
model: The PyTorch model to be exported to ONNX.
|
| 408 |
+
model_args: Positional inputs to ``model``.
|
| 409 |
+
model_kwargs: Keyword inputs to ``model``.
|
| 410 |
+
export_options: Options to influence the export to ONNX.
|
| 411 |
+
|
| 412 |
+
Returns:
|
| 413 |
+
An in-memory representation of the exported ONNX model.
|
| 414 |
+
|
| 415 |
+
**Example 1 - Simplest export**
|
| 416 |
+
::
|
| 417 |
+
|
| 418 |
+
class MyModel(torch.nn.Module):
|
| 419 |
+
def __init__(self) -> None:
|
| 420 |
+
super().__init__()
|
| 421 |
+
self.linear = torch.nn.Linear(2, 2)
|
| 422 |
+
|
| 423 |
+
def forward(self, x, bias=None):
|
| 424 |
+
out = self.linear(x)
|
| 425 |
+
out = out + bias
|
| 426 |
+
return out
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
model = MyModel()
|
| 430 |
+
kwargs = {"bias": 3.0}
|
| 431 |
+
args = (torch.randn(2, 2, 2),)
|
| 432 |
+
onnx_program = torch.onnx.dynamo_export(model, *args, **kwargs).save(
|
| 433 |
+
"my_simple_model.onnx"
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
**Example 2 - Exporting with dynamic shapes**
|
| 437 |
+
::
|
| 438 |
+
|
| 439 |
+
# The previous model can be exported with dynamic shapes
|
| 440 |
+
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
|
| 441 |
+
onnx_program = torch.onnx.dynamo_export(
|
| 442 |
+
model, *args, **kwargs, export_options=export_options
|
| 443 |
+
)
|
| 444 |
+
onnx_program.save("my_dynamic_model.onnx")
|
| 445 |
+
"""
|
| 446 |
+
|
| 447 |
+
# NOTE: The new exporter is experimental and is not enabled by default.
|
| 448 |
+
import warnings
|
| 449 |
+
|
| 450 |
+
from torch.onnx import _flags
|
| 451 |
+
from torch.onnx._internal import exporter
|
| 452 |
+
from torch.utils import _pytree
|
| 453 |
+
|
| 454 |
+
if isinstance(model, torch.export.ExportedProgram):
|
| 455 |
+
return exporter.export_compat(
|
| 456 |
+
model, # type: ignore[arg-type]
|
| 457 |
+
model_args,
|
| 458 |
+
f=None,
|
| 459 |
+
kwargs=model_kwargs,
|
| 460 |
+
opset_version=18,
|
| 461 |
+
external_data=True,
|
| 462 |
+
export_params=True,
|
| 463 |
+
fallback=True,
|
| 464 |
+
)
|
| 465 |
+
elif _flags.USE_EXPERIMENTAL_LOGIC:
|
| 466 |
+
if export_options is not None:
|
| 467 |
+
warnings.warn(
|
| 468 |
+
"You are using an experimental ONNX export logic, which currently only supports dynamic shapes. "
|
| 469 |
+
"For a more comprehensive set of export options, including advanced features, please consider using "
|
| 470 |
+
"`torch.onnx.export(..., dynamo=True)`. ",
|
| 471 |
+
category=FutureWarning,
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
if export_options is not None and export_options.dynamic_shapes:
|
| 475 |
+
# Make all shapes dynamic
|
| 476 |
+
def _to_dynamic_shapes_mapper():
|
| 477 |
+
arg_order = 0
|
| 478 |
+
|
| 479 |
+
def _to_dynamic_shape(x):
|
| 480 |
+
nonlocal arg_order
|
| 481 |
+
if isinstance(x, torch.Tensor):
|
| 482 |
+
rank = len(x.shape)
|
| 483 |
+
dynamic_shape = {}
|
| 484 |
+
for i in range(rank):
|
| 485 |
+
dynamic_shape[i] = torch.export.Dim(
|
| 486 |
+
f"arg_{arg_order}_dim_{i}"
|
| 487 |
+
)
|
| 488 |
+
arg_order += 1
|
| 489 |
+
return dynamic_shape
|
| 490 |
+
else:
|
| 491 |
+
return None
|
| 492 |
+
|
| 493 |
+
return _to_dynamic_shape
|
| 494 |
+
|
| 495 |
+
# model_args could be nested
|
| 496 |
+
dynamic_shapes = _pytree.tree_map(
|
| 497 |
+
_to_dynamic_shapes_mapper(),
|
| 498 |
+
model_args,
|
| 499 |
+
)
|
| 500 |
+
else:
|
| 501 |
+
dynamic_shapes = None
|
| 502 |
+
|
| 503 |
+
return exporter.export_compat(
|
| 504 |
+
model, # type: ignore[arg-type]
|
| 505 |
+
model_args,
|
| 506 |
+
f=None,
|
| 507 |
+
kwargs=model_kwargs,
|
| 508 |
+
dynamic_shapes=dynamic_shapes,
|
| 509 |
+
opset_version=18,
|
| 510 |
+
external_data=True,
|
| 511 |
+
export_params=True,
|
| 512 |
+
fallback=True,
|
| 513 |
+
)
|
| 514 |
+
else:
|
| 515 |
+
from torch.onnx._internal._exporter_legacy import dynamo_export
|
| 516 |
+
|
| 517 |
+
return dynamo_export(
|
| 518 |
+
model, *model_args, export_options=export_options, **model_kwargs
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
# TODO(justinchuby): Deprecate these logging functions in favor of the new diagnostic module.
|
| 523 |
+
|
| 524 |
+
# Returns True iff ONNX logging is turned on.
|
| 525 |
+
is_onnx_log_enabled = _C._jit_is_onnx_log_enabled
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def enable_log() -> None:
|
| 529 |
+
r"""Enables ONNX logging."""
|
| 530 |
+
_C._jit_set_onnx_log_enabled(True)
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def disable_log() -> None:
|
| 534 |
+
r"""Disables ONNX logging."""
|
| 535 |
+
_C._jit_set_onnx_log_enabled(False)
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
"""Sets output stream for ONNX logging.
|
| 539 |
+
|
| 540 |
+
Args:
|
| 541 |
+
stream_name (str, default "stdout"): Only 'stdout' and 'stderr' are supported
|
| 542 |
+
as ``stream_name``.
|
| 543 |
+
"""
|
| 544 |
+
set_log_stream = _C._jit_set_onnx_log_output_stream
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
"""A simple logging facility for ONNX exporter.
|
| 548 |
+
|
| 549 |
+
Args:
|
| 550 |
+
args: Arguments are converted to string, concatenated together with a newline
|
| 551 |
+
character appended to the end, and flushed to output stream.
|
| 552 |
+
"""
|
| 553 |
+
log = _C._jit_onnx_log
|
.venv/Lib/site-packages/torch/onnx/_constants.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Constant values used in ONNX."""
|
| 2 |
+
|
| 3 |
+
ONNX_ARCHIVE_MODEL_PROTO_NAME = "__MODEL_PROTO"
|
| 4 |
+
|
| 5 |
+
ONNX_BASE_OPSET = 9
|
| 6 |
+
ONNX_MIN_OPSET = 7
|
| 7 |
+
ONNX_MAX_OPSET = 20
|
| 8 |
+
ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET = 20
|
| 9 |
+
# ONNX_DEFAULT_OPSET generated by tools/onnx/update_default_opset_version.py
|
| 10 |
+
ONNX_DEFAULT_OPSET = 17
|
| 11 |
+
ONNX_CONSTANT_FOLDING_MIN_OPSET = 9
|
| 12 |
+
|
| 13 |
+
PYTORCH_GITHUB_ISSUES_URL = "https://github.com/pytorch/pytorch/issues"
|
| 14 |
+
|
| 15 |
+
INT64_MAX = 9223372036854775807
|
| 16 |
+
INT32_MAX = 2147483647
|
| 17 |
+
INT16_MAX = 32767
|
| 18 |
+
INT8_MAX = 127
|
| 19 |
+
UINT8_MAX = 255
|
| 20 |
+
|
| 21 |
+
INT64_MIN = -9223372036854775808
|
| 22 |
+
INT32_MIN = -2147483648
|
| 23 |
+
INT16_MIN = -32768
|
| 24 |
+
INT8_MIN = -128
|
| 25 |
+
UINT8_MIN = 0
|
.venv/Lib/site-packages/torch/onnx/_deprecation.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility for deprecating functions."""
|
| 2 |
+
|
| 3 |
+
import functools
|
| 4 |
+
import textwrap
|
| 5 |
+
import warnings
|
| 6 |
+
from typing import Callable, TypeVar
|
| 7 |
+
from typing_extensions import ParamSpec
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
_T = TypeVar("_T")
|
| 11 |
+
_P = ParamSpec("_P")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def deprecated(
|
| 15 |
+
since: str, removed_in: str, instructions: str
|
| 16 |
+
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
| 17 |
+
"""Marks functions as deprecated.
|
| 18 |
+
|
| 19 |
+
It will result in a warning when the function is called and a note in the
|
| 20 |
+
docstring.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
since: The version when the function was first deprecated.
|
| 24 |
+
removed_in: The version when the function will be removed.
|
| 25 |
+
instructions: The action users should take.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def decorator(function: Callable[_P, _T]) -> Callable[_P, _T]:
|
| 29 |
+
@functools.wraps(function)
|
| 30 |
+
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
|
| 31 |
+
warnings.warn(
|
| 32 |
+
f"'{function.__module__}.{function.__name__}' "
|
| 33 |
+
f"is deprecated in version {since} and will be "
|
| 34 |
+
f"removed in {removed_in}. Please {instructions}.",
|
| 35 |
+
category=FutureWarning,
|
| 36 |
+
stacklevel=2,
|
| 37 |
+
)
|
| 38 |
+
return function(*args, **kwargs)
|
| 39 |
+
|
| 40 |
+
# Add a deprecation note to the docstring.
|
| 41 |
+
docstring = function.__doc__ or ""
|
| 42 |
+
|
| 43 |
+
# Add a note to the docstring.
|
| 44 |
+
deprecation_note = textwrap.dedent(
|
| 45 |
+
f"""\
|
| 46 |
+
.. deprecated:: {since}
|
| 47 |
+
Deprecated and will be removed in version {removed_in}.
|
| 48 |
+
Please {instructions}.
|
| 49 |
+
"""
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# Split docstring at first occurrence of newline
|
| 53 |
+
summary_and_body = docstring.split("\n\n", 1)
|
| 54 |
+
|
| 55 |
+
if len(summary_and_body) > 1:
|
| 56 |
+
summary, body = summary_and_body
|
| 57 |
+
|
| 58 |
+
# Dedent the body. We cannot do this with the presence of the summary because
|
| 59 |
+
# the body contains leading whitespaces when the summary does not.
|
| 60 |
+
body = textwrap.dedent(body)
|
| 61 |
+
|
| 62 |
+
new_docstring_parts = [deprecation_note, "\n\n", summary, body]
|
| 63 |
+
else:
|
| 64 |
+
summary = summary_and_body[0]
|
| 65 |
+
|
| 66 |
+
new_docstring_parts = [deprecation_note, "\n\n", summary]
|
| 67 |
+
|
| 68 |
+
wrapper.__doc__ = "".join(new_docstring_parts)
|
| 69 |
+
|
| 70 |
+
return wrapper
|
| 71 |
+
|
| 72 |
+
return decorator
|
.venv/Lib/site-packages/torch/onnx/_experimental.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Experimental classes and functions used by ONNX export."""
|
| 2 |
+
|
| 3 |
+
import dataclasses
|
| 4 |
+
from typing import Mapping, Optional, Sequence, Set, Type, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch._C._onnx as _C_onnx
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclasses.dataclass
|
| 11 |
+
class ExportOptions:
|
| 12 |
+
"""Arguments used by :func:`torch.onnx.export`."""
|
| 13 |
+
|
| 14 |
+
# TODO(justinchuby): Deprecate and remove this class.
|
| 15 |
+
|
| 16 |
+
export_params: bool = True
|
| 17 |
+
verbose: bool = False
|
| 18 |
+
training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL
|
| 19 |
+
input_names: Optional[Sequence[str]] = None
|
| 20 |
+
output_names: Optional[Sequence[str]] = None
|
| 21 |
+
operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX
|
| 22 |
+
opset_version: Optional[int] = None
|
| 23 |
+
do_constant_folding: bool = True
|
| 24 |
+
dynamic_axes: Optional[Mapping[str, Union[Mapping[int, str], Sequence[int]]]] = None
|
| 25 |
+
keep_initializers_as_inputs: Optional[bool] = None
|
| 26 |
+
custom_opsets: Optional[Mapping[str, int]] = None
|
| 27 |
+
export_modules_as_functions: Union[bool, Set[Type[torch.nn.Module]]] = False
|
.venv/Lib/site-packages/torch/onnx/_exporter_states.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class ExportTypes:
|
| 5 |
+
"""Specifies how the ONNX model is stored."""
|
| 6 |
+
|
| 7 |
+
# TODO(justinchuby): Deprecate and remove this class.
|
| 8 |
+
|
| 9 |
+
PROTOBUF_FILE = "Saves model in the specified protobuf file."
|
| 10 |
+
ZIP_ARCHIVE = "Saves model in the specified ZIP file (uncompressed)."
|
| 11 |
+
COMPRESSED_ZIP_ARCHIVE = "Saves model in the specified ZIP file (compressed)."
|
| 12 |
+
DIRECTORY = "Saves model in the specified folder."
|
.venv/Lib/site-packages/torch/onnx/_flags.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Internal feature flags for torch.onnx.
|
| 2 |
+
|
| 3 |
+
NOTE: These flags are experimental only. Any flag here can be removed at any
|
| 4 |
+
time without notice.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _load_boolean_flag(
|
| 15 |
+
name: str,
|
| 16 |
+
*,
|
| 17 |
+
this_will: str,
|
| 18 |
+
deprecated: bool = False,
|
| 19 |
+
default: bool = False,
|
| 20 |
+
) -> bool:
|
| 21 |
+
"""Load a boolean flag from environment variable.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
name: The name of the environment variable.
|
| 25 |
+
this_will: A string that describes what this flag will do.
|
| 26 |
+
deprecated: Whether this flag is deprecated.
|
| 27 |
+
default: The default value if envvar not defined.
|
| 28 |
+
"""
|
| 29 |
+
undefined = os.getenv(name) is None
|
| 30 |
+
state = os.getenv(name) == "1"
|
| 31 |
+
if state:
|
| 32 |
+
if deprecated:
|
| 33 |
+
logger.error(
|
| 34 |
+
"Experimental flag %s is deprecated. Please remove it from your environment.",
|
| 35 |
+
name,
|
| 36 |
+
)
|
| 37 |
+
else:
|
| 38 |
+
logger.warning(
|
| 39 |
+
"Experimental flag %s is enabled. This will %s.", name, this_will
|
| 40 |
+
)
|
| 41 |
+
if undefined:
|
| 42 |
+
state = default
|
| 43 |
+
return state
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
USE_EXPERIMENTAL_LOGIC: bool = _load_boolean_flag(
|
| 47 |
+
"TORCH_ONNX_USE_EXPERIMENTAL_LOGIC",
|
| 48 |
+
this_will="use ExportedProgram and the new torch.onnx export logic",
|
| 49 |
+
)
|
.venv/Lib/site-packages/torch/onnx/_globals.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
"""Globals used internally by the ONNX exporter.
|
| 3 |
+
|
| 4 |
+
Do not use this module outside of `torch.onnx` and its tests.
|
| 5 |
+
|
| 6 |
+
Be very judicious when adding any new global variables. Do not create new global
|
| 7 |
+
variables unless they are absolutely necessary.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch._C._onnx as _C_onnx
|
| 11 |
+
|
| 12 |
+
# This module should only depend on _constants and nothing else in torch.onnx to keep
|
| 13 |
+
# dependency direction clean.
|
| 14 |
+
from torch.onnx import _constants
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class _InternalGlobals:
|
| 18 |
+
"""Globals used internally by ONNX exporter.
|
| 19 |
+
|
| 20 |
+
NOTE: Be very judicious when adding any new variables. Do not create new
|
| 21 |
+
global variables unless they are absolutely necessary.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self) -> None:
|
| 25 |
+
self._export_onnx_opset_version = _constants.ONNX_DEFAULT_OPSET
|
| 26 |
+
self._training_mode: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL
|
| 27 |
+
self._in_onnx_export: bool = False
|
| 28 |
+
# Whether the user's model is training during export
|
| 29 |
+
self.export_training: bool = False
|
| 30 |
+
self.operator_export_type: _C_onnx.OperatorExportTypes = (
|
| 31 |
+
_C_onnx.OperatorExportTypes.ONNX
|
| 32 |
+
)
|
| 33 |
+
self.onnx_shape_inference: bool = True
|
| 34 |
+
self._autograd_inlining: bool = True
|
| 35 |
+
|
| 36 |
+
@property
|
| 37 |
+
def training_mode(self):
|
| 38 |
+
"""The training mode for the exporter."""
|
| 39 |
+
return self._training_mode
|
| 40 |
+
|
| 41 |
+
@training_mode.setter
|
| 42 |
+
def training_mode(self, training_mode: _C_onnx.TrainingMode):
|
| 43 |
+
if not isinstance(training_mode, _C_onnx.TrainingMode):
|
| 44 |
+
raise TypeError(
|
| 45 |
+
"training_mode must be of type 'torch.onnx.TrainingMode'. This is "
|
| 46 |
+
"likely a bug in torch.onnx."
|
| 47 |
+
)
|
| 48 |
+
self._training_mode = training_mode
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def export_onnx_opset_version(self) -> int:
|
| 52 |
+
"""Opset version used during export."""
|
| 53 |
+
return self._export_onnx_opset_version
|
| 54 |
+
|
| 55 |
+
@export_onnx_opset_version.setter
|
| 56 |
+
def export_onnx_opset_version(self, value: int):
|
| 57 |
+
supported_versions = range(
|
| 58 |
+
_constants.ONNX_MIN_OPSET, _constants.ONNX_MAX_OPSET + 1
|
| 59 |
+
)
|
| 60 |
+
if value not in supported_versions:
|
| 61 |
+
raise ValueError(f"Unsupported ONNX opset version: {value}")
|
| 62 |
+
self._export_onnx_opset_version = value
|
| 63 |
+
|
| 64 |
+
@property
|
| 65 |
+
def in_onnx_export(self) -> bool:
|
| 66 |
+
"""Whether it is in the middle of ONNX export."""
|
| 67 |
+
return self._in_onnx_export
|
| 68 |
+
|
| 69 |
+
@in_onnx_export.setter
|
| 70 |
+
def in_onnx_export(self, value: bool):
|
| 71 |
+
if type(value) is not bool:
|
| 72 |
+
raise TypeError("in_onnx_export must be a boolean")
|
| 73 |
+
self._in_onnx_export = value
|
| 74 |
+
|
| 75 |
+
@property
|
| 76 |
+
def autograd_inlining(self) -> bool:
|
| 77 |
+
"""Whether Autograd must be inlined."""
|
| 78 |
+
return self._autograd_inlining
|
| 79 |
+
|
| 80 |
+
@autograd_inlining.setter
|
| 81 |
+
def autograd_inlining(self, value: bool):
|
| 82 |
+
if type(value) is not bool:
|
| 83 |
+
raise TypeError("autograd_inlining must be a boolean")
|
| 84 |
+
self._autograd_inlining = value
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
GLOBALS = _InternalGlobals()
|
.venv/Lib/site-packages/torch/onnx/_internal/__init__.py
ADDED
|
File without changes
|
.venv/Lib/site-packages/torch/onnx/_internal/_lazy_import.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility to lazily import modules."""
|
| 2 |
+
|
| 3 |
+
# mypy: allow-untyped-defs
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import importlib
|
| 7 |
+
from typing import Any, TYPE_CHECKING
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class _LazyModule:
|
| 11 |
+
"""Lazily import a module."""
|
| 12 |
+
|
| 13 |
+
def __init__(self, module_name: str) -> None:
|
| 14 |
+
self._name = module_name
|
| 15 |
+
self._module: Any = None
|
| 16 |
+
|
| 17 |
+
def __repr__(self) -> str:
|
| 18 |
+
return f"<lazy module '{self._name}'>"
|
| 19 |
+
|
| 20 |
+
def __getattr__(self, attr):
|
| 21 |
+
if self._module is None:
|
| 22 |
+
self._module = importlib.import_module(".", self._name)
|
| 23 |
+
return getattr(self._module, attr)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# Import the following modules during type checking to enable code intelligence features,
|
| 27 |
+
# such as auto-completion in tools like pylance, even when these modules are not explicitly
|
| 28 |
+
# imported in user code.
|
| 29 |
+
# NOTE: Add additional used imports here.
|
| 30 |
+
if TYPE_CHECKING:
|
| 31 |
+
import onnx
|
| 32 |
+
import onnxscript
|
| 33 |
+
import onnxscript._framework_apis.torch_2_5 as onnxscript_apis
|
| 34 |
+
|
| 35 |
+
onnxscript_ir = onnxscript.ir
|
| 36 |
+
|
| 37 |
+
else:
|
| 38 |
+
onnx = _LazyModule("onnx")
|
| 39 |
+
onnxscript = _LazyModule("onnxscript")
|
| 40 |
+
onnxscript_ir = _LazyModule("onnxscript.ir")
|
| 41 |
+
onnxscript_apis = _LazyModule("onnxscript._framework_apis.torch_2_5")
|
.venv/Lib/site-packages/torch/onnx/_internal/diagnostics/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._diagnostic import (
|
| 2 |
+
create_export_diagnostic_context,
|
| 3 |
+
diagnose,
|
| 4 |
+
engine,
|
| 5 |
+
export_context,
|
| 6 |
+
ExportDiagnosticEngine,
|
| 7 |
+
TorchScriptOnnxExportDiagnostic,
|
| 8 |
+
)
|
| 9 |
+
from ._rules import rules
|
| 10 |
+
from .infra import levels
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
"TorchScriptOnnxExportDiagnostic",
|
| 15 |
+
"ExportDiagnosticEngine",
|
| 16 |
+
"rules",
|
| 17 |
+
"levels",
|
| 18 |
+
"engine",
|
| 19 |
+
"export_context",
|
| 20 |
+
"create_export_diagnostic_context",
|
| 21 |
+
"diagnose",
|
| 22 |
+
]
|
.venv/Lib/site-packages/torch/onnx/_internal/diagnostics/_diagnostic.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
"""Diagnostic components for TorchScript based ONNX export, i.e. `torch.onnx.export`."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import contextlib
|
| 7 |
+
import gzip
|
| 8 |
+
from typing import TYPE_CHECKING
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch.onnx._internal.diagnostics import infra
|
| 12 |
+
from torch.onnx._internal.diagnostics.infra import formatter, sarif
|
| 13 |
+
from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version
|
| 14 |
+
from torch.utils import cpp_backtrace
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
if TYPE_CHECKING:
|
| 18 |
+
from collections.abc import Generator
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _cpp_call_stack(frames_to_skip: int = 0, frames_to_log: int = 32) -> infra.Stack:
|
| 22 |
+
"""Returns the current C++ call stack.
|
| 23 |
+
|
| 24 |
+
This function utilizes `torch.utils.cpp_backtrace` to get the current C++ call stack.
|
| 25 |
+
The returned C++ call stack is a concatenated string of the C++ call stack frames.
|
| 26 |
+
Each frame is separated by a newline character, in the same format of
|
| 27 |
+
r"frame #[0-9]+: (?P<frame_info>.*)". More info at `c10/util/Backtrace.cpp`.
|
| 28 |
+
|
| 29 |
+
"""
|
| 30 |
+
frames = cpp_backtrace.get_cpp_backtrace(frames_to_skip, frames_to_log).split("\n")
|
| 31 |
+
frame_messages = []
|
| 32 |
+
for frame in frames:
|
| 33 |
+
segments = frame.split(":", 1)
|
| 34 |
+
if len(segments) == 2:
|
| 35 |
+
frame_messages.append(segments[1].strip())
|
| 36 |
+
else:
|
| 37 |
+
frame_messages.append("<unknown frame>")
|
| 38 |
+
return infra.Stack(
|
| 39 |
+
frames=[
|
| 40 |
+
infra.StackFrame(location=infra.Location(message=message))
|
| 41 |
+
for message in frame_messages
|
| 42 |
+
]
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class TorchScriptOnnxExportDiagnostic(infra.Diagnostic):
|
| 47 |
+
"""Base class for all export diagnostics.
|
| 48 |
+
|
| 49 |
+
This class is used to represent all export diagnostics. It is a subclass of
|
| 50 |
+
infra.Diagnostic, and adds additional methods to add more information to the
|
| 51 |
+
diagnostic.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
python_call_stack: infra.Stack | None = None
|
| 55 |
+
cpp_call_stack: infra.Stack | None = None
|
| 56 |
+
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
*args,
|
| 60 |
+
frames_to_skip: int = 1,
|
| 61 |
+
cpp_stack: bool = False,
|
| 62 |
+
**kwargs,
|
| 63 |
+
) -> None:
|
| 64 |
+
super().__init__(*args, **kwargs)
|
| 65 |
+
self.python_call_stack = self.record_python_call_stack(
|
| 66 |
+
frames_to_skip=frames_to_skip
|
| 67 |
+
)
|
| 68 |
+
if cpp_stack:
|
| 69 |
+
self.cpp_call_stack = self.record_cpp_call_stack(
|
| 70 |
+
frames_to_skip=frames_to_skip
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def record_cpp_call_stack(self, frames_to_skip: int) -> infra.Stack:
|
| 74 |
+
"""Records the current C++ call stack in the diagnostic."""
|
| 75 |
+
stack = _cpp_call_stack(frames_to_skip=frames_to_skip)
|
| 76 |
+
stack.message = "C++ call stack"
|
| 77 |
+
self.with_stack(stack)
|
| 78 |
+
return stack
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class ExportDiagnosticEngine:
|
| 82 |
+
"""PyTorch ONNX Export diagnostic engine.
|
| 83 |
+
|
| 84 |
+
The only purpose of creating this class instead of using `DiagnosticContext` directly
|
| 85 |
+
is to provide a background context for `diagnose` calls inside exporter.
|
| 86 |
+
|
| 87 |
+
By design, one `torch.onnx.export` call should initialize one diagnostic context.
|
| 88 |
+
All `diagnose` calls inside exporter should be made in the context of that export.
|
| 89 |
+
However, since diagnostic context is currently being accessed via a global variable,
|
| 90 |
+
there is no guarantee that the context is properly initialized. Therefore, we need
|
| 91 |
+
to provide a default background context to fallback to, otherwise any invocation of
|
| 92 |
+
exporter internals, e.g. unit tests, will fail due to missing diagnostic context.
|
| 93 |
+
This can be removed once the pipeline for context to flow through the exporter is
|
| 94 |
+
established.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
contexts: list[infra.DiagnosticContext]
|
| 98 |
+
_background_context: infra.DiagnosticContext
|
| 99 |
+
|
| 100 |
+
def __init__(self) -> None:
|
| 101 |
+
self.contexts = []
|
| 102 |
+
self._background_context = infra.DiagnosticContext(
|
| 103 |
+
name="torch.onnx",
|
| 104 |
+
version=torch.__version__,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
@property
|
| 108 |
+
def background_context(self) -> infra.DiagnosticContext:
|
| 109 |
+
return self._background_context
|
| 110 |
+
|
| 111 |
+
def create_diagnostic_context(
|
| 112 |
+
self,
|
| 113 |
+
name: str,
|
| 114 |
+
version: str,
|
| 115 |
+
options: infra.DiagnosticOptions | None = None,
|
| 116 |
+
) -> infra.DiagnosticContext:
|
| 117 |
+
"""Creates a new diagnostic context.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
name: The subject name for the diagnostic context.
|
| 121 |
+
version: The subject version for the diagnostic context.
|
| 122 |
+
options: The options for the diagnostic context.
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
A new diagnostic context.
|
| 126 |
+
"""
|
| 127 |
+
if options is None:
|
| 128 |
+
options = infra.DiagnosticOptions()
|
| 129 |
+
context: infra.DiagnosticContext[infra.Diagnostic] = infra.DiagnosticContext(
|
| 130 |
+
name, version, options
|
| 131 |
+
)
|
| 132 |
+
self.contexts.append(context)
|
| 133 |
+
return context
|
| 134 |
+
|
| 135 |
+
def clear(self):
|
| 136 |
+
"""Clears all diagnostic contexts."""
|
| 137 |
+
self.contexts.clear()
|
| 138 |
+
self._background_context.diagnostics.clear()
|
| 139 |
+
|
| 140 |
+
def to_json(self) -> str:
|
| 141 |
+
return formatter.sarif_to_json(self.sarif_log())
|
| 142 |
+
|
| 143 |
+
def dump(self, file_path: str, compress: bool = False) -> None:
|
| 144 |
+
"""Dumps the SARIF log to a file."""
|
| 145 |
+
if compress:
|
| 146 |
+
with gzip.open(file_path, "wt") as f:
|
| 147 |
+
f.write(self.to_json())
|
| 148 |
+
else:
|
| 149 |
+
with open(file_path, "w") as f:
|
| 150 |
+
f.write(self.to_json())
|
| 151 |
+
|
| 152 |
+
def sarif_log(self):
|
| 153 |
+
log = sarif.SarifLog(
|
| 154 |
+
version=sarif_version.SARIF_VERSION,
|
| 155 |
+
schema_uri=sarif_version.SARIF_SCHEMA_LINK,
|
| 156 |
+
runs=[context.sarif() for context in self.contexts],
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
log.runs.append(self._background_context.sarif())
|
| 160 |
+
return log
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
engine = ExportDiagnosticEngine()
|
| 164 |
+
_context = engine.background_context
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
@contextlib.contextmanager
|
| 168 |
+
def create_export_diagnostic_context() -> (
|
| 169 |
+
Generator[infra.DiagnosticContext, None, None]
|
| 170 |
+
):
|
| 171 |
+
"""Create a diagnostic context for export.
|
| 172 |
+
|
| 173 |
+
This is a workaround for code robustness since diagnostic context is accessed by
|
| 174 |
+
export internals via global variable. See `ExportDiagnosticEngine` for more details.
|
| 175 |
+
"""
|
| 176 |
+
global _context
|
| 177 |
+
assert (
|
| 178 |
+
_context == engine.background_context
|
| 179 |
+
), "Export context is already set. Nested export is not supported."
|
| 180 |
+
_context = engine.create_diagnostic_context(
|
| 181 |
+
"torch.onnx.export",
|
| 182 |
+
torch.__version__,
|
| 183 |
+
)
|
| 184 |
+
try:
|
| 185 |
+
yield _context
|
| 186 |
+
finally:
|
| 187 |
+
_context = engine.background_context
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def diagnose(
|
| 191 |
+
rule: infra.Rule,
|
| 192 |
+
level: infra.Level,
|
| 193 |
+
message: str | None = None,
|
| 194 |
+
frames_to_skip: int = 2,
|
| 195 |
+
**kwargs,
|
| 196 |
+
) -> TorchScriptOnnxExportDiagnostic:
|
| 197 |
+
"""Creates a diagnostic and record it in the global diagnostic context.
|
| 198 |
+
|
| 199 |
+
This is a wrapper around `context.log` that uses the global diagnostic
|
| 200 |
+
context.
|
| 201 |
+
"""
|
| 202 |
+
diagnostic = TorchScriptOnnxExportDiagnostic(
|
| 203 |
+
rule, level, message, frames_to_skip=frames_to_skip, **kwargs
|
| 204 |
+
)
|
| 205 |
+
export_context().log(diagnostic)
|
| 206 |
+
return diagnostic
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def export_context() -> infra.DiagnosticContext:
|
| 210 |
+
global _context
|
| 211 |
+
return _context
|
.venv/Lib/site-packages/torch/onnx/_internal/diagnostics/_rules.py
ADDED
|
@@ -0,0 +1,636 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
"""
|
| 3 |
+
GENERATED CODE - DO NOT EDIT DIRECTLY
|
| 4 |
+
This file is generated by gen_diagnostics.py.
|
| 5 |
+
See tools/onnx/gen_diagnostics.py for more information.
|
| 6 |
+
|
| 7 |
+
Diagnostic rules for PyTorch ONNX export.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import dataclasses
|
| 11 |
+
from typing import Tuple
|
| 12 |
+
|
| 13 |
+
# flake8: noqa
|
| 14 |
+
from torch.onnx._internal.diagnostics import infra
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
"""
|
| 18 |
+
GENERATED CODE - DO NOT EDIT DIRECTLY
|
| 19 |
+
The purpose of generating a class for each rule is to override the `format_message`
|
| 20 |
+
method to provide more details in the signature about the format arguments.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class _NodeMissingOnnxShapeInference(infra.Rule):
|
| 25 |
+
"""Node is missing ONNX shape inference."""
|
| 26 |
+
|
| 27 |
+
def format_message(self, op_name) -> str: # type: ignore[override]
|
| 28 |
+
"""Returns the formatted default message of this Rule.
|
| 29 |
+
|
| 30 |
+
Message template: 'The shape inference of {op_name} type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.'
|
| 31 |
+
"""
|
| 32 |
+
return self.message_default_template.format(op_name=op_name)
|
| 33 |
+
|
| 34 |
+
def format( # type: ignore[override]
|
| 35 |
+
self, level: infra.Level, op_name
|
| 36 |
+
) -> Tuple[infra.Rule, infra.Level, str]:
|
| 37 |
+
"""Returns a tuple of (Rule, Level, message) for this Rule.
|
| 38 |
+
|
| 39 |
+
Message template: 'The shape inference of {op_name} type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.'
|
| 40 |
+
"""
|
| 41 |
+
return self, level, self.format_message(op_name=op_name)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class _MissingCustomSymbolicFunction(infra.Rule):
|
| 45 |
+
"""Missing symbolic function for custom PyTorch operator, cannot translate node to ONNX."""
|
| 46 |
+
|
| 47 |
+
def format_message(self, op_name) -> str: # type: ignore[override]
|
| 48 |
+
"""Returns the formatted default message of this Rule.
|
| 49 |
+
|
| 50 |
+
Message template: 'ONNX export failed on an operator with unrecognized namespace {op_name}. If you are trying to export a custom operator, make sure you registered it with the right domain and version.'
|
| 51 |
+
"""
|
| 52 |
+
return self.message_default_template.format(op_name=op_name)
|
| 53 |
+
|
| 54 |
+
def format( # type: ignore[override]
|
| 55 |
+
self, level: infra.Level, op_name
|
| 56 |
+
) -> Tuple[infra.Rule, infra.Level, str]:
|
| 57 |
+
"""Returns a tuple of (Rule, Level, message) for this Rule.
|
| 58 |
+
|
| 59 |
+
Message template: 'ONNX export failed on an operator with unrecognized namespace {op_name}. If you are trying to export a custom operator, make sure you registered it with the right domain and version.'
|
| 60 |
+
"""
|
| 61 |
+
return self, level, self.format_message(op_name=op_name)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class _MissingStandardSymbolicFunction(infra.Rule):
|
| 65 |
+
"""Missing symbolic function for standard PyTorch operator, cannot translate node to ONNX."""
|
| 66 |
+
|
| 67 |
+
def format_message( # type: ignore[override]
|
| 68 |
+
self, op_name, opset_version, issue_url
|
| 69 |
+
) -> str:
|
| 70 |
+
"""Returns the formatted default message of this Rule.
|
| 71 |
+
|
| 72 |
+
Message template: "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: {issue_url}."
|
| 73 |
+
"""
|
| 74 |
+
return self.message_default_template.format(
|
| 75 |
+
op_name=op_name, opset_version=opset_version, issue_url=issue_url
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
def format( # type: ignore[override]
|
| 79 |
+
self, level: infra.Level, op_name, opset_version, issue_url
|
| 80 |
+
) -> Tuple[infra.Rule, infra.Level, str]:
|
| 81 |
+
"""Returns a tuple of (Rule, Level, message) for this Rule.
|
| 82 |
+
|
| 83 |
+
Message template: "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: {issue_url}."
|
| 84 |
+
"""
|
| 85 |
+
return (
|
| 86 |
+
self,
|
| 87 |
+
level,
|
| 88 |
+
self.format_message(
|
| 89 |
+
op_name=op_name, opset_version=opset_version, issue_url=issue_url
|
| 90 |
+
),
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class _OperatorSupportedInNewerOpsetVersion(infra.Rule):
|
| 95 |
+
"""Operator is supported in newer opset version."""
|
| 96 |
+
|
| 97 |
+
def format_message( # type: ignore[override]
|
| 98 |
+
self, op_name, opset_version, supported_opset_version
|
| 99 |
+
) -> str:
|
| 100 |
+
"""Returns the formatted default message of this Rule.
|
| 101 |
+
|
| 102 |
+
Message template: "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. Support for this operator was added in version {supported_opset_version}, try exporting with this version."
|
| 103 |
+
"""
|
| 104 |
+
return self.message_default_template.format(
|
| 105 |
+
op_name=op_name,
|
| 106 |
+
opset_version=opset_version,
|
| 107 |
+
supported_opset_version=supported_opset_version,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
def format( # type: ignore[override]
|
| 111 |
+
self, level: infra.Level, op_name, opset_version, supported_opset_version
|
| 112 |
+
) -> Tuple[infra.Rule, infra.Level, str]:
|
| 113 |
+
"""Returns a tuple of (Rule, Level, message) for this Rule.
|
| 114 |
+
|
| 115 |
+
Message template: "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. Support for this operator was added in version {supported_opset_version}, try exporting with this version."
|
| 116 |
+
"""
|
| 117 |
+
return (
|
| 118 |
+
self,
|
| 119 |
+
level,
|
| 120 |
+
self.format_message(
|
| 121 |
+
op_name=op_name,
|
| 122 |
+
opset_version=opset_version,
|
| 123 |
+
supported_opset_version=supported_opset_version,
|
| 124 |
+
),
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class _FxGraphToOnnx(infra.Rule):
|
| 129 |
+
"""Transforms graph from FX IR to ONNX IR."""
|
| 130 |
+
|
| 131 |
+
def format_message(self, graph_name) -> str: # type: ignore[override]
|
| 132 |
+
"""Returns the formatted default message of this Rule.
|
| 133 |
+
|
| 134 |
+
Message template: 'Transforming FX graph {graph_name} to ONNX graph.'
|
| 135 |
+
"""
|
| 136 |
+
return self.message_default_template.format(graph_name=graph_name)
|
| 137 |
+
|
| 138 |
+
def format( # type: ignore[override]
|
| 139 |
+
self, level: infra.Level, graph_name
|
| 140 |
+
) -> Tuple[infra.Rule, infra.Level, str]:
|
| 141 |
+
"""Returns a tuple of (Rule, Level, message) for this Rule.
|
| 142 |
+
|
| 143 |
+
Message template: 'Transforming FX graph {graph_name} to ONNX graph.'
|
| 144 |
+
"""
|
| 145 |
+
return self, level, self.format_message(graph_name=graph_name)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class _FxNodeToOnnx(infra.Rule):
|
| 149 |
+
"""Transforms an FX node to an ONNX node."""
|
| 150 |
+
|
| 151 |
+
def format_message(self, node_repr) -> str: # type: ignore[override]
|
| 152 |
+
"""Returns the formatted default message of this Rule.
|
| 153 |
+
|
| 154 |
+
Message template: 'Transforming FX node {node_repr} to ONNX node.'
|
| 155 |
+
"""
|
| 156 |
+
return self.message_default_template.format(node_repr=node_repr)
|
| 157 |
+
|
| 158 |
+
def format( # type: ignore[override]
|
| 159 |
+
self, level: infra.Level, node_repr
|
| 160 |
+
) -> Tuple[infra.Rule, infra.Level, str]:
|
| 161 |
+
"""Returns a tuple of (Rule, Level, message) for this Rule.
|
| 162 |
+
|
| 163 |
+
Message template: 'Transforming FX node {node_repr} to ONNX node.'
|
| 164 |
+
"""
|
| 165 |
+
return self, level, self.format_message(node_repr=node_repr)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class _FxPass(infra.Rule):
|
| 169 |
+
"""FX graph transformation during ONNX export before converting from FX IR to ONNX IR."""
|
| 170 |
+
|
| 171 |
+
def format_message(self, pass_name) -> str: # type: ignore[override]
|
| 172 |
+
"""Returns the formatted default message of this Rule.
|
| 173 |
+
|
| 174 |
+
Message template: 'Running {pass_name} pass.'
|
| 175 |
+
"""
|
| 176 |
+
return self.message_default_template.format(pass_name=pass_name)
|
| 177 |
+
|
| 178 |
+
def format( # type: ignore[override]
|
| 179 |
+
self, level: infra.Level, pass_name
|
| 180 |
+
) -> Tuple[infra.Rule, infra.Level, str]:
|
| 181 |
+
"""Returns a tuple of (Rule, Level, message) for this Rule.
|
| 182 |
+
|
| 183 |
+
Message template: 'Running {pass_name} pass.'
|
| 184 |
+
"""
|
| 185 |
+
return self, level, self.format_message(pass_name=pass_name)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class _NoSymbolicFunctionForCallFunction(infra.Rule):
|
| 189 |
+
"""Cannot find symbolic function to convert the "call_function" FX node to ONNX."""
|
| 190 |
+
|
| 191 |
+
def format_message(self, target) -> str: # type: ignore[override]
|
| 192 |
+
"""Returns the formatted default message of this Rule.
|
| 193 |
+
|
| 194 |
+
Message template: 'No symbolic function to convert the "call_function" node {target} to ONNX. '
|
| 195 |
+
"""
|
| 196 |
+
return self.message_default_template.format(target=target)
|
| 197 |
+
|
| 198 |
+
def format( # type: ignore[override]
|
| 199 |
+
self, level: infra.Level, target
|
| 200 |
+
) -> Tuple[infra.Rule, infra.Level, str]:
|
| 201 |
+
"""Returns a tuple of (Rule, Level, message) for this Rule.
|
| 202 |
+
|
| 203 |
+
Message template: 'No symbolic function to convert the "call_function" node {target} to ONNX. '
|
| 204 |
+
"""
|
| 205 |
+
return self, level, self.format_message(target=target)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class _UnsupportedFxNodeAnalysis(infra.Rule):
|
| 209 |
+
"""Result from FX graph analysis to reveal unsupported FX nodes."""
|
| 210 |
+
|
| 211 |
+
def format_message( # type: ignore[override]
|
| 212 |
+
self, node_op_to_target_mapping
|
| 213 |
+
) -> str:
|
| 214 |
+
"""Returns the formatted default message of this Rule.
|
| 215 |
+
|
| 216 |
+
Message template: 'Unsupported FX nodes: {node_op_to_target_mapping}. '
|
| 217 |
+
"""
|
| 218 |
+
return self.message_default_template.format(
|
| 219 |
+
node_op_to_target_mapping=node_op_to_target_mapping
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
def format( # type: ignore[override]
|
| 223 |
+
self, level: infra.Level, node_op_to_target_mapping
|
| 224 |
+
) -> Tuple[infra.Rule, infra.Level, str]:
|
| 225 |
+
"""Returns a tuple of (Rule, Level, message) for this Rule.
|
| 226 |
+
|
| 227 |
+
Message template: 'Unsupported FX nodes: {node_op_to_target_mapping}. '
|
| 228 |
+
"""
|
| 229 |
+
return (
|
| 230 |
+
self,
|
| 231 |
+
level,
|
| 232 |
+
self.format_message(node_op_to_target_mapping=node_op_to_target_mapping),
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class _OpLevelDebugging(infra.Rule):
|
| 237 |
+
"""Report any op level validation failure in warnings."""
|
| 238 |
+
|
| 239 |
+
def format_message(self, node, symbolic_fn) -> str: # type: ignore[override]
|
| 240 |
+
"""Returns the formatted default message of this Rule.
|
| 241 |
+
|
| 242 |
+
Message template: 'FX node: {node} and its onnx function: {symbolic_fn} fails on op level validation.'
|
| 243 |
+
"""
|
| 244 |
+
return self.message_default_template.format(node=node, symbolic_fn=symbolic_fn)
|
| 245 |
+
|
| 246 |
+
def format( # type: ignore[override]
|
| 247 |
+
self, level: infra.Level, node, symbolic_fn
|
| 248 |
+
) -> Tuple[infra.Rule, infra.Level, str]:
|
| 249 |
+
"""Returns a tuple of (Rule, Level, message) for this Rule.
|
| 250 |
+
|
| 251 |
+
Message template: 'FX node: {node} and its onnx function: {symbolic_fn} fails on op level validation.'
|
| 252 |
+
"""
|
| 253 |
+
return self, level, self.format_message(node=node, symbolic_fn=symbolic_fn)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class _FindOpschemaMatchedSymbolicFunction(infra.Rule):
|
| 257 |
+
"""Find the OnnxFunction that matches the input/attribute dtypes by comparing them with their opschemas."""
|
| 258 |
+
|
| 259 |
+
def format_message(self, symbolic_fn, node) -> str: # type: ignore[override]
|
| 260 |
+
"""Returns the formatted default message of this Rule.
|
| 261 |
+
|
| 262 |
+
Message template: 'The OnnxFunction: {symbolic_fn} is the nearest match of the node {node}.'
|
| 263 |
+
"""
|
| 264 |
+
return self.message_default_template.format(symbolic_fn=symbolic_fn, node=node)
|
| 265 |
+
|
| 266 |
+
def format( # type: ignore[override]
|
| 267 |
+
self, level: infra.Level, symbolic_fn, node
|
| 268 |
+
) -> Tuple[infra.Rule, infra.Level, str]:
|
| 269 |
+
"""Returns a tuple of (Rule, Level, message) for this Rule.
|
| 270 |
+
|
| 271 |
+
Message template: 'The OnnxFunction: {symbolic_fn} is the nearest match of the node {node}.'
|
| 272 |
+
"""
|
| 273 |
+
return self, level, self.format_message(symbolic_fn=symbolic_fn, node=node)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class _FxNodeInsertTypePromotion(infra.Rule):
|
| 277 |
+
"""Determine if type promotion is required for the FX node. Insert cast nodes if needed."""
|
| 278 |
+
|
| 279 |
+
def format_message(self, target) -> str: # type: ignore[override]
|
| 280 |
+
"""Returns the formatted default message of this Rule.
|
| 281 |
+
|
| 282 |
+
Message template: 'Performing explicit type promotion for node {target}. '
|
| 283 |
+
"""
|
| 284 |
+
return self.message_default_template.format(target=target)
|
| 285 |
+
|
| 286 |
+
def format( # type: ignore[override]
|
| 287 |
+
self, level: infra.Level, target
|
| 288 |
+
) -> Tuple[infra.Rule, infra.Level, str]:
|
| 289 |
+
"""Returns a tuple of (Rule, Level, message) for this Rule.
|
| 290 |
+
|
| 291 |
+
Message template: 'Performing explicit type promotion for node {target}. '
|
| 292 |
+
"""
|
| 293 |
+
return self, level, self.format_message(target=target)
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class _FindOperatorOverloadsInOnnxRegistry(infra.Rule):
|
| 297 |
+
"""Find the list of OnnxFunction of the PyTorch operator in onnx registry."""
|
| 298 |
+
|
| 299 |
+
def format_message(self, node) -> str: # type: ignore[override]
|
| 300 |
+
"""Returns the formatted default message of this Rule.
|
| 301 |
+
|
| 302 |
+
Message template: 'Checking if the FX node: {node} is supported in onnx registry.'
|
| 303 |
+
"""
|
| 304 |
+
return self.message_default_template.format(node=node)
|
| 305 |
+
|
| 306 |
+
def format( # type: ignore[override]
|
| 307 |
+
self, level: infra.Level, node
|
| 308 |
+
) -> Tuple[infra.Rule, infra.Level, str]:
|
| 309 |
+
"""Returns a tuple of (Rule, Level, message) for this Rule.
|
| 310 |
+
|
| 311 |
+
Message template: 'Checking if the FX node: {node} is supported in onnx registry.'
|
| 312 |
+
"""
|
| 313 |
+
return self, level, self.format_message(node=node)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
@dataclasses.dataclass
|
| 317 |
+
class _POERules(infra.RuleCollection):
|
| 318 |
+
node_missing_onnx_shape_inference: _NodeMissingOnnxShapeInference = dataclasses.field(
|
| 319 |
+
default=_NodeMissingOnnxShapeInference.from_sarif(
|
| 320 |
+
**{
|
| 321 |
+
"id": "POE0001",
|
| 322 |
+
"name": "node-missing-onnx-shape-inference",
|
| 323 |
+
"short_description": {"text": "Node is missing ONNX shape inference."},
|
| 324 |
+
"full_description": {
|
| 325 |
+
"text": "Node is missing ONNX shape inference. This usually happens when the node is not valid under standard ONNX operator spec.",
|
| 326 |
+
"markdown": "Node is missing ONNX shape inference.\nThis usually happens when the node is not valid under standard ONNX operator spec.\n",
|
| 327 |
+
},
|
| 328 |
+
"message_strings": {
|
| 329 |
+
"default": {
|
| 330 |
+
"text": "The shape inference of {op_name} type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function."
|
| 331 |
+
}
|
| 332 |
+
},
|
| 333 |
+
"help_uri": None,
|
| 334 |
+
"properties": {"deprecated": False, "tags": []},
|
| 335 |
+
}
|
| 336 |
+
),
|
| 337 |
+
init=False,
|
| 338 |
+
)
|
| 339 |
+
"""Node is missing ONNX shape inference."""
|
| 340 |
+
|
| 341 |
+
missing_custom_symbolic_function: _MissingCustomSymbolicFunction = dataclasses.field(
|
| 342 |
+
default=_MissingCustomSymbolicFunction.from_sarif(
|
| 343 |
+
**{
|
| 344 |
+
"id": "POE0002",
|
| 345 |
+
"name": "missing-custom-symbolic-function",
|
| 346 |
+
"short_description": {
|
| 347 |
+
"text": "Missing symbolic function for custom PyTorch operator, cannot translate node to ONNX."
|
| 348 |
+
},
|
| 349 |
+
"full_description": {
|
| 350 |
+
"text": "Missing symbolic function for custom PyTorch operator, cannot translate node to ONNX.",
|
| 351 |
+
"markdown": "Missing symbolic function for custom PyTorch operator, cannot translate node to ONNX.\n",
|
| 352 |
+
},
|
| 353 |
+
"message_strings": {
|
| 354 |
+
"default": {
|
| 355 |
+
"text": "ONNX export failed on an operator with unrecognized namespace {op_name}. If you are trying to export a custom operator, make sure you registered it with the right domain and version."
|
| 356 |
+
}
|
| 357 |
+
},
|
| 358 |
+
"help_uri": None,
|
| 359 |
+
"properties": {"deprecated": False, "tags": []},
|
| 360 |
+
}
|
| 361 |
+
),
|
| 362 |
+
init=False,
|
| 363 |
+
)
|
| 364 |
+
"""Missing symbolic function for custom PyTorch operator, cannot translate node to ONNX."""
|
| 365 |
+
|
| 366 |
+
missing_standard_symbolic_function: _MissingStandardSymbolicFunction = dataclasses.field(
|
| 367 |
+
default=_MissingStandardSymbolicFunction.from_sarif(
|
| 368 |
+
**{
|
| 369 |
+
"id": "POE0003",
|
| 370 |
+
"name": "missing-standard-symbolic-function",
|
| 371 |
+
"short_description": {
|
| 372 |
+
"text": "Missing symbolic function for standard PyTorch operator, cannot translate node to ONNX."
|
| 373 |
+
},
|
| 374 |
+
"full_description": {
|
| 375 |
+
"text": "Missing symbolic function for standard PyTorch operator, cannot translate node to ONNX.",
|
| 376 |
+
"markdown": "Missing symbolic function for standard PyTorch operator, cannot translate node to ONNX.\n",
|
| 377 |
+
},
|
| 378 |
+
"message_strings": {
|
| 379 |
+
"default": {
|
| 380 |
+
"text": "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: {issue_url}."
|
| 381 |
+
}
|
| 382 |
+
},
|
| 383 |
+
"help_uri": None,
|
| 384 |
+
"properties": {"deprecated": False, "tags": []},
|
| 385 |
+
}
|
| 386 |
+
),
|
| 387 |
+
init=False,
|
| 388 |
+
)
|
| 389 |
+
"""Missing symbolic function for standard PyTorch operator, cannot translate node to ONNX."""
|
| 390 |
+
|
| 391 |
+
operator_supported_in_newer_opset_version: _OperatorSupportedInNewerOpsetVersion = dataclasses.field(
|
| 392 |
+
default=_OperatorSupportedInNewerOpsetVersion.from_sarif(
|
| 393 |
+
**{
|
| 394 |
+
"id": "POE0004",
|
| 395 |
+
"name": "operator-supported-in-newer-opset-version",
|
| 396 |
+
"short_description": {
|
| 397 |
+
"text": "Operator is supported in newer opset version."
|
| 398 |
+
},
|
| 399 |
+
"full_description": {
|
| 400 |
+
"text": "Operator is supported in newer opset version.",
|
| 401 |
+
"markdown": "Operator is supported in newer opset version.\n\nExample:\n```python\ntorch.onnx.export(model, args, ..., opset_version=9)\n```\n",
|
| 402 |
+
},
|
| 403 |
+
"message_strings": {
|
| 404 |
+
"default": {
|
| 405 |
+
"text": "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. Support for this operator was added in version {supported_opset_version}, try exporting with this version."
|
| 406 |
+
}
|
| 407 |
+
},
|
| 408 |
+
"help_uri": None,
|
| 409 |
+
"properties": {"deprecated": False, "tags": []},
|
| 410 |
+
}
|
| 411 |
+
),
|
| 412 |
+
init=False,
|
| 413 |
+
)
|
| 414 |
+
"""Operator is supported in newer opset version."""
|
| 415 |
+
|
| 416 |
+
fx_graph_to_onnx: _FxGraphToOnnx = dataclasses.field(
|
| 417 |
+
default=_FxGraphToOnnx.from_sarif(
|
| 418 |
+
**{
|
| 419 |
+
"id": "FXE0007",
|
| 420 |
+
"name": "fx-graph-to-onnx",
|
| 421 |
+
"short_description": {
|
| 422 |
+
"text": "Transforms graph from FX IR to ONNX IR."
|
| 423 |
+
},
|
| 424 |
+
"full_description": {
|
| 425 |
+
"text": "Transforms graph from FX IR to ONNX IR.",
|
| 426 |
+
"markdown": "This diagnostic tracks the transformation process from an FX Graph (in FX IR) to an ONNX Graph (in ONNX IR).\n\n## Key Representations:\n\n- **FX Graph**: The graph in FX IR produced by dynamo or symbolic tracing.\n- **ONNX Graph**: The graph in ONNX IR and [operators](https://onnx.ai/onnx/operators/).\n\n## Additional Notes:\n\n- Prior to this transformation step, the FX graph undergoes preprocessing through multiple FX passes.\n To gain insight into these transformations, refer to diagnostic `FXE0010`.\n- To enable a detailed view of the graph transformation in progress within this diagnostic, switch to the DEBUG mode.\n\n - Set DiagnosticOptions.verbosity_level to logging.DEBUG.\n - Activate the environment variable TORCH_LOGS='onnx_diagnostics'.\n\n- For specific information related to node-level FX to ONNX transformations, explore the diagnostic `FXE0008`.\n",
|
| 427 |
+
},
|
| 428 |
+
"message_strings": {
|
| 429 |
+
"default": {
|
| 430 |
+
"text": "Transforming FX graph {graph_name} to ONNX graph."
|
| 431 |
+
}
|
| 432 |
+
},
|
| 433 |
+
"help_uri": None,
|
| 434 |
+
"properties": {"deprecated": False, "tags": []},
|
| 435 |
+
}
|
| 436 |
+
),
|
| 437 |
+
init=False,
|
| 438 |
+
)
|
| 439 |
+
"""Transforms graph from FX IR to ONNX IR."""
|
| 440 |
+
|
| 441 |
+
fx_node_to_onnx: _FxNodeToOnnx = dataclasses.field(
|
| 442 |
+
default=_FxNodeToOnnx.from_sarif(
|
| 443 |
+
**{
|
| 444 |
+
"id": "FXE0008",
|
| 445 |
+
"name": "fx-node-to-onnx",
|
| 446 |
+
"short_description": {"text": "Transforms an FX node to an ONNX node."},
|
| 447 |
+
"full_description": {
|
| 448 |
+
"text": "Transforms an FX node to an ONNX node.",
|
| 449 |
+
"markdown": "This diagnostic tracks the transformation process from an FX Node to ONNX [Operators](https://onnx.ai/onnx/operators/).\n\nThe process of converting FX Node to ONNX Node involves dealing with six distinct node types:\n 1. `placeholder`: Represents a module input, maps to an ONNX graph input.\n 2. `call_module`: Symbolizes a call to a submodule, maps to an ONNX\n 3. `call_method`: Symbolizes a method call. Not yet implemented.\n 4. `call_function`: Symbolizes a function call. [Core ATen](https://pytorch.org/docs/stable/ir.html#core-aten-ir) is expected\n as the function call target. The mapping from ATen to ONNX is implemented by [ONNXScript torchlib](https://github.com/microsoft/onnxscript/tree/main/onnxscript/function_libs/torch_lib/ops).\n This [guide](https://pytorch.org/docs/stable/onnx.html#onnx-script-functions) shows how to write and register a custom symbolic function for call_function FX node.\n 5. `get_attr`: Indicates an attribute access within the current module. Maps to an ONNX graph initializer.\n 6. `output`: Represents the module's output. Maps to an ONNX graph output.\n\nFor a granular understanding of how each node type is transformed, refer to the implementation details in `FxOnnxInterpreter`.\n",
|
| 450 |
+
},
|
| 451 |
+
"message_strings": {
|
| 452 |
+
"default": {
|
| 453 |
+
"text": "Transforming FX node {node_repr} to ONNX node."
|
| 454 |
+
}
|
| 455 |
+
},
|
| 456 |
+
"help_uri": None,
|
| 457 |
+
"properties": {"deprecated": False, "tags": []},
|
| 458 |
+
}
|
| 459 |
+
),
|
| 460 |
+
init=False,
|
| 461 |
+
)
|
| 462 |
+
"""Transforms an FX node to an ONNX node."""
|
| 463 |
+
|
| 464 |
+
fx_pass: _FxPass = dataclasses.field(
|
| 465 |
+
default=_FxPass.from_sarif(
|
| 466 |
+
**{
|
| 467 |
+
"id": "FXE0010",
|
| 468 |
+
"name": "fx-pass",
|
| 469 |
+
"short_description": {
|
| 470 |
+
"text": "FX graph transformation during ONNX export before converting from FX IR to ONNX IR."
|
| 471 |
+
},
|
| 472 |
+
"full_description": {
|
| 473 |
+
"text": "FX graph transformation during ONNX export before converting from FX IR to ONNX IR.",
|
| 474 |
+
"markdown": "This diagnostic tracks the FX passes executed during the ONNX export process prior\nto converting from FX IR (Intermediate Representation) to ONNX IR.\n\nUnder the scope of ONNX export, an FX pass refers to a specific transformation applied to the FX GraphModule.\nThe primary aim of these passes is to streamline the graph into a format that aligns more with the ONNX IR.\nMoreover, these passes work to substitute unsupported FX IR features with those recognized and endorsed by\nONNX IR. Common transformations include, but aren't limited to, decomposition, functionalization and\ntype promotion.\n\nFor those who are interested in a comprehensive log detailing the modifications made during these passes,\nthere are a couple of options:\n\n- Set DiagnosticOptions.verbosity_level to logging.DEBUG.\n- Activate the environment variable TORCH_LOGS='onnx_diagnostics'.\n\nHowever, it's noteworthy that by default, such detailed logging is turned off. The primary reason being\nits considerable impact on performance.\n\nFor an in-depth understanding of each specific pass, please refer to the directory: torch/onnx/_internal/fx/passes.\n",
|
| 475 |
+
},
|
| 476 |
+
"message_strings": {"default": {"text": "Running {pass_name} pass."}},
|
| 477 |
+
"help_uri": None,
|
| 478 |
+
"properties": {"deprecated": False, "tags": []},
|
| 479 |
+
}
|
| 480 |
+
),
|
| 481 |
+
init=False,
|
| 482 |
+
)
|
| 483 |
+
"""FX graph transformation during ONNX export before converting from FX IR to ONNX IR."""
|
| 484 |
+
|
| 485 |
+
no_symbolic_function_for_call_function: _NoSymbolicFunctionForCallFunction = dataclasses.field(
|
| 486 |
+
default=_NoSymbolicFunctionForCallFunction.from_sarif(
|
| 487 |
+
**{
|
| 488 |
+
"id": "FXE0011",
|
| 489 |
+
"name": "no-symbolic-function-for-call-function",
|
| 490 |
+
"short_description": {
|
| 491 |
+
"text": 'Cannot find symbolic function to convert the "call_function" FX node to ONNX.'
|
| 492 |
+
},
|
| 493 |
+
"full_description": {
|
| 494 |
+
"text": 'Cannot find symbolic function to convert the "call_function" FX node to ONNX. ',
|
| 495 |
+
"markdown": 'This error occurs when the ONNX converter is unable to find a corresponding symbolic function\nto convert a "call_function" node in the input graph to its equivalence in ONNX. The "call_function"\nnode represents a normalized function call in PyTorch, such as "torch.aten.ops.add".\n\nTo resolve this error, you can try one of the following:\n\n- If exists, apply the auto-fix suggested by the diagnostic. TODO: this part is not available yet.\n- Rewrite the model using only supported PyTorch operators or functions.\n- Follow this [guide](https://pytorch.org/tutorials/beginner/onnx/onnx_registry_tutorial.html#overview) to write and\n register a custom symbolic function for the unsupported call_function FX node.\n',
|
| 496 |
+
},
|
| 497 |
+
"message_strings": {
|
| 498 |
+
"default": {
|
| 499 |
+
"text": 'No symbolic function to convert the "call_function" node {target} to ONNX. '
|
| 500 |
+
}
|
| 501 |
+
},
|
| 502 |
+
"help_uri": None,
|
| 503 |
+
"properties": {"deprecated": False, "tags": []},
|
| 504 |
+
}
|
| 505 |
+
),
|
| 506 |
+
init=False,
|
| 507 |
+
)
|
| 508 |
+
"""Cannot find symbolic function to convert the "call_function" FX node to ONNX."""
|
| 509 |
+
|
| 510 |
+
unsupported_fx_node_analysis: _UnsupportedFxNodeAnalysis = dataclasses.field(
|
| 511 |
+
default=_UnsupportedFxNodeAnalysis.from_sarif(
|
| 512 |
+
**{
|
| 513 |
+
"id": "FXE0012",
|
| 514 |
+
"name": "unsupported-fx-node-analysis",
|
| 515 |
+
"short_description": {
|
| 516 |
+
"text": "Result from FX graph analysis to reveal unsupported FX nodes."
|
| 517 |
+
},
|
| 518 |
+
"full_description": {
|
| 519 |
+
"text": "Result from FX graph analysis to reveal unsupported FX nodes.",
|
| 520 |
+
"markdown": "This error indicates that an FX graph contains one or more unsupported nodes. The error message\nis typically accompanied by a list of the unsupported nodes found during analysis.\n\nTo resolve this error, you can try resolving each individual unsupported node error by following\nthe suggestions by its diagnostic. Typically, options include:\n\n- If exists, apply the auto-fix suggested by the diagnostic. TODO: this part is not available yet.\n- Rewrite the model using only supported PyTorch operators or functions.\n- Follow this [guide](https://pytorch.org/docs/stable/onnx.html#onnx-script-functions) to write and\n register a custom symbolic function for the unsupported call_function FX node.\n",
|
| 521 |
+
},
|
| 522 |
+
"message_strings": {
|
| 523 |
+
"default": {
|
| 524 |
+
"text": "Unsupported FX nodes: {node_op_to_target_mapping}. "
|
| 525 |
+
}
|
| 526 |
+
},
|
| 527 |
+
"help_uri": None,
|
| 528 |
+
"properties": {"deprecated": False, "tags": []},
|
| 529 |
+
}
|
| 530 |
+
),
|
| 531 |
+
init=False,
|
| 532 |
+
)
|
| 533 |
+
"""Result from FX graph analysis to reveal unsupported FX nodes."""
|
| 534 |
+
|
| 535 |
+
op_level_debugging: _OpLevelDebugging = dataclasses.field(
|
| 536 |
+
default=_OpLevelDebugging.from_sarif(
|
| 537 |
+
**{
|
| 538 |
+
"id": "FXE0013",
|
| 539 |
+
"name": "op-level-debugging",
|
| 540 |
+
"short_description": {
|
| 541 |
+
"text": "Report any op level validation failure in warnings."
|
| 542 |
+
},
|
| 543 |
+
"full_description": {
|
| 544 |
+
"text": "Report any op level validation failure in warnings.",
|
| 545 |
+
"markdown": "This warning message indicates that during op level debugging, certain symbolic functions\nhave failed to match the results of torch ops when using real tensors generated from fake\ntensors. It is important to note that the symbolic functions may not necessarily be\nincorrect, as the validation process is non-deterministic and should only be used as a\nreference.\n\nThere are two categories of warnings that can be triggered:\n\n1. Non-validated operators:\n If the warnings are caused by the following errors, they can be disregarded by users,\n as these errors occur due to the non-deterministic nature of the validation. However,\n it is important to be aware that the operators have not been validated.\n\n - IndexError: Unsupported input arguments of randomized dimensions/indices(INT64).\n - RuntimeError: Unsupported input arguments for torch ops are generated.\n - ValueError: Arguments/keyword arguments do not match the signature of the symbolic function.\n\n2. Potentially wrong torchlib operators:\n If the warnings are triggered by the following error, users should be aware that the symbolic functions\n may be incorrect in dispatching or implementation. In such cases, it is recommended to report\n the issue to the PyTorch-ONNX team, or create/register a custom symbolic function to replace the default one.\n\n - AssertionError: The symbolic function is potentially wrong as the results do not match the results of torch ops.\n - TypeError: The symbolic function is potentially wrong as the opschema doesn't match inputs.\n",
|
| 546 |
+
},
|
| 547 |
+
"message_strings": {
|
| 548 |
+
"default": {
|
| 549 |
+
"text": "FX node: {node} and its onnx function: {symbolic_fn} fails on op level validation."
|
| 550 |
+
}
|
| 551 |
+
},
|
| 552 |
+
"help_uri": None,
|
| 553 |
+
"properties": {"deprecated": False, "tags": []},
|
| 554 |
+
}
|
| 555 |
+
),
|
| 556 |
+
init=False,
|
| 557 |
+
)
|
| 558 |
+
"""Report any op level validation failure in warnings."""
|
| 559 |
+
|
| 560 |
+
find_opschema_matched_symbolic_function: _FindOpschemaMatchedSymbolicFunction = dataclasses.field(
|
| 561 |
+
default=_FindOpschemaMatchedSymbolicFunction.from_sarif(
|
| 562 |
+
**{
|
| 563 |
+
"id": "FXE0014",
|
| 564 |
+
"name": "find-opschema-matched-symbolic-function",
|
| 565 |
+
"short_description": {
|
| 566 |
+
"text": "Find the OnnxFunction that matches the input/attribute dtypes by comparing them with their opschemas."
|
| 567 |
+
},
|
| 568 |
+
"full_description": {
|
| 569 |
+
"text": "Find the OnnxFunction that matches the input dtypes by comparing them with their opschemas. A warning will be issued if the matched OnnxFunction is not an exact match.",
|
| 570 |
+
"markdown": "When an ATen/Custom operator is registered and needs to be dispatched to an OnnxFunction, the input/attribute\ndtypes of the ATen/Custom operator are compared with the input/attribute dtypes of the OnnxFunction opschemas\nto find a match. However, if a perfect/exact match is not found, the dispatcher will attempt to find\nthe nearest match with the highest number of input/attribute dtypes matching the OnnxFunction opschemas, while\nissuing a warning.\n\nThere are two types of level that can be triggered in this rule:\n\n1. NOTE: A perfect match is found, and no warning is issued.\n2. WARNING: The matched OnnxFunction is not a perfect/exact match.\n\nHere are some suggestions based on the WARNING situation:\n\n1. If there are NO errors or mismatches in the results, it is safe to disregard this warning,\n as the definition of OnnxFunction schema is usually more stringent.\n2. If there are errors or mismatches in the results, it is recommended to:\n (a) Enable op_level_debugging to determine if the OnnxFunction might be incorrect.\n (b) Report the issue to the PyTorch-ONNX team.\n (c) Create/register a custom symbolic function to replace the default one.\n",
|
| 571 |
+
},
|
| 572 |
+
"message_strings": {
|
| 573 |
+
"default": {
|
| 574 |
+
"text": "The OnnxFunction: {symbolic_fn} is the nearest match of the node {node}."
|
| 575 |
+
}
|
| 576 |
+
},
|
| 577 |
+
"help_uri": None,
|
| 578 |
+
"properties": {"deprecated": False, "tags": []},
|
| 579 |
+
}
|
| 580 |
+
),
|
| 581 |
+
init=False,
|
| 582 |
+
)
|
| 583 |
+
"""Find the OnnxFunction that matches the input/attribute dtypes by comparing them with their opschemas."""
|
| 584 |
+
|
| 585 |
+
fx_node_insert_type_promotion: _FxNodeInsertTypePromotion = dataclasses.field(
|
| 586 |
+
default=_FxNodeInsertTypePromotion.from_sarif(
|
| 587 |
+
**{
|
| 588 |
+
"id": "FXE0015",
|
| 589 |
+
"name": "fx-node-insert-type-promotion",
|
| 590 |
+
"short_description": {
|
| 591 |
+
"text": "Determine if type promotion is required for the FX node. Insert cast nodes if needed."
|
| 592 |
+
},
|
| 593 |
+
"full_description": {
|
| 594 |
+
"text": "Determine if type promotion is required for the FX node. Insert cast nodes if needed.",
|
| 595 |
+
"markdown": "This diagnostic monitors the node-level type promotion insertion process. In PyTorch, there is an automatic process called implicit type promotion,\nwhere the input types of an operator are promoted to a common type. The determination of the common type is based on the type promotion rule specific to each operator.\nTo learn more about PyTorch's type promotion rules, refer to the [elementwise_dtypes doc](https://github.com/pytorch/pytorch/blob/f044613f78df713fb57f70c608483c9f10ad332e/torch/_prims_common/__init__.py#L1252-L1335)\nand [torch._refs ops](https://github.com/pytorch/pytorch/blob/a475ea4542dfe961c9d097e33ab5041f61c8c17f/torch/_refs/__init__.py#L484).\n\nHowever, implicit type promotion is not supported in ONNX. Therefore, to replicate the PyTorch behavior, we need to explicitly insert cast nodes.\nThis diagnostic tracks the process of node-level type promotion insertion.\n\nThe type promotion rules used by this process can be found in `torch/onnx/_internal/fx/passes/type_promotion.py.`\nTo update or add new type promotion rules, please refer to the [Note: Update type promotion rule] section.\n",
|
| 596 |
+
},
|
| 597 |
+
"message_strings": {
|
| 598 |
+
"default": {
|
| 599 |
+
"text": "Performing explicit type promotion for node {target}. "
|
| 600 |
+
}
|
| 601 |
+
},
|
| 602 |
+
"help_uri": None,
|
| 603 |
+
"properties": {"deprecated": False, "tags": []},
|
| 604 |
+
}
|
| 605 |
+
),
|
| 606 |
+
init=False,
|
| 607 |
+
)
|
| 608 |
+
"""Determine if type promotion is required for the FX node. Insert cast nodes if needed."""
|
| 609 |
+
|
| 610 |
+
find_operator_overloads_in_onnx_registry: _FindOperatorOverloadsInOnnxRegistry = dataclasses.field(
|
| 611 |
+
default=_FindOperatorOverloadsInOnnxRegistry.from_sarif(
|
| 612 |
+
**{
|
| 613 |
+
"id": "FXE0016",
|
| 614 |
+
"name": "find-operator-overloads-in-onnx-registry",
|
| 615 |
+
"short_description": {
|
| 616 |
+
"text": "Find the list of OnnxFunction of the PyTorch operator in onnx registry."
|
| 617 |
+
},
|
| 618 |
+
"full_description": {
|
| 619 |
+
"text": "This rule involves finding the list of OnnxFunction for the PyTorch operator overload in the ONNX registry. If the operator overload is not supported but its default overload is, a warning will be issued. If both the operator overload and its default overload are not supported, an error will be issued.",
|
| 620 |
+
"markdown": "The operator overload name serves the purpose of verifying whether a PyTorch operator is registered in the ONNX registry.\nIf it's not found, the dispatcher takes a fallback approach and tries to locate the default overload of the PyTorch\noperator in the registry. If even the default overload is absent, it signifies that the operator is officially unsupported.\n\nThere are three types of level that can be triggered in this rule:\n\n1. NOTE: The op overload is supported.\n2. WARNING: The op overload is not supported, but it's default overload is supported.\n3. ERROR: The op overload is not supported, and it's default overload is also not supported.\n\nHere are some suggestions based on the WARNING situation:\n\n1. If there are NO errors or mismatches in the results, it is safe to disregard this warning.\n2. If there are errors or mismatches in the results, it is recommended to:\n (a) Enable op_level_debugging to determine if the OnnxFunction might be incorrect.\n (b) Report the unsupported overload to the PyTorch-ONNX team.\n (c) Create/register a custom symbolic function to replace the default one.\n\nHere are some suggestions based on the ERROR situation:\n\n1. Report the unsupported operator to the PyTorch-ONNX team.\n2. Create/register a custom symbolic function to replace the default one.\n",
|
| 621 |
+
},
|
| 622 |
+
"message_strings": {
|
| 623 |
+
"default": {
|
| 624 |
+
"text": "Checking if the FX node: {node} is supported in onnx registry."
|
| 625 |
+
}
|
| 626 |
+
},
|
| 627 |
+
"help_uri": None,
|
| 628 |
+
"properties": {"deprecated": False, "tags": []},
|
| 629 |
+
}
|
| 630 |
+
),
|
| 631 |
+
init=False,
|
| 632 |
+
)
|
| 633 |
+
"""Find the list of OnnxFunction of the PyTorch operator in onnx registry."""
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
rules = _POERules()
|
.venv/Lib/site-packages/torch/onnx/_internal/diagnostics/infra/_infra.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
"""This file defines an additional layer of abstraction on top of the SARIF OM."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import dataclasses
|
| 7 |
+
import enum
|
| 8 |
+
import logging
|
| 9 |
+
from typing import Mapping, Sequence
|
| 10 |
+
|
| 11 |
+
from torch.onnx._internal.diagnostics.infra import formatter, sarif
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Level(enum.IntEnum):
|
| 15 |
+
"""The level of a diagnostic.
|
| 16 |
+
|
| 17 |
+
This class is used to represent the level of a diagnostic. The levels are defined
|
| 18 |
+
by the SARIF specification, and are not modifiable. For alternative categories,
|
| 19 |
+
please use infra.Tag instead. When selecting a level, please consider the following
|
| 20 |
+
guidelines:
|
| 21 |
+
|
| 22 |
+
- NONE: Informational result that does not indicate the presence of a problem.
|
| 23 |
+
- NOTE: An opportunity for improvement was found.
|
| 24 |
+
- WARNING: A potential problem was found.
|
| 25 |
+
- ERROR: A serious problem was found.
|
| 26 |
+
|
| 27 |
+
This level is a subclass of enum.IntEnum, and can be used as an integer. Its integer
|
| 28 |
+
value maps to the logging levels in Python's logging module. The mapping is as
|
| 29 |
+
follows:
|
| 30 |
+
|
| 31 |
+
Level.NONE = logging.DEBUG = 10
|
| 32 |
+
Level.NOTE = logging.INFO = 20
|
| 33 |
+
Level.WARNING = logging.WARNING = 30
|
| 34 |
+
Level.ERROR = logging.ERROR = 40
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
NONE = 10
|
| 38 |
+
NOTE = 20
|
| 39 |
+
WARNING = 30
|
| 40 |
+
ERROR = 40
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
levels = Level
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class Tag(enum.Enum):
|
| 47 |
+
"""The tag of a diagnostic. This class can be inherited to define custom tags."""
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class PatchedPropertyBag(sarif.PropertyBag):
|
| 51 |
+
"""Key/value pairs that provide additional information about the object.
|
| 52 |
+
|
| 53 |
+
The definition of PropertyBag via SARIF spec is "A property bag is an object (section 3.6)
|
| 54 |
+
containing an unordered set of properties with arbitrary names." However it is not
|
| 55 |
+
reflected in the json file, and therefore not captured by the python representation.
|
| 56 |
+
This patch adds additional **kwargs to the `__init__` method to allow recording
|
| 57 |
+
arbitrary key/value pairs.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(self, tags: list[str] | None = None, **kwargs):
|
| 61 |
+
super().__init__(tags=tags)
|
| 62 |
+
self.__dict__.update(kwargs)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@dataclasses.dataclass(frozen=True)
|
| 66 |
+
class Rule:
|
| 67 |
+
id: str
|
| 68 |
+
name: str
|
| 69 |
+
message_default_template: str
|
| 70 |
+
short_description: str | None = None
|
| 71 |
+
full_description: str | None = None
|
| 72 |
+
full_description_markdown: str | None = None
|
| 73 |
+
help_uri: str | None = None
|
| 74 |
+
|
| 75 |
+
@classmethod
|
| 76 |
+
def from_sarif(cls, **kwargs):
|
| 77 |
+
"""Returns a rule from the SARIF reporting descriptor."""
|
| 78 |
+
short_description = kwargs.get("short_description", {}).get("text")
|
| 79 |
+
full_description = kwargs.get("full_description", {}).get("text")
|
| 80 |
+
full_description_markdown = kwargs.get("full_description", {}).get("markdown")
|
| 81 |
+
help_uri = kwargs.get("help_uri")
|
| 82 |
+
|
| 83 |
+
rule = cls(
|
| 84 |
+
id=kwargs["id"],
|
| 85 |
+
name=kwargs["name"],
|
| 86 |
+
message_default_template=kwargs["message_strings"]["default"]["text"],
|
| 87 |
+
short_description=short_description,
|
| 88 |
+
full_description=full_description,
|
| 89 |
+
full_description_markdown=full_description_markdown,
|
| 90 |
+
help_uri=help_uri,
|
| 91 |
+
)
|
| 92 |
+
return rule
|
| 93 |
+
|
| 94 |
+
def sarif(self) -> sarif.ReportingDescriptor:
|
| 95 |
+
"""Returns a SARIF reporting descriptor of this Rule."""
|
| 96 |
+
short_description = (
|
| 97 |
+
sarif.MultiformatMessageString(text=self.short_description)
|
| 98 |
+
if self.short_description is not None
|
| 99 |
+
else None
|
| 100 |
+
)
|
| 101 |
+
full_description = (
|
| 102 |
+
sarif.MultiformatMessageString(
|
| 103 |
+
text=self.full_description, markdown=self.full_description_markdown
|
| 104 |
+
)
|
| 105 |
+
if self.full_description is not None
|
| 106 |
+
else None
|
| 107 |
+
)
|
| 108 |
+
return sarif.ReportingDescriptor(
|
| 109 |
+
id=self.id,
|
| 110 |
+
name=self.name,
|
| 111 |
+
short_description=short_description,
|
| 112 |
+
full_description=full_description,
|
| 113 |
+
help_uri=self.help_uri,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def format(self, level: Level, *args, **kwargs) -> tuple[Rule, Level, str]:
|
| 117 |
+
"""Returns a tuple of (rule, level, message) for a diagnostic.
|
| 118 |
+
|
| 119 |
+
This method is used to format the message of a diagnostic. The message is
|
| 120 |
+
formatted using the default template of this rule, and the arguments passed in
|
| 121 |
+
as `*args` and `**kwargs`. The level is used to override the default level of
|
| 122 |
+
this rule.
|
| 123 |
+
"""
|
| 124 |
+
return (self, level, self.format_message(*args, **kwargs))
|
| 125 |
+
|
| 126 |
+
def format_message(self, *args, **kwargs) -> str:
|
| 127 |
+
"""Returns the formatted default message of this Rule.
|
| 128 |
+
|
| 129 |
+
This method should be overridden (with code generation) by subclasses to reflect
|
| 130 |
+
the exact arguments needed by the message template. This is a helper method to
|
| 131 |
+
create the default message for a diagnostic.
|
| 132 |
+
"""
|
| 133 |
+
return self.message_default_template.format(*args, **kwargs)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@dataclasses.dataclass
|
| 137 |
+
class Location:
|
| 138 |
+
uri: str | None = None
|
| 139 |
+
line: int | None = None
|
| 140 |
+
message: str | None = None
|
| 141 |
+
start_column: int | None = None
|
| 142 |
+
end_column: int | None = None
|
| 143 |
+
snippet: str | None = None
|
| 144 |
+
function: str | None = None
|
| 145 |
+
|
| 146 |
+
def sarif(self) -> sarif.Location:
|
| 147 |
+
"""Returns the SARIF representation of this location."""
|
| 148 |
+
return sarif.Location(
|
| 149 |
+
physical_location=sarif.PhysicalLocation(
|
| 150 |
+
artifact_location=sarif.ArtifactLocation(uri=self.uri),
|
| 151 |
+
region=sarif.Region(
|
| 152 |
+
start_line=self.line,
|
| 153 |
+
start_column=self.start_column,
|
| 154 |
+
end_column=self.end_column,
|
| 155 |
+
snippet=sarif.ArtifactContent(text=self.snippet),
|
| 156 |
+
),
|
| 157 |
+
),
|
| 158 |
+
message=sarif.Message(text=self.message)
|
| 159 |
+
if self.message is not None
|
| 160 |
+
else None,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
@dataclasses.dataclass
|
| 165 |
+
class StackFrame:
|
| 166 |
+
location: Location
|
| 167 |
+
|
| 168 |
+
def sarif(self) -> sarif.StackFrame:
|
| 169 |
+
"""Returns the SARIF representation of this stack frame."""
|
| 170 |
+
return sarif.StackFrame(location=self.location.sarif())
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
@dataclasses.dataclass
|
| 174 |
+
class Stack:
|
| 175 |
+
"""Records a stack trace. The frames are in order from newest to oldest stack frame."""
|
| 176 |
+
|
| 177 |
+
frames: list[StackFrame] = dataclasses.field(default_factory=list)
|
| 178 |
+
message: str | None = None
|
| 179 |
+
|
| 180 |
+
def sarif(self) -> sarif.Stack:
|
| 181 |
+
"""Returns the SARIF representation of this stack."""
|
| 182 |
+
return sarif.Stack(
|
| 183 |
+
frames=[frame.sarif() for frame in self.frames],
|
| 184 |
+
message=sarif.Message(text=self.message)
|
| 185 |
+
if self.message is not None
|
| 186 |
+
else None,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
@dataclasses.dataclass
|
| 191 |
+
class ThreadFlowLocation:
|
| 192 |
+
"""Records code location and the initial state."""
|
| 193 |
+
|
| 194 |
+
location: Location
|
| 195 |
+
state: Mapping[str, str]
|
| 196 |
+
index: int
|
| 197 |
+
stack: Stack | None = None
|
| 198 |
+
|
| 199 |
+
def sarif(self) -> sarif.ThreadFlowLocation:
|
| 200 |
+
"""Returns the SARIF representation of this thread flow location."""
|
| 201 |
+
return sarif.ThreadFlowLocation(
|
| 202 |
+
location=self.location.sarif(),
|
| 203 |
+
state=self.state,
|
| 204 |
+
stack=self.stack.sarif() if self.stack is not None else None,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
@dataclasses.dataclass
|
| 209 |
+
class Graph:
|
| 210 |
+
"""A graph of diagnostics.
|
| 211 |
+
|
| 212 |
+
This class stores the string representation of a model graph.
|
| 213 |
+
The `nodes` and `edges` fields are unused in the current implementation.
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
graph: str
|
| 217 |
+
name: str
|
| 218 |
+
description: str | None = None
|
| 219 |
+
|
| 220 |
+
def sarif(self) -> sarif.Graph:
|
| 221 |
+
"""Returns the SARIF representation of this graph."""
|
| 222 |
+
return sarif.Graph(
|
| 223 |
+
description=sarif.Message(text=self.graph),
|
| 224 |
+
properties=PatchedPropertyBag(name=self.name, description=self.description),
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
@dataclasses.dataclass
|
| 229 |
+
class RuleCollection:
|
| 230 |
+
_rule_id_name_set: frozenset[tuple[str, str]] = dataclasses.field(init=False)
|
| 231 |
+
|
| 232 |
+
def __post_init__(self) -> None:
|
| 233 |
+
self._rule_id_name_set = frozenset(
|
| 234 |
+
{
|
| 235 |
+
(field.default.id, field.default.name)
|
| 236 |
+
for field in dataclasses.fields(self)
|
| 237 |
+
if isinstance(field.default, Rule)
|
| 238 |
+
}
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
def __contains__(self, rule: Rule) -> bool:
|
| 242 |
+
"""Checks if the rule is in the collection."""
|
| 243 |
+
return (rule.id, rule.name) in self._rule_id_name_set
|
| 244 |
+
|
| 245 |
+
@classmethod
|
| 246 |
+
def custom_collection_from_list(
|
| 247 |
+
cls, new_collection_class_name: str, rules: Sequence[Rule]
|
| 248 |
+
) -> RuleCollection:
|
| 249 |
+
"""Creates a custom class inherited from RuleCollection with the list of rules."""
|
| 250 |
+
return dataclasses.make_dataclass(
|
| 251 |
+
new_collection_class_name,
|
| 252 |
+
[
|
| 253 |
+
(
|
| 254 |
+
formatter.kebab_case_to_snake_case(rule.name),
|
| 255 |
+
type(rule),
|
| 256 |
+
dataclasses.field(default=rule),
|
| 257 |
+
)
|
| 258 |
+
for rule in rules
|
| 259 |
+
],
|
| 260 |
+
bases=(cls,),
|
| 261 |
+
)()
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class Invocation:
|
| 265 |
+
# TODO: Implement this.
|
| 266 |
+
# Tracks top level call arguments and diagnostic options.
|
| 267 |
+
def __init__(self) -> None:
|
| 268 |
+
raise NotImplementedError
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
@dataclasses.dataclass
|
| 272 |
+
class DiagnosticOptions:
|
| 273 |
+
"""Options for diagnostic context.
|
| 274 |
+
|
| 275 |
+
Attributes:
|
| 276 |
+
verbosity_level: Set the amount of information logged for each diagnostics,
|
| 277 |
+
equivalent to the 'level' in Python logging module.
|
| 278 |
+
warnings_as_errors: When True, warning diagnostics are treated as error diagnostics.
|
| 279 |
+
"""
|
| 280 |
+
|
| 281 |
+
verbosity_level: int = dataclasses.field(default=logging.INFO)
|
| 282 |
+
"""Set the amount of information logged for each diagnostics, equivalent to the 'level' in Python logging module."""
|
| 283 |
+
|
| 284 |
+
warnings_as_errors: bool = dataclasses.field(default=False)
|
| 285 |
+
"""If True, warning diagnostics are treated as error diagnostics."""
|
.venv/Lib/site-packages/torch/onnx/_internal/diagnostics/infra/context.py
ADDED
|
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
"""A diagnostic context based on SARIF."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import contextlib
|
| 7 |
+
import dataclasses
|
| 8 |
+
import gzip
|
| 9 |
+
import logging
|
| 10 |
+
from typing import Callable, Generator, Generic, Literal, Mapping, TypeVar
|
| 11 |
+
from typing_extensions import Self
|
| 12 |
+
|
| 13 |
+
from torch.onnx._internal.diagnostics import infra
|
| 14 |
+
from torch.onnx._internal.diagnostics.infra import formatter, sarif, utils
|
| 15 |
+
from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# This is a workaround for mypy not supporting Self from typing_extensions.
|
| 19 |
+
_Diagnostic = TypeVar("_Diagnostic", bound="Diagnostic")
|
| 20 |
+
diagnostic_logger: logging.Logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclasses.dataclass
|
| 24 |
+
class Diagnostic:
|
| 25 |
+
rule: infra.Rule
|
| 26 |
+
level: infra.Level
|
| 27 |
+
message: str | None = None
|
| 28 |
+
locations: list[infra.Location] = dataclasses.field(default_factory=list)
|
| 29 |
+
stacks: list[infra.Stack] = dataclasses.field(default_factory=list)
|
| 30 |
+
graphs: list[infra.Graph] = dataclasses.field(default_factory=list)
|
| 31 |
+
thread_flow_locations: list[infra.ThreadFlowLocation] = dataclasses.field(
|
| 32 |
+
default_factory=list
|
| 33 |
+
)
|
| 34 |
+
additional_messages: list[str] = dataclasses.field(default_factory=list)
|
| 35 |
+
tags: list[infra.Tag] = dataclasses.field(default_factory=list)
|
| 36 |
+
source_exception: Exception | None = None
|
| 37 |
+
"""The exception that caused this diagnostic to be created."""
|
| 38 |
+
logger: logging.Logger = dataclasses.field(init=False, default=diagnostic_logger)
|
| 39 |
+
"""The logger for this diagnostic. Defaults to 'diagnostic_logger' which has the same
|
| 40 |
+
log level setting with `DiagnosticOptions.verbosity_level`."""
|
| 41 |
+
_current_log_section_depth: int = 0
|
| 42 |
+
|
| 43 |
+
def __post_init__(self) -> None:
|
| 44 |
+
pass
|
| 45 |
+
|
| 46 |
+
def sarif(self) -> sarif.Result:
|
| 47 |
+
"""Returns the SARIF Result representation of this diagnostic."""
|
| 48 |
+
message = self.message or self.rule.message_default_template
|
| 49 |
+
if self.additional_messages:
|
| 50 |
+
additional_message = "\n".join(self.additional_messages)
|
| 51 |
+
message_markdown = (
|
| 52 |
+
f"{message}\n\n## Additional Message:\n\n{additional_message}"
|
| 53 |
+
)
|
| 54 |
+
else:
|
| 55 |
+
message_markdown = message
|
| 56 |
+
|
| 57 |
+
kind: Literal["informational", "fail"] = (
|
| 58 |
+
"informational" if self.level == infra.Level.NONE else "fail"
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
sarif_result = sarif.Result(
|
| 62 |
+
message=sarif.Message(text=message, markdown=message_markdown),
|
| 63 |
+
level=self.level.name.lower(), # type: ignore[arg-type]
|
| 64 |
+
rule_id=self.rule.id,
|
| 65 |
+
kind=kind,
|
| 66 |
+
)
|
| 67 |
+
sarif_result.locations = [location.sarif() for location in self.locations]
|
| 68 |
+
sarif_result.stacks = [stack.sarif() for stack in self.stacks]
|
| 69 |
+
sarif_result.graphs = [graph.sarif() for graph in self.graphs]
|
| 70 |
+
sarif_result.code_flows = [
|
| 71 |
+
sarif.CodeFlow(
|
| 72 |
+
thread_flows=[
|
| 73 |
+
sarif.ThreadFlow(
|
| 74 |
+
locations=[loc.sarif() for loc in self.thread_flow_locations]
|
| 75 |
+
)
|
| 76 |
+
]
|
| 77 |
+
)
|
| 78 |
+
]
|
| 79 |
+
sarif_result.properties = sarif.PropertyBag(
|
| 80 |
+
tags=[tag.value for tag in self.tags]
|
| 81 |
+
)
|
| 82 |
+
return sarif_result
|
| 83 |
+
|
| 84 |
+
def with_location(self: Self, location: infra.Location) -> Self:
|
| 85 |
+
"""Adds a location to the diagnostic."""
|
| 86 |
+
self.locations.append(location)
|
| 87 |
+
return self
|
| 88 |
+
|
| 89 |
+
def with_thread_flow_location(
|
| 90 |
+
self: Self, location: infra.ThreadFlowLocation
|
| 91 |
+
) -> Self:
|
| 92 |
+
"""Adds a thread flow location to the diagnostic."""
|
| 93 |
+
self.thread_flow_locations.append(location)
|
| 94 |
+
return self
|
| 95 |
+
|
| 96 |
+
def with_stack(self: Self, stack: infra.Stack) -> Self:
|
| 97 |
+
"""Adds a stack to the diagnostic."""
|
| 98 |
+
self.stacks.append(stack)
|
| 99 |
+
return self
|
| 100 |
+
|
| 101 |
+
def with_graph(self: Self, graph: infra.Graph) -> Self:
|
| 102 |
+
"""Adds a graph to the diagnostic."""
|
| 103 |
+
self.graphs.append(graph)
|
| 104 |
+
return self
|
| 105 |
+
|
| 106 |
+
@contextlib.contextmanager
|
| 107 |
+
def log_section(
|
| 108 |
+
self, level: int, message: str, *args, **kwargs
|
| 109 |
+
) -> Generator[None, None, None]:
|
| 110 |
+
"""
|
| 111 |
+
Context manager for a section of log messages, denoted by a title message and increased indentation.
|
| 112 |
+
|
| 113 |
+
Same api as `logging.Logger.log`.
|
| 114 |
+
|
| 115 |
+
This context manager logs the given title at the specified log level, increases the current
|
| 116 |
+
section depth for subsequent log messages, and ensures that the section depth is decreased
|
| 117 |
+
again when exiting the context.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
level: The log level.
|
| 121 |
+
message: The title message to log.
|
| 122 |
+
*args: The arguments to the message. Use `LazyString` to defer the
|
| 123 |
+
expensive evaluation of the arguments until the message is actually logged.
|
| 124 |
+
**kwargs: The keyword arguments for `logging.Logger.log`.
|
| 125 |
+
|
| 126 |
+
Yields:
|
| 127 |
+
None: This context manager does not yield any value.
|
| 128 |
+
|
| 129 |
+
Example:
|
| 130 |
+
>>> with DiagnosticContext("DummyContext", "1.0"):
|
| 131 |
+
... rule = infra.Rule("RuleID", "DummyRule", "Rule message")
|
| 132 |
+
... diagnostic = Diagnostic(rule, infra.Level.WARNING)
|
| 133 |
+
... with diagnostic.log_section(logging.INFO, "My Section"):
|
| 134 |
+
... diagnostic.log(logging.INFO, "My Message")
|
| 135 |
+
... with diagnostic.log_section(logging.INFO, "My Subsection"):
|
| 136 |
+
... diagnostic.log(logging.INFO, "My Submessage")
|
| 137 |
+
... diagnostic.additional_messages
|
| 138 |
+
['## My Section', 'My Message', '### My Subsection', 'My Submessage']
|
| 139 |
+
"""
|
| 140 |
+
if self.logger.isEnabledFor(level):
|
| 141 |
+
indented_format_message = (
|
| 142 |
+
f"##{'#' * self._current_log_section_depth } {message}"
|
| 143 |
+
)
|
| 144 |
+
self.log(
|
| 145 |
+
level,
|
| 146 |
+
indented_format_message,
|
| 147 |
+
*args,
|
| 148 |
+
**kwargs,
|
| 149 |
+
)
|
| 150 |
+
self._current_log_section_depth += 1
|
| 151 |
+
try:
|
| 152 |
+
yield
|
| 153 |
+
finally:
|
| 154 |
+
self._current_log_section_depth -= 1
|
| 155 |
+
|
| 156 |
+
def log(self, level: int, message: str, *args, **kwargs) -> None:
|
| 157 |
+
"""Logs a message within the diagnostic. Same api as `logging.Logger.log`.
|
| 158 |
+
|
| 159 |
+
If logger is not enabled for the given level, the message will not be logged.
|
| 160 |
+
Otherwise, the message will be logged and also added to the diagnostic's additional_messages.
|
| 161 |
+
|
| 162 |
+
The default setting for `DiagnosticOptions.verbosity_level` is `logging.INFO`. Based on this default,
|
| 163 |
+
the log level recommendations are as follows. If you've set a different default verbosity level in your
|
| 164 |
+
application, please adjust accordingly:
|
| 165 |
+
|
| 166 |
+
- logging.ERROR: Log any events leading to application failure.
|
| 167 |
+
- logging.WARNING: Log events that might result in application issues or failures, although not guaranteed.
|
| 168 |
+
- logging.INFO: Log general useful information, ensuring minimal performance overhead.
|
| 169 |
+
- logging.DEBUG: Log detailed debug information, which might affect performance when logged.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
level: The log level.
|
| 173 |
+
message: The message to log.
|
| 174 |
+
*args: The arguments to the message. Use `LazyString` to defer the
|
| 175 |
+
expensive evaluation of the arguments until the message is actually logged.
|
| 176 |
+
**kwargs: The keyword arguments for `logging.Logger.log`.
|
| 177 |
+
"""
|
| 178 |
+
if self.logger.isEnabledFor(level):
|
| 179 |
+
formatted_message = message % args
|
| 180 |
+
self.logger.log(level, formatted_message, **kwargs)
|
| 181 |
+
self.additional_messages.append(formatted_message)
|
| 182 |
+
|
| 183 |
+
def debug(self, message: str, *args, **kwargs) -> None:
|
| 184 |
+
"""Logs a debug message within the diagnostic. Same api as logging.Logger.debug.
|
| 185 |
+
|
| 186 |
+
Checkout `log` for more details.
|
| 187 |
+
"""
|
| 188 |
+
self.log(logging.DEBUG, message, *args, **kwargs)
|
| 189 |
+
|
| 190 |
+
def info(self, message: str, *args, **kwargs) -> None:
|
| 191 |
+
"""Logs an info message within the diagnostic. Same api as logging.Logger.info.
|
| 192 |
+
|
| 193 |
+
Checkout `log` for more details.
|
| 194 |
+
"""
|
| 195 |
+
self.log(logging.INFO, message, *args, **kwargs)
|
| 196 |
+
|
| 197 |
+
def warning(self, message: str, *args, **kwargs) -> None:
|
| 198 |
+
"""Logs a warning message within the diagnostic. Same api as logging.Logger.warning.
|
| 199 |
+
|
| 200 |
+
Checkout `log` for more details.
|
| 201 |
+
"""
|
| 202 |
+
self.log(logging.WARNING, message, *args, **kwargs)
|
| 203 |
+
|
| 204 |
+
def error(self, message: str, *args, **kwargs) -> None:
|
| 205 |
+
"""Logs an error message within the diagnostic. Same api as logging.Logger.error.
|
| 206 |
+
|
| 207 |
+
Checkout `log` for more details.
|
| 208 |
+
"""
|
| 209 |
+
self.log(logging.ERROR, message, *args, **kwargs)
|
| 210 |
+
|
| 211 |
+
def log_source_exception(self, level: int, exception: Exception) -> None:
|
| 212 |
+
"""Logs a source exception within the diagnostic.
|
| 213 |
+
|
| 214 |
+
Invokes `log_section` and `log` to log the exception in markdown section format.
|
| 215 |
+
"""
|
| 216 |
+
self.source_exception = exception
|
| 217 |
+
with self.log_section(level, "Exception log"):
|
| 218 |
+
self.log(level, "%s", formatter.lazy_format_exception(exception))
|
| 219 |
+
|
| 220 |
+
def record_python_call_stack(self, frames_to_skip: int) -> infra.Stack:
|
| 221 |
+
"""Records the current Python call stack."""
|
| 222 |
+
frames_to_skip += 1 # Skip this function.
|
| 223 |
+
stack = utils.python_call_stack(frames_to_skip=frames_to_skip)
|
| 224 |
+
self.with_stack(stack)
|
| 225 |
+
if len(stack.frames) > 0:
|
| 226 |
+
self.with_location(stack.frames[0].location)
|
| 227 |
+
return stack
|
| 228 |
+
|
| 229 |
+
def record_python_call(
|
| 230 |
+
self,
|
| 231 |
+
fn: Callable,
|
| 232 |
+
state: Mapping[str, str],
|
| 233 |
+
message: str | None = None,
|
| 234 |
+
frames_to_skip: int = 0,
|
| 235 |
+
) -> infra.ThreadFlowLocation:
|
| 236 |
+
"""Records a python call as one thread flow step."""
|
| 237 |
+
frames_to_skip += 1 # Skip this function.
|
| 238 |
+
stack = utils.python_call_stack(frames_to_skip=frames_to_skip, frames_to_log=5)
|
| 239 |
+
location = utils.function_location(fn)
|
| 240 |
+
location.message = message
|
| 241 |
+
# Add function location to the top of the stack.
|
| 242 |
+
stack.frames.insert(0, infra.StackFrame(location=location))
|
| 243 |
+
thread_flow_location = infra.ThreadFlowLocation(
|
| 244 |
+
location=location,
|
| 245 |
+
state=state,
|
| 246 |
+
index=len(self.thread_flow_locations),
|
| 247 |
+
stack=stack,
|
| 248 |
+
)
|
| 249 |
+
self.with_thread_flow_location(thread_flow_location)
|
| 250 |
+
return thread_flow_location
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class RuntimeErrorWithDiagnostic(RuntimeError):
|
| 254 |
+
"""Runtime error with enclosed diagnostic information."""
|
| 255 |
+
|
| 256 |
+
def __init__(self, diagnostic: Diagnostic):
|
| 257 |
+
super().__init__(diagnostic.message)
|
| 258 |
+
self.diagnostic = diagnostic
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
@dataclasses.dataclass
|
| 262 |
+
class DiagnosticContext(Generic[_Diagnostic]):
|
| 263 |
+
name: str
|
| 264 |
+
version: str
|
| 265 |
+
options: infra.DiagnosticOptions = dataclasses.field(
|
| 266 |
+
default_factory=infra.DiagnosticOptions
|
| 267 |
+
)
|
| 268 |
+
diagnostics: list[_Diagnostic] = dataclasses.field(init=False, default_factory=list)
|
| 269 |
+
# TODO(bowbao): Implement this.
|
| 270 |
+
# _invocation: infra.Invocation = dataclasses.field(init=False)
|
| 271 |
+
_inflight_diagnostics: list[_Diagnostic] = dataclasses.field(
|
| 272 |
+
init=False, default_factory=list
|
| 273 |
+
)
|
| 274 |
+
_previous_log_level: int = dataclasses.field(init=False, default=logging.WARNING)
|
| 275 |
+
logger: logging.Logger = dataclasses.field(init=False, default=diagnostic_logger)
|
| 276 |
+
_bound_diagnostic_type: type = dataclasses.field(init=False, default=Diagnostic)
|
| 277 |
+
|
| 278 |
+
def __enter__(self):
|
| 279 |
+
self._previous_log_level = self.logger.level
|
| 280 |
+
self.logger.setLevel(self.options.verbosity_level)
|
| 281 |
+
return self
|
| 282 |
+
|
| 283 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 284 |
+
self.logger.setLevel(self._previous_log_level)
|
| 285 |
+
return None
|
| 286 |
+
|
| 287 |
+
def sarif(self) -> sarif.Run:
|
| 288 |
+
"""Returns the SARIF Run object."""
|
| 289 |
+
unique_rules = {diagnostic.rule for diagnostic in self.diagnostics}
|
| 290 |
+
return sarif.Run(
|
| 291 |
+
sarif.Tool(
|
| 292 |
+
driver=sarif.ToolComponent(
|
| 293 |
+
name=self.name,
|
| 294 |
+
version=self.version,
|
| 295 |
+
rules=[rule.sarif() for rule in unique_rules],
|
| 296 |
+
)
|
| 297 |
+
),
|
| 298 |
+
results=[diagnostic.sarif() for diagnostic in self.diagnostics],
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
def sarif_log(self) -> sarif.SarifLog: # type: ignore[name-defined]
|
| 302 |
+
"""Returns the SARIF Log object."""
|
| 303 |
+
return sarif.SarifLog(
|
| 304 |
+
version=sarif_version.SARIF_VERSION,
|
| 305 |
+
schema_uri=sarif_version.SARIF_SCHEMA_LINK,
|
| 306 |
+
runs=[self.sarif()],
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
def to_json(self) -> str:
|
| 310 |
+
return formatter.sarif_to_json(self.sarif_log())
|
| 311 |
+
|
| 312 |
+
def dump(self, file_path: str, compress: bool = False) -> None:
|
| 313 |
+
"""Dumps the SARIF log to a file."""
|
| 314 |
+
if compress:
|
| 315 |
+
with gzip.open(file_path, "wt") as f:
|
| 316 |
+
f.write(self.to_json())
|
| 317 |
+
else:
|
| 318 |
+
with open(file_path, "w") as f:
|
| 319 |
+
f.write(self.to_json())
|
| 320 |
+
|
| 321 |
+
def log(self, diagnostic: _Diagnostic) -> None:
|
| 322 |
+
"""Logs a diagnostic.
|
| 323 |
+
|
| 324 |
+
This method should be used only after all the necessary information for the diagnostic
|
| 325 |
+
has been collected.
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
diagnostic: The diagnostic to add.
|
| 329 |
+
"""
|
| 330 |
+
if not isinstance(diagnostic, self._bound_diagnostic_type):
|
| 331 |
+
raise TypeError(
|
| 332 |
+
f"Expected diagnostic of type {self._bound_diagnostic_type}, got {type(diagnostic)}"
|
| 333 |
+
)
|
| 334 |
+
if self.options.warnings_as_errors and diagnostic.level == infra.Level.WARNING: # type: ignore[attr-defined]
|
| 335 |
+
diagnostic.level = infra.Level.ERROR # type: ignore[attr-defined]
|
| 336 |
+
self.diagnostics.append(diagnostic) # type: ignore[arg-type]
|
| 337 |
+
|
| 338 |
+
def log_and_raise_if_error(self, diagnostic: _Diagnostic) -> None:
|
| 339 |
+
"""Logs a diagnostic and raises an exception if it is an error.
|
| 340 |
+
|
| 341 |
+
Use this method for logging non inflight diagnostics where diagnostic level is not known or
|
| 342 |
+
lower than ERROR. If it is always expected raise, use `log` and explicit
|
| 343 |
+
`raise` instead. Otherwise there is no way to convey the message that it always
|
| 344 |
+
raises to Python intellisense and type checking tools.
|
| 345 |
+
|
| 346 |
+
This method should be used only after all the necessary information for the diagnostic
|
| 347 |
+
has been collected.
|
| 348 |
+
|
| 349 |
+
Args:
|
| 350 |
+
diagnostic: The diagnostic to add.
|
| 351 |
+
"""
|
| 352 |
+
self.log(diagnostic)
|
| 353 |
+
if diagnostic.level == infra.Level.ERROR:
|
| 354 |
+
if diagnostic.source_exception is not None:
|
| 355 |
+
raise diagnostic.source_exception
|
| 356 |
+
raise RuntimeErrorWithDiagnostic(diagnostic)
|
| 357 |
+
|
| 358 |
+
@contextlib.contextmanager
|
| 359 |
+
def add_inflight_diagnostic(
|
| 360 |
+
self, diagnostic: _Diagnostic
|
| 361 |
+
) -> Generator[_Diagnostic, None, None]:
|
| 362 |
+
"""Adds a diagnostic to the context.
|
| 363 |
+
|
| 364 |
+
Use this method to add diagnostics that are not created by the context.
|
| 365 |
+
Args:
|
| 366 |
+
diagnostic: The diagnostic to add.
|
| 367 |
+
"""
|
| 368 |
+
self._inflight_diagnostics.append(diagnostic)
|
| 369 |
+
try:
|
| 370 |
+
yield diagnostic
|
| 371 |
+
finally:
|
| 372 |
+
self._inflight_diagnostics.pop()
|
| 373 |
+
|
| 374 |
+
def push_inflight_diagnostic(self, diagnostic: _Diagnostic) -> None:
|
| 375 |
+
"""Pushes a diagnostic to the inflight diagnostics stack.
|
| 376 |
+
|
| 377 |
+
Args:
|
| 378 |
+
diagnostic: The diagnostic to push.
|
| 379 |
+
|
| 380 |
+
Raises:
|
| 381 |
+
ValueError: If the rule is not supported by the tool.
|
| 382 |
+
"""
|
| 383 |
+
self._inflight_diagnostics.append(diagnostic)
|
| 384 |
+
|
| 385 |
+
def pop_inflight_diagnostic(self) -> _Diagnostic:
|
| 386 |
+
"""Pops the last diagnostic from the inflight diagnostics stack.
|
| 387 |
+
|
| 388 |
+
Returns:
|
| 389 |
+
The popped diagnostic.
|
| 390 |
+
"""
|
| 391 |
+
return self._inflight_diagnostics.pop()
|
| 392 |
+
|
| 393 |
+
def inflight_diagnostic(self, rule: infra.Rule | None = None) -> _Diagnostic:
|
| 394 |
+
if rule is None:
|
| 395 |
+
# TODO(bowbao): Create builtin-rules and create diagnostic using that.
|
| 396 |
+
if len(self._inflight_diagnostics) <= 0:
|
| 397 |
+
raise AssertionError("No inflight diagnostics")
|
| 398 |
+
|
| 399 |
+
return self._inflight_diagnostics[-1]
|
| 400 |
+
else:
|
| 401 |
+
for diagnostic in reversed(self._inflight_diagnostics):
|
| 402 |
+
if diagnostic.rule == rule:
|
| 403 |
+
return diagnostic
|
| 404 |
+
raise AssertionError(f"No inflight diagnostic for rule {rule.name}")
|
.venv/Lib/site-packages/torch/onnx/_internal/diagnostics/infra/decorator.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import functools
|
| 5 |
+
import logging
|
| 6 |
+
import traceback
|
| 7 |
+
from typing import Any, Callable, Dict, Tuple
|
| 8 |
+
|
| 9 |
+
from torch.onnx._internal.diagnostics import infra
|
| 10 |
+
from torch.onnx._internal.diagnostics.infra import formatter, utils
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
MessageFormatterType = Callable[..., str]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def format_message_in_text(fn: Callable, *args: Any, **kwargs: Any) -> str:
|
| 17 |
+
return f"{formatter.display_name(fn)}. "
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def format_exception_in_markdown(exception: Exception) -> str:
|
| 21 |
+
msg_list = ["### Exception log", "```"]
|
| 22 |
+
msg_list.extend(
|
| 23 |
+
traceback.format_exception(type(exception), exception, exception.__traceback__)
|
| 24 |
+
)
|
| 25 |
+
msg_list.append("```")
|
| 26 |
+
return "\n".join(msg_list)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def format_function_signature_in_markdown(
|
| 30 |
+
fn: Callable,
|
| 31 |
+
args: tuple[Any, ...],
|
| 32 |
+
kwargs: dict[str, Any],
|
| 33 |
+
format_argument: Callable[[Any], str] = formatter.format_argument,
|
| 34 |
+
) -> str:
|
| 35 |
+
msg_list = [f"### Function Signature {formatter.display_name(fn)}"]
|
| 36 |
+
|
| 37 |
+
state = utils.function_state(fn, args, kwargs)
|
| 38 |
+
|
| 39 |
+
for k, v in state.items():
|
| 40 |
+
msg_list.append(f"- {k}: {format_argument(v)}")
|
| 41 |
+
|
| 42 |
+
return "\n".join(msg_list)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def format_return_values_in_markdown(
|
| 46 |
+
return_values: Any,
|
| 47 |
+
format_argument: Callable[[Any], str] = formatter.format_argument,
|
| 48 |
+
) -> str:
|
| 49 |
+
return f"{format_argument(return_values)}"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
ModifierCallableType = Callable[
|
| 53 |
+
[infra.Diagnostic, Callable, Tuple[Any, ...], Dict[str, Any], Any], None
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def diagnose_call(
|
| 58 |
+
rule: infra.Rule,
|
| 59 |
+
*,
|
| 60 |
+
level: infra.Level = infra.Level.NONE,
|
| 61 |
+
diagnostic_type: type[infra.Diagnostic] = infra.Diagnostic,
|
| 62 |
+
format_argument: Callable[[Any], str] = formatter.format_argument,
|
| 63 |
+
diagnostic_message_formatter: MessageFormatterType = format_message_in_text,
|
| 64 |
+
) -> Callable:
|
| 65 |
+
def decorator(fn):
|
| 66 |
+
@functools.wraps(fn)
|
| 67 |
+
def wrapper(*args, **kwargs):
|
| 68 |
+
common_error_message = "diagnose_call can only be applied to callables"
|
| 69 |
+
if not callable(fn):
|
| 70 |
+
raise AssertionError(
|
| 71 |
+
f"{common_error_message}. Got {type(fn)} instead of callable."
|
| 72 |
+
)
|
| 73 |
+
arg0 = args[0] if len(args) > 0 else None
|
| 74 |
+
if isinstance(ctx := arg0, infra.DiagnosticContext):
|
| 75 |
+
pass
|
| 76 |
+
elif isinstance(
|
| 77 |
+
ctx := getattr(arg0, "diagnostic_context", None),
|
| 78 |
+
infra.DiagnosticContext,
|
| 79 |
+
):
|
| 80 |
+
pass
|
| 81 |
+
else:
|
| 82 |
+
# NOTE: At decorate time, it can't tell if a callable is function or method.
|
| 83 |
+
# Technically both are regarded as function at that time.
|
| 84 |
+
raise AssertionError(
|
| 85 |
+
f"{common_error_message}. For {fn}, "
|
| 86 |
+
f"If it is a function, a DiagnosticContext instance must be present as "
|
| 87 |
+
f"the first argument. "
|
| 88 |
+
f"If it is a method, a DiagnosticContext instance must be present as "
|
| 89 |
+
f"the attribute 'diagnostic_context' of the 'self' argument."
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
diag = diagnostic_type(
|
| 93 |
+
rule,
|
| 94 |
+
level,
|
| 95 |
+
diagnostic_message_formatter(fn, *args, **kwargs),
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# pop the decorator frame
|
| 99 |
+
# TODO(bowbao): by default diagnostic doesn't have stack.
|
| 100 |
+
# So need to check before doing this. Make the code cleaner.
|
| 101 |
+
# Option: do not capture stack by default in diagnostic initialization.
|
| 102 |
+
stack: infra.Stack | None = None
|
| 103 |
+
if len(diag.stacks) > 0:
|
| 104 |
+
stack = diag.stacks[0]
|
| 105 |
+
stack.frames.pop(0)
|
| 106 |
+
|
| 107 |
+
# set function location
|
| 108 |
+
fn_location = utils.function_location(fn)
|
| 109 |
+
diag.locations.insert(0, fn_location)
|
| 110 |
+
# Add function location to the top of the stack.
|
| 111 |
+
if stack is not None:
|
| 112 |
+
stack.frames.insert(0, infra.StackFrame(location=fn_location))
|
| 113 |
+
|
| 114 |
+
with diag.log_section(logging.INFO, "Function Signature"):
|
| 115 |
+
diag.log(
|
| 116 |
+
logging.INFO,
|
| 117 |
+
"%s",
|
| 118 |
+
formatter.LazyString(
|
| 119 |
+
format_function_signature_in_markdown,
|
| 120 |
+
fn,
|
| 121 |
+
args,
|
| 122 |
+
kwargs,
|
| 123 |
+
format_argument,
|
| 124 |
+
),
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
return_values: Any = None
|
| 128 |
+
with ctx.add_inflight_diagnostic(diag) as diag:
|
| 129 |
+
try:
|
| 130 |
+
return_values = fn(*args, **kwargs)
|
| 131 |
+
with diag.log_section(logging.INFO, "Return values"):
|
| 132 |
+
diag.log(
|
| 133 |
+
logging.INFO,
|
| 134 |
+
"%s",
|
| 135 |
+
formatter.LazyString(
|
| 136 |
+
format_return_values_in_markdown,
|
| 137 |
+
return_values,
|
| 138 |
+
format_argument,
|
| 139 |
+
),
|
| 140 |
+
)
|
| 141 |
+
return return_values
|
| 142 |
+
except Exception as e:
|
| 143 |
+
diag.log_source_exception(logging.ERROR, e)
|
| 144 |
+
diag.level = infra.Level.ERROR
|
| 145 |
+
finally:
|
| 146 |
+
ctx.log_and_raise_if_error(diag)
|
| 147 |
+
|
| 148 |
+
return wrapper
|
| 149 |
+
|
| 150 |
+
return decorator
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# TODO(bowbao): decorator to report only when failed.
|
.venv/Lib/site-packages/torch/onnx/_internal/diagnostics/infra/formatter.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import dataclasses
|
| 4 |
+
import json
|
| 5 |
+
import re
|
| 6 |
+
import traceback
|
| 7 |
+
from typing import Any, Callable, Union
|
| 8 |
+
|
| 9 |
+
from torch._logging import LazyString
|
| 10 |
+
from torch.onnx._internal.diagnostics.infra import sarif
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# A list of types in the SARIF module to support pretty printing.
|
| 14 |
+
# This is solely for type annotation for the functions below.
|
| 15 |
+
_SarifClass = Union[
|
| 16 |
+
sarif.SarifLog,
|
| 17 |
+
sarif.Run,
|
| 18 |
+
sarif.ReportingDescriptor,
|
| 19 |
+
sarif.Result,
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def lazy_format_exception(exception: Exception) -> LazyString:
|
| 24 |
+
return LazyString(
|
| 25 |
+
lambda: "\n".join(
|
| 26 |
+
(
|
| 27 |
+
"```",
|
| 28 |
+
*traceback.format_exception(
|
| 29 |
+
type(exception), exception, exception.__traceback__
|
| 30 |
+
),
|
| 31 |
+
"```",
|
| 32 |
+
)
|
| 33 |
+
),
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def snake_case_to_camel_case(s: str) -> str:
|
| 38 |
+
splits = s.split("_")
|
| 39 |
+
if len(splits) <= 1:
|
| 40 |
+
return s
|
| 41 |
+
return "".join([splits[0], *map(str.capitalize, splits[1:])])
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def camel_case_to_snake_case(s: str) -> str:
|
| 45 |
+
return re.sub(r"([A-Z])", r"_\1", s).lower()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def kebab_case_to_snake_case(s: str) -> str:
|
| 49 |
+
return s.replace("-", "_")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _convert_key(
|
| 53 |
+
object: dict[str, Any] | Any, convert: Callable[[str], str]
|
| 54 |
+
) -> dict[str, Any] | Any:
|
| 55 |
+
"""Convert and update keys in a dictionary with "convert".
|
| 56 |
+
|
| 57 |
+
Any value that is a dictionary will be recursively updated.
|
| 58 |
+
Any value that is a list will be recursively searched.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
object: The object to update.
|
| 62 |
+
convert: The function to convert the keys, e.g. `kebab_case_to_snake_case`.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
The updated object.
|
| 66 |
+
"""
|
| 67 |
+
if not isinstance(object, dict):
|
| 68 |
+
return object
|
| 69 |
+
new_dict = {}
|
| 70 |
+
for k, v in object.items():
|
| 71 |
+
new_k = convert(k)
|
| 72 |
+
if isinstance(v, dict):
|
| 73 |
+
new_v = _convert_key(v, convert)
|
| 74 |
+
elif isinstance(v, list):
|
| 75 |
+
new_v = [_convert_key(elem, convert) for elem in v]
|
| 76 |
+
else:
|
| 77 |
+
new_v = v
|
| 78 |
+
if new_v is None:
|
| 79 |
+
# Otherwise unnecessarily bloated sarif log with "null"s.
|
| 80 |
+
continue
|
| 81 |
+
if new_v == -1:
|
| 82 |
+
# WAR: -1 as default value shouldn't be logged into sarif.
|
| 83 |
+
continue
|
| 84 |
+
|
| 85 |
+
new_dict[new_k] = new_v
|
| 86 |
+
|
| 87 |
+
return new_dict
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def sarif_to_json(attr_cls_obj: _SarifClass, indent: str | None = " ") -> str:
|
| 91 |
+
dict = dataclasses.asdict(attr_cls_obj)
|
| 92 |
+
dict = _convert_key(dict, snake_case_to_camel_case)
|
| 93 |
+
return json.dumps(dict, indent=indent, separators=(",", ":"))
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def format_argument(obj: Any) -> str:
|
| 97 |
+
return f"{type(obj)}"
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def display_name(fn: Callable) -> str:
|
| 101 |
+
if hasattr(fn, "__qualname__"):
|
| 102 |
+
return fn.__qualname__
|
| 103 |
+
elif hasattr(fn, "__name__"):
|
| 104 |
+
return fn.__name__
|
| 105 |
+
else:
|
| 106 |
+
return str(fn)
|
.venv/Lib/site-packages/torch/onnx/_internal/diagnostics/infra/utils.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import functools
|
| 4 |
+
import inspect
|
| 5 |
+
import traceback
|
| 6 |
+
from typing import Any, Callable, Mapping, Sequence
|
| 7 |
+
|
| 8 |
+
from torch.onnx._internal.diagnostics.infra import _infra, formatter
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def python_frame(frame: traceback.FrameSummary) -> _infra.StackFrame:
|
| 12 |
+
"""Returns a StackFrame for the given traceback.FrameSummary."""
|
| 13 |
+
snippet = frame.line
|
| 14 |
+
|
| 15 |
+
return _infra.StackFrame(
|
| 16 |
+
location=_infra.Location(
|
| 17 |
+
uri=frame.filename,
|
| 18 |
+
line=frame.lineno,
|
| 19 |
+
snippet=snippet,
|
| 20 |
+
function=frame.name,
|
| 21 |
+
message=snippet,
|
| 22 |
+
)
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def python_call_stack(frames_to_skip: int = 0, frames_to_log: int = 16) -> _infra.Stack:
|
| 27 |
+
"""Returns the current Python call stack."""
|
| 28 |
+
if frames_to_skip < 0:
|
| 29 |
+
raise ValueError("frames_to_skip must be non-negative")
|
| 30 |
+
if frames_to_log < 0:
|
| 31 |
+
raise ValueError("frames_to_log must be non-negative")
|
| 32 |
+
frames_to_skip += 1 # Skip this function.
|
| 33 |
+
stack = _infra.Stack()
|
| 34 |
+
# Frames are returned in order of oldest to newest.
|
| 35 |
+
frames = traceback.extract_stack(limit=frames_to_skip + frames_to_log)
|
| 36 |
+
frames.reverse()
|
| 37 |
+
stack.frames = [python_frame(frame) for frame in frames[frames_to_skip:]]
|
| 38 |
+
stack.message = "Python call stack"
|
| 39 |
+
return stack
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@functools.lru_cache
|
| 43 |
+
def _function_source_info(fn: Callable) -> tuple[Sequence[str], int, str | None]:
|
| 44 |
+
"""Returns the source lines, line number, and source file path for the given function.
|
| 45 |
+
|
| 46 |
+
Essentially, inspect.getsourcelines() and inspect.getsourcefile() combined.
|
| 47 |
+
Caching is applied to reduce the performance impact of this function.
|
| 48 |
+
"""
|
| 49 |
+
source_lines, lineno = inspect.getsourcelines(fn)
|
| 50 |
+
return source_lines, lineno, inspect.getsourcefile(fn)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def function_location(fn: Callable) -> _infra.Location:
|
| 54 |
+
"""Returns a Location for the given function."""
|
| 55 |
+
source_lines, lineno, uri = _function_source_info(fn)
|
| 56 |
+
snippet = source_lines[0].strip() if len(source_lines) > 0 else "<unknown>"
|
| 57 |
+
return _infra.Location(
|
| 58 |
+
uri=uri,
|
| 59 |
+
line=lineno,
|
| 60 |
+
snippet=snippet,
|
| 61 |
+
message=formatter.display_name(fn),
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def function_state(
|
| 66 |
+
fn: Callable, args: tuple[Any, ...], kwargs: dict[str, Any]
|
| 67 |
+
) -> Mapping[str, Any]:
|
| 68 |
+
bind = inspect.signature(fn).bind(*args, **kwargs)
|
| 69 |
+
return bind.arguments
|
.venv/Lib/site-packages/torch/onnx/_internal/io_adapter.py
ADDED
|
@@ -0,0 +1,641 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from typing import (
|
| 5 |
+
Any,
|
| 6 |
+
Callable,
|
| 7 |
+
Mapping,
|
| 8 |
+
Protocol,
|
| 9 |
+
runtime_checkable,
|
| 10 |
+
Sequence,
|
| 11 |
+
TYPE_CHECKING,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.export as torch_export
|
| 16 |
+
from torch.utils import _pytree as pytree
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
if TYPE_CHECKING:
|
| 20 |
+
import inspect
|
| 21 |
+
|
| 22 |
+
# TODO(bowbao): Add diagnostics for IO adapters.
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@runtime_checkable
|
| 26 |
+
class InputAdaptStep(Protocol):
|
| 27 |
+
"""A protocol that defines a step in the input adapting process.
|
| 28 |
+
|
| 29 |
+
The input adapting process is a sequence of steps that are applied to the
|
| 30 |
+
PyTorch model inputs to transform them into the inputs format expected by the
|
| 31 |
+
exported ONNX model. Each step takes the PyTorch model inputs as arguments and
|
| 32 |
+
returns the transformed inputs.
|
| 33 |
+
|
| 34 |
+
This serves as a base formalized construct for the transformation done to model
|
| 35 |
+
input signature by any individual component in the exporter.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def apply(
|
| 39 |
+
self,
|
| 40 |
+
model_args: Sequence[Any],
|
| 41 |
+
model_kwargs: Mapping[str, Any],
|
| 42 |
+
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
|
| 43 |
+
) -> tuple[Sequence[Any], Mapping[str, Any]]: ...
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class InputAdapter:
|
| 47 |
+
"""A class that adapts the PyTorch model inputs to exported ONNX model inputs format."""
|
| 48 |
+
|
| 49 |
+
def __init__(self, steps: list[InputAdaptStep] | None = None):
|
| 50 |
+
self._steps = steps or []
|
| 51 |
+
|
| 52 |
+
def append_step(self, step: InputAdaptStep) -> None:
|
| 53 |
+
"""Appends a step to the input adapt steps.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
step: The step to append.
|
| 57 |
+
"""
|
| 58 |
+
self._steps.append(step)
|
| 59 |
+
|
| 60 |
+
def apply(
|
| 61 |
+
self,
|
| 62 |
+
*model_args,
|
| 63 |
+
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
|
| 64 |
+
**model_kwargs,
|
| 65 |
+
) -> Sequence[int | float | bool | str | torch.Tensor | torch.dtype | None]:
|
| 66 |
+
"""Converts the PyTorch model inputs to exported ONNX model inputs format.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
model_args: The PyTorch model inputs.
|
| 70 |
+
model: The PyTorch model.
|
| 71 |
+
model_kwargs: The PyTorch model keyword inputs.
|
| 72 |
+
Returns:
|
| 73 |
+
A sequence of tensors converted from PyTorch model inputs.
|
| 74 |
+
"""
|
| 75 |
+
args: Sequence[Any] = model_args
|
| 76 |
+
kwargs: Mapping[str, Any] = model_kwargs
|
| 77 |
+
for step in self._steps:
|
| 78 |
+
args, kwargs = step.apply(args, kwargs, model=model)
|
| 79 |
+
assert not kwargs
|
| 80 |
+
return args
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@runtime_checkable
|
| 84 |
+
class OutputAdaptStep(Protocol):
|
| 85 |
+
"""A protocol that defines a step in the output adapting process.
|
| 86 |
+
|
| 87 |
+
The output adapting process is a sequence of steps that are applied to the
|
| 88 |
+
PyTorch model outputs to transform them into the outputs format produced by the
|
| 89 |
+
exported ONNX model. Each step takes the PyTorch model outputs as arguments and
|
| 90 |
+
returns the transformed outputs.
|
| 91 |
+
|
| 92 |
+
This serves as a base formalized construct for the transformation done to model
|
| 93 |
+
output signature by any individual component in the exporter.
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
def apply(
|
| 97 |
+
self,
|
| 98 |
+
model_outputs: Any,
|
| 99 |
+
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
|
| 100 |
+
) -> Any: ...
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class OutputAdapter:
|
| 104 |
+
"""A class that adapts the PyTorch model outputs to exported ONNX model outputs format."""
|
| 105 |
+
|
| 106 |
+
def __init__(self, steps: list[OutputAdaptStep] | None = None):
|
| 107 |
+
self._steps = steps or []
|
| 108 |
+
|
| 109 |
+
def append_step(self, step: OutputAdaptStep) -> None:
|
| 110 |
+
"""Appends a step to the output format steps.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
step: The step to append.
|
| 114 |
+
"""
|
| 115 |
+
self._steps.append(step)
|
| 116 |
+
|
| 117 |
+
def apply(
|
| 118 |
+
self,
|
| 119 |
+
model_outputs: Any,
|
| 120 |
+
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
|
| 121 |
+
) -> Sequence[torch.Tensor | int | float | bool | str]:
|
| 122 |
+
"""Converts the PyTorch model outputs to exported ONNX model outputs format.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
model_outputs: The PyTorch model outputs.
|
| 126 |
+
model: The PyTorch model.
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
PyTorch model outputs in exported ONNX model outputs format.
|
| 130 |
+
"""
|
| 131 |
+
for step in self._steps:
|
| 132 |
+
model_outputs = step.apply(model_outputs, model=model)
|
| 133 |
+
return model_outputs
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
# TODO: make_fx lose stack info https://github.com/pytorch/pytorch/issues/90276
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _replace_tuple_with_list(spec: pytree.TreeSpec) -> pytree.TreeSpec:
|
| 140 |
+
_type = list if spec.type == tuple else spec.type
|
| 141 |
+
return pytree.TreeSpec(
|
| 142 |
+
_type, spec.context, list(map(_replace_tuple_with_list, spec.children_specs))
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def _open_top_level_list_if_single_element(spec: pytree.TreeSpec) -> pytree.TreeSpec:
|
| 147 |
+
if spec.type == list and spec.num_children == 1:
|
| 148 |
+
return spec.children_specs[0]
|
| 149 |
+
return spec
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _assert_identical_pytree_spec(
|
| 153 |
+
spec1: pytree.TreeSpec, spec2: pytree.TreeSpec, error_message: str
|
| 154 |
+
) -> None:
|
| 155 |
+
"""Assert the two `TreeSpec` objects are identical.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
spec1: The first `TreeSpec` object.
|
| 159 |
+
spec2: The second `TreeSpec` object.
|
| 160 |
+
error_message: The error message to raise if the two `TreeSpec` objects are not
|
| 161 |
+
identical.
|
| 162 |
+
|
| 163 |
+
Raises:
|
| 164 |
+
ValueError: If the two `TreeSpec` objects are not identical.
|
| 165 |
+
"""
|
| 166 |
+
# TODO(bowbao): Turn this check into diagnostic. Consider warning instead of error.
|
| 167 |
+
pass_if_any_checks: Sequence[Callable[[], bool]] = [
|
| 168 |
+
lambda: spec1 == spec2,
|
| 169 |
+
# FIXME: Bug in `dynamo.export`. Sometimes outputs returned in 'list' instead of 'tuple'.
|
| 170 |
+
lambda: _replace_tuple_with_list(spec1) == _replace_tuple_with_list(spec2),
|
| 171 |
+
# FIXME: Bug in `dynamo.export`. Sometimes single function return is wrapped in list.
|
| 172 |
+
lambda: _open_top_level_list_if_single_element(spec1) == spec2,
|
| 173 |
+
lambda: spec1 == _open_top_level_list_if_single_element(spec2),
|
| 174 |
+
]
|
| 175 |
+
|
| 176 |
+
if not any(check() for check in pass_if_any_checks):
|
| 177 |
+
raise ValueError(f"{error_message}\nExpect {spec1}.\nActual {spec2}.")
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class BindInputStep(InputAdaptStep):
|
| 181 |
+
"""Bind the input arguments to the model signature."""
|
| 182 |
+
|
| 183 |
+
def __init__(self, model_signature: inspect.Signature):
|
| 184 |
+
self._model_signature = model_signature
|
| 185 |
+
|
| 186 |
+
def apply(
|
| 187 |
+
self,
|
| 188 |
+
model_args: Sequence[Any],
|
| 189 |
+
model_kwargs: Mapping[str, Any],
|
| 190 |
+
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
|
| 191 |
+
) -> tuple[Sequence[Any], Mapping[str, Any]]:
|
| 192 |
+
"""Bind the input arguments to the model signature.
|
| 193 |
+
|
| 194 |
+
We hope the input kwargs will be mapped to bound.args after binding.
|
| 195 |
+
If not, we will raise an error.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
model_args: The model args.
|
| 199 |
+
model_kwargs: The model kwargs.
|
| 200 |
+
model: The PyTorch model.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
A tuple of the model args and kwargs. args is always empty.
|
| 204 |
+
|
| 205 |
+
Raises:
|
| 206 |
+
ValueError: If there are keyword-only arguments left after binding args and
|
| 207 |
+
kwargs to model signature.
|
| 208 |
+
"""
|
| 209 |
+
bound = self._model_signature.bind(*model_args, **model_kwargs)
|
| 210 |
+
bound.apply_defaults()
|
| 211 |
+
|
| 212 |
+
# keyword-only arguments are not handled.
|
| 213 |
+
# bound.kwargs only contains keyword-only arguments after calling
|
| 214 |
+
# bind & apply_defaults, so we raise if it's not empty.
|
| 215 |
+
if bound.kwargs:
|
| 216 |
+
raise ValueError("Keyword-only arguments are not supported.")
|
| 217 |
+
return (), bound.arguments
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class MergeKwargsIntoArgsInputStep(InputAdaptStep):
|
| 221 |
+
"""Merge the input kwargs into the input args."""
|
| 222 |
+
|
| 223 |
+
def apply(
|
| 224 |
+
self,
|
| 225 |
+
model_args: Sequence[Any],
|
| 226 |
+
model_kwargs: Mapping[str, Any],
|
| 227 |
+
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
|
| 228 |
+
) -> tuple[Sequence[Any], Mapping[str, Any]]:
|
| 229 |
+
"""Merge the input kwargs into the input args.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
model_args: The model args.
|
| 233 |
+
model_kwargs: The model kwargs.
|
| 234 |
+
model: The PyTorch model.
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
A tuple of the model args and kwargs. kwargs is always empty.
|
| 238 |
+
"""
|
| 239 |
+
return tuple(model_args) + tuple(model_kwargs.values()), {}
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class LiftParametersAndBuffersIntoArgsInputStep(InputAdaptStep):
|
| 243 |
+
"""Append parameters and buffers to model's positional argument list."""
|
| 244 |
+
|
| 245 |
+
def __init__(self, inputs: tuple[torch.Tensor, ...]) -> None:
|
| 246 |
+
self.inputs = inputs
|
| 247 |
+
|
| 248 |
+
def apply(
|
| 249 |
+
self,
|
| 250 |
+
model_args: Sequence[Any],
|
| 251 |
+
model_kwargs: Mapping[str, Any],
|
| 252 |
+
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
|
| 253 |
+
) -> tuple[Sequence[Any], Mapping[str, Any]]:
|
| 254 |
+
"""Append model's parameters and buffers into its input.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
model_args: The model args.
|
| 258 |
+
model_kwargs: The model kwargs.
|
| 259 |
+
model: The PyTorch model.
|
| 260 |
+
|
| 261 |
+
Returns:
|
| 262 |
+
A tuple of the model args + appended inputs and kwargs.
|
| 263 |
+
"""
|
| 264 |
+
return (*model_args, *self.inputs), model_kwargs
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class ConvertComplexToRealRepresentationInputStep(InputAdaptStep):
|
| 268 |
+
"""Convert complex dtype tensors to real representation tensors.
|
| 269 |
+
|
| 270 |
+
ONNX does not support complex dtype tensors. Thus, we convert complex dtype tensors
|
| 271 |
+
to real representation tensors (i.e., float dtype tensors with an extra dimension
|
| 272 |
+
representing the real and imaginary parts of the complex number).
|
| 273 |
+
|
| 274 |
+
"""
|
| 275 |
+
|
| 276 |
+
def apply(
|
| 277 |
+
self,
|
| 278 |
+
model_args: Sequence[Any],
|
| 279 |
+
model_kwargs: Mapping[str, Any],
|
| 280 |
+
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
|
| 281 |
+
) -> tuple[Sequence[Any], Mapping[str, Any]]:
|
| 282 |
+
"""Convert complex tensors to float tensors.
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
model_args: The model args.
|
| 286 |
+
model_kwargs: The model kwargs.
|
| 287 |
+
model: The PyTorch model.
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
A tuple of the model args and kwargs.
|
| 291 |
+
"""
|
| 292 |
+
return (
|
| 293 |
+
tuple(
|
| 294 |
+
torch.view_as_real(arg.resolve_conj())
|
| 295 |
+
if isinstance(arg, torch.Tensor) and arg.is_complex()
|
| 296 |
+
else arg
|
| 297 |
+
for arg in model_args
|
| 298 |
+
),
|
| 299 |
+
model_kwargs,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class RemoveNoneInputStep(InputAdaptStep):
|
| 304 |
+
"""Remove `None` from arguments.
|
| 305 |
+
|
| 306 |
+
This adapt step assumes ``model_kwargs`` is empty. It also assumes ``model_args``
|
| 307 |
+
is flattened, i.e. it does not check `None` inside nested collections.
|
| 308 |
+
"""
|
| 309 |
+
|
| 310 |
+
def apply(
|
| 311 |
+
self,
|
| 312 |
+
model_args: Sequence[Any],
|
| 313 |
+
model_kwargs: Mapping[str, Any],
|
| 314 |
+
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
|
| 315 |
+
) -> tuple[Sequence[Any], Mapping[str, Any]]:
|
| 316 |
+
"""Remove `None` from arguments.
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
model_args: The model args.
|
| 320 |
+
model_kwargs: The model kwargs.
|
| 321 |
+
model: The PyTorch model.
|
| 322 |
+
|
| 323 |
+
Returns:
|
| 324 |
+
A tuple of the model args and kwargs.
|
| 325 |
+
|
| 326 |
+
Raises:
|
| 327 |
+
ValueError: If `model_kwargs` is not empty.
|
| 328 |
+
"""
|
| 329 |
+
assert not model_kwargs
|
| 330 |
+
return tuple(arg for arg in model_args if arg is not None), {}
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
class RemoveNonTensorInputStep(InputAdaptStep):
|
| 334 |
+
"""Remove the non-tensor input arguments.
|
| 335 |
+
|
| 336 |
+
Dynamo does not support non-tensor input arguments (https://github.com/pytorch/pytorch/issues/99534).
|
| 337 |
+
|
| 338 |
+
Specifically, it does put the input into graph with an empty node, but consumed by no ones.
|
| 339 |
+
The concrete value is embedded into the graph as a constant arg of a target node. Meta
|
| 340 |
+
suggests in this case that one should rewrite the model code to make it tensor if the
|
| 341 |
+
input value is supposed to change at runtime. We might need to further investigate
|
| 342 |
+
the feasibility of that suggestion.
|
| 343 |
+
|
| 344 |
+
For example,
|
| 345 |
+
|
| 346 |
+
def func(x, b=1.0):
|
| 347 |
+
y = x + b
|
| 348 |
+
z = y.relu()
|
| 349 |
+
return (y, z)
|
| 350 |
+
|
| 351 |
+
x = torch.randn(1, 1, 2, dtype=torch.float32)
|
| 352 |
+
gm_fun, _ = dynamo.export(func, x, b=8.0, aten_graph=True, tracing_mode="real")
|
| 353 |
+
|
| 354 |
+
# class GraphModule(torch.nn.Module):
|
| 355 |
+
# def forward(self, x, b):
|
| 356 |
+
# arg0: f32[1, 1, 2], arg1, = fx_pytree.tree_flatten_spec(([x, b], {}), self._in_spec)
|
| 357 |
+
# # File: path/to/pytorch/test_constant_input.py:5, code: y = x + b
|
| 358 |
+
# add_tensor: f32[1, 1, 2] = torch.ops.aten.add.Tensor(arg0, 8.0); arg0 = None
|
| 359 |
+
|
| 360 |
+
# # File: path/to/pytorch/test_constant_input.py:6, code: z = y.relu()
|
| 361 |
+
# relu_default: f32[1, 1, 2] = torch.ops.aten.relu.default(add_tensor)
|
| 362 |
+
# return pytree.tree_unflatten([add_tensor, relu_default], self._out_spec)
|
| 363 |
+
|
| 364 |
+
Empty torch.fx.Node input leading to a mismatched number of input with PyTorch, as
|
| 365 |
+
it's ignored in ONNX graph. Thus, we delete the useless input here.
|
| 366 |
+
|
| 367 |
+
"""
|
| 368 |
+
|
| 369 |
+
def apply(
|
| 370 |
+
self,
|
| 371 |
+
model_args: Sequence[Any],
|
| 372 |
+
model_kwargs: Mapping[str, Any],
|
| 373 |
+
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
|
| 374 |
+
) -> tuple[Sequence[Any], Mapping[str, Any]]:
|
| 375 |
+
"""Remove Constant from arguments.
|
| 376 |
+
|
| 377 |
+
Args:
|
| 378 |
+
model_args: The model args.
|
| 379 |
+
model_kwargs: The model kwargs.
|
| 380 |
+
model: The PyTorch model.
|
| 381 |
+
|
| 382 |
+
Returns:
|
| 383 |
+
A tuple of the model args and kwargs.
|
| 384 |
+
|
| 385 |
+
Raises:
|
| 386 |
+
ValueError: If `model_kwargs` is not empty.
|
| 387 |
+
"""
|
| 388 |
+
assert not model_kwargs
|
| 389 |
+
return (
|
| 390 |
+
tuple(
|
| 391 |
+
arg
|
| 392 |
+
for arg in model_args
|
| 393 |
+
if not isinstance(arg, (int, float, bool, str))
|
| 394 |
+
),
|
| 395 |
+
{},
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
class FlattenInputWithTreeSpecValidationInputStep(InputAdaptStep):
|
| 400 |
+
"""Flatten nested collection types and return a flat list of elements.
|
| 401 |
+
|
| 402 |
+
ONNX can't represent collection types (e.g., dictionary, tuple of tuple of tensor,
|
| 403 |
+
etc).
|
| 404 |
+
|
| 405 |
+
This class stores the `SpecTree` output produced when `adapt` was called the first
|
| 406 |
+
time. It then validates the `SpecTree` output produced from later `adapt` calls.
|
| 407 |
+
"""
|
| 408 |
+
|
| 409 |
+
_spec: pytree.TreeSpec | None = None
|
| 410 |
+
|
| 411 |
+
def apply(
|
| 412 |
+
self,
|
| 413 |
+
model_args: Sequence[Any],
|
| 414 |
+
model_kwargs: Mapping[str, Any],
|
| 415 |
+
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
|
| 416 |
+
) -> tuple[Sequence[Any], Mapping[str, Any]]:
|
| 417 |
+
"""Flatten the model args and kwargs and validate the `SpecTree` output.
|
| 418 |
+
|
| 419 |
+
Args:
|
| 420 |
+
model_args: The model args.
|
| 421 |
+
model_kwargs: The model kwargs.
|
| 422 |
+
model: The PyTorch model.
|
| 423 |
+
|
| 424 |
+
Returns:
|
| 425 |
+
A tuple of the flattened model args and kwargs. The kwargs is empty, because
|
| 426 |
+
they are flattened and merged into the args.
|
| 427 |
+
|
| 428 |
+
Raises:
|
| 429 |
+
ValueError: If the `SpecTree` output produced from the current `model_outputs`
|
| 430 |
+
is not identical to the `SpecTree` output produced from the first
|
| 431 |
+
`model_outputs` that was passed to this method.
|
| 432 |
+
"""
|
| 433 |
+
flattened_args, spec = pytree.tree_flatten((model_args, model_kwargs))
|
| 434 |
+
if self._spec is None:
|
| 435 |
+
self._spec = spec
|
| 436 |
+
else:
|
| 437 |
+
_assert_identical_pytree_spec(
|
| 438 |
+
self._spec,
|
| 439 |
+
spec,
|
| 440 |
+
error_message="Model inputs incompatible with the format that was exported. ",
|
| 441 |
+
)
|
| 442 |
+
return flattened_args, {}
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
class FlattenOutputStep(OutputAdaptStep):
|
| 446 |
+
"""Flatten nested collection types and return a flat list of elements.
|
| 447 |
+
|
| 448 |
+
ONNX can't represent collection types (e.g., dictionary, tuple of tuple of tensor,
|
| 449 |
+
etc).
|
| 450 |
+
|
| 451 |
+
NOTE: Ideally we would want to use ``FlattenOutputWithTreeSpecValidationOutputStep``, such
|
| 452 |
+
that `SpecTree` can be validate for new model outputs. However, this is not possible
|
| 453 |
+
currently because we never have access to real PyTorch model outputs during export.
|
| 454 |
+
Only traced outputs may be available, but they are not an accurate reflection of the
|
| 455 |
+
original PyTorch model outputs format as they are typically in their own unique format,
|
| 456 |
+
depending on the tracing strategy.
|
| 457 |
+
"""
|
| 458 |
+
|
| 459 |
+
def apply(
|
| 460 |
+
self,
|
| 461 |
+
model_outputs: Any,
|
| 462 |
+
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
|
| 463 |
+
) -> Sequence[Any]:
|
| 464 |
+
"""Flatten the model outputs.
|
| 465 |
+
|
| 466 |
+
Args:
|
| 467 |
+
model_outputs: The model outputs to flatten.
|
| 468 |
+
model: The PyTorch model.
|
| 469 |
+
|
| 470 |
+
Returns:
|
| 471 |
+
A tuple of the flattened model outputs.
|
| 472 |
+
"""
|
| 473 |
+
return pytree.tree_leaves(model_outputs)
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
class ConvertComplexToRealRepresentationOutputStep(OutputAdaptStep):
|
| 477 |
+
"""Convert complex dtype tensors to real representation tensors.
|
| 478 |
+
|
| 479 |
+
ONNX does not support complex dtype tensors. Thus, we convert complex dtype tensors
|
| 480 |
+
to real representation tensors (i.e., float dtype tensors with an extra dimension
|
| 481 |
+
representing the real and imaginary parts of the complex number).
|
| 482 |
+
|
| 483 |
+
"""
|
| 484 |
+
|
| 485 |
+
def apply(
|
| 486 |
+
self,
|
| 487 |
+
model_outputs: Any,
|
| 488 |
+
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
|
| 489 |
+
) -> Any:
|
| 490 |
+
"""Convert float tensors to complex tensors.
|
| 491 |
+
|
| 492 |
+
Args:
|
| 493 |
+
model_output: The model output.
|
| 494 |
+
model: The PyTorch model.
|
| 495 |
+
|
| 496 |
+
Returns:
|
| 497 |
+
A tuple of the model output.
|
| 498 |
+
"""
|
| 499 |
+
return [
|
| 500 |
+
torch.view_as_real(output.resolve_conj())
|
| 501 |
+
if isinstance(output, torch.Tensor) and torch.is_complex(output)
|
| 502 |
+
else output
|
| 503 |
+
for output in model_outputs
|
| 504 |
+
]
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
class FlattenOutputWithTreeSpecValidationOutputStep(OutputAdaptStep):
|
| 508 |
+
"""Same as ``FlattenOutputStep``, with additional `TreeSpec` validation.
|
| 509 |
+
|
| 510 |
+
This class stores the `SpecTree` output produced when `adapt` was called the first
|
| 511 |
+
time. It then validates the `SpecTree` output produced from later `adapt` calls.
|
| 512 |
+
"""
|
| 513 |
+
|
| 514 |
+
_spec: pytree.TreeSpec | None = None
|
| 515 |
+
|
| 516 |
+
def apply(
|
| 517 |
+
self,
|
| 518 |
+
model_outputs: Any,
|
| 519 |
+
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
|
| 520 |
+
) -> Sequence[Any]:
|
| 521 |
+
"""Flatten the model outputs and validate the `SpecTree` output.
|
| 522 |
+
|
| 523 |
+
Args:
|
| 524 |
+
model_outputs: The model outputs to flatten.
|
| 525 |
+
model: The PyTorch model.
|
| 526 |
+
|
| 527 |
+
Returns:
|
| 528 |
+
flattened_outputs: The flattened model outputs.
|
| 529 |
+
|
| 530 |
+
Raises:
|
| 531 |
+
ValueError: If the `SpecTree` output produced from the current `model_outputs`
|
| 532 |
+
is not identical to the `SpecTree` output produced from the first
|
| 533 |
+
`model_outputs` that was passed to this method.
|
| 534 |
+
"""
|
| 535 |
+
flattened_outputs, spec = pytree.tree_flatten(model_outputs)
|
| 536 |
+
if self._spec is None:
|
| 537 |
+
self._spec = spec
|
| 538 |
+
else:
|
| 539 |
+
_assert_identical_pytree_spec(
|
| 540 |
+
self._spec,
|
| 541 |
+
spec,
|
| 542 |
+
error_message="Model outputs incompatible with the format that was exported. ",
|
| 543 |
+
)
|
| 544 |
+
return flattened_outputs
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
class PrependParamsBuffersConstantAotAutogradInputStep(InputAdaptStep):
|
| 548 |
+
"""Prepend model parameters, buffers and constants to the user input.
|
| 549 |
+
|
| 550 |
+
:func:`torch.export.export` lifts model parameters, buffers and constants as model input, thus, they
|
| 551 |
+
must be added to the user input before the model is executed.
|
| 552 |
+
|
| 553 |
+
Args:
|
| 554 |
+
model: The PyTorch model with embedded parameters and buffers.
|
| 555 |
+
"""
|
| 556 |
+
|
| 557 |
+
def apply(
|
| 558 |
+
self,
|
| 559 |
+
model_args: Sequence[Any],
|
| 560 |
+
model_kwargs: Mapping[str, Any],
|
| 561 |
+
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
|
| 562 |
+
) -> tuple[Sequence[Any], Mapping[str, Any]]:
|
| 563 |
+
"""Convert complex tensors to float tensors.
|
| 564 |
+
|
| 565 |
+
Args:
|
| 566 |
+
model_args: The model args.
|
| 567 |
+
model_kwargs: The model kwargs.
|
| 568 |
+
model: The PyTorch model.
|
| 569 |
+
|
| 570 |
+
Returns:
|
| 571 |
+
A tuple of the model args and kwargs.
|
| 572 |
+
"""
|
| 573 |
+
ordered_params = tuple(
|
| 574 |
+
model.state_dict[name] # type: ignore[union-attr,index]
|
| 575 |
+
for name in model.graph_signature.parameters # type: ignore[union-attr]
|
| 576 |
+
)
|
| 577 |
+
non_persistent_buffers = set(model.graph_signature.non_persistent_buffers) # type: ignore[union-attr]
|
| 578 |
+
ordered_buffers = []
|
| 579 |
+
for name in model.graph_signature.buffers: # type: ignore[union-attr]
|
| 580 |
+
if name in non_persistent_buffers:
|
| 581 |
+
ordered_buffers.append(model.constants[name]) # type: ignore[union-attr]
|
| 582 |
+
else:
|
| 583 |
+
ordered_buffers.append(model.state_dict[name]) # type: ignore[union-attr,index]
|
| 584 |
+
ordered_constant_tensors = tuple(
|
| 585 |
+
model.constants[fqn] # type: ignore[union-attr,index]
|
| 586 |
+
for fqn in model.graph_signature.lifted_tensor_constants # type: ignore[union-attr]
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
# NOTE: calling convention is first params, then buffers, then args as user supplied them.
|
| 590 |
+
# See: torch/_functorch/aot_autograd.py#L1034
|
| 591 |
+
updated_args = (
|
| 592 |
+
*ordered_params,
|
| 593 |
+
*ordered_buffers,
|
| 594 |
+
*ordered_constant_tensors,
|
| 595 |
+
*model_args,
|
| 596 |
+
)
|
| 597 |
+
if model_kwargs:
|
| 598 |
+
return MergeKwargsIntoArgsInputStep().apply(
|
| 599 |
+
updated_args, model_kwargs, model=model
|
| 600 |
+
)
|
| 601 |
+
return updated_args, {}
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
class PrependParamsAndBuffersAotAutogradOutputStep(OutputAdaptStep):
|
| 605 |
+
"""Prepend model's mutated buffers to the user output.
|
| 606 |
+
|
| 607 |
+
:func:`torch.export.export` lifts model's mutated buffers as outputs, thus, they
|
| 608 |
+
must be added to the user output after the model is executed.
|
| 609 |
+
|
| 610 |
+
Args:
|
| 611 |
+
model: The PyTorch model with mutated buffers.
|
| 612 |
+
"""
|
| 613 |
+
|
| 614 |
+
def apply(
|
| 615 |
+
self,
|
| 616 |
+
model_outputs: Any,
|
| 617 |
+
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
|
| 618 |
+
) -> Sequence[Any]:
|
| 619 |
+
"""Flatten the model outputs and validate the `SpecTree` output.
|
| 620 |
+
|
| 621 |
+
Args:
|
| 622 |
+
model_outputs: The model outputs to flatten.
|
| 623 |
+
model: The PyTorch model.
|
| 624 |
+
|
| 625 |
+
Returns:
|
| 626 |
+
flattened_outputs: The flattened model outputs.
|
| 627 |
+
"""
|
| 628 |
+
|
| 629 |
+
assert isinstance(
|
| 630 |
+
model, torch_export.ExportedProgram
|
| 631 |
+
), "'model' must be torch_export.ExportedProgram"
|
| 632 |
+
ordered_buffers = tuple(
|
| 633 |
+
model.state_dict[name]
|
| 634 |
+
if name in model.state_dict
|
| 635 |
+
else model.constants[name]
|
| 636 |
+
for name in model.graph_signature.buffers_to_mutate.values()
|
| 637 |
+
)
|
| 638 |
+
|
| 639 |
+
# NOTE: calling convention is first mutated buffers, then outputs args as model returned them.
|
| 640 |
+
updated_outputs = (*ordered_buffers, *model_outputs)
|
| 641 |
+
return updated_outputs
|
.venv/Lib/site-packages/torch/onnx/_internal/jit_utils.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
"""Utilities for manipulating the torch.Graph object and the torchscript."""
|
| 3 |
+
|
| 4 |
+
# TODO(justinchuby): Move more of the symbolic helper functions here and expose
|
| 5 |
+
# them to the user.
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import dataclasses
|
| 10 |
+
import re
|
| 11 |
+
import typing
|
| 12 |
+
from typing import Any, Iterable, Sequence
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from torch import _C
|
| 16 |
+
from torch.onnx._globals import GLOBALS
|
| 17 |
+
from torch.onnx._internal import registration
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
_ATTR_PATTERN = re.compile("^(.+)_(([ifstgz])|(ty))$")
|
| 21 |
+
_SKIP_NODE_ATTRIBUTES = {"inplace", "aten"}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclasses.dataclass
|
| 25 |
+
class GraphContext:
|
| 26 |
+
"""Extra context for symbolic functions with all methods from torch.Graph.
|
| 27 |
+
|
| 28 |
+
NOTE: This class is not meant for external consumption. Please do not depend on
|
| 29 |
+
it outside of torch.onnx as the interface may evolve.
|
| 30 |
+
|
| 31 |
+
Attributes:
|
| 32 |
+
graph: The _C.Graph being constructed.
|
| 33 |
+
block: The current _C.Block being constructed.
|
| 34 |
+
opset: The opset version.
|
| 35 |
+
original_node: Current node that is being converted from.
|
| 36 |
+
params_dict: Mapping from graph initializer name to IValue.
|
| 37 |
+
env: Mapping from Torch domain graph Value to ONNX domain graph Value.
|
| 38 |
+
values_in_env: Set of all values in env, for constant-time lookups.
|
| 39 |
+
new_nodes: List that tracks all new nodes that are added (used to make
|
| 40 |
+
sure metadata is propagated to all new nodes).
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
graph: _C.Graph
|
| 44 |
+
block: _C.Block
|
| 45 |
+
opset: int
|
| 46 |
+
original_node: _C.Node
|
| 47 |
+
params_dict: dict[str, _C.IValue]
|
| 48 |
+
env: dict[_C.Value, _C.Value]
|
| 49 |
+
values_in_env: set[_C.Value]
|
| 50 |
+
new_nodes: list[_C.Node] = dataclasses.field(default_factory=list)
|
| 51 |
+
|
| 52 |
+
# Relay methods from _C.Graph for compatibility with symbolic functions that expect
|
| 53 |
+
# a _C.Graph
|
| 54 |
+
def __getattr__(self, name: str) -> Any:
|
| 55 |
+
return getattr(self.graph, name)
|
| 56 |
+
|
| 57 |
+
def op(
|
| 58 |
+
self,
|
| 59 |
+
opname: str,
|
| 60 |
+
*raw_args: torch.Tensor | _C.Value,
|
| 61 |
+
outputs: int = 1,
|
| 62 |
+
**kwargs,
|
| 63 |
+
):
|
| 64 |
+
"""Creates an ONNX operator "opname", taking "raw_args" as inputs and "kwargs" as attributes.
|
| 65 |
+
|
| 66 |
+
The set of operators and the inputs/attributes they take
|
| 67 |
+
is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified
|
| 71 |
+
with a namespace, e.g., `aten::add`.
|
| 72 |
+
raw_args: The inputs to the operator; usually provided
|
| 73 |
+
as arguments to the `symbolic` definition.
|
| 74 |
+
outputs: The number of outputs this operator returns.
|
| 75 |
+
By default an operator is assumed to return a single output.
|
| 76 |
+
If `outputs` is greater than one, this functions returns a tuple
|
| 77 |
+
of output `Value`, representing each output of the ONNX operator
|
| 78 |
+
in order.
|
| 79 |
+
kwargs: The attributes of the ONNX operator, whose keys are named
|
| 80 |
+
according to the following convention: `alpha_f` indicates
|
| 81 |
+
the `alpha` attribute with type `f`. The valid type specifiers are
|
| 82 |
+
`f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute
|
| 83 |
+
specified with type float accepts either a single float, or a
|
| 84 |
+
list of floats (e.g., you would say `dims_i` for a `dims` attribute
|
| 85 |
+
that takes a list of integers).
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
The value representing the single output of this operator (see the `outputs`
|
| 89 |
+
keyword argument for multi-return nodes).
|
| 90 |
+
"""
|
| 91 |
+
# FIXME(justinchuby): Add the return type back once we know how to handle mypy
|
| 92 |
+
return _add_op(self, opname, *raw_args, outputs=outputs, **kwargs)
|
| 93 |
+
|
| 94 |
+
def aten_op(self, operator: str, *args, overload_name: str = "", **kwargs):
|
| 95 |
+
"""Generates an ONNX ATen op node.
|
| 96 |
+
|
| 97 |
+
This function is for backward compatibility with the old symbolic functions.
|
| 98 |
+
"""
|
| 99 |
+
return self.op(
|
| 100 |
+
"aten::ATen",
|
| 101 |
+
*args,
|
| 102 |
+
operator_s=operator,
|
| 103 |
+
overload_name_s=overload_name,
|
| 104 |
+
**kwargs,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# NOTE: For backward compatibility with the old symbolic functions.
|
| 108 |
+
# We are probably going to remove this only after the fx exporter is established.
|
| 109 |
+
at = aten_op
|
| 110 |
+
|
| 111 |
+
def onnxscript_op(
|
| 112 |
+
self,
|
| 113 |
+
onnx_fn,
|
| 114 |
+
*raw_args: torch.Tensor | _C.Value,
|
| 115 |
+
outputs: int = 1,
|
| 116 |
+
**kwargs,
|
| 117 |
+
):
|
| 118 |
+
"""Creates an ONNX operator from onnx-script function, taking "raw_args" as inputs and "kwargs" as attributes.
|
| 119 |
+
|
| 120 |
+
onnx-script repository: https://github.com/microsoft/onnx-script
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
onnx_fn: ONNXFunction from onnx-script; An example can be found at
|
| 124 |
+
https://github.com/microsoft/onnx-script#example
|
| 125 |
+
raw_args: The inputs to the operator; usually provided
|
| 126 |
+
as arguments to the `symbolic` definition.
|
| 127 |
+
outputs: The number of outputs this operator returns.
|
| 128 |
+
By default an operator is assumed to return a single output.
|
| 129 |
+
If `outputs` is greater than one, this functions returns a tuple
|
| 130 |
+
of output `Value`, representing each output of the ONNX operator
|
| 131 |
+
in order.
|
| 132 |
+
kwargs: The attributes of the ONNX operator, whose keys are named
|
| 133 |
+
according to the following convention: `alpha_f` indicates
|
| 134 |
+
the `alpha` attribute with type `f`. The valid type specifiers are
|
| 135 |
+
`f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute
|
| 136 |
+
specified with type float accepts either a single float, or a
|
| 137 |
+
list of floats (e.g., you would say `dims_i` for a `dims` attribute
|
| 138 |
+
that takes a list of integers).
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
The value representing the single output of this operator (see the `outputs`
|
| 142 |
+
keyword argument for multi-return nodes).
|
| 143 |
+
"""
|
| 144 |
+
# NOTE(titaiwang): This is using class attributes, and it needs to be updated
|
| 145 |
+
# if onnx-script makes any change on these.
|
| 146 |
+
symbolic_name = f"{onnx_fn.opset.domain}::{onnx_fn.name}"
|
| 147 |
+
opset_version = onnx_fn.opset.version
|
| 148 |
+
|
| 149 |
+
registration.custom_onnx_symbolic(symbolic_name, opset_version)(onnx_fn)
|
| 150 |
+
|
| 151 |
+
return _add_op(self, symbolic_name, *raw_args, outputs=outputs, **kwargs)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def add_op_with_blocks(
|
| 155 |
+
graph_context: GraphContext,
|
| 156 |
+
opname: str,
|
| 157 |
+
*inputs: _C.Value,
|
| 158 |
+
outputs: int = 1,
|
| 159 |
+
n_blocks: int = 1,
|
| 160 |
+
**attributes,
|
| 161 |
+
) -> tuple[Any, tuple[GraphContext, ...], _C.Node]:
|
| 162 |
+
"""Creates an ONNX operator "opname", taking inputs and attributes.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
graph_context: The context for the current graph.
|
| 166 |
+
opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified
|
| 167 |
+
with a namespace, e.g., `aten::add`.
|
| 168 |
+
inputs: The inputs to the operator.
|
| 169 |
+
outputs: The number of outputs this operator returns.
|
| 170 |
+
By default an operator is assumed to return a single output.
|
| 171 |
+
If `outputs` is greater than one, this functions returns a tuple
|
| 172 |
+
of output `Value`, representing each output of the ONNX operator
|
| 173 |
+
in order.
|
| 174 |
+
n_blocks: The number of sub-blocks to create in the node.
|
| 175 |
+
attributes: The attributes of the ONNX operator.
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
A tuple of (output_values, new_contexts, node) where:
|
| 179 |
+
output_values: One or more output value of this operator
|
| 180 |
+
(see the `outputs` keyword argument for multi-return nodes).
|
| 181 |
+
new_contexts: A tuple of new graph contexts for each sub-block.
|
| 182 |
+
node: The node representing the operator.
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
output_values = graph_context.op(opname, *inputs, outputs=outputs, **attributes)
|
| 186 |
+
if isinstance(output_values, Sequence):
|
| 187 |
+
node = output_values[0].node()
|
| 188 |
+
else:
|
| 189 |
+
node = output_values.node()
|
| 190 |
+
|
| 191 |
+
new_contexts = []
|
| 192 |
+
for _ in range(n_blocks):
|
| 193 |
+
new_block = node.addBlock()
|
| 194 |
+
# Create shallow copy of the graph context and update the block
|
| 195 |
+
new_context = dataclasses.replace(graph_context, block=new_block)
|
| 196 |
+
new_contexts.append(new_context)
|
| 197 |
+
|
| 198 |
+
return output_values, tuple(new_contexts), node
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def _add_op(
|
| 202 |
+
graph_context: GraphContext,
|
| 203 |
+
opname: str,
|
| 204 |
+
*args: torch.Tensor | _C.Value,
|
| 205 |
+
outputs: int = 1,
|
| 206 |
+
**kwargs,
|
| 207 |
+
):
|
| 208 |
+
"""Creates an ONNX operator "opname", taking "args" as inputs and attributes "kwargs".
|
| 209 |
+
|
| 210 |
+
The set of operators and the inputs/attributes they take
|
| 211 |
+
is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md
|
| 212 |
+
|
| 213 |
+
This function is monkey-patched onto Graph.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
graph_context: The Torch Graph or Block.
|
| 217 |
+
opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified
|
| 218 |
+
with a namespace, e.g., `aten::add`.
|
| 219 |
+
args: The inputs to the operator; usually provided
|
| 220 |
+
as arguments to the `symbolic` definition.
|
| 221 |
+
outputs: The number of outputs this operator returns.
|
| 222 |
+
By default an operator is assumed to return a single output.
|
| 223 |
+
If `outputs` is greater than one, this functions returns a tuple
|
| 224 |
+
of output `Value`, representing each output of the ONNX operator
|
| 225 |
+
in order.
|
| 226 |
+
kwargs: The attributes of the ONNX operator, whose keys are named
|
| 227 |
+
according to the following convention: `alpha_f` indicates
|
| 228 |
+
the `alpha` attribute with type `f`. The valid type specifiers are
|
| 229 |
+
`f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute
|
| 230 |
+
specified with type float accepts either a single float, or a
|
| 231 |
+
list of floats (e.g., you would say `dims_i` for a `dims` attribute
|
| 232 |
+
that takes a list of integers).
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
(Union[_C.Value, Tuple[_C.Value, ...]])
|
| 236 |
+
The value representing the single output of this operator (see the `outputs`
|
| 237 |
+
keyword argument for multi-return nodes).
|
| 238 |
+
"""
|
| 239 |
+
inputs = [_const_if_tensor(graph_context, arg) for arg in args]
|
| 240 |
+
# Filter out None attributes, this can be convenient client side because
|
| 241 |
+
# now they can pass through None attributes, and have them not show up
|
| 242 |
+
attributes = {k: v for k, v in kwargs.items() if v is not None}
|
| 243 |
+
|
| 244 |
+
if "::" not in opname:
|
| 245 |
+
opname = "onnx::" + opname
|
| 246 |
+
|
| 247 |
+
node = _create_node(
|
| 248 |
+
graph_context.block,
|
| 249 |
+
opname,
|
| 250 |
+
inputs,
|
| 251 |
+
attributes,
|
| 252 |
+
params_dict=graph_context.params_dict,
|
| 253 |
+
opset_version=graph_context.opset,
|
| 254 |
+
n_outputs=outputs,
|
| 255 |
+
shape_inference=GLOBALS.onnx_shape_inference,
|
| 256 |
+
)
|
| 257 |
+
graph_context.new_nodes.append(node)
|
| 258 |
+
|
| 259 |
+
if outputs == 1:
|
| 260 |
+
return node.output()
|
| 261 |
+
return tuple(node.outputs())
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def _const_if_tensor(graph_context: GraphContext, arg):
|
| 265 |
+
if arg is None:
|
| 266 |
+
return arg
|
| 267 |
+
if isinstance(arg, _C.Value):
|
| 268 |
+
return arg
|
| 269 |
+
|
| 270 |
+
return _add_op(graph_context, "onnx::Constant", value_z=arg)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def _create_node(
|
| 274 |
+
graph_or_block: _C.Graph | _C.Block,
|
| 275 |
+
domain_op: str,
|
| 276 |
+
inputs: Sequence,
|
| 277 |
+
attributes: dict,
|
| 278 |
+
params_dict: dict,
|
| 279 |
+
opset_version: int,
|
| 280 |
+
n_outputs: int,
|
| 281 |
+
shape_inference: bool = True,
|
| 282 |
+
) -> _C.Node:
|
| 283 |
+
"""Creates an node 'domain_op', taking inputs and attributes."""
|
| 284 |
+
if isinstance(graph_or_block, _C.Graph):
|
| 285 |
+
graph = graph_or_block
|
| 286 |
+
node = graph.create(domain_op, inputs, n_outputs)
|
| 287 |
+
node = graph.insertNode(node)
|
| 288 |
+
elif isinstance(graph_or_block, _C.Block):
|
| 289 |
+
block = graph_or_block
|
| 290 |
+
node = block.addNode(domain_op, inputs)
|
| 291 |
+
|
| 292 |
+
# Block does not have create defined, so we need to add outputs manually
|
| 293 |
+
if n_outputs > 1:
|
| 294 |
+
for _ in range(1, n_outputs):
|
| 295 |
+
node.addOutput()
|
| 296 |
+
|
| 297 |
+
node_outputs = tuple(node.outputs()) # type: ignore[possibly-undefined]
|
| 298 |
+
assert len(node_outputs) == n_outputs
|
| 299 |
+
|
| 300 |
+
aten = domain_op.startswith("aten::")
|
| 301 |
+
|
| 302 |
+
# Add all attributes
|
| 303 |
+
for key, value in sorted(attributes.items()):
|
| 304 |
+
if key in _SKIP_NODE_ATTRIBUTES:
|
| 305 |
+
continue
|
| 306 |
+
_add_attribute(node, key, value, aten=aten)
|
| 307 |
+
if shape_inference:
|
| 308 |
+
_C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
|
| 309 |
+
return node
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def _is_onnx_list(value):
|
| 313 |
+
return isinstance(value, Iterable) and not isinstance(
|
| 314 |
+
value, (str, bytes, torch.Tensor)
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def _scalar(x: torch.Tensor):
|
| 319 |
+
"""Convert a scalar tensor into a Python value."""
|
| 320 |
+
assert x.numel() == 1
|
| 321 |
+
return x[0]
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool):
|
| 325 |
+
r"""Initializes the right attribute based on type of value."""
|
| 326 |
+
m = _ATTR_PATTERN.match(key)
|
| 327 |
+
if m is None:
|
| 328 |
+
raise ValueError(
|
| 329 |
+
f"Invalid attribute specifier '{key}' names "
|
| 330 |
+
"must be suffixed with type, e.g. 'dim_i' or 'dims_i'"
|
| 331 |
+
)
|
| 332 |
+
name, kind = m.group(1), m.group(2)
|
| 333 |
+
if _is_onnx_list(value):
|
| 334 |
+
kind += "s"
|
| 335 |
+
|
| 336 |
+
return getattr(node, f"{kind}_")(name, value)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
# TODO: Expose this to user when migrating symbolic helper functions to here.
|
| 340 |
+
def _is_tensor(x: _C.Value) -> bool:
|
| 341 |
+
return x.type().isSubtypeOf(_C.TensorType.get())
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def get_device_from_value(value: _C.Value) -> torch.device | None:
|
| 345 |
+
if not _is_tensor(value):
|
| 346 |
+
return None
|
| 347 |
+
tensor_type = typing.cast(_C.TensorType, value.type())
|
| 348 |
+
return tensor_type.device()
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def parse_node_kind(kind: str) -> tuple[str, str]:
|
| 352 |
+
"""Parse node kind into domain and Op name."""
|
| 353 |
+
if "::" not in kind:
|
| 354 |
+
raise ValueError(f"Node kind: {kind} is invalid. '::' is not in node kind.")
|
| 355 |
+
domain, opname = kind.split("::", 1)
|
| 356 |
+
if "::" in opname:
|
| 357 |
+
raise ValueError(f"Node kind: {kind} is invalid. '::' should only apear once.")
|
| 358 |
+
return domain, opname
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def is_aten(domain: str) -> bool:
|
| 362 |
+
"""Check if the domain is official."""
|
| 363 |
+
return domain == "aten"
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def is_prim(domain: str) -> bool:
|
| 367 |
+
"""Check if the domain is official."""
|
| 368 |
+
return domain == "prim"
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def is_onnx(domain: str) -> bool:
|
| 372 |
+
"""Check if the domain is official."""
|
| 373 |
+
return domain == "onnx"
|