scp4950 commited on
Commit
59e33cf
1 Parent(s): 7573b08

Create simple-generative-model-regressor.py

Browse files
Files changed (1) hide show
  1. simple-generative-model-regressor.py +118 -0
simple-generative-model-regressor.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import time
3
+ import numpy as np
4
+ from keras.activations import relu
5
+ from scipy.io.wavfile import read, write
6
+ from keras.models import Model, Sequential
7
+ from keras.layers import Convolution2D, AtrousConvolution2D, Flatten, Dense, \
8
+ Input, Lambda, merge
9
+
10
+
11
+ def wavenetBlock(n_atrous_filters, atrous_filter_size, atrous_rate,
12
+ n_conv_filters, conv_filter_size):
13
+ def f(input_):
14
+ residual = input_
15
+ tanh_out = AtrousConvolution1D(n_atrous_filters, atrous_filter_size,
16
+ atrous_rate=atrous_rate,
17
+ border_mode='same',
18
+ activation='tanh')(input_)
19
+ sigmoid_out = AtrousConvolution1D(n_atrous_filters, atrous_filter_size,
20
+ atrous_rate=atrous_rate,
21
+ border_mode='same',
22
+ activation='sigmoid')(input_)
23
+ merged = merge([tanh_out, sigmoid_out], mode='mul')
24
+ skip_out = Convolution1D(1, 1, activation='relu', border_mode='same')(merged)
25
+ out = merge([skip_out, residual], mode='sum')
26
+ return out, skip_out
27
+ return f
28
+
29
+
30
+ def get_basic_generative_model(input_size):
31
+ input = Input(shape=(1, input_size, 1))
32
+ l1a, l1b = wavenetBlock(10, 5, 2, 1, 3)(input)
33
+ l2a, l2b = wavenetBlock(1, 2, 4, 1, 3)(l1a)
34
+ l3a, l3b = wavenetBlock(1, 2, 8, 1, 3)(l2a)
35
+ l4a, l4b = wavenetBlock(1, 2, 16, 1, 3)(l3a)
36
+ l5a, l5b = wavenetBlock(1, 2, 32, 1, 3)(l4a)
37
+ l6 = merge([l1b, l2b, l3b, l4b, l5b], mode='sum')
38
+ l7 = Lambda(relu)(l6)
39
+ l8 = Convolution2D(1, 1, 1, activation='relu')(l7)
40
+ l9 = Convolution2D(1, 1, 1)(l8)
41
+ l10 = Flatten()(l9)
42
+ l11 = Dense(1, activation='tanh')(l10)
43
+ model = Model(input=input, output=l11)
44
+ model.compile(loss='mse', optimizer='rmsprop', metrics=['accuracy'])
45
+ model.summary()
46
+ return model
47
+
48
+
49
+ def get_audio(filename):
50
+ sr, audio = read(filename)
51
+ audio = audio.astype(float)
52
+ audio = audio - audio.min()
53
+ audio = audio / (audio.max() - audio.min())
54
+ audio = (audio - 0.5) * 2
55
+ return sr, audio
56
+
57
+
58
+ def frame_generator(sr, audio, frame_size, frame_shift):
59
+ audio_len = len(audio)
60
+ while 1:
61
+ for i in range(0, audio_len - frame_size - 1, frame_shift):
62
+ frame = audio[i:i+frame_size]
63
+ if len(frame) < frame_size:
64
+ break
65
+ if i + frame_size >= audio_len:
66
+ break
67
+ temp = audio[i + frame_size]
68
+ yield frame.reshape(1, 1, frame_size, 1), \
69
+ temp.reshape(1, 1)
70
+
71
+
72
+ if __name__ == '__main__':
73
+ n_epochs = 20
74
+ frame_size = 2048
75
+ frame_shift = 512
76
+ sr_training, training_audio = get_audio('train.wav')
77
+ training_audio = training_audio[:sr_training*240]
78
+ sr_valid, valid_audio = get_audio('validate.wav')
79
+ valid_audio = valid_audio[:sr_valid*30]
80
+ assert sr_training == sr_valid, "Training, validation samplerate mismatch"
81
+ n_training_examples = int((len(training_audio)-frame_size-1) / float(
82
+ frame_shift))
83
+ n_validation_examples = int((len(valid_audio)-frame_size-1) / float(
84
+ frame_shift))
85
+ model = get_basic_generative_model(frame_size)
86
+ print 'Total training examples:', n_training_examples
87
+ print 'Total validation examples:', n_validation_examples
88
+ model.fit_generator(frame_generator(sr_training, training_audio,
89
+ frame_size, frame_shift),
90
+ samples_per_epoch=n_training_examples,
91
+ nb_epoch=n_epochs,
92
+ validation_data=frame_generator(sr_valid, valid_audio,
93
+ frame_size, frame_shift
94
+ ),
95
+ nb_val_samples=n_validation_examples,
96
+ verbose=1)
97
+ print 'Saving model...'
98
+ str_timestamp = str(int(time.time()))
99
+ model.save('models/model_'+str_timestamp+'_'+str(n_epochs)+'.h5')
100
+ print 'Generating audio...'
101
+ new_audio = np.zeros((sr_training * 3))
102
+ curr_sample_idx = 0
103
+ audio_context = valid_audio[:frame_size]
104
+ while curr_sample_idx < new_audio.shape[0]:
105
+ predicted_val = model.predict(audio_context.reshape(1, 1, frame_size,
106
+ 1))
107
+ ampl_val_16 = predicted_val * 2**15
108
+ new_audio[curr_sample_idx] = ampl_val_16
109
+ audio_context[-1] = ampl_val_16
110
+ audio_context[:-1] = audio_context[1:]
111
+ pc_str = str(round(100*curr_sample_idx/float(new_audio.shape[0]), 2))
112
+ sys.stdout.write('Percent complete: ' + pc_str + '\r')
113
+ sys.stdout.flush()
114
+ curr_sample_idx += 1
115
+ outfilepath = 'output/reg_generated_'+str_timestamp+'.wav'
116
+ print 'Writing generated audio to:', outfilepath
117
+ write(outfilepath, sr_training, new_audio.astype(np.int16))
118
+ print '\nDone!'