import tensorflow as tf import numpy as np import miditoolkit import modules import pickle import utils import time class PopMusicTransformer(object): ######################################## # initialize ######################################## def __init__(self, checkpoint, is_training=False): # load dictionary self.dictionary_path = '{}/dictionary.pkl'.format(checkpoint) self.event2word, self.word2event = pickle.load(open(self.dictionary_path, 'rb')) # model settings self.x_len = 512 self.mem_len = 512 self.n_layer = 12 self.d_embed = 512 self.d_model = 512 self.dropout = 0.1 self.n_head = 8 self.d_head = self.d_model // self.n_head self.d_ff = 2048 self.n_token = len(self.event2word) self.learning_rate = 0.0002 # load model self.is_training = is_training if self.is_training: self.batch_size = 4 else: self.batch_size = 1 self.checkpoint_path = '{}/model'.format(checkpoint) self.load_model() ######################################## # load model ######################################## def load_model(self): # placeholders self.x = tf.compat.v1.placeholder(tf.int32, shape=[self.batch_size, None]) self.y = tf.compat.v1.placeholder(tf.int32, shape=[self.batch_size, None]) self.mems_i = [tf.compat.v1.placeholder(tf.float32, [self.mem_len, self.batch_size, self.d_model]) for _ in range(self.n_layer)] # model self.global_step = tf.compat.v1.train.get_or_create_global_step() initializer = tf.compat.v1.initializers.random_normal(stddev=0.02, seed=None) proj_initializer = tf.compat.v1.initializers.random_normal(stddev=0.01, seed=None) with tf.compat.v1.variable_scope(tf.compat.v1.get_variable_scope()): xx = tf.transpose(self.x, [1, 0]) yy = tf.transpose(self.y, [1, 0]) loss, self.logits, self.new_mem = modules.transformer( dec_inp=xx, target=yy, mems=self.mems_i, n_token=self.n_token, n_layer=self.n_layer, d_model=self.d_model, d_embed=self.d_embed, n_head=self.n_head, d_head=self.d_head, d_inner=self.d_ff, dropout=self.dropout, dropatt=self.dropout, initializer=initializer, proj_initializer=proj_initializer, is_training=self.is_training, mem_len=self.mem_len, cutoffs=[], div_val=-1, tie_projs=[], same_length=False, clamp_len=-1, input_perms=None, target_perms=None, head_target=None, untie_r=False, proj_same_dim=True) self.avg_loss = tf.reduce_mean(loss) # vars all_vars = tf.compat.v1.trainable_variables() grads = tf.gradients(self.avg_loss, all_vars) grads_and_vars = list(zip(grads, all_vars)) all_trainable_vars = tf.reduce_sum([tf.reduce_prod(v.shape) for v in tf.compat.v1.trainable_variables()]) # optimizer decay_lr = tf.compat.v1.train.cosine_decay( self.learning_rate, global_step=self.global_step, decay_steps=400000, alpha=0.004) optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=decay_lr) self.train_op = optimizer.apply_gradients(grads_and_vars, self.global_step) # saver self.saver = tf.compat.v1.train.Saver() config = tf.compat.v1.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True self.sess = tf.compat.v1.Session(config=config) self.saver.restore(self.sess, self.checkpoint_path) ######################################## # temperature sampling ######################################## def temperature_sampling(self, logits, temperature, topk): probs = np.exp(logits / temperature) / np.sum(np.exp(logits / temperature)) if topk == 1: prediction = np.argmax(probs) else: sorted_index = np.argsort(probs)[::-1] candi_index = sorted_index[:topk] candi_probs = [probs[i] for i in candi_index] # normalize probs candi_probs /= sum(candi_probs) # choose by predicted probs prediction = np.random.choice(candi_index, size=1, p=candi_probs)[0] return prediction ######################################## # extract events for prompt continuation ######################################## def extract_events(self, input_path): note_items, tempo_items = utils.read_items(input_path) note_items = utils.quantize_items(note_items) max_time = note_items[-1].end if 'chord' in self.checkpoint_path: chord_items = utils.extract_chords(note_items) items = chord_items + tempo_items + note_items else: items = tempo_items + note_items groups = utils.group_items(items, max_time) events = utils.item2event(groups) return events ######################################## # generate ######################################## def generate(self, n_target_bar, temperature, topk, output_path, prompt=None): # if prompt, load it. Or, random start if prompt: events = self.extract_events(prompt) words = [[self.event2word['{}_{}'.format(e.name, e.value)] for e in events]] words[0].append(self.event2word['Bar_None']) else: words = [] for _ in range(self.batch_size): ws = [self.event2word['Bar_None']] if 'chord' in self.checkpoint_path: tempo_classes = [v for k, v in self.event2word.items() if 'Tempo Class' in k] tempo_values = [v for k, v in self.event2word.items() if 'Tempo Value' in k] chords = [v for k, v in self.event2word.items() if 'Chord' in k] ws.append(self.event2word['Position_1/16']) ws.append(np.random.choice(chords)) ws.append(self.event2word['Position_1/16']) ws.append(np.random.choice(tempo_classes)) ws.append(np.random.choice(tempo_values)) else: tempo_classes = [v for k, v in self.event2word.items() if 'Tempo Class' in k] tempo_values = [v for k, v in self.event2word.items() if 'Tempo Value' in k] ws.append(self.event2word['Position_1/16']) ws.append(np.random.choice(tempo_classes)) ws.append(np.random.choice(tempo_values)) words.append(ws) # initialize mem batch_m = [np.zeros((self.mem_len, self.batch_size, self.d_model), dtype=np.float32) for _ in range(self.n_layer)] # generate original_length = len(words[0]) initial_flag = 1 current_generated_bar = 0 while current_generated_bar < n_target_bar: # input if initial_flag: temp_x = np.zeros((self.batch_size, original_length)) for b in range(self.batch_size): for z, t in enumerate(words[b]): temp_x[b][z] = t initial_flag = 0 else: temp_x = np.zeros((self.batch_size, 1)) for b in range(self.batch_size): temp_x[b][0] = words[b][-1] # prepare feed dict feed_dict = {self.x: temp_x} for m, m_np in zip(self.mems_i, batch_m): feed_dict[m] = m_np # model (prediction) _logits, _new_mem = self.sess.run([self.logits, self.new_mem], feed_dict=feed_dict) # sampling _logit = _logits[-1, 0] word = self.temperature_sampling( logits=_logit, temperature=temperature, topk=topk) words[0].append(word) # if bar event (only work for batch_size=1) if word == self.event2word['Bar_None']: current_generated_bar += 1 # re-new mem batch_m = _new_mem # write if prompt: utils.write_midi( words=words[0][original_length:], word2event=self.word2event, output_path=output_path, prompt_path=prompt) else: utils.write_midi( words=words[0], word2event=self.word2event, output_path=output_path, prompt_path=None) ######################################## # prepare training data ######################################## def prepare_data(self, midi_paths): # extract events all_events = [] for path in midi_paths: events = self.extract_events(path) all_events.append(events) # event to word all_words = [] for events in all_events: words = [] for event in events: e = '{}_{}'.format(event.name, event.value) if e in self.event2word: words.append(self.event2word[e]) else: # OOV if event.name == 'Note Velocity': # replace with max velocity based on our training data words.append(self.event2word['Note Velocity_21']) else: # something is wrong # you should handle it for your own purpose print('something is wrong! {}'.format(e)) all_words.append(words) # to training data self.group_size = 5 segments = [] for words in all_words: pairs = [] for i in range(0, len(words)-self.x_len-1, self.x_len): x = words[i:i+self.x_len] y = words[i+1:i+self.x_len+1] pairs.append([x, y]) pairs = np.array(pairs) # abandon the last for i in np.arange(0, len(pairs)-self.group_size, self.group_size*2): data = pairs[i:i+self.group_size] if len(data) == self.group_size: segments.append(data) segments = np.array(segments) return segments ######################################## # finetune ######################################## def finetune(self, training_data, output_checkpoint_folder): # shuffle index = np.arange(len(training_data)) np.random.shuffle(index) training_data = training_data[index] num_batches = len(training_data) // self.batch_size st = time.time() for e in range(200): total_loss = [] for i in range(num_batches): segments = training_data[self.batch_size*i:self.batch_size*(i+1)] batch_m = [np.zeros((self.mem_len, self.batch_size, self.d_model), dtype=np.float32) for _ in range(self.n_layer)] for j in range(self.group_size): batch_x = segments[:, j, 0, :] batch_y = segments[:, j, 1, :] # prepare feed dict feed_dict = {self.x: batch_x, self.y: batch_y} for m, m_np in zip(self.mems_i, batch_m): feed_dict[m] = m_np # run _, gs_, loss_, new_mem_ = self.sess.run([self.train_op, self.global_step, self.avg_loss, self.new_mem], feed_dict=feed_dict) batch_m = new_mem_ total_loss.append(loss_) print('>>> Epoch: {}, Step: {}, Loss: {:.5f}, Time: {:.2f}'.format(e, gs_, loss_, time.time()-st)) self.saver.save(self.sess, '{}/model-{:03d}-{:.3f}'.format(output_checkpoint_folder, e, np.mean(total_loss))) # stop if np.mean(total_loss) <= 0.1: break ######################################## # close ######################################## def close(self): self.sess.close()