FRN / inference_onnx.py
vietanhnami
first commit
45916af
raw
history blame
2.4 kB
import argparse
import glob
import os
import librosa
import numpy as np
import onnx
import onnxruntime
import soundfile as sf
import torch
import tqdm
from config import CONFIG
parser = argparse.ArgumentParser()
parser.add_argument('--onnx_path', default=None,
help='path to onnx')
args = parser.parse_args()
if __name__ == '__main__':
path = args.onnx_path
window = CONFIG.DATA.window_size
stride = CONFIG.DATA.stride
onnx_model = onnx.load(path)
options = onnxruntime.SessionOptions()
options.intra_op_num_threads = 8
options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
session = onnxruntime.InferenceSession(path, options)
input_names = [x.name for x in session.get_inputs()]
output_names = [x.name for x in session.get_outputs()]
print(input_names)
print(output_names)
audio_files = glob.glob(os.path.join(CONFIG.TEST.in_dir, '*.wav'))
hann = torch.sqrt(torch.hann_window(window))
os.makedirs(CONFIG.TEST.out_dir, exist_ok=True)
for file in tqdm.tqdm(audio_files, total=len(audio_files)):
sig, _ = librosa.load(file, sr=48000)
sig = torch.tensor(sig)
re_im = torch.stft(sig, window, stride, window=hann, return_complex=False).permute(2, 0, 1).unsqueeze(
0).numpy().astype(np.float32)
inputs = {input_names[i]: np.zeros([d.dim_value for d in _input.type.tensor_type.shape.dim],
dtype=np.float32)
for i, _input in enumerate(onnx_model.graph.input)
}
output_audio = []
for t in range(re_im.shape[-1]):
ri_t = re_im[:, :, :, t:t + 1]
out, prev_mag, predictor_state, mlp_state = session.run(output_names, inputs)
inputs[input_names[1]] = prev_mag
inputs[input_names[2]] = predictor_state
inputs[input_names[3]] = mlp_state
output_audio.append(out)
output_audio = torch.tensor(np.concatenate(output_audio, 0))
output_audio = output_audio.permute(1, 0, 2).contiguous()
output_audio = torch.view_as_complex(output_audio)
output_audio = torch.istft(output_audio, window, stride, window=hann)
sf.write(os.path.join(CONFIG.TEST.out_dir, os.path.basename(file)), output_audio, samplerate=48000,
subtype='PCM_16')