Upload 4 files
Browse files- README.md +0 -3
- convert.py +23 -0
- main.cpp +79 -0
- xg_runtime_api.h +138 -0
README.md
CHANGED
@@ -1,3 +0,0 @@
|
|
1 |
-
---
|
2 |
-
license: mit
|
3 |
-
---
|
|
|
|
|
|
|
|
convert.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from onnx import numpy_helper
|
2 |
+
import numpy as np
|
3 |
+
import onnx
|
4 |
+
import ffmpeg
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
# Parameter settings
|
8 |
+
parser = argparse.ArgumentParser(description='Whisper format converter')
|
9 |
+
parser.add_argument('--ipath', metavar='S', help='path to the input file')
|
10 |
+
parser.add_argument('--opath', metavar='S', help='path to the output file (.pb extension)')
|
11 |
+
args = parser.parse_args()
|
12 |
+
|
13 |
+
if __name__ == '__main__':
|
14 |
+
|
15 |
+
out, _ = (
|
16 |
+
ffmpeg.input(args.ipath, threads=0)
|
17 |
+
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=16000)
|
18 |
+
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
|
19 |
+
)
|
20 |
+
audio = np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
21 |
+
|
22 |
+
onnx_tp = numpy_helper.from_array(audio, 'raw_audio')
|
23 |
+
onnx.save_tensor(onnx_tp, args.opath)
|
main.cpp
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <iostream>
|
2 |
+
#include "xg_runtime_api.h"
|
3 |
+
|
4 |
+
|
5 |
+
void test_whisper(const std::string& weight_path, const std::string& input_path);
|
6 |
+
|
7 |
+
int main(int argc, char** argv) {
|
8 |
+
|
9 |
+
if (argc == 3)
|
10 |
+
{
|
11 |
+
std::string weight_path = argv[1];
|
12 |
+
std::string input_path = argv[2];
|
13 |
+
test_whisper(weight_path, input_path);
|
14 |
+
}
|
15 |
+
|
16 |
+
return 0;
|
17 |
+
}
|
18 |
+
|
19 |
+
void test_whisper(const std::string& weight_path, const std::string& input_path)
|
20 |
+
{
|
21 |
+
XgModelInfo minfo = {};
|
22 |
+
xg_get_model_info(&minfo);
|
23 |
+
std::cout << minfo.model_name << " " << minfo.model_version << std::endl;
|
24 |
+
|
25 |
+
std::cout << "initing graph" << std::endl;
|
26 |
+
XgGraph* graph = nullptr;
|
27 |
+
if (xg_init_graph(weight_path, XGWeightSource::XG_ONNX, &graph) != XGResult::XG_SUCCESS)
|
28 |
+
{
|
29 |
+
std::cout << "Graph init error" << std::endl;
|
30 |
+
return;
|
31 |
+
}
|
32 |
+
else
|
33 |
+
{
|
34 |
+
std::cout << "Graph init: successful" << std::endl;
|
35 |
+
}
|
36 |
+
|
37 |
+
XgData* input_data = nullptr;
|
38 |
+
if (xg_allocate_input_compatible_data(0, &input_data) != XGResult::XG_SUCCESS)
|
39 |
+
{
|
40 |
+
std::cout << "Input allocation error" << std::endl;
|
41 |
+
return;
|
42 |
+
}
|
43 |
+
else
|
44 |
+
{
|
45 |
+
std::cout << "Input allocation: successful" << std::endl;
|
46 |
+
}
|
47 |
+
|
48 |
+
// load the data into XgData
|
49 |
+
reinterpret_cast<std::string*>(input_data->raw_data)[0] = input_path;
|
50 |
+
|
51 |
+
if (xg_set_input_data(graph, 0, input_data) != XGResult::XG_SUCCESS)
|
52 |
+
{
|
53 |
+
std::cout << "Input data set error" << std::endl;
|
54 |
+
return;
|
55 |
+
}
|
56 |
+
else
|
57 |
+
{
|
58 |
+
std::cout << "Input data set: successful" << std::endl;
|
59 |
+
}
|
60 |
+
|
61 |
+
// execute the graph
|
62 |
+
xg_execute_graph(graph);
|
63 |
+
|
64 |
+
// write output
|
65 |
+
XgData* output_data = nullptr;
|
66 |
+
if (xg_get_output_data(graph, 0, &output_data) != XGResult::XG_SUCCESS)
|
67 |
+
{
|
68 |
+
std::cout << "Getting output error" << std::endl;
|
69 |
+
return;
|
70 |
+
}
|
71 |
+
else
|
72 |
+
{
|
73 |
+
std::cout << "Getting output: successful" << std::endl;
|
74 |
+
}
|
75 |
+
|
76 |
+
// print output
|
77 |
+
std::string* o1 = reinterpret_cast<std::string*>(output_data->raw_data);
|
78 |
+
std::cout << o1[0] << std::endl;
|
79 |
+
}
|
xg_runtime_api.h
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef __XG_RUNTIME_API__
|
2 |
+
#define __XG_RUNTIME_API__
|
3 |
+
|
4 |
+
#include <vector>
|
5 |
+
#include <string>
|
6 |
+
|
7 |
+
#if _WIN32
|
8 |
+
#define XG_API extern "C" __declspec(dllexport)
|
9 |
+
#elif __unix__ || __linux__
|
10 |
+
#define XG_API extern "C"
|
11 |
+
#endif
|
12 |
+
|
13 |
+
// XG type definitions
|
14 |
+
|
15 |
+
enum class XGResult
|
16 |
+
{
|
17 |
+
XG_SUCCESS,
|
18 |
+
XG_INPUT_SIZE_MISSMATCH,
|
19 |
+
XG_INPUT_TYPE_MISSMATCH,
|
20 |
+
XG_WRONG_INPUT_INDEX,
|
21 |
+
XG_WRONG_OUTPUT_INDEX,
|
22 |
+
// device related
|
23 |
+
XG_DEVICE_NOT_SUPPORTED,
|
24 |
+
XG_MEMORY_ALLOCATION_FAILED,
|
25 |
+
// weight file access
|
26 |
+
XG_FILE_NOT_FOUND,
|
27 |
+
XG_EXECUTION_FAILED
|
28 |
+
};
|
29 |
+
|
30 |
+
enum class XGWeightSource
|
31 |
+
{
|
32 |
+
XG_ONNX,
|
33 |
+
XG_XGDB
|
34 |
+
};
|
35 |
+
|
36 |
+
enum class XGDataType
|
37 |
+
{
|
38 |
+
XG_BOOL,
|
39 |
+
XG_TOKEN,
|
40 |
+
XG_STRING,
|
41 |
+
XG_UINT8,
|
42 |
+
XG_UINT16,
|
43 |
+
XG_UINT32,
|
44 |
+
XG_UINT64,
|
45 |
+
XG_INT8,
|
46 |
+
XG_INT16,
|
47 |
+
XG_INT32,
|
48 |
+
XG_INT64,
|
49 |
+
XG_BFLOAT16,
|
50 |
+
XG_FLOAT16,
|
51 |
+
XG_FLOAT32,
|
52 |
+
XG_FLOAT64
|
53 |
+
};
|
54 |
+
|
55 |
+
// access information about the contained model
|
56 |
+
struct XgModelInfo
|
57 |
+
{
|
58 |
+
std::string model_name;
|
59 |
+
std::string model_version;
|
60 |
+
std::string device; // cpu, gpu, tpu etc.
|
61 |
+
std::string hardware; // e.g. intel i7 9th gen
|
62 |
+
unsigned int num_inputs;
|
63 |
+
unsigned int num_outputs;
|
64 |
+
};
|
65 |
+
|
66 |
+
XG_API void xg_get_model_info(
|
67 |
+
XgModelInfo* model_info
|
68 |
+
);
|
69 |
+
|
70 |
+
XG_API bool is_current_device_supported(); // may be list the supported devices on this machine
|
71 |
+
|
72 |
+
// create graph
|
73 |
+
struct XgGraph;
|
74 |
+
|
75 |
+
XG_API XGResult xg_init_graph(
|
76 |
+
const std::string& weight_path,
|
77 |
+
const XGWeightSource weight_source,
|
78 |
+
XgGraph** graph
|
79 |
+
);
|
80 |
+
XG_API XGResult xg_execute_graph(
|
81 |
+
XgGraph* graph
|
82 |
+
);
|
83 |
+
XG_API XGResult xg_destroy_graph(
|
84 |
+
XgGraph** graph
|
85 |
+
);
|
86 |
+
|
87 |
+
// set the input to the graph,
|
88 |
+
// query the output
|
89 |
+
|
90 |
+
struct XgData
|
91 |
+
{
|
92 |
+
XGDataType dtype;
|
93 |
+
unsigned int size_in_bytes;
|
94 |
+
unsigned int dimension;
|
95 |
+
unsigned int length;
|
96 |
+
unsigned int* shape;
|
97 |
+
char* raw_data;
|
98 |
+
};
|
99 |
+
|
100 |
+
XG_API unsigned int xg_calculate_tensor_size_in_bytes(
|
101 |
+
const XGDataType dtype,
|
102 |
+
const unsigned int* shape,
|
103 |
+
const unsigned int dimension
|
104 |
+
);
|
105 |
+
XG_API XGResult xg_allocate_input_compatible_data(
|
106 |
+
const unsigned int input_idx,
|
107 |
+
XgData** data
|
108 |
+
);
|
109 |
+
XG_API XGResult xg_destroy_data(
|
110 |
+
XgData** data
|
111 |
+
);
|
112 |
+
XG_API XGResult xg_get_output_data(
|
113 |
+
const XgGraph* graph,
|
114 |
+
const unsigned int output_idx,
|
115 |
+
XgData** data
|
116 |
+
);
|
117 |
+
XG_API XGResult xg_set_input_data(
|
118 |
+
const XgGraph* graph,
|
119 |
+
const unsigned int input_idx,
|
120 |
+
const XgData* data
|
121 |
+
);
|
122 |
+
|
123 |
+
// helper functions
|
124 |
+
XG_API bool xg_is_data_bool(const XgData* data);
|
125 |
+
XG_API bool xg_is_data_uint8(const XgData* data);
|
126 |
+
XG_API bool xg_is_data_uint16(const XgData* data);
|
127 |
+
XG_API bool xg_is_data_uint32(const XgData* data);
|
128 |
+
XG_API bool xg_is_data_uint64(const XgData* data);
|
129 |
+
XG_API bool xg_is_data_int8(const XgData* data);
|
130 |
+
XG_API bool xg_is_data_int16(const XgData* data);
|
131 |
+
XG_API bool xg_is_data_int32(const XgData* data);
|
132 |
+
XG_API bool xg_is_data_int64(const XgData* data);
|
133 |
+
XG_API bool xg_is_data_bfloat16(const XgData* data);
|
134 |
+
XG_API bool xg_is_data_float16(const XgData* data);
|
135 |
+
XG_API bool xg_is_data_float32(const XgData* data);
|
136 |
+
XG_API bool xg_is_data_float64(const XgData* data);
|
137 |
+
|
138 |
+
#endif // __XG_RUNTIME_API__
|