skytnt commited on
Commit
ff0299c
1 Parent(s): 7406325

update MIDITokenizer

Browse files
Files changed (1) hide show
  1. midi_tokenizer.py +50 -14
midi_tokenizer.py CHANGED
@@ -200,7 +200,7 @@ class MIDITokenizerV1:
200
  note_tracks = channel_note_tracks[c]
201
  if len(note_tracks) != 0 and track_idx not in note_tracks:
202
  track_idx = channel_note_tracks[c][0]
203
- new_track_idx = tr_map.setdefault(track_idx, next(iter(tr_map.values())))
204
  event[3] = new_track_idx
205
  if name == "patch_change" and event[4] not in patch_channels:
206
  patch_channels.append(event[4])
@@ -235,7 +235,9 @@ class MIDITokenizerV1:
235
  else:
236
  if event[0] == "note":
237
  notes_in_setup = True
238
- key = tuple(event[3:-1])
 
 
239
  setup_events[key] = new_event
240
 
241
  last_t1 = 0
@@ -551,6 +553,8 @@ class MIDITokenizerV2:
551
  def detect_key_signature(key_hist, threshold=0.7):
552
  if len(key_hist) != 12:
553
  return None
 
 
554
  p = sum(sorted(key_hist, reverse=True)[:7]) / sum(key_hist)
555
  if p < threshold:
556
  return None
@@ -590,7 +594,7 @@ class MIDITokenizerV2:
590
  empty_channels = [True] * 16
591
  channel_note_tracks = {i: list() for i in range(16)}
592
  note_key_hist = [0]*12
593
- key_sig_num = 0
594
  track_to_channels = {}
595
  for track_idx, track in enumerate(midi_score[1:129]):
596
  last_notes = {}
@@ -661,11 +665,11 @@ class MIDITokenizerV2:
661
  sf, mi = event[2:]
662
  if not (-7 <= sf <= 7 and 0 <= mi <= 1): # invalid
663
  continue
664
- key_sig_num += 1
665
  sf += 7
666
  new_event += [sf, mi]
 
667
 
668
- if name == "note":
669
  key = tuple(new_event[:-2])
670
  else:
671
  key = tuple(new_event[:-1])
@@ -731,7 +735,9 @@ class MIDITokenizerV2:
731
 
732
  empty_channels = [channels_map[c] for c in empty_channels]
733
  track_idx_dict = {}
 
734
  key_signature_to_add = []
 
735
  for event in event_list:
736
  name = event[0]
737
  track_idx = event[3]
@@ -748,14 +754,23 @@ class MIDITokenizerV2:
748
  for c, tr_map in track_idx_map.items():
749
  if track_idx in tr_map:
750
  new_track_idx = tr_map[track_idx]
 
751
  new_channel_track_idx = (c, new_track_idx)
 
 
752
  if new_channel_track_idx not in new_channel_track_idxs:
753
  new_channel_track_idxs.append(new_channel_track_idx)
 
754
  if len(new_channel_track_idxs) == 0:
755
- event[3] = 0
 
 
 
 
756
  continue
757
  c, nt = new_channel_track_idxs[0]
758
  event[3] = nt
 
759
  if c == 9:
760
  event[4] = 7 # sf=0
761
  for c, nt in new_channel_track_idxs[1:]:
@@ -763,6 +778,7 @@ class MIDITokenizerV2:
763
  new_event[3] = nt
764
  if c == 9:
765
  new_event[4] = 7 # sf=0
 
766
  key_signature_to_add.append(new_event)
767
  elif name == "control_change" or name == "patch_change":
768
  c = event[4]
@@ -772,10 +788,12 @@ class MIDITokenizerV2:
772
  note_tracks = channel_note_tracks[c]
773
  if len(note_tracks) != 0 and track_idx not in note_tracks:
774
  track_idx = channel_note_tracks[c][0]
775
- new_track_idx = tr_map.setdefault(track_idx, next(iter(tr_map.values())))
776
  event[3] = new_track_idx
777
  if name == "patch_change" and event[4] not in patch_channels:
778
  patch_channels.append(event[4])
 
 
779
  event_list += key_signature_to_add
780
  track_to_channels ={}
781
  for c, tr_map in track_idx_map.items():
@@ -793,16 +811,31 @@ class MIDITokenizerV2:
793
  if c not in patch_channels and c in track_idx_dict:
794
  event_list.append(["patch_change", 0, 0, track_idx_dict[c], c, 0])
795
 
796
- if key_sig_num == 0:
797
- # detect key signature.
798
  root_key = self.detect_key_signature(note_key_hist)
799
  if root_key is not None:
800
  sf = self.key2sf(root_key, 0)
801
  # print("detect_key_signature",sf)
802
- for tr, cs in track_to_channels.items():
803
- if remap_track_channel and tr == 0:
804
- continue
805
- event_list.append(["key_signature", 0, 0, tr, (0 if (len(cs) == 1 and cs[0] == 9) else sf) + 7, 0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
806
 
807
  events_name_order = ["time_signature", "key_signature", "set_tempo", "patch_change", "control_change", "note"]
808
  events_name_order = {name: i for i, name in enumerate(events_name_order)}
@@ -830,7 +863,10 @@ class MIDITokenizerV2:
830
  else:
831
  if event[0] == "note":
832
  notes_in_setup = True
833
- key = tuple(event[3:-1])
 
 
 
834
  setup_events[key] = new_event
835
 
836
  last_t1 = 0
 
200
  note_tracks = channel_note_tracks[c]
201
  if len(note_tracks) != 0 and track_idx not in note_tracks:
202
  track_idx = channel_note_tracks[c][0]
203
+ new_track_idx = tr_map[track_idx]
204
  event[3] = new_track_idx
205
  if name == "patch_change" and event[4] not in patch_channels:
206
  patch_channels.append(event[4])
 
235
  else:
236
  if event[0] == "note":
237
  notes_in_setup = True
238
+ key = tuple([event[0]] + event[3:-2])
239
+ else:
240
+ key = tuple([event[0]] + event[3:-1])
241
  setup_events[key] = new_event
242
 
243
  last_t1 = 0
 
553
  def detect_key_signature(key_hist, threshold=0.7):
554
  if len(key_hist) != 12:
555
  return None
556
+ if sum(key_hist) == 0:
557
+ return None
558
  p = sum(sorted(key_hist, reverse=True)[:7]) / sum(key_hist)
559
  if p < threshold:
560
  return None
 
594
  empty_channels = [True] * 16
595
  channel_note_tracks = {i: list() for i in range(16)}
596
  note_key_hist = [0]*12
597
+ key_sigs = []
598
  track_to_channels = {}
599
  for track_idx, track in enumerate(midi_score[1:129]):
600
  last_notes = {}
 
665
  sf, mi = event[2:]
666
  if not (-7 <= sf <= 7 and 0 <= mi <= 1): # invalid
667
  continue
 
668
  sf += 7
669
  new_event += [sf, mi]
670
+ key_sigs.append(new_event)
671
 
672
+ if name in ["note", "time_signature", "key_signature"]:
673
  key = tuple(new_event[:-2])
674
  else:
675
  key = tuple(new_event[:-1])
 
735
 
736
  empty_channels = [channels_map[c] for c in empty_channels]
737
  track_idx_dict = {}
738
+ key_sigs = []
739
  key_signature_to_add = []
740
+ key_signature_to_remove = []
741
  for event in event_list:
742
  name = event[0]
743
  track_idx = event[3]
 
754
  for c, tr_map in track_idx_map.items():
755
  if track_idx in tr_map:
756
  new_track_idx = tr_map[track_idx]
757
+ c = channels_map[c]
758
  new_channel_track_idx = (c, new_track_idx)
759
+ if new_track_idx == 0:
760
+ continue
761
  if new_channel_track_idx not in new_channel_track_idxs:
762
  new_channel_track_idxs.append(new_channel_track_idx)
763
+
764
  if len(new_channel_track_idxs) == 0:
765
+ if event[3] == 0: # keep key_signature on track 0 (meta)
766
+ key_sigs.append(event)
767
+ continue
768
+ event[3] = -1 # avoid remove same event
769
+ key_signature_to_remove.append(event) # empty track
770
  continue
771
  c, nt = new_channel_track_idxs[0]
772
  event[3] = nt
773
+ key_sigs.append(event)
774
  if c == 9:
775
  event[4] = 7 # sf=0
776
  for c, nt in new_channel_track_idxs[1:]:
 
778
  new_event[3] = nt
779
  if c == 9:
780
  new_event[4] = 7 # sf=0
781
+ key_sigs.append(new_event)
782
  key_signature_to_add.append(new_event)
783
  elif name == "control_change" or name == "patch_change":
784
  c = event[4]
 
788
  note_tracks = channel_note_tracks[c]
789
  if len(note_tracks) != 0 and track_idx not in note_tracks:
790
  track_idx = channel_note_tracks[c][0]
791
+ new_track_idx = tr_map[track_idx]
792
  event[3] = new_track_idx
793
  if name == "patch_change" and event[4] not in patch_channels:
794
  patch_channels.append(event[4])
795
+ for key_sig in key_signature_to_remove:
796
+ event_list.remove(key_sig)
797
  event_list += key_signature_to_add
798
  track_to_channels ={}
799
  for c, tr_map in track_idx_map.items():
 
811
  if c not in patch_channels and c in track_idx_dict:
812
  event_list.append(["patch_change", 0, 0, track_idx_dict[c], c, 0])
813
 
814
+ if len(key_sigs) == 0 or all([key_sig[4]==7 for key_sig in key_sigs]):
815
+ # detect key signature or fix the default key signature
816
  root_key = self.detect_key_signature(note_key_hist)
817
  if root_key is not None:
818
  sf = self.key2sf(root_key, 0)
819
  # print("detect_key_signature",sf)
820
+ if len(key_sigs) == 0:
821
+ for tr, cs in track_to_channels.items():
822
+ if remap_track_channel and tr == 0:
823
+ continue
824
+ new_event = ["key_signature", 0, 0, tr, (0 if (len(cs) == 1 and cs[0] == 9) else sf) + 7, 0]
825
+ event_list.append(new_event)
826
+ else:
827
+ for key_sig in key_sigs:
828
+ tr = key_sig[3]
829
+ if tr in track_to_channels:
830
+ cs = track_to_channels[tr]
831
+ if len(cs) == 1 and cs[0] == 9:
832
+ continue
833
+ key_sig[4] = sf + 7
834
+ key_sig[5] = 0
835
+ else:
836
+ # remove default key signature
837
+ for key_sig in key_sigs:
838
+ event_list.remove(key_sig)
839
 
840
  events_name_order = ["time_signature", "key_signature", "set_tempo", "patch_change", "control_change", "note"]
841
  events_name_order = {name: i for i, name in enumerate(events_name_order)}
 
863
  else:
864
  if event[0] == "note":
865
  notes_in_setup = True
866
+ if event[0] in ["note", "time_signature", "key_signature"]:
867
+ key = tuple([event[0]]+event[3:-2])
868
+ else:
869
+ key = tuple([event[0]]+event[3:-1])
870
  setup_events[key] = new_event
871
 
872
  last_t1 = 0