TomRB22 commited on
Commit
ebc3f07
1 Parent(s): 86ff1f5

Fixed issue related to negative duration and step

Browse files
Files changed (1) hide show
  1. audio.py +11 -2
audio.py CHANGED
@@ -105,12 +105,14 @@ def map_to_wav(song_map: pd.DataFrame, out_file: str, velocity: int=50) -> prett
105
  Returns:
106
  pretty_midi.PrettyMIDI: PrettyMIDI object containing the song's representation.
107
  """
108
-
 
109
  contracted_map = tf.squeeze(song_map)
110
  song_map_T = contracted_map.numpy().T
111
  notes = pd.DataFrame(song_map_T, columns=["pitch", "step", "duration"]).mul(_SCALING_FACTORS, axis=1)
112
  notes["pitch"] = notes["pitch"].astype('int32').clip(1, 127)
113
 
 
114
  pm = pretty_midi.PrettyMIDI()
115
  instrument = pretty_midi.Instrument(
116
  program=pretty_midi.instrument_name_to_program(
@@ -118,6 +120,11 @@ def map_to_wav(song_map: pd.DataFrame, out_file: str, velocity: int=50) -> prett
118
 
119
  prev_start = 0
120
  for i, note in notes.iterrows():
 
 
 
 
 
121
  start = float(prev_start + note['step'])
122
  end = float(start + note['duration'])
123
  note = pretty_midi.Note(
@@ -130,7 +137,9 @@ def map_to_wav(song_map: pd.DataFrame, out_file: str, velocity: int=50) -> prett
130
  prev_start = start
131
 
132
  pm.instruments.append(instrument)
133
- if (out_file):
 
 
134
  pm.write(out_file)
135
  return pm
136
 
 
105
  Returns:
106
  pretty_midi.PrettyMIDI: PrettyMIDI object containing the song's representation.
107
  """
108
+
109
+ # Get song map as dataframe
110
  contracted_map = tf.squeeze(song_map)
111
  song_map_T = contracted_map.numpy().T
112
  notes = pd.DataFrame(song_map_T, columns=["pitch", "step", "duration"]).mul(_SCALING_FACTORS, axis=1)
113
  notes["pitch"] = notes["pitch"].astype('int32').clip(1, 127)
114
 
115
+ # Instantiate PrettyMIDI object and append notes
116
  pm = pretty_midi.PrettyMIDI()
117
  instrument = pretty_midi.Instrument(
118
  program=pretty_midi.instrument_name_to_program(
 
120
 
121
  prev_start = 0
122
  for i, note in notes.iterrows():
123
+ # The VAE might generate notes with negative step and duration,
124
+ # and we therefore need to make sure to skip these anomalies
125
+ if (note['step'] < 0 or note['duration'] < 0):
126
+ continue
127
+
128
  start = float(prev_start + note['step'])
129
  end = float(start + note['duration'])
130
  note = pretty_midi.Note(
 
137
  prev_start = start
138
 
139
  pm.instruments.append(instrument)
140
+
141
+ # If a path was specified, save as midi file
142
+ if out_file:
143
  pm.write(out_file)
144
  return pm
145