ayousanz commited on
Commit
e8864b3
·
verified ·
1 Parent(s): cc55e01

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/Lib/site-packages/torch/nn/__pycache__/__init__.cpython-39.pyc +0 -0
  2. .venv/Lib/site-packages/torch/nn/__pycache__/_reduction.cpython-39.pyc +0 -0
  3. .venv/Lib/site-packages/torch/nn/__pycache__/common_types.cpython-39.pyc +0 -0
  4. .venv/Lib/site-packages/torch/nn/__pycache__/functional.cpython-39.pyc +0 -0
  5. .venv/Lib/site-packages/torch/nn/__pycache__/grad.cpython-39.pyc +0 -0
  6. .venv/Lib/site-packages/torch/nn/__pycache__/init.cpython-39.pyc +0 -0
  7. .venv/Lib/site-packages/torch/nn/__pycache__/parameter.cpython-39.pyc +0 -0
  8. .venv/Lib/site-packages/torch/nn/quantized/__pycache__/__init__.cpython-39.pyc +0 -0
  9. .venv/Lib/site-packages/torch/nn/quantized/__pycache__/functional.cpython-39.pyc +0 -0
  10. .venv/Lib/site-packages/torch/nn/quantized/_reference/__init__.py +1 -0
  11. .venv/Lib/site-packages/torch/nn/quantized/_reference/modules/__init__.py +39 -0
  12. .venv/Lib/site-packages/torch/nn/quantized/_reference/modules/conv.py +21 -0
  13. .venv/Lib/site-packages/torch/nn/quantized/_reference/modules/linear.py +12 -0
  14. .venv/Lib/site-packages/torch/nn/quantized/_reference/modules/rnn.py +19 -0
  15. .venv/Lib/site-packages/torch/nn/quantized/_reference/modules/sparse.py +12 -0
  16. .venv/Lib/site-packages/torch/nn/quantized/_reference/modules/utils.py +18 -0
  17. .venv/Lib/site-packages/torch/nn/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  18. .venv/Lib/site-packages/torch/nn/utils/__pycache__/_named_member_accessor.cpython-39.pyc +0 -0
  19. .venv/Lib/site-packages/torch/nn/utils/__pycache__/clip_grad.cpython-39.pyc +0 -0
  20. .venv/Lib/site-packages/torch/nn/utils/__pycache__/convert_parameters.cpython-39.pyc +0 -0
  21. .venv/Lib/site-packages/torch/nn/utils/__pycache__/fusion.cpython-39.pyc +0 -0
  22. .venv/Lib/site-packages/torch/nn/utils/__pycache__/init.cpython-39.pyc +0 -0
  23. .venv/Lib/site-packages/torch/nn/utils/__pycache__/memory_format.cpython-39.pyc +0 -0
  24. .venv/Lib/site-packages/torch/nn/utils/__pycache__/parametrizations.cpython-39.pyc +0 -0
  25. .venv/Lib/site-packages/torch/nn/utils/__pycache__/parametrize.cpython-39.pyc +0 -0
  26. .venv/Lib/site-packages/torch/nn/utils/__pycache__/rnn.cpython-39.pyc +0 -0
  27. .venv/Lib/site-packages/torch/nn/utils/__pycache__/spectral_norm.cpython-39.pyc +0 -0
  28. .venv/Lib/site-packages/torch/nn/utils/__pycache__/stateless.cpython-39.pyc +0 -0
  29. .venv/Lib/site-packages/torch/nn/utils/__pycache__/weight_norm.cpython-39.pyc +0 -0
  30. .venv/Lib/site-packages/torch/nn/utils/_expanded_weights/__init__.py +10 -0
  31. .venv/Lib/site-packages/torch/nn/utils/clip_grad.py +189 -0
  32. .venv/Lib/site-packages/torch/onnx/__init__.py +553 -0
  33. .venv/Lib/site-packages/torch/onnx/_constants.py +25 -0
  34. .venv/Lib/site-packages/torch/onnx/_deprecation.py +72 -0
  35. .venv/Lib/site-packages/torch/onnx/_experimental.py +27 -0
  36. .venv/Lib/site-packages/torch/onnx/_exporter_states.py +12 -0
  37. .venv/Lib/site-packages/torch/onnx/_flags.py +49 -0
  38. .venv/Lib/site-packages/torch/onnx/_globals.py +87 -0
  39. .venv/Lib/site-packages/torch/onnx/_internal/__init__.py +0 -0
  40. .venv/Lib/site-packages/torch/onnx/_internal/_lazy_import.py +41 -0
  41. .venv/Lib/site-packages/torch/onnx/_internal/diagnostics/__init__.py +22 -0
  42. .venv/Lib/site-packages/torch/onnx/_internal/diagnostics/_diagnostic.py +211 -0
  43. .venv/Lib/site-packages/torch/onnx/_internal/diagnostics/_rules.py +636 -0
  44. .venv/Lib/site-packages/torch/onnx/_internal/diagnostics/infra/_infra.py +285 -0
  45. .venv/Lib/site-packages/torch/onnx/_internal/diagnostics/infra/context.py +404 -0
  46. .venv/Lib/site-packages/torch/onnx/_internal/diagnostics/infra/decorator.py +153 -0
  47. .venv/Lib/site-packages/torch/onnx/_internal/diagnostics/infra/formatter.py +106 -0
  48. .venv/Lib/site-packages/torch/onnx/_internal/diagnostics/infra/utils.py +69 -0
  49. .venv/Lib/site-packages/torch/onnx/_internal/io_adapter.py +641 -0
  50. .venv/Lib/site-packages/torch/onnx/_internal/jit_utils.py +373 -0
.venv/Lib/site-packages/torch/nn/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (2.23 kB). View file
 
.venv/Lib/site-packages/torch/nn/__pycache__/_reduction.cpython-39.pyc ADDED
Binary file (1.31 kB). View file
 
.venv/Lib/site-packages/torch/nn/__pycache__/common_types.cpython-39.pyc ADDED
Binary file (1.03 kB). View file
 
.venv/Lib/site-packages/torch/nn/__pycache__/functional.cpython-39.pyc ADDED
Binary file (182 kB). View file
 
.venv/Lib/site-packages/torch/nn/__pycache__/grad.cpython-39.pyc ADDED
Binary file (9.05 kB). View file
 
.venv/Lib/site-packages/torch/nn/__pycache__/init.cpython-39.pyc ADDED
Binary file (21.2 kB). View file
 
.venv/Lib/site-packages/torch/nn/__pycache__/parameter.cpython-39.pyc ADDED
Binary file (10.8 kB). View file
 
.venv/Lib/site-packages/torch/nn/quantized/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (744 Bytes). View file
 
.venv/Lib/site-packages/torch/nn/quantized/__pycache__/functional.cpython-39.pyc ADDED
Binary file (459 Bytes). View file
 
.venv/Lib/site-packages/torch/nn/quantized/_reference/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from torch.nn.quantized._reference.modules import * # noqa: F403
.venv/Lib/site-packages/torch/nn/quantized/_reference/modules/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""Quantized Reference Modules.
3
+
4
+ This module is in the process of migration to
5
+ `torch/ao/nn/quantized/reference`, and is kept here for
6
+ compatibility while the migration process is ongoing.
7
+ If you are adding a new entry/functionality, please, add it to the
8
+ appropriate file under the `torch/ao/nn/quantized/reference`,
9
+ while adding an import statement here.
10
+ """
11
+
12
+ from torch.ao.nn.quantized.reference.modules.conv import (
13
+ Conv1d,
14
+ Conv2d,
15
+ Conv3d,
16
+ ConvTranspose1d,
17
+ ConvTranspose2d,
18
+ ConvTranspose3d,
19
+ )
20
+ from torch.ao.nn.quantized.reference.modules.linear import Linear
21
+ from torch.ao.nn.quantized.reference.modules.rnn import GRUCell, LSTM, LSTMCell, RNNCell
22
+ from torch.ao.nn.quantized.reference.modules.sparse import Embedding, EmbeddingBag
23
+
24
+
25
+ __all__ = [
26
+ "Linear",
27
+ "Conv1d",
28
+ "Conv2d",
29
+ "Conv3d",
30
+ "ConvTranspose1d",
31
+ "ConvTranspose2d",
32
+ "ConvTranspose3d",
33
+ "RNNCell",
34
+ "LSTMCell",
35
+ "GRUCell",
36
+ "LSTM",
37
+ "Embedding",
38
+ "EmbeddingBag",
39
+ ]
.venv/Lib/site-packages/torch/nn/quantized/_reference/modules/conv.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""Quantized Reference Modules.
3
+
4
+ This module is in the process of migration to
5
+ `torch/ao/nn/quantized/reference`, and is kept here for
6
+ compatibility while the migration process is ongoing.
7
+ If you are adding a new entry/functionality, please, add it to the
8
+ appropriate file under the `torch/ao/nn/quantized/reference`,
9
+ while adding an import statement here.
10
+ """
11
+
12
+ from torch.ao.nn.quantized.reference.modules.conv import (
13
+ _ConvNd,
14
+ _ConvTransposeNd,
15
+ Conv1d,
16
+ Conv2d,
17
+ Conv3d,
18
+ ConvTranspose1d,
19
+ ConvTranspose2d,
20
+ ConvTranspose3d,
21
+ )
.venv/Lib/site-packages/torch/nn/quantized/_reference/modules/linear.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""Quantized Reference Modules.
3
+
4
+ This module is in the process of migration to
5
+ `torch/ao/nn/quantized/reference`, and is kept here for
6
+ compatibility while the migration process is ongoing.
7
+ If you are adding a new entry/functionality, please, add it to the
8
+ appropriate file under the `torch/ao/nn/quantized/reference`,
9
+ while adding an import statement here.
10
+ """
11
+
12
+ from torch.ao.nn.quantized.reference.modules.linear import Linear
.venv/Lib/site-packages/torch/nn/quantized/_reference/modules/rnn.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""Quantized Reference Modules.
3
+
4
+ This module is in the process of migration to
5
+ `torch/ao/nn/quantized/reference`, and is kept here for
6
+ compatibility while the migration process is ongoing.
7
+ If you are adding a new entry/functionality, please, add it to the
8
+ appropriate file under the `torch/ao/nn/quantized/reference`,
9
+ while adding an import statement here.
10
+ """
11
+
12
+ from torch.ao.nn.quantized.reference.modules.rnn import (
13
+ GRUCell,
14
+ LSTM,
15
+ LSTMCell,
16
+ RNNBase,
17
+ RNNCell,
18
+ RNNCellBase,
19
+ )
.venv/Lib/site-packages/torch/nn/quantized/_reference/modules/sparse.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""Quantized Reference Modules.
3
+
4
+ This module is in the process of migration to
5
+ `torch/ao/nn/quantized/reference`, and is kept here for
6
+ compatibility while the migration process is ongoing.
7
+ If you are adding a new entry/functionality, please, add it to the
8
+ appropriate file under the `torch/ao/nn/quantized/reference`,
9
+ while adding an import statement here.
10
+ """
11
+
12
+ from torch.ao.nn.quantized.reference.modules.sparse import Embedding, EmbeddingBag
.venv/Lib/site-packages/torch/nn/quantized/_reference/modules/utils.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""Quantized Reference Modules.
3
+
4
+ This module is in the process of migration to
5
+ `torch/ao/nn/quantized/reference`, and is kept here for
6
+ compatibility while the migration process is ongoing.
7
+ If you are adding a new entry/functionality, please, add it to the
8
+ appropriate file under the `torch/ao/nn/quantized/reference`,
9
+ while adding an import statement here.
10
+ """
11
+
12
+ from torch.ao.nn.quantized.reference.modules.utils import (
13
+ _get_weight_qparam_keys,
14
+ _quantize_and_dequantize_weight,
15
+ _quantize_weight,
16
+ _save_weight_qparams,
17
+ ReferenceQuantizedModule,
18
+ )
.venv/Lib/site-packages/torch/nn/utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (1 kB). View file
 
.venv/Lib/site-packages/torch/nn/utils/__pycache__/_named_member_accessor.cpython-39.pyc ADDED
Binary file (12.1 kB). View file
 
.venv/Lib/site-packages/torch/nn/utils/__pycache__/clip_grad.cpython-39.pyc ADDED
Binary file (6.26 kB). View file
 
.venv/Lib/site-packages/torch/nn/utils/__pycache__/convert_parameters.cpython-39.pyc ADDED
Binary file (2.51 kB). View file
 
.venv/Lib/site-packages/torch/nn/utils/__pycache__/fusion.cpython-39.pyc ADDED
Binary file (4.97 kB). View file
 
.venv/Lib/site-packages/torch/nn/utils/__pycache__/init.cpython-39.pyc ADDED
Binary file (2.38 kB). View file
 
.venv/Lib/site-packages/torch/nn/utils/__pycache__/memory_format.cpython-39.pyc ADDED
Binary file (7.57 kB). View file
 
.venv/Lib/site-packages/torch/nn/utils/__pycache__/parametrizations.cpython-39.pyc ADDED
Binary file (18 kB). View file
 
.venv/Lib/site-packages/torch/nn/utils/__pycache__/parametrize.cpython-39.pyc ADDED
Binary file (23.7 kB). View file
 
.venv/Lib/site-packages/torch/nn/utils/__pycache__/rnn.cpython-39.pyc ADDED
Binary file (20.9 kB). View file
 
.venv/Lib/site-packages/torch/nn/utils/__pycache__/spectral_norm.cpython-39.pyc ADDED
Binary file (9.94 kB). View file
 
.venv/Lib/site-packages/torch/nn/utils/__pycache__/stateless.cpython-39.pyc ADDED
Binary file (9.39 kB). View file
 
.venv/Lib/site-packages/torch/nn/utils/__pycache__/weight_norm.cpython-39.pyc ADDED
Binary file (5.9 kB). View file
 
.venv/Lib/site-packages/torch/nn/utils/_expanded_weights/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .conv_expanded_weights import ConvPerSampleGrad
2
+ from .embedding_expanded_weights import EmbeddingPerSampleGrad
3
+ from .expanded_weights_impl import ExpandedWeight
4
+ from .group_norm_expanded_weights import GroupNormPerSampleGrad
5
+ from .instance_norm_expanded_weights import InstanceNormPerSampleGrad
6
+ from .layer_norm_expanded_weights import LayerNormPerSampleGrad
7
+ from .linear_expanded_weights import LinearPerSampleGrad
8
+
9
+
10
+ __all__ = ["ExpandedWeight"]
.venv/Lib/site-packages/torch/nn/utils/clip_grad.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-decorators
2
+ # mypy: allow-untyped-defs
3
+ import functools
4
+ from typing import cast, Dict, Iterable, List, Optional, Tuple, Union
5
+ from typing_extensions import deprecated
6
+
7
+ import torch
8
+ from torch import Tensor
9
+ from torch.utils._foreach_utils import (
10
+ _device_has_foreach_support,
11
+ _group_tensors_by_device_and_dtype,
12
+ _has_foreach_support,
13
+ )
14
+
15
+
16
+ __all__ = ["clip_grad_norm_", "clip_grad_norm", "clip_grad_value_"]
17
+
18
+
19
+ _tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
20
+
21
+
22
+ def _no_grad(func):
23
+ """
24
+ This wrapper is needed to avoid a circular import when using @torch.no_grad on the exposed functions
25
+ clip_grad_norm_ and clip_grad_value_ themselves.
26
+ """
27
+
28
+ def _no_grad_wrapper(*args, **kwargs):
29
+ with torch.no_grad():
30
+ return func(*args, **kwargs)
31
+
32
+ functools.update_wrapper(_no_grad_wrapper, func)
33
+ return _no_grad_wrapper
34
+
35
+
36
+ @_no_grad
37
+ def clip_grad_norm_(
38
+ parameters: _tensor_or_tensors,
39
+ max_norm: float,
40
+ norm_type: float = 2.0,
41
+ error_if_nonfinite: bool = False,
42
+ foreach: Optional[bool] = None,
43
+ ) -> torch.Tensor:
44
+ r"""Clip the gradient norm of an iterable of parameters.
45
+
46
+ The norm is computed over the norms of the individual gradients of all parameters,
47
+ as if the norms of the individual gradients were concatenated into a single vector.
48
+ Gradients are modified in-place.
49
+
50
+ Args:
51
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
52
+ single Tensor that will have gradients normalized
53
+ max_norm (float): max norm of the gradients
54
+ norm_type (float): type of the used p-norm. Can be ``'inf'`` for
55
+ infinity norm.
56
+ error_if_nonfinite (bool): if True, an error is thrown if the total
57
+ norm of the gradients from :attr:`parameters` is ``nan``,
58
+ ``inf``, or ``-inf``. Default: False (will switch to True in the future)
59
+ foreach (bool): use the faster foreach-based implementation.
60
+ If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
61
+ fall back to the slow implementation for other device types.
62
+ Default: ``None``
63
+
64
+ Returns:
65
+ Total norm of the parameter gradients (viewed as a single vector).
66
+ """
67
+ if isinstance(parameters, torch.Tensor):
68
+ parameters = [parameters]
69
+ grads = [p.grad for p in parameters if p.grad is not None]
70
+ max_norm = float(max_norm)
71
+ norm_type = float(norm_type)
72
+ if len(grads) == 0:
73
+ return torch.tensor(0.0)
74
+ first_device = grads[0].device
75
+ grouped_grads: Dict[
76
+ Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
77
+ ] = _group_tensors_by_device_and_dtype(
78
+ [grads]
79
+ ) # type: ignore[assignment]
80
+
81
+ norms: List[Tensor] = []
82
+ for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
83
+ if (foreach is None and _has_foreach_support(device_grads, device)) or (
84
+ foreach and _device_has_foreach_support(device)
85
+ ):
86
+ norms.extend(torch._foreach_norm(device_grads, norm_type))
87
+ elif foreach:
88
+ raise RuntimeError(
89
+ f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
90
+ )
91
+ else:
92
+ norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
93
+
94
+ total_norm = torch.linalg.vector_norm(
95
+ torch.stack([norm.to(first_device) for norm in norms]), norm_type
96
+ )
97
+
98
+ if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
99
+ raise RuntimeError(
100
+ f"The total norm of order {norm_type} for gradients from "
101
+ "`parameters` is non-finite, so it cannot be clipped. To disable "
102
+ "this error and scale the gradients by the non-finite norm anyway, "
103
+ "set `error_if_nonfinite=False`"
104
+ )
105
+ clip_coef = max_norm / (total_norm + 1e-6)
106
+ # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
107
+ # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
108
+ # when the gradients do not reside in CPU memory.
109
+ clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
110
+ for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
111
+ if (foreach is None and _has_foreach_support(device_grads, device)) or (
112
+ foreach and _device_has_foreach_support(device)
113
+ ):
114
+ torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
115
+ elif foreach:
116
+ raise RuntimeError(
117
+ f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
118
+ )
119
+ else:
120
+ clip_coef_clamped_device = clip_coef_clamped.to(device)
121
+ for g in device_grads:
122
+ g.mul_(clip_coef_clamped_device)
123
+
124
+ return total_norm
125
+
126
+
127
+ @deprecated(
128
+ "`torch.nn.utils.clip_grad_norm` is now deprecated "
129
+ "in favor of `torch.nn.utils.clip_grad_norm_`.",
130
+ category=FutureWarning,
131
+ )
132
+ def clip_grad_norm(
133
+ parameters: _tensor_or_tensors,
134
+ max_norm: float,
135
+ norm_type: float = 2.0,
136
+ error_if_nonfinite: bool = False,
137
+ foreach: Optional[bool] = None,
138
+ ) -> torch.Tensor:
139
+ r"""Clip the gradient norm of an iterable of parameters.
140
+
141
+ .. warning::
142
+ This method is now deprecated in favor of
143
+ :func:`torch.nn.utils.clip_grad_norm_`.
144
+ """
145
+ return clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite, foreach)
146
+
147
+
148
+ @_no_grad
149
+ def clip_grad_value_(
150
+ parameters: _tensor_or_tensors,
151
+ clip_value: float,
152
+ foreach: Optional[bool] = None,
153
+ ) -> None:
154
+ r"""Clip the gradients of an iterable of parameters at specified value.
155
+
156
+ Gradients are modified in-place.
157
+
158
+ Args:
159
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
160
+ single Tensor that will have gradients normalized
161
+ clip_value (float): maximum allowed value of the gradients.
162
+ The gradients are clipped in the range
163
+ :math:`\left[\text{-clip\_value}, \text{clip\_value}\right]`
164
+ foreach (bool): use the faster foreach-based implementation
165
+ If ``None``, use the foreach implementation for CUDA and CPU native tensors and
166
+ silently fall back to the slow implementation for other device types.
167
+ Default: ``None``
168
+ """
169
+ if isinstance(parameters, torch.Tensor):
170
+ parameters = [parameters]
171
+ clip_value = float(clip_value)
172
+
173
+ grads = [p.grad for p in parameters if p.grad is not None]
174
+ grouped_grads = _group_tensors_by_device_and_dtype([grads])
175
+
176
+ for (device, _), ([grads], _) in grouped_grads.items(): # type: ignore[assignment]
177
+ if (
178
+ foreach is None
179
+ and _has_foreach_support(cast(List[Tensor], grads), device=device)
180
+ ) or (foreach and _device_has_foreach_support(device)):
181
+ torch._foreach_clamp_min_(cast(List[Tensor], grads), -clip_value)
182
+ torch._foreach_clamp_max_(cast(List[Tensor], grads), clip_value)
183
+ elif foreach:
184
+ raise RuntimeError(
185
+ f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
186
+ )
187
+ else:
188
+ for grad in grads:
189
+ cast(Tensor, grad).clamp_(min=-clip_value, max=clip_value)
.venv/Lib/site-packages/torch/onnx/__init__.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from __future__ import annotations
3
+
4
+
5
+ __all__ = [
6
+ # Modules
7
+ "symbolic_helper",
8
+ "utils",
9
+ "errors",
10
+ # All opsets
11
+ "symbolic_caffe2",
12
+ "symbolic_opset7",
13
+ "symbolic_opset8",
14
+ "symbolic_opset9",
15
+ "symbolic_opset10",
16
+ "symbolic_opset11",
17
+ "symbolic_opset12",
18
+ "symbolic_opset13",
19
+ "symbolic_opset14",
20
+ "symbolic_opset15",
21
+ "symbolic_opset16",
22
+ "symbolic_opset17",
23
+ "symbolic_opset18",
24
+ "symbolic_opset19",
25
+ "symbolic_opset20",
26
+ # Enums
27
+ "ExportTypes",
28
+ "OperatorExportTypes",
29
+ "TrainingMode",
30
+ "TensorProtoDataType",
31
+ "JitScalarType",
32
+ # Public functions
33
+ "export",
34
+ "export_to_pretty_string",
35
+ "is_in_onnx_export",
36
+ "select_model_mode_for_export",
37
+ "register_custom_op_symbolic",
38
+ "unregister_custom_op_symbolic",
39
+ "disable_log",
40
+ "enable_log",
41
+ # Base error
42
+ "OnnxExporterError",
43
+ # Dynamo Exporter
44
+ "DiagnosticOptions",
45
+ "ExportOptions",
46
+ "ONNXProgram",
47
+ "ONNXRuntimeOptions",
48
+ "OnnxRegistry",
49
+ "dynamo_export",
50
+ "enable_fake_mode",
51
+ # DORT / torch.compile
52
+ "is_onnxrt_backend_supported",
53
+ ]
54
+
55
+ from typing import Any, Callable, Collection, Mapping, Sequence, TYPE_CHECKING
56
+
57
+ import torch
58
+ from torch import _C
59
+ from torch._C import _onnx as _C_onnx
60
+ from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode
61
+
62
+ from ._exporter_states import ExportTypes
63
+ from ._internal.onnxruntime import (
64
+ is_onnxrt_backend_supported,
65
+ OrtBackend as _OrtBackend,
66
+ OrtBackendOptions as _OrtBackendOptions,
67
+ OrtExecutionProvider as _OrtExecutionProvider,
68
+ )
69
+ from ._type_utils import JitScalarType
70
+ from .errors import OnnxExporterError
71
+ from .utils import (
72
+ _optimize_graph,
73
+ _run_symbolic_function,
74
+ _run_symbolic_method,
75
+ export_to_pretty_string,
76
+ is_in_onnx_export,
77
+ register_custom_op_symbolic,
78
+ select_model_mode_for_export,
79
+ unregister_custom_op_symbolic,
80
+ )
81
+
82
+
83
+ from . import ( # usort: skip. Keep the order instead of sorting lexicographically
84
+ errors,
85
+ symbolic_caffe2,
86
+ symbolic_helper,
87
+ symbolic_opset7,
88
+ symbolic_opset8,
89
+ symbolic_opset9,
90
+ symbolic_opset10,
91
+ symbolic_opset11,
92
+ symbolic_opset12,
93
+ symbolic_opset13,
94
+ symbolic_opset14,
95
+ symbolic_opset15,
96
+ symbolic_opset16,
97
+ symbolic_opset17,
98
+ symbolic_opset18,
99
+ symbolic_opset19,
100
+ symbolic_opset20,
101
+ utils,
102
+ )
103
+
104
+
105
+ from ._internal._exporter_legacy import ( # usort: skip. needs to be last to avoid circular import
106
+ DiagnosticOptions,
107
+ ExportOptions,
108
+ ONNXProgram,
109
+ ONNXRuntimeOptions,
110
+ OnnxRegistry,
111
+ enable_fake_mode,
112
+ )
113
+
114
+
115
+ if TYPE_CHECKING:
116
+ import os
117
+
118
+ # Set namespace for exposed private names
119
+ DiagnosticOptions.__module__ = "torch.onnx"
120
+ ExportOptions.__module__ = "torch.onnx"
121
+ ExportTypes.__module__ = "torch.onnx"
122
+ JitScalarType.__module__ = "torch.onnx"
123
+ ONNXProgram.__module__ = "torch.onnx"
124
+ ONNXRuntimeOptions.__module__ = "torch.onnx"
125
+ OnnxExporterError.__module__ = "torch.onnx"
126
+ OnnxRegistry.__module__ = "torch.onnx"
127
+ _OrtBackend.__module__ = "torch.onnx"
128
+ _OrtBackendOptions.__module__ = "torch.onnx"
129
+ _OrtExecutionProvider.__module__ = "torch.onnx"
130
+ enable_fake_mode.__module__ = "torch.onnx"
131
+ is_onnxrt_backend_supported.__module__ = "torch.onnx"
132
+
133
+ producer_name = "pytorch"
134
+ producer_version = _C_onnx.PRODUCER_VERSION
135
+
136
+
137
+ def export(
138
+ model: torch.nn.Module
139
+ | torch.export.ExportedProgram
140
+ | torch.jit.ScriptModule
141
+ | torch.jit.ScriptFunction,
142
+ args: tuple[Any, ...] = (),
143
+ f: str | os.PathLike | None = None,
144
+ *,
145
+ kwargs: dict[str, Any] | None = None,
146
+ export_params: bool = True,
147
+ verbose: bool | None = None,
148
+ input_names: Sequence[str] | None = None,
149
+ output_names: Sequence[str] | None = None,
150
+ opset_version: int | None = None,
151
+ dynamic_axes: Mapping[str, Mapping[int, str]]
152
+ | Mapping[str, Sequence[int]]
153
+ | None = None,
154
+ keep_initializers_as_inputs: bool = False,
155
+ dynamo: bool = False,
156
+ # Dynamo only options
157
+ external_data: bool = True,
158
+ dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None,
159
+ report: bool = False,
160
+ verify: bool = False,
161
+ profile: bool = False,
162
+ dump_exported_program: bool = False,
163
+ artifacts_dir: str | os.PathLike = ".",
164
+ fallback: bool = False,
165
+ # Deprecated options
166
+ training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,
167
+ operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX,
168
+ do_constant_folding: bool = True,
169
+ custom_opsets: Mapping[str, int] | None = None,
170
+ export_modules_as_functions: bool | Collection[type[torch.nn.Module]] = False,
171
+ autograd_inlining: bool = True,
172
+ **_: Any, # ignored options
173
+ ) -> Any | None:
174
+ r"""Exports a model into ONNX format.
175
+
176
+ Args:
177
+ model: The model to be exported.
178
+ args: Example positional inputs. Any non-Tensor arguments will be hard-coded into the
179
+ exported model; any Tensor arguments will become inputs of the exported model,
180
+ in the order they occur in the tuple.
181
+ f: Path to the output ONNX model file. E.g. "model.onnx".
182
+ kwargs: Optional example keyword inputs.
183
+ export_params: If false, parameters (weights) will not be exported.
184
+ verbose: Whether to enable verbose logging.
185
+ input_names: names to assign to the input nodes of the graph, in order.
186
+ output_names: names to assign to the output nodes of the graph, in order.
187
+ opset_version: The version of the
188
+ `default (ai.onnx) opset <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_
189
+ to target. Must be >= 7.
190
+ dynamic_axes:
191
+
192
+ By default the exported model will have the shapes of all input and output tensors
193
+ set to exactly match those given in ``args``. To specify axes of tensors as
194
+ dynamic (i.e. known only at run-time), set ``dynamic_axes`` to a dict with schema:
195
+
196
+ * KEY (str): an input or output name. Each name must also be provided in ``input_names`` or
197
+ ``output_names``.
198
+ * VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a
199
+ list, each element is an axis index.
200
+
201
+ For example::
202
+
203
+ class SumModule(torch.nn.Module):
204
+ def forward(self, x):
205
+ return torch.sum(x, dim=1)
206
+
207
+
208
+ torch.onnx.export(
209
+ SumModule(),
210
+ (torch.ones(2, 2),),
211
+ "onnx.pb",
212
+ input_names=["x"],
213
+ output_names=["sum"],
214
+ )
215
+
216
+ Produces::
217
+
218
+ input {
219
+ name: "x"
220
+ ...
221
+ shape {
222
+ dim {
223
+ dim_value: 2 # axis 0
224
+ }
225
+ dim {
226
+ dim_value: 2 # axis 1
227
+ ...
228
+ output {
229
+ name: "sum"
230
+ ...
231
+ shape {
232
+ dim {
233
+ dim_value: 2 # axis 0
234
+ ...
235
+
236
+ While::
237
+
238
+ torch.onnx.export(
239
+ SumModule(),
240
+ (torch.ones(2, 2),),
241
+ "onnx.pb",
242
+ input_names=["x"],
243
+ output_names=["sum"],
244
+ dynamic_axes={
245
+ # dict value: manually named axes
246
+ "x": {0: "my_custom_axis_name"},
247
+ # list value: automatic names
248
+ "sum": [0],
249
+ },
250
+ )
251
+
252
+ Produces::
253
+
254
+ input {
255
+ name: "x"
256
+ ...
257
+ shape {
258
+ dim {
259
+ dim_param: "my_custom_axis_name" # axis 0
260
+ }
261
+ dim {
262
+ dim_value: 2 # axis 1
263
+ ...
264
+ output {
265
+ name: "sum"
266
+ ...
267
+ shape {
268
+ dim {
269
+ dim_param: "sum_dynamic_axes_1" # axis 0
270
+ ...
271
+
272
+ keep_initializers_as_inputs: If True, all the
273
+ initializers (typically corresponding to model weights) in the
274
+ exported graph will also be added as inputs to the graph. If False,
275
+ then initializers are not added as inputs to the graph, and only
276
+ the user inputs are added as inputs.
277
+
278
+ Set this to True if you intend to supply model weights at runtime.
279
+ Set it to False if the weights are static to allow for better optimizations
280
+ (e.g. constant folding) by backends/runtimes.
281
+
282
+ dynamo: Whether to export the model with ``torch.export`` ExportedProgram instead of TorchScript.
283
+ external_data: Whether to save the model weights as an external data file.
284
+ This is required for models with large weights that exceed the ONNX file size limit (2GB).
285
+ When False, the weights are saved in the ONNX file with the model architecture.
286
+ dynamic_shapes: A dictionary of dynamic shapes for the model inputs. Refer to
287
+ :func:`torch.export.export` for more details. This is only used (and preferred) when dynamo is True.
288
+ Only one parameter `dynamic_axes` or `dynamic_shapes` should be set
289
+ at the same time.
290
+ report: Whether to generate a markdown report for the export process.
291
+ verify: Whether to verify the exported model using ONNX Runtime.
292
+ profile: Whether to profile the export process.
293
+ dump_exported_program: Whether to dump the :class:`torch.export.ExportedProgram` to a file.
294
+ This is useful for debugging the exporter.
295
+ artifacts_dir: The directory to save the debugging artifacts like the report and the serialized
296
+ exported program.
297
+ fallback: Whether to fallback to the TorchScript exporter if the dynamo exporter fails.
298
+
299
+ training: Deprecated option. Instead, set the training mode of the model before exporting.
300
+ operator_export_type: Deprecated option. Only ONNX is supported.
301
+ do_constant_folding: Deprecated option. The exported graph is always optimized.
302
+ custom_opsets: Deprecated.
303
+ A dictionary:
304
+
305
+ * KEY (str): opset domain name
306
+ * VALUE (int): opset version
307
+
308
+ If a custom opset is referenced by ``model`` but not mentioned in this dictionary,
309
+ the opset version is set to 1. Only custom opset domain name and version should be
310
+ indicated through this argument.
311
+ export_modules_as_functions: Deprecated option.
312
+
313
+ Flag to enable
314
+ exporting all ``nn.Module`` forward calls as local functions in ONNX. Or a set to indicate the
315
+ particular types of modules to export as local functions in ONNX.
316
+ This feature requires ``opset_version`` >= 15, otherwise the export will fail. This is because
317
+ ``opset_version`` < 15 implies IR version < 8, which means no local function support.
318
+ Module variables will be exported as function attributes. There are two categories of function
319
+ attributes.
320
+
321
+ 1. Annotated attributes: class variables that have type annotations via
322
+ `PEP 526-style <https://www.python.org/dev/peps/pep-0526/#class-and-instance-variable-annotations>`_
323
+ will be exported as attributes.
324
+ Annotated attributes are not used inside the subgraph of ONNX local function because
325
+ they are not created by PyTorch JIT tracing, but they may be used by consumers
326
+ to determine whether or not to replace the function with a particular fused kernel.
327
+
328
+ 2. Inferred attributes: variables that are used by operators inside the module. Attribute names
329
+ will have prefix "inferred::". This is to differentiate from predefined attributes retrieved from
330
+ python module annotations. Inferred attributes are used inside the subgraph of ONNX local function.
331
+
332
+ * ``False`` (default): export ``nn.Module`` forward calls as fine grained nodes.
333
+ * ``True``: export all ``nn.Module`` forward calls as local function nodes.
334
+ * Set of type of nn.Module: export ``nn.Module`` forward calls as local function nodes,
335
+ only if the type of the ``nn.Module`` is found in the set.
336
+ autograd_inlining: Deprecated.
337
+ Flag used to control whether to inline autograd functions.
338
+ Refer to https://github.com/pytorch/pytorch/pull/74765 for more details.
339
+ """
340
+ if dynamo is True or isinstance(model, torch.export.ExportedProgram):
341
+ from torch.onnx._internal import exporter
342
+
343
+ if isinstance(args, torch.Tensor):
344
+ args = (args,)
345
+ return exporter.export_compat(
346
+ model,
347
+ args,
348
+ f,
349
+ kwargs=kwargs,
350
+ export_params=export_params,
351
+ verbose=verbose,
352
+ input_names=input_names,
353
+ output_names=output_names,
354
+ opset_version=opset_version,
355
+ dynamic_axes=dynamic_axes,
356
+ keep_initializers_as_inputs=keep_initializers_as_inputs,
357
+ external_data=external_data,
358
+ dynamic_shapes=dynamic_shapes,
359
+ report=report,
360
+ verify=verify,
361
+ profile=profile,
362
+ dump_exported_program=dump_exported_program,
363
+ artifacts_dir=artifacts_dir,
364
+ fallback=fallback,
365
+ )
366
+ else:
367
+ from torch.onnx.utils import export
368
+
369
+ if dynamic_shapes:
370
+ raise ValueError(
371
+ "The exporter only supports dynamic shapes "
372
+ "through parameter dynamic_axes when dynamo=False."
373
+ )
374
+
375
+ export(
376
+ model,
377
+ args,
378
+ f, # type: ignore[arg-type]
379
+ kwargs=kwargs,
380
+ export_params=export_params,
381
+ verbose=verbose is True,
382
+ input_names=input_names,
383
+ output_names=output_names,
384
+ opset_version=opset_version,
385
+ dynamic_axes=dynamic_axes,
386
+ keep_initializers_as_inputs=keep_initializers_as_inputs,
387
+ training=training,
388
+ operator_export_type=operator_export_type,
389
+ do_constant_folding=do_constant_folding,
390
+ custom_opsets=custom_opsets,
391
+ export_modules_as_functions=export_modules_as_functions,
392
+ autograd_inlining=autograd_inlining,
393
+ )
394
+ return None
395
+
396
+
397
+ def dynamo_export(
398
+ model: torch.nn.Module | Callable | torch.export.ExportedProgram, # type: ignore[name-defined]
399
+ /,
400
+ *model_args,
401
+ export_options: ExportOptions | None = None,
402
+ **model_kwargs,
403
+ ) -> ONNXProgram | Any:
404
+ """Export a torch.nn.Module to an ONNX graph.
405
+
406
+ Args:
407
+ model: The PyTorch model to be exported to ONNX.
408
+ model_args: Positional inputs to ``model``.
409
+ model_kwargs: Keyword inputs to ``model``.
410
+ export_options: Options to influence the export to ONNX.
411
+
412
+ Returns:
413
+ An in-memory representation of the exported ONNX model.
414
+
415
+ **Example 1 - Simplest export**
416
+ ::
417
+
418
+ class MyModel(torch.nn.Module):
419
+ def __init__(self) -> None:
420
+ super().__init__()
421
+ self.linear = torch.nn.Linear(2, 2)
422
+
423
+ def forward(self, x, bias=None):
424
+ out = self.linear(x)
425
+ out = out + bias
426
+ return out
427
+
428
+
429
+ model = MyModel()
430
+ kwargs = {"bias": 3.0}
431
+ args = (torch.randn(2, 2, 2),)
432
+ onnx_program = torch.onnx.dynamo_export(model, *args, **kwargs).save(
433
+ "my_simple_model.onnx"
434
+ )
435
+
436
+ **Example 2 - Exporting with dynamic shapes**
437
+ ::
438
+
439
+ # The previous model can be exported with dynamic shapes
440
+ export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
441
+ onnx_program = torch.onnx.dynamo_export(
442
+ model, *args, **kwargs, export_options=export_options
443
+ )
444
+ onnx_program.save("my_dynamic_model.onnx")
445
+ """
446
+
447
+ # NOTE: The new exporter is experimental and is not enabled by default.
448
+ import warnings
449
+
450
+ from torch.onnx import _flags
451
+ from torch.onnx._internal import exporter
452
+ from torch.utils import _pytree
453
+
454
+ if isinstance(model, torch.export.ExportedProgram):
455
+ return exporter.export_compat(
456
+ model, # type: ignore[arg-type]
457
+ model_args,
458
+ f=None,
459
+ kwargs=model_kwargs,
460
+ opset_version=18,
461
+ external_data=True,
462
+ export_params=True,
463
+ fallback=True,
464
+ )
465
+ elif _flags.USE_EXPERIMENTAL_LOGIC:
466
+ if export_options is not None:
467
+ warnings.warn(
468
+ "You are using an experimental ONNX export logic, which currently only supports dynamic shapes. "
469
+ "For a more comprehensive set of export options, including advanced features, please consider using "
470
+ "`torch.onnx.export(..., dynamo=True)`. ",
471
+ category=FutureWarning,
472
+ )
473
+
474
+ if export_options is not None and export_options.dynamic_shapes:
475
+ # Make all shapes dynamic
476
+ def _to_dynamic_shapes_mapper():
477
+ arg_order = 0
478
+
479
+ def _to_dynamic_shape(x):
480
+ nonlocal arg_order
481
+ if isinstance(x, torch.Tensor):
482
+ rank = len(x.shape)
483
+ dynamic_shape = {}
484
+ for i in range(rank):
485
+ dynamic_shape[i] = torch.export.Dim(
486
+ f"arg_{arg_order}_dim_{i}"
487
+ )
488
+ arg_order += 1
489
+ return dynamic_shape
490
+ else:
491
+ return None
492
+
493
+ return _to_dynamic_shape
494
+
495
+ # model_args could be nested
496
+ dynamic_shapes = _pytree.tree_map(
497
+ _to_dynamic_shapes_mapper(),
498
+ model_args,
499
+ )
500
+ else:
501
+ dynamic_shapes = None
502
+
503
+ return exporter.export_compat(
504
+ model, # type: ignore[arg-type]
505
+ model_args,
506
+ f=None,
507
+ kwargs=model_kwargs,
508
+ dynamic_shapes=dynamic_shapes,
509
+ opset_version=18,
510
+ external_data=True,
511
+ export_params=True,
512
+ fallback=True,
513
+ )
514
+ else:
515
+ from torch.onnx._internal._exporter_legacy import dynamo_export
516
+
517
+ return dynamo_export(
518
+ model, *model_args, export_options=export_options, **model_kwargs
519
+ )
520
+
521
+
522
+ # TODO(justinchuby): Deprecate these logging functions in favor of the new diagnostic module.
523
+
524
+ # Returns True iff ONNX logging is turned on.
525
+ is_onnx_log_enabled = _C._jit_is_onnx_log_enabled
526
+
527
+
528
+ def enable_log() -> None:
529
+ r"""Enables ONNX logging."""
530
+ _C._jit_set_onnx_log_enabled(True)
531
+
532
+
533
+ def disable_log() -> None:
534
+ r"""Disables ONNX logging."""
535
+ _C._jit_set_onnx_log_enabled(False)
536
+
537
+
538
+ """Sets output stream for ONNX logging.
539
+
540
+ Args:
541
+ stream_name (str, default "stdout"): Only 'stdout' and 'stderr' are supported
542
+ as ``stream_name``.
543
+ """
544
+ set_log_stream = _C._jit_set_onnx_log_output_stream
545
+
546
+
547
+ """A simple logging facility for ONNX exporter.
548
+
549
+ Args:
550
+ args: Arguments are converted to string, concatenated together with a newline
551
+ character appended to the end, and flushed to output stream.
552
+ """
553
+ log = _C._jit_onnx_log
.venv/Lib/site-packages/torch/onnx/_constants.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Constant values used in ONNX."""
2
+
3
+ ONNX_ARCHIVE_MODEL_PROTO_NAME = "__MODEL_PROTO"
4
+
5
+ ONNX_BASE_OPSET = 9
6
+ ONNX_MIN_OPSET = 7
7
+ ONNX_MAX_OPSET = 20
8
+ ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET = 20
9
+ # ONNX_DEFAULT_OPSET generated by tools/onnx/update_default_opset_version.py
10
+ ONNX_DEFAULT_OPSET = 17
11
+ ONNX_CONSTANT_FOLDING_MIN_OPSET = 9
12
+
13
+ PYTORCH_GITHUB_ISSUES_URL = "https://github.com/pytorch/pytorch/issues"
14
+
15
+ INT64_MAX = 9223372036854775807
16
+ INT32_MAX = 2147483647
17
+ INT16_MAX = 32767
18
+ INT8_MAX = 127
19
+ UINT8_MAX = 255
20
+
21
+ INT64_MIN = -9223372036854775808
22
+ INT32_MIN = -2147483648
23
+ INT16_MIN = -32768
24
+ INT8_MIN = -128
25
+ UINT8_MIN = 0
.venv/Lib/site-packages/torch/onnx/_deprecation.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility for deprecating functions."""
2
+
3
+ import functools
4
+ import textwrap
5
+ import warnings
6
+ from typing import Callable, TypeVar
7
+ from typing_extensions import ParamSpec
8
+
9
+
10
+ _T = TypeVar("_T")
11
+ _P = ParamSpec("_P")
12
+
13
+
14
+ def deprecated(
15
+ since: str, removed_in: str, instructions: str
16
+ ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
17
+ """Marks functions as deprecated.
18
+
19
+ It will result in a warning when the function is called and a note in the
20
+ docstring.
21
+
22
+ Args:
23
+ since: The version when the function was first deprecated.
24
+ removed_in: The version when the function will be removed.
25
+ instructions: The action users should take.
26
+ """
27
+
28
+ def decorator(function: Callable[_P, _T]) -> Callable[_P, _T]:
29
+ @functools.wraps(function)
30
+ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
31
+ warnings.warn(
32
+ f"'{function.__module__}.{function.__name__}' "
33
+ f"is deprecated in version {since} and will be "
34
+ f"removed in {removed_in}. Please {instructions}.",
35
+ category=FutureWarning,
36
+ stacklevel=2,
37
+ )
38
+ return function(*args, **kwargs)
39
+
40
+ # Add a deprecation note to the docstring.
41
+ docstring = function.__doc__ or ""
42
+
43
+ # Add a note to the docstring.
44
+ deprecation_note = textwrap.dedent(
45
+ f"""\
46
+ .. deprecated:: {since}
47
+ Deprecated and will be removed in version {removed_in}.
48
+ Please {instructions}.
49
+ """
50
+ )
51
+
52
+ # Split docstring at first occurrence of newline
53
+ summary_and_body = docstring.split("\n\n", 1)
54
+
55
+ if len(summary_and_body) > 1:
56
+ summary, body = summary_and_body
57
+
58
+ # Dedent the body. We cannot do this with the presence of the summary because
59
+ # the body contains leading whitespaces when the summary does not.
60
+ body = textwrap.dedent(body)
61
+
62
+ new_docstring_parts = [deprecation_note, "\n\n", summary, body]
63
+ else:
64
+ summary = summary_and_body[0]
65
+
66
+ new_docstring_parts = [deprecation_note, "\n\n", summary]
67
+
68
+ wrapper.__doc__ = "".join(new_docstring_parts)
69
+
70
+ return wrapper
71
+
72
+ return decorator
.venv/Lib/site-packages/torch/onnx/_experimental.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Experimental classes and functions used by ONNX export."""
2
+
3
+ import dataclasses
4
+ from typing import Mapping, Optional, Sequence, Set, Type, Union
5
+
6
+ import torch
7
+ import torch._C._onnx as _C_onnx
8
+
9
+
10
+ @dataclasses.dataclass
11
+ class ExportOptions:
12
+ """Arguments used by :func:`torch.onnx.export`."""
13
+
14
+ # TODO(justinchuby): Deprecate and remove this class.
15
+
16
+ export_params: bool = True
17
+ verbose: bool = False
18
+ training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL
19
+ input_names: Optional[Sequence[str]] = None
20
+ output_names: Optional[Sequence[str]] = None
21
+ operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX
22
+ opset_version: Optional[int] = None
23
+ do_constant_folding: bool = True
24
+ dynamic_axes: Optional[Mapping[str, Union[Mapping[int, str], Sequence[int]]]] = None
25
+ keep_initializers_as_inputs: Optional[bool] = None
26
+ custom_opsets: Optional[Mapping[str, int]] = None
27
+ export_modules_as_functions: Union[bool, Set[Type[torch.nn.Module]]] = False
.venv/Lib/site-packages/torch/onnx/_exporter_states.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+
4
+ class ExportTypes:
5
+ """Specifies how the ONNX model is stored."""
6
+
7
+ # TODO(justinchuby): Deprecate and remove this class.
8
+
9
+ PROTOBUF_FILE = "Saves model in the specified protobuf file."
10
+ ZIP_ARCHIVE = "Saves model in the specified ZIP file (uncompressed)."
11
+ COMPRESSED_ZIP_ARCHIVE = "Saves model in the specified ZIP file (compressed)."
12
+ DIRECTORY = "Saves model in the specified folder."
.venv/Lib/site-packages/torch/onnx/_flags.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Internal feature flags for torch.onnx.
2
+
3
+ NOTE: These flags are experimental only. Any flag here can be removed at any
4
+ time without notice.
5
+ """
6
+
7
+ import logging
8
+ import os
9
+
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def _load_boolean_flag(
15
+ name: str,
16
+ *,
17
+ this_will: str,
18
+ deprecated: bool = False,
19
+ default: bool = False,
20
+ ) -> bool:
21
+ """Load a boolean flag from environment variable.
22
+
23
+ Args:
24
+ name: The name of the environment variable.
25
+ this_will: A string that describes what this flag will do.
26
+ deprecated: Whether this flag is deprecated.
27
+ default: The default value if envvar not defined.
28
+ """
29
+ undefined = os.getenv(name) is None
30
+ state = os.getenv(name) == "1"
31
+ if state:
32
+ if deprecated:
33
+ logger.error(
34
+ "Experimental flag %s is deprecated. Please remove it from your environment.",
35
+ name,
36
+ )
37
+ else:
38
+ logger.warning(
39
+ "Experimental flag %s is enabled. This will %s.", name, this_will
40
+ )
41
+ if undefined:
42
+ state = default
43
+ return state
44
+
45
+
46
+ USE_EXPERIMENTAL_LOGIC: bool = _load_boolean_flag(
47
+ "TORCH_ONNX_USE_EXPERIMENTAL_LOGIC",
48
+ this_will="use ExportedProgram and the new torch.onnx export logic",
49
+ )
.venv/Lib/site-packages/torch/onnx/_globals.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ """Globals used internally by the ONNX exporter.
3
+
4
+ Do not use this module outside of `torch.onnx` and its tests.
5
+
6
+ Be very judicious when adding any new global variables. Do not create new global
7
+ variables unless they are absolutely necessary.
8
+ """
9
+
10
+ import torch._C._onnx as _C_onnx
11
+
12
+ # This module should only depend on _constants and nothing else in torch.onnx to keep
13
+ # dependency direction clean.
14
+ from torch.onnx import _constants
15
+
16
+
17
+ class _InternalGlobals:
18
+ """Globals used internally by ONNX exporter.
19
+
20
+ NOTE: Be very judicious when adding any new variables. Do not create new
21
+ global variables unless they are absolutely necessary.
22
+ """
23
+
24
+ def __init__(self) -> None:
25
+ self._export_onnx_opset_version = _constants.ONNX_DEFAULT_OPSET
26
+ self._training_mode: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL
27
+ self._in_onnx_export: bool = False
28
+ # Whether the user's model is training during export
29
+ self.export_training: bool = False
30
+ self.operator_export_type: _C_onnx.OperatorExportTypes = (
31
+ _C_onnx.OperatorExportTypes.ONNX
32
+ )
33
+ self.onnx_shape_inference: bool = True
34
+ self._autograd_inlining: bool = True
35
+
36
+ @property
37
+ def training_mode(self):
38
+ """The training mode for the exporter."""
39
+ return self._training_mode
40
+
41
+ @training_mode.setter
42
+ def training_mode(self, training_mode: _C_onnx.TrainingMode):
43
+ if not isinstance(training_mode, _C_onnx.TrainingMode):
44
+ raise TypeError(
45
+ "training_mode must be of type 'torch.onnx.TrainingMode'. This is "
46
+ "likely a bug in torch.onnx."
47
+ )
48
+ self._training_mode = training_mode
49
+
50
+ @property
51
+ def export_onnx_opset_version(self) -> int:
52
+ """Opset version used during export."""
53
+ return self._export_onnx_opset_version
54
+
55
+ @export_onnx_opset_version.setter
56
+ def export_onnx_opset_version(self, value: int):
57
+ supported_versions = range(
58
+ _constants.ONNX_MIN_OPSET, _constants.ONNX_MAX_OPSET + 1
59
+ )
60
+ if value not in supported_versions:
61
+ raise ValueError(f"Unsupported ONNX opset version: {value}")
62
+ self._export_onnx_opset_version = value
63
+
64
+ @property
65
+ def in_onnx_export(self) -> bool:
66
+ """Whether it is in the middle of ONNX export."""
67
+ return self._in_onnx_export
68
+
69
+ @in_onnx_export.setter
70
+ def in_onnx_export(self, value: bool):
71
+ if type(value) is not bool:
72
+ raise TypeError("in_onnx_export must be a boolean")
73
+ self._in_onnx_export = value
74
+
75
+ @property
76
+ def autograd_inlining(self) -> bool:
77
+ """Whether Autograd must be inlined."""
78
+ return self._autograd_inlining
79
+
80
+ @autograd_inlining.setter
81
+ def autograd_inlining(self, value: bool):
82
+ if type(value) is not bool:
83
+ raise TypeError("autograd_inlining must be a boolean")
84
+ self._autograd_inlining = value
85
+
86
+
87
+ GLOBALS = _InternalGlobals()
.venv/Lib/site-packages/torch/onnx/_internal/__init__.py ADDED
File without changes
.venv/Lib/site-packages/torch/onnx/_internal/_lazy_import.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility to lazily import modules."""
2
+
3
+ # mypy: allow-untyped-defs
4
+ from __future__ import annotations
5
+
6
+ import importlib
7
+ from typing import Any, TYPE_CHECKING
8
+
9
+
10
+ class _LazyModule:
11
+ """Lazily import a module."""
12
+
13
+ def __init__(self, module_name: str) -> None:
14
+ self._name = module_name
15
+ self._module: Any = None
16
+
17
+ def __repr__(self) -> str:
18
+ return f"<lazy module '{self._name}'>"
19
+
20
+ def __getattr__(self, attr):
21
+ if self._module is None:
22
+ self._module = importlib.import_module(".", self._name)
23
+ return getattr(self._module, attr)
24
+
25
+
26
+ # Import the following modules during type checking to enable code intelligence features,
27
+ # such as auto-completion in tools like pylance, even when these modules are not explicitly
28
+ # imported in user code.
29
+ # NOTE: Add additional used imports here.
30
+ if TYPE_CHECKING:
31
+ import onnx
32
+ import onnxscript
33
+ import onnxscript._framework_apis.torch_2_5 as onnxscript_apis
34
+
35
+ onnxscript_ir = onnxscript.ir
36
+
37
+ else:
38
+ onnx = _LazyModule("onnx")
39
+ onnxscript = _LazyModule("onnxscript")
40
+ onnxscript_ir = _LazyModule("onnxscript.ir")
41
+ onnxscript_apis = _LazyModule("onnxscript._framework_apis.torch_2_5")
.venv/Lib/site-packages/torch/onnx/_internal/diagnostics/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._diagnostic import (
2
+ create_export_diagnostic_context,
3
+ diagnose,
4
+ engine,
5
+ export_context,
6
+ ExportDiagnosticEngine,
7
+ TorchScriptOnnxExportDiagnostic,
8
+ )
9
+ from ._rules import rules
10
+ from .infra import levels
11
+
12
+
13
+ __all__ = [
14
+ "TorchScriptOnnxExportDiagnostic",
15
+ "ExportDiagnosticEngine",
16
+ "rules",
17
+ "levels",
18
+ "engine",
19
+ "export_context",
20
+ "create_export_diagnostic_context",
21
+ "diagnose",
22
+ ]
.venv/Lib/site-packages/torch/onnx/_internal/diagnostics/_diagnostic.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ """Diagnostic components for TorchScript based ONNX export, i.e. `torch.onnx.export`."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import contextlib
7
+ import gzip
8
+ from typing import TYPE_CHECKING
9
+
10
+ import torch
11
+ from torch.onnx._internal.diagnostics import infra
12
+ from torch.onnx._internal.diagnostics.infra import formatter, sarif
13
+ from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version
14
+ from torch.utils import cpp_backtrace
15
+
16
+
17
+ if TYPE_CHECKING:
18
+ from collections.abc import Generator
19
+
20
+
21
+ def _cpp_call_stack(frames_to_skip: int = 0, frames_to_log: int = 32) -> infra.Stack:
22
+ """Returns the current C++ call stack.
23
+
24
+ This function utilizes `torch.utils.cpp_backtrace` to get the current C++ call stack.
25
+ The returned C++ call stack is a concatenated string of the C++ call stack frames.
26
+ Each frame is separated by a newline character, in the same format of
27
+ r"frame #[0-9]+: (?P<frame_info>.*)". More info at `c10/util/Backtrace.cpp`.
28
+
29
+ """
30
+ frames = cpp_backtrace.get_cpp_backtrace(frames_to_skip, frames_to_log).split("\n")
31
+ frame_messages = []
32
+ for frame in frames:
33
+ segments = frame.split(":", 1)
34
+ if len(segments) == 2:
35
+ frame_messages.append(segments[1].strip())
36
+ else:
37
+ frame_messages.append("<unknown frame>")
38
+ return infra.Stack(
39
+ frames=[
40
+ infra.StackFrame(location=infra.Location(message=message))
41
+ for message in frame_messages
42
+ ]
43
+ )
44
+
45
+
46
+ class TorchScriptOnnxExportDiagnostic(infra.Diagnostic):
47
+ """Base class for all export diagnostics.
48
+
49
+ This class is used to represent all export diagnostics. It is a subclass of
50
+ infra.Diagnostic, and adds additional methods to add more information to the
51
+ diagnostic.
52
+ """
53
+
54
+ python_call_stack: infra.Stack | None = None
55
+ cpp_call_stack: infra.Stack | None = None
56
+
57
+ def __init__(
58
+ self,
59
+ *args,
60
+ frames_to_skip: int = 1,
61
+ cpp_stack: bool = False,
62
+ **kwargs,
63
+ ) -> None:
64
+ super().__init__(*args, **kwargs)
65
+ self.python_call_stack = self.record_python_call_stack(
66
+ frames_to_skip=frames_to_skip
67
+ )
68
+ if cpp_stack:
69
+ self.cpp_call_stack = self.record_cpp_call_stack(
70
+ frames_to_skip=frames_to_skip
71
+ )
72
+
73
+ def record_cpp_call_stack(self, frames_to_skip: int) -> infra.Stack:
74
+ """Records the current C++ call stack in the diagnostic."""
75
+ stack = _cpp_call_stack(frames_to_skip=frames_to_skip)
76
+ stack.message = "C++ call stack"
77
+ self.with_stack(stack)
78
+ return stack
79
+
80
+
81
+ class ExportDiagnosticEngine:
82
+ """PyTorch ONNX Export diagnostic engine.
83
+
84
+ The only purpose of creating this class instead of using `DiagnosticContext` directly
85
+ is to provide a background context for `diagnose` calls inside exporter.
86
+
87
+ By design, one `torch.onnx.export` call should initialize one diagnostic context.
88
+ All `diagnose` calls inside exporter should be made in the context of that export.
89
+ However, since diagnostic context is currently being accessed via a global variable,
90
+ there is no guarantee that the context is properly initialized. Therefore, we need
91
+ to provide a default background context to fallback to, otherwise any invocation of
92
+ exporter internals, e.g. unit tests, will fail due to missing diagnostic context.
93
+ This can be removed once the pipeline for context to flow through the exporter is
94
+ established.
95
+ """
96
+
97
+ contexts: list[infra.DiagnosticContext]
98
+ _background_context: infra.DiagnosticContext
99
+
100
+ def __init__(self) -> None:
101
+ self.contexts = []
102
+ self._background_context = infra.DiagnosticContext(
103
+ name="torch.onnx",
104
+ version=torch.__version__,
105
+ )
106
+
107
+ @property
108
+ def background_context(self) -> infra.DiagnosticContext:
109
+ return self._background_context
110
+
111
+ def create_diagnostic_context(
112
+ self,
113
+ name: str,
114
+ version: str,
115
+ options: infra.DiagnosticOptions | None = None,
116
+ ) -> infra.DiagnosticContext:
117
+ """Creates a new diagnostic context.
118
+
119
+ Args:
120
+ name: The subject name for the diagnostic context.
121
+ version: The subject version for the diagnostic context.
122
+ options: The options for the diagnostic context.
123
+
124
+ Returns:
125
+ A new diagnostic context.
126
+ """
127
+ if options is None:
128
+ options = infra.DiagnosticOptions()
129
+ context: infra.DiagnosticContext[infra.Diagnostic] = infra.DiagnosticContext(
130
+ name, version, options
131
+ )
132
+ self.contexts.append(context)
133
+ return context
134
+
135
+ def clear(self):
136
+ """Clears all diagnostic contexts."""
137
+ self.contexts.clear()
138
+ self._background_context.diagnostics.clear()
139
+
140
+ def to_json(self) -> str:
141
+ return formatter.sarif_to_json(self.sarif_log())
142
+
143
+ def dump(self, file_path: str, compress: bool = False) -> None:
144
+ """Dumps the SARIF log to a file."""
145
+ if compress:
146
+ with gzip.open(file_path, "wt") as f:
147
+ f.write(self.to_json())
148
+ else:
149
+ with open(file_path, "w") as f:
150
+ f.write(self.to_json())
151
+
152
+ def sarif_log(self):
153
+ log = sarif.SarifLog(
154
+ version=sarif_version.SARIF_VERSION,
155
+ schema_uri=sarif_version.SARIF_SCHEMA_LINK,
156
+ runs=[context.sarif() for context in self.contexts],
157
+ )
158
+
159
+ log.runs.append(self._background_context.sarif())
160
+ return log
161
+
162
+
163
+ engine = ExportDiagnosticEngine()
164
+ _context = engine.background_context
165
+
166
+
167
+ @contextlib.contextmanager
168
+ def create_export_diagnostic_context() -> (
169
+ Generator[infra.DiagnosticContext, None, None]
170
+ ):
171
+ """Create a diagnostic context for export.
172
+
173
+ This is a workaround for code robustness since diagnostic context is accessed by
174
+ export internals via global variable. See `ExportDiagnosticEngine` for more details.
175
+ """
176
+ global _context
177
+ assert (
178
+ _context == engine.background_context
179
+ ), "Export context is already set. Nested export is not supported."
180
+ _context = engine.create_diagnostic_context(
181
+ "torch.onnx.export",
182
+ torch.__version__,
183
+ )
184
+ try:
185
+ yield _context
186
+ finally:
187
+ _context = engine.background_context
188
+
189
+
190
+ def diagnose(
191
+ rule: infra.Rule,
192
+ level: infra.Level,
193
+ message: str | None = None,
194
+ frames_to_skip: int = 2,
195
+ **kwargs,
196
+ ) -> TorchScriptOnnxExportDiagnostic:
197
+ """Creates a diagnostic and record it in the global diagnostic context.
198
+
199
+ This is a wrapper around `context.log` that uses the global diagnostic
200
+ context.
201
+ """
202
+ diagnostic = TorchScriptOnnxExportDiagnostic(
203
+ rule, level, message, frames_to_skip=frames_to_skip, **kwargs
204
+ )
205
+ export_context().log(diagnostic)
206
+ return diagnostic
207
+
208
+
209
+ def export_context() -> infra.DiagnosticContext:
210
+ global _context
211
+ return _context
.venv/Lib/site-packages/torch/onnx/_internal/diagnostics/_rules.py ADDED
@@ -0,0 +1,636 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ """
3
+ GENERATED CODE - DO NOT EDIT DIRECTLY
4
+ This file is generated by gen_diagnostics.py.
5
+ See tools/onnx/gen_diagnostics.py for more information.
6
+
7
+ Diagnostic rules for PyTorch ONNX export.
8
+ """
9
+
10
+ import dataclasses
11
+ from typing import Tuple
12
+
13
+ # flake8: noqa
14
+ from torch.onnx._internal.diagnostics import infra
15
+
16
+
17
+ """
18
+ GENERATED CODE - DO NOT EDIT DIRECTLY
19
+ The purpose of generating a class for each rule is to override the `format_message`
20
+ method to provide more details in the signature about the format arguments.
21
+ """
22
+
23
+
24
+ class _NodeMissingOnnxShapeInference(infra.Rule):
25
+ """Node is missing ONNX shape inference."""
26
+
27
+ def format_message(self, op_name) -> str: # type: ignore[override]
28
+ """Returns the formatted default message of this Rule.
29
+
30
+ Message template: 'The shape inference of {op_name} type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.'
31
+ """
32
+ return self.message_default_template.format(op_name=op_name)
33
+
34
+ def format( # type: ignore[override]
35
+ self, level: infra.Level, op_name
36
+ ) -> Tuple[infra.Rule, infra.Level, str]:
37
+ """Returns a tuple of (Rule, Level, message) for this Rule.
38
+
39
+ Message template: 'The shape inference of {op_name} type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.'
40
+ """
41
+ return self, level, self.format_message(op_name=op_name)
42
+
43
+
44
+ class _MissingCustomSymbolicFunction(infra.Rule):
45
+ """Missing symbolic function for custom PyTorch operator, cannot translate node to ONNX."""
46
+
47
+ def format_message(self, op_name) -> str: # type: ignore[override]
48
+ """Returns the formatted default message of this Rule.
49
+
50
+ Message template: 'ONNX export failed on an operator with unrecognized namespace {op_name}. If you are trying to export a custom operator, make sure you registered it with the right domain and version.'
51
+ """
52
+ return self.message_default_template.format(op_name=op_name)
53
+
54
+ def format( # type: ignore[override]
55
+ self, level: infra.Level, op_name
56
+ ) -> Tuple[infra.Rule, infra.Level, str]:
57
+ """Returns a tuple of (Rule, Level, message) for this Rule.
58
+
59
+ Message template: 'ONNX export failed on an operator with unrecognized namespace {op_name}. If you are trying to export a custom operator, make sure you registered it with the right domain and version.'
60
+ """
61
+ return self, level, self.format_message(op_name=op_name)
62
+
63
+
64
+ class _MissingStandardSymbolicFunction(infra.Rule):
65
+ """Missing symbolic function for standard PyTorch operator, cannot translate node to ONNX."""
66
+
67
+ def format_message( # type: ignore[override]
68
+ self, op_name, opset_version, issue_url
69
+ ) -> str:
70
+ """Returns the formatted default message of this Rule.
71
+
72
+ Message template: "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: {issue_url}."
73
+ """
74
+ return self.message_default_template.format(
75
+ op_name=op_name, opset_version=opset_version, issue_url=issue_url
76
+ )
77
+
78
+ def format( # type: ignore[override]
79
+ self, level: infra.Level, op_name, opset_version, issue_url
80
+ ) -> Tuple[infra.Rule, infra.Level, str]:
81
+ """Returns a tuple of (Rule, Level, message) for this Rule.
82
+
83
+ Message template: "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: {issue_url}."
84
+ """
85
+ return (
86
+ self,
87
+ level,
88
+ self.format_message(
89
+ op_name=op_name, opset_version=opset_version, issue_url=issue_url
90
+ ),
91
+ )
92
+
93
+
94
+ class _OperatorSupportedInNewerOpsetVersion(infra.Rule):
95
+ """Operator is supported in newer opset version."""
96
+
97
+ def format_message( # type: ignore[override]
98
+ self, op_name, opset_version, supported_opset_version
99
+ ) -> str:
100
+ """Returns the formatted default message of this Rule.
101
+
102
+ Message template: "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. Support for this operator was added in version {supported_opset_version}, try exporting with this version."
103
+ """
104
+ return self.message_default_template.format(
105
+ op_name=op_name,
106
+ opset_version=opset_version,
107
+ supported_opset_version=supported_opset_version,
108
+ )
109
+
110
+ def format( # type: ignore[override]
111
+ self, level: infra.Level, op_name, opset_version, supported_opset_version
112
+ ) -> Tuple[infra.Rule, infra.Level, str]:
113
+ """Returns a tuple of (Rule, Level, message) for this Rule.
114
+
115
+ Message template: "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. Support for this operator was added in version {supported_opset_version}, try exporting with this version."
116
+ """
117
+ return (
118
+ self,
119
+ level,
120
+ self.format_message(
121
+ op_name=op_name,
122
+ opset_version=opset_version,
123
+ supported_opset_version=supported_opset_version,
124
+ ),
125
+ )
126
+
127
+
128
+ class _FxGraphToOnnx(infra.Rule):
129
+ """Transforms graph from FX IR to ONNX IR."""
130
+
131
+ def format_message(self, graph_name) -> str: # type: ignore[override]
132
+ """Returns the formatted default message of this Rule.
133
+
134
+ Message template: 'Transforming FX graph {graph_name} to ONNX graph.'
135
+ """
136
+ return self.message_default_template.format(graph_name=graph_name)
137
+
138
+ def format( # type: ignore[override]
139
+ self, level: infra.Level, graph_name
140
+ ) -> Tuple[infra.Rule, infra.Level, str]:
141
+ """Returns a tuple of (Rule, Level, message) for this Rule.
142
+
143
+ Message template: 'Transforming FX graph {graph_name} to ONNX graph.'
144
+ """
145
+ return self, level, self.format_message(graph_name=graph_name)
146
+
147
+
148
+ class _FxNodeToOnnx(infra.Rule):
149
+ """Transforms an FX node to an ONNX node."""
150
+
151
+ def format_message(self, node_repr) -> str: # type: ignore[override]
152
+ """Returns the formatted default message of this Rule.
153
+
154
+ Message template: 'Transforming FX node {node_repr} to ONNX node.'
155
+ """
156
+ return self.message_default_template.format(node_repr=node_repr)
157
+
158
+ def format( # type: ignore[override]
159
+ self, level: infra.Level, node_repr
160
+ ) -> Tuple[infra.Rule, infra.Level, str]:
161
+ """Returns a tuple of (Rule, Level, message) for this Rule.
162
+
163
+ Message template: 'Transforming FX node {node_repr} to ONNX node.'
164
+ """
165
+ return self, level, self.format_message(node_repr=node_repr)
166
+
167
+
168
+ class _FxPass(infra.Rule):
169
+ """FX graph transformation during ONNX export before converting from FX IR to ONNX IR."""
170
+
171
+ def format_message(self, pass_name) -> str: # type: ignore[override]
172
+ """Returns the formatted default message of this Rule.
173
+
174
+ Message template: 'Running {pass_name} pass.'
175
+ """
176
+ return self.message_default_template.format(pass_name=pass_name)
177
+
178
+ def format( # type: ignore[override]
179
+ self, level: infra.Level, pass_name
180
+ ) -> Tuple[infra.Rule, infra.Level, str]:
181
+ """Returns a tuple of (Rule, Level, message) for this Rule.
182
+
183
+ Message template: 'Running {pass_name} pass.'
184
+ """
185
+ return self, level, self.format_message(pass_name=pass_name)
186
+
187
+
188
+ class _NoSymbolicFunctionForCallFunction(infra.Rule):
189
+ """Cannot find symbolic function to convert the "call_function" FX node to ONNX."""
190
+
191
+ def format_message(self, target) -> str: # type: ignore[override]
192
+ """Returns the formatted default message of this Rule.
193
+
194
+ Message template: 'No symbolic function to convert the "call_function" node {target} to ONNX. '
195
+ """
196
+ return self.message_default_template.format(target=target)
197
+
198
+ def format( # type: ignore[override]
199
+ self, level: infra.Level, target
200
+ ) -> Tuple[infra.Rule, infra.Level, str]:
201
+ """Returns a tuple of (Rule, Level, message) for this Rule.
202
+
203
+ Message template: 'No symbolic function to convert the "call_function" node {target} to ONNX. '
204
+ """
205
+ return self, level, self.format_message(target=target)
206
+
207
+
208
+ class _UnsupportedFxNodeAnalysis(infra.Rule):
209
+ """Result from FX graph analysis to reveal unsupported FX nodes."""
210
+
211
+ def format_message( # type: ignore[override]
212
+ self, node_op_to_target_mapping
213
+ ) -> str:
214
+ """Returns the formatted default message of this Rule.
215
+
216
+ Message template: 'Unsupported FX nodes: {node_op_to_target_mapping}. '
217
+ """
218
+ return self.message_default_template.format(
219
+ node_op_to_target_mapping=node_op_to_target_mapping
220
+ )
221
+
222
+ def format( # type: ignore[override]
223
+ self, level: infra.Level, node_op_to_target_mapping
224
+ ) -> Tuple[infra.Rule, infra.Level, str]:
225
+ """Returns a tuple of (Rule, Level, message) for this Rule.
226
+
227
+ Message template: 'Unsupported FX nodes: {node_op_to_target_mapping}. '
228
+ """
229
+ return (
230
+ self,
231
+ level,
232
+ self.format_message(node_op_to_target_mapping=node_op_to_target_mapping),
233
+ )
234
+
235
+
236
+ class _OpLevelDebugging(infra.Rule):
237
+ """Report any op level validation failure in warnings."""
238
+
239
+ def format_message(self, node, symbolic_fn) -> str: # type: ignore[override]
240
+ """Returns the formatted default message of this Rule.
241
+
242
+ Message template: 'FX node: {node} and its onnx function: {symbolic_fn} fails on op level validation.'
243
+ """
244
+ return self.message_default_template.format(node=node, symbolic_fn=symbolic_fn)
245
+
246
+ def format( # type: ignore[override]
247
+ self, level: infra.Level, node, symbolic_fn
248
+ ) -> Tuple[infra.Rule, infra.Level, str]:
249
+ """Returns a tuple of (Rule, Level, message) for this Rule.
250
+
251
+ Message template: 'FX node: {node} and its onnx function: {symbolic_fn} fails on op level validation.'
252
+ """
253
+ return self, level, self.format_message(node=node, symbolic_fn=symbolic_fn)
254
+
255
+
256
+ class _FindOpschemaMatchedSymbolicFunction(infra.Rule):
257
+ """Find the OnnxFunction that matches the input/attribute dtypes by comparing them with their opschemas."""
258
+
259
+ def format_message(self, symbolic_fn, node) -> str: # type: ignore[override]
260
+ """Returns the formatted default message of this Rule.
261
+
262
+ Message template: 'The OnnxFunction: {symbolic_fn} is the nearest match of the node {node}.'
263
+ """
264
+ return self.message_default_template.format(symbolic_fn=symbolic_fn, node=node)
265
+
266
+ def format( # type: ignore[override]
267
+ self, level: infra.Level, symbolic_fn, node
268
+ ) -> Tuple[infra.Rule, infra.Level, str]:
269
+ """Returns a tuple of (Rule, Level, message) for this Rule.
270
+
271
+ Message template: 'The OnnxFunction: {symbolic_fn} is the nearest match of the node {node}.'
272
+ """
273
+ return self, level, self.format_message(symbolic_fn=symbolic_fn, node=node)
274
+
275
+
276
+ class _FxNodeInsertTypePromotion(infra.Rule):
277
+ """Determine if type promotion is required for the FX node. Insert cast nodes if needed."""
278
+
279
+ def format_message(self, target) -> str: # type: ignore[override]
280
+ """Returns the formatted default message of this Rule.
281
+
282
+ Message template: 'Performing explicit type promotion for node {target}. '
283
+ """
284
+ return self.message_default_template.format(target=target)
285
+
286
+ def format( # type: ignore[override]
287
+ self, level: infra.Level, target
288
+ ) -> Tuple[infra.Rule, infra.Level, str]:
289
+ """Returns a tuple of (Rule, Level, message) for this Rule.
290
+
291
+ Message template: 'Performing explicit type promotion for node {target}. '
292
+ """
293
+ return self, level, self.format_message(target=target)
294
+
295
+
296
+ class _FindOperatorOverloadsInOnnxRegistry(infra.Rule):
297
+ """Find the list of OnnxFunction of the PyTorch operator in onnx registry."""
298
+
299
+ def format_message(self, node) -> str: # type: ignore[override]
300
+ """Returns the formatted default message of this Rule.
301
+
302
+ Message template: 'Checking if the FX node: {node} is supported in onnx registry.'
303
+ """
304
+ return self.message_default_template.format(node=node)
305
+
306
+ def format( # type: ignore[override]
307
+ self, level: infra.Level, node
308
+ ) -> Tuple[infra.Rule, infra.Level, str]:
309
+ """Returns a tuple of (Rule, Level, message) for this Rule.
310
+
311
+ Message template: 'Checking if the FX node: {node} is supported in onnx registry.'
312
+ """
313
+ return self, level, self.format_message(node=node)
314
+
315
+
316
+ @dataclasses.dataclass
317
+ class _POERules(infra.RuleCollection):
318
+ node_missing_onnx_shape_inference: _NodeMissingOnnxShapeInference = dataclasses.field(
319
+ default=_NodeMissingOnnxShapeInference.from_sarif(
320
+ **{
321
+ "id": "POE0001",
322
+ "name": "node-missing-onnx-shape-inference",
323
+ "short_description": {"text": "Node is missing ONNX shape inference."},
324
+ "full_description": {
325
+ "text": "Node is missing ONNX shape inference. This usually happens when the node is not valid under standard ONNX operator spec.",
326
+ "markdown": "Node is missing ONNX shape inference.\nThis usually happens when the node is not valid under standard ONNX operator spec.\n",
327
+ },
328
+ "message_strings": {
329
+ "default": {
330
+ "text": "The shape inference of {op_name} type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function."
331
+ }
332
+ },
333
+ "help_uri": None,
334
+ "properties": {"deprecated": False, "tags": []},
335
+ }
336
+ ),
337
+ init=False,
338
+ )
339
+ """Node is missing ONNX shape inference."""
340
+
341
+ missing_custom_symbolic_function: _MissingCustomSymbolicFunction = dataclasses.field(
342
+ default=_MissingCustomSymbolicFunction.from_sarif(
343
+ **{
344
+ "id": "POE0002",
345
+ "name": "missing-custom-symbolic-function",
346
+ "short_description": {
347
+ "text": "Missing symbolic function for custom PyTorch operator, cannot translate node to ONNX."
348
+ },
349
+ "full_description": {
350
+ "text": "Missing symbolic function for custom PyTorch operator, cannot translate node to ONNX.",
351
+ "markdown": "Missing symbolic function for custom PyTorch operator, cannot translate node to ONNX.\n",
352
+ },
353
+ "message_strings": {
354
+ "default": {
355
+ "text": "ONNX export failed on an operator with unrecognized namespace {op_name}. If you are trying to export a custom operator, make sure you registered it with the right domain and version."
356
+ }
357
+ },
358
+ "help_uri": None,
359
+ "properties": {"deprecated": False, "tags": []},
360
+ }
361
+ ),
362
+ init=False,
363
+ )
364
+ """Missing symbolic function for custom PyTorch operator, cannot translate node to ONNX."""
365
+
366
+ missing_standard_symbolic_function: _MissingStandardSymbolicFunction = dataclasses.field(
367
+ default=_MissingStandardSymbolicFunction.from_sarif(
368
+ **{
369
+ "id": "POE0003",
370
+ "name": "missing-standard-symbolic-function",
371
+ "short_description": {
372
+ "text": "Missing symbolic function for standard PyTorch operator, cannot translate node to ONNX."
373
+ },
374
+ "full_description": {
375
+ "text": "Missing symbolic function for standard PyTorch operator, cannot translate node to ONNX.",
376
+ "markdown": "Missing symbolic function for standard PyTorch operator, cannot translate node to ONNX.\n",
377
+ },
378
+ "message_strings": {
379
+ "default": {
380
+ "text": "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: {issue_url}."
381
+ }
382
+ },
383
+ "help_uri": None,
384
+ "properties": {"deprecated": False, "tags": []},
385
+ }
386
+ ),
387
+ init=False,
388
+ )
389
+ """Missing symbolic function for standard PyTorch operator, cannot translate node to ONNX."""
390
+
391
+ operator_supported_in_newer_opset_version: _OperatorSupportedInNewerOpsetVersion = dataclasses.field(
392
+ default=_OperatorSupportedInNewerOpsetVersion.from_sarif(
393
+ **{
394
+ "id": "POE0004",
395
+ "name": "operator-supported-in-newer-opset-version",
396
+ "short_description": {
397
+ "text": "Operator is supported in newer opset version."
398
+ },
399
+ "full_description": {
400
+ "text": "Operator is supported in newer opset version.",
401
+ "markdown": "Operator is supported in newer opset version.\n\nExample:\n```python\ntorch.onnx.export(model, args, ..., opset_version=9)\n```\n",
402
+ },
403
+ "message_strings": {
404
+ "default": {
405
+ "text": "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. Support for this operator was added in version {supported_opset_version}, try exporting with this version."
406
+ }
407
+ },
408
+ "help_uri": None,
409
+ "properties": {"deprecated": False, "tags": []},
410
+ }
411
+ ),
412
+ init=False,
413
+ )
414
+ """Operator is supported in newer opset version."""
415
+
416
+ fx_graph_to_onnx: _FxGraphToOnnx = dataclasses.field(
417
+ default=_FxGraphToOnnx.from_sarif(
418
+ **{
419
+ "id": "FXE0007",
420
+ "name": "fx-graph-to-onnx",
421
+ "short_description": {
422
+ "text": "Transforms graph from FX IR to ONNX IR."
423
+ },
424
+ "full_description": {
425
+ "text": "Transforms graph from FX IR to ONNX IR.",
426
+ "markdown": "This diagnostic tracks the transformation process from an FX Graph (in FX IR) to an ONNX Graph (in ONNX IR).\n\n## Key Representations:\n\n- **FX Graph**: The graph in FX IR produced by dynamo or symbolic tracing.\n- **ONNX Graph**: The graph in ONNX IR and [operators](https://onnx.ai/onnx/operators/).\n\n## Additional Notes:\n\n- Prior to this transformation step, the FX graph undergoes preprocessing through multiple FX passes.\n To gain insight into these transformations, refer to diagnostic `FXE0010`.\n- To enable a detailed view of the graph transformation in progress within this diagnostic, switch to the DEBUG mode.\n\n - Set DiagnosticOptions.verbosity_level to logging.DEBUG.\n - Activate the environment variable TORCH_LOGS='onnx_diagnostics'.\n\n- For specific information related to node-level FX to ONNX transformations, explore the diagnostic `FXE0008`.\n",
427
+ },
428
+ "message_strings": {
429
+ "default": {
430
+ "text": "Transforming FX graph {graph_name} to ONNX graph."
431
+ }
432
+ },
433
+ "help_uri": None,
434
+ "properties": {"deprecated": False, "tags": []},
435
+ }
436
+ ),
437
+ init=False,
438
+ )
439
+ """Transforms graph from FX IR to ONNX IR."""
440
+
441
+ fx_node_to_onnx: _FxNodeToOnnx = dataclasses.field(
442
+ default=_FxNodeToOnnx.from_sarif(
443
+ **{
444
+ "id": "FXE0008",
445
+ "name": "fx-node-to-onnx",
446
+ "short_description": {"text": "Transforms an FX node to an ONNX node."},
447
+ "full_description": {
448
+ "text": "Transforms an FX node to an ONNX node.",
449
+ "markdown": "This diagnostic tracks the transformation process from an FX Node to ONNX [Operators](https://onnx.ai/onnx/operators/).\n\nThe process of converting FX Node to ONNX Node involves dealing with six distinct node types:\n 1. `placeholder`: Represents a module input, maps to an ONNX graph input.\n 2. `call_module`: Symbolizes a call to a submodule, maps to an ONNX\n 3. `call_method`: Symbolizes a method call. Not yet implemented.\n 4. `call_function`: Symbolizes a function call. [Core ATen](https://pytorch.org/docs/stable/ir.html#core-aten-ir) is expected\n as the function call target. The mapping from ATen to ONNX is implemented by [ONNXScript torchlib](https://github.com/microsoft/onnxscript/tree/main/onnxscript/function_libs/torch_lib/ops).\n This [guide](https://pytorch.org/docs/stable/onnx.html#onnx-script-functions) shows how to write and register a custom symbolic function for call_function FX node.\n 5. `get_attr`: Indicates an attribute access within the current module. Maps to an ONNX graph initializer.\n 6. `output`: Represents the module's output. Maps to an ONNX graph output.\n\nFor a granular understanding of how each node type is transformed, refer to the implementation details in `FxOnnxInterpreter`.\n",
450
+ },
451
+ "message_strings": {
452
+ "default": {
453
+ "text": "Transforming FX node {node_repr} to ONNX node."
454
+ }
455
+ },
456
+ "help_uri": None,
457
+ "properties": {"deprecated": False, "tags": []},
458
+ }
459
+ ),
460
+ init=False,
461
+ )
462
+ """Transforms an FX node to an ONNX node."""
463
+
464
+ fx_pass: _FxPass = dataclasses.field(
465
+ default=_FxPass.from_sarif(
466
+ **{
467
+ "id": "FXE0010",
468
+ "name": "fx-pass",
469
+ "short_description": {
470
+ "text": "FX graph transformation during ONNX export before converting from FX IR to ONNX IR."
471
+ },
472
+ "full_description": {
473
+ "text": "FX graph transformation during ONNX export before converting from FX IR to ONNX IR.",
474
+ "markdown": "This diagnostic tracks the FX passes executed during the ONNX export process prior\nto converting from FX IR (Intermediate Representation) to ONNX IR.\n\nUnder the scope of ONNX export, an FX pass refers to a specific transformation applied to the FX GraphModule.\nThe primary aim of these passes is to streamline the graph into a format that aligns more with the ONNX IR.\nMoreover, these passes work to substitute unsupported FX IR features with those recognized and endorsed by\nONNX IR. Common transformations include, but aren't limited to, decomposition, functionalization and\ntype promotion.\n\nFor those who are interested in a comprehensive log detailing the modifications made during these passes,\nthere are a couple of options:\n\n- Set DiagnosticOptions.verbosity_level to logging.DEBUG.\n- Activate the environment variable TORCH_LOGS='onnx_diagnostics'.\n\nHowever, it's noteworthy that by default, such detailed logging is turned off. The primary reason being\nits considerable impact on performance.\n\nFor an in-depth understanding of each specific pass, please refer to the directory: torch/onnx/_internal/fx/passes.\n",
475
+ },
476
+ "message_strings": {"default": {"text": "Running {pass_name} pass."}},
477
+ "help_uri": None,
478
+ "properties": {"deprecated": False, "tags": []},
479
+ }
480
+ ),
481
+ init=False,
482
+ )
483
+ """FX graph transformation during ONNX export before converting from FX IR to ONNX IR."""
484
+
485
+ no_symbolic_function_for_call_function: _NoSymbolicFunctionForCallFunction = dataclasses.field(
486
+ default=_NoSymbolicFunctionForCallFunction.from_sarif(
487
+ **{
488
+ "id": "FXE0011",
489
+ "name": "no-symbolic-function-for-call-function",
490
+ "short_description": {
491
+ "text": 'Cannot find symbolic function to convert the "call_function" FX node to ONNX.'
492
+ },
493
+ "full_description": {
494
+ "text": 'Cannot find symbolic function to convert the "call_function" FX node to ONNX. ',
495
+ "markdown": 'This error occurs when the ONNX converter is unable to find a corresponding symbolic function\nto convert a "call_function" node in the input graph to its equivalence in ONNX. The "call_function"\nnode represents a normalized function call in PyTorch, such as "torch.aten.ops.add".\n\nTo resolve this error, you can try one of the following:\n\n- If exists, apply the auto-fix suggested by the diagnostic. TODO: this part is not available yet.\n- Rewrite the model using only supported PyTorch operators or functions.\n- Follow this [guide](https://pytorch.org/tutorials/beginner/onnx/onnx_registry_tutorial.html#overview) to write and\n register a custom symbolic function for the unsupported call_function FX node.\n',
496
+ },
497
+ "message_strings": {
498
+ "default": {
499
+ "text": 'No symbolic function to convert the "call_function" node {target} to ONNX. '
500
+ }
501
+ },
502
+ "help_uri": None,
503
+ "properties": {"deprecated": False, "tags": []},
504
+ }
505
+ ),
506
+ init=False,
507
+ )
508
+ """Cannot find symbolic function to convert the "call_function" FX node to ONNX."""
509
+
510
+ unsupported_fx_node_analysis: _UnsupportedFxNodeAnalysis = dataclasses.field(
511
+ default=_UnsupportedFxNodeAnalysis.from_sarif(
512
+ **{
513
+ "id": "FXE0012",
514
+ "name": "unsupported-fx-node-analysis",
515
+ "short_description": {
516
+ "text": "Result from FX graph analysis to reveal unsupported FX nodes."
517
+ },
518
+ "full_description": {
519
+ "text": "Result from FX graph analysis to reveal unsupported FX nodes.",
520
+ "markdown": "This error indicates that an FX graph contains one or more unsupported nodes. The error message\nis typically accompanied by a list of the unsupported nodes found during analysis.\n\nTo resolve this error, you can try resolving each individual unsupported node error by following\nthe suggestions by its diagnostic. Typically, options include:\n\n- If exists, apply the auto-fix suggested by the diagnostic. TODO: this part is not available yet.\n- Rewrite the model using only supported PyTorch operators or functions.\n- Follow this [guide](https://pytorch.org/docs/stable/onnx.html#onnx-script-functions) to write and\n register a custom symbolic function for the unsupported call_function FX node.\n",
521
+ },
522
+ "message_strings": {
523
+ "default": {
524
+ "text": "Unsupported FX nodes: {node_op_to_target_mapping}. "
525
+ }
526
+ },
527
+ "help_uri": None,
528
+ "properties": {"deprecated": False, "tags": []},
529
+ }
530
+ ),
531
+ init=False,
532
+ )
533
+ """Result from FX graph analysis to reveal unsupported FX nodes."""
534
+
535
+ op_level_debugging: _OpLevelDebugging = dataclasses.field(
536
+ default=_OpLevelDebugging.from_sarif(
537
+ **{
538
+ "id": "FXE0013",
539
+ "name": "op-level-debugging",
540
+ "short_description": {
541
+ "text": "Report any op level validation failure in warnings."
542
+ },
543
+ "full_description": {
544
+ "text": "Report any op level validation failure in warnings.",
545
+ "markdown": "This warning message indicates that during op level debugging, certain symbolic functions\nhave failed to match the results of torch ops when using real tensors generated from fake\ntensors. It is important to note that the symbolic functions may not necessarily be\nincorrect, as the validation process is non-deterministic and should only be used as a\nreference.\n\nThere are two categories of warnings that can be triggered:\n\n1. Non-validated operators:\n If the warnings are caused by the following errors, they can be disregarded by users,\n as these errors occur due to the non-deterministic nature of the validation. However,\n it is important to be aware that the operators have not been validated.\n\n - IndexError: Unsupported input arguments of randomized dimensions/indices(INT64).\n - RuntimeError: Unsupported input arguments for torch ops are generated.\n - ValueError: Arguments/keyword arguments do not match the signature of the symbolic function.\n\n2. Potentially wrong torchlib operators:\n If the warnings are triggered by the following error, users should be aware that the symbolic functions\n may be incorrect in dispatching or implementation. In such cases, it is recommended to report\n the issue to the PyTorch-ONNX team, or create/register a custom symbolic function to replace the default one.\n\n - AssertionError: The symbolic function is potentially wrong as the results do not match the results of torch ops.\n - TypeError: The symbolic function is potentially wrong as the opschema doesn't match inputs.\n",
546
+ },
547
+ "message_strings": {
548
+ "default": {
549
+ "text": "FX node: {node} and its onnx function: {symbolic_fn} fails on op level validation."
550
+ }
551
+ },
552
+ "help_uri": None,
553
+ "properties": {"deprecated": False, "tags": []},
554
+ }
555
+ ),
556
+ init=False,
557
+ )
558
+ """Report any op level validation failure in warnings."""
559
+
560
+ find_opschema_matched_symbolic_function: _FindOpschemaMatchedSymbolicFunction = dataclasses.field(
561
+ default=_FindOpschemaMatchedSymbolicFunction.from_sarif(
562
+ **{
563
+ "id": "FXE0014",
564
+ "name": "find-opschema-matched-symbolic-function",
565
+ "short_description": {
566
+ "text": "Find the OnnxFunction that matches the input/attribute dtypes by comparing them with their opschemas."
567
+ },
568
+ "full_description": {
569
+ "text": "Find the OnnxFunction that matches the input dtypes by comparing them with their opschemas. A warning will be issued if the matched OnnxFunction is not an exact match.",
570
+ "markdown": "When an ATen/Custom operator is registered and needs to be dispatched to an OnnxFunction, the input/attribute\ndtypes of the ATen/Custom operator are compared with the input/attribute dtypes of the OnnxFunction opschemas\nto find a match. However, if a perfect/exact match is not found, the dispatcher will attempt to find\nthe nearest match with the highest number of input/attribute dtypes matching the OnnxFunction opschemas, while\nissuing a warning.\n\nThere are two types of level that can be triggered in this rule:\n\n1. NOTE: A perfect match is found, and no warning is issued.\n2. WARNING: The matched OnnxFunction is not a perfect/exact match.\n\nHere are some suggestions based on the WARNING situation:\n\n1. If there are NO errors or mismatches in the results, it is safe to disregard this warning,\n as the definition of OnnxFunction schema is usually more stringent.\n2. If there are errors or mismatches in the results, it is recommended to:\n (a) Enable op_level_debugging to determine if the OnnxFunction might be incorrect.\n (b) Report the issue to the PyTorch-ONNX team.\n (c) Create/register a custom symbolic function to replace the default one.\n",
571
+ },
572
+ "message_strings": {
573
+ "default": {
574
+ "text": "The OnnxFunction: {symbolic_fn} is the nearest match of the node {node}."
575
+ }
576
+ },
577
+ "help_uri": None,
578
+ "properties": {"deprecated": False, "tags": []},
579
+ }
580
+ ),
581
+ init=False,
582
+ )
583
+ """Find the OnnxFunction that matches the input/attribute dtypes by comparing them with their opschemas."""
584
+
585
+ fx_node_insert_type_promotion: _FxNodeInsertTypePromotion = dataclasses.field(
586
+ default=_FxNodeInsertTypePromotion.from_sarif(
587
+ **{
588
+ "id": "FXE0015",
589
+ "name": "fx-node-insert-type-promotion",
590
+ "short_description": {
591
+ "text": "Determine if type promotion is required for the FX node. Insert cast nodes if needed."
592
+ },
593
+ "full_description": {
594
+ "text": "Determine if type promotion is required for the FX node. Insert cast nodes if needed.",
595
+ "markdown": "This diagnostic monitors the node-level type promotion insertion process. In PyTorch, there is an automatic process called implicit type promotion,\nwhere the input types of an operator are promoted to a common type. The determination of the common type is based on the type promotion rule specific to each operator.\nTo learn more about PyTorch's type promotion rules, refer to the [elementwise_dtypes doc](https://github.com/pytorch/pytorch/blob/f044613f78df713fb57f70c608483c9f10ad332e/torch/_prims_common/__init__.py#L1252-L1335)\nand [torch._refs ops](https://github.com/pytorch/pytorch/blob/a475ea4542dfe961c9d097e33ab5041f61c8c17f/torch/_refs/__init__.py#L484).\n\nHowever, implicit type promotion is not supported in ONNX. Therefore, to replicate the PyTorch behavior, we need to explicitly insert cast nodes.\nThis diagnostic tracks the process of node-level type promotion insertion.\n\nThe type promotion rules used by this process can be found in `torch/onnx/_internal/fx/passes/type_promotion.py.`\nTo update or add new type promotion rules, please refer to the [Note: Update type promotion rule] section.\n",
596
+ },
597
+ "message_strings": {
598
+ "default": {
599
+ "text": "Performing explicit type promotion for node {target}. "
600
+ }
601
+ },
602
+ "help_uri": None,
603
+ "properties": {"deprecated": False, "tags": []},
604
+ }
605
+ ),
606
+ init=False,
607
+ )
608
+ """Determine if type promotion is required for the FX node. Insert cast nodes if needed."""
609
+
610
+ find_operator_overloads_in_onnx_registry: _FindOperatorOverloadsInOnnxRegistry = dataclasses.field(
611
+ default=_FindOperatorOverloadsInOnnxRegistry.from_sarif(
612
+ **{
613
+ "id": "FXE0016",
614
+ "name": "find-operator-overloads-in-onnx-registry",
615
+ "short_description": {
616
+ "text": "Find the list of OnnxFunction of the PyTorch operator in onnx registry."
617
+ },
618
+ "full_description": {
619
+ "text": "This rule involves finding the list of OnnxFunction for the PyTorch operator overload in the ONNX registry. If the operator overload is not supported but its default overload is, a warning will be issued. If both the operator overload and its default overload are not supported, an error will be issued.",
620
+ "markdown": "The operator overload name serves the purpose of verifying whether a PyTorch operator is registered in the ONNX registry.\nIf it's not found, the dispatcher takes a fallback approach and tries to locate the default overload of the PyTorch\noperator in the registry. If even the default overload is absent, it signifies that the operator is officially unsupported.\n\nThere are three types of level that can be triggered in this rule:\n\n1. NOTE: The op overload is supported.\n2. WARNING: The op overload is not supported, but it's default overload is supported.\n3. ERROR: The op overload is not supported, and it's default overload is also not supported.\n\nHere are some suggestions based on the WARNING situation:\n\n1. If there are NO errors or mismatches in the results, it is safe to disregard this warning.\n2. If there are errors or mismatches in the results, it is recommended to:\n (a) Enable op_level_debugging to determine if the OnnxFunction might be incorrect.\n (b) Report the unsupported overload to the PyTorch-ONNX team.\n (c) Create/register a custom symbolic function to replace the default one.\n\nHere are some suggestions based on the ERROR situation:\n\n1. Report the unsupported operator to the PyTorch-ONNX team.\n2. Create/register a custom symbolic function to replace the default one.\n",
621
+ },
622
+ "message_strings": {
623
+ "default": {
624
+ "text": "Checking if the FX node: {node} is supported in onnx registry."
625
+ }
626
+ },
627
+ "help_uri": None,
628
+ "properties": {"deprecated": False, "tags": []},
629
+ }
630
+ ),
631
+ init=False,
632
+ )
633
+ """Find the list of OnnxFunction of the PyTorch operator in onnx registry."""
634
+
635
+
636
+ rules = _POERules()
.venv/Lib/site-packages/torch/onnx/_internal/diagnostics/infra/_infra.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ """This file defines an additional layer of abstraction on top of the SARIF OM."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import dataclasses
7
+ import enum
8
+ import logging
9
+ from typing import Mapping, Sequence
10
+
11
+ from torch.onnx._internal.diagnostics.infra import formatter, sarif
12
+
13
+
14
+ class Level(enum.IntEnum):
15
+ """The level of a diagnostic.
16
+
17
+ This class is used to represent the level of a diagnostic. The levels are defined
18
+ by the SARIF specification, and are not modifiable. For alternative categories,
19
+ please use infra.Tag instead. When selecting a level, please consider the following
20
+ guidelines:
21
+
22
+ - NONE: Informational result that does not indicate the presence of a problem.
23
+ - NOTE: An opportunity for improvement was found.
24
+ - WARNING: A potential problem was found.
25
+ - ERROR: A serious problem was found.
26
+
27
+ This level is a subclass of enum.IntEnum, and can be used as an integer. Its integer
28
+ value maps to the logging levels in Python's logging module. The mapping is as
29
+ follows:
30
+
31
+ Level.NONE = logging.DEBUG = 10
32
+ Level.NOTE = logging.INFO = 20
33
+ Level.WARNING = logging.WARNING = 30
34
+ Level.ERROR = logging.ERROR = 40
35
+ """
36
+
37
+ NONE = 10
38
+ NOTE = 20
39
+ WARNING = 30
40
+ ERROR = 40
41
+
42
+
43
+ levels = Level
44
+
45
+
46
+ class Tag(enum.Enum):
47
+ """The tag of a diagnostic. This class can be inherited to define custom tags."""
48
+
49
+
50
+ class PatchedPropertyBag(sarif.PropertyBag):
51
+ """Key/value pairs that provide additional information about the object.
52
+
53
+ The definition of PropertyBag via SARIF spec is "A property bag is an object (section 3.6)
54
+ containing an unordered set of properties with arbitrary names." However it is not
55
+ reflected in the json file, and therefore not captured by the python representation.
56
+ This patch adds additional **kwargs to the `__init__` method to allow recording
57
+ arbitrary key/value pairs.
58
+ """
59
+
60
+ def __init__(self, tags: list[str] | None = None, **kwargs):
61
+ super().__init__(tags=tags)
62
+ self.__dict__.update(kwargs)
63
+
64
+
65
+ @dataclasses.dataclass(frozen=True)
66
+ class Rule:
67
+ id: str
68
+ name: str
69
+ message_default_template: str
70
+ short_description: str | None = None
71
+ full_description: str | None = None
72
+ full_description_markdown: str | None = None
73
+ help_uri: str | None = None
74
+
75
+ @classmethod
76
+ def from_sarif(cls, **kwargs):
77
+ """Returns a rule from the SARIF reporting descriptor."""
78
+ short_description = kwargs.get("short_description", {}).get("text")
79
+ full_description = kwargs.get("full_description", {}).get("text")
80
+ full_description_markdown = kwargs.get("full_description", {}).get("markdown")
81
+ help_uri = kwargs.get("help_uri")
82
+
83
+ rule = cls(
84
+ id=kwargs["id"],
85
+ name=kwargs["name"],
86
+ message_default_template=kwargs["message_strings"]["default"]["text"],
87
+ short_description=short_description,
88
+ full_description=full_description,
89
+ full_description_markdown=full_description_markdown,
90
+ help_uri=help_uri,
91
+ )
92
+ return rule
93
+
94
+ def sarif(self) -> sarif.ReportingDescriptor:
95
+ """Returns a SARIF reporting descriptor of this Rule."""
96
+ short_description = (
97
+ sarif.MultiformatMessageString(text=self.short_description)
98
+ if self.short_description is not None
99
+ else None
100
+ )
101
+ full_description = (
102
+ sarif.MultiformatMessageString(
103
+ text=self.full_description, markdown=self.full_description_markdown
104
+ )
105
+ if self.full_description is not None
106
+ else None
107
+ )
108
+ return sarif.ReportingDescriptor(
109
+ id=self.id,
110
+ name=self.name,
111
+ short_description=short_description,
112
+ full_description=full_description,
113
+ help_uri=self.help_uri,
114
+ )
115
+
116
+ def format(self, level: Level, *args, **kwargs) -> tuple[Rule, Level, str]:
117
+ """Returns a tuple of (rule, level, message) for a diagnostic.
118
+
119
+ This method is used to format the message of a diagnostic. The message is
120
+ formatted using the default template of this rule, and the arguments passed in
121
+ as `*args` and `**kwargs`. The level is used to override the default level of
122
+ this rule.
123
+ """
124
+ return (self, level, self.format_message(*args, **kwargs))
125
+
126
+ def format_message(self, *args, **kwargs) -> str:
127
+ """Returns the formatted default message of this Rule.
128
+
129
+ This method should be overridden (with code generation) by subclasses to reflect
130
+ the exact arguments needed by the message template. This is a helper method to
131
+ create the default message for a diagnostic.
132
+ """
133
+ return self.message_default_template.format(*args, **kwargs)
134
+
135
+
136
+ @dataclasses.dataclass
137
+ class Location:
138
+ uri: str | None = None
139
+ line: int | None = None
140
+ message: str | None = None
141
+ start_column: int | None = None
142
+ end_column: int | None = None
143
+ snippet: str | None = None
144
+ function: str | None = None
145
+
146
+ def sarif(self) -> sarif.Location:
147
+ """Returns the SARIF representation of this location."""
148
+ return sarif.Location(
149
+ physical_location=sarif.PhysicalLocation(
150
+ artifact_location=sarif.ArtifactLocation(uri=self.uri),
151
+ region=sarif.Region(
152
+ start_line=self.line,
153
+ start_column=self.start_column,
154
+ end_column=self.end_column,
155
+ snippet=sarif.ArtifactContent(text=self.snippet),
156
+ ),
157
+ ),
158
+ message=sarif.Message(text=self.message)
159
+ if self.message is not None
160
+ else None,
161
+ )
162
+
163
+
164
+ @dataclasses.dataclass
165
+ class StackFrame:
166
+ location: Location
167
+
168
+ def sarif(self) -> sarif.StackFrame:
169
+ """Returns the SARIF representation of this stack frame."""
170
+ return sarif.StackFrame(location=self.location.sarif())
171
+
172
+
173
+ @dataclasses.dataclass
174
+ class Stack:
175
+ """Records a stack trace. The frames are in order from newest to oldest stack frame."""
176
+
177
+ frames: list[StackFrame] = dataclasses.field(default_factory=list)
178
+ message: str | None = None
179
+
180
+ def sarif(self) -> sarif.Stack:
181
+ """Returns the SARIF representation of this stack."""
182
+ return sarif.Stack(
183
+ frames=[frame.sarif() for frame in self.frames],
184
+ message=sarif.Message(text=self.message)
185
+ if self.message is not None
186
+ else None,
187
+ )
188
+
189
+
190
+ @dataclasses.dataclass
191
+ class ThreadFlowLocation:
192
+ """Records code location and the initial state."""
193
+
194
+ location: Location
195
+ state: Mapping[str, str]
196
+ index: int
197
+ stack: Stack | None = None
198
+
199
+ def sarif(self) -> sarif.ThreadFlowLocation:
200
+ """Returns the SARIF representation of this thread flow location."""
201
+ return sarif.ThreadFlowLocation(
202
+ location=self.location.sarif(),
203
+ state=self.state,
204
+ stack=self.stack.sarif() if self.stack is not None else None,
205
+ )
206
+
207
+
208
+ @dataclasses.dataclass
209
+ class Graph:
210
+ """A graph of diagnostics.
211
+
212
+ This class stores the string representation of a model graph.
213
+ The `nodes` and `edges` fields are unused in the current implementation.
214
+ """
215
+
216
+ graph: str
217
+ name: str
218
+ description: str | None = None
219
+
220
+ def sarif(self) -> sarif.Graph:
221
+ """Returns the SARIF representation of this graph."""
222
+ return sarif.Graph(
223
+ description=sarif.Message(text=self.graph),
224
+ properties=PatchedPropertyBag(name=self.name, description=self.description),
225
+ )
226
+
227
+
228
+ @dataclasses.dataclass
229
+ class RuleCollection:
230
+ _rule_id_name_set: frozenset[tuple[str, str]] = dataclasses.field(init=False)
231
+
232
+ def __post_init__(self) -> None:
233
+ self._rule_id_name_set = frozenset(
234
+ {
235
+ (field.default.id, field.default.name)
236
+ for field in dataclasses.fields(self)
237
+ if isinstance(field.default, Rule)
238
+ }
239
+ )
240
+
241
+ def __contains__(self, rule: Rule) -> bool:
242
+ """Checks if the rule is in the collection."""
243
+ return (rule.id, rule.name) in self._rule_id_name_set
244
+
245
+ @classmethod
246
+ def custom_collection_from_list(
247
+ cls, new_collection_class_name: str, rules: Sequence[Rule]
248
+ ) -> RuleCollection:
249
+ """Creates a custom class inherited from RuleCollection with the list of rules."""
250
+ return dataclasses.make_dataclass(
251
+ new_collection_class_name,
252
+ [
253
+ (
254
+ formatter.kebab_case_to_snake_case(rule.name),
255
+ type(rule),
256
+ dataclasses.field(default=rule),
257
+ )
258
+ for rule in rules
259
+ ],
260
+ bases=(cls,),
261
+ )()
262
+
263
+
264
+ class Invocation:
265
+ # TODO: Implement this.
266
+ # Tracks top level call arguments and diagnostic options.
267
+ def __init__(self) -> None:
268
+ raise NotImplementedError
269
+
270
+
271
+ @dataclasses.dataclass
272
+ class DiagnosticOptions:
273
+ """Options for diagnostic context.
274
+
275
+ Attributes:
276
+ verbosity_level: Set the amount of information logged for each diagnostics,
277
+ equivalent to the 'level' in Python logging module.
278
+ warnings_as_errors: When True, warning diagnostics are treated as error diagnostics.
279
+ """
280
+
281
+ verbosity_level: int = dataclasses.field(default=logging.INFO)
282
+ """Set the amount of information logged for each diagnostics, equivalent to the 'level' in Python logging module."""
283
+
284
+ warnings_as_errors: bool = dataclasses.field(default=False)
285
+ """If True, warning diagnostics are treated as error diagnostics."""
.venv/Lib/site-packages/torch/onnx/_internal/diagnostics/infra/context.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ """A diagnostic context based on SARIF."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import contextlib
7
+ import dataclasses
8
+ import gzip
9
+ import logging
10
+ from typing import Callable, Generator, Generic, Literal, Mapping, TypeVar
11
+ from typing_extensions import Self
12
+
13
+ from torch.onnx._internal.diagnostics import infra
14
+ from torch.onnx._internal.diagnostics.infra import formatter, sarif, utils
15
+ from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version
16
+
17
+
18
+ # This is a workaround for mypy not supporting Self from typing_extensions.
19
+ _Diagnostic = TypeVar("_Diagnostic", bound="Diagnostic")
20
+ diagnostic_logger: logging.Logger = logging.getLogger(__name__)
21
+
22
+
23
+ @dataclasses.dataclass
24
+ class Diagnostic:
25
+ rule: infra.Rule
26
+ level: infra.Level
27
+ message: str | None = None
28
+ locations: list[infra.Location] = dataclasses.field(default_factory=list)
29
+ stacks: list[infra.Stack] = dataclasses.field(default_factory=list)
30
+ graphs: list[infra.Graph] = dataclasses.field(default_factory=list)
31
+ thread_flow_locations: list[infra.ThreadFlowLocation] = dataclasses.field(
32
+ default_factory=list
33
+ )
34
+ additional_messages: list[str] = dataclasses.field(default_factory=list)
35
+ tags: list[infra.Tag] = dataclasses.field(default_factory=list)
36
+ source_exception: Exception | None = None
37
+ """The exception that caused this diagnostic to be created."""
38
+ logger: logging.Logger = dataclasses.field(init=False, default=diagnostic_logger)
39
+ """The logger for this diagnostic. Defaults to 'diagnostic_logger' which has the same
40
+ log level setting with `DiagnosticOptions.verbosity_level`."""
41
+ _current_log_section_depth: int = 0
42
+
43
+ def __post_init__(self) -> None:
44
+ pass
45
+
46
+ def sarif(self) -> sarif.Result:
47
+ """Returns the SARIF Result representation of this diagnostic."""
48
+ message = self.message or self.rule.message_default_template
49
+ if self.additional_messages:
50
+ additional_message = "\n".join(self.additional_messages)
51
+ message_markdown = (
52
+ f"{message}\n\n## Additional Message:\n\n{additional_message}"
53
+ )
54
+ else:
55
+ message_markdown = message
56
+
57
+ kind: Literal["informational", "fail"] = (
58
+ "informational" if self.level == infra.Level.NONE else "fail"
59
+ )
60
+
61
+ sarif_result = sarif.Result(
62
+ message=sarif.Message(text=message, markdown=message_markdown),
63
+ level=self.level.name.lower(), # type: ignore[arg-type]
64
+ rule_id=self.rule.id,
65
+ kind=kind,
66
+ )
67
+ sarif_result.locations = [location.sarif() for location in self.locations]
68
+ sarif_result.stacks = [stack.sarif() for stack in self.stacks]
69
+ sarif_result.graphs = [graph.sarif() for graph in self.graphs]
70
+ sarif_result.code_flows = [
71
+ sarif.CodeFlow(
72
+ thread_flows=[
73
+ sarif.ThreadFlow(
74
+ locations=[loc.sarif() for loc in self.thread_flow_locations]
75
+ )
76
+ ]
77
+ )
78
+ ]
79
+ sarif_result.properties = sarif.PropertyBag(
80
+ tags=[tag.value for tag in self.tags]
81
+ )
82
+ return sarif_result
83
+
84
+ def with_location(self: Self, location: infra.Location) -> Self:
85
+ """Adds a location to the diagnostic."""
86
+ self.locations.append(location)
87
+ return self
88
+
89
+ def with_thread_flow_location(
90
+ self: Self, location: infra.ThreadFlowLocation
91
+ ) -> Self:
92
+ """Adds a thread flow location to the diagnostic."""
93
+ self.thread_flow_locations.append(location)
94
+ return self
95
+
96
+ def with_stack(self: Self, stack: infra.Stack) -> Self:
97
+ """Adds a stack to the diagnostic."""
98
+ self.stacks.append(stack)
99
+ return self
100
+
101
+ def with_graph(self: Self, graph: infra.Graph) -> Self:
102
+ """Adds a graph to the diagnostic."""
103
+ self.graphs.append(graph)
104
+ return self
105
+
106
+ @contextlib.contextmanager
107
+ def log_section(
108
+ self, level: int, message: str, *args, **kwargs
109
+ ) -> Generator[None, None, None]:
110
+ """
111
+ Context manager for a section of log messages, denoted by a title message and increased indentation.
112
+
113
+ Same api as `logging.Logger.log`.
114
+
115
+ This context manager logs the given title at the specified log level, increases the current
116
+ section depth for subsequent log messages, and ensures that the section depth is decreased
117
+ again when exiting the context.
118
+
119
+ Args:
120
+ level: The log level.
121
+ message: The title message to log.
122
+ *args: The arguments to the message. Use `LazyString` to defer the
123
+ expensive evaluation of the arguments until the message is actually logged.
124
+ **kwargs: The keyword arguments for `logging.Logger.log`.
125
+
126
+ Yields:
127
+ None: This context manager does not yield any value.
128
+
129
+ Example:
130
+ >>> with DiagnosticContext("DummyContext", "1.0"):
131
+ ... rule = infra.Rule("RuleID", "DummyRule", "Rule message")
132
+ ... diagnostic = Diagnostic(rule, infra.Level.WARNING)
133
+ ... with diagnostic.log_section(logging.INFO, "My Section"):
134
+ ... diagnostic.log(logging.INFO, "My Message")
135
+ ... with diagnostic.log_section(logging.INFO, "My Subsection"):
136
+ ... diagnostic.log(logging.INFO, "My Submessage")
137
+ ... diagnostic.additional_messages
138
+ ['## My Section', 'My Message', '### My Subsection', 'My Submessage']
139
+ """
140
+ if self.logger.isEnabledFor(level):
141
+ indented_format_message = (
142
+ f"##{'#' * self._current_log_section_depth } {message}"
143
+ )
144
+ self.log(
145
+ level,
146
+ indented_format_message,
147
+ *args,
148
+ **kwargs,
149
+ )
150
+ self._current_log_section_depth += 1
151
+ try:
152
+ yield
153
+ finally:
154
+ self._current_log_section_depth -= 1
155
+
156
+ def log(self, level: int, message: str, *args, **kwargs) -> None:
157
+ """Logs a message within the diagnostic. Same api as `logging.Logger.log`.
158
+
159
+ If logger is not enabled for the given level, the message will not be logged.
160
+ Otherwise, the message will be logged and also added to the diagnostic's additional_messages.
161
+
162
+ The default setting for `DiagnosticOptions.verbosity_level` is `logging.INFO`. Based on this default,
163
+ the log level recommendations are as follows. If you've set a different default verbosity level in your
164
+ application, please adjust accordingly:
165
+
166
+ - logging.ERROR: Log any events leading to application failure.
167
+ - logging.WARNING: Log events that might result in application issues or failures, although not guaranteed.
168
+ - logging.INFO: Log general useful information, ensuring minimal performance overhead.
169
+ - logging.DEBUG: Log detailed debug information, which might affect performance when logged.
170
+
171
+ Args:
172
+ level: The log level.
173
+ message: The message to log.
174
+ *args: The arguments to the message. Use `LazyString` to defer the
175
+ expensive evaluation of the arguments until the message is actually logged.
176
+ **kwargs: The keyword arguments for `logging.Logger.log`.
177
+ """
178
+ if self.logger.isEnabledFor(level):
179
+ formatted_message = message % args
180
+ self.logger.log(level, formatted_message, **kwargs)
181
+ self.additional_messages.append(formatted_message)
182
+
183
+ def debug(self, message: str, *args, **kwargs) -> None:
184
+ """Logs a debug message within the diagnostic. Same api as logging.Logger.debug.
185
+
186
+ Checkout `log` for more details.
187
+ """
188
+ self.log(logging.DEBUG, message, *args, **kwargs)
189
+
190
+ def info(self, message: str, *args, **kwargs) -> None:
191
+ """Logs an info message within the diagnostic. Same api as logging.Logger.info.
192
+
193
+ Checkout `log` for more details.
194
+ """
195
+ self.log(logging.INFO, message, *args, **kwargs)
196
+
197
+ def warning(self, message: str, *args, **kwargs) -> None:
198
+ """Logs a warning message within the diagnostic. Same api as logging.Logger.warning.
199
+
200
+ Checkout `log` for more details.
201
+ """
202
+ self.log(logging.WARNING, message, *args, **kwargs)
203
+
204
+ def error(self, message: str, *args, **kwargs) -> None:
205
+ """Logs an error message within the diagnostic. Same api as logging.Logger.error.
206
+
207
+ Checkout `log` for more details.
208
+ """
209
+ self.log(logging.ERROR, message, *args, **kwargs)
210
+
211
+ def log_source_exception(self, level: int, exception: Exception) -> None:
212
+ """Logs a source exception within the diagnostic.
213
+
214
+ Invokes `log_section` and `log` to log the exception in markdown section format.
215
+ """
216
+ self.source_exception = exception
217
+ with self.log_section(level, "Exception log"):
218
+ self.log(level, "%s", formatter.lazy_format_exception(exception))
219
+
220
+ def record_python_call_stack(self, frames_to_skip: int) -> infra.Stack:
221
+ """Records the current Python call stack."""
222
+ frames_to_skip += 1 # Skip this function.
223
+ stack = utils.python_call_stack(frames_to_skip=frames_to_skip)
224
+ self.with_stack(stack)
225
+ if len(stack.frames) > 0:
226
+ self.with_location(stack.frames[0].location)
227
+ return stack
228
+
229
+ def record_python_call(
230
+ self,
231
+ fn: Callable,
232
+ state: Mapping[str, str],
233
+ message: str | None = None,
234
+ frames_to_skip: int = 0,
235
+ ) -> infra.ThreadFlowLocation:
236
+ """Records a python call as one thread flow step."""
237
+ frames_to_skip += 1 # Skip this function.
238
+ stack = utils.python_call_stack(frames_to_skip=frames_to_skip, frames_to_log=5)
239
+ location = utils.function_location(fn)
240
+ location.message = message
241
+ # Add function location to the top of the stack.
242
+ stack.frames.insert(0, infra.StackFrame(location=location))
243
+ thread_flow_location = infra.ThreadFlowLocation(
244
+ location=location,
245
+ state=state,
246
+ index=len(self.thread_flow_locations),
247
+ stack=stack,
248
+ )
249
+ self.with_thread_flow_location(thread_flow_location)
250
+ return thread_flow_location
251
+
252
+
253
+ class RuntimeErrorWithDiagnostic(RuntimeError):
254
+ """Runtime error with enclosed diagnostic information."""
255
+
256
+ def __init__(self, diagnostic: Diagnostic):
257
+ super().__init__(diagnostic.message)
258
+ self.diagnostic = diagnostic
259
+
260
+
261
+ @dataclasses.dataclass
262
+ class DiagnosticContext(Generic[_Diagnostic]):
263
+ name: str
264
+ version: str
265
+ options: infra.DiagnosticOptions = dataclasses.field(
266
+ default_factory=infra.DiagnosticOptions
267
+ )
268
+ diagnostics: list[_Diagnostic] = dataclasses.field(init=False, default_factory=list)
269
+ # TODO(bowbao): Implement this.
270
+ # _invocation: infra.Invocation = dataclasses.field(init=False)
271
+ _inflight_diagnostics: list[_Diagnostic] = dataclasses.field(
272
+ init=False, default_factory=list
273
+ )
274
+ _previous_log_level: int = dataclasses.field(init=False, default=logging.WARNING)
275
+ logger: logging.Logger = dataclasses.field(init=False, default=diagnostic_logger)
276
+ _bound_diagnostic_type: type = dataclasses.field(init=False, default=Diagnostic)
277
+
278
+ def __enter__(self):
279
+ self._previous_log_level = self.logger.level
280
+ self.logger.setLevel(self.options.verbosity_level)
281
+ return self
282
+
283
+ def __exit__(self, exc_type, exc_val, exc_tb):
284
+ self.logger.setLevel(self._previous_log_level)
285
+ return None
286
+
287
+ def sarif(self) -> sarif.Run:
288
+ """Returns the SARIF Run object."""
289
+ unique_rules = {diagnostic.rule for diagnostic in self.diagnostics}
290
+ return sarif.Run(
291
+ sarif.Tool(
292
+ driver=sarif.ToolComponent(
293
+ name=self.name,
294
+ version=self.version,
295
+ rules=[rule.sarif() for rule in unique_rules],
296
+ )
297
+ ),
298
+ results=[diagnostic.sarif() for diagnostic in self.diagnostics],
299
+ )
300
+
301
+ def sarif_log(self) -> sarif.SarifLog: # type: ignore[name-defined]
302
+ """Returns the SARIF Log object."""
303
+ return sarif.SarifLog(
304
+ version=sarif_version.SARIF_VERSION,
305
+ schema_uri=sarif_version.SARIF_SCHEMA_LINK,
306
+ runs=[self.sarif()],
307
+ )
308
+
309
+ def to_json(self) -> str:
310
+ return formatter.sarif_to_json(self.sarif_log())
311
+
312
+ def dump(self, file_path: str, compress: bool = False) -> None:
313
+ """Dumps the SARIF log to a file."""
314
+ if compress:
315
+ with gzip.open(file_path, "wt") as f:
316
+ f.write(self.to_json())
317
+ else:
318
+ with open(file_path, "w") as f:
319
+ f.write(self.to_json())
320
+
321
+ def log(self, diagnostic: _Diagnostic) -> None:
322
+ """Logs a diagnostic.
323
+
324
+ This method should be used only after all the necessary information for the diagnostic
325
+ has been collected.
326
+
327
+ Args:
328
+ diagnostic: The diagnostic to add.
329
+ """
330
+ if not isinstance(diagnostic, self._bound_diagnostic_type):
331
+ raise TypeError(
332
+ f"Expected diagnostic of type {self._bound_diagnostic_type}, got {type(diagnostic)}"
333
+ )
334
+ if self.options.warnings_as_errors and diagnostic.level == infra.Level.WARNING: # type: ignore[attr-defined]
335
+ diagnostic.level = infra.Level.ERROR # type: ignore[attr-defined]
336
+ self.diagnostics.append(diagnostic) # type: ignore[arg-type]
337
+
338
+ def log_and_raise_if_error(self, diagnostic: _Diagnostic) -> None:
339
+ """Logs a diagnostic and raises an exception if it is an error.
340
+
341
+ Use this method for logging non inflight diagnostics where diagnostic level is not known or
342
+ lower than ERROR. If it is always expected raise, use `log` and explicit
343
+ `raise` instead. Otherwise there is no way to convey the message that it always
344
+ raises to Python intellisense and type checking tools.
345
+
346
+ This method should be used only after all the necessary information for the diagnostic
347
+ has been collected.
348
+
349
+ Args:
350
+ diagnostic: The diagnostic to add.
351
+ """
352
+ self.log(diagnostic)
353
+ if diagnostic.level == infra.Level.ERROR:
354
+ if diagnostic.source_exception is not None:
355
+ raise diagnostic.source_exception
356
+ raise RuntimeErrorWithDiagnostic(diagnostic)
357
+
358
+ @contextlib.contextmanager
359
+ def add_inflight_diagnostic(
360
+ self, diagnostic: _Diagnostic
361
+ ) -> Generator[_Diagnostic, None, None]:
362
+ """Adds a diagnostic to the context.
363
+
364
+ Use this method to add diagnostics that are not created by the context.
365
+ Args:
366
+ diagnostic: The diagnostic to add.
367
+ """
368
+ self._inflight_diagnostics.append(diagnostic)
369
+ try:
370
+ yield diagnostic
371
+ finally:
372
+ self._inflight_diagnostics.pop()
373
+
374
+ def push_inflight_diagnostic(self, diagnostic: _Diagnostic) -> None:
375
+ """Pushes a diagnostic to the inflight diagnostics stack.
376
+
377
+ Args:
378
+ diagnostic: The diagnostic to push.
379
+
380
+ Raises:
381
+ ValueError: If the rule is not supported by the tool.
382
+ """
383
+ self._inflight_diagnostics.append(diagnostic)
384
+
385
+ def pop_inflight_diagnostic(self) -> _Diagnostic:
386
+ """Pops the last diagnostic from the inflight diagnostics stack.
387
+
388
+ Returns:
389
+ The popped diagnostic.
390
+ """
391
+ return self._inflight_diagnostics.pop()
392
+
393
+ def inflight_diagnostic(self, rule: infra.Rule | None = None) -> _Diagnostic:
394
+ if rule is None:
395
+ # TODO(bowbao): Create builtin-rules and create diagnostic using that.
396
+ if len(self._inflight_diagnostics) <= 0:
397
+ raise AssertionError("No inflight diagnostics")
398
+
399
+ return self._inflight_diagnostics[-1]
400
+ else:
401
+ for diagnostic in reversed(self._inflight_diagnostics):
402
+ if diagnostic.rule == rule:
403
+ return diagnostic
404
+ raise AssertionError(f"No inflight diagnostic for rule {rule.name}")
.venv/Lib/site-packages/torch/onnx/_internal/diagnostics/infra/decorator.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from __future__ import annotations
3
+
4
+ import functools
5
+ import logging
6
+ import traceback
7
+ from typing import Any, Callable, Dict, Tuple
8
+
9
+ from torch.onnx._internal.diagnostics import infra
10
+ from torch.onnx._internal.diagnostics.infra import formatter, utils
11
+
12
+
13
+ MessageFormatterType = Callable[..., str]
14
+
15
+
16
+ def format_message_in_text(fn: Callable, *args: Any, **kwargs: Any) -> str:
17
+ return f"{formatter.display_name(fn)}. "
18
+
19
+
20
+ def format_exception_in_markdown(exception: Exception) -> str:
21
+ msg_list = ["### Exception log", "```"]
22
+ msg_list.extend(
23
+ traceback.format_exception(type(exception), exception, exception.__traceback__)
24
+ )
25
+ msg_list.append("```")
26
+ return "\n".join(msg_list)
27
+
28
+
29
+ def format_function_signature_in_markdown(
30
+ fn: Callable,
31
+ args: tuple[Any, ...],
32
+ kwargs: dict[str, Any],
33
+ format_argument: Callable[[Any], str] = formatter.format_argument,
34
+ ) -> str:
35
+ msg_list = [f"### Function Signature {formatter.display_name(fn)}"]
36
+
37
+ state = utils.function_state(fn, args, kwargs)
38
+
39
+ for k, v in state.items():
40
+ msg_list.append(f"- {k}: {format_argument(v)}")
41
+
42
+ return "\n".join(msg_list)
43
+
44
+
45
+ def format_return_values_in_markdown(
46
+ return_values: Any,
47
+ format_argument: Callable[[Any], str] = formatter.format_argument,
48
+ ) -> str:
49
+ return f"{format_argument(return_values)}"
50
+
51
+
52
+ ModifierCallableType = Callable[
53
+ [infra.Diagnostic, Callable, Tuple[Any, ...], Dict[str, Any], Any], None
54
+ ]
55
+
56
+
57
+ def diagnose_call(
58
+ rule: infra.Rule,
59
+ *,
60
+ level: infra.Level = infra.Level.NONE,
61
+ diagnostic_type: type[infra.Diagnostic] = infra.Diagnostic,
62
+ format_argument: Callable[[Any], str] = formatter.format_argument,
63
+ diagnostic_message_formatter: MessageFormatterType = format_message_in_text,
64
+ ) -> Callable:
65
+ def decorator(fn):
66
+ @functools.wraps(fn)
67
+ def wrapper(*args, **kwargs):
68
+ common_error_message = "diagnose_call can only be applied to callables"
69
+ if not callable(fn):
70
+ raise AssertionError(
71
+ f"{common_error_message}. Got {type(fn)} instead of callable."
72
+ )
73
+ arg0 = args[0] if len(args) > 0 else None
74
+ if isinstance(ctx := arg0, infra.DiagnosticContext):
75
+ pass
76
+ elif isinstance(
77
+ ctx := getattr(arg0, "diagnostic_context", None),
78
+ infra.DiagnosticContext,
79
+ ):
80
+ pass
81
+ else:
82
+ # NOTE: At decorate time, it can't tell if a callable is function or method.
83
+ # Technically both are regarded as function at that time.
84
+ raise AssertionError(
85
+ f"{common_error_message}. For {fn}, "
86
+ f"If it is a function, a DiagnosticContext instance must be present as "
87
+ f"the first argument. "
88
+ f"If it is a method, a DiagnosticContext instance must be present as "
89
+ f"the attribute 'diagnostic_context' of the 'self' argument."
90
+ )
91
+
92
+ diag = diagnostic_type(
93
+ rule,
94
+ level,
95
+ diagnostic_message_formatter(fn, *args, **kwargs),
96
+ )
97
+
98
+ # pop the decorator frame
99
+ # TODO(bowbao): by default diagnostic doesn't have stack.
100
+ # So need to check before doing this. Make the code cleaner.
101
+ # Option: do not capture stack by default in diagnostic initialization.
102
+ stack: infra.Stack | None = None
103
+ if len(diag.stacks) > 0:
104
+ stack = diag.stacks[0]
105
+ stack.frames.pop(0)
106
+
107
+ # set function location
108
+ fn_location = utils.function_location(fn)
109
+ diag.locations.insert(0, fn_location)
110
+ # Add function location to the top of the stack.
111
+ if stack is not None:
112
+ stack.frames.insert(0, infra.StackFrame(location=fn_location))
113
+
114
+ with diag.log_section(logging.INFO, "Function Signature"):
115
+ diag.log(
116
+ logging.INFO,
117
+ "%s",
118
+ formatter.LazyString(
119
+ format_function_signature_in_markdown,
120
+ fn,
121
+ args,
122
+ kwargs,
123
+ format_argument,
124
+ ),
125
+ )
126
+
127
+ return_values: Any = None
128
+ with ctx.add_inflight_diagnostic(diag) as diag:
129
+ try:
130
+ return_values = fn(*args, **kwargs)
131
+ with diag.log_section(logging.INFO, "Return values"):
132
+ diag.log(
133
+ logging.INFO,
134
+ "%s",
135
+ formatter.LazyString(
136
+ format_return_values_in_markdown,
137
+ return_values,
138
+ format_argument,
139
+ ),
140
+ )
141
+ return return_values
142
+ except Exception as e:
143
+ diag.log_source_exception(logging.ERROR, e)
144
+ diag.level = infra.Level.ERROR
145
+ finally:
146
+ ctx.log_and_raise_if_error(diag)
147
+
148
+ return wrapper
149
+
150
+ return decorator
151
+
152
+
153
+ # TODO(bowbao): decorator to report only when failed.
.venv/Lib/site-packages/torch/onnx/_internal/diagnostics/infra/formatter.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import json
5
+ import re
6
+ import traceback
7
+ from typing import Any, Callable, Union
8
+
9
+ from torch._logging import LazyString
10
+ from torch.onnx._internal.diagnostics.infra import sarif
11
+
12
+
13
+ # A list of types in the SARIF module to support pretty printing.
14
+ # This is solely for type annotation for the functions below.
15
+ _SarifClass = Union[
16
+ sarif.SarifLog,
17
+ sarif.Run,
18
+ sarif.ReportingDescriptor,
19
+ sarif.Result,
20
+ ]
21
+
22
+
23
+ def lazy_format_exception(exception: Exception) -> LazyString:
24
+ return LazyString(
25
+ lambda: "\n".join(
26
+ (
27
+ "```",
28
+ *traceback.format_exception(
29
+ type(exception), exception, exception.__traceback__
30
+ ),
31
+ "```",
32
+ )
33
+ ),
34
+ )
35
+
36
+
37
+ def snake_case_to_camel_case(s: str) -> str:
38
+ splits = s.split("_")
39
+ if len(splits) <= 1:
40
+ return s
41
+ return "".join([splits[0], *map(str.capitalize, splits[1:])])
42
+
43
+
44
+ def camel_case_to_snake_case(s: str) -> str:
45
+ return re.sub(r"([A-Z])", r"_\1", s).lower()
46
+
47
+
48
+ def kebab_case_to_snake_case(s: str) -> str:
49
+ return s.replace("-", "_")
50
+
51
+
52
+ def _convert_key(
53
+ object: dict[str, Any] | Any, convert: Callable[[str], str]
54
+ ) -> dict[str, Any] | Any:
55
+ """Convert and update keys in a dictionary with "convert".
56
+
57
+ Any value that is a dictionary will be recursively updated.
58
+ Any value that is a list will be recursively searched.
59
+
60
+ Args:
61
+ object: The object to update.
62
+ convert: The function to convert the keys, e.g. `kebab_case_to_snake_case`.
63
+
64
+ Returns:
65
+ The updated object.
66
+ """
67
+ if not isinstance(object, dict):
68
+ return object
69
+ new_dict = {}
70
+ for k, v in object.items():
71
+ new_k = convert(k)
72
+ if isinstance(v, dict):
73
+ new_v = _convert_key(v, convert)
74
+ elif isinstance(v, list):
75
+ new_v = [_convert_key(elem, convert) for elem in v]
76
+ else:
77
+ new_v = v
78
+ if new_v is None:
79
+ # Otherwise unnecessarily bloated sarif log with "null"s.
80
+ continue
81
+ if new_v == -1:
82
+ # WAR: -1 as default value shouldn't be logged into sarif.
83
+ continue
84
+
85
+ new_dict[new_k] = new_v
86
+
87
+ return new_dict
88
+
89
+
90
+ def sarif_to_json(attr_cls_obj: _SarifClass, indent: str | None = " ") -> str:
91
+ dict = dataclasses.asdict(attr_cls_obj)
92
+ dict = _convert_key(dict, snake_case_to_camel_case)
93
+ return json.dumps(dict, indent=indent, separators=(",", ":"))
94
+
95
+
96
+ def format_argument(obj: Any) -> str:
97
+ return f"{type(obj)}"
98
+
99
+
100
+ def display_name(fn: Callable) -> str:
101
+ if hasattr(fn, "__qualname__"):
102
+ return fn.__qualname__
103
+ elif hasattr(fn, "__name__"):
104
+ return fn.__name__
105
+ else:
106
+ return str(fn)
.venv/Lib/site-packages/torch/onnx/_internal/diagnostics/infra/utils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import inspect
5
+ import traceback
6
+ from typing import Any, Callable, Mapping, Sequence
7
+
8
+ from torch.onnx._internal.diagnostics.infra import _infra, formatter
9
+
10
+
11
+ def python_frame(frame: traceback.FrameSummary) -> _infra.StackFrame:
12
+ """Returns a StackFrame for the given traceback.FrameSummary."""
13
+ snippet = frame.line
14
+
15
+ return _infra.StackFrame(
16
+ location=_infra.Location(
17
+ uri=frame.filename,
18
+ line=frame.lineno,
19
+ snippet=snippet,
20
+ function=frame.name,
21
+ message=snippet,
22
+ )
23
+ )
24
+
25
+
26
+ def python_call_stack(frames_to_skip: int = 0, frames_to_log: int = 16) -> _infra.Stack:
27
+ """Returns the current Python call stack."""
28
+ if frames_to_skip < 0:
29
+ raise ValueError("frames_to_skip must be non-negative")
30
+ if frames_to_log < 0:
31
+ raise ValueError("frames_to_log must be non-negative")
32
+ frames_to_skip += 1 # Skip this function.
33
+ stack = _infra.Stack()
34
+ # Frames are returned in order of oldest to newest.
35
+ frames = traceback.extract_stack(limit=frames_to_skip + frames_to_log)
36
+ frames.reverse()
37
+ stack.frames = [python_frame(frame) for frame in frames[frames_to_skip:]]
38
+ stack.message = "Python call stack"
39
+ return stack
40
+
41
+
42
+ @functools.lru_cache
43
+ def _function_source_info(fn: Callable) -> tuple[Sequence[str], int, str | None]:
44
+ """Returns the source lines, line number, and source file path for the given function.
45
+
46
+ Essentially, inspect.getsourcelines() and inspect.getsourcefile() combined.
47
+ Caching is applied to reduce the performance impact of this function.
48
+ """
49
+ source_lines, lineno = inspect.getsourcelines(fn)
50
+ return source_lines, lineno, inspect.getsourcefile(fn)
51
+
52
+
53
+ def function_location(fn: Callable) -> _infra.Location:
54
+ """Returns a Location for the given function."""
55
+ source_lines, lineno, uri = _function_source_info(fn)
56
+ snippet = source_lines[0].strip() if len(source_lines) > 0 else "<unknown>"
57
+ return _infra.Location(
58
+ uri=uri,
59
+ line=lineno,
60
+ snippet=snippet,
61
+ message=formatter.display_name(fn),
62
+ )
63
+
64
+
65
+ def function_state(
66
+ fn: Callable, args: tuple[Any, ...], kwargs: dict[str, Any]
67
+ ) -> Mapping[str, Any]:
68
+ bind = inspect.signature(fn).bind(*args, **kwargs)
69
+ return bind.arguments
.venv/Lib/site-packages/torch/onnx/_internal/io_adapter.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from __future__ import annotations
3
+
4
+ from typing import (
5
+ Any,
6
+ Callable,
7
+ Mapping,
8
+ Protocol,
9
+ runtime_checkable,
10
+ Sequence,
11
+ TYPE_CHECKING,
12
+ )
13
+
14
+ import torch
15
+ import torch.export as torch_export
16
+ from torch.utils import _pytree as pytree
17
+
18
+
19
+ if TYPE_CHECKING:
20
+ import inspect
21
+
22
+ # TODO(bowbao): Add diagnostics for IO adapters.
23
+
24
+
25
+ @runtime_checkable
26
+ class InputAdaptStep(Protocol):
27
+ """A protocol that defines a step in the input adapting process.
28
+
29
+ The input adapting process is a sequence of steps that are applied to the
30
+ PyTorch model inputs to transform them into the inputs format expected by the
31
+ exported ONNX model. Each step takes the PyTorch model inputs as arguments and
32
+ returns the transformed inputs.
33
+
34
+ This serves as a base formalized construct for the transformation done to model
35
+ input signature by any individual component in the exporter.
36
+ """
37
+
38
+ def apply(
39
+ self,
40
+ model_args: Sequence[Any],
41
+ model_kwargs: Mapping[str, Any],
42
+ model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
43
+ ) -> tuple[Sequence[Any], Mapping[str, Any]]: ...
44
+
45
+
46
+ class InputAdapter:
47
+ """A class that adapts the PyTorch model inputs to exported ONNX model inputs format."""
48
+
49
+ def __init__(self, steps: list[InputAdaptStep] | None = None):
50
+ self._steps = steps or []
51
+
52
+ def append_step(self, step: InputAdaptStep) -> None:
53
+ """Appends a step to the input adapt steps.
54
+
55
+ Args:
56
+ step: The step to append.
57
+ """
58
+ self._steps.append(step)
59
+
60
+ def apply(
61
+ self,
62
+ *model_args,
63
+ model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
64
+ **model_kwargs,
65
+ ) -> Sequence[int | float | bool | str | torch.Tensor | torch.dtype | None]:
66
+ """Converts the PyTorch model inputs to exported ONNX model inputs format.
67
+
68
+ Args:
69
+ model_args: The PyTorch model inputs.
70
+ model: The PyTorch model.
71
+ model_kwargs: The PyTorch model keyword inputs.
72
+ Returns:
73
+ A sequence of tensors converted from PyTorch model inputs.
74
+ """
75
+ args: Sequence[Any] = model_args
76
+ kwargs: Mapping[str, Any] = model_kwargs
77
+ for step in self._steps:
78
+ args, kwargs = step.apply(args, kwargs, model=model)
79
+ assert not kwargs
80
+ return args
81
+
82
+
83
+ @runtime_checkable
84
+ class OutputAdaptStep(Protocol):
85
+ """A protocol that defines a step in the output adapting process.
86
+
87
+ The output adapting process is a sequence of steps that are applied to the
88
+ PyTorch model outputs to transform them into the outputs format produced by the
89
+ exported ONNX model. Each step takes the PyTorch model outputs as arguments and
90
+ returns the transformed outputs.
91
+
92
+ This serves as a base formalized construct for the transformation done to model
93
+ output signature by any individual component in the exporter.
94
+ """
95
+
96
+ def apply(
97
+ self,
98
+ model_outputs: Any,
99
+ model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
100
+ ) -> Any: ...
101
+
102
+
103
+ class OutputAdapter:
104
+ """A class that adapts the PyTorch model outputs to exported ONNX model outputs format."""
105
+
106
+ def __init__(self, steps: list[OutputAdaptStep] | None = None):
107
+ self._steps = steps or []
108
+
109
+ def append_step(self, step: OutputAdaptStep) -> None:
110
+ """Appends a step to the output format steps.
111
+
112
+ Args:
113
+ step: The step to append.
114
+ """
115
+ self._steps.append(step)
116
+
117
+ def apply(
118
+ self,
119
+ model_outputs: Any,
120
+ model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
121
+ ) -> Sequence[torch.Tensor | int | float | bool | str]:
122
+ """Converts the PyTorch model outputs to exported ONNX model outputs format.
123
+
124
+ Args:
125
+ model_outputs: The PyTorch model outputs.
126
+ model: The PyTorch model.
127
+
128
+ Returns:
129
+ PyTorch model outputs in exported ONNX model outputs format.
130
+ """
131
+ for step in self._steps:
132
+ model_outputs = step.apply(model_outputs, model=model)
133
+ return model_outputs
134
+
135
+
136
+ # TODO: make_fx lose stack info https://github.com/pytorch/pytorch/issues/90276
137
+
138
+
139
+ def _replace_tuple_with_list(spec: pytree.TreeSpec) -> pytree.TreeSpec:
140
+ _type = list if spec.type == tuple else spec.type
141
+ return pytree.TreeSpec(
142
+ _type, spec.context, list(map(_replace_tuple_with_list, spec.children_specs))
143
+ )
144
+
145
+
146
+ def _open_top_level_list_if_single_element(spec: pytree.TreeSpec) -> pytree.TreeSpec:
147
+ if spec.type == list and spec.num_children == 1:
148
+ return spec.children_specs[0]
149
+ return spec
150
+
151
+
152
+ def _assert_identical_pytree_spec(
153
+ spec1: pytree.TreeSpec, spec2: pytree.TreeSpec, error_message: str
154
+ ) -> None:
155
+ """Assert the two `TreeSpec` objects are identical.
156
+
157
+ Args:
158
+ spec1: The first `TreeSpec` object.
159
+ spec2: The second `TreeSpec` object.
160
+ error_message: The error message to raise if the two `TreeSpec` objects are not
161
+ identical.
162
+
163
+ Raises:
164
+ ValueError: If the two `TreeSpec` objects are not identical.
165
+ """
166
+ # TODO(bowbao): Turn this check into diagnostic. Consider warning instead of error.
167
+ pass_if_any_checks: Sequence[Callable[[], bool]] = [
168
+ lambda: spec1 == spec2,
169
+ # FIXME: Bug in `dynamo.export`. Sometimes outputs returned in 'list' instead of 'tuple'.
170
+ lambda: _replace_tuple_with_list(spec1) == _replace_tuple_with_list(spec2),
171
+ # FIXME: Bug in `dynamo.export`. Sometimes single function return is wrapped in list.
172
+ lambda: _open_top_level_list_if_single_element(spec1) == spec2,
173
+ lambda: spec1 == _open_top_level_list_if_single_element(spec2),
174
+ ]
175
+
176
+ if not any(check() for check in pass_if_any_checks):
177
+ raise ValueError(f"{error_message}\nExpect {spec1}.\nActual {spec2}.")
178
+
179
+
180
+ class BindInputStep(InputAdaptStep):
181
+ """Bind the input arguments to the model signature."""
182
+
183
+ def __init__(self, model_signature: inspect.Signature):
184
+ self._model_signature = model_signature
185
+
186
+ def apply(
187
+ self,
188
+ model_args: Sequence[Any],
189
+ model_kwargs: Mapping[str, Any],
190
+ model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
191
+ ) -> tuple[Sequence[Any], Mapping[str, Any]]:
192
+ """Bind the input arguments to the model signature.
193
+
194
+ We hope the input kwargs will be mapped to bound.args after binding.
195
+ If not, we will raise an error.
196
+
197
+ Args:
198
+ model_args: The model args.
199
+ model_kwargs: The model kwargs.
200
+ model: The PyTorch model.
201
+
202
+ Returns:
203
+ A tuple of the model args and kwargs. args is always empty.
204
+
205
+ Raises:
206
+ ValueError: If there are keyword-only arguments left after binding args and
207
+ kwargs to model signature.
208
+ """
209
+ bound = self._model_signature.bind(*model_args, **model_kwargs)
210
+ bound.apply_defaults()
211
+
212
+ # keyword-only arguments are not handled.
213
+ # bound.kwargs only contains keyword-only arguments after calling
214
+ # bind & apply_defaults, so we raise if it's not empty.
215
+ if bound.kwargs:
216
+ raise ValueError("Keyword-only arguments are not supported.")
217
+ return (), bound.arguments
218
+
219
+
220
+ class MergeKwargsIntoArgsInputStep(InputAdaptStep):
221
+ """Merge the input kwargs into the input args."""
222
+
223
+ def apply(
224
+ self,
225
+ model_args: Sequence[Any],
226
+ model_kwargs: Mapping[str, Any],
227
+ model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
228
+ ) -> tuple[Sequence[Any], Mapping[str, Any]]:
229
+ """Merge the input kwargs into the input args.
230
+
231
+ Args:
232
+ model_args: The model args.
233
+ model_kwargs: The model kwargs.
234
+ model: The PyTorch model.
235
+
236
+ Returns:
237
+ A tuple of the model args and kwargs. kwargs is always empty.
238
+ """
239
+ return tuple(model_args) + tuple(model_kwargs.values()), {}
240
+
241
+
242
+ class LiftParametersAndBuffersIntoArgsInputStep(InputAdaptStep):
243
+ """Append parameters and buffers to model's positional argument list."""
244
+
245
+ def __init__(self, inputs: tuple[torch.Tensor, ...]) -> None:
246
+ self.inputs = inputs
247
+
248
+ def apply(
249
+ self,
250
+ model_args: Sequence[Any],
251
+ model_kwargs: Mapping[str, Any],
252
+ model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
253
+ ) -> tuple[Sequence[Any], Mapping[str, Any]]:
254
+ """Append model's parameters and buffers into its input.
255
+
256
+ Args:
257
+ model_args: The model args.
258
+ model_kwargs: The model kwargs.
259
+ model: The PyTorch model.
260
+
261
+ Returns:
262
+ A tuple of the model args + appended inputs and kwargs.
263
+ """
264
+ return (*model_args, *self.inputs), model_kwargs
265
+
266
+
267
+ class ConvertComplexToRealRepresentationInputStep(InputAdaptStep):
268
+ """Convert complex dtype tensors to real representation tensors.
269
+
270
+ ONNX does not support complex dtype tensors. Thus, we convert complex dtype tensors
271
+ to real representation tensors (i.e., float dtype tensors with an extra dimension
272
+ representing the real and imaginary parts of the complex number).
273
+
274
+ """
275
+
276
+ def apply(
277
+ self,
278
+ model_args: Sequence[Any],
279
+ model_kwargs: Mapping[str, Any],
280
+ model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
281
+ ) -> tuple[Sequence[Any], Mapping[str, Any]]:
282
+ """Convert complex tensors to float tensors.
283
+
284
+ Args:
285
+ model_args: The model args.
286
+ model_kwargs: The model kwargs.
287
+ model: The PyTorch model.
288
+
289
+ Returns:
290
+ A tuple of the model args and kwargs.
291
+ """
292
+ return (
293
+ tuple(
294
+ torch.view_as_real(arg.resolve_conj())
295
+ if isinstance(arg, torch.Tensor) and arg.is_complex()
296
+ else arg
297
+ for arg in model_args
298
+ ),
299
+ model_kwargs,
300
+ )
301
+
302
+
303
+ class RemoveNoneInputStep(InputAdaptStep):
304
+ """Remove `None` from arguments.
305
+
306
+ This adapt step assumes ``model_kwargs`` is empty. It also assumes ``model_args``
307
+ is flattened, i.e. it does not check `None` inside nested collections.
308
+ """
309
+
310
+ def apply(
311
+ self,
312
+ model_args: Sequence[Any],
313
+ model_kwargs: Mapping[str, Any],
314
+ model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
315
+ ) -> tuple[Sequence[Any], Mapping[str, Any]]:
316
+ """Remove `None` from arguments.
317
+
318
+ Args:
319
+ model_args: The model args.
320
+ model_kwargs: The model kwargs.
321
+ model: The PyTorch model.
322
+
323
+ Returns:
324
+ A tuple of the model args and kwargs.
325
+
326
+ Raises:
327
+ ValueError: If `model_kwargs` is not empty.
328
+ """
329
+ assert not model_kwargs
330
+ return tuple(arg for arg in model_args if arg is not None), {}
331
+
332
+
333
+ class RemoveNonTensorInputStep(InputAdaptStep):
334
+ """Remove the non-tensor input arguments.
335
+
336
+ Dynamo does not support non-tensor input arguments (https://github.com/pytorch/pytorch/issues/99534).
337
+
338
+ Specifically, it does put the input into graph with an empty node, but consumed by no ones.
339
+ The concrete value is embedded into the graph as a constant arg of a target node. Meta
340
+ suggests in this case that one should rewrite the model code to make it tensor if the
341
+ input value is supposed to change at runtime. We might need to further investigate
342
+ the feasibility of that suggestion.
343
+
344
+ For example,
345
+
346
+ def func(x, b=1.0):
347
+ y = x + b
348
+ z = y.relu()
349
+ return (y, z)
350
+
351
+ x = torch.randn(1, 1, 2, dtype=torch.float32)
352
+ gm_fun, _ = dynamo.export(func, x, b=8.0, aten_graph=True, tracing_mode="real")
353
+
354
+ # class GraphModule(torch.nn.Module):
355
+ # def forward(self, x, b):
356
+ # arg0: f32[1, 1, 2], arg1, = fx_pytree.tree_flatten_spec(([x, b], {}), self._in_spec)
357
+ # # File: path/to/pytorch/test_constant_input.py:5, code: y = x + b
358
+ # add_tensor: f32[1, 1, 2] = torch.ops.aten.add.Tensor(arg0, 8.0); arg0 = None
359
+
360
+ # # File: path/to/pytorch/test_constant_input.py:6, code: z = y.relu()
361
+ # relu_default: f32[1, 1, 2] = torch.ops.aten.relu.default(add_tensor)
362
+ # return pytree.tree_unflatten([add_tensor, relu_default], self._out_spec)
363
+
364
+ Empty torch.fx.Node input leading to a mismatched number of input with PyTorch, as
365
+ it's ignored in ONNX graph. Thus, we delete the useless input here.
366
+
367
+ """
368
+
369
+ def apply(
370
+ self,
371
+ model_args: Sequence[Any],
372
+ model_kwargs: Mapping[str, Any],
373
+ model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
374
+ ) -> tuple[Sequence[Any], Mapping[str, Any]]:
375
+ """Remove Constant from arguments.
376
+
377
+ Args:
378
+ model_args: The model args.
379
+ model_kwargs: The model kwargs.
380
+ model: The PyTorch model.
381
+
382
+ Returns:
383
+ A tuple of the model args and kwargs.
384
+
385
+ Raises:
386
+ ValueError: If `model_kwargs` is not empty.
387
+ """
388
+ assert not model_kwargs
389
+ return (
390
+ tuple(
391
+ arg
392
+ for arg in model_args
393
+ if not isinstance(arg, (int, float, bool, str))
394
+ ),
395
+ {},
396
+ )
397
+
398
+
399
+ class FlattenInputWithTreeSpecValidationInputStep(InputAdaptStep):
400
+ """Flatten nested collection types and return a flat list of elements.
401
+
402
+ ONNX can't represent collection types (e.g., dictionary, tuple of tuple of tensor,
403
+ etc).
404
+
405
+ This class stores the `SpecTree` output produced when `adapt` was called the first
406
+ time. It then validates the `SpecTree` output produced from later `adapt` calls.
407
+ """
408
+
409
+ _spec: pytree.TreeSpec | None = None
410
+
411
+ def apply(
412
+ self,
413
+ model_args: Sequence[Any],
414
+ model_kwargs: Mapping[str, Any],
415
+ model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
416
+ ) -> tuple[Sequence[Any], Mapping[str, Any]]:
417
+ """Flatten the model args and kwargs and validate the `SpecTree` output.
418
+
419
+ Args:
420
+ model_args: The model args.
421
+ model_kwargs: The model kwargs.
422
+ model: The PyTorch model.
423
+
424
+ Returns:
425
+ A tuple of the flattened model args and kwargs. The kwargs is empty, because
426
+ they are flattened and merged into the args.
427
+
428
+ Raises:
429
+ ValueError: If the `SpecTree` output produced from the current `model_outputs`
430
+ is not identical to the `SpecTree` output produced from the first
431
+ `model_outputs` that was passed to this method.
432
+ """
433
+ flattened_args, spec = pytree.tree_flatten((model_args, model_kwargs))
434
+ if self._spec is None:
435
+ self._spec = spec
436
+ else:
437
+ _assert_identical_pytree_spec(
438
+ self._spec,
439
+ spec,
440
+ error_message="Model inputs incompatible with the format that was exported. ",
441
+ )
442
+ return flattened_args, {}
443
+
444
+
445
+ class FlattenOutputStep(OutputAdaptStep):
446
+ """Flatten nested collection types and return a flat list of elements.
447
+
448
+ ONNX can't represent collection types (e.g., dictionary, tuple of tuple of tensor,
449
+ etc).
450
+
451
+ NOTE: Ideally we would want to use ``FlattenOutputWithTreeSpecValidationOutputStep``, such
452
+ that `SpecTree` can be validate for new model outputs. However, this is not possible
453
+ currently because we never have access to real PyTorch model outputs during export.
454
+ Only traced outputs may be available, but they are not an accurate reflection of the
455
+ original PyTorch model outputs format as they are typically in their own unique format,
456
+ depending on the tracing strategy.
457
+ """
458
+
459
+ def apply(
460
+ self,
461
+ model_outputs: Any,
462
+ model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
463
+ ) -> Sequence[Any]:
464
+ """Flatten the model outputs.
465
+
466
+ Args:
467
+ model_outputs: The model outputs to flatten.
468
+ model: The PyTorch model.
469
+
470
+ Returns:
471
+ A tuple of the flattened model outputs.
472
+ """
473
+ return pytree.tree_leaves(model_outputs)
474
+
475
+
476
+ class ConvertComplexToRealRepresentationOutputStep(OutputAdaptStep):
477
+ """Convert complex dtype tensors to real representation tensors.
478
+
479
+ ONNX does not support complex dtype tensors. Thus, we convert complex dtype tensors
480
+ to real representation tensors (i.e., float dtype tensors with an extra dimension
481
+ representing the real and imaginary parts of the complex number).
482
+
483
+ """
484
+
485
+ def apply(
486
+ self,
487
+ model_outputs: Any,
488
+ model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
489
+ ) -> Any:
490
+ """Convert float tensors to complex tensors.
491
+
492
+ Args:
493
+ model_output: The model output.
494
+ model: The PyTorch model.
495
+
496
+ Returns:
497
+ A tuple of the model output.
498
+ """
499
+ return [
500
+ torch.view_as_real(output.resolve_conj())
501
+ if isinstance(output, torch.Tensor) and torch.is_complex(output)
502
+ else output
503
+ for output in model_outputs
504
+ ]
505
+
506
+
507
+ class FlattenOutputWithTreeSpecValidationOutputStep(OutputAdaptStep):
508
+ """Same as ``FlattenOutputStep``, with additional `TreeSpec` validation.
509
+
510
+ This class stores the `SpecTree` output produced when `adapt` was called the first
511
+ time. It then validates the `SpecTree` output produced from later `adapt` calls.
512
+ """
513
+
514
+ _spec: pytree.TreeSpec | None = None
515
+
516
+ def apply(
517
+ self,
518
+ model_outputs: Any,
519
+ model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
520
+ ) -> Sequence[Any]:
521
+ """Flatten the model outputs and validate the `SpecTree` output.
522
+
523
+ Args:
524
+ model_outputs: The model outputs to flatten.
525
+ model: The PyTorch model.
526
+
527
+ Returns:
528
+ flattened_outputs: The flattened model outputs.
529
+
530
+ Raises:
531
+ ValueError: If the `SpecTree` output produced from the current `model_outputs`
532
+ is not identical to the `SpecTree` output produced from the first
533
+ `model_outputs` that was passed to this method.
534
+ """
535
+ flattened_outputs, spec = pytree.tree_flatten(model_outputs)
536
+ if self._spec is None:
537
+ self._spec = spec
538
+ else:
539
+ _assert_identical_pytree_spec(
540
+ self._spec,
541
+ spec,
542
+ error_message="Model outputs incompatible with the format that was exported. ",
543
+ )
544
+ return flattened_outputs
545
+
546
+
547
+ class PrependParamsBuffersConstantAotAutogradInputStep(InputAdaptStep):
548
+ """Prepend model parameters, buffers and constants to the user input.
549
+
550
+ :func:`torch.export.export` lifts model parameters, buffers and constants as model input, thus, they
551
+ must be added to the user input before the model is executed.
552
+
553
+ Args:
554
+ model: The PyTorch model with embedded parameters and buffers.
555
+ """
556
+
557
+ def apply(
558
+ self,
559
+ model_args: Sequence[Any],
560
+ model_kwargs: Mapping[str, Any],
561
+ model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
562
+ ) -> tuple[Sequence[Any], Mapping[str, Any]]:
563
+ """Convert complex tensors to float tensors.
564
+
565
+ Args:
566
+ model_args: The model args.
567
+ model_kwargs: The model kwargs.
568
+ model: The PyTorch model.
569
+
570
+ Returns:
571
+ A tuple of the model args and kwargs.
572
+ """
573
+ ordered_params = tuple(
574
+ model.state_dict[name] # type: ignore[union-attr,index]
575
+ for name in model.graph_signature.parameters # type: ignore[union-attr]
576
+ )
577
+ non_persistent_buffers = set(model.graph_signature.non_persistent_buffers) # type: ignore[union-attr]
578
+ ordered_buffers = []
579
+ for name in model.graph_signature.buffers: # type: ignore[union-attr]
580
+ if name in non_persistent_buffers:
581
+ ordered_buffers.append(model.constants[name]) # type: ignore[union-attr]
582
+ else:
583
+ ordered_buffers.append(model.state_dict[name]) # type: ignore[union-attr,index]
584
+ ordered_constant_tensors = tuple(
585
+ model.constants[fqn] # type: ignore[union-attr,index]
586
+ for fqn in model.graph_signature.lifted_tensor_constants # type: ignore[union-attr]
587
+ )
588
+
589
+ # NOTE: calling convention is first params, then buffers, then args as user supplied them.
590
+ # See: torch/_functorch/aot_autograd.py#L1034
591
+ updated_args = (
592
+ *ordered_params,
593
+ *ordered_buffers,
594
+ *ordered_constant_tensors,
595
+ *model_args,
596
+ )
597
+ if model_kwargs:
598
+ return MergeKwargsIntoArgsInputStep().apply(
599
+ updated_args, model_kwargs, model=model
600
+ )
601
+ return updated_args, {}
602
+
603
+
604
+ class PrependParamsAndBuffersAotAutogradOutputStep(OutputAdaptStep):
605
+ """Prepend model's mutated buffers to the user output.
606
+
607
+ :func:`torch.export.export` lifts model's mutated buffers as outputs, thus, they
608
+ must be added to the user output after the model is executed.
609
+
610
+ Args:
611
+ model: The PyTorch model with mutated buffers.
612
+ """
613
+
614
+ def apply(
615
+ self,
616
+ model_outputs: Any,
617
+ model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
618
+ ) -> Sequence[Any]:
619
+ """Flatten the model outputs and validate the `SpecTree` output.
620
+
621
+ Args:
622
+ model_outputs: The model outputs to flatten.
623
+ model: The PyTorch model.
624
+
625
+ Returns:
626
+ flattened_outputs: The flattened model outputs.
627
+ """
628
+
629
+ assert isinstance(
630
+ model, torch_export.ExportedProgram
631
+ ), "'model' must be torch_export.ExportedProgram"
632
+ ordered_buffers = tuple(
633
+ model.state_dict[name]
634
+ if name in model.state_dict
635
+ else model.constants[name]
636
+ for name in model.graph_signature.buffers_to_mutate.values()
637
+ )
638
+
639
+ # NOTE: calling convention is first mutated buffers, then outputs args as model returned them.
640
+ updated_outputs = (*ordered_buffers, *model_outputs)
641
+ return updated_outputs
.venv/Lib/site-packages/torch/onnx/_internal/jit_utils.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ """Utilities for manipulating the torch.Graph object and the torchscript."""
3
+
4
+ # TODO(justinchuby): Move more of the symbolic helper functions here and expose
5
+ # them to the user.
6
+
7
+ from __future__ import annotations
8
+
9
+ import dataclasses
10
+ import re
11
+ import typing
12
+ from typing import Any, Iterable, Sequence
13
+
14
+ import torch
15
+ from torch import _C
16
+ from torch.onnx._globals import GLOBALS
17
+ from torch.onnx._internal import registration
18
+
19
+
20
+ _ATTR_PATTERN = re.compile("^(.+)_(([ifstgz])|(ty))$")
21
+ _SKIP_NODE_ATTRIBUTES = {"inplace", "aten"}
22
+
23
+
24
+ @dataclasses.dataclass
25
+ class GraphContext:
26
+ """Extra context for symbolic functions with all methods from torch.Graph.
27
+
28
+ NOTE: This class is not meant for external consumption. Please do not depend on
29
+ it outside of torch.onnx as the interface may evolve.
30
+
31
+ Attributes:
32
+ graph: The _C.Graph being constructed.
33
+ block: The current _C.Block being constructed.
34
+ opset: The opset version.
35
+ original_node: Current node that is being converted from.
36
+ params_dict: Mapping from graph initializer name to IValue.
37
+ env: Mapping from Torch domain graph Value to ONNX domain graph Value.
38
+ values_in_env: Set of all values in env, for constant-time lookups.
39
+ new_nodes: List that tracks all new nodes that are added (used to make
40
+ sure metadata is propagated to all new nodes).
41
+ """
42
+
43
+ graph: _C.Graph
44
+ block: _C.Block
45
+ opset: int
46
+ original_node: _C.Node
47
+ params_dict: dict[str, _C.IValue]
48
+ env: dict[_C.Value, _C.Value]
49
+ values_in_env: set[_C.Value]
50
+ new_nodes: list[_C.Node] = dataclasses.field(default_factory=list)
51
+
52
+ # Relay methods from _C.Graph for compatibility with symbolic functions that expect
53
+ # a _C.Graph
54
+ def __getattr__(self, name: str) -> Any:
55
+ return getattr(self.graph, name)
56
+
57
+ def op(
58
+ self,
59
+ opname: str,
60
+ *raw_args: torch.Tensor | _C.Value,
61
+ outputs: int = 1,
62
+ **kwargs,
63
+ ):
64
+ """Creates an ONNX operator "opname", taking "raw_args" as inputs and "kwargs" as attributes.
65
+
66
+ The set of operators and the inputs/attributes they take
67
+ is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md
68
+
69
+ Args:
70
+ opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified
71
+ with a namespace, e.g., `aten::add`.
72
+ raw_args: The inputs to the operator; usually provided
73
+ as arguments to the `symbolic` definition.
74
+ outputs: The number of outputs this operator returns.
75
+ By default an operator is assumed to return a single output.
76
+ If `outputs` is greater than one, this functions returns a tuple
77
+ of output `Value`, representing each output of the ONNX operator
78
+ in order.
79
+ kwargs: The attributes of the ONNX operator, whose keys are named
80
+ according to the following convention: `alpha_f` indicates
81
+ the `alpha` attribute with type `f`. The valid type specifiers are
82
+ `f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute
83
+ specified with type float accepts either a single float, or a
84
+ list of floats (e.g., you would say `dims_i` for a `dims` attribute
85
+ that takes a list of integers).
86
+
87
+ Returns:
88
+ The value representing the single output of this operator (see the `outputs`
89
+ keyword argument for multi-return nodes).
90
+ """
91
+ # FIXME(justinchuby): Add the return type back once we know how to handle mypy
92
+ return _add_op(self, opname, *raw_args, outputs=outputs, **kwargs)
93
+
94
+ def aten_op(self, operator: str, *args, overload_name: str = "", **kwargs):
95
+ """Generates an ONNX ATen op node.
96
+
97
+ This function is for backward compatibility with the old symbolic functions.
98
+ """
99
+ return self.op(
100
+ "aten::ATen",
101
+ *args,
102
+ operator_s=operator,
103
+ overload_name_s=overload_name,
104
+ **kwargs,
105
+ )
106
+
107
+ # NOTE: For backward compatibility with the old symbolic functions.
108
+ # We are probably going to remove this only after the fx exporter is established.
109
+ at = aten_op
110
+
111
+ def onnxscript_op(
112
+ self,
113
+ onnx_fn,
114
+ *raw_args: torch.Tensor | _C.Value,
115
+ outputs: int = 1,
116
+ **kwargs,
117
+ ):
118
+ """Creates an ONNX operator from onnx-script function, taking "raw_args" as inputs and "kwargs" as attributes.
119
+
120
+ onnx-script repository: https://github.com/microsoft/onnx-script
121
+
122
+ Args:
123
+ onnx_fn: ONNXFunction from onnx-script; An example can be found at
124
+ https://github.com/microsoft/onnx-script#example
125
+ raw_args: The inputs to the operator; usually provided
126
+ as arguments to the `symbolic` definition.
127
+ outputs: The number of outputs this operator returns.
128
+ By default an operator is assumed to return a single output.
129
+ If `outputs` is greater than one, this functions returns a tuple
130
+ of output `Value`, representing each output of the ONNX operator
131
+ in order.
132
+ kwargs: The attributes of the ONNX operator, whose keys are named
133
+ according to the following convention: `alpha_f` indicates
134
+ the `alpha` attribute with type `f`. The valid type specifiers are
135
+ `f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute
136
+ specified with type float accepts either a single float, or a
137
+ list of floats (e.g., you would say `dims_i` for a `dims` attribute
138
+ that takes a list of integers).
139
+
140
+ Returns:
141
+ The value representing the single output of this operator (see the `outputs`
142
+ keyword argument for multi-return nodes).
143
+ """
144
+ # NOTE(titaiwang): This is using class attributes, and it needs to be updated
145
+ # if onnx-script makes any change on these.
146
+ symbolic_name = f"{onnx_fn.opset.domain}::{onnx_fn.name}"
147
+ opset_version = onnx_fn.opset.version
148
+
149
+ registration.custom_onnx_symbolic(symbolic_name, opset_version)(onnx_fn)
150
+
151
+ return _add_op(self, symbolic_name, *raw_args, outputs=outputs, **kwargs)
152
+
153
+
154
+ def add_op_with_blocks(
155
+ graph_context: GraphContext,
156
+ opname: str,
157
+ *inputs: _C.Value,
158
+ outputs: int = 1,
159
+ n_blocks: int = 1,
160
+ **attributes,
161
+ ) -> tuple[Any, tuple[GraphContext, ...], _C.Node]:
162
+ """Creates an ONNX operator "opname", taking inputs and attributes.
163
+
164
+ Args:
165
+ graph_context: The context for the current graph.
166
+ opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified
167
+ with a namespace, e.g., `aten::add`.
168
+ inputs: The inputs to the operator.
169
+ outputs: The number of outputs this operator returns.
170
+ By default an operator is assumed to return a single output.
171
+ If `outputs` is greater than one, this functions returns a tuple
172
+ of output `Value`, representing each output of the ONNX operator
173
+ in order.
174
+ n_blocks: The number of sub-blocks to create in the node.
175
+ attributes: The attributes of the ONNX operator.
176
+
177
+ Returns:
178
+ A tuple of (output_values, new_contexts, node) where:
179
+ output_values: One or more output value of this operator
180
+ (see the `outputs` keyword argument for multi-return nodes).
181
+ new_contexts: A tuple of new graph contexts for each sub-block.
182
+ node: The node representing the operator.
183
+ """
184
+
185
+ output_values = graph_context.op(opname, *inputs, outputs=outputs, **attributes)
186
+ if isinstance(output_values, Sequence):
187
+ node = output_values[0].node()
188
+ else:
189
+ node = output_values.node()
190
+
191
+ new_contexts = []
192
+ for _ in range(n_blocks):
193
+ new_block = node.addBlock()
194
+ # Create shallow copy of the graph context and update the block
195
+ new_context = dataclasses.replace(graph_context, block=new_block)
196
+ new_contexts.append(new_context)
197
+
198
+ return output_values, tuple(new_contexts), node
199
+
200
+
201
+ def _add_op(
202
+ graph_context: GraphContext,
203
+ opname: str,
204
+ *args: torch.Tensor | _C.Value,
205
+ outputs: int = 1,
206
+ **kwargs,
207
+ ):
208
+ """Creates an ONNX operator "opname", taking "args" as inputs and attributes "kwargs".
209
+
210
+ The set of operators and the inputs/attributes they take
211
+ is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md
212
+
213
+ This function is monkey-patched onto Graph.
214
+
215
+ Args:
216
+ graph_context: The Torch Graph or Block.
217
+ opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified
218
+ with a namespace, e.g., `aten::add`.
219
+ args: The inputs to the operator; usually provided
220
+ as arguments to the `symbolic` definition.
221
+ outputs: The number of outputs this operator returns.
222
+ By default an operator is assumed to return a single output.
223
+ If `outputs` is greater than one, this functions returns a tuple
224
+ of output `Value`, representing each output of the ONNX operator
225
+ in order.
226
+ kwargs: The attributes of the ONNX operator, whose keys are named
227
+ according to the following convention: `alpha_f` indicates
228
+ the `alpha` attribute with type `f`. The valid type specifiers are
229
+ `f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute
230
+ specified with type float accepts either a single float, or a
231
+ list of floats (e.g., you would say `dims_i` for a `dims` attribute
232
+ that takes a list of integers).
233
+
234
+ Returns:
235
+ (Union[_C.Value, Tuple[_C.Value, ...]])
236
+ The value representing the single output of this operator (see the `outputs`
237
+ keyword argument for multi-return nodes).
238
+ """
239
+ inputs = [_const_if_tensor(graph_context, arg) for arg in args]
240
+ # Filter out None attributes, this can be convenient client side because
241
+ # now they can pass through None attributes, and have them not show up
242
+ attributes = {k: v for k, v in kwargs.items() if v is not None}
243
+
244
+ if "::" not in opname:
245
+ opname = "onnx::" + opname
246
+
247
+ node = _create_node(
248
+ graph_context.block,
249
+ opname,
250
+ inputs,
251
+ attributes,
252
+ params_dict=graph_context.params_dict,
253
+ opset_version=graph_context.opset,
254
+ n_outputs=outputs,
255
+ shape_inference=GLOBALS.onnx_shape_inference,
256
+ )
257
+ graph_context.new_nodes.append(node)
258
+
259
+ if outputs == 1:
260
+ return node.output()
261
+ return tuple(node.outputs())
262
+
263
+
264
+ def _const_if_tensor(graph_context: GraphContext, arg):
265
+ if arg is None:
266
+ return arg
267
+ if isinstance(arg, _C.Value):
268
+ return arg
269
+
270
+ return _add_op(graph_context, "onnx::Constant", value_z=arg)
271
+
272
+
273
+ def _create_node(
274
+ graph_or_block: _C.Graph | _C.Block,
275
+ domain_op: str,
276
+ inputs: Sequence,
277
+ attributes: dict,
278
+ params_dict: dict,
279
+ opset_version: int,
280
+ n_outputs: int,
281
+ shape_inference: bool = True,
282
+ ) -> _C.Node:
283
+ """Creates an node 'domain_op', taking inputs and attributes."""
284
+ if isinstance(graph_or_block, _C.Graph):
285
+ graph = graph_or_block
286
+ node = graph.create(domain_op, inputs, n_outputs)
287
+ node = graph.insertNode(node)
288
+ elif isinstance(graph_or_block, _C.Block):
289
+ block = graph_or_block
290
+ node = block.addNode(domain_op, inputs)
291
+
292
+ # Block does not have create defined, so we need to add outputs manually
293
+ if n_outputs > 1:
294
+ for _ in range(1, n_outputs):
295
+ node.addOutput()
296
+
297
+ node_outputs = tuple(node.outputs()) # type: ignore[possibly-undefined]
298
+ assert len(node_outputs) == n_outputs
299
+
300
+ aten = domain_op.startswith("aten::")
301
+
302
+ # Add all attributes
303
+ for key, value in sorted(attributes.items()):
304
+ if key in _SKIP_NODE_ATTRIBUTES:
305
+ continue
306
+ _add_attribute(node, key, value, aten=aten)
307
+ if shape_inference:
308
+ _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
309
+ return node
310
+
311
+
312
+ def _is_onnx_list(value):
313
+ return isinstance(value, Iterable) and not isinstance(
314
+ value, (str, bytes, torch.Tensor)
315
+ )
316
+
317
+
318
+ def _scalar(x: torch.Tensor):
319
+ """Convert a scalar tensor into a Python value."""
320
+ assert x.numel() == 1
321
+ return x[0]
322
+
323
+
324
+ def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool):
325
+ r"""Initializes the right attribute based on type of value."""
326
+ m = _ATTR_PATTERN.match(key)
327
+ if m is None:
328
+ raise ValueError(
329
+ f"Invalid attribute specifier '{key}' names "
330
+ "must be suffixed with type, e.g. 'dim_i' or 'dims_i'"
331
+ )
332
+ name, kind = m.group(1), m.group(2)
333
+ if _is_onnx_list(value):
334
+ kind += "s"
335
+
336
+ return getattr(node, f"{kind}_")(name, value)
337
+
338
+
339
+ # TODO: Expose this to user when migrating symbolic helper functions to here.
340
+ def _is_tensor(x: _C.Value) -> bool:
341
+ return x.type().isSubtypeOf(_C.TensorType.get())
342
+
343
+
344
+ def get_device_from_value(value: _C.Value) -> torch.device | None:
345
+ if not _is_tensor(value):
346
+ return None
347
+ tensor_type = typing.cast(_C.TensorType, value.type())
348
+ return tensor_type.device()
349
+
350
+
351
+ def parse_node_kind(kind: str) -> tuple[str, str]:
352
+ """Parse node kind into domain and Op name."""
353
+ if "::" not in kind:
354
+ raise ValueError(f"Node kind: {kind} is invalid. '::' is not in node kind.")
355
+ domain, opname = kind.split("::", 1)
356
+ if "::" in opname:
357
+ raise ValueError(f"Node kind: {kind} is invalid. '::' should only apear once.")
358
+ return domain, opname
359
+
360
+
361
+ def is_aten(domain: str) -> bool:
362
+ """Check if the domain is official."""
363
+ return domain == "aten"
364
+
365
+
366
+ def is_prim(domain: str) -> bool:
367
+ """Check if the domain is official."""
368
+ return domain == "prim"
369
+
370
+
371
+ def is_onnx(domain: str) -> bool:
372
+ """Check if the domain is official."""
373
+ return domain == "onnx"