from datetime import datetime from miditok import Event, MIDILike import os import json from time import perf_counter from constants import DRUMS_BEAT_QUANTIZATION, NONE_DRUMS_BEAT_QUANTIZATION from joblib import Parallel, delayed from zipfile import ZipFile, ZIP_DEFLATED from scipy.io.wavfile import write import numpy as np from pydub import AudioSegment import shutil """ Diverse utils""" def index_has_substring(list, substring): for i, s in enumerate(list): if substring in s: return i return -1 # TODO: Make this singleton def get_miditok(): pitch_range = range(0, 127) # was (21, 109) beat_res = {(0, 400): 8} return MIDILike(pitch_range, beat_res) def timeit(func): def wrapper(*args, **kwargs): start = perf_counter() result = func(*args, **kwargs) end = perf_counter() print(f"{func.__name__} took {end - start:.2f} seconds to run.") return result return wrapper def chain(input, funcs, *params): """Chain functions together, passing the output of one function as the input of the next.""" res = input for func in funcs: try: res = func(res, *params) except TypeError: res = func(res) return res def split_dots(value): """Splits a string separated by dots "a.b.c" into a list of integers [a, b, c]""" return list(map(int, value.split("."))) def compute_list_average(l): return sum(l) / len(l) def get_datetime(): return datetime.now().strftime("%Y%m%d_%H%M%S") """ Encoding functions """ def int_dec_base_to_beat(beat_str): """ Converts "integer.decimal.base" (str, from miditok) into beats e.g. "0.4.8" = 0 + 4/8 = 0.5 Args: - beat_str: "integer.decimal.base" Returns: - beats: float """ integer, decimal, base = split_dots(beat_str) return integer + decimal / base def int_dec_base_to_delta(beat_str, instrument="drums"): """converts the time shift to time_delta according to Tristan's encoding scheme Drums TIME_DELTA are quantized according to DRUMS_BEAT_QUANTIZATION Other Instrument TIME_DELTA are quantized according to NONE_DRUMS_BEAT_QUANTIZATION """ beat_res = ( DRUMS_BEAT_QUANTIZATION if instrument.lower() == "drums" else NONE_DRUMS_BEAT_QUANTIZATION ) time_delta = int_dec_base_to_beat(beat_str) * beat_res return time_delta.__int__() def get_text(event, instrument="drums"): """Converts an event into a string for the midi-text format""" match event.type: case "Piece-Start": return "PIECE_START " case "Track-Start": return "TRACK_START " case "Track-End": return "TRACK_END " case "Instrument": if str(event.value).lower() == "drums": return f"INST=DRUMS " else: return f"INST={event.value} " case "Density": return f"DENSITY={event.value} " case "Bar-Start": return "BAR_START " case "Bar-End": return "BAR_END " case "Time-Shift": return f"TIME_DELTA={int_dec_base_to_delta(event.value, instrument)} " case "Note-On": return f"NOTE_ON={event.value} " case "Note-Off": return f"NOTE_OFF={event.value} " case _: return "" """ Decoding functions """ def time_delta_to_beat(time_delta, instrument="drums"): """ Converts TIME_DELTA (from midi-text) to beats according to Tristan's encoding scheme Args: - time_delta: int (TIME_DELTA) - instrument: str ("Drums" or other instrument): used to determine the quantization resolution defined on constants.py Returns: - beats: float """ beat_res = ( DRUMS_BEAT_QUANTIZATION if instrument.lower() == "drums" else NONE_DRUMS_BEAT_QUANTIZATION ) beats = float(time_delta) / beat_res return beats def beat_to_int_dec_base(beat, beat_res=8): """ Converts beats into "integer.decimal.base" (str) for miditok Args: - beat_str: "integer.decimal.base" Returns: - beats: float (e.g. "0.4.8" = 0 + 4/8 = 0.5) """ int_dec_base = [ int((beat * beat_res) // beat_res), int((beat * beat_res) % beat_res), beat_res, ] return ".".join(map(str, int_dec_base)) def time_delta_to_int_dec_base(time_delta, instrument="drums"): return chain( time_delta, [ time_delta_to_beat, beat_to_int_dec_base, ], instrument, ) def get_event(text, value=None, instrument="drums"): """Converts a midi-text like event into a miditok like event""" match text: case "PIECE_START": return Event("Piece-Start", value) case "TRACK_START": return Event("Track-Start", value) case "TRACK_END": return Event("Track-End", value) case "INST": if value == "DRUMS": value = "Drums" return Event("Instrument", value) case "BAR_START": return Event("Bar-Start", value) case "BAR_END": return Event("Bar-End", value) case "TIME_SHIFT": return Event("Time-Shift", value) case "TIME_DELTA": return Event("Time-Shift", time_delta_to_int_dec_base(value, instrument)) # return Event("Time-Shift", to_beat_str(int(value) / 4)) case "NOTE_ON": return Event("Note-On", value) case "NOTE_OFF": return Event("Note-Off", value) case _: return None """ File utils""" def writeToFile(path, content): if type(content) is dict: with open(f"{path}", "w") as json_file: json.dump(content, json_file) else: if type(content) is not str: content = str(content) os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "w") as f: f.write(content) def readFromFile(path, isJSON=False): with open(path, "r") as f: if isJSON: return json.load(f) else: return f.read() def get_files(directory, extension, recursive=False): """ Given a directory, get a list of the file paths of all files matching the specified file extension. directory: the directory to search as a Path object extension: the file extension to match as a string recursive: whether to search recursively in the directory or not """ if recursive: return list(directory.rglob(f"*.{extension}")) else: return list(directory.glob(f"*.{extension}")) def load_jsonl(filepath): """Load a jsonl file""" with open(filepath, "r") as f: data = [json.loads(line) for line in f] return data def write_mp3(waveform, output_path, bitrate="92k"): """ Write a waveform to an mp3 file. output_path: Path object for the output mp3 file waveform: numpy array of the waveform bitrate: bitrate of the mp3 file (64k, 92k, 128k, 256k, 312k) """ # write the wav file wav_path = output_path.with_suffix(".wav") write(wav_path, 44100, waveform.astype(np.float32)) # compress the wav file as mp3 AudioSegment.from_wav(wav_path).export(output_path, format="mp3", bitrate=bitrate) # remove the wav file wav_path.unlink() def copy_file(input_file, output_dir): """Copy an input file to the output_dir""" output_file = output_dir / input_file.name shutil.copy(input_file, output_file) class FileCompressor: def __init__(self, input_directory, output_directory, n_jobs=-1): self.input_directory = input_directory self.output_directory = output_directory self.n_jobs = n_jobs # File compression and decompression def unzip_file(self, file): """uncompress single zip file""" with ZipFile(file, "r") as zip_ref: zip_ref.extractall(self.output_directory) def zip_file(self, file): """compress a single text file to a new zip file and delete the original""" output_file = self.output_directory / (file.stem + ".zip") with ZipFile(output_file, "w") as zip_ref: zip_ref.write(file, arcname=file.name, compress_type=ZIP_DEFLATED) file.unlink() @timeit def unzip(self): """uncompress all zip files in folder""" files = get_files(self.input_directory, extension="zip") Parallel(n_jobs=self.n_jobs)(delayed(self.unzip_file)(file) for file in files) @timeit def zip(self): """compress all text files in folder to new zip files and remove the text files""" files = get_files(self.output_directory, extension="txt") Parallel(n_jobs=self.n_jobs)(delayed(self.zip_file)(file) for file in files)