import argparse import glob import os import librosa import numpy as np import onnx import onnxruntime import soundfile as sf import torch import tqdm from config import CONFIG parser = argparse.ArgumentParser() parser.add_argument('--onnx_path', default=None, help='path to onnx') args = parser.parse_args() if __name__ == '__main__': path = args.onnx_path window = CONFIG.DATA.window_size stride = CONFIG.DATA.stride onnx_model = onnx.load(path) options = onnxruntime.SessionOptions() options.intra_op_num_threads = 8 options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL session = onnxruntime.InferenceSession(path, options) input_names = [x.name for x in session.get_inputs()] output_names = [x.name for x in session.get_outputs()] print(input_names) print(output_names) audio_files = glob.glob(os.path.join(CONFIG.TEST.in_dir, '*.wav')) hann = torch.sqrt(torch.hann_window(window)) os.makedirs(CONFIG.TEST.out_dir, exist_ok=True) for file in tqdm.tqdm(audio_files, total=len(audio_files)): sig, _ = librosa.load(file, sr=48000) sig = torch.tensor(sig) re_im = torch.stft(sig, window, stride, window=hann, return_complex=False).permute(1, 0, 2).unsqueeze( 1).numpy().astype(np.float32) inputs = {input_names[i]: np.zeros([d.dim_value for d in _input.type.tensor_type.shape.dim], dtype=np.float32) for i, _input in enumerate(onnx_model.graph.input) } output_audio = [] for t in range(re_im.shape[0]): inputs[input_names[0]] = re_im[t] out, prev_mag, predictor_state, mlp_state = session.run(output_names, inputs) inputs[input_names[1]] = prev_mag inputs[input_names[2]] = predictor_state inputs[input_names[3]] = mlp_state output_audio.append(out) output_audio = torch.tensor(np.concatenate(output_audio, 0)) output_audio = output_audio.permute(1, 0, 2).contiguous() output_audio = torch.view_as_complex(output_audio) output_audio = torch.istft(output_audio, window, stride, window=hann) sf.write(os.path.join(CONFIG.TEST.out_dir, os.path.basename(file)), output_audio, samplerate=48000, subtype='PCM_16')