Upload model.py
Browse files
model.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import numpy as np
|
3 |
+
import miditoolkit
|
4 |
+
import modules
|
5 |
+
import pickle
|
6 |
+
import utils
|
7 |
+
import time
|
8 |
+
|
9 |
+
class PopMusicTransformer(object):
|
10 |
+
########################################
|
11 |
+
# initialize
|
12 |
+
########################################
|
13 |
+
def __init__(self, checkpoint, is_training=False):
|
14 |
+
# load dictionary
|
15 |
+
self.dictionary_path = '{}/dictionary.pkl'.format(checkpoint)
|
16 |
+
self.event2word, self.word2event = pickle.load(open(self.dictionary_path, 'rb'))
|
17 |
+
# model settings
|
18 |
+
self.x_len = 512
|
19 |
+
self.mem_len = 512
|
20 |
+
self.n_layer = 12
|
21 |
+
self.d_embed = 512
|
22 |
+
self.d_model = 512
|
23 |
+
self.dropout = 0.1
|
24 |
+
self.n_head = 8
|
25 |
+
self.d_head = self.d_model // self.n_head
|
26 |
+
self.d_ff = 2048
|
27 |
+
self.n_token = len(self.event2word)
|
28 |
+
self.learning_rate = 0.0002
|
29 |
+
# load model
|
30 |
+
self.is_training = is_training
|
31 |
+
if self.is_training:
|
32 |
+
self.batch_size = 4
|
33 |
+
else:
|
34 |
+
self.batch_size = 1
|
35 |
+
self.checkpoint_path = '{}/model'.format(checkpoint)
|
36 |
+
self.load_model()
|
37 |
+
|
38 |
+
########################################
|
39 |
+
# load model
|
40 |
+
########################################
|
41 |
+
def load_model(self):
|
42 |
+
# placeholders
|
43 |
+
self.x = tf.compat.v1.placeholder(tf.int32, shape=[self.batch_size, None])
|
44 |
+
self.y = tf.compat.v1.placeholder(tf.int32, shape=[self.batch_size, None])
|
45 |
+
self.mems_i = [tf.compat.v1.placeholder(tf.float32, [self.mem_len, self.batch_size, self.d_model]) for _ in range(self.n_layer)]
|
46 |
+
# model
|
47 |
+
self.global_step = tf.compat.v1.train.get_or_create_global_step()
|
48 |
+
initializer = tf.compat.v1.initializers.random_normal(stddev=0.02, seed=None)
|
49 |
+
proj_initializer = tf.compat.v1.initializers.random_normal(stddev=0.01, seed=None)
|
50 |
+
with tf.compat.v1.variable_scope(tf.compat.v1.get_variable_scope()):
|
51 |
+
xx = tf.transpose(self.x, [1, 0])
|
52 |
+
yy = tf.transpose(self.y, [1, 0])
|
53 |
+
loss, self.logits, self.new_mem = modules.transformer(
|
54 |
+
dec_inp=xx,
|
55 |
+
target=yy,
|
56 |
+
mems=self.mems_i,
|
57 |
+
n_token=self.n_token,
|
58 |
+
n_layer=self.n_layer,
|
59 |
+
d_model=self.d_model,
|
60 |
+
d_embed=self.d_embed,
|
61 |
+
n_head=self.n_head,
|
62 |
+
d_head=self.d_head,
|
63 |
+
d_inner=self.d_ff,
|
64 |
+
dropout=self.dropout,
|
65 |
+
dropatt=self.dropout,
|
66 |
+
initializer=initializer,
|
67 |
+
proj_initializer=proj_initializer,
|
68 |
+
is_training=self.is_training,
|
69 |
+
mem_len=self.mem_len,
|
70 |
+
cutoffs=[],
|
71 |
+
div_val=-1,
|
72 |
+
tie_projs=[],
|
73 |
+
same_length=False,
|
74 |
+
clamp_len=-1,
|
75 |
+
input_perms=None,
|
76 |
+
target_perms=None,
|
77 |
+
head_target=None,
|
78 |
+
untie_r=False,
|
79 |
+
proj_same_dim=True)
|
80 |
+
self.avg_loss = tf.reduce_mean(loss)
|
81 |
+
# vars
|
82 |
+
all_vars = tf.compat.v1.trainable_variables()
|
83 |
+
grads = tf.gradients(self.avg_loss, all_vars)
|
84 |
+
grads_and_vars = list(zip(grads, all_vars))
|
85 |
+
all_trainable_vars = tf.reduce_sum([tf.reduce_prod(v.shape) for v in tf.compat.v1.trainable_variables()])
|
86 |
+
# optimizer
|
87 |
+
decay_lr = tf.compat.v1.train.cosine_decay(
|
88 |
+
self.learning_rate,
|
89 |
+
global_step=self.global_step,
|
90 |
+
decay_steps=400000,
|
91 |
+
alpha=0.004)
|
92 |
+
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=decay_lr)
|
93 |
+
self.train_op = optimizer.apply_gradients(grads_and_vars, self.global_step)
|
94 |
+
# saver
|
95 |
+
self.saver = tf.compat.v1.train.Saver()
|
96 |
+
config = tf.compat.v1.ConfigProto(allow_soft_placement=True)
|
97 |
+
config.gpu_options.allow_growth = True
|
98 |
+
self.sess = tf.compat.v1.Session(config=config)
|
99 |
+
self.saver.restore(self.sess, self.checkpoint_path)
|
100 |
+
|
101 |
+
########################################
|
102 |
+
# temperature sampling
|
103 |
+
########################################
|
104 |
+
def temperature_sampling(self, logits, temperature, topk):
|
105 |
+
probs = np.exp(logits / temperature) / np.sum(np.exp(logits / temperature))
|
106 |
+
if topk == 1:
|
107 |
+
prediction = np.argmax(probs)
|
108 |
+
else:
|
109 |
+
sorted_index = np.argsort(probs)[::-1]
|
110 |
+
candi_index = sorted_index[:topk]
|
111 |
+
candi_probs = [probs[i] for i in candi_index]
|
112 |
+
# normalize probs
|
113 |
+
candi_probs /= sum(candi_probs)
|
114 |
+
# choose by predicted probs
|
115 |
+
prediction = np.random.choice(candi_index, size=1, p=candi_probs)[0]
|
116 |
+
return prediction
|
117 |
+
|
118 |
+
########################################
|
119 |
+
# extract events for prompt continuation
|
120 |
+
########################################
|
121 |
+
def extract_events(self, input_path):
|
122 |
+
note_items, tempo_items = utils.read_items(input_path)
|
123 |
+
note_items = utils.quantize_items(note_items)
|
124 |
+
max_time = note_items[-1].end
|
125 |
+
if 'chord' in self.checkpoint_path:
|
126 |
+
chord_items = utils.extract_chords(note_items)
|
127 |
+
items = chord_items + tempo_items + note_items
|
128 |
+
else:
|
129 |
+
items = tempo_items + note_items
|
130 |
+
groups = utils.group_items(items, max_time)
|
131 |
+
events = utils.item2event(groups)
|
132 |
+
return events
|
133 |
+
|
134 |
+
########################################
|
135 |
+
# generate
|
136 |
+
########################################
|
137 |
+
def generate(self, n_target_bar, temperature, topk, output_path, prompt=None):
|
138 |
+
# if prompt, load it. Or, random start
|
139 |
+
if prompt:
|
140 |
+
events = self.extract_events(prompt)
|
141 |
+
words = [[self.event2word['{}_{}'.format(e.name, e.value)] for e in events]]
|
142 |
+
words[0].append(self.event2word['Bar_None'])
|
143 |
+
else:
|
144 |
+
words = []
|
145 |
+
for _ in range(self.batch_size):
|
146 |
+
ws = [self.event2word['Bar_None']]
|
147 |
+
if 'chord' in self.checkpoint_path:
|
148 |
+
tempo_classes = [v for k, v in self.event2word.items() if 'Tempo Class' in k]
|
149 |
+
tempo_values = [v for k, v in self.event2word.items() if 'Tempo Value' in k]
|
150 |
+
chords = [v for k, v in self.event2word.items() if 'Chord' in k]
|
151 |
+
ws.append(self.event2word['Position_1/16'])
|
152 |
+
ws.append(np.random.choice(chords))
|
153 |
+
ws.append(self.event2word['Position_1/16'])
|
154 |
+
ws.append(np.random.choice(tempo_classes))
|
155 |
+
ws.append(np.random.choice(tempo_values))
|
156 |
+
else:
|
157 |
+
tempo_classes = [v for k, v in self.event2word.items() if 'Tempo Class' in k]
|
158 |
+
tempo_values = [v for k, v in self.event2word.items() if 'Tempo Value' in k]
|
159 |
+
ws.append(self.event2word['Position_1/16'])
|
160 |
+
ws.append(np.random.choice(tempo_classes))
|
161 |
+
ws.append(np.random.choice(tempo_values))
|
162 |
+
words.append(ws)
|
163 |
+
# initialize mem
|
164 |
+
batch_m = [np.zeros((self.mem_len, self.batch_size, self.d_model), dtype=np.float32) for _ in range(self.n_layer)]
|
165 |
+
# generate
|
166 |
+
original_length = len(words[0])
|
167 |
+
initial_flag = 1
|
168 |
+
current_generated_bar = 0
|
169 |
+
while current_generated_bar < n_target_bar:
|
170 |
+
# input
|
171 |
+
if initial_flag:
|
172 |
+
temp_x = np.zeros((self.batch_size, original_length))
|
173 |
+
for b in range(self.batch_size):
|
174 |
+
for z, t in enumerate(words[b]):
|
175 |
+
temp_x[b][z] = t
|
176 |
+
initial_flag = 0
|
177 |
+
else:
|
178 |
+
temp_x = np.zeros((self.batch_size, 1))
|
179 |
+
for b in range(self.batch_size):
|
180 |
+
temp_x[b][0] = words[b][-1]
|
181 |
+
# prepare feed dict
|
182 |
+
feed_dict = {self.x: temp_x}
|
183 |
+
for m, m_np in zip(self.mems_i, batch_m):
|
184 |
+
feed_dict[m] = m_np
|
185 |
+
# model (prediction)
|
186 |
+
_logits, _new_mem = self.sess.run([self.logits, self.new_mem], feed_dict=feed_dict)
|
187 |
+
# sampling
|
188 |
+
_logit = _logits[-1, 0]
|
189 |
+
word = self.temperature_sampling(
|
190 |
+
logits=_logit,
|
191 |
+
temperature=temperature,
|
192 |
+
topk=topk)
|
193 |
+
words[0].append(word)
|
194 |
+
# if bar event (only work for batch_size=1)
|
195 |
+
if word == self.event2word['Bar_None']:
|
196 |
+
current_generated_bar += 1
|
197 |
+
# re-new mem
|
198 |
+
batch_m = _new_mem
|
199 |
+
# write
|
200 |
+
if prompt:
|
201 |
+
utils.write_midi(
|
202 |
+
words=words[0][original_length:],
|
203 |
+
word2event=self.word2event,
|
204 |
+
output_path=output_path,
|
205 |
+
prompt_path=prompt)
|
206 |
+
else:
|
207 |
+
utils.write_midi(
|
208 |
+
words=words[0],
|
209 |
+
word2event=self.word2event,
|
210 |
+
output_path=output_path,
|
211 |
+
prompt_path=None)
|
212 |
+
|
213 |
+
########################################
|
214 |
+
# prepare training data
|
215 |
+
########################################
|
216 |
+
def prepare_data(self, midi_paths):
|
217 |
+
# extract events
|
218 |
+
all_events = []
|
219 |
+
for path in midi_paths:
|
220 |
+
events = self.extract_events(path)
|
221 |
+
all_events.append(events)
|
222 |
+
# event to word
|
223 |
+
all_words = []
|
224 |
+
for events in all_events:
|
225 |
+
words = []
|
226 |
+
for event in events:
|
227 |
+
e = '{}_{}'.format(event.name, event.value)
|
228 |
+
if e in self.event2word:
|
229 |
+
words.append(self.event2word[e])
|
230 |
+
else:
|
231 |
+
# OOV
|
232 |
+
if event.name == 'Note Velocity':
|
233 |
+
# replace with max velocity based on our training data
|
234 |
+
words.append(self.event2word['Note Velocity_21'])
|
235 |
+
else:
|
236 |
+
# something is wrong
|
237 |
+
# you should handle it for your own purpose
|
238 |
+
print('something is wrong! {}'.format(e))
|
239 |
+
all_words.append(words)
|
240 |
+
# to training data
|
241 |
+
self.group_size = 5
|
242 |
+
segments = []
|
243 |
+
for words in all_words:
|
244 |
+
pairs = []
|
245 |
+
for i in range(0, len(words)-self.x_len-1, self.x_len):
|
246 |
+
x = words[i:i+self.x_len]
|
247 |
+
y = words[i+1:i+self.x_len+1]
|
248 |
+
pairs.append([x, y])
|
249 |
+
pairs = np.array(pairs)
|
250 |
+
# abandon the last
|
251 |
+
for i in np.arange(0, len(pairs)-self.group_size, self.group_size*2):
|
252 |
+
data = pairs[i:i+self.group_size]
|
253 |
+
if len(data) == self.group_size:
|
254 |
+
segments.append(data)
|
255 |
+
segments = np.array(segments)
|
256 |
+
return segments
|
257 |
+
|
258 |
+
########################################
|
259 |
+
# finetune
|
260 |
+
########################################
|
261 |
+
def finetune(self, training_data, output_checkpoint_folder):
|
262 |
+
# shuffle
|
263 |
+
index = np.arange(len(training_data))
|
264 |
+
np.random.shuffle(index)
|
265 |
+
training_data = training_data[index]
|
266 |
+
num_batches = len(training_data) // self.batch_size
|
267 |
+
st = time.time()
|
268 |
+
for e in range(200):
|
269 |
+
total_loss = []
|
270 |
+
for i in range(num_batches):
|
271 |
+
segments = training_data[self.batch_size*i:self.batch_size*(i+1)]
|
272 |
+
batch_m = [np.zeros((self.mem_len, self.batch_size, self.d_model), dtype=np.float32) for _ in range(self.n_layer)]
|
273 |
+
for j in range(self.group_size):
|
274 |
+
batch_x = segments[:, j, 0, :]
|
275 |
+
batch_y = segments[:, j, 1, :]
|
276 |
+
# prepare feed dict
|
277 |
+
feed_dict = {self.x: batch_x, self.y: batch_y}
|
278 |
+
for m, m_np in zip(self.mems_i, batch_m):
|
279 |
+
feed_dict[m] = m_np
|
280 |
+
# run
|
281 |
+
_, gs_, loss_, new_mem_ = self.sess.run([self.train_op, self.global_step, self.avg_loss, self.new_mem], feed_dict=feed_dict)
|
282 |
+
batch_m = new_mem_
|
283 |
+
total_loss.append(loss_)
|
284 |
+
print('>>> Epoch: {}, Step: {}, Loss: {:.5f}, Time: {:.2f}'.format(e, gs_, loss_, time.time()-st))
|
285 |
+
self.saver.save(self.sess, '{}/model-{:03d}-{:.3f}'.format(output_checkpoint_folder, e, np.mean(total_loss)))
|
286 |
+
# stop
|
287 |
+
if np.mean(total_loss) <= 0.1:
|
288 |
+
break
|
289 |
+
|
290 |
+
########################################
|
291 |
+
# close
|
292 |
+
########################################
|
293 |
+
def close(self):
|
294 |
+
self.sess.close()
|