| """ |
| This is an example how to implement real time processing of the DTLN ONNX |
| model in python. |
| |
| Please change the name of the .wav file at line 49 before running the sript. |
| For the ONNX runtime call: $ pip install onnxruntime |
| |
| |
| |
| Author: Nils L. Westhausen (nils.westhausen@uol.de) |
| Version: 03.07.2020 |
| |
| This code is licensed under the terms of the MIT-license. |
| """ |
|
|
| import soundfile as sf |
| import numpy as np |
| import time |
| import onnxruntime |
|
|
|
|
|
|
| |
| |
| |
| block_len = 512 |
| block_shift = 128 |
| |
| interpreter_1 = onnxruntime.InferenceSession('./model_1.onnx') |
| model_input_names_1 = [inp.name for inp in interpreter_1.get_inputs()] |
| |
| model_inputs_1 = { |
| inp.name: np.zeros( |
| [dim if isinstance(dim, int) else 1 for dim in inp.shape], |
| dtype=np.float32) |
| for inp in interpreter_1.get_inputs()} |
| |
| interpreter_2 = onnxruntime.InferenceSession('./model_2.onnx') |
| model_input_names_2 = [inp.name for inp in interpreter_2.get_inputs()] |
| |
| model_inputs_2 = { |
| inp.name: np.zeros( |
| [dim if isinstance(dim, int) else 1 for dim in inp.shape], |
| dtype=np.float32) |
| for inp in interpreter_2.get_inputs()} |
|
|
| |
| audio,fs = sf.read('path/to/your/favorite.wav') |
| |
| if fs != 16000: |
| raise ValueError('This model only supports 16k sampling rate.') |
| |
| out_file = np.zeros((len(audio))) |
| |
| in_buffer = np.zeros((block_len)).astype('float32') |
| out_buffer = np.zeros((block_len)).astype('float32') |
| |
| num_blocks = (audio.shape[0] - (block_len-block_shift)) // block_shift |
| |
| time_array = [] |
| for idx in range(num_blocks): |
| start_time = time.time() |
| |
| in_buffer[:-block_shift] = in_buffer[block_shift:] |
| in_buffer[-block_shift:] = audio[idx*block_shift:(idx*block_shift)+block_shift] |
| |
| in_block_fft = np.fft.rfft(in_buffer) |
| in_mag = np.abs(in_block_fft) |
| in_phase = np.angle(in_block_fft) |
| |
| in_mag = np.reshape(in_mag, (1,1,-1)).astype('float32') |
| |
| model_inputs_1[model_input_names_1[0]] = in_mag |
| |
| model_outputs_1 = interpreter_1.run(None, model_inputs_1) |
| |
| out_mask = model_outputs_1[0] |
| |
| model_inputs_1[model_input_names_1[1]] = model_outputs_1[1] |
| |
| estimated_complex = in_mag * out_mask * np.exp(1j * in_phase) |
| estimated_block = np.fft.irfft(estimated_complex) |
| |
| estimated_block = np.reshape(estimated_block, (1,1,-1)).astype('float32') |
| |
| |
| model_inputs_2[model_input_names_2[0]] = estimated_block |
| |
| model_outputs_2 = interpreter_2.run(None, model_inputs_2) |
| |
| out_block = model_outputs_2[0] |
| |
| model_inputs_2[model_input_names_2[1]] = model_outputs_2[1] |
| |
| out_buffer[:-block_shift] = out_buffer[block_shift:] |
| out_buffer[-block_shift:] = np.zeros((block_shift)) |
| out_buffer += np.squeeze(out_block) |
| |
| out_file[idx*block_shift:(idx*block_shift)+block_shift] = out_buffer[:block_shift] |
| time_array.append(time.time()-start_time) |
| |
| |
| sf.write('out.wav', out_file, fs) |
| print('Processing Time [ms]:') |
| print(np.mean(np.stack(time_array))*1000) |
| print('Processing finished.') |