/* * SPDX-License-Identifier: Apache-2.0 */ #pragma once #include "onnx/defs/shape_inference.h" #include "onnx/defs/tensor_proto_util.h" #include "onnx/onnx_pb.h" namespace ONNX_NAMESPACE { namespace defs { namespace math { namespace utils { template T GetScalarValueFromTensor(const ONNX_NAMESPACE::TensorProto* t) { if (t == nullptr) { return T{}; } auto data_type = t->data_type(); switch (data_type) { case ONNX_NAMESPACE::TensorProto::FLOAT: return static_cast(ONNX_NAMESPACE::ParseData(t).at(0)); case ONNX_NAMESPACE::TensorProto::DOUBLE: return static_cast(ONNX_NAMESPACE::ParseData(t).at(0)); case ONNX_NAMESPACE::TensorProto::INT32: return static_cast(ONNX_NAMESPACE::ParseData(t).at(0)); case ONNX_NAMESPACE::TensorProto::INT64: return static_cast(ONNX_NAMESPACE::ParseData(t).at(0)); default: fail_shape_inference("Unsupported input data type of ", data_type); } } void MatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int input1Idx, int input2Idx); void QLinearMatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx); const char* QLinearMatMulDoc(); } // namespace utils } // namespace math } // namespace defs } // namespace ONNX_NAMESPACE