SteveZerb commited on
Commit
671fe47
·
verified ·
1 Parent(s): da5ef74

Upload 10 files

Browse files
Files changed (10) hide show
  1. Dockerfile +43 -0
  2. MIDI.py +1735 -0
  3. README.md +8 -8
  4. app.py +533 -0
  5. app_onnx.py +625 -0
  6. midi_model.py +250 -0
  7. midi_synthesizer.py +81 -0
  8. midi_tokenizer.py +1196 -0
  9. packages.txt +1 -0
  10. requirements.txt +11 -0
Dockerfile ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.6.1-cudnn8-devel-ubuntu20.04
2
+
3
+ ARG DEBIAN_FRONTEND=noninteractive
4
+
5
+ ENV PYTHONUNBUFFERED=1
6
+
7
+ RUN apt-get update && apt-get install --no-install-recommends -y \
8
+ build-essential \
9
+ python3.9 \
10
+ python3-pip \
11
+ git \
12
+ ffmpeg \
13
+ fluidsynth \
14
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
15
+
16
+ WORKDIR /code
17
+
18
+ COPY ./requirements.txt /code/requirements.txt
19
+
20
+ # Set up a new user named "user" with user ID 1000
21
+ RUN useradd -m -u 1000 user
22
+ # Switch to the "user" user
23
+ USER user
24
+ # Set home to the user's home directory
25
+ ENV HOME=/home/user \
26
+ PATH=/home/user/.local/bin:$PATH \
27
+ PYTHONPATH=$HOME/app \
28
+ PYTHONUNBUFFERED=1 \
29
+ GRADIO_ALLOW_FLAGGING=never \
30
+ GRADIO_NUM_PORTS=1 \
31
+ GRADIO_SERVER_NAME=0.0.0.0 \
32
+ GRADIO_THEME=huggingface \
33
+ SYSTEM=spaces
34
+
35
+ RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
36
+
37
+ # Set the working directory to the user's home directory
38
+ WORKDIR $HOME/app
39
+
40
+ # Copy the current directory contents into the container at $HOME/app setting the owner to the user
41
+ COPY --chown=user . $HOME/app
42
+
43
+ CMD ["python3", "app.py"]
MIDI.py ADDED
@@ -0,0 +1,1735 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/python3
2
+ # unsupported 20091104 ...
3
+ # ['set_sequence_number', dtime, sequence]
4
+ # ['raw_data', dtime, raw]
5
+
6
+ # 20150914 jimbo1qaz MIDI.py str/bytes bug report
7
+ # I found a MIDI file which had Shift-JIS titles. When midi.py decodes it as
8
+ # latin-1, it produces a string which cannot even be accessed without raising
9
+ # a UnicodeDecodeError. Maybe, when converting raw byte strings from MIDI,
10
+ # you should keep them as bytes, not improperly decode them. However, this
11
+ # would change the API. (ie: text = a "string" ? of 0 or more bytes). It
12
+ # could break compatiblity, but there's not much else you can do to fix the bug
13
+ # https://en.wikipedia.org/wiki/Shift_JIS
14
+
15
+ r'''
16
+ This module offers functions: concatenate_scores(), grep(),
17
+ merge_scores(), mix_scores(), midi2opus(), midi2score(), opus2midi(),
18
+ opus2score(), play_score(), score2midi(), score2opus(), score2stats(),
19
+ score_type(), segment(), timeshift() and to_millisecs(),
20
+ where "midi" means the MIDI-file bytes (as can be put in a .mid file,
21
+ or piped into aplaymidi), and "opus" and "score" are list-structures
22
+ as inspired by Sean Burke's MIDI-Perl CPAN module.
23
+
24
+ Warning: Version 6.4 is not necessarily backward-compatible with
25
+ previous versions, in that text-data is now bytes, not strings.
26
+ This reflects the fact that many MIDI files have text data in
27
+ encodings other that ISO-8859-1, for example in Shift-JIS.
28
+
29
+ Download MIDI.py from http://www.pjb.com.au/midi/free/MIDI.py
30
+ and put it in your PYTHONPATH. MIDI.py depends on Python3.
31
+
32
+ There is also a call-compatible translation into Lua of this
33
+ module: see http://www.pjb.com.au/comp/lua/MIDI.html
34
+
35
+ The "opus" is a direct translation of the midi-file-events, where
36
+ the times are delta-times, in ticks, since the previous event.
37
+
38
+ The "score" is more human-centric; it uses absolute times, and
39
+ combines the separate note_on and note_off events into one "note"
40
+ event, with a duration:
41
+ ['note', start_time, duration, channel, note, velocity] # in a "score"
42
+
43
+ EVENTS (in an "opus" structure)
44
+ ['note_off', dtime, channel, note, velocity] # in an "opus"
45
+ ['note_on', dtime, channel, note, velocity] # in an "opus"
46
+ ['key_after_touch', dtime, channel, note, velocity]
47
+ ['control_change', dtime, channel, controller(0-127), value(0-127)]
48
+ ['patch_change', dtime, channel, patch]
49
+ ['channel_after_touch', dtime, channel, velocity]
50
+ ['pitch_wheel_change', dtime, channel, pitch_wheel]
51
+ ['text_event', dtime, text]
52
+ ['copyright_text_event', dtime, text]
53
+ ['track_name', dtime, text]
54
+ ['instrument_name', dtime, text]
55
+ ['lyric', dtime, text]
56
+ ['marker', dtime, text]
57
+ ['cue_point', dtime, text]
58
+ ['text_event_08', dtime, text]
59
+ ['text_event_09', dtime, text]
60
+ ['text_event_0a', dtime, text]
61
+ ['text_event_0b', dtime, text]
62
+ ['text_event_0c', dtime, text]
63
+ ['text_event_0d', dtime, text]
64
+ ['text_event_0e', dtime, text]
65
+ ['text_event_0f', dtime, text]
66
+ ['end_track', dtime]
67
+ ['set_tempo', dtime, tempo]
68
+ ['smpte_offset', dtime, hr, mn, se, fr, ff]
69
+ ['time_signature', dtime, nn, dd, cc, bb]
70
+ ['key_signature', dtime, sf, mi]
71
+ ['sequencer_specific', dtime, raw]
72
+ ['raw_meta_event', dtime, command(0-255), raw]
73
+ ['sysex_f0', dtime, raw]
74
+ ['sysex_f7', dtime, raw]
75
+ ['song_position', dtime, song_pos]
76
+ ['song_select', dtime, song_number]
77
+ ['tune_request', dtime]
78
+
79
+ DATA TYPES
80
+ channel = a value 0 to 15
81
+ controller = 0 to 127 (see http://www.pjb.com.au/muscript/gm.html#cc )
82
+ dtime = time measured in "ticks", 0 to 268435455
83
+ velocity = a value 0 (soft) to 127 (loud)
84
+ note = a value 0 to 127 (middle-C is 60)
85
+ patch = 0 to 127 (see http://www.pjb.com.au/muscript/gm.html )
86
+ pitch_wheel = a value -8192 to 8191 (0x1FFF)
87
+ raw = bytes, of length 0 or more (for sysex events see below)
88
+ sequence_number = a value 0 to 65,535 (0xFFFF)
89
+ song_pos = a value 0 to 16,383 (0x3FFF)
90
+ song_number = a value 0 to 127
91
+ tempo = microseconds per crochet (quarter-note), 0 to 16777215
92
+ text = bytes, of length 0 or more
93
+ ticks = the number of ticks per crochet (quarter-note)
94
+
95
+ In sysex_f0 events, the raw data must not start with a \xF0 byte,
96
+ since this gets added automatically;
97
+ but it must end with an explicit \xF7 byte!
98
+ In the very unlikely case that you ever need to split sysex data
99
+ into one sysex_f0 followed by one or more sysex_f7s, then only the
100
+ last of those sysex_f7 events must end with the explicit \xF7 byte
101
+ (again, the raw data of individual sysex_f7 events must not start
102
+ with any \xF7 byte, since this gets added automatically).
103
+
104
+ Since version 6.4, text data is in bytes, not in a ISO-8859-1 string.
105
+
106
+
107
+ GOING THROUGH A SCORE WITHIN A PYTHON PROGRAM
108
+ channels = {2,3,5,8,13}
109
+ itrack = 1 # skip 1st element which is ticks
110
+ while itrack < len(score):
111
+ for event in score[itrack]:
112
+ if event[0] == 'note': # for example,
113
+ pass # do something to all notes
114
+ # or, to work on events in only particular channels...
115
+ channel_index = MIDI.Event2channelindex.get(event[0], False)
116
+ if channel_index and (event[channel_index] in channels):
117
+ pass # do something to channels 2,3,5,8 and 13
118
+ itrack += 1
119
+
120
+ '''
121
+
122
+ import sys, struct, copy
123
+ # sys.stdout = os.fdopen(sys.stdout.fileno(), 'wb')
124
+ Version = '6.7'
125
+ VersionDate = '20201120'
126
+ # 20201120 6.7 call to bytest() removed, and protect _unshift_ber_int
127
+ # 20160702 6.6 to_millisecs() now handles set_tempo across multiple Tracks
128
+ # 20150921 6.5 segment restores controllers as well as patch and tempo
129
+ # 20150914 6.4 text data is bytes or bytearray, not ISO-8859-1 strings
130
+ # 20150628 6.3 absent any set_tempo, default is 120bpm (see MIDI file spec 1.1)
131
+ # 20150101 6.2 all text events can be 8-bit; let user get the right encoding
132
+ # 20141231 6.1 fix _some_text_event; sequencer_specific data can be 8-bit
133
+ # 20141230 6.0 synth_specific data can be 8-bit
134
+ # 20120504 5.9 add the contents of mid_opus_tracks()
135
+ # 20120208 5.8 fix num_notes_by_channel() ; should be a dict
136
+ # 20120129 5.7 _encode handles empty tracks; score2stats num_notes_by_channel
137
+ # 20111111 5.6 fix patch 45 and 46 in Number2patch, should be Harp
138
+ # 20110129 5.5 add mix_opus_tracks() and event2alsaseq()
139
+ # 20110126 5.4 "previous message repeated N times" to save space on stderr
140
+ # 20110125 5.2 opus2score terminates unended notes at the end of the track
141
+ # 20110124 5.1 the warnings in midi2opus display track_num
142
+ # 21110122 5.0 if garbage, midi2opus returns the opus so far
143
+ # 21110119 4.9 non-ascii chars stripped out of the text_events
144
+ # 21110110 4.8 note_on with velocity=0 treated as a note-off
145
+ # 21110108 4.6 unknown F-series event correctly eats just one byte
146
+ # 21011010 4.2 segment() uses start_time, end_time named params
147
+ # 21011005 4.1 timeshift() must not pad the set_tempo command
148
+ # 21011003 4.0 pitch2note_event must be chapitch2note_event
149
+ # 21010918 3.9 set_sequence_number supported, FWIW
150
+ # 20100913 3.7 many small bugfixes; passes all tests
151
+ # 20100910 3.6 concatenate_scores enforce ticks=1000, just like merge_scores
152
+ # 20100908 3.5 minor bugs fixed in score2stats
153
+ # 20091104 3.4 tune_request now supported
154
+ # 20091104 3.3 fixed bug in decoding song_position and song_select
155
+ # 20091104 3.2 unsupported: set_sequence_number tune_request raw_data
156
+ # 20091101 3.1 document how to traverse a score within Python
157
+ # 20091021 3.0 fixed bug in score2stats detecting GM-mode = 0
158
+ # 20091020 2.9 score2stats reports GM-mode and bank msb,lsb events
159
+ # 20091019 2.8 in merge_scores, channel 9 must remain channel 9 (in GM)
160
+ # 20091018 2.7 handles empty tracks gracefully
161
+ # 20091015 2.6 grep() selects channels
162
+ # 20091010 2.5 merge_scores reassigns channels to avoid conflicts
163
+ # 20091010 2.4 fixed bug in to_millisecs which now only does opusses
164
+ # 20091010 2.3 score2stats returns channels & patch_changes, by_track & total
165
+ # 20091010 2.2 score2stats() returns also pitches and percussion dicts
166
+ # 20091010 2.1 bugs: >= not > in segment, to notice patch_change at time 0
167
+ # 20091010 2.0 bugs: spurious pop(0) ( in _decode sysex
168
+ # 20091008 1.9 bugs: ISO decoding in sysex; str( not int( in note-off warning
169
+ # 20091008 1.8 add concatenate_scores()
170
+ # 20091006 1.7 score2stats() measures nticks and ticks_per_quarter
171
+ # 20091004 1.6 first mix_scores() and merge_scores()
172
+ # 20090424 1.5 timeshift() bugfix: earliest only sees events after from_time
173
+ # 20090330 1.4 timeshift() has also a from_time argument
174
+ # 20090322 1.3 timeshift() has also a start_time argument
175
+ # 20090319 1.2 add segment() and timeshift()
176
+ # 20090301 1.1 add to_millisecs()
177
+
178
+ _previous_warning = '' # 5.4
179
+ _previous_times = 0 # 5.4
180
+ _no_warning = True
181
+ #------------------------------- Encoding stuff --------------------------
182
+
183
+ def opus2midi(opus=[]):
184
+ r'''The argument is a list: the first item in the list is the "ticks"
185
+ parameter, the others are the tracks. Each track is a list
186
+ of midi-events, and each event is itself a list; see above.
187
+ opus2midi() returns a bytestring of the MIDI, which can then be
188
+ written either to a file opened in binary mode (mode='wb'),
189
+ or to stdout by means of: sys.stdout.buffer.write()
190
+
191
+ my_opus = [
192
+ 96,
193
+ [ # track 0:
194
+ ['patch_change', 0, 1, 8], # and these are the events...
195
+ ['note_on', 5, 1, 25, 96],
196
+ ['note_off', 96, 1, 25, 0],
197
+ ['note_on', 0, 1, 29, 96],
198
+ ['note_off', 96, 1, 29, 0],
199
+ ], # end of track 0
200
+ ]
201
+ my_midi = opus2midi(my_opus)
202
+ sys.stdout.buffer.write(my_midi)
203
+ '''
204
+ if len(opus) < 2:
205
+ opus=[1000, [],]
206
+ tracks = copy.deepcopy(opus)
207
+ ticks = int(tracks.pop(0))
208
+ ntracks = len(tracks)
209
+ if ntracks == 1:
210
+ format = 0
211
+ else:
212
+ format = 1
213
+
214
+ my_midi = b"MThd\x00\x00\x00\x06"+struct.pack('>HHH',format,ntracks,ticks)
215
+ for track in tracks:
216
+ events = _encode(track)
217
+ my_midi += b'MTrk' + struct.pack('>I',len(events)) + events
218
+ _clean_up_warnings()
219
+ return my_midi
220
+
221
+
222
+ def score2opus(score=None):
223
+ r'''
224
+ The argument is a list: the first item in the list is the "ticks"
225
+ parameter, the others are the tracks. Each track is a list
226
+ of score-events, and each event is itself a list. A score-event
227
+ is similar to an opus-event (see above), except that in a score:
228
+ 1) the times are expressed as an absolute number of ticks
229
+ from the track's start time
230
+ 2) the pairs of 'note_on' and 'note_off' events in an "opus"
231
+ are abstracted into a single 'note' event in a "score":
232
+ ['note', start_time, duration, channel, pitch, velocity]
233
+ score2opus() returns a list specifying the equivalent "opus".
234
+
235
+ my_score = [
236
+ 96,
237
+ [ # track 0:
238
+ ['patch_change', 0, 1, 8],
239
+ ['note', 5, 96, 1, 25, 96],
240
+ ['note', 101, 96, 1, 29, 96]
241
+ ], # end of track 0
242
+ ]
243
+ my_opus = score2opus(my_score)
244
+ '''
245
+ if len(score) < 2:
246
+ score=[1000, [],]
247
+ tracks = copy.deepcopy(score)
248
+ ticks = int(tracks.pop(0))
249
+ opus_tracks = []
250
+ for scoretrack in tracks:
251
+ time2events = dict([])
252
+ for scoreevent in scoretrack:
253
+ if scoreevent[0] == 'note':
254
+ note_on_event = ['note_on',scoreevent[1],
255
+ scoreevent[3],scoreevent[4],scoreevent[5]]
256
+ note_off_event = ['note_off',scoreevent[1]+scoreevent[2],
257
+ scoreevent[3],scoreevent[4],scoreevent[5]]
258
+ if time2events.get(note_on_event[1]):
259
+ time2events[note_on_event[1]].append(note_on_event)
260
+ else:
261
+ time2events[note_on_event[1]] = [note_on_event,]
262
+ if time2events.get(note_off_event[1]):
263
+ time2events[note_off_event[1]].append(note_off_event)
264
+ else:
265
+ time2events[note_off_event[1]] = [note_off_event,]
266
+ continue
267
+ if time2events.get(scoreevent[1]):
268
+ time2events[scoreevent[1]].append(scoreevent)
269
+ else:
270
+ time2events[scoreevent[1]] = [scoreevent,]
271
+
272
+ sorted_times = [] # list of keys
273
+ for k in time2events.keys():
274
+ sorted_times.append(k)
275
+ sorted_times.sort()
276
+
277
+ sorted_events = [] # once-flattened list of values sorted by key
278
+ for time in sorted_times:
279
+ sorted_events.extend(time2events[time])
280
+
281
+ abs_time = 0
282
+ for event in sorted_events: # convert abs times => delta times
283
+ delta_time = event[1] - abs_time
284
+ abs_time = event[1]
285
+ event[1] = delta_time
286
+ opus_tracks.append(sorted_events)
287
+ opus_tracks.insert(0,ticks)
288
+ _clean_up_warnings()
289
+ return opus_tracks
290
+
291
+ def score2midi(score=None):
292
+ r'''
293
+ Translates a "score" into MIDI, using score2opus() then opus2midi()
294
+ '''
295
+ return opus2midi(score2opus(score))
296
+
297
+ #--------------------------- Decoding stuff ------------------------
298
+
299
+ def midi2opus(midi=b''):
300
+ r'''Translates MIDI into a "opus". For a description of the
301
+ "opus" format, see opus2midi()
302
+ '''
303
+ my_midi=bytearray(midi)
304
+ if len(my_midi) < 4:
305
+ _clean_up_warnings()
306
+ return [1000,[],]
307
+ id = bytes(my_midi[0:4])
308
+ if id != b'MThd':
309
+ _warn("midi2opus: midi starts with "+str(id)+" instead of 'MThd'")
310
+ _clean_up_warnings()
311
+ return [1000,[],]
312
+ [length, format, tracks_expected, ticks] = struct.unpack(
313
+ '>IHHH', bytes(my_midi[4:14]))
314
+ if length != 6:
315
+ _warn("midi2opus: midi header length was "+str(length)+" instead of 6")
316
+ _clean_up_warnings()
317
+ return [1000,[],]
318
+ my_opus = [ticks,]
319
+ my_midi = my_midi[14:]
320
+ track_num = 1 # 5.1
321
+ while len(my_midi) >= 8:
322
+ track_type = bytes(my_midi[0:4])
323
+ if track_type != b'MTrk':
324
+ _warn('midi2opus: Warning: track #'+str(track_num)+' type is '+str(track_type)+" instead of b'MTrk'")
325
+ [track_length] = struct.unpack('>I', my_midi[4:8])
326
+ my_midi = my_midi[8:]
327
+ if track_length > len(my_midi):
328
+ _warn('midi2opus: track #'+str(track_num)+' length '+str(track_length)+' is too large')
329
+ _clean_up_warnings()
330
+ return my_opus # 5.0
331
+ my_midi_track = my_midi[0:track_length]
332
+ my_track = _decode(my_midi_track)
333
+ my_opus.append(my_track)
334
+ my_midi = my_midi[track_length:]
335
+ track_num += 1 # 5.1
336
+ _clean_up_warnings()
337
+ return my_opus
338
+
339
+ def opus2score(opus=[]):
340
+ r'''For a description of the "opus" and "score" formats,
341
+ see opus2midi() and score2opus().
342
+ '''
343
+ if len(opus) < 2:
344
+ _clean_up_warnings()
345
+ return [1000,[],]
346
+ tracks = copy.deepcopy(opus) # couple of slices probably quicker...
347
+ ticks = int(tracks.pop(0))
348
+ score = [ticks,]
349
+ for opus_track in tracks:
350
+ ticks_so_far = 0
351
+ score_track = []
352
+ chapitch2note_on_events = dict([]) # 4.0
353
+ for opus_event in opus_track:
354
+ ticks_so_far += opus_event[1]
355
+ if opus_event[0] == 'note_off' or (opus_event[0] == 'note_on' and opus_event[4] == 0): # 4.8
356
+ cha = opus_event[2]
357
+ pitch = opus_event[3]
358
+ key = cha*128 + pitch
359
+ if chapitch2note_on_events.get(key):
360
+ new_event = chapitch2note_on_events[key].pop(0)
361
+ new_event[2] = ticks_so_far - new_event[1]
362
+ score_track.append(new_event)
363
+ elif pitch > 127:
364
+ pass #_warn('opus2score: note_off with no note_on, bad pitch='+str(pitch))
365
+ else:
366
+ pass #_warn('opus2score: note_off with no note_on cha='+str(cha)+' pitch='+str(pitch))
367
+ elif opus_event[0] == 'note_on':
368
+ cha = opus_event[2]
369
+ pitch = opus_event[3]
370
+ key = cha*128 + pitch
371
+ new_event = ['note',ticks_so_far,0,cha,pitch, opus_event[4]]
372
+ if chapitch2note_on_events.get(key):
373
+ chapitch2note_on_events[key].append(new_event)
374
+ else:
375
+ chapitch2note_on_events[key] = [new_event,]
376
+ else:
377
+ opus_event[1] = ticks_so_far
378
+ score_track.append(opus_event)
379
+ # check for unterminated notes (Oisín) -- 5.2
380
+ for chapitch in chapitch2note_on_events:
381
+ note_on_events = chapitch2note_on_events[chapitch]
382
+ for new_e in note_on_events:
383
+ new_e[2] = ticks_so_far - new_e[1]
384
+ score_track.append(new_e)
385
+ pass #_warn("opus2score: note_on with no note_off cha="+str(new_e[3])+' pitch='+str(new_e[4])+'; adding note_off at end')
386
+ score.append(score_track)
387
+ _clean_up_warnings()
388
+ return score
389
+
390
+ def midi2score(midi=b''):
391
+ r'''
392
+ Translates MIDI into a "score", using midi2opus() then opus2score()
393
+ '''
394
+ return opus2score(midi2opus(midi))
395
+
396
+ def midi2ms_score(midi=b''):
397
+ r'''
398
+ Translates MIDI into a "score" with one beat per second and one
399
+ tick per millisecond, using midi2opus() then to_millisecs()
400
+ then opus2score()
401
+ '''
402
+ return opus2score(to_millisecs(midi2opus(midi)))
403
+
404
+ #------------------------ Other Transformations ---------------------
405
+
406
+ def to_millisecs(old_opus=None):
407
+ r'''Recallibrates all the times in an "opus" to use one beat
408
+ per second and one tick per millisecond. This makes it
409
+ hard to retrieve any information about beats or barlines,
410
+ but it does make it easy to mix different scores together.
411
+ '''
412
+ if old_opus == None:
413
+ return [1000,[],]
414
+ try:
415
+ old_tpq = int(old_opus[0])
416
+ except IndexError: # 5.0
417
+ _warn('to_millisecs: the opus '+str(type(old_opus))+' has no elements')
418
+ return [1000,[],]
419
+ new_opus = [1000,]
420
+ # 6.7 first go through building a table of set_tempos by absolute-tick
421
+ ticks2tempo = {}
422
+ itrack = 1
423
+ while itrack < len(old_opus):
424
+ ticks_so_far = 0
425
+ for old_event in old_opus[itrack]:
426
+ if old_event[0] == 'note':
427
+ raise TypeError('to_millisecs needs an opus, not a score')
428
+ ticks_so_far += old_event[1]
429
+ if old_event[0] == 'set_tempo':
430
+ ticks2tempo[ticks_so_far] = old_event[2]
431
+ itrack += 1
432
+ # then get the sorted-array of their keys
433
+ tempo_ticks = [] # list of keys
434
+ for k in ticks2tempo.keys():
435
+ tempo_ticks.append(k)
436
+ tempo_ticks.sort()
437
+ # then go through converting to millisec, testing if the next
438
+ # set_tempo lies before the next track-event, and using it if so.
439
+ itrack = 1
440
+ while itrack < len(old_opus):
441
+ ms_per_old_tick = 500.0 / old_tpq # float: will round later 6.3
442
+ i_tempo_ticks = 0
443
+ ticks_so_far = 0
444
+ ms_so_far = 0.0
445
+ previous_ms_so_far = 0.0
446
+ new_track = [['set_tempo',0,1000000],] # new "crochet" is 1 sec
447
+ for old_event in old_opus[itrack]:
448
+ # detect if ticks2tempo has something before this event
449
+ # 20160702 if ticks2tempo is at the same time, leave it
450
+ event_delta_ticks = old_event[1]
451
+ if (i_tempo_ticks < len(tempo_ticks) and
452
+ tempo_ticks[i_tempo_ticks] < (ticks_so_far + old_event[1])):
453
+ delta_ticks = tempo_ticks[i_tempo_ticks] - ticks_so_far
454
+ ms_so_far += (ms_per_old_tick * delta_ticks)
455
+ ticks_so_far = tempo_ticks[i_tempo_ticks]
456
+ ms_per_old_tick = ticks2tempo[ticks_so_far] / (1000.0*old_tpq)
457
+ i_tempo_ticks += 1
458
+ event_delta_ticks -= delta_ticks
459
+ new_event = copy.deepcopy(old_event) # now handle the new event
460
+ ms_so_far += (ms_per_old_tick * old_event[1])
461
+ new_event[1] = round(ms_so_far - previous_ms_so_far)
462
+ if old_event[0] != 'set_tempo':
463
+ previous_ms_so_far = ms_so_far
464
+ new_track.append(new_event)
465
+ ticks_so_far += event_delta_ticks
466
+ new_opus.append(new_track)
467
+ itrack += 1
468
+ _clean_up_warnings()
469
+ return new_opus
470
+
471
+ def event2alsaseq(event=None): # 5.5
472
+ r'''Converts an event into the format needed by the alsaseq module,
473
+ http://pp.com.mx/python/alsaseq
474
+ The type of track (opus or score) is autodetected.
475
+ '''
476
+ pass
477
+
478
+ def grep(score=None, channels=None):
479
+ r'''Returns a "score" containing only the channels specified
480
+ '''
481
+ if score == None:
482
+ return [1000,[],]
483
+ ticks = score[0]
484
+ new_score = [ticks,]
485
+ if channels == None:
486
+ return new_score
487
+ channels = set(channels)
488
+ global Event2channelindex
489
+ itrack = 1
490
+ while itrack < len(score):
491
+ new_score.append([])
492
+ for event in score[itrack]:
493
+ channel_index = Event2channelindex.get(event[0], False)
494
+ if channel_index:
495
+ if event[channel_index] in channels:
496
+ new_score[itrack].append(event)
497
+ else:
498
+ new_score[itrack].append(event)
499
+ itrack += 1
500
+ return new_score
501
+
502
+ def play_score(score=None):
503
+ r'''Converts the "score" to midi, and feeds it into 'aplaymidi -'
504
+ '''
505
+ if score == None:
506
+ return
507
+ import subprocess
508
+ pipe = subprocess.Popen(['aplaymidi','-'], stdin=subprocess.PIPE)
509
+ if score_type(score) == 'opus':
510
+ pipe.stdin.write(opus2midi(score))
511
+ else:
512
+ pipe.stdin.write(score2midi(score))
513
+ pipe.stdin.close()
514
+
515
+ def timeshift(score=None, shift=None, start_time=None, from_time=0, tracks={0,1,2,3,4,5,6,7,8,10,12,13,14,15}):
516
+ r'''Returns a "score" shifted in time by "shift" ticks, or shifted
517
+ so that the first event starts at "start_time" ticks.
518
+
519
+ If "from_time" is specified, only those events in the score
520
+ that begin after it are shifted. If "start_time" is less than
521
+ "from_time" (or "shift" is negative), then the intermediate
522
+ notes are deleted, though patch-change events are preserved.
523
+
524
+ If "tracks" are specified, then only those tracks get shifted.
525
+ "tracks" can be a list, tuple or set; it gets converted to set
526
+ internally.
527
+
528
+ It is deprecated to specify both "shift" and "start_time".
529
+ If this does happen, timeshift() will print a warning to
530
+ stderr and ignore the "shift" argument.
531
+
532
+ If "shift" is negative and sufficiently large that it would
533
+ leave some event with a negative tick-value, then the score
534
+ is shifted so that the first event occurs at time 0. This
535
+ also occurs if "start_time" is negative, and is also the
536
+ default if neither "shift" nor "start_time" are specified.
537
+ '''
538
+ #_warn('tracks='+str(tracks))
539
+ if score == None or len(score) < 2:
540
+ return [1000, [],]
541
+ new_score = [score[0],]
542
+ my_type = score_type(score)
543
+ if my_type == '':
544
+ return new_score
545
+ if my_type == 'opus':
546
+ _warn("timeshift: opus format is not supported\n")
547
+ # _clean_up_scores() 6.2; doesn't exist! what was it supposed to do?
548
+ return new_score
549
+ if not (shift == None) and not (start_time == None):
550
+ _warn("timeshift: shift and start_time specified: ignoring shift\n")
551
+ shift = None
552
+ if shift == None:
553
+ if (start_time == None) or (start_time < 0):
554
+ start_time = 0
555
+ # shift = start_time - from_time
556
+
557
+ i = 1 # ignore first element (ticks)
558
+ tracks = set(tracks) # defend against tuples and lists
559
+ earliest = 1000000000
560
+ if not (start_time == None) or shift < 0: # first find the earliest event
561
+ while i < len(score):
562
+ if len(tracks) and not ((i-1) in tracks):
563
+ i += 1
564
+ continue
565
+ for event in score[i]:
566
+ if event[1] < from_time:
567
+ continue # just inspect the to_be_shifted events
568
+ if event[1] < earliest:
569
+ earliest = event[1]
570
+ i += 1
571
+ if earliest > 999999999:
572
+ earliest = 0
573
+ if shift == None:
574
+ shift = start_time - earliest
575
+ elif (earliest + shift) < 0:
576
+ start_time = 0
577
+ shift = 0 - earliest
578
+
579
+ i = 1 # ignore first element (ticks)
580
+ while i < len(score):
581
+ if len(tracks) == 0 or not ((i-1) in tracks): # 3.8
582
+ new_score.append(score[i])
583
+ i += 1
584
+ continue
585
+ new_track = []
586
+ for event in score[i]:
587
+ new_event = list(event)
588
+ #if new_event[1] == 0 and shift > 0 and new_event[0] != 'note':
589
+ # pass
590
+ #elif new_event[1] >= from_time:
591
+ if new_event[1] >= from_time:
592
+ # 4.1 must not rightshift set_tempo
593
+ if new_event[0] != 'set_tempo' or shift<0:
594
+ new_event[1] += shift
595
+ elif (shift < 0) and (new_event[1] >= (from_time+shift)):
596
+ continue
597
+ new_track.append(new_event)
598
+ if len(new_track) > 0:
599
+ new_score.append(new_track)
600
+ i += 1
601
+ _clean_up_warnings()
602
+ return new_score
603
+
604
+ def segment(score=None, start_time=None, end_time=None, start=0, end=100000000,
605
+ tracks={0,1,2,3,4,5,6,7,8,10,11,12,13,14,15}):
606
+ r'''Returns a "score" which is a segment of the one supplied
607
+ as the argument, beginning at "start_time" ticks and ending
608
+ at "end_time" ticks (or at the end if "end_time" is not supplied).
609
+ If the set "tracks" is specified, only those tracks will
610
+ be returned.
611
+ '''
612
+ if score == None or len(score) < 2:
613
+ return [1000, [],]
614
+ if start_time == None: # as of 4.2 start_time is recommended
615
+ start_time = start # start is legacy usage
616
+ if end_time == None: # likewise
617
+ end_time = end
618
+ new_score = [score[0],]
619
+ my_type = score_type(score)
620
+ if my_type == '':
621
+ return new_score
622
+ if my_type == 'opus':
623
+ # more difficult (disconnecting note_on's from their note_off's)...
624
+ _warn("segment: opus format is not supported\n")
625
+ _clean_up_warnings()
626
+ return new_score
627
+ i = 1 # ignore first element (ticks); we count in ticks anyway
628
+ tracks = set(tracks) # defend against tuples and lists
629
+ while i < len(score):
630
+ if len(tracks) and not ((i-1) in tracks):
631
+ i += 1
632
+ continue
633
+ new_track = []
634
+ channel2cc_num = {} # most recent controller change before start
635
+ channel2cc_val = {}
636
+ channel2cc_time = {}
637
+ channel2patch_num = {} # keep most recent patch change before start
638
+ channel2patch_time = {}
639
+ set_tempo_num = 500000 # most recent tempo change before start 6.3
640
+ set_tempo_time = 0
641
+ earliest_note_time = end_time
642
+ for event in score[i]:
643
+ if event[0] == 'control_change': # 6.5
644
+ cc_time = channel2cc_time.get(event[2]) or 0
645
+ if (event[1] <= start_time) and (event[1] >= cc_time):
646
+ channel2cc_num[event[2]] = event[3]
647
+ channel2cc_val[event[2]] = event[4]
648
+ channel2cc_time[event[2]] = event[1]
649
+ elif event[0] == 'patch_change':
650
+ patch_time = channel2patch_time.get(event[2]) or 0
651
+ if (event[1]<=start_time) and (event[1] >= patch_time): # 2.0
652
+ channel2patch_num[event[2]] = event[3]
653
+ channel2patch_time[event[2]] = event[1]
654
+ elif event[0] == 'set_tempo':
655
+ if (event[1]<=start_time) and (event[1]>=set_tempo_time): #6.4
656
+ set_tempo_num = event[2]
657
+ set_tempo_time = event[1]
658
+ if (event[1] >= start_time) and (event[1] <= end_time):
659
+ new_track.append(event)
660
+ if (event[0] == 'note') and (event[1] < earliest_note_time):
661
+ earliest_note_time = event[1]
662
+ if len(new_track) > 0:
663
+ new_track.append(['set_tempo', start_time, set_tempo_num])
664
+ for c in channel2patch_num:
665
+ new_track.append(['patch_change',start_time,c,channel2patch_num[c]],)
666
+ for c in channel2cc_num: # 6.5
667
+ new_track.append(['control_change',start_time,c,channel2cc_num[c],channel2cc_val[c]])
668
+ new_score.append(new_track)
669
+ i += 1
670
+ _clean_up_warnings()
671
+ return new_score
672
+
673
+ def score_type(opus_or_score=None):
674
+ r'''Returns a string, either 'opus' or 'score' or ''
675
+ '''
676
+ if opus_or_score == None or str(type(opus_or_score)).find('list')<0 or len(opus_or_score) < 2:
677
+ return ''
678
+ i = 1 # ignore first element
679
+ while i < len(opus_or_score):
680
+ for event in opus_or_score[i]:
681
+ if event[0] == 'note':
682
+ return 'score'
683
+ elif event[0] == 'note_on':
684
+ return 'opus'
685
+ i += 1
686
+ return ''
687
+
688
+ def concatenate_scores(scores):
689
+ r'''Concatenates a list of scores into one score.
690
+ If the scores differ in their "ticks" parameter,
691
+ they will all get converted to millisecond-tick format.
692
+ '''
693
+ # the deepcopys are needed if the input_score's are refs to the same obj
694
+ # e.g. if invoked by midisox's repeat()
695
+ input_scores = _consistentise_ticks(scores) # 3.7
696
+ output_score = copy.deepcopy(input_scores[0])
697
+ for input_score in input_scores[1:]:
698
+ output_stats = score2stats(output_score)
699
+ delta_ticks = output_stats['nticks']
700
+ itrack = 1
701
+ while itrack < len(input_score):
702
+ if itrack >= len(output_score): # new output track if doesn't exist
703
+ output_score.append([])
704
+ for event in input_score[itrack]:
705
+ output_score[itrack].append(copy.deepcopy(event))
706
+ output_score[itrack][-1][1] += delta_ticks
707
+ itrack += 1
708
+ return output_score
709
+
710
+ def merge_scores(scores):
711
+ r'''Merges a list of scores into one score. A merged score comprises
712
+ all of the tracks from all of the input scores; un-merging is possible
713
+ by selecting just some of the tracks. If the scores differ in their
714
+ "ticks" parameter, they will all get converted to millisecond-tick
715
+ format. merge_scores attempts to resolve channel-conflicts,
716
+ but there are of course only 15 available channels...
717
+ '''
718
+ input_scores = _consistentise_ticks(scores) # 3.6
719
+ output_score = [1000]
720
+ channels_so_far = set()
721
+ all_channels = {0,1,2,3,4,5,6,7,8,10,11,12,13,14,15}
722
+ global Event2channelindex
723
+ for input_score in input_scores:
724
+ new_channels = set(score2stats(input_score).get('channels_total', []))
725
+ new_channels.discard(9) # 2.8 cha9 must remain cha9 (in GM)
726
+ for channel in channels_so_far & new_channels:
727
+ # consistently choose lowest avaiable, to ease testing
728
+ free_channels = list(all_channels - (channels_so_far|new_channels))
729
+ if len(free_channels) > 0:
730
+ free_channels.sort()
731
+ free_channel = free_channels[0]
732
+ else:
733
+ free_channel = None
734
+ break
735
+ itrack = 1
736
+ while itrack < len(input_score):
737
+ for input_event in input_score[itrack]:
738
+ channel_index=Event2channelindex.get(input_event[0],False)
739
+ if channel_index and input_event[channel_index]==channel:
740
+ input_event[channel_index] = free_channel
741
+ itrack += 1
742
+ channels_so_far.add(free_channel)
743
+
744
+ channels_so_far |= new_channels
745
+ output_score.extend(input_score[1:])
746
+ return output_score
747
+
748
+ def _ticks(event):
749
+ return event[1]
750
+ def mix_opus_tracks(input_tracks): # 5.5
751
+ r'''Mixes an array of tracks into one track. A mixed track
752
+ cannot be un-mixed. It is assumed that the tracks share the same
753
+ ticks parameter and the same tempo.
754
+ Mixing score-tracks is trivial (just insert all events into one array).
755
+ Mixing opus-tracks is only slightly harder, but it's common enough
756
+ that a dedicated function is useful.
757
+ '''
758
+ output_score = [1000, []]
759
+ for input_track in input_tracks: # 5.8
760
+ input_score = opus2score([1000, input_track])
761
+ for event in input_score[1]:
762
+ output_score[1].append(event)
763
+ output_score[1].sort(key=_ticks)
764
+ output_opus = score2opus(output_score)
765
+ return output_opus[1]
766
+
767
+ def mix_scores(scores):
768
+ r'''Mixes a list of scores into one one-track score.
769
+ A mixed score cannot be un-mixed. Hopefully the scores
770
+ have no undesirable channel-conflicts between them.
771
+ If the scores differ in their "ticks" parameter,
772
+ they will all get converted to millisecond-tick format.
773
+ '''
774
+ input_scores = _consistentise_ticks(scores) # 3.6
775
+ output_score = [1000, []]
776
+ for input_score in input_scores:
777
+ for input_track in input_score[1:]:
778
+ output_score[1].extend(input_track)
779
+ return output_score
780
+
781
+ def score2stats(opus_or_score=None):
782
+ r'''Returns a dict of some basic stats about the score, like
783
+ bank_select (list of tuples (msb,lsb)),
784
+ channels_by_track (list of lists), channels_total (set),
785
+ general_midi_mode (list),
786
+ ntracks, nticks, patch_changes_by_track (list of dicts),
787
+ num_notes_by_channel (list of numbers),
788
+ patch_changes_total (set),
789
+ percussion (dict histogram of channel 9 events),
790
+ pitches (dict histogram of pitches on channels other than 9),
791
+ pitch_range_by_track (list, by track, of two-member-tuples),
792
+ pitch_range_sum (sum over tracks of the pitch_ranges),
793
+ '''
794
+ bank_select_msb = -1
795
+ bank_select_lsb = -1
796
+ bank_select = []
797
+ channels_by_track = []
798
+ channels_total = set([])
799
+ general_midi_mode = []
800
+ num_notes_by_channel = dict([])
801
+ patches_used_by_track = []
802
+ patches_used_total = set([])
803
+ patch_changes_by_track = []
804
+ patch_changes_total = set([])
805
+ percussion = dict([]) # histogram of channel 9 "pitches"
806
+ pitches = dict([]) # histogram of pitch-occurrences channels 0-8,10-15
807
+ pitch_range_sum = 0 # u pitch-ranges of each track
808
+ pitch_range_by_track = []
809
+ is_a_score = True
810
+ if opus_or_score == None:
811
+ return {'bank_select':[], 'channels_by_track':[], 'channels_total':[],
812
+ 'general_midi_mode':[], 'ntracks':0, 'nticks':0,
813
+ 'num_notes_by_channel':dict([]),
814
+ 'patch_changes_by_track':[], 'patch_changes_total':[],
815
+ 'percussion':{}, 'pitches':{}, 'pitch_range_by_track':[],
816
+ 'ticks_per_quarter':0, 'pitch_range_sum':0}
817
+ ticks_per_quarter = opus_or_score[0]
818
+ i = 1 # ignore first element, which is ticks
819
+ nticks = 0
820
+ while i < len(opus_or_score):
821
+ highest_pitch = 0
822
+ lowest_pitch = 128
823
+ channels_this_track = set([])
824
+ patch_changes_this_track = dict({})
825
+ for event in opus_or_score[i]:
826
+ if event[0] == 'note':
827
+ num_notes_by_channel[event[3]] = num_notes_by_channel.get(event[3],0) + 1
828
+ if event[3] == 9:
829
+ percussion[event[4]] = percussion.get(event[4],0) + 1
830
+ else:
831
+ pitches[event[4]] = pitches.get(event[4],0) + 1
832
+ if event[4] > highest_pitch:
833
+ highest_pitch = event[4]
834
+ if event[4] < lowest_pitch:
835
+ lowest_pitch = event[4]
836
+ channels_this_track.add(event[3])
837
+ channels_total.add(event[3])
838
+ finish_time = event[1] + event[2]
839
+ if finish_time > nticks:
840
+ nticks = finish_time
841
+ elif event[0] == 'note_off' or (event[0] == 'note_on' and event[4] == 0): # 4.8
842
+ finish_time = event[1]
843
+ if finish_time > nticks:
844
+ nticks = finish_time
845
+ elif event[0] == 'note_on':
846
+ is_a_score = False
847
+ num_notes_by_channel[event[2]] = num_notes_by_channel.get(event[2],0) + 1
848
+ if event[2] == 9:
849
+ percussion[event[3]] = percussion.get(event[3],0) + 1
850
+ else:
851
+ pitches[event[3]] = pitches.get(event[3],0) + 1
852
+ if event[3] > highest_pitch:
853
+ highest_pitch = event[3]
854
+ if event[3] < lowest_pitch:
855
+ lowest_pitch = event[3]
856
+ channels_this_track.add(event[2])
857
+ channels_total.add(event[2])
858
+ elif event[0] == 'patch_change':
859
+ patch_changes_this_track[event[2]] = event[3]
860
+ patch_changes_total.add(event[3])
861
+ elif event[0] == 'control_change':
862
+ if event[3] == 0: # bank select MSB
863
+ bank_select_msb = event[4]
864
+ elif event[3] == 32: # bank select LSB
865
+ bank_select_lsb = event[4]
866
+ if bank_select_msb >= 0 and bank_select_lsb >= 0:
867
+ bank_select.append((bank_select_msb,bank_select_lsb))
868
+ bank_select_msb = -1
869
+ bank_select_lsb = -1
870
+ elif event[0] == 'sysex_f0':
871
+ if _sysex2midimode.get(event[2], -1) >= 0:
872
+ general_midi_mode.append(_sysex2midimode.get(event[2]))
873
+ if is_a_score:
874
+ if event[1] > nticks:
875
+ nticks = event[1]
876
+ else:
877
+ nticks += event[1]
878
+ if lowest_pitch == 128:
879
+ lowest_pitch = 0
880
+ channels_by_track.append(channels_this_track)
881
+ patch_changes_by_track.append(patch_changes_this_track)
882
+ pitch_range_by_track.append((lowest_pitch,highest_pitch))
883
+ pitch_range_sum += (highest_pitch-lowest_pitch)
884
+ i += 1
885
+
886
+ return {'bank_select':bank_select,
887
+ 'channels_by_track':channels_by_track,
888
+ 'channels_total':channels_total,
889
+ 'general_midi_mode':general_midi_mode,
890
+ 'ntracks':len(opus_or_score)-1,
891
+ 'nticks':nticks,
892
+ 'num_notes_by_channel':num_notes_by_channel,
893
+ 'patch_changes_by_track':patch_changes_by_track,
894
+ 'patch_changes_total':patch_changes_total,
895
+ 'percussion':percussion,
896
+ 'pitches':pitches,
897
+ 'pitch_range_by_track':pitch_range_by_track,
898
+ 'pitch_range_sum':pitch_range_sum,
899
+ 'ticks_per_quarter':ticks_per_quarter}
900
+
901
+ #----------------------------- Event stuff --------------------------
902
+
903
+ _sysex2midimode = {
904
+ "\x7E\x7F\x09\x01\xF7": 1,
905
+ "\x7E\x7F\x09\x02\xF7": 0,
906
+ "\x7E\x7F\x09\x03\xF7": 2,
907
+ }
908
+
909
+ # Some public-access tuples:
910
+ MIDI_events = tuple('''note_off note_on key_after_touch
911
+ control_change patch_change channel_after_touch
912
+ pitch_wheel_change'''.split())
913
+
914
+ Text_events = tuple('''text_event copyright_text_event
915
+ track_name instrument_name lyric marker cue_point text_event_08
916
+ text_event_09 text_event_0a text_event_0b text_event_0c
917
+ text_event_0d text_event_0e text_event_0f'''.split())
918
+
919
+ Nontext_meta_events = tuple('''end_track set_tempo
920
+ smpte_offset time_signature key_signature sequencer_specific
921
+ raw_meta_event sysex_f0 sysex_f7 song_position song_select
922
+ tune_request'''.split())
923
+ # unsupported: raw_data
924
+
925
+ # Actually, 'tune_request' is is F-series event, not strictly a meta-event...
926
+ Meta_events = Text_events + Nontext_meta_events
927
+ All_events = MIDI_events + Meta_events
928
+
929
+ # And three dictionaries:
930
+ Number2patch = { # General MIDI patch numbers:
931
+ 0:'Acoustic Grand',
932
+ 1:'Bright Acoustic',
933
+ 2:'Electric Grand',
934
+ 3:'Honky-Tonk',
935
+ 4:'Electric Piano 1',
936
+ 5:'Electric Piano 2',
937
+ 6:'Harpsichord',
938
+ 7:'Clav',
939
+ 8:'Celesta',
940
+ 9:'Glockenspiel',
941
+ 10:'Music Box',
942
+ 11:'Vibraphone',
943
+ 12:'Marimba',
944
+ 13:'Xylophone',
945
+ 14:'Tubular Bells',
946
+ 15:'Dulcimer',
947
+ 16:'Drawbar Organ',
948
+ 17:'Percussive Organ',
949
+ 18:'Rock Organ',
950
+ 19:'Church Organ',
951
+ 20:'Reed Organ',
952
+ 21:'Accordion',
953
+ 22:'Harmonica',
954
+ 23:'Tango Accordion',
955
+ 24:'Acoustic Guitar(nylon)',
956
+ 25:'Acoustic Guitar(steel)',
957
+ 26:'Electric Guitar(jazz)',
958
+ 27:'Electric Guitar(clean)',
959
+ 28:'Electric Guitar(muted)',
960
+ 29:'Overdriven Guitar',
961
+ 30:'Distortion Guitar',
962
+ 31:'Guitar Harmonics',
963
+ 32:'Acoustic Bass',
964
+ 33:'Electric Bass(finger)',
965
+ 34:'Electric Bass(pick)',
966
+ 35:'Fretless Bass',
967
+ 36:'Slap Bass 1',
968
+ 37:'Slap Bass 2',
969
+ 38:'Synth Bass 1',
970
+ 39:'Synth Bass 2',
971
+ 40:'Violin',
972
+ 41:'Viola',
973
+ 42:'Cello',
974
+ 43:'Contrabass',
975
+ 44:'Tremolo Strings',
976
+ 45:'Pizzicato Strings',
977
+ 46:'Orchestral Harp',
978
+ 47:'Timpani',
979
+ 48:'String Ensemble 1',
980
+ 49:'String Ensemble 2',
981
+ 50:'SynthStrings 1',
982
+ 51:'SynthStrings 2',
983
+ 52:'Choir Aahs',
984
+ 53:'Voice Oohs',
985
+ 54:'Synth Voice',
986
+ 55:'Orchestra Hit',
987
+ 56:'Trumpet',
988
+ 57:'Trombone',
989
+ 58:'Tuba',
990
+ 59:'Muted Trumpet',
991
+ 60:'French Horn',
992
+ 61:'Brass Section',
993
+ 62:'SynthBrass 1',
994
+ 63:'SynthBrass 2',
995
+ 64:'Soprano Sax',
996
+ 65:'Alto Sax',
997
+ 66:'Tenor Sax',
998
+ 67:'Baritone Sax',
999
+ 68:'Oboe',
1000
+ 69:'English Horn',
1001
+ 70:'Bassoon',
1002
+ 71:'Clarinet',
1003
+ 72:'Piccolo',
1004
+ 73:'Flute',
1005
+ 74:'Recorder',
1006
+ 75:'Pan Flute',
1007
+ 76:'Blown Bottle',
1008
+ 77:'Skakuhachi',
1009
+ 78:'Whistle',
1010
+ 79:'Ocarina',
1011
+ 80:'Lead 1 (square)',
1012
+ 81:'Lead 2 (sawtooth)',
1013
+ 82:'Lead 3 (calliope)',
1014
+ 83:'Lead 4 (chiff)',
1015
+ 84:'Lead 5 (charang)',
1016
+ 85:'Lead 6 (voice)',
1017
+ 86:'Lead 7 (fifths)',
1018
+ 87:'Lead 8 (bass+lead)',
1019
+ 88:'Pad 1 (new age)',
1020
+ 89:'Pad 2 (warm)',
1021
+ 90:'Pad 3 (polysynth)',
1022
+ 91:'Pad 4 (choir)',
1023
+ 92:'Pad 5 (bowed)',
1024
+ 93:'Pad 6 (metallic)',
1025
+ 94:'Pad 7 (halo)',
1026
+ 95:'Pad 8 (sweep)',
1027
+ 96:'FX 1 (rain)',
1028
+ 97:'FX 2 (soundtrack)',
1029
+ 98:'FX 3 (crystal)',
1030
+ 99:'FX 4 (atmosphere)',
1031
+ 100:'FX 5 (brightness)',
1032
+ 101:'FX 6 (goblins)',
1033
+ 102:'FX 7 (echoes)',
1034
+ 103:'FX 8 (sci-fi)',
1035
+ 104:'Sitar',
1036
+ 105:'Banjo',
1037
+ 106:'Shamisen',
1038
+ 107:'Koto',
1039
+ 108:'Kalimba',
1040
+ 109:'Bagpipe',
1041
+ 110:'Fiddle',
1042
+ 111:'Shanai',
1043
+ 112:'Tinkle Bell',
1044
+ 113:'Agogo',
1045
+ 114:'Steel Drums',
1046
+ 115:'Woodblock',
1047
+ 116:'Taiko Drum',
1048
+ 117:'Melodic Tom',
1049
+ 118:'Synth Drum',
1050
+ 119:'Reverse Cymbal',
1051
+ 120:'Guitar Fret Noise',
1052
+ 121:'Breath Noise',
1053
+ 122:'Seashore',
1054
+ 123:'Bird Tweet',
1055
+ 124:'Telephone Ring',
1056
+ 125:'Helicopter',
1057
+ 126:'Applause',
1058
+ 127:'Gunshot',
1059
+ }
1060
+ Notenum2percussion = { # General MIDI Percussion (on Channel 9):
1061
+ 35:'Acoustic Bass Drum',
1062
+ 36:'Bass Drum 1',
1063
+ 37:'Side Stick',
1064
+ 38:'Acoustic Snare',
1065
+ 39:'Hand Clap',
1066
+ 40:'Electric Snare',
1067
+ 41:'Low Floor Tom',
1068
+ 42:'Closed Hi-Hat',
1069
+ 43:'High Floor Tom',
1070
+ 44:'Pedal Hi-Hat',
1071
+ 45:'Low Tom',
1072
+ 46:'Open Hi-Hat',
1073
+ 47:'Low-Mid Tom',
1074
+ 48:'Hi-Mid Tom',
1075
+ 49:'Crash Cymbal 1',
1076
+ 50:'High Tom',
1077
+ 51:'Ride Cymbal 1',
1078
+ 52:'Chinese Cymbal',
1079
+ 53:'Ride Bell',
1080
+ 54:'Tambourine',
1081
+ 55:'Splash Cymbal',
1082
+ 56:'Cowbell',
1083
+ 57:'Crash Cymbal 2',
1084
+ 58:'Vibraslap',
1085
+ 59:'Ride Cymbal 2',
1086
+ 60:'Hi Bongo',
1087
+ 61:'Low Bongo',
1088
+ 62:'Mute Hi Conga',
1089
+ 63:'Open Hi Conga',
1090
+ 64:'Low Conga',
1091
+ 65:'High Timbale',
1092
+ 66:'Low Timbale',
1093
+ 67:'High Agogo',
1094
+ 68:'Low Agogo',
1095
+ 69:'Cabasa',
1096
+ 70:'Maracas',
1097
+ 71:'Short Whistle',
1098
+ 72:'Long Whistle',
1099
+ 73:'Short Guiro',
1100
+ 74:'Long Guiro',
1101
+ 75:'Claves',
1102
+ 76:'Hi Wood Block',
1103
+ 77:'Low Wood Block',
1104
+ 78:'Mute Cuica',
1105
+ 79:'Open Cuica',
1106
+ 80:'Mute Triangle',
1107
+ 81:'Open Triangle',
1108
+ }
1109
+
1110
+ Event2channelindex = { 'note':3, 'note_off':2, 'note_on':2,
1111
+ 'key_after_touch':2, 'control_change':2, 'patch_change':2,
1112
+ 'channel_after_touch':2, 'pitch_wheel_change':2
1113
+ }
1114
+
1115
+ ################################################################
1116
+ # The code below this line is full of frightening things, all to
1117
+ # do with the actual encoding and decoding of binary MIDI data.
1118
+
1119
+ def _twobytes2int(byte_a):
1120
+ r'''decode a 16 bit quantity from two bytes,'''
1121
+ return (byte_a[1] | (byte_a[0] << 8))
1122
+
1123
+ def _int2twobytes(int_16bit):
1124
+ r'''encode a 16 bit quantity into two bytes,'''
1125
+ return bytes([(int_16bit>>8) & 0xFF, int_16bit & 0xFF])
1126
+
1127
+ def _read_14_bit(byte_a):
1128
+ r'''decode a 14 bit quantity from two bytes,'''
1129
+ return (byte_a[0] | (byte_a[1] << 7))
1130
+
1131
+ def _write_14_bit(int_14bit):
1132
+ r'''encode a 14 bit quantity into two bytes,'''
1133
+ return bytes([int_14bit & 0x7F, (int_14bit>>7) & 0x7F])
1134
+
1135
+ def _ber_compressed_int(integer):
1136
+ r'''BER compressed integer (not an ASN.1 BER, see perlpacktut for
1137
+ details). Its bytes represent an unsigned integer in base 128,
1138
+ most significant digit first, with as few digits as possible.
1139
+ Bit eight (the high bit) is set on each byte except the last.
1140
+ '''
1141
+ ber = bytearray(b'')
1142
+ seven_bits = 0x7F & integer
1143
+ ber.insert(0, seven_bits) # XXX surely should convert to a char ?
1144
+ integer >>= 7
1145
+ while integer > 0:
1146
+ seven_bits = 0x7F & integer
1147
+ ber.insert(0, 0x80|seven_bits) # XXX surely should convert to a char ?
1148
+ integer >>= 7
1149
+ return ber
1150
+
1151
+ def _unshift_ber_int(ba):
1152
+ r'''Given a bytearray, returns a tuple of (the ber-integer at the
1153
+ start, and the remainder of the bytearray).
1154
+ '''
1155
+ if not len(ba): # 6.7
1156
+ _warn('_unshift_ber_int: no integer found')
1157
+ return ((0, b""))
1158
+ byte = ba.pop(0)
1159
+ integer = 0
1160
+ while True:
1161
+ integer += (byte & 0x7F)
1162
+ if not (byte & 0x80):
1163
+ return ((integer, ba))
1164
+ if not len(ba):
1165
+ _warn('_unshift_ber_int: no end-of-integer found')
1166
+ return ((0, ba))
1167
+ byte = ba.pop(0)
1168
+ integer <<= 7
1169
+
1170
+ def _clean_up_warnings(): # 5.4
1171
+ # Call this before returning from any publicly callable function
1172
+ # whenever there's a possibility that a warning might have been printed
1173
+ # by the function, or by any private functions it might have called.
1174
+ if _no_warning:
1175
+ return
1176
+ global _previous_times
1177
+ global _previous_warning
1178
+ if _previous_times > 1:
1179
+ # E:1176, 0: invalid syntax (<string>, line 1176) (syntax-error) ???
1180
+ # print(' previous message repeated '+str(_previous_times)+' times', file=sys.stderr)
1181
+ # 6.7
1182
+ sys.stderr.write(' previous message repeated {0} times\n'.format(_previous_times))
1183
+ elif _previous_times > 0:
1184
+ sys.stderr.write(' previous message repeated\n')
1185
+ _previous_times = 0
1186
+ _previous_warning = ''
1187
+
1188
+ def _warn(s=''):
1189
+ if _no_warning:
1190
+ return
1191
+ global _previous_times
1192
+ global _previous_warning
1193
+ if s == _previous_warning: # 5.4
1194
+ _previous_times = _previous_times + 1
1195
+ else:
1196
+ _clean_up_warnings()
1197
+ sys.stderr.write(str(s)+"\n")
1198
+ _previous_warning = s
1199
+
1200
+ def _some_text_event(which_kind=0x01, text=b'some_text'):
1201
+ if str(type(text)).find("'str'") >= 0: # 6.4 test for back-compatibility
1202
+ data = bytes(text, encoding='ISO-8859-1')
1203
+ else:
1204
+ data = bytes(text)
1205
+ return b'\xFF'+bytes((which_kind,))+_ber_compressed_int(len(data))+data
1206
+
1207
+ def _consistentise_ticks(scores): # 3.6
1208
+ # used by mix_scores, merge_scores, concatenate_scores
1209
+ if len(scores) == 1:
1210
+ return copy.deepcopy(scores)
1211
+ are_consistent = True
1212
+ ticks = scores[0][0]
1213
+ iscore = 1
1214
+ while iscore < len(scores):
1215
+ if scores[iscore][0] != ticks:
1216
+ are_consistent = False
1217
+ break
1218
+ iscore += 1
1219
+ if are_consistent:
1220
+ return copy.deepcopy(scores)
1221
+ new_scores = []
1222
+ iscore = 0
1223
+ while iscore < len(scores):
1224
+ score = scores[iscore]
1225
+ new_scores.append(opus2score(to_millisecs(score2opus(score))))
1226
+ iscore += 1
1227
+ return new_scores
1228
+
1229
+
1230
+ ###########################################################################
1231
+
1232
+ def _decode(trackdata=b'', exclude=None, include=None,
1233
+ event_callback=None, exclusive_event_callback=None, no_eot_magic=False):
1234
+ r'''Decodes MIDI track data into an opus-style list of events.
1235
+ The options:
1236
+ 'exclude' is a list of event types which will be ignored SHOULD BE A SET
1237
+ 'include' (and no exclude), makes exclude a list
1238
+ of all possible events, /minus/ what include specifies
1239
+ 'event_callback' is a coderef
1240
+ 'exclusive_event_callback' is a coderef
1241
+ '''
1242
+ trackdata = bytearray(trackdata)
1243
+ if exclude == None:
1244
+ exclude = []
1245
+ if include == None:
1246
+ include = []
1247
+ if include and not exclude:
1248
+ exclude = All_events
1249
+ include = set(include)
1250
+ exclude = set(exclude)
1251
+
1252
+ # Pointer = 0; not used here; we eat through the bytearray instead.
1253
+ event_code = -1; # used for running status
1254
+ event_count = 0;
1255
+ events = []
1256
+
1257
+ while(len(trackdata)):
1258
+ # loop while there's anything to analyze ...
1259
+ eot = False # When True, the event registrar aborts this loop
1260
+ event_count += 1
1261
+
1262
+ E = []
1263
+ # E for events - we'll feed it to the event registrar at the end.
1264
+
1265
+ # Slice off the delta time code, and analyze it
1266
+ [time, remainder] = _unshift_ber_int(trackdata)
1267
+
1268
+ # Now let's see what we can make of the command
1269
+ first_byte = trackdata.pop(0) & 0xFF
1270
+
1271
+ if (first_byte < 0xF0): # It's a MIDI event
1272
+ if (first_byte & 0x80):
1273
+ event_code = first_byte
1274
+ else:
1275
+ # It wants running status; use last event_code value
1276
+ trackdata.insert(0, first_byte)
1277
+ if (event_code == -1):
1278
+ _warn("Running status not set; Aborting track.")
1279
+ return []
1280
+
1281
+ command = event_code & 0xF0
1282
+ channel = event_code & 0x0F
1283
+
1284
+ if (command == 0xF6): # 0-byte argument
1285
+ pass
1286
+ elif (command == 0xC0 or command == 0xD0): # 1-byte argument
1287
+ parameter = trackdata.pop(0) # could be B
1288
+ else: # 2-byte argument could be BB or 14-bit
1289
+ parameter = (trackdata.pop(0), trackdata.pop(0))
1290
+
1291
+ #################################################################
1292
+ # MIDI events
1293
+
1294
+ if (command == 0x80):
1295
+ if 'note_off' in exclude:
1296
+ continue
1297
+ E = ['note_off', time, channel, parameter[0], parameter[1]]
1298
+ elif (command == 0x90):
1299
+ if 'note_on' in exclude:
1300
+ continue
1301
+ E = ['note_on', time, channel, parameter[0], parameter[1]]
1302
+ elif (command == 0xA0):
1303
+ if 'key_after_touch' in exclude:
1304
+ continue
1305
+ E = ['key_after_touch',time,channel,parameter[0],parameter[1]]
1306
+ elif (command == 0xB0):
1307
+ if 'control_change' in exclude:
1308
+ continue
1309
+ E = ['control_change',time,channel,parameter[0],parameter[1]]
1310
+ elif (command == 0xC0):
1311
+ if 'patch_change' in exclude:
1312
+ continue
1313
+ E = ['patch_change', time, channel, parameter]
1314
+ elif (command == 0xD0):
1315
+ if 'channel_after_touch' in exclude:
1316
+ continue
1317
+ E = ['channel_after_touch', time, channel, parameter]
1318
+ elif (command == 0xE0):
1319
+ if 'pitch_wheel_change' in exclude:
1320
+ continue
1321
+ E = ['pitch_wheel_change', time, channel,
1322
+ _read_14_bit(parameter)-0x2000]
1323
+ else:
1324
+ _warn("Shouldn't get here; command="+hex(command))
1325
+
1326
+ elif (first_byte == 0xFF): # It's a Meta-Event! ##################
1327
+ #[command, length, remainder] =
1328
+ # unpack("xCwa*", substr(trackdata, $Pointer, 6));
1329
+ #Pointer += 6 - len(remainder);
1330
+ # # Move past JUST the length-encoded.
1331
+ command = trackdata.pop(0) & 0xFF
1332
+ [length, trackdata] = _unshift_ber_int(trackdata)
1333
+ if (command == 0x00):
1334
+ if (length == 2):
1335
+ E = ['set_sequence_number',time,_twobytes2int(trackdata)]
1336
+ else:
1337
+ _warn('set_sequence_number: length must be 2, not '+str(length))
1338
+ E = ['set_sequence_number', time, 0]
1339
+
1340
+ elif command >= 0x01 and command <= 0x0f: # Text events
1341
+ # 6.2 take it in bytes; let the user get the right encoding.
1342
+ # text_str = trackdata[0:length].decode('ascii','ignore')
1343
+ # text_str = trackdata[0:length].decode('ISO-8859-1')
1344
+ # 6.4 take it in bytes; let the user get the right encoding.
1345
+ text_data = bytes(trackdata[0:length]) # 6.4
1346
+ # Defined text events
1347
+ if (command == 0x01):
1348
+ E = ['text_event', time, text_data]
1349
+ elif (command == 0x02):
1350
+ E = ['copyright_text_event', time, text_data]
1351
+ elif (command == 0x03):
1352
+ E = ['track_name', time, text_data]
1353
+ elif (command == 0x04):
1354
+ E = ['instrument_name', time, text_data]
1355
+ elif (command == 0x05):
1356
+ E = ['lyric', time, text_data]
1357
+ elif (command == 0x06):
1358
+ E = ['marker', time, text_data]
1359
+ elif (command == 0x07):
1360
+ E = ['cue_point', time, text_data]
1361
+ # Reserved but apparently unassigned text events
1362
+ elif (command == 0x08):
1363
+ E = ['text_event_08', time, text_data]
1364
+ elif (command == 0x09):
1365
+ E = ['text_event_09', time, text_data]
1366
+ elif (command == 0x0a):
1367
+ E = ['text_event_0a', time, text_data]
1368
+ elif (command == 0x0b):
1369
+ E = ['text_event_0b', time, text_data]
1370
+ elif (command == 0x0c):
1371
+ E = ['text_event_0c', time, text_data]
1372
+ elif (command == 0x0d):
1373
+ E = ['text_event_0d', time, text_data]
1374
+ elif (command == 0x0e):
1375
+ E = ['text_event_0e', time, text_data]
1376
+ elif (command == 0x0f):
1377
+ E = ['text_event_0f', time, text_data]
1378
+
1379
+ # Now the sticky events -------------------------------------
1380
+ elif (command == 0x2F):
1381
+ E = ['end_track', time]
1382
+ # The code for handling this, oddly, comes LATER,
1383
+ # in the event registrar.
1384
+ elif (command == 0x51): # DTime, Microseconds/Crochet
1385
+ if length != 3:
1386
+ _warn('set_tempo event, but length='+str(length))
1387
+ E = ['set_tempo', time,
1388
+ struct.unpack(">I", b'\x00'+trackdata[0:3])[0]]
1389
+ elif (command == 0x54):
1390
+ if length != 5: # DTime, HR, MN, SE, FR, FF
1391
+ _warn('smpte_offset event, but length='+str(length))
1392
+ E = ['smpte_offset',time] + list(struct.unpack(">BBBBB",trackdata[0:5]))
1393
+ elif (command == 0x58):
1394
+ if length != 4: # DTime, NN, DD, CC, BB
1395
+ _warn('time_signature event, but length='+str(length))
1396
+ E = ['time_signature', time]+list(trackdata[0:4])
1397
+ elif (command == 0x59):
1398
+ if length != 2: # DTime, SF(signed), MI
1399
+ _warn('key_signature event, but length='+str(length))
1400
+ E = ['key_signature',time] + list(struct.unpack(">bB",trackdata[0:2]))
1401
+ elif (command == 0x7F): # 6.4
1402
+ E = ['sequencer_specific',time, bytes(trackdata[0:length])]
1403
+ else:
1404
+ E = ['raw_meta_event', time, command,
1405
+ bytes(trackdata[0:length])] # 6.0
1406
+ #"[uninterpretable meta-event command of length length]"
1407
+ # DTime, Command, Binary Data
1408
+ # It's uninterpretable; record it as raw_data.
1409
+
1410
+ # Pointer += length; # Now move Pointer
1411
+ trackdata = trackdata[length:]
1412
+
1413
+ ######################################################################
1414
+ elif (first_byte == 0xF0 or first_byte == 0xF7):
1415
+ # Note that sysexes in MIDI /files/ are different than sysexes
1416
+ # in MIDI transmissions!! The vast majority of system exclusive
1417
+ # messages will just use the F0 format. For instance, the
1418
+ # transmitted message F0 43 12 00 07 F7 would be stored in a
1419
+ # MIDI file as F0 05 43 12 00 07 F7. As mentioned above, it is
1420
+ # required to include the F7 at the end so that the reader of the
1421
+ # MIDI file knows that it has read the entire message. (But the F7
1422
+ # is omitted if this is a non-final block in a multiblock sysex;
1423
+ # but the F7 (if there) is counted in the message's declared
1424
+ # length, so we don't have to think about it anyway.)
1425
+ #command = trackdata.pop(0)
1426
+ [length, trackdata] = _unshift_ber_int(trackdata)
1427
+ if first_byte == 0xF0:
1428
+ # 20091008 added ISO-8859-1 to get an 8-bit str
1429
+ # 6.4 return bytes instead
1430
+ E = ['sysex_f0', time, bytes(trackdata[0:length])]
1431
+ else:
1432
+ E = ['sysex_f7', time, bytes(trackdata[0:length])]
1433
+ trackdata = trackdata[length:]
1434
+
1435
+ ######################################################################
1436
+ # Now, the MIDI file spec says:
1437
+ # <track data> = <MTrk event>+
1438
+ # <MTrk event> = <delta-time> <event>
1439
+ # <event> = <MIDI event> | <sysex event> | <meta-event>
1440
+ # I know that, on the wire, <MIDI event> can include note_on,
1441
+ # note_off, and all the other 8x to Ex events, AND Fx events
1442
+ # other than F0, F7, and FF -- namely, <song position msg>,
1443
+ # <song select msg>, and <tune request>.
1444
+ #
1445
+ # Whether these can occur in MIDI files is not clear specified
1446
+ # from the MIDI file spec. So, I'm going to assume that
1447
+ # they CAN, in practice, occur. I don't know whether it's
1448
+ # proper for you to actually emit these into a MIDI file.
1449
+
1450
+ elif (first_byte == 0xF2): # DTime, Beats
1451
+ # <song position msg> ::= F2 <data pair>
1452
+ E = ['song_position', time, _read_14_bit(trackdata[:2])]
1453
+ trackdata = trackdata[2:]
1454
+
1455
+ elif (first_byte == 0xF3): # <song select msg> ::= F3 <data singlet>
1456
+ # E = ['song_select', time, struct.unpack('>B',trackdata.pop(0))[0]]
1457
+ E = ['song_select', time, trackdata[0]]
1458
+ trackdata = trackdata[1:]
1459
+ # DTime, Thing (what?! song number? whatever ...)
1460
+
1461
+ elif (first_byte == 0xF6): # DTime
1462
+ E = ['tune_request', time]
1463
+ # What would a tune request be doing in a MIDI /file/?
1464
+
1465
+ #########################################################
1466
+ # ADD MORE META-EVENTS HERE. TODO:
1467
+ # f1 -- MTC Quarter Frame Message. One data byte follows
1468
+ # the Status; it's the time code value, from 0 to 127.
1469
+ # f8 -- MIDI clock. no data.
1470
+ # fa -- MIDI start. no data.
1471
+ # fb -- MIDI continue. no data.
1472
+ # fc -- MIDI stop. no data.
1473
+ # fe -- Active sense. no data.
1474
+ # f4 f5 f9 fd -- unallocated
1475
+
1476
+ r'''
1477
+ elif (first_byte > 0xF0) { # Some unknown kinda F-series event ####
1478
+ # Here we only produce a one-byte piece of raw data.
1479
+ # But the encoder for 'raw_data' accepts any length of it.
1480
+ E = [ 'raw_data',
1481
+ time, substr(trackdata,Pointer,1) ]
1482
+ # DTime and the Data (in this case, the one Event-byte)
1483
+ ++Pointer; # itself
1484
+
1485
+ '''
1486
+ elif first_byte > 0xF0: # Some unknown F-series event
1487
+ # Here we only produce a one-byte piece of raw data.
1488
+ # E = ['raw_data', time, bytest(trackdata[0])] # 6.4
1489
+ E = ['raw_data', time, trackdata[0]] # 6.4 6.7
1490
+ trackdata = trackdata[1:]
1491
+ else: # Fallthru.
1492
+ _warn("Aborting track. Command-byte first_byte="+hex(first_byte))
1493
+ break
1494
+ # End of the big if-group
1495
+
1496
+
1497
+ ######################################################################
1498
+ # THE EVENT REGISTRAR...
1499
+ if E and (E[0] == 'end_track'):
1500
+ # This is the code for exceptional handling of the EOT event.
1501
+ eot = True
1502
+ if not no_eot_magic:
1503
+ if E[1] > 0: # a null text-event to carry the delta-time
1504
+ E = ['text_event', E[1], '']
1505
+ else:
1506
+ E = [] # EOT with a delta-time of 0; ignore it.
1507
+
1508
+ if E and not (E[0] in exclude):
1509
+ #if ( $exclusive_event_callback ):
1510
+ # &{ $exclusive_event_callback }( @E );
1511
+ #else:
1512
+ # &{ $event_callback }( @E ) if $event_callback;
1513
+ events.append(E)
1514
+ if eot:
1515
+ break
1516
+
1517
+ # End of the big "Event" while-block
1518
+
1519
+ return events
1520
+
1521
+
1522
+ ###########################################################################
1523
+ def _encode(events_lol, unknown_callback=None, never_add_eot=False,
1524
+ no_eot_magic=False, no_running_status=False):
1525
+ # encode an event structure, presumably for writing to a file
1526
+ # Calling format:
1527
+ # $data_r = MIDI::Event::encode( \@event_lol, { options } );
1528
+ # Takes a REFERENCE to an event structure (a LoL)
1529
+ # Returns an (unblessed) REFERENCE to track data.
1530
+
1531
+ # If you want to use this to encode a /single/ event,
1532
+ # you still have to do it as a reference to an event structure (a LoL)
1533
+ # that just happens to have just one event. I.e.,
1534
+ # encode( [ $event ] ) or encode( [ [ 'note_on', 100, 5, 42, 64] ] )
1535
+ # If you're doing this, consider the never_add_eot track option, as in
1536
+ # print MIDI ${ encode( [ $event], { 'never_add_eot' => 1} ) };
1537
+
1538
+ data = [] # what I'll store the chunks of byte-data in
1539
+
1540
+ # This is so my end_track magic won't corrupt the original
1541
+ events = copy.deepcopy(events_lol)
1542
+
1543
+ if not never_add_eot:
1544
+ # One way or another, tack on an 'end_track'
1545
+ if events:
1546
+ last = events[-1]
1547
+ if not (last[0] == 'end_track'): # no end_track already
1548
+ if (last[0] == 'text_event' and len(last[2]) == 0):
1549
+ # 0-length text event at track-end.
1550
+ if no_eot_magic:
1551
+ # Exceptional case: don't mess with track-final
1552
+ # 0-length text_events; just peg on an end_track
1553
+ events.append(['end_track', 0])
1554
+ else:
1555
+ # NORMAL CASE: replace with an end_track, leaving DTime
1556
+ last[0] = 'end_track'
1557
+ else:
1558
+ # last event was neither 0-length text_event nor end_track
1559
+ events.append(['end_track', 0])
1560
+ else: # an eventless track!
1561
+ events = [['end_track', 0],]
1562
+
1563
+ # maybe_running_status = not no_running_status # unused? 4.7
1564
+ last_status = -1
1565
+
1566
+ for event_r in (events):
1567
+ E = copy.deepcopy(event_r)
1568
+ # otherwise the shifting'd corrupt the original
1569
+ if not E:
1570
+ continue
1571
+
1572
+ event = E.pop(0)
1573
+ if not len(event):
1574
+ continue
1575
+
1576
+ dtime = int(E.pop(0))
1577
+ # print('event='+str(event)+' dtime='+str(dtime))
1578
+
1579
+ event_data = ''
1580
+
1581
+ if ( # MIDI events -- eligible for running status
1582
+ event == 'note_on'
1583
+ or event == 'note_off'
1584
+ or event == 'control_change'
1585
+ or event == 'key_after_touch'
1586
+ or event == 'patch_change'
1587
+ or event == 'channel_after_touch'
1588
+ or event == 'pitch_wheel_change' ):
1589
+
1590
+ # This block is where we spend most of the time. Gotta be tight.
1591
+ if (event == 'note_off'):
1592
+ status = 0x80 | (int(E[0]) & 0x0F)
1593
+ parameters = struct.pack('>BB', int(E[1])&0x7F, int(E[2])&0x7F)
1594
+ elif (event == 'note_on'):
1595
+ status = 0x90 | (int(E[0]) & 0x0F)
1596
+ parameters = struct.pack('>BB', int(E[1])&0x7F, int(E[2])&0x7F)
1597
+ elif (event == 'key_after_touch'):
1598
+ status = 0xA0 | (int(E[0]) & 0x0F)
1599
+ parameters = struct.pack('>BB', int(E[1])&0x7F, int(E[2])&0x7F)
1600
+ elif (event == 'control_change'):
1601
+ status = 0xB0 | (int(E[0]) & 0x0F)
1602
+ parameters = struct.pack('>BB', int(E[1])&0xFF, int(E[2])&0xFF)
1603
+ elif (event == 'patch_change'):
1604
+ status = 0xC0 | (int(E[0]) & 0x0F)
1605
+ parameters = struct.pack('>B', int(E[1]) & 0xFF)
1606
+ elif (event == 'channel_after_touch'):
1607
+ status = 0xD0 | (int(E[0]) & 0x0F)
1608
+ parameters = struct.pack('>B', int(E[1]) & 0xFF)
1609
+ elif (event == 'pitch_wheel_change'):
1610
+ status = 0xE0 | (int(E[0]) & 0x0F)
1611
+ parameters = _write_14_bit(int(E[1]) + 0x2000)
1612
+ else:
1613
+ _warn("BADASS FREAKOUT ERROR 31415!")
1614
+
1615
+ # And now the encoding
1616
+ # w = BER compressed integer (not ASN.1 BER, see perlpacktut for
1617
+ # details). Its bytes represent an unsigned integer in base 128,
1618
+ # most significant digit first, with as few digits as possible.
1619
+ # Bit eight (the high bit) is set on each byte except the last.
1620
+
1621
+ data.append(_ber_compressed_int(dtime))
1622
+ if (status != last_status) or no_running_status:
1623
+ data.append(struct.pack('>B', status))
1624
+ data.append(parameters)
1625
+
1626
+ last_status = status
1627
+ continue
1628
+ else:
1629
+ # Not a MIDI event.
1630
+ # All the code in this block could be more efficient,
1631
+ # but this is not where the code needs to be tight.
1632
+ # print "zaz $event\n";
1633
+ last_status = -1
1634
+
1635
+ if event == 'raw_meta_event':
1636
+ event_data = _some_text_event(int(E[0]), E[1])
1637
+ elif (event == 'set_sequence_number'): # 3.9
1638
+ event_data = b'\xFF\x00\x02'+_int2twobytes(E[0])
1639
+
1640
+ # Text meta-events...
1641
+ # a case for a dict, I think (pjb) ...
1642
+ elif (event == 'text_event'):
1643
+ event_data = _some_text_event(0x01, E[0])
1644
+ elif (event == 'copyright_text_event'):
1645
+ event_data = _some_text_event(0x02, E[0])
1646
+ elif (event == 'track_name'):
1647
+ event_data = _some_text_event(0x03, E[0])
1648
+ elif (event == 'instrument_name'):
1649
+ event_data = _some_text_event(0x04, E[0])
1650
+ elif (event == 'lyric'):
1651
+ event_data = _some_text_event(0x05, E[0])
1652
+ elif (event == 'marker'):
1653
+ event_data = _some_text_event(0x06, E[0])
1654
+ elif (event == 'cue_point'):
1655
+ event_data = _some_text_event(0x07, E[0])
1656
+ elif (event == 'text_event_08'):
1657
+ event_data = _some_text_event(0x08, E[0])
1658
+ elif (event == 'text_event_09'):
1659
+ event_data = _some_text_event(0x09, E[0])
1660
+ elif (event == 'text_event_0a'):
1661
+ event_data = _some_text_event(0x0A, E[0])
1662
+ elif (event == 'text_event_0b'):
1663
+ event_data = _some_text_event(0x0B, E[0])
1664
+ elif (event == 'text_event_0c'):
1665
+ event_data = _some_text_event(0x0C, E[0])
1666
+ elif (event == 'text_event_0d'):
1667
+ event_data = _some_text_event(0x0D, E[0])
1668
+ elif (event == 'text_event_0e'):
1669
+ event_data = _some_text_event(0x0E, E[0])
1670
+ elif (event == 'text_event_0f'):
1671
+ event_data = _some_text_event(0x0F, E[0])
1672
+ # End of text meta-events
1673
+
1674
+ elif (event == 'end_track'):
1675
+ event_data = b"\xFF\x2F\x00"
1676
+
1677
+ elif (event == 'set_tempo'):
1678
+ #event_data = struct.pack(">BBwa*", 0xFF, 0x51, 3,
1679
+ # substr( struct.pack('>I', E[0]), 1, 3))
1680
+ event_data = b'\xFF\x51\x03'+struct.pack('>I',E[0])[1:]
1681
+ elif (event == 'smpte_offset'):
1682
+ # event_data = struct.pack(">BBwBBBBB", 0xFF, 0x54, 5, E[0:5] )
1683
+ event_data = struct.pack(">BBBbBBBB", 0xFF,0x54,0x05,E[0],E[1],E[2],E[3],E[4])
1684
+ elif (event == 'time_signature'):
1685
+ # event_data = struct.pack(">BBwBBBB", 0xFF, 0x58, 4, E[0:4] )
1686
+ event_data = struct.pack(">BBBbBBB", 0xFF, 0x58, 0x04, E[0],E[1],E[2],E[3])
1687
+ elif (event == 'key_signature'):
1688
+ event_data = struct.pack(">BBBbB", 0xFF, 0x59, 0x02, E[0],E[1])
1689
+ elif (event == 'sequencer_specific'):
1690
+ # event_data = struct.pack(">BBwa*", 0xFF,0x7F, len(E[0]), E[0])
1691
+ event_data = _some_text_event(0x7F, E[0])
1692
+ # End of Meta-events
1693
+
1694
+ # Other Things...
1695
+ elif (event == 'sysex_f0'):
1696
+ #event_data = struct.pack(">Bwa*", 0xF0, len(E[0]), E[0])
1697
+ #B=bitstring w=BER-compressed-integer a=null-padded-ascii-str
1698
+ event_data = bytearray(b'\xF0')+_ber_compressed_int(len(E[0]))+bytearray(E[0])
1699
+ elif (event == 'sysex_f7'):
1700
+ #event_data = struct.pack(">Bwa*", 0xF7, len(E[0]), E[0])
1701
+ event_data = bytearray(b'\xF7')+_ber_compressed_int(len(E[0]))+bytearray(E[0])
1702
+
1703
+ elif (event == 'song_position'):
1704
+ event_data = b"\xF2" + _write_14_bit( E[0] )
1705
+ elif (event == 'song_select'):
1706
+ event_data = struct.pack('>BB', 0xF3, E[0] )
1707
+ elif (event == 'tune_request'):
1708
+ event_data = b"\xF6"
1709
+ elif (event == 'raw_data'):
1710
+ _warn("_encode: raw_data event not supported")
1711
+ # event_data = E[0]
1712
+ continue
1713
+ # End of Other Stuff
1714
+
1715
+ else:
1716
+ # The Big Fallthru
1717
+ if unknown_callback:
1718
+ # push(@data, &{ $unknown_callback }( @$event_r ))
1719
+ pass
1720
+ else:
1721
+ _warn("Unknown event: "+str(event))
1722
+ # To surpress complaint here, just set
1723
+ # 'unknown_callback' => sub { return () }
1724
+ continue
1725
+
1726
+ #print "Event $event encoded part 2\n"
1727
+ if str(type(event_data)).find("'str'") >= 0:
1728
+ event_data = bytearray(event_data.encode('Latin1', 'ignore'))
1729
+ if len(event_data): # how could $event_data be empty
1730
+ # data.append(struct.pack('>wa*', dtime, event_data))
1731
+ # print(' event_data='+str(event_data))
1732
+ data.append(_ber_compressed_int(dtime)+event_data)
1733
+
1734
+ return b''.join(data)
1735
+
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
- title: Modified AI Midi Tool Space IAT 360
3
- emoji: 🐠
4
- colorFrom: gray
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 5.8.0
8
- app_file: app.py
9
- pinned: false
10
- short_description: 'A modified version of the AI midi composer '
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Midi Music Generator
3
+ emoji: 🎼🎶
4
+ colorFrom: red
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 5.3.0
8
+ app_file: app_onnx.py
9
+ pinned: true
10
+ license: apache-2.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import random
3
+ import argparse
4
+ import glob
5
+ import json
6
+ import os
7
+ import time
8
+ from concurrent.futures import ThreadPoolExecutor
9
+
10
+ import gradio as gr
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn.functional as F
14
+ import tqdm
15
+ from huggingface_hub import hf_hub_download
16
+ from transformers import DynamicCache
17
+
18
+ import MIDI
19
+ from midi_model import MIDIModel, MIDIModelConfig
20
+ from midi_synthesizer import MidiSynthesizer
21
+
22
+ MAX_SEED = np.iinfo(np.int32).max
23
+ in_space = os.getenv("SYSTEM") == "spaces"
24
+
25
+
26
+ @torch.inference_mode()
27
+ def generate(model: MIDIModel, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20,
28
+ disable_patch_change=False, disable_control_change=False, disable_channels=None, generator=None):
29
+ tokenizer = model.tokenizer
30
+ if disable_channels is not None:
31
+ disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
32
+ else:
33
+ disable_channels = []
34
+ max_token_seq = tokenizer.max_token_seq
35
+ if prompt is None:
36
+ input_tensor = torch.full((1, max_token_seq), tokenizer.pad_id, dtype=torch.long, device=model.device)
37
+ input_tensor[0, 0] = tokenizer.bos_id # bos
38
+ input_tensor = input_tensor.unsqueeze(0)
39
+ input_tensor = torch.cat([input_tensor] * batch_size, dim=0)
40
+ else:
41
+ if len(prompt.shape) == 2:
42
+ prompt = prompt[None, :]
43
+ prompt = np.repeat(prompt, repeats=batch_size, axis=0)
44
+ elif prompt.shape[0] == 1:
45
+ prompt = np.repeat(prompt, repeats=batch_size, axis=0)
46
+ elif len(prompt.shape) != 3 or prompt.shape[0] != batch_size:
47
+ raise ValueError(f"invalid shape for prompt, {prompt.shape}")
48
+ prompt = prompt[..., :max_token_seq]
49
+ if prompt.shape[-1] < max_token_seq:
50
+ prompt = np.pad(prompt, ((0, 0), (0, 0), (0, max_token_seq - prompt.shape[-1])),
51
+ mode="constant", constant_values=tokenizer.pad_id)
52
+ input_tensor = torch.from_numpy(prompt).to(dtype=torch.long, device=model.device)
53
+ cur_len = input_tensor.shape[1]
54
+ bar = tqdm.tqdm(desc="generating", total=max_len - cur_len, disable=in_space)
55
+ cache1 = DynamicCache()
56
+ past_len = 0
57
+ with bar:
58
+ while cur_len < max_len:
59
+ end = [False] * batch_size
60
+ hidden = model.forward(input_tensor[:, past_len:], cache=cache1)[:, -1]
61
+ next_token_seq = None
62
+ event_names = [""] * batch_size
63
+ cache2 = DynamicCache()
64
+ for i in range(max_token_seq):
65
+ mask = torch.zeros((batch_size, tokenizer.vocab_size), dtype=torch.int64, device=model.device)
66
+ for b in range(batch_size):
67
+ if end[b]:
68
+ mask[b, tokenizer.pad_id] = 1
69
+ continue
70
+ if i == 0:
71
+ mask_ids = list(tokenizer.event_ids.values()) + [tokenizer.eos_id]
72
+ if disable_patch_change:
73
+ mask_ids.remove(tokenizer.event_ids["patch_change"])
74
+ if disable_control_change:
75
+ mask_ids.remove(tokenizer.event_ids["control_change"])
76
+ mask[b, mask_ids] = 1
77
+ else:
78
+ param_names = tokenizer.events[event_names[b]]
79
+ if i > len(param_names):
80
+ mask[b, tokenizer.pad_id] = 1
81
+ continue
82
+ param_name = param_names[i - 1]
83
+ mask_ids = tokenizer.parameter_ids[param_name]
84
+ if param_name == "channel":
85
+ mask_ids = [i for i in mask_ids if i not in disable_channels]
86
+ mask[b, mask_ids] = 1
87
+ mask = mask.unsqueeze(1)
88
+ x = next_token_seq
89
+ if i != 0:
90
+ hidden = None
91
+ x = x[:, -1:]
92
+ logits = model.forward_token(hidden, x, cache=cache2)[:, -1:]
93
+ scores = torch.softmax(logits / temp, dim=-1) * mask
94
+ samples = model.sample_top_p_k(scores, top_p, top_k, generator=generator)
95
+ if i == 0:
96
+ next_token_seq = samples
97
+ for b in range(batch_size):
98
+ if end[b]:
99
+ continue
100
+ eid = samples[b].item()
101
+ if eid == tokenizer.eos_id:
102
+ end[b] = True
103
+ else:
104
+ event_names[b] = tokenizer.id_events[eid]
105
+ else:
106
+ next_token_seq = torch.cat([next_token_seq, samples], dim=1)
107
+ if all([len(tokenizer.events[event_names[b]]) == i for b in range(batch_size) if not end[b]]):
108
+ break
109
+ if next_token_seq.shape[1] < max_token_seq:
110
+ next_token_seq = F.pad(next_token_seq, (0, max_token_seq - next_token_seq.shape[1]),
111
+ "constant", value=tokenizer.pad_id)
112
+ next_token_seq = next_token_seq.unsqueeze(1)
113
+ input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
114
+ past_len = cur_len
115
+ cur_len += 1
116
+ bar.update(1)
117
+ yield next_token_seq[:, 0].cpu().numpy()
118
+ if all(end):
119
+ break
120
+
121
+
122
+ def create_msg(name, data):
123
+ return {"name": name, "data": data}
124
+
125
+
126
+ def send_msgs(msgs):
127
+ return json.dumps(msgs)
128
+
129
+
130
+ def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
131
+ time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
132
+ remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
133
+ t = gen_events // 23
134
+ if "large" in model_name:
135
+ t = gen_events // 14
136
+ return t + 5
137
+
138
+
139
+ @spaces.GPU(duration=get_duration)
140
+ def run(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm, time_sig,
141
+ key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels,
142
+ seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
143
+ model = models[model_name]
144
+ model.to(device=opt.device)
145
+ tokenizer = model.tokenizer
146
+ bpm = int(bpm)
147
+ if time_sig == "auto":
148
+ time_sig = None
149
+ time_sig_nn = 4
150
+ time_sig_dd = 2
151
+ else:
152
+ time_sig_nn, time_sig_dd = time_sig.split('/')
153
+ time_sig_nn = int(time_sig_nn)
154
+ time_sig_dd = {2: 1, 4: 2, 8: 3}[int(time_sig_dd)]
155
+ if key_sig == 0:
156
+ key_sig = None
157
+ key_sig_sf = 0
158
+ key_sig_mi = 0
159
+ else:
160
+ key_sig = (key_sig - 1)
161
+ key_sig_sf = key_sig // 2 - 7
162
+ key_sig_mi = key_sig % 2
163
+ gen_events = int(gen_events)
164
+ max_len = gen_events
165
+ if seed_rand:
166
+ seed = random.randint(0, MAX_SEED)
167
+ generator = torch.Generator(opt.device).manual_seed(seed)
168
+ disable_patch_change = False
169
+ disable_channels = None
170
+ if tab == 0:
171
+ i = 0
172
+ mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
173
+ if tokenizer.version == "v2":
174
+ if time_sig is not None:
175
+ mid.append(tokenizer.event2tokens(["time_signature", 0, 0, 0, time_sig_nn - 1, time_sig_dd - 1]))
176
+ if key_sig is not None:
177
+ mid.append(tokenizer.event2tokens(["key_signature", 0, 0, 0, key_sig_sf + 7, key_sig_mi]))
178
+ if bpm != 0:
179
+ mid.append(tokenizer.event2tokens(["set_tempo", 0, 0, 0, bpm]))
180
+ patches = {}
181
+ if instruments is None:
182
+ instruments = []
183
+ for instr in instruments:
184
+ patches[i] = patch2number[instr]
185
+ i = (i + 1) if i != 8 else 10
186
+ if drum_kit != "None":
187
+ patches[9] = drum_kits2number[drum_kit]
188
+ for i, (c, p) in enumerate(patches.items()):
189
+ mid.append(tokenizer.event2tokens(["patch_change", 0, 0, i + 1, c, p]))
190
+ mid = np.asarray([mid] * OUTPUT_BATCH_SIZE, dtype=np.int64)
191
+ mid_seq = mid.tolist()
192
+ if len(instruments) > 0:
193
+ disable_patch_change = True
194
+ disable_channels = [i for i in range(16) if i not in patches]
195
+ elif tab == 1 and mid is not None:
196
+ eps = 4 if reduce_cc_st else 0
197
+ mid = tokenizer.tokenize(MIDI.midi2score(mid), cc_eps=eps, tempo_eps=eps,
198
+ remap_track_channel=remap_track_channel,
199
+ add_default_instr=add_default_instr,
200
+ remove_empty_channels=remove_empty_channels)
201
+ mid = mid[:int(midi_events)]
202
+ mid = np.asarray([mid] * OUTPUT_BATCH_SIZE, dtype=np.int64)
203
+ mid_seq = mid.tolist()
204
+ elif tab == 2 and mid_seq is not None:
205
+ mid = np.asarray(mid_seq, dtype=np.int64)
206
+ if continuation_select > 0:
207
+ continuation_state.append(mid_seq)
208
+ mid = np.repeat(mid[continuation_select - 1:continuation_select], repeats=OUTPUT_BATCH_SIZE, axis=0)
209
+ mid_seq = mid.tolist()
210
+ else:
211
+ continuation_state.append(mid.shape[1])
212
+ else:
213
+ continuation_state = [0]
214
+ mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
215
+ mid = np.asarray([mid] * OUTPUT_BATCH_SIZE, dtype=np.int64)
216
+ mid_seq = mid.tolist()
217
+
218
+ if mid is not None:
219
+ max_len += mid.shape[1]
220
+
221
+ init_msgs = [create_msg("progress", [0, gen_events])]
222
+ if not (tab == 2 and continuation_select == 0):
223
+ for i in range(OUTPUT_BATCH_SIZE):
224
+ events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
225
+ init_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
226
+ create_msg("visualizer_append", [i, events])]
227
+ yield mid_seq, continuation_state, seed, send_msgs(init_msgs)
228
+ midi_generator = generate(model, mid, batch_size=OUTPUT_BATCH_SIZE, max_len=max_len, temp=temp,
229
+ top_p=top_p, top_k=top_k, disable_patch_change=disable_patch_change,
230
+ disable_control_change=not allow_cc, disable_channels=disable_channels,
231
+ generator=generator)
232
+ events = [list() for i in range(OUTPUT_BATCH_SIZE)]
233
+ t = time.time() + 1
234
+ for i, token_seqs in enumerate(midi_generator):
235
+ token_seqs = token_seqs.tolist()
236
+ for j in range(OUTPUT_BATCH_SIZE):
237
+ token_seq = token_seqs[j]
238
+ mid_seq[j].append(token_seq)
239
+ events[j].append(tokenizer.tokens2event(token_seq))
240
+ if time.time() - t > 0.5:
241
+ msgs = [create_msg("progress", [i + 1, gen_events])]
242
+ for j in range(OUTPUT_BATCH_SIZE):
243
+ msgs += [create_msg("visualizer_append", [j, events[j]])]
244
+ events[j] = list()
245
+ yield mid_seq, continuation_state, seed, send_msgs(msgs)
246
+ t = time.time()
247
+ yield mid_seq, continuation_state, seed, send_msgs([])
248
+
249
+
250
+ def finish_run(model_name, mid_seq):
251
+ if mid_seq is None:
252
+ outputs = [None] * OUTPUT_BATCH_SIZE
253
+ return *outputs, []
254
+ tokenizer = models[model_name].tokenizer
255
+ outputs = []
256
+ end_msgs = [create_msg("progress", [0, 0])]
257
+ if not os.path.exists("outputs"):
258
+ os.mkdir("outputs")
259
+ for i in range(OUTPUT_BATCH_SIZE):
260
+ events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
261
+ mid = tokenizer.detokenize(mid_seq[i])
262
+ with open(f"outputs/output{i + 1}.mid", 'wb') as f:
263
+ f.write(MIDI.score2midi(mid))
264
+ outputs.append(f"outputs/output{i + 1}.mid")
265
+ end_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
266
+ create_msg("visualizer_append", [i, events]),
267
+ create_msg("visualizer_end", i)]
268
+ return *outputs, send_msgs(end_msgs)
269
+
270
+
271
+ def synthesis_task(mid):
272
+ return synthesizer.synthesis(MIDI.score2opus(mid))
273
+
274
+ def render_audio(model_name, mid_seq, should_render_audio):
275
+ if (not should_render_audio) or mid_seq is None:
276
+ outputs = [None] * OUTPUT_BATCH_SIZE
277
+ return tuple(outputs)
278
+ tokenizer = models[model_name].tokenizer
279
+ outputs = []
280
+ if not os.path.exists("outputs"):
281
+ os.mkdir("outputs")
282
+ audio_futures = []
283
+ for i in range(OUTPUT_BATCH_SIZE):
284
+ mid = tokenizer.detokenize(mid_seq[i])
285
+ audio_future = thread_pool.submit(synthesis_task, mid)
286
+ audio_futures.append(audio_future)
287
+ for future in audio_futures:
288
+ outputs.append((44100, future.result()))
289
+ if OUTPUT_BATCH_SIZE == 1:
290
+ return outputs[0]
291
+ return tuple(outputs)
292
+
293
+
294
+ def undo_continuation(model_name, mid_seq, continuation_state):
295
+ if mid_seq is None or len(continuation_state) < 2:
296
+ return mid_seq, continuation_state, send_msgs([])
297
+ tokenizer = models[model_name].tokenizer
298
+ if isinstance(continuation_state[-1], list):
299
+ mid_seq = continuation_state[-1]
300
+ else:
301
+ mid_seq = [ms[:continuation_state[-1]] for ms in mid_seq]
302
+ continuation_state = continuation_state[:-1]
303
+ end_msgs = [create_msg("progress", [0, 0])]
304
+ for i in range(OUTPUT_BATCH_SIZE):
305
+ events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
306
+ end_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
307
+ create_msg("visualizer_append", [i, events]),
308
+ create_msg("visualizer_end", i)]
309
+ return mid_seq, continuation_state, send_msgs(end_msgs)
310
+
311
+
312
+ def load_javascript(dir="javascript"):
313
+ scripts_list = glob.glob(f"{dir}/*.js")
314
+ javascript = ""
315
+ for path in scripts_list:
316
+ with open(path, "r", encoding="utf8") as jsfile:
317
+ js_content = jsfile.read()
318
+ js_content = js_content.replace("const MIDI_OUTPUT_BATCH_SIZE=4;",
319
+ f"const MIDI_OUTPUT_BATCH_SIZE={OUTPUT_BATCH_SIZE};")
320
+ javascript += f"\n<!-- {path} --><script>{js_content}</script>"
321
+ template_response_ori = gr.routes.templates.TemplateResponse
322
+
323
+ def template_response(*args, **kwargs):
324
+ res = template_response_ori(*args, **kwargs)
325
+ res.body = res.body.replace(
326
+ b'</head>', f'{javascript}</head>'.encode("utf8"))
327
+ res.init_headers()
328
+ return res
329
+
330
+ gr.routes.templates.TemplateResponse = template_response
331
+
332
+
333
+ def hf_hub_download_retry(repo_id, filename):
334
+ print(f"downloading {repo_id} {filename}")
335
+ retry = 0
336
+ err = None
337
+ while retry < 30:
338
+ try:
339
+ return hf_hub_download(repo_id=repo_id, filename=filename)
340
+ except Exception as e:
341
+ err = e
342
+ retry += 1
343
+ if err:
344
+ raise err
345
+
346
+
347
+ number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
348
+ 40: "Blush", 48: "Orchestra"}
349
+ patch2number = {v: k for k, v in MIDI.Number2patch.items()}
350
+ drum_kits2number = {v: k for k, v in number2drum_kits.items()}
351
+ key_signatures = ['C♭', 'A♭m', 'G♭', 'E♭m', 'D♭', 'B♭m', 'A♭', 'Fm', 'E♭', 'Cm', 'B♭', 'Gm', 'F', 'Dm',
352
+ 'C', 'Am', 'G', 'Em', 'D', 'Bm', 'A', 'F♯m', 'E', 'C♯m', 'B', 'G♯m', 'F♯', 'D♯m', 'C♯', 'A♯m']
353
+
354
+ if __name__ == "__main__":
355
+ parser = argparse.ArgumentParser()
356
+ parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
357
+ parser.add_argument("--port", type=int, default=7860, help="gradio server port")
358
+ parser.add_argument("--device", type=str, default="cuda", help="device to run model")
359
+ parser.add_argument("--batch", type=int, default=8, help="batch size")
360
+ parser.add_argument("--max-gen", type=int, default=1024, help="max")
361
+ opt = parser.parse_args()
362
+ OUTPUT_BATCH_SIZE = opt.batch
363
+ soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
364
+ thread_pool = ThreadPoolExecutor(max_workers=OUTPUT_BATCH_SIZE)
365
+ synthesizer = MidiSynthesizer(soundfont_path)
366
+ models_info = {
367
+ "generic pretrain model (tv2o-medium) by skytnt": [
368
+ "skytnt/midi-model-tv2o-medium", {
369
+ "jpop": "skytnt/midi-model-tv2om-jpop-lora",
370
+ "touhou": "skytnt/midi-model-tv2om-touhou-lora"
371
+ }
372
+ ],
373
+ "generic pretrain model (tv2o-large) by asigalov61": [
374
+ "asigalov61/Music-Llama", {}
375
+ ],
376
+ "generic pretrain model (tv2o-medium) by asigalov61": [
377
+ "asigalov61/Music-Llama-Medium", {}
378
+ ],
379
+ "generic pretrain model (tv1-medium) by skytnt": [
380
+ "skytnt/midi-model", {}
381
+ ]
382
+ }
383
+ models = {}
384
+ if opt.device == "cuda":
385
+ torch.backends.cudnn.deterministic = True
386
+ torch.backends.cudnn.benchmark = False
387
+ torch.backends.cuda.matmul.allow_tf32 = True
388
+ torch.backends.cudnn.allow_tf32 = True
389
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
390
+ torch.backends.cuda.enable_flash_sdp(True)
391
+ for name, (repo_id, loras) in models_info.items():
392
+ model = MIDIModel.from_pretrained(repo_id)
393
+ model.to(device="cpu", dtype=torch.float32)
394
+ models[name] = model
395
+ for lora_name, lora_repo in loras.items():
396
+ model = MIDIModel.from_pretrained(repo_id)
397
+ print(f"loading lora {lora_repo} for {name}")
398
+ model = model.load_merge_lora(lora_repo)
399
+ model.to(device="cpu", dtype=torch.float32)
400
+ models[f"{name} with {lora_name} lora"] = model
401
+
402
+ load_javascript()
403
+ app = gr.Blocks(theme=gr.themes.Soft())
404
+ with app:
405
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Midi Composer</h1>")
406
+ gr.Markdown("![Visitors](https://api.visitorbadge.io/api/visitors?path=skytnt.midi-composer&style=flat)\n\n"
407
+ "Midi event transformer for symbolic music generation\n\n"
408
+ "Demo for [SkyTNT/midi-model](https://github.com/SkyTNT/midi-model)\n\n"
409
+ "[Open In Colab]"
410
+ "(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
411
+ " or [download windows app](https://github.com/SkyTNT/midi-model/releases)"
412
+ " for unlimited generation\n\n"
413
+ "**Update v1.3**: MIDITokenizerV2 and new MidiVisualizer\n\n"
414
+ "The current **best** model: generic pretrain model (tv2o-medium) by skytnt"
415
+ )
416
+ js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
417
+ js_msg.change(None, [js_msg], [], js="""
418
+ (msg_json) =>{
419
+ let msgs = JSON.parse(msg_json);
420
+ executeCallbacks(msgReceiveCallbacks, msgs);
421
+ return [];
422
+ }
423
+ """)
424
+ input_model = gr.Dropdown(label="select model", choices=list(models.keys()),
425
+ type="value", value=list(models.keys())[0])
426
+ tab_select = gr.State(value=0)
427
+ with gr.Tabs():
428
+ with gr.TabItem("custom prompt") as tab1:
429
+ input_instruments = gr.Dropdown(label="🪗instruments (auto if empty)", choices=list(patch2number.keys()),
430
+ multiselect=True, max_choices=15, type="value")
431
+ input_drum_kit = gr.Dropdown(label="🥁drum kit", choices=list(drum_kits2number.keys()), type="value",
432
+ value="None")
433
+ input_bpm = gr.Slider(label="BPM (beats per minute, auto if 0)", minimum=0, maximum=255,
434
+ step=1,
435
+ value=0)
436
+ input_time_sig = gr.Radio(label="time signature (only for tv2 models)",
437
+ value="auto",
438
+ choices=["auto", "4/4", "2/4", "3/4", "6/4", "7/4",
439
+ "2/2", "3/2", "4/2", "3/8", "5/8", "6/8", "7/8", "9/8", "12/8"]
440
+ )
441
+ input_key_sig = gr.Radio(label="key signature (only for tv2 models)",
442
+ value="auto",
443
+ choices=["auto"] + key_signatures,
444
+ type="index"
445
+ )
446
+ example1 = gr.Examples([
447
+ [[], "None"],
448
+ [["Acoustic Grand"], "None"],
449
+ [['Acoustic Grand', 'SynthStrings 2', 'SynthStrings 1', 'Pizzicato Strings',
450
+ 'Pad 2 (warm)', 'Tremolo Strings', 'String Ensemble 1'], "Orchestra"],
451
+ [['Trumpet', 'Oboe', 'Trombone', 'String Ensemble 1', 'Clarinet',
452
+ 'French Horn', 'Pad 4 (choir)', 'Bassoon', 'Flute'], "None"],
453
+ [['Flute', 'French Horn', 'Clarinet', 'String Ensemble 2', 'English Horn', 'Bassoon',
454
+ 'Oboe', 'Pizzicato Strings'], "Orchestra"],
455
+ [['Electric Piano 2', 'Lead 5 (charang)', 'Electric Bass(pick)', 'Lead 2 (sawtooth)',
456
+ 'Pad 1 (new age)', 'Orchestra Hit', 'Cello', 'Electric Guitar(clean)'], "Standard"],
457
+ [["Electric Guitar(clean)", "Electric Guitar(muted)", "Overdriven Guitar", "Distortion Guitar",
458
+ "Electric Bass(finger)"], "Standard"]
459
+ ], [input_instruments, input_drum_kit])
460
+ with gr.TabItem("midi prompt") as tab2:
461
+ input_midi = gr.File(label="input midi", file_types=[".midi", ".mid"], type="binary")
462
+ input_midi_events = gr.Slider(label="use first n midi events as prompt", minimum=1, maximum=512,
463
+ step=1,
464
+ value=128)
465
+ input_reduce_cc_st = gr.Checkbox(label="reduce control_change and set_tempo events", value=True)
466
+ input_remap_track_channel = gr.Checkbox(
467
+ label="remap tracks and channels so each track has only one channel and in order", value=True)
468
+ input_add_default_instr = gr.Checkbox(
469
+ label="add a default instrument to channels that don't have an instrument", value=True)
470
+ input_remove_empty_channels = gr.Checkbox(label="remove channels without notes", value=False)
471
+ example2 = gr.Examples([[file, 128] for file in glob.glob("example/*.mid")],
472
+ [input_midi, input_midi_events])
473
+ with gr.TabItem("last output prompt") as tab3:
474
+ gr.Markdown("Continue generating on the last output.")
475
+ input_continuation_select = gr.Radio(label="select output to continue generating", value="all",
476
+ choices=["all"] + [f"output{i + 1}" for i in
477
+ range(OUTPUT_BATCH_SIZE)],
478
+ type="index"
479
+ )
480
+ undo_btn = gr.Button("undo the last continuation")
481
+
482
+ tab1.select(lambda: 0, None, tab_select, queue=False)
483
+ tab2.select(lambda: 1, None, tab_select, queue=False)
484
+ tab3.select(lambda: 2, None, tab_select, queue=False)
485
+ input_seed = gr.Slider(label="seed", minimum=0, maximum=2 ** 31 - 1,
486
+ step=1, value=0)
487
+ input_seed_rand = gr.Checkbox(label="random seed", value=True)
488
+ input_gen_events = gr.Slider(label="generate max n midi events", minimum=1, maximum=opt.max_gen,
489
+ step=1, value=opt.max_gen // 2)
490
+ with gr.Accordion("options", open=False):
491
+ input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
492
+ input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.95)
493
+ input_top_k = gr.Slider(label="top k", minimum=1, maximum=128, step=1, value=20)
494
+ input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
495
+ input_render_audio = gr.Checkbox(label="render audio after generation", value=True)
496
+ example3 = gr.Examples([[1, 0.94, 128], [1, 0.98, 20], [1, 0.98, 12]],
497
+ [input_temp, input_top_p, input_top_k])
498
+ run_btn = gr.Button("generate", variant="primary")
499
+ # stop_btn = gr.Button("stop and output")
500
+ output_midi_seq = gr.State()
501
+ output_continuation_state = gr.State([0])
502
+ midi_outputs = []
503
+ audio_outputs = []
504
+ with gr.Tabs(elem_id="output_tabs"):
505
+ for i in range(OUTPUT_BATCH_SIZE):
506
+ with gr.TabItem(f"output {i + 1}") as tab1:
507
+ output_midi_visualizer = gr.HTML(elem_id=f"midi_visualizer_container_{i}")
508
+ output_audio = gr.Audio(label="output audio", format="mp3", elem_id=f"midi_audio_{i}")
509
+ output_midi = gr.File(label="output midi", file_types=[".mid"])
510
+ midi_outputs.append(output_midi)
511
+ audio_outputs.append(output_audio)
512
+ run_event = run_btn.click(run, [input_model, tab_select, output_midi_seq, output_continuation_state,
513
+ input_continuation_select, input_instruments, input_drum_kit, input_bpm,
514
+ input_time_sig, input_key_sig, input_midi, input_midi_events,
515
+ input_reduce_cc_st, input_remap_track_channel,
516
+ input_add_default_instr, input_remove_empty_channels,
517
+ input_seed, input_seed_rand, input_gen_events, input_temp, input_top_p,
518
+ input_top_k, input_allow_cc],
519
+ [output_midi_seq, output_continuation_state, input_seed, js_msg], queue=True)
520
+ finish_run_event = run_event.then(fn=finish_run,
521
+ inputs=[input_model, output_midi_seq],
522
+ outputs=midi_outputs + [js_msg],
523
+ queue=False)
524
+ finish_run_event.then(fn=render_audio,
525
+ inputs=[input_model, output_midi_seq, input_render_audio],
526
+ outputs=audio_outputs,
527
+ queue=False)
528
+ # stop_btn.click(None, [], [], cancels=run_event,
529
+ # queue=False)
530
+ undo_btn.click(undo_continuation, [input_model, output_midi_seq, output_continuation_state],
531
+ [output_midi_seq, output_continuation_state, js_msg], queue=False)
532
+ app.queue().launch(server_port=opt.port, share=opt.share, inbrowser=True, ssr_mode=False)
533
+ thread_pool.shutdown()
app_onnx.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import random
3
+ import argparse
4
+ import glob
5
+ import json
6
+ import os
7
+ import time
8
+ from concurrent.futures import ThreadPoolExecutor
9
+
10
+ import gradio as gr
11
+ import numpy as np
12
+ import onnxruntime as rt
13
+ import tqdm
14
+ from huggingface_hub import hf_hub_download
15
+
16
+ import MIDI
17
+ from midi_synthesizer import MidiSynthesizer
18
+ from midi_tokenizer import MIDITokenizer
19
+
20
+ MAX_SEED = np.iinfo(np.int32).max
21
+ in_space = os.getenv("SYSTEM") == "spaces"
22
+
23
+
24
+ def softmax(x, axis):
25
+ x_max = np.amax(x, axis=axis, keepdims=True)
26
+ exp_x_shifted = np.exp(x - x_max)
27
+ return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
28
+
29
+
30
+ def sample_top_p_k(probs, p, k, generator=None):
31
+ if generator is None:
32
+ generator = np.random
33
+ probs_idx = np.argsort(-probs, axis=-1)
34
+ probs_sort = np.take_along_axis(probs, probs_idx, -1)
35
+ probs_sum = np.cumsum(probs_sort, axis=-1)
36
+ mask = probs_sum - probs_sort > p
37
+ probs_sort[mask] = 0.0
38
+ mask = np.zeros(probs_sort.shape[-1])
39
+ mask[:k] = 1
40
+ probs_sort = probs_sort * mask
41
+ probs_sort /= np.sum(probs_sort, axis=-1, keepdims=True)
42
+ shape = probs_sort.shape
43
+ probs_sort_flat = probs_sort.reshape(-1, shape[-1])
44
+ probs_idx_flat = probs_idx.reshape(-1, shape[-1])
45
+ next_token = np.stack([generator.choice(idxs, p=pvals) for pvals, idxs in zip(probs_sort_flat, probs_idx_flat)])
46
+ next_token = next_token.reshape(*shape[:-1])
47
+ return next_token
48
+
49
+
50
+ def apply_io_binding(model: rt.InferenceSession, inputs, outputs, batch_size, past_len, cur_len):
51
+ io_binding = model.io_binding()
52
+ for input_ in model.get_inputs():
53
+ name = input_.name
54
+ if name.startswith("past_key_values"):
55
+ present_name = name.replace("past_key_values", "present")
56
+ if present_name in outputs:
57
+ v = outputs[present_name]
58
+ else:
59
+ v = rt.OrtValue.ortvalue_from_shape_and_type(
60
+ (batch_size, input_.shape[1], past_len, input_.shape[3]),
61
+ element_type=np.float32,
62
+ device_type=device)
63
+ inputs[name] = v
64
+ else:
65
+ v = inputs[name]
66
+ io_binding.bind_ortvalue_input(name, v)
67
+
68
+ for output in model.get_outputs():
69
+ name = output.name
70
+ if name.startswith("present"):
71
+ v = rt.OrtValue.ortvalue_from_shape_and_type(
72
+ (batch_size, output.shape[1], cur_len, output.shape[3]),
73
+ element_type=np.float32,
74
+ device_type=device)
75
+ outputs[name] = v
76
+ else:
77
+ v = outputs[name]
78
+ io_binding.bind_ortvalue_output(name, v)
79
+ return io_binding
80
+
81
+ def generate(model, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20,
82
+ disable_patch_change=False, disable_control_change=False, disable_channels=None, generator=None):
83
+ tokenizer = model[2]
84
+ if disable_channels is not None:
85
+ disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
86
+ else:
87
+ disable_channels = []
88
+ if generator is None:
89
+ generator = np.random
90
+ max_token_seq = tokenizer.max_token_seq
91
+ if prompt is None:
92
+ input_tensor = np.full((1, max_token_seq), tokenizer.pad_id, dtype=np.int64)
93
+ input_tensor[0, 0] = tokenizer.bos_id # bos
94
+ input_tensor = input_tensor[None, :, :]
95
+ input_tensor = np.repeat(input_tensor, repeats=batch_size, axis=0)
96
+ else:
97
+ if len(prompt.shape) == 2:
98
+ prompt = prompt[None, :]
99
+ prompt = np.repeat(prompt, repeats=batch_size, axis=0)
100
+ elif prompt.shape[0] == 1:
101
+ prompt = np.repeat(prompt, repeats=batch_size, axis=0)
102
+ elif len(prompt.shape) != 3 or prompt.shape[0] != batch_size:
103
+ raise ValueError(f"invalid shape for prompt, {prompt.shape}")
104
+ prompt = prompt[..., :max_token_seq]
105
+ if prompt.shape[-1] < max_token_seq:
106
+ prompt = np.pad(prompt, ((0, 0), (0, 0), (0, max_token_seq - prompt.shape[-1])),
107
+ mode="constant", constant_values=tokenizer.pad_id)
108
+ input_tensor = prompt
109
+ cur_len = input_tensor.shape[1]
110
+ bar = tqdm.tqdm(desc="generating", total=max_len - cur_len, disable=in_space)
111
+ model0_inputs = {}
112
+ model0_outputs = {}
113
+ emb_size = 1024
114
+ for output in model[0].get_outputs():
115
+ if output.name == "hidden":
116
+ emb_size = output.shape[2]
117
+ past_len = 0
118
+ with bar:
119
+ while cur_len < max_len:
120
+ end = [False] * batch_size
121
+ model0_inputs["x"] = rt.OrtValue.ortvalue_from_numpy(input_tensor[:, past_len:], device_type=device)
122
+ model0_outputs["hidden"] = rt.OrtValue.ortvalue_from_shape_and_type(
123
+ (batch_size, cur_len - past_len, emb_size),
124
+ element_type=np.float32,
125
+ device_type=device)
126
+ io_binding = apply_io_binding(model[0], model0_inputs, model0_outputs, batch_size, past_len, cur_len)
127
+ io_binding.synchronize_inputs()
128
+ model[0].run_with_iobinding(io_binding)
129
+ io_binding.synchronize_outputs()
130
+
131
+ hidden = model0_outputs["hidden"].numpy()[:, -1:]
132
+ next_token_seq = np.zeros((batch_size, 0), dtype=np.int64)
133
+ event_names = [""] * batch_size
134
+ model1_inputs = {"hidden": rt.OrtValue.ortvalue_from_numpy(hidden, device_type=device)}
135
+ model1_outputs = {}
136
+ for i in range(max_token_seq):
137
+ mask = np.zeros((batch_size, tokenizer.vocab_size), dtype=np.int64)
138
+ for b in range(batch_size):
139
+ if end[b]:
140
+ mask[b, tokenizer.pad_id] = 1
141
+ continue
142
+ if i == 0:
143
+ mask_ids = list(tokenizer.event_ids.values()) + [tokenizer.eos_id]
144
+ if disable_patch_change:
145
+ mask_ids.remove(tokenizer.event_ids["patch_change"])
146
+ if disable_control_change:
147
+ mask_ids.remove(tokenizer.event_ids["control_change"])
148
+ mask[b, mask_ids] = 1
149
+ else:
150
+ param_names = tokenizer.events[event_names[b]]
151
+ if i > len(param_names):
152
+ mask[b, tokenizer.pad_id] = 1
153
+ continue
154
+ param_name = param_names[i - 1]
155
+ mask_ids = tokenizer.parameter_ids[param_name]
156
+ if param_name == "channel":
157
+ mask_ids = [i for i in mask_ids if i not in disable_channels]
158
+ mask[b, mask_ids] = 1
159
+ mask = mask[:, None, :]
160
+ x = next_token_seq
161
+ if i != 0:
162
+ # cached
163
+ if i == 1:
164
+ hidden = np.zeros((batch_size, 0, emb_size), dtype=np.float32)
165
+ model1_inputs["hidden"] = rt.OrtValue.ortvalue_from_numpy(hidden, device_type=device)
166
+ x = x[:, -1:]
167
+ model1_inputs["x"] = rt.OrtValue.ortvalue_from_numpy(x, device_type=device)
168
+ model1_outputs["y"] = rt.OrtValue.ortvalue_from_shape_and_type(
169
+ (batch_size, 1, tokenizer.vocab_size),
170
+ element_type=np.float32,
171
+ device_type=device
172
+ )
173
+ io_binding = apply_io_binding(model[1], model1_inputs, model1_outputs, batch_size, i, i+1)
174
+ io_binding.synchronize_inputs()
175
+ model[1].run_with_iobinding(io_binding)
176
+ io_binding.synchronize_outputs()
177
+ logits = model1_outputs["y"].numpy()
178
+ scores = softmax(logits / temp, -1) * mask
179
+ samples = sample_top_p_k(scores, top_p, top_k, generator)
180
+ if i == 0:
181
+ next_token_seq = samples
182
+ for b in range(batch_size):
183
+ if end[b]:
184
+ continue
185
+ eid = samples[b].item()
186
+ if eid == tokenizer.eos_id:
187
+ end[b] = True
188
+ else:
189
+ event_names[b] = tokenizer.id_events[eid]
190
+ else:
191
+ next_token_seq = np.concatenate([next_token_seq, samples], axis=1)
192
+ if all([len(tokenizer.events[event_names[b]]) == i for b in range(batch_size) if not end[b]]):
193
+ break
194
+ if next_token_seq.shape[1] < max_token_seq:
195
+ next_token_seq = np.pad(next_token_seq,
196
+ ((0, 0), (0, max_token_seq - next_token_seq.shape[-1])),
197
+ mode="constant", constant_values=tokenizer.pad_id)
198
+ next_token_seq = next_token_seq[:, None, :]
199
+ input_tensor = np.concatenate([input_tensor, next_token_seq], axis=1)
200
+ past_len = cur_len
201
+ cur_len += 1
202
+ bar.update(1)
203
+ yield next_token_seq[:, 0]
204
+ if all(end):
205
+ break
206
+
207
+
208
+ def create_msg(name, data):
209
+ return {"name": name, "data": data}
210
+
211
+
212
+ def send_msgs(msgs):
213
+ return json.dumps(msgs)
214
+
215
+
216
+ def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
217
+ time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
218
+ remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
219
+ t = gen_events // 28
220
+ if "large" in model_name:
221
+ t = gen_events // 20
222
+ return t + 10
223
+
224
+
225
+ @spaces.GPU(duration=get_duration)
226
+ def run(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm, time_sig,
227
+ key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels,
228
+ seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
229
+ model = models[model_name]
230
+ model_base = rt.InferenceSession(model[0], providers=providers)
231
+ model_token = rt.InferenceSession(model[1], providers=providers)
232
+ tokenizer = model[2]
233
+ model = [model_base, model_token, tokenizer]
234
+ bpm = int(bpm)
235
+ if time_sig == "auto":
236
+ time_sig = None
237
+ time_sig_nn = 4
238
+ time_sig_dd = 2
239
+ else:
240
+ time_sig_nn, time_sig_dd = time_sig.split('/')
241
+ time_sig_nn = int(time_sig_nn)
242
+ time_sig_dd = {2: 1, 4: 2, 8: 3}[int(time_sig_dd)]
243
+ if key_sig == 0:
244
+ key_sig = None
245
+ key_sig_sf = 0
246
+ key_sig_mi = 0
247
+ else:
248
+ key_sig = (key_sig - 1)
249
+ key_sig_sf = key_sig // 2 - 7
250
+ key_sig_mi = key_sig % 2
251
+ gen_events = int(gen_events)
252
+ max_len = gen_events
253
+ if seed_rand:
254
+ seed = random.randint(0, MAX_SEED)
255
+ generator = np.random.RandomState(seed)
256
+ disable_patch_change = False
257
+ disable_channels = None
258
+ if tab == 0:
259
+ i = 0
260
+ mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
261
+ if tokenizer.version == "v2":
262
+ if time_sig is not None:
263
+ mid.append(tokenizer.event2tokens(["time_signature", 0, 0, 0, time_sig_nn - 1, time_sig_dd - 1]))
264
+ if key_sig is not None:
265
+ mid.append(tokenizer.event2tokens(["key_signature", 0, 0, 0, key_sig_sf + 7, key_sig_mi]))
266
+ if bpm != 0:
267
+ mid.append(tokenizer.event2tokens(["set_tempo", 0, 0, 0, bpm]))
268
+ patches = {}
269
+ if instruments is None:
270
+ instruments = []
271
+ for instr in instruments:
272
+ patches[i] = patch2number[instr]
273
+ i = (i + 1) if i != 8 else 10
274
+ if drum_kit != "None":
275
+ patches[9] = drum_kits2number[drum_kit]
276
+ for i, (c, p) in enumerate(patches.items()):
277
+ mid.append(tokenizer.event2tokens(["patch_change", 0, 0, i + 1, c, p]))
278
+ mid = np.asarray([mid] * OUTPUT_BATCH_SIZE, dtype=np.int64)
279
+ mid_seq = mid.tolist()
280
+ if len(instruments) > 0:
281
+ disable_patch_change = True
282
+ disable_channels = [i for i in range(16) if i not in patches]
283
+ elif tab == 1 and mid is not None:
284
+ eps = 4 if reduce_cc_st else 0
285
+ mid = tokenizer.tokenize(MIDI.midi2score(mid), cc_eps=eps, tempo_eps=eps,
286
+ remap_track_channel=remap_track_channel,
287
+ add_default_instr=add_default_instr,
288
+ remove_empty_channels=remove_empty_channels)
289
+ mid = mid[:int(midi_events)]
290
+ mid = np.asarray([mid] * OUTPUT_BATCH_SIZE, dtype=np.int64)
291
+ mid_seq = mid.tolist()
292
+ elif tab == 2 and mid_seq is not None:
293
+ mid = np.asarray(mid_seq, dtype=np.int64)
294
+ if continuation_select > 0:
295
+ continuation_state.append(mid_seq)
296
+ mid = np.repeat(mid[continuation_select - 1:continuation_select], repeats=OUTPUT_BATCH_SIZE, axis=0)
297
+ mid_seq = mid.tolist()
298
+ else:
299
+ continuation_state.append(mid.shape[1])
300
+ else:
301
+ continuation_state = [0]
302
+ mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
303
+ mid = np.asarray([mid] * OUTPUT_BATCH_SIZE, dtype=np.int64)
304
+ mid_seq = mid.tolist()
305
+
306
+ if mid is not None:
307
+ max_len += mid.shape[1]
308
+
309
+ init_msgs = [create_msg("progress", [0, gen_events])]
310
+ if not (tab == 2 and continuation_select == 0):
311
+ for i in range(OUTPUT_BATCH_SIZE):
312
+ events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
313
+ init_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
314
+ create_msg("visualizer_append", [i, events])]
315
+ yield mid_seq, continuation_state, seed, send_msgs(init_msgs)
316
+ midi_generator = generate(model, mid, batch_size=OUTPUT_BATCH_SIZE, max_len=max_len, temp=temp,
317
+ top_p=top_p, top_k=top_k, disable_patch_change=disable_patch_change,
318
+ disable_control_change=not allow_cc, disable_channels=disable_channels,
319
+ generator=generator)
320
+ events = [list() for i in range(OUTPUT_BATCH_SIZE)]
321
+ t = time.time() + 1
322
+ for i, token_seqs in enumerate(midi_generator):
323
+ token_seqs = token_seqs.tolist()
324
+ for j in range(OUTPUT_BATCH_SIZE):
325
+ token_seq = token_seqs[j]
326
+ mid_seq[j].append(token_seq)
327
+ events[j].append(tokenizer.tokens2event(token_seq))
328
+ if time.time() - t > 0.5:
329
+ msgs = [create_msg("progress", [i + 1, gen_events])]
330
+ for j in range(OUTPUT_BATCH_SIZE):
331
+ msgs += [create_msg("visualizer_append", [j, events[j]])]
332
+ events[j] = list()
333
+ yield mid_seq, continuation_state, seed, send_msgs(msgs)
334
+ t = time.time()
335
+ yield mid_seq, continuation_state, seed, send_msgs([])
336
+
337
+
338
+ def finish_run(model_name, mid_seq):
339
+ if mid_seq is None:
340
+ outputs = [None] * OUTPUT_BATCH_SIZE
341
+ return *outputs, []
342
+ tokenizer = models[model_name][2]
343
+ outputs = []
344
+ end_msgs = [create_msg("progress", [0, 0])]
345
+ if not os.path.exists("outputs"):
346
+ os.mkdir("outputs")
347
+ for i in range(OUTPUT_BATCH_SIZE):
348
+ events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
349
+ mid = tokenizer.detokenize(mid_seq[i])
350
+ with open(f"outputs/output{i + 1}.mid", 'wb') as f:
351
+ f.write(MIDI.score2midi(mid))
352
+ outputs.append(f"outputs/output{i + 1}.mid")
353
+ end_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
354
+ create_msg("visualizer_append", [i, events]),
355
+ create_msg("visualizer_end", i)]
356
+ return *outputs, send_msgs(end_msgs)
357
+
358
+
359
+ def synthesis_task(mid):
360
+ return synthesizer.synthesis(MIDI.score2opus(mid))
361
+
362
+ def render_audio(model_name, mid_seq, should_render_audio):
363
+ if (not should_render_audio) or mid_seq is None:
364
+ outputs = [None] * OUTPUT_BATCH_SIZE
365
+ return tuple(outputs)
366
+ tokenizer = models[model_name][2]
367
+ outputs = []
368
+ if not os.path.exists("outputs"):
369
+ os.mkdir("outputs")
370
+ audio_futures = []
371
+ for i in range(OUTPUT_BATCH_SIZE):
372
+ mid = tokenizer.detokenize(mid_seq[i])
373
+ audio_future = thread_pool.submit(synthesis_task, mid)
374
+ audio_futures.append(audio_future)
375
+ for future in audio_futures:
376
+ outputs.append((44100, future.result()))
377
+ if OUTPUT_BATCH_SIZE == 1:
378
+ return outputs[0]
379
+ return tuple(outputs)
380
+
381
+
382
+ def undo_continuation(model_name, mid_seq, continuation_state):
383
+ if mid_seq is None or len(continuation_state) < 2:
384
+ return mid_seq, continuation_state, send_msgs([])
385
+ tokenizer = models[model_name][2]
386
+ if isinstance(continuation_state[-1], list):
387
+ mid_seq = continuation_state[-1]
388
+ else:
389
+ mid_seq = [ms[:continuation_state[-1]] for ms in mid_seq]
390
+ continuation_state = continuation_state[:-1]
391
+ end_msgs = [create_msg("progress", [0, 0])]
392
+ for i in range(OUTPUT_BATCH_SIZE):
393
+ events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
394
+ end_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
395
+ create_msg("visualizer_append", [i, events]),
396
+ create_msg("visualizer_end", i)]
397
+ return mid_seq, continuation_state, send_msgs(end_msgs)
398
+
399
+
400
+ def load_javascript(dir="javascript"):
401
+ scripts_list = glob.glob(f"{dir}/*.js")
402
+ javascript = ""
403
+ for path in scripts_list:
404
+ with open(path, "r", encoding="utf8") as jsfile:
405
+ js_content = jsfile.read()
406
+ js_content = js_content.replace("const MIDI_OUTPUT_BATCH_SIZE=4;",
407
+ f"const MIDI_OUTPUT_BATCH_SIZE={OUTPUT_BATCH_SIZE};")
408
+ javascript += f"\n<!-- {path} --><script>{js_content}</script>"
409
+ template_response_ori = gr.routes.templates.TemplateResponse
410
+
411
+ def template_response(*args, **kwargs):
412
+ res = template_response_ori(*args, **kwargs)
413
+ res.body = res.body.replace(
414
+ b'</head>', f'{javascript}</head>'.encode("utf8"))
415
+ res.init_headers()
416
+ return res
417
+
418
+ gr.routes.templates.TemplateResponse = template_response
419
+
420
+
421
+ def hf_hub_download_retry(repo_id, filename):
422
+ print(f"downloading {repo_id} {filename}")
423
+ retry = 0
424
+ err = None
425
+ while retry < 30:
426
+ try:
427
+ return hf_hub_download(repo_id=repo_id, filename=filename)
428
+ except Exception as e:
429
+ err = e
430
+ retry += 1
431
+ if err:
432
+ raise err
433
+
434
+
435
+ def get_tokenizer(repo_id):
436
+ config_path = hf_hub_download_retry(repo_id=repo_id, filename=f"config.json")
437
+ with open(config_path, "r") as f:
438
+ config = json.load(f)
439
+ tokenizer = MIDITokenizer(config["tokenizer"]["version"])
440
+ tokenizer.set_optimise_midi(config["tokenizer"]["optimise_midi"])
441
+ return tokenizer
442
+
443
+
444
+ number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
445
+ 40: "Blush", 48: "Orchestra"}
446
+ patch2number = {v: k for k, v in MIDI.Number2patch.items()}
447
+ drum_kits2number = {v: k for k, v in number2drum_kits.items()}
448
+ key_signatures = ['C♭', 'A♭m', 'G♭', 'E♭m', 'D♭', 'B♭m', 'A♭', 'Fm', 'E♭', 'Cm', 'B♭', 'Gm', 'F', 'Dm',
449
+ 'C', 'Am', 'G', 'Em', 'D', 'Bm', 'A', 'F♯m', 'E', 'C♯m', 'B', 'G♯m', 'F♯', 'D♯m', 'C♯', 'A♯m']
450
+
451
+ if __name__ == "__main__":
452
+ parser = argparse.ArgumentParser()
453
+ parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
454
+ parser.add_argument("--port", type=int, default=7860, help="gradio server port")
455
+ parser.add_argument("--device", type=str, default="cuda", help="device to run model")
456
+ parser.add_argument("--batch", type=int, default=8, help="batch size")
457
+ parser.add_argument("--max-gen", type=int, default=1024, help="max")
458
+ opt = parser.parse_args()
459
+ OUTPUT_BATCH_SIZE = opt.batch
460
+ soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
461
+ thread_pool = ThreadPoolExecutor(max_workers=OUTPUT_BATCH_SIZE)
462
+ synthesizer = MidiSynthesizer(soundfont_path)
463
+ models_info = {
464
+ "generic pretrain model (tv2o-medium) by skytnt": [
465
+ "skytnt/midi-model-tv2o-medium", "", {
466
+ "jpop": "skytnt/midi-model-tv2om-jpop-lora",
467
+ "touhou": "skytnt/midi-model-tv2om-touhou-lora"
468
+ }
469
+ ],
470
+ "generic pretrain model (tv2o-large) by asigalov61": [
471
+ "asigalov61/Music-Llama", "", {}
472
+ ],
473
+ "generic pretrain model (tv2o-medium) by asigalov61": [
474
+ "asigalov61/Music-Llama-Medium", "", {}
475
+ ],
476
+ "generic pretrain model (tv1-medium) by skytnt": [
477
+ "skytnt/midi-model", "", {}
478
+ ]
479
+ }
480
+ models = {}
481
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
482
+ device = "cuda"
483
+
484
+ for name, (repo_id, path, loras) in models_info.items():
485
+ model_base_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_base.onnx")
486
+ model_token_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_token.onnx")
487
+ tokenizer = get_tokenizer(repo_id)
488
+ models[name] = [model_base_path, model_token_path, tokenizer]
489
+ for lora_name, lora_repo in loras.items():
490
+ model_base_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_base.onnx")
491
+ model_token_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_token.onnx")
492
+ models[f"{name} with {lora_name} lora"] = [model_base_path, model_token_path, tokenizer]
493
+
494
+ load_javascript()
495
+ app = gr.Blocks(theme=gr.themes.Soft())
496
+ with app:
497
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Midi Composer</h1>")
498
+ gr.Markdown("![Visitors](https://api.visitorbadge.io/api/visitors?path=skytnt.midi-composer&style=flat)\n\n"
499
+ "Midi event transformer for symbolic music generation\n\n"
500
+ "Demo for [SkyTNT/midi-model](https://github.com/SkyTNT/midi-model)\n\n"
501
+ "[Open In Colab]"
502
+ "(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
503
+ " or [download windows app](https://github.com/SkyTNT/midi-model/releases)"
504
+ " for unlimited generation\n\n"
505
+ "**Update v1.3**: MIDITokenizerV2 and new MidiVisualizer\n\n"
506
+ "The current **best** model: generic pretrain model (tv2o-medium) by skytnt"
507
+ )
508
+ js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
509
+ js_msg.change(None, [js_msg], [], js="""
510
+ (msg_json) =>{
511
+ let msgs = JSON.parse(msg_json);
512
+ executeCallbacks(msgReceiveCallbacks, msgs);
513
+ return [];
514
+ }
515
+ """)
516
+ input_model = gr.Dropdown(label="select model", choices=list(models.keys()),
517
+ type="value", value=list(models.keys())[0])
518
+ tab_select = gr.State(value=0)
519
+ with gr.Tabs():
520
+ with gr.TabItem("custom prompt") as tab1:
521
+ input_instruments = gr.Dropdown(label="🪗instruments (auto if empty)", choices=list(patch2number.keys()),
522
+ multiselect=True, max_choices=15, type="value")
523
+ input_drum_kit = gr.Dropdown(label="🥁drum kit", choices=list(drum_kits2number.keys()), type="value",
524
+ value="None")
525
+ input_bpm = gr.Slider(label="BPM (beats per minute, auto if 0)", minimum=0, maximum=255,
526
+ step=1,
527
+ value=0)
528
+ input_time_sig = gr.Radio(label="time signature (only for tv2 models)",
529
+ value="auto",
530
+ choices=["auto", "4/4", "2/4", "3/4", "6/4", "7/4",
531
+ "2/2", "3/2", "4/2", "3/8", "5/8", "6/8", "7/8", "9/8", "12/8"]
532
+ )
533
+ input_key_sig = gr.Radio(label="key signature (only for tv2 models)",
534
+ value="auto",
535
+ choices=["auto"] + key_signatures,
536
+ type="index"
537
+ )
538
+ example1 = gr.Examples([
539
+ [[], "None"],
540
+ [["Acoustic Grand"], "None"],
541
+ [['Acoustic Grand', 'SynthStrings 2', 'SynthStrings 1', 'Pizzicato Strings',
542
+ 'Pad 2 (warm)', 'Tremolo Strings', 'String Ensemble 1'], "Orchestra"],
543
+ [['Trumpet', 'Oboe', 'Trombone', 'String Ensemble 1', 'Clarinet',
544
+ 'French Horn', 'Pad 4 (choir)', 'Bassoon', 'Flute'], "None"],
545
+ [['Flute', 'French Horn', 'Clarinet', 'String Ensemble 2', 'English Horn', 'Bassoon',
546
+ 'Oboe', 'Pizzicato Strings'], "Orchestra"],
547
+ [['Electric Piano 2', 'Lead 5 (charang)', 'Electric Bass(pick)', 'Lead 2 (sawtooth)',
548
+ 'Pad 1 (new age)', 'Orchestra Hit', 'Cello', 'Electric Guitar(clean)'], "Standard"],
549
+ [["Electric Guitar(clean)", "Electric Guitar(muted)", "Overdriven Guitar", "Distortion Guitar",
550
+ "Electric Bass(finger)"], "Standard"]
551
+ ], [input_instruments, input_drum_kit])
552
+ with gr.TabItem("midi prompt") as tab2:
553
+ input_midi = gr.File(label="input midi", file_types=[".midi", ".mid"], type="binary")
554
+ input_midi_events = gr.Slider(label="use first n midi events as prompt", minimum=1, maximum=512,
555
+ step=1,
556
+ value=128)
557
+ input_reduce_cc_st = gr.Checkbox(label="reduce control_change and set_tempo events", value=True)
558
+ input_remap_track_channel = gr.Checkbox(
559
+ label="remap tracks and channels so each track has only one channel and in order", value=True)
560
+ input_add_default_instr = gr.Checkbox(
561
+ label="add a default instrument to channels that don't have an instrument", value=True)
562
+ input_remove_empty_channels = gr.Checkbox(label="remove channels without notes", value=False)
563
+ example2 = gr.Examples([[file, 128] for file in glob.glob("example/*.mid")],
564
+ [input_midi, input_midi_events])
565
+ with gr.TabItem("last output prompt") as tab3:
566
+ gr.Markdown("Continue generating on the last output.")
567
+ input_continuation_select = gr.Radio(label="select output to continue generating", value="all",
568
+ choices=["all"] + [f"output{i + 1}" for i in
569
+ range(OUTPUT_BATCH_SIZE)],
570
+ type="index"
571
+ )
572
+ undo_btn = gr.Button("undo the last continuation")
573
+
574
+ tab1.select(lambda: 0, None, tab_select, queue=False)
575
+ tab2.select(lambda: 1, None, tab_select, queue=False)
576
+ tab3.select(lambda: 2, None, tab_select, queue=False)
577
+ input_seed = gr.Slider(label="seed", minimum=0, maximum=2 ** 31 - 1,
578
+ step=1, value=0)
579
+ input_seed_rand = gr.Checkbox(label="random seed", value=True)
580
+ input_gen_events = gr.Slider(label="generate max n midi events", minimum=1, maximum=opt.max_gen,
581
+ step=1, value=opt.max_gen // 2)
582
+ with gr.Accordion("options", open=False):
583
+ input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
584
+ input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.95)
585
+ input_top_k = gr.Slider(label="top k", minimum=1, maximum=128, step=1, value=20)
586
+ input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
587
+ input_render_audio = gr.Checkbox(label="render audio after generation", value=True)
588
+ example3 = gr.Examples([[1, 0.94, 128], [1, 0.98, 20], [1, 0.98, 12]],
589
+ [input_temp, input_top_p, input_top_k])
590
+ run_btn = gr.Button("generate", variant="primary")
591
+ # stop_btn = gr.Button("stop and output")
592
+ output_midi_seq = gr.State()
593
+ output_continuation_state = gr.State([0])
594
+ midi_outputs = []
595
+ audio_outputs = []
596
+ with gr.Tabs(elem_id="output_tabs"):
597
+ for i in range(OUTPUT_BATCH_SIZE):
598
+ with gr.TabItem(f"output {i + 1}") as tab1:
599
+ output_midi_visualizer = gr.HTML(elem_id=f"midi_visualizer_container_{i}")
600
+ output_audio = gr.Audio(label="output audio", format="mp3", elem_id=f"midi_audio_{i}")
601
+ output_midi = gr.File(label="output midi", file_types=[".mid"])
602
+ midi_outputs.append(output_midi)
603
+ audio_outputs.append(output_audio)
604
+ run_event = run_btn.click(run, [input_model, tab_select, output_midi_seq, output_continuation_state,
605
+ input_continuation_select, input_instruments, input_drum_kit, input_bpm,
606
+ input_time_sig, input_key_sig, input_midi, input_midi_events,
607
+ input_reduce_cc_st, input_remap_track_channel,
608
+ input_add_default_instr, input_remove_empty_channels,
609
+ input_seed, input_seed_rand, input_gen_events, input_temp, input_top_p,
610
+ input_top_k, input_allow_cc],
611
+ [output_midi_seq, output_continuation_state, input_seed, js_msg], queue=True)
612
+ finish_run_event = run_event.then(fn=finish_run,
613
+ inputs=[input_model, output_midi_seq],
614
+ outputs=midi_outputs + [js_msg],
615
+ queue=False)
616
+ finish_run_event.then(fn=render_audio,
617
+ inputs=[input_model, output_midi_seq, input_render_audio],
618
+ outputs=audio_outputs,
619
+ queue=False)
620
+ # stop_btn.click(None, [], [], cancels=run_event,
621
+ # queue=False)
622
+ undo_btn.click(undo_continuation, [input_model, output_midi_seq, output_continuation_state],
623
+ [output_midi_seq, output_continuation_state, js_msg], queue=False)
624
+ app.queue().launch(server_port=opt.port, share=opt.share, inbrowser=True, ssr_mode=False)
625
+ thread_pool.shutdown()
midi_model.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Union, Dict, Any
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import tqdm
9
+ from peft import PeftConfig, LoraModel, load_peft_weights, set_peft_model_state_dict
10
+ from transformers import LlamaModel, LlamaConfig, DynamicCache, PretrainedConfig, PreTrainedModel
11
+
12
+ from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2, MIDITokenizer
13
+
14
+ config_name_list = ["tv1-medium", "tv2-medium", "tv2o-medium", "tv2-large", "tv2o-large"]
15
+
16
+
17
+ class MIDIModelConfig(PretrainedConfig):
18
+ model_type = "midi_model"
19
+
20
+ def __init__(self,
21
+ tokenizer: Union[MIDITokenizerV1, MIDITokenizerV2, Dict]=None,
22
+ net_config: Union[LlamaConfig, Dict]=None,
23
+ net_token_config: Union[LlamaConfig, Dict]=None,
24
+ **kwargs):
25
+ super().__init__(**kwargs)
26
+ if tokenizer:
27
+ if isinstance(tokenizer, dict):
28
+ self.tokenizer = MIDITokenizer(tokenizer["version"])
29
+ self.tokenizer.set_optimise_midi(tokenizer["optimise_midi"])
30
+ else:
31
+ self.tokenizer = tokenizer
32
+ else:
33
+ self.tokenizer = MIDITokenizer()
34
+ if net_config:
35
+ if isinstance(net_config, dict):
36
+ self.net_config = LlamaConfig(**net_config)
37
+ else:
38
+ self.net_config = net_config
39
+ else:
40
+ self.net_config = LlamaConfig()
41
+ if net_token_config:
42
+ if isinstance(net_token_config, dict):
43
+ self.net_token_config = LlamaConfig(**net_token_config)
44
+ else:
45
+ self.net_token_config = net_token_config
46
+ else:
47
+ self.net_token_config = LlamaConfig()
48
+ self.n_embd = self.net_token_config.hidden_size
49
+
50
+ def to_dict(self) -> Dict[str, Any]:
51
+ d = super().to_dict()
52
+ d["tokenizer"] = self.tokenizer.to_dict()
53
+ return d
54
+
55
+ def __str__(self):
56
+ d = {
57
+ "net": self.net_config.to_json_string(use_diff=False),
58
+ "net_token": self.net_token_config.to_json_string(use_diff=False)
59
+ }
60
+ return json.dumps(d, indent=4)
61
+
62
+ @staticmethod
63
+ def get_config(tokenizer_ver="v2", optimise_midi=True, n_layer=12, n_head=16, n_embd=1024, n_inner=4096):
64
+ tokenizer = MIDITokenizer(tokenizer_ver)
65
+ tokenizer.set_optimise_midi(optimise_midi)
66
+ net_config = LlamaConfig(vocab_size=tokenizer.vocab_size,
67
+ hidden_size=n_embd, num_attention_heads=n_head,
68
+ num_hidden_layers=n_layer, intermediate_size=n_inner,
69
+ pad_token_id=tokenizer.pad_id, max_position_embeddings=4096,
70
+ use_cache=False)
71
+ net_token_config = LlamaConfig(vocab_size=tokenizer.vocab_size,
72
+ hidden_size=n_embd, num_attention_heads=n_head // 4,
73
+ num_hidden_layers=n_layer // 4, intermediate_size=n_inner // 4,
74
+ pad_token_id=tokenizer.pad_id, max_position_embeddings=4096,
75
+ use_cache=False)
76
+ return MIDIModelConfig(tokenizer, net_config, net_token_config)
77
+
78
+ @staticmethod
79
+ def from_name(name="tv2o-medium"):
80
+ tv, size = name.split("-")
81
+ tv = tv[1:]
82
+ if tv[-1] == "o":
83
+ o = True
84
+ tv = tv[:-1]
85
+ else:
86
+ o = False
87
+ if tv not in ["v1", "v2"]:
88
+ raise ValueError(f"Unknown tokenizer version {tv}")
89
+ if size == "medium":
90
+ return MIDIModelConfig.get_config(tokenizer_ver=tv, optimise_midi=o,
91
+ n_layer=12, n_head=16, n_embd=1024, n_inner=4096)
92
+ elif size == "large":
93
+ return MIDIModelConfig.get_config(tokenizer_ver=tv, optimise_midi=o,
94
+ n_layer=24, n_head=16, n_embd=1024, n_inner=4096)
95
+ else:
96
+ raise ValueError(f"Unknown model size {size}")
97
+
98
+
99
+ class MIDIModel(PreTrainedModel):
100
+ config_class = MIDIModelConfig
101
+
102
+ def __init__(self, config: MIDIModelConfig, *args, **kwargs):
103
+ super(MIDIModel, self).__init__(config, *args, **kwargs)
104
+ self.tokenizer = config.tokenizer
105
+ self.net = LlamaModel(config.net_config)
106
+ self.net_token = LlamaModel(config.net_token_config)
107
+ self.lm_head = nn.Linear(config.n_embd, self.tokenizer.vocab_size, bias=False)
108
+
109
+ def load_merge_lora(self, model_id):
110
+ peft_config = PeftConfig.from_pretrained(model_id)
111
+ model = LoraModel(self, peft_config, adapter_name="default")
112
+ adapter_state_dict = load_peft_weights(model_id, device=str(self.device))
113
+ set_peft_model_state_dict(self, adapter_state_dict, "default")
114
+ return model.merge_and_unload()
115
+
116
+ def forward_token(self, hidden_state=None, x=None, cache=None):
117
+ """
118
+
119
+ :param hidden_state: (batch_size, n_embd)
120
+ :param x: (batch_size, token_sequence_length)
121
+ :param cache: Cache
122
+ :return: (batch_size, 1 + token_sequence_length, vocab_size)
123
+ """
124
+ if hidden_state is not None:
125
+ #if you use cache, you don't need to pass in hidden_state
126
+ hidden_state = hidden_state.unsqueeze(1) # (batch_size, 1, n_embd)
127
+ if x is not None:
128
+ x = self.net_token.embed_tokens(x)
129
+ if hidden_state is not None:
130
+ x = torch.cat([hidden_state, x], dim=1)
131
+ hidden_state = x
132
+ hidden_state = self.net_token.forward(inputs_embeds=hidden_state,
133
+ past_key_values=cache,
134
+ use_cache=cache is not None).last_hidden_state
135
+ return self.lm_head(hidden_state)
136
+
137
+ def forward(self, x, cache = None):
138
+ """
139
+ :param x: (batch_size, midi_sequence_length, token_sequence_length)
140
+ :param cache: Cache
141
+ :return: hidden (batch_size, midi_sequence_length, n_embd)
142
+ """
143
+
144
+ # merge token sequence
145
+ x = self.net.embed_tokens(x)
146
+ x = x.sum(dim=-2)
147
+ x = self.net.forward(inputs_embeds=x,
148
+ past_key_values=cache,
149
+ use_cache=cache is not None)
150
+ return x.last_hidden_state
151
+
152
+ def sample_top_p_k(self, probs, p, k, generator=None):
153
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
154
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
155
+ mask = probs_sum - probs_sort > p
156
+ probs_sort[mask] = 0.0
157
+ mask = torch.zeros(probs_sort.shape[-1], device=probs_sort.device)
158
+ mask[:k] = 1
159
+ probs_sort = probs_sort * mask
160
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
161
+ shape = probs_sort.shape
162
+ next_token = torch.multinomial(probs_sort.reshape(-1, shape[-1]),
163
+ num_samples=1, generator=generator).reshape(*shape[:-1], 1)
164
+ next_token = torch.gather(probs_idx, -1, next_token).reshape(*shape[:-1])
165
+ return next_token
166
+
167
+ @torch.inference_mode()
168
+ def generate(self, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20, generator=None):
169
+ tokenizer = self.tokenizer
170
+ max_token_seq = tokenizer.max_token_seq
171
+ if prompt is None:
172
+ input_tensor = torch.full((1, max_token_seq), tokenizer.pad_id, dtype=torch.long, device=self.device)
173
+ input_tensor[0, 0] = tokenizer.bos_id # bos
174
+ input_tensor = input_tensor.unsqueeze(0)
175
+ input_tensor = torch.cat([input_tensor] * batch_size, dim=0)
176
+ else:
177
+ if len(prompt.shape) == 2:
178
+ prompt = prompt[None, :]
179
+ prompt = np.repeat(prompt, repeats=batch_size, axis=0)
180
+ elif prompt.shape[0] == 1:
181
+ prompt = np.repeat(prompt, repeats=batch_size, axis=0)
182
+ elif len(prompt.shape) != 3 or prompt.shape[0] != batch_size:
183
+ raise ValueError(f"invalid shape for prompt, {prompt.shape}")
184
+ prompt = prompt[..., :max_token_seq]
185
+ if prompt.shape[-1] < max_token_seq:
186
+ prompt = np.pad(prompt, ((0, 0), (0, 0), (0, max_token_seq - prompt.shape[-1])),
187
+ mode="constant", constant_values=tokenizer.pad_id)
188
+ input_tensor = torch.from_numpy(prompt).to(dtype=torch.long, device=self.device)
189
+
190
+ cur_len = input_tensor.shape[1]
191
+ bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
192
+ cache1 = DynamicCache()
193
+ past_len = 0
194
+ with bar:
195
+ while cur_len < max_len:
196
+ end = [False] * batch_size
197
+ hidden = self.forward(input_tensor[:, past_len:], cache=cache1)[:, -1]
198
+ next_token_seq = None
199
+ event_names = [""] * batch_size
200
+ cache2 = DynamicCache()
201
+ for i in range(max_token_seq):
202
+ mask = torch.zeros((batch_size, tokenizer.vocab_size), dtype=torch.int64, device=self.device)
203
+ for b in range(batch_size):
204
+ if end[b]:
205
+ mask[b, tokenizer.pad_id] = 1
206
+ continue
207
+ if i == 0:
208
+ mask[b, list(tokenizer.event_ids.values()) + [tokenizer.eos_id]] = 1
209
+ else:
210
+ param_names = tokenizer.events[event_names[b]]
211
+ if i > len(param_names):
212
+ mask[b, tokenizer.pad_id] = 1
213
+ continue
214
+ mask[b, tokenizer.parameter_ids[param_names[i - 1]]] = 1
215
+ mask = mask.unsqueeze(1)
216
+ x = next_token_seq
217
+ if i != 0:
218
+ # cached
219
+ hidden = None
220
+ x = x[:, -1:]
221
+ logits = self.forward_token(hidden, x, cache=cache2)[:, -1:]
222
+ scores = torch.softmax(logits / temp, dim=-1) * mask
223
+ samples = self.sample_top_p_k(scores, top_p, top_k, generator=generator)
224
+ if i == 0:
225
+ next_token_seq = samples
226
+ for b in range(batch_size):
227
+ if end[b]:
228
+ continue
229
+ eid = samples[b].item()
230
+ if eid == tokenizer.eos_id:
231
+ end[b] = True
232
+ else:
233
+ event_names[b] = tokenizer.id_events[eid]
234
+ else:
235
+ next_token_seq = torch.cat([next_token_seq, samples], dim=1)
236
+ if all([len(tokenizer.events[event_names[b]]) == i for b in range(batch_size) if not end[b]]):
237
+ break
238
+
239
+ if next_token_seq.shape[1] < max_token_seq:
240
+ next_token_seq = F.pad(next_token_seq, (0, max_token_seq - next_token_seq.shape[1]),
241
+ "constant", value=tokenizer.pad_id)
242
+ next_token_seq = next_token_seq.unsqueeze(1)
243
+ input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
244
+ past_len = cur_len
245
+ cur_len += 1
246
+ bar.update(1)
247
+
248
+ if all(end):
249
+ break
250
+ return input_tensor.cpu().numpy()
midi_synthesizer.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from threading import Lock
2
+
3
+ import fluidsynth
4
+ import numpy as np
5
+
6
+
7
+ class MidiSynthesizer:
8
+ def __init__(self, soundfont_path, sample_rate=44100):
9
+ self.soundfont_path = soundfont_path
10
+ self.sample_rate = sample_rate
11
+ fl = fluidsynth.Synth(samplerate=float(sample_rate))
12
+ sfid = fl.sfload(soundfont_path)
13
+ self.devices = [[fl, sfid, False]]
14
+ self.devices_lock = Lock()
15
+
16
+ def get_fluidsynth(self):
17
+ with self.devices_lock:
18
+ for device in self.devices:
19
+ if not device[2]:
20
+ device[2] = True
21
+ return device
22
+ fl = fluidsynth.Synth(samplerate=float(self.sample_rate))
23
+ sfid = fl.sfload(self.soundfont_path)
24
+ device = [fl, sfid, True]
25
+ self.devices.append(device)
26
+ return device
27
+
28
+ def release_fluidsynth(self, device):
29
+ device[0].system_reset()
30
+ device[0].get_samples(self.sample_rate*5) # wait for silence
31
+ device[2] = False
32
+
33
+ def synthesis(self, midi_opus):
34
+ ticks_per_beat = midi_opus[0]
35
+ event_list = []
36
+ for track_idx, track in enumerate(midi_opus[1:]):
37
+ abs_t = 0
38
+ for event in track:
39
+ abs_t += event[1]
40
+ event_new = [*event]
41
+ event_new[1] = abs_t
42
+ event_list.append(event_new)
43
+ event_list = sorted(event_list, key=lambda e: e[1])
44
+
45
+ tempo = int((60 / 120) * 10 ** 6) # default 120 bpm
46
+ ss = np.empty((0, 2), dtype=np.int16)
47
+ device = self.get_fluidsynth()
48
+ fl, sfid = device[:-1]
49
+ last_t = 0
50
+ for c in range(16):
51
+ fl.program_select(c, sfid, 128 if c == 9 else 0, 0)
52
+ for event in event_list:
53
+ name = event[0]
54
+ sample_len = int(((event[1] / ticks_per_beat) * tempo / (10 ** 6)) * self.sample_rate)
55
+ sample_len -= int(((last_t / ticks_per_beat) * tempo / (10 ** 6)) * self.sample_rate)
56
+ last_t = event[1]
57
+ if sample_len > 0:
58
+ sample = fl.get_samples(sample_len).reshape(sample_len, 2)
59
+ ss = np.concatenate([ss, sample])
60
+ if name == "set_tempo":
61
+ tempo = event[2]
62
+ elif name == "patch_change":
63
+ c, p = event[2:4]
64
+ fl.program_select(c, sfid, 128 if c == 9 else 0, p)
65
+ elif name == "control_change":
66
+ c, cc, v = event[2:5]
67
+ fl.cc(c, cc, v)
68
+ elif name == "note_on" and event[3] > 0:
69
+ c, p, v = event[2:5]
70
+ fl.noteon(c, p, v)
71
+ elif name == "note_off" or (name == "note_on" and event[3] == 0):
72
+ c, p = event[2:4]
73
+ fl.noteoff(c, p)
74
+
75
+ self.release_fluidsynth(device)
76
+ if ss.shape[0] > 0:
77
+ max_val = np.abs(ss).max()
78
+ if max_val != 0:
79
+ ss = (ss / max_val) * np.iinfo(np.int16).max
80
+ ss = ss.astype(np.int16)
81
+ return ss
midi_tokenizer.py ADDED
@@ -0,0 +1,1196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Dict, Any
3
+
4
+ import PIL.Image
5
+ import numpy as np
6
+
7
+
8
+ class MIDITokenizerV1:
9
+ def __init__(self):
10
+ self.version = "v1"
11
+ self.optimise_midi = False
12
+ self.vocab_size = 0
13
+
14
+ def allocate_ids(size):
15
+ ids = [self.vocab_size + i for i in range(size)]
16
+ self.vocab_size += size
17
+ return ids
18
+
19
+ self.pad_id = allocate_ids(1)[0]
20
+ self.bos_id = allocate_ids(1)[0]
21
+ self.eos_id = allocate_ids(1)[0]
22
+ self.events = {
23
+ "note": ["time1", "time2", "track", "duration", "channel", "pitch", "velocity"],
24
+ "patch_change": ["time1", "time2", "track", "channel", "patch"],
25
+ "control_change": ["time1", "time2", "track", "channel", "controller", "value"],
26
+ "set_tempo": ["time1", "time2", "track", "bpm"],
27
+ }
28
+ self.event_parameters = {
29
+ "time1": 128, "time2": 16, "duration": 2048, "track": 128, "channel": 16, "pitch": 128, "velocity": 128,
30
+ "patch": 128, "controller": 128, "value": 128, "bpm": 256
31
+ }
32
+ self.event_ids = {e: allocate_ids(1)[0] for e in self.events.keys()}
33
+ self.id_events = {i: e for e, i in self.event_ids.items()}
34
+ self.parameter_ids = {p: allocate_ids(s) for p, s in self.event_parameters.items()}
35
+ self.max_token_seq = max([len(ps) for ps in self.events.values()]) + 1
36
+
37
+ def to_dict(self) -> Dict[str, Any]:
38
+ d = {
39
+ "version":self.version,
40
+ "optimise_midi":self.optimise_midi,
41
+ "vocab_size": self.vocab_size,
42
+ "events": self.events,
43
+ "event_parameters": self.event_parameters,
44
+ "max_token_seq": self.max_token_seq,
45
+ "pad_id": self.pad_id,
46
+ "bos_id": self.bos_id,
47
+ "eos_id": self.eos_id,
48
+ }
49
+ return d
50
+
51
+ def set_optimise_midi(self, optimise_midi=True):
52
+ self.optimise_midi = optimise_midi
53
+
54
+ @staticmethod
55
+ def tempo2bpm(tempo):
56
+ tempo = tempo / 10 ** 6 # us to s
57
+ bpm = 60 / tempo
58
+ return bpm
59
+
60
+ @staticmethod
61
+ def bpm2tempo(bpm):
62
+ if bpm == 0:
63
+ bpm = 1
64
+ tempo = int((60 / bpm) * 10 ** 6)
65
+ return tempo
66
+
67
+ def tokenize(self, midi_score, add_bos_eos=True, cc_eps=4, tempo_eps=4,
68
+ remap_track_channel=None, add_default_instr=None, remove_empty_channels=None):
69
+ if remap_track_channel is None: # set default value
70
+ remap_track_channel = self.optimise_midi
71
+ if add_default_instr is None:
72
+ add_default_instr = self.optimise_midi
73
+ if remove_empty_channels is None:
74
+ remove_empty_channels = self.optimise_midi
75
+
76
+ ticks_per_beat = midi_score[0]
77
+ event_list = {}
78
+ track_idx_map = {i: dict() for i in range(16)}
79
+ track_idx_dict = {}
80
+ channels = []
81
+ patch_channels = []
82
+ empty_channels = [True] * 16
83
+ channel_note_tracks = {i: list() for i in range(16)}
84
+ for track_idx, track in enumerate(midi_score[1:129]):
85
+ last_notes = {}
86
+ patch_dict = {}
87
+ control_dict = {}
88
+ last_tempo = 0
89
+ for event in track:
90
+ if event[0] not in self.events:
91
+ continue
92
+ c = -1
93
+ t = round(16 * event[1] / ticks_per_beat) # quantization
94
+ new_event = [event[0], t // 16, t % 16, track_idx] + event[2:]
95
+ if event[0] == "note":
96
+ c = event[3]
97
+ if c > 15 or c < 0:
98
+ continue
99
+ empty_channels[c] = False
100
+ track_idx_dict.setdefault(c, track_idx)
101
+ note_tracks = channel_note_tracks[c]
102
+ if track_idx not in note_tracks:
103
+ note_tracks.append(track_idx)
104
+ new_event[4] = max(1, round(16 * new_event[4] / ticks_per_beat))
105
+ elif event[0] == "set_tempo":
106
+ if new_event[4] == 0: # invalid tempo
107
+ continue
108
+ bpm = int(self.tempo2bpm(new_event[4]))
109
+ new_event[4] = min(bpm, 255)
110
+ if event[0] == "note":
111
+ key = tuple(new_event[:4] + new_event[5:-1])
112
+ else:
113
+ key = tuple(new_event[:-1])
114
+ if event[0] == "patch_change":
115
+ c, p = event[2:]
116
+ if c > 15 or c < 0:
117
+ continue
118
+ last_p = patch_dict.setdefault(c, None)
119
+ if last_p == p:
120
+ continue
121
+ patch_dict[c] = p
122
+ if c not in patch_channels:
123
+ patch_channels.append(c)
124
+ elif event[0] == "control_change":
125
+ c, cc, v = event[2:]
126
+ if c > 15 or c < 0:
127
+ continue
128
+ last_v = control_dict.setdefault((c, cc), 0)
129
+ if abs(last_v - v) < cc_eps:
130
+ continue
131
+ control_dict[(c, cc)] = v
132
+ elif event[0] == "set_tempo":
133
+ tempo = new_event[-1]
134
+ if abs(last_tempo - tempo) < tempo_eps:
135
+ continue
136
+ last_tempo = tempo
137
+
138
+ if c != -1:
139
+ if c not in channels:
140
+ channels.append(c)
141
+ tr_map = track_idx_map[c]
142
+ if track_idx not in tr_map:
143
+ tr_map[track_idx] = 0
144
+
145
+ if event[0] == "note": # to eliminate note overlap due to quantization
146
+ cp = tuple(new_event[5:7])
147
+ if cp in last_notes:
148
+ last_note_key, last_note = last_notes[cp]
149
+ last_t = last_note[1] * 16 + last_note[2]
150
+ last_note[4] = max(0, min(last_note[4], t - last_t))
151
+ if last_note[4] == 0:
152
+ event_list.pop(last_note_key)
153
+ last_notes[cp] = (key, new_event)
154
+ event_list[key] = new_event
155
+ event_list = list(event_list.values())
156
+
157
+ empty_channels = [c for c in channels if empty_channels[c]]
158
+
159
+ if remap_track_channel:
160
+ patch_channels = []
161
+ channels_count = 0
162
+ channels_map = {9: 9} if 9 in channels else {}
163
+ if remove_empty_channels:
164
+ channels = sorted(channels, key=lambda x: 1 if x in empty_channels else 0)
165
+ for c in channels:
166
+ if c == 9:
167
+ continue
168
+ channels_map[c] = channels_count
169
+ channels_count += 1
170
+ if channels_count == 9:
171
+ channels_count = 10
172
+ channels = list(channels_map.values())
173
+
174
+ track_count = 0
175
+ track_idx_map_order = [k for k, v in sorted(list(channels_map.items()), key=lambda x: x[1])]
176
+ for c in track_idx_map_order: # tracks not to remove
177
+ if remove_empty_channels and c in empty_channels:
178
+ continue
179
+ tr_map = track_idx_map[c]
180
+ for track_idx in tr_map:
181
+ note_tracks = channel_note_tracks[c]
182
+ if len(note_tracks) != 0 and track_idx not in note_tracks:
183
+ continue
184
+ track_count += 1
185
+ tr_map[track_idx] = track_count
186
+ for c in track_idx_map_order: # tracks to remove
187
+ if not (remove_empty_channels and c in empty_channels):
188
+ continue
189
+ tr_map = track_idx_map[c]
190
+ for track_idx in tr_map:
191
+ note_tracks = channel_note_tracks[c]
192
+ if not (len(note_tracks) != 0 and track_idx not in note_tracks):
193
+ continue
194
+ track_count += 1
195
+ tr_map[track_idx] = track_count
196
+
197
+ empty_channels = [channels_map[c] for c in empty_channels]
198
+ track_idx_dict = {}
199
+ for event in event_list:
200
+ name = event[0]
201
+ track_idx = event[3]
202
+ if name == "note":
203
+ c = event[5]
204
+ event[5] = channels_map[c]
205
+ event[3] = track_idx_map[c][track_idx]
206
+ track_idx_dict.setdefault(event[5], event[3])
207
+ # setdefault, so the track_idx is first of the channel
208
+ elif name == "set_tempo":
209
+ event[3] = 0
210
+ elif name == "control_change" or name == "patch_change":
211
+ c = event[4]
212
+ event[4] = channels_map[c]
213
+ tr_map = track_idx_map[c]
214
+ # move the event to first track of the channel if it's original track is empty
215
+ note_tracks = channel_note_tracks[c]
216
+ if len(note_tracks) != 0 and track_idx not in note_tracks:
217
+ track_idx = channel_note_tracks[c][0]
218
+ new_track_idx = tr_map[track_idx]
219
+ event[3] = new_track_idx
220
+ if name == "patch_change" and event[4] not in patch_channels:
221
+ patch_channels.append(event[4])
222
+
223
+ if add_default_instr:
224
+ for c in channels:
225
+ if c not in patch_channels and c in track_idx_dict:
226
+ event_list.append(["patch_change", 0, 0, track_idx_dict[c], c, 0])
227
+
228
+ events_name_order = {"set_tempo": 0, "patch_change": 1, "control_change": 2, "note": 3}
229
+ events_order = lambda e: e[1:4] + [events_name_order[e[0]]]
230
+ event_list = sorted(event_list, key=events_order)
231
+
232
+ setup_events = {}
233
+ notes_in_setup = False
234
+ for i, event in enumerate(event_list): # optimise setup
235
+ new_event = [*event]
236
+ if event[0] != "note":
237
+ new_event[1] = 0
238
+ new_event[2] = 0
239
+ has_next = False
240
+ has_pre = False
241
+ if i < len(event_list) - 1:
242
+ next_event = event_list[i + 1]
243
+ has_next = event[1] + event[2] == next_event[1] + next_event[2]
244
+ if notes_in_setup and i > 0:
245
+ pre_event = event_list[i - 1]
246
+ has_pre = event[1] + event[2] == pre_event[1] + pre_event[2]
247
+ if (event[0] == "note" and not has_next) or (notes_in_setup and not has_pre):
248
+ event_list = sorted(setup_events.values(), key=events_order) + event_list[i:]
249
+ break
250
+ else:
251
+ if event[0] == "note":
252
+ notes_in_setup = True
253
+ key = tuple([event[0]] + event[3:-2])
254
+ else:
255
+ key = tuple([event[0]] + event[3:-1])
256
+ setup_events[key] = new_event
257
+
258
+ last_t1 = 0
259
+ midi_seq = []
260
+ for event in event_list:
261
+ if remove_empty_channels and event[0] in ["control_change", "patch_change"] and event[4] in empty_channels:
262
+ continue
263
+ cur_t1 = event[1]
264
+ event[1] = event[1] - last_t1
265
+ tokens = self.event2tokens(event)
266
+ if not tokens:
267
+ continue
268
+ midi_seq.append(tokens)
269
+ last_t1 = cur_t1
270
+
271
+ if add_bos_eos:
272
+ bos = [self.bos_id] + [self.pad_id] * (self.max_token_seq - 1)
273
+ eos = [self.eos_id] + [self.pad_id] * (self.max_token_seq - 1)
274
+ midi_seq = [bos] + midi_seq + [eos]
275
+ return midi_seq
276
+
277
+ def event2tokens(self, event):
278
+ name = event[0]
279
+ params = event[1:]
280
+ if not all([0 <= params[i] < self.event_parameters[p] for i, p in enumerate(self.events[name])]):
281
+ return []
282
+ tokens = [self.event_ids[name]] + [self.parameter_ids[p][params[i]]
283
+ for i, p in enumerate(self.events[name])]
284
+ tokens += [self.pad_id] * (self.max_token_seq - len(tokens))
285
+ return tokens
286
+
287
+ def tokens2event(self, tokens):
288
+ if tokens[0] not in self.id_events:
289
+ return []
290
+ name = self.id_events[tokens[0]]
291
+ if len(tokens) <= len(self.events[name]):
292
+ return []
293
+ params = tokens[1:]
294
+ params = [params[i] - self.parameter_ids[p][0] for i, p in enumerate(self.events[name])]
295
+ if not all([0 <= params[i] < self.event_parameters[p] for i, p in enumerate(self.events[name])]):
296
+ return []
297
+ event = [name] + params
298
+ return event
299
+
300
+ def detokenize(self, midi_seq):
301
+ ticks_per_beat = 480
302
+ tracks_dict = {}
303
+ t1 = 0
304
+ for tokens in midi_seq:
305
+ if tokens[0] in self.id_events:
306
+ event = self.tokens2event(tokens)
307
+ if not event:
308
+ continue
309
+ name = event[0]
310
+ if name == "set_tempo":
311
+ event[4] = self.bpm2tempo(event[4])
312
+ if event[0] == "note":
313
+ event[4] = int(event[4] * ticks_per_beat / 16)
314
+ t1 += event[1]
315
+ t = t1 * 16 + event[2]
316
+ t = int(t * ticks_per_beat / 16)
317
+ track_idx = event[3]
318
+ if track_idx not in tracks_dict:
319
+ tracks_dict[track_idx] = []
320
+ tracks_dict[track_idx].append([event[0], t] + event[4:])
321
+ tracks = [tr for idx, tr in sorted(list(tracks_dict.items()), key=lambda it: it[0])]
322
+
323
+ for i in range(len(tracks)): # to eliminate note overlap
324
+ track = tracks[i]
325
+ track = sorted(track, key=lambda e: e[1])
326
+ last_note_t = {}
327
+ zero_len_notes = []
328
+ for e in reversed(track):
329
+ if e[0] == "note":
330
+ t, d, c, p = e[1:5]
331
+ key = (c, p)
332
+ if key in last_note_t:
333
+ d = min(d, max(last_note_t[key] - t, 0))
334
+ last_note_t[key] = t
335
+ e[2] = d
336
+ if d == 0:
337
+ zero_len_notes.append(e)
338
+ for e in zero_len_notes:
339
+ track.remove(e)
340
+ tracks[i] = track
341
+ return [ticks_per_beat, *tracks]
342
+
343
+ def midi2img(self, midi_score):
344
+ ticks_per_beat = midi_score[0]
345
+ notes = []
346
+ max_time = 1
347
+ track_num = len(midi_score[1:])
348
+ for track_idx, track in enumerate(midi_score[1:]):
349
+ for event in track:
350
+ t = round(16 * event[1] / ticks_per_beat)
351
+ if event[0] == "note":
352
+ d = max(1, round(16 * event[2] / ticks_per_beat))
353
+ c, p = event[3:5]
354
+ max_time = max(max_time, t + d + 1)
355
+ notes.append((track_idx, c, p, t, d))
356
+ img = np.zeros((128, max_time, 3), dtype=np.uint8)
357
+ colors = {(i, j): np.random.randint(50, 256, 3) for i in range(track_num) for j in range(16)}
358
+ for note in notes:
359
+ tr, c, p, t, d = note
360
+ img[p, t: t + d] = colors[(tr, c)]
361
+ img = PIL.Image.fromarray(np.flip(img, 0))
362
+ return img
363
+
364
+ def augment(self, midi_seq, max_pitch_shift=4, max_vel_shift=10, max_cc_val_shift=10, max_bpm_shift=10,
365
+ max_track_shift=0, max_channel_shift=16):
366
+ pitch_shift = random.randint(-max_pitch_shift, max_pitch_shift)
367
+ vel_shift = random.randint(-max_vel_shift, max_vel_shift)
368
+ cc_val_shift = random.randint(-max_cc_val_shift, max_cc_val_shift)
369
+ bpm_shift = random.randint(-max_bpm_shift, max_bpm_shift)
370
+ track_shift = random.randint(0, max_track_shift)
371
+ channel_shift = random.randint(0, max_channel_shift)
372
+ midi_seq_new = []
373
+ for tokens in midi_seq:
374
+ tokens_new = [*tokens]
375
+ if tokens[0] in self.id_events:
376
+ name = self.id_events[tokens[0]]
377
+ for i, pn in enumerate(self.events[name]):
378
+ if pn == "track":
379
+ tr = tokens[1 + i] - self.parameter_ids[pn][0]
380
+ tr += track_shift
381
+ tr = tr % self.event_parameters[pn]
382
+ tokens_new[1 + i] = self.parameter_ids[pn][tr]
383
+ elif pn == "channel":
384
+ c = tokens[1 + i] - self.parameter_ids[pn][0]
385
+ c0 = c
386
+ c += channel_shift
387
+ c = c % self.event_parameters[pn]
388
+ if c0 == 9:
389
+ c = 9
390
+ elif c == 9:
391
+ c = (9 + channel_shift) % self.event_parameters[pn]
392
+ tokens_new[1 + i] = self.parameter_ids[pn][c]
393
+
394
+ if name == "note":
395
+ c = tokens[5] - self.parameter_ids["channel"][0]
396
+ p = tokens[6] - self.parameter_ids["pitch"][0]
397
+ v = tokens[7] - self.parameter_ids["velocity"][0]
398
+ if c != 9: # no shift for drums
399
+ p += pitch_shift
400
+ if not 0 <= p < 128:
401
+ return midi_seq
402
+ v += vel_shift
403
+ v = max(1, min(127, v))
404
+ tokens_new[6] = self.parameter_ids["pitch"][p]
405
+ tokens_new[7] = self.parameter_ids["velocity"][v]
406
+ elif name == "control_change":
407
+ cc = tokens[5] - self.parameter_ids["controller"][0]
408
+ val = tokens[6] - self.parameter_ids["value"][0]
409
+ if cc in [1, 2, 7, 11]:
410
+ val += cc_val_shift
411
+ val = max(1, min(127, val))
412
+ tokens_new[6] = self.parameter_ids["value"][val]
413
+ elif name == "set_tempo":
414
+ bpm = tokens[4] - self.parameter_ids["bpm"][0]
415
+ bpm += bpm_shift
416
+ bpm = max(1, min(255, bpm))
417
+ tokens_new[4] = self.parameter_ids["bpm"][bpm]
418
+ midi_seq_new.append(tokens_new)
419
+ return midi_seq_new
420
+
421
+ def check_quality(self, midi_seq, alignment_min=0.3, tonality_min=0.8, piano_max=0.7, notes_bandwidth_min=3,
422
+ notes_density_max=50, notes_density_min=2.5, total_notes_max=20000, total_notes_min=256,
423
+ note_window_size=16):
424
+ total_notes = 0
425
+ channels = []
426
+ time_hist = [0] * 16
427
+ note_windows = {}
428
+ notes_sametime = []
429
+ notes_density_list = []
430
+ tonality_list = []
431
+ notes_bandwidth_list = []
432
+ instruments = {}
433
+ piano_channels = []
434
+ abs_t1 = 0
435
+ last_t = 0
436
+ for tsi, tokens in enumerate(midi_seq):
437
+ event = self.tokens2event(tokens)
438
+ if not event:
439
+ continue
440
+ t1, t2, tr = event[1:4]
441
+ abs_t1 += t1
442
+ t = abs_t1 * 16 + t2
443
+ c = None
444
+ if event[0] == "note":
445
+ d, c, p, v = event[4:]
446
+ total_notes += 1
447
+ time_hist[t2] += 1
448
+ if c != 9: # ignore drum channel
449
+ if c not in instruments:
450
+ instruments[c] = 0
451
+ if c not in piano_channels:
452
+ piano_channels.append(c)
453
+ note_windows.setdefault(abs_t1 // note_window_size, []).append(p)
454
+ if last_t != t:
455
+ notes_sametime = [(et, p_) for et, p_ in notes_sametime if et > last_t]
456
+ notes_sametime_p = [p_ for _, p_ in notes_sametime]
457
+ if len(notes_sametime) > 0:
458
+ notes_bandwidth_list.append(max(notes_sametime_p) - min(notes_sametime_p))
459
+ notes_sametime.append((t + d - 1, p))
460
+ elif event[0] == "patch_change":
461
+ c, p = event[4:]
462
+ instruments[c] = p
463
+ if p == 0 and c not in piano_channels:
464
+ piano_channels.append(c)
465
+ if c is not None and c not in channels:
466
+ channels.append(c)
467
+ last_t = t
468
+ reasons = []
469
+ if total_notes < total_notes_min:
470
+ reasons.append("total_min")
471
+ if total_notes > total_notes_max:
472
+ reasons.append("total_max")
473
+ if len(note_windows) == 0 and total_notes > 0:
474
+ reasons.append("drum_only")
475
+ if reasons:
476
+ return False, reasons
477
+ time_hist = sorted(time_hist, reverse=True)
478
+ alignment = sum(time_hist[:2]) / total_notes
479
+ for notes in note_windows.values():
480
+ key_hist = [0] * 12
481
+ for p in notes:
482
+ key_hist[p % 12] += 1
483
+ key_hist = sorted(key_hist, reverse=True)
484
+ tonality_list.append(sum(key_hist[:7]) / len(notes))
485
+ notes_density_list.append(len(notes) / note_window_size)
486
+ tonality_list = sorted(tonality_list)
487
+ tonality = sum(tonality_list) / len(tonality_list)
488
+ notes_bandwidth = sum(notes_bandwidth_list) / len(notes_bandwidth_list) if notes_bandwidth_list else 0
489
+ notes_density = max(notes_density_list) if notes_density_list else 0
490
+ piano_ratio = len(piano_channels) / len(channels)
491
+ if len(channels) <= 3: # ignore piano threshold if it is a piano solo midi
492
+ piano_max = 1
493
+ if alignment < alignment_min: # check weather the notes align to the bars (because some midi files are recorded)
494
+ reasons.append("alignment")
495
+ if tonality < tonality_min: # check whether the music is tonal
496
+ reasons.append("tonality")
497
+ if notes_bandwidth < notes_bandwidth_min: # check whether music is melodic line only
498
+ reasons.append("bandwidth")
499
+ if not notes_density_min < notes_density < notes_density_max:
500
+ reasons.append("density")
501
+ if piano_ratio > piano_max: # check whether most instruments is piano (because some midi files don't have instruments assigned correctly)
502
+ reasons.append("piano")
503
+ return not reasons, reasons
504
+
505
+
506
+ class MIDITokenizerV2:
507
+ def __init__(self):
508
+ self.version = "v2"
509
+ self.optimise_midi = False
510
+ self.vocab_size = 0
511
+
512
+ def allocate_ids(size):
513
+ ids = [self.vocab_size + i for i in range(size)]
514
+ self.vocab_size += size
515
+ return ids
516
+
517
+ self.pad_id = allocate_ids(1)[0]
518
+ self.bos_id = allocate_ids(1)[0]
519
+ self.eos_id = allocate_ids(1)[0]
520
+ self.events = {
521
+ "note": ["time1", "time2", "track", "channel", "pitch", "velocity", "duration"],
522
+ "patch_change": ["time1", "time2", "track", "channel", "patch"],
523
+ "control_change": ["time1", "time2", "track", "channel", "controller", "value"],
524
+ "set_tempo": ["time1", "time2", "track", "bpm"],
525
+ "time_signature": ["time1", "time2", "track", "nn", "dd"],
526
+ "key_signature": ["time1", "time2", "track", "sf", "mi"],
527
+ }
528
+ self.event_parameters = {
529
+ "time1": 128, "time2": 16, "duration": 2048, "track": 128, "channel": 16, "pitch": 128, "velocity": 128,
530
+ "patch": 128, "controller": 128, "value": 128, "bpm": 384, "nn": 16, "dd": 4, "sf": 15, "mi": 2
531
+ }
532
+ self.event_ids = {e: allocate_ids(1)[0] for e in self.events.keys()}
533
+ self.id_events = {i: e for e, i in self.event_ids.items()}
534
+ self.parameter_ids = {p: allocate_ids(s) for p, s in self.event_parameters.items()}
535
+ self.max_token_seq = max([len(ps) for ps in self.events.values()]) + 1
536
+
537
+ def to_dict(self) -> Dict[str, Any]:
538
+ d = {
539
+ "version":self.version,
540
+ "optimise_midi":self.optimise_midi,
541
+ "vocab_size": self.vocab_size,
542
+ "events": self.events,
543
+ "event_parameters": self.event_parameters,
544
+ "max_token_seq": self.max_token_seq,
545
+ "pad_id": self.pad_id,
546
+ "bos_id": self.bos_id,
547
+ "eos_id": self.eos_id,
548
+ }
549
+ return d
550
+
551
+ def set_optimise_midi(self, optimise_midi=True):
552
+ self.optimise_midi = optimise_midi
553
+
554
+ @staticmethod
555
+ def tempo2bpm(tempo):
556
+ tempo = tempo / 10 ** 6 # us to s
557
+ bpm = 60 / tempo
558
+ return bpm
559
+
560
+ @staticmethod
561
+ def bpm2tempo(bpm):
562
+ if bpm == 0:
563
+ bpm = 1
564
+ tempo = int((60 / bpm) * 10 ** 6)
565
+ return tempo
566
+
567
+ @staticmethod
568
+ def sf2key(sf):
569
+ # sf in key_signature to key.
570
+ # key represents the sequence from C note to B note (12 in total)
571
+ return (sf * 7) % 12
572
+
573
+ @staticmethod
574
+ def key2sf(k, mi):
575
+ # key to sf
576
+ sf = (k * 7) % 12
577
+ if sf > 6 or (mi == 1 and sf >= 5):
578
+ sf -= 12
579
+ return sf
580
+
581
+ @staticmethod
582
+ def detect_key_signature(key_hist, threshold=0.7):
583
+ if len(key_hist) != 12:
584
+ return None
585
+ if sum(key_hist) == 0:
586
+ return None
587
+ p = sum(sorted(key_hist, reverse=True)[:7]) / sum(key_hist)
588
+ if p < threshold:
589
+ return None
590
+ keys = [x[1] for x in sorted(zip(key_hist, range(len(key_hist))), reverse=True, key=lambda x: x[0])[:7]]
591
+ keys = sorted(keys)
592
+ semitones = []
593
+ for i in range(len(keys)):
594
+ dis = keys[i] - keys[i - 1]
595
+ if dis == 1 or dis == -11:
596
+ semitones.append(keys[i])
597
+ if len(semitones) != 2:
598
+ return None
599
+ semitones_dis = semitones[1] - semitones[0]
600
+ if semitones_dis == 5:
601
+ root_key = semitones[0]
602
+ elif semitones_dis == 7:
603
+ root_key = semitones[1]
604
+ else:
605
+ return None
606
+ return root_key
607
+
608
+ def tokenize(self, midi_score, add_bos_eos=True, cc_eps=4, tempo_eps=4,
609
+ remap_track_channel=None, add_default_instr=None, remove_empty_channels=None):
610
+ if remap_track_channel is None: # set default value
611
+ remap_track_channel = self.optimise_midi
612
+ if add_default_instr is None:
613
+ add_default_instr = self.optimise_midi
614
+ if remove_empty_channels is None:
615
+ remove_empty_channels = self.optimise_midi
616
+
617
+ ticks_per_beat = midi_score[0]
618
+ event_list = {}
619
+ track_idx_map = {i: dict() for i in range(16)}
620
+ track_idx_dict = {}
621
+ channels = []
622
+ patch_channels = []
623
+ empty_channels = [True] * 16
624
+ channel_note_tracks = {i: list() for i in range(16)}
625
+ note_key_hist = [0]*12
626
+ key_sigs = []
627
+ track_to_channels = {}
628
+ for track_idx, track in enumerate(midi_score[1:129]):
629
+ last_notes = {}
630
+ patch_dict = {}
631
+ control_dict = {}
632
+ last_bpm = 0
633
+ track_channels = []
634
+ track_to_channels.setdefault(track_idx, track_channels)
635
+ for event in track:
636
+ if event[0] not in self.events:
637
+ continue
638
+ name = event[0]
639
+ c = -1
640
+ t = round(16 * event[1] / ticks_per_beat) # quantization
641
+ new_event = [name, t // 16, t % 16, track_idx]
642
+ if name == "note":
643
+ d, c, p, v = event[2:]
644
+ if not (0 <= c <= 15):
645
+ continue
646
+ d = max(1, round(16 * d / ticks_per_beat))
647
+ new_event += [c, p, v, d]
648
+ empty_channels[c] = False
649
+ track_idx_dict.setdefault(c, track_idx)
650
+ note_tracks = channel_note_tracks[c]
651
+ if track_idx not in note_tracks:
652
+ note_tracks.append(track_idx)
653
+ if c != 9:
654
+ note_key_hist[p%12] += 1
655
+ if c not in track_channels:
656
+ track_channels.append(c)
657
+ elif name == "patch_change":
658
+ c, p = event[2:]
659
+ if not (0 <= c <= 15):
660
+ continue
661
+ new_event += [c, p]
662
+ last_p = patch_dict.setdefault(c, None)
663
+ if last_p == p:
664
+ continue
665
+ patch_dict[c] = p
666
+ if c not in patch_channels:
667
+ patch_channels.append(c)
668
+ elif name == "control_change":
669
+ c, cc, v = event[2:]
670
+ if not (0 <= c <= 15):
671
+ continue
672
+ new_event += [c, cc, v]
673
+ last_v = control_dict.setdefault((c, cc), 0)
674
+ if abs(last_v - v) < cc_eps:
675
+ continue
676
+ control_dict[(c, cc)] = v
677
+ elif name == "set_tempo":
678
+ tempo = event[2]
679
+ if tempo == 0: # invalid tempo
680
+ continue
681
+ bpm = min(int(self.tempo2bpm(tempo)), 383)
682
+ new_event += [bpm]
683
+ if abs(last_bpm - bpm) < tempo_eps:
684
+ continue
685
+ last_bpm = bpm
686
+ elif name == "time_signature":
687
+ nn, dd = event[2:4]
688
+ if not (1 <= nn <= 16 and 1 <= dd <= 4): # invalid
689
+ continue
690
+ nn -= 1 # make it start from 0
691
+ dd -= 1
692
+ new_event += [nn, dd]
693
+ elif name == "key_signature":
694
+ sf, mi = event[2:]
695
+ if not (-7 <= sf <= 7 and 0 <= mi <= 1): # invalid
696
+ continue
697
+ sf += 7
698
+ new_event += [sf, mi]
699
+ key_sigs.append(new_event)
700
+
701
+ if name in ["note", "time_signature", "key_signature"]:
702
+ key = tuple(new_event[:-2])
703
+ else:
704
+ key = tuple(new_event[:-1])
705
+
706
+ if c != -1:
707
+ if c not in channels:
708
+ channels.append(c)
709
+ tr_map = track_idx_map[c]
710
+ if track_idx not in tr_map:
711
+ tr_map[track_idx] = 0
712
+
713
+ if event[0] == "note": # to eliminate note overlap due to quantization
714
+ cp = tuple(new_event[4:6]) # channel pitch
715
+ if cp in last_notes:
716
+ last_note_key, last_note = last_notes[cp]
717
+ last_t = last_note[1] * 16 + last_note[2]
718
+ last_note[-1] = max(0, min(last_note[-1], t - last_t)) # modify duration
719
+ if last_note[-1] == 0:
720
+ event_list.pop(last_note_key)
721
+ last_notes[cp] = (key, new_event)
722
+ event_list[key] = new_event
723
+ event_list = list(event_list.values())
724
+
725
+ empty_channels = [c for c in channels if empty_channels[c]]
726
+
727
+ if remap_track_channel:
728
+ patch_channels = []
729
+ channels_count = 0
730
+ channels_map = {9: 9} if 9 in channels else {}
731
+ if remove_empty_channels:
732
+ channels = sorted(channels, key=lambda x: 1 if x in empty_channels else 0)
733
+ for c in channels:
734
+ if c == 9:
735
+ continue
736
+ channels_map[c] = channels_count
737
+ channels_count += 1
738
+ if channels_count == 9:
739
+ channels_count = 10
740
+ channels = list(channels_map.values())
741
+
742
+ track_count = 0
743
+ track_idx_map_order = [k for k, v in sorted(list(channels_map.items()), key=lambda x: x[1])]
744
+ for c in track_idx_map_order: # tracks not to remove
745
+ if remove_empty_channels and c in empty_channels:
746
+ continue
747
+ tr_map = track_idx_map[c]
748
+ for track_idx in tr_map:
749
+ note_tracks = channel_note_tracks[c]
750
+ if len(note_tracks) != 0 and track_idx not in note_tracks:
751
+ continue
752
+ track_count += 1
753
+ tr_map[track_idx] = track_count
754
+ for c in track_idx_map_order: # tracks to remove
755
+ if not (remove_empty_channels and c in empty_channels):
756
+ continue
757
+ tr_map = track_idx_map[c]
758
+ for track_idx in tr_map:
759
+ note_tracks = channel_note_tracks[c]
760
+ if not (len(note_tracks) != 0 and track_idx not in note_tracks):
761
+ continue
762
+ track_count += 1
763
+ tr_map[track_idx] = track_count
764
+
765
+ empty_channels = [channels_map[c] for c in empty_channels]
766
+ track_idx_dict = {}
767
+ key_sigs = []
768
+ key_signature_to_add = []
769
+ key_signature_to_remove = []
770
+ for event in event_list:
771
+ name = event[0]
772
+ track_idx = event[3]
773
+ if name == "note":
774
+ c = event[4]
775
+ event[4] = channels_map[c] # channel
776
+ event[3] = track_idx_map[c][track_idx] # track
777
+ track_idx_dict.setdefault(event[4], event[3])
778
+ # setdefault, so the track_idx is first of the channel
779
+ elif name in ["set_tempo", "time_signature"]:
780
+ event[3] = 0 # set track 0 for meta events
781
+ elif name == "key_signature":
782
+ new_channel_track_idxs = []
783
+ for c, tr_map in track_idx_map.items():
784
+ if track_idx in tr_map:
785
+ new_track_idx = tr_map[track_idx]
786
+ c = channels_map[c]
787
+ new_channel_track_idx = (c, new_track_idx)
788
+ if new_track_idx == 0:
789
+ continue
790
+ if new_channel_track_idx not in new_channel_track_idxs:
791
+ new_channel_track_idxs.append(new_channel_track_idx)
792
+
793
+ if len(new_channel_track_idxs) == 0:
794
+ if event[3] == 0: # keep key_signature on track 0 (meta)
795
+ key_sigs.append(event)
796
+ continue
797
+ event[3] = -1 # avoid remove same event
798
+ key_signature_to_remove.append(event) # empty track
799
+ continue
800
+ c, nt = new_channel_track_idxs[0]
801
+ event[3] = nt
802
+ key_sigs.append(event)
803
+ if c == 9:
804
+ event[4] = 7 # sf=0
805
+ for c, nt in new_channel_track_idxs[1:]:
806
+ new_event = [*event]
807
+ new_event[3] = nt
808
+ if c == 9:
809
+ new_event[4] = 7 # sf=0
810
+ key_sigs.append(new_event)
811
+ key_signature_to_add.append(new_event)
812
+ elif name == "control_change" or name == "patch_change":
813
+ c = event[4]
814
+ event[4] = channels_map[c] # channel
815
+ tr_map = track_idx_map[c]
816
+ # move the event to first track of the channel if it's original track is empty
817
+ note_tracks = channel_note_tracks[c]
818
+ if len(note_tracks) != 0 and track_idx not in note_tracks:
819
+ track_idx = channel_note_tracks[c][0]
820
+ new_track_idx = tr_map[track_idx]
821
+ event[3] = new_track_idx
822
+ if name == "patch_change" and event[4] not in patch_channels:
823
+ patch_channels.append(event[4])
824
+ for key_sig in key_signature_to_remove:
825
+ event_list.remove(key_sig)
826
+ event_list += key_signature_to_add
827
+ track_to_channels ={}
828
+ for c, tr_map in track_idx_map.items():
829
+ if c not in channels_map:
830
+ continue
831
+ c = channels_map[c]
832
+ for _, track_idx in tr_map.items():
833
+ track_to_channels.setdefault(track_idx, [])
834
+ cs = track_to_channels[track_idx]
835
+ if c not in cs:
836
+ cs.append(c)
837
+
838
+ if add_default_instr:
839
+ for c in channels:
840
+ if c not in patch_channels and c in track_idx_dict:
841
+ event_list.append(["patch_change", 0, 0, track_idx_dict[c], c, 0])
842
+
843
+ if len(key_sigs) == 0 or all([key_sig[4]==7 for key_sig in key_sigs]):
844
+ # detect key signature or fix the default key signature
845
+ root_key = self.detect_key_signature(note_key_hist)
846
+ if root_key is not None:
847
+ sf = self.key2sf(root_key, 0)
848
+ # print("detect_key_signature",sf)
849
+ if len(key_sigs) == 0:
850
+ for tr, cs in track_to_channels.items():
851
+ if remap_track_channel and tr == 0:
852
+ continue
853
+ new_event = ["key_signature", 0, 0, tr, (0 if (len(cs) == 1 and cs[0] == 9) else sf) + 7, 0]
854
+ event_list.append(new_event)
855
+ else:
856
+ for key_sig in key_sigs:
857
+ tr = key_sig[3]
858
+ if tr in track_to_channels:
859
+ cs = track_to_channels[tr]
860
+ if len(cs) == 1 and cs[0] == 9:
861
+ continue
862
+ key_sig[4] = sf + 7
863
+ key_sig[5] = 0
864
+ else:
865
+ # remove default key signature
866
+ for key_sig in key_sigs:
867
+ event_list.remove(key_sig)
868
+
869
+ events_name_order = ["time_signature", "key_signature", "set_tempo", "patch_change", "control_change", "note"]
870
+ events_name_order = {name: i for i, name in enumerate(events_name_order)}
871
+ events_order = lambda e: e[1:4] + [events_name_order[e[0]]]
872
+ event_list = sorted(event_list, key=events_order)
873
+
874
+ setup_events = {}
875
+ notes_in_setup = False
876
+ for i, event in enumerate(event_list): # optimise setup
877
+ new_event = [*event] # make copy of event
878
+ if event[0] not in ["note", "time_signature"]:
879
+ new_event[1] = 0
880
+ new_event[2] = 0
881
+ has_next = False
882
+ has_pre = False
883
+ if i < len(event_list) - 1:
884
+ next_event = event_list[i + 1]
885
+ has_next = event[1] + event[2] == next_event[1] + next_event[2]
886
+ if notes_in_setup and i > 0:
887
+ pre_event = event_list[i - 1]
888
+ has_pre = event[1] + event[2] == pre_event[1] + pre_event[2]
889
+ if (event[0] == "note" and not has_next) or (notes_in_setup and not has_pre):
890
+ event_list = sorted(setup_events.values(), key=events_order) + event_list[i:]
891
+ break
892
+ else:
893
+ if event[0] == "note":
894
+ notes_in_setup = True
895
+ if event[0] in ["note", "time_signature", "key_signature"]:
896
+ key = tuple([event[0]]+event[3:-2])
897
+ else:
898
+ key = tuple([event[0]]+event[3:-1])
899
+ setup_events[key] = new_event
900
+
901
+ last_t1 = 0
902
+ midi_seq = []
903
+ for event in event_list:
904
+ if remove_empty_channels and event[0] in ["control_change", "patch_change"] and event[4] in empty_channels:
905
+ continue
906
+ cur_t1 = event[1]
907
+ event[1] = event[1] - last_t1
908
+ tokens = self.event2tokens(event)
909
+ if not tokens:
910
+ continue
911
+ midi_seq.append(tokens)
912
+ last_t1 = cur_t1
913
+
914
+ if add_bos_eos:
915
+ bos = [self.bos_id] + [self.pad_id] * (self.max_token_seq - 1)
916
+ eos = [self.eos_id] + [self.pad_id] * (self.max_token_seq - 1)
917
+ midi_seq = [bos] + midi_seq + [eos]
918
+ return midi_seq
919
+
920
+ def event2tokens(self, event):
921
+ name = event[0]
922
+ params = event[1:]
923
+ if not all([0 <= params[i] < self.event_parameters[p] for i, p in enumerate(self.events[name])]):
924
+ return []
925
+ tokens = [self.event_ids[name]] + [self.parameter_ids[p][params[i]]
926
+ for i, p in enumerate(self.events[name])]
927
+ tokens += [self.pad_id] * (self.max_token_seq - len(tokens))
928
+ return tokens
929
+
930
+ def tokens2event(self, tokens):
931
+ if tokens[0] not in self.id_events:
932
+ return []
933
+ name = self.id_events[tokens[0]]
934
+ if len(tokens) <= len(self.events[name]):
935
+ return []
936
+ params = tokens[1:]
937
+ params = [params[i] - self.parameter_ids[p][0] for i, p in enumerate(self.events[name])]
938
+ if not all([0 <= params[i] < self.event_parameters[p] for i, p in enumerate(self.events[name])]):
939
+ return []
940
+ event = [name] + params
941
+ return event
942
+
943
+ def detokenize(self, midi_seq):
944
+ ticks_per_beat = 480
945
+ tracks_dict = {}
946
+ t1 = 0
947
+ for tokens in midi_seq:
948
+ if tokens[0] in self.id_events:
949
+ event = self.tokens2event(tokens)
950
+ if not event:
951
+ continue
952
+ name = event[0]
953
+ t1 += event[1]
954
+ t = t1 * 16 + event[2]
955
+ t = int(t * ticks_per_beat / 16)
956
+ track_idx = event[3]
957
+ event_new = [name, t]
958
+ if name == "note":
959
+ c, p, v, d = event[4:]
960
+ d = int(d * ticks_per_beat / 16)
961
+ event_new += [d, c, p, v]
962
+ elif name == "control_change" or name == "patch_change":
963
+ event_new += event[4:]
964
+ elif name == "set_tempo":
965
+ event_new += [self.bpm2tempo(event[4])]
966
+ elif name == "time_signature":
967
+ nn, dd = event[4:]
968
+ nn += 1
969
+ dd += 1
970
+ event_new += [nn, dd, 24, 8] # usually cc, bb = 24, 8
971
+ elif name == "key_signature":
972
+ sf, mi = event[4:]
973
+ sf -= 7
974
+ event_new += [sf, mi]
975
+ else: # should not go here
976
+ continue
977
+ if track_idx not in tracks_dict:
978
+ tracks_dict[track_idx] = []
979
+ tracks_dict[track_idx].append(event_new)
980
+ tracks = [tr for idx, tr in sorted(list(tracks_dict.items()), key=lambda it: it[0])]
981
+
982
+ for i in range(len(tracks)): # to eliminate note overlap
983
+ track = tracks[i]
984
+ track = sorted(track, key=lambda e: e[1])
985
+ last_note_t = {}
986
+ zero_len_notes = []
987
+ for e in reversed(track):
988
+ if e[0] == "note":
989
+ t, d, c, p = e[1:5]
990
+ key = (c, p)
991
+ if key in last_note_t:
992
+ d = min(d, max(last_note_t[key] - t, 0))
993
+ last_note_t[key] = t
994
+ e[2] = d
995
+ if d == 0:
996
+ zero_len_notes.append(e)
997
+ for e in zero_len_notes:
998
+ track.remove(e)
999
+ tracks[i] = track
1000
+ return [ticks_per_beat, *tracks]
1001
+
1002
+ def midi2img(self, midi_score):
1003
+ ticks_per_beat = midi_score[0]
1004
+ notes = []
1005
+ max_time = 1
1006
+ track_num = len(midi_score[1:])
1007
+ for track_idx, track in enumerate(midi_score[1:]):
1008
+ for event in track:
1009
+ t = round(16 * event[1] / ticks_per_beat)
1010
+ if event[0] == "note":
1011
+ d = max(1, round(16 * event[2] / ticks_per_beat))
1012
+ c, p = event[3:5]
1013
+ max_time = max(max_time, t + d + 1)
1014
+ notes.append((track_idx, c, p, t, d))
1015
+ img = np.zeros((128, max_time, 3), dtype=np.uint8)
1016
+ colors = {(i, j): np.random.randint(50, 256, 3) for i in range(track_num) for j in range(16)}
1017
+ for note in notes:
1018
+ tr, c, p, t, d = note
1019
+ img[p, t: t + d] = colors[(tr, c)]
1020
+ img = PIL.Image.fromarray(np.flip(img, 0))
1021
+ return img
1022
+
1023
+ def augment(self, midi_seq, max_pitch_shift=4, max_vel_shift=10, max_cc_val_shift=10, max_bpm_shift=10,
1024
+ max_track_shift=0, max_channel_shift=16):
1025
+ pitch_shift = random.randint(-max_pitch_shift, max_pitch_shift)
1026
+ vel_shift = random.randint(-max_vel_shift, max_vel_shift)
1027
+ cc_val_shift = random.randint(-max_cc_val_shift, max_cc_val_shift)
1028
+ bpm_shift = random.randint(-max_bpm_shift, max_bpm_shift)
1029
+ track_shift = random.randint(0, max_track_shift)
1030
+ channel_shift = random.randint(0, max_channel_shift)
1031
+ midi_seq_new = []
1032
+ key_signature_tokens = []
1033
+ track_to_channels = {}
1034
+ for tokens in midi_seq:
1035
+ tokens_new = [*tokens]
1036
+ if tokens[0] in self.id_events:
1037
+ name = self.id_events[tokens[0]]
1038
+ for i, pn in enumerate(self.events[name]):
1039
+ if pn == "track":
1040
+ tr = tokens[1 + i] - self.parameter_ids[pn][0]
1041
+ tr += track_shift
1042
+ tr = tr % self.event_parameters[pn]
1043
+ tokens_new[1 + i] = self.parameter_ids[pn][tr]
1044
+ elif pn == "channel":
1045
+ c = tokens[1 + i] - self.parameter_ids[pn][0]
1046
+ c0 = c
1047
+ c += channel_shift
1048
+ c = c % self.event_parameters[pn]
1049
+ if c0 == 9:
1050
+ c = 9
1051
+ elif c == 9:
1052
+ c = (9 + channel_shift) % self.event_parameters[pn]
1053
+ tokens_new[1 + i] = self.parameter_ids[pn][c]
1054
+
1055
+ if name == "note":
1056
+ tr = tokens[3] - self.parameter_ids["track"][0]
1057
+ c = tokens[4] - self.parameter_ids["channel"][0]
1058
+ p = tokens[5] - self.parameter_ids["pitch"][0]
1059
+ v = tokens[6] - self.parameter_ids["velocity"][0]
1060
+ if c != 9: # no shift for drums
1061
+ p += pitch_shift
1062
+ if not 0 <= p < 128:
1063
+ return midi_seq
1064
+ v += vel_shift
1065
+ v = max(1, min(127, v))
1066
+ tokens_new[5] = self.parameter_ids["pitch"][p]
1067
+ tokens_new[6] = self.parameter_ids["velocity"][v]
1068
+ track_to_channels.setdefault(tr, [])
1069
+ cs = track_to_channels[tr]
1070
+ if c not in cs:
1071
+ cs.append(c)
1072
+ elif name == "control_change":
1073
+ cc = tokens[5] - self.parameter_ids["controller"][0]
1074
+ val = tokens[6] - self.parameter_ids["value"][0]
1075
+ if cc in [1, 2, 7, 11]:
1076
+ val += cc_val_shift
1077
+ val = max(1, min(127, val))
1078
+ tokens_new[6] = self.parameter_ids["value"][val]
1079
+ elif name == "set_tempo":
1080
+ bpm = tokens[4] - self.parameter_ids["bpm"][0]
1081
+ bpm += bpm_shift
1082
+ bpm = max(1, min(383, bpm))
1083
+ tokens_new[4] = self.parameter_ids["bpm"][bpm]
1084
+ elif name == "key_signature":
1085
+ sf = tokens[4] - self.parameter_ids["sf"][0]
1086
+ mi = tokens[5] - self.parameter_ids["mi"][0]
1087
+ sf -= 7
1088
+ k = self.sf2key(sf)
1089
+ k = (k + pitch_shift) % 12
1090
+ sf = self.key2sf(k, mi)
1091
+ sf += 7
1092
+ tokens_new[4] = self.parameter_ids["sf"][sf]
1093
+ tokens_new[5] = self.parameter_ids["mi"][mi]
1094
+ key_signature_tokens.append(tokens_new)
1095
+ midi_seq_new.append(tokens_new)
1096
+ for tokens in key_signature_tokens:
1097
+ tr = tokens[3] - self.parameter_ids["track"][0]
1098
+ if tr in track_to_channels:
1099
+ cs = track_to_channels[tr]
1100
+ if len(cs) == 1 and cs[0] == 9:
1101
+ tokens[4] = self.parameter_ids["sf"][7] # sf=0
1102
+ return midi_seq_new
1103
+
1104
+ def check_quality(self, midi_seq, alignment_min=0.3, tonality_min=0.8, piano_max=0.7, notes_bandwidth_min=3,
1105
+ notes_density_max=50, notes_density_min=2.5, total_notes_max=20000, total_notes_min=256,
1106
+ note_window_size=16):
1107
+ total_notes = 0
1108
+ channels = []
1109
+ time_hist = [0] * 16
1110
+ note_windows = {}
1111
+ notes_sametime = []
1112
+ notes_density_list = []
1113
+ tonality_list = []
1114
+ notes_bandwidth_list = []
1115
+ instruments = {}
1116
+ piano_channels = []
1117
+ abs_t1 = 0
1118
+ last_t = 0
1119
+ for tsi, tokens in enumerate(midi_seq):
1120
+ event = self.tokens2event(tokens)
1121
+ if not event:
1122
+ continue
1123
+ t1, t2, tr = event[1:4]
1124
+ abs_t1 += t1
1125
+ t = abs_t1 * 16 + t2
1126
+ c = None
1127
+ if event[0] == "note":
1128
+ c, p, v, d = event[4:]
1129
+ total_notes += 1
1130
+ time_hist[t2] += 1
1131
+ if c != 9: # ignore drum channel
1132
+ if c not in instruments:
1133
+ instruments[c] = 0
1134
+ if c not in piano_channels:
1135
+ piano_channels.append(c)
1136
+ note_windows.setdefault(abs_t1 // note_window_size, []).append(p)
1137
+ if last_t != t:
1138
+ notes_sametime = [(et, p_) for et, p_ in notes_sametime if et > last_t]
1139
+ notes_sametime_p = [p_ for _, p_ in notes_sametime]
1140
+ if len(notes_sametime) > 0:
1141
+ notes_bandwidth_list.append(max(notes_sametime_p) - min(notes_sametime_p))
1142
+ notes_sametime.append((t + d - 1, p))
1143
+ elif event[0] == "patch_change":
1144
+ c, p = event[4:]
1145
+ instruments[c] = p
1146
+ if p == 0 and c not in piano_channels:
1147
+ piano_channels.append(c)
1148
+ if c is not None and c not in channels:
1149
+ channels.append(c)
1150
+ last_t = t
1151
+ reasons = []
1152
+ if total_notes < total_notes_min:
1153
+ reasons.append("total_min")
1154
+ if total_notes > total_notes_max:
1155
+ reasons.append("total_max")
1156
+ if len(note_windows) == 0 and total_notes > 0:
1157
+ reasons.append("drum_only")
1158
+ if reasons:
1159
+ return False, reasons
1160
+ time_hist = sorted(time_hist, reverse=True)
1161
+ alignment = sum(time_hist[:2]) / total_notes
1162
+ for notes in note_windows.values():
1163
+ key_hist = [0] * 12
1164
+ for p in notes:
1165
+ key_hist[p % 12] += 1
1166
+ key_hist = sorted(key_hist, reverse=True)
1167
+ tonality_list.append(sum(key_hist[:7]) / len(notes))
1168
+ notes_density_list.append(len(notes) / note_window_size)
1169
+ tonality_list = sorted(tonality_list)
1170
+ tonality = sum(tonality_list) / len(tonality_list)
1171
+ notes_bandwidth = sum(notes_bandwidth_list) / len(notes_bandwidth_list) if notes_bandwidth_list else 0
1172
+ notes_density = max(notes_density_list) if notes_density_list else 0
1173
+ piano_ratio = len(piano_channels) / len(channels)
1174
+ if len(channels) <= 3: # ignore piano threshold if it is a piano solo midi
1175
+ piano_max = 1
1176
+ if alignment < alignment_min: # check weather the notes align to the bars (because some midi files are recorded)
1177
+ reasons.append("alignment")
1178
+ if tonality < tonality_min: # check whether the music is tonal
1179
+ reasons.append("tonality")
1180
+ if notes_bandwidth < notes_bandwidth_min: # check whether music is melodic line only
1181
+ reasons.append("bandwidth")
1182
+ if not notes_density_min < notes_density < notes_density_max:
1183
+ reasons.append("density")
1184
+ if piano_ratio > piano_max: # check whether most instruments is piano (because some midi files don't have instruments assigned correctly)
1185
+ reasons.append("piano")
1186
+ return not reasons, reasons
1187
+
1188
+
1189
+ class MIDITokenizer:
1190
+ def __new__(cls, version="v2"):
1191
+ if version == "v1":
1192
+ return MIDITokenizerV1()
1193
+ elif version == "v2":
1194
+ return MIDITokenizerV2()
1195
+ else:
1196
+ raise ValueError(f"Unsupported version: {version}")
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ fluidsynth
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu124
2
+ Pillow
3
+ numpy
4
+ torch
5
+ onnxruntime-gpu
6
+ peft>=0.13.0
7
+ transformers>=4.36
8
+ gradio==5.3.0
9
+ pyfluidsynth
10
+ tqdm
11
+ huggingface_hub