Openai-Codex-REAL / simple-generative-model-regressor.py
scp4950's picture
Create simple-generative-model-regressor.py
59e33cf
raw history blame
No virus
4.99 kB
import sys
import time
import numpy as np
from keras.activations import relu
from scipy.io.wavfile import read, write
from keras.models import Model, Sequential
from keras.layers import Convolution2D, AtrousConvolution2D, Flatten, Dense, \
Input, Lambda, merge
def wavenetBlock(n_atrous_filters, atrous_filter_size, atrous_rate,
n_conv_filters, conv_filter_size):
def f(input_):
residual = input_
tanh_out = AtrousConvolution1D(n_atrous_filters, atrous_filter_size,
atrous_rate=atrous_rate,
border_mode='same',
activation='tanh')(input_)
sigmoid_out = AtrousConvolution1D(n_atrous_filters, atrous_filter_size,
atrous_rate=atrous_rate,
border_mode='same',
activation='sigmoid')(input_)
merged = merge([tanh_out, sigmoid_out], mode='mul')
skip_out = Convolution1D(1, 1, activation='relu', border_mode='same')(merged)
out = merge([skip_out, residual], mode='sum')
return out, skip_out
return f
def get_basic_generative_model(input_size):
input = Input(shape=(1, input_size, 1))
l1a, l1b = wavenetBlock(10, 5, 2, 1, 3)(input)
l2a, l2b = wavenetBlock(1, 2, 4, 1, 3)(l1a)
l3a, l3b = wavenetBlock(1, 2, 8, 1, 3)(l2a)
l4a, l4b = wavenetBlock(1, 2, 16, 1, 3)(l3a)
l5a, l5b = wavenetBlock(1, 2, 32, 1, 3)(l4a)
l6 = merge([l1b, l2b, l3b, l4b, l5b], mode='sum')
l7 = Lambda(relu)(l6)
l8 = Convolution2D(1, 1, 1, activation='relu')(l7)
l9 = Convolution2D(1, 1, 1)(l8)
l10 = Flatten()(l9)
l11 = Dense(1, activation='tanh')(l10)
model = Model(input=input, output=l11)
model.compile(loss='mse', optimizer='rmsprop', metrics=['accuracy'])
model.summary()
return model
def get_audio(filename):
sr, audio = read(filename)
audio = audio.astype(float)
audio = audio - audio.min()
audio = audio / (audio.max() - audio.min())
audio = (audio - 0.5) * 2
return sr, audio
def frame_generator(sr, audio, frame_size, frame_shift):
audio_len = len(audio)
while 1:
for i in range(0, audio_len - frame_size - 1, frame_shift):
frame = audio[i:i+frame_size]
if len(frame) < frame_size:
break
if i + frame_size >= audio_len:
break
temp = audio[i + frame_size]
yield frame.reshape(1, 1, frame_size, 1), \
temp.reshape(1, 1)
if __name__ == '__main__':
n_epochs = 20
frame_size = 2048
frame_shift = 512
sr_training, training_audio = get_audio('train.wav')
training_audio = training_audio[:sr_training*240]
sr_valid, valid_audio = get_audio('validate.wav')
valid_audio = valid_audio[:sr_valid*30]
assert sr_training == sr_valid, "Training, validation samplerate mismatch"
n_training_examples = int((len(training_audio)-frame_size-1) / float(
frame_shift))
n_validation_examples = int((len(valid_audio)-frame_size-1) / float(
frame_shift))
model = get_basic_generative_model(frame_size)
print 'Total training examples:', n_training_examples
print 'Total validation examples:', n_validation_examples
model.fit_generator(frame_generator(sr_training, training_audio,
frame_size, frame_shift),
samples_per_epoch=n_training_examples,
nb_epoch=n_epochs,
validation_data=frame_generator(sr_valid, valid_audio,
frame_size, frame_shift
),
nb_val_samples=n_validation_examples,
verbose=1)
print 'Saving model...'
str_timestamp = str(int(time.time()))
model.save('models/model_'+str_timestamp+'_'+str(n_epochs)+'.h5')
print 'Generating audio...'
new_audio = np.zeros((sr_training * 3))
curr_sample_idx = 0
audio_context = valid_audio[:frame_size]
while curr_sample_idx < new_audio.shape[0]:
predicted_val = model.predict(audio_context.reshape(1, 1, frame_size,
1))
ampl_val_16 = predicted_val * 2**15
new_audio[curr_sample_idx] = ampl_val_16
audio_context[-1] = ampl_val_16
audio_context[:-1] = audio_context[1:]
pc_str = str(round(100*curr_sample_idx/float(new_audio.shape[0]), 2))
sys.stdout.write('Percent complete: ' + pc_str + '\r')
sys.stdout.flush()
curr_sample_idx += 1
outfilepath = 'output/reg_generated_'+str_timestamp+'.wav'
print 'Writing generated audio to:', outfilepath
write(outfilepath, sr_training, new_audio.astype(np.int16))
print '\nDone!'