File size: 2,399 Bytes
45916af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import argparse
import glob
import os
import librosa
import numpy as np
import onnx
import onnxruntime
import soundfile as sf
import torch
import tqdm
from config import CONFIG
parser = argparse.ArgumentParser()
parser.add_argument('--onnx_path', default=None,
help='path to onnx')
args = parser.parse_args()
if __name__ == '__main__':
path = args.onnx_path
window = CONFIG.DATA.window_size
stride = CONFIG.DATA.stride
onnx_model = onnx.load(path)
options = onnxruntime.SessionOptions()
options.intra_op_num_threads = 8
options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
session = onnxruntime.InferenceSession(path, options)
input_names = [x.name for x in session.get_inputs()]
output_names = [x.name for x in session.get_outputs()]
print(input_names)
print(output_names)
audio_files = glob.glob(os.path.join(CONFIG.TEST.in_dir, '*.wav'))
hann = torch.sqrt(torch.hann_window(window))
os.makedirs(CONFIG.TEST.out_dir, exist_ok=True)
for file in tqdm.tqdm(audio_files, total=len(audio_files)):
sig, _ = librosa.load(file, sr=48000)
sig = torch.tensor(sig)
re_im = torch.stft(sig, window, stride, window=hann, return_complex=False).permute(2, 0, 1).unsqueeze(
0).numpy().astype(np.float32)
inputs = {input_names[i]: np.zeros([d.dim_value for d in _input.type.tensor_type.shape.dim],
dtype=np.float32)
for i, _input in enumerate(onnx_model.graph.input)
}
output_audio = []
for t in range(re_im.shape[-1]):
ri_t = re_im[:, :, :, t:t + 1]
out, prev_mag, predictor_state, mlp_state = session.run(output_names, inputs)
inputs[input_names[1]] = prev_mag
inputs[input_names[2]] = predictor_state
inputs[input_names[3]] = mlp_state
output_audio.append(out)
output_audio = torch.tensor(np.concatenate(output_audio, 0))
output_audio = output_audio.permute(1, 0, 2).contiguous()
output_audio = torch.view_as_complex(output_audio)
output_audio = torch.istft(output_audio, window, stride, window=hann)
sf.write(os.path.join(CONFIG.TEST.out_dir, os.path.basename(file)), output_audio, samplerate=48000,
subtype='PCM_16')
|