| | |
| | """ |
| | This File contains everything to train the DTLN model. |
| | |
| | For running the training see "run_training.py". |
| | To run evaluation with the provided pretrained model see "run_evaluation.py". |
| | |
| | Author: Nils L. Westhausen (nils.westhausen@uol.de) |
| | Version: 24.06.2020 |
| | |
| | This code is licensed under the terms of the MIT-license. |
| | """ |
| |
|
| |
|
| | import os, fnmatch |
| | import tensorflow.keras as keras |
| | from tensorflow.keras.models import Model |
| | from tensorflow.keras.layers import Activation, Dense, LSTM, Dropout, \ |
| | Lambda, Input, Multiply, Layer, Conv1D |
| | from tensorflow.keras.callbacks import ReduceLROnPlateau, CSVLogger, \ |
| | EarlyStopping, ModelCheckpoint |
| | import tensorflow as tf |
| | import soundfile as sf |
| | from wavinfo import WavInfoReader |
| | from random import shuffle, seed |
| | import numpy as np |
| |
|
| |
|
| |
|
| | class audio_generator(): |
| | ''' |
| | Class to create a Tensorflow dataset based on an iterator from a large scale |
| | audio dataset. This audio generator only supports single channel audio files. |
| | ''' |
| | |
| | def __init__(self, path_to_input, path_to_s1, len_of_samples, fs, train_flag=False): |
| | ''' |
| | Constructor of the audio generator class. |
| | Inputs: |
| | path_to_input path to the mixtures |
| | path_to_s1 path to the target source data |
| | len_of_samples length of audio snippets in samples |
| | fs sampling rate |
| | train_flag flag for activate shuffling of files |
| | ''' |
| | |
| | self.path_to_input = path_to_input |
| | self.path_to_s1 = path_to_s1 |
| | self.len_of_samples = len_of_samples |
| | self.fs = fs |
| | self.train_flag=train_flag |
| | |
| | |
| | self.count_samples() |
| | |
| | self.create_tf_data_obj() |
| | |
| | def count_samples(self): |
| | ''' |
| | Method to list the data of the dataset and count the number of samples. |
| | ''' |
| |
|
| | |
| | self.file_names = fnmatch.filter(os.listdir(self.path_to_input), '*.wav') |
| | |
| | self.total_samples = 0 |
| | for file in self.file_names: |
| | info = WavInfoReader(os.path.join(self.path_to_input, file)) |
| | self.total_samples = self.total_samples + \ |
| | int(np.fix(info.data.frame_count/self.len_of_samples)) |
| | |
| | |
| | def create_generator(self): |
| | ''' |
| | Method to create the iterator. |
| | ''' |
| |
|
| | |
| | if self.train_flag: |
| | shuffle(self.file_names) |
| | |
| | for file in self.file_names: |
| | |
| | noisy, fs_1 = sf.read(os.path.join(self.path_to_input, file)) |
| | speech, fs_2 = sf.read(os.path.join(self.path_to_s1, file)) |
| | |
| | if fs_1 != self.fs or fs_2 != self.fs: |
| | raise ValueError('Sampling rates do not match.') |
| | if noisy.ndim != 1 or speech.ndim != 1: |
| | raise ValueError('Too many audio channels. The DTLN audio_generator \ |
| | only supports single channel audio data.') |
| | |
| | num_samples = int(np.fix(noisy.shape[0]/self.len_of_samples)) |
| | |
| | for idx in range(num_samples): |
| | |
| | in_dat = noisy[int(idx*self.len_of_samples):int((idx+1)* |
| | self.len_of_samples)] |
| | tar_dat = speech[int(idx*self.len_of_samples):int((idx+1)* |
| | self.len_of_samples)] |
| | |
| | yield in_dat.astype('float32'), tar_dat.astype('float32') |
| | |
| |
|
| | def create_tf_data_obj(self): |
| | ''' |
| | Method to to create the tf.data.Dataset. |
| | ''' |
| |
|
| | |
| | self.tf_data_set = tf.data.Dataset.from_generator( |
| | self.create_generator, |
| | (tf.float32, tf.float32), |
| | output_shapes=(tf.TensorShape([self.len_of_samples]), \ |
| | tf.TensorShape([self.len_of_samples])), |
| | args=None |
| | ) |
| |
|
| | |
| | |
| |
|
| |
|
| | class DTLN_model(): |
| | ''' |
| | Class to create and train the DTLN model |
| | ''' |
| | |
| | def __init__(self): |
| | ''' |
| | Constructor |
| | ''' |
| |
|
| | |
| | self.cost_function = self.snr_cost |
| | |
| | self.model = [] |
| | |
| | self.fs = 16000 |
| | self.batchsize = 32 |
| | self.len_samples = 15 |
| | self.activation = 'sigmoid' |
| | self.numUnits = 128 |
| | self.numLayer = 2 |
| | self.blockLen = 512 |
| | self.block_shift = 128 |
| | self.dropout = 0.25 |
| | self.lr = 1e-3 |
| | self.max_epochs = 200 |
| | self.encoder_size = 256 |
| | self.eps = 1e-7 |
| | |
| | os.environ['PYTHONHASHSEED']=str(42) |
| | seed(42) |
| | np.random.seed(42) |
| | tf.random.set_seed(42) |
| | |
| | physical_devices = tf.config.experimental.list_physical_devices('GPU') |
| | if len(physical_devices) > 0: |
| | for device in physical_devices: |
| | tf.config.experimental.set_memory_growth(device, enable=True) |
| | |
| |
|
| | @staticmethod |
| | def snr_cost(s_estimate, s_true): |
| | ''' |
| | Static Method defining the cost function. |
| | The negative signal to noise ratio is calculated here. The loss is |
| | always calculated over the last dimension. |
| | ''' |
| |
|
| | |
| | snr = tf.reduce_mean(tf.math.square(s_true), axis=-1, keepdims=True) / \ |
| | (tf.reduce_mean(tf.math.square(s_true-s_estimate), axis=-1, keepdims=True)+1e-7) |
| | |
| | num = tf.math.log(snr) |
| | denom = tf.math.log(tf.constant(10, dtype=num.dtype)) |
| | loss = -10*(num / (denom)) |
| | |
| | return loss |
| | |
| |
|
| | def lossWrapper(self): |
| | ''' |
| | A wrapper function which returns the loss function. This is done to |
| | to enable additional arguments to the loss function if necessary. |
| | ''' |
| | def lossFunction(y_true,y_pred): |
| | |
| | loss = tf.squeeze(self.cost_function(y_pred,y_true)) |
| | |
| | loss = tf.reduce_mean(loss) |
| | |
| | return loss |
| | |
| | return lossFunction |
| | |
| | |
| |
|
| | ''' |
| | In the following some helper layers are defined. |
| | ''' |
| | |
| | def stftLayer(self, x): |
| | ''' |
| | Method for an STFT helper layer used with a Lambda layer. The layer |
| | calculates the STFT on the last dimension and returns the magnitude and |
| | phase of the STFT. |
| | ''' |
| | |
| | |
| | frames = tf.signal.frame(x, self.blockLen, self.block_shift) |
| | |
| | stft_dat = tf.signal.rfft(frames) |
| | |
| | mag = tf.abs(stft_dat) |
| | phase = tf.math.angle(stft_dat) |
| | |
| | return [mag, phase] |
| | |
| | def fftLayer(self, x): |
| | ''' |
| | Method for an fft helper layer used with a Lambda layer. The layer |
| | calculates the rFFT on the last dimension and returns the magnitude and |
| | phase of the STFT. |
| | ''' |
| | |
| | |
| | frame = tf.expand_dims(x, axis=1) |
| | |
| | stft_dat = tf.signal.rfft(frame) |
| | |
| | mag = tf.abs(stft_dat) |
| | phase = tf.math.angle(stft_dat) |
| | |
| | return [mag, phase] |
| |
|
| | |
| | |
| | def ifftLayer(self, x): |
| | ''' |
| | Method for an inverse FFT layer used with an Lambda layer. This layer |
| | calculates time domain frames from magnitude and phase information. |
| | As input x a list with [mag,phase] is required. |
| | ''' |
| | |
| | |
| | s1_stft = (tf.cast(x[0], tf.complex64) * |
| | tf.exp( (1j * tf.cast(x[1], tf.complex64)))) |
| | |
| | return tf.signal.irfft(s1_stft) |
| | |
| | |
| | def overlapAddLayer(self, x): |
| | ''' |
| | Method for an overlap and add helper layer used with a Lambda layer. |
| | This layer reconstructs the waveform from a framed signal. |
| | ''' |
| |
|
| | |
| | return tf.signal.overlap_and_add(x, self.block_shift) |
| | |
| | |
| |
|
| | def seperation_kernel(self, num_layer, mask_size, x, stateful=False): |
| | ''' |
| | Method to create a separation kernel. |
| | !! Important !!: Do not use this layer with a Lambda layer. If used with |
| | a Lambda layer the gradients are updated correctly. |
| | |
| | Inputs: |
| | num_layer Number of LSTM layers |
| | mask_size Output size of the mask and size of the Dense layer |
| | ''' |
| |
|
| | |
| | for idx in range(num_layer): |
| | x = LSTM(self.numUnits, return_sequences=True, stateful=stateful)(x) |
| | |
| | if idx<(num_layer-1): |
| | x = Dropout(self.dropout)(x) |
| | |
| | mask = Dense(mask_size)(x) |
| | mask = Activation(self.activation)(mask) |
| | |
| | return mask |
| | |
| | def seperation_kernel_with_states(self, num_layer, mask_size, x, |
| | in_states): |
| | ''' |
| | Method to create a separation kernel, which returns the LSTM states. |
| | !! Important !!: Do not use this layer with a Lambda layer. If used with |
| | a Lambda layer the gradients are updated correctly. |
| | |
| | Inputs: |
| | num_layer Number of LSTM layers |
| | mask_size Output size of the mask and size of the Dense layer |
| | ''' |
| | |
| | states_h = [] |
| | states_c = [] |
| | |
| | for idx in range(num_layer): |
| | in_state = [in_states[:,idx,:, 0], in_states[:,idx,:, 1]] |
| | x, h_state, c_state = LSTM(self.numUnits, return_sequences=True, |
| | unroll=True, return_state=True)(x, initial_state=in_state) |
| | |
| | if idx<(num_layer-1): |
| | x = Dropout(self.dropout)(x) |
| | states_h.append(h_state) |
| | states_c.append(c_state) |
| | |
| | mask = Dense(mask_size)(x) |
| | mask = Activation(self.activation)(mask) |
| | out_states_h = tf.reshape(tf.stack(states_h, axis=0), |
| | [1,num_layer,self.numUnits]) |
| | out_states_c = tf.reshape(tf.stack(states_c, axis=0), |
| | [1,num_layer,self.numUnits]) |
| | out_states = tf.stack([out_states_h, out_states_c], axis=-1) |
| | |
| | return mask, out_states |
| |
|
| | def build_DTLN_model(self, norm_stft=False): |
| | ''' |
| | Method to build and compile the DTLN model. The model takes time domain |
| | batches of size (batchsize, len_in_samples) and returns enhanced clips |
| | in the same dimensions. As optimizer for the Training process the Adam |
| | optimizer with a gradient norm clipping of 3 is used. |
| | The model contains two separation cores. The first has an STFT signal |
| | transformation and the second a learned transformation based on 1D-Conv |
| | layer. |
| | ''' |
| | |
| | |
| | time_dat = Input(batch_shape=(None, None)) |
| | |
| | mag,angle = Lambda(self.stftLayer)(time_dat) |
| | |
| | if norm_stft: |
| | mag_norm = InstantLayerNormalization()(tf.math.log(mag + 1e-7)) |
| | else: |
| | |
| | mag_norm = mag |
| | |
| | mask_1 = self.seperation_kernel(self.numLayer, (self.blockLen//2+1), mag_norm) |
| | |
| | estimated_mag = Multiply()([mag, mask_1]) |
| | |
| | estimated_frames_1 = Lambda(self.ifftLayer)([estimated_mag,angle]) |
| | |
| | encoded_frames = Conv1D(self.encoder_size,1,strides=1,use_bias=False)(estimated_frames_1) |
| | |
| | encoded_frames_norm = InstantLayerNormalization()(encoded_frames) |
| | |
| | mask_2 = self.seperation_kernel(self.numLayer, self.encoder_size, encoded_frames_norm) |
| | |
| | estimated = Multiply()([encoded_frames, mask_2]) |
| | |
| | decoded_frames = Conv1D(self.blockLen, 1, padding='causal',use_bias=False)(estimated) |
| | |
| | estimated_sig = Lambda(self.overlapAddLayer)(decoded_frames) |
| |
|
| | |
| | |
| | self.model = Model(inputs=time_dat, outputs=estimated_sig) |
| | |
| | print(self.model.summary()) |
| | |
| | def build_DTLN_model_stateful(self, norm_stft=False): |
| | ''' |
| | Method to build stateful DTLN model for real time processing. The model |
| | takes one time domain frame of size (1, blockLen) and one enhanced frame. |
| | |
| | ''' |
| | |
| | |
| | time_dat = Input(batch_shape=(1, self.blockLen)) |
| | |
| | mag,angle = Lambda(self.fftLayer)(time_dat) |
| | |
| | if norm_stft: |
| | mag_norm = InstantLayerNormalization()(tf.math.log(mag + 1e-7)) |
| | else: |
| | |
| | mag_norm = mag |
| | |
| | mask_1 = self.seperation_kernel(self.numLayer, (self.blockLen//2+1), mag_norm, stateful=True) |
| | |
| | estimated_mag = Multiply()([mag, mask_1]) |
| | |
| | estimated_frames_1 = Lambda(self.ifftLayer)([estimated_mag,angle]) |
| | |
| | encoded_frames = Conv1D(self.encoder_size,1,strides=1,use_bias=False)(estimated_frames_1) |
| | |
| | encoded_frames_norm = InstantLayerNormalization()(encoded_frames) |
| | |
| | mask_2 = self.seperation_kernel(self.numLayer, self.encoder_size, encoded_frames_norm, stateful=True) |
| | |
| | estimated = Multiply()([encoded_frames, mask_2]) |
| | |
| | decoded_frame = Conv1D(self.blockLen, 1, padding='causal',use_bias=False)(estimated) |
| | |
| | self.model = Model(inputs=time_dat, outputs=decoded_frame) |
| | |
| | print(self.model.summary()) |
| | |
| | def compile_model(self): |
| | ''' |
| | Method to compile the model for training |
| | |
| | ''' |
| | |
| | |
| | optimizerAdam = keras.optimizers.Adam(lr=self.lr, clipnorm=3.0) |
| | |
| | self.model.compile(loss=self.lossWrapper(), optimizer=optimizerAdam) |
| | |
| | def create_saved_model(self, weights_file, target_name): |
| | ''' |
| | Method to create a saved model folder from a weights file |
| | |
| | ''' |
| | |
| | if weights_file.find('_norm_') != -1: |
| | norm_stft = True |
| | else: |
| | norm_stft = False |
| | |
| | self.build_DTLN_model_stateful(norm_stft=norm_stft) |
| | |
| | self.model.load_weights(weights_file) |
| | |
| | tf.saved_model.save(self.model, target_name) |
| | |
| | def create_tf_lite_model(self, weights_file, target_name, use_dynamic_range_quant=False): |
| | ''' |
| | Method to create a tf lite model folder from a weights file. |
| | The conversion creates two models, one for each separation core. |
| | Tf lite does not support complex numbers yet. Some processing must be |
| | done outside the model. |
| | For further information and how real time processing can be |
| | implemented see "real_time_processing_tf_lite.py". |
| | |
| | The conversion only works with TF 2.3. |
| | |
| | ''' |
| | |
| | if weights_file.find('_norm_') != -1: |
| | norm_stft = True |
| | num_elements_first_core = 2 + self.numLayer * 3 + 2 |
| | else: |
| | norm_stft = False |
| | num_elements_first_core = self.numLayer * 3 + 2 |
| | |
| | self.build_DTLN_model_stateful(norm_stft=norm_stft) |
| | |
| | self.model.load_weights(weights_file) |
| | |
| | |
| | mag = Input(batch_shape=(1, 1, (self.blockLen//2+1))) |
| | states_in_1 = Input(batch_shape=(1, self.numLayer, self.numUnits, 2)) |
| | |
| | if norm_stft: |
| | mag_norm = InstantLayerNormalization()(tf.math.log(mag + 1e-7)) |
| | else: |
| | |
| | mag_norm = mag |
| | |
| | mask_1, states_out_1 = self.seperation_kernel_with_states(self.numLayer, |
| | (self.blockLen//2+1), |
| | mag_norm, states_in_1) |
| | |
| | model_1 = Model(inputs=[mag, states_in_1], outputs=[mask_1, states_out_1]) |
| | |
| | |
| | |
| | estimated_frame_1 = Input(batch_shape=(1, 1, (self.blockLen))) |
| | states_in_2 = Input(batch_shape=(1, self.numLayer, self.numUnits, 2)) |
| | |
| | |
| | encoded_frames = Conv1D(self.encoder_size,1,strides=1, |
| | use_bias=False)(estimated_frame_1) |
| | |
| | encoded_frames_norm = InstantLayerNormalization()(encoded_frames) |
| | |
| | mask_2, states_out_2 = self.seperation_kernel_with_states(self.numLayer, |
| | self.encoder_size, |
| | encoded_frames_norm, |
| | states_in_2) |
| | |
| | estimated = Multiply()([encoded_frames, mask_2]) |
| | |
| | decoded_frame = Conv1D(self.blockLen, 1, padding='causal', |
| | use_bias=False)(estimated) |
| | |
| | model_2 = Model(inputs=[estimated_frame_1, states_in_2], |
| | outputs=[decoded_frame, states_out_2]) |
| | |
| | |
| | weights = self.model.get_weights() |
| | model_1.set_weights(weights[:num_elements_first_core]) |
| | model_2.set_weights(weights[num_elements_first_core:]) |
| | |
| | converter = tf.lite.TFLiteConverter.from_keras_model(model_1) |
| | if use_dynamic_range_quant: |
| | converter.optimizations = [tf.lite.Optimize.DEFAULT] |
| | tflite_model = converter.convert() |
| | with tf.io.gfile.GFile(target_name + '_1.tflite', 'wb') as f: |
| | f.write(tflite_model) |
| | |
| | converter = tf.lite.TFLiteConverter.from_keras_model(model_2) |
| | if use_dynamic_range_quant: |
| | converter.optimizations = [tf.lite.Optimize.DEFAULT] |
| | tflite_model = converter.convert() |
| | with tf.io.gfile.GFile(target_name + '_2.tflite', 'wb') as f: |
| | f.write(tflite_model) |
| | |
| | print('TF lite conversion complete!') |
| | |
| | |
| | def train_model(self, runName, path_to_train_mix, path_to_train_speech, \ |
| | path_to_val_mix, path_to_val_speech): |
| | ''' |
| | Method to train the DTLN model. |
| | ''' |
| | |
| | |
| | savePath = './models_'+ runName+'/' |
| | if not os.path.exists(savePath): |
| | os.makedirs(savePath) |
| | |
| | csv_logger = CSVLogger(savePath+ 'training_' +runName+ '.log') |
| | |
| | reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, |
| | patience=3, min_lr=10**(-10), cooldown=1) |
| | |
| | early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, |
| | patience=10, verbose=0, mode='auto', baseline=None) |
| | |
| | checkpointer = ModelCheckpoint(savePath+runName+'.h5', |
| | monitor='val_loss', |
| | verbose=1, |
| | save_best_only=True, |
| | save_weights_only=True, |
| | mode='auto', |
| | save_freq='epoch' |
| | ) |
| |
|
| | |
| | len_in_samples = int(np.fix(self.fs * self.len_samples / |
| | self.block_shift)*self.block_shift) |
| | |
| | generator_input = audio_generator(path_to_train_mix, |
| | path_to_train_speech, |
| | len_in_samples, |
| | self.fs, train_flag=True) |
| | dataset = generator_input.tf_data_set |
| | dataset = dataset.batch(self.batchsize, drop_remainder=True).repeat() |
| | |
| | steps_train = generator_input.total_samples//self.batchsize |
| | |
| | generator_val = audio_generator(path_to_val_mix, |
| | path_to_val_speech, |
| | len_in_samples, self.fs) |
| | dataset_val = generator_val.tf_data_set |
| | dataset_val = dataset_val.batch(self.batchsize, drop_remainder=True).repeat() |
| | |
| | steps_val = generator_val.total_samples//self.batchsize |
| | |
| | self.model.fit( |
| | x=dataset, |
| | batch_size=None, |
| | steps_per_epoch=steps_train, |
| | epochs=self.max_epochs, |
| | verbose=1, |
| | validation_data=dataset_val, |
| | validation_steps=steps_val, |
| | callbacks=[checkpointer, reduce_lr, csv_logger, early_stopping], |
| | max_queue_size=50, |
| | workers=4, |
| | use_multiprocessing=True) |
| | |
| | tf.keras.backend.clear_session() |
| |
|
| | |
| |
|
| | class InstantLayerNormalization(Layer): |
| | ''' |
| | Class implementing instant layer normalization. It can also be called |
| | channel-wise layer normalization and was proposed by |
| | Luo & Mesgarani (https://arxiv.org/abs/1809.07454v2) |
| | ''' |
| |
|
| | def __init__(self, **kwargs): |
| | ''' |
| | Constructor |
| | ''' |
| | super(InstantLayerNormalization, self).__init__(**kwargs) |
| | self.epsilon = 1e-7 |
| | self.gamma = None |
| | self.beta = None |
| |
|
| | def build(self, input_shape): |
| | ''' |
| | Method to build the weights. |
| | ''' |
| | shape = input_shape[-1:] |
| | |
| | self.gamma = self.add_weight(shape=shape, |
| | initializer='ones', |
| | trainable=True, |
| | name='gamma') |
| | |
| | self.beta = self.add_weight(shape=shape, |
| | initializer='zeros', |
| | trainable=True, |
| | name='beta') |
| | |
| |
|
| | def call(self, inputs): |
| | ''' |
| | Method to call the Layer. All processing is done here. |
| | ''' |
| |
|
| | |
| | mean = tf.math.reduce_mean(inputs, axis=[-1], keepdims=True) |
| | |
| | variance = tf.math.reduce_mean(tf.math.square(inputs - mean), |
| | axis=[-1], keepdims=True) |
| | |
| | std = tf.math.sqrt(variance + self.epsilon) |
| | |
| | outputs = (inputs - mean) / std |
| | |
| | outputs = outputs * self.gamma |
| | |
| | outputs = outputs + self.beta |
| | |
| | return outputs |
| | |
| |
|
| | |
| |
|