csukuangfj's picture
add models
f53860a
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import argparse
from pathlib import Path
import kaldi_native_fbank as knf
import librosa
import numpy as np
import onnxruntime as ort
import soundfile as sf
import torch
def create_fbank():
opts = knf.FbankOptions()
opts.frame_opts.dither = 0
opts.frame_opts.remove_dc_offset = False
opts.frame_opts.preemph_coeff = 0
opts.frame_opts.window_type = "hann"
# Even though GigaAM uses 400 for fft, here we use 512
# since kaldi-native-fbank only support fft for power of 2.
opts.frame_opts.round_to_power_of_two = True
opts.mel_opts.low_freq = 0
opts.mel_opts.high_freq = 8000
opts.mel_opts.num_bins = 64
fbank = knf.OnlineFbank(opts)
return fbank
def compute_features(audio, fbank):
assert len(audio.shape) == 1, audio.shape
fbank.accept_waveform(16000, audio)
ans = []
processed = 0
while processed < fbank.num_frames_ready:
ans.append(np.array(fbank.get_frame(processed)))
processed += 1
ans = np.stack(ans)
return ans
def display(sess):
print("==========Input==========")
for i in sess.get_inputs():
print(i)
print("==========Output==========")
for i in sess.get_outputs():
print(i)
"""
==========Input==========
NodeArg(name='audio_signal', type='tensor(float)', shape=['audio_signal_dynamic_axes_1', 64, 'audio_signal_dynamic_axes_2'])
NodeArg(name='length', type='tensor(int64)', shape=['length_dynamic_axes_1'])
==========Output==========
NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 768, 'outputs_dynamic_axes_2'])
NodeArg(name='encoded_lengths', type='tensor(int64)', shape=['encoded_lengths_dynamic_axes_1'])
==========Input==========
NodeArg(name='targets', type='tensor(int32)', shape=['targets_dynamic_axes_1', 'targets_dynamic_axes_2'])
NodeArg(name='target_length', type='tensor(int32)', shape=['target_length_dynamic_axes_1'])
NodeArg(name='states.1', type='tensor(float)', shape=[1, 'states.1_dim_1', 320])
NodeArg(name='onnx::LSTM_3', type='tensor(float)', shape=[1, 1, 320])
==========Output==========
NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 320, 'outputs_dynamic_axes_2'])
NodeArg(name='prednet_lengths', type='tensor(int32)', shape=['prednet_lengths_dynamic_axes_1'])
NodeArg(name='states', type='tensor(float)', shape=[1, 'states_dynamic_axes_1', 320])
NodeArg(name='74', type='tensor(float)', shape=[1, 'states_dynamic_axes_1', 320])
==========Input==========
NodeArg(name='encoder_outputs', type='tensor(float)', shape=['encoder_outputs_dynamic_axes_1', 768, 'encoder_outputs_dynamic_axes_2'])
NodeArg(name='decoder_outputs', type='tensor(float)', shape=['decoder_outputs_dynamic_axes_1', 320, 'decoder_outputs_dynamic_axes_2'])
==========Output==========
NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 'outputs_dynamic_axes_2', 'outputs_dynamic_axes_3', 513])
"""
class OnnxModel:
def __init__(
self,
encoder: str,
decoder: str,
joiner: str,
):
self.init_encoder(encoder)
display(self.encoder)
self.init_decoder(decoder)
display(self.decoder)
self.init_joiner(joiner)
display(self.joiner)
def init_encoder(self, encoder):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.encoder = ort.InferenceSession(
encoder,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
meta = self.encoder.get_modelmeta().custom_metadata_map
self.normalize_type = meta["normalize_type"]
print(meta)
self.pred_rnn_layers = int(meta["pred_rnn_layers"])
self.pred_hidden = int(meta["pred_hidden"])
def init_decoder(self, decoder):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.decoder = ort.InferenceSession(
decoder,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
def init_joiner(self, joiner):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.joiner = ort.InferenceSession(
joiner,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
def get_decoder_state(self):
batch_size = 1
state0 = torch.zeros(self.pred_rnn_layers, batch_size, self.pred_hidden).numpy()
state1 = torch.zeros(self.pred_rnn_layers, batch_size, self.pred_hidden).numpy()
return state0, state1
def run_encoder(self, x: np.ndarray):
# x: (T, C)
x = torch.from_numpy(x)
x = x.t().unsqueeze(0)
# x: [1, C, T]
x_lens = torch.tensor([x.shape[-1]], dtype=torch.int64)
(encoder_out, out_len) = self.encoder.run(
[
self.encoder.get_outputs()[0].name,
self.encoder.get_outputs()[1].name,
],
{
self.encoder.get_inputs()[0].name: x.numpy(),
self.encoder.get_inputs()[1].name: x_lens.numpy(),
},
)
# [batch_size, dim, T]
return encoder_out
def run_decoder(
self,
token: int,
state0: np.ndarray,
state1: np.ndarray,
):
target = torch.tensor([[token]], dtype=torch.int32).numpy()
target_len = torch.tensor([1], dtype=torch.int32).numpy()
(
decoder_out,
decoder_out_length,
state0_next,
state1_next,
) = self.decoder.run(
[
self.decoder.get_outputs()[0].name,
self.decoder.get_outputs()[1].name,
self.decoder.get_outputs()[2].name,
self.decoder.get_outputs()[3].name,
],
{
self.decoder.get_inputs()[0].name: target,
self.decoder.get_inputs()[1].name: target_len,
self.decoder.get_inputs()[2].name: state0,
self.decoder.get_inputs()[3].name: state1,
},
)
return decoder_out, state0_next, state1_next
def run_joiner(
self,
encoder_out: np.ndarray,
decoder_out: np.ndarray,
):
# encoder_out: [batch_size, dim, 1]
# decoder_out: [batch_size, dim, 1]
logit = self.joiner.run(
[
self.joiner.get_outputs()[0].name,
],
{
self.joiner.get_inputs()[0].name: encoder_out,
self.joiner.get_inputs()[1].name: decoder_out,
},
)[0]
# logit: [batch_size, 1, 1, vocab_size]
return logit
def main():
model = OnnxModel("encoder.int8.onnx", "decoder.onnx", "joiner.onnx")
id2token = dict()
with open("./tokens.txt", encoding="utf-8") as f:
for line in f:
t, idx = line.split()
id2token[int(idx)] = t
fbank = create_fbank()
audio, sample_rate = sf.read("./example.wav", dtype="float32", always_2d=True)
audio = audio[:, 0] # only use the first channel
if sample_rate != 16000:
audio = librosa.resample(
audio,
orig_sr=sample_rate,
target_sr=16000,
)
sample_rate = 16000
tail_padding = np.zeros(sample_rate * 2)
audio = np.concatenate([audio, tail_padding])
blank = len(id2token) - 1
ans = [blank]
state0, state1 = model.get_decoder_state()
decoder_out, state0_next, state1_next = model.run_decoder(ans[-1], state0, state1)
features = compute_features(audio, fbank)
print("audio.shape", audio.shape)
print("features.shape", features.shape)
encoder_out = model.run_encoder(features)
# encoder_out:[batch_size, dim, T)
for t in range(encoder_out.shape[2]):
encoder_out_t = encoder_out[:, :, t : t + 1]
logits = model.run_joiner(encoder_out_t, decoder_out)
logits = torch.from_numpy(logits)
logits = logits.squeeze()
idx = torch.argmax(logits, dim=-1).item()
if idx != blank:
ans.append(idx)
state0 = state0_next
state1 = state1_next
decoder_out, state0_next, state1_next = model.run_decoder(
ans[-1], state0, state1
)
ans = ans[1:] # remove the first blank
print(ans)
tokens = [id2token[i] for i in ans]
underline = "▁"
# underline = b"\xe2\x96\x81".decode()
text = "".join(tokens).replace(underline, " ").strip()
print("./example.wav")
print(text)
if __name__ == "__main__":
main()