|
"""This file exports ONNX ops for opset 17. |
|
|
|
Note [ONNX Operators that are added/updated in opset 17] |
|
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
|
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-17-of-the-default-onnx-operator-set |
|
New operators: |
|
BlackmanWindow |
|
DFT |
|
HammingWindow |
|
HannWindow |
|
LayerNormalization |
|
MelWeightMatrix |
|
STFT |
|
SequenceMap |
|
""" |
|
|
|
import functools |
|
from typing import Sequence |
|
|
|
from torch import _C |
|
from torch.onnx import symbolic_helper |
|
from torch.onnx._internal import jit_utils, registration |
|
|
|
|
|
|
|
|
|
__all__ = ["layer_norm"] |
|
|
|
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=17) |
|
|
|
|
|
@_onnx_symbolic("aten::layer_norm") |
|
@symbolic_helper.parse_args("v", "is", "v", "v", "f", "none") |
|
def layer_norm( |
|
g: jit_utils.GraphContext, |
|
input: _C.Value, |
|
normalized_shape: Sequence[int], |
|
weight: _C.Value, |
|
bias: _C.Value, |
|
eps: float, |
|
cudnn_enable: bool, |
|
): |
|
|
|
|
|
|
|
|
|
axis = -len(normalized_shape) |
|
return g.op( |
|
"LayerNormalization", |
|
input, |
|
weight, |
|
bias, |
|
epsilon_f=eps, |
|
axis_i=axis, |
|
) |
|
|