danurahul commited on
Commit
1a21884
1 Parent(s): 4d95c77

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +294 -0
model.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import numpy as np
3
+ import miditoolkit
4
+ import modules
5
+ import pickle
6
+ import utils
7
+ import time
8
+
9
+ class PopMusicTransformer(object):
10
+ ########################################
11
+ # initialize
12
+ ########################################
13
+ def __init__(self, checkpoint, is_training=False):
14
+ # load dictionary
15
+ self.dictionary_path = '{}/dictionary.pkl'.format(checkpoint)
16
+ self.event2word, self.word2event = pickle.load(open(self.dictionary_path, 'rb'))
17
+ # model settings
18
+ self.x_len = 512
19
+ self.mem_len = 512
20
+ self.n_layer = 12
21
+ self.d_embed = 512
22
+ self.d_model = 512
23
+ self.dropout = 0.1
24
+ self.n_head = 8
25
+ self.d_head = self.d_model // self.n_head
26
+ self.d_ff = 2048
27
+ self.n_token = len(self.event2word)
28
+ self.learning_rate = 0.0002
29
+ # load model
30
+ self.is_training = is_training
31
+ if self.is_training:
32
+ self.batch_size = 4
33
+ else:
34
+ self.batch_size = 1
35
+ self.checkpoint_path = '{}/model'.format(checkpoint)
36
+ self.load_model()
37
+
38
+ ########################################
39
+ # load model
40
+ ########################################
41
+ def load_model(self):
42
+ # placeholders
43
+ self.x = tf.compat.v1.placeholder(tf.int32, shape=[self.batch_size, None])
44
+ self.y = tf.compat.v1.placeholder(tf.int32, shape=[self.batch_size, None])
45
+ self.mems_i = [tf.compat.v1.placeholder(tf.float32, [self.mem_len, self.batch_size, self.d_model]) for _ in range(self.n_layer)]
46
+ # model
47
+ self.global_step = tf.compat.v1.train.get_or_create_global_step()
48
+ initializer = tf.compat.v1.initializers.random_normal(stddev=0.02, seed=None)
49
+ proj_initializer = tf.compat.v1.initializers.random_normal(stddev=0.01, seed=None)
50
+ with tf.compat.v1.variable_scope(tf.compat.v1.get_variable_scope()):
51
+ xx = tf.transpose(self.x, [1, 0])
52
+ yy = tf.transpose(self.y, [1, 0])
53
+ loss, self.logits, self.new_mem = modules.transformer(
54
+ dec_inp=xx,
55
+ target=yy,
56
+ mems=self.mems_i,
57
+ n_token=self.n_token,
58
+ n_layer=self.n_layer,
59
+ d_model=self.d_model,
60
+ d_embed=self.d_embed,
61
+ n_head=self.n_head,
62
+ d_head=self.d_head,
63
+ d_inner=self.d_ff,
64
+ dropout=self.dropout,
65
+ dropatt=self.dropout,
66
+ initializer=initializer,
67
+ proj_initializer=proj_initializer,
68
+ is_training=self.is_training,
69
+ mem_len=self.mem_len,
70
+ cutoffs=[],
71
+ div_val=-1,
72
+ tie_projs=[],
73
+ same_length=False,
74
+ clamp_len=-1,
75
+ input_perms=None,
76
+ target_perms=None,
77
+ head_target=None,
78
+ untie_r=False,
79
+ proj_same_dim=True)
80
+ self.avg_loss = tf.reduce_mean(loss)
81
+ # vars
82
+ all_vars = tf.compat.v1.trainable_variables()
83
+ grads = tf.gradients(self.avg_loss, all_vars)
84
+ grads_and_vars = list(zip(grads, all_vars))
85
+ all_trainable_vars = tf.reduce_sum([tf.reduce_prod(v.shape) for v in tf.compat.v1.trainable_variables()])
86
+ # optimizer
87
+ decay_lr = tf.compat.v1.train.cosine_decay(
88
+ self.learning_rate,
89
+ global_step=self.global_step,
90
+ decay_steps=400000,
91
+ alpha=0.004)
92
+ optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=decay_lr)
93
+ self.train_op = optimizer.apply_gradients(grads_and_vars, self.global_step)
94
+ # saver
95
+ self.saver = tf.compat.v1.train.Saver()
96
+ config = tf.compat.v1.ConfigProto(allow_soft_placement=True)
97
+ config.gpu_options.allow_growth = True
98
+ self.sess = tf.compat.v1.Session(config=config)
99
+ self.saver.restore(self.sess, self.checkpoint_path)
100
+
101
+ ########################################
102
+ # temperature sampling
103
+ ########################################
104
+ def temperature_sampling(self, logits, temperature, topk):
105
+ probs = np.exp(logits / temperature) / np.sum(np.exp(logits / temperature))
106
+ if topk == 1:
107
+ prediction = np.argmax(probs)
108
+ else:
109
+ sorted_index = np.argsort(probs)[::-1]
110
+ candi_index = sorted_index[:topk]
111
+ candi_probs = [probs[i] for i in candi_index]
112
+ # normalize probs
113
+ candi_probs /= sum(candi_probs)
114
+ # choose by predicted probs
115
+ prediction = np.random.choice(candi_index, size=1, p=candi_probs)[0]
116
+ return prediction
117
+
118
+ ########################################
119
+ # extract events for prompt continuation
120
+ ########################################
121
+ def extract_events(self, input_path):
122
+ note_items, tempo_items = utils.read_items(input_path)
123
+ note_items = utils.quantize_items(note_items)
124
+ max_time = note_items[-1].end
125
+ if 'chord' in self.checkpoint_path:
126
+ chord_items = utils.extract_chords(note_items)
127
+ items = chord_items + tempo_items + note_items
128
+ else:
129
+ items = tempo_items + note_items
130
+ groups = utils.group_items(items, max_time)
131
+ events = utils.item2event(groups)
132
+ return events
133
+
134
+ ########################################
135
+ # generate
136
+ ########################################
137
+ def generate(self, n_target_bar, temperature, topk, output_path, prompt=None):
138
+ # if prompt, load it. Or, random start
139
+ if prompt:
140
+ events = self.extract_events(prompt)
141
+ words = [[self.event2word['{}_{}'.format(e.name, e.value)] for e in events]]
142
+ words[0].append(self.event2word['Bar_None'])
143
+ else:
144
+ words = []
145
+ for _ in range(self.batch_size):
146
+ ws = [self.event2word['Bar_None']]
147
+ if 'chord' in self.checkpoint_path:
148
+ tempo_classes = [v for k, v in self.event2word.items() if 'Tempo Class' in k]
149
+ tempo_values = [v for k, v in self.event2word.items() if 'Tempo Value' in k]
150
+ chords = [v for k, v in self.event2word.items() if 'Chord' in k]
151
+ ws.append(self.event2word['Position_1/16'])
152
+ ws.append(np.random.choice(chords))
153
+ ws.append(self.event2word['Position_1/16'])
154
+ ws.append(np.random.choice(tempo_classes))
155
+ ws.append(np.random.choice(tempo_values))
156
+ else:
157
+ tempo_classes = [v for k, v in self.event2word.items() if 'Tempo Class' in k]
158
+ tempo_values = [v for k, v in self.event2word.items() if 'Tempo Value' in k]
159
+ ws.append(self.event2word['Position_1/16'])
160
+ ws.append(np.random.choice(tempo_classes))
161
+ ws.append(np.random.choice(tempo_values))
162
+ words.append(ws)
163
+ # initialize mem
164
+ batch_m = [np.zeros((self.mem_len, self.batch_size, self.d_model), dtype=np.float32) for _ in range(self.n_layer)]
165
+ # generate
166
+ original_length = len(words[0])
167
+ initial_flag = 1
168
+ current_generated_bar = 0
169
+ while current_generated_bar < n_target_bar:
170
+ # input
171
+ if initial_flag:
172
+ temp_x = np.zeros((self.batch_size, original_length))
173
+ for b in range(self.batch_size):
174
+ for z, t in enumerate(words[b]):
175
+ temp_x[b][z] = t
176
+ initial_flag = 0
177
+ else:
178
+ temp_x = np.zeros((self.batch_size, 1))
179
+ for b in range(self.batch_size):
180
+ temp_x[b][0] = words[b][-1]
181
+ # prepare feed dict
182
+ feed_dict = {self.x: temp_x}
183
+ for m, m_np in zip(self.mems_i, batch_m):
184
+ feed_dict[m] = m_np
185
+ # model (prediction)
186
+ _logits, _new_mem = self.sess.run([self.logits, self.new_mem], feed_dict=feed_dict)
187
+ # sampling
188
+ _logit = _logits[-1, 0]
189
+ word = self.temperature_sampling(
190
+ logits=_logit,
191
+ temperature=temperature,
192
+ topk=topk)
193
+ words[0].append(word)
194
+ # if bar event (only work for batch_size=1)
195
+ if word == self.event2word['Bar_None']:
196
+ current_generated_bar += 1
197
+ # re-new mem
198
+ batch_m = _new_mem
199
+ # write
200
+ if prompt:
201
+ utils.write_midi(
202
+ words=words[0][original_length:],
203
+ word2event=self.word2event,
204
+ output_path=output_path,
205
+ prompt_path=prompt)
206
+ else:
207
+ utils.write_midi(
208
+ words=words[0],
209
+ word2event=self.word2event,
210
+ output_path=output_path,
211
+ prompt_path=None)
212
+
213
+ ########################################
214
+ # prepare training data
215
+ ########################################
216
+ def prepare_data(self, midi_paths):
217
+ # extract events
218
+ all_events = []
219
+ for path in midi_paths:
220
+ events = self.extract_events(path)
221
+ all_events.append(events)
222
+ # event to word
223
+ all_words = []
224
+ for events in all_events:
225
+ words = []
226
+ for event in events:
227
+ e = '{}_{}'.format(event.name, event.value)
228
+ if e in self.event2word:
229
+ words.append(self.event2word[e])
230
+ else:
231
+ # OOV
232
+ if event.name == 'Note Velocity':
233
+ # replace with max velocity based on our training data
234
+ words.append(self.event2word['Note Velocity_21'])
235
+ else:
236
+ # something is wrong
237
+ # you should handle it for your own purpose
238
+ print('something is wrong! {}'.format(e))
239
+ all_words.append(words)
240
+ # to training data
241
+ self.group_size = 5
242
+ segments = []
243
+ for words in all_words:
244
+ pairs = []
245
+ for i in range(0, len(words)-self.x_len-1, self.x_len):
246
+ x = words[i:i+self.x_len]
247
+ y = words[i+1:i+self.x_len+1]
248
+ pairs.append([x, y])
249
+ pairs = np.array(pairs)
250
+ # abandon the last
251
+ for i in np.arange(0, len(pairs)-self.group_size, self.group_size*2):
252
+ data = pairs[i:i+self.group_size]
253
+ if len(data) == self.group_size:
254
+ segments.append(data)
255
+ segments = np.array(segments)
256
+ return segments
257
+
258
+ ########################################
259
+ # finetune
260
+ ########################################
261
+ def finetune(self, training_data, output_checkpoint_folder):
262
+ # shuffle
263
+ index = np.arange(len(training_data))
264
+ np.random.shuffle(index)
265
+ training_data = training_data[index]
266
+ num_batches = len(training_data) // self.batch_size
267
+ st = time.time()
268
+ for e in range(200):
269
+ total_loss = []
270
+ for i in range(num_batches):
271
+ segments = training_data[self.batch_size*i:self.batch_size*(i+1)]
272
+ batch_m = [np.zeros((self.mem_len, self.batch_size, self.d_model), dtype=np.float32) for _ in range(self.n_layer)]
273
+ for j in range(self.group_size):
274
+ batch_x = segments[:, j, 0, :]
275
+ batch_y = segments[:, j, 1, :]
276
+ # prepare feed dict
277
+ feed_dict = {self.x: batch_x, self.y: batch_y}
278
+ for m, m_np in zip(self.mems_i, batch_m):
279
+ feed_dict[m] = m_np
280
+ # run
281
+ _, gs_, loss_, new_mem_ = self.sess.run([self.train_op, self.global_step, self.avg_loss, self.new_mem], feed_dict=feed_dict)
282
+ batch_m = new_mem_
283
+ total_loss.append(loss_)
284
+ print('>>> Epoch: {}, Step: {}, Loss: {:.5f}, Time: {:.2f}'.format(e, gs_, loss_, time.time()-st))
285
+ self.saver.save(self.sess, '{}/model-{:03d}-{:.3f}'.format(output_checkpoint_folder, e, np.mean(total_loss)))
286
+ # stop
287
+ if np.mean(total_loss) <= 0.1:
288
+ break
289
+
290
+ ########################################
291
+ # close
292
+ ########################################
293
+ def close(self):
294
+ self.sess.close()