|
|
|
|
|
|
|
import argparse |
|
from pathlib import Path |
|
|
|
import kaldi_native_fbank as knf |
|
import librosa |
|
import numpy as np |
|
import onnxruntime as ort |
|
import soundfile as sf |
|
import torch |
|
|
|
|
|
def create_fbank(): |
|
opts = knf.FbankOptions() |
|
opts.frame_opts.dither = 0 |
|
opts.frame_opts.remove_dc_offset = False |
|
opts.frame_opts.preemph_coeff = 0 |
|
opts.frame_opts.window_type = "hann" |
|
|
|
|
|
|
|
opts.frame_opts.round_to_power_of_two = True |
|
|
|
opts.mel_opts.low_freq = 0 |
|
opts.mel_opts.high_freq = 8000 |
|
opts.mel_opts.num_bins = 64 |
|
|
|
fbank = knf.OnlineFbank(opts) |
|
return fbank |
|
|
|
|
|
def compute_features(audio, fbank): |
|
assert len(audio.shape) == 1, audio.shape |
|
fbank.accept_waveform(16000, audio) |
|
ans = [] |
|
processed = 0 |
|
while processed < fbank.num_frames_ready: |
|
ans.append(np.array(fbank.get_frame(processed))) |
|
processed += 1 |
|
ans = np.stack(ans) |
|
return ans |
|
|
|
|
|
def display(sess): |
|
print("==========Input==========") |
|
for i in sess.get_inputs(): |
|
print(i) |
|
print("==========Output==========") |
|
for i in sess.get_outputs(): |
|
print(i) |
|
|
|
|
|
""" |
|
==========Input========== |
|
NodeArg(name='audio_signal', type='tensor(float)', shape=['audio_signal_dynamic_axes_1', 64, 'audio_signal_dynamic_axes_2']) |
|
NodeArg(name='length', type='tensor(int64)', shape=['length_dynamic_axes_1']) |
|
==========Output========== |
|
NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 768, 'outputs_dynamic_axes_2']) |
|
NodeArg(name='encoded_lengths', type='tensor(int64)', shape=['encoded_lengths_dynamic_axes_1']) |
|
==========Input========== |
|
NodeArg(name='targets', type='tensor(int32)', shape=['targets_dynamic_axes_1', 'targets_dynamic_axes_2']) |
|
NodeArg(name='target_length', type='tensor(int32)', shape=['target_length_dynamic_axes_1']) |
|
NodeArg(name='states.1', type='tensor(float)', shape=[1, 'states.1_dim_1', 320]) |
|
NodeArg(name='onnx::LSTM_3', type='tensor(float)', shape=[1, 1, 320]) |
|
==========Output========== |
|
NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 320, 'outputs_dynamic_axes_2']) |
|
NodeArg(name='prednet_lengths', type='tensor(int32)', shape=['prednet_lengths_dynamic_axes_1']) |
|
NodeArg(name='states', type='tensor(float)', shape=[1, 'states_dynamic_axes_1', 320]) |
|
NodeArg(name='74', type='tensor(float)', shape=[1, 'states_dynamic_axes_1', 320]) |
|
==========Input========== |
|
NodeArg(name='encoder_outputs', type='tensor(float)', shape=['encoder_outputs_dynamic_axes_1', 768, 'encoder_outputs_dynamic_axes_2']) |
|
NodeArg(name='decoder_outputs', type='tensor(float)', shape=['decoder_outputs_dynamic_axes_1', 320, 'decoder_outputs_dynamic_axes_2']) |
|
==========Output========== |
|
NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 'outputs_dynamic_axes_2', 'outputs_dynamic_axes_3', 513]) |
|
""" |
|
|
|
|
|
class OnnxModel: |
|
def __init__( |
|
self, |
|
encoder: str, |
|
decoder: str, |
|
joiner: str, |
|
): |
|
self.init_encoder(encoder) |
|
display(self.encoder) |
|
self.init_decoder(decoder) |
|
display(self.decoder) |
|
self.init_joiner(joiner) |
|
display(self.joiner) |
|
|
|
def init_encoder(self, encoder): |
|
session_opts = ort.SessionOptions() |
|
session_opts.inter_op_num_threads = 1 |
|
session_opts.intra_op_num_threads = 1 |
|
|
|
self.encoder = ort.InferenceSession( |
|
encoder, |
|
sess_options=session_opts, |
|
providers=["CPUExecutionProvider"], |
|
) |
|
|
|
meta = self.encoder.get_modelmeta().custom_metadata_map |
|
self.normalize_type = meta["normalize_type"] |
|
print(meta) |
|
|
|
self.pred_rnn_layers = int(meta["pred_rnn_layers"]) |
|
self.pred_hidden = int(meta["pred_hidden"]) |
|
|
|
def init_decoder(self, decoder): |
|
session_opts = ort.SessionOptions() |
|
session_opts.inter_op_num_threads = 1 |
|
session_opts.intra_op_num_threads = 1 |
|
|
|
self.decoder = ort.InferenceSession( |
|
decoder, |
|
sess_options=session_opts, |
|
providers=["CPUExecutionProvider"], |
|
) |
|
|
|
def init_joiner(self, joiner): |
|
session_opts = ort.SessionOptions() |
|
session_opts.inter_op_num_threads = 1 |
|
session_opts.intra_op_num_threads = 1 |
|
|
|
self.joiner = ort.InferenceSession( |
|
joiner, |
|
sess_options=session_opts, |
|
providers=["CPUExecutionProvider"], |
|
) |
|
|
|
def get_decoder_state(self): |
|
batch_size = 1 |
|
state0 = torch.zeros(self.pred_rnn_layers, batch_size, self.pred_hidden).numpy() |
|
state1 = torch.zeros(self.pred_rnn_layers, batch_size, self.pred_hidden).numpy() |
|
return state0, state1 |
|
|
|
def run_encoder(self, x: np.ndarray): |
|
|
|
x = torch.from_numpy(x) |
|
x = x.t().unsqueeze(0) |
|
|
|
x_lens = torch.tensor([x.shape[-1]], dtype=torch.int64) |
|
|
|
(encoder_out, out_len) = self.encoder.run( |
|
[ |
|
self.encoder.get_outputs()[0].name, |
|
self.encoder.get_outputs()[1].name, |
|
], |
|
{ |
|
self.encoder.get_inputs()[0].name: x.numpy(), |
|
self.encoder.get_inputs()[1].name: x_lens.numpy(), |
|
}, |
|
) |
|
|
|
return encoder_out |
|
|
|
def run_decoder( |
|
self, |
|
token: int, |
|
state0: np.ndarray, |
|
state1: np.ndarray, |
|
): |
|
target = torch.tensor([[token]], dtype=torch.int32).numpy() |
|
target_len = torch.tensor([1], dtype=torch.int32).numpy() |
|
|
|
( |
|
decoder_out, |
|
decoder_out_length, |
|
state0_next, |
|
state1_next, |
|
) = self.decoder.run( |
|
[ |
|
self.decoder.get_outputs()[0].name, |
|
self.decoder.get_outputs()[1].name, |
|
self.decoder.get_outputs()[2].name, |
|
self.decoder.get_outputs()[3].name, |
|
], |
|
{ |
|
self.decoder.get_inputs()[0].name: target, |
|
self.decoder.get_inputs()[1].name: target_len, |
|
self.decoder.get_inputs()[2].name: state0, |
|
self.decoder.get_inputs()[3].name: state1, |
|
}, |
|
) |
|
return decoder_out, state0_next, state1_next |
|
|
|
def run_joiner( |
|
self, |
|
encoder_out: np.ndarray, |
|
decoder_out: np.ndarray, |
|
): |
|
|
|
|
|
logit = self.joiner.run( |
|
[ |
|
self.joiner.get_outputs()[0].name, |
|
], |
|
{ |
|
self.joiner.get_inputs()[0].name: encoder_out, |
|
self.joiner.get_inputs()[1].name: decoder_out, |
|
}, |
|
)[0] |
|
|
|
return logit |
|
|
|
|
|
def main(): |
|
model = OnnxModel("encoder.int8.onnx", "decoder.onnx", "joiner.onnx") |
|
|
|
id2token = dict() |
|
with open("./tokens.txt", encoding="utf-8") as f: |
|
for line in f: |
|
t, idx = line.split() |
|
id2token[int(idx)] = t |
|
|
|
fbank = create_fbank() |
|
audio, sample_rate = sf.read("./example.wav", dtype="float32", always_2d=True) |
|
audio = audio[:, 0] |
|
if sample_rate != 16000: |
|
audio = librosa.resample( |
|
audio, |
|
orig_sr=sample_rate, |
|
target_sr=16000, |
|
) |
|
sample_rate = 16000 |
|
|
|
tail_padding = np.zeros(sample_rate * 2) |
|
|
|
audio = np.concatenate([audio, tail_padding]) |
|
|
|
blank = len(id2token) - 1 |
|
ans = [blank] |
|
state0, state1 = model.get_decoder_state() |
|
decoder_out, state0_next, state1_next = model.run_decoder(ans[-1], state0, state1) |
|
|
|
features = compute_features(audio, fbank) |
|
print("audio.shape", audio.shape) |
|
print("features.shape", features.shape) |
|
|
|
encoder_out = model.run_encoder(features) |
|
|
|
for t in range(encoder_out.shape[2]): |
|
encoder_out_t = encoder_out[:, :, t : t + 1] |
|
logits = model.run_joiner(encoder_out_t, decoder_out) |
|
logits = torch.from_numpy(logits) |
|
logits = logits.squeeze() |
|
idx = torch.argmax(logits, dim=-1).item() |
|
if idx != blank: |
|
ans.append(idx) |
|
state0 = state0_next |
|
state1 = state1_next |
|
decoder_out, state0_next, state1_next = model.run_decoder( |
|
ans[-1], state0, state1 |
|
) |
|
|
|
ans = ans[1:] |
|
print(ans) |
|
tokens = [id2token[i] for i in ans] |
|
underline = "▁" |
|
|
|
text = "".join(tokens).replace(underline, " ").strip() |
|
print("./example.wav") |
|
print(text) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|