Samuelblue's picture
Fixed print 'Generating audio...' problem
7980d55
raw
history blame
6.22 kB
import os
import sys
import time
import numpy as np
from keras.callbacks import Callback
from scipy.io.wavfile import read, write
from keras.models import Model, Sequential
from keras.layers import Convolution1D, AtrousConvolution1D, Flatten, Dense, \
Input, Lambda, merge, Activation
def wavenetBlock(n_atrous_filters, atrous_filter_size, atrous_rate):
def f(input_):
residual = input_
tanh_out = AtrousConvolution1D(n_atrous_filters, atrous_filter_size,
atrous_rate=atrous_rate,
border_mode='same',
activation='tanh')(input_)
sigmoid_out = AtrousConvolution1D(n_atrous_filters, atrous_filter_size,
atrous_rate=atrous_rate,
border_mode='same',
activation='sigmoid')(input_)
merged = merge([tanh_out, sigmoid_out], mode='mul')
skip_out = Convolution1D(1, 1, activation='relu', border_mode='same')(merged)
out = merge([skip_out, residual], mode='sum')
return out, skip_out
return f
def get_basic_generative_model(input_size):
input_ = Input(shape=(input_size, 1))
A, B = wavenetBlock(64, 2, 2)(input_)
skip_connections = [B]
for i in range(20):
A, B = wavenetBlock(64, 2, 2**((i+2)%9))(A)
skip_connections.append(B)
net = merge(skip_connections, mode='sum')
net = Activation('relu')(net)
net = Convolution1D(1, 1, activation='relu')(net)
net = Convolution1D(1, 1)(net)
net = Flatten()(net)
net = Dense(256, activation='softmax')(net)
model = Model(input=input_, output=net)
model.compile(loss='categorical_crossentropy', optimizer='sgd',
metrics=['accuracy'])
model.summary()
return model
def get_audio(filename):
sr, audio = read(filename)
audio = audio.astype(float)
audio = audio - audio.min()
audio = audio / (audio.max() - audio.min())
audio = (audio - 0.5) * 2
return sr, audio
def frame_generator(sr, audio, frame_size, frame_shift, minibatch_size=20):
audio_len = len(audio)
X = []
y = []
while 1:
for i in range(0, audio_len - frame_size - 1, frame_shift):
frame = audio[i:i+frame_size]
if len(frame) < frame_size:
break
if i + frame_size >= audio_len:
break
temp = audio[i + frame_size]
target_val = int((np.sign(temp) * (np.log(1 + 256*abs(temp)) / (
np.log(1+256))) + 1)/2.0 * 255)
X.append(frame.reshape(frame_size, 1))
y.append((np.eye(256)[target_val]))
if len(X) == minibatch_size:
yield np.array(X), np.array(y)
X = []
y = []
def get_audio_from_model(model, sr, duration, seed_audio):
print('Generating audio...')
new_audio = np.zeros((sr * duration))
curr_sample_idx = 0
while curr_sample_idx < new_audio.shape[0]:
distribution = np.array(model.predict(seed_audio.reshape(1,
frame_size, 1)
), dtype=float).reshape(256)
distribution /= distribution.sum().astype(float)
predicted_val = np.random.choice(range(256), p=distribution)
ampl_val_8 = ((((predicted_val) / 255.0) - 0.5) * 2.0)
ampl_val_16 = (np.sign(ampl_val_8) * (1/256.0) * ((1 + 256.0)**abs(
ampl_val_8) - 1)) * 2**15
new_audio[curr_sample_idx] = ampl_val_16
seed_audio[-1] = ampl_val_16
seed_audio[:-1] = seed_audio[1:]
pc_str = str(round(100*curr_sample_idx/float(new_audio.shape[0]), 2))
sys.stdout.write('Percent complete: ' + pc_str + '\r')
sys.stdout.flush()
curr_sample_idx += 1
print 'Audio generated.'
return new_audio.astype(np.int16)
class SaveAudioCallback(Callback):
def __init__(self, ckpt_freq, sr, seed_audio):
super(SaveAudioCallback, self).__init__()
self.ckpt_freq = ckpt_freq
self.sr = sr
self.seed_audio = seed_audio
def on_epoch_end(self, epoch, logs={}):
if (epoch+1)%self.ckpt_freq==0:
ts = str(int(time.time()))
filepath = os.path.join('output/', 'ckpt_'+ts+'.wav')
audio = get_audio_from_model(self.model, self.sr, 0.5, self.seed_audio)
write(filepath, self.sr, audio)
if __name__ == '__main__':
n_epochs = 2000
frame_size = 2048
frame_shift = 128
sr_training, training_audio = get_audio('train.wav')
# training_audio = training_audio[:sr_training*1200]
sr_valid, valid_audio = get_audio('validate.wav')
# valid_audio = valid_audio[:sr_valid*60]
assert sr_training == sr_valid, "Training, validation samplerate mismatch"
n_training_examples = int((len(training_audio)-frame_size-1) / float(
frame_shift))
n_validation_examples = int((len(valid_audio)-frame_size-1) / float(
frame_shift))
model = get_basic_generative_model(frame_size)
print 'Total training examples:', n_training_examples
print 'Total validation examples:', n_validation_examples
audio_context = valid_audio[:frame_size]
save_audio_clbk = SaveAudioCallback(100, sr_training, audio_context)
validation_data_gen = frame_generator(sr_valid, valid_audio, frame_size, frame_shift)
training_data_gen = frame_generator(sr_training, training_audio, frame_size, frame_shift)
model.fit_generator(training_data_gen, samples_per_epoch=3000, nb_epoch=n_epochs, validation_data=validation_data_gen,nb_val_samples=500, verbose=1, callbacks=[save_audio_clbk])
print('Saving model...')
str_timestamp = str(int(time.time()))
model.save('models/model_'+str_timestamp+'_'+str(n_epochs)+'.h5')
print('Generating audio...')
new_audio = get_audio_from_model(model, sr_training, 2, audio_context)
outfilepath = 'output/generated_'+str_timestamp+'.wav'
print 'Writing generated audio to:', outfilepath
write(outfilepath, sr_training, new_audio)
print '\nDone!'