File size: 3,134 Bytes
9dd3461
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""ONNX exporter."""

from torch import _C
from torch._C import _onnx as _C_onnx
from torch._C._onnx import (
    _CAFFE2_ATEN_FALLBACK,
    OperatorExportTypes,
    TensorProtoDataType,
    TrainingMode,
)

from . import (  # usort:skip. Keep the order instead of sorting lexicographically
    _deprecation,
    errors,
    symbolic_caffe2,
    symbolic_helper,
    symbolic_opset7,
    symbolic_opset8,
    symbolic_opset9,
    symbolic_opset10,
    symbolic_opset11,
    symbolic_opset12,
    symbolic_opset13,
    symbolic_opset14,
    symbolic_opset15,
    symbolic_opset16,
    symbolic_opset17,
    utils,
)

# TODO(After 1.13 release): Remove the deprecated SymbolicContext
from ._exporter_states import ExportTypes, SymbolicContext
from ._type_utils import JitScalarType
from .errors import CheckerError  # Backwards compatibility
from .utils import (
    _optimize_graph,
    _run_symbolic_function,
    _run_symbolic_method,
    export,
    export_to_pretty_string,
    is_in_onnx_export,
    register_custom_op_symbolic,
    select_model_mode_for_export,
    unregister_custom_op_symbolic,
)

__all__ = [
    # Modules
    "symbolic_helper",
    "utils",
    "errors",
    # All opsets
    "symbolic_caffe2",
    "symbolic_opset7",
    "symbolic_opset8",
    "symbolic_opset9",
    "symbolic_opset10",
    "symbolic_opset11",
    "symbolic_opset12",
    "symbolic_opset13",
    "symbolic_opset14",
    "symbolic_opset15",
    "symbolic_opset16",
    "symbolic_opset17",
    # Enums
    "ExportTypes",
    "OperatorExportTypes",
    "TrainingMode",
    "TensorProtoDataType",
    "JitScalarType",
    # Public functions
    "export",
    "export_to_pretty_string",
    "is_in_onnx_export",
    "select_model_mode_for_export",
    "register_custom_op_symbolic",
    "unregister_custom_op_symbolic",
    "disable_log",
    "enable_log",
    # Errors
    "CheckerError",  # Backwards compatibility
]

# Set namespace for exposed private names
ExportTypes.__module__ = "torch.onnx"
JitScalarType.__module__ = "torch.onnx"

producer_name = "pytorch"
producer_version = _C_onnx.PRODUCER_VERSION


@_deprecation.deprecated(
    since="1.12.0", removed_in="1.14", instructions="use `torch.onnx.export` instead"
)
def _export(*args, **kwargs):
    return utils._export(*args, **kwargs)


# TODO(justinchuby): Deprecate these logging functions in favor of the new diagnostic module.

# Returns True iff ONNX logging is turned on.
is_onnx_log_enabled = _C._jit_is_onnx_log_enabled


def enable_log() -> None:
    r"""Enables ONNX logging."""
    _C._jit_set_onnx_log_enabled(True)


def disable_log() -> None:
    r"""Disables ONNX logging."""
    _C._jit_set_onnx_log_enabled(False)


"""Sets output stream for ONNX logging.

Args:
    stream_name (str, default "stdout"): Only 'stdout' and 'stderr' are supported
        as ``stream_name``.
"""
set_log_stream = _C._jit_set_onnx_log_output_stream


"""A simple logging facility for ONNX exporter.

Args:
    args: Arguments are converted to string, concatenated together with a newline
        character appended to the end, and flushed to output stream.
"""
log = _C._jit_onnx_log