|
|
#include <iostream> |
|
|
#include <vector> |
|
|
#include <fstream> |
|
|
#include <cstdint> |
|
|
#include <cmath> |
|
|
#include <numeric> |
|
|
#include <algorithm> |
|
|
#include <string> |
|
|
|
|
|
|
|
|
#include <onnxruntime_cxx_api.h> |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <Eigen/Dense> |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <kiss_fft.h> |
|
|
#include <kiss_fftr.h> |
|
|
|
|
|
|
|
|
#ifndef M_PI |
|
|
#define M_PI 3.14159265358979323846 |
|
|
#endif |
|
|
|
|
|
|
|
|
const float PREEMPHASIS_COEFF = 0.97f; |
|
|
const int N_FFT = 512; |
|
|
const int WIN_LENGTH = 400; |
|
|
const int HOP_LENGTH = 160; |
|
|
const int N_MELS = 80; |
|
|
const int TARGET_SAMPLE_RATE = 16000; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#pragma pack(push, 1) |
|
|
|
|
|
struct WavHeader { |
|
|
char riff_id[4]; |
|
|
uint32_t file_size; |
|
|
char wave_id[4]; |
|
|
char fmt_id[4]; |
|
|
uint32_t fmt_size; |
|
|
uint16_t audio_format; |
|
|
uint16_t num_channels; |
|
|
uint32_t sample_rate; |
|
|
uint32_t byte_rate; |
|
|
uint16_t block_align; |
|
|
uint16_t bits_per_sample; |
|
|
}; |
|
|
|
|
|
struct WavDataChunk { |
|
|
char data_id[4]; |
|
|
uint32_t data_size; |
|
|
}; |
|
|
|
|
|
#pragma pack(pop) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<float> loadWavToFloatArray(const std::string& filename, int& actual_sample_rate) { |
|
|
std::ifstream file(filename, std::ios::binary); |
|
|
if (!file.is_open()) { |
|
|
std::cerr << "Error: Could not open WAV file: " << filename << std::endl; |
|
|
return {}; |
|
|
} |
|
|
|
|
|
WavHeader header; |
|
|
file.read(reinterpret_cast<char*>(&header), sizeof(WavHeader)); |
|
|
|
|
|
|
|
|
if (std::string(header.riff_id, 4) != "RIFF" || |
|
|
std::string(header.wave_id, 4) != "WAVE" || |
|
|
std::string(header.fmt_id, 4) != "fmt ") { |
|
|
std::cerr << "Error: Invalid WAV header (RIFF, WAVE, or fmt chunk missing/invalid)." << std::endl; |
|
|
file.close(); |
|
|
return {}; |
|
|
} |
|
|
|
|
|
if (header.audio_format != 1) { |
|
|
std::cerr << "Error: Only PCM audio format (1) is supported. Found: " << header.audio_format << std::endl; |
|
|
file.close(); |
|
|
return {}; |
|
|
} |
|
|
|
|
|
if (header.bits_per_sample != 16) { |
|
|
std::cerr << "Error: Only 16-bit PCM is supported. Found: " << header.bits_per_sample << " bits per sample." << std::endl; |
|
|
file.close(); |
|
|
return {}; |
|
|
} |
|
|
|
|
|
actual_sample_rate = header.sample_rate; |
|
|
std::cout << "WAV file info: Sample Rate=" << header.sample_rate |
|
|
<< ", Channels=" << header.num_channels |
|
|
<< ", Bit Depth=" << header.bits_per_sample << std::endl; |
|
|
|
|
|
|
|
|
WavDataChunk data_chunk; |
|
|
bool data_chunk_found = false; |
|
|
while (!file.eof()) { |
|
|
file.read(reinterpret_cast<char*>(&data_chunk.data_id), 4); |
|
|
file.read(reinterpret_cast<char*>(&data_chunk.data_size), 4); |
|
|
|
|
|
if (std::string(data_chunk.data_id, 4) == "data") { |
|
|
data_chunk_found = true; |
|
|
break; |
|
|
} else { |
|
|
|
|
|
file.seekg(data_chunk.data_size, std::ios::cur); |
|
|
} |
|
|
} |
|
|
|
|
|
if (!data_chunk_found) { |
|
|
std::cerr << "Error: 'data' chunk not found in WAV file." << std::endl; |
|
|
file.close(); |
|
|
return {}; |
|
|
} |
|
|
|
|
|
std::vector<float> audioData; |
|
|
int16_t sample_buffer; |
|
|
long num_samples_to_read = data_chunk.data_size / sizeof(int16_t); |
|
|
|
|
|
for (long i = 0; i < num_samples_to_read; ++i) { |
|
|
file.read(reinterpret_cast<char*>(&sample_buffer), sizeof(int16_t)); |
|
|
float normalized_sample = static_cast<float>(sample_buffer) / 32768.0f; |
|
|
|
|
|
if (header.num_channels == 1) { |
|
|
audioData.push_back(normalized_sample); |
|
|
} else if (header.num_channels == 2) { |
|
|
|
|
|
|
|
|
int16_t right_sample; |
|
|
if (file.read(reinterpret_cast<char*>(&right_sample), sizeof(int16_t))) { |
|
|
float normalized_right_sample = static_cast<float>(right_sample) / 32768.0f; |
|
|
audioData.push_back((normalized_sample + normalized_right_sample) / 2.0f); |
|
|
i++; |
|
|
} else { |
|
|
std::cerr << "Warning: Unexpected end of file while reading stereo data." << std::endl; |
|
|
break; |
|
|
} |
|
|
} else { |
|
|
std::cerr << "Error: Unsupported number of channels: " << header.num_channels << std::endl; |
|
|
file.close(); |
|
|
return {}; |
|
|
} |
|
|
} |
|
|
|
|
|
file.close(); |
|
|
return audioData; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<float> generateHammingWindow(int window_length) { |
|
|
std::vector<float> window(window_length); |
|
|
for (int i = 0; i < window_length; ++i) { |
|
|
window[i] = 0.54f - 0.46f * std::cos(2 * M_PI * i / static_cast<float>(window_length - 1)); |
|
|
} |
|
|
return window; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Eigen::MatrixXf extractSpectrogram(const std::vector<float>& wav, int fs) { |
|
|
|
|
|
int n_batch = (wav.size() - WIN_LENGTH) / HOP_LENGTH + 1; |
|
|
if (n_batch <= 0) { |
|
|
std::cerr << "Warning: Input waveform too short for feature extraction. Returning empty spectrogram." << std::endl; |
|
|
return Eigen::MatrixXf(0, N_FFT / 2 + 1); |
|
|
} |
|
|
|
|
|
|
|
|
std::vector<float> fft_window = generateHammingWindow(WIN_LENGTH); |
|
|
|
|
|
|
|
|
kiss_fftr_cfg fft_cfg = kiss_fftr_alloc(N_FFT, 0 , nullptr, nullptr); |
|
|
if (!fft_cfg) { |
|
|
std::cerr << "Error: Failed to allocate KissFFT configuration." << std::endl; |
|
|
return Eigen::MatrixXf(0, N_FFT / 2 + 1); |
|
|
} |
|
|
|
|
|
|
|
|
Eigen::MatrixXf spec_matrix(n_batch, N_FFT / 2 + 1); |
|
|
|
|
|
std::vector<float> frame_buffer(WIN_LENGTH); |
|
|
kiss_fft_scalar fft_input[N_FFT]; |
|
|
kiss_fft_cpx fft_output[N_FFT / 2 + 1]; |
|
|
|
|
|
for (int i = 0; i < n_batch; ++i) { |
|
|
int start_idx = i * HOP_LENGTH; |
|
|
|
|
|
|
|
|
for (int j = 0; j < WIN_LENGTH; ++j) { |
|
|
frame_buffer[j] = wav[start_idx + j]; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (WIN_LENGTH > 0) { |
|
|
if (WIN_LENGTH > 1) { |
|
|
fft_input[0] = (frame_buffer[0] - PREEMPHASIS_COEFF * frame_buffer[1]) * 32768.0f; |
|
|
for (int j = 1; j < WIN_LENGTH; ++j) { |
|
|
fft_input[j] = (frame_buffer[j] - PREEMPHASIS_COEFF * frame_buffer[j - 1]) * 32768.0f; |
|
|
} |
|
|
} else { |
|
|
fft_input[0] = frame_buffer[0] * 32768.0f; |
|
|
} |
|
|
} |
|
|
|
|
|
for (int j = WIN_LENGTH; j < N_FFT; ++j) { |
|
|
fft_input[j] = 0.0f; |
|
|
} |
|
|
|
|
|
|
|
|
for (int j = 0; j < WIN_LENGTH; ++j) { |
|
|
fft_input[j] *= fft_window[j]; |
|
|
} |
|
|
|
|
|
|
|
|
kiss_fftr(fft_cfg, fft_input, fft_output); |
|
|
|
|
|
|
|
|
for (int j = 0; j <= N_FFT / 2; ++j) { |
|
|
spec_matrix(i, j) = std::sqrt(fft_output[j].r * fft_output[j].r + fft_output[j].i * fft_output[j].i); |
|
|
} |
|
|
} |
|
|
|
|
|
kiss_fftr_free(fft_cfg); |
|
|
return spec_matrix; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Eigen::MatrixXf speechlibMel(int sample_rate, int n_fft, int n_mels, float fmin, float fmax) { |
|
|
int bank_width = n_fft / 2 + 1; |
|
|
if (fmax == 0.0f) fmax = sample_rate / 2.0f; |
|
|
if (fmin == 0.0f) fmin = 0.0f; |
|
|
|
|
|
|
|
|
auto mel = [](float f) { return 1127.0f * std::log(1.0f + f / 700.0f); }; |
|
|
auto bin2mel = [&](int fft_bin) { return 1127.0f * std::log(1.0f + static_cast<float>(fft_bin) * sample_rate / (static_cast<float>(n_fft) * 700.0f)); }; |
|
|
auto f2bin = [&](float f) { return static_cast<int>((f * n_fft / sample_rate) + 0.5f); }; |
|
|
|
|
|
|
|
|
int klo = f2bin(fmin) + 1; |
|
|
int khi = f2bin(fmax); |
|
|
khi = std::max(khi, klo); |
|
|
|
|
|
|
|
|
float mlo = mel(fmin); |
|
|
float mhi = mel(fmax); |
|
|
|
|
|
|
|
|
std::vector<float> m_centers(n_mels + 2); |
|
|
float ms = (mhi - mlo) / (n_mels + 1); |
|
|
for (int i = 0; i < n_mels + 2; ++i) { |
|
|
m_centers[i] = mlo + i * ms; |
|
|
} |
|
|
|
|
|
Eigen::MatrixXf matrix = Eigen::MatrixXf::Zero(n_mels, bank_width); |
|
|
|
|
|
for (int m = 0; m < n_mels; ++m) { |
|
|
float left = m_centers[m]; |
|
|
float center = m_centers[m + 1]; |
|
|
float right = m_centers[m + 2]; |
|
|
for (int fft_bin = klo; fft_bin < bank_width; ++fft_bin) { |
|
|
float mbin = bin2mel(fft_bin); |
|
|
if (left < mbin && mbin < right) { |
|
|
matrix(m, fft_bin) = 1.0f - std::abs(center - mbin) / ms; |
|
|
} |
|
|
} |
|
|
} |
|
|
return matrix; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Eigen::MatrixXf extractFeatures(const std::vector<float>& wav, int fs, const Eigen::MatrixXf& mel_filterbank) { |
|
|
|
|
|
Eigen::MatrixXf spec = extractSpectrogram(wav, fs); |
|
|
if (spec.rows() == 0) { |
|
|
return Eigen::MatrixXf(0, N_MELS); |
|
|
} |
|
|
|
|
|
|
|
|
Eigen::MatrixXf spec_power = spec.array().square(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Eigen::MatrixXf fbank_power = spec_power * mel_filterbank.transpose(); |
|
|
|
|
|
|
|
|
|
|
|
fbank_power = fbank_power.array().max(1.0f); |
|
|
|
|
|
|
|
|
Eigen::MatrixXf log_fbank = fbank_power.array().log(); |
|
|
|
|
|
return log_fbank; |
|
|
} |
|
|
|
|
|
|
|
|
void createDummyWavFile(const std::string& filename, int sampleRate, int numChannels, int bitsPerSample, double durationSeconds) { |
|
|
std::ofstream file(filename, std::ios::binary); |
|
|
if (!file.is_open()) { |
|
|
std::cerr << "Error: Could not create dummy WAV file: " << filename << std::endl; |
|
|
return; |
|
|
} |
|
|
|
|
|
WavHeader header; |
|
|
std::memcpy(header.riff_id, "RIFF", 4); |
|
|
std::memcpy(header.wave_id, "WAVE", 4); |
|
|
std::memcpy(header.fmt_id, "fmt ", 4); |
|
|
header.fmt_size = 16; |
|
|
header.audio_format = 1; |
|
|
header.num_channels = numChannels; |
|
|
header.sample_rate = sampleRate; |
|
|
header.bits_per_sample = bitsPerSample; |
|
|
header.byte_rate = (sampleRate * numChannels * bitsPerSample) / 8; |
|
|
header.block_align = (numChannels * bitsPerSample) / 8; |
|
|
|
|
|
WavDataChunk data_chunk; |
|
|
std::memcpy(data_chunk.data_id, "data", 4); |
|
|
uint32_t num_samples = static_cast<uint32_t>(sampleRate * durationSeconds); |
|
|
data_chunk.data_size = num_samples * numChannels * (bitsPerSample / 8); |
|
|
header.file_size = 36 + data_chunk.data_size; |
|
|
|
|
|
file.write(reinterpret_cast<const char*>(&header), sizeof(WavHeader)); |
|
|
file.write(reinterpret_cast<const char*>(&data_chunk), sizeof(WavDataChunk)); |
|
|
|
|
|
|
|
|
for (uint32_t i = 0; i < num_samples; ++i) { |
|
|
int16_t sample = static_cast<int16_t>(30000 * std::sin(2 * M_PI * 440 * i / static_cast<double>(sampleRate))); |
|
|
for (int c = 0; c < numChannels; ++c) { |
|
|
file.write(reinterpret_cast<const char*>(&sample), sizeof(int16_t)); |
|
|
} |
|
|
} |
|
|
|
|
|
file.close(); |
|
|
std::cout << "Dummy WAV file '" << filename << "' created successfully." << std::endl; |
|
|
} |
|
|
|
|
|
|
|
|
int main(int argc, char* argv[]) { |
|
|
|
|
|
if (argc != 3) { |
|
|
std::cerr << "Usage: " << argv[0] << " <path_to_onnx_model> <path_to_wav_file>" << std::endl; |
|
|
std::cerr << "Example: " << argv[0] << " model.onnx audio.wav" << std::endl; |
|
|
return 1; |
|
|
} |
|
|
|
|
|
std::string onnxModelPath = argv[1]; |
|
|
std::string wavFilename = argv[2]; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int actual_wav_sample_rate = 0; |
|
|
|
|
|
|
|
|
std::ifstream wavCheck(wavFilename, std::ios::binary); |
|
|
if (!wavCheck.is_open()) { |
|
|
std::cerr << "WAV file '" << wavFilename << "' not found. Creating a dummy one for demonstration." << std::endl; |
|
|
|
|
|
createDummyWavFile(wavFilename, TARGET_SAMPLE_RATE, 1, 16, 2.0); |
|
|
} else { |
|
|
wavCheck.close(); |
|
|
} |
|
|
|
|
|
|
|
|
std::vector<float> audioWav = loadWavToFloatArray(wavFilename, actual_wav_sample_rate); |
|
|
|
|
|
if (audioWav.empty()) { |
|
|
std::cerr << "Failed to load audio data from " << wavFilename << ". Exiting." << std::endl; |
|
|
return 1; |
|
|
} |
|
|
|
|
|
std::cout << "Successfully loaded " << audioWav.size() << " samples from " << wavFilename << std::endl; |
|
|
|
|
|
|
|
|
if (actual_wav_sample_rate != TARGET_SAMPLE_RATE) { |
|
|
std::cerr << "Warning: WAV file sample rate (" << actual_wav_sample_rate |
|
|
<< " Hz) does not match the target sample rate for feature extraction (" |
|
|
<< TARGET_SAMPLE_RATE << " Hz)." << std::endl; |
|
|
std::cerr << "This example does NOT include resampling. Features will be extracted at " |
|
|
<< TARGET_SAMPLE_RATE << " Hz, which might lead to incorrect results if the WAV file's sample rate is different." << std::endl; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
float mel_fmax = static_cast<float>(TARGET_SAMPLE_RATE) / 2.0f - 80.0f - 230.0f; |
|
|
Eigen::MatrixXf mel_filterbank = speechlibMel(TARGET_SAMPLE_RATE, N_FFT, N_MELS, 0.0f, mel_fmax); |
|
|
|
|
|
if (mel_filterbank.rows() == 0 || mel_filterbank.cols() == 0) { |
|
|
std::cerr << "Error: Failed to create Mel filterbank. Exiting." << std::endl; |
|
|
return 1; |
|
|
} |
|
|
std::cout << "Mel filterbank created with shape: [" << mel_filterbank.rows() << ", " << mel_filterbank.cols() << "]" << std::endl; |
|
|
|
|
|
|
|
|
|
|
|
std::cout << "Extracting features from audio..." << std::endl; |
|
|
Eigen::MatrixXf features = extractFeatures(audioWav, TARGET_SAMPLE_RATE, mel_filterbank); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (features.rows() == 0 || features.cols() == 0) { |
|
|
std::cerr << "Error: Feature extraction resulted in an empty matrix. Exiting." << std::endl; |
|
|
return 1; |
|
|
} |
|
|
std::cout << "Features extracted with shape: [" << features.rows() << ", " << features.cols() << "]" << std::endl; |
|
|
std::cout << "First few feature values (first frame): ["; |
|
|
for (int i = 0; i < std::min((int)features.cols(), 5); ++i) { |
|
|
std::cout << features(0, i) << (i == std::min((int)features.cols(), 5) - 1 ? "" : ", "); |
|
|
} |
|
|
std::cout << "]" << std::endl; |
|
|
|
|
|
|
|
|
std::ifstream onnxModelCheck(onnxModelPath, std::ios::binary); |
|
|
if (!onnxModelCheck.is_open()) { |
|
|
std::cerr << "\nError: ONNX model file '" << onnxModelPath << "' not found." << std::endl; |
|
|
std::cerr << "Please provide a valid ONNX model file. If you need a simple dummy one for testing, " |
|
|
<< "you can create it using Python (e.g., with PyTorch) like this:" << std::endl; |
|
|
std::cerr << "```python" << std::endl; |
|
|
std::cerr << "import torch" << std::endl; |
|
|
std::cerr << "import torch.nn as nn" << std::endl; |
|
|
std::cerr << "" << std::endl; |
|
|
std::cerr << "class SimpleAudioModel(nn.Module):" << std::endl; |
|
|
std::cerr << " def __init__(self, input_frames, feature_size, output_size):" << std::endl; |
|
|
std::cerr << " super(SimpleAudioModel, self).__init__()" << std::endl; |
|
|
std::cerr << " # This model expects input of shape [batch_size, frames, feature_size]" << std::endl; |
|
|
std::cerr << " # Example: a simple linear layer that flattens input and processes it." << std::endl; |
|
|
std::cerr << " self.flatten = nn.Flatten()" << std::endl; |
|
|
std::cerr << " self.linear = nn.Linear(input_frames * feature_size, output_size)" << std::endl; |
|
|
std::cerr << "" << std::endl; |
|
|
std::cerr << " def forward(self, x):" << std::endl; |
|
|
std::cerr << " x = self.flatten(x)" << std::endl; |
|
|
std::cerr << " return self.linear(x)" << std::endl; |
|
|
std::cerr << "" << std::endl; |
|
|
std::cerr << "# --- IMPORTANT: Define model input and output sizes. Adjust these to match your actual model's requirements. ---" << std::endl; |
|
|
std::cerr << "# The C++ preprocessor will produce features of shape [frames, 80]." << std::endl; |
|
|
std::cerr << "# For a dummy model, we need to provide a fixed 'frames' value for ONNX export." << std::endl; |
|
|
std::cerr << "# A typical audio segment might be 2 seconds at 16kHz, which is 32000 samples." << std::endl; |
|
|
std::cerr << "# Frames = (32000 - 400) / 160 + 1 = 198.75 + 1 = 199 frames (approx)" << std::endl; |
|
|
std::cerr << "# Let's use a representative number of frames, e.g., 200 for a dummy input." << std::endl; |
|
|
std::cerr << "DUMMY_INPUT_FRAMES = 200 # This should be representative of your typical audio segment's frames" << std::endl; |
|
|
std::cerr << "DUMMY_FEATURE_SIZE = 80 # Fixed by the Mel filterbank (N_MELS)" << std::endl; |
|
|
std::cerr << "DUMMY_OUTPUT_SIZE = 10 # Example: 10 classification scores or features" << std::endl; |
|
|
std::cerr << "" << std::endl; |
|
|
std::cerr << "model = SimpleAudioModel(DUMMY_INPUT_FRAMES, DUMMY_FEATURE_SIZE, DUMMY_OUTPUT_SIZE)" << std::endl; |
|
|
std::cerr << "dummy_input_tensor = torch.randn(1, DUMMY_INPUT_FRAMES, DUMMY_FEATURE_SIZE) # Batch size 1" << std::endl; |
|
|
std::cerr << "" << std::endl; |
|
|
std::cerr << "torch.onnx.export(" << std::endl; |
|
|
std::cerr << " model," << std::endl; |
|
|
std::cerr << " dummy_input_tensor," << std::endl; |
|
|
std::cerr << " \"model.onnx\"," << std::endl; |
|
|
std::cerr << " verbose=True," << std::endl; |
|
|
std::cerr << " input_names=['input'], # Name of the input tensor in the ONNX graph" << std::endl; |
|
|
std::cerr << " output_names=['output'], # Name of the output tensor in the ONNX graph" << std::endl; |
|
|
std::cerr << " # Define dynamic axes for batch_size and frames" << std::endl; |
|
|
std::cerr << " dynamic_axes={'input': {0: 'batch_size', 1: 'frames'}, 'output': {0: 'batch_size'}}" << std::endl; |
|
|
std::cerr << ")" << std::endl; |
|
|
std::cerr << "print(\"Dummy model.onnx created successfully. Remember to adjust DUMMY_INPUT_FRAMES in this script to match the expected number of frames from your audio segments.\")" << std::endl; |
|
|
std::cerr << "```" << std::endl; |
|
|
return 1; |
|
|
} |
|
|
onnxModelCheck.close(); |
|
|
std::cout << "ONNX model '" << onnxModelPath << "' found. Proceeding with inference." << std::endl; |
|
|
|
|
|
|
|
|
|
|
|
try { |
|
|
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "AudioInference"); |
|
|
Ort::SessionOptions session_options; |
|
|
session_options.SetIntraOpNumThreads(1); |
|
|
session_options.SetGraphOptimizationLevel(ORT_ENABLE_EXTENDED); |
|
|
|
|
|
Ort::Session session(env, onnxModelPath.c_str(), session_options); |
|
|
Ort::AllocatorWithDefaultOptions allocator; |
|
|
|
|
|
|
|
|
size_t numInputNodes = session.GetInputCount(); |
|
|
std::vector<const char*> inputNodeNames(numInputNodes); |
|
|
|
|
|
std::cout << "\n--- Model Input Information ---" << std::endl; |
|
|
if (numInputNodes == 0) { |
|
|
std::cerr << "Error: Model has no input nodes. Exiting." << std::endl; |
|
|
return 1; |
|
|
} |
|
|
|
|
|
|
|
|
inputNodeNames[0] = "audio_embeds"; |
|
|
Ort::TypeInfo type_info = session.GetInputTypeInfo(0); |
|
|
auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); |
|
|
std::vector<int64_t> actualInputShape = tensor_info.GetShape(); |
|
|
|
|
|
std::cout << " Input 0 : Name='" << inputNodeNames[0] << "', Shape=["; |
|
|
for (size_t j = 0; j < actualInputShape.size(); ++j) { |
|
|
|
|
|
if (actualInputShape[j] == -1) { |
|
|
std::cout << "-1"; |
|
|
} else { |
|
|
std::cout << actualInputShape[j]; |
|
|
} |
|
|
std::cout << (j == actualInputShape.size() - 1 ? "" : ", "); |
|
|
} |
|
|
std::cout << "]" << std::endl; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> inputTensorShape = {1, features.rows(), features.cols()}; |
|
|
std::cout << " Preparing input tensor with shape: [" << inputTensorShape[0] << ", " |
|
|
<< inputTensorShape[1] << ", " << inputTensorShape[2] << "]" << std::endl; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<float> inputTensorData(features.rows() * features.cols()); |
|
|
for (int r = 0; r < features.rows(); ++r) { |
|
|
for (int c = 0; c < features.cols(); ++c) { |
|
|
inputTensorData[r * features.cols() + c] = features(r, c); |
|
|
} |
|
|
} |
|
|
|
|
|
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); |
|
|
Ort::Value inputTensor = Ort::Value::CreateTensor<float>(memory_info, inputTensorData.data(), inputTensorData.size(), |
|
|
inputTensorShape.data(), inputTensorShape.size()); |
|
|
|
|
|
if (!inputTensor.IsTensor()) { |
|
|
std::cerr << "Error: Created input tensor is not valid! Exiting." << std::endl; |
|
|
return 1; |
|
|
} |
|
|
|
|
|
|
|
|
size_t numOutputNodes = session.GetOutputCount(); |
|
|
std::vector<const char*> outputNodeNames(numOutputNodes); |
|
|
|
|
|
std::cout << "\n--- Model Output Information ---" << std::endl; |
|
|
for (size_t k = 0; k < numOutputNodes; ++k) { |
|
|
outputNodeNames[k] = "audio_features"; |
|
|
Ort::TypeInfo type_info_out = session.GetOutputTypeInfo(k); |
|
|
auto tensor_info_out = type_info_out.GetTensorTypeAndShapeInfo(); |
|
|
std::vector<int64_t> outputShape = tensor_info_out.GetShape(); |
|
|
std::cout << " Output " << k << " : Name='" << outputNodeNames[k] << "', Shape=["; |
|
|
for (size_t l = 0; l < outputShape.size(); ++l) { |
|
|
if (outputShape[l] == -1) { |
|
|
std::cout << "-1"; |
|
|
} else { |
|
|
std::cout << outputShape[l]; |
|
|
} |
|
|
std::cout << (l == outputShape.size() - 1 ? "" : ", "); |
|
|
} |
|
|
std::cout << "]" << std::endl; |
|
|
} |
|
|
|
|
|
|
|
|
std::cout << "\nRunning ONNX model inference..." << std::endl; |
|
|
std::vector<Ort::Value> outputTensors = session.Run(Ort::RunOptions{nullptr}, |
|
|
inputNodeNames.data(), &inputTensor, 1, |
|
|
outputNodeNames.data(), numOutputNodes); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (outputTensors.empty()) { |
|
|
std::cerr << "Error: No output tensors received from the model." << std::endl; |
|
|
return 1; |
|
|
} |
|
|
|
|
|
if (outputTensors[0].IsTensor()) { |
|
|
float* outputData = outputTensors[0].GetTensorMutableData<float>(); |
|
|
Ort::TensorTypeAndShapeInfo outputShapeInfo = outputTensors[0].GetTensorTypeAndShapeInfo(); |
|
|
std::vector<int64_t> outputShape = outputShapeInfo.GetShape(); |
|
|
size_t outputSize = outputShapeInfo.GetElementCount(); |
|
|
|
|
|
std::cout << "\n--- Model Inference Result (first few elements) ---" << std::endl; |
|
|
for (size_t k = 0; k < std::min((size_t)10, outputSize); ++k) { |
|
|
std::cout << outputData[k] << (k == std::min((size_t)10, outputSize) - 1 ? "" : ", "); |
|
|
} |
|
|
std::cout << std::endl; |
|
|
|
|
|
std::cout << "Full output tensor size: " << outputSize << " elements." << std::endl; |
|
|
std::cout << "Full output tensor shape: ["; |
|
|
for (size_t k = 0; k < outputShape.size(); ++k) { |
|
|
std::cout << outputShape[k] << (k == outputShape.size() - 1 ? "" : ", "); |
|
|
} |
|
|
std::cout << "]" << std::endl; |
|
|
} else { |
|
|
std::cerr << "Error: First output tensor is not of the expected type (float tensor)." << std::endl; |
|
|
} |
|
|
|
|
|
} catch (const Ort::Exception& e) { |
|
|
std::cerr << "ONNX Runtime Exception: " << e.what() << std::endl; |
|
|
return 1; |
|
|
} catch (const std::exception& e) { |
|
|
std::cerr << "Standard Exception: " << e.what() << std::endl; |
|
|
return 1; |
|
|
} |
|
|
|
|
|
std::cout << "\nProgram finished successfully." << std::endl; |
|
|
return 0; |
|
|
} |