Spaces:
Running
Running
update MIDITokenizer
Browse files- midi_tokenizer.py +50 -14
midi_tokenizer.py
CHANGED
|
@@ -200,7 +200,7 @@ class MIDITokenizerV1:
|
|
| 200 |
note_tracks = channel_note_tracks[c]
|
| 201 |
if len(note_tracks) != 0 and track_idx not in note_tracks:
|
| 202 |
track_idx = channel_note_tracks[c][0]
|
| 203 |
-
new_track_idx = tr_map
|
| 204 |
event[3] = new_track_idx
|
| 205 |
if name == "patch_change" and event[4] not in patch_channels:
|
| 206 |
patch_channels.append(event[4])
|
|
@@ -235,7 +235,9 @@ class MIDITokenizerV1:
|
|
| 235 |
else:
|
| 236 |
if event[0] == "note":
|
| 237 |
notes_in_setup = True
|
| 238 |
-
|
|
|
|
|
|
|
| 239 |
setup_events[key] = new_event
|
| 240 |
|
| 241 |
last_t1 = 0
|
|
@@ -551,6 +553,8 @@ class MIDITokenizerV2:
|
|
| 551 |
def detect_key_signature(key_hist, threshold=0.7):
|
| 552 |
if len(key_hist) != 12:
|
| 553 |
return None
|
|
|
|
|
|
|
| 554 |
p = sum(sorted(key_hist, reverse=True)[:7]) / sum(key_hist)
|
| 555 |
if p < threshold:
|
| 556 |
return None
|
|
@@ -590,7 +594,7 @@ class MIDITokenizerV2:
|
|
| 590 |
empty_channels = [True] * 16
|
| 591 |
channel_note_tracks = {i: list() for i in range(16)}
|
| 592 |
note_key_hist = [0]*12
|
| 593 |
-
|
| 594 |
track_to_channels = {}
|
| 595 |
for track_idx, track in enumerate(midi_score[1:129]):
|
| 596 |
last_notes = {}
|
|
@@ -661,11 +665,11 @@ class MIDITokenizerV2:
|
|
| 661 |
sf, mi = event[2:]
|
| 662 |
if not (-7 <= sf <= 7 and 0 <= mi <= 1): # invalid
|
| 663 |
continue
|
| 664 |
-
key_sig_num += 1
|
| 665 |
sf += 7
|
| 666 |
new_event += [sf, mi]
|
|
|
|
| 667 |
|
| 668 |
-
if name
|
| 669 |
key = tuple(new_event[:-2])
|
| 670 |
else:
|
| 671 |
key = tuple(new_event[:-1])
|
|
@@ -731,7 +735,9 @@ class MIDITokenizerV2:
|
|
| 731 |
|
| 732 |
empty_channels = [channels_map[c] for c in empty_channels]
|
| 733 |
track_idx_dict = {}
|
|
|
|
| 734 |
key_signature_to_add = []
|
|
|
|
| 735 |
for event in event_list:
|
| 736 |
name = event[0]
|
| 737 |
track_idx = event[3]
|
|
@@ -748,14 +754,23 @@ class MIDITokenizerV2:
|
|
| 748 |
for c, tr_map in track_idx_map.items():
|
| 749 |
if track_idx in tr_map:
|
| 750 |
new_track_idx = tr_map[track_idx]
|
|
|
|
| 751 |
new_channel_track_idx = (c, new_track_idx)
|
|
|
|
|
|
|
| 752 |
if new_channel_track_idx not in new_channel_track_idxs:
|
| 753 |
new_channel_track_idxs.append(new_channel_track_idx)
|
|
|
|
| 754 |
if len(new_channel_track_idxs) == 0:
|
| 755 |
-
event[3]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 756 |
continue
|
| 757 |
c, nt = new_channel_track_idxs[0]
|
| 758 |
event[3] = nt
|
|
|
|
| 759 |
if c == 9:
|
| 760 |
event[4] = 7 # sf=0
|
| 761 |
for c, nt in new_channel_track_idxs[1:]:
|
|
@@ -763,6 +778,7 @@ class MIDITokenizerV2:
|
|
| 763 |
new_event[3] = nt
|
| 764 |
if c == 9:
|
| 765 |
new_event[4] = 7 # sf=0
|
|
|
|
| 766 |
key_signature_to_add.append(new_event)
|
| 767 |
elif name == "control_change" or name == "patch_change":
|
| 768 |
c = event[4]
|
|
@@ -772,10 +788,12 @@ class MIDITokenizerV2:
|
|
| 772 |
note_tracks = channel_note_tracks[c]
|
| 773 |
if len(note_tracks) != 0 and track_idx not in note_tracks:
|
| 774 |
track_idx = channel_note_tracks[c][0]
|
| 775 |
-
new_track_idx = tr_map
|
| 776 |
event[3] = new_track_idx
|
| 777 |
if name == "patch_change" and event[4] not in patch_channels:
|
| 778 |
patch_channels.append(event[4])
|
|
|
|
|
|
|
| 779 |
event_list += key_signature_to_add
|
| 780 |
track_to_channels ={}
|
| 781 |
for c, tr_map in track_idx_map.items():
|
|
@@ -793,16 +811,31 @@ class MIDITokenizerV2:
|
|
| 793 |
if c not in patch_channels and c in track_idx_dict:
|
| 794 |
event_list.append(["patch_change", 0, 0, track_idx_dict[c], c, 0])
|
| 795 |
|
| 796 |
-
if
|
| 797 |
-
# detect key signature
|
| 798 |
root_key = self.detect_key_signature(note_key_hist)
|
| 799 |
if root_key is not None:
|
| 800 |
sf = self.key2sf(root_key, 0)
|
| 801 |
# print("detect_key_signature",sf)
|
| 802 |
-
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 806 |
|
| 807 |
events_name_order = ["time_signature", "key_signature", "set_tempo", "patch_change", "control_change", "note"]
|
| 808 |
events_name_order = {name: i for i, name in enumerate(events_name_order)}
|
|
@@ -830,7 +863,10 @@ class MIDITokenizerV2:
|
|
| 830 |
else:
|
| 831 |
if event[0] == "note":
|
| 832 |
notes_in_setup = True
|
| 833 |
-
|
|
|
|
|
|
|
|
|
|
| 834 |
setup_events[key] = new_event
|
| 835 |
|
| 836 |
last_t1 = 0
|
|
|
|
| 200 |
note_tracks = channel_note_tracks[c]
|
| 201 |
if len(note_tracks) != 0 and track_idx not in note_tracks:
|
| 202 |
track_idx = channel_note_tracks[c][0]
|
| 203 |
+
new_track_idx = tr_map[track_idx]
|
| 204 |
event[3] = new_track_idx
|
| 205 |
if name == "patch_change" and event[4] not in patch_channels:
|
| 206 |
patch_channels.append(event[4])
|
|
|
|
| 235 |
else:
|
| 236 |
if event[0] == "note":
|
| 237 |
notes_in_setup = True
|
| 238 |
+
key = tuple([event[0]] + event[3:-2])
|
| 239 |
+
else:
|
| 240 |
+
key = tuple([event[0]] + event[3:-1])
|
| 241 |
setup_events[key] = new_event
|
| 242 |
|
| 243 |
last_t1 = 0
|
|
|
|
| 553 |
def detect_key_signature(key_hist, threshold=0.7):
|
| 554 |
if len(key_hist) != 12:
|
| 555 |
return None
|
| 556 |
+
if sum(key_hist) == 0:
|
| 557 |
+
return None
|
| 558 |
p = sum(sorted(key_hist, reverse=True)[:7]) / sum(key_hist)
|
| 559 |
if p < threshold:
|
| 560 |
return None
|
|
|
|
| 594 |
empty_channels = [True] * 16
|
| 595 |
channel_note_tracks = {i: list() for i in range(16)}
|
| 596 |
note_key_hist = [0]*12
|
| 597 |
+
key_sigs = []
|
| 598 |
track_to_channels = {}
|
| 599 |
for track_idx, track in enumerate(midi_score[1:129]):
|
| 600 |
last_notes = {}
|
|
|
|
| 665 |
sf, mi = event[2:]
|
| 666 |
if not (-7 <= sf <= 7 and 0 <= mi <= 1): # invalid
|
| 667 |
continue
|
|
|
|
| 668 |
sf += 7
|
| 669 |
new_event += [sf, mi]
|
| 670 |
+
key_sigs.append(new_event)
|
| 671 |
|
| 672 |
+
if name in ["note", "time_signature", "key_signature"]:
|
| 673 |
key = tuple(new_event[:-2])
|
| 674 |
else:
|
| 675 |
key = tuple(new_event[:-1])
|
|
|
|
| 735 |
|
| 736 |
empty_channels = [channels_map[c] for c in empty_channels]
|
| 737 |
track_idx_dict = {}
|
| 738 |
+
key_sigs = []
|
| 739 |
key_signature_to_add = []
|
| 740 |
+
key_signature_to_remove = []
|
| 741 |
for event in event_list:
|
| 742 |
name = event[0]
|
| 743 |
track_idx = event[3]
|
|
|
|
| 754 |
for c, tr_map in track_idx_map.items():
|
| 755 |
if track_idx in tr_map:
|
| 756 |
new_track_idx = tr_map[track_idx]
|
| 757 |
+
c = channels_map[c]
|
| 758 |
new_channel_track_idx = (c, new_track_idx)
|
| 759 |
+
if new_track_idx == 0:
|
| 760 |
+
continue
|
| 761 |
if new_channel_track_idx not in new_channel_track_idxs:
|
| 762 |
new_channel_track_idxs.append(new_channel_track_idx)
|
| 763 |
+
|
| 764 |
if len(new_channel_track_idxs) == 0:
|
| 765 |
+
if event[3] == 0: # keep key_signature on track 0 (meta)
|
| 766 |
+
key_sigs.append(event)
|
| 767 |
+
continue
|
| 768 |
+
event[3] = -1 # avoid remove same event
|
| 769 |
+
key_signature_to_remove.append(event) # empty track
|
| 770 |
continue
|
| 771 |
c, nt = new_channel_track_idxs[0]
|
| 772 |
event[3] = nt
|
| 773 |
+
key_sigs.append(event)
|
| 774 |
if c == 9:
|
| 775 |
event[4] = 7 # sf=0
|
| 776 |
for c, nt in new_channel_track_idxs[1:]:
|
|
|
|
| 778 |
new_event[3] = nt
|
| 779 |
if c == 9:
|
| 780 |
new_event[4] = 7 # sf=0
|
| 781 |
+
key_sigs.append(new_event)
|
| 782 |
key_signature_to_add.append(new_event)
|
| 783 |
elif name == "control_change" or name == "patch_change":
|
| 784 |
c = event[4]
|
|
|
|
| 788 |
note_tracks = channel_note_tracks[c]
|
| 789 |
if len(note_tracks) != 0 and track_idx not in note_tracks:
|
| 790 |
track_idx = channel_note_tracks[c][0]
|
| 791 |
+
new_track_idx = tr_map[track_idx]
|
| 792 |
event[3] = new_track_idx
|
| 793 |
if name == "patch_change" and event[4] not in patch_channels:
|
| 794 |
patch_channels.append(event[4])
|
| 795 |
+
for key_sig in key_signature_to_remove:
|
| 796 |
+
event_list.remove(key_sig)
|
| 797 |
event_list += key_signature_to_add
|
| 798 |
track_to_channels ={}
|
| 799 |
for c, tr_map in track_idx_map.items():
|
|
|
|
| 811 |
if c not in patch_channels and c in track_idx_dict:
|
| 812 |
event_list.append(["patch_change", 0, 0, track_idx_dict[c], c, 0])
|
| 813 |
|
| 814 |
+
if len(key_sigs) == 0 or all([key_sig[4]==7 for key_sig in key_sigs]):
|
| 815 |
+
# detect key signature or fix the default key signature
|
| 816 |
root_key = self.detect_key_signature(note_key_hist)
|
| 817 |
if root_key is not None:
|
| 818 |
sf = self.key2sf(root_key, 0)
|
| 819 |
# print("detect_key_signature",sf)
|
| 820 |
+
if len(key_sigs) == 0:
|
| 821 |
+
for tr, cs in track_to_channels.items():
|
| 822 |
+
if remap_track_channel and tr == 0:
|
| 823 |
+
continue
|
| 824 |
+
new_event = ["key_signature", 0, 0, tr, (0 if (len(cs) == 1 and cs[0] == 9) else sf) + 7, 0]
|
| 825 |
+
event_list.append(new_event)
|
| 826 |
+
else:
|
| 827 |
+
for key_sig in key_sigs:
|
| 828 |
+
tr = key_sig[3]
|
| 829 |
+
if tr in track_to_channels:
|
| 830 |
+
cs = track_to_channels[tr]
|
| 831 |
+
if len(cs) == 1 and cs[0] == 9:
|
| 832 |
+
continue
|
| 833 |
+
key_sig[4] = sf + 7
|
| 834 |
+
key_sig[5] = 0
|
| 835 |
+
else:
|
| 836 |
+
# remove default key signature
|
| 837 |
+
for key_sig in key_sigs:
|
| 838 |
+
event_list.remove(key_sig)
|
| 839 |
|
| 840 |
events_name_order = ["time_signature", "key_signature", "set_tempo", "patch_change", "control_change", "note"]
|
| 841 |
events_name_order = {name: i for i, name in enumerate(events_name_order)}
|
|
|
|
| 863 |
else:
|
| 864 |
if event[0] == "note":
|
| 865 |
notes_in_setup = True
|
| 866 |
+
if event[0] in ["note", "time_signature", "key_signature"]:
|
| 867 |
+
key = tuple([event[0]]+event[3:-2])
|
| 868 |
+
else:
|
| 869 |
+
key = tuple([event[0]]+event[3:-1])
|
| 870 |
setup_events[key] = new_event
|
| 871 |
|
| 872 |
last_t1 = 0
|