JSX_TTS / torch /onnx /symbolic_opset10.py
UMMJ's picture
Upload 5875 files
9dd3461
import functools
import sys
import warnings
from typing import Callable, Sequence
import torch
import torch._C._onnx as _C_onnx
import torch.onnx
from torch import _C
# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics
from torch.onnx import ( # noqa: F401
_patch_torch,
_type_utils,
errors,
symbolic_helper,
symbolic_opset9 as opset9,
)
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import _beartype, jit_utils, registration
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md
# This file exports ONNX ops for opset 10
# Opset 10 is supported by ONNX release 1.5.0
# release on 04/24/19
__all__ = [
"dequantize",
"div",
"embedding_bag",
"fake_quantize_per_tensor_affine",
"flip",
"fmod",
"isfinite",
"isinf",
"nan_to_num",
"quantize_per_tensor",
"quantized_add_relu",
"quantized_add",
"quantized_cat",
"quantized_conv2d_relu",
"quantized_conv2d",
"quantized_group_norm",
"quantized_hardswish",
"quantized_instance_norm",
"quantized_layer_norm",
"quantized_leaky_relu",
"quantized_linear",
"quantized_mul",
"quantized_sigmoid",
"slice",
"sort",
"topk",
]
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=10)
def _apply_params(*args, **kwargs):
"""Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
def _apply(fn):
return fn(*args, **kwargs)
return _apply
@_onnx_symbolic("aten::div")
@_beartype.beartype
def div(g: jit_utils.GraphContext, self, other, *args):
if len(args) == 0:
return opset9.true_divide(g, self, other)
else:
return _div_rounding_mode(g, self, other, *args)
@symbolic_helper.parse_args("v", "v", "s")
@_beartype.beartype
def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode):
if rounding_mode == "floor":
return _floor_divide(g, self, other)
else:
return opset9._div_rounding_mode(g, self, other, rounding_mode)
@_onnx_symbolic("aten::_floor_divide")
@_beartype.beartype
def _floor_divide(g: jit_utils.GraphContext, self, other):
if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other):
out = opset9.true_divide(g, self, other)
return g.op("Floor", out)
else:
# Integer division does trunction rounding
div = g.op("Div", self, other)
# Division is negative if: self < 0 != other < 0
zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
negative = g.op("Xor", g.op("Less", self, zero), g.op("Less", other, zero))
# For negative numbers with self % other != 0, subtract 1 to round down instead of up
mod = g.op("Mod", self, other, fmod_i=0)
fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero)))
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
fixup = g.op("Sub", div, one)
return g.op("Where", fixup_mask, fixup, div)
@_onnx_symbolic("aten::sort")
@symbolic_helper.parse_args("v", "i", "i", "none")
@_beartype.beartype
def sort(g: jit_utils.GraphContext, self, dim, decending, out=None):
return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out)
@_onnx_symbolic("aten::topk")
@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none")
@_beartype.beartype
def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None):
return symbolic_helper._topk_helper(
g, self, k, dim, largest=largest, sorted=sorted, out=out
)
@_onnx_symbolic(
"aten::max_pool1d",
decorate=[
_apply_params(
"max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False
)
],
)
@_onnx_symbolic(
"aten::max_pool2d",
decorate=[
_apply_params(
"max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False
)
],
)
@_onnx_symbolic(
"aten::max_pool3d",
decorate=[
_apply_params(
"max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False
)
],
)
@_onnx_symbolic(
"aten::max_pool1d_with_indices",
decorate=[
_apply_params(
"max_pool1d_with_indices",
torch.nn.modules.utils._single,
1,
return_indices=True,
)
],
)
@_onnx_symbolic(
"aten::max_pool2d_with_indices",
decorate=[
_apply_params(
"max_pool2d_with_indices",
torch.nn.modules.utils._pair,
2,
return_indices=True,
)
],
)
@_onnx_symbolic(
"aten::max_pool3d_with_indices",
decorate=[
_apply_params(
"max_pool3d_with_indices",
torch.nn.modules.utils._triple,
3,
return_indices=True,
)
],
)
@_beartype.beartype
def _max_pool(name: str, tuple_fn: Callable, ndims: int, return_indices: bool):
@symbolic_helper.quantized_args(True, False, False, False, False, False)
@symbolic_helper.parse_args("v", "is", "is", "is", "is", "i")
def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
if not stride:
stride = kernel_size
kwargs = {
"kernel_shape_i": tuple_fn(kernel_size),
"pads_i": tuple_fn(padding) * 2,
"strides_i": tuple_fn(stride),
"ceil_mode_i": ceil_mode,
}
if set(tuple_fn(dilation)) != {1}:
kwargs["dilations_i"] = tuple_fn(dilation)
# easy but hacky way to get flattened indices values
# to be used to convert the indices values to non-flattened.
# In ONNX the indices are computed as a flatten 1-D tensor,
# so the values in indices are in [0, N x C x D1 x ... x Dn).
# To convert the indices to the same format used by Pytorch,
# we first execute a maxpool with a kernel and stride of 1 on the same input.
# This will result in a tensor of indices in which each index will have it's own value.
# Using this tensor as a reference, we extract the first index of each axis and subtract
# it from each index of this axis in the indices to convert.
# This step will result in a tensor were each dimension has values of indices within
# the dimension it is in.
# For more information :
# https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
if return_indices:
r, indices = g.op("MaxPool", input, outputs=2, **kwargs)
_, flattened_indices = g.op(
"MaxPool",
input,
outputs=2,
kernel_shape_i=[1 for _ in range(ndims)],
strides_i=[1 for _ in range(ndims)],
)
# convert indices to have non-flattened indices values
s = symbolic_helper._slice_helper(
g,
flattened_indices,
axes=[2 + i for i in range(ndims)],
starts=tuple_fn(0),
ends=tuple_fn(1),
)
indices = opset9.sub(g, indices, s)
return r, indices
else:
r = g.op("MaxPool", input, outputs=1, **kwargs)
return r
return symbolic_fn
@_onnx_symbolic(
"aten::avg_pool1d",
decorate=[_apply_params("avg_pool1d", torch.nn.modules.utils._single)],
)
@_onnx_symbolic(
"aten::avg_pool2d",
decorate=[_apply_params("avg_pool2d", torch.nn.modules.utils._pair)],
)
@_onnx_symbolic(
"aten::avg_pool3d",
decorate=[_apply_params("avg_pool3d", torch.nn.modules.utils._triple)],
)
@_beartype.beartype
def _avg_pool(name, tuple_fn):
@symbolic_helper.quantized_args(True, False, False, False, False, False, False)
@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
@_beartype.beartype
def symbolic_fn(
g,
input: _C.Value,
kernel_size: Sequence[int],
stride: Sequence[int],
padding: Sequence[int],
ceil_mode: int,
count_include_pad: int,
divisor_override=None,
):
if not stride:
stride = kernel_size
padding = symbolic_helper._avgpool_helper(
tuple_fn, padding, kernel_size, stride, divisor_override, name
)
assert isinstance(padding, tuple)
if count_include_pad:
input = opset9._op_with_optional_float_cast(
g,
"Pad",
input,
pads_i=((0,) * 2 + padding) * 2,
mode_s="constant",
value_f=0.0,
opset_before=11,
)
padding = (0,) * len(padding)
output = g.op(
"AveragePool",
input,
kernel_shape_i=tuple_fn(kernel_size),
strides_i=tuple_fn(stride),
pads_i=padding * 2,
ceil_mode_i=ceil_mode,
)
return output
return symbolic_fn
@_onnx_symbolic(
"aten::upsample_nearest1d",
decorate=[_apply_params("upsample_nearest1d", 3, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_nearest2d",
decorate=[_apply_params("upsample_nearest2d", 4, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_nearest3d",
decorate=[_apply_params("upsample_nearest3d", 5, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_linear1d",
decorate=[_apply_params("upsample_linear1d", 3, "linear")],
)
@_onnx_symbolic(
"aten::upsample_bilinear2d",
decorate=[_apply_params("upsample_bilinear2d", 4, "linear")],
)
@_onnx_symbolic(
"aten::upsample_trilinear3d",
decorate=[_apply_params("upsample_trilinear3d", 5, "linear")],
)
@_beartype.beartype
def _interpolate(name, dim, interpolate_mode):
@symbolic_helper.quantized_args(True, False, False)
@_beartype.beartype
def symbolic_fn(g, input, output_size, *args):
scales, align_corners = symbolic_helper._get_interpolate_attributes(
g, interpolate_mode, args
)
symbolic_helper._interpolate_warning(interpolate_mode)
align_corners = symbolic_helper._maybe_get_scalar(align_corners)
if align_corners:
return symbolic_helper._unimplemented(name, "align_corners == True", input)
if scales is None:
scales = symbolic_helper._interpolate_size_to_scales(
g, input, output_size, dim
)
return g.op("Resize", input, scales, mode_s=interpolate_mode)
return symbolic_fn
@_onnx_symbolic("aten::__interpolate")
@_beartype.beartype
def __interpolate(
g: jit_utils.GraphContext,
input,
size,
scale_factor,
mode,
align_corners,
recompute_scale_factor,
antialias,
):
scales, mode = symbolic_helper._interpolate_get_scales_and_mode(
g, input, size, scale_factor, mode, align_corners
)
return g.op("Resize", input, scales, mode_s=mode)
@_beartype.beartype
def _slice(
g: jit_utils.GraphContext,
input,
axes,
starts,
ends,
steps=None,
dynamic_slice=False,
):
if dynamic_slice:
starts = symbolic_helper._unsqueeze_helper(g, starts, [0])
ends = symbolic_helper._unsqueeze_helper(g, ends, [0])
if isinstance(axes, int):
axes = g.op("Constant", value_t=torch.tensor(axes))
axes = symbolic_helper._unsqueeze_helper(g, axes, [0])
else:
assert len(starts) == len(ends)
assert len(starts) == len(axes)
assert steps is None or len(starts) == len(steps)
if (
len(starts) == 1
and starts[0] == 0
and ends[0] == 9223372036854775807
and (steps is None or (len(steps) == 1 and steps[0] == 1))
):
return input
axes = g.op("Constant", value_t=torch.tensor(axes))
starts = g.op("Constant", value_t=torch.tensor(starts))
ends = g.op("Constant", value_t=torch.tensor(ends))
if steps is None:
return g.op("Slice", input, starts, ends, axes)
steps = g.op("Constant", value_t=torch.tensor(steps))
return g.op("Slice", input, starts, ends, axes, steps)
@_onnx_symbolic("aten::slice")
@_beartype.beartype
def slice(g: jit_utils.GraphContext, self, *args):
if len(args) == 4:
# aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor
dim, start, end, step = args
elif len(args) == 3:
# aten::slice(t[] l, int? start=None, int? end=None, int step=1) -> t[]
start, end, step = args
dim = 0
else:
raise errors.SymbolicValueError("Unknown aten::slice signature", self)
is_start_none = start.node().kind() == "prim::Constant" and isinstance(
start.type(), _C.NoneType
)
is_end_none = end.node().kind() == "prim::Constant" and isinstance(
end.type(), _C.NoneType
)
is_start_onnx_const = start.node().kind() == "onnx::Constant"
is_end_onnx_const = end.node().kind() == "onnx::Constant"
step = symbolic_helper._parse_arg(step, "i")
if (
(not is_start_none and not is_start_onnx_const)
or (not isinstance(end, int) and not is_end_none and not is_end_onnx_const)
or (not isinstance(dim, int) and dim.node().kind() != "onnx::Constant")
):
dynamic_slice = True
if is_start_none:
start = g.op("Constant", value_t=torch.tensor(0))
if is_end_none:
end = g.op("Constant", value_t=torch.tensor(9223372036854775807))
else:
start = [0 if is_start_none else symbolic_helper._parse_arg(start, "i")]
end = [
9223372036854775807 if is_end_none else symbolic_helper._parse_arg(end, "i")
]
dim = [symbolic_helper._parse_arg(dim, "i")]
dynamic_slice = False
return symbolic_helper._slice_helper(
g,
self,
axes=dim,
starts=start,
ends=end,
steps=[step],
dynamic_slice=dynamic_slice,
)
@_onnx_symbolic("aten::flip")
@symbolic_helper.parse_args("v", "is")
@_beartype.beartype
def flip(g: jit_utils.GraphContext, input, dims):
return symbolic_helper._slice_helper(
g,
input,
axes=dims,
starts=[-1] * len(dims),
ends=[-9223372036854775807] * len(dims),
steps=[-1] * len(dims),
)
@_onnx_symbolic("aten::fmod")
@_beartype.beartype
def fmod(g: jit_utils.GraphContext, input, other):
return g.op("Mod", input, other, fmod_i=1)
@_onnx_symbolic("aten::embedding_bag")
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
@_beartype.beartype
def embedding_bag(
g: jit_utils.GraphContext,
embedding_matrix,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse,
per_sample_weights,
include_last_offset,
padding_idx,
):
if scale_grad_by_freq and GLOBALS.export_training:
return symbolic_helper._onnx_unsupported(
"embedding_bag with scale_grad_by_freq for training mode"
)
if padding_idx is not None and padding_idx >= 0:
raise RuntimeError("embedding_bag with padding_idx")
warnings.warn(
"Export of embedding_bag with dynamic input/offsets shape is not supported in opset 10. "
"Please use opset 11 or higher to export model for dynamic input shape.'"
)
offsets_dim_0 = symbolic_helper._get_tensor_dim_size(offsets, 0)
if offsets_dim_0 is not None:
if include_last_offset:
offset_len = offsets_dim_0 - 1
offsets_extended = offsets
else:
offset_len = offsets_dim_0
offsets_extended = [
offsets,
g.op("Constant", value_t=torch.tensor([sys.maxsize])),
]
offsets_extended = g.op("Concat", *offsets_extended, axis_i=0)
list_ = []
for i in range(offset_len):
start_ = symbolic_helper._unsqueeze_helper(
g,
opset9.select(g, offsets_extended, torch.tensor(0), torch.tensor(i)),
[0],
)
end_ = symbolic_helper._unsqueeze_helper(
g,
opset9.select(
g, offsets_extended, torch.tensor(0), torch.tensor(i + 1)
),
[0],
)
axes_ = g.op("Constant", value_t=torch.tensor([0]))
indices_row = g.op("Slice", indices, start_, end_, axes_)
embeddings = g.op("Gather", embedding_matrix, indices_row)
if not symbolic_helper._is_none(per_sample_weights):
per_sample_weights_row = g.op(
"Slice", per_sample_weights, start_, end_, axes_
)
per_sample_weights_row = symbolic_helper._unsqueeze_helper(
g, per_sample_weights_row, [1]
)
embeddings = g.op("Mul", embeddings, per_sample_weights_row)
if mode == 0:
embeddings = symbolic_helper._reducesum_helper(
g, embeddings, axes_i=[0], keepdims_i=0
)
elif mode == 1:
embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0)
else:
embeddings = g.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0)
embeddings = symbolic_helper._unsqueeze_helper(g, embeddings, [0])
list_.append(embeddings)
output = g.op("Concat", *list_, axis_i=0)
# aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
# But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
return output, None, None, None
else:
return symbolic_helper._onnx_unsupported(
"embedding_bag with unknown shape of offsets for opset 10 is not supported. "
"please use opset 11 or higher."
)
@_onnx_symbolic("aten::fake_quantize_per_tensor_affine")
@symbolic_helper.parse_args("v", "v", "v", "i", "i")
@_beartype.beartype
def fake_quantize_per_tensor_affine(
g: jit_utils.GraphContext,
inputs,
scale,
zero_point,
quant_min=-128,
quant_max=127,
):
# NOTE: (0, 127) is a special case. PyTorch restricts activations to be in the range (0, 127).
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
if (quant_min, quant_max) == (0, 127):
symbolic_helper._onnx_opset_unsupported_detailed(
"fake_quantize_per_tensor_affine",
10,
13,
"Quantize range (0, 127) not supported, requires opset 13 Clip",
inputs,
)
if (quant_min, quant_max) not in [(0, 255), (-128, 127)]:
raise errors.SymbolicValueError(
f"For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). "
f"Got ({quant_min}, {quant_max})",
inputs,
)
scale = symbolic_helper._maybe_get_scalar(scale)
if scale is None:
symbolic_helper._onnx_opset_unsupported_detailed(
"fake_quantize_per_tensor_affine",
10,
13,
"Non-constant scale not supported",
inputs,
)
scale = scale.float().data # Avoid exporter generating double type
if quant_min == 0:
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
else:
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8)
return g.op(
"DequantizeLinear",
g.op("QuantizeLinear", inputs, scale, zero_point),
scale,
zero_point,
)
@_onnx_symbolic("aten::isinf")
@_beartype.beartype
def isinf(g: jit_utils.GraphContext, input):
return g.op("IsInf", opset9._cast_Double(g, input, False)) # type: ignore[attr-defined]
@_onnx_symbolic("aten::isfinite")
@_beartype.beartype
def isfinite(g: jit_utils.GraphContext, input):
inf_node = isinf(g, input)
nan_node = opset9.isnan(g, input)
return opset9.__not_(g, opset9.__or_(g, inf_node, nan_node))
@_onnx_symbolic("aten::quantize_per_tensor")
@_beartype.beartype
def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype):
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
# TODO(justinchuby): Extract all the cast ops into a helper function.
zero_point = g.op(
"Cast", zero_point, to_i=_type_utils.JitScalarType(dtype).onnx_type()
)
scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT)
return symbolic_helper.quantize_helper(g, input, scale, zero_point)
@_onnx_symbolic("aten::dequantize")
@_beartype.beartype
def dequantize(g: jit_utils.GraphContext, input):
return symbolic_helper.dequantize_helper(g, input)[0]
@_onnx_symbolic("aten::nan_to_num")
@symbolic_helper.parse_args("v", "f", "f", "f")
@_beartype.beartype
def nan_to_num(g: jit_utils.GraphContext, input, nan, posinf, neginf):
# Cannot create a int type tensor with inf/nan values, so we simply
# return the original tensor
if not symbolic_helper._is_fp(input):
return input
input_dtype = _type_utils.JitScalarType.from_name(input.type().scalarType()).dtype()
if nan is None:
nan = 0.0
nan_cond = opset9.isnan(g, input)
nan_result = g.op(
"Where",
nan_cond,
g.op("Constant", value_t=torch.tensor([nan], dtype=input_dtype)),
input,
)
# For None values of posinf, neginf we use the greatest/lowest finite
# value representable by input’s dtype.
finfo = torch.finfo(input_dtype)
if posinf is None:
posinf = finfo.max
posinf_cond = opset9.logical_and(
g,
isinf(g, nan_result),
opset9.gt(g, nan_result, g.op("Constant", value_t=torch.LongTensor([0]))),
)
nan_posinf_result = g.op(
"Where",
posinf_cond,
g.op("Constant", value_t=torch.tensor([posinf], dtype=input_dtype)),
nan_result,
)
if neginf is None:
neginf = finfo.min
neginf_cond = opset9.logical_and(
g,
isinf(g, nan_posinf_result),
opset9.lt(
g, nan_posinf_result, g.op("Constant", value_t=torch.LongTensor([0]))
),
)
return g.op(
"Where",
neginf_cond,
g.op("Constant", value_t=torch.tensor([neginf], dtype=input_dtype)),
nan_posinf_result,
)
# Quantized symbolics ---------------------------------------------------------
# https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export
# Support starts from opset 10 because `DequantizeLinear` and `QuantizeLinear` were
# introduced in opset version 10.
@_onnx_symbolic("quantized::linear")
@_beartype.beartype
def quantized_linear(
g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point
):
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
output = opset9.linear(g, input, weight, bias)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::add")
@_beartype.beartype
def quantized_add(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
output = opset9.add(g, x, y)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::add_relu")
@_beartype.beartype
def quantized_add_relu(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
output = opset9.add(g, x, y)
output = opset9.relu(g, output)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::mul")
@_beartype.beartype
def quantized_mul(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
output = opset9.mul(g, x, y)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::hardswish")
@_beartype.beartype
def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
output = opset9.hardswish(g, x)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::sigmoid")
@_beartype.beartype
def quantized_sigmoid(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
output = opset9.sigmoid(g, x)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::leaky_relu")
@_beartype.beartype
def quantized_leaky_relu(
g: jit_utils.GraphContext, x, negative_slope, inplace, op_scale, op_zero_point
):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
output = opset9.leaky_relu(g, x, negative_slope, inplace)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::layer_norm")
@_beartype.beartype
def quantized_layer_norm(
g: jit_utils.GraphContext,
x,
normalized_shape,
weight,
bias,
eps,
op_scale,
op_zero_point,
):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
output = opset9.layer_norm(g, x, normalized_shape, weight, bias, eps, False)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::group_norm")
@_beartype.beartype
def quantized_group_norm(
g: jit_utils.GraphContext,
x,
num_groups,
weight,
bias,
eps,
op_scale,
op_zero_point,
):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
output = opset9.group_norm(g, x, num_groups, weight, bias, eps, False)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::instance_norm")
@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v")
@_beartype.beartype
def quantized_instance_norm(
g: jit_utils.GraphContext,
q_input,
weight,
bias,
eps,
op_scale,
op_zero_point,
):
input, _, _, _ = symbolic_helper.dequantize_helper(g, q_input)
output = opset9.instance_norm(
g, input, weight, bias, None, None, False, 0.0, eps, False
)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::conv2d_relu")
@_beartype.beartype
def quantized_conv2d_relu(
g: jit_utils.GraphContext,
q_input,
q_weight,
bias,
stride,
padding,
dilation,
groups,
op_scale,
op_zero_point,
):
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups)
output = opset9.relu(g, output)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::conv2d")
@_beartype.beartype
def quantized_conv2d(
g: jit_utils.GraphContext,
q_input,
q_weight,
bias,
stride,
padding,
dilation,
groups,
op_scale,
op_zero_point,
):
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@_onnx_symbolic("quantized::cat")
@symbolic_helper.parse_args("v", "i", "v", "v")
@_beartype.beartype
def quantized_cat(
g: jit_utils.GraphContext,
q_inputs: _C.Value,
dim: int,
op_scale: _C.Value,
op_zero_point: _C.Value,
) -> _C.Value:
unpacked_inputs = symbolic_helper._unpack_list(q_inputs)
dequantized = [
symbolic_helper.dequantize_helper(g, input)[0] for input in unpacked_inputs
]
concatenated = g.op("Concat", *dequantized, axis_i=dim)
return symbolic_helper.quantize_helper(g, concatenated, op_scale, op_zero_point)