Spaces:
Running
on
Zero
Running
on
Zero
update MIDITokenizer
Browse files- midi_tokenizer.py +50 -14
midi_tokenizer.py
CHANGED
@@ -200,7 +200,7 @@ class MIDITokenizerV1:
|
|
200 |
note_tracks = channel_note_tracks[c]
|
201 |
if len(note_tracks) != 0 and track_idx not in note_tracks:
|
202 |
track_idx = channel_note_tracks[c][0]
|
203 |
-
new_track_idx = tr_map
|
204 |
event[3] = new_track_idx
|
205 |
if name == "patch_change" and event[4] not in patch_channels:
|
206 |
patch_channels.append(event[4])
|
@@ -235,7 +235,9 @@ class MIDITokenizerV1:
|
|
235 |
else:
|
236 |
if event[0] == "note":
|
237 |
notes_in_setup = True
|
238 |
-
|
|
|
|
|
239 |
setup_events[key] = new_event
|
240 |
|
241 |
last_t1 = 0
|
@@ -551,6 +553,8 @@ class MIDITokenizerV2:
|
|
551 |
def detect_key_signature(key_hist, threshold=0.7):
|
552 |
if len(key_hist) != 12:
|
553 |
return None
|
|
|
|
|
554 |
p = sum(sorted(key_hist, reverse=True)[:7]) / sum(key_hist)
|
555 |
if p < threshold:
|
556 |
return None
|
@@ -590,7 +594,7 @@ class MIDITokenizerV2:
|
|
590 |
empty_channels = [True] * 16
|
591 |
channel_note_tracks = {i: list() for i in range(16)}
|
592 |
note_key_hist = [0]*12
|
593 |
-
|
594 |
track_to_channels = {}
|
595 |
for track_idx, track in enumerate(midi_score[1:129]):
|
596 |
last_notes = {}
|
@@ -661,11 +665,11 @@ class MIDITokenizerV2:
|
|
661 |
sf, mi = event[2:]
|
662 |
if not (-7 <= sf <= 7 and 0 <= mi <= 1): # invalid
|
663 |
continue
|
664 |
-
key_sig_num += 1
|
665 |
sf += 7
|
666 |
new_event += [sf, mi]
|
|
|
667 |
|
668 |
-
if name
|
669 |
key = tuple(new_event[:-2])
|
670 |
else:
|
671 |
key = tuple(new_event[:-1])
|
@@ -731,7 +735,9 @@ class MIDITokenizerV2:
|
|
731 |
|
732 |
empty_channels = [channels_map[c] for c in empty_channels]
|
733 |
track_idx_dict = {}
|
|
|
734 |
key_signature_to_add = []
|
|
|
735 |
for event in event_list:
|
736 |
name = event[0]
|
737 |
track_idx = event[3]
|
@@ -748,14 +754,23 @@ class MIDITokenizerV2:
|
|
748 |
for c, tr_map in track_idx_map.items():
|
749 |
if track_idx in tr_map:
|
750 |
new_track_idx = tr_map[track_idx]
|
|
|
751 |
new_channel_track_idx = (c, new_track_idx)
|
|
|
|
|
752 |
if new_channel_track_idx not in new_channel_track_idxs:
|
753 |
new_channel_track_idxs.append(new_channel_track_idx)
|
|
|
754 |
if len(new_channel_track_idxs) == 0:
|
755 |
-
event[3]
|
|
|
|
|
|
|
|
|
756 |
continue
|
757 |
c, nt = new_channel_track_idxs[0]
|
758 |
event[3] = nt
|
|
|
759 |
if c == 9:
|
760 |
event[4] = 7 # sf=0
|
761 |
for c, nt in new_channel_track_idxs[1:]:
|
@@ -763,6 +778,7 @@ class MIDITokenizerV2:
|
|
763 |
new_event[3] = nt
|
764 |
if c == 9:
|
765 |
new_event[4] = 7 # sf=0
|
|
|
766 |
key_signature_to_add.append(new_event)
|
767 |
elif name == "control_change" or name == "patch_change":
|
768 |
c = event[4]
|
@@ -772,10 +788,12 @@ class MIDITokenizerV2:
|
|
772 |
note_tracks = channel_note_tracks[c]
|
773 |
if len(note_tracks) != 0 and track_idx not in note_tracks:
|
774 |
track_idx = channel_note_tracks[c][0]
|
775 |
-
new_track_idx = tr_map
|
776 |
event[3] = new_track_idx
|
777 |
if name == "patch_change" and event[4] not in patch_channels:
|
778 |
patch_channels.append(event[4])
|
|
|
|
|
779 |
event_list += key_signature_to_add
|
780 |
track_to_channels ={}
|
781 |
for c, tr_map in track_idx_map.items():
|
@@ -793,16 +811,31 @@ class MIDITokenizerV2:
|
|
793 |
if c not in patch_channels and c in track_idx_dict:
|
794 |
event_list.append(["patch_change", 0, 0, track_idx_dict[c], c, 0])
|
795 |
|
796 |
-
if
|
797 |
-
# detect key signature
|
798 |
root_key = self.detect_key_signature(note_key_hist)
|
799 |
if root_key is not None:
|
800 |
sf = self.key2sf(root_key, 0)
|
801 |
# print("detect_key_signature",sf)
|
802 |
-
|
803 |
-
|
804 |
-
|
805 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
806 |
|
807 |
events_name_order = ["time_signature", "key_signature", "set_tempo", "patch_change", "control_change", "note"]
|
808 |
events_name_order = {name: i for i, name in enumerate(events_name_order)}
|
@@ -830,7 +863,10 @@ class MIDITokenizerV2:
|
|
830 |
else:
|
831 |
if event[0] == "note":
|
832 |
notes_in_setup = True
|
833 |
-
|
|
|
|
|
|
|
834 |
setup_events[key] = new_event
|
835 |
|
836 |
last_t1 = 0
|
|
|
200 |
note_tracks = channel_note_tracks[c]
|
201 |
if len(note_tracks) != 0 and track_idx not in note_tracks:
|
202 |
track_idx = channel_note_tracks[c][0]
|
203 |
+
new_track_idx = tr_map[track_idx]
|
204 |
event[3] = new_track_idx
|
205 |
if name == "patch_change" and event[4] not in patch_channels:
|
206 |
patch_channels.append(event[4])
|
|
|
235 |
else:
|
236 |
if event[0] == "note":
|
237 |
notes_in_setup = True
|
238 |
+
key = tuple([event[0]] + event[3:-2])
|
239 |
+
else:
|
240 |
+
key = tuple([event[0]] + event[3:-1])
|
241 |
setup_events[key] = new_event
|
242 |
|
243 |
last_t1 = 0
|
|
|
553 |
def detect_key_signature(key_hist, threshold=0.7):
|
554 |
if len(key_hist) != 12:
|
555 |
return None
|
556 |
+
if sum(key_hist) == 0:
|
557 |
+
return None
|
558 |
p = sum(sorted(key_hist, reverse=True)[:7]) / sum(key_hist)
|
559 |
if p < threshold:
|
560 |
return None
|
|
|
594 |
empty_channels = [True] * 16
|
595 |
channel_note_tracks = {i: list() for i in range(16)}
|
596 |
note_key_hist = [0]*12
|
597 |
+
key_sigs = []
|
598 |
track_to_channels = {}
|
599 |
for track_idx, track in enumerate(midi_score[1:129]):
|
600 |
last_notes = {}
|
|
|
665 |
sf, mi = event[2:]
|
666 |
if not (-7 <= sf <= 7 and 0 <= mi <= 1): # invalid
|
667 |
continue
|
|
|
668 |
sf += 7
|
669 |
new_event += [sf, mi]
|
670 |
+
key_sigs.append(new_event)
|
671 |
|
672 |
+
if name in ["note", "time_signature", "key_signature"]:
|
673 |
key = tuple(new_event[:-2])
|
674 |
else:
|
675 |
key = tuple(new_event[:-1])
|
|
|
735 |
|
736 |
empty_channels = [channels_map[c] for c in empty_channels]
|
737 |
track_idx_dict = {}
|
738 |
+
key_sigs = []
|
739 |
key_signature_to_add = []
|
740 |
+
key_signature_to_remove = []
|
741 |
for event in event_list:
|
742 |
name = event[0]
|
743 |
track_idx = event[3]
|
|
|
754 |
for c, tr_map in track_idx_map.items():
|
755 |
if track_idx in tr_map:
|
756 |
new_track_idx = tr_map[track_idx]
|
757 |
+
c = channels_map[c]
|
758 |
new_channel_track_idx = (c, new_track_idx)
|
759 |
+
if new_track_idx == 0:
|
760 |
+
continue
|
761 |
if new_channel_track_idx not in new_channel_track_idxs:
|
762 |
new_channel_track_idxs.append(new_channel_track_idx)
|
763 |
+
|
764 |
if len(new_channel_track_idxs) == 0:
|
765 |
+
if event[3] == 0: # keep key_signature on track 0 (meta)
|
766 |
+
key_sigs.append(event)
|
767 |
+
continue
|
768 |
+
event[3] = -1 # avoid remove same event
|
769 |
+
key_signature_to_remove.append(event) # empty track
|
770 |
continue
|
771 |
c, nt = new_channel_track_idxs[0]
|
772 |
event[3] = nt
|
773 |
+
key_sigs.append(event)
|
774 |
if c == 9:
|
775 |
event[4] = 7 # sf=0
|
776 |
for c, nt in new_channel_track_idxs[1:]:
|
|
|
778 |
new_event[3] = nt
|
779 |
if c == 9:
|
780 |
new_event[4] = 7 # sf=0
|
781 |
+
key_sigs.append(new_event)
|
782 |
key_signature_to_add.append(new_event)
|
783 |
elif name == "control_change" or name == "patch_change":
|
784 |
c = event[4]
|
|
|
788 |
note_tracks = channel_note_tracks[c]
|
789 |
if len(note_tracks) != 0 and track_idx not in note_tracks:
|
790 |
track_idx = channel_note_tracks[c][0]
|
791 |
+
new_track_idx = tr_map[track_idx]
|
792 |
event[3] = new_track_idx
|
793 |
if name == "patch_change" and event[4] not in patch_channels:
|
794 |
patch_channels.append(event[4])
|
795 |
+
for key_sig in key_signature_to_remove:
|
796 |
+
event_list.remove(key_sig)
|
797 |
event_list += key_signature_to_add
|
798 |
track_to_channels ={}
|
799 |
for c, tr_map in track_idx_map.items():
|
|
|
811 |
if c not in patch_channels and c in track_idx_dict:
|
812 |
event_list.append(["patch_change", 0, 0, track_idx_dict[c], c, 0])
|
813 |
|
814 |
+
if len(key_sigs) == 0 or all([key_sig[4]==7 for key_sig in key_sigs]):
|
815 |
+
# detect key signature or fix the default key signature
|
816 |
root_key = self.detect_key_signature(note_key_hist)
|
817 |
if root_key is not None:
|
818 |
sf = self.key2sf(root_key, 0)
|
819 |
# print("detect_key_signature",sf)
|
820 |
+
if len(key_sigs) == 0:
|
821 |
+
for tr, cs in track_to_channels.items():
|
822 |
+
if remap_track_channel and tr == 0:
|
823 |
+
continue
|
824 |
+
new_event = ["key_signature", 0, 0, tr, (0 if (len(cs) == 1 and cs[0] == 9) else sf) + 7, 0]
|
825 |
+
event_list.append(new_event)
|
826 |
+
else:
|
827 |
+
for key_sig in key_sigs:
|
828 |
+
tr = key_sig[3]
|
829 |
+
if tr in track_to_channels:
|
830 |
+
cs = track_to_channels[tr]
|
831 |
+
if len(cs) == 1 and cs[0] == 9:
|
832 |
+
continue
|
833 |
+
key_sig[4] = sf + 7
|
834 |
+
key_sig[5] = 0
|
835 |
+
else:
|
836 |
+
# remove default key signature
|
837 |
+
for key_sig in key_sigs:
|
838 |
+
event_list.remove(key_sig)
|
839 |
|
840 |
events_name_order = ["time_signature", "key_signature", "set_tempo", "patch_change", "control_change", "note"]
|
841 |
events_name_order = {name: i for i, name in enumerate(events_name_order)}
|
|
|
863 |
else:
|
864 |
if event[0] == "note":
|
865 |
notes_in_setup = True
|
866 |
+
if event[0] in ["note", "time_signature", "key_signature"]:
|
867 |
+
key = tuple([event[0]]+event[3:-2])
|
868 |
+
else:
|
869 |
+
key = tuple([event[0]]+event[3:-1])
|
870 |
setup_events[key] = new_event
|
871 |
|
872 |
last_t1 = 0
|