#!/usr/bin/env python3 # Copyright (c) 2023 Xiaomi Corporation # Author: Fangjun Kuang import kaldi_native_fbank as knf import librosa import numpy as np import onnxruntime def load_cmvn(): neg_mean = None inv_std = None with open("am.mvn") as f: for line in f: if not line.startswith(""): continue t = line.split()[3:-1] t = list(map(lambda x: float(x), t)) if neg_mean is None: neg_mean = np.array(t, dtype=np.float32) else: inv_std = np.array(t, dtype=np.float32) return neg_mean, inv_std def compute_feat(): sample_rate = 16000 samples, _ = librosa.load("2.wav", sr=sample_rate) opts = knf.FbankOptions() opts.frame_opts.dither = 0 opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = sample_rate opts.mel_opts.num_bins = 80 online_fbank = knf.OnlineFbank(opts) online_fbank.accept_waveform(sample_rate, (samples * 32768).tolist()) online_fbank.input_finished() features = np.stack( [online_fbank.get_frame(i) for i in range(online_fbank.num_frames_ready)] ) assert features.data.contiguous is True assert features.dtype == np.float32, features.dtype window_size = 7 # lfr_m window_shift = 6 # lfr_n T = (features.shape[0] - window_size) // window_shift + 1 features = np.lib.stride_tricks.as_strided( features, shape=(T, features.shape[1] * window_size), strides=((window_shift * features.shape[1]) * 4, 4), ) neg_mean, inv_std = load_cmvn() features = (features + neg_mean) * inv_std return features # tokens.txt in paraformer has only one column # while it has two columns ins sherpa-onnx. # This function can handle tokens.txt from both paraformer and sherpa-onnx def load_tokens(): ans = dict() i = 0 with open("tokens.txt", encoding="utf-8") as f: for line in f: ans[i] = line.strip().split()[0] i += 1 return ans def main(): features = compute_feat() features = np.expand_dims(features, axis=0) features_length = np.array([features.shape[1]], dtype=np.int32) session_opts = onnxruntime.SessionOptions() session_opts.log_severity_level = 3 # error level sess = onnxruntime.InferenceSession("model.onnx", session_opts) inputs = { "speech": features, "speech_lengths": features_length, } output_names = ["logits"] try: outputs = sess.run(output_names, input_feed=inputs) except ONNXRuntimeError: print("Input wav is silence or noise") return log_probs = outputs[0].squeeze(0) y = log_probs.argmax(axis=-1) tokens = load_tokens() text = "".join([tokens[i] for i in y if i not in (0, 2)]) print(text) if __name__ == "__main__": main()