JSX_TTS / torch /onnx /symbolic_opset17.py
UMMJ's picture
Upload 5875 files
9dd3461
"""This file exports ONNX ops for opset 17.
Note [ONNX Operators that are added/updated in opset 17]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-17-of-the-default-onnx-operator-set
New operators:
BlackmanWindow
DFT
HammingWindow
HannWindow
LayerNormalization
MelWeightMatrix
STFT
SequenceMap
"""
import functools
from typing import Sequence
from torch import _C
from torch.onnx import symbolic_helper
from torch.onnx._internal import jit_utils, registration
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md
__all__ = ["layer_norm"]
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=17)
@_onnx_symbolic("aten::layer_norm")
@symbolic_helper.parse_args("v", "is", "v", "v", "f", "none")
def layer_norm(
g: jit_utils.GraphContext,
input: _C.Value,
normalized_shape: Sequence[int],
weight: _C.Value,
bias: _C.Value,
eps: float,
cudnn_enable: bool,
):
# normalized_shape: input shape from an expected input of size
# axis: The first normalization dimension.
# layer_norm normalizes on the last D dimensions,
# where D is the size of normalized_shape
axis = -len(normalized_shape)
return g.op(
"LayerNormalization",
input,
weight,
bias,
epsilon_f=eps,
axis_i=axis,
)