danurahul commited on
Commit
a399c30
1 Parent(s): cfae8e2

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +348 -0
utils.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chord_recognition
2
+ import numpy as np
3
+ import miditoolkit
4
+ import copy
5
+
6
+ # parameters for input
7
+ DEFAULT_VELOCITY_BINS = np.linspace(0, 128, 32+1, dtype=np.int)
8
+ DEFAULT_FRACTION = 16
9
+ DEFAULT_DURATION_BINS = np.arange(60, 3841, 60, dtype=int)
10
+ DEFAULT_TEMPO_INTERVALS = [range(30, 90), range(90, 150), range(150, 210)]
11
+
12
+ # parameters for output
13
+ DEFAULT_RESOLUTION = 480
14
+
15
+ # define "Item" for general storage
16
+ class Item(object):
17
+ def __init__(self, name, start, end, velocity, pitch):
18
+ self.name = name
19
+ self.start = start
20
+ self.end = end
21
+ self.velocity = velocity
22
+ self.pitch = pitch
23
+
24
+ def __repr__(self):
25
+ return 'Item(name={}, start={}, end={}, velocity={}, pitch={})'.format(
26
+ self.name, self.start, self.end, self.velocity, self.pitch)
27
+
28
+ # read notes and tempo changes from midi (assume there is only one track)
29
+ def read_items(file_path):
30
+ midi_obj = miditoolkit.midi.parser.MidiFile(file_path)
31
+ # note
32
+ note_items = []
33
+ notes = midi_obj.instruments[0].notes
34
+ notes.sort(key=lambda x: (x.start, x.pitch))
35
+ for note in notes:
36
+ note_items.append(Item(
37
+ name='Note',
38
+ start=note.start,
39
+ end=note.end,
40
+ velocity=note.velocity,
41
+ pitch=note.pitch))
42
+ note_items.sort(key=lambda x: x.start)
43
+ # tempo
44
+ tempo_items = []
45
+ for tempo in midi_obj.tempo_changes:
46
+ tempo_items.append(Item(
47
+ name='Tempo',
48
+ start=tempo.time,
49
+ end=None,
50
+ velocity=None,
51
+ pitch=int(tempo.tempo)))
52
+ tempo_items.sort(key=lambda x: x.start)
53
+ # expand to all beat
54
+ max_tick = tempo_items[-1].start
55
+ existing_ticks = {item.start: item.pitch for item in tempo_items}
56
+ wanted_ticks = np.arange(0, max_tick+1, DEFAULT_RESOLUTION)
57
+ output = []
58
+ for tick in wanted_ticks:
59
+ if tick in existing_ticks:
60
+ output.append(Item(
61
+ name='Tempo',
62
+ start=tick,
63
+ end=None,
64
+ velocity=None,
65
+ pitch=existing_ticks[tick]))
66
+ else:
67
+ output.append(Item(
68
+ name='Tempo',
69
+ start=tick,
70
+ end=None,
71
+ velocity=None,
72
+ pitch=output[-1].pitch))
73
+ tempo_items = output
74
+ return note_items, tempo_items
75
+
76
+ # quantize items
77
+ def quantize_items(items, ticks=120):
78
+ # grid
79
+ grids = np.arange(0, items[-1].start, ticks, dtype=int)
80
+ # process
81
+ for item in items:
82
+ index = np.argmin(abs(grids - item.start))
83
+ shift = grids[index] - item.start
84
+ item.start += shift
85
+ item.end += shift
86
+ return items
87
+
88
+ # extract chord
89
+ def extract_chords(items):
90
+ method = chord_recognition.MIDIChord()
91
+ chords = method.extract(notes=items)
92
+ output = []
93
+ for chord in chords:
94
+ output.append(Item(
95
+ name='Chord',
96
+ start=chord[0],
97
+ end=chord[1],
98
+ velocity=None,
99
+ pitch=chord[2].split('/')[0]))
100
+ return output
101
+
102
+ # group items
103
+ def group_items(items, max_time, ticks_per_bar=DEFAULT_RESOLUTION*4):
104
+ items.sort(key=lambda x: x.start)
105
+ downbeats = np.arange(0, max_time+ticks_per_bar, ticks_per_bar)
106
+ groups = []
107
+ for db1, db2 in zip(downbeats[:-1], downbeats[1:]):
108
+ insiders = []
109
+ for item in items:
110
+ if (item.start >= db1) and (item.start < db2):
111
+ insiders.append(item)
112
+ overall = [db1] + insiders + [db2]
113
+ groups.append(overall)
114
+ return groups
115
+
116
+ # define "Event" for event storage
117
+ class Event(object):
118
+ def __init__(self, name, time, value, text):
119
+ self.name = name
120
+ self.time = time
121
+ self.value = value
122
+ self.text = text
123
+
124
+ def __repr__(self):
125
+ return 'Event(name={}, time={}, value={}, text={})'.format(
126
+ self.name, self.time, self.value, self.text)
127
+
128
+ # item to event
129
+ def item2event(groups):
130
+ events = []
131
+ n_downbeat = 0
132
+ for i in range(len(groups)):
133
+ if 'Note' not in [item.name for item in groups[i][1:-1]]:
134
+ continue
135
+ bar_st, bar_et = groups[i][0], groups[i][-1]
136
+ n_downbeat += 1
137
+ events.append(Event(
138
+ name='Bar',
139
+ time=None,
140
+ value=None,
141
+ text='{}'.format(n_downbeat)))
142
+ for item in groups[i][1:-1]:
143
+ # position
144
+ flags = np.linspace(bar_st, bar_et, DEFAULT_FRACTION, endpoint=False)
145
+ index = np.argmin(abs(flags-item.start))
146
+ events.append(Event(
147
+ name='Position',
148
+ time=item.start,
149
+ value='{}/{}'.format(index+1, DEFAULT_FRACTION),
150
+ text='{}'.format(item.start)))
151
+ if item.name == 'Note':
152
+ # velocity
153
+ velocity_index = np.searchsorted(
154
+ DEFAULT_VELOCITY_BINS,
155
+ item.velocity,
156
+ side='right') - 1
157
+ events.append(Event(
158
+ name='Note Velocity',
159
+ time=item.start,
160
+ value=velocity_index,
161
+ text='{}/{}'.format(item.velocity, DEFAULT_VELOCITY_BINS[velocity_index])))
162
+ # pitch
163
+ events.append(Event(
164
+ name='Note On',
165
+ time=item.start,
166
+ value=item.pitch,
167
+ text='{}'.format(item.pitch)))
168
+ # duration
169
+ duration = item.end - item.start
170
+ index = np.argmin(abs(DEFAULT_DURATION_BINS-duration))
171
+ events.append(Event(
172
+ name='Note Duration',
173
+ time=item.start,
174
+ value=index,
175
+ text='{}/{}'.format(duration, DEFAULT_DURATION_BINS[index])))
176
+ elif item.name == 'Chord':
177
+ events.append(Event(
178
+ name='Chord',
179
+ time=item.start,
180
+ value=item.pitch,
181
+ text='{}'.format(item.pitch)))
182
+ elif item.name == 'Tempo':
183
+ tempo = item.pitch
184
+ if tempo in DEFAULT_TEMPO_INTERVALS[0]:
185
+ tempo_style = Event('Tempo Class', item.start, 'slow', None)
186
+ tempo_value = Event('Tempo Value', item.start,
187
+ tempo-DEFAULT_TEMPO_INTERVALS[0].start, None)
188
+ elif tempo in DEFAULT_TEMPO_INTERVALS[1]:
189
+ tempo_style = Event('Tempo Class', item.start, 'mid', None)
190
+ tempo_value = Event('Tempo Value', item.start,
191
+ tempo-DEFAULT_TEMPO_INTERVALS[1].start, None)
192
+ elif tempo in DEFAULT_TEMPO_INTERVALS[2]:
193
+ tempo_style = Event('Tempo Class', item.start, 'fast', None)
194
+ tempo_value = Event('Tempo Value', item.start,
195
+ tempo-DEFAULT_TEMPO_INTERVALS[2].start, None)
196
+ elif tempo < DEFAULT_TEMPO_INTERVALS[0].start:
197
+ tempo_style = Event('Tempo Class', item.start, 'slow', None)
198
+ tempo_value = Event('Tempo Value', item.start, 0, None)
199
+ elif tempo > DEFAULT_TEMPO_INTERVALS[2].stop:
200
+ tempo_style = Event('Tempo Class', item.start, 'fast', None)
201
+ tempo_value = Event('Tempo Value', item.start, 59, None)
202
+ events.append(tempo_style)
203
+ events.append(tempo_value)
204
+ return events
205
+
206
+ #############################################################################################
207
+ # WRITE MIDI
208
+ #############################################################################################
209
+ def word_to_event(words, word2event):
210
+ events = []
211
+ for word in words:
212
+ event_name, event_value = word2event.get(word).split('_')
213
+ events.append(Event(event_name, None, event_value, None))
214
+ return events
215
+
216
+ def write_midi(words, word2event, output_path, prompt_path=None):
217
+ events = word_to_event(words, word2event)
218
+ # get downbeat and note (no time)
219
+ temp_notes = []
220
+ temp_chords = []
221
+ temp_tempos = []
222
+ for i in range(len(events)-3):
223
+ if events[i].name == 'Bar' and i > 0:
224
+ temp_notes.append('Bar')
225
+ temp_chords.append('Bar')
226
+ temp_tempos.append('Bar')
227
+ elif events[i].name == 'Position' and \
228
+ events[i+1].name == 'Note Velocity' and \
229
+ events[i+2].name == 'Note On' and \
230
+ events[i+3].name == 'Note Duration':
231
+ # start time and end time from position
232
+ position = int(events[i].value.split('/')[0]) - 1
233
+ # velocity
234
+ index = int(events[i+1].value)
235
+ velocity = int(DEFAULT_VELOCITY_BINS[index])
236
+ # pitch
237
+ pitch = int(events[i+2].value)
238
+ # duration
239
+ index = int(events[i+3].value)
240
+ duration = DEFAULT_DURATION_BINS[index]
241
+ # adding
242
+ temp_notes.append([position, velocity, pitch, duration])
243
+ elif events[i].name == 'Position' and events[i+1].name == 'Chord':
244
+ position = int(events[i].value.split('/')[0]) - 1
245
+ temp_chords.append([position, events[i+1].value])
246
+ elif events[i].name == 'Position' and \
247
+ events[i+1].name == 'Tempo Class' and \
248
+ events[i+2].name == 'Tempo Value':
249
+ position = int(events[i].value.split('/')[0]) - 1
250
+ if events[i+1].value == 'slow':
251
+ tempo = DEFAULT_TEMPO_INTERVALS[0].start + int(events[i+2].value)
252
+ elif events[i+1].value == 'mid':
253
+ tempo = DEFAULT_TEMPO_INTERVALS[1].start + int(events[i+2].value)
254
+ elif events[i+1].value == 'fast':
255
+ tempo = DEFAULT_TEMPO_INTERVALS[2].start + int(events[i+2].value)
256
+ temp_tempos.append([position, tempo])
257
+ # get specific time for notes
258
+ ticks_per_beat = DEFAULT_RESOLUTION
259
+ ticks_per_bar = DEFAULT_RESOLUTION * 4 # assume 4/4
260
+ notes = []
261
+ current_bar = 0
262
+ for note in temp_notes:
263
+ if note == 'Bar':
264
+ current_bar += 1
265
+ else:
266
+ position, velocity, pitch, duration = note
267
+ # position (start time)
268
+ current_bar_st = current_bar * ticks_per_bar
269
+ current_bar_et = (current_bar + 1) * ticks_per_bar
270
+ flags = np.linspace(current_bar_st, current_bar_et, DEFAULT_FRACTION, endpoint=False, dtype=int)
271
+ st = flags[position]
272
+ # duration (end time)
273
+ et = st + duration
274
+ notes.append(miditoolkit.Note(velocity, pitch, st, et))
275
+ # get specific time for chords
276
+ if len(temp_chords) > 0:
277
+ chords = []
278
+ current_bar = 0
279
+ for chord in temp_chords:
280
+ if chord == 'Bar':
281
+ current_bar += 1
282
+ else:
283
+ position, value = chord
284
+ # position (start time)
285
+ current_bar_st = current_bar * ticks_per_bar
286
+ current_bar_et = (current_bar + 1) * ticks_per_bar
287
+ flags = np.linspace(current_bar_st, current_bar_et, DEFAULT_FRACTION, endpoint=False, dtype=int)
288
+ st = flags[position]
289
+ chords.append([st, value])
290
+ # get specific time for tempos
291
+ tempos = []
292
+ current_bar = 0
293
+ for tempo in temp_tempos:
294
+ if tempo == 'Bar':
295
+ current_bar += 1
296
+ else:
297
+ position, value = tempo
298
+ # position (start time)
299
+ current_bar_st = current_bar * ticks_per_bar
300
+ current_bar_et = (current_bar + 1) * ticks_per_bar
301
+ flags = np.linspace(current_bar_st, current_bar_et, DEFAULT_FRACTION, endpoint=False, dtype=int)
302
+ st = flags[position]
303
+ tempos.append([int(st), value])
304
+ # write
305
+ if prompt_path:
306
+ midi = miditoolkit.midi.parser.MidiFile(prompt_path)
307
+ #
308
+ last_time = DEFAULT_RESOLUTION * 4 * 4
309
+ # note shift
310
+ for note in notes:
311
+ note.start += last_time
312
+ note.end += last_time
313
+ midi.instruments[0].notes.extend(notes)
314
+ # tempo changes
315
+ temp_tempos = []
316
+ for tempo in midi.tempo_changes:
317
+ if tempo.time < DEFAULT_RESOLUTION*4*4:
318
+ temp_tempos.append(tempo)
319
+ else:
320
+ break
321
+ for st, bpm in tempos:
322
+ st += last_time
323
+ temp_tempos.append(miditoolkit.midi.containers.TempoChange(bpm, st))
324
+ midi.tempo_changes = temp_tempos
325
+ # write chord into marker
326
+ if len(temp_chords) > 0:
327
+ for c in chords:
328
+ midi.markers.append(
329
+ miditoolkit.midi.containers.Marker(text=c[1], time=c[0]+last_time))
330
+ else:
331
+ midi = miditoolkit.midi.parser.MidiFile()
332
+ midi.ticks_per_beat = DEFAULT_RESOLUTION
333
+ # write instrument
334
+ inst = miditoolkit.midi.containers.Instrument(0, is_drum=False)
335
+ inst.notes = notes
336
+ midi.instruments.append(inst)
337
+ # write tempo
338
+ tempo_changes = []
339
+ for st, bpm in tempos:
340
+ tempo_changes.append(miditoolkit.midi.containers.TempoChange(bpm, st))
341
+ midi.tempo_changes = tempo_changes
342
+ # write chord into marker
343
+ if len(temp_chords) > 0:
344
+ for c in chords:
345
+ midi.markers.append(
346
+ miditoolkit.midi.containers.Marker(text=c[1], time=c[0]))
347
+ # write
348
+ midi.dump(output_path)