scp4950 commited on
Commit
7573b08
1 Parent(s): 718b3d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -34
app.py CHANGED
@@ -1,34 +1,149 @@
1
- import gradio as gr
2
- from gradio.inputs import Dropdown
3
- import requests
4
-
5
- """
6
- Use Gradio library to create a form with a textbox to store the input text.
7
- And API_KEY textbox to store the key.
8
- Use api.openai.com to get the response.
9
- """
10
-
11
-
12
- def generate_text(input_text, api_key, temperature, max_tokens, frequency_penalty, presence_penalty, engine_id):
13
- url = "https://api.openai.com/v1/engines/{engine_id}/completions".format(engine_id=engine_id)
14
- data = {
15
- "prompt": input_text,
16
- "max_tokens": max_tokens,
17
- "temperature": temperature,
18
- "frequency_penalty": frequency_penalty,
19
- "presence_penalty": presence_penalty
20
- }
21
- headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
22
- response = requests.post(url=url, json=data, headers=headers)
23
- if response.status_code == 200:
24
- return response.json()["choices"][0]["text"]
25
- else:
26
- return "Error: " + response.text
27
-
28
-
29
- if __name__ == "__main__":
30
- gr.Interface(
31
- generate_text, # function to be called
32
- [gr.inputs.Textbox(lines=1), gr.inputs.Textbox(lines=1), gr.inputs.Slider(minimum=0, maximum=1, step=0.01),
33
- gr.inputs.Slider(minimum=32, maximum=500, step=1), gr.inputs.Slider(minimum=0, maximum=1, step=0.01), gr.inputs.Slider(minimum=0, maximum=1, step=0.01), gr.inputs.Dropdown(["davinci-codex", "cushman-codex"])], # input types
34
- gr.outputs.Textbox() # output type
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import numpy as np
5
+ from keras.callbacks import Callback
6
+ from scipy.io.wavfile import read, write
7
+ from keras.models import Model, Sequential
8
+ from keras.layers import Convolution1D, AtrousConvolution1D, Flatten, Dense, \
9
+ Input, Lambda, merge, Activation
10
+
11
+
12
+ def wavenetBlock(n_atrous_filters, atrous_filter_size, atrous_rate):
13
+ def f(input_):
14
+ residual = input_
15
+ tanh_out = AtrousConvolution1D(n_atrous_filters, atrous_filter_size,
16
+ atrous_rate=atrous_rate,
17
+ border_mode='same',
18
+ activation='tanh')(input_)
19
+ sigmoid_out = AtrousConvolution1D(n_atrous_filters, atrous_filter_size,
20
+ atrous_rate=atrous_rate,
21
+ border_mode='same',
22
+ activation='sigmoid')(input_)
23
+ merged = merge([tanh_out, sigmoid_out], mode='mul')
24
+ skip_out = Convolution1D(1, 1, activation='relu', border_mode='same')(merged)
25
+ out = merge([skip_out, residual], mode='sum')
26
+ return out, skip_out
27
+ return f
28
+
29
+
30
+ def get_basic_generative_model(input_size):
31
+ input_ = Input(shape=(input_size, 1))
32
+ A, B = wavenetBlock(64, 2, 2)(input_)
33
+ skip_connections = [B]
34
+ for i in range(20):
35
+ A, B = wavenetBlock(64, 2, 2**((i+2)%9))(A)
36
+ skip_connections.append(B)
37
+ net = merge(skip_connections, mode='sum')
38
+ net = Activation('relu')(net)
39
+ net = Convolution1D(1, 1, activation='relu')(net)
40
+ net = Convolution1D(1, 1)(net)
41
+ net = Flatten()(net)
42
+ net = Dense(256, activation='softmax')(net)
43
+ model = Model(input=input_, output=net)
44
+ model.compile(loss='categorical_crossentropy', optimizer='sgd',
45
+ metrics=['accuracy'])
46
+ model.summary()
47
+ return model
48
+
49
+
50
+ def get_audio(filename):
51
+ sr, audio = read(filename)
52
+ audio = audio.astype(float)
53
+ audio = audio - audio.min()
54
+ audio = audio / (audio.max() - audio.min())
55
+ audio = (audio - 0.5) * 2
56
+ return sr, audio
57
+
58
+
59
+ def frame_generator(sr, audio, frame_size, frame_shift, minibatch_size=20):
60
+ audio_len = len(audio)
61
+ X = []
62
+ y = []
63
+ while 1:
64
+ for i in range(0, audio_len - frame_size - 1, frame_shift):
65
+ frame = audio[i:i+frame_size]
66
+ if len(frame) < frame_size:
67
+ break
68
+ if i + frame_size >= audio_len:
69
+ break
70
+ temp = audio[i + frame_size]
71
+ target_val = int((np.sign(temp) * (np.log(1 + 256*abs(temp)) / (
72
+ np.log(1+256))) + 1)/2.0 * 255)
73
+ X.append(frame.reshape(frame_size, 1))
74
+ y.append((np.eye(256)[target_val]))
75
+ if len(X) == minibatch_size:
76
+ yield np.array(X), np.array(y)
77
+ X = []
78
+ y = []
79
+
80
+
81
+ def get_audio_from_model(model, sr, duration, seed_audio):
82
+ print 'Generating audio...'
83
+ new_audio = np.zeros((sr * duration))
84
+ curr_sample_idx = 0
85
+ while curr_sample_idx < new_audio.shape[0]:
86
+ distribution = np.array(model.predict(seed_audio.reshape(1,
87
+ frame_size, 1)
88
+ ), dtype=float).reshape(256)
89
+ distribution /= distribution.sum().astype(float)
90
+ predicted_val = np.random.choice(range(256), p=distribution)
91
+ ampl_val_8 = ((((predicted_val) / 255.0) - 0.5) * 2.0)
92
+ ampl_val_16 = (np.sign(ampl_val_8) * (1/256.0) * ((1 + 256.0)**abs(
93
+ ampl_val_8) - 1)) * 2**15
94
+ new_audio[curr_sample_idx] = ampl_val_16
95
+ seed_audio[-1] = ampl_val_16
96
+ seed_audio[:-1] = seed_audio[1:]
97
+ pc_str = str(round(100*curr_sample_idx/float(new_audio.shape[0]), 2))
98
+ sys.stdout.write('Percent complete: ' + pc_str + '\r')
99
+ sys.stdout.flush()
100
+ curr_sample_idx += 1
101
+ print 'Audio generated.'
102
+ return new_audio.astype(np.int16)
103
+
104
+
105
+ class SaveAudioCallback(Callback):
106
+ def __init__(self, ckpt_freq, sr, seed_audio):
107
+ super(SaveAudioCallback, self).__init__()
108
+ self.ckpt_freq = ckpt_freq
109
+ self.sr = sr
110
+ self.seed_audio = seed_audio
111
+
112
+ def on_epoch_end(self, epoch, logs={}):
113
+ if (epoch+1)%self.ckpt_freq==0:
114
+ ts = str(int(time.time()))
115
+ filepath = os.path.join('output/', 'ckpt_'+ts+'.wav')
116
+ audio = get_audio_from_model(self.model, self.sr, 0.5, self.seed_audio)
117
+ write(filepath, self.sr, audio)
118
+
119
+
120
+ if __name__ == '__main__':
121
+ n_epochs = 2000
122
+ frame_size = 2048
123
+ frame_shift = 128
124
+ sr_training, training_audio = get_audio('train.wav')
125
+ # training_audio = training_audio[:sr_training*1200]
126
+ sr_valid, valid_audio = get_audio('validate.wav')
127
+ # valid_audio = valid_audio[:sr_valid*60]
128
+ assert sr_training == sr_valid, "Training, validation samplerate mismatch"
129
+ n_training_examples = int((len(training_audio)-frame_size-1) / float(
130
+ frame_shift))
131
+ n_validation_examples = int((len(valid_audio)-frame_size-1) / float(
132
+ frame_shift))
133
+ model = get_basic_generative_model(frame_size)
134
+ print 'Total training examples:', n_training_examples
135
+ print 'Total validation examples:', n_validation_examples
136
+ audio_context = valid_audio[:frame_size]
137
+ save_audio_clbk = SaveAudioCallback(100, sr_training, audio_context)
138
+ validation_data_gen = frame_generator(sr_valid, valid_audio, frame_size, frame_shift)
139
+ training_data_gen = frame_generator(sr_training, training_audio, frame_size, frame_shift)
140
+ model.fit_generator(training_data_gen, samples_per_epoch=3000, nb_epoch=n_epochs, validation_data=validation_data_gen,nb_val_samples=500, verbose=1, callbacks=[save_audio_clbk])
141
+ print 'Saving model...'
142
+ str_timestamp = str(int(time.time()))
143
+ model.save('models/model_'+str_timestamp+'_'+str(n_epochs)+'.h5')
144
+ print 'Generating audio...'
145
+ new_audio = get_audio_from_model(model, sr_training, 2, audio_context)
146
+ outfilepath = 'output/generated_'+str_timestamp+'.wav'
147
+ print 'Writing generated audio to:', outfilepath
148
+ write(outfilepath, sr_training, new_audio)
149
+ print '\nDone!'