File size: 4,986 Bytes
59e33cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import sys
import time
import numpy as np
from keras.activations import relu
from scipy.io.wavfile import read, write
from keras.models import Model, Sequential
from keras.layers import Convolution2D, AtrousConvolution2D, Flatten, Dense, \
    Input, Lambda, merge


def wavenetBlock(n_atrous_filters, atrous_filter_size, atrous_rate,
                 n_conv_filters, conv_filter_size):
    def f(input_):
        residual = input_
        tanh_out = AtrousConvolution1D(n_atrous_filters, atrous_filter_size,
                                       atrous_rate=atrous_rate,
                                       border_mode='same',
                                       activation='tanh')(input_)
        sigmoid_out = AtrousConvolution1D(n_atrous_filters, atrous_filter_size,
                                          atrous_rate=atrous_rate,
                                          border_mode='same',
                                          activation='sigmoid')(input_)
        merged = merge([tanh_out, sigmoid_out], mode='mul')
        skip_out = Convolution1D(1, 1, activation='relu', border_mode='same')(merged)
        out = merge([skip_out, residual], mode='sum')
        return out, skip_out
    return f


def get_basic_generative_model(input_size):
    input = Input(shape=(1, input_size, 1))
    l1a, l1b = wavenetBlock(10, 5, 2, 1, 3)(input)
    l2a, l2b = wavenetBlock(1, 2, 4, 1, 3)(l1a)
    l3a, l3b = wavenetBlock(1, 2, 8, 1, 3)(l2a)
    l4a, l4b = wavenetBlock(1, 2, 16, 1, 3)(l3a)
    l5a, l5b = wavenetBlock(1, 2, 32, 1, 3)(l4a)
    l6 = merge([l1b, l2b, l3b, l4b, l5b], mode='sum')
    l7 = Lambda(relu)(l6)
    l8 = Convolution2D(1, 1, 1, activation='relu')(l7)
    l9 = Convolution2D(1, 1, 1)(l8)
    l10 = Flatten()(l9)
    l11 = Dense(1, activation='tanh')(l10)
    model = Model(input=input, output=l11)
    model.compile(loss='mse', optimizer='rmsprop', metrics=['accuracy'])
    model.summary()
    return model


def get_audio(filename):
    sr, audio = read(filename)
    audio = audio.astype(float)
    audio = audio - audio.min()
    audio = audio / (audio.max() - audio.min())
    audio = (audio - 0.5) * 2
    return sr, audio


def frame_generator(sr, audio, frame_size, frame_shift):
    audio_len = len(audio)
    while 1:
        for i in range(0, audio_len - frame_size - 1, frame_shift):
            frame = audio[i:i+frame_size]
            if len(frame) < frame_size:
                break
            if i + frame_size >= audio_len:
                break
            temp = audio[i + frame_size]
            yield frame.reshape(1, 1, frame_size, 1), \
                temp.reshape(1, 1)


if __name__ == '__main__':
    n_epochs = 20
    frame_size = 2048
    frame_shift = 512
    sr_training, training_audio = get_audio('train.wav')
    training_audio = training_audio[:sr_training*240]
    sr_valid, valid_audio = get_audio('validate.wav')
    valid_audio = valid_audio[:sr_valid*30]
    assert sr_training == sr_valid, "Training, validation samplerate mismatch"
    n_training_examples = int((len(training_audio)-frame_size-1) / float(
        frame_shift))
    n_validation_examples = int((len(valid_audio)-frame_size-1) / float(
        frame_shift))
    model = get_basic_generative_model(frame_size)
    print 'Total training examples:', n_training_examples
    print 'Total validation examples:', n_validation_examples
    model.fit_generator(frame_generator(sr_training, training_audio,
                                        frame_size, frame_shift),
                        samples_per_epoch=n_training_examples,
                        nb_epoch=n_epochs,
                        validation_data=frame_generator(sr_valid, valid_audio,
                                                        frame_size, frame_shift
                                                        ),
                        nb_val_samples=n_validation_examples,
                        verbose=1)
    print 'Saving model...'
    str_timestamp = str(int(time.time()))
    model.save('models/model_'+str_timestamp+'_'+str(n_epochs)+'.h5')
    print 'Generating audio...'
    new_audio = np.zeros((sr_training * 3))
    curr_sample_idx = 0
    audio_context = valid_audio[:frame_size]
    while curr_sample_idx < new_audio.shape[0]:
        predicted_val = model.predict(audio_context.reshape(1, 1, frame_size,
                                                            1))
        ampl_val_16 = predicted_val * 2**15
        new_audio[curr_sample_idx] = ampl_val_16
        audio_context[-1] = ampl_val_16
        audio_context[:-1] = audio_context[1:]
        pc_str = str(round(100*curr_sample_idx/float(new_audio.shape[0]), 2))
        sys.stdout.write('Percent complete: ' + pc_str + '\r')
        sys.stdout.flush()
        curr_sample_idx += 1
    outfilepath = 'output/reg_generated_'+str_timestamp+'.wav'
    print 'Writing generated audio to:', outfilepath
    write(outfilepath, sr_training, new_audio.astype(np.int16))
    print '\nDone!'