misnaej commited on
Commit
88094df
1 Parent(s): 0ffb5f8

decoding utils update

Browse files
Files changed (1) hide show
  1. utils.py +165 -111
utils.py CHANGED
@@ -3,6 +3,7 @@ from miditok import Event, MIDILike
3
  import os
4
  import json
5
  from time import perf_counter
 
6
  from joblib import Parallel, delayed
7
  from zipfile import ZipFile, ZIP_DEFLATED
8
  from scipy.io.wavfile import write
@@ -10,29 +11,36 @@ import numpy as np
10
  from pydub import AudioSegment
11
  import shutil
12
 
 
13
 
14
- def writeToFile(path, content):
15
- if type(content) is dict:
16
- with open(f"{path}", "w") as json_file:
17
- json.dump(content, json_file)
18
- else:
19
- if type(content) is not str:
20
- content = str(content)
21
- os.makedirs(os.path.dirname(path), exist_ok=True)
22
- with open(path, "w") as f:
23
- f.write(content)
24
 
 
 
 
 
 
25
 
26
- # Function to read from text from txt file:
27
- def readFromFile(path, isJSON=False):
28
- with open(path, "r") as f:
29
- if isJSON:
30
- return json.load(f)
31
- else:
32
- return f.read()
 
 
 
 
 
 
 
 
 
 
33
 
34
 
35
  def chain(input, funcs, *params):
 
36
  res = input
37
  for func in funcs:
38
  try:
@@ -42,21 +50,8 @@ def chain(input, funcs, *params):
42
  return res
43
 
44
 
45
- def to_beat_str(value, beat_res=8):
46
- values = [
47
- int(int(value * beat_res) / beat_res),
48
- int(int(value * beat_res) % beat_res),
49
- beat_res,
50
- ]
51
- return ".".join(map(str, values))
52
-
53
-
54
- def to_base10(beat_str):
55
- integer, decimal, base = split_dots(beat_str)
56
- return integer + decimal / base
57
-
58
-
59
  def split_dots(value):
 
60
  return list(map(int, value.split(".")))
61
 
62
 
@@ -68,7 +63,39 @@ def get_datetime():
68
  return datetime.now().strftime("%Y%m%d_%H%M%S")
69
 
70
 
71
- def get_text(event):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  match event.type:
73
  case "Piece-Start":
74
  return "PIECE_START "
@@ -77,13 +104,18 @@ def get_text(event):
77
  case "Track-End":
78
  return "TRACK_END "
79
  case "Instrument":
80
- return f"INST={event.value} "
 
 
 
 
 
81
  case "Bar-Start":
82
  return "BAR_START "
83
  case "Bar-End":
84
  return "BAR_END "
85
  case "Time-Shift":
86
- return f"TIME_SHIFT={event.value} "
87
  case "Note-On":
88
  return f"NOTE_ON={event.value} "
89
  case "Note-Off":
@@ -92,15 +124,66 @@ def get_text(event):
92
  return ""
93
 
94
 
95
- def get_event(text, value=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  match text:
97
  case "PIECE_START":
98
  return Event("Piece-Start", value)
99
  case "TRACK_START":
100
- return None
101
  case "TRACK_END":
102
- return None
103
  case "INST":
 
 
104
  return Event("Instrument", value)
105
  case "BAR_START":
106
  return Event("Bar-Start", value)
@@ -109,7 +192,8 @@ def get_event(text, value=None):
109
  case "TIME_SHIFT":
110
  return Event("Time-Shift", value)
111
  case "TIME_DELTA":
112
- return Event("Time-Shift", to_beat_str(int(value) / 4))
 
113
  case "NOTE_ON":
114
  return Event("Note-On", value)
115
  case "NOTE_OFF":
@@ -118,39 +202,27 @@ def get_event(text, value=None):
118
  return None
119
 
120
 
121
- # TODO: Make this singleton
122
- def get_miditok():
123
- pitch_range = range(0, 140) # was (21, 109)
124
- beat_res = {(0, 400): 8}
125
- return MIDILike(pitch_range, beat_res)
126
-
127
 
128
- class WriteTextMidiToFile: # utils saving to file
129
- def __init__(self, generate_midi, output_path):
130
- self.generated_midi = generate_midi.generated_piece
131
- self.output_path = output_path
132
- self.hyperparameter_and_bars = generate_midi.piece_by_track
133
 
134
- def hashing_seq(self):
135
- self.current_time = get_datetime()
136
- self.output_path_filename = f"{self.output_path}/{self.current_time}.json"
 
 
 
 
 
 
 
137
 
138
- def wrapping_seq_hyperparameters_in_dict(self):
139
- # assert type(self.generated_midi) is str, "error: generate_midi must be a string"
140
- # assert (
141
- # type(self.hyperparameter_dict) is dict
142
- # ), "error: feature_dict must be a dictionnary"
143
- return {
144
- "generate_midi": self.generated_midi,
145
- "hyperparameters_and_bars": self.hyperparameter_and_bars,
146
- }
147
 
148
- def text_midi_to_file(self):
149
- self.hashing_seq()
150
- output_dict = self.wrapping_seq_hyperparameters_in_dict()
151
- print(f"Token generate_midi written: {self.output_path_filename}")
152
- writeToFile(self.output_path_filename, output_dict)
153
- return self.output_path_filename
154
 
155
 
156
  def get_files(directory, extension, recursive=False):
@@ -167,15 +239,33 @@ def get_files(directory, extension, recursive=False):
167
  return list(directory.glob(f"*.{extension}"))
168
 
169
 
170
- def timeit(func):
171
- def wrapper(*args, **kwargs):
172
- start = perf_counter()
173
- result = func(*args, **kwargs)
174
- end = perf_counter()
175
- print(f"{func.__name__} took {end - start:.2f} seconds to run.")
176
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
- return wrapper
 
 
 
 
179
 
180
 
181
  class FileCompressor:
@@ -208,39 +298,3 @@ class FileCompressor:
208
  """compress all text files in folder to new zip files and remove the text files"""
209
  files = get_files(self.output_directory, extension="txt")
210
  Parallel(n_jobs=self.n_jobs)(delayed(self.zip_file)(file) for file in files)
211
-
212
-
213
- def load_jsonl(filepath):
214
- """Load a jsonl file"""
215
- with open(filepath, "r") as f:
216
- data = [json.loads(line) for line in f]
217
- return data
218
-
219
-
220
- def write_mp3(waveform, output_path, bitrate="92k"):
221
- """
222
- Write a waveform to an mp3 file.
223
- output_path: Path object for the output mp3 file
224
- waveform: numpy array of the waveform
225
- bitrate: bitrate of the mp3 file (64k, 92k, 128k, 256k, 312k)
226
- """
227
- # write the wav file
228
- wav_path = output_path.with_suffix(".wav")
229
- write(wav_path, 44100, waveform.astype(np.float32))
230
- # compress the wav file as mp3
231
- AudioSegment.from_wav(wav_path).export(output_path, format="mp3", bitrate=bitrate)
232
- # remove the wav file
233
- wav_path.unlink()
234
-
235
-
236
- def copy_file(input_file, output_dir):
237
- """Copy an input file to the output_dir"""
238
- output_file = output_dir / input_file.name
239
- shutil.copy(input_file, output_file)
240
-
241
-
242
- def index_has_substring(list, substring):
243
- for i, s in enumerate(list):
244
- if substring in s:
245
- return i
246
- return -1
3
  import os
4
  import json
5
  from time import perf_counter
6
+ from constants import DRUMS_BEAT_QUANTIZATION, NONE_DRUMS_BEAT_QUANTIZATION
7
  from joblib import Parallel, delayed
8
  from zipfile import ZipFile, ZIP_DEFLATED
9
  from scipy.io.wavfile import write
11
  from pydub import AudioSegment
12
  import shutil
13
 
14
+ """ Diverse utils"""
15
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ def index_has_substring(list, substring):
18
+ for i, s in enumerate(list):
19
+ if substring in s:
20
+ return i
21
+ return -1
22
 
23
+
24
+ # TODO: Make this singleton
25
+ def get_miditok():
26
+ pitch_range = range(0, 127) # was (21, 109)
27
+ beat_res = {(0, 400): 8}
28
+ return MIDILike(pitch_range, beat_res)
29
+
30
+
31
+ def timeit(func):
32
+ def wrapper(*args, **kwargs):
33
+ start = perf_counter()
34
+ result = func(*args, **kwargs)
35
+ end = perf_counter()
36
+ print(f"{func.__name__} took {end - start:.2f} seconds to run.")
37
+ return result
38
+
39
+ return wrapper
40
 
41
 
42
  def chain(input, funcs, *params):
43
+ """Chain functions together, passing the output of one function as the input of the next."""
44
  res = input
45
  for func in funcs:
46
  try:
50
  return res
51
 
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def split_dots(value):
54
+ """Splits a string separated by dots "a.b.c" into a list of integers [a, b, c]"""
55
  return list(map(int, value.split(".")))
56
 
57
 
63
  return datetime.now().strftime("%Y%m%d_%H%M%S")
64
 
65
 
66
+ """ Encoding functions """
67
+
68
+
69
+ def int_dec_base_to_beat(beat_str):
70
+ """
71
+ Converts "integer.decimal.base" (str, from miditok) into beats
72
+ e.g. "0.4.8" = 0 + 4/8 = 0.5
73
+ Args:
74
+ - beat_str: "integer.decimal.base"
75
+ Returns:
76
+ - beats: float
77
+ """
78
+ integer, decimal, base = split_dots(beat_str)
79
+ return integer + decimal / base
80
+
81
+
82
+ def int_dec_base_to_delta(beat_str, instrument="drums"):
83
+ """converts the time shift to time_delta according to Tristan's encoding scheme
84
+ Drums TIME_DELTA are quantized according to DRUMS_BEAT_QUANTIZATION
85
+ Other Instrument TIME_DELTA are quantized according to NONE_DRUMS_BEAT_QUANTIZATION
86
+ """
87
+
88
+ beat_res = (
89
+ DRUMS_BEAT_QUANTIZATION
90
+ if instrument.lower() == "drums"
91
+ else NONE_DRUMS_BEAT_QUANTIZATION
92
+ )
93
+ time_delta = int_dec_base_to_beat(beat_str) * beat_res
94
+ return time_delta.__int__()
95
+
96
+
97
+ def get_text(event, instrument="drums"):
98
+ """Converts an event into a string for the midi-text format"""
99
  match event.type:
100
  case "Piece-Start":
101
  return "PIECE_START "
104
  case "Track-End":
105
  return "TRACK_END "
106
  case "Instrument":
107
+ if str(event.value).lower() == "drums":
108
+ return f"INST=DRUMS "
109
+ else:
110
+ return f"INST={event.value} "
111
+ case "Density":
112
+ return f"DENSITY={event.value} "
113
  case "Bar-Start":
114
  return "BAR_START "
115
  case "Bar-End":
116
  return "BAR_END "
117
  case "Time-Shift":
118
+ return f"TIME_DELTA={int_dec_base_to_delta(event.value, instrument)} "
119
  case "Note-On":
120
  return f"NOTE_ON={event.value} "
121
  case "Note-Off":
124
  return ""
125
 
126
 
127
+ """ Decoding functions """
128
+
129
+
130
+ def time_delta_to_beat(time_delta, instrument="drums"):
131
+ """
132
+ Converts TIME_DELTA (from midi-text) to beats according to Tristan's encoding scheme
133
+ Args:
134
+ - time_delta: int (TIME_DELTA)
135
+ - instrument: str ("Drums" or other instrument): used to determine the quantization resolution defined on constants.py
136
+ Returns:
137
+ - beats: float
138
+ """
139
+ beat_res = (
140
+ DRUMS_BEAT_QUANTIZATION
141
+ if instrument.lower() == "drums"
142
+ else NONE_DRUMS_BEAT_QUANTIZATION
143
+ )
144
+ beats = float(time_delta) / beat_res
145
+ return beats
146
+
147
+
148
+ def beat_to_int_dec_base(beat, beat_res=8):
149
+ """
150
+ Converts beats into "integer.decimal.base" (str) for miditok
151
+ Args:
152
+ - beat_str: "integer.decimal.base"
153
+ Returns:
154
+ - beats: float (e.g. "0.4.8" = 0 + 4/8 = 0.5)
155
+ """
156
+ int_dec_base = [
157
+ int((beat * beat_res) // beat_res),
158
+ int((beat * beat_res) % beat_res),
159
+ beat_res,
160
+ ]
161
+ return ".".join(map(str, int_dec_base))
162
+
163
+
164
+ def time_delta_to_int_dec_base(time_delta, instrument="drums"):
165
+ return chain(
166
+ time_delta,
167
+ [
168
+ time_delta_to_beat,
169
+ beat_to_int_dec_base,
170
+ ],
171
+ instrument,
172
+ )
173
+
174
+
175
+ def get_event(text, value=None, instrument="drums"):
176
+ """Converts a midi-text like event into a miditok like event"""
177
  match text:
178
  case "PIECE_START":
179
  return Event("Piece-Start", value)
180
  case "TRACK_START":
181
+ return Event("Track-Start", value)
182
  case "TRACK_END":
183
+ return Event("Track-End", value)
184
  case "INST":
185
+ if value == "DRUMS":
186
+ value = "Drums"
187
  return Event("Instrument", value)
188
  case "BAR_START":
189
  return Event("Bar-Start", value)
192
  case "TIME_SHIFT":
193
  return Event("Time-Shift", value)
194
  case "TIME_DELTA":
195
+ return Event("Time-Shift", time_delta_to_int_dec_base(value, instrument))
196
+ # return Event("Time-Shift", to_beat_str(int(value) / 4))
197
  case "NOTE_ON":
198
  return Event("Note-On", value)
199
  case "NOTE_OFF":
202
  return None
203
 
204
 
205
+ """ File utils"""
 
 
 
 
 
206
 
 
 
 
 
 
207
 
208
+ def writeToFile(path, content):
209
+ if type(content) is dict:
210
+ with open(f"{path}", "w") as json_file:
211
+ json.dump(content, json_file)
212
+ else:
213
+ if type(content) is not str:
214
+ content = str(content)
215
+ os.makedirs(os.path.dirname(path), exist_ok=True)
216
+ with open(path, "w") as f:
217
+ f.write(content)
218
 
 
 
 
 
 
 
 
 
 
219
 
220
+ def readFromFile(path, isJSON=False):
221
+ with open(path, "r") as f:
222
+ if isJSON:
223
+ return json.load(f)
224
+ else:
225
+ return f.read()
226
 
227
 
228
  def get_files(directory, extension, recursive=False):
239
  return list(directory.glob(f"*.{extension}"))
240
 
241
 
242
+ def load_jsonl(filepath):
243
+ """Load a jsonl file"""
244
+ with open(filepath, "r") as f:
245
+ data = [json.loads(line) for line in f]
246
+ return data
247
+
248
+
249
+ def write_mp3(waveform, output_path, bitrate="92k"):
250
+ """
251
+ Write a waveform to an mp3 file.
252
+ output_path: Path object for the output mp3 file
253
+ waveform: numpy array of the waveform
254
+ bitrate: bitrate of the mp3 file (64k, 92k, 128k, 256k, 312k)
255
+ """
256
+ # write the wav file
257
+ wav_path = output_path.with_suffix(".wav")
258
+ write(wav_path, 44100, waveform.astype(np.float32))
259
+ # compress the wav file as mp3
260
+ AudioSegment.from_wav(wav_path).export(output_path, format="mp3", bitrate=bitrate)
261
+ # remove the wav file
262
+ wav_path.unlink()
263
 
264
+
265
+ def copy_file(input_file, output_dir):
266
+ """Copy an input file to the output_dir"""
267
+ output_file = output_dir / input_file.name
268
+ shutil.copy(input_file, output_file)
269
 
270
 
271
  class FileCompressor:
298
  """compress all text files in folder to new zip files and remove the text files"""
299
  files = get_files(self.output_directory, extension="txt")
300
  Parallel(n_jobs=self.n_jobs)(delayed(self.zip_file)(file) for file in files)