Added Reduced onnxruntime.xcframework
Browse files- .gitattributes +1 -0
- 1.15.1/onnxruntime.xcframework/Headers/coreml_provider_factory.h +41 -0
- 1.15.1/onnxruntime.xcframework/Headers/cpu_provider_factory.h +19 -0
- 1.15.1/onnxruntime.xcframework/Headers/onnxruntime_c_api.h +0 -0
- 1.15.1/onnxruntime.xcframework/Headers/onnxruntime_cxx_api.h +1878 -0
- 1.15.1/onnxruntime.xcframework/Headers/onnxruntime_cxx_inline.h +1888 -0
- 1.15.1/onnxruntime.xcframework/Info.plist +40 -0
- 1.15.1/onnxruntime.xcframework/ios-arm64/onnxruntime.a +3 -0
- 1.15.1/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.a +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.a filter=lfs diff=lfs merge=lfs -text
|
1.15.1/onnxruntime.xcframework/Headers/coreml_provider_factory.h
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Microsoft Corporation. All rights reserved.
|
2 |
+
// Licensed under the MIT License.
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include "onnxruntime_c_api.h"
|
6 |
+
|
7 |
+
// COREMLFlags are bool options we want to set for CoreML EP
|
8 |
+
// This enum is defined as bit flags, and cannot have negative value
|
9 |
+
// To generate an uint32_t coreml_flags for using with OrtSessionOptionsAppendExecutionProvider_CoreML below,
|
10 |
+
// uint32_t coreml_flags = 0;
|
11 |
+
// coreml_flags |= COREML_FLAG_USE_CPU_ONLY;
|
12 |
+
enum COREMLFlags {
|
13 |
+
COREML_FLAG_USE_NONE = 0x000,
|
14 |
+
|
15 |
+
// Using CPU only in CoreML EP, this may decrease the perf but will provide
|
16 |
+
// reference output value without precision loss, which is useful for validation
|
17 |
+
COREML_FLAG_USE_CPU_ONLY = 0x001,
|
18 |
+
|
19 |
+
// Enable CoreML EP on subgraph
|
20 |
+
COREML_FLAG_ENABLE_ON_SUBGRAPH = 0x002,
|
21 |
+
|
22 |
+
// By default CoreML Execution provider will be enabled for all compatible Apple devices
|
23 |
+
// Enable this option will only enable CoreML EP for Apple devices with ANE (Apple Neural Engine)
|
24 |
+
// Please note, enable this option does not guarantee the entire model to be executed using ANE only
|
25 |
+
COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE = 0x004,
|
26 |
+
|
27 |
+
// Keep COREML_FLAG_MAX at the end of the enum definition
|
28 |
+
// And assign the last COREMLFlag to it
|
29 |
+
COREML_FLAG_LAST = COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE,
|
30 |
+
};
|
31 |
+
|
32 |
+
#ifdef __cplusplus
|
33 |
+
extern "C" {
|
34 |
+
#endif
|
35 |
+
|
36 |
+
ORT_EXPORT ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CoreML,
|
37 |
+
_In_ OrtSessionOptions* options, uint32_t coreml_flags);
|
38 |
+
|
39 |
+
#ifdef __cplusplus
|
40 |
+
}
|
41 |
+
#endif
|
1.15.1/onnxruntime.xcframework/Headers/cpu_provider_factory.h
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Microsoft Corporation. All rights reserved.
|
2 |
+
// Licensed under the MIT License.
|
3 |
+
|
4 |
+
#include "onnxruntime_c_api.h"
|
5 |
+
|
6 |
+
#ifdef __cplusplus
|
7 |
+
extern "C" {
|
8 |
+
#endif
|
9 |
+
|
10 |
+
/**
|
11 |
+
* \param use_arena zero: false. non-zero: true.
|
12 |
+
*/
|
13 |
+
ORT_EXPORT
|
14 |
+
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessionOptions* options, int use_arena)
|
15 |
+
ORT_ALL_ARGS_NONNULL;
|
16 |
+
|
17 |
+
#ifdef __cplusplus
|
18 |
+
}
|
19 |
+
#endif
|
1.15.1/onnxruntime.xcframework/Headers/onnxruntime_c_api.h
ADDED
The diff for this file is too large to render.
See raw diff
|
|
1.15.1/onnxruntime.xcframework/Headers/onnxruntime_cxx_api.h
ADDED
@@ -0,0 +1,1878 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Microsoft Corporation. All rights reserved.
|
2 |
+
// Licensed under the MIT License.
|
3 |
+
|
4 |
+
// Summary: The Ort C++ API is a header only wrapper around the Ort C API.
|
5 |
+
//
|
6 |
+
// The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors
|
7 |
+
// and automatically releasing resources in the destructors. The primary purpose of C++ API is exception safety so
|
8 |
+
// all the resources follow RAII and do not leak memory.
|
9 |
+
//
|
10 |
+
// Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers.
|
11 |
+
// To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};). However, you can't use them
|
12 |
+
// until you assign an instance that actually holds an underlying object.
|
13 |
+
//
|
14 |
+
// For Ort objects only move assignment between objects is allowed, there are no copy constructors.
|
15 |
+
// Some objects have explicit 'Clone' methods for this purpose.
|
16 |
+
//
|
17 |
+
// ConstXXXX types are copyable since they do not own the underlying C object, so you can pass them to functions as arguments
|
18 |
+
// by value or by reference. ConstXXXX types are restricted to const only interfaces.
|
19 |
+
//
|
20 |
+
// UnownedXXXX are similar to ConstXXXX but also allow non-const interfaces.
|
21 |
+
//
|
22 |
+
// The lifetime of the corresponding owning object must eclipse the lifetimes of the ConstXXXX/UnownedXXXX types. They exists so you do not
|
23 |
+
// have to fallback to C types and the API with the usual pitfalls. In general, do not use C API from your C++ code.
|
24 |
+
|
25 |
+
#pragma once
|
26 |
+
#include "onnxruntime_c_api.h"
|
27 |
+
#include <cstddef>
|
28 |
+
#include <array>
|
29 |
+
#include <memory>
|
30 |
+
#include <stdexcept>
|
31 |
+
#include <string>
|
32 |
+
#include <vector>
|
33 |
+
#include <unordered_map>
|
34 |
+
#include <utility>
|
35 |
+
#include <type_traits>
|
36 |
+
|
37 |
+
#ifdef ORT_NO_EXCEPTIONS
|
38 |
+
#include <iostream>
|
39 |
+
#endif
|
40 |
+
|
41 |
+
/** \brief All C++ Onnxruntime APIs are defined inside this namespace
|
42 |
+
*
|
43 |
+
*/
|
44 |
+
namespace Ort {
|
45 |
+
|
46 |
+
/** \brief All C++ methods that can fail will throw an exception of this type
|
47 |
+
*
|
48 |
+
* If <tt>ORT_NO_EXCEPTIONS</tt> is defined, then any error will result in a call to abort()
|
49 |
+
*/
|
50 |
+
struct Exception : std::exception {
|
51 |
+
Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {}
|
52 |
+
|
53 |
+
OrtErrorCode GetOrtErrorCode() const { return code_; }
|
54 |
+
const char* what() const noexcept override { return message_.c_str(); }
|
55 |
+
|
56 |
+
private:
|
57 |
+
std::string message_;
|
58 |
+
OrtErrorCode code_;
|
59 |
+
};
|
60 |
+
|
61 |
+
#ifdef ORT_NO_EXCEPTIONS
|
62 |
+
// The #ifndef is for the very special case where the user of this library wants to define their own way of handling errors.
|
63 |
+
// NOTE: This header expects control flow to not continue after calling ORT_CXX_API_THROW
|
64 |
+
#ifndef ORT_CXX_API_THROW
|
65 |
+
#define ORT_CXX_API_THROW(string, code) \
|
66 |
+
do { \
|
67 |
+
std::cerr << Ort::Exception(string, code) \
|
68 |
+
.what() \
|
69 |
+
<< std::endl; \
|
70 |
+
abort(); \
|
71 |
+
} while (false)
|
72 |
+
#endif
|
73 |
+
#else
|
74 |
+
#define ORT_CXX_API_THROW(string, code) \
|
75 |
+
throw Ort::Exception(string, code)
|
76 |
+
#endif
|
77 |
+
|
78 |
+
// This is used internally by the C++ API. This class holds the global variable that points to the OrtApi,
|
79 |
+
// it's in a template so that we can define a global variable in a header and make
|
80 |
+
// it transparent to the users of the API.
|
81 |
+
template <typename T>
|
82 |
+
struct Global {
|
83 |
+
static const OrtApi* api_;
|
84 |
+
};
|
85 |
+
|
86 |
+
// If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it.
|
87 |
+
template <typename T>
|
88 |
+
#ifdef ORT_API_MANUAL_INIT
|
89 |
+
const OrtApi* Global<T>::api_{};
|
90 |
+
inline void InitApi() noexcept { Global<void>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); }
|
91 |
+
|
92 |
+
// Used by custom operator libraries that are not linked to onnxruntime. Sets the global API object, which is
|
93 |
+
// required by C++ APIs.
|
94 |
+
//
|
95 |
+
// Example mycustomop.cc:
|
96 |
+
//
|
97 |
+
// #define ORT_API_MANUAL_INIT
|
98 |
+
// #include <onnxruntime_cxx_api.h>
|
99 |
+
// #undef ORT_API_MANUAL_INIT
|
100 |
+
//
|
101 |
+
// OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api_base) {
|
102 |
+
// Ort::InitApi(api_base->GetApi(ORT_API_VERSION));
|
103 |
+
// // ...
|
104 |
+
// }
|
105 |
+
//
|
106 |
+
inline void InitApi(const OrtApi* api) noexcept { Global<void>::api_ = api; }
|
107 |
+
#else
|
108 |
+
#if defined(_MSC_VER) && !defined(__clang__)
|
109 |
+
#pragma warning(push)
|
110 |
+
// "Global initializer calls a non-constexpr function." Therefore you can't use ORT APIs in the other global initializers.
|
111 |
+
// Please define ORT_API_MANUAL_INIT if it conerns you.
|
112 |
+
#pragma warning(disable : 26426)
|
113 |
+
#endif
|
114 |
+
const OrtApi* Global<T>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION);
|
115 |
+
#if defined(_MSC_VER) && !defined(__clang__)
|
116 |
+
#pragma warning(pop)
|
117 |
+
#endif
|
118 |
+
#endif
|
119 |
+
|
120 |
+
/// This returns a reference to the OrtApi interface in use
|
121 |
+
inline const OrtApi& GetApi() noexcept { return *Global<void>::api_; }
|
122 |
+
|
123 |
+
/// <summary>
|
124 |
+
/// This is a C++ wrapper for OrtApi::GetAvailableProviders() and
|
125 |
+
/// returns a vector of strings representing the available execution providers.
|
126 |
+
/// </summary>
|
127 |
+
/// <returns>vector of strings</returns>
|
128 |
+
std::vector<std::string> GetAvailableProviders();
|
129 |
+
|
130 |
+
/** \brief IEEE 754 half-precision floating point data type
|
131 |
+
* \details It is necessary for type dispatching to make use of C++ API
|
132 |
+
* The type is implicitly convertible to/from uint16_t.
|
133 |
+
* The size of the structure should align with uint16_t and one can freely cast
|
134 |
+
* uint16_t buffers to/from Ort::Float16_t to feed and retrieve data.
|
135 |
+
*
|
136 |
+
* Generally, you can feed any of your types as float16/blfoat16 data to create a tensor
|
137 |
+
* on top of it, providing it can form a continuous buffer with 16-bit elements with no padding.
|
138 |
+
* And you can also feed a array of uint16_t elements directly. For example,
|
139 |
+
*
|
140 |
+
* \code{.unparsed}
|
141 |
+
* uint16_t values[] = { 15360, 16384, 16896, 17408, 17664};
|
142 |
+
* constexpr size_t values_length = sizeof(values) / sizeof(values[0]);
|
143 |
+
* std::vector<int64_t> dims = {values_length}; // one dimensional example
|
144 |
+
* Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
|
145 |
+
* // Note we are passing bytes count in this api, not number of elements -> sizeof(values)
|
146 |
+
* auto float16_tensor = Ort::Value::CreateTensor(info, values, sizeof(values),
|
147 |
+
* dims.data(), dims.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
|
148 |
+
* \endcode
|
149 |
+
*
|
150 |
+
* Here is another example, a little bit more elaborate. Let's assume that you use your own float16 type and you want to use
|
151 |
+
* a templated version of the API above so the type is automatically set based on your type. You will need to supply an extra
|
152 |
+
* template specialization.
|
153 |
+
*
|
154 |
+
* \code{.unparsed}
|
155 |
+
* namespace yours { struct half {}; } // assume this is your type, define this:
|
156 |
+
* namespace Ort {
|
157 |
+
* template<>
|
158 |
+
* struct TypeToTensorType<yours::half> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; };
|
159 |
+
* } //namespace Ort
|
160 |
+
*
|
161 |
+
* std::vector<yours::half> values;
|
162 |
+
* std::vector<int64_t> dims = {values.size()}; // one dimensional example
|
163 |
+
* Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
|
164 |
+
* // Here we are passing element count -> values.size()
|
165 |
+
* auto float16_tensor = Ort::Value::CreateTensor<yours::half>(info, values.data(), values.size(), dims.data(), dims.size());
|
166 |
+
*
|
167 |
+
* \endcode
|
168 |
+
*/
|
169 |
+
struct Float16_t {
|
170 |
+
uint16_t value;
|
171 |
+
constexpr Float16_t() noexcept : value(0) {}
|
172 |
+
constexpr Float16_t(uint16_t v) noexcept : value(v) {}
|
173 |
+
constexpr operator uint16_t() const noexcept { return value; }
|
174 |
+
constexpr bool operator==(const Float16_t& rhs) const noexcept { return value == rhs.value; };
|
175 |
+
constexpr bool operator!=(const Float16_t& rhs) const noexcept { return value != rhs.value; };
|
176 |
+
};
|
177 |
+
|
178 |
+
static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match");
|
179 |
+
|
180 |
+
/** \brief bfloat16 (Brain Floating Point) data type
|
181 |
+
* \details It is necessary for type dispatching to make use of C++ API
|
182 |
+
* The type is implicitly convertible to/from uint16_t.
|
183 |
+
* The size of the structure should align with uint16_t and one can freely cast
|
184 |
+
* uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data.
|
185 |
+
*
|
186 |
+
* See also code examples for Float16_t above.
|
187 |
+
*/
|
188 |
+
struct BFloat16_t {
|
189 |
+
uint16_t value;
|
190 |
+
constexpr BFloat16_t() noexcept : value(0) {}
|
191 |
+
constexpr BFloat16_t(uint16_t v) noexcept : value(v) {}
|
192 |
+
constexpr operator uint16_t() const noexcept { return value; }
|
193 |
+
constexpr bool operator==(const BFloat16_t& rhs) const noexcept { return value == rhs.value; };
|
194 |
+
constexpr bool operator!=(const BFloat16_t& rhs) const noexcept { return value != rhs.value; };
|
195 |
+
};
|
196 |
+
|
197 |
+
static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match");
|
198 |
+
|
199 |
+
namespace detail {
|
200 |
+
// This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type
|
201 |
+
// This can't be done in the C API since C doesn't have function overloading.
|
202 |
+
#define ORT_DEFINE_RELEASE(NAME) \
|
203 |
+
inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); }
|
204 |
+
|
205 |
+
ORT_DEFINE_RELEASE(Allocator);
|
206 |
+
ORT_DEFINE_RELEASE(MemoryInfo);
|
207 |
+
ORT_DEFINE_RELEASE(CustomOpDomain);
|
208 |
+
ORT_DEFINE_RELEASE(ThreadingOptions);
|
209 |
+
ORT_DEFINE_RELEASE(Env);
|
210 |
+
ORT_DEFINE_RELEASE(RunOptions);
|
211 |
+
ORT_DEFINE_RELEASE(Session);
|
212 |
+
ORT_DEFINE_RELEASE(SessionOptions);
|
213 |
+
ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo);
|
214 |
+
ORT_DEFINE_RELEASE(SequenceTypeInfo);
|
215 |
+
ORT_DEFINE_RELEASE(MapTypeInfo);
|
216 |
+
ORT_DEFINE_RELEASE(TypeInfo);
|
217 |
+
ORT_DEFINE_RELEASE(Value);
|
218 |
+
ORT_DEFINE_RELEASE(ModelMetadata);
|
219 |
+
ORT_DEFINE_RELEASE(IoBinding);
|
220 |
+
ORT_DEFINE_RELEASE(ArenaCfg);
|
221 |
+
ORT_DEFINE_RELEASE(Status);
|
222 |
+
ORT_DEFINE_RELEASE(OpAttr);
|
223 |
+
ORT_DEFINE_RELEASE(Op);
|
224 |
+
ORT_DEFINE_RELEASE(KernelInfo);
|
225 |
+
|
226 |
+
#undef ORT_DEFINE_RELEASE
|
227 |
+
|
228 |
+
/** \brief This is a tagging template type. Use it with Base<T> to indicate that the C++ interface object
|
229 |
+
* has no ownership of the underlying C object.
|
230 |
+
*/
|
231 |
+
template <typename T>
|
232 |
+
struct Unowned {
|
233 |
+
using Type = T;
|
234 |
+
};
|
235 |
+
|
236 |
+
/** \brief Used internally by the C++ API. C++ wrapper types inherit from this.
|
237 |
+
* This is a zero cost abstraction to wrap the C API objects and delete them on destruction.
|
238 |
+
*
|
239 |
+
* All of the C++ classes
|
240 |
+
* a) serve as containers for pointers to objects that are created by the underlying C API.
|
241 |
+
* Their size is just a pointer size, no need to dynamically allocate them. Use them by value.
|
242 |
+
* b) Each of struct XXXX, XXX instances function as smart pointers to the underlying C API objects.
|
243 |
+
* they would release objects owned automatically when going out of scope, they are move-only.
|
244 |
+
* c) ConstXXXX and UnownedXXX structs function as non-owning, copyable containers for the above pointers.
|
245 |
+
* ConstXXXX allow calling const interfaces only. They give access to objects that are owned by somebody else
|
246 |
+
* such as Onnxruntime or instances of XXXX classes.
|
247 |
+
* d) serve convenient interfaces that return C++ objects and further enhance exception and type safety so they can be used
|
248 |
+
* in C++ code.
|
249 |
+
*
|
250 |
+
*/
|
251 |
+
|
252 |
+
/// <summary>
|
253 |
+
/// This is a non-const pointer holder that is move-only. Disposes of the pointer on destruction.
|
254 |
+
/// </summary>
|
255 |
+
template <typename T>
|
256 |
+
struct Base {
|
257 |
+
using contained_type = T;
|
258 |
+
|
259 |
+
constexpr Base() = default;
|
260 |
+
constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
|
261 |
+
~Base() { OrtRelease(p_); }
|
262 |
+
|
263 |
+
Base(const Base&) = delete;
|
264 |
+
Base& operator=(const Base&) = delete;
|
265 |
+
|
266 |
+
Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
|
267 |
+
Base& operator=(Base&& v) noexcept {
|
268 |
+
OrtRelease(p_);
|
269 |
+
p_ = v.release();
|
270 |
+
return *this;
|
271 |
+
}
|
272 |
+
|
273 |
+
constexpr operator contained_type*() const noexcept { return p_; }
|
274 |
+
|
275 |
+
/// \brief Relinquishes ownership of the contained C object pointer
|
276 |
+
/// The underlying object is not destroyed
|
277 |
+
contained_type* release() {
|
278 |
+
T* p = p_;
|
279 |
+
p_ = nullptr;
|
280 |
+
return p;
|
281 |
+
}
|
282 |
+
|
283 |
+
protected:
|
284 |
+
contained_type* p_{};
|
285 |
+
};
|
286 |
+
|
287 |
+
// Undefined. For const types use Base<Unowned<const T>>
|
288 |
+
template <typename T>
|
289 |
+
struct Base<const T>;
|
290 |
+
|
291 |
+
/// <summary>
|
292 |
+
/// Covers unowned pointers owned by either the ORT
|
293 |
+
/// or some other instance of CPP wrappers.
|
294 |
+
/// Used for ConstXXX and UnownedXXXX types that are copyable.
|
295 |
+
/// Also convenient to wrap raw OrtXX pointers .
|
296 |
+
/// </summary>
|
297 |
+
/// <typeparam name="T"></typeparam>
|
298 |
+
template <typename T>
|
299 |
+
struct Base<Unowned<T>> {
|
300 |
+
using contained_type = typename Unowned<T>::Type;
|
301 |
+
|
302 |
+
constexpr Base() = default;
|
303 |
+
constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
|
304 |
+
|
305 |
+
~Base() = default;
|
306 |
+
|
307 |
+
Base(const Base&) = default;
|
308 |
+
Base& operator=(const Base&) = default;
|
309 |
+
|
310 |
+
Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
|
311 |
+
Base& operator=(Base&& v) noexcept {
|
312 |
+
p_ = nullptr;
|
313 |
+
std::swap(p_, v.p_);
|
314 |
+
return *this;
|
315 |
+
}
|
316 |
+
|
317 |
+
constexpr operator contained_type*() const noexcept { return p_; }
|
318 |
+
|
319 |
+
protected:
|
320 |
+
contained_type* p_{};
|
321 |
+
};
|
322 |
+
|
323 |
+
// Light functor to release memory with OrtAllocator
|
324 |
+
struct AllocatedFree {
|
325 |
+
OrtAllocator* allocator_;
|
326 |
+
explicit AllocatedFree(OrtAllocator* allocator)
|
327 |
+
: allocator_(allocator) {}
|
328 |
+
void operator()(void* ptr) const {
|
329 |
+
if (ptr) allocator_->Free(allocator_, ptr);
|
330 |
+
}
|
331 |
+
};
|
332 |
+
|
333 |
+
} // namespace detail
|
334 |
+
|
335 |
+
struct AllocatorWithDefaultOptions;
|
336 |
+
struct Env;
|
337 |
+
struct TypeInfo;
|
338 |
+
struct Value;
|
339 |
+
struct ModelMetadata;
|
340 |
+
|
341 |
+
/** \brief unique_ptr typedef used to own strings allocated by OrtAllocators
|
342 |
+
* and release them at the end of the scope. The lifespan of the given allocator
|
343 |
+
* must eclipse the lifespan of AllocatedStringPtr instance
|
344 |
+
*/
|
345 |
+
using AllocatedStringPtr = std::unique_ptr<char, detail::AllocatedFree>;
|
346 |
+
|
347 |
+
/** \brief The Status that holds ownership of OrtStatus received from C API
|
348 |
+
* Use it to safely destroy OrtStatus* returned from the C API. Use appropriate
|
349 |
+
* constructors to construct an instance of a Status object from exceptions.
|
350 |
+
*/
|
351 |
+
struct Status : detail::Base<OrtStatus> {
|
352 |
+
explicit Status(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used
|
353 |
+
explicit Status(OrtStatus* status); ///< Takes ownership of OrtStatus instance returned from the C API. Must be non-null
|
354 |
+
explicit Status(const Exception&); ///< Creates status instance out of exception
|
355 |
+
explicit Status(const std::exception&); ///< Creates status instance out of exception
|
356 |
+
std::string GetErrorMessage() const;
|
357 |
+
OrtErrorCode GetErrorCode() const;
|
358 |
+
};
|
359 |
+
|
360 |
+
/** \brief The ThreadingOptions
|
361 |
+
*
|
362 |
+
* The ThreadingOptions used for set global threadpools' options of The Env.
|
363 |
+
*/
|
364 |
+
struct ThreadingOptions : detail::Base<OrtThreadingOptions> {
|
365 |
+
/// \brief Wraps OrtApi::CreateThreadingOptions
|
366 |
+
ThreadingOptions();
|
367 |
+
|
368 |
+
/// \brief Wraps OrtApi::SetGlobalIntraOpNumThreads
|
369 |
+
ThreadingOptions& SetGlobalIntraOpNumThreads(int intra_op_num_threads);
|
370 |
+
|
371 |
+
/// \brief Wraps OrtApi::SetGlobalInterOpNumThreads
|
372 |
+
ThreadingOptions& SetGlobalInterOpNumThreads(int inter_op_num_threads);
|
373 |
+
|
374 |
+
/// \brief Wraps OrtApi::SetGlobalSpinControl
|
375 |
+
ThreadingOptions& SetGlobalSpinControl(int allow_spinning);
|
376 |
+
|
377 |
+
/// \brief Wraps OrtApi::SetGlobalDenormalAsZero
|
378 |
+
ThreadingOptions& SetGlobalDenormalAsZero();
|
379 |
+
|
380 |
+
/// \brief Wraps OrtApi::SetGlobalCustomCreateThreadFn
|
381 |
+
ThreadingOptions& SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn);
|
382 |
+
|
383 |
+
/// \brief Wraps OrtApi::SetGlobalCustomThreadCreationOptions
|
384 |
+
ThreadingOptions& SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options);
|
385 |
+
|
386 |
+
/// \brief Wraps OrtApi::SetGlobalCustomJoinThreadFn
|
387 |
+
ThreadingOptions& SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn);
|
388 |
+
};
|
389 |
+
|
390 |
+
/** \brief The Env (Environment)
|
391 |
+
*
|
392 |
+
* The Env holds the logging state used by all other objects.
|
393 |
+
* <b>Note:</b> One Env must be created before using any other Onnxruntime functionality
|
394 |
+
*/
|
395 |
+
struct Env : detail::Base<OrtEnv> {
|
396 |
+
explicit Env(std::nullptr_t) {} ///< Create an empty Env object, must be assigned a valid one to be used
|
397 |
+
|
398 |
+
/// \brief Wraps OrtApi::CreateEnv
|
399 |
+
Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
|
400 |
+
|
401 |
+
/// \brief Wraps OrtApi::CreateEnvWithCustomLogger
|
402 |
+
Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param);
|
403 |
+
|
404 |
+
/// \brief Wraps OrtApi::CreateEnvWithGlobalThreadPools
|
405 |
+
Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
|
406 |
+
|
407 |
+
/// \brief Wraps OrtApi::CreateEnvWithCustomLoggerAndGlobalThreadPools
|
408 |
+
Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
|
409 |
+
OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
|
410 |
+
|
411 |
+
/// \brief C Interop Helper
|
412 |
+
explicit Env(OrtEnv* p) : Base<OrtEnv>{p} {}
|
413 |
+
|
414 |
+
Env& EnableTelemetryEvents(); ///< Wraps OrtApi::EnableTelemetryEvents
|
415 |
+
Env& DisableTelemetryEvents(); ///< Wraps OrtApi::DisableTelemetryEvents
|
416 |
+
|
417 |
+
Env& UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level); ///< Wraps OrtApi::UpdateEnvWithCustomLogLevel
|
418 |
+
|
419 |
+
Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocator
|
420 |
+
};
|
421 |
+
|
422 |
+
/** \brief Custom Op Domain
|
423 |
+
*
|
424 |
+
*/
|
425 |
+
struct CustomOpDomain : detail::Base<OrtCustomOpDomain> {
|
426 |
+
explicit CustomOpDomain(std::nullptr_t) {} ///< Create an empty CustomOpDomain object, must be assigned a valid one to be used
|
427 |
+
|
428 |
+
/// \brief Wraps OrtApi::CreateCustomOpDomain
|
429 |
+
explicit CustomOpDomain(const char* domain);
|
430 |
+
|
431 |
+
// This does not take ownership of the op, simply registers it.
|
432 |
+
void Add(const OrtCustomOp* op); ///< Wraps CustomOpDomain_Add
|
433 |
+
};
|
434 |
+
|
435 |
+
/** \brief RunOptions
|
436 |
+
*
|
437 |
+
*/
|
438 |
+
struct RunOptions : detail::Base<OrtRunOptions> {
|
439 |
+
explicit RunOptions(std::nullptr_t) {} ///< Create an empty RunOptions object, must be assigned a valid one to be used
|
440 |
+
RunOptions(); ///< Wraps OrtApi::CreateRunOptions
|
441 |
+
|
442 |
+
RunOptions& SetRunLogVerbosityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel
|
443 |
+
int GetRunLogVerbosityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel
|
444 |
+
|
445 |
+
RunOptions& SetRunLogSeverityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogSeverityLevel
|
446 |
+
int GetRunLogSeverityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogSeverityLevel
|
447 |
+
|
448 |
+
RunOptions& SetRunTag(const char* run_tag); ///< wraps OrtApi::RunOptionsSetRunTag
|
449 |
+
const char* GetRunTag() const; ///< Wraps OrtApi::RunOptionsGetRunTag
|
450 |
+
|
451 |
+
RunOptions& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddRunConfigEntry
|
452 |
+
|
453 |
+
/** \brief Terminates all currently executing Session::Run calls that were made using this RunOptions instance
|
454 |
+
*
|
455 |
+
* If a currently executing session needs to be force terminated, this can be called from another thread to force it to fail with an error
|
456 |
+
* Wraps OrtApi::RunOptionsSetTerminate
|
457 |
+
*/
|
458 |
+
RunOptions& SetTerminate();
|
459 |
+
|
460 |
+
/** \brief Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without it instantly terminating
|
461 |
+
*
|
462 |
+
* Wraps OrtApi::RunOptionsUnsetTerminate
|
463 |
+
*/
|
464 |
+
RunOptions& UnsetTerminate();
|
465 |
+
};
|
466 |
+
|
467 |
+
|
468 |
+
namespace detail {
|
469 |
+
// Utility function that returns a SessionOption config entry key for a specific custom operator.
|
470 |
+
// Ex: custom_op.[custom_op_name].[config]
|
471 |
+
std::string MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config);
|
472 |
+
} // namespace detail
|
473 |
+
|
474 |
+
/// <summary>
|
475 |
+
/// Class that represents session configuration entries for one or more custom operators.
|
476 |
+
///
|
477 |
+
/// Example:
|
478 |
+
/// Ort::CustomOpConfigs op_configs;
|
479 |
+
/// op_configs.AddConfig("my_custom_op", "device_type", "CPU");
|
480 |
+
///
|
481 |
+
/// Passed to Ort::SessionOptions::RegisterCustomOpsLibrary.
|
482 |
+
/// </summary>
|
483 |
+
struct CustomOpConfigs {
|
484 |
+
CustomOpConfigs() = default;
|
485 |
+
~CustomOpConfigs() = default;
|
486 |
+
CustomOpConfigs(const CustomOpConfigs&) = default;
|
487 |
+
CustomOpConfigs& operator=(const CustomOpConfigs&) = default;
|
488 |
+
CustomOpConfigs(CustomOpConfigs&& o) = default;
|
489 |
+
CustomOpConfigs& operator=(CustomOpConfigs&& o) = default;
|
490 |
+
|
491 |
+
/** \brief Adds a session configuration entry/value for a specific custom operator.
|
492 |
+
*
|
493 |
+
* \param custom_op_name The name of the custom operator for which to add a configuration entry.
|
494 |
+
* Must match the name returned by the CustomOp's GetName() method.
|
495 |
+
* \param config_key The name of the configuration entry.
|
496 |
+
* \param config_value The value of the configuration entry.
|
497 |
+
* \return A reference to this object to enable call chaining.
|
498 |
+
*/
|
499 |
+
CustomOpConfigs& AddConfig(const char* custom_op_name, const char* config_key, const char* config_value);
|
500 |
+
|
501 |
+
/** \brief Returns a flattened map of custom operator configuration entries and their values.
|
502 |
+
*
|
503 |
+
* The keys has been flattened to include both the custom operator name and the configuration entry key name.
|
504 |
+
* For example, a prior call to AddConfig("my_op", "key", "value") corresponds to the flattened key/value pair
|
505 |
+
* {"my_op.key", "value"}.
|
506 |
+
*
|
507 |
+
* \return An unordered map of flattened configurations.
|
508 |
+
*/
|
509 |
+
const std::unordered_map<std::string, std::string>& GetFlattenedConfigs() const;
|
510 |
+
|
511 |
+
private:
|
512 |
+
std::unordered_map<std::string, std::string> flat_configs_;
|
513 |
+
};
|
514 |
+
|
515 |
+
/** \brief Options object used when creating a new Session object
|
516 |
+
*
|
517 |
+
* Wraps ::OrtSessionOptions object and methods
|
518 |
+
*/
|
519 |
+
|
520 |
+
struct SessionOptions;
|
521 |
+
|
522 |
+
namespace detail {
|
523 |
+
// we separate const-only methods because passing const ptr to non-const methods
|
524 |
+
// is only discovered when inline methods are compiled which is counter-intuitive
|
525 |
+
template <typename T>
|
526 |
+
struct ConstSessionOptionsImpl : Base<T> {
|
527 |
+
using B = Base<T>;
|
528 |
+
using B::B;
|
529 |
+
|
530 |
+
SessionOptions Clone() const; ///< Creates and returns a copy of this SessionOptions object. Wraps OrtApi::CloneSessionOptions
|
531 |
+
|
532 |
+
std::string GetConfigEntry(const char* config_key) const; ///< Wraps OrtApi::GetSessionConfigEntry
|
533 |
+
bool HasConfigEntry(const char* config_key) const; ///< Wraps OrtApi::HasSessionConfigEntry
|
534 |
+
std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def);
|
535 |
+
};
|
536 |
+
|
537 |
+
template <typename T>
|
538 |
+
struct SessionOptionsImpl : ConstSessionOptionsImpl<T> {
|
539 |
+
using B = ConstSessionOptionsImpl<T>;
|
540 |
+
using B::B;
|
541 |
+
|
542 |
+
SessionOptionsImpl& SetIntraOpNumThreads(int intra_op_num_threads); ///< Wraps OrtApi::SetIntraOpNumThreads
|
543 |
+
SessionOptionsImpl& SetInterOpNumThreads(int inter_op_num_threads); ///< Wraps OrtApi::SetInterOpNumThreads
|
544 |
+
SessionOptionsImpl& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::SetSessionGraphOptimizationLevel
|
545 |
+
|
546 |
+
SessionOptionsImpl& EnableCpuMemArena(); ///< Wraps OrtApi::EnableCpuMemArena
|
547 |
+
SessionOptionsImpl& DisableCpuMemArena(); ///< Wraps OrtApi::DisableCpuMemArena
|
548 |
+
|
549 |
+
SessionOptionsImpl& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file); ///< Wraps OrtApi::SetOptimizedModelFilePath
|
550 |
+
|
551 |
+
SessionOptionsImpl& EnableProfiling(const ORTCHAR_T* profile_file_prefix); ///< Wraps OrtApi::EnableProfiling
|
552 |
+
SessionOptionsImpl& DisableProfiling(); ///< Wraps OrtApi::DisableProfiling
|
553 |
+
|
554 |
+
SessionOptionsImpl& EnableOrtCustomOps(); ///< Wraps OrtApi::EnableOrtCustomOps
|
555 |
+
|
556 |
+
SessionOptionsImpl& EnableMemPattern(); ///< Wraps OrtApi::EnableMemPattern
|
557 |
+
SessionOptionsImpl& DisableMemPattern(); ///< Wraps OrtApi::DisableMemPattern
|
558 |
+
|
559 |
+
SessionOptionsImpl& SetExecutionMode(ExecutionMode execution_mode); ///< Wraps OrtApi::SetSessionExecutionMode
|
560 |
+
|
561 |
+
SessionOptionsImpl& SetLogId(const char* logid); ///< Wraps OrtApi::SetSessionLogId
|
562 |
+
SessionOptionsImpl& SetLogSeverityLevel(int level); ///< Wraps OrtApi::SetSessionLogSeverityLevel
|
563 |
+
|
564 |
+
SessionOptionsImpl& Add(OrtCustomOpDomain* custom_op_domain); ///< Wraps OrtApi::AddCustomOpDomain
|
565 |
+
|
566 |
+
SessionOptionsImpl& DisablePerSessionThreads(); ///< Wraps OrtApi::DisablePerSessionThreads
|
567 |
+
|
568 |
+
SessionOptionsImpl& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddSessionConfigEntry
|
569 |
+
|
570 |
+
SessionOptionsImpl& AddInitializer(const char* name, const OrtValue* ort_val); ///< Wraps OrtApi::AddInitializer
|
571 |
+
SessionOptionsImpl& AddExternalInitializers(const std::vector<std::string>& names, const std::vector<Value>& ort_values); ///< Wraps OrtApi::AddExternalInitializers
|
572 |
+
|
573 |
+
SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA
|
574 |
+
SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2
|
575 |
+
SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM
|
576 |
+
SessionOptionsImpl& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO
|
577 |
+
SessionOptionsImpl& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
|
578 |
+
SessionOptionsImpl& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
|
579 |
+
SessionOptionsImpl& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX
|
580 |
+
///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN
|
581 |
+
SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options);
|
582 |
+
///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl
|
583 |
+
SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options);
|
584 |
+
/// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports SNPE and XNNPACK.
|
585 |
+
SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name,
|
586 |
+
const std::unordered_map<std::string, std::string>& provider_options = {});
|
587 |
+
|
588 |
+
SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn
|
589 |
+
SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions
|
590 |
+
SessionOptionsImpl& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn
|
591 |
+
|
592 |
+
///< Registers the custom operator from the specified shared library via OrtApi::RegisterCustomOpsLibrary_V2.
|
593 |
+
///< The custom operator configurations are optional. If provided, custom operator configs are set via
|
594 |
+
///< OrtApi::AddSessionConfigEntry.
|
595 |
+
SessionOptionsImpl& RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs = {});
|
596 |
+
|
597 |
+
SessionOptionsImpl& RegisterCustomOpsUsingFunction(const char* function_name); ///< Wraps OrtApi::RegisterCustomOpsUsingFunction
|
598 |
+
};
|
599 |
+
} // namespace detail
|
600 |
+
|
601 |
+
using UnownedSessionOptions = detail::SessionOptionsImpl<detail::Unowned<OrtSessionOptions>>;
|
602 |
+
using ConstSessionOptions = detail::ConstSessionOptionsImpl<detail::Unowned<const OrtSessionOptions>>;
|
603 |
+
|
604 |
+
/** \brief Wrapper around ::OrtSessionOptions
|
605 |
+
*
|
606 |
+
*/
|
607 |
+
struct SessionOptions : detail::SessionOptionsImpl<OrtSessionOptions> {
|
608 |
+
explicit SessionOptions(std::nullptr_t) {} ///< Create an empty SessionOptions object, must be assigned a valid one to be used
|
609 |
+
SessionOptions(); ///< Wraps OrtApi::CreateSessionOptions
|
610 |
+
explicit SessionOptions(OrtSessionOptions* p) : SessionOptionsImpl<OrtSessionOptions>{p} {} ///< Used for interop with the C API
|
611 |
+
UnownedSessionOptions GetUnowned() const { return UnownedSessionOptions{this->p_}; }
|
612 |
+
ConstSessionOptions GetConst() const { return ConstSessionOptions{this->p_}; }
|
613 |
+
};
|
614 |
+
|
615 |
+
/** \brief Wrapper around ::OrtModelMetadata
|
616 |
+
*
|
617 |
+
*/
|
618 |
+
struct ModelMetadata : detail::Base<OrtModelMetadata> {
|
619 |
+
explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used
|
620 |
+
explicit ModelMetadata(OrtModelMetadata* p) : Base<OrtModelMetadata>{p} {} ///< Used for interop with the C API
|
621 |
+
|
622 |
+
/** \brief Returns a copy of the producer name.
|
623 |
+
*
|
624 |
+
* \param allocator to allocate memory for the copy of the name returned
|
625 |
+
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
|
626 |
+
* The OrtAllocator instances must be valid at the point of memory release.
|
627 |
+
*/
|
628 |
+
AllocatedStringPtr GetProducerNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName
|
629 |
+
|
630 |
+
/** \brief Returns a copy of the graph name.
|
631 |
+
*
|
632 |
+
* \param allocator to allocate memory for the copy of the name returned
|
633 |
+
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
|
634 |
+
* The OrtAllocator instances must be valid at the point of memory release.
|
635 |
+
*/
|
636 |
+
AllocatedStringPtr GetGraphNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName
|
637 |
+
|
638 |
+
/** \brief Returns a copy of the domain name.
|
639 |
+
*
|
640 |
+
* \param allocator to allocate memory for the copy of the name returned
|
641 |
+
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
|
642 |
+
* The OrtAllocator instances must be valid at the point of memory release.
|
643 |
+
*/
|
644 |
+
AllocatedStringPtr GetDomainAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain
|
645 |
+
|
646 |
+
/** \brief Returns a copy of the description.
|
647 |
+
*
|
648 |
+
* \param allocator to allocate memory for the copy of the string returned
|
649 |
+
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
|
650 |
+
* The OrtAllocator instances must be valid at the point of memory release.
|
651 |
+
*/
|
652 |
+
AllocatedStringPtr GetDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDescription
|
653 |
+
|
654 |
+
/** \brief Returns a copy of the graph description.
|
655 |
+
*
|
656 |
+
* \param allocator to allocate memory for the copy of the string returned
|
657 |
+
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
|
658 |
+
* The OrtAllocator instances must be valid at the point of memory release.
|
659 |
+
*/
|
660 |
+
AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphDescription
|
661 |
+
|
662 |
+
/** \brief Returns a vector of copies of the custom metadata keys.
|
663 |
+
*
|
664 |
+
* \param allocator to allocate memory for the copy of the string returned
|
665 |
+
* \return a instance std::vector of smart pointers that would deallocate the buffers when out of scope.
|
666 |
+
* The OrtAllocator instance must be valid at the point of memory release.
|
667 |
+
*/
|
668 |
+
std::vector<AllocatedStringPtr> GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys
|
669 |
+
|
670 |
+
/** \brief Looks up a value by a key in the Custom Metadata map
|
671 |
+
*
|
672 |
+
* \param key zero terminated string key to lookup
|
673 |
+
* \param allocator to allocate memory for the copy of the string returned
|
674 |
+
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
|
675 |
+
* maybe nullptr if key is not found.
|
676 |
+
*
|
677 |
+
* The OrtAllocator instances must be valid at the point of memory release.
|
678 |
+
*/
|
679 |
+
AllocatedStringPtr LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataLookupCustomMetadataMap
|
680 |
+
|
681 |
+
int64_t GetVersion() const; ///< Wraps OrtApi::ModelMetadataGetVersion
|
682 |
+
};
|
683 |
+
|
684 |
+
struct IoBinding;
|
685 |
+
|
686 |
+
namespace detail {
|
687 |
+
|
688 |
+
// we separate const-only methods because passing const ptr to non-const methods
|
689 |
+
// is only discovered when inline methods are compiled which is counter-intuitive
|
690 |
+
template <typename T>
|
691 |
+
struct ConstSessionImpl : Base<T> {
|
692 |
+
using B = Base<T>;
|
693 |
+
using B::B;
|
694 |
+
|
695 |
+
size_t GetInputCount() const; ///< Returns the number of model inputs
|
696 |
+
size_t GetOutputCount() const; ///< Returns the number of model outputs
|
697 |
+
size_t GetOverridableInitializerCount() const; ///< Returns the number of inputs that have defaults that can be overridden
|
698 |
+
|
699 |
+
/** \brief Returns a copy of input name at the specified index.
|
700 |
+
*
|
701 |
+
* \param index must less than the value returned by GetInputCount()
|
702 |
+
* \param allocator to allocate memory for the copy of the name returned
|
703 |
+
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
|
704 |
+
* The OrtAllocator instances must be valid at the point of memory release.
|
705 |
+
*/
|
706 |
+
AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator* allocator) const;
|
707 |
+
|
708 |
+
/** \brief Returns a copy of output name at then specified index.
|
709 |
+
*
|
710 |
+
* \param index must less than the value returned by GetOutputCount()
|
711 |
+
* \param allocator to allocate memory for the copy of the name returned
|
712 |
+
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
|
713 |
+
* The OrtAllocator instances must be valid at the point of memory release.
|
714 |
+
*/
|
715 |
+
AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const;
|
716 |
+
|
717 |
+
/** \brief Returns a copy of the overridable initializer name at then specified index.
|
718 |
+
*
|
719 |
+
* \param index must less than the value returned by GetOverridableInitializerCount()
|
720 |
+
* \param allocator to allocate memory for the copy of the name returned
|
721 |
+
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
|
722 |
+
* The OrtAllocator instances must be valid at the point of memory release.
|
723 |
+
*/
|
724 |
+
AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOverridableInitializerName
|
725 |
+
|
726 |
+
uint64_t GetProfilingStartTimeNs() const; ///< Wraps OrtApi::SessionGetProfilingStartTimeNs
|
727 |
+
ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::SessionGetModelMetadata
|
728 |
+
|
729 |
+
TypeInfo GetInputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetInputTypeInfo
|
730 |
+
TypeInfo GetOutputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOutputTypeInfo
|
731 |
+
TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo
|
732 |
+
};
|
733 |
+
|
734 |
+
template <typename T>
|
735 |
+
struct SessionImpl : ConstSessionImpl<T> {
|
736 |
+
using B = ConstSessionImpl<T>;
|
737 |
+
using B::B;
|
738 |
+
|
739 |
+
/** \brief Run the model returning results in an Ort allocated vector.
|
740 |
+
*
|
741 |
+
* Wraps OrtApi::Run
|
742 |
+
*
|
743 |
+
* The caller provides a list of inputs and a list of the desired outputs to return.
|
744 |
+
*
|
745 |
+
* See the output logs for more information on warnings/errors that occur while processing the model.
|
746 |
+
* Common errors are.. (TODO)
|
747 |
+
*
|
748 |
+
* \param[in] run_options
|
749 |
+
* \param[in] input_names Array of null terminated strings of length input_count that is the list of input names
|
750 |
+
* \param[in] input_values Array of Value objects of length input_count that is the list of input values
|
751 |
+
* \param[in] input_count Number of inputs (the size of the input_names & input_values arrays)
|
752 |
+
* \param[in] output_names Array of C style strings of length output_count that is the list of output names
|
753 |
+
* \param[in] output_count Number of outputs (the size of the output_names array)
|
754 |
+
* \return A std::vector of Value objects that directly maps to the output_names array (eg. output_name[0] is the first entry of the returned vector)
|
755 |
+
*/
|
756 |
+
std::vector<Value> Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
|
757 |
+
const char* const* output_names, size_t output_count);
|
758 |
+
|
759 |
+
/** \brief Run the model returning results in user provided outputs
|
760 |
+
* Same as Run(const RunOptions&, const char* const*, const Value*, size_t,const char* const*, size_t)
|
761 |
+
*/
|
762 |
+
void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
|
763 |
+
const char* const* output_names, Value* output_values, size_t output_count);
|
764 |
+
|
765 |
+
void Run(const RunOptions& run_options, const IoBinding&); ///< Wraps OrtApi::RunWithBinding
|
766 |
+
|
767 |
+
/** \brief End profiling and return a copy of the profiling file name.
|
768 |
+
*
|
769 |
+
* \param allocator to allocate memory for the copy of the string returned
|
770 |
+
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
|
771 |
+
* The OrtAllocator instances must be valid at the point of memory release.
|
772 |
+
*/
|
773 |
+
AllocatedStringPtr EndProfilingAllocated(OrtAllocator* allocator); ///< Wraps OrtApi::SessionEndProfiling
|
774 |
+
};
|
775 |
+
|
776 |
+
} // namespace detail
|
777 |
+
|
778 |
+
using ConstSession = detail::ConstSessionImpl<detail::Unowned<const OrtSession>>;
|
779 |
+
using UnownedSession = detail::SessionImpl<detail::Unowned<OrtSession>>;
|
780 |
+
|
781 |
+
/** \brief Wrapper around ::OrtSession
|
782 |
+
*
|
783 |
+
*/
|
784 |
+
struct Session : detail::SessionImpl<OrtSession> {
|
785 |
+
explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used
|
786 |
+
Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession
|
787 |
+
Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
|
788 |
+
OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer
|
789 |
+
Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray
|
790 |
+
Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options,
|
791 |
+
OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer
|
792 |
+
|
793 |
+
ConstSession GetConst() const { return ConstSession{this->p_}; }
|
794 |
+
UnownedSession GetUnowned() const { return UnownedSession{this->p_}; }
|
795 |
+
};
|
796 |
+
|
797 |
+
namespace detail {
|
798 |
+
template <typename T>
|
799 |
+
struct MemoryInfoImpl : Base<T> {
|
800 |
+
using B = Base<T>;
|
801 |
+
using B::B;
|
802 |
+
|
803 |
+
std::string GetAllocatorName() const;
|
804 |
+
OrtAllocatorType GetAllocatorType() const;
|
805 |
+
int GetDeviceId() const;
|
806 |
+
OrtMemoryInfoDeviceType GetDeviceType() const;
|
807 |
+
OrtMemType GetMemoryType() const;
|
808 |
+
|
809 |
+
template <typename U>
|
810 |
+
bool operator==(const MemoryInfoImpl<U>& o) const;
|
811 |
+
};
|
812 |
+
} // namespace detail
|
813 |
+
|
814 |
+
// Const object holder that does not own the underlying object
|
815 |
+
using ConstMemoryInfo = detail::MemoryInfoImpl<detail::Unowned<const OrtMemoryInfo>>;
|
816 |
+
|
817 |
+
/** \brief Wrapper around ::OrtMemoryInfo
|
818 |
+
*
|
819 |
+
*/
|
820 |
+
struct MemoryInfo : detail::MemoryInfoImpl<OrtMemoryInfo> {
|
821 |
+
static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1);
|
822 |
+
explicit MemoryInfo(std::nullptr_t) {} ///< No instance is created
|
823 |
+
explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl<OrtMemoryInfo>{p} {} ///< Take ownership of a pointer created by C Api
|
824 |
+
MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type);
|
825 |
+
ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; }
|
826 |
+
};
|
827 |
+
|
828 |
+
namespace detail {
|
829 |
+
template <typename T>
|
830 |
+
struct TensorTypeAndShapeInfoImpl : Base<T> {
|
831 |
+
using B = Base<T>;
|
832 |
+
using B::B;
|
833 |
+
|
834 |
+
ONNXTensorElementDataType GetElementType() const; ///< Wraps OrtApi::GetTensorElementType
|
835 |
+
size_t GetElementCount() const; ///< Wraps OrtApi::GetTensorShapeElementCount
|
836 |
+
|
837 |
+
size_t GetDimensionsCount() const; ///< Wraps OrtApi::GetDimensionsCount
|
838 |
+
|
839 |
+
/** \deprecated use GetShape() returning std::vector
|
840 |
+
* [[deprecated]]
|
841 |
+
* This interface is unsafe to use
|
842 |
+
*/
|
843 |
+
[[deprecated("use GetShape()")]] void GetDimensions(int64_t* values, size_t values_count) const; ///< Wraps OrtApi::GetDimensions
|
844 |
+
|
845 |
+
void GetSymbolicDimensions(const char** values, size_t values_count) const; ///< Wraps OrtApi::GetSymbolicDimensions
|
846 |
+
|
847 |
+
std::vector<int64_t> GetShape() const; ///< Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape
|
848 |
+
};
|
849 |
+
|
850 |
+
} // namespace detail
|
851 |
+
|
852 |
+
using ConstTensorTypeAndShapeInfo = detail::TensorTypeAndShapeInfoImpl<detail::Unowned<const OrtTensorTypeAndShapeInfo>>;
|
853 |
+
|
854 |
+
/** \brief Wrapper around ::OrtTensorTypeAndShapeInfo
|
855 |
+
*
|
856 |
+
*/
|
857 |
+
struct TensorTypeAndShapeInfo : detail::TensorTypeAndShapeInfoImpl<OrtTensorTypeAndShapeInfo> {
|
858 |
+
explicit TensorTypeAndShapeInfo(std::nullptr_t) {} ///< Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used
|
859 |
+
explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {} ///< Used for interop with the C API
|
860 |
+
ConstTensorTypeAndShapeInfo GetConst() const { return ConstTensorTypeAndShapeInfo{this->p_}; }
|
861 |
+
};
|
862 |
+
|
863 |
+
namespace detail {
|
864 |
+
template <typename T>
|
865 |
+
struct SequenceTypeInfoImpl : Base<T> {
|
866 |
+
using B = Base<T>;
|
867 |
+
using B::B;
|
868 |
+
TypeInfo GetSequenceElementType() const; ///< Wraps OrtApi::GetSequenceElementType
|
869 |
+
};
|
870 |
+
|
871 |
+
} // namespace detail
|
872 |
+
|
873 |
+
using ConstSequenceTypeInfo = detail::SequenceTypeInfoImpl<detail::Unowned<const OrtSequenceTypeInfo>>;
|
874 |
+
|
875 |
+
/** \brief Wrapper around ::OrtSequenceTypeInfo
|
876 |
+
*
|
877 |
+
*/
|
878 |
+
struct SequenceTypeInfo : detail::SequenceTypeInfoImpl<OrtSequenceTypeInfo> {
|
879 |
+
explicit SequenceTypeInfo(std::nullptr_t) {} ///< Create an empty SequenceTypeInfo object, must be assigned a valid one to be used
|
880 |
+
explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : SequenceTypeInfoImpl<OrtSequenceTypeInfo>{p} {} ///< Used for interop with the C API
|
881 |
+
ConstSequenceTypeInfo GetConst() const { return ConstSequenceTypeInfo{this->p_}; }
|
882 |
+
};
|
883 |
+
|
884 |
+
namespace detail {
|
885 |
+
template <typename T>
|
886 |
+
struct MapTypeInfoImpl : detail::Base<T> {
|
887 |
+
using B = Base<T>;
|
888 |
+
using B::B;
|
889 |
+
ONNXTensorElementDataType GetMapKeyType() const; ///< Wraps OrtApi::GetMapKeyType
|
890 |
+
TypeInfo GetMapValueType() const; ///< Wraps OrtApi::GetMapValueType
|
891 |
+
};
|
892 |
+
|
893 |
+
} // namespace detail
|
894 |
+
|
895 |
+
using ConstMapTypeInfo = detail::MapTypeInfoImpl<detail::Unowned<const OrtMapTypeInfo>>;
|
896 |
+
|
897 |
+
/** \brief Wrapper around ::OrtMapTypeInfo
|
898 |
+
*
|
899 |
+
*/
|
900 |
+
struct MapTypeInfo : detail::MapTypeInfoImpl<OrtMapTypeInfo> {
|
901 |
+
explicit MapTypeInfo(std::nullptr_t) {} ///< Create an empty MapTypeInfo object, must be assigned a valid one to be used
|
902 |
+
explicit MapTypeInfo(OrtMapTypeInfo* p) : MapTypeInfoImpl<OrtMapTypeInfo>{p} {} ///< Used for interop with the C API
|
903 |
+
ConstMapTypeInfo GetConst() const { return ConstMapTypeInfo{this->p_}; }
|
904 |
+
};
|
905 |
+
|
906 |
+
namespace detail {
|
907 |
+
template <typename T>
|
908 |
+
struct TypeInfoImpl : detail::Base<T> {
|
909 |
+
using B = Base<T>;
|
910 |
+
using B::B;
|
911 |
+
|
912 |
+
ConstTensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; ///< Wraps OrtApi::CastTypeInfoToTensorInfo
|
913 |
+
ConstSequenceTypeInfo GetSequenceTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToSequenceTypeInfo
|
914 |
+
ConstMapTypeInfo GetMapTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToMapTypeInfo
|
915 |
+
|
916 |
+
ONNXType GetONNXType() const;
|
917 |
+
};
|
918 |
+
} // namespace detail
|
919 |
+
|
920 |
+
/// <summary>
|
921 |
+
/// Contains a constant, unowned OrtTypeInfo that can be copied and passed around by value.
|
922 |
+
/// Provides access to const OrtTypeInfo APIs.
|
923 |
+
/// </summary>
|
924 |
+
using ConstTypeInfo = detail::TypeInfoImpl<detail::Unowned<const OrtTypeInfo>>;
|
925 |
+
|
926 |
+
/// <summary>
|
927 |
+
/// Type information that may contain either TensorTypeAndShapeInfo or
|
928 |
+
/// the information about contained sequence or map depending on the ONNXType.
|
929 |
+
/// </summary>
|
930 |
+
struct TypeInfo : detail::TypeInfoImpl<OrtTypeInfo> {
|
931 |
+
explicit TypeInfo(std::nullptr_t) {} ///< Create an empty TypeInfo object, must be assigned a valid one to be used
|
932 |
+
explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl<OrtTypeInfo>{p} {} ///< C API Interop
|
933 |
+
|
934 |
+
ConstTypeInfo GetConst() const { return ConstTypeInfo{this->p_}; }
|
935 |
+
};
|
936 |
+
|
937 |
+
namespace detail {
|
938 |
+
// This structure is used to feed sparse tensor values
|
939 |
+
// information for use with FillSparseTensor<Format>() API
|
940 |
+
// if the data type for the sparse tensor values is numeric
|
941 |
+
// use data.p_data, otherwise, use data.str pointer to feed
|
942 |
+
// values. data.str is an array of const char* that are zero terminated.
|
943 |
+
// number of strings in the array must match shape size.
|
944 |
+
// For fully sparse tensors use shape {0} and set p_data/str
|
945 |
+
// to nullptr.
|
946 |
+
struct OrtSparseValuesParam {
|
947 |
+
const int64_t* values_shape;
|
948 |
+
size_t values_shape_len;
|
949 |
+
union {
|
950 |
+
const void* p_data;
|
951 |
+
const char** str;
|
952 |
+
} data;
|
953 |
+
};
|
954 |
+
|
955 |
+
// Provides a way to pass shape in a single
|
956 |
+
// argument
|
957 |
+
struct Shape {
|
958 |
+
const int64_t* shape;
|
959 |
+
size_t shape_len;
|
960 |
+
};
|
961 |
+
|
962 |
+
template <typename T>
|
963 |
+
struct ConstValueImpl : Base<T> {
|
964 |
+
using B = Base<T>;
|
965 |
+
using B::B;
|
966 |
+
|
967 |
+
/// <summary>
|
968 |
+
/// Obtains a pointer to a user defined data for experimental purposes
|
969 |
+
/// </summary>
|
970 |
+
template <typename R>
|
971 |
+
void GetOpaqueData(const char* domain, const char* type_name, R&) const; ///< Wraps OrtApi::GetOpaqueValue
|
972 |
+
|
973 |
+
bool IsTensor() const; ///< Returns true if Value is a tensor, false for other types like map/sequence/etc
|
974 |
+
bool HasValue() const; /// < Return true if OrtValue contains data and returns false if the OrtValue is a None
|
975 |
+
|
976 |
+
size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements
|
977 |
+
Value GetValue(int index, OrtAllocator* allocator) const;
|
978 |
+
|
979 |
+
/// <summary>
|
980 |
+
/// This API returns a full length of string data contained within either a tensor or a sparse Tensor.
|
981 |
+
/// For sparse tensor it returns a full length of stored non-empty strings (values). The API is useful
|
982 |
+
/// for allocating necessary memory and calling GetStringTensorContent().
|
983 |
+
/// </summary>
|
984 |
+
/// <returns>total length of UTF-8 encoded bytes contained. No zero terminators counted.</returns>
|
985 |
+
size_t GetStringTensorDataLength() const;
|
986 |
+
|
987 |
+
/// <summary>
|
988 |
+
/// The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor
|
989 |
+
/// into a supplied buffer. Use GetStringTensorDataLength() to find out the length of the buffer to allocate.
|
990 |
+
/// The user must also allocate offsets buffer with the number of entries equal to that of the contained
|
991 |
+
/// strings.
|
992 |
+
///
|
993 |
+
/// Strings are always assumed to be on CPU, no X-device copy.
|
994 |
+
/// </summary>
|
995 |
+
/// <param name="buffer">user allocated buffer</param>
|
996 |
+
/// <param name="buffer_length">length in bytes of the allocated buffer</param>
|
997 |
+
/// <param name="offsets">a pointer to the offsets user allocated buffer</param>
|
998 |
+
/// <param name="offsets_count">count of offsets, must be equal to the number of strings contained.
|
999 |
+
/// that can be obtained from the shape of the tensor or from GetSparseTensorValuesTypeAndShapeInfo()
|
1000 |
+
/// for sparse tensors</param>
|
1001 |
+
void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const;
|
1002 |
+
|
1003 |
+
/// <summary>
|
1004 |
+
/// Returns a const typed pointer to the tensor contained data.
|
1005 |
+
/// No type checking is performed, the caller must ensure the type matches the tensor type.
|
1006 |
+
/// </summary>
|
1007 |
+
/// <typeparam name="T"></typeparam>
|
1008 |
+
/// <returns>const pointer to data, no copies made</returns>
|
1009 |
+
template <typename R>
|
1010 |
+
const R* GetTensorData() const; ///< Wraps OrtApi::GetTensorMutableData /// <summary>
|
1011 |
+
|
1012 |
+
/// <summary>
|
1013 |
+
/// Returns a non-typed pointer to a tensor contained data.
|
1014 |
+
/// </summary>
|
1015 |
+
/// <returns>const pointer to data, no copies made</returns>
|
1016 |
+
const void* GetTensorRawData() const;
|
1017 |
+
|
1018 |
+
/// <summary>
|
1019 |
+
/// The API returns type information for data contained in a tensor. For sparse
|
1020 |
+
/// tensors it returns type information for contained non-zero values.
|
1021 |
+
/// It returns dense shape for sparse tensors.
|
1022 |
+
/// </summary>
|
1023 |
+
/// <returns>TypeInfo</returns>
|
1024 |
+
TypeInfo GetTypeInfo() const;
|
1025 |
+
|
1026 |
+
/// <summary>
|
1027 |
+
/// The API returns type information for data contained in a tensor. For sparse
|
1028 |
+
/// tensors it returns type information for contained non-zero values.
|
1029 |
+
/// It returns dense shape for sparse tensors.
|
1030 |
+
/// </summary>
|
1031 |
+
/// <returns>TensorTypeAndShapeInfo</returns>
|
1032 |
+
TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const;
|
1033 |
+
|
1034 |
+
/// <summary>
|
1035 |
+
/// This API returns information about the memory allocation used to hold data.
|
1036 |
+
/// </summary>
|
1037 |
+
/// <returns>Non owning instance of MemoryInfo</returns>
|
1038 |
+
ConstMemoryInfo GetTensorMemoryInfo() const;
|
1039 |
+
|
1040 |
+
/// <summary>
|
1041 |
+
/// The API copies UTF-8 encoded bytes for the requested string element
|
1042 |
+
/// contained within a tensor or a sparse tensor into a provided buffer.
|
1043 |
+
/// Use GetStringTensorElementLength() to obtain the length of the buffer to allocate.
|
1044 |
+
/// </summary>
|
1045 |
+
/// <param name="buffer_length"></param>
|
1046 |
+
/// <param name="element_index"></param>
|
1047 |
+
/// <param name="buffer"></param>
|
1048 |
+
void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const;
|
1049 |
+
|
1050 |
+
/// <summary>
|
1051 |
+
/// The API returns a byte length of UTF-8 encoded string element
|
1052 |
+
/// contained in either a tensor or a spare tensor values.
|
1053 |
+
/// </summary>
|
1054 |
+
/// <param name="element_index"></param>
|
1055 |
+
/// <returns>byte length for the specified string element</returns>
|
1056 |
+
size_t GetStringTensorElementLength(size_t element_index) const;
|
1057 |
+
|
1058 |
+
#if !defined(DISABLE_SPARSE_TENSORS)
|
1059 |
+
/// <summary>
|
1060 |
+
/// The API returns the sparse data format this OrtValue holds in a sparse tensor.
|
1061 |
+
/// If the sparse tensor was not fully constructed, i.e. Use*() or Fill*() API were not used
|
1062 |
+
/// the value returned is ORT_SPARSE_UNDEFINED.
|
1063 |
+
/// </summary>
|
1064 |
+
/// <returns>Format enum</returns>
|
1065 |
+
OrtSparseFormat GetSparseFormat() const;
|
1066 |
+
|
1067 |
+
/// <summary>
|
1068 |
+
/// The API returns type and shape information for stored non-zero values of the
|
1069 |
+
/// sparse tensor. Use GetSparseTensorValues() to obtain values buffer pointer.
|
1070 |
+
/// </summary>
|
1071 |
+
/// <returns>TensorTypeAndShapeInfo values information</returns>
|
1072 |
+
TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const;
|
1073 |
+
|
1074 |
+
/// <summary>
|
1075 |
+
/// The API returns type and shape information for the specified indices. Each supported
|
1076 |
+
/// indices have their own enum values even if a give format has more than one kind of indices.
|
1077 |
+
/// Use GetSparseTensorIndicesData() to obtain pointer to indices buffer.
|
1078 |
+
/// </summary>
|
1079 |
+
/// <param name="format">enum requested</param>
|
1080 |
+
/// <returns>type and shape information</returns>
|
1081 |
+
TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const;
|
1082 |
+
|
1083 |
+
/// <summary>
|
1084 |
+
/// The API retrieves a pointer to the internal indices buffer. The API merely performs
|
1085 |
+
/// a convenience data type casting on the return type pointer. Make sure you are requesting
|
1086 |
+
/// the right type, use GetSparseTensorIndicesTypeShapeInfo();
|
1087 |
+
/// </summary>
|
1088 |
+
/// <typeparam name="T">type to cast to</typeparam>
|
1089 |
+
/// <param name="indices_format">requested indices kind</param>
|
1090 |
+
/// <param name="num_indices">number of indices entries</param>
|
1091 |
+
/// <returns>Pinter to the internal sparse tensor buffer containing indices. Do not free this pointer.</returns>
|
1092 |
+
template <typename R>
|
1093 |
+
const R* GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const;
|
1094 |
+
|
1095 |
+
/// <summary>
|
1096 |
+
/// Returns true if the OrtValue contains a sparse tensor
|
1097 |
+
/// </summary>
|
1098 |
+
/// <returns></returns>
|
1099 |
+
bool IsSparseTensor() const;
|
1100 |
+
|
1101 |
+
/// <summary>
|
1102 |
+
/// The API returns a pointer to an internal buffer of the sparse tensor
|
1103 |
+
/// containing non-zero values. The API merely does casting. Make sure you
|
1104 |
+
/// are requesting the right data type by calling GetSparseTensorValuesTypeAndShapeInfo()
|
1105 |
+
/// first.
|
1106 |
+
/// </summary>
|
1107 |
+
/// <typeparam name="T">numeric data types only. Use GetStringTensor*() to retrieve strings.</typeparam>
|
1108 |
+
/// <returns>a pointer to the internal values buffer. Do not free this pointer.</returns>
|
1109 |
+
template <typename R>
|
1110 |
+
const R* GetSparseTensorValues() const;
|
1111 |
+
|
1112 |
+
#endif
|
1113 |
+
};
|
1114 |
+
|
1115 |
+
template <typename T>
|
1116 |
+
struct ValueImpl : ConstValueImpl<T> {
|
1117 |
+
using B = ConstValueImpl<T>;
|
1118 |
+
using B::B;
|
1119 |
+
|
1120 |
+
/// <summary>
|
1121 |
+
/// Returns a non-const typed pointer to an OrtValue/Tensor contained buffer
|
1122 |
+
/// No type checking is performed, the caller must ensure the type matches the tensor type.
|
1123 |
+
/// </summary>
|
1124 |
+
/// <returns>non-const pointer to data, no copies made</returns>
|
1125 |
+
template <typename R>
|
1126 |
+
R* GetTensorMutableData();
|
1127 |
+
|
1128 |
+
/// <summary>
|
1129 |
+
/// Returns a non-typed non-const pointer to a tensor contained data.
|
1130 |
+
/// </summary>
|
1131 |
+
/// <returns>pointer to data, no copies made</returns>
|
1132 |
+
void* GetTensorMutableRawData();
|
1133 |
+
|
1134 |
+
/// <summary>
|
1135 |
+
// Obtain a reference to an element of data at the location specified
|
1136 |
+
/// by the vector of dims.
|
1137 |
+
/// </summary>
|
1138 |
+
/// <typeparam name="R"></typeparam>
|
1139 |
+
/// <param name="location">[in] expressed by a vecotr of dimensions offsets</param>
|
1140 |
+
/// <returns></returns>
|
1141 |
+
template <typename R>
|
1142 |
+
R& At(const std::vector<int64_t>& location);
|
1143 |
+
|
1144 |
+
/// <summary>
|
1145 |
+
/// Set all strings at once in a string tensor
|
1146 |
+
/// </summary>
|
1147 |
+
/// <param name="s">[in] An array of strings. Each string in this array must be null terminated.</param>
|
1148 |
+
/// <param name="s_len">[in] Count of strings in s (Must match the size of \p value's tensor shape)</param>
|
1149 |
+
void FillStringTensor(const char* const* s, size_t s_len);
|
1150 |
+
|
1151 |
+
/// <summary>
|
1152 |
+
/// Set a single string in a string tensor
|
1153 |
+
/// </summary>
|
1154 |
+
/// <param name="s">[in] A null terminated UTF-8 encoded string</param>
|
1155 |
+
/// <param name="index">[in] Index of the string in the tensor to set</param>
|
1156 |
+
void FillStringTensorElement(const char* s, size_t index);
|
1157 |
+
|
1158 |
+
#if !defined(DISABLE_SPARSE_TENSORS)
|
1159 |
+
/// <summary>
|
1160 |
+
/// Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tensor.
|
1161 |
+
/// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
|
1162 |
+
/// allocated buffers lifespan must eclipse that of the OrtValue.
|
1163 |
+
/// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
|
1164 |
+
/// </summary>
|
1165 |
+
/// <param name="indices_data">pointer to the user allocated buffer with indices. Use nullptr for fully sparse tensors.</param>
|
1166 |
+
/// <param name="indices_num">number of indices entries. Use 0 for fully sparse tensors</param>
|
1167 |
+
void UseCooIndices(int64_t* indices_data, size_t indices_num);
|
1168 |
+
|
1169 |
+
/// <summary>
|
1170 |
+
/// Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tensor.
|
1171 |
+
/// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
|
1172 |
+
/// allocated buffers lifespan must eclipse that of the OrtValue.
|
1173 |
+
/// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
|
1174 |
+
/// </summary>
|
1175 |
+
/// <param name="inner_data">pointer to the user allocated buffer with inner indices or nullptr for fully sparse tensors</param>
|
1176 |
+
/// <param name="inner_num">number of csr inner indices or 0 for fully sparse tensors</param>
|
1177 |
+
/// <param name="outer_data">pointer to the user allocated buffer with outer indices or nullptr for fully sparse tensors</param>
|
1178 |
+
/// <param name="outer_num">number of csr outer indices or 0 for fully sparse tensors</param>
|
1179 |
+
void UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num);
|
1180 |
+
|
1181 |
+
/// <summary>
|
1182 |
+
/// Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSparse format tensor.
|
1183 |
+
/// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
|
1184 |
+
/// allocated buffers lifespan must eclipse that of the OrtValue.
|
1185 |
+
/// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
|
1186 |
+
/// </summary>
|
1187 |
+
/// <param name="indices_shape">indices shape or a {0} for fully sparse</param>
|
1188 |
+
/// <param name="indices_data">user allocated buffer with indices or nullptr for fully spare tensors</param>
|
1189 |
+
void UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data);
|
1190 |
+
|
1191 |
+
/// <summary>
|
1192 |
+
/// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
|
1193 |
+
/// and copy the values and COO indices into it. If data_mem_info specifies that the data is located
|
1194 |
+
/// at difference device than the allocator, a X-device copy will be performed if possible.
|
1195 |
+
/// </summary>
|
1196 |
+
/// <param name="data_mem_info">specified buffer memory description</param>
|
1197 |
+
/// <param name="values_param">values buffer information.</param>
|
1198 |
+
/// <param name="indices_data">coo indices buffer or nullptr for fully sparse data</param>
|
1199 |
+
/// <param name="indices_num">number of COO indices or 0 for fully sparse data</param>
|
1200 |
+
void FillSparseTensorCoo(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values_param,
|
1201 |
+
const int64_t* indices_data, size_t indices_num);
|
1202 |
+
|
1203 |
+
/// <summary>
|
1204 |
+
/// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
|
1205 |
+
/// and copy the values and CSR indices into it. If data_mem_info specifies that the data is located
|
1206 |
+
/// at difference device than the allocator, a X-device copy will be performed if possible.
|
1207 |
+
/// </summary>
|
1208 |
+
/// <param name="data_mem_info">specified buffer memory description</param>
|
1209 |
+
/// <param name="values">values buffer information</param>
|
1210 |
+
/// <param name="inner_indices_data">csr inner indices pointer or nullptr for fully sparse tensors</param>
|
1211 |
+
/// <param name="inner_indices_num">number of csr inner indices or 0 for fully sparse tensors</param>
|
1212 |
+
/// <param name="outer_indices_data">pointer to csr indices data or nullptr for fully sparse tensors</param>
|
1213 |
+
/// <param name="outer_indices_num">number of csr outer indices or 0</param>
|
1214 |
+
void FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
|
1215 |
+
const OrtSparseValuesParam& values,
|
1216 |
+
const int64_t* inner_indices_data, size_t inner_indices_num,
|
1217 |
+
const int64_t* outer_indices_data, size_t outer_indices_num);
|
1218 |
+
|
1219 |
+
/// <summary>
|
1220 |
+
/// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
|
1221 |
+
/// and copy the values and BlockSparse indices into it. If data_mem_info specifies that the data is located
|
1222 |
+
/// at difference device than the allocator, a X-device copy will be performed if possible.
|
1223 |
+
/// </summary>
|
1224 |
+
/// <param name="data_mem_info">specified buffer memory description</param>
|
1225 |
+
/// <param name="values">values buffer information</param>
|
1226 |
+
/// <param name="indices_shape">indices shape. use {0} for fully sparse tensors</param>
|
1227 |
+
/// <param name="indices_data">pointer to indices data or nullptr for fully sparse tensors</param>
|
1228 |
+
void FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
|
1229 |
+
const OrtSparseValuesParam& values,
|
1230 |
+
const Shape& indices_shape,
|
1231 |
+
const int32_t* indices_data);
|
1232 |
+
|
1233 |
+
#endif
|
1234 |
+
};
|
1235 |
+
|
1236 |
+
} // namespace detail
|
1237 |
+
|
1238 |
+
using ConstValue = detail::ConstValueImpl<detail::Unowned<const OrtValue>>;
|
1239 |
+
using UnownedValue = detail::ValueImpl<detail::Unowned<OrtValue>>;
|
1240 |
+
|
1241 |
+
/** \brief Wrapper around ::OrtValue
|
1242 |
+
*
|
1243 |
+
*/
|
1244 |
+
struct Value : detail::ValueImpl<OrtValue> {
|
1245 |
+
using Base = detail::ValueImpl<OrtValue>;
|
1246 |
+
using OrtSparseValuesParam = detail::OrtSparseValuesParam;
|
1247 |
+
using Shape = detail::Shape;
|
1248 |
+
|
1249 |
+
explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used
|
1250 |
+
explicit Value(OrtValue* p) : Base{p} {} ///< Used for interop with the C API
|
1251 |
+
Value(Value&&) = default;
|
1252 |
+
Value& operator=(Value&&) = default;
|
1253 |
+
|
1254 |
+
ConstValue GetConst() const { return ConstValue{this->p_}; }
|
1255 |
+
UnownedValue GetUnowned() const { return UnownedValue{this->p_}; }
|
1256 |
+
|
1257 |
+
/** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
|
1258 |
+
* \tparam T The numeric datatype. This API is not suitable for strings.
|
1259 |
+
* \param info Memory description of where the p_data buffer resides (CPU vs GPU etc).
|
1260 |
+
* \param p_data Pointer to the data buffer.
|
1261 |
+
* \param p_data_element_count The number of elements in the data buffer.
|
1262 |
+
* \param shape Pointer to the tensor shape dimensions.
|
1263 |
+
* \param shape_len The number of tensor shape dimensions.
|
1264 |
+
*/
|
1265 |
+
template <typename T>
|
1266 |
+
static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len);
|
1267 |
+
|
1268 |
+
/** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
|
1269 |
+
* \param info Memory description of where the p_data buffer resides (CPU vs GPU etc).
|
1270 |
+
* \param p_data Pointer to the data buffer.
|
1271 |
+
* \param p_data_byte_count The number of bytes in the data buffer.
|
1272 |
+
* \param shape Pointer to the tensor shape dimensions.
|
1273 |
+
* \param shape_len The number of tensor shape dimensions.
|
1274 |
+
* \param type The data type.
|
1275 |
+
*/
|
1276 |
+
static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
|
1277 |
+
ONNXTensorElementDataType type);
|
1278 |
+
|
1279 |
+
/** \brief Creates a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue.
|
1280 |
+
* \tparam T The numeric datatype. This API is not suitable for strings.
|
1281 |
+
* \param allocator The allocator to use.
|
1282 |
+
* \param shape Pointer to the tensor shape dimensions.
|
1283 |
+
* \param shape_len The number of tensor shape dimensions.
|
1284 |
+
*/
|
1285 |
+
template <typename T>
|
1286 |
+
static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len);
|
1287 |
+
|
1288 |
+
/** \brief Creates a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue.
|
1289 |
+
* \param allocator The allocator to use.
|
1290 |
+
* \param shape Pointer to the tensor shape dimensions.
|
1291 |
+
* \param shape_len The number of tensor shape dimensions.
|
1292 |
+
* \param type The data type.
|
1293 |
+
*/
|
1294 |
+
static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type);
|
1295 |
+
|
1296 |
+
static Value CreateMap(Value& keys, Value& values); ///< Wraps OrtApi::CreateValue
|
1297 |
+
static Value CreateSequence(std::vector<Value>& values); ///< Wraps OrtApi::CreateValue
|
1298 |
+
|
1299 |
+
template <typename T>
|
1300 |
+
static Value CreateOpaque(const char* domain, const char* type_name, const T&); ///< Wraps OrtApi::CreateOpaqueValue
|
1301 |
+
|
1302 |
+
#if !defined(DISABLE_SPARSE_TENSORS)
|
1303 |
+
/// <summary>
|
1304 |
+
/// This is a simple forwarding method to the other overload that helps deducing
|
1305 |
+
/// data type enum value from the type of the buffer.
|
1306 |
+
/// </summary>
|
1307 |
+
/// <typeparam name="T">numeric datatype. This API is not suitable for strings.</typeparam>
|
1308 |
+
/// <param name="info">Memory description where the user buffers reside (CPU vs GPU etc)</param>
|
1309 |
+
/// <param name="p_data">pointer to the user supplied buffer, use nullptr for fully sparse tensors</param>
|
1310 |
+
/// <param name="dense_shape">a would be dense shape of the tensor</param>
|
1311 |
+
/// <param name="values_shape">non zero values shape. Use a single 0 shape for fully sparse tensors.</param>
|
1312 |
+
/// <returns></returns>
|
1313 |
+
template <typename T>
|
1314 |
+
static Value CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
|
1315 |
+
const Shape& values_shape);
|
1316 |
+
|
1317 |
+
/// <summary>
|
1318 |
+
/// Creates an OrtValue instance containing SparseTensor. This constructs
|
1319 |
+
/// a sparse tensor that makes use of user allocated buffers. It does not make copies
|
1320 |
+
/// of the user provided data and does not modify it. The lifespan of user provided buffers should
|
1321 |
+
/// eclipse the life span of the resulting OrtValue. This call constructs an instance that only contain
|
1322 |
+
/// a pointer to non-zero values. To fully populate the sparse tensor call Use<Format>Indices() API below
|
1323 |
+
/// to supply a sparse format specific indices.
|
1324 |
+
/// This API is not suitable for string data. Use CreateSparseTensor() with allocator specified so strings
|
1325 |
+
/// can be properly copied into the allocated buffer.
|
1326 |
+
/// </summary>
|
1327 |
+
/// <param name="info">Memory description where the user buffers reside (CPU vs GPU etc)</param>
|
1328 |
+
/// <param name="p_data">pointer to the user supplied buffer, use nullptr for fully sparse tensors</param>
|
1329 |
+
/// <param name="dense_shape">a would be dense shape of the tensor</param>
|
1330 |
+
/// <param name="values_shape">non zero values shape. Use a single 0 shape for fully sparse tensors.</param>
|
1331 |
+
/// <param name="type">data type</param>
|
1332 |
+
/// <returns>Ort::Value instance containing SparseTensor</returns>
|
1333 |
+
static Value CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
|
1334 |
+
const Shape& values_shape, ONNXTensorElementDataType type);
|
1335 |
+
|
1336 |
+
/// <summary>
|
1337 |
+
/// This is a simple forwarding method to the below CreateSparseTensor.
|
1338 |
+
/// This helps to specify data type enum in terms of C++ data type.
|
1339 |
+
/// Use CreateSparseTensor<T>
|
1340 |
+
/// </summary>
|
1341 |
+
/// <typeparam name="T">numeric data type only. String data enum must be specified explicitly.</typeparam>
|
1342 |
+
/// <param name="allocator">allocator to use</param>
|
1343 |
+
/// <param name="dense_shape">a would be dense shape of the tensor</param>
|
1344 |
+
/// <returns>Ort::Value</returns>
|
1345 |
+
template <typename T>
|
1346 |
+
static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape);
|
1347 |
+
|
1348 |
+
/// <summary>
|
1349 |
+
/// Creates an instance of OrtValue containing sparse tensor. The created instance has no data.
|
1350 |
+
/// The data must be supplied by on of the FillSparseTensor<Format>() methods that take both non-zero values
|
1351 |
+
/// and indices. The data will be copied into a buffer that would be allocated using the supplied allocator.
|
1352 |
+
/// Use this API to create OrtValues that contain sparse tensors with all supported data types including
|
1353 |
+
/// strings.
|
1354 |
+
/// </summary>
|
1355 |
+
/// <param name="allocator">allocator to use. The allocator lifespan must eclipse that of the resulting OrtValue</param>
|
1356 |
+
/// <param name="dense_shape">a would be dense shape of the tensor</param>
|
1357 |
+
/// <param name="type">data type</param>
|
1358 |
+
/// <returns>an instance of Ort::Value</returns>
|
1359 |
+
static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, ONNXTensorElementDataType type);
|
1360 |
+
|
1361 |
+
#endif // !defined(DISABLE_SPARSE_TENSORS)
|
1362 |
+
};
|
1363 |
+
|
1364 |
+
/// <summary>
|
1365 |
+
/// Represents native memory allocation coming from one of the
|
1366 |
+
/// OrtAllocators registered with OnnxRuntime.
|
1367 |
+
/// Use it to wrap an allocation made by an allocator
|
1368 |
+
/// so it can be automatically released when no longer needed.
|
1369 |
+
/// </summary>
|
1370 |
+
struct MemoryAllocation {
|
1371 |
+
MemoryAllocation(OrtAllocator* allocator, void* p, size_t size);
|
1372 |
+
~MemoryAllocation();
|
1373 |
+
MemoryAllocation(const MemoryAllocation&) = delete;
|
1374 |
+
MemoryAllocation& operator=(const MemoryAllocation&) = delete;
|
1375 |
+
MemoryAllocation(MemoryAllocation&&) noexcept;
|
1376 |
+
MemoryAllocation& operator=(MemoryAllocation&&) noexcept;
|
1377 |
+
|
1378 |
+
void* get() { return p_; }
|
1379 |
+
size_t size() const { return size_; }
|
1380 |
+
|
1381 |
+
private:
|
1382 |
+
OrtAllocator* allocator_;
|
1383 |
+
void* p_;
|
1384 |
+
size_t size_;
|
1385 |
+
};
|
1386 |
+
|
1387 |
+
namespace detail {
|
1388 |
+
template <typename T>
|
1389 |
+
struct AllocatorImpl : Base<T> {
|
1390 |
+
using B = Base<T>;
|
1391 |
+
using B::B;
|
1392 |
+
|
1393 |
+
void* Alloc(size_t size);
|
1394 |
+
MemoryAllocation GetAllocation(size_t size);
|
1395 |
+
void Free(void* p);
|
1396 |
+
ConstMemoryInfo GetInfo() const;
|
1397 |
+
};
|
1398 |
+
|
1399 |
+
} // namespace detail
|
1400 |
+
|
1401 |
+
/** \brief Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime
|
1402 |
+
*
|
1403 |
+
*/
|
1404 |
+
struct AllocatorWithDefaultOptions : detail::AllocatorImpl<detail::Unowned<OrtAllocator>> {
|
1405 |
+
explicit AllocatorWithDefaultOptions(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance
|
1406 |
+
AllocatorWithDefaultOptions();
|
1407 |
+
};
|
1408 |
+
|
1409 |
+
/** \brief Wrapper around ::OrtAllocator
|
1410 |
+
*
|
1411 |
+
*/
|
1412 |
+
struct Allocator : detail::AllocatorImpl<OrtAllocator> {
|
1413 |
+
explicit Allocator(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance
|
1414 |
+
Allocator(const Session& session, const OrtMemoryInfo*);
|
1415 |
+
};
|
1416 |
+
|
1417 |
+
using UnownedAllocator = detail::AllocatorImpl<detail::Unowned<OrtAllocator>>;
|
1418 |
+
|
1419 |
+
namespace detail {
|
1420 |
+
namespace binding_utils {
|
1421 |
+
// Bring these out of template
|
1422 |
+
std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator*);
|
1423 |
+
std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator*);
|
1424 |
+
} // namespace binding_utils
|
1425 |
+
|
1426 |
+
template <typename T>
|
1427 |
+
struct ConstIoBindingImpl : Base<T> {
|
1428 |
+
using B = Base<T>;
|
1429 |
+
using B::B;
|
1430 |
+
|
1431 |
+
std::vector<std::string> GetOutputNames() const;
|
1432 |
+
std::vector<std::string> GetOutputNames(OrtAllocator*) const;
|
1433 |
+
std::vector<Value> GetOutputValues() const;
|
1434 |
+
std::vector<Value> GetOutputValues(OrtAllocator*) const;
|
1435 |
+
};
|
1436 |
+
|
1437 |
+
template <typename T>
|
1438 |
+
struct IoBindingImpl : ConstIoBindingImpl<T> {
|
1439 |
+
using B = ConstIoBindingImpl<T>;
|
1440 |
+
using B::B;
|
1441 |
+
|
1442 |
+
void BindInput(const char* name, const Value&);
|
1443 |
+
void BindOutput(const char* name, const Value&);
|
1444 |
+
void BindOutput(const char* name, const OrtMemoryInfo*);
|
1445 |
+
void ClearBoundInputs();
|
1446 |
+
void ClearBoundOutputs();
|
1447 |
+
void SynchronizeInputs();
|
1448 |
+
void SynchronizeOutputs();
|
1449 |
+
};
|
1450 |
+
|
1451 |
+
} // namespace detail
|
1452 |
+
|
1453 |
+
using ConstIoBinding = detail::ConstIoBindingImpl<detail::Unowned<const OrtIoBinding>>;
|
1454 |
+
using UnownedIoBinding = detail::IoBindingImpl<detail::Unowned<OrtIoBinding>>;
|
1455 |
+
|
1456 |
+
/** \brief Wrapper around ::OrtIoBinding
|
1457 |
+
*
|
1458 |
+
*/
|
1459 |
+
struct IoBinding : detail::IoBindingImpl<OrtIoBinding> {
|
1460 |
+
explicit IoBinding(std::nullptr_t) {} ///< Create an empty object for convenience. Sometimes, we want to initialize members later.
|
1461 |
+
explicit IoBinding(Session& session);
|
1462 |
+
ConstIoBinding GetConst() const { return ConstIoBinding{this->p_}; }
|
1463 |
+
UnownedIoBinding GetUnowned() const { return UnownedIoBinding{this->p_}; }
|
1464 |
+
};
|
1465 |
+
|
1466 |
+
/*! \struct Ort::ArenaCfg
|
1467 |
+
* \brief it is a structure that represents the configuration of an arena based allocator
|
1468 |
+
* \details Please see docs/C_API.md for details
|
1469 |
+
*/
|
1470 |
+
struct ArenaCfg : detail::Base<OrtArenaCfg> {
|
1471 |
+
explicit ArenaCfg(std::nullptr_t) {} ///< Create an empty ArenaCfg object, must be assigned a valid one to be used
|
1472 |
+
/**
|
1473 |
+
* Wraps OrtApi::CreateArenaCfg
|
1474 |
+
* \param max_mem - use 0 to allow ORT to choose the default
|
1475 |
+
* \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested
|
1476 |
+
* \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default
|
1477 |
+
* \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default
|
1478 |
+
* See docs/C_API.md for details on what the following parameters mean and how to choose these values
|
1479 |
+
*/
|
1480 |
+
ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk);
|
1481 |
+
};
|
1482 |
+
|
1483 |
+
//
|
1484 |
+
// Custom OPs (only needed to implement custom OPs)
|
1485 |
+
//
|
1486 |
+
|
1487 |
+
/// <summary>
|
1488 |
+
/// This struct provides life time management for custom op attribute
|
1489 |
+
/// </summary>
|
1490 |
+
struct OpAttr : detail::Base<OrtOpAttr> {
|
1491 |
+
OpAttr(const char* name, const void* data, int len, OrtOpAttrType type);
|
1492 |
+
};
|
1493 |
+
|
1494 |
+
/// <summary>
|
1495 |
+
/// This class wraps a raw pointer OrtKernelContext* that is being passed
|
1496 |
+
/// to the custom kernel Compute() method. Use it to safely access context
|
1497 |
+
/// attributes, input and output parameters with exception safety guarantees.
|
1498 |
+
/// See usage example in onnxruntime/test/testdata/custom_op_library/custom_op_library.cc
|
1499 |
+
/// </summary>
|
1500 |
+
struct KernelContext {
|
1501 |
+
explicit KernelContext(OrtKernelContext* context);
|
1502 |
+
size_t GetInputCount() const;
|
1503 |
+
size_t GetOutputCount() const;
|
1504 |
+
ConstValue GetInput(size_t index) const;
|
1505 |
+
UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const;
|
1506 |
+
UnownedValue GetOutput(size_t index, const std::vector<int64_t>& dims) const;
|
1507 |
+
void* GetGPUComputeStream() const;
|
1508 |
+
|
1509 |
+
private:
|
1510 |
+
OrtKernelContext* ctx_;
|
1511 |
+
};
|
1512 |
+
|
1513 |
+
struct KernelInfo;
|
1514 |
+
|
1515 |
+
namespace detail {
|
1516 |
+
namespace attr_utils {
|
1517 |
+
void GetAttr(const OrtKernelInfo* p, const char* name, float&);
|
1518 |
+
void GetAttr(const OrtKernelInfo* p, const char* name, int64_t&);
|
1519 |
+
void GetAttr(const OrtKernelInfo* p, const char* name, std::string&);
|
1520 |
+
void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>&);
|
1521 |
+
void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>&);
|
1522 |
+
} // namespace attr_utils
|
1523 |
+
|
1524 |
+
template <typename T>
|
1525 |
+
struct KernelInfoImpl : Base<T> {
|
1526 |
+
using B = Base<T>;
|
1527 |
+
using B::B;
|
1528 |
+
|
1529 |
+
KernelInfo Copy() const;
|
1530 |
+
|
1531 |
+
template <typename R> // R is only implemented for float, int64_t, and string
|
1532 |
+
R GetAttribute(const char* name) const {
|
1533 |
+
R val;
|
1534 |
+
attr_utils::GetAttr(this->p_, name, val);
|
1535 |
+
return val;
|
1536 |
+
}
|
1537 |
+
|
1538 |
+
template <typename R> // R is only implemented for std::vector<float>, std::vector<int64_t>
|
1539 |
+
std::vector<R> GetAttributes(const char* name) const {
|
1540 |
+
std::vector<R> result;
|
1541 |
+
attr_utils::GetAttrs(this->p_, name, result);
|
1542 |
+
return result;
|
1543 |
+
}
|
1544 |
+
|
1545 |
+
Value GetTensorAttribute(const char* name, OrtAllocator* allocator) const;
|
1546 |
+
|
1547 |
+
size_t GetInputCount() const;
|
1548 |
+
size_t GetOutputCount() const;
|
1549 |
+
|
1550 |
+
std::string GetInputName(size_t index) const;
|
1551 |
+
std::string GetOutputName(size_t index) const;
|
1552 |
+
|
1553 |
+
TypeInfo GetInputTypeInfo(size_t index) const;
|
1554 |
+
TypeInfo GetOutputTypeInfo(size_t index) const;
|
1555 |
+
};
|
1556 |
+
|
1557 |
+
} // namespace detail
|
1558 |
+
|
1559 |
+
using ConstKernelInfo = detail::KernelInfoImpl<detail::Unowned<const OrtKernelInfo>>;
|
1560 |
+
|
1561 |
+
/// <summary>
|
1562 |
+
/// This struct owns the OrtKernInfo* pointer when a copy is made.
|
1563 |
+
/// For convenient wrapping of OrtKernelInfo* passed to kernel constructor
|
1564 |
+
/// and query attributes, warp the pointer with Ort::Unowned<KernelInfo> instance
|
1565 |
+
/// so it does not destroy the pointer the kernel does not own.
|
1566 |
+
/// </summary>
|
1567 |
+
struct KernelInfo : detail::KernelInfoImpl<OrtKernelInfo> {
|
1568 |
+
explicit KernelInfo(std::nullptr_t) {} ///< Create an empty instance to initialize later
|
1569 |
+
explicit KernelInfo(OrtKernelInfo* info); ///< Take ownership of the instance
|
1570 |
+
ConstKernelInfo GetConst() const { return ConstKernelInfo{this->p_}; }
|
1571 |
+
};
|
1572 |
+
|
1573 |
+
/// <summary>
|
1574 |
+
/// Create and own custom defined operation.
|
1575 |
+
/// </summary>
|
1576 |
+
struct Op : detail::Base<OrtOp> {
|
1577 |
+
explicit Op(std::nullptr_t) {} ///< Create an empty Operator object, must be assigned a valid one to be used
|
1578 |
+
|
1579 |
+
explicit Op(OrtOp*); ///< Take ownership of the OrtOp
|
1580 |
+
|
1581 |
+
static Op Create(const OrtKernelInfo* info, const char* op_name, const char* domain,
|
1582 |
+
int version, const char** type_constraint_names,
|
1583 |
+
const ONNXTensorElementDataType* type_constraint_values,
|
1584 |
+
size_t type_constraint_count,
|
1585 |
+
const OpAttr* attr_values,
|
1586 |
+
size_t attr_count,
|
1587 |
+
size_t input_count, size_t output_count);
|
1588 |
+
|
1589 |
+
void Invoke(const OrtKernelContext* context,
|
1590 |
+
const Value* input_values,
|
1591 |
+
size_t input_count,
|
1592 |
+
Value* output_values,
|
1593 |
+
size_t output_count);
|
1594 |
+
|
1595 |
+
// For easier refactoring
|
1596 |
+
void Invoke(const OrtKernelContext* context,
|
1597 |
+
const OrtValue* const* input_values,
|
1598 |
+
size_t input_count,
|
1599 |
+
OrtValue* const* output_values,
|
1600 |
+
size_t output_count);
|
1601 |
+
};
|
1602 |
+
|
1603 |
+
/// <summary>
|
1604 |
+
/// This entire structure is deprecated, but we not marking
|
1605 |
+
/// it as a whole yet since we want to preserve for the next release.
|
1606 |
+
/// </summary>
|
1607 |
+
struct CustomOpApi {
|
1608 |
+
CustomOpApi(const OrtApi& api) : api_(api) {}
|
1609 |
+
|
1610 |
+
/** \deprecated use Ort::Value::GetTensorTypeAndShape()
|
1611 |
+
* [[deprecated]]
|
1612 |
+
* This interface produces a pointer that must be released. Not exception safe.
|
1613 |
+
*/
|
1614 |
+
[[deprecated("use Ort::Value::GetTensorTypeAndShape()")]] OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value);
|
1615 |
+
|
1616 |
+
/** \deprecated use Ort::TensorTypeAndShapeInfo::GetElementCount()
|
1617 |
+
* [[deprecated]]
|
1618 |
+
* This interface is redundant.
|
1619 |
+
*/
|
1620 |
+
[[deprecated("use Ort::TensorTypeAndShapeInfo::GetElementCount()")]] size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info);
|
1621 |
+
|
1622 |
+
/** \deprecated use Ort::TensorTypeAndShapeInfo::GetElementType()
|
1623 |
+
* [[deprecated]]
|
1624 |
+
* This interface is redundant.
|
1625 |
+
*/
|
1626 |
+
[[deprecated("use Ort::TensorTypeAndShapeInfo::GetElementType()")]] ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info);
|
1627 |
+
|
1628 |
+
/** \deprecated use Ort::TensorTypeAndShapeInfo::GetDimensionsCount()
|
1629 |
+
* [[deprecated]]
|
1630 |
+
* This interface is redundant.
|
1631 |
+
*/
|
1632 |
+
[[deprecated("use Ort::TensorTypeAndShapeInfo::GetDimensionsCount()")]] size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info);
|
1633 |
+
|
1634 |
+
/** \deprecated use Ort::TensorTypeAndShapeInfo::GetShape()
|
1635 |
+
* [[deprecated]]
|
1636 |
+
* This interface is redundant.
|
1637 |
+
*/
|
1638 |
+
[[deprecated("use Ort::TensorTypeAndShapeInfo::GetShape()")]] void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length);
|
1639 |
+
|
1640 |
+
/** \deprecated
|
1641 |
+
* [[deprecated]]
|
1642 |
+
* This interface sets dimensions to TensorTypeAndShapeInfo, but has no effect on the OrtValue.
|
1643 |
+
*/
|
1644 |
+
[[deprecated("Do not use")]] void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count);
|
1645 |
+
|
1646 |
+
/** \deprecated use Ort::Value::GetTensorMutableData()
|
1647 |
+
* [[deprecated]]
|
1648 |
+
* This interface is redundant.
|
1649 |
+
*/
|
1650 |
+
template <typename T>
|
1651 |
+
[[deprecated("use Ort::Value::GetTensorMutableData()")]] T* GetTensorMutableData(_Inout_ OrtValue* value);
|
1652 |
+
|
1653 |
+
/** \deprecated use Ort::Value::GetTensorData()
|
1654 |
+
* [[deprecated]]
|
1655 |
+
* This interface is redundant.
|
1656 |
+
*/
|
1657 |
+
template <typename T>
|
1658 |
+
[[deprecated("use Ort::Value::GetTensorData()")]] const T* GetTensorData(_Inout_ const OrtValue* value);
|
1659 |
+
|
1660 |
+
/** \deprecated use Ort::Value::GetTensorMemoryInfo()
|
1661 |
+
* [[deprecated]]
|
1662 |
+
* This interface is redundant.
|
1663 |
+
*/
|
1664 |
+
[[deprecated("use Ort::Value::GetTensorMemoryInfo()")]] const OrtMemoryInfo* GetTensorMemoryInfo(_In_ const OrtValue* value);
|
1665 |
+
|
1666 |
+
/** \deprecated use Ort::TensorTypeAndShapeInfo::GetShape()
|
1667 |
+
* [[deprecated]]
|
1668 |
+
* This interface is redundant.
|
1669 |
+
*/
|
1670 |
+
[[deprecated("use Ort::TensorTypeAndShapeInfo::GetShape()")]] std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info);
|
1671 |
+
|
1672 |
+
/** \deprecated use TensorTypeAndShapeInfo instances for automatic ownership.
|
1673 |
+
* [[deprecated]]
|
1674 |
+
* This interface is not exception safe.
|
1675 |
+
*/
|
1676 |
+
[[deprecated("use TensorTypeAndShapeInfo")]] void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input);
|
1677 |
+
|
1678 |
+
/** \deprecated use Ort::KernelContext::GetInputCount
|
1679 |
+
* [[deprecated]]
|
1680 |
+
* This interface is redundant.
|
1681 |
+
*/
|
1682 |
+
[[deprecated("use Ort::KernelContext::GetInputCount")]] size_t KernelContext_GetInputCount(const OrtKernelContext* context);
|
1683 |
+
|
1684 |
+
/** \deprecated use Ort::KernelContext::GetInput
|
1685 |
+
* [[deprecated]]
|
1686 |
+
* This interface is redundant.
|
1687 |
+
*/
|
1688 |
+
[[deprecated("use Ort::KernelContext::GetInput")]] const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index);
|
1689 |
+
|
1690 |
+
/** \deprecated use Ort::KernelContext::GetOutputCount
|
1691 |
+
* [[deprecated]]
|
1692 |
+
* This interface is redundant.
|
1693 |
+
*/
|
1694 |
+
[[deprecated("use Ort::KernelContext::GetOutputCount")]] size_t KernelContext_GetOutputCount(const OrtKernelContext* context);
|
1695 |
+
|
1696 |
+
/** \deprecated use Ort::KernelContext::GetOutput
|
1697 |
+
* [[deprecated]]
|
1698 |
+
* This interface is redundant.
|
1699 |
+
*/
|
1700 |
+
[[deprecated("use Ort::KernelContext::GetOutput")]] OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count);
|
1701 |
+
|
1702 |
+
/** \deprecated use Ort::KernelContext::GetGPUComputeStream
|
1703 |
+
* [[deprecated]]
|
1704 |
+
* This interface is redundant.
|
1705 |
+
*/
|
1706 |
+
[[deprecated("use Ort::KernelContext::GetGPUComputeStream")]] void* KernelContext_GetGPUComputeStream(const OrtKernelContext* context);
|
1707 |
+
|
1708 |
+
/** \deprecated use Ort::ThrowOnError()
|
1709 |
+
* [[deprecated]]
|
1710 |
+
* This interface is redundant.
|
1711 |
+
*/
|
1712 |
+
[[deprecated("use Ort::ThrowOnError()")]] void ThrowOnError(OrtStatus* result);
|
1713 |
+
|
1714 |
+
/** \deprecated use Ort::OpAttr
|
1715 |
+
* [[deprecated]]
|
1716 |
+
* This interface is not exception safe.
|
1717 |
+
*/
|
1718 |
+
[[deprecated("use Ort::OpAttr")]] OrtOpAttr* CreateOpAttr(_In_ const char* name,
|
1719 |
+
_In_ const void* data,
|
1720 |
+
_In_ int len,
|
1721 |
+
_In_ OrtOpAttrType type);
|
1722 |
+
|
1723 |
+
/** \deprecated use Ort::OpAttr
|
1724 |
+
* [[deprecated]]
|
1725 |
+
* This interface is not exception safe.
|
1726 |
+
*/
|
1727 |
+
[[deprecated("use Ort::OpAttr")]] void ReleaseOpAttr(_Frees_ptr_opt_ OrtOpAttr* op_attr);
|
1728 |
+
|
1729 |
+
/** \deprecated use Ort::Op
|
1730 |
+
* [[deprecated]]
|
1731 |
+
* This interface is not exception safe.
|
1732 |
+
*/
|
1733 |
+
[[deprecated("use Ort::Op")]] OrtOp* CreateOp(_In_ const OrtKernelInfo* info,
|
1734 |
+
_In_ const char* op_name,
|
1735 |
+
_In_ const char* domain,
|
1736 |
+
_In_ int version,
|
1737 |
+
_In_opt_ const char** type_constraint_names,
|
1738 |
+
_In_opt_ const ONNXTensorElementDataType* type_constraint_values,
|
1739 |
+
_In_opt_ int type_constraint_count,
|
1740 |
+
_In_opt_ const OrtOpAttr* const* attr_values,
|
1741 |
+
_In_opt_ int attr_count,
|
1742 |
+
_In_ int input_count,
|
1743 |
+
_In_ int output_count);
|
1744 |
+
|
1745 |
+
/** \deprecated use Ort::Op::Invoke
|
1746 |
+
* [[deprecated]]
|
1747 |
+
* This interface is redundant
|
1748 |
+
*/
|
1749 |
+
[[deprecated("use Ort::Op::Invoke")]] void InvokeOp(_In_ const OrtKernelContext* context,
|
1750 |
+
_In_ const OrtOp* ort_op,
|
1751 |
+
_In_ const OrtValue* const* input_values,
|
1752 |
+
_In_ int input_count,
|
1753 |
+
_Inout_ OrtValue* const* output_values,
|
1754 |
+
_In_ int output_count);
|
1755 |
+
|
1756 |
+
/** \deprecated use Ort::Op for automatic lifespan management.
|
1757 |
+
* [[deprecated]]
|
1758 |
+
* This interface is not exception safe.
|
1759 |
+
*/
|
1760 |
+
[[deprecated("use Ort::Op")]] void ReleaseOp(_Frees_ptr_opt_ OrtOp* ort_op);
|
1761 |
+
|
1762 |
+
/** \deprecated use Ort::KernelInfo for automatic lifespan management or for
|
1763 |
+
* querying attributes
|
1764 |
+
* [[deprecated]]
|
1765 |
+
* This interface is redundant
|
1766 |
+
*/
|
1767 |
+
template <typename T> // T is only implemented for std::vector<float>, std::vector<int64_t>, float, int64_t, and string
|
1768 |
+
[[deprecated("use Ort::KernelInfo::GetAttribute")]] T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name);
|
1769 |
+
|
1770 |
+
/** \deprecated use Ort::KernelInfo::Copy
|
1771 |
+
* querying attributes
|
1772 |
+
* [[deprecated]]
|
1773 |
+
* This interface is not exception safe
|
1774 |
+
*/
|
1775 |
+
[[deprecated("use Ort::KernelInfo::Copy")]] OrtKernelInfo* CopyKernelInfo(_In_ const OrtKernelInfo* info);
|
1776 |
+
|
1777 |
+
/** \deprecated use Ort::KernelInfo for lifespan management
|
1778 |
+
* querying attributes
|
1779 |
+
* [[deprecated]]
|
1780 |
+
* This interface is not exception safe
|
1781 |
+
*/
|
1782 |
+
[[deprecated("use Ort::KernelInfo")]] void ReleaseKernelInfo(_Frees_ptr_opt_ OrtKernelInfo* info_copy);
|
1783 |
+
|
1784 |
+
private:
|
1785 |
+
const OrtApi& api_;
|
1786 |
+
};
|
1787 |
+
|
1788 |
+
template <typename TOp, typename TKernel>
|
1789 |
+
struct CustomOpBase : OrtCustomOp {
|
1790 |
+
CustomOpBase() {
|
1791 |
+
OrtCustomOp::version = ORT_API_VERSION;
|
1792 |
+
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); };
|
1793 |
+
OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); };
|
1794 |
+
|
1795 |
+
OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); };
|
1796 |
+
|
1797 |
+
OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); };
|
1798 |
+
OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); };
|
1799 |
+
OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputMemoryType(index); };
|
1800 |
+
|
1801 |
+
OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
|
1802 |
+
OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); };
|
1803 |
+
|
1804 |
+
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast<TKernel*>(op_kernel)->Compute(context); };
|
1805 |
+
#if defined(_MSC_VER) && !defined(__clang__)
|
1806 |
+
#pragma warning(push)
|
1807 |
+
#pragma warning(disable : 26409)
|
1808 |
+
#endif
|
1809 |
+
OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
|
1810 |
+
#if defined(_MSC_VER) && !defined(__clang__)
|
1811 |
+
#pragma warning(pop)
|
1812 |
+
#endif
|
1813 |
+
OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputCharacteristic(index); };
|
1814 |
+
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index); };
|
1815 |
+
|
1816 |
+
OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicInputMinArity(); };
|
1817 |
+
OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicInputHomogeneity()); };
|
1818 |
+
OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicOutputMinArity(); };
|
1819 |
+
OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicOutputHomogeneity()); };
|
1820 |
+
}
|
1821 |
+
|
1822 |
+
// Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
|
1823 |
+
const char* GetExecutionProviderType() const { return nullptr; }
|
1824 |
+
|
1825 |
+
// Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
|
1826 |
+
// (inputs and outputs are required by default)
|
1827 |
+
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const {
|
1828 |
+
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
|
1829 |
+
}
|
1830 |
+
|
1831 |
+
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
|
1832 |
+
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
|
1833 |
+
}
|
1834 |
+
|
1835 |
+
// Default implemention of GetInputMemoryType() that returns OrtMemTypeDefault
|
1836 |
+
OrtMemType GetInputMemoryType(size_t /*index*/) const {
|
1837 |
+
return OrtMemTypeDefault;
|
1838 |
+
}
|
1839 |
+
|
1840 |
+
// Default implementation of GetVariadicInputMinArity() returns 1 to specify that a variadic input
|
1841 |
+
// should expect at least 1 argument.
|
1842 |
+
int GetVariadicInputMinArity() const {
|
1843 |
+
return 1;
|
1844 |
+
}
|
1845 |
+
|
1846 |
+
// Default implementation of GetVariadicInputHomegeneity() returns true to specify that all arguments
|
1847 |
+
// to a variadic input should be of the same type.
|
1848 |
+
bool GetVariadicInputHomogeneity() const {
|
1849 |
+
return true;
|
1850 |
+
}
|
1851 |
+
|
1852 |
+
// Default implementation of GetVariadicOutputMinArity() returns 1 to specify that a variadic output
|
1853 |
+
// should produce at least 1 output value.
|
1854 |
+
int GetVariadicOutputMinArity() const {
|
1855 |
+
return 1;
|
1856 |
+
}
|
1857 |
+
|
1858 |
+
// Default implementation of GetVariadicOutputHomegeneity() returns true to specify that all output values
|
1859 |
+
// produced by a variadic output should be of the same type.
|
1860 |
+
bool GetVariadicOutputHomogeneity() const {
|
1861 |
+
return true;
|
1862 |
+
}
|
1863 |
+
|
1864 |
+
// Declare list of session config entries used by this Custom Op.
|
1865 |
+
// Implement this function in order to get configs from CustomOpBase::GetSessionConfigs().
|
1866 |
+
// This default implementation returns an empty vector of config entries.
|
1867 |
+
std::vector<std::string> GetSessionConfigKeys() const {
|
1868 |
+
return std::vector<std::string>{};
|
1869 |
+
}
|
1870 |
+
|
1871 |
+
protected:
|
1872 |
+
// Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys.
|
1873 |
+
void GetSessionConfigs(std::unordered_map<std::string, std::string>& out, ConstSessionOptions options) const;
|
1874 |
+
};
|
1875 |
+
|
1876 |
+
} // namespace Ort
|
1877 |
+
|
1878 |
+
#include "onnxruntime_cxx_inline.h"
|
1.15.1/onnxruntime.xcframework/Headers/onnxruntime_cxx_inline.h
ADDED
@@ -0,0 +1,1888 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Microsoft Corporation. All rights reserved.
|
2 |
+
// Licensed under the MIT License.
|
3 |
+
|
4 |
+
// Do not include this file directly. Please include "onnxruntime_cxx_api.h" instead.
|
5 |
+
// If interested in trying out features of the new experimental C++ API, include "experimental_onnxruntime_cxx_api.h" instead.
|
6 |
+
//
|
7 |
+
// These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter
|
8 |
+
// the main C++ file with implementation details.
|
9 |
+
|
10 |
+
namespace Ort {
|
11 |
+
|
12 |
+
namespace detail {
|
13 |
+
inline void ThrowStatus(const Status& st) {
|
14 |
+
std::string error_message = st.GetErrorMessage();
|
15 |
+
OrtErrorCode error_code = st.GetErrorCode();
|
16 |
+
ORT_CXX_API_THROW(std::move(error_message), error_code);
|
17 |
+
}
|
18 |
+
} // namespace detail
|
19 |
+
|
20 |
+
inline void ThrowOnError(OrtStatus* ort_status) {
|
21 |
+
if (ort_status) {
|
22 |
+
Ort::Status st(ort_status);
|
23 |
+
detail::ThrowStatus(st);
|
24 |
+
}
|
25 |
+
}
|
26 |
+
|
27 |
+
inline void ThrowOnError(const Status& st) {
|
28 |
+
if (st) {
|
29 |
+
detail::ThrowStatus(st);
|
30 |
+
}
|
31 |
+
}
|
32 |
+
|
33 |
+
inline Status::Status(OrtStatus* status) : Base<OrtStatus>{status} {
|
34 |
+
}
|
35 |
+
|
36 |
+
inline Status::Status(const std::exception& e) {
|
37 |
+
p_ = GetApi().CreateStatus(ORT_FAIL, e.what());
|
38 |
+
}
|
39 |
+
|
40 |
+
inline Status::Status(const Exception& e) {
|
41 |
+
p_ = GetApi().CreateStatus(e.GetOrtErrorCode(), e.what());
|
42 |
+
}
|
43 |
+
|
44 |
+
inline std::string Status::GetErrorMessage() const {
|
45 |
+
std::string message(GetApi().GetErrorMessage(p_));
|
46 |
+
return message;
|
47 |
+
}
|
48 |
+
|
49 |
+
inline OrtErrorCode Status::GetErrorCode() const {
|
50 |
+
return GetApi().GetErrorCode(p_);
|
51 |
+
}
|
52 |
+
|
53 |
+
// This template converts a C++ type into it's ONNXTensorElementDataType
|
54 |
+
template <typename T>
|
55 |
+
struct TypeToTensorType;
|
56 |
+
template <>
|
57 |
+
struct TypeToTensorType<float> {
|
58 |
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
59 |
+
};
|
60 |
+
template <>
|
61 |
+
struct TypeToTensorType<Float16_t> {
|
62 |
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
|
63 |
+
};
|
64 |
+
template <>
|
65 |
+
struct TypeToTensorType<BFloat16_t> {
|
66 |
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16;
|
67 |
+
};
|
68 |
+
template <>
|
69 |
+
struct TypeToTensorType<double> {
|
70 |
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
|
71 |
+
};
|
72 |
+
template <>
|
73 |
+
struct TypeToTensorType<int8_t> {
|
74 |
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
|
75 |
+
};
|
76 |
+
template <>
|
77 |
+
struct TypeToTensorType<int16_t> {
|
78 |
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
|
79 |
+
};
|
80 |
+
template <>
|
81 |
+
struct TypeToTensorType<int32_t> {
|
82 |
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
|
83 |
+
};
|
84 |
+
template <>
|
85 |
+
struct TypeToTensorType<int64_t> {
|
86 |
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
87 |
+
};
|
88 |
+
template <>
|
89 |
+
struct TypeToTensorType<uint8_t> {
|
90 |
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
91 |
+
};
|
92 |
+
template <>
|
93 |
+
struct TypeToTensorType<uint16_t> {
|
94 |
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
|
95 |
+
};
|
96 |
+
template <>
|
97 |
+
struct TypeToTensorType<uint32_t> {
|
98 |
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
|
99 |
+
};
|
100 |
+
template <>
|
101 |
+
struct TypeToTensorType<uint64_t> {
|
102 |
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
|
103 |
+
};
|
104 |
+
template <>
|
105 |
+
struct TypeToTensorType<bool> {
|
106 |
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
107 |
+
};
|
108 |
+
|
109 |
+
inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size)
|
110 |
+
: allocator_(allocator), p_(p), size_(size) {
|
111 |
+
}
|
112 |
+
|
113 |
+
inline MemoryAllocation::~MemoryAllocation() {
|
114 |
+
if (p_ != nullptr) {
|
115 |
+
// We do not throw out of destructor
|
116 |
+
auto ret = GetApi().AllocatorFree(allocator_, p_);
|
117 |
+
static_cast<void>(ret);
|
118 |
+
}
|
119 |
+
}
|
120 |
+
|
121 |
+
inline MemoryAllocation::MemoryAllocation(MemoryAllocation&& o) noexcept : allocator_(nullptr), p_(nullptr), size_(0) {
|
122 |
+
*this = std::move(o);
|
123 |
+
}
|
124 |
+
|
125 |
+
inline MemoryAllocation& MemoryAllocation::operator=(MemoryAllocation&& o) noexcept {
|
126 |
+
OrtAllocator* alloc = nullptr;
|
127 |
+
void* p = nullptr;
|
128 |
+
size_t sz = 0;
|
129 |
+
|
130 |
+
// Swap out this
|
131 |
+
std::swap(alloc, allocator_);
|
132 |
+
std::swap(p, p_);
|
133 |
+
std::swap(sz, size_);
|
134 |
+
|
135 |
+
// Swap with incoming
|
136 |
+
std::swap(allocator_, o.allocator_);
|
137 |
+
std::swap(p_, o.p_);
|
138 |
+
std::swap(size_, o.size_);
|
139 |
+
|
140 |
+
// Destroy this instance if needed
|
141 |
+
MemoryAllocation this_alloc(alloc, p, sz);
|
142 |
+
return *this;
|
143 |
+
}
|
144 |
+
|
145 |
+
namespace detail {
|
146 |
+
|
147 |
+
template <typename T>
|
148 |
+
inline void* AllocatorImpl<T>::Alloc(size_t size) {
|
149 |
+
void* out;
|
150 |
+
ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
|
151 |
+
return out;
|
152 |
+
}
|
153 |
+
|
154 |
+
template <typename T>
|
155 |
+
inline MemoryAllocation AllocatorImpl<T>::GetAllocation(size_t size) {
|
156 |
+
void* out;
|
157 |
+
ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
|
158 |
+
MemoryAllocation result(this->p_, out, size);
|
159 |
+
return result;
|
160 |
+
}
|
161 |
+
|
162 |
+
template <typename T>
|
163 |
+
inline void AllocatorImpl<T>::Free(void* p) {
|
164 |
+
ThrowOnError(GetApi().AllocatorFree(this->p_, p));
|
165 |
+
}
|
166 |
+
|
167 |
+
template <typename T>
|
168 |
+
inline ConstMemoryInfo AllocatorImpl<T>::GetInfo() const {
|
169 |
+
const OrtMemoryInfo* out;
|
170 |
+
ThrowOnError(GetApi().AllocatorGetInfo(this->p_, &out));
|
171 |
+
return ConstMemoryInfo{out};
|
172 |
+
}
|
173 |
+
|
174 |
+
} // namespace detail
|
175 |
+
|
176 |
+
inline AllocatorWithDefaultOptions::AllocatorWithDefaultOptions() {
|
177 |
+
ThrowOnError(GetApi().GetAllocatorWithDefaultOptions(&this->p_));
|
178 |
+
}
|
179 |
+
|
180 |
+
inline Allocator::Allocator(const Session& sess, const OrtMemoryInfo* mem_info) {
|
181 |
+
ThrowOnError(GetApi().CreateAllocator(sess, mem_info, &this->p_));
|
182 |
+
}
|
183 |
+
|
184 |
+
namespace detail {
|
185 |
+
|
186 |
+
template <typename T>
|
187 |
+
inline std::string MemoryInfoImpl<T>::GetAllocatorName() const {
|
188 |
+
const char* name = nullptr;
|
189 |
+
ThrowOnError(GetApi().MemoryInfoGetName(this->p_, &name));
|
190 |
+
return std::string(name);
|
191 |
+
}
|
192 |
+
|
193 |
+
template <typename T>
|
194 |
+
inline OrtAllocatorType MemoryInfoImpl<T>::GetAllocatorType() const {
|
195 |
+
OrtAllocatorType type;
|
196 |
+
ThrowOnError(GetApi().MemoryInfoGetType(this->p_, &type));
|
197 |
+
return type;
|
198 |
+
}
|
199 |
+
|
200 |
+
template <typename T>
|
201 |
+
inline int MemoryInfoImpl<T>::GetDeviceId() const {
|
202 |
+
int id = 0;
|
203 |
+
ThrowOnError(GetApi().MemoryInfoGetId(this->p_, &id));
|
204 |
+
return id;
|
205 |
+
}
|
206 |
+
|
207 |
+
template <typename T>
|
208 |
+
inline OrtMemoryInfoDeviceType MemoryInfoImpl<T>::GetDeviceType() const {
|
209 |
+
OrtMemoryInfoDeviceType type;
|
210 |
+
GetApi().MemoryInfoGetDeviceType(this->p_, &type);
|
211 |
+
return type;
|
212 |
+
}
|
213 |
+
|
214 |
+
template <typename T>
|
215 |
+
inline OrtMemType MemoryInfoImpl<T>::GetMemoryType() const {
|
216 |
+
OrtMemType type;
|
217 |
+
ThrowOnError(GetApi().MemoryInfoGetMemType(this->p_, &type));
|
218 |
+
return type;
|
219 |
+
}
|
220 |
+
|
221 |
+
template <typename T>
|
222 |
+
template <typename U>
|
223 |
+
inline bool MemoryInfoImpl<T>::operator==(const MemoryInfoImpl<U>& o) const {
|
224 |
+
int comp_result = 0;
|
225 |
+
ThrowOnError(Ort::GetApi().CompareMemoryInfo(this->p_, o, &comp_result));
|
226 |
+
return comp_result == 0;
|
227 |
+
}
|
228 |
+
|
229 |
+
} // namespace detail
|
230 |
+
|
231 |
+
inline MemoryInfo MemoryInfo::CreateCpu(OrtAllocatorType type, OrtMemType mem_type) {
|
232 |
+
OrtMemoryInfo* p;
|
233 |
+
ThrowOnError(GetApi().CreateCpuMemoryInfo(type, mem_type, &p));
|
234 |
+
return MemoryInfo(p);
|
235 |
+
}
|
236 |
+
|
237 |
+
inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) {
|
238 |
+
ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &this->p_));
|
239 |
+
}
|
240 |
+
|
241 |
+
namespace detail {
|
242 |
+
template <typename T>
|
243 |
+
inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames() const {
|
244 |
+
AllocatorWithDefaultOptions allocator;
|
245 |
+
return binding_utils::GetOutputNamesHelper(this->p_, allocator);
|
246 |
+
}
|
247 |
+
|
248 |
+
template <typename T>
|
249 |
+
inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames(OrtAllocator* allocator) const {
|
250 |
+
return binding_utils::GetOutputNamesHelper(this->p_, allocator);
|
251 |
+
}
|
252 |
+
|
253 |
+
template <typename T>
|
254 |
+
inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues() const {
|
255 |
+
AllocatorWithDefaultOptions allocator;
|
256 |
+
return binding_utils::GetOutputValuesHelper(this->p_, allocator);
|
257 |
+
}
|
258 |
+
|
259 |
+
template <typename T>
|
260 |
+
inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues(OrtAllocator* allocator) const {
|
261 |
+
return binding_utils::GetOutputValuesHelper(this->p_, allocator);
|
262 |
+
}
|
263 |
+
|
264 |
+
template <typename T>
|
265 |
+
inline void IoBindingImpl<T>::BindInput(const char* name, const Value& value) {
|
266 |
+
ThrowOnError(GetApi().BindInput(this->p_, name, value));
|
267 |
+
}
|
268 |
+
|
269 |
+
template <typename T>
|
270 |
+
inline void IoBindingImpl<T>::BindOutput(const char* name, const Value& value) {
|
271 |
+
ThrowOnError(GetApi().BindOutput(this->p_, name, value));
|
272 |
+
}
|
273 |
+
|
274 |
+
template <typename T>
|
275 |
+
inline void IoBindingImpl<T>::BindOutput(const char* name, const OrtMemoryInfo* mem_info) {
|
276 |
+
ThrowOnError(GetApi().BindOutputToDevice(this->p_, name, mem_info));
|
277 |
+
}
|
278 |
+
|
279 |
+
template <typename T>
|
280 |
+
inline void IoBindingImpl<T>::ClearBoundInputs() {
|
281 |
+
GetApi().ClearBoundInputs(this->p_);
|
282 |
+
}
|
283 |
+
|
284 |
+
template <typename T>
|
285 |
+
inline void IoBindingImpl<T>::ClearBoundOutputs() {
|
286 |
+
GetApi().ClearBoundOutputs(this->p_);
|
287 |
+
}
|
288 |
+
|
289 |
+
template <typename T>
|
290 |
+
inline void IoBindingImpl<T>::SynchronizeInputs() {
|
291 |
+
ThrowOnError(GetApi().SynchronizeBoundInputs(this->p_));
|
292 |
+
}
|
293 |
+
|
294 |
+
template <typename T>
|
295 |
+
inline void IoBindingImpl<T>::SynchronizeOutputs() {
|
296 |
+
ThrowOnError(GetApi().SynchronizeBoundOutputs(this->p_));
|
297 |
+
}
|
298 |
+
|
299 |
+
namespace binding_utils {
|
300 |
+
inline std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
|
301 |
+
std::vector<std::string> result;
|
302 |
+
auto free_fn = detail::AllocatedFree(allocator);
|
303 |
+
using Ptr = std::unique_ptr<void, decltype(free_fn)>;
|
304 |
+
|
305 |
+
char* buffer = nullptr;
|
306 |
+
size_t* lengths = nullptr;
|
307 |
+
size_t count = 0;
|
308 |
+
ThrowOnError(GetApi().GetBoundOutputNames(binding, allocator, &buffer, &lengths, &count));
|
309 |
+
|
310 |
+
if (count == 0) {
|
311 |
+
return result;
|
312 |
+
}
|
313 |
+
|
314 |
+
Ptr buffer_g(buffer, free_fn);
|
315 |
+
Ptr lengths_g(lengths, free_fn);
|
316 |
+
|
317 |
+
result.reserve(count);
|
318 |
+
for (size_t i = 0; i < count; ++i) {
|
319 |
+
auto sz = *lengths;
|
320 |
+
result.emplace_back(buffer, sz);
|
321 |
+
buffer += sz;
|
322 |
+
++lengths;
|
323 |
+
}
|
324 |
+
return result;
|
325 |
+
}
|
326 |
+
|
327 |
+
inline std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
|
328 |
+
std::vector<Value> result;
|
329 |
+
size_t owned = 0;
|
330 |
+
size_t output_count = 0;
|
331 |
+
// Lambda to release the buffer when no longer needed and
|
332 |
+
// make sure that we destroy all instances on exception
|
333 |
+
auto free_fn = [&owned, &output_count, allocator](OrtValue** buffer) {
|
334 |
+
if (buffer) {
|
335 |
+
while (owned < output_count) {
|
336 |
+
auto* p = buffer + owned++;
|
337 |
+
GetApi().ReleaseValue(*p);
|
338 |
+
}
|
339 |
+
allocator->Free(allocator, buffer);
|
340 |
+
}
|
341 |
+
};
|
342 |
+
using Ptr = std::unique_ptr<OrtValue*, decltype(free_fn)>;
|
343 |
+
|
344 |
+
OrtValue** output_buffer = nullptr;
|
345 |
+
ThrowOnError(GetApi().GetBoundOutputValues(binding, allocator, &output_buffer, &output_count));
|
346 |
+
if (output_count == 0) {
|
347 |
+
return result;
|
348 |
+
}
|
349 |
+
|
350 |
+
Ptr buffer_g(output_buffer, free_fn);
|
351 |
+
|
352 |
+
result.reserve(output_count);
|
353 |
+
for (size_t i = 0; i < output_count; ++i) {
|
354 |
+
result.emplace_back(output_buffer[i]);
|
355 |
+
++owned;
|
356 |
+
}
|
357 |
+
return result;
|
358 |
+
}
|
359 |
+
|
360 |
+
} // namespace binding_utils
|
361 |
+
} // namespace detail
|
362 |
+
|
363 |
+
inline IoBinding::IoBinding(Session& session) {
|
364 |
+
ThrowOnError(GetApi().CreateIoBinding(session, &this->p_));
|
365 |
+
}
|
366 |
+
|
367 |
+
inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk) {
|
368 |
+
ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_));
|
369 |
+
}
|
370 |
+
|
371 |
+
inline ThreadingOptions::ThreadingOptions() {
|
372 |
+
ThrowOnError(GetApi().CreateThreadingOptions(&p_));
|
373 |
+
}
|
374 |
+
|
375 |
+
inline ThreadingOptions& ThreadingOptions::SetGlobalIntraOpNumThreads(int intra_op_num_threads) {
|
376 |
+
ThrowOnError(GetApi().SetGlobalIntraOpNumThreads(p_, intra_op_num_threads));
|
377 |
+
return *this;
|
378 |
+
}
|
379 |
+
|
380 |
+
inline ThreadingOptions& ThreadingOptions::SetGlobalInterOpNumThreads(int inter_op_num_threads) {
|
381 |
+
ThrowOnError(GetApi().SetGlobalInterOpNumThreads(p_, inter_op_num_threads));
|
382 |
+
return *this;
|
383 |
+
}
|
384 |
+
|
385 |
+
inline ThreadingOptions& ThreadingOptions::SetGlobalSpinControl(int allow_spinning) {
|
386 |
+
ThrowOnError(GetApi().SetGlobalSpinControl(p_, allow_spinning));
|
387 |
+
return *this;
|
388 |
+
}
|
389 |
+
|
390 |
+
inline ThreadingOptions& ThreadingOptions::SetGlobalDenormalAsZero() {
|
391 |
+
ThrowOnError(GetApi().SetGlobalDenormalAsZero(p_));
|
392 |
+
return *this;
|
393 |
+
}
|
394 |
+
|
395 |
+
inline ThreadingOptions& ThreadingOptions::SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
|
396 |
+
ThrowOnError(GetApi().SetGlobalCustomCreateThreadFn(p_, ort_custom_create_thread_fn));
|
397 |
+
return *this;
|
398 |
+
}
|
399 |
+
|
400 |
+
inline ThreadingOptions& ThreadingOptions::SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
|
401 |
+
ThrowOnError(GetApi().SetGlobalCustomThreadCreationOptions(p_, ort_custom_thread_creation_options));
|
402 |
+
return *this;
|
403 |
+
}
|
404 |
+
|
405 |
+
inline ThreadingOptions& ThreadingOptions::SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
|
406 |
+
ThrowOnError(GetApi().SetGlobalCustomJoinThreadFn(p_, ort_custom_join_thread_fn));
|
407 |
+
return *this;
|
408 |
+
}
|
409 |
+
|
410 |
+
inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) {
|
411 |
+
ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_));
|
412 |
+
if (strcmp(logid, "onnxruntime-node") == 0) {
|
413 |
+
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
|
414 |
+
} else {
|
415 |
+
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
|
416 |
+
}
|
417 |
+
}
|
418 |
+
|
419 |
+
inline Env::Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param) {
|
420 |
+
ThrowOnError(GetApi().CreateEnvWithCustomLogger(logging_function, logger_param, logging_level, logid, &p_));
|
421 |
+
if (strcmp(logid, "onnxruntime-node") == 0) {
|
422 |
+
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
|
423 |
+
} else {
|
424 |
+
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
|
425 |
+
}
|
426 |
+
}
|
427 |
+
|
428 |
+
inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level, _In_ const char* logid) {
|
429 |
+
ThrowOnError(GetApi().CreateEnvWithGlobalThreadPools(logging_level, logid, tp_options, &p_));
|
430 |
+
if (strcmp(logid, "onnxruntime-node") == 0) {
|
431 |
+
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
|
432 |
+
} else {
|
433 |
+
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
|
434 |
+
}
|
435 |
+
}
|
436 |
+
|
437 |
+
inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
|
438 |
+
OrtLoggingLevel logging_level, _In_ const char* logid) {
|
439 |
+
ThrowOnError(GetApi().CreateEnvWithCustomLoggerAndGlobalThreadPools(logging_function, logger_param, logging_level, logid, tp_options, &p_));
|
440 |
+
if (strcmp(logid, "onnxruntime-node") == 0) {
|
441 |
+
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
|
442 |
+
} else {
|
443 |
+
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
|
444 |
+
}
|
445 |
+
}
|
446 |
+
|
447 |
+
inline Env& Env::EnableTelemetryEvents() {
|
448 |
+
ThrowOnError(GetApi().EnableTelemetryEvents(p_));
|
449 |
+
return *this;
|
450 |
+
}
|
451 |
+
|
452 |
+
inline Env& Env::DisableTelemetryEvents() {
|
453 |
+
ThrowOnError(GetApi().DisableTelemetryEvents(p_));
|
454 |
+
return *this;
|
455 |
+
}
|
456 |
+
|
457 |
+
inline Env& Env::UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level) {
|
458 |
+
ThrowOnError(GetApi().UpdateEnvWithCustomLogLevel(p_, log_severity_level));
|
459 |
+
return *this;
|
460 |
+
}
|
461 |
+
|
462 |
+
inline Env& Env::CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg) {
|
463 |
+
ThrowOnError(GetApi().CreateAndRegisterAllocator(p_, mem_info, arena_cfg));
|
464 |
+
return *this;
|
465 |
+
}
|
466 |
+
|
467 |
+
inline CustomOpDomain::CustomOpDomain(const char* domain) {
|
468 |
+
ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_));
|
469 |
+
}
|
470 |
+
|
471 |
+
inline void CustomOpDomain::Add(const OrtCustomOp* op) {
|
472 |
+
ThrowOnError(GetApi().CustomOpDomain_Add(p_, op));
|
473 |
+
}
|
474 |
+
|
475 |
+
inline RunOptions::RunOptions() {
|
476 |
+
ThrowOnError(GetApi().CreateRunOptions(&p_));
|
477 |
+
}
|
478 |
+
|
479 |
+
inline RunOptions& RunOptions::SetRunLogVerbosityLevel(int level) {
|
480 |
+
ThrowOnError(GetApi().RunOptionsSetRunLogVerbosityLevel(p_, level));
|
481 |
+
return *this;
|
482 |
+
}
|
483 |
+
|
484 |
+
inline RunOptions& RunOptions::SetRunLogSeverityLevel(int level) {
|
485 |
+
ThrowOnError(GetApi().RunOptionsSetRunLogSeverityLevel(p_, level));
|
486 |
+
return *this;
|
487 |
+
}
|
488 |
+
|
489 |
+
inline int RunOptions::GetRunLogVerbosityLevel() const {
|
490 |
+
int out;
|
491 |
+
ThrowOnError(GetApi().RunOptionsGetRunLogVerbosityLevel(p_, &out));
|
492 |
+
return out;
|
493 |
+
}
|
494 |
+
|
495 |
+
inline int RunOptions::GetRunLogSeverityLevel() const {
|
496 |
+
int out;
|
497 |
+
ThrowOnError(GetApi().RunOptionsGetRunLogSeverityLevel(p_, &out));
|
498 |
+
return out;
|
499 |
+
}
|
500 |
+
|
501 |
+
inline RunOptions& RunOptions::SetRunTag(const char* run_tag) {
|
502 |
+
ThrowOnError(GetApi().RunOptionsSetRunTag(p_, run_tag));
|
503 |
+
return *this;
|
504 |
+
}
|
505 |
+
|
506 |
+
inline const char* RunOptions::GetRunTag() const {
|
507 |
+
const char* out;
|
508 |
+
ThrowOnError(GetApi().RunOptionsGetRunTag(p_, &out));
|
509 |
+
return out;
|
510 |
+
}
|
511 |
+
|
512 |
+
inline RunOptions& RunOptions::AddConfigEntry(const char* config_key, const char* config_value) {
|
513 |
+
ThrowOnError(GetApi().AddRunConfigEntry(p_, config_key, config_value));
|
514 |
+
return *this;
|
515 |
+
}
|
516 |
+
|
517 |
+
inline RunOptions& RunOptions::SetTerminate() {
|
518 |
+
ThrowOnError(GetApi().RunOptionsSetTerminate(p_));
|
519 |
+
return *this;
|
520 |
+
}
|
521 |
+
|
522 |
+
inline RunOptions& RunOptions::UnsetTerminate() {
|
523 |
+
ThrowOnError(GetApi().RunOptionsUnsetTerminate(p_));
|
524 |
+
return *this;
|
525 |
+
}
|
526 |
+
|
527 |
+
namespace detail {
|
528 |
+
|
529 |
+
template <typename T>
|
530 |
+
inline Ort::SessionOptions ConstSessionOptionsImpl<T>::Clone() const {
|
531 |
+
OrtSessionOptions* out;
|
532 |
+
ThrowOnError(GetApi().CloneSessionOptions(this->p_, &out));
|
533 |
+
return SessionOptions{out};
|
534 |
+
}
|
535 |
+
|
536 |
+
template <typename T>
|
537 |
+
inline std::string ConstSessionOptionsImpl<T>::GetConfigEntry(const char* config_key) const {
|
538 |
+
size_t size = 0;
|
539 |
+
// Feed nullptr for the data buffer to query the true size of the string value
|
540 |
+
Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, nullptr, &size));
|
541 |
+
|
542 |
+
std::string out;
|
543 |
+
out.resize(size);
|
544 |
+
Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, &out[0], &size));
|
545 |
+
out.resize(size - 1); // remove the terminating character '\0'
|
546 |
+
|
547 |
+
return out;
|
548 |
+
}
|
549 |
+
|
550 |
+
template <typename T>
|
551 |
+
inline bool ConstSessionOptionsImpl<T>::HasConfigEntry(const char* config_key) const {
|
552 |
+
int out = 0;
|
553 |
+
Ort::ThrowOnError(GetApi().HasSessionConfigEntry(this->p_, config_key, &out));
|
554 |
+
return static_cast<bool>(out);
|
555 |
+
}
|
556 |
+
|
557 |
+
template <typename T>
|
558 |
+
inline std::string ConstSessionOptionsImpl<T>::GetConfigEntryOrDefault(const char* config_key, const std::string& def) {
|
559 |
+
if (!this->HasConfigEntry(config_key)) {
|
560 |
+
return def;
|
561 |
+
}
|
562 |
+
|
563 |
+
return this->GetConfigEntry(config_key);
|
564 |
+
}
|
565 |
+
|
566 |
+
template <typename T>
|
567 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetIntraOpNumThreads(int intra_op_num_threads) {
|
568 |
+
ThrowOnError(GetApi().SetIntraOpNumThreads(this->p_, intra_op_num_threads));
|
569 |
+
return *this;
|
570 |
+
}
|
571 |
+
|
572 |
+
template <typename T>
|
573 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetInterOpNumThreads(int inter_op_num_threads) {
|
574 |
+
ThrowOnError(GetApi().SetInterOpNumThreads(this->p_, inter_op_num_threads));
|
575 |
+
return *this;
|
576 |
+
}
|
577 |
+
|
578 |
+
template <typename T>
|
579 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) {
|
580 |
+
ThrowOnError(GetApi().SetSessionGraphOptimizationLevel(this->p_, graph_optimization_level));
|
581 |
+
return *this;
|
582 |
+
}
|
583 |
+
|
584 |
+
template <typename T>
|
585 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) {
|
586 |
+
ThrowOnError(GetApi().SetOptimizedModelFilePath(this->p_, optimized_model_filepath));
|
587 |
+
return *this;
|
588 |
+
}
|
589 |
+
|
590 |
+
template <typename T>
|
591 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableProfiling(const ORTCHAR_T* profile_file_prefix) {
|
592 |
+
ThrowOnError(GetApi().EnableProfiling(this->p_, profile_file_prefix));
|
593 |
+
return *this;
|
594 |
+
}
|
595 |
+
|
596 |
+
template <typename T>
|
597 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableProfiling() {
|
598 |
+
ThrowOnError(GetApi().DisableProfiling(this->p_));
|
599 |
+
return *this;
|
600 |
+
}
|
601 |
+
|
602 |
+
template <typename T>
|
603 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableOrtCustomOps() {
|
604 |
+
ThrowOnError(GetApi().EnableOrtCustomOps(this->p_));
|
605 |
+
return *this;
|
606 |
+
}
|
607 |
+
|
608 |
+
template <typename T>
|
609 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableMemPattern() {
|
610 |
+
ThrowOnError(GetApi().EnableMemPattern(this->p_));
|
611 |
+
return *this;
|
612 |
+
}
|
613 |
+
|
614 |
+
template <typename T>
|
615 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableMemPattern() {
|
616 |
+
ThrowOnError(GetApi().DisableMemPattern(this->p_));
|
617 |
+
return *this;
|
618 |
+
}
|
619 |
+
|
620 |
+
template <typename T>
|
621 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableCpuMemArena() {
|
622 |
+
ThrowOnError(GetApi().EnableCpuMemArena(this->p_));
|
623 |
+
return *this;
|
624 |
+
}
|
625 |
+
|
626 |
+
template <typename T>
|
627 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableCpuMemArena() {
|
628 |
+
ThrowOnError(GetApi().DisableCpuMemArena(this->p_));
|
629 |
+
return *this;
|
630 |
+
}
|
631 |
+
|
632 |
+
template <typename T>
|
633 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetExecutionMode(ExecutionMode execution_mode) {
|
634 |
+
ThrowOnError(GetApi().SetSessionExecutionMode(this->p_, execution_mode));
|
635 |
+
return *this;
|
636 |
+
}
|
637 |
+
|
638 |
+
template <typename T>
|
639 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLogId(const char* logid) {
|
640 |
+
ThrowOnError(GetApi().SetSessionLogId(this->p_, logid));
|
641 |
+
return *this;
|
642 |
+
}
|
643 |
+
|
644 |
+
template <typename T>
|
645 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLogSeverityLevel(int level) {
|
646 |
+
ThrowOnError(GetApi().SetSessionLogSeverityLevel(this->p_, level));
|
647 |
+
return *this;
|
648 |
+
}
|
649 |
+
|
650 |
+
template <typename T>
|
651 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::Add(OrtCustomOpDomain* custom_op_domain) {
|
652 |
+
ThrowOnError(GetApi().AddCustomOpDomain(this->p_, custom_op_domain));
|
653 |
+
return *this;
|
654 |
+
}
|
655 |
+
|
656 |
+
template <typename T>
|
657 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddConfigEntry(const char* config_key, const char* config_value) {
|
658 |
+
ThrowOnError(GetApi().AddSessionConfigEntry(this->p_, config_key, config_value));
|
659 |
+
return *this;
|
660 |
+
}
|
661 |
+
|
662 |
+
template <typename T>
|
663 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddInitializer(const char* name, const OrtValue* ort_val) {
|
664 |
+
ThrowOnError(GetApi().AddInitializer(this->p_, name, ort_val));
|
665 |
+
return *this;
|
666 |
+
}
|
667 |
+
|
668 |
+
template <typename T>
|
669 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisablePerSessionThreads() {
|
670 |
+
ThrowOnError(GetApi().DisablePerSessionThreads(this->p_));
|
671 |
+
return *this;
|
672 |
+
}
|
673 |
+
|
674 |
+
template <typename T>
|
675 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddExternalInitializers(const std::vector<std::string>& names,
|
676 |
+
const std::vector<Value>& ort_values) {
|
677 |
+
const size_t inputs_num = names.size();
|
678 |
+
if (inputs_num != ort_values.size()) {
|
679 |
+
ORT_CXX_API_THROW("Expecting names and ort_values to have the same length", ORT_INVALID_ARGUMENT);
|
680 |
+
}
|
681 |
+
std::vector<const char*> names_ptr;
|
682 |
+
std::vector<const OrtValue*> ort_values_ptrs;
|
683 |
+
names_ptr.reserve(inputs_num);
|
684 |
+
ort_values_ptrs.reserve(inputs_num);
|
685 |
+
for (size_t i = 0; i < inputs_num; ++i) {
|
686 |
+
names_ptr.push_back(names[i].c_str());
|
687 |
+
ort_values_ptrs.push_back(ort_values[i]);
|
688 |
+
}
|
689 |
+
ThrowOnError(GetApi().AddExternalInitializers(this->p_, names_ptr.data(), ort_values_ptrs.data(), inputs_num));
|
690 |
+
return *this;
|
691 |
+
}
|
692 |
+
|
693 |
+
template <typename T>
|
694 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options) {
|
695 |
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA(this->p_, &provider_options));
|
696 |
+
return *this;
|
697 |
+
}
|
698 |
+
|
699 |
+
template <typename T>
|
700 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options) {
|
701 |
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA_V2(this->p_, &provider_options));
|
702 |
+
return *this;
|
703 |
+
}
|
704 |
+
|
705 |
+
template <typename T>
|
706 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options) {
|
707 |
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_ROCM(this->p_, &provider_options));
|
708 |
+
return *this;
|
709 |
+
}
|
710 |
+
|
711 |
+
template <typename T>
|
712 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options) {
|
713 |
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT(this->p_, &provider_options));
|
714 |
+
return *this;
|
715 |
+
}
|
716 |
+
|
717 |
+
template <typename T>
|
718 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options) {
|
719 |
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT_V2(this->p_, &provider_options));
|
720 |
+
return *this;
|
721 |
+
}
|
722 |
+
|
723 |
+
template <typename T>
|
724 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options) {
|
725 |
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_MIGraphX(this->p_, &provider_options));
|
726 |
+
return *this;
|
727 |
+
}
|
728 |
+
|
729 |
+
template <typename T>
|
730 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options) {
|
731 |
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options));
|
732 |
+
return *this;
|
733 |
+
}
|
734 |
+
|
735 |
+
template <typename T>
|
736 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options) {
|
737 |
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_Dnnl(this->p_, &provider_options));
|
738 |
+
return *this;
|
739 |
+
}
|
740 |
+
|
741 |
+
template <typename T>
|
742 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider(
|
743 |
+
const std::string& provider_name,
|
744 |
+
const std::unordered_map<std::string, std::string>& provider_options) {
|
745 |
+
auto num_entries = provider_options.size();
|
746 |
+
std::vector<const char*> keys, values;
|
747 |
+
if (num_entries > 0) {
|
748 |
+
keys.reserve(num_entries);
|
749 |
+
values.reserve(num_entries);
|
750 |
+
|
751 |
+
for (const auto& entry : provider_options) {
|
752 |
+
keys.push_back(entry.first.c_str());
|
753 |
+
values.push_back(entry.second.c_str());
|
754 |
+
}
|
755 |
+
}
|
756 |
+
|
757 |
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider(this->p_, provider_name.c_str(),
|
758 |
+
keys.data(), values.data(), num_entries));
|
759 |
+
|
760 |
+
return *this;
|
761 |
+
}
|
762 |
+
|
763 |
+
template <typename T>
|
764 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
|
765 |
+
ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn));
|
766 |
+
return *this;
|
767 |
+
}
|
768 |
+
|
769 |
+
template <typename T>
|
770 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
|
771 |
+
ThrowOnError(GetApi().SessionOptionsSetCustomThreadCreationOptions(this->p_, ort_custom_thread_creation_options));
|
772 |
+
return *this;
|
773 |
+
}
|
774 |
+
|
775 |
+
template <typename T>
|
776 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
|
777 |
+
ThrowOnError(GetApi().SessionOptionsSetCustomJoinThreadFn(this->p_, ort_custom_join_thread_fn));
|
778 |
+
return *this;
|
779 |
+
}
|
780 |
+
|
781 |
+
template <typename T>
|
782 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options) {
|
783 |
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO(this->p_, &provider_options));
|
784 |
+
return *this;
|
785 |
+
}
|
786 |
+
|
787 |
+
template <typename T>
|
788 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::RegisterCustomOpsLibrary(const ORTCHAR_T* library_name,
|
789 |
+
const CustomOpConfigs& custom_op_configs) {
|
790 |
+
// Add custom op config entries before registering the custom op library. Otherwise, the config entries _may_ be ignored by
|
791 |
+
// the custom op library.
|
792 |
+
for (const auto& config_iter : custom_op_configs.GetFlattenedConfigs()) {
|
793 |
+
AddConfigEntry(config_iter.first.c_str(), config_iter.second.c_str());
|
794 |
+
}
|
795 |
+
|
796 |
+
ThrowOnError(GetApi().RegisterCustomOpsLibrary_V2(this->p_, library_name));
|
797 |
+
return *this;
|
798 |
+
}
|
799 |
+
|
800 |
+
template <typename T>
|
801 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::RegisterCustomOpsUsingFunction(const char* registration_function_name) {
|
802 |
+
ThrowOnError(GetApi().RegisterCustomOpsUsingFunction(this->p_, registration_function_name));
|
803 |
+
return *this;
|
804 |
+
}
|
805 |
+
|
806 |
+
/// Session
|
807 |
+
template <typename T>
|
808 |
+
inline size_t ConstSessionImpl<T>::GetInputCount() const {
|
809 |
+
size_t out;
|
810 |
+
ThrowOnError(GetApi().SessionGetInputCount(this->p_, &out));
|
811 |
+
return out;
|
812 |
+
}
|
813 |
+
|
814 |
+
template <typename T>
|
815 |
+
inline size_t ConstSessionImpl<T>::GetOutputCount() const {
|
816 |
+
size_t out;
|
817 |
+
ThrowOnError(GetApi().SessionGetOutputCount(this->p_, &out));
|
818 |
+
return out;
|
819 |
+
}
|
820 |
+
|
821 |
+
template <typename T>
|
822 |
+
inline size_t ConstSessionImpl<T>::GetOverridableInitializerCount() const {
|
823 |
+
size_t out;
|
824 |
+
ThrowOnError(GetApi().SessionGetOverridableInitializerCount(this->p_, &out));
|
825 |
+
return out;
|
826 |
+
}
|
827 |
+
|
828 |
+
template <typename T>
|
829 |
+
inline AllocatedStringPtr ConstSessionImpl<T>::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const {
|
830 |
+
char* out;
|
831 |
+
ThrowOnError(GetApi().SessionGetInputName(this->p_, index, allocator, &out));
|
832 |
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
833 |
+
}
|
834 |
+
|
835 |
+
template <typename T>
|
836 |
+
inline AllocatedStringPtr ConstSessionImpl<T>::GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const {
|
837 |
+
char* out;
|
838 |
+
ThrowOnError(GetApi().SessionGetOutputName(this->p_, index, allocator, &out));
|
839 |
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
840 |
+
}
|
841 |
+
|
842 |
+
template <typename T>
|
843 |
+
inline AllocatedStringPtr ConstSessionImpl<T>::GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const {
|
844 |
+
char* out;
|
845 |
+
ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, index, allocator, &out));
|
846 |
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
847 |
+
}
|
848 |
+
|
849 |
+
template <typename T>
|
850 |
+
inline uint64_t ConstSessionImpl<T>::GetProfilingStartTimeNs() const {
|
851 |
+
uint64_t out;
|
852 |
+
ThrowOnError(GetApi().SessionGetProfilingStartTimeNs(this->p_, &out));
|
853 |
+
return out;
|
854 |
+
}
|
855 |
+
|
856 |
+
template <typename T>
|
857 |
+
inline ModelMetadata ConstSessionImpl<T>::GetModelMetadata() const {
|
858 |
+
OrtModelMetadata* out;
|
859 |
+
ThrowOnError(GetApi().SessionGetModelMetadata(this->p_, &out));
|
860 |
+
return ModelMetadata{out};
|
861 |
+
}
|
862 |
+
|
863 |
+
template <typename T>
|
864 |
+
inline TypeInfo ConstSessionImpl<T>::GetInputTypeInfo(size_t index) const {
|
865 |
+
OrtTypeInfo* out;
|
866 |
+
ThrowOnError(GetApi().SessionGetInputTypeInfo(this->p_, index, &out));
|
867 |
+
return TypeInfo{out};
|
868 |
+
}
|
869 |
+
|
870 |
+
template <typename T>
|
871 |
+
inline TypeInfo ConstSessionImpl<T>::GetOutputTypeInfo(size_t index) const {
|
872 |
+
OrtTypeInfo* out;
|
873 |
+
ThrowOnError(GetApi().SessionGetOutputTypeInfo(this->p_, index, &out));
|
874 |
+
return TypeInfo{out};
|
875 |
+
}
|
876 |
+
|
877 |
+
template <typename T>
|
878 |
+
inline TypeInfo ConstSessionImpl<T>::GetOverridableInitializerTypeInfo(size_t index) const {
|
879 |
+
OrtTypeInfo* out;
|
880 |
+
ThrowOnError(GetApi().SessionGetOverridableInitializerTypeInfo(this->p_, index, &out));
|
881 |
+
return TypeInfo{out};
|
882 |
+
}
|
883 |
+
|
884 |
+
template <typename T>
|
885 |
+
inline std::vector<Value> SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
|
886 |
+
const char* const* output_names, size_t output_count) {
|
887 |
+
std::vector<Value> output_values;
|
888 |
+
output_values.reserve(output_count);
|
889 |
+
for (size_t i = 0; i < output_count; i++)
|
890 |
+
output_values.emplace_back(nullptr);
|
891 |
+
Run(run_options, input_names, input_values, input_count, output_names, output_values.data(), output_count);
|
892 |
+
return output_values;
|
893 |
+
}
|
894 |
+
|
895 |
+
template <typename T>
|
896 |
+
inline void SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
|
897 |
+
const char* const* output_names, Value* output_values, size_t output_count) {
|
898 |
+
static_assert(sizeof(Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
|
899 |
+
auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
|
900 |
+
auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
|
901 |
+
ThrowOnError(GetApi().Run(this->p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values));
|
902 |
+
}
|
903 |
+
|
904 |
+
template <typename T>
|
905 |
+
inline void SessionImpl<T>::Run(const RunOptions& run_options, const IoBinding& io_binding) {
|
906 |
+
ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding));
|
907 |
+
}
|
908 |
+
|
909 |
+
template <typename T>
|
910 |
+
inline AllocatedStringPtr SessionImpl<T>::EndProfilingAllocated(OrtAllocator* allocator) {
|
911 |
+
char* out = nullptr;
|
912 |
+
ThrowOnError(GetApi().SessionEndProfiling(this->p_, allocator, &out));
|
913 |
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
914 |
+
}
|
915 |
+
|
916 |
+
} // namespace detail
|
917 |
+
|
918 |
+
inline SessionOptions::SessionOptions() {
|
919 |
+
ThrowOnError(GetApi().CreateSessionOptions(&this->p_));
|
920 |
+
}
|
921 |
+
|
922 |
+
/// CustomOpConfigs
|
923 |
+
inline std::string detail::MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config) {
|
924 |
+
std::string config_key = "custom_op.";
|
925 |
+
|
926 |
+
config_key += custom_op_name;
|
927 |
+
config_key += ".";
|
928 |
+
config_key += config;
|
929 |
+
|
930 |
+
return config_key;
|
931 |
+
}
|
932 |
+
|
933 |
+
inline CustomOpConfigs& CustomOpConfigs::AddConfig(const char* custom_op_name, const char* config_key, const char* config_value) {
|
934 |
+
const std::string full_flat_key = detail::MakeCustomOpConfigEntryKey(custom_op_name, config_key);
|
935 |
+
flat_configs_[full_flat_key] = config_value;
|
936 |
+
return *this;
|
937 |
+
}
|
938 |
+
|
939 |
+
inline const std::unordered_map<std::string, std::string>& CustomOpConfigs::GetFlattenedConfigs() const {
|
940 |
+
return flat_configs_;
|
941 |
+
}
|
942 |
+
|
943 |
+
inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) {
|
944 |
+
ThrowOnError(GetApi().CreateSession(env, model_path, options, &this->p_));
|
945 |
+
}
|
946 |
+
|
947 |
+
inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
|
948 |
+
OrtPrepackedWeightsContainer* prepacked_weights_container) {
|
949 |
+
ThrowOnError(GetApi().CreateSessionWithPrepackedWeightsContainer(env, model_path, options, prepacked_weights_container, &this->p_));
|
950 |
+
}
|
951 |
+
|
952 |
+
inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) {
|
953 |
+
ThrowOnError(GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &this->p_));
|
954 |
+
}
|
955 |
+
|
956 |
+
inline Session::Session(const Env& env, const void* model_data, size_t model_data_length,
|
957 |
+
const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container) {
|
958 |
+
ThrowOnError(GetApi().CreateSessionFromArrayWithPrepackedWeightsContainer(env, model_data, model_data_length, options,
|
959 |
+
prepacked_weights_container, &this->p_));
|
960 |
+
}
|
961 |
+
|
962 |
+
inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator) const {
|
963 |
+
char* out;
|
964 |
+
ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out));
|
965 |
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
966 |
+
}
|
967 |
+
|
968 |
+
inline AllocatedStringPtr ModelMetadata::GetGraphNameAllocated(OrtAllocator* allocator) const {
|
969 |
+
char* out;
|
970 |
+
ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out));
|
971 |
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
972 |
+
}
|
973 |
+
|
974 |
+
inline AllocatedStringPtr ModelMetadata::GetDomainAllocated(OrtAllocator* allocator) const {
|
975 |
+
char* out;
|
976 |
+
ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out));
|
977 |
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
978 |
+
}
|
979 |
+
|
980 |
+
inline AllocatedStringPtr Ort::ModelMetadata::GetDescriptionAllocated(OrtAllocator* allocator) const {
|
981 |
+
char* out;
|
982 |
+
ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out));
|
983 |
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
984 |
+
}
|
985 |
+
|
986 |
+
inline AllocatedStringPtr ModelMetadata::GetGraphDescriptionAllocated(OrtAllocator* allocator) const {
|
987 |
+
char* out;
|
988 |
+
ThrowOnError(GetApi().ModelMetadataGetGraphDescription(p_, allocator, &out));
|
989 |
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
990 |
+
}
|
991 |
+
|
992 |
+
inline AllocatedStringPtr ModelMetadata::LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const {
|
993 |
+
char* out;
|
994 |
+
ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out));
|
995 |
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
996 |
+
}
|
997 |
+
|
998 |
+
inline std::vector<AllocatedStringPtr> ModelMetadata::GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const {
|
999 |
+
auto deletor = detail::AllocatedFree(allocator);
|
1000 |
+
std::vector<AllocatedStringPtr> result;
|
1001 |
+
|
1002 |
+
char** out = nullptr;
|
1003 |
+
int64_t num_keys = 0;
|
1004 |
+
ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys));
|
1005 |
+
if (num_keys <= 0) {
|
1006 |
+
return result;
|
1007 |
+
}
|
1008 |
+
|
1009 |
+
// array of pointers will be freed
|
1010 |
+
std::unique_ptr<void, decltype(deletor)> array_guard(out, deletor);
|
1011 |
+
// reserve may throw
|
1012 |
+
auto strings_deletor = [&deletor, num_keys](char** out) { for(int64_t i = 0; i < num_keys; ++i) deletor(out[i]); };
|
1013 |
+
std::unique_ptr<char*, decltype(strings_deletor)> strings_guard(out, strings_deletor);
|
1014 |
+
result.reserve(static_cast<size_t>(num_keys));
|
1015 |
+
strings_guard.release();
|
1016 |
+
for (int64_t i = 0; i < num_keys; ++i) {
|
1017 |
+
result.push_back(AllocatedStringPtr(out[i], deletor));
|
1018 |
+
}
|
1019 |
+
|
1020 |
+
return result;
|
1021 |
+
}
|
1022 |
+
|
1023 |
+
inline int64_t ModelMetadata::GetVersion() const {
|
1024 |
+
int64_t out;
|
1025 |
+
ThrowOnError(GetApi().ModelMetadataGetVersion(p_, &out));
|
1026 |
+
return out;
|
1027 |
+
}
|
1028 |
+
|
1029 |
+
namespace detail {
|
1030 |
+
|
1031 |
+
template <typename T>
|
1032 |
+
inline ONNXTensorElementDataType TensorTypeAndShapeInfoImpl<T>::GetElementType() const {
|
1033 |
+
ONNXTensorElementDataType out;
|
1034 |
+
ThrowOnError(GetApi().GetTensorElementType(this->p_, &out));
|
1035 |
+
return out;
|
1036 |
+
}
|
1037 |
+
|
1038 |
+
template <typename T>
|
1039 |
+
inline size_t TensorTypeAndShapeInfoImpl<T>::GetElementCount() const {
|
1040 |
+
size_t out;
|
1041 |
+
ThrowOnError(GetApi().GetTensorShapeElementCount(this->p_, &out));
|
1042 |
+
return static_cast<size_t>(out);
|
1043 |
+
}
|
1044 |
+
|
1045 |
+
template <typename T>
|
1046 |
+
inline size_t TensorTypeAndShapeInfoImpl<T>::GetDimensionsCount() const {
|
1047 |
+
size_t out;
|
1048 |
+
ThrowOnError(GetApi().GetDimensionsCount(this->p_, &out));
|
1049 |
+
return out;
|
1050 |
+
}
|
1051 |
+
|
1052 |
+
template <typename T>
|
1053 |
+
inline void TensorTypeAndShapeInfoImpl<T>::GetDimensions(int64_t* values, size_t values_count) const {
|
1054 |
+
ThrowOnError(GetApi().GetDimensions(this->p_, values, values_count));
|
1055 |
+
}
|
1056 |
+
|
1057 |
+
template <typename T>
|
1058 |
+
inline void TensorTypeAndShapeInfoImpl<T>::GetSymbolicDimensions(const char** values, size_t values_count) const {
|
1059 |
+
ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, values, values_count));
|
1060 |
+
}
|
1061 |
+
|
1062 |
+
template <typename T>
|
1063 |
+
inline std::vector<int64_t> TensorTypeAndShapeInfoImpl<T>::GetShape() const {
|
1064 |
+
std::vector<int64_t> out(GetDimensionsCount(), 0);
|
1065 |
+
ThrowOnError(GetApi().GetDimensions(this->p_, out.data(), out.size()));
|
1066 |
+
return out;
|
1067 |
+
}
|
1068 |
+
|
1069 |
+
} // namespace detail
|
1070 |
+
|
1071 |
+
namespace detail {
|
1072 |
+
template <typename T>
|
1073 |
+
inline ConstTensorTypeAndShapeInfo TypeInfoImpl<T>::GetTensorTypeAndShapeInfo() const {
|
1074 |
+
const OrtTensorTypeAndShapeInfo* out;
|
1075 |
+
ThrowOnError(GetApi().CastTypeInfoToTensorInfo(this->p_, &out));
|
1076 |
+
return ConstTensorTypeAndShapeInfo{out};
|
1077 |
+
}
|
1078 |
+
|
1079 |
+
template <typename T>
|
1080 |
+
inline ConstSequenceTypeInfo TypeInfoImpl<T>::GetSequenceTypeInfo() const {
|
1081 |
+
const OrtSequenceTypeInfo* out;
|
1082 |
+
ThrowOnError(GetApi().CastTypeInfoToSequenceTypeInfo(this->p_, &out));
|
1083 |
+
return ConstSequenceTypeInfo{out};
|
1084 |
+
}
|
1085 |
+
|
1086 |
+
template <typename T>
|
1087 |
+
inline ConstMapTypeInfo TypeInfoImpl<T>::GetMapTypeInfo() const {
|
1088 |
+
const OrtMapTypeInfo* out;
|
1089 |
+
ThrowOnError(GetApi().CastTypeInfoToMapTypeInfo(this->p_, &out));
|
1090 |
+
return ConstMapTypeInfo{out};
|
1091 |
+
}
|
1092 |
+
|
1093 |
+
template <typename T>
|
1094 |
+
inline ONNXType TypeInfoImpl<T>::GetONNXType() const {
|
1095 |
+
ONNXType out;
|
1096 |
+
ThrowOnError(GetApi().GetOnnxTypeFromTypeInfo(this->p_, &out));
|
1097 |
+
return out;
|
1098 |
+
}
|
1099 |
+
|
1100 |
+
} // namespace detail
|
1101 |
+
|
1102 |
+
namespace detail {
|
1103 |
+
template <typename T>
|
1104 |
+
inline TypeInfo SequenceTypeInfoImpl<T>::GetSequenceElementType() const {
|
1105 |
+
OrtTypeInfo* output;
|
1106 |
+
ThrowOnError(GetApi().GetSequenceElementType(this->p_, &output));
|
1107 |
+
return TypeInfo{output};
|
1108 |
+
}
|
1109 |
+
|
1110 |
+
} // namespace detail
|
1111 |
+
|
1112 |
+
namespace detail {
|
1113 |
+
template <typename T>
|
1114 |
+
inline ONNXTensorElementDataType MapTypeInfoImpl<T>::GetMapKeyType() const {
|
1115 |
+
ONNXTensorElementDataType out;
|
1116 |
+
ThrowOnError(GetApi().GetMapKeyType(this->p_, &out));
|
1117 |
+
return out;
|
1118 |
+
}
|
1119 |
+
|
1120 |
+
template <typename T>
|
1121 |
+
inline TypeInfo MapTypeInfoImpl<T>::GetMapValueType() const {
|
1122 |
+
OrtTypeInfo* output;
|
1123 |
+
ThrowOnError(GetApi().GetMapValueType(this->p_, &output));
|
1124 |
+
return TypeInfo{output};
|
1125 |
+
}
|
1126 |
+
} // namespace detail
|
1127 |
+
|
1128 |
+
namespace detail {
|
1129 |
+
|
1130 |
+
template <typename T>
|
1131 |
+
template <typename R>
|
1132 |
+
inline void ConstValueImpl<T>::GetOpaqueData(const char* domain, const char* type_name, R& out) const {
|
1133 |
+
ThrowOnError(GetApi().GetOpaqueValue(domain, type_name, this->p_, &out, sizeof(R)));
|
1134 |
+
}
|
1135 |
+
|
1136 |
+
template <typename T>
|
1137 |
+
inline bool ConstValueImpl<T>::IsTensor() const {
|
1138 |
+
int out;
|
1139 |
+
ThrowOnError(GetApi().IsTensor(this->p_, &out));
|
1140 |
+
return out != 0;
|
1141 |
+
}
|
1142 |
+
|
1143 |
+
template <typename T>
|
1144 |
+
inline bool ConstValueImpl<T>::HasValue() const {
|
1145 |
+
int out;
|
1146 |
+
ThrowOnError(GetApi().HasValue(this->p_, &out));
|
1147 |
+
return out != 0;
|
1148 |
+
}
|
1149 |
+
|
1150 |
+
template <typename T>
|
1151 |
+
inline size_t ConstValueImpl<T>::GetCount() const {
|
1152 |
+
size_t out;
|
1153 |
+
ThrowOnError(GetApi().GetValueCount(this->p_, &out));
|
1154 |
+
return out;
|
1155 |
+
}
|
1156 |
+
|
1157 |
+
template <typename T>
|
1158 |
+
inline Value ConstValueImpl<T>::GetValue(int index, OrtAllocator* allocator) const {
|
1159 |
+
OrtValue* out;
|
1160 |
+
ThrowOnError(GetApi().GetValue(this->p_, index, allocator, &out));
|
1161 |
+
return Value{out};
|
1162 |
+
}
|
1163 |
+
|
1164 |
+
template <typename T>
|
1165 |
+
inline size_t ConstValueImpl<T>::GetStringTensorDataLength() const {
|
1166 |
+
size_t out;
|
1167 |
+
ThrowOnError(GetApi().GetStringTensorDataLength(this->p_, &out));
|
1168 |
+
return out;
|
1169 |
+
}
|
1170 |
+
|
1171 |
+
template <typename T>
|
1172 |
+
inline size_t ConstValueImpl<T>::GetStringTensorElementLength(size_t element_index) const {
|
1173 |
+
size_t out;
|
1174 |
+
ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &out));
|
1175 |
+
return out;
|
1176 |
+
}
|
1177 |
+
|
1178 |
+
template <typename T>
|
1179 |
+
template <typename R>
|
1180 |
+
inline const R* ConstValueImpl<T>::GetTensorData() const {
|
1181 |
+
R* out;
|
1182 |
+
ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), (void**)&out));
|
1183 |
+
return out;
|
1184 |
+
}
|
1185 |
+
|
1186 |
+
template <typename T>
|
1187 |
+
inline const void* ConstValueImpl<T>::GetTensorRawData() const {
|
1188 |
+
void* out;
|
1189 |
+
ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), &out));
|
1190 |
+
return out;
|
1191 |
+
}
|
1192 |
+
|
1193 |
+
template <typename T>
|
1194 |
+
inline TypeInfo ConstValueImpl<T>::GetTypeInfo() const {
|
1195 |
+
OrtTypeInfo* output;
|
1196 |
+
ThrowOnError(GetApi().GetTypeInfo(this->p_, &output));
|
1197 |
+
return TypeInfo{output};
|
1198 |
+
}
|
1199 |
+
|
1200 |
+
template <typename T>
|
1201 |
+
inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetTensorTypeAndShapeInfo() const {
|
1202 |
+
OrtTensorTypeAndShapeInfo* output;
|
1203 |
+
ThrowOnError(GetApi().GetTensorTypeAndShape(this->p_, &output));
|
1204 |
+
return TensorTypeAndShapeInfo{output};
|
1205 |
+
}
|
1206 |
+
|
1207 |
+
template <typename T>
|
1208 |
+
inline ConstMemoryInfo ConstValueImpl<T>::GetTensorMemoryInfo() const {
|
1209 |
+
const OrtMemoryInfo* mem_info;
|
1210 |
+
ThrowOnError(GetApi().GetTensorMemoryInfo(this->p_, &mem_info));
|
1211 |
+
return ConstMemoryInfo(mem_info);
|
1212 |
+
}
|
1213 |
+
|
1214 |
+
template <typename T>
|
1215 |
+
inline void ConstValueImpl<T>::GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const {
|
1216 |
+
ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, buffer));
|
1217 |
+
}
|
1218 |
+
|
1219 |
+
template <typename T>
|
1220 |
+
inline void ConstValueImpl<T>::GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const {
|
1221 |
+
ThrowOnError(GetApi().GetStringTensorContent(this->p_, buffer, buffer_length, offsets, offsets_count));
|
1222 |
+
}
|
1223 |
+
|
1224 |
+
#if !defined(DISABLE_SPARSE_TENSORS)
|
1225 |
+
template <typename T>
|
1226 |
+
inline OrtSparseFormat ConstValueImpl<T>::GetSparseFormat() const {
|
1227 |
+
OrtSparseFormat format;
|
1228 |
+
ThrowOnError(GetApi().GetSparseTensorFormat(this->p_, &format));
|
1229 |
+
return format;
|
1230 |
+
}
|
1231 |
+
|
1232 |
+
template <typename T>
|
1233 |
+
inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetSparseTensorValuesTypeAndShapeInfo() const {
|
1234 |
+
OrtTensorTypeAndShapeInfo* output;
|
1235 |
+
ThrowOnError(GetApi().GetSparseTensorValuesTypeAndShape(this->p_, &output));
|
1236 |
+
return TensorTypeAndShapeInfo{output};
|
1237 |
+
}
|
1238 |
+
|
1239 |
+
template <typename T>
|
1240 |
+
inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat indices_format) const {
|
1241 |
+
OrtTensorTypeAndShapeInfo* output;
|
1242 |
+
ThrowOnError(GetApi().GetSparseTensorIndicesTypeShape(this->p_, indices_format, &output));
|
1243 |
+
return TensorTypeAndShapeInfo{output};
|
1244 |
+
}
|
1245 |
+
|
1246 |
+
template <typename T>
|
1247 |
+
template <typename R>
|
1248 |
+
inline const R* ConstValueImpl<T>::GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const {
|
1249 |
+
const void* out;
|
1250 |
+
ThrowOnError(GetApi().GetSparseTensorIndices(this->p_, indices_format, &num_indices, &out));
|
1251 |
+
return reinterpret_cast<const R*>(out);
|
1252 |
+
}
|
1253 |
+
|
1254 |
+
template <typename T>
|
1255 |
+
inline bool ConstValueImpl<T>::IsSparseTensor() const {
|
1256 |
+
int out;
|
1257 |
+
ThrowOnError(GetApi().IsSparseTensor(this->p_, &out));
|
1258 |
+
return out != 0;
|
1259 |
+
}
|
1260 |
+
|
1261 |
+
template <typename T>
|
1262 |
+
template <typename R>
|
1263 |
+
inline const R* ConstValueImpl<T>::GetSparseTensorValues() const {
|
1264 |
+
const void* out;
|
1265 |
+
ThrowOnError(GetApi().GetSparseTensorValues(this->p_, &out));
|
1266 |
+
return reinterpret_cast<const R*>(out);
|
1267 |
+
}
|
1268 |
+
|
1269 |
+
#endif
|
1270 |
+
|
1271 |
+
template <typename T>
|
1272 |
+
void ValueImpl<T>::FillStringTensor(const char* const* s, size_t s_len) {
|
1273 |
+
ThrowOnError(GetApi().FillStringTensor(this->p_, s, s_len));
|
1274 |
+
}
|
1275 |
+
|
1276 |
+
template <typename T>
|
1277 |
+
void ValueImpl<T>::FillStringTensorElement(const char* s, size_t index) {
|
1278 |
+
ThrowOnError(GetApi().FillStringTensorElement(this->p_, s, index));
|
1279 |
+
}
|
1280 |
+
|
1281 |
+
template <typename T>
|
1282 |
+
void* ValueImpl<T>::GetTensorMutableRawData() {
|
1283 |
+
void* out;
|
1284 |
+
ThrowOnError(GetApi().GetTensorMutableData(this->p_, &out));
|
1285 |
+
return out;
|
1286 |
+
}
|
1287 |
+
|
1288 |
+
template <typename T>
|
1289 |
+
template <typename R>
|
1290 |
+
R* ValueImpl<T>::GetTensorMutableData() {
|
1291 |
+
R* out;
|
1292 |
+
ThrowOnError(GetApi().GetTensorMutableData(this->p_, (void**)&out));
|
1293 |
+
return out;
|
1294 |
+
}
|
1295 |
+
|
1296 |
+
template <typename T>
|
1297 |
+
template <typename R>
|
1298 |
+
R& ValueImpl<T>::At(const std::vector<int64_t>& location) {
|
1299 |
+
static_assert(!std::is_same<T, std::string>::value, "this api does not support std::string");
|
1300 |
+
R* out;
|
1301 |
+
ThrowOnError(GetApi().TensorAt(this->p_, location.data(), location.size(), (void**)&out));
|
1302 |
+
return *out;
|
1303 |
+
}
|
1304 |
+
|
1305 |
+
#if !defined(DISABLE_SPARSE_TENSORS)
|
1306 |
+
template <typename T>
|
1307 |
+
void ValueImpl<T>::UseCooIndices(int64_t* indices_data, size_t indices_num) {
|
1308 |
+
ThrowOnError(GetApi().UseCooIndices(this->p_, indices_data, indices_num));
|
1309 |
+
}
|
1310 |
+
|
1311 |
+
template <typename T>
|
1312 |
+
void ValueImpl<T>::UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num) {
|
1313 |
+
ThrowOnError(GetApi().UseCsrIndices(this->p_, inner_data, inner_num, outer_data, outer_num));
|
1314 |
+
}
|
1315 |
+
|
1316 |
+
template <typename T>
|
1317 |
+
void ValueImpl<T>::UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data) {
|
1318 |
+
ThrowOnError(GetApi().UseBlockSparseIndices(this->p_, indices_shape.shape, indices_shape.shape_len, indices_data));
|
1319 |
+
}
|
1320 |
+
|
1321 |
+
template <typename T>
|
1322 |
+
void ValueImpl<T>::FillSparseTensorCoo(const OrtMemoryInfo* mem_info, const OrtSparseValuesParam& values_param,
|
1323 |
+
const int64_t* indices_data, size_t indices_num) {
|
1324 |
+
ThrowOnError(GetApi().FillSparseTensorCoo(this->p_, mem_info, values_param.values_shape,
|
1325 |
+
values_param.values_shape_len, values_param.data.p_data,
|
1326 |
+
indices_data, indices_num));
|
1327 |
+
}
|
1328 |
+
|
1329 |
+
template <typename T>
|
1330 |
+
void ValueImpl<T>::FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
|
1331 |
+
const OrtSparseValuesParam& values,
|
1332 |
+
const int64_t* inner_indices_data, size_t inner_indices_num,
|
1333 |
+
const int64_t* outer_indices_data, size_t outer_indices_num) {
|
1334 |
+
ThrowOnError(GetApi().FillSparseTensorCsr(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
|
1335 |
+
inner_indices_data, inner_indices_num,
|
1336 |
+
outer_indices_data, outer_indices_num));
|
1337 |
+
}
|
1338 |
+
|
1339 |
+
template <typename T>
|
1340 |
+
void ValueImpl<T>::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
|
1341 |
+
const OrtSparseValuesParam& values,
|
1342 |
+
const Shape& indices_shape,
|
1343 |
+
const int32_t* indices_data) {
|
1344 |
+
ThrowOnError(GetApi().FillSparseTensorBlockSparse(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
|
1345 |
+
indices_shape.shape, indices_shape.shape_len,
|
1346 |
+
indices_data));
|
1347 |
+
}
|
1348 |
+
|
1349 |
+
#endif // !defined(DISABLE_SPARSE_TENSORS)
|
1350 |
+
|
1351 |
+
} // namespace detail
|
1352 |
+
|
1353 |
+
template <typename T>
|
1354 |
+
inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) {
|
1355 |
+
return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType<T>::type);
|
1356 |
+
}
|
1357 |
+
|
1358 |
+
inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
|
1359 |
+
ONNXTensorElementDataType type) {
|
1360 |
+
OrtValue* out;
|
1361 |
+
ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out));
|
1362 |
+
return Value{out};
|
1363 |
+
}
|
1364 |
+
|
1365 |
+
template <typename T>
|
1366 |
+
inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) {
|
1367 |
+
return CreateTensor(allocator, shape, shape_len, TypeToTensorType<T>::type);
|
1368 |
+
}
|
1369 |
+
|
1370 |
+
inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) {
|
1371 |
+
OrtValue* out;
|
1372 |
+
ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out));
|
1373 |
+
return Value{out};
|
1374 |
+
}
|
1375 |
+
|
1376 |
+
#if !defined(DISABLE_SPARSE_TENSORS)
|
1377 |
+
|
1378 |
+
template <typename T>
|
1379 |
+
inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
|
1380 |
+
const Shape& values_shape) {
|
1381 |
+
return CreateSparseTensor(info, p_data, dense_shape, values_shape, TypeToTensorType<T>::type);
|
1382 |
+
}
|
1383 |
+
|
1384 |
+
inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
|
1385 |
+
const Shape& values_shape, ONNXTensorElementDataType type) {
|
1386 |
+
OrtValue* out;
|
1387 |
+
ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len,
|
1388 |
+
values_shape.shape, values_shape.shape_len, type, &out));
|
1389 |
+
return Value{out};
|
1390 |
+
}
|
1391 |
+
|
1392 |
+
template <typename T>
|
1393 |
+
inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape) {
|
1394 |
+
return CreateSparseTensor(allocator, dense_shape, TypeToTensorType<T>::type);
|
1395 |
+
}
|
1396 |
+
|
1397 |
+
inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape,
|
1398 |
+
ONNXTensorElementDataType type) {
|
1399 |
+
OrtValue* out;
|
1400 |
+
ThrowOnError(GetApi().CreateSparseTensorAsOrtValue(allocator, dense_shape.shape, dense_shape.shape_len, type, &out));
|
1401 |
+
return Value{out};
|
1402 |
+
}
|
1403 |
+
#endif // !defined(DISABLE_SPARSE_TENSORS)
|
1404 |
+
|
1405 |
+
inline Value Value::CreateMap(Value& keys, Value& values) {
|
1406 |
+
OrtValue* out;
|
1407 |
+
OrtValue* inputs[2] = {keys, values};
|
1408 |
+
ThrowOnError(GetApi().CreateValue(inputs, 2, ONNX_TYPE_MAP, &out));
|
1409 |
+
return Value{out};
|
1410 |
+
}
|
1411 |
+
|
1412 |
+
inline Value Value::CreateSequence(std::vector<Value>& values) {
|
1413 |
+
OrtValue* out;
|
1414 |
+
std::vector<OrtValue*> values_ort{values.data(), values.data() + values.size()};
|
1415 |
+
ThrowOnError(GetApi().CreateValue(values_ort.data(), values_ort.size(), ONNX_TYPE_SEQUENCE, &out));
|
1416 |
+
return Value{out};
|
1417 |
+
}
|
1418 |
+
|
1419 |
+
template <typename T>
|
1420 |
+
inline Value Value::CreateOpaque(const char* domain, const char* type_name, const T& data_container) {
|
1421 |
+
OrtValue* out;
|
1422 |
+
ThrowOnError(GetApi().CreateOpaqueValue(domain, type_name, &data_container, sizeof(T), &out));
|
1423 |
+
return Value{out};
|
1424 |
+
}
|
1425 |
+
|
1426 |
+
//
|
1427 |
+
// Custom OP Inlines
|
1428 |
+
//
|
1429 |
+
inline KernelContext::KernelContext(OrtKernelContext* context) : ctx_(context) {
|
1430 |
+
}
|
1431 |
+
|
1432 |
+
inline size_t KernelContext::GetInputCount() const {
|
1433 |
+
size_t out = 0;
|
1434 |
+
Ort::ThrowOnError(GetApi().KernelContext_GetInputCount(ctx_, &out));
|
1435 |
+
return out;
|
1436 |
+
}
|
1437 |
+
|
1438 |
+
inline size_t KernelContext::GetOutputCount() const {
|
1439 |
+
size_t out = 0;
|
1440 |
+
Ort::ThrowOnError(GetApi().KernelContext_GetOutputCount(ctx_, &out));
|
1441 |
+
return out;
|
1442 |
+
}
|
1443 |
+
|
1444 |
+
inline ConstValue KernelContext::GetInput(size_t index) const {
|
1445 |
+
const OrtValue* out = nullptr;
|
1446 |
+
Ort::ThrowOnError(GetApi().KernelContext_GetInput(ctx_, index, &out));
|
1447 |
+
return ConstValue{out};
|
1448 |
+
}
|
1449 |
+
|
1450 |
+
inline UnownedValue KernelContext::GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const {
|
1451 |
+
OrtValue* out = nullptr;
|
1452 |
+
Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dim_values, dim_count, &out));
|
1453 |
+
return UnownedValue(out);
|
1454 |
+
}
|
1455 |
+
|
1456 |
+
inline UnownedValue KernelContext::GetOutput(size_t index, const std::vector<int64_t>& dims) const {
|
1457 |
+
OrtValue* out = nullptr;
|
1458 |
+
Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dims.data(), dims.size(), &out));
|
1459 |
+
return UnownedValue(out);
|
1460 |
+
}
|
1461 |
+
|
1462 |
+
inline void* KernelContext::GetGPUComputeStream() const {
|
1463 |
+
void* out = nullptr;
|
1464 |
+
Ort::ThrowOnError(GetApi().KernelContext_GetGPUComputeStream(ctx_, &out));
|
1465 |
+
return out;
|
1466 |
+
}
|
1467 |
+
|
1468 |
+
inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType type) {
|
1469 |
+
Ort::ThrowOnError(GetApi().CreateOpAttr(name, data, len, type, &p_));
|
1470 |
+
}
|
1471 |
+
|
1472 |
+
namespace detail {
|
1473 |
+
template <typename T>
|
1474 |
+
inline KernelInfo KernelInfoImpl<T>::Copy() const {
|
1475 |
+
OrtKernelInfo* info_copy = nullptr;
|
1476 |
+
Ort::ThrowOnError(GetApi().CopyKernelInfo(this->p_, &info_copy));
|
1477 |
+
return KernelInfo{info_copy};
|
1478 |
+
}
|
1479 |
+
|
1480 |
+
template <typename T>
|
1481 |
+
inline size_t KernelInfoImpl<T>::GetInputCount() const {
|
1482 |
+
size_t out = 0;
|
1483 |
+
ThrowOnError(GetApi().KernelInfo_GetInputCount(this->p_, &out));
|
1484 |
+
return out;
|
1485 |
+
}
|
1486 |
+
|
1487 |
+
template <typename T>
|
1488 |
+
inline size_t KernelInfoImpl<T>::GetOutputCount() const {
|
1489 |
+
size_t out = 0;
|
1490 |
+
ThrowOnError(GetApi().KernelInfo_GetOutputCount(this->p_, &out));
|
1491 |
+
return out;
|
1492 |
+
}
|
1493 |
+
|
1494 |
+
template <typename T>
|
1495 |
+
inline std::string KernelInfoImpl<T>::GetInputName(size_t index) const {
|
1496 |
+
size_t size = 0;
|
1497 |
+
|
1498 |
+
// Feed nullptr for the data buffer to query the true size of the string value
|
1499 |
+
Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, nullptr, &size));
|
1500 |
+
|
1501 |
+
std::string out;
|
1502 |
+
out.resize(size);
|
1503 |
+
Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, &out[0], &size));
|
1504 |
+
out.resize(size - 1); // remove the terminating character '\0'
|
1505 |
+
|
1506 |
+
return out;
|
1507 |
+
}
|
1508 |
+
|
1509 |
+
template <typename T>
|
1510 |
+
inline std::string KernelInfoImpl<T>::GetOutputName(size_t index) const {
|
1511 |
+
size_t size = 0;
|
1512 |
+
|
1513 |
+
// Feed nullptr for the data buffer to query the true size of the string value
|
1514 |
+
Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, nullptr, &size));
|
1515 |
+
|
1516 |
+
std::string out;
|
1517 |
+
out.resize(size);
|
1518 |
+
Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, &out[0], &size));
|
1519 |
+
out.resize(size - 1); // remove the terminating character '\0'
|
1520 |
+
|
1521 |
+
return out;
|
1522 |
+
}
|
1523 |
+
|
1524 |
+
template <typename T>
|
1525 |
+
inline TypeInfo KernelInfoImpl<T>::GetInputTypeInfo(size_t index) const {
|
1526 |
+
OrtTypeInfo* out = nullptr;
|
1527 |
+
ThrowOnError(GetApi().KernelInfo_GetInputTypeInfo(this->p_, index, &out));
|
1528 |
+
return TypeInfo{out};
|
1529 |
+
}
|
1530 |
+
|
1531 |
+
template <typename T>
|
1532 |
+
inline TypeInfo KernelInfoImpl<T>::GetOutputTypeInfo(size_t index) const {
|
1533 |
+
OrtTypeInfo* out = nullptr;
|
1534 |
+
ThrowOnError(GetApi().KernelInfo_GetOutputTypeInfo(this->p_, index, &out));
|
1535 |
+
return TypeInfo{out};
|
1536 |
+
}
|
1537 |
+
|
1538 |
+
template <typename T>
|
1539 |
+
inline Value KernelInfoImpl<T>::GetTensorAttribute(const char* name, OrtAllocator* allocator) const {
|
1540 |
+
OrtValue* out = nullptr;
|
1541 |
+
ThrowOnError(GetApi().KernelInfoGetAttribute_tensor(this->p_, name, allocator, &out));
|
1542 |
+
return Value{out};
|
1543 |
+
}
|
1544 |
+
|
1545 |
+
inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) {
|
1546 |
+
Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out));
|
1547 |
+
}
|
1548 |
+
|
1549 |
+
inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, int64_t& out) {
|
1550 |
+
Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_int64(p, name, &out));
|
1551 |
+
}
|
1552 |
+
|
1553 |
+
inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, std::string& result) {
|
1554 |
+
size_t size = 0;
|
1555 |
+
// Feed nullptr for the data buffer to query the true size of the string attribute
|
1556 |
+
Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, nullptr, &size));
|
1557 |
+
|
1558 |
+
std::string out;
|
1559 |
+
out.resize(size);
|
1560 |
+
Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, &out[0], &size));
|
1561 |
+
out.resize(size - 1); // remove the terminating character '\0'
|
1562 |
+
out.swap(result);
|
1563 |
+
}
|
1564 |
+
|
1565 |
+
inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>& result) {
|
1566 |
+
size_t size = 0;
|
1567 |
+
// Feed nullptr for the data buffer to query the true size of the attribute
|
1568 |
+
Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, nullptr, &size));
|
1569 |
+
|
1570 |
+
std::vector<float> out;
|
1571 |
+
out.resize(size);
|
1572 |
+
Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, out.data(), &size));
|
1573 |
+
out.swap(result);
|
1574 |
+
}
|
1575 |
+
|
1576 |
+
inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>& result) {
|
1577 |
+
size_t size = 0;
|
1578 |
+
|
1579 |
+
// Feed nullptr for the data buffer to query the true size of the attribute
|
1580 |
+
Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, nullptr, &size));
|
1581 |
+
|
1582 |
+
std::vector<int64_t> out;
|
1583 |
+
out.resize(size);
|
1584 |
+
Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, out.data(), &size));
|
1585 |
+
out.swap(result);
|
1586 |
+
}
|
1587 |
+
} // namespace detail
|
1588 |
+
|
1589 |
+
inline KernelInfo::KernelInfo(OrtKernelInfo* info) : detail::KernelInfoImpl<OrtKernelInfo>{info} {}
|
1590 |
+
|
1591 |
+
inline Op::Op(OrtOp* p) : Base<OrtOp>(p) {}
|
1592 |
+
|
1593 |
+
inline Op Op::Create(const OrtKernelInfo* info, const char* op_name, const char* domain, int version,
|
1594 |
+
const char** type_constraint_names,
|
1595 |
+
const ONNXTensorElementDataType* type_constraint_values,
|
1596 |
+
size_t type_constraint_count,
|
1597 |
+
const OpAttr* attr_values, size_t attr_count,
|
1598 |
+
size_t input_count, size_t output_count) {
|
1599 |
+
static_assert(sizeof(OpAttr) == sizeof(OrtOpAttr*),
|
1600 |
+
"OpAttr's is expected to be just an array of OrtOpAttr in memory so we can reinterpret safely");
|
1601 |
+
auto attr_input_values = reinterpret_cast<const OrtOpAttr* const*>(attr_values);
|
1602 |
+
OrtOp* op;
|
1603 |
+
Ort::ThrowOnError(GetApi().CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values,
|
1604 |
+
static_cast<int>(type_constraint_count),
|
1605 |
+
attr_input_values,
|
1606 |
+
static_cast<int>(attr_count),
|
1607 |
+
static_cast<int>(input_count),
|
1608 |
+
static_cast<int>(output_count), &op));
|
1609 |
+
return Op{op};
|
1610 |
+
}
|
1611 |
+
|
1612 |
+
inline void Op::Invoke(const OrtKernelContext* context,
|
1613 |
+
const Value* input_values,
|
1614 |
+
size_t input_count,
|
1615 |
+
Value* output_values,
|
1616 |
+
size_t output_count) {
|
1617 |
+
static_assert(sizeof(Value) == sizeof(OrtValue*),
|
1618 |
+
"Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
|
1619 |
+
auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
|
1620 |
+
auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
|
1621 |
+
Ort::ThrowOnError(GetApi().InvokeOp(context, p_, ort_input_values, static_cast<int>(input_count),
|
1622 |
+
ort_output_values, static_cast<int>(output_count)));
|
1623 |
+
}
|
1624 |
+
|
1625 |
+
inline void Op::Invoke(const OrtKernelContext* context,
|
1626 |
+
const OrtValue* const* input_values,
|
1627 |
+
size_t input_count,
|
1628 |
+
OrtValue* const* output_values,
|
1629 |
+
size_t output_count) {
|
1630 |
+
Ort::ThrowOnError(GetApi().InvokeOp(context, p_, input_values, static_cast<int>(input_count),
|
1631 |
+
output_values, static_cast<int>(output_count)));
|
1632 |
+
}
|
1633 |
+
|
1634 |
+
inline void CustomOpApi::ThrowOnError(OrtStatus* status) {
|
1635 |
+
Ort::ThrowOnError(status);
|
1636 |
+
}
|
1637 |
+
|
1638 |
+
template <>
|
1639 |
+
inline float CustomOpApi::KernelInfoGetAttribute<float>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
|
1640 |
+
float out;
|
1641 |
+
Ort::ThrowOnError(api_.KernelInfoGetAttribute_float(info, name, &out));
|
1642 |
+
return out;
|
1643 |
+
}
|
1644 |
+
|
1645 |
+
template <>
|
1646 |
+
inline int64_t CustomOpApi::KernelInfoGetAttribute<int64_t>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
|
1647 |
+
int64_t out;
|
1648 |
+
Ort::ThrowOnError(api_.KernelInfoGetAttribute_int64(info, name, &out));
|
1649 |
+
return out;
|
1650 |
+
}
|
1651 |
+
|
1652 |
+
template <>
|
1653 |
+
inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
|
1654 |
+
size_t size = 0;
|
1655 |
+
std::string out;
|
1656 |
+
|
1657 |
+
// Feed nullptr for the data buffer to query the true size of the string attribute
|
1658 |
+
OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size);
|
1659 |
+
|
1660 |
+
if (status == nullptr) {
|
1661 |
+
out.resize(size);
|
1662 |
+
Ort::ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size));
|
1663 |
+
out.resize(size - 1); // remove the terminating character '\0'
|
1664 |
+
} else {
|
1665 |
+
Ort::ThrowOnError(status);
|
1666 |
+
}
|
1667 |
+
return out;
|
1668 |
+
}
|
1669 |
+
|
1670 |
+
template <>
|
1671 |
+
inline std::vector<float> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) {
|
1672 |
+
size_t size = 0;
|
1673 |
+
std::vector<float> out;
|
1674 |
+
|
1675 |
+
// Feed nullptr for the data buffer to query the true size of the attribute
|
1676 |
+
OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size);
|
1677 |
+
|
1678 |
+
if (status == nullptr) {
|
1679 |
+
out.resize(size);
|
1680 |
+
Ort::ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size));
|
1681 |
+
} else {
|
1682 |
+
Ort::ThrowOnError(status);
|
1683 |
+
}
|
1684 |
+
return out;
|
1685 |
+
}
|
1686 |
+
|
1687 |
+
template <>
|
1688 |
+
inline std::vector<int64_t> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) {
|
1689 |
+
size_t size = 0;
|
1690 |
+
std::vector<int64_t> out;
|
1691 |
+
|
1692 |
+
// Feed nullptr for the data buffer to query the true size of the attribute
|
1693 |
+
OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size);
|
1694 |
+
|
1695 |
+
if (status == nullptr) {
|
1696 |
+
out.resize(size);
|
1697 |
+
Ort::ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size));
|
1698 |
+
} else {
|
1699 |
+
Ort::ThrowOnError(status);
|
1700 |
+
}
|
1701 |
+
return out;
|
1702 |
+
}
|
1703 |
+
inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) {
|
1704 |
+
OrtTensorTypeAndShapeInfo* out;
|
1705 |
+
Ort::ThrowOnError(api_.GetTensorTypeAndShape(value, &out));
|
1706 |
+
return out;
|
1707 |
+
}
|
1708 |
+
|
1709 |
+
inline size_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) {
|
1710 |
+
size_t out;
|
1711 |
+
Ort::ThrowOnError(api_.GetTensorShapeElementCount(info, &out));
|
1712 |
+
return out;
|
1713 |
+
}
|
1714 |
+
|
1715 |
+
inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) {
|
1716 |
+
ONNXTensorElementDataType out;
|
1717 |
+
Ort::ThrowOnError(api_.GetTensorElementType(info, &out));
|
1718 |
+
return out;
|
1719 |
+
}
|
1720 |
+
|
1721 |
+
inline size_t CustomOpApi::GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) {
|
1722 |
+
size_t out;
|
1723 |
+
Ort::ThrowOnError(api_.GetDimensionsCount(info, &out));
|
1724 |
+
return out;
|
1725 |
+
}
|
1726 |
+
|
1727 |
+
inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) {
|
1728 |
+
Ort::ThrowOnError(api_.GetDimensions(info, dim_values, dim_values_length));
|
1729 |
+
}
|
1730 |
+
|
1731 |
+
inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) {
|
1732 |
+
Ort::ThrowOnError(api_.SetDimensions(info, dim_values, dim_count));
|
1733 |
+
}
|
1734 |
+
|
1735 |
+
template <typename T>
|
1736 |
+
inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) {
|
1737 |
+
T* data;
|
1738 |
+
Ort::ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
|
1739 |
+
return data;
|
1740 |
+
}
|
1741 |
+
|
1742 |
+
inline const OrtMemoryInfo* CustomOpApi::GetTensorMemoryInfo(_In_ const OrtValue* value) {
|
1743 |
+
const OrtMemoryInfo* mem_info;
|
1744 |
+
Ort::ThrowOnError(api_.GetTensorMemoryInfo(value, &mem_info));
|
1745 |
+
return mem_info;
|
1746 |
+
}
|
1747 |
+
|
1748 |
+
template <typename T>
|
1749 |
+
inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) {
|
1750 |
+
T* data = nullptr;
|
1751 |
+
Ort::ThrowOnError(api_.GetTensorMutableData(const_cast<OrtValue*>(value), reinterpret_cast<void**>(&data)));
|
1752 |
+
return data;
|
1753 |
+
}
|
1754 |
+
|
1755 |
+
inline std::vector<int64_t> CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) {
|
1756 |
+
size_t out;
|
1757 |
+
Ort::ThrowOnError(api_.GetDimensionsCount(info, &out));
|
1758 |
+
std::vector<int64_t> output(out);
|
1759 |
+
Ort::ThrowOnError(api_.GetDimensions(info, output.data(), out));
|
1760 |
+
return output;
|
1761 |
+
}
|
1762 |
+
|
1763 |
+
inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) {
|
1764 |
+
api_.ReleaseTensorTypeAndShapeInfo(input);
|
1765 |
+
}
|
1766 |
+
|
1767 |
+
inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) {
|
1768 |
+
size_t out;
|
1769 |
+
Ort::ThrowOnError(api_.KernelContext_GetInputCount(context, &out));
|
1770 |
+
return out;
|
1771 |
+
}
|
1772 |
+
|
1773 |
+
inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) {
|
1774 |
+
const OrtValue* out;
|
1775 |
+
Ort::ThrowOnError(api_.KernelContext_GetInput(context, index, &out));
|
1776 |
+
return out;
|
1777 |
+
}
|
1778 |
+
|
1779 |
+
inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) {
|
1780 |
+
size_t out;
|
1781 |
+
Ort::ThrowOnError(api_.KernelContext_GetOutputCount(context, &out));
|
1782 |
+
return out;
|
1783 |
+
}
|
1784 |
+
|
1785 |
+
inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index,
|
1786 |
+
_In_ const int64_t* dim_values, size_t dim_count) {
|
1787 |
+
OrtValue* out;
|
1788 |
+
Ort::ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out));
|
1789 |
+
return out;
|
1790 |
+
}
|
1791 |
+
|
1792 |
+
inline void* CustomOpApi::KernelContext_GetGPUComputeStream(const OrtKernelContext* context) {
|
1793 |
+
void* out;
|
1794 |
+
Ort::ThrowOnError(api_.KernelContext_GetGPUComputeStream(context, &out));
|
1795 |
+
return out;
|
1796 |
+
}
|
1797 |
+
|
1798 |
+
inline OrtOpAttr* CustomOpApi::CreateOpAttr(_In_ const char* name,
|
1799 |
+
_In_ const void* data,
|
1800 |
+
_In_ int len,
|
1801 |
+
_In_ OrtOpAttrType type) {
|
1802 |
+
OrtOpAttr* op_attr{};
|
1803 |
+
Ort::ThrowOnError(api_.CreateOpAttr(name, data, len, type, &op_attr));
|
1804 |
+
return op_attr;
|
1805 |
+
}
|
1806 |
+
|
1807 |
+
inline void CustomOpApi::ReleaseOpAttr(_Frees_ptr_opt_ OrtOpAttr* op_attr) {
|
1808 |
+
api_.ReleaseOpAttr(op_attr);
|
1809 |
+
}
|
1810 |
+
|
1811 |
+
inline OrtOp* CustomOpApi::CreateOp(_In_ const OrtKernelInfo* info,
|
1812 |
+
_In_ const char* op_name,
|
1813 |
+
_In_ const char* domain,
|
1814 |
+
_In_ int version,
|
1815 |
+
_In_opt_ const char** type_constraint_names,
|
1816 |
+
_In_opt_ const ONNXTensorElementDataType* type_constraint_values,
|
1817 |
+
_In_opt_ int type_constraint_count,
|
1818 |
+
_In_opt_ const OrtOpAttr* const* attr_values,
|
1819 |
+
_In_opt_ int attr_count,
|
1820 |
+
_In_ int input_count,
|
1821 |
+
_In_ int output_count) {
|
1822 |
+
OrtOp* ort_op{};
|
1823 |
+
Ort::ThrowOnError(api_.CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values,
|
1824 |
+
type_constraint_count, attr_values, attr_count, input_count, output_count, &ort_op));
|
1825 |
+
return ort_op;
|
1826 |
+
}
|
1827 |
+
|
1828 |
+
inline void CustomOpApi::InvokeOp(_In_ const OrtKernelContext* context,
|
1829 |
+
_In_ const OrtOp* ort_op,
|
1830 |
+
_In_ const OrtValue* const* input_values,
|
1831 |
+
_In_ int input_count,
|
1832 |
+
_Inout_ OrtValue* const* output_values,
|
1833 |
+
_In_ int output_count) {
|
1834 |
+
Ort::ThrowOnError(api_.InvokeOp(context, ort_op, input_values, input_count, output_values, output_count));
|
1835 |
+
}
|
1836 |
+
|
1837 |
+
inline void CustomOpApi::ReleaseOp(_Frees_ptr_opt_ OrtOp* ort_op) {
|
1838 |
+
api_.ReleaseOp(ort_op);
|
1839 |
+
}
|
1840 |
+
|
1841 |
+
inline OrtKernelInfo* CustomOpApi::CopyKernelInfo(_In_ const OrtKernelInfo* info) {
|
1842 |
+
OrtKernelInfo* info_copy{};
|
1843 |
+
Ort::ThrowOnError(api_.CopyKernelInfo(info, &info_copy));
|
1844 |
+
return info_copy;
|
1845 |
+
}
|
1846 |
+
|
1847 |
+
inline void CustomOpApi::ReleaseKernelInfo(_Frees_ptr_opt_ OrtKernelInfo* info_copy) {
|
1848 |
+
api_.ReleaseKernelInfo(info_copy);
|
1849 |
+
}
|
1850 |
+
|
1851 |
+
inline std::vector<std::string> GetAvailableProviders() {
|
1852 |
+
char** providers;
|
1853 |
+
int len;
|
1854 |
+
|
1855 |
+
auto release_fn = [&len](char** providers) {
|
1856 |
+
// This should always return nullptr.
|
1857 |
+
ThrowOnError(GetApi().ReleaseAvailableProviders(providers, len));
|
1858 |
+
};
|
1859 |
+
|
1860 |
+
ThrowOnError(GetApi().GetAvailableProviders(&providers, &len));
|
1861 |
+
std::unique_ptr<char*, decltype(release_fn)> guard(providers, release_fn);
|
1862 |
+
std::vector<std::string> available_providers;
|
1863 |
+
available_providers.reserve(static_cast<size_t>(len));
|
1864 |
+
for (int i = 0; i < len; ++i) {
|
1865 |
+
available_providers.emplace_back(providers[i]);
|
1866 |
+
}
|
1867 |
+
return available_providers;
|
1868 |
+
}
|
1869 |
+
|
1870 |
+
template <typename TOp, typename TKernel>
|
1871 |
+
void CustomOpBase<TOp, TKernel>::GetSessionConfigs(std::unordered_map<std::string, std::string>& out,
|
1872 |
+
ConstSessionOptions options) const {
|
1873 |
+
const TOp* derived = static_cast<const TOp*>(this);
|
1874 |
+
std::vector<std::string> keys = derived->GetSessionConfigKeys();
|
1875 |
+
|
1876 |
+
out.reserve(keys.size());
|
1877 |
+
|
1878 |
+
std::string config_entry_key = detail::MakeCustomOpConfigEntryKey(derived->GetName(), "");
|
1879 |
+
const size_t prefix_size = config_entry_key.length();
|
1880 |
+
|
1881 |
+
for (const auto& key : keys) {
|
1882 |
+
config_entry_key.resize(prefix_size);
|
1883 |
+
config_entry_key.append(key);
|
1884 |
+
out[key] = options.GetConfigEntryOrDefault(config_entry_key.c_str(), "");
|
1885 |
+
}
|
1886 |
+
}
|
1887 |
+
|
1888 |
+
} // namespace Ort
|
1.15.1/onnxruntime.xcframework/Info.plist
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
3 |
+
<plist version="1.0">
|
4 |
+
<dict>
|
5 |
+
<key>AvailableLibraries</key>
|
6 |
+
<array>
|
7 |
+
<dict>
|
8 |
+
<key>LibraryIdentifier</key>
|
9 |
+
<string>ios-arm64_x86_64-simulator</string>
|
10 |
+
<key>LibraryPath</key>
|
11 |
+
<string>onnxruntime.a</string>
|
12 |
+
<key>SupportedArchitectures</key>
|
13 |
+
<array>
|
14 |
+
<string>arm64</string>
|
15 |
+
<string>x86_64</string>
|
16 |
+
</array>
|
17 |
+
<key>SupportedPlatform</key>
|
18 |
+
<string>ios</string>
|
19 |
+
<key>SupportedPlatformVariant</key>
|
20 |
+
<string>simulator</string>
|
21 |
+
</dict>
|
22 |
+
<dict>
|
23 |
+
<key>LibraryIdentifier</key>
|
24 |
+
<string>ios-arm64</string>
|
25 |
+
<key>LibraryPath</key>
|
26 |
+
<string>onnxruntime.a</string>
|
27 |
+
<key>SupportedArchitectures</key>
|
28 |
+
<array>
|
29 |
+
<string>arm64</string>
|
30 |
+
</array>
|
31 |
+
<key>SupportedPlatform</key>
|
32 |
+
<string>ios</string>
|
33 |
+
</dict>
|
34 |
+
</array>
|
35 |
+
<key>CFBundlePackageType</key>
|
36 |
+
<string>XFWK</string>
|
37 |
+
<key>XCFrameworkFormatVersion</key>
|
38 |
+
<string>1.0</string>
|
39 |
+
</dict>
|
40 |
+
</plist>
|
1.15.1/onnxruntime.xcframework/ios-arm64/onnxruntime.a
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bd716418bdd0b9b7df2d65701d075c7698aa54bcdfe811c360c79fa61e7f3e3b
|
3 |
+
size 57978208
|
1.15.1/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.a
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ae51a98ef737bcfade20f599fb589141ed1e98239f41cd70c59658ef828dfd14
|
3 |
+
size 118264080
|