vad_cpp / vad_onnx /vad_onnx.cpp
hzeng412's picture
Duplicate from MoYoYoTech/vad_cpp
d21d362
#include <stdexcept>
#include <cmath>
#include <iostream>
#include "vad_onnx.h"
static void get_input_names(Ort::Session* session, std::vector<std::string> &input_names_str,
std::vector<const char *> &input_names_char) {
Ort::AllocatorWithDefaultOptions allocator;
size_t nodes_num = session->GetInputCount();
input_names_str.resize(nodes_num);
input_names_char.resize(nodes_num);
for (size_t i = 0; i != nodes_num; ++i) {
auto t = session->GetInputNameAllocated(i, allocator);
input_names_str[i] = t.get();
input_names_char[i] = input_names_str[i].c_str();
}
}
static void get_output_names(Ort::Session* session, std::vector<std::string> &output_names_,
std::vector<const char *> &vad_out_names_) {
Ort::AllocatorWithDefaultOptions allocator;
size_t nodes_num = session->GetOutputCount();
output_names_.resize(nodes_num);
vad_out_names_.resize(nodes_num);
for (size_t i = 0; i != nodes_num; ++i) {
auto t = session->GetOutputNameAllocated(i, allocator);
output_names_[i] = t.get();
vad_out_names_[i] = output_names_[i].c_str();
}
}
VadOnnx::VadOnnx(const std::string& model_path,
int batch_size,
int thread_num,
float threshold,
int sampling_rate,
int min_silence_duration_ms,
float max_speech_duration_s,
int speech_pad_ms)
: batch_size_(batch_size),
thread_num_(thread_num),
threshold_(threshold),
sample_rates_(sampling_rate),
min_silence_samples_(sampling_rate * min_silence_duration_ms / 1000.0),
speech_pad_samples_(sampling_rate * speech_pad_ms / 1000.0),
triggered_(false),
temp_end_(0),
current_sample_(0),
start_(0),
memory_info(Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU))
{
init_onnx_model(model_path);
get_input_names(session.get(), input_names_, vad_in_names_);
get_output_names(session.get(), output_names_, vad_out_names_);
sr.resize(1);
sr[0] = sample_rates_;
if (batch_size_ != 1) {
state_shape = {2, batch_size_, 128};
state_size = 2 * batch_size_ * 128;
}
state_.resize(state_size);
context_size = (sample_rates_ == 16000) ? 64 : 32;
context_.resize(context_size);
effective_window_size = window_size_samples + context_size;
input_node_shape[0] = 1;
input_node_shape[1] = effective_window_size;
reset_states();
}
VadOnnx::~VadOnnx() = default;
void VadOnnx::reset_states() {
std::memset(state_.data(), 0, state_.size() * sizeof(float));
std::fill(context_.begin(), context_.end(), 0.0f);
triggered_ = false;
temp_end_ = 0;
current_sample_ = 0;
start_ = 0;
last_sr_ = 0;
last_batch_size_ = 0;
}
float VadOnnx::forward_infer(std::vector<float>& data_chunk) {
// 合并 context 和 input
std::vector<float> x_with_context(effective_window_size, 0.0f);
std::copy(context_.begin(), context_.end(), x_with_context.begin());
std::copy(data_chunk.begin(), data_chunk.end(), x_with_context.begin() + context_size);
input = x_with_context;
// Prepare inputs
Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
memory_info, input.data(), input.size(), input_node_shape.data(), 2);
Ort::Value state_tensor = Ort::Value::CreateTensor<float>(
memory_info, state_.data(), state_.size(), state_shape.data(), 3);
Ort::Value sr_tensor = Ort::Value::CreateTensor<int64_t>(
memory_info, sr.data(), 1, sr_shape.data(), 1);
ort_inputs.clear();
ort_inputs.emplace_back(std::move(input_tensor));
ort_inputs.emplace_back(std::move(state_tensor));
ort_inputs.emplace_back(std::move(sr_tensor));
// Run inference
ort_outputs = session->Run(
Ort::RunOptions{nullptr}, vad_in_names_.data(), ort_inputs.data(),
ort_inputs.size(), vad_out_names_.data(), vad_out_names_.size());
// Get output
float speech_prob = ort_outputs[0].GetTensorMutableData<float>()[0];
// Update state
float* stateN = ort_outputs[1].GetTensorMutableData<float>();
std::memcpy(state_.data(), stateN, state_size * sizeof(float));
// Update context
std::copy(x_with_context.end() - context_size, x_with_context.end(), context_.begin());
return speech_prob;
}
std::vector<float> VadOnnx::vad_dectect(std::vector<float>& audio) {
std::vector<float> result;
// Pad to multiple of num_samples
int pad_num = (window_size_samples - (audio.size() % window_size_samples)) % window_size_samples;
audio.insert(audio.end(), pad_num, 0.0f);
for (size_t i = 0; i < audio.size(); i += window_size_samples) {
std::vector<float> chunk(audio.begin() + i, audio.begin() + i + window_size_samples);
auto prob = forward_infer(chunk);
result.emplace_back(prob);
}
return result;
}
std::map<std::string, double> VadOnnx::vad_dectect(std::vector<float>& audio, bool return_seconds) {
std::map<std::string, double> result;
// 将新音频追加到缓存中
buffer_.insert(buffer_.end(), audio.begin(), audio.end());
while (buffer_.size() > 0) {
std::map<std::string, double> tmp;
std::vector<float> chunk(buffer_.begin(), buffer_.begin() + std::min(static_cast<int>(buffer_.size()), window_size_samples));
// 补零到固定长度
if (chunk.size() < static_cast<size_t>(window_size_samples)) {
chunk.resize(window_size_samples, 0.0f);
}
current_sample_ += window_size_samples;
// 推理得到语音概率
float speech_prob = forward_infer(chunk);
if (speech_prob >= threshold_ && temp_end_ > 0) {
temp_end_ = 0;
}
if (speech_prob >= threshold_ && !triggered_) {
triggered_ = true;
start_ = std::max(0.0, current_sample_ - window_size_samples);
tmp["start"] = return_seconds ? start_ / sample_rates_ : start_;
}
if (speech_prob < (threshold_ - 0.15) && triggered_) {
if (temp_end_ == 0) {
temp_end_ = current_sample_;
}
if (current_sample_ - temp_end_ >= min_silence_samples_) {
double speech_end = temp_end_;
tmp["end"] = return_seconds ? speech_end / sample_rates_ : speech_end;
temp_end_ = 0;
triggered_ = false;
}
}
// 移除已处理的数据
if (window_size_samples >= buffer_.size()) {
buffer_.clear(); // 全部丢弃
} else {
std::copy(buffer_.begin() + window_size_samples, buffer_.end(), buffer_.begin());
buffer_.resize(buffer_.size() - window_size_samples);
}
// 合并检测结果
if (result.empty()) {
result = tmp;
} else if (!tmp.empty()) {
// 如果当前结果有 'end',更新最终 end
if (tmp.find("end") != tmp.end()) {
result["end"] = tmp["end"];
}
// 如果有新的 start,但前一个有 end,则合并成连续语音段
if (tmp.find("start") != tmp.end() && result.find("end") != result.end()) {
result.erase("end");
}
}
}
return result;
}
void VadOnnx::init_onnx_model(const std::string& model_path) {
init_engine_threads(1, 1);
init_exec_provider();
// 初始化 ONNX Session
env_ = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "VadOnnx");
session = std::make_unique<Ort::Session>(env_, ORTCHAR(model_path.c_str()), session_options);
}
void VadOnnx::init_engine_threads(int inter_threads, int intra_threads) {
session_options.SetInterOpNumThreads(inter_threads);
session_options.SetIntraOpNumThreads(intra_threads);
session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
}
void VadOnnx::init_exec_provider() {
// 获取所有可用的 Execution Providers
std::vector<std::string> providers = Ort::GetAvailableProviders();
// 根据支持情况添加 Execution Provider
if (std::find(providers.begin(), providers.end(), "CUDAExecutionProvider") != providers.end()) {
OrtCUDAProviderOptions cuda_options{};
session_options.AppendExecutionProvider_CUDA(cuda_options);
}
// #if defined(__APPLE__)
// if (std::find(providers.begin(), providers.end(), "CoreMLExecutionProvider") != providers.end()) {
// session_options.AppendExecutionProvider_CoreML();
// }
// #endif
}