|
import functools |
|
import sys |
|
import warnings |
|
from typing import Callable, Sequence |
|
|
|
import torch |
|
import torch._C._onnx as _C_onnx |
|
import torch.onnx |
|
from torch import _C |
|
|
|
|
|
from torch.onnx import ( |
|
_patch_torch, |
|
_type_utils, |
|
errors, |
|
symbolic_helper, |
|
symbolic_opset9 as opset9, |
|
) |
|
from torch.onnx._globals import GLOBALS |
|
from torch.onnx._internal import _beartype, jit_utils, registration |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
"dequantize", |
|
"div", |
|
"embedding_bag", |
|
"fake_quantize_per_tensor_affine", |
|
"flip", |
|
"fmod", |
|
"isfinite", |
|
"isinf", |
|
"nan_to_num", |
|
"quantize_per_tensor", |
|
"quantized_add_relu", |
|
"quantized_add", |
|
"quantized_cat", |
|
"quantized_conv2d_relu", |
|
"quantized_conv2d", |
|
"quantized_group_norm", |
|
"quantized_hardswish", |
|
"quantized_instance_norm", |
|
"quantized_layer_norm", |
|
"quantized_leaky_relu", |
|
"quantized_linear", |
|
"quantized_mul", |
|
"quantized_sigmoid", |
|
"slice", |
|
"sort", |
|
"topk", |
|
] |
|
|
|
|
|
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=10) |
|
|
|
|
|
def _apply_params(*args, **kwargs): |
|
"""Returns a decorator that calls the decorated (higher-order) function with the given parameters.""" |
|
|
|
def _apply(fn): |
|
return fn(*args, **kwargs) |
|
|
|
return _apply |
|
|
|
|
|
@_onnx_symbolic("aten::div") |
|
@_beartype.beartype |
|
def div(g: jit_utils.GraphContext, self, other, *args): |
|
if len(args) == 0: |
|
return opset9.true_divide(g, self, other) |
|
else: |
|
return _div_rounding_mode(g, self, other, *args) |
|
|
|
|
|
@symbolic_helper.parse_args("v", "v", "s") |
|
@_beartype.beartype |
|
def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode): |
|
if rounding_mode == "floor": |
|
return _floor_divide(g, self, other) |
|
else: |
|
return opset9._div_rounding_mode(g, self, other, rounding_mode) |
|
|
|
|
|
@_onnx_symbolic("aten::_floor_divide") |
|
@_beartype.beartype |
|
def _floor_divide(g: jit_utils.GraphContext, self, other): |
|
if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): |
|
out = opset9.true_divide(g, self, other) |
|
return g.op("Floor", out) |
|
else: |
|
|
|
div = g.op("Div", self, other) |
|
|
|
zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) |
|
negative = g.op("Xor", g.op("Less", self, zero), g.op("Less", other, zero)) |
|
|
|
|
|
mod = g.op("Mod", self, other, fmod_i=0) |
|
fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero))) |
|
|
|
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) |
|
fixup = g.op("Sub", div, one) |
|
return g.op("Where", fixup_mask, fixup, div) |
|
|
|
|
|
@_onnx_symbolic("aten::sort") |
|
@symbolic_helper.parse_args("v", "i", "i", "none") |
|
@_beartype.beartype |
|
def sort(g: jit_utils.GraphContext, self, dim, decending, out=None): |
|
return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out) |
|
|
|
|
|
@_onnx_symbolic("aten::topk") |
|
@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none") |
|
@_beartype.beartype |
|
def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): |
|
return symbolic_helper._topk_helper( |
|
g, self, k, dim, largest=largest, sorted=sorted, out=out |
|
) |
|
|
|
|
|
@_onnx_symbolic( |
|
"aten::max_pool1d", |
|
decorate=[ |
|
_apply_params( |
|
"max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False |
|
) |
|
], |
|
) |
|
@_onnx_symbolic( |
|
"aten::max_pool2d", |
|
decorate=[ |
|
_apply_params( |
|
"max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False |
|
) |
|
], |
|
) |
|
@_onnx_symbolic( |
|
"aten::max_pool3d", |
|
decorate=[ |
|
_apply_params( |
|
"max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False |
|
) |
|
], |
|
) |
|
@_onnx_symbolic( |
|
"aten::max_pool1d_with_indices", |
|
decorate=[ |
|
_apply_params( |
|
"max_pool1d_with_indices", |
|
torch.nn.modules.utils._single, |
|
1, |
|
return_indices=True, |
|
) |
|
], |
|
) |
|
@_onnx_symbolic( |
|
"aten::max_pool2d_with_indices", |
|
decorate=[ |
|
_apply_params( |
|
"max_pool2d_with_indices", |
|
torch.nn.modules.utils._pair, |
|
2, |
|
return_indices=True, |
|
) |
|
], |
|
) |
|
@_onnx_symbolic( |
|
"aten::max_pool3d_with_indices", |
|
decorate=[ |
|
_apply_params( |
|
"max_pool3d_with_indices", |
|
torch.nn.modules.utils._triple, |
|
3, |
|
return_indices=True, |
|
) |
|
], |
|
) |
|
@_beartype.beartype |
|
def _max_pool(name: str, tuple_fn: Callable, ndims: int, return_indices: bool): |
|
@symbolic_helper.quantized_args(True, False, False, False, False, False) |
|
@symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") |
|
def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode): |
|
if not stride: |
|
stride = kernel_size |
|
kwargs = { |
|
"kernel_shape_i": tuple_fn(kernel_size), |
|
"pads_i": tuple_fn(padding) * 2, |
|
"strides_i": tuple_fn(stride), |
|
"ceil_mode_i": ceil_mode, |
|
} |
|
if set(tuple_fn(dilation)) != {1}: |
|
kwargs["dilations_i"] = tuple_fn(dilation) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if return_indices: |
|
r, indices = g.op("MaxPool", input, outputs=2, **kwargs) |
|
_, flattened_indices = g.op( |
|
"MaxPool", |
|
input, |
|
outputs=2, |
|
kernel_shape_i=[1 for _ in range(ndims)], |
|
strides_i=[1 for _ in range(ndims)], |
|
) |
|
|
|
s = symbolic_helper._slice_helper( |
|
g, |
|
flattened_indices, |
|
axes=[2 + i for i in range(ndims)], |
|
starts=tuple_fn(0), |
|
ends=tuple_fn(1), |
|
) |
|
indices = opset9.sub(g, indices, s) |
|
return r, indices |
|
else: |
|
r = g.op("MaxPool", input, outputs=1, **kwargs) |
|
return r |
|
|
|
return symbolic_fn |
|
|
|
|
|
@_onnx_symbolic( |
|
"aten::avg_pool1d", |
|
decorate=[_apply_params("avg_pool1d", torch.nn.modules.utils._single)], |
|
) |
|
@_onnx_symbolic( |
|
"aten::avg_pool2d", |
|
decorate=[_apply_params("avg_pool2d", torch.nn.modules.utils._pair)], |
|
) |
|
@_onnx_symbolic( |
|
"aten::avg_pool3d", |
|
decorate=[_apply_params("avg_pool3d", torch.nn.modules.utils._triple)], |
|
) |
|
@_beartype.beartype |
|
def _avg_pool(name, tuple_fn): |
|
@symbolic_helper.quantized_args(True, False, False, False, False, False, False) |
|
@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") |
|
@_beartype.beartype |
|
def symbolic_fn( |
|
g, |
|
input: _C.Value, |
|
kernel_size: Sequence[int], |
|
stride: Sequence[int], |
|
padding: Sequence[int], |
|
ceil_mode: int, |
|
count_include_pad: int, |
|
divisor_override=None, |
|
): |
|
if not stride: |
|
stride = kernel_size |
|
padding = symbolic_helper._avgpool_helper( |
|
tuple_fn, padding, kernel_size, stride, divisor_override, name |
|
) |
|
assert isinstance(padding, tuple) |
|
if count_include_pad: |
|
input = opset9._op_with_optional_float_cast( |
|
g, |
|
"Pad", |
|
input, |
|
pads_i=((0,) * 2 + padding) * 2, |
|
mode_s="constant", |
|
value_f=0.0, |
|
opset_before=11, |
|
) |
|
padding = (0,) * len(padding) |
|
output = g.op( |
|
"AveragePool", |
|
input, |
|
kernel_shape_i=tuple_fn(kernel_size), |
|
strides_i=tuple_fn(stride), |
|
pads_i=padding * 2, |
|
ceil_mode_i=ceil_mode, |
|
) |
|
return output |
|
|
|
return symbolic_fn |
|
|
|
|
|
@_onnx_symbolic( |
|
"aten::upsample_nearest1d", |
|
decorate=[_apply_params("upsample_nearest1d", 3, "nearest")], |
|
) |
|
@_onnx_symbolic( |
|
"aten::upsample_nearest2d", |
|
decorate=[_apply_params("upsample_nearest2d", 4, "nearest")], |
|
) |
|
@_onnx_symbolic( |
|
"aten::upsample_nearest3d", |
|
decorate=[_apply_params("upsample_nearest3d", 5, "nearest")], |
|
) |
|
@_onnx_symbolic( |
|
"aten::upsample_linear1d", |
|
decorate=[_apply_params("upsample_linear1d", 3, "linear")], |
|
) |
|
@_onnx_symbolic( |
|
"aten::upsample_bilinear2d", |
|
decorate=[_apply_params("upsample_bilinear2d", 4, "linear")], |
|
) |
|
@_onnx_symbolic( |
|
"aten::upsample_trilinear3d", |
|
decorate=[_apply_params("upsample_trilinear3d", 5, "linear")], |
|
) |
|
@_beartype.beartype |
|
def _interpolate(name, dim, interpolate_mode): |
|
@symbolic_helper.quantized_args(True, False, False) |
|
@_beartype.beartype |
|
def symbolic_fn(g, input, output_size, *args): |
|
scales, align_corners = symbolic_helper._get_interpolate_attributes( |
|
g, interpolate_mode, args |
|
) |
|
symbolic_helper._interpolate_warning(interpolate_mode) |
|
align_corners = symbolic_helper._maybe_get_scalar(align_corners) |
|
if align_corners: |
|
return symbolic_helper._unimplemented(name, "align_corners == True", input) |
|
if scales is None: |
|
scales = symbolic_helper._interpolate_size_to_scales( |
|
g, input, output_size, dim |
|
) |
|
return g.op("Resize", input, scales, mode_s=interpolate_mode) |
|
|
|
return symbolic_fn |
|
|
|
|
|
@_onnx_symbolic("aten::__interpolate") |
|
@_beartype.beartype |
|
def __interpolate( |
|
g: jit_utils.GraphContext, |
|
input, |
|
size, |
|
scale_factor, |
|
mode, |
|
align_corners, |
|
recompute_scale_factor, |
|
antialias, |
|
): |
|
scales, mode = symbolic_helper._interpolate_get_scales_and_mode( |
|
g, input, size, scale_factor, mode, align_corners |
|
) |
|
return g.op("Resize", input, scales, mode_s=mode) |
|
|
|
|
|
@_beartype.beartype |
|
def _slice( |
|
g: jit_utils.GraphContext, |
|
input, |
|
axes, |
|
starts, |
|
ends, |
|
steps=None, |
|
dynamic_slice=False, |
|
): |
|
if dynamic_slice: |
|
starts = symbolic_helper._unsqueeze_helper(g, starts, [0]) |
|
ends = symbolic_helper._unsqueeze_helper(g, ends, [0]) |
|
if isinstance(axes, int): |
|
axes = g.op("Constant", value_t=torch.tensor(axes)) |
|
axes = symbolic_helper._unsqueeze_helper(g, axes, [0]) |
|
else: |
|
assert len(starts) == len(ends) |
|
assert len(starts) == len(axes) |
|
assert steps is None or len(starts) == len(steps) |
|
if ( |
|
len(starts) == 1 |
|
and starts[0] == 0 |
|
and ends[0] == 9223372036854775807 |
|
and (steps is None or (len(steps) == 1 and steps[0] == 1)) |
|
): |
|
return input |
|
axes = g.op("Constant", value_t=torch.tensor(axes)) |
|
starts = g.op("Constant", value_t=torch.tensor(starts)) |
|
ends = g.op("Constant", value_t=torch.tensor(ends)) |
|
if steps is None: |
|
return g.op("Slice", input, starts, ends, axes) |
|
steps = g.op("Constant", value_t=torch.tensor(steps)) |
|
return g.op("Slice", input, starts, ends, axes, steps) |
|
|
|
|
|
@_onnx_symbolic("aten::slice") |
|
@_beartype.beartype |
|
def slice(g: jit_utils.GraphContext, self, *args): |
|
if len(args) == 4: |
|
|
|
dim, start, end, step = args |
|
elif len(args) == 3: |
|
|
|
start, end, step = args |
|
dim = 0 |
|
else: |
|
raise errors.SymbolicValueError("Unknown aten::slice signature", self) |
|
is_start_none = start.node().kind() == "prim::Constant" and isinstance( |
|
start.type(), _C.NoneType |
|
) |
|
is_end_none = end.node().kind() == "prim::Constant" and isinstance( |
|
end.type(), _C.NoneType |
|
) |
|
is_start_onnx_const = start.node().kind() == "onnx::Constant" |
|
is_end_onnx_const = end.node().kind() == "onnx::Constant" |
|
step = symbolic_helper._parse_arg(step, "i") |
|
if ( |
|
(not is_start_none and not is_start_onnx_const) |
|
or (not isinstance(end, int) and not is_end_none and not is_end_onnx_const) |
|
or (not isinstance(dim, int) and dim.node().kind() != "onnx::Constant") |
|
): |
|
dynamic_slice = True |
|
if is_start_none: |
|
start = g.op("Constant", value_t=torch.tensor(0)) |
|
if is_end_none: |
|
end = g.op("Constant", value_t=torch.tensor(9223372036854775807)) |
|
else: |
|
start = [0 if is_start_none else symbolic_helper._parse_arg(start, "i")] |
|
end = [ |
|
9223372036854775807 if is_end_none else symbolic_helper._parse_arg(end, "i") |
|
] |
|
dim = [symbolic_helper._parse_arg(dim, "i")] |
|
dynamic_slice = False |
|
return symbolic_helper._slice_helper( |
|
g, |
|
self, |
|
axes=dim, |
|
starts=start, |
|
ends=end, |
|
steps=[step], |
|
dynamic_slice=dynamic_slice, |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::flip") |
|
@symbolic_helper.parse_args("v", "is") |
|
@_beartype.beartype |
|
def flip(g: jit_utils.GraphContext, input, dims): |
|
return symbolic_helper._slice_helper( |
|
g, |
|
input, |
|
axes=dims, |
|
starts=[-1] * len(dims), |
|
ends=[-9223372036854775807] * len(dims), |
|
steps=[-1] * len(dims), |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::fmod") |
|
@_beartype.beartype |
|
def fmod(g: jit_utils.GraphContext, input, other): |
|
return g.op("Mod", input, other, fmod_i=1) |
|
|
|
|
|
@_onnx_symbolic("aten::embedding_bag") |
|
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") |
|
@_beartype.beartype |
|
def embedding_bag( |
|
g: jit_utils.GraphContext, |
|
embedding_matrix, |
|
indices, |
|
offsets, |
|
scale_grad_by_freq, |
|
mode, |
|
sparse, |
|
per_sample_weights, |
|
include_last_offset, |
|
padding_idx, |
|
): |
|
if scale_grad_by_freq and GLOBALS.export_training: |
|
return symbolic_helper._onnx_unsupported( |
|
"embedding_bag with scale_grad_by_freq for training mode" |
|
) |
|
if padding_idx is not None and padding_idx >= 0: |
|
raise RuntimeError("embedding_bag with padding_idx") |
|
|
|
warnings.warn( |
|
"Export of embedding_bag with dynamic input/offsets shape is not supported in opset 10. " |
|
"Please use opset 11 or higher to export model for dynamic input shape.'" |
|
) |
|
offsets_dim_0 = symbolic_helper._get_tensor_dim_size(offsets, 0) |
|
if offsets_dim_0 is not None: |
|
if include_last_offset: |
|
offset_len = offsets_dim_0 - 1 |
|
offsets_extended = offsets |
|
else: |
|
offset_len = offsets_dim_0 |
|
offsets_extended = [ |
|
offsets, |
|
g.op("Constant", value_t=torch.tensor([sys.maxsize])), |
|
] |
|
offsets_extended = g.op("Concat", *offsets_extended, axis_i=0) |
|
list_ = [] |
|
for i in range(offset_len): |
|
start_ = symbolic_helper._unsqueeze_helper( |
|
g, |
|
opset9.select(g, offsets_extended, torch.tensor(0), torch.tensor(i)), |
|
[0], |
|
) |
|
end_ = symbolic_helper._unsqueeze_helper( |
|
g, |
|
opset9.select( |
|
g, offsets_extended, torch.tensor(0), torch.tensor(i + 1) |
|
), |
|
[0], |
|
) |
|
axes_ = g.op("Constant", value_t=torch.tensor([0])) |
|
indices_row = g.op("Slice", indices, start_, end_, axes_) |
|
|
|
embeddings = g.op("Gather", embedding_matrix, indices_row) |
|
if not symbolic_helper._is_none(per_sample_weights): |
|
per_sample_weights_row = g.op( |
|
"Slice", per_sample_weights, start_, end_, axes_ |
|
) |
|
per_sample_weights_row = symbolic_helper._unsqueeze_helper( |
|
g, per_sample_weights_row, [1] |
|
) |
|
embeddings = g.op("Mul", embeddings, per_sample_weights_row) |
|
if mode == 0: |
|
embeddings = symbolic_helper._reducesum_helper( |
|
g, embeddings, axes_i=[0], keepdims_i=0 |
|
) |
|
elif mode == 1: |
|
embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0) |
|
else: |
|
embeddings = g.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0) |
|
|
|
embeddings = symbolic_helper._unsqueeze_helper(g, embeddings, [0]) |
|
list_.append(embeddings) |
|
|
|
output = g.op("Concat", *list_, axis_i=0) |
|
|
|
|
|
return output, None, None, None |
|
else: |
|
return symbolic_helper._onnx_unsupported( |
|
"embedding_bag with unknown shape of offsets for opset 10 is not supported. " |
|
"please use opset 11 or higher." |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::fake_quantize_per_tensor_affine") |
|
@symbolic_helper.parse_args("v", "v", "v", "i", "i") |
|
@_beartype.beartype |
|
def fake_quantize_per_tensor_affine( |
|
g: jit_utils.GraphContext, |
|
inputs, |
|
scale, |
|
zero_point, |
|
quant_min=-128, |
|
quant_max=127, |
|
): |
|
|
|
|
|
if (quant_min, quant_max) == (0, 127): |
|
symbolic_helper._onnx_opset_unsupported_detailed( |
|
"fake_quantize_per_tensor_affine", |
|
10, |
|
13, |
|
"Quantize range (0, 127) not supported, requires opset 13 Clip", |
|
inputs, |
|
) |
|
if (quant_min, quant_max) not in [(0, 255), (-128, 127)]: |
|
raise errors.SymbolicValueError( |
|
f"For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). " |
|
f"Got ({quant_min}, {quant_max})", |
|
inputs, |
|
) |
|
scale = symbolic_helper._maybe_get_scalar(scale) |
|
if scale is None: |
|
symbolic_helper._onnx_opset_unsupported_detailed( |
|
"fake_quantize_per_tensor_affine", |
|
10, |
|
13, |
|
"Non-constant scale not supported", |
|
inputs, |
|
) |
|
scale = scale.float().data |
|
if quant_min == 0: |
|
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) |
|
else: |
|
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) |
|
return g.op( |
|
"DequantizeLinear", |
|
g.op("QuantizeLinear", inputs, scale, zero_point), |
|
scale, |
|
zero_point, |
|
) |
|
|
|
|
|
@_onnx_symbolic("aten::isinf") |
|
@_beartype.beartype |
|
def isinf(g: jit_utils.GraphContext, input): |
|
return g.op("IsInf", opset9._cast_Double(g, input, False)) |
|
|
|
|
|
@_onnx_symbolic("aten::isfinite") |
|
@_beartype.beartype |
|
def isfinite(g: jit_utils.GraphContext, input): |
|
inf_node = isinf(g, input) |
|
nan_node = opset9.isnan(g, input) |
|
return opset9.__not_(g, opset9.__or_(g, inf_node, nan_node)) |
|
|
|
|
|
@_onnx_symbolic("aten::quantize_per_tensor") |
|
@_beartype.beartype |
|
def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype): |
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
|
|
|
zero_point = g.op( |
|
"Cast", zero_point, to_i=_type_utils.JitScalarType(dtype).onnx_type() |
|
) |
|
scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) |
|
return symbolic_helper.quantize_helper(g, input, scale, zero_point) |
|
|
|
|
|
@_onnx_symbolic("aten::dequantize") |
|
@_beartype.beartype |
|
def dequantize(g: jit_utils.GraphContext, input): |
|
return symbolic_helper.dequantize_helper(g, input)[0] |
|
|
|
|
|
@_onnx_symbolic("aten::nan_to_num") |
|
@symbolic_helper.parse_args("v", "f", "f", "f") |
|
@_beartype.beartype |
|
def nan_to_num(g: jit_utils.GraphContext, input, nan, posinf, neginf): |
|
|
|
|
|
if not symbolic_helper._is_fp(input): |
|
return input |
|
input_dtype = _type_utils.JitScalarType.from_name(input.type().scalarType()).dtype() |
|
if nan is None: |
|
nan = 0.0 |
|
nan_cond = opset9.isnan(g, input) |
|
nan_result = g.op( |
|
"Where", |
|
nan_cond, |
|
g.op("Constant", value_t=torch.tensor([nan], dtype=input_dtype)), |
|
input, |
|
) |
|
|
|
|
|
|
|
finfo = torch.finfo(input_dtype) |
|
if posinf is None: |
|
posinf = finfo.max |
|
posinf_cond = opset9.logical_and( |
|
g, |
|
isinf(g, nan_result), |
|
opset9.gt(g, nan_result, g.op("Constant", value_t=torch.LongTensor([0]))), |
|
) |
|
nan_posinf_result = g.op( |
|
"Where", |
|
posinf_cond, |
|
g.op("Constant", value_t=torch.tensor([posinf], dtype=input_dtype)), |
|
nan_result, |
|
) |
|
|
|
if neginf is None: |
|
neginf = finfo.min |
|
neginf_cond = opset9.logical_and( |
|
g, |
|
isinf(g, nan_posinf_result), |
|
opset9.lt( |
|
g, nan_posinf_result, g.op("Constant", value_t=torch.LongTensor([0])) |
|
), |
|
) |
|
return g.op( |
|
"Where", |
|
neginf_cond, |
|
g.op("Constant", value_t=torch.tensor([neginf], dtype=input_dtype)), |
|
nan_posinf_result, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
@_onnx_symbolic("quantized::linear") |
|
@_beartype.beartype |
|
def quantized_linear( |
|
g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point |
|
): |
|
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) |
|
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) |
|
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) |
|
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) |
|
|
|
output = opset9.linear(g, input, weight, bias) |
|
|
|
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
|
|
|
|
|
@_onnx_symbolic("quantized::add") |
|
@_beartype.beartype |
|
def quantized_add(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): |
|
x, _, _, _ = symbolic_helper.dequantize_helper(g, x) |
|
y, _, _, _ = symbolic_helper.dequantize_helper(g, y) |
|
|
|
output = opset9.add(g, x, y) |
|
|
|
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
|
|
|
|
|
@_onnx_symbolic("quantized::add_relu") |
|
@_beartype.beartype |
|
def quantized_add_relu(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): |
|
x, _, _, _ = symbolic_helper.dequantize_helper(g, x) |
|
y, _, _, _ = symbolic_helper.dequantize_helper(g, y) |
|
|
|
output = opset9.add(g, x, y) |
|
output = opset9.relu(g, output) |
|
|
|
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
|
|
|
|
|
@_onnx_symbolic("quantized::mul") |
|
@_beartype.beartype |
|
def quantized_mul(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): |
|
x, _, _, _ = symbolic_helper.dequantize_helper(g, x) |
|
y, _, _, _ = symbolic_helper.dequantize_helper(g, y) |
|
|
|
output = opset9.mul(g, x, y) |
|
|
|
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
|
|
|
|
|
@_onnx_symbolic("quantized::hardswish") |
|
@_beartype.beartype |
|
def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point): |
|
x, _, _, _ = symbolic_helper.dequantize_helper(g, x) |
|
|
|
output = opset9.hardswish(g, x) |
|
|
|
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
|
|
|
|
|
@_onnx_symbolic("quantized::sigmoid") |
|
@_beartype.beartype |
|
def quantized_sigmoid(g: jit_utils.GraphContext, x, op_scale, op_zero_point): |
|
x, _, _, _ = symbolic_helper.dequantize_helper(g, x) |
|
|
|
output = opset9.sigmoid(g, x) |
|
|
|
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
|
|
|
|
|
@_onnx_symbolic("quantized::leaky_relu") |
|
@_beartype.beartype |
|
def quantized_leaky_relu( |
|
g: jit_utils.GraphContext, x, negative_slope, inplace, op_scale, op_zero_point |
|
): |
|
x, _, _, _ = symbolic_helper.dequantize_helper(g, x) |
|
|
|
output = opset9.leaky_relu(g, x, negative_slope, inplace) |
|
|
|
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
|
|
|
|
|
@_onnx_symbolic("quantized::layer_norm") |
|
@_beartype.beartype |
|
def quantized_layer_norm( |
|
g: jit_utils.GraphContext, |
|
x, |
|
normalized_shape, |
|
weight, |
|
bias, |
|
eps, |
|
op_scale, |
|
op_zero_point, |
|
): |
|
x, _, _, _ = symbolic_helper.dequantize_helper(g, x) |
|
|
|
output = opset9.layer_norm(g, x, normalized_shape, weight, bias, eps, False) |
|
|
|
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
|
|
|
|
|
@_onnx_symbolic("quantized::group_norm") |
|
@_beartype.beartype |
|
def quantized_group_norm( |
|
g: jit_utils.GraphContext, |
|
x, |
|
num_groups, |
|
weight, |
|
bias, |
|
eps, |
|
op_scale, |
|
op_zero_point, |
|
): |
|
x, _, _, _ = symbolic_helper.dequantize_helper(g, x) |
|
|
|
output = opset9.group_norm(g, x, num_groups, weight, bias, eps, False) |
|
|
|
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
|
|
|
|
|
@_onnx_symbolic("quantized::instance_norm") |
|
@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v") |
|
@_beartype.beartype |
|
def quantized_instance_norm( |
|
g: jit_utils.GraphContext, |
|
q_input, |
|
weight, |
|
bias, |
|
eps, |
|
op_scale, |
|
op_zero_point, |
|
): |
|
input, _, _, _ = symbolic_helper.dequantize_helper(g, q_input) |
|
|
|
output = opset9.instance_norm( |
|
g, input, weight, bias, None, None, False, 0.0, eps, False |
|
) |
|
|
|
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
|
|
|
|
|
@_onnx_symbolic("quantized::conv2d_relu") |
|
@_beartype.beartype |
|
def quantized_conv2d_relu( |
|
g: jit_utils.GraphContext, |
|
q_input, |
|
q_weight, |
|
bias, |
|
stride, |
|
padding, |
|
dilation, |
|
groups, |
|
op_scale, |
|
op_zero_point, |
|
): |
|
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) |
|
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) |
|
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) |
|
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) |
|
|
|
output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) |
|
output = opset9.relu(g, output) |
|
|
|
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
|
|
|
|
|
@_onnx_symbolic("quantized::conv2d") |
|
@_beartype.beartype |
|
def quantized_conv2d( |
|
g: jit_utils.GraphContext, |
|
q_input, |
|
q_weight, |
|
bias, |
|
stride, |
|
padding, |
|
dilation, |
|
groups, |
|
op_scale, |
|
op_zero_point, |
|
): |
|
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) |
|
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) |
|
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) |
|
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) |
|
|
|
output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) |
|
|
|
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
|
|
|
|
|
@_onnx_symbolic("quantized::cat") |
|
@symbolic_helper.parse_args("v", "i", "v", "v") |
|
@_beartype.beartype |
|
def quantized_cat( |
|
g: jit_utils.GraphContext, |
|
q_inputs: _C.Value, |
|
dim: int, |
|
op_scale: _C.Value, |
|
op_zero_point: _C.Value, |
|
) -> _C.Value: |
|
unpacked_inputs = symbolic_helper._unpack_list(q_inputs) |
|
dequantized = [ |
|
symbolic_helper.dequantize_helper(g, input)[0] for input in unpacked_inputs |
|
] |
|
concatenated = g.op("Concat", *dequantized, axis_i=dim) |
|
return symbolic_helper.quantize_helper(g, concatenated, op_scale, op_zero_point) |
|
|